diff --git a/.bazelrc b/.bazelrc index 1a9c46362e5..d4d7ad61867 100644 --- a/.bazelrc +++ b/.bazelrc @@ -105,9 +105,6 @@ build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include -# Disable MKL-DNN contraction kernels by default. -build --define=tensorflow_mkldnn_contraction_kernel=0 - # Default options should come above this line # Options from ./configure diff --git a/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md b/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md index 34ba4cf9601..d562ced6f3a 100644 --- a/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md +++ b/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md @@ -18,10 +18,11 @@ about: Use this template for reporting a bug or a performance issue. - CUDA/cuDNN version: - GPU model and memory: - -You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) -You can also obtain the TensorFlow version with -python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" +You can collect some of this information using our environment capture +[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) +You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import +tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c +"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"` **Describe the current behavior** diff --git a/.github/ISSUE_TEMPLATE/20-documentation-issue.md b/.github/ISSUE_TEMPLATE/20-documentation-issue.md index 7123ca6d6c5..7f4a1f1b5b0 100644 --- a/.github/ISSUE_TEMPLATE/20-documentation-issue.md +++ b/.github/ISSUE_TEMPLATE/20-documentation-issue.md @@ -1,17 +1,55 @@ --- name: Documentation Issue -about: Use this template for documentation related issues +about: Use this template for documentation related +labels: 'type:docs' --- -Please make sure that this is a documentation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:doc_template +Thank you for submitting a TensorFlow documentation issue. Per our GitHub +policy, we only address code/doc bugs, performance issues, feature requests, and +build/installation issues on GitHub. +The TensorFlow docs are open source! To get involved, read the documentation +contributor guide: https://www.tensorflow.org/community/contribute/docs -**System information** -- TensorFlow version: -- Doc Link: +## URL(s) with the issue: +Please provide a link to the documentation entry, for example: +https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/MyMethod -**Describe the documentation issue** +## Description of issue (what needs changing): -**We welcome contributions by users. Will you be able to update submit a PR (use the [doc style guide](https://www.tensorflow.org/community/documentation)) to fix the doc Issue?** +### Clear description + +For example, why should someone use this method? How is it useful? + +### Correct links + +Is the link to the source code correct? + +### Parameters defined + +Are all parameters defined and formatted correctly? + +### Returns defined + +Are return values defined? + +### Raises listed and defined + +Are the errors defined? For example, +https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/feature_column/categorical_column_with_vocabulary_file#raises + +### Usage example + +Is there a usage example? + +### Request visuals, if applicable + +Are there currently visuals? If not, will it clarify the content? + +### Submit a pull request? + +Are you planning to also submit a pull request to fix the issue? See the docs +contributor guide: https://www.tensorflow.org/community/contribute/docs and the +docs style guide: https://www.tensorflow.org/community/contribute/docs_style diff --git a/.gitignore b/.gitignore index e1d352c238a..99ba9312a92 100644 --- a/.gitignore +++ b/.gitignore @@ -20,15 +20,8 @@ tensorflow/contrib/cmake/_build/ [Bb]uild/ /tensorflow/core/util/version_info.cc /tensorflow/python/framework/fast_tensor_util.cpp -Pods -Podfile.lock -*.pbxproj -*.xcworkspacedata -/tensorflow/lite/tools/make/downloads/** /tensorflow/lite/gen/** -/tensorflow/lite/examples/ios/simple/data/*.txt -/tensorflow/lite/examples/ios/simple/data/*.tflite -xcuserdata/** +/tensorflow/lite/tools/make/downloads/** /api_init_files_list.txt /estimator_api_init_files_list.txt *.whl @@ -39,3 +32,14 @@ xcuserdata/** *.iml local.properties gradleBuild + +# iOS +*.pbxproj +*.xcworkspace +/*.podspec +/tensorflow/lite/**/[ios|objc|swift]*/BUILD +/tensorflow/lite/examples/ios/simple/data/*.tflite +/tensorflow/lite/examples/ios/simple/data/*.txt +Podfile.lock +Pods +xcuserdata diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index b3d84ad8c94..04cd8cb65ef 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -32,7 +32,7 @@ https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh You can obtain the TensorFlow version with: ```bash -python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" +python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)" ``` ### Describe the problem diff --git a/LICENSE b/LICENSE index 4862420c023..12763eca4c2 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2018 The TensorFlow Authors. All rights reserved. +Copyright 2019 The TensorFlow Authors. All rights reserved. Apache License Version 2.0, January 2004 diff --git a/README.md b/README.md index 96a8ecf4f69..ec5e9af58d8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
-

+
----------------- @@ -25,7 +25,7 @@ networks research. The system is general enough to be applicable in a wide variety of other domains, as well. TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards -compatible API's for C++, Go, Java, JavaScript and Swift. +compatible API's for C++, Go, Java, JavaScript, and Swift. Keep up to date with release announcements and security updates by subscribing to @@ -50,10 +50,10 @@ instructions, and how to build from source.* People who are a little more adventurous can also try our nightly binaries: -**Nightly pip packages** -* We are pleased to announce that TensorFlow now offers nightly pip packages -under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and -[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on pypi. +**Nightly pip packages** * We are pleased to announce that TensorFlow now offers +nightly pip packages under the +[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and +[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on PyPi. Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean environment to install the nightly TensorFlow build. We support CPU and GPU packages on Linux, Mac, and Windows. @@ -85,7 +85,7 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's uphold this code.** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs, so please see +tracking requests and bugs, please see [TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** @@ -114,15 +114,16 @@ The TensorFlow project strives to abide by generally accepted best practices in ### Community Supported Builds -Build Type | Status | Artifacts ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA -**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) -**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) -**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) -**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) -**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) -**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)
[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)
[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)
[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl) +Build Type | Status | Artifacts +--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA +**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) +**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) +**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) +**Linux CPU with Intel® MKL-DNN**
**Supports Python 2.7, 3.4, 3.5, and 3.6** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/) +**Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 pypi](https://tensorflow.pypi.thoth-station.ninja/index/) ## For more information diff --git a/RELEASE.md b/RELEASE.md index 0a56e690987..c2c50c590ba 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,212 @@ +# Release 1.12.2 + +## Bug Fixes and Other Changes + +* Fixes a potential security vulnerability where carefully crafted GIF images + can produce a null pointer dereference during decoding. + +# Release 1.13.0 + +## Major Features and Improvements + +* TensorFlow Lite has moved from contrib to core. This means that Python modules are under `tf.lite` and source code is now under `tensorflow/lite` rather than `tensorflow/contrib/lite`. +* TensorFlow GPU binaries are now built against CUDA 10 and TensorRT 5.0. +* Support for Python3.7 on all operating systems. +* Moved NCCL to core. + +## Behavioral changes + +* Disallow conversion of python floating types to uint32/64 (matching behavior of other integer types) in `tf.constant`. +* Make the `gain` argument of convolutional orthogonal initializers (`convolutional_delta_orthogonal`, `convolutional_orthogonal_1D`, `convolutional_orthogonal_2D`, `convolutional_orthogonal_3D`) have consistent behavior with the `tf.initializers.orthogonal` initializer, i.e. scale the output l2-norm by `gain` and NOT by `sqrt(gain)`. (Note that these functions are currently in `tf.contrib` which is not guaranteed backward compatible). + +## Bug Fixes and Other Changes + +* Documentation + * Update the doc with the details about the rounding mode used in + quantize_and_dequantize_v2. + * Clarify that tensorflow::port::InitMain() _should_ be called before + using the TensorFlow library. Programs failing to do this are not + portable to all platforms. +* Deprecations and Symbol renames. + * Removing deprecations for the following endpoints: `tf.acos`, + `tf.acosh`, `tf.add`, `tf.as_string`, `tf.asin`, `tf.asinh`, `tf.atan`, + `tf.atan2`, `tf.atanh`, `tf.cos`, `tf.cosh`, `tf.equal`, `tf.exp`, + `tf.floor`, `tf.greater`, `tf.greater_equal`, `tf.less`, + `tf.less_equal`, `tf.log`, `tf.logp1`, `tf.logical_and`, + `tf.logical_not`, `tf.logical_or`, `tf.maximum`, `tf.minimum`, + `tf.not_equal`, `tf.sin`, `tf.sinh`, `tf.tan` + * Deprecate `tf.data.Dataset.shard`. + * Deprecate `saved_model.loader.load` which is replaced by + `saved_model.load` and `saved_model.main_op`, which will be replaced by + `saved_model.main_op` in V2. + * Deprecate tf.QUANTIZED_DTYPES. The official new symbol is + tf.dtypes.QUANTIZED_DTYPES. + * Update sklearn imports for deprecated packages. + * Deprecate `Variable.count_up_to` and `tf.count_up_to` in favor of + `Dataset.range`. + * Export `confusion_matrix` op as `tf.math.confusion_matrix` instead of + `tf.train.confusion_matrix`. + * Add `tf.dtypes.` endpoint for every constant in dtypes.py. Moving + endpoints in versions.py to corresponding endpoints in `tf.sysconfig.` + and `tf.version.`. Moving all constants under `tf.saved_model` + submodules to `tf.saved_model` module. New endpoints are added in V1 and + V2 but existing endpoint removals are only applied in V2. + * Deprecates behavior where device assignment overrides collocation + constraints inside a collocation context manager. +* Keras & Python API + * Add to Keras functionality analogous to + `tf.register_tensor_conversion_function`. + * Subclassed Keras models can now be saved through + `tf.contrib.saved_model.save_keras_model`. + * `LinearOperator.matmul` now returns a new `LinearOperator`. +* New ops and improved op functionality + * Add a Nearest Neighbor Resize op. + * Add an `ignore_unknown` argument to `parse_values` which suppresses + ValueError for unknown hyperparameter types. Such * Add + `tf.linalg.matvec` convenience function. + * `tf.einsum()`raises `ValueError` for unsupported equations like + `"ii->"`. + * Add DCT-I and IDCT-I in `tf.signal.dct` and `tf.signal.idct`. + * Add LU decomposition op. + * Add quantile loss to gradient boosted trees in estimator. + * Add `round_mode` to `QuantizeAndDequantizeV2` op to select rounding + algorithm. + * Add `unicode_encode`, `unicode_decode`, `unicode_decode_with_offsets`, + `unicode_split`, `unicode_split_with_offset`, and `unicode_transcode` + ops. Amongst other things, this Op adds the ability to encode, decode, + and transcode a variety of input text encoding formats into the main + Unicode encodings (UTF-8, UTF-16-BE, UTF-32-BE) + * Add "unit" attribute to the substr op, which allows obtaining the + substring of a string containing unicode characters. + * Broadcasting support for Ragged Tensors. + * `SpaceToDepth` supports uint8 data type. + * Support multi-label quantile regression in estimator. + * We now use "div" as the default partition_strategy in + `tf.nn.safe_embedding_lookup_sparse`, `tf.nn.sampled_softmax` and + `tf.nn.nce_loss`. hyperparameter are ignored. +* Performance + * Improve performance of GPU cumsum/cumprod by up to 300x. + * Added support for weight decay in most TPU embedding optimizers, + including AdamW and MomentumW. +* TensorFlow 2.0 Development + * Add a command line tool to convert to TF2.0, tf_upgrade_v2 + * Merge `tf.spectral` into `tf.signal` for TensorFlow 2.0. + * Change the default recurrent activation function for LSTM from + 'hard_sigmoid' to 'sigmoid' in 2.0. Historically recurrent activation is + 'hard_sigmoid' since it is fast than 'sigmoid'. With new unified backend + between CPU and GPU mode, since the CuDNN kernel is using sigmoid, we + change the default for CPU mode to sigmoid as well. With that, the + default LSTM will be compatible with both CPU and GPU kernel. This will + enable user with GPU to use CuDNN kernel by default and get a 10x + performance boost in training. Note that this is checkpoint breaking + change. If user want to use their 1.x pre-trained checkpoint, please + construct the layer with LSTM(recurrent_activation='hard_sigmoid') to + fallback to 1.x behavior. +* TensorFlow Lite + * Move from `tensorflow/contrib/lite` to `tensorflow/lite`. + * Add experimental Java API for injecting TensorFlow Lite delegates + * Add support for strings in TensorFlow Lite Java API. +* `tf.contrib`: + * Add Apache Ignite Filesystem plugin to support accessing Apache IGFS. + * Dropout now takes `rate` argument, `keep_prob` is deprecated. + * Estimator occurrences references `tf.contrib.estimator` were changed to + `tf.estimator`: + * `tf.contrib.estimator.BaselineEstimator` with + `tf.estimator.BaselineEstimator` + * `tf.contrib.estimator.DNNLinearCombinedEstimator` with + `tf.estimator.DNNLinearCombinedEstimator` + * `tf.contrib.estimator.DNNEstimator` with `tf.estimator.DNNEstimator` + * `tf.contrib.estimator.LinearEstimator` with + `tf.estimator.LinearEstimator` + * `tf.contrib.estimator.InMemoryEvaluatorHook` and + tf.estimator.experimental.InMemoryEvaluatorHook`. + * `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with + `tf.estimator.experimental.make_stop_at_checkpoint_step_hook`. + * Expose `tf.distribute.Strategy as the new name for + tf.contrib.distribute.DistributionStrategy. + * Migrate linear optimizer from contrib to core. + * Move `tf.contrib.signal` to `tf.signal` (preserving aliases in + tf.contrib.signal). + * Users of `tf.contrib.estimator.export_all_saved_models` and related + should switch to + `tf.estimator.Estimator.experimental_export_all_saved_models`. +* tf.data: + * Add `tf.data.experimental.StatsOptions()`, to configure options to + collect statistics from `tf.data.Dataset` pipeline using + `StatsAggregator`. Add nested option, `experimental_stats` (which takes + a `tf.data.experimen tal.StatsOptions` object), to `tf.data.Options`. + Deprecates `tf.data.experimental.set_stats_agregator`. + * Performance optimizations: + * Add `tf.data.experimental.OptimizationOptions()`, to configure options + to enable `tf.data` performance optimizations. Add nested option, + `experimental_optimization` (which takes a + `tf.data.experimental.OptimizationOptions` object), to + `tf.data.Options`. Remove performance optimization options from + `tf.data.Options`, and add them under + `tf.data.experimental.OptimizationOptions` instead. + * Enable `map_and_batch_fusion` and `noop_elimination` optimizations by + default. They can be disabled by configuring + `tf.data.experimental.OptimizationOptions` to set `map_and_batch = + False` or `noop_elimination = False` respectively. To disable all + default optimizations, set `apply_default_optimizations = False`. + * Support parallel map in `map_and_filter_fusion`. + * Disable static optimizations for input pipelines that use non-resource + `tf.Variable`s. + * Add NUMA-aware MapAndBatch dataset. + * Deprecate `tf.data.Dataset.make_one_shot_iterator()` in V1, removed it + from V2, and added tf.compat.v1.data.make_one_shot_iterator()`. + * Deprecate `tf.data.Dataset.make_initializable_iterator()` in V1, removed + it from V2, and added `tf.compat.v1.data.make_initializable_iterator()`. + * Enable nested dataset support in core `tf.data` transformations. + * For `tf.data.Dataset` implementers: Added + `tf.data.Dataset._element_structured property` to replace + `Dataset.output_{types,shapes,classes}`. + * Make `num_parallel_calls` of `tf.data.Dataset.interleave` and + `tf.data.Dataset.map` work in Eager mode. +* Toolchains + * Fixed OpenSSL compatibility by avoiding `EVP_MD_CTX_destroy`. + * Added bounds checking to printing deprecation warnings. + * Upgraded CUDA dependency to 10.0 + * To build with Android NDK r14b, add "#include " to + android-ndk-r14b/platforms/android-14/arch-*/usr/include/linux/futex.h + * Removed `:android_tensorflow_lib_selective_registration*` targets, use + `:android_tensorflow_lib_lite*` targets instead. +* XLA + * Move `RoundToEven` function to xla/client/lib/math.h. + * A new environment variable `TF_XLA_DEBUG_OPTIONS_PASSTHROUGH` set to "1" + or "true" allows the debug options passed within an XRTCompile op to be + passed directly to the XLA compilation backend. If such variable is not + set (service side), only a restricted set will be passed through. + * Allow the XRTCompile op to return the ProgramShape resulted form the XLA + compilation as a second return argument. + * XLA HLO graphs can now be rendered as SVG/HTML. +* Estimator + * Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with + `tf.estimator.BaselineEstimator` + * Replace all occurences of + `tf.contrib.estimator.DNNLinearCombinedEstimator` with + `tf.estimator.DNNLinearCombinedEstimator` + * Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with + `tf.estimator.DNNEstimator` + * Replace all occurrences of `tf.contrib.estimator.LinearEstimator` with + `tf.estimator.LinearEstimator` + * Users of `tf.contrib.estimator.export_all_saved_models` and related + should switch to + `tf.estimator.Estimator.experimental_export_all_saved_models`. + * Update `regression_head` to the new Head API for Canned Estimator V2. + * Switch `multi_class_head` to Head API for Canned Estimator V2. + * Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook` + and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with + `tf.estimator.experimental.InMemoryEvaluatorHook` and + `tf.estimator.experimental.make_stop_at_checkpoint_step_hook` + * Migrate linear optimizer from contrib to core. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Abhinav Upadhyay, Ag Ramesh, akikaaa, Alexis Louis, Anders Huss, Andreas Madsen, Andrew Banchich, Andy Craze, Anton Dmitriev, Artem Malykh, Avijit-Nervana, Balint Cristian, Benjamin Tan Wei Hao, Bhavani Subramanian, Brendan Finan, Brian Nemsick, Bryan Cutler, By Shen, Cao Zongyan, Castiel, Chris Antaki, Christian Goll, Cibifang, Clayne Robison, Codrut Grosu, Cong Xu, Dalmo Cirne, Daniel Hunter, Dougal J. Sutherland, Edvard Fagerholm, EFanZh, Erik Smistad, Evgeniy Polyakov, Feiyang Chen, franklin5, Fred Reiss, Gautam, gehring, Geoffrey Irving, George Sterpu, Gitea, Grzegorz George Pawelczak, Guozhong Zhuang, himkt, Hoeseong Kim, Huan Li (李卓桓), HuiyangFei, hyunyoung, Isaac Burbank, jackonan, Jacky Ko, Jason Furmanek, Jason Zaman, Javier Luraschi, Jiang,Zhoulong, joaak, John Lin, Jonathan Wyatt Hoech, josephyearsley, Josh Gordon, Julian Niedermeier, Karl Lessard, Keno Fischer, lanhin, Leon Graser, leondgarse, Li, Guizi, Li, Yiqiang, lxl910915, Mahmoud Abuzaina, manhyuk, Marcela Morales Quispe, margaretmz, Matt Conley, Max Pumperla, mbhuiyan, mdfaijul, Meng, Peng, Michael, Michael Gielda, mrTsjolder, Muhammad Wildan, neargye, Nehal J Wani, NEWPLAN, Niranjan Hasabnis, Nutti, olicht, Pan Daoxin, Pedro Monreal, Peng Yu, pillarpond, Pooya Davoodi, qiezi, Rholais Lii, Richard Yu, Rin Arakaki, Roger Iyengar, sahilbadyal, Sami Kama, Sandip Giri, Scott Leishman, Serge Panev, Seunghoon Park, Shafi Dayatar, shengfuintel, Shimin Guo, Siju, silent567, Stefan Dyulgerov, steven, Tao Wei, Thor Johnsen, Tingbo Lu, tomguluson92, Tongxuan Liu, Trevor Morris, Ubuntu, Vadim Borisov, vanderliang, wangsiyu, Wen Yun, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, Xiaoming (Jason) Cui, Yan Facai (颜发才), Yanbo Liang, Yaniv Blumenfeld, Yash Gaurkar, Yicheng Fan, Yong Tang, Yongjoon Lee, Yuan (Terry) Tang, Yuxin Wu, zldrobit + # Release 1.12.0 ## Major Features and Improvements @@ -38,21 +247,21 @@ * Remove integer types from `tf.nn.softplus` and `tf.nn.softsign` OpDefs. This is a bugfix; these ops were never meant to support integers. * Allow subslicing Tensors with a single dimension. - * Add option to calculate string length in Unicode characters + * Add option to calculate string length in Unicode characters. * Add functionality to SubSlice a tensor. * Add searchsorted (ie lower/upper_bound) op. * Add model explainability to Boosted Trees. - * Support negative positions for tf.substr + * Support negative positions for tf.substr. * There was previously a bug in the bijector_impl where the _reduce_jacobian_det_over_event does not handle scalar ILDJ implementations properly. - * In tf eager execution, allow re-entering a GradientTape context + * In tf eager execution, allow re-entering a GradientTape context. * Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, then bazel will build TensorFlow API version 2.0. Note that TensorFlow 2.0 is under active development and has no guarantees at this point. - * Add additional compression options to TfRecordWriter + * Add additional compression options to TfRecordWriter. * Performance improvements for regex full match operations. - * Replace tf.GraphKeys.VARIABLES with `tf.GraphKeys.GLOBAL_VARIABLES` + * Replace tf.GraphKeys.VARIABLES with `tf.GraphKeys.GLOBAL_VARIABLES`. * Remove unused dynamic learning rate support. ## Thanks to our Contributors @@ -75,15 +284,22 @@ Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为 ## Major Features and Improvements -* Nvidia GPU: - * Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN 7.2 and TensorRT 4. See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) -* Google Cloud TPU: - * Experimental tf.data integration for Keras on Google Cloud TPUs. - * Experimental / preview support for eager execution on Google Cloud TPUs. -* DistributionStrategy: - * Add multi-GPU DistributionStrategy support in tf.keras. Users can now use `fit`, `evaluate` and `predict` to distribute their model on multiple GPUs. - * Add multi-worker DistributionStrategy and standalone client support in Estimator. See [README] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute) for more details. -* Add C, C++, and Python functions for querying kernels +* Nvidia GPU: + * Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN + 7.2 and TensorRT 4. See updated install guides: + [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) +* Google Cloud TPU: + * Experimental tf.data integration for Keras on Google Cloud TPUs. + * Experimental / preview support for eager execution on Google Cloud TPUs. +* DistributionStrategy: + * Add multi-GPU DistributionStrategy support in tf.keras. Users can now + use `fit`, `evaluate` and `predict` to distribute their model on + multiple GPUs. + * Add multi-worker DistributionStrategy and standalone client support in + Estimator. See + [README](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute) + for more details. +* Add C, C++, and Python functions for querying kernels. ## Breaking Changes @@ -134,18 +350,18 @@ Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为 * Deprecate self.test_session() in favor of self.session() or self.cached_session(). * Directly import tensor.proto.h (the transitive import will be removed - from tensor.h soon) + from tensor.h soon). * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one. * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator. - * Fix toco compilation/execution on Windows + * Fix toco compilation/execution on Windows. * GoogleZoneProvider class added to detect which Google Cloud Engine zone tensorflow is running in. * It is now safe to call any of the C API's TF_Delete\* functions on - nullptr - * Log some errors on Android to logcat + nullptr. + * Log some errors on Android to logcat. * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models. * Optional bucket location check for the GCS Filesystem. @@ -166,7 +382,7 @@ Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为 the existing zero_state() method. * Update initialization of variables in Keras. * Updates to "constrained_optimization" in tensorflow/contrib. - * boosted trees: adding pruning mode + * boosted trees: adding pruning mode. * tf.train.Checkpoint does not delete old checkpoints by default. * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow diff --git a/WORKSPACE b/WORKSPACE index 9f07b9fd471..868421dc31e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -4,11 +4,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file" http_archive( name = "io_bazel_rules_closure", - sha256 = "43c9b882fa921923bcba764453f4058d102bece35a37c9f6383c713004aacff1", - strip_prefix = "rules_closure-9889e2348259a5aad7e805547c1a0cf311cfcd91", + sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9", + strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", # 2018-12-21 + "http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04 ], ) @@ -43,17 +43,37 @@ remote_config_workspace() # Apple and Swift rules. http_archive( name = "build_bazel_rules_apple", - sha256 = "73b4980a318d203d3307f850e27e66ec5cc8d223147a3475a6f11597eb6438a5", - strip_prefix = "rules_apple-0.13.0", - urls = ["https://github.com/bazelbuild/rules_apple/archive/0.13.0.tar.gz"], -) + sha256 = "23792cd999f97fc97284d1c44cb1324bfdd0bc54aa68ad513fa3705aca3b1f9e", + urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.15.0/rules_apple.0.15.0.tar.gz"], +) # https://github.com/bazelbuild/rules_apple/releases +http_archive( + name = "build_bazel_apple_support", + sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471", + urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"], +) # https://github.com/bazelbuild/apple_support/releases +http_archive( + name = "bazel_skylib", + sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e", + urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"], +) # https://github.com/bazelbuild/bazel-skylib/releases +http_archive( + name = "build_bazel_rules_swift", + sha256 = "9efe9699e9765e6b4a5e063e4a08f6b163cccaf0443f775d935baf5c3cd6ed0e", + urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.9.0/rules_swift.0.9.0.tar.gz"], +) # https://github.com/bazelbuild/rules_swift/releases +http_archive( + name = "com_github_apple_swift_swift_protobuf", + type = "zip", + strip_prefix = "swift-protobuf-1.5.0/", + urls = ["https://github.com/apple/swift-protobuf/archive/1.5.0.zip"], +) # https://github.com/apple/swift-protobuf/releases http_file( name = "xctestrunner", executable = 1, - urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"], -) -load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") -apple_rules_dependencies() + urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"], +) # https://github.com/google/xctestrunner/releases +# Use `swift_rules_dependencies` to fetch the toolchains. With the +# `git_repository` rules above, the following call will skip redefining them. load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") swift_rules_dependencies() diff --git a/configure.py b/configure.py index 4814143f466..2120a4b27d6 100644 --- a/configure.py +++ b/configure.py @@ -33,13 +33,11 @@ except ImportError: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_DEFAULT_CUDA_VERSION = '10.0' +_DEFAULT_CUDA_VERSION = '10' _DEFAULT_CUDNN_VERSION = '7' +_DEFAULT_TENSORRT_VERSION = '5' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' -_DEFAULT_CUDA_PATH = '/usr/local/cuda' -_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' -_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' - 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) + _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' @@ -50,21 +48,24 @@ _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' _TF_WORKSPACE_ROOT = '' _TF_BAZELRC = '' +_TF_CURRENT_BAZEL_VERSION = None NCCL_LIB_PATHS = [ 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' ] -# List of files to be configured for using Bazel on Apple platforms. +# List of files to configure when building Bazel on Apple platforms. APPLE_BAZEL_FILES = [ + 'tensorflow/lite/experimental/ios/BUILD', 'tensorflow/lite/experimental/objc/BUILD', 'tensorflow/lite/experimental/swift/BUILD' ] -if platform.machine() == 'ppc64le': - _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' -else: - _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() +# List of files to move when building for iOS. +IOS_FILES = [ + 'tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec', + 'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec', +] class UserInputError(Exception): @@ -199,9 +200,10 @@ def setup_python(environ_cp): ask_python_bin_path = ('Please specify the location of python. [Default is ' '%s]: ') % default_python_bin_path while True: - python_bin_path = get_from_env_or_user_or_default( - environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path, - default_python_bin_path) + python_bin_path = get_from_env_or_user_or_default(environ_cp, + 'PYTHON_BIN_PATH', + ask_python_bin_path, + default_python_bin_path) # Check if the path is valid if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): break @@ -291,9 +293,9 @@ def get_var(environ_cp, Args: environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". - query_item: string for feature related to the variable, e.g. "Hadoop File - System". + var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". + query_item: string for feature related to the variable, e.g. "CUDA for + Nvidia GPUs". enabled_by_default: boolean for default behavior. question: optional string for how to ask for user input. yes_reply: optional string for reply when feature is enabled. @@ -337,8 +339,8 @@ def get_var(environ_cp, 'Environment variable %s must be set as a boolean indicator.\n' 'The following are accepted as TRUE : %s.\n' 'The following are accepted as FALSE: %s.\n' - 'Current value is %s.' % (var_name, ', '.join(true_strings), - ', '.join(false_strings), var)) + 'Current value is %s.' % + (var_name, ', '.join(true_strings), ', '.join(false_strings), var)) while var is None: user_input_origin = get_input(question) @@ -374,9 +376,9 @@ def set_build_var(environ_cp, Args: environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". - query_item: string for feature related to the variable, e.g. "Hadoop File - System". + var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". + query_item: string for feature related to the variable, e.g. "CUDA for + Nvidia GPUs". option_name: string for option to define in .bazelrc. enabled_by_default: boolean for default behavior. bazel_config_name: Name for Bazel --config argument to enable build feature. @@ -385,14 +387,14 @@ def set_build_var(environ_cp, var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) environ_cp[var_name] = var if var == '1': - write_to_bazelrc( - 'build:%s --define %s=true' % (bazel_config_name, option_name)) + write_to_bazelrc('build:%s --define %s=true' % + (bazel_config_name, option_name)) write_to_bazelrc('build --config=%s' % bazel_config_name) elif bazel_config_name is not None: # TODO(mikecase): Migrate all users of configure.py to use --config Bazel # options and not to set build configs through environment variables. - write_to_bazelrc( - 'build:%s --define %s=true' % (bazel_config_name, option_name)) + write_to_bazelrc('build:%s --define %s=true' % + (bazel_config_name, option_name)) def set_action_env_var(environ_cp, @@ -409,9 +411,9 @@ def set_action_env_var(environ_cp, Args: environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". - query_item: string for feature related to the variable, e.g. "Hadoop File - System". + var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". + query_item: string for feature related to the variable, e.g. "CUDA for + Nvidia GPUs". enabled_by_default: boolean for default behavior. question: optional string for how to ask for user input. yes_reply: optional string for reply when feature is enabled. @@ -439,6 +441,9 @@ def convert_version_to_int(version): """ version = version.split('-')[0] version_segments = version.split('.') + # Treat "0.24" as "0.24.0" + if len(version_segments) == 2: + version_segments.append('0') for seg in version_segments: if not seg.isdigit(): return None @@ -451,8 +456,8 @@ def check_bazel_version(min_version, max_version): """Check installed bazel version is between min_version and max_version. Args: - min_version: string for minimum bazel version. - max_version: string for maximum bazel version. + min_version: string for minimum bazel version (must exist!). + max_version: string for maximum bazel version (must exist!). Returns: The bazel version detected. @@ -565,7 +570,7 @@ def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, Args: environ_cp: copy of the os.environ. - var_name: string for name of environment variable, e.g. "TF_NEED_HDFS". + var_name: string for name of environment variable, e.g. "TF_NEED_CUDA". ask_for_var: string for how to ask for user input. var_default: default value string. @@ -658,9 +663,9 @@ def prompt_loop_or_load_from_env(environ_cp, print(error_msg % val) environ_cp[var_name] = '' else: - raise UserInputError( - 'Invalid %s setting was provided %d times in a row. ' - 'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts)) + raise UserInputError('Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % + (var_name, n_ask_attempts)) environ_cp[var_name] = val return val @@ -669,8 +674,8 @@ def prompt_loop_or_load_from_env(environ_cp, def create_android_ndk_rule(environ_cp): """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" if is_windows() or is_cygwin(): - default_ndk_path = cygpath( - '%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA']) + default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % + environ_cp['APPDATA']) elif is_macos(): default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] else: @@ -689,8 +694,9 @@ def create_android_ndk_rule(environ_cp): error_msg=('The path %s or its child file "source.properties" ' 'does not exist.')) write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) - write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', - check_ndk_level(android_ndk_home_path)) + write_action_env_to_bazelrc( + 'ANDROID_NDK_API_LEVEL', + get_ndk_api_level(environ_cp, android_ndk_home_path)) def create_android_sdk_rule(environ_cp): @@ -757,8 +763,10 @@ def create_android_sdk_rule(environ_cp): write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path) -def check_ndk_level(android_ndk_home_path): - """Check the revision number of an Android NDK path.""" +def get_ndk_api_level(environ_cp, android_ndk_home_path): + """Gets the appropriate NDK API level to use for the provided Android NDK path.""" + + # First check to see if we're using a blessed version of the NDK. properties_path = '%s/source.properties' % android_ndk_home_path if is_windows() or is_cygwin(): properties_path = cygpath(properties_path) @@ -767,16 +775,40 @@ def check_ndk_level(android_ndk_home_path): revision = re.search(r'Pkg.Revision = (\d+)', filedata) if revision: - ndk_api_level = revision.group(1) + ndk_version = revision.group(1) else: raise Exception('Unable to parse NDK revision.') - if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: - print('WARNING: The API level of the NDK in %s is %s, which is not ' + if int(ndk_version) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The NDK version in %s is %s, which is not ' 'supported by Bazel (officially supported versions: %s). Please use ' 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_api_level, + 'errors.\n' % (android_ndk_home_path, ndk_version, _SUPPORTED_ANDROID_NDK_VERSIONS)) - return ndk_api_level + + # Now grab the NDK API level to use. Note that this is different from the + # SDK API level, as the NDK API level is effectively the *min* target SDK + # version. + platforms = os.path.join(android_ndk_home_path, 'platforms') + api_levels = sorted(os.listdir(platforms)) + api_levels = [ + x.replace('android-', '') for x in api_levels if 'android-' in x + ] + + def valid_api_level(api_level): + return os.path.exists( + os.path.join(android_ndk_home_path, 'platforms', + 'android-' + api_level)) + + android_ndk_api_level = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_NDK_API_LEVEL', + var_default='18', # 18 is required for GPU acceleration. + ask_for_var=('Please specify the (min) Android NDK API level to use. ' + '[Available levels: %s]') % api_levels, + check_success=valid_api_level, + error_msg='Android-%s is not present in the NDK path.') + + return android_ndk_api_level def set_gcc_host_compiler_path(environ_cp): @@ -823,149 +855,39 @@ def reformat_version_sequence(version_str, sequence_count): return '.'.join(v[:sequence_count]) +def set_tf_cuda_paths(environ_cp): + """Set TF_CUDA_PATHS.""" + ask_cuda_paths = ( + 'Please specify the comma-separated list of base paths to look for CUDA ' + 'libraries and headers. [Leave empty to use the default]: ') + tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS', + ask_cuda_paths, '') + if tf_cuda_paths: + environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths + + def set_tf_cuda_version(environ_cp): - """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" + """Set TF_CUDA_VERSION.""" ask_cuda_version = ( 'Please specify the CUDA SDK version you want to use. ' '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - - for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): - # Configure the Cuda SDK version to use. - tf_cuda_version = get_from_env_or_user_or_default( - environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) - tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2) - - # Find out where the CUDA toolkit is installed - default_cuda_path = _DEFAULT_CUDA_PATH - if is_windows() or is_cygwin(): - default_cuda_path = cygpath( - environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN)) - elif is_linux(): - # If the default doesn't exist, try an alternative default. - if (not os.path.exists(default_cuda_path) - ) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX): - default_cuda_path = _DEFAULT_CUDA_PATH_LINUX - ask_cuda_path = ('Please specify the location where CUDA %s toolkit is' - ' installed. Refer to README.md for more details. ' - '[Default is %s]: ') % (tf_cuda_version, default_cuda_path) - cuda_toolkit_path = get_from_env_or_user_or_default( - environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path) - if is_windows() or is_cygwin(): - cuda_toolkit_path = cygpath(cuda_toolkit_path) - - if is_windows(): - cuda_rt_lib_paths = ['lib/x64/cudart.lib'] - elif is_linux(): - cuda_rt_lib_paths = [ - '%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [ - 'lib64', - 'lib/powerpc64le-linux-gnu', - 'lib/x86_64-linux-gnu', - ] - ] - elif is_macos(): - cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version] - - cuda_toolkit_paths_full = [ - os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths - ] - if any(os.path.exists(x) for x in cuda_toolkit_paths_full): - break - - # Reset and retry - print('Invalid path to CUDA %s toolkit. %s cannot be found' % - (tf_cuda_version, cuda_toolkit_paths_full)) - environ_cp['TF_CUDA_VERSION'] = '' - environ_cp['CUDA_TOOLKIT_PATH'] = '' - - else: - raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d ' - 'times in a row. Assuming to be a scripting mistake.' % - _DEFAULT_PROMPT_ASK_ATTEMPTS) - - # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION - environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path - write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path) + tf_cuda_version = get_from_env_or_user_or_default(environ_cp, + 'TF_CUDA_VERSION', + ask_cuda_version, + _DEFAULT_CUDA_VERSION) environ_cp['TF_CUDA_VERSION'] = tf_cuda_version - write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version) def set_tf_cudnn_version(environ_cp): - """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" + """Set TF_CUDNN_VERSION.""" ask_cudnn_version = ( 'Please specify the cuDNN version you want to use. ' '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION - - for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): - tf_cudnn_version = get_from_env_or_user_or_default( - environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, - _DEFAULT_CUDNN_VERSION) - tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1) - - default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') - ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' - 'installed. Refer to README.md for more details. [Default' - ' is %s]: ') % (tf_cudnn_version, default_cudnn_path) - cudnn_install_path = get_from_env_or_user_or_default( - environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path) - - # Result returned from "read" will be used unexpanded. That make "~" - # unusable. Going through one more level of expansion to handle that. - cudnn_install_path = os.path.realpath( - os.path.expanduser(cudnn_install_path)) - if is_windows() or is_cygwin(): - cudnn_install_path = cygpath(cudnn_install_path) - - if is_windows(): - cuda_dnn_lib_path = 'lib/x64/cudnn.lib' - cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib' - elif is_linux(): - cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version - cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version - elif is_macos(): - cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version - cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version - - cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path) - cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path, - cuda_dnn_lib_alt_path) - if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists( - cuda_dnn_lib_alt_path_full): - break - - # Try another alternative for Linux - if is_linux(): - ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' - cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) - cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)', - cudnn_path_from_ldconfig) - if cudnn_path_from_ldconfig: - cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1) - if os.path.exists( - '%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)): - cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) - break - - # Reset and Retry - print( - 'Invalid path to cuDNN %s toolkit. None of the following files can be ' - 'found:' % tf_cudnn_version) - print(cuda_dnn_lib_path_full) - print(cuda_dnn_lib_alt_path_full) - if is_linux(): - print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)) - - environ_cp['TF_CUDNN_VERSION'] = '' - else: - raise UserInputError('Invalid TF_CUDNN setting was provided %d ' - 'times in a row. Assuming to be a scripting mistake.' % - _DEFAULT_PROMPT_ASK_ATTEMPTS) - - # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION - environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path - write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path) + tf_cudnn_version = get_from_env_or_user_or_default(environ_cp, + 'TF_CUDNN_VERSION', + ask_cudnn_version, + _DEFAULT_CUDNN_VERSION) environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version - write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) def is_cuda_compatible(lib, cuda_ver, cudnn_ver): @@ -997,252 +919,38 @@ def is_cuda_compatible(lib, cuda_ver, cudnn_ver): return cudnn_ok and cuda_ok -def set_tf_tensorrt_install_path(environ_cp): - """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. - - Adapted from code contributed by Sami Kama (https://github.com/samikama). - - Args: - environ_cp: copy of the os.environ. - - Raises: - ValueError: if this method was called under non-Linux platform. - UserInputError: if user has provided invalid input multiple times. - """ +def set_tf_tensorrt_version(environ_cp): + """Set TF_TENSORRT_VERSION.""" if not is_linux(): raise ValueError('Currently TensorRT is only supported on Linux platform.') - # Ask user whether to add TensorRT support. - if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', - False))) != '1': + if not int(environ_cp.get('TF_NEED_TENSORRT', False)): return - for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): - ask_tensorrt_path = (r'Please specify the location where TensorRT is ' - 'installed. [Default is %s]:') % ( - _DEFAULT_TENSORRT_PATH_LINUX) - trt_install_path = get_from_env_or_user_or_default( - environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path, - _DEFAULT_TENSORRT_PATH_LINUX) - - # Result returned from "read" will be used unexpanded. That make "~" - # unusable. Going through one more level of expansion to handle that. - trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path)) - - def find_libs(search_path): - """Search for libnvinfer.so in "search_path".""" - fl = set() - if os.path.exists(search_path) and os.path.isdir(search_path): - fl.update([ - os.path.realpath(os.path.join(search_path, x)) - for x in os.listdir(search_path) - if 'libnvinfer.so' in x - ]) - return fl - - possible_files = find_libs(trt_install_path) - possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) - possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) - cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) - cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) - nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') - highest_ver = [0, None, None] - - for lib_file in possible_files: - if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): - matches = nvinfer_pattern.search(lib_file) - if not matches.groups(): - continue - ver_str = matches.group(1) - ver = convert_version_to_int(ver_str) if len(ver_str) else 0 - if ver > highest_ver[0]: - highest_ver = [ver, ver_str, lib_file] - if highest_ver[1] is not None: - trt_install_path = os.path.dirname(highest_ver[2]) - tf_tensorrt_version = highest_ver[1] - break - - # Try another alternative from ldconfig. - ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' - ldconfig_output = run_shell([ldconfig_bin, '-p']) - search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)', - ldconfig_output) - if search_result: - libnvinfer_path_from_ldconfig = search_result.group(2) - if os.path.exists(libnvinfer_path_from_ldconfig): - if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver, - cudnn_ver): - trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) - tf_tensorrt_version = search_result.group(1) - break - - # Reset and Retry - if possible_files: - print('TensorRT libraries found in one the following directories', - 'are not compatible with selected cuda and cudnn installations') - print(trt_install_path) - print(os.path.join(trt_install_path, 'lib')) - print(os.path.join(trt_install_path, 'lib64')) - if search_result: - print(libnvinfer_path_from_ldconfig) - else: - print( - 'Invalid path to TensorRT. None of the following files can be found:') - print(trt_install_path) - print(os.path.join(trt_install_path, 'lib')) - print(os.path.join(trt_install_path, 'lib64')) - if search_result: - print(libnvinfer_path_from_ldconfig) - - else: - raise UserInputError('Invalid TF_TENSORRT setting was provided %d ' - 'times in a row. Assuming to be a scripting mistake.' % - _DEFAULT_PROMPT_ASK_ATTEMPTS) - - # Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION - environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path - write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path) + ask_tensorrt_version = ( + 'Please specify the TensorRT version you want to use. ' + '[Leave empty to default to TensorRT %s]: ') % _DEFAULT_TENSORRT_VERSION + tf_tensorrt_version = get_from_env_or_user_or_default( + environ_cp, 'TF_TENSORRT_VERSION', ask_tensorrt_version, + _DEFAULT_TENSORRT_VERSION) environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version - write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version) -def set_tf_nccl_install_path(environ_cp): - """Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION. - - Args: - environ_cp: copy of the os.environ. - - Raises: - ValueError: if this method was called under non-Linux platform. - UserInputError: if user has provided invalid input multiple times. - """ +def set_tf_nccl_version(environ_cp): + """Set TF_NCCL_VERSION.""" if not is_linux(): - raise ValueError('Currently NCCL is only supported on Linux platforms.') + raise ValueError('Currently NCCL is only supported on Linux platform.') + + if 'TF_NCCL_VERSION' in environ_cp: + return ask_nccl_version = ( 'Please specify the locally installed NCCL version you want to use. ' - '[Default is to use https://github.com/nvidia/nccl]: ') - - for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): - tf_nccl_version = get_from_env_or_user_or_default( - environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '') - - if not tf_nccl_version: - break # No need to get install path, building the open source code. - - tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) - - # Look with ldconfig first if we can find the library in paths - # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding - # include directory. This is where the NCCL .deb packages install them. - - # First check to see if NCCL is in the ldconfig. - # If its found, use that location. - if is_linux(): - ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' - nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) - nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)', - nccl2_path_from_ldconfig) - if nccl2_path_from_ldconfig: - nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1) - if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)): - nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig) - print('NCCL libraries found in ' + nccl2_path_from_ldconfig) - - # Check if this is the main system lib location - if re.search('.*linux-gnu', nccl_install_path): - trunc_nccl_install_path = '/usr' - print('This looks like a system path.') - else: - trunc_nccl_install_path = nccl_install_path + '/..' - - # Look for header - nccl_hdr_path = trunc_nccl_install_path + '/include' - print('Assuming NCCL header path is ' + nccl_hdr_path) - if os.path.exists(nccl_hdr_path + '/nccl.h'): - # Set NCCL_INSTALL_PATH - environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path - write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) - - # Set NCCL_HDR_PATH - environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path - write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path) - break - else: - print( - 'The header for NCCL2 cannot be found. Please install the libnccl-dev package.' - ) - else: - print('NCCL2 is listed by ldconfig but the library is not found. ' - 'Your ldconfig is out of date. Please run sudo ldconfig.') - else: - # NCCL is not found in ldconfig. Ask the user for the location. - default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') - ask_nccl_path = ( - r'Please specify the location where NCCL %s library is ' - 'installed. Refer to README.md for more details. [Default ' - 'is %s]:') % (tf_nccl_version, default_nccl_path) - nccl_install_path = get_from_env_or_user_or_default( - environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) - - # Result returned from "read" will be used unexpanded. That make "~" - # unusable. Going through one more level of expansion to handle that. - nccl_install_path = os.path.realpath( - os.path.expanduser(nccl_install_path)) - if is_windows() or is_cygwin(): - nccl_install_path = cygpath(nccl_install_path) - - nccl_lib_path = '' - if is_windows(): - nccl_lib_path = 'lib/x64/nccl.lib' - elif is_linux(): - nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version - nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename) - if not os.path.exists(nccl_lpath): - for relative_path in NCCL_LIB_PATHS: - path = '%s/%s%s' % (nccl_install_path, relative_path, - nccl_lib_filename) - if os.path.exists(path): - print('NCCL found at ' + path) - nccl_lib_path = path - break - else: - nccl_lib_path = nccl_lpath - elif is_macos(): - nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version - - nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) - nccl_hdr_path = os.path.join( - os.path.dirname(nccl_lib_path), '../include/nccl.h') - print('Assuming NCCL header path is ' + nccl_hdr_path) - if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): - # Set NCCL_INSTALL_PATH - environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path) - write_action_env_to_bazelrc('NCCL_INSTALL_PATH', - os.path.dirname(nccl_lib_path)) - - # Set NCCL_HDR_PATH - environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path) - write_action_env_to_bazelrc('NCCL_HDR_PATH', - os.path.dirname(nccl_hdr_path)) - break - - # Reset and Retry - print( - 'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' - 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, - nccl_hdr_path)) - - environ_cp['TF_NCCL_VERSION'] = '' - else: - raise UserInputError('Invalid TF_NCCL setting was provided %d ' - 'times in a row. Assuming to be a scripting mistake.' % - _DEFAULT_PROMPT_ASK_ATTEMPTS) - - # Set TF_NCCL_VERSION + '[Leave empty to use http://github.com/nvidia/nccl]: ') + tf_nccl_version = get_from_env_or_user_or_default(environ_cp, + 'TF_NCCL_VERSION', + ask_nccl_version, '') environ_cp['TF_NCCL_VERSION'] = tf_nccl_version - write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) - def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1305,11 +1013,14 @@ def set_tf_cuda_compute_capabilities(environ_cp): all_valid = False else: ver = float(m.group(0)) - if ver < 3.5: - print('ERROR: TensorFlow only supports CUDA compute capabilities 3.5 ' + if ver < 3.0: + print('ERROR: TensorFlow only supports CUDA compute capabilities 3.0 ' 'and higher. Please re-specify the list of compute ' 'capabilities excluding version %s.' % ver) all_valid = False + if ver < 3.5: + print('WARNING: XLA does not support CUDA compute capabilities ' + 'lower than 3.5. Disable XLA when running on older GPUs.') if all_valid: break @@ -1328,10 +1039,8 @@ def set_other_cuda_vars(environ_cp): # If CUDA is enabled, always use GPU during build and test. if environ_cp.get('TF_CUDA_CLANG') == '1': write_to_bazelrc('build --config=cuda_clang') - write_to_bazelrc('test --config=cuda_clang') else: write_to_bazelrc('build --config=cuda') - write_to_bazelrc('test --config=cuda') def set_host_cxx_compiler(environ_cp): @@ -1495,15 +1204,16 @@ def set_other_mpi_vars(environ_cp): 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % (mpi_home, mpi_home, mpi_home)) + def system_specific_test_config(env): - """Add default test flags required for TF tests to bazelrc.""" + """Add default build and test flags required for TF tests to bazelrc.""" write_to_bazelrc('test --flaky_test_attempts=3') write_to_bazelrc('test --test_size_filters=small,medium') write_to_bazelrc( 'test --test_tag_filters=-benchmark-test,-no_oss,-oss_serial') write_to_bazelrc('test --build_tag_filters=-benchmark-test,-no_oss') if is_windows(): - if env.get('TF_NEED_CUDA', None) == 1: + if env.get('TF_NEED_CUDA', None) == '1': write_to_bazelrc( 'test --test_tag_filters=-no_windows,-no_windows_gpu,-no_gpu') write_to_bazelrc( @@ -1515,7 +1225,7 @@ def system_specific_test_config(env): write_to_bazelrc('test --test_tag_filters=-gpu,-nomac,-no_mac') write_to_bazelrc('test --build_tag_filters=-gpu,-nomac,-no_mac') elif is_linux(): - if env.get('TF_NEED_CUDA', None) == 1: + if env.get('TF_NEED_CUDA', None) == '1': write_to_bazelrc('test --test_tag_filters=-no_gpu') write_to_bazelrc('test --build_tag_filters=-no_gpu') write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') @@ -1549,7 +1259,8 @@ def set_windows_build_flags(environ_cp): write_to_bazelrc('build --copt=-w --host_copt=-w') # Fix winsock2.h conflicts write_to_bazelrc( - 'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN') + 'build --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN ' + '--copt=-DNOGDI --host_copt=-DNOGDI') # Output more verbose information when something goes wrong write_to_bazelrc('build --verbose_failures') # The host and target platforms are the same in Windows build. So we don't @@ -1575,26 +1286,90 @@ def config_info_line(name, help_text): print('\t--config=%-12s\t# %s' % (name, help_text)) -def configure_apple_bazel_rules(): - """Configures Bazel rules for building on Apple platforms. +def configure_ios(): + """Configures TensorFlow for iOS builds. - Enables analyzing and building Apple Bazel rules on Apple platforms. This - function will only be executed if `is_macos()` is true. + This function will only be executed if `is_macos()` is true. """ if not is_macos(): return - for filepath in APPLE_BAZEL_FILES: + if _TF_CURRENT_BAZEL_VERSION is None or _TF_CURRENT_BAZEL_VERSION < 23000: print( - 'Configuring %s file to analyze and build Bazel rules on Apple platforms.' - % filepath) + 'Building Bazel rules on Apple platforms requires Bazel 0.23 or later.') + for filepath in APPLE_BAZEL_FILES: existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple') renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath) - os.rename(existing_filepath, renamed_filepath) + symlink_force(existing_filepath, renamed_filepath) + for filepath in IOS_FILES: + filename = os.path.basename(filepath) + new_filepath = os.path.join(_TF_WORKSPACE_ROOT, filename) + symlink_force(filepath, new_filepath) + + +def validate_cuda_config(environ_cp): + """Run find_cuda_config.py and return cuda_toolkit_path, or None.""" + + def maybe_encode_env(env): + """Encodes unicode in env to str on Windows python 2.x.""" + if not is_windows() or sys.version_info[0] != 2: + return env + for k, v in env.items(): + if isinstance(k, unicode): + k = k.encode('ascii') + if isinstance(v, unicode): + v = v.encode('ascii') + env[k] = v + return env + + cuda_libraries = ['cuda', 'cudnn'] + if is_linux(): + if int(environ_cp.get('TF_NEED_TENSORRT', False)): + cuda_libraries.append('tensorrt') + if environ_cp.get('TF_NCCL_VERSION', None): + cuda_libraries.append('nccl') + + proc = subprocess.Popen( + [environ_cp['PYTHON_BIN_PATH'], 'third_party/gpus/find_cuda_config.py'] + + cuda_libraries, + stdout=subprocess.PIPE, + env=maybe_encode_env(environ_cp)) + + if proc.wait(): + # Errors from find_cuda_config.py were sent to stderr. + print('Asking for detailed CUDA configuration...\n') + return False + + config = dict( + tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout) + + print('Found CUDA %s in:' % config['cuda_version']) + print(' %s' % config['cuda_library_dir']) + print(' %s' % config['cuda_include_dir']) + + print('Found cuDNN %s in:' % config['cudnn_version']) + print(' %s' % config['cudnn_library_dir']) + print(' %s' % config['cudnn_include_dir']) + + if 'tensorrt_version' in config: + print('Found TensorRT %s in:' % config['tensorrt_version']) + print(' %s' % config['tensorrt_library_dir']) + print(' %s' % config['tensorrt_include_dir']) + + if config.get('nccl_version', None): + print('Found NCCL %s in:' % config['nccl_version']) + print(' %s' % config['nccl_library_dir']) + print(' %s' % config['nccl_include_dir']) + + print('\n') + + environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path'] + return True def main(): global _TF_WORKSPACE_ROOT global _TF_BAZELRC + global _TF_CURRENT_BAZEL_VERSION parser = argparse.ArgumentParser() parser.add_argument( @@ -1611,7 +1386,8 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.19.0', '0.23.0') + current_bazel_version = check_bazel_version('0.24.1', '0.25.2') + _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) reset_tf_configure_bazelrc() @@ -1633,7 +1409,7 @@ def main(): if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' else: - environ_cp['TF_CONFIGURE_APPLE_BAZEL_RULES'] = '0' + environ_cp['TF_CONFIGURE_IOS'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at @@ -1666,11 +1442,43 @@ def main(): set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) if (environ_cp.get('TF_NEED_CUDA') == '1' and 'TF_CUDA_CONFIG_REPO' not in environ_cp): - set_tf_cuda_version(environ_cp) - set_tf_cudnn_version(environ_cp) - if is_linux(): - set_tf_tensorrt_install_path(environ_cp) - set_tf_nccl_install_path(environ_cp) + + set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False) + + environ_save = dict(environ_cp) + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): + + if validate_cuda_config(environ_cp): + cuda_env_names = [ + 'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION', + 'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS', + # Items below are for backwards compatibility when not using + # TF_CUDA_PATHS. + 'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH', + 'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH' + ] + # Note: set_action_env_var above already writes to bazelrc. + for name in cuda_env_names: + if name in environ_cp: + write_action_env_to_bazelrc(name, environ_cp[name]) + break + + # Restore settings changed below if CUDA config could not be validated. + environ_cp = dict(environ_save) + + set_tf_cuda_version(environ_cp) + set_tf_cudnn_version(environ_cp) + if is_linux(): + set_tf_tensorrt_version(environ_cp) + set_tf_nccl_version(environ_cp) + + set_tf_cuda_paths(environ_cp) + + else: + raise UserInputError( + 'Invalid CUDA setting were provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) set_tf_cuda_compute_capabilities(environ_cp) if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( @@ -1688,7 +1496,6 @@ def main(): else: # Use downloaded LLD for linking. write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld') - write_to_bazelrc('test:cuda_clang --config=download_clang_use_lld') else: # Set up which gcc nvcc should use as the host compiler # No need to set this on Windows @@ -1701,7 +1508,6 @@ def main(): set_tf_download_clang(environ_cp) if environ_cp.get('TF_DOWNLOAD_CLANG') == '1': write_to_bazelrc('build --config=download_clang') - write_to_bazelrc('test --config=download_clang') # SYCL / ROCm / CUDA are mutually exclusive. # At most 1 GPU platform can be configured. @@ -1738,13 +1544,9 @@ def main(): system_specific_test_config(os.environ) - if get_var( - environ_cp, 'TF_CONFIGURE_APPLE_BAZEL_RULES', - 'Configure Bazel rules for Apple platforms', False, - ('Would you like to configure Bazel rules for building on Apple platforms?' - ), 'Configuring Bazel rules for Apple platforms.', - 'Not configuring Bazel rules for Apple platforms.'): - configure_apple_bazel_rules() + set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False) + if environ_cp.get('TF_CONFIGURE_IOS') == '1': + configure_ios() print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See .bazelrc for more ' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 8f05e653ecc..a04ddf9f8a1 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -15,6 +15,7 @@ exports_files([ "leakr_file_type_recipe.ftrcp", ]) +load("//tensorflow:tensorflow.bzl", "VERSION") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl") load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") @@ -163,7 +164,7 @@ config_setting( name = "macos", values = { "apple_platform_type": "macos", - "cpu": "darwin_x86_64", + "cpu": "darwin", }, visibility = ["//visibility:public"], ) @@ -183,6 +184,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_aarch64", + values = {"cpu": "aarch64"}, + visibility = ["//visibility:public"], +) + config_setting( name = "linux_x86_64", values = {"cpu": "k8"}, @@ -325,6 +332,18 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "macos_with_framework_shared_object", + define_values = { + "framework_shared_object": "true", + }, + values = { + "apple_platform_type": "macos", + "cpu": "darwin", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "using_cuda_clang", define_values = { @@ -407,9 +426,15 @@ config_setting( values = {"cpu": "x64_windows"}, ) +# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! +# Instead, please use public APIs or public build rules TF provides. +# If you need functionality that is not exposed, we will work with you to expand our public APIs. package_group( name = "internal", - packages = ["//tensorflow/..."], + packages = [ + "//tensorflow/...", + "//tensorflow_estimator/python/estimator/...", + ], ) load( @@ -467,7 +492,7 @@ cc_library( # projects building with Bazel and importing TensorFlow as a dependency will not # depend on libtensorflow_framework.so unless they opt in. tf_cc_shared_object( - name = "libtensorflow_framework.so", + name = "tensorflow_framework", framework_so = [], linkopts = select({ "//tensorflow:macos": [], @@ -477,8 +502,11 @@ tf_cc_shared_object( ], }), linkstatic = 1, + per_os_targets = True, + soversion = VERSION, visibility = ["//visibility:public"], deps = [ + "//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/core:core_cpu_impl", "//tensorflow/core:framework_internal_impl", "//tensorflow/core:gpu_runtime_impl", @@ -508,7 +536,6 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:macos": [ "-Wl,-exported_symbols_list,$(location //tensorflow/c:exported_symbols.lds)", - "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [ ], @@ -518,6 +545,7 @@ tf_cc_shared_object( ], }), per_os_targets = True, + soversion = VERSION, visibility = ["//visibility:public"], # add win_def_file for tensorflow win_def_file = select({ @@ -548,6 +576,7 @@ tf_cc_shared_object( ], }), per_os_targets = True, + soversion = VERSION, visibility = ["//visibility:public"], # add win_def_file for tensorflow_cc win_def_file = select({ diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 7bd6b722398..feaf805f684 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -12,7 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Bring in all of the public TensorFlow interface into this module.""" +""" +Top-level module of TensorFlow. By convention, we refer to this module as +`tf` instead of `tensorflow`, following the common practice of importing +TensorFlow via the command `import tensorflow as tf`. + +The primary function of this module is to import all of the public TensorFlow +interfaces into a single place. The interfaces themselves are located in +sub-modules, as described below. + +Note that the file `__init__.py` in the TensorFlow source code tree is actually +only a placeholder to enable test cases to run. The TensorFlow build replaces +this file with a file generated from [`api_template.__init__.py`](https://www.github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py) +""" from __future__ import absolute_import as _absolute_import from __future__ import division as _division @@ -20,10 +32,13 @@ from __future__ import print_function as _print_function import distutils as _distutils import inspect as _inspect +import logging as _logging import os as _os import site as _site import sys as _sys +from tensorflow.python.tools import module_util as _module_util + # API IMPORTS PLACEHOLDER # Make sure directory containing top level submodules is in @@ -37,25 +52,29 @@ if not hasattr(_current_module, '__path__'): elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) -# pylint: disable=g-bad-import-order -from tensorflow.python.tools import component_api_helper as _component_api_helper -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=('tensorboard.summary._tf.summary'), - error_msg="Limited tf.summary API due to missing TensorBoard installation") -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=( - 'tensorflow_estimator.python.estimator.api._v2.estimator')) +# Hook external TensorFlow modules. +try: + from tensorboard.summary._tf import summary + _current_module.__path__ = ( + [_module_util.get_parent_dir(summary)] + _current_module.__path__) +except ImportError: + _logging.warning( + "Limited tf.summary API due to missing TensorBoard installation.") + +try: + from tensorflow_estimator.python.estimator.api._v2 import estimator + _current_module.__path__ = ( + [_module_util.get_parent_dir(estimator)] + _current_module.__path__) +except ImportError: + pass + +try: + from tensorflow.python.keras.api._v2 import keras + _current_module.__path__ = ( + [_module_util.get_parent_dir(keras)] + _current_module.__path__) +except ImportError: + pass -if not hasattr(_current_module, 'estimator'): - _component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=( - 'tensorflow_estimator.python.estimator.api.estimator')) -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=('tensorflow.python.keras.api._v2.keras')) # Enable TF2 behaviors from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 5eb25a81b7f..a83ff3a16c2 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -26,30 +26,44 @@ import sys as _sys # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +from tensorflow.python.tools import module_util as _module_util # API IMPORTS PLACEHOLDER -from tensorflow.python.tools import component_api_helper as _component_api_helper -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=( - 'tensorflow_estimator.python.estimator.api._v1.estimator')) - +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +# We're using bitwise, but there's nothing special about that. +_API_MODULE = bitwise # pylint: disable=undefined-variable _current_module = _sys.modules[__name__] -if not hasattr(_current_module, 'estimator'): - _component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=( - 'tensorflow_estimator.python.estimator.api.estimator')) -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=('tensorflow.python.keras.api._v1.keras')) +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + +# Hook external TensorFlow modules. +try: + from tensorflow_estimator.python.estimator.api._v1 import estimator + _current_module.__path__ = ( + [_module_util.get_parent_dir(estimator)] + _current_module.__path__) +except ImportError: + pass + +try: + from tensorflow.python.keras.api._v1 import keras + _current_module.__path__ = ( + [_module_util.get_parent_dir(keras)] + _current_module.__path__) +except ImportError: + pass + + from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top _CONTRIB_WARNING = """ -WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. +The TensorFlow contrib module will not be included in TensorFlow 2.0. For more information, please see: * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md * https://github.com/tensorflow/addons + * https://github.com/tensorflow/io (for I/O related ops) If you depend on functionality not listed there, please file an issue. """ contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib', @@ -65,17 +79,6 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at- # The 'app' module will be imported as part of the placeholder section above. app.flags = flags # pylint: disable=undefined-variable -# Also use 'app' module (choice is arbitrary) to derive the API directory below. -_API_MODULE = app # pylint: disable=undefined-variable - -# Make sure directory containing top level submodules is in -# the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) -if not hasattr(_current_module, '__path__'): - __path__ = [_tf_api_dir] -elif _tf_api_dir not in __path__: - __path__.append(_tf_api_dir) - # Load all plugin libraries from site-packages/tensorflow-plugins if we are # running under pip. # TODO(gunan): Enable setting an environment variable to define arbitrary plugin @@ -117,7 +120,11 @@ if _running_from_pip_package(): # pylint: disable=undefined-variable try: del python + if '__all__' in vars(): + vars()['__all__'].remove('python') del core + if '__all__' in vars(): + vars()['__all__'].remove('core') except NameError: # Don't fail if these modules are not available. # For e.g. this file will be originally placed under tensorflow/_api/v1 which @@ -128,6 +135,8 @@ except NameError: # others don't exist. try: del compiler + if '__all__' in vars(): + vars()['__all__'].remove('compiler') except NameError: pass # pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 3c43467b510..f2ca79f57fc 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -21,6 +21,7 @@ filegroup( srcs = [ "c_api.h", "c_api_experimental.h", + "tf_attrtype.h", ], visibility = ["//tensorflow:__subpackages__"], ) @@ -39,14 +40,19 @@ filegroup( "python_api.h", "*test*", ], - ), + ) + [ + "//tensorflow/cc:srcs", + "//tensorflow/core/distributed_runtime:server_lib.h", + ], visibility = ["//visibility:public"], ) tf_cuda_library( name = "c_api_internal", - srcs = ["c_api.h"], - hdrs = ["c_api_internal.h"], + hdrs = [ + "c_api.h", + "c_api_internal.h", + ], visibility = [ "//tensorflow:internal", "//tensorflow/c:__subpackages__", @@ -56,6 +62,7 @@ tf_cuda_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + ":tf_attrtype", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -66,14 +73,24 @@ tf_cuda_library( }), ) +cc_library( + name = "tf_attrtype", + hdrs = ["tf_attrtype.h"], + visibility = ["//visibility:public"], +) + tf_cuda_library( name = "c_api", - hdrs = ["c_api.h"], + hdrs = [ + "c_api.h", + "tf_attrtype.h", + ], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ ":c_api_no_xla", ":c_api_internal", + ":tf_attrtype", ] + select({ "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", @@ -89,16 +106,18 @@ tf_cuda_library( "c_api.cc", "c_api_function.cc", ], - hdrs = [ - "c_api.h", - ], + hdrs = ["c_api.h"], copts = tf_copts(), visibility = ["//tensorflow/c:__subpackages__"], - deps = [":c_api_internal"] + select({ + deps = [ + ":c_api_internal", + ":tf_attrtype", + ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + "@com_google_absl//absl/strings", "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", @@ -140,19 +159,11 @@ tf_cuda_library( "//tensorflow/core:lib_platform", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "@com_google_absl//absl/strings", ], ) -cc_library( - name = "c_api_headers", - hdrs = [ - "c_api.h", - ], - copts = tf_copts(), - visibility = ["//tensorflow:__subpackages__"], -) - exports_files( [ "version_script.lds", @@ -238,6 +249,28 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "ops", + srcs = [ + "ops.cc", + ], + hdrs = [ + "ops.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":tf_status_helper", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }) + [":c_api_internal"], +) + # ----------------------------------------------------------------------------- # Tests @@ -286,7 +319,6 @@ tf_cuda_cc_test( "//conditions:default": [], }), tags = [ - "no_oss", # http://b/119522529 "noasan", ], # We must ensure that the dependencies can be dynamically linked since @@ -440,6 +472,27 @@ tf_cuda_cc_test( ], ) +tf_cc_test( + name = "ops_test", + size = "small", + srcs = ["ops_test.cc"], + linkopts = select({ + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + # ----------------------------------------------------------------------------- # Python API target diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index af93d91b94c..21d72ac96b5 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -30,8 +30,8 @@ limitations under the License. #include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/framework/logging.h" #include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/kernels/logging_ops.h" #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -368,7 +368,7 @@ static Status TF_StringDecode_Impl(const char* src, size_t src_len, size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, size_t* dst_len, TF_Status* status) { status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len); - if (!status->status.ok()) return 0; + if (TF_GetCode(status) != TF_OK) return 0; return static_cast(*dst - src) + *dst_len; } @@ -423,7 +423,7 @@ TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, TF_Status* status) { Session* session; status->status = NewSession(opt->options, &session); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { return new TF_DeprecatedSession({session}); } else { DCHECK_EQ(nullptr, session); @@ -615,7 +615,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, offsets++; const string& s = srcarray(i); size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { status->status = InvalidArgument( "invalid string tensor encoding (string #", i, " of ", srcarray.size(), "): ", status->status.error_message()); @@ -775,7 +775,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { // TODO(nolivia): check this on a subset of the graph instead of all of // it. status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { session->graph->mu.unlock(); return false; } @@ -795,7 +795,7 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { *graph_def.mutable_library() = graph.flib_def().ToProto(); session->graph->mu.unlock(); status->status = session->session->Extend(graph_def); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { // Contract is we always delete input_values[i]. return false; } @@ -825,7 +825,7 @@ static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, const int ninputs = input_pairs->size(); for (int i = 0; i < ninputs; ++i) { status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); - if (!status->status.ok()) return false; + if (TF_GetCode(status) != TF_OK) return false; } return true; } @@ -863,7 +863,7 @@ static void TF_Run_Helper( // Serialize back to upstream client, who now owns the new buffer if (run_metadata != nullptr) { status->status = MessageToBuffer(run_metadata_proto, run_metadata); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; } } else { // NOTE(zongheng): PRun does not support RunOptions yet. @@ -883,7 +883,7 @@ static void TF_Run_Helper( continue; } c_outputs[i] = TF_TensorFromTensor(src, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; } } @@ -940,7 +940,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, string new_handle; status->status = s->session->PRunSetup(input_names, output_names, target_oper_names, &new_handle); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; @@ -979,7 +979,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { status->status = tensorflow::LoadLibrary( library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, &lib_handle->op_list.length); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { delete lib_handle; return nullptr; } @@ -1009,7 +1009,7 @@ TF_Buffer* TF_GetAllOpList() { // -------------------------------------------------------------------------- // ListDevices & SessionListDevices API -void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; } +void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; } TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { TF_DeviceList* response = new TF_DeviceList; @@ -1407,7 +1407,7 @@ void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, TF_Tensor* value, TF_Status* status) { Tensor t; status->status = TF_TensorToTensor(value, &t); - if (status->status.ok()) desc->node_builder.Attr(attr_name, t); + if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t); } void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, @@ -1417,13 +1417,13 @@ void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, std::vector t; t.reserve(num_values); - for (int i = 0; i < num_values && status->status.ok(); ++i) { + for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) { Tensor v; status->status = TF_TensorToTensor(values[i], &v); t.emplace_back(v); } - if (status->status.ok()) desc->node_builder.Attr(attr_name, t); + if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t); } void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, @@ -1471,11 +1471,11 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, } status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { // Run shape inference function for newly added node. status->status = desc->graph->refiner.AddNode(ret); } - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { // Add the node to the name-to-node mapping. desc->graph->name_map[ret->name()] = ret; } else if (ret != nullptr) { @@ -1524,10 +1524,10 @@ int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, NameRangeMap name_ranges; status->status = NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); - if (!status->status.ok()) return -1; + if (TF_GetCode(status) != TF_OK) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { - status->status = InvalidArgument("Input arg '", arg_name, "' not found"); + status->status = InvalidArgument("Output arg '", arg_name, "' not found"); return -1; } return iter->second.second - iter->second.first; @@ -1546,7 +1546,7 @@ int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, NameRangeMap name_ranges; status->status = NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); - if (!status->status.ok()) return -1; + if (TF_GetCode(status) != TF_OK) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { status->status = InvalidArgument("Input arg '", arg_name, "' not found"); @@ -1644,7 +1644,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, TF_Status* status) { TF_AttrMetadata metadata; const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return metadata; + if (TF_GetCode(status) != TF_OK) return metadata; switch (attr->value_case()) { #define SINGLE_CASE(kK, attr_type, size_expr) \ case tensorflow::AttrValue::kK: \ @@ -1751,7 +1751,7 @@ void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, void* value, size_t max_length, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; if (attr->value_case() != tensorflow::AttrValue::kS) { status->status = InvalidArgument("Attribute '", attr_name, "' is not a string"); @@ -1769,7 +1769,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, int max_values, void* storage, size_t storage_size, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; if (attr->value_case() != tensorflow::AttrValue::kList) { status->status = InvalidArgument("Value for '", attr_name, "' is not a list"); @@ -1802,7 +1802,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ int max_values, TF_Status* status) { \ const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (!status->status.ok()) return; \ + if (TF_GetCode(status) != TF_OK) return; \ if (attr->value_case() != tensorflow::AttrValue::kList) { \ status->status = \ InvalidArgument("Value for '", attr_name, "' is not a list."); \ @@ -1824,7 +1824,7 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, PartialTensorShape shape; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; auto len = std::min(shape.dims(), num_dims); for (int i = 0; i < len; ++i) { value[i] = shape.dim_size(i); @@ -1832,21 +1832,21 @@ void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, } void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, - int64_t** values, int* num_dims, - int max_values, int64_t* storage, - int storage_size, TF_Status* status) { + int64_t** dims, int* num_dims, int num_shapes, + int64_t* storage, int storage_size, + TF_Status* status) { std::vector shapes; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); - if (!status->status.ok()) return; - auto len = std::min(static_cast(shapes.size()), max_values); + if (TF_GetCode(status) != TF_OK) return; + auto len = std::min(static_cast(shapes.size()), num_shapes); int64_t* p = storage; int storage_left = storage_size; for (int i = 0; i < len; ++i) { // shapes[i].dims() == -1 for shapes with an unknown rank. int64_t n = shapes[i].dims(); num_dims[i] = n; - values[i] = p; + dims[i] = p; if (n < 0) { continue; } @@ -1866,7 +1866,7 @@ void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, const char* attr_name, TF_Buffer* value, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; if (attr->value_case() != tensorflow::AttrValue::kShape) { status->status = InvalidArgument("Value for '", attr_name, "' is not a shape."); @@ -1880,7 +1880,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, TF_Buffer** values, int max_values, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; if (attr->value_case() != tensorflow::AttrValue::kList) { status->status = InvalidArgument("Value for '", attr_name, "' is not a list"); @@ -1890,7 +1890,7 @@ void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, for (int i = 0; i < len; ++i) { values[i] = TF_NewBuffer(); status->status = MessageToBuffer(attr->list().shape(i), values[i]); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { // Delete everything allocated to far, the operation has failed. for (int j = 0; j <= i; ++j) { TF_DeleteBuffer(values[j]); @@ -1905,7 +1905,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, *value = nullptr; Tensor t; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; *value = TF_TensorFromTensor(t, status); } @@ -1914,7 +1914,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, TF_Status* status) { std::vector ts; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; const auto len = std::min(max_values, static_cast(ts.size())); for (int i = 0; i < len; ++i) { values[i] = TF_TensorFromTensor(ts[i], status); @@ -1925,7 +1925,7 @@ void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, TF_Status* status) { const auto* attr = GetAttrValue(oper, attr_name, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; status->status = MessageToBuffer(*attr, output_attr_value); } @@ -1941,7 +1941,10 @@ TF_Graph::TF_Graph() refiner(graph.versions().producer(), graph.op_registry()), delete_requested(false), parent(nullptr), - parent_inputs(nullptr) {} + parent_inputs(nullptr) { + // Tell the shape refiner to also run shape inference on functions. + refiner.set_function_library_for_shape_inference(&graph.flib_def()); +} TF_Graph* TF_NewGraph() { return new TF_Graph; } @@ -2003,7 +2006,7 @@ void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, { mutex_lock l(graph->mu); status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; } status->status = MessageToBuffer(*op_def, output_op_def); } @@ -2121,7 +2124,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, tensorflow::ImportGraphDefResults results; status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, &graph->refiner, &results); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; // Add new nodes to name_map for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { @@ -2175,7 +2178,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( auto results = new TF_ImportGraphDefResults(); mutex_lock l(graph->mu); GraphImportGraphDefLocked(graph, def, options, results, status); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { delete results; return nullptr; } @@ -2233,7 +2236,7 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input)); // TODO(skyewm): set placeholder shape TF_Operation* oper = TF_FinishOperation(desc, status); - if (!status->status.ok()) return false; + if (TF_GetCode(status) != TF_OK) return false; *input = {oper, 0}; return true; } @@ -2378,7 +2381,7 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output, body_graph, body_inputs, body_outputs, name}; - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { FreeWhileResources(¶ms); return EmptyWhileParams(); } @@ -2582,7 +2585,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { Session* session; status->status = NewSession(opt->options, &session); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); @@ -2630,7 +2633,7 @@ TF_Session* TF_LoadSessionFromSavedModel( status->status = tensorflow::LoadSavedModel(session_options->options, run_options_proto, export_dir, tag_set, &bundle); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session // extends using GraphDefs. The Graph instance is different, but equivalent @@ -2647,7 +2650,7 @@ TF_Session* TF_LoadSessionFromSavedModel( if (meta_graph_def != nullptr) { status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; } TF_Session* session = new TF_Session(bundle.session.release(), graph); @@ -2747,7 +2750,7 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, string new_handle; status->status = session->session->PRunSetup(input_names, output_names, target_names, &new_handle); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; @@ -2809,9 +2812,9 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, tensor, graph->refiner, *graph->graph.op_registry(), graph->graph.versions().producer(), &evaluated, &result_tensor); if (evaluated) { - DCHECK(status->status.ok()); + DCHECK(TF_GetCode(status) == TF_OK); *result = TF_TensorFromTensor(result_tensor, status); - if (!status->status.ok()) evaluated = false; + if (TF_GetCode(status) != TF_OK) evaluated = false; } return evaluated; } @@ -2866,7 +2869,7 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(*api_def, ret); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { TF_DeleteBuffer(ret); return nullptr; } @@ -2878,7 +2881,7 @@ TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(kernel_list, ret); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { TF_DeleteBuffer(ret); return nullptr; } @@ -2890,7 +2893,7 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { tensorflow::GetRegisteredKernelsForOp(name); TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(kernel_list, ret); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { TF_DeleteBuffer(ret); return nullptr; } @@ -2920,7 +2923,7 @@ TF_Server* TF_NewServer(const void* proto, size_t proto_len, std::unique_ptr out_server; status->status = tensorflow::NewServer(server_def, &out_server); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; return new TF_Server(std::move(out_server)); #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 051de3a7dc0..c074e5d3629 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "tensorflow/c/tf_attrtype.h" + // -------------------------------------------------------------------------- // C API for TensorFlow. // @@ -686,19 +688,6 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( TF_Operation* oper, TF_Operation** control_outputs, int max_control_outputs); -// TF_AttrType describes the type of the value of an attribute on an operation. -typedef enum TF_AttrType { - TF_ATTR_STRING = 0, - TF_ATTR_INT = 1, - TF_ATTR_FLOAT = 2, - TF_ATTR_BOOL = 3, - TF_ATTR_TYPE = 4, - TF_ATTR_SHAPE = 5, - TF_ATTR_TENSOR = 6, - TF_ATTR_PLACEHOLDER = 7, - TF_ATTR_FUNC = 8, -} TF_AttrType; - // TF_AttrMetadata describes the value of an attribute on an operation. typedef struct TF_AttrMetadata { // A boolean: 1 if the attribute value is a list, 0 otherwise. diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 7ff4084decc..726ce2784ae 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -66,6 +67,24 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { } } +unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) { + tensorflow::BuildXlaOpsPassFlags* flags = + tensorflow::GetBuildXlaOpsPassFlags(); + bool original = flags->tf_xla_enable_lazy_compilation; + flags->tf_xla_enable_lazy_compilation = enable; + return original; +} + +void TF_SetXLaAutoJitMode(const char* mode) { + tensorflow::SetXlaAutoJitFlagFromFlagString(mode); +} + +void TF_SetXlaMinClusterSize(int size) { + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); + flags->tf_xla_min_cluster_size = size; +} + TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth, unsigned int num_cpu_devices) { @@ -177,8269 +196,6 @@ static std::vector CreateFunctionsFromTextProto( return ret; } -// On success, returns a newly created TF_Function instance encoding a dataset -// node stack that returns a sequence of 3 floats, and sets `dataset_name` to -// the created dataset name. The returned function must be deleted by calling -// TF_DeleteFunction. -static UniqueFuncPtr CreateFakeDatasetFunction(std::string* dataset_name, - TF_Status* status) { - const char* func_def = R"PREFIX( -library { - function { - signature { - name: "_make_dataset_d8de2712" - output_arg { - name: "TensorSliceDataset" - type: DT_VARIANT - } - is_stateful: true - } - node_def { - name: "TensorSliceDataset/tensors/component_0" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000(B\000\000,B\000\0000B" - } - } - } - } - node_def { - name: "TensorSliceDataset" - op: "TensorSliceDataset" - input: "TensorSliceDataset/tensors/component_0:output:0" - attr { - key: "Toutput_types" - value { - list { - type: DT_FLOAT - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - } - ret { - key: "TensorSliceDataset" - value: "TensorSliceDataset:handle:0" - } - } -} -)PREFIX"; - - *dataset_name = "_make_dataset_d8de2712"; - auto functions = CreateFunctionsFromTextProto( - func_def, /*mutate_proto_func*/ nullptr, status); - DCHECK_EQ(functions.size(), 1); - return std::move(functions[0]); -} - -#if not defined(PLATFORM_WINDOWS) -// On success, returns a set of TF_Function instances encoding a dataset -// node stack that reads a Imagenet TFRecordFile dataset from `file_path`, and -// sets `dataset_name` to the created dataset name. The returned functions must -// be deleted by calling TF_DeleteFunction. -static std::vector CreateImagenetDatasetFunctions( - const char* file_path, std::string* dataset_name, TF_Status* status) { -#if defined(PLATFORM_WINDOWS) - status->status = tensorflow::errors::Unimplemented( - "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " - "is not implemented for Windows"); - return std::vector(); -#else - const char* func_def = R"PREFIX( -library { - function { - signature { - name: "tf_map_func_91295dea" - input_arg { - name: "arg0" - type: DT_STRING - } - output_arg { - name: "FlatMapDataset" - type: DT_VARIANT - } - description: "A wrapper for Defun that facilitates shape inference." - is_stateful: true - } - node_def { - name: "flat_filenames/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } - } - node_def { - name: "flat_filenames" - op: "Reshape" - input: "arg0" - input: "flat_filenames/shape:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "TensorSliceDataset" - op: "TensorSliceDataset" - input: "flat_filenames:output:0" - attr { - key: "Toutput_types" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - } - node_def { - name: "FlatMapDataset" - op: "FlatMapDataset" - input: "TensorSliceDataset:handle:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "f" - value { - func { - name: "tf_map_func_0cc8c35b" - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_STRING - } - } - } - } - ret { - key: "FlatMapDataset" - value: "FlatMapDataset:handle:0" - } - } - function { - signature { - name: "tf_map_func_0cc8c35b" - input_arg { - name: "arg0" - type: DT_STRING - } - output_arg { - name: "TFRecordDataset" - type: DT_VARIANT - } - description: "A wrapper for Defun that facilitates shape inference." - is_stateful: true - } - node_def { - name: "compression_type" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "" - } - } - } - } - node_def { - name: "buffer_size" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 8388608 - } - } - } - } - node_def { - name: "TFRecordDataset" - op: "TFRecordDataset" - input: "arg0" - input: "compression_type:output:0" - input: "buffer_size:output:0" - } - ret { - key: "TFRecordDataset" - value: "TFRecordDataset:handle:0" - } - } - function { - signature { - name: "tf_map_func_74b6b15c" - input_arg { - name: "arg0" - type: DT_STRING - } - output_arg { - name: "Reshape_1" - type: DT_FLOAT - } - output_arg { - name: "sub_1" - type: DT_INT32 - } - description: "A wrapper for Defun that facilitates shape inference." - is_stateful: true - } - node_def { - name: "ParseSingleExample/key_image/class/label" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } - } - node_def { - name: "ParseSingleExample/Reshape/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node_def { - name: "ParseSingleExample/Reshape" - op: "Reshape" - input: "ParseSingleExample/key_image/class/label:output:0" - input: "ParseSingleExample/Reshape/shape:output:0" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "ParseSingleExample/key_image/class/text" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "" - } - } - } - } - node_def { - name: "ParseSingleExample/Reshape_1/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node_def { - name: "ParseSingleExample/Reshape_1" - op: "Reshape" - input: "ParseSingleExample/key_image/class/text:output:0" - input: "ParseSingleExample/Reshape_1/shape:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "ParseSingleExample/key_image/encoded" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "" - } - } - } - } - node_def { - name: "ParseSingleExample/Reshape_2/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node_def { - name: "ParseSingleExample/Reshape_2" - op: "Reshape" - input: "ParseSingleExample/key_image/encoded:output:0" - input: "ParseSingleExample/Reshape_2/shape:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "ParseSingleExample/key_image/format" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "jpeg" - } - } - } - } - node_def { - name: "ParseSingleExample/Reshape_3/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node_def { - name: "ParseSingleExample/Reshape_3" - op: "Reshape" - input: "ParseSingleExample/key_image/format:output:0" - input: "ParseSingleExample/Reshape_3/shape:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "ParseSingleExample/ParseSingleExample" - op: "ParseSingleExample" - input: "arg0" - input: "ParseSingleExample/Reshape:output:0" - input: "ParseSingleExample/Reshape_1:output:0" - input: "ParseSingleExample/Reshape_2:output:0" - input: "ParseSingleExample/Reshape_3:output:0" - attr { - key: "Tdense" - value { - list { - type: DT_INT64 - type: DT_STRING - type: DT_STRING - type: DT_STRING - } - } - } - attr { - key: "dense_keys" - value { - list { - s: "image/class/label" - s: "image/class/text" - s: "image/encoded" - s: "image/format" - } - } - } - attr { - key: "dense_shapes" - value { - list { - shape { - } - shape { - } - shape { - } - shape { - } - } - } - } - attr { - key: "num_sparse" - value { - i: 5 - } - } - attr { - key: "sparse_keys" - value { - list { - s: "image/object/bbox/xmax" - s: "image/object/bbox/xmin" - s: "image/object/bbox/ymax" - s: "image/object/bbox/ymin" - s: "image/object/class/label" - } - } - } - attr { - key: "sparse_types" - value { - list { - type: DT_FLOAT - type: DT_FLOAT - type: DT_FLOAT - type: DT_FLOAT - type: DT_INT64 - } - } - } - } - node_def { - name: "Reshape/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node_def { - name: "Reshape" - op: "Reshape" - input: "ParseSingleExample/ParseSingleExample:dense_values:2" - input: "Reshape/shape:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "decode_image/Substr/pos" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node_def { - name: "decode_image/Substr/len" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "decode_image/Substr" - op: "Substr" - input: "Reshape:output:0" - input: "decode_image/Substr/pos:output:0" - input: "decode_image/Substr/len:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "decode_image/is_jpeg/Substr/pos" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node_def { - name: "decode_image/is_jpeg/Substr/len" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "decode_image/is_jpeg/Substr" - op: "Substr" - input: "Reshape:output:0" - input: "decode_image/is_jpeg/Substr/pos:output:0" - input: "decode_image/is_jpeg/Substr/len:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "decode_image/is_jpeg/Equal/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "\377\330\377" - } - } - } - } - node_def { - name: "decode_image/is_jpeg/Equal" - op: "Equal" - input: "decode_image/is_jpeg/Substr:output:0" - input: "decode_image/is_jpeg/Equal/y:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - } - node_def { - name: "decode_image/cond_jpeg/Switch" - op: "Switch" - input: "decode_image/is_jpeg/Equal:z:0" - input: "decode_image/is_jpeg/Equal:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/switch_t" - op: "Identity" - input: "decode_image/cond_jpeg/Switch:output_true:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/switch_f" - op: "Identity" - input: "decode_image/cond_jpeg/Switch:output_false:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/pred_id" - op: "Identity" - input: "decode_image/is_jpeg/Equal:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/check_jpeg_channels/x" - op: "Const" - input: "^decode_image/cond_jpeg/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/check_jpeg_channels/y" - op: "Const" - input: "^decode_image/cond_jpeg/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/check_jpeg_channels" - op: "NotEqual" - input: "decode_image/cond_jpeg/check_jpeg_channels/x:output:0" - input: "decode_image/cond_jpeg/check_jpeg_channels/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "decode_image/cond_jpeg/Assert/Const" - op: "Const" - input: "^decode_image/cond_jpeg/switch_t" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/Assert/Assert/data_0" - op: "Const" - input: "^decode_image/cond_jpeg/switch_t" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/Assert/Assert" - op: "Assert" - input: "decode_image/cond_jpeg/check_jpeg_channels:z:0" - input: "decode_image/cond_jpeg/Assert/Assert/data_0:output:0" - attr { - key: "T" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "summarize" - value { - i: 3 - } - } - } - node_def { - name: "decode_image/cond_jpeg/DecodeJpeg" - op: "DecodeJpeg" - input: "decode_image/cond_jpeg/DecodeJpeg/Switch:output_true:0" - input: "^decode_image/cond_jpeg/Assert/Assert" - attr { - key: "acceptable_fraction" - value { - f: 1.0 - } - } - attr { - key: "channels" - value { - i: 3 - } - } - attr { - key: "dct_method" - value { - s: "" - } - } - attr { - key: "fancy_upscaling" - value { - b: true - } - } - attr { - key: "ratio" - value { - i: 1 - } - } - attr { - key: "try_recover_truncated" - value { - b: false - } - } - } - node_def { - name: "decode_image/cond_jpeg/DecodeJpeg/Switch" - op: "Switch" - input: "Reshape:output:0" - input: "decode_image/cond_jpeg/pred_id:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Reshape" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/is_png/y" - op: "Const" - input: "^decode_image/cond_jpeg/switch_f" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "\211PN" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/is_png" - op: "Equal" - input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" - input: "decode_image/cond_jpeg/is_png/y:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - } - node_def { - name: "decode_image/cond_jpeg/is_png/Switch" - op: "Switch" - input: "decode_image/Substr:output:0" - input: "decode_image/cond_jpeg/pred_id:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decode_image/Substr" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/Switch" - op: "Switch" - input: "decode_image/cond_jpeg/is_png:z:0" - input: "decode_image/cond_jpeg/is_png:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/switch_t" - op: "Identity" - input: "decode_image/cond_jpeg/cond_png/Switch:output_true:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/switch_f" - op: "Identity" - input: "decode_image/cond_jpeg/cond_png/Switch:output_false:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/pred_id" - op: "Identity" - input: "decode_image/cond_jpeg/is_png:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/DecodePng" - op: "DecodePng" - input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1:output_true:0" - attr { - key: "channels" - value { - i: 3 - } - } - attr { - key: "dtype" - value { - type: DT_UINT8 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch" - op: "Switch" - input: "Reshape:output:0" - input: "decode_image/cond_jpeg/pred_id:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Reshape" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1" - op: "Switch" - input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" - input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Reshape" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/is_gif/y" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/switch_f" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "GIF" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/is_gif" - op: "Equal" - input: "decode_image/cond_jpeg/cond_png/is_gif/Switch:output_false:0" - input: "decode_image/cond_jpeg/cond_png/is_gif/y:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/is_gif/Switch" - op: "Switch" - input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" - input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@decode_image/Substr" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Switch" - op: "Switch" - input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" - input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_t" - op: "Identity" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_true:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - op: "Identity" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_false:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id" - op: "Identity" - input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels" - op: "NotEqual" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x:output:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 4 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1" - op: "NotEqual" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x:output:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd" - op: "LogicalAnd" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels:z:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1:z:0" - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Const" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Channels must be in (None, 0, 3) when decoding GIF images" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Channels must be in (None, 0, 3) when decoding GIF images" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" - op: "Assert" - input: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd:z:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0:output:0" - attr { - key: "T" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "summarize" - value { - i: 3 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif" - op: "DecodeGif" - input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1:output_true:0" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch" - op: "Switch" - input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" - input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Reshape" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1" - op: "Switch" - input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Reshape" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr" - op: "Substr" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos:output:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch" - op: "Switch" - input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Reshape" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "BM" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp" - op: "Equal" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr:output:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y:output:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Const" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" - op: "Assert" - input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp:z:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0:output:0" - attr { - key: "T" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "summarize" - value { - i: 3 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels" - op: "NotEqual" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x:output:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Const" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Channels must be in (None, 0, 3) when decoding BMP images" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0" - op: "Const" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Channels must be in (None, 0, 3) when decoding BMP images" - } - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" - op: "Assert" - input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels:z:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0:output:0" - attr { - key: "T" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "summarize" - value { - i: 3 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp" - op: "DecodeBmp" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" - input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" - attr { - key: "channels" - value { - i: 0 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/cond_gif/Merge" - op: "Merge" - input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp:image:0" - input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif:image:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_UINT8 - } - } - } - node_def { - name: "decode_image/cond_jpeg/cond_png/Merge" - op: "Merge" - input: "decode_image/cond_jpeg/cond_png/cond_gif/Merge:output:0" - input: "decode_image/cond_jpeg/cond_png/DecodePng:image:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_UINT8 - } - } - } - node_def { - name: "decode_image/cond_jpeg/Merge" - op: "Merge" - input: "decode_image/cond_jpeg/cond_png/Merge:output:0" - input: "decode_image/cond_jpeg/DecodeJpeg:image:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_UINT8 - } - } - } - node_def { - name: "convert_image/Cast" - op: "Cast" - input: "decode_image/cond_jpeg/Merge:output:0" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_UINT8 - } - } - } - node_def { - name: "convert_image/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.00392156885937 - } - } - } - } - node_def { - name: "convert_image" - op: "Mul" - input: "convert_image/Cast:y:0" - input: "convert_image/y:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "Const" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 4 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\200?\000\000\200?" - } - } - } - } - node_def { - name: "distorted_bounding_box_crop/Shape" - op: "Shape" - input: "convert_image:z:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node_def { - name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.10000000149 - } - } - } - } - node_def { - name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2" - op: "SampleDistortedBoundingBoxV2" - input: "distorted_bounding_box_crop/Shape:output:0" - input: "Const:output:0" - input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "area_range" - value { - list { - f: 0.0799999982119 - f: 1.0 - } - } - } - attr { - key: "aspect_ratio_range" - value { - list { - f: 0.75 - f: 1.33333337307 - } - } - } - attr { - key: "max_attempts" - value { - i: 1 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } - attr { - key: "use_image_if_no_bounding_boxes" - value { - b: true - } - } - } - node_def { - name: "distorted_bounding_box_crop/Slice" - op: "Slice" - input: "convert_image:z:0" - input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:begin:0" - input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:size:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "Shape" - op: "Shape" - input: "convert_image:z:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node_def { - name: "Shape_1" - op: "Shape" - input: "distorted_bounding_box_crop/Slice:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node_def { - name: "Equal" - op: "Equal" - input: "Shape:output:0" - input: "Shape_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "Cast" - op: "Cast" - input: "Equal:z:0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_BOOL - } - } - } - node_def { - name: "Const_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "Sum" - op: "Sum" - input: "Cast:y:0" - input: "Const_1:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "keep_dims" - value { - b: false - } - } - } - node_def { - name: "GreaterEqual/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "GreaterEqual" - op: "GreaterEqual" - input: "Sum:output:0" - input: "GreaterEqual/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/Switch" - op: "Switch" - input: "GreaterEqual:z:0" - input: "GreaterEqual:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "cond/switch_t" - op: "Identity" - input: "cond/Switch:output_true:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "cond/switch_f" - op: "Identity" - input: "cond/Switch:output_false:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "cond/pred_id" - op: "Identity" - input: "GreaterEqual:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "cond/Shape" - op: "Shape" - input: "cond/Shape/Switch:output_true:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/Shape/Switch" - op: "Switch" - input: "convert_image:z:0" - input: "cond/pred_id:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@convert_image" - } - } - } - } - node_def { - name: "cond/Cast" - op: "Cast" - input: "cond/Shape:output:0" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/strided_slice/stack" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "cond/strided_slice/stack_1" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice/stack_2" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice" - op: "StridedSlice" - input: "cond/Cast:y:0" - input: "cond/strided_slice/stack:output:0" - input: "cond/strided_slice/stack_1:output:0" - input: "cond/strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/strided_slice_1/stack" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_1/stack_1" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } - } - node_def { - name: "cond/strided_slice_1/stack_2" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_1" - op: "StridedSlice" - input: "cond/Cast:y:0" - input: "cond/strided_slice_1/stack:output:0" - input: "cond/strided_slice_1/stack_1:output:0" - input: "cond/strided_slice_1/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/Greater" - op: "Greater" - input: "cond/strided_slice:output:0" - input: "cond/strided_slice_1:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "cond/cond/Switch" - op: "Switch" - input: "cond/Greater:z:0" - input: "cond/Greater:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "cond/cond/switch_t" - op: "Identity" - input: "cond/cond/Switch:output_true:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "cond/cond/switch_f" - op: "Identity" - input: "cond/cond/Switch:output_false:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "cond/cond/pred_id" - op: "Identity" - input: "cond/Greater:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "cond/cond/strided_slice/stack" - op: "Const" - input: "^cond/cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "cond/cond/strided_slice/stack_1" - op: "Const" - input: "^cond/cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/cond/strided_slice/stack_2" - op: "Const" - input: "^cond/cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/cond/strided_slice" - op: "StridedSlice" - input: "cond/cond/strided_slice/Switch:output_true:0" - input: "cond/cond/strided_slice/stack:output:0" - input: "cond/cond/strided_slice/stack_1:output:0" - input: "cond/cond/strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/cond/strided_slice/Switch" - op: "Switch" - input: "cond/Cast:y:0" - input: "cond/cond/pred_id:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@cond/Cast" - } - } - } - } - node_def { - name: "cond/cond/strided_slice_1/stack" - op: "Const" - input: "^cond/cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_1/stack_1" - op: "Const" - input: "^cond/cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_1/stack_2" - op: "Const" - input: "^cond/cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_1" - op: "StridedSlice" - input: "cond/cond/strided_slice/Switch:output_true:0" - input: "cond/cond/strided_slice_1/stack:output:0" - input: "cond/cond/strided_slice_1/stack_1:output:0" - input: "cond/cond/strided_slice_1/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/cond/truediv" - op: "RealDiv" - input: "cond/cond/strided_slice:output:0" - input: "cond/cond/strided_slice_1:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "cond/cond/mul/y" - op: "Const" - input: "^cond/cond/switch_t" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 224.0 - } - } - } - } - node_def { - name: "cond/cond/mul" - op: "Mul" - input: "cond/cond/truediv:z:0" - input: "cond/cond/mul/y:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "cond/cond/Cast/x/1" - op: "Const" - input: "^cond/cond/switch_t" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 224.0 - } - } - } - } - node_def { - name: "cond/cond/Cast/x" - op: "Pack" - input: "cond/cond/mul:z:0" - input: "cond/cond/Cast/x/1:output:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node_def { - name: "cond/cond/Cast" - op: "Cast" - input: "cond/cond/Cast/x:output:0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "cond/cond/strided_slice_2/stack" - op: "Const" - input: "^cond/cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_2/stack_1" - op: "Const" - input: "^cond/cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_2/stack_2" - op: "Const" - input: "^cond/cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_2" - op: "StridedSlice" - input: "cond/cond/strided_slice_2/Switch:output_false:0" - input: "cond/cond/strided_slice_2/stack:output:0" - input: "cond/cond/strided_slice_2/stack_1:output:0" - input: "cond/cond/strided_slice_2/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/cond/strided_slice_2/Switch" - op: "Switch" - input: "cond/Cast:y:0" - input: "cond/cond/pred_id:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@cond/Cast" - } - } - } - } - node_def { - name: "cond/cond/strided_slice_3/stack" - op: "Const" - input: "^cond/cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_3/stack_1" - op: "Const" - input: "^cond/cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_3/stack_2" - op: "Const" - input: "^cond/cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/cond/strided_slice_3" - op: "StridedSlice" - input: "cond/cond/strided_slice_2/Switch:output_false:0" - input: "cond/cond/strided_slice_3/stack:output:0" - input: "cond/cond/strided_slice_3/stack_1:output:0" - input: "cond/cond/strided_slice_3/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/cond/truediv_1" - op: "RealDiv" - input: "cond/cond/strided_slice_2:output:0" - input: "cond/cond/strided_slice_3:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "cond/cond/mul_1/y" - op: "Const" - input: "^cond/cond/switch_f" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 224.0 - } - } - } - } - node_def { - name: "cond/cond/mul_1" - op: "Mul" - input: "cond/cond/truediv_1:z:0" - input: "cond/cond/mul_1/y:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "cond/cond/Cast_1/x/0" - op: "Const" - input: "^cond/cond/switch_f" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 224.0 - } - } - } - } - node_def { - name: "cond/cond/Cast_1/x" - op: "Pack" - input: "cond/cond/Cast_1/x/0:output:0" - input: "cond/cond/mul_1:z:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node_def { - name: "cond/cond/Cast_1" - op: "Cast" - input: "cond/cond/Cast_1/x:output:0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "cond/cond/Merge" - op: "Merge" - input: "cond/cond/Cast_1:y:0" - input: "cond/cond/Cast:y:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/ResizeBicubic/images" - op: "Pack" - input: "cond/Shape/Switch:output_true:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node_def { - name: "cond/ResizeBicubic" - op: "ResizeBicubic" - input: "cond/ResizeBicubic/images:output:0" - input: "cond/cond/Merge:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "align_corners" - value { - b: false - } - } - } - node_def { - name: "cond/strided_slice_2/stack" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "cond/strided_slice_2/stack_1" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_2/stack_2" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_2" - op: "StridedSlice" - input: "cond/ResizeBicubic:resized_images:0" - input: "cond/strided_slice_2/stack:output:0" - input: "cond/strided_slice_2/stack_1:output:0" - input: "cond/strided_slice_2/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/Shape_1" - op: "Shape" - input: "cond/strided_slice_2:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/strided_slice_3/stack" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "cond/strided_slice_3/stack_1" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_3/stack_2" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_3" - op: "StridedSlice" - input: "cond/Shape_1:output:0" - input: "cond/strided_slice_3/stack:output:0" - input: "cond/strided_slice_3/stack_1:output:0" - input: "cond/strided_slice_3/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/Shape_2" - op: "Shape" - input: "cond/strided_slice_2:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/strided_slice_4/stack" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_4/stack_1" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } - } - node_def { - name: "cond/strided_slice_4/stack_2" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_4" - op: "StridedSlice" - input: "cond/Shape_2:output:0" - input: "cond/strided_slice_4/stack:output:0" - input: "cond/strided_slice_4/stack_1:output:0" - input: "cond/strided_slice_4/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/sub/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 224 - } - } - } - } - node_def { - name: "cond/sub" - op: "Sub" - input: "cond/strided_slice_3:output:0" - input: "cond/sub/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/add/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/add" - op: "Add" - input: "cond/sub:z:0" - input: "cond/add/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/truediv/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } - } - node_def { - name: "cond/truediv/Cast" - op: "Cast" - input: "cond/add:z:0" - attr { - key: "DstT" - value { - type: DT_DOUBLE - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/truediv/Cast_1" - op: "Cast" - input: "cond/truediv/y:output:0" - attr { - key: "DstT" - value { - type: DT_DOUBLE - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/truediv" - op: "RealDiv" - input: "cond/truediv/Cast:y:0" - input: "cond/truediv/Cast_1:y:0" - attr { - key: "T" - value { - type: DT_DOUBLE - } - } - } - node_def { - name: "cond/sub_1/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 224 - } - } - } - } - node_def { - name: "cond/sub_1" - op: "Sub" - input: "cond/strided_slice_4:output:0" - input: "cond/sub_1/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/add_1/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/add_1" - op: "Add" - input: "cond/sub_1:z:0" - input: "cond/add_1/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/truediv_1/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } - } - node_def { - name: "cond/truediv_1/Cast" - op: "Cast" - input: "cond/add_1:z:0" - attr { - key: "DstT" - value { - type: DT_DOUBLE - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/truediv_1/Cast_1" - op: "Cast" - input: "cond/truediv_1/y:output:0" - attr { - key: "DstT" - value { - type: DT_DOUBLE - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/truediv_1" - op: "RealDiv" - input: "cond/truediv_1/Cast:y:0" - input: "cond/truediv_1/Cast_1:y:0" - attr { - key: "T" - value { - type: DT_DOUBLE - } - } - } - node_def { - name: "cond/Shape_3" - op: "Shape" - input: "cond/strided_slice_2:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/Rank" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "cond/Equal/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 3 - } - } - } - } - node_def { - name: "cond/Equal" - op: "Equal" - input: "cond/Rank:output:0" - input: "cond/Equal/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/Assert/Const" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Rank of image must be equal to 3." - } - } - } - } - node_def { - name: "cond/Assert/Assert/data_0" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Rank of image must be equal to 3." - } - } - } - } - node_def { - name: "cond/Assert/Assert" - op: "Assert" - input: "cond/Equal:z:0" - input: "cond/Assert/Assert/data_0:output:0" - attr { - key: "T" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "summarize" - value { - i: 3 - } - } - } - node_def { - name: "cond/strided_slice_5/stack" - op: "Const" - input: "^cond/Assert/Assert" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } - } - node_def { - name: "cond/strided_slice_5/stack_1" - op: "Const" - input: "^cond/Assert/Assert" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } - } - node_def { - name: "cond/strided_slice_5/stack_2" - op: "Const" - input: "^cond/Assert/Assert" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_5" - op: "StridedSlice" - input: "cond/Shape_3:output:0" - input: "cond/strided_slice_5/stack:output:0" - input: "cond/strided_slice_5/stack_1:output:0" - input: "cond/strided_slice_5/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/stack/0" - op: "Const" - input: "^cond/Assert/Assert" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 224 - } - } - } - } - node_def { - name: "cond/stack/1" - op: "Const" - input: "^cond/Assert/Assert" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 224 - } - } - } - } - node_def { - name: "cond/stack" - op: "Pack" - input: "cond/stack/0:output:0" - input: "cond/stack/1:output:0" - input: "cond/strided_slice_5:output:0" - attr { - key: "N" - value { - i: 3 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node_def { - name: "cond/strided_slice_6/stack" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "cond/strided_slice_6/stack_1" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_6/stack_2" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_6" - op: "StridedSlice" - input: "cond/Shape_3:output:0" - input: "cond/strided_slice_6/stack:output:0" - input: "cond/strided_slice_6/stack_1:output:0" - input: "cond/strided_slice_6/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/GreaterEqual/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 224 - } - } - } - } - node_def { - name: "cond/GreaterEqual" - op: "GreaterEqual" - input: "cond/strided_slice_6:output:0" - input: "cond/GreaterEqual/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/strided_slice_7/stack" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_7/stack_1" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 2 - } - } - } - } - node_def { - name: "cond/strided_slice_7/stack_2" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_7" - op: "StridedSlice" - input: "cond/Shape_3:output:0" - input: "cond/strided_slice_7/stack:output:0" - input: "cond/strided_slice_7/stack_1:output:0" - input: "cond/strided_slice_7/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/GreaterEqual_1/y" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 224 - } - } - } - } - node_def { - name: "cond/GreaterEqual_1" - op: "GreaterEqual" - input: "cond/strided_slice_7:output:0" - input: "cond/GreaterEqual_1/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/LogicalAnd" - op: "LogicalAnd" - input: "cond/GreaterEqual:z:0" - input: "cond/GreaterEqual_1:z:0" - } - node_def { - name: "cond/Assert_1/Const" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Crop size greater than the image size." - } - } - } - } - node_def { - name: "cond/Assert_1/Assert/data_0" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "Crop size greater than the image size." - } - } - } - } - node_def { - name: "cond/Assert_1/Assert" - op: "Assert" - input: "cond/LogicalAnd:z:0" - input: "cond/Assert_1/Assert/data_0:output:0" - attr { - key: "T" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "summarize" - value { - i: 3 - } - } - } - node_def { - name: "cond/stack_1/2" - op: "Const" - input: "^cond/switch_t" - attr { - key: "dtype" - value { - type: DT_DOUBLE - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_DOUBLE - tensor_shape { - } - double_val: 0.0 - } - } - } - } - node_def { - name: "cond/stack_1" - op: "Pack" - input: "cond/truediv:z:0" - input: "cond/truediv_1:z:0" - input: "cond/stack_1/2:output:0" - attr { - key: "N" - value { - i: 3 - } - } - attr { - key: "T" - value { - type: DT_DOUBLE - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node_def { - name: "cond/ToInt32" - op: "Cast" - input: "cond/stack_1:output:0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_DOUBLE - } - } - } - node_def { - name: "cond/Slice" - op: "Slice" - input: "cond/strided_slice_2:output:0" - input: "cond/ToInt32:y:0" - input: "cond/stack:output:0" - input: "^cond/Assert_1/Assert" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "cond/Reshape" - op: "Reshape" - input: "cond/Slice:output:0" - input: "cond/stack:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "cond/ResizeBicubic_1/images" - op: "Pack" - input: "cond/ResizeBicubic_1/images/Switch:output_false:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "axis" - value { - i: 0 - } - } - } - node_def { - name: "cond/ResizeBicubic_1/images/Switch" - op: "Switch" - input: "distorted_bounding_box_crop/Slice:output:0" - input: "cond/pred_id:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@distorted_bounding_box_crop/Slice" - } - } - } - } - node_def { - name: "cond/ResizeBicubic_1/size" - op: "Const" - input: "^cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\340\000\000\000\340\000\000\000" - } - } - } - } - node_def { - name: "cond/ResizeBicubic_1" - op: "ResizeBicubic" - input: "cond/ResizeBicubic_1/images:output:0" - input: "cond/ResizeBicubic_1/size:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "align_corners" - value { - b: false - } - } - } - node_def { - name: "cond/strided_slice_8/stack" - op: "Const" - input: "^cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "cond/strided_slice_8/stack_1" - op: "Const" - input: "^cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_8/stack_2" - op: "Const" - input: "^cond/switch_f" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "cond/strided_slice_8" - op: "StridedSlice" - input: "cond/ResizeBicubic_1:resized_images:0" - input: "cond/strided_slice_8/stack:output:0" - input: "cond/strided_slice_8/stack_1:output:0" - input: "cond/strided_slice_8/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "cond/Merge" - op: "Merge" - input: "cond/strided_slice_8:output:0" - input: "cond/Reshape:output:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "Const_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 3 - } - } - tensor_content: "\354Q\370>\325x\351>;\337\317>" - } - } - } - } - node_def { - name: "sub" - op: "Sub" - input: "cond/Merge:output:0" - input: "Const_2:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "Const_3" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 1 - } - dim { - size: 1 - } - dim { - size: 3 - } - } - tensor_content: "\372~j>B`e>fff>" - } - } - } - } - node_def { - name: "truediv" - op: "RealDiv" - input: "sub:z:0" - input: "Const_3:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "random_flip_left_right/control_dependency" - op: "Identity" - input: "truediv:z:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@truediv" - } - } - } - } - node_def { - name: "random_flip_left_right/random_uniform/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node_def { - name: "random_flip_left_right/random_uniform/min" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } - } - node_def { - name: "random_flip_left_right/random_uniform/max" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1.0 - } - } - } - } - node_def { - name: "random_flip_left_right/random_uniform/RandomUniform" - op: "RandomUniform" - input: "random_flip_left_right/random_uniform/shape:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } - } - node_def { - name: "random_flip_left_right/random_uniform/sub" - op: "Sub" - input: "random_flip_left_right/random_uniform/max:output:0" - input: "random_flip_left_right/random_uniform/min:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "random_flip_left_right/random_uniform/mul" - op: "Mul" - input: "random_flip_left_right/random_uniform/RandomUniform:output:0" - input: "random_flip_left_right/random_uniform/sub:z:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "random_flip_left_right/random_uniform" - op: "Add" - input: "random_flip_left_right/random_uniform/mul:z:0" - input: "random_flip_left_right/random_uniform/min:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "random_flip_left_right/Less/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.5 - } - } - } - } - node_def { - name: "random_flip_left_right/Less" - op: "Less" - input: "random_flip_left_right/random_uniform:z:0" - input: "random_flip_left_right/Less/y:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "random_flip_left_right/Switch" - op: "Switch" - input: "random_flip_left_right/Less:z:0" - input: "random_flip_left_right/Less:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "random_flip_left_right/switch_t" - op: "Identity" - input: "random_flip_left_right/Switch:output_true:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "random_flip_left_right/switch_f" - op: "Identity" - input: "random_flip_left_right/Switch:output_false:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "random_flip_left_right/pred_id" - op: "Identity" - input: "random_flip_left_right/Less:z:0" - attr { - key: "T" - value { - type: DT_BOOL - } - } - } - node_def { - name: "random_flip_left_right/ReverseV2/axis" - op: "Const" - input: "^random_flip_left_right/switch_t" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "random_flip_left_right/ReverseV2" - op: "ReverseV2" - input: "random_flip_left_right/ReverseV2/Switch:output_true:0" - input: "random_flip_left_right/ReverseV2/axis:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - } - node_def { - name: "random_flip_left_right/ReverseV2/Switch" - op: "Switch" - input: "random_flip_left_right/control_dependency:output:0" - input: "random_flip_left_right/pred_id:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@truediv" - } - } - } - } - node_def { - name: "random_flip_left_right/Switch_1" - op: "Switch" - input: "random_flip_left_right/control_dependency:output:0" - input: "random_flip_left_right/pred_id:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@truediv" - } - } - } - } - node_def { - name: "random_flip_left_right/Merge" - op: "Merge" - input: "random_flip_left_right/Switch_1:output_false:0" - input: "random_flip_left_right/ReverseV2:output:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - node_def { - name: "Reshape_1/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\340\000\000\000\340\000\000\000\003\000\000\000" - } - } - } - } - node_def { - name: "Reshape_1" - op: "Reshape" - input: "random_flip_left_right/Merge:output:0" - input: "Reshape_1/shape:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "Reshape_2/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node_def { - name: "Reshape_2" - op: "Reshape" - input: "ParseSingleExample/ParseSingleExample:dense_values:0" - input: "Reshape_2/shape:output:0" - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "Cast_1" - op: "Cast" - input: "Reshape_2:output:0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_INT64 - } - } - } - node_def { - name: "sub_1/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } - } - node_def { - name: "sub_1" - op: "Sub" - input: "Cast_1:y:0" - input: "sub_1/y:output:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - } - ret { - key: "Reshape_1" - value: "Reshape_1:output:0" - } - ret { - key: "sub_1" - value: "sub_1:z:0" - } - } - function { - signature { - name: "tf_predicate_7089b845" - input_arg { - name: "arg0" - type: DT_FLOAT - } - input_arg { - name: "arg1" - type: DT_INT32 - } - input_arg { - name: "Equal/Placeholder" - type: DT_INT64 - } - output_arg { - name: "Equal" - type: DT_BOOL - } - description: "A wrapper for Defun that facilitates shape inference." - } - node_def { - name: "Shape" - op: "Shape" - input: "arg0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT64 - } - } - } - node_def { - name: "strided_slice/stack" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "strided_slice/stack_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "strided_slice/stack_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "strided_slice" - op: "StridedSlice" - input: "Shape:output:0" - input: "strided_slice/stack:output:0" - input: "strided_slice/stack_1:output:0" - input: "strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "Equal" - op: "Equal" - input: "strided_slice:output:0" - input: "Equal/Placeholder" - attr { - key: "T" - value { - type: DT_INT64 - } - } - } - ret { - key: "Equal" - value: "Equal:z:0" - } - } - function { - signature { - name: "_make_dataset_5fa5e1f4" - output_arg { - name: "PrefetchDataset_1" - type: DT_VARIANT - } - is_stateful: true - } - node_def { - name: "TensorSliceDataset/MatchingFiles/pattern" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "$(DATA_DIR)" - } - } - } - } - node_def { - name: "TensorSliceDataset/MatchingFiles" - op: "MatchingFiles" - input: "TensorSliceDataset/MatchingFiles/pattern:output:0" - } - node_def { - name: "TensorSliceDataset" - op: "TensorSliceDataset" - input: "TensorSliceDataset/MatchingFiles:filenames:0" - attr { - key: "Toutput_types" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - } - node_def { - name: "ShuffleDataset/MatchingFiles/pattern" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "$(DATA_DIR)" - } - } - } - } - node_def { - name: "ShuffleDataset/MatchingFiles" - op: "MatchingFiles" - input: "ShuffleDataset/MatchingFiles/pattern:output:0" - } - node_def { - name: "ShuffleDataset/Shape" - op: "Shape" - input: "ShuffleDataset/MatchingFiles:filenames:0" - attr { - key: "T" - value { - type: DT_STRING - } - } - attr { - key: "out_type" - value { - type: DT_INT64 - } - } - } - node_def { - name: "ShuffleDataset/strided_slice/stack" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset/strided_slice/stack_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "ShuffleDataset/strided_slice/stack_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "ShuffleDataset/strided_slice" - op: "StridedSlice" - input: "ShuffleDataset/Shape:output:0" - input: "ShuffleDataset/strided_slice/stack:output:0" - input: "ShuffleDataset/strided_slice/stack_1:output:0" - input: "ShuffleDataset/strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "ShuffleDataset/Maximum/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 1 - } - } - } - } - node_def { - name: "ShuffleDataset/Maximum" - op: "Maximum" - input: "ShuffleDataset/strided_slice:output:0" - input: "ShuffleDataset/Maximum/y:output:0" - attr { - key: "T" - value { - type: DT_INT64 - } - } - } - node_def { - name: "ShuffleDataset/seed" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset/seed2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset" - op: "ShuffleDataset" - input: "TensorSliceDataset:handle:0" - input: "ShuffleDataset/Maximum:z:0" - input: "ShuffleDataset/seed:output:0" - input: "ShuffleDataset/seed2:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "reshuffle_each_iteration" - value { - b: true - } - } - } - node_def { - name: "ShuffleDataset_1/buffer_size" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 1024 - } - } - } - } - node_def { - name: "ShuffleDataset_1/seed_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset_1/seed2_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset_1" - op: "ShuffleDataset" - input: "ShuffleDataset:handle:0" - input: "ShuffleDataset_1/buffer_size:output:0" - input: "ShuffleDataset_1/seed_1:output:0" - input: "ShuffleDataset_1/seed2_1:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "reshuffle_each_iteration" - value { - b: true - } - } - } - node_def { - name: "RepeatDataset/count" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } - } - node_def { - name: "RepeatDataset" - op: "RepeatDataset" - input: "ShuffleDataset_1:handle:0" - input: "RepeatDataset/count:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_STRING - } - } - } - } - node_def { - name: "ExperimentalParallelInterleaveDataset/cycle_length" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 8 - } - } - } - } - node_def { - name: "ExperimentalParallelInterleaveDataset/block_length" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 1 - } - } - } - } - node_def { - name: "ExperimentalParallelInterleaveDataset/sloppy" - op: "Const" - attr { - key: "dtype" - value { - type: DT_BOOL - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_BOOL - tensor_shape { - } - bool_val: true - } - } - } - } - node_def { - name: "ExperimentalParallelInterleaveDataset/buffer_output_elements" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 2 - } - } - } - } - node_def { - name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 16 - } - } - } - } - node_def { - name: "ExperimentalParallelInterleaveDataset" - op: "ExperimentalParallelInterleaveDataset" - input: "RepeatDataset:handle:0" - input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0" - input: "ExperimentalParallelInterleaveDataset/block_length:output:0" - input: "ExperimentalParallelInterleaveDataset/sloppy:output:0" - input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0" - input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "f" - value { - func { - name: "tf_map_func_91295dea" - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_STRING - } - } - } - } - node_def { - name: "ShuffleDataset_2/buffer_size_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 1024 - } - } - } - } - node_def { - name: "ShuffleDataset_2/seed_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset_2/seed2_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset_2" - op: "ShuffleDataset" - input: "ExperimentalParallelInterleaveDataset:handle:0" - input: "ShuffleDataset_2/buffer_size_1:output:0" - input: "ShuffleDataset_2/seed_2:output:0" - input: "ShuffleDataset_2/seed2_2:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_STRING - } - } - } - attr { - key: "reshuffle_each_iteration" - value { - b: true - } - } - } - node_def { - name: "ParallelMapDataset/num_parallel_calls" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 64 - } - } - } - } - node_def { - name: "ParallelMapDataset" - op: "ParallelMapDataset" - input: "ShuffleDataset_2:handle:0" - input: "ParallelMapDataset/num_parallel_calls:output:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "f" - value { - func { - name: "tf_map_func_74b6b15c" - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 224 - } - dim { - size: 224 - } - dim { - size: 3 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - } - node_def { - name: "PrefetchDataset/buffer_size_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 64 - } - } - } - } - node_def { - name: "PrefetchDataset" - op: "PrefetchDataset" - input: "ParallelMapDataset:handle:0" - input: "PrefetchDataset/buffer_size_2:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 224 - } - dim { - size: 224 - } - dim { - size: 3 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - } - node_def { - name: "BatchDataset/batch_size" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 64 - } - } - } - } - node_def { - name: "BatchDataset" - op: "BatchDataset" - input: "PrefetchDataset:handle:0" - input: "BatchDataset/batch_size:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 224 - } - dim { - size: 224 - } - dim { - size: 3 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - } - node_def { - name: "FilterDataset/batch_size_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 64 - } - } - } - } - node_def { - name: "FilterDataset" - op: "FilterDataset" - input: "BatchDataset:handle:0" - input: "FilterDataset/batch_size_1:output:0" - attr { - key: "Targuments" - value { - list { - type: DT_INT64 - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 224 - } - dim { - size: 224 - } - dim { - size: 3 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - attr { - key: "predicate" - value { - func { - name: "tf_predicate_7089b845" - } - } - } - } - node_def { - name: "PrefetchDataset_1/buffer_size_3" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 2 - } - } - } - } - node_def { - name: "PrefetchDataset_1" - op: "PrefetchDataset" - input: "FilterDataset:handle:0" - input: "PrefetchDataset_1/buffer_size_3:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 64 - } - dim { - size: 224 - } - dim { - size: 224 - } - dim { - size: 3 - } - } - shape { - dim { - size: 64 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - } - ret { - key: "PrefetchDataset_1" - value: "PrefetchDataset_1:handle:0" - } - } -} -)PREFIX"; - - *dataset_name = "_make_dataset_5fa5e1f4"; - std::function mutate_proto_func = - [dataset_name, file_path](FunctionDef* fdef) { - VLOG(1) << "Processsing function " << fdef->DebugString(); - if (std::string(fdef->signature().name()) != *dataset_name) return; - // Change the input file pattern to `file_path`. - bool found = false; - for (auto& node_def : *fdef->mutable_node_def()) { - if (node_def.name() != "TensorSliceDataset/MatchingFiles/pattern" && - node_def.name() != "ShuffleDataset/MatchingFiles/pattern") - continue; - DCHECK_EQ(node_def.op(), "Const"); - DCHECK_GT(node_def.attr().count("value"), 0); - found = true; - DCHECK_EQ(node_def.attr().at("value").tensor().string_val(0), - "$(DATA_DIR)"); - VLOG(1) << "Setting the value of node_def " - "TensorSliceDataset/MatchingFiles/pattern to " - << file_path; - auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); - tensor->clear_string_val(); - tensor->add_string_val(file_path); - } - VLOG(1) << "Rewrote function to " << fdef->DebugString(); - DCHECK(found); - }; - return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); -#endif -} -#endif - -#if not defined(PLATFORM_WINDOWS) -// On success, returns a set of TF_Function instances encoding a dataset -// node stack that reads an MNIST file dataset from `file_path`, and -// sets `dataset_name` to the created dataset name. The returned functions must -// be deleted by calling TF_DeleteFunction. -static std::vector CreateMNISTDatasetFunctions( - const char* file_path, int batch_size, std::string* dataset_name, - TF_Status* status) { -#if defined(PLATFORM_WINDOWS) - status->status = tensorflow::errors::Unimplemented( - "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " - "is not implemented for Windows"); - return nullptr; -#else - const char* func_def = R"PREFIX( -library { - function { - signature { - name: "tf_map_func_521bfd08" - input_arg { - name: "arg0" - type: DT_STRING - } - output_arg { - name: "truediv" - type: DT_FLOAT - } - description: "A wrapper for Defun that facilitates shape inference." - } - node_def { - name: "DecodeRaw" - op: "DecodeRaw" - input: "arg0" - attr { - key: "little_endian" - value { - b: true - } - } - attr { - key: "out_type" - value { - type: DT_UINT8 - } - } - } - node_def { - name: "Cast" - op: "Cast" - input: "DecodeRaw:output:0" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_UINT8 - } - } - } - node_def { - name: "Reshape/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 784 - } - } - } - } - node_def { - name: "Reshape" - op: "Reshape" - input: "Cast:y:0" - input: "Reshape/shape:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "truediv/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 255.0 - } - } - } - } - node_def { - name: "truediv" - op: "RealDiv" - input: "Reshape:output:0" - input: "truediv/y:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - ret { - key: "truediv" - value: "truediv:z:0" - } - } - function { - signature { - name: "tf_map_func_9a08860d" - input_arg { - name: "arg0" - type: DT_STRING - } - output_arg { - name: "ToInt32" - type: DT_INT32 - } - description: "A wrapper for Defun that facilitates shape inference." - } - node_def { - name: "DecodeRaw" - op: "DecodeRaw" - input: "arg0" - attr { - key: "little_endian" - value { - b: true - } - } - attr { - key: "out_type" - value { - type: DT_UINT8 - } - } - } - node_def { - name: "Reshape/shape" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } - } - node_def { - name: "Reshape" - op: "Reshape" - input: "DecodeRaw:output:0" - input: "Reshape/shape:output:0" - attr { - key: "T" - value { - type: DT_UINT8 - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - } - node_def { - name: "ToInt32" - op: "Cast" - input: "Reshape:output:0" - attr { - key: "DstT" - value { - type: DT_INT32 - } - } - attr { - key: "SrcT" - value { - type: DT_UINT8 - } - } - } - ret { - key: "ToInt32" - value: "ToInt32:y:0" - } - } - function { - signature { - name: "tf_predicate_7089b845" - input_arg { - name: "arg0" - type: DT_FLOAT - } - input_arg { - name: "arg1" - type: DT_INT32 - } - input_arg { - name: "Equal/Placeholder" - type: DT_INT64 - } - output_arg { - name: "Equal" - type: DT_BOOL - } - description: "A wrapper for Defun that facilitates shape inference." - } - node_def { - name: "Shape" - op: "Shape" - input: "arg0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT64 - } - } - } - node_def { - name: "strided_slice/stack" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } - } - node_def { - name: "strided_slice/stack_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "strided_slice/stack_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } - } - node_def { - name: "strided_slice" - op: "StridedSlice" - input: "Shape:output:0" - input: "strided_slice/stack:output:0" - input: "strided_slice/stack_1:output:0" - input: "strided_slice/stack_2:output:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT64 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } - } - node_def { - name: "Equal" - op: "Equal" - input: "strided_slice:output:0" - input: "Equal/Placeholder" - attr { - key: "T" - value { - type: DT_INT64 - } - } - } - ret { - key: "Equal" - value: "Equal:z:0" - } - } - function { - signature { - name: "_make_dataset_2451e43a" - output_arg { - name: "FilterDataset" - type: DT_VARIANT - } - is_stateful: true - } - node_def { - name: "FixedLengthRecordDataset/filenames" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "$(DATA_DIR)/train-images-idx3-ubyte" - } - } - } - } - node_def { - name: "FixedLengthRecordDataset/header_bytes" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 16 - } - } - } - } - node_def { - name: "FixedLengthRecordDataset/record_bytes" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 784 - } - } - } - } - node_def { - name: "FixedLengthRecordDataset/footer_bytes" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "FixedLengthRecordDataset/buffer_size" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 262144 - } - } - } - } - node_def { - name: "FixedLengthRecordDataset" - op: "FixedLengthRecordDataset" - input: "FixedLengthRecordDataset/filenames:output:0" - input: "FixedLengthRecordDataset/header_bytes:output:0" - input: "FixedLengthRecordDataset/record_bytes:output:0" - input: "FixedLengthRecordDataset/footer_bytes:output:0" - input: "FixedLengthRecordDataset/buffer_size:output:0" - } - node_def { - name: "MapDataset" - op: "MapDataset" - input: "FixedLengthRecordDataset:handle:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "f" - value { - func { - name: "tf_map_func_521bfd08" - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 784 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - } - } - } - } - node_def { - name: "FixedLengthRecordDataset_1/filenames_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "$(DATA_DIR)/train-labels-idx1-ubyte" - } - } - } - } - node_def { - name: "FixedLengthRecordDataset_1/header_bytes_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 8 - } - } - } - } - node_def { - name: "FixedLengthRecordDataset_1/record_bytes_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 1 - } - } - } - } - node_def { - name: "FixedLengthRecordDataset_1/footer_bytes_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "FixedLengthRecordDataset_1/buffer_size_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 262144 - } - } - } - } - node_def { - name: "FixedLengthRecordDataset_1" - op: "FixedLengthRecordDataset" - input: "FixedLengthRecordDataset_1/filenames_1:output:0" - input: "FixedLengthRecordDataset_1/header_bytes_1:output:0" - input: "FixedLengthRecordDataset_1/record_bytes_1:output:0" - input: "FixedLengthRecordDataset_1/footer_bytes_1:output:0" - input: "FixedLengthRecordDataset_1/buffer_size_1:output:0" - } - node_def { - name: "MapDataset_1" - op: "MapDataset" - input: "FixedLengthRecordDataset_1:handle:0" - attr { - key: "Targuments" - value { - list { - } - } - } - attr { - key: "f" - value { - func { - name: "tf_map_func_9a08860d" - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_INT32 - } - } - } - } - node_def { - name: "ZipDataset" - op: "ZipDataset" - input: "MapDataset:handle:0" - input: "MapDataset_1:handle:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 784 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - } - node_def { - name: "CacheDataset/filename" - op: "Const" - attr { - key: "dtype" - value { - type: DT_STRING - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_STRING - tensor_shape { - } - string_val: "" - } - } - } - } - node_def { - name: "CacheDataset" - op: "CacheDataset" - input: "ZipDataset:handle:0" - input: "CacheDataset/filename:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 784 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - } - node_def { - name: "RepeatDataset/count" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -1 - } - } - } - } - node_def { - name: "RepeatDataset" - op: "RepeatDataset" - input: "CacheDataset:handle:0" - input: "RepeatDataset/count:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 784 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - } - node_def { - name: "ShuffleDataset/buffer_size_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 50000 - } - } - } - } - node_def { - name: "ShuffleDataset/seed" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset/seed2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: 0 - } - } - } - } - node_def { - name: "ShuffleDataset" - op: "ShuffleDataset" - input: "RepeatDataset:handle:0" - input: "ShuffleDataset/buffer_size_2:output:0" - input: "ShuffleDataset/seed:output:0" - input: "ShuffleDataset/seed2:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: 784 - } - } - shape { - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - attr { - key: "reshuffle_each_iteration" - value { - b: true - } - } - } - node_def { - name: "BatchDataset/batch_size" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -123 - } - } - } - } - node_def { - name: "BatchDataset" - op: "BatchDataset" - input: "ShuffleDataset:handle:0" - input: "BatchDataset/batch_size:output:0" - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - } - node_def { - name: "FilterDataset/batch_size_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT64 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT64 - tensor_shape { - } - int64_val: -123 - } - } - } - } - node_def { - name: "FilterDataset" - op: "FilterDataset" - input: "BatchDataset:handle:0" - input: "FilterDataset/batch_size_1:output:0" - attr { - key: "Targuments" - value { - list { - type: DT_INT64 - } - } - } - attr { - key: "output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 784 - } - } - shape { - dim { - size: -1 - } - } - } - } - } - attr { - key: "output_types" - value { - list { - type: DT_FLOAT - type: DT_INT32 - } - } - } - attr { - key: "predicate" - value { - func { - name: "tf_predicate_7089b845" - } - } - } - } - ret { - key: "FilterDataset" - value: "FilterDataset:handle:0" - } - } -} -)PREFIX"; - - *dataset_name = "_make_dataset_2451e43a"; - std::function mutate_proto_func = - [dataset_name, file_path, batch_size](FunctionDef* fdef) { - VLOG(1) << "Processsing function " << fdef->DebugString(); - if (std::string(fdef->signature().name()) != *dataset_name) return; - // Change the input file pattern to `file_path`. - bool found_file_path = false, found_batch_size = false; - // `node_def` may be mutated. - for (auto& node_def : *fdef->mutable_node_def()) { - if (node_def.name() == "FixedLengthRecordDataset/filenames" || - node_def.name() == "FixedLengthRecordDataset_1/filenames_1") { - DCHECK_EQ(node_def.op(), "Const"); - DCHECK_GT(node_def.attr().count("value"), 0); - found_file_path = true; - // Replace $(DATA_DIR)/foo with /foo - // TODO(hongm): Use StringPiece manipulation for better efficiency. - const std::string cur_value = - node_def.attr().at("value").tensor().string_val(0); - const std::string pattern = "$(DATA_DIR)"; - DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0); - const std::string new_value = - file_path + cur_value.substr(pattern.length()); - VLOG(1) << "Setting the value of node_def " << node_def.name() - << " to " << new_value; - auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); - tensor->clear_string_val(); - tensor->add_string_val(new_value); - } else if (node_def.name() == "BatchDataset/batch_size" || - node_def.name() == "FilterDataset/batch_size_1") { - DCHECK_EQ(node_def.op(), "Const"); - DCHECK_GT(node_def.attr().count("value"), 0); - found_batch_size = true; - // Replace $(BATCH_SIZE) with `batch_size` - DCHECK_EQ(node_def.attr().at("value").tensor().int64_val(0), -123); - VLOG(1) << "Setting the batch size attr value of node_def " - << node_def.name() << " to " << batch_size; - auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); - tensor->clear_int64_val(); - tensor->add_int64_val(batch_size); - } - } - VLOG(1) << "Rewrote function to " << fdef->DebugString(); - DCHECK(found_file_path); - DCHECK(found_batch_size); - }; - return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); -#endif -} -#endif - -// Adds the input functions to `graph`. On success, returns the created -// IteratorGetNext node. -static TF_Operation* AddDatasetFunctionAndIteratorNodesToGraph( - const std::vector& funcs, const std::string& dataset_name, - const std::vector& output_types, - const std::vector& output_shapes, - TF_Graph* graph, TF_Status* status) { - DCHECK(!dataset_name.empty()); - for (auto& func : funcs) { - TF_GraphCopyFunction(graph, func.get(), /*gradient*/ nullptr, status); - if (!status->status.ok()) { - return nullptr; - } - } - - tensorflow::mutex_lock c(graph->mu); - - tensorflow::NameAttrList func; - func.set_name(dataset_name); - // Run the iterator node on CPU. - Node* oneshot_iterator_node; - tensorflow::Status s = NodeBuilder("OneShotIterator", "OneShotIterator") - .Device("/device:CPU:0") - .Attr("container", "") - .Attr("dataset_factory", func) - .Attr("output_types", output_types) - .Attr("output_shapes", output_shapes) - .Attr("shared_name", "") - .Finalize(&graph->graph, &oneshot_iterator_node); - if (!s.ok()) { - status->status = s; - return nullptr; - } - // Run shape inference function for each newly added node, so that more - // subsequent nodes can be added to the graph via C API (TF_NewOperation()). - s = graph->refiner.AddNode(oneshot_iterator_node); - if (!s.ok()) { - status->status = s; - return nullptr; - } - - // Run the iterator node on CPU. - Node* getnext_node; - s = NodeBuilder("IteratorGetNext", "IteratorGetNext") - .Input(oneshot_iterator_node) - .Device("/device:CPU:0") - .Attr("output_types", output_types) - .Attr("output_shapes", output_shapes) - .Finalize(&graph->graph, &getnext_node); - if (!s.ok()) { - status->status = s; - return nullptr; - } - // Run shape inference function for each newly added node, so that more - // subsequent nodes can be added to the graph via C API (TF_NewOperation()). - s = graph->refiner.AddNode(getnext_node); - if (!s.ok()) { - status->status = s; - return nullptr; - } - - VLOG(1) << "Output graph: " << graph->graph.ToGraphDefDebug().DebugString(); - return ToTF_Operation(getnext_node); -} - -TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph, - TF_Status* status) { - tensorflow::Status s; - - std::string dataset_name; - UniqueFuncPtr result_func = CreateFakeDatasetFunction(&dataset_name, status); - if (!status->status.ok()) { - return nullptr; - } - - std::vector funcs; - funcs.push_back(std::move(result_func)); - std::vector output_shape_list; - output_shape_list.push_back(tensorflow::TensorShapeProto()); - auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( - funcs, dataset_name, {tensorflow::DT_FLOAT}, output_shape_list, graph, - status); - if (!status->status.ok()) { - return nullptr; - } - - return getnext_node; -} - -TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( - TF_Graph* graph, const char* file_path, int batch_size, - unsigned char is_mnist, TF_Status* status) { -#if defined(PLATFORM_WINDOWS) - // TODO(ashankar): get these functions working on Windows. - status->status = tensorflow::errors::Unimplemented( - "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " - "is not implemented for Windows"); - return nullptr; -#else - tensorflow::Status s; - - std::string dataset_name; - const auto& funcs = - is_mnist - ? CreateMNISTDatasetFunctions(file_path, batch_size, &dataset_name, - status) - : CreateImagenetDatasetFunctions(file_path, &dataset_name, status); - if (!status->status.ok()) { - return nullptr; - } - - std::vector output_shape_list; - // batch_size X 224 X 224 X 3 - auto image_shape = tensorflow::TensorShapeProto(); - image_shape.add_dim()->set_size(batch_size); - if (is_mnist) { - image_shape.add_dim()->set_size(784); - } else { - image_shape.add_dim()->set_size(224); - image_shape.add_dim()->set_size(224); - image_shape.add_dim()->set_size(3); - } - output_shape_list.push_back(image_shape); - - // batch_size - auto label_shape = tensorflow::TensorShapeProto(); - label_shape.add_dim()->set_size(batch_size); - output_shape_list.push_back(label_shape); - auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( - funcs, dataset_name, {tensorflow::DT_FLOAT, tensorflow::DT_INT32}, - output_shape_list, graph, status); - if (!status->status.ok()) { - return nullptr; - } - - tensorflow::mutex_lock c(graph->mu); - VLOG(1) << "The extended graph: " - << graph->graph.ToGraphDefDebug().DebugString(); - - return getnext_node; -#endif -} - TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, TF_Status* status) { assert(session); @@ -8939,7 +695,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer( + LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer( std::move(server), grpc_server->worker_env()->device_mgr, grpc_server->worker_env()->collective_executor_mgr)); @@ -9062,8 +818,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, const auto& op_type = op->operation.Name(); auto op_name = tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++); - auto* desc = - TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); + std::unique_ptr desc( + TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str())); VLOG(1) << "Adding attrs."; tensorflow::AttrValueMap attrs; @@ -9077,30 +833,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, size_t inputIndex = 0; const tensorflow::OpDef& op_def = desc->node_builder.op_def(); for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) { - // TODO(bgogul): Add support for number attributes. - DCHECK(input_arg.number_attr().empty()) - << "Number attributes is not implemented yet."; - if (input_arg.type_list_attr().empty()) { + if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) { auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); if (!status->status.ok()) return nullptr; - TF_AddInput(desc, symbolic_input); + TF_AddInput(desc.get(), symbolic_input); continue; } - const std::string& type_list_attr = input_arg.type_list_attr(); - const auto& attr_value = attrs[type_list_attr]; - DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList) - << "Type list attribute should be a list!"; - std::vector list_inputs(attr_value.list().type_size()); + size_t list_size = 0; + if (!input_arg.type_list_attr().empty()) { + const std::string& type_list_attr = input_arg.type_list_attr(); + const auto& attr_value = attrs[type_list_attr]; + CHECK(attr_value.value_case() == tensorflow::AttrValue::kList) + << "Type list attribute should be a list!"; + list_size = attr_value.list().type_size(); + } else { + CHECK(!input_arg.number_attr().empty()); + const auto& attr_value = attrs[input_arg.number_attr()]; + CHECK(attr_value.value_case() == tensorflow::AttrValue::kI) + << "Number attribute should be int!"; + if (attr_value.i() < 0) { + status->status = tensorflow::errors::Internal( + "Number attribute for length should be >=0!"); + return nullptr; + } + list_size = attr_value.i(); + } + std::vector list_inputs(list_size); for (TF_Output& list_input : list_inputs) { list_input = getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); if (!status->status.ok()) return nullptr; } - TF_AddInputList(desc, list_inputs.data(), list_inputs.size()); + TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size()); } - auto* graph_op = TF_FinishOperation(desc, status); + auto* graph_op = TF_FinishOperation(desc.release(), status); if (!status->status.ok()) return nullptr; VLOG(1) << "Op finalized; setting return tensors."; diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 8d1a8b82fba..795768a1415 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -62,6 +62,20 @@ extern "C" { TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable); +// Set XLA's internal BuildXlaOpsPassFlags.tf_xla_enable_lazy_compilation to the +// value of 'enabled'. Also returns the original value of that flag. +// +// Use in tests to allow XLA to fallback to TF classic. This has global effect. +TF_CAPI_EXPORT unsigned char TF_SetXlaEnableLazyCompilation( + unsigned char enable); + +// Sets XLA's auto jit mode according to the specified string, which is parsed +// as if passed in XLA_FLAGS. This has global effect. +TF_CAPI_EXPORT void TF_SetXLaAutoJitMode(const char* mode); + +// Sets XLA's minimum cluster size. This has global effect. +TF_CAPI_EXPORT void TF_SetXlaMinClusterSize(int size); + // Create a serialized tensorflow.ConfigProto proto, where: // // a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if @@ -93,26 +107,6 @@ TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, TF_CAPI_EXPORT extern char* TF_FunctionDebugString(TF_Function* func, size_t* len); -// Creates a stack of data set + iterator nodes, currently hard-coded to return -// a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, -// returns the IteratorGetNext node, which caller can run or feed into an node. -// -// TODO(hongm): Extend the API to allow customization of the nodes created. -TF_CAPI_EXPORT extern TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets( - TF_Graph* graph, TF_Status* status); - -// Similar to the above API, except that the returned iterator reads the -// file based dataset from `file_path`. -// If `is_mnist` is 0, the dataset corresponds to ImageNet. -// The iterators outputs 2 tensors: -// - A float tensor of shape `batch_size` X 784 when `is_mnist` is non-zero, or -// `batch_size` X 224 X 224 X 3 otherwise. -// - An int32 tensor of shape `batch_size` -// TODO(hongm): Extend the API to allow customization of the nodes created. -TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( - TF_Graph* graph, const char* file_path, int batch_size, - unsigned char is_mnist, TF_Status* status); - // On success, dequeues a tensor from a TF-managed FifoQueue given by // `tensor_id`, associated with `session`. There must be a graph node named // "fifo_queue_dequeue_", to be executed by this API call. diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 2c92e38f03a..6eb289107c5 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -27,100 +27,6 @@ limitations under the License. namespace tensorflow { namespace { -void TestFakeIteratorStack() { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - - CSession csession(graph, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - - // Run the graph. - const float base_value = 42.0; - for (int i = 0; i < 3; ++i) { - csession.SetOutputs({get_next}); - csession.Run(s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Tensor* out = csession.output_tensor(0); - ASSERT_TRUE(out != nullptr); - ASSERT_EQ(TF_FLOAT, TF_TensorType(out)); - ASSERT_EQ(0, TF_NumDims(out)); // scalar - ASSERT_EQ(sizeof(float), TF_TensorByteSize(out)); - float* output_contents = static_cast(TF_TensorData(out)); - ASSERT_EQ(base_value + i, *output_contents); - } - - // This should error out since we've exhausted the iterator. - csession.Run(s); - ASSERT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)) << TF_Message(s); - - // Clean up - csession.CloseAndDelete(s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_DeleteGraph(graph); - TF_DeleteStatus(s); -} - -TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); } - -TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - const string file_path = tensorflow::io::JoinPath( - tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record"); - VLOG(1) << "data file path is " << file_path; - const int batch_size = 64; - TF_Operation* get_next = TF_MakeFileBasedIteratorGetNextWithDatasets( - graph, file_path.c_str(), batch_size, /*is_mnist*/ false, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - - CSession csession(graph, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - - // Run the graph. - // The two output tensors should look like: - // Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32) - // Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32) - for (int i = 0; i < 3; ++i) { - LOG(INFO) << "Running iter " << i; - csession.SetOutputs({{get_next, 0}, {get_next, 1}}); - csession.Run(s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - - { - TF_Tensor* image = csession.output_tensor(0); - ASSERT_TRUE(image != nullptr); - ASSERT_EQ(TF_FLOAT, TF_TensorType(image)); - // Confirm shape is 224 X 224 X 3 - ASSERT_EQ(4, TF_NumDims(image)); - ASSERT_EQ(batch_size, TF_Dim(image, 0)); - ASSERT_EQ(224, TF_Dim(image, 1)); - ASSERT_EQ(224, TF_Dim(image, 2)); - ASSERT_EQ(3, TF_Dim(image, 3)); - ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3, - TF_TensorByteSize(image)); - } - - { - TF_Tensor* label = csession.output_tensor(1); - ASSERT_TRUE(label != nullptr); - ASSERT_EQ(TF_INT32, TF_TensorType(label)); - ASSERT_EQ(1, TF_NumDims(label)); - ASSERT_EQ(batch_size, TF_Dim(label, 0)); - ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label)); - } - } - - // Clean up - csession.CloseAndDelete(s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_DeleteGraph(graph); - TF_DeleteStatus(s); -} - TEST(CAPI_EXPERIMENTAL, GetServerDefTest) { const string expected_text_proto(R"(cluster { job { @@ -470,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) { TFE_DeleteOp(identityn); } +TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) { + TFE_TensorHandle* matrix = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpSetAttrType(concatv2, "T", TF_FLOAT); + TFE_OpSetAttrInt(concatv2, "N", 2); + TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32); + constexpr size_t kNumInputs = 2; + for (size_t i = 0; i < kNumInputs; ++i) { + TFE_OpAddInput(concatv2, matrix, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + } + TFE_OpAddInput(concatv2, axis, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + AddEagerOpToGraphAndCheck( + concatv2, [this, kNumInputs](TF_Operation* graph_op) { + EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1); + int64_t attrN; + TF_OperationGetAttrInt(graph_op, "N", &attrN, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_EQ(attrN, kNumInputs); + EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteTensorHandle(matrix); + TFE_DeleteOp(concatv2); +} + +TEST_F(AddEagerOpToGraphTest, + GeneratesInternalErrorsForInvalidNumberAttributes) { + TFE_TensorHandle* matrix = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + int num_retvals = 5; + TFE_TensorHandle* retvals[5]; + + TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpSetAttrType(concatv2, "T", TF_FLOAT); + TFE_OpSetAttrInt(concatv2, "N", -1); + TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32); + + TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals, + &num_retvals, status_); + EXPECT_EQ(graph_op, nullptr); + EXPECT_EQ(status_->status.error_message(), + "Number attribute for length should be >=0!"); + + TFE_DeleteOp(concatv2); + TFE_DeleteTensorHandle(axis); + TFE_DeleteTensorHandle(matrix); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 03d65ecefd4..5a82cb0c48f 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/c/c_api_internal.h" - #include #include #include +#include "absl/strings/match.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -295,7 +295,8 @@ Status FillFunctionBody( } // Graph to FunctionDef conversion. This code is closely modeled on the Python -// code in tensorflow/python/framework/function.py. +// function graph_to_function_def(), which is located in +// tensorflow/python/framework/graph_to_function_def.py. Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, bool append_hash_to_fn_name, const std::vector& body_nodes, @@ -352,6 +353,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, argdef->set_type(node->output_type(idx)); const string& input_name = node_names.GetInputName(node->name()); argdef->set_name(input_name); + auto& arg_attrs = (*fdef->mutable_arg_attr())[i]; + for (const auto& attr : node->attrs()) { + // Only copy internal attributes. These attributes will be applied to + // _Arg/Placeholder nodes when this FunctionDef is converted to graph, and + // normal attributes for nodes cannot be applied to those _Arg/Placeholder + // nodes. + if (absl::StartsWith(attr.first, "_")) { + arg_attrs.mutable_attr()->insert(attr); + } + } tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; } @@ -442,12 +453,21 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, } else { signature_name = control_outputs[i]->name(); } + if (signature_name.empty()) { + return errors::InvalidArgument("Control output name must be not empty"); + } if (!control_output_names_set.insert(signature_name).second) { return errors::InvalidArgument("Repeated control output name: ", signature_name); } + const string control_output_node = + node_names.Lookup(control_outputs[i]->name()); + if (control_output_node.empty()) { + return errors::InvalidArgument( + "Control output node name must be not empty"); + } fdef->mutable_signature()->add_control_output(signature_name); - (*fdef->mutable_control_ret())[signature_name] = control_outputs[i]->name(); + (*fdef->mutable_control_ret())[signature_name] = control_output_node; } return Status::OK(); @@ -572,13 +592,13 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( std::unordered_map> input_nodes; status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs, &input_tensors, &input_nodes); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; // Process outputs. std::vector output_tensors; status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs, outputs, &output_tensors); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; // Process output names. std::vector output_names_vec; @@ -602,7 +622,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( std::vector body_nodes; status->status = tensorflow::ComputeBodyNodes( fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); - if (!status->status.ok()) return nullptr; + if (TF_GetCode(status) != TF_OK) return nullptr; // Compute body nodes. std::vector control_output_nodes; @@ -617,7 +637,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes, input_tensors, output_tensors, output_names_vec, control_output_nodes, control_output_names_vec, description, &tf_function->fdef); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { TF_DeleteFunction(tf_function); return nullptr; } diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 946f8c4a2c3..760f14cac5b 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1278,6 +1278,46 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) { EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int"); } +void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name, + const char* attr_name, const char* attr_value, + TF_Operation** op) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); + TF_SetAttrType(desc, "dtype", TF_INT32); + TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value)); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Operation* node; + NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value", + &node); + + TF_Output inputs[] = {{node, 0}}; + TF_Output outputs[] = {}; + func_ = TF_GraphToFunction( + func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1, + /*opers=*/nullptr, 1, inputs, 0, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, /*description=*/nullptr, s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(func_, nullptr); + + // Verify that FunctionDef ArgDef has attributes. + ASSERT_EQ(func_->fdef.arg_attr_size(), 1); + auto arg_attrs = func_->fdef.arg_attr().find(0); + ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end()); + auto iter = arg_attrs->second.attr().find("_test_attr"); + ASSERT_NE(iter, arg_attrs->second.attr().end()); + EXPECT_EQ(iter->second.s(), "value"); +} + TEST_F(CApiFunctionTest, SetGradientAndRun) { // Define the function and its grad DefineFunction(func_name_, &func_); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 9a69c58718b..f02160044c5 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -24,8 +24,10 @@ limitations under the License. #include #include +// clang-format off // Required for IS_MOBILE_PLATFORM -#include "tensorflow/core/platform/platform.h" // NO_LINT +#include "tensorflow/core/platform/platform.h" +// clang-format on #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #include "tensorflow/core/framework/op_gen_lib.h" diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index d3311f0cd06..1a92efb89bc 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -29,8 +29,7 @@ namespace checkpoint { class TensorSliceReader; -CheckpointReader::CheckpointReader(const string& filename, - TF_Status* out_status) +CheckpointReader::CheckpointReader(const string& filename, TF_Status* status) : reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_(nullptr), @@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename, v2_reader_.reset( new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); if (!v2_reader_->status().ok()) { - Set_TF_Status_from_Status(out_status, v2_reader_->status()); + Set_TF_Status_from_Status(status, v2_reader_->status()); return; } auto result = BuildV2VarMaps(); @@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename, } else { reader_.reset(new TensorSliceReader(filename)); if (!reader_->status().ok()) { - Set_TF_Status_from_Status(out_status, reader_->status()); + Set_TF_Status_from_Status(status, reader_->status()); return; } var_to_shape_map_.reset( diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 91654c8d4fb..0e613db7719 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -39,7 +39,7 @@ class TensorSliceReader; // variables. class CheckpointReader { public: - CheckpointReader(const string& filepattern, TF_Status* out_status); + CheckpointReader(const string& filename, TF_Status* status); bool HasTensor(const string& name) const; const string DebugString() const; diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 445b2cd2581..8c2be2af3e0 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -1,4 +1,5 @@ # Experimental extensions to the C API for eager execution of kernels. + licenses(["notice"]) # Apache 2.0 load( @@ -70,6 +71,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/profiler/lib:profiler_eager_lib", "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core:gpu_runtime", ], @@ -110,6 +112,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/profiler/lib:profiler_eager_lib", "//tensorflow/core/profiler/lib:profiler_session", ], ) @@ -200,6 +203,7 @@ tf_cuda_library( "//conditions:default": [], }) + [ "@com_google_absl//absl/memory", + "//tensorflow/c:tf_status_helper", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", @@ -236,7 +240,6 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/profiler:protos_all_cc", "@com_google_absl//absl/strings", ], ) @@ -256,3 +259,22 @@ filegroup( srcs = ["c_api.h"], visibility = ["//tensorflow:__subpackages__"], ) + +# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime +# right now, remove this public rule when no longer needed (it should be +# replaced by TF Lite) +filegroup( + name = "srcs", + srcs = glob( + [ + "*.cc", + "*.h", + ], + exclude = [ + "c_api_experimental.cc", + "c_api_experimental.h", + "*test*", + ], + ), + visibility = ["//visibility:public"], +) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc old mode 100755 new mode 100644 index 9509135e239..9c2d1dd38fd --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,11 +21,18 @@ limitations under the License. #include #include +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +// clang-format on + #include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/platform/host_info.h" +#include "tensorflow/core/platform/platform.h" // NOLINT #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -38,11 +45,15 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/distributed_runtime/remote_device.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" +#endif // !IS_MOBILE_PLATFORM #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -63,6 +74,17 @@ using tensorflow::int64; using tensorflow::string; namespace { + +const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) { + if (op->inference_ctx) { + return op->inference_ctx->op_def; + } + const tensorflow::OpDef* op_def; + status->status = + tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def); + return op_def; +} + bool IsCPU(const tensorflow::Device* d) { return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; } @@ -77,6 +99,7 @@ string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } +#if !defined(IS_MOBILE_PLATFORM) tensorflow::Status GetAllRemoteDevices( const std::vector& remote_workers, tensorflow::WorkerCacheInterface* worker_cache, @@ -114,11 +137,12 @@ tensorflow::Status CreateRemoteContexts( const std::vector& remote_workers, int64 rendezvous_id, int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + const tensorflow::eager::CreateContextRequest& base_request, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { const string& remote_worker = remote_workers[i]; - tensorflow::eager::CreateContextRequest request; + tensorflow::eager::CreateContextRequest request(base_request); tensorflow::eager::CreateContextResponse response; request.set_rendezvous_id(rendezvous_id); tensorflow::DeviceNameUtils::ParsedName parsed_name; @@ -132,7 +156,9 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); request.set_keep_alive_secs(keep_alive_secs); - auto* eager_client = remote_eager_workers->GetClient(remote_worker); + tensorflow::eager::EagerClient* eager_client; + TF_RETURN_IF_ERROR( + remote_eager_workers->GetClient(remote_worker, &eager_client)); if (eager_client == nullptr) { return tensorflow::errors::Internal( "Cannot find a client for the given target:", remote_worker); @@ -198,6 +224,23 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( remote_workers, grpc_server->master_env()->worker_cache, &remote_device_mgr)); + std::vector cluster_device_attributes; + remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes); + + std::vector local_device_attributes; + grpc_server->worker_env()->device_mgr->ListDeviceAttributes( + &local_device_attributes); + + // This request make sure that we can create Rendevzous properly between + // Local and Remote context. + tensorflow::eager::CreateContextRequest base_request; + for (const auto& da : cluster_device_attributes) { + *base_request.add_cluster_device_attributes() = da; + } + for (const auto& da : local_device_attributes) { + *base_request.add_cluster_device_attributes() = da; + } + std::shared_ptr channel_cache = grpc_server->channel_cache(); std::unique_ptr remote_eager_workers( @@ -207,14 +250,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( remote_workers, rendezvous_id, keep_alive_secs, server_def, - remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); + remote_eager_workers.get(), ctx->context->Async(), base_request, + &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, server_def, true)); + session_name, server_def, base_request.cluster_device_attributes(), + true)); std::shared_ptr worker_session; TF_RETURN_IF_ERROR( @@ -226,14 +271,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - ctx->context.InitializeRemote(std::move(server), - std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts, - r, device_mgr, keep_alive_secs); - - return tensorflow::Status::OK(); + return ctx->context->InitializeRemote( + std::move(server), grpc_server->worker_env(), worker_session, + std::move(remote_eager_workers), std::move(remote_device_mgr), + remote_contexts, r, device_mgr, keep_alive_secs, + worker_session->cluster_flr.get()); #undef LOG_AND_RETURN_IF_ERROR } +#endif // !IS_MOBILE_PLATFORM tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op, TFE_TensorHandle* input) { @@ -330,7 +375,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - status->status = ctx->context.SetAsyncForThread(enable); + status->status = ctx->context->SetAsyncForThread(enable); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -349,7 +394,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { return new TFE_Context(opts->session_options.options, opts->policy, opts->async, device_mgr.release(), - /*device_mgr_owned*/ true, r); + /*device_mgr_owned*/ true, r, + tensorflow::GetDefaultCustomKernelCreator()); } TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, @@ -359,23 +405,24 @@ TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, if (!status->status.ok()) return nullptr; tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr); + return new TFE_Context(opts->session_options.options, opts->policy, - opts->async, device_mgr, /*device_mgr_owned*/ false, - r); + opts->async, device_mgr, /*device_mgr_owned*/ false, r, + tensorflow::GetDefaultCustomKernelCreator()); } void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); - if (ctx->context.remote_device_mgr()) { - ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); + ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response); + if (ctx->context->remote_device_mgr()) { + ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response); } return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } +void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); } // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, @@ -383,6 +430,10 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, const void* proto, size_t proto_len, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) + status->status = tensorflow::errors::Unimplemented( + "TFE_ContextSetServerDef not supported on mobile"); +#else // !defined(IS_MOBILE_PLATFORM) tensorflow::ServerDef server_def; if (!server_def.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( @@ -391,11 +442,12 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, } status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); +#endif // !IS_MOBILE_PLATFORM } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - ctx->context.SetThreadLocalDevicePlacementPolicy( + ctx->context->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -405,19 +457,19 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { return static_cast( - ctx->context.GetDevicePlacementPolicy()); + ctx->context->GetDevicePlacementPolicy()); } void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.AsyncWait(); + status->status = ctx->context->AsyncWait(); } void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context.GetStatus(); + status->status = ctx->context->GetStatus(); } void TFE_ContextAsyncClearError(TFE_Context* ctx) { - ctx->context.ClearAsyncError(); + ctx->context->ClearAsyncError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { @@ -577,7 +629,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, return new TFE_Op(ctx, name, false, types, new TFE_OpInferenceContext(op_def)); } - if (!ctx->context.FindFunctionByName(name)) { + if (!ctx->context->FindFunctionByName(name)) { status->status = tensorflow::errors::NotFound( "'", name, "' is neither a type of a primitive operation nor a name " @@ -807,6 +859,54 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, funcs.get(), num_values)); } +TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op, + const char* input_name, + TF_Status* status) { + const tensorflow::OpDef* op_def = GetOpDef(op, status); + if (!status->status.ok()) { + return -1; + } + tensorflow::AttrValueMap attrs; + op->operation.Attrs().FillAttrValueMap(&attrs); + tensorflow::NameRangeMap name_ranges; + status->status = tensorflow::NameRangesForNode( + tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr); + if (!status->status.ok()) { + return -1; + } + auto iter = name_ranges.find(input_name); + if (iter == name_ranges.end()) { + status->status = tensorflow::errors::InvalidArgument("Input '", input_name, + "' not found"); + return -1; + } + return iter->second.second - iter->second.first; +} + +TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, + const char* output_name, + TF_Status* status) { + const tensorflow::OpDef* op_def = GetOpDef(op, status); + if (!status->status.ok()) { + return -1; + } + tensorflow::AttrValueMap attrs; + op->operation.Attrs().FillAttrValueMap(&attrs); + tensorflow::NameRangeMap name_ranges; + status->status = tensorflow::NameRangesForNode( + tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges); + if (!status->status.ok()) { + return -1; + } + auto iter = name_ranges.find(output_name); + if (iter == name_ranges.end()) { + status->status = tensorflow::errors::InvalidArgument( + "Output '", output_name, "' not found"); + return -1; + } + return iter->second.second - iter->second.first; +} + void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { VLOG(1) << "Calling TFE_Execute() on op " << op; @@ -827,7 +927,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, const char* device_name, TF_Status* status) { tensorflow::TensorHandle* handle; - status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, + status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context, device_name, &handle); if (status->status.ok()) { return new TFE_TensorHandle(handle); @@ -844,26 +944,31 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - status->status = ctx->context.AddFunctionDef(function_def); + status->status = ctx->context->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = ctx->context.AddFunctionDef(function->fdef); + status->status = ctx->context->AddFunctionDef(function->fdef); +} + +void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, + TF_Status* status) { + status->status = ctx->context->RemoveFunction(name); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - return ctx->context.FindFunctionDef(name) != nullptr; + return ctx->context->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(true); - ctx->context.SetShouldStoreStepStats(true); + ctx->context->SetShouldStoreGraphs(true); + ctx->context->SetShouldStoreStepStats(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(false); - ctx->context.SetShouldStoreStepStats(false); + ctx->context->SetShouldStoreGraphs(false); + ctx->context->SetShouldStoreStepStats(false); } } // extern "C" @@ -892,9 +997,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); - status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); - ctx->context.ClearRunMetadata(); + tensorflow::mutex_lock ml(*ctx->context->MetadataMu()); + status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf); + ctx->context->ClearRunMetadata(); } namespace { @@ -910,9 +1015,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace -void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); } -void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); } namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index ce3da7f9189..d5223e63f13 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -366,6 +366,18 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, const TFE_Op** value, int num_values); +// Returns the length (number of tensors) of the input argument `input_name` +// found in the provided `op`. +TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op, + const char* input_name, + TF_Status* status); + +// Returns the length (number of tensors) of the output argument `output_name` +// found in the provided `op`. +TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, + const char* output_name, + TF_Status* status); + // Execute the operation defined by 'op' and return handles to computed // tensors in `retvals`. // @@ -398,6 +410,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status); +// Removes a function from the context. Once removed, you can no longer +// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any +// other function which calls it as an attribute. +TF_CAPI_EXPORT extern void TFE_ContextRemoveFunction(TFE_Context* ctx, + const char* name, + TF_Status* status); + // Checks whether a function is registered under `name`. TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name); diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index ffcd5ace0b9..b4192716c4f 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -32,13 +32,13 @@ std::vector TensorShapeAsVector(TFE_TensorHandle* handle, TF_Status* status) { std::vector shape; int rank = TFE_TensorHandleNumDims(handle, status); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { return shape; } shape.reserve(rank); for (int i = 0; i < rank; ++i) { shape.push_back(TFE_TensorHandleDim(handle, i, status)); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { return shape; } } @@ -53,7 +53,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TFE_TensorHandle* handle, TF_Status* status) { const tensorflow::Tensor* tensor; status->status = handle->handle->Tensor(&tensor); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { return nullptr; } @@ -139,7 +139,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( // If the tensor is not an XLA tensor, the device shape is // the same as regular tensor shape. std::vector dev_dims = TensorShapeAsVector(handle, status); - if (!status->status.ok()) { + if (TF_GetCode(status) != TF_OK) { return nullptr; } return new TFE_TensorDebugInfo(dev_dims); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index c6a12247ef1..0c170ead40a 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -17,6 +17,12 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" @@ -39,7 +45,7 @@ void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; } void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; string content; status->status = profiler->profiler->SerializeToString(&content); void* data = tensorflow::port::Malloc(content.length()); @@ -57,7 +63,7 @@ TFE_ProfilerContext* TFE_NewProfilerContext() { void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context, TFE_Context* eager_context) { - profiler_context->profiler_context.eager_context = &eager_context->context; + profiler_context->profiler_context.eager_context = eager_context->context; } void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) { @@ -71,11 +77,11 @@ void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) { } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(true); + ctx->context->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - ctx->context.SetShouldStoreGraphs(false); + ctx->context->SetShouldStoreGraphs(false); } bool TFE_ProfilerClientStartTracing(const char* service_addr, @@ -92,3 +98,423 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr, num_tracing_attempts); return s.ok(); } + +void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, + int64_t value) { + cell->cell.IncrementBy(value); +} + +int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) { + return cell->cell.value(); +} + +TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name, + TF_Status* status, + const char* description) { + auto* result = new TFE_MonitoringCounter0({name, description}); + Set_TF_Status_from_Status(status, result->counter->GetStatus()); + if (!result->counter->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) { + delete counter; +} + +TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0( + TFE_MonitoringCounter0* counter) { + return static_cast( + static_cast(counter->counter->GetCell())); +} + +TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name, + TF_Status* status, + const char* description, + const char* label1) { + auto* result = new TFE_MonitoringCounter1({name, description, label1}); + Set_TF_Status_from_Status(status, result->counter->GetStatus()); + if (!result->counter->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) { + delete counter; +} + +TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1( + TFE_MonitoringCounter1* counter, const char* label1) { + return static_cast( + static_cast(counter->counter->GetCell(label1))); +} + +TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name, + TF_Status* status, + const char* description, + const char* label1, + const char* label2) { + auto* result = + new TFE_MonitoringCounter2({name, description, label1, label2}); + Set_TF_Status_from_Status(status, result->counter->GetStatus()); + if (!result->counter->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) { + delete counter; +} + +TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2( + TFE_MonitoringCounter2* counter, const char* label1, const char* label2) { + return static_cast( + static_cast(counter->counter->GetCell(label1, label2))); +} + +void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell, + int64_t value) { + cell->cell.Set(value); +} + +int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) { + return cell->cell.value(); +} + +TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name, + TF_Status* status, + const char* description) { + auto* result = new TFE_MonitoringIntGauge0({name, description}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) { + delete gauge; +} + +TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0( + TFE_MonitoringIntGauge0* gauge) { + return static_cast( + static_cast(gauge->gauge->GetCell())); +} + +TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name, + TF_Status* status, + const char* description, + const char* label1) { + auto* result = new TFE_MonitoringIntGauge1({name, description, label1}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) { + delete gauge; +} + +TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1( + TFE_MonitoringIntGauge1* gauge, const char* label1) { + return static_cast( + static_cast(gauge->gauge->GetCell(label1))); +} + +TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name, + TF_Status* status, + const char* description, + const char* label1, + const char* label2) { + auto* result = + new TFE_MonitoringIntGauge2({name, description, label1, label2}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) { + delete gauge; +} + +TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2( + TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) { + return static_cast( + static_cast(gauge->gauge->GetCell(label1, label2))); +} + +void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell, + const char* value) { + cell->cell.Set({value}); +} + +const void TFE_MonitoringStringGaugeCellValue( + TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) { + tensorflow::string value = cell->cell.value(); + void* data = tensorflow::port::Malloc(value.length()); + value.copy(static_cast(data), value.length(), 0); + buf->data = data; + buf->length = value.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; +} + +TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0( + const char* name, TF_Status* status, const char* description) { + auto* result = new TFE_MonitoringStringGauge0({name, description}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) { + delete gauge; +} + +TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0( + TFE_MonitoringStringGauge0* gauge) { + return static_cast( + static_cast(gauge->gauge->GetCell())); +} + +TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1( + const char* name, TF_Status* status, const char* description, + const char* label1) { + auto* result = new TFE_MonitoringStringGauge1({name, description, label1}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) { + delete gauge; +} + +TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1( + TFE_MonitoringStringGauge1* gauge, const char* label1) { + return static_cast( + static_cast(gauge->gauge->GetCell(label1))); +} + +TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2( + const char* name, TF_Status* status, const char* description, + const char* label1, const char* label2) { + auto* result = + new TFE_MonitoringStringGauge2({name, description, label1, label2}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) { + delete gauge; +} + +TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2( + TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) { + return static_cast( + static_cast(gauge->gauge->GetCell(label1, label2))); +} + +void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell, + bool value) { + cell->cell.Set(value); +} + +bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) { + return cell->cell.value(); +} + +TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name, + TF_Status* status, + const char* description) { + auto* result = new TFE_MonitoringBoolGauge0({name, description}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) { + delete gauge; +} + +TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0( + TFE_MonitoringBoolGauge0* gauge) { + return static_cast( + static_cast(gauge->gauge->GetCell())); +} + +TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name, + TF_Status* status, + const char* description, + const char* label1) { + auto* result = new TFE_MonitoringBoolGauge1({name, description, label1}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) { + delete gauge; +} + +TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1( + TFE_MonitoringBoolGauge1* gauge, const char* label1) { + return static_cast( + static_cast(gauge->gauge->GetCell(label1))); +} + +TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name, + TF_Status* status, + const char* description, + const char* label1, + const char* label2) { + auto* result = + new TFE_MonitoringBoolGauge2({name, description, label1, label2}); + Set_TF_Status_from_Status(status, result->gauge->GetStatus()); + if (!result->gauge->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) { + delete gauge; +} + +TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2( + TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) { + return static_cast( + static_cast(gauge->gauge->GetCell(label1, label2))); +} + +void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell, + double value) { + cell->cell.Add(value); +} + +void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell, + TF_Buffer* buf) { + string content; + cell->cell.value().SerializeToString(&content); + void* data = tensorflow::port::Malloc(content.length()); + content.copy(static_cast(data), content.length(), 0); + buf->data = data; + buf->length = content.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; +} + +TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale, + double growth_factor, + int bucket_count) { + return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() { + return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor, + bucket_count); + }); +} + +void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) { + delete buckets; +} + +TFE_MonitoringSampler0* TFE_MonitoringNewSampler0( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status, + const char* description) { + auto* result = new TFE_MonitoringSampler0( + {name, buckets->create_buckets(), description}); + Set_TF_Status_from_Status(status, result->sampler->GetStatus()); + if (!result->sampler->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) { + delete sampler; +} + +TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0( + TFE_MonitoringSampler0* sampler) { + return static_cast( + static_cast(sampler->sampler->GetCell())); +} + +TFE_MonitoringSampler1* TFE_MonitoringNewSampler1( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status, + const char* description, const char* label1) { + auto* result = new TFE_MonitoringSampler1( + {name, buckets->create_buckets(), description, label1}); + Set_TF_Status_from_Status(status, result->sampler->GetStatus()); + if (!result->sampler->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) { + delete sampler; +} + +TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1( + TFE_MonitoringSampler1* sampler, const char* label1) { + return static_cast( + static_cast(sampler->sampler->GetCell(label1))); +} + +TFE_MonitoringSampler2* TFE_MonitoringNewSampler2( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status, + const char* description, const char* label1, const char* label2) { + auto* result = new TFE_MonitoringSampler2( + {name, buckets->create_buckets(), description, label1, label2}); + Set_TF_Status_from_Status(status, result->sampler->GetStatus()); + if (!result->sampler->GetStatus().ok()) { + delete result; + return nullptr; + } + return result; +} + +void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) { + delete sampler; +} + +TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( + TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) { + return static_cast( + static_cast(sampler->sampler->GetCell(label1, label2))); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 219b9f40720..4dc57e1eec5 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -87,6 +87,229 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing( const char* service_addr, const char* logdir, const char* worker_list, bool include_dataset_ops, int duration_ms, int num_tracing_attempts); +// TODO(fishx): Move these monitoring APIs into a separate file. +// ----------------------------------------------------------------------------- +// Monitoring Counter APIs. +// These APIs de-templated monitoring Counter for swig. + +typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell; + +// Atomically increments the value of the cell. The value must be non-negative. +TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy( + TFE_MonitoringCounterCell* cell, int64_t value); + +// Retrieves the current value of the cell. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue( + TFE_MonitoringCounterCell* cell); + +// APIs for Counter without label. +typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0; +// Returns a new Counter metric object. The caller should manage lifetime of +// the object. Using duplicate metric name will crash the program with fatal +// error. +TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0( + const char* name, TF_Status* status, const char* description); +// Deletes the Counter object. +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0( + TFE_MonitoringCounter0* counter); +// Retrieves the cell from the Counter object. The Counter object will manage +// lifetime of the cell. +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0( + TFE_MonitoringCounter0* counter); + +// APIs for Counter with 1 label. +typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1; +TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1( + const char* name, TF_Status* status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1( + TFE_MonitoringCounter1* counter); +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1( + TFE_MonitoringCounter1* counter, const char* label1); + +// APIs for Counter with 2 labels. +typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2; +TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2( + const char* name, TF_Status* status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2( + TFE_MonitoringCounter2* counter); +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2( + TFE_MonitoringCounter2* counter, const char* label1, const char* label2); + +// ----------------------------------------------------------------------------- +// Monitoring Gauge APIs. +// These APIs de-templated monitoring Gauge for swig. + +typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell; + +// Atomically set the value of the cell. +TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet( + TFE_MonitoringIntGaugeCell* cell, int64_t value); + +// Retrieves the current value of the cell. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue( + TFE_MonitoringIntGaugeCell* cell); + +// APIs for Int Gauge without label. +typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0( + TFE_MonitoringIntGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge); + +// APIs for Int Gauge with 1 label. +typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1( + TFE_MonitoringIntGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge, + const char* label1); + +// APIs for Int Gauge with 2 label. +typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2( + TFE_MonitoringIntGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge, + const char* label1, const char* label2); + +typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell; +TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet( + TFE_MonitoringStringGaugeCell* cell, const char* value); +// Retrieves the string value and saves it in buffer. +TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue( + TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf); + +// APIs for String Gauge without label. +typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0( + TFE_MonitoringStringGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge); + +// APIs for String Gauge with 1 label. +typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1( + TFE_MonitoringStringGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge, + const char* label1); + +// APIs for String Gauge with 2 label. +typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2( + TFE_MonitoringStringGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge, + const char* label1, const char* label2); + +typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell; +TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet( + TFE_MonitoringBoolGaugeCell* cell, bool value); +TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue( + TFE_MonitoringBoolGaugeCell* cell); + +// APIs for Bool Gauge without label. +typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0( + TFE_MonitoringBoolGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge); + +// APIs for Bool Gauge with 1 label. +typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1( + TFE_MonitoringBoolGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge, + const char* label1); + +// APIs for Bool Gauge with 2 label. +typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2( + TFE_MonitoringBoolGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge, + const char* label1, const char* label2); + +// ----------------------------------------------------------------------------- +// Monitoring Sampler APIs. +// These APIs de-templated monitoring Sampler for swig. + +typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell; + +// Atomically add the value of the cell. +TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd( + TFE_MonitoringSamplerCell* cell, double value); + +// Retrieves the current value of the cell. The return value is a HistogramProto +// saved in buffer. +TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue( + TFE_MonitoringSamplerCell* cell, TF_Buffer* buf); + +// APIs for sampler buckets +typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets; +TF_CAPI_EXPORT extern TFE_MonitoringBuckets* +TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor, + int bucket_count); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets( + TFE_MonitoringBuckets* buckets); + +// APIs for Sampler without label. +typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0; +TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0( + TFE_MonitoringSampler0* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0( + TFE_MonitoringSampler0* sampler); + +// APIs for Sampler with 1 label. +typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1; +TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description, const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1( + TFE_MonitoringSampler1* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1( + TFE_MonitoringSampler1* sampler, const char* label1); + +// APIs for Sampler with 2 label. +typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2; +TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description, const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2( + TFE_MonitoringSampler2* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( + TFE_MonitoringSampler2* sampler, const char* label1, const char* label2); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index d85048caa7c..4e48a7591a9 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -16,14 +16,16 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" #include + #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/cc/profiler/profiler.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/profiler/trace_events.pb.h" +#include "tensorflow/core/protobuf/trace_events.pb.h" using tensorflow::string; @@ -79,11 +81,15 @@ void ExecuteWithProfiling(bool async) { profiler_result->length})); string profile_proto_str = profile_proto.DebugString(); if (!gpu_device_name.empty()) { - EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0")); + EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0")); // device name with "stream:all" is collected by Device Tracer. EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all")); + // TODO(fishx): move following check out from this if statement. + // This is collected by TraceMe + EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU")); } - EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0")); + EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:CPU:0")); + EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul")); TF_DeleteBuffer(profiler_result); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); @@ -125,5 +131,165 @@ TEST(CAPI, MultipleProfilerSession) { TFE_DeleteProfilerContext(profiler_context); } +TEST(CAPI, MonitoringCounter0) { + TF_Status* status = TF_NewStatus(); + auto* counter = + TFE_MonitoringNewCounter0("test/counter", status, "description"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + auto* cell = TFE_MonitoringGetCellCounter0(counter); + TFE_MonitoringCounterCellIncrementBy(cell, 1); + EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 1); + auto* collection_registry = monitoring::CollectionRegistry::Default(); + monitoring::CollectionRegistry::CollectMetricsOptions options; + std::unique_ptr metrics = + collection_registry->CollectMetrics(options); + + EXPECT_EQ("test/counter", + metrics->point_set_map.at("test/counter")->metric_name); + EXPECT_EQ( + 1, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value); + + TFE_MonitoringCounterCellIncrementBy(cell, 5); + EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 6); + metrics = collection_registry->CollectMetrics(options); + EXPECT_EQ( + 6, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value); + + TFE_MonitoringDeleteCounter0(counter); + metrics = collection_registry->CollectMetrics(options); + EXPECT_EQ(metrics->point_set_map.end(), + metrics->point_set_map.find("test/counter")); +} + +TEST(CAPI, MonitoringCounterMultiple) { + TF_Status* status = TF_NewStatus(); + auto* counter1 = TFE_MonitoringNewCounter1("test/counter1", status, + "description", "label1"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* cell1 = TFE_MonitoringGetCellCounter1(counter1, "test"); + TFE_MonitoringCounterCellIncrementBy(cell1, 1); + EXPECT_EQ(TFE_MonitoringCounterCellValue(cell1), 1); + + auto* counter2 = TFE_MonitoringNewCounter2("test/counter2", status, + "description", "label1", "label2"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + auto* cell2 = TFE_MonitoringGetCellCounter2(counter2, "foo", "bar"); + TFE_MonitoringCounterCellIncrementBy(cell2, 2); + EXPECT_EQ(TFE_MonitoringCounterCellValue(cell2), 2); + + TFE_MonitoringDeleteCounter1(counter1); + TFE_MonitoringDeleteCounter2(counter2); +} + +TEST(CAPI, MonitoringGauge0) { + TF_Status* status = TF_NewStatus(); + auto* gauge = TFE_MonitoringNewIntGauge0("test/gauge", status, "test"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* cell = TFE_MonitoringGetCellIntGauge0(gauge); + TFE_MonitoringIntGaugeCellSet(cell, 1); + EXPECT_EQ(TFE_MonitoringIntGaugeCellValue(cell), 1); + auto* collection_registry = monitoring::CollectionRegistry::Default(); + monitoring::CollectionRegistry::CollectMetricsOptions options; + std::unique_ptr metrics = + collection_registry->CollectMetrics(options); + + EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name); + EXPECT_EQ(1, + metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value); + + TFE_MonitoringIntGaugeCellSet(cell, 5); + metrics = collection_registry->CollectMetrics(options); + EXPECT_EQ(5, + metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value); + TF_DeleteStatus(status); +} + +TEST(CAPI, MonitoringMultipleGauge) { + TF_Status* status = TF_NewStatus(); + auto* gauge1 = + TFE_MonitoringNewBoolGauge1("test/gauge1", status, "test", "label1"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* cell1 = TFE_MonitoringGetCellBoolGauge1(gauge1, "foo"); + TFE_MonitoringBoolGaugeCellSet(cell1, true); + EXPECT_TRUE(TFE_MonitoringBoolGaugeCellValue(cell1)); + + auto* gauge2 = TFE_MonitoringNewStringGauge2("test/gauge2", status, "test", + "label1", "label2"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* cell2 = TFE_MonitoringGetCellStringGauge2(gauge2, "foo", "bar"); + TFE_MonitoringStringGaugeCellSet(cell2, "str"); + auto* buf = new TF_Buffer; + TFE_MonitoringStringGaugeCellValue(cell2, buf); + string data(static_cast(buf->data), buf->length); + delete buf; + EXPECT_EQ(data, "str"); + TF_DeleteStatus(status); +} + +TEST(CAPI, MonitoringSampler0) { + TF_Status* status = TF_NewStatus(); + auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2); + auto* sampler = + TFE_MonitoringNewSampler0("test/sampler", buckets, status, "test"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* cell = TFE_MonitoringGetCellSampler0(sampler); + TFE_MonitoringSamplerCellAdd(cell, 1.0); + auto* collection_registry = monitoring::CollectionRegistry::Default(); + monitoring::CollectionRegistry::CollectMetricsOptions options; + std::unique_ptr metrics = + collection_registry->CollectMetrics(options); + + EXPECT_EQ("test/sampler", + metrics->point_set_map.at("test/sampler")->metric_name); + EXPECT_EQ(1.0, metrics->point_set_map.at("test/sampler") + ->points.at(0) + ->histogram_value.sum()); + + TFE_MonitoringSamplerCellAdd(cell, 5.0); + metrics = collection_registry->CollectMetrics(options); + EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler") + ->points.at(0) + ->histogram_value.sum()); + TFE_MonitoringDeleteBuckets(buckets); + TF_DeleteStatus(status); +} + +TEST(CAPI, MonitoringMultipleSampler) { + TF_Status* status = TF_NewStatus(); + auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2); + auto* sampler1 = TFE_MonitoringNewSampler1("test/sampler1", buckets, status, + "test", "label1"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* cell1 = TFE_MonitoringGetCellSampler1(sampler1, "foo"); + TFE_MonitoringSamplerCellAdd(cell1, 1.0); + TFE_MonitoringSamplerCellAdd(cell1, 2.0); + TF_Buffer* result1 = TF_NewBuffer(); + TFE_MonitoringSamplerCellValue(cell1, result1); + tensorflow::HistogramProto hitogram1; + EXPECT_TRUE(hitogram1.ParseFromString( + {reinterpret_cast(result1->data), result1->length})); + EXPECT_EQ(hitogram1.sum(), 3.0); + delete result1; + + auto* sampler2 = TFE_MonitoringNewSampler2("test/sampler2", buckets, status, + "test", "label1", "label2"); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* cell2 = TFE_MonitoringGetCellSampler2(sampler2, "foo", "bar"); + TFE_MonitoringSamplerCellAdd(cell2, 2.0); + TFE_MonitoringSamplerCellAdd(cell2, 3.0); + TF_Buffer* result2 = TF_NewBuffer(); + TFE_MonitoringSamplerCellValue(cell2, result2); + tensorflow::HistogramProto hitogram2; + EXPECT_TRUE(hitogram2.ParseFromString( + {reinterpret_cast(result2->data), result2->length})); + EXPECT_EQ(hitogram2.sum(), 5.0); + delete result2; + + TFE_MonitoringDeleteBuckets(buckets); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 35dafb9a7f1..061b0e5adcd 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -15,8 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ -#include "tensorflow/c/eager/c_api.h" - #include #include #include @@ -28,6 +26,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/c_api.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -37,19 +36,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/eager/eager_client.h" -#include "tensorflow/core/distributed_runtime/remote_device.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" -#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/lib/profiler_session.h" @@ -66,13 +60,18 @@ struct TFE_Context { TFE_Context(const tensorflow::SessionOptions& opts, TFE_ContextDevicePlacementPolicy default_policy, bool async, const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, - tensorflow::Rendezvous* rendezvous) - : context(opts, - static_cast( - default_policy), - async, device_mgr, device_mgr_owned, rendezvous) {} + tensorflow::Rendezvous* rendezvous, + const tensorflow::CustomKernelCreator* custom_kernel_creator) + : context(new tensorflow::EagerContext( + opts, + static_cast( + default_policy), + async, device_mgr, device_mgr_owned, rendezvous, + custom_kernel_creator)) {} - tensorflow::EagerContext context; + ~TFE_Context() { context->Unref(); } + + tensorflow::EagerContext* context; }; struct TFE_TensorHandle { @@ -112,7 +111,7 @@ struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, bool is_function, const tensorflow::AttrTypeMap* t, TFE_OpInferenceContext* inference_ctx) - : operation(&ctx->context, op, is_function, t), + : operation(ctx->context, op, is_function, t), inference_ctx(inference_ctx) {} tensorflow::EagerOperation operation; @@ -131,6 +130,124 @@ struct TFE_Profiler { std::unique_ptr profiler; }; +struct TFE_MonitoringCounterCell { + tensorflow::monitoring::CounterCell cell; +}; + +template +struct TFE_MonitoringCounter { + template + TFE_MonitoringCounter(const char* name, const char* description, + LabelDesc&&... label) { + counter = absl::WrapUnique(tensorflow::monitoring::Counter::New( + name, description, label...)); + } + + std::unique_ptr> counter; +}; + +struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; +struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; +struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; + +struct TFE_MonitoringIntGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; +struct TFE_MonitoringStringGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; +struct TFE_MonitoringBoolGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; + +template +struct TFE_MonitoringGauge { + template + TFE_MonitoringGauge(const char* name, const char* description, + LabelDesc&&... label) { + gauge = absl::WrapUnique( + tensorflow::monitoring::Gauge::New( + name, description, label...)); + } + + std::unique_ptr> gauge; +}; + +struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringBuckets { + TFE_MonitoringBuckets( + std::function(void)> + fn) { + create_buckets = fn; + } + + std::function(void)> + create_buckets; +}; + +struct TFE_MonitoringSamplerCell { + tensorflow::monitoring::SamplerCell cell; +}; + +template +struct TFE_MonitoringSampler { + template + TFE_MonitoringSampler( + const char* name, + std::unique_ptr buckets, + const char* description, LabelDesc&&... label) { + sampler = absl::WrapUnique(tensorflow::monitoring::Sampler::New( + {name, description, label...}, std::move(buckets))); + } + + std::unique_ptr> sampler; +}; + +struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; +struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; +struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; + namespace tensorflow { // Set an AttrValue on the op. Doesn't handle the list types. void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index b5e55420016..57aa71d5b3b 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_internal.h" #include + #include "absl/strings/match.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" @@ -297,6 +298,61 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } +void TestRemoteExecuteDeleteTensorAfterContext(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContext(ctx); + + // Delete tensors after context is deleted. + TFE_DeleteTensorHandle(h0_task1); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) { + TestRemoteExecuteDeleteTensorAfterContext(false); +} +TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) { + TestRemoteExecuteDeleteTensorAfterContext(true); +} + void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, const std::vector& expected_values) { std::unique_ptr status( @@ -1225,6 +1281,8 @@ TEST(CAPI, Function_ident_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } + TFE_ContextRemoveFunction(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -1295,6 +1353,8 @@ TEST(CAPI, Function_ident_XLA_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } + TFE_ContextRemoveFunction(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -1371,6 +1431,8 @@ void FunctionDefAndExecute(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); + TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -1412,6 +1474,8 @@ void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); + TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -1781,4 +1845,80 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) { TFE_DeleteTensorHandle(dim); TFE_DeleteContext(ctx); } + +TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input1 = TestMatrixTensorHandle(); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(); + TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Try to retrieve lengths before building the attributes (should fail) + EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* inputs[] = {input1, input2}; + TFE_OpAddInputList(identityOp, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Try to retrieve lengths before executing the op (should work) + EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; + TFE_Execute(identityOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Try to retrieve lengths after executing the op (should work) + EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + TFE_DeleteOp(identityOp); + TFE_DeleteTensorHandle(input1); + TFE_DeleteTensorHandle(input2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(retvals[1]); + TFE_DeleteContext(ctx); +} + +TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input1 = TestMatrixTensorHandle(); + TFE_TensorHandle* input2 = TestMatrixTensorHandle(); + TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* inputs[] = {input1, input2}; + TFE_OpAddInputList(identityOp, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "cheese", status)); + CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "cheese", status)); + CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + TFE_DeleteOp(identityOp); + TFE_DeleteTensorHandle(input1); + TFE_DeleteTensorHandle(input2); + TFE_DeleteContext(ctx); +} + } // namespace diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 5c11f51e874..1e0112894c5 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -47,11 +47,12 @@ struct OpTapeEntry { // Map from tensor_id to internally-defined operation-id of the operation which // produced this tensor. A value of -1 means that the tensor was directly // watched and not the result of any operation in the tape. -using TensorTape = gtl::FlatMap; +using TensorTape = std::unordered_map; // Map from operation-id to tape entry. template -using OpTape = gtl::FlatMap>; +using OpTape = + std::unordered_map>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -94,6 +95,7 @@ class VSpace { // Calls the passed-in backward function. virtual Status CallBackwardFunction( BackwardFunction* backward_function, + const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, std::vector* result) const = 0; @@ -143,7 +145,7 @@ class GradientTape { const VSpace& vspace, const gtl::ArraySlice target_tensor_ids, const gtl::ArraySlice source_tensor_ids, - const gtl::FlatMap sources_that_are_targets, + const std::unordered_map& sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result); @@ -156,7 +158,7 @@ class GradientTape { // Map from tensor id to number of remaining usages (i.e. how many entries in // the tape refer to it); to aid in tape garbage collection. - gtl::FlatMap tensor_usage_; + std::unordered_map tensor_usage_; // If false, all activations are deleted in the first call to ComputeGradient. // Else, only when this is destructed. @@ -307,11 +309,11 @@ struct BackpropInitialState { // Map from tensor ID to how many references still exist for this tensor in // the tape. - gtl::FlatMap tensor_usage_counts; + std::unordered_map tensor_usage_counts; // Maps from op ID to how many output tensors of this op still need to have // their gradients computed. - gtl::FlatMap op_missing_tensor; + std::unordered_map op_missing_tensor; }; // If `persistent_tape` is true, op_tape is not changed and none of the @@ -323,7 +325,7 @@ template BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, OpTape* op_tape, - const gtl::FlatSet& sources_set, bool persistent_tape) { + const std::unordered_set& sources_set, bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { @@ -383,7 +385,7 @@ BackpropInitialState PrepareBackprop( template std::vector InitialStack( const OpTape& op_tape, - const gtl::FlatMap& op_missing_tensor) { + const std::unordered_map& op_missing_tensor) { std::vector result; for (auto& op_entry : op_tape) { if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { @@ -397,10 +399,10 @@ template Status InitialGradients( const VSpace& vspace, gtl::ArraySlice target_tensor_ids, - gtl::FlatMap sources_that_are_targets, + const std::unordered_map& sources_that_are_targets, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, - gtl::FlatMap>* result) { + std::unordered_map>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (output_gradients.empty() || output_gradients[i] == nullptr) { @@ -454,12 +456,14 @@ Status InitialGradients( // corresponding to index 0 is used, and the gradient values at indices 1-4 are // ignored (and hence can be None). The backprop algorithm can then leverage // this by not constructing zeros to pass for those indices. -gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { - static auto* const m = new gtl::FlatMap>({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"SparseSoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); +std::unordered_map>* +FunctionsAcceptingNoneForIndicesMap() { + static auto* const m = + new std::unordered_map>({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); return m; } @@ -476,16 +480,16 @@ Status GradientTape::ComputeGradient( const VSpace& vspace, const gtl::ArraySlice target_tensor_ids, const gtl::ArraySlice source_tensor_ids, - const gtl::FlatMap sources_that_are_targets, + const std::unordered_map& sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result) { - gtl::FlatSet sources_set(source_tensor_ids.begin(), - source_tensor_ids.end()); + std::unordered_set sources_set(source_tensor_ids.begin(), + source_tensor_ids.end()); BackpropInitialState state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); - gtl::FlatMap> gradients; + std::unordered_map> gradients; Status s = InitialGradients(vspace, target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape, &gradients); @@ -501,7 +505,8 @@ Status GradientTape::ComputeGradient( cleanup(); return s; } - gtl::FlatMap gradients_size; + + std::unordered_map gradients_size; // TODO(apassos) multiple threads could be dequeuing from op_stack at the same // time, for better CPU backprop performance. VLOG(1) << "Initial stack:"; @@ -524,7 +529,17 @@ Status GradientTape::ComputeGradient( state.op_tape.erase(op_it); std::vector out_gradients; out_gradients.reserve(trace.output_tensor_info.size()); + std::vector unneeded_gradients; + for (int i = 0; i < trace.input_tensor_id.size(); i++) { + const auto& in_tensor_id = trace.input_tensor_id[i]; + if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() && + sources_set.find(in_tensor_id) == sources_set.end()) { + unneeded_gradients.push_back(i); + } + } + bool any_gradient_nonzero = false; + std::vector zero_indices; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); @@ -535,7 +550,8 @@ Status GradientTape::ComputeGradient( func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { - out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i])); + out_gradients.push_back(nullptr); + zero_indices.push_back(i); } } else { any_gradient_nonzero = true; @@ -557,8 +573,13 @@ Status GradientTape::ComputeGradient( } std::vector in_gradients; if (any_gradient_nonzero) { - Status s = vspace.CallBackwardFunction(trace.backward_function, - out_gradients, &in_gradients); + for (const auto i : zero_indices) { + out_gradients[i] = vspace.Zeros(trace.output_tensor_info[i]); + } + Status s; + s = vspace.CallBackwardFunction(trace.backward_function, + unneeded_gradients, out_gradients, + &in_gradients); if (!persistent_) { trace.backward_function_deleter(trace.backward_function); } @@ -634,14 +655,16 @@ Status GradientTape::ComputeGradient( VLOG(1) << "Op " << op_id << " missing " << missing_it->second << " output gradients"; if (missing_it->second == 0) { - op_stack.push_back(op_id); + op_stack.insert(op_stack.begin(), op_id); } } } } - CHECK(state.op_tape.empty()); + if (!state.op_tape.empty()) { + return tensorflow::errors::Internal("Invalid tape state."); + } result->reserve(source_tensor_ids.size()); - gtl::FlatSet used_gradient_ids(source_tensor_ids.size()); + std::unordered_set used_gradient_ids(source_tensor_ids.size()); for (auto is : source_tensor_ids) { auto grad_it = gradients.find(is); if (grad_it == gradients.end()) { diff --git a/tensorflow/c/experimental/BUILD b/tensorflow/c/experimental/BUILD new file mode 100644 index 00000000000..b66969eb3ff --- /dev/null +++ b/tensorflow/c/experimental/BUILD @@ -0,0 +1,122 @@ +# Description: +# Experimental C APIs for TensorFlow. + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_cuda_library", +) +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") + +tf_cuda_library( + name = "rendezvous_internal", + srcs = [ + "rendezvous.cc", + ], + hdrs = [ + "rendezvous.h", + "rendezvous_internal.h", + ], + copts = tf_copts(), + visibility = ["//tensorflow/c:__subpackages__"], + deps = [ + "//tensorflow/c:c_api_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + ], +) + +tf_cuda_library( + name = "rendezvous", + hdrs = [ + "rendezvous.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":rendezvous_internal", + "//tensorflow/c:c_api", + ], +) + +tf_cuda_library( + name = "network_internal", + srcs = [ + "network.cc", + ], + hdrs = [ + "network.h", + "network_internal.h", + ], + copts = tf_copts(), + visibility = ["//tensorflow/c:__subpackages__"], + deps = [ + ":rendezvous_internal", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + ], +) + +tf_cuda_library( + name = "network", + hdrs = [ + "network.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":network_internal", + ":rendezvous", + "//tensorflow/c:c_api", + ], +) + +# ----------------------------------------------------------------------------- +# Tests + +tf_cuda_cc_test( + name = "network_test", + size = "medium", + srcs = ["network_test.cc"], + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":network", + ":network_internal", + ":rendezvous", + ":rendezvous_internal", + "//tensorflow/c:c_api", + "//tensorflow/c:env", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:session_mgr", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_session", + "//tensorflow/core/distributed_runtime/rpc:async_service_interface", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) diff --git a/tensorflow/c/experimental/network.cc b/tensorflow/c/experimental/network.cc new file mode 100644 index 00000000000..9dfce1b63f6 --- /dev/null +++ b/tensorflow/c/experimental/network.cc @@ -0,0 +1,166 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/network.h" + +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/experimental/network_internal.h" +#include "tensorflow/c/experimental/rendezvous_internal.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +using tensorflow::ServerFactory; + +namespace tensorflow { + +/* static */ Status CGrpcServer::Create( + const ServerDef& server_def, + void* (*init_function)(const TF_GrpcServer*, TF_Status*), + void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*delete_function)(void*), + TF_RemoteRendezvousBuilder* rendezvous_builder, + std::unique_ptr* out_server) { + auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function, + join_function, delete_function); + + GrpcServerOptions options; + options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) { + return new CRendezvousMgr(env, rendezvous_builder); + }; + TF_RETURN_IF_ERROR(grpc_server->Init(options)); + TF_Status* tf_status = TF_NewStatus(); + grpc_server->SetContext(init_function( + reinterpret_cast(grpc_server), tf_status)); + TF_RETURN_IF_ERROR(tf_status->status); + TF_DeleteStatus(tf_status); + + out_server->reset(grpc_server); + return Status::OK(); +} + +Status CGrpcServer::Start() { + Status status = GrpcServer::Start(); + TF_Status* tf_status = TF_NewStatus(); + (*start_function_)(reinterpret_cast(this), context_, + tf_status); + status.Update(tf_status->status); + TF_DeleteStatus(tf_status); + return status; +} + +Status CGrpcServer::Stop() { + Status status = GrpcServer::Stop(); + TF_Status* tf_status = TF_NewStatus(); + (*stop_function_)(reinterpret_cast(this), context_, + tf_status); + status.Update(tf_status->status); + TF_DeleteStatus(tf_status); + return status; +} + +Status CGrpcServer::Join() { + Status status = GrpcServer::Join(); + TF_Status* tf_status = TF_NewStatus(); + (*join_function_)(reinterpret_cast(this), context_, + tf_status); + status.Update(tf_status->status); + TF_DeleteStatus(tf_status); + return status; +} + +namespace { +// Factory that creates CGrpcServer instances. +class CServerFactory : public ServerFactory { + public: + CServerFactory(bool (*accept_function)(const char*), + void* (*init_function)(const TF_GrpcServer*, TF_Status*), + void (*start_function)(const TF_GrpcServer*, void*, + TF_Status*), + void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*delete_function)(void*), + TF_RemoteRendezvousBuilder* rendezvous_builder) + : accept_function_(accept_function), + init_function_(init_function), + start_function_(start_function), + stop_function_(stop_function), + join_function_(join_function), + delete_function_(delete_function), + rendezvous_builder_(rendezvous_builder) {} + + Status NewServer(const ServerDef& server_def, + std::unique_ptr* out_server) override { + TF_RETURN_IF_ERROR(CGrpcServer::Create( + server_def, init_function_, start_function_, stop_function_, + join_function_, delete_function_, rendezvous_builder_, out_server)); + return Status::OK(); + } + + // Returns true if and only if this factory can create a server + // based on the given `server_def`. + bool AcceptsOptions(const ServerDef& server_def) override { + return (*accept_function_)(server_def.protocol().c_str()); + } + + private: + bool (*accept_function_)(const char* protocol); + void* (*init_function_)(const TF_GrpcServer*, TF_Status*); + void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*); + void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*); + void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*); + void (*delete_function_)(void*); + TF_RemoteRendezvousBuilder* rendezvous_builder_; +}; +} // namespace +} // namespace tensorflow + +// Server factory representation to use in C API. +// Holds CServerFactory pointer. +struct TF_GrpcServerFactory { + ::tensorflow::CServerFactory* factory; +}; + +TF_GrpcServerFactory* TF_NewGrpcServerFactory( + bool (*accept_function)(const char*), + void* (*init_function)(const TF_GrpcServer*, TF_Status*), + void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*delete_function)(void*), + TF_RemoteRendezvousBuilder* rendezvous_builder) { + TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory; + server_factory->factory = new ::tensorflow::CServerFactory( + accept_function, init_function, start_function, stop_function, + join_function, delete_function, rendezvous_builder); + return server_factory; +} + +void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) { + DCHECK_NE(server_factory, nullptr); + delete server_factory; +} + +void TF_RegisterGrpcServerFactory(const char* server_type, + TF_GrpcServerFactory* server_factory) { + ServerFactory::Register(server_type, server_factory->factory); +} diff --git a/tensorflow/c/experimental/network.h b/tensorflow/c/experimental/network.h new file mode 100644 index 00000000000..bd74ec8ffec --- /dev/null +++ b/tensorflow/c/experimental/network.h @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_ +#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/experimental/rendezvous.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// C API for TensorFlow Networking. +// NOTE: This API is unstable and almost certainly will change in the near +// future. +// +// Users wishing to register a custom GrpcServer should call +// TF_NewServerFactory and then TF_RegisterGrpcServerFactory. +// +// Example: +// ```c++ +// auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder( +// rendezvous_init_function, +// receive_from_remote_async_function, +// rendezvous_delete_function); +// +// TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory( +// accept_function, +// init_function, +// start_function, +// stop_function, +// join_function, +// delete_function, +// rendezvous_builder); +// TF_RegisterGrpcServerFactory("customfactory", factory); +// ... +// TF_DeleteGrpcServerFactory(factory); +// ``` + +typedef struct TF_GrpcServerFactory TF_GrpcServerFactory; +typedef struct TF_GrpcServerOptions TF_GrpcServerOptions; +typedef struct TF_GrpcServer TF_GrpcServer; +typedef struct TF_ServerContext { + TF_GrpcServer* const server; + void* context; +} TF_ServerContext; + +// Creates a new TF_GrpcServerFactory instance. Caller takes ownership +// of TF_GrpcServerFactory instance and should deallocate it by calling +// TF_GrpcDeleteServerFactory. +// accept_function should return true if this ServerFactory can create +// server instances for the given protocol name (for e.g. grpc+verbs). +// GRPC servers created by this factory will call provided +// init_function, start_function, stop_function, join_function and +// delete_function. +// +// Note that clean shutdown is currently not implemented for GrpcServer. +// So, stop_function will never be called now but may be in the future +// when stop mechanism is supported. +TF_CAPI_EXPORT extern TF_GrpcServerFactory* TF_NewGrpcServerFactory( + bool (*accept_function)(const char*), + void* (*init_function)(const TF_GrpcServer*, TF_Status*), + void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*delete_function)(void*), + TF_RemoteRendezvousBuilder* rendezvous_builder); + +// Deletes TF_GrpcServerFactory instances. +// Note that this function only deletes TF_GrpcServerFactory wrapper. +// Actual underlying server factory would not be deleted and will +// remain registered. +TF_CAPI_EXPORT extern void TF_DeleteGrpcServerFactory( + TF_GrpcServerFactory* server_factory); + +// Registers provided server_factory for the given server_type. +// server_type must be unique to the server factory. +TF_CAPI_EXPORT extern void TF_RegisterGrpcServerFactory( + const char* server_type, TF_GrpcServerFactory* server_factory); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif +#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_ diff --git a/tensorflow/c/experimental/network_internal.h b/tensorflow/c/experimental/network_internal.h new file mode 100644 index 00000000000..c2575296397 --- /dev/null +++ b/tensorflow/c/experimental/network_internal.h @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/experimental/network.h" +#include "tensorflow/c/experimental/rendezvous.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace tensorflow { + +// GrpcServer implementation that forwards calls to callbacks. +class CGrpcServer : public GrpcServer { + protected: + CGrpcServer(const ServerDef& server_def, + void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*delete_function)(void*)) + : GrpcServer(server_def, ::tensorflow::Env::Default()), + start_function_(start_function), + stop_function_(stop_function), + join_function_(join_function), + delete_function_(delete_function), + context_(nullptr) {} + + public: + static Status Create( + const ServerDef& server_def, + void* (*init_function)(const TF_GrpcServer*, TF_Status*), + void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), + void (*delete_function)(void*), + TF_RemoteRendezvousBuilder* rendezvous_builder, + std::unique_ptr* out_server); + + Status Start() override; + Status Stop() override; + Status Join() override; + + ~CGrpcServer() override { delete_function_(context_); } + + protected: + void SetContext(void* context) { context_ = context; } + + private: + void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*); + void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*); + void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*); + void (*delete_function_)(void*); + void* context_; + + friend class NetworksTest; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_ diff --git a/tensorflow/c/experimental/network_test.cc b/tensorflow/c/experimental/network_test.cc new file mode 100644 index 00000000000..39f7e646d28 --- /dev/null +++ b/tensorflow/c/experimental/network_test.cc @@ -0,0 +1,256 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/network.h" + +#include +#include +#include + +#include +#include + +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/experimental/network_internal.h" +#include "tensorflow/c/experimental/rendezvous.h" +#include "tensorflow/c/experimental/rendezvous_internal.h" +#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/distributed_runtime/worker_session.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace tensorflow { + +bool accept_functionA(const char* protocol_name) { + return strcmp(protocol_name, "grpc+A") == 0; +} + +bool accept_functionB(const char* protocol_name) { + return strcmp(protocol_name, "grpc+B") == 0; +} + +struct SomeServerData { + bool server_started = false; +}; + +struct SomeRendezvousData { + int test = 0; +}; + +void* init_function(const TF_GrpcServer* server, TF_Status* status) { + SomeServerData* server_data = new SomeServerData(); + TF_SetStatus(status, TF_OK, ""); + return server_data; +} + +void start_function(const TF_GrpcServer* server, void* context, + TF_Status* status) { + auto* server_data = static_cast(context); + server_data->server_started = true; + TF_SetStatus(status, TF_OK, ""); +} + +void stop_function(const TF_GrpcServer* server, void* context, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); +} + +void join_function(const TF_GrpcServer* server, void* context, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); +} + +void delete_function(void* context) { + auto* server_data = static_cast(context); + delete server_data; +} + +void* rendezvous_init_function(void* server_context) { + return new SomeRendezvousData(); +} + +void Deallocator(void* data, size_t, void* arg) { + tensorflow::cpu_allocator()->DeallocateRaw(data); + *reinterpret_cast(arg) = true; +} + +void receive_from_remote_async_function(TF_ParsedKey* key, + TF_RendezvousArgs* args, + TF_RendezvousDoneCallback* callback, + void* context) { + // Create dummy tensor + const int num_bytes = 6 * sizeof(float); + float* values = + reinterpret_cast(tensorflow::cpu_allocator()->AllocateRaw( + EIGEN_MAX_ALIGN_BYTES, num_bytes)); + int64_t dims[] = {2, 3}; + bool deallocator_called = false; + auto* tensor = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes, + &Deallocator, &deallocator_called); + callback->tensor = tensor; + auto* tf_status = TF_NewStatus(); + TF_SetStatus(tf_status, TF_OK, ""); + callback->status = tf_status; + TF_RendezvousDone(callback); + TF_DeleteStatus(tf_status); + TF_DeleteTensor(tensor); +} + +void rendezvous_delete_function(void* context) { + auto* rendezvous_data = static_cast(context); + delete rendezvous_data; +} + +tensorflow::ServerDef GetServerDef(const string& protocol, + const string& job_name, int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol(protocol); + server_def.set_job_name(job_name); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name(job_name); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +class NetworksTest : public ::testing::Test { + public: + ~NetworksTest() override {} + + SomeServerData* GetServerData(CGrpcServer* server) { + EXPECT_NE(server->context_, nullptr); + return static_cast(server->context_); + } +}; + +Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, + const string& receiver, const string& name) { + Rendezvous::ParsedKey result; + CHECK( + Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, + name, FrameAndIter(0, 0)), + &result) + .ok()); + return result; +} + +void InitializeRendezvous(GrpcServer* grpc_server, ServerDef* server_def, + RemoteRendezvous* remote_rendezvous) { + int rendezvous_id = 0; + auto session_name = tensorflow::strings::StrCat("test_", rendezvous_id); + TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->CreateSession( + session_name, *server_def, true)); + + std::shared_ptr worker_session; + TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->WorkerSessionForSession( + session_name, &worker_session)); + + TF_EXPECT_OK(remote_rendezvous->Initialize(worker_session.get())); +} + +TEST_F(NetworksTest, TestStartServer) { + auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder( + rendezvous_init_function, receive_from_remote_async_function, + rendezvous_delete_function); + + TF_Status* tf_status = TF_NewStatus(); + TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory( + accept_functionA, init_function, start_function, stop_function, + join_function, delete_function, rendezvous_builder); + TF_RegisterGrpcServerFactory("testfactoryA", factory); + + ServerDef server_def = GetServerDef("grpc+A", "localhost", 1); + std::unique_ptr server; + TF_EXPECT_OK(NewServer(server_def, &server)); + auto* grpc_server = static_cast(server.get()); + auto* server_data = GetServerData(grpc_server); + ASSERT_FALSE(server_data->server_started); + + TF_EXPECT_OK(server->Start()); + ASSERT_TRUE(server_data->server_started); + + TF_DeleteStatus(tf_status); + TF_DeleteGrpcServerFactory(factory); + TF_DeleteRemoteRendezvousBuilder(rendezvous_builder); + // TODO(annarev): find a clean way to shutdown server. + server.release(); +} + +TEST_F(NetworksTest, TestReceiveData) { + auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder( + rendezvous_init_function, receive_from_remote_async_function, + rendezvous_delete_function); + + TF_Status* tf_status = TF_NewStatus(); + TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory( + accept_functionB, init_function, start_function, stop_function, + join_function, delete_function, rendezvous_builder); + TF_RegisterGrpcServerFactory("testfactoryB", factory); + + ServerDef server_def = GetServerDef("grpc+B", "localhost", 1); + std::unique_ptr server; + TF_EXPECT_OK(NewServer(server_def, &server)); + auto* grpc_server = static_cast(server.get()); + + TF_EXPECT_OK(server->Start()); + auto* rendezvous_mgr = grpc_server->worker_env()->rendezvous_mgr; + auto* remote_rendezvous = rendezvous_mgr->Find(0); + + auto key = Key("/job:localhost/replica:1/task:2/device:CPU:0", 1, + "/job:localhost/replica:0/task:0/device:CPU:0", "test"); + Rendezvous::Args args; + bool done_callback_called = false; + auto* done_callback_called_ptr = &done_callback_called; + absl::Notification notification; + auto* notification_ptr = ¬ification; + + InitializeRendezvous(grpc_server, &server_def, remote_rendezvous); + remote_rendezvous->RecvAsync( + key, args, + [done_callback_called_ptr, notification_ptr]( + const Status&, const Rendezvous::Args&, const Rendezvous::Args&, + const Tensor&, const bool) mutable { + *done_callback_called_ptr = true; + notification_ptr->Notify(); + }); + notification.WaitForNotificationWithTimeout(absl::Seconds(10)); + ASSERT_EQ(done_callback_called, true); + + TF_DeleteStatus(tf_status); + TF_DeleteGrpcServerFactory(factory); + TF_DeleteRemoteRendezvousBuilder(rendezvous_builder); + // Server doesn't have a clean shutdown. + server.release(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/rendezvous.cc b/tensorflow/c/experimental/rendezvous.cc new file mode 100644 index 00000000000..0ee4907b7a4 --- /dev/null +++ b/tensorflow/c/experimental/rendezvous.cc @@ -0,0 +1,124 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/rendezvous.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/experimental/rendezvous_internal.h" +#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id, + void (*receive_from_remote_async_function)( + TF_ParsedKey*, TF_RendezvousArgs*, + TF_RendezvousDoneCallback*, + void* context), + void (*delete_function)(void* context), + void* server_context) + : BaseRemoteRendezvous(env, step_id), + receive_from_remote_async_function_(receive_from_remote_async_function), + delete_function_(delete_function), + context_(nullptr) {} + +void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, + DoneCallback done) { + TF_ParsedKey key; + key.src_device = parsed.src_device.data(); + key.src_device_len = parsed.src_device.size(); + key.dst_device = parsed.dst_device.data(); + key.dst_device_len = parsed.dst_device.size(); + key.full_key = parsed.FullKey().data(); + key.full_key_len = parsed.FullKey().size(); + + TF_DeviceContext* device_context = new TF_DeviceContext(); + device_context->context = args.device_context; + + TF_AllocatorAttributes* alloc_attrs = new TF_AllocatorAttributes(); + alloc_attrs->value = args.alloc_attrs.value; + alloc_attrs->scope_id = args.alloc_attrs.scope_id; + alloc_attrs->on_host = args.alloc_attrs.on_host(); + alloc_attrs->nic_compatible = args.alloc_attrs.nic_compatible(); + + TF_RendezvousArgs* cargs = new TF_RendezvousArgs(); + cargs->device_context = device_context; + cargs->alloc_attrs = alloc_attrs; + + TF_RendezvousDoneCallback* done_callback = new TF_RendezvousDoneCallback(); + done_callback->done_callback = done; + done_callback->recv_args = cargs; + + receive_from_remote_async_function_(&key, cargs, done_callback, context_); +} + +CRemoteRendezvous::~CRemoteRendezvous() { delete_function_(context_); } +} // namespace tensorflow + +TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder( + void* (*init_function)(void* server_context), + void (*receive_from_remote_async_function)(TF_ParsedKey*, + TF_RendezvousArgs*, + TF_RendezvousDoneCallback*, + void* context), + void (*delete_function)(void* context)) { + TF_RemoteRendezvousBuilder* builder = new TF_RemoteRendezvousBuilder(); + builder->init_function = init_function; + builder->delete_function = delete_function; + builder->receive_from_remote_async_function = + receive_from_remote_async_function; + return builder; +} + +void TF_DeleteRemoteRendezvousBuilder( + TF_RemoteRendezvousBuilder* rendezvous_builder) { + DCHECK_NE(rendezvous_builder, nullptr); + delete rendezvous_builder; +} + +TF_CAPI_EXPORT extern void TF_RendezvousDone( + TF_RendezvousDoneCallback* callback) { + DCHECK_NE(callback, nullptr); + ::tensorflow::Tensor tensor; + TF_CHECK_OK(TF_TensorToTensor(callback->tensor, &tensor)); + ::tensorflow::Rendezvous::Args recv_args; + recv_args.alloc_attrs.value = callback->recv_args->alloc_attrs->value; + recv_args.alloc_attrs.scope_id = callback->recv_args->alloc_attrs->scope_id; + recv_args.device_context = callback->recv_args->device_context->context; + ::tensorflow::Rendezvous::Args sent_args; + + callback->done_callback(callback->status->status, sent_args, recv_args, + tensor, callback->dead); + + if (callback->recv_args) { + DCHECK_NE(callback->recv_args, nullptr); + DCHECK_NE(callback->recv_args->alloc_attrs, nullptr); + DCHECK_NE(callback->recv_args->device_context, nullptr); + delete callback->recv_args->alloc_attrs; + delete callback->recv_args->device_context; + delete callback->recv_args; + } + delete callback; + callback = nullptr; +} diff --git a/tensorflow/c/experimental/rendezvous.h b/tensorflow/c/experimental/rendezvous.h new file mode 100644 index 00000000000..5b007d52429 --- /dev/null +++ b/tensorflow/c/experimental/rendezvous.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_ + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// C API for Rendezvous. +// NOTE: This API is unstable and almost certainly will change in the near +// future. +// +// Custom rendezvous allows for custom implementations of Recv call. +// +// Users wishing to create custom rendezvous objects should call +// TF_NewRemoteRendezvousBuilder and pass returned TF_RemoteRendezvousBuilder +// to to TF_NewServerFactory. + +typedef struct TF_RemoteRendezvousBuilder TF_RemoteRendezvousBuilder; +typedef struct TF_ParsedKey TF_ParsedKey; +typedef struct TF_RendezvousArgs TF_RendezvousArgs; +typedef struct TF_RendezvousDoneCallback TF_RendezvousDoneCallback; + +// Creates a new TF_RemoteRendezvousBuilder instance. +// Rendezvous instances will forward calls to init_function, +// receive_from_remote_async_function and delete_function passed here. +// +// Note that receive_from_remote_async_function implementation must call +// TF_Done with the TF_DoneCallback passed as an argument. +TF_CAPI_EXPORT extern TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder( + void* (*init_function)(void* server_context), + void (*receive_from_remote_async_function)(TF_ParsedKey*, + TF_RendezvousArgs*, + TF_RendezvousDoneCallback*, + void* context), + void (*delete_function)(void* context)); + +// Deletes TF_RemoteRendezvousBuilder instances. +TF_CAPI_EXPORT extern void TF_DeleteRemoteRendezvousBuilder( + TF_RemoteRendezvousBuilder* rendezvous_builder); + +// Calls TF_DoneCallback and destroys callback instance and +// TF_DoneCallback members except `tensor` and `status`. Caller is +// responsible for deleting `tensor` and `status` after TF_Done returns. +TF_CAPI_EXPORT extern void TF_RendezvousDone( + TF_RendezvousDoneCallback* callback); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif +#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_ diff --git a/tensorflow/c/experimental/rendezvous_internal.h b/tensorflow/c/experimental/rendezvous_internal.h new file mode 100644 index 00000000000..f06686023e6 --- /dev/null +++ b/tensorflow/c/experimental/rendezvous_internal.h @@ -0,0 +1,135 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/experimental/rendezvous.h" +#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/platform/macros.h" + +struct TF_ParsedKey { + // char* members might not be null-terminated. + const char* src_device; + size_t src_device_len; + const char* dst_device; + size_t dst_device_len; + const char* full_key; + size_t full_key_len; +}; + +struct TF_AllocatorAttributes { + bool on_host; + bool nic_compatible; + // NOTE: The upper 8 bits of the value are reserved for + // device-specific uses. Implementors of a device can interpret these + // upper 8 bits in device-specific ways, and ops implemented for those + // devices are responsible for setting those 8 bits appropriately. + tensorflow::uint32 value = 0; + // EXPERIMENTAL: If this is greater than zero, then allocation is delegated to + // a named special-purpose allocator on the same device. + tensorflow::int32 scope_id = 0; +}; + +struct TF_DeviceContext { + ::tensorflow::DeviceContext* context; +}; + +struct TF_RendezvousArgs { + const TF_DeviceContext* device_context; + const TF_AllocatorAttributes* alloc_attrs; +}; + +struct TF_RendezvousDoneCallback { + ::tensorflow::Rendezvous::DoneCallback done_callback; + + // TODO(annarev): figure out if we should also support sent_args. + const TF_RendezvousArgs* recv_args; + TF_Tensor* tensor = nullptr; + TF_Status* status; + bool dead; +}; + +struct TF_RemoteRendezvousBuilder { + void* (*init_function)(void* server_context); + void (*receive_from_remote_async_function)(TF_ParsedKey*, TF_RendezvousArgs*, + TF_RendezvousDoneCallback*, + void* context); + void (*delete_function)(void* context); + void* server_context; +}; + +namespace tensorflow { + +class CRemoteRendezvous : public BaseRemoteRendezvous { + public: + CRemoteRendezvous(const WorkerEnv* env, int64 step_id, + void (*receive_from_remote_async_function)( + TF_ParsedKey*, TF_RendezvousArgs*, + TF_RendezvousDoneCallback*, void* context), + void (*delete_function)(void* context), + void* server_context); + + void SetContext(void* context) { context_ = context; } + + protected: + void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, + DoneCallback done) override; + + private: + ~CRemoteRendezvous() override; + + void (*receive_from_remote_async_function_)(TF_ParsedKey*, TF_RendezvousArgs*, + TF_RendezvousDoneCallback*, + void* context); + void (*delete_function_)(void* context); + void* context_; + TF_DISALLOW_COPY_AND_ASSIGN(CRemoteRendezvous); +}; + +class CRendezvousMgr : public BaseRendezvousMgr { + public: + CRendezvousMgr(const WorkerEnv* env, + const TF_RemoteRendezvousBuilder* rendezvous_builder) + : BaseRendezvousMgr(env), rendezvous_builder_(rendezvous_builder) {} + + protected: + BaseRemoteRendezvous* Create(int64 step_id, + const WorkerEnv* worker_env) override { + auto* rendezvous = new CRemoteRendezvous( + worker_env, step_id, + rendezvous_builder_->receive_from_remote_async_function, + rendezvous_builder_->delete_function, + rendezvous_builder_->server_context); + + rendezvous->SetContext(rendezvous_builder_->init_function( + rendezvous_builder_->server_context)); + return rendezvous; + } + + private: + const TF_RemoteRendezvousBuilder* rendezvous_builder_; + TF_DISALLOW_COPY_AND_ASSIGN(CRendezvousMgr); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_ diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD new file mode 100644 index 00000000000..597182ab016 --- /dev/null +++ b/tensorflow/c/kernels/BUILD @@ -0,0 +1,44 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_kernel_library", +) + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +tf_kernel_library( + name = "bitcast_op", + prefix = "bitcast_op", + deps = [ + "//tensorflow/c:kernels", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + ], +) + +tf_cc_test( + name = "bitcast_op_test", + srcs = ["bitcast_op_test.cc"], + deps = [ + ":bitcast_op", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# Changes to the Android srcs here should be replicated in +# tensorflow/contrib/makefile/tf_op_files.txt +# LINT.IfChange +filegroup( + name = "android_all_ops", + srcs = [ + "bitcast_op.cc", + ], +) +# LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt) diff --git a/tensorflow/c/kernels/bitcast_op.cc b/tensorflow/c/kernels/bitcast_op.cc new file mode 100644 index 00000000000..f2f313af386 --- /dev/null +++ b/tensorflow/c/kernels/bitcast_op.cc @@ -0,0 +1,171 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/kernels.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" + +// BitcastOp implements a bitcast kernel, creating an output tensor that shares +// the same data buffer as the input but with a different shape and/or data +// type. Its inputs are: +// +// * the input tensor +// * an attribute named "T" containing the TF_DataType of the input tensor +// * an attribute named "type" containing the TF_DataType of the output tensor +// +// Given an input tensor of shape [...], if the input DataType "T" is larger +// than the output DataType "type", then the shape changes from [...] +// to [..., sizeof(T)/sizeof(type)]. +// +// If "T" is smaller than "type", the operator requires that the rightmost +// dimension be equal to sizeof(type)/sizeof(T). The shape then goes from +// [..., sizeof(type)/sizeof(T)] to [...]. +// +// Bitcast is implemented as a low-level cast, so machines with different endian +// orderings will give different results. +typedef struct BitcastOp { + TF_DataType input_data_type; + TF_DataType output_data_type; + size_t in_size; + size_t out_size; +} BitcastOp; + +static void* BitcastOp_Create(TF_OpKernelConstruction* ctx) { + auto* kernel = new BitcastOp; + + TF_Status* s = TF_NewStatus(); + TF_OpKernelConstruction_GetAttrType(ctx, "T", &kernel->input_data_type, s); + + if (TF_GetCode(s) == TF_OK) { + TF_OpKernelConstruction_GetAttrType(ctx, "type", &kernel->output_data_type, + s); + } + + if (TF_GetCode(s) == TF_OK) { + kernel->in_size = TF_DataTypeSize(kernel->input_data_type); + kernel->out_size = TF_DataTypeSize(kernel->output_data_type); + + size_t check_size = std::max(kernel->in_size, kernel->out_size) % + std::min(kernel->in_size, kernel->out_size); + if (check_size != 0) { + std::ostringstream err; + err << "cannot convert between datatype " << kernel->input_data_type + << " and " << kernel->output_data_type; + TF_SetStatus(s, TF_INVALID_ARGUMENT, err.str().c_str()); + } + } + + if (TF_GetCode(s) != TF_OK) { + TF_OpKernelConstruction_Failure(ctx, s); + delete kernel; + kernel = nullptr; + } + + TF_DeleteStatus(s); + return kernel; +} + +static void BitcastOp_Delete(void* kernel) { + delete static_cast(kernel); +} + +static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) { + auto* k = static_cast(kernel); + int dim_count = 0; + + TF_Tensor* tensor; + TF_Status* status = TF_NewStatus(); + TF_GetInput(ctx, 0, &tensor, status); + if (TF_GetCode(status) == TF_OK) { + dim_count = TF_NumDims(tensor); + if (!(k->in_size >= k->out_size || + (dim_count > 0 && + TF_Dim(tensor, dim_count - 1) == k->out_size / k->in_size))) { + std::ostringstream err; + err << "Cannot bitcast from " << k->input_data_type << " to " + << k->output_data_type; + TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str()); + } + } + + if (TF_GetCode(status) == TF_OK) { + auto* dims = new int64_t[dim_count + 1]; + int new_dim_count = dim_count; + for (int dim = 0; dim < dim_count; ++dim) { + dims[dim] = TF_Dim(tensor, dim); + } + if (k->out_size < k->in_size) { + dims[new_dim_count++] = static_cast(k->in_size / k->out_size); + } else if (k->out_size > k->in_size) { + --new_dim_count; + } + + TF_Tensor* output = TF_AllocateTensor(k->output_data_type, dims, 0, + TF_DataTypeSize(k->output_data_type)); + TF_TensorBitcastFrom(tensor, k->output_data_type, output, dims, + new_dim_count, status); + if (TF_GetCode(status) == TF_OK) { + TF_SetOutput(ctx, 0, output, status); + } + delete[] dims; + TF_DeleteTensor(output); + } + + if (TF_GetCode(status) != TF_OK) { + TF_OpKernelContext_Failure(ctx, status); + } + TF_DeleteStatus(status); + TF_DeleteTensor(tensor); +} + +static void RegisterBitcastOp() { + TF_Status* status = TF_NewStatus(); + + { + auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU, + &BitcastOp_Create, &BitcastOp_Compute, + &BitcastOp_Delete); + TF_RegisterKernelBuilder("BitcastOp", builder, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) + << "Error while registering bitcast kernel"; + } + +#if GOOGLE_CUDA + { + auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU, + &BitcastOp_Create, &BitcastOp_Compute, + &BitcastOp_Delete); + TF_RegisterKernelBuilder("BitcastOp", builder, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) + << "Error while registering CUDA bitcast kernel"; + } +#endif + + TF_DeleteStatus(status); +} + +// A dummy static variable initialized by a lambda whose side-effect is to +// register the bitcast kernel. +static bool BitcastOpIsRegistered = []() { + if (SHOULD_REGISTER_OP_KERNEL("BitcastOp")) { + RegisterBitcastOp(); + } + return true; +}(); diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc new file mode 100644 index 00000000000..06ffcca19da --- /dev/null +++ b/tensorflow/c/kernels/bitcast_op_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +void TestBitcastOp(Tensor* input_tensor, DataType out_type, + TensorShape expected_shape, error::Code expected_code) { + Status status; + NodeDef def; + def.set_op("Bitcast"); + def.set_device(DEVICE_CPU); + + AttrValue typeAttr; + SetAttrValue(input_tensor->dtype(), &typeAttr); + + AttrValue outTypeAttr; + SetAttrValue(out_type, &outTypeAttr); + + (*def.mutable_attr())["T"] = typeAttr; + (*def.mutable_attr())["type"] = outTypeAttr; + + def.add_input( + strings::StrCat("input1: ", DataTypeString(input_tensor->dtype()))); + + std::unique_ptr kernel = + CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status); + ASSERT_TRUE(status.ok()) << status.ToString(); + + OpKernelContext::Params params; + DummyDevice dummy_device(nullptr, false); + params.device = &dummy_device; + params.op_kernel = kernel.get(); + gtl::InlinedVector inputs; + inputs.emplace_back(input_tensor); + params.inputs = &inputs; + + OpKernelContext ctx(¶ms); + kernel->Compute(&ctx); + ASSERT_EQ(expected_code, ctx.status().code()); + if (expected_code == error::OK) { + ASSERT_EQ(expected_shape, ctx.mutable_output(0)->shape()) + << ctx.mutable_output(0)->shape().DebugString(); + } +} + +TEST(BitcastOpTest, TestUpcast) { + Tensor int8_input(DT_UINT8, {8}); + for (int i = 0; i < 8; i++) { + int8_input.vec()(i) = static_cast(1); + } + TestBitcastOp(&int8_input, DT_UINT64, TensorShape(), error::OK); +} + +TEST(BitcastOpTest, TestDowncast) { + Tensor int64_input(static_cast(1)); + TestBitcastOp(&int64_input, DT_UINT8, TensorShape({8}), error::OK); +} + +TEST(BitcastOpTest, TestCastToSameSize) { + Tensor int32_input(DT_UINT32, {4, 6}); + TestBitcastOp(&int32_input, DT_UINT8, TensorShape({4, 6, 4}), error::OK); +} + +TEST(BitcastOpTest, TestImpossibleCast) { + Tensor int8_input(DT_UINT8, {1}); + TestBitcastOp(&int8_input, DT_UINT32, TensorShape(), error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/ops.cc b/tensorflow/c/ops.cc new file mode 100644 index 00000000000..b175d262c01 --- /dev/null +++ b/tensorflow/c/ops.cc @@ -0,0 +1,326 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/ops.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" + +using ::tensorflow::DataType; +using ::tensorflow::OpDef; +using ::tensorflow::OpDeprecation; +using ::tensorflow::OpShapeInferenceFn; +using ::tensorflow::Set_TF_Status_from_Status; +using ::tensorflow::Status; +using ::tensorflow::shape_inference::DimensionHandle; +using ::tensorflow::shape_inference::InferenceContext; +using ::tensorflow::shape_inference::ShapeHandle; + +typedef struct TF_OpDefinitionBuilder { + // The op definition proto representing the op. + tensorflow::OpDef op_def; + + // The shape inference function, or nullptr if none is provided for this op. + OpShapeInferenceFn shape_inference_func; +} TF_OpDefinitionBuilder; + +TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) { + auto* result = new TF_OpDefinitionBuilder; + result->op_def.set_name(op_name); + return result; +} + +void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) { + delete builder; +} + +static void PopulateArg(OpDef::ArgDef* arg, const char* name, + TF_DataType type) { + arg->set_name(name); + arg->set_type(static_cast(type)); +} + +void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder, + const char* name, TF_DataType type) { + PopulateArg(builder->op_def.add_input_arg(), name, type); +} + +void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder, + const char* name, TF_DataType type) { + PopulateArg(builder->op_def.add_output_arg(), name, type); +} + +#define DEFINE_BUILDER_BOOL_SETTER(func_name, builder_setter_name, arg_name) \ + void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \ + bool arg_name) { \ + builder->op_def.builder_setter_name(arg_name); \ + } + +DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative, set_is_commutative, is_commutative) +DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate, set_is_aggregate, is_aggregate) +DEFINE_BUILDER_BOOL_SETTER(SetIsStateful, set_is_stateful, is_stateful) +DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput, + set_allows_uninitialized_input, + allows_unintialized_input) + +static OpDef::AttrDef* AddAttribute(TF_OpDefinitionBuilder* builder, + const char* name, const char* type_name) { + OpDef::AttrDef* attr = builder->op_def.add_attr(); + attr->set_name(name); + attr->set_type(type_name); + return attr; +} + +#define DEFINE_ATTR_SETTER(attr_type, type_name, field_c_type, field_name) \ + void TF_OpDefinitionBuilderAdd##attr_type##Attr( \ + TF_OpDefinitionBuilder* builder, const char* name) { \ + AddAttribute(builder, name, type_name); \ + } \ + \ + void TF_OpDefinitionBuilderAdd##attr_type##AttrWithDefaultValue( \ + TF_OpDefinitionBuilder* builder, const char* name, \ + field_c_type field_name) { \ + OpDef::AttrDef* attr = AddAttribute(builder, name, type_name); \ + attr->mutable_default_value()->set_##field_name(field_name); \ + } \ + \ + void TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \ + TF_OpDefinitionBuilder* builder, const char* name, \ + field_c_type field_name[], size_t n) { \ + OpDef::AttrDef* attr = AddAttribute(builder, name, "list(" type_name ")"); \ + for (int _i = 0; _i < n; ++_i) { \ + attr->mutable_default_value()->mutable_list()->add_##field_name( \ + field_name[_i]); \ + } \ + } \ + \ + void TF_OpDefinitionBuilderAdd##attr_type##ListAttr( \ + TF_OpDefinitionBuilder* builder, const char* name) { \ + TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \ + builder, name, NULL, 0); \ + } + +DEFINE_ATTR_SETTER(String, "string", const char*, s) +DEFINE_ATTR_SETTER(Int, "int", int64_t, i) +DEFINE_ATTR_SETTER(Float, "float", float, f) +DEFINE_ATTR_SETTER(Bool, "bool", bool, b) + +void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder, + int version, const char* explanation) { + OpDeprecation* dep = builder->op_def.mutable_deprecation(); + dep->set_version(version); + dep->set_explanation(explanation); +} + +void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::OpRegistry::Global()->Register( + [builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status { + op_reg_data->op_def.Clear(); + op_reg_data->op_def.MergeFrom(builder->op_def); + op_reg_data->shape_inference_fn = builder->shape_inference_func; + return Status::OK(); + }); + + // Calling ProcessRegistrations ensures that the cc_builder's finalize method + // is called and that the builder can be deleted. + Set_TF_Status_from_Status( + status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations()); + + delete builder; +} + +void TF_OpDefinitionBuilderSetShapeInferenceFunction( + TF_OpDefinitionBuilder* builder, + void (*shape_inference_func)(TF_ShapeInferenceContext* ctx, + TF_Status* status)) { + builder->shape_inference_func = + [shape_inference_func](InferenceContext* ctx) -> tensorflow::Status { + TF_Status* c_status = TF_NewStatus(); + auto c_ctx = reinterpret_cast(ctx); + shape_inference_func(c_ctx, c_status); + tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status); + TF_DeleteStatus(c_status); + return result; + }; +} + +TF_ShapeHandle* TF_NewShapeHandle() { + return reinterpret_cast(new ShapeHandle); +} + +TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize( + TF_ShapeInferenceContext* ctx, size_t size) { + auto* handle = new ShapeHandle; + *handle = reinterpret_cast(ctx)->Vector(size); + return reinterpret_cast(handle); +} + +void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx, + TF_ShapeHandle* first, + TF_ShapeHandle* second, + TF_ShapeHandle* result, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast(ctx); + Status s = cc_ctx->Concatenate(*reinterpret_cast(first), + *reinterpret_cast(second), + reinterpret_cast(result)); + Set_TF_Status_from_Status(status, s); +} + +TF_DimensionHandle* TF_NewDimensionHandle() { + return reinterpret_cast(new DimensionHandle); +} + +int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) { + auto* cc_ctx = reinterpret_cast(ctx); + return cc_ctx->num_inputs(); +} + +void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i, + TF_ShapeHandle* handle, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + auto* cc_ctx = reinterpret_cast(ctx); + if (0 < i || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range"); + } + if (TF_GetCode(status) == TF_OK) { + auto* cc_result = reinterpret_cast(handle); + *cc_result = cc_ctx->input(i); + } +} + +int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx, + TF_ShapeHandle* handle) { + auto* cc_ctx = reinterpret_cast(ctx); + return cc_ctx->RankKnown(*reinterpret_cast(handle)); +} + +void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i, + TF_ShapeHandle* handle, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + auto* cc_ctx = reinterpret_cast(ctx); + if (0 < i || i >= cc_ctx->num_outputs()) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range"); + } + if (TF_GetCode(status) == TF_OK) { + cc_ctx->set_output(i, *(reinterpret_cast(handle))); + } +} + +void TF_DeleteShapeHandle(TF_ShapeHandle* handle) { + if (handle == nullptr) { + return; + } + + delete reinterpret_cast(handle); +} + +void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) { + if (handle == nullptr) { + return; + } + + delete reinterpret_cast(handle); +} + +#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ + void TF_ShapeInferenceContext_GetAttr##func( \ + TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \ + TF_Status* status) { \ + TF_SetStatus(status, TF_OK, ""); \ + cc_type v; \ + auto* cc_ctx = reinterpret_cast(ctx); \ + Status s = cc_ctx->GetAttr(attr_name, &v); \ + Set_TF_Status_from_Status(status, s); \ + if (s.ok()) { \ + *val = static_cast(v); \ + } \ + } + +DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) + +#define DEFINE_RANK_FUNC(func_name) \ + void TF_ShapeInferenceContext##func_name( \ + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \ + TF_ShapeHandle* result, TF_Status* status) { \ + auto* cc_ctx = reinterpret_cast(ctx); \ + auto* cc_handle = reinterpret_cast(handle); \ + auto* cc_result = reinterpret_cast(result); \ + Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \ + Set_TF_Status_from_Status(status, s); \ + } + +DEFINE_RANK_FUNC(WithRank) +DEFINE_RANK_FUNC(WithRankAtLeast) +DEFINE_RANK_FUNC(WithRankAtMost) + +int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx, + TF_ShapeHandle* handle) { + return reinterpret_cast(ctx)->Rank( + *reinterpret_cast(handle)); +} + +void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx, + TF_ShapeHandle* shape_handle, int64_t i, + TF_DimensionHandle* result) { + int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle); + auto* cc_result = reinterpret_cast(result); + + if (i < -rank || i >= rank) { + *cc_result = DimensionHandle(); + return; + } + + auto* cc_ctx = reinterpret_cast(ctx); + auto* cc_shape_handle = reinterpret_cast(shape_handle); + *cc_result = cc_ctx->Dim(*cc_shape_handle, i); +} + +int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) { + return InferenceContext::ValueKnown( + *reinterpret_cast(dim_handle)); +} + +void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx, + TF_Status* status) { + Status s = ::tensorflow::shape_inference::UnknownShape( + reinterpret_cast(ctx)); + Set_TF_Status_from_Status(status, s); +} + +void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx, + TF_ShapeHandle* shape_handle, + int64_t start, int64_t end, + TF_ShapeHandle* result, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + auto* cc_ctx = reinterpret_cast(ctx); + auto* cc_result = reinterpret_cast(result); + Status s = cc_ctx->Subshape(*reinterpret_cast(shape_handle), + start, end, cc_result); + Set_TF_Status_from_Status(status, s); +} + +int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) { + return InferenceContext::Value( + *reinterpret_cast(dim_handle)); +} diff --git a/tensorflow/c/ops.h b/tensorflow/c/ops.h new file mode 100644 index 00000000000..7e2e95084ea --- /dev/null +++ b/tensorflow/c/ops.h @@ -0,0 +1,407 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Routines for registering new ops and for implementing op shape inference +// functions. +// +// This API is alpha software and is subject to change. +// +// REGISTRATION +// ------------ +// +// In order to register a new op, create a new TF_OpDefinitionBuilder: +// +// TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("OpName"); +// +// Inputs, outputs and attributes can be added to the builder with the +// corresponding functions, e.g. +// +// TF_OpDefinitionBuilderAddInput(builder, "input1: int32"); +// TF_OpDefinitionBuilderAddOutput(builder, "output1: int64"); +// TF_OpDefinitionBuilderAddAttr(builder, "attr: int32"); +// +// The builder may then be registered with TensorFlow using the +// TF_RegisterOpDefinition function. E.g. +// +// TF_Status* status = TF_NewStatus(); +// TF_RegisterOpDefinition(builder, &status); +// if (TF_GetCode(status) != TF_OK) { +// // handle error +// } +// +// SHAPE INFERENCE +// --------------- +// +// You can provide a shape inference function that TensorFlow will call when it +// wants to understand the shape of outputs that the op will produce. Use the +// TF_OpDefinitionBuilderSetShapeInferenceFunction function to register a shape +// inference function pointer with TensorFlow. The following is an example of a +// very simple shape inference function: +// +// void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) { +// TF_ShapeHandle* input = TF_NewShapeHandle(); +// TF_ShapeInferenceContextGetInput(ctx, 0, input, status); +// if (TF_GetCode(status) == TF_OK) { +// TF_ShapeInferenceContextSetOutput(ctx, 0, input, status); +// } +// TF_DeleteShapeHandle(input); +// } +// +// The following code registers the inference function with TensorFlow: +// +// TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn); +// +// For more details about shape inference, see the documentation for +// TF_OpDefinitionBuilderSetShapeInferenceFunction. + +#ifndef TENSORFLOW_C_OPS_H_ +#define TENSORFLOW_C_OPS_H_ + +#include +#include +#include + +#include "tensorflow/c/c_api.h" + +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +struct TF_DimensionHandle; +struct TF_OpDefinitionBuilder; +struct TF_ShapeHandle; +struct TF_ShapeInferenceContext; + +// Returns a newly allocated op definition builder for the given op name. The +// returned builder may be customized with the `TF_OpDefinitionBuilder...` +// functions and then registered with TensorFlow with TF_RegisterOpDefinition. +// +// The returned pointer is either freed by a call to TF_RegisterOpDefinition, or +// can be manually deleted by TF_DeleteOpDefinitionBuilder if it is never +// registered. +TF_CAPI_EXPORT extern TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder( + const char* op_name); + +// Registers the given op builder with TensorFlow. Indicates success or +// otherwise in the given status. +// +// `builder` is freed whether the op was successfully registered or not. You +// must call either this function or TF_DeleteOpDefinitionBuilder to free the +// builder, but never both. +TF_CAPI_EXPORT extern void TF_RegisterOpDefinition( + TF_OpDefinitionBuilder* builder, TF_Status* status); + +// Frees the given op definition builder. You must call either this function or +// TF_RegisterOpDefinition to free the builder, but never both. +TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder( + TF_OpDefinitionBuilder* builder); + +//---------------------------------------------------- +// Attribute functions. + +// Adds a string attribute with the given name to the builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a string attribute with the given name and default value to the builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttrWithDefaultValue( + TF_OpDefinitionBuilder* builder, const char* name, const char* value); + +// Adds a string list attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringListAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a string list attribute with the given default values to the builder. +// `values` must contain at least `n` elements. +TF_CAPI_EXPORT extern void +TF_OpDefinitionBuilderAddStringListAttrWithDefaultValues( + TF_OpDefinitionBuilder* builder, const char* name, const char* values[], + size_t n); + +// Adds an integer attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds an integer attribute with the given name and default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttrWithDefaultValue( + TF_OpDefinitionBuilder* builder, const char* name, int64_t value); + +// Adds an integer list attribute with the given name and no default value to +// the builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntListAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds an integer list attribute with the given name and default values to the +// builder. `values` must contain at least `n` elements. +TF_CAPI_EXPORT extern void +TF_OpDefinitionBuilderAddIntListAttrWithDefaultValues( + TF_OpDefinitionBuilder* builder, const char* name, int64_t values[], + size_t n); + +// Adds a float attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a float attribute with the given name and default value to the builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttrWithDefaultValue( + TF_OpDefinitionBuilder* builder, const char* name, float value); + +// Adds a float list attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatListAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a float list attribute with the given name and default values to the +// builder. `values` must contain at least `n` elements. +TF_CAPI_EXPORT extern void +TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues( + TF_OpDefinitionBuilder* builder, const char* name, float values[], + size_t n); + +// Adds a boolean attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a boolean attribute with the given name and default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttrWithDefaultValue( + TF_OpDefinitionBuilder* builder, const char* name, bool value); + +// Adds a boolean list attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolListAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a boolean list attribute with the given name and default values to the +// builder. `values` must contain at least `n` elements. +TF_CAPI_EXPORT extern void +TF_OpDefinitionBuilderAddBoolListAttrWithDefaultValues( + TF_OpDefinitionBuilder* builder, const char* name, bool values[], size_t n); + +// Adds the input with the given name and type to the op. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput( + TF_OpDefinitionBuilder* builder, const char* name, TF_DataType type); + +// Adds the output with the given name and type to the op. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput( + TF_OpDefinitionBuilder* builder, const char* output, TF_DataType type); + +// Sets the commutative property for the op built by the given builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative( + TF_OpDefinitionBuilder* builder, bool is_commutative); + +// Sets the is_aggregate property of the builder to the given value. +// +// If is_aggregate is true, then the operation produced by this builder accepts +// N >= 2 inputs and produces 1 output all of the same type. Should be +// associative and commutative, and produce output with the same shape as the +// input. The optimizer may replace an aggregate op taking input from multiple +// devices with a tree of aggregate ops that aggregate locally within each +// device (and possibly within groups of nearby devices) before communicating. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsAggregate( + TF_OpDefinitionBuilder* builder, bool is_aggregate); + +// Sets the is_stateful property of the builder to the given value. +// +// The op built by this builder is stateful if its behavior depends on some +// state beyond its input tensors (e.g. variable reading op) or if it has a +// side-effect (e.g. printing or asserting ops). Equivalently, stateless ops +// must always produce the same output for the same input and have no +// side-effects. +// +// By default Ops may be moved between devices. Stateful ops should either not +// be moved, or should only be moved if that state can also be moved (e.g. via +// some sort of save / restore). Stateful ops are guaranteed to never be +// optimized away by Common Subexpression Elimination (CSE). +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsStateful( + TF_OpDefinitionBuilder* builder, bool is_stateful); + +// Sets the allows_uninitialized_input property of the operation built by this +// builder. +// +// By default, all inputs to an Op must be initialized Tensors. Ops that may +// initialize tensors for the first time should set this field to true, to allow +// the Op to take an uninitialized Tensor as input. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetAllowsUninitializedInput( + TF_OpDefinitionBuilder* builder, bool allows_uninitialized_input); + +// Adds a deprecation warning for the given op. This indicates to the user that +// `version` is the first TensorFlow GraphDef version for which the operation is +// deprecated. `explanation` should contain the reason for the deprecation and +// what to use instead. +// +// This function is only an indicator that the operation may disappear in a +// version of TensorFlow after `version`. It does not affect op registration. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderDeprecated( + TF_OpDefinitionBuilder* builder, int version, const char* explanation); + +// Sets the shape inference function for the op. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetShapeInferenceFunction( + TF_OpDefinitionBuilder* builder, + void (*shape_inference_func)(TF_ShapeInferenceContext* ctx, + TF_Status* status)); + +//---------------------------------------------------- +// Functions for TF_ShapeInferenceContext. +// +// Functions for implementing shape inference functions. TensorFlow uses these +// functions to determine the shape of tensors produced by an operation without +// having to actually run the operation. If an operation chooses to provide a +// shape inference function, it will be invoked by TensorFlow as needed. +// +// When invoked by TensorFlow, the shape inference function is provided with a +// TF_ShapeInferenceContext pointer. The function's implementation will use the +// accessor and mutator functions with names beginning with +// TF_ShapeInferenceContext to examine the input state and determine the output +// shape. + +// Returns the number of inputs in the given shape inference context. +TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextNumInputs( + TF_ShapeInferenceContext* ctx); + +// Returns a newly allocated shape handle. The shapes represented by these +// handles may be queried or mutated with the corresponding +// TF_ShapeInferenceContext... functions. +TF_CAPI_EXPORT extern TF_ShapeHandle* TF_NewShapeHandle(); + +// Places the ith input of the given shape inference context into the given +// shape handle, or returns a status other than TF_OK indicating why the input +// could not be retrieved +// (for example, if i < 0 || i >= TF_ShapeInferenceContextNumInputs(ctx)). +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextGetInput( + TF_ShapeInferenceContext* ctx, int i, TF_ShapeHandle* handle, + TF_Status* status); + +// Places the given shape handle into the `i`th output position of the given +// context. Internally, the shape handle is copied; the caller may subsequently +// delete `handle`. +TF_CAPI_EXPORT +extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, + int i, TF_ShapeHandle* handle, + TF_Status* status); + +// Returns a newly-allocate shape handle representing a vector of the given +// size. The returned handle should be freed with TF_DeleteShapeHandle. +TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize( + TF_ShapeInferenceContext* ctx, size_t size); + +// Returns a newly allocated dimension handle. It must be freed with +// TF_DeleteDimensionHandle. +TF_CAPI_EXPORT extern TF_DimensionHandle* TF_NewDimensionHandle(); + +// Interprets the named shape inference context attribute as a TF_DataType and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// TF_DataType, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContext_GetAttrType( + TF_ShapeInferenceContext* ctx, const char* attr_name, TF_DataType* val, + TF_Status* status); + +// Returns the rank of the shape represented by the given handle. +TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextRank( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle); + +// Returns 1 if `handle` has a known rank, 0 otherwise. +TF_CAPI_EXPORT extern int TF_ShapeInferenceContextRankKnown( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle); + +// If has rank , or its rank is unknown, return OK and return the +// shape with asserted rank in <*result>. Otherwise an error is placed into +// `status`. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRank( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, + TF_ShapeHandle* result, TF_Status* status); + +// If has rank at least , or its rank is unknown, return OK and +// return the shape with asserted rank in <*result>. Otherwise an error is +// placed into `status`. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtLeast( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, + TF_ShapeHandle* result, TF_Status* status); + +// If has rank at most , or its rank is unknown, return OK and +// return the shape with asserted rank in <*result>. Otherwise an error is +// placed into `status`. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtMost( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, + TF_ShapeHandle* result, TF_Status* status); + +// Places a handle to the ith dimension of the given shape into *result. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i, + TF_DimensionHandle* result); + +// Returns 1 if the given handle represents a known dimension. +TF_CAPI_EXPORT extern int TF_ShapeInferenceContextDimValueKnown( + TF_ShapeInferenceContext* ctx, TF_DimensionHandle* handle); + +// Returns in <*result> a sub-shape of , with dimensions +// [start:end]. and can be negative, to index from the end of the +// shape. and are set to the rank of if > rank of +// . +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSubshape( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t start, + int64_t end, TF_ShapeHandle* result, TF_Status* status); + +// Places an unknown shape in all outputs for the given inference context. Used +// for shape inference functions with ops whose output shapes are unknown. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSetUnknownShape( + TF_ShapeInferenceContext* ctx, TF_Status* status); + +// Returns whether the given handle represents a known dimension. +TF_CAPI_EXPORT extern int TF_DimensionHandleValueKnown( + TF_DimensionHandle* dim_handle); + +// Returns the value of the given dimension. +TF_CAPI_EXPORT extern int64_t TF_DimensionHandleValue( + TF_DimensionHandle* dim_handle); + +// Returns in <*result> the result of appending the dimensions of to +// those of . +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextConcatenateShapes( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* first, + TF_ShapeHandle* second, TF_ShapeHandle* result, TF_Status* status); + +// Frees the given shape handle. +TF_CAPI_EXPORT extern void TF_DeleteShapeHandle(TF_ShapeHandle* handle); + +// Frees the given dimension handle. +TF_CAPI_EXPORT extern void TF_DeleteDimensionHandle(TF_DimensionHandle* handle); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_OPS_H_ diff --git a/tensorflow/c/ops_test.cc b/tensorflow/c/ops_test.cc new file mode 100644 index 00000000000..2b40f96157e --- /dev/null +++ b/tensorflow/c/ops_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/ops.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(OpsTest, TestBasicOpRegistration) { + TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp"); + TF_OpDefinitionBuilderAddStringAttr(builder, "attr1"); + TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8); + TF_OpDefinitionBuilderAddInput(builder, "input2", TF_UINT16); + TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT32); + TF_Status* status = TF_NewStatus(); + TF_RegisterOpDefinition(builder, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_Buffer* op_list_buffer = TF_GetAllOpList(); + ::tensorflow::OpList op_list; + op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length); + bool found = false; + for (const auto& op : op_list.op()) { + if (op.name() == "SomeOp") { + ASSERT_EQ(2, op.input_arg_size()); + ASSERT_EQ("input1", op.input_arg(0).name()); + ASSERT_EQ(::tensorflow::DT_UINT8, op.input_arg(0).type()); + ASSERT_EQ(1, op.attr_size()); + ASSERT_EQ("string", op.attr(0).type()); + found = true; + } + } + EXPECT_TRUE(found); + TF_DeleteStatus(status); + TF_DeleteBuffer(op_list_buffer); +} + +void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) { + TF_ShapeHandle* handle = TF_NewShapeHandle(); + TF_ShapeInferenceContextGetInput(ctx, 0, handle, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + TF_ShapeInferenceContextSetOutput(ctx, 0, handle, status); + TF_DeleteShapeHandle(handle); +} + +TEST(OpsTest, TestShapeInference_IdentityFunction) { + ShapeInferenceTestOp op("SomeTestOp"); + + TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp"); + TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8); + TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8); + TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn); + TF_Status* status = TF_NewStatus(); + TF_RegisterOpDefinition(builder, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_ASSERT_OK( + shape_inference::ShapeInferenceTestutil::InferShapes(op, "[1,2]", "in0")); + TF_DeleteStatus(status); +} + +// Creates an output whose shape is a vector of length +// TF_ShapeInferenceContextRank. +void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) { + TF_ShapeHandle* handle = TF_NewShapeHandle(); + TF_ShapeInferenceContextGetInput(ctx, 0, handle, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + TF_ShapeHandle* new_shape = TF_ShapeInferenceContextVectorFromSize( + ctx, TF_ShapeInferenceContextRank(ctx, handle)); + TF_ShapeInferenceContextSetOutput(ctx, 0, new_shape, status); + TF_DeleteShapeHandle(handle); + TF_DeleteShapeHandle(new_shape); +} + +TEST(OpsTest, TestShapeInference_VectorizeFunction) { + ShapeInferenceTestOp op("VectorizeTestOp"); + + TF_OpDefinitionBuilder* builder = + TF_NewOpDefinitionBuilder("VectorizeTestOp"); + TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8); + TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8); + TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn); + TF_Status* status = TF_NewStatus(); + TF_RegisterOpDefinition(builder, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes( + op, "[4,5,9]", "[3]")); + TF_DeleteStatus(status); +} + +TEST(OpsTest, AttributeAccessors) { + TF_OpDefinitionBuilder* builder = + TF_NewOpDefinitionBuilder("AttributeAccesorsOp"); + float values[] = {1, 2, 3, 4}; + TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues( + builder, "foo1", values, sizeof(values)); + TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(builder, "foo2", + "my string"); + TF_OpDefinitionBuilderSetIsCommutative(builder, true); + TF_OpDefinitionBuilderSetIsAggregate(builder, true); + TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true); + std::string deprecation_msg = "use something else instead"; + TF_OpDefinitionBuilderDeprecated(builder, 4, deprecation_msg.c_str()); + + TF_Status* status = TF_NewStatus(); + TF_RegisterOpDefinition(builder, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + + TF_Buffer* op_list_buffer = TF_GetAllOpList(); + ::tensorflow::OpList op_list; + op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length); + bool found = false; + for (const auto& op : op_list.op()) { + if (op.name() == "AttributeAccesorsOp") { + ASSERT_TRUE(op.is_commutative()); + ASSERT_TRUE(op.is_aggregate()); + ASSERT_TRUE(op.allows_uninitialized_input()); + ASSERT_EQ(4, op.deprecation().version()); + ASSERT_EQ(deprecation_msg, op.deprecation().explanation()); + ASSERT_EQ(2, op.attr_size()); + ASSERT_EQ("list(float)", op.attr(0).type()); + AttrValue::ListValue l = op.attr(0).default_value().list(); + ASSERT_EQ(1, l.f(0)); + ASSERT_EQ(2, l.f(1)); + ASSERT_EQ(3, l.f(2)); + ASSERT_EQ(4, l.f(3)); + + ASSERT_EQ("string", op.attr(1).type()); + ASSERT_EQ("my string", op.attr(1).default_value().s()); + found = true; + } + } + ASSERT_TRUE(found); + TF_DeleteStatus(status); + TF_DeleteBuffer(op_list_buffer); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 6449e7f44f7..2c9d9f3a15b 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -89,7 +89,7 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { // This modification only updates the destination node for // the purposes of running this graph in a session. Thus, we don't // record the source node as being modified. @@ -163,7 +163,7 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, tensorflow::shape_inference::ShapeHandle shape; status->status = ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); - if (!status->status.ok()) return; + if (TF_GetCode(status) != TF_OK) return; shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); } ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); @@ -174,7 +174,7 @@ void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, mutex_lock l(graph->mu); status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, new_src.index, &dst->node); - if (status->status.ok()) { + if (TF_GetCode(status) == TF_OK) { // This modification only updates the destination node for // the purposes of running this graph in a session. Thus, we don't // record the source node as being modified. diff --git a/tensorflow/c/tf_attrtype.h b/tensorflow/c/tf_attrtype.h new file mode 100644 index 00000000000..0c1545db232 --- /dev/null +++ b/tensorflow/c/tf_attrtype.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_TF_ATTRTYPE_H_ +#define TENSORFLOW_C_TF_ATTRTYPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// TF_AttrType describes the type of the value of an attribute on an operation. +typedef enum TF_AttrType { + TF_ATTR_STRING = 0, + TF_ATTR_INT = 1, + TF_ATTR_FLOAT = 2, + TF_ATTR_BOOL = 3, + TF_ATTR_TYPE = 4, + TF_ATTR_SHAPE = 5, + TF_ATTR_TENSOR = 6, + TF_ATTR_PLACEHOLDER = 7, + TF_ATTR_FUNC = 8, +} TF_AttrType; + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_ATTRTYPE_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index fca9416fdca..bd741249cf2 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -8,6 +8,19 @@ package( licenses(["notice"]) # Apache 2.0 +filegroup( + name = "srcs", + srcs = [ + "framework/gradients.h", + "framework/ops.h", + "framework/scope.h", + "framework/scope_internal.h", + "ops/array_ops.h", + "ops/while_loop.h", + "//tensorflow/cc/saved_model:loader.h", + ], +) + load( "//tensorflow:tensorflow.bzl", "cc_library_with_android_deps", @@ -190,6 +203,7 @@ tf_cc_test( deps = [ ":ops", ":scope", + "//tensorflow/cc:cc_ops", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -606,16 +620,13 @@ tf_gen_op_wrappers_cc( visibility = ["//tensorflow:internal"], ) -cc_library_with_android_deps( +cc_library( name = "cc_op_gen_main", srcs = [ "framework/cc_op_gen.cc", "framework/cc_op_gen.h", "framework/cc_op_gen_main.cc", ], - android_deps = [ - "//tensorflow/core:android_tensorflow_lib", - ], copts = tf_copts(), data = [ "//tensorflow/core/api_def:base_api_def", diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 43a33cbea6e..0605a62b83a 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -42,14 +42,19 @@ namespace { const int kRightMargin = 79; // Converts: -// bazel-out/.../genfiles/(external/YYY/)?XX +// bazel-out/.../(bin|genfiles)/(external/YYY/)?XX // to: XX. string GetPath(const string& dot_h_fname) { - auto pos = dot_h_fname.find("/genfiles/"); + auto pos = dot_h_fname.find("/bin/"); string result = dot_h_fname; if (pos != string::npos) { // - 1 account for the terminating null character (\0) in "/genfiles/". - result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1); + result = dot_h_fname.substr(pos + sizeof("/bin/") - 1); + } else { + pos = dot_h_fname.find("/genfiles/"); + if (pos != string::npos) { + result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1); + } } if (result.size() > sizeof("external/") && result.compare(0, sizeof("external/") - 1, "external/") == 0) { diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 134d64af140..e74ba009083 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -531,4 +531,23 @@ Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) { return InternalScope::NewScope(graph, status, refiner); } +Status CreateOutputWithScope(string op_name, + absl::Span inputs, + const Scope& scope, Output* output) { + TF_RETURN_IF_ERROR(scope.status()); + const auto unique_name = scope.GetUniqueNameForOp(op_name); + auto builder = ::tensorflow::NodeBuilder(unique_name, op_name); + for (auto input : inputs) { + TF_RETURN_IF_ERROR(scope.status()); + builder = builder.Input(input.node()); + } + ::tensorflow::Node* ret; + scope.UpdateBuilder(&builder); + TF_RETURN_IF_ERROR(scope.status()); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + TF_RETURN_IF_ERROR(scope.status()); + *output = Output(ret, 0); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 1e17b74bc8f..ef2daff1357 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -255,6 +255,12 @@ struct CompositeOpScopes { Scope last; }; +// Creates a node of the given operation, with the given inputs, and assigns the +// result to output. This does not support the ability to add additional +// attributes. +Status CreateOutputWithScope(string op_name, + absl::Span inputs, + const Scope& scope, Output* output); /// @} } // namespace tensorflow diff --git a/tensorflow/cc/framework/scope_test.cc b/tensorflow/cc/framework/scope_test.cc index b40b345eb84..0b410b7f544 100644 --- a/tensorflow/cc/framework/scope_test.cc +++ b/tensorflow/cc/framework/scope_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/cc/framework/scope.h" + +#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -145,4 +147,14 @@ TEST(ScopeTest, ControlDeps) { EXPECT_EQ(c_c.control_deps().size(), 3); } +TEST(ScopeTest, CreateOutput) { + Scope root = Scope::NewRootScope(); + Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT); + Output add; + ASSERT_TRUE( + CreateOutputWithScope("Add", {a, a}, root.WithOpName("add"), &add).ok()); + EXPECT_EQ(add.node()->name(), "add"); + EXPECT_EQ(add.node()->type_string(), "Add"); +} + } // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 7d0f63efbcc..056eea7eb5a 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -88,15 +88,19 @@ Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, string kernel_type; TF_RETURN_IF_ERROR( GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type)); + bool antialias; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "antialias", &antialias)); grad_outputs->push_back(internal::ScaleAndTranslateGrad( scope, grad_inputs[0], op.input(0), op.input(2), op.input(3), - internal::ScaleAndTranslateGrad::KernelType(kernel_type))); + internal::ScaleAndTranslateGrad::KernelType(kernel_type) + .Antialias(antialias))); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); return scope.status(); } + REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper); Status CropAndResizeGradHelper(const Scope& scope, const Operation& op, diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index 3bd52c80bd9..d50f4f5750a 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -196,29 +196,106 @@ class ScaleAndTranslateGradTest : public ::testing::Test { } template - void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x, - Output* y) { + void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale, + Input translation, const string& kernel_type, bool antialias, + Output* x, Output* y) { *x = Const(scope_, x_data); - *y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f}); + *y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation, + ScaleAndTranslate::KernelType(kernel_type) + .Antialias(antialias) + .Antialias(antialias)); TF_ASSERT_OK(scope_.status()); } template - void TestResize() { - TensorShape x_shape({1, 2, 3, 1}); + void TestScaleAndTranslate(const TensorShape x_shape, const int out_height, + const int out_width, Input scale, + Input translation, const string& kernel_type, + bool antialias) { Tensor x_data = MakeData(x_shape); Output x, y; - MakeOp(x_data, {4, 6}, &x, &y); + MakeOp(x_data, {out_height, out_width}, scale, translation, + kernel_type, antialias, &x, &y); JAC_T max_error; TF_ASSERT_OK((ComputeGradientError( - scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); - EXPECT_LT(max_error, 1e-3); + scope_, x, x_data, y, {1, out_height, out_width, 1}, &max_error))); + EXPECT_LT(max_error, 2e-3); } + const std::vector kScales = {Input{1.0f, 1.0f}, Input{0.37f, 0.47f}, + Input{2.1f, 2.1f}}; + const std::vector kTranslations = { + Input{0.0f, 0.0f}, Input{3.14f, 1.19f}, Input{2.1f, 3.1f}, + Input{100.0f, 200.0f}}; Scope scope_; }; -TEST_F(ScaleAndTranslateGradTest, Works) { TestResize(); } +TEST_F(ScaleAndTranslateGradTest, TestGrads) { + const std::vector kKernelTypes = {"lanczos1", "lanczos3", + "lanczos5", "gaussian"}; + constexpr int kOutHeight = 4; + constexpr int kOutWidth = 6; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithoutAntialias) { + constexpr int kOutHeight = 4; + constexpr int kOutWidth = 6; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + TestScaleAndTranslate(kXShape, kOutHeight, kOutWidth, + scale, translation, "lanczos3", + false); + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithSameShape) { + const std::vector kKernelTypes = {"lanczos3", "gaussian"}; + + constexpr int kOutHeight = 2; + constexpr int kOutWidth = 3; + + const TensorShape kXShape = TensorShape({1, 2, 3, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} + +TEST_F(ScaleAndTranslateGradTest, TestGradsWithSmallerShape) { + const std::vector kKernelTypes = {"lanczos3", "gaussian"}; + constexpr int kOutHeight = 2; + constexpr int kOutWidth = 3; + + const TensorShape kXShape = TensorShape({1, 4, 6, 1}); + for (const Input scale : kScales) { + for (const Input translation : kTranslations) { + for (const std::string& kernel_type : kKernelTypes) { + TestScaleAndTranslate( + kXShape, kOutHeight, kOutWidth, scale, translation, kernel_type, + true); + } + } + } +} class CropAndResizeGradTest : public ::testing::Test { protected: @@ -237,9 +314,9 @@ class CropAndResizeGradTest : public ::testing::Test { template void MakeOp(const Tensor& x_data, const Input& boxes, const Input& box_ind, - const Input& crop_szie, Output* x, Output* y) { + const Input& crop_size, Output* x, Output* y) { *x = Const(scope_, x_data); - *y = CropAndResize(scope_, *x, boxes, box_ind, crop_szie, + *y = CropAndResize(scope_, *x, boxes, box_ind, crop_size, CropAndResize::Method("bilinear")); TF_ASSERT_OK(scope_.status()); } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index dedd55f16af..13bc88f7cd3 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -17,6 +17,11 @@ load( "if_not_mobile", "tf_cc_test", ) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", + "if_static_and_not_mobile", +) cc_library( name = "constants", @@ -78,12 +83,13 @@ cc_library( hdrs = ["loader.h"], deps = [ ":loader_lite", - ] + if_not_mobile([ + ] + if_static_and_not_mobile([ + "//tensorflow/core:tensorflow", + ]) + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", ]) + if_android([ "//tensorflow/core:android_tensorflow_lib", ]), @@ -91,6 +97,19 @@ cc_library( cc_library( name = "loader_lite", + hdrs = ["loader.h"], + deps = if_static([ + ":loader_lite_impl", + ]) + if_not_mobile([ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + # mobile not supported yet + ]), +) + +cc_library( + name = "loader_lite_impl", srcs = ["loader.cc"], hdrs = ["loader.h"], deps = [ @@ -121,6 +140,7 @@ tf_cc_test( ":tag_constants", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 66260fcf4a9..70f362cfeae 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -148,7 +148,8 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir, const std::vector& asset_file_defs, Session* session, const string& init_op_name) { if (!init_op_name.empty()) { - LOG(INFO) << "Running initialization op on SavedModel bundle."; + LOG(INFO) << "Running initialization op on SavedModel bundle at path: " + << export_dir; std::vector> inputs; AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index 6f1c8735407..c173569a095 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -36,6 +36,7 @@ tf_cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:framework_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py index 49cb74f19ef..ad2443a0c32 100644 --- a/tensorflow/compat_template.__init__.py +++ b/tensorflow/compat_template.__init__.py @@ -18,27 +18,41 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import logging as _logging import os as _os import sys as _sys +from tensorflow.python.tools import module_util as _module_util + # pylint: disable=g-bad-import-order # API IMPORTS PLACEHOLDER -from tensorflow.python.tools import component_api_helper as _component_api_helper -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=('tensorboard.summary._tf.summary'), - error_msg=( - "Limited tf.compat.v2.summary API due to missing TensorBoard " - "installation")) -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=( - 'tensorflow_estimator.python.estimator.api._v2.estimator')) -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=('tensorflow.python.keras.api._v2.keras')) +# Hook external TensorFlow modules. +_current_module = _sys.modules[__name__] +try: + from tensorboard.summary._tf import summary + _current_module.__path__ = ( + [_module_util.get_parent_dir(summary)] + _current_module.__path__) +except ImportError: + _logging.warning( + "Limited tf.compat.v2.summary API due to missing TensorBoard " + "installation.") + +try: + from tensorflow_estimator.python.estimator.api._v2 import estimator + _current_module.__path__ = ( + [_module_util.get_parent_dir(estimator)] + _current_module.__path__) +except ImportError: + pass + +try: + from tensorflow.python.keras.api._v2 import keras + _current_module.__path__ = ( + [_module_util.get_parent_dir(keras)] + _current_module.__path__) +except ImportError: + pass + # We would like the following to work for fully enabling 2.0 in a 1.0 install: # diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index 9549a71c41a..23c722edef7 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -19,19 +19,30 @@ from __future__ import division as _division from __future__ import print_function as _print_function import os as _os +import sys as _sys + +from tensorflow.python.tools import module_util as _module_util # pylint: disable=g-bad-import-order -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import # API IMPORTS PLACEHOLDER -from tensorflow.python.tools import component_api_helper as _component_api_helper -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=( - 'tensorflow_estimator.python.estimator.api._v1.estimator')) -_component_api_helper.package_hook( - parent_package_str=__name__, - child_package_str=('tensorflow.python.keras.api._v1.keras')) +# Hook external TensorFlow modules. +_current_module = _sys.modules[__name__] +try: + from tensorflow_estimator.python.estimator.api._v1 import estimator + _current_module.__path__ = ( + [_module_util.get_parent_dir(estimator)] + _current_module.__path__) +except ImportError: + pass + +try: + from tensorflow.python.keras.api._v1 import keras + _current_module.__path__ = ( + [_module_util.get_parent_dir(keras)] + _current_module.__path__) +except ImportError: + pass + + from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index af016bf80e7..a2ae086c41e 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -33,13 +33,13 @@ cc_library( ":aot_only_var_handle_op", ":embedded_protocol_buffers", "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -70,6 +70,7 @@ tf_cc_test( ], deps = [ ":tfcompile_lib", + "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 2355fad8802..2f063d7dd47 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/types/span.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" -#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -38,7 +38,7 @@ namespace tfcompile { namespace { -using BufferInfo = cpu_function_runtime::BufferInfo; +using BufferInfo = xla::cpu_function_runtime::BufferInfo; bool IsAlpha(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); @@ -213,7 +213,11 @@ Status GenResultMethods(const tf2xla::Config& config, return errors::Internal("codegen requires the XLA result to be a tuple"); } size_t num_results = ps.result().tuple_shapes_size(); - if (config.fetch_size() + config.variable_size() != num_results) { + int readonly_variables = absl::c_count_if( + config.variable(), + [](const tf2xla::Variable& var) { return var.readonly(); }); + if (config.fetch_size() + config.variable_size() - readonly_variables != + num_results) { return errors::InvalidArgument("mismatch between fetch_size(", config.fetch_size(), ")+variable_size(", config.variable_size(), ") and tuple_size(", @@ -256,36 +260,26 @@ Status GenVariableMethods(const tf2xla::Config& config, TF_RETURN_IF_ERROR( AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); const string code = R"( - void set_var_{{NAME}}_data({{TYPE}}* data) { + void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) { set_arg_data({{I}}, data); } -)"; - const tf2xla::Variable& var = config.variable(i - config.feed_size()); - *methods += RewriteWithName( - var.name().empty() ? var.node_name() : var.name(), code, rewrites); + {{MAYBE_CONST}}{{TYPE}}* var_{{NAME}}_data() { + return static_cast<{{MAYBE_CONST}}{{TYPE}}*>(arg_data({{I}})); } - size_t num_results = ps.result().tuple_shapes_size(); - for (int i = config.fetch_size(); i < num_results; ++i) { - std::vector> rewrites; - TF_RETURN_IF_ERROR(AddRewritesForShape( - i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); - string code = R"( - {{TYPE}}* var_{{NAME}}_data() { - return static_cast<{{TYPE}}*>(result_data({{I}})); - } - {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) { - return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - result_data({{I}}))){{INDICES}}; + {{MAYBE_CONST}}{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) { + return (*static_cast<{{MAYBE_CONST}}{{TYPE}}(*){{DIM_SIZES}}>( + arg_data({{I}}))){{INDICES}}; } const {{TYPE}}* var_{{NAME}}_data() const { - return static_cast(result_data({{I}})); + return static_cast(arg_data({{I}})); } const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const { return (*static_cast( - result_data({{I}}))){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } )"; - const tf2xla::Variable& var = config.variable(i - config.fetch_size()); + const tf2xla::Variable& var = config.variable(i - config.feed_size()); + rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : ""); *methods += RewriteWithName( var.name().empty() ? var.node_name() : var.name(), code, rewrites); } @@ -363,7 +357,7 @@ std::vector BufferInfosToCppExpression( ? "~0ULL" : absl::StrCat(encoded.second, "ULL"); return absl::StrCat( - "::tensorflow::cpu_function_runtime::BufferInfo({", + "::xla::cpu_function_runtime::BufferInfo({", encoded.first, "ULL, ", encoded_second_as_str, "})"); }); return buffer_infos_as_strings; @@ -398,13 +392,15 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable)); - const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( - buffer_infos_for_args.data(), buffer_infos_for_args.size(), - /*allocate_entry_params=*/true); + const size_t arg_bytes_aligned = + xla::cpu_function_runtime::AlignedBufferBytes( + buffer_infos_for_args.data(), buffer_infos_for_args.size(), + /*allocate_entry_params=*/true); const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args); - const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( - buffer_infos_for_temps.data(), buffer_infos_for_temps.size(), - /*allocate_entry_params=*/true); + const size_t temp_bytes_aligned = + xla::cpu_function_runtime::AlignedBufferBytes( + buffer_infos_for_temps.data(), buffer_infos_for_temps.size(), + /*allocate_entry_params=*/true); const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps); // Create rewrite strings for namespace start and end. @@ -538,7 +534,8 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) + {{CLASS}}(AllocMode alloc_mode = + AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} {{CLASS}}(const {{CLASS}}&) = delete; @@ -579,27 +576,37 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // buffers are managed internally, and may change after each call to Run. {{METHODS_RESULT}} - // Methods for managing variable buffers. Buffers are in row-major order. The - // input and output buffers may or may not be identical. + // Methods for managing variable buffers. Buffers are in row-major order. + // + // For read-write variables we generate the following methods: // // void set_var_X_data(T* data) - // Sets the buffer for variable X. + // Sets the buffer for variable X. Must be called before Run if the + // allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY. // // T* var_X_data() - // Returns the buffer of type T for variable X. + // Returns the buffer of type T for variable X. If the allocation mode is + // RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the + // buffer passed to set_var_X_data. // // T& var_X(...dim indices...) // Returns a reference to the value of type T for variable X, // with dim indices specifying which value. No bounds checking is performed // on dim indices. + // + // For readonly variables we generate the same set of methods, except that we + // use `const T` instead of `T`. We use `const T` to avoid erasing the + // constness of the buffer passed to `set_var_X_data` but the underlying + // buffer is not const (and thus the const can be safely const-cast'ed away) + // unless `set_var_X_data` is called with a pointer to constant storage. {{METHODS_VARIABLE}} private: // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = {{NUM_BUFFERS}}; - static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() { - static const ::tensorflow::cpu_function_runtime::BufferInfo + static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() { + static const ::xla::cpu_function_runtime::BufferInfo kBufferInfos[kNumBuffers] = { {{BUFFER_INFOS_AS_STRING}} }; diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 5580e55b691..73be43c1d0c 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" @@ -34,7 +35,7 @@ namespace tensorflow { namespace tfcompile { namespace { -using ::tensorflow::cpu_function_runtime::BufferInfo; +using ::xla::cpu_function_runtime::BufferInfo; void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); @@ -175,14 +176,19 @@ TEST(CodegenTest, Golden) { fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); tf2xla::Variable* variable = config.add_variable(); - variable->set_node_name("myvar"); + variable->set_node_name("myvar_readonly"); variable->mutable_shape()->add_dim()->set_size(1); variable->set_type(DT_FLOAT); + variable->set_readonly(true); tf2xla::Variable* variable2 = config.add_variable(); - variable2->set_node_name("my/var"); - variable2->set_name("myvar2"); - variable2->mutable_shape()->add_dim()->set_size(5); - variable2->set_type(DT_INT32); + variable2->set_node_name("myvar"); + variable2->mutable_shape()->add_dim()->set_size(1); + variable2->set_type(DT_FLOAT); + tf2xla::Variable* variable3 = config.add_variable(); + variable3->set_node_name("my/var"); + variable3->set_name("myvar2"); + variable3->mutable_shape()->add_dim()->set_size(5); + variable3->set_type(DT_INT32); CompileResult compile_result; compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( {}, @@ -198,6 +204,7 @@ TEST(CodegenTest, Golden) { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), xla::ShapeUtil::MakeShape(xla::F32, {1}), + xla::ShapeUtil::MakeShape(xla::F32, {1}), xla::ShapeUtil::MakeShape(xla::S32, {5}), }, xla::ShapeUtil::MakeTupleShape({ diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 8591df53877..702582b968a 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -52,7 +52,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) // // Memory stats: // arg bytes total: 104 @@ -91,7 +91,8 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) + MyClass(AllocMode alloc_mode = + AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} MyClass(const MyClass&) = delete; @@ -214,71 +215,97 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { result_data(0)))[dim0][dim1]; } - // Methods for managing variable buffers. Buffers are in row-major order. The - // input and output buffers may or may not be identical. + // Methods for managing variable buffers. Buffers are in row-major order. + // + // For read-write variables we generate the following methods: // // void set_var_X_data(T* data) - // Sets the buffer for variable X. + // Sets the buffer for variable X. Must be called before Run if the + // allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY. // // T* var_X_data() - // Returns the buffer of type T for variable X. + // Returns the buffer of type T for variable X. If the allocation mode is + // RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the + // buffer passed to set_var_X_data. // // T& var_X(...dim indices...) // Returns a reference to the value of type T for variable X, // with dim indices specifying which value. No bounds checking is performed // on dim indices. + // + // For readonly variables we generate the same set of methods, except that we + // use `const T` instead of `T`. We use `const T` to avoid erasing the + // constness of the buffer passed to `set_var_X_data` but the underlying + // buffer is not const (and thus the const can be safely const-cast'ed away) + // unless `set_var_X_data` is called with a pointer to constant storage. - void set_var_myvar_data(float* data) { + void set_var_myvar_readonly_data(const float* data) { set_arg_data(2, data); } - - void set_var_myvar2_data(tensorflow::int32* data) { - set_arg_data(3, data); + const float* var_myvar_readonly_data() { + return static_cast(arg_data(2)); + } + const float& var_myvar_readonly() { + return (*static_cast( + arg_data(2)))[0]; + } + const float* var_myvar_readonly_data() const { + return static_cast(arg_data(2)); + } + const float& var_myvar_readonly() const { + return (*static_cast( + arg_data(2)))[0]; } + void set_var_myvar_data(float* data) { + set_arg_data(3, data); + } float* var_myvar_data() { - return static_cast(result_data(1)); + return static_cast(arg_data(3)); } float& var_myvar() { return (*static_cast( - result_data(1)))[0]; + arg_data(3)))[0]; } const float* var_myvar_data() const { - return static_cast(result_data(1)); + return static_cast(arg_data(3)); } const float& var_myvar() const { return (*static_cast( - result_data(1)))[0]; + arg_data(3)))[0]; } + void set_var_myvar2_data(tensorflow::int32* data) { + set_arg_data(4, data); + } tensorflow::int32* var_myvar2_data() { - return static_cast(result_data(2)); + return static_cast(arg_data(4)); } tensorflow::int32& var_myvar2(size_t dim0) { return (*static_cast( - result_data(2)))[dim0]; + arg_data(4)))[dim0]; } const tensorflow::int32* var_myvar2_data() const { - return static_cast(result_data(2)); + return static_cast(arg_data(4)); } const tensorflow::int32& var_myvar2(size_t dim0) const { return (*static_cast( - result_data(2)))[dim0]; + arg_data(4)))[dim0]; } private: // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = 6; - static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() { - static const ::tensorflow::cpu_function_runtime::BufferInfo + static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() { + static const ::xla::cpu_function_runtime::BufferInfo kBufferInfos[kNumBuffers] = { -::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), -::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}), -::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}), -::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}), -::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}), -::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL}) +::xla::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({34ULL, 0ULL}), +::xla::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({386ULL, 1ULL}), +::xla::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}), +::xla::cpu_function_runtime::BufferInfo({481ULL, ~0ULL}) }; return kBufferInfos; } @@ -309,7 +336,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShapeProto* StaticProgramShape() { static const xla::ProgramShapeProto* kShape = []() { xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 132); + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 149); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index 2884597abcf..38c75d1fb60 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index ce8dae42629..6362470abef 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -36,6 +36,7 @@ py_binary( name = "make_test_graphs", testonly = 1, srcs = ["make_test_graphs.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 7f5e907e263..739cb016643 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -159,10 +159,11 @@ def tfvariable(_): def tfvariable_sequential_updates(_): x = variables.Variable(1.0, name='x') + y = variables.Variable(1.0, name='y') updates = control_flow_ops.no_op() for _ in range(3): with ops.control_dependencies([updates]): - x_val = x.read_value() + 1.0 + x_val = x.read_value() + y updates = x.assign_sub(0.1 * x_val) array_ops.identity(updates, name='result') diff --git a/tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.config.pbtxt index 7312c40baf6..eb2ae56cd10 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.config.pbtxt @@ -7,3 +7,9 @@ variable { node_name: "x" type: DT_FLOAT } + +variable { + node_name: "y" + type: DT_FLOAT + readonly: true +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 5bee7f2540a..c55f3f946db 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -83,7 +83,8 @@ TEST(TFCompileTest, Add) { // Run tests that use set_argN_data separately, to avoid accidentally re-using // non-existent buffers. TEST(TFCompileTest, Add_SetArg) { - AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); + AddComp add( + XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); int32 arg_x = 10; int32 arg_y = 32; @@ -296,7 +297,7 @@ TEST(TFCompileTest, MatMul2_SetArg) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); foo::bar::MatMulComp matmul( - foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); + XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); matmul.set_thread_pool(&device); // Test using the set_argN_data() methods. @@ -502,20 +503,50 @@ TEST(TFCompileTest, VariableSequentialUpdates) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); // This implements the recursion: - // x[0] = 1.0 - // x[n+1] = x[n] - 0.1*(x[n-1] + 1.0) + // x[0] = 2.0 + // x[n+1] = x[n] - 0.1*(x[n-1] + y) VariableSequentialUpdatesComp fn; - float x = 1; - fn.set_var_x_data(&x); + fn.var_x() = 2; + *const_cast(fn.var_y_data()) = 1; fn.set_thread_pool(&device); // First calculate x[3] fn.Run(); - EXPECT_NEAR(x, 0.458f, 1e-6); + EXPECT_NEAR(fn.var_x(), 1.187f, 1e-6); + + const float y = 1; + fn.set_var_y_data(&y); + + // Now const_cast(fn.var_y_data()) is not longer legal since we've set + // the buffer to point to a constant location. // Then calculate x[6] fn.Run(); - EXPECT_NEAR(x, 0.062882f, 1e-6); + EXPECT_NEAR(fn.var_x(), 0.594322f, 1e-6); +} + +TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + // This implements the recursion: + // x[0] = 2.0 + // x[n+1] = x[n] - 0.1*(x[n-1] + 1.0) + VariableSequentialUpdatesComp fn( + XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); + float x = 2; + float y = 1; + fn.set_var_x_data(&x); + fn.set_var_y_data(&y); + + fn.set_thread_pool(&device); + // First calculate x[3] + fn.Run(); + EXPECT_NEAR(x, 1.187f, 1e-6); + + // Then calculate x[6] + fn.Run(); + EXPECT_NEAR(x, 0.594322f, 1e-6); } TEST(TFCompileTest, AssertEqAndReturnDiff) { diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index fd701ab7166..e7f3c0aebdd 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -163,7 +163,10 @@ def tf_library( header_file = name + ".h" metadata_object_file = name + "_tfcompile_metadata.o" function_object_file = name + "_tfcompile_function.o" - ep = ("__" + native.package_name() + "__" + name).replace("/", "_") + + # The XLA backends morph kernal name prefix __ that is not in the form of + # __xla_. + ep = ("__xla_" + native.package_name() + "__" + name).replace("/", "_") if type(tfcompile_flags) == type(""): flags = tfcompile_flags else: @@ -171,6 +174,20 @@ def tf_library( "'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or []) ]) + + # Do this before we append the `select` into `flags`, because doing so + # transforms `flags` into a variable of type `select`, and we can't call + # `find` on such an object. + need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1 + + # Pass --target_cpu=haswell to tfcompile if compiling for Haswell (bazel + # build --cpu=haswell). We put it at the beginning of the flags list so + # that tfcompile_flags can override if if desired. + flags = select({ + "//tools/target_cpu:haswell": "--target_cpu=haswell ", + "//conditions:default": "", + }) + flags + if enable_xla_hlo_profiling: profiling_flag = "--xla_hlo_profile" else: @@ -248,7 +265,6 @@ def tf_library( # The cc_library rule packaging up the header and object file, and needed # kernel implementations. - need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) native.cc_library( name = name, srcs = [function_object_file, metadata_object_file], diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index b846ad789e5..4b3726b8475 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -17,15 +17,14 @@ package_group( package( default_visibility = [ ":internal", + # BEGIN-GOOGLE-INTERNAL + "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", + # END-GOOGLE-INTERNAL ], ) -# NB! Removing the cc_header_only_library import breaks the OSS build since -# copybara injects some build rules that use it. -load("//tensorflow:tensorflow.bzl", "cc_header_only_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") # Target that bundles up the XLA CPU and GPU JIT devices. @@ -78,10 +77,10 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ - ":create_xla_launch_op", # buildcleaner: keep ":flags", ":jit_compilation_passes", ":xla_device", + ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -98,9 +97,9 @@ cc_library( srcs = ["xla_gpu_device.cc"], visibility = [":friends"], deps = [ - ":create_xla_launch_op", # buildcleaner: keep ":jit_compilation_passes", ":xla_device", + ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -168,7 +167,6 @@ cc_library( ":xla_tensor", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -199,6 +197,7 @@ cc_library( "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:host_constant_op", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", @@ -212,6 +211,8 @@ cc_library( "//tensorflow/core/kernels/data:iterator_ops", "//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor/platform", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", @@ -222,6 +223,7 @@ cc_library( name = "shape_inference_helpers", srcs = ["shape_inference_helpers.cc"], hdrs = ["shape_inference_helpers.h"], + visibility = [":friends"], deps = ["//tensorflow/core:graph"], ) @@ -236,6 +238,7 @@ cc_library( "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -254,6 +257,11 @@ cc_library( name = "xla_launch_util", srcs = ["xla_launch_util.cc"], hdrs = ["xla_launch_util.h"], + # TODO(skyewm): remove this once XlaAllocator is factored out. + visibility = [ + ":internal", + "//tensorflow/compiler/xla/python:__pkg__", + ], deps = [ ":common", ":xla_compilation_cache", @@ -263,7 +271,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -271,6 +278,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -283,7 +291,6 @@ cc_library( hdrs = ["xla_compilation_cache.h"], deps = [ "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", @@ -326,10 +333,10 @@ cc_library( ) cc_library( - name = "create_xla_launch_op", + name = "xla_kernel_creator", srcs = [ - "create_xla_launch_op.cc", - "create_xla_launch_op.h", + "xla_kernel_creator.cc", + "xla_kernel_creator.h", ], deps = [ ":common", @@ -346,13 +353,13 @@ cc_library( ) tf_cc_test( - name = "create_xla_launch_op_test", + name = "xla_kernel_creator_test", srcs = [ - "create_xla_launch_op.h", - "create_xla_launch_op_test.cc", + "xla_kernel_creator.h", + "xla_kernel_creator_test.cc", ], deps = [ - ":create_xla_launch_op", + ":xla_kernel_creator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -370,6 +377,7 @@ cc_library( srcs = ["resource_operation_safety_analysis.cc"], hdrs = ["resource_operation_safety_analysis.h"], deps = [ + ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/core:framework", @@ -417,7 +425,6 @@ cc_library( hdrs = ["shape_inference.h"], deps = [ ":shape_inference_helpers", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -467,6 +474,9 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -498,6 +508,7 @@ cc_library( "encapsulate_xla_computations_pass.cc", "extract_outside_compilation_pass.cc", "increase_dynamism_for_auto_jit_pass.cc", + "introduce_floating_point_jitter_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -510,24 +521,28 @@ cc_library( "encapsulate_xla_computations_pass.h", "extract_outside_compilation_pass.h", "increase_dynamism_for_auto_jit_pass.h", + "introduce_floating_point_jitter_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", ], deps = [ + "compilability_check_util", ":common", + ":device_util", ":encapsulate_util", ":flags", + ":resource_operation_safety_analysis", ":shape_inference_helpers", ":union_find", ":xla_cluster_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:scope", "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/ops:xla_ops", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -535,6 +550,7 @@ cc_library( "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -561,19 +577,49 @@ cc_library( hdrs = ["xla_cluster_util.h"], deps = [ ":flags", - ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_bounds_check", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "device_util", + srcs = ["device_util.cc"], + hdrs = ["device_util.h"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:framework", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "device_util_test", + srcs = ["device_util_test.cc"], + deps = [ + ":device_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -631,13 +677,15 @@ tf_cc_test( srcs = [ "build_xla_ops_pass_test.cc", "clone_constants_for_better_clustering_test.cc", - "compilation_passes_test_main.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", "extract_outside_compilation_pass_test.cc", "increase_dynamism_for_auto_jit_pass_test.cc", + "introduce_floating_point_jitter_pass_internal.h", + "introduce_floating_point_jitter_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", + "rearrange_function_argument_pass_test.cc", ], deps = [ ":common", @@ -658,6 +706,7 @@ tf_cc_test( "//tensorflow/cc:scope", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla:rearrange_function_argument", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -677,6 +726,7 @@ tf_cc_test( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -696,6 +746,7 @@ tf_cc_test( "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -708,43 +759,6 @@ tf_cc_test( ], ) -cc_library( - name = "xla_fusion_optimizer", - srcs = ["xla_fusion_optimizer.cc"], - hdrs = ["xla_fusion_optimizer.h"], - visibility = ["//visibility:public"], - deps = [ - ":common", - ":compilation_passes", - ":union_find", - ":xla_cluster_util", - "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "@com_google_absl//absl/strings", - ], -) - -tf_cuda_cc_test( - name = "xla_fusion_optimizer_test", - srcs = ["xla_fusion_optimizer_test.cc"], - deps = [ - ":common", - ":xla_cluster_util", - ":xla_fusion_optimizer", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:resource_variable_ops", - "//tensorflow/core:graph", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/grappler/utils:grappler_test", - ], -) - cc_library( name = "node_matchers", testonly = True, @@ -776,6 +790,34 @@ tf_cc_test( ], ) +cc_library( + name = "compilability_check_util", + srcs = ["compilability_check_util.cc"], + hdrs = ["compilability_check_util.h"], + deps = [ + ":common", + ":device_util", + ":flags", + ":resource_operation_safety_analysis", + ":union_find", + ":xla_cluster_util", + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + tf_custom_op_py_library( name = "xla_ops_py", kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 80cca24c827..47b3c6611f3 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/build_xla_ops_pass.h" + #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -23,12 +24,13 @@ limitations under the License. #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/control_flow_ops.h" #include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/logging_ops.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" @@ -42,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { @@ -74,7 +77,8 @@ Operation DataToControl(const Scope& scope, Output data) { // Replaces each outgoing edge from `old_node` with a merge node that merges in // the corresponding output from `new_node`. -void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node) { +void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node, + bool insert_print_nodes) { if (!s.status().ok()) { return; } @@ -91,7 +95,21 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node) { if (merged_output.node() == nullptr) { ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)), {Output(old_node, oidx), Output(new_node, oidx)}); - merged_output = merged_outputs[oidx] = merge_op.output; + if (insert_print_nodes) { + string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx)) + .WithDevice(cpu_device) + .WithAssignedDevice(cpu_device), + merge_op.output, {merge_op.output}, + ops::Print::Attrs{} + .Message(absl::StrCat("output ", oidx, " from ", + old_node->name(), " is ")) + .FirstN(1000) + .Summarize(-1)); + merged_output = merged_outputs[oidx] = print_op; + } else { + merged_output = merged_outputs[oidx] = merge_op.output; + } } Node* dst = e->dst(); @@ -215,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) { } // Returns true (into `result`) if a node placed on `device` must be compiled. -Status DeviceRequiresCompilation(const string& device, bool* result) { - DeviceType device_type(""); - TF_RETURN_IF_ERROR(DeviceToDeviceType(device, &device_type)); - const XlaOpRegistry::DeviceRegistration* registration = nullptr; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - return errors::Internal("Could not find compilation device ", - device_type.type()); - } +Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache, + jit::DeviceId device, bool* result) { + const XlaOpRegistry::DeviceRegistration* registration = + device_info_cache.GetCompilationDevice(device); *result = registration->autoclustering_policy == XlaOpRegistry::AutoclusteringPolicy::kAlways; return Status::OK(); @@ -275,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall( return Status::OK(); } -Status InferDeviceForCluster(Node* n, const string& function_name, - const FunctionLibraryDefinition& flib_def, - string* result) { +xla::StatusOr InferDeviceForCluster( + jit::DeviceInfoCache* device_info_cache, Node* n, + const string& function_name, const FunctionLibraryDefinition& flib_def) { const FunctionDef* func_def = flib_def.Find(function_name); TF_RET_CHECK(func_def) << "Could not find " << function_name; - std::set device_names; + jit::DeviceSet device_set; + for (const NodeDef& ndef : func_def->node_def()) { VLOG(3) << ndef.DebugString(); if (!ndef.device().empty()) { - device_names.insert(ndef.device()); + TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, + device_info_cache->GetIdFor(ndef.device())); + device_set.Insert(device_id); } } @@ -293,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name, // TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device // assignment when constant folding. We should fix EncapsulateSubgraphsPass // instead. - device_names.insert(n->assigned_device_name()); + TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, + device_info_cache->GetIdFor(n->assigned_device_name())); + device_set.Insert(device_id); } - std::vector device_names_vector; - absl::c_copy(device_names, std::back_inserter(device_names_vector)); - - Status s = PickDeviceForXla(device_names_vector, true, result); - if (s.ok()) { - VLOG(2) << "For " << function_name << " PickDeviceForXla(" - << absl::StrJoin(device_names_vector, ", ") << ") -> " << *result; - } - return s; + TF_ASSIGN_OR_RETURN(jit::DeviceId result, + PickDeviceForXla(*device_info_cache, device_set, + /*allow_mixing_unknown_and_cpu=*/true)); + VLOG(2) << "For " << function_name << " PickDeviceForXla(" + << device_info_cache->DebugString(device_set) << ") -> " + << device_info_cache->GetNameFor(result); + return result; } Status ReplaceNodeWithXlaCompileAndXlaRun( + jit::DeviceInfoCache* device_info_cache, const GraphOptimizationPassOptions& options, const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, - Graph* g, Node* n) { + bool insert_print_nodes, Graph* g, Node* n) { XlaClusterInfo cluster_info; TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); - string device; - TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(), - flib_def, &device)); + TF_ASSIGN_OR_RETURN( + jit::DeviceId device, + InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(), + flib_def)); + bool requires_compilation; - TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation)); + TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device, + &requires_compilation)); if (!lazy_compilation_enabled) { requires_compilation = true; } + string device_name_str = string(device_info_cache->GetNameFor(device)); + Status status; Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) .NewSubScope(n->name()) .WithDevice(n->requested_device()) - .WithAssignedDevice(device); + .WithAssignedDevice(device_name_str); ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), /*constants=*/cluster_info.constant_inputs, @@ -378,7 +401,8 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( /*new_node=*/xla_run.operation.node()); MergeOutgoingDataEdges(root, /*old_node=*/n, - /*new_node=*/xla_run.operation.node()); + /*new_node=*/xla_run.operation.node(), + insert_print_nodes); TF_RETURN_IF_ERROR(root.status()); @@ -418,15 +442,20 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { bool lazy_compilation_enabled = enable_lazy_compilation_ ? *enable_lazy_compilation_ - : GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation; + : GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation; + bool insert_print_nodes = + GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs; + + jit::DeviceInfoCache device_info_cache; for (Node* n : xla_compiled_kernels) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( - options, *options.flib_def, lazy_compilation_enabled, graph, n)); + &device_info_cache, options, *options.flib_def, + lazy_compilation_enabled, insert_print_nodes, graph, n)); } if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def); + DumpGraphToFile("build_xla_ops", *graph, options.flib_def); } return Status::OK(); diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc index 848a6362a4a..6df4aa2380e 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -122,7 +122,7 @@ Status CloneConstantsForBetterClusteringPass::CloneSmallHostConstantInputs( Status CloneConstantsForBetterClusteringPass::Run( const GraphOptimizationPassOptions& options) { - if (GetGlobalJitLevel(options) == OptimizerOptions::OFF) { + if (GetGlobalJitLevelForGraph(options) == OptimizerOptions::OFF) { return Status::OK(); } diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc new file mode 100644 index 00000000000..91e85970cc0 --- /dev/null +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -0,0 +1,277 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/compilability_check_util.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/device_util.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { + +namespace { +bool HasResourceInput(const Node& node) { + return absl::c_count(node.input_types(), DT_RESOURCE) != 0; +} +} // anonymous namespace + +bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) { + // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient + // is really a kind of function call and will be handled by + // IsCompilableCall(). + if (node.type_string() == "SymbolicGradient") return false; + if (node.type_string() == "Const") { + // Skip Const op with type DT_STRING, since XLA doesn't support it, but the + // registered Const KernelDef says that it does, to support no-op Assert for + // tfcompile. + const AttrValue* attr = node.attrs().Find("dtype"); + if (attr != nullptr && attr->type() == DT_STRING) { + return false; + } + } + + // XLA does not offer guaranteed aliasing between the input and output of the + // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave + // such nodes out of XLA clusters. + if (HasForwardedRefInput(node)) { + VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast."; + return false; + } + + return FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr).ok(); +} + +// Tests whether 'while_node' is a completely compilable loop. +// Every operator in the condition and body functions must be compilable for a +// while loop to be compilable. +bool RecursiveCompilabilityChecker::IsCompilableWhile( + const Node& while_node, int depth, FunctionLibraryRuntime* lib_runtime) { + const NameAttrList* name_attr; + NodeDef call; + Status status; + status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); + if (!status.ok()) { + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'cond' attribute on While node."; + return false; + } + const string cond_func = name_attr->name(); + call.set_name("while_cond"); + call.set_op(cond_func); + *call.mutable_attr() = name_attr->attr(); + if (!IsCompilableCall(call, depth + 1, lib_runtime)) { + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop condition: " << cond_func; + return false; + } + status = GetNodeAttr(while_node.attrs(), "body", &name_attr); + if (!status.ok()) { + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'body' attribute on While node."; + return false; + } + const string body_func = name_attr->name(); + call.set_name("while_body"); + call.set_op(body_func); + *call.mutable_attr() = name_attr->attr(); + if (!IsCompilableCall(call, depth + 1, lib_runtime)) { + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop body: " << body_func; + return false; + } + return true; +} + +// Tests whether 'call_def' is a call to a completely compilable function. +// Every operator in the function must be compilable for a function to be +// compilable. +bool RecursiveCompilabilityChecker::IsCompilableCall( + const NodeDef& call_def, int depth, FunctionLibraryRuntime* lib_runtime) { + if (depth > kMaxRecursionDepth) { + VLOG(2) << "Rejecting " << call_def.op() + << ": function depth limit exceeded."; + return false; + } + + FunctionLibraryRuntime::Handle handle; + Status status = InstantiateFunctionCall(call_def, lib_runtime, &handle); + if (!status.ok()) { + VLOG(2) << "Rejecting " << call_def.DebugString() + << ": could not instantiate: " << status; + return false; + } + + auto release_handle_on_return = gtl::MakeCleanup( + [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); }); + + const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); + for (Node* node : fbody->graph->op_nodes()) { + if (!IsCompilableNode(*node, depth + 1, lib_runtime)) { + return false; + } + } + + return true; +} + +bool LogNotCompilableAndReturn(const Node& node, + absl::string_view reason = "") { + VLOG(3) << "Not clustering " << node.name() << " (op " << node.type_string() + << ")" << (reason.empty() ? "" : ": ") << reason; + return false; +} + +bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) { + // b/127344411: SelfAdjointEigV2 and Svd precision issues. + return node.type_string() == "SelfAdjointEigV2" || + node.type_string() == "Svd"; +} + +bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) { + // b/128001705: SelfAdjointEigV2 and Svd performance issues. + return node.type_string() == "SelfAdjointEigV2" || + node.type_string() == "Svd" || node.type_string() == "Qr"; +} + +bool RecursiveCompilabilityChecker::IsCompilableNode( + const Node& node, int depth, FunctionLibraryRuntime* lib_runtime) { + if (node.IsSource() || node.IsSink()) { + return LogNotCompilableAndReturn(node, "source or sink node"); + } + + // _Arg nodes in a top-level function represent feeds and _Retval nodes in a + // top-level function represent fetches. + if (depth == 0 && + (node.type_string() == "_Arg" || node.type_string() == "_Retval")) { + return LogNotCompilableAndReturn(node, "depth is 0"); + } + + if (node.attrs().Find("_scoped_allocator") || + node.attrs().Find("_forward_from")) { + // TODO(b/128858118): XLA does not support _scoped_allocator and + // _forward_from. + return LogNotCompilableAndReturn( + node, "_scoped_allocator or _forward_from attribute"); + } + + if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) { + if (!IsCompilableCall(node.def(), depth + 1, lib_runtime)) { + return LogNotCompilableAndReturn(node, "unsupported function"); + } + } else if (!HasXLAKernel(node)) { + return LogNotCompilableAndReturn(node, "unsupported op"); + } + + if (node.type_string() == "While" && + !IsCompilableWhile(node, depth + 1, lib_runtime)) { + return LogNotCompilableAndReturn(node, "unsupported while"); + } + + if (!op_filter_.allow_stateful_rng_ops && + IsStatefulRandomOp(node.type_string())) { + return LogNotCompilableAndReturn(node, "stateful random op"); + } + + if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) { + return LogNotCompilableAndReturn(node); + } + + if (!op_filter_.allow_eliding_assert_and_checknumerics_ops && + IsAssertOrCheckNumerics(node.type_string())) { + return LogNotCompilableAndReturn(node, "Assert or CheckNumerics"); + } + + if (!op_filter_.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(node)) { + return LogNotCompilableAndReturn(node, "DT_VARIANT producer/consumer"); + } + + if (!op_filter_.allow_stack_ops && IsStackOp(node)) { + return LogNotCompilableAndReturn(node, "Stack op"); + } + + if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) { + return LogNotCompilableAndReturn(node, "TensorArray op"); + } + + if (!op_filter_.allow_resource_ops_in_called_functions && depth > 0 && + HasResourceInput(node)) { + return LogNotCompilableAndReturn(node, + "resource variable op in called function"); + } + + if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsInaccurate(node)) { + return LogNotCompilableAndReturn(node, "operation with correctness issues"); + } + + if (!op_filter_.allow_slow_and_inaccurate_ops && OpIsSlow(node)) { + return LogNotCompilableAndReturn(node, "slow operation"); + } + + return true; +} + +RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( + const XlaOpRegistry::DeviceRegistration& registration) { + RecursiveCompilabilityChecker::OperationFilter op_filter; + op_filter.allow_resource_ops_in_called_functions = + registration.cluster_resource_variable_ops_unsafely; + op_filter.allow_stack_ops = registration.cluster_stack_ops; + op_filter.allow_tensor_array_ops = registration.cluster_tensor_array_ops; + op_filter.allow_stateful_rng_ops = registration.cluster_stateful_rng_ops; + op_filter.allow_control_trigger = registration.cluster_control_trigger; + op_filter.allow_eliding_assert_and_checknumerics_ops = + registration.elide_assert_and_checknumerics; + op_filter.allow_ops_producing_or_consuming_variant = + registration.cluster_variant_ops; + op_filter.allow_slow_and_inaccurate_ops = + registration.cluster_slow_and_inaccurate_ops; + return op_filter; +} + + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h new file mode 100644 index 00000000000..4be8050f7da --- /dev/null +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -0,0 +1,175 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/device_util.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { +// Checks whether a TF node can be compiled or not. "Recursive" as in for call +// and functional while nodes it recursively checks whether the callee functions +// can be compiled. +class RecursiveCompilabilityChecker { + public: + // Aggregates information about what kinds of ops are allowed. + struct OperationFilter { // TODO(lzr): Add AllowEverything() helper. + // Whether resource variable ops are allowed are allowed in callees. We do + // not allow resource variable ops in called functions (either as direct TF + // calls or as higher order control flow ops) because we do not yet model + // their memory effects in jit/resource_variable_safety_analysis. + bool allow_resource_ops_in_called_functions; + + // Whether Stack operations are allowed. We avoid auto-clustering Stack + // operations in general because we do not support snapshotting them. + // + // TODO(b/112837194): This restriction can be lifted with some work. + bool allow_stack_ops; + + // Whether TensorArray operations are allowed. We avoid auto-clustering + // TensorArray operations in general because we do not support snapshotting + // them. + // + // TODO(b/112837194): This restriction can be lifted with some work. + bool allow_tensor_array_ops; + + // Whether stateful RNG ops are allowed. XLA's RNG does not have the same + // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid + // auto-clustering stateful RNG ops. + bool allow_stateful_rng_ops; + + // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound + // to cluster ControlTrigger because of how we use deadness analysis. + bool allow_control_trigger; + + // Whether it is okay to "cluster" Assert and CheckNumerics by simply + // removing them (they're not removed during clustering, but their + // XlaOpKernel is a no-op kernel). We avoid auto-clustering these ops so + // that the user is not surprised when XLA is implicitly enabled. If the + // user explicitly specifies to use XLA, it is fine to resort to a dummy + // implementation. Currently Assert and CheckNumerics ops have dummy XLA + // implementations. + bool allow_eliding_assert_and_checknumerics_ops; + + // Whether ops that produce or consume DT_VARIANT values are allowed. We + // don't auto-cluster these ops because we don't yet support live-in or + // live-out DT_VARIANT values. + bool allow_ops_producing_or_consuming_variant; + + // Whether ops known to be slow or to have correctness issues should be + // auto-clustered. + bool allow_slow_and_inaccurate_ops; + }; + + RecursiveCompilabilityChecker(const OperationFilter* op_filter, + const DeviceType* jit_device_type) + : op_filter_(*op_filter), jit_device_type_(*jit_device_type) {} + + // Returns true if `node` can be compiled by XLA. + bool IsCompilableNode(const Node& node, FunctionLibraryRuntime* lib_runtime) { + return IsCompilableNode(node, /*depth=*/0, lib_runtime); + } + + // Returns true if `call_def` can be compiled by XLA. It is assumed that + // `call_def` is a call operation. + bool IsCompilableCall(const NodeDef& call_def, + FunctionLibraryRuntime* lib_runtime) { + return IsCompilableCall(call_def, /*depth=*/0, lib_runtime); + } + + // Returns true if XLA supports this Op, but we don't want to cluster it (ie: + // due to performance or correctness concerns). + bool OpIsInaccurate(const Node& node); + bool OpIsSlow(const Node& node); + + private: + bool IsCompilableNode(const Node& node, int depth, + FunctionLibraryRuntime* lib_runtime); + bool IsCompilableCall(const NodeDef& call_def, int depth, + FunctionLibraryRuntime* lib_runtime); + bool IsCompilableWhile(const Node& while_node, int depth, + FunctionLibraryRuntime* lib_runtime); + + bool IsStackOp(const Node& node) { + const XlaResourceOpInfo* op_info = + GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() == XlaResourceKind::kStack; + } + + bool IsTensorArrayOp(const Node& node) { + const XlaResourceOpInfo* op_info = + GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray; + } + + bool IsAssertOrCheckNumerics(absl::string_view op_name) { + return op_name == "Assert" || op_name == "CheckNumerics"; + } + + bool IsStatefulRandomOp(absl::string_view op_name) { + return op_name == "RandomUniform" || op_name == "RandomShuffle" || + op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || + op_name == "TruncatedNormal" || op_name == "Multinomial"; + } + + bool OpProducesOrConsumesVariant(const Node& node) { + auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; + return absl::c_any_of(node.input_types(), is_variant) || + absl::c_any_of(node.output_types(), is_variant); + } + + bool HasXLAKernel(const Node& node); + + // Make sure we don't recurse infinitely on recursive functions. + const int kMaxRecursionDepth = 10; + + const OperationFilter& op_filter_; + const DeviceType& jit_device_type_; +}; + +RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( + const XlaOpRegistry::DeviceRegistration& registration); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ diff --git a/tensorflow/compiler/jit/compilation_passes_test_main.cc b/tensorflow/compiler/jit/compilation_passes_test_main.cc index 4b5c26faeaf..c73702fa642 100644 --- a/tensorflow/compiler/jit/compilation_passes_test_main.cc +++ b/tensorflow/compiler/jit/compilation_passes_test_main.cc @@ -38,10 +38,13 @@ GTEST_API_ int main(int real_argc, char** real_argv) { void operator()(char* ptr) { free(ptr); } }; - std::unique_ptr allocated_arg( + std::unique_ptr enable_global_jit_arg( strdup("--tf_xla_cpu_global_jit=true")); + args.push_back(enable_global_jit_arg.get()); - args.push_back(allocated_arg.get()); + std::unique_ptr reduce_min_cluster_size_arg( + strdup("--tf_xla_min_cluster_size=2")); + args.push_back(reduce_min_cluster_size_arg.get()); int argc = args.size(); diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 4397eea9af2..d2501b9ef1e 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -106,6 +106,8 @@ namespace tensorflow { namespace { +using se::port::StatusOr; + // Represents a logical predicate, used as described in the algorithm overview // above. class Predicate { @@ -369,7 +371,8 @@ class PredicateFactory { Predicate** predicate) { TensorId tensor_id(node->name(), output_idx); - bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL; + bool is_boolean_tensor = + BaseType(node->output_type(tensor_id.index())) == DT_BOOL; TF_RET_CHECK(!must_be_true || is_boolean_tensor); if (node->type_string() == "Const" && must_be_true) { @@ -698,7 +701,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status Populate(); Status PopulateWithReversePostOrder(absl::Span rpo); - bool HasInputsWithMismatchingDeadness(const Node& node) override; + StatusOr GetPredicateFor( + Node* n, int oidx) const override; void Print() const override; absl::flat_hash_map PredicateMapAsString() const; @@ -768,7 +772,8 @@ Status DeadnessAnalysisImpl::GetInputPreds( auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); if (it == predicate_map_.end()) { GraphCycles graph_cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + TF_RETURN_IF_ERROR( + CreateCycleDetectionGraph(&graph_, &graph_cycles).status()); // If we didn't return with an error above then the graph is probably // fine and we have a bug in deadness analysis. @@ -1112,42 +1117,13 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( return Status::OK(); } -bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) { - CHECK(!node.IsMerge()); - - if (vlog_) { - VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")"; - } - - Predicate* pred = nullptr; - for (const Edge* edge : node.in_edges()) { - auto it = predicate_map_.find(InputEdgeToTensorId(edge)); - CHECK(it != predicate_map_.end()); - if (vlog_) { - VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": " - << it->second->ToString(); - } - - // Today we just compare the predicates for equality (with some - // canonicalization/simplification happening before) but we could be more - // sophisticated here if need be. Comparing pointers is sufficient because - // we intern Predicate instances by their content. - if (pred != nullptr && pred != it->second) { - if (vlog_) { - VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() - << ") -> true"; - } - return true; - } - pred = it->second; - } - - if (vlog_) { - VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() - << ") -> false"; - } - - return false; +StatusOr +DeadnessAnalysisImpl::GetPredicateFor(Node* n, int oidx) const { + auto it = predicate_map_.find(TensorId(n->name(), oidx)); + TF_RET_CHECK(it != predicate_map_.end()) + << "could not find " << TensorId(n->name(), oidx).ToString() + << " in predicate map"; + return MakeDeadnessPredicate(it->second); } void DeadnessAnalysisImpl::Print() const { @@ -1212,4 +1188,8 @@ Status ComputePredicates(const Graph& graph, } } // namespace deadness_analysis_internal +string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const { + return static_cast(predicate.pred_)->ToString(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h index 6e7ab411619..c8527de503d 100644 --- a/tensorflow/compiler/jit/deadness_analysis.h +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ #include "tensorflow/core/graph/graph.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -43,24 +44,55 @@ namespace tensorflow { // "liveness" already has other connotations. class DeadnessAnalysis { public: - // Returns true if `node` may have some live inputs and some dead inputs. - // - // This is a conservatively correct routine -- if it returns false then `node` - // is guaranteed to not have inputs with mismatching liveness, but not the - // converse. - // - // REQUIRES: node is not a Merge operation. - virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0; + // An opaque representation of a predicate. DeadnessPredicate + // instances that compare equal via operator== represent predicates + // that always evaluate to the same value. + struct DeadnessPredicate { + public: + DeadnessPredicate(const DeadnessPredicate&) = default; + DeadnessPredicate(DeadnessPredicate&&) = default; + + DeadnessPredicate& operator=(const DeadnessPredicate&) = default; + DeadnessPredicate& operator=(DeadnessPredicate&&) = default; + + bool operator==(const DeadnessPredicate& other) const { + return other.pred_ == pred_; + } + + bool operator!=(const DeadnessPredicate& other) const { + return other.pred_ != pred_; + } + + private: + explicit DeadnessPredicate(void* pred) : pred_(pred) {} + + // This is really a Predicate*, but we don't want to expose that + // implementation detail to our clients. `pred_` has pointer equality so we + // can just compare the pointer in operator== and operator!=. + void* pred_; + + friend class DeadnessAnalysis; + }; + + virtual se::port::StatusOr GetPredicateFor( + Node* n, int oidx) const = 0; // Prints out the internal state of this instance. For debugging purposes // only. virtual void Print() const = 0; virtual ~DeadnessAnalysis(); + string DebugString(DeadnessPredicate predicate) const; + // Run the deadness analysis over `graph` and returns an error or a populated // instance of DeadnessAnalysis in `result`. static Status Run(const Graph& graph, std::unique_ptr* result); + + protected: + static DeadnessPredicate MakeDeadnessPredicate(void* pred) { + return DeadnessPredicate(pred); + } }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 38a5118d9a7..3a44eb7db75 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -37,6 +37,22 @@ limitations under the License. namespace tensorflow { namespace { +se::port::StatusOr HasInputsWithMismatchingDeadness( + const DeadnessAnalysis& deadness_analysis, const Node& n) { + absl::optional pred; + for (const Edge* edge : n.in_edges()) { + TF_ASSIGN_OR_RETURN( + DeadnessAnalysis::DeadnessPredicate this_pred, + deadness_analysis.GetPredicateFor(edge->src(), edge->src_output())); + if (pred && *pred != this_pred) { + return true; + } + pred = this_pred; + } + + return false; +} + using deadness_analysis_internal::ComputePredicates; using deadness_analysis_internal::PredicateMapTy; @@ -219,7 +235,10 @@ TEST(DeadnessAnalysisTest, BasicPositive) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, BasicNegative) { @@ -232,7 +251,10 @@ TEST(DeadnessAnalysisTest, BasicNegative) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, AndIsCommutative) { @@ -260,11 +282,27 @@ TEST(DeadnessAnalysisTest, AndIsCommutative) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node())); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node())); + bool has_inputs_with_mismatching_deadness; - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node())); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node())); + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *live0.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); + + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *live1.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); + + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *halfdead0.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); + + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *halfdead1.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, AndIsAssociative) { @@ -287,7 +325,10 @@ TEST(DeadnessAnalysisTest, AndIsAssociative) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, OrIsCommutative) { @@ -312,11 +353,27 @@ TEST(DeadnessAnalysisTest, OrIsCommutative) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node())); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node())); + bool has_inputs_with_mismatching_deadness; - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node())); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node())); + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *live0.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); + + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *live1.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); + + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *halfdead0.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); + + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *halfdead1.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, OrIsAssociative) { @@ -336,7 +393,10 @@ TEST(DeadnessAnalysisTest, OrIsAssociative) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, AndOfOr) { @@ -358,7 +418,10 @@ TEST(DeadnessAnalysisTest, AndOfOr) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add2.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, OrOfAnd) { @@ -382,7 +445,10 @@ TEST(DeadnessAnalysisTest, OrOfAnd) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add2.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) { @@ -430,7 +496,10 @@ TEST(DeadnessAnalysisTest, AndOrDistributive) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add3.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, Ternary) { @@ -454,7 +523,10 @@ TEST(DeadnessAnalysisTest, Ternary) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, Recv) { @@ -469,7 +541,10 @@ TEST(DeadnessAnalysisTest, Recv) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, HostRecv) { @@ -484,7 +559,10 @@ TEST(DeadnessAnalysisTest, HostRecv) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, Loop) { @@ -505,8 +583,17 @@ TEST(DeadnessAnalysisTest, Loop) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); + bool has_inputs_with_mismatching_deadness; + + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add0.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); + + TF_ASSERT_OK_AND_ASSIGN( + has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add1.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } { PredicateMapTy predicate_map; @@ -544,7 +631,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add0.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } { PredicateMapTy predicate_map; @@ -634,7 +724,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add0.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } { PredicateMapTy predicate_map; @@ -693,7 +786,10 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add0.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } { @@ -792,7 +888,10 @@ TEST(DeadnessAnalysisTest, ControlInputs) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, ControlTrigger) { @@ -819,7 +918,10 @@ TEST(DeadnessAnalysisTest, ControlTrigger) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, ControlInputsToMerge) { @@ -840,7 +942,10 @@ TEST(DeadnessAnalysisTest, ControlInputsToMerge) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add.node())); + EXPECT_FALSE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, RecvVsSwitch) { @@ -857,7 +962,10 @@ TEST(DeadnessAnalysisTest, RecvVsSwitch) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node())); + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *logical_and.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); } TEST(DeadnessAnalysisTest, RecvVsSwitchText) { @@ -959,5 +1067,25 @@ TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) { EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false"); } +TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output condition_ref_var = + ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_util.cc b/tensorflow/compiler/jit/device_util.cc new file mode 100644 index 00000000000..200e795a2e8 --- /dev/null +++ b/tensorflow/compiler/jit/device_util.cc @@ -0,0 +1,206 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_util.h" + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { +namespace jit { +using xla::StatusOr; + +void DeviceSet::Insert(DeviceId device_id) { + int word_index = device_id.id() / kWordSize; + int bit_index = device_id.id() % kWordSize; + + if (word_index >= storage_.size()) { + storage_.resize(word_index + 1, 0); + } + + storage_[word_index] |= (1ull << bit_index); +} + +void DeviceSet::UnionWith(const DeviceSet& other) { + if (other.storage_.size() > storage_.size()) { + storage_.resize(other.storage_.size(), 0); + } + + for (int i = 0; i < other.storage_.size(); i++) { + storage_[i] |= other.storage_[i]; + } +} + +bool DeviceSet::IsEmpty() const { + return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; }); +} + +xla::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { + TF_RET_CHECK(!name.empty()); + + auto it = name_to_id_.find(name); + if (it != name_to_id_.end()) { + return it->second; + } + + int new_id = names_.size(); + names_.push_back(string(name)); + id_to_device_type_.push_back(absl::make_unique("")); + DeviceType* device_type = id_to_device_type_.back().get(); + TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type)); + + is_cpu_.push_back(device_type->type_string() == DEVICE_CPU); + is_gpu_.push_back(device_type->type_string() == DEVICE_GPU); + + name_to_id_.emplace(string(name), DeviceId(new_id)); + + const XlaOpRegistry::DeviceRegistration* compilation_device; + if (!XlaOpRegistry::GetCompilationDevice(device_type->type(), + &compilation_device)) { + compilation_device = nullptr; + } + id_to_compilation_device_.push_back(compilation_device); + + return DeviceId(new_id); +} + +string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { + std::vector names; + device_set.ForEach([&](DeviceId device_id) { + names.push_back(string(GetNameFor(device_id))); + return false; + }); + + return absl::StrCat("[", absl::StrJoin(names, ","), "]"); +} +} // namespace jit + +Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device, &parsed)) { + return errors::Internal("Malformed assigned device '", device, "'"); + } + *device_type = DeviceType(parsed.type); + return Status::OK(); +} + +xla::StatusOr> PickDeviceForXlaImpl( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu, + bool failure_to_pick_is_error) { +#define FAILED_TO_PICK_DEVICE(failing_status) \ + do { \ + if (failure_to_pick_is_error) { \ + return failing_status; \ + } else { \ + return {absl::nullopt}; \ + } \ + } while (false) + + absl::optional maybe_gpu_device; + absl::optional maybe_cpu_device; + absl::optional maybe_unknown_device; + + bool multiple_cpu_devices = false; + bool multiple_gpu_devices = false; + bool multiple_unknown_devices = false; + + devices.ForEach([&](jit::DeviceId device) { + if (device_info_cache.IsGpu(device)) { + if (maybe_gpu_device) { + multiple_gpu_devices = true; + return false; + } + maybe_gpu_device = device; + } else if (device_info_cache.IsCpu(device)) { + if (maybe_cpu_device) { + multiple_cpu_devices = true; + return false; + } + maybe_cpu_device = device; + } else { + if (maybe_unknown_device) { + multiple_unknown_devices = true; + return false; + } + maybe_unknown_device = device; + } + + return true; + }); + + if (multiple_cpu_devices) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple CPU devices ", device_info_cache.DebugString(devices))); + } + + if (multiple_gpu_devices) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple GPU devices ", device_info_cache.DebugString(devices))); + } + + if (multiple_unknown_devices) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple unknown devices ", device_info_cache.DebugString(devices))); + } + + if (maybe_unknown_device && maybe_gpu_device) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Found both unknown and GPU devices: ", + device_info_cache.GetNameFor(*maybe_unknown_device), ", ", + device_info_cache.GetNameFor(*maybe_gpu_device))); + } + + if (!allow_mixing_unknown_and_cpu) { + if (maybe_unknown_device && maybe_cpu_device) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Found both unknown and CPU devices: ", + device_info_cache.GetNameFor(*maybe_unknown_device), ", ", + device_info_cache.GetNameFor(*maybe_cpu_device))); + } + } + + if (maybe_gpu_device) { + return {*maybe_gpu_device}; + } else if (maybe_unknown_device) { + return {*maybe_unknown_device}; + } else if (maybe_cpu_device) { + return {*maybe_cpu_device}; + } + + FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!")); + +#undef FAILED_TO_PICK_DEVICE +} + +xla::StatusOr PickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) { + TF_ASSIGN_OR_RETURN(absl::optional device_id, + PickDeviceForXlaImpl(device_info_cache, devices, + allow_mixing_unknown_and_cpu, + /*failure_to_pick_is_error=*/true)); + return *device_id; +} + +xla::StatusOr> MaybePickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) { + return PickDeviceForXlaImpl(device_info_cache, devices, + allow_mixing_unknown_and_cpu, + /*failure_to_pick_is_error=*/false); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h new file mode 100644 index 00000000000..f26a565ff12 --- /dev/null +++ b/tensorflow/compiler/jit/device_util.h @@ -0,0 +1,211 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace jit { +// Instances of DeviceId represent TensorFlow devices as integers. +// +// This helps avoid having to manipulate device names as strings when +// auto-clustering. +class DeviceId { + public: + DeviceId(DeviceId&&) = default; + DeviceId(const DeviceId&) = default; + DeviceId& operator=(const DeviceId&) = default; + + bool operator==(const DeviceId& other) const { return id() == other.id(); } + bool operator!=(const DeviceId& other) const { return !(*this == other); } + + private: + int id_; + + explicit DeviceId(int id) : id_(id) {} + + int id() const { return id_; } + + friend class DeviceInfoCache; + friend class DeviceSet; +}; + +// A set of DeviceIds, represented as a bitmap. +class DeviceSet { + public: + void Insert(DeviceId device_id); + void UnionWith(const DeviceSet& other); + bool IsEmpty() const; + + // Calls `func` on each DeviceId in the set. Stops iterating early if `func` + // return false. + // + // TODO(sanjoy): Change this to take a typed std::function if that's + // performance neutral. + template + void ForEach(FnTy func) const { + // This is really a poor man's iterator, we should consider writing a proper + // iterator if this ends up being used widely. + for (int word_index = 0; word_index < storage_.size(); word_index++) { + uint64 word = storage_[word_index]; + while (word != 0) { + uint64 only_lowest_bit_set = word & -word; + // The number of trailing zeros in a non-zero word is the index of the + // least significant 1. + int bit_index = ctz_uint64(word); + if (!func(DeviceId(word_index * kWordSize + bit_index))) { + return; + } + word ^= only_lowest_bit_set; + } + } + } + + private: + static int ctz_uint64(uint64 x) { + DCHECK_NE(x, 0); +#ifdef __GNUC__ + return __builtin_ctzl(x); +#else + int result = 0u; + while ((x & 1u) == 0u) { + x >>= 1; + ++result; + } + return result; +#endif + } + + absl::InlinedVector storage_; + + const int kWordSize = 64; +}; + +// Caches some miscellaneous information about TF devices. Thread compatible. +class DeviceInfoCache { + public: + bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; } + bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; } + + absl::string_view GetNameFor(DeviceId device) const { + return names_[device.id()]; + } + + xla::StatusOr GetIdFor(absl::string_view name); + + using DeviceRegistration = const XlaOpRegistry::DeviceRegistration; + + DeviceRegistration* GetCompilationDevice(DeviceId device) const { + return id_to_compilation_device_[device.id()]; + } + + xla::StatusOr GetCompilationDevice( + absl::string_view name) { + TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name)); + return GetCompilationDevice(device_id); + } + + const DeviceType& GetDeviceTypeFor(DeviceId device) const { + return *id_to_device_type_[device.id()]; + } + + using DeviceTypeConstRef = std::reference_wrapper; + + xla::StatusOr GetDeviceTypeFor( + absl::string_view device_name) { + TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name)); + return std::cref(*id_to_device_type_[device_id.id()]); + } + + string DebugString(const DeviceSet& device_set) const; + + private: + absl::flat_hash_map name_to_id_; + + // These fields are populated for a device in GetIdFor, *before* we give out a + // DeviceId. + std::vector + id_to_compilation_device_; + std::vector> id_to_device_type_; + std::vector names_; + std::vector is_cpu_; + std::vector is_gpu_; +}; + +} // namespace jit + +// Returns the DeviceType corresponding to 'device'. +Status DeviceNameToDeviceType(const string& device, DeviceType* device_type); + +// Picks the device for which XLA should compile a cluster that contains +// operations placed in devices in `devices`. For instance a cluster that +// contains operations solely placed on the CPU will be compiled into a CPU +// executable by XLA, whereas a cluster that contains operations placed on the +// CPU and also operations placed on the GPU will be compiled into a GPU +// executable. +// +// Returns a non-OK Status if no unambiguous choice of device exists. +// +// We choose the device using the following rules: +// +// - It is an error for `device_names` to contain more than one device of the +// same type. +// - GPU is preferred over CPU. +// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are +// preferred over CPU. +// - XLA devices count as "unrecognized devices". +// +// This set of rules above implicitly assume that XLA:GPU can compile all +// operations in the cluster that XLA:CPU can compile, and if +// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile +// all operations in the cluster that XLA:CPU can compile. +// +// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of +// the following things: +// +// - Let MarkForCompilationPass not inject CPU-placed operations into clusters +// that will run on unknown devices (because the unknown XLA backend may not +// support every operation supported by CPU). +// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster +// that contains nodes placed on both the CPU and on unknown devices. In this +// case it is the responsibility of the optimization pass that injected the +// CPU nodes into the cluster to ensure that these nodes can be compiled by +// the unknown XLA backend. +xla::StatusOr PickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); + +// This is like `PickDeviceForXla` except that it returns nullopt (instead of a +// non-OK Status) if no unambiguous choice of device exists. +// +// We return a failing Status for errors unrelated to the device choice +// algorithm itself. +xla::StatusOr> MaybePickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ diff --git a/tensorflow/compiler/jit/device_util_test.cc b/tensorflow/compiler/jit/device_util_test.cc new file mode 100644 index 00000000000..9396c49d52e --- /dev/null +++ b/tensorflow/compiler/jit/device_util_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_util.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, + absl::Span device_names, + string* result) { + jit::DeviceInfoCache cache; + jit::DeviceSet device_set; + for (absl::string_view name : device_names) { + TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name)); + device_set.Insert(device_id); + } + + TF_ASSIGN_OR_RETURN( + jit::DeviceId result_id, + PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu)); + *result = string(cache.GetNameFor(result_id)); + return Status::OK(); +} + +void CheckPickDeviceResult(absl::string_view expected_result, + bool allow_mixing_unknown_and_cpu, + absl::Span inputs) { + string result; + TF_ASSERT_OK(PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result)) + << "inputs = [" << absl::StrJoin(inputs, ", ") + << "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu + << ", expected_result=" << expected_result; + EXPECT_EQ(result, expected_result); +} + +void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu, + absl::Span inputs) { + string result; + EXPECT_FALSE( + PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result).ok()); +} + +const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0"; +const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0"; +const char* kXPU0 = "/job:localhost/replica:0/task:0/device:XPU:0"; +const char* kYPU0 = "/job:localhost/replica:0/task:0/device:YPU:0"; + +const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1"; +const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1"; +const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1"; + +TEST(PickDeviceForXla, UniqueDevice) { + CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0}); +} + +TEST(PickDeviceForXla, DeviceOrder) { + CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0}); + CheckPickDeviceResult(kGPU0, false, {kCPU0, kGPU0}); + CheckPickDeviceResult(kXPU0, true, {kXPU0, kCPU0}); +} + +TEST(PickDeviceForXla, MultipleUnknownDevices) { + CheckPickDeviceHasError(false, {kXPU0, kYPU0}); +} + +TEST(PickDeviceForXla, GpuAndUnknown) { + CheckPickDeviceHasError(false, {kGPU0, kXPU1}); +} + +TEST(PickDeviceForXla, UnknownAndCpu) { + CheckPickDeviceHasError(false, {kXPU0, kCPU1}); +} + +TEST(PickDeviceForXla, MultipleDevicesOfSameType) { + CheckPickDeviceHasError(true, {kCPU0, kCPU1}); + CheckPickDeviceHasError(false, {kCPU0, kCPU1}); + CheckPickDeviceHasError(false, {kGPU0, kGPU1}); + CheckPickDeviceHasError(false, {kXPU0, kXPU1}); + CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0}); +} + +void SimpleRoundTripTestForDeviceSet(int num_devices) { + jit::DeviceSet device_set; + jit::DeviceInfoCache device_info_cache; + + std::vector expected_devices, actual_devices; + + for (int i = 0; i < num_devices; i++) { + string device_name = + absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i); + TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id, + device_info_cache.GetIdFor(device_name)); + device_set.Insert(device_id); + expected_devices.push_back(device_name); + } + + device_set.ForEach([&](jit::DeviceId device_id) { + actual_devices.push_back(string(device_info_cache.GetNameFor(device_id))); + return true; + }); + + EXPECT_EQ(expected_devices, actual_devices); +} + +TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); } + +TEST(DeviceSetTest, SimpleRoundTrip_Small) { + SimpleRoundTripTestForDeviceSet(8); +} + +TEST(DeviceSetTest, SimpleRoundTrip_Large) { + SimpleRoundTripTestForDeviceSet(800); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 8b6bffa267d..b6d97434eb0 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -25,12 +25,13 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -50,6 +51,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -108,14 +110,14 @@ void MarkGuaranteedConstants( for (const auto& src_arg : src_arg_pairs) { srcs.push_back(src_arg.first); } - ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, - /*leave=*/[&guaranteed_const_nodes](const Node* n) { - // TODO(vinuraja): Doesn't work in the presence of loops. - if (AreAllParentsGuaranteedConst(*n, - guaranteed_const_nodes)) { - guaranteed_const_nodes.insert(n); - } - }); + ReverseDFSFrom( + graph, srcs, /*enter=*/nullptr, + /*leave=*/[&guaranteed_const_nodes](const Node* n) { + // TODO(vinuraja): Doesn't work in the presence of loops. + if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) { + guaranteed_const_nodes.insert(n); + } + }); for (auto& src_arg : src_arg_pairs) { if (guaranteed_const_nodes.count(src_arg.first) != 0) { @@ -307,6 +309,13 @@ class Encapsulator { const std::unordered_map& node_images, std::vector>* src_arg_pairs); + // Records the src of the given edge as a control result of the graph. + // Used during graph to function conversion to tie control results to + // the function signature. + Status RecordControlResult( + const Edge* edge, + const std::unordered_map& node_images); + // Creates a _Retval node for the src node of edge, and add it to results_, // if none exists yet. If a new _Retval node is created, also adds the edge // within the subgraph from the src to the _Retval node. @@ -484,6 +493,11 @@ class Encapsulator { // Map from source tensor in the input graph to result #. std::unordered_map results_; + // Set of node names that are the source of a control output of the + // subgraph. We store strings here so that we can tolerate nodes being + // removed from the graph. + absl::flat_hash_set control_output_nodes_; + // The outside_compilation clusters in this subgraph. std::unordered_map outside_compilation_subgraphs_; @@ -801,6 +815,15 @@ Status Encapsulator::Subgraph::RecordArg( return Status::OK(); } +Status Encapsulator::Subgraph::RecordControlResult( + const Edge* edge, + const std::unordered_map& node_images) { + Node* src_node = edge->src(); + Node* src_image = node_images.at(src_node); + control_output_nodes_.insert(src_image->name()); + return Status::OK(); +} + Status Encapsulator::Subgraph::RecordResult( const Edge* edge, const std::unordered_map& node_images) { @@ -1117,17 +1140,22 @@ Status Encapsulator::Subgraph::BuildFunctionDef( function_def_name_ = name; FunctionDef fdef; + auto lookup = [this](const Node* node) -> absl::optional { + if (control_output_nodes_.contains(node->name())) { + return absl::make_optional(node->name()); + } + return absl::nullopt; + }; // Verify that the graph has well-formed control flow structure. std::vector dummy; TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy)); - TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); + TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef)); if (VLOG_IS_ON(1)) { VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), - *graph_, library); - dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), - fdef); + DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_, + library); + DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef); } const FunctionDef* original_fdef = library->Find(name); @@ -1190,11 +1218,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Replace function def " << name; - dump_graph::DumpGraphToFile( - absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, - library); - dump_graph::DumpFunctionDefToFile( - absl::StrCat("replace_encapsulate_fdef_", name), fdef); + DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name), + *graph_, library); + DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name), + fdef); } TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); @@ -1479,9 +1506,10 @@ Status Encapsulator::CopySubgraphEdges( src_subgraph.RecordOutsideCompilationInputOrControl( dst_outside_compilation_id, edge); } else { - // Ignore control edges leaving the subgraph. We will lift them onto the - // enclosing call operators in BuildOutputGraph(). - if (!edge->IsControlEdge()) { + if (edge->IsControlEdge()) { + TF_RETURN_IF_ERROR( + src_subgraph.RecordControlResult(edge, node_images)); + } else { TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); } } @@ -1556,7 +1584,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { if (VLOG_IS_ON(1)) { // Dump subgraphs. for (auto& entry : subgraphs_) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first), *entry.second.GetGraph(), library); } @@ -2320,16 +2348,15 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( return errors::Internal("Failed to find function ", node->type_string(), " in function library."); } - FunctionBody* fbody = nullptr; + std::unique_ptr fbody; TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fdef, node->attrs(), library, - [library](const string& op, const OpDef** sig) { - return library->LookUpOpDef(op, sig); - }, - &fbody)); - TF_RETURN_IF_ERROR( - InlineFunctionBody(*library, pruned_graph->get(), node, fbody)); - delete fbody; + FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody)); + + InlineFunctionBodyOptions inline_opts; + inline_opts.override_device = false; + + TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node, + fbody.get(), inline_opts)); } return Status::OK(); @@ -2394,8 +2421,7 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( &node_images, library)); if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference", - *pruned_graph, library); + DumpGraphToFile("pruned_graph_for_shape_inference", *pruned_graph, library); } for (auto& subgraph_entry : subgraphs_) { @@ -2471,8 +2497,6 @@ Status EncapsulateSubgraphsInFunctions( const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { - Status s; - Encapsulator encapsulator(std::move(group_attribute), std::move(outside_compilation_attribute), &graph_in); @@ -2526,19 +2550,49 @@ Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph, - options.flib_def); + DumpGraphToFile("encapsulate_subgraphs_before", **options.graph, + options.flib_def); } std::unique_ptr graph_out; FunctionLibraryDefinition* const library = options.flib_def; + // Constant folding below might need to run part of the function to compute + // constants. Create an FunctionLibraryRuntime with a single CPU device + // that can run the part of the function. + // NOTE: If this turns out to be slow, we can cache the FLRs keyed by + // `options`. + SessionOptions session_options; + auto* device_count = session_options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + std::vector> devices; + + DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU"); + if (!cpu_factory) { + return errors::NotFound( + "CPU Factory not registered. Can't run EncapsulateSubgraphsPass"); + } + TF_RETURN_IF_ERROR(cpu_factory->CreateDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + if (devices.empty()) { + return errors::NotFound( + "Failed to create a CPU device for EncapsulateSubgraphsPass"); + } + + std::unique_ptr device_mgr = + absl::make_unique(std::move(devices)); OptimizerOptions opts; std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env, + new ProcessFunctionLibraryRuntime(device_mgr.get(), + options.session_options->env, TF_GRAPH_DEF_VERSION, library, opts)); FunctionLibraryRuntime* flr = - pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0"); + if (flr == nullptr) { + return errors::Internal( + "Failed to create and retrieve function library runtime to run " + "constant folding"); + } auto rewrite_subgraph = [flr](const std::vector& arg_source_tensors, @@ -2637,8 +2691,8 @@ Status EncapsulateSubgraphsPass::Run( "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out, - options.flib_def); + DumpGraphToFile("encapsulate_subgraphs_after", *graph_out, + options.flib_def); } *options.graph = std::move(graph_out); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 261519de347..22a12a540ce 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -537,8 +537,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, XlaClusterInfo{func, func_name_attrs, xla_computation_node, std::map{}}); } + bool modified; s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, - graph_out.get(), flr, lib_def.get()); + graph_out.get(), flr, lib_def.get(), &modified); if (!s.ok()) return s; GraphDef graphdef_out; @@ -1105,8 +1106,10 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}}, - {"F"}}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}}, + {"F", "outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, @@ -1985,7 +1988,10 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}}}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}}, + {"outside_compilation_O1_host_compute"}}, }, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); @@ -2110,7 +2116,10 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}}}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}}, + {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2258,8 +2267,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}}, - {}}, + absl::Span( + {"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}}, + {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2271,8 +2281,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}}, - {}}}, + absl::Span({"_xla_token_arg_node", + "outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}}, + {"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"}}}, {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, {"h_0_retval_retval", "H:o:0"}}); diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 2264806d6bd..ae0912c3f23 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/encapsulate_util.h" + #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/shape_inference.h" @@ -24,6 +27,9 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +using stream_executor::port::StatusOr; namespace tensorflow { @@ -333,6 +339,43 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { return Status::OK(); } +StatusOr>>> +OutsideCompilationClusterDependencies( + const Graph* g, const string& outside_compilation_attr_name) { + auto cluster_deps = absl::make_unique< + absl::flat_hash_map>>(); + + for (const Edge* e : g->edges()) { + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation && + *src_outside_compilation != *dst_outside_compilation) { + auto dst_deps_it = cluster_deps->find(*dst_outside_compilation); + if (dst_deps_it == cluster_deps->end()) { + cluster_deps->insert(std::make_pair( + *dst_outside_compilation, + absl::flat_hash_set({*src_outside_compilation}))); + } else { + dst_deps_it->second.insert(*src_outside_compilation); + } + } + } + + auto cluster_deps_ordered = + absl::make_unique>>(); + + for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) { + std::vector ordered_deps(it->second.begin(), it->second.end()); + std::sort(ordered_deps.begin(), ordered_deps.end()); + cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps)); + } + + return std::move(cluster_deps_ordered); +} + Status PreprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index c9f16d14168..c873c2a888c 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -19,7 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -89,6 +91,15 @@ struct XlaClusterInfo { const std::map host_compute_core; }; +// Finds dependencies between outside compilation clusters, including both data +// dependencies and control dependencies. cluster_deps maps the name name of an +// outside compilation cluster to a set of names of outside compilation clusters +// that it depends on. +stream_executor::port::StatusOr< + std::unique_ptr>>> +OutsideCompilationClusterDependencies( + const Graph* g, const string& outside_compilation_attr_name); + // Preprocesses edges within the same XLA cluster. It will perform the following // operations in order: // diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index f0c9d573451..4e65971191a 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/types.h" @@ -30,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -372,8 +372,8 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, Status EncapsulateXlaComputationsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateXlaComputations(): " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", - **options.graph, options.flib_def); + << DumpGraphToFile("encapsulate_xla_computations_before", + **options.graph, options.flib_def); const char* additional_help = IsCpuGpuCompile(options.graph->get()) @@ -383,14 +383,14 @@ Status EncapsulateXlaComputationsPass::Run( TF_RETURN_WITH_CONTEXT_IF_ERROR(Encapsulate(options.graph, options.flib_def), additional_help); VLOG(1) << "EncapsulateXlaComputations() half-way: " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", - **options.graph, options.flib_def); + << DumpGraphToFile("encapsulate_xla_computations_halfway", + **options.graph, options.flib_def); TF_RETURN_WITH_CONTEXT_IF_ERROR(BuildXlaLaunchOps(options.graph->get()), additional_help); VLOG(1) << "EncapsulateXlaComputations() finished: " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", - **options.graph, options.flib_def); + << DumpGraphToFile("encapsulate_xla_computations_after", + **options.graph, options.flib_def); return Status::OK(); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 2a770c527b2..1df4d8b5d44 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_util.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -287,15 +289,20 @@ absl::optional> GetInferredInputShapes( return results; } +string host_compute_node_name(const string& original_oc_name) { + return absl::StrCat("outside_compilation_", original_oc_name, + "_host_compute"); +} + // Builds XlaHostCompute NodeDef from the outside compilation call node. xla::StatusOr BuildXlaHostComputeNodeDef( - const Node* call_node, const std::map& host_compute_core) { + const Node* call_node, const std::map& host_compute_core, + const absl::flat_hash_map>& cluster_deps) { string original_oc_name; TF_RETURN_IF_ERROR(GetNodeAttr( call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name)); - NodeDefBuilder host_compute_builder( - absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"), - "XlaHostCompute"); + NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name), + "XlaHostCompute"); // Copy all attributes. for (auto attr : call_node->attrs()) { @@ -309,9 +316,25 @@ xla::StatusOr BuildXlaHostComputeNodeDef( host_compute_builder.Attr("tpu_core", core); } - // Set input tokens. - host_compute_builder.Attr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + // Set input tokens and other outside compilation clusters that current + // cluster depends in `kXlaTokenArgNodeName`. This is needed because when + // outside compilation subgraphs are encapsulated and moved to host graph, + // control/data edges between them will only be reflected in host graph. + // From XLA's perspective, two originally dependent clusters are no longer + // connected, which makes them look like they can be scheduled for execution + // in arbitrary order even though in fact they must be executed in order + // according to their host-side graph dependency. This can cause deadlock. + // Therefore, we hint XLA what the correct ordering of these clusters should + // be to avoid deadlocks. + std::vector xla_token_input_nodes; + xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName); + auto cluster_deps_it = cluster_deps.find(original_oc_name); + if (cluster_deps_it != cluster_deps.end()) { + for (auto dep : cluster_deps_it->second) { + xla_token_input_nodes.emplace_back(host_compute_node_name(dep)); + } + } + host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes); // Populate inputs. std::vector input_dtypes; @@ -370,8 +393,9 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) { // Replace outside compilation function call node with XlaHostCompute node. // If the function call node has no input/output edges, we will just remove it // and not create a XlaHostCompute node. -Status ReplaceOrRemoveOutsideCompilationCallNode( - Graph* g, Node* call_node, const std::map& host_compute_core) { +xla::StatusOr ReplaceOrRemoveOutsideCompilationCallNode( + Graph* g, Node* call_node, const std::map& host_compute_core, + const absl::flat_hash_map>& cluster_deps) { // If the function call node has no input/output edges, just remove it. bool has_edge = false; for (auto e : call_node->in_edges()) { @@ -389,17 +413,18 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( if (!has_edge) { VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString(); g->RemoveNode(call_node); - return Status::OK(); + return nullptr; } // Build XlaHostCompute NodeDef. - TF_ASSIGN_OR_RETURN(NodeDef node_def, - BuildXlaHostComputeNodeDef(call_node, host_compute_core)); + TF_ASSIGN_OR_RETURN( + NodeDef node_def, + BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps)); TF_ASSIGN_OR_RETURN(Node * host_compute_node, ReplaceNode(g, call_node, node_def)); VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString(); - return Status::OK(); + return host_compute_node; } // Resets "device_ordinal" attr to placeholder value for related nodes @@ -493,14 +518,9 @@ Status ConstructHostGraph( device_ordinal_attr.set_i(0); protobuf::Map attrs; attrs["device_ordinal"] = device_ordinal_attr; - FunctionBody* host_fbody = nullptr; + std::unique_ptr host_fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(host_func), AttrSlice(&attrs), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &host_fbody)); - std::unique_ptr host_fbody_deleter(host_fbody); + *fld->Find(host_func), AttrSlice(&attrs), fld, &host_fbody)); // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse // reachable from sink node so all nodes will be copied. @@ -581,10 +601,9 @@ Status ConstructHostGraph( &host_graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("extract_outside_compilation_host_graph_for_", - xla_cluster_name), - host_graph, fld); + DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_", + xla_cluster_name), + host_graph, fld); } FunctionDef host_graph_fdef; @@ -605,7 +624,8 @@ Status ConstructHostGraph( Status ExpandHostGraphIntoMainGraph(Graph* main_graph, FunctionLibraryDefinition* fld, const string& host_graph_func_name, - Node* xla_computation_node) { + Node* xla_computation_node, + Node* pivot_node) { // Temporarily use "0" as "device_ordinal". It will be rewritten with the // correct value in a later pass. We cannot just use placeholder value here // because FunctionDef instantiation does not allow placeholder value for @@ -614,14 +634,9 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, device_ordinal_attr.set_i(0); protobuf::Map attrs; attrs["device_ordinal"] = device_ordinal_attr; - FunctionBody* fbody = nullptr; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); + std::unique_ptr fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(host_graph_func_name), + AttrSlice(&attrs), fld, &fbody)); Graph* host_graph = fbody->graph; // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse @@ -631,7 +646,11 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, // Copy all nodes. std::map node_map; - node_map[host_graph->source_node()] = main_graph->source_node(); + if (pivot_node) { + node_map[host_graph->source_node()] = pivot_node; + } else { + node_map[host_graph->source_node()] = main_graph->source_node(); + } node_map[host_graph->sink_node()] = main_graph->sink_node(); Status s = Status::OK(); auto copy_node_fn = [&](const Node* n) { @@ -684,21 +703,16 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, // 2) Remove control edges. // 3) Prune nodes that are not useful for shape inference. Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, - Graph* host_graph, + Graph* host_graph, Node* pivot_node, FunctionLibraryDefinition* fld) { // Use "0" as "device_ordinal". It does not matter for shape inference. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); protobuf::Map attrs; attrs["device_ordinal"] = device_ordinal_attr; - FunctionBody* fbody = nullptr; + std::unique_ptr fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); + *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, &fbody)); Graph* g = fbody->graph; // Find SendFromHost node. @@ -733,41 +747,45 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, for (Node* n : nodes) { g->RemoveNode(n); } - - std::map node_map; - node_map[host_graph->source_node()] = g->source_node(); - Status s; - auto copy_node_fn = [&](const Node* n) { - if (!s.ok()) { - return; - } - - if (node_map.find(n) != node_map.end()) { - return; - } - - NodeDef copy_def = n->def(); - Node* copy = g->AddNode(copy_def, &s); - if (!s.ok()) { - return; - } - for (auto e : n->in_edges()) { - if (node_map.find(e->src()) == node_map.end()) { - s = errors::Internal("Cannot find node image for ", - e->src()->DebugString()); - return; - } - g->AddEdge(node_map[e->src()], e->src_output(), copy, e->dst_input()); - } - - node_map[n] = copy; + Node* start_node = pivot_node ? pivot_node : host_graph->source_node(); + // Reverse DFS from send_from_host_main_graph, and stop at start_node. + struct Visit { + Node* n; + bool is_exiting; }; - // TODO(b/77601805): consolidate copy graph functions. - ReverseDFSFrom(*host_graph, - std::vector{send_from_host_main_graph}, - /*enter=*/nullptr, copy_node_fn, NodeComparatorID()); - if (!s.ok()) { - return s; + std::vector stack{{send_from_host_main_graph, false}}; + std::map node_map; + node_map[host_graph->source_node()] = g->source_node(); + while (!stack.empty()) { + Visit& curr = stack.back(); + if (curr.is_exiting) { + if (node_map.find(curr.n) == node_map.end()) { + Node* copy = g->CopyNode(curr.n); + if (curr.n != start_node) { + for (const Edge* e : curr.n->in_edges()) { + auto node_iter = node_map.find(e->src()); + if (node_iter == node_map.end()) { + return errors::Internal("Cannot find node image for ", + e->src()->DebugString()); + } + g->AddEdge(node_iter->second, e->src_output(), copy, + e->dst_input()); + } + } + node_map[curr.n] = copy; + } + stack.pop_back(); + } else { + curr.is_exiting = true; + if (curr.n != start_node) { + for (const Edge* e : curr.n->in_edges()) { + if (node_map.find(e->src()) != node_map.end()) { + continue; + } + stack.push_back({e->src(), false}); + } + } + } } send_from_host = node_map[send_from_host_main_graph]; @@ -789,7 +807,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, std::unordered_set{send_from_host}); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile(shape_inference_graph_name, *g, fld); + DumpGraphToFile(shape_inference_graph_name, *g, fld); } // Replace original shape inference graph. @@ -831,14 +849,9 @@ Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, device_ordinal_attr.set_i(0); protobuf::Map attrs; attrs["device_ordinal"] = device_ordinal_attr; - FunctionBody* fbody = nullptr; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(func_name), AttrSlice(&attrs), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); + std::unique_ptr fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(func_name), + AttrSlice(&attrs), fld, &fbody)); Graph* g = fbody->graph; // Find or create the key placeholder node. @@ -962,14 +975,10 @@ Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld, const string& while_node_name, const string& host_transfer_key) { // Instantiate the loop cond function. - FunctionBody* fbody = nullptr; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); + std::unique_ptr fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(loop_cond_func.name()), + AttrSlice(&loop_cond_func.attr()), + fld, &fbody)); Graph* g = fbody->graph; // Find the _Retval node and the loop cond node. @@ -1033,14 +1042,9 @@ Status RewriteHostWhileLoopCond( device_ordinal_temp_value.set_i(0); protobuf::Map attrs; attrs["device_ordinal"] = device_ordinal_temp_value; - FunctionBody* cond_fbody = nullptr; + std::unique_ptr cond_fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &cond_fbody)); - std::unique_ptr cond_fbody_deleter(cond_fbody); + *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, &cond_fbody)); Graph* cond_graph = cond_fbody->graph; Node* key_arg = nullptr; for (Node* n : cond_graph->nodes()) { @@ -1113,14 +1117,9 @@ Status RewriteHostWhileLoopBody( device_ordinal_temp_value.set_i(0); protobuf::Map attrs; attrs["device_ordinal"] = device_ordinal_temp_value; - FunctionBody* body_fbody = nullptr; + std::unique_ptr body_fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &body_fbody)); - std::unique_ptr body_fbody_deleter(body_fbody); + *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, &body_fbody)); Graph* body_graph = body_fbody->graph; Node* key_arg = nullptr; for (Node* n : body_graph->nodes()) { @@ -1615,12 +1614,17 @@ Status ExtractOutsideCompilationForFunction( // We cannot early return here, because we might have outside compilation in // If/While function body. + // Find dependencies between outside compilation clusters. + TF_ASSIGN_OR_RETURN(auto cluster_deps, + OutsideCompilationClusterDependencies( + fbody->graph, outside_compilation_attr_name)); + // Preprocess edges between different outside compilations. They will be // restored in `ConstructHostGraph()`. TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( fbody->graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("extract_outside_compilation_for_func_before_", func_name), *fbody->graph, fld); } @@ -1666,10 +1670,35 @@ Status ExtractOutsideCompilationForFunction( } } } + std::map host_compute_nodes; for (Node* n : outside_compilation_nodes) { TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); - TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( - graph_out.get(), n, host_compute_core)); + auto host_compute_node_or = ReplaceOrRemoveOutsideCompilationCallNode( + graph_out.get(), n, host_compute_core, *cluster_deps); + TF_RETURN_IF_ERROR(host_compute_node_or.status()); + Node* host_compute_node = host_compute_node_or.ValueOrDie(); + if (host_compute_node) { + host_compute_nodes[host_compute_node->name()] = host_compute_node; + } + } + // For XlaHostCompute nodes with dependencies, add control edges between them + // so XlaCompiler can handle them in correct order. + for (auto iter : host_compute_nodes) { + Node* host_compute_node = iter.second; + std::vector token_input_node_names; + TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(), + kXlaTokenInputNodesAttrName, + &token_input_node_names)); + for (const string& node_name : token_input_node_names) { + if (node_name == kXlaTokenArgNodeName) { + continue; + } + + auto iter = host_compute_nodes.find(node_name); + if (iter != host_compute_nodes.end()) { + graph_out->AddControlEdge(iter->second, host_compute_node); + } + } } // Handle nodes with associated functions. @@ -1705,7 +1734,7 @@ Status ExtractOutsideCompilationForFunction( TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); } if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("extract_outside_compilation_for_func_after_", func_name), *graph_out, fld); } @@ -1717,18 +1746,21 @@ Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) { + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + bool* modified) { if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_before", *g, fld); + DumpGraphToFile("extract_outside_compilation_before", *g, fld); } - std::vector shape_inference_graphs; + *modified = false; + auto node_name_index = g->BuildNodeNameIndex(); for (auto& iter : clusters) { string xla_cluster_name = iter.first; Node* n = iter.second.node; auto const& func_name_attrs = iter.second.func_name_attrs; auto const& host_compute_core = iter.second.host_compute_core; + std::vector shape_inference_graphs; bool has_outside_compilation; string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name()); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( @@ -1736,18 +1768,23 @@ Status ExtractOutsideCompilation( func_name_attrs, func_name_attrs.name(), host_graph_func_name, host_compute_core, flr, fld, &shape_inference_graphs, &has_outside_compilation)); - TF_RETURN_IF_ERROR( - ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); - TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); - } + *modified |= has_outside_compilation; - for (auto shape_inference_graph_name : shape_inference_graphs) { - TF_RETURN_IF_ERROR( - RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld)); + string pivot_name = absl::StrCat(xla_cluster_name, "/pivot"); + Node* pivot_node = node_name_index[pivot_name]; + TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph( + g, fld, host_graph_func_name, n, pivot_node)); + + TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); + + for (auto shape_inference_graph_name : shape_inference_graphs) { + TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph_name, + g, pivot_node, fld)); + } } if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_after", *g, fld); + DumpGraphToFile("extract_outside_compilation_after", *g, fld); } return Status::OK(); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index d64cc2a103e..0a29fdaa5c8 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -101,7 +101,8 @@ Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + bool* modified); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index 7c3a24feff8..2717487c78e 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -300,14 +300,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { &has_outside_compilation)); // Get rewritten XLA computation function. - FunctionBody *xla_fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("cluster_rewritten"), AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &xla_fbody)); - std::unique_ptr xla_fbody_deleter(xla_fbody); + std::unique_ptr xla_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), + AttrSlice(), &fld, &xla_fbody)); auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); // Check XlaHostCompute nodes. @@ -343,18 +338,13 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { EXPECT_EQ(shape_inference_graphs.size(), 0); // Check host graph: verify we have key placeholder and sequencer. - FunctionBody *host_fbody = nullptr; + std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); protobuf::Map host_func_attrs; host_func_attrs["device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &host_fbody)); - std::unique_ptr host_fbody_deleter(host_fbody); + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody)); Graph *host_graph = host_fbody->graph; Node *key_placeholder = nullptr, *sequencer = nullptr; for (Node *n : host_graph->nodes()) { @@ -428,18 +418,13 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { &has_outside_compilation)); // Check host graph is empty. - FunctionBody *host_fbody = nullptr; + std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); protobuf::Map host_func_attrs; host_func_attrs["device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &host_fbody)); - std::unique_ptr host_fbody_deleter(host_fbody); + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody)); Graph *host_graph = host_fbody->graph; EXPECT_EQ(host_graph->num_nodes(), 2); } @@ -476,31 +461,21 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { &has_outside_compilation)); // Check rewritten XLA graph: verify that we have no XlaHostCompute. - FunctionBody *xla_fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("cluster_rewritten"), AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &xla_fbody)); - std::unique_ptr xla_fbody_deleter(xla_fbody); + std::unique_ptr xla_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), + AttrSlice(), &fld, &xla_fbody)); for (Node *n : xla_fbody->graph->nodes()) { EXPECT_NE(n->type_string(), "XlaHostCompute"); } // Check host graph: verify we have no placeholder, but we have "const1". - FunctionBody *host_fbody = nullptr; + std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); protobuf::Map host_func_attrs; host_func_attrs["device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &host_fbody)); - std::unique_ptr host_fbody_deleter(host_fbody); + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody)); Graph *host_graph = host_fbody->graph; int num_key_placeholders = 0; for (Node *n : host_graph->nodes()) { @@ -600,18 +575,14 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { // Check host graph. { - FunctionBody *host_fbody = nullptr; + std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); protobuf::Map host_func_attrs; host_func_attrs["device_ordinal"] = device_ordinal_temp_value; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &host_fbody)); - std::unique_ptr host_fbody_deleter(host_fbody); + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), + AttrSlice(&host_func_attrs), &fld, + &host_fbody)); Graph *host_graph = host_fbody->graph; auto node_name_index = host_graph->BuildNodeNameIndex(); @@ -654,14 +625,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { // Check XLA graph. { - FunctionBody *xla_fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("cluster_rewritten"), AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &xla_fbody)); - std::unique_ptr xla_fbody_deleter(xla_fbody); + std::unique_ptr xla_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), + AttrSlice(), &fld, &xla_fbody)); Graph *xla_graph = xla_fbody->graph; auto node_name_index = xla_graph->BuildNodeNameIndex(); @@ -759,18 +725,14 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { // Check host graph. { - FunctionBody *host_fbody = nullptr; + std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); protobuf::Map host_func_attrs; host_func_attrs["device_ordinal"] = device_ordinal_temp_value; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &host_fbody)); - std::unique_ptr host_fbody_deleter(host_fbody); + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), + AttrSlice(&host_func_attrs), &fld, + &host_fbody)); Graph *host_graph = host_fbody->graph; auto node_name_index = host_graph->BuildNodeNameIndex(); @@ -899,18 +861,14 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { // Check host graph. { - FunctionBody *host_fbody = nullptr; + std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); protobuf::Map host_func_attrs; host_func_attrs["device_ordinal"] = device_ordinal_temp_value; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &host_fbody)); - std::unique_ptr host_fbody_deleter(host_fbody); + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), + AttrSlice(&host_func_attrs), &fld, + &host_fbody)); Graph *host_graph = host_fbody->graph; auto node_name_index = host_graph->BuildNodeNameIndex(); @@ -918,14 +876,10 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { Node *call_node = node_name_index["oc_call_fn"]; EXPECT_NE(call_node, nullptr); - FunctionBody *call_fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("oc_func_call_host_fn"), AttrSlice(&host_func_attrs), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &call_fbody)); - std::unique_ptr call_fbody_deleter(call_fbody); + std::unique_ptr call_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("oc_func_call_host_fn"), + AttrSlice(&host_func_attrs), &fld, + &call_fbody)); // Verify we have _XlaRecvAtHost and _XlaSendFromHost nodes. bool has_recv = false, has_send = false; @@ -942,14 +896,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { // Check XLA graph. { - FunctionBody *xla_fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("cluster_rewritten"), AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &xla_fbody)); - std::unique_ptr xla_fbody_deleter(xla_fbody); + std::unique_ptr xla_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), + AttrSlice(), &fld, &xla_fbody)); Graph *xla_graph = xla_fbody->graph; auto node_name_index = xla_graph->BuildNodeNameIndex(); @@ -958,14 +907,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { EXPECT_NE(fn_node, nullptr); EXPECT_EQ(fn_node->type_string(), "fn_oc"); - FunctionBody *call_fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper( - *fld.Find("fn_oc"), AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &call_fbody)); - std::unique_ptr call_fbody_deleter(call_fbody); + std::unique_ptr call_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("fn_oc"), AttrSlice(), &fld, + &call_fbody)); // Verify we have XlaHostCompute nodes. bool has_hc = false; @@ -978,4 +922,165 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { } } +TEST_F(ExtractOutsideCompilationForFunctionTest, + OutsideCompilationClusterDataDependency) { + // Build the XLA computation func. + // "const0" + // "identity0" = "const0" (outside compilation cluster "0") + // "identity1" = "identity0" (outside compilation cluster "1") + // "identity2" = "identity1" + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {2}); + Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); + Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); + Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString() + << std::endl; + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity0"]->AddAttr("_oc", "0"); + node_name_image["identity1"]->AddAttr("_oc", "1"); + + PartialTensorShape shape({2}); + node_name_image["identity1"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Get rewritten XLA computation function. + std::unique_ptr xla_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), + AttrSlice(), &fld, &xla_fbody)); + auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); + + // Check XlaHostCompute nodes. + Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"]; + EXPECT_NE(host_compute_0, nullptr); + Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"]; + EXPECT_NE(host_compute_1, nullptr); + + // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr. + std::vector token_input_nodes; + TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()), + "_xla_token_input_nodes", &token_input_nodes)); + + std::vector expected_token_input_nodes_0({"_xla_token_arg_node"}); + EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0); + token_input_nodes.clear(); + std::vector expected_token_input_nodes_1( + {"_xla_token_arg_node", "outside_compilation_0_host_compute"}); + TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), + "_xla_token_input_nodes", &token_input_nodes)); + EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1); + + // Check there is a control edge from host_compute_0 to host_compute_1. + bool has_control_edge = false; + for (const Edge *e : host_compute_1->in_edges()) { + if (e->IsControlEdge() && e->src() == host_compute_0) { + has_control_edge = true; + break; + } + } + EXPECT_TRUE(has_control_edge); +} + +TEST_F(ExtractOutsideCompilationForFunctionTest, + OutsideCompilationClusterControlDependency) { + // Build the XLA computation func. + // "const0" + // "identity0" = "const0" (outside compilation cluster "0") + // "identity1" = "const0" "^identity0" (outside compilation cluster "1", + // control depdent on cluster "0") + // "identity2" = "identity1" + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {2}); + Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); + Output identity1 = ops::Identity( + s.WithOpName("identity1").WithControlDependencies(identity0), const0); + Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + std::cout << "Graph is " << (*g).ToGraphDefDebug().DebugString() + << std::endl; + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity0"]->AddAttr("_oc", "0"); + node_name_image["identity1"]->AddAttr("_oc", "1"); + + PartialTensorShape shape({2}); + node_name_image["identity1"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Get rewritten XLA computation function. + std::unique_ptr xla_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), + AttrSlice(), &fld, &xla_fbody)); + auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); + + // Check XlaHostCompute nodes. + Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"]; + EXPECT_NE(host_compute_0, nullptr); + Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"]; + EXPECT_NE(host_compute_1, nullptr); + + // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr. + std::vector token_input_nodes; + TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()), + "_xla_token_input_nodes", &token_input_nodes)); + + std::vector expected_token_input_nodes_0({"_xla_token_arg_node"}); + EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0); + token_input_nodes.clear(); + std::vector expected_token_input_nodes_1( + {"_xla_token_arg_node", "outside_compilation_0_host_compute"}); + TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), + "_xla_token_input_nodes", &token_input_nodes)); + EXPECT_EQ(token_input_nodes, expected_token_input_nodes_1); + + // Check there is a control edge from host_compute_0 to host_compute_1. + bool has_control_edge = false; + for (const Edge *e : host_compute_1->in_edges()) { + if (e->IsControlEdge() && e->src() == host_compute_0) { + has_control_edge = true; + break; + } + } + EXPECT_TRUE(has_control_edge); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index fba69dfccc3..3ee3c5e48d4 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -15,6 +15,9 @@ limitations under the License. #include // NOLINT +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/strip.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/core/util/command_line_flags.h" @@ -23,30 +26,50 @@ namespace tensorflow { namespace { BuildXlaOpsPassFlags* build_ops_flags; -DumpGraphFlags* dump_graph_flags; MarkForCompilationPassFlags* mark_for_compilation_flags; XlaDeviceFlags* device_flags; XlaOpsCommonFlags* ops_flags; +IntroduceFloatingPointJitterPassFlags* jitter_flags; std::vector* flag_list; std::once_flag flags_init; -void AppendDumpGraphFlagsInternal(std::vector* flag_list) { - std::vector new_flags = { - Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix, - "Path prefix to which graphs dumped during debugging should be " - "written."), - }; - flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +bool SetterForXlaAutoJitFlag(const string& value) { + int32 opt_level; + // We need to use the mark_for_compilation_flags directly here instead of + // going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The + // latter will try to setup and parse flags, which would bring us back to this + // setter. + if (absl::SimpleAtoi(value, &opt_level)) { + mark_for_compilation_flags->xla_auto_jit_flag + .optimization_level_single_gpu = opt_level; + mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = + opt_level; + return true; + } + + absl::string_view value_sv(value); + if (!absl::ConsumePrefix(&value_sv, "single-gpu(") || + !absl::ConsumeSuffix(&value_sv, ")") || + !absl::SimpleAtoi(value_sv, &opt_level)) { + return false; + } + + mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu = + opt_level; + return true; } void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { std::vector new_flags = { - Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit, + Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0", "Control compilation of operators into XLA computations on CPU and " "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " "things very likely to be improved; 2 = on for everything. " - "Experimental."), + "If set to single-gpu() then this resolves to for single-GPU " + "graphs (graphs that have at least one node placed on a GPU and no " + "more than one GPU is in use through the entire graph) and 0 " + "otherwise. Experimental."), Flag("tf_xla_min_cluster_size", &mark_for_compilation_flags->tf_xla_min_cluster_size, "Minimum number of operators in an XLA compilation. Ignored for " @@ -65,10 +88,6 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { &mark_for_compilation_flags->tf_xla_clustering_fuel, "Places an artificial limit on the number of ops marked as " "eligible for clustering."), - Flag("tf_xla_fusion_only", - &mark_for_compilation_flags->tf_xla_fusion_only, - "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*."), Flag("tf_xla_disable_deadness_safety_checks_for_debugging", &mark_for_compilation_flags ->tf_xla_disable_deadness_safety_checks_for_debugging, @@ -80,20 +99,19 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { void AllocateAndParseFlags() { build_ops_flags = new BuildXlaOpsPassFlags; build_ops_flags->tf_xla_enable_lazy_compilation = true; - - dump_graph_flags = new DumpGraphFlags; - dump_graph_flags->tf_dump_graph_prefix = "/tmp/"; + build_ops_flags->tf_xla_print_cluster_outputs = false; mark_for_compilation_flags = new MarkForCompilationPassFlags; - mark_for_compilation_flags->tf_xla_auto_jit = 0; - mark_for_compilation_flags->tf_xla_min_cluster_size = 2; + mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu = + 0; + mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0; + mark_for_compilation_flags->tf_xla_min_cluster_size = 4; mark_for_compilation_flags->tf_xla_max_cluster_size = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_clustering_debug = false; mark_for_compilation_flags->tf_xla_cpu_global_jit = false; mark_for_compilation_flags->tf_xla_clustering_fuel = std::numeric_limits::max(); - mark_for_compilation_flags->tf_xla_fusion_only = false; mark_for_compilation_flags ->tf_xla_disable_deadness_safety_checks_for_debugging = false; @@ -103,32 +121,52 @@ void AllocateAndParseFlags() { ops_flags = new XlaOpsCommonFlags; ops_flags->tf_xla_always_defer_compilation = false; - flag_list = new std::vector({ - Flag("tf_xla_enable_lazy_compilation", - &build_ops_flags->tf_xla_enable_lazy_compilation, ""), + jitter_flags = new IntroduceFloatingPointJitterPassFlags; + jitter_flags->jitter_amount = 1e-5; - Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, - "Switch a device into 'on-demand' mode, where instead of " - "autoclustering ops are compiled one by one just-in-time."), + auto setter_for_jitter_tensor_names = [](string sequence) { + jitter_flags->tensor_names = absl::StrSplit(sequence, ','); + return true; + }; + + flag_list = new std::vector( + {Flag("tf_xla_enable_lazy_compilation", + &build_ops_flags->tf_xla_enable_lazy_compilation, ""), + Flag("tf_xla_print_cluster_outputs", + &build_ops_flags->tf_xla_print_cluster_outputs, + "If true then insert Print nodes to print out values produced by " + "XLA clusters."), + + Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), + + Flag("tf_xla_always_defer_compilation", + &ops_flags->tf_xla_always_defer_compilation, ""), + + Flag("tf_introduce_floating_point_jitter_to_tensors", + setter_for_jitter_tensor_names, "", + "The Tensors to add the jitter to. The tensors are named in the " + "TensorId format of :."), + Flag("tf_introduce_floating_point_jitter_amount", + &jitter_flags->jitter_amount, + "The amount of jitter to introduce. This amount is added to each " + "element in the tensors named in `tensor_names.")}); - Flag("tf_xla_always_defer_compilation", - &ops_flags->tf_xla_always_defer_compilation, ""), - }); - AppendDumpGraphFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); } } // namespace -const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { +bool SetXlaAutoJitFlagFromFlagString(const string& value) { std::call_once(flags_init, &AllocateAndParseFlags); - return *build_ops_flags; + return SetterForXlaAutoJitFlag(value); } -DumpGraphFlags* GetDumpGraphFlags() { +BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() { std::call_once(flags_init, &AllocateAndParseFlags); - return dump_graph_flags; + return build_ops_flags; } MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { @@ -146,14 +184,14 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { return *ops_flags; } +const IntroduceFloatingPointJitterPassFlags& +GetIntroduceFloatingPointJitterPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *jitter_flags; +} + void AppendMarkForCompilationPassFlags(std::vector* flag_list) { std::call_once(flags_init, &AllocateAndParseFlags); AppendMarkForCompilationPassFlagsInternal(flag_list); } - -void AppendDumpGraphFlags(std::vector* flag_list) { - std::call_once(flags_init, &AllocateAndParseFlags); - AppendDumpGraphFlagsInternal(flag_list); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index ed7810fcfd8..42608d1c145 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -23,14 +23,30 @@ limitations under the License. namespace tensorflow { -// Flags associated with the XLA bridge's mark_for_compilation_pass module. -struct MarkForCompilationPassFlags { +struct XlaAutoJitFlag { // Control compilation of operators into XLA computations on CPU and GPU // devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very // likely to be improved; 2 = on for everything. // + // If all non-CPU ops in the graph being optimized are placed on a single GPU + // and there is at least one node placed on that GPU then + // `optimization_level_single_gpu` applies. Otherwise + // `optimization_level_general` applies. + // // Experimental. - int32 tf_xla_auto_jit; + int32 optimization_level_single_gpu; + int32 optimization_level_general; +}; + +// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax +// is: +// : sets general and single_gpu setting to the provided number. +// single-gpu(): sets the single_gpu setting to the provided number. +bool SetXlaAutoJitFlagFromFlagString(const string& value); + +// Flags associated with the XLA bridge's mark_for_compilation_pass module. +struct MarkForCompilationPassFlags { + XlaAutoJitFlag xla_auto_jit_flag; // Minimum number of operators in an XLA compilation. Ignored for operators // placed on an XLA device or operators explicitly marked for compilation. @@ -49,11 +65,6 @@ struct MarkForCompilationPassFlags { // eligible for clustering. int64 tf_xla_clustering_fuel; - // tf_xla_fusion_only is effective only when global_jit_level is set to ON* - // and overrides its behavior. If true, enable fusion of element-wise - // operations only using XLA. - bool tf_xla_fusion_only; - // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then // we do not do deadness related safety checks. This is unsound in general, // but can be used as a debugging aid. @@ -81,12 +92,21 @@ struct BuildXlaOpsPassFlags { // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. // Defaults to true. bool tf_xla_enable_lazy_compilation; + + // If true then insert Print nodes to print out values produced by XLA + // clusters. Useful for debugging. + bool tf_xla_print_cluster_outputs; }; -// Flags for the XLA bridge's dump_graph module. -struct DumpGraphFlags { - // Path prefix to which graphs dumped during debugging should be written. - string tf_dump_graph_prefix; +// Flags for the IntroduceFloatingPointJitter pass. +struct IntroduceFloatingPointJitterPassFlags { + // The amount of jitter to introduce. This amount is added to each element in + // the tensors named in `tensor_names. + float jitter_amount; + + // The Tensors to add the jitter to. The tensors are named in the TensorId + // format of :. + std::vector tensor_names; }; // Return a pointer to the DumpGraphFlags struct; @@ -97,10 +117,12 @@ struct DumpGraphFlags { // parses TF_XLA_FLAGS for all of them. Those functions which return a pointer // always return the same pointer. MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); -const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); +BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags(); XlaDeviceFlags* GetXlaDeviceFlags(); const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); -DumpGraphFlags* GetDumpGraphFlags(); + +const IntroduceFloatingPointJitterPassFlags& +GetIntroduceFloatingPointJitterPassFlags(); // Appends the flag definitions associated with // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. @@ -108,8 +130,6 @@ DumpGraphFlags* GetDumpGraphFlags(); // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. void AppendMarkForCompilationPassFlags( std::vector* flag_list); -void AppendDumpGraphFlags(std::vector* flag_list); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 8212956adfe..f9be7c45743 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -13,8 +13,23 @@ cc_library( srcs = ["graphcycles.cc"], hdrs = ["graphcycles.h"], deps = [ + ":ordered_set", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "ordered_set", + hdrs = ["ordered_set.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", ], ) @@ -28,3 +43,14 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "ordered_set_test", + srcs = ["ordered_set_test.cc"], + deps = [ + ":ordered_set", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 756377bd950..6ec9b5a477a 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -34,14 +34,20 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/graphcycles/ordered_set.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { namespace { -typedef std::unordered_set NodeSet; +using NodeSet = absl::flat_hash_set; +using OrderedNodeSet = OrderedSet; + template struct VecStruct { typedef absl::InlinedVector type; @@ -50,13 +56,11 @@ template using Vec = typename VecStruct::type; struct Node { - Node() : in(4), out(4) {} // Small hashtables for in/out edges - int32 rank; // rank number assigned by Pearce-Kelly algorithm bool visited; // Temporary marker used by depth-first-search void* data; // User-supplied data - NodeSet in; // List of immediate predecessor nodes in graph - NodeSet out; // List of immediate successor nodes in graph + OrderedNodeSet in; // List of immediate predecessor nodes in graph + OrderedNodeSet out; // List of immediate successor nodes in graph }; } // namespace @@ -93,7 +97,7 @@ bool GraphCycles::CheckInvariants() const { if (!ranks.insert(nx->rank).second) { LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank; } - for (auto y : nx->out) { + for (int32 y : nx->out.GetSequence()) { Node* ny = r->nodes_[y]; if (nx->rank >= ny->rank) { LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment " @@ -124,14 +128,14 @@ int32 GraphCycles::NewNode() { void GraphCycles::RemoveNode(int32 node) { Node* x = rep_->nodes_[node]; - for (auto y : x->out) { - rep_->nodes_[y]->in.erase(node); + for (int32 y : x->out.GetSequence()) { + rep_->nodes_[y]->in.Erase(node); } - for (auto y : x->in) { - rep_->nodes_[y]->out.erase(node); + for (int32 y : x->in.GetSequence()) { + rep_->nodes_[y]->out.Erase(node); } - x->in.clear(); - x->out.clear(); + x->in.Clear(); + x->out.Clear(); rep_->free_nodes_.push_back(node); } @@ -144,12 +148,12 @@ void GraphCycles::SetNodeData(int32 node, void* data) { } bool GraphCycles::HasEdge(int32 x, int32 y) const { - return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end(); + return rep_->nodes_[x]->out.Contains(y); } void GraphCycles::RemoveEdge(int32 x, int32 y) { - rep_->nodes_[x]->out.erase(y); - rep_->nodes_[y]->in.erase(x); + rep_->nodes_[x]->out.Erase(y); + rep_->nodes_[y]->in.Erase(x); // No need to update the rank assignment since a previous valid // rank assignment remains valid after an edge deletion. } @@ -165,13 +169,13 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) { if (x == y) return false; Rep* r = rep_; Node* nx = r->nodes_[x]; - if (!nx->out.insert(y).second) { + if (!nx->out.Insert(y)) { // Edge already exists. return true; } Node* ny = r->nodes_[y]; - ny->in.insert(x); + ny->in.Insert(x); if (nx->rank <= ny->rank) { // New edge is consistent with existing rank assignment. @@ -182,8 +186,8 @@ bool GraphCycles::InsertEdge(int32 x, int32 y) { // We only need to consider nodes that fall in the range [ny->rank,nx->rank]. if (!ForwardDFS(r, y, nx->rank)) { // Found a cycle. Undo the insertion and tell caller. - nx->out.erase(y); - ny->in.erase(x); + nx->out.Erase(y); + ny->in.Erase(x); // Since we do not call Reorder() on this path, clear any visited // markers left by ForwardDFS. ClearVisitedBits(r, r->deltaf_); @@ -209,7 +213,7 @@ static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) { nn->visited = true; r->deltaf_.push_back(n); - for (auto w : nn->out) { + for (auto w : nn->out.GetSequence()) { Node* nw = r->nodes_[w]; if (nw->rank == upper_bound) { return false; // Cycle @@ -235,7 +239,7 @@ static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) { nn->visited = true; r->deltab_.push_back(n); - for (auto w : nn->in) { + for (auto w : nn->in.GetSequence()) { Node* nw = r->nodes_[w]; if (!nw->visited && lower_bound < nw->rank) { r->stack_.push_back(w); @@ -321,7 +325,7 @@ int GraphCycles::FindPath(int32 x, int32 y, int max_path_len, return path_len; } - for (auto w : r->nodes_[n]->out) { + for (auto w : r->nodes_[n]->out.GetSequence()) { if (seen.insert(w).second) { r->stack_.push_back(w); } @@ -375,31 +379,94 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) { } Node* nb = rep_->nodes_[b]; - std::unordered_set out = std::move(nb->out); - std::unordered_set in = std::move(nb->in); - for (auto y : out) { - rep_->nodes_[y]->in.erase(b); + OrderedNodeSet out = std::move(nb->out); + OrderedNodeSet in = std::move(nb->in); + for (int32 y : out.GetSequence()) { + rep_->nodes_[y]->in.Erase(b); } - for (auto y : in) { - rep_->nodes_[y]->out.erase(b); + for (int32 y : in.GetSequence()) { + rep_->nodes_[y]->out.Erase(b); } rep_->free_nodes_.push_back(b); - for (auto y : out) { + rep_->nodes_[a]->out.Reserve(rep_->nodes_[a]->out.Size() + out.Size()); + for (int32 y : out.GetSequence()) { InsertEdge(a, y); } - for (auto y : in) { + + rep_->nodes_[a]->in.Reserve(rep_->nodes_[a]->in.Size() + in.Size()); + for (int32 y : in.GetSequence()) { InsertEdge(y, a); } + return true; } -std::unordered_set GraphCycles::Successors(int32 node) { - return rep_->nodes_[node]->out; +absl::Span GraphCycles::Successors(int32 node) const { + return rep_->nodes_[node]->out.GetSequence(); } -std::unordered_set GraphCycles::Predecessors(int32 node) { - return rep_->nodes_[node]->in; +absl::Span GraphCycles::Predecessors(int32 node) const { + return rep_->nodes_[node]->in.GetSequence(); +} + +std::vector GraphCycles::SuccessorsCopy(int32 node) const { + absl::Span successors = Successors(node); + return std::vector(successors.begin(), successors.end()); +} + +std::vector GraphCycles::PredecessorsCopy(int32 node) const { + absl::Span predecessors = Predecessors(node); + return std::vector(predecessors.begin(), predecessors.end()); +} + +namespace { +void SortInPostOrder(absl::Span nodes, + std::vector* to_sort) { + absl::c_sort(*to_sort, [&](int32 a, int32 b) { + DCHECK(a == b || nodes[a]->rank != nodes[b]->rank); + return nodes[a]->rank > nodes[b]->rank; + }); +} +} // namespace + +std::vector GraphCycles::AllNodesInPostOrder() const { + absl::flat_hash_set free_nodes_set; + absl::c_copy(rep_->free_nodes_, + std::inserter(free_nodes_set, free_nodes_set.begin())); + + std::vector all_nodes; + all_nodes.reserve(rep_->nodes_.size() - free_nodes_set.size()); + for (int64 i = 0, e = rep_->nodes_.size(); i < e; i++) { + if (!free_nodes_set.contains(i)) { + all_nodes.push_back(i); + } + } + + SortInPostOrder(rep_->nodes_, &all_nodes); + return all_nodes; +} + +string GraphCycles::DebugString() const { + absl::flat_hash_set free_nodes_set; + for (int32 free_node : rep_->free_nodes_) { + free_nodes_set.insert(free_node); + } + + string result = "digraph {\n"; + for (int i = 0; i < rep_->nodes_.size(); i++) { + if (free_nodes_set.contains(i)) { + continue; + } + + for (int32 succ : rep_->nodes_[i]->out.GetSequence()) { + absl::StrAppend(&result, " \"", i, "\" -> \"", succ, "\"\n"); + } + } + + absl::StrAppend(&result, "}\n"); + + return result; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h index 44448fa3d78..ce171a2ead0 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ #define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_ +#include + // GraphCycles detects the introduction of a cycle into a directed // graph that is being built up incrementally. // @@ -38,8 +40,7 @@ limitations under the License. // FindPath() is linear in the size of the graph. // The current implementation uses O(|V|+|E|) space. -#include - +#include "absl/types/span.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -117,8 +118,26 @@ class GraphCycles { // Expensive: should only be called from graphcycles_test.cc. bool CheckInvariants() const; - std::unordered_set Successors(int32 node); - std::unordered_set Predecessors(int32 node); + // Warning: Do not use these if iterating over the span and modifying the + // GraphCycles at the same time. Instead use SuccessorsCopy/PredecessorsCopy. + absl::Span Successors(int32 node) const; + absl::Span Predecessors(int32 node) const; + + // Return a copy of the sucessors set. This is needed for code using the + // collection while modifying the GraphCycles. + std::vector SuccessorsCopy(int32 node) const; + // Return a copy of the predecessors set. This is needed for code using the + // collection while modifying the GraphCycles. + std::vector PredecessorsCopy(int32 node) const; + + // Returns all nodes in post order. + // + // If there is a path from X to Y then X appears after Y in the + // returned vector. + std::vector AllNodesInPostOrder() const; + + // Returns the graph in graphviz format. + string DebugString() const; // ---------------------------------------------------- struct Rep; diff --git a/tensorflow/compiler/jit/graphcycles/ordered_set.h b/tensorflow/compiler/jit/graphcycles/ordered_set.h new file mode 100644 index 00000000000..0417782b984 --- /dev/null +++ b/tensorflow/compiler/jit/graphcycles/ordered_set.h @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_ +#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +// This is a set data structure that provides a deterministic iteration order. +// The iteration order of elements only depends on the sequence of +// inserts/deletes, so as long as the inserts/deletes happen in the same +// sequence, the set will have the same iteration order. +// +// Assumes that T can be cheaply copied for simplicity. +template +class OrderedSet { + public: + // Inserts `value` into the ordered set. Returns true if the value was not + // present in the set before the insertion. + bool Insert(T value) { + bool new_insertion = + value_to_index_.insert({value, value_sequence_.size()}).second; + if (new_insertion) { + value_sequence_.push_back(value); + } + return new_insertion; + } + + // Removes `value` from the set. Assumes `value` is already present in the + // set. + void Erase(T value) { + auto it = value_to_index_.find(value); + DCHECK(it != value_to_index_.end()); + + // Since we don't want to move values around in `value_sequence_` we swap + // the value in the last position and with value to be deleted and then + // pop_back. + value_to_index_[value_sequence_.back()] = it->second; + std::swap(value_sequence_[it->second], value_sequence_.back()); + value_sequence_.pop_back(); + value_to_index_.erase(it); + } + + void Reserve(size_t new_size) { + value_to_index_.reserve(new_size); + value_sequence_.reserve(new_size); + } + + void Clear() { + value_to_index_.clear(); + value_sequence_.clear(); + } + + bool Contains(T value) const { return value_to_index_.contains(value); } + size_t Size() const { return value_sequence_.size(); } + + absl::Span GetSequence() const { return value_sequence_; } + + private: + // The stable order that we maintain through insertions and deletions. + std::vector value_sequence_; + + // Maps values to their indices in `value_sequence_`. + absl::flat_hash_map value_to_index_; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_ORDERED_SET_H_ diff --git a/tensorflow/compiler/jit/graphcycles/ordered_set_test.cc b/tensorflow/compiler/jit/graphcycles/ordered_set_test.cc new file mode 100644 index 00000000000..38ac1cfe9b6 --- /dev/null +++ b/tensorflow/compiler/jit/graphcycles/ordered_set_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/graphcycles/ordered_set.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { +TEST(OrderedSetTest, Insert) { + OrderedSet ordered_set; + EXPECT_TRUE(ordered_set.Insert(90)); + EXPECT_TRUE(ordered_set.Insert(100)); + EXPECT_TRUE(ordered_set.Insert(80)); + + EXPECT_FALSE(ordered_set.Insert(100)); + + EXPECT_EQ(ordered_set.Size(), 3); + + EXPECT_TRUE(ordered_set.Contains(90)); + EXPECT_TRUE(ordered_set.Contains(100)); + EXPECT_TRUE(ordered_set.Contains(80)); + + EXPECT_FALSE(ordered_set.Contains(40)); + + std::array expected_sequence = {90, 100, 80}; + EXPECT_EQ(ordered_set.GetSequence(), expected_sequence); +} + +TEST(OrderedSetTest, Erase) { + OrderedSet ordered_set; + EXPECT_TRUE(ordered_set.Insert(90)); + EXPECT_TRUE(ordered_set.Insert(100)); + EXPECT_TRUE(ordered_set.Insert(80)); + + ordered_set.Erase(100); + + EXPECT_EQ(ordered_set.Size(), 2); + + EXPECT_TRUE(ordered_set.Contains(90)); + EXPECT_FALSE(ordered_set.Contains(100)); + EXPECT_TRUE(ordered_set.Contains(80)); + + std::array expected_sequence_0 = {90, 80}; + EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_0); + + ordered_set.Erase(80); + + EXPECT_EQ(ordered_set.Size(), 1); + + EXPECT_TRUE(ordered_set.Contains(90)); + EXPECT_FALSE(ordered_set.Contains(100)); + EXPECT_FALSE(ordered_set.Contains(80)); + + std::array expected_sequence_1 = {90}; + EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_1); + + ordered_set.Erase(90); + + EXPECT_EQ(ordered_set.Size(), 0); + + EXPECT_FALSE(ordered_set.Contains(90)); + EXPECT_FALSE(ordered_set.Contains(100)); + EXPECT_FALSE(ordered_set.Contains(80)); + + std::array expected_sequence_2 = {}; + EXPECT_EQ(ordered_set.GetSequence(), expected_sequence_2); +} + +TEST(OrderedSetTest, Clear) { + OrderedSet ordered_set; + EXPECT_TRUE(ordered_set.Insert(90)); + EXPECT_TRUE(ordered_set.Insert(100)); + EXPECT_TRUE(ordered_set.Insert(80)); + + ordered_set.Clear(); + + EXPECT_EQ(ordered_set.Size(), 0); + + EXPECT_FALSE(ordered_set.Contains(90)); + EXPECT_FALSE(ordered_set.Contains(100)); + EXPECT_FALSE(ordered_set.Contains(80)); + + std::array expected_sequence = {}; + EXPECT_EQ(ordered_set.GetSequence(), expected_sequence); +} + +TEST(OrderedSetTest, LargeInsertions) { + const int kSize = 50 * 9000; + + OrderedSet ordered_set; + + for (int i = 0; i < kSize; i++) { + EXPECT_TRUE(ordered_set.Insert(i + 500)); + } + + for (int i = 0; i < kSize; i++) { + EXPECT_EQ(ordered_set.GetSequence()[i], i + 500); + } +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 5287fd175df..23931a0d7cd 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -27,12 +27,12 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { @@ -375,15 +375,15 @@ Status IncreaseDynamismForAutoJitPass::Run( const GraphOptimizationPassOptions& options) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", - **options.graph, options.flib_def); + DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", + **options.graph, options.flib_def); } bool changed; TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed)); if (changed && flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass", - **options.graph, options.flib_def); + DumpGraphToFile("increase_dynamism_for_auto_jit_pass", **options.graph, + options.flib_def); } return Status::OK(); diff --git a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc new file mode 100644 index 00000000000..ff0fa8710a8 --- /dev/null +++ b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc @@ -0,0 +1,153 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/graph/tensor_id.h" + +namespace tensorflow { +namespace { +std::vector>> GetNodesToModify( + const Graph& g, absl::Span tensor_names) { + absl::flat_hash_map name_to_node; + for (Node* n : g.op_nodes()) { + name_to_node[n->name()] = n; + } + + absl::flat_hash_map> nodes_to_modify_map; + + for (const string& tensor_name : tensor_names) { + TensorId tensor_id = ParseTensorName(tensor_name); + auto it = name_to_node.find(tensor_id.node()); + DCHECK(it != name_to_node.end()); + nodes_to_modify_map[it->second].push_back(tensor_id.index()); + } + + std::vector>> nodes_to_modify; + absl::c_copy(nodes_to_modify_map, std::back_inserter(nodes_to_modify)); + + absl::c_sort(nodes_to_modify, + [](const std::pair>& a, + const std::pair>& b) { + return a.first->id() < b.first->id(); + }); + + for (auto& p : nodes_to_modify) { + absl::c_sort(p.second); + p.second.erase(std::unique(p.second.begin(), p.second.end()), + p.second.end()); + } + + return nodes_to_modify; +} + +Status IntroduceJitterToTensor( + Graph* g, Node* n, int oidx, float jitter_amount, + absl::flat_hash_map, Output>* + node_to_jitter_constant) { + std::vector edges_to_update; + absl::c_copy_if(n->out_edges(), std::back_inserter(edges_to_update), + [&](const Edge* e) { return e->src_output() == oidx; }); + + if (edges_to_update.empty()) { + VLOG(1) << "No users for " << TensorId(n->name(), oidx).ToString(); + return Status::OK(); + } + + VLOG(1) << "Updating " << edges_to_update.size() << " users for " + << TensorId(n->name(), oidx).ToString(); + + Status status; + Scope s = NewInternalScope(g, &status, /*refiner=*/nullptr) + .NewSubScope(absl::StrCat(n->name(), "/jitter")); + + Output node_out(n, oidx); + Output jitter_constant; + DataType dtype = n->output_type(oidx); + auto it = node_to_jitter_constant->find({dtype, n}); + if (it == node_to_jitter_constant->end()) { + Tensor constant_tensor; + if (dtype == DT_FLOAT) { + constant_tensor = Tensor(static_cast(jitter_amount)); + } else if (dtype == DT_HALF) { + constant_tensor = Tensor(Eigen::half(jitter_amount)); + } else { + return errors::Unimplemented("Only float and half are supported"); + } + + jitter_constant = + ops::Const(s.WithOpName("jitter_amount"), constant_tensor); + (*node_to_jitter_constant)[{dtype, n}] = jitter_constant; + } else { + jitter_constant = it->second; + } + + Output jittered_output = + ops::Add(s.NewSubScope(absl::StrCat(oidx)).WithOpName("jittered_output"), + jitter_constant, node_out); + + TF_RETURN_IF_ERROR(status); + + for (const Edge* e : edges_to_update) { + VLOG(3) << "Updating " << e->dst()->name(); + TF_RETURN_IF_ERROR( + g->UpdateEdge(jittered_output.node(), 0, e->dst(), e->dst_input())); + } + + // Add a control edge to make sure that the two inputs to jittered_output are + // from the same frame. + g->AddControlEdge(n, jitter_constant.node()); + + return Status::OK(); +} +} // namespace + +Status IntroduceFloatingPointJitter(Graph* graph, + absl::Span tensor_names, + float jitter_amount) { + if (tensor_names.empty()) { + VLOG(3) << "Nothing to do"; + return Status::OK(); + } + + std::vector>> nodes_to_modify = + GetNodesToModify(*graph, tensor_names); + + absl::flat_hash_map, Output> + node_to_jitter_constant; + for (const auto& p : nodes_to_modify) { + for (int oidx : p.second) { + TF_RETURN_IF_ERROR(IntroduceJitterToTensor( + graph, p.first, oidx, jitter_amount, &node_to_jitter_constant)); + } + } + + return Status::OK(); +} + +Status IntroduceFloatingPointJitterPass::Run( + const GraphOptimizationPassOptions& options) { + const IntroduceFloatingPointJitterPassFlags& flags = + GetIntroduceFloatingPointJitterPassFlags(); + + return IntroduceFloatingPointJitter(options.graph->get(), flags.tensor_names, + flags.jitter_amount); +} +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_if_while.h b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h similarity index 52% rename from tensorflow/core/common_runtime/lower_if_while.h rename to tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h index efa3945bca4..115f72a6eea 100644 --- a/tensorflow/core/common_runtime/lower_if_while.h +++ b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,27 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_WHILE_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_WHILE_H_ +#ifndef TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_ #include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/lib/core/status.h" namespace tensorflow { - -// Rewrite If and While ops to use lower level control flow primitives instead. -class LowerIfWhilePass : public GraphOptimizationPass { +// A debug-only pass that introduces error into outputs of specific TF nodes. +// This can be used to check the sensitivity of a TF graph to floating point +// rounding differences. +// +// This pass is controlled by TF_XLA_FLAGS. Please see +// IntroduceFloatingPointJitterPassFlags for information on how to use this. +class IntroduceFloatingPointJitterPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; -#if defined(_MSC_VER) - static constexpr char* kLowerUsingSwitchMergeAttr = -#else - static constexpr char kLowerUsingSwitchMergeAttr[] = -#endif - "_lower_using_switch_merge"; -}; + IntroduceFloatingPointJitterPass() = default; + Status Run(const GraphOptimizationPassOptions& options) override; +}; } // namespace tensorflow -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_WHILE_H_ +#endif // TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_ diff --git a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h new file mode 100644 index 00000000000..ea7261bc872 --- /dev/null +++ b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h @@ -0,0 +1,27 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_ +#define TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_ + +#include "absl/types/span.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { +Status IntroduceFloatingPointJitter(Graph* graph, + absl::Span tensor_names, + float jitter_amount); +} + +#endif // TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_ diff --git a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_test.cc b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_test.cc new file mode 100644 index 00000000000..96ddfcbd025 --- /dev/null +++ b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/linalg_ops.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using testing::matchers::Const; +using testing::matchers::Inputs; +using testing::matchers::Name; +using testing::matchers::NodeWith; +using testing::matchers::Op; +using testing::matchers::Out; + +TEST(IntroduceFloatingPointJitterTest, SingleOutputFP32) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT); + Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT); + + Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a); + Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b); + + Output tanh_a = ops::Tanh(root.WithOpName("tanh_a"), sigmoid_a); + Output tanh_b = ops::Tanh(root.WithOpName("tanh_b"), sigmoid_b); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + std::vector tensor_names; + tensor_names.push_back("sigmoid_a"); + tensor_names.push_back("sigmoid_b"); + + TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f)); + VLOG(1) << graph->ToGraphDefDebug().DebugString(); + + auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a"))); + auto m_sigmoid_a_with_jitter = + NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a)); + auto m_tanh_a = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_a_with_jitter))); + + auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b"))); + auto m_sigmoid_b_with_jitter = + NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b)); + auto m_tanh_b = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_b_with_jitter))); + + Node* tanh_a_transformed = testing::FindNodeByName(graph.get(), "tanh_a"); + Node* tanh_b_transformed = testing::FindNodeByName(graph.get(), "tanh_b"); + + ASSERT_NE(tanh_a_transformed, nullptr); + ASSERT_NE(tanh_b_transformed, nullptr); + + EXPECT_THAT(tanh_a_transformed, m_tanh_a); + EXPECT_THAT(tanh_b_transformed, m_tanh_b); +} + +TEST(IntroduceFloatingPointJitterTest, TwoNodesOneUser) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT); + Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT); + + Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a); + Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b); + + Output add = ops::Add(root.WithOpName("add"), sigmoid_a, sigmoid_b); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + std::vector tensor_names; + tensor_names.push_back("sigmoid_a"); + tensor_names.push_back("sigmoid_b"); + + TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f)); + VLOG(1) << graph->ToGraphDefDebug().DebugString(); + + auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a"))); + auto m_sigmoid_a_with_jitter = + NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a)); + + auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b"))); + auto m_sigmoid_b_with_jitter = + NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b)); + + auto m_add = NodeWith(Op("Add"), Inputs(Out(m_sigmoid_a_with_jitter), + Out(m_sigmoid_b_with_jitter))); + + Node* add_transformed = testing::FindNodeByName(graph.get(), "add"); + + ASSERT_NE(add_transformed, nullptr); + + EXPECT_THAT(add_transformed, m_add); +} + +TEST(IntroduceFloatingPointJitterTest, NotFP32) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF); + + Output sigmoid = ops::Sigmoid(root.WithOpName("sigmoid"), input); + + Output tanh = ops::Tanh(root.WithOpName("tanh"), sigmoid); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + std::vector tensor_names; + tensor_names.push_back("sigmoid"); + + TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f)); + VLOG(1) << graph->ToGraphDefDebug().DebugString(); + + auto m_sigmoid = Out(NodeWith(Name("sigmoid"))); + auto m_sigmoid_with_jitter = + NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_sigmoid)); + auto m_tanh = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_with_jitter))); + + Node* tanh_transformed = testing::FindNodeByName(graph.get(), "tanh"); + + ASSERT_NE(tanh_transformed, nullptr); + + EXPECT_THAT(tanh_transformed, m_tanh); +} + +TEST(IntroduceFloatingPointJitterTest, MultiOutput) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF); + + ops::Svd svd(root.WithOpName("svd"), input); + + Output tanh_s = ops::Tanh(root.WithOpName("tanh_s"), svd.s); + Output tanh_u = ops::Tanh(root.WithOpName("tanh_u"), svd.u); + Output tanh_v = ops::Tanh(root.WithOpName("tanh_v"), svd.v); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + std::vector tensor_names; + tensor_names.push_back("svd:0"); + tensor_names.push_back("svd:2"); + + TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f)); + VLOG(1) << graph->ToGraphDefDebug().DebugString(); + + auto m_svd_s = Out(0, NodeWith(Name("svd"))); + auto m_svd_s_with_jitter = Out( + NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_s))); + + auto m_svd_u = Out(1, NodeWith(Name("svd"))); + + auto m_svd_v = Out(2, NodeWith(Name("svd"))); + auto m_svd_v_with_jitter = Out( + NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_v))); + + auto m_tanh_s = NodeWith(Op("Tanh"), Inputs(m_svd_s_with_jitter)); + auto m_tanh_u = NodeWith(Op("Tanh"), Inputs(m_svd_u)); + auto m_tanh_v = NodeWith(Op("Tanh"), Inputs(m_svd_v_with_jitter)); + + Node* tanh_s_transformed = testing::FindNodeByName(graph.get(), "tanh_s"); + ASSERT_NE(tanh_s_transformed, nullptr); + + Node* tanh_u_transformed = testing::FindNodeByName(graph.get(), "tanh_u"); + ASSERT_NE(tanh_u_transformed, nullptr); + + Node* tanh_v_transformed = testing::FindNodeByName(graph.get(), "tanh_v"); + ASSERT_NE(tanh_v_transformed, nullptr); + + EXPECT_THAT(tanh_s_transformed, m_tanh_s); + EXPECT_THAT(tanh_u_transformed, m_tanh_u); + EXPECT_THAT(tanh_v_transformed, m_tanh_v); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 7326b6c222b..69186da38f2 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" +#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" @@ -31,6 +32,9 @@ namespace tensorflow { REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, EncapsulateXlaComputationsPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25, + IntroduceFloatingPointJitterPass); + // from // third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc // FunctionalizeControlFlowPass: 27 diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 88d00f7f8e1..6df0991e354 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -62,7 +62,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; std::unique_ptr xla_allocator; - xla::DeviceMemoryAllocator* device_allocator = nullptr; + se::DeviceMemoryAllocator* device_allocator = nullptr; if (ctx->device_type() == DeviceType(DEVICE_CPU)) { platform_id = se::host::kHostPlatformId; diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index 7b4d4b5b473..eaa686780e4 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -40,7 +40,7 @@ class XlaPlatformInfo { se::Platform::Id platform_id, const XlaDevice::Metadata* xla_device_metadata, std::unique_ptr xla_allocator, - xla::DeviceMemoryAllocator* device_allocator) + se::DeviceMemoryAllocator* device_allocator) : device_type_(device_type), platform_id_(platform_id), xla_device_metadata_(xla_device_metadata), @@ -55,7 +55,7 @@ class XlaPlatformInfo { return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); } - xla::DeviceMemoryAllocator* allocator() const { + se::DeviceMemoryAllocator* allocator() const { return device_allocator_ ? device_allocator_ : xla_allocator_.get(); } DeviceType device_type() const { return device_type_; } @@ -86,7 +86,7 @@ class XlaPlatformInfo { // then device_allocator_ is null and xla_allocator_ points to an appropriate // XlaAllocator instance. std::unique_ptr xla_allocator_; - xla::DeviceMemoryAllocator* device_allocator_; + se::DeviceMemoryAllocator* device_allocator_; TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); }; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 11a710b2a4e..3d3497c5c36 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -24,16 +24,19 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/compilability_check_util.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/bounds_check.h" @@ -48,10 +51,16 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { +using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate; +using jit::DeviceId; +using jit::DeviceSet; +using xla::StatusOr; + // The clusters we create here are eventually lowered into an // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the // PartitionedCall op to execute the cluster in the regular graph executor if @@ -68,497 +77,915 @@ namespace { // cluster. const char* kXlaAlreadyClustered = "_XlaAlreadyClustered"; -// Aggregates information about what kinds of ops are allowed. -struct OperationFilter { - // Whether resource variable ops are allowed. We do not allow resource - // variable ops in called functions (either as direct TF calls or as higher - // order control flow ops) because we do not yet model their memory effects in - // jit/resource_variable_safety_analysis. - bool allow_resource_ops; +class MarkForCompilationPassImpl { + public: + struct DebugOptions { + // If true, do not respect the results of deadness analysis. + bool ignore_deadness_checks; - // Whether stateful RNG ops are allowed. XLA's RNG does not have the same - // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid - // auto-clustering stateful RNG ops. - bool allow_stateful_rng_ops; + // If true, do not respect the _XlaCompile=false attribute. + bool ignore_xla_compile_attr; - // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound - // to cluster ControlTrigger because of how we use deadness analysis. - bool allow_control_trigger; + int max_cluster_size; + int min_cluster_size; - // Whether ops with dummy implementations are allowed. We avoid - // auto-clustering these ops so that the user is not surprised when XLA is - // implicitly enabled. If the user explicitly specifies to use XLA, it is fine - // to resort to a dummy implementation. Currently Assert and CheckNumerics ops - // have dummy XLA implementations. - bool allow_dummy_ops; + // Compiler fuel for the auto-clustering algorithm. + // + // We decrement this value by one on every time we choose a compilation + // candidate and we stop clustering when it hits zero. This means the + // initial value for this variable (via --tf_xla_clustering_fuel=N) + // effectively acts as a "cap" for how much we cluster and we can bisect + // over this initial value to discover clustering decisions that cause a + // miscompile or a performance regression. + std::atomic* fuel; - // Whether ops that produce or consume DT_VARIANT values are allowed. We - // don't auto-cluster these ops because we don't yet support live-in or - // live-out DT_VARIANT values. - bool allow_ops_producing_or_consuming_variant; + bool dump_graphs; + }; + + MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph, + FunctionLibraryDefinition* flib_def, Env* env, + OptimizerOptions::GlobalJitLevel global_jit_level) + : debug_options_(debug_options), + graph_(graph), + flib_def_(flib_def), + env_(env), + global_jit_level_(global_jit_level) {} + + Status Run(); + + private: + // Represents a "cluster" or a connected subgraph of a TensorFlow graph. + class Cluster { + public: + // Constructs a trivial cluster representing a single TF node. + Cluster(int tf_graph_node_id, int effective_cluster_size, + bool has_functional_control_flow, DeviceSet devices, + absl::optional resource_op_device, + absl::optional resource_var_operation_node_id, + absl::optional deadness_predicate, + bool is_xla_compile_attr_true, absl::optional xla_scope) + : cycles_graph_node_id_(tf_graph_node_id), + effective_cluster_size_(effective_cluster_size), + has_functional_control_flow_(has_functional_control_flow), + devices_(std::move(devices)), + resource_op_device_(resource_op_device), + deadness_predicate_(deadness_predicate), + is_xla_compile_attr_true_(is_xla_compile_attr_true), + xla_scope_(std::move(xla_scope)) { + if (resource_var_operation_node_id.has_value()) { + resource_var_operation_node_ids_.push_back( + *resource_var_operation_node_id); + } + } + + // Merges `other` into this cluster, and clears `other`. This method is + // closely tied with the implementation of `MarkForCompilationPassImpl`. + void Merge(Cluster* other); + + // If this is a trivial cluster containing only one node then return the ID + // of that node. May not be called otherwise. + int GetIdOfOnlyNode() const { + DCHECK_EQ(cluster_size(), 1); + return cycles_graph_node_id(); + } + + // The number of TF nodes in this cluster. + int cluster_size() const { return cluster_size_; } + + // The ID of the cluster as represented in `cycles_graph_`. + int cycles_graph_node_id() const { return cycles_graph_node_id_; } + + // The size of the cluster excluding constant and identity nodes. + int effective_cluster_size() const { return effective_cluster_size_; } + + // True if the cluster has functional control flow like `If` and `While`. + bool has_functional_control_flow() const { + return has_functional_control_flow_; + } + + // The set of devices nodes in the cluster are placed on. + const DeviceSet& devices() const { return devices_; } + + // If the cluster has a resource operation then the device the resource + // operation is placed on. A cluster may have resource ops placed only on a + // single device. + const absl::optional& resource_op_device() const { + return resource_op_device_; + } + + // If not nullopt the a predicate that is true iff the cluster is alive. + // Otherwise the user has (unsafely) disabled deadness analysis. If this is + // unset on a single Cluster instance then it is unset on all Cluster + // instances. + const absl::optional& deadness_predicate() const { + return deadness_predicate_; + } + + // If true then the cluster has a XlaCompile=true attribute on one of its + // nodes. + bool is_xla_compile_attr_true() const { return is_xla_compile_attr_true_; } + + // If not nullopt then the all nodes in the cluster either do not have the + // XlaScope attribute set or have it set to the value returned. + const absl::optional& xla_scope() const { return xla_scope_; } + + // Returns the TF graph node IDs for the resource variable operations in + // this cluster. + absl::Span resource_var_operation_node_ids() const { + return resource_var_operation_node_ids_; + } + + string DebugString(const Graph& graph) const { + Node* node = graph.FindNodeId(cycles_graph_node_id()); + if (!node) { + // This should never happen but we try to be resilient because this is a + // debugging aid. + return absl::StrCat("NULL NODE IN #", cycles_graph_node_id()); + } + + return absl::StrCat("<", node->name(), " + ", cluster_size(), " others #", + cycles_graph_node_id(), ">"); + } + + private: + int cluster_size_ = 1; + int cycles_graph_node_id_; + int effective_cluster_size_; + bool has_functional_control_flow_; + DeviceSet devices_; + absl::optional resource_op_device_; + absl::optional deadness_predicate_; + bool is_xla_compile_attr_true_; + absl::optional xla_scope_; + std::vector resource_var_operation_node_ids_; + + TF_DISALLOW_COPY_AND_ASSIGN(Cluster); + }; + + // --------------------------------------------------------------------------- + // The pass proceeds in four steps, out of which `RunEdgeContractionLoop` and + // `CreateClusters` do most of the heavy lifting. + + // Initialize some internal data structures. + Status Initialize(); + + // Runs through all the nodes in `cycles_graph_` and tries to create clusters. + // Returns true if any new clusters were created. + StatusOr RunEdgeContractionLoopInPostOrderOnce(); + + // Runs through all the nodes in `cycles_graph_` and tries to contract high + // priority edges for clusters. Returns true if any new clusters were created. + // + // There are potentially many maximal clustering results, but they will not + // all be equally performant. Some clustering decision are likely to improve + // performance much more than others, and we cannot order contractions on this + // cost function, nor can we look at global information while deciding on + // individual edges to contract. Instead, we will make decisions on these + // important edges then make decisions on all other edges, causing the highest + // chance of all most important edges to be contracted. + // + // An example of where this might occur is with a digraph: + // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is + // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B + // should be clustered with A because it will prevent a potentially large + // tensor from A being computed and copied. + // + // This pass will ensure that contraction happens, which cannot be enforced in + // a single pass with the current algorithm. + // graph and prevent B->C from being clusterd in anticipation of a later A->B + // cluster. + StatusOr ContractPreferredEdges(); + + // Contracts as many edges as possible to create XLA clusters. After this + // finishes the clustering decisions made are implicitly stored in + // `clusters_`. + Status RunEdgeContractionLoop(); + + // Manifests the clustering decisions into the TF graph by tagging nodes with + // an `_XlaCluster` attribute. Also some basic filter logic, like + // tf_xla_min_cluster_size, are applied here. + Status CreateClusters(); + + Status DumpDebugInfo(); + + bool IsCompilationCandidate(Node* n) const { + return compilation_candidates_.find(n) != compilation_candidates_.end(); + } + + // Tries to contract the edge from cluster `from` to cluster `to`. Returns + // true if successful. + StatusOr TryToContractEdge(Cluster* from, Cluster* to); + + // Tries to contract each edge from `cluster_from`. Returns true if any edges + // were contracted, false otherwise. + StatusOr TryToContractEdgesFrom(Cluster* cluster_from); + + // Nodes that XLA can compile are put in `compilation_candidates_`. + Status FindCompilationCandidates(); + + bool CompilationDisallowedByXlaCompileAttr(Node* node); + + // Populates `clusters_`. + Status BuildInitialClusterSet(); + + StatusOr ShouldCompileClusterImpl(const Cluster& cluster); + + StatusOr ShouldCompileCluster(const Cluster& cluster); + + StatusOr ClusteringWillIntroduceInterDeviceDependency( + const Cluster& from, const Cluster& to); + + // Returns true if the devices in `cluster_a` and `cluster_b` are compatible + // and therefore not a hindrance for combining the two clusters into a larger + // cluster. + StatusOr AreDevicesCompatible(const Cluster& cluster_a, + const Cluster& cluster_b); + + void DumpPostClusteringGraphs(); + void VLogClusteringSummary(); + + Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size, + bool has_functional_control_flow, + const DeviceSet& device_set, + absl::optional resource_op_device, + absl::optional resource_var_operation_node_id, + absl::optional deadness_predicate, + bool is_xla_compile_attr_true, + absl::optional xla_scope) { + cluster_storage_.push_back(absl::make_unique( + cycles_graph_node_id, effective_cluster_size, + has_functional_control_flow, device_set, resource_op_device, + resource_var_operation_node_id, deadness_predicate, + is_xla_compile_attr_true, xla_scope)); + return cluster_storage_.back().get(); + } + + absl::optional GetXlaScope(Node* n); + + // Returns the cluster for node `n`. If two nodes, N1 and N2, are placed in + // the same cluster by the clustering algorithm then this function will return + // the same Cluster instance for N1 and N2. + // + // Returns nullptr if `n` is not a compilation candidate. + Cluster* GetClusterForNode(Node* n) { + return cluster_for_node_[n->id()].Get(); + } + + // Returns the cluster for a node in `cycles_graph_`. This uses the same + // underlying map because of how we set things up, but we can do an additional + // CHECK in this accessor. + // + // Returns nullptr if `node_id` is not a compilation candidate. + Cluster* GetClusterForCyclesGraphNode(int node_id) { + // We have to check `graph_->FindNodeId(node) == nullptr` because we add all + // nodes in [0, graph_->num_node_ids()) to the cycle detection graph but the + // TF graph may be missing some node ids. + if (node_id >= graph_->num_node_ids() || + graph_->FindNodeId(node_id) == nullptr) { + return nullptr; + } + Cluster* cluster = cluster_for_node_[node_id].Get(); + if (cluster) { + DCHECK_EQ(cluster->cycles_graph_node_id(), node_id); + } + return cluster; + } + + bool LogNotContractableAndReturnFalse(Cluster* from, Cluster* to, + absl::string_view reason); + + // Finds a path in `cycles_graph_` from `from` to `to` that is not a direct + // edge from `from` to `to`. + // + // Tries to find a path that contains at least one unclusterable node. + std::vector FindAlternatePathForDebugging(int from, int to); + + // Returns a string representing `cycles_graph_node_id`. If the node is + // unclusterable (either it is a phatom "frame" node or is not a compilation + // candidate) then set `*found_unclustered` to true. + string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered); + + // We could not contract the edge from `from` to `to`. Return a string + // describing an alternate path from `from` to `to` (besides the direct edge + // from `from` to `to`) which would have created a cycle had we contracted the + // edge. + // + // Tries (if possible) to find a path that contains at least one unclusterable + // node as it is surprising to the user if we print "A->B could not be + // contracted because of the path [P,Q,R]" where P, Q and R are all clusters + // since in that case a natural question is why we could not form a {A, P, Q, + // R, B} cluster. + string DescribePotentialCycle(int from, int to); + + // Merge the clusters `cluster_from` and `cluster_to`. After this step the + // larger combined cluster is represented by `cluster_from`'s ID in + // `cycles_graph_`. + bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) { + int from = cluster_from->cycles_graph_node_id(); + int to = cluster_to->cycles_graph_node_id(); + + if (!cycles_graph_.ContractEdge(from, to)) { + VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_) + << " -> " << cluster_to->DebugString(*graph_) + << " because contracting the edge would create a cycle via " + << DescribePotentialCycle(from, to) << "."; + return false; + } + + // Merge the clusters. + cluster_from->Merge(cluster_to); + + // Merge the UnionFind. + cluster_for_node_[from].Merge(&cluster_for_node_[to]); + + return true; + } + + string EdgeContractionFailureMsg(Cluster* from, Cluster* to, + absl::string_view reason) { + return absl::StrCat("Could not contract ", from->DebugString(*graph_), + " -> ", to->DebugString(*graph_), " because ", reason, + "."); + } + + DebugOptions debug_options_; + Graph* graph_; + FunctionLibraryDefinition* flib_def_; + Env* env_; + OptimizerOptions::GlobalJitLevel global_jit_level_; + absl::flat_hash_map should_compile_cluster_cache_; + jit::DeviceInfoCache device_info_cache_; + + bool initialized_ = false; + bool edges_contracted_ = false; + bool clusters_created_ = false; + + std::vector> cluster_storage_; + std::vector> cluster_for_node_; + GraphCycles cycles_graph_; + OrderedNodeSet compilation_candidates_; + std::unique_ptr deadness_analysis_; + int64 iteration_count_ = 0; + absl::flat_hash_set> unsafe_resource_deps_; }; -bool IsDummyImplOp(absl::string_view op_name) { - return op_name == "Assert" || op_name == "CheckNumerics"; -} +std::vector MarkForCompilationPassImpl::FindAlternatePathForDebugging( + int from, int to) { + std::vector rpo = cycles_graph_.AllNodesInPostOrder(); + absl::c_reverse(rpo); -bool IsStatefulRandomOp(absl::string_view op_name) { - return op_name == "RandomUniform" || op_name == "RandomShuffle" || - op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || - op_name == "TruncatedNormal" || op_name == "Multinomial"; -} + // best_pred_for_node[n] contains a predecessor of `n` that has an + // unclusterable node in some path from `from` to itself. + // best_pred_for_node[n] is unpopulated for nodes that are not reachable from + // `from`. We build this table up inductively by traversing the cycles graph + // in RPO. + absl::flat_hash_map best_pred_for_node; + best_pred_for_node[from] = -1; -bool OpProducesOrConsumesVariant(const Node& node) { - auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; - return absl::c_any_of(node.input_types(), is_variant) || - absl::c_any_of(node.output_types(), is_variant); -} + int rpo_index = 0, current_rpo_node; + do { + current_rpo_node = rpo[rpo_index++]; + absl::optional some_pred, preferred_pred; + for (int pred : cycles_graph_.Predecessors(current_rpo_node)) { + if (!best_pred_for_node.contains(pred)) { + continue; + } -bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { - // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient - // is really a kind of function call and will be handled by - // IsCompilableCall(). - if (node.type_string() == "SymbolicGradient") return false; - if (node.type_string() == "Const") { - // Skip Const op with type DT_STRING, since XLA doesn't support it, but the - // registered Const KernelDef says that it does, to support no-op Assert for - // tfcompile. - const AttrValue* attr = node.attrs().Find("dtype"); - if (attr != nullptr && attr->type() == DT_STRING) { - return false; + // Ignore the from->to edge since we're trying to find an alternate path. + if (current_rpo_node == to && pred == from) { + continue; + } + + some_pred = pred; + if (GetClusterForCyclesGraphNode(pred) == nullptr) { + preferred_pred = pred; + } } + + if (some_pred || preferred_pred) { + best_pred_for_node[current_rpo_node] = + preferred_pred.has_value() ? *preferred_pred : *some_pred; + } + } while (current_rpo_node != to); + + auto get_best_pred = [&](int n) { + auto it = best_pred_for_node.find(n); + CHECK(it != best_pred_for_node.end()); + return it->second; + }; + + std::vector path; + int current_path_node = get_best_pred(to); + while (current_path_node != from) { + path.push_back(current_path_node); + current_path_node = get_best_pred(current_path_node); } - // XLA does not offer guaranteed aliasing between the input and output of the - // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave - // such nodes out of XLA clusters. - if (HasForwardedRefInput(node)) { - VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast."; - return false; - } - - return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); + absl::c_reverse(path); + return path; } -bool HasResourceOutput(const Node& node) { - return std::find(node.output_types().begin(), node.output_types().end(), - DT_RESOURCE) != node.output_types().end(); +string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode( + int cycles_graph_node_id, bool* found_unclustered) { + Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id); + if (cluster) { + return cluster->DebugString(*graph_); + } + + *found_unclustered = true; + if (cycles_graph_node_id >= graph_->num_node_ids()) { + return absl::StrCat(""); + } + + Node* node = graph_->FindNodeId(cycles_graph_node_id); + if (!node) { + return absl::StrCat(""); + } + + return node->name(); } -bool HasResourceInput(const Node& node) { - return std::find(node.input_types().begin(), node.input_types().end(), - DT_RESOURCE) != node.input_types().end(); +string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) { + std::vector path_str; + bool found_unclustered = false; + absl::c_transform(FindAlternatePathForDebugging(from, to), + std::back_inserter(path_str), [&](int node_id) { + return DebugStringForCyclesGraphNode(node_id, + &found_unclustered); + }); + return absl::StrCat(!found_unclustered ? "(all clusters) " : "", "[", + absl::StrJoin(path_str, ","), "]"); } -// Returns true if `node` is a resource operation recognized by tf2xla that -// operates on something other than resource variables. -bool IsNonResourceVarResourceOp(const Node& node) { - // TODO(b/112837194): We can't cluster these because we only support - // snapshotting resource variables (and we can't e.g. snapshot stacks). This - // limitation may be fixable with some work. - const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string()); - return op_info && op_info->resource_kind() != XlaResourceKind::kVariable; +void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) { + // We keep our own cycles_graph_node_id_ to mirror what GraphCycles does. + + // Clearing out data structures in `other` is just a memory saving + // optimization and not needed for correctness. + + cluster_size_ += other->cluster_size_; + effective_cluster_size_ += other->effective_cluster_size_; + has_functional_control_flow_ |= other->has_functional_control_flow_; + + devices_.UnionWith(other->devices_); + + DCHECK(!(resource_op_device_.has_value() && + other->resource_op_device_.has_value()) || + *resource_op_device_ == *other->resource_op_device_) + << "AreDevicesCompatible should have returned false otherwise!"; + + if (!resource_op_device_.has_value()) { + resource_op_device_ = other->resource_op_device_; + } + + is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_; + + if (!xla_scope_.has_value()) { + xla_scope_ = std::move(other->xla_scope_); + } + + resource_var_operation_node_ids_.reserve( + resource_var_operation_node_ids_.size() + + other->resource_var_operation_node_ids_.size()); + absl::c_copy(other->resource_var_operation_node_ids_, + std::back_inserter(resource_var_operation_node_ids_)); + other->resource_var_operation_node_ids_.clear(); } -// Make sure we don't recurse infinitely on recursive functions. -const int kMaxRecursionDepth = 10; +Status IgnoreResourceOpForSafetyAnalysis( + jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) { + // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then + // ignore it during resource operation safety analysis. We need this hack + // because of two reasons: + // + // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. + // 2. We don't support live-out values of type DT_RESOURCE and live-in values + // of type DT_RESOURCE that are not resource variables. + // + // Together these imply we cannot let resource variable safety analysis + // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different + // clusters: both of them will have to be clustered because of (1) and we + // won't be able to keep the edge between the two as neither the input to the + // second XLA cluster nor the output from the first XLA cluster are supported + // because of (2). + // + // TODO(b/113100872): This can be fixed if the TensorFlow representation for + // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then + // (2) would no longer hold. -bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, - const OperationFilter& op_filter, int depth, - FunctionLibraryRuntime* lib_runtime); + if (n.assigned_device_name().empty()) { + *ignore = false; + return Status::OK(); + } -// Tests whether 'while_node' is a completely compilable loop. -// Every operator in the condition and body functions must be compilable for a -// while loop to be compilable. -bool IsCompilableWhile(const Node& while_node, - const DeviceType& jit_device_type, - const OperationFilter& op_filter, int depth, - FunctionLibraryRuntime* lib_runtime) { - const NameAttrList* name_attr; - NodeDef call; - Status status; - status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); - if (!status.ok()) { - VLOG(2) << "Rejecting While " << while_node.name() - << ": missing 'cond' attribute on While node."; - return false; + TF_ASSIGN_OR_RETURN( + const XlaOpRegistry::DeviceRegistration* registration, + device_info_cache->GetCompilationDevice(n.assigned_device_name())); + + if (!registration) { + *ignore = true; + } else { + *ignore = registration->cluster_resource_variable_ops_unsafely; } - const string cond_func = name_attr->name(); - call.set_name("while_cond"); - call.set_op(cond_func); - *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1, - lib_runtime)) { - VLOG(2) << "Rejecting While " << while_node.name() - << ": can't compile loop condition: " << cond_func; - return false; - } - status = GetNodeAttr(while_node.attrs(), "body", &name_attr); - if (!status.ok()) { - VLOG(2) << "Rejecting While " << while_node.name() - << ": missing 'body' attribute on While node."; - return false; - } - const string body_func = name_attr->name(); - call.set_name("while_body"); - call.set_op(body_func); - *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1, - lib_runtime)) { - VLOG(2) << "Rejecting While " << while_node.name() - << ": can't compile loop body: " << body_func; - return false; - } - return true; + return Status::OK(); } -// Tests whether 'call_def' is a call to a completely compilable function. -// Every operator in the function must be compilable for a function to be -// compilable. -bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, - const OperationFilter& op_filter, int depth, - FunctionLibraryRuntime* lib_runtime) { - if (depth > kMaxRecursionDepth) { - VLOG(2) << "Rejecting " << call_def.op() - << ": function depth limit exceeded."; - return false; +Status MarkForCompilationPassImpl::Initialize() { + TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_); + initialized_ = true; + + TF_RETURN_IF_ERROR(FindCompilationCandidates()); + + if (compilation_candidates_.empty()) { + VLOG(2) << "No compilable candidates"; + return Status::OK(); } - FunctionLibraryRuntime::Handle handle; - Status status = - lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); - if (!status.ok()) { - VLOG(2) << "Rejecting " << call_def.op() - << ": could not instantiate: " << status; - return false; + TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, + CreateCycleDetectionGraph(graph_, &cycles_graph_)); + if (!cycle_detection_graph_ok) { + return Status::OK(); } - auto release_handle_on_return = gtl::MakeCleanup( - [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); }); - - const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); - CHECK(fbody); - const FunctionDef& fdef = fbody->fdef; - bool noinline = false; - if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() && - noinline) { - // The underlying mechanism that calls non-inlined functions uses - // LocalExecutor, which interacts poorly with the LocalExecutor used by - // tf2xla to translate the TF graph into XLA. So we avoid this for now. - // - // TODO(b/36139787): Create a mechanism to set inlining hints. - VLOG(2) << "Rejecting " << call_def.op() - << ": can't compile noinline function."; - return false; + if (!debug_options_.ignore_deadness_checks) { + XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); + TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_)); } - for (Node* node : fbody->graph->op_nodes()) { - if (node->type_string() == "_Arg" || node->type_string() == "_Retval") + // Each compilation candidate belongs to a cluster. The cluster's + // representative names the node in the 'cycles' graph that represents the + // cluster. + return BuildInitialClusterSet(); +} + +StatusOr MarkForCompilationPassImpl::ContractPreferredEdges() { + bool changed = false; + for (int32 node : cycles_graph_.AllNodesInPostOrder()) { + Cluster* cluster_from = GetClusterForCyclesGraphNode(node); + if (!cluster_from) { continue; - if (node->type_string() == "While") { - // Handle functional While loop. - return IsCompilableWhile(*node, jit_device_type, op_filter, depth + 1, - lib_runtime); } - if (!op_filter.allow_resource_ops && - (HasResourceInput(*node) || HasResourceOutput(*node))) { - return false; - } - if (!op_filter.allow_stateful_rng_ops && - IsStatefulRandomOp(node->type_string())) { - return false; - } - if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { - return false; - } - if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { - return false; - } - if (!op_filter.allow_ops_producing_or_consuming_variant && - OpProducesOrConsumesVariant(*node)) { - return false; - } - if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, - lib_runtime)) { - VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " - << node->name() << ": " << node->def().ShortDebugString(); - return false; + + // Make a copy of the set of successors because we may modify the graph in + // TryToContractEdge. + std::vector successors_copy = + cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id()); + + for (int to : successors_copy) { + iteration_count_++; + + Cluster* cluster_to = GetClusterForCyclesGraphNode(to); + if (!cluster_to) { + continue; + } + + if (cluster_to->cluster_size() == 1) { + Node* n = graph_->FindNodeId(cluster_to->GetIdOfOnlyNode()); + + // Shape consuming operations are desirable to cluster with their + // operands because they return a small set of scalar values after + // consuming a large amount of data. For example, given a graph X -> Y + // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or + // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the + // output of size will be a small tensor while Y is a potentially large + // tensor that must be computed and possible transposed/copied before + // the second cluster executes. + if (IsShapeConsumerOp(*n)) { + TF_ASSIGN_OR_RETURN(bool contracted_edge, + TryToContractEdge(cluster_from, cluster_to)); + changed |= contracted_edge; + } + } } } - return true; + + return changed; } -// Returns true if the op can be decomposed into XLA ops for which -// there are fusable elemental implementations. -// -// TODO(hpucha): Remove this code since this functionality is subsumed by -// Grappler XlaFusionOptimizer. -bool IsXlaFusable(const NodeDef& node) { - static const std::unordered_set* elementwise_ops = - new std::unordered_set( - {// tf2xla/kernels/aggregate_ops.cc - "AddN", - // tf2xla/kernels/batchtospace_op.cc - "BatchToSpace", "BatchToSpaceND", - // tf2xla/kernels/bcast_ops.cc - "BroadcastArgs", "BroadcastGradientArgs", - // tf2xla/kernels/bias_ops.cc - "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, - // tf2xla/kernels/binary_ops.cc - "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", - "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", - "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", - "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", - "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", - "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", - "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", - // tf2xla/kernels/cast_op.cc - "Cast", - // tf2xla/kernels/categorical_op.cc - "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/, - // tf2xla/kernels/concat_op.cc - "Concat", "ConcatV2", "ConcatOffset", - // tf2xla/kernels/const_op.cc - "Const", - // tf2xla/kernels/cross_op.cc - "Cross", - // tf2xla/kernels/depthtospace_op.cc - "DepthToSpace", - // tf2xla/kernels/diag_op.cc - "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart", - // tf2xla/kernels/dynamic_stitch_op.cc - "DynamicStitch", "ParallelDynamicStitch", - // tf2xla/kernels/elu_op.cc - "Elu", "EluGrad", "Selu", "SeluGrad", - // tf2xla/kernels/fake_quantize_ops.cc - "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient", - "FakeQuantWithMinMaxVars", - "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/, - // tf2xla/kernels/fill_op.cc - "Fill", - // tf2xla/kernels/gather_op.cc - "Gather", "GatherV2", "GatherNd", - // tf2xla/kernels/identity_op.cc - "Identity", "IdentityN", "PreventGradient", "StopGradient", - "Snapshot", - // tf2xla/kernels/image_ops.cc - "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/, - "AdjustSaturation", "AdjustHue", - // tf2xla/kernels/index_ops.cc - "ArgMax", "ArgMin", - // tf2xla/kernels/l2loss_op.cc - "L2Loss" /*(Reduce)*/, - // tf2xla/kernels/lrn_ops.cc (ReduceWindow) - "LRN", "LRNGrad", - // tf2xla/kernels/matrix_band_part_op.cc - "MatrixBandPart", - // tf2xla/kernels/matrix_set_diag_op.cc - "MatrixSetDiag", - // tf2xla/kernels/mirror_pad_op.cc - "MirrorPad", - // tf2xla/kernels/no_op.cc - "NoOp", "ControlTrigger", - // tf2xla/kernels/one_hot_op.cc - "OneHot", - // tf2xla/kernels/pack_op.cc - "Pack", - // tf2xla/kernels/pad_op.cc - "Pad", "PadV2", - // tf2xla/kernels/pooling_ops.cc - "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool", - "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/ - "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad", - "AvgPool3DGrad", - // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce) - "QuantizeAndDequantizeV2", - // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend - // currently) - "RandomUniform", "RandomUniformInt", "RandomStandardNormal", - "TruncatedNormal", - // tf2xla/kernels/reduction_ops.cc (Reduce) - "Sum", "Prod", "Min", "Max", "Mean", "All", "Any", - // tf2xla/kernels/relu_op.cc - "Relu", "Relu6", "ReluGrad", "Relu6Grad", - // tf2xla/kernels/reshape_op.cc - "Reshape", - // tf2xla/kernels/reverse_op.cc - "Reverse", "ReverseV2", - // tf2xla/kernels/reverse_sequence_op.cc - "ReverseSequence", - // tf2xla/kernels/scan_ops.cc (ReduceWindow) - "Cumsum", "Cumprod", - // tf2xla/kernels/scatter_nd_op.cc (Reduce) - "ScatterNd", - // tf2xla/kernels/segment_reduction_ops.cc (Reduce) - "UnsortedSegmentSum", - // tf2xla/kernels/select_op.cc - "Select", - // tf2xla/kernels/sequence_ops.cc - "Range", "LinSpace", - // tf2xla/kernels/shape_op.cc - "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", - "ZerosLike", "OnesLike", - // tf2xla/kernels/slice_op.cc - "Slice", - // tf2xla/kernels/softmax_op.cc (Reduce) - "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits", - "SparseSoftmaxCrossEntropyWithLogits", - // tf2xla/kernels/spacetobatch_op.cc - "SpaceToBatchND", "SpaceToBatch", - // tf2xla/kernels/spacetodepth_op.cc - "SpaceToDepth", - // tf2xla/kernels/split_op.cc - "Split", "SplitV", - // tf2xla/kernels/stack_ops.cc - "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2", - // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on - // GPU - // backend currently) - "StatelessRandomUniform", - "StatelessRandomNormal" - // tf2xla/kernels/strided_slice_op.cc - "StridedSlice", - "StridedSliceGrad", "ResourceStridedSliceAssign", - // tf2xla/kernels/tile_ops.cc - "Tile", - // tf2xla/kernels/training_ops.cc - "ResourceApplyGradientDescent", "ResourceApplyMomentum", - "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp", - "ResourceApplyFtrl", "ResourceApplyFtrlV2", - // tf2xla/kernels/transpose_op.cc - "Transpose", "InvertPermutation", - // tf2xla/kernels/unary_ops.cc - "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", - "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", - "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", - "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", - "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", - "Square", "Tan", "Tanh", "Real", "Imag", - // tf2xla/kernels/unpack_op.cc - "Unpack"}); +StatusOr +MarkForCompilationPassImpl::RunEdgeContractionLoopInPostOrderOnce() { + bool changed = false; + // Iterating over the graph once in post-order is sufficient to produce a + // maximal clustering: + // + // A. We visit a cluster only after maximally clustering all its children. + // B. By the time we're done with `node` (in `TryToContractEdgesFrom`) all of + // its children that could have been absorbed into `node` have been + // absorbed. + // C. We have an invariant that making a cluster larger does not make edges + // leaving it more contractable. That is, if we have + // digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible + // to contract Y->Z if Y->Z was not contractible originally. + for (int32 node : cycles_graph_.AllNodesInPostOrder()) { + Cluster* cluster_from = GetClusterForCyclesGraphNode(node); + if (!cluster_from) { + continue; + } - return elementwise_ops->count(node.op()) > 0; + TF_ASSIGN_OR_RETURN(bool contracted_one_edge, + TryToContractEdgesFrom(cluster_from)); + changed |= contracted_one_edge; + } + + return changed; } -// Nodes that XLA can compile are put in `candidates`. Nodes put in -// `isolated_nodes` must either be unclustered or be put in trivial single-node -// clusters. -Status FindCompilationCandidates( - const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, - const std::function& is_compilable_fn, - OrderedNodeSet* candidates, absl::flat_hash_set* isolated_nodes) { +Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { + TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_); + edges_contracted_ = true; + + // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for + // example, from the Grappler fusion pass). + + // Run twice, first only targeted at contracting very beneficial edges then + // without restrictions. This helps to minimize data output from clusters (and + // possible transpose operations before outputs) that might occur if a + // ShapeConsumingOp is on the edge of 2 clusters due to cycle considerations. + TF_ASSIGN_OR_RETURN(bool changed, ContractPreferredEdges()); + + TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce()); + + // Check that RunEdgeContractionLoopInPostOrderOnce is idempotent. Once the + // linear time post-order scheme has been battle tested we can move this to + // happen only in debug builds. + TF_ASSIGN_OR_RETURN(changed, RunEdgeContractionLoopInPostOrderOnce()); + TF_RET_CHECK(!changed); + + return Status::OK(); +} + +std::atomic cluster_sequence_num; + +int64 GetNextClusterSequenceNumber() { return cluster_sequence_num++; } + +Status MarkForCompilationPassImpl::CreateClusters() { + TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_); + clusters_created_ = true; + + // Names for each cluster. + std::unordered_map cluster_names; + + if (debug_options_.dump_graphs) { + DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_); + } + + // Mark clusters for compilation that: + // * are placed on a device that requires compilation (an XlaDevice), + // * are explicitly marked for compilation (_XlaCompile=true), or + // * have more than debug_options_.xla_min_cluster_size elements (applicable + // only if compilation is enabled, otherwise there will be no such + // candidates). + for (Node* n : compilation_candidates_) { + Cluster* cluster = GetClusterForNode(n); + TF_ASSIGN_OR_RETURN(bool should_compile_cluster, + ShouldCompileCluster(*cluster)); + if (!should_compile_cluster) { + continue; + } + + // We assume that functional If and While nodes have at least + // min_cluster_size non-trivial nodes in them. It would be more principled + // to (recursively) verify this fact, but that's probably not worth the + // trouble. + + if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size || + cluster->has_functional_control_flow() || + cluster->is_xla_compile_attr_true()) { + string& name = cluster_names[cluster->cycles_graph_node_id()]; + + if (name.empty()) { + name = absl::StrCat("cluster_", GetNextClusterSequenceNumber()); + } + + n->AddAttr(kXlaClusterAttr, name); + n->AddAttr(kXlaAlreadyClustered, true); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; + } + } + + return Status::OK(); +} + +Status MarkForCompilationPassImpl::DumpDebugInfo() { + TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_); + + if (debug_options_.dump_graphs) { + DumpPostClusteringGraphs(); + } + + VLogClusteringSummary(); + + return Status::OK(); +} + +StatusOr +MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency( + const Cluster& cluster_from, const Cluster& cluster_to) { + // If any of the consumer's producers are on a different device, do not + // cluster these nodes. This prevents other work on this device from being + // delayed by work on other devices. We consider predecessors of the entire + // cluster rather than just the inputs to the node to prevent the cluster + // still being combined in cases where the 'to' cluster has multiple + // dependencies on the 'from' cluster and another dependency leads to a + // merging of the clusters. + // + // TODO(b/117085735): We probably want to handle the reciprocal of this case + // where a cluster is producing data for multiple devices. + for (const auto& in_id : + cycles_graph_.Predecessors(cluster_to.cycles_graph_node_id())) { + const Cluster* cluster_in = GetClusterForCyclesGraphNode(in_id); + if (cluster_in) { + TF_ASSIGN_OR_RETURN(bool devices_compatible, + AreDevicesCompatible(cluster_to, *cluster_in)); + if (!devices_compatible) { + return true; + } + TF_ASSIGN_OR_RETURN(devices_compatible, + AreDevicesCompatible(cluster_from, *cluster_in)); + if (!devices_compatible) { + return true; + } + } + } + + return false; +} + +absl::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { + // Look for an _XlaScope on both nodes. If both nodes have a scope and the + // scopes do not match, do not cluster along this edge. This restriction is + // overridden if the global_jit_level_ is ON. If even one of the nodes lacks + // an _XlaScope attribute, then it is treated as a "bridge" and a cluster may + // be created along it. We may want to restrict this behavior to require all + // nodes marked with _XlaCompile=true to also have a _XlaScope property set + // (and raise an error otherwise); but for now we don't do this. + if (global_jit_level_ != OptimizerOptions::OFF) { + return absl::nullopt; + } + + string scope; + if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) { + return scope; + } + + return absl::nullopt; +} + +Status MarkForCompilationPassImpl::BuildInitialClusterSet() { + auto ignore_resource_ops = [&](const Node& n, bool* ignore) { + return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore); + }; + + std::vector> unsafe_resource_deps_vect; + TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs( + *graph_, flib_def_, ignore_resource_ops, &unsafe_resource_deps_vect)); + absl::c_copy( + unsafe_resource_deps_vect, + std::inserter(unsafe_resource_deps_, unsafe_resource_deps_.begin())); + + cluster_for_node_.resize(graph_->num_node_ids()); + for (Node* node : graph_->nodes()) { + if (!IsCompilationCandidate(node)) { + cluster_for_node_[node->id()].Get() = nullptr; + continue; + } + + // We want clusters to be big enough that the benefit from XLA's + // optimizations offsets XLA related overhead (for instance we add some + // Switch/Merge nodes into the graph to implement lazy compilation). To + // this end, we don't count Identity and Constant nodes because they do not + // enable interesting optimizations by themselves. + int effective_cluster_size = + (node->IsIdentity() || node->IsConstant()) ? 0 : 1; + + bool has_functional_control_flow = + node->type_string() == "While" || node->type_string() == "If"; + + absl::optional deadness_predicate; + if (deadness_analysis_) { + TF_ASSIGN_OR_RETURN( + deadness_predicate, + deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot)); + } + + const string& device_name_str = !node->assigned_device_name().empty() + ? node->assigned_device_name() + : node->requested_device(); + TF_ASSIGN_OR_RETURN(DeviceId device, + device_info_cache_.GetIdFor(device_name_str)); + + bool is_resource_op = HasResourceInputOrOutput(*node); + absl::optional resource_op_device; + if (is_resource_op) { + resource_op_device = device; + } + + absl::optional resource_var_operation_node_id; + if (is_resource_op || MayCallFunction(*node, flib_def_)) { + resource_var_operation_node_id = node->id(); + } + + bool is_xla_compile_attr_true = false; + + bool xla_compile_attr; + if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) { + is_xla_compile_attr_true |= xla_compile_attr; + } + + if (flib_def_->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr).ok()) { + is_xla_compile_attr_true |= xla_compile_attr; + } + + DeviceSet devices; + devices.Insert(device); + + Cluster* new_cluster = MakeNewCluster( + /*cycles_graph_node_id=*/node->id(), + /*effective_cluster_size=*/effective_cluster_size, + /*has_functional_control_flow=*/has_functional_control_flow, devices, + resource_op_device, resource_var_operation_node_id, deadness_predicate, + /*is_xla_compile_attr_true=*/is_xla_compile_attr_true, + GetXlaScope(node)); + + cluster_for_node_[node->id()].Get() = new_cluster; + } + + return Status::OK(); +} + +Status MarkForCompilationPassImpl::FindCompilationCandidates() { OptimizerOptions opts; std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, - flib_def, opts)); + new ProcessFunctionLibraryRuntime(nullptr, env_, TF_GRAPH_DEF_VERSION, + flib_def_, opts)); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - std::vector compile_time_const_nodes(graph.num_node_ids(), false); - TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, - &compile_time_const_nodes, lib_runtime)); - - int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + std::vector compile_time_const_nodes(graph_->num_node_ids(), false); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph_, /*compile_time_const_arg_indices=*/nullptr, + &compile_time_const_nodes, lib_runtime)); // Iterate over nodes in sorted order so that compiler fuel is deterministic. // We can't simply pass op_nodes().begin() and op_nodes().end to the // std::vector constructor because they're not proper iterators, with // iterator_traits defined and so on. std::vector sorted_nodes; - for (Node* node : graph.op_nodes()) { + for (Node* node : graph_->op_nodes()) { sorted_nodes.push_back(node); } std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); - if (fuel >= std::numeric_limits::max() / 2) { + if (*debug_options_.fuel >= std::numeric_limits::max() / 2) { // The assumption is that if fuel started out as INT64_MAX, it will forever // stay greater than INT64_MAX / 2. VLOG(2) << "Starting fuel: infinity"; } else { - VLOG(2) << "Starting fuel: " << fuel; + VLOG(2) << "Starting fuel: " << *debug_options_.fuel; } + VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size(); + for (Node* node : sorted_nodes) { - if (fuel <= 0) { + if (*debug_options_.fuel <= 0) { VLOG(1) << "Hit fuel limit; not marking any remaining ops as clusterable."; break; } - DeviceType device_type(""); - TF_RETURN_IF_ERROR( - DeviceToDeviceType(node->assigned_device_name(), &device_type)); + TF_ASSIGN_OR_RETURN( + const DeviceType& device_type, + device_info_cache_.GetDeviceTypeFor(node->assigned_device_name())); VLOG(4) << "Device type for " << node->name() << ": " << device_type.type_string(); - if (is_compilable_fn && !is_compilable_fn(node, device_type)) { - // is_compilable_fn has already logged the reason if it returned false. + if (CompilationDisallowedByXlaCompileAttr(node)) { + VLOG(2) << "Not clustering " << node->name() + << ": disallowed by _XlaCompile attribute"; continue; } const XlaOpRegistry::DeviceRegistration* registration; - CHECK( - XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), + ®istration)) { + VLOG(2) << "Rejecting " << node->name() + << ": could not find JIT device for " << device_type.type(); + continue; + } + DeviceType jit_device_type(registration->compilation_device_name); - bool always_auto_cluster = registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways; + RecursiveCompilabilityChecker::OperationFilter op_filter = + CreateOperationFilter(*registration); - OperationFilter op_filter; - op_filter.allow_resource_ops = registration->compile_resource_ops; - op_filter.allow_stateful_rng_ops = always_auto_cluster; - op_filter.allow_control_trigger = always_auto_cluster; - op_filter.allow_dummy_ops = always_auto_cluster; - op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster; - - if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, op_filter, 0, - lib_runtime)) { - VLOG(2) << "Rejecting " << node->name() << ": unsupported op " - << node->type_string(); - continue; - } - - if (!op_filter.allow_stateful_rng_ops && - IsStatefulRandomOp(node->type_string())) { - VLOG(2) << "Rejecting " << node->name() << ": stateful random operation"; - continue; - } - if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { - VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op"; - continue; - } - if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { - VLOG(2) << "Rejecting " << node->name() << ": dummy op (" - << node->type_string() << ")"; - continue; - } - if (!op_filter.allow_ops_producing_or_consuming_variant && - OpProducesOrConsumesVariant(*node)) { - VLOG(2) << "Rejecting " << node->name() - << ": produces or consumes DT_VARIANT"; - continue; - } - - if (!op_filter.allow_resource_ops && - (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { - // We don't have a way of returning values of type DT_RESOURCE from XLA - // computations so we avoid auto-clustering nodes producing DT_RESOURCE. - // XlaLaunchOp also cannot snapshot resources that are not resource - // variables so we avoid clustering resource operations that operate on - // non-resource variables. - VLOG(2) << "Rejecting: " << node->name() << ": resource output " - << node->type_string(); + if (!RecursiveCompilabilityChecker{&op_filter, &jit_device_type} + .IsCompilableNode(*node, lib_runtime)) { continue; } if (compile_time_const_nodes[node->id()]) { const OpDef* op_def; TF_RETURN_IF_ERROR( - graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); + graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { // It is easiest to demonstrate the problem we're trying to solve with // an example. Say we have this graph: @@ -602,175 +1029,198 @@ Status FindCompilationCandidates( if (!is_tensor_array_or_stack_op) { VLOG(2) << "Isolating " << node->name() << ": must-be-constant stateful op"; - isolated_nodes->insert(node); - // Keep going and execute all the other checks. + continue; } } } - // We don't auto-cluster functional control flow nodes containing resource - // operations because safety checks are trickier in this case. - // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not - // for CPU/GPU. - if (node->type_string() == "While" && - !IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) { - continue; - } - // _Arg nodes in a top-level function represent feeds. - // Do not compile them. - if (node->type_string() == "_Arg") { - continue; - } - // _Retval nodes in a top-level function represent fetches. - // Do not compile them. - if (node->type_string() == "_Retval") { - continue; - } - candidates->insert(node); - --fuel; + + compilation_candidates_.insert(node); + --(*debug_options_.fuel); } - VLOG(2) << "candidates->size() = " << candidates->size(); + + VLOG(2) << "compilation_candidates_.size() = " + << compilation_candidates_.size(); return Status::OK(); } -struct Cluster { - // Identifies the node that represents this cluster in the cycle detection - // graph. - int representative = -1; +bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( + Node* node) { + if (debug_options_.ignore_xla_compile_attr) { + return false; + } - // The set of devices the nodes in this cluster are placed on. - absl::flat_hash_set devices; + // If there is a _XlaCompile annotation, use its value. + bool compile = false; + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") is false."; + } + return !compile; + } - // If there are resource operation in the cluster then this is the device that - // resource operations are placed on. All resource operations in a cluster - // must be placed on the same device. - string resource_op_device; + status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile); + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") on callee is false."; + } + return !compile; + } - // True if any node in the cluster has an _XlaCompile attribute set to true. - bool has_xla_compile_attr; -}; - -} // anonymous namespace - -bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { - Device* device = flr->device(); - const XlaOpRegistry::DeviceRegistration* registration; - CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), - ®istration)); - DeviceType jit_device_type(registration->compilation_device_name); - - // We can always *compile* resource operations, stateful RNGs and dummy ops, - // even if we are sometimes unable to auto-cluster them. - OperationFilter op_filter; - op_filter.allow_resource_ops = true; - op_filter.allow_stateful_rng_ops = true; - op_filter.allow_control_trigger = true; - op_filter.allow_dummy_ops = true; - op_filter.allow_ops_producing_or_consuming_variant = true; - - return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); + return false; } -Status MarkForCompilationPass::Run( - const GraphOptimizationPassOptions& options) { - // TODO(phawkins): precompute the "GetCompilationDevice" properties of each - // device ahead of time. - OptimizerOptions::GlobalJitLevel global_jit_level = - GetGlobalJitLevel(options); - MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - bool fusion_only = flags->tf_xla_fusion_only; - - VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; - VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; - const FunctionLibraryDefinition* fld = options.flib_def; - - // Deadness analysis expects a graph with source and sink edges properly - // connected but sometimes the incoming graph does not follow this invariant. - // So fix up the source and sink edges before calling into deadness analysis. - FixupSourceAndSinkEdges(options.graph->get()); - - // See explanation on `kXlaAlreadyClustered`. - for (Node* n : options.graph->get()->nodes()) { - if (n->attrs().Find(kXlaAlreadyClustered)) { - return Status::OK(); - } - } - - std::unique_ptr deadness; - { - XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); - TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); - } - - bool deadness_analysis_disabled = - GetMarkForCompilationPassFlags() - ->tf_xla_disable_deadness_safety_checks_for_debugging; - - if (deadness_analysis_disabled) { - LOG(WARNING) << "Deadness analysis was manually disabled via " - "--tf_xla_disable_deadness_safety_checks_for_debugging; " - "auto-clustering " - "is unsound!"; - } - - auto is_compilable = [&](const Node* node, const DeviceType& device_type) { - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), - ®istration)) { - VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device."; - return false; - } - - // If there is a _XlaCompile annotation, use its value. - bool compile = false; - Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); - if (status.ok()) { - if (!compile) { - VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" - << kXlaCompileAttr << ") is false."; - } - return compile; - } - - status = fld->GetAttr(*node, kXlaCompileAttr, &compile); - if (status.ok()) { - if (!compile) { - VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" - << kXlaCompileAttr << ") on callee is false."; - } - return compile; - } - - // If inputs to `node` can have conflicting deadness (i.e. some are alive - // and some are dead) then don't compile it. XLA cannot represent the - // deadness semantics of these nodes correctly and auto-clustering these - // nodes can cause deadness to propagate to nodes that should be live. - if (!deadness_analysis_disabled) { - if (node->IsMerge() || - deadness->HasInputsWithMismatchingDeadness(*node)) { - VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; - return false; - } - } - - // Check for fusable ops only if requested. - if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { - VLOG(2) << "Rejecting " << node->name() - << ": not fusable op but fusion_only enabled."; - return false; - } - - return true; - }; - - return RunImpl(options, is_compilable); +bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( + Cluster* from, Cluster* to, absl::string_view reason) { + VLOG(3) << EdgeContractionFailureMsg(from, to, reason); + return false; } -static string RatioToString(int numerator, int denominator) { +StatusOr MarkForCompilationPassImpl::TryToContractEdge(Cluster* from, + Cluster* to) { + DCHECK(from->deadness_predicate().has_value() == + to->deadness_predicate().has_value()); + if (from->deadness_predicate() != to->deadness_predicate()) { + VLOG(3) << EdgeContractionFailureMsg( + from, to, + absl::StrCat( + "the two nodes have mismatching deadness: ", + deadness_analysis_->DebugString(*from->deadness_predicate()), + " and ", + deadness_analysis_->DebugString(*to->deadness_predicate()))); + return false; + } + + TF_ASSIGN_OR_RETURN(bool devices_compatible, + AreDevicesCompatible(*from, *to)); + if (!devices_compatible) { + return LogNotContractableAndReturnFalse( + from, to, "the two nodes have incompatible devices"); + } + + if (from->xla_scope().has_value() && to->xla_scope().has_value() && + *from->xla_scope() != *to->xla_scope()) { + return LogNotContractableAndReturnFalse( + from, to, "the two nodes have mismatching XLA scopes"); + } + + // Don't exceed the maximum cluster size. + if (from->cluster_size() + to->cluster_size() > + debug_options_.max_cluster_size) { + return LogNotContractableAndReturnFalse( + from, to, "the new cluster will be larger than the max cluster size"); + } + + TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency, + ClusteringWillIntroduceInterDeviceDependency(*from, *to)); + + if (will_introduce_cross_device_dependency) { + return LogNotContractableAndReturnFalse( + from, to, "the new cluster will introduce a cross device dependency"); + } + + // Check if contracting this edge will break the resource variable concurrency + // semantics. In theory this is quadratic in the number of nodes, but seems + // to not be a problem in practice so far. + for (int resource_var_from : from->resource_var_operation_node_ids()) { + for (int resource_var_to : to->resource_var_operation_node_ids()) { + // If unsafe_resource_deps_ contains {A, B} then + // + // a. A and B are resource operations. + // b. A and B cannot be placed in the same cluster. + // c. There is no path from B to A in the cycles graph (but there may be + // a path from A to B). + // + // So check the legality of the edge contraction by checking if any of the + // n^2 pairs of resource variable operations are forbidden. + if (unsafe_resource_deps_.contains( + {resource_var_from, resource_var_to})) { + return LogNotContractableAndReturnFalse( + from, to, + "the new cluster would break resource variable semantics"); + } + } + } + + return MergeClusters(from, to); +} + +StatusOr MarkForCompilationPassImpl::TryToContractEdgesFrom( + Cluster* cluster_from) { + bool changed = false; + + // Make a copy of the set of successors because we may modify the graph in + // TryToContractEdge. + std::vector successors_copy = + cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id()); + + for (int to : successors_copy) { + iteration_count_++; + + Cluster* cluster_to = GetClusterForCyclesGraphNode(to); + if (!cluster_to) { + continue; + } + + TF_ASSIGN_OR_RETURN(bool contracted_edge, + TryToContractEdge(cluster_from, cluster_to)); + + changed |= contracted_edge; + } + + return changed; +} + +Status MarkForCompilationPassImpl::Run() { + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + + // Start the timer after XlaOpRegistry::RegisterCompilationKernels which does + // some one-time work. + XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1); + + TF_RETURN_IF_ERROR(Initialize()); + TF_RETURN_IF_ERROR(RunEdgeContractionLoop()); + TF_RETURN_IF_ERROR(CreateClusters()); + TF_RETURN_IF_ERROR(DumpDebugInfo()); + + return Status::OK(); +} + +void MarkForCompilationPassImpl::DumpPostClusteringGraphs() { + DumpGraphToFile("mark_for_compilation", *graph_, flib_def_); + + // We also dump out an annoated version of the TF graph where the nodes + // names are prefixed with the cluster names. This can help visualizing the + // clustering decisions on TensorBoard. + Graph new_graph(graph_->op_registry()); + CopyGraph(*graph_, &new_graph); + + for (Node* n : new_graph.nodes()) { + if (absl::optional cluster_name = + GetXlaClusterForNode(*n)) { + n->set_name(absl::StrCat(*cluster_name, "/", n->name())); + } else if (n->type_string() == "VarHandleOp") { + n->set_name(absl::StrCat("varhandle/", n->name())); + } else { + // There is room for improvement here. In particular, it may help to + // split these unclustered nodes into classes where every node in a + // specific class has edges to and from the same set of clusters. + n->set_name(absl::StrCat("unclustered/", n->name())); + } + } + + DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_); +} + +string RatioToString(int numerator, int denominator) { return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, (100.0 * numerator) / denominator); } -static void VLogClusteringSummary(const Graph& g) { +void MarkForCompilationPassImpl::VLogClusteringSummary() { if (!VLOG_IS_ON(2)) { return; } @@ -781,7 +1231,7 @@ static void VLogClusteringSummary(const Graph& g) { std::map unclustered_op_histogram; int clustered_node_count = 0; - for (Node* n : g.nodes()) { + for (Node* n : graph_->nodes()) { absl::optional cluster_name = GetXlaClusterForNode(*n); if (cluster_name) { clustered_node_count++; @@ -792,17 +1242,17 @@ static void VLogClusteringSummary(const Graph& g) { } } - int unclustered_node_count = g.num_nodes() - clustered_node_count; + int unclustered_node_count = graph_->num_nodes() - clustered_node_count; - VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes(); + VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes(); VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size " - << RatioToString(clustered_node_count, g.num_nodes()); + << RatioToString(clustered_node_count, graph_->num_nodes()); for (const auto& cluster_name_size_pair : cluster_name_to_size) { absl::string_view cluster_name = cluster_name_size_pair.first; int size = cluster_name_size_pair.second; VLOG(2) << " " << cluster_name << " " - << RatioToString(size, g.num_nodes()); + << RatioToString(size, graph_->num_nodes()); for (const auto& op_count_pair : cluster_name_to_op_histogram[cluster_name]) { VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second @@ -812,7 +1262,7 @@ static void VLogClusteringSummary(const Graph& g) { if (!unclustered_op_histogram.empty()) { VLOG(2) << " Unclustered nodes: " - << RatioToString(unclustered_node_count, g.num_nodes()); + << RatioToString(unclustered_node_count, graph_->num_nodes()); for (const auto& pair : unclustered_op_histogram) { VLOG(3) << " " << pair.first << ": " << pair.second << " instances"; } @@ -843,7 +1293,7 @@ static void VLogClusteringSummary(const Graph& g) { std::set cluster_names_to_print; - for (const Edge* e : g.edges()) { + for (const Edge* e : graph_->edges()) { const Node* from = e->src(); absl::optional from_cluster_name = GetXlaClusterForNode(*from); @@ -898,82 +1348,20 @@ static void VLogClusteringSummary(const Graph& g) { } } -// Is 'node' an operator that consumes only the shape of its input, not the -// data itself? -static bool IsShapeConsumerOp(const Node& node) { - return node.type_string() == "Shape" || node.type_string() == "Rank" || - node.type_string() == "Size"; -} +StatusOr MarkForCompilationPassImpl::AreDevicesCompatible( + const Cluster& cluster_a, const Cluster& cluster_b) { + DeviceSet devices = cluster_a.devices(); + devices.UnionWith(cluster_b.devices()); -static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { - // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then - // ignore it during resource operation safety analysis. We need this hack - // because of two reasons: - // - // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. - // 2. We don't support live-out values of type DT_RESOURCE and live-in values - // of type DT_RESOURCE that are not resource variables. - // - // Together these imply we cannot let resource variable safety analysis - // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different - // clusters: both of them will have to be clustered because of (1) and we - // won't be able to keep the edge between the two as neither the input to the - // second XLA cluster nor the output from the first XLA cluster are supported - // because of (2). - // - // TODO(b/113100872): This can be fixed if the TensorFlow representation for - // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then - // (2) would no longer hold. - - if (n.assigned_device_name().empty()) { - *ignore = false; - return Status::OK(); - } - DeviceType device_type(""); - TF_RETURN_IF_ERROR( - DeviceToDeviceType(n.assigned_device_name(), &device_type)); - - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - *ignore = true; - } else { - *ignore = registration->compile_resource_ops; - } - return Status::OK(); -} - -// Sequence number generator to ensure clusters have unique names. -static std::atomic cluster_sequence_num; - -// Returns true if the devices in `cluster_a` and `cluster_b` are compatible and -// therefore not a hindrance for combining the two clusters into a larger -// cluster. -static Status AreDevicesCompatible( - const Cluster& cluster_a, const Cluster& cluster_b, - OptimizerOptions::GlobalJitLevel global_jit_level, bool* result) { - std::vector devices; - absl::c_remove_copy(cluster_a.devices, std::back_inserter(devices), ""); - absl::c_remove_copy(cluster_b.devices, std::back_inserter(devices), ""); - absl::c_sort(devices); - - if (devices.empty()) { - *result = false; - return Status::OK(); + TF_ASSIGN_OR_RETURN( + absl::optional maybe_chosen_device, + MaybePickDeviceForXla(device_info_cache_, devices, + /*allow_mixing_unknown_and_cpu=*/false)); + if (!maybe_chosen_device.has_value()) { + return false; } - // First check if we will even be able to pick a device for the larger - // combined cluster. - bool can_pick_device; - TF_RETURN_IF_ERROR(CanPickDeviceForXla( - devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device)); - if (!can_pick_device) { - *result = false; - return Status::OK(); - } - - string chosen_device; - TF_RETURN_IF_ERROR(PickDeviceForXla( - devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device)); + jit::DeviceId chosen_device = *maybe_chosen_device; // If we are able to pick a device `chosen_device` for the larger cluster, the // resource operations in `cluster_a` and `cluster_b` must be placed on the @@ -981,381 +1369,175 @@ static Status AreDevicesCompatible( // _XlaRun kernels are going to run on and therefore try to access the // resource variables from `chosen_device`, which will be an error if the // resource variables are placed on some other device. - auto resource_op_device_ok = [&](const string& resource_op_device) { - return resource_op_device.empty() || resource_op_device == chosen_device; - }; + auto resource_op_device_ok = + [&](absl::optional resource_op_device) { + return !resource_op_device.has_value() || + *resource_op_device == chosen_device; + }; - *result = resource_op_device_ok(cluster_a.resource_op_device) && - resource_op_device_ok(cluster_b.resource_op_device); - if (!*result) { - return Status::OK(); - } - - // We will check this again later, but here we prune out clusters that would - // never have been sent to XLA to save compile time. Without this change we - // will e.g. create a CPU cluster only to later notice that the user did not - // enable the CPU JIT via --tf_xla_cpu_global_jit. With this change we avoid - // creating the cluster to begin with. - // - // TODO(b/126629785): It is possible that this is just papering over O(n^2) - // behavior in our clustering algorithm. - const XlaOpRegistry::DeviceRegistration* registration; - DeviceType device_type(""); - TF_RETURN_IF_ERROR(DeviceToDeviceType(chosen_device, &device_type)); - TF_RET_CHECK( - XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) - << "chosen device = " << chosen_device - << "; device type = " << device_type.type() << "; devices (" - << devices.size() << ") = " << absl::StrJoin(devices, ", "); - - *result = cluster_a.has_xla_compile_attr || cluster_b.has_xla_compile_attr || - registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways || - (registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && - global_jit_level != OptimizerOptions::OFF); - - return Status::OK(); + return resource_op_device_ok(cluster_a.resource_op_device()) && + resource_op_device_ok(cluster_b.resource_op_device()); } // Returns `true` iff we should compile `cluster`. -static Status ShouldCompileClusterImpl( - const Cluster& cluster, OptimizerOptions::GlobalJitLevel global_jit_level, - bool* should_compile, string* device) { - std::vector devices; - absl::c_remove_copy(cluster.devices, std::back_inserter(devices), ""); - absl::c_sort(devices); +StatusOr MarkForCompilationPassImpl::ShouldCompileClusterImpl( + const Cluster& cluster) { + TF_ASSIGN_OR_RETURN(DeviceId chosen_device, + PickDeviceForXla(device_info_cache_, cluster.devices(), + /*allow_mixing_unknown_and_cpu=*/false)); - string chosen_device; - TF_RETURN_IF_ERROR(PickDeviceForXla( - devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device)); - - const XlaOpRegistry::DeviceRegistration* registration; - DeviceType device_type(""); - TF_RETURN_IF_ERROR(DeviceToDeviceType(chosen_device, &device_type)); - TF_RET_CHECK( - XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) - << "chosen device = " << chosen_device + const DeviceType& device_type = + device_info_cache_.GetDeviceTypeFor(chosen_device); + const XlaOpRegistry::DeviceRegistration* registration = + device_info_cache_.GetCompilationDevice(chosen_device); + TF_RET_CHECK(registration) + << "chosen device = " << device_info_cache_.GetNameFor(chosen_device) << "; device type = " << device_type.type() << "; devices (" - << devices.size() << ") = " << absl::StrJoin(devices, ", "); + << device_info_cache_.DebugString(cluster.devices()); - *should_compile = - cluster.has_xla_compile_attr || + bool should_compile = + cluster.is_xla_compile_attr_true() || registration->autoclustering_policy == XlaOpRegistry::AutoclusteringPolicy::kAlways || (registration->autoclustering_policy == XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && - global_jit_level != OptimizerOptions::OFF); + global_jit_level_ != OptimizerOptions::OFF); - VLOG(3) << (*should_compile ? "Compiling" : "Not compiling") - << " cluster with device " << chosen_device; - - *device = std::move(chosen_device); - return Status::OK(); -} - -static Status ShouldCompileCluster( - absl::flat_hash_map>* cache, - OptimizerOptions::GlobalJitLevel global_jit_level, const Cluster& cluster, - bool* should_compile, string* device) { - auto it = cache->find(cluster.representative); - if (it != cache->end()) { - *should_compile = it->second.first; - *device = it->second.second; - return Status::OK(); + if (!should_compile && + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested && + device_type.type_string() == DEVICE_CPU) { + static std::once_flag once; + std::call_once(once, [] { + LOG(WARNING) + << "(One-time warning): Not using XLA:CPU for cluster because envvar " + "TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want " + "XLA:CPU, either set that envvar, or use experimental_jit_scope " + "to enable XLA:CPU. To confirm that XLA is active, pass " + "--vmodule=xla_compilation_cache=1 (as a proper command-line " + "flag, not via TF_XLA_FLAGS) or set the envvar " + "XLA_FLAGS=--xla_hlo_profile."; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_cpu_global_jit) { + LOG(WARNING) + << "(Although the tf_xla_cpu_global_jit flag is currently enabled, " + "perhaps it wasn't enabled at process startup?)"; + } + }); } - string device_s; - TF_RETURN_IF_ERROR(ShouldCompileClusterImpl(cluster, global_jit_level, - should_compile, &device_s)); - cache->insert({cluster.representative, {*should_compile, device_s}}); - *device = std::move(device_s); - return Status::OK(); + VLOG(3) << (should_compile ? "Compiling" : "Not compiling") + << " cluster with device " + << device_info_cache_.GetNameFor(chosen_device); + + return should_compile; } -Status MarkForCompilationPass::RunImpl( +StatusOr MarkForCompilationPassImpl::ShouldCompileCluster( + const Cluster& cluster) { + auto it = should_compile_cluster_cache_.find(&cluster); + if (it != should_compile_cluster_cache_.end()) { + return it->second; + } + + TF_ASSIGN_OR_RETURN(bool should_compile, ShouldCompileClusterImpl(cluster)); + should_compile_cluster_cache_.insert({&cluster, should_compile}); + return should_compile; +} + +Status MarkForCompilation( const GraphOptimizationPassOptions& options, - const std::function& - is_compilable_fn) { - VLOG(1) << "MarkForCompilationPass::Run"; - - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - + const MarkForCompilationPassImpl::DebugOptions& debug_options) { Graph* graph = options.graph->get(); + FunctionLibraryDefinition* flib_def = options.flib_def; - OrderedNodeSet compilation_candidates; - absl::flat_hash_set isolated_nodes; - TF_RETURN_IF_ERROR(FindCompilationCandidates( - *graph, options.flib_def, - (options.session_options != nullptr) ? options.session_options->env - : Env::Default(), - is_compilable_fn, &compilation_candidates, &isolated_nodes)); + // Deadness analysis expects a graph with source and sink edges properly + // connected but sometimes the incoming graph does not follow this invariant. + // So fix up the source and sink edges before calling into deadness analysis. + FixupSourceAndSinkEdges(graph); - if (compilation_candidates.empty()) { - VLOG(2) << "No compilable candidates"; - return Status::OK(); + // See explanation on `kXlaAlreadyClustered`. + for (Node* n : graph->nodes()) { + if (n->attrs().Find(kXlaAlreadyClustered)) { + return Status::OK(); + } } - GraphCycles cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); - TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( - graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); + return MarkForCompilationPassImpl{debug_options, graph, flib_def, + options.session_options != nullptr + ? options.session_options->env + : Env::Default(), + GetGlobalJitLevelForGraph(options)} + .Run(); +} - // Each compilation candidate belongs to a cluster. The cluster's - // representative - // names the node in the 'cycles' graph that represents the cluster. - std::vector> clusters(graph->num_node_ids()); - std::deque*> worklist; - for (Node* node : compilation_candidates) { - Cluster& cluster = clusters[node->id()].Get(); - cluster.representative = node->id(); - const string& device = !node->assigned_device_name().empty() - ? node->assigned_device_name() - : node->requested_device(); - if (HasResourceInput(*node) || HasResourceOutput(*node)) { - cluster.resource_op_device = device; - } - cluster.has_xla_compile_attr = false; - bool xla_compile_attr; - if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) { - cluster.has_xla_compile_attr |= xla_compile_attr; - } - if (options.flib_def->GetAttr(*node, kXlaCompileAttr, &xla_compile_attr) - .ok()) { - cluster.has_xla_compile_attr |= xla_compile_attr; - } +std::atomic* GetPointerToFuel(int64 initial_value) { + static std::atomic* fuel = [&]() { + std::atomic* fuel = new std::atomic; + *fuel = initial_value; + return fuel; + }(); - cluster.devices.insert(device); - worklist.push_back(&clusters[node->id()]); - } + return fuel; +} +} // anonymous namespace - OptimizerOptions::GlobalJitLevel global_jit_level = - GetGlobalJitLevel(options); +bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { + Device* device = flr->device(); + const XlaOpRegistry::DeviceRegistration* registration; + CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), + ®istration)); + DeviceType jit_device_type(registration->compilation_device_name); + + // We can always *compile* resource operations, stateful RNGs and dummy ops, + // even if we are sometimes unable to auto-cluster them. + RecursiveCompilabilityChecker::OperationFilter op_filter; + op_filter.allow_resource_ops_in_called_functions = true; + op_filter.allow_stack_ops = true; + op_filter.allow_tensor_array_ops = true; + op_filter.allow_stateful_rng_ops = true; + op_filter.allow_control_trigger = true; + op_filter.allow_eliding_assert_and_checknumerics_ops = true; + op_filter.allow_ops_producing_or_consuming_variant = true; + op_filter.allow_slow_and_inaccurate_ops = true; + + return RecursiveCompilabilityChecker{&op_filter, &jit_device_type} + .IsCompilableCall(ndef, flr); +} + +Status MarkForCompilationPass::Run( + const GraphOptimizationPassOptions& options) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - // Repeatedly contract edges between clusters that are on the same device, - // provided the contraction would not create a cycle. - // - // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for - // example, from the Grappler fusion pass). - while (!worklist.empty()) { - Cluster* cluster_from = &worklist.front()->Get(); - int from = cluster_from->representative; - worklist.pop_front(); + MarkForCompilationPassImpl::DebugOptions debug_options; + debug_options.ignore_deadness_checks = + flags->tf_xla_disable_deadness_safety_checks_for_debugging; + debug_options.ignore_xla_compile_attr = false; + debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; + debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; + debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); + debug_options.dump_graphs = flags->tf_xla_clustering_debug; - Node* node_from = graph->FindNodeId(from); - if (node_from->IsControlFlow()) { - // Control flow nodes aren't compilation candidates and should never - // appear. - return errors::Internal( - "Found control flow node in clustering worklist: ", - node_from->type_string()); - } - - if (isolated_nodes.count(node_from)) { - continue; - } - - string from_scope; - string to_scope; - for (int to : cycles.Successors(from)) { - if (to >= graph->num_node_ids()) { - // Node is a fictitious node that is present only in the cycle detection - // graph. No clustering is possible. - continue; - } - - const Cluster& cluster_to = clusters[to].Get(); - Node* node_to = graph->FindNodeId(to); - if (compilation_candidates.find(node_to) == - compilation_candidates.cend()) { - continue; - } - bool devices_compatible; - TF_RETURN_IF_ERROR(AreDevicesCompatible( - *cluster_from, cluster_to, global_jit_level, &devices_compatible)); - if (!devices_compatible) { - continue; - } - if (isolated_nodes.count(node_to)) { - continue; - } - // Look for an _XlaScope on both nodes. If both nodes have a - // scope and the scopes do not match, do not cluster along this - // edge. This restriction is overridden if the global_jit_level is ON. If - // even one of the nodes lacks an _XlaScope attribute, - // then it is treated as a "bridge" and a cluster may be created - // along it. We may want to restrict this behavior to require - // all nodes marked with _XlaCompile=true to also have a - // _XlaScope property set (and raise an error otherwise); but - // for now we don't do this. - if (global_jit_level == OptimizerOptions::OFF && - GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && - GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && - from_scope != to_scope) { - continue; - } - - // Ops that consume shapes cannot be the root of a cluster. This is an - // optimization. - if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { - continue; - } - - // Don't exceed the maximum cluster size. - if (clusters[from].Size() + clusters[to].Size() > - flags->tf_xla_max_cluster_size) { - continue; - } - - // If any of the consumer's producers are on a different device, do not - // cluster these nodes. This prevents other work on this device from being - // delayed by work on other devices. We consider predecessors of the - // entire cluster rather than just the inputs to the node to prevent the - // cluster still being combined in cases where the 'to' cluster has - // multiple dependencies on the 'from' cluster and another dependency - // leads to a merging of the clusters. - // - // TODO(b/117085735): We probably want to handle the reciprocal of this - // case where a cluster is producing data for multiple devices. - bool found_split = false; - for (const auto& in_id : cycles.Predecessors(to)) { - if (in_id >= graph->num_node_ids()) continue; - - Node* in = graph->FindNodeId(in_id); - const Cluster& cluster_in = clusters[in_id].Get(); - if (compilation_candidates.find(in) != compilation_candidates.cend()) { - bool devices_compatible; - TF_RETURN_IF_ERROR(AreDevicesCompatible( - cluster_to, cluster_in, global_jit_level, &devices_compatible)); - if (!devices_compatible) { - found_split = true; - } - } - } - if (found_split) continue; - - // If contracting the edge would create a cycle, bail out. - // However, just because we can't merge the clusters now does not mean - // we won't be able to merge them in the future. - // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge - // 1->3. But if we first contract 1->2 then we can later contract 1->3. - if (!cycles.ContractEdge(from, to)) continue; - - // Merge the clusters. ContractEdge uses 'from' as the number of the - // merged node, so make sure 'from' is the chosen representative. - cluster_from->devices.insert(cluster_to.devices.begin(), - cluster_to.devices.end()); - if (!cluster_to.resource_op_device.empty()) { - cluster_from->resource_op_device = cluster_to.resource_op_device; - } - cluster_from->has_xla_compile_attr |= cluster_to.has_xla_compile_attr; - clusters[from].Merge(&clusters[to]); - - worklist.push_back(&clusters[from]); - break; - } - } - - // Count the number of non-trivial elements in each cluster. - std::vector effective_cluster_sizes(graph->num_node_ids()); - for (const Node* n : compilation_candidates) { - int cluster = clusters[n->id()].Get().representative; - // Identity nodes will be removed if the node gets marked for compilation. - // Therefore we don't want to count them towards the effective cluster size. - if (n->def().op() != "Identity") { - effective_cluster_sizes[cluster]++; - } - } - - // Names for each cluster. - std::unordered_map cluster_names; - - if (flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph, - options.flib_def); - } - - absl::flat_hash_map> - should_compile_cluster_cache; - - // Mark clusters for compilation that: - // * are placed on a device that requires compilation (an XlaDevice), - // * are explicitly marked for compilation (_XlaCompile=true), or - // * have more than flags->tf_xla_min_cluster_size elements (applicable only - // if compilation is enabled, otherwise there will be no such candidates). - const int min_cluster_size = flags->tf_xla_min_cluster_size; - for (Node* n : compilation_candidates) { - const Cluster& cluster = clusters[n->id()].Get(); - bool should_compile; - string device; - TF_RETURN_IF_ERROR(ShouldCompileCluster(&should_compile_cluster_cache, - global_jit_level, cluster, - &should_compile, &device)); - if (!should_compile) { - continue; - } - - int cluster_repr = cluster.representative; - - // Compile if the user marked this node _XlaCompile=true - bool compile_attr = false; - bool marked_for_compilation = false; - if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) { - marked_for_compilation = compile_attr; - } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr) - .ok()) { - marked_for_compilation = compile_attr; - } - - // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if it contains at least one op that is marked for - // compilation that is not an Identity op. - if (effective_cluster_sizes[cluster_repr] >= min_cluster_size || - (effective_cluster_sizes[cluster_repr] > 0 && marked_for_compilation)) { - string& name = cluster_names[cluster_repr]; - - if (name.empty()) { - name = absl::StrCat("cluster_", cluster_sequence_num++); - } - n->AddAttr(kXlaClusterAttr, name); - n->AddAttr(kXlaAlreadyClustered, true); - VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; - } - } - - if (flags->tf_xla_clustering_debug) { - dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, - options.flib_def); - - // We also dump out an annoated version of the TF graph where the nodes - // names are prefixed with the cluster names. This can help visualizing the - // clustering decisions on TensorBoard. - Graph new_graph((*options.graph)->op_registry()); - CopyGraph(**options.graph, &new_graph); - - for (Node* n : new_graph.nodes()) { - if (absl::optional cluster_name = - GetXlaClusterForNode(*n)) { - n->set_name(absl::StrCat(*cluster_name, "/", n->name())); - } else if (n->type_string() == "VarHandleOp") { - n->set_name(absl::StrCat("varhandle/", n->name())); - } else { - // There is room for improvement here. In particular, it may help to - // split these unclustered nodes into classes where every node in a - // specific class has edges to and from the same set of clusters. - n->set_name(absl::StrCat("unclustered/", n->name())); - } - } - - dump_graph::DumpGraphToFile("mark_for_compilation_annotated", new_graph, - options.flib_def); - } - - VLogClusteringSummary(*graph); - - return Status::OK(); + return MarkForCompilation(options, debug_options); } +Status MarkForCompilationPass::RunForTest( + const GraphOptimizationPassOptions& options, + bool disable_deadness_analysis) { + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + + MarkForCompilationPassImpl::DebugOptions debug_options; + debug_options.ignore_deadness_checks = disable_deadness_analysis; + debug_options.ignore_xla_compile_attr = true; + debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; + debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; + debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); + debug_options.dump_graphs = flags->tf_xla_clustering_debug; + + return MarkForCompilation(options, debug_options); +} + +namespace testing { +void ResetClusterSequenceNumber() { cluster_sequence_num = 0; } +} // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index f1137af3c1e..2eee144e645 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -41,9 +41,8 @@ class MarkForCompilationPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; private: - Status RunImpl(const GraphOptimizationPassOptions& options, - const std::function& - is_compilable_fn = {}); + Status RunForTest(const GraphOptimizationPassOptions& options, + bool disable_deadness_analysis); friend class MarkForCompilationPassTestHelper; }; @@ -52,6 +51,13 @@ class MarkForCompilationPass : public GraphOptimizationPass { // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); + +namespace testing { +// DO NOT USE IN PRODUCTION. +// +// Resets some internal state to let us write reliable unit tests. +void ResetClusterSequenceNumber(); +} // namespace testing } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index f91ce59ad2b..b8937de4db5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/list_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" @@ -195,33 +196,43 @@ TEST(XlaCompilationTest, HalfSupported) { EXPECT_FALSE(clusters.empty()); } -TEST(XlaCompilationTest, ConcatWithConstArg) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - GraphDef graphdef; - { - Tensor t(DT_INT32, TensorShape()); - t.scalar()() = 0; - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* dim = ops::SourceOp("Const", builder.opts() - .WithName("Dim") - .WithAttr("dtype", DT_INT32) - .WithAttr("value", t)); - Node* a = ops::SourceOp("Const", builder.opts() - .WithName("A") - .WithAttr("dtype", DT_FLOAT) - .WithAttr("value", t)); +// Tests that PartitionedCalls are only marked for compilation if every node +// inside the function can be compiled. +TEST(XlaCompilationTest, PartitionedCallUnsupported) { + FunctionDef compilable = FunctionDefHelper::Define( + "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, + {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}}); + FunctionDef uncompilable = + FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"}, + {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); - NodeBuilder concat_builder("Concat", "Concat", - builder.opts().op_registry()); - concat_builder.Input(dim).Input({a, a}).Attr("N", 2); - builder.opts().FinalizeBuilder(&concat_builder); + FunctionDefLibrary flib; + *flib.add_function() = compilable; + *flib.add_function() = uncompilable; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - } + std::unique_ptr graph(new Graph(&flib_def)); + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT); - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + NameAttrList b_name_attr; + b_name_attr.set_name("CompilableFn"); + ops::PartitionedCall b(root.WithOpName("B"), {a, a}, {DT_FLOAT}, b_name_attr); + NameAttrList c_name_attr; + c_name_attr.set_name("UncompilableFn"); + + ops::PartitionedCall c(root.WithOpName("C"), {a}, {DT_FLOAT}, c_name_attr); + Output d = ops::Add(root.WithOpName("D"), b.output.front(), c.output.front()); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); auto clusters = GetClusters(*graph); - EXPECT_EQ(3, clusters.size()); // Everything should be compiled. + + EXPECT_EQ(2, clusters.size()); + EXPECT_FALSE(clusters["B"].empty()); + EXPECT_TRUE(clusters["C"].empty()); + EXPECT_EQ(clusters["B"], clusters["D"]); } TEST(XlaCompilationTest, FunctionCalls) { @@ -259,36 +270,66 @@ TEST(XlaCompilationTest, FunctionCalls) { auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); - EXPECT_FALSE(clusters["B"].empty()); - EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_FALSE(clusters["C"].empty()); + EXPECT_EQ(clusters["C"], clusters["E"]); EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("B") == clusters.cend()); EXPECT_TRUE(clusters.find("D") == clusters.cend()); - EXPECT_TRUE(clusters.find("E") == clusters.cend()); } -// Metadata-only operators such as Shape/Rank/Size may not be the root of a -// cluster. This is partially to work around b/26800664, and partially because -// we should probably prefer to compile metadata operators with their producers -// wherever possible, rather than their consumers. -TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); +TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) { + FunctionDef compilable = FunctionDefHelper::Define( + "FnWithResourceOp", {"var:resource", "val:float"}, {"retval:float"}, {}, + {{{"assign_op"}, + "AssignVariableOp", + {"var", "val"}, + {{"dtype", DT_FLOAT}}}, + {{"retval"}, "Identity", {"val"}, {{"T", DT_FLOAT}}, {"assign_op"}}}); + + FunctionDefLibrary flib; + *flib.add_function() = compilable; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + std::unique_ptr graph(new Graph(&flib_def)); GraphDef graphdef; { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = - ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); - // While all of the following ops are notionally compilable, none is - // permitted - // to start a cluster. So nothing should be compiled. - Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B")); - Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); - Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); - ops::UnaryOp("Shape", d, builder.opts().WithName("E")); + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def); + Node* resource = + ops::SourceOp("VarHandleOp", builder.opts() + .WithName("varhandle") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("shape", TensorShape({}))); + + Tensor const_tensor(DT_FLOAT, TensorShape({})); + const_tensor.scalar()() = 42.0f; + Node* value = ops::SourceOp("Const", builder.opts() + .WithName("const") + .WithAttr("value", const_tensor) + .WithAttr("dtype", DT_FLOAT)); + + Node* call = ops::BinaryOp("FnWithResourceOp", resource, value, + builder.opts().WithName("A")); + Node* tanh0 = ops::UnaryOp("Tanh", call, builder.opts().WithName("tanh0")); + Node* tanh1 = ops::UnaryOp("Tanh", tanh0, builder.opts().WithName("tanh1")); + ops::UnaryOp("Tanh", tanh1, builder.opts().WithName("tanh2")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + testing::FindNodeByName(graph.get(), "A") + ->set_assigned_device_name(xla_cpu_device); + testing::FindNodeByName(graph.get(), "tanh0") + ->set_assigned_device_name(xla_cpu_device); + testing::FindNodeByName(graph.get(), "tanh1") + ->set_assigned_device_name(xla_cpu_device); + testing::FindNodeByName(graph.get(), "tanh2") + ->set_assigned_device_name(xla_cpu_device); + + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); auto clusters = GetClusters(*graph); - EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. + + EXPECT_NE(clusters["A"], ""); } static Status GradForUnaryCwise(FunctionDef* g, @@ -459,8 +500,8 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK( - MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit())); auto clusters = GetClusters(*graph); // The computation is: C = A + relu(A) @@ -498,8 +539,8 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK( - MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit())); auto clusters = GetClusters(*graph); // The computation is: D = relu(A) + (A @ relu(A)) @@ -532,8 +573,8 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK( - MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit())); auto clusters = GetClusters(*graph); // The computation is: C = A @ relu(A) @@ -544,6 +585,77 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { EXPECT_EQ(clusters["B"], clusters["C"]); } +TEST(XlaCompilationTest, DontClusterNodesWithMismatchingDeadness) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL); + Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL); + + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + + ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a); + ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b); + + Output tanh_a0 = ops::Tanh(root.WithOpName("tan_a0"), switch_a.output_true); + Output tanh_a1 = ops::Tanh(root.WithOpName("tan_a1"), tanh_a0); + + Output tanh_b0 = ops::Tanh(root.WithOpName("tan_b0"), switch_b.output_true); + Output tanh_b1 = ops::Tanh(root.WithOpName("tan_b1"), tanh_b0); + + Output add = ops::Add(root.WithOpName("add"), tanh_a1, tanh_b1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, + MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis())); + auto clusters = GetClusters(*graph); + + EXPECT_NE(clusters["tan_a0"], ""); + EXPECT_NE(clusters["tan_a1"], ""); + EXPECT_NE(clusters["tan_b0"], ""); + EXPECT_NE(clusters["tan_b1"], ""); + + EXPECT_EQ(clusters["tan_a0"], clusters["tan_a1"]); + EXPECT_EQ(clusters["tan_b0"], clusters["tan_b1"]); + + EXPECT_NE(clusters["tan_a0"], clusters["tan_b0"]); +} + +TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL); + Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL); + + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + + ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a); + ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b); + + Output add_a = ops::Add(root.WithOpName("add_a"), switch_a.output_true, + switch_b.output_true); + Output add_b = ops::Add(root.WithOpName("add_b"), switch_a.output_true, + switch_b.output_true); + Output add = ops::Add(root.WithOpName("add_c"), add_a, add_b); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, + MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis())); + auto clusters = GetClusters(*graph); + + EXPECT_NE(clusters["add_a"], ""); + EXPECT_NE(clusters["add_b"], ""); + EXPECT_NE(clusters["add_c"], ""); + + EXPECT_EQ(clusters["add_a"], clusters["add_b"]); + EXPECT_EQ(clusters["add_b"], clusters["add_c"]); +} + namespace { Node* MakeRead(const Scope& scope, const string& id, Node** var_handle_op = nullptr) { @@ -606,10 +718,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); - ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes = {"AssignmentW", - "ValueToAssignW"}; - ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); + ASSERT_EQ(cluster_sets.size(), 0); } TEST(XlaCompilationTest, ChainOfOps) { @@ -637,15 +746,11 @@ TEST(XlaCompilationTest, ChainOfOps) { absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); - ASSERT_EQ(cluster_sets.size(), 2); + ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", - "ValueToAssignW0"}; + std::vector expected_clustered_nodes_a = { + "AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"}; ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); - - std::vector expected_clustered_nodes_b = { - "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; - ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); } TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { @@ -704,9 +809,7 @@ TEST(XlaCompilationTest, Retval) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); - EXPECT_EQ(2, clusters.size()); - EXPECT_TRUE(clusters.find("R") == clusters.cend()); - EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_TRUE(clusters.empty()); } TEST(XlaCompilationTest, DontCountIdentityOps) { @@ -725,22 +828,6 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { EXPECT_TRUE(clusters.empty()); } -TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - Scope root = Scope::NewRootScope().ExitOnError(); - { - auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); - auto b = ops::Identity(root.WithOpName("B"), a); - b.node()->AddAttr(kXlaCompileAttr, true); - auto r = ops::_Retval(root.WithOpName("R"), b, 0); - } - TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - - EXPECT_TRUE(clusters.empty()); -} - TEST(XlaCompilationTest, ConstOp) { // valid data type { @@ -996,8 +1083,10 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) { absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:XLA_GPU:1"; std::unique_ptr graph(new Graph(OpRegistry::Global())); - Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}); - Output b = ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}); + Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"), + ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2})); + Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"), + ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2})); Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a); Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b); @@ -1023,6 +1112,45 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) { EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]); } +TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) { + // This is similar to the 'DontClusterMergingNodes' above, except + // MatMulCombined is placed on the CPU. + Scope root = Scope::NewRootScope().ExitOnError(); + absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0"; + absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1"; + absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0"; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"), + ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2})); + Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"), + ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2})); + Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a); + Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b); + + Output combined = + ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) { + n->set_assigned_device_name(string(xla_cpu_dev0)); + } else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { + n->set_assigned_device_name(string(xla_gpu_dev0)); + } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { + n->set_assigned_device_name(string(xla_gpu_dev1)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + // Each of the MatMuls should be in a separate cluster. + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); + EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]); + EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]); + EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]); + EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]); +} + // TODO(b/117085735): This form of clustering should be prevented. TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) { // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed @@ -1366,5 +1494,113 @@ TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) { EXPECT_EQ(clusters[resource_read_name], ""); } +TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::Add(root.WithOpName("test/x"), a, b); + Output y = ops::MatMul(root.WithOpName("test/y"), a, b); + Output z = ops::Add(root.WithOpName("test/z"), x, y); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0); + + std::vector scoped_allocator_value; + scoped_allocator_value.push_back(0); + scoped_allocator_value.push_back(155); + FindNodeByName(graph.get(), "test/z") + ->AddAttr("_scoped_allocator", scoped_allocator_value); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_EQ(clusters["test/z"], ""); +} + +TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::Add(root.WithOpName("test/x"), a, b); + Output y = ops::MatMul(root.WithOpName("test/y"), a, b); + Output z = ops::Add(root.WithOpName("test/z"), x, y); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0); + + FindNodeByName(graph.get(), "test/z")->AddAttr("_forward_from", 0); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_EQ(clusters["test/z"], ""); +} + +// Note, this relies on other implementation details to test the +// specific heuristic we care about here, so other changes might be at fault if +// this CL breaks. What we care about is that if a ShapeConsumingOp can be +// connected with a producer or consumer and cannot be clustered with both, it +// should be clustered with the producer. +TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::MatMul(root.WithOpName("test/x"), a, b); + Output y = ops::Size(root.WithOpName("test/y"), x); + Output z = ops::Add(root.WithOpName("test/z"), y, y); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + // Ensure that the "Size" op can only be clustered with either the producer or + // consumer by putting them on different devices. + FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0); + FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0); + FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU1); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_NE(clusters["test/y"], ""); + EXPECT_EQ(clusters["test/x"], clusters["test/y"]); + EXPECT_NE(clusters["test/z"], clusters["test/y"]); +} + +// Test that ShapeConsuming ops are still fully clustered whenever possible. +TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + + Output x = ops::MatMul(root.WithOpName("test/x"), a, b); + Output y = ops::Size(root.WithOpName("test/y"), x); + Output z = ops::Add(root.WithOpName("test/z"), y, y); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + EXPECT_NE(clusters["test/y"], ""); + EXPECT_EQ(clusters["test/y"], clusters["test/x"]); + EXPECT_EQ(clusters["test/y"], clusters["test/z"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index 5f0ebe150fa..fa5abdfe508 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, - bool enable_global_jit) { + MarkForCompilationPassTestHelper::Options options) { // Assign all unassigned nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { @@ -31,7 +31,7 @@ namespace tensorflow { } SessionOptions session_options; - if (enable_global_jit) { + if (options.enable_global_jit) { session_options.config.mutable_graph_options() ->mutable_optimizer_options() ->set_global_jit_level(OptimizerOptions::ON_2); @@ -49,13 +49,16 @@ namespace tensorflow { opt_options.session_options = &session_options; opt_options.flib_def = flib_def; MarkForCompilationPass pass; - return pass.RunImpl(opt_options); + return pass.RunForTest( + opt_options, + /*disable_deadness_analysis=*/options.disable_deadness_analysis); } /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( - std::unique_ptr* graph, bool enable_global_jit) { + std::unique_ptr* graph, + MarkForCompilationPassTestHelper::Options options) { FunctionDefLibrary flib; FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); - return MarkForCompilation(graph, &flib_def, enable_global_jit); + return MarkForCompilation(graph, &flib_def, options); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h index df751978562..b81fca43c80 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -21,16 +21,35 @@ limitations under the License. namespace tensorflow { class MarkForCompilationPassTestHelper { public: + struct Options { + bool enable_global_jit; + bool disable_deadness_analysis; + + Options() : enable_global_jit(true), disable_deadness_analysis(true) {} + + Options WithNoGlobalJit() { + Options copy = *this; + copy.enable_global_jit = false; + return copy; + } + + Options WithDeadnessAnalysis() { + Options copy = *this; + copy.disable_deadness_analysis = false; + return copy; + } + }; + // Runs the MarkForCompilation pass on `graph` after assigning all nodes in // `graph` to the CPU device. To make testing easier, ignores device - // registration, _XlaCompile attributes and input deadness. + // registration and _XlaCompile attributes. static Status MarkForCompilation(std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, - bool enable_global_jit = true); + Options options = Options()); // Like `MarkForCompilation` but creates `flib_def` from the op registry. static Status MarkForCompilation(std::unique_ptr* graph, - bool enable_global_jit = true); + Options options = Options()); }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc index c788091724e..b878f05e1df 100644 --- a/tensorflow/compiler/jit/node_matchers.cc +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -77,6 +77,8 @@ bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, } switch (tensor.dtype()) { + case DT_HALF: + return CompareTensor(tensor, expected_tensor, listener); case DT_FLOAT: return CompareTensor(tensor, expected_tensor, listener); case DT_DOUBLE: diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index ffc5d0edbcc..30ba5a56efd 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -14,9 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/partially_decluster_pass.h" + #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -49,6 +51,15 @@ Status FindNodesToDecluster(const Graph& graph, continue; } + // Assume the benefit of not outputting a larger tensor outweighs the + // benefit of this check. + // TODO(tpopp): Only apply this if the value being consumed is not output + // from the cluster to another consumer. + // TODO(tpopp): See if XlaRun can be modified to avoid this issue + // completely. + if (IsShapeConsumerOp(*n)) { + continue; + } // We assume the only XLA-auto-clusterable operations with side effects are // resource variable updates. We can't execute these twice. if (HasResourceInputOrOutput(*n)) { @@ -57,7 +68,7 @@ Status FindNodesToDecluster(const Graph& graph, DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceToDeviceType(n->assigned_device_name(), &device_type)); + DeviceNameToDeviceType(n->assigned_device_name(), &device_type)); TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, n->def(), &input_mtypes, &output_mtypes)); @@ -77,8 +88,8 @@ Status FindNodesToDecluster(const Graph& graph, } else { MemoryTypeVector dst_input_mtypes, dst_output_mtypes; DeviceType dst_device_type(""); - TF_RETURN_IF_ERROR( - DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type)); + TF_RETURN_IF_ERROR(DeviceNameToDeviceType(dst->assigned_device_name(), + &dst_device_type)); TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, dst->def(), &dst_input_mtypes, &dst_output_mtypes)); @@ -237,7 +248,7 @@ bool IsMustCompileDevice(const DeviceType& device_type) { Status MustCompileNode(const Node* n, bool* must_compile) { DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceToDeviceType(n->assigned_device_name(), &device_type)); + DeviceNameToDeviceType(n->assigned_device_name(), &device_type)); if (IsMustCompileDevice(device_type)) { *must_compile = true; @@ -340,6 +351,40 @@ Status PartiallyDeclusterGraph(Graph* graph, return Status::OK(); } } // namespace reduce_recompilation + +namespace decluster_root_shape_consumers { + +Status PartiallyDeclusterGraph(Graph* graph) { + std::vector reverse_post_order; + GetReversePostOrder(*graph, &reverse_post_order, + /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); + + for (Node* n : reverse_post_order) { + if (!IsShapeConsumerOp(*n)) { + continue; + } + + absl::optional cluster = GetXlaClusterForNode(*n); + if (!cluster.has_value()) { + continue; + } + + auto input_belongs_to_same_cluster = [&](const Edge* e) { + return cluster == GetXlaClusterForNode(*e->src()); + }; + + if (absl::c_any_of(n->in_edges(), input_belongs_to_same_cluster)) { + continue; + } + + VLOG(2) << "Declustering " << n->name() + << " because it is a root shape consumer"; + RemoveFromXlaCluster(n); + } + return Status::OK(); +} +} // namespace decluster_root_shape_consumers } // namespace Status PartiallyDeclusterPass::Run( @@ -367,6 +412,9 @@ Status PartiallyDeclusterPass::Run( TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph( graph, options.flib_def, options.session_options->env)); + TF_RETURN_IF_ERROR( + decluster_root_shape_consumers::PartiallyDeclusterGraph(graph)); + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 3494d0ee7ef..a9c44fb1cb7 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -40,20 +40,20 @@ limitations under the License. namespace tensorflow { namespace { -REGISTER_OP("FakeNullary").Output("out: float"); +REGISTER_OP("FakeNullary").Output("out: int32"); REGISTER_OP("FakeBinary") - .Input("host_in: float") - .Input("device_in: float") - .Output("host_out: float") - .Output("device_out: float"); + .Input("host_in: int32") + .Input("device_in: int32") + .Output("host_out: int32") + .Output("device_out: int32"); REGISTER_OP("FakeResourceVar").Output("out: resource"); REGISTER_OP("FakeResourceUpdate") .Input("in: resource") .Output("out: resource") - .Output("something_else: float"); + .Output("something_else: int32"); class FakeBinaryOp : public OpKernel { public: @@ -467,5 +467,61 @@ TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) { EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr); } +TEST(PartiallyDeclusterPassTest, MetadataOpsDontStartClusters) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0"); + + Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT); + Output b = ops::Shape(in_cluster_and.WithOpName("b"), a); + Output c = ops::Rank(in_cluster_and.WithOpName("c"), b); + Output d = ops::Size(in_cluster_and.WithOpName("d"), c); + (void)ops::Shape(in_cluster_and.WithOpName("e"), d); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + Node* n_b = FindNodeByName(*graph, "b"); + ASSERT_NE(n_b, nullptr); + EXPECT_EQ(GetXlaClusterForNode(*n_b), absl::nullopt); + + Node* n_c = FindNodeByName(*graph, "c"); + ASSERT_NE(n_c, nullptr); + EXPECT_EQ(GetXlaClusterForNode(*n_c), absl::nullopt); + + Node* n_d = FindNodeByName(*graph, "d"); + ASSERT_NE(n_d, nullptr); + EXPECT_EQ(GetXlaClusterForNode(*n_d), absl::nullopt); + + Node* n_e = FindNodeByName(*graph, "e"); + ASSERT_NE(n_e, nullptr); + EXPECT_EQ(GetXlaClusterForNode(*n_e), absl::nullopt); +} + +TEST(PartiallyDeclusterPassTest, MetaConsumersArentDeclustered) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0"); + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT); + Output b = ops::Add(in_cluster_and.WithOpName("b"), a, a); + Output c = ops::Rank(in_cluster_and.WithOpName("c"), b); + + Output e; + TF_ASSERT_OK( + CreateOutputWithScope("FakeBinary", {c, c}, root.WithOpName("e"), &e)); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + Node* n_b = FindNodeByName(*graph, "b"); + ASSERT_NE(n_b, nullptr); + EXPECT_EQ(GetXlaClusterForNode(*n_b), "cluster_0"); + + Node* n_c = FindNodeByName(*graph, "c"); + ASSERT_NE(n_c, nullptr); + EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc new file mode 100644 index 00000000000..fb56ff2ddf5 --- /dev/null +++ b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/match.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +TEST(RearrangeFunctionArgumentForFunctionTest, Basic) { + FunctionDefLibrary fdl; + { + // Function for StatefulPartitionedCall's "f", If's + // "then_branch"/"else_branch". + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL) + // "ret0" = "arg1" + // "ret1" = "arg0" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1); + auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0); + auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef)); + } + { + // Function for While's "body". + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL) + // "ret0" = "arg0" + // "ret1" = "arg1" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1); + auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg0, 0); + auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef)); + } + { + // Function for While's "cond". + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_BOOL) + // "ret0" = "arg1" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1); + auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "f3", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + // Build the XLA computation graph. + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_INT32) + // "arg0", "arg1" -> "if" (If) -> "ret0", "ret1" + // "arg0", "arg1" -> "while" (While) -> "ret2", "ret3" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_BOOL, 1); + NameAttrList f; + f.set_name("f1"); + auto if_op = ops::If(s.WithOpName("if"), arg1, + std::initializer_list{arg0, arg1}, + {DT_BOOL, DT_RESOURCE}, f, f); + auto ret0 = ops::_Retval(s.WithOpName("ret0"), if_op.output[0], 0); + auto ret1 = ops::_Retval(s.WithOpName("ret1"), if_op.output[1], 1); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("f3"); + body_fn.set_name("f2"); + auto while_op = + ops::While(s.WithOpName("while"), + std::initializer_list{arg0, arg1}, cond_fn, body_fn); + auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2); + auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + std::vector> fbodies; + TF_CHECK_OK(RearrangeFunctionArguments( + [&](const NameAttrList &function, const FunctionBody **fbody) { + std::unique_ptr new_fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()), + AttrSlice(&function.attr()), + &fld, &new_fbody)); + *fbody = new_fbody.get(); + fbodies.push_back(std::move(new_fbody)); + return Status::OK(); + }, + g.get(), &fld)); + + // Check function f1_rearrange_0, input types should be {DT_BOOL, DT_RESOURCE} + // and output types should be {DT_BOOL}. + const FunctionDef *f1_rewritten = fld.Find("f1_rearrange_0"); + CHECK_NE(f1_rewritten, nullptr); + ASSERT_EQ(f1_rewritten->signature().input_arg_size(), 2); + EXPECT_EQ(f1_rewritten->signature().input_arg(0).type(), DT_BOOL); + EXPECT_EQ(f1_rewritten->signature().input_arg(1).type(), DT_RESOURCE); + ASSERT_EQ(f1_rewritten->signature().output_arg_size(), 1); + EXPECT_EQ(f1_rewritten->signature().output_arg(0).type(), DT_BOOL); + + // Check node "if" input and output edges. + auto node_name_index = g->BuildNodeNameIndex(); + const Node *if_node = node_name_index.at("if"); + ASSERT_NE(if_node, nullptr); + const Node *input_node; + TF_CHECK_OK(if_node->input_node(1, &input_node)); + EXPECT_EQ(input_node->name(), "arg1"); + TF_CHECK_OK(if_node->input_node(2, &input_node)); + EXPECT_EQ(input_node->name(), "arg0"); + const Node *ret0_node = node_name_index.at("ret0"); + ASSERT_NE(ret0_node, nullptr); + TF_CHECK_OK(ret0_node->input_node(0, &input_node)); + EXPECT_EQ(input_node->name(), "if"); + const Node *ret1_node = node_name_index.at("ret1"); + ASSERT_NE(ret1_node, nullptr); + TF_CHECK_OK(ret1_node->input_node(0, &input_node)); + EXPECT_EQ(input_node->name(), "arg0"); + + // Check node "while" input and output edges. + const Node *while_node = node_name_index.at("while"); + ASSERT_NE(while_node, nullptr); + TF_CHECK_OK(while_node->input_node(0, &input_node)); + EXPECT_EQ(input_node->name(), "arg1"); + TF_CHECK_OK(while_node->input_node(1, &input_node)); + EXPECT_EQ(input_node->name(), "arg0"); + const Node *ret2_node = node_name_index.at("ret2"); + ASSERT_NE(ret2_node, nullptr); + TF_CHECK_OK(ret2_node->input_node(0, &input_node)); + EXPECT_EQ(input_node->name(), "arg0"); + const Node *ret3_node = node_name_index.at("ret3"); + ASSERT_NE(ret3_node, nullptr); + TF_CHECK_OK(ret3_node->input_node(0, &input_node)); + EXPECT_EQ(input_node->name(), "while"); +} + +TEST(RearrangeFunctionArgumentForFunctionTest, + WhileResourceRetvalFromDifferentArgUnimplemented) { + FunctionDefLibrary fdl; + { + // Function for While's "body". + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32) + // "ret0" = "arg1" + // "ret1" = "arg0" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1); + Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2); + auto ret0 = ops::_Retval(s.WithOpName("ret0"), arg1, 0); + auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg0, 1); + auto ret2 = ops::_Retval(s.WithOpName("ret2"), arg2, 2); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "f2", xla_fdef)); + } + { + // Function for While's "cond". + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32) + // "ret0" = true + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1); + Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2); + Output cond = ops::Const(s.WithOpName("const"), true, TensorShape({})); + auto ret0 = ops::_Retval(s.WithOpName("ret0"), cond, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "f1", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + // Build the XLA computation graph. + // "arg0" (T=DT_RESOURCE), "arg1" (T=DT_RESOURCE), "arg2" (T=DT_INT32) + // "arg0", "arg1" -> "while" (While) + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_RESOURCE, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_RESOURCE, 1); + Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("f1"); + body_fn.set_name("f2"); + auto while_op = ops::While(s.WithOpName("while"), + std::initializer_list{arg0, arg1, arg2}, + cond_fn, body_fn); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + std::vector> fbodies; + Status status = RearrangeFunctionArguments( + [&](const NameAttrList &function, const FunctionBody **fbody) { + std::unique_ptr new_fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()), + AttrSlice(&function.attr()), + &fld, &new_fbody)); + *fbody = new_fbody.get(); + fbodies.push_back(std::move(new_fbody)); + return Status::OK(); + }, + g.get(), &fld); + EXPECT_EQ(status.code(), error::UNIMPLEMENTED); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index c0897217bcb..fc2f69e2ad3 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -84,6 +84,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" @@ -93,22 +94,6 @@ limitations under the License. namespace tensorflow { namespace { -// Returns true if `n` may call a function. -Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def, - bool* out_result) { - if (flib_def->Contains(n.type_string())) { - *out_result = true; - } else { - *out_result = - std::any_of(n.def().attr().begin(), n.def().attr().end(), - [](const std::pair& name_attr_pair) { - return name_attr_pair.second.has_func(); - }); - } - - return Status::OK(); -} - // Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is // not a resource operation recognized by XLA then sets `out_resource_op_kind` // to nullopt. @@ -134,9 +119,7 @@ Status XlaResourceOpKindForNode( // We conservatively assume that functions will both read and write resource // variables. In the future we may consider doing some form of // inter-procedural analysis. - bool may_call_function; - TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function)); - if (may_call_function) { + if (MayCallFunction(n, flib_def)) { *out_resource_op_kind = XlaResourceOpKind::kReadWrite; } else { *out_resource_op_kind = absl::nullopt; diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index a27e0d9f2a6..a9c53a943be 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index eaa7015768c..063bb9c26a3 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -21,10 +21,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -84,15 +84,6 @@ bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); } } // namespace -Status DeviceToDeviceType(const string& device, DeviceType* device_type) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device, &parsed)) { - return errors::Internal("Malformed assigned device '", device, "'"); - } - *device_type = DeviceType(parsed.type); - return Status::OK(); -} - bool HasForwardedRefInput(const Node& node) { if (AlwaysForwardsRefInput(node)) { for (const Edge* incoming_edge : node.in_edges()) { @@ -111,7 +102,8 @@ bool HasForwardedRefInput(const Node& node) { return false; } -Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { +xla::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles) { for (int i = 0; i < graph->num_node_ids(); ++i) { // We rely on the node IDs in the cycle detection graph being consecutive // integers starting from 0. @@ -174,9 +166,11 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { } if (!cycles->InsertEdge(src, dst)) { - return errors::Internal( - "Cycle detected when adding ", src_type, "->", dst_type, - " edge: ", DescribeCycle(cycles, *graph, src, dst)); + // TODO(b/127521408): We can probably handle this situation with a more + // sophisticated SCC based algorithm, but for now we bail out. + VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type + << " edge: " << DescribeCycle(cycles, *graph, src, dst); + return false; } // Drop the original edge. continue; @@ -194,7 +188,8 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); } } - return Status::OK(); + + return true; } absl::optional GetXlaClusterForNode(const Node& node) { @@ -222,148 +217,105 @@ void RemoveFromXlaCluster(NodeDef* node_def) { void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); } -Status AdjustCycleDetectionGraphForResourceOps( - const Graph* graph, const FunctionLibraryDefinition* flib_def, - const std::function& resource_ops_to_ignore, - GraphCycles* cycles) { - std::vector> unsafe_deps; - TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs( - *graph, flib_def, resource_ops_to_ignore, &unsafe_deps)); +namespace { +struct XlaGlobalJitLevel { + OptimizerOptions::GlobalJitLevel single_gpu; + OptimizerOptions::GlobalJitLevel general; +}; - // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are - // operations that interact with resource variables, must not be put in the - // same cluster. We enforce this constraint by creating a phantom node, X, - // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P - // and Q together since that would create a cycle with X. - - for (std::pair unsafe_dep : unsafe_deps) { - int phantom_node_id = cycles->NewNode(); - CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id)); - CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second)); - } - return Status::OK(); -} - -Status PickDeviceForXlaImpl(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - bool* out_can_pick_device, - string* out_device_picked) { - if (out_can_pick_device) { - *out_can_pick_device = true; - } - -#define FAILED_TO_PICK_DEVICE(failing_status) \ - do { \ - if (out_can_pick_device) { \ - *out_can_pick_device = false; \ - return Status::OK(); \ - } else { \ - return failing_status; \ - } \ - } while (false) - - TF_RET_CHECK(!device_names.empty()) << "No devices to choose from"; - DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr); - - absl::flat_hash_set device_names_set; - for (absl::string_view device_name : device_names) { - if (!device_name.empty()) { - device_names_set.insert(device_name); - } - } - - absl::optional maybe_gpu_device; - absl::optional maybe_cpu_device; - absl::optional maybe_unknown_device; - - for (absl::string_view device_name : device_names_set) { - DeviceNameUtils::ParsedName parsed_name; - TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name)) - << device_name; - if (parsed_name.type == "GPU") { - if (maybe_gpu_device) { - FAILED_TO_PICK_DEVICE(errors::Internal( - "Multiple GPU devices ", absl::StrJoin(device_names, ", "))); - } - maybe_gpu_device = device_name; - } else if (parsed_name.type == "CPU") { - if (maybe_cpu_device) { - FAILED_TO_PICK_DEVICE(errors::Internal( - "Multiple CPU devices ", absl::StrJoin(device_names, ", "))); - } - maybe_cpu_device = device_name; - } else { - if (maybe_unknown_device) { - FAILED_TO_PICK_DEVICE(errors::Internal( - "Multiple unknown devices ", absl::StrJoin(device_names, ", "))); - } - maybe_unknown_device = device_name; - } - } - - if (maybe_unknown_device && maybe_gpu_device) { - FAILED_TO_PICK_DEVICE(errors::Internal( - "Found both unknown and GPU devices: ", *maybe_unknown_device, ", ", - *maybe_gpu_device)); - } - - if (!allow_mixing_unknown_and_cpu) { - if (maybe_unknown_device && maybe_cpu_device) { - FAILED_TO_PICK_DEVICE(errors::Internal( - "Found both unknown and CPU devices: ", *maybe_unknown_device, ", ", - *maybe_cpu_device)); - } - } - - if (out_device_picked) { - if (maybe_gpu_device) { - *out_device_picked = string(*maybe_gpu_device); - } else if (maybe_unknown_device) { - *out_device_picked = string(*maybe_unknown_device); - } else { - *out_device_picked = string(*maybe_cpu_device); - } - } - - return Status::OK(); - -#undef FAILED_TO_PICK_DEVICE -} - -Status PickDeviceForXla(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - string* out_device_picked) { - return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu, - /*out_can_pick_device=*/nullptr, - out_device_picked); -} - -Status CanPickDeviceForXla(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - bool* out_can_pick_device) { - return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu, - out_can_pick_device, - /*out_device_picked=*/nullptr); -} - -OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( +XlaGlobalJitLevel GetXlaGlobalJitLevel( const GraphOptimizationPassOptions& options) { - OptimizerOptions::GlobalJitLevel global_jit_level = + XlaGlobalJitLevel result; + + OptimizerOptions::GlobalJitLevel jit_level_in_session_opts = options.session_options->config.graph_options() .optimizer_options() .global_jit_level(); - if (global_jit_level == OptimizerOptions::DEFAULT) { + if (jit_level_in_session_opts == OptimizerOptions::DEFAULT) { // To set compilation to be on by default, change the following line. - global_jit_level = OptimizerOptions::OFF; + result.single_gpu = result.general = OptimizerOptions::OFF; + } else { + result.single_gpu = result.general = jit_level_in_session_opts; } + + // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides + // the setting in ConfigProto. MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - if (flags->tf_xla_auto_jit != OptimizerOptions::DEFAULT) { - // If the flag tf_xla_auto_jit is a valid, non-DEFAULT setting, it overrides - // the setting in ConfigProto. - global_jit_level = - static_cast(flags->tf_xla_auto_jit); + if (flags->xla_auto_jit_flag.optimization_level_single_gpu != + OptimizerOptions::DEFAULT) { + result.single_gpu = static_cast( + flags->xla_auto_jit_flag.optimization_level_single_gpu); } - return global_jit_level; + if (flags->xla_auto_jit_flag.optimization_level_general != + OptimizerOptions::DEFAULT) { + result.general = static_cast( + flags->xla_auto_jit_flag.optimization_level_general); + } + + return result; } +int GetGpuNumber(const string& device_name) { + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) { + return -1; + } + + return parsed_name.type == DEVICE_GPU ? parsed_name.id : -1; +} +} // namespace + +bool IsSingleGpuGraph(const Graph& g) { + int gpus_seen = 0; + absl::flat_hash_set devices_seen; + + for (Node* n : g.op_nodes()) { + if (devices_seen.contains(n->assigned_device_name())) { + continue; + } + + int gpu_number = GetGpuNumber(n->assigned_device_name()); + if (gpu_number != -1) { + if (++gpus_seen > 1) { + return false; + } + } + + devices_seen.insert(n->assigned_device_name()); + } + + return gpus_seen == 1; +} + +OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph( + const GraphOptimizationPassOptions& options) { + XlaGlobalJitLevel xla_global_jit_level = GetXlaGlobalJitLevel(options); + if (xla_global_jit_level.single_gpu == xla_global_jit_level.general) { + VLOG(4) << "GetGlobalJitLevelForGraph returning " + << xla_global_jit_level.single_gpu; + return xla_global_jit_level.single_gpu; + } + OptimizerOptions::GlobalJitLevel result = + IsSingleGpuGraph(**options.graph) ? xla_global_jit_level.single_gpu + : xla_global_jit_level.general; + VLOG(4) << "GetGlobalJitLevelForGraph returning " << result; + return result; +} + +bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) { + if (flib_def->Contains(n.type_string())) { + return true; + } + + // This is a conservative check: there may be nodes with a `func` + // attribute that do not make function calls. + return absl::c_any_of(n.def().attr(), + [](const std::pair& name_attr_pair) { + return name_attr_pair.second.has_func(); + }); +} +bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "Rank" || + node.type_string() == "Size"; +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index ddca0aaeabb..657075caf4d 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,8 +20,10 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -44,16 +46,17 @@ extern const char* const kXlaCompileTimeConstantInputsAttr; using OrderedNodeSet = std::set; -// Returns the DeviceType corresponding to 'device'. -Status DeviceToDeviceType(const string& device, DeviceType* device_type); - // Returns true if `node` has a ref tensor input that it forwards to its output. bool HasForwardedRefInput(const Node& node); // Creates a graph representation to enable cycle detection when clustering. // This representation handles loops in graph by disconnecting each loop from // the enclosing graph. -Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// +// Returns true for success and false for valid graphs that we can't handle yet +// (b/127521408). +xla::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. @@ -68,64 +71,22 @@ void RemoveFromXlaCluster(Node* node); // Returns true if `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node); -// Adds edges to `cycles` to prevent clustering resource operations that cannot -// be legally clustered. -Status AdjustCycleDetectionGraphForResourceOps( - const Graph* graph, const FunctionLibraryDefinition* flib_def, - const std::function& resource_ops_to_ignore, - GraphCycles* cycles); - -// Picks the device for which XLA should compile a cluster that contains -// operations placed in devices in `device_names`. For instance a cluster that -// contains operations solely placed on the CPU will be compiled into a CPU -// executable by XLA, whereas a cluster that contains operations placed on the -// CPU and also operations placed on the GPU will be compiled into a GPU -// executable. -// -// Returns a non-OK Status if no unambiguous choice of device exists. -// -// We choose the device using the following rules: -// -// - It is an error for `device_names` to contain more than one device of the -// same type. -// - GPU is preferred over CPU. -// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are -// preferred over CPU. -// - XLA devices count as "unrecognized devices". -// -// This set of rules above implicitly assume that XLA:GPU can compile all -// operations in the cluster that XLA:CPU can compile, and if -// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile -// all operations in the cluster that XLA:CPU can compile. -// -// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of -// the following things: -// -// - Let MarkForCompilationPass not inject CPU-placed operations into clusters -// that will run on unknown devices (because the unknown XLA backend may not -// support every operation supported by CPU). -// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster -// that contains nodes placed on both the CPU and on unknown devices. In this -// case it is the responsibility of the optimization pass that injected the -// CPU nodes into the cluster to ensure that these nodes can be compiled by -// the unknown XLA backend. -Status PickDeviceForXla(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - string* out_device_picked); - -// This is like `PickDeviceForXla` except that it returns false (instead of a -// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device -// exists. -Status CanPickDeviceForXla(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - bool* out_can_pick_device); - -// Determine the global jit level which is ON if either the -// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag -// is true. -OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( +// Determines the global jit level based on GraphOptimizationPassOptions, +// --tf_xla_auto_jit and whether the graph is a single GPU graph. +OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph( const GraphOptimizationPassOptions& options); +// Returns true if `g` is a single-GPU graph. A single-GPU graph uses exactly +// one GPU (and any number of CPUs). +bool IsSingleGpuGraph(const Graph& g); + +// Returns true if it is possible (but not guaranteed) that `n` calls a +// function. +bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def); + +// Returns true if `node` an operator that consumes only the shape of its input, +// not the data itself. +bool IsShapeConsumerOp(const Node& node); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 68fb4da134e..571d247c39b 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/algorithm.h" @@ -44,7 +45,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) { FixupSourceAndSinkEdges(root.graph()); GraphCycles cycles; - TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); } @@ -63,70 +64,71 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { FixupSourceAndSinkEdges(root.graph()); GraphCycles cycles; - TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); } -void CheckPickDeviceResult(absl::string_view expected_result, - bool allow_mixing_unknown_and_cpu, - absl::Span inputs) { - std::vector inputs_string; - absl::c_transform(inputs, std::back_inserter(inputs_string), - [](absl::string_view sv) { return string(sv); }); - string result; - TF_ASSERT_OK( - PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, &result)) - << "inputs = [" << absl::StrJoin(inputs, ", ") - << "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu - << ", expected_result=" << expected_result; - EXPECT_EQ(result, expected_result); -} +TEST(CreateCycleDetectionGraph, ReachingEnterExit) { + // TODO(b/127521408): We can lift this limitation with some work. + Scope root = Scope::NewRootScope().ExitOnError(); -void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu, - absl::Span inputs) { - std::vector inputs_string; - absl::c_transform(inputs, std::back_inserter(inputs_string), - [](absl::string_view sv) { return string(sv); }); - string result; - EXPECT_FALSE( - PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, &result) - .ok()); + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter_0 = + ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0"); + Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0); + + Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0); + + Output enter_1 = + ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0"); + Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK_AND_ASSIGN(bool ok, + CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(ok); } const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0"; const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0"; -const char* kXPU0 = "/job:localhost/replica:0/task:0/device:XPU:0"; - -const char* kCPU1 = "/job:localhost/replica:0/task:0/device:CPU:1"; const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1"; -const char* kXPU1 = "/job:localhost/replica:0/task:0/device:XPU:1"; -TEST(PickDeviceForXla, UniqueDevice) { - CheckPickDeviceResult(kGPU0, false, {kGPU0, kGPU0}); +TEST(IsSingleGpuGraph, ReturnsTrue) { + Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output b = ops::Add(root.WithOpName("b"), a, a); + Output c = ops::Add(root.WithOpName("c"), b, b); + + FixupSourceAndSinkEdges(root.graph()); + + EXPECT_TRUE(IsSingleGpuGraph(*root.graph())); } -TEST(PickDeviceForXla, DeviceOrder) { - CheckPickDeviceResult(kGPU0, false, {kGPU0, kCPU0}); - CheckPickDeviceResult(kXPU0, true, {kXPU0, kCPU0}); +TEST(IsSingleGpuGraph, ReturnsFalseForCpuGraph) { + Scope root = Scope::NewRootScope().WithAssignedDevice(kCPU0).ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output b = ops::Add(root.WithOpName("b"), a, a); + Output c = ops::Add(root.WithOpName("c"), b, b); + + FixupSourceAndSinkEdges(root.graph()); + + EXPECT_FALSE(IsSingleGpuGraph(*root.graph())); } -TEST(PickDeviceForXla, MultipleUnknownDevices) { - CheckPickDeviceHasError(false, {kXPU0, kXPU1}); -} +TEST(IsSingleGpuGraph, ReturnsFalseForMultiGpuGraph) { + Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError(); -TEST(PickDeviceForXla, GpuAndUnknown) { - CheckPickDeviceHasError(false, {kGPU0, kXPU1}); -} + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output b = ops::Add(root.WithOpName("b").WithAssignedDevice(kGPU1), a, a); + Output c = ops::Add(root.WithOpName("c"), b, b); -TEST(PickDeviceForXla, UnknownAndCpu) { - CheckPickDeviceHasError(false, {kXPU0, kCPU1}); -} + FixupSourceAndSinkEdges(root.graph()); -TEST(PickDeviceForXla, MultipleDevicesOfSameType) { - CheckPickDeviceHasError(false, {kCPU0, kCPU1}); - CheckPickDeviceHasError(false, {kGPU0, kGPU1}); - CheckPickDeviceHasError(false, {kXPU0, kXPU1}); - CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0}); + EXPECT_FALSE(IsSingleGpuGraph(*root.graph())); } } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 9f958463076..f53a1e5d403 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -35,6 +34,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -205,6 +205,10 @@ Status XlaCompilationCache::CompileSingleOp( NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); + // Remove the "_class" attribute from the attribute set used to create the + // compilation cache key. This attribute is information for the colocator + // and causes false uniqueness between nodes. + name.mutable_attr()->erase("_class"); auto compile_op = [&](XlaCompiler* compiler, XlaCompiler::CompilationResult* result) { std::vector result_dtypes(ctx->num_outputs()); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 345e87a5735..19e3793f29b 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -30,10 +30,17 @@ namespace tensorflow { class XlaCpuDeviceFactory : public DeviceFactory { public: + Status ListPhysicalDevices(std::vector* devices) override; Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector>* devices) override; }; +Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { + devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0")); + + return Status::OK(); +} + Status XlaCpuDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, std::vector>* devices) { @@ -46,7 +53,14 @@ Status XlaCpuDeviceFactory::CreateDevices( compile_on_demand ? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested : XlaOpRegistry::AutoclusteringPolicy::kAlways; - registration.compile_resource_ops = true; + registration.cluster_resource_variable_ops_unsafely = true; + registration.cluster_stack_ops = false; + registration.cluster_tensor_array_ops = true; + registration.cluster_stateful_rng_ops = true; + registration.cluster_control_trigger = true; + registration.elide_assert_and_checknumerics = true; + registration.cluster_variant_ops = true; + registration.cluster_slow_and_inaccurate_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); static XlaDeviceOpRegistrations* registrations = diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 56c4220f12b..a697246d1c7 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include + #include #include "absl/memory/memory.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -48,9 +48,11 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/stream_executor_util.h" @@ -290,17 +292,17 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", &host_to_device_stream_, &need_new_device_context)); - TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", - &device_to_host_stream_, - &need_new_device_context)); for (std::shared_ptr& stream : device_to_device_streams_) { TF_RETURN_IF_ERROR( EnsureStreamOkLocked(backend, "device_to_device_stream", &stream, &need_new_device_context)); } host_to_device_stream = host_to_device_stream_; - device_to_host_stream = device_to_host_stream_; device_to_device_streams = device_to_device_streams_; + // The data transfer requests from device to host could arrive out of order, + // so a single stream would cause deadlock. For this case, + // xla_device_context would borrow a stream for each transfer request. + device_to_host_stream = nullptr; } else { host_to_device_stream = stream_; device_to_host_stream = stream_; @@ -380,14 +382,17 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); - tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); + profiler::TraceMe activity( + [&] { + return absl::StrCat(op_kernel->name(), ":", op_kernel->type_string()); + }, + profiler::GetTFTraceMeLevel(op_kernel->IsExpensive())); op_kernel->ComputeAsync(context, done); } Status XlaDevice::Sync() { VLOG(1) << "XlaDevice::Sync"; - tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true); + profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo); std::shared_ptr stream; { mutex_lock lock(mu_); @@ -428,13 +433,12 @@ void XlaDevice::Sync(const DoneCallback& done) { // that everything enqueued onto the stream (i.e., the device) at this very // moment--when ThenEnqueueOnBackgroundThread is called--will have finished. // This achieves a device-wide sync. - stream->ThenEnqueueOnBackgroundThread( - [stream, done](se::StreamExecutor*) { - tracing::ScopedActivity activity("XlaDevice::Sync::Callback", - /*is_expensive=*/true); - done(stream->ok() ? Status::OK() - : errors::Internal("XlaDevice::Sync() failed.")); - }); + stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) { + profiler::TraceMe activity("XlaDevice::Sync::Callback", + profiler::TraceMeLevel::kInfo); + done(stream->ok() ? Status::OK() + : errors::Internal("XlaDevice::Sync() failed.")); + }); } Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, @@ -458,11 +462,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Allocator* allocator = GetAllocatorLocked(alloc_attrs); Tensor copy(allocator, parsed.dtype(), parsed.shape()); Notification n; - device_context->CopyCPUTensorToDevice(&parsed, this, ©, - [&n, &status](const Status& s) { - status = s; - n.Notify(); - }); + device_context->CopyCPUTensorToDevice( + &parsed, this, ©, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }, + true /*sync_dst_compute*/); n.WaitForNotification(); *tensor = copy; } @@ -519,7 +525,7 @@ Status XlaDevice::RefreshStatus() { XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter - // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes + // gets executed by an XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. OpKernel* (*factory)(OpKernelConstruction*) = [](OpKernelConstruction* context) -> OpKernel* { diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 977f5f5cf15..51910c6fabc 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -212,14 +212,12 @@ class XlaDevice : public LocalDevice { std::shared_ptr stream_ GUARDED_BY(mu_); // If false, only stream_ is valid and all computation and transfers use // stream_. If true, computation is performed by stream_ and transfers are - // performed by host_to_device/device_to_host_stream. + // performed by host_to_device/device_to_device stream or borrowing a stream + // for each device to host transfer. const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. std::shared_ptr host_to_device_stream_ GUARDED_BY(mu_); - // If use_multiple_streams_, device to host transfers are performed using this - // stream. - std::shared_ptr device_to_host_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, transfers between different devices are performed // using these streams. std::vector> device_to_device_streams_ diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 05b9c511866..ea784e72137 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/stream_executor/platform/port.h" namespace tensorflow { @@ -64,6 +65,9 @@ absl::optional XlaDeviceAllocator::GetStats() { tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use; tf_stats.largest_alloc_size = se_stats->largest_alloc_size; tf_stats.bytes_limit = se_stats->bytes_limit; + tf_stats.bytes_reserved = se_stats->bytes_reserved; + tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved; + tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit; return tf_stats; } @@ -84,7 +88,6 @@ XlaDeviceContext::XlaDeviceContext( shape_representation_fn_(std::move(shape_representation_fn)), thread_pool_(thread_pool) { CHECK(host_to_device_stream_ != nullptr); - CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); if (!shape_representation_fn_) { shape_representation_fn_ = [](const TensorShape& shape, @@ -106,7 +109,8 @@ void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, - StatusCallback done) const { + StatusCallback done, + bool sync_dst_compute) const { if (cpu_tensor->NumElements() == 0) { VLOG(2) << "CopyCPUTensorToDevice empty tensor"; done(Status::OK()); @@ -213,8 +217,23 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, << cpu_tensor->shape().DebugString() << " " << device_tensor->shape().DebugString(); + std::shared_ptr device_to_host_stream; + if (device_to_host_stream_) { + device_to_host_stream = device_to_host_stream_; + } else { + stream_executor::port::StatusOr ptr_or_status = + client_->mutable_backend()->BorrowStream( + stream_->parent()->device_ordinal()); + if (!ptr_or_status.status().ok()) { + done(ptr_or_status.status()); + return; + } + device_to_host_stream = + std::shared_ptr(std::move(ptr_or_status.ValueOrDie())); + } + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); - xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get()); + xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream.get()); // Transfer manager requires the shape of the shaped buffer to be the same as // literal shape except for the layout. Set the literal to use xla_tensor's @@ -227,14 +246,25 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, cpu_tensor, &literal)); TensorReference ref(*device_tensor); + const bool device_allows_sync_on_completion = + device->AllowsSyncOnCompletion(); + // Explicitly capture device_to_host_stream to make sure the stream is alive + // before the transfer finishes. transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal, - [ref, xla_tensor, done](xla::Status status) { - done([&]() -> Status { - VLOG(2) << "Transfer from device as literal: " - << xla_tensor->shaped_buffer().ToString(); - return status; - }()); + device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal, + [ref, xla_tensor, done, device_to_host_stream, + device_allows_sync_on_completion](xla::Status status) { + Status done_status = status; + VLOG(2) << "Transfer from device as literal: " + << xla_tensor->shaped_buffer().ToString(); + // For devices don't allow sync on completion, the device execution is + // deferred. We check the execution stream status here to avoid wrong + // results from a failed stream being propogated to following + // host-side ops. + if (!device_allows_sync_on_completion) { + done_status.Update(xla_tensor->RefreshStatusOfStreams()); + } + done(done_status); ref.Unref(); }); } diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1ce64ad323b..3b9c4160b95 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -61,8 +61,8 @@ class XlaDeviceContext : public DeviceContext { thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, - Tensor* device_tensor, - StatusCallback done) const override; + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; @@ -75,9 +75,6 @@ class XlaDeviceContext : public DeviceContext { se::Stream* host_to_device_stream() const { return host_to_device_stream_.get(); } - se::Stream* device_to_host_stream() const { - return device_to_host_stream_.get(); - } se::Stream* device_to_device_stream(int index) const { return device_to_device_streams_.at(index).get(); } @@ -99,7 +96,8 @@ class XlaDeviceContext : public DeviceContext { // idential to stream_, but must not be nullptr. std::shared_ptr host_to_device_stream_; // The stream to use for transferring data from device to host. Can be - // idential to stream_, but must not be nullptr. + // idential to stream_. If nullptr, borrow a stream from backend for each + // transfer request to support out-of-order requests. std::shared_ptr device_to_host_stream_; // Streams to use for transferring data directly between different devices, // e.g., over NVLINK. diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 09e04d22def..293ea3997cc 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/function_ops.h" +#include "tensorflow/core/kernels/host_constant_op.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" @@ -93,11 +94,22 @@ class XlaAssignVariableOp : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ ConstantOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("HostConst").Device(DEVICE).HostMemory("output"), _HostConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", DT_STRING), \ IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Identity").Device(DEVICE).TypeConstraint("T"), \ + IdentityOp); \ + REGISTER_KERNEL_BUILDER(Name("Identity") \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input") \ + .HostMemory("output"), \ + IdentityOp); \ REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ @@ -196,9 +208,7 @@ class XlaAssignVariableOp : public OpKernel { Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ \ REGISTER_KERNEL_BUILDER( \ - Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \ - TYPES), \ - ArgOp); \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp); \ REGISTER_KERNEL_BUILDER(Name(kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ @@ -207,11 +217,8 @@ class XlaAssignVariableOp : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ \ - REGISTER_KERNEL_BUILDER(Name(kRetOp) \ - .Device(DEVICE) \ - .TypeConstraint("T", TYPES) \ - .HostMemory("input"), \ - RetvalOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp); \ REGISTER_KERNEL_BUILDER(Name(kRetOp) \ .Device(DEVICE) \ .TypeConstraint("T") \ @@ -240,6 +247,9 @@ class XlaAssignVariableOp : public OpKernel { data::MakeIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ data::AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \ + data::AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ data::IteratorGetNextOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc deleted file mode 100644 index bc0db558d8d..00000000000 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ /dev/null @@ -1,343 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" - -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/jit/deadness_analysis.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/core/common_runtime/shape_refiner.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" - -namespace tensorflow { - -// Is 'node' an operator that consumes only the shape of its input, not the -// data itself? -static bool IsShapeConsumerOp(const Node& node) { - return node.type_string() == "Shape" || node.type_string() == "ShapeN" || - node.type_string() == "Rank" || node.type_string() == "Size"; -} - -// Returns true if the op can be decomposed into XLA ops for which -// there are fusible elemental implementations. -static bool IsXlaFusible(const NodeDef& node) { - static const std::unordered_set* elementwise_ops = - new std::unordered_set( - {// tf2xla/kernels/aggregate_ops.cc - "AddN", - // tf2xla/kernels/binary_ops.cc - "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", - "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", - "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", - "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", - "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", - "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", - "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", - // tf2xla/kernels/unary_ops.cc - "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", - "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", - "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", - "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", - "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", - "Square", "Tan", "Tanh", "Real", "Imag", - // tf2xla/kernels/bcast_ops.cc - "BroadcastArgs", "BroadcastGradientArgs", - // tf2xla/kernels/bias_ops.cc - "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, - // tf2xla/kernels/cast_op.cc - "Cast", - // tf2xla/kernels/concat_op.cc - "Concat", "ConcatV2", "ConcatOffset", - // tf2xla/kernels/const_op.cc - "Const", - // tf2xla/kernels/elu_op.cc - "Elu", "EluGrad", "Selu", "SeluGrad", - // tf2xla/kernels/fill_op.cc - "Fill", - // tf2xla/kernels/identity_op.cc - "Identity", "IdentityN", "PreventGradient", - "StopGradient", /*"Snapshot",*/ - // tf2xla/kernels/index_ops.cc - "ArgMax", "ArgMin", - // tf2xla/kernels/mirror_pad_op.cc - "MirrorPad", - // tf2xla/kernels/one_hot_op.cc - "OneHot", - // tf2xla/kernels/pack_op.cc - "Pack", - // tf2xla/kernels/pad_op.cc - "Pad", "PadV2", - // tf2xla/kernels/relu_op.cc - "Relu", "Relu6", "ReluGrad", "Relu6Grad", - // tf2xla/kernels/reshape_op.cc - "Reshape", - // tf2xla/kernels/reverse_op.cc - "Reverse", "ReverseV2", - // tf2xla/kernels/reverse_sequence_op.cc - "ReverseSequence", - // tf2xla/kernels/shape_op.cc - "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", - "ZerosLike", "OnesLike", - // tf2xla/kernels/slice_op.cc - "Slice", - // tf2xla/kernels/split_op.cc - "Split", "SplitV", - // tf2xla/kernels/strided_slice_op.cc - "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", - // tf2xla/kernels/tile_ops.cc - "Tile", - // tf2xla/kernels/transpose_op.cc - "Transpose", "InvertPermutation", - // tf2xla/kernels/unpack_op.cc - "Unpack"}); - - return elementwise_ops->count(node.op()) > 0; -} - -Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, - const grappler::GrapplerItem& item, - GraphDef* output) { - VLOG(2) << "Here at fusion optimizer"; - - // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op. - // Once that happens, the expected interaction between this optimizer and when - // the global_jit_level is set is as follows: Fusion optimizer will replace - // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can - // be further compiled where possible via mark_for_compilation_pass. Note that - // this might lead to inefficient clustering, and it is best to use either the - // fusion optimizer or the global_jit flag, and not combine the two. - - // Create a Graph out of GraphDef. This is required currently because the - // helpers around clustering, encapsulation etc work on graphs. - FunctionLibraryDefinition function_library(OpRegistry::Global(), - item.graph.library()); - Graph graph(function_library); - ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); - shape_refiner.set_require_shape_inference_fns(false); - shape_refiner.set_disable_constant_propagation(true); - ImportGraphDefOptions options; - // Graph optimization happens at the late stage of graph execution, when - // colocation constraints are already validated previously and the device - // placement of nodes has also completed, so there is no need to validate - // colocation constraints again. - options.validate_colocation_constraints = false; - options.validate_shape = false; - TF_RETURN_IF_ERROR( - ImportGraphDef(options, item.graph, &graph, &shape_refiner)); - - std::unique_ptr deadness; - TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness)); - - // Collect nodes that can be fused via XLA, while ignoring those that - // explicitly ask for XLA: (*) nodes that are marked to be compiled - // explicitly. (*) nodes assigned to XLA device. - OrderedNodeSet compilation_candidates; - for (Node* node : graph.op_nodes()) { - // If there is a _XlaCompile annotation, ignore the node if it is - // true. Nodes are marked with this attr via experimental_jit_scope, and - // will be handled by the mark_for_compilation pass. - bool compile = false; - Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); - if (status.ok() && compile) { - continue; - } - // If there is already a _XlaCluster annotation, ignore the node. Nodes are - // marked with this attr to indicate they are already part of a cluster and - // hence ignored. - status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile); - if (status.ok()) { - continue; - } - - // If there is an explicit XLA device placement, ignore the node. - DeviceType device_type(""); - TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); - if (device_type.type_string().find("XLA") != string::npos) continue; - - // Assume all fusible ops are registered. - // TODO(hpucha): Check for registration if possible. - if (!IsXlaFusible(node->def())) { - continue; - } - - // XLA does not offer guaranteed aliasing between the input and output of - // the XLA cluster so it can't implement the forward-tensor-ref semantic. - // Leave such nodes out of XLA clusters. - if (HasForwardedRefInput(*node)) { - continue; - } - - // If inputs to `node` can have conflicting deadness (i.e. some are alive - // and some are dead) then don't compile it. XLA cannot represent the - // deadness semantics of these nodes correctly and auto-clustering these - // nodes can cause deadness to propagate to nodes that should be live. - if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { - continue; - } - - compilation_candidates.insert(node); - } - - if (compilation_candidates.empty()) { - VLOG(2) << "No compilable candidates"; - *output = item.graph; - return Status::OK(); - } - - GraphCycles cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); - TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( - &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles)); - - // TODO(hpucha): Make clustering more robust. There are two known issues that - // we need to mitigate: (a) Non-resource variables can cause deadlocks - // when clustering changes order of execution. See b/77263461 for a specific - // example. (b) Queue operations can also cause deadlocks. See b/77261498 for - // example. - - struct Cluster { - // Identifies the node that represents this cluster in the cycle detection - // graph. - int representative = -1; - }; - - // Each compilation candidate belongs to a cluster. The cluster's - // representative names the node in the 'cycles' graph that represents the - // cluster. - std::vector> clusters(graph.num_node_ids()); - std::deque*> worklist; - for (Node* node : compilation_candidates) { - Cluster& cluster = clusters[node->id()].Get(); - cluster.representative = node->id(); - worklist.push_back(&clusters[node->id()]); - } - - // Repeatedly contract edges between clusters that are on the same device, - // provided the contraction would not create a cycle. This is a simplified - // version of the clustering in mark_for_compilation_pass that also deals with - // nodes that are explicitly tagged to be compiled/clustered. - while (!worklist.empty()) { - int from = worklist.front()->Get().representative; - worklist.pop_front(); - - Node* node_from = graph.FindNodeId(from); - if (node_from->IsControlFlow()) { - // Control flow nodes aren't compilation candidates and should never - // appear. - return errors::Internal( - "Found control flow node in clustering worklist: ", - node_from->type_string()); - } - for (int to : cycles.Successors(from)) { - if (to >= graph.num_node_ids()) { - // Node is a "frame" node that is present only in the cycle detection - // graph. No clustering is possible. - continue; - } - Node* node_to = graph.FindNodeId(to); - if (compilation_candidates.find(node_to) == - compilation_candidates.cend()) { - continue; - } - - // Do not cluster across devices. - if (node_from->def().device() != node_to->def().device()) { - VLOG(2) << "Devices " << node_from->def().device() << " " - << node_to->def().device(); - VLOG(2) << "Device names " << node_from->assigned_device_name() << " " - << node_to->assigned_device_name(); - continue; - } - - // Ops that consume shapes cannot be the root of a cluster. This is an - // optimization. - if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { - continue; - } - - // If contracting the edge would create a cycle, bail out. - // However, just because we can't merge the clusters now does not mean - // we won't be able to merge them in the future. - // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge - // 1->3. But if we first contract 1->2 then we can later contract 1->3. - if (!cycles.ContractEdge(from, to)) continue; - - // Merge the clusters. ContractEdge uses 'from' as the number of the - // merged node, so make sure 'from' is the chosen representative. - clusters[from].Merge(&clusters[to]); - - worklist.push_back(&clusters[from]); - break; - } - } - - // Count the number of non-trivial elements in each cluster. - std::vector effective_cluster_sizes(graph.num_node_ids()); - for (const Node* n : compilation_candidates) { - int cluster = clusters[n->id()].Get().representative; - // Identity nodes will be removed if the node gets marked for compilation. - // Therefore we don't want to count them towards the effective cluster size. - if (n->def().op() != "Identity") { - effective_cluster_sizes[cluster]++; - } - } - - const int min_cluster_size = 2; - int num_clusters = 0; - for (auto size : effective_cluster_sizes) { - if (size >= min_cluster_size) { - VLOG(3) << "Cluster " << num_clusters << " " << size; - num_clusters++; - } - } - - // Names for each cluster. - std::unordered_map cluster_names; - // Sequence number generator to ensure clusters have unique names. - static std::atomic cluster_sequence_num; - - for (Node* n : compilation_candidates) { - int cluster = clusters[n->id()].Get().representative; - - // Compile if this is a cluster of >= min_cluster_size compilable operators. - if (effective_cluster_sizes[cluster] >= min_cluster_size) { - string& name = cluster_names[cluster]; - - if (name.empty()) { - name = absl::StrCat("cluster_", cluster_sequence_num++); - } - n->AddAttr(kXlaClusterAttr, name); - VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; - } - } - - graph.ToGraphDef(output); - return Status::OK(); -} - -REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h deleted file mode 100644 index 3d2309e782d..00000000000 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ -#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ - -#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" - -namespace tensorflow { - -// Optimizes graphs by fusing ops where possible, resulting in more efficient -// execution. -class XlaFusionOptimizer : public grappler::CustomGraphOptimizer { - public: - XlaFusionOptimizer() {} - ~XlaFusionOptimizer() override {} - - Status Init( - const RewriterConfig_CustomGraphOptimizer* config = nullptr) override { - return Status::OK(); - } - - string name() const override { return "xla-fusion"; }; - - Status Optimize(grappler::Cluster* cluster, - const grappler::GrapplerItem& item, - GraphDef* output) override; - - void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item, - const GraphDef& optimize_output, double result) override { - // Nothing to do for XlaFusionOptimizer. - } -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc deleted file mode 100644 index 68e19c8a135..00000000000 --- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" -#include "tensorflow/cc/ops/resource_variable_ops.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/graph/graph_def_builder_util.h" -#include "tensorflow/core/grappler/utils/grappler_test.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace tensorflow { -namespace { - -REGISTER_OP("UncompilableNullary").Output("o: float"); -REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); - -class XlaFusionOptimizerTest : public grappler::GrapplerTest { - protected: - std::unordered_map GetClusters(const GraphDef& graph) { - std::unordered_map ids; - for (const NodeDef& node : graph.node()) { - string cluster; - if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) { - CHECK(!cluster.empty()); - ids[node.name()] = cluster; - } - } - return ids; - } -}; - -TEST_F(XlaFusionOptimizerTest, Chains) { - GraphDef graph; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = - ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); - Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); - Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); - Node* d = - ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); - Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); - ops::UnaryOp("Relu", e, builder.opts().WithName("F")); - TF_ASSERT_OK(builder.ToGraphDef(&graph)); - } - grappler::GrapplerItem item; - item.graph = graph; - - XlaFusionOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - auto clusters = GetClusters(output); - EXPECT_EQ(4, clusters.size()); - EXPECT_EQ(clusters["B"], clusters["C"]); - EXPECT_EQ(clusters["E"], clusters["F"]); - EXPECT_NE(clusters["B"], clusters["E"]); - EXPECT_TRUE(clusters.find("A") == clusters.cend()); - EXPECT_TRUE(clusters.find("D") == clusters.cend()); -} - -TEST_F(XlaFusionOptimizerTest, FusibleOps) { - GraphDef graph; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = ops::SourceOp( - "Placeholder", - builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); - Node* b = ops::SourceOp( - "Placeholder", - builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); - - Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C")); - ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); - ops::UnaryOp("Abs", c, builder.opts().WithName("E")); - - TF_ASSERT_OK(builder.ToGraphDef(&graph)); - } - grappler::GrapplerItem item; - item.graph = graph; - - XlaFusionOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - auto clusters = GetClusters(output); - EXPECT_EQ(2, clusters.size()); - EXPECT_EQ(clusters["C"], clusters["E"]); - EXPECT_TRUE(clusters.find("D") == clusters.cend()); -} - -TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) { - GraphDef graph; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = ops::SourceOp( - "Placeholder", - builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); - Node* b = ops::SourceOp( - "Placeholder", - builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); - - Node* c = ops::BinaryOp( - "Add", a, b, - builder.opts().WithName("C").WithDevice("/device:XLA_CPU")); - ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); - Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E")); - ops::UnaryOp("Cos", e, - builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true)); - - TF_ASSERT_OK(builder.ToGraphDef(&graph)); - } - grappler::GrapplerItem item; - item.graph = graph; - - XlaFusionOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - auto clusters = GetClusters(output); - EXPECT_TRUE(clusters.empty()); -} - -TEST_F(XlaFusionOptimizerTest, UncompilableCycles) { - GraphDef graph; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = ops::SourceOp("Const", builder.opts() - .WithName("A") - .WithAttr("dtype", DT_FLOAT) - .WithAttr("value", Tensor())); - Node* b = - ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); - ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); - - TF_ASSERT_OK(builder.ToGraphDef(&graph)); - } - grappler::GrapplerItem item; - item.graph = graph; - - XlaFusionOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - auto clusters = GetClusters(output); - EXPECT_TRUE(clusters.empty()); -} - -TEST_F(XlaFusionOptimizerTest, CompilableCycles) { - GraphDef graph; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = ops::SourceOp("Const", builder.opts() - .WithName("A") - .WithAttr("dtype", DT_FLOAT) - .WithAttr("value", Tensor())); - Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); - ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); - TF_ASSERT_OK(builder.ToGraphDef(&graph)); - } - grappler::GrapplerItem item; - item.graph = graph; - - XlaFusionOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - auto clusters = GetClusters(output); - EXPECT_EQ(3, clusters.size()); - EXPECT_EQ(clusters["A"], clusters["B"]); - EXPECT_EQ(clusters["A"], clusters["C"]); -} - -TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) { - Scope root = Scope::NewRootScope().ExitOnError(); - Output var_handle = - ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({})); - Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f); - Output begin = ops::Const(root.WithOpName("begin"), 0); - Output end = ops::Const(root.WithOpName("end"), 1); - Output strides = ops::Const(root.WithOpName("strides"), 1); - ops::ResourceStridedSliceAssign assign_1( - root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign); - ops::ResourceStridedSliceAssign assign_2( - root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign); - root.graph()->AddControlEdge(assign_1.operation.node(), - assign_2.operation.node()); - grappler::GrapplerItem item; - root.graph()->ToGraphDef(&item.graph); - - XlaFusionOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - auto clusters = GetClusters(output); - EXPECT_NE(clusters["assign_1"], clusters["assign_2"]); -} -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index b29f6a009b9..913612f9a6c 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -55,10 +55,32 @@ static xla::StatusOr>> ParseVisibleDeviceList( class XlaGpuDeviceFactory : public DeviceFactory { public: + Status ListPhysicalDevices(std::vector* devices) override; Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector>* devices) override; }; +Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { + auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); + if (!platform.ok()) { + // Treat failures as non-fatal; there might not be a GPU in the machine. + VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); + return Status::OK(); + } + + int device_count = platform.ValueOrDie()->VisibleDeviceCount(); + if (device_count <= 0) { + return Status::OK(); + } + + for (int i = 0; i < device_count; ++i) { + devices->push_back( + absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i)); + } + + return Status::OK(); +} + Status XlaGpuDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, std::vector>* devices) { @@ -66,7 +88,14 @@ Status XlaGpuDeviceFactory::CreateDevices( registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.autoclustering_policy = XlaOpRegistry::AutoclusteringPolicy::kAlways; - registration.compile_resource_ops = true; + registration.cluster_resource_variable_ops_unsafely = true; + registration.cluster_stack_ops = false; + registration.cluster_tensor_array_ops = true; + registration.cluster_stateful_rng_ops = true; + registration.cluster_control_trigger = true; + registration.elide_assert_and_checknumerics = true; + registration.cluster_variant_ops = true; + registration.cluster_slow_and_inaccurate_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); static XlaDeviceOpRegistrations* registrations = diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index e1a58240615..4252e2e24ac 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -32,10 +32,19 @@ constexpr std::array kExecAllTypes = { class XlaInterpreterDeviceFactory : public DeviceFactory { public: + Status ListPhysicalDevices(std::vector* devices) override; Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector>* devices) override; }; +Status XlaInterpreterDeviceFactory::ListPhysicalDevices( + std::vector* devices) { + devices->push_back( + absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0")); + + return Status::OK(); +} + Status XlaInterpreterDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, std::vector>* devices) { @@ -47,7 +56,14 @@ Status XlaInterpreterDeviceFactory::CreateDevices( registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; registration.autoclustering_policy = XlaOpRegistry::AutoclusteringPolicy::kAlways; - registration.compile_resource_ops = true; + registration.cluster_resource_variable_ops_unsafely = true; + registration.cluster_stack_ops = false; + registration.cluster_tensor_array_ops = true; + registration.cluster_stateful_rng_ops = true; + registration.cluster_control_trigger = true; + registration.elide_assert_and_checknumerics = true; + registration.cluster_variant_ops = true; + registration.cluster_slow_and_inaccurate_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, registration); diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc similarity index 90% rename from tensorflow/compiler/jit/create_xla_launch_op.cc rename to tensorflow/compiler/jit/xla_kernel_creator.cc index 7e4c8466d88..a5e59e932a2 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "tensorflow/compiler/jit/xla_kernel_creator.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" @@ -64,26 +64,23 @@ class SinglePassSearch { int current_index_; const std::vector* values_; }; +} // namespace -Status CompilationRequested(const FunctionLibraryRuntime& flr, - const NodeDef& node_def) { +bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) const { const FunctionDef* function_def = flr.GetFunctionLibraryDefinition()->Find(node_def.name()); if (function_def == nullptr) { // The node def is not calling a function. Individual ops can be // run directly using on-demand mode, no need to create XlaLaunch // kernel for them. - // TODO(b/110359382): Make custom kernel creation return a bool instead of - // status. - // We don't set error messages here to avoid unnecessary string copy. - // Similarly below. - return Status(error::INVALID_ARGUMENT, ""); + return false; } // If kXlaCompileAttr is set on the node_def, use its value. const auto& it = node_def.attr().find(kXlaCompileAttr); if (it != node_def.attr().end()) { - return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, ""); + return it->second.b(); } // kXlaCompileAttr is not set on node_def, check if it is set on @@ -100,9 +97,9 @@ Status CompilationRequested(const FunctionLibraryRuntime& flr, VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; } } - return Status(error::INVALID_ARGUMENT, ""); + return false; } - return Status::OK(); + return true; } // Given a FunctionLibraryRuntime and a NodeDef calling a function in the @@ -148,17 +145,21 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, return Status::OK(); } -} // namespace +Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + std::unique_ptr* kernel) const { + if (!CanCreateKernel(*flr, node_def)) { + return errors::Internal("Invalid node: ", node_def.ShortDebugString()); + } -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, - std::unique_ptr* kernel) { - TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); - - VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); + VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString(); // Make sure that kernels have been registered on the JIT device. XlaOpRegistry::RegisterCompilationKernels(); if (!IsCompilable(flr, node_def)) { + VLOG(1) << "Not creating XlaLaunchOp because function invoked by the " + "following node is not compilable: " + << node_def.DebugString(); // node_def is calling a function that XLA can't compile. return errors::InvalidArgument("Not compilable: ", node_def.ShortDebugString()); @@ -239,7 +240,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, namespace { bool RegisterLaunchOpCreator() { - RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); + XlaKernelCreator* xla_kernel_creator = new XlaKernelCreator(); + RegisterDefaultCustomKernelCreator(xla_kernel_creator); return true; } diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/xla_kernel_creator.h similarity index 61% rename from tensorflow/compiler/jit/create_xla_launch_op.h rename to tensorflow/compiler/jit/xla_kernel_creator.h index 98a22e35153..739cf02d877 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.h +++ b/tensorflow/compiler/jit/xla_kernel_creator.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ #define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -23,12 +24,18 @@ namespace tensorflow { class FunctionLibraryRuntime; class OpKernel; -// Given a NodeDef 'node_def' and the function library runtime 'flr', if -// 'node_def' is a call to a compilable function defined in 'flr', returns OK -// and fills in 'kernel' with a XlaLaunchOp kernel which computes the -// node. Otherwise, returns a non-OK. -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, - std::unique_ptr* kernel); +class XlaKernelCreator : public CustomKernelCreator { + public: + // Given a NodeDef 'node_def' and the function library runtime 'flr', returns + // true if 'node_def' is a call to a compilable function defined in 'flr', + // with the kXlaCompileAttr set. + bool CanCreateKernel(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) const override; + + // Given a supported NodeDef, returns a XlaLaunchOp that computes the node. + Status CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel) const override; +}; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc similarity index 71% rename from tensorflow/compiler/jit/create_xla_launch_op_test.cc rename to tensorflow/compiler/jit/xla_kernel_creator_test.cc index 0f872a480f4..de930bb2ad4 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "tensorflow/compiler/jit/xla_kernel_creator.h" #include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" @@ -53,7 +54,7 @@ FunctionDef XTimesY() { }); } -class CreateXlaLaunchOpTest : public ::testing::Test { +class XlaKernelCreatorTest : public ::testing::Test { protected: void Init(const std::vector& flib) { SessionOptions options; @@ -91,15 +92,17 @@ AttrValue BoolAttr(bool b) { return v; } -TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { +TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) { FunctionDef fdef = XTimesY(); (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); Init({fdef}); + XlaKernelCreator xla_kernel_creator; - Status status = CreateXlaLaunchOp( + Status status = xla_kernel_creator.CreateKernel( flr_, ToNodeDef(R"pb( name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' - )pb"), &kernel_); + )pb"), + &kernel_); ASSERT_TRUE(status.ok()) << status.ToString(); EXPECT_EQ("XTimesY", kernel_->name()); @@ -116,31 +119,35 @@ TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]); } -TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) { +TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) { FunctionDef fdef = XTimesY(); Init({fdef}); + XlaKernelCreator xla_kernel_creator; - Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), &kernel_); - EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); + Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), + &kernel_); + EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); } -TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) { +TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { FunctionDef fdef = XTimesY(); (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); Init({fdef}); + XlaKernelCreator xla_kernel_creator; - Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), &kernel_); - EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); + Status status = xla_kernel_creator.CreateKernel(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), + &kernel_); + EXPECT_TRUE(errors::IsInternal(status)) << status.ToString(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index c64981053fa..3bb698b33d6 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { @@ -132,7 +133,8 @@ Status LockVariables(absl::Span variables) { // cluster because we would not handle variable updates correctly. Any // locks we have already acquired will be released when the VariableInfo // objects are destroyed. - return errors::Internal("Duplicate variable passed to XLA cluster"); + // TODO(b/128495870) Add support for passing aliased resource variables. + return errors::Unimplemented("Duplicate variable passed to XLA cluster"); } VLOG(4) << "Acquiring lock for variable " << reinterpret_cast(variable); @@ -166,11 +168,11 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, } XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) - : xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {} + : se::DeviceMemoryAllocator(platform), wrapped_(wrapped) {} XlaAllocator::~XlaAllocator() {} -xla::StatusOr XlaAllocator::Allocate( +xla::StatusOr XlaAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { AllocationAttributes attrs; attrs.no_retry_on_failure = !retry_on_failure; @@ -182,8 +184,8 @@ xla::StatusOr XlaAllocator::Allocate( "Out of memory while trying to allocate ", size, " bytes."); } } - return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), - device_ordinal, this); + return se::OwningDeviceMemory(se::DeviceMemoryBase(data, size), + device_ordinal, this); } Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { @@ -192,7 +194,7 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { } XlaComputationLaunchContext::XlaComputationLaunchContext( - xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, + xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors, bool use_multiple_streams) : client_(client), xla_allocator_(xla_allocator), @@ -242,7 +244,8 @@ void XlaComputationLaunchContext::PopulateInputs( CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); } else { - CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, + on_device_shape)) << "On-device shape " << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) << " not the same as on-host shape " @@ -347,9 +350,11 @@ Status XlaComputationLaunchContext::PopulateOutputs( VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " << DataTypeString(type); if (type == DT_RESOURCE) { - TF_RET_CHECK(kernel->outputs[i].input_index >= 0) - << "Invalid input for outputs " << i; - ctx->set_output(i, ctx->input(kernel->outputs[i].input_index)); + int input_index = + kernel->outputs[i].input_index - missing_ctx_input_prefix; + TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs()) + << "Invalid input for outputs " << i << ": " << input_index; + ctx->set_output(i, ctx->input(input_index)); } else { se::DeviceMemoryBase buffer = output.buffer({output_num}); if (allocate_xla_tensors_) { @@ -369,7 +374,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + output.set_buffer(se::OwningDeviceMemory(), {output_num}); ctx->set_output(i, output_tensor); } ++output_num; @@ -430,7 +435,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( *variable_infos[i].var()->tensor() = output_tensor; } else { se::DeviceMemoryBase buffer = output.buffer({output_num}); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + output.set_buffer(se::OwningDeviceMemory(), {output_num}); Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); *variable_infos[i].var()->tensor() = output_tensor; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index c915b7118d0..4cb020ffe34 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -23,14 +23,13 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace tensorflow { class XlaAllocator; @@ -108,11 +107,11 @@ Status LockVariables(absl::Span variables) // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: // see comment on `AllowsAsynchronousDeallocation()`. -class XlaAllocator : public xla::DeviceMemoryAllocator { +class XlaAllocator : public se::DeviceMemoryAllocator { public: XlaAllocator(const se::Platform* platform, Allocator* wrapped); ~XlaAllocator() override; - xla::StatusOr Allocate( + xla::StatusOr Allocate( int device_ordinal, uint64 size, bool retry_on_failure) override; Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; @@ -129,6 +128,50 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { Allocator* wrapped_; }; +// Adapter class that wraps per-device TF allocators as an XLA allocator. +// Assumes that the Tensorflow allocator permits asynchronous deallocation; +// see comment on `AllowsAsynchronousDeallocation()`. +class MultiDeviceAdapter : public se::DeviceMemoryAllocator { + public: + MultiDeviceAdapter( + const se::Platform* platform, + std::vector> tf_allocators) + : DeviceMemoryAllocator(platform), + tf_allocators_(std::move(tf_allocators)) { + for (const auto& tf_allocator : tf_allocators_) { + per_device_allocators_.emplace_back(platform, tf_allocator.get()); + } + } + + xla::StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override { + CHECK_LT(device_ordinal, per_device_allocators_.size()); + return per_device_allocators_[device_ordinal].Allocate(device_ordinal, size, + retry_on_failure); + } + + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override { + CHECK_LT(device_ordinal, per_device_allocators_.size()); + return per_device_allocators_[device_ordinal].Deallocate(device_ordinal, + mem); + } + + // The Tensorflow BFC allocator used on GPU allows host-side deallocation + // before GPU execution takes place. Tensorflow uses the ordering of the main + // compute stream to enforce a happens-before relationship between a memory + // allocation and code that reuses the same memory. If Tensorflow adds + // support for multiple GPU streams or allocators with different ordering + // requirements, this code may need to change. + // (This attribute has no effect on CPU.) + bool AllowsAsynchronousDeallocation() const override { return true; } + + private: + std::vector per_device_allocators_; + // The wrapped TF allocators backing per_device_allocators_ (XlaAllocator does + // not take ownership of its underlying Allocator). + std::vector> tf_allocators_; +}; + // Helper class to perform the marshalling of TensorFlow inputs and outputs to // ShapedBuffers suitable for passing to an XLA computation. class XlaComputationLaunchContext { @@ -142,7 +185,7 @@ class XlaComputationLaunchContext { // because we track inter-stream dependencies through events inside XlaTensor // objects. XlaComputationLaunchContext(xla::LocalClient* client, - xla::DeviceMemoryAllocator* xla_allocator, + se::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors, bool use_multiple_streams); @@ -186,7 +229,7 @@ class XlaComputationLaunchContext { private: xla::LocalClient* client_; - xla::DeviceMemoryAllocator* xla_allocator_; + se::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; bool use_multiple_streams_; std::vector> arg_buffers_; diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index d1f7f754c83..c211deaf87c 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -59,11 +59,11 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = client->backend().transfer_manager()->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer, + TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer, client->backend().memory_allocator()->Allocate( device_ordinal, size, /*retry_on_failure=*/false)); // Move our buffer into shaped_buffer, which takes ownership of it. - index_to_buffer.second = buffer.Forget(); + index_to_buffer.second = buffer.Release(); } VLOG(4) << shaped_buffer.ToString(); @@ -97,6 +97,15 @@ void XlaTensor::ResetDefinitionEvent(std::shared_ptr event, streams_defined_on_ = {stream}; } +Status XlaTensor::RefreshStatusOfStreams() { + mutex_lock lock(mu_); + Status status; + for (se::Stream* stream : streams_defined_on_) { + status.Update(stream->RefreshStatus()); + } + return status; +} + // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from // device-side tensors, which are either CPU or GPU memory pointers. This works // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 77e80aa2527..8a4eb7493be 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -102,6 +102,10 @@ class XlaTensor { void ResetDefinitionEvent(std::shared_ptr event, se::Stream* stream); + // Refresh the status of streams_defined_on_. Return the first not-OK stream's + // status or OK. + Status RefreshStatusOfStreams(); + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. static XlaTensor* FromOpaquePointer(void* ptr); // Convert to a raw pointer from an XlaTensor, adding the pointer tag. diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 65f5fba269c..fbb60d17316 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -65,6 +65,7 @@ py_test( name = "xla_test_test", size = "small", srcs = ["xla_test_test.py"], + python_version = "PY2", deps = [ ":xla_test", ], @@ -138,6 +139,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "add_n_test", + size = "small", + srcs = ["add_n_test.py"], + # TensorList ops are not implemented in the on-demand compilation model yet. + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:list_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "addsign_test", size = "small", @@ -243,14 +260,47 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "cond_test", + size = "small", + srcs = ["cond_test.py"], + disabled_backends = ["cpu_ondemand"], # b/129021699 + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:function", + ], +) + tf_xla_py_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], - # TODO(kuny): remove it after b/124377352 is fixed. + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:map_fn", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", + ], +) + +tf_xla_py_test( + name = "svd_op_test", + size = "medium", + srcs = ["svd_op_test.py"], disabled_backends = [ + # TODO(b/129396575): Fails on CPU. "cpu", - "gpu", "cpu_ondemand", ], tags = ["optonly"], @@ -409,10 +459,6 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], - tags = [ - "manual", - "notap", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -455,7 +501,7 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - shard_count = 3, + shard_count = 6, tags = ["optonly"], deps = [ ":xla_test", @@ -826,6 +872,8 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + tags = ["config-cuda-only"], + use_xla_device = False, deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -846,6 +894,7 @@ tf_xla_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:standard_ops", "//tensorflow/python:stateful_random_ops", + "//tensorflow/python/kernel_tests/random:util", ], ) @@ -860,15 +909,18 @@ tf_xla_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:standard_ops", "//tensorflow/python:stateless_random_ops", + "//tensorflow/python/kernel_tests/random:util", ], ) tf_xla_py_test( name = "tensor_array_ops_test", - size = "small", + size = "medium", srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + tags = ["config-cuda-only"], + use_xla_device = False, deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1022,7 +1074,7 @@ tf_xla_py_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], - shard_count = 5, + shard_count = 1, # Times out in fastbuild mode. tags = ["optonly"], deps = [ diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py index 6cf16cc07ff..548dbe53f2a 100644 --- a/tensorflow/compiler/tests/adadelta_test.py +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -41,7 +41,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): all_lr = [1.0, 0.5, 0.1] for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): for grad in all_grad: for lr in all_lr: var0_init = [1.0, 2.0] diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index e9c2d363aca..369d0097a0f 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithoutRegularizationBasic1(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) @@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAwithoutRegularizationBasic2(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) @@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithL1(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) @@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithL1_L2(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index e26483303c3..844e5dfd831 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -59,7 +59,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -87,7 +87,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 8bcff9d379d..bf22b756074 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -99,7 +99,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -142,7 +142,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py index 961b46375c9..e50b5594a62 100644 --- a/tensorflow/compiler/tests/adamax_test.py +++ b/tensorflow/compiler/tests/adamax_test.py @@ -49,7 +49,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): def testBasic(self): for i, dtype in enumerate(self.float_types): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 @@ -103,7 +103,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 diff --git a/tensorflow/compiler/tests/add_n_test.py b/tensorflow/compiler/tests/add_n_test.py new file mode 100644 index 00000000000..40e6bea0cc5 --- /dev/null +++ b/tensorflow/compiler/tests/add_n_test.py @@ -0,0 +1,84 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AddN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class XlaAddNTest(xla_test.XLATestCase): + + def testAddTensorLists(self): + with self.session(), self.test_scope(): + l1 = list_ops.tensor_list_reserve( + element_shape=[], element_dtype=dtypes.float32, num_elements=3) + l2 = list_ops.tensor_list_reserve( + element_shape=[], element_dtype=dtypes.float32, num_elements=3) + l1 = list_ops.tensor_list_set_item(l1, 0, 5.) + l2 = list_ops.tensor_list_set_item(l2, 2, 10.) + + l = math_ops.add_n([l1, l2]) + self.assertAllEqual( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32), + [5.0, 0.0, 10.0]) + + def testAddTensorListsFailsIfLeadingDimsMismatch(self): + with self.session(), self.test_scope(): + l1 = list_ops.tensor_list_reserve( + element_shape=[], element_dtype=dtypes.float32, num_elements=2) + l2 = list_ops.tensor_list_reserve( + element_shape=[], element_dtype=dtypes.float32, num_elements=3) + l = math_ops.add_n([l1, l2]) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "TensorList arguments to AddN must all have the same shape"): + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval() + + def testAddTensorListsFailsIfElementShapesMismatch(self): + with self.session() as session, self.test_scope(): + # Use placeholders instead of constant values for shapes to prevent TF's + # shape inference from catching this early. + l1_element_shape = array_ops.placeholder(dtype=dtypes.int32) + l2_element_shape = array_ops.placeholder(dtype=dtypes.int32) + l1 = list_ops.tensor_list_reserve( + element_shape=l1_element_shape, + element_dtype=dtypes.float32, + num_elements=3) + l2 = list_ops.tensor_list_reserve( + element_shape=l2_element_shape, + element_dtype=dtypes.float32, + num_elements=3) + l = math_ops.add_n([l1, l2]) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "TensorList arguments to AddN must all have the same shape"): + session.run( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32), { + l1_element_shape: [], + l2_element_shape: [2] + }) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py index a37c97e6d37..f55ab75745a 100644 --- a/tensorflow/compiler/tests/addsign_test.py +++ b/tensorflow/compiler/tests/addsign_test.py @@ -63,7 +63,7 @@ class AddSignTest(xla_test.XLATestCase): alpha=1.0, beta=0.9): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): # Initialize variables for numpy implementation. m0, m1 = 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 68f52e796c2..6fc7114a900 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -40,7 +40,7 @@ class ArgMinMaxTest(xla_test.XLATestCase): op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): pinp = array_ops.placeholder( dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index c829c50b551..1f5ef5586f4 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -23,6 +23,7 @@ import itertools import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -38,7 +39,7 @@ class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") @@ -1041,6 +1042,62 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([2], dtype=np.int64), expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype)) + def testBatchMatMulBroadcast(self): + """Tests broadcasting behavior of BatchMatMul.""" + with compat.forward_compatibility_horizon(2019, 4, 26): + # [2, 3] @ [1, 3, 4] -> [1, 2, 4] + self._testBinary( + math_ops.matmul, + np.array([[10, 20, 30], [11, 21, 31]], dtype=np.float32), + np.array([[[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]]], + dtype=np.float32), + expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]], + dtype=np.float32)) + # [1, 2, 3] @ [3, 4] -> [1, 2, 4] + self._testBinary( + math_ops.matmul, + np.array([[[10, 20, 30], [11, 21, 31]]], dtype=np.float32), + np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]], + dtype=np.float32), + expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]], + dtype=np.float32)) + # [2, 1, 3] @ [3, 1] -> [2, 1, 1] + self._testBinary( + math_ops.matmul, + np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32), + np.array([[1], [2], [3]], dtype=np.float32), + expected=np.array([[[140]], [[146]]], dtype=np.float32)) + # [2, 1, 3] @ [1, 3] -> [2, 1, 1] (adjoint_b) + self._testBinary( + lambda x, y: math_ops.matmul(x, y, adjoint_b=True), + np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32), + np.array([[1, 2, 3]], dtype=np.float32), + expected=np.array([[[140]], [[146]]], dtype=np.float32)) + # [2, 3, 1] @ [3, 1] -> [2, 1, 1] (adjoint_a) + self._testBinary( + lambda x, y: math_ops.matmul(x, y, adjoint_a=True), + np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32), + np.array([[1], [2], [3]], dtype=np.float32), + expected=np.array([[[140]], [[146]]], dtype=np.float32)) + # [2, 3, 1] @ [1, 3] -> [2, 1, 1] (adjoint_a and adjoint_b) + self._testBinary( + lambda x, y: math_ops.matmul(x, y, adjoint_a=True, adjoint_b=True), + np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32), + np.array([[1, 2, 3]], dtype=np.float32), + expected=np.array([[[140]], [[146]]], dtype=np.float32)) + # [5, 1, 2, 3] @ [1, 7, 3, 4] -> [5, 7, 2, 4] + self._testBinary( + math_ops.matmul, + np.ones([5, 1, 2, 3], dtype=np.float32), + np.ones([1, 7, 3, 4], dtype=np.float32), + expected=np.full([5, 7, 2, 4], 3, dtype=np.float32)) + # [4, 5, 1, 2, 3] @ [1, 1, 3, 5] -> [4, 5, 1, 2, 5] + self._testBinary( + math_ops.matmul, + np.full([4, 5, 1, 2, 3], 2., dtype=np.float32), + np.full([1, 1, 3, 5], 3., dtype=np.float32), + expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32)) + def testPad(self): for dtype, pad_type in itertools.product( self.numeric_types, [np.int32, np.int64]): diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index 5c24db539bc..75d06706a2d 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class BucketizationOpTest(xla_test.XLATestCase): def testInt(self): - with self.cached_session() as sess: + with self.session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) @@ -38,7 +38,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) def testFloat(self): - with self.cached_session() as sess: + with self.session() as sess: p = array_ops.placeholder(dtypes.float32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) @@ -48,7 +48,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) def test2DInput(self): - with self.cached_session() as sess: + with self.session() as sess: p = array_ops.placeholder(dtypes.float32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) @@ -58,7 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase): {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) def testInvalidBoundariesOrder(self): - with self.cached_session() as sess: + with self.session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) @@ -67,7 +67,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5, 0]}) def testBoundariesNotList(self): - with self.cached_session(): + with self.session(): with self.assertRaisesRegexp(TypeError, "Expected list.*"): p = array_ops.placeholder(dtypes.int32) with self.test_scope(): diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index ed580f95b6c..3e81f850100 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -23,6 +23,7 @@ def tf_xla_py_test( data = [], main = None, disabled_backends = None, + use_xla_device = True, **kwargs): """Generates py_test targets, one per XLA backend. @@ -47,6 +48,9 @@ def tf_xla_py_test( main: Same as py_test's main attribute. disabled_backends: A list of backends that should not be tested. Supported values include "cpu" and "gpu". If not specified, defaults to None. + use_xla_device: If true then the --test_device argument is set to XLA_CPU + and XLA_GPU for the CPU and GPU tests. Otherwise it is set to CPU and + GPU. **kwargs: keyword arguments passed onto the generated py_test() rules. """ if disabled_backends == None: @@ -56,6 +60,14 @@ def tf_xla_py_test( enabled_backends = [b for b in all_backends() if b not in disabled_backends] test_names = [] + + if use_xla_device: + cpu_xla_device = "XLA_CPU" + gpu_xla_device = "XLA_GPU" + else: + cpu_xla_device = "CPU" + gpu_xla_device = "GPU" + for backend in enabled_backends: test_name = "{}_{}".format(name, backend) backend_tags = ["tf_xla_{}".format(backend)] @@ -64,12 +76,12 @@ def tf_xla_py_test( backend_data = [] if backend == "cpu": backend_args += [ - "--test_device=XLA_CPU", + "--test_device=" + cpu_xla_device, "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128", ] elif backend == "gpu": backend_args += [ - "--test_device=XLA_GPU", + "--test_device=" + gpu_xla_device, "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", ] backend_tags += tf_cuda_tests_tags() diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index eec69ea7d2d..ef6df1f0879 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -57,7 +57,7 @@ class CategoricalTest(xla_test.XLATestCase): Returns: Frequencies from sampled classes; shape [batch_size, num_classes]. """ - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): random_seed.set_random_seed(1618) op = random_ops.multinomial(logits, num_samples, output_dtype=dtypes.int32) @@ -80,7 +80,7 @@ class CategoricalTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. - with self.cached_session(): + with self.session(): with self.test_scope(): x = rng(dtype, output_dtype) @@ -108,7 +108,7 @@ class CategoricalTest(xla_test.XLATestCase): def testCategoricalIsInRange(self): for dtype in self.float_types: for output_dtype in self.output_dtypes(): - with self.cached_session(): + with self.session(): with self.test_scope(): x = random_ops.multinomial( array_ops.ones(shape=[1, 20], dtype=dtype), 1000, @@ -140,9 +140,10 @@ class CategoricalTest(xla_test.XLATestCase): self.assertLess(chi2, 1e-3) def testStatelessMultinomialIsInRange(self): - for dtype in self.float_types: + for dtype in self.float_types.intersection( + [dtypes.float32, dtypes.bfloat16]): for output_dtype in self.output_dtypes(): - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless_random_ops.stateless_multinomial( @@ -157,7 +158,7 @@ class CategoricalTest(xla_test.XLATestCase): def testDeterminismMultinomial(self): # Stateless values should be equal iff the seeds are equal (roughly) num_samples = 10 - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seeds = [(x, y) for x in range(5) for y in range(5)] * 3 for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], @@ -170,7 +171,7 @@ class CategoricalTest(xla_test.XLATestCase): self.assertEqual(s0 == s1, np.all(v0 == v1)) def testEmpty(self): - with self.cached_session(): + with self.session(): with self.test_scope(): x = random_ops.multinomial( array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32) @@ -178,7 +179,7 @@ class CategoricalTest(xla_test.XLATestCase): self.assertEqual(y.shape, (42, 0)) def testEmptyStateless(self): - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless_random_ops.stateless_multinomial( diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index d1896a50f70..0a739df1163 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -54,7 +54,7 @@ class CholeskyOpTest(xla_test.XLATestCase): def _verifyCholesky(self, x, atol=1e-6): # Verify that LL^T == x. - with self.cached_session() as sess: + with self.session() as sess: placeholder = array_ops.placeholder( dtypes.as_dtype(x.dtype), shape=x.shape) with self.test_scope(): diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index ef2d7af69de..64d587268c1 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -38,7 +38,7 @@ class ClusteringTest(xla_test.XLATestCase): val1 = np.array([4, 3, 2, 1], dtype=np.float32) val2 = np.array([5, 6, 7, 8], dtype=np.float32) expected = val1 + val2 - with self.cached_session(): + with self.session(): with self.test_scope(): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") @@ -50,7 +50,7 @@ class ClusteringTest(xla_test.XLATestCase): val1 = np.array([4, 3, 2, 1]).astype(np.float32) val2 = np.array([5, 6, 7, 8]).astype(np.float32) expected = val1 + val2 - with self.cached_session(): + with self.session(): with ops.device(CPU_DEVICE): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") @@ -68,7 +68,7 @@ class ClusteringTest(xla_test.XLATestCase): # where x and z are placed on the CPU and y and w are placed on the XLA # device. If y and w are clustered for compilation, then the graph will # deadlock since the clustered graph will contain a self-loop. - with self.cached_session() as sess: + with self.session() as sess: with ops.device(CPU_DEVICE): x = array_ops.placeholder(dtypes.float32, [2]) with self.test_scope(): @@ -81,7 +81,7 @@ class ClusteringTest(xla_test.XLATestCase): self.assertAllClose(result, [12., 2.], rtol=1e-3) def testHostMemory(self): - with self.cached_session() as sess: + with self.session() as sess: x = array_ops.placeholder(dtypes.int32) with self.test_scope(): y = x + 1 diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 76750decd29..10dd2d6542c 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest class ConcatTest(xla_test.XLATestCase): def testHStack(self): - with self.cached_session(): + with self.session(): p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) with self.test_scope(): @@ -49,7 +49,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(result[4:, :], params[p2]) def testVStack(self): - with self.cached_session(): + with self.session(): p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) with self.test_scope(): @@ -65,7 +65,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(result[:, 4:], params[p2]) def testInt32(self): - with self.cached_session(): + with self.session(): p1 = np.random.rand(2, 3).astype("i") p2 = np.random.rand(2, 3).astype("i") x1 = constant_op.constant(p1) @@ -88,7 +88,7 @@ class ConcatTest(xla_test.XLATestCase): dtype_feed = dtypes.float32 else: dtype_feed = dtype - with self.cached_session(): + with self.session(): p = [] for i in np.arange(num_tensors): input_shape = shape @@ -130,7 +130,7 @@ class ConcatTest(xla_test.XLATestCase): self._testRandom(dtypes.int32) def _testGradientsSimple(self): - with self.cached_session(): + with self.session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -157,7 +157,7 @@ class ConcatTest(xla_test.XLATestCase): self._testGradientsSimple() def _testGradientsFirstDim(self): - with self.cached_session(): + with self.session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -185,7 +185,7 @@ class ConcatTest(xla_test.XLATestCase): self._testGradientsFirstDim() def _testGradientsLastDim(self): - with self.cached_session(): + with self.session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -220,7 +220,7 @@ class ConcatTest(xla_test.XLATestCase): # Random dim to concat on concat_dim = np.random.randint(5) concat_dim_sizes = np.random.randint(1, 5, size=num_tensors) - with self.cached_session(): + with self.session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase): def DISABLED_testZeroSize(self): # Verify that concat doesn't crash and burn for zero size inputs np.random.seed(7) - with self.cached_session(): + with self.session(): with self.test_scope(): for shape0 in (), (2,): axis = len(shape0) @@ -276,14 +276,14 @@ class ConcatTest(xla_test.XLATestCase): def testConcatTuple(self): c1 = np.random.rand(4, 4).astype(np.float32) c2 = np.random.rand(4, 4).astype(np.float32) - with self.cached_session(): + with self.session(): with self.test_scope(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t)) def testConcatNoScalars(self): - with self.cached_session(): + with self.session(): with self.test_scope(): scalar = constant_op.constant(7) dim = array_ops.placeholder(dtypes.int32) @@ -297,7 +297,7 @@ class ConcatTest(xla_test.XLATestCase): if "CPU" in self.device: self.skipTest("This test can time out on CPU, so we will just allow " "other backends to catch this specific error.") - with self.cached_session(): + with self.session(): with self.test_scope(): for concat_dim in range(2): params = {} @@ -333,7 +333,7 @@ class ConcatTest(xla_test.XLATestCase): class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): - with self.cached_session(): + with self.session(): with self.test_scope(): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) @@ -347,7 +347,7 @@ class ConcatOffsetTest(xla_test.XLATestCase): class PackTest(xla_test.XLATestCase): def testBasic(self): - with self.cached_session(): + with self.session(): with self.test_scope(): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) @@ -357,7 +357,7 @@ class PackTest(xla_test.XLATestCase): self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) def testScalars(self): - with self.cached_session(): + with self.session(): with self.test_scope(): s0 = constant_op.constant(2, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32) @@ -367,7 +367,7 @@ class PackTest(xla_test.XLATestCase): self.assertAllEqual(ans, [2, 3, 5]) def testEmpty(self): - with self.cached_session(): + with self.session(): with self.test_scope(): s0 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32) diff --git a/tensorflow/compiler/tests/cond_test.py b/tensorflow/compiler/tests/cond_test.py new file mode 100644 index 00000000000..5963020bbb7 --- /dev/null +++ b/tensorflow/compiler/tests/cond_test.py @@ -0,0 +1,256 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.cond in XLA.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.compiler.xla import xla +from tensorflow.python.eager import function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +@test_util.with_control_flow_v2 +class CondTest(xla_test.XLATestCase): + + def testCondAndTensorArrayInDefun(self): + # TODO(b/132430685): Make test more useful. Also b/129396295, b/127846988 + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + @function.defun + def f(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) + output = control_flow_ops.cond( + constant_op.constant( + True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) + + return output.stack() + + output_t = f() + self.assertAllEqual([5.], self.evaluate(output_t)) + + xla_context.Exit() + + def testCondConstPropagation(self): + with self.session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder(dtypes.float32) + p = array_ops.placeholder(dtypes.int32) + + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def if_true(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[p] + + def if_false(): + return 5. + + output = control_flow_ops.cond( + constant_op.constant(True), if_true, if_false) + + self.assertAllEqual(1., + sess.run(output, feed_dict={ + x: [0., 1., 2.], + p: 1 + })) + + xla_context.Exit() + + def testCondConstPropagation_xlaCompile(self): + self.skipTest("b/132430685") + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder_with_default([0., 1., 2.], shape=[3]) + p = constant_op.constant(1) + + def f(): + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def if_true(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[p] + + def if_false(): + return 5. + + return control_flow_ops.cond( + constant_op.constant(True), if_true, if_false) + + output = xla.compile(f) + + self.assertAllEqual(1., self.evaluate(output)) + + xla_context.Exit() + + def testCondConstPropagation_errorMsg(self): + self.skipTest("b/132430685") + with self.session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder(dtypes.float32) + p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32) + + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def if_true(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[:p] + + def if_false(): + return array_ops.fill([p], 5.) + + output = control_flow_ops.cond( + constant_op.constant(True), if_true, if_false) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must be a compile-time constant"): + sess.run( + output, feed_dict={ + x: [0., 1., 2.], + }) + + xla_context.Exit() + + def testCondConstPropagation_errorMsg_xlaCompile(self): + with self.session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder(dtypes.float32) + p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32) + condition = math_ops.cast( + random_ops.random_uniform([], minval=0, maxval=2, dtype=dtypes.int32), + dtypes.bool) + + def f(): + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def if_true(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[:p] + + def if_false(): + return array_ops.fill([p], 5.) + + return control_flow_ops.cond(condition, if_true, if_false) + + output = xla.compile(f) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must be a compile-time constant"): + sess.run( + output, feed_dict={ + x: [0., 1., 2.], + }) + + xla_context.Exit() + + def testSwitchCaseAndTensorArrayInDefun(self): + self.skipTest("b/127846988") + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + @function.defun + def f(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) + output = control_flow_ops.switch_case( + constant_op.constant(1), { + 0: lambda: ta.write(0, 5.), + 1: lambda: ta.write(0, 10.), + 2: lambda: ta.write(0, 15.), + }) + + return output.stack() + + output_t = f() + self.assertAllEqual([10.], self.evaluate(output_t)) + + xla_context.Exit() + + def testSwitchCaseConstPropagation(self): + self.skipTest("b/127846988") + with self.session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder(dtypes.float32) + p = array_ops.placeholder(dtypes.int32) + + def branch0(): + return 5. + + def branch1(): + return 15. + + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def branch2(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[p] + + output = control_flow_ops.switch_case( + constant_op.constant(2), { + 0: branch0, + 1: branch1, + 2: branch2, + }) + + self.assertAllEqual(7., + sess.run(output, feed_dict={ + x: [0., 1., 7.], + p: 2, + })) + + xla_context.Exit() + + def testCondNoInputs(self): + """Verifies against `Failed precondition: Expected one input shape`.""" + + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + for pred in True, False: + cond_out = control_flow_ops.cond( + array_ops.placeholder_with_default(pred, []), + lambda: constant_op.constant(2.), + lambda: constant_op.constant(1.)) + self.assertEqual(int(pred) + 1., self.evaluate(cond_out)) + + xla_context.Exit() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index af00ff287d4..e18e6784317 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -87,7 +87,7 @@ class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.cached_session() as sess: + with self.session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) with self.test_scope(): @@ -288,7 +288,7 @@ class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.cached_session() as sess: + with self.session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): @@ -586,7 +586,7 @@ class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.cached_session() as sess: + with self.session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 01cc1b63928..155b513e0d6 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): for padding in ["SAME", "VALID"]: for stride in [1, 2]: np.random.seed(1) @@ -69,7 +69,7 @@ class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): strides = [1, 1, 1, 1, 1] # Input, output: [batch, depth, height, width, channel] @@ -119,7 +119,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): self.assertAllClose(target, value[n, d, h, w, k]) def testConv3DTransposeSame(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): strides = [1, 2, 2, 2, 1] # Input, output: [batch, depth, height, width, depth] @@ -157,7 +157,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): self.assertAllClose(target, value[n, d, h, w, k]) def testConv3DTransposeValid(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): strides = [1, 2, 2, 2, 1] # Input, output: [batch, depth, height, width, depth] @@ -217,7 +217,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): np.random.seed(1) # Make it reproducible. x_val = np.random.random_sample(x_shape).astype(np.float64) f_val = np.random.random_sample(f_shape).astype(np.float64) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index b7d08df9f7d..74f16292334 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -92,7 +92,7 @@ class DenseLayerTest(test.TestCase): XlaCompile/XlaRun op pair by XLA. """ - with self.cached_session() as sess: + with self.session() as sess: x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) with jit_scope(): y = layers.dense(x, 3) @@ -113,15 +113,9 @@ class DenseLayerTest(test.TestCase): def testDenseLayerJitScopeUndefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. - - Dense layer uses shape op to get shape of input tensor if its shape is not - fully defined. XLA does not cluster shape op with other operators. But in - experimental_jit_scope, XLA is forced to compile shape op into its own - cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op - pairs. """ - with self.cached_session() as sess: + with self.session() as sess: x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) with jit_scope(): y = layers.dense(x, 3) @@ -136,7 +130,7 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(2, self.countXlaOps(labels)) + self.assertEqual(1, self.countXlaOps(labels)) self.assertFalse(InLabels(labels, "MatMult")) diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 90146e6b27c..c55bc23cf47 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -151,7 +151,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): dtype=data_type).reshape(tensor_in_sizes) x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], dtype=data_type).reshape(filter_in_sizes) - with self.cached_session() as sess: + with self.session() as sess: if data_type == np.float32: tolerance = 1e-4 else: @@ -247,7 +247,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): dtype=np.float32).reshape(tensor_in_sizes) x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], dtype=np.float32).reshape(filter_in_sizes) - with self.cached_session() as sess: + with self.session() as sess: t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32) t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32) with self.test_scope(): @@ -321,7 +321,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(use_xla): - with self.cached_session(): + with self.session(): t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) t1 = array_ops.placeholder(np.float32, shape=filter_sizes) t2 = array_ops.placeholder(np.float32, shape=output_sizes) @@ -361,7 +361,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(use_xla): - with self.cached_session(): + with self.session(): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 5f01e128f0b..93bc2dd0bf1 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index e89cf975f5d..0e07c2f741e 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import googletest class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): - with self.cached_session() as session: + with self.session() as session: index_placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices ] diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 632eccbb097..d2c459bf1ec 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -104,7 +104,7 @@ class EagerTest(xla_test.XLATestCase): self.assertAllEqual(15, product) # Run some ops graphly - with context.graph_mode(), self.cached_session(): + with context.graph_mode(), self.session(): with self.test_scope(): three = constant_op.constant(3) five = constant_op.constant(5) @@ -341,6 +341,57 @@ class EagerFunctionTest(xla_test.XLATestCase): var = f() self.assertEqual(1.0, var.numpy()) + def testResourceVariableNoInlineReadWrite(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + w = resource_variable_ops.ResourceVariable(0.0) + + @function.defun_with_attributes(attributes={'_noinline': True}) + def g(x): + w.assign(w.read_value() + x) + return v.read_value() + x * w.read_value() + + @function.defun_with_attributes(attributes={'_noinline': True}) + def f(): + return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0) + + # 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15 + self.assertEqual(145.0, f().numpy()) + self.assertEqual(15.0, w.read_value().numpy()) + + def testResourceVariableNoInlineReadOnly(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(10.0) + + @function.defun_with_attributes(attributes={'_noinline': True}) + def g(): + return v.read_value() + + @function.defun_with_attributes(attributes={'_noinline': True}) + def f(): + return g() + g() + g() + g() + g() + + self.assertEqual(50.0, f().numpy()) + + def testResourceVariableNoInlineWriteOnly(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(0.0) + + @function.defun_with_attributes(attributes={'_noinline': True}) + def g(x): + v.assign(x) + + @function.defun_with_attributes(attributes={'_noinline': True}) + def f(): + g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0) + + f() + self.assertEqual(5.0, v.read_value().numpy()) + def testUpdateVariable(self): with self.test_scope(): v = resource_variable_ops.ResourceVariable(1.0) @@ -623,6 +674,50 @@ class EagerFunctionTest(xla_test.XLATestCase): r = f(elems) self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) + def testFeedDeviceMemoryToOpExpectingHostMemory(self): + @function.defun + def f(dims, value): + return array_ops.fill(dims, value) + + with self.test_scope(): + x = constant_op.constant([4], dtype=dtypes.int64) + + y = f(x, 3) + self.assertAllEqual([3, 3, 3, 3], y) + + def testRequestNotToCompile(self): + with self.test_scope(): + def f(x): + with ops.device('device:CPU:0'): + y = 2.0 * x + return x, y + + wholly_compiled_f = def_function.function(f) + op_by_op_f = function.defun_with_attributes( + f, attributes={'_XlaCompile': False}) + + x = constant_op.constant([0.0, 2.0], name='data') + + # When function is wholly compiled, all outputs will be on the + # device on which it is run. + r_x, r_y = wholly_compiled_f(x) + self.assertAllEqual([0.0, 2.0], r_x) + self.assertAllEqual([0.0, 4.0], r_y) + if context.executing_eagerly(): + # backing_device is only available for eager tensors. + self.assertRegexpMatches(r_x.backing_device, self.device) + self.assertRegexpMatches(r_y.backing_device, self.device) + + # When function is executed op-by-op, requested devices will be + # respected. + r_x, r_y = op_by_op_f(x) + self.assertAllEqual([0.0, 2.0], r_x) + self.assertAllEqual([0.0, 4.0], r_y) + if context.executing_eagerly(): + # backing_device is only available for eager tensors. + self.assertRegexpMatches(r_x.backing_device, self.device) + self.assertRegexpMatches(r_y.backing_device, 'device:CPU:0') + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 37061e91d16..9e9b1f367e2 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -44,7 +44,7 @@ class ExtractImagePatches(xla_test.XLATestCase): strides = [1] + strides + [1] rates = [1] + rates + [1] - with self.cached_session(): + with self.session(): image_placeholder = array_ops.placeholder(dtypes.float32) with self.test_scope(): out_tensor = array_ops.extract_image_patches( diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index 2178c445560..dce5234ae44 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -107,7 +107,7 @@ class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): ], dtype=np.float32) - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): input_placeholder = array_ops.placeholder( dtypes.float32, inputs.shape, name="inputs") @@ -198,7 +198,7 @@ class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], dtype=np.float32) - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): gradient_placeholder = array_ops.placeholder( dtypes.float32, gradients.shape, name="gradients") @@ -306,7 +306,7 @@ class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): ], dtype=np.float32) - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): input_placeholder = array_ops.placeholder( dtypes.float32, inputs.shape, name="inputs") @@ -406,7 +406,7 @@ class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): expected_backprops_wrt_min = 1.0 + 2.0 expected_backprops_wrt_max = 10.0 + 11.0 - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): gradient_placeholder = array_ops.placeholder( dtypes.float32, gradients.shape, name="gradients") diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index 0edd0c35aa2..35df97c5222 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -70,7 +70,7 @@ class FFTTest(xla_test.XLATestCase): data = np.reshape(data.astype(np.float32).view(np.complex64), shape) data = to_32bit(complex_to_input(data)) expected = to_32bit(input_to_expected(data)) - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) @@ -92,7 +92,7 @@ class FFTTest(xla_test.XLATestCase): data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2] expected = np.swapaxes(expected, -1, -2) expected *= window.sum() # scipy divides by window sum - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 91d77d2f791..ba80fa0a0b2 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -31,13 +31,13 @@ from tensorflow.python.platform import test class FIFOQueueTest(xla_test.XLATestCase): def testEnqueue(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) enqueue_op = q.enqueue((10.0,)) enqueue_op.run() def testEnqueueWithShape(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) enqueue_correct_op.run() @@ -46,7 +46,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual(1, q.size().eval()) def testMultipleDequeues(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue([1])) self.evaluate(q.enqueue([2])) @@ -55,7 +55,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) def testQueuesDontShare(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue(1)) q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) @@ -64,13 +64,13 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertAllEqual(self.evaluate(q.dequeue()), 1) def testEnqueueDictWithoutNames(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) with self.assertRaisesRegexp(ValueError, "must have names"): q.enqueue({"a": 12.0}) def testParallelEnqueue(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -95,7 +95,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertItemsEqual(elems, results) def testParallelDequeue(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -119,7 +119,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertItemsEqual(elems, results) def testDequeue(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -133,7 +133,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) elems = [10.0, 20.0, 30.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -163,7 +163,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([elem], result) def testMultiEnqueueAndDequeue(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) elems = [(5, 10.0), (10, 20.0), (15, 30.0)] enqueue_ops = [q.enqueue((x, y)) for x, y in elems] @@ -179,12 +179,12 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([y], y_val) def testQueueSizeEmpty(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) self.assertEqual([0], q.size().eval()) def testQueueSizeAfterEnqueueAndDequeue(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) enqueue_op = q.enqueue((10.0,)) dequeued_t = q.dequeue() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index b078053cdbd..a2efb413a57 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -111,7 +111,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlwithoutRegularization(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -145,7 +145,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlwithoutRegularization2(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -178,7 +178,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlWithL1(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -212,7 +212,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlWithL1_L2(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -250,7 +250,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): weights will tend to have smaller magnitudes with this parameter set. """ for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -284,7 +284,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -331,9 +331,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testEquivAdagradwithoutRegularization(self): steps = 5 for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2) @@ -342,9 +342,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testEquivGradientDescentwithoutRegularization(self): steps = 5 for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): val2, val3 = self.equivGradientDescentTest_GradientDescentPart( steps, dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index a61827c2ae4..585a67eecba 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.cached_session(): + with self.session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.cached_session(): + with self.session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = Func(aval, bval) - with self.cached_session(): + with self.session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -105,7 +105,7 @@ class FunctionTest(xla_test.XLATestCase): def testCompileTimeConstantsInDefun(self): """Tests that XLA handles compile-time constants in defuns.""" - with self.cached_session() as sess: + with self.session() as sess: @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) def Foo(a, c, d): @@ -140,7 +140,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = aval + bval * 2 - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): a = array_ops.placeholder(dtypes.float32, name="a") b = array_ops.placeholder(dtypes.float32, name="b") diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 56a8e1b1667..ad8368a2bfb 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -46,8 +46,10 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): element_count = np.size(x) / int(np.shape(x)[-1]) mean = x_sum / element_count var = x_square_sum / element_count - mean * mean + factor = element_count / max(element_count - 1, 1) + corrected_var = var * factor normalized = (x - mean) / np.sqrt(var + epsilon) - return (normalized * scale + offset), mean, var + return (normalized * scale + offset), mean, var, corrected_var def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format): # Use the following formulas to calculate gradients: @@ -80,10 +82,10 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): offset_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 data_format_src = "NHWC" - y_ref, mean_ref, var_ref = self._reference_training( + y_ref, mean_ref, var_ref, _ = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) @@ -123,10 +125,12 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 data_format_src = "NHWC" - y_ref, mean_ref, var_ref = self._reference_training( + # When in training mode, fused_batchnorm applies an implicit Bessel's + # correction. So we have to use the corrected variance here, as well. + y_ref, mean_ref, _, var_ref_corr = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) @@ -168,7 +172,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): }) self.assertAllClose(mean_val, mean_ref, atol=1e-3) self.assertAllClose(y_val, y_ref_converted, atol=1e-3) - self.assertAllClose(var_val, var_ref, atol=1e-3) + self.assertAllClose(var_val, var_ref_corr, atol=1e-3) @parameterized.named_parameters(*DATA_FORMATS) def testLearning(self, data_format): @@ -209,7 +213,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( @@ -262,7 +266,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): var_val = np.random.random_sample(scale_shape).astype(np.float32) data_format_src = "NHWC" - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 7161f4ab339..d1f72b89e83 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): - with self.cached_session(): + with self.session(): paramsp = array_ops.placeholder(params.dtype) indicesp = array_ops.placeholder(indices.dtype) with self.test_scope(): @@ -46,7 +46,7 @@ class GatherNdTest(xla_test.XLATestCase): np.array([[4], [4], [0]], np.int32))) def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): - with self.cached_session(): + with self.session(): params = np.ones((3, 3), dtype=np.float32) indices_empty = np.empty((0, 2), dtype=np.int32) diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index a38e1edafe8..5c2fe3a37d8 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -42,7 +42,7 @@ class GatherTest(xla_test.XLATestCase): return data def testScalar1D(self): - with self.cached_session() as session, self.test_scope(): + with self.session() as session, self.test_scope(): data = np.array([0, 1, 2, 3, 7, 5]) for dtype in self.all_tf_types: for indices in 4, [4], [1, 2, 2, 4, 5]: @@ -55,7 +55,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(np_val, gather_val) def testScalar2D(self): - with self.cached_session() as session, self.test_scope(): + with self.session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in self.all_tf_types: @@ -70,7 +70,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): - with self.cached_session() as session, self.test_scope(): + with self.session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in self.all_tf_types: @@ -89,7 +89,7 @@ class GatherTest(xla_test.XLATestCase): if np.int64 not in self.int_types: return - with self.cached_session() as session, self.test_scope(): + with self.session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) # The indices must be in bounds for any axis. @@ -117,7 +117,7 @@ class GatherTest(xla_test.XLATestCase): for axis in 0, 1, 2, 3, -1, -2: params = self._buildParams(np.random.randn(*shape), dtype) indices = np.random.randint(shape[axis], size=indices_shape) - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): tf_params = array_ops.placeholder(dtype=dtype) tf_indices = constant_op.constant(indices, dtype=dtypes.int32) gather = array_ops.gather(tf_params, tf_indices, axis=axis) @@ -127,7 +127,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(gather_np, gather_value) def testIndicesWithDifferentDimensions(self): - with self.cached_session(): + with self.session(): for dtype in self.numeric_tf_types: params = array_ops.placeholder(dtype=dtype) indices = array_ops.placeholder(dtype=np.int32) @@ -141,7 +141,7 @@ class GatherTest(xla_test.XLATestCase): [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) def testGatherPrecision(self): - with self.cached_session() as session, self.test_scope(): + with self.session() as session, self.test_scope(): data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) indices = np.array([1, 2, 3, 1]) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 42e688174fc..425483c81a5 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -53,7 +53,7 @@ class RGBToHSVTest(xla_test.XLATestCase): inp = GenerateNumpyRandomRGB(shape).astype(nptype) # Convert to HSV and back, as a batch and individually - with self.cached_session() as sess: + with self.session() as sess: batch0 = array_ops.placeholder(nptype, shape=shape) with self.test_scope(): batch1 = image_ops.rgb_to_hsv(batch0) @@ -68,8 +68,8 @@ class RGBToHSVTest(xla_test.XLATestCase): {batch0: inp}) # Verify that processing batch elements together is the same as separate - self.assertAllClose(batch1, join1) - self.assertAllClose(batch2, join2) + self.assertAllCloseAccordingToType(batch1, join1, half_rtol=0.000002) + self.assertAllCloseAccordingToType(batch2, join2, half_rtol=0.000002) self.assertAllCloseAccordingToType( batch2, inp, bfloat16_atol=0.03, half_rtol=0.02) @@ -77,7 +77,7 @@ class RGBToHSVTest(xla_test.XLATestCase): data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] for nptype in self.float_types: rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255. - with self.cached_session(): + with self.session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv = image_ops.rgb_to_hsv(placeholder) @@ -96,7 +96,7 @@ class RGBToHSVTest(xla_test.XLATestCase): for r, g, b in rgb_flat ]) hsv_np = hsv_np.reshape(4, 4, 4, 3) - with self.cached_session(): + with self.session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv_op = image_ops.rgb_to_hsv(placeholder) @@ -107,7 +107,7 @@ class RGBToHSVTest(xla_test.XLATestCase): class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): - with self.cached_session(): + with self.session(): x = array_ops.placeholder(x_np.dtype, shape=x_np.shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -145,7 +145,7 @@ class AdjustContrastTest(xla_test.XLATestCase): return y_np def _adjustContrastTf(self, x_np, contrast_factor): - with self.cached_session(): + with self.session(): x = array_ops.placeholder(np.float32) with self.test_scope(): y = image_ops.adjust_contrast(x, contrast_factor) @@ -179,7 +179,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.cached_session(): + with self.session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -197,7 +197,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.cached_session(): + with self.session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -215,7 +215,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.cached_session(): + with self.session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -243,7 +243,7 @@ class AdjustHueTest(xla_test.XLATestCase): return y_v.reshape(x_np.shape) def _adjustHueTf(self, x_np, delta_h): - with self.cached_session(): + with self.session(): x = array_ops.placeholder(dtypes.float32) with self.test_scope(): y = gen_image_ops.adjust_hue(x, delta_h) @@ -323,7 +323,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128] y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape) - with self.cached_session(): + with self.session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) y = self._adjust_saturation(x, saturation_factor) y_tf = y.eval({x: x_np}) @@ -338,7 +338,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.cached_session(): + with self.session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) y = self._adjust_saturation(x, saturation_factor) y_tf = y.eval({x: x_np}) @@ -377,7 +377,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): "gb_same", "rgb_same", ] - with self.cached_session(): + with self.session(): for x_shape in x_shapes: for test_style in test_styles: x_np = np.random.rand(*x_shape) * 255. @@ -416,7 +416,7 @@ class ResizeNearestNeighborTest(xla_test.XLATestCase): align_corners=True): if expected is None: self.fail("expected must be specified") - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): image = array_ops.placeholder(image_np.dtype) resized = gen_image_ops.resize_nearest_neighbor( image, target_shape, align_corners=align_corners) @@ -524,7 +524,7 @@ class ResizeBilinearTest(xla_test.XLATestCase): align_corners=True): if expected is None: self.fail("expected must be specified") - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): image = array_ops.placeholder(image_np.dtype) resized = gen_image_ops.resize_bilinear( image, target_shape, align_corners=align_corners) @@ -544,7 +544,7 @@ class ResizeBilinearTest(xla_test.XLATestCase): self.fail("input_shape must be specified") if expected is None: self.fail("expected must be specified") - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): dtype = dtype or np.float32 grads = array_ops.placeholder(np.float32) resized = gen_image_ops.resize_bilinear_grad( @@ -722,7 +722,7 @@ class ResizeBilinearTest(xla_test.XLATestCase): for dtype in self.float_types: input_image = np.array(input_data, dtype=dtype) expected = np.array(expected_data, dtype=dtype) - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): image = array_ops.placeholder(input_image.dtype) resized = gen_image_ops.resize_bilinear( image, [6, 4], align_corners=False) @@ -741,7 +741,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32) - with self.cached_session() as sess: + with self.session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -779,7 +779,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32) - with self.cached_session() as sess: + with self.session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -821,7 +821,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.4, dtype=np.float32) - with self.cached_session() as sess: + with self.session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -864,7 +864,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.4, dtype=np.float32) - with self.cached_session() as sess: + with self.session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -905,7 +905,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.1, dtype=np.float32) - with self.cached_session() as sess: + with self.session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index dbea9849e21..29444c19014 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -95,7 +95,12 @@ class JitLaunchTest(test.TestCase): # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun # ops to be constant-folded away, so the check is optional. - def _compare(self, fn, args, require_kernel_launch=True, noinline=None): + def _compare(self, + fn, + args, + require_kernel_launch=True, + name=None, + noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: placeholders = [] feeds = {} @@ -105,7 +110,8 @@ class JitLaunchTest(test.TestCase): placeholders.append(placeholder) feeds[placeholder] = arg - compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline) + compiled_op = CompiledKernel( + fn, *placeholders, name=name, noinline=noinline) direct_op = fn(*placeholders) run_metadata = config_pb2.RunMetadata() @@ -155,17 +161,16 @@ class JitLaunchTest(test.TestCase): # to symbolically execute Bar correctly regardless of whether Bar is inlined # or not. - # TODO(b/36139787): Re-enable this test when noinline works again. # Tests compiled=True and noinline=True. - # self._compare( - # AddOnceReturnTwice, [np.array( - # [[[0.5, -1.0]]], dtype=np.float32)], - # noinline=True) + self._compare( + AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)], + name="AddOnceReturnTwice_inline", + noinline=True) # Tests compiled=True and noinline=False. self._compare( - AddOnceReturnTwice, [np.array( - [[[0.5, -1.0]]], dtype=np.float32)], + AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)], + name="AddOnceReturnTwice_noinline", noinline=False) def testOneConstOutput(self): @@ -510,22 +515,6 @@ class ElementWiseFusionTest(test.TestCase): return output, xla_run_count - def testElementWiseClustering(self): - arg0 = np.random.rand(2, 2).astype(np.float32) - arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ( - "--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) - tf_op, tf_count = self.simpleTest(arg0, arg1, - config_pb2.OptimizerOptions.OFF) - self.assertEqual(0, tf_count) - - tfef_op, tfef_count = self.simpleTest(arg0, arg1, - config_pb2.OptimizerOptions.ON_1) - self.assertEqual(2, tfef_count) - - self.assertAllClose(tf_op, tfef_op, rtol=1e-1) - class LazyCompilationTest(test.TestCase): diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py index 0210201fa71..b061ac3c817 100644 --- a/tensorflow/compiler/tests/listdiff_op_test.py +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -33,7 +33,7 @@ class ListDiffTest(xla_test.XLATestCase): def _testListDiff(self, x, y, out, idx): for dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]: - with self.cached_session(): + with self.session(): x_tensor = ops.convert_to_tensor(x, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype) with self.test_scope(): diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 5dddf6ae4e8..309db2f2f3a 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -58,7 +58,7 @@ class LRNTest(xla_test.XLATestCase): return output def _RunAndVerify(self, dtype): - with self.cached_session(): + with self.session(): # random shape shape = np.random.randint(1, 16, size=4) # Make depth at least 2 to make it meaningful @@ -110,7 +110,7 @@ class LRNTest(xla_test.XLATestCase): alpha = 1.0 * np.random.rand() beta = 1.0 * np.random.rand() - with self.cached_session(): + with self.session(): in_image = constant_op.constant(in_image_vals, shape=shape) out_image = constant_op.constant(out_image_vals, shape=shape) out_grads = constant_op.constant(out_grads_vals, shape=shape) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 776ed899e68..05a32c839e4 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -73,7 +73,7 @@ class LSTMTest(test.TestCase): def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar, pad_scalar): - with self.cached_session() as sess: + with self.session() as sess: num_inputs = 1 num_nodes = 1 @@ -156,7 +156,7 @@ class LSTMTest(test.TestCase): def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar, pad_scalar): - with self.cached_session() as sess: + with self.session() as sess: num_inputs = 1 num_nodes = 1 seq_length = 3 diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 0eec070a906..b216aae6891 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -173,7 +173,7 @@ class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase): ]: pass for dtype in self.float_types: - with self.cached_session(): + with self.session(): mat = np.ones(batch_shape + [rows, cols]).astype(dtype) batch_mat = np.tile(mat, batch_shape + [1, 1]) for lower in -1, 0, 1, rows - 1: diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 94cd3eeb317..b348af97c51 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -54,7 +54,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): clean_a = np.tril(a) if lower else np.triu(a) - with self.cached_session() as sess: + with self.session() as sess: placeholder_a = MakePlaceholder(a) placeholder_ca = MakePlaceholder(clean_a) placeholder_b = MakePlaceholder(b) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index 3416f7dbd6b..dc4ccd52624 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -41,7 +41,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -101,7 +101,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testNesterovMomentum(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype) var0_np = np.array([0.1, 0.2], dtype=dtype) @@ -126,7 +126,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index a1c07fce732..8210dff17c0 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -126,7 +126,7 @@ class NAryOpsTest(xla_test.XLATestCase): [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) def testOneHot(self): - with self.cached_session() as session, self.test_scope(): + with self.session() as session, self.test_scope(): indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32)) op = array_ops.one_hot(indices, np.int32(4), @@ -148,7 +148,7 @@ class NAryOpsTest(xla_test.XLATestCase): self.assertAllEqual(output, expected) def testSplitV(self): - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): output = session.run( array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]], diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 38cb2f83efc..c3e1f42bd9d 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -29,14 +29,14 @@ from tensorflow.python.platform import googletest class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): output = op() result = session.run(output) self.assertAllClose(result, expected, rtol=1e-3) def testNoOp(self): - with self.cached_session(): + with self.session(): with self.test_scope(): output = control_flow_ops.no_op() # This should not crash. diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py index e2f6de821b5..b7fe5def7b6 100644 --- a/tensorflow/compiler/tests/permute_test.py +++ b/tensorflow/compiler/tests/permute_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class XlaPermuteOpTest(xla_test.XLATestCase): def _runPermuteAndCompare(self, x, src_format, dst_format, expected): - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape) param = {placeholder: x} diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 9671ae0ae97..675ac047a3c 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import googletest class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(4.0) ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 @@ -36,7 +36,7 @@ class PlaceholderTest(xla_test.XLATestCase): self.assertEqual(8.0, self.evaluate(out)) def test_placeholder_with_default_fed(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(4.0) ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index b6cdd38345b..9a008940fa2 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -62,7 +62,7 @@ class Pooling3DTest(xla_test.XLATestCase): # numbers from 1. x = np.arange(1.0, total_size + 1, dtype=np.float32) x = x.reshape(input_sizes) - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): inputs = array_ops.placeholder(dtypes.float32) t = pool_func( inputs, @@ -210,7 +210,7 @@ class Pooling3DTest(xla_test.XLATestCase): strides = [1] + strides + [1] total_size = np.prod(input_sizes) x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) - with self.cached_session() as sess: + with self.session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). with ops.device("CPU"): diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index d03bd4fdbb7..bcc5ce77ec6 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -89,7 +89,7 @@ class PoolingTest(xla_test.XLATestCase): # numbers from 1. x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32) x = x.reshape(input_sizes) - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): inputs = array_ops.placeholder(dtypes.float32) t = inputs @@ -324,7 +324,7 @@ class PoolGradTest(xla_test.XLATestCase): # TODO(b/74222344): Fix nan handling for max pool grad. # x[np.random.choice(total_size)] = np.nan x = x.reshape(input_sizes) - with self.cached_session() as sess: + with self.session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). with ops.device(self.CPU_DEVICE): @@ -454,7 +454,7 @@ class PoolGradTest(xla_test.XLATestCase): """Verifies the output values of the pooling function. Args: - pool_func: Pooling function to be called, e.g., tf.nn.max_pool + pool_func: Pooling function to be called, e.g., tf.nn.max_pool2d pool_grad_func: Corresponding pooling gradient function. input_sizes: Input tensor dimensions. ksize: The kernel size dimensions diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py index 5b35c200277..119b15b6b2c 100644 --- a/tensorflow/compiler/tests/powersign_test.py +++ b/tensorflow/compiler/tests/powersign_test.py @@ -64,7 +64,7 @@ class PowerSignTest(xla_test.XLATestCase): base=math.e, beta=0.9): for dtype in self.float_types: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): # Initialize variables for numpy implementation. m0, m1 = 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py index 63cc51a4701..1993d4ecb19 100644 --- a/tensorflow/compiler/tests/proximal_adagrad_test.py +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_adagrad class ProximalAdagradOptimizerTest(xla_test.XLATestCase): def testResourceProximalAdagradwithoutRegularization(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -62,7 +62,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertEqual(2, len(opt_vars)) def testProximalAdagradwithoutRegularization2(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -86,7 +86,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.715679, 2.433051]), self.evaluate(var1)) def testProximalAdagradWithL1(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -110,7 +110,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([2.959304, 1.029232]), self.evaluate(var1)) def testProximalAdagradWithL1_L2(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -153,7 +153,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): return self.evaluate(var0), self.evaluate(var1) def testEquivAdagradwithoutRegularization(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): val0, val1 = self.applyOptimizer( proximal_adagrad.ProximalAdagradOptimizer( 3.0, @@ -161,7 +161,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): val2, val3 = self.applyOptimizer( adagrad.AdagradOptimizer( 3.0, initial_accumulator_value=0.1)) diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py index 5aec433be76..ce97fd1a5ba 100644 --- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_gradient_descent class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): def testResourceProximalGradientDescentwithoutRegularization(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -53,7 +53,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([-0.09, -0.18]), self.evaluate(var1)) def testProximalGradientDescentwithoutRegularization2(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -75,7 +75,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.91, 2.82]), self.evaluate(var1)) def testProximalGradientDescentWithL1(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -97,7 +97,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.67, 2.37]), self.evaluate(var1)) def testProximalGradientDescentWithL1_L2(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -137,14 +137,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): return self.evaluate(var0), self.evaluate(var1) def testEquivGradientDescentwithoutRegularization(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): val0, val1 = self.applyOptimizer( proximal_gradient_descent.ProximalGradientDescentOptimizer( 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): val2, val3 = self.applyOptimizer( gradient_descent.GradientDescentOptimizer(3.0)) diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index b4d4193e35f..5fcf254db82 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -71,7 +71,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): x_np = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) - with self.cached_session() as sess: + with self.session() as sess: x_tf = array_ops.placeholder(dtype) with self.test_scope(): q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py index cd9b728ab31..9a1d29c0092 100644 --- a/tensorflow/compiler/tests/quantized_ops_test.py +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -36,7 +36,7 @@ class QuantizedOpsTest(xla_test.XLATestCase): # Verify that quantized types can be clustered by XLA. def testQuantizedTypeRoundtrip(self): - with self.cached_session() as session: + with self.session() as session: for dtype in self.quantized_tf_types: in_values = np.array([1, 2, 3, 4, 5, 6]) expected = [[1, 2], [3, 4], [5, 6]] @@ -82,7 +82,7 @@ class DeuantizedOpsTest(xla_test.XLATestCase): num_rows = 100 num_columns = 3547 random_input = np.random.normal(128.0, 10.0, [num_rows, num_columns]) - with self.cached_session() as session: + with self.session() as session: with ops.device("CPU"): test_input = ops.convert_to_tensor(random_input, dtype=dtypes.float32) transposed_input = array_ops.transpose(test_input, [1, 0]) @@ -95,7 +95,7 @@ class DeuantizedOpsTest(xla_test.XLATestCase): quantized_output = array_ops.slice(transposed_quantized_output, [0, 0], [num_rows, num_columns]) - value = session.run(quantized_output) + value = session.run(quantized_output) self.assertAllClose(value, random_input, 1.0) diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 34f2465ba63..4ac6a82145d 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -36,11 +36,11 @@ class RandomOpsTest(xla_test.XLATestCase): def _random_types(self): return set(self.numeric_types) - set( - self.complex_types) - {np.uint8, np.int8} + self.complex_types) - {np.uint64, np.int64, np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): x = rng(dtype) @@ -72,6 +72,30 @@ class RandomOpsTest(xla_test.XLATestCase): for dtype in self._random_types() & self.float_types: self._testRngIsNotConstant(rng, dtype) + def testRandomNormalMean(self): + for dtype in self._random_types() & self.float_types: + with self.session(): + with self.test_scope(): + normal = random_ops.random_normal([1024], + dtype=dtype, + mean=1.4, + stddev=1.2) + mean = math_ops.reduce_mean(normal) + x = self.evaluate(mean) + self.assertAllClose(x, 1.4, rtol=1e-1, atol=1e-1) + + def testRandomNormalVariance(self): + for dtype in self._random_types() & self.float_types: + with self.session(): + with self.test_scope(): + normal = random_ops.random_normal([1024], + dtype=dtype, + mean=2.3, + stddev=2.0) + variance = math_ops.reduce_variance(normal) + x = self.evaluate(variance) + self.assertAllClose(x, 4.0, rtol=1e-1, atol=1e-1) + def testRandomUniformIsInRange(self): for dtype in self._random_types(): # TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is @@ -79,7 +103,7 @@ class RandomOpsTest(xla_test.XLATestCase): if (self.device in ["XLA_GPU", "XLA_CPU" ]) and (dtype in [dtypes.bfloat16, dtypes.half]): continue - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): x = random_ops.random_uniform( shape=[1000], dtype=dtype, minval=-2, maxval=33) @@ -92,14 +116,13 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) - for dtype in self._random_types() & self.float_types: - self._testRngIsNotConstant(rng, dtype) + self._testRngIsNotConstant(rng, dtypes.float32) def testTruncatedNormalIsInRange(self): count = 10000000 # TODO(b/34339814): make this test work with 16 bit float types. for dtype in self._random_types() & {dtypes.float32, dtypes.float64}: - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) y = self.evaluate(x) @@ -144,7 +167,7 @@ class RandomOpsTest(xla_test.XLATestCase): self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3) def testShuffle1d(self): - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) @@ -155,7 +178,7 @@ class RandomOpsTest(xla_test.XLATestCase): self.assertAllEqual(set(result), set(expected)) def testShuffle2d(self): - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): x = array_ops.diag(math_ops.range(20)) shuffle = random_ops.random_shuffle(x) diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index e8fc81bbb54..a39f633858a 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -45,7 +45,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" for test_input in test_inputs: - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) index = array_ops.placeholder(index_dtype) @@ -190,7 +190,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase): """ for test_input in test_inputs: - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) index = array_ops.placeholder(dtypes.int32) diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index ff20ea3f428..04dc65fb5e8 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -32,7 +32,7 @@ class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): - with self.cached_session(): + with self.session(): placeholder = array_ops.placeholder(operand.dtype) with self.test_scope(): output = xla.reduce_window(placeholder, init, reducer, **kwargs) diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py index 96e0b074754..4960666396e 100644 --- a/tensorflow/compiler/tests/reshape_op_test.py +++ b/tensorflow/compiler/tests/reshape_op_test.py @@ -33,7 +33,7 @@ class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): ('64_bit_index', dtypes.int64)) def testBasic(self, index_dtype): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): shape = constant_op.constant([3, 2], dtype=index_dtype) diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index 392290fd92d..7dc323b0ab5 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -51,7 +51,7 @@ class ReverseOpsTest(xla_test.XLATestCase): def _AssertReverseEqual(self, revdims, shape): np.random.seed(120) pval = np.random.randint(0, 100, size=shape).astype(float) - with self.cached_session(): + with self.session(): with self.test_scope(): p = array_ops.placeholder(dtypes.int32, shape=shape) axis = constant_op.constant( diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index abc822ef363..80c332b1ce3 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -35,7 +35,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): seq_lengths, truth, expected_err_re=None): - with self.cached_session(): + with self.session(): p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) with self.test_scope(): diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index dc3e90b4afa..961103e83f2 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -55,7 +55,7 @@ class RmspropTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: for centered in [False, True]: - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): # Initialize variables for numpy implementation. var0_np = np.array([1.0, 2.0], dtype=dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 17639bd8a75..7c36f8b13ca 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -78,7 +78,7 @@ class CumsumTest(xla_test.XLATestCase): def _compare(self, x, axis, exclusive, reverse): np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): p = array_ops.placeholder(x.dtype) tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( feed_dict={p: x}) @@ -100,7 +100,7 @@ class CumsumTest(xla_test.XLATestCase): for dtype in self.valid_dtypes: x = np.arange(1, 6).reshape([5]).astype(dtype) for axis_dtype in self.axis_dtypes(): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): p = array_ops.placeholder(x.dtype) axis = constant_op.constant(0, axis_dtype) math_ops.cumsum(p, axis).eval(feed_dict={p: x}) @@ -131,7 +131,7 @@ class CumsumTest(xla_test.XLATestCase): def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): input_tensor = ops.convert_to_tensor(x) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, @@ -156,7 +156,7 @@ class CumprodTest(xla_test.XLATestCase): def _compare(self, x, axis, exclusive, reverse): np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): p = array_ops.placeholder(x.dtype) prod = math_ops.cumprod(p, axis, exclusive, reverse) tf_out = prod.eval(feed_dict={p: x}) @@ -178,7 +178,7 @@ class CumprodTest(xla_test.XLATestCase): for dtype in self.valid_dtypes: x = np.arange(1, 6).reshape([5]).astype(dtype) for axis_dtype in self.axis_dtypes(): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): p = array_ops.placeholder(x.dtype) axis = constant_op.constant(0, axis_dtype) math_ops.cumprod(x, axis).eval(feed_dict={p: x}) @@ -209,7 +209,7 @@ class CumprodTest(xla_test.XLATestCase): def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): input_tensor = ops.convert_to_tensor(x) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index a9a87b8fb31..f1559a236d2 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -119,7 +119,7 @@ class ScatterNdTest(xla_test.XLATestCase): self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) def _runScatterNd(self, indices, updates, shape): - with self.cached_session(): + with self.session(): updates_placeholder = array_ops.placeholder(updates.dtype) indices_placeholder = array_ops.placeholder(indices.dtype) with self.test_scope(): diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 287bb0d84e2..500617bc38b 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -32,7 +32,7 @@ class SegmentReductionOpsTest(xla_test.XLATestCase): """Test cases for segment reduction ops.""" def _segmentReduction(self, op, data, indices, num_segments): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): d = array_ops.placeholder(data.dtype, shape=data.shape) if isinstance(indices, int): i = array_ops.placeholder(np.int32, shape=[]) diff --git a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py index cfb5c82b22e..0c1a1d145d4 100644 --- a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py +++ b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py @@ -38,7 +38,7 @@ class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase): n = shape[-1] e_np, _ = np.linalg.eigh(x_np) - with self.cached_session() as sess: + with self.session() as sess: x_tf = array_ops.placeholder(dtype) with self.test_scope(): e, v = linalg_ops.self_adjoint_eig(x_tf) diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 2c611a959e1..b7784062e82 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -29,7 +29,7 @@ class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.slice(i, [2], [4]) @@ -42,7 +42,7 @@ class SliceTest(xla_test.XLATestCase): def testZeroSlice(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[2]) with self.test_scope(): o = array_ops.slice(i, [0], [0]) @@ -55,7 +55,7 @@ class SliceTest(xla_test.XLATestCase): def test3D(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) with self.test_scope(): o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) @@ -77,7 +77,7 @@ class SliceTest(xla_test.XLATestCase): def test3DWithDynamicBegin(self): """Tests a slice where the start offset is not known at compile time.""" for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) begin = array_ops.placeholder(dtypes.int32, shape=[3]) with self.test_scope(): @@ -101,7 +101,7 @@ class SliceTest(xla_test.XLATestCase): def test3DWithDynamicBeginAndNegativeSize(self): """Tests a slice where `begin` is fed dynamically and `size` contains -1.""" for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) begin = array_ops.placeholder(dtypes.int32, shape=[3]) with self.test_scope(): @@ -127,7 +127,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.strided_slice(i, [2], [6], [2]) @@ -140,7 +140,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test1DNegativeStride(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.strided_slice(i, [6], [2], [-2]) @@ -153,7 +153,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test2DDegenerate(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): o = array_ops.strided_slice(i, [-1, 0], [0, 3]) @@ -167,7 +167,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test2DDegenerateNegativeStride(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1]) @@ -181,7 +181,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test3D(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) with self.test_scope(): o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) @@ -202,7 +202,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test3DNegativeStride(self): for dtype in self.numeric_types: - with self.cached_session(): + with self.session(): i = array_ops.placeholder(dtype, shape=[3, 4, 10]) with self.test_scope(): o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 3e499c2fb17..d50fdec7c63 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import test class XlaSortOpTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -134,7 +134,7 @@ class XlaSortOpTest(xla_test.XLATestCase): if bfloat16 not in self.numeric_types: return - with self.cached_session() as sess: + with self.session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=4) @@ -152,7 +152,7 @@ class XlaSortOpTest(xla_test.XLATestCase): if bfloat16 not in self.numeric_types: return - with self.cached_session() as sess: + with self.session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=6) @@ -166,6 +166,28 @@ class XlaSortOpTest(xla_test.XLATestCase): dtype=bfloat16), results[0]) self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) + def testInTopK(self): + supported_types = set([np.int32, np.int64]) + for dtype in supported_types.intersection(self.numeric_types): + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + batch = 16 + for x in [np.arange(batch * array_size)]: + np.random.shuffle(x) + x = np.reshape(x, [batch, array_size]) + y = np.random.randint(0, array_size, size=batch) + for k in k_options: + indices = x.argsort(axis=1)[::, -1:-k - 1:-1] + expected = [y[i] in indices[i] for i in range(batch)] + + def in_topk(predictions, targets, k=k): + return nn_ops.in_top_k(predictions, targets, k) + + self._assertOpOutputMatchesExpected( + in_topk, + [x.astype(np.float32), y.astype(dtype)], + expected=[expected]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index 33b84cec718..74f5f7bb48f 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -72,7 +72,7 @@ class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): for dtype in self.float_types: # outputs = space_to_batch(inputs) placeholder = array_ops.placeholder(dtype) @@ -155,7 +155,7 @@ class SpaceToBatchNDTest(xla_test.XLATestCase): def _testPad(self, inputs, block_shape, paddings, outputs): block_shape = np.array(block_shape) paddings = np.array(paddings).reshape((len(block_shape), 2)) - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): for dtype in self.float_types: # TODO(b/68813416): Skip bfloat16's as the input type for direct is # float32 and results in a mismatch, while making testDirect provide the diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py index 07afd1ab3fb..dbfdc3b7247 100644 --- a/tensorflow/compiler/tests/sparse_to_dense_op_test.py +++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py @@ -45,32 +45,32 @@ def _SparseToDense(sparse_indices, class SparseToDenseTest(xla_test.XLATestCase): def testInt(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1, 0) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def testFloat(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) self.assertAllClose(np_ans, tf_ans) def testSetValue(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1) np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def testSetSingleValue(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1, -1) np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def test2d(self): # pylint: disable=bad-whitespace - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1) np_ans = np.array([[-1, -1, -1, -1], [-1, -1, -1, 1], @@ -78,38 +78,44 @@ class SparseToDenseTest(xla_test.XLATestCase): self.assertAllClose(np_ans, tf_ans) def testZeroDefault(self): - with self.cached_session(): + with self.session(): x = sparse_ops.sparse_to_dense(2, [4], 7).eval() self.assertAllEqual(x, [0, 0, 7, 0]) def test3d(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1) np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 np_ans[1, 3, 0] = 1 np_ans[2, 0, 1] = 1 self.assertAllClose(np_ans, tf_ans) + def testDegenerateIndexMatrix(self): + with self.session(), self.test_scope(): + tf_ans = _SparseToDense([[2], [3], [4], [5], [6], [7], [8], [9]], [10], + [1, 2, 3, 4, 5, 6, 7, 8], -1) + self.assertAllClose([-1, -1, 1, 2, 3, 4, 5, 6, 7, 8], tf_ans) + def testBadShape(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): _SparseToDense([1, 3], [[5], [3]], 1, -1) def testBadValue(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[2,1\], " r"should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [[5], [3]], -1) def testBadNumValues(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [1, 2, 3], -1) def testBadDefault(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): with self.assertRaisesOpError("default_value should be a scalar"): _SparseToDense([1, 3], [5], [1, 2], [0]) diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 720595a159e..0c13d632997 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.compiler.xla import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -31,73 +32,93 @@ from tensorflow.python.platform import test class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): - with self.cached_session(), self.test_scope(): - size = array_ops.placeholder(dtypes.int32) + with self.session(), self.test_scope(): + v = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops.stack_push_v2(h, v) - with ops.control_dependencies([c]): - c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) - self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) + + def fn(): + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, v) + with ops.control_dependencies([c]): + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) + return c1 + + self.assertAllClose([[4.0, 5.0]], + xla.compile(fn)[0].eval({v: [[4.0, 5.0]]})) def testStackPushPopSwap(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): a = np.arange(2000) x = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops.stack_push_v2(h, x, swap_memory=True) - with ops.control_dependencies([c]): - c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) - self.assertAllClose(a, c1.eval({x: a})) + + def fn(): + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, x, swap_memory=True) + with ops.control_dependencies([c]): + return gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) + + self.assertAllClose(a, xla.compile(fn)[0].eval({x: a})) def testMultiStack(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): v = array_ops.placeholder(dtypes.float32) - h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops.stack_push_v2(h1, v) - with ops.control_dependencies([c1]): - c1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) - h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="bar") - c2 = gen_data_flow_ops.stack_push_v2(h2, 5.0) - with ops.control_dependencies([c2]): - c2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) - r = c1 + c2 - self.assertAllClose(9.0, r.eval({v: 4.0})) + + def fn(): + h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops.stack_push_v2(h1, v) + with ops.control_dependencies([c1]): + c1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="bar") + c2 = gen_data_flow_ops.stack_push_v2(h2, 5.0) + with ops.control_dependencies([c2]): + c2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) + return c1 + c2 + + self.assertAllClose(9.0, xla.compile(fn)[0].eval({v: 4.0})) def testSameNameStacks(self): """Different stacks with the same name do not interfere.""" - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v1 = array_ops.placeholder(dtypes.float32) v2 = array_ops.placeholder(dtypes.float32) - h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops.stack_push_v2(h1, v1) - with ops.control_dependencies([c1]): - c2 = gen_data_flow_ops.stack_push_v2(h2, v2) - with ops.control_dependencies([c2]): - pop1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) - pop2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) + def fn(): + h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0}) + c1 = gen_data_flow_ops.stack_push_v2(h1, v1) + with ops.control_dependencies([c1]): + c2 = gen_data_flow_ops.stack_push_v2(h2, v2) + with ops.control_dependencies([c2]): + pop1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + pop2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) + return [pop1, pop2] + + [pop1_compiled, pop2_compiled] = xla.compile(fn) + out1, out2 = sess.run([pop1_compiled, pop2_compiled], {v1: 4.0, v2: 5.0}) self.assertAllClose(out1, 4.0) self.assertAllClose(out2, 5.0) def testCloseStack(self): - with self.cached_session() as sess, self.test_scope(): - size = array_ops.placeholder(dtypes.int32) - h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops.stack_close_v2(h) - sess.run(c1, {size: 5}) + with self.session() as sess, self.test_scope(): + + def fn(): + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + gen_data_flow_ops.stack_close_v2(h) + + sess.run(xla.compile(fn)) def testPushCloseStack(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops.stack_push_v2(h, v) - with ops.control_dependencies([c]): - c1 = gen_data_flow_ops.stack_close_v2(h) - sess.run(c1, {v: [[4.0, 5.0]]}) + + def fn(): + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, v) + with ops.control_dependencies([c]): + gen_data_flow_ops.stack_close_v2(h) + + sess.run(xla.compile(fn), {v: [[4.0, 5.0]]}) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py index f0535579bf2..205bb501084 100644 --- a/tensorflow/compiler/tests/stateful_random_ops_test.py +++ b/tensorflow/compiler/tests/stateful_random_ops_test.py @@ -18,8 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -29,6 +28,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.kernel_tests.random import util as \ +random_test_util from tensorflow.python.ops import gen_stateful_random_ops from tensorflow.python.ops import stateful_random_ops as \ random @@ -50,25 +51,33 @@ def xla_device_name(): return str(name) -class StatefulRandomOpsTest(xla_test.XLATestCase): +ALGS = [random.RNG_ALG_PHILOX, random.RNG_ALG_THREEFRY] +INTS = [dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64] + + +# TODO(wangpeng): use parametrized tests to test both ThreeFry and Philox +class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): """Test cases for stateful random-number generator operators.""" + _ints = INTS + _floats = [dtypes.bfloat16, dtypes.float32] + + @parameterized.parameters(ALGS) @test_util.run_v2_only - def testSimple(self): - """A simple test. - """ + def testSimple(self, alg): + """A simple test.""" with ops.device(xla_device_name()): - gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) + gen = random.Generator.from_seed(seed=0, alg=alg) gen.normal(shape=(3,)) gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32) gen.uniform_full_int(shape=(3,)) + @parameterized.parameters(ALGS) @test_util.run_v2_only - def testDefun(self): - """Test for defun. - """ + def testDefun(self, alg): + """Test for defun.""" with ops.device(xla_device_name()): - gen = random.Generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) + gen = random.Generator.from_seed(seed=0, alg=alg) @def_function.function def f(): x = gen.normal(shape=(3,)) @@ -77,6 +86,26 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): return (x, y, z) f() + def _compareToKnownOutputs(self, g, counter, key, expect): + """Compares against known outputs for specific counter and key inputs.""" + def uint32s_to_uint64(a, b): + return b << 32 | a + + def uint32s_to_uint64s(ls): + return [uint32s_to_uint64(ls[2 * i], ls[2 * i + 1]) + for i in range(len(ls) // 2)] + + ctr_len = len(counter) + counter = uint32s_to_uint64s(counter) + key = uint32s_to_uint64s(key) + state = counter + key + g.reset(state) + got = g.uniform_full_int(shape=(ctr_len,), dtype=dtypes.uint32) + self.assertAllEqual(expect, got) + g.reset(state) + got = g.uniform_full_int(shape=(ctr_len // 2,), dtype=dtypes.uint64) + self.assertAllEqual(uint32s_to_uint64s(expect), got) + @test_util.run_v2_only def testThreefry2x32(self): """Tests ThreeFry2x32 conforms to known results. @@ -86,48 +115,108 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): # which is in turn based on # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32 - def uint32s_to_uint64(a, b): - return b << 32 | a - - def verify(counter1, counter2, key1, key2, expect1, expect2): - counter = uint32s_to_uint64(counter1, counter2) - key = uint32s_to_uint64(key1, key2) - random.get_global_generator().reset([counter, key]) - got = random.get_global_generator().uniform_full_int( - shape=(2,), dtype=dtypes.uint32) - expect = [expect1, expect2] - self.assertAllEqual(expect, got) - random.get_global_generator().reset([counter, key]) - got = random.get_global_generator().uniform_full_int( - shape=(), dtype=dtypes.uint64) - self.assertAllEqual(uint32s_to_uint64(*expect), got) - with ops.device(xla_device_name()): - random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY) - verify(0x00000000, 0x00000000, 0x00000000, 0x00000000, - 0x6b200159, 0x99ba4efe) - verify(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, - 0x1cb996fc, 0xbb002be7) - verify(0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344, - 0xc4923a9c, 0x483df7a0) + g = random.Generator.from_seed(seed=0, alg=random.RNG_ALG_THREEFRY) + self._compareToKnownOutputs( + g, + [0x00000000, 0x00000000], [0x00000000, 0x00000000], + [0x6b200159, 0x99ba4efe]) + self._compareToKnownOutputs( + g, + [0xffffffff, 0xffffffff], [0xffffffff, 0xffffffff], + [0x1cb996fc, 0xbb002be7]) + self._compareToKnownOutputs( + g, + [0x243f6a88, 0x85a308d3], [0x13198a2e, 0x03707344], + [0xc4923a9c, 0x483df7a0]) @test_util.run_v2_only - def testNewState(self): - """Tests that the new state is correct. + def testPhilox4x32(self): + """Tests Philox4x32 conforms to known results. + """ + # Based on + # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_philox.cpp#L50-L52 + + with ops.device(xla_device_name()): + g = random.Generator.from_seed(seed=0, alg=random.RNG_ALG_PHILOX) + self._compareToKnownOutputs( + g, + [0x00000000, 0x00000000, 0x00000000, 0x00000000], + [0x00000000, 0x00000000], + [0x6627e8d5, 0xe169c58d, 0xbc57ac4c, 0x9b00dbd8]) + self._compareToKnownOutputs( + g, + [0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff], + [0xffffffff, 0xffffffff], + [0x408f276d, 0x41c83b0e, 0xa20bc7c6, 0x6d5451fd]) + self._compareToKnownOutputs( + g, + [0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344], + [0xa4093822, 0x299f31d0], + [0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1]) + + @test_util.run_v2_only + def testNewStateThreeFry(self): + """Tests that the new state is correct (for ThreeFry). """ with ops.device(xla_device_name()): counter = 57 key = 0x1234 size = 46 - seed = [counter, key] - gen = random.Generator( - seed=seed, algorithm=random.RNG_ALG_THREEFRY) + state = [counter, key] + gen = random.Generator(state=state, alg=random.RNG_ALG_THREEFRY) gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32) self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value()) - gen.reset(seed=seed) + gen.reset(state) gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64) self.assertAllEqual([counter+size, key], gen.state.read_value()) + @test_util.run_v2_only + def testNewStatePhilox(self): + """Tests that the new state is correct (for Philox). + """ + with ops.device(xla_device_name()): + counter_low = 57 + counter_high = 283 + key = 0x1234 + size = 47 + state = [counter_low, counter_high, key] + gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX) + gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32) + self.assertAllEqual([counter_low+(size+3)//4, counter_high, key], + gen.state.read_value()) + gen.reset(state) + gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64) + self.assertAllEqual([counter_low+(size+1)//2, counter_high, key], + gen.state.read_value()) + # Tests that large counter_low will correctly overflows to counter_high + counter_low = -1 # same as 0xffffffffffffffff + counter_high = 283 + size = 47 + state = [counter_low, counter_high, key] + gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX) + gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32) + self.assertAllEqual([(size+3)//4-1, counter_high+1, key], + gen.state.read_value()) + gen.reset(state) + gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64) + self.assertAllEqual([(size+1)//2-1, counter_high+1, key], + gen.state.read_value()) + + @parameterized.parameters(INTS) + @test_util.run_v2_only + def testXLAEqualsCPU(self, dtype): + """Tests that XLA and CPU kernels generate the same integers.""" + seed = 1234 + shape = [315, 49] + with ops.device("/device:CPU:0"): + cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX) + .uniform_full_int(shape=shape, dtype=dtype)) + with ops.device(xla_device_name()): + xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX) + .uniform_full_int(shape=shape, dtype=dtype)) + self.assertAllEqual(cpu, xla) + def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. # The random-number generator, if working correctly, should produce the @@ -136,10 +225,11 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): y = rng(dtype).numpy() self.assertFalse(np.array_equal(x, y)) + @parameterized.parameters(ALGS) @test_util.run_v2_only - def testUniformIsNotConstant(self): + def testUniformIsNotConstant(self, alg): with ops.device(xla_device_name()): - gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + gen = random.Generator.from_seed(seed=1234, alg=alg) def rng(dtype): maxval = dtype.max # Workaround for b/125364959 @@ -147,56 +237,52 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): maxval = 10000000 return gen.uniform(shape=[2], dtype=dtype, maxval=maxval) - for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + for dtype in self._ints + self._floats: self._testRngIsNotConstant(rng, dtype) + @parameterized.parameters(ALGS) @test_util.run_v2_only - def testNormalIsNotConstant(self): + def testNormalIsNotConstant(self, alg): with ops.device(xla_device_name()): - gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + gen = random.Generator.from_seed(seed=1234, alg=alg) def rng(dtype): return gen.normal(shape=[2], dtype=dtype) - for dtype in {dtypes.float32}: + for dtype in self._floats: self._testRngIsNotConstant(rng, dtype) + @parameterized.parameters(ALGS) @test_util.run_v2_only - def testUniformIntIsInRange(self): + def testUniformIsInRange(self, alg): minval = 2 maxval = 33 size = 1000 with ops.device(xla_device_name()): - gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) - for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: + for dtype in self._ints + self._floats: + gen = random.Generator.from_seed(seed=1234, alg=alg) x = gen.uniform( shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy() self.assertTrue(np.all(x >= minval)) - self.assertTrue(np.all(x < maxval)) + self.assertTrue(np.all(x <= maxval)) + @parameterized.parameters(ALGS) @test_util.run_v2_only - def testNormalIsFinite(self): + def testNormalIsFinite(self, alg): with ops.device(xla_device_name()): - gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) - for dtype in {dtypes.float32}: + gen = random.Generator.from_seed(seed=1234, alg=alg) + for dtype in self._floats: x = gen.normal(shape=[10000], dtype=dtype).numpy() self.assertTrue(np.all(np.isfinite(x))) - def _chi_squared(self, x, bins): - """Pearson's Chi-squared test.""" - x = np.ravel(x) - n = len(x) - histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) - expected = n / float(bins) - return np.sum(np.square(histogram - expected) / expected) - + @parameterized.parameters(ALGS) @test_util.run_v2_only - def testDistributionOfUniform(self): + def testDistributionOfUniform(self, alg): """Use Pearson's Chi-squared test to test for uniformity.""" with ops.device(xla_device_name()): n = 1000 seed = 12 - for dtype in {dtypes.int32, dtypes.uint32, dtypes.int64, dtypes.uint64}: - gen = random.Generator(seed=seed, algorithm=random.RNG_ALG_THREEFRY) + for dtype in self._ints + self._floats: + gen = random.Generator.from_seed(seed=seed, alg=alg) maxval = 1 if dtype.is_integer: maxval = 100 @@ -208,34 +294,34 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # p=0.05. This test is probabilistic and would be flaky if the random # seed were not fixed. - val = self._chi_squared(x, 10) + val = random_test_util.chi_squared(x, 10) self.assertLess(val, 16.92) - def _normal_cdf(self, x): - """Cumulative distribution function for a standard normal distribution.""" - return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) - - def _anderson_darling(self, x): - """Anderson-Darling test for a standard normal distribution.""" - x = np.sort(np.ravel(x)) - n = len(x) - i = np.linspace(1, n, n) - z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + - (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) - return -n - z / n - + @parameterized.parameters(ALGS) @test_util.run_v2_only - def testDistributionOfNormal(self): + def testDistributionOfNormal(self, alg): """Use Anderson-Darling test to test distribution appears normal.""" with ops.device(xla_device_name()): n = 1000 - for dtype in {dtypes.float32}: - gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + for dtype in self._floats: + gen = random.Generator.from_seed(seed=1234, alg=alg) x = gen.normal(shape=[n], dtype=dtype).numpy() # The constant 2.492 is the 5% critical value for the Anderson-Darling # test where the mean and variance are known. This test is probabilistic # so to avoid flakiness the seed is fixed. - self.assertLess(self._anderson_darling(x.astype(float)), 2.492) + self.assertLess( + random_test_util.anderson_darling(x.astype(float)), 2.492) + + @parameterized.parameters(ALGS) + @test_util.run_v2_only + def testTruncatedNormal(self, alg): + with ops.device(xla_device_name()): + for dtype in self._floats: + gen = random.Generator.from_seed(seed=123, alg=alg) + n = 10000000 + y = gen.truncated_normal(shape=[n], dtype=dtype).numpy() + random_test_util.test_truncated_normal( + self.assertEqual, self.assertAllClose, dtype, n, y) @test_util.run_v2_only def testErrors(self): @@ -243,7 +329,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): """ shape = [2, 3] with ops.device(xla_device_name()): - gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY) + gen = random.Generator.from_seed(seed=1234, alg=random.RNG_ALG_THREEFRY) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, r"algorithm must be of shape \[\], not"): @@ -273,9 +359,15 @@ class StatefulRandomOpsTest(xla_test.XLATestCase): var = variables.Variable([0], dtype=dtypes.int64) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, - "For the ThreeFry algorithm, the size of state must be at least"): + "The size of the state must be at least"): gen_stateful_random_ops.stateful_standard_normal_v2( var.handle, random.RNG_ALG_THREEFRY, shape) + var = variables.Variable([0, 0], dtype=dtypes.int64) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + "The size of the state must be at least"): + gen_stateful_random_ops.stateful_standard_normal_v2( + var.handle, random.RNG_ALG_PHILOX, shape) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index df5914a518e..edc14729d2e 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -18,15 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.kernel_tests.random import util as \ +random_test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import stateless_random_ops as stateless -from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test @@ -34,16 +33,16 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self, include_int=False): - allowed_types = {dtypes.float32, dtypes.float64, dtypes.bfloat16} + allowed_types = {dtypes.float32, dtypes.bfloat16} if include_int: allowed_types.update({dtypes.int32, dtypes.int64}) return self.all_tf_types & allowed_types def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension for stateless_op in [ stateless.stateless_random_uniform, stateless.stateless_random_normal ]: @@ -63,7 +62,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertEqual(s0 == s1, np.all(v0 == v1)) def testRandomUniformIsInRange(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): for dtype in self._random_types(include_int=True): maxval = 1 if dtype.is_integer: @@ -75,17 +74,9 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(np.all(y >= 0)) self.assertTrue(np.all(y < maxval)) - def _chi_squared(self, x, bins): - """Pearson's Chi-squared test.""" - x = np.ravel(x) - n = len(x) - histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) - expected = n / float(bins) - return np.sum(np.square(histogram - expected) / expected) - def testDistributionOfStatelessRandomUniform(self): """Use Pearson's Chi-squared test to test for uniformity.""" - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): for dtype in self._random_types(include_int=True): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 @@ -102,10 +93,10 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # p=0.05. This test is probabilistic and would be flaky if the random # seed were not fixed. - self.assertTrue(self._chi_squared(y, 10) < 16.92) + self.assertLess(random_test_util.chi_squared(y, 10), 16.92) def testRandomNormalIsFinite(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_normal( @@ -113,22 +104,9 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(np.isfinite(y))) - def _normal_cdf(self, x): - """Cumulative distribution function for a standard normal distribution.""" - return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) - - def _anderson_darling(self, x): - """Anderson-Darling test for a standard normal distribution.""" - x = np.sort(np.ravel(x)) - n = len(x) - i = np.linspace(1, n, n) - z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + - (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) - return -n - z / n - def testDistributionOfStatelessRandomNormal(self): """Use Anderson-Darling test to test distribution appears normal.""" - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 @@ -138,57 +116,19 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # The constant 2.492 is the 5% critical value for the Anderson-Darling # test where the mean and variance are known. This test is probabilistic # so to avoid flakiness the seed is fixed. - self.assertTrue(self._anderson_darling(y.astype(float)) < 2.492) + self.assertLess( + random_test_util.anderson_darling(y.astype(float)), 2.492) - def testTruncatedNormalIsInRange(self): + def testTruncatedNormal(self): for dtype in self._random_types(): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 x = stateless.stateless_truncated_normal( shape=[n], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) - - def normal_cdf(x): - return .5 * math.erfc(-x / math.sqrt(2)) - - def normal_pdf(x): - return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) - - def probit(x, sess=sess): - return self.evaluate(special_math.ndtri(x)) - - a = -2. - b = 2. - mu = 0. - sigma = 1. - - alpha = (a - mu) / sigma - beta = (b - mu) / sigma - z = normal_cdf(beta) - normal_cdf(alpha) - - self.assertEqual((y >= a).sum(), n) - self.assertEqual((y <= b).sum(), n) - - # For more information on these calculations, see: - # Burkardt, John. "The Truncated Normal Distribution". - # Department of Scientific Computing website. Florida State University. - expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma - y = y.astype(float) - actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=5e-4) - - expected_median = mu + probit( - (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma - actual_median = np.median(y) - self.assertAllClose(actual_median, expected_median, atol=8e-4) - - expected_variance = sigma**2 * (1 + ( - (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( - (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) - actual_variance = np.var(y) - self.assertAllClose(actual_variance, expected_variance, - rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3) + random_test_util.test_truncated_normal( + self.assertEqual, self.assertAllClose, dtype, n, y) if __name__ == '__main__': diff --git a/tensorflow/compiler/tests/svd_op_test.py b/tensorflow/compiler/tests/svd_op_test.py new file mode 100644 index 00000000000..7791b409a37 --- /dev/null +++ b/tensorflow/compiler/tests/svd_op_test.py @@ -0,0 +1,93 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.svd.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_linalg_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.platform import test + + +class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def _compute_usvt(self, s, u, v): + m = u.shape[-1] + n = v.shape[-1] + if m <= n: + v = v[..., :m] + else: + u = u[..., :n] + + return np.matmul(u * s[..., None, :], np.swapaxes(v, -1, -2)) + + def _testSvdCorrectness(self, dtype, shape): + np.random.seed(1) + x_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype) + m, n = shape[-2], shape[-1] + _, s_np, _ = np.linalg.svd(x_np) + with self.session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + s, u, v = linalg_ops.svd(x_tf, full_matrices=True) + s_val, u_val, v_val = sess.run([s, u, v], feed_dict={x_tf: x_np}) + u_diff = np.matmul(u_val, np.swapaxes(u_val, -1, -2)) - np.eye(m) + v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n) + # Check u_val and v_val are orthogonal matrices. + self.assertLess(np.linalg.norm(u_diff), 1e-2) + self.assertLess(np.linalg.norm(v_diff), 1e-2) + # Check that the singular values are correct, i.e., close to the ones from + # numpy.lingal.svd. + self.assertLess(np.linalg.norm(s_val - s_np), 1e-2) + # The tolerance is set based on our tests on numpy's svd. As our tests + # have batch dimensions and all our operations are on float32, we set the + # tolerance a bit larger. Numpy's svd calls LAPACK's svd, which operates + # on double precision. + self.assertLess( + np.linalg.norm(self._compute_usvt(s_val, u_val, v_val) - x_np), 2e-2) + + # Check behavior with compute_uv=False. We expect to still see 3 outputs, + # with a sentinel scalar 0 in the last two outputs. + with self.test_scope(): + no_uv_s, no_uv_u, no_uv_v = gen_linalg_ops.svd( + x_tf, full_matrices=True, compute_uv=False) + no_uv_s_val, no_uv_u_val, no_uv_v_val = sess.run( + [no_uv_s, no_uv_u, no_uv_v], feed_dict={x_tf: x_np}) + self.assertAllClose(no_uv_s_val, s_val, atol=1e-4, rtol=1e-4) + self.assertEqual(no_uv_u_val, 0.0) + self.assertEqual(no_uv_v_val, 0.0) + + SIZES = [1, 2, 5, 10, 32, 64] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testSvd(self, n, dtype): + for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10): + self._testSvdCorrectness(dtype, batch_dims + (n, n)) + self._testSvdCorrectness(dtype, batch_dims + (2 * n, n)) + self._testSvdCorrectness(dtype, batch_dims + (n, 2 * n)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index d7e26d79c4c..99847e84c28 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -21,13 +21,17 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.compiler.xla import xla from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops @@ -40,131 +44,152 @@ def _make_converter(dtype): return np.asarray(x).astype(dtype.as_numpy_dtype) return _converter +# This lets me define `fn` repeatedly to pass to xla.compile. +# +# pylint: disable=function-redefined + +@test_util.with_control_flow_v2 class TensorArrayTest(xla_test.XLATestCase): + @test_util.disable_control_flow_v2("Tries to evaluate flow") def testTensorArrayWriteRead(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=3) + with self.session() as session, self.test_scope(): - w0 = ta.write(0, [[4.0, 5.0]]) - w1 = w0.write(1, [[1.0, 3.0]]) - w2 = w1.write(2, [[7.0, -8.5]]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) - r0 = w2.read(0) - r1 = w2.read(1) - r2 = w2.read(2) - flow = w2.flow + w0 = ta.write(0, [[4.0, 5.0]]) + w1 = w0.write(1, [[1.0, 3.0]]) + w2 = w1.write(2, [[7.0, -8.5]]) - d0, d1, d2, flow_val = session.run([r0, r1, r2, flow]) + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + flow = w2.flow + return [r0, r1, r2, flow] + + d0, d1, d2, flow_val = self.evaluate(xla.compile(fn)) self.assertAllEqual([[4.0, 5.0]], d0) self.assertAllEqual([[1.0, 3.0]], d1) self.assertAllEqual([[7.0, -8.5]], d2) self.assertAllEqual([], flow_val.shape) def _testTensorArrayWritePack(self, tf_dtype): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, tensor_array_name="foo", size=3) - + with self.session(), self.test_scope(): convert = _make_converter(tf_dtype) - w0 = ta.write(0, convert([[4.0, 5.0]])) - w1 = w0.write(1, convert([[6.0, 7.0]])) - w2 = w1.write(2, convert([[8.0, 9.0]])) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) - c0 = w2.stack() + w0 = ta.write(0, convert([[4.0, 5.0]])) + w1 = w0.write(1, convert([[6.0, 7.0]])) + w2 = w1.write(2, convert([[8.0, 9.0]])) + + return w2.stack() self.assertAllEqual( convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), - self.evaluate(c0)) + self.evaluate(xla.compile(fn)[0])) def testTensorArrayWritePack(self): for dtype in self.numeric_tf_types: self._testTensorArrayWritePack(dtype) def testEmptyTensorArrayPack(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3) + with self.session(), self.test_scope(): - empty_element = np.zeros((0, 1), dtype=np.float32) - w0 = ta.write(0, empty_element) - w1 = w0.write(1, empty_element) - w2 = w1.write(2, empty_element) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) - c0 = w2.stack() + empty_element = np.zeros((0, 1), dtype=np.float32) + w0 = ta.write(0, empty_element) + w1 = w0.write(1, empty_element) + w2 = w1.write(2, empty_element) - self.assertAllEqual([3, 0, 1], self.evaluate(c0).shape) + return w2.stack() + + self.assertAllEqual([3, 0, 1], self.evaluate(xla.compile(fn)[0]).shape) def _testTensorArrayWriteConcat(self, tf_dtype): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, tensor_array_name="foo", size=3) - + with self.session(), self.test_scope(): convert = _make_converter(tf_dtype) - w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0]])) - w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]])) - w2 = w1.write(2, convert([[8.0, 9.0], [204.0, 205.0]])) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) - c0 = w2.concat() + w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0]])) + w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]])) + w2 = w1.write(2, convert([[8.0, 9.0], [204.0, 205.0]])) + + return w2.concat() self.assertAllEqual( convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0], - [8.0, 9.0], [204.0, 205.0]]), self.evaluate(c0)) + [8.0, 9.0], [204.0, 205.0]]), + self.evaluate(xla.compile(fn)[0])) + @test_util.disable_control_flow_v2("b/122315751 (concat)") def testTensorArrayWriteConcat(self): for dtype in self.numeric_tf_types: self._testTensorArrayWriteConcat(dtype) def _testTensorArrayUnpackRead(self, tf_dtype): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, tensor_array_name="foo", size=3) - + with self.session() as session, self.test_scope(): convert = _make_converter(tf_dtype) - # Unpack a vector into scalars - w0 = ta.unstack(convert([1.0, 2.0, 3.0])) - r0 = w0.read(0) - r1 = w0.read(1) - r2 = w0.read(2) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) - d0, d1, d2 = session.run([r0, r1, r2]) + # Unpack a vector into scalars + w0 = ta.unstack(convert([1.0, 2.0, 3.0])) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + + return [r0, r1, r2] + + d0, d1, d2 = self.evaluate(xla.compile(fn)) self.assertAllEqual(convert(1.0), d0) self.assertAllEqual(convert(2.0), d1) self.assertAllEqual(convert(3.0), d2) - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, tensor_array_name="foo", size=3) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) - # Unpack a matrix into vectors. - w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])) - r0 = w1.read(0) - r1 = w1.read(1) - r2 = w1.read(2) + # Unpack a matrix into vectors. + w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]])) + r0 = w1.read(0) + r1 = w1.read(1) + r2 = w1.read(2) + return [r0, r1, r2] + + d0, d1, d2 = self.evaluate(xla.compile(fn)) - d0, d1, d2 = session.run([r0, r1, r2]) self.assertAllEqual(convert([1.0, 1.1]), d0) self.assertAllEqual(convert([2.0, 2.1]), d1) self.assertAllEqual(convert([3.0, 3.1]), d2) - # Reset ta because we're going to change the shape, else shape - # inference will throw an error. - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, tensor_array_name="foo", size=3) + def fn(): + # Reset ta because we're going to change the shape, else shape + # inference will throw an error. + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) - # Try unpacking an empty matrix, which should not cause an error. - w2 = ta.unstack(convert([[], [], []])) - r0 = w2.read(0) - r1 = w2.read(1) - r2 = w2.read(2) + # Try unpacking an empty matrix, which should not cause an error. + w2 = ta.unstack(convert([[], [], []])) + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) + return [r0, r1, r2] - d0, d1, d2 = session.run([r0, r1, r2]) + d0, d1, d2 = self.evaluate(xla.compile(fn)) self.assertAllEqual(convert([]), d0) self.assertAllEqual(convert([]), d1) self.assertAllEqual(convert([]), d2) @@ -177,83 +202,96 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayUnpackReadMaybeLegacy() def _testTensorArraySplitRead(self, tf_dtype): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, tensor_array_name="foo", size=3) - + with self.session() as session, self.test_scope(): convert = _make_converter(tf_dtype) - # Split an empty vector. - lengths = constant_op.constant([0, 0, 0]) - w0 = ta.split(convert([]), lengths=lengths) - r0 = w0.read(0) - r1 = w0.read(1) - r2 = w0.read(2) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + + # Split an empty vector. + lengths = constant_op.constant([0, 0, 0]) + w0 = ta.split(convert([]), lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + return [r0, r1, r2] + + d0, d1, d2 = self.evaluate(xla.compile(fn)) - d0, d1, d2 = session.run([r0, r1, r2]) self.assertAllEqual(convert([]), d0) self.assertAllEqual(convert([]), d1) self.assertAllEqual(convert([]), d2) - # Split a vector. - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, tensor_array_name="foo", size=3) - lengths = constant_op.constant([1, 1, 1]) - w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths) - r0 = w0.read(0) - r1 = w0.read(1) - r2 = w0.read(2) + def fn(): + # Split a vector. + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + lengths = constant_op.constant([1, 1, 1]) + w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + return [r0, r1, r2] + + d0, d1, d2 = self.evaluate(xla.compile(fn)) - d0, d1, d2 = session.run([r0, r1, r2]) self.assertAllEqual(convert([1.0]), d0) self.assertAllEqual(convert([2.0]), d1) self.assertAllEqual(convert([3.0]), d2) - # Split a matrix. - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, tensor_array_name="foo", size=3) - lengths = constant_op.constant([1, 1, 1]) - w0 = ta.split( - convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), lengths=lengths) - r0 = w0.read(0) - r1 = w0.read(1) - r2 = w0.read(2) + def fn(): + # Split a matrix. + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=3) + lengths = constant_op.constant([1, 1, 1]) + w0 = ta.split( + convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), + lengths=lengths) + r0 = w0.read(0) + r1 = w0.read(1) + r2 = w0.read(2) + return [r0, r1, r2] - d0, d1, d2 = session.run([r0, r1, r2]) + d0, d1, d2 = self.evaluate(xla.compile(fn)) self.assertAllEqual(convert([[1.0, 101.0]]), d0) self.assertAllEqual(convert([[2.0, 201.0]]), d1) self.assertAllEqual(convert([[3.0, 301.0]]), d2) + @test_util.disable_control_flow_v2("b/122315872 (split)") def testTensorArraySplitRead(self): for dtype in self.numeric_tf_types: self._testTensorArraySplitRead(dtype) + @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") def testTensorGradArrayWriteRead(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=3) + with self.session() as session, self.test_scope(): - w0 = ta.write(0, [[4.0]]) - w1 = w0.write(1, [[1.0]]) - w2 = w1.write(2, [[-3.0]]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) - g_ta = w2.grad("grad") + w0 = ta.write(0, [[4.0]]) + w1 = w0.write(1, [[1.0]]) + w2 = w1.write(2, [[-3.0]]) - g_w0 = g_ta.write(0, [[5.0]]) - g_w1 = g_w0.write(1, [[2.0]]) - g_w2 = g_w1.write(2, [[-2.0]]) + g_ta = w2.grad("grad") - r0 = w2.read(0) - r1 = w2.read(1) - r2 = w2.read(2) + g_w0 = g_ta.write(0, [[5.0]]) + g_w1 = g_w0.write(1, [[2.0]]) + g_w2 = g_w1.write(2, [[-2.0]]) - g_r0 = g_w2.read(0) - g_r1 = g_w2.read(1) - g_r2 = g_w2.read(2) + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) - d0, d1, d2, g_d0, g_d1, g_d2 = session.run([r0, r1, r2, g_r0, g_r1, g_r2]) + g_r0 = g_w2.read(0) + g_r1 = g_w2.read(1) + g_r2 = g_w2.read(2) + + return [r0, r1, r2, g_r0, g_r1, g_r2] + + d0, d1, d2, g_d0, g_d1, g_d2 = self.evaluate(xla.compile(fn)) self.assertAllEqual([[4.0]], d0) self.assertAllEqual([[1.0]], d1) self.assertAllEqual([[-3.0]], d2) @@ -261,36 +299,38 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[2.0]], g_d1) self.assertAllEqual([[-2.0]], g_d2) + @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") def testTensorGradArrayDynamicWriteRead(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=3) + with self.session() as session, self.test_scope(): - w0 = ta.write(0, [[4.0]]) - w1 = w0.write(1, [[1.0]]) - w2 = w1.write(2, [[-3.0]]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) - g_ta = w2.grad("grad") # Get gradient array here so we know the shape + w0 = ta.write(0, [[4.0]]) + w1 = w0.write(1, [[1.0]]) + w2 = w1.write(2, [[-3.0]]) - s = w2.size() - g_s = g_ta.size() + g_ta = w2.grad("grad") # Get gradient array here so we know the shape - g_w0 = g_ta.write(0, [[5.0]]) - g_w1 = g_w0.write(1, [[2.0]]) - g_w2 = g_w1.write(2, [[-2.0]]) + s = w2.size() + g_s = g_ta.size() - r0 = w2.read(0) - r1 = w2.read(1) - r2 = w2.read(2) + g_w0 = g_ta.write(0, [[5.0]]) + g_w1 = g_w0.write(1, [[2.0]]) + g_w2 = g_w1.write(2, [[-2.0]]) - g_r0 = g_w2.read(0) - g_r1 = g_w2.read(1) - g_r2 = g_w2.read(2) + r0 = w2.read(0) + r1 = w2.read(1) + r2 = w2.read(2) - d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = session.run( - [r0, r1, r2, g_r0, g_r1, g_r2, s, g_s]) + g_r0 = g_w2.read(0) + g_r1 = g_w2.read(1) + g_r2 = g_w2.read(2) + + return [r0, r1, r2, g_r0, g_r1, g_r2, s, g_s] + + d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = self.evaluate(xla.compile(fn)) self.assertAllEqual([[4.0]], d0) self.assertAllEqual([[1.0]], d1) self.assertAllEqual([[-3.0]], d2) @@ -300,174 +340,265 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(3, vs) self.assertAllEqual(3, g_vs) + @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") def testTensorGradAccessTwiceReceiveSameObject(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3, - element_shape=[1, 2]) - g_ta_0 = ta.grad("grad") - g_ta_1 = ta.grad("grad") + with self.session() as session, self.test_scope(): + ta_out = {} - with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]): - # Write with one gradient handle, read with another copy of it - r1_0 = g_ta_1.read(0) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + element_shape=[1, 2]) - t_g_ta_0, t_g_ta_1, d_r1_0 = session.run( - [g_ta_0.handle.op, g_ta_1.handle.op, r1_0]) - self.assertAllEqual(t_g_ta_0, t_g_ta_1) + g_ta_0 = ta.grad("grad") + g_ta_1 = ta.grad("grad") + + ta_out[0] = g_ta_0.handle + ta_out[1] = g_ta_1.handle + + with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]): + # Write with one gradient handle, read with another copy of it + r1_0 = g_ta_1.read(0) + + with ops.control_dependencies([g_ta_0.handle.op, g_ta_1.handle.op]): + return [r1_0] + + [d_r1_0] = self.evaluate(xla.compile(fn)) self.assertAllEqual([[4.0, 5.0]], d_r1_0) + # Can't assert this because adding a side output like we have here fails + # as follows: + # + # ValueError: Operation u'TensorArrayGrad/TensorArrayGradV3' has been + # marked as not fetchable. + # + # On the other hand, legitimately returning the handle from the + # xla.compile function fails because we don't support DT_RESOURCE outputs + # from XLA clusters. + # + # self.assertAllEqual(ta_out[0], ta_out[1]) + + @test_util.disable_control_flow_v2("b/124334470") def testTensorArrayWriteWrongIndexOrDataTypeFails(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3) + with self.session(), self.test_scope(): + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + return ta.write(-1, constant_op.constant(7)).flow # Test writing the wrong datatype. - with self.assertRaisesOpError( - "TensorArray dtype is float but op has dtype int32"): - ta.write(-1, np.int32(7)).flow.eval() + # TODO(b/129870929): Remove InvalidArgumentError/second regexp after all + # callers provide proper init dtype. + with self.assertRaisesRegexp( + (ValueError, errors.InvalidArgumentError), + r"(" + r"conversion requested dtype float32 for Tensor with dtype int32" + r"|" + r"TensorArray dtype is float but op has dtype int32" + r")"): + xla.compile(fn)[0].eval() + @test_util.disable_control_flow_v2("b/124334096 verify dtype") def testTensorArrayReadWrongIndexOrDataTypeFails(self): # Find two different floating point types, create an array of # the first type, but try to read the other type. if len(self.float_types) > 1: dtype1, dtype2 = list(self.float_types)[:2] - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtype1, tensor_array_name="foo", size=3) + with self.session(), self.test_scope(): - w0 = ta.write(0, [[4.0, 5.0]]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtype1, tensor_array_name="foo", size=3) + + w0 = ta.write(0, math_ops.cast([[4.0, 5.0]], dtype1)) + + # Test reading wrong datatype. + return gen_data_flow_ops.tensor_array_read_v3( + handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) - # Test reading wrong datatype. - r0_bad = gen_data_flow_ops.tensor_array_read_v3( - handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) with self.assertRaisesOpError("TensorArray dtype is "): - self.evaluate(r0_bad) + self.evaluate(xla.compile(fn)) - # Test reading from a different index than the one we wrote to - w0.read(1) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtype1, tensor_array_name="foo", size=3) + w0 = ta.write(0, math_ops.cast([[4.0, 5.0]], dtype1)) + + # Test reading from a different index than the one we wrote to + with ops.control_dependencies([w0.read(1)]): + return 1.0 + + xla.compile(fn)[0].eval() + + @test_util.disable_control_flow_v2("b/122315872 (split)") def testTensorArraySplitIncompatibleShapesFails(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=3, - infer_shape=False) + with self.session(), self.test_scope(): - with self.assertRaisesOpError( - r"value is not 1D"): - lengths = array_ops.placeholder(dtypes.int64) - ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1}) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + return ta.split([1.0, 2.0, 3.0], 1).flow + + with self.assertRaisesWithPredicateMatch( + ValueError, r"Shape must be rank 1 but is rank 0"): + xla.compile(fn)[0].eval() + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + return ta.split([1.0, 2.0, 3.0], [1, 2, 3]).flow with self.assertRaisesOpError( r"lengths must be equal: 1 vs. 2"): - ta.split([1.0, 2.0, 3.0], [1, 2, 3]).flow.eval() + xla.compile(fn)[0].eval() + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + return ta.split(1.0, [1]).flow with self.assertRaisesOpError( r"value must have rank >= 1"): - ta.split(1.0, [1]).flow.eval() + xla.compile(fn)[0].eval() - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=2, - infer_shape=False) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + infer_shape=False) + + return ta.split([1.0], [1]).flow with self.assertRaisesOpError( r"TensorArray's size is not equal to the size of lengths " r"\(1 vs. 2\)"): - ta.split([1.0], [1]).flow.eval() + xla.compile(fn)[0].eval() def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) - + with self.session(), self.test_scope(): c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype) - w0 = ta.write(2, c(3.0)) - w1 = w0.write(2, c(4.0)) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) - ta_grad = w1.grad("grad") + w0 = ta.write(2, c(3.0)) + w1 = w0.write(2, c(4.0)) - w0_grad = ta_grad.write(2, c(3.0)) - w1_grad = w0_grad.write(2, c(4.0)) - w2_grad = w1_grad.write(2, c(5.0)) + ta_grad = w1.grad("grad") + + w0_grad = ta_grad.write(2, c(3.0)) + w1_grad = w0_grad.write(2, c(4.0)) + w2_grad = w1_grad.write(2, c(5.0)) + + return w2_grad.read(2) # Assert that aggregation works correctly - self.assertAllEqual(c(12.00), w2_grad.read(2).eval()) + self.assertAllEqual(c(12.00), xla.compile(fn)[0].eval()) - # Using differing shapes causes an exception - wb0_grad = ta_grad.write(1, c(1.0)) - wb1_grad = wb0_grad.write(1, c([1.0])) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) + + w0 = ta.write(2, c(3.0)) + w1 = w0.write(2, c(4.0)) + + ta_grad = w1.grad("grad") + # Using differing shapes causes an exception + wb0_grad = ta_grad.write(1, c(1.0)) + wb1_grad = wb0_grad.write(1, c([1.0])) + + return wb1_grad.flow with self.assertRaisesOpError( r"Mismatched TensorArray sizes"): - wb1_grad.flow.eval() + xla.compile(fn)[0].eval() + @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2") def testTensorArrayWriteGradientAddMultipleAdds(self): for dtype in self.numeric_tf_types: self._testTensorArrayWriteGradientAddMultipleAdds(dtype) def testMultiTensorArray(self): - with self.cached_session(), self.test_scope(): - h1 = tensor_array_ops.TensorArray( - size=1, dtype=dtypes.float32, tensor_array_name="foo") - w1 = h1.write(0, 4.0) - r1 = w1.read(0) + with self.session(), self.test_scope(): - h2 = tensor_array_ops.TensorArray( - size=1, dtype=dtypes.float32, tensor_array_name="bar") + def fn(): + h1 = tensor_array_ops.TensorArray( + size=1, dtype=dtypes.float32, tensor_array_name="foo") + w1 = h1.write(0, 4.0) + r1 = w1.read(0) - w2 = h2.write(0, 5.0) - r2 = w2.read(0) - r = r1 + r2 - self.assertAllClose(9.0, self.evaluate(r)) + h2 = tensor_array_ops.TensorArray( + size=1, dtype=dtypes.float32, tensor_array_name="bar") + + w2 = h2.write(0, 5.0) + r2 = w2.read(0) + return r1 + r2 + + self.assertAllClose(9.0, self.evaluate(xla.compile(fn)[0])) def _testTensorArrayGradientWriteReadType(self, dtype): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.as_dtype(dtype), - tensor_array_name="foo", - size=3, - infer_shape=False) - + with self.session() as session, self.test_scope(): c = lambda x: np.array(x, dtype=dtype) - value_0 = constant_op.constant(c([[4.0, 5.0]])) - value_1 = constant_op.constant(c([[3.0, 3.5]])) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.as_dtype(dtype), + tensor_array_name="foo", + size=3, + infer_shape=False) - w0 = ta.write(0, value_0) - w1 = w0.write(1, value_1) - r0 = w1.read(0) - r1 = w1.read(1) - r0_2 = w1.read(0) + value_0 = constant_op.constant(c([[4.0, 5.0]])) + value_1 = constant_op.constant(c([[3.0, 3.5]])) + + w0 = ta.write(0, value_0) + w1 = w0.write(1, value_1) + r0 = w1.read(0) + r1 = w1.read(1) + r0_2 = w1.read(0) + + # Test individual components' gradients + grad_just_r0 = gradients_impl.gradients( + ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])]) + grad_r0_r0_2 = gradients_impl.gradients( + ys=[r0, r0_2], + xs=[value_0], + grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])]) + grad_just_r1 = gradients_impl.gradients( + ys=[r1], xs=[value_1], grad_ys=[c([[-2.0, -4.0]])]) + # Test combined gradients + grad = gradients_impl.gradients( + ys=[r0, r0_2, r1], + xs=[value_0, value_1], + grad_ys=[c([[2.0, 3.0]]), + c([[1.0, -1.0]]), + c([[-2.0, -10.0]])]) + + return [grad_just_r0, grad_r0_r0_2, grad_just_r1, grad] + + [grad_just_r0_vals, grad_r0_r0_2_vals, grad_just_r1_vals, + grad_vals] = self.evaluate(xla.compile(fn)) - # Test individual components' gradients - grad_just_r0 = gradients_impl.gradients( - ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])]) - grad_just_r0_vals = session.run(grad_just_r0) self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0]) - grad_r0_r0_2 = gradients_impl.gradients( - ys=[r0, r0_2], - xs=[value_0], - grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])]) - grad_r0_r0_2_vals = session.run(grad_r0_r0_2) self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0]) - grad_just_r1 = gradients_impl.gradients( - ys=[r1], xs=[value_1], grad_ys=[c([[-2.0, -4.0]])]) - grad_just_r1_vals = session.run(grad_just_r1) self.assertAllEqual(c([[-2.0, -4.0]]), grad_just_r1_vals[0]) - # Test combined gradients - grad = gradients_impl.gradients( - ys=[r0, r0_2, r1], - xs=[value_0, value_1], - grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]]), c([[-2.0, -10.0]])]) - grad_vals = session.run(grad) self.assertEqual(len(grad_vals), 2) self.assertAllEqual(c([[3.0, 2.0]]), grad_vals[0]) self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1]) @@ -479,77 +610,88 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): - with self.cached_session() as sess, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=2, - clear_after_read=False) + with self.session() as sess, self.test_scope(): - value_0 = constant_op.constant([-1.0, 1.0]) - value_1 = constant_op.constant([-10.0, 10.0]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) - w0 = ta.write(0, value_0) - w1 = w0.write(1, value_1) - p0 = w1.stack() - r0 = w1.read(0) - s0 = w1.concat() + value_0 = constant_op.constant([-1.0, 1.0]) + value_1 = constant_op.constant([-10.0, 10.0]) - # Test gradient accumulation between read(0), pack(), and concat(). - with ops.control_dependencies([p0, r0, s0]): - grad_r = gradients_impl.gradients( - ys=[p0, r0, s0], - xs=[value_0, value_1], - grad_ys=[ - [[2.0, 3.0], [4.0, 5.0]], # stack gradient - [-0.5, 1.5], # read(0) gradient - [20.0, 30.0, 40.0, 50.0], # concat gradient - ]) - grad_vals = self.evaluate(grad_r) # 2 + 2 entries + w0 = ta.write(0, value_0) + w1 = w0.write(1, value_1) + p0 = w1.stack() + r0 = w1.read(0) + s0 = w1.concat() + + # Test gradient accumulation between read(0), pack(), and concat(). + with ops.control_dependencies([p0, r0, s0]): + return gradients_impl.gradients( + ys=[p0, r0, s0], + xs=[value_0, value_1], + grad_ys=[ + [[2.0, 3.0], [4.0, 5.0]], # stack gradient + [-0.5, 1.5], # read(0) gradient + [20.0, 30.0, 40.0, 50.0], # concat gradient + ]) + + grad_vals = self.evaluate(xla.compile(fn)) # 2 + 2 entries self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0]) self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1]) + @test_util.disable_control_flow_v2("b/122315751 (concat)") def testTensorArrayGradientWritePackConcatAndRead(self): self._testTensorArrayGradientWritePackConcatAndRead() def testTensorArrayReadTwice(self): - with self.cached_session(), self.test_scope(): - value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + with self.session(), self.test_scope(): - ta_readtwice = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=2, - clear_after_read=False) - w_readtwice = ta_readtwice.unstack(value) - r0_readtwice = w_readtwice.read(0) - with ops.control_dependencies([r0_readtwice]): - r1_readtwice = w_readtwice.read(0) + def fn(): + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) - self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice)) + ta_readtwice = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) + w_readtwice = ta_readtwice.unstack(value) + r0_readtwice = w_readtwice.read(0) + with ops.control_dependencies([r0_readtwice]): + r1_readtwice = w_readtwice.read(0) + + return [r0_readtwice, r1_readtwice] + + self.assertAllEqual([1.0, -1.0], self.evaluate(xla.compile(fn))[0]) def _testTensorArrayGradientUnpackRead(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=2, - clear_after_read=False) + with self.session() as session, self.test_scope(): - value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=2, + clear_after_read=False) - w = ta.unstack(value) - r0 = w.read(0) - r0_1 = w.read(0) - r1 = w.read(1) + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) - # Test combined gradients + aggregation of read(0). - grad = gradients_impl.gradients( - ys=[r0, r0_1, r1], - xs=[value], - grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]]) - grad_vals = session.run(grad) + w = ta.unstack(value) + r0 = w.read(0) + r0_1 = w.read(0) + r1 = w.read(1) + + # Test combined gradients + aggregation of read(0). + return gradients_impl.gradients( + ys=[r0, r0_1, r1], + xs=[value], + grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]]) + + grad_vals = self.evaluate(xla.compile(fn)) self.assertEqual(len(grad_vals), 1) self.assertAllEqual([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0]) @@ -557,24 +699,28 @@ class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayGradientUnpackRead(self): self._testTensorArrayGradientUnpackRead() + @test_util.disable_control_flow_v2("b/122315751(concat), b/122315872(split)") def testTensorArrayGradientSplitConcat(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=2) + with self.session() as session, self.test_scope(): - value = constant_op.constant( - [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0], [1000.0, -1000.0]]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=2) - w = ta.split(value, [2, 2]) - r = w.concat() + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0], + [100.0, -100.0], [1000.0, -1000.0]]) - # Test combined gradients - grad = gradients_impl.gradients( - ys=[r], - xs=[value], - grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0], - [2000.0, -2000.0]]]) - grad_vals = session.run(grad) + w = ta.split(value, [2, 2]) + r = w.concat() + + # Test combined gradients + return gradients_impl.gradients( + ys=[r], + xs=[value], + grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0], + [2000.0, -2000.0]]]) + + grad_vals = self.evaluate(xla.compile(fn)) self.assertEqual(len(grad_vals), 1) self.assertAllEqual([[2.0, -2.0], [20.0, -20.0], [200.0, -200.0], @@ -582,34 +728,46 @@ class TensorArrayTest(xla_test.XLATestCase): grad_vals[0]) def testCloseTensorArray(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3) - c1 = ta.close() - session.run(c1) + with self.session() as session, self.test_scope(): + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + with ops.control_dependencies([ta.close()]): + return 1.0 + + self.evaluate(xla.compile(fn)[0]) def testSizeTensorArray(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3) - s = ta.size() - self.assertAllEqual(3, self.evaluate(s)) + with self.session(), self.test_scope(): + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + return ta.size() + + self.assertAllEqual(3, self.evaluate(xla.compile(fn))[0]) def testWriteCloseTensorArray(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=3, - infer_shape=False) - w0 = ta.write(0, [[4.0, 5.0]]) - w1 = w0.write(1, [3.0]) - w1.close().run() # Expected to run without problems + with self.session(), self.test_scope(): + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + w0 = ta.write(0, [[4.0, 5.0]]) + w1 = w0.write(1, [[3.0, 1.0]]) + with ops.control_dependencies([w1.close()]): + return 1.0 + + self.evaluate(xla.compile(fn)) # TODO(phawkins): implement while loops. # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): # np_dtype = dtype.as_numpy_dtype - # with self.cached_session() as session, self.test_scope(): + # with self.session() as session, self.test_scope(): # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)) # var = variables.Variable(np.arange(100, 105, dtype=np_dtype)) # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype)) @@ -645,9 +803,9 @@ class TensorArrayTest(xla_test.XLATestCase): # variables.global_variables_initializer().run() # state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = ( - # session.run([state0, var, v0, vout, v0_grad, var_grad, state0_grad]) + # self.evaluate([state0, var, v0, vout, v0_grad, var_grad, state0_grad]) # ) - # just_v0_grad_t, = session.run([v0_grad]) + # just_v0_grad_t, = self.evaluate([v0_grad]) # # state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ] # # vout = [ v0[0] + var + state[0] | @@ -693,7 +851,7 @@ class TensorArrayTest(xla_test.XLATestCase): # dynamic_size=True, dtype=dtypes.float32) # def testGradSerialTwoLoops(self): - # with self.cached_session(), self.test_scope(): + # with self.session(), self.test_scope(): # num_steps = 100 # acc = tensor_array_ops.TensorArray( # dtype=dtypes.float32, @@ -726,250 +884,201 @@ class TensorArrayTest(xla_test.XLATestCase): # self.assertAllClose(31.0, self.evaluate(grad)) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): - with self.cached_session() as session, self.test_scope(): - a = array_ops.identity( - np.arange( - 3 * 5, dtype=np.float32).reshape(3, 5) + 1) - b = array_ops.identity( - np.arange( - 3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5) - ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) - ta = ta.write(0, a, name="write_a") - ta = ta.write(1, b, name="write_b") - c = ( - ta.read( - 0, name="read_a_0") + # a + b - ta.read( - 1, name="read_b_0")) + with self.session() as session, self.test_scope(): g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1) - grad_a = gradients_impl.gradients([c], [a], [g0])[0] # d(a+b)/da = 1 - grad_b = gradients_impl.gradients([c], [b], [g0])[0] # d(a+b)/db = 1 + + def fn(): + a = array_ops.identity( + np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1) + b = array_ops.identity( + np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5) + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) + ta = ta.write(0, a, name="write_a") + ta = ta.write(1, b, name="write_b") + c = ( + ta.read(0, name="read_a_0") + # a + b + ta.read(1, name="read_b_0")) + grad_a = gradients_impl.gradients([c], [a], [g0])[0] # d(a+b)/da = 1 + grad_b = gradients_impl.gradients([c], [b], [g0])[0] # d(a+b)/db = 1 + + return [grad_a, grad_b] + + grad_a, grad_b = xla.compile(fn) # Test gradients calculated individually - grad_a_t, = session.run([grad_a]) + grad_a_t, = self.evaluate([grad_a]) self.assertAllEqual(grad_a_t, g0) - grad_b_t, = session.run([grad_b]) + grad_b_t, = self.evaluate([grad_b]) self.assertAllEqual(grad_b_t, g0) # Test gradients calculated jointly. - joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b]) + joint_grad_a_t, joint_grad_b_t = self.evaluate([grad_a, grad_b]) self.assertAllEqual(joint_grad_a_t, g0) self.assertAllEqual(joint_grad_b_t, g0) def testWriteShape(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3) - c0 = constant_op.constant([4.0, 5.0]) - w0 = ta.write(0, c0) - r0 = w0.read(0) + with self.session(), self.test_scope(): + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c0 = constant_op.constant([4.0, 5.0]) + w0 = ta.write(0, c0) + r0 = w0.read(0) + + return [c0, r0] + + c0, r0 = xla.compile(fn) + self.assertAllEqual(c0.get_shape(), r0.get_shape()) - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3) - c1 = constant_op.constant([6.0, 7.0]) - w1 = w0.write(1, c1) - r0 = w1.read(0) - r1 = w1.read(1) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c1 = constant_op.constant([6.0, 7.0]) + w0 = ta.write(0, c0) + w1 = w0.write(1, c1) + r0 = w1.read(0) + r1 = w1.read(1) + + return [r0, c1, r1] + + [r0, c1, r1] = xla.compile(fn) + self.assertAllEqual(c0.get_shape(), r0.get_shape()) self.assertAllEqual(c1.get_shape(), r1.get_shape()) - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=3) - c2 = constant_op.constant([4.0, 5.0, 6.0]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + w0 = ta.write(0, c0) + c2 = constant_op.constant([4.0, 5.0, 6.0]) + return w0.write(0, c2).flow + with self.assertRaises(ValueError): - w0.write(0, c2) - - def testPartlyUnknownShape(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=6) - - c0 = array_ops.placeholder(dtypes.float32, [None, None, None, 3]) - w0 = ta.write(0, c0) - r0 = w0.read(0) - self.assertAllEqual([None, None, None, 3], r0.get_shape().as_list()) - - c1 = array_ops.placeholder(dtypes.float32, [None, None, None, 3]) - w1 = w0.write(1, c1) - r1 = w1.read(0) - self.assertAllEqual([None, None, None, 3], r1.get_shape().as_list()) - - # Writing less specific shape (doesn't change type.) - c2 = array_ops.placeholder(dtypes.float32, [None, None, None, None]) - w2 = w1.write(2, c2) - r2 = w2.read(0) - self.assertAllEqual([None, None, None, 3], r2.get_shape().as_list()) - - # Writing more specific shape in one dimension and less specific in - # another. - c3 = array_ops.placeholder(dtypes.float32, [None, None, 2, None]) - w3 = w2.write(3, c3) - r3 = w3.read(0) - self.assertAllEqual([None, None, 2, 3], r3.get_shape().as_list()) - - # Writing partly defined shape using TensorArray.scatter. - c4 = array_ops.placeholder(dtypes.float32, [2, None, 4, 2, 3]) - w4 = w3.scatter([4, 5], c4) - r4 = w4.read(0) - self.assertAllEqual([None, 4, 2, 3], r4.get_shape().as_list()) - - # Writing fully defined shape using TensorArray.split. - c5 = array_ops.placeholder(dtypes.float32, [10, 4, 2, 3]) - w5 = w4.split(c5, constant_op.constant([5, 5])) - r5 = w5.read(0) - self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) - - def _testUnpackShape(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=0, - infer_shape=True) - value = constant_op.constant( - [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]]) - w0 = ta.unstack(value) - r0 = w0.read(0) - self.assertAllEqual((2,), r0.get_shape()) - - c1 = constant_op.constant([4.0, 5.0]) - w1 = w0.write(3, c1) - r1 = w1.read(0) - self.assertAllEqual(c1.get_shape(), r1.get_shape()) - - c2 = constant_op.constant([4.0, 5.0, 6.0]) - with self.assertRaises(ValueError): - w1.write(4, c2) - - def testUnpackShape(self): - self._testUnpackShape() - - def testSplitShape(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=0, - infer_shape=True) - value = constant_op.constant([[1.0, -1.0], [2.0, -2.0], [3.0, -3.0]]) - w0 = ta.split(value, [1, 1, 1]) - r0 = w0.read(0) - self.assertAllEqual((1, 2), r0.get_shape()) - - ta1 = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo1", - size=0, - infer_shape=True) - w0 = ta1.split(value, [1, 2]) - r0 = w0.read(0) - self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) - - def testWriteUnknownShape(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=3, - infer_shape=True) - c0 = array_ops.placeholder(dtypes.float32) - w0 = ta.write(0, c0) - r0 = w0.read(0) - self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) + self.evaluate(xla.compile(fn)) def _testGradientWhenNotAllComponentsRead(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) - x = constant_op.constant([2.0, 3.0]) - w = ta.unstack(x) - r0 = w.read(0) - # Calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0). - grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0]) - grad_r0_vals = session.run(grad_r0)[0] + with self.session() as session, self.test_scope(): + + def fn(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) + x = constant_op.constant([2.0, 3.0]) + w = ta.unstack(x) + r0 = w.read(0) + # Calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0). + return gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0]) + + grad_r0_vals = self.evaluate(xla.compile(fn))[0] self.assertAllEqual(grad_r0_vals, [1.0, 0.0]) def testGradientWhenNotAllComponentsRead(self): self._testGradientWhenNotAllComponentsRead() def _testTensorArrayEvalEmpty(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, size=0, infer_shape=False) - with self.assertRaisesOpError( - "TensorArray has size zero, but element shape is not fully " - "defined. Currently only static shapes are supported when packing " - "zero-size TensorArrays."): - ta.stack().eval() + with self.session(), self.test_scope(): + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=0, infer_shape=False) + return ta.stack() + + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, "Uninitialized TensorArray passed to " + "TensorArrayStack/TensorArrayGatherV3"): + xla.compile(fn)[0].eval() + + @test_util.disable_control_flow_v2("b/124335246") def testTensorArrayEvalEmpty(self): self._testTensorArrayEvalEmpty() def _testTensorArrayEvalEmptyWithDefault(self): - with self.cached_session(), self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, size=0, infer_shape=True) - self.assertEqual(0, ta.size().eval()) - ta = ta.unstack(array_ops.zeros([0, 3, 5])) - packed = ta.stack() - self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape) + with self.session(), self.test_scope(): + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=0, infer_shape=True) + size = ta.size() + ta = ta.unstack(array_ops.zeros([0, 3, 5])) + return [size, ta.stack()] + + [size, stack] = self.evaluate(xla.compile(fn)) + self.assertEqual(0, size) + self.assertAllEqual([0, 3, 5], stack.shape) # Concatenating zero tensors along their first dimension gives a # first dimension of zero - self.assertAllEqual([0, 5], ta.concat().eval().shape) + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: + + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=0, infer_shape=True) + ta = ta.unstack(array_ops.zeros([0, 3, 5])) + return ta.concat() + + # TODO(b/122315751): Enable this. + self.assertAllEqual([0, 5], self.evaluate(xla.compile(fn))[0].shape) def testTensorArrayEvalEmptyWithDefault(self): self._testTensorArrayEvalEmptyWithDefault() def _testTensorArrayScatterRead(self, tf_dtype): - with self.cached_session() as session, self.test_scope(): + with self.session() as session, self.test_scope(): convert = _make_converter(tf_dtype) - - ta = tensor_array_ops.TensorArray( - dtype=tf_dtype, - tensor_array_name="foo", - size=10) - - indices = constant_op.constant([1, 8]) - value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]])) id0 = array_ops.placeholder(dtypes.int32) id1 = array_ops.placeholder(dtypes.int32) - w = ta.scatter(indices, value) - r0 = w.read(id0) - r1 = w.read(id1) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, tensor_array_name="foo", size=10) + + indices = constant_op.constant([1, 8]) + value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]])) + + w = ta.scatter(indices, value) + r0 = w.read(id0) + r1 = w.read(id1) + + return [r0, r1] # Test aggregation of read - read_vals = session.run([r0, r1], feed_dict={id0: 1, id1: 8}) + read_vals = session.run(xla.compile(fn), feed_dict={id0: 1, id1: 8}) self.assertAllEqual(convert([1.0, -1.0]), read_vals[0]) self.assertAllEqual(convert([10.0, -10.0]), read_vals[1]) + @test_util.disable_control_flow_v2("b/122315734 (scatter)") def testTensorArrayScatterRead(self): for dtype in self.numeric_tf_types: self._testTensorArrayScatterRead(dtype) self._testTensorArrayScatterRead(dtypes.bool) + @test_util.disable_control_flow_v2("b/122315734 (scatter)") def testTensorArrayScatterReadAndGradients(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=10) - - indices = constant_op.constant([1, 8]) - value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + with self.session() as session, self.test_scope(): id0 = array_ops.placeholder(dtypes.int32) id1 = array_ops.placeholder(dtypes.int32) - w = ta.scatter(indices, value) - r0 = w.read(id0) - r1 = w.read(id1) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=10) - # Test combined gradients + aggregation of read(0). - grad = gradients_impl.gradients( - ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) - read_vals, grad_vals = session.run([[r0, r1], grad], - feed_dict={id0: 1, id1: 8}) + indices = constant_op.constant([1, 8]) + value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + + w = ta.scatter(indices, value) + r0 = w.read(id0) + r1 = w.read(id1) + + # Test combined gradients + aggregation of read(0). + grad = gradients_impl.gradients( + ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) + return [[r0, r1], grad] + + read_vals, grad_vals = session.run( + xla.compile(fn), feed_dict={ + id0: 1, + id1: 8 + }) self.assertEqual(len(read_vals), 2) self.assertEqual(len(grad_vals), 1) @@ -977,23 +1086,26 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([10.0, -10.0], read_vals[1]) self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) + @test_util.disable_control_flow_v2("b/122315378 (gather)") def testTensorArrayWriteGatherAndGradients(self): - with self.cached_session() as session, self.test_scope(): - ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, - tensor_array_name="foo", - size=10) + with self.session() as session, self.test_scope(): - values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)]) - indices = constant_op.constant([1, 8]) + def fn(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=10) - w = ta.unstack(values) - g = w.gather(indices) + values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)]) + indices = constant_op.constant([1, 8]) - # Test combined gradients + aggregation of read(0). - grad = gradients_impl.gradients( - ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]]) - g_vals, grad_vals = session.run([[g], grad]) + w = ta.unstack(values) + g = w.gather(indices) + + # Test combined gradients + aggregation of read(0). + grad = gradients_impl.gradients( + ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]]) + return [[g], grad] + + g_vals, grad_vals = self.evaluate(xla.compile(fn)) # Gradients for 8 of the 10 unread components are zero. expected_grad = np.zeros((10, 2)) @@ -1006,44 +1118,50 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(expected_grad, grad_vals[0]) def testTensorArrayIdentity(self): - with self.cached_session() as session, self.test_scope(): - ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, - infer_shape=False) - ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4, - infer_shape=True) + with self.session() as session, self.test_scope(): + tensor_arrays = {} - ta0 = ta0.write(0, 0.) - ta1 = ta1.write(0, 1) + v0 = resource_variable_ops.ResourceVariable(0.0) + v1 = resource_variable_ops.ResourceVariable(0.0) - v0 = resource_variable_ops.ResourceVariable(0) - v1 = resource_variable_ops.ResourceVariable(0) + def fn(): + ta0 = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=2, infer_shape=False) + ta1 = tensor_array_ops.TensorArray( + dtype=dtypes.int32, size=4, infer_shape=True) - with ops.control_dependencies([v0.assign_add(1)]): - ta0 = ta0.identity() + ta0 = ta0.write(0, 0.) + ta1 = ta1.write(0, 1) - with ops.control_dependencies([v1.assign_add(1)]): - ta1 = ta1.identity() + with ops.control_dependencies([v0.assign_add(1.0)]): + ta0 = ta0.identity() - read0 = ta0.read(0) - read1 = ta1.read(0) + with ops.control_dependencies([v1.assign_add(1.0)]): + ta1 = ta1.identity() - size0 = ta0.size() - size1 = ta1.size() + read0 = ta0.read(0) + read1 = ta1.read(0) - # Tests correct properties on new TensorArrays. - self.assertEqual(dtypes.float32, ta0.dtype) - self.assertEqual(dtypes.int32, ta1.dtype) - self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape()) - self.assertEqual(tensor_shape.scalar(), read1.get_shape()) + size0 = ta0.size() + size1 = ta1.size() + + tensor_arrays[0] = ta0 + tensor_arrays[1] = ta1 + + return [read0, read1, size0, size1, v0, v1] variables.global_variables_initializer().run() - read0_v, read1_v, size0_v, size1_v = session.run( - (read0, read1, size0, size1)) + read0_v, read1_v, size0_v, size1_v, v0, v1 = self.evaluate( + xla.compile(fn)) + + # Tests correct properties on new TensorArrays. + self.assertEqual(dtypes.float32, tensor_arrays[0].dtype) + self.assertEqual(dtypes.int32, tensor_arrays[1].dtype) # Tests that the control dependencies was added and executed. - self.assertEqual(1, self.evaluate(v0)) - self.assertEqual(1, self.evaluate(v1)) + self.assertEqual(1.0, v0) + self.assertEqual(1.0, v1) # Tests correct TensorArray. self.assertEqual(read0_v, 0) diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 3c0c36d0c4d..b24e807b034 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op @@ -31,7 +32,7 @@ from tensorflow.python.platform import test class ListOpsTest(xla_test.XLATestCase): def testElementShape(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): dim = array_ops.placeholder(dtypes.int32) l = list_ops.empty_tensor_list( element_shape=(dim, 15), @@ -43,7 +44,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15)) def testPushPop(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(7, 15), element_dtype=dtypes.float32, @@ -58,7 +59,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15))) def testDoNotConstantFoldVariants(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): val = array_ops.placeholder(dtype=dtypes.float32) l = list_ops.empty_tensor_list( element_shape=(7, 15), @@ -77,7 +78,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15))) def testPushPopSeparateLists(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=[], element_dtype=dtypes.float32, @@ -94,7 +95,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) def testEmptyTensorListNoMax(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(7, 15), element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( @@ -105,7 +106,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15))) def testEmptyTensorListMax(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): l = list_ops.empty_tensor_list( element_shape=(10, 15), element_dtype=dtypes.float32, max_num_elements=2) @@ -115,7 +116,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15))) def testListFromTensor(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) @@ -124,10 +125,10 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(e0, 2.0) l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(e1, 1.0) - self.assertAllEqual(list_ops.tensor_list_length(l), 0) + self.assertAllEqual(list_ops.tensor_list_length(l), 2) def testGetSet(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) @@ -137,7 +138,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(t, [3.0, 2.0]) def testSetDoesNotUpdatePushIndex(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): l = list_ops.empty_tensor_list( element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) # SetItem should not change the push index. @@ -148,7 +149,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(t, [5., 7.]) def testGetSetReserved(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): l = list_ops.tensor_list_reserve( element_dtype=dtypes.float32, element_shape=[], num_elements=2) e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) @@ -158,7 +159,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(t, [3.0, 0.0]) def testSetStackReservedUnknownElementShape(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): l = list_ops.tensor_list_reserve( element_dtype=dtypes.float32, element_shape=None, num_elements=2) l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0]) @@ -166,7 +167,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]]) def testPushInEmptyListWithUnknownElementShape(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) @@ -177,7 +178,7 @@ class ListOpsTest(xla_test.XLATestCase): list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) def testGetSetReservedNonScalar(self): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): l = list_ops.tensor_list_reserve( element_dtype=dtypes.float32, element_shape=(7, 15), @@ -190,7 +191,7 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(sess.run(e2), np.zeros((7, 15))) def testStack(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=[], @@ -204,11 +205,25 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(t, [1.0, 2.0]) def testStackWithUninitializedTensors(self): - with self.cached_session(), self.test_scope(): + with self.session(), self.test_scope(): l = list_ops.tensor_list_reserve( element_dtype=dtypes.float32, element_shape=[], num_elements=3) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [0., 0., 0.]) + def testZerosLikeForTensorList(self): + with self.session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=[], + max_num_elements=2) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + z = array_ops.zeros_like(l) + z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32) + self.assertAllEqual(z.shape.as_list(), [None]) + self.assertAllEqual(z, [0.0, 0.0]) + if __name__ == "__main__": + os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' + + os.environ.get('TF_XLA_FLAGS', '')) test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 98a07709c61..7e8edc5f0b1 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index f2e0eac2d99..7e0a16d7ac4 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -65,7 +65,7 @@ class UnaryOpsTest(xla_test.XLATestCase): rtol: relative tolerance for equality test. atol: absolute tolerance for equality test. """ - with self.cached_session() as session: + with self.session() as session: with self.test_scope(): pinp = array_ops.placeholder( dtypes.as_dtype(inp.dtype), inp.shape, name="a") @@ -74,7 +74,7 @@ class UnaryOpsTest(xla_test.XLATestCase): if equality_test is None: self.assertEqual(output.dtype, expected.dtype) self.assertAllCloseAccordingToType( - result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) else: equality_test(result, expected, rtol=rtol, atol=atol) @@ -200,7 +200,7 @@ class UnaryOpsTest(xla_test.XLATestCase): # Disable float16 testing for now if dtype != np.float16: x = np.arange(-10, 10, 1).astype(dtype) - with self.cached_session() as session: + with self.session() as session: erf_x = session.run(math_ops.erf(x)) erfc_x = session.run(math_ops.erfc(x)) @@ -956,6 +956,15 @@ class UnaryOpsTest(xla_test.XLATestCase): [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], dtype=dtype), data_format)) + self._assertOpOutputMatchesExpected( + make_op("NCHW_VECT_C"), + np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)), + expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]], + [[[2, 3], [10, 11]], [[18, 19], [26, 27]]], + [[[4, 5], [12, 13]], [[20, 21], [28, 29]]], + [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]], + dtype=dtype)) + def testSpaceToDepth(self): def make_op(data_format): @@ -999,6 +1008,15 @@ class UnaryOpsTest(xla_test.XLATestCase): [13, 14, 15, 16]]]], dtype=dtype), data_format)) + self._assertOpOutputMatchesExpected( + make_op("NCHW_VECT_C"), + np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)), + expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]], + [[[4, 5, 6, 7, 20, 21, 22, 23]]], + [[[8, 9, 10, 11, 24, 25, 26, 27]]], + [[[12, 13, 14, 15, 28, 29, 30, 31]]]]], + dtype=dtype)) + def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 18c5870e0de..fbc7ef49700 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -44,7 +44,7 @@ class VariableOpsTest(xla_test.XLATestCase): # Verifies that we can pass an uninitialized variable with an empty shape, # assign it a value, and successfully return it. for dtype in self.numeric_types: - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): zeros = np.zeros([3, 0], dtype=dtype) v = resource_variable_ops.ResourceVariable(zeros) p = array_ops.placeholder(dtype) @@ -58,7 +58,7 @@ class VariableOpsTest(xla_test.XLATestCase): # output and one variable update were mishandled. for dtype in self.numeric_types: init = np.array([[1, 2j], [3, 4]]).astype(dtype) - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) p = array_ops.placeholder(dtype) @@ -72,7 +72,7 @@ class VariableOpsTest(xla_test.XLATestCase): for dtype in self.numeric_types: init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10, 11]]).astype(dtype) - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read(2) @@ -83,7 +83,7 @@ class VariableOpsTest(xla_test.XLATestCase): for dtype in self.numeric_types: init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10, 11]]).astype(dtype) - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([2, 1]) @@ -95,7 +95,7 @@ class VariableOpsTest(xla_test.XLATestCase): for dtype in self.numeric_types: init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10, 11]]).astype(dtype) - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([[2, 1], [0, 2]]) @@ -109,7 +109,7 @@ class VariableOpsTest(xla_test.XLATestCase): init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]], [[20, 21, 22], [23, 24j, 25]], [[30, 31, 32], [33, 34, 35]]]).astype(dtype) - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([[2, 1], [3, 0]]) @@ -122,7 +122,7 @@ class VariableOpsTest(xla_test.XLATestCase): def testShape(self): for dtype in self.numeric_types: init = np.ones([2, 3]).astype(dtype) - with self.test_session() as session, self.test_scope(): + with self.session() as session, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) session.run(variables.variables_initializer([v])) h = v.handle @@ -138,7 +138,7 @@ class VariableOpsTest(xla_test.XLATestCase): def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" for dtype in self.numeric_types: - with self.test_session() as session: + with self.session() as session: with self.test_scope(): with variable_scope.variable_scope("ascope", use_resource=True): x = variable_scope.get_variable( @@ -166,7 +166,7 @@ class VariableOpsTest(xla_test.XLATestCase): def testTraining(self): """Tests a gradient descent step for a simple model.""" - with self.test_session() as session: + with self.session() as session: with self.test_scope(): with variable_scope.variable_scope("ascope", use_resource=True): w = variable_scope.get_variable( @@ -203,7 +203,7 @@ class VariableOpsTest(xla_test.XLATestCase): for dtype in self.numeric_types: init = np.array([[1, 2j], [3, 4]]).astype(dtype) update = np.array([[7, 1j], [2, 11]]).astype(dtype) - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) p = array_ops.placeholder(dtype) @@ -219,7 +219,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertAllClose(update, result[2]) def testScatterAdd(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[2, 1]) sess.run( @@ -232,7 +232,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertAllEqual(self.evaluate(read), [[3], [7]]) def testScatterSub(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[2, 1]) sess.run( @@ -245,7 +245,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertAllEqual(self.evaluate(read), [[4], [-1]]) def testScatterMul(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -258,7 +258,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[5]]) def testScatterDiv(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -271,7 +271,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertAllEqual(self.evaluate(read), [[2]]) def testScatterMin(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -284,7 +284,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[3]]) def testScatterMax(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -297,7 +297,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[6]]) def testScatterUpdate(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -310,7 +310,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[3]]) def testScatterAddScalar(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -323,7 +323,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[3]]) def testScatterSubScalar(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -336,7 +336,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[-1]]) def testScatterMulScalar(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -349,7 +349,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[5]]) def testScatterDivScalar(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -362,7 +362,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[2]]) def testScatterMinScalar(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -375,7 +375,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[3]]) def testScatterMaxScalar(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) sess.run( @@ -388,7 +388,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertEqual(self.evaluate(read), [[6]]) def testScatterNdAddOps(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.float32, shape=[8]) sess.run( @@ -403,7 +403,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertAllClose(expected, self.evaluate(read)) def testScatterNdUpdateAddOps(self): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.float32, shape=[8]) sess.run( @@ -433,7 +433,7 @@ class StridedSliceAssignChecker(object): self.which_mode = 1 - self.which_mode value = np.array(value).astype(self.dtype) - with self.test.test_session() as sess, self.test.test_scope(): + with self.test.session() as sess, self.test.test_scope(): x = constant_op.constant(self.x_np, dtype=self.dtype) var = resource_variable_ops.ResourceVariable(x) sess.run(variables.variables_initializer([var])) @@ -487,7 +487,7 @@ class SliceAssignTest(xla_test.XLATestCase): def testUninitialized(self): with self.assertRaisesRegexp(errors.FailedPreconditionError, "uninitialized variable"): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable([1, 2]) sess.run(v[:].assign([1, 2])) diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index 55d1f853700..3ef12ced704 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -30,6 +30,8 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import map_fn +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -47,7 +49,7 @@ class WhileTest(xla_test.XLATestCase): def loop_cond(step): return step < 10 - with self.cached_session() as sess: + with self.session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) @@ -69,7 +71,7 @@ class WhileTest(xla_test.XLATestCase): del rsum return step < 10 - with self.cached_session() as sess: + with self.session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.float32, []) with self.test_scope(): @@ -95,7 +97,7 @@ class WhileTest(xla_test.XLATestCase): del rsum return step < 10 - with self.cached_session() as sess: + with self.session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.complex64, []) with self.test_scope(): @@ -121,7 +123,7 @@ class WhileTest(xla_test.XLATestCase): del x return step < 10 - with self.cached_session() as sess: + with self.session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) @@ -132,7 +134,7 @@ class WhileTest(xla_test.XLATestCase): def _testMaxItersSimple(self): if is_compile_on_demand(): self.skipTest("list_ops are not supported in cpu_ondemand") - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() v = constant_op.constant(1.0) @@ -166,7 +168,7 @@ class WhileTest(xla_test.XLATestCase): def _testNestedWhileLoopWithMaxItersFromOuterContext(self): if is_compile_on_demand(): self.skipTest("list_ops are not supported in cpu_ondemand") - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() v = constant_op.constant(1.0) @@ -223,6 +225,20 @@ class WhileTest(xla_test.XLATestCase): def testNestedWhileLoopWithMaxItersFromOuterContextV2(self): self._testNestedWhileLoopWithMaxItersFromOuterContext() + @test_util.enable_control_flow_v2 + def testMap(self): + if is_compile_on_demand(): + self.skipTest("list_ops are not supported in cpu_ondemand") + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, name="data") + r = map_fn.map_fn(lambda x: math_ops.multiply(math_ops.add(x, 3), 2), + elems) + self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums])) + xla_context.Exit() + def is_compile_on_demand(): return ("TF_XLA_FLAGS" in os.environ and @@ -230,4 +246,6 @@ def is_compile_on_demand(): if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index ef55292b1be..271bf66f40a 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -37,7 +37,7 @@ class XlaDeviceTest(xla_test.XLATestCase): [16384, 1], [1, 16384], [1, 20000, 1, 1]] for dtype in self.numeric_types: for shape in shapes: - with self.cached_session() as sess: + with self.session() as sess: with ops.device("CPU"): x = array_ops.placeholder(dtype, shape) with self.test_scope(): @@ -58,7 +58,7 @@ class XlaDeviceTest(xla_test.XLATestCase): ]) shape = (10, 10) for unsupported_dtype in test_types - self.all_types: - with self.cached_session() as sess: + with self.session() as sess: with ops.device("CPU"): x = array_ops.placeholder(unsupported_dtype, shape) with self.test_scope(): @@ -78,7 +78,7 @@ class XlaDeviceTest(xla_test.XLATestCase): pass def testControlTrigger(self): - with self.cached_session() as sess: + with self.session() as sess: with self.test_scope(): x = gen_control_flow_ops.control_trigger() self.evaluate(x) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 28274ff799d..b6a522bdeff 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -35,7 +35,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): def _assertOpOutputMatchesExpected(self, op, args, expected, equality_fn=None): - with self.test_session() as session: + with self.session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -310,7 +310,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dtype=dtype)) def testDynamicSliceWithIncorrectStartIndicesShape(self): - with self.test_session() as session: + with self.session() as session: with self.test_scope(): output = xla.dynamic_slice( np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), @@ -323,7 +323,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): r'but input rank is 3 and start_indices has shape \[2\].*')) def testDynamicSliceWithIncorrectSizeIndicesShape(self): - with self.test_session() as session: + with self.session() as session: with self.test_scope(): output = xla.dynamic_slice( np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 98a41981cf3..7fe22ad94cd 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -26,6 +26,7 @@ import re import numpy as np from tensorflow.contrib.compiler import jit +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -66,7 +67,7 @@ def parse_disabled_manifest(manifest_content): raise ValueError('Bad entry in manifest file.') disabled_regex = '|'.join(disabled_tests) - method_types_filter = dict() + method_types_filter = {} for method, types in disabled_method_types: method_types_filter[method] = set([ dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype @@ -199,10 +200,10 @@ class XLATestCase(test.TestCase): logging.info('End test case: %s', self._testMethodName) @contextlib.contextmanager - def test_session(self): - """Custom implementation of test_session() for XLA tests. + def session(self): + """Custom implementation of session() for XLA tests. - We override the standard Tensorflow test_session() since it is too + We override the standard Tensorflow session() since it is too specific to CPU and GPU tests. In particular, we want to disable soft placement and explicitly assign ops to devices under test. @@ -210,9 +211,25 @@ class XLATestCase(test.TestCase): A session to use when running a test case. """ graph = ops.Graph() - with session.Session(graph=graph) as sess, graph.as_default(): + config = config_pb2.ConfigProto() + + # Grappler can constant fold TensorListFromTensor ops into DT_VARIANT + # constants which XLA does not understand. So disable constant folding in + # these tests. + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with session.Session( + graph=graph, config=config) as sess, graph.as_default(): yield sess + def cached_session(self): + raise NotImplementedError( + 'cached_session not supported on XLATestCase, please use session') + + def test_session(self): + raise NotImplementedError( + 'test_session not supported on XLATestCase, please use session') + @contextlib.contextmanager def test_scope(self): """Test scope that runs tests on a Tensorflow/XLA device. @@ -268,6 +285,7 @@ def Benchmark(tf_bench, for fetch in fetches: targets.append(array_ops.identity(fetch).op) + # TODO(b/132430685): Should we allow soft placement here? config = config_pb2.ConfigProto(allow_soft_placement=True) with session.Session(config=config) as sess: sess.run(variables.global_variables_initializer()) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index d2fe38ce921..eb10b021349 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -11,20 +11,24 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", + "tf_cc_shared_object", "tf_cc_test", "tf_copts", "tf_cuda_library", - "tf_custom_op_library", "tf_custom_op_library_additional_deps", "tf_gen_op_libs", "tf_gen_op_wrapper_py", ) load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( - "@local_config_tensorrt//:build_defs.bzl", - "if_tensorrt", + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", ) +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +# Placeholder for Google-internal load statements. tf_cuda_cc_test( name = "tensorrt_test_cc", @@ -46,19 +50,6 @@ tf_cuda_cc_test( ]), ) -tf_custom_op_library( - name = "python/ops/_trt_ops.so", - srcs = [ - "ops/get_serialized_resource_op.cc", - "ops/trt_engine_op.cc", - ], - deps = [ - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:tensorrt", - ]), -) - cc_library( name = "trt_op_kernels", srcs = [ @@ -86,6 +77,73 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "trt_engine_resource_op_kernels", + srcs = ["kernels/trt_engine_resource_ops.cc"], + copts = tf_copts(), + visibility = ["//visibility:private"], + deps = [ + ":trt_allocator", + ":trt_engine_instance_proto_cc", + ":trt_logging", + ":trt_plugins", + ":trt_resources", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), + alwayslink = 1, +) + +tf_cuda_cc_test( + name = "trt_engine_resource_ops_test", + size = "small", + srcs = ["kernels/trt_engine_resource_ops_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_engine_instance_proto_cc", + ":trt_engine_resource_op_kernels", + ":trt_engine_resource_ops_op_lib", + ":trt_logging", + ":trt_resources", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:resource_variable_ops", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_shared_object( + name = "python/ops/libtftrt.so", + copts = tf_copts(is_external = True), + linkopts = ["-lm"], + deps = [ + ":trt_op_kernels", + ":trt_engine_resource_op_kernels", + ":trt_op_libs", + ":trt_engine_resource_ops_op_lib", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), +) + tf_cuda_cc_test( name = "get_serialized_resource_op_test", size = "small", @@ -111,10 +169,40 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "trt_engine_op_test", + size = "small", + srcs = ["kernels/trt_engine_op_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_op_kernels", + ":trt_op_libs", + ":trt_resources", + "@com_google_googletest//:gtest", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + ]), +) + tf_gen_op_libs( op_lib_names = [ "trt_engine_op", "get_serialized_resource_op", + "trt_engine_resource_ops", ], ) @@ -141,6 +229,7 @@ tf_cuda_library( tf_gen_op_wrapper_py( name = "trt_ops", deps = [ + ":trt_engine_resource_ops_op_lib", ":trt_op_libs", ], ) @@ -149,16 +238,20 @@ tf_custom_op_py_library( name = "trt_ops_loader", srcs = ["python/ops/trt_ops.py"], dso = [ - "python/ops/_trt_ops.so", + "python/ops/libtftrt.so", ] + if_tensorrt([ "@local_config_tensorrt//:tensorrt", ]), kernels = [ ":trt_op_kernels", + ":trt_engine_resource_op_kernels", ":trt_op_libs", + ":trt_engine_resource_ops_op_lib", ], srcs_version = "PY2AND3", deps = [ + ":trt_ops", + ":wrap_py_utils", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", @@ -170,6 +263,7 @@ tf_cuda_library( name = "trt_resources", srcs = [ "utils/trt_int8_calibrator.cc", + "utils/trt_lru_cache.cc", "utils/trt_resources.cc", ], hdrs = [ @@ -271,6 +365,7 @@ tf_cuda_library( ] + if_tensorrt([ "@local_config_tensorrt//:tensorrt", ]) + tf_custom_op_library_additional_deps(), + alwayslink = 1, ) tf_cuda_cc_test( @@ -283,10 +378,13 @@ tf_cuda_cc_test( "nomac", ], deps = [ + ":trt_op_kernels", + ":trt_op_libs", ":trt_conversion", "@com_google_googletest//:gtest", "@com_google_absl//absl/strings", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/clusters:cluster", @@ -321,6 +419,7 @@ tf_cuda_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/core/grappler/costs:graph_properties", @@ -431,3 +530,30 @@ cc_library( "//tensorflow/core:lib_proto_parsing", ], ) + +tf_proto_library( + name = "trt_engine_instance_proto", + srcs = ["utils/trt_engine_instance.proto"], + cc_api_version = 2, + protodeps = tf_additional_all_protos(), +) + +cc_library( + name = "py_utils", + srcs = ["utils/py_utils.cc"], + hdrs = ["utils/py_utils.h"], + copts = tf_copts(), + deps = if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_py_wrap_cc( + name = "wrap_py_utils", + srcs = ["utils/py_utils.i"], + copts = tf_copts(), + deps = [ + "//tensorflow/compiler/tf2tensorrt:py_utils", + "//third_party/python_runtime:headers", + ], +) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 9bc94d55047..723eba7eb96 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -57,29 +57,14 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { namespace convert { using absl::StrAppend; using absl::StrCat; -// Returns compiled TRT version information {Maj, Min, Patch} -std::vector GetLinkedTensorRTVersion() { - return {NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH}; -} - -// Returns loaded TRT library version {Maj, Min, Patch} -std::vector GetLoadedTensorRTVersion() { - int ver = getInferLibVersion(); - int ver_major = ver / 1000; - ver = ver - ver_major * 1000; - int ver_minor = ver / 100; - int ver_patch = ver - ver_minor * 100; - return {ver_major, ver_minor, ver_patch}; -} - TrtCandidateSelector::TrtCandidateSelector( const grappler::GraphProperties& graph_properties, TrtPrecisionMode precision_mode) @@ -113,93 +98,6 @@ Status BuildNodeMap(const Graph& graph, } // namespace -Status ConvertGraphDefToTensorRT( - const GraphDef& graph_def, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - GraphDef* new_graph_def, TrtPrecisionMode precision_mode, - int minimum_segment_size, bool is_dyn_op, int max_cached_engines, - std::vector cached_engine_batches, bool use_calibration) { - // Create GrapplerItem. - grappler::GrapplerItem item; - item.fetch = output_names; - item.graph = graph_def; - -// TODO(aaroey): we should have used single machine cluster like the -// following, but the problem is then wrap_conversion will depend on -// direct_session and cause double linking problems. To fix this we need to -// fix or get rid of the swig dependency. Here we use VirtualCluster -// as a work around, and we need to create a session to initialize the -// underlying device before calling this method. -#if 0 - // Create single machine cluster. Note that this will create a session and - // initialize the gpu devices. - const int num_cpu_cores = - grappler::GetNumAvailableLogicalCPUCores(); - const int num_gpus = grappler::GetNumAvailableGPUs(); - VLOG(2) << "cpu_cores: " << num_cpu_cores; - VLOG(2) << "gpus: " << num_gpus; - const int timeout_s = 60 * 10; - std::unique_ptr cluster( - new grappler::SingleMachine( - timeout_s, num_cpu_cores, num_gpus)); - // These settings are the defaults in tensorflow/python/grappler/cluster.py. - cluster->DisableDetailedStats(true); - cluster->AllowSoftPlacement(true); - cluster->SetNumWarmupSteps(10); - TF_RETURN_IF_ERROR(cluster->Provision()); -#else - // Create virtual cluster. Grappler requires a virtual cluster with a proper - // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode. - // We add numbers from a Pascal card here to have flops>0. - DeviceProperties device_properties; - device_properties.set_type("GPU"); - device_properties.mutable_environment()->insert({"architecture", "6"}); - device_properties.set_num_cores(3584); - device_properties.set_frequency(1531); - std::unique_ptr cluster( - new grappler::VirtualCluster({{"/GPU:0", device_properties}})); -#endif - - // Create RewriterConfig. - ConfigProto config_proto; - auto& rw_cfg = - *config_proto.mutable_graph_options()->mutable_rewrite_options(); - // TODO(aaroey): use only const folding and layout for the time being since - // new optimizers break the graph for trt. - rw_cfg.add_optimizers("constfold"); - rw_cfg.add_optimizers("layout"); - auto optimizer = rw_cfg.add_custom_optimizers(); - optimizer->set_name("TensorRTOptimizer"); - auto& parameters = *(optimizer->mutable_parameter_map()); - parameters["minimum_segment_size"].set_i(minimum_segment_size); - parameters["max_batch_size"].set_i(max_batch_size); - parameters["is_dynamic_op"].set_b(is_dyn_op); - parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes); - TF_RETURN_IF_ERROR(TrtPrecisionModeToName( - precision_mode, parameters["precision_mode"].mutable_s())); - parameters["maximum_cached_engines"].set_i(max_cached_engines); - if (!cached_engine_batches.empty()) { - auto list = parameters["cached_engine_batches"].mutable_list(); - for (const int batch : cached_engine_batches) { - list->add_i(batch); - } - } - parameters["use_calibration"].set_b(use_calibration); - - // Run optimizer. - grappler::MetaOptimizer meta_opt(nullptr, config_proto); - TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def)); - - if (VLOG_IS_ON(5)) { - std::fstream f; - f.open("TRTConversionInput.pb", - std::fstream::out | std::fstream::binary | std::fstream::trunc); - f << new_graph_def->SerializeAsString(); - f.close(); - } - return Status::OK(); -} - struct EdgePtrCompare { bool operator()(const Edge* lhs, const Edge* rhs) const { return lhs->id() < rhs->id(); @@ -346,9 +244,13 @@ Status GetEngineInfo(const Graph* g, // Construct the const nodes first. subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(), added_const_nodes.end()); + string scope_name; TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( g, graph_properties, subgraph_nodes, &info->connections, - &info->segment_graph_def, &info->engine_name)); + &info->segment_graph_def, &scope_name)); + info->engine_name = StrCat(scope_name, info->engine_name); + VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name + << "' to a GraphDef"; // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -507,20 +409,15 @@ Status CreateTRTNode(const ConversionParams& params, // these segments. if (inputs.empty()) { return errors::Internal( - "Segment has no inputs (possible " - "constfold failure)"); + "Segment has no inputs (possible constfold failure)"); } const bool calibrate_int8 = (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration); // Build the engine and get its serialized representation. string segment_string; - if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) { - // Create static engine for fp32/fp16 mode, and test validity of the engine - // for int8 calibration mode. We don't want engine to fail at the - // calibration time. So we are constructing a FP32 engine here to check its - // validity, and if it is a valid engine then we put the serialized graphdef - // to the op. Otherwise we skip node creation for this engine. + if (info.engine_type == EngineInfo::EngineType::TRTStatic) { + // Create static engine for fp32/fp16 mode. Logger trt_logger; TrtUniquePtrType engine; // TODO(sami): What happens if 1st dim is not batch? @@ -534,10 +431,6 @@ Status CreateTRTNode(const ConversionParams& params, TrtUniquePtrType engine_data(engine->serialize()); segment_string = string(static_cast(engine_data->data()), engine_data->size()); - if (calibrate_int8) { - // See above comment about why not putting this inside the 'else' branch. - segment_string = info.segment_graph_def.SerializeAsString(); - } } else { segment_string = info.segment_graph_def.SerializeAsString(); } @@ -623,17 +516,18 @@ Status CreateTRTNode(const ConversionParams& params, UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false, conn.outside_node_name, &output_node, &port); } - VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number - << " to " << output_node->name() << ":" << port; if (conn.is_control_edge()) { + VLOG(1) << "Updating control edge from " << engine_node->name() << " to " + << output_node->name(); QCHECK_EQ(Graph::kControlSlot, port); graph->AddControlEdge(engine_node, output_node); } else { - auto new_edge = - graph->AddEdge(engine_node, conn.port_number, output_node, port); - QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name() - << ":" << conn.port_number << " -> " - << output_node->name() << ":" << conn.outside_port; + VLOG(1) << "Updating data edge from " << engine_node->name() << ":" + << conn.port_number << " to " << output_node->name() << ":" + << port; + // Use UpdateEdge() to avoid adding the same edge multiple times. + TF_CHECK_OK( + graph->UpdateEdge(engine_node, conn.port_number, output_node, port)); } } return Status::OK(); @@ -854,6 +748,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { for (size_t t = 0; t < initial_segments.size(); t++) { auto& curr_segment = initial_segments.at(t); EngineInfo curr_engine; + curr_engine.engine_name = StrCat("TRTEngineOp_", t); Status status = GetEngineInfo(&graph, *params.graph_properties, curr_segment.first, node_map, reverse_topo_order, &curr_engine); @@ -869,7 +764,6 @@ Status ConvertAfterShapes(const ConversionParams& params) { curr_engine.use_calibration = params.use_calibration; curr_engine.cached_engine_batches = params.cached_engine_batches; curr_engine.maximum_cached_engines = params.max_cached_engines; - StrAppend(&curr_engine.engine_name, "TRTEngineOp_", t); if (params.use_function_backup) { status = RegisterSegmentFunctionToFunctionLibrary( &graph, curr_engine.segment_graph_def, curr_engine.engine_name); @@ -914,6 +808,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { (engine_bytes_size.at(i) / total_engine_bytes_size + converted_segments.at(i).first.size() / total_num_nodes_in_segments) / 2.0; + VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to " + << engine.engine_name; // The allocator is used to build the engine. The build and the built engine // will be destroyed after we get the serialized engine string, so it's fine // to use unique_ptr here. diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index b9600126624..e43bffe69ab 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -76,28 +76,9 @@ struct ConversionParams { bool use_function_backup = true; }; -// - max_batch_size: maximum batch size which can be used for inference for -// optimization targets inference run with max batch size. -// - max_workspace_size_bytes: The upper bound of memory allowance for engine -// building. -Status ConvertGraphDefToTensorRT( - const GraphDef& graph_def, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - GraphDef* new_graph_def, - TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32, - int minimum_segment_size = 3, bool is_dyn_op = false, - int max_cached_engines = 1, std::vector cached_engine_batches = {}, - bool use_calibration = true); - // Method to call from optimization pass Status ConvertAfterShapes(const ConversionParams& params); -// Return compile time TensorRT library version information. -std::vector GetLinkedTensorRTVersion(); - -// Return runtime time TensorRT library version information. -std::vector GetLoadedTensorRTVersion(); - // Helper method for the conversion, expose for testing. std::pair GetDeviceAndAllocator(const ConversionParams& params, const EngineInfo& engine); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 1a754181deb..d8db0ffac7e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" @@ -105,8 +106,8 @@ TEST(TrtCandidateSelector, Basics) { ExpectStatus( selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), error::INVALID_ARGUMENT, - "transpose_a is not supported for TensorRT FullyConnected " - "(op: MatMul), at: incompatible_matmul"); + "Cannot transpose first input if it is a tensor with fewer than 2 " + "non-batch dimensions."); ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), error::UNIMPLEMENTED, "Op type Erf is not supported"); ExpectStatus( @@ -222,6 +223,76 @@ TEST(ConvertGraphTest, GetDeviceAndAllocator) { } } +class ConvertAfterShapesTest : public ::testing::Test { + public: + Status RunConvertAfterShape(Scope s, GraphDef* output_graph_def) { + // Create GraphProperties. + grappler::GrapplerItem item; + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + grappler::GraphProperties graph_properties(item); + TF_EXPECT_OK(graph_properties.InferStatically(true)); + + // Construct ConversionParams. + const std::vector output_names{"output"}; + ConversionParams params; + params.input_graph_def = &item.graph; + params.output_names = &output_names; + params.max_workspace_size_bytes = 8 << 20; + params.output_graph_def = output_graph_def; + params.minimum_segment_size = 2; + params.graph_properties = &graph_properties; + params.use_calibration = false; + + return ConvertAfterShapes(params); + } +}; + +TEST_F(ConvertAfterShapesTest, DirectlyConnectedEngines) { + // Create the graph. There will be two TRTEngineOps after the conversion, and + // the upstream TRTEngineOp will have two output connections from the same + // node:port inside the op to the downstream TRTEngineOp. Then, if it adds the + // downstream TRTEngineOp first, when adding the upstream op it'll need to + // update the same output connection twice. This test ensures the correctness + // of the conversion under such condition. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape({2, 1})); + // We purposefully choose the name of the root node of each segment, so it'll + // process the segment in the downstream first, then, when it tries to update + // the edge between the two TRTEngineOps, it'll try to add the same edge + // multiple times. + auto segment_root_1 = ops::Identity(s.WithOpName("segment_root_b"), input); + auto add1 = ops::Add(s.WithOpName("add1"), segment_root_1, segment_root_1); + // Add incompatible reshapes that change the batch dimension. + auto incompatible = + ops::Reshape(s.WithOpName("reshape1"), add1, Input({1, 2})); + incompatible = + ops::Reshape(s.WithOpName("reshape2"), incompatible, Input({2, 1})); + + auto add2 = ops::Add(s.WithOpName("add2"), incompatible, add1); + auto segment_root_2 = ops::Identity(s.WithOpName("segment_root_a"), add1); + auto add3 = ops::Add(s.WithOpName("add3"), add2, segment_root_2); + ops::Identity(s.WithOpName("output"), add3); + + GraphDef output_graph_def; + TF_EXPECT_OK(RunConvertAfterShape(s, &output_graph_def)); + + int num_trt_ops = 0; + for (const NodeDef& node : output_graph_def.node()) { + if (node.name() == "TRTEngineOp_1") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("input", node.input(0)); + ++num_trt_ops; + } else if (node.name() == "TRTEngineOp_0") { + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("TRTEngineOp_1", node.input(0)); + EXPECT_EQ("reshape2", node.input(1)); + ++num_trt_ops; + } + } + EXPECT_EQ(2, num_trt_ops); +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 65c8b7744e5..c0f664ddd93 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include +#include #include #include #include @@ -33,7 +34,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/tensor.pb.h" // NOLINT +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" @@ -45,13 +47,16 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/strided_slice_op.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" +#include "third_party/tensorrt/NvInferPlugin.h" // Check if the types are equal. Cast to int first so that failure log message // would work! @@ -59,7 +64,7 @@ limitations under the License. #define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ do { \ - return errors::Internal("TFTRT::", __FUNCTION__, \ + return errors::Internal("TFTRT::", __FUNCTION__, ":", __LINE__, \ " failed to add TRT layer, at: ", node); \ } while (0) @@ -94,15 +99,12 @@ namespace convert { using absl::StrAppend; using absl::StrCat; -inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) { +inline Status TfDataTypeToTrt(DataType tf_dtype, + nvinfer1::DataType* trt_dtype) { switch (tf_dtype) { case DataType::DT_FLOAT: *trt_dtype = nvinfer1::DataType::kFLOAT; break; - // TODO(aaroey): this should be DT_QINT8 which is not a well supported type. - case DataType::DT_INT8: - *trt_dtype = nvinfer1::DataType::kINT8; - break; case DataType::DT_HALF: *trt_dtype = nvinfer1::DataType::kHALF; break; @@ -116,6 +118,107 @@ inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) { return Status::OK(); } +inline Status TrtDataTypeToTf(nvinfer1::DataType trt_dtype, + DataType* tf_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + *tf_dtype = DataType::DT_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *tf_dtype = DataType::DT_HALF; + break; + case nvinfer1::DataType::kINT32: + *tf_dtype = DataType::DT_INT32; + break; + default: + return errors::InvalidArgument("Unsupported data type ", + DebugString(trt_dtype)); + } + return Status::OK(); +} + +class TFAttrs { + public: + explicit TFAttrs(const NodeDef& tf_node) { + for (const auto& attr : tf_node.attr()) { + attrs_.insert({attr.first, &attr.second}); + } + } + + bool count(const string& key) const { return attrs_.count(key); } + + AttrValue const* at(const string& key) const { + if (!attrs_.count(key)) { + LOG(FATAL) << "Attribute not found: " << key; + } + return attrs_.at(key); + } + + template + T get(const string& key) const; + + template + T get(const string& key, const T& default_value) const { + return attrs_.count(key) ? this->get(key) : default_value; + } + + std::vector GetAllAttrKeys() const { + std::vector attr_list; + for (const auto& attr_item : attrs_) { + attr_list.emplace_back(attr_item.first); + } + return attr_list; + } + + private: + typedef std::map AttrMap; + AttrMap attrs_; +}; + +template <> +string TFAttrs::get(const string& key) const { + return this->at(key)->s(); +} + +template <> +std::vector TFAttrs::get>(const string& key) const { + auto attr = this->at(key)->list().i(); + return std::vector(attr.begin(), attr.end()); +} + +template <> +std::vector TFAttrs::get>(const string& key) const { + auto attr = this->at(key)->list().f(); + return std::vector(attr.begin(), attr.end()); +} + +template <> +nvinfer1::DataType TFAttrs::get(const string& key) const { + nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); + TF_CHECK_OK(TfDataTypeToTrt(this->at(key)->type(), &trt_dtype)); + return trt_dtype; +} + +template <> +DataType TFAttrs::get(const string& key) const { + return this->at(key)->type(); +} + +template <> +float TFAttrs::get(const string& key) const { + return this->at(key)->f(); +} + +template <> +bool TFAttrs::get(const string& key) const { + return this->at(key)->b(); +} + +template <> +int64 TFAttrs::get(const string& key) const { + return this->at(key)->i(); +} + template inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, bool ignore_first_dim) { @@ -128,8 +231,8 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, return trt_dims; } -Status TensorShapeArrayToTrtDims(const std::vector& shape, - nvinfer1::Dims* out, +template +Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out, bool ignore_first_dim = false) { PartialTensorShape tensor_shape; TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape)); @@ -182,7 +285,7 @@ Status ValidateTensorProperties(const string& producer_node_type, nvinfer1::DataType* trt_dtype, nvinfer1::Dims* trt_dims, int* batch_size) { // Convert data type. - TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); + TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, trt_dtype)); // Convert shape. if (shape.dims() < 0) { @@ -192,9 +295,9 @@ Status ValidateTensorProperties(const string& producer_node_type, return errors::OutOfRange("Input tensor rank is greater than ", nvinfer1::Dims::MAX_DIMS + 1); } - if (producer_node_type != "Const" && shape.dims() < 2) { + if (producer_node_type != "Const" && shape.dims() < 1) { return errors::InvalidArgument( - "Input tensor with rank<2 is not supported since the first dimension " + "Scalar input tensor is not supported since the first dimension " "is treated as batch dimension by TRT"); } *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true); @@ -255,7 +358,12 @@ string DebugString(const nvinfer1::DataType trt_dtype) { string DebugString(const nvinfer1::Dims& dims) { string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d="); for (int i = 0; i < dims.nbDims; ++i) { - StrAppend(&out, dims.d[i], "[", DebugString(dims.type[i]), "],"); + StrAppend(&out, dims.d[i]); + if (VLOG_IS_ON(2)) { + StrAppend(&out, "[", DebugString(dims.type[i]), "],"); + } else { + StrAppend(&out, ","); + } } StrAppend(&out, ")"); return out; @@ -311,31 +419,31 @@ Status Converter::GetTrtBroadcastShape( } const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; - auto compute_output_dims = - [](const TRT_TensorOrWeights& input, int broadcast_num_dims, - int* output_dims_array, nvinfer1::Dims* output_dims) { - const nvinfer1::Dims input_dims = input.GetTrtDims(); - std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); - std::copy(input_dims.d, input_dims.d + input_dims.nbDims, - output_dims_array + broadcast_num_dims - input_dims.nbDims); - if (input.is_tensor()) { - const int true_input_dims = input_dims.nbDims + 1; - if (true_input_dims < broadcast_num_dims) { - return errors::InvalidArgument( - "Broadcasting beyond batch dimension is not supported ", - "(tensor #dims ", true_input_dims, " vs broadcast #dims ", - broadcast_num_dims, ")"); - } - // Set the batch dimension to -1, since batch size is not supposed to - // be broadcasted. - output_dims_array[0] = -1; - } - // Copy to output dimensions (stripping the batch dimension). - output_dims->nbDims = broadcast_num_dims - 1; - std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, - output_dims->d); - return Status::OK(); - }; + auto compute_output_dims = [](const TRT_TensorOrWeights& input, + int broadcast_num_dims, int* output_dims_array, + nvinfer1::Dims* output_dims) { + const nvinfer1::Dims input_dims = input.GetTrtDims(); + std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); + std::copy(input_dims.d, input_dims.d + input_dims.nbDims, + output_dims_array + broadcast_num_dims - input_dims.nbDims); + if (input.is_tensor()) { + const int true_input_dims = input_dims.nbDims + 1; + if (true_input_dims < broadcast_num_dims) { + return errors::InvalidArgument( + "Broadcasting beyond batch dimension is not supported ", + "(tensor #dims ", true_input_dims, " vs broadcast #dims ", + broadcast_num_dims, ")"); + } + // Set the batch dimension to -1, since batch size is not supposed to + // be broadcasted. + output_dims_array[0] = -1; + } + // Copy to output dimensions (stripping the batch dimension). + output_dims->nbDims = broadcast_num_dims - 1; + std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, + output_dims->d); + return Status::OK(); + }; // Compute the output dimensions. const int broadcast_num_dims = @@ -365,33 +473,51 @@ nvinfer1::ITensor* Converter::CreateConstantLayer( nvinfer1::Weights trt_weights = weights.GetTrtWeights(); nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights); if (!layer) return nullptr; - const nvinfer1::DataType trt_dtype = trt_weights.type; nvinfer1::ITensor* trt_tensor = layer->getOutput(0); +#if !IS_TRT_VERSION_GE(5, 1, 3, 0) // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set // the data type below, it will always be kFLOAT regardless what the data type // of the weights is. Once NVIDIA fixes this bug, we should remove the data // type setting logic below and test should still pass. - trt_tensor->setType(trt_dtype); + trt_tensor->setType(trt_weights.type); +#endif return trt_tensor; } Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, const nvinfer1::Dims& dims, - const nvinfer1::ITensor** tensor) { + nvinfer1::ITensor** tensor, + const char* dtype_attr_name = "T") { + nvinfer1::DataType trt_dtype = + nvinfer1::DataType::kFLOAT; // Default to FP32. + TFAttrs attrs(params->node_def); + if (attrs.count(dtype_attr_name)) { + DataType dtype = attrs.get(dtype_attr_name); + TF_RETURN_IF_ERROR(TfDataTypeToTrt(dtype, &trt_dtype)); + } + // In order to be broadcastable, the number of dims has to match. nvinfer1::Dims broadcastable_dims(dims); for (int i = 0; i < broadcastable_dims.nbDims; i++) { broadcastable_dims.d[i] = 1; } - TRT_ShapedWeights weights = params->weight_store->GetTempWeights( - DataType::DT_FLOAT, broadcastable_dims); - auto weights_ptr = - static_cast(const_cast(weights.GetValues())); - weights_ptr[0] = value; + TRT_ShapedWeights weights = + params->weight_store->GetTempWeights(trt_dtype, broadcastable_dims); + void* raw_ptr = weights.GetValues(); + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + static_cast(raw_ptr)[0] = value; + break; + case nvinfer1::DataType::kHALF: + static_cast(raw_ptr)[0] = Eigen::half(value); + break; + default: + return errors::InvalidArgument("Unsupported data type ", + DebugString(trt_dtype)); + } *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims); TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name()); - params->converter->ProvideQuantizationRange( - const_cast(*tensor), value, value); + params->converter->ProvideQuantizationRange(*tensor, value, value); return Status::OK(); } @@ -460,11 +586,7 @@ inline bool HasStaticShape(const nvinfer1::Dims& dims) { return true; } -// Returns total number of elements in dims. Returning 0 means either some dim -// is 0 or the number of dims is 0. -// Note that for TF scalar constant, we always convert to dims [1]. -int64_t TrtDimsNumElements(const nvinfer1::Dims& dims) { - if (dims.nbDims == 0) return 0; +int64_t Prod(const nvinfer1::Dims& dims) { int64_t count = 1; for (int d = 0; d < dims.nbDims; ++d) { count *= dims.d[d]; @@ -472,6 +594,46 @@ int64_t TrtDimsNumElements(const nvinfer1::Dims& dims) { return count; } +// Returns total number of elements in a TensorRT weights dimensions. +// Returning 0 means either some dim is 0 or the number of dims is 0 (TensorRT +// doesn't allow scalar weights). +// Note that for TF scalar constant, we always convert to dims [1]. +int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims) { + if (dims.nbDims == 0) return 0; + return Prod(dims); +} + +// Returns total number of elements in an ITensor dimension. +// Returns 1 if the number of dims is 0 (the total number is fully determined by +// the batch size). +// Returns -1 if any dimension is known. +int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims) { + if (!HasStaticShape(dims)) return -1; + return Prod(dims); +} + +bool DimsHaveSameSize(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs, + bool is_tensor) { + if (is_tensor) { + return TrtTensorDimsNumElements(lhs) == TrtTensorDimsNumElements(rhs); + } + return TrtWeightDimsNumElements(lhs) == TrtWeightDimsNumElements(rhs); +} + +// Returns whether both dimensions are fully specified and the total number of +// elements equals. +bool AreDimsStaticWithSameSize(const nvinfer1::Dims& lhs, + const nvinfer1::Dims& rhs, bool is_tensor) { + if (!HasStaticShape(lhs) || !HasStaticShape(rhs)) return false; + return DimsHaveSameSize(lhs, rhs, is_tensor); +} + +bool AreDimsStaticWithDifferentSize(const nvinfer1::Dims& lhs, + const nvinfer1::Dims& rhs, bool is_tensor) { + if (!HasStaticShape(lhs) || !HasStaticShape(rhs)) return false; + return !DimsHaveSameSize(lhs, rhs, is_tensor); +} + static std::vector> CreateSamePadding( const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel, const std::vector& input_dims) { @@ -506,32 +668,69 @@ string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { return op_name_a.substr(0, last_scope_separator); } -TRT_ShapedWeights::TRT_ShapedWeights(DataType type) : type_(type) { +// Verifies that shapes of the given inputs match after masking the specified +// dimension. +Status VerifyShapesMatch(absl::Span inputs, + int masked_dim, absl::string_view node_name) { + size_t num_inputs = inputs.size(); + if (num_inputs <= 1) return Status::OK(); + + const nvinfer1::Dims dims_0 = inputs.at(0).GetTrtDims(); + for (size_t i = 1; i < num_inputs; ++i) { + const nvinfer1::Dims dim_i = inputs.at(i).GetTrtDims(); + if (dim_i.nbDims != dims_0.nbDims) { + return errors::InvalidArgument( + "Received inputs with inconsistent rank, at ", node_name); + } + for (size_t j = 0; j < dims_0.nbDims; ++j) { + if (dim_i.d[j] != dims_0.d[j] && j != masked_dim) { + return errors::InvalidArgument( + "Received inputs with inconsistent shape, at ", node_name); + } + } + } + return Status::OK(); +} + +TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type) : type_(type) { shape_.nbDims = 0; } -TRT_ShapedWeights::TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, - Tensor tensor) +TRT_ShapedWeights::TRT_ShapedWeights(nvinfer1::DataType type, + nvinfer1::Dims dims, Tensor tensor) : shape_(dims), type_(type), tensor_(tensor) {} TRT_ShapedWeights::TRT_ShapedWeights(const TRT_ShapedWeights& rhs) : shape_(rhs.shape_), type_(rhs.type_), tensor_(rhs.tensor_) {} -int64_t TRT_ShapedWeights::count() const { return TrtDimsNumElements(shape_); } +int64_t TRT_ShapedWeights::count() const { + return TrtWeightDimsNumElements(shape_); +} nvinfer1::Weights TRT_ShapedWeights::GetTrtWeights() const { - nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT); - TF_CHECK_OK(ConvertDType(type_, &trt_type)); - return nvinfer1::Weights{trt_type, GetValues(), count()}; + return nvinfer1::Weights{type_, GetValues(), count()}; } size_t TRT_ShapedWeights::size_bytes() const { - return this->count() * DataTypeSize(this->type_); + size_t data_type_size = -1; + switch (type_) { + case nvinfer1::DataType::kFLOAT: + case nvinfer1::DataType::kINT32: + data_type_size = 4; + break; + case nvinfer1::DataType::kHALF: + data_type_size = 2; + break; + case nvinfer1::DataType::kINT8: + data_type_size = 1; + break; + } + return this->count() * data_type_size; } string TRT_ShapedWeights::DebugString() const { return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_), - ", type=", DataTypeString(type_), + ", type=", convert::DebugString(type_), ", values=", reinterpret_cast(GetValues()), ")"); } @@ -574,13 +773,13 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { void setLocation(nvinfer1::TensorLocation location) override {} -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0, 0) bool setDynamicRange(float min, float max) override { return true; } float getDynamicRange() const override { return 0; } #endif -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0, 0) bool dynamicRangeIsSet() const override { return true; } void resetDynamicRange() override {} @@ -590,6 +789,14 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { float getDynamicRangeMax() const override { return 0.f; } #endif +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + void setAllowedFormats(nvinfer1::TensorFormats formats) override {} + + nvinfer1::TensorFormats getAllowedFormats() const override { return 1; } + + bool isShape() const override { return false; } +#endif + private: nvinfer1::DataType trt_dtype_; nvinfer1::Dims trt_dims_; @@ -630,12 +837,7 @@ void TRT_TensorOrWeights::operator=(const TRT_TensorOrWeights& rhs) { is_tensor_ = rhs.is_tensor_; } -nvinfer1::ITensor* TRT_TensorOrWeights::tensor() { - CHECK(is_tensor()); - return tensor_ == nullptr ? simple_itensor_.get() : tensor_; -} - -const nvinfer1::ITensor* TRT_TensorOrWeights::tensor() const { +nvinfer1::ITensor* TRT_TensorOrWeights::tensor() const { CHECK(is_tensor()); return tensor_ == nullptr ? simple_itensor_.get() : tensor_; } @@ -660,88 +862,6 @@ string TRT_TensorOrWeights::DebugString() const { return output; } -class TFAttrs { - public: - explicit TFAttrs(const NodeDef& tf_node) { - for (const auto& attr : tf_node.attr()) { - attrs_.insert({attr.first, &attr.second}); - } - } - - bool count(const string& key) const { return attrs_.count(key); } - - AttrValue const* at(const string& key) const { - if (!attrs_.count(key)) { - LOG(FATAL) << "Attribute not found: " << key; - } - return attrs_.at(key); - } - - template - T get(const string& key) const; - - template - T get(const string& key, const T& default_value) const { - return attrs_.count(key) ? this->get(key) : default_value; - } - - std::vector GetAllAttrKeys() const { - std::vector attr_list; - for (const auto& attr_item : attrs_) { - attr_list.emplace_back(attr_item.first); - } - return attr_list; - } - - private: - typedef std::map AttrMap; - AttrMap attrs_; -}; - -template <> -string TFAttrs::get(const string& key) const { - return this->at(key)->s(); -} - -template <> -std::vector TFAttrs::get>(const string& key) const { - auto attr = this->at(key)->list().i(); - return std::vector(attr.begin(), attr.end()); -} - -template <> -std::vector TFAttrs::get>(const string& key) const { - auto attr = this->at(key)->list().f(); - return std::vector(attr.begin(), attr.end()); -} - -template <> -nvinfer1::DataType TFAttrs::get(const string& key) const { - nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); - TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); - return trt_dtype; -} - -template <> -DataType TFAttrs::get(const string& key) const { - return this->at(key)->type(); -} - -template <> -float TFAttrs::get(const string& key) const { - return this->at(key)->f(); -} - -template <> -bool TFAttrs::get(const string& key) const { - return this->at(key)->b(); -} - -template <> -int64 TFAttrs::get(const string& key) const { - return this->at(key)->i(); -} - // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -782,32 +902,27 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, oweights->shape_.d[1] = c; const nvinfer1::DimsHW istrides = {1, k}; const nvinfer1::DimsHW ostrides = {c, 1}; - switch (iweights.type_) { - case DataType::DT_FLOAT: { + switch (iweights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { Reorder2({k, c}, static_cast(iweights.GetValues()), - istrides, - // TODO(aaroey): get rid of all the const_cast like this. - static_cast(const_cast(oweights->GetValues())), - ostrides); + istrides, static_cast(oweights->GetValues()), ostrides); break; } - case DataType::DT_HALF: { - Reorder2( - {k, c}, static_cast(iweights.GetValues()), - istrides, - static_cast(const_cast(oweights->GetValues())), - ostrides); + case nvinfer1::DataType::kHALF: { + Reorder2({k, c}, static_cast(iweights.GetValues()), + istrides, static_cast(oweights->GetValues()), + ostrides); break; } default: LOG(FATAL) << "Unsupported type in reorder expected fp32 or fp16 but got " - << DataTypeString(iweights.type_); + << DebugString(iweights.TrtDType()); } } void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, const int num_groups) { - CHECK_EQ(iweights.type_, oweights->type_); + CHECK(iweights.TrtDType() == oweights->TrtDType()); CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); // K indexes over output channels, C over input channels, and R and S over the // height and width of the convolution @@ -826,37 +941,35 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, oweights->shape_.d[3] = s; const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; - switch (iweights.type_) { - case DataType::DT_FLOAT: { + switch (iweights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { Reorder4({k, c, r, s}, static_cast(iweights.GetValues()), - istrides, - static_cast(const_cast(oweights->GetValues())), - ostrides); + istrides, static_cast(oweights->GetValues()), ostrides); break; } - case DataType::DT_HALF: { - Reorder4( - {k, c, r, s}, static_cast(iweights.GetValues()), - istrides, - static_cast(const_cast(oweights->GetValues())), - ostrides); + case nvinfer1::DataType::kHALF: { + Reorder4({k, c, r, s}, + static_cast(iweights.GetValues()), istrides, + static_cast(oweights->GetValues()), ostrides); break; } default: LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got " - << DataTypeString(iweights.type_); + << DebugString(iweights.TrtDType()); } } -TRT_ShapedWeights TrtWeightStore::GetTempWeights(DataType type, +TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& dims) { TensorShape shape; + DataType tf_dtype; // TODO(laigd): make it return a status. TF_CHECK_OK(TensorShapeUtils::MakeShape(dims.d, dims.nbDims, &shape)); + TF_CHECK_OK(TrtDataTypeToTf(trt_dtype, &tf_dtype)); // TODO(jie): check weights size_bytes. 0 means type error - Tensor tensor(type, shape); - TRT_ShapedWeights weights(type, dims, tensor); + Tensor tensor(tf_dtype, shape); + TRT_ShapedWeights weights(trt_dtype, dims, tensor); store_.emplace_back(std::move(tensor)); return weights; } @@ -970,11 +1083,45 @@ Status TrtNodeValidator::ConvertConstToWeights( return status; } +static void InitializeTrtPlugins() { + static mutex plugin_mutex(LINKER_INITIALIZED); + static bool plugin_initialized = false; + static Logger trt_logger; + mutex_lock lock(plugin_mutex); + if (plugin_initialized) return; + + plugin_initialized = initLibNvInferPlugins(&trt_logger, ""); + if (!plugin_initialized) { + LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may " + "fail later."; + } + + int num_trt_plugins = 0; + nvinfer1::IPluginCreator* const* trt_plugin_creator_list = + getPluginRegistry()->getPluginCreatorList(&num_trt_plugins); + if (!trt_plugin_creator_list) { + LOG(WARNING) << "Can not find any TensorRT plugins in registry."; + } else { + VLOG(1) << "Found the following " << num_trt_plugins + << " TensorRT plugins in registry:"; + for (int i = 0; i < num_trt_plugins; ++i) { + if (!trt_plugin_creator_list[i]) { + LOG(WARNING) << "TensorRT plugin at index " << i + << " is not accessible (null pointer returned by " + "getPluginCreatorList for this plugin)"; + } else { + VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName(); + } + } + } +} + Converter::Converter(nvinfer1::INetworkDefinition* trt_network, TrtPrecisionMode precision_mode, bool use_calibration) : trt_network_(trt_network), precision_mode_(precision_mode), use_calibration_(use_calibration) { + InitializeTrtPlugins(); this->RegisterOpConverters(); } @@ -1133,7 +1280,7 @@ Status Converter::GetTensorOrWeights(const string& name, Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, const std::vector& order_with_batch_dim, - const nvinfer1::ITensor** output_tensor) { + nvinfer1::ITensor** output_tensor) { const auto dims = input_tensor->getDimensions(); if (order_with_batch_dim.size() - 1 != size_t(dims.nbDims)) { @@ -1172,22 +1319,22 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, float* out_max) const { - switch (weights.type_) { - case DataType::DT_FLOAT: { + switch (weights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = *result.first; *out_max = *result.second; break; } - case DataType::DT_HALF: { + case nvinfer1::DataType::kHALF: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = Eigen::half_impl::half_to_float(*result.first); *out_max = Eigen::half_impl::half_to_float(*result.second); break; } - case DataType::DT_INT32: { + case nvinfer1::DataType::kINT32: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = static_cast(*result.first); @@ -1197,7 +1344,7 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, default: return errors::Unimplemented( "Data type not supported for GetWeightRange: ", - DataTypeString(weights.type_)); + DebugString(weights.TrtDType())); } return Status::OK(); } @@ -1205,21 +1352,24 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, const bool validation_only, - const nvinfer1::ITensor** tensor) { - // If -1 is not used for one of the dims, we can check if the shapes are - // compatible. - bool can_check_shapes = true; - for (int i = 0; i < dims.nbDims; i++) { - if (dims.d[i] == -1) { - can_check_shapes = false; - break; - } + nvinfer1::ITensor** tensor) { + const nvinfer1::Dims input_dims = input.GetTrtDims(); + // If one of input_dims and dims doesn't have static shape, it means some of + // the dims are unknown or need to be inferred. And we don't do further checks + // but rely on the caller to not make mistakes. + // Otherwise we do simple check to make sure the total sizes are the same. + // If an input is a weight, it is going to become a tensor via + // CreateConstantLayer. So we can treat it as a tensor for + // AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors. + if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) { + return errors::InvalidArgument( + "Incompatible shapes: ", DebugString(input_dims), " vs. ", + DebugString(dims)); } - if (can_check_shapes && - TrtDimsNumElements(input.GetTrtDims()) != TrtDimsNumElements(dims)) { - return errors::InvalidArgument("Reshape shapes are not compatible (", - DebugString(input.GetTrtDims()), " vs ", - DebugString(dims), ")"); + // ConstantLayer requires static shapes (cannot infer -1). + if (input.is_weights() && !HasStaticShape(dims)) { + return errors::InvalidArgument("Shape is not fully defined: ", + DebugString(dims)); } if (validation_only) { *tensor = nullptr; @@ -1227,15 +1377,14 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, } if (input.is_tensor()) { - if (DimsEqual(input.GetTrtDims(), dims)) { + if (DimsEqual(input_dims, dims)) { *tensor = input.tensor(); } else { - nvinfer1::IShuffleLayer* layer = this->network()->addShuffle( - *const_cast(input.tensor())); + nvinfer1::IShuffleLayer* layer = + this->network()->addShuffle(*input.tensor()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); layer->setReshapeDimensions(dims); - MarkQuantizationRangesAsInferrable( - const_cast(input.tensor()), layer->getOutput(0)); + MarkQuantizationRangesAsInferrable(input.tensor(), layer->getOutput(0)); *tensor = layer->getOutput(0); } } else { @@ -1256,8 +1405,7 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, min_range = -127.0f; max_range = 127.0f; } - ProvideQuantizationRange(const_cast(*tensor), - min_range, max_range); + ProvideQuantizationRange(*tensor, min_range, max_range); } } return Status::OK(); @@ -1281,7 +1429,7 @@ void Converter::MaybeApplyQuantizationRanges() { // Infer ranges across marked ops. PropagateQuantizationRanges(); // Apply ranges. -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0, 0) for (auto pair : quantization_ranges_) { nvinfer1::ITensor* tensor = pair.first; const float range = pair.second; @@ -1433,43 +1581,30 @@ Status CheckInputsWeights( } Status AllowDataTypes(const OpConverterParams& params, - const std::set& allowed_dtypes) { + const std::set& allowed_dtypes, + const char* dtype_attr_name = "T") { const auto& node_def = params.node_def; - TFAttrs attrs(params.node_def); - if (attrs.count("T")) { - const auto op_dtype = attrs.get("T"); - if (!allowed_dtypes.count(op_dtype)) { - // Build string list of allowed types. - std::ostringstream ss; - for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { - if (it != allowed_dtypes.begin()) ss << ", "; - ss << DataTypeString(*it); - } - return errors::Unimplemented("Data type ", DataTypeString(op_dtype), - " is not supported for ", node_def.op(), - ", must be one of [", ss.str(), "], at ", - node_def.name()); + TFAttrs attrs(node_def); + if (!attrs.count(dtype_attr_name)) { + return errors::InvalidArgument("Attribute with name ", dtype_attr_name, + " not found."); + } + const auto op_dtype = attrs.get(dtype_attr_name); + if (!allowed_dtypes.count(op_dtype)) { + // Build string list of allowed types. + std::ostringstream ss; + for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { + if (it != allowed_dtypes.begin()) ss << ", "; + ss << DataTypeString(*it); } + return errors::Unimplemented("Data type ", DataTypeString(op_dtype), + " is not supported for ", node_def.op(), + ", must be one of [", ss.str(), "], at ", + node_def.name()); } - // If there is no T attribute, we can't determine the type of the op. We will - // allow it to convert for now. return Status::OK(); } -TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, - const TRT_ShapedWeights& weights_src) { - auto dtype_new = DataType::DT_HALF; - TRT_ShapedWeights weights = - store->GetTempWeights(dtype_new, weights_src.shape_); - const float* src = static_cast(weights_src.GetValues()); - Eigen::half* dst = const_cast( - static_cast(weights.GetValues())); - for (int64_t i = 0; i < weights_src.count(); i++) { - dst[i] = Eigen::half_impl::float_to_half_rtne(src[i]); - } - return weights; -} - // **************************************************************************** // Constant folding functions for weights. // TODO(laigd): we should probably use eigen directly. @@ -1483,7 +1618,7 @@ struct LambdaFactory { switch (op) { case OP_CATEGORY::RSQRT: { VLOG(2) << "RSQRT GETS DONE"; - return [](T t) -> T { return 1.0 / sqrt(t); }; + return [](T t) -> T { return 1.0 / std::sqrt(t); }; } case OP_CATEGORY::NEG: return [](T t) -> T { return -t; }; @@ -1502,7 +1637,7 @@ std::function LambdaFactory::unary() { case OP_CATEGORY::RSQRT: { VLOG(2) << "RSQRT GETS DONE"; return [](Eigen::half t) { - return Eigen::half(1.0 / sqrt(static_cast(t))); + return Eigen::half(1.0 / std::sqrt(static_cast(t))); }; } case OP_CATEGORY::NEG: @@ -1519,25 +1654,24 @@ std::function LambdaFactory::unary() { Status UnaryCompute(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, LambdaFactory unary_op) { - CHECK_EQ(iweights.type_, oweights->type_); - switch (iweights.type_) { - case DataType::DT_FLOAT: { + CHECK(iweights.TrtDType() == oweights->TrtDType()); + switch (iweights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { auto inp = static_cast(iweights.GetValues()); - auto oup = static_cast(const_cast(oweights->GetValues())); + auto oup = static_cast(oweights->GetValues()); std::transform(inp, inp + iweights.count(), oup, unary_op.unary()); break; } - case DataType::DT_HALF: { + case nvinfer1::DataType::kHALF: { auto inp = static_cast(iweights.GetValues()); - auto oup = - static_cast(const_cast(oweights->GetValues())); + auto oup = static_cast(oweights->GetValues()); std::transform(inp, inp + iweights.count(), oup, unary_op.unary()); break; } default: - return errors::Unimplemented("Data type not supported: " + - DataTypeString(iweights.type_)); + return errors::Unimplemented("Data type not supported: ", + DebugString(iweights.TrtDType())); } return Status::OK(); } @@ -1548,7 +1682,7 @@ Status UnaryCompute(const TRT_ShapedWeights& iweights, // TODO(jie): broadcast is needed yet not implemented. // Only implemented channel wise for the time being. Status BinaryTensorOpWeight(OpConverterParams* params, - const nvinfer1::ITensor* tensor, + nvinfer1::ITensor* tensor, TRT_ShapedWeights weights, bool swapped_inputs) { static const std::unordered_set supported_ops = {"Sub", "Add", "Mul", "Div", "RealDiv"}; @@ -1558,10 +1692,6 @@ Status BinaryTensorOpWeight(OpConverterParams* params, node_def.name()); } - // Check type consistency. - nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype)); - // Check scale mode. auto dims_w = weights.shape_; const auto dims_t = tensor->getDimensions(); @@ -1642,30 +1772,25 @@ Status BinaryTensorOpWeight(OpConverterParams* params, } permutation[1] = dims_t.nbDims; permutation[dims_t.nbDims] = 1; - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(tensor), permutation, &tensor)); - } - - if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { - weights = ConvertFP32ToFP16(params->weight_store, weights); + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, permutation, &tensor)); } // Prepare weights - TRT_ShapedWeights shift_weights(weights.type_); - TRT_ShapedWeights scale_weights(weights.type_); - TRT_ShapedWeights power_weights(weights.type_); + TRT_ShapedWeights shift_weights(weights.TrtDType()); + TRT_ShapedWeights scale_weights(weights.TrtDType()); + TRT_ShapedWeights power_weights(weights.TrtDType()); if (node_def.op() == "Sub") { if (swapped_inputs) { shift_weights = weights; nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( - *const_cast(tensor), - nvinfer1::UnaryOperation::kNEG); + *tensor, nvinfer1::UnaryOperation::kNEG); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); // Since quantization ranges are symmetric, the same range as the input // will work for the negation of the input. params->converter->MarkQuantizationRangesAsInferrable( - const_cast(tensor), layer->getOutput(0)); + tensor, layer->getOutput(0)); tensor = layer->getOutput(0); } else { TRT_ShapedWeights neg_weights = @@ -1698,8 +1823,7 @@ Status BinaryTensorOpWeight(OpConverterParams* params, } scale_weights = weights; nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( - *const_cast(tensor), - nvinfer1::UnaryOperation::kRECIP); + *tensor, nvinfer1::UnaryOperation::kRECIP); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); tensor = layer->getOutput(0); } else { @@ -1720,22 +1844,19 @@ Status BinaryTensorOpWeight(OpConverterParams* params, } nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( - *const_cast(tensor), scale_mode, - shift_weights.GetTrtWeights(), scale_weights.GetTrtWeights(), - power_weights.GetTrtWeights()); + *tensor, scale_mode, shift_weights.GetTrtWeights(), + scale_weights.GetTrtWeights(), power_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - const nvinfer1::ITensor* output_tensor = layer->getOutput(0); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); // Transpose back dimension if (need_to_permute) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(output_tensor), permutation, - &output_tensor)); + output_tensor, permutation, &output_tensor)); } // Pass the output - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -1744,7 +1865,7 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, const auto& inputs = params->inputs; const auto& node_def = params->node_def; TRT_TensorOrWeights backprop_output_size; - const nvinfer1::ITensor* tensor = nullptr; + nvinfer1::ITensor* tensor = nullptr; if (is_conv2d_backprop_input) { // In the case when Conv2dBackpropInput is used for conv2d_transpose, these // inputs correspond to: output size, filter, and input. @@ -1805,8 +1926,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, // Transpose to NCHW (NCHW is required for IConvLayer). const bool need_transpose = (data_format == "NHWC"); if (need_transpose) { - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(tensor), {0, 3, 1, 2}, &tensor)); + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, {0, 3, 1, 2}, &tensor)); } // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); @@ -1816,16 +1937,13 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, // num_groups will be 1. const int num_groups = (group == 0) ? tensor_dim.d[0] : group; - if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { - weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck); - } // For conv, TF weights are RSCK, and TRT expects KCRS. // For backprop, TF weights are RSKC, and TRT expects CKRS. // Therefore, this reorder will work for both cases. TRT_ShapedWeights weights = params->weight_store->GetTempWeights(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); - TRT_ShapedWeights biases(weights.type_); + TRT_ShapedWeights biases(weights.TrtDType()); const int output_axis = is_conv2d_backprop_input ? 1 : 0; const int noutput = weights.shape_.d[output_axis] * num_groups; nvinfer1::DimsHW kernel_size; @@ -1845,8 +1963,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, // context of Conv2DBackpropInput). // We use h_index and w_index instead of 1 and 2 because we havent // transposed backprop_output_size along with the input. - auto output_size_weights = static_cast( - const_cast(backprop_output_size.weights().GetValues())); + auto output_size_weights = + static_cast(backprop_output_size.weights().GetValues()); input_dims = {output_size_weights[h_index], output_size_weights[w_index]}; } else { // Use 1 and 2 because tensor_dim has the dimensions of the transposed @@ -1858,56 +1976,87 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, } else { padding = {{0, 0}, {0, 0}}; } + +// TensorRT 5.1 added support for asymmetric padding. Due to a bug in 5.1.2, we +// can only use asymmetric padding in convolutions with 5.1.3+. +#if !IS_TRT_VERSION_GE(5, 1, 3, 0) if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { // Handle asymmetric padding. auto pad_layer = params->converter->network()->addPadding( - *const_cast(tensor), - nvinfer1::DimsHW(padding[0].first, padding[1].first), + *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); params->converter->MarkQuantizationRangesAsInferrable( - const_cast(tensor), pad_layer->getOutput(0)); + tensor, pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); } +#endif // Add convolution. nvinfer1::ILayer* conv_layer = nullptr; if (is_conv2d_backprop_input) { nvinfer1::IDeconvolutionLayer* layer = params->converter->network()->addDeconvolution( - *const_cast(tensor), noutput, kernel_size, - weights.GetTrtWeights(), biases.GetTrtWeights()); + *tensor, noutput, kernel_size, weights.GetTrtWeights(), + biases.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setStride(stride); - layer->setPadding({padding[0].first, padding[1].first}); +// TensorRT 5.1.3 added support for padding modes. +#if IS_TRT_VERSION_GE(5, 1, 3, 0) + if (attrs.get("padding") == "SAME") { + VLOG(2) << "Using SAME padding"; + // SAME_UPPER means that post padding is preferred. + layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } + // For VALID padding, we need to manually set the padding. + layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); + layer->setPostPadding( + nvinfer1::DimsHW{padding[0].second, padding[1].second}); + VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding()) + << " and post-padding to: " << DebugString(layer->getPostPadding()); +#else + layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); + VLOG(2) << "Set padding to: " << DebugString(layer->getPadding()); +#endif layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); conv_layer = layer; } else { nvinfer1::IConvolutionLayer* layer = params->converter->network()->addConvolution( - *const_cast(tensor), noutput, kernel_size, - weights.GetTrtWeights(), biases.GetTrtWeights()); + *tensor, noutput, kernel_size, weights.GetTrtWeights(), + biases.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setStride(stride); - layer->setPadding({padding[0].first, padding[1].first}); +#if IS_TRT_VERSION_GE(5, 1, 3, 0) + if (attrs.get("padding") == "SAME") { + VLOG(2) << "Using SAME padding"; + layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } + layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); + layer->setPostPadding( + nvinfer1::DimsHW{padding[0].second, padding[1].second}); + VLOG(2) << "Set pre-padding to: " << DebugString(layer->getPrePadding()) + << " and post-padding to: " << DebugString(layer->getPostPadding()); +#else + layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); + VLOG(2) << "Set padding to: " << DebugString(layer->getPadding()); +#endif layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); layer->setDilation(dilation); conv_layer = layer; } - const nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); + nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); // Restore transpose. if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(output_tensor), {0, 2, 3, 1}, - &output_tensor)); + output_tensor, {0, 2, 3, 1}, &output_tensor)); } - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -1923,6 +2072,7 @@ Status BinaryTensorOpTensor(OpConverterParams* params, {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, + {"Pow", nvinfer1::ElementWiseOperation::kPOW}, }; auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) { @@ -1947,8 +2097,8 @@ Status BinaryTensorOpTensor(OpConverterParams* params, } if (params->validation_only) return Status::OK(); - const nvinfer1::ITensor* tensor_l = nullptr; - const nvinfer1::ITensor* tensor_r = nullptr; + nvinfer1::ITensor* tensor_l = nullptr; + nvinfer1::ITensor* tensor_r = nullptr; status = params->converter->PrepareTensorForShape( operand_l, broadcasted_dims_l, /*validation_only=*/false, &tensor_l); if (status.ok()) { @@ -1968,9 +2118,8 @@ Status BinaryTensorOpTensor(OpConverterParams* params, // Add ElementWise layer. nvinfer1::IElementWiseLayer* layer = - params->converter->network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), op_pair->second); + params->converter->network()->addElementWise(*tensor_l, *tensor_r, + op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -1985,8 +2134,8 @@ Status ConvertPlugin(OpConverterParams* params) { // prepare input std::vector all_inputs; all_inputs.reserve(inputs.size()); - for (auto input : inputs) { - all_inputs.emplace_back(const_cast(input.tensor())); + for (const auto& input : inputs) { + all_inputs.emplace_back(input.tensor()); } // plugin is owned by PluginFactory @@ -2026,13 +2175,11 @@ Status ConvertTranspose(OpConverterParams* params) { *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); // Get the permutation from weights. TRT_ShapedWeights weights = inputs.at(1).weights(); - const int* weights_ptr = - static_cast(const_cast(weights.GetValues())); + const int* weights_ptr = static_cast(weights.GetValues()); std::vector perm(weights_ptr, weights_ptr + weights.count()); // Verify the permutation. - nvinfer1::ITensor* input_tensor = - const_cast(inputs.at(0).tensor()); + nvinfer1::ITensor* input_tensor = inputs.at(0).tensor(); if (perm.size() - 1 != size_t(input_tensor->getDimensions().nbDims)) { return errors::InvalidArgument( "Rank of perm for transpose does not match with that of the input."); @@ -2045,11 +2192,10 @@ Status ConvertTranspose(OpConverterParams* params) { if (params->validation_only) return Status::OK(); // Start conversion. - const nvinfer1::ITensor* output_tensor = nullptr; + nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR( params->converter->TransposeTensor(input_tensor, perm, &output_tensor)); - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -2060,15 +2206,14 @@ Status ConvertReshape(OpConverterParams* params) { CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}})); TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); - TRT_TensorOrWeights input_tensor = inputs.at(0); + const TRT_TensorOrWeights& input_tensor = inputs.at(0); TRT_ShapedWeights weights = inputs.at(1).weights(); if (weights.count() == 0) { return errors::Unimplemented("Reshape to shape=[] is not supported, at ", node_def.name()); } - const int* weights_ptr = - static_cast(const_cast(weights.GetValues())); + const int* weights_ptr = static_cast(weights.GetValues()); // Check that it doesn't change the batch dimension. This check is // conservative, for example, when the first dim of the shape is -1 and input @@ -2101,6 +2246,18 @@ Status ConvertReshape(OpConverterParams* params) { // not ok // else: // not ok + // + // Note that the following is ok no matter whether reshape_batch_dim is fixed + // or not: + // + // ``` + // input_batch_dim is not fixed && + // reshape_dims are fixed && + // prod(input_dims) == prod(reshape_dims), + // ``` + // + // because the non-batch dims of the new and old shapes match, and TF runtime + // should make sure the batch dim is not changed. const int input_batch_dim = input_tensor.batch_size(); const int reshape_batch_dim = weights_ptr[0]; @@ -2117,19 +2274,18 @@ Status ConvertReshape(OpConverterParams* params) { bool reshape_may_change_batch_dim = false; if (input_batch_dim > 0) { // Batch size is fixed. if (reshape_batch_dim == -1) { // Other dims of the shape must be fixed. - if (!HasStaticShape(input_dims) || - TrtDimsNumElements(reshape_dims) != TrtDimsNumElements(input_dims)) { + if (!AreDimsStaticWithSameSize(input_dims, reshape_dims, + /*is_tensor=*/true)) { reshape_may_change_batch_dim = true; } } else if (reshape_batch_dim != input_batch_dim) { reshape_may_change_batch_dim = true; + } else { + // This means (input_batch_dim>0 && input_batch_dim==reshape_batch_dim), + // and TF runtime should make sure non-batch dims are matched. } - } else if (HasStaticShape(input_dims)) { - if (!HasStaticShape(reshape_dims) || - TrtDimsNumElements(reshape_dims) != TrtDimsNumElements(input_dims)) { - reshape_may_change_batch_dim = true; - } - } else { + } else if (!AreDimsStaticWithSameSize(input_dims, reshape_dims, + /*is_tensor=*/true)) { reshape_may_change_batch_dim = true; } VLOG(1) << "input_batch_dim=" << input_batch_dim @@ -2141,14 +2297,14 @@ Status ConvertReshape(OpConverterParams* params) { "Reshape on batch dimension is not supported, at ", node_def.name()); return errors::Unimplemented(msg); } - if (params->validation_only) return Status::OK(); // Start conversion. - const nvinfer1::ITensor* output_tensor = nullptr; + nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, reshape_dims, /*validation_only=*/false, &output_tensor)); - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + input_tensor, reshape_dims, params->validation_only, &output_tensor)); + if (params->validation_only) return Status::OK(); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -2160,48 +2316,31 @@ Status ConvertExpandDims(OpConverterParams* params) { TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); // Get input shape as vector. - TRT_TensorOrWeights input_tensor = inputs.at(0); + const TRT_TensorOrWeights& input_tensor = inputs.at(0); const nvinfer1::Dims dims = input_tensor.GetTrtDims(); std::vector input_dims(dims.d, dims.d + dims.nbDims); - // Add batch dim back. - input_dims.insert(input_dims.begin(), -1); - const int input_rank = input_dims.size(); // Get axis to expand on. - TRT_ShapedWeights weights = inputs.at(1).weights(); - if (weights.count() != 1) { + auto axis = inputs.at(1).weights().GetSpan(); + if (axis.size() != 1) { return errors::InvalidArgument("ExpandDims axis must be a scalar, at ", node_def.name()); } - const int* weights_ptr = - static_cast(const_cast(weights.GetValues())); - int axis = weights_ptr[0]; - // Make sure axis is valid. - if ((axis < (-input_rank - 1)) || (axis > input_rank)) { - return errors::InvalidArgument( - "Axis for ExpandDims is invalid, must be in the range " - "[-rank(input) - 1, rank(input)], at ", - node_def.name()); - } - // Convert negative axis to corresponding positive axis. - if (axis < 0) axis += input_rank + 1; - if (axis == 0) { - return errors::Unimplemented( - "Modifying batch dimension is not supported for ExpandDims, at ", - node_def.name()); - } + // Use rank = nbDims + 1 for ConvertAxis's bounds checking to account for + // ExpandDim's ability to add an axis at end of the shape. + int trt_axis; + TF_RETURN_IF_ERROR( + ConvertAxis(axis[0], dims.nbDims + 1, node_def.name(), &trt_axis)); if (params->validation_only) return Status::OK(); // ExpandDims: Insert new dim of size 1. - input_dims.insert(input_dims.begin() + axis, 1); + input_dims.insert(input_dims.begin() + trt_axis, 1); // Reshape tensor. nvinfer1::Dims new_dims; - TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, - /*ignore_first_dim=*/true)); - const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims)); + nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( input_tensor, new_dims, /*validation_only=*/false, &output_tensor)); - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -2212,12 +2351,9 @@ Status ConvertSqueeze(OpConverterParams* params) { TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); // Get input shape. - TRT_TensorOrWeights input_tensor = inputs.at(0); + const TRT_TensorOrWeights& input_tensor = inputs.at(0); const nvinfer1::Dims dims = input_tensor.GetTrtDims(); std::vector input_dims(dims.d, dims.d + dims.nbDims); - // Add batch dim back. - input_dims.insert(input_dims.begin(), -1); - const int input_rank = input_dims.size(); // Mark axes to remove by setting them to 0. TFAttrs attrs(node_def); auto squeeze_dims = attrs.get>("squeeze_dims"); @@ -2225,29 +2361,20 @@ Status ConvertSqueeze(OpConverterParams* params) { return errors::Unimplemented( "Squeeze is only implemented for explicit dims, at ", node_def.name()); } - for (int axis : squeeze_dims) { + for (int tf_axis : squeeze_dims) { // Make sure axis is valid. - if ((axis < -input_rank) || (axis >= input_rank)) { - return errors::InvalidArgument( - "Axis for Squeeze is invalid, must be in the range " - "[-rank(input), rank(input)), at ", - node_def.name()); - } - // Convert negative axis to corresponding positive axis. - if (axis < 0) axis += input_rank; - // Don't squeeze batch dim. - if (axis == 0) { - return errors::Unimplemented("Cannot squeeze batch dimension, at ", - node_def.name()); - } + int trt_axis; + TF_RETURN_IF_ERROR( + ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); // Make sure target dimension is size 1. - if (input_dims[axis] != 1) { + if (input_dims[trt_axis] != 1) { return errors::InvalidArgument( - "Cannot squeeze a dimension which isn't size 1, at ", + "Dimension ", tf_axis, " with size ", input_dims[trt_axis], + " cannot be squeezed because it must be size 1, at ", node_def.name()); } // Mark dim for removal by setting to 0. - input_dims[axis] = 0; + input_dims[trt_axis] = 0; } if (params->validation_only) return Status::OK(); @@ -2256,20 +2383,19 @@ Status ConvertSqueeze(OpConverterParams* params) { input_dims.end()); // Reshape tensor. nvinfer1::Dims new_dims; - TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, - /*ignore_first_dim=*/true)); - const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims)); + nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( input_tensor, new_dims, /*validation_only=*/false, &output_tensor)); - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } +template Status ConvertStridedSliceHelper(OpConverterParams* params, const TRT_TensorOrWeights& input, - std::vector begin, std::vector size, - const std::vector& stride) { + Container begin, Container size, + const Container& stride) { const auto& node_def = params->node_def; // Get input dims. nvinfer1::Dims dims = input.GetTrtDims(); @@ -2294,10 +2420,9 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, node_def.op(), ", at ", node_def.name()); } } -// TRT 5.1 adds a slice layer. For older versions, we attempt to use the +// TRT 5.1 adds ISliceLayer. For older versions, we attempt to use the // padding layer with negative padding. -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) - // Use ISliceLayer. +#if IS_TRT_VERSION_GE(5, 1, 3, 1) nvinfer1::Dims begin_dims, size_dims, stride_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims, /*ignore_first_dim=*/true)); @@ -2308,8 +2433,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, if (params->validation_only) return Status::OK(); nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice( - *const_cast(input.tensor()), begin_dims, size_dims, - stride_dims); + *input.tensor(), begin_dims, size_dims, stride_dims); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); #else @@ -2326,9 +2450,8 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, // Rank must be 2, 3 or 4. if (input_dims.size() > 4) { return errors::Unimplemented(node_def.op(), - " for tensors with rank > 4 is " - "not supported in this version of " - "TRT, at ", + " for tensors with rank > 4 is not supported " + "in this version of TRT, at ", node_def.name()); } // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. @@ -2361,8 +2484,8 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, // TODO(tmorris): Allow empty engines in the unit tests and return the input // as output here. if (params->validation_only) return Status::OK(); - nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle( - *const_cast(input.tensor())); + nvinfer1::IShuffleLayer* layer = + params->converter->network()->addShuffle(*input.tensor()); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); } else if (pad_dims.size() == 1) { @@ -2408,32 +2531,26 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, if (params->validation_only) return Status::OK(); // Start conversion. - nvinfer1::ITensor* tensor = const_cast(input.tensor()); + nvinfer1::ITensor* tensor = input.tensor(); if (need_reshape) { - const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, reshape_dims, /*validation_only=*/false, &output_tensor)); - tensor = const_cast(output_tensor); + input, reshape_dims, /*validation_only=*/false, &tensor)); } if (need_transpose) { - const nvinfer1::ITensor* output_tensor = nullptr; - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, transpose_order, &output_tensor)); - tensor = const_cast(output_tensor); + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, transpose_order, &tensor)); } // Add padding layer nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( - *const_cast(tensor), pre_padding, post_padding); + *tensor, pre_padding, post_padding); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); params->converter->MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); // Restore transpose if (need_transpose) { - const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, inv_transpose_order, &output_tensor)); - tensor = const_cast(output_tensor); + tensor, inv_transpose_order, &tensor)); } // Restore reshape if (need_reshape) { @@ -2455,15 +2572,12 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, nvinfer1::Dims new_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, /*ignore_first_dim=*/true)); - const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor), new_dims, /*validation_only=*/false, - &output_tensor)); - tensor = const_cast(output_tensor); + &tensor)); } - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(tensor))); + params->outputs->push_back(TRT_TensorOrWeights(tensor)); return Status::OK(); #endif } @@ -2521,83 +2635,93 @@ Status ConvertStridedSlice(OpConverterParams* params) { {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}})); TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); - // Get input dims. - nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); - std::vector input_dims(dims.d, dims.d + dims.nbDims); - // Add batch dimension so that indexes line up properly. - input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); - // Get begin and end bounds per axis. - std::vector begin = inputs.at(1).weights().ToVector(); - std::vector end = inputs.at(2).weights().ToVector(); - std::vector stride = inputs.at(3).weights().ToVector(); - if (!AllLengthsEqual({input_dims, begin, end, stride})) { - return errors::InvalidArgument( - "Length of begin, end, and stride arguments must equal rank of input " - "for StridedSlice, at ", - node_def.name()); - } - // Unsupported mask options. + TFAttrs attrs(node_def); - for (const string& attr : - {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + // Unsupported mask options. + for (const string& attr : {"new_axis_mask", "shrink_axis_mask"}) { int attr_val = attrs.get(attr); if (attr_val != 0) { return errors::Unimplemented( attr, " is not supported for StridedSlice, at ", node_def.name()); } } - const int begin_mask = attrs.get("begin_mask"); - const int end_mask = attrs.get("end_mask"); - // Check that batch dimension is unmodified. - const bool begin_is_modified = !(begin_mask & 1) && begin[0] != 0; - const bool stride_is_modified = stride[0] != 1; - // If the batch size is -1 and the end mask is not set, we can only know if - // the batch dimension is unmodified when the batch size is defined. When the - // batch size is undefined, we don't convert to be safe. - const bool batch_size_is_defined = input_dims[0] > 0; - const bool end_is_modified = - !(end_mask & 1) && (!batch_size_is_defined || - (batch_size_is_defined && end[0] != input_dims[0])); - if (begin_is_modified || stride_is_modified || end_is_modified) { - return errors::Unimplemented( - "TensorRT does not allow modifications to the batch dimension, at ", - node_def.name()); - } - // Standarize begin and end bounds by applying masks, making negative values - // positive, and correcting out of bounds ranges (StridedSlice does this - // silently). - for (int i = 1; i < input_dims.size(); i++) { - // Begin - if ((1 << i) & begin_mask) { - begin[i] = 0; - } else if (begin[i] < 0) { - begin[i] += input_dims[i]; - } - begin[i] = std::max(0, std::min(begin[i], input_dims[i])); - // End - if ((1 << i) & end_mask) { - end[i] = input_dims[i]; - } else if (end[i] < 0) { - end[i] += input_dims[i]; - } - end[i] = std::max(0, std::min(end[i], input_dims[i])); + const int32 begin_mask = attrs.get("begin_mask"); + const int32 end_mask = attrs.get("end_mask"); + const int32 ellipsis_mask = attrs.get("ellipsis_mask"); + + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. Set it to -1 if it's + // unknown, so ValidateStridedSliceOp() can handle it correctly below. + input_dims.insert(input_dims.begin(), + std::max(-1, inputs.at(0).batch_size())); + + const TRT_ShapedWeights& begin_weights = inputs.at(1).weights(); + const TRT_ShapedWeights& end_weights = inputs.at(2).weights(); + const TRT_ShapedWeights& stride_weights = inputs.at(3).weights(); + if (!AllLengthsEqual({begin_weights.ToVector(), + end_weights.ToVector(), + stride_weights.ToVector()})) { + return errors::InvalidArgument( + "Length of begin, end, and stride must be equal, at ", node_def.name()); } + + PartialTensorShape input_shape(input_dims); + PartialTensorShape processing_shape; + PartialTensorShape final_shape; + bool is_identity; + bool is_simple_slice; + bool slice_dim0; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; + TF_RETURN_IF_ERROR(ValidateStridedSliceOp( + &begin_weights.GetTensor(), &end_weights.GetTensor(), + stride_weights.GetTensor(), input_shape, begin_mask, end_mask, + ellipsis_mask, /*new_axis_mask=*/0, + /*shrink_axis_mask=*/0, &processing_shape, &final_shape, &is_identity, + &is_simple_slice, &slice_dim0, &begin, &end, &strides)); + // Negative or zero strides currently not supported. - for (int i = 0; i < input_dims.size(); i++) { - if (stride[i] <= 0) { + for (int stride : strides) { + if (stride <= 0) { return errors::Unimplemented( "Negative or zero stride values are not supported for StridedSlice, " "at ", node_def.name()); } } + + // If batch dimension is covered by the ellipsis mask, it means it's left + // untouched. Otherwise we check whether it modifies the batch dimension here. + if (!(ellipsis_mask & 1) || + begin_weights.shape_.nbDims >= input_dims.size()) { + // Check that batch dimension is unmodified. We need to use the expanded + // begin/end/strides array since the original array may be incorrect when + // (ellipsis_mask&1)==1. + const bool begin_is_modified = !(begin_mask & 1) && (begin[0] != 0); + const bool stride_is_modified = (strides[0] != 1); + // If the batch size is -1 and the end mask is not set, we can only know if + // the batch dimension is unmodified when the batch size is defined. When + // the batch size is undefined, we don't convert to be safe. + const bool batch_size_is_defined = (input_dims[0] > 0); + const bool end_is_modified = + !(end_mask & 1) && (!batch_size_is_defined || + (batch_size_is_defined && end[0] != input_dims[0])); + if (begin_is_modified || stride_is_modified || end_is_modified) { + return errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + } // TRT Slice layer uses (begin, size) instead of (begin, end) - std::vector size(input_dims.size()); + absl::InlinedVector size(input_dims.size()); for (int i = 0; i < input_dims.size(); i++) { // Divide by stride (round up) - size[i] = (end[i] - begin[i] + stride[i] - 1) / stride[i]; + size[i] = (end[i] - begin[i] + strides[i] - 1) / strides[i]; } - return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, strides); } Status ConvertConv2D(OpConverterParams* params) { @@ -2635,15 +2759,15 @@ Status ConvertPool(OpConverterParams* params) { } if (params->validation_only) return Status::OK(); - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int h_index = 2; int w_index = 3; const auto data_format = attrs.get("data_format"); if (data_format == "NHWC") { h_index = 1; w_index = 2; - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(tensor), {0, 3, 1, 2}, &tensor)); + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, {0, 3, 1, 2}, &tensor)); } const auto tf_stride = attrs.get>("strides"); @@ -2665,56 +2789,81 @@ Status ConvertPool(OpConverterParams* params) { padding = {{0, 0}, {0, 0}}; } +// TensorRT 5.1 added support for asymmetric padding. +#if !IS_TRT_VERSION_GE(5, 1, 0, 0) if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second << padding[1].first << padding[1].second; auto pad_layer = params->converter->network()->addPadding( - *const_cast(tensor), - nvinfer1::DimsHW(padding[0].first, padding[1].first), + *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); params->converter->MarkQuantizationRangesAsInferrable( - const_cast(tensor), pad_layer->getOutput(0)); + tensor, pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); } +#endif - nvinfer1::IPoolingLayer* layer = params->converter->network()->addPooling( - *const_cast(tensor), type, ksize); + nvinfer1::IPoolingLayer* layer = + params->converter->network()->addPooling(*tensor, type, ksize); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); // TODO(tmorris): Average pooling may not be entirely safe to infer // quantization range through (at least forwards - backwards should be fine). // Max pooling is okay. - params->converter->MarkQuantizationRangesAsInferrable( - const_cast(tensor), layer->getOutput(0)); + params->converter->MarkQuantizationRangesAsInferrable(tensor, + layer->getOutput(0)); layer->setStride(stride); - layer->setPadding({padding[0].first, padding[1].first}); +// TensorRT 5.1.3 added support for padding modes. +#if IS_TRT_VERSION_GE(5, 1, 3, 0) + if (attrs.get("padding") == "SAME") { + // SAME_UPPER means that post padding is preferred. + layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } +#endif +// TensorRT 5.1 has support for asymmetric padding. +#if IS_TRT_VERSION_GE(5, 1, 0, 0) + // If padding mode is not SAME, then these values will be used instead. + layer->setPrePadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); + layer->setPostPadding(nvinfer1::DimsHW{padding[0].second, padding[1].second}); +#else + layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); +#endif layer->setName(node_def.name().c_str()); - const nvinfer1::ITensor* output_tensor = layer->getOutput(0); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NHWC") { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(output_tensor), {0, 2, 3, 1}, - &output_tensor)); + output_tensor, {0, 2, 3, 1}, &output_tensor)); } - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } -// TODO(tmorris): Use ActivationType::kLEAKY_RELU in TRT 5.1+ once perf -// improves. Status ConvertLeakyRelu(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); - TFAttrs attrs(node_def); const float alpha = attrs.get("alpha"); + +#if IS_TRT_VERSION_GE(5, 1, 2, 0) + // Use IActivationLayer when available. + if (params->validation_only) return Status::OK(); + + nvinfer1::IActivationLayer* layer = + params->converter->network()->addActivation( + *inputs.at(0).tensor(), nvinfer1::ActivationType::kLEAKY_RELU); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setAlpha(alpha); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +#else + // Use elementwise ops when IActivationLayer is not available. if (alpha < 0.0f || alpha > 1.0f) { return errors::Unimplemented( "Alpha value for LeakyRelu must be between 0 and 1, at ", @@ -2722,32 +2871,89 @@ Status ConvertLeakyRelu(OpConverterParams* params) { } if (params->validation_only) return Status::OK(); - // Input Tensor - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); // Create const for alpha. - const nvinfer1::ITensor* const_alpha_tensor = nullptr; + nvinfer1::ITensor* const_alpha_tensor = nullptr; TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( params, alpha, tensor->getDimensions(), &const_alpha_tensor)); // alpha * x nvinfer1::IElementWiseLayer* mul_layer = params->converter->network()->addElementWise( - *const_cast(tensor), - *const_cast(const_alpha_tensor), - nvinfer1::ElementWiseOperation::kPROD); + *tensor, *const_alpha_tensor, nvinfer1::ElementWiseOperation::kPROD); TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name()); // max(x, alpha * x) nvinfer1::IElementWiseLayer* max_layer = params->converter->network()->addElementWise( - *const_cast(tensor), - *const_cast(mul_layer->getOutput(0)), + *tensor, *mul_layer->getOutput(0), nvinfer1::ElementWiseOperation::kMAX); TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name()); nvinfer1::ITensor* output_tensor = max_layer->getOutput(0); params->converter->MarkQuantizationRangesAsInferrable( - output_tensor, const_cast(mul_layer->getOutput(0))); + output_tensor, mul_layer->getOutput(0)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); +#endif +} + +#if IS_TRT_VERSION_GE(5, 1, 2, 0) +Status ConvertClipByValue(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + // TODO(tmorris): We can also allow the case where min and max are tensors by + // using elementwise min and max layers. + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"t", false}, {"clip_value_min", true}, {"clip_value_max", true}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + if (params->validation_only) return Status::OK(); + + TFAttrs attrs(node_def); + const DataType dtype = attrs.get("T"); + float clip_value_min = 0.0f; + float clip_value_max = 0.0f; + // TODO(tmorris): Add a templated helper function to get scalar weights of + // InType casted to OutType. + if (dtype == DataType::DT_FLOAT) { + clip_value_min = inputs.at(1).weights().GetSpan()[0]; + clip_value_max = inputs.at(2).weights().GetSpan()[0]; + } else if (dtype == DataType::DT_HALF) { + clip_value_min = Eigen::half_impl::half_to_float( + inputs.at(1).weights().GetSpan()[0]); + clip_value_max = Eigen::half_impl::half_to_float( + inputs.at(2).weights().GetSpan()[0]); + } + + nvinfer1::IActivationLayer* layer = + params->converter->network()->addActivation( + *inputs.at(0).tensor(), nvinfer1::ActivationType::kCLIP); + layer->setAlpha(clip_value_min); + layer->setBeta(clip_value_max); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + params->converter->ProvideQuantizationRange(output_tensor, clip_value_min, + clip_value_max); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} +#endif + +const std::unordered_map* +ActivationTypeMap() { + static auto* const m = + new std::unordered_map({ + {"Relu", nvinfer1::ActivationType::kRELU}, + {"Sigmoid", nvinfer1::ActivationType::kSIGMOID}, + {"Tanh", nvinfer1::ActivationType::kTANH}, +#if IS_TRT_VERSION_GE(5, 1, 2, 0) + {"Elu", nvinfer1::ActivationType::kELU}, + {"Selu", nvinfer1::ActivationType::kSELU}, + {"Softsign", nvinfer1::ActivationType::kSOFTSIGN}, + {"Softplus", nvinfer1::ActivationType::kSOFTPLUS}, +#endif + }); + return m; } Status ConvertActivation(OpConverterParams* params) { @@ -2756,30 +2962,39 @@ Status ConvertActivation(OpConverterParams* params) { TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); - static const std::unordered_map ops{ - {"Relu", nvinfer1::ActivationType::kRELU}, - {"Sigmoid", nvinfer1::ActivationType::kSIGMOID}, - {"Tanh", nvinfer1::ActivationType::kTANH}, - }; - auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) { + auto op_pair = ActivationTypeMap()->find(node_def.op()); + if (op_pair == ActivationTypeMap()->end()) { return errors::Unimplemented("Activation op: ", node_def.op(), " not supported at: ", node_def.name()); } if (params->validation_only) return Status::OK(); // Start conversion. - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); nvinfer1::IActivationLayer* layer = - params->converter->network()->addActivation( - *const_cast(tensor), op_pair->second); + params->converter->network()->addActivation(*inputs.at(0).tensor(), + op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Set parameters. +#if IS_TRT_VERSION_GE(5, 1, 2, 0) + if (node_def.op() == "Elu") { + layer->setAlpha(1.0f); + } else if (node_def.op() == "Selu") { + // From tensorflow/core/kernels/relu_op_functor.h + layer->setAlpha(1.7580993408473768599402175208123f); + layer->setBeta(1.0507009873554804934193349852946f); + } else if (node_def.op() == "Softplus") { + layer->setAlpha(1.0f); + layer->setBeta(1.0f); + } +#endif nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // Set quantization range for output of Sigmoid, Tanh. + // Set quantization range for output when known. if (node_def.op() == "Sigmoid") { params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f); } else if (node_def.op() == "Tanh") { params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f); + } else if (node_def.op() == "Softsign") { + params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -2818,8 +3033,8 @@ Status ConvertQuantize(OpConverterParams* params) { node_def.op() == "QuantizeAndDequantizeV3") { // Get ranges via inputs. auto get_weights_value = [&inputs](int index) { - auto raw_weights = static_cast( - const_cast(inputs.at(index).weights().GetValues())); + auto raw_weights = + static_cast(inputs.at(index).weights().GetValues()); return raw_weights[0]; }; min_range = get_weights_value(1); @@ -2831,9 +3046,8 @@ Status ConvertQuantize(OpConverterParams* params) { if (params->validation_only) return Status::OK(); // Store ranges for tensor - params->converter->ProvideQuantizationRange( - const_cast(inputs.at(0).tensor()), min_range, - max_range); + params->converter->ProvideQuantizationRange(inputs.at(0).tensor(), min_range, + max_range); // Sometimes, TRT may not quantize a tensor, either because it chooses to // execute a higher precision kernel or because of op fusion. In these cases, // accuracy will suffer if the model was trained to expect quantization at @@ -2847,7 +3061,6 @@ Status ConvertQuantize(OpConverterParams* params) { return Status::OK(); } -// TODO(tmorris): Use ActivationType::kCLIP in TRT 5.1+ once perf improves. Status ConvertRelu6(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -2855,19 +3068,28 @@ Status ConvertRelu6(OpConverterParams* params) { TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); if (params->validation_only) return Status::OK(); - // *************************************************************************** - // TensorRT does not implement Relu6 natively. This function converts Relu6 op - // to available TensorRT ops: Relu6(x) = min(Relu(x), 6) - // *************************************************************************** +#if IS_TRT_VERSION_GE(5, 1, 2, 0) + // Use IActivationLayer for TRT >= 5.1 + nvinfer1::IActivationLayer* layer = + params->converter->network()->addActivation( + *inputs.at(0).tensor(), nvinfer1::ActivationType::kCLIP); + layer->setAlpha(0.0f); + layer->setBeta(6.0f); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +#else + // Convert using min(Relu(x), 6) before TRT 5.1 // Input Tensor - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); // Relu operation i.e. Relu(x) = max(0, x) nvinfer1::IActivationLayer* relu_layer = params->converter->network()->addActivation( - *const_cast(tensor), - nvinfer1::ActivationType::kRELU); + *tensor, nvinfer1::ActivationType::kRELU); TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name()); // Large range of relu is problematic during quantization in INT8 precision @@ -2878,7 +3100,7 @@ Status ConvertRelu6(OpConverterParams* params) { 6.0f); // Create a constant layer to store the floating point weight i.e. 6.0f - const nvinfer1::ITensor* const6_tensor = nullptr; + nvinfer1::ITensor* const6_tensor = nullptr; TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( params, 6.0f, relu_layer->getOutput(0)->getDimensions(), &const6_tensor)); @@ -2887,8 +3109,7 @@ Status ConvertRelu6(OpConverterParams* params) { // to this layer will only have values in range [0.f, 6.0f]. nvinfer1::IElementWiseLayer* relu6_layer = params->converter->network()->addElementWise( - *const_cast(relu_layer->getOutput(0)), - *const_cast(const6_tensor), + *relu_layer->getOutput(0), *const6_tensor, nvinfer1::ElementWiseOperation::kMIN); TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); @@ -2896,6 +3117,7 @@ Status ConvertRelu6(OpConverterParams* params) { params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); +#endif } Status ConvertBiasAdd(OpConverterParams* params) { @@ -2907,8 +3129,7 @@ Status ConvertBiasAdd(OpConverterParams* params) { AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); if (params->validation_only) return Status::OK(); - nvinfer1::ITensor* tensor = - const_cast(inputs.at(0).tensor()); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); const nvinfer1::Dims original_dims = tensor->getDimensions(); TFAttrs attrs(node_def); const string data_format = attrs.get("data_format"); @@ -2956,15 +3177,12 @@ Status ConvertBiasAdd(OpConverterParams* params) { } TRT_ShapedWeights weights = inputs.at(1).weights(); - if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { - weights = ConvertFP32ToFP16(params->weight_store, weights); - } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; if (weights.shape_.d[0] == 1) { mode = nvinfer1::ScaleMode::kUNIFORM; } - TRT_ShapedWeights empty_weights(weights.type_); + TRT_ShapedWeights empty_weights(weights.TrtDType()); nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(), empty_weights.GetTrtWeights()); @@ -3014,55 +3232,67 @@ void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) { } } +template +void CopyToTrtInt32Array(const Tensor& tensor, int32* dst) { + typedef typename EnumToDataType::Type CType; + const CType* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); +} + Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, TRT_ShapedWeights* weights) { const DataType dtype = tensor.dtype(); - // We always convert the integer constants to INT32, since TRT INT8 is for - // quantized inference. + // We always convert the integer constants to INT32. // // TODO(aaroey): FP16 will remain in half format and is not converted to // FP32, but the converter currently uses all float weights as FP32. Fix // this. - const DataType converted_dtype = - (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 - : dtype); + DataType converted_dtype = dtype; + if (dtype == DataType::DT_INT8 || dtype == DataType::DT_UINT8 || + dtype == DataType::DT_INT16 || dtype == DataType::DT_UINT16) { + converted_dtype = DT_INT32; + } // Verify that the dtype is supported by TensorRT. Otherwise, return an error. nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); + TF_RETURN_IF_ERROR(TfDataTypeToTrt(converted_dtype, &trt_dtype)); if (tensor.NumElements() == 0) { - // Return empty weights having converted dtype. - *weights = TRT_ShapedWeights(converted_dtype); + // Return empty weights. + *weights = TRT_ShapedWeights(trt_dtype); return Status::OK(); } nvinfer1::Dims weight_dims; GetTensorDimsWithProtoShape(tensor, &weight_dims); - *weights = weight_store->GetTempWeights(converted_dtype, weight_dims); + *weights = weight_store->GetTempWeights(trt_dtype, weight_dims); // Copy the tensor directly if the tensor does not require cast to the // supported type. if (converted_dtype == dtype) { - char* dst = static_cast(const_cast(weights->GetValues())); + char* dst = static_cast(weights->GetValues()); memcpy(dst, tensor.tensor_data().data(), tensor.TotalBytes()); return Status::OK(); } // Copy tensor elements after casting them to the converted DataType. - int32* dst = static_cast(const_cast(weights->GetValues())); - if (dtype == DT_INT16) { - const int16* src = tensor.flat().data(); - std::copy(src, src + tensor.NumElements(), dst); - } else if (dtype == DT_INT8) { - const int8* src = tensor.flat().data(); - std::copy(src, src + tensor.NumElements(), dst); - } else { - // dtype can only be DT_UINT8 at this point. - TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8); - const uint8* src = tensor.flat().data(); - std::copy(src, src + tensor.NumElements(), dst); + int32* dst = static_cast(weights->GetValues()); + switch (dtype) { + case DT_INT8: + CopyToTrtInt32Array(tensor, dst); + break; + case DT_UINT8: + CopyToTrtInt32Array(tensor, dst); + break; + case DT_INT16: + CopyToTrtInt32Array(tensor, dst); + break; + case DT_UINT16: + CopyToTrtInt32Array(tensor, dst); + break; + default: + return errors::Internal("Unexpected DataType: ", DataTypeString(dtype)); } return Status::OK(); } @@ -3192,10 +3422,10 @@ Status ConvertRsqrt(OpConverterParams* params) { node_def.name()); } // Start conversion. - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); // Sqrt nvinfer1::IUnaryLayer* sqrt_layer = params->converter->network()->addUnary( - *const_cast(tensor), nvinfer1::UnaryOperation::kSQRT); + *tensor, nvinfer1::UnaryOperation::kSQRT); TFTRT_RETURN_ERROR_IF_NULLPTR(sqrt_layer, node_def.name()); // Recip nvinfer1::IUnaryLayer* recip_layer = params->converter->network()->addUnary( @@ -3215,7 +3445,7 @@ UnaryOperationMap() { {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, {"Abs", nvinfer1::UnaryOperation::kABS}, {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0, 0) {"Sin", nvinfer1::UnaryOperation::kSIN}, {"Cos", nvinfer1::UnaryOperation::kCOS}, {"Tan", nvinfer1::UnaryOperation::kTAN}, @@ -3248,9 +3478,9 @@ Status ConvertUnary(OpConverterParams* params) { if (params->validation_only) return Status::OK(); // Start conversion. - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( - *const_cast(tensor), op_pair->second); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::IUnaryLayer* layer = + params->converter->network()->addUnary(*tensor, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -3265,11 +3495,10 @@ Status ConvertUnary(OpConverterParams* params) { // Neg and Abs will have same range as input since TRT uses symmetric // quantization. // TODO(tmorris): Should we infer ranges for Ceil and Floor as well? - params->converter->MarkQuantizationRangesAsInferrable( - const_cast(tensor), output_tensor); + params->converter->MarkQuantizationRangesAsInferrable(tensor, + output_tensor); } - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -3282,15 +3511,14 @@ Status ConvertSquare(OpConverterParams* params) { if (params->validation_only) return Status::OK(); // Constant 2 with same rank as input - const nvinfer1::ITensor* const2_tensor = nullptr; + nvinfer1::ITensor* const2_tensor = nullptr; TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( params, 2.0f, inputs.at(0).GetTrtDims(), &const2_tensor)); // ElementWise Pow Operation nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( - *const_cast(inputs.at(0).tensor()), - *const_cast(const2_tensor), + *inputs.at(0).tensor(), *const2_tensor, nvinfer1::ElementWiseOperation::kPOW); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -3307,8 +3535,8 @@ Status ConvertReduce(OpConverterParams* params) { TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - TRT_ShapedWeights index_list = inputs.at(1).weights(); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + auto tf_axes_list = inputs.at(1).weights().GetSpan(); TFAttrs attrs(node_def); // Only expect to handle INT32 as attributes for now @@ -3317,22 +3545,17 @@ Status ConvertReduce(OpConverterParams* params) { } int axes = 0; - if (index_list.count() == 0) { + if (tf_axes_list.size() == 0) { return errors::InvalidArgument( "TRT cannot support reduce on all (batch) dimensions, at", node_def.name()); - } else { - auto index_list_data = - static_cast(const_cast(index_list.GetValues())); - for (int i = 0; i < index_list.count(); i++) { - int axis = index_list_data[i]; - if (axis < 0) axis += tensor->getDimensions().nbDims + 1; - if (axis == 0) { - return errors::InvalidArgument( - "TRT cannot reduce at batch dimension, at", node_def.name()); - } - axes |= (1 << (axis - 1)); - } + } + for (int i = 0; i < tf_axes_list.size(); i++) { + int trt_axis; + TF_RETURN_IF_ERROR(ConvertAxis(tf_axes_list[i], + tensor->getDimensions().nbDims, + node_def.name(), &trt_axis)); + axes |= (1 << trt_axis); } nvinfer1::ReduceOperation reduce_operation; @@ -3354,14 +3577,87 @@ Status ConvertReduce(OpConverterParams* params) { const auto keep_dims = attrs.get("keep_dims"); nvinfer1::ILayer* layer = params->converter->network()->addReduce( - *const_cast(tensor), reduce_operation, axes, - keep_dims); + *tensor, reduce_operation, axes, keep_dims); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); } +// TensorRT does not support the Pack op natively. Therefore, Pack op is +// converted by first expanding input tensors by adding a new dimension of size +// one at the specified axis and then concatenating the tensors at the same +// axis. +Status ConvertPack(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + + TFAttrs attrs(node_def); + const int num_inputs = attrs.get("N"); + if (num_inputs != inputs.size()) { + return errors::InvalidArgument( + "Number of inputs for Pack is inconsistent with N attribute, at ", + node_def.name()); + } + + // Validate inputs. Values must be tensors for now. + std::vector> inputs_is_weight; + for (int i = 0; i < num_inputs; ++i) { + inputs_is_weight.push_back({StrCat("values_", i), false}); + } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, inputs_is_weight)); + + // TODO(hinsu): Enable INT32 with TensorRT version 5.1.3 after testing. + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + + if (num_inputs > 1) { + // Verify that inputs are compatible for concatenation after the expansion. + TF_RETURN_IF_ERROR( + VerifyShapesMatch(inputs, /*masked_dim=*/-1, node_def.name())); + } + + // Convert axis from the TensorFlow format to TensorRT format. + const nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + const int64 tf_axis = attrs.get("axis"); + int trt_axis; + TF_RETURN_IF_ERROR( + ConvertAxis(tf_axis, dims.nbDims + 1, node_def.name(), &trt_axis)); + + // Compute expanded dimensions and then reshape input tensors. + std::vector tensor_dims(dims.d, dims.d + dims.nbDims); + tensor_dims.insert(tensor_dims.begin() + trt_axis, 1); + nvinfer1::Dims expanded_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(tensor_dims, &expanded_dims)); + std::vector expanded_tensors; + for (const TRT_TensorOrWeights& tensor : inputs) { + nvinfer1::ITensor* expanded_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + tensor, expanded_dims, params->validation_only, &expanded_tensor)); + if (!params->validation_only) { + expanded_tensors.push_back(expanded_tensor); + } + } + if (params->validation_only) return Status::OK(); + + // If there is only one tensor in the input, return the expanded tensor. + if (num_inputs == 1) { + params->outputs->push_back(TRT_TensorOrWeights(expanded_tensors[0])); + return Status::OK(); + } + + // Otherwise, concatenate expanded tensors. + nvinfer1::IConcatenationLayer* layer = + params->converter->network()->addConcatenation( + const_cast(expanded_tensors.data()), + expanded_tensors.size()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Note that trt_axis stays the same even after expanding tensors at the axis. + layer->setAxis(trt_axis); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + Status ConvertPad(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -3371,7 +3667,7 @@ Status ConvertPad(OpConverterParams* params) { AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); // Implement tensor binaryOp weight [channel wise] for now; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); const auto dims = tensor->getDimensions(); // Restore implicit batch dimension const int nb_dims = dims.nbDims + 1; @@ -3394,7 +3690,7 @@ Status ConvertPad(OpConverterParams* params) { if (padding_type != DataType::DT_INT32) { return errors::Unimplemented("Tpaddings supports only DT_INT32"); } - auto pad_data = static_cast(const_cast(pads.GetValues())); + auto pad_data = static_cast(pads.GetValues()); std::vector pad_index; for (int i = 0; i < nb_dims; i++) { @@ -3436,8 +3732,8 @@ Status ConvertPad(OpConverterParams* params) { std::vector permuted_pad_index(pad_index); if (pad_index[0] == 1) { legit_pad = false; - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(tensor), {0, 3, 2, 1}, &tensor)); + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, {0, 3, 2, 1}, &tensor)); permuted_pad_index[0] = 3; } @@ -3453,99 +3749,175 @@ Status ConvertPad(OpConverterParams* params) { } nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( - *const_cast(tensor), pre_padding, post_padding); + *tensor, pre_padding, post_padding); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - const nvinfer1::ITensor* output_tensor = layer->getOutput(0); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (!legit_pad) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(output_tensor), {0, 3, 2, 1}, - &output_tensor)); + output_tensor, {0, 3, 2, 1}, &output_tensor)); } - params->outputs->push_back( - TRT_TensorOrWeights(const_cast(output_tensor))); + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } +Status ConvertSplitHelper(OpConverterParams* params, + const TRT_TensorOrWeights& input, int tf_axis, + int num_splits, bool squeeze_after) { + const auto& node_def = params->node_def; + const nvinfer1::Dims dims = input.GetTrtDims(); + // Convert axis. + int trt_axis; + TF_RETURN_IF_ERROR( + ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); + // Dimension must equal num_splits for Unstack (when squeeze_after is true) + if (squeeze_after && dims.d[trt_axis] != num_splits) { + return errors::InvalidArgument( + "Dimension ", tf_axis, " has size ", dims.d[trt_axis], + " which is not equal to num of ", num_splits, ", at ", node_def.name()); + } + // Dimension must be evenly divisible by num_splits. + if (dims.d[trt_axis] % num_splits != 0) { + return errors::InvalidArgument( + "Dimension ", tf_axis, " of size ", dims.d[trt_axis], + " is not evenly divisble by ", num_splits, ", at ", node_def.name()); + } + + // Create parameters for StridedSliceHelper. + // Slice will begin on zero for all dims, except the one being split which + // will change. + std::vector begin(dims.nbDims, 0); + // Determine size of split. Slice will get the full length of all dims, except + // the one being split. + std::vector size(dims.d, dims.d + dims.nbDims); + const int split_size_on_axis = dims.d[trt_axis] / num_splits; + size[trt_axis] = split_size_on_axis; + // Stride will always be 1 + std::vector stride(dims.nbDims, 1); + // Add dummy batch dimension + begin.insert(begin.begin(), 0); + size.insert(size.begin(), 1); + stride.insert(stride.begin(), 1); + + // Slice the input. ConvertStridedSliceHelper will push the outputs onto + // params->outputs. + for (int i = 0; i < num_splits; ++i) { + begin[trt_axis + 1] = i * split_size_on_axis; + TF_RETURN_IF_ERROR( + ConvertStridedSliceHelper(params, input, begin, size, stride)); + } + if (params->validation_only) return Status::OK(); + + // For Unpack/Unstack, remove axis that we split upon. + if (squeeze_after) { + // Create the new shape. + size.erase(size.begin() + trt_axis + 1); + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR( + TensorShapeArrayToTrtDims(size, &new_dims, /*ignore_frst_dim=*/true)); + // Reshape each slice. + for (int i = 0; i < params->outputs->size(); i++) { + nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + params->outputs->at(i), new_dims, /*validation_only=*/false, + &output_tensor)); + (*params->outputs)[i] = TRT_TensorOrWeights(output_tensor); + } + } + return Status::OK(); +} + +Status ConvertSplit(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"axis", true}, {"value", false}})); + TF_RETURN_IF_ERROR(AllowDataTypes(*params, { + DataType::DT_FLOAT, DataType::DT_HALF, +#if IS_TRT_VERSION_GE(5, 1, 3, 1) + DataType::DT_INT32, +#endif + })); + int tf_axis = inputs.at(0).weights().GetSpan()[0]; + TFAttrs attrs(node_def); + const int num_split = attrs.get("num_split"); + + return ConvertSplitHelper(params, inputs.at(1), tf_axis, num_split, false); +} + +Status ConvertUnpack(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"value", false}})); + TF_RETURN_IF_ERROR(AllowDataTypes(*params, { + DataType::DT_FLOAT, DataType::DT_HALF, +#if IS_TRT_VERSION_GE(5, 1, 3, 1) + DataType::DT_INT32, +#endif + })); + // Input must be rank 1 or higher, since we can't unpack on axis 0. + if (inputs.at(0).GetTrtDims().nbDims == 0) { + return errors::Unimplemented( + "Input \"value\" for Unpack must be rank 2 or greater, at ", + node_def.name()); + } + TFAttrs attrs(node_def); + const int tf_axis = attrs.get("axis"); + const int num = attrs.get("num"); + + return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true); +} + Status ConvertConcat(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + TFAttrs attrs(node_def); + // Get number of tensor inputs. + const int num_inputs = attrs.get("N"); + if (num_inputs != static_cast(inputs.size()) - 1) { + return errors::InvalidArgument( + "Number of inputs for ConcatV2 is inconsistent with N attribute, at ", + node_def.name()); + } + // Validate inputs. Values must be tensors for now. + std::vector> inputs_is_weight; + for (int i = 0; i < num_inputs; ++i) { + inputs_is_weight.push_back({StrCat("values_", i), false}); + } + inputs_is_weight.push_back({"axis", true}); + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, inputs_is_weight)); // TODO(tmorris): There is a bug with Concat and INT32 in TRT - it is supposed // to be supported. TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); - // not including the last input (axis) here - int input_size = static_cast(inputs.size()) - 1; - - if (!inputs.at(0).is_tensor()) { - return errors::InvalidArgument( - "Concat in TRT support only Tensor input, at ", node_def.name()); - } - - // We are retrieving the axis - TRT_ShapedWeights axis = inputs.at(input_size).weights(); - - TFAttrs attrs(node_def); - auto index_type = attrs.get("Tidx"); - - // TODO(jie): handle data type - // Only expect to handle INT32 as index attributes for now - if (index_type != DataType::DT_INT32) - return errors::Unimplemented("Tidx supports only DT_INT32, at ", - node_def.name()); - - int index = *(static_cast(const_cast(axis.GetValues()))); - - // TODO(jie): early termination with no-op (attr_size==1) - - auto dim = inputs.at(0).tensor()->getDimensions(); - // dimension check - if (index > dim.nbDims + 1) { - return errors::InvalidArgument( - "Concatenate on axis out of dimension range, at ", node_def.name()); - } - if (index == 0) { - return errors::InvalidArgument( - "Concatenate on batch dimension not supported, at ", node_def.name()); - } - if (index < 0) { - index = dim.nbDims + index + 1; - } - - std::vector inputs_vec; - // Shap chack (all input tensor should have same shape) - // starting from 0 since we are probably also doing transpose here; - for (int i = 0; i < input_size; i++) { - auto tensor_i = inputs.at(i).tensor(); - auto dim_i = tensor_i->getDimensions(); - if (dim_i.nbDims != dim.nbDims) { - return errors::InvalidArgument( - "Concatenate receives inputs with inconsistent dimensions, at ", - node_def.name()); - } - for (int j = 0; j < dim.nbDims; j++) { - // check dimension consistency on non-concatenate axis - if (j != index - 1 && dim_i.d[j] != dim.d[j]) { - return errors::InvalidArgument( - "Concatenate receives inputs with inconsistent shape, at", - node_def.name()); - } - } - - inputs_vec.push_back(tensor_i); + const auto axis = inputs.at(num_inputs).weights().GetSpan(); + if (axis.size() != 1) { + return errors::InvalidArgument("Axis for ConcatV2 must be a scalar, at ", + node_def.name()); } + int trt_axis = 0; + const auto dim = inputs.at(0).GetTrtDims(); + TF_RETURN_IF_ERROR( + ConvertAxis(axis[0], dim.nbDims, node_def.name(), &trt_axis)); + // Check that dimensions match on non-concatenate axis. + TF_RETURN_IF_ERROR(VerifyShapesMatch( + absl::Span(inputs).first(num_inputs), trt_axis, + node_def.name())); if (params->validation_only) return Status::OK(); - // nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + // Gather inputs as tensors + std::vector input_tensors; + for (int i = 0; i < num_inputs; i++) { + input_tensors.push_back(inputs.at(i).tensor()); + } nvinfer1::IConcatenationLayer* layer = params->converter->network()->addConcatenation( - const_cast(inputs_vec.data()), - inputs_vec.size()); + const_cast(input_tensors.data()), + input_tensors.size()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setAxis(index - 1); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + layer->setAxis(trt_axis); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); } @@ -3579,18 +3951,18 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { " only supports is_training=false, at ", node_def.name()); } - nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); // Check parameter types - auto parameter_type = inputs.at(1).weights().type_; - if ((parameter_type != DataType::DT_FLOAT) && - (parameter_type != DataType::DT_HALF)) { + auto parameter_type = inputs.at(1).weights().TrtDType(); + if ((parameter_type != nvinfer1::DataType::kFLOAT) && + (parameter_type != nvinfer1::DataType::kHALF)) { return errors::Unimplemented( - "only float32 or float16 weight data type is supported, for node " + - node_def.name() + " got " + DataTypeString(parameter_type)); + "Only float32 or float16 weight data type is supported, for node ", + node_def.name(), " got ", DebugString(parameter_type)); } for (int i = 1; i < 5; i++) { - if (inputs.at(i).weights().type_ != parameter_type) { + if (inputs.at(i).weights().TrtDType() != parameter_type) { return errors::Unimplemented( "Inconsistent parameter type for batchnorm is not supported, at: " + node_def.name()); @@ -3602,11 +3974,10 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { for (int i = 1; i < 5; i++) { nweight = std::max(nweight, inputs.at(i).weights().count()); } - TRT_ShapedWeights* ptr_shape_weights = nullptr; + const TRT_ShapedWeights* ptr_shape_weights = nullptr; for (int i = 1; i < 5; i++) { if (inputs.at(i).weights().count() == nweight) { - ptr_shape_weights = - const_cast(&(inputs.at(i).weights())); + ptr_shape_weights = &(inputs.at(i).weights()); } else if (inputs.at(i).weights().count() != 1) { return errors::InvalidArgument( "Inconsistent batchnorm parameter count, at: " + node_def.name()); @@ -3629,29 +4000,29 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { vals_array[j] = static_cast(inputs.at(j + 1).weights().GetValues()); } - Eigen::half* cast_combined_scale_vals = const_cast( - static_cast(combined_scale_weights.GetValues())); - Eigen::half* cast_combined_offset_vals = const_cast( - static_cast(combined_offset_weights.GetValues())); - float* combined_scale_vals = const_cast( - static_cast(combined_scale_weights.GetValues())); - float* combined_offset_vals = const_cast( - static_cast(combined_offset_weights.GetValues())); + Eigen::half* cast_combined_scale_vals = + static_cast(combined_scale_weights.GetValues()); + Eigen::half* cast_combined_offset_vals = + static_cast(combined_offset_weights.GetValues()); + float* combined_scale_vals = + static_cast(combined_scale_weights.GetValues()); + float* combined_offset_vals = + static_cast(combined_offset_weights.GetValues()); for (size_t i = 0; i < nweight; ++i) { float batchnorm_data[4]; for (int j = 0; j < 4; j++) { if (inputs.at(j + 1).weights().count() != 1) { - if (parameter_type == DT_FLOAT) { + if (parameter_type == nvinfer1::DataType::kFLOAT) { batchnorm_data[j] = vals_array[j][i]; - } else if (parameter_type == DT_HALF) { + } else if (parameter_type == nvinfer1::DataType::kHALF) { batchnorm_data[j] = Eigen::half_impl::half_to_float(cast_vals_array[j][i]); } } else { - if (parameter_type == DT_FLOAT) { + if (parameter_type == nvinfer1::DataType::kFLOAT) { batchnorm_data[j] = vals_array[j][0]; - } else if (parameter_type == DT_HALF) { + } else if (parameter_type == nvinfer1::DataType::kHALF) { batchnorm_data[j] = Eigen::half_impl::half_to_float(cast_vals_array[j][0]); } @@ -3663,10 +4034,10 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { float variance = batchnorm_data[3]; float combined_scale_val = scale / sqrtf(variance + epsilon); float combined_offset_val = offset - mean * combined_scale_val; - if (parameter_type == DT_FLOAT) { + if (parameter_type == nvinfer1::DataType::kFLOAT) { combined_scale_vals[i] = combined_scale_val; combined_offset_vals[i] = combined_offset_val; - } else if (parameter_type == DT_HALF) { + } else if (parameter_type == nvinfer1::DataType::kHALF) { cast_combined_scale_vals[i] = Eigen::half(combined_scale_val); cast_combined_offset_vals[i] = Eigen::half(combined_offset_val); } @@ -3675,8 +4046,7 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { nvinfer1::ScaleMode mode = nweight == 1 ? nvinfer1::ScaleMode::kUNIFORM : nvinfer1::ScaleMode::kCHANNEL; nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( - *const_cast(tensor), mode, - combined_offset_weights.GetTrtWeights(), + *tensor, mode, combined_offset_weights.GetTrtWeights(), combined_scale_weights.GetTrtWeights(), dummy_power_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); @@ -3691,7 +4061,8 @@ Status ConvertGather(OpConverterParams* params) { TF_RETURN_IF_ERROR(CheckInputsWeights( *params, {{"params", false}, {"indices", false}, {"axis", true}})); TF_RETURN_IF_ERROR(AllowDataTypes( - *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}, + /*dtype_attr_name=*/"Tparams")); absl::Span axis = inputs.at(2).weights().GetSpan(); if (axis.size() != 1) { return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ", @@ -3700,58 +4071,184 @@ Status ConvertGather(OpConverterParams* params) { int trt_axis = 0; TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, node_def.name(), &trt_axis)); + const TRT_TensorOrWeights& params_tensor = inputs.at(0); + const TRT_TensorOrWeights& indices_tensor = inputs.at(1); + if (indices_tensor.batch_size() != 1) { + return errors::InvalidArgument("Only indices with batch 1 are supported."); + } + // Both input are tensors, and the TF gather result will have rank: + // (params.nbDims + 1) + (indices.nbDims + 1) - 1, + // where "+ 1" adds the batch dim. + const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims + + indices_tensor.GetTrtDims().nbDims + 1; + if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) { + return errors::InvalidArgument( + "Result of gather has dimension greater than ", + nvinfer1::Dims::MAX_DIMS + 1); + } if (params->validation_only) return Status::OK(); + // Note on how IGatherLayer works: if both the data and indices tensors have + // a batch size dimension of size N, it performs: + // for batchid in xrange(N): + // output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = ( + // data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn]) nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( - *const_cast(inputs.at(0).tensor()), - *const_cast(inputs.at(1).tensor()), trt_axis); + *params_tensor.tensor(), *indices_tensor.tensor(), trt_axis); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + + nvinfer1::ITensor* gather_output = layer->getOutput(0); + nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions(); + // Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT, + // and the other is for the output dimension that is squeezed by IGatherLayer + // because of the implicit batch dim in the indices (see the above note). + if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) { + return errors::Internal( + "Get unexpected output dimensions of IGatherLayer. Expect nbDims: ", + tf_gather_output_rank - 2, + ", actual nbDims: ", trt_gather_output_dims.nbDims); + } + // Reshape the output so after adding the implicit batch dim it'll match the + // output shape of TF GatherV2. + for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) { + trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1]; + } + trt_gather_output_dims.d[trt_axis] = 1; + ++trt_gather_output_dims.nbDims; + + nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(gather_output), trt_gather_output_dims, + /*validation_only=*/false, &output_tensor)); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } -Status ConvertMatMulHelper(OpConverterParams* params, - TRT_TensorOrWeights tensor_input, - TRT_ShapedWeights weights_raw, bool transpose_weight, - string node_name) { - nvinfer1::ITensor* output_tensor; - if (!tensor_input.is_tensor()) { - return errors::InvalidArgument("Input 0 expects tensor"); - } - const nvinfer1::ITensor* tensor = tensor_input.tensor(); - - TRT_ShapedWeights weights(weights_raw.type_); - if (transpose_weight) { - weights = weights_raw; - } else { - weights = params->weight_store->GetTempWeights(weights_raw); - ReorderCKtoKC(weights_raw, &weights); - } - TRT_ShapedWeights biases(weights.type_); - - int noutput = weights.shape_.d[0]; - - auto input_dim = tensor->getDimensions(); - while (input_dim.nbDims != 3) { +Status ConvertFullyConnectedHelper(OpConverterParams* params, + nvinfer1::ITensor* tensor_a, + TRT_ShapedWeights weights_b, + bool transpose_b, const string& node_name) { + // Reshape input to 3D - this will be a no-op unless using int8 precision. + auto input_dim = tensor_a->getDimensions(); + while (input_dim.nbDims < 3) { input_dim.d[input_dim.nbDims++] = 1; } TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - tensor_input, input_dim, /*validation_only=*/false, &tensor)); + TRT_TensorOrWeights(tensor_a), input_dim, /*validation_only=*/false, + &tensor_a)); + // FC layer will transpose weights, so we need to pre-transpose. + TRT_ShapedWeights weights(weights_b.TrtDType()); + if (!transpose_b) { + weights = params->weight_store->GetTempWeights(weights_b); + ReorderCKtoKC(weights_b, &weights); + } else { + weights = weights_b; + } + TRT_ShapedWeights biases(weights.TrtDType()); + const int noutput = weights.shape_.d[0]; nvinfer1::IFullyConnectedLayer* layer = params->converter->network()->addFullyConnected( - *const_cast(tensor), noutput, - weights.GetTrtWeights(), biases.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); - output_tensor = layer->getOutput(0); + *tensor_a, noutput, weights.GetTrtWeights(), biases.GetTrtWeights()); - const nvinfer1::ITensor* temp_tensor = nullptr; + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // Reshape output to 1D - this will be a no-op unless using int8 precision. auto output_dim = output_tensor->getDimensions(); output_dim.nbDims = 1; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(output_tensor), output_dim, /*validation_only=*/false, - &temp_tensor)); - output_tensor = const_cast(temp_tensor); + &output_tensor)); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} + +Status ConvertMatMulHelper(OpConverterParams* params, + TRT_TensorOrWeights input_a, + TRT_TensorOrWeights input_b, bool transpose_a, + bool transpose_b, string node_name) { + // TODO: ReorderCKtoKC is currently not general enough to transpose weights + // that are not 2D. + if ((transpose_a && input_a.is_weights() && + input_a.GetTrtDims().nbDims != 2) || + (transpose_b && input_b.is_weights() && + input_b.GetTrtDims().nbDims != 2)) { + return errors::InvalidArgument( + "Cannot currently transpose constant input if it is not 2 dimensional"); + } + + // If A is a tensor, we can only transpose if it is at least 3D in TF, + // or TRT will not do the correct transposition. + if (transpose_a && input_a.is_tensor() && input_a.GetTrtDims().nbDims < 2) { + return errors::InvalidArgument( + "Cannot transpose first input if it is a tensor with fewer than 2 " + "non-batch dimensions."); + } + + // If B is a tensor, then it must be at least 3D in TF, + // or TRT won't be able to handle the multiply correctly. + if (input_b.is_tensor() && input_b.GetTrtDims().nbDims < 2) { + return errors::InvalidArgument( + "Second input must either be a constant, or contain at least 2 " + "non-batch dimensions."); + } + if (params->validation_only) return Status::OK(); + + // If an FC layer can be used and would be faster, use that instead. + const bool should_use_fc = + !transpose_a && input_a.is_tensor() && input_b.is_weights() && + input_a.GetTrtDims().nbDims >= 3 && input_b.GetTrtDims().nbDims == 2; + // If int8 is specified, FC must be used, as MM does not support int8 at this + // time. + if (should_use_fc || + params->converter->precision_mode() == TrtPrecisionMode::INT8) { + return ConvertFullyConnectedHelper( + params, input_a.tensor(), input_b.weights(), transpose_b, node_name); + } + + constexpr auto get_matrix_op = + [](nvinfer1::ITensor* in, bool transpose) -> nvinfer1::MatrixOperation { + return (in->getDimensions().nbDims < 2) + ? nvinfer1::MatrixOperation::kVECTOR + : (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE + : nvinfer1::MatrixOperation::kNONE; + }; + + // If the MatMul operand is a constant, applies transposes at conversion-time + // as necessary. If the operand is a tensor, does nothing. If required + // transposes were applied, sets transpose to false. + const auto prepare_matmul_operand = + [¶ms](TRT_TensorOrWeights operand, + bool* transpose) -> nvinfer1::ITensor* { + if (operand.is_tensor()) { + return operand.tensor(); + } else { + TRT_ShapedWeights weights(operand.weights().TrtDType()); + if (*transpose) { + weights = params->weight_store->GetTempWeights(operand.weights()); + ReorderCKtoKC(operand.weights(), &weights); + // Weights have been transposed, can set transpose to false + *transpose = false; + } else { + weights = operand.weights(); + } + return params->converter->CreateConstantLayer(weights, weights.shape_); + } + }; + + nvinfer1::ITensor* tensor_a = prepare_matmul_operand(input_a, &transpose_a); + nvinfer1::ITensor* tensor_b = prepare_matmul_operand(input_b, &transpose_b); + + nvinfer1::IMatrixMultiplyLayer* layer = + params->converter->network()->addMatrixMultiply( + *tensor_a, get_matrix_op(tensor_a, transpose_a), *tensor_b, + get_matrix_op(tensor_b, transpose_b)); + + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -3760,7 +4257,11 @@ Status ConvertMatMulHelper(OpConverterParams* params, Status ConvertMatMul(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"a", false}, {"b", true}})); + if (inputs.size() != 2) { + return errors::InvalidArgument(node_def.op(), " got ", inputs.size(), + " inputs but expected 2, at ", + node_def.name()); + } TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); @@ -3768,84 +4269,68 @@ Status ConvertMatMul(OpConverterParams* params) { bool transpose_a = attrs.get("transpose_a"); bool transpose_b = attrs.get("transpose_b"); - // FullyConnected: - if (transpose_a) { - return errors::InvalidArgument( - "transpose_a is not supported for TensorRT FullyConnected (op: ", - node_def.op(), "), at: ", node_def.name()); - } - if (params->validation_only) return Status::OK(); - return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(), + return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1), transpose_a, transpose_b, node_def.name()); } Status ConvertBatchMatMul(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - // TODO(tmorris): Enable once false is updated to mean either tensor or weight - // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", - // false}})); - TF_RETURN_IF_ERROR( - AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); if (inputs.size() != 2) { return errors::InvalidArgument(node_def.op(), " got ", inputs.size(), " inputs but expected 2, at ", node_def.name()); } + // TODO(tmorris): Enable once false is updated to mean either tensor or weight + // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", + // false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); if (inputs[0].is_weights() && inputs[1].is_weights()) { return errors::InvalidArgument( "All inputs are weights, but Grappler is expected to fold them."); } + TFAttrs attrs(node_def); const bool transpose_a = attrs.get("adj_x"); const bool transpose_b = attrs.get("adj_y"); - const auto dims = inputs.at(0).GetTrtDims(); - if (dims.nbDims == 1) { // NC * CK is only supported through fully connected - if (transpose_a == false && inputs.at(0).is_tensor() && - inputs.at(1).is_weights()) { - return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1).weights(), - transpose_b, node_def.name()); - } else { - return errors::InvalidArgument("Invalid configuration for MatMul, at: ", - node_def.name()); - } + + // Removes the batch dimension from weights. + const auto remove_weights_batch_dim = + [¶ms](const TRT_TensorOrWeights& input, TRT_TensorOrWeights* tensor) { + auto dims = input.GetTrtDims(); + if (input.is_weights()) { + // The other operand must be a tensor, this is ensured by earlier + // checks. Checks that the batch dimension is not changed by + // broadcasting. + if (dims.d[0] != 1) { + return errors::InvalidArgument( + "Input weight attempts to broadcast across batch dimension for " + "BatchMatMul, at ", + params->node_def.name()); + } + // Remove the batch dimension from the weights. + TF_RETURN_IF_ERROR(RemoveBatchDimension(&dims)); + } + // Create tensor and reshape if necessary. + nvinfer1::ITensor* t{nullptr}; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input, dims, params->validation_only, &t)); + *tensor = TRT_TensorOrWeights{t}; + return Status::OK(); + }; + + TRT_TensorOrWeights tensor_l{nullptr}; + TRT_TensorOrWeights tensor_r{nullptr}; + TF_RETURN_IF_ERROR(remove_weights_batch_dim(inputs.at(0), &tensor_l)); + TF_RETURN_IF_ERROR(remove_weights_batch_dim(inputs.at(1), &tensor_r)); + + if (params->validation_only) { + return Status::OK(); } - auto get_tensor_with_proper_dims = [params]( - const TRT_TensorOrWeights& input, - const nvinfer1::ITensor** tensor) { - auto dims = input.GetTrtDims(); - if (input.is_weights()) { - // The other operand must be a tensor, this is ensured by earlier checks. - // Checks that the batch dimension is not changed by broadcasting. - if (dims.d[0] != 1) { - return errors::InvalidArgument( - "Input weight attempts to broadcast across batch dimension for " - "BatchMatMul, at ", - params->node_def.name()); - } - // Remove the batch dimension from the weights. - TF_RETURN_IF_ERROR(RemoveBatchDimension(&dims)); - } - // Create tensor and reshape if necessary. - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, dims, params->validation_only, tensor)); - return Status::OK(); - }; - const nvinfer1::ITensor* tensor_l; - const nvinfer1::ITensor* tensor_r; - TF_RETURN_IF_ERROR(get_tensor_with_proper_dims(inputs.at(0), &tensor_l)); - TF_RETURN_IF_ERROR(get_tensor_with_proper_dims(inputs.at(1), &tensor_r)); - if (params->validation_only) return Status::OK(); - - nvinfer1::IMatrixMultiplyLayer* layer = - params->converter->network()->addMatrixMultiply( - *const_cast(tensor_l), transpose_a, - *const_cast(tensor_r), transpose_b); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return Status::OK(); + return ConvertMatMulHelper(params, tensor_l, tensor_r, transpose_a, + transpose_b, node_def.name()); } Status ConvertSoftmax(OpConverterParams* params) { @@ -3854,7 +4339,7 @@ Status ConvertSoftmax(OpConverterParams* params) { TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}})); TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int nbDims = tensor->getDimensions().nbDims; if (nbDims == 0) { @@ -3864,8 +4349,8 @@ Status ConvertSoftmax(OpConverterParams* params) { } if (params->validation_only) return Status::OK(); - nvinfer1::ISoftMaxLayer* layer = params->converter->network()->addSoftMax( - *const_cast(tensor)); + nvinfer1::ISoftMaxLayer* layer = + params->converter->network()->addSoftMax(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); // Tensorflow SoftMax assumes applying softmax on the last dimension. layer->setAxes(1 << (nbDims - 1)); @@ -3877,14 +4362,64 @@ Status ConvertSoftmax(OpConverterParams* params) { return Status::OK(); } +Status ConvertArgMinMax(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"dimension", true}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + // INT64 outputs are not supported by TRT. + TFAttrs attrs(node_def); + DataType output_dtype = attrs.get("output_type"); + if (output_dtype != DataType::DT_INT32) { + return errors::Unimplemented("Output type ", DataTypeString(output_dtype), + " is not supported, at ", node_def.name()); + } + int tf_axis = inputs.at(1).weights().GetSpan()[0]; + int trt_axis; + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + TF_RETURN_IF_ERROR( + ConvertAxis(tf_axis, dims.nbDims, node_def.name(), &trt_axis)); + nvinfer1::TopKOperation topk_op; + if (node_def.op() == "ArgMin") { + topk_op = nvinfer1::TopKOperation::kMIN; + } else if (node_def.op() == "ArgMax") { + topk_op = nvinfer1::TopKOperation::kMAX; + } else { + return errors::InvalidArgument("Unsupported ArgMin/Max operation"); + } + if (params->validation_only) return Status::OK(); + + // Use TopK with k = 1. Only indices output is needed (output 1). + const uint32_t reduce_axes = 1 << trt_axis; + nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK( + *inputs.at(0).tensor(), topk_op, 1, reduce_axes); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); + + // Squeeze on axis. + std::vector size(dims.d, dims.d + dims.nbDims); + size.erase(size.begin() + trt_axis); + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &new_dims)); + nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(output_indices_tensor), new_dims, + /*validation_only=*/false, &output_tensor)); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} + Status ConvertTopK(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; TF_RETURN_IF_ERROR( CheckInputsWeights(*params, {{"input", false}, {"k", true}})); - TF_RETURN_IF_ERROR(AllowDataTypes( - *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); const int num_dims = tensor->getDimensions().nbDims; if (num_dims == 0) { return errors::InvalidArgument( @@ -3901,10 +4436,10 @@ Status ConvertTopK(OpConverterParams* params) { if (params->validation_only) return Status::OK(); const nvinfer1::TopKOperation op = nvinfer1::TopKOperation::kMAX; - const int k = *(static_cast(const_cast(k_w.GetValues()))); + const int k = *(static_cast(k_w.GetValues())); const uint32_t reduce_axes = 1 << (num_dims - 1); - nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK( - *const_cast(tensor), op, k, reduce_axes); + nvinfer1::ITopKLayer* layer = + params->converter->network()->addTopK(*tensor, op, k, reduce_axes); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); @@ -3914,40 +4449,415 @@ Status ConvertTopK(OpConverterParams* params) { return Status::OK(); } +Status ConvertDepthSpaceShuffle(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); + TFAttrs attrs(node_def); + const int block_size = attrs.get("block_size"); + if (block_size < 2) { + return errors::InvalidArgument("Block size must be 2 or greater, at ", + node_def.name()); + } + const string data_format = attrs.get("data_format"); + if (data_format != "NCHW" && data_format != "NHWC") { + return errors::Unimplemented("Data format ", data_format, + " is not supported, at ", node_def.name()); + } + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + if (dims.nbDims != 3) { + return errors::InvalidArgument("The input to ", node_def.op(), + " must be rank 4, at ", node_def.name()); + } + const int num_channels = data_format == "NCHW" ? dims.d[0] : dims.d[2]; + const int h = data_format == "NCHW" ? dims.d[1] : dims.d[0]; + const int w = data_format == "NCHW" ? dims.d[2] : dims.d[1]; + // Get shuffle parameters. + nvinfer1::Dims first_shuffle_shape; + nvinfer1::Permutation transpose_perm; + nvinfer1::Dims second_shuffle_shape; + if (node_def.op() == "DepthToSpace") { + if (num_channels % (block_size * block_size) != 0) { + return errors::InvalidArgument( + "Number of channels must be divisible by block_size*block_size, at ", + node_def.name()); + } + // First Reshape [C, H, W] - > [r, r, C/(r*r), H, W] + first_shuffle_shape = { + /*nbDims=*/5, + /*d=*/{block_size, block_size, num_channels / (block_size * block_size), + h, w}}; + // Transpose [r, r, C/(r*r), H, W] -> [C/(r*r), H, r, W, r] + transpose_perm = {2, 3, 0, 4, 1}; + // Second Reshape [C/(r*r), H, r, W, r] -> [C/(r*r), H * r, W * r] + second_shuffle_shape = + nvinfer1::DimsCHW(num_channels / (block_size * block_size), + h * block_size, w * block_size); + } else if (node_def.op() == "SpaceToDepth") { + if (h % block_size != 0 || w % block_size != 0) { + return errors::InvalidArgument( + "Width and height must be divisible by block_size, at ", + node_def.name()); + } + // First Reshape [C, H, W] -> [C, H/r, r, W/r, r] + first_shuffle_shape = {/*nbDims=*/5, + /*d=*/{num_channels, h / block_size, block_size, + w / block_size, block_size}}; + // Transpose [C, H/r, r, W/r, r] -> [r, r, C, H/r, W/r] + transpose_perm = {2, 4, 0, 1, 3}; + // Second Reshape [r, r, C, H/r, W/r] -> [C*r*r, H/r, W/r] + second_shuffle_shape = nvinfer1::DimsCHW( + num_channels * block_size * block_size, h / block_size, w / block_size); + } + if (params->validation_only) return Status::OK(); + + nvinfer1::IShuffleLayer* first_shuffle = + params->converter->network()->addShuffle(*inputs.at(0).tensor()); + TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name()); + if (data_format == "NHWC") { + first_shuffle->setFirstTranspose({2, 0, 1}); + } + first_shuffle->setReshapeDimensions(first_shuffle_shape); + first_shuffle->setSecondTranspose(transpose_perm); + + nvinfer1::IShuffleLayer* second_shuffle = + params->converter->network()->addShuffle(*first_shuffle->getOutput(0)); + TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name()); + second_shuffle->setReshapeDimensions(second_shuffle_shape); + if (data_format == "NHWC") { + second_shuffle->setSecondTranspose({1, 2, 0}); + } + + params->converter->MarkQuantizationRangesAsInferrable( + inputs.at(0).tensor(), first_shuffle->getOutput(0)); + params->converter->MarkQuantizationRangesAsInferrable( + first_shuffle->getOutput(0), second_shuffle->getOutput(0)); + params->outputs->push_back(TRT_TensorOrWeights(second_shuffle->getOutput(0))); + return Status::OK(); +} + +Status ConvertSquaredDifference(OpConverterParams* params) { + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", false}})); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + // Broadcast inputs. + nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; + TF_RETURN_IF_ERROR(params->converter->GetTrtBroadcastShape( + inputs.at(0), inputs.at(1), &broadcasted_dims_l, &broadcasted_dims_r)); + nvinfer1::ITensor* tensor_l = nullptr; + nvinfer1::ITensor* tensor_r = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l)); + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r)); + if (params->validation_only) return Status::OK(); + + // Subtract x - y. + nvinfer1::IElementWiseLayer* sub = + params->converter->network()->addElementWise( + *tensor_l, *tensor_r, nvinfer1::ElementWiseOperation::kSUB); + TFTRT_RETURN_ERROR_IF_NULLPTR(sub, node_def.name()); + // Multiply (x - y) * (x - y). + nvinfer1::IElementWiseLayer* mul = + params->converter->network()->addElementWise( + *sub->getOutput(0), *sub->getOutput(0), + nvinfer1::ElementWiseOperation::kPROD); + TFTRT_RETURN_ERROR_IF_NULLPTR(mul, node_def.name()); + + params->outputs->push_back(TRT_TensorOrWeights(mul->getOutput(0))); + return Status::OK(); +} + +#if IS_TRT_VERSION_GE(5, 1, 0, 0) +Status ConvertCombinedNMS(OpConverterParams* params) { + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"boxes", false}, + {"scores", false}, + {"max_output_size_per_class", true}, + {"max_total_size", true}, + {"iou_threshold", true}, + {"score_threshold", true}})); + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + + nvinfer1::ITensor* boxes_tensor = inputs.at(0).tensor(); + nvinfer1::ITensor* scores_tensor = inputs.at(1).tensor(); + TRT_ShapedWeights output_size_per_class = inputs.at(2).weights(); + TRT_ShapedWeights total_size = inputs.at(3).weights(); + TRT_ShapedWeights iou_threshold = inputs.at(4).weights(); + TRT_ShapedWeights score_threshold = inputs.at(5).weights(); + + // Validate tensors and weights (also set some of the needed plugin fields) + const auto boxes_dims = boxes_tensor->getDimensions(); + const auto scores_dims = scores_tensor->getDimensions(); + if (boxes_dims.nbDims != 3) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin input boxes must be 3-D excluding batch ", + node_def.name()); + } + const int num_classes = scores_dims.d[1]; + bool box_check = boxes_dims.d[1] == 1 || boxes_dims.d[1] == num_classes; + if (!box_check) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin third dimension of boxes must be either 1 " + "or num_classes ", + node_def.name()); + } + if (output_size_per_class.shape_.nbDims != 1) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin max_output_size_per_class must be 0-D ", + node_def.name()); + } + int max_size_per_class = + *(static_cast(output_size_per_class.GetValues())); + if (max_size_per_class <= 0) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin max_output_size_per_class should be > 0", + node_def.name()); + } + if (total_size.shape_.nbDims != 1) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin max_total_size must be 0-D ", + node_def.name()); + } + int max_total_size = *(static_cast(total_size.GetValues())); + if (max_total_size <= 0) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin max_total_size should be > 0", + node_def.name()); + } + if (iou_threshold.shape_.nbDims != 1) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin iou_threshold must be 0-D ", + node_def.name()); + } + float iou_thresh = *(static_cast(iou_threshold.GetValues())); + if (iou_thresh < 0.0 || iou_thresh > 1.0) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin iou_threshold must be in [0, 1]", + node_def.name()); + } + if (score_threshold.shape_.nbDims != 1) { + return errors::InvalidArgument( + "TensorRT BatchedNMS Plugin score_threshold must be 0-D ", + node_def.name()); + } + + if (params->validation_only) return Status::OK(); + + // TF op CombinedNonMaxSuppression doesn't have the option of + // not normalizing coordinates. + const bool is_normalized = true; + // Set plugin fields and the field collection + TFAttrs attrs(node_def); + bool share_location = (boxes_dims.d[1] == 1); + const bool pad_per_class = attrs.get("pad_per_class"); + int top_k; + if (pad_per_class) { + top_k = std::min(max_size_per_class * num_classes, max_total_size); + } else { + top_k = max_total_size; + } + const int keep_top_k = top_k; + float score_thresh = *(static_cast(score_threshold.GetValues())); + const int background_id = -1; + nvinfer1::PluginField fields[8] = { + nvinfer1::PluginField{"shareLocation", &share_location, + nvinfer1::PluginFieldType::kINT32, 1}, + nvinfer1::PluginField{"backgroundLabelId", &background_id, + nvinfer1::PluginFieldType::kINT32, 1}, + nvinfer1::PluginField{"numClasses", &num_classes, + nvinfer1::PluginFieldType::kINT32, 1}, + nvinfer1::PluginField{"topK", &top_k, nvinfer1::PluginFieldType::kINT32, + 1}, + nvinfer1::PluginField{"keepTopK", &keep_top_k, + nvinfer1::PluginFieldType::kINT32, 1}, + nvinfer1::PluginField{"scoreThreshold", &score_thresh, + nvinfer1::PluginFieldType::kFLOAT32, 1}, + nvinfer1::PluginField{"iouThreshold", &iou_thresh, + nvinfer1::PluginFieldType::kFLOAT32, 1}, + nvinfer1::PluginField{"isNormalized", &is_normalized, + nvinfer1::PluginFieldType::kINT32, 1}, + }; + nvinfer1::PluginFieldCollection fc{8, fields}; + + // Get plugin creator + auto creator = + getPluginRegistry()->getPluginCreator("BatchedNMS_TRT", "1", ""); + TFTRT_RETURN_ERROR_IF_NULLPTR(creator, node_def.name()); + + // Create plugin + TrtUniquePtrType plugin( + creator->createPlugin(node_def.name().c_str(), &fc)); + TFTRT_RETURN_ERROR_IF_NULLPTR(plugin, node_def.name()); + + // Set plugin inputs + std::vector plugin_inputs; + plugin_inputs.push_back(boxes_tensor); + plugin_inputs.push_back(scores_tensor); + + // Add plugin to network + nvinfer1::IPluginV2Layer* layer = params->converter->network()->addPluginV2( + &plugin_inputs[0], static_cast(plugin_inputs.size()), *plugin); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + auto shrink_last_dim = [params](nvinfer1::ITensor* in_tensor, + nvinfer1::ITensor** out_tensor) { + nvinfer1::Dims dims = in_tensor->getDimensions(); + if (dims.d[dims.nbDims - 1] != 1) { + return errors::Internal("Expect last dims to be 1, for tensor ", + DebugString(*in_tensor)); + } + --dims.nbDims; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(in_tensor), dims, + /*validation_only=*/false, out_tensor)); + return Status::OK(); + }; + + // Set plugin outputs + nvinfer1::ITensor* output_nmsed_boxes = layer->getOutput(1); + nvinfer1::ITensor* output_nmsed_scores = nullptr; + nvinfer1::ITensor* output_nmsed_classes = nullptr; + nvinfer1::ITensor* output_num_detections = nullptr; + TF_RETURN_IF_ERROR( + shrink_last_dim(layer->getOutput(2), &output_nmsed_scores)); + TF_RETURN_IF_ERROR( + shrink_last_dim(layer->getOutput(3), &output_nmsed_classes)); + TF_RETURN_IF_ERROR( + shrink_last_dim(layer->getOutput(0), &output_num_detections)); + + params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_boxes)); + params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_scores)); + params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_classes)); + params->outputs->push_back(TRT_TensorOrWeights(output_num_detections)); + + return Status::OK(); +} +#endif // CombinedNonMaxSuppression + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +Status ConvertResize(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"size", true}})); + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); + + // Get input tensor. Transpose it from NHWC to NCHW. + nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, params->node_def.name()); + + // Get output size. It must constain two values i.e. [H_out, W_out] + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (weights.count() != 2) { + return errors::Unimplemented("Resize to shape=[] is not supported, at ", + node_def.name()); + } + const int* weights_ptr = static_cast(weights.GetValues()); + + // Verify and consume node attributes. + TFAttrs attrs(node_def); + bool align_corners = attrs.get("align_corners"); + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + + // Verify resize mode. Initialize resize mode if supported. + nvinfer1::ResizeMode resize_mode; + if (node_def.op() == "ResizeBilinear") { + resize_mode = nvinfer1::ResizeMode::kLINEAR; + } else if (node_def.op() == "ResizeNearestNeighbor") { + resize_mode = nvinfer1::ResizeMode::kNEAREST; + } else { + return errors::Unimplemented(node_def.op(), " is not yet implemented at ", + node_def.name()); + } + + // return after validation if only validation is requested. + if (params->validation_only) return Status::OK(); + + // Tranpose tensor from NHWC to NCHW format. + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, {0, 3, 1, 2}, &tensor)); + + // Calculate output dimensions. + // Given input dimensions [N, C, H, W] and output size [H_out, W_out], + // output dimensions equals [N, C, H_out, W_out] + nvinfer1::Dims output_dimensions; + output_dimensions.nbDims = tensor->getDimensions().nbDims; + for (int i = 0; i < output_dimensions.nbDims; ++i) { + output_dimensions.d[i] = tensor->getDimensions().d[i]; + } + output_dimensions.d[output_dimensions.nbDims - 2] = weights_ptr[0]; + output_dimensions.d[output_dimensions.nbDims - 1] = weights_ptr[1]; + + // Add resize layer. + nvinfer1::IResizeLayer* layer = + params->converter->network()->addResize(*tensor); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + // Set layer parameters. + layer->setResizeMode(resize_mode); + layer->setOutputDimensions(output_dimensions); + layer->setAlignCorners(align_corners); + + // Get output tensor. Transpose it from NCHW to NHWC. + nvinfer1::ITensor* output = layer->getOutput(0); + + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(output, {0, 2, 3, 1}, &output)); + params->outputs->push_back(TRT_TensorOrWeights(output)); + // Success + return Status::OK(); +} // ConvertResize +#endif // IS_TRT_VERSION_GE(6, 0, 0, 0) + static void RegisterValidatableOpConverters( std::unordered_map* registration) { + (*registration)["BatchMatMul"] = ConvertBatchMatMul; (*registration)["BiasAdd"] = ConvertBiasAdd; +#if IS_TRT_VERSION_GE(5, 1, 2, 0) + (*registration)["ClipByValue"] = ConvertClipByValue; +#endif +#if IS_TRT_VERSION_GE(5, 1, 0, 0) + (*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS; +#endif (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; (*registration)["Conv2D"] = ConvertConv2D; (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput; + (*registration)["DepthToSpace"] = ConvertDepthSpaceShuffle; (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["GatherV2"] = ConvertGather; (*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Pack"] = ConvertPack; (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; (*registration)["Reshape"] = ConvertReshape; +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + for (auto resize_mode : {"ResizeBilinear", "ResizeNearestNeighbor"}) { + (*registration)[resize_mode] = ConvertResize; + } +#endif (*registration)["Rsqrt"] = ConvertRsqrt; (*registration)["Slice"] = ConvertSlice; + (*registration)["Softmax"] = ConvertSoftmax; + (*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle; + (*registration)["Split"] = ConvertSplit; (*registration)["Square"] = ConvertSquare; + (*registration)["SquaredDifference"] = ConvertSquaredDifference; (*registration)["Squeeze"] = ConvertSqueeze; (*registration)["StridedSlice"] = ConvertStridedSlice; - (*registration)["Transpose"] = ConvertTranspose; (*registration)["TopKV2"] = ConvertTopK; - - // TODO(ben,jie): this is a temp hack. - (*registration)["Identity"] = ConvertIdentity; // Identity should be removed - (*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed - - (*registration)["Sum"] = ConvertReduce; - (*registration)["Prod"] = ConvertReduce; - (*registration)["Max"] = ConvertReduce; - (*registration)["Min"] = ConvertReduce; - (*registration)["Mean"] = ConvertReduce; - (*registration)["Softmax"] = ConvertSoftmax; - (*registration)["BatchMatMul"] = ConvertBatchMatMul; + (*registration)["Transpose"] = ConvertTranspose; + (*registration)["Unpack"] = ConvertUnpack; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", @@ -3955,11 +4865,11 @@ static void RegisterValidatableOpConverters( (*registration)[quantization_op_type] = ConvertQuantize; } for (auto binary_op_type : - {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum"}) { + {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum", "Pow"}) { (*registration)[binary_op_type] = ConvertBinary; } - for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { - (*registration)[activation_op_type] = ConvertActivation; + for (auto activation_op_pair : *ActivationTypeMap()) { + (*registration)[activation_op_pair.first] = ConvertActivation; } for (auto pool_op_type : {"AvgPool", "MaxPool"}) { (*registration)[pool_op_type] = ConvertPool; @@ -3970,6 +4880,17 @@ static void RegisterValidatableOpConverters( for (auto unary_op_pair : *UnaryOperationMap()) { (*registration)[unary_op_pair.first] = ConvertUnary; } + for (auto reduce_op_type : {"Sum", "Prod", "Max", "Min", "Mean"}) { + (*registration)[reduce_op_type] = ConvertReduce; + } + for (auto arg_minmax_type : {"ArgMin", "ArgMax"}) { + (*registration)[arg_minmax_type] = ConvertArgMinMax; + } + // The following are no-ops during inference and will not be mapped to any TRT + // layer. + for (auto identity_op_type : {"Identity", "Snapshot", "StopGradient"}) { + (*registration)[identity_op_type] = ConvertIdentity; + } } void TrtNodeValidator::RegisterOpValidators() { @@ -4066,7 +4987,7 @@ Status ConvertGraphDefToEngine( TFAttrs attrs(node_def); DataType tf_dtype = attrs.get("T"); nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); + TF_RETURN_IF_ERROR(TfDataTypeToTrt(tf_dtype, &trt_dtype)); if (output_tensors.size() <= slot_number) { output_tensors.resize(slot_number + 1); } @@ -4098,7 +5019,7 @@ Status ConvertSegmentToGraphDef( const Graph* graph, const grappler::GraphProperties& graph_properties, const std::vector& subgraph_nodes, // In topological order std::vector* connections, GraphDef* segment_def, - string* common_scope) { + string* scope_name) { std::set marker_nodes; // Update connection shapes/data types and add corresponding input/output // nodes in the segment graphdef. @@ -4231,9 +5152,7 @@ Status ConvertSegmentToGraphDef( snode->mutable_input()->RemoveLast(); } } - *common_scope = local_scope; - VLOG(1) << "Converted TensorRT candidate segment @scope '" << local_scope - << "' to a GraphDef"; + *scope_name = local_scope; return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 6333d9130a8..763b28b7402 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -34,7 +34,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { @@ -43,6 +43,14 @@ extern const char* const kOutputPHName; namespace convert { +#define IS_TRT_VERSION_GE(major, minor, patch, build) \ + ((NV_TENSORRT_MAJOR > major) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ + NV_TENSORRT_PATCH > patch) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ + NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build)) + struct EngineConnection { // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, @@ -123,13 +131,14 @@ struct EngineInfo { // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be // sorted in topological order. +// - scope_name: the name of the scope where the TRTEngineOp will be placed. // // TODO(aaroey): add tests to validate these properties. Status ConvertSegmentToGraphDef( const Graph* graph, const grappler::GraphProperties& graph_properties, const std::vector& subgraph_nodes, std::vector* connections, GraphDef* segment_def, - string* common_scope); + string* scope_name); // Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff // 'builder' successfully build the engine. If the result is not ok, 'engine' @@ -161,12 +170,14 @@ string DebugString(const nvinfer1::DataType trt_dtype); string DebugString(const nvinfer1::Dims& dims); string DebugString(const nvinfer1::Permutation& permutation, int len); string DebugString(const nvinfer1::ITensor& tensor); -int64_t TrtDimsNumElements(const nvinfer1::Dims& dims); +int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims); +int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims); // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight. class TRT_ShapedWeights { public: - explicit TRT_ShapedWeights(DataType type = DT_FLOAT); + explicit TRT_ShapedWeights( + nvinfer1::DataType type = nvinfer1::DataType::kFLOAT); // Copy from another weights. // @@ -176,6 +187,8 @@ class TRT_ShapedWeights { nvinfer1::Weights GetTrtWeights() const; + const Tensor& GetTensor() const { return tensor_; } + // Returns the raw pointer to the underlying buffer which holds the weights // value. void* GetValues() const { @@ -199,14 +212,18 @@ class TRT_ShapedWeights { return std::vector(span.data(), span.data() + span.size()); } + nvinfer1::DataType TrtDType() const { return type_; } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; // Note: shape.type[] is not used. - DataType type_; private: // This constructor is only used by TrtWeightStore, which creates the // underlying buffer. - TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor); + TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims, + Tensor tensor); + + nvinfer1::DataType type_; // All weights should be stored inside TrtWeightStore to make sure lifetime of // all the underlying tensors are available until the engine is built. For @@ -227,12 +244,13 @@ class TRT_ShapedWeights { class TrtWeightStore { public: // Get a TRT_ShapedWeights with 'type' and 'dims'. - TRT_ShapedWeights GetTempWeights(DataType type, const nvinfer1::Dims& dims); + TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type, + const nvinfer1::Dims& dims); // Get a TRT_ShapedWeights with the same data type and dimensions as // 'weights'. TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) { - return GetTempWeights(weights.type_, weights.shape_); + return GetTempWeights(weights.TrtDType(), weights.shape_); } private: @@ -272,9 +290,7 @@ class TRT_TensorOrWeights { bool is_tensor() const { return initialized_ && is_tensor_; } bool is_weights() const { return initialized_ && !is_tensor_; } - nvinfer1::ITensor* tensor(); - - const nvinfer1::ITensor* tensor() const; + nvinfer1::ITensor* tensor() const; TRT_ShapedWeights& weights() { CHECK(is_weights()); @@ -483,9 +499,10 @@ class Converter { // dimension which should always be 0. Status TransposeTensor(nvinfer1::ITensor* input_tensor, const std::vector& order_with_batch_dim, - const nvinfer1::ITensor** output_tensor); + nvinfer1::ITensor** output_tensor); - // Converts 'input' into 'tensor' with shape specified by 'dims'. + // Converts 'input' into 'tensor' with shape specified by 'dims' (which + // doesn't contain the batch dimension). // // If validation_only is true, it doesn't do the conversion but only do some // minimum validation for the eligibility of the conversion, and *tensor will @@ -493,7 +510,7 @@ class Converter { Status PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, const bool validation_only, - const nvinfer1::ITensor** tensor); + nvinfer1::ITensor** tensor); // Return OK if the broadcast scheme is supported and compute the shapes after // broadcasting. @@ -577,6 +594,8 @@ class Converter { // Map of all supported UnaryOperations const std::unordered_map* UnaryOperationMap(); +// Map of all supported ActivationTypes +const std::unordered_map* ActivationTypeMap(); } // namespace convert } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index e89b31759f2..2106ca95e8c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" @@ -49,9 +50,9 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "cuda/include/cuda.h" -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { @@ -109,13 +110,17 @@ DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { } NodeDef MakeNodeDef(const string& name, const string& op, - const std::vector& inputs) { + const std::vector& inputs, + const std::map attrs = {}) { NodeDef node_def; node_def.set_name(name); node_def.set_op(op); for (const string& input : inputs) { node_def.add_input(input); } + for (const auto& attr : attrs) { + (*node_def.mutable_attr())[attr.first] = attr.second; + } return node_def; } @@ -179,10 +184,33 @@ void ExpectArrayNear(const std::vector& lhs, } } +template +void ExpectArrayAlmostEqual(const std::vector& lhs, absl::Span rhs, + T tolerance) { + ASSERT_EQ(lhs.size(), rhs.size()); + for (int i = 0; i < lhs.size(); i++) { + EXPECT_NEAR(lhs[i], rhs[i], tolerance); + } +} + +// Eigen::half cannot implicitly convert to float which is required for +// EXPECT_NEAR. +template <> +void ExpectArrayAlmostEqual(const std::vector& lhs, + absl::Span rhs, + Eigen::half tolerance) { + ASSERT_EQ(lhs.size(), rhs.size()); + for (int i = 0; i < lhs.size(); i++) { + EXPECT_NEAR(Eigen::half_impl::half_to_float(lhs[i]), + Eigen::half_impl::half_to_float(rhs[i]), + Eigen::half_impl::half_to_float(tolerance)); + } +} + bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, const TRT_ShapedWeights& rhs) { - return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ && - lhs.GetValues() == rhs.GetValues(); + return TrtDimsEquals(lhs.shape_, rhs.shape_) && + lhs.TrtDType() == rhs.TrtDType() && lhs.GetValues() == rhs.GetValues(); } template @@ -197,6 +225,29 @@ void ValidateWeights(const TRT_ShapedWeights& weights, } } +template +std::vector InitTestVector(int size, CType start_value = CType(0)) { + std::vector res; + res.reserve(size); + for (int i = 0; i < size; ++i) { + res.push_back(start_value + CType(i)); + } + return res; +} + +template +struct StaticCaster { + OutCType operator()(InCType in) const { return static_cast(in); } +}; + +template +std::vector CastTestVector(const std::vector& vals) { + std::vector res(vals.size()); + std::transform(vals.begin(), vals.end(), res.begin(), + StaticCaster()); + return res; +} + // Fake ITensor implementation for testing purposes. class FakeITensor : public nvinfer1::ITensor { public: @@ -233,7 +284,7 @@ class FakeITensor : public nvinfer1::ITensor { location_ = location; } -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0, 0) bool setDynamicRange(float min, float max) override { dynamic_range_ = std::max(std::abs(min), std::abs(max)); return true; @@ -242,7 +293,7 @@ class FakeITensor : public nvinfer1::ITensor { float getDynamicRange() const override { return dynamic_range_; } #endif -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +#if IS_TRT_VERSION_GE(5, 1, 0, 0) bool dynamicRangeIsSet() const override { return true; } void resetDynamicRange() override {} @@ -252,6 +303,14 @@ class FakeITensor : public nvinfer1::ITensor { float getDynamicRangeMax() const override { return 0.f; } #endif +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + void setAllowedFormats(nvinfer1::TensorFormats formats) override {} + + nvinfer1::TensorFormats getAllowedFormats() const override { return 1; } + + bool isShape() const override { return false; } +#endif + private: string name_; nvinfer1::Dims dims_; @@ -278,7 +337,7 @@ TEST(TRT_ShapedWeights_Test, Basic) { } // Test constructor with DataType argument. { - TRT_ShapedWeights weights(DT_FLOAT); + TRT_ShapedWeights weights(nvinfer1::DataType::kFLOAT); TRT_ShapedWeights copy(weights); for (auto ptr : {&weights, ©}) { nvinfer1::Weights trt_weights = ptr->GetTrtWeights(); @@ -295,7 +354,7 @@ TEST(TRT_ShapedWeights_Test, Basic) { { TrtWeightStore store; TRT_ShapedWeights weights = - store.GetTempWeights(DT_FLOAT, GetTestDims({2, 5})); + store.GetTempWeights(nvinfer1::DataType::kFLOAT, GetTestDims({2, 5})); TRT_ShapedWeights copy(weights); for (auto ptr : {&weights, ©}) { nvinfer1::Weights trt_weights = ptr->GetTrtWeights(); @@ -341,7 +400,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { assigned = *original_ptr; for (auto ptr : {original_ptr, ©, &assigned}) { - EXPECT_EQ(true, ptr->is_tensor()); + ASSERT_TRUE(ptr->is_tensor()); EXPECT_EQ(false, ptr->is_weights()); if (original_ptr == &tw) { EXPECT_EQ(-1, ptr->batch_size()); @@ -364,7 +423,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { assigned = tw; for (auto ptr : {&tw, ©, &assigned}) { - EXPECT_EQ(true, ptr->is_tensor()); + ASSERT_TRUE(ptr->is_tensor()); EXPECT_EQ(false, ptr->is_weights()); EXPECT_EQ(1, ptr->batch_size()); EXPECT_NE(nullptr, ptr->tensor()); @@ -449,12 +508,12 @@ TEST_F(ValidatorTest, ConvertToTensorOrWeights) { std::vector(nvinfer1::Dims::MAX_DIMS + 2, 1), &output), error::OUT_OF_RANGE, "Input tensor rank is greater than 9"); } - // Convert non-Const with #dims < 2. + // Convert non-Const with #dims < 1. { TRT_TensorOrWeights output; ExpectStatus( - convert_to_tensor_or_weights({1}, &output), error::INVALID_ARGUMENT, - "Input tensor with rank<2 is not supported since the first dimension " + convert_to_tensor_or_weights({}, &output), error::INVALID_ARGUMENT, + "Scalar input tensor is not supported since the first dimension " "is treated as batch dimension by TRT"); } // Convert non-Const. We test the case where the non-batch dimemsion is @@ -464,7 +523,7 @@ TEST_F(ValidatorTest, ConvertToTensorOrWeights) { TRT_TensorOrWeights output; ExpectStatus( convert_to_tensor_or_weights({batch_size, non_batch_dim}, &output)); - EXPECT_EQ(true, output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); EXPECT_EQ(batch_size, output.batch_size()); EXPECT_NE(nullptr, output.tensor()); ExpectTrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims()); @@ -513,7 +572,9 @@ TEST_F(ValidatorTest, ValidateNode) { class ConverterTest : public ::testing::Test { public: - ConverterTest() { + ConverterTest() { Reset(); } + + void Reset() { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32, @@ -646,8 +707,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { perm.order[0] = 1; perm.order[1] = 0; for (int i = 0; i < 2; ++i) { - nvinfer1::ITensor* input_tensor = - const_cast(params->inputs[0].tensor()); + nvinfer1::ITensor* input_tensor = params->inputs[0].tensor(); nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle(*input_tensor); layer->setFirstTranspose(perm); @@ -655,7 +715,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { params->outputs->emplace_back(output_tensor); output_tensors.push_back(output_tensor); } - TRT_ShapedWeights output_weights(DT_FLOAT); + TRT_ShapedWeights output_weights(nvinfer1::DataType::kFLOAT); params->outputs->emplace_back(output_weights); return Status::OK(); }; @@ -686,7 +746,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { TEST_F(ConverterTest, TransposeTensor) { nvinfer1::ITensor* input_tensor = converter_->network()->addInput( "", nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5})); - const nvinfer1::ITensor* output_tensor = nullptr; + nvinfer1::ITensor* output_tensor = nullptr; // Rank doesn't match. ExpectStatus( @@ -705,56 +765,78 @@ TEST_F(ConverterTest, TransposeTensor) { ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions()); } -TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { - nvinfer1::ITensor* input_tensor = converter_->network()->addInput( - "", nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5})); - TRT_TensorOrWeights tw(input_tensor); - const nvinfer1::ITensor* output_tensor = nullptr; +void TestPrepareTensorForShape( + const std::vector& input_dims, const std::vector& reshape_dims, + const std::vector& expected_tensor_dims, bool input_is_tensor, + Converter* converter, TrtWeightStore* weight_store, + error::Code expected_code = error::OK, + const char* expected_error_msg_substr = nullptr) { + TRT_TensorOrWeights input; + if (input_is_tensor) { + input = TRT_TensorOrWeights(converter->network()->addInput( + "", nvinfer1::DataType::kFLOAT, GetTestDims(input_dims))); + } else { + input = TRT_TensorOrWeights(weight_store->GetTempWeights( + nvinfer1::DataType::kFLOAT, GetTestDims(input_dims))); + } + nvinfer1::ITensor* output_tensor = nullptr; for (bool validation_only : {false, true}) { - // Shape size doesn't match. - ExpectStatus( - converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}), - validation_only, &output_tensor), - error::INVALID_ARGUMENT, "Reshape shapes are not compatible"); - - // TODO(aaroey): we should check the case where uninferred dimensions are - // not an exact divisor of input dim ensions, e.g. for dims {-1, 7}. - - // Infer shape, ok. - TF_EXPECT_OK(converter_->PrepareTensorForShape( - tw, GetTestDims({-1, 2}), validation_only, &output_tensor)); - if (validation_only) { - EXPECT_EQ(nullptr, output_tensor); + const Status status = converter->PrepareTensorForShape( + input, GetTestDims(reshape_dims), validation_only, &output_tensor); + if (expected_code == error::OK) { + TF_EXPECT_OK(status); + if (validation_only) { + EXPECT_EQ(nullptr, output_tensor); + } else { + ExpectTrtDimsEqualsArray(expected_tensor_dims, + output_tensor->getDimensions()); + } } else { - ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions()); - } - - // Regular shape. - TF_EXPECT_OK(converter_->PrepareTensorForShape( - tw, GetTestDims({10, 3}), validation_only, &output_tensor)); - if (validation_only) { - EXPECT_EQ(nullptr, output_tensor); - } else { - ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); + ExpectStatus(status, expected_code, expected_error_msg_substr); } } } -TEST_F(ConverterTest, PrepareTensorForShape_Weights) { - TRT_ShapedWeights weights = - weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3, 5})); - TRT_TensorOrWeights tw(weights); - const nvinfer1::ITensor* output_tensor = nullptr; - for (bool validation_only : {false, true}) { - TF_EXPECT_OK(converter_->PrepareTensorForShape( - tw, GetTestDims({10, 3}), validation_only, &output_tensor)); - if (validation_only) { - EXPECT_EQ(nullptr, output_tensor); - } else { - ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); - } +TEST_F(ConverterTest, PrepareTensorForShape) { + for (bool input_is_tensor : {true, false}) { + // Shape size doesn't match. + Reset(); + TestPrepareTensorForShape({2, 3, 5}, {2, 3, 6}, {}, input_is_tensor, + converter_.get(), weight_store_, + error::INVALID_ARGUMENT, "Incompatible shapes"); + + // Regular shape. + Reset(); + TestPrepareTensorForShape({2, 3, 5}, {10, 3}, {10, 3}, input_is_tensor, + converter_.get(), weight_store_); + + // Reshape to zero rank. + Reset(); + TestPrepareTensorForShape({1, 1}, {}, {}, input_is_tensor, converter_.get(), + weight_store_); } + + // Tensor input with zero rank. + Reset(); + TestPrepareTensorForShape({}, {1, 1}, {1, 1}, /*input_is_tensor=*/true, + converter_.get(), weight_store_); + + // TODO(aaroey): we should check the case where uninferred dimensions are + // not an exact divisor of input dim ensions, e.g. for dims {-1, 7}. + + // Infer tensor shape, ok. + Reset(); + TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2}, + /*input_is_tensor=*/true, converter_.get(), + weight_store_); + + // Infer weight shape, should fail. + Reset(); + TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2}, + /*input_is_tensor=*/false, converter_.get(), + weight_store_, error::INVALID_ARGUMENT, + "Shape is not fully defined"); } TEST_F(ConverterTest, MaybeUpdateBatchSize) { @@ -796,11 +878,10 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) { template void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) { - TRT_ShapedWeights weights = - weight_store->GetTempWeights(DataTypeToEnum::v(), GetTestDims({2, 3})); + TRT_ShapedWeights weights = weight_store->GetTempWeights( + TfDataTypeToTrt(DataTypeToEnum::v()), GetTestDims({2, 3})); const std::vector values = {T(3), T(1), T(2), T(6), T(5), T(4)}; - memcpy(const_cast(weights.GetValues()), values.data(), - weights.size_bytes()); + memcpy(weights.GetValues(), values.data(), weights.size_bytes()); float out_min = 0.0f; float out_max = 0.0f; @@ -845,7 +926,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { // Input range should be inferred along the chain and applied to tensors. int8_converter.MaybeApplyQuantizationRanges(); -#if NV_TENSORRT_MAJOR >= 5 +#if IS_TRT_VERSION_GE(5, 0, 0, 0) EXPECT_EQ(input.getDynamicRange(), 5.0f); EXPECT_EQ(infer_1.getDynamicRange(), 5.0f); EXPECT_EQ(infer_2.getDynamicRange(), 5.0f); @@ -967,14 +1048,14 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { } TEST_F(ConverterTest, CreateConstantLayer) { - for (auto dtype : {DT_FLOAT, DT_INT32}) { + for (auto dtype : {nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT32}) { TRT_ShapedWeights weights = weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5})); nvinfer1::ITensor* tensor = converter_->CreateConstantLayer(weights, GetTestDims({3, 10})); ASSERT_NE(nullptr, tensor); - EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType()) - << "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual " + EXPECT_EQ(dtype, tensor->getType()) + << "Expected " << DebugString(dtype) << " vs. actual " << DebugString(tensor->getType()); ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions()); } @@ -1039,7 +1120,7 @@ struct InputOutputData { size_t TotalBytes() const { return tensor.TotalBytes(); } - const char* name; + string name; Tensor tensor; }; @@ -1081,7 +1162,6 @@ class OpConverterTest : public ::testing::Test { network_.reset(nullptr); builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); - builder_->setMaxBatchSize(1); builder_->setMaxWorkspaceSize(1 << 26); // Reset the validator and converter. @@ -1094,8 +1174,23 @@ class OpConverterTest : public ::testing::Test { validator_inputs_.clear(); } - // TODO(laigd): test fp16 and int8 support. - void BuildAndRun(const DataVec& input_data, DataVec* output_data) { + void CheckDataTypeMatches(const DataVec& datas) { + for (const auto& data : datas) { + const int input_index = engine_->getBindingIndex(data.name.c_str()); + ASSERT_NE(-1, input_index); + const nvinfer1::DataType trt_dtype = + engine_->getBindingDataType(input_index); + const DataType tf_dtype = TrtDataTypeToTf(trt_dtype); + ASSERT_EQ(data.tensor.dtype(), tf_dtype) + << DataTypeString(data.tensor.dtype()) << " vs. " + << DataTypeString(tf_dtype); + } + } + + // TODO(laigd): test fp16 and int8 support for more converters. + void BuildAndRun(const DataVec& input_data, DataVec* output_data, + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32, + const int batch_size = 1) { // Mark the output tensor as TRT engine output. std::vector output_info; for (const auto& data : *output_data) { @@ -1105,16 +1200,29 @@ class OpConverterTest : public ::testing::Test { TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. + if (precision_mode == TrtPrecisionMode::FP16) { + builder_->setFp16Mode(true); + } else if (precision_mode == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + builder_->setFp16Mode(true); + builder_->setInt8Mode(true); + } ASSERT_EQ(nullptr, engine_.get()); + builder_->setMaxBatchSize(batch_size); engine_.reset(builder_->buildCudaEngine(*converter_->network())); CHECK_NOTNULL(engine_.get()); + CheckDataTypeMatches(input_data); + CheckDataTypeMatches(*output_data); // Execute the TRT engine. const int num_bindings = input_data.size() + output_data->size(); std::vector buffers(num_bindings); for (const auto& data : input_data) { - const int input_index = engine_->getBindingIndex(data.name); + const int input_index = engine_->getBindingIndex(data.name.c_str()); + ASSERT_NE(-1, input_index); ASSERT_EQ(0, cudaMalloc(&buffers[input_index], data.TotalBytes())); ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], data.Buffer(), data.TotalBytes(), cudaMemcpyHostToDevice, @@ -1128,7 +1236,8 @@ class OpConverterTest : public ::testing::Test { }; std::vector output_infos; for (const auto& data : *output_data) { - const int output_index = engine_->getBindingIndex(data.name); + const int output_index = engine_->getBindingIndex(data.name.c_str()); + ASSERT_NE(-1, output_index); output_infos.emplace_back(data.TotalBytes(), output_index); ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes())); } @@ -1136,8 +1245,7 @@ class OpConverterTest : public ::testing::Test { ASSERT_EQ(engine_->getNbBindings(), num_bindings); TrtUniquePtrType execution_context( engine_->createExecutionContext()); - execution_context->enqueue(/*batchSize=*/1, buffers.data(), stream_, - nullptr); + execution_context->enqueue(batch_size, buffers.data(), stream_, nullptr); for (int i = 0; i < output_infos.size(); ++i) { const auto& output_info = output_infos[i]; @@ -1162,7 +1270,7 @@ class OpConverterTest : public ::testing::Test { // Add ITensor for both validation and conversion. void AddTestTensor( - const char* name, const std::vector& dims, int batch_size = 1, + const string& name, const std::vector& dims, int batch_size = 1, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { DataType tf_dtype = TrtDataTypeToTf(trt_dtype); ops::Placeholder::Attrs attrs; @@ -1182,11 +1290,11 @@ class OpConverterTest : public ::testing::Test { // Add weights for both validation and conversion. template - void AddTestWeights(const char* name, const std::vector& dims, + void AddTestWeights(const string& name, const std::vector& dims, const std::vector& values) { - const DataType dtype = DataTypeToEnum::v(); + const nvinfer1::DataType dtype = TfDataTypeToTrt(DataTypeToEnum::v()); const nvinfer1::Dims trt_dims = GetTestDims(dims); - const int64_t num_elements = TrtDimsNumElements(trt_dims); + const int64_t num_elements = TrtWeightDimsNumElements(trt_dims); QCHECK_EQ(num_elements, values.size()) << num_elements << " vs " << values.size(); TRT_ShapedWeights weights(dtype); @@ -1194,8 +1302,7 @@ class OpConverterTest : public ::testing::Test { weights = converter_->weight_store_.GetTempWeights(dtype, trt_dims); QCHECK_EQ(weights.size_bytes(), sizeof(T) * values.size()) << weights.size_bytes() << " vs " << sizeof(T) * values.size(); - memcpy(const_cast(weights.GetValues()), values.data(), - weights.size_bytes()); + memcpy(weights.GetValues(), values.data(), weights.size_bytes()); } // Add weights for validation. TensorShape shape; @@ -1244,6 +1351,10 @@ class OpConverterTest : public ::testing::Test { } } + void TestMatMulHelper( + const std::function& get_matmul, + const std::string& op_name); + // Expose quantization_ranges_ for tests std::unordered_map& quantization_ranges() { return converter_->quantization_ranges_; @@ -1391,6 +1502,9 @@ TEST_F(OpConverterTest, ConvertConst) { TestConvertConst(this); TestConvertConst(this); + TestConvertConst(this); + TestConvertConst(this); + TestConvertConst(this); TestConvertConst(this); } @@ -1444,7 +1558,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); const DataVec input_data{ @@ -1490,6 +1604,15 @@ TEST_F(OpConverterTest, ConvertReshape) { node_def, error::UNIMPLEMENTED, "Reshape to shape=[] is not supported, at my_reshape"); } + { + // Reshape tensor with zero rank to empty tensor, should fail. + Reset(); + AddTestTensor("input", {}); + AddTestWeights("weights", {1, 0, 1}, {}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Reshape to shape=[] is not supported, at my_reshape"); + } struct TestParams { int batch_size; @@ -1518,28 +1641,139 @@ TEST_F(OpConverterTest, ConvertReshape) { } // Reshape on non batch dimensions, ok. - const int kReshapeOKCases = 3; + const int kReshapeOKCases = 8; TestParams ok_params[kReshapeOKCases] = { TestParams{-1, {1, 2, 3}, {-1, 1, 3, 2}}, TestParams{1, {1, 2, 3}, {-1, 1, 3, 2}}, TestParams{1, {1, 2, 3}, {1, 1, 3, 2}}, + TestParams{2, {1, 2, 3}, {2, 1, 3, 2}}, + TestParams{1, {1, 1}, {1}}, + TestParams{1, {}, {1, 1}}, + TestParams{2, {1, 1}, {2}}, + TestParams{2, {}, {2, 1}}, }; for (int i = 0; i < kReshapeOKCases; ++i) { + const int batch_size = std::max(1, ok_params[i].batch_size); + const auto& shape = ok_params[i].shape; Reset(); - AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size); - AddTestWeights("weights", {4}, ok_params[i].shape); + AddTestTensor("input", ok_params[i].tensor_dims, batch_size); + AddTestWeights("weights", {static_cast(shape.size())}, shape); RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output)); - EXPECT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions()); + ASSERT_TRUE(output.is_tensor()); + const std::vector expected_output_dims(shape.begin() + 1, shape.end()); + const nvinfer1::Dims actual_output_dims = output.tensor()->getDimensions(); + ExpectTrtDimsEqualsArray(expected_output_dims, actual_output_dims); - const DataVec input_data{ - {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; - DataVec output_data{{"my_reshape", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); + std::vector input_vec(TrtTensorDimsNumElements(actual_output_dims) * + batch_size); + std::iota(input_vec.begin(), input_vec.end(), 1); + const DataVec input_data{{"input", test::AsTensor(input_vec)}}; + DataVec output_data{ + {"my_reshape", ConstructTensor(input_vec.size())}}; + BuildAndRun(input_data, &output_data, TrtPrecisionMode::FP32, batch_size); EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(1, 2, 3, 4, 5, 6)); + ElementsAreArray(input_vec)); + } +} + +// Helper function for testing MatMul and BatchMatMul +// get_matmul corresponds to the function used to generate the node. It should +// accept (DataType, transpose_a, transpose_b) as parameters. +void OpConverterTest::TestMatMulHelper( + const std::function& get_matmul, + const std::string& op_name) { + // HACK: This needs to be done in a better way. + const bool is_batch_matmul = op_name == "BatchMatMul"; + { + // Unsupported data type. + Reset(); + NodeDef node_def = get_matmul(DT_INT32, false, false); + AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32); + AddTestWeights("weights", {2, 1}, {3, 5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + ("Data type int32 is not supported for " + op_name + + ", " + "must be one of [float, half], at my_matmul") + .c_str()); + } + // OK. + for (bool transpose_a : {false, true}) { + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + if (is_batch_matmul) { + if (transpose_a || transpose_b) { + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Input weight attempts to broadcast across batch dimension for " + "BatchMatMul, at my_matmul"); + } else { + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Input weight attempts to broadcast across batch dimension"); + } + continue; + } else if (transpose_a) { + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Cannot transpose first input if it is a tensor with fewer than 2 " + "non-batch dimensions"); + continue; + } + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + ASSERT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); + + const DataVec input_data{{"input", test::AsTensor({0, 1})}}; + DataVec output_data{{"my_matmul", ConstructTensor(2)}}; + BuildAndRun(input_data, &output_data); + if (transpose_b) { + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); + } else { + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(2, 3)); + } + } + } + // OK, 3D inputs + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + if (is_batch_matmul) { + if (transpose_b) { + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Input weight attempts to broadcast across batch dimension for " + "BatchMatMul, at my_matmul"); + } else { + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Input weight attempts to broadcast across batch dimension"); + } + continue; + } + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + ASSERT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); + const DataVec input_data{{"input", test::AsTensor({0, 1})}}; + DataVec output_data{{"my_matmul", ConstructTensor(2)}}; + BuildAndRun(input_data, &output_data); + if (transpose_b) { + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); + } else { + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(2, 3)); + } } } @@ -1565,49 +1799,97 @@ TEST_F(OpConverterTest, ConvertMatMul) { return matmul.operation.node()->def(); }; + // Additional test cases specific to MatMul { - // Unsupported data type. + // Can only transpose A if it is 2D in TRT Reset(); - NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false); - AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32); - AddTestWeights("weights", {2, 1}, {3, 5}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "Data type int32 is not supported for MatMul, " - "must be one of [float, half], at my_matmul"); - } - // transpose_a is set. - for (bool transpose_b : {false, true}) { - Reset(); - NodeDef node_def = - get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); + NodeDef node_def = get_matmul_nodedef(DT_FLOAT, true, false); AddTestTensor("input", {2}, /*batch_size=*/1); AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "transpose_a is not supported for TensorRT FullyConnected"); + "Cannot transpose first input if it is a tensor with fewer than 2 " + "non-batch dimensions."); } - // OK. - for (bool transpose_b : {false, true}) { + { + // B must always have 2 non-batch dimensions Reset(); - NodeDef node_def = - get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); + NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false); AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); - EXPECT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); + AddTestTensor("weights", {2}, /*batch_size=*/1); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Second input must either be a constant, or contain at least 2 " + "non-batch dimensions."); + } + { + // We can never transpose weights that are not 2D. + Reset(); + NodeDef node_def = get_matmul_nodedef(DT_FLOAT, true, false); + AddTestWeights("input", {1, 1, 2}, {0, 1}); + AddTestTensor("weights", {2, 2}, /*batch_size=*/1); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Cannot currently transpose constant input if it is not 2 dimensional"); + } + TestMatMulHelper(get_matmul_nodedef, "MatMul"); +} - const DataVec input_data{{"input", test::AsTensor({0, 1})}}; - DataVec output_data{{"my_matmul", ConstructTensor(2)}}; - BuildAndRun(input_data, &output_data); - if (transpose_b) { - EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); - } else { - EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(2, 3)); +TEST_F(OpConverterTest, ConvertBatchMatMul) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_matmul", "BatchMatMul", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "BatchMatMul got 0 inputs but expected 2, at my_matmul"); + } + + // Get the NodeDef for BatchMatMul. + auto get_batch_matmul_nodedef = [](DataType dtype, bool transpose_a, + bool transpose_b) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); + const auto matmul_attrs = + ops::BatchMatMul::AdjX(transpose_a).AdjY(transpose_b); + auto matmul = ops::BatchMatMul(s.WithOpName("my_matmul"), input, weights, + matmul_attrs); + return matmul.operation.node()->def(); + }; + + for (bool transpose_a : {false, true}) { + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_batch_matmul_nodedef(DT_FLOAT, transpose_a, transpose_b); + AddTestTensor("input", {2, 2}, /*batch_size=*/1); + AddTestWeights("weights", {1, 2, 2}, {1, 2, 3, 4}); + + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + ASSERT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); + const DataVec input_data{{"input", test::AsTensor({0, 1, 2, 3})}}; + DataVec output_data{{"my_matmul", ConstructTensor(4)}}; + BuildAndRun(input_data, &output_data); + if (!transpose_a && !transpose_b) { + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(3, 4, 11, 16)); + } else if (transpose_a && transpose_b) { + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(4, 8, 7, 15)); + } else if (transpose_a) { + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(6, 8, 10, 14)); + } else if (transpose_b) { + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(2, 4, 8, 18)); + } } } + + TestMatMulHelper(get_batch_matmul_nodedef, "BatchMatMul"); } template @@ -1652,11 +1934,11 @@ void TestConvertBiasAdd(OpConverterTest* test) { test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_biasadd", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(dims_array, output.tensor()->getDimensions()); // Build and run the engine. - const int num_input = TrtDimsNumElements(GetTestDims(dims_array)); + const int num_input = TrtTensorDimsNumElements(GetTestDims(dims_array)); ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2), num_input); @@ -1755,13 +2037,15 @@ void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { // Check the dims of the output ITensor. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); const DataVec input_data{ {"input", test::AsTensor(swap_inputs ? operand2 : operand1)}}; DataVec output_data{{"my_binary", ConstructTensor(2)}}; - test->BuildAndRun(input_data, &output_data); + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); if (node_def.op() == "Add") { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(5), CType(10.5))); @@ -1808,7 +2092,7 @@ void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { // Check the dims of the output ITensor. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); const DataVec input_data{{"input", test::AsTensor(input)}}; @@ -1843,7 +2127,7 @@ void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { // Check the dims of the output ITensor. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); const DataVec input_data{{"input", test::AsTensor(input)}}; @@ -1862,8 +2146,9 @@ void TestBinaryTensorOpWeightFallback(OpConverterTest* test, const int input_batch_size = 1) { const DataType dtype = DT_FLOAT; typedef typename EnumToDataType::Type CType; - const size_t num_inputs = TrtDimsNumElements(GetTestDims(input_dims)); - const size_t num_weights = TrtDimsNumElements(GetTestDims(weights_dims)); + const size_t num_inputs = TrtTensorDimsNumElements(GetTestDims(input_dims)); + const size_t num_weights = + TrtWeightDimsNumElements(GetTestDims(weights_dims)); test->Reset(); const NodeDef node_def = @@ -1881,7 +2166,7 @@ void TestBinaryTensorOpWeightFallback(OpConverterTest* test, TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); // Check the dims of the output ITensor. std::vector expected_output_dims = input_dims; @@ -1896,7 +2181,7 @@ void TestBinaryTensorOpWeightFallback(OpConverterTest* test, // Check the result of running the engine. const int expected_num_outputs = - TrtDimsNumElements(GetTestDims(expected_output_dims)); + TrtTensorDimsNumElements(GetTestDims(expected_output_dims)); const DataVec input_data{ {"input", ConstructTensor(num_inputs, CType(2))}}; DataVec output_data{ @@ -1933,7 +2218,7 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { // Check output dims. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); const DataVec input_data{ @@ -1942,7 +2227,9 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { DataVec output_data{{"my_binary", ConstructTensor(4)}}; // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun(input_data, &output_data); + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); if (node_def.op() == "Add") { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(5), CType(8), CType(6), CType(9))); @@ -1964,16 +2251,23 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { } else if (node_def.op() == "Maximum") { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(3), CType(6), CType(3), CType(6))); + } else if (node_def.op() == "Pow") { + ExpectArrayNear( + std::vector{CType(9), CType(36), CType(27), CType(216)}, + GetSpanForData(output_data[0])); } else { ASSERT_TRUE(false); } } TEST_F(OpConverterTest, ConvertBinary) { + AttrValue dtype; + dtype.set_type(DT_FLOAT); // Input size doesn't match, should fail. for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) { Reset(); - NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); + NodeDef node_def = + MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}}); AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, StrCat("Add got ", std::to_string(num_inputs), @@ -1983,7 +2277,8 @@ TEST_F(OpConverterTest, ConvertBinary) { { // Both inputs are weights. Reset(); - NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"}); + NodeDef node_def = + MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}}); AddTestWeights("weights1", {1}, {1}); AddTestWeights("weights2", {1}, {1}); RunValidationAndConversion( @@ -1998,15 +2293,12 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); -#if 0 - // TODO(b/119560144): it doesn't support FP16 constants and the following test - // will fail. + TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); TestBinaryTensorOpWeightNoBroadcast(this); -#endif // Test BinaryTensorOpWeight() with channel-wise broadcasting. TestBinaryTensorOpWeightWithChannelWiseBroadcast(this); @@ -2037,6 +2329,7 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); @@ -2045,6 +2338,7 @@ TEST_F(OpConverterTest, ConvertBinary) { TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); } TEST_F(OpConverterTest, ConvertQuantize) { @@ -2086,7 +2380,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); auto ranges = quantization_ranges(); EXPECT_EQ(1, ranges.count(output.tensor())); EXPECT_EQ(6.0f, ranges[output.tensor()]); @@ -2107,7 +2401,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); auto ranges = quantization_ranges(); EXPECT_EQ(1, ranges.count(output.tensor())); EXPECT_EQ(6.0f, ranges[output.tensor()]); @@ -2128,7 +2422,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); auto ranges = quantization_ranges(); EXPECT_EQ(1, ranges.count(output.tensor())); EXPECT_EQ(6.0f, ranges[output.tensor()]); @@ -2169,7 +2463,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); auto ranges = quantization_ranges(); EXPECT_EQ(1, ranges.count(output.tensor())); EXPECT_EQ(6.0f, ranges[output.tensor()]); @@ -2186,24 +2480,29 @@ void TestConvertSquare(OpConverterTest* test) { auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); - test->AddTestTensor("input", {1, 20}); + test->AddTestTensor("input", {1, 20}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions()); const int num_inputs = 20; std::vector inputs(num_inputs); std::vector expected_outputs(num_inputs); - for (int i = 0; i < 20; i++) { + for (int i = 0; i < num_inputs; ++i) { const CType value = CType(i - 9); inputs[i] = value; expected_outputs[i] = value * value; } const DataVec input_data{{"input", test::AsTensor(inputs)}}; + // Engine outputs are converted to FP16 automatically if we set FP16 mode in + // the builder. DataVec output_data{{"my_square", ConstructTensor(num_inputs)}}; - test->BuildAndRun(input_data, &output_data); + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); } @@ -2231,11 +2530,117 @@ TEST_F(OpConverterTest, ConvertSquare) { // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't // test DT_INT32 type here. TestConvertSquare(this); - // TODO(tmorris): Looks like there may be a bug with this layer for FP16 - // inputs. Disabling for now. - // TestConvertSquare(this); + TestConvertSquare(this); } +#if IS_TRT_VERSION_GE(5, 1, 0, 0) +TEST_F(OpConverterTest, ConvertCombinedNMS) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_nms", "CombinedNonMaxSuppression", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "CombinedNonMaxSuppression got 0 inputs but expected 6, at my_nms"); + } + // Get the NodeDef for CombinedNMS. + auto get_nms_nodedef = []() -> NodeDef { + Scope s = Scope::NewRootScope(); + auto boxes_tensor = ops::Placeholder(s.WithOpName("boxes"), DT_FLOAT); + auto scores_tensor = ops::Placeholder(s.WithOpName("scores"), DT_FLOAT); + auto max_output_size_per_class = + ops::Placeholder(s.WithOpName("max_output_size_per_class"), DT_INT32); + auto max_total_size = + ops::Placeholder(s.WithOpName("max_total_size"), DT_INT32); + auto iou_threshold = + ops::Placeholder(s.WithOpName("iou_threshold"), DT_FLOAT); + auto score_threshold = + ops::Placeholder(s.WithOpName("score_threshold"), DT_FLOAT); + auto nms_attrs = ops::CombinedNonMaxSuppression::Attrs().PadPerClass(false); + + auto nms_op = ops::CombinedNonMaxSuppression( + s.WithOpName("my_nms"), boxes_tensor, scores_tensor, + max_output_size_per_class, max_total_size, iou_threshold, + score_threshold, nms_attrs); + return nms_op.operation.node()->def(); + }; + + struct TestParams { + const std::vector boxes_tensor_dims; + const std::vector scores_tensor_dims; + const int32 max_output_size_per_class; + const int32 max_total_size; + const float iou_threshold; + const float score_threshold; + const std::vector expected_nmsed_boxes_dims; + const std::vector expected_nmsed_scores_dims; + const std::vector expected_nmsed_classes_dims; + }; + + // Ok. + const int kCombinedNMSOKCases = 1; + TestParams ok_params[kCombinedNMSOKCases] = { + // TODO(aaroey): there is a bug in TRT's CombinedNonMaxSuppression + // implementation that, the extra output classes that are outside of the + // range specified by valid_detections[i] are not zeros but -1s. + TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}}; + + for (int i = 0; i < kCombinedNMSOKCases; ++i) { + Reset(); + + AddTestTensor("boxes", ok_params[i].boxes_tensor_dims); + AddTestTensor("scores", ok_params[i].scores_tensor_dims); + AddTestWeights("max_output_size_per_class", {1}, + {ok_params[i].max_output_size_per_class}); + AddTestWeights("max_total_size", {1}, {ok_params[i].max_total_size}); + AddTestWeights("iou_threshold", {1}, {ok_params[i].iou_threshold}); + AddTestWeights("score_threshold", {1}, + {ok_params[i].score_threshold}); + + RunValidationAndConversion(get_nms_nodedef()); + + TRT_TensorOrWeights nmsed_boxes; + TRT_TensorOrWeights nmsed_scores; + TRT_TensorOrWeights nmsed_classes; + TRT_TensorOrWeights valid_detections; + + TF_EXPECT_OK(GetTensorOrWeights("my_nms", &nmsed_boxes)); + TF_EXPECT_OK(GetTensorOrWeights("my_nms:1", &nmsed_scores)); + TF_EXPECT_OK(GetTensorOrWeights("my_nms:2", &nmsed_classes)); + TF_EXPECT_OK(GetTensorOrWeights("my_nms:3", &valid_detections)); + + ASSERT_TRUE(nmsed_boxes.is_tensor()); + ASSERT_TRUE(nmsed_scores.is_tensor()); + ASSERT_TRUE(nmsed_classes.is_tensor()); + ASSERT_TRUE(valid_detections.is_tensor()); + + ExpectTrtDimsEqualsArray(ok_params[i].expected_nmsed_boxes_dims, + nmsed_boxes.tensor()->getDimensions()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_nmsed_scores_dims, + nmsed_scores.tensor()->getDimensions()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_nmsed_classes_dims, + nmsed_classes.tensor()->getDimensions()); + ExpectTrtDimsEqualsArray({}, valid_detections.tensor()->getDimensions()); + + DataVec output_data{ + {"my_nms", ConstructTensor(8)}, + {"my_nms:1", ConstructTensor(2)}, + {"my_nms:2", ConstructTensor(2)}, + {"my_nms:3", ConstructTensor(1)}, + }; + const DataVec input_data{ + {"boxes", test::AsTensor({0, 0, 0.3, 0.4})}, + {"scores", test::AsTensor({0.4, 0.7, 0.3})}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4)); + EXPECT_THAT(GetSpanForData(output_data[1]), ElementsAre(0.7, 0.4)); + EXPECT_THAT(GetSpanForData(output_data[2]), ElementsAre(1, 0)); + EXPECT_THAT(GetSpanForData(output_data[3]), ElementsAre(2)); + } +} + +#endif // CombinedNonMaxSuppression + TEST_F(OpConverterTest, ConvertActivation) { { // Input list is empty, should fail. @@ -2256,17 +2661,19 @@ TEST_F(OpConverterTest, ConvertActivation) { "The input \"input\" for Relu must be a tensor, at my_act"); } - constexpr float kAlpha = 0.2f; + constexpr float kLeakyReluAlpha = 0.2f; + constexpr float kSeluAlpha = 1.7580993408473768599402175208123f; + constexpr float kSeluScale = 1.0507009873554804934193349852946f; // Get nodedef for activation layer. auto get_act_nodedef = [](string op_name) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); if (op_name == "LeakyRelu") { - // LeakyRelu does not have a C++ API - NodeDef node_def = MakeNodeDef("my_act", "LeakyRelu", {"input"}); - (*node_def.mutable_attr())["alpha"].set_f(kAlpha); - return node_def; + auto act = ops::internal::LeakyRelu( + s.WithOpName("my_act"), input, + ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha)); + return act.operation.node()->def(); } else if (op_name == "Relu") { auto act = ops::Relu(s.WithOpName("my_act"), input); return act.operation.node()->def(); @@ -2279,6 +2686,18 @@ TEST_F(OpConverterTest, ConvertActivation) { } else if (op_name == "Tanh") { auto act = ops::Tanh(s.WithOpName("my_act"), input); return act.operation.node()->def(); + } else if (op_name == "Elu") { + auto act = ops::Elu(s.WithOpName("my_act"), input); + return act.operation.node()->def(); + } else if (op_name == "Selu") { + auto act = ops::Selu(s.WithOpName("my_act"), input); + return act.operation.node()->def(); + } else if (op_name == "Softsign") { + auto act = ops::Softsign(s.WithOpName("my_act"), input); + return act.operation.node()->def(); + } else if (op_name == "Softplus") { + auto act = ops::Softplus(s.WithOpName("my_act"), input); + return act.operation.node()->def(); } EXPECT_TRUE(false); return NodeDef(); @@ -2286,7 +2705,7 @@ TEST_F(OpConverterTest, ConvertActivation) { // Get expected output for activation layer. auto get_act_output = [](string op_name, float input) -> float { if (op_name == "LeakyRelu") { - return (input > 0.0f) ? input : input * kAlpha; + return (input > 0.0f) ? input : input * kLeakyReluAlpha; } else if (op_name == "Relu") { return (input > 0.0f) ? input : 0.0f; } else if (op_name == "Relu6") { @@ -2295,29 +2714,53 @@ TEST_F(OpConverterTest, ConvertActivation) { return 1.0f / (1.0f + std::exp(-input)); } else if (op_name == "Tanh") { return std::tanh(input); + } else if (op_name == "Elu") { + return (input > 0.0f) ? input : std::exp(input) - 1; + } else if (op_name == "Selu") { + return (input > 0.0f) ? kSeluScale * input + : kSeluScale * kSeluAlpha * (std::exp(input) - 1); + } else if (op_name == "Softsign") { + return input / (std::abs(input) + 1); + } else if (op_name == "Softplus") { + return std::log(std::exp(input) + 1); } EXPECT_TRUE(false); return 0; }; + // Get list of ops to test. + std::vector ops_to_test; + // Add all ops supported by ConvertUnary. + auto* map = ActivationTypeMap(); + ops_to_test.reserve(map->size()); + for (auto& pair : *map) { + ops_to_test.push_back(pair.first); + } + // Add other activation ops to test. + ops_to_test.push_back("Relu6"); + ops_to_test.push_back("LeakyRelu"); // Ok. - for (const string& op_name : - {"LeakyRelu", "Relu", "Relu6", "Sigmoid", "Tanh"}) { + for (const string& op_name : ops_to_test) { Reset(); NodeDef node_def = get_act_nodedef(op_name); AddTestTensor("input", {1, 2, 3}); RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_act", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + + // Certain activations should set quantization range automatically. + auto ranges = quantization_ranges(); if (op_name == "Relu6") { - // Relu6 should set quantization range automatically. - auto ranges = quantization_ranges(); EXPECT_EQ(ranges[output.tensor()], 6.0f); + } else if (op_name == "Sigmoid" || op_name == "Tanh" || + op_name == "Softsign") { + EXPECT_EQ(ranges[output.tensor()], 1.0f); } - const std::vector input = {-100, -2, -1, 0, 1, 100}; + // std::exp in Softplus will overflow for input > 88 + const std::vector input = {-100, -2, -1, 0, 1, 88}; const DataVec input_data{{"input", test::AsTensor(input)}}; DataVec output_data{{"my_act", ConstructTensor(6)}}; BuildAndRun(input_data, &output_data); @@ -2370,7 +2813,7 @@ TEST_F(OpConverterTest, ConvertExpandDims) { AddTestWeights("weights", {1}, {0}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Modifying batch dimension is not supported for ExpandDims, at " + "TensorRT does not allow manipulation of the batch dimension, at " "my_expanddims"); } { @@ -2381,7 +2824,7 @@ TEST_F(OpConverterTest, ConvertExpandDims) { AddTestWeights("weights", {1}, {-5}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Modifying batch dimension is not supported for ExpandDims, at " + "TensorRT does not allow manipulation of the batch dimension, at " "my_expanddims"); } { @@ -2392,8 +2835,8 @@ TEST_F(OpConverterTest, ConvertExpandDims) { AddTestWeights("weights", {1}, {5}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Axis for ExpandDims is invalid, must be in the range " - "[-rank(input) - 1, rank(input)], at my_expanddims"); + "Axis value of 5 is out of bounds, must be in range [-5, 5), at " + "my_expanddims"); } { // Axis < -rank(input)-1, should fail. @@ -2403,8 +2846,8 @@ TEST_F(OpConverterTest, ConvertExpandDims) { AddTestWeights("weights", {1}, {-6}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Axis for ExpandDims is invalid, must be in the range " - "[-rank(input) - 1, rank(input)], at my_expanddims"); + "Axis value of -6 is out of bounds, must be in range [-5, 5), at " + "my_expanddims"); } struct TestParams { @@ -2428,7 +2871,7 @@ TEST_F(OpConverterTest, ConvertExpandDims) { RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); @@ -2488,7 +2931,8 @@ TEST_F(OpConverterTest, ConvertSqueeze) { NodeDef node_def = get_squeeze_nodedef({0}); AddTestTensor("input", {1, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "Cannot squeeze batch dimension, at my_squeeze"); + "TensorRT does not allow manipulation of the " + "batch dimension, at my_squeeze"); } { // Squeeze batch dim via negative axis, should fail. @@ -2496,7 +2940,8 @@ TEST_F(OpConverterTest, ConvertSqueeze) { NodeDef node_def = get_squeeze_nodedef({-4}); AddTestTensor("input", {1, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "Cannot squeeze batch dimension, at my_squeeze"); + "TensorRT does not allow manipulation of the " + "batch dimension, at my_squeeze"); } { // Squeeze >= rank(input), should fail. @@ -2505,8 +2950,8 @@ TEST_F(OpConverterTest, ConvertSqueeze) { AddTestTensor("input", {1, 2, 3}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Axis for Squeeze is invalid, must be in the range " - "[-rank(input), rank(input)), at my_squeeze"); + "Axis value of 4 is out of bounds, must be in range [-4, 4), at " + "my_squeeze"); } { // Squeeze < -rank(input), should fail. @@ -2515,8 +2960,18 @@ TEST_F(OpConverterTest, ConvertSqueeze) { AddTestTensor("input", {1, 2, 3}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Axis for Squeeze is invalid, must be in the range " - "[-rank(input), rank(input)), at my_squeeze"); + "Axis value of -5 is out of bounds, must be in range [-4, 4), at " + "my_squeeze"); + } + { + // Squeeze an axis with size != 1, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({2}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Dimension 2 with size 2 cannot be squeezed because it must be size 1, " + "at my_squeeze"); } struct TestParams { @@ -2546,7 +3001,7 @@ TEST_F(OpConverterTest, ConvertSqueeze) { RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); @@ -2613,21 +3068,6 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { "The input \"begin\" for StridedSlice must be a constant, at " "my_strided_slice"); } - { - // Non-zero ellipsis_mask, should fail. - Reset(); - NodeDef node_def = get_strided_slice_nodedef( - /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2, - /*new_axis_mask=*/0, /*shrink_axis_mask=*/0); - AddTestTensor("input", {1, 2, 3}); - AddTestWeights("begin", {4}, {0, 0, 0, 0}); - AddTestWeights("end", {4}, {1, 1, 2, 3}); - AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "ellipsis_mask is not supported for StridedSlice, at " - "my_strided_slice"); - } { // Modify batch dim, should fail. Reset(); @@ -2665,8 +3105,8 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestWeights("strides", {4}, {1, 1, 1, 1}); RunValidationAndConversion(node_def); } -// TRT 5.1+ supports strides -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +// TRT 5.1+ supports strides (disabled until 5.1.3.1 due to bugs) +#if IS_TRT_VERSION_GE(5, 1, 3, 1) { // Negative strides, should fail. Reset(); @@ -2714,6 +3154,7 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { std::vector strides; int begin_mask; int end_mask; + int ellipsis_mask; std::vector expected_output_dims; std::vector expected_output; }; @@ -2729,162 +3170,340 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { // Same input is used for all tests. const std::vector ok_input = {1, 2, 3, 4, 5, 6}; -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) - const int kStridedSliceOKCases = 23; +#if IS_TRT_VERSION_GE(5, 1, 3, 1) + const int kStridedSliceOKCases = 28; #else - const int kStridedSliceOKCases = 19; + const int kStridedSliceOKCases = 24; #endif // Ok. TestParams ok_params[kStridedSliceOKCases] = { // 2D Crop. - TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, 0, 0}, - /*end=*/{0, 0, 1, 2}, /*strides=*/{1, 1, 1, 1}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 0, 0}), - /*expected_output_dims=*/{1, 1, 2}, /*expected_output=*/{1, 2}}, TestParams{ /*input_dims=*/{1, 2, 3}, - /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 0, 1, 2}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 2}, - /*expected_output=*/{5, 6}}, + /*end_mask=*/get_mask({1, 1, 0, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{1, 2}, + }, TestParams{ /*input_dims=*/{1, 2, 3}, - /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 0, 1, 1}, + /*end=*/{0, 0, 0, 0}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, - /*expected_output=*/{5, 6}}, + /*end_mask=*/get_mask({1, 1, 1, 1}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}, + }, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, + /*end=*/{0, 1, 2, 3}, + /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}, + }, // 2D Crop, with transpose. TestParams{ /*input_dims=*/{2, 3, 1}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 1, 2, 1}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, - /*expected_output=*/{1, 2}}, + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{1, 2}, + }, TestParams{ /*input_dims=*/{2, 3, 1}, - /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 1, 1, 0}, + /*end=*/{0, 2, 3, 1}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, - /*expected_output=*/{5, 6}}, + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{5, 6}, + }, TestParams{ /*input_dims=*/{2, 1, 3}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 1, 1, 2}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, - /*expected_output=*/{1, 2}}, + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{1, 2}, + }, TestParams{ /*input_dims=*/{2, 1, 3}, - /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 1, 0, 1}, + /*end=*/{0, 2, 1, 3}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, - /*expected_output=*/{5, 6}}, + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}, + }, // 2D Crop, with reshape. - TestParams{/*input_dims=*/{2, 3}, - /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, /*strides=*/{1, 1, 1}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0}), - /*expected_output_dims=*/{1, 2}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 3}, - /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, /*strides=*/{1, 1, 1}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1}), - /*expected_output_dims=*/{1, 2}, - /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{2, 3}, + /*begin=*/{0, 0, 0}, + /*end=*/{0, 1, 2}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{1, 2}, + }, + TestParams{ + /*input_dims=*/{2, 3}, + /*begin=*/{0, 1, 1}, + /*end=*/{0, 0, 0}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{5, 6}, + }, // 1D Crop. TestParams{ /*input_dims=*/{1, 2, 3}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 0, 0, 2}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1, 0}), /*expected_output_dims=*/{1, 2, 2}, - /*expected_output=*/{1, 2, 4, 5}}, + /*end_mask=*/get_mask({1, 1, 1, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 2, 4, 5}, + }, TestParams{ /*input_dims=*/{1, 2, 3}, - /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 0, 1, 0}, + /*end=*/{0, 0, 0, 0}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 3}, - /*expected_output=*/{4, 5, 6}}, + /*end_mask=*/get_mask({1, 1, 1, 1}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 1, 3}, + /*expected_output=*/{4, 5, 6}, + }, // 1D Crop, with transpose. TestParams{ /*input_dims=*/{2, 3, 1}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 1, 0, 0}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, - /*expected_output=*/{1, 2, 3}}, + /*end_mask=*/get_mask({1, 0, 1, 1}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{1, 2, 3}, + }, TestParams{ /*input_dims=*/{2, 3, 1}, - /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin=*/{0, 1, 0, 0}, + /*end=*/{0, 0, 0, 0}, + /*strides=*/{1, 1, 1, 1}, /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, - /*expected_output=*/{4, 5, 6}}, + /*end_mask=*/get_mask({1, 1, 1, 1}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{4, 5, 6}, + }, // 1D Crop, with reshape. - TestParams{/*input_dims=*/{6}, - /*begin=*/{0, 0}, /*end=*/{0, 3}, /*strides=*/{1, 1}, - /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), - /*expected_output_dims=*/{3}, - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{1, 6}, - /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, /*strides=*/{1, 1, 1}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 1, 0}), - /*expected_output_dims=*/{1, 3}, - /*expected_output=*/{3, 4, 5}}, - TestParams{/*input_dims=*/{6, 1}, - /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, /*strides=*/{1, 1, 1}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 1}), - /*expected_output_dims=*/{3, 1}, - /*expected_output=*/{3, 4, 5}}, + TestParams{ + /*input_dims=*/{6}, + /*begin=*/{0, 0}, + /*end=*/{0, 3}, + /*strides=*/{1, 1}, + /*begin_mask=*/get_mask({0, 0}), + /*end_mask=*/get_mask({1, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 2, 3}, + }, + TestParams{ + /*input_dims=*/{1, 6}, + /*begin=*/{0, 0, 2}, + /*end=*/{0, 0, 5}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 3}, + /*expected_output=*/{3, 4, 5}, + }, + TestParams{ + /*input_dims=*/{6, 1}, + /*begin=*/{0, 2, 0}, + /*end=*/{0, 5, 0}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{3, 4, 5}, + }, // Negative axis. - TestParams{/*input_dims=*/{6, 1}, - /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, /*strides=*/{1, 1, 1}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 1}), - /*expected_output_dims=*/{3, 1}, - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{6, 1}, - /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, /*strides=*/{1, 1, 1}, - /*begin_mask=*/get_mask({0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 1}), - /*expected_output_dims=*/{5, 1}, - /*expected_output=*/{1, 2, 3, 4, 5}}, + TestParams{ + /*input_dims=*/{6, 1}, + /*begin=*/{0, -6, 0}, + /*end=*/{0, -3, 0}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{1, 2, 3}, + }, + TestParams{ + /*input_dims=*/{6, 1}, + /*begin=*/{0, 0, 0}, + /*end=*/{0, -1, 0}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{5, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}, + }, // Clamp out of bounds begin and end. - TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, -9999, -9}, - /*end=*/{0, 1, 1000, 4}, /*strides=*/{1, 1, 1, 1}, - /*begin_mask=*/get_mask({0, 0, 0, 0}), - /*end_mask=*/get_mask({1, 0, 0, 0}), - /*expected_output_dims=*/{1, 2, 3}, - /*expected_output=*/{1, 2, 3, 4, 5, 6}}, -#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, -9999, -9}, + /*end=*/{0, 1, 1000, 4}, + /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}, + }, +#if IS_TRT_VERSION_GE(5, 1, 3, 1) // Strides - TestParams{/*input_dims=*/{6}, - /*begin=*/{0, 0}, /*end=*/{0, 5}, /*strides=*/{1, 2}, - /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), - /*expected_output_dims=*/{3}, - /*expected_output=*/{1, 3, 5}}, - TestParams{/*input_dims=*/{6}, - /*begin=*/{0, 0}, /*end=*/{0, 6}, /*strides=*/{1, 2}, - /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), - /*expected_output_dims=*/{3}, - /*expected_output=*/{1, 3, 5}}, - TestParams{/*input_dims=*/{6}, - /*begin=*/{0, 1}, /*end=*/{0, 6}, /*strides=*/{1, 2}, - /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), - /*expected_output_dims=*/{3}, - /*expected_output=*/{2, 4, 6}}, - TestParams{/*input_dims=*/{6}, - /*begin=*/{0, 2}, /*end=*/{0, 6}, /*strides=*/{1, 3}, - /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), - /*expected_output_dims=*/{2}, - /*expected_output=*/{3, 6}}, + TestParams{ + /*input_dims=*/{6}, + /*begin=*/{0, 0}, + /*end=*/{0, 5}, + /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), + /*end_mask=*/get_mask({1, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}, + }, + TestParams{ + /*input_dims=*/{6}, + /*begin=*/{0, 0}, + /*end=*/{0, 6}, + /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), + /*end_mask=*/get_mask({1, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}, + }, + TestParams{ + /*input_dims=*/{6}, + /*begin=*/{0, 1}, + /*end=*/{0, 6}, + /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), + /*end_mask=*/get_mask({1, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{3}, + /*expected_output=*/{2, 4, 6}, + }, + TestParams{ + /*input_dims=*/{6}, + /*begin=*/{0, 2}, + /*end=*/{0, 6}, + /*strides=*/{1, 3}, + /*begin_mask=*/get_mask({0, 0}), + /*end_mask=*/get_mask({1, 0}), + /*ellipsis_mask=*/0, + /*expected_output_dims=*/{2}, + /*expected_output=*/{3, 6}, + }, #endif + // ellipsis_mask + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 1}, + /*end=*/{0, 2}, + /*strides=*/{1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({0, 0, 0, 0}), + /*ellipsis_mask=*/get_mask({1, 0, 0, 0}), + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 5}, + }, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1}, + /*end=*/{0, 0, 2}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({1, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*ellipsis_mask=*/get_mask({0, 1, 0, 0}), + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 5}, + }, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 0, 1}, + /*end=*/{0, 1, 2, 2}, + /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({0, 0, 0, 0}), + /*ellipsis_mask=*/get_mask({1, 0, 0, 0}), + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 5}, + }, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 0, 1}, + /*end=*/{1, 1, 2, 2}, + /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({0, 0, 0, 0}), + /*ellipsis_mask=*/get_mask({0, 1, 0, 0}), + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 5}, + }, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 0, 0, 1}, + /*end=*/{0, 1, 1, 2, 2}, + /*strides=*/{1, 1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({0, 0, 0, 0}), + /*ellipsis_mask=*/get_mask({1, 0, 0, 0}), + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 5}, + }, }; for (int i = 0; i < kStridedSliceOKCases; i++) { Reset(); NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask, - ok_params[i].end_mask); + ok_params[i].end_mask, + ok_params[i].ellipsis_mask); AddTestTensor("input", ok_params[i].input_dims); AddTestWeights("begin", {static_cast(ok_params[i].begin.size())}, @@ -2898,7 +3517,7 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); @@ -3040,7 +3659,7 @@ TEST_F(OpConverterTest, ConvertSlice) { TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_slice", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); @@ -3309,7 +3928,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); @@ -3333,7 +3952,9 @@ TEST_F(OpConverterTest, ConvertTopK) { "TopKV2 got 0 inputs but expected 2, at my_topk"); } - for (const auto dtype : {DT_FLOAT, DT_INT32}) { + // TODO(tmorris): This test isn't setting the input dtype properly. TopK with + // int32 is unsupported by TRT. + for (const auto dtype : {DT_FLOAT}) { // Get the NodeDef for TopKV2. Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), dtype); @@ -3360,7 +3981,7 @@ TEST_F(OpConverterTest, ConvertTopK) { TF_EXPECT_OK(GetTensorOrWeights("my_topk", &outputs[0])); TF_EXPECT_OK(GetTensorOrWeights("my_topk:1", &outputs[1])); for (auto& output : outputs) { - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 2, 2}, output.tensor()->getDimensions()); } @@ -3399,15 +4020,25 @@ void TestConvertGather(OpConverterTest* test) { }; // Input is the same {1, 2, 3, 4, 5, 6} for all cases. - const int kGatherOKCases = 5; + const int kGatherOKCases = 7; + const std::vector params_input = {CType(1), CType(2), CType(3), + CType(4), CType(5), CType(6)}; TestParams ok_params[kGatherOKCases] = { - // Vector indices (output is rank(params)). - TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}}, - TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}}, - TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}}, - TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}}, - // Higher rank indices (output is rank(params) + rank(indices) - 1). - TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}}, + // Vector indices, and output rank is rank(params). + TestParams{{1, 2, 3}, {}, {0}, 3, {1, 2, 1}, {1, 4}}, + TestParams{{1, 2, 3}, {}, {1}, 2, {1, 1, 3}, {4, 5, 6}}, + // Indices with rank>1, and output rank is rank(params)+rank(indices)-1. + TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}}, + TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}}, + TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}}, + TestParams{ + {1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}}, + TestParams{{3, 2}, + {2, 2}, + {0, 0, 1, 0}, + 2, + {3, 1, 2, 2}, + {1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}}, }; // Ok. @@ -3421,19 +4052,17 @@ void TestConvertGather(OpConverterTest* test) { test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); // Create input in CType and convert expected output to CType. - std::vector inputs = {CType(1), CType(2), CType(3), - CType(4), CType(5), CType(6)}; std::vector converted_expected_output( ok_params[i].expected_output.begin(), ok_params[i].expected_output.end()); const DataVec input_data{ - {"params", test::AsTensor(inputs)}, + {"params", test::AsTensor(params_input)}, {"indices", test::AsTensor(ok_params[i].indices)}}; DataVec output_data{ {"my_gather", @@ -3643,14 +4272,14 @@ TEST_F(OpConverterTest, ConvertUnary) { // Add other unary ops to test. ops_to_test.push_back("Rsqrt"); // Ok. - for (string op_name : ops_to_test) { + for (const string& op_name : ops_to_test) { Reset(); NodeDef node_def = get_unary_nodedef(op_name); AddTestTensor("input", {1, 2, 3}); RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); - EXPECT_TRUE(output.is_tensor()); + ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); const std::vector input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; @@ -3665,6 +4294,1476 @@ TEST_F(OpConverterTest, ConvertUnary) { } } +// Get the NodeDef for ConcatV2. +// TODO(hinsu): Consider switching this to static function. +auto get_concat_nodedef = [](DataType dtype, int num_inputs) -> NodeDef { + Scope s = Scope::NewRootScope(); + std::vector values; + for (int i = 0; i < num_inputs; ++i) { + const string input_name = StrCat("values_", i); + values.push_back(ops::Placeholder(s.WithOpName(input_name), dtype)); + } + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto concat = ops::Concat(s.WithOpName("my_concat"), + absl::Span(values), axis); + return concat.operation.node()->def(); +}; + +template +void TestConvertConcat(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + struct TestParams { + std::vector> input_shapes; + std::vector> input_values; + int axis; + std::vector expected_output_dims; + std::vector expected_output; + }; + + const std::vector> common_input{ + InitTestVector(6), + InitTestVector(6, /*start_value=*/CType(6))}; + // TODO(hinsu): Use std::vector instead of an array to avoid use of explicit + // size. + const int kConcatOKCases = 4; + TestParams ok_params[kConcatOKCases] = { + { + /*input_shapes=*/{{1, 2, 3}, {1, 2, 3}}, + /*input_values=*/common_input, + /*axis=*/1, + /*expected_output_dims=*/{2, 2, 3}, + /*expected_output=*/InitTestVector(12), + }, + { + /*input_shapes=*/{{1, 2, 3}, {1, 2, 3}}, + /*input_values=*/common_input, + /*axis=*/2, + /*expected_output_dims=*/{1, 4, 3}, + /*expected_output=*/InitTestVector(12), + }, + { + /*input_shapes=*/{{1, 2, 3}, {1, 2, 3}}, + /*input_values=*/common_input, + /*axis=*/3, + /*expected_output_dims=*/{1, 2, 6}, + /*expected_output=*/ + {CType(0), CType(1), CType(2), CType(6), CType(7), CType(8), CType(3), + CType(4), CType(5), CType(9), CType(10), CType(11)}, + }, + { + /*input_shapes=*/{{1}, {2}, {3}, {1}, {1}, {2}}, + /*input_values=*/ + {{CType(1)}, + {CType(2), CType(3)}, + {CType(4), CType(5), CType(6)}, + {CType(7)}, + {CType(8)}, + {CType(9), CType(10)}}, + /*axis=*/1, + /*expected_output_dims=*/{10}, + /*expected_output=*/ + InitTestVector(10, /*start_value=*/CType(1)), + }, + }; + + for (int i = 0; i < kConcatOKCases; ++i) { + test->Reset(); + const int num_inputs = ok_params[i].input_shapes.size(); + EXPECT_EQ(num_inputs, ok_params[i].input_values.size()); + NodeDef node_def = get_concat_nodedef(dtype, num_inputs); + // Create inputs. + for (int j = 0; j < num_inputs; ++j) { + test->AddTestTensor(StrCat("values_", j), ok_params[i].input_shapes[j], 1, + TfDataTypeToTrt(dtype)); + } + test->AddTestWeights("axis", {1}, {ok_params[i].axis}); + test->RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_concat", &output)); + ASSERT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + // Create input data for tensors. + DataVec input_data; + for (int j = 0; j < num_inputs; ++j) { + input_data.push_back( + {StrCat("values_", j), + test::AsTensor(ok_params[i].input_values[j])}); + } + DataVec output_data{ + {"my_concat", + ConstructTensor(ok_params[i].expected_output.size())}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertConcat) { + { + // Axis is a tensor, should fail. + Reset(); + NodeDef node_def = get_concat_nodedef(DT_FLOAT, 2); + AddTestTensor("values_0", {1, 2, 3}); + AddTestTensor("values_1", {1, 2, 3}); + AddTestTensor("axis", {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for ConcatV2 must be a constant, at my_concat"); + } + { + // Axis is out of bounds, should fail. + Reset(); + NodeDef node_def = get_concat_nodedef(DT_FLOAT, 2); + AddTestTensor("values_0", {1, 2, 3}); + AddTestTensor("values_1", {1, 2, 3}); + AddTestWeights("axis", {1}, {4}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in " + "range [-4, 4), at my_concat"); + } + { + // Axis is batch dimension, should fail. + Reset(); + NodeDef node_def = get_concat_nodedef(DT_FLOAT, 2); + AddTestTensor("values_0", {1, 2, 3}); + AddTestTensor("values_1", {1, 2, 3}); + AddTestWeights("axis", {1}, {0}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_concat"); + } + { + // Inputs have inconsistent rank, should fail. + Reset(); + NodeDef node_def = get_concat_nodedef(DT_FLOAT, 2); + AddTestTensor("values_0", {1, 2, 3}); + AddTestTensor("values_1", {1, 6}); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Received inputs with inconsistent rank, at my_concat"); + } + { + // An input is a weight, should fail. + Reset(); + NodeDef node_def = get_concat_nodedef(DT_FLOAT, 2); + AddTestTensor("values_0", {1, 2, 3}); + AddTestWeights("values_1", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"values_1\" for ConcatV2 must be a tensor, at my_concat"); + } + { + // Inputs have inconsistent non-axis shapes, should fail. + Reset(); + NodeDef node_def = get_concat_nodedef(DT_FLOAT, 2); + AddTestTensor("values_0", {1, 2, 3}); + AddTestTensor("values_1", {1, 3, 2}); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Received inputs with inconsistent shape, at my_concat"); + } + + TestConvertConcat(this); + TestConvertConcat(this); + // TODO(tmorris): Enable once TRT adds support. + // TestConvertConcat(this); +} + +// Get the NodeDef for Split. +auto get_split_nodedef = [](DataType dtype, int num_split) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto value = ops::Placeholder(s.WithOpName("value"), dtype); + auto split = ops::Split(s.WithOpName("my_split"), axis, value, num_split); + return split.operation.node()->def(); +}; + +template +void TestConvertSplit(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + struct TestParams { + std::vector input_shape; + std::vector value; + int axis; + int num_split; + std::vector expected_output_dims; + std::vector> expected_outputs; + }; + + const std::vector common_input = InitTestVector(6); + const int kSplitOKCases = 4; + TestParams ok_params[kSplitOKCases] = { + // Identity (num_split = 1) + {/*input_shape=*/{1, 2, 3}, /*value=*/common_input, /*axis=*/1, + /*num_split=*/1, /*expected_output_dims=*/{1, 2, 3}, + /*expected_outputs=*/{InitTestVector(6)}}, + {/*input_shape=*/{1, 2, 3}, + /*value=*/common_input, + /*axis=*/3, + /*num_split=*/3, + /*expected_output_dims=*/{1, 2, 1}, + /*expected_outputs=*/ + {{CType(0), CType(3)}, {CType(1), CType(4)}, {CType(2), CType(5)}}}, + {/*input_shape=*/{1, 6}, + /*value=*/common_input, + /*axis=*/2, + /*num_split=*/6, + /*expected_output_dims=*/{1, 1}, + /*expected_outputs=*/ + {{CType(0)}, + {CType(1)}, + {CType(2)}, + {CType(3)}, + {CType(4)}, + {CType(5)}}}, + {/*input_shape=*/{1, 6}, + /*value=*/common_input, + /*axis=*/-1, + /*num_split=*/2, + /*expected_output_dims=*/{1, 3}, + /*expected_outputs=*/ + {InitTestVector(3), InitTestVector(3, CType(3))}}, + }; + + for (int i = 0; i < kSplitOKCases; ++i) { + test->Reset(); + NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split); + // Create inputs. + test->AddTestWeights("axis", {1}, {ok_params[i].axis}); + test->AddTestTensor("value", ok_params[i].input_shape, 1, + TfDataTypeToTrt(dtype)); + // Convert. + test->RunValidationAndConversion(node_def); + + // Get output tensors and verify output dims. + EXPECT_EQ(ok_params[i].expected_outputs.size(), ok_params[i].num_split); + std::vector outputs(ok_params[i].num_split); + DataVec output_data; + for (int j = 0; j < outputs.size(); ++j) { + const string name = j == 0 ? StrCat("my_split") : StrCat("my_split:", j); + TF_EXPECT_OK(test->GetTensorOrWeights(name, &outputs[j])); + EXPECT_TRUE(outputs[j].is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + outputs[j].tensor()->getDimensions()); + // Create buffer to store output. + output_data.push_back( + {name, + ConstructTensor(ok_params[i].expected_outputs[j].size())}); + } + + // Verify output values are correct. + const DataVec input_data{ + {"value", test::AsTensor(ok_params[i].value)}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + for (int j = 0; j < outputs.size(); ++j) { + EXPECT_THAT(GetSpanForData(output_data[j]), + ElementsAreArray(ok_params[i].expected_outputs[j])); + } + } +} + +TEST_F(OpConverterTest, ConvertSplit) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_split", "Split", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Split got 0 inputs but expected 2, at my_split"); + } + { + // Axis is a tensor, should fail. + Reset(); + NodeDef node_def = get_split_nodedef(DT_FLOAT, 1); + AddTestTensor("axis", {1}); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for Split must be a constant, at my_split"); + } + { + // Axis is out of bounds, should fail. + Reset(); + NodeDef node_def = get_split_nodedef(DT_FLOAT, 1); + AddTestWeights("axis", {1}, {4}); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in " + "range [-4, 4), at my_split"); + } + { + // Axis is out of bounds (negative), should fail. + Reset(); + NodeDef node_def = get_split_nodedef(DT_FLOAT, 1); + AddTestWeights("axis", {1}, {-5}); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of -5 is out of bounds, must be in " + "range [-4, 4), at my_split"); + } + { + // Axis is batch dimension, should fail. + Reset(); + NodeDef node_def = get_split_nodedef(DT_FLOAT, 1); + AddTestWeights("axis", {1}, {0}); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_split"); + } + { + // Value is a weight, should fail. + Reset(); + NodeDef node_def = get_split_nodedef(DT_FLOAT, 1); + AddTestWeights("axis", {1}, {1}); + AddTestWeights("value", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"value\" for Split must be a tensor, at my_split"); + } + { + // Dim is not evenly divisibly by num_split, should fail. + Reset(); + NodeDef node_def = get_split_nodedef(DT_FLOAT, 2); + AddTestWeights("axis", {1}, {3}); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Dimension 3 of size 3 is not evenly divisble by 2, at my_split"); + } + { + // num_split > dim size, should fail. + Reset(); + NodeDef node_def = get_split_nodedef(DT_FLOAT, 4); + AddTestWeights("axis", {1}, {3}); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Dimension 3 of size 3 is not evenly divisble by 4, at my_split"); + } + + TestConvertSplit(this); + TestConvertSplit(this); +#if IS_TRT_VERSION_GE(5, 1, 3, 1) + TestConvertSplit(this); +#endif +} + +// Get the NodeDef for Unpack (Unstack in TF API). +auto get_unpack_nodedef = [](DataType dtype, int num, int axis) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto value = ops::Placeholder(s.WithOpName("value"), dtype); + auto unstack_attrs = ops::Unstack::Axis(axis); + auto unstack = + ops::Unstack(s.WithOpName("my_unpack"), value, num, unstack_attrs); + return unstack.operation.node()->def(); +}; + +template +void TestConvertUnpack(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + struct TestParams { + std::vector input_shape; + std::vector value; + int axis; + int num; + std::vector expected_output_dims; + std::vector> expected_outputs; + }; + + const std::vector common_input = InitTestVector(6); + const int kUnpackOKCases = 4; + TestParams ok_params[kUnpackOKCases] = { + {/*input_shape=*/{1, 2, 3}, /*value=*/common_input, /*axis=*/1, + /*num=*/1, /*expected_output_dims=*/{2, 3}, + /*expected_outputs=*/{InitTestVector(6)}}, + {/*input_shape=*/{1, 2, 3}, + /*value=*/common_input, + /*axis=*/3, + /*num=*/3, + /*expected_output_dims=*/{1, 2}, + /*expected_outputs=*/ + {{CType(0), CType(3)}, {CType(1), CType(4)}, {CType(2), CType(5)}}}, + {/*input_shape=*/{6, 1}, + /*value=*/common_input, + /*axis=*/-2, + /*num=*/6, + /*expected_output_dims=*/{1}, + /*expected_outputs=*/ + {{CType(0)}, + {CType(1)}, + {CType(2)}, + {CType(3)}, + {CType(4)}, + {CType(5)}}}, + {/*input_shape=*/{6}, + /*value=*/common_input, + /*axis=*/1, + /*num=*/6, + /*expected_output_dims=*/{}, + /*expected_outputs=*/ + {{CType(0)}, + {CType(1)}, + {CType(2)}, + {CType(3)}, + {CType(4)}, + {CType(5)}}}, + }; + + for (int i = 0; i < kUnpackOKCases; ++i) { + test->Reset(); + NodeDef node_def = + get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis); + // Create inputs. + test->AddTestTensor("value", ok_params[i].input_shape, 1, + TfDataTypeToTrt(dtype)); + // Convert. + test->RunValidationAndConversion(node_def); + + // Get output tensors and verify output dims. + EXPECT_EQ(ok_params[i].expected_outputs.size(), ok_params[i].num); + std::vector outputs(ok_params[i].num); + DataVec output_data; + for (int j = 0; j < outputs.size(); ++j) { + const string name = j == 0 ? "my_unpack" : StrCat("my_unpack:", j); + TF_EXPECT_OK(test->GetTensorOrWeights(name, &outputs[j])); + EXPECT_TRUE(outputs[j].is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + outputs[j].tensor()->getDimensions()); + // Create buffer to store output. + output_data.push_back( + {name, + ConstructTensor(ok_params[i].expected_outputs[j].size())}); + } + + // Verify output values are correct. + const DataVec input_data{ + {"value", test::AsTensor(ok_params[i].value)}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + for (int j = 0; j < outputs.size(); ++j) { + EXPECT_THAT(GetSpanForData(output_data[j]), + ElementsAreArray(ok_params[i].expected_outputs[j])); + } + } +} + +TEST_F(OpConverterTest, ConvertUnpack) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_unpack", "Unpack", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Unpack got 0 inputs but expected 1, at my_unpack"); + } + { + // Value is weights, should fail. + Reset(); + NodeDef node_def = get_unpack_nodedef(DT_FLOAT, /*num=*/3, /*axis=*/3); + AddTestWeights("value", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"value\" for Unpack must be a tensor, at my_unpack"); + } + { + // Axis is out of bounds, should fail. + Reset(); + NodeDef node_def = get_unpack_nodedef(DT_FLOAT, /*num=*/1, /*axis=*/4); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in " + "range [-4, 4), at my_unpack"); + } + { + // Axis is out of bounds (negative), should fail. + Reset(); + NodeDef node_def = get_unpack_nodedef(DT_FLOAT, /*num=*/1, /*axis=*/-5); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of -5 is out of bounds, must be in " + "range [-4, 4), at my_unpack"); + } + { + // Axis is batch dimension, should fail. + Reset(); + NodeDef node_def = get_unpack_nodedef(DT_FLOAT, /*num=*/1, /*axis=*/0); + AddTestTensor("value", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_unpack"); + } + { + // Dim size does not match num, should fail. + Reset(); + NodeDef node_def = get_unpack_nodedef(DT_FLOAT, /*num=*/5, /*axis=*/2); + AddTestTensor("value", {1, 6}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Dimension 2 has size 6 which is not equal to num of 5, at my_unpack"); + } + { + // Output would be TF scalar, should fail. + Reset(); + NodeDef node_def = get_unpack_nodedef(DT_FLOAT, /*num=*/1, /*axis=*/0); + AddTestTensor("value", {}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Input \"value\" for Unpack must be rank 2 or greater, at my_unpack"); + } + + TestConvertUnpack(this); + TestConvertUnpack(this); +#if IS_TRT_VERSION_GE(5, 1, 3, 1) + TestConvertUnpack(this); +#endif +} + +// Get the NodeDef for Pack. +NodeDef GetPackNodeDef(DataType dtype, int num_inputs, int axis) { + Scope s = Scope::NewRootScope(); + std::vector values; + for (int i = 0; i < num_inputs; ++i) { + const string input_name = StrCat("values_", i); + values.push_back(ops::Placeholder(s.WithOpName(input_name), dtype)); + } + // Pack op is renamed to Stack in APIs. + auto pack = + ops::Stack(s.WithOpName("my_pack"), absl::Span(values), + ops::Stack::Axis(axis)); + return pack.operation.node()->def(); +} + +template +void TestConvertPack(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + struct TestParams { + std::vector> input_shapes; + std::vector> input_values; + int axis; + std::vector expected_output_dims; + std::vector expected_output; + }; + + const std::vector> common_input{ + InitTestVector(6), + InitTestVector(6, /*start_value=*/CType(6))}; + std::vector params = { + { + /*input_shapes=*/{{2, 3}, {2, 3}}, + /*input_values=*/common_input, + /*axis=*/1, + /*expected_output_dims=*/{2, 2, 3}, + /*expected_output=*/InitTestVector(12), + }, + { + /*input_shapes=*/{{2, 3}, {2, 3}}, + /*input_values=*/common_input, + /*axis=*/2, + /*expected_output_dims=*/{2, 2, 3}, + /*expected_output=*/ + {CType(0), CType(1), CType(2), CType(6), CType(7), CType(8), CType(3), + CType(4), CType(5), CType(9), CType(10), CType(11)}, + }, + { + /*input_shapes=*/{{2, 3}, {2, 3}}, + /*input_values=*/common_input, + /*axis=*/3, + /*expected_output_dims=*/{2, 3, 2}, + /*expected_output=*/ + {CType(0), CType(6), CType(1), CType(7), CType(2), CType(8), CType(3), + CType(9), CType(4), CType(10), CType(5), CType(11)}, + }, + { + /*input_shapes=*/{{2, 3}}, + /*input_values=*/{InitTestVector(6)}, + /*axis=*/1, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/InitTestVector(6), + }, + { + /*input_shapes=*/{{2, 3}}, + /*input_values=*/{InitTestVector(6)}, + /*axis=*/2, + /*expected_output_dims=*/{2, 1, 3}, + /*expected_output=*/InitTestVector(6), + }, + }; + + for (int i = 0; i < params.size(); ++i) { + test->Reset(); + const int num_inputs = params[i].input_shapes.size(); + EXPECT_EQ(num_inputs, params[i].input_values.size()); + + NodeDef node_def = GetPackNodeDef(dtype, num_inputs, params[i].axis); + // Create inputs. + for (int j = 0; j < num_inputs; ++j) { + test->AddTestTensor(StrCat("values_", j), params[i].input_shapes[j], 1, + TfDataTypeToTrt(dtype)); + } + test->RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_pack", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(params[i].expected_output_dims, + output.tensor()->getDimensions()); + // Create input data for tensors. + DataVec input_data; + for (int j = 0; j < num_inputs; ++j) { + input_data.push_back({StrCat("values_", j), + test::AsTensor(params[i].input_values[j])}); + } + DataVec output_data{ + {"my_pack", ConstructTensor(params[i].expected_output.size())}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertPack) { + { + // An input is a weight, should fail. + Reset(); + NodeDef node_def = GetPackNodeDef(DT_FLOAT, 2, /*axis=*/1); + AddTestTensor("values_0", {1, 2, 3}); + AddTestWeights("values_1", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"values_1\" for Pack must be a tensor, at my_pack"); + } + { + // Axis is out of bounds, should fail. + Reset(); + NodeDef node_def = GetPackNodeDef(DT_FLOAT, 2, /*axis=*/-5); + AddTestTensor("values_0", {2, 3}); + AddTestTensor("values_1", {2, 3}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of -5 is out of bounds, must be in " + "range [-4, 4), at my_pack"); + } + { + // Axis is batch dimension, should fail. + Reset(); + NodeDef node_def = GetPackNodeDef(DT_FLOAT, 2, /*axis=*/-4); + AddTestTensor("values_0", {2, 3}); + AddTestTensor("values_1", {2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_pack"); + } + { + // Inputs have inconsistent rank, should fail. + Reset(); + NodeDef node_def = GetPackNodeDef(DT_FLOAT, 2, /*axis=*/1); + AddTestTensor("values_0", {1, 2, 3}); + AddTestTensor("values_1", {1, 6}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Received inputs with inconsistent rank, at my_pack"); + } + { + // Inputs have inconsistent shapes, should fail. + Reset(); + NodeDef node_def = GetPackNodeDef(DT_FLOAT, 2, /*axis=*/1); + AddTestTensor("values_0", {1, 2}); + AddTestTensor("values_1", {2, 2}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Received inputs with inconsistent shape, at my_pack"); + } + + TestConvertPack(this); + TestConvertPack(this); + + // TODO(hinsu): Enable INT32 with TensorRT version 5.1.3 after testing. + // TestConvertPack(this); +} + +// Get the NodeDef for ArgMin or ArgMax. +template +NodeDef GetArgMinMaxNodeDef(DataType input_dtype, DataType output_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), input_dtype); + auto dimension = ops::Placeholder(s.WithOpName("dimension"), DT_INT32); + auto attrs = OpType::OutputType(output_dtype); + auto arg = OpType(s.WithOpName("my_arg"), input, dimension, attrs); + return arg.operation.node()->def(); +} + +template +void TestConvertArgMinMax(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + struct TestParams { + std::vector input_shape; + std::vector input_value; + int axis; + std::vector expected_output_dims; + std::vector expected_argmax_output; + std::vector expected_argmin_output; + }; + + const std::vector common_input = InitTestVector(6); + std::vector params = { + { + /*input_shape=*/{2, 3}, + /*input_value=*/common_input, + /*axis=*/2, + /*expected_output_dims=*/{2}, + /*expected_argmax_output=*/{2, 2}, + /*expected_argmin_output=*/{0, 0}, + }, + { + /*input_shape=*/{2, 3}, + /*input_value=*/common_input, + /*axis=*/-2, + /*expected_output_dims=*/{3}, + /*expected_argmax_output=*/{1, 1, 1}, + /*expected_argmin_output=*/{0, 0, 0}, + }, + { + /*input_shape=*/{6}, + /*input_value=*/common_input, + /*axis=*/1, + /*expected_output_dims=*/{}, + /*expected_argmax_output=*/{5}, + /*expected_argmin_output=*/{0}, + }, + { + /*input_shape=*/{10}, + /*input_value=*/ + {CType(-5), CType(3), CType(5), CType(1), CType(6), CType(-9), + CType(7), CType(1), CType(0), CType(-1)}, + /*axis=*/-1, + /*expected_output_dims=*/{}, + /*expected_argmax_output=*/{6}, + /*expected_argmin_output=*/{5}, + }, + }; + + for (int i = 0; i < params.size(); ++i) { + test->Reset(); + + NodeDef node_def = GetArgMinMaxNodeDef(dtype, DT_INT32); + // Create inputs. + test->AddTestTensor("input", params[i].input_shape, /*batch_size=*/1, + /*trt_dtype=*/TfDataTypeToTrt(dtype)); + test->AddTestWeights("dimension", {1}, {params[i].axis}); + test->RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_arg", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(params[i].expected_output_dims, + output.tensor()->getDimensions()); + // Create input data for tensors. + const DataVec input_data{ + {"input", test::AsTensor(params[i].input_value)}}; + DataVec output_data{ + {"my_arg", + ConstructTensor(params[i].expected_argmax_output.size())}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + + if (node_def.op() == "ArgMax") { + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(params[i].expected_argmax_output)); + } else if (node_def.op() == "ArgMin") { + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(params[i].expected_argmin_output)); + } else { + ASSERT_TRUE(false); + } + } +} + +TEST_F(OpConverterTest, ConvertArgMinMax) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_argmax", "ArgMax", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "ArgMax got 0 inputs but expected 2, at my_argmax"); + } + { + // Dimension is a tensor, should fail. + Reset(); + NodeDef node_def = GetArgMinMaxNodeDef(DT_FLOAT, DT_INT32); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("dimension", {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"dimension\" for ArgMax must be a constant, at my_arg"); + } + { + // Output type is INT64, should fail. + Reset(); + NodeDef node_def = GetArgMinMaxNodeDef(DT_FLOAT, DT_INT64); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("dimension", {1}, {3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Output type int64 is not supported, at my_arg"); + } + { + // Axis is batch dimension, should fail + Reset(); + NodeDef node_def = GetArgMinMaxNodeDef(DT_FLOAT, DT_INT32); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("dimension", {1}, {0}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the batch dimension, at " + "my_arg"); + } + + TestConvertArgMinMax(this); + TestConvertArgMinMax(this); + TestConvertArgMinMax(this); + TestConvertArgMinMax(this); + // TRT does not support int32 for TopK layer which is used to implement ArgMin + // and ArgMax. + // TestConvertArgMinMax(this); + // TestConvertArgMinMax(this); +} + +// Get the NodeDef for DepthToSpace or SpaceToSpace. +template +NodeDef GetDepthSpaceShuffleNodeDef(DataType dtype, int block_size, + string data_format) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto attrs = OpType::DataFormat(data_format); + auto shuffle = OpType(s.WithOpName("my_shuffle"), input, block_size, attrs); + return shuffle.operation.node()->def(); +} + +template +struct DepthSpaceShuffleTestParams { + std::vector input_dims; + std::vector input_value; + int block_size; + string data_format; + std::vector expected_output_dims; + std::vector expected_output; +}; + +template +void TestConvertDepthSpaceShuffle( + OpConverterTest* test, + const std::vector>& params) { + for (int i = 0; i < params.size(); ++i) { + test->Reset(); + + NodeDef node_def = GetDepthSpaceShuffleNodeDef( + dtype, params[i].block_size, params[i].data_format); + test->AddTestTensor("input", params[i].input_dims, 1, + TfDataTypeToTrt(dtype)); + test->RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_shuffle", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(params[i].expected_output_dims, + output.tensor()->getDimensions()); + + DataVec input_data{{"input", test::AsTensor(params[i].input_value)}}; + DataVec output_data{{"my_shuffle", ConstructTensor( + params[i].expected_output.size())}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(params[i].expected_output)); + } +} + +template +void TestConvertDepthToSpace(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + const std::vector common_input = InitTestVector(16); + std::vector> params = { + { + /*input_shape=*/{4, 2, 2}, + /*input_value=*/common_input, + /*block_size=*/2, + /*data_format=*/"NCHW", + /*expected_output_dims=*/{1, 4, 4}, + /*expected_output=*/ + CastTestVector( + {0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}), + }, + { + /*input_shape=*/{2, 2, 4}, + /*input_value=*/common_input, + /*block_size=*/2, + /*data_format=*/"NHWC", + /*expected_output_dims=*/{4, 4, 1}, + /*expected_output=*/ + CastTestVector( + {0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}), + }, + { + /*input_shape=*/{16, 1, 1}, + /*input_value=*/common_input, + /*block_size=*/4, + /*data_format=*/"NCHW", + /*expected_output_dims=*/{1, 4, 4}, + /*expected_output=*/InitTestVector(16), + }, + { + /*input_shape=*/{2, 2, 8}, + /*input_value=*/InitTestVector(32), + /*block_size=*/2, + /*data_format=*/"NHWC", + /*expected_output_dims=*/{4, 4, 2}, + /*expected_output=*/CastTestVector({0, 1, 2, 3, 8, + 9, 10, 11, 4, 5, + 6, 7, 12, 13, 14, + 15, 16, 17, 18, 19, + 24, 25, 26, 27, 20, + 21, 22, 23, 28, 29, + 30, 31}), + }, + }; + + TestConvertDepthSpaceShuffle(test, params); +} + +TEST_F(OpConverterTest, ConvertDepthToSpace) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_shuffle", "DepthToSpace", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "DepthToSpace got 0 inputs but expected 1, at my_shuffle"); + } + { + // Input is a weight, should fail. + Reset(); + NodeDef node_def = + GetDepthSpaceShuffleNodeDef(DT_FLOAT, 2, "NCHW"); + AddTestWeights("input", {4, 1, 1}, {1, 2, 3, 4}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"input\" for DepthToSpace must be a " + "tensor, at my_shuffle"); + } + { + // Input rank != 4 + Reset(); + NodeDef node_def = + GetDepthSpaceShuffleNodeDef(DT_FLOAT, 2, "NCHW"); + AddTestTensor("input", {16, 32}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "The input to DepthToSpace must be rank 4, at my_shuffle"); + } + { + // Channels not divisible by block_size, should fail. + Reset(); + NodeDef node_def = + GetDepthSpaceShuffleNodeDef(DT_FLOAT, 3, "NCHW"); + AddTestTensor("input", {16, 32, 32}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Number of channels must be divisible by " + "block_size*block_size, at my_shuffle"); + } + { + // Unsupported format, should fail. + Reset(); + NodeDef node_def = GetDepthSpaceShuffleNodeDef( + DT_FLOAT, 2, "NCHW_VECT_C"); + AddTestTensor("input", {16, 32, 32}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Data format NCHW_VECT_C is not supported, at my_shuffle"); + } + + TestConvertDepthToSpace(this); + TestConvertDepthToSpace(this); + TestConvertDepthToSpace(this); +} + +template +void TestConvertSpaceToDepth(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + const std::vector common_input = InitTestVector(16); + std::vector> params = { + { + /*input_shape=*/{1, 4, 4}, + /*input_value=*/common_input, + /*block_size=*/2, + /*data_format=*/"NCHW", + /*expected_output_dims=*/{4, 2, 2}, + /*expected_output=*/ + CastTestVector( + {0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15}), + }, + { + /*input_shape=*/{4, 4, 1}, + /*input_value=*/common_input, + /*block_size=*/2, + /*data_format=*/"NHWC", + /*expected_output_dims=*/{2, 2, 4}, + /*expected_output=*/ + CastTestVector( + {0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}), + }, + { + /*input_shape=*/{1, 4, 4}, + /*input_value=*/common_input, + /*block_size=*/4, + /*data_format=*/"NCHW", + /*expected_output_dims=*/{16, 1, 1}, + /*expected_output=*/InitTestVector(16), + }, + { + /*input_shape=*/{4, 4, 2}, + /*input_value=*/InitTestVector(32), + /*block_size=*/2, + /*data_format=*/"NHWC", + /*expected_output_dims=*/{2, 2, 8}, + /*expected_output=*/CastTestVector({0, 1, 2, 3, 8, + 9, 10, 11, 4, 5, + 6, 7, 12, 13, 14, + 15, 16, 17, 18, 19, + 24, 25, 26, 27, 20, + 21, 22, 23, 28, 29, + 30, 31}), + }, + }; + + TestConvertDepthSpaceShuffle(test, params); +} + +TEST_F(OpConverterTest, ConvertSpaceToDepth) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_shuffle", "SpaceToDepth", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "SpaceToDepth got 0 inputs but expected 1, at my_shuffle"); + } + { + // Input is a weight, should fail. + Reset(); + NodeDef node_def = + GetDepthSpaceShuffleNodeDef(DT_FLOAT, 2, "NCHW"); + AddTestWeights("input", {4, 1, 1}, {1, 2, 3, 4}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"input\" for SpaceToDepth must be a " + "tensor, at my_shuffle"); + } + { + // Input rank != 4 + Reset(); + NodeDef node_def = + GetDepthSpaceShuffleNodeDef(DT_FLOAT, 2, "NCHW"); + AddTestTensor("input", {16, 32}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "The input to SpaceToDepth must be rank 4, at my_shuffle"); + } + { + // Width not divisble by block_size, should fail. + Reset(); + NodeDef node_def = + GetDepthSpaceShuffleNodeDef(DT_FLOAT, 3, "NCHW"); + AddTestTensor("input", {16, 9, 32}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Width and height must be divisible by " + "block_size, at my_shuffle"); + } + { + // Height not divisble by block_size, should fail. + Reset(); + NodeDef node_def = + GetDepthSpaceShuffleNodeDef(DT_FLOAT, 3, "NCHW"); + AddTestTensor("input", {16, 32, 9}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Width and height must be divisible by " + "block_size, at my_shuffle"); + } + { + // Unsupported format, should fail. + Reset(); + NodeDef node_def = GetDepthSpaceShuffleNodeDef( + DT_FLOAT, 2, "NCHW_VECT_C"); + AddTestTensor("input", {16, 32, 32}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Data format NCHW_VECT_C is not supported, at my_shuffle"); + } + + TestConvertSpaceToDepth(this); + TestConvertSpaceToDepth(this); + TestConvertSpaceToDepth(this); +} + +#if IS_TRT_VERSION_GE(5, 1, 2, 0) +// Get the NodeDef for ClipByValue. +NodeDef GetClipByValueNodeDef(DataType dtype) { + Scope s = Scope::NewRootScope(); + auto t = ops::Placeholder(s.WithOpName("t"), dtype); + auto clip_value_min = ops::Placeholder(s.WithOpName("clip_value_min"), dtype); + auto clip_value_max = ops::Placeholder(s.WithOpName("clip_value_max"), dtype); + auto clip = ops::ClipByValue(s.WithOpName("my_clip"), t, clip_value_min, + clip_value_max); + return clip.operation.node()->def(); +} + +template +void TestConvertClipByValue(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + struct TestParams { + std::vector dims; + std::vector input_value; + CType clip_value_min; + CType clip_value_max; + std::vector expected_output; + }; + + const std::vector common_input = InitTestVector(6); + std::vector params = { + { + /*dims=*/{1, 2, 3}, + /*input_value=*/common_input, + /*clip_value_min=*/CType(2), + /*clip_value_max=*/CType(5), + /*expected_output=*/ + {CType(2), CType(2), CType(2), CType(3), CType(4), CType(5)}, + }, + { + /*dims=*/{2, 1, 3}, + /*input_value=*/common_input, + /*clip_value_min=*/CType(-1), + /*clip_value_max=*/CType(8), + /*expected_output=*/common_input, + }, + }; + + for (int i = 0; i < params.size(); ++i) { + test->Reset(); + + NodeDef node_def = GetClipByValueNodeDef(dtype); + test->AddTestTensor("t", params[i].dims, 1, TfDataTypeToTrt(dtype)); + test->AddTestWeights("clip_value_min", {1}, + {params[i].clip_value_min}); + test->AddTestWeights("clip_value_max", {1}, + {params[i].clip_value_max}); + test->RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_clip", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(params[i].dims, output.tensor()->getDimensions()); + + DataVec input_data{{"t", test::AsTensor(params[i].input_value)}}; + DataVec output_data{ + {"my_clip", ConstructTensor(params[i].expected_output.size())}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertClipByValue) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_clip", "ClipByValue", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "ClipByValue got 0 inputs but expected 3, at my_clip"); + } + { + // Input is a weight, should fail. + Reset(); + NodeDef node_def = GetClipByValueNodeDef(DT_FLOAT); + AddTestWeights("t", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("clip_value_min", {1}, {1}); + AddTestWeights("clip_value_max", {1}, {5}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"t\" for ClipByValue must be a " + "tensor, at my_clip"); + } + { + // Clip min is a tensor, should fail. + Reset(); + NodeDef node_def = GetClipByValueNodeDef(DT_FLOAT); + AddTestTensor("t", {1, 2, 3}); + AddTestTensor("clip_value_min", {1}); + AddTestWeights("clip_value_max", {1}, {1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"clip_value_min\" for ClipByValue " + "must be a constant, at my_clip"); + } + { + // Clip max is a tensor, should fail. + Reset(); + NodeDef node_def = GetClipByValueNodeDef(DT_FLOAT); + AddTestTensor("t", {1, 2, 3}); + AddTestWeights("clip_value_min", {1}, {1}); + AddTestTensor("clip_value_max", {1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"clip_value_max\" for ClipByValue " + "must be a constant, at my_clip"); + } + + TestConvertClipByValue(this); + TestConvertClipByValue(this); +} +#endif // IS_TRT_VERSION_GE(5, 1, 2, 0) + +// Get the NodeDef for SquaredDifference. +NodeDef GetSquaredDifferenceNodeDef(DataType dtype) { + Scope s = Scope::NewRootScope(); + auto x = ops::Placeholder(s.WithOpName("x"), dtype); + auto y = ops::Placeholder(s.WithOpName("y"), dtype); + auto squared_diff = + ops::SquaredDifference(s.WithOpName("my_squared_diff"), x, y); + return squared_diff.operation.node()->def(); +} + +template +void TestConvertSquaredDifference(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + struct TestParams { + std::vector dims_x; + std::vector dims_y; + std::vector value_x; + std::vector value_y; + std::vector expected_output_dims; + std::vector expected_output; + }; + + const std::vector common_input = InitTestVector(6); + std::vector params = { + { + /*dims_x=*/{1, 2, 3}, + /*dims_y=*/{1, 2, 3}, + /*value_x=*/common_input, + /*value_y=*/CastTestVector({0, -1, 3, 0, 10, -7}), + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/CastTestVector({0, 4, 1, 9, 36, 144}), + }, + { + /*dims_x=*/{1, 2, 3}, + /*dims_y=*/{1, 1, 3}, + /*value_x=*/common_input, + /*value_y=*/CastTestVector({0, 1, 2}), + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/CastTestVector({0, 0, 0, 9, 9, 9}), + }, + }; + + for (int i = 0; i < params.size(); ++i) { + test->Reset(); + + NodeDef node_def = GetSquaredDifferenceNodeDef(dtype); + test->AddTestTensor("x", params[i].dims_x, 1, TfDataTypeToTrt(dtype)); + test->AddTestTensor("y", params[i].dims_y, 1, TfDataTypeToTrt(dtype)); + test->RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_squared_diff", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(params[i].expected_output_dims, + output.tensor()->getDimensions()); + + DataVec input_data{{"x", test::AsTensor(params[i].value_x)}, + {"y", test::AsTensor(params[i].value_y)}}; + DataVec output_data{ + {"my_squared_diff", + ConstructTensor(params[i].expected_output.size())}}; + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertSquaredDifference) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_squared_diff", "SquaredDifference", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "SquaredDifference got 0 inputs but expected 2, at my_squared_diff"); + } + { + // Input is a weight, should fail. + Reset(); + NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT); + AddTestWeights("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestTensor("y", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"x\" for SquaredDifference must be " + "a tensor, at my_squared_diff"); + } + { + // Shapes are not broadcastable, should fail. + Reset(); + NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT); + AddTestTensor("x", {2, 3}); + AddTestTensor("y", {7, 5}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Infeasible broadcast scheme"); + } + + TestConvertSquaredDifference(this); + TestConvertSquaredDifference(this); +} + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +template +NodeDef MakeResizeNodeDef(std::string name, DataType dtype, + bool align_corners) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32); + auto attrs = typename OpType::Attrs().AlignCorners(align_corners); + auto resize = OpType(s.WithOpName(name), input, size, attrs); + return resize.operation.node()->def(); +} + +template +struct ResizeTestParams { + std::vector input_dims; + std::vector output_resize_dims; + std::vector input_values; + bool align_corners; + std::vector expected_output_dims; + std::vector expected_nearest_output_values; + std::vector expected_bilinear_output_values; +}; + +template +void TestConvertResize(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + std::vector> params{ + { + /*input_dims=*/{1, 2, 1}, // H, W, C + /*output_resize_dims=*/{2, 3}, // H_out, W_out + /*input_values=*/CastTestVector({2.0f, -1.0f}), + /*align_corners=*/false, + /*expected_output_dims=*/{2, 3, 1}, // H, W, C + /*expected_nearest_output_values=*/ + CastTestVector({2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f}), + /*expected_bilinear_output_values=*/ + CastTestVector({2.0f, 0.f, -1.0f, 2.0f, 0.f, -1.0f}), + }, + { + /*input_dims=*/{1, 2, 1}, // H, W, C + /*output_resize_dims=*/{2, 3}, // H_out, W_out + /*input_values=*/CastTestVector({2.0f, -1.0f}), + /*align_corners=*/true, + /*expected_output_dims=*/{2, 3, 1}, // H, W, C + /*expected_nearest_output_values=*/ + CastTestVector({2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f}), + /*expected_bilinear_output_values=*/ + CastTestVector({2.0f, 0.5f, -1.0f, 2.0f, 0.5f, -1.0f}), + }}; + + for (int i = 0; i < params.size(); ++i) { + test->Reset(); + // Create resize node. + NodeDef node_def = + MakeResizeNodeDef("my_resize", dtype, params[i].align_corners); + // Create input tensor + test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1, + /*trt_dtype=*/TfDataTypeToTrt(dtype)); + // Create output size. + test->AddTestWeights("size", {2}, params[i].output_resize_dims); + + test->RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_resize", &output)); + + // Create input data for tensors. + const DataVec input_data{ + {"input", test::AsTensor(params[i].input_values)}}; + DataVec output_data{ + {"my_resize", ConstructTensor( + params[i].expected_nearest_output_values.size())}}; + + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + + if (node_def.op() == "ResizeBilinear") { + ExpectArrayAlmostEqual(params[i].expected_bilinear_output_values, + GetSpanForData(output_data[0]), + CType(1e-3)); + } else if (node_def.op() == "ResizeNearestNeighbor") { + ExpectArrayAlmostEqual(params[i].expected_nearest_output_values, + GetSpanForData(output_data[0]), + CType(1e-3)); + } + } +} + +TEST_F(OpConverterTest, ConvertResize) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_resize", "ResizeBilinear", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "ResizeBilinear got 0 inputs but expected 2, at my_resize"); + } + { + // First input is weight, should fail. + Reset(); + NodeDef node_def = + MakeResizeNodeDef("my_resize", DT_FLOAT, false); + AddTestWeights("input", {1, 2}, {1, 2}); + AddTestWeights("size", {1, 2}, {1, 2}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"input\" for ResizeBilinear must be a " + "tensor, at my_resize"); + } + { + // output dimension is a tensor, should fail. + Reset(); + NodeDef node_def = + MakeResizeNodeDef("my_resize", DT_FLOAT, false); + AddTestTensor("input", {1, 2}); + AddTestTensor("size", {1, 2}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"size\" for ResizeBilinear must be a " + "constant, at my_resize"); + } + TestConvertResize(this); + TestConvertResize(this); + TestConvertResize(this); + TestConvertResize(this); +} +#endif // IS_TRT_VERSION_GE(6, 0, 0, 0) + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index 0ca3a5a4a58..ca21c193d63 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -21,19 +21,6 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -bool IsGoogleTensorRTEnabled() { - // TODO(laigd): consider also checking if tensorrt shared libraries are - // accessible. We can then direct users to this function to make sure they can - // safely write code that uses tensorrt conditionally. E.g. if it does not - // check for for tensorrt, and user mistakenly uses tensorrt, they will just - // crash and burn. -#if GOOGLE_CUDA && GOOGLE_TENSORRT - return true; -#else - return false; -#endif -} - Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name) { switch (mode) { case TrtPrecisionMode::FP32: diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 0aa602dda2f..91c8c660f85 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -33,8 +33,6 @@ struct TrtDestroyer { template using TrtUniquePtrType = std::unique_ptr>; -bool IsGoogleTensorRTEnabled(); - enum class TrtPrecisionMode { FP32, FP16, INT8 }; Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 0f800d7cf26..51ac9528864 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" @@ -40,8 +41,8 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { @@ -54,11 +55,20 @@ using ::nvinfer1::IRuntime; // Helps simultaneous execution of native and TRT engines. class AsyncHelper : public core::RefCounted { public: - AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; } - ~AsyncHelper() override { done_(); } + AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {} + + ~AsyncHelper() override { this->operator()(); } + + void operator()() { + if (!called_) { + done_(); + called_ = true; + } + } private: AsyncOpKernel::DoneCallback done_; + bool called_ = false; // Has `done_` been called? }; // This OP can construct TRTEngine on the fly and if construction of engine @@ -170,17 +180,11 @@ Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) { " can't be found in function library"); } FunctionLibraryRuntime::InstantiateOptions inst_ops; - inst_ops.overlay_lib = nullptr; inst_ops.state_handle = ""; inst_ops.target = ctx->device()->name(); native_func_ = 0; - auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), - inst_ops, &native_func_); - if (!status.ok()) { - LOG(ERROR) << " Instantiating native function " << funcdef_name_ - << " failed!"; - } - return status; + return lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), inst_ops, + &native_func_); } TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) @@ -241,25 +245,16 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper) { - if (funcdef_name_.empty()) { - const string err_msg = StrCat("Fallback path is disabled, for ", name()); - LOG(WARNING) << err_msg; - ctx->SetStatus(errors::Internal(err_msg)); - return; - } + OP_REQUIRES_ASYNC(ctx, !funcdef_name_.empty(), + errors::Internal("Fallback path is disabled, for ", name()), + *helper); std::vector inputs; std::vector* outputs = new std::vector(); if (native_func_ == kInvalidHandle) { - auto status = ConstructFunctionHandle(ctx); - if (!status.ok()) { - LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_; - ctx->SetStatus(status); - return; - } + OP_REQUIRES_OK_ASYNC(ctx, ConstructFunctionHandle(ctx), *helper); } auto lib = ctx->function_library(); FunctionLibraryRuntime::Options opts; - opts.step_id = ctx->step_id(); opts.rendezvous = ctx->rendezvous(); opts.cancellation_manager = ctx->cancellation_manager(); opts.runner = ctx->runner(); @@ -272,12 +267,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, lib->Run(opts, native_func_, inputs, outputs, [this, ctx, outputs, helper](const Status& s) { core::ScopedUnref sc(helper); - if (!s.ok()) { - LOG(ERROR) << "Failed to execute native segment " << this->name() - << ": " << s; - ctx->SetStatus(s); - return; - } + OP_REQUIRES_OK_ASYNC(ctx, s, *helper); VLOG(1) << "Native Segment completed"; for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); @@ -291,17 +281,19 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, VLOG(1) << "Executing TRT calibration: " << name(); helper->Ref(); core::ScopedUnref sc(helper); - auto res_mgr = ctx->resource_manager(); TRTCalibrationResource* calib_res = nullptr; - OP_REQUIRES_OK(ctx, - res_mgr->LookupOrCreate( - "TF_TRT_Calibration", name(), - reinterpret_cast(&calib_res), - {[ctx, this](SerializableResourceBase** cr) -> Status { - return this->AllocateCalibrationResources(ctx, cr); - }})); + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->resource_manager()->LookupOrCreate( + "TF-TRT-Calibration", name(), + reinterpret_cast(&calib_res), + {[ctx, this](SerializableResourceBase** cr) -> Status { + return this->AllocateCalibrationResources(ctx, cr); + }}), + *helper); core::ScopedUnref calib_sc(calib_res); int num_inputs = ctx->num_inputs(); + // TODO(laigd): need to check that input shape matches. // Pass input data to calibrator std::unordered_map input_data; for (int i = 0; i < num_inputs; i++) { @@ -325,7 +317,17 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ->stream() ->implementation() ->GpuStreamMemberHack())); - calib_res->calibrator_->setBatch(input_data, *stream); + // If calibrator is terminated before, it means an error has occurred. + // + // Note: setBatch() will wait until TRTInt8Calibrator::getBatch() is called + // the first time before proceeding, so if buildCudaEngine() returns an error, + // it means getBatch() is never called, and the setBatch() here will hang + // until setDone() is called later by the calibration thread in + // AllocateCalibrationResources(). In that case, this setBatch() will always + // be able to detect the error and return false. + OP_REQUIRES_ASYNC(ctx, calib_res->calibrator_->setBatch(input_data, *stream), + errors::Internal("Failed to feed calibration data"), + *helper); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); } @@ -376,9 +378,9 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } EngineContext* engine_context = GetEngine(input_shapes, ctx); if (!engine_context->cuda_engine) { - LOG(WARNING) << "Engine retrieval for input shapes: " - << TensorShapeUtils::ShapeListString(input_shapes) - << " failed. Running native segment for " << name(); + VLOG(1) << "Engine retrieval for input shapes: " + << TensorShapeUtils::ShapeListString(input_shapes) + << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } @@ -426,8 +428,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, const_cast(input_tensor.flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(ERROR) << "FP16 inputs are not supported yet!"; - return kRetry; + buffers[binding_index] = + const_cast(input_tensor.flat().data()); + break; case nvinfer1::DataType::kINT8: LOG(ERROR) << "INT8 inputs are not supported yet!"; return kRetry; @@ -481,8 +484,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, const_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(WARNING) << "half size is not supported yet!"; - return kRetry; + buffers[binding_index] = + const_cast(output_tensor->flat().data()); + break; case nvinfer1::DataType::kINT8: LOG(WARNING) << "int8 is not supported yet!"; return kRetry; @@ -523,18 +527,33 @@ EngineContext* TRTEngineOp::GetEngine( // TODO(tmorris): using first input to get batch size - is this reliable? const int batch_size = input_shapes[0].dim_size(0); - // Get engine cache + // Canonicalize the op name by removing the scopes if any. This is mainly + // because in TFv2, the function graph can be instantiated in various ways and + // it'll insert scope names to the name of the TRTEngineOps, which will result + // in many different engine caches if we use the instantiated op name + // directly, but we still want all of them share the same cache (if they were + // representing the same subgraph). + absl::string_view resource_name = name(); + size_t last_slash = resource_name.find_last_of('/'); + if (last_slash != absl::string_view::npos) { + resource_name.remove_prefix(last_slash + 1); + } + + // Get engine cache. TRTEngineCacheResource* cache_res = nullptr; auto status = ctx->resource_manager()->LookupOrCreate( - "TRTEngineCache", name(), &cache_res, + "TF-TRT-Engine-Cache", string(resource_name), &cache_res, {[this, ctx](TRTEngineCacheResource** cr) -> Status { *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_); return Status::OK(); }}); if (!status.ok()) { - ctx->SetStatus(status); + LOG(WARNING) << "Not able to find or create engine cache for " << name() + << ". The native segment will be used instead. " + << "Reason: " << status; return &empty_context; } + core::ScopedUnref sc(cache_res); auto& cache = cache_res->cache_; auto allocator = cache_res->allocator_.get(); @@ -625,22 +644,21 @@ EngineContext* TRTEngineOp::GetEngine( partial_shapes, &logger, allocator, calibrator_.get(), &engine, use_calibration_, &convert_successfully); if (!status.ok()) { - if (convert_successfully) { - // This means it fail to build the engine even when the network is built - // successfully, probably due to internal issues. In this case we don't - // retry in the future. - cache.emplace(engine_input_shapes, absl::make_unique()); - } - LOG(WARNING) << "Engine creation for batch size " << batch_size - << " failed " << status; + LOG(WARNING) << "Engine creation for " << name() << " failed. " + << "The native segment will be used instead. " + << "Reason: " << status; + // Store an empty engine in the cache for these input shapes so we don't + // try to build the same failing engine again. + cache.emplace(engine_input_shapes, absl::make_unique()); return &empty_context; } - VLOG(1) << "Conversion is done"; TrtUniquePtrType exec_context( engine->createExecutionContext()); cache.emplace(engine_input_shapes, absl::make_unique(std::move(engine), std::move(exec_context))); + VLOG(1) << "Added new engine to cache of " << name() + << ". Cache size: " << cache.size(); } return cache.at(engine_input_shapes).get(); } diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc new file mode 100644 index 00000000000..b62fdc5dc4b --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -0,0 +1,106 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include + +#include +#include +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" + +namespace tensorflow { +namespace tensorrt { +using ::testing::ElementsAre; + +template +class TRTEngineOpTest : public OpsTestBase {}; + +using TypeList = ::testing::Types; +TYPED_TEST_SUITE(TRTEngineOpTest, TypeList); + +TYPED_TEST(TRTEngineOpTest, Basic) { + DataType dtype = DataTypeToEnum::v(); + // Create the GPU device. + std::unique_ptr device( + DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); + + // Create simple TF graph. + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype, + ops::Placeholder::Shape({1, 2})); + auto add = ops::Add(s.WithOpName("add"), feed, feed); + ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add); + + // Serialize the graph. TRTEngineOp will convert it using dynamic mode. + GraphDef graph_def; + TF_ASSERT_OK(s.ToGraphDef(&graph_def)); + TensorShapeProto shape; + TensorShape({1, 2}).AsProto(&shape); + + // Create the op. + OpsTestBase::SetDevice(DEVICE_GPU, std::move(device)); + TF_ASSERT_OK(NodeDefBuilder("op", "TRTEngineOp") + .Input(FakeInput(1, dtype)) + .Attr("input_shapes", {shape}) + .Attr("output_shapes", {shape}) + .Attr("static_engine", false) + .Attr("segment_funcdef_name", "") // no native fallback + .Attr("serialized_segment", graph_def.SerializeAsString()) + .Attr("calibration_data", "") + .Attr("max_cached_engines_count", 1) + .Attr("workspace_size_bytes", 1 << 20) + .Attr("precision_mode", "FP32") + .Attr("use_calibration", false) + .Attr("OutT", {dtype}) + .Finalize(OpsTestBase::node_def())); + TF_ASSERT_OK(OpsTestBase::InitOp()); + + // Execute the op. + OpsTestBase::AddInputFromArray(TensorShape({1, 2}), + {TypeParam(0.0f), TypeParam(1.0f)}); + TF_ASSERT_OK(OpsTestBase::RunOpKernel()); + + // Verify the result. + // TODO(laigd): OpsTestBase::GetOutput() doesn't work. + Tensor* output = OpsTestBase::context_->mutable_output(0); + const auto& tensor_map = output->flat(); + std::vector output_data(tensor_map.size()); + ASSERT_EQ(0, cudaDeviceSynchronize()); + ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(), + sizeof(TypeParam) * tensor_map.size(), + cudaMemcpyDeviceToHost)); + EXPECT_THAT(absl::Span(output_data), + ElementsAre(TypeParam(0.0f), TypeParam(2.0f))); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc new file mode 100644 index 00000000000..a41a8a2c1c4 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -0,0 +1,223 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +using ::nvinfer1::IRuntime; + +class CreateTRTEngineCache : public OpKernel { + public: + explicit CreateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_)); + } + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "Creating TRT engine cache resource in container " << container_ + << " for op " << resource_name_ << " on device " + << ctx->device()->name(); + OP_REQUIRES_OK(ctx, + ctx->resource_manager()->Create( + container_, resource_name_, + new TRTEngineCacheResource(ctx, max_cached_engines_))); + + Tensor* handle; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); + handle->scalar()() = + MakeResourceHandle(ctx, container_, + resource_name_); + } + + private: + string container_; + string resource_name_; + + // Maximum number of cached engines + int max_cached_engines_; + + TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCache); +}; + +REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCache") + .Device(DEVICE_GPU) + .HostMemory("engine_cache_handle"), + CreateTRTEngineCache); + +class PopulateTRTEngineCache : public OpKernel { + public: + explicit PopulateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + ResourceHandle handle = HandleFromInput(ctx, 0); + TRTEngineCacheResource* resource = nullptr; + OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource)); + core::ScopedUnref unref_me(resource); + + auto allocator = resource->allocator_.get(); + OP_REQUIRES(ctx, allocator != nullptr, + errors::Internal("Not able to initialize TRT engine cache when " + "GPU allocator is empty.")); + OP_REQUIRES(ctx, resource->cache_.size() == 0, + errors::Internal("Expect engine cache to be empty, but got ", + resource->cache_.size(), " entries.")); + + // Get the file name. + const string& filename = ctx->input(1).scalar()(); + OP_REQUIRES(ctx, !filename.empty(), + errors::InvalidArgument("filename cannot be empty.")); + + // Parse the serialized engines and add them to the cache. + std::unique_ptr file; + OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file)); + auto reader = absl::make_unique(file.get()); + + uint64 offset = 0; + int num_loaded_engine = 0; + do { + string record; + Status status = reader->ReadRecord(&offset, &record); + if (errors::IsOutOfRange(status)) break; + + TRTEngineInstance engine_instance; + engine_instance.ParseFromString(record); + std::vector engine_input_shapes; + for (const TensorShapeProto& shape : engine_instance.input_shapes()) { + engine_input_shapes.emplace_back(shape); + } + + TrtUniquePtrType infer( + nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger())); + infer->setGpuAllocator(allocator); + TrtUniquePtrType engine( + infer->deserializeCudaEngine( + engine_instance.serialized_engine().c_str(), + engine_instance.serialized_engine().size(), + PluginFactoryTensorRT::GetInstance())); + auto raw_engine = engine.get(); + resource->cache_.emplace( + engine_input_shapes, + absl::make_unique( + std::move(engine), TrtUniquePtrType( + raw_engine->createExecutionContext()))); + ++num_loaded_engine; + } while (1); + VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines to container " + << handle.container() << " for op " << handle.name() + << " on device " << ctx->device()->name() << " from file " + << filename; + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(PopulateTRTEngineCache); +}; + +REGISTER_KERNEL_BUILDER(Name("PopulateTRTEngineCache") + .Device(DEVICE_GPU) + .HostMemory("engine_cache_handle"), + PopulateTRTEngineCache); + +class DumpTRTEngineCache : public OpKernel { + public: + explicit DumpTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump", + &delete_cache_after_dump_)); + } + + void Compute(OpKernelContext* ctx) override { + const string& container = ctx->input(0).scalar()(); + const string& resource_name = ctx->input(1).scalar()(); + const string& filename = ctx->input(2).scalar()(); + OP_REQUIRES(ctx, !filename.empty(), + errors::InvalidArgument("filename cannot be empty.")); + + TRTEngineCacheResource* resource = nullptr; + OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( + container, resource_name, &resource)); + core::ScopedUnref unref_me(resource); + + // Serialize the engines and write them to file. + std::unique_ptr file; + OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file)); + auto writer = absl::make_unique(file.get()); + + for (const auto& pair : resource->cache_) { + TRTEngineInstance engine_instance; + // Add input shapes. + const std::vector& engine_input_shapes = pair.first; + for (const TensorShape& shape : engine_input_shapes) { + shape.AsProto(engine_instance.add_input_shapes()); + } + // Add the serialized engine. + const std::unique_ptr& engine = pair.second; + TrtUniquePtrType engine_data( + engine->cuda_engine->serialize()); + engine_instance.set_serialized_engine(engine_data->data(), + engine_data->size()); + + OP_REQUIRES_OK(ctx, + writer->WriteRecord(engine_instance.SerializeAsString())); + } + VLOG(1) << "Serialized " << resource->cache_.size() + << " TRT engines in container " << container << " for op " + << resource_name << " on device " << ctx->device()->name() + << " to file " << filename; + + if (delete_cache_after_dump_) { + VLOG(1) << "Destroying TRT engine cache resource in container " + << container << " for op " << resource_name << " on device " + << ctx->device()->name(); + OP_REQUIRES_OK(ctx, + ctx->resource_manager()->Delete( + container, resource_name)); + } + } + + private: + bool delete_cache_after_dump_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(DumpTRTEngineCache); +}; + +REGISTER_KERNEL_BUILDER(Name("DumpTRTEngineCache").Device(DEVICE_GPU), + DumpTRTEngineCache); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc new file mode 100644 index 00000000000..5281433ffc4 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class TRTEngineResourceOpsTest : public OpsTestBase { + protected: + void Reset() { + inputs_.clear(); + gtl::STLDeleteElements(&tensors_); + gtl::STLDeleteElements(&managed_outputs_); + } + + TrtUniquePtrType CreateTRTEngine() { + Logger logger; + TrtUniquePtrType builder( + nvinfer1::createInferBuilder(logger)); + TrtUniquePtrType network( + builder->createNetwork()); + + // Add the input. + nvinfer1::Dims dims; + dims.nbDims = 1; + dims.d[0] = 1; + nvinfer1::ITensor* input = + network->addInput("input", nvinfer1::DataType::kFLOAT, dims); + EXPECT_NE(nullptr, input); + + // Add a unary layer. + nvinfer1::IUnaryLayer* layer = + network->addUnary(*input, nvinfer1::UnaryOperation::kEXP); + EXPECT_NE(nullptr, layer); + + // Mark the output. + nvinfer1::ITensor* output = layer->getOutput(0); + output->setName("output"); + network->markOutput(*output); + + // Build the engine + builder->setMaxBatchSize(1); + builder->setMaxWorkspaceSize(1 << 10); + TrtUniquePtrType engine( + builder->buildCudaEngine(*network)); + EXPECT_NE(nullptr, engine); + return engine; + } +}; + +TEST_F(TRTEngineResourceOpsTest, Basic) { + // Create the GPU device. + std::unique_ptr device( + DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); + ResourceMgr* rm = device->resource_manager(); + SetDevice(DEVICE_GPU, std::move(device)); + + // Create the resource. + const string container = "mycontainer"; + const string resource_name = "myresource"; + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache") + .Attr("container", container) + .Attr("resource_name", resource_name) + .Attr("max_cached_engines_count", 1) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + TF_ASSERT_OK(RunOpKernel()); + ResourceHandle handle = + context_->mutable_output(0)->scalar()(); + + TRTEngineCacheResource* resource = nullptr; + EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok()); + + // Create a serialized TRT engine file. + TrtUniquePtrType engine = CreateTRTEngine(); + TrtUniquePtrType context( + engine->createExecutionContext()); + resource->cache_.emplace( + std::vector{TensorShape({1, 1})}, + absl::make_unique(std::move(engine), std::move(context))); + resource->Unref(); + + // Serialize the engine using DumpTRTEngineCache op. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "DumpTRTEngineCache") + .Attr("delete_cache_after_dump", true) + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({}), {container}); + AddInputFromArray(TensorShape({}), {resource_name}); + const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file"); + AddInputFromArray(TensorShape({}), {filename}); + TF_ASSERT_OK(RunOpKernel()); + + // Make sure the cache is deleted. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp") + .Attr("ignore_lookup_error", false) + .Input(FakeInput(DT_RESOURCE)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({}), {handle}); + EXPECT_TRUE(errors::IsNotFound(RunOpKernel())); + + // Verify the serialized engine file. + Env* env = Env::Default(); + std::unique_ptr file; + TF_ASSERT_OK(env->NewRandomAccessFile(filename, &file)); + auto reader = absl::make_unique(file.get()); + uint64 offset = 0; + string record; + TF_ASSERT_OK(reader->ReadRecord(&offset, &record)); + TRTEngineInstance engine_instance; + engine_instance.ParseFromString(record); + EXPECT_EQ(1, engine_instance.input_shapes_size()); + EXPECT_EQ(2, engine_instance.input_shapes(0).dim_size()); + EXPECT_EQ(1, engine_instance.input_shapes(0).dim(0).size()); + EXPECT_EQ(1, engine_instance.input_shapes(0).dim(1).size()); + EXPECT_TRUE(errors::IsOutOfRange(reader->ReadRecord(&offset, &record))); + + // Recreate the cache resource. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache") + .Attr("container", container) + .Attr("resource_name", resource_name) + .Attr("max_cached_engines_count", 1) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + TF_ASSERT_OK(RunOpKernel()); + handle = context_->mutable_output(0)->scalar()(); + EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok()); + EXPECT_EQ(0, resource->cache_.size()); + resource->Unref(); + + // Deserialize the engine using PopulateTRTEngineCache op. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache") + .Input(FakeInput(DT_RESOURCE)) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({}), {handle}); + AddInputFromArray(TensorShape({}), {filename}); + TF_ASSERT_OK(RunOpKernel()); + EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok()); + EXPECT_EQ(1, resource->cache_.size()); + resource->Unref(); + + // Destroy the engine cache again. + Reset(); + TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp") + .Attr("ignore_lookup_error", false) + .Input(FakeInput(DT_RESOURCE)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({}), {handle}); + TF_ASSERT_OK(RunOpKernel()); + EXPECT_TRUE(errors::IsNotFound(RunOpKernel())); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index b84d2fe0b8c..791ddc41b4f 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -24,12 +24,9 @@ limitations under the License. namespace tensorflow { -namespace shape_inference { -extern Status TRTEngineOpShapeInference(InferenceContext* c); -} - -// NOTE: please try NOT to add/modify/remove attributes or inputs/outputs to the -// list below, this will break backward compatibility! +// NOTE: when making changes please follow +// https://www.tensorflow.org/guide/extend/op#backwards_compatibility to not +// break backward compatibility. // // TODO(laigd): consider making this op stateful. The only problem is it uses TF // function which has to be stateless, but we can use function library as the @@ -41,8 +38,6 @@ REGISTER_OP("TRTEngineOp") .Attr("segment_funcdef_name: string") .Attr("InT: list({int8,float16,float32,int32})") .Attr("OutT: list({int8,float16,float32,int32})") - .Attr("static_engine: bool = true") - .Attr("fixed_input_size: bool = true") .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("max_cached_engines_count: int = 1") .Attr("workspace_size_bytes: int") @@ -57,8 +52,10 @@ REGISTER_OP("TRTEngineOp") // implementation, we do require all input tensor to carry the same batch // size, but this could change in the future). Hence we disable shape // inference function as a workaround. - // .SetShapeFn(shape_inference::TRTEngineOpShapeInference); - .SetShapeFn(shape_inference::UnknownShape); + .SetShapeFn(shape_inference::UnknownShape) + // Deprecated attributes. + .Attr("fixed_input_size: bool = true") + .Attr("static_engine: bool = true"); } // namespace tensorflow #endif // GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc new file mode 100644 index 00000000000..cf1909a0b47 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +REGISTER_OP("CreateTRTEngineCache") + .Attr("container: string") + .Attr("resource_name: string") + .Attr("max_cached_engines_count: int = 1") + .Output("engine_cache_handle: resource") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("PopulateTRTEngineCache") + .Input("engine_cache_handle: resource") + .Input("filename: string") + .SetIsStateful() + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("DumpTRTEngineCache") + .Attr("delete_cache_after_dump: bool = false") + .Input("container: string") + .Input("resource_name: string") + .Input("filename: string") + .SetIsStateful() + .SetShapeFn(shape_inference::NoOutputs); + +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h index f495d857037..0a55aadb7df 100644 --- a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h @@ -24,7 +24,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h index cce4f52d9f1..b445eb9d107 100644 --- a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h @@ -27,13 +27,20 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: + // TODO(b/131313301): Delete this when IPluginFactory is fixed upstream. + // IPluginFactory defines virtual methods and no virtual destructor. To avoid + // a non-virtual-dtor error, we need to add a virtual destructor here. Do not + // use a pointer to IPluginFactory because deleting through such a pointer + // results in undefined behavior. + virtual ~PluginFactoryTensorRT() {} + // TODO(aaroey): this static method has to be inlined to make the singleton a // unique global symbol. Find a way to fix it. static PluginFactoryTensorRT* GetInstance() { diff --git a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc index 7d9c465c22b..99a144b1737 100644 --- a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc @@ -23,7 +23,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h index e5eff15c196..256f8dcedb2 100644 --- a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h @@ -23,7 +23,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py index 25fb3a13db9..019ebdc0d7b 100644 --- a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -21,26 +21,32 @@ from __future__ import print_function import threading import platform +from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled from tensorflow.python.framework import errors -_trt_ops_so = None +_tf_trt_so = None _module_lock = threading.Lock() def load_trt_ops(): """Load TF-TRT op libraries so if it hasn't been loaded already.""" - global _trt_ops_so + global _tf_trt_so + + if not is_tensorrt_enabled(): + return if platform.system() == "Windows": raise RuntimeError("Windows platforms are not supported") with _module_lock: - if _trt_ops_so: + if _tf_trt_so: return try: # pylint: disable=g-import-not-at-top,unused-variable - # This registers the TRT ops, it doesn't require loading TRT library. + # This will call register_op_list() in + # tensorflow/python/framework/op_def_registry.py, but it doesn't register + # the op or the op kernel in C++ runtime. from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op # pylint: enable=g-import-not-at-top,unused-variable except ImportError as e: @@ -48,16 +54,16 @@ def load_trt_ops(): "not built with CUDA or TensorRT enabled. ****") raise e - # TODO(laigd): we should load TF-TRT kernels here as well after removing the - # swig binding. try: # pylint: disable=g-import-not-at-top from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader # pylint: enable=g-import-not-at-top - _trt_ops_so = load_library.load_op_library( - resource_loader.get_path_to_datafile("_trt_ops.so")) + # Loading the shared object will cause registration of the op and the op + # kernel if we link TF-TRT dynamically. + _tf_trt_so = load_library.load_op_library( + resource_loader.get_path_to_datafile("libtftrt.so")) except errors.NotFoundError as e: no_trt_message = ( "**** Failed to initialize TensorRT. This is either because the " diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 9cab9d70129..5d9a1b25210 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -443,6 +443,10 @@ Status SegmentGraph(const Graph* tf_graph, unsupported_ops.emplace(node->tf_node()->type_string()); num_unsupported_ops++; node = nullptr; + } else { + VLOG(2) << "Accepted as a TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name(); } } node_segments.emplace_back(node); @@ -455,7 +459,7 @@ Status SegmentGraph(const Graph* tf_graph, } LOG(INFO) << msg << "(For more information see " << "https://docs.nvidia.com/deeplearning" - << "/dgx/integrate-tf-trt/index.html#support-ops)."; + << "/dgx/tf-trt-user-guide/index.html#supported-ops)."; // The segmentation algorithm below visits nodes in reverse topological order // and attempts to merge nodes along output edges. That means that subgraphs @@ -668,10 +672,13 @@ Status SegmentGraph(const Graph* tf_graph, const string& segment_root = itr.first; // Return format does not require set comparator. std::set segment_nodes(itr.second.begin(), itr.second.end()); - if (VLOG_IS_ON(1)) { - string s = "parent=" + segment_root + ":"; - for (auto node : segment_nodes) s += " " + node->name(); - VLOG(1) << "Segment " << segments->size() << ": " << s; + if (VLOG_IS_ON(1) && !segment_nodes.empty()) { + string s; + for (auto node : segment_nodes) { + StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name()); + } + VLOG(1) << "Nodes in segment " << segments->size() + << " with parent=" << segment_root << ":" << s; } // Don't use small segments. diff --git a/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc index 769982c6456..510591bfe00 100644 --- a/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc +++ b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc @@ -20,9 +20,9 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "cuda/include/cuda.h" -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc new file mode 100644 index 00000000000..008cabb9cb4 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc @@ -0,0 +1,65 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" +#endif + +namespace tensorflow { +namespace tensorrt { + +bool IsGoogleTensorRTEnabled() { + // TODO(laigd): consider also checking if tensorrt shared libraries are + // accessible. We can then direct users to this function to make sure they can + // safely write code that uses tensorrt conditionally. E.g. if it does not + // check for for tensorrt, and user mistakenly uses tensorrt, they will just + // crash and burn. +#if GOOGLE_CUDA && GOOGLE_TENSORRT + return true; +#else + return false; +#endif +} + +void GetLinkedTensorRTVersion(int* major, int* minor, int* patch) { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + *major = NV_TENSORRT_MAJOR; + *minor = NV_TENSORRT_MINOR; + *patch = NV_TENSORRT_PATCH; +#else + *major = 0; + *minor = 0; + *patch = 0; +#endif +} + +void GetLoadedTensorRTVersion(int* major, int* minor, int* patch) { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + int ver = getInferLibVersion(); + *major = ver / 1000; + ver = ver - *major * 1000; + *minor = ver / 100; + *patch = ver - *minor * 100; +#else + *major = 0; + *minor = 0; + *patch = 0; +#endif +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/core/platform/default/fingerprint.h b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h similarity index 51% rename from tensorflow/core/platform/default/fingerprint.h rename to tensorflow/compiler/tf2tensorrt/utils/py_utils.h index f901befc16b..f52bb6f1bad 100644 --- a/tensorflow/core/platform/default/fingerprint.h +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,25 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_FINGERPRINT_H_ -#define TENSORFLOW_CORE_PLATFORM_DEFAULT_FINGERPRINT_H_ - -#include - -#include "tensorflow/core/lib/core/stringpiece.h" +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ namespace tensorflow { +namespace tensorrt { -inline uint64 Fingerprint64(StringPiece s) { - return ::util::Fingerprint64(s.data(), s.size()); -} +bool IsGoogleTensorRTEnabled(); -inline Fprint128 Fingerprint128(StringPiece s) { - const auto fingerprint = ::util::Fingerprint128(s.data(), s.size()); - return {::util::Uint128Low64(fingerprint), - ::util::Uint128High64(fingerprint)}; -} +// Return compile time TensorRT library version information {Maj, Min, Patch}. +void GetLinkedTensorRTVersion(int* major, int* minor, int* patch); +// Return runtime time TensorRT library version information {Maj, Min, Patch}. +void GetLoadedTensorRTVersion(int* major, int* minor, int* patch); + +} // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_FINGERPRINT_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ diff --git a/tensorflow/python/compiler/tensorrt/trt_conversion.i b/tensorflow/compiler/tf2tensorrt/utils/py_utils.i similarity index 62% rename from tensorflow/python/compiler/tensorrt/trt_conversion.i rename to tensorflow/compiler/tf2tensorrt/utils/py_utils.i index 4d187c7988f..d6e8eac5836 100644 --- a/tensorflow/python/compiler/tensorrt/trt_conversion.i +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.i @@ -17,8 +17,6 @@ limitations under the License. %{ #define SWIG_FILE_WITH_INIT %} -%include "std_string.i" -%include "tensorflow/python/platform/base.i" %{ struct version_struct{ @@ -40,22 +38,8 @@ PyObject* version_helper(version_struct* in) { return tuple; } -/* Define converters for vector */ -template<> -bool _PyObjAs(PyObject *pyobj, int* dest) { - *dest = PyLong_AsLong(pyobj); - return true; -} - -template<> -PyObject *_PyObjFrom(const int& src) { - return PyLong_FromLong(src); -} - %} -_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); - %typemap(out) version_struct { PyObject *tuple = version_helper(&$1); if (!tuple) SWIG_fail; @@ -63,39 +47,29 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); } %{ -#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" -#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" %} -%ignoreall -%unignore tensorflow; -%unignore get_linked_tensorrt_version; -%unignore get_loaded_tensorrt_version; -%unignore is_tensorrt_enabled; +%ignore ""; +%rename("%s") get_linked_tensorrt_version; +%rename("%s") get_loaded_tensorrt_version; +%rename("%s") is_tensorrt_enabled; %{ version_struct get_linked_tensorrt_version() { // Return the version at the link time. version_struct s; -#if GOOGLE_CUDA && GOOGLE_TENSORRT - const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion(); - s.vmajor = lv[0]; - s.vminor = lv[1]; - s.vpatch = lv[2]; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + tensorflow::tensorrt::GetLinkedTensorRTVersion( + &s.vmajor, &s.vminor, &s.vpatch); return s; } version_struct get_loaded_tensorrt_version() { // Return the version from the loaded library. version_struct s; -#if GOOGLE_CUDA && GOOGLE_TENSORRT - const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion(); - s.vmajor = lv[0]; - s.vminor = lv[1]; - s.vpatch = lv[2]; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + tensorflow::tensorrt::GetLoadedTensorRTVersion( + &s.vmajor, &s.vminor, &s.vpatch); return s; } @@ -109,4 +83,4 @@ version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); bool is_tensorrt_enabled(); -%unignoreall +%rename("%s") ""; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc index a18f758a551..8d2ae49a0d0 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc @@ -19,7 +19,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h index 8ec06d7456c..baab5aed35c 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h @@ -22,7 +22,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.proto b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.proto new file mode 100644 index 00000000000..e8394974478 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package tensorflow.tensorrt; + +import "tensorflow/core/framework/tensor_shape.proto"; + +// Containing information for a serialized TensorRT engine. +message TRTEngineInstance { + // The input shapes of the TRT engine. + repeated TensorShapeProto input_shapes = 1; + + // The serialized TRT engine. + // + // TODO(laigd): consider using a more efficient in-memory representation + // instead of string which is the default here. + bytes serialized_engine = 2; + + // TODO(laigd): consider adding calibration stats, precision_modes, etc. +} diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc index 33a5c719ba9..51aa7be07db 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc @@ -22,7 +22,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h index d34e244f6c7..76f3a0392f8 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h @@ -25,8 +25,8 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { @@ -57,6 +57,8 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { bool getBatch(void* bindings[], const char* names[], int num_bindings) override; + // Feed calibration data to the calibrator, and return true if the data is + // accepted. Return false if the calibrator has been terminated. bool setBatch(const std::unordered_map& data, const cudaStream_t stream); diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc index 6bc842ed5ca..a552f5160c6 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc @@ -48,7 +48,8 @@ void Logger::log(Severity severity, const char* msg) { // This is useless for now. But would catch it in future if enum changes. It // is always good to have default case! default: { - LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg; + LOG(FATAL) << name_ << "Got unknown severity level " << int(severity) + << " from TensorRT: " << msg; break; } } diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h index 22f4de970a8..2db9923d7dc 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h @@ -20,7 +20,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc new file mode 100644 index 00000000000..43dcd52b5a2 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" + +#include + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/mutex.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +Logger& TRTEngineCacheResource::GetLogger() { + static Logger* logger = new Logger(); + return *logger; +} + +TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx, + size_t capacity) + : cache_(capacity) { + auto device = ctx->device(); + auto alloc = device->GetAllocator(AllocatorAttributes()); + if (!alloc) { + LOG(ERROR) << "Can't find device allocator for gpu device " + << device->name(); + allocator_ = nullptr; + } else { + allocator_.reset(new TRTDeviceAllocator(alloc)); + } +} + +TRTEngineCacheResource::~TRTEngineCacheResource() { + VLOG(1) << "Destroying TRTEngineCacheResource..."; +} + +string TRTEngineCacheResource::DebugString() const { + std::stringstream oss; + using std::dec; + using std::endl; + using std::hex; + oss << "TRTEngineCacheResource: "; + oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", "; + oss << "LRUCache = " << hex << &cache_ << dec << endl; + oss << "Containing " << cache_.size() << " entries: " << endl; + for (const auto& item : cache_) { + mutex_lock lock(item.second->mu); + oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex + << "ICudaEngine: " << item.second->cuda_engine.get() << ", " + << "IExecutionContext: " << item.second->execution_context.get() << dec + << endl; + } + return oss.str(); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 8ece326446d..442e0bcfb53 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -21,11 +21,12 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/errors.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" #endif // GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { @@ -100,17 +101,14 @@ class LRUCache { } // Creates n free positions in cache - Status DiscardOld(size_t n = 0) { - if (n > capacity_) { - return errors::Internal("Insufficient capacity in cache (capacity = ", - capacity_, ", requested ", n, ")"); - } + void DiscardOld(size_t n = 0) { + DCHECK(capacity_ >= n) << "Insufficient capacity in cache (capacity = " + << capacity_ << ", requested " << n << ")"; while (objects_.size() > (capacity_ - n)) { key_type discard_key = keys_.back(); keys_.pop_back(); objects_.erase(discard_key); } - return Status::OK(); } }; @@ -141,36 +139,18 @@ struct EngineContext { class TRTEngineCacheResource : public ResourceBase { public: - TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity) - : cache_(capacity) { - auto device = ctx->device(); - auto alloc = device->GetAllocator(AllocatorAttributes()); - if (!alloc) { - LOG(ERROR) << "Can't find device allocator for gpu device " - << device->name(); - allocator_ = nullptr; - } else { - allocator_.reset(new TRTDeviceAllocator(alloc)); - } - } + // According to the TensorRT API, the logger is considered a singleton by the + // TensorRT library, and multiple instances of IRuntime and/or IBuilder must + // all use the same logger. So here we make it a singleton. + // + // TODO(laigd): use this logger in all places where conversion happens. + static Logger& GetLogger(); - string DebugString() const override { - std::stringstream oss; - using std::dec; - using std::endl; - using std::hex; - oss << "TRTEngineCacheResource: "; - oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", "; - oss << "LRUCache = " << hex << &cache_ << dec << endl; - oss << "Containing " << cache_.size() << " entries: " << endl; - for (const auto& item : cache_) { - oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex - << "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", " - << "IExecutionContext: " << item.second.get()->execution_context.get() - << dec << endl; - } - return oss.str(); - } + TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity); + + ~TRTEngineCacheResource() override; + + string DebugString() const override; // Keep device allocator for TRT. std::unique_ptr allocator_; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h index abfed2c1816..697cef5d788 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h @@ -32,7 +32,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 7d9e7b9fc1f..dcce43cbe70 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -10,20 +10,27 @@ package_group( "//tensorflow/compiler/tests/...", "//tensorflow/compiler/tf2xla/...", "//tensorflow/contrib/compiler/...", + "//tensorflow/python/compiler/...", ], ) package_group( name = "friends", includes = [":internal"], - packages = ["//tensorflow/..."], + packages = [ + "//learning/brain/tools/tf_replay/...", + "//tensorflow/...", + ], ) package( default_visibility = [":internal"], ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load( + "//tensorflow/core:platform/default/cuda_build_defs.bzl", + "if_cuda_is_configured", +) load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") cc_library( @@ -84,7 +91,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":common", - ":dump_graph", ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", @@ -103,22 +109,6 @@ cc_library( ], ) -cc_library( - name = "cpu_function_runtime", - srcs = ["cpu_function_runtime.cc"], - hdrs = ["cpu_function_runtime.h"], - visibility = [ - "//tensorflow/compiler/aot:__pkg__", - "//tensorflow/compiler/xla/service/cpu:__pkg__", - ], - deps = [ - # Keep dependencies to a minimum here; this library is used in every AOT - # binary produced by tfcompile. - "//tensorflow/compiler/xla:executable_run_options", - "//tensorflow/core:framework_lite", - ], -) - cc_library( name = "xla_compiled_cpu_function", srcs = ["xla_compiled_cpu_function.cc"], @@ -127,7 +117,7 @@ cc_library( deps = [ # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. - ":cpu_function_runtime", + "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], @@ -137,7 +127,7 @@ tf_cc_test( name = "cpu_function_runtime_test", srcs = ["cpu_function_runtime_test.cc"], deps = [ - ":cpu_function_runtime", + "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -153,6 +143,7 @@ cc_library( ":tf2xla", ":tf2xla_proto", ":xla_compiled_cpu_function", + "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", @@ -199,13 +190,13 @@ cc_library( visibility = [":friends"], deps = [ ":common", - ":dump_graph", ":host_compute_metadata_proto", ":sharding_util", ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/tf2xla:rearrange_function_argument", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -226,6 +217,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", @@ -301,6 +293,7 @@ cc_library( name = "tf2xla_util", srcs = ["tf2xla_util.cc"], hdrs = ["tf2xla_util.h"], + visibility = [":friends"], deps = [ ":sharding_util", ":tf2xla_proto", @@ -392,6 +385,7 @@ tf_cc_test( ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -406,6 +400,7 @@ tf_cc_test( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -449,23 +444,6 @@ tf_cc_test( ], ) -cc_library( - name = "dump_graph", - srcs = [ - "dump_graph.cc", - ], - hdrs = [ - "dump_graph.h", - ], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/jit:flags", - "//tensorflow/core:framework", - "//tensorflow/core:graph", - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "functionalize_control_flow_util", srcs = [ @@ -497,7 +475,6 @@ cc_library( ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", @@ -525,7 +502,28 @@ cc_library( ":functionalize_while", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", - "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "rearrange_function_argument", + srcs = [ + "rearrange_function_argument.cc", + ], + hdrs = [ + "rearrange_function_argument.h", + ], + deps = [ + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", @@ -562,7 +560,6 @@ cc_library( ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", - "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 6aff436da4f..1c94f38e06d 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -125,61 +125,125 @@ Status BackwardsConstAnalysis(const Graph& g, return status; } -Status GetCompileTimeConstInputs(const Node* node, - std::vector* const_input_idxs, - FunctionLibraryRuntime* flib_runtime) { - if (node->type_string() != "While") { - return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(), - const_input_idxs); - } - // For While nodes, recurse into the body and cond graphs. - // TODO(b/124403063): Implement similar functionality for cond nodes and other - // functional ops. - NameAttrList cond_function; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "cond", &cond_function)); - NameAttrList body_function; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "body", &body_function)); - FunctionLibraryRuntime::Handle cond_handle; - FunctionLibraryRuntime::Handle body_handle; - TF_RETURN_IF_ERROR(flib_runtime->Instantiate( - cond_function.name(), AttrSlice(&cond_function.attr()), &cond_handle)); - TF_RETURN_IF_ERROR(flib_runtime->Instantiate( - body_function.name(), AttrSlice(&body_function.attr()), &body_handle)); - const FunctionBody* fcond = flib_runtime->GetFunctionBody(cond_handle); - const FunctionBody* fbody = flib_runtime->GetFunctionBody(body_handle); - TF_RET_CHECK(fcond); - TF_RET_CHECK(fbody); - int num_inputs = fbody->fdef.signature().input_arg_size(); +namespace { - // Stores which of the loop inputs are expected to be compile time constants. +Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node, + StringPiece func_attr_name, const FunctionBody** fbody) { + NameAttrList name_attr_list; + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), func_attr_name, &name_attr_list)); + FunctionLibraryRuntime::Handle func_handle; + TF_RETURN_IF_ERROR(flib_runtime->Instantiate( + name_attr_list.name(), AttrSlice(&name_attr_list.attr()), &func_handle)); + *fbody = flib_runtime->GetFunctionBody(func_handle); + return Status::OK(); +} + +Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, const Node* node, + StringPiece func_list_attr_name, + std::vector* fbodies) { + std::vector name_attr_lists; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->def(), func_list_attr_name, &name_attr_lists)); + for (const NameAttrList& name_attr_list : name_attr_lists) { + FunctionLibraryRuntime::Handle func_handle; + TF_RETURN_IF_ERROR(flib_runtime->Instantiate( + name_attr_list.name(), AttrSlice(&name_attr_list.attr()), + &func_handle)); + fbodies->push_back(flib_runtime->GetFunctionBody(func_handle)); + } + return Status::OK(); +} + +Status CondConstInputIndices( + absl::Span branch_bodies, + std::vector* const_input_idxs, FunctionLibraryRuntime* flib_runtime) { + TF_RET_CHECK(!branch_bodies.empty()); + TF_RET_CHECK(branch_bodies[0] != nullptr); + int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size(); + // Stores indices of the "branch function" inputs that are expected to be + // compile time constants. std::vector compile_time_const_arg_indices(num_inputs); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis( - *(fcond->graph), &compile_time_const_arg_indices, - /*compile_time_const_nodes=*/nullptr, flib_runtime)); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis( - *(fbody->graph), &compile_time_const_arg_indices, - /*compile_time_const_nodes=*/nullptr, flib_runtime)); - for (int i = 0; i < num_inputs; i++) { + for (auto fbody : branch_bodies) { + TF_RET_CHECK(fbody != nullptr); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *(fbody->graph), &compile_time_const_arg_indices, + /*compile_time_const_nodes=*/nullptr, flib_runtime)); + } + for (int i = 0; i < compile_time_const_arg_indices.size(); i++) { if (compile_time_const_arg_indices[i]) { - // Check that this input is actually a loop invariant. - // NOTE(srbs): Ideally this should raise an error if the loop body - // requires the input at this index to be a compile time const but it is - // not a loop invariant. However, that causes problems because const - // analysis is performed for the entire graph (in the - // MarkForCompilationPass for example) and not just for the ops - // that will actually be run using XLA kernels. So we silently return here - // and let the error be raised during the actual compilation of the - // XLA graph. - Node* arg_i = fbody->arg_nodes[i]; - Node* ret_i = fbody->ret_nodes[i]; - const Node* ret_i_input_0; - TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0)); - if (ret_i_input_0->id() == arg_i->id()) { - const_input_idxs->push_back(i); - } + // The 0th input is the pred or branch index, which is not passed to the + // branches. So the i'th input of a branch function corresponds to the + // i + 1'th input of the If/Case op. + const_input_idxs->push_back(i + 1); } } return Status::OK(); } +} // namespace + +Status GetCompileTimeConstInputs(const Node* node, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime) { + // TODO(b/124403063): Implement similar functionality for function call nodes. + if (node->type_string() == "While") { + // For While nodes, recurse into the body and cond graphs. + const FunctionBody* fcond = nullptr; + const FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "cond", &fcond)); + TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "body", &fbody)); + TF_RET_CHECK(fcond); + TF_RET_CHECK(fbody); + int num_inputs = fbody->fdef.signature().input_arg_size(); + + // Stores which of the loop inputs are expected to be compile time + // constants. + std::vector compile_time_const_arg_indices(num_inputs); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *(fcond->graph), &compile_time_const_arg_indices, + /*compile_time_const_nodes=*/nullptr, flib_runtime)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *(fbody->graph), &compile_time_const_arg_indices, + /*compile_time_const_nodes=*/nullptr, flib_runtime)); + for (int i = 0; i < num_inputs; i++) { + if (compile_time_const_arg_indices[i]) { + // Check that this input is actually a loop invariant. + // NOTE(srbs): Ideally this should raise an error if the loop body + // requires the input at this index to be a compile time const but it is + // not a loop invariant. However, that causes problems because const + // analysis is performed for the entire graph (in the + // MarkForCompilationPass for example) and not just for the ops + // that will actually be run using XLA kernels. So we silently return + // here and let the error be raised during the actual compilation of the + // XLA graph. + Node* arg_i = fbody->arg_nodes[i]; + Node* ret_i = fbody->ret_nodes[i]; + const Node* ret_i_input_0; + TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0)); + if (ret_i_input_0->id() == arg_i->id()) { + const_input_idxs->push_back(i); + } + } + } + return Status::OK(); + } else if (node->type_string() == "If") { + const FunctionBody* fthen = nullptr; + const FunctionBody* felse = nullptr; + TF_RETURN_IF_ERROR( + GetFunctionBody(flib_runtime, node, "then_branch", &fthen)); + TF_RETURN_IF_ERROR( + GetFunctionBody(flib_runtime, node, "else_branch", &felse)); + return CondConstInputIndices({fthen, felse}, const_input_idxs, + flib_runtime); + } else if (node->type_string() == "Case") { + std::vector branch_bodies; + TF_RETURN_IF_ERROR( + GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies)); + return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime); + } else { + return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(), + const_input_idxs); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc index 8ca628c4eb6..f06665dad56 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc @@ -13,15 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" - +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -using cpu_function_runtime::BufferInfo; +using ::xla::cpu_function_runtime::BufferInfo; TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { // We've chosen 64 byte alignment for the tfcompile runtime to mimic the @@ -29,7 +28,7 @@ TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 // So any value that we choose must abide by that constraint as well. - EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); + EXPECT_EQ(xla::cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); } std::vector SizesToBufferInfos(const intptr_t* sizes, size_t n) { @@ -91,7 +90,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test empty sizes. void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false); EXPECT_EQ(base, nullptr); - cpu_function_runtime::FreeContiguous(base); + xla::cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with 0 sum. static constexpr intptr_t sizesA[1] = {-1}; @@ -99,7 +98,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { base = MallocContiguousBuffersFromSizes(sizesA, 1, bufA, false); EXPECT_EQ(base, nullptr); EXPECT_EQ(bufA[0], nullptr); - cpu_function_runtime::FreeContiguous(base); + xla::cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum. static constexpr intptr_t sizesB[1] = {3}; @@ -111,7 +110,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { bufB0_bytes[0] = 'A'; bufB0_bytes[1] = 'B'; bufB0_bytes[2] = 'C'; - cpu_function_runtime::FreeContiguous(base); + xla::cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum, and annotate_initialized. static constexpr intptr_t sizesC[1] = {3}; @@ -123,7 +122,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { bufC0_bytes[0] = 'A'; bufC0_bytes[1] = 'B'; bufC0_bytes[2] = 'C'; - cpu_function_runtime::FreeContiguous(base); + xla::cpu_function_runtime::FreeContiguous(base); // Test mixed sizes. static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; @@ -146,7 +145,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { } } } - cpu_function_runtime::FreeContiguous(base); + xla::cpu_function_runtime::FreeContiguous(base); } void CheckRoundTripIsOk(const BufferInfo& buffer_info) { diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc deleted file mode 100644 index 64fdbbebc65..00000000000 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for -// debugging. - -#include "tensorflow/compiler/tf2xla/dump_graph.h" - -#include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/core/util/dump_graph.h" - -namespace tensorflow { -namespace dump_graph { - -string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - return tensorflow::DumpGraphDefToFile( - name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix); -} - -string DumpGraphToFile(const string& name, Graph const& graph, - const FunctionLibraryDefinition* flib_def) { - return tensorflow::DumpGraphToFile(name, graph, flib_def, - GetDumpGraphFlags()->tf_dump_graph_prefix); -} - -string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - return tensorflow::DumpFunctionDefToFile( - name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix); -} - -} // namespace dump_graph -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.h b/tensorflow/compiler/tf2xla/dump_graph.h deleted file mode 100644 index bbf01eb90db..00000000000 --- a/tensorflow/compiler/tf2xla/dump_graph.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for -// debugging. - -#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ -#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ - -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/graph/graph.h" - -namespace tensorflow { -namespace dump_graph { - -// Dumps 'graph_def' to a file, as a GraphDef text proto. Returns the file name -// chosen. -// -// Automatically picks a file name. Prefixes 'name' with the value of the -// --tf_dump_graph_prefix flag and suffixes it with ".pbtxt" to form a name. -// If a graph has already been dumped by this process with the same name, -// suffixes with "_n.pbtxt", where 'n' is a sequence number. -string DumpGraphDefToFile(const string& name, GraphDef const& graph_def); - -// Similar to DumpGraphDefToFile, but builds the GraphDef to dump from a 'graph' -// and an optional function library 'flib_def'. Returns the file name chosen. -string DumpGraphToFile(const string& name, Graph const& graph, - const FunctionLibraryDefinition* flib_def = nullptr); - -// Similar to DumpGraphDefToFile, but dumps a function as a FunctionDef text -// proto. Returns the file name chosen. -string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef); - -} // namespace dump_graph -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index d6379460291..6e093400e47 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -25,7 +25,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" @@ -37,6 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/dump_graph.h" using xla::StatusOr; @@ -735,7 +735,7 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index] << "): " - << dump_graph::DumpGraphToFile( + << DumpGraphToFile( "functionalize_cond_body_" + branch_name[branch_index], *bodies_[branch_index], nullptr); @@ -1516,9 +1516,8 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " - << dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_cond_", name), *graph_, - library_); + << DumpGraphToFile(absl::StrCat("functionalize_cond_", name), + *graph_, library_); } void FunctionalizeCond::AddSwitchId(int switch_id) { diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 3dfd3f854c8..89d5a860179 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/functionalize_while.h" @@ -43,6 +42,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -50,8 +50,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " - << dump_graph::DumpGraphToFile("functionalize_initial", *graph, - library); + << DumpGraphToFile("functionalize_initial", *graph, library); // Functionalize and remove while loops from graph. TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library)); @@ -62,8 +61,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library)); VLOG(2) << "FunctionalizeControlFlow (final): " - << dump_graph::DumpGraphToFile("functionalize_final", *graph, - library); + << DumpGraphToFile("functionalize_final", *graph, library); return Status::OK(); } @@ -200,13 +198,13 @@ Status FunctionalizeControlFlowForFunction( // Functionalize the function body. if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("functionalize_control_flow_before_fdef_", func_name), *g, fld); } TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( + DumpGraphToFile( absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, fld); } @@ -234,8 +232,8 @@ Status FunctionalizeControlFlowPass::Run( const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, - options.flib_def); + DumpGraphToFile("functionalize_control_flow_before", *graph, + options.flib_def); } std::unique_ptr pflr( new ProcessFunctionLibraryRuntime( @@ -255,6 +253,7 @@ Status FunctionalizeControlFlowPass::Run( {"XlaLaunch", "function"}, }; std::map> canonicalized_name_to_new_name; + bool fld_modified = false; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); if (it == kNodeTypeToFunctionAttrMapping->end()) { @@ -275,12 +274,19 @@ Status FunctionalizeControlFlowPass::Run( n->ClearAttr(func_attr); func.set_name(new_func_name); n->AddAttr(func_attr, func); + + fld_modified = true; } } + if (fld_modified) { + TF_RETURN_IF_ERROR( + PruneUnreachableFunctionsFromGraph(*graph, options.flib_def)); + } + if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, - options.flib_def); + DumpGraphToFile("functionalize_control_flow_after", *graph, + options.flib_def); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 9784985af83..6c6b6cd1a77 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -1010,13 +1010,14 @@ TEST(FunctionalizeControlFlow, Complex) { ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + auto retval3 = ops::_Retval(scope.WithOpName("_retval3_RetVal"), arg3, 3); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), result.ret_types); TF_EXPECT_GRAPH_EQ(expected, result.gdef); } @@ -1083,6 +1084,7 @@ TEST(FunctionalizeControlFlow, Complex) { auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + auto retval3 = ops::_Retval(scope.WithOpName("_retval3_RetVal"), arg3, 3); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -1093,7 +1095,7 @@ TEST(FunctionalizeControlFlow, Complex) { EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), result.ret_types); TF_EXPECT_GRAPH_EQ(expected, result.gdef); } diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index d87436a7b4a..fbab2803f5c 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -36,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { @@ -200,33 +200,28 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, arg_types->push_back(dtype); TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); - - if (dtype == DT_RESOURCE) { - // The convention of the XLA bridge is that resource variable arguments - // are only inputs to the loop body and have no corresponding output. - // TODO(b/37741920): change the convention so that DT_RESOURCE variables - // are both inputs and outputs, and then remove this case. - TF_RET_CHECK(arg.is_loop_invariant); + TF_ASSIGN_OR_RETURN(Node * retval_node, BuildRetvalNode(output, dtype, i)); + if (arg.is_loop_invariant) { + // Argument is loop-invariant. Forward it from the Arg to the Retval. node_map[arg.enter->id()] = arg_node; + output->AddEdge(arg_node, 0, retval_node, 0); } else { - TF_ASSIGN_OR_RETURN(Node * retval_node, - BuildRetvalNode(output, dtype, i)); - - if (arg.is_loop_invariant) { - // Argument is loop-invariant. Forward it from the Arg to the Retval. - node_map[arg.enter->id()] = arg_node; - output->AddEdge(arg_node, 0, retval_node, 0); - } else { - // Argument is loop-varying. - node_map[arg.switch_node->id()] = arg_node; - // The Switch node has two outputs, but _Arg only has one. This tells - // the CopySubgraph function to rewrite the output number of edges from - // the _Arg node to be 0 rather than copying the output number from the - // Switch node. - squash_src_outputs[arg.switch_node->id()] = true; - node_map[arg.next_iteration->id()] = retval_node; - next_iterations.push_back(arg.next_iteration); + // Argument is loop-varying. + if (dtype == DT_RESOURCE) { + // DT_RESOURCE arguments should always be loop-invariant in the graphs + // generated from TF. + return errors::Unimplemented("Loop-varying DT_RESOURCE Enter node ", + arg.enter->name(), " is currently not", + " supported."); } + node_map[arg.switch_node->id()] = arg_node; + // The Switch node has two outputs, but _Arg only has one. This tells + // the CopySubgraph function to rewrite the output number of edges from + // the _Arg node to be 0 rather than copying the output number from the + // Switch node. + squash_src_outputs[arg.switch_node->id()] = true; + node_map[arg.next_iteration->id()] = retval_node; + next_iterations.push_back(arg.next_iteration); } } @@ -293,8 +288,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph, - library); + << DumpGraphToFile("functionalize_before", *graph, library); // Split loop-varying Enter nodes with multiple successors. If the same // Tensor is fed as input to multiple loop arguments, we may end up with a @@ -490,8 +484,8 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) - << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); + << DumpGraphToFile("loop_condition", *cond_graph, library) + << " body: " << DumpGraphToFile("loop_body", *body_graph); static std::atomic sequence_num(0LL); int64 id = ++sequence_num; @@ -585,8 +579,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph, - library); + << DumpGraphToFile("functionalize_after", *graph, library); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index b3cb23003ec..a431abd26e0 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" @@ -34,7 +33,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -86,12 +88,24 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.kind = XlaCompiler::Argument::kParameter; } break; - case XlaExpression::Kind::kResource: - return errors::Unimplemented( - "Resource as function argument is not yet implemented."); - case XlaExpression::Kind::kTensorList: - return errors::Unimplemented( - "TensorList as function argument is not yet implemented."); + case XlaExpression::Kind::kResource: { + XlaResource* resource = expressions[i]->resource(); + + arg.initialized = resource->initialized(); + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = resource->kind(); + arg.type = resource->type(); + arg.shape = resource->shape(); + arg.max_array_size = resource->max_array_size(); + arg.name = resource->name(); + break; + } + case XlaExpression::Kind::kTensorList: { + arg.kind = XlaCompiler::Argument::kTensorList; + const xla::XlaOp& tensor_list = expressions[i]->handle(); + arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie(); + break; + } case XlaExpression::Kind::kInvalid: return errors::InvalidArgument("Invalid function argument"); } @@ -124,6 +138,8 @@ Status GraphCompiler::Compile() { for (Node* n : topo_sorted_nodes) { OpKernel* op_kernel_raw = nullptr; + // The kernel is not actually run for functional ops, we just need it + // for metadata. Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); // Transfer ownership of the kernel to a local smart pointer. std::unique_ptr op_kernel(op_kernel_raw); @@ -157,7 +173,7 @@ Status GraphCompiler::Compile() { OpKernelContext op_context(¶ms, n->num_outputs()); VLOG(3) << "Translating " << params.op_kernel->name(); - if (IsFunctional(n)) { + if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) { TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); } else { device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); @@ -182,15 +198,37 @@ Status GraphCompiler::Compile() { return Status::OK(); } -bool GraphCompiler::IsFunctional(Node* n) { - return n->type_string() == FunctionLibraryDefinition::kGradientOp || - (flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) != - nullptr); +namespace { + +Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, + const Node& node, NameAttrList* func) { + if (node.IsPartitionedCall()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value)); + if (!attr_value->has_func()) { + return errors::InvalidArgument( + "The attribute value for attribute 'f' in node ", node.DebugString(), + " does not have 'func' field set"); + } + *func = attr_value->func(); + return Status::OK(); + } + + if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) { + func->set_name(node.type_string()); + } else { + func->set_name(FunctionLibraryDefinition::kGradientOp); + } + *func->mutable_attr() = node.def().attr(); + return Status::OK(); } +} // namespace + Status GraphCompiler::CompileFunctionalNode(Node* n, OpKernelContext* op_context) { - TF_RET_CHECK(IsFunctional(n)); + TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)); // For functional nodes, compile them using compiler from the context and call // into the functions. XlaOpKernelContext xla_op_context(op_context); @@ -201,12 +239,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, XlaCompiler* compiler = xla_op_context.compiler(); NameAttrList func; - if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) { - func.set_name(n->def().op()); - } else { - func.set_name(FunctionLibraryDefinition::kGradientOp); - } - *func.mutable_attr() = n->def().attr(); + TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func)); std::vector expressions; @@ -227,7 +260,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); bool add_token_input_output = - HasNodeAttr(n->def(), kXlaTokenInputNodesAttrName); + func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end(); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = false; @@ -243,12 +276,17 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, if (arguments[i].kind == XlaCompiler::Argument::kConstant) { continue; } - handles.push_back(expressions[i]->handle()); + if (arguments[i].kind == XlaCompiler::Argument::kResource) { + handles.push_back(expressions[i]->resource()->value()); + } else { + handles.push_back(expressions[i]->handle()); + } } if (add_token_input_output) { std::vector token_input_nodes; - TF_RETURN_IF_ERROR( - GetNodeAttr(n->def(), kXlaTokenInputNodesAttrName, &token_input_nodes)); + TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()), + kXlaTokenInputNodesAttrName, + &token_input_nodes)); std::vector token_inputs; for (const string& node_name : token_input_nodes) { auto token_or = compiler->GetNodeToken(node_name); @@ -267,11 +305,27 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, if (result.outputs[i].is_constant) { xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { - xla_op_context.SetOutput( - i, xla::GetTupleElement(output_handle, computation_output)); + if (result.outputs[i].is_tensor_list) { + xla_op_context.SetTensorListOutput( + i, xla::GetTupleElement(output_handle, computation_output)); + } else { + xla_op_context.SetOutput( + i, xla::GetTupleElement(output_handle, computation_output)); + } ++computation_output; } } + + for (int64 i = 0; i < result.resource_updates.size(); i++) { + if (result.resource_updates[i].modified) { + XlaResource* resource = + expressions[result.resource_updates[i].input_index]->resource(); + xla::XlaOp updated_value = + xla::GetTupleElement(output_handle, i + n->num_outputs()); + TF_RETURN_IF_ERROR(resource->SetValue(updated_value)); + } + } + if (add_token_input_output) { TF_RETURN_IF_ERROR(compiler->SetNodeToken( n->name(), xla::GetTupleElement(output_handle, computation_output))); diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index e9f02201cf6..eb02534e7fb 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -73,10 +73,6 @@ class GraphCompiler { // across multiple nodes visit. void PartiallySetupParams(OpKernelContext::Params* params); - // Tests if a node is a functional node. A functional node represents a - // defined computation and should be compiled using `compiler_`. - bool IsFunctional(Node* n); - // Compiles a functional node and writes result to OpkernelContext. A // functional node represents a defined computation and should be compiled // using `compiler_`. diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index b4d4b4433eb..d6dfa39e658 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -33,6 +33,7 @@ tf_kernel_library( "diag_op.cc", "dynamic_slice_ops.cc", "dynamic_stitch_op.cc", + "einsum_op.cc", "elu_op.cc", "empty_op.cc", "extract_image_patches_op.cc", @@ -46,6 +47,7 @@ tf_kernel_library( "identity_op.cc", "image_ops.cc", "image_resize_ops.cc", + "in_topk_op.cc", "index_ops.cc", "l2loss_op.cc", "listdiff_op.cc", @@ -70,6 +72,7 @@ tf_kernel_library( "reduction_ops.h", "reduction_ops_common.cc", "relu_op.cc", + "replica_id_op.cc", "reshape_op.cc", "retval_op.cc", "reverse_op.cc", @@ -110,6 +113,7 @@ tf_kernel_library( "xla_reduce_op.cc", "xla_select_and_scatter_op.cc", "xla_self_adjoint_eig_op.cc", + "xla_svd_op.cc", ], hdrs = [ "index_ops.h", @@ -125,6 +129,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:broadcast", + "//tensorflow/compiler/tf2xla/lib:data_format", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:util", @@ -140,6 +145,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:comparators", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", @@ -149,7 +155,9 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:quantize", "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", + "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:svd", "//tensorflow/core:bitwise_ops_op_lib", "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:data_flow_ops_op_lib", @@ -246,26 +254,35 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) +cc_library( + name = "if_while_utils", + srcs = ["if_while_utils.cc"], + hdrs = ["if_while_utils.h"], +) + tf_kernel_library( name = "while_op", srcs = ["while_op.cc"], hdrs = ["while_op.h"], deps = [ + ":if_while_utils", + ":tensor_list_utils", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", ], @@ -276,6 +293,7 @@ tf_kernel_library( srcs = ["if_op.cc"], hdrs = ["if_op.h"], deps = [ + ":if_while_utils", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -293,6 +311,7 @@ tf_kernel_library( srcs = ["case_op.cc"], hdrs = ["case_op.h"], deps = [ + ":if_while_utils", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -349,7 +368,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -362,7 +381,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 41a453da80d..f34b2ff11df 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { namespace { @@ -30,19 +32,66 @@ class AddNOp : public XlaOpKernel { OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("AddN requires at least one argument")); - xla::XlaOp sum = ctx->Input(0); - for (int i = 1; i < ctx->num_inputs(); ++i) { - sum = xla::Add(sum, ctx->Input(i)); - } + XlaExpression::Kind kind = ctx->InputExpression(0).kind(); + xla::XlaOp sum; + switch (kind) { + case XlaExpression::Kind::kTensorList: { + // Check that all TensorLists are initialized. + for (int i = 1; i < ctx->num_inputs(); ++i) { + xla::XlaOp list = ctx->Input(i); + bool is_initialized; + OP_REQUIRES_OK(ctx, IsTensorListInitialized(list, &is_initialized)); + OP_REQUIRES( + ctx, is_initialized, + errors::InvalidArgument("TensorList input #", i, + " for AddN op is an uninitialized list")); + } + // Nested TensorList is not supported. + bool is_nested_list; + OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested_list)); + OP_REQUIRES(ctx, !is_nested_list, + errors::Unimplemented( + "Nested TensorList is not supported for AddN op")); - ctx->SetOutput(0, sum); + OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &sum)); + xla::Shape sum_shape; + OP_REQUIRES_OK(ctx, + GetTensorListBufferShape(ctx->Input(0), &sum_shape)); + for (int i = 1; i < ctx->num_inputs(); ++i) { + xla::XlaOp operand; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(i), &operand)); + // Check that the shapes match. + xla::Shape operand_shape; + OP_REQUIRES_OK( + ctx, GetTensorListBufferShape(ctx->Input(i), &operand_shape)); + OP_REQUIRES( + ctx, sum_shape.dimensions() == operand_shape.dimensions(), + errors::InvalidArgument( + "TensorList arguments to AddN must all have the same ", + "shape.\n", "Expected: ", sum_shape.DebugString(), "\n", + "Found: ", operand_shape.DebugString())); + sum = xla::Add(sum, operand); + } + xla::XlaOp push_index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &push_index)); + OP_REQUIRES_OK(ctx, BuildNonNestedTensorList(sum, push_index, &sum)); + ctx->SetTensorListOutput(0, sum); + break; + } + default: + sum = ctx->Input(0); + for (int i = 1; i < ctx->num_inputs(); ++i) { + sum = xla::Add(sum, ctx->Input(i)); + } + ctx->SetOutput(0, sum); + } } private: TF_DISALLOW_COPY_AND_ASSIGN(AddNOp); }; -REGISTER_XLA_OP(Name("AddN"), AddNOp); +REGISTER_XLA_OP(Name("AddN").AllowVariantTypes(), AddNOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 5554d7a377d..3d9aceae8ec 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -41,10 +41,15 @@ class XlaArgOp : public XlaOpKernel { if (frame != nullptr) { Tensor val; OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); - OP_REQUIRES(ctx, val.dtype() == dtype_, - errors::InvalidArgument( - "Type mismatch: actual ", DataTypeString(val.dtype()), - " vs. expect ", DataTypeString(dtype_))); + // Types that cannot be copied using memcpy (like DT_STRING) are wrapped + // in a DT_UINT8 and hence the type mismatches. Skip the test in such + // cases. See XlaOpKernelContext::SetOutputExpression for details. + if (DataTypeCanUseMemcpy(dtype_)) { + OP_REQUIRES(ctx, val.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(val.dtype()), + " vs. expect ", DataTypeString(dtype_))); + } // Forwards the argument from the frame. ctx->op_kernel_context()->set_output(0, val); return; diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 1b254e328a8..f60509b3746 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -30,11 +30,8 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = - xla::BatchDot(MaybeTransposeInMinorDims( - MaybeConjugate(ctx->Input(0), adj_x_), adj_x_), - MaybeTransposeInMinorDims( - MaybeConjugate(ctx->Input(1), adj_y_), adj_y_)); + auto result = xla::BatchDot(MaybeConjugate(ctx->Input(0), adj_x_), adj_x_, + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_); ctx->SetOutput(0, result); } @@ -44,6 +41,7 @@ class BatchMatMulOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("BatchMatMul"), BatchMatMulOp); +REGISTER_XLA_OP(Name("BatchMatMulV2"), BatchMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index f1d78c87527..84eda80fc25 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -67,7 +67,18 @@ class FusedBatchNormOp : public XlaOpKernel { ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0), input_type)); ctx->SetOutput(1, xla::GetTupleElement(output, 1)); - ctx->SetOutput(2, xla::GetTupleElement(output, 2)); + xla::XlaOp variance = xla::GetTupleElement(output, 2); + // Apply Bessel's correction. + int total_input_size = ctx->InputShape(0).num_elements(); + int total_scale_size = ctx->InputShape(1).num_elements(); + int sample_size = total_input_size / total_scale_size; + int sample_size_minus_one = std::max(1, sample_size - 1); + + xla::XlaOp factor = + xla::Div(xla::ScalarLike(variance, sample_size), + xla::ScalarLike(variance, sample_size_minus_one)); + xla::XlaOp corrected = xla::Mul(variance, factor); + ctx->SetOutput(2, corrected); // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and @@ -80,11 +91,10 @@ class FusedBatchNormOp : public XlaOpKernel { // behavior of the op: // output 3 is the mean // output 4 is rsqrt(variance + epsilon) - xla::XlaOp variance = xla::GetTupleElement(output, 2); ctx->SetOutput(4, xla::Rsqrt(xla::Add( variance, xla::ScalarLike(variance, epsilon_)))); } else { - ctx->SetOutput(4, xla::GetTupleElement(output, 2)); + ctx->SetOutput(4, variance); } } else { xla::XlaOp output = xla::BatchNormInference( diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 66446106d3a..59f6041a608 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -79,6 +79,24 @@ static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(DivNoNan, DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +// Implementation of MulNoNan. Pseudo-code: +// if (y == 0) { +// return 0 +// } else { +// return x * y; +// } +static xla::XlaOp MulNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto y_equals_0 = xla::Eq(y, zero); + auto zeros = xla::ZerosLike(x); + auto result = xla::Select(y_equals_0, zeros, xla::Mul(x, y)); + return result; +} +XLA_MAKE_BINARY(MulNoNan, + MulNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorDiv. // // For floating-point values, simply returns floor(x / y). For integers, does: @@ -94,7 +112,17 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); if (DataTypeIsFloating(dtype)) { - return xla::Floor(xla::Div(x, y)); + if (dtype == DataType::DT_BFLOAT16) { + // The result of a BF16 division may produce the Ceil of what was + // computed by F32 division, so avoid end user confusion by doing the + // intermediate divide in F32. + return xla::ConvertElementType( + xla::Floor(xla::Div(xla::ConvertElementType(x, xla::F32), + xla::ConvertElementType(y, xla::F32))), + xla::BF16); + } else { + return xla::Floor(xla::Div(x, y)); + } } if (DataTypeIsUnsigned(dtype)) { return xla::Div(x, y); @@ -130,14 +158,17 @@ XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper)); // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); -// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); +// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y +// : trunc_mod; static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); - auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); auto trunc_mod = xla::Rem(x, y); - return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y)); + auto trunc_mod_not_zero = xla::Ne(trunc_mod, zero); + auto do_plus = xla::And(xla::Ne(xla::Lt(trunc_mod, zero), xla::Lt(y, zero)), + trunc_mod_not_zero); + return xla::Select(do_plus, xla::Add(trunc_mod, y), trunc_mod); } XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); @@ -161,7 +192,7 @@ XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs)))); XLA_MAKE_BINARY( RsqrtGrad, - xla::Mul(xla::Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), + xla::Mul((lhs * lhs) * lhs, xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), extend_dimensions)); XLA_MAKE_BINARY( diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index 24623768f38..5ba844e10bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/case_op.h" +#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -34,10 +35,41 @@ XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } else { has_token_input_output_ = !token_input_nodes_.empty(); } + if (ctx->HasAttr(kPropagateCompileTimeConsts)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts, + &propagate_compile_time_consts_)); + } } +namespace { + +Status ConvertCompileTimeConstArgumentsToConst( + XlaOpKernelContext* ctx, std::vector* args) { + for (int i = 0; i < args->size(); i++) { + XlaCompiler::Argument& arg = (*args)[i]; + const XlaExpression& expression = ctx->InputExpression(i + 1); + // If the input tensor is a compile time constant build a kConstant type + // argument. + if (arg.kind == XlaCompiler::Argument::kParameter) { + // NOTE: We can not simply check that this is Kind::kConstant because + // this could be the output of a MetadataOnly op e.g. Size. + xla::StatusOr> maybe_constant = + expression.ResolveConstant(ctx->compiler()->client()); + if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = expression.dtype(); + arg.constant_value = std::move(maybe_constant.ValueOrDie().value()); + arg.shape = expression.GetShape().ValueOrDie(); + } + } + } + return Status::OK(); +} + +} // namespace + // TODO(b/35949885): There is duplication here with the handling of the -// while_op. Refactor the common code out/rework. +// while_op/if_op. Refactor the common code out/rework. void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { xla::XlaBuilder* b = ctx->builder(); int num_branches = branches_.size(); @@ -84,12 +116,30 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; - arg.shape = ctx->InputShape(i + 1); + // Use the xla::Shape for the input instead of ctx->InputShape. This is + // necessary for forwarding shapes of DT_VARIANTs, e.g. TensorLists. + auto shape_or = ctx->builder()->GetShape(ctx->Input(i + 1)); + OP_REQUIRES_OK(ctx, shape_or.status()); + arg.shape = shape_or.ValueOrDie(); VLOG(2) << "Arg type: " << DataTypeString(arg.type) << " shape: " << arg.HumanString(); } } + if (propagate_compile_time_consts_) { + // Replaces `kParameter` type args in `arguments` with `kConstant` if + // the op input corresponding to that arg is a compile-time const. This + // is necessary to propagate compile time consts to ops in the branch + // functions. + // Note: Propagating "all" compile-time constants may not be necessary. We + // should ideally only propagate consts which are required to be compile + // time constants in the branch functions. But that would require calling + // BackwardsConstAnalysis here which would be expensive. However, if we + // start hitting memory issues we should revisit this. + OP_REQUIRES_OK(ctx, + ConvertCompileTimeConstArgumentsToConst(ctx, &arguments)); + } + // Compile each branch of the conditional. XlaCompiler::CompileOptions options; options.use_tuple_arg = true; @@ -158,8 +208,6 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { } OP_REQUIRES(ctx, branch_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); - OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1, - errors::FailedPrecondition("Expected one input shape")); OP_REQUIRES( ctx, xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape), @@ -227,7 +275,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); } else { - inputs[i] = ctx->Input(i + 1); + inputs[i] = ctx->Input(input_num); } } auto input_tuple = xla::Tuple(b, inputs); @@ -292,6 +340,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building Case"; } -REGISTER_XLA_OP(Name("Case").AllowResourceTypes(), XlaCaseOp); +REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(), + XlaCaseOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h index ea14b18149c..4a61707864e 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.h +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -55,6 +55,10 @@ class XlaCaseOp : public XlaOpKernel { DataTypeVector output_types_; bool has_token_input_output_; std::vector token_input_nodes_; + // Whether to propagate compile time consts into the cond branches. + // This is not supported by default now since it may cause HBM memory + // overheads. + bool propagate_compile_time_consts_ = false; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index a99c6ee4431..a64ce55b36b 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA implementations of Categorical op. +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -35,7 +36,9 @@ namespace { class CategoricalOp : public XlaOpKernel { public: - explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit CategoricalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx), + is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {} void Compile(XlaOpKernelContext* ctx) override { // Get the logits @@ -100,8 +103,15 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType xla_output_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_type(0), &xla_output_type)); - xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type, - /*axis=*/class_dimension); + xla::XlaOp argmax; + if (is_gpu_) { + argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type, + /*axis=*/class_dimension); + } else { + argmax = xla::ArgMax(softmax_entries, xla_output_type, + /*axis=*/class_dimension); + } + if (num_samples == 1) { argmax = xla::Reshape(argmax, {batch_size, 1}); } @@ -123,6 +133,7 @@ class CategoricalOp : public XlaOpKernel { } private: + bool is_gpu_; TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp); }; @@ -133,15 +144,14 @@ REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstantInput("num_samples"), class StatelessCategoricalOp : public CategoricalOp { public: explicit StatelessCategoricalOp(OpKernelConstruction* ctx) - : CategoricalOp(ctx) { + : CategoricalOp(ctx), + device_type_string_(ctx->device_type().type_string()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); } xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type, XlaOpKernelContext* ctx) override { xla::XlaOp seed = ctx->Input(2); - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); xla::XlaBuilder* builder = ctx->builder(); if (uniform_shape.element_type() == xla::BF16) { @@ -150,8 +160,8 @@ class StatelessCategoricalOp : public CategoricalOp { // We want a number in (0, 1) rather than [0, 1) or (0, 1]: // * log(-log(0)) is ∞. // * log(-log(1)) is -∞. - auto uniforms = xla::StatelessRngUniform( - {seed0, seed1}, uniform_shape, + xla::XlaOp uniforms = StatelessRngUniform( + device_type_string_, seed, uniform_shape, xla::MinPositiveNormalValue(builder, uniform_shape.element_type()), xla::One(builder, uniform_shape.element_type())); return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); @@ -167,6 +177,7 @@ class StatelessCategoricalOp : public CategoricalOp { private: DataType dtype_; + string device_type_string_; TF_DISALLOW_COPY_AND_ASSIGN(StatelessCategoricalOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 6512ba25ce6..36a4422527f 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Ops for 2D convolution. #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" + #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -53,6 +54,57 @@ xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { return expanded_shape; } +// Returns the transposed filter for use in BackpropInput of group convolution. +xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput( + const xla::XlaOp& filter, const xla::Shape& filter_shape, int64 num_groups, + int num_spatial_dims) { + // 1. Reshape from [H, W, ..., filter_in_depth, out_depth] to [H, W, ..., + // filter_in_depth, G, out_depth / G] + int num_dims = filter_shape.dimensions_size(); + CHECK_GE(num_dims, 2); // Crash OK + xla::Shape new_shape = filter_shape; + new_shape.set_dimensions(num_dims - 1, num_groups); + new_shape.add_dimensions(filter_shape.dimensions(num_dims - 1) / num_groups); + xla::XlaOp result = xla::Reshape(filter, new_shape.dimensions()); + + // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G] + std::vector transpose_dims(num_dims + 1); + std::iota(transpose_dims.begin(), transpose_dims.end(), 0); + std::swap(transpose_dims[num_spatial_dims], + transpose_dims[num_spatial_dims + 1]); + result = xla::Transpose(result, transpose_dims); + + // 3. Reshape to [H, W, ..., in_depth, out_depth / G] + result = xla::Collapse(result, {num_spatial_dims, num_spatial_dims + 1}); + return result; +} + +// Returns the transposed input for use in BackpropFilter of group convolution. +xla::XlaOp TransposeInputForGroupConvolutionBackpropFilter( + const xla::XlaOp& input, const xla::Shape& input_shape, int64 num_groups, + int batch_dim, int depth_dim) { + // 1. Reshape the depth_dim C into [G, C/G] + int num_dims = input_shape.dimensions_size(); + std::vector reshape_dims = input_shape.dimensions(); + reshape_dims[depth_dim] = reshape_dims[depth_dim] / num_groups; + reshape_dims.insert(reshape_dims.begin() + depth_dim, num_groups); + xla::XlaOp result = xla::Reshape(input, reshape_dims); + + // 2. Transpose G to the axis before N, e.g.: [G, N, H, W, C/G] + std::vector transpose_dims(num_dims + 1); + std::iota(transpose_dims.begin(), transpose_dims.end(), + 0); // e.g.: [0, 1, 2, 3, 4] -> [N, H, W, G, C/G] + transpose_dims.erase(transpose_dims.begin() + depth_dim); + transpose_dims.insert( + transpose_dims.begin() + batch_dim, + depth_dim); // e.g.: [3, 0, 1, 2, 4] -> [G, N, H, W, C/G] + result = xla::Transpose(result, transpose_dims); + + // 3. Merge [G, N] to [G*N] + result = xla::Collapse(result, {batch_dim, batch_dim + 1}); + return result; +} + // Create a mask for depthwise convolution that will make a normal convolution // produce the same results as a depthwise convolution. For a [2, 2, 3, 2] // depthwise filter this returns a [2, 2, 3, 6] tensor @@ -242,10 +294,9 @@ xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, return attrs; } -xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, - xla::XlaOp conv_input, - xla::XlaOp filter, - const ConvOpAttrs& attrs) { +xla::StatusOr MakeXlaForwardConvOp( + StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter, + const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); auto* builder = conv_input.builder(); @@ -269,13 +320,21 @@ xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); - int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims); - // The 'C' dimension for input is in_depth. It must be the same as - // the filter's in_depth. - if (in_depth != input_shape.dimensions(feature_dim)) { + int64 filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims), + out_depth = filter_shape.dimensions(attrs.num_spatial_dims + 1), + in_depth = input_shape.dimensions(feature_dim); + // The 'C' dimension for input is in_depth. + // It must be a multiple of the filter's in_depth. + if (in_depth % filter_in_depth != 0) { return errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, " vs ", - input_shape.dimensions(feature_dim)); + "Depth of input must be a multiple of depth of filter: ", in_depth, + " vs ", filter_in_depth); + } + int64 feature_group_count = in_depth / filter_in_depth; + if (out_depth % feature_group_count != 0) { + return errors::InvalidArgument( + "Depth of output must be a multiple of the number of groups: ", + out_depth, " vs ", feature_group_count); } if (attrs.depthwise) { @@ -317,12 +376,15 @@ xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, return xla::ConvGeneralDilated( conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, - dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1); + dims, + /*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count, + /*batch_group_count=*/1, precision_config); } xla::StatusOr MakeXlaBackpropInputConvOp( StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, - xla::XlaOp out_backprop, const ConvOpAttrs& attrs) { + xla::XlaOp out_backprop, const ConvOpAttrs& attrs, + const xla::PrecisionConfig* precision_config) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); int num_dims = attrs.num_spatial_dims + 2; @@ -334,6 +396,10 @@ xla::StatusOr MakeXlaBackpropInputConvOp( TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, builder->GetShape(out_backprop)); + int64 in_depth = input_shape.dimensions(feature_dim), + filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims), + feature_group_count = in_depth / filter_in_depth; + xla::Shape expanded_filter_shape = attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) : filter_shape; @@ -377,24 +443,29 @@ xla::StatusOr MakeXlaBackpropInputConvOp( rhs_dilation[i] = attrs.dilations[dim]; } + if (feature_group_count != 1 && !attrs.depthwise) { + filter = TransposeFilterForGroupConvolutionBackpropInput( + filter, filter_shape, feature_group_count, attrs.num_spatial_dims); + } // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); + filter = xla::Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) mirrored_weights return xla::ConvGeneralDilated( - out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, rhs_dilation, dnums, + out_backprop, filter, /*window_strides=*/ones, padding, lhs_dilation, + rhs_dilation, dnums, /*feature_group_count=*/ attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) / filter_shape.dimensions(attrs.num_spatial_dims + 1) - : 1); + : feature_group_count, + /*batch_group_count=*/1, precision_config); } xla::StatusOr MakeXlaBackpropFilterConvOp( StringPiece type_string, xla::XlaOp activations, const xla::Shape& filter_shape, xla::XlaOp gradients, - const ConvOpAttrs& attrs) { + const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) { TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); auto* builder = activations.builder(); @@ -427,17 +498,28 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings)); - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we flip the roles of the batch and - // feature dimensions. - // Each spatial entry has size in_depth * batch - + // Obtain some useful dimensions: // The last two dimensions of the filter are the input and output shapes. int num_dims = attrs.num_spatial_dims + 2; int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + int64 in_depth = input_shape.dimensions(c_dim), + filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims), + feature_group_count = in_depth / filter_in_depth; + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we need to: + // 1. In the case of group convolution, move the num_groups dimension before + // the batch dimension + // 2. Swap the roles of the batch and feature dimensions. + if (feature_group_count != 1 && !attrs.depthwise) { + activations = TransposeInputForGroupConvolutionBackpropFilter( + activations, input_shape, feature_group_count, n_dim, c_dim); + } + + // In the case of depthwise convolution with no multiplier, + // the computation can be done by the batch_group_count parameter. bool use_batch_group_count = filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; @@ -532,8 +614,9 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( filter_backprop = xla::ConvGeneralDilated( activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums, - /*feature_group_count=*/1, - /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1); + /*feature_group_count=*/feature_group_count, + /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1, + precision_config); if (!use_batch_group_count && attrs.depthwise) { filter_backprop = ContractFilterForDepthwiseBackprop( diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index d893eca7f9b..927857a2661 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -53,17 +53,19 @@ struct ConvOpAttrs { // Creates a new XLA forward or backward convolution with the given inputs and // attributes. -xla::StatusOr MakeXlaForwardConvOp(StringPiece type_string, - xla::XlaOp conv_input, - xla::XlaOp filter, - const ConvOpAttrs& attrs); +xla::StatusOr MakeXlaForwardConvOp( + StringPiece type_string, xla::XlaOp conv_input, xla::XlaOp filter, + const ConvOpAttrs& attrs, + const xla::PrecisionConfig* precision_config = nullptr); xla::StatusOr MakeXlaBackpropInputConvOp( StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, - xla::XlaOp out_backprop, const ConvOpAttrs& attrs); + xla::XlaOp out_backprop, const ConvOpAttrs& attrs, + const xla::PrecisionConfig* precision_config = nullptr); xla::StatusOr MakeXlaBackpropFilterConvOp( StringPiece type_string, xla::XlaOp activations, const xla::Shape& filter_shape, xla::XlaOp gradients, - const ConvOpAttrs& attrs); + const ConvOpAttrs& attrs, + const xla::PrecisionConfig* precision_config = nullptr); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 234f7b4a019..a709a20c28b 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -48,7 +48,6 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { rhs_shape.DebugString())); return; } - TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); // Fetch the expressions containing the input tensors. auto lhs_handle = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index e96a1adce43..9fe91d16d77 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -30,11 +31,6 @@ class DepthToSpaceOp : public XlaOpKernel { OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, - errors::InvalidArgument("Unsupported data format ", - ToString(data_format_), - "; expected formats NHWC or NCHW")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -42,19 +38,36 @@ class DepthToSpaceOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_tensor_shape = ctx->InputShape(0); - int input_rank = input_tensor_shape.dims(); + xla::XlaOp input = ctx->Input(0); + + TensorFormat data_format = data_format_; + // If the data is in a vectorized format, reformat it into a non-vectorized + // version first. We'll undo the transformation later. + if (data_format == FORMAT_NCHW_VECT_C) { + data_format = FORMAT_NCHW; + auto input_reshaped = NCHW_VECT_CToNCHW(input); + OP_REQUIRES_OK(ctx, input_reshaped.status()); + input = input_reshaped.ValueOrDie(); + } + + OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_))); + + xla::XlaBuilder* builder = input.builder(); + auto input_xla_shape = builder->GetShape(input); + OP_REQUIRES_OK(ctx, input_xla_shape.status()); + const std::vector& input_shape = + input_xla_shape.ValueOrDie().dimensions(); + int input_rank = input_shape.size(); + static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got: ", input_rank)); - const absl::InlinedVector input_shape = - input_tensor_shape.dim_sizes(); - xla::XlaOp input = ctx->Input(0); - - int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); - int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format); std::vector reshaped_shape; std::vector transpose_order; @@ -62,7 +75,7 @@ class DepthToSpaceOp : public XlaOpKernel { reshaped_shape.reserve(input_rank); transpose_order.reserve(input_rank); output_shape.reserve(input_rank); - if (data_format_ == FORMAT_NHWC) { + if (data_format == FORMAT_NHWC) { reshaped_shape.push_back(input_shape[0]); for (int i = 0; i < num_spatial_dims; ++i) { reshaped_shape.push_back(input_shape[1 + i]); @@ -153,6 +166,14 @@ class DepthToSpaceOp : public XlaOpKernel { // xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); + // If this used to be a vectorized format turn it back now. + if (data_format != data_format_) { + DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C); + auto output_reshaped = NCHWToNCHW_VECT_C(output); + OP_REQUIRES_OK(ctx, output_reshaped.status()); + output = output_reshaped.ValueOrDie(); + } + ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ee79cbc70da..747ec133983 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -29,8 +29,7 @@ namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, - absl::Span other_dims, - xla::PrimitiveType element_type) { + absl::Span other_dims) { xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: // @@ -58,22 +57,17 @@ xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, // select( [f, t, f] , [4, 4, 4] , [0, 0, 0] ) = [0, 4, 0] // [f, f, t]] [9, 9, 9]] [0, 0, 0]] [0, 0, 9]] // - // Broadcasting the input is less-than-trivial, since we need to broadcast - // into a "middle" dimension. We can do this with a reshape + implicit - // broadcast. - // TODO(b/30112114): Replace with in-dim broadcast when those are supported. - std::vector broadcast_dims(other_dims.begin(), other_dims.end()); - broadcast_dims.push_back(1LL); - broadcast_dims.push_back(last_dim_size); - xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims); + std::vector out_dim_sizes(other_dims.begin(), other_dims.end()); + out_dim_sizes.push_back(last_dim_size); + out_dim_sizes.push_back(last_dim_size); - broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; - auto broadcast_shape = - xla::ShapeUtil::MakeShape(element_type, broadcast_dims); - xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape); - - input_broadcast = xla::Add(input_broadcast, zeros); - return xla::Select(mask, input_broadcast, zeros); + // Broadcast into the second to last dimension. + std::vector broadcast_dimensions(other_dims.size() + 1); + absl::c_iota(broadcast_dimensions, 0); + ++broadcast_dimensions.back(); + xla::XlaOp input_broadcast = + xla::BroadcastInDim(input, out_dim_sizes, broadcast_dimensions); + return xla::Select(mask, input_broadcast, xla::ZerosLike(input_broadcast)); } class DiagOp : public XlaOpKernel { @@ -103,8 +97,7 @@ class DiagOp : public XlaOpKernel { input = xla::Reshape(input, {size}); // Create an R2 with the R1 diagonal. - xla::XlaOp diag = - CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0)); + xla::XlaOp diag = CreateDiagonal(input, size, /*other_dims=*/{}); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -181,8 +174,7 @@ class MatrixDiagOp : public XlaOpKernel { other_dims.remove_suffix(1); xla::XlaOp input = ctx->Input(0); - xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, - ctx->input_xla_type(0)); + xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims); ctx->SetOutput(0, diag); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc new file mode 100644 index 00000000000..6b3334dc1de --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +constexpr std::array kEinsumTypes = {{DT_BFLOAT16, DT_FLOAT}}; + +class EinsumOp : public XlaOpKernel { + public: + explicit EinsumOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("equation", &equation_)); + } + + ~EinsumOp() override = default; + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp lhs = ctx->Input(0); + xla::XlaOp rhs = ctx->Input(1); + const TensorShape a_shape = ctx->InputShape(0); + const TensorShape b_shape = ctx->InputShape(1); + ctx->SetOutput(0, xla::Einsum(lhs, rhs, equation_)); + } + + private: + string equation_; + TF_DISALLOW_COPY_AND_ASSIGN(EinsumOp); +}; + +REGISTER_XLA_OP(Name("XlaEinsum").TypeConstraint("T", kEinsumTypes), EinsumOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 29687c7b82f..d801d560040 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -99,23 +101,22 @@ class ExtractImagePatchesOp : public XlaOpKernel { // The following code is equivalent to: // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) int64 kernel_size = 1; - std::vector lhs_shape(num_dims, 1); + std::vector kernel_shape(num_dims, 1); for (int i = 0; i < num_spatial_dims; ++i) { int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); - lhs_shape[i] = ksizes_[input_dim]; + kernel_shape[i] = ksizes_[input_dim]; kernel_size *= ksizes_[input_dim]; } - lhs_shape[num_spatial_dims] = depth; - lhs_shape[num_spatial_dims + 1] = 1; - - // Builds an identity matrix as a broadcast equality of iotas. - // iota = np.arange(np.prod(ksize), depth) - // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth); - - auto lhs = xla::Reshape(iota, lhs_shape); - auto filter = xla::ConvertElementType( - xla::Eq(lhs, iota, {num_spatial_dims + 1}), type); + kernel_shape[num_spatial_dims] = 1; + kernel_shape[num_spatial_dims + 1] = kernel_size * depth; + xla::Shape iota_kernel_shape = + xla::ShapeUtil::MakeShape(xla::S32, {kernel_size, depth, kernel_size}); + xla::XlaOp filter = + xla::Reshape(xla::ConvertElementType( + xla::Eq(xla::Iota(builder, iota_kernel_shape, 0), + xla::Iota(builder, iota_kernel_shape, 2)), + type), + kernel_shape); xla::ConvolutionDimensionNumbers dims; std::vector window_strides(num_spatial_dims); @@ -148,7 +149,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { xla::XlaOp conv = xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, - lhs_dilation, rhs_dilation, dims); + lhs_dilation, rhs_dilation, dims, depth); ctx->SetOutput(0, conv); } diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 142be030f73..96f066d117c 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -82,33 +82,71 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min); } +// Builds a custom_call to a method named 'fake_quant_with_min_max_vars'. +// The method will be provided the input, the min/max range from the original +// TensorFlow op, and the num_bits and narrow_range attributes. +xla::StatusOr BuildFakeQuantCustomCall( + xla::XlaBuilder* b, xla::XlaOp input, xla::XlaOp input_min, + xla::XlaOp input_max, int num_bits, bool narrow_range) { + xla::XlaOp num_bits_arg = + XlaHelpers::IntegerLiteral(b, DataType::DT_INT32, num_bits); + xla::XlaOp narrow_range_arg = narrow_range + ? XlaHelpers::One(b, DataType::DT_BOOL) + : XlaHelpers::Zero(b, DataType::DT_BOOL); + + std::vector args = {input, input_min, input_max, num_bits_arg, + narrow_range_arg}; + std::vector arg_shapes; + for (const xla::XlaOp& arg : args) { + TF_ASSIGN_OR_RETURN(xla::Shape arg_shape, b->GetShape(arg)); + *arg_shape.mutable_layout() = + xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank()); + arg_shapes.push_back(std::move(arg_shape)); + } + + // Input and output shapes match exactly. + TF_ASSIGN_OR_RETURN(xla::Shape output_shape, b->GetShape(input)); + + return xla::CustomCallWithLayout(b, "fake_quant_with_min_max_vars", args, + output_shape, arg_shapes); +} + class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { public: explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - int num_bits; - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); + OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16, errors::InvalidArgument("num_bits is out of range, expected " "between 2 and 16, was: ", - num_bits)); - bool narrow_range; - OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); - quant_min_ = narrow_range ? 1 : 0; - quant_max_ = (1 << num_bits) - 1; + num_bits_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_)); + quant_min_ = narrow_range_ ? 1 : 0; + quant_max_ = (1 << num_bits_) - 1; - float input_min, input_max; - OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max)); - CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_, + OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max_)); + CpuNudge(input_min_, input_max_, quant_min_, quant_max_, &nudged_input_min_, &nudged_input_max_, &input_scale_); } void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - xla::XlaBuilder* b = ctx->builder(); + if (ctx->compiler()->options().allow_cpu_custom_calls && + ctx->compiler()->options().custom_fake_quant_op_calls) { + xla::XlaOp custom_call_output = + b->ReportErrorOrReturn(BuildFakeQuantCustomCall( + b, input, + XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_min_), + XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_max_), + num_bits_, narrow_range_)); + ctx->SetOutput(0, custom_call_output); + return; + } + xla::XlaOp nudged_input_min = XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); xla::XlaOp nudged_input_max = @@ -121,6 +159,10 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { } private: + int num_bits_; + bool narrow_range_; + float input_min_; + float input_max_; float quant_min_; float quant_max_; float nudged_input_min_; @@ -184,25 +226,32 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel { public: explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - int num_bits; - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); - OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); + OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16, errors::InvalidArgument("num_bits is out of range, expected " "between 2 and 16, was: ", - num_bits)); - bool narrow_range; - OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); - quant_min_ = narrow_range ? 1 : 0; - quant_max_ = (1 << num_bits) - 1; + num_bits_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_)); + quant_min_ = narrow_range_ ? 1 : 0; + quant_max_ = (1 << num_bits_) - 1; } void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); xla::XlaOp input_min = ctx->Input(1); xla::XlaOp input_max = ctx->Input(2); - xla::XlaBuilder* b = ctx->builder(); + if (ctx->compiler()->options().allow_cpu_custom_calls && + ctx->compiler()->options().custom_fake_quant_op_calls) { + xla::XlaOp custom_call_output = + b->ReportErrorOrReturn(BuildFakeQuantCustomCall( + b, input, input_min, input_max, num_bits_, narrow_range_)); + ctx->SetOutput(0, custom_call_output); + return; + } + xla::XlaOp nudged_input_min, nudged_input_max, input_scale; XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); @@ -213,6 +262,8 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel { } private: + int num_bits_; + bool narrow_range_; float quant_min_; float quant_max_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 2c430e3e55f..5ac288d8a34 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -67,6 +67,13 @@ class GenericFftOp : public XlaOpKernel { } for (int i = 0; i < fft_rank_; i++) { int index = input_shape.dims() - fft_rank_ + i; + OP_REQUIRES( + ctx, + input_shape.dim_size(index) == 0 || + input_shape.dim_size(index) >= expected_sizes[i], + errors::InvalidArgument( + "Input dimension ", index, " must have length of at least ", + expected_sizes[i], " but got: ", input_shape.dim_size(index))); if (input_shape.dim_size(index) > expected_sizes[i]) { slice_sizes[index] = expected_sizes[i]; } else { diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index af1085d5b35..516e3aeaa88 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -51,62 +51,35 @@ class PassOn : public XlaOpKernel { REGISTER_XLA_OP(Name("_ListToArray"), PassOn); REGISTER_XLA_OP(Name("_ArrayToList"), PassOn); -// TODO(phawkins): this is an almost exact copy of the SymbolicGradientOp -// implementation from regular Tensorflow. Once XLA has been open sourced -// merge the two implementations. (Note: this implementation propagates the -// step_resource_manager). -class SymbolicGradientOp : public AsyncOpKernel { +class AlwaysFailOp : public OpKernel { public: - explicit SymbolicGradientOp(OpKernelConstruction* ctx) - : AsyncOpKernel(ctx), handle_(-1) {} + explicit AlwaysFailOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - ~SymbolicGradientOp() override {} + ~AlwaysFailOp() override {} - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - FunctionLibraryRuntime* lib = ctx->function_library(); - OP_REQUIRES_ASYNC(ctx, lib != nullptr, - errors::Internal("No function library is provided."), - done); - - OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_), - done); - - FunctionLibraryRuntime::Options opts; - opts.step_id = ctx->step_id(); - opts.runner = ctx->runner(); - opts.step_container = ctx->step_container(); - std::vector args; - args.reserve(ctx->num_inputs()); - for (int i = 0; i < ctx->num_inputs(); ++i) { - args.push_back(ctx->input(i)); - } - std::vector* rets = new std::vector; - lib->Run( - opts, handle_, args, rets, [ctx, done, rets](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } else if (rets->size() != ctx->num_outputs()) { - ctx->SetStatus(errors::InvalidArgument( - "SymGrad expects to return ", ctx->num_outputs(), - " tensor(s), but get ", rets->size(), " tensor(s) instead.")); - } else { - for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } - } - delete rets; - done(); - }); + void Compute(OpKernelContext* ctx) override { + ctx->CtxFailure(errors::FailedPrecondition( + "Unexpected attempt to compile ", name(), " which is a ", type_string(), + ". These nodes should always be handled by the graph compiler")); } - - private: - FunctionLibraryRuntime::Handle handle_; - - TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp); }; -REGISTER_XLA_OP(Name(kGradientOp), SymbolicGradientOp); +// These operations are handled specially in the TF/XLA bridge so their +// OpKernel's should never be called. We still register a dummy kernel so that +// they show up as "supported" when we are deciding whether a graph containing +// them is compilable with XLA. + +REGISTER_XLA_OP(Name(kGradientOp), AlwaysFailOp); +REGISTER_XLA_OP(Name("PartitionedCall") + .AllowResourceTypes() + .AllowVariantTypes() + .AllowStringType(), + AlwaysFailOp); +REGISTER_XLA_OP(Name("StatefulPartitionedCall") + .AllowResourceTypes() + .AllowVariantTypes() + .AllowStringType(), + AlwaysFailOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index aa5637e2669..4422af7d15f 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/if_op.h" +#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" @@ -39,6 +40,33 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } else { has_token_input_output_ = !token_input_nodes_.empty(); } + if (ctx->HasAttr(kPropagateCompileTimeConsts)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts, + &propagate_compile_time_consts_)); + } +} + +Status ConvertCompileTimeConstArgumentsToConst( + XlaOpKernelContext* ctx, std::vector* args) { + for (int i = 0; i < args->size(); i++) { + XlaCompiler::Argument& arg = (*args)[i]; + const XlaExpression& expression = ctx->InputExpression(i + 1); + // If the input tensor is a compile time constant build a kConstant type + // argument. + if (arg.kind == XlaCompiler::Argument::kParameter) { + // NOTE: We can not simply check that this is Kind::kConstant because + // this could be the output of a MetadataOnly op e.g. Size. + xla::StatusOr> maybe_constant = + expression.ResolveConstant(ctx->compiler()->client()); + if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = expression.dtype(); + arg.constant_value = std::move(maybe_constant.ValueOrDie().value()); + arg.shape = expression.GetShape().ValueOrDie(); + } + } + } + return Status::OK(); } // TODO(b/35949885): There is duplication here with the handling of the @@ -87,12 +115,30 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; - arg.shape = ctx->InputShape(i + 1); + // Use the xla::Shape for the input instead of ctx->InputShape. This is + // necessary for forwarding shapes of DT_VARIANTs, e.g. TensorLists. + auto shape_or = ctx->builder()->GetShape(ctx->Input(i + 1)); + OP_REQUIRES_OK(ctx, shape_or.status()); + arg.shape = shape_or.ValueOrDie(); VLOG(2) << "Arg type: " << DataTypeString(arg.type) << " shape: " << arg.HumanString(); } } + if (propagate_compile_time_consts_) { + // Replaces `kParameter` type args in `arguments` with `kConstant` if + // the op input corresponding to that arg is a compile-time const. This + // is necessary to propagate compile time consts to ops in the branch + // functions. + // Note: Propagating "all" compile-time constants may not be necessary. We + // should ideally only propagate consts which are required to be compile + // time constants in the branch functions. But that would require calling + // BackwardsConstAnalysis here which would be expensive. However, if we + // start hitting memory issues we should revisit this. + OP_REQUIRES_OK(ctx, + ConvertCompileTimeConstArgumentsToConst(ctx, &arguments)); + } + // Compile both branches of the conditional. XlaCompiler::CompileOptions options; options.use_tuple_arg = true; @@ -215,7 +261,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); } else { - inputs[i] = ctx->Input(i + 1); + inputs[i] = ctx->Input(input_num); } } @@ -280,8 +326,10 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building If"; } -REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); -REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp); -REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); +REGISTER_XLA_OP(Name("If").AllowResourceTypes().AllowVariantTypes(), XlaIfOp); +REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes().AllowVariantTypes(), + XlaIfOp); +REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes().AllowVariantTypes(), + XlaIfOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index 7783e13a8a5..3ac1b344ef8 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -54,6 +54,10 @@ class XlaIfOp : public XlaOpKernel { DataTypeVector output_types_; bool has_token_input_output_; std::vector token_input_nodes_; + // Whether to propagate compile time consts into the cond branches. + // This is not supported by default now since it may cause HBM memory + // overheads. + bool propagate_compile_time_consts_ = false; }; } // namespace tensorflow diff --git a/tensorflow/core/kernels/string_view_variant_wrapper.cc b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc similarity index 83% rename from tensorflow/core/kernels/string_view_variant_wrapper.cc rename to tensorflow/compiler/tf2xla/kernels/if_while_utils.cc index b576eb4a3e6..0011aa29ae2 100644 --- a/tensorflow/core/kernels/string_view_variant_wrapper.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/string_view_variant_wrapper.h" +#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" namespace tensorflow { -constexpr const char StringViewVariantWrapper::kTypeName[]; +const char kPropagateCompileTimeConsts[] = "_xla_propagate_compile_time_consts"; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h new file mode 100644 index 00000000000..4bf76d4da5c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h @@ -0,0 +1,25 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ + +namespace tensorflow { + +extern const char kPropagateCompileTimeConsts[]; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 92b20fe0ba5..dcd523e711d 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -20,11 +20,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -427,7 +429,8 @@ class NonMaxSuppressionOp : public XlaOpKernel { errors::InvalidArgument("XLA compilation requires number of " "boxes to be <= kint32max, got ", num_boxes)); - + xla::PrimitiveType boxes_xla_type = context->InputXlaType("boxes"); + xla::PrimitiveType scores_xla_type = context->InputXlaType("scores"); const xla::XlaOp boxes_input = context->Input("boxes"); const xla::XlaOp scores_input = context->Input("scores"); int64 output_size; @@ -445,15 +448,18 @@ class NonMaxSuppressionOp : public XlaOpKernel { // Choose a more convenient layout. const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0}); const xla::XlaOp boxes_sorted = xla::GetTupleElement( - xla::Sort(/*keys=*/-xla::Broadcast(scores_input, {4}), - /*values=*/{boxes}, + xla::Sort({xla::Broadcast(scores_input, {4}), boxes}, + xla::CreateScalarGtComputation( + {scores_xla_type, boxes_xla_type}, builder), /*dimension=*/1), 1); // Track the mapping of indices into sorted domain. const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes); - const xla::XlaOp indices_sort = xla::Sort(-scores_input, {iota_indices}); + const xla::XlaOp indices_sort = xla::Sort( + {scores_input, iota_indices}, + xla::CreateScalarGtComputation({scores_xla_type, xla::S32}, builder)); const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1); - const xla::XlaOp scores = xla::Neg(xla::GetTupleElement(indices_sort, 0)); + const xla::XlaOp scores = xla::GetTupleElement(indices_sort, 0); // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0. const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted, diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index d19d48e5dd9..6d447d9f7c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -555,6 +555,12 @@ class ResizeNearestNeighborOp : public XlaOpKernel { ctx, align_corners_ == true, errors::Unimplemented("ResizeNearestNeighbor with align_corners=False " "is not yet implemented")); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("half_pixel_centers", &half_pixel_centers_)); + OP_REQUIRES(ctx, half_pixel_centers_ == false, + errors::Unimplemented( + "ResizeNearestNeighbor with half_pixel_centers=True is " + "not yet implemented")); } void Compile(XlaOpKernelContext* ctx) override { @@ -563,6 +569,7 @@ class ResizeNearestNeighborOp : public XlaOpKernel { private: bool align_corners_ = true; + bool half_pixel_centers_ = true; bool is_kernel_bilinear_ = false; }; @@ -573,6 +580,12 @@ class ResizeBilinearOp : public XlaOpKernel { public: explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("half_pixel_centers", &half_pixel_centers_)); + OP_REQUIRES( + ctx, half_pixel_centers_ == false, + errors::Unimplemented("ResizeBilinear with half_pixel_centers=True is " + "not yet implemented")); } void Compile(XlaOpKernelContext* ctx) override { @@ -581,6 +594,7 @@ class ResizeBilinearOp : public XlaOpKernel { private: bool align_corners_ = true; + bool half_pixel_centers_ = true; bool is_kernel_bilinear_ = true; }; @@ -591,10 +605,16 @@ class ResizeBilinearGradOp : public XlaOpKernel { public: explicit ResizeBilinearGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("half_pixel_centers", &half_pixel_centers_)); OP_REQUIRES( ctx, align_corners_ == true, errors::Unimplemented("ResizeBilinearGrad with align_corners=False is " "not yet implemented")); + OP_REQUIRES(ctx, half_pixel_centers_ == false, + errors::Unimplemented( + "ResizeBilinearGrad with half_pixel_centers=True is " + "not yet implemented")); DataType output_dtype; OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); @@ -676,6 +696,7 @@ class ResizeBilinearGradOp : public XlaOpKernel { private: bool align_corners_; + bool half_pixel_centers_ = true; xla::PrimitiveType output_type_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc new file mode 100644 index 00000000000..9c6fcf429d4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace { + +class InTopKOp : public XlaOpKernel { + public: + explicit InTopKOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("T", &targets_dtype_)); + OP_REQUIRES_OK(context, + DataTypeToPrimitiveType(targets_dtype_, &targets_type_)); + } + + void Compile(XlaOpKernelContext* context) override { + int64 k; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &k)); + OP_REQUIRES(context, k >= 0, + errors::InvalidArgument("Need k >= 0, got ", k)); + const TensorShape predictions_shape = context->InputShape(0); + OP_REQUIRES( + context, predictions_shape.dims() == 2, + errors::InvalidArgument("predictions must be == 2-D, got shape ", + predictions_shape.DebugString())); + const TensorShape targets_shape = context->InputShape(1); + OP_REQUIRES(context, targets_shape.dims() == 1, + errors::InvalidArgument("targets must be == 1-D, got shape ", + targets_shape.DebugString())); + + int64 batch_size = predictions_shape.dim_size(0); + OP_REQUIRES(context, batch_size == targets_shape.dim_size(0), + errors::InvalidArgument( + "targets must have same elements as predictions rows. Had ", + targets_shape.dim_size(0), ", needed ", batch_size)); + + // Given `predictions` with shape batch_size*num_classes and `target` with + // shape num_classes, we generate `targets_values_r1` with shape num_classes + // which the elements are the corresponding values of `targets` in + // `predictions` for each example. This step can be done using xla::Gather + // as well. + xla::XlaOp predictions_r2 = context->Input(0); + xla::XlaOp targets_r1 = context->Input(1); + + xla::XlaBuilder* xla_builder = context->builder(); + xla::XlaOp iota_r1 = + xla::Iota(xla_builder, targets_type_, predictions_shape.dim_size(1)); + xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {batch_size}); + + xla::XlaOp eq_r2 = xla::Eq(targets_r1, iota_r2, {0}); + xla::XlaOp zero_r0_f32 = xla::Zero(xla_builder, xla::F32); + xla::XlaOp zero_r2_f32 = xla::ZerosLike(predictions_r2); + xla::XlaOp select_r2 = xla::Select(eq_r2, predictions_r2, zero_r2_f32); + xla::XlaOp targets_values_r1 = xla::Reduce( + select_r2, zero_r0_f32, + xla::CreateScalarAddComputation(xla::F32, xla_builder), {1}); + + // Calculate in each row of `predictions`, how many values are larger than + // the value of target class. Then return the result whether the count <= k, + // which indicates the target is in topk. + xla::XlaOp ge_r2 = xla::Ge(predictions_r2, targets_values_r1, {0}); + xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32); + xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes()); + xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32); + xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes()); + xla::XlaOp one_hot_r2 = xla::Select(ge_r2, one_r2, zero_r2); + xla::XlaOp num_ge_r1 = xla::Reduce( + one_hot_r2, zero_r0, + xla::CreateScalarAddComputation(xla::S32, xla_builder), {1}); + + xla::XlaOp result = + xla::Le(num_ge_r1, xla::ConstantR0(xla_builder, k)); + + context->SetOutput(0, result); + } + + protected: + DataType targets_dtype_; + xla::PrimitiveType targets_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(InTopKOp); +}; + +REGISTER_XLA_OP(Name("InTopKV2") + .CompileTimeConstantInput("k") + .TypeConstraint("T", {DT_INT32, DT_INT64}), + InTopKOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index c1539f48d4f..219dc738eaa 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -31,7 +31,9 @@ limitations under the License. namespace tensorflow { XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min) - : XlaOpKernel(ctx), is_min_(is_min) {} + : XlaOpKernel(ctx), + is_min_(is_min), + is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {} void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { const TensorShape input_shape = ctx->InputShape(0); @@ -64,10 +66,19 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp input = ctx->Input(0); xla::XlaOp output; + // One pass ArgMin/ArgMax is slow on GPUs. if (is_min_) { - output = xla::ArgMin(input, index_xla_type, axis); + if (is_gpu_) { + output = xla::ArgMinTwoPass(input, index_xla_type, axis); + } else { + output = xla::ArgMin(input, index_xla_type, axis); + } } else { - output = xla::ArgMax(input, index_xla_type, axis); + if (is_gpu_) { + output = xla::ArgMaxTwoPass(input, index_xla_type, axis); + } else { + output = xla::ArgMax(input, index_xla_type, axis); + } } ctx->SetOutput(0, output); @@ -76,7 +87,6 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/false) {} REGISTER_XLA_OP(Name("ArgMax") - .Device(DEVICE_GPU_XLA_JIT) .CompileTimeConstantInput("dimension"), XlaArgMaxOp); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.h b/tensorflow/compiler/tf2xla/kernels/index_ops.h index ef2b9e6b6eb..4089a2071e4 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.h @@ -30,6 +30,7 @@ class XlaArgMinMaxOp : public XlaOpKernel { private: const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)? + const bool is_gpu_; }; class XlaArgMaxOp : public XlaArgMinMaxOp { diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 39d96e748b3..19ec222e2e8 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -16,7 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/macros.h" @@ -46,4 +46,4 @@ extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } -REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index 9b83392d8fb..6e1c1226321 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -16,7 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/macros.h" @@ -51,4 +51,4 @@ extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } -REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc index 71920bf5c1e..94db561ee65 100644 --- a/tensorflow/compiler/tf2xla/kernels/permute_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/tensor_format.h" @@ -77,7 +78,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel { if (input_rank == 2) { keys = xla::BroadcastInDim(keys, {4, 2}, {0}); } - auto sorted = xla::Sort(keys, {ctx->Input(0)}, 0); + auto sorted = xla::Sort({keys, ctx->Input(0)}, + xla::CreateScalarLtComputation( + {xla::S32, ctx->input_xla_type(0)}, builder), + 0); auto output = xla::GetTupleElement(sorted, 1); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 85223795aa8..507bc8d7a3b 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA specific pooling ops. +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -183,8 +184,7 @@ class MaxPoolOp : public PoolingOp { class MaxPool2DOp : public MaxPoolOp { public: explicit MaxPool2DOp(OpKernelConstruction* ctx) - : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { - } + : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {} }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); REGISTER_XLA_OP(Name("MaxPoolV2") @@ -245,8 +245,7 @@ class AvgPoolOp : public PoolingOp { class AvgPool2DOp : public AvgPoolOp { public: explicit AvgPool2DOp(OpKernelConstruction* ctx) - : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { - } + : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {} }; REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); @@ -329,6 +328,20 @@ class MaxPoolGradOp : public XlaOpKernel { xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + // Create a MaxPool operation to check the expected resulting shape, and + // then throw away the operation because we don't actually neeed it here. + TensorShape expected_out_shape; + auto pooling = + xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding, + XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2)); + auto status_or_shape = pooling.builder()->GetShape(pooling); + OP_REQUIRES_OK(ctx, status_or_shape.status()); + OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(), + &expected_out_shape)); + OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape, + errors::Unimplemented("The output dimensions do not match the " + "other input values.")); + xla::PrimitiveType element_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); @@ -454,8 +467,7 @@ class AvgPoolGradOp : public XlaOpKernel { class AvgPool2DGradOp : public AvgPoolGradOp { public: explicit AvgPool2DGradOp(OpKernelConstruction* ctx) - : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { - } + : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {} }; REGISTER_XLA_OP( Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"), @@ -558,10 +570,13 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto b = ctx->builder(); auto sixteen = xla::ConstantR0(b, 16); - // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 + // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32. + // + // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back + // to F32 since the XLA compiler may ignore narrowing casts to floating + // point types if the debug option xla_allow_excess_precision is set. auto in_hi = xla::BitcastConvertType( - xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16), - xla::F32), + xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7), xla::U32); auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32); auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index d6c70d4af1c..0b54c88fae9 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -135,7 +136,9 @@ class RandomShuffleOp : public XlaOpKernel { xla::XlaOp curr = input; for (int i = 0; i < rounds; ++i) { xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape); - xla::XlaOp sorted = xla::Sort(keys, {curr}); + xla::XlaOp sorted = xla::Sort( + {keys, curr}, xla::CreateScalarLtComputation( + {xla::U32, ctx->input_xla_type(0)}, builder)); curr = xla::GetTupleElement(sorted, 1); } diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h index d107be6f13c..9a6dc37e2c9 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -22,6 +22,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tensorflow { +// Returns a tensor containing 'shape' random values uniformly distributed in +// the range [minval, maxval). The raw random bits are generated by the given +// `bit_generator` and converted to the requested data type and range. This +// routine requires 2 32-bit integer seeds and currently only supports 'shape's +// of type F32, S32 and S64. +xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, + xla::XlaOp seeds, const xla::Shape& shape, + xla::XlaOp minval, xla::XlaOp maxval); // Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise. // It masks the last 16 bit. With normal rounding, values near "maxval" would be diff --git a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc new file mode 100644 index 00000000000..46585a26769 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace tensorflow { +namespace { + +class XlaReplicaIdOp : public XlaOpKernel { + public: + explicit XlaReplicaIdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaReplicaIdOp); +}; + +void XlaReplicaIdOp::Compile(XlaOpKernelContext* ctx) { + ctx->SetOutput(0, xla::ReplicaId(ctx->builder())); +} + +REGISTER_XLA_OP(Name("XlaReplicaId"), XlaReplicaIdOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index fa1b6b91710..6cf4c01f1b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel { TensorShape shape; int64 product = 1; int unknown_index = -1; + bool shape_has_zero_dim = false; for (int d = 0; d < num_dims; ++d) { const int32 size = shape_input[d]; if (size == -1) { @@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel { unknown_index, " and ", d)); unknown_index = d; shape.AddDim(1); + } else if (size == 0) { + // We don't include zero-sized dimension in product, so that we can + // still calculate number of elements for non-zero-sized dimensions and + // therefore infer their shapes. + shape.AddDim(size); + shape_has_zero_dim = true; } else { OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument( @@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel { } } if (unknown_index != -1) { - OP_REQUIRES( - ctx, product > 0, - errors::InvalidArgument("Reshape cannot infer the missing input size " - "for an empty tensor unless all specified " - "input sizes are non-zero")); - const int64 missing = input_shape.num_elements() / product; - OP_REQUIRES( - ctx, product * missing == input_shape.num_elements(), - errors::InvalidArgument( - "Input to reshape is a tensor with ", input_shape.num_elements(), - " values, but the requested shape requires a multiple of ", - product)); + int64 input_num_elements = 1; + bool input_has_zero_dim = false; + for (int dim = 0; dim < input_shape.dims(); dim++) { + // For zero dimension, we don't count it into `input_num_elements` + // unless `sizes` has no zero dimension, so we are still able to + // infer shapes for other dimensions. + if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) { + input_num_elements *= input_shape.dim_size(dim); + } else { + input_has_zero_dim = true; + } + } + + const int64 missing = input_num_elements / product; + if (!input_has_zero_dim) { + OP_REQUIRES( + ctx, product * missing == input_num_elements, + errors::InvalidArgument( + "Input to reshape is a tensor with ", input_num_elements, + " values, but the requested shape requires a multiple of ", + product)); + } shape.set_dim(unknown_index, missing); } OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 1f417037284..058938a46db 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -37,9 +37,11 @@ class RetvalOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const Tensor& input = ctx->op_kernel_context()->input(0); - // DT_VARIANT types represent Tensor Lists and are wrapped in a DT_UINT8 - // tensor so we skip the check here. - if (dtype_ != DT_VARIANT) { + // Types that cannot be copied using memcpy (like DT_VARIANT types that + // represent Tensor Lists) are wrapped in a DT_UINT8 and hence the type + // mismatches. Skip the test in such cases. See + // XlaOpKernelContext::SetOutputExpression for details. + if (DataTypeCanUseMemcpy(dtype_)) { OP_REQUIRES(ctx, input.dtype() == dtype_, errors::InvalidArgument( "Type mismatch: actual ", DataTypeString(input.dtype()), diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index aaf8c6075dd..ed303ba2774 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -62,18 +62,11 @@ class SelectOp : public XlaOpKernel { then_shape.dim_size(0), " vs. ", cond_shape.num_elements())); - // TODO(phawkins): broadcasting on the right seems pretty awkward in - // XLA. It seems we have to broadcast on the left and then Reshape - // to get the dimensions in the right order. - const auto dim_sizes = then_shape.dim_sizes(); - absl::Span bdims = dim_sizes; - bdims.remove_prefix(1); - cond_handle = xla::Broadcast(cond_handle, bdims); - - std::vector dim_order(then_shape.dims()); - dim_order[0] = then_shape.dims() - 1; - std::iota(dim_order.begin() + 1, dim_order.end(), 0); - cond_handle = xla::Transpose(cond_handle, dim_order); + // Broadcast into the dimensions on the right. + std::vector broadcast_dimensions(cond_shape.dims()); + absl::c_iota(broadcast_dimensions, 0); + cond_handle = xla::BroadcastInDim(cond_handle, then_shape.dim_sizes(), + broadcast_dimensions); } ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); } diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 280b68383c2..265e7e784a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -16,6 +16,8 @@ limitations under the License. // XLA-specific Shape Ops. #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -24,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -223,14 +226,42 @@ class ZerosLikeOp : public XlaOpKernel { explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); + if (IsTensorListInput(ctx, 0)) { + // Input is a TensorList. - auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes())); + // Check the TensorList input is initialized. + xla::XlaOp list = ctx->Input(0); + bool is_initialized; + OP_REQUIRES_OK(ctx, IsTensorListInitialized(list, &is_initialized)); + OP_REQUIRES( + ctx, is_initialized, + errors::InvalidArgument( + "TensorList input for ZerosLike op is an uninitialized list")); + + auto list_shape_or = ctx->builder()->GetShape(list); + OP_REQUIRES_OK(ctx, list_shape_or.status()); + xla::XlaOp new_list; + OP_REQUIRES_OK( + ctx, CreateZerosTensorListWithShape( + ctx->builder(), list_shape_or.ValueOrDie(), &new_list)); + + xla::XlaOp push_index; + OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index)); + + xla::XlaOp result; + OP_REQUIRES_OK(ctx, + SetTensorListPushIndex(new_list, push_index, &result)); + ctx->SetTensorListOutput(0, result); + } else { + const TensorShape input_shape = ctx->InputShape(0); + + auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); + ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes())); + } } }; -REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp); +REGISTER_XLA_OP(Name("ZerosLike").AllowVariantTypes(), ZerosLikeOp); class OnesLikeOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index 6cfdf4a5ae4..8cfd9850519 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { @@ -25,7 +26,10 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - context->SetOutput(0, xla::Sort(context->Input("input"))); + context->SetOutput(0, xla::Sort({context->Input("input")}, + xla::CreateScalarLtComputation( + {context->InputXlaType("input")}, + context->builder()))); } }; @@ -37,8 +41,11 @@ class XlaKeyValueSortOp : public XlaOpKernel { : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - xla::XlaOp result = - xla::Sort(context->Input("keys"), {context->Input("values")}); + xla::XlaOp result = xla::Sort( + {context->Input("keys"), context->Input("values")}, + xla::CreateScalarLtComputation( + {context->InputXlaType("keys"), context->InputXlaType("values")}, + context->builder())); context->SetOutput(0, xla::GetTupleElement(result, 0)); context->SetOutput(1, xla::GetTupleElement(result, 1)); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 3293c13b21b..96863d6d1ba 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -30,11 +31,6 @@ class SpaceToDepthOp : public XlaOpKernel { OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, - errors::InvalidArgument("Unsupported data format ", - ToString(data_format_), - "; expected formats NHWC or NCHW")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -42,19 +38,36 @@ class SpaceToDepthOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_tensor_shape = ctx->InputShape(0); - int input_rank = input_tensor_shape.dims(); + xla::XlaOp input = ctx->Input(0); + + TensorFormat data_format = data_format_; + // If the data is in a vectorized format, reformat it into a non-vectorized + // version first. We'll undo the transformation later. + if (data_format == FORMAT_NCHW_VECT_C) { + data_format = FORMAT_NCHW; + auto input_reshaped = NCHW_VECT_CToNCHW(input); + OP_REQUIRES_OK(ctx, input_reshaped.status()); + input = input_reshaped.ValueOrDie(); + } + + OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_))); + + xla::XlaBuilder* builder = input.builder(); + auto input_xla_shape = builder->GetShape(input); + OP_REQUIRES_OK(ctx, input_xla_shape.status()); + const std::vector& input_shape = + input_xla_shape.ValueOrDie().dimensions(); + int input_rank = input_shape.size(); + static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got ", input_rank)); - const absl::InlinedVector input_shape = - input_tensor_shape.dim_sizes(); - xla::XlaOp input = ctx->Input(0); - - int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); - int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format); std::vector reshaped_shape; std::vector transpose_order; @@ -62,7 +75,7 @@ class SpaceToDepthOp : public XlaOpKernel { reshaped_shape.reserve(input_rank); transpose_order.reserve(input_rank); output_shape.reserve(input_rank); - if (data_format_ == FORMAT_NHWC) { + if (data_format == FORMAT_NHWC) { int64 block_elems = 1; for (int i = 0; i < num_spatial_dims; ++i) { OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, @@ -157,6 +170,14 @@ class SpaceToDepthOp : public XlaOpKernel { // xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); + // If this used to be a vectorized format turn it back now. + if (data_format != data_format_) { + DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C); + auto output_reshaped = NCHWToNCHW_VECT_C(output); + OP_REQUIRES_OK(ctx, output_reshaped.status()); + output = output_reshaped.ValueOrDie(); + } + ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index def3c147bf3..ff7f0ac6255 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -74,7 +74,7 @@ class SparseToDenseOp : public XlaOpKernel { auto buffer = Broadcast(default_value, output_shape.dim_sizes()); auto result = XlaScatter(buffer, sparse_values, indices, - /*indices_are_vectors=*/num_dims > 1, + /*indices_are_vectors=*/indices_shape.dims() > 1, /*combiner=*/{}, builder); context->SetOutput(0, builder->ReportErrorOrReturn(result)); } diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index f1d68835e12..7e210f57303 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/stateful_random_ops.h" + #include #include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" @@ -29,161 +31,130 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/stateful_random_ops.h" #include "tensorflow/core/lib/math/math_util.h" namespace tensorflow { namespace { -std::pair GetInputsFromCounter( - xla::XlaOp counter, const int64 size) { - auto builder = counter.builder(); - auto input_u64 = Iota(builder, xla::U64, size); - input_u64 = input_u64 + counter; - counter = counter + xla::ConstantR0(builder, size); - return std::make_pair(xla::Uint64ToUint32s(input_u64), counter); -} - -// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too -// wastefully, only able to generate 2^32*2 int32 numbers for each key, while -// the real capacity is 2^64*2. Counter-space efficiency is important for -// stateful ops, hence the following 2 new functions. -std::pair StatefulRngUniformU32( - xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { - auto builder = key.builder(); - const int64 size = xla::ShapeUtil::ElementsIn(shape); - const int64 half_size = xla::CeilOfRatio(size, 2); - const bool size_is_odd = (half_size * 2 != size); - auto inputs_counter = GetInputsFromCounter(counter, half_size); - auto inputs = inputs_counter.first; - counter = inputs_counter.second; - auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key)); - if (size_is_odd) { - outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); +xla::BitGeneratorTy BitGenerator(Algorithm alg) { + if (alg == RNG_ALG_PHILOX) { + return [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { + return xla::PhiloxBitGenerator(key, state, shape, /*scramble=*/false); + }; } - auto result = ConcatInDim(builder, outputs, 0); - return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())), - counter); + return xla::ThreeFryBitGenerator; } -std::pair StatefulRngUniformU64( - xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { - const int64 size = xla::ShapeUtil::ElementsIn(shape); - auto inputs_counter = GetInputsFromCounter(counter, size); - auto inputs = inputs_counter.first; - counter = inputs_counter.second; - auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); - auto result = Uint32sToUint64(outputs); - return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())), - counter); -} - -std::pair StatefulRngUniform(xla::XlaOp key, - xla::XlaOp counter, - const xla::Shape& shape, - xla::XlaOp minval, - xla::XlaOp maxval) { - auto builder = key.builder(); +xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key, + xla::XlaOp initial_state, + const xla::Shape& shape, xla::XlaOp minval, + xla::XlaOp maxval) { xla::PrimitiveType type = shape.element_type(); switch (type) { - case xla::F32: { - auto bits_counter = StatefulRngUniformU32(key, counter, shape); - auto bits = bits_counter.first; - counter = bits_counter.second; - return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval), - counter); - } - case xla::U32: // fall through - case xla::S32: { - auto bits_counter = StatefulRngUniformU32(key, counter, shape); - auto bits = bits_counter.first; - counter = bits_counter.second; - return std::make_pair( - xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32), - counter); - } - case xla::U64: // fall through - case xla::S64: { - auto bits_counter = StatefulRngUniformU64(key, counter, shape); - auto bits = bits_counter.first; - counter = bits_counter.second; - return std::make_pair( - xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64), - counter); - } + case xla::F32: + return xla::UniformF32Distribution(key, initial_state, BitGenerator(alg), + minval, maxval, shape); + case xla::U32: + case xla::S32: + case xla::U64: + case xla::S64: + return UniformIntDistribution(key, initial_state, BitGenerator(alg), + minval, maxval, shape); default: - return std::make_pair(builder->ReportError(xla::Unimplemented( - "Types other than F32, U32, S32, U64 and S64 " - "are not implemented by " - "StatefulRngUniform.")), - counter); + return {key.builder()->ReportError(xla::Unimplemented( + "Types other than F32, U32, S32, U64 and S64 " + "are not implemented by " + "StatefulRngUniform; got %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; } } -template -std::pair map_first(std::function f, std::pair p) { - return std::make_pair(f(p.first), p.second); -} - -std::pair StatefulRngUniformFullInt( - xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { +xla::RngOutput StatefulRngUniformFullInt(Algorithm alg, xla::XlaOp key, + xla::XlaOp initial_state, + const xla::Shape& shape) { xla::PrimitiveType type = shape.element_type(); + xla::RngOutput output = BitGenerator(alg)(key, initial_state, shape); switch (type) { case xla::U32: - return StatefulRngUniformU32(key, counter, shape); - case xla::S32: { - // Needs explicit function type because of type-inference failure. - std::function f = [](xla::XlaOp x) { - return BitcastConvertType(x, xla::S32); - }; - return map_first(f, StatefulRngUniformU32(key, counter, shape)); - } case xla::U64: - return StatefulRngUniformU64(key, counter, shape); - case xla::S64: { - std::function f = [](xla::XlaOp x) { - return BitcastConvertType(x, xla::S64); - }; - return map_first(f, StatefulRngUniformU64(key, counter, shape)); - } + return output; + case xla::S32: + case xla::S64: + output.value = BitcastConvertType(output.value, type); + return output; default: - auto builder = key.builder(); - return std::make_pair( - builder->ReportError(xla::Unimplemented( + return { + key.builder()->ReportError(xla::Unimplemented( "Types other than U32, S32, U64 and S64 are not implemented by " "StatefulRngUniformFullInt; got: %s", xla::primitive_util::LowercasePrimitiveTypeName(type))), - counter); + initial_state}; } } -template -ListB Map(F f, ListA const& list_a) { - ListB list_b; - for (auto a : list_a) { - list_b.push_back(f(a)); +using SamplerReturnType = xla::StatusOr; + +int64 GetMinStateSize(Algorithm alg) { + if (alg == RNG_ALG_PHILOX) { + return PHILOX_MIN_STATE_SIZE; } - return list_b; + return THREEFRY_MIN_STATE_SIZE; } -xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, - absl::Span scalars) { - return ConcatInDim( - builder, - Map>( - [](xla::XlaOp x) { return xla::Reshape(x, {1}); }, scalars), - 0); +Status CheckStateShape(Algorithm alg, const TensorShape& shape) { + if (shape.dims() != 1) { + return errors::InvalidArgument( + "RNG state must have one and only one dimension, not ", shape.dims()); + } + auto state_size = shape.dim_size(0); + auto min_state_size = GetMinStateSize(alg); + if (state_size < min_state_size) { + return errors::InvalidArgument("The size of the state must be at least ", + min_state_size, "; got ", state_size); + } + return Status::OK(); } -using sampler_return_type = xla::StatusOr>; +std::pair StateAndKeyFromVariable(Algorithm alg, + xla::XlaOp var) { + if (alg == RNG_ALG_THREEFRY) { + static constexpr int kStateSize = 1; + auto state = BitcastConvertType( + xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64); + auto key = BitcastConvertType( + xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}), + xla::U64); + return std::make_pair(state, key); + } else { + static constexpr int kStateSize = 2; + auto state = + BitcastConvertType(xla::Slice(var, {0}, {kStateSize}, {1}), xla::U64); + auto key = xla::Reshape( + BitcastConvertType(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), + xla::U64), + {}); + return std::make_pair(state, key); + } +} + +xla::XlaOp StateAndKeyToVariable(Algorithm alg, xla::XlaOp state, + xla::XlaOp key) { + auto builder = state.builder(); + if (alg == RNG_ALG_THREEFRY) { + return ConcatScalars(builder, {state, key}); + } else { + return ConcatInDim(builder, {state, xla::Reshape(key, {1})}, 0); + } +} // A helper function containing the common part of several kernels below. // Precondition: 'algorithm' and 'shape' are compile-time constants. -Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx, - int alg_input_idx, int shape_input_idx, - std::function const& - sample_with_threefry) { +Status CompileImpl( + XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx, + int shape_input_idx, + std::function const& sampler) { auto alg_shape = ctx->InputShape(alg_input_idx); if (alg_shape.dims() != 0) { return errors::InvalidArgument("algorithm must be of shape [], not ", @@ -192,57 +163,77 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx, xla::Literal alg_literal; TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal)); auto alg = alg_literal.Get({}); - - if (alg == RNG_ALG_THREEFRY) { - xla::XlaOp var; - TensorShape var_shape; - TF_RETURN_IF_ERROR(ctx->ReadVariableInput( - state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var)); - if (var_shape.dims() != 1) { - return errors::InvalidArgument( - "RNG state must have one and only one dimension, not ", - var_shape.dims()); - } - auto state_size = var_shape.dim_size(0); - if (state_size < THREEFRY_MIN_STATE_SIZE) { - return errors::InvalidArgument( - "For the ThreeFry algorithm, the size of state" - " must be at least ", - THREEFRY_MIN_STATE_SIZE, "; got ", state_size); - } - TensorShape shape; - TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape)); - - static constexpr int COUNTER_SIZE = 1; - auto counter = BitcastConvertType( - xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64); - auto key = BitcastConvertType( - xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}), - {}), - xla::U64); - - auto status_or_value = sample_with_threefry(counter, key, shape); - if (!status_or_value.ok()) { - return status_or_value.status(); - } - auto output_counter = status_or_value.ConsumeValueOrDie(); - auto output = output_counter.first; - counter = output_counter.second; - ctx->SetOutput(0, output); - auto builder = ctx->builder(); - var = ConcatScalars(builder, {counter, key}); - xla::PrimitiveType state_element_type; - TF_RETURN_IF_ERROR( - DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type)); - var = BitcastConvertType(var, state_element_type); - TF_RETURN_IF_ERROR( - ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var)); - return Status::OK(); - } else { + if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) { return errors::InvalidArgument("Unsupported algorithm id: ", alg); } + + xla::XlaOp var; + TensorShape var_shape; + TF_RETURN_IF_ERROR(ctx->ReadVariableInput( + state_input_idx, STATE_ELEMENT_DTYPE, &var_shape, &var)); + TF_RETURN_IF_ERROR(CheckStateShape(alg, var_shape)); + TensorShape shape; + TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape)); + xla::XlaOp state; + xla::XlaOp key; + std::tie(state, key) = StateAndKeyFromVariable(alg, var); + auto status_or_value = sampler(alg, state, key, shape); + if (!status_or_value.ok()) { + return status_or_value.status(); + } + xla::RngOutput value_state = status_or_value.ConsumeValueOrDie(); + state = value_state.state; + ctx->SetOutput(0, value_state.value); + var = StateAndKeyToVariable(alg, state, key); + xla::PrimitiveType state_element_type; + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type)); + var = BitcastConvertType(var, state_element_type); + TF_RETURN_IF_ERROR( + ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var)); + return Status::OK(); } +class StatefulUniformOp : public XlaOpKernel { + public: + explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + auto sampler = [builder, this](Algorithm alg, xla::XlaOp state, + xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + xla::RngOutput uniform_state = StatefulRngUniform( + alg, key, state, xla_shape, xla::ConstantR0(builder, 0.0), + xla::ConstantR0(builder, 1.0)); + xla::XlaOp uniform = uniform_state.value; + state = uniform_state.state; + uniform = MaybeConvertF32ToBF16(uniform, dtype_); + return {{uniform, state}}; + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sampler)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulUniformOp); +}; + +// TODO(wangpeng): Support plain float16 and float64 to get rid of the +// `TypeConstraint`. +REGISTER_XLA_OP(Name("StatefulUniform") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}), + StatefulUniformOp); + class StatefulStandardNormalOp : public XlaOpKernel { public: explicit StatefulStandardNormalOp(OpKernelConstruction* ctx) @@ -251,30 +242,20 @@ class StatefulStandardNormalOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto builder = ctx->builder(); - auto sample_with_threefry = + auto sampler = // Needs explicit lambda return type because it fails to be inferred. - [builder, this](xla::XlaOp counter, xla::XlaOp key, - TensorShape shape) -> sampler_return_type { + [this](Algorithm alg, xla::XlaOp state, xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - - auto uniform_counter = StatefulRngUniform( - key, counter, xla_shape, - xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), - xla::ConstantR0(builder, 1.0)); - auto uniform = uniform_counter.first; - counter = uniform_counter.second; - // Convert uniform distribution to normal distribution by computing - // sqrt(2) * erfinv(x) - auto normal = - xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); - normal = MaybeConvertF32ToBF16(normal, dtype_); - return {{normal, counter}}; + xla::RngOutput value_state = + xla::NormalF32Distribution(key, state, BitGenerator(alg), xla_shape); + xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_); + return {{normal, value_state.state}}; }; OP_REQUIRES_OK(ctx, CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, - /*shape_input_idx=*/2, sample_with_threefry)); + /*shape_input_idx=*/2, sampler)); } private: @@ -291,6 +272,51 @@ REGISTER_XLA_OP(Name("StatefulStandardNormalV2") .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}), StatefulStandardNormalOp); +class StatefulTruncatedNormalOp : public XlaOpKernel { + public: + explicit StatefulTruncatedNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* builder = ctx->builder(); + auto sampler = + // Needs explicit lambda return type because it fails to be inferred. + [builder, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + xla::RngOutput uniform_result = StatefulRngUniform( + alg, key, state, xla_shape, + xla::MinPositiveNormalValue(builder, xla_shape.element_type()), + xla::One(builder, xla_shape.element_type())); + xla::XlaOp uniform = uniform_result.value; + state = uniform_result.state; + xla::XlaOp truncated_normal = TruncatedNormal(uniform); + truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_); + return {{truncated_normal, state}}; + }; + OP_REQUIRES_OK(ctx, + CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, + /*shape_input_idx=*/2, sampler)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatefulTruncatedNormalOp); +}; + +// TODO(wangpeng): Support plain float16 and float64 to get rid of the +// `TypeConstraint`. +REGISTER_XLA_OP(Name("StatefulTruncatedNormal") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}), + StatefulTruncatedNormalOp); + class StatefulUniformIntOp : public XlaOpKernel { public: explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -300,12 +326,12 @@ class StatefulUniformIntOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp minval = ctx->Input(3); xla::XlaOp maxval = ctx->Input(4); - auto sample_with_threefry = [minval, maxval, this]( - xla::XlaOp counter, xla::XlaOp key, - TensorShape shape) -> sampler_return_type { + auto sample_with_threefry = + [minval, maxval, this](Algorithm alg, xla::XlaOp state, xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape)); - return StatefulRngUniform(key, counter, xla_shape, minval, maxval); + return StatefulRngUniform(alg, key, state, xla_shape, minval, maxval); }; OP_REQUIRES_OK(ctx, CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, @@ -333,12 +359,12 @@ class StatefulUniformFullIntOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto sample_with_threefry = [this]( - xla::XlaOp counter, xla::XlaOp key, - TensorShape shape) -> sampler_return_type { + auto sample_with_threefry = [this](Algorithm alg, xla::XlaOp state, + xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape)); - return StatefulRngUniformFullInt(key, counter, xla_shape); + return StatefulRngUniformFullInt(alg, key, state, xla_shape); }; OP_REQUIRES_OK(ctx, CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 43255452cc3..648181eef04 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -33,11 +33,29 @@ limitations under the License. namespace tensorflow { +namespace { + +xla::BitGeneratorTy GetBitGeneratorForDevice( + absl::string_view device_type_string) { + // The Philox algorithm may cause performance regression on other devices. + // Turn on the Philox algorithm for the CPU and GPU backends only. + if (device_type_string == DEVICE_GPU_XLA_JIT || + device_type_string == DEVICE_CPU_XLA_JIT) { + return [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { + return xla::PhiloxBitGenerator(key, state, shape, /*scramble=*/true); + }; + } + + return xla::ThreeFryBitGenerator; +} + +} // namespace + xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { if (dtype == DT_BFLOAT16) { xla::XlaBuilder* builder = input.builder(); - auto output = xla::BitcastConvertType(input, xla::U32) & - xla::ConstantR0(builder, 0xFFFF0000); + xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) & + xla::ConstantR0(builder, 0xFFFF0000); return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32), xla::BF16); } else { @@ -45,12 +63,48 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { } } +xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, + xla::XlaOp seeds, const xla::Shape& shape, + xla::XlaOp minval, xla::XlaOp maxval) { + xla::XlaBuilder* builder = seeds.builder(); + + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {}); + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {}); + xla::XlaOp key = ConvertElementType(seed0, xla::U64) | + ShiftLeft(ConvertElementType(seed1, xla::U64), + ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); + xla::PrimitiveType type = shape.element_type(); + switch (type) { + case xla::F32: + return xla::UniformF32Distribution( + key, initial_state, + GetBitGeneratorForDevice(device_type_string), minval, maxval, + shape) + .value; + case xla::S32: // fall through + case xla::S64: + return UniformIntDistribution( + key, initial_state, + GetBitGeneratorForDevice(device_type_string), minval, maxval, + shape) + .value; + break; + default: + return builder->ReportError(xla::Unimplemented( + "Types other than F32, S32 and S64 are not implemented by " + "StatelessRngUniform; got %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))); + } +} + namespace { class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) { + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); } @@ -68,19 +122,17 @@ class StatelessRandomUniformOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - - auto uniform = xla::StatelessRngUniform( - {seed0, seed1}, xla_shape, xla::ConstantR0(builder, 0.0), - xla::ConstantR0(builder, 1.0)); + xla::XlaOp uniform = + StatelessRngUniform(device_type_string_, seed, xla_shape, + xla::ConstantR0(builder, 0.0), + xla::ConstantR0(builder, 1.0)); uniform = MaybeConvertF32ToBF16(uniform, dtype_); ctx->SetOutput(0, uniform); } private: DataType dtype_; + string device_type_string_; TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp); }; @@ -95,7 +147,8 @@ REGISTER_XLA_OP(Name("StatelessRandomUniform") class StatelessRandomUniformIntOp : public XlaOpKernel { public: explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) { + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); } @@ -122,17 +175,15 @@ class StatelessRandomUniformIntOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + xla::XlaOp uniform = StatelessRngUniform(device_type_string_, seed, + xla_shape, minval, maxval); - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - - auto uniform = - xla::StatelessRngUniform({seed0, seed1}, xla_shape, minval, maxval); ctx->SetOutput(0, uniform); } private: DataType dtype_; + string device_type_string_; TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp); }; @@ -147,7 +198,8 @@ REGISTER_XLA_OP(Name("StatelessRandomUniformInt") class StatelessRandomNormalOp : public XlaOpKernel { public: explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) { + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); } @@ -160,27 +212,28 @@ class StatelessRandomNormalOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - xla::XlaBuilder* builder = ctx->builder(); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - - auto uniform = xla::StatelessRngUniform( - {seed0, seed1}, xla_shape, - xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), - xla::ConstantR0(builder, 1.0)); - // Convert uniform distribution to normal distribution by computing - // sqrt(2) * erfinv(x) - auto normal = - xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); + xla::XlaBuilder* builder = seed.builder(); + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); + xla::XlaOp key = ConvertElementType(seed0, xla::U64) | + ShiftLeft(ConvertElementType(seed1, xla::U64), + ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp normal = + xla::NormalF32Distribution( + key, initial_state, GetBitGeneratorForDevice(device_type_string_), + xla_shape) + .value; normal = MaybeConvertF32ToBF16(normal, dtype_); ctx->SetOutput(0, normal); } private: DataType dtype_; + string device_type_string_; TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp); }; @@ -195,7 +248,8 @@ REGISTER_XLA_OP(Name("StatelessRandomNormal") class StatelessTruncatedNormalOp : public XlaOpKernel { public: explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) { + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); } @@ -210,22 +264,20 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - auto uniform = xla::StatelessRngUniform( - {seed0, seed1}, xla_shape, + xla::XlaOp uniform = StatelessRngUniform( + device_type_string_, seed, xla_shape, xla::MinPositiveNormalValue(builder, xla_shape.element_type()), xla::One(builder, xla_shape.element_type())); - auto output = TruncatedNormal(uniform); - output = MaybeConvertF32ToBF16(output, dtype_); - ctx->SetOutput(0, output); + xla::XlaOp truncated_normal = TruncatedNormal(uniform); + truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_); + ctx->SetOutput(0, truncated_normal); } private: DataType dtype_; + string device_type_string_; TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 67a291d7ead..ac3d2c22d65 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -47,51 +48,44 @@ class TensorListLengthOp : public XlaOpKernel { explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp index; - OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index)); - ctx->SetOutput(0, index); + int64 leading_dim; + OP_REQUIRES_OK(ctx, + GetLeadingDimForTensorList(ctx->Input(0), &leading_dim)); + Tensor length_tensor(DT_INT32, {}); + length_tensor.scalar()() = static_cast(leading_dim); + ctx->SetConstantOutput(0, length_tensor); } private: TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp); }; -REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); +REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp); -// Creates an empty list with size (leading_dim, *element_shape) if -// element_shape is known at compile time. Otherwise creates one with size -// (leading_dim, 0) which gets initialized later in `GetInitializedList`. -Status CreateZerosList(XlaOpKernelContext* ctx, int element_shape_index, - int64 leading_dim, DataType dtype, xla::XlaOp* list) { - TensorShape list_shape; - list_shape.AddDim(leading_dim); - xla::XlaOp element_shape_handle = ctx->Input(element_shape_index); - TF_ASSIGN_OR_RETURN( - bool is_element_shape_compile_time_const, - element_shape_handle.builder()->IsConstant(element_shape_handle)); - PartialTensorShape partial_element_shape; - if (is_element_shape_compile_time_const) { - TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape( - element_shape_index, &partial_element_shape)); +// "input" is the shape input for EmptyTensorList/TensorListReserve ops. +// If "input" is a compile time constant and not "unknown rank" (-1), return +// its value in "*shape". +Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input, + xla::PrimitiveType dtype, bool* got_shape, + xla::Shape* shape) { + auto is_compile_time_constant_or = input.builder()->IsConstant(input); + TF_RETURN_IF_ERROR(is_compile_time_constant_or.status()); + + bool is_compile_time_constant = is_compile_time_constant_or.ValueOrDie(); + if (!is_compile_time_constant) { + *got_shape = false; + return Status::OK(); } - if (is_element_shape_compile_time_const && - partial_element_shape.IsFullyDefined()) { - TensorShape element_shape; - partial_element_shape.AsTensorShape(&element_shape); - list_shape.AppendShape(element_shape); - } else { - // If element_shape is not a compile time constant or if it is not fully - // defined we will have to wait for the first write call to fully allocate - // the array. - // TODO(srbs): We are using element_shape of [0] as a proxy to denote an - // uninitialized list. A better implementation may be to represent the - // list as a 3-tuple containining an explicit "initialized" flag. However, - // we would still need to create a dummy tensor for the first tuple - // element. - list_shape.AddDim(0); + + PartialTensorShape partial_shape; + TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(0, &partial_shape)); + if (!partial_shape.IsFullyDefined()) { + *got_shape = false; + return Status::OK(); } - *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), - list_shape.dim_sizes()); + + *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes()); + *got_shape = true; return Status::OK(); } @@ -99,21 +93,53 @@ class TensorListReserveOp : public XlaOpKernel { public: explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + // Only non-nested TensorList is supported for now. + OP_REQUIRES( + ctx, dtype_ != DT_VARIANT, + errors::Unimplemented( + "Only non-nested TensorList is supported for TensorListReserve.")); } void Compile(XlaOpKernelContext* ctx) override { int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); + OP_REQUIRES( + ctx, num_elements >= 0, + errors::InvalidArgument("XLA compilation requires a fixed tensor list " + "size. Set the number of elements.")); - xla::XlaOp buffer; - OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &buffer)); + // If element shape is compile time constant and it's not "unknown rank" + // shape (-1), create an initialized TensorList. Otherwise create an + // uninitialized TensorList. + xla::XlaOp element_shape_handle = ctx->Input(0); + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type)); + bool got_shape; + xla::Shape element_shape; + OP_REQUIRES_OK(ctx, + TryGetElementShapeFromInput(ctx, element_shape_handle, type, + &got_shape, &element_shape)); + if (got_shape) { + xla::Shape list_shape; + OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape( + element_shape, num_elements, &list_shape)); - xla::XlaOp output_list; - OP_REQUIRES_OK( - ctx, BuildTensorList( - buffer, xla::ConstantR0(ctx->builder(), num_elements), - &output_list)); - ctx->SetTensorListOutput(0, output_list); + xla::XlaOp new_list; + OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape( + ctx->builder(), list_shape, &new_list)); + xla::XlaOp result; + OP_REQUIRES_OK( + ctx, + SetTensorListPushIndex( + new_list, xla::ConstantR0(ctx->builder(), num_elements), + &result)); + ctx->SetTensorListOutput(0, result); + return; + } + + xla::XlaOp result = + BuildUninitializedTensorList(ctx->builder(), num_elements); + ctx->SetTensorListOutput(0, result); } private: @@ -141,15 +167,37 @@ class EmptyTensorListOp : public XlaOpKernel { errors::InvalidArgument("XLA compilation requires a fixed tensor list " "size. Set the max number of elements.")); - xla::XlaOp buffer; - OP_REQUIRES_OK(ctx, - CreateZerosList(ctx, 0, max_num_elements, dtype_, &buffer)); + if (dtype_ != DT_VARIANT) { + // We are creating a non-nested TensorList. + // If element shape is compile time constant and it's not "unknown rank" + // shape (-1), create an initialized TensorList. Otherwise create an + // uninitialized TensorList. + xla::XlaOp element_shape_handle = ctx->Input(0); + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type)); + bool got_shape; + xla::Shape element_shape; + OP_REQUIRES_OK( + ctx, TryGetElementShapeFromInput(ctx, element_shape_handle, type, + &got_shape, &element_shape)); + if (got_shape) { + xla::Shape list_shape; + OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape( + element_shape, max_num_elements, &list_shape)); - xla::XlaOp output_list; - OP_REQUIRES_OK( - ctx, BuildTensorList(buffer, xla::ConstantR0(ctx->builder(), 0), - &output_list)); - ctx->SetTensorListOutput(0, output_list); + xla::XlaOp result; + OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape( + ctx->builder(), list_shape, &result)); + ctx->SetTensorListOutput(0, result); + return; + } + } + + // We are creating a nested TensorList or a non-nested TensorList with + // unknown shape. Just create an uninitialized TensorList. + xla::XlaOp result = + BuildUninitializedTensorList(ctx->builder(), max_num_elements); + ctx->SetTensorListOutput(0, result); } private: @@ -160,7 +208,8 @@ class EmptyTensorListOp : public XlaOpKernel { REGISTER_XLA_OP(Name("EmptyTensorList") .CompileTimeConstantInput("element_shape") - .CompileTimeConstantInput("max_num_elements"), + .CompileTimeConstantInput("max_num_elements") + .AllowVariantTypes(), EmptyTensorListOp); class TensorListElementShapeOp : public XlaOpKernel { @@ -171,18 +220,34 @@ class TensorListElementShapeOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { + // Check that the TensorList is initialized. + bool is_initialized; + OP_REQUIRES_OK(ctx, + (IsTensorListInitialized(ctx->Input(0), &is_initialized))); + OP_REQUIRES(ctx, is_initialized, + errors::InvalidArgument("TensorList is not initialized")); + + // Only non-nested TensorList is supported for now. + bool is_nested; + OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested)); + OP_REQUIRES(ctx, !is_nested, + errors::Unimplemented("Only non-nested TensorList is supported " + "for TensorListElementShape.")); + + // For non-nested TensorList, element shape is the buffer shape without + // the first dimension. xla::XlaBuilder* b = ctx->builder(); - TensorShape shape; - OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); - shape.RemoveDim(0); + xla::Shape list_shape; + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &list_shape)); + list_shape.DeleteDimension(0); switch (shape_type_) { case DT_INT64: - ctx->SetOutput(0, xla::ConstantR1(b, shape.dim_sizes())); + ctx->SetOutput(0, xla::ConstantR1(b, list_shape.dimensions())); break; case DT_INT32: { std::vector size; - for (int64 s : shape.dim_sizes()) { + for (int64 s : list_shape.dimensions()) { size.push_back(s); } ctx->SetOutput(0, xla::ConstantR1(b, size)); @@ -201,7 +266,8 @@ class TensorListElementShapeOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp); }; -REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); +REGISTER_XLA_OP(Name("TensorListElementShape").IsMetadataOp(), + TensorListElementShapeOp); class TensorListGetItemOp : public XlaOpKernel { public: @@ -210,28 +276,27 @@ class TensorListGetItemOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp state = ctx->Input(0); + // Check that the TensorList is initialized. + bool is_initialized; + OP_REQUIRES_OK(ctx, + (IsTensorListInitialized(ctx->Input(0), &is_initialized))); + OP_REQUIRES(ctx, is_initialized, + errors::InvalidArgument("TensorList is not initialized")); - TensorShape shape; - OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); + // Only non-nested TensorList is supported for now. + bool is_nested; + OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested)); + OP_REQUIRES(ctx, !is_nested, + errors::Unimplemented("Only non-nested TensorList is supported " + "for TensorListGetItem.")); - xla::XlaOp buffer; - OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &buffer)); + xla::XlaOp list = ctx->Input(0); xla::XlaOp index = ctx->Input(1); - // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - std::vector start_indices(shape.dims(), - xla::ConstantR0(b, 0)); - start_indices[0] = index; - auto slice_shape = shape.dim_sizes(); - slice_shape[0] = 1LL; + xla::XlaOp result; + OP_REQUIRES_OK(ctx, ExecuteTensorListGetItem(list, index, &result)); - xla::XlaOp read = xla::DynamicSlice(buffer, start_indices, slice_shape); - // Remove the leading '1' dimension. - std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - - ctx->SetOutput(0, xla::Reshape(read, value_shape)); + ctx->SetOutput(0, result); } private: @@ -244,19 +309,29 @@ REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp); class TensorListStackOp : public XlaOpKernel { public: - explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); - } + explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + // Check that the TensorList is initialized. + bool is_initialized; + OP_REQUIRES_OK(ctx, + (IsTensorListInitialized(ctx->Input(0), &is_initialized))); + OP_REQUIRES(ctx, is_initialized, + errors::InvalidArgument("TensorList is not initialized")); + + // Only non-nested TensorList is supported for now. + bool is_nested; + OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested)); + OP_REQUIRES(ctx, !is_nested, + errors::Unimplemented("Only non-nested TensorList is supported " + "for TensorListGetItem.")); + xla::XlaOp buffer; OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer)); ctx->SetOutput(0, buffer); } private: - DataType dtype_; - TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp); }; @@ -265,34 +340,20 @@ REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp); class TensorListFromTensorOp : public XlaOpKernel { public: explicit TensorListFromTensorOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); - } + : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &element_shape)); - - const TensorShape tensor_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, tensor_shape.dims() > 0, - errors::InvalidArgument("Input value must be at least a " - "vector but received shape: ", - tensor_shape.DebugString())); - const int num_elements = tensor_shape.dim_size(0); - - xla::XlaBuilder* b = ctx->builder(); + const TensorShape& tensor_shape = ctx->InputShape(0); + int num_elements = tensor_shape.dim_size(0); const xla::XlaOp tensor = ctx->Input(0); - - xla::XlaOp output_list; - OP_REQUIRES_OK( - ctx, BuildTensorList(tensor, xla::ConstantR0(b, num_elements), - &output_list)); - ctx->SetTensorListOutput(0, output_list); + xla::XlaOp result; + OP_REQUIRES_OK(ctx, + ExecuteTensorListFromTensor(num_elements, tensor, &result)); + auto list_shape_or = ctx->builder()->GetShape(result); + ctx->SetTensorListOutput(0, result); } private: - DataType dtype_; - TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp); }; @@ -300,75 +361,34 @@ REGISTER_XLA_OP( Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), TensorListFromTensorOp); -// Returns the 0'th element of `tuple` containing the list tensor if it has been -// initialized already else creates one lazily. This allows lazy initialization -// of the list on the first call to SetItem or PushBack. -Status GetInitializedList(const xla::XlaOp& input_list, - const TensorShape& element_shape, DataType dtype, - xla::XlaOp* output_list_buffer) { - bool is_already_initialized; - TF_RETURN_IF_ERROR( - IsTensorListInitialized(input_list, &is_already_initialized)); - TensorShape input_list_shape; - TF_RETURN_IF_ERROR(GetTensorListBufferShape(input_list, &input_list_shape)); - TensorShape input_list_element_shape = input_list_shape; - input_list_element_shape.RemoveDim(0); - - if (is_already_initialized) { - TF_RET_CHECK(element_shape == input_list_element_shape); - TF_RETURN_IF_ERROR(GetTensorListBuffer(input_list, output_list_buffer)); - return Status::OK(); - } - - int64 leading_dim = input_list_shape.dim_size(0); - TensorShape output_list_shape = element_shape; - output_list_shape.InsertDim(0, leading_dim); - - xla::XlaOp output_list; - TF_RETURN_IF_ERROR( - InitializeTensorList(input_list, output_list_shape, &output_list)); - TF_RETURN_IF_ERROR(GetTensorListBuffer(output_list, output_list_buffer)); - return Status::OK(); -} - class TensorListSetItemOp : public XlaOpKernel { public: - explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); - } + explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp tl = ctx->Input(0); - TensorShape elem_shape = ctx->InputShape(2); - - xla::XlaOp buffer; - OP_REQUIRES_OK(ctx, GetInitializedList(tl, elem_shape, dtype_, &buffer)); - xla::XlaOp push_index; - OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tl, &push_index)); - + xla::XlaOp list = ctx->Input(0); xla::XlaOp index = ctx->Input(1); - xla::XlaOp value = ctx->Input(2); + xla::XlaOp element = ctx->Input(2); + xla::XlaOp initialized_list; + OP_REQUIRES_OK(ctx, GetInitializedTensorListForElement( + list, element, /*element_is_tensor_list=*/false, + &initialized_list)); - // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - std::vector start_indices(elem_shape.dims() + 1, - xla::ConstantR0(b, 0)); - start_indices[0] = index; + // Only non-nested TensorList is supported for now. + bool is_nested; + OP_REQUIRES_OK(ctx, IsNestedTensorList(initialized_list, &is_nested)); + OP_REQUIRES(ctx, !is_nested, + errors::Unimplemented("Only non-nested TensorList is supported " + "for TensorListSetItem.")); - TensorShape slice_shape = elem_shape; - slice_shape.InsertDim(0, 1LL); - auto update = xla::Reshape(value, slice_shape.dim_sizes()); + xla::XlaOp result; + OP_REQUIRES_OK(ctx, ExecuteTensorListSetItem(initialized_list, index, + element, &result)); - xla::XlaOp output_list; - OP_REQUIRES_OK(ctx, BuildTensorList(xla::DynamicUpdateSlice(buffer, update, - start_indices), - push_index, &output_list)); - ctx->SetTensorListOutput(0, output_list); + ctx->SetTensorListOutput(0, result); } private: - DataType dtype_; - TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp); }; @@ -376,83 +396,57 @@ REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp); class TensorListPushBackOp : public XlaOpKernel { public: - explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); - } + explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp list_tuple = ctx->Input(0); - TensorShape elem_shape = ctx->InputShape(1); - - xla::XlaOp buffer; - OP_REQUIRES_OK(ctx, - GetInitializedList(list_tuple, elem_shape, dtype_, &buffer)); - - xla::XlaOp index; - OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list_tuple, &index)); - xla::XlaOp value = ctx->Input(1); - - // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - std::vector start_indices(elem_shape.dims() + 1, - xla::ConstantR0(b, 0)); - start_indices[0] = index; - - TensorShape slice_shape = elem_shape; - slice_shape.InsertDim(0, 1LL); - auto update = xla::Reshape(value, slice_shape.dim_sizes()); - - xla::XlaOp output_list; + xla::XlaOp list = ctx->Input(0); + xla::XlaOp element = ctx->Input(1); + bool element_is_tensor_list = IsTensorListInput(ctx, 1); + xla::XlaOp initialized_list; OP_REQUIRES_OK( - ctx, - BuildTensorList(xla::DynamicUpdateSlice(buffer, update, start_indices), - index + xla::ConstantR0(b, 1), &output_list)); - ctx->SetTensorListOutput(0, output_list); + ctx, GetInitializedTensorListForElement( + list, element, element_is_tensor_list, &initialized_list)); + + xla::XlaOp result; + OP_REQUIRES_OK(ctx, + ExecuteTensorListPushBack(initialized_list, element, + element_is_tensor_list, &result)); + + ctx->SetTensorListOutput(0, result); } private: - DataType dtype_; - TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp); }; -REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp); +REGISTER_XLA_OP(Name("TensorListPushBack").AllowVariantTypes(), + TensorListPushBackOp); class TensorListPopBackOp : public XlaOpKernel { public: - explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); - } + explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp state = ctx->Input(0); + // Check that the TensorList is initialized. + bool is_initialized; + OP_REQUIRES_OK(ctx, + (IsTensorListInitialized(ctx->Input(0), &is_initialized))); + OP_REQUIRES(ctx, is_initialized, + errors::InvalidArgument("TensorList is not initialized")); - TensorShape shape; - OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &shape)); + xla::XlaOp list = ctx->Input(0); + xla::XlaOp list_result, element_result; + bool element_is_tensor_list; + OP_REQUIRES_OK(ctx, + ExecuteTensorListPopBack(list, &list_result, &element_result, + &element_is_tensor_list)); - xla::XlaOp ta; - OP_REQUIRES_OK(ctx, GetTensorListBuffer(state, &ta)); - xla::XlaOp index; - OP_REQUIRES_OK(ctx, GetTensorListPushIndex(state, &index)); - - index = index - xla::ConstantR0(b, 1); - - // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - std::vector start_indices(shape.dims(), - xla::ConstantR0(b, 0)); - start_indices[0] = index; - auto slice_shape = shape.dim_sizes(); - slice_shape[0] = 1LL; - - xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); - // Remove the leading '1' dimension. - std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - - xla::XlaOp output_list; - OP_REQUIRES_OK(ctx, BuildTensorList(ta, index, &output_list)); - ctx->SetTensorListOutput(0, output_list); - ctx->SetOutput(1, xla::Reshape(read, value_shape)); + ctx->SetTensorListOutput(0, list_result); + if (element_is_tensor_list) { + ctx->SetTensorListOutput(1, element_result); + } else { + ctx->SetOutput(1, element_result); + } } private: @@ -461,7 +455,8 @@ class TensorListPopBackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp); }; -REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp); +REGISTER_XLA_OP(Name("TensorListPopBack").AllowVariantTypes(), + TensorListPopBackOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index aa6ee2ac35e..579c9ac33b0 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -14,87 +14,481 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" + #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/errors.h" +// TensorList is represented by a tuple. +// - The first part of the tuple is a buffer containing all the tensors, +// - The following parts are push indices for all nested levels of +// TensorLists. The last part is push index for the outermost TensorList. +// +// TensorList, as it name suggests, is conceptually a list of tensors. In actual +// representation of a non-nested TensorList, the buffer shape is +// [tensor_list_size, element shape]. We will call tensor_list_size "leading +// dimension" below. Notice that the leading dimension must be a compile time +// constant, since it's part of the buffer shape. +// +// Example: consider a 3-level nested TensorList whose element type is scalar. +// Assume inner TensorList has leading dimension 4, middle TensorList has 3, +// and outer TensorList has 3. +// Assume that lower cased letter means there is data in that position, and "." +// means there is no data in that position. +// First element of outer TensorList: +// [ a . . . ] +// [ b c . . ] +// [ d e f . ] +// Second element of outer TensorList: +// [ g h i . ] +// [ j k . . ] +// [ . . . . ] +// Third element: not pushed yet. +// +// The first part of the tuple is an array of shape [3, 3, 4] containing data. +// The second part is an array of shape [3, 3], each element is push index +// for the inner TensorList. In this case, its values are: +// [ 1 2 3 ] +// [ 3 2 . ] +// [ . . . ] +// The third part is an array of shape [3], each element is push index for +// the middle TensorList. In this case, its values are: +// [ 3 ] +// [ 2 ] +// [ . ] +// The forth (and last) part is a scalar. It's the push index for the outer +// TensorList. In this case, its values is 2. +// +// Now imagine we need to push the following element to the outer TensorList: +// [ l . . . ] +// [ m n . . ] +// [ . . . . ] +// This element is represented by a tuple of 3 parts: +// First part is all data. +// Second part is push indices for the inner TensorList, which is [ 1 2 . ]. +// Third part is push index for the middle TensorList, which is 2. +// Now let's do the push. +// First, we append its data to outer TensorList's data. +// Then we start to deal with push indices. Similar to data, we append push +// indices for each level of TensorList. +// For the inner TensorList: append push indices for the pushed element. +// [ 1 2 3 ] [ 1 2 3 ] +// [ 3 2 . ] + = [ 3 2 . ] +// [ . . . ] [ 1 2 . ] [ 1 2 . ] +// For the middle TensorList: append push indices for the pushed element. +// [ 3 ] [ 3 ] +// [ 2 ] + = [ 2 ] +// [ . ] [ 2 ] [ 2 ] +// For the outer TensorList: just add 1. +// 2 + 1 = 3 +// +// Popping an element from the outer TensorList also follows a similar process. +// First part is data. We get data by slicing data with push index for outer +// TensorList (which is 3). +// Second part is push indices for inner TensorList. We get it by slicing +// push indices for inner TensorList with push index for outer TensorList (which +// is 3). +// [ 1 2 3 ] +// [ 3 2 . ] +// [ 1 2 . ] ===> This is what we want +// Third part is push index for middle TensorList. We get it by slicing +// push indices for middle TensorList with push index for outer TensorList +// (which is 3). +// [ 3 ] +// [ 2 ] +// [ 2 ] ===> This is what we want + namespace tensorflow { bool IsTensorListInput(XlaOpKernelContext* ctx, int index) { return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList; } -Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index, - xla::XlaOp* output_list) { +Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) { + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); + *is_initialized = list_shape.IsTuple(); + return Status::OK(); +} + +Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); + } + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); + *is_nested_list = (xla::ShapeUtil::TupleElementCount(list_shape) > 2); + return Status::OK(); +} + +Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, + xla::XlaOp* output_list) { TF_RET_CHECK(buffer.builder()); *output_list = xla::Tuple(buffer.builder(), {buffer, push_index}); return Status::OK(); } -Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) { - TF_RET_CHECK(op.builder()); - *buffer = xla::GetTupleElement(op, 0); - return Status::OK(); -} - -Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index) { - TF_RET_CHECK(op.builder()); - *push_index = xla::GetTupleElement(op, 1); - return Status::OK(); -} - -Status GetTensorListBufferShape(const xla::XlaOp& op, - TensorShape* buffer_shape) { - TF_RET_CHECK(op.builder()); - TensorShape shape; - TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape, - op.builder()->GetShape(op)); - return GetTensorListBufferShape(list_tuple_shape, buffer_shape); -} - -Status GetTensorListBufferShape(const xla::Shape& list_shape, - TensorShape* buffer_shape) { - TF_RET_CHECK(list_shape.IsTuple()); - TF_RETURN_IF_ERROR(XLAShapeToTensorShape( - xla::ShapeUtil::GetTupleElementShape(list_shape, 0), buffer_shape)); - return Status::OK(); -} - -Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized) { - TensorShape list_shape; - TF_RETURN_IF_ERROR(GetTensorListBufferShape(op, &list_shape)); - *is_initialized = !(list_shape.dims() == 2 && list_shape.dim_size(1) == 0); - return Status::OK(); -} - -Status InitializeTensorList(const xla::XlaOp& uninitialized_list, - const TensorShape& buffer_shape, - xla::XlaOp* output_list) { - TensorShape input_buffer_shape; - TF_RETURN_IF_ERROR( - GetTensorListBufferShape(uninitialized_list, &input_buffer_shape)); - if (input_buffer_shape.dim_size(0) != buffer_shape.dim_size(0)) { - return errors::InvalidArgument( - "Number of elements in input list does not match buffer size. ", - "input list size: ", input_buffer_shape.dim_size(0), - "buffer size: ", buffer_shape.dim_size(0)); +Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); } - xla::XlaBuilder* builder = uninitialized_list.builder(); - xla::XlaOp input_buffer; - TF_RETURN_IF_ERROR(GetTensorListBuffer(uninitialized_list, &input_buffer)); - TF_ASSIGN_OR_RETURN(const xla::Shape& input_buffer_xla_shape, - builder->GetShape(input_buffer)); - auto new_buffer = xla::Broadcast( - xla::ConstantLiteral(builder, xla::LiteralUtil::Zero( - input_buffer_xla_shape.element_type())), - buffer_shape.dim_sizes()); - xla::XlaOp push_index; - TF_RETURN_IF_ERROR(GetTensorListPushIndex(uninitialized_list, &push_index)); - return BuildTensorList(new_buffer, push_index, output_list); + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); + *buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); + return Status::OK(); +} + +Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); + } + *buffer = xla::GetTupleElement(list, 0); + return Status::OK(); +} + +Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); + } + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); + int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); + *push_index = xla::GetTupleElement(list, tuple_size - 1); + return Status::OK(); +} + +Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, + xla::XlaOp* result) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); + } + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); + int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); + std::vector result_parts; + result_parts.reserve(tuple_size); + for (int i = 0; i < tuple_size - 1; i++) { + result_parts.push_back(xla::GetTupleElement(list, i)); + } + result_parts.push_back(push_index); + *result = xla::Tuple(list.builder(), result_parts); + return Status::OK(); +} + +xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, + int64 leading_dimension) { + auto zero = + xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32)); + return xla::Broadcast(zero, std::vector{leading_dimension}); +} + +Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); + if (is_initialized) { + auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); + *leading_dim = buffer_shape.dimensions(0); + } else { + *leading_dim = list_shape.dimensions(0); + } + return Status::OK(); +} + +Status GetTensorListShapeFromElementTensorListShape( + const xla::Shape& element_tensor_list_shape, int64 leading_dim, + xla::Shape* tensor_list_shape) { + std::vector shapes; + int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape); + for (int i = 0; i < tuple_size; i++) { + const xla::Shape& shape = + xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i); + std::vector dimensions = shape.dimensions(); + dimensions.insert(dimensions.begin(), leading_dim); + shapes.push_back( + xla::ShapeUtil::MakeShape(shape.element_type(), dimensions)); + } + shapes.push_back( + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); + *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes); + return Status::OK(); +} + +Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, + int64 leading_dim, + xla::Shape* tensor_list_shape) { + if (!element_shape.IsArray()) { + return errors::InvalidArgument( + "GetTensorListShapeFromElementShape() only supports normal tensor " + "shape. But element shape is ", + element_shape.DebugString()); + } + + std::vector shapes; + std::vector dimensions = element_shape.dimensions(); + dimensions.insert(dimensions.begin(), leading_dim); + shapes.push_back( + xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions)); + shapes.push_back( + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector{})); + *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes); + return Status::OK(); +} + +Status CreateZerosTensorListWithShape(xla::XlaBuilder* b, + const xla::Shape& list_shape, + xla::XlaOp* list) { + int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); + std::vector elements; + for (int i = 0; i < tuple_size; i++) { + const xla::Shape& shape = + xla::ShapeUtil::GetTupleElementShape(list_shape, i); + xla::XlaOp zero = + xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type())); + xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions()); + elements.push_back(zeros); + } + *list = xla::Tuple(b, elements); + return Status::OK(); +} + +Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* initialized_list) { + int64 leading_dim; + TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(list, &leading_dim)); + + xla::XlaBuilder* b = list.builder(); + xla::Shape list_shape; + if (element_is_tensor_list) { + TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element)); + TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape( + element_shape, leading_dim, &list_shape)); + } else { + TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element)); + TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape( + element_shape, leading_dim, &list_shape)); + } + + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (is_initialized) { + // Check shape of initialized list is correct. + TF_ASSIGN_OR_RETURN(xla::Shape original_list_shape, b->GetShape(list)); + if (!xla::ShapeUtil::Equal(original_list_shape, list_shape)) { + return errors::Internal( + "Invalid TensorList shape: ", original_list_shape.DebugString(), + ", expected: ", list_shape.DebugString()); + } + *initialized_list = list; + return Status::OK(); + } else { + return CreateZerosTensorListWithShape(b, list_shape, initialized_list); + } +} + +Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* result) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); + } + + xla::XlaBuilder* b = list.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list)); + int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); + xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1); + + std::vector result_parts; + + if (element_is_tensor_list) { + TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element)); + int element_tuple_size = xla::ShapeUtil::TupleElementCount(element_shape); + for (int i = 0; i < element_tuple_size; i++) { + const xla::Shape& element_part_shape = + xla::ShapeUtil::GetTupleElementShape(element_shape, i); + xla::XlaOp element_part = xla::GetTupleElement(element, i); + std::vector element_part_dims = element_part_shape.dimensions(); + element_part_dims.insert(element_part_dims.begin(), 1); + element_part = xla::Reshape(element_part, element_part_dims); + + std::vector start_indices( + element_part_shape.dimensions_size() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = push_index; + + xla::XlaOp list_part = xla::GetTupleElement(list, i); + xla::XlaOp updated_list_part = + xla::DynamicUpdateSlice(list_part, element_part, start_indices); + result_parts.push_back(updated_list_part); + } + } else { + TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element)); + std::vector element_dims = element_shape.dimensions(); + element_dims.insert(element_dims.begin(), 1); + xla::XlaOp update = xla::Reshape(element, element_dims); + + std::vector start_indices(element_shape.dimensions_size() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = push_index; + + xla::XlaOp list_part = xla::GetTupleElement(list, 0); + xla::XlaOp updated_list_part = + xla::DynamicUpdateSlice(list_part, update, start_indices); + result_parts.push_back(updated_list_part); + } + + xla::XlaOp updated_push_index = push_index + xla::ConstantR0(b, 1); + result_parts.push_back(updated_push_index); + + *result = xla::Tuple(b, result_parts); + return Status::OK(); +} + +Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, + xla::XlaOp* element_result, + bool* element_is_tensor_list) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); + } + + // If the TensorList is a nested TensorList, element will be TensorList. + TF_RETURN_IF_ERROR(IsNestedTensorList(list, element_is_tensor_list)); + + xla::XlaBuilder* b = list.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list)); + int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); + xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1); + push_index = push_index - xla::ConstantR0(b, 1); + + std::vector list_result_parts, element_result_parts; + for (int i = 0; i < list_tuple_size - 1; i++) { + const xla::Shape& list_part_shape = + xla::ShapeUtil::GetTupleElementShape(list_shape, i); + std::vector start_indices(list_part_shape.dimensions_size(), + xla::ConstantR0(b, 0)); + start_indices[0] = push_index; + + std::vector slice_shape = list_part_shape.dimensions(); + slice_shape[0] = 1LL; + + xla::XlaOp list_part = xla::GetTupleElement(list, i); + xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); + + slice_shape.erase(slice_shape.begin()); + element_result_parts.push_back(xla::Reshape(read, slice_shape)); + list_result_parts.push_back(list_part); + } + list_result_parts.push_back(push_index); + + *list_result = xla::Tuple(b, list_result_parts); + if (*element_is_tensor_list) { + *element_result = xla::Tuple(b, element_result_parts); + } else { + *element_result = element_result_parts[0]; + } + + return Status::OK(); +} + +Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp element, xla::XlaOp* result) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); + } + bool is_nested; + TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested)); + if (is_nested) { + return errors::Unimplemented( + "ExecuteTensorListSetItem() only supports non-nested TensorList"); + } + + xla::XlaBuilder* b = list.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element)); + std::vector element_dims = element_shape.dimensions(); + element_dims.insert(element_dims.begin(), 1); + xla::XlaOp update = xla::Reshape(element, element_dims); + + std::vector start_indices(element_shape.dimensions_size() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; + + xla::XlaOp list_part = xla::GetTupleElement(list, 0); + xla::XlaOp updated_list_part = + xla::DynamicUpdateSlice(list_part, update, start_indices); + + std::vector result_parts; + result_parts.push_back(updated_list_part); + result_parts.push_back(xla::GetTupleElement(list, 1)); + *result = xla::Tuple(b, result_parts); + return Status::OK(); +} + +Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp* result) { + bool is_initialized; + TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); + if (!is_initialized) { + return errors::InvalidArgument("TensorList is not initialized"); + } + bool is_nested; + TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested)); + if (is_nested) { + return errors::Unimplemented( + "ExecuteTensorListGetItem() only supports non-nested TensorList"); + } + + xla::XlaBuilder* b = list.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list)); + const xla::Shape& buffer_shape = + xla::ShapeUtil::GetTupleElementShape(list_shape, 0); + std::vector start_indices(buffer_shape.dimensions_size(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; + + std::vector slice_shape = buffer_shape.dimensions(); + slice_shape[0] = 1LL; + + xla::XlaOp list_part = xla::GetTupleElement(list, 0); + xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); + + slice_shape.erase(slice_shape.begin()); + *result = xla::Reshape(read, slice_shape); + return Status::OK(); +} + +Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, + xla::XlaOp* result) { + xla::XlaBuilder* b = tensor.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape shape, b->GetShape(tensor)); + if (!shape.IsArray()) { + return errors::InvalidArgument( + "ExecuteTensorListFromTensor() only supports normal tensor. But input " + "shape is ", + shape.DebugString()); + } + + std::vector result_parts{tensor, + xla::ConstantR0(b, push_index)}; + *result = xla::Tuple(b, result_parts); + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h index 937af6f8d77..7fac2d9dbab 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -16,12 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ -// TensorList utilities. -// -// Tensor lists are represented as tuple consisting of a pre-allocated buffer -// consisting of the tensors (and where dim 0 is the list index), along with a -// scalar telling us the next index to push a value at. - #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -31,36 +25,97 @@ namespace tensorflow { // Whether the input expression at `index` corresponds to a TensorList. bool IsTensorListInput(XlaOpKernelContext* ctx, int index); -// Builds a TensorList from its constituents, `buffer` and `push_index`. -Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index, - xla::XlaOp* output_list); +// Whether the TensorList is initialized (has known data type and shape). +Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized); -// Returns the buffer for the TensorList. -Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer); +// Whether the TensorList is a nested TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list); -// Returns the push_index for the TensorList. -Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index); +// Builds a non-nested TensorList from `buffer` and `push_index`. +Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, + xla::XlaOp* output_list); -// Returns the shape of the TensorList buffer. -Status GetTensorListBufferShape(const xla::XlaOp& op, - TensorShape* buffer_shape); +// Returns buffer shape for the TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape); -// Inputs the TensorList shape and returns the buffer shape. -Status GetTensorListBufferShape(const xla::Shape& list_shape, - TensorShape* buffer_shape); +// Returns buffer for the TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer); -// Returns whether the TensorList has been initialized. -// -// A TensorList is considered initialized if its element_shape is completely -// known. -Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized); +// Returns push index for the TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index); -// Inputs an uninitialized list and a buffer_shape and returns an initialized -// list. The initialized list uses the dtype and push index of the uninitialized -// list and is filled with zeros. -Status InitializeTensorList(const xla::XlaOp& uninitialized_list, - const TensorShape& buffer_shape, - xla::XlaOp* output_list); +// Returns a new TensorList with given push_index. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, + xla::XlaOp* result); + +// Returns an uninitialized TensorList. +xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, + int64 leading_dimension); + +// Returns leading dimension for the TensorList. +// Input can be initialized or uninitialized TensorList. +// Non-nested and nested TensorLists are both supported. +Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim); + +// Returns TensorList shape for the element shape. +// Element shape must be a normal tensor shape. +Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, + int64 leading_dim, + xla::Shape* tensor_list_shape); + +// Returns a TensorList filled by zeros with the given shape. +Status CreateZerosTensorListWithShape(xla::XlaBuilder* b, + const xla::Shape& list_shape, + xla::XlaOp* list); + +// If the TensorList is initialized, check that its shape matches element shape; +// If the TensorList is uninitialized, initialize it with the element shape. +// Input can be initialized or uninitialized TensorList. +// "element" can be normal tensor or TensorList. +Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* initialized_list); + +// Executes TensorListPushBack with given TensorList and element. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* result); + +// Executes TensorListPopBack with given TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, + xla::XlaOp* element_result, + bool* element_is_tensor_list); + +// Executes TensorListSetItem with given TensorList, index and element. +// Input must be an initialized TensorList. +// Only non-nested TensorList is supported. +Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp element, xla::XlaOp* result); + +// Executes TensorListGetItem with given TensorList and index. +// Input must be an initialized TensorList. +// Only non-nested TensorList is supported. +Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp* result); + +// Executes TensorListPushBack with given tensor and push index. +// "tensor" must be a normal tensor. +Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, + xla::XlaOp* result); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 7c4176eb839..37c28ef4173 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -82,12 +82,7 @@ XLAJIT_MAKE_UNARY(Round, xla::RoundToEven(x)); XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); -// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. -xla::XlaOp Sigmoid(xla::XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - return half + half * xla::Tanh(half * x); -} -XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); +XLAJIT_MAKE_UNARY(Sigmoid, xla::Logistic(x)); // Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. XLAJIT_MAKE_UNARY(Sign, diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index dfa09b16081..7b4125ab76e 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -58,7 +58,7 @@ class VariableShapeOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); +REGISTER_XLA_OP(Name("VariableShape").IsMetadataOp(), VariableShapeOp); class ReadVariableOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 58637302d4a..f8d33a423dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" #include "absl/strings/str_split.h" +#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -26,24 +28,25 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { -const char kPropagateCompileTimeConsts[] = "_xla_propagate_compile_time_consts"; - namespace { // Builds XlaCompiler argument descriptions `args` from `ctx`. Status MakeXlaCompilerArgumentsFromInputs( XlaOpKernelContext* ctx, std::vector* args, - bool* has_uninitialized_vars, bool* has_tensor_arrays) { + bool* has_uninitialized_vars, bool* has_tensor_arrays, + bool* has_uninitialized_tensor_lists) { VLOG(2) << "Num inputs " << ctx->num_inputs(); args->resize(ctx->num_inputs()); *has_uninitialized_vars = false; *has_tensor_arrays = false; + *has_uninitialized_tensor_lists = false; for (int i = 0; i < ctx->num_inputs(); ++i) { VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); @@ -79,15 +82,19 @@ Status MakeXlaCompilerArgumentsFromInputs( } else { arg.kind = XlaCompiler::Argument::kParameter; - arg.type = ctx->input_type(i); - - xla::XlaBuilder* builder = ctx->builder(); - xla::XlaOp handle = ctx->Input(i); - auto shape_or_status = builder->GetShape(handle); - if (!shape_or_status.ok()) { - return shape_or_status.status(); + arg.type = type; + TF_ASSIGN_OR_RETURN(arg.shape, ctx->builder()->GetShape(ctx->Input(i))); + if (IsTensorListInput(ctx, i)) { + // arg.initialized == false means that the element_shape of the list + // was not available at the time of building the list so an empty list + // was created instead. If so, the body function of While is run once + // to infer the shape of the list before actually building the While op. + TF_RETURN_IF_ERROR( + IsTensorListInitialized(ctx->Input(i), &arg.initialized)); + if (!arg.initialized) { + *has_uninitialized_tensor_lists = true; + } } - arg.shape = shape_or_status.ValueOrDie(); } } return Status::OK(); @@ -266,9 +273,10 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector arguments; bool has_uninitialized_vars; bool has_tensor_arrays; - OP_REQUIRES_OK( - ctx, MakeXlaCompilerArgumentsFromInputs( - ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays)); + bool has_uninitialized_tensor_lists; + OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs( + ctx, &arguments, &has_uninitialized_vars, + &has_tensor_arrays, &has_uninitialized_tensor_lists)); xla::XlaBuilder* builder = ctx->builder(); XlaCompiler* compiler = ctx->compiler(); @@ -327,10 +335,13 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Hence we can use the output shapes and TensorArray gradients of each // resource as the "true" shapes. // 2) again with the "correct" resource information determined by (1). - if (has_uninitialized_vars || has_tensor_arrays) { + if (has_uninitialized_vars || has_tensor_arrays || + has_uninitialized_tensor_lists) { VLOG(2) << "Recompiling loop body: has_uninitialized_vars: " << has_uninitialized_vars - << " has_tensor_arrays: " << has_tensor_arrays; + << " has_tensor_arrays: " << has_tensor_arrays + << " has_uninitialized_tensor_lists: " + << has_uninitialized_tensor_lists; // Initializes any uninitialized resource with zero values of the // shape determined by the first compilation. for (int i = 0; i < body.resource_updates.size(); ++i) { @@ -367,6 +378,23 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { arg.tensor_array_gradients.insert(gradient.first); } } + + // Set the shape of any uninitialized TensorLists to the shape determined by + // the first compilation. Note that, unlike resources, we do not initialize + // the input list with zeros here, that is done later. + xla::Shape body_output_shape = body.xla_output_shape; + OP_REQUIRES(ctx, body_output_shape.IsTuple(), + errors::FailedPrecondition( + "xla_output_shape of while body must be a tuple.")); + for (int i = 0; i < arguments.size(); i++) { + XlaCompiler::Argument& arg = arguments[i]; + if (arg.initialized || !IsTensorListInput(ctx, i)) { + continue; + } + arg.shape = body_output_shape.tuple_shapes(i); + arg.initialized = true; + } + // Recompile the body with the "correct" resource shapes. VLOG(1) << "Recompiling body with corrected resource shapes"; body = {}; @@ -450,6 +478,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); + } else if (IsTensorListInput(ctx, input_num)) { + xla::XlaOp input = ctx->Input(input_num); + auto input_shape_or = ctx->builder()->GetShape(input); + OP_REQUIRES_OK(ctx, input_shape_or.status()); + xla::Shape input_shape = input_shape_or.ValueOrDie(); + const xla::Shape& list_shape = body_input_shape.tuple_shapes(i); + // Shape/datatype of the input list may differ from shape/datatype of the + // body/cond input if the list's shape/datatype was inferred after the + // first compilation and the body/cond was recompiled with the updated + // shape/datatype of the list. + if (input_shape != list_shape) { + OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape( + ctx->builder(), list_shape, &inputs[i])); + } else { + inputs[i] = ctx->Input(input_num); + } } else { inputs[i] = ctx->Input(input_num); } @@ -481,7 +525,11 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { int resource_index = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { - ctx->SetOutput(i, xla::GetTupleElement(while_result, i)); + if (IsTensorListInput(ctx, i)) { + ctx->SetTensorListOutput(i, xla::GetTupleElement(while_result, i)); + } else { + ctx->SetOutput(i, xla::GetTupleElement(while_result, i)); + } ++resource_index; } else { break; diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 16ec8d0e520..bae187ca3ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -21,8 +21,6 @@ limitations under the License. namespace tensorflow { -extern const char kPropagateCompileTimeConsts[]; - // This TensorFlow op provides a functional iteration primitive. // // The inputs and outputs of the loop body must agree on the number, types, and diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index b20adc592a0..0b5b66ae52f 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -35,10 +35,9 @@ class XlaConvOp : public XlaOpKernel { string precision_config_attr; OP_REQUIRES_OK( context, context->GetAttr("precision_config", &precision_config_attr)); - OP_REQUIRES( - context, - precision_config_.ParsePartialFromString(precision_config_attr), - errors::InvalidArgument("Error parsing convolution dimension numbers")); + OP_REQUIRES(context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing precison config.")); } void Compile(XlaOpKernelContext* context) override { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc new file mode 100644 index 00000000000..a28ecd660ab --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/lib/svd.h" + +namespace tensorflow { +namespace { + +class XlaSvdOp : public XlaOpKernel { + public: + explicit XlaSvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + string precision_config_attr; + OP_REQUIRES_OK(ctx, + ctx->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES(ctx, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing precison config.")); + if (precision_config_.operand_precision_size() == 0) { + precision_config_.mutable_operand_precision()->Add( + xla::PrecisionConfig::HIGHEST); + } + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = xla::SVD(ctx->Input(0), max_iter_, epsilon_, + precision_config_.operand_precision(0)); + ctx->SetOutput(0, result.d); + ctx->SetOutput(1, result.u); + ctx->SetOutput(2, result.v); + } + + private: + int32 max_iter_; + float epsilon_; + xla::PrecisionConfig precision_config_; +}; + +class SvdOp : public XlaOpKernel { + public: + explicit SvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("compute_uv", &compute_uv_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); + } + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + int m = input_shape.dim_size(input_shape.dims() - 2); + int n = input_shape.dim_size(input_shape.dims() - 1); + // This is based on heuristics that approx log(n) sweep updates are needed. + // Note: the heuristics provides no theoretical guarantee, max_iter=100 and + // epsilon should be used to determine exit condition. + int max_iter = 2 * tensorflow::Log2Ceiling(std::max(m, n)); + auto result = xla::SVD(ctx->Input(0), max_iter, 1e-6); + ctx->SetOutput(0, result.d); + if (compute_uv_) { + int p = std::min(m, n); + if (!full_matrices_) { + if (p < m) { + result.u = xla::SliceInMinorDims(result.u, {0, 0}, {m, p}); + } + if (p < n) { + result.v = xla::SliceInMinorDims(result.v, {0, 0}, {n, p}); + } + } + ctx->SetOutput(1, result.u); + ctx->SetOutput(2, result.v); + } else { + ctx->SetOutput(1, xla::ScalarLike(ctx->Input(0), 0.0)); + ctx->SetOutput(2, xla::ScalarLike(ctx->Input(0), 0.0)); + } + } + + private: + bool compute_uv_; + bool full_matrices_; +}; + +REGISTER_XLA_OP(Name("XlaSvd").TypeConstraint("T", kFloatTypes), XlaSvdOp); +REGISTER_XLA_OP(Name("Svd").TypeConstraint("T", kFloatTypes), SvdOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 3d7b0bc959f..f9ce50be6e3 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -82,3 +82,15 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "data_format", + srcs = ["data_format.cc"], + hdrs = ["data_format.h"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc new file mode 100644 index 00000000000..0253bcdc5f9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -0,0 +1,87 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/data_format.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +xla::StatusOr Contract(xla::XlaOp input, int64 dim) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + + if (input_shape.dimensions().back() != 4) { + return errors::InvalidArgument("Expected last dimension to be 4; got ", + input_shape.dimensions().back()); + } + + // Transpose the input so C is directly followed by VECT_C. + std::vector permutation; + for (int64 i = 0; i != input_shape.rank() - 1; ++i) { + permutation.push_back(i); + if (i == dim) { + permutation.push_back(input_shape.rank() - 1); + } + } + + // Now merge the adjacent dimensions with a reshape. + std::vector contracted_shape(input_shape.dimensions().begin(), + input_shape.dimensions().end() - 1); + contracted_shape[dim] *= 4; + + return xla::Reshape(xla::Transpose(input, permutation), contracted_shape); +} + +xla::StatusOr Expand(xla::XlaOp input, int64 dim) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + + if (input_shape.dimensions(dim) % 4 != 0) { + return errors::InvalidArgument( + "Expected vectorized dimension to be evenly divisible by 4; got ", + input_shape.dimensions(dim)); + } + + // Split the `dim` into two dimensions with a reshape. The size of the new + // dimension is always 4. + std::vector expanded_shape(input_shape.dimensions()); + expanded_shape[dim] /= 4; + expanded_shape.insert(expanded_shape.begin() + dim, 4); + + // Move the newly created dimension to the end with a transpose. + std::vector permutation; + for (int64 i = 0; i != expanded_shape.size(); ++i) { + permutation.push_back(i); + if (i == dim) { + ++i; + } + } + permutation.push_back(dim + 1); + + return xla::Transpose(xla::Reshape(input, expanded_shape), permutation); +} + +} // namespace + +xla::StatusOr NCHW_VECT_CToNCHW(xla::XlaOp input) { + return Contract(input, 1); +} + +xla::StatusOr NCHWToNCHW_VECT_C(xla::XlaOp input) { + return Expand(input, 1); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/data_format.h b/tensorflow/compiler/tf2xla/lib/data_format.h new file mode 100644 index 00000000000..839723b0ea8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/data_format.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +// Reformat from NCHW_VECT_C to NCHW. +// +// Prerequisites: the last dimension of the input must be of size 4. +xla::StatusOr NCHW_VECT_CToNCHW(xla::XlaOp input); + +// Reformat from NCHW to NCHW_VECT_C. +// +// Prerequisites: the vectorized dimension `C` must be a multiple of 4. +xla::StatusOr NCHWToNCHW_VECT_C(xla::XlaOp input); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 06eda416118..d348d2b41dd 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -111,7 +111,7 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; - case xla::OPAQUE: + case xla::OPAQUE_TYPE: LOG(FATAL) << "opaque element type is not integral"; default: LOG(FATAL) << "unhandled element type " << type; diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 7140b6a1227..4f1f3d7c326 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -1,5 +1,5 @@ package( - default_visibility = ["//tensorflow/compiler/tf2xla:internal"], + default_visibility = ["//tensorflow:internal"], ) licenses(["notice"]) # Apache 2.0 @@ -17,6 +17,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index ccd58071d35..c9bf15ac5f1 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -91,6 +93,40 @@ v: The column v[..., :, i] is the normalized eigenvector corresponding to the eigenvalue w[..., i]. )doc"); +REGISTER_OP("XlaSvd") + .Input("a: T") + .Attr("max_iter: int") + .Attr("epsilon: float") + .Attr("precision_config: string") + .Output("s: T") + .Output("u: T") + .Output("v: T") + .SetShapeFn(shape_inference::UnknownShape) + .Attr("T: numbertype") + .Doc(R"doc( +Computes the eigen decomposition of a batch of self-adjoint matrices +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in +tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]). + +a: the input tensor. + +max_iter: maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximatly log(min (M, N)) sweeps are needed in practice + (Ref: Golub & van Loan "Matrix Computation"). + +epsilon: the tolerance ratio. + +precision_config: a serialized xla::PrecisionConfig proto. + +s: Singular values. The values are sorted in reverse order of magnitude, so + s[..., 0] is the largest value, s[..., 1] is the second largest, etc. +u: Left singular vectors. +v: Right singular vectors. +)doc"); + REGISTER_OP("XlaConv") .Input("lhs: T") .Input("rhs: T") @@ -472,5 +508,94 @@ transpose_output: Boolean to determine if output is transposed. transpose_output is faster when input is large and rank of input is higher than 1. )doc"); +REGISTER_OP("XlaEinsum") + .Input("a: T") + .Input("b: T") + .Output("product: T") + .Attr("equation: string") + .Attr("T: {bfloat16, float}") + .SetShapeFn([](shape_inference::InferenceContext* context) { + shape_inference::ShapeHandle input_a = context->input(0); + shape_inference::ShapeHandle input_b = context->input(1); + + int64 rank_a, rank_b; + if (context->RankKnown(input_a)) { + rank_a = context->Rank(input_a); + } else { + return errors::InvalidArgument("input 0's rank is unknown."); + } + if (context->RankKnown(input_b)) { + rank_b = context->Rank(input_b); + } else { + return errors::InvalidArgument("input 1's rank is unknown."); + } + string equation; + TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation)); + + std::map left_map; + std::map right_map; + std::vector dims; + + std::vector equation_split = absl::StrSplit(equation, "->"); + + if (equation_split.size() != 2) { + return errors::InvalidArgument("Expected one \"->\" in equation. Got: ", + equation); + } + + std::vector lhs_rhs_split = + absl::StrSplit(equation_split[0], ','); + if (lhs_rhs_split.size() != 2) { + return errors::InvalidArgument("Expected one \",\" in equation. Got: ", + equation); + } + + if (rank_a != lhs_rhs_split[0].size()) { + return errors::InvalidArgument(absl::StrCat( + "Expected equation[0] with size: ", rank_a, " Got '", + lhs_rhs_split[0], "'", " with size: ", lhs_rhs_split[0].size())); + } + + if (rank_b != lhs_rhs_split[1].size()) { + return errors::InvalidArgument(absl::StrCat( + "Expected equation[1] with size: ", rank_b, " Got '", + lhs_rhs_split[1], "'", " with size: ", lhs_rhs_split[1].size())); + } + + for (int i = 0; i < lhs_rhs_split[0].size(); ++i) { + left_map[lhs_rhs_split[0][i]] = context->Dim(input_a, i); + } + for (int i = 0; i < lhs_rhs_split[1].size(); ++i) { + right_map[lhs_rhs_split[1][i]] = context->Dim(input_b, i); + } + + for (const char& c : equation_split[1]) { + if (left_map.count(c)) { + dims.push_back(left_map[c]); + } else if (right_map.count(c)) { + dims.push_back(right_map[c]); + } else { + return errors::InvalidArgument("Invalid equation: ", equation); + } + } + + context->set_output(0, context->MakeShape(dims)); + return Status::OK(); + }) + .Doc(R"doc( +An op which supports basic einsum op with 2 inputs and 1 output. + +This op has better TPU performnce since it doesn't have explicitly reshape and +transpose operations as tf.einsum does. +)doc"); + +REGISTER_OP("XlaReplicaId") + .Output("id: int32") + .SetShapeFn([](shape_inference::InferenceContext* context) { + context->set_output(0, context->MakeShape({})); + return Status::OK(); + }) + .Doc("Replica ID."); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 9abdb04d773..c6f57b386eb 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -1,13 +1,8 @@ licenses(["notice"]) # Apache 2.0 -package_group( - name = "friends", - includes = ["//tensorflow:internal"], -) - package( default_visibility = [ - ":friends", + "//visibility:public", ], ) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index de4710d03a3..bedb9a6eb0e 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -295,8 +295,16 @@ def self_adjoint_eig(a, lower, max_iter, epsilon): return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) +def svd(a, max_iter, epsilon, precision_config=None): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) + + dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice +einsum = gen_xla_ops.xla_einsum # TODO(phawkins): generalize tf.pad to support interior padding, and then remove # the XLA-specific pad operator. @@ -364,6 +372,9 @@ def reduce_window(operand, name=name) +replica_id = gen_xla_ops.xla_replica_id + + def reshape(x, new_sizes, dimensions=None, name=None): if dimensions is not None: x = array_ops.transpose(x, dimensions) diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc new file mode 100644 index 00000000000..bea9a4c4123 --- /dev/null +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -0,0 +1,544 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" + +#include + +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { + +namespace { + +// Given original input types and argument index mapping, return the new input +// types. +std::vector ShuffleInputDataTypeAttribute( + const std::vector& in_types, + const std::vector& index_mapping) { + std::vector result(index_mapping.size()); + for (int i = 0; i < in_types.size(); i++) { + result[index_mapping.at(i)] = in_types[i]; + } + return result; +} + +// Given original input types, check if we need to rewrite the function (by +// checking if all DT_RESOURCE inputs are in the end). If the function needs to +// be rewritten, `resource_input_count` will be set to number of DT_RESOURCE +// inputs, and `index_mapping` will hold a mapping for original input index to +// rearranged input index. +Status InputTypesNeedsRearrange(const std::vector& in_types, + bool* need_rewrite, int* resource_input_count, + std::vector* index_mapping) { + int first_resource_index = -1; + for (int i = 0; i < in_types.size(); i++) { + DataType type = in_types[i]; + if (type == DT_RESOURCE) { + first_resource_index = i; + break; + } + } + if (first_resource_index == -1) { + // No resource input. No need to rewrite. + *need_rewrite = false; + return Status::OK(); + } + + *need_rewrite = false; + for (int i = first_resource_index + 1; i < in_types.size(); i++) { + if (in_types[i] != DT_RESOURCE) { + *need_rewrite = true; + break; + } + } + if (!*need_rewrite) { + return Status::OK(); + } + + *resource_input_count = 0; + for (int i = 0; i < in_types.size(); i++) { + DataType type = in_types[i]; + if (type == DT_RESOURCE) { + ++(*resource_input_count); + } + } + int non_resource_index = 0, + resource_index = in_types.size() - *resource_input_count; + index_mapping->resize(in_types.size()); + for (int i = 0; i < in_types.size(); i++) { + if (in_types[i] != DT_RESOURCE) { + (*index_mapping)[i] = non_resource_index; + non_resource_index++; + } else { + (*index_mapping)[i] = resource_index; + resource_index++; + } + } + + return Status::OK(); +} + +// Given mapping between original input index and rearranged input index, +// reorder input edges for the node. +Status ReorderInputEdges(Graph* g, Node* n, + const std::vector& index_mapping) { + std::vector input_edges; + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + input_edges.push_back(e); + } + for (const Edge* e : input_edges) { + Node* src = e->src(); + int src_output = e->src_output(); + int dst_input = e->dst_input(); + int new_dst_input = index_mapping.at(dst_input); + g->RemoveEdge(e); + g->AddEdge(src, src_output, n, new_dst_input)->DebugString(); + } + return Status::OK(); +} + +// For While node, given mapping between original input index and rearranged +// input index, reorder output edges for the node. DT_RESOURCE outputs are +// removed from the node and we will use the node's corresponding input for the +// edge. +Status ReorderOutputEdges(Graph* g, Node* n, int input_count, + int resource_input_count, + const std::vector& index_mapping) { + std::vector output_edges; + for (const Edge* e : n->out_edges()) { + if (e->IsControlEdge()) { + continue; + } + output_edges.push_back(e); + } + for (const Edge* e : output_edges) { + int src_output = e->src_output(); + int new_src_output = index_mapping.at(src_output); + Node* dst = e->dst(); + int dst_input = e->dst_input(); + g->RemoveEdge(e); + + if (new_src_output < input_count - resource_input_count) { + g->AddEdge(n, new_src_output, dst, dst_input); + } else { + const Edge* input_edge; + TF_RETURN_IF_ERROR(n->input_edge(new_src_output, &input_edge)); + g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input); + } + } + return Status::OK(); +} + +// Given mapping between original input index and rearranged input index, change +// "index" attribute for _Arg nodes. +void RearrangeArgNodes( + const gtl::InlinedVector* arg_nodes, // non-absl ok + const std::vector& index_mapping) { + for (int i = 0; i < arg_nodes->size(); i++) { + Node* n = (*arg_nodes)[i]; + int new_index = index_mapping.at(i); + n->ClearAttr("index"); + n->AddAttr("index", new_index); + } +} + +// Given all _Retval nodes in the function, return if we need to rewrite the +// function (by checking if we have DT_RESOURCE return values). If we need to +// rewrite the function, `retval_index_mapping` will hold the mapping from +// original _Retval to rearranged _Retval, and `resource_retval_to_arg` will +// hold mapping from DT_RESOURCE _Retval index to its input _Arg index. Here we +// assume that all DT_RESOURCE _Retval nodes come from _Arg nodes directly. +Status CalculateRetvalRearrange( + const gtl::InlinedVector& ret_nodes, // non-absl ok + std::map* retval_index_mapping, + std::map* resource_retval_to_arg) { + for (int i = 0; i < ret_nodes.size(); i++) { + Node* n = ret_nodes[i]; + DataType t; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &t)); + if (t != DT_RESOURCE) { + int new_retval_index = retval_index_mapping->size(); + retval_index_mapping->insert(std::make_pair(i, new_retval_index)); + continue; + } + + const Edge* e; + TF_RETURN_IF_ERROR(n->input_edge(0, &e)); + if (!e->src()->IsArg()) { + return errors::Unimplemented( + "Resource _Retval node's input does not come from _Arg " + "directly: ", + e->DebugString()); + } + Node* arg = e->src(); + int src_index; + TF_RETURN_IF_ERROR(GetNodeAttr(arg->def(), "index", &src_index)); + resource_retval_to_arg->insert(std::make_pair(i, src_index)); + } + return Status::OK(); +} + +// Given original output types and return value index mapping, return the new +// output types. Notice that DT_RESOURCE will be removed. +std::vector ShuffleOutputDataTypeAttribute( + const std::vector& out_types, + const std::map& index_mapping) { + std::vector result(index_mapping.size()); + for (int i = 0; i < out_types.size(); i++) { + auto iter = index_mapping.find(i); + if (iter != index_mapping.end()) { + result[iter->second] = out_types[i]; + } + } + return result; +} + +// For StatefulPartitionedCall node, given mapping between original input index +// and rearranged input index, reorder output edges for the node. DT_RESOURCE +// outputs are removed from the node and we will use the node's corresponding +// input for the edge. +Status RearrangeOutputEdges(Node* n, Graph* g, + const std::map& retval_index_mapping, + const std::map& resource_retval_to_arg) { + std::vector out_edges; + for (const Edge* e : n->out_edges()) { + if (!e->IsControlEdge()) { + out_edges.push_back(e); + } + } + for (const Edge* e : out_edges) { + Node* dst = e->dst(); + int dst_input = e->dst_input(); + int src_output = e->src_output(); + auto iter = retval_index_mapping.find(src_output); + if (iter == retval_index_mapping.end()) { + TF_RET_CHECK(resource_retval_to_arg.find(src_output) != + resource_retval_to_arg.end()); + g->RemoveEdge(e); + const Edge* input_edge; + TF_RETURN_IF_ERROR( + n->input_edge(resource_retval_to_arg.at(src_output), &input_edge)); + g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input); + } else { + g->RemoveEdge(e); + g->AddEdge(n, iter->second, dst, dst_input); + } + } + return Status::OK(); +} + +// Given mapping between original output index and rearranged output index, +// change "index" attribute for _Retval nodes. Notice that DT_RESOURCE _Retval +// nodes will be removed. +void RearrangeRetvalNodes( + const gtl::InlinedVector& ret_nodes, // non-absl ok + Graph* g, const std::map& retval_index_mapping) { + for (int i = 0; i < ret_nodes.size(); i++) { + Node* n = ret_nodes[i]; + auto iter = retval_index_mapping.find(i); + if (iter == retval_index_mapping.end()) { + g->RemoveNode(n); + } else { + n->ClearAttr("index"); + n->AddAttr("index", iter->second); + } + } +} + +Status MaybeRewriteWhileNode( + std::function + get_function_body_fn, + Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) { + // Check if this While node needs rewrite. + std::vector types; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &types)); + bool input_need_rearrange; + int resource_input_count; + std::vector index_mapping; + TF_RETURN_IF_ERROR(InputTypesNeedsRearrange( + types, &input_need_rearrange, &resource_input_count, &index_mapping)); + if (!input_need_rearrange) { + *node_rewritten = false; + return Status::OK(); + } + + *node_rewritten = true; + + // Modify "T" attribute for this While node. + std::vector new_types = + ShuffleInputDataTypeAttribute(types, index_mapping); + n->ClearAttr("T"); + n->AddAttr("T", new_types); + + // Reorder input and output edges. + TF_RETURN_IF_ERROR(ReorderInputEdges(g, n, index_mapping)); + TF_RETURN_IF_ERROR(ReorderOutputEdges(g, n, types.size(), + resource_input_count, index_mapping)); + + // Modify cond and body functions. + for (auto const& attr_name : std::vector{"cond", "body"}) { + NameAttrList attr_value; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value)); + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(get_function_body_fn(attr_value, &fbody)); + + // Check that resource _Arg nodes for While node are always returned with + // the same index, and we don't have cases like this: + // tf.while_loop( + // cond, + // lambda resource_var1, resource_var2: [resource_var2, resource_var1], + // [resource_var1, resource_var2]) + if (attr_name == "body") { + for (int i = 0; i < fbody->ret_nodes.size(); i++) { + Node* n = fbody->ret_nodes[i]; + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype)); + if (dtype != DT_RESOURCE) { + continue; + } + + Node* input_node; + TF_RETURN_IF_ERROR(n->input_node(0, &input_node)); + while (input_node->IsIdentity()) { + TF_RETURN_IF_ERROR(input_node->input_node(0, &input_node)); + } + if (input_node->IsArg()) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index)); + if (index != i) { + return errors::Unimplemented("While node ", n->DebugString(), + " has resource _Retval[", i, + "] coming from _Arg[", index, "]"); + } + } else { + return errors::Unimplemented("Encountered node ", + input_node->DebugString(), + " while tracing _Arg node for _Retval[", + i, "] of while node ", n->DebugString()); + } + } + } + + RearrangeArgNodes(&fbody->arg_nodes, index_mapping); + if (attr_name == "body") { + for (int i = 0; i < fbody->ret_nodes.size(); i++) { + Node* n = fbody->ret_nodes[i]; + int new_index = index_mapping.at(i); + if (new_index < types.size() - resource_input_count) { + n->ClearAttr("index"); + n->AddAttr("index", new_index); + } else { + fbody->graph->RemoveNode(n); + } + } + } + + // Save the new FunctionDef. + FunctionDef new_fdef; + string new_name = + fld->UniqueFunctionName(absl::StrCat(attr_value.name(), "_rearrange_")); + TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef)); + + // Change node to use rewritten function. + attr_value.set_name(new_name); + n->ClearAttr(attr_name); + n->AddAttr(attr_name, attr_value); + } + return Status::OK(); +} + +Status MaybeRewriteIfNode( + std::function + get_function_body_fn, + Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) { + // This node needs rewrite when either of these is true: + // 1) Tin has DT_RESOURCE which requires rearrange; + // 2) Tout has DT_RESOURCE. + std::vector in_types; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &in_types)); + bool input_need_rearrange; + int resource_input_count; + std::vector index_mapping; + TF_RETURN_IF_ERROR(InputTypesNeedsRearrange( + in_types, &input_need_rearrange, &resource_input_count, &index_mapping)); + std::vector out_types; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tout", &out_types)); + bool has_resource_output = std::find(out_types.begin(), out_types.end(), + DT_RESOURCE) != out_types.end(); + if (!input_need_rearrange && !has_resource_output) { + *node_rewritten = false; + return Status::OK(); + } + + *node_rewritten = true; + + if (input_need_rearrange) { + // Reorder input edges. + std::vector input_edges; + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge() || e->dst_input() == 0) { + continue; + } + input_edges.push_back(e); + } + for (const Edge* e : input_edges) { + Node* src = e->src(); + int src_output = e->src_output(); + int dst_input = e->dst_input(); + int new_dst_input = index_mapping.at(dst_input - 1) + 1; + g->RemoveEdge(e); + g->AddEdge(src, src_output, n, new_dst_input)->DebugString(); + } + + // Change Tin attribute. + std::vector new_in_types = + ShuffleInputDataTypeAttribute(in_types, index_mapping); + n->ClearAttr("Tin"); + n->AddAttr("Tin", new_in_types); + } + + std::map resource_retval_to_arg, retval_index_mapping; + for (auto const& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList f; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f)); + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(get_function_body_fn(f, &fbody)); + + if (input_need_rearrange) { + // Change _Arg node index. + RearrangeArgNodes(&fbody->arg_nodes, index_mapping); + } + + if (has_resource_output) { + // Resource _Retval must come from resource _Arg directly, or we do + // not support it. + TF_RETURN_IF_ERROR(CalculateRetvalRearrange( + fbody->ret_nodes, &retval_index_mapping, &resource_retval_to_arg)); + + // Change index for _Retval nodes. + RearrangeRetvalNodes(fbody->ret_nodes, fbody->graph, + retval_index_mapping); + } + + // Save the new FunctionDef. + FunctionDef new_fdef; + string new_name = + fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_")); + TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef)); + + // Change node to use rewritten function. + f.set_name(new_name); + n->ClearAttr(attr_name); + n->AddAttr(attr_name, f); + } + + if (has_resource_output) { + // Rearrange output edges. + std::vector out_edges; + for (const Edge* e : n->out_edges()) { + if (!e->IsControlEdge()) { + out_edges.push_back(e); + } + } + for (const Edge* e : out_edges) { + Node* dst = e->dst(); + int dst_input = e->dst_input(); + int src_output = e->src_output(); + auto iter = retval_index_mapping.find(src_output); + if (iter == retval_index_mapping.end()) { + TF_RET_CHECK(resource_retval_to_arg.find(src_output) != + resource_retval_to_arg.end()); + g->RemoveEdge(e); + const Edge* input_edge; + TF_RETURN_IF_ERROR(n->input_edge( + resource_retval_to_arg.at(src_output) + 1, &input_edge)); + g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input); + } else { + g->RemoveEdge(e); + g->AddEdge(n, iter->second, dst, dst_input); + } + } + + // Change Tout attribute for the node. + std::vector new_out_types = + ShuffleOutputDataTypeAttribute(out_types, retval_index_mapping); + n->ClearAttr("Tout"); + n->AddAttr("Tout", new_out_types); + } + return Status::OK(); +} + +} // namespace + +Status RearrangeFunctionArguments( + std::function + get_function_body_fn, + Graph* g, FunctionLibraryDefinition* fld) { + // Inline StatefulPartitionedCall nodes. + std::vector call_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "StatefulPartitionedCall") { + call_nodes.push_back(n); + } + } + for (Node* n : call_nodes) { + NameAttrList func_name_attrs; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func_name_attrs)); + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(get_function_body_fn(func_name_attrs, &fbody)); + InlineFunctionBodyOptions opts; + Status s = InlineFunctionBody(*fld, g, n, fbody, opts); + // Inlining might fail because the function is marked with attribute + // _noinline. + s.IgnoreError(); + FixupSourceAndSinkEdges(g); + } + + // Rewrite If/While nodes. + for (Node* n : g->nodes()) { + if (n->type_string() == "While") { + bool node_rewritten; + TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld, + &node_rewritten)); + } else if (n->type_string() == "If") { + bool node_rewritten; + TF_RETURN_IF_ERROR( + MaybeRewriteIfNode(get_function_body_fn, g, n, fld, &node_rewritten)); + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.h b/tensorflow/compiler/tf2xla/rearrange_function_argument.h new file mode 100644 index 00000000000..c553d8b6e41 --- /dev/null +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// For the given graph `g`: +// 1. Rewrite If/While node functions to rearrange arguments and return values, +// so that all resource arguments/return values are placed in the end (as +// required by XlaCompiler), +// 2. Inline StatefulPartitionedCall nodes so we do not need to rearrange +// arguments and return values. +// `get_function_body_fn` is used to instantiate FunctionDef. +// `fld` is used to store rewritten functions. +Status RearrangeFunctionArguments( + std::function + get_function_body_fn, + Graph* g, FunctionLibraryDefinition* fld); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 29ebf46e4bf..1243e31a047 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -83,6 +83,8 @@ CreateResourceOpInfoMap() { add("ResourceScatterUpdate" , kReadWrite, kVariable); add("ResourceStridedSliceAssign" , kReadWrite, kVariable); add("StatefulStandardNormalV2" , kReadWrite, kVariable); + add("StatefulTruncatedNormal" , kReadWrite, kVariable); + add("StatefulUniform" , kReadWrite, kVariable); add("StatefulUniformFullInt" , kReadWrite, kVariable); add("StatefulUniformInt" , kReadWrite, kVariable); add("VarIsInitializedOp" , kRead, kVariable); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 6db027e2acd..3e4188f3c6d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -45,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -164,12 +164,10 @@ Status RewriteAndPruneGraph( std::unordered_set retval_nodes; TF_RETURN_IF_ERROR( AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); - VLOG(2) << "Post rewrite: " - << dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph); - PruneForReverseReachability(graph, retval_nodes); + VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph); + PruneForReverseReachability(graph, std::move(retval_nodes)); FixupSourceAndSinkEdges(graph); - VLOG(2) << "Post prune: " - << dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph); + VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); // Sanity-check, to make sure the feeds and fetches still exist post-pruning. std::set missing_feeds, missing_fetches; for (const tf2xla::Feed& feed : config.feed()) { @@ -280,12 +278,14 @@ Status ConvertGraphToXla(std::unique_ptr graph, arg.initialized = true; xla_args.push_back(std::move(arg)); - // We want to alias the input and output of the variable, so the updates are - // carried out in-place. - xla_aliases.push_back({/*output_index=*/{output_num}, - /*param_number=*/input_num, /*param_index=*/{}}); + if (!variable.readonly()) { + // We want to alias the input and output of the variable, so the updates + // are carried out in-place. + xla_aliases.push_back({/*output_index=*/{output_num}, + /*param_number=*/input_num, /*param_index=*/{}}); + ++output_num; + } ++input_num; - ++output_num; } // Compile the graph into an XLA computation. @@ -295,6 +295,8 @@ Status ConvertGraphToXla(std::unique_ptr graph, compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; + compiler_options.custom_fake_quant_op_calls = + config.conversion_options().custom_fake_quant_op_calls(); XlaCompiler compiler(compiler_options); XlaCompiler::CompilationResult result; @@ -326,6 +328,24 @@ Status ConvertGraphToXla(std::unique_ptr graph, " constant results. The configuration of " "the output args (i.e. fetch ids) is probably wrong."); } + { + // Verify that the readonly bits on variables are set correctly by the user. + std::vector updated_inputs(xla_args.size()); + for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) { + updated_inputs[update.input_index] = true; + } + int64 input_index = xla_args.size() - config.variable_size(); + for (const tf2xla::Variable& variable : config.variable()) { + if (variable.readonly() == updated_inputs[input_index]) { + return errors::InvalidArgument( + "Variable \"", variable.node_name(), "\" is marked as ", + variable.readonly() ? "" : "not ", "readonly, but is ", + updated_inputs[input_index] ? "" : "not ", + "modified by the computation."); + } + ++input_index; + } + } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.proto b/tensorflow/compiler/tf2xla/tf2xla.proto index 5627af7452b..3093a0b1d8d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.proto +++ b/tensorflow/compiler/tf2xla/tf2xla.proto @@ -1,14 +1,15 @@ syntax = "proto3"; package tensorflow.tf2xla; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + option cc_enable_arenas = true; option java_outer_classname = "Tf2XlaProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.tf2xla"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; - // TensorId identifies a tensor in a TensorFlow graph, by specifying the output // index of a particular node in the graph. If the output of the named node // feeds into other node(s), this corresponds to one or more edges. Otherwise @@ -16,7 +17,7 @@ import "tensorflow/core/framework/types.proto"; message TensorId { string node_name = 1; int64 output_index = 2; -}; +} // Feed represents a single feed tensor in the graph, which corresponds to an // input argument for the generated computation. @@ -30,14 +31,18 @@ message Feed { // not linked into the binary, then the type cannot be inferred from the node; // in this case, the type should be set here. DataType type = 4; -}; +} // Fetch represents a single fetch tensor in the graph, which corresponds to an // output argument for the generated computation. message Fetch { TensorId id = 1; string name = 2; // Optional name for generated code. -}; + + // Optional shape and data type. If specified, may be used for validation. + TensorShapeProto shape = 3; + DataType type = 4; +} // Variable represents a resource variable with the given name, shape and type. message Variable { @@ -46,6 +51,18 @@ message Variable { 2; // Optional name for generated code. If empty, node_name will be used. TensorShapeProto shape = 3; DataType type = 4; + + // Flag for variables that are never assigned. Assigments to a read-only + // variable or unassigned variables that are not read-only are invalid. + bool readonly = 5; +} + +// Options used during the conversion and compilation process. +message ConversionOptions { + // When true tf.fake_quant_* ops will be emitted as custom calls to a + // 'fake_quant_with_min_max_vars' function accepting the input, min, max, + // num_bits, and narrow_range values as runtime arguments. + bool custom_fake_quant_op_calls = 1; } // Config represents configuration information for tf2xla conversion. @@ -58,4 +75,6 @@ message Config { repeated Fetch fetch = 2; // Each variable is a named input and output of the generated computation. repeated Variable variable = 3; -}; + // Optional conversion options. + ConversionOptions conversion_options = 4; +} diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 93a5d9d7bab..8cae193fa30 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -114,7 +114,7 @@ Status ReplaceArgUsageWithConstNode( // Collect all _Arg nodes. std::unordered_map arg_nodes; for (Node* n : g->op_nodes()) { - if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + if (n->IsArg()) { int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); arg_nodes[index] = n; @@ -184,14 +184,9 @@ Status PropagateConstIntoFuncAttr( return errors::Internal("Cannot find function ", func_attr.name(), " for node ", n->name()); } - FunctionBody* fbody; + std::unique_ptr fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fdef, AttrSlice(&func_attr.attr()), lookup_fld, - [lookup_fld](const string& op, const OpDef** sig) { - return lookup_fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); + *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody)); // Rewrite _Arg usages with Const node. Graph* func_graph = fbody->graph; @@ -778,4 +773,157 @@ Status PropagateConstIntoFunctionalNodes( return Status::OK(); } +Status PruneUnreachableFunctionsFromGraph(const Graph& g, + FunctionLibraryDefinition* fld) { + GraphDef graph_def; + g.ToGraphDef(&graph_def); + FunctionLibraryDefinition reachable_functions = + fld->ReachableDefinitions(graph_def); + for (const string& func_name : fld->ListFunctionNames()) { + if (!reachable_functions.Find(func_name)) { + TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name)); + } + } + return Status::OK(); +} + +Status RewriteTensorListWithConstElement(Graph* g, + FunctionLibraryDefinition* fld) { + for (Node* n : g->nodes()) { + if (n->type_string() != "EmptyTensorList") { + continue; + } + + // Find the forward While op. + std::vector fwd_while_edges; + for (const Edge* e : n->out_edges()) { + if (!e->IsControlEdge() && e->dst()->type_string() == "While") { + fwd_while_edges.push_back(e); + } + } + if (fwd_while_edges.size() != 1) { + // No forward While op found, or multiple forward While ops. + continue; + } + + // Find the backward While op. + Node* fwd_while = fwd_while_edges[0]->dst(); + int fwd_while_dst_input = fwd_while_edges[0]->dst_input(); + std::vector bwd_while_edges; + for (const Edge* e : fwd_while->out_edges()) { + if (e->src_output() == fwd_while_dst_input && + e->dst()->type_string() == "While") { + bwd_while_edges.push_back(e); + } + } + if (bwd_while_edges.size() != 1) { + // No backward While op found, or multiple backward While ops. + continue; + } + + Node* bwd_while = bwd_while_edges[0]->dst(); + int bwd_while_dst_input = bwd_while_edges[0]->dst_input(); + + // Look into forward While body function and check if TensorListPushBack op + // has a Const input. + NameAttrList fwd_body_attr; + TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr)); + const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name()); + if (!fwd_body) { + return errors::InvalidArgument("Cannot find function ", + fwd_body_attr.name(), " for While node ", + fwd_while->DebugString()); + } + std::unique_ptr fwd_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody)); + + // Find the TensorListPushBack node; it's one of fwd_arg's successors. + Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input]; + std::vector tl_push_nodes; + for (const Edge* out_edge : fwd_arg->out_edges()) { + if (out_edge->dst()->type_string() == "TensorListPushBack") { + tl_push_nodes.push_back(out_edge->dst()); + } + } + if (tl_push_nodes.size() != 1) { + // No TensorListPushBack found, or multiple TensorListPushBack. + continue; + } + + // Get input for the TensorListPushBack node. + Node* input_node; + TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node)); + if (input_node->type_string() != "Const") { + // Input for the TensorList is not Const node. + continue; + } + + NodeDef const_input_nodedef = input_node->def(); + + // Rewrite backward While body function, replace usages of + // TensorListPopBack with a Const node. + NameAttrList bwd_body_attr; + TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr)); + const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name()); + if (!bwd_body) { + return errors::InvalidArgument("Cannot find function ", + bwd_body_attr.name(), " for While node ", + bwd_while->DebugString()); + } + std::unique_ptr bwd_fbody; + TF_CHECK_OK(FunctionDefToBodyHelper( + *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody)); + + // Find the TensorListPopBack node; it's one of bwd_arg's successors. + Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input]; + std::vector tl_pop_nodes; + for (const Edge* out_edge : bwd_arg->out_edges()) { + if (out_edge->dst()->type_string() == "TensorListPopBack") { + tl_pop_nodes.push_back(out_edge->dst()); + } + } + if (tl_pop_nodes.size() != 1) { + // No TensorListPopBack found, or multiple TensorListPopBack. + continue; + } + + // Replace TensorListPopBack usages with Const node. + std::vector edges_to_replace; + for (const Edge* e : tl_pop_nodes[0]->out_edges()) { + if (e->src_output() == 1) { + edges_to_replace.push_back(e); + } + } + if (edges_to_replace.empty()) { + continue; + } + Status s; + const_input_nodedef.set_name( + bwd_fbody->graph->NewName(const_input_nodedef.name())); + Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s); + TF_RETURN_IF_ERROR(s); + for (const Edge* e : edges_to_replace) { + Node* dst = e->dst(); + int dst_input = e->dst_input(); + bwd_fbody->graph->RemoveEdge(e); + bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input); + } + + // Add rewritten backward While body function. + FunctionDef new_fdef; + string new_name = fld->UniqueFunctionName( + absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_")); + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef)); + + // Change backward While op to use the new body function. + bwd_body_attr.set_name(new_name); + bwd_while->ClearAttr("body"); + bwd_while->AddAttr("body", bwd_body_attr); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index cf3aa2f847c..c9d73450425 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -197,6 +198,20 @@ Status PropagateConstIntoFunctionalNodes( Graph* g, const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld); +// Prunes unreachable FunctionDefs from FunctionLibraryDefinition. +Status PruneUnreachableFunctionsFromGraph(const Graph& g, + FunctionLibraryDefinition* fld); + +// Finds the following pattern in the graph: +// 1) EmptyTensorList -> forward While op -> backward While op, +// 2) in forward While op, a Const node is pushed, +// 3) in backward While op, data is popped from the tensor list. +// And rewrites backward While op to use Const node instead of TensorListPopBack +// result. +// TODO(b/128633174) remove the TensorList and related TensorList ops. +Status RewriteTensorListWithConstElement(Graph* g, + FunctionLibraryDefinition* fld); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 28b4744470e..0fde45c2696 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/list_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" @@ -416,5 +418,86 @@ TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) { EXPECT_EQ(const_def->second.op(), "Const"); } +TEST(PropagateConstIntoFunctionalNodes, RewriteTensorListWithConstMember) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph + Scope scope = Scope::NewRootScope().ExitOnError(); + auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + auto result = + ops::Const(scope.WithOpName("result"), false, TensorShape({})); + auto ret = ops::_Retval(scope.WithOpName("ret"), result, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(fdef)); + } + { + // Forward body graph + Scope scope = Scope::NewRootScope().ExitOnError(); + auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + auto element = ops::Const(scope.WithOpName("element"), 0, TensorShape({})); + auto push = + ops::TensorListPushBack(scope.WithOpName("push"), input, element); + auto ret = ops::_Retval(scope.WithOpName("ret"), push.output_handle, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "fwd_body", &fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(fdef)); + } + { + // Backward body graph + Scope scope = Scope::NewRootScope().ExitOnError(); + auto input = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({})); + auto pop = + ops::TensorListPopBack(scope.WithOpName("pop"), input, shape, DT_INT32); + auto identity = ops::Identity(scope.WithOpName("identity"), pop.tensor); + auto ret = ops::_Retval(scope.WithOpName("ret"), pop.output_handle, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "bwd_body", &fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto shape = ops::Const(scope.WithOpName("element"), -1, TensorShape({})); + auto max_num_elements = + ops::Const(scope.WithOpName("max_num_elements"), 10, TensorShape({})); + auto tl = ops::EmptyTensorList(scope.WithOpName("tl"), shape, + max_num_elements, DT_INT32); + NameAttrList cond_fn, fwd_body_fn, bwd_body_fn; + cond_fn.set_name("cond"); + fwd_body_fn.set_name("fwd_body"); + bwd_body_fn.set_name("bwd_body"); + auto fwd_while_op = + ops::While(scope.WithOpName("fwd_while"), + std::initializer_list{tl}, cond_fn, fwd_body_fn); + auto bwd_while_op = + ops::While(scope.WithOpName("bwd_while"), + std::initializer_list{fwd_while_op.output[0]}, cond_fn, + bwd_body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(RewriteTensorListWithConstElement(&graph, &fld)); + + // Check that in rewritten backward While body function, the Identity node now + // has Const node as input. + const FunctionDef* bwd_body = fld.Find("bwd_body_tl_rewrite_0"); + ASSERT_NE(bwd_body, nullptr); + std::unique_ptr bwd_fbody; + TF_CHECK_OK( + FunctionDefToBodyHelper(*bwd_body, AttrSlice(), &fld, &bwd_fbody)); + auto node_name_index = bwd_fbody->graph->BuildNodeNameIndex(); + const Node* identity = node_name_index.at("identity"); + ASSERT_NE(identity, nullptr); + const Node* input; + TF_ASSERT_OK(identity->input_node(0, &input)); + EXPECT_EQ(input->type_string(), "Const"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index f98d07d196e..c14519c3ade 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -58,18 +58,13 @@ class XlaCompilationAllocator : public Allocator { // Make sure that even tensors with 0 elements have allocated // buffers, so they get ids to track. - bool ShouldAllocateEmptyTensors() override { return true; } - - private: - // Don't run any constructors or destructors for complex objects, - // since there is no backing store for the tensor to run them - // on. strings are the only complex objects currently stored in - // Tensors. If others are added, this set of overrides must be - // extended to include them. - void RunStringCtor(string* p, size_t n) override {} - void RunStringDtor(string* p, size_t n) override {} - void RunResourceCtor(ResourceHandle* p, size_t n) override {} - void RunResourceDtor(ResourceHandle* p, size_t n) override {} + // + // NOTE: It is the caller's responsibility to track whether an allocated + // object is a buffer or an opaque handle. In particular, when this allocator + // is used, the caller must not run any constructors or destructors for + // complex objects, since there is no backing store for the tensor in which to + // place their outputs. + bool AllocatesOpaqueHandle() const override { return true; } }; XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 1f0f240135d..5420cf3e04f 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include +#include "tensorflow/compiler/xla/cpu_function_runtime.h" namespace tensorflow { @@ -32,9 +33,9 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, program_shape_(static_data.program_shape_), hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) { bool allocate_entry_params = - alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS; + alloc_mode == AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS; // Allocate arg and temp buffers. - alloc_buffer_table_ = cpu_function_runtime::MallocContiguousBuffers( + alloc_buffer_table_ = xla::cpu_function_runtime::MallocContiguousBuffers( static_data.buffer_infos_, static_data.num_buffers_, /*allocate_entry_params=*/allocate_entry_params, buffer_table_, /*annotate_initialized=*/true); @@ -55,7 +56,7 @@ bool XlaCompiledCpuFunction::Run() { } XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { - cpu_function_runtime::FreeContiguous(alloc_buffer_table_); + xla::cpu_function_runtime::FreeContiguous(alloc_buffer_table_); delete[] buffer_table_; delete[] profile_counters_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index de2e485a47c..5e452b50e71 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -66,7 +66,7 @@ class XlaCompiledCpuFunction { RawFunction raw_function_; // Contains information about the buffers used by the XLA computation. - const cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; + const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; size_t num_buffers_ = 0; // Entry parameter i is described by @@ -105,7 +105,7 @@ class XlaCompiledCpuFunction { // AllocMode controls the buffer allocation mode. enum class AllocMode { // Allocate all buffers - args, results, profile and temps. - ARGS_RESULTS_PROFILES_AND_TEMPS, + ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS, // Only allocate result, profile and temp buffers. // Use set_arg_data to set argument buffers before Run is called. @@ -114,7 +114,8 @@ class XlaCompiledCpuFunction { explicit XlaCompiledCpuFunction( const StaticData& static_data, - AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); + AllocMode alloc_mode = + AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; @@ -166,7 +167,8 @@ class XlaCompiledCpuFunction { // // Allocated memory must be aligned to the size specified by // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in - // tensorflow/compiler/aot/runtime.h to ensure correct alignment. + // tensorflow/compiler/tf2xla/cpu_function_runtime.h to ensure correct + // alignment. // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. @@ -259,7 +261,7 @@ class XlaCompiledCpuFunction { static void set_static_data_buffer_infos( StaticData* static_data, - const cpu_function_runtime::BufferInfo* buffer_infos) { + const xla::cpu_function_runtime::BufferInfo* buffer_infos) { static_data->buffer_infos_ = buffer_infos; } @@ -323,7 +325,7 @@ class XlaCompiledCpuFunction { void** const buffer_table_; // Describes the buffers used by the XLA computation. - const cpu_function_runtime::BufferInfo* const buffer_infos_; + const xla::cpu_function_runtime::BufferInfo* const buffer_infos_; // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index f4a0f456ed5..b8eda1de94a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -19,8 +19,9 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" +#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" @@ -48,6 +49,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace { @@ -92,20 +94,19 @@ ComputeArgAndRetvalCores(const Graph& graph) { std::map arg_cores; std::map retval_cores; for (const Node* n : graph.nodes()) { - if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + if (n->IsArg()) { TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); if (core < 0) continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0) << "Negative _Arg index"; arg_cores[index] = core; - } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { + } else if (n->IsRetval()) { TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); if (core < 0) continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0) << "Negative _Retval index"; - TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n)); retval_cores[index] = core; } } @@ -200,8 +201,13 @@ Status BuildComputation( output.shape = output.constant_value.shape(); break; - case XlaExpression::Kind::kTensorList: - TF_FALLTHROUGH_INTENDED; + case XlaExpression::Kind::kTensorList: { + output.is_tensor_list = true; + xla::XlaOp value = retval.handle(); + elems.push_back(value); + break; + } + case XlaExpression::Kind::kXlaOp: { output.is_constant = false; TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); @@ -228,6 +234,11 @@ Status BuildComputation( } case XlaExpression::Kind::kResource: + // Resources are pushed into elems later when processing resource + // arguments. This is correct as long as the input and output resources + // are in the same order. In the case of functionalized while body, + // this property is guaranteed since a corresponding output is always + // created for a DT_RESOURCE input in a corresponding location. output.is_constant = false; output.input_index = retval.resource()->arg_num(); output.shape = retval.resource()->shape(); @@ -367,6 +378,9 @@ bool XlaCompiler::Argument::operator==( if (constant_value.shape() != other.constant_value.shape()) { return false; } + if (is_same_data_across_replicas != other.is_same_data_across_replicas) { + return false; + } return constant_value.tensor_data() == other.constant_value.tensor_data(); } @@ -377,6 +391,8 @@ string XlaCompiler::Argument::HumanString() const { } absl::StrAppend(&common, " type=", DataTypeString(type), " shape=", ShapeHumanString()); + absl::StrAppend( + &common, " is_same_data_across_replicas=", is_same_data_across_replicas); switch (kind) { case kInvalid: return "invalid"; @@ -398,6 +414,8 @@ string XlaCompiler::Argument::HumanString() const { } case kParameter: return absl::StrCat("kind=parameter", common); + case kTensorList: + return absl::StrCat("kind=tensorlist", common); case kToken: return absl::StrCat("token", common); } @@ -436,11 +454,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) FunctionDefLibrary{})); local_pflr_.reset(new ProcessFunctionLibraryRuntime( &device_mgr_, Env::Default(), options.graph_def_version, - local_flib_def_.get(), OptimizerOptions(), - nullptr /* custom_kernel_creator */)); + local_flib_def_.get(), OptimizerOptions())); pflr_.reset(new ProcessFunctionLibraryRuntime( &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def, - OptimizerOptions(), nullptr /* custom_kernel_creator */)); + OptimizerOptions())); local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); @@ -540,11 +557,12 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { } Status XlaCompiler::CompileFunction( - const XlaCompiler::CompileOptions& options, const NameAttrList& function, + const XlaCompiler::CompileOptions& options, + const NameAttrList& fn_name_attrs, absl::Span args, XlaCompiler::CompilationResult* result) { const string function_id = - Canonicalize(function.name(), AttrSlice(&function.attr())); + Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; const std::vector arg_vector(args.begin(), args.end()); @@ -555,11 +573,11 @@ Status XlaCompiler::CompileFunction( } const FunctionBody* fbody; - TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody)); + TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody)); TF_RETURN_WITH_CONTEXT_IF_ERROR( CheckSignature(fbody->arg_types, args), - "Signature check failure while compiling: ", function.name()); + "Signature check failure while compiling: ", fn_name_attrs.name()); std::unique_ptr graph = GetGraph(fbody); @@ -581,23 +599,21 @@ Status XlaCompiler::CompileFunction( // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == - FunctionLibraryDefinition::kArgOp) { + if (n->IsArg()) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == - FunctionLibraryDefinition::kRetOp) { + if (n->IsRetval()) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileFunction: " - << dump_graph::DumpGraphToFile( + << DumpGraphToFile( absl::StrCat("xla_compile_function_", function_id), *graph); } @@ -638,6 +654,11 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, } return Status::OK(); } + case XlaCompiler::Argument::kTensorList: { + TF_RET_CHECK(absl::holds_alternative(arg.shape)); + *xla_shape = absl::get(arg.shape); + return Status::OK(); + } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); @@ -741,6 +762,7 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kTensorList: case XlaCompiler::Argument::kToken: { input_to_args->push_back(i); break; @@ -754,7 +776,7 @@ Status XlaCompiler::BuildArguments( } } - if (input_to_args->empty()) { + if (input_to_args->empty() && !use_tuple_arg) { return Status::OK(); } @@ -799,9 +821,19 @@ Status XlaCompiler::BuildArguments( *tuple_sharding.add_tuple_shardings() = xla::sharding_builder::AssignDevice(core); } - xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, - tuple_sharding); - tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); + std::vector is_same_across_replicas; + for (int i = 0; i < input_to_args->size(); ++i) { + // Add an entry to is_same_across_replicas for every leaf buffer. + is_same_across_replicas.insert( + is_same_across_replicas.end(), + xla::ShapeUtil::GetLeafCount(arg_shapes[i]), + args[input_to_args->at(i)].is_same_data_across_replicas); + } + xla::XlaScopedShardingAssignment assign_tuple_sharding( + builder, input_to_args->empty() ? absl::optional() + : tuple_sharding); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple", + is_same_across_replicas); } else { tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } @@ -832,8 +864,18 @@ Status XlaCompiler::BuildArguments( xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], - absl::StrCat("arg", i)); + if (is_entry_computation) { + // Add an entry to is_same_across_replicas for every leaf buffer. + std::vector is_same_across_replicas( + xla::ShapeUtil::GetLeafCount((*input_shapes)[i]), + args[input_to_args->at(i)].is_same_data_across_replicas); + arg_handles[i] = + xla::Parameter(builder, i, (*input_shapes)[i], + absl::StrCat("arg", i), is_same_across_replicas); + } else { + arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], + absl::StrCat("arg", i)); + } } for (int i = 0; i < input_to_args->size(); ++i) { @@ -880,6 +922,10 @@ Status XlaCompiler::BuildArguments( arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } break; + case XlaCompiler::Argument::kTensorList: { + arg_expression = XlaExpression::TensorList(arg_handles[i]); + break; + } case XlaCompiler::Argument::kToken: { arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); break; @@ -956,6 +1002,28 @@ Status ValidateFunctionDef(const FunctionDef* fdef, return Status::OK(); } +// If node is PartitionedCall or StatefulPartitionedCall, returns the +// name from the "f" attr, else returns node.def().op(). +// Returned pointer points to the internal string either in node's attributes +// or in its NodeDef. This pointer is valid as long as the node has not been +// modified. +Status GetPotentialFunctionName(const Node& node, const string** name) { + if (node.IsPartitionedCall()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value)); + if (!attr_value->has_func()) { + return errors::InvalidArgument( + "The attribute value for attribute 'f' in node ", node.DebugString(), + " does not have 'func' field set"); + } + *name = &attr_value->func().name(); + return Status::OK(); + } + *name = &node.type_string(); + return Status::OK(); +} + // Check that the graph doesn't have any invalid nodes (e.g. incompatible with // given device_type, invalid data type, missing attributes...) Status ValidateGraph(const Graph* graph, @@ -975,7 +1043,9 @@ Status ValidateGraph(const Graph* graph, if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { continue; } - const FunctionDef* fdef = flib_def.Find(node->def().op()); + const string* function_name; + TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name)); + const FunctionDef* fdef = flib_def.Find(*function_name); Status s; if (fdef) { s = ValidateFunctionDef(fdef, flib_def); @@ -1029,11 +1099,15 @@ Status XlaCompiler::CompileGraph( TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes( graph.get(), options_.flib_def, local_flib_def_.get())); + TF_RETURN_IF_ERROR(RearrangeFunctionArguments( + [this](const NameAttrList& function, const FunctionBody** fbody) { + return FindFunctionBody(function, fbody); + }, + graph.get(), local_flib_def_.get())); if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " - << dump_graph::DumpGraphToFile( - absl::StrCat("xla_compile_graph_", name), *graph, - flib_runtime_->GetFunctionLibraryDefinition()); + << DumpGraphToFile(absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); } // Report the error here if initialization failed. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 0b0908e9d69..1cc5d8d4728 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -116,6 +116,9 @@ class XlaCompiler { // Argument is an XLA token. kToken, + + // Argument is a TensorList. + kTensorList, }; Kind kind = kInvalid; @@ -163,6 +166,9 @@ class XlaCompiler { std::map dynamic_dim_to_arg_num_map; bool is_pad_arg = false; + // Whether this argument will receive the same data across all replicas. + bool is_same_data_across_replicas = false; + bool operator==(const Argument& other) const; // Returns a human-readable summary of the argument. @@ -223,6 +229,9 @@ class XlaCompiler { // When this output is a resource, i.e. `type == DT_RESOURCE`, this is // the index of the input that contains the resource. int input_index; + + // Whether this output is a TensorList. + bool is_tensor_list = false; }; // Describes a variable write side effect of the computation. @@ -302,6 +311,12 @@ class XlaCompiler { // for CPU. bool allow_cpu_custom_calls = false; + // If both this and 'allow_cpu_custom_calls' are true then tf.fake_quant_* + // ops will be emitted as custom calls to a 'fake_quant_with_min_max_vars' + // function accepting the input, min, max, num_bits, and narrow_range values + // as runtime arguments. + bool custom_fake_quant_op_calls = false; + // If set, the XLA representation of variables represented to XLA as the // shape given by this shape function. Variables are reshaped to this shape // on write, and reshaped to their original shape on read. @@ -324,7 +339,7 @@ class XlaCompiler { // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly // allocate most or all available memory on the device, leaving none for the // compiler to access, unless it can use TensorFlow's allocator. - xla::DeviceMemoryAllocator* device_allocator = nullptr; + se::DeviceMemoryAllocator* device_allocator = nullptr; }; explicit XlaCompiler(Options options); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 1818d429032..16f18c0cc88 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -14,10 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiler.h" + #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/list_ops.h" +#include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -35,7 +40,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -1498,5 +1505,191 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { } } +TEST_F(XlaCompilerTest, OpsWithTensorListInput) { + FunctionDefLibrary fdef_lib; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + // Build cond fn for While. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + auto result = ops::Const(scope, {true}, {}); + ops::_Retval(scope.WithOpName("ret"), result, 0); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef)); + TF_ASSERT_OK(flib_def.AddFunctionDef(fdef)); + } + // Build body fn for While. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + ops::_Retval(scope.WithOpName("ret"), arg, 0); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef)); + TF_ASSERT_OK(flib_def.AddFunctionDef(fdef)); + } + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto element_shape = ops::Const(scope, {1}, {1}); + auto max_elements = ops::Const(scope, {10}, {}); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); + std::initializer_list out = {arg, arg}; + auto add_n = ops::AddN(scope, out); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = + ops::While(scope, std::initializer_list{arg}, cond_fn, body_fn); + auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add_n, 0); + auto ret1 = ops::_Retval(scope.WithOpName("ret1"), while_op.output[0], 1); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kTensorList; + xla::Shape tensor_list_element_shape; + TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{1}, + &tensor_list_element_shape)); + xla::Shape index_shape; + TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{}, &index_shape)); + std::vector shapes{tensor_list_element_shape, index_shape}; + xla::Shape arg_shape = xla::ShapeUtil::MakeTupleShape(shapes); + args[0].shape = arg_shape; + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.flib_def = &flib_def; + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, + /*user_aliases=*/{}, &result)); + ASSERT_EQ(result.outputs.size(), 2); + const XlaCompiler::OutputDescription& output0 = result.outputs[0]; + ASSERT_TRUE(output0.is_tensor_list); + const XlaCompiler::OutputDescription& output1 = result.outputs[1]; + ASSERT_TRUE(output1.is_tensor_list); +} + +// Test the compiler supports WhileOp with a loop body where DT_RESOURCE +// variables are both inputs and outputs. +TEST_F(XlaCompilerTest, WhileWithResources) { + FunctionDefLibrary fdef_lib; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + // Build cond fn for While. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1); + auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2); + auto less = ops::Less(scope, arg0, ops::Const(scope, 10)); + (void)ops::_Retval(scope.WithOpName("ret"), less, 0); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef)); + TF_ASSERT_OK(flib_def.AddFunctionDef(fdef)); + } + // Build body fn for While. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1); + auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2); + auto read1 = ops::ReadVariableOp(scope.WithOpName("read1"), arg1, DT_INT32); + auto plus_read1 = ops::Add(scope, arg0, read1); + auto read2 = ops::ReadVariableOp(scope.WithOpName("read2"), arg2, DT_INT32); + auto minus_read2 = ops::Sub(scope, plus_read1, read2); + (void)ops::_Retval(scope.WithOpName("ret0"), minus_read2, 0); + (void)ops::_Retval(scope.WithOpName("ret1"), arg1, 1); + (void)ops::_Retval(scope.WithOpName("ret2"), arg2, 2); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + FunctionDef fdef; + TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef)); + TF_ASSERT_OK(flib_def.AddFunctionDef(fdef)); + } + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1); + auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2); + + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = ops::While( + scope, std::initializer_list{arg0, arg1, arg2}, cond_fn, body_fn); + + (void)ops::_Retval(scope.WithOpName("ret0"), while_op.output[0], 0); + (void)ops::_Retval(scope.WithOpName("ret1"), while_op.output[1], 1); + (void)ops::_Retval(scope.WithOpName("ret2"), while_op.output[2], 2); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(3); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({}); + args[2].kind = XlaCompiler::Argument::kResource; + args[2].resource_kind = XlaResource::kVariable; + args[2].initialized = true; + args[2].type = DT_INT32; + args[2].shape = TensorShape({}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.flib_def = &flib_def; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options = XlaCompiler::CompileOptions(); + compile_options.return_updated_values_for_all_resources = true; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars", + std::move(graph), args, + /*user_aliases=*/{}, &result)); + ASSERT_EQ(result.outputs.size(), 3); + const XlaCompiler::OutputDescription& output1 = result.outputs[1]; + ASSERT_EQ(output1.input_index, 1); + const XlaCompiler::OutputDescription& output2 = result.outputs[2]; + ASSERT_EQ(output2.input_index, 2); + + // Tests that the generated computation works. + xla::Literal literal0 = xla::LiteralUtil::CreateR0(0); + xla::Literal literal1 = xla::LiteralUtil::CreateR0(2); + xla::Literal literal2 = xla::LiteralUtil::CreateR0(1); + std::unique_ptr data0 = + client_->TransferToServer(literal0).ConsumeValueOrDie(); + std::unique_ptr data1 = + client_->TransferToServer(literal1).ConsumeValueOrDie(); + std::unique_ptr data2 = + client_->TransferToServer(literal2).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, + {data0.get(), data1.get(), data2.get()}) + .ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR0(10); + xla::Literal expected1 = xla::LiteralUtil::CreateR0(2); + xla::Literal expected2 = xla::LiteralUtil::CreateR0(1); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 884dc45cb11..625809ae083 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -120,7 +121,7 @@ XlaJitCompiledCpuFunction::Compile( cpu_executable->buffer_assignment(); // Compute buffer infos and the result index, needed to run the raw function. - std::vector buffer_infos = + std::vector buffer_infos = xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment); std::vector arg_index_table = xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index a5392057177..11fc4571189 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/types.h" @@ -67,7 +68,7 @@ class XlaJitCompiledCpuFunction { XlaCompiledCpuFunction::StaticData static_data_; // The backing array for buffer infos. - std::vector buffer_infos_; + std::vector buffer_infos_; // The backing array for the arg index table. std::vector arg_index_table_; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index ee11f3a3de6..6996e39ba16 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -93,11 +94,26 @@ TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { } DataType XlaOpKernelContext::input_type(int index) const { - return context_->input_dtype(index); + DataType type = context_->input_dtype(index); + if (type == DT_UINT8) { + // Masqueraded XlaExpression could have different type. See + // XlaOpKernelContext::SetOutputExpression for details. + auto expression = CastExpressionFromTensor(context_->input(index)); + type = expression->dtype(); + } + return type; } DataType XlaOpKernelContext::InputType(absl::string_view name) { - return GetInputTensorByName(name).dtype(); + const Tensor& tensor = GetInputTensorByName(name); + DataType type = tensor.dtype(); + if (type == DT_UINT8) { + // Masqueraded XlaExpression could have different type. See + // XlaOpKernelContext::SetOutputExpression for details. + auto expression = CastExpressionFromTensor(tensor); + type = expression->dtype(); + } + return type; } xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { @@ -110,6 +126,16 @@ xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { return type; } +xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(InputType(name), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + Status XlaOpKernelContext::ConstantInput(int index, xla::Literal* constant_literal) { return ConstantInputReshaped( @@ -151,9 +177,9 @@ Status XlaOpKernelContext::ConstantInputReshaped( absl::optional constant = constant_or_status.ValueOrDie(); if (!constant.has_value()) { return errors::InvalidArgument( - "Input ", index, " to ", context_->op_kernel().type_string(), - " operator must be a compile-time constant.\n" - "\n" + "Input ", index, " to node `", context_->op_kernel().name(), + "` with op ", context_->op_kernel().type_string(), + " must be a compile-time constant.\n\n" "XLA compilation requires that operator arguments that represent " "shapes or dimensions be evaluated to concrete values at compile time. " "This error means that a shape or dimension argument could not be " @@ -439,28 +465,27 @@ void XlaOpKernelContext::SetOutputExpression(int index, // The step's default allocator is the dummy XlaCompilationAllocator which // simply allocates a metadata buffer to hold the expression to which it // corresponds. - Tensor* output = nullptr; - // Provides a special behavior for DT_VARIANT: a variant is treated as - // DT_UINT8 scalar as the type to allow mapping for variant to more generic - // types. - if (expression.dtype() == DT_VARIANT) { - // tensor_data() is not supported for variant Tensor (i.e., - // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the - // XlaExpression inside the Tensor's tensor_data() does not work for - // variant. Instead construct a uint8 tensor and store the expression in - // its value. + // Provides a special behavior for DT_VARIANT and other types that are not + // trivially copyable. In those cases, allocate a tensor of type DT_UINT8. + if (!DataTypeCanUseMemcpy(expression.dtype())) { + // tensor_data() is not supported for tensors that cannot be copied via + // memcpy, as the copy logic might try to inspect the stored data (e.g. + // a std::string). This is likely to fail, as the data is invalid given + // that it actually encodes an XlaExpression. Using a uint8 tensor is + // always safe, so simply do that. // TODO(jpienaar): This should be refactored to stop masquerading // XlaExpressions as Tensors. - output = new Tensor(); + Tensor output; TensorShape tensor_shape; TF_RETURN_IF_ERROR( - context_->allocate_temp(DT_UINT8, tensor_shape, output)); - context_->set_output(index, *output); + context_->allocate_temp(DT_UINT8, tensor_shape, &output)); + context_->set_output(index, output); } else { + Tensor* output = nullptr; TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); } - AssignExpressionToTensor(output, expression); + AssignExpressionToTensor(context_->mutable_output(index), expression); return Status::OK(); }(); if (!status.ok()) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index cc2d5e8de3e..7794786f905 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -81,6 +81,11 @@ class XlaOpKernelContext { // xla::PRIMITIVE_TYPE_INVALID. xla::PrimitiveType input_xla_type(int index); + // Returns the type of input `name` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType InputXlaType(absl::string_view name); + // Returns the shape of input `index`. TensorShape InputShape(int index); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 9470c7e334c..1f298ee4bc4 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -78,6 +78,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_variant_types settings."; return false; } + if (x.allow_string_type != y.allow_string_type) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible allow_string_type settings."; + return false; + } if (!x.has_device_whitelist && !y.has_device_whitelist) { LOG(WARNING) << "Duplicate registrations of " << x.name << "with no device whitelists."; @@ -148,7 +153,6 @@ XlaOpRegistry::~XlaOpRegistry() = default; cpu_global_jit ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested; - registration.compile_resource_ops = false; } if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = @@ -156,7 +160,6 @@ XlaOpRegistry::~XlaOpRegistry() = default; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.autoclustering_policy = XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally; - registration.compile_resource_ops = false; } return nullptr; }(); @@ -298,6 +301,9 @@ void XlaOpRegistry::RegisterCompilationKernels() { if (op_registration->allow_variant_types) { allowed_values->add_type(DT_VARIANT); } + if (op_registration->allow_string_type) { + allowed_values->add_type(DT_STRING); + } // Don't build KernelDefs that have unsatisfiable type constraints. if (allowed_values->type().empty()) { unsatisfiable_type_constraint = true; @@ -499,6 +505,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowVariantTypes() { return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowStringType() { + registration_->allow_string_type = true; + return *this; +} + XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( absl::string_view attr_name, DataType allowed) { std::set& types = diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 80d022b592c..95d1bf25150 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -88,8 +88,37 @@ class XlaOpRegistry { // When should we autocluster operators assigned to this device? AutoclusteringPolicy autoclustering_policy; - // Enable compilation of operators that use DT_RESOURCE types? - bool compile_resource_ops = false; + // If we should ignore the resource variable memory model when clustering + // resource variable reads and writes placed on this device. + bool cluster_resource_variable_ops_unsafely = false; + + // If we should auto-cluster Stack operations placed on this device. + bool cluster_stack_ops = false; + + // If we should auto-cluster TensorArray operations placed on this device. + bool cluster_tensor_array_ops = false; + + // If we should auto-cluster stateful RNG operations placed on this device. + // Stateful RNG semantics are not properly supported by XLA so it is not + // necessarily correct to auto-cluster stateful RNG ops in general. + bool cluster_stateful_rng_ops = false; + + // If we should auto-cluster ControlTrigger operations placed on this + // device. ControlTrigger operations are not necessarily safe to cluster + // since they affect deadness (a dead ControlTrigger produces a live + // output). + bool cluster_control_trigger = false; + + // If we should cluster Assert and CheckNumerics by eliding them (XLA does + // not natively support Assert or CheckNumerics). + bool elide_assert_and_checknumerics = false; + + // If we should cluster operations returning DT_VARIANT. + bool cluster_variant_ops = false; + + // Whether ops known to be slow or to have correctness issues should be + // auto-clustered. + bool cluster_slow_and_inaccurate_ops = false; }; // Registers an XLA backend. `compilation_device_name` is the name of the @@ -216,6 +245,10 @@ class XlaOpRegistry { // allow TensorList which is of type DT_VARIANT. bool allow_variant_types = false; + // Should we allow string type for type attributes? Used by PartitionedCall + // to allow DT_STRING. + bool allow_string_type = false; + // Mapping from attribute name to a list of supported types. std::unordered_map> type_constraints; @@ -300,6 +333,9 @@ class XlaOpRegistrationBuilder { // Allow DT_VARIANT types for type parameters. XlaOpRegistrationBuilder& AllowVariantTypes(); + // Allow DT_STRING type for type parameters. + XlaOpRegistrationBuilder& AllowStringType(); + // Mark 'input_name' as an argument whose value must be known at compile-time. XlaOpRegistrationBuilder& CompileTimeConstantInput( absl::string_view input_name); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index ee6f7d5956e..91f33ff914e 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -18,8 +18,7 @@ package_group( ], ) -load("//tensorflow:tensorflow.bzl", "cc_header_only_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load( "//tensorflow/core:platform/default/build_config.bzl", @@ -57,6 +56,24 @@ xla_proto_library( ], ) +cc_library( + name = "comparison_util", + srcs = [ + "comparison_util.cc", + ], + hdrs = [ + "comparison_util.h", + ], + visibility = [":friends"], + deps = [ + ":statusor", + ":types", + ":util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "execution_options_util", srcs = [ @@ -67,8 +84,8 @@ cc_library( ], visibility = [":friends"], deps = [ + ":debug_options_flags", ":xla_proto", - "//tensorflow/compiler/xla:debug_options_flags", ], ) @@ -172,6 +189,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -269,6 +287,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "@com_google_absl//absl/hash:hash_testing", "@com_google_absl//absl/strings", ], ) @@ -379,7 +398,6 @@ tf_cc_test( ":shape_util", ":test", ":types", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -794,7 +812,7 @@ cc_library( hdrs = ["parse_flags_from_env.h"], deps = [ - "//tensorflow/compiler/xla:types", + ":types", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", @@ -809,7 +827,7 @@ tf_cc_test( deps = [ ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", + ":types", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -828,13 +846,22 @@ cc_library( deps = [ ":parse_flags_from_env", - "//tensorflow/compiler/xla:xla_proto", + ":status", + ":xla_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", "@com_google_absl//absl/strings", ], ) +cc_library( + name = "cpu_function_runtime", + srcs = ["cpu_function_runtime.cc"], + hdrs = ["cpu_function_runtime.h"], + visibility = [":friends"], + deps = ["//tensorflow/core:framework_lite"], +) + tf_cc_test( name = "debug_options_parsers_test", size = "small", @@ -844,7 +871,7 @@ tf_cc_test( ], deps = [ - "//tensorflow/compiler/xla:xla_proto", + ":xla_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index f5d56e8a9e1..b800229bd90 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -96,7 +96,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -117,7 +117,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:local_service", @@ -125,6 +125,7 @@ cc_library( "//tensorflow/compiler/xla/service:source_map_util", "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", "@llvm//:support", @@ -164,11 +165,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compile_only_service", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], @@ -212,6 +213,7 @@ cc_library( ":padding", ":sharding_builder", ":xla_computation", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 62d225c6c29..33d1de370de 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -39,6 +38,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index ec0e0897592..d5de53a7941 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -22,12 +22,12 @@ limitations under the License. namespace xla { ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator( - DeviceMemoryAllocator* allocator) { + se::DeviceMemoryAllocator* allocator) { device_allocator_ = allocator; return *this; } -DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const { +se::DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const { return device_allocator_; } @@ -71,9 +71,8 @@ string ExecutableBuildOptions::ToString() const { } return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " - "generate_hlo_graph=%s, num_replicas=%d}", - device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph(), - num_replicas_); + "num_replicas=%d}", + device_ordinal_, result_layout, num_replicas_); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 1d85fb34304..e2e231981bf 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -18,11 +18,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { @@ -57,11 +57,11 @@ class ExecutableBuildOptions { // want to run various algorithms on the device and pick the fastest one -- it // might allocate buffers for use by these algorithms using this allocator. // - // This does not need to be the same as the DeviceMemoryAllocator passed when - // running the executable. + // This does not need to be the same as the se::DeviceMemoryAllocator passed + // when running the executable. ExecutableBuildOptions& set_device_allocator( - DeviceMemoryAllocator* allocator); - DeviceMemoryAllocator* device_allocator() const; + se::DeviceMemoryAllocator* allocator); + se::DeviceMemoryAllocator* device_allocator() const; // Returns a string representation of the build options, suitable for // debugging. @@ -77,7 +77,7 @@ class ExecutableBuildOptions { Shape result_layout_; bool result_layout_set_ = false; absl::optional debug_options_; - DeviceMemoryAllocator* device_allocator_ = nullptr; + se::DeviceMemoryAllocator* device_allocator_ = nullptr; int num_replicas_ = 1; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 1ddd3c2a455..4a99debbe70 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -296,6 +296,7 @@ cc_library( hdrs = ["slicing.h"], deps = [ "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", ], @@ -471,6 +472,7 @@ cc_library( xla_test( name = "svd_test", srcs = ["svd_test.cc"], + # Blacklisted because the tests are flaky. blacklisted_backends = [ "cpu", "gpu", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 3b875135af2..d34ecaf99c8 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -125,8 +125,60 @@ XlaOp Any(XlaOp predicates) { namespace { +XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder, + PrimitiveType value_type, + PrimitiveType index_type, bool is_min) { + auto sub_builder = outer_builder->CreateSubBuilder("minmax_func"); + XlaBuilder* b = sub_builder.get(); + XlaOp lhs_value = + Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value"); + XlaOp lhs_index = + Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index"); + XlaOp rhs_value = + Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value"); + XlaOp rhs_index = + Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index"); + + auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value); + XlaOp max = Select(cmp, lhs_value, rhs_value); + XlaOp arg_max = Select(cmp, lhs_index, rhs_index); + Tuple(b, {max, arg_max}); + return b->Build().ConsumeValueOrDie(); +} + XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + XlaOp value_init_value; + if (is_min) { + value_init_value = MaxValue(builder, input_shape.element_type()); + } else { + value_init_value = MinValue(builder, input_shape.element_type()); + } + int64 dimension_size = input_shape.dimensions(axis); + auto index_type = dimension_size <= INT32_MAX ? S32 : output_type; + XlaOp index_init_value = Zero(builder, index_type); + auto iota_shape = input_shape; + iota_shape.set_element_type(index_type); + XlaOp iota = Iota(builder, iota_shape, axis); + + XlaComputation reducer = CreateMinMaxComputation( + builder, input_shape.element_type(), index_type, is_min); + XlaOp max_argmax = Reduce(builder, {input, iota}, + {value_init_value, index_init_value}, reducer, + /*dimensions_to_reduce=*/{axis}); + XlaOp argmax = GetTupleElement(max_argmax, 1); + if (index_type != output_type) { + argmax = ConvertElementType(argmax, output_type); + } + return argmax; + }); +} + +XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis, + bool is_min) { + XlaBuilder* builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); XlaOp init_value; @@ -172,7 +224,6 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { /*dimensions_to_reduce=*/{axis}); }); } - } // namespace XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) { @@ -183,4 +234,11 @@ XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) { return ArgMinMax(input, output_type, axis, /*is_min=*/true); } +XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false); +} + +XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true); +} } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index d4a7812c441..6f64d587fa8 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -60,10 +60,12 @@ XlaOp Any(XlaOp predicates); // Returns the argmax of `input` along `axis`. `output_type` is the type to // use for the output. XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); +XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis); // Returns the argmin of `input` along `axis`. `output_type` is the type to // use for the output. XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis); +XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc index c620c9841a5..11a79a262ef 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.cc +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -32,8 +32,7 @@ limitations under the License. namespace xla { namespace { -using XlaOpGenerator = XlaOp (*)(const XlaOp&, const XlaOp&, - absl::Span); +using XlaOpGenerator = XlaOp (*)(XlaOp, XlaOp, absl::Span); XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, int64 bit_width) { diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 20d3c0fc549..3d15101ea66 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -528,28 +528,149 @@ XlaOp Asin(XlaOp x) { XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); } -XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); } +XlaOp Tan(XlaOp x) { + return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); }); +} // Hyperbolic trigonometric functions. -// acosh(x) = log(x + sqrt(x^2 - 1)) +// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 // = log(x + sqrt((x+1)*(x-1))) +// acosh(x) = nan if x < -1 +// +// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as +// log(2*x) = log(2) + log(x). (Note this works because negative x never +// overflows; x < -1 simply yields nan. This is quite different than asinh!) XlaOp Acosh(XlaOp x) { - return Log(x + Sqrt((x + ScalarLike(x, 1.0)) * (x - ScalarLike(x, 1.0)))); + XlaBuilder* b = x.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); + + auto one = ScalarLike(x, 1); + auto neg_one = ScalarLike(x, -1); + auto nan = FullLike(x, std::numeric_limits::quiet_NaN()); + + // return + // + // nan if x < -1 + // log(x) + log(2) if x >= sqrt_max_value + // log(x + sqrt((x+1)*(x-1))) otherwise + // + // TODO(jlebar): For now, we ignore the question of overflow if x is a + // complex type, because we don't yet have exhaustive tests for complex trig + // functions. + auto naive_result = Log(x + Sqrt((x + one) * (x - one))); + if (primitive_util::IsComplexType(shape.element_type())) { + return naive_result; + } + auto overflow_result = Log(x) + Log(ScalarLike(x, 2)); + + auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type())); + return Select(Lt(x, neg_one), nan, + Select(Ge(x, sqrt_max_value), overflow_result, naive_result)); + }); } // asinh(x) = log(x + sqrt(x^2 + 1)) -XlaOp Asinh(XlaOp x) { return Log(x + Sqrt(x * x + ScalarLike(x, 1.0))); } +// +// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) +// as 2*x and return log(2) + log(x). +// +// If x is negative, the above would give us some trouble; we can't approximate +// the result as x + abs(x) = 0! But we're saved by the fact that asinh(-x) = +// -asinh(x). +XlaOp Asinh(XlaOp x) { + XlaBuilder* b = x.builder(); + auto do_it = [&](XlaOp x) -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); + auto one = ScalarLike(x, 1); -// atanh(x) = 0.5 * log((1 + x) / (1 - x)) -XlaOp Atanh(XlaOp x) { - return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) * - ScalarLike(x, 0.5); + // Let a = abs(x). Compute + // + // y = log(a + sqrt(a*a + 1)) if a < sqrt_max_value, or + // y = log(a) + log(2) otherwise + // + // and then return + // + // y * sign(x). + // + // TODO(jlebar): For now, we ignore the question of overflow if x is a + // complex type, because we don't yet have exhaustive tests for complex trig + // functions. + if (primitive_util::IsComplexType(shape.element_type())) { + return Log(x + Sqrt(x * x + one)); + } + auto a = Abs(x); + auto naive_result = Log(a + Sqrt(a * a + one)); + auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2)); + auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type())); + return Sign(x) * + Select(Ge(a, sqrt_max_value), overflow_result, naive_result); + }; + // These upcasts are not strictly necessary on all platforms to get within our + // error tolerances, so we could relax this if it ever mattered. + return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) { + return b->ReportErrorOrReturn(do_it(x)); + }); } -XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } +// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 +// atanh(x) = nan otherwise +XlaOp Atanh(XlaOp x) { + XlaBuilder* b = x.builder(); + auto do_it = [&](XlaOp x) -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); + auto naive_result = + Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) * + ScalarLike(x, 0.5); -XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } + // TODO(jlebar): For now, we ignore the nan edge case for complex inputs, + // because we don't yet have exhaustive tests for complex trig functions. + if (primitive_util::IsComplexType(shape.element_type())) { + return naive_result; + } + + auto nan = FullLike(x, std::numeric_limits::quiet_NaN()); + return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result); + }; + return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) { // + return b->ReportErrorOrReturn(do_it(x)); + }); +} + +// Cosh(x) = (e^x + e^-x) / 2 +// = e^(x + log(1/2)) + e^(-x + log(1/2)). +// +// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not +// inf. +// +// This incorrectly overflows to inf for two f32 input values, namely +// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The +// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so +// we deem this acceptable. +XlaOp Cosh(XlaOp x) { + return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { + auto log_one_half = Log(ScalarLike(x, 0.5)); + return Exp(x + log_one_half) + Exp(-x + log_one_half); + }); +} + +// Sinh(x) = (e^x - e^-x) / 2 +// = e^(x + log(1/2)) - e^(-x + log(1/2)). +// +// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not +// inf. +// +// This incorrectly overflows to +/-inf for two f32 input values, namely +// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The +// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so +// we deem this acceptable. +XlaOp Sinh(XlaOp x) { + return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { + auto log_one_half = Log(ScalarLike(x, 0.5)); + return Exp(x + log_one_half) - Exp(-x + log_one_half); + }); +} XlaOp MaybeConjugate(XlaOp x, bool conjugate) { XlaBuilder* builder = x.builder(); @@ -639,4 +760,9 @@ XlaOp NextAfter(XlaOp from, XlaOp to) { }); } +XlaOp Logistic(XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + return half + half * xla::Tanh(half * x); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 71a3acedcec..89a58aa3970 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -101,6 +101,9 @@ XlaOp Sinh(XlaOp x); // is true, otherwise returns its argument. xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); +// Computes the logistic function: logistic(x) = 0.5 + 0.5 * tanh(0.5 * x). +XlaOp Logistic(XlaOp x); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 50613ce5025..d0429990f87 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -180,7 +180,7 @@ XLA_TEST_F(MathTest, RealFpOnlyOps) { shape = ShapeUtil::MakeShape(ty, {42}); } else if (ty == PrimitiveType::TUPLE) { shape = ShapeUtil::MakeTupleShape({}); - } else if (ty == PrimitiveType::OPAQUE) { + } else if (ty == PrimitiveType::OPAQUE_TYPE) { shape = ShapeUtil::MakeOpaqueShape(); } else if (ty == PrimitiveType::TOKEN) { shape = ShapeUtil::MakeTokenShape(); diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index a055a8e625c..93f3d3ab131 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -253,28 +253,86 @@ xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, } XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + return BatchDot(x, false, y, false, precision); +} + +XlaOp BatchDot(XlaOp x, bool transpose_x, XlaOp y, bool transpose_y, + PrecisionConfig::Precision precision) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - const int ndims = x_shape.rank(); - batch_dimension_numbers.reserve(ndims - 2); + // The batch dimensions must be broadcast-compatible and the matrix + // dimensions must be valid. + std::vector x_config; + std::vector y_config; + std::vector output_config; + + std::vector x_implicit_broadcast; + std::vector y_implicit_broadcast; + + const int64 ndims = std::max(y_shape.rank(), x_shape.rank()); + // If X and Y have unequal ranks, the major dimensions of the higher rank + // shape are broadcasted. + // + // A dimension of size 1 can be implicitly broadcasted to any other + // dimension. + const int64 x_offset = std::max(0, y_shape.rank() - x_shape.rank()); + const int64 y_offset = std::max(0, x_shape.rank() - y_shape.rank()); for (int i = 0; i < ndims - 2; ++i) { - batch_dimension_numbers.push_back(i); + const int64 x_dim = i - x_offset; + const int64 y_dim = i - y_offset; + output_config.push_back(i); + if (x_dim < 0) { + y_config.push_back(i); + } else if (y_dim < 0) { + x_config.push_back(i); + } else if (x_shape.dimensions(x_dim) == y_shape.dimensions(y_dim)) { + y_config.push_back(i); + x_config.push_back(i); + } else if (x_shape.dimensions(x_dim) == 1) { + y_config.push_back(i); + x_implicit_broadcast.push_back(x_dim); + } else if (y_shape.dimensions(y_dim) == 1) { + x_config.push_back(i); + y_implicit_broadcast.push_back(y_dim); + } else { + return InvalidArgument("Expected batch dot dimension to be equal or 1"); + } + } + if (transpose_x) { + x_config.push_back(ndims); + x_config.push_back(ndims - 2); + } else { + x_config.push_back(ndims - 2); + x_config.push_back(ndims); + } + if (transpose_y) { + y_config.push_back(ndims - 1); + y_config.push_back(ndims); + } else { + y_config.push_back(ndims); + y_config.push_back(ndims - 1); } - std::vector x_config = batch_dimension_numbers; - x_config.push_back(ndims - 2); - x_config.push_back(ndims); - std::vector y_config = batch_dimension_numbers; - y_config.push_back(ndims); - y_config.push_back(ndims - 1); - std::vector output_config = batch_dimension_numbers; output_config.push_back(ndims - 2); output_config.push_back(ndims - 1); + if (!x_implicit_broadcast.empty()) { + x_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(x_implicit_broadcast, dim); + }, + x_shape); + x = Reshape(x, x_shape.dimensions()); + } + if (!y_implicit_broadcast.empty()) { + y_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(y_implicit_broadcast, dim); + }, + y_shape); + y = Reshape(y, y_shape.dimensions()); + } return Einsum(x, x_config, y, y_config, output_config, precision); }); } diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 60c41ec45a0..5f1ca964a41 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ #include + #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -73,6 +74,9 @@ XlaOp LowerTriangle(XlaOp x); xla::XlaOp BatchDot( xla::XlaOp x, xla::XlaOp y, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); +xla::XlaOp BatchDot( + xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); // Parse an einsum string into dimension numbers: // "ab,cb->ac" diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 63b3b07ddc2..77ebb75b051 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -24,6 +24,15 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { + +xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, + absl::Span scalars) { + std::vector vectors; + absl::c_transform(scalars, std::back_inserter(vectors), + [](xla::XlaOp x) { return xla::Reshape(x, {1}); }); + return ConcatInDim(builder, vectors, 0); +} + namespace { // Rotates a 32-bit integer 'v' left by 'distance' bits. @@ -32,8 +41,12 @@ XlaOp RotateLeftU32(XlaOp v, int distance) { ShiftRightLogical(v, ConstantR0(v.builder(), 32 - distance)); } -} // namespace +// The internal state of the Three Fry implementation. +using ThreeFry2x32State = std::array; +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { XlaBuilder* builder = input[0].builder(); key[0] = BitcastConvertType(key[0], U32); @@ -104,56 +117,277 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { return x; } -// Returns the inputs with unique counter values for ThreeFry2x32. -ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) { - ThreeFry2x32State inputs; - inputs[0] = Iota(builder, U32, size); - inputs[1] = inputs[0] + ConstantR0(builder, size); - return inputs; -} - -XlaOp StatelessRngUniformU32(std::array key, const Shape& shape) { - XlaBuilder* builder = key[0].builder(); - const int64 size = ShapeUtil::ElementsIn(shape); - const int64 half_size = CeilOfRatio(size, 2); - const bool size_is_odd = (half_size * 2 != size); - ThreeFry2x32State inputs = GetInputs(half_size, builder); - ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); - if (size_is_odd) { - outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); - } - auto result = ConcatInDim(builder, outputs, 0); - return Reshape(result, AsInt64Slice(shape.dimensions())); -} - -ThreeFry2x32State Uint64ToUint32s(XlaOp u64) { - auto builder = u64.builder(); - auto const32 = ConstantR0WithType(builder, U64, 32); - auto fst = ConvertElementType(u64, U32); - auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32); +// Converts a uint64 to two uint32s. +std::array Uint64ToUint32s(XlaOp u64) { + XlaBuilder* builder = u64.builder(); + XlaOp const32 = ConstantR0WithType(builder, U64, 32); + XlaOp fst = ConvertElementType(u64, U32); + XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32); return {fst, snd}; } -XlaOp Uint32sToUint64(ThreeFry2x32State u32s) { - auto builder = u32s[0].builder(); +// Converts two uint32s to a uint64. +XlaOp Uint32sToUint64(std::array u32s) { + XlaBuilder* builder = u32s[0].builder(); return ConvertElementType(u32s[0], U64) | ShiftLeft(ConvertElementType(u32s[1], U64), ConstantR0WithType(builder, U64, 32)); } -XlaOp StatelessRngUniformU64(std::array key, const Shape& shape) { - XlaBuilder* builder = key[0].builder(); - const int64 size = ShapeUtil::ElementsIn(shape); - ThreeFry2x32State inputs = GetInputs(size, builder); - ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); - // low 32 bit: outputs[0], high 32 bit: outputs[1] - auto result = Uint32sToUint64(outputs); - return Reshape(result, AsInt64Slice(shape.dimensions())); +// Given the initial state and the request number of random numbers to be +// generated, returns the input for the random number generator and a new state. +std::pair GetThreeFryInputsAndUpdatedState( + XlaOp initial_state, const int64 size) { + XlaBuilder* builder = initial_state.builder(); + XlaOp input_u64 = Iota(builder, U64, size); + input_u64 = input_u64 + initial_state; + XlaOp new_state = initial_state + ConstantR0(builder, size); + return std::make_pair(Uint64ToUint32s(input_u64), new_state); } -XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { - XlaBuilder* builder = bits.builder(); +// Generates random 32bits with the given shape using the Three Fry +// implementation. Returns the random bits and the new state. +RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) { + XlaBuilder* builder = key.builder(); + const int64 size = ShapeUtil::ElementsIn(shape); + const int64 half_size = CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + std::pair inputs_state = + GetThreeFryInputsAndUpdatedState(initial_state, half_size); + ThreeFry2x32State inputs = inputs_state.first; + ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); + if (size_is_odd) { + outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + XlaOp result = ConcatInDim(builder, outputs, 0); + return {Reshape(result, AsInt64Slice(shape.dimensions())), + inputs_state.second}; +} +// Generates random 64bits with the given shape using the Three Fry +// implementation. Returns the random bits and the new state. +RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) { + const int64 size = ShapeUtil::ElementsIn(shape); + std::pair inputs_state = + GetThreeFryInputsAndUpdatedState(initial_state, size); + ThreeFry2x32State inputs = inputs_state.first; + ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); + XlaOp result = Uint32sToUint64(outputs); + return {Reshape(result, AsInt64Slice(shape.dimensions())), + inputs_state.second}; +} + +// The key of the Philox random number generator. +using Philox4x32Key = std::array; +// The internal state of the Philox random number generator. +using Philox4x32State = std::array; + +// Computes the Philox4x32 algorithm using 10 rounds. +Philox4x32State Philox4x32(Philox4x32State state, Philox4x32Key key) { + // Constants specified by the Philox algorithm. + static const uint32 kPhiloxW32A = 0x9E3779B9; + static const uint32 kPhiloxW32B = 0xBB67AE85; + static const uint32 kPhiloxM4x32A = 0xD2511F53; + static const uint32 kPhiloxM4x32B = 0xCD9E8D57; + + struct HighLowPair { + XlaOp high; + XlaOp low; + }; + + // Compute the high and low words from multiplying two 32-bit integers. + auto mul_hi_low = [](XlaOp x, uint32 k) { + auto product = + ConvertElementType(x, U64) * ConstantR0(x.builder(), k); + auto low = ConvertElementType(product, U32); + auto high = + ConvertElementType(product >> ConstantR0(x.builder(), 32), U32); + return HighLowPair{high, low}; + }; + + // Perform a single round of the Philox algorithm. + auto philox_round = [&](Philox4x32State x, Philox4x32Key key) { + auto product0 = mul_hi_low(x[0], kPhiloxM4x32A); + auto product1 = mul_hi_low(x[2], kPhiloxM4x32B); + return Philox4x32State{product1.high ^ x[1] ^ key[0], product1.low, + product0.high ^ x[3] ^ key[1], product0.low}; + }; + + // Update the key after a round of Philox algorithm. + auto raise_key = [](Philox4x32Key key) { + XlaBuilder* builder = key[0].builder(); + return Philox4x32Key{key[0] + ConstantR0(builder, kPhiloxW32A), + key[1] + ConstantR0(builder, kPhiloxW32B)}; + }; + + static const int kNumRounds = 10; + for (int round = 0; round < kNumRounds; ++round, key = raise_key(key)) { + state = philox_round(state, key); + } + return state; +} + +// Scrambles the input key so that users don't need to worry about which part +// of the key needs to be strong. +std::pair ScramblePhiloxKey(Philox4x32Key key) { + XlaBuilder* builder = key[0].builder(); + XlaOp key0 = ConvertElementType(key[0], U64); + XlaOp key1 = ConvertElementType(key[1], U64); + + Philox4x32State state = { + ConvertElementType(key0, U32), + ConvertElementType(key0 >> ScalarLike(key0, 32), U32), + ConvertElementType(key1, U32), + ConvertElementType(key1 >> ScalarLike(key1, 32), U32), + }; + key = {ConstantR0(builder, 0x3ec8f720), + ConstantR0(builder, 0x02461e29)}; + state = Philox4x32(state, key); + XlaOp zero = ConstantR0(builder, 0); + return {Philox4x32State{zero, zero, state[2], state[3]}, + Philox4x32Key{state[0], state[1]}}; +} + +// Adds an U128 tensor with an U64 tensor. The U128 tensor is represented as two +// U64s with the low 64bits in the front. This routine supports explicit +// broadcasting of the U128 tensor, with `broadcast_sizes` representing the +// dimensions prepended to its shape. +std::array Uint128AddUint64( + const std::array& u128, XlaOp u64, + absl::Span broadcast_sizes = {}) { + auto u128_low = u128[0]; + auto u128_high = u128[1]; + XlaOp new_u128_low = u128_low + u64; + XlaOp one = ConstantR0(u128[0].builder(), 1); + XlaOp new_u128_high = Select(Lt(new_u128_low, u128_low), + Broadcast(u128_high + one, broadcast_sizes), + Broadcast(u128_high, broadcast_sizes)); + return {new_u128_low, new_u128_high}; +} + +std::array Uint32sToUint128(const std::array& u32s) { + return {Uint32sToUint64({u32s[0], u32s[1]}), + Uint32sToUint64({u32s[2], u32s[3]})}; +} + +std::array Uint128ToUint32s(const std::array& u128) { + std::array u128_low_32s = Uint64ToUint32s(u128[0]); + std::array u128_high_32s = Uint64ToUint32s(u128[1]); + return {u128_low_32s[0], u128_low_32s[1], u128_high_32s[0], u128_high_32s[1]}; +} + +std::array Uint128FromOp(XlaOp op) { + auto u128_low = xla::Reshape(xla::Slice(op, {0}, {1}, {1}), {}); + auto u128_high = xla::Reshape(xla::Slice(op, {1}, {2}, {1}), {}); + return {u128_low, u128_high}; +} + +XlaOp Uint128ToOp(std::array u128) { + return ConcatScalars(u128[0].builder(), {u128[0], u128[1]}); +} + +// Returns the pair (state + [0, 1, ..., n-1], state + n), which should be used +// as the inputs fed to `Philox4x32` and the updated state. `state` is an U128 +// represented as 4 U32s in the order from the least significant one to the most +// significant one. +std::pair GetPhiloxInputsAndUpdatedState( + const Philox4x32State& state, int64 n) { + XlaBuilder* builder = state[0].builder(); + XlaOp iota = Iota(builder, U64, n); + auto state_u128 = Uint32sToUint128(state); + auto inputs = Uint128ToUint32s(Uint128AddUint64(state_u128, iota, {n})); + XlaOp new_state = + Uint128ToOp(Uint128AddUint64(state_u128, ConstantR0(builder, n))); + return std::make_pair(inputs, new_state); +} + +// Generates CeilOfRatio(num_elems, 4)*4 32bit Philox random numbers, as Philox +// numbers are generated in the unit of 128bits. +std::pair GeneratePhiloxBits(int64 num_elems, + XlaOp initial_state, + Philox4x32Key key, + bool scramble) { + Philox4x32State state; + if (scramble) { + // When `scramble` is true, `initial_state` is not used. This is because + // scramble is true only when this function is called by stateless random + // ops, for which `initial_state` is always zero. + std::tie(state, key) = ScramblePhiloxKey(key); + } else { + state = Uint128ToUint32s(Uint128FromOp(initial_state)); + } + const int64 num_vector4 = CeilOfRatio(num_elems, 4); + Philox4x32State inputs; + XlaOp new_state; + std::tie(inputs, new_state) = + GetPhiloxInputsAndUpdatedState(state, num_vector4); + auto outputs = Philox4x32(inputs, key); + return std::make_pair(outputs, new_state); +} + +// Generates an array of primitive type U32 with the given shape containing +// random bits generated by the Philox algorithm. Returns the array and the new +// state of the random number generator. +RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, const Shape& shape, + bool scramble) { + XlaBuilder* builder = op_key.builder(); + const int64 num_elems = ShapeUtil::ElementsIn(shape); + + Philox4x32Key key = Uint64ToUint32s(op_key); + Philox4x32State bits; + XlaOp new_state; + std::tie(bits, new_state) = + GeneratePhiloxBits(num_elems, initial_state, key, scramble); + // Combining bits[i] in a round-robin fashion, to align with non-XLA + // implementations + int64 bits_len = (num_elems + 3) / 4; + for (auto i = 0; i < 4; ++i) { + bits[i] = Reshape(bits[i], {bits_len, 1}); + } + XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]}, + /*dimension=*/1); + numbers = Reshape(numbers, {bits_len * 4}); + numbers = Slice(numbers, /*start_indices=*/{0}, + /*limit_indices=*/{num_elems}, + /*strides=*/{1}); + return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state}; +} + +// Generates an array of primitive type U64 with the given shape containing +// random bits generated by the Philox algorithm. Returns the array and the new +// state of the random number generator. +RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, const Shape& shape, + bool scramble) { + XlaBuilder* builder = op_key.builder(); + const int64 num_elems = ShapeUtil::ElementsIn(shape); + + Philox4x32Key key = Uint64ToUint32s(op_key); + Philox4x32State bits32; + XlaOp new_state; + std::tie(bits32, new_state) = + GeneratePhiloxBits(num_elems * 2, initial_state, key, scramble); + + std::array bits64; + bits64[0] = Uint32sToUint64({bits32[0], bits32[1]}); + bits64[1] = Uint32sToUint64({bits32[2], bits32[3]}); + + // Combining bits64[i] in a round-robin fashion, to align with non-XLA + // implementations + int64 bits64_len = (num_elems + 1) / 2; + for (auto i = 0; i < 2; ++i) { + bits64[i] = Reshape(bits64[i], {bits64_len, 1}); + } + XlaOp numbers = ConcatInDim(builder, {bits64[0], bits64[1]}, + /*dimension=*/1); + numbers = Reshape(numbers, {bits64_len * 2}); + numbers = Slice(numbers, /*start_indices=*/{0}, + /*limit_indices=*/{num_elems}, + /*strides=*/{1}); + return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state}; +} + +XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { + XlaBuilder* builder = bits.builder(); // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit // forces the random bits into the mantissa. constexpr int kFloatBits = 32; @@ -161,50 +395,139 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { bits = ShiftRightLogical( bits, ConstantR0(builder, kFloatBits - kMantissaBits)) | ConstantR0(builder, absl::bit_cast(1.0f)); - auto floats = BitcastConvertType(bits, F32); + XlaOp values = BitcastConvertType(bits, F32); // We have a floating point number in the range [1.0, 2.0). // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = floats - ConstantR0(builder, 1.0f); + values = values - ConstantR0(builder, 1.0f); // Multiply and add to shift to the range [minval, maxval). - return floats * (maxval - minval) + minval; + return values * (maxval - minval) + minval; } -XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, - PrimitiveType type, PrimitiveType unsigned_type) { +XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, + PrimitiveType type, + PrimitiveType unsigned_type) { XlaBuilder* builder = bits.builder(); - auto range = BitcastConvertType(maxval, unsigned_type) - - BitcastConvertType(minval, unsigned_type); - auto dist = Rem(bits, range); - auto dist_div_2 = + XlaOp range = BitcastConvertType(maxval, unsigned_type) - + BitcastConvertType(minval, unsigned_type); + XlaOp dist = Rem(bits, range); + XlaOp dist_div_2 = ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1)); return minval + BitcastConvertType(dist_div_2, type) + BitcastConvertType(dist - dist_div_2, type); } -XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, - XlaOp minval, XlaOp maxval) { - XlaBuilder* builder = seeds[0].builder(); +// Implements the Box-Muller transform, which converts random floats in the +// range of [0, 1] from uniform distribution to normal distribution with mean 0 +// and variance 1. For more detail on the Box-Muller transform, see +// http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form +std::pair BoxMullerTransform(XlaOp x0, XlaOp x1) { + // Do not send a really small number to log(). + XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f)); + + XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1; + XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1)); + return {Sin(v1) * u2, Cos(v1) * u2}; +} + +} // namespace + +RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, + const Shape& shape) { PrimitiveType type = shape.element_type(); switch (type) { - case F32: { - auto bits = StatelessRngUniformU32(seeds, shape); - return StatelessRngUniformF32(bits, minval, maxval); - } - case S32: { - auto bits = StatelessRngUniformU32(seeds, shape); - return StatelessRngUniformInt(bits, minval, maxval, type, U32); - } - case S64: { - auto bits = StatelessRngUniformU64(seeds, shape); - return StatelessRngUniformInt(bits, minval, maxval, type, U64); - } + case F32: + case U32: + case S32: + return ThreeFryRngBit32(key, initial_state, shape); + case U64: + case S64: + return ThreeFryRngBit64(key, initial_state, shape); default: - return builder->ReportError(Unimplemented( - "Types other than F32, S32 and S64 are not implemented by " - "StatelessRngUniform.")); + return {key.builder()->ReportError(Unimplemented( + "Types other than F32, U32, S32, U64 and S64 " + "are not implemented by ThreeFryBitGenerator; got %s", + primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; } } +RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape, + bool scramble) { + PrimitiveType type = shape.element_type(); + switch (type) { + case F32: + case U32: + case S32: + return PhiloxRngBit32(key, initial_state, shape, scramble); + case U64: + case S64: + return PhiloxRngBit64(key, initial_state, shape, scramble); + default: + return {key.builder()->ReportError(Unimplemented( + "Types other than F32, U32, S32, U64 and S64 " + "are not implemented by ThreeFryBitGenerator; got %s", + primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; + } +} + +RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const Shape& shape) { + DCHECK_EQ(shape.element_type(), F32); + RngOutput bits_state = bit_generator(key, initial_state, shape); + XlaOp bits = bits_state.value; + XlaOp new_state = bits_state.state; + return {ConvertRandomBitsToUniformF32(bits, minval, maxval), new_state}; +} + +RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const Shape& shape) { + RngOutput bits_state = bit_generator(key, initial_state, shape); + XlaOp bits = bits_state.value; + XlaOp new_state = bits_state.state; + PrimitiveType type = shape.element_type(); + PrimitiveType unsigned_type; + if (type == U32 || type == S32) { + unsigned_type = U32; + } else { + DCHECK(type == U64 || type == S64); + unsigned_type = U64; + } + return { + ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type), + new_state}; +} + +RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, + const Shape& shape) { + DCHECK_EQ(shape.element_type(), F32); + XlaBuilder* builder = key.builder(); + const int64 num_elems = ShapeUtil::ElementsIn(shape); + const int64 num_pairs = CeilOfRatio(num_elems, 2); + RngOutput bits_state = UniformF32Distribution( + key, initial_state, bit_generator, ConstantR0(builder, 0.0), + ConstantR0(builder, 1.0), + ShapeUtil::MakeShape(F32, {num_pairs * 2})); + + // Separate the bits into two groups to perform the Box-Muller transform. + XlaOp bits_0 = Slice(bits_state.value, {0}, {num_pairs}, {1}); + XlaOp bits_1 = Slice(bits_state.value, {num_pairs}, {2 * num_pairs}, {1}); + std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1); + + // Put the numbers in the two groups back to form the requested shape. + XlaOp normal = ConcatInDim(builder, {bits_0, bits_1}, /*dimension=*/0); + if (num_elems != num_pairs * 2) { + normal = Slice(normal, /*start_indices=*/{0}, /*limit_indices=*/{num_elems}, + /*strides=*/{1}); + } + normal = Reshape(normal, shape.dimensions()); + + return {normal, bits_state.state}; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index 7b0b4c2439e..fcd1dbfb919 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -23,37 +23,69 @@ limitations under the License. namespace xla { +// Records the bits and state generated by a random number generator. +struct RngOutput { + XlaOp value; + XlaOp state; +}; + +// A BitGenerator returns random bits and updated random bit generator state. +// +// key: is a value input to a random number generator that can affect the +// sequence of number it will generate. A random number generator constructs +// its seed using the key and the initial state. The tf2xla bridge passes the +// seed operand of a tensorflow random operation as a key to the random bit +// generator, for example. +// initial_state: initial_state is the initial state of the current random +// number generation. It could be 0 for a stateless random operation, and +// the returned state from a previous execution for a stateful random +// operation. +// shape: the shape of the random bits. +using BitGeneratorTy = std::function; + // Implements the ThreeFry counter-based PRNG algorithm. // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -using ThreeFry2x32State = std::array; -ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key); +RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, + const xla::Shape& shape); -// Returns a tensor containing 'shape' random values uniformly distributed in -// the range [minval, maxval). Requires 2 32-bit integer seeds. -// Currently only 'shape's of type F32, S32 and S64 are implemented. -XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, - XlaOp minval, XlaOp maxval); +// Implements the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +// +// The paper presents a few variants of the Philox algorithm, we picked the +// 4x32_10 version of the algorithm for the following reasons: +// . 4x32 uses 32-bit multiplication which is fast on GPUs. +// . The authors recommend the 10-round variant, and TensorFlow also uses it. +// 'scramble` controls whether to scramble 'key' and 'initial_state' to form +// the actual key and state fed to the Philox algorithm. +RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape, + bool scramble); -// Converts a 32-bit (signed or unsigned) integer random number `bits` into a -// float32 in the range [minval, maxval). -XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval); +// Uses the given bit generator to generate random bits and then converts the +// random bits to random numbers of uniform distribution in the given range. +// Returns the random numbers and the state of the random number generator. +// This function is for shape with float element type. +RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const xla::Shape& shape); -// Converts an integer random number 'bits' of type 'type' to a random number -// in the range [minval, maxval), of the same type. 'unsigned_type' is the -// unsigned version of 'type' (could be the same) with the same bit width. -// The algorithm is the same one that TF uses right now, but it's -// uniform only when maxval - minval is a divisor of the range that bits is -// generated from. -// TODO(b/72573764): Generate real uniform integer distribution. -XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, - PrimitiveType type, PrimitiveType unsigned_type); +// Similar to UniformF32Distribution but for shape with integer element types. +RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const xla::Shape& shape); -// The following 2 functions, for converting between one uint64 and two uint32s, -// use the contract "lower 32 bits for the first uint32, higher 32 bits for the -// second". -ThreeFry2x32State Uint64ToUint32s(XlaOp u64); -XlaOp Uint32sToUint64(ThreeFry2x32State u32s); +// Uses the given bit generator to generate random bits and then converts the +// random bits to random numbers of normal distribution. +// Returns the random numbers and the state of the random number generator. +RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, + const xla::Shape& shape); + +// Concatenates scalars into a vector. +xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, + absl::Span scalars); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index 640412ec8bc..5a7c826c389 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -101,7 +101,7 @@ Status House(XlaOp x, XlaOp k, absl::Span batch_dims, auto sigma_is_zero = Eq(sigma, zero); - *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu); + *beta = Select(sigma_is_zero, alpha, Select(Lt(alpha, zero), one, -one) * mu); *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), (*beta - alpha) / *beta); auto divisor = @@ -192,7 +192,7 @@ StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) auto vva = BatchDot(v_broadcast, a, precision); - vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); + vva = BatchDot(v_broadcast, true, vva, false, precision); a = a - Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -271,7 +271,7 @@ StatusOr ComputeWYRepresentation(PrimitiveType type, auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(TransposeInMinorDims(y), v, precision); + auto yv = BatchDot(y, true, v, false, precision); // wyv has shape [..., m, 1] auto wyv = BatchDot(w, yv, precision); @@ -365,7 +365,7 @@ StatusOr QRDecomposition( // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision); + auto a_update = BatchDot(w, true, a_panel, false, precision); a_update = BatchDot(y, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); @@ -373,7 +373,7 @@ StatusOr QRDecomposition( // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); auto q_update = BatchDot(q_panel, w, precision); - q_update = BatchDot(q_update, TransposeInMinorDims(y), precision); + q_update = BatchDot(q_update, false, y, true, precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc index b27d364b624..a61f243e126 100644 --- a/tensorflow/compiler/xla/client/lib/qr_test.cc +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -60,6 +60,33 @@ XLA_TEST_F(QrTest, Simple) { xla::ErrorSpec(1e-4, 1e-4)); } +XLA_TEST_F(QrTest, ZeroDiagonal) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {0, 1, 1}, + {1, 0, 1}, + {1, 1, 0}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/8)); + + // Verifies that the decomposition composes back to the original matrix. + // + // This isn't a terribly demanding test, (e.g., we should verify that Q is + // orthonormal and R is upper-triangular) but it's awkward to write such tests + // without more linear algebra libraries. It's easier to test the numerics + // from Python, anyway, where we have access to numpy and scipy. + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR2(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + XLA_TEST_F(QrTest, SimpleBatched) { xla::XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc index c8875dff7bf..99bec8a9ab5 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -125,7 +125,9 @@ class SelfAdjointEigTest : public ClientLibraryTestBase { Array2D GenerateRandomSymmetricMatrix(int size) { Array2D result{size, size, 0.0}; - result.FillRandom(10 /* stddev */, 2 /* mean */); + // TODO(b/128001705): This seed should not be needed but makes the test + // avoid inputs which trigger numerical instability. + result.FillRandom(10 /* stddev */, 2 /* mean */, 12346 /* seed */); for (int i = 0; i < size; ++i) { for (int j = 0; j < i; ++j) { result({j, i}) = result({i, j}); diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 0878cbeaf9a..32de252ba1d 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/slicing.h" + #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -161,22 +163,45 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) { }); } -XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim) { +XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) { XlaBuilder* builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); + if (dim < batch_dims) { + return InvalidArgument( + "Gather dim must be greater than or equal to the number of batch " + "dims"); + } std::vector slice_sizes = input_shape.dimensions(); - slice_sizes[dim] = 1; GatherDimensionNumbers gather_dnums; + gather_dnums.set_index_vector_dim(index_shape.rank()); + if (batch_dims > 0) { + ShapeUtil::AppendMajorDimension(1, &index_shape); + std::vector to_concat; + to_concat.reserve(batch_dims + 1); + for (int64 batch_dim = 0; batch_dim < batch_dims; ++batch_dim) { + to_concat.push_back(Iota(builder, index_shape, batch_dim)); + } + to_concat.push_back(Reshape(index, index_shape.dimensions())); + index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim()); + } for (int64 i = 0; i < input_shape.rank(); ++i) { - if (i != dim) { - gather_dnums.add_offset_dims(i); + if (i < batch_dims || i == dim) { + if (slice_sizes[i] != 0) { + slice_sizes[i] = 1; + gather_dnums.add_collapsed_slice_dims(i); + } + gather_dnums.add_start_index_map(i); + } else { + if (i < dim) { + gather_dnums.add_offset_dims(i); + } else { + gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() - + (1 + batch_dims)); + } } } - gather_dnums.set_index_vector_dim(index_shape.rank()); - gather_dnums.add_collapsed_slice_dims(dim); - gather_dnums.add_start_index_map(dim); return Gather(input, index, gather_dnums, slice_sizes); }); } diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index bb6191df7c4..89ec1fe510e 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -63,7 +63,11 @@ XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim); // The returned tensor has the same number of dimensions as the original tensor // (input). The dimth dimension has the same size as the length of index; other // dimensions have the same size as in the original tensor. -XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim); +// +// This operation supports 0 or more major batch dimensions that act like a +// multidimensional loop over both the input and the index. +XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, + int64 batch_dims = 0); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 408a82ca3c6..04d3f96b6a5 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -146,6 +146,7 @@ XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) { 0, "input", &builder, &input); auto index_data = CreateR1Parameter({0, 2}, 1, "index", &builder, &index); + TorchIndexSelect(input, index, 1); ComputeAndCompareR2( @@ -153,5 +154,35 @@ XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) { {input_data.get(), index_data.get()}); } +XLA_TEST_F(SlicingTest, EmptyIndexSelect) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR2Parameter({{0}, {0}, {0}}, 0, "input", &builder, &input); + auto index_data = CreateR1Parameter({}, 1, "index", &builder, &index); + TorchIndexSelect(input, index, 1); + ComputeAndCompareR2(&builder, {{}, {}, {}}, + {input_data.get(), index_data.get()}); +} + +XLA_TEST_F(SlicingTest, BatchTorchIndexSelectOn0) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR3Parameter({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}}, + {{3, 2, 1, 0}, {7, 6, 5, 4}, {11, 10, 9, 8}}}, + 0, "input", &builder, &input); + auto index_data = + CreateR2Parameter({{0, 2}, {1, 2}}, 1, "index", &builder, &index); + TorchIndexSelect(input, index, 1, 1); + + ComputeAndCompareR3( + &builder, + {{{0, 1, 2, 3}, {8, 9, 10, 11}}, {{7, 6, 5, 4}, {11, 10, 9, 8}}}, + {input_data.get(), index_data.get()}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index ddc39f4d874..49b3a4f109a 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -31,11 +31,9 @@ XlaOp TopK(XlaOp input, int64 k) { ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); auto input_dims = input_shape.dimensions(); - // TODO(b/122298745): Get rid of Neg() and use CreateScalarGtComputation - // once the TPU backend supports the comparison computations. XlaOp sort_result = - Sort({Neg(input), iota_s32}, - CreateScalarLtComputation({input_shape.element_type(), S32}, + Sort({input, iota_s32}, + CreateScalarGtComputation({input_shape.element_type(), S32}, iota_s32.builder()), last_dim, /*is_stable=*/true); std::vector start_indices(input_shape.dimensions_size(), 0); @@ -43,8 +41,8 @@ XlaOp TopK(XlaOp input, int64 k) { limit_indices[last_dim] = k; std::vector strides(input_shape.dimensions_size(), 1); - XlaOp values = Neg(Slice(GetTupleElement(sort_result, 0), start_indices, - limit_indices, strides)); + XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides); XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, limit_indices, strides); return Tuple(builder, {values, indices}); diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index 0fbd138aca1..3bba84d90d4 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -44,8 +44,7 @@ XLA_TEST_F(SortingTest, TopK3From8Indices) { ComputeAndCompareR1(&builder, {0, 1, 2}, {}); } -// TODO(b/119930279): enable this test. -XLA_TEST_F(SortingTest, DISABLED_TopKFullSortMinInt) { +XLA_TEST_F(SortingTest, TopKFullSortMinInt) { XlaBuilder builder(TestName()); auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), std::numeric_limits::min() + 1, @@ -54,18 +53,6 @@ XLA_TEST_F(SortingTest, DISABLED_TopKFullSortMinInt) { ComputeAndCompareR1(&builder, {2, 1, 0}, {}); } -XLA_TEST_F(SortingTest, NOT_TopKFullSortMinInt) { - XlaBuilder builder(TestName()); - auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), - std::numeric_limits::min() + 1, - std::numeric_limits::max()}); - xla::GetTupleElement(xla::TopK(x_rev, 3), 1); - // TopK currently negates the keys, which doesn't work correctly for - // std::numeric_limits::min(). Therefore, it will sort this key to the - // front instead of to the back. - ComputeAndCompareR1(&builder, {0, 2, 1}, {}); -} - XLA_TEST_F(SortingTest, TopKFullSort) { XlaBuilder builder(TestName()); const int kSize = 16; diff --git a/tensorflow/compiler/xla/client/lib/svd.cc b/tensorflow/compiler/xla/client/lib/svd.cc index 8dad4cab515..53a23872709 100644 --- a/tensorflow/compiler/xla/client/lib/svd.cc +++ b/tensorflow/compiler/xla/client/lib/svd.cc @@ -163,9 +163,8 @@ StatusOr HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, HouseHolderResult result; result.v = v; result.beta = beta; - result.a = - Sub(a, Mul(beta, BatchDot(BatchDot(a, TransposeInMinorDims(v), precision), - v, precision))); + result.a = Sub(a, Mul(beta, BatchDot(BatchDot(a, false, v, true, precision), + v, precision))); return result; } @@ -231,8 +230,8 @@ StatusOr HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, result.v = v; result.beta = beta; result.a = Sub( - a, Mul(beta, BatchDot(v, BatchDot(TransposeInMinorDims(v), a, precision), - precision))); + a, Mul(beta, BatchDot(v, false, BatchDot(v, true, a, false, precision), + false, precision))); return result; } @@ -290,18 +289,16 @@ StatusOr HouseHolderBidiagonalization( TF_ASSIGN_OR_RETURN(HouseHolderResult house_col, HouseCol(a, i, i, eps, precision)); - u = Sub(u, Mul(house_col.beta, - BatchDot(BatchDot(u, house_col.v, precision), - TransposeInMinorDims(house_col.v), precision))); + u = Sub(u, + Mul(house_col.beta, BatchDot(BatchDot(u, house_col.v, precision), + false, house_col.v, true, precision))); a = house_col.a; TF_ASSIGN_OR_RETURN(HouseHolderResult house_row, HouseRow(a, i, i + one, eps, precision)); - v = Sub( - v, - Mul(house_row.beta, - BatchDot(BatchDot(v, TransposeInMinorDims(house_row.v), precision), - house_row.v, precision))); + v = Sub(v, Mul(house_row.beta, + BatchDot(BatchDot(v, false, house_row.v, true, precision), + house_row.v, precision))); a = house_row.a; std::vector updated_values; @@ -331,11 +328,10 @@ StatusOr HouseHolderBidiagonalization( XlaOp index = ScalarLike(values[0], n - k); TF_ASSIGN_OR_RETURN(HouseHolderResult house_col, HouseCol(values[3], index, index, eps, precision)); - values[1] = - Sub(values[1], - Mul(house_col.beta, - BatchDot(BatchDot(values[1], house_col.v, precision), - TransposeInMinorDims(house_col.v), precision))); + values[1] = Sub(values[1], + Mul(house_col.beta, + BatchDot(BatchDot(values[1], house_col.v, precision), + false, house_col.v, true, precision))); values[3] = house_col.a; } } @@ -750,25 +746,21 @@ StatusOr SortBySingularValuesAndPostProcessing(SVDResult result) { result.v = Mul(result.v, sign, broadcast_dims); d = BroadcastInDim(d, dimensions, broadcast_dims); - auto zero = Zero(builder, S32); - // As m >= n, only first m columns vectors are needed to be permuted, and the - // rest of m - n vectors are appended after the sorting is done. + // As m >= n, only first n column vectors need to be permuted, and the rest of + // m - n vectors are appended after the sorting is done. XlaOp sort_u_result = - Sort({-d, DynamicSliceInMinorDims(result.u, {zero, zero}, {m, n})}, - CreateScalarLtComputation( + Sort({d, SliceInMinorDims(result.u, {0, 0}, {m, n})}, + CreateScalarGtComputation( {shape.element_type(), shape.element_type()}, builder), num_dims - 1); - // TODO(kuny): using CreateScalarGtComputation after b/124862300 is fixed. XlaOp sort_v_result = - Sort({DynamicSliceInMinorDims(-d, {zero, zero}, {n, n}), result.v}, - CreateScalarLtComputation( + Sort({SliceInMinorDims(d, {0, 0}, {n, n}), result.v}, + CreateScalarGtComputation( {shape.element_type(), shape.element_type()}, builder), num_dims - 1); - // Make sure all the signular values are non-negative. - result.d = Max(-GetMatrixDiagonal(GetTupleElement(sort_v_result, 0)), - ScalarLike(d, 0.0)); + result.d = GetMatrixDiagonal(GetTupleElement(sort_v_result, 0)); result.v = GetTupleElement(sort_v_result, 1); result.v = Mul( @@ -779,12 +771,10 @@ StatusOr SortBySingularValuesAndPostProcessing(SVDResult result) { broadcast_dims); // Append the rest of m - n vectors. - result.u = - ConcatInDim(builder, - {GetTupleElement(sort_u_result, 1), - DynamicSliceInMinorDims( - result.u, {zero, ScalarLike(zero, n)}, {m, m - n})}, - num_dims - 1); + result.u = ConcatInDim(builder, + {GetTupleElement(sort_u_result, 1), + SliceInMinorDims(result.u, {0, n}, {m, m})}, + num_dims - 1); result.u = Mul( result.u, Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0), diff --git a/tensorflow/compiler/xla/client/lib/svd_test.cc b/tensorflow/compiler/xla/client/lib/svd_test.cc index c3c6ae93d81..a987f7fcaf6 100644 --- a/tensorflow/compiler/xla/client/lib/svd_test.cc +++ b/tensorflow/compiler/xla/client/lib/svd_test.cc @@ -77,11 +77,10 @@ class SVDTest : public ClientLibraryTestBase { auto u = result.u; auto d = result.d; - auto zero = Zero(builder, S32); if (m > n) { - u = DynamicSliceInMinorDims(u, {zero, zero}, {m, n}); + u = SliceInMinorDims(u, {0, 0}, {m, n}); } else if (m < n) { - v = DynamicSliceInMinorDims(v, {zero, zero}, {n, m}); + v = SliceInMinorDims(v, {0, 0}, {n, m}); } int num_dims = u_shape.rank(); @@ -92,25 +91,6 @@ class SVDTest : public ClientLibraryTestBase { PrecisionConfig::HIGHEST); } - Array3D ExtractTriangularMatrix(const Array3D& matrix, - bool lower) { - Array3D result(matrix); - for (int i = 0; i < result.n1(); ++i) { - for (int j = 0; j < result.n2(); ++j) { - if (lower) { - for (int k = j + 1; k < result.n3(); ++k) { - result({i, j, k}) = 0.0; - } - } else { - for (int k = 0; k < j; ++k) { - result({i, j, k}) = 0.0; - } - } - } - } - return result; - } - XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { Shape shape = builder->GetShape(m1).ValueOrDie(); int64 size = 1; @@ -268,7 +248,7 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) { Array2D a_val = GenerateRandomMatrix(512, 512); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 48b5f94538f..1bd9d7b7228 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/stream_pool.h" @@ -139,7 +140,8 @@ Status LocalExecutable::ValidateExecutionOptions( return Status::OK(); } -StatusOr LocalExecutable::Run( +StatusOr> +LocalExecutable::RunHelper( const absl::Span arguments, ExecutableRunOptions run_options) { TF_RETURN_IF_ERROR( @@ -148,7 +150,7 @@ StatusOr LocalExecutable::Run( StreamPool::Ptr stream; if (run_options.stream() == nullptr) { // NB! The lifetime of `stream` needs to match the lifetime of - // `actual_options` (otherwise we will end up using a returned stream in + // `service_options` (otherwise we will end up using a returned stream in // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if" // scope. TF_ASSIGN_OR_RETURN( @@ -166,12 +168,29 @@ StatusOr LocalExecutable::Run( // backend_->eigen_intra_op_thread_pool(). ServiceExecutableRunOptions service_options(run_options, backend_->StreamBorrower()); + return std::make_pair(service_options, std::move(stream)); +} + +StatusOr LocalExecutable::Run( + const absl::Span arguments, + ExecutableRunOptions run_options) { + TF_ASSIGN_OR_RETURN(auto options_and_stream, + RunHelper(arguments, run_options)); if (executable_->dumping_snapshot()) { - return ExecuteAndDump(&service_options, arguments); + return ExecuteAndDump(&options_and_stream.first, arguments); } return executable_->ExecuteOnStreamWrapper( - &service_options, run_options.execution_profile(), arguments); + &options_and_stream.first, run_options.execution_profile(), arguments); +} + +StatusOr LocalExecutable::RunAsync( + const absl::Span arguments, + ExecutableRunOptions run_options) { + TF_ASSIGN_OR_RETURN(auto options_and_stream, + RunHelper(arguments, run_options)); + return executable_->ExecuteAsyncOnStream(&options_and_stream.first, + arguments); } StatusOr LocalExecutable::ExecuteAndDump( @@ -185,7 +204,7 @@ StatusOr LocalExecutable::ExecuteAndDump( executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); - TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot()); + DumpHloSnapshotIfEnabled(executable_->module(), *executable_->hlo_snapshot()); return std::move(result); } @@ -259,8 +278,8 @@ StatusOr> LocalClient::Compile( } StatusOr LocalClient::LiteralToShapedBuffer( - const Literal& literal, int device_ordinal, - DeviceMemoryAllocator* allocator) { + const LiteralSlice& literal, int device_ordinal, + se::DeviceMemoryAllocator* allocator) { if (allocator == nullptr) { allocator = backend().memory_allocator(); } @@ -287,7 +306,7 @@ StatusOr LocalClient::GlobalDataToShapedBuffer( return local_service_->GlobalDataToShapedBuffer(data, replica_number); } -Status LocalClient::TransferToInfeedLocal(const Literal& literal, +Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 4f4fc8df31c..1e7c97d6f06 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/local_service.h" @@ -32,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { @@ -43,6 +43,12 @@ class LocalExecutable { const absl::Span arguments, ExecutableRunOptions run_options); + // Similar to Run(), but need not block the host waiting for the computation + // to complete before returning. + StatusOr RunAsync( + const absl::Span arguments, + ExecutableRunOptions run_options); + // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; } @@ -67,10 +73,10 @@ class LocalExecutable { const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to - // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. + // invoke it, and the result. Enabled by flag: --xla_dump_hlo_snapshots. // - // The given ServiceExecutableRunOptions override any values from TF_XLA_FLAGS - // environment variable. + // The given ServiceExecutableRunOptions override any values from the + // XLA_FLAGS environment variable. StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const absl::Span arguments); @@ -86,6 +92,10 @@ class LocalExecutable { // Returns a literal containing the contents of the given ShapedBuffer. StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); + StatusOr> RunHelper( + const absl::Span arguments, + ExecutableRunOptions run_options); + // The ordinal of the device which this executable was compiled for. The // executable can run on all equivalent devices (as determined by // Backend::devices_equivalent). @@ -126,8 +136,8 @@ class LocalClient : public Client { // device memory allocation. If null, the default memory allocator for the // device is used. StatusOr LiteralToShapedBuffer( - const Literal& literal, int device_ordinal, - DeviceMemoryAllocator* allocator = nullptr); + const LiteralSlice& literal, int device_ordinal, + se::DeviceMemoryAllocator* allocator = nullptr); // Transfer the BorrowingLiteral to the device with the given ordinal. StatusOr TransferToLocalServer( @@ -146,7 +156,7 @@ class LocalClient : public Client { // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with // Client::TransferToInfeed. - Status TransferToInfeedLocal(const Literal& literal, int device_ordinal); + Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal); // Transfer and return a value of the given shape from the outfeed of the // given device. diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 6c6d1a9bd3a..1fa52a1fa22 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -480,7 +481,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { } XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { + absl::Span broadcast_dimensions, + absl::optional direction) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -489,6 +491,17 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); *instr.mutable_shape() = shape.ToProto(); + if (binop == HloOpcode::kCompare) { + if (!direction.has_value()) { + return InvalidArgument( + "kCompare expects a ComparisonDirection, but none provided."); + } + instr.set_comparison_direction(ComparisonDirectionToString(*direction)); + } else if (direction.has_value()) { + return InvalidArgument( + "A comparison direction is provided for a non-compare opcode: %s.", + HloOpcodeString(binop)); + } const int64 lhs_rank = lhs_shape.rank(); const int64 rhs_rank = rhs_shape.rank(); @@ -542,33 +555,50 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); - TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferTernaryOpShape(triop, lhs_shape, - rhs_shape, ehs_shape)); - *instr.mutable_shape() = shape.ToProto(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; - if (!shape.IsTuple()) { - if (!lhs_shape.IsTuple() && - !ShapeUtil::SameDimensions(shape, lhs_shape)) { - // lhs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs)); + // The client API supports implicit broadcast for kSelect and kClamp, but + // XLA does not support implicit broadcast. Make implicit broadcast explicit + // and update the operands. + if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) { + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); + + absl::optional non_scalar_shape; + for (const Shape& shape : {lhs_shape, rhs_shape, ehs_shape}) { + if (shape.IsArray() && shape.rank() != 0) { + non_scalar_shape = shape; + } } - if (!rhs_shape.IsTuple() && - !ShapeUtil::SameDimensions(shape, rhs_shape)) { - // rhs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs)); - } - if (!ehs_shape.IsTuple() && - !ShapeUtil::SameDimensions(shape, ehs_shape)) { - // ehs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs)); + if (non_scalar_shape.has_value()) { + if (ShapeUtil::IsScalar(lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(*non_scalar_shape, lhs)); + } + if (ShapeUtil::IsScalar(rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(*non_scalar_shape, rhs)); + } + if (ShapeUtil::IsScalar(ehs_shape)) { + TF_ASSIGN_OR_RETURN(updated_ehs, + AddBroadcastSequence(*non_scalar_shape, ehs)); + } } } + + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(updated_lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(updated_rhs)); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(updated_ehs)); + StatusOr status_or_shape = ShapeInference::InferTernaryOpShape( + triop, lhs_shape, rhs_shape, ehs_shape); + if (!status_or_shape.status().ok()) { + return InvalidArgument( + "%s Input scalar shapes may have been changed to non-scalar shapes.", + status_or_shape.status().error_message()); + } + *instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto(); return AddInstruction(std::move(instr), triop, {updated_lhs, updated_rhs, updated_ehs}); }); @@ -617,8 +647,9 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, }); } -XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, - const string& name) { +XlaOp XlaBuilder::Parameter( + int64 parameter_number, const Shape& shape, const string& name, + const std::vector& replicated_at_leaf_buffers) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { @@ -628,6 +659,12 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, instr.set_parameter_number(parameter_number); instr.set_name(name); *instr.mutable_shape() = shape.ToProto(); + if (!replicated_at_leaf_buffers.empty()) { + auto replication = instr.mutable_parameter_replication(); + for (bool replicated : replicated_at_leaf_buffers) { + replication->add_replicated_at_leaf_buffers(replicated); + } + } return AddInstruction(std::move(instr), HloOpcode::kParameter); }); } @@ -1015,18 +1052,6 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); - // If one operand is a scalar, just multiply the two operands. - if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { - if (dimension_numbers.rhs_batch_dimensions_size() != 0 || - dimension_numbers.lhs_batch_dimensions_size() != 0 || - dimension_numbers.rhs_contracting_dimensions_size() != 0 || - dimension_numbers.lhs_contracting_dimensions_size() != 0) { - return InvalidArgument( - "Dots with scalar operands must have no contracting or batch " - "dimensions"); - } - return xla::Mul(lhs, rhs); - } TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); @@ -1492,7 +1517,7 @@ XlaOp XlaBuilder::CustomCall( } *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); - instr.set_custom_call_opaque(opaque); + instr.set_backend_config(opaque); if (operand_shapes_with_layout.has_value()) { if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument( @@ -1552,122 +1577,6 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -namespace { -// Switch from a floating point value to a integer value in such a way that when -// using the integer value to compare, we get the same result for normal values, -// and -Nan is treated as the smallest value, and Nan is treated as the largest -// value. -// If f is a float, and -// x = bit_cast(f); -// y = x < 0 ? numeric_limits::max() - x : x; -// then y is ordered as an int32 such that finite values have the obvious order, -// -0 is ordered before 0, and -NaN and NaN appear at the beginning and end of -// the ordering. -// Note that in order to avoid -x to overflow, we calculate -// numeric_limits::max() - x as unsigned, and then convert back to -// signed. -XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, - int64 bit_width) { - PrimitiveType signed_type; - PrimitiveType unsigned_type; - XlaOp max_value; - switch (bit_width) { - case 16: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S16; - unsigned_type = U16; - break; - case 32: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S32; - unsigned_type = U32; - break; - case 64: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S64; - unsigned_type = U64; - break; - default: - return value.builder()->ReportError( - InvalidArgument("Invalid bit width %lld for Comparator floating " - "point parameter.", - bit_width)); - } - auto signed_value = BitcastConvertType(value, signed_type); - auto unsigned_value = BitcastConvertType(value, unsigned_type); - auto flipped_value = - BitcastConvertType(Sub(max_value, unsigned_value), signed_type); - auto is_negative = - Lt(signed_value, - ConstantLiteral(value.builder(), LiteralUtil::Zero(signed_type))); - return Select(is_negative, flipped_value, signed_value); -} -} // namespace - -XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, - int64 dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { - std::vector operands{keys}; - for (const XlaOp& value : values) { - operands.push_back(value); - } - // Build the default less-than comparator (copied from lib/comparators.cc). - // TODO(b/122298745): Remove the deprecated API method so that this code - // duplication can be deleted. - auto b = this->CreateSubBuilder("comparator"); - std::vector operand_types; - for (const XlaOp& operand : operands) { - TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(operand)); - operand_types.push_back(operand_shape.element_type()); - } - - int64 parameter_count = 0; - XlaOp first_lhs_param; - XlaOp first_rhs_param; - - for (auto operand_type : operand_types) { - auto scalar_shape = ShapeUtil::MakeShape(operand_type, {}); - auto lhs_param = - b->Parameter(parameter_count * 2, scalar_shape, - absl::StrCat("p.", parameter_count, ".lhs")); - auto rhs_param = - b->Parameter(parameter_count * 2 + 1, scalar_shape, - absl::StrCat("p.", parameter_count, ".rhs")); - if (parameter_count == 0) { - first_lhs_param = lhs_param; - first_rhs_param = rhs_param; - } - ++parameter_count; - } - if (primitive_util::IsFloatingPointType(operand_types[0])) { - PrimitiveType compare_type = operand_types[0]; - // Special-case handling for BF16. We currently do not support direct - // comparisons with BF16, so we convert to F32 and then use the F32 - // comparison logic. - if (compare_type == BF16) { - compare_type = F32; - first_lhs_param = b->ConvertElementType(first_lhs_param, F32); - first_rhs_param = b->ConvertElementType(first_rhs_param, F32); - } - int64 bit_width = primitive_util::BitWidth(compare_type); - first_lhs_param = - BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width); - first_rhs_param = - BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width); - } - Lt(first_lhs_param, first_rhs_param); - - TF_ASSIGN_OR_RETURN(auto comparator, b->Build()); - return Sort(operands, comparator, dimension, /*is_stable=*/false); - }); -} - XlaOp XlaBuilder::Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension, bool is_stable) { @@ -1880,16 +1789,42 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation) { - // The index of true_computation must be 0 and that of false computation - // must be 1. - return Conditional(predicate, {&true_computation, &false_computation}, - {true_operand, false_operand}); + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, GetShape(predicate)); + + if (!ShapeUtil::IsScalar(shape) || shape.element_type() != PRED) { + return InvalidArgument( + "Argument to predicated-Conditional is not a scalar of PRED type " + "(%s).", + ShapeUtil::HumanString(shape)); + } + // The index of true_computation must be 0 and that of false computation + // must be 1. + return ConditionalImpl(predicate, {&true_computation, &false_computation}, + {true_operand, false_operand}); + }); } XlaOp XlaBuilder::Conditional( const XlaOp& branch_index, absl::Span branch_computations, absl::Span branch_operands) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, GetShape(branch_index)); + + if (!ShapeUtil::IsScalar(shape) || shape.element_type() != S32) { + return InvalidArgument( + "Argument to indexed-Conditional is not a scalar of S32 type (%s).", + ShapeUtil::HumanString(shape)); + } + return ConditionalImpl(branch_index, branch_computations, branch_operands); + }); +} + +XlaOp XlaBuilder::ConditionalImpl( + const XlaOp& branch_index, + absl::Span branch_computations, + absl::Span branch_operands) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -2815,7 +2750,15 @@ StatusOr XlaBuilder::LookUpInstructionByHandle( // passed to the computation. XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, const string& name) { - return builder->Parameter(parameter_number, shape, name); + std::vector empty_bools; + return Parameter(builder, parameter_number, shape, name, empty_bools); +} + +XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, + const string& name, + const std::vector& replicated_at_leaf_buffers) { + return builder->Parameter(parameter_number, shape, name, + replicated_at_leaf_buffers); } // Enqueues a constant with the value of the given literal onto the @@ -2824,63 +2767,63 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { return builder->ConstantLiteral(literal); } -XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes) { +XlaOp Broadcast(const XlaOp operand, absl::Span broadcast_sizes) { return operand.builder()->Broadcast(operand, broadcast_sizes); } -XlaOp BroadcastInDim(const XlaOp& operand, +XlaOp BroadcastInDim(const XlaOp operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions) { return operand.builder()->BroadcastInDim(operand, out_dim_size, broadcast_dimensions); } -XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, +XlaOp Pad(const XlaOp operand, const XlaOp padding_value, const PaddingConfig& padding_config) { return operand.builder()->Pad(operand, padding_value, padding_config); } -XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, +XlaOp Reshape(const XlaOp operand, absl::Span dimensions, absl::Span new_sizes) { return operand.builder()->Reshape(operand, dimensions, new_sizes); } -XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes) { +XlaOp Reshape(const XlaOp operand, absl::Span new_sizes) { return operand.builder()->Reshape(operand, new_sizes); } -XlaOp Collapse(const XlaOp& operand, absl::Span dimensions) { +XlaOp Collapse(const XlaOp operand, absl::Span dimensions) { return operand.builder()->Collapse(operand, dimensions); } -XlaOp Slice(const XlaOp& operand, absl::Span start_indices, +XlaOp Slice(const XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides) { return operand.builder()->Slice(operand, start_indices, limit_indices, strides); } -XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, +XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { return operand.builder()->SliceInDim(operand, start_index, limit_index, stride, dimno); } -XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, +XlaOp DynamicSlice(const XlaOp operand, const XlaOp start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } -XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, +XlaOp DynamicSlice(const XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } -XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, - const XlaOp& start_indices) { +XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, + const XlaOp start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } -XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, +XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, absl::Span start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } @@ -2890,11 +2833,11 @@ XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, return builder->ConcatInDim(operands, dimension); } -void Trace(const string& tag, const XlaOp& operand) { +void Trace(const string& tag, const XlaOp operand) { return operand.builder()->Trace(tag, operand); } -XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { +XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) { return pred.builder()->Select(pred, on_true, on_false); } @@ -2902,59 +2845,60 @@ XlaOp Tuple(XlaBuilder* builder, absl::Span elements) { return builder->Tuple(elements); } -XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { +XlaOp GetTupleElement(const XlaOp tuple_data, int64 index) { return tuple_data.builder()->GetTupleElement(tuple_data, index); } -XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Eq(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kEq, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq); } -XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Ne(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kNe, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe); } -XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Ge(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kGe, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe); } -XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Gt(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kGt, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt); } -XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Le(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kLe, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe); } -XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Lt(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kLt, lhs, rhs, - broadcast_dimensions); + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt); } -XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Compare(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction) { + return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs, + broadcast_dimensions, direction); +} + +XlaOp Dot(const XlaOp lhs, const XlaOp rhs, const PrecisionConfig* precision_config) { return lhs.builder()->Dot(lhs, rhs, precision_config); } -XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, +XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config) { return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, precision_config); } -XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Conv(const XlaOp lhs, const XlaOp rhs, absl::Span window_strides, Padding padding, int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { @@ -2963,7 +2907,7 @@ XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, precision_config); } -XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs, absl::Span window_strides, absl::Span> padding, int64 feature_group_count, int64 batch_group_count, @@ -2974,7 +2918,7 @@ XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, } XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + const XlaOp lhs, const XlaOp rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { @@ -2983,7 +2927,7 @@ XlaOp ConvWithGeneralDimensions( batch_group_count, precision_config); } -XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, @@ -2994,7 +2938,7 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, batch_group_count, precision_config); } -XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs, absl::Span window_strides, absl::Span> padding, absl::Span lhs_dilation, @@ -3008,7 +2952,7 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, precision_config); } -XlaOp Fft(const XlaOp& operand, FftType fft_type, +XlaOp Fft(const XlaOp operand, FftType fft_type, absl::Span fft_length) { return operand.builder()->Fft(operand, fft_type, fft_length); } @@ -3055,7 +2999,7 @@ XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) { return builder->Infeed(shape, config); } -void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, +void Outfeed(const XlaOp operand, const Shape& shape_with_layout, const string& outfeed_config) { return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config); } @@ -3080,99 +3024,103 @@ XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, operand_shapes_with_layout); } -XlaOp Complex(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Complex(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs, broadcast_dimensions); } -XlaOp Conj(const XlaOp& operand) { +XlaOp Conj(const XlaOp operand) { return Complex(Real(operand), Neg(Imag(operand))); } -XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Add(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); } -XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Sub(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); } -XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Mul(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } -XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Div(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); } -XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Rem(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); } -XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Max(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); } -XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Min(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); } -XlaOp And(const XlaOp& lhs, const XlaOp& rhs, +XlaOp And(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); } -XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Or(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); } -XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Xor(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); } -XlaOp Not(const XlaOp& operand) { +XlaOp Not(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kNot, operand); } -XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, +XlaOp PopulationCount(const XlaOp operand) { + return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand); +} + +XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, broadcast_dimensions); } -XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, +XlaOp Reduce(const XlaOp operand, const XlaOp init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce) { return operand.builder()->Reduce(operand, init_value, computation, @@ -3189,12 +3137,12 @@ XlaOp Reduce(XlaBuilder* builder, absl::Span operands, dimensions_to_reduce); } -XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, +XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value, const XlaComputation& computation) { return operand.builder()->ReduceAll(operand, init_value, computation); } -XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, +XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, Padding padding) { @@ -3204,7 +3152,7 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, } XlaOp ReduceWindowWithGeneralPadding( - const XlaOp& operand, const XlaOp& init_value, + const XlaOp operand, const XlaOp init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, @@ -3216,19 +3164,19 @@ XlaOp ReduceWindowWithGeneralPadding( base_dilations, window_dilations, padding); } -XlaOp CrossReplicaSum(const XlaOp& operand, +XlaOp CrossReplicaSum(const XlaOp operand, absl::Span replica_groups) { return operand.builder()->CrossReplicaSum(operand, replica_groups); } -XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, +XlaOp CrossReplicaSum(const XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const absl::optional& channel_id) { return operand.builder()->CrossReplicaSum(operand, computation, replica_groups, channel_id); } -XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, +XlaOp AllToAll(const XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups) { return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, @@ -3236,17 +3184,17 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, } XlaOp CollectivePermute( - const XlaOp& operand, + const XlaOp operand, const std::vector>& source_target_pairs) { return operand.builder()->CollectivePermute(operand, source_target_pairs); } XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); } -XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, +XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, + const XlaOp source, const XlaOp init_value, const XlaComputation& scatter) { return operand.builder()->SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, @@ -3254,116 +3202,112 @@ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, } XlaOp SelectAndScatterWithGeneralPadding( - const XlaOp& operand, const XlaComputation& select, + const XlaOp operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, - absl::Span> padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter) { + absl::Span> padding, const XlaOp source, + const XlaOp init_value, const XlaComputation& scatter) { return operand.builder()->SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); } -XlaOp Abs(const XlaOp& operand) { +XlaOp Abs(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kAbs, operand); } -XlaOp Atan2(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Atan2(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs, broadcast_dimensions); } -XlaOp Exp(const XlaOp& operand) { +XlaOp Exp(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kExp, operand); } -XlaOp Expm1(const XlaOp& operand) { +XlaOp Expm1(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand); } -XlaOp Floor(const XlaOp& operand) { +XlaOp Floor(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kFloor, operand); } -XlaOp Ceil(const XlaOp& operand) { +XlaOp Ceil(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kCeil, operand); } -XlaOp Round(const XlaOp& operand) { +XlaOp Round(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand); } -XlaOp Log(const XlaOp& operand) { +XlaOp Log(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kLog, operand); } -XlaOp Log1p(const XlaOp& operand) { +XlaOp Log1p(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand); } -XlaOp Sign(const XlaOp& operand) { +XlaOp Sign(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kSign, operand); } -XlaOp Clz(const XlaOp& operand) { +XlaOp Clz(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kClz, operand); } -XlaOp Cos(const XlaOp& operand) { +XlaOp Cos(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kCos, operand); } -XlaOp Sin(const XlaOp& operand) { +XlaOp Sin(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kSin, operand); } -XlaOp Tanh(const XlaOp& operand) { +XlaOp Tanh(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kTanh, operand); } -XlaOp Real(const XlaOp& operand) { +XlaOp Real(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kReal, operand); } -XlaOp Imag(const XlaOp& operand) { +XlaOp Imag(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kImag, operand); } -XlaOp Sqrt(const XlaOp& operand) { +XlaOp Sqrt(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand); } -XlaOp Rsqrt(const XlaOp& operand) { +XlaOp Rsqrt(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand); } -XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Pow(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); } -XlaOp IsFinite(const XlaOp& operand) { +XlaOp IsFinite(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand); } -XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { +XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) { return operand.builder()->ConvertElementType(operand, new_element_type); } -XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { +XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) { return operand.builder()->BitcastConvertType(operand, new_element_type); } -XlaOp Neg(const XlaOp& operand) { +XlaOp Neg(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kNegate, operand); } -XlaOp Transpose(const XlaOp& operand, absl::Span permutation) { +XlaOp Transpose(const XlaOp operand, absl::Span permutation) { return operand.builder()->Transpose(operand, permutation); } -XlaOp Rev(const XlaOp& operand, absl::Span dimensions) { +XlaOp Rev(const XlaOp operand, absl::Span dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { - return keys.builder()->Sort(keys, values, dimension); -} - XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension, bool is_stable) { return operands[0].builder()->Sort(operands, comparator, dimension, is_stable); } -XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { +XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) { return min.builder()->Clamp(min, operand, max); } @@ -3373,56 +3317,56 @@ XlaOp Map(XlaBuilder* builder, absl::Span operands, return builder->Map(operands, computation, dimensions, static_operands); } -XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) { +XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) { return mu.builder()->RngNormal(mu, sigma, shape); } -XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) { +XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) { return a.builder()->RngUniform(a, b, shape); } XlaOp While(const XlaComputation& condition, const XlaComputation& body, - const XlaOp& init) { + const XlaOp init) { return init.builder()->While(condition, body, init); } -XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, +XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand, const XlaComputation& true_computation, - const XlaOp& false_operand, + const XlaOp false_operand, const XlaComputation& false_computation) { return predicate.builder()->Conditional(predicate, true_operand, true_computation, false_operand, false_computation); } -XlaOp Conditional(const XlaOp& branch_index, +XlaOp Conditional(const XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands) { return branch_index.builder()->Conditional(branch_index, branch_computations, branch_operands); } -XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, +XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits, const int mantissa_bits) { return operand.builder()->ReducePrecision(operand, exponent_bits, mantissa_bits); } -XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, +XlaOp Gather(const XlaOp input, const XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes) { return input.builder()->Gather(input, start_indices, dimension_numbers, slice_sizes); } -XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, - const XlaOp& updates, const XlaComputation& update_computation, +XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices, + const XlaOp updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers) { return input.builder()->Scatter(input, scatter_indices, updates, update_computation, dimension_numbers); } -void Send(const XlaOp& operand, const ChannelHandle& handle) { +void Send(const XlaOp operand, const ChannelHandle& handle) { return operand.builder()->Send(operand, handle); } @@ -3431,33 +3375,33 @@ XlaOp Recv(XlaBuilder* builder, const Shape& shape, return builder->Recv(shape, handle); } -XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, +XlaOp SendWithToken(const XlaOp operand, const XlaOp token, const ChannelHandle& handle) { return operand.builder()->SendWithToken(operand, token, handle); } -XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, +XlaOp RecvWithToken(const XlaOp token, const Shape& shape, const ChannelHandle& handle) { return token.builder()->RecvWithToken(token, shape, handle); } -XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, +XlaOp SendToHost(const XlaOp operand, const XlaOp token, const Shape& shape_with_layout, const ChannelHandle& handle) { return operand.builder()->SendToHost(operand, token, shape_with_layout, handle); } -XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, +XlaOp RecvFromHost(const XlaOp token, const Shape& shape, const ChannelHandle& handle) { return token.builder()->RecvFromHost(token, shape, handle); } -XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, +XlaOp InfeedWithToken(const XlaOp token, const Shape& shape, const string& config) { return token.builder()->InfeedWithToken(token, shape, config); } -XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, +XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token, const Shape& shape_with_layout, const string& outfeed_config) { return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout, @@ -3470,24 +3414,24 @@ XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens) { return builder->AfterAll(tokens); } -XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, - const XlaOp& offset, float epsilon, +XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale, + const XlaOp offset, float epsilon, int64 feature_index) { return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon, feature_index); } -XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, - const XlaOp& offset, const XlaOp& mean, - const XlaOp& variance, float epsilon, +XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale, + const XlaOp offset, const XlaOp mean, + const XlaOp variance, float epsilon, int64 feature_index) { return operand.builder()->BatchNormInference( operand, scale, offset, mean, variance, epsilon, feature_index); } -XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, - const XlaOp& batch_mean, const XlaOp& batch_var, - const XlaOp& grad_output, float epsilon, +XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale, + const XlaOp batch_mean, const XlaOp batch_var, + const XlaOp grad_output, float epsilon, int64 feature_index) { return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var, grad_output, epsilon, feature_index); @@ -3501,7 +3445,7 @@ XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { return builder->Iota(shape, iota_dimension); } -XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension) { +XlaOp GetDimensionSize(const XlaOp operand, int64 dimension) { return operand.builder()->GetDimensionSize(operand, dimension); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 56e85e394c5..508f16a945f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" @@ -321,7 +322,13 @@ class XlaBuilder { // functions section in this file. XlaOp Parameter(int64 parameter_number, const Shape& shape, - const string& name); + const string& name, + const std::vector& replicated_at_leaf_buffers); + XlaOp Parameter(int64 parameter_number, const Shape& shape, + const string& name) { + std::vector empty_bools; + return Parameter(parameter_number, shape, name, empty_bools); + } XlaOp ConstantLiteral(const LiteralSlice& literal); @@ -508,9 +515,6 @@ class XlaBuilder { XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - ABSL_DEPRECATED("Use form with comparator computation instead") - XlaOp Sort(const XlaOp& keys, absl::Span values = {}, - int64 dimension = -1); XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension = -1, bool is_stable = false); @@ -596,9 +600,11 @@ class XlaBuilder { // Internal helper method that does the building for an arbitrary binary op. // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. + // when the operation is between tensors of different ranks. The direction is + // only used if opcode is kCompare. XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions); + absl::Span broadcast_dimensions, + absl::optional direction = absl::nullopt); // Internal helper method that does the building for an arbitrary ternary op. XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, @@ -705,105 +711,103 @@ class XlaBuilder { XlaBuilder* parent_builder_{nullptr}; friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, - const Shape& shape, const string& name); + const Shape& shape, const string& name, + const std::vector& replicated_at_leaf_buffers); friend XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); - friend XlaOp Broadcast(const XlaOp& operand, + friend XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); friend XlaOp BroadcastInDim( - const XlaOp& operand, const absl::Span out_dim_size, + XlaOp operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions); - friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + friend XlaOp Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config); - friend XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + friend XlaOp Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes); - friend XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); + friend XlaOp Reshape(XlaOp operand, absl::Span new_sizes); - friend XlaOp Collapse(const XlaOp& operand, - absl::Span dimensions); + friend XlaOp Collapse(XlaOp operand, absl::Span dimensions); - friend XlaOp Slice(const XlaOp& operand, - absl::Span start_indices, + friend XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); - friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, - int64 limit_index, int64 stride, int64 dimno); + friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); - friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + friend XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, absl::Span slice_sizes); - friend XlaOp DynamicSlice(const XlaOp& operand, + friend XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); - friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, - const XlaOp& start_indices); - friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, + XlaOp start_indices); + friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension); - friend void Trace(const string& tag, const XlaOp& operand); + friend void Trace(const string& tag, XlaOp operand); - friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, - const XlaOp& on_false); + friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); - friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index); + friend XlaOp Eq(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Ne(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Ge(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Gt(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Lt(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Le(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); + friend XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config); - friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_number, const PrecisionConfig* precision_config); - friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, Padding padding, int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, + XlaOp lhs, XlaOp rhs, absl::Span window_strides, absl::Span> padding, int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers, + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); - friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, + XlaOp lhs, XlaOp rhs, absl::Span window_strides, absl::Span> padding, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); - friend XlaOp Fft(const XlaOp& operand, FftType fft_type, + friend XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span fft_length); friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, bool unit_diagonal, @@ -811,7 +815,7 @@ class XlaBuilder { friend XlaOp Cholesky(XlaOp a, bool lower); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); - friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + friend void Outfeed(XlaOp operand, const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span operands); @@ -822,182 +826,180 @@ class XlaBuilder { XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, absl::Span operand_shapes_with_layout, const string& opaque); - friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, + friend XlaOp Complex(XlaOp real, XlaOp imag, absl::Span broadcast_dimensions); - friend XlaOp Conj(const XlaOp& operand); - friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Conj(XlaOp operand); + friend XlaOp Add(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Sub(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Mul(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Div(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Rem(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Max(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Min(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp And(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Or(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Xor(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Not(const XlaOp& operand); - friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Not(XlaOp operand); + friend XlaOp PopulationCount(XlaOp operand); + friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); friend XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions); - friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); + friend XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + friend XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, absl::Span init_values, const XlaComputation& computation, absl::Span dimensions_to_reduce); - friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value, const XlaComputation& computation); - friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, Padding padding); friend XlaOp ReduceWindowWithGeneralPadding( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, + XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, absl::Span window_dilations, absl::Span> padding); - friend XlaOp CrossReplicaSum(const XlaOp& operand, + friend XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups); - friend XlaOp CrossReplicaSum(const XlaOp& operand, - const XlaComputation& computation, + friend XlaOp CrossReplicaSum(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const absl::optional& channel_id); - friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + friend XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); friend XlaOp CollectivePermute( - const XlaOp& operand, + XlaOp operand, const std::vector>& source_target_pairs); friend XlaOp ReplicaId(XlaBuilder* builder); - friend XlaOp SelectAndScatter(const XlaOp& operand, - const XlaComputation& select, + friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, + Padding padding, XlaOp source, XlaOp init_value, const XlaComputation& scatter); friend XlaOp SelectAndScatterWithGeneralPadding( - const XlaOp& operand, const XlaComputation& select, + XlaOp operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, - absl::Span> padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter); - friend XlaOp Abs(const XlaOp& operand); - friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + friend XlaOp Abs(XlaOp operand); + friend XlaOp Atan2(XlaOp y, XlaOp x, absl::Span broadcast_dimensions); - friend XlaOp Exp(const XlaOp& operand); - friend XlaOp Expm1(const XlaOp& operand); - friend XlaOp Floor(const XlaOp& operand); - friend XlaOp Ceil(const XlaOp& operand); - friend XlaOp Round(const XlaOp& operand); - friend XlaOp Log(const XlaOp& operand); - friend XlaOp Log1p(const XlaOp& operand); - friend XlaOp Sign(const XlaOp& operand); - friend XlaOp Clz(const XlaOp& operand); - friend XlaOp Cos(const XlaOp& operand); - friend XlaOp Sin(const XlaOp& operand); - friend XlaOp Tanh(const XlaOp& operand); - friend XlaOp Real(const XlaOp& operand); - friend XlaOp Imag(const XlaOp& operand); - friend XlaOp Sqrt(const XlaOp& operand); - friend XlaOp Rsqrt(const XlaOp& operand); - friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + friend XlaOp Exp(XlaOp operand); + friend XlaOp Expm1(XlaOp operand); + friend XlaOp Floor(XlaOp operand); + friend XlaOp Ceil(XlaOp operand); + friend XlaOp Round(XlaOp operand); + friend XlaOp Log(XlaOp operand); + friend XlaOp Log1p(XlaOp operand); + friend XlaOp Sign(XlaOp operand); + friend XlaOp Clz(XlaOp operand); + friend XlaOp Cos(XlaOp operand); + friend XlaOp Sin(XlaOp operand); + friend XlaOp Tanh(XlaOp operand); + friend XlaOp Real(XlaOp operand); + friend XlaOp Imag(XlaOp operand); + friend XlaOp Sqrt(XlaOp operand); + friend XlaOp Rsqrt(XlaOp operand); + friend XlaOp Pow(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp IsFinite(const XlaOp& operand); + friend XlaOp IsFinite(XlaOp operand); friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); - friend XlaOp ConvertElementType(const XlaOp& operand, + friend XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); - friend XlaOp BitcastConvertType(const XlaOp& operand, + friend XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); - friend XlaOp Neg(const XlaOp& operand); - friend XlaOp Transpose(const XlaOp& operand, - absl::Span permutation); - friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - friend XlaOp Sort(const XlaOp& keys, absl::Span values, - int64 dimension); + friend XlaOp Neg(XlaOp operand); + friend XlaOp Transpose(XlaOp operand, absl::Span permutation); + friend XlaOp Rev(XlaOp operand, absl::Span dimensions); friend XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension, bool is_stable); - friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, absl::Span dimensions, absl::Span static_operands); - friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, - const Shape& shape); - friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); + friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); friend XlaOp While(const XlaComputation& condition, - const XlaComputation& body, const XlaOp& init); - friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& body, XlaOp init); + friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand, const XlaComputation& true_computation, - const XlaOp& false_operand, + XlaOp false_operand, const XlaComputation& false_computation); friend XlaOp Conditional( + XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + friend XlaOp ConditionalImpl( const XlaOp& branch_index, absl::Span branch_computations, absl::Span branch_operands); - friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits, const int mantissa_bits); - friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, + friend XlaOp Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes); - friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, - const XlaOp& updates, + friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - friend void Send(const XlaOp& operand, const ChannelHandle& handle); + friend void Send(XlaOp operand, const ChannelHandle& handle); friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); - friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, - const XlaOp& offset, float epsilon, - int64 feature_index); - friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, - const XlaOp& offset, const XlaOp& mean, - const XlaOp& variance, float epsilon, + friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, + float epsilon, int64 feature_index); + friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, + XlaOp mean, XlaOp variance, float epsilon, int64 feature_index); - friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, - const XlaOp& batch_mean, const XlaOp& batch_var, - const XlaOp& grad_output, float epsilon, + friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, + XlaOp batch_var, XlaOp grad_output, float epsilon, int64 feature_index); - friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + friend XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); - friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + friend XlaOp RecvWithToken(XlaOp token, const Shape& shape, const ChannelHandle& handle); - friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + friend XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const ChannelHandle& handle); - friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + friend XlaOp RecvFromHost(XlaOp token, const Shape& shape, const ChannelHandle& handle); - friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config); - friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); - friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension); + + private: + XlaOp ConditionalImpl( + const XlaOp& branch_index, + absl::Span branch_computations, + absl::Span branch_operands); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1039,6 +1041,11 @@ class XlaScopedShardingAssignment { XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, const string& name); +// Same as above, but with leaf buffer replication annotation. +XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, + const string& name, + const std::vector& replicated_at_leaf_buffers); + // Enqueues a constant with the value of the given literal onto the // computation. XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); @@ -1111,7 +1118,7 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); // The new dimensions index into copies of the operand, i.e. // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); +XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); // This op broadcasts the `operand` to an output with the given `shape`. // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the @@ -1128,14 +1135,13 @@ XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); // will generate output // {{1 , 1}, // {2 , 2}} -XlaOp BroadcastInDim(const XlaOp& operand, - const absl::Span out_dim_size, +XlaOp BroadcastInDim(XlaOp operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config // specifies the padding amount for each dimension. -XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, +XlaOp Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config); // Enqueues an operation onto the computation that flattens the operand based @@ -1143,13 +1149,13 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". -XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, +XlaOp Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". -XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); +XlaOp Reshape(XlaOp operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -1169,7 +1175,7 @@ XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. -XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); +XlaOp Collapse(XlaOp operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -1182,7 +1188,7 @@ XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice -XlaOp Slice(const XlaOp& operand, absl::Span start_indices, +XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); @@ -1192,7 +1198,7 @@ XlaOp Slice(const XlaOp& operand, absl::Span start_indices, // for: // // array[:, 2:4:1, :] -XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, +XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); // Enqueues a slice operation onto the computation that slices the 'operand' @@ -1205,11 +1211,11 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // have the same shape. // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. -XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, +XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); ABSL_DEPRECATED("Use span-of-indices form instead") -XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, +XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which @@ -1229,12 +1235,11 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, // have the same shape. // Slice index calculations are computed modulo update dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. -XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, +XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); ABSL_DEPRECATED("Use span-of-indices form instead") -XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, - const XlaOp& start_indices); +XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices); // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. @@ -1243,61 +1248,66 @@ XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. -void Trace(const string& tag, const XlaOp& operand); +void Trace(const string& tag, XlaOp operand); // Enqueues a conditional-move-like select operation onto the computation; // predicated on pred, selects between on_true and on_false. -XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); +XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); // Enqueues a tuple-creation instruction onto the computation. XlaOp Tuple(XlaBuilder* builder, absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. -XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); +XlaOp GetTupleElement(XlaOp tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. -XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Eq(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. -XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Ne(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. -XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Ge(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. -XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Gt(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. -XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Lt(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. -XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Le(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +// Enqueues a comparison instruction onto the computation. +XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); + // Enqueues a dot instruction onto the computation. -XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. -XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, +XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. -XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, Padding padding, - int64 feature_group_count = 1, int64 batch_group_count = 1, +XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs, absl::Span window_strides, absl::Span> padding, int64 feature_group_count = 1, @@ -1307,15 +1317,14 @@ XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + XlaOp lhs, XlaOp rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. -XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, +XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, int64 batch_group_count = 1, @@ -1323,7 +1332,7 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs, absl::Span window_strides, absl::Span> padding, absl::Span lhs_dilation, @@ -1335,8 +1344,7 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. -XlaOp Fft(const XlaOp& operand, FftType fft_type, - absl::Span fft_length); +XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span fft_length); // Solves systems of linear equations with lower or upper triangular coefficient // matrices by forward- or back-substitution. Broadcasting along leading @@ -1386,7 +1394,7 @@ XlaOp Infeed(XlaBuilder* builder, const Shape& shape, // two-element tuple containing the data value and a token-shaped value. // Tokens are used for ordering side-effecting operations. // TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, +XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config = ""); // Enqueues an outfeed instruction onto the computation. This instruction @@ -1395,13 +1403,13 @@ XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, // shape_with_layout communicates the laid out shape that we want to outfeed // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error // will occur. -void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, +void Outfeed(XlaOp operand, const Shape& shape_with_layout, const string& outfeed_config); // Variant of Outfeed which takes a token-shaped operand and produces a // token-shaped value. Tokens are used for ordering side-effecting operations. // TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, +XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const string& outfeed_config); @@ -1438,87 +1446,86 @@ XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, // (see g3doc for more details). // Enqueues a complex compose instruction onto the computation. -XlaOp Complex(const XlaOp& real, const XlaOp& imag, +XlaOp Complex(XlaOp real, XlaOp imag, absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. -XlaOp Conj(const XlaOp& operand); +XlaOp Conj(XlaOp operand); // Enqueues an add instruction onto the computation. -XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Add(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. -XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Sub(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. -XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Mul(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. -XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Div(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. -XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Rem(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. -XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Max(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. -XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Min(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Element-wise logical operators -XlaOp And(const XlaOp& lhs, const XlaOp& rhs, +XlaOp And(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Overload to call And with 3 or more operands. We need the following somewhat // convoluted overload set to disambiguate with the overload that takes the // `broadcast_dimensions` optional param. -inline XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { +inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) { return And(op1, And(op2, op3)); } template -XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, - const XlaOpTs&... operands) { +XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { return And(op1, And(op2, And(op3, operands...))); } -XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Or(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Overload to call Or with 3 or more operands. As with `And`, we need the // following complicated overload set to handle the default arg in the `Or` // overload above. -inline XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { +inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) { return Or(op1, Or(op2, op3)); } template -XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, - const XlaOpTs&... operands) { +XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { return Or(op1, Or(op2, Or(op3, operands...))); } -XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Xor(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); -XlaOp Not(const XlaOp& operand); +XlaOp Not(XlaOp operand); -XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, +XlaOp PopulationCount(XlaOp operand); + +XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); -XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); -XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, +XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. -XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, +XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); // Reduces several arrays simultaneously among the provided dimensions, given @@ -1530,11 +1537,11 @@ XlaOp Reduce(XlaBuilder* builder, absl::Span operands, // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. -XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, +XlaOp ReduceAll(XlaOp operand, XlaOp init_value, const XlaComputation& computation); // Enqueues a windowed reduce instruction onto the computation. -XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, +XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, Padding padding); @@ -1542,8 +1549,7 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, + XlaOp operand, XlaOp init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -1553,7 +1559,7 @@ XlaOp ReduceWindowWithGeneralPadding( // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. -XlaOp CrossReplicaSum(const XlaOp& operand, +XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here @@ -1574,13 +1580,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, // // TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( - const XlaOp& operand, const XlaComputation& computation, + XlaOp operand, const XlaComputation& computation, absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. -XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, - int64 concat_dimension, int64 split_count, +XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, + int64 split_count, const std::vector& replica_groups = {}); // Enqueues an collective operation that sends and receives data cross replicas. @@ -1592,7 +1598,7 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, // is not a target in any pair, then the output on that replica is a tensor // consists of 0(s) with the same shape as the input. XlaOp CollectivePermute( - const XlaOp& operand, + XlaOp operand, const std::vector>& source_target_pairs); // Enqueues an operation that returns the replica ID. @@ -1600,79 +1606,79 @@ XlaOp ReplicaId(XlaBuilder* builder); // Enqueues an operation that scatters the `source` array to the selected // indices of each window. -XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, +XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, + XlaOp source, XlaOp init_value, const XlaComputation& scatter); // As SelectAndScatter(), but the padding is given in the format // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( - const XlaOp& operand, const XlaComputation& select, + XlaOp operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, - absl::Span> padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter); + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. -XlaOp Abs(const XlaOp& operand); +XlaOp Abs(XlaOp operand); // Enqueues a atan2 instruction onto the computation. -XlaOp Atan2(const XlaOp& y, const XlaOp& x, +XlaOp Atan2(XlaOp y, XlaOp x, absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. -XlaOp Exp(const XlaOp& operand); +XlaOp Exp(XlaOp operand); // Enqueues an expm1 instruction onto the computation. -XlaOp Expm1(const XlaOp& operand); +XlaOp Expm1(XlaOp operand); // Enqueues a floor instruction onto the computation. -XlaOp Floor(const XlaOp& operand); +XlaOp Floor(XlaOp operand); // Enqueues a ceil instruction onto the computation. -XlaOp Ceil(const XlaOp& operand); +XlaOp Ceil(XlaOp operand); // Enqueues a round instruction onto the computation, rounding to nearest even // with half-way cases rounding away from zero. -XlaOp Round(const XlaOp& operand); +XlaOp Round(XlaOp operand); // Enqueues an log instruction (natural logarithm) onto the computation. -XlaOp Log(const XlaOp& operand); +XlaOp Log(XlaOp operand); // Enqueues an log1p instruction (log(x+1)) onto the computation. -XlaOp Log1p(const XlaOp& operand); +XlaOp Log1p(XlaOp operand); // Enqueues a sign instruction onto the computation. -XlaOp Sign(const XlaOp& operand); +XlaOp Sign(XlaOp operand); // Enqueues a count leading zeros instruction onto the computation. -XlaOp Clz(const XlaOp& operand); +XlaOp Clz(XlaOp operand); // Enqueues a cosine instruction onto the computation. -XlaOp Cos(const XlaOp& operand); +XlaOp Cos(XlaOp operand); // Enqueues a sine instruction onto the computation. -XlaOp Sin(const XlaOp& operand); +XlaOp Sin(XlaOp operand); // Enqueues a tanh instruction onto the computation. -XlaOp Tanh(const XlaOp& operand); +XlaOp Tanh(XlaOp operand); // Enqueues a real-part instruction onto the computation. -XlaOp Real(const XlaOp& operand); +XlaOp Real(XlaOp operand); // Enqueues an imaginary-part instruction onto the computation. -XlaOp Imag(const XlaOp& operand); +XlaOp Imag(XlaOp operand); // Enqueues a sqrt computation onto the computation. -XlaOp Sqrt(const XlaOp& operand); +XlaOp Sqrt(XlaOp operand); // Enqueues a rsqrt computation onto the computation. -XlaOp Rsqrt(const XlaOp& operand); +XlaOp Rsqrt(XlaOp operand); // Enqueues a lhs^rhs computation onto the computation. -XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Pow(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., not @@ -1683,7 +1689,7 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, // an error for other types. // // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. -XlaOp IsFinite(const XlaOp& operand); +XlaOp IsFinite(XlaOp operand); // Enqueues an iota operation onto the computation. XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); @@ -1693,44 +1699,24 @@ XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. -XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); +XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); // Enqueues a no-op instruction onto the computation that changes // the element type of the operand array to primitive_type. The // bit-widths of the source and destination element types must be // identical. -XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); +XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); // Enqueues a negate instruction onto the computation. -XlaOp Neg(const XlaOp& operand); +XlaOp Neg(XlaOp operand); // Enqueues a transpose instruction onto the computation. -XlaOp Transpose(const XlaOp& operand, absl::Span permutation); +XlaOp Transpose(XlaOp operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). -XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - -// Enqueues a sort (as increasing order) instruction onto the computation. -// If only keys are provided: -// * If the keys are an rank-1 tensor (an array), the result is a sorted array -// of keys, in ascending order. -// * If the keys have higher rank, the keys are sorted along the provided -// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension -// value of 0 will independently sort every column, and a dimension value of 1 -// will independently sort each row. If no dimension number is provided, then -// the last dimension is chosen by default. -// -// If both keys and values are provided: -// * The keys and all values must be tensors with the same dimensions. The -// element types of the tensors may be different. -// * The result is a tuple that consists of a sorted tensor of keys (along the -// provided dimension, as above) as the first element, and tensors with their -// corresponding values as the other elements. -ABSL_DEPRECATED("Use form with comparator computation instead") -XlaOp Sort(const XlaOp& keys, absl::Span values = {}, - int64 dimension = -1); +XlaOp Rev(XlaOp operand, absl::Span dimensions); // Enqueues a sort instruction onto the computation, using 'comparator' for // comparisons. 'comparator' needs to define a strict weak order. 'is_stable' @@ -1762,7 +1748,7 @@ XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension = -1, bool is_stable = false); // Enqueues a clamp instruction onto the computation. -XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); +XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); // Enqueues a map instruction onto the computation. XlaOp Map(XlaBuilder* builder, absl::Span operands, @@ -1771,20 +1757,19 @@ XlaOp Map(XlaBuilder* builder, absl::Span operands, // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. -XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); +XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); // Enqueues a U(a, b) random number generation instruction onto the // computation. Returns values in the semi-open interval [a, b). -XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); +XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); // Enqueues a while node onto the computation. XlaOp While(const XlaComputation& condition, const XlaComputation& body, - const XlaOp& init); + XlaOp init); // Enqueues a conditional node onto the computation. -XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, - const XlaComputation& true_computation, - const XlaOp& false_operand, +XlaOp Conditional(XlaOp predicate, XlaOp true_operand, + const XlaComputation& true_computation, XlaOp false_operand, const XlaComputation& false_computation); // Enqueues either a predicated (if/else) or indexed (switch/case/default) @@ -1792,35 +1777,34 @@ XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, // branch_operands are matched by index. branch_index selects the branch that // will be executed. Out of range branch_index uses the N-1'th // branch_computation as default. -XlaOp Conditional(const XlaOp& branch_index, +XlaOp Conditional(XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands); // Enqueues a ReducePrecision node onto the computation. -XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, +XlaOp ReducePrecision(XlaOp operand, const int exponent_bits, const int mantissa_bits); // Enqueues a Gather node onto the computation. -XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, +XlaOp Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. -XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, - const XlaOp& updates, const XlaComputation& update_computation, +XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers); // Enqueues a Send node onto the computation for device-to-device // communication. This operation sends the given operand to // a Recv instruction in a different computation that shares the same channel // handle. -void Send(const XlaOp& operand, const ChannelHandle& handle); +void Send(XlaOp operand, const ChannelHandle& handle); // Variant of Send which takes a token-shaped operand and produces a // token-shaped value. Tokens are used for ordering side-effecting operations. // TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, - const ChannelHandle& handle); +XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); // Enqueues a Recv node onto the computation for device-to-device // communication. The data comes from a Send instruction in a different @@ -1833,7 +1817,7 @@ XlaOp Recv(XlaBuilder* builder, const Shape& shape, // tuple containing the data value and a token-shaped value. Tokens are used // for ordering side-effecting operations. // TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, +XlaOp RecvWithToken(XlaOp token, const Shape& shape, const ChannelHandle& handle); // Enqueues a Send node which transfers data from the device to the host. The @@ -1841,13 +1825,13 @@ XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, // shape must be compatible with the shape of the operand. The operand must be // array-shaped. // TODO(b/111544877): Support tuple shapes. -XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, - const Shape& shape_with_layout, const ChannelHandle& handle); +XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const ChannelHandle& handle); // Enqueues a Recv node which transfers data from the host to the device. The // given shape must contain a layout and must be an array. // TODO(b/111544877): Support tuple shapes. -XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, +XlaOp RecvFromHost(XlaOp token, const Shape& shape, const ChannelHandle& handle); // Enqueues an operation (AfterAll) with no operands that produces a @@ -1868,8 +1852,7 @@ XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` // is the normalized result and batch_mean and batch_var are the mean and // variance, respectively, across batch for the operand. -XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, - const XlaOp& offset, float epsilon, +XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon, int64 feature_index); // Normalizes operand across spatial and batch dimensions for each feature. @@ -1882,10 +1865,8 @@ XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, // // The output has the same shape as `operand`, and contains the normalized // values for each batch. -XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, - const XlaOp& offset, const XlaOp& mean, - const XlaOp& variance, float epsilon, - int64 feature_index); +XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, + XlaOp variance, float epsilon, int64 feature_index); // Calculates the gradients of a batch norm op. // @@ -1896,14 +1877,13 @@ XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, // - grad_operand: Gradient with respect to input `operand` // - grad_offset: Gradient with respect to input `offset` // - grad_scale: Gradient with respect to input `scale` -XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, - const XlaOp& batch_mean, const XlaOp& batch_var, - const XlaOp& grad_output, float epsilon, +XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, + XlaOp batch_var, XlaOp grad_output, float epsilon, int64 feature_index); // Returns the size of the given dimension of the operand. The operand must be // array shaped. -XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); +XlaOp GetDimensionSize(XlaOp operand, int64 dimension); // Implementation details below this point. // @@ -1917,7 +1897,11 @@ XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { template XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { - return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); + BorrowingLiteral literal( + reinterpret_cast(values.begin()), + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + {static_cast(values.size())})); + return ConstantLiteral(builder, literal); } template diff --git a/tensorflow/compiler/xla/comparison_util.cc b/tensorflow/compiler/xla/comparison_util.cc new file mode 100644 index 00000000000..de34ad678e7 --- /dev/null +++ b/tensorflow/compiler/xla/comparison_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/comparison_util.h" +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +std::string ComparisonDirectionToString(ComparisonDirection direction) { + switch (direction) { + case ComparisonDirection::kEq: + return "EQ"; + case ComparisonDirection::kNe: + return "NE"; + case ComparisonDirection::kGe: + return "GE"; + case ComparisonDirection::kGt: + return "GT"; + case ComparisonDirection::kLe: + return "LE"; + case ComparisonDirection::kLt: + return "LT"; + } +} + +StatusOr StringToComparisonDirection( + absl::string_view direction_name) { + static auto* direction_map = + new absl::flat_hash_map({ + {"EQ", ComparisonDirection::kEq}, + {"NE", ComparisonDirection::kNe}, + {"GE", ComparisonDirection::kGe}, + {"GT", ComparisonDirection::kGt}, + {"LE", ComparisonDirection::kLe}, + {"LT", ComparisonDirection::kLt}, + }); + auto it = direction_map->find(direction_name); + if (it == direction_map->end()) { + return InvalidArgument("Unknown comparison direction: %s", direction_name); + } + return it->second; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h new file mode 100644 index 00000000000..8b150c3cfad --- /dev/null +++ b/tensorflow/compiler/xla/comparison_util.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ + +#include "absl/base/macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// Represents different comparison operations. +enum class ComparisonDirection : uint8 { + kEq, + kNe, + kGe, + kGt, + kLe, + kLt, +}; + +string ComparisonDirectionToString(ComparisonDirection direction); + +StatusOr StringToComparisonDirection( + absl::string_view direction_name); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.cc b/tensorflow/compiler/xla/cpu_function_runtime.cc similarity index 97% rename from tensorflow/compiler/tf2xla/cpu_function_runtime.cc rename to tensorflow/compiler/xla/cpu_function_runtime.cc index fcc4095e396..517b30a8251 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime.cc +++ b/tensorflow/compiler/xla/cpu_function_runtime.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/core/platform/dynamic_annotations.h" -namespace tensorflow { +namespace xla { namespace { // Inline memory allocation routines here, because depending on '//base' brings // in libraries which use c++ streams, which adds considerable code size on @@ -105,4 +105,4 @@ void FreeContiguous(void* contiguous) { } } } // namespace cpu_function_runtime -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.h b/tensorflow/compiler/xla/cpu_function_runtime.h similarity index 78% rename from tensorflow/compiler/tf2xla/cpu_function_runtime.h rename to tensorflow/compiler/xla/cpu_function_runtime.h index 78970fb39ba..281ca5b2203 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime.h +++ b/tensorflow/compiler/xla/cpu_function_runtime.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ -#define TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CPU_FUNCTION_RUNTIME_H_ +#define TENSORFLOW_COMPILER_XLA_CPU_FUNCTION_RUNTIME_H_ #include "tensorflow/core/platform/types.h" #include -namespace tensorflow { +namespace xla { namespace cpu_function_runtime { // Stores information about one buffer used by an XLA:CPU compiled function. // These buffers are used for holding inputs to the computation, outputs from @@ -28,10 +28,11 @@ namespace cpu_function_runtime { class BufferInfo { public: // Creates a BufferInfo from a serialized encoding generated by `Encode`. - explicit BufferInfo(std::pair encoding) + explicit BufferInfo( + std::pair encoding) : entry_param_number_(encoding.second) { Kind kind; - uint64 size; + tensorflow::uint64 size; Unpack(encoding.first, &kind, &size); kind_ = kind; size_ = size; @@ -47,7 +48,7 @@ class BufferInfo { bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; } // Returns the entry parameter number of this buffer. - uint64 entry_parameter_number() const { + tensorflow::uint64 entry_parameter_number() const { assert(is_entry_parameter()); return entry_param_number_; } @@ -61,16 +62,16 @@ class BufferInfo { bool is_on_stack_buffer() const { return kind() == Kind::kOnStackBuffer; } // Returns the size for this buffer. - uint64 size() const { return size_; } + tensorflow::uint64 size() const { return size_; } // Encodes this BufferInfo into two 64 bit integers that can be used to // reconstruct the BufferInfo later using the constructor. We need this // because we use BufferInfo in places where using protocol buffers would // negatively impact binary size. - std::pair Encode() const { + std::pair Encode() const { static_assert(sizeof(*this) == 16, ""); - uint64 upper = Pack(kind(), size_); - uint64 lower = entry_param_number_; + tensorflow::uint64 upper = Pack(kind(), size_); + tensorflow::uint64 lower = entry_param_number_; return {upper, lower}; } @@ -84,19 +85,20 @@ class BufferInfo { // Factory methods: - static BufferInfo MakeTempBuffer(uint64 size) { + static BufferInfo MakeTempBuffer(tensorflow::uint64 size) { return BufferInfo(Kind::kTempBuffer, /*size=*/size, /*entry_param_number=*/-1); } - static BufferInfo MakeConstant(uint64 size) { + static BufferInfo MakeConstant(tensorflow::uint64 size) { return BufferInfo(Kind::kConstant, /*size=*/size, /*entry_param_number=*/-1); } - static BufferInfo MakeEntryParameter(uint64 size, uint64 param_number) { + static BufferInfo MakeEntryParameter(tensorflow::uint64 size, + tensorflow::uint64 param_number) { return BufferInfo(Kind::kEntryParameter, /*size=*/size, /*entry_param_number=*/param_number); } - static BufferInfo MakeOnStackBuffer(uint64 size) { + static BufferInfo MakeOnStackBuffer(tensorflow::uint64 size) { return BufferInfo(Kind::kOnStackBuffer, /*size=*/size, /*entry_param_number=*/-1); } @@ -104,7 +106,7 @@ class BufferInfo { private: BufferInfo() = default; - enum class Kind : uint64 { + enum class Kind : tensorflow::uint64 { kConstant, kTempBuffer, kEntryParameter, @@ -113,21 +115,24 @@ class BufferInfo { Kind kind() const { return static_cast(kind_); } - explicit BufferInfo(Kind kind, uint64 size, uint64 entry_param_number) + explicit BufferInfo(Kind kind, tensorflow::uint64 size, + tensorflow::uint64 entry_param_number) : kind_(kind), size_(size), entry_param_number_(entry_param_number) {} - static uint64 Pack(Kind kind, uint64 size) { - return (static_cast(size) << 2) | static_cast(kind); + static tensorflow::uint64 Pack(Kind kind, tensorflow::uint64 size) { + return (static_cast(size) << 2) | + static_cast(kind); } - static void Unpack(uint64 packed, Kind* kind, uint64* size) { + static void Unpack(tensorflow::uint64 packed, Kind* kind, + tensorflow::uint64* size) { *size = packed >> 2; *kind = static_cast((packed << 62) >> 62); } Kind kind_ : 2; - uint64 size_ : 62; - int64 entry_param_number_; + tensorflow::uint64 size_ : 62; + tensorflow::int64 entry_param_number_; }; // Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. @@ -160,6 +165,6 @@ void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n, // MallocContiguousBuffers. void FreeContiguous(void* contiguous); } // namespace cpu_function_runtime -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ +#endif // TENSORFLOW_COMPILER_XLA_CPU_FUNCTION_RUNTIME_H_ diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 43d9ee0d9a5..35ee92aa829 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -33,7 +33,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_eliminate_hlo_implicit_broadcast(true); - opts.set_xla_hlo_dump_as_html(false); + opts.set_xla_dump_hlo_as_html(false); #ifdef INTEL_MKL opts.set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL @@ -53,6 +53,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_enable_fast_math(true); opts.set_xla_gpu_enable_fast_min_max(true); + opts.set_xla_allow_excess_precision(true); opts.set_xla_force_host_platform_device_count(1); return opts; } @@ -84,6 +85,14 @@ static void AllocateFlags() { }; }; + auto string_setter_for = + [](void (DebugOptions::*member_setter)(const string& value)) { + return [member_setter](const string& value) { + (flag_values->*member_setter)(value); + return true; + }; + }; + // Custom "sub-parser" lambda for xla_disable_hlo_passes. auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { std::vector disabled_passes = @@ -114,44 +123,26 @@ static void AllocateFlags() { }; flag_objects = new std::vector({ - tensorflow::Flag( - "xla_generate_hlo_graph", - flag_values->mutable_xla_generate_hlo_graph(), - "HLO modules matching this regex will be dumped to a .dot file " - "throughout various stages in compilation."), - tensorflow::Flag( - "xla_hlo_graph_addresses", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), - flag_values->xla_hlo_graph_addresses(), - "With xla_generate_hlo_graph, show addresses of HLO ops in " - "graph dump."), - tensorflow::Flag( - "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), - "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag("xla_hlo_dump_as_html", - bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), - flag_values->xla_hlo_dump_as_html(), - "Dump HLO graphs as an HTML (DOT rendered into SVG " - "inlined in HTML)."), - tensorflow::Flag( - "xla_hlo_graph_sharding_color", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), - flag_values->xla_hlo_graph_sharding_color(), - "Assign colors based on sharding assignments when generating the " - "HLO graphs."), - tensorflow::Flag( - "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), - "HLO modules matching this regex will be dumped to LOG(INFO)."), - tensorflow::Flag( - "xla_generate_hlo_text_to", - flag_values->mutable_xla_generate_hlo_text_to(), - "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( "xla_cpu_enable_fast_math", bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), flag_values->xla_cpu_enable_fast_math(), "Enable unsafe fast-math optimizations in the CPU compiler; " "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_cpu_fast_math_honor_nans", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans), + flag_values->xla_cpu_fast_math_honor_nans(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "allow operations to produce NaNs. Ignored when " + "xla_cpu_enable_fast_math is false."), + tensorflow::Flag( + "xla_cpu_fast_math_honor_infs", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs), + flag_values->xla_cpu_fast_math_honor_infs(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "allow operations to produce infinites. Ignored when " + "xla_cpu_enable_fast_math is false."), tensorflow::Flag( "xla_gpu_enable_fast_min_max", bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), @@ -210,9 +201,6 @@ static void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), flag_values->xla_embed_ir_in_executable(), "Embed the compiler IR as a string in the executable."), - tensorflow::Flag( - "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), - "Dump the compiler IR into this directory as individual files."), tensorflow::Flag( "xla_eliminate_hlo_implicit_broadcast", bool_setter_for( @@ -247,20 +235,6 @@ static void AllocateFlags() { int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), flag_values->xla_gpu_max_kernel_unroll_factor(), "Specify the maximum kernel unroll factor for the GPU backend."), - tensorflow::Flag( - "xla_dump_optimized_hlo_proto_to", - flag_values->mutable_xla_dump_optimized_hlo_proto_to(), - "Dump Hlo after all hlo passes are executed as proto binary into " - "this directory."), - tensorflow::Flag( - "xla_dump_unoptimized_hlo_proto_to", - flag_values->mutable_xla_dump_unoptimized_hlo_proto_to(), - "Dump HLO before any hlo passes are executed as proto binary into " - "this directory."), - tensorflow::Flag("xla_dump_per_pass_hlo_proto_to", - flag_values->mutable_xla_dump_per_pass_hlo_proto_to(), - "Dump HLO after each pass as an HloProto in binary file " - "format into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), @@ -283,14 +257,6 @@ static void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_profile), flag_values->xla_hlo_profile(), "Instrument the computation to collect per-HLO cycle counts"), - tensorflow::Flag("xla_dump_computations_to", - flag_values->mutable_xla_dump_computations_to(), - "Dump computations that XLA executes into the provided " - "directory path"), - tensorflow::Flag("xla_dump_executions_to", - flag_values->mutable_xla_dump_executions_to(), - "Dump parameters and results of computations that XLA " - "executes into the provided directory path"), tensorflow::Flag("xla_backend_extra_options", setter_for_xla_backend_extra_options, "", "Extra options to pass to a backend; " @@ -326,6 +292,11 @@ static void AllocateFlags() { flag_values->xla_gpu_crash_on_verification_failures(), "Crashes the program on extra verification failures, e.g. cuDNN " "cross checking failures"), + tensorflow::Flag( + "xla_gpu_disable_autotune", + bool_setter_for(&DebugOptions::set_xla_gpu_disable_autotune), + flag_values->xla_gpu_disable_autotune(), + "Disable GEMM and Convolution auto-tuning."), tensorflow::Flag( "xla_force_host_platform_device_count", int32_setter_for( @@ -343,6 +314,84 @@ static void AllocateFlags() { &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), flag_values->xla_gpu_disable_ptxas_optimizations(), "In XLA:GPU run ptxas in -O0 (default is -O3)."), + + tensorflow::Flag( + "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), + flag_values->xla_dump_to(), + "Directory into which debugging data is written. If not specified " + "but another dumping flag is passed, data will be written to stdout. " + " To explicitly write to stdout, set this to \"-\". The values " + "\"sponge\" and \"test_undeclared_outputs_dir\" have a special " + "meaning: They cause us to dump into the directory specified by the " + "environment variable TEST_UNDECLARED_OUTPUTS_DIR."), + tensorflow::Flag( + "xla_dump_hlo_as_text", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text), + flag_values->xla_dump_hlo_as_text(), + "Dumps HLO modules as text before and after optimizations. Results " + "are written to the --xla_dump_to dir, or, if no dir is specified, " + "to stdout."), + tensorflow::Flag( + "xla_dump_hlo_as_proto", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto), + flag_values->xla_dump_hlo_as_proto(), + "Dumps HLO modules as HloProtos to the directory specified by " + "--xla_dump_to."), + tensorflow::Flag( + "xla_dump_hlo_as_dot", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot), + flag_values->xla_dump_hlo_as_dot(), + "Dumps HLO modules rendered as dot files to the directory " + "specified by --xla_dump_to."), + tensorflow::Flag("xla_dump_hlo_as_html", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html), + flag_values->xla_dump_hlo_as_html(), + "Dumps HLO modules rendered as HTML files to the " + "directory specified by --xla_dump_to."), + tensorflow::Flag( + "xla_dump_hlo_as_url", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url), + flag_values->xla_dump_hlo_as_url(), + "Tries to dump HLO modules rendered as URLs to stdout (and also to " + "the directory specified by --xla_dump_to). This is not implemented " + "by default; you need to add a plugin which calls " + "RegisterGraphToURLRenderer()."), + tensorflow::Flag( + "xla_dump_hlo_snapshots", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots), + flag_values->xla_dump_hlo_snapshots(), + "Every time an HLO module is run, dumps an HloSnapshot to the " + "directory specified by --xla_dump_to."), + tensorflow::Flag( + "xla_dump_hlo_module_re", + string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re), + flag_values->xla_dump_hlo_module_re(), + "Limits dumping only to modules which match this regular expression. " + " Default is to dump all modules."), + tensorflow::Flag( + "xla_dump_hlo_pass_re", + string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re), + flag_values->xla_dump_hlo_pass_re(), + "If specified, dumps HLO before and after optimization passes which " + "match this regular expression, in addition to dumping at the very " + "beginning and end of compilation."), + tensorflow::Flag( + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), + "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays " + "the address in memory of each HloInstruction object."), + tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the " + "HLO graphs."), + tensorflow::Flag( + "xla_allow_excess_precision", + bool_setter_for(&DebugOptions::set_xla_allow_excess_precision), + flag_values->xla_allow_excess_precision(), + "Allow xla to increase the output precision of an instruction."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 230f3b202a4..39c90b60a09 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -26,12 +26,13 @@ ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal( int ExecutableRunOptions::device_ordinal() const { return device_ordinal_; } ExecutableRunOptions& ExecutableRunOptions::set_allocator( - DeviceMemoryAllocator* allocator) { + stream_executor::DeviceMemoryAllocator* allocator) { allocator_ = allocator; return *this; } -DeviceMemoryAllocator* ExecutableRunOptions::allocator() const { +stream_executor::DeviceMemoryAllocator* ExecutableRunOptions::allocator() + const { return allocator_; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 1e744953bd3..84629593953 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -23,6 +23,7 @@ limitations under the License. namespace stream_executor { class Stream; class Platform; +class DeviceMemoryAllocator; } // namespace stream_executor namespace Eigen { @@ -31,7 +32,6 @@ struct ThreadPoolDevice; namespace xla { -class DeviceMemoryAllocator; class DeviceAssignment; class ExecutionProfile; @@ -39,8 +39,9 @@ class ExecutionProfile; class ExecutableRunOptions { public: // Specifies the allocator to use during execution. - ExecutableRunOptions& set_allocator(DeviceMemoryAllocator* allocator); - DeviceMemoryAllocator* allocator() const; + ExecutableRunOptions& set_allocator( + stream_executor::DeviceMemoryAllocator* allocator); + stream_executor::DeviceMemoryAllocator* allocator() const; // If set, this is the device to run the computation on. Valid device_ordinal // values are: 0 to # of devices - 1. These values are identical to the device @@ -64,6 +65,12 @@ class ExecutableRunOptions { stream_executor::Stream* host_to_device_stream() const; // Sets the thread pool device on which to run Eigen subcomputations. + // + // This field must be set for XLA:CPU models that call Eigen routines, but may + // be null otherwise. Routines that use this field should always CHECK (or + // TF_RET_CHECK) that it's not null before dereferencing it, so that users get + // a clean crash rather than a segfault. + // // Does not take ownership. ExecutableRunOptions& set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool); @@ -81,7 +88,7 @@ class ExecutableRunOptions { int rng_seed() const; private: - DeviceMemoryAllocator* allocator_ = nullptr; + stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; const DeviceAssignment* device_assignment_ = nullptr; stream_executor::Stream* stream_ = nullptr; diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index c34e84efc80..7c458844a93 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -120,9 +120,14 @@ class Sharding(object): tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices))) - def apply_to_tensor(self, tensor): - """Applies this Sharding attribute to `tensor`.""" - if len(tensor.op.outputs) > 1: + def apply_to_tensor(self, tensor, assign_tuple_sharding=False): + """Applies this Sharding attribute to `tensor`. + + Args: + tensor: A tf.Tensor to split. + assign_tuple_sharding: If the sharding type should be a tuple. + """ + if len(tensor.op.outputs) > 1 or assign_tuple_sharding: proto = self._get_or_create_tuple_proto(tensor.op) # We can't mutate an element of old_proto.tuple_shardings, so create # a new proto. @@ -166,21 +171,30 @@ class Sharding(object): # tensor = xla_sharding.replicate(tensor) -def replicate(tensor): - Sharding.replicate().apply_to_tensor(tensor) +def replicate(tensor, assign_tuple_sharding=False): + Sharding.replicate().apply_to_tensor( + tensor, + assign_tuple_sharding=assign_tuple_sharding) return tensor -def assign_device(tensor, device): - Sharding.assign_device(device).apply_to_tensor(tensor) +def assign_device(tensor, device, assign_tuple_sharding=False): + Sharding.assign_device(device).apply_to_tensor( + tensor, + assign_tuple_sharding=assign_tuple_sharding) return tensor -def tile(tensor, tile_assignment): - Sharding.tile(tile_assignment).apply_to_tensor(tensor) +def tile(tensor, tile_assignment, assign_tuple_sharding=False): + Sharding.tile(tile_assignment).apply_to_tensor( + tensor, + assign_tuple_sharding=assign_tuple_sharding + ) return tensor -def split(tensor, split_dimension, num_devices): - Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor) +def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False): + Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor( + tensor, + assign_tuple_sharding=assign_tuple_sharding) return tensor diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index d756cd74c98..dafc3345555 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -29,6 +29,8 @@ upper_tabs: path: /xla/tiled_layout - title: Using AOT compilation path: /xla/tfcompile + - title: Writing custom calls + path: /xla/custom_call - heading: Tutorials - title: XLA compile API path: /xla/tutorials/xla_compile diff --git a/tensorflow/compiler/xla/g3doc/custom_call.md b/tensorflow/compiler/xla/g3doc/custom_call.md new file mode 100644 index 00000000000..acc2c9a92f5 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/custom_call.md @@ -0,0 +1,329 @@ +# XLA Custom Calls + +This document describes how to write and use XLA "custom calls". Custom calls +let you invoke code written in a programming language like C++ or CUDA from an +XLA program. + +Warning: Custom calls are a low-level power-user feature. It is easy to break +your program in difficult-to-debug (and even difficult-to-notice) ways using +custom-calls. You shouldn't use custom calls unless you're prepared to debug XLA +yourself when something goes wrong, and you should expect relatively less +assistance from XLA developers if you run into trouble. + +Warning: The custom-call API/ABI is not currently stable. We don't intend to +change it capriciously, but it may change. Some possible future changes are +described below. + +## Custom-call on CPU + +You can create an HLO instruction which represents a custom-call via XLA's +client API. This is not exposed via TensorFlow as of writing. + +For example, the following code uses a custom-call to compute +`A[i] = B[i % 128] + C[i]` on the CPU. (Of course you could -- and should! -- do +this with regular HLO.) + +```c++ +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" + +void do_it() { + xla::XlaBuilder b("do_it"); + xla::XlaOp param0 = + xla::Parameter(0, xla::ShapeUtil::CreateShape(F32, {128}), "p0"); + xla::XlaOp param1 = + xla::Parameter(1, xla::ShapeUtil::CreateShape(F32, {2048}), "p1"); + xla::XlaOp custom_call = + xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1}, + /*output_shape=*/ShapeUtil::CreateShape(F32, {2048})); +} + +void do_custom_call(void* out, const void** in) { + float* out_buf = reinterpret_cast(out); + const float* in0 = reinterpret_cast(in[0]); + const float* in1 = reinterpret_cast(in[1]); + for (int i = 0; i < 2048; ++i) { + out_buf[i] = in0[i % 128] + in1[i]; + } +} +XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host"); +``` + +Notice that the function `do_custom_call` needs to know the dimensions of the +buffers it operates over. In this example we hardcode the sizes 128 and 2048. If +you don't want to do this, you can pass the dimensions in as parameters to the +call. + +## Custom-call on GPU + +The GPU custom call framework is somewhat different than that on the CPU. Here +is a CUDA example that does the same `A[i] = B[i % 128] + C[i]` computation as +the CPU code above. + +```c++ +void do_it() { /* same implementation as above */ } + +__global__ custom_call_kernel(const float* in0, const float* in1, float* out) { + size_t idx = threadIdx.x * blockSize.x + gridIdx.x; + out[idx] = in0[idx % 128] + in1[idx]; +} + +void do_custom_call(CUstream stream, void** buffers, + const char* opaque, size_t opaque_len) { + const float* in0 = reinterpret_cast(buffers[0]); + const float* in1 = reinterpret_cast(buffers[1]); + float* out = reinterpret_cast(buffers[2]); + + const int64 block_dim = 64; + const int64 grid_dim = 2048 / block_dim; + custom_call_kernel<<>>(in0, in1, out); +} +XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA"); +``` + +Notice first that the GPU custom call function *is still a function executed on +the CPU*. Our `do_custom_call` CPU function is responsible for enqueueing work +on the GPU. Here it launches a CUDA kernel, but it could also do something else, +like call cublas. + +`buffers` is an array of pointers which lives on the host, and each element it +contains points to device (i.e. GPU) memory. The parameters come first, followed +by the output value. This is notably different from the CPU calling convention, +which has two params, `ins` and `out`. The main reason we diverge is to make it +possible to handle tuple-shaped inputs/outputs efficiently; see the section +below. + +As in the CPU example, we've hardcoded the input and output buffer sizes into +our custom call. However unlike in the CPU case, passing the buffer sizes in as +operands to the custom call would not work well. Usually we need the buffer +sizes available to us on the CPU; e.g. when launching a kernel, we need to know +the block/grid dimensions to use. But if we were to pass the buffer sizes as +operands to our custom call, their values would live in GPU memory. We'd then +have to do an expensive synchronous device-to-host memcpy at the start of our +operation just to read the sizes. + +To let you work around this, we provide the `opaque` parameter. You can set this +to an arbitrary string of bytes when you create the custom call: + +```c++ +std::string opaque = "..."; +xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1}, + /*output_shape=*/ShapeUtil::CreateShape(F32, {2048}), + opaque); +``` + +Since `xla::Shape` has a protocol buffer representation, you could store this +serialized proto inside of `opaque` and deserialize it within your GPU +custom-call. Note however that although `xla::ShapeProto` does not change +frequently, it *does* change. Check the git log to see how it has changed in the +past. + +## Passing tuples to custom-calls + +Consider the following custom-call. + +```c++ +using xla::ShapeUtil; +Shape p0_shape = ShapeUtil::MakeTuple({ + ShapeUtil::MakeShape(F32, {32}), + ShapeUtil::MakeTuple({ + ShapeUtil::MakeTuple(F32, {64}), + ShapeUtil::MakeTuple(F32, {128}), + }), + ShapeUtil::MakeShape(F32, {256}), +}); +xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0"); + +Shape out_shape = ShapeUtil::MakeTuple({ + ShapeUtil::MakeShape(F32, {512}), + ShapeUtil::MakeShape(F32, {1024}), +}); +xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape); +``` + +On both CPU and GPU, a tuple is represented in memory as an array of pointers. +In C++-pseudocode, parameter 0 above is laid out as follows. + +```c++ +// In-memory layout of parameter 0 from custom-call above. True on both CPU +// and GPU. +float* subbuf0 = new float[32]; +float* subbuf1 = new float[64]; +float* subbuf2 = new float[128] +float* subbuf3 = new float[256]; + +void* subtuple = new void*[2]; +(*subtuple)[0] = subbuf1; +(*subtuple)[1] = subbuf2; + +void* p0 = new void*[3]; +(*p0)[0] = subbuf0; +(*p0)[1] = subtuple; +(*p0)[2] = subbuf3; +``` + +Although the in-memory representation of tuples is the same in CPU and GPU, they +are handled differently in the CPU and GPU custom-call calling conventions. + +### Tuple outputs as temp buffers + +Tuple inputs to custom-calls are a convenience, but they aren't strictly +necessary. If we didn't support tuple inputs to custom calls, you could always +unpack the tuples using get-tuple-element before passing them to the custom +call. + +On the other hand, tuple *outputs* do let you do things you couldn't otherwise. + +The obvious reason to have tuple outputs is, that's how a custom call (or any +other XLA op) returns multiple independent arrays. + +But less obviously, a tuple output is also a way to give your custom call temp +memory. Yes, an *output* can represent a temp buffer. Consider, an output buffer +has the property that the op can write to it, and it can read from it after it's +been written to. That's exactly what you want from a temp buffer. + +In the example above, suppose we wanted to use the `F32[1024]` as a temp buffer. +Then we'd write the HLO just as above, and we'd simply never read tuple index 1 +of the custom call's output. + +### Tuples in CPU custom-calls + +In CPU code, we have a function `do_custom_call(const void** ins, void* out)`. +`ins` is an array with just one element, which points to `param0`. The +subbuffers of `param0` are accessible by dereferencing that pointer, and the +subbuffers of `output_tuple` are accessible by dereferencing `out`. + +### Tuples in GPU custom-calls + +In GPU code, we have a function `do_custom_call(..., void** buffers, ...)`. In +this case `buffers` is a host array of *nine* device pointers, one for each +nested buffer. To generate the flat list, we iterate over the parameters and +output, and then do preorder traversal of their shapes. Concretely: + +```c++ +// Layout of `buffers` parameter to GPU custom call function for custom-call +// above. +buffers[0] == param0 +buffers[1] == subbuf0 or null +buffers[2] == subtuple or null +buffers[3] == subbuf1 or null +buffers[4] == subbuf2 or null +buffers[5] == subbuf3 or null +buffers[6] == output_tuple +buffers[7] == output_subbuf0 +buffers[8] == output_subbuf1 +``` + +The `or null` part is significant. A sub-buffer of an input tuple will be +non-null in the `buffers` list if XLA is able to statically analyze the program +and figure out the address of the sub-buffer. This is usually the case, but may +not be in programs with control flow and/or `select` ops over tuples. + +A correct custom-call implementation that accepts a tuple as input must always +handle null input sub-buffers, by dereferencing the root tuple. + +The rule is reversed for output buffers. The output sub-buffers will always be +populated, but it's up to the custom call to populate the root tuple at the end. + +See the following code. Note that we leave out CUDA error handling for clarity, +but you'll be thankful if you do it, because otherwise it can be hard to tell +when a stream encounters an error. + +```c++ +void do_custom_call(CUstream stream, void** buffers, const char* opaque, + size_t opaque_len) { + bool needs_sync = false; + const float* subbuf0 = reinterpret_cast(buffers[1]); + if (subbuf0 == nullptr) { + needs_sync = true; + cudaMemcpyAsync(&subbuf0, buffers[0], sizeof(void*), + cudaMemcpyDeviceToHost, stream); + } + const void** subtuple = reinterpret_cast(buffers[2]); + if (subtuple == nullptr) { + needs_sync = true; + cudaMemcpyAsync(&subtuple, buffers[2], ...); + } + + // ... similarly for other params ... + + // Wait for copies enqueued above to complete. + if (needs_sync) { + cudaStreamSynchronize(stream); + } + needs_sync = false; + + // Now that we have `subtuple`, we can get subbuf1 and subbuf2. + float* subbuf1 = buffers[3]; + if (subbuf1 == nullptr) { + needs_sync = true; + cudaMemcpyAsync(&subbuf1, subtuple, ...); + } + float* subbuf2 = buffers[4]; + if (subbuf2 == nullptr) { + needs_sync = true; + cudaMemcpyAsync(&subbuf2, subtuple + 1, ...); + } + + // Wait for copies enqueued above to complete. + if (needs_sync) { + cudaStreamSynchronize(stream); + } + + // ... actually run the kernel ... + + // Fill the output tuple. + void* outputs[2] = {buffers[7], buffers[8]}; + cudaMemcpyAsync(buffers[6], outputs, sizeof(outputs), cudaMemcpyHostToDevice, + stream); + + // Necessary to force the cudaMemcpyAsync above to complete before `outputs` + // goes out of scope. A sync is only necessary in the tuple output case, and + // see below for a way to avoid this. + cudaStreamSynchronize(stream); +} +``` + +The `cudaStreamSynchronize` at the end of the function is unfortunate, as it's +not required in the non-tuple-output case, and it can be expensive. One way to +get around this would be to make `outputs` into a global variable and ensure +that the previous cudaMemcpyAsync completed before overwriting the global and +enqueueing another one. This is sketched below. + +``` +void do_custom_call(CUstream stream, void** buffers, const char* opaque, + size_t opaque_len) { + + // ... Beginning of function is the same as above ... + + // ... actually run the kernel ... + + static std::atomic first_time{true}; + static CUevent event; + static void* outputs[2]; + if (first_time.fetch_and(false)) { + // First time running this function. Initialize `event`. + cuEventCreate(&event, CU_EVENT_DISABLE_TIMING); + } else { + // Not first time running this function. Wait for previous event to + // complete before touching `outputs`. + cuEventSynchronize(event); + } + + // Fill the output tuple. + outputs[0] = buffers[7]; + outputs[1] = buffers[8]; + cudaMemcpyAsync(buffers[6], outputs, sizeof(outputs), cudaMemcpyHostToDevice, + stream); + + // Unblock `event` after the memcpy completes. + cuEventRecord(event, stream); +} +``` + +This simple implementation would limit parallelism if you want to run this op on +multiple GPUs concurrently (or on one GPU with multiple streams); in that case +you might need multiple events and globals. We have seen one implementation of +this algorithm which keeps a pool of globals and events and periodically polls +them (perhaps on each call to the op) to garbage collect. diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md index 85fa16ccc7f..d7ce5ee1ba6 100644 --- a/tensorflow/compiler/xla/g3doc/jit.md +++ b/tensorflow/compiler/xla/g3doc/jit.md @@ -144,7 +144,8 @@ Execute the python script to train the model with XLA and turn on a debugging feature of XLA via an environmental variable that outputs the XLA graph. ```shell -XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py +XLA_FLAGS="--xla_hlo_profile --xla_dump_to=/tmp/foo --xla_dump_hlo_as_text" +python mnist_softmax_xla.py ``` Open the timeline file created (`timeline.ctf.json`). The rendered timeline @@ -153,28 +154,10 @@ should look similar to the picture below with one long bar labeled `XlaLaunch`. -To understand what is happening in `XlaLaunch`, look at the console output for -statements similar to the following: +To understand what is happening in `XlaLaunch`, look at the console output. Each +XLA cluster that's launched will have a corresponding profile (from +`--xla_hlo_profile`) showing how long each HLO took to run. -```shell -computation cluster_0[_XlaCompiledKernel=true,_XlaNumConstantArgs=1].v82 [CPU: -pipeline start, before inline]: /tmp/hlo_graph_0.dot - -``` - -The console statements point to the location of `hlo_graph_xx.dot` files that -contain information about the graph created by XLA. The process that XLA takes -to fuse Ops is visible by starting at `hlo_graph_0.dot` and viewing each diagram -in succession. - -To Render the .dot file into a png, install -[GraphViz](https://www.graphviz.org/download/) and run: - -```shell -dot -Tpng hlo_graph_80.dot -o hlo_graph_80.png -``` - -The result will look like the following: -
- -
+`/tmp/foo` will contain the HLO before and after optimizations for each HLO +module that's run. You can read this as-is, or you can visualize it using +`tensorflow/compiler/xla/tools:interactive_graphviz`. diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index b3fdd36b113..9dbe8148c27 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -303,6 +303,37 @@ For example, if `operand` is a scalar `f32` with value `2.0f`, and `broadcast_sizes` is `{2, 3}`, then the result will be an array with shape `f32[2, 3]` and all the values in the result will be `2.0f`. +## BroadcastInDim + +See also +[`XlaBuilder::BroadcastInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Expands the size and rank of an array by duplicating the data in the array. + + `BroadcastInDim(operand, out_dim_size, broadcast_dimensions)` + +| Arguments | Type | Semantics | +| ---------------------- | ------------------- | ----------------------------- | +| `operand` | `XlaOp` | The array to duplicate | +| `out_dim_size` | `ArraySlice` | The sizes of the dimensions | +: : : of the target shape : +| `broadcast_dimensions` | `ArraySlice` | Which dimension in the target | +: : : shape each dimension of the : +: : : operand shape corresponds to : + +Similar to Broadcast, but allows adding dimensions anywhere and expanding +existing dimensions with size 1. + +The `operand` is broadcast to the shape described by `out_dim_size`. +`broadcast_dimensions` maps the dimensions of `operand` to the dimensions of the +target shape, i.e. the i'th dimension of the operand is mapped to the +broadcast_dimension\[i\]'th dimension of the output shape. The dimensions of +`operand` must have size 1 or be the same size as the dimension in in the output +shape they are mapped to. The remaining dimensions are filled with dimensions of +size 1. Degenerate-dimension broadcasting then broadcasts along these degenerate +dimensions to reach the output shape. The semantics are described in detail on +the [broadcasting page](broadcasting.md). + ## Call See also @@ -564,8 +595,7 @@ executed depending on the value of `pred`. | Arguments | Type | Semantics | | --------------------- | --------------------- | ---------------------------- | -| `branch_index` | `XlaOp` | Scalar of type `PRED` or | -: : : `S32` : +| `branch_index` | `XlaOp` | Scalar of type `S32` | | `branch_computations` | sequence of N | XlaComputations of type $$ | : : `XlaComputation` : T_0 \to S , T_1 \to S , ..., : : : : T_{N-1} \to S $$ : @@ -573,9 +603,8 @@ executed depending on the value of `pred`. : : : T_1 , ..., T_{N-1} $$ : Executes `branch_computations[branch_index]`, and returns the result. If -`branch_index` is a `PRED`, then the `true` branch is in position 0 and the -`false` branch is in position 1. If `branch_index` is an `S32` which is < 0 -or >= N, then `branch_computations[N-1]` is executed as the default branch. +`branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]` +is executed as the default branch. Each `branch_computations[b]` must take in a single argument of type `T_b` and will be invoked with `branch_operands[b]` which must be of the same type. The @@ -897,11 +926,11 @@ The exact semantics of this operation depend on the ranks of the operands: | matrix [m x k] `dot` | matrix [m x n] | matrix-matrix | : matrix [k x n] : : multiplication : -The operation performs sum of products over the last dimension of `lhs` and the -one-before-last dimension of `rhs`. These are the "contracted" dimensions. The -contracted dimensions of `lhs` and `rhs` must be of the same size. In practice, -it can be used to perform dot products between vectors, vector/matrix -multiplications or matrix/matrix multiplications. +The operation performs sum of products over the second dimension of `lhs` (or +the first if it has rank 1) and the first dimension of `rhs`. These are the +"contracted" dimensions. The contracted dimensions of `lhs` and `rhs` must be of +the same size. In practice, it can be used to perform dot products between +vectors, vector/matrix multiplications or matrix/matrix multiplications. ## DotGeneral @@ -1237,6 +1266,9 @@ if and only if the corresponding input element is finite. `LogicalNot(operand)` Element-wise logical not `x -> !(x)`. +`PopulationCount(operand)` Computes the number of bits set in each +element of `operand`. + `Neg(operand)` Element-wise negation `x -> -x`. `Sign(operand)` Element-wise sign operation `x -> sgn(x)` where @@ -1255,6 +1287,59 @@ Arguments | Type | Semantics The function is applied to each element in the `operand` array, resulting in an array with the same shape. It is allowed for `operand` to be a scalar (rank 0). +## Fft + +The XLA FFT operation implements the forward and inverse Fourier Transforms for +real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are +supported, except on TPU, where only a single axis is supported (please file a +github issue if you require higher order). + +See also +[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +| Arguments | Type | Semantics | +| ------------ | ------------------- | ------------------------ | +| `operand` | `XlaOp` | The array we are Fourier | +: : : transforming. : +| `fft_type` | `FftType` | See the table below. | +| `fft_length` | `ArraySlice` | The time-domain lengths | +: : : of the axes being : +: : : transformed. This is : +: : : needed in particular for : +: : : IRFFT to right-size the : +: : : innermost axis, since : +: : : `RFFT(fft_length=[16])` : +: : : has the same output : +: : : shape as : +: : : `RFFT(fft_length=[17])`. : + +| `FftType` | Semantics | +| --------- | ---------------------------------------------------------------- | +| `FFT` | Forward complex-to-complex FFT. Shape is unchanged. | +| `IFFT` | Inverse complex-to-complex FFT. Shape is unchanged. | +| `RFFT` | Forward real-to-complex FFT. Shape of the innermost axis is | +: : reduced to `fft_length[-1] // 2 + 1` if `fft_length[-1]` is a : +: : non-zero value, omitting the reversed conjugate part of the : +: : transformed signal beyond the Nyquist frequency. : +| `IRFFT` | Inverse real-to-complex FFT (i.e. takes complex, returns real). | +: : Shape of the innermost axis is expanded to `fft_length[-1]` if : +: : `fft_length[-1]` is a non-zero value, inferring the part of the : +: : transformed signal beyond the Nyquist frequency from the reverse : +: : conjugate of the `1` to `fft_length[-1] // 2 + 1` entries. : + +#### Multidimensional FFT + +When more than 1 `fft_length` is provided, this is equivalent to applying a +cascade of FFT operations to each of the innermost axes. Note that for the +real->complex and complex->real cases, the innermost axis transform is +(effectively) performed first (RFFT; last for IRFFT), which is why the innermost +axis is the one which changes size. Other axis transforms will then be +complex->complex. + +#### Implementation details + +CPU FFT is backed by Eigen's TensorFFT. GPU FFT uses cuFFT. + ## Gather The XLA gather operation stitches together several slices (each slice at a @@ -1396,8 +1481,8 @@ The element in the output array at index [`G`,`O``0`,`O``1`] is then the element in the input array at index [`X`+`O``0`,`Y`+`O``1`]. -`slice_sizes` is `[8,6]`, which decides the range of W`0` and -W`1`, and this in turn decides the bounds of the slice. +`slice_sizes` is `[8,6]`, which decides the range of O`0` and +O`1`, and this in turn decides the bounds of the slice. This gather operation acts as a batch dynamic slice with `G` as the batch dimension. @@ -1663,15 +1748,15 @@ Applies a reduction function to one or more arrays in parallel. `Reduce(operands..., init_values..., computation, dimensions)` -Arguments | Type | Semantics -------------- | --------------------- | --------------------------------------- -`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. -`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`. -`computation` | `XlaComputation` | computation of type - : : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)` -`dimensions` | `int64` array | unordered array of dimensions to reduce +| Arguments | Type | Semantics | +| ------------- | --------------------- | ------------------------------------ | +| `operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. | +| `init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`. | +| `computation` | `XlaComputation` | computation of type `T_0, ..., T_N, T_0, ..., T_N ->` `Collate(T_0, ..., T_N)`. | +| `dimensions` | `int64` array | unordered array of dimensions to reduce. | Where: + * N is required to be greater or equal to 1. * All input arrays must have the same dimensions. * If `N = 1`, `Collate(T)` is `T`. @@ -1681,10 +1766,10 @@ The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type `T_i`, the dimensions of which are described below. This operation reduces one or more dimensions of each input array into scalars. -The rank of each returned array is `rank(operand) - len(dimensions)`. -`init_value` is the initial value used for every reduction and may be inserted +The rank of each returned array is `rank(operand) - len(dimensions)`. The +initial value used for every reduction is `init_value`, and it may be inserted anywhere during computation by the back-end. In most cases, `init_value` is an -identity of the reduction function (for example, 0 for addition). The applied +identity of the reduction function (for example, `0` for addition). The applied `computation` is always passed the `init_value` on the left-hand side. The evaluation order of the reduction function is arbitrary and may be @@ -1695,10 +1780,10 @@ Some reduction functions like addition are not strictly associative for floats. However, if the range of the data is limited, floating-point addition is close enough to being associative for most practical uses. It is possible to conceive of some completely non-associative reductions, however, and these will produce -incorrect or unpredictable results in XLA reductions. +incorrect or unpredictable results in XLA. As an example, when reducing across one dimension in a single 1D array with -values [10, 11, 12, 13], with reduction function `f` (this is `computation`) +values `[10, 11, 12, 13]`, with reduction function `f` (this is `computation`) then that could be computed as `f(10, f(11, f(12, f(init_value, 13)))` @@ -1777,16 +1862,27 @@ preserved in the output, but some dimensions may get assigned new numbers (since the rank changes). We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces -the 1D array `| 20 28 36 |`. +the 1D array `[20, 28, 36]`. Reducing the 3D array over all its dimensions produces the scalar `84`. -When `N > 1`, reduce function application is slightly more complex, as it is -applied simultaneously to all inputs. For example, consider the following -reduction function, which can be used to compute the max and the argmax of a a -1-D array in parallel: +### Variadic Reduce -``` +When `N > 1`, reduce function application is slightly more complex, as it is +applied simultaneously to all inputs. The operands are supplied to the +computation in the following order: + +* Running reduced value for the first operand +* ... +* Running reduced value for the N'th operand +* Input value for the first operand +* ... +* Input value for the N'th operand + +For example, consider the following reduction function, which can be used to +compute the max and the argmax of a 1-D array in parallel: + +```python f: (Float, Int, Float, Int) -> Float, Int f(max, argmax, value, index): if value >= argmax: @@ -1798,6 +1894,7 @@ f(max, argmax, value, index): For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values `I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only input dimension is equivalent to the following recursive application: + ``` f_0 = f(I_V, I_K, V_0, K_0) f_1 = f(f_0.first, f_0.second, V_1, K_1) @@ -2104,7 +2201,7 @@ Arguments | Type | Semantics `operand` | `XlaOp` | Array to be scattered into. `scatter_indices` | `XlaOp` | Array containing the starting indices of the slices that must be scattered to. `updates` | `XlaOp` | Array containing the values that must be used for scattering. -`update_computation` | `XlaComputation` | Computation to be used for combining the existing values in the input array and the updates during scatter. This computation should be of type `T, T -> T`. +`update_computation` | `XlaComputation` | Computation to be used for combining the existing values in the input array and the updates during scatter. This computation should be of type `(T, T) -> T`. `index_vector_dim` | `int64` | The dimension in `scatter_indices` that contains the starting indices. `update_window_dims` | `ArraySlice` | The set of dimensions in `updates` shape that are _window dimensions_. `inserted_window_dims` | `ArraySlice` | The set of _window dimensions_ that must be inserted into `updates` shape. @@ -2438,43 +2535,58 @@ Slice(b, {2, 1}, {4, 3}) produces: See also [`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). -There are two versions of the Sort instruction: a single-operand and a -multi-operand version. +`Sort(operands, comparator, dimension, is_stable)` -`Sort(operand, dimension)` +Arguments | Type | Semantics +------------ | ------------------- | -------------------- +`operands` | `ArraySlice` | The operands to sort. +`comparator` | `XlaComputation` | The comparator computation to use. +`dimension` | `int64` | The dimension along which to sort. +`is_stable` | `bool` | Whether stable sorting should be used. -Arguments | Type | Semantics ------------ | ------- | -------------------- -`operand` | `XlaOp` | The operand to sort. -`dimension` | `int64` | The dimension along which to sort. +If only one operand is provided: -Sorts the elements in the operand in ascending order along the provided -dimension. For example, for a rank-2 (matrix) operand, a `dimension` value of 0 -will sort each column independently, and a `dimension` value of 1 will sort each -row independently. If the operand's elements have floating point type, and the -operand contains NaN elements, the order of elements in the output is -implementation-defined. +* If the operand is a rank-1 tensor (an array), the result is a sorted array. + If you want to sort the array into ascending order, the comparator should + perform a less-than comparison. Formally, after the array is sorted, it holds + for all index positions `i, j` with `i < j` that either + `comparator(value[i], value[j]) = comparator(value[j], value[i]) = false` or + `comparator(value[i], value[j]) = true`. -`Sort(keys, values, ... values, dimension)` +* If the operand has higher rank, the operand is sorted along the provided + dimension. For example, for a rank-2 tensor (a matrix), a dimension value of + `0` will independently sort every column, and a dimension value of `1` will + independently sort each row. If no dimension number is provided, then the last + dimension is chosen by default. For the dimension which is sorted, the same + sorting order applies as in the rank-1 case. -Sorts both the key and one or more value operands. The keys are sorted as in the -single-operand version. Each of the values inputs is sorted according to the -order of the corresponding keys. For example, if the three inputs are `keys = -[3, 1]`, `values0 = [42, 50]`, `values1 = [-3.0, 1.1]`, then the output of the -sort is the tuple `{[1, 3], [50, 42], [1.1, -3.0]}`. +If `n > 1` operands are provided: -The sort is not guaranteed to be stable, that is, if the keys array contains -duplicates, the order of values corresponding to these keys may not be -preserved. +* All `n` operands must be tensors with the same dimensions. The element types + of the tensors may be different. -Arguments | Type | Semantics ------------ | ---------------------- | ---------------------------------- -`keys` | `XlaOp` | The sort keys. -`values` | Sequence of N `XlaOp`s | The values to sort. -`dimension` | `int64` | The dimension along which to sort. +* All operands are sorted together, not individually. Conceptually the operands + are treated as a tuple. When checking whether the elements of each operand at + index positions `i` and `j` need to be swapped, the comparator is called with + `2 * n` scalar parameters, where parameter `2 * k` corresponds to the value at + position `i` from the `k-th` operand, and parameter `2 * k + 1` corresponds to + the value at position `j` from the `k-th` operand. Usually, the comparator + would thus compare parameters `2 * k` and `2 * k + 1` with each other and + possibly use other parameter pairs as tie breakers. -The `keys` and each of the `values` inputs must have the same dimensions, but -may have different element types. +* The result is a tuple that consists of the operands in sorted order (along + the provided dimension, as above). The `i-th` operand of the tuple corresponds + to the `i-th` operand of Sort. + +For example, if there are three operands `operand0 = [3, 1]`, +`operand1 = [42, 50]`, `operand2 = [-3.0, 1.1]`, and the comparator compares +only the values of `operand0` with less-than, then the output of the sort is the +tuple `([1, 3], [50, 42], [1.1, -3.0])`. + +If `is_stable` is set to true, the sort is guaranteed to be stable, that is, if +there are elements which are considered to be equal by the comparator, the +relative order of the equal values is preserved. By default, `is_stable` is set +to false. ## Transpose diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc index 000c4fdc405..7052ec09f35 100644 --- a/tensorflow/compiler/xla/layout.cc +++ b/tensorflow/compiler/xla/layout.cc @@ -96,8 +96,13 @@ string Layout::ToString() const { } bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { - if (lhs.format() != rhs.format() || - lhs.minor_to_major() != rhs.minor_to_major() || + if (lhs.format() != rhs.format()) { + return false; + } + if (lhs.format() == DENSE && lhs.minor_to_major() != rhs.minor_to_major()) { + return false; + } + if (lhs.format() == SPARSE && lhs.max_sparse_elements() != rhs.max_sparse_elements()) { return false; } diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index acc449b781b..4721c9fcaa1 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -69,6 +69,11 @@ class Tile { // combined with the next minor dimension before tiling is applied. static constexpr int64 kCombineDimension = std::numeric_limits::min(); + template + friend H AbslHashValue(H h, const Tile& t) { + return H::combine(std::move(h), t.dimensions_); + } + private: // The bounds of the tile. std::vector dimensions_; @@ -127,6 +132,12 @@ class Layout { return *this; } + Equal& MinorToMajorOnly() { + ignore_tiles_ = true; + ignore_element_size_ = true; + return *this; + } + private: bool ignore_tiles_ = false; bool ignore_element_size_ = false; @@ -206,6 +217,13 @@ class Layout { element_size_in_bits_ = 0; } + template + friend H AbslHashValue(H h, const Layout& l) { + return H::combine(std::move(h), l.format_, l.minor_to_major_, + l.max_sparse_elements_, l.tiles_, + l.element_size_in_bits_); + } + private: // The format of this layout. Format format_ = INVALID_FORMAT; diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 5cd738d0f77..23eaf318223 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -293,8 +293,9 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return InvalidArgument("LiteralProto has no shape"); } Shape shape(proto.shape()); - if (ShapeUtil::HasPrimitiveType(shape, OPAQUE)) { - return InvalidArgument("Literal shape cannot include OPAQUE sub-shape"); + if (ShapeUtil::HasPrimitiveType(shape, OPAQUE_TYPE)) { + return InvalidArgument( + "Literal shape cannot include OPAQUE_TYPE sub-shape"); } if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("LiteralProto has no layout"); @@ -912,6 +913,24 @@ StatusOr LiteralBase::GetIntegralAsS64( } } +StatusOr LiteralBase::GetAsDouble( + absl::Span multi_index) const { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case F16: + return static_cast(Get(multi_index)); + case F32: + return static_cast(Get(multi_index)); + case F64: + return Get(multi_index); + case BF16: + return static_cast(Get(multi_index)); + default: + return FailedPrecondition("Array element type is not floating: %s", + PrimitiveType_Name(shape().element_type())); + } +} + size_t LiteralBase::Hash() const { using tensorflow::Hash64; using tensorflow::Hash64Combine; @@ -962,6 +981,29 @@ Status MutableLiteralBase::SetIntegralAsS64(absl::Span multi_index, return Status::OK(); } +Status MutableLiteralBase::SetFromDouble(absl::Span multi_index, + double value) { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case F16: + Set(multi_index, Eigen::half(value)); + break; + case F32: + Set(multi_index, value); + break; + case F64: + Set(multi_index, value); + break; + case BF16: + Set(multi_index, static_cast(value)); + break; + default: + return FailedPrecondition("Array element type is not floating: %s", + PrimitiveType_Name(shape().element_type())); + } + return Status::OK(); +} + absl::Span LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index c418be895d6..c810ae9cbae 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -134,6 +134,10 @@ class LiteralBase { // int64. This literal must be an array. StatusOr GetIntegralAsS64(absl::Span multi_index) const; + // As Get(), but determines the correct type, and converts the value into + // double. This literal must be an array. + StatusOr GetAsDouble(absl::Span multi_index) const; + // Returns the multi-index of the element in a sparse literal at the given // sparse element number. The sparse element number is the position with in // the sparse array's list of (index, value) pairs, and is checked against the @@ -637,6 +641,10 @@ class MutableLiteralBase : public LiteralBase { // This literal must be an array. Status SetIntegralAsS64(absl::Span multi_index, int64 value); + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetFromDouble(absl::Span multi_index, double value); + // Populate this literal with the given values. Examples: // // // Populate with floats. diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 9b3de75dd4e..0431bb3d54a 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -238,21 +238,20 @@ string FpValueToString(complex128 value) { return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } -// Returns the absolute value of the given floating point value. This function -// is used instead of std::abs directly in order to allow type-dependent -// implementations for NearComparator. +// A wrapper of std::abs to include data types that are not supported by +// std::abs, in particular, bfloat16 and half. template -float FpAbsoluteValue(NativeT value) { +double FpAbsoluteValue(NativeT value) { return std::abs(value); } template <> -float FpAbsoluteValue(bfloat16 value) { +double FpAbsoluteValue(bfloat16 value) { return FpAbsoluteValue(static_cast(value)); } template <> -float FpAbsoluteValue(half value) { +double FpAbsoluteValue(half value) { return FpAbsoluteValue(static_cast(value)); } @@ -278,8 +277,8 @@ class NearComparator { struct Mismatch { NativeT actual; NativeT expected; - float rel_error; - float abs_error; + double rel_error; + double abs_error; // The linear index of the failure within the shape. This linear index is // from the 'actual' literal. @@ -340,7 +339,7 @@ class NearComparator { void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { // Adjust the bucket containing the absolute values of the 'actual' // elements. - const float abs_value = FpAbsoluteValue(value); + const double abs_value = FpAbsoluteValue(value); for (int i = 0; i < abs_value_buckets_.size(); ++i) { if (i == abs_value_buckets_.size() - 1 || (abs_value >= kAbsValueBucketBounds[i] && @@ -370,8 +369,8 @@ class NearComparator { // the given literal_index and keeps track of various mismatch statistics. template void CompareValues(T expected, T actual, int64 linear_index) { - float abs_error; - float rel_error; + double abs_error; + double rel_error; if (CompareEqual(expected, actual, {linear_index})) { abs_error = 0; rel_error = 0; @@ -459,48 +458,30 @@ class NearComparator { // For complex types, we compare real and imaginary parts individually. void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { - bool mismatch = false; + const auto both_parts_mismatch = num_mismatches_ + 2; CompareValues(expected.real(), actual.real(), linear_index); - if (mismatches_.data()[linear_index] == true) { - mismatch = true; - // Delay the mismatch count increase for real part, instead increase - // mismatch by 1 for the entire complex number. - num_mismatches_--; - } CompareValues(expected.imag(), actual.imag(), linear_index); - if (mismatches_.data()[linear_index] == true) { - mismatch = true; - // Delay the mismatch count increase for imag part, instead increase - // mismatch by 1 for the entire complex number. + if (num_mismatches_ == both_parts_mismatch) { + // The mismatch counter had been incremented by each CompareValues() call, + // which means that both real and imaginary parts of the passed-in complex + // values are different. However, the counter should reflect a single + // mismatch between these complex values. num_mismatches_--; } - if (mismatch == true) { - num_mismatches_++; - } - mismatches_.data()[linear_index] = mismatch; } void CompareValues(complex128 expected, complex128 actual, int64 linear_index) { - bool mismatch = false; + const auto both_parts_mismatch = num_mismatches_ + 2; CompareValues(expected.real(), actual.real(), linear_index); - if (mismatches_.data()[linear_index] == true) { - mismatch = true; - // Delay the mismatch count increase for real part, instead increase - // mismatch by 1 for the entire complex number. - num_mismatches_--; - } CompareValues(expected.imag(), actual.imag(), linear_index); - if (mismatches_.data()[linear_index] == true) { - mismatch = true; - // Delay the mismatch count increase for imag part, instead increase - // mismatch by 1 for the entire complex number. + if (num_mismatches_ == both_parts_mismatch) { + // The mismatch counter had been incremented by each CompareValues() call, + // which means that both real and imaginary parts of the passed-in complex + // values are different. However, the counter should reflect a single + // mismatch between these complex values. num_mismatches_--; } - if (mismatch == true) { - num_mismatches_++; - } - mismatches_.data()[linear_index] = mismatch; } // Compares the two literals elementwise. diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index b54a71ae682..8d46d30b4cf 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 26b029c8d0c..323481455d6 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -136,7 +136,7 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(false); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; - case OPAQUE: + case OPAQUE_TYPE: LOG(FATAL) << "opaque element type cannot take on value of 0"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; @@ -176,7 +176,7 @@ Literal ConvertType(LiteralSlice literal) { LOG(FATAL) << "u16/s16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; - case OPAQUE: + case OPAQUE_TYPE: LOG(FATAL) << "opaque element type cannot take on value of 1"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; @@ -220,7 +220,7 @@ Literal ConvertType(LiteralSlice literal) { static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; - case OPAQUE: + case OPAQUE_TYPE: LOG(FATAL) << "opaque element type has no minimum value"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; @@ -260,7 +260,7 @@ Literal ConvertType(LiteralSlice literal) { static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; - case OPAQUE: + case OPAQUE_TYPE: LOG(FATAL) << "opaque element type has no maximum value"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 1eedddf72c1..2143d1dfbe7 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -89,8 +89,8 @@ int BitWidth(PrimitiveType type) { case TUPLE: LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; - case OPAQUE: - LOG(FATAL) << "OPAQUE is an invalid type for BitWidth"; + case OPAQUE_TYPE: + LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth"; default: LOG(FATAL) << "Unhandled primitive type " << type; @@ -126,17 +126,22 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) { bool IsArrayType(PrimitiveType primitive_type) { return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && - primitive_type != OPAQUE && primitive_type != TOKEN; + primitive_type != OPAQUE_TYPE && primitive_type != TOKEN; } // Class to memoize the computation of // absl::AsciiStrToLower(PrimitiveType_Name(p)) // for all PrimitiveType values "p" +// +// xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason +// it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro. class PrimitiveTypeNameGenerator { public: PrimitiveTypeNameGenerator() { for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { + if (i == static_cast(OPAQUE_TYPE)) { + lowercase_name_[i] = "opaque"; + } else if (PrimitiveType_IsValid(i)) { lowercase_name_[i] = absl::AsciiStrToLower( PrimitiveType_Name(static_cast(i))); } @@ -158,6 +163,9 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { namespace { // Returns a map from lower-case primitive type name to primitive type. +// +// Due to Postel's Law considerations, both "opaque" and "opaque_type" map to +// the xla::OPAQUE_TYPE enumerator. const std::unordered_map& GetPrimitiveTypeStringMap() { static std::unordered_map* name_to_type = [] { static auto* map = new std::unordered_map; @@ -167,6 +175,7 @@ const std::unordered_map& GetPrimitiveTypeStringMap() { (*map)[LowercasePrimitiveTypeName(value)] = value; } } + (*map)["opaque"] = OPAQUE_TYPE; return map; }(); return *name_to_type; diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index ac342bf40fb..e476015f94f 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -38,42 +38,14 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, return (serialized1 == serialized2); } -namespace { - -std::pair>*> -GetDirectoryExpanders() { - static auto* mutex = new tensorflow::mutex; - static auto* singleton = new std::vector>; - return {mutex, singleton}; -} - -// Runs all the directory expanders over x and returns the result. -string Expand(string x) { - auto pair = GetDirectoryExpanders(); - tensorflow::mutex_lock lock(*pair.first); - for (const auto& f : *pair.second) { - x = f(x); - } - return x; -} - -} // namespace - Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); - string expanded_dir = Expand(directory); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir)); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name); + const string path = tensorflow::io::JoinPath(directory, safe_file_name); return tensorflow::WriteBinaryProto(env, path, message); } -void RegisterDirectoryExpander(const std::function& expander) { - auto pair = GetDirectoryExpanders(); - tensorflow::mutex_lock lock(*pair.first); - pair.second->push_back(expander); -} - } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 4a88a48f285..e20a7e95a63 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ -#include "google/protobuf/duration.pb.h" #include "absl/time/time.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -45,20 +44,6 @@ Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, // dirpath along as-is. void RegisterDirectoryExpander(const std::function& expander); -// Converts an absl::Duration to a google::protobuf::Duration. -inline google::protobuf::Duration ToDurationProto(absl::Duration duration) { - google::protobuf::Duration proto; - proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); - proto.set_nanos( - absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); - return proto; -} - -// Converts a google::protobuf::Duration to an absl::Duration. -inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { - return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); -} - } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 0aed81f1024..45a3a264fd6 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -2,18 +2,20 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins") +load("//tensorflow:tensorflow.bzl", "tf_pybind_extension") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") py_library( name = "xla_client", - srcs = ["xla_client.py"], + srcs = [ + "xla_client.py", + "xrt.py", + ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ - ":pywrap_xla", - ], + deps = [":xla_extension"], ) pyx_library( @@ -26,6 +28,7 @@ py_test( name = "xla_client_test", srcs = ["xla_client_test.py"], main = "xla_client_test.py", + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_oss"], deps = [ @@ -37,49 +40,38 @@ py_test( ) cc_library( - name = "numpy_bridge", - srcs = ["numpy_bridge.cc"], - hdrs = ["numpy_bridge.h"], + name = "worker_thread", + srcs = ["worker_thread.cc"], + hdrs = ["worker_thread.h"], deps = [ - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/python:numpy_lib", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", + "@com_google_absl//absl/synchronization", ], ) cc_library( - name = "local_computation_builder", - srcs = ["local_computation_builder.cc"], - hdrs = ["local_computation_builder.h"], + name = "types", + srcs = ["types.cc"], + hdrs = ["types.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + "-Wno-c++98-c++11-compat", + ], + features = ["-use_header_modules"], deps = [ - "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:qr", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:lib", - "//third_party/python_runtime:headers", # buildcleaner: keep - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", + "//tensorflow/stream_executor:device_memory_allocator", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:optional", + "@pybind11", ], ) @@ -87,72 +79,171 @@ cc_library( name = "xrt", srcs = ["xrt.cc"], hdrs = ["xrt.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + "-Wno-c++98-c++11-compat", + ], + features = ["-use_header_modules"], deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", + ":types", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xrt/client:xrt_client", + "//tensorflow/compiler/xrt/client:xrt_grpc_eager_client", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + "@pybind11", + ], +) + +cc_library( + name = "shared_device_buffer", + srcs = ["shared_device_buffer.cc"], + hdrs = ["shared_device_buffer.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/stream_executor:device_memory_allocator", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "shared_device_buffer_test", + srcs = ["shared_device_buffer_test.cc"], + deps = [ + ":shared_device_buffer", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "local_client", + srcs = ["local_client.cc"], + hdrs = ["local_client.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + "-Wno-c++98-c++11-compat", + ], + features = ["-use_header_modules"], + deps = [ + ":shared_device_buffer", + ":types", + ":worker_thread", + "//tensorflow/compiler/jit:xla_launch_util", + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xrt:xrt_proto", - "//tensorflow/compiler/xrt/cc:xrt_ops", - "//tensorflow/core:framework", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:bfc_allocator", + "//tensorflow/core:gpu_mem_allocator", "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@pybind11", ], ) -tf_py_wrap_cc( - name = "pywrap_xla", +tf_pybind_extension( + name = "xla_extension", srcs = [ - "xla.i", + "xla.cc", ], - swig_includes = [ - "local_computation_builder.i", - "xla_data.i", - "//tensorflow/python:platform/base.i", + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + "-Wno-c++98-c++11-compat", ], - version_script = select({ - "//tensorflow:macos": "pywrap_xla_exported_symbols.lds", - "//tensorflow:windows": None, - "//conditions:default": "pywrap_xla_version_script.lds", - }), + features = ["-use_header_modules"], + module_name = "xla_extension", deps = [ - ":local_computation_builder", - ":numpy_bridge", + ":local_client", + ":shared_device_buffer", + ":types", + ":worker_thread", + ":xrt", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@pybind11", + "//third_party/python_runtime:headers", # buildcleaner: keep "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", + "//tensorflow/compiler/xla/client/lib:svd", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:lib", + # Do NOT remove this dependency. The XLA Python extension must not + # depend on any part of TensorFlow at runtime, **including** + # libtensorflow_framework.so. The XLA module is deployed self-contained + # without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does + # not require Tensorflow. + "//tensorflow/core:lib_internal_impl", # buildcleaner: keep + "//tensorflow/stream_executor:device_memory_allocator", ] + xla_python_default_plugins(), ) -tf_py_wrap_cc( - name = "pywrap_xrt", - srcs = [ - "xrt.i", - ], - swig_includes = [ - "xla_data.i", - "//tensorflow/python:platform/base.i", - ], - version_script = select({ - "//tensorflow:macos": "pywrap_xla_exported_symbols.lds", - "//tensorflow:windows": None, - "//conditions:default": "pywrap_xla_version_script.lds", - }), - visibility = ["//visibility:public"], - deps = [ - ":numpy_bridge", - ":xrt", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_data_proto", - ], -) +# TODO(phawkins): enable this test. +# py_test( +# name = "xrt_test", +# srcs = ["xrt_test.py"], +# deps = [ +# ":xla_client", +# "//third_party/py/numpy", +# "//tensorflow/compiler/jit:xla_cpu_device", +# "//tensorflow/compiler/xrt:xrt_server", +# "//tensorflow/python:client_testlib", +# ], +# ) diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc new file mode 100644 index 00000000000..facc61d515d --- /dev/null +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -0,0 +1,789 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation notes: +// +// Asynchronous execution: +// ----------------------- +// +// If 'asynchronous' is set when constructing the client, computations and +// host-to-device transfers do not block the host waiting for the operation to +// complete but instead return control to the host immediately. This allows +// Python logic to overlap with device-side computation. +// +// For a good user experience, we must be careful only to enqueue operations +// that are unlikely to fail; as a rule error checking must be done eagerly +// before returning control to the client. +// +// Multi-stream execution: +// ----------------------- +// +// On certain platforms (e.g., TPU), we use a multistream execution design, +// where different Streams are used for host-to-device transfers, +// device-to-host transfers, and compute. This allows us to overlap transfers on +// and off the device with computation. +// +// Synchronization between streams occurs via BufferDefinitionEvents that +// describe when the contents of a logical buffer are known to be valid on +// a particular stream. +// +// Synchronous vs asynchronous deallocation: +// ----------------------------------------- +// +// In asynchronous deallocation mode (currently only enabled on TPU), the client +// need only keep buffers alive from its perspective until all operations that +// touch those buffers have been enqueued. +// The allocator and lower-level runtime is responsible for keeping buffers +// alive (if that is needed) from the perspective of the device until any +// device-side work actually completes. The client's use of the device allocator +// thereby corresponds to a view of the tail of the compute stream instead of +// its head. +// +// In synchronous deallocation mode the client is responsible for keeping +// buffers alive until all device-side activity that consumes those buffers has +// ceased. This is the case for CPU since HostExecutor performs allocation +// and deallocation eagerly. In this mode, the client's use of the device +// allocator is logically synchronized to the head of the compute stream, not +// the tail. + +#include "tensorflow/compiler/xla/python/local_client.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/blocking_counter.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "include/pybind11/pybind11.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/bfc_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace xla { + +namespace py = pybind11; + +// Registers a 'fn_capsule' as a CPU custom call target. +// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name +// "xla._CPU_CUSTOM_CALL_TARGET". +Status RegisterCpuCustomCallTarget(const std::string& fn_name, + py::capsule capsule) { + static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET"; + if (absl::string_view(capsule.name()) != kName) { + return InvalidArgument( + "Argument to RegisterCpuCustomCallTargetRegistry was not a " + "xla._CPU_CUSTOM_CALL_TARGET capsule."); + } + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule), "Host"); + return Status::OK(); +} + +PythonRefManager::ManagedPyObjects::ManagedPyObjects( + PythonRefManager* manager, absl::Span objects) + : manager_(manager) { + objects_.reserve(objects.size()); + for (pybind11::object& object : objects) { + objects_.push_back(std::move(object)); + } +} + +PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { + if (manager_) { + absl::MutexLock lock(&manager_->mu_); + for (pybind11::object& object : objects_) { + manager_->python_garbage_.push_back(std::move(object)); + } + } +} + +PythonRefManager::ManagedPyObjects PythonRefManager::ManageReferences( + absl::Span objects) { + return ManagedPyObjects(this, objects); +} + +void PythonRefManager::CollectGarbage() { + // TODO(phawkins): ideally we would assert that the GIL is held, but there is + // no API to do this across all Python versions. + absl::MutexLock lock(&mu_); + python_garbage_.clear(); +} + +Device::Device(se::StreamExecutor* executor, bool use_multiple_streams, + bool synchronous_deallocation, bool asynchronous) + : use_multiple_streams_(use_multiple_streams), + synchronous_deallocation_(synchronous_deallocation), + asynchronous_(asynchronous) { + compute_stream_ = std::make_shared(executor); + compute_stream_->Init(); + if (use_multiple_streams) { + host_to_device_stream_ = std::make_shared(executor); + device_to_host_stream_ = std::make_shared(executor); + callback_stream_ = std::make_shared(executor); + host_to_device_stream_->Init(); + device_to_host_stream_->Init(); + callback_stream_->Init(); + } else { + callback_stream_ = host_to_device_stream_ = device_to_host_stream_ = + compute_stream_; + } + worker_thread_ = absl::make_unique(tensorflow::Env::Default(), + "py_xla_execute"); +} + +Device::~Device() { + bool ok = compute_stream_->parent()->SynchronizeAllActivity(); + if (!ok) { + LOG(ERROR) << "SynchronizeAllActivity failed when destroying Device."; + } +} + +void Device::ThenExecuteOnWorkerThread(se::Stream* stream, + std::function callback) const { + stream->ThenDoHostCallback( + [this, callback]() { worker_thread_->Schedule(std::move(callback)); }); +} + +static StatusOr> +CreateBFCAllocator(se::Platform* platform, LocalClient* client, + double memory_fraction) { + CHECK_GT(client->backend().device_count(), 0); + std::vector> allocators; + for (se::StreamExecutor* executor : client->backend().stream_executors()) { + int device_ordinal = executor->device_ordinal(); + tensorflow::GPUMemAllocator* sub_allocator = + new tensorflow::GPUMemAllocator( + executor, tensorflow::PlatformGpuId(device_ordinal), + /*use_unified_memory=*/false, /*alloc_visitors=*/{}, + /*free_visitors=*/{}); + + int64 free_memory; + int64 total_memory; + if (!executor->DeviceMemoryUsage(&free_memory, &total_memory)) { + return Unavailable("Failed to query available memory from device %i", + device_ordinal); + } + size_t allocator_memory = free_memory * memory_fraction; + LOG(INFO) << "XLA backend reserving " << allocator_memory << " out of " + << total_memory << " bytes on device " << device_ordinal + << " for BFCAllocator."; + + tensorflow::BFCAllocator* gpu_bfc_allocator = new tensorflow::BFCAllocator( + sub_allocator, allocator_memory, /*allow_growth=*/false, + absl::StrCat("GPU_", device_ordinal, "_bfc")); + allocators.emplace_back(gpu_bfc_allocator); + } + return absl::make_unique( + platform, std::move(allocators)); +} + +StatusOr> PyLocalClient::Get( + const std::string& platform_name, const std::string& xla_platform_name, + bool asynchronous, const AllocatorConfig& allocator_config) { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(xla_platform_name)); + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("Platform %s (%s) has no visible devices.", + platform_name, xla_platform_name); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + std::unique_ptr allocator; + if (allocator_config.kind == AllocatorConfig::Kind::kBFC || + (platform_name == "gpu" && + allocator_config.kind == AllocatorConfig::Kind::kDefault)) { + if (platform_name != "gpu") { + return Unimplemented("BFCAllocator only available for GPU."); + } + TF_ASSIGN_OR_RETURN( + auto bfc_allocator, + CreateBFCAllocator(platform, client, allocator_config.memory_fraction)); + allocator = std::move(bfc_allocator); + } + return std::make_shared(platform_name, client, + std::move(allocator), asynchronous); +} + +PyLocalClient::PyLocalClient( + std::string platform_name, LocalClient* client, + std::unique_ptr allocator, bool asynchronous) + : platform_name_(std::move(platform_name)), + client_(client), + owned_allocator_(std::move(allocator)), + h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer", + client->device_count()) { + if (owned_allocator_ != nullptr) { + allocator_ = owned_allocator_.get(); + } else { + allocator_ = client_->backend().memory_allocator(); + } + devices_.reserve(client->device_count()); + // TODO(phawkins): enable multistream mode on GPU too. + bool use_multiple_streams = (platform_name == "tpu"); + bool synchronous_deallocation = !use_multiple_streams; + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client_->backend().stream_executor(i).ValueOrDie(); + devices_.push_back(absl::make_unique(executor, use_multiple_streams, + synchronous_deallocation, + asynchronous)); + } +} + +Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal, + int device_ordinal) { + py_ref_manager().CollectGarbage(); + py::gil_scoped_release gil_release; + return client_->TransferToInfeedLocal(literal, device_ordinal); +} + +StatusOr PyLocalClient::TransferFromOutfeed( + const Shape& shape, int device_ordinal) { + py_ref_manager().CollectGarbage(); + Literal literal; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + literal, client_->TransferFromOutfeedLocal(shape, device_ordinal)); + } + return LiteralToPython(absl::make_unique(std::move(literal))); +} + +static StatusOr TransferHostToDeviceAsync( + const PythonBufferTree& tree, int device_ordinal, + std::shared_ptr client, const Device& device) { + se::DeviceMemoryAllocator* allocator = client->allocator(); + TransferManager* transfer_manager = + client->client()->backend().transfer_manager(); + TF_ASSIGN_OR_RETURN( + Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + transfer_manager->AllocateScopedShapedBuffer( + shape, allocator, device_ordinal)); + TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( + device.host_to_device_stream(), buffer)); + + auto it = tree.leaves.begin(); + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(shape)) { + TF_RET_CHECK(it != tree.leaves.end()); + ShapedBuffer leaf( + indexed_shape.shape, + transfer_manager->HostShapeToDeviceShape(indexed_shape.shape), + client->client()->platform(), device_ordinal); + leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {}); + if (device.use_multiple_streams() && + !transfer_manager->CanShapedBufferBeAccessedNow( + device.host_to_device_stream()->parent(), leaf)) { + device.host_to_device_stream()->ThenWaitFor(device.compute_stream()); + } + TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync( + device.host_to_device_stream(), *it, leaf)); + ++it; + } + std::shared_ptr definition_event; + if (device.use_multiple_streams()) { + definition_event = std::make_shared( + device.host_to_device_stream()->parent()); + definition_event->RecordOnStream(device.host_to_device_stream()); + } + std::shared_ptr device_buffer = + PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(buffer), + definition_event); + if (device.synchronous_deallocation()) { + device.ThenReleaseOnWorkerThread(device.host_to_device_stream(), + device_buffer); + } + return PyLocalBuffer(shape, std::move(device_buffer), std::move(client)); +} + +/* static */ +StatusOr PyLocalBuffer::FromPython( + const py::object& argument, std::shared_ptr client, + int device_ordinal) { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython"); + TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument)); + + client->py_ref_manager().CollectGarbage(); + + // Take a reference to the buffer to ensure that the inputs in host memory + // remain live until the transfer is complete. + auto py_buffer_ref = + client->py_ref_manager().ManageReferences(absl::MakeSpan(tree.arrays)); + + // We are done manipulating Python objects; release the GIL. + py::gil_scoped_release gil_release; + VLOG(1) << "PyLocalBuffer::FromPython: shape: " << tree.shape.ToString() + << " device ordinal: " << device_ordinal; + + const Device& device = client->device(device_ordinal); + TF_ASSIGN_OR_RETURN(PyLocalBuffer buffer, + TransferHostToDeviceAsync(tree, device_ordinal, + std::move(client), device)); + + device.ThenRelease(device.host_to_device_stream(), std::move(py_buffer_ref)); + return buffer; +} + +/*static */ StatusOr> +PyLocalBuffer::FromPythonValues( + const std::vector>& arguments, + std::shared_ptr client) { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPythonValues"); + int num_arguments = static_cast(arguments.size()); + std::vector outputs(num_arguments); + if (num_arguments == 0) { + return outputs; + } + + struct H2DTransfer { + PythonBufferTree tree; + StatusOr buffer; + PythonRefManager::ManagedPyObjects py_buffer_refs; + }; + + std::vector transfers(num_arguments); + for (int i = 0; i < num_arguments; ++i) { + TF_ASSIGN_OR_RETURN(transfers[i].tree, + GetPythonBufferTree(arguments[i].first)); + transfers[i].py_buffer_refs = client->py_ref_manager().ManageReferences( + absl::MakeSpan(transfers[i].tree.arrays)); + } + client->py_ref_manager().CollectGarbage(); + // We are done manipulating Python objects; release the GIL. + py::gil_scoped_release gil_release; + + auto transfer_h2d = [&](int i) -> StatusOr { + int device_ordinal = arguments[i].second; + return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client, + client->device(device_ordinal)); + }; + + // We perform the transfers on a thread pool in case XLA needs to do any + // host-side preprocessing of the input data. + if (num_arguments == 1) { + transfers[0].buffer = transfer_h2d(0); + } else { + absl::BlockingCounter counter(num_arguments); + for (int i = 0; i < num_arguments; ++i) { + client->h2d_transfer_pool()->Schedule([&, i]() { + transfers[i].buffer = transfer_h2d(i); + counter.DecrementCount(); + }); + } + counter.Wait(); + } + + // Release our references once the transfers have completed. + for (int i = 0; i < num_arguments; ++i) { + int device_ordinal = arguments[i].second; + const Device& device = client->device(device_ordinal); + device.ThenRelease(device.host_to_device_stream(), + std::move(transfers[i].py_buffer_refs)); + } + + for (int i = 0; i < num_arguments; ++i) { + TF_ASSIGN_OR_RETURN(outputs[i], std::move(transfers[i].buffer)); + } + return outputs; +} + +/* static */ StatusOr PyLocalBuffer::MakeTuple( + const std::vector buffers, + std::shared_ptr client, int device_ordinal) { + std::vector host_shapes; + std::vector> device_buffers; + host_shapes.reserve(buffers.size()); + device_buffers.reserve(buffers.size()); + for (const PyLocalBuffer& buffer : buffers) { + TF_RET_CHECK(buffer.device_buffer()->device_memory().device_ordinal() == + device_ordinal); + host_shapes.push_back(buffer.on_host_shape()); + device_buffers.push_back(buffer.device_buffer()); + } + se::DeviceMemoryAllocator* allocator = client->allocator(); + TransferManager* transfer_manager = + client->client()->backend().transfer_manager(); + const Device& device = client->device(device_ordinal); + std::shared_ptr definition_event; + if (device.use_multiple_streams()) { + definition_event = std::make_shared( + device.host_to_device_stream()->parent()); + } + TF_ASSIGN_OR_RETURN(std::shared_ptr tuple_buffer, + PySharedDeviceBuffer::MakeTuple( + device_buffers, transfer_manager, allocator, + device_ordinal, definition_event)); + PyLocalBuffer buffer(ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, + std::move(client)); + + // TODO(phawkins): extend TransferManager so we do not need to form a full + // ShapedBuffer just to write the root tuple index table. + ShapedBuffer shaped_buffer = buffer.AsShapedBuffer(); + if (device.use_multiple_streams() && + !transfer_manager->CanShapedBufferBeAccessedNow( + device.host_to_device_stream()->parent(), shaped_buffer)) { + // Wait for the compute stream so that memory allocations are synchronized. + device.host_to_device_stream()->ThenWaitFor(device.compute_stream()); + } + TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable( + device.host_to_device_stream(), shaped_buffer)); + if (definition_event) { + definition_event->RecordOnStream(device.host_to_device_stream()); + } + + if (device.synchronous_deallocation()) { + device.ThenReleaseOnWorkerThread(device.host_to_device_stream(), + std::move(tuple_buffer)); + } + return buffer; +} + +PyLocalBuffer::PyLocalBuffer( + Shape on_host_shape, std::shared_ptr device_buffer, + std::shared_ptr client) + : client_(std::move(client)), + on_host_shape_(std::move(on_host_shape)), + device_buffer_(std::move(device_buffer)) {} + +StatusOr PyLocalBuffer::ToPython() const { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython"); + auto literal = absl::make_unique(on_host_shape()); + client_->py_ref_manager().CollectGarbage(); + { + py::gil_scoped_release gil_release; + se::Stream* stream = client_->device(device_buffer_->device_ordinal()) + .device_to_host_stream(); + WaitForBufferDefinitionEventsOnStream(*device_buffer_, stream); + absl::Notification done; + Status status; + client_->client()->backend().transfer_manager()->TransferLiteralFromDevice( + stream, AsShapedBuffer(), *literal, [&](Status done_status) { + status = done_status; + done.Notify(); + }); + done.WaitForNotification(); + } + return LiteralToPython(std::move(literal)); +} + +ShapedBuffer PyLocalBuffer::AsShapedBuffer() const { + return device_buffer_->AsShapedBuffer(on_host_shape_); +} + +StatusOr> PyLocalBuffer::DestructureTuple() { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::DestructureTuple"); + if (!on_host_shape().IsTuple()) { + return InvalidArgument( + "Attemped to destructure a PyLocalBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(on_host_shape())); + } + int num_children = ShapeUtil::TupleElementCount(on_host_shape()); + std::vector results; + results.reserve(num_children); + for (int64 i = 0; i < num_children; ++i) { + results.push_back(PyLocalBuffer(on_host_shape().tuple_shapes(i), + device_buffer_->children().at(i), client_)); + } + return results; +} + +PyLocalExecutable::PyLocalExecutable( + std::shared_ptr executable, + DeviceAssignment device_assignment, std::shared_ptr client) + : client_(std::move(client)), + executable_(std::move(executable)), + device_assignment_(std::move(device_assignment)) {} + +std::vector PyLocalExecutable::DeviceOrdinals() const { + int num_replicas = device_assignment_.replica_count(); + std::vector device_ordinals; + device_ordinals.reserve(num_replicas); + for (int i = 0; i < num_replicas; ++i) { + device_ordinals.push_back(device_assignment_(i, 0)); + } + return device_ordinals; +} + +StatusOr PyLocalExecutable::ExecuteHelper( + absl::Span argument_handles, int replica) { + const int device_ordinal = device_assignment_(replica, 0); + tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " << device_ordinal; + + absl::flat_hash_set events; + std::vector argument_buffers; + std::vector argument_buffer_ptrs; + argument_buffers.reserve(argument_handles.size()); + argument_buffer_ptrs.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + if (handle->device_buffer() == nullptr) { + return InvalidArgument( + "Deleted buffer passed to Execute() as argument " + "%d to replica %d", + argument_buffers.size(), replica); + } + if (handle->device_buffer()->device_ordinal() != device_ordinal) { + return InvalidArgument( + "Buffer passed to Execute() as argument %d to replica %d is on " + "device %d, but replica is assigned to device %d.", + argument_buffers.size(), replica, + handle->device_buffer()->device_ordinal(), device_ordinal); + } + argument_buffers.push_back(handle->AsShapedBuffer()); + argument_buffer_ptrs.push_back(&argument_buffers.back()); + GetDeviceBufferDefinitionEvents(*handle->device_buffer(), &events); + VLOG(4) << "Argument " << argument_buffers.size() - 1 + << " buffer: " << argument_buffers.back().ToString(); + } + + const Device& device = client_->device(device_ordinal); + // The choice of where we wait in "synchronous" mode is arbitrary; the reason + // for the wait is pacing to avoid problems such as memory fragmentation, not + // for correctness. + if (!device.asynchronous()) { + TF_RETURN_IF_ERROR(device.compute_stream()->BlockHostUntilDone()); + } + + for (BufferDefinitionEvent* event : events) { + event->WaitForEventOnStream(device.compute_stream()); + } + + ExecutableRunOptions options; + options.set_stream(device.compute_stream()); + options.set_host_to_device_stream(device.host_to_device_stream()); + options.set_allocator(client_->allocator()); + options.set_intra_op_thread_pool( + client_->client()->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); + + StatusOr result_buffer = + executable_->RunAsync(argument_buffer_ptrs, options); + + VLOG(1) << "Replica " << replica << " completed; ok=" << result_buffer.ok(); + if (!result_buffer.ok()) { + LOG(ERROR) << "Execution of replica " << replica + << " failed: " << result_buffer.status(); + return result_buffer.status(); + } + + std::shared_ptr definition_event; + if (device.use_multiple_streams()) { + definition_event = std::make_shared( + device.compute_stream()->parent()); + definition_event->RecordOnStream(device.compute_stream()); + } + Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape(); + std::shared_ptr out_buffer = + PySharedDeviceBuffer::FromScopedShapedBuffer( + std::move(result_buffer.ValueOrDie()), definition_event); + + if (device.synchronous_deallocation()) { + std::vector> buffers; + buffers.reserve(argument_handles.size() + 1); + for (auto& handle : argument_handles) { + buffers.push_back(handle->device_buffer()); + } + buffers.push_back(out_buffer); + device.ThenReleaseOnWorkerThread(device.compute_stream(), + std::move(buffers)); + device.ThenReleaseOnWorkerThread(device.compute_stream(), executable_); + } + return PyLocalBuffer(on_host_shape, std::move(out_buffer), client_); +} + +StatusOr PyLocalExecutable::Execute( + absl::Span argument_handles) { + if (num_replicas() != 1) { + return InvalidArgument( + "Attempted to execute computation with %d replicas using Execute()", + num_replicas()); + } + return ExecuteHelper(argument_handles, /*replica=*/0); +} + +StatusOr> PyLocalExecutable::ExecutePerReplica( + absl::Span> argument_handles) { + tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica"); + const int num_devices = client_->device_count(); + + if (argument_handles.size() != num_replicas()) { + return InvalidArgument( + "Attempted to execute with %d replicas when replica count is %d", + argument_handles.size(), num_devices); + } + if (argument_handles.size() > num_devices) { + return InvalidArgument( + "Attempted to execute with %d replicas when device count is %d", + argument_handles.size(), num_devices); + } + + VLOG(1) << "Executing replicated computation; num_replicas=" + << num_replicas(); + std::vector> results(num_replicas()); + if (num_replicas() == 1) { + // Fast-path if there is only one replica — run the computation on the + // current thread. + results[0] = ExecuteHelper(argument_handles[0], /*replica=*/0); + } else { + absl::Mutex mu; + int running GUARDED_BY(mu) = num_replicas(); + int failed GUARDED_BY(mu) = 0; + Status first_failure_status GUARDED_BY(mu); + + for (int replica = 0; replica < num_replicas(); ++replica) { + const int device_ordinal = device_assignment_(replica, 0); + const Device& device = client_->device(device_ordinal); + device.worker_thread()->Schedule([&, replica] { + results[replica] = ExecuteHelper(argument_handles[replica], replica); + + absl::MutexLock lock(&mu); + --running; + if (!results[replica].ok()) { + if (failed == 0) { + first_failure_status = results[replica].status(); + } + ++failed; + } + }); + } + + auto done_running_or_failed = [&]() { + mu.AssertHeld(); + return running == 0 || failed > 0; + }; + absl::MutexLock lock(&mu); + mu.Await(absl::Condition(&done_running_or_failed)); + if (failed > 0) { + auto done_running = [&]() { + mu.AssertHeld(); + return running == 0; + }; + // If execution does not terminate within a reasonable amount of time, we + // may be stuck at a cross-replica barrier on-device. Terminate the + // process since that's the only way we can escape this situation at the + // moment (b/130629719). + if (!mu.AwaitWithTimeout(absl::Condition(&done_running), + absl::Seconds(10))) { + LOG(FATAL) + << "Replicated computation launch failed, but not all replicas " + "terminated. Aborting process to work around deadlock. Failure " + "message (there may have been multiple failures, see the " + "error log for all failures): \n\n" + << first_failure_status.error_message(); + } + } + } + VLOG(1) << "Replicated execution complete."; + + std::vector wrapped_results(num_replicas()); + for (int replica = 0; replica < num_replicas(); ++replica) { + auto& statusor = results[replica]; + if (!statusor.ok()) { + return AppendStatus( + statusor.status(), + absl::StrFormat( + "while running replica %d of a replicated computation (other " + "replicas may have failed as well).", + replica)); + } + wrapped_results[replica] = std::move(statusor.ValueOrDie()); + } + return wrapped_results; +} + +/*static*/ StatusOr> +PyLocalExecutable::Compile(const XlaComputation& computation, + std::vector argument_layouts, + const ExecutableBuildOptions* build_options, + std::shared_ptr client) { + tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile"); + std::vector argument_layout_pointers; + argument_layout_pointers.reserve(argument_layouts.size()); + + // Assign a default layout to any array subshapes that are missing layouts. + auto assign_layouts = [client](Shape* shape) { + return ShapeUtil::ForEachMutableSubshapeWithStatus( + shape, [&](Shape* subshape, const ShapeIndex&) { + if (subshape->IsArray() && !subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + TF_ASSIGN_OR_RETURN(*subshape, + client->client() + ->backend() + .transfer_manager() + ->ChooseCompactLayoutForShape(*subshape)); + } + return Status::OK(); + }); + }; + + for (Shape& layout : argument_layouts) { + argument_layout_pointers.push_back(&layout); + TF_RETURN_IF_ERROR(assign_layouts(&layout)); + } + + ExecutableBuildOptions options; + if (build_options != nullptr) { + options = *build_options; + } + + Shape result_layout; + if (options.result_layout()) { + result_layout = *options.result_layout(); + } else { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + result_layout = program_shape.result(); + LayoutUtil::ClearLayout(&result_layout); + } + TF_RETURN_IF_ERROR(assign_layouts(&result_layout)); + options.set_result_layout(result_layout); + + TF_ASSIGN_OR_RETURN(std::unique_ptr local_executable, + client->client()->Compile( + computation, argument_layout_pointers, options)); + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client->client()->backend().computation_placer()->AssignDevices( + options.num_replicas(), /*computation_count=*/1)); + + return absl::make_unique( + std::shared_ptr(std::move(local_executable)), + std::move(device_assignment), std::move(client)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h new file mode 100644 index 00000000000..1ad0f933007 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_client.h @@ -0,0 +1,337 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "include/pybind11/pybind11.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/python/worker_thread.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Registers a 'fn_capsule' as a CPU custom call target. +// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name +// "xla._CPU_CUSTOM_CALL_TARGET". +Status RegisterCpuCustomCallTarget(const std::string& fn_name, + pybind11::capsule capsule); + +// Class that manages destruction of Python objects. +// +// We must not destroy Python objects without holding the GIL. However, we +// frequently want to hold references to Python objects for the duration of +// an asynchronous transfer on a Stream, and release our reference when the +// transfer completes. +// +// This class holds references to Python objects outside a GIL scope, that can +// be collected later when the GIL is held by calling CollectGarbage(). +class PythonRefManager { + public: + PythonRefManager() = default; + + // Holds references to a set of pybind11::objects, adding the references to + // the PythonRefManager on destruction. + class ManagedPyObjects { + public: + ManagedPyObjects() = default; + ManagedPyObjects(PythonRefManager* manager, + absl::Span objects); + + ~ManagedPyObjects(); + + ManagedPyObjects(const ManagedPyObjects& other) = default; + ManagedPyObjects(ManagedPyObjects&& other) = default; + ManagedPyObjects& operator=(const ManagedPyObjects& other) = default; + ManagedPyObjects& operator=(ManagedPyObjects&& other) = default; + + private: + PythonRefManager* manager_ = nullptr; + absl::InlinedVector objects_; + }; + + // Creates a managed std::shared_ptr to an object. When the shared_ptr is + // destroyed, the reference to 'object' will be added to python_garbage_, + // and collected next time CollectGarbage() is called. + ManagedPyObjects ManageReferences(absl::Span objects); + + // Releases the contents of python_garbage_. Requires that the GIL is held. + // The client calls this method during API entry points where the GIL is held + // to free any garbage that has accumulated. + void CollectGarbage(); + + private: + absl::Mutex mu_; + std::deque python_garbage_ GUARDED_BY(mu_); +}; + +// Class that encapsulates state relating to a device (e.g., a GPU) on which we +// can perform computation and transfers. +class Device { + public: + // If use_multiple_streams is true, we allocate separate streams for compute + // and transfers. If it is false, we share a single stream for compute and + // transfers. The CPU device does not support multiple streams, and this is + // a workaround until it does. + // + // If synchronous_deallocation is true, the host must not free buffers until + // compute/transfers that use those buffers have completed. For example, this + // typically is the case for the "platform" where compute/transfers are + // operations that take place on another thread. + // + // If asynchronous is false, the host will synchronize to the device after + // each execution or transfer. This is intended for debugging only. + Device(se::StreamExecutor* executor, bool use_multiple_streams, + bool synchronous_deallocation, bool asynchronous); + ~Device(); + + bool use_multiple_streams() const { return use_multiple_streams_; } + bool synchronous_deallocation() const { return synchronous_deallocation_; } + bool asynchronous() const { return asynchronous_; } + se::Stream* compute_stream() const { return compute_stream_.get(); } + se::Stream* host_to_device_stream() const { + return host_to_device_stream_.get(); + } + se::Stream* device_to_host_stream() const { + return device_to_host_stream_.get(); + } + + // A worker thread, used for replicated computation launches and callbacks. + WorkerThread* worker_thread() const { return worker_thread_.get(); } + + // Enqueues a host callback on 'stream', to be executed by worker_thread_. + // ThenDoHostCallback is often constrained in what it can do, in particular, + // on GPU the callback runs on a thread belonging to the GPU runtime and + // cannot perform GPU operations itself. + void ThenExecuteOnWorkerThread(se::Stream* stream, + std::function callback) const; + + // Helper for releasing values from a callback at the tail of a stream. + // This is only permitted if object's destructor will not free any device + // objects, since the callback may be called from a device thread pool on + // GPU. + template + void ThenRelease(se::Stream* stream, T object) const { + if (callback_stream_.get() != stream) { + callback_stream_->ThenWaitFor(stream); + } + callback_stream_->ThenDoHostCallback( + std::bind([](T& object) { /* releases object */ }, std::move(object))); + } + + // Helpers for releasing values on a worker thread at the tail of a stream on + // a worker thread. + template + void ThenReleaseOnWorkerThread(se::Stream* stream, + std::shared_ptr object) const { + // We use a non-smart pointer here because we want to ensure that the worker + // thread is the only callee of the shared_ptr destructor, and if we passed + // object by lambda capture we have a race where the worker thread might + // run and release its reference first. + auto* ref = new std::shared_ptr(std::move(object)); + if (callback_stream_.get() != stream) { + callback_stream_->ThenWaitFor(stream); + } + ThenExecuteOnWorkerThread(callback_stream_.get(), [ref]() { delete ref; }); + } + template + void ThenReleaseOnWorkerThread(se::Stream* stream, + std::vector> object) const { + auto* ref = new std::vector>(std::move(object)); + if (callback_stream_.get() != stream) { + callback_stream_->ThenWaitFor(stream); + } + ThenExecuteOnWorkerThread(callback_stream_.get(), [ref]() { delete ref; }); + } + + private: + bool use_multiple_streams_; + bool synchronous_deallocation_; + bool asynchronous_; + std::shared_ptr compute_stream_; + std::shared_ptr host_to_device_stream_; + std::shared_ptr device_to_host_stream_; + + // Callback stream is used for running short host-side callbacks after device + // side events, without preventing the device-side stream from doing useful + // work. + std::shared_ptr callback_stream_; + + std::unique_ptr worker_thread_; +}; + +struct AllocatorConfig { + enum class Kind { + kDefault, // Client picks the best option for the platform. + kPlatform, // The platform's default. + kBFC, // Allocator using a "Best-Fit with Coalescing" algorithm. Currently + // only available for GPU. + }; + Kind kind = Kind::kDefault; + + // Only used if kind == kBFC. Fraction of available memory to allocate. + double memory_fraction = .9; +}; + +// Encapsulates the state of Python session with XLA. +class PyLocalClient { + public: + // Initializes a local XLA client for `platform_name`. Returns an error if no + // such platform exists, or if the platform has no visible devices. + static StatusOr> Get( + const std::string& platform_name, const std::string& xla_platform_name, + bool asynchronous, const AllocatorConfig& allocator_config); + + // `allocator` may null, in which case the platform default allocator is used. + explicit PyLocalClient(std::string platform_name, LocalClient* client, + std::unique_ptr allocator, + bool asynchronous); + virtual ~PyLocalClient() = default; + + Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal); + StatusOr TransferFromOutfeed(const Shape& shape, + int device_ordinal); + + int device_count() const { return client_->device_count(); } + const Device& device(int device_ordinal) const { + return *devices_.at(device_ordinal); + } + LocalClient* client() const { return client_; } + se::DeviceMemoryAllocator* allocator() const { return allocator_; } + + tensorflow::thread::ThreadPool* h2d_transfer_pool() { + return &h2d_transfer_pool_; + } + + PythonRefManager& py_ref_manager() { return py_ref_manager_; } + + protected: + std::string platform_name_; + LocalClient* client_; + std::vector> devices_; + se::DeviceMemoryAllocator* allocator_; + std::unique_ptr owned_allocator_; + + tensorflow::thread::ThreadPool h2d_transfer_pool_; + + PythonRefManager py_ref_manager_; +}; + +// Holds a reference from Python to one or more device buffers. +class PyLocalBuffer { + public: + static StatusOr FromPython( + const pybind11::object& argument, std::shared_ptr client, + int device_ordinal); + + // Converts multiple (python object, device ordinal) pairs into + // PyLocalBuffers in parallel. + static StatusOr> FromPythonValues( + const std::vector>& argument, + std::shared_ptr client); + + static StatusOr MakeTuple( + const std::vector buffers, + std::shared_ptr client, int device_ordinal); + + PyLocalBuffer() = default; + PyLocalBuffer(Shape on_host_shape, + std::shared_ptr device_buffer, + std::shared_ptr client); + StatusOr ToPython() const; + const Shape& on_host_shape() const { return on_host_shape_; } + const std::shared_ptr& device_buffer() const { + return device_buffer_; + } + int device_ordinal() const { return device_buffer_->device_ordinal(); } + + void Delete() { + device_buffer_ = nullptr; + client_ = nullptr; + } + + // Returns a view of the PyLocalBuffer DAG as a ShapedBuffer. The + // PyLocalBuffer retains ownership of the device buffers. + ShapedBuffer AsShapedBuffer() const; + + // Destructures a tuple-valued PyLocalBuffer into its constituent elements. + StatusOr> DestructureTuple(); + + private: + std::shared_ptr client_ = nullptr; + Shape on_host_shape_; + std::shared_ptr device_buffer_; +}; + +// Represents a compiled computation that can be executed given handles to +// device-allocated literals. Wraps an XLA LocalExecutable. +class PyLocalExecutable { + public: + // Compiles a computation to an executable. + static StatusOr> Compile( + const XlaComputation& computation, std::vector argument_layouts, + const ExecutableBuildOptions* build_options, + std::shared_ptr client); + + PyLocalExecutable(std::shared_ptr executable, + DeviceAssignment device_assignment, + std::shared_ptr client); + + int num_replicas() const { + return executable_->build_options().num_replicas(); + } + + // Returns the device ordinals to which each replica is assigned. + std::vector DeviceOrdinals() const; + + const DeviceAssignment& device_assignment() const { + return device_assignment_; + } + + StatusOr Execute( + absl::Span argument_handles); + + // Execute on many replicas. Takes a sequence of argument lists (one argument + // list per replica) and returns a tuple of results (one result per replica). + // The number of argument lists must be equal to the replica count. + StatusOr> ExecutePerReplica( + absl::Span> argument_handles); + + void Delete() { executable_ = nullptr; } + + private: + StatusOr ExecuteHelper( + absl::Span argument_handles, int replica); + + std::shared_ptr const client_; + std::shared_ptr executable_; + const DeviceAssignment device_assignment_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc deleted file mode 100644 index 5cfbb2c20df..00000000000 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ /dev/null @@ -1,855 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/python/local_computation_builder.h" - -#include -#include -#include - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/qr.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/computation_placer.h" -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" -#include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace swig { - -Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { - const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; - if (!PyCapsule_IsValid(capsule, name)) { - return InvalidArgument( - "Argument to RegisterCpuCustomCallTargetRegistry was not a " - "xla._CPU_CUSTOM_CALL_TARGET capsule."); - } - void* fn_ptr = PyCapsule_GetPointer(capsule, name); - CHECK(fn_ptr != nullptr); - cpu::CustomCallTargetRegistry::Global()->Register( - std::string(fn_name.begin(), fn_name.end()), fn_ptr); - return Status::OK(); -} - -LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {} - -/* static */ StatusOr LocalClient::Get( - const string& platform_name) { - TF_ASSIGN_OR_RETURN(se::Platform * platform, - PlatformUtil::GetPlatform(platform_name)); - if (platform->VisibleDeviceCount() <= 0) { - return InvalidArgument("Platform %s has no visible devices.", - platform_name); - } - LocalClientOptions options; - options.set_platform(platform); - TF_ASSIGN_OR_RETURN(xla::LocalClient * client, - ClientLibrary::GetOrCreateLocalClient(options)); - CHECK(client != nullptr); - return LocalClient(client); -} - -// Returns the number of devices known to the XLA client. -int LocalClient::DeviceCount() const { return client_->device_count(); } - -Status LocalClient::TransferToInfeed(const Literal& literal, - int device_ordinal) { - VLOG(1) << "Infeeding literal to device " << device_ordinal - << "; shape: " << literal.shape(); - return client_->TransferToInfeed(literal, device_ordinal); -} - -StatusOr LocalClient::TransferFromOutfeed(const Shape& shape, - int device_ordinal) { - VLOG(1) << "Outfeeding literal from device " << device_ordinal - << "; shape: " << shape; - return client_->TransferFromOutfeed(&shape, device_ordinal); -} - -/* static */ -StatusOr LocalShapedBuffer::FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout, - const LocalClient& client, int device_ordinal) { - VLOG(1) << "Creating shaped buffer from literal on device ordinal: " - << device_ordinal; - auto literal_to_buffer = [&](const Literal& arg) { - return client.client()->LiteralToShapedBuffer( - arg, device_ordinal, client.client()->backend().memory_allocator()); - }; - - StatusOr buf = [&] { - if (shape_with_layout) { - Literal relaid = argument.Relayout(shape_with_layout.value()); - return literal_to_buffer(relaid); - } - return literal_to_buffer(argument); - }(); - TF_RETURN_IF_ERROR(buf.status()); - return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client()); -} - -LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, - xla::LocalClient* client) - : shaped_buffer_(std::move(shaped_buffer)), client_(client) {} - -const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { - return &shaped_buffer_; -} - -ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); } - -const Shape& LocalShapedBuffer::shape() const { - return shaped_buffer()->on_device_shape(); -} - -StatusOr LocalShapedBuffer::ToLiteral() const { - return client_->ShapedBufferToLiteral(*shaped_buffer()); -} - -LocalShapedBufferTuple::LocalShapedBufferTuple( - std::vector elements) - : elements_(std::move(elements)) { - for (auto* element : elements_) { - CHECK(element != nullptr); - } -} - -LocalShapedBufferTuple::~LocalShapedBufferTuple() { - for (LocalShapedBuffer* element : elements_) { - if (element != nullptr) { - delete element; - } - } -} - -StatusOr LocalShapedBufferTuple::Release(int i) { - LocalShapedBuffer* element = elements_[i]; - if (element == nullptr) { - return InvalidArgument("Attempted to release already-released element %d.", - i); - } - elements_[i] = nullptr; - return element; -} - -int64 LocalShapedBufferTuple::size() const { return elements_.size(); } - -StatusOr LocalShapedBuffer::DestructureTuple() { - const Shape tuple_shape = shape(); - - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator(); - ShapedBuffer tuple_buffer = Release(); - - // Extract some metadata we use to construct scoped buffers. - const se::Platform* platform = tuple_buffer.platform(); - int device_ordinal = tuple_buffer.device_ordinal(); - - ShapeTree& shape_tree = tuple_buffer.buffers(); - std::vector results; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - // Create a shaped buffer for this destructured tuple element. - const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); - VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; - ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - - ShapeUtil::ForEachSubshape( - subshape, [&](const Shape& s, const ShapeIndex& index) { - ShapeIndex original(index); - original.push_front(i); - se::DeviceMemoryBase* device_memory = - shape_tree.mutable_element(original); - shaped_buffer.set_buffer(*device_memory, index); - *device_memory = se::DeviceMemoryBase(); - }); - - VLOG(3) << "Completed tuple element: " << i; - results.push_back(new LocalShapedBuffer( - ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_)); - } - // Deallocate the root buffer. - se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); - TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); - return new LocalShapedBufferTuple(std::move(results)); -} - -LocalExecutable::LocalExecutable( - std::unique_ptr executable, - xla::DeviceAssignment device_assignment, xla::LocalClient* client) - : executable_(std::move(executable)), - device_assignment_(std::move(device_assignment)), - client_(client) {} - -std::vector LocalExecutable::DeviceOrdinals() const { - int num_replicas = device_assignment_.replica_count(); - std::vector device_ordinals; - device_ordinals.reserve(num_replicas); - for (int i = 0; i < num_replicas; ++i) { - device_ordinals.push_back(device_assignment_(i, 0)); - } - return device_ordinals; -} - -StatusOr LocalExecutable::Execute( - absl::Span argument_handles) { - if (num_replicas() != 1) { - return InvalidArgument( - "Attempted to execute computation with %d replicas using Execute()", - num_replicas()); - } - StatusOr result_buffer_status; - const int device_ordinal = device_assignment_(0, 0); - VLOG(3) << "Replica 0 mapped to device ordinal for execution: " - << device_ordinal; - - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client_->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client_->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment_); - - result_buffer_status = executable_->Run(argument_buffers, options); - - if (!result_buffer_status.ok()) { - return InternalError( - "Failed running replica 0 (other replicas may have failed as well): " - "%s.", - result_buffer_status.status().ToString()); - } - return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(), - client_); -} - -StatusOr LocalExecutable::ExecutePerReplica( - absl::Span> argument_handles) { - const int num_devices = client_->device_count(); - - if (argument_handles.size() != num_replicas()) { - return InvalidArgument( - "Attempted to execute with %d replicas when replica count is %d", - argument_handles.size(), num_devices); - } - if (argument_handles.size() > num_devices) { - return InvalidArgument( - "Attempted to execute with %d replicas when device count is %d", - argument_handles.size(), num_devices); - } - - VLOG(1) << "Executing with " << num_replicas() << " replicas."; - - std::vector> results(num_replicas()); - auto execute = [this, &argument_handles, &results](int replica) { - const int device_ordinal = device_assignment_(replica, 0); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " << device_ordinal; - - std::vector argument_buffers; - argument_buffers.reserve(argument_handles[replica].size()); - for (auto& handle : argument_handles[replica]) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client_->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client_->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment_); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - - results[replica] = std::move(result_buffer_status); - }; - - if (num_replicas() == 1) { - // Fast-path if there is only one replica — run the computation on the - // current thread. - execute(0); - } else { - // TODO(phawkins): don't recreate the threadpool for each execution. - tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - num_replicas() - 1); - - for (int replica = 0; replica < num_replicas() - 1; ++replica) { - pool.Schedule([&execute, replica] { execute(replica); }); - } - execute(num_replicas() - 1); - } - - std::vector wrapped_results(num_replicas()); - for (int replica = 0; replica < num_replicas(); ++replica) { - auto& statusor = results[replica]; - if (!statusor.ok()) { - return InternalError( - "Failed running replica %d (other replicas may have failed as well): " - "%s.", - replica, statusor.status().ToString()); - } - wrapped_results[replica] = - new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_); - } - - return new LocalShapedBufferTuple(std::move(wrapped_results)); -} - -Computation::Computation(XlaComputation computation) - : computation_(std::move(computation)) {} - -StatusOr Computation::Compile( - const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options, const LocalClient& client) { - std::vector argument_shape_pointers; - argument_shape_pointers.reserve(argument_shapes.size()); - for (auto& argument_shape : argument_shapes) { - argument_shape_pointers.push_back(&argument_shape); - } - - ExecutableBuildOptions options; - if (build_options != nullptr) { - options = *build_options; - } - TF_ASSIGN_OR_RETURN( - auto local_executable, - client.client()->Compile(computation_, argument_shape_pointers, options)); - TF_ASSIGN_OR_RETURN( - DeviceAssignment device_assignment, - client.client()->backend().computation_placer()->AssignDevices( - options.num_replicas(), /*computation_count=*/1)); - - return new LocalExecutable(std::move(local_executable), - std::move(device_assignment), client.client()); -} - -const XlaComputation& Computation::computation() const { return computation_; } - -string Computation::GetSerializedProto() const { - string result; - if (!computation_.proto().SerializeToString(&result)) { - LOG(ERROR) << "Failed to serialize the HloModuleProto."; - return ""; - } - return result; -} - -StatusOr Computation::GetHloText() const { - TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, - HloModule::CreateModuleConfigFromProto( - computation_.proto(), GetDebugOptionsFromFlags())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - HloModule::CreateFromProto(computation_.proto(), module_config)); - HloPrintOptions options; - options = HloPrintOptions::ShortParsable(); - options.set_print_large_constants(false); - return hlo_module->ToString(options); -} - -StatusOr Computation::GetHloDotGraph() const { - TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, - HloModule::CreateModuleConfigFromProto( - computation_.proto(), GetDebugOptionsFromFlags())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - HloModule::CreateFromProto(computation_.proto(), module_config)); - hlo_graph_dumper::DotGraphOptions options; - options.debug_options = &hlo_module->config().debug_options(); - return hlo_graph_dumper::HloComputationToDotGraph( - *hlo_module->entry_computation(), options); -} - -StatusOr Computation::GetProgramShape() const { - return computation_.GetProgramShape(); -} - -StatusOr Computation::GetReturnValueShape() const { - TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape()); - return std::move(*shape.mutable_result()); -} - -LocalOp::LocalOp(const XlaOp& op) : op_(op) {} - -const XlaOp& LocalOp::op() const { return op_; } - -ComputationBuilder::ComputationBuilder(const string& computation_name) - : builder_(computation_name) {} - -void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { - builder_.SetOpMetadata(metadata); -} - -void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } - -StatusOr ComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); - return new Computation(std::move(computation)); -} - -LocalOp ComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, const string& name) { - return xla::Parameter(&builder_, parameter_number, shape, name); -} - -StatusOr ComputationBuilder::BuildWithRoot(const LocalOp& root) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op())); - return new Computation(std::move(computation)); -} - -StatusOr ComputationBuilder::GetShape(const LocalOp& operand) { - return builder_.GetShape(operand.op()); -} - -StatusOr ComputationBuilder::GetReturnValueShape() { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); - return program_shape.result(); -} - -LocalOp ComputationBuilder::ReplicaId() { return xla::ReplicaId(&builder_); } - -LocalOp ComputationBuilder::Infeed(const Shape& shape) { - return xla::Infeed(&builder_, shape); -} - -void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, - const string& outfeed_config) { - xla::Outfeed(operand.op(), shape, outfeed_config); -} - -LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) { - return xla::ConstantLiteral(&builder_, literal); -} - -LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) { - return xla::Iota(&builder_, element_type, size); -} - -LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape, - int64 dimension) { - return xla::Iota(&builder_, shape, dimension); -} - -LocalOp ComputationBuilder::Broadcast(const LocalOp& operand, - absl::Span broadcast_sizes) { - return xla::Broadcast(operand.op(), broadcast_sizes); -} - -LocalOp ComputationBuilder::BroadcastInDim( - const LocalOp& operand, absl::Span out_dim_sizes, - absl::Span broadcast_dimensions) { - return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); -} - -LocalOp ComputationBuilder::Pad(const LocalOp& operand, - const LocalOp& padding_value, - const PaddingConfig& padding_config) { - return xla::Pad(operand.op(), padding_value.op(), padding_config); -} - -LocalOp ComputationBuilder::Reshape(const LocalOp& operand, - absl::Span dimensions, - absl::Span new_sizes) { - return xla::Reshape(operand.op(), dimensions, new_sizes); -} - -LocalOp ComputationBuilder::Collapse(const LocalOp& operand, - absl::Span dimensions) { - return xla::Collapse(operand.op(), dimensions); -} - -LocalOp ComputationBuilder::AllToAll( - const LocalOp& operand, int64 split_dimension, int64 concat_dimension, - int64 split_count, absl::Span replica_groups) { - std::vector rg; - rg.reserve(replica_groups.size()); - for (int i = 0; i < replica_groups.size(); ++i) { - rg.push_back(replica_groups[i]); - } - return xla::AllToAll(operand.op(), split_dimension, concat_dimension, - split_count, rg); -} - -LocalOp ComputationBuilder::CrossReplicaSum( - const LocalOp& operand, absl::Span replica_groups) { - return xla::CrossReplicaSum(operand.op(), replica_groups); -} - -LocalOp ComputationBuilder::Slice(const LocalOp& operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides) { - return xla::Slice(operand.op(), start_indices, limit_indices, strides); -} - -LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); -} - -LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand, - const LocalOp& start_indices, - absl::Span slice_sizes) { - return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); -} - -LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand, - const LocalOp& update, - const LocalOp& start_indices) { - return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); -} - -LocalOp ComputationBuilder::ConcatInDim(absl::Span operands, - int64 dimension) { - std::vector xla_ops; - xla_ops.reserve(operands.size()); - for (const auto& op : operands) { - xla_ops.push_back(op.op()); - } - return xla::ConcatInDim(&builder_, xla_ops, dimension); -} - -LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const Computation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, const LocalOp& source, - const LocalOp& init_value, const Computation& scatter) { - return xla::SelectAndScatterWithGeneralPadding( - operand.op(), select.computation(), window_dimensions, window_strides, - padding, source.op(), init_value.op(), scatter.computation()); -} - -LocalOp ComputationBuilder::Tuple(absl::Span elements) { - std::vector xla_ops; - xla_ops.reserve(elements.size()); - for (const auto& op : elements) { - xla_ops.push_back(op.op()); - } - - return xla::Tuple(&builder_, xla_ops); -} - -LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data, - int64 index) { - return xla::GetTupleElement(tuple_data.op(), index); -} - -LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { - return xla::Dot(lhs.op(), rhs.op()); -} - -LocalOp ComputationBuilder::DotGeneral( - const LocalOp& lhs, const LocalOp& rhs, - const DotDimensionNumbers& dimension_numbers) { - return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); -} - -LocalOp ComputationBuilder::ConvGeneralDilated( - const LocalOp& lhs, const LocalOp& rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { - return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, - lhs_dilation, rhs_dilation, dimension_numbers, - feature_group_count); -} - -LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand, - PrimitiveType new_element_type) { - return xla::ConvertElementType(operand.op(), new_element_type); -} - -LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand, - PrimitiveType new_element_type) { - return xla::BitcastConvertType(operand.op(), new_element_type); -} - -LocalOp ComputationBuilder::Call(const Computation& local_computation, - absl::Span operands) { - std::vector xla_ops; - xla_ops.reserve(operands.size()); - for (const auto& op : operands) { - xla_ops.push_back(op.op()); - } - return xla::Call(&builder_, local_computation.computation(), xla_ops); -} - -LocalOp ComputationBuilder::CustomCall( - const string& call_target_name, absl::Span operands, - const Shape& shape_with_layout, - const std::vector& operand_shapes_with_layout, - const string& opaque) { - std::vector xla_ops; - xla_ops.reserve(operands.size()); - for (const auto& op : operands) { - xla_ops.push_back(op.op()); - } - return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops, - shape_with_layout, - operand_shapes_with_layout, opaque); -} - -LocalOp ComputationBuilder::Transpose(const LocalOp& operand, - absl::Span permutation) { - return xla::Transpose(operand.op(), permutation); -} - -LocalOp ComputationBuilder::Rev(const LocalOp& operand, - absl::Span dimensions) { - return xla::Rev(operand.op(), dimensions); -} - -LocalOp ComputationBuilder::Map(absl::Span operands, - const Computation& local_computation, - absl::Span dimensions) { - std::vector xla_ops; - xla_ops.reserve(operands.size()); - for (const auto& op : operands) { - xla_ops.push_back(op.op()); - } - - return xla::Map(&builder_, xla_ops, local_computation.computation(), - dimensions); -} - -LocalOp ComputationBuilder::Reduce( - const LocalOp& operand, const LocalOp& init_value, - const Computation& local_computation, - absl::Span dimensions_to_reduce) { - return xla::Reduce(operand.op(), init_value.op(), - local_computation.computation(), dimensions_to_reduce); -} - -LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding( - const LocalOp& operand, const LocalOp& init_value, - const Computation& local_computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding) { - return xla::ReduceWindowWithGeneralPadding( - operand.op(), init_value.op(), local_computation.computation(), - window_dimensions, window_strides, base_dilations, window_dilations, - padding); -} - -LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, - const Shape& shape) { - return xla::RngNormal(mu.op(), sigma.op(), shape); -} - -LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, - const Shape& shape) { - return xla::RngUniform(a.op(), b.op(), shape); -} - -LocalOp ComputationBuilder::While(const Computation& condition, - const Computation& body, - const LocalOp& init) { - return xla::While(condition.computation(), body.computation(), init.op()); -} - -LocalOp ComputationBuilder::Conditional(const LocalOp& predicate, - const LocalOp& true_operand, - const Computation& true_computation, - const LocalOp& false_operand, - const Computation& false_computation) { - return xla::Conditional(predicate.op(), true_operand.op(), - true_computation.computation(), false_operand.op(), - false_computation.computation()); -} - -StatusOr ComputationBuilder::IsConstant(const LocalOp& operand) { - return builder_.IsConstant(operand.op()); -} - -LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { - return xla::Sort(operand.op(), {}, dimension); -} - -LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys, - const LocalOp& values, int64 dimension) { - return xla::Sort(keys.op(), {values.op()}, dimension); -} - -LocalOp ComputationBuilder::Cholesky(const LocalOp& a, bool lower) { - return xla::Cholesky(a.op(), lower); -} - -LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) { - XlaBuilder* builder = a.op().builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); - return xla::Tuple(builder, {qr.q, qr.r}); - }); -} - -LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b, - bool left_side, bool lower, - bool unit_diagonal, - int transpose_a) { - return xla::TriangularSolve( - a.op(), b.op(), left_side, lower, unit_diagonal, - xla::TriangularSolveOptions::Transpose(transpose_a)); -} - -LocalOp ComputationBuilder::Gather( - const LocalOp& input, const LocalOp& start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes) { - return xla::Gather(input.op(), start_indices.op(), dimension_numbers, - slice_sizes); -} - -LocalOp ComputationBuilder::Scatter( - const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, const Computation& update_computation, - const ScatterDimensionNumbers& dimension_numbers) { - return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), - update_computation.computation(), dimension_numbers); -} - -StatusOr ComputationBuilder::BuildConstantSubGraph( - const LocalOp& operand) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, - builder_.BuildConstantSubGraph(operand.op())); - return new Computation(std::move(computation)); -} - -#define _FORWARD(method_name, return_sig, args_sig, args) \ - return_sig ComputationBuilder::method_name args_sig { \ - return xla::method_name args; \ - } - -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) - -#define _FORWARD_BINOP(method_name) \ - _FORWARD(method_name, LocalOp, \ - (const LocalOp& lhs, const LocalOp& rhs, \ - absl::Span broadcast_dimensions), \ - (lhs.op(), rhs.op(), broadcast_dimensions)) - -#define _FORWARD_TRIOP(method_name) \ - _FORWARD(method_name, LocalOp, \ - (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ - (lhs.op(), rhs.op(), ehs.op())) - -_FORWARD_TRIOP(Select) -_FORWARD_TRIOP(Clamp) -_FORWARD_BINOP(Eq) -_FORWARD_BINOP(Ne) -_FORWARD_BINOP(Ge) -_FORWARD_BINOP(Gt) -_FORWARD_BINOP(Lt) -_FORWARD_BINOP(Le) -_FORWARD_BINOP(Add) -_FORWARD_BINOP(Sub) -_FORWARD_BINOP(Mul) -_FORWARD_BINOP(Div) -_FORWARD_BINOP(Rem) -_FORWARD_BINOP(Max) -_FORWARD_BINOP(Min) -_FORWARD_BINOP(And) -_FORWARD_BINOP(Or) -_FORWARD_BINOP(Xor) -_FORWARD_BINOP(ShiftLeft) -_FORWARD_BINOP(ShiftRightArithmetic) -_FORWARD_BINOP(ShiftRightLogical) -_FORWARD_BINOP(Atan2) -_FORWARD_BINOP(Pow) -_FORWARD_BINOP(Complex) -_FORWARD_UNOP(Not) -_FORWARD_UNOP(Clz) -_FORWARD_UNOP(Abs) -_FORWARD_UNOP(Exp) -_FORWARD_UNOP(Expm1) -_FORWARD_UNOP(Floor) -_FORWARD_UNOP(Ceil) -_FORWARD_UNOP(Round) -_FORWARD_UNOP(Log) -_FORWARD_UNOP(Log1p) -_FORWARD_UNOP(Sign) -_FORWARD_UNOP(Cos) -_FORWARD_UNOP(Sin) -_FORWARD_UNOP(Tanh) -_FORWARD_UNOP(IsFinite) -_FORWARD_UNOP(Neg) -_FORWARD_UNOP(Sqrt) -_FORWARD_UNOP(Rsqrt) -_FORWARD_UNOP(Square) -_FORWARD_UNOP(Reciprocal) -_FORWARD_UNOP(Erfc) -_FORWARD_UNOP(Erf) -_FORWARD_UNOP(ErfInv) -_FORWARD_UNOP(Lgamma) -_FORWARD_UNOP(Digamma) -_FORWARD_UNOP(Acos) -_FORWARD_UNOP(Asin) -_FORWARD_UNOP(Atan) -_FORWARD_UNOP(Tan) -_FORWARD_UNOP(Acosh) -_FORWARD_UNOP(Asinh) -_FORWARD_UNOP(Atanh) -_FORWARD_UNOP(Cosh) -_FORWARD_UNOP(Sinh) -_FORWARD_UNOP(Real) -_FORWARD_UNOP(Imag) -_FORWARD_UNOP(Conj) - -#undef _FORWARD -#undef _FORWARD_UNOP -#undef _FORWARD_BINOP -#undef _FORWARD_TRIOP - -void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { - delete local_shaped_buffer; -} - -void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; } - -void DeleteComputation(Computation* computation) { delete computation; } - -} // namespace swig -} // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h deleted file mode 100644 index fa878501aba..00000000000 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ /dev/null @@ -1,473 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ - -#include -#include - -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { -namespace swig { - -// Registers a 'fn_capsule' as a CPU custom call target. -// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name -// "xla._CPU_CUSTOM_CALL_TARGET". -Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); - -// Wrapper around an xla::LocalClient. -class LocalClient { - public: - // Initializes a local XLA client for `platform_name`. Returns an error if no - /// such platform exists, or if the platform has no visible devices. - static StatusOr Get(const string& platform_name); - - // Copyable and moveable; the class is just a wrapper around a - // xla::LocalClient pointer for convenient SWIG wrapping. - - // Returns the number of devices known to the XLA client. - int DeviceCount() const; - - // Wraps the local client's infeed-transfer function. - // - // The default device ordinal (0) is used. - Status TransferToInfeed(const Literal& literal, int device_ordinal); - - // Transfers a literal of the given shape from the outfeed of the given - // replica. - StatusOr TransferFromOutfeed(const Shape& shape, int device_ordinal); - - xla::LocalClient* client() const { return client_; } - - private: - LocalClient(xla::LocalClient* client); - - xla::LocalClient* client_; -}; - -class LocalShapedBufferTuple; - -// Represents a reference to literals that live in a device-allocated buffer via -// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a -// literal to device via the local client. -class LocalShapedBuffer { - public: - static StatusOr FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout, - const LocalClient& client, int device_ordinal); - - LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client); - StatusOr ToLiteral() const; - const Shape& shape() const; - const ScopedShapedBuffer* shaped_buffer() const; - - // Transfers ownership of the encapsulated ShapedBuffer to the caller, - // analogous to std::unique_ptr::release(). - ShapedBuffer Release(); - - // Destructures a tuple-valued LocalShapedBuffer into its constituent - // elements in LocalShapedBufferTuple form. - StatusOr DestructureTuple(); - - private: - ScopedShapedBuffer shaped_buffer_; - xla::LocalClient* client_; -}; - -// Result of a tuple destructuring operation on a LocalShapedBuffer -- this -// appears to be a simpler mechanism for the time being than an alternative like -// using SWIG to transform std::vectors into Python lists of SWIG objects -// directly. -class LocalShapedBufferTuple { - public: - // Note: any LocalShapedBuffer elements that are not Release()'d will be - // deallocated in the destructor. - explicit LocalShapedBufferTuple(std::vector elements); - - ~LocalShapedBufferTuple(); - - // Releases the ith element to the caller. Further attempts to release the ith - // element will return an invalid argument error. - StatusOr Release(int i); - - // Returns the number of elements in the destructured tuple. - int64 size() const; - - private: - std::vector elements_; -}; - -// Represents a compiled computation that can be executed given handles to -// device-allocated literals. Specifically, wraps an XLA LocalExecutable. -class LocalExecutable { - public: - LocalExecutable(std::unique_ptr executable, - xla::DeviceAssignment device_assignment, - xla::LocalClient* client); - - int num_replicas() const { - return executable_->build_options().num_replicas(); - } - - // Returns the device ordinals to which each replica is assigned. - std::vector DeviceOrdinals() const; - - StatusOr Execute( - absl::Span argument_handles); - - // Execute on many replicas. Takes a sequence of argument lists (one argument - // list per replica) and returns a tuple of results (one result per replica). - // The number of argument lists must be equal to the replica count. - StatusOr ExecutePerReplica( - absl::Span > argument_handles); - - private: - const std::unique_ptr executable_; - const xla::DeviceAssignment device_assignment_; - xla::LocalClient* const client_; -}; - -// Wraps a XlaComputation produced by a ComputationBuilder. The -// Compile method compiles the computation to a (local) executable via -// the client library's local client. This class is intended to be -// made available to Python via SWIG. -class Computation { - public: - Computation(XlaComputation computation); - - StatusOr Compile( - const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options, const LocalClient& client); - - const XlaComputation& computation() const; - - // Returns the HloModuleProto contained in the XlaComputation in the - // serialized binary format. Logs an internal error and returns an empty - // string on failure. - string GetSerializedProto() const; - - // Returns the computation in human-readable HLO text format. - StatusOr GetHloText() const; - - // Returns the computation in graphviz dot format. - StatusOr GetHloDotGraph() const; - - // Returns the program shape for this computation. - StatusOr GetProgramShape() const; - - // Returns the return-value shape for this computation. - StatusOr GetReturnValueShape() const; - - private: - XlaComputation computation_; -}; - -// Wraps a XlaOp produced by a ComputationBuilder. This class is intended -// to be made available to Python via SWIG. -class LocalOp { - public: - LocalOp(const XlaOp& op); - - const XlaOp& op() const; - - private: - XlaOp op_; -}; - -// Wraps the ComputationBuilder API in order to: -// - Support consumption by SWIG in order to be made available to -// Python. -// - Set up the underlying builder to use the client library's -// LocalClient. -// - Wrap Computations in Computations for Python access. -// - Correspondingly unwrap incoming Computations. -class ComputationBuilder { - public: - ComputationBuilder(const string& computation_name); - - void SetOpMetadata(const OpMetadata& metadata); - void ClearOpMetadata(); - - // Returns an owned Computation to the caller on success. - StatusOr Build(); - - // Returns an owned Computation to the caller on success with given root. - StatusOr BuildWithRoot(const LocalOp& root); - - LocalOp Parameter(int64 parameter_number, const Shape& shape, - const string& name); - - StatusOr GetShape(const LocalOp& operand); - - // Returns the shape of the current return value for the computation. - StatusOr GetReturnValueShape(); - - LocalOp ReplicaId(); - - LocalOp Infeed(const Shape& shape); - - void Outfeed(const LocalOp& operand, const Shape& shape, - const string& outfeed_config); - - LocalOp ConstantLiteral(const Literal& literal); - - LocalOp Iota(PrimitiveType element_type, int64 size); - - LocalOp BroadcastedIota(const Shape& shape, int64 dimension); - - LocalOp Broadcast(const LocalOp& operand, - absl::Span broadcast_sizes); - - LocalOp BroadcastInDim(const LocalOp& operand, - absl::Span out_dim_sizes, - absl::Span broadcast_dimensions); - - LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, - const PaddingConfig& padding_config); - - LocalOp Reshape(const LocalOp& operand, absl::Span dimensions, - absl::Span new_sizes); - - LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); - - LocalOp AllToAll(const LocalOp& operand, int64 split_dimension, - int64 concat_dimension, int64 split_count, - absl::Span replica_groups); - - LocalOp CrossReplicaSum(const LocalOp& operand, - absl::Span replica_groups); - - LocalOp Slice(const LocalOp& operand, absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - - LocalOp SliceInDim(const LocalOp& operand, int64 start_index, - int64 limit_index, int64 stride, int64 dimno); - - LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, - absl::Span slice_sizes); - - LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, - const LocalOp& start_indices); - - LocalOp ConcatInDim(absl::Span operands, int64 dimension); - - LocalOp SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const Computation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span > padding, const LocalOp& source, - const LocalOp& init_value, const Computation& scatter); - - LocalOp Tuple(absl::Span elements); - - LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - - LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - - LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, - const DotDimensionNumbers& dimension_numbers); - - LocalOp ConvGeneralDilated( - const LocalOp& lhs, const LocalOp& rhs, - absl::Span window_strides, - absl::Span > padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count); - - LocalOp ConvertElementType(const LocalOp& operand, - PrimitiveType new_element_type); - - LocalOp BitcastConvertType(const LocalOp& operand, - PrimitiveType new_element_type); - - LocalOp Call(const Computation& local_computation, - absl::Span operands); - - LocalOp CustomCall(const string& call_target_name, - absl::Span operands, - const Shape& shape_with_layout, - const std::vector& operand_shapes_with_layout, - const string& opaque); - - LocalOp Transpose(const LocalOp& operand, - absl::Span permutation); - - LocalOp Rev(const LocalOp& operand, absl::Span dimensions); - - LocalOp Map(absl::Span operands, - const Computation& local_computation, - absl::Span dimensions); - - LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, - const Computation& local_computation, - absl::Span dimensions_to_reduce); - - LocalOp ReduceWindowWithGeneralPadding( - const LocalOp& operand, const LocalOp& init_value, - const Computation& local_computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span > padding); - - LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, - const Shape& shape); - - LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - - LocalOp While(const Computation& condition, const Computation& body, - const LocalOp& init); - - LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, - const Computation& true_computation, - const LocalOp& false_operand, - const Computation& false_computation); - - StatusOr IsConstant(const LocalOp& operand); - - LocalOp Sort(const LocalOp& operand, int64 dimension); - - LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, - int64 dimension); - - LocalOp QR(const LocalOp& a, bool full_matrices); - - LocalOp Cholesky(const LocalOp& a, bool lower); - - // `transpose_a` is the integer value of a TriangularSolveOptions::Transpose - // enum. We use an integer here so we don't have to teach SWIG about the - // enum. - LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, - bool lower, bool unit_diagonal, int transpose_a); - - LocalOp Gather(const LocalOp& input, const LocalOp& start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes); - - LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, const Computation& update_computation, - const ScatterDimensionNumbers& dimension_numbers); - - StatusOr BuildConstantSubGraph(const LocalOp& operand); - -#define _FORWARD(method_name, return_sig, args_sig) \ - return_sig method_name args_sig; - -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, LocalOp, (const LocalOp& operand)) - -#define _FORWARD_BINOP(method_name) \ - _FORWARD(method_name, LocalOp, \ - (const LocalOp& lhs, const LocalOp& rhs, \ - absl::Span broadcast_dimensions)) - -#define _FORWARD_TRIOP(method_name) \ - _FORWARD(method_name, LocalOp, \ - (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) - - _FORWARD_TRIOP(Select) - _FORWARD_TRIOP(Clamp) - _FORWARD_BINOP(Eq) - _FORWARD_BINOP(Ne) - _FORWARD_BINOP(Ge) - _FORWARD_BINOP(Gt) - _FORWARD_BINOP(Lt) - _FORWARD_BINOP(Le) - _FORWARD_BINOP(Add) - _FORWARD_BINOP(Sub) - _FORWARD_BINOP(Mul) - _FORWARD_BINOP(Div) - _FORWARD_BINOP(Rem) - _FORWARD_BINOP(Max) - _FORWARD_BINOP(Min) - _FORWARD_BINOP(And) - _FORWARD_BINOP(Or) - _FORWARD_BINOP(Xor) - _FORWARD_BINOP(ShiftLeft) - _FORWARD_BINOP(ShiftRightArithmetic) - _FORWARD_BINOP(ShiftRightLogical) - _FORWARD_BINOP(Atan2) - _FORWARD_BINOP(Pow) - _FORWARD_BINOP(Complex) - _FORWARD_UNOP(Not) - _FORWARD_UNOP(Clz) - _FORWARD_UNOP(Abs) - _FORWARD_UNOP(Exp) - _FORWARD_UNOP(Expm1) - _FORWARD_UNOP(Floor) - _FORWARD_UNOP(Ceil) - _FORWARD_UNOP(Round) - _FORWARD_UNOP(Log) - _FORWARD_UNOP(Log1p) - _FORWARD_UNOP(Sign) - _FORWARD_UNOP(Cos) - _FORWARD_UNOP(Sin) - _FORWARD_UNOP(Tanh) - _FORWARD_UNOP(IsFinite) - _FORWARD_UNOP(Neg) - _FORWARD_UNOP(Sqrt) - _FORWARD_UNOP(Rsqrt) - _FORWARD_UNOP(Square) - _FORWARD_UNOP(Reciprocal) - _FORWARD_UNOP(Erfc) - _FORWARD_UNOP(Erf) - _FORWARD_UNOP(ErfInv) - _FORWARD_UNOP(Lgamma) - _FORWARD_UNOP(Digamma) - _FORWARD_UNOP(Acos) - _FORWARD_UNOP(Asin) - _FORWARD_UNOP(Atan) - _FORWARD_UNOP(Tan) - _FORWARD_UNOP(Acosh) - _FORWARD_UNOP(Asinh) - _FORWARD_UNOP(Atanh) - _FORWARD_UNOP(Cosh) - _FORWARD_UNOP(Sinh) - _FORWARD_UNOP(Real) - _FORWARD_UNOP(Imag) - _FORWARD_UNOP(Conj) - -#undef _FORWARD -#undef _FORWARD_UNOP -#undef _FORWARD_BINOP -#undef _FORWARD_TRIOP - - private: - XlaBuilder builder_; -}; - -// Functions for freeing resources from the Python side. -void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); -void DeleteLocalExecutable(LocalExecutable* computation); -void DeleteComputation(Computation* computation); - -} // namespace swig -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i deleted file mode 100644 index 9fcb4822c7f..00000000000 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ /dev/null @@ -1,412 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// SWIG typemaps and declarations for building, compiling, and -// executing XLA computations, wrapping most of what is declared in -// local_computation_builder.h. - -%module(threads="1") local_computation_builder - -// Keep the GIL except where explicitly specified. -%nothread; - -%include "tensorflow/python/platform/base.i" -%include "tensorflow/compiler/xla/python/xla_data.i" - -%{ -// Must be included first -#include "tensorflow/python/lib/core/numpy.h" - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/python/numpy_bridge.h" -#include "tensorflow/compiler/xla/python/local_computation_builder.h" - -using namespace xla; -using namespace xla::swig; - -%} - -// Required to use PyArray_* functions. -%init %{ -tensorflow::ImportNumpy(); -%} - -// Computation builder types - -%typemap(in) absl::Span( - std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - LocalOp* op; - if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), - SWIG_POINTER_EXCEPTION)) == -1) { - SWIG_fail; - } - temps.push_back(*op); - Py_DECREF(o); - } - $1 = temps; -} - -// Computation and buffer/allocation types - -%typemap(out) StatusOr { - if ($1.ok()) { - xla::swig::LocalClient value = $1.ValueOrDie(); - { - auto $1 = value; - $typemap(out, xla::swig::LocalClient) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::LocalExecutable*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::LocalShapedBuffer*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::LocalShapedBufferTuple*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::Computation*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - LocalShapedBuffer* lsbp; - if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), - SWIG_POINTER_EXCEPTION)) == -1) { - SWIG_fail; - } - temps.push_back(lsbp); - Py_DECREF(o); - } - $1 = temps; -} - -%typemap(in) absl::Span > - (std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - std::vector vec; - const int vec_size = PySequence_Size(o); - vec.reserve(vec_size); - for (int j = 0; j < vec_size; ++j) { - PyObject* vec_elt = PySequence_GetItem(o, j); - LocalShapedBuffer* lsbp; - if ((SWIG_ConvertPtr(vec_elt, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), - SWIG_POINTER_EXCEPTION)) == -1) { - Py_DECREF(vec_elt); - Py_DECREF(o); - SWIG_fail; - } - vec.push_back(lsbp); - Py_DECREF(vec_elt); - } - temps.push_back(vec); - Py_DECREF(o); - } - $1 = temps; -} - -// ExecutableBuildOptions - -%typemap(in) const ExecutableBuildOptions* - (ExecutableBuildOptions build_options) { - if ($input == Py_None) { - $1 = NULL; - } else { - if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { - build_options.mutable_debug_options()->set_xla_generate_hlo_graph(std::move(s)); - })) { - return nullptr; - } - if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { - build_options.mutable_debug_options()->set_xla_dump_optimized_hlo_proto_to(std::move(s)); - })) { - return nullptr; - } - if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { - build_options.mutable_debug_options()->set_xla_dump_unoptimized_hlo_proto_to(std::move(s)); - })) { - return nullptr; - } - if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { - build_options.mutable_debug_options()->set_xla_dump_per_pass_hlo_proto_to(std::move(s)); - })) { - return nullptr; - } - - PyObject* o = PyObject_GetAttrString($input, "hlo_profile"); - if (o == NULL) { - SWIG_fail; - } - if (o != Py_None) { - if (!PyBool_Check(o)) { - PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); - SWIG_fail; - } - build_options.mutable_debug_options()->set_xla_hlo_profile(o == Py_True); - } - Py_DECREF(o); - - o = PyObject_GetAttrString($input, "result_shape"); - if (o == nullptr) { - return nullptr; - } - if (o != Py_None) { - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); - Py_DECREF(o); - SWIG_fail; - } - build_options.set_result_layout(statusor.ValueOrDie()); - } - Py_DECREF(o); - - int64 num_replicas; - if (!GetIntAttr($input, "num_replicas", &num_replicas)) { - SWIG_fail; - } - build_options.set_num_replicas(num_replicas); - - $1 = &build_options; - } -} - -%ignoreall -%unignore xla; -%unignore xla::swig; -%unignore xla::swig::RegisterCpuCustomCallTarget; -%unignore xla::swig::LocalClient; -%unignore xla::swig::LocalClient::Get; -%unignore xla::swig::LocalClient::DeviceCount; -%unignore xla::swig::LocalClient::TransferToInfeed; -%unignore xla::swig::LocalClient::TransferFromOutfeed; -%unignore xla::swig::LocalShapedBuffer; -%unignore xla::swig::LocalShapedBuffer::FromLiteral; -%unignore xla::swig::LocalShapedBuffer::ToLiteral; -%unignore xla::swig::LocalShapedBuffer::shape; -%unignore xla::swig::LocalShapedBuffer::DestructureTuple; -%unignore xla::swig::LocalShapedBufferTuple; -%unignore xla::swig::LocalShapedBufferTuple::Release; -%unignore xla::swig::LocalShapedBufferTuple::size; -%unignore xla::swig::LocalExecutable; -%unignore xla::swig::LocalExecutable::DeviceOrdinals; -%unignore xla::swig::LocalExecutable::Execute; -%unignore xla::swig::LocalExecutable::ExecutePerReplica; -%unignore xla::swig::Computation; -%unignore xla::swig::Computation::Compile; -%unignore xla::swig::Computation::GetProgramShape; -%unignore xla::swig::Computation::GetReturnValueShape; -%unignore xla::swig::Computation::GetSerializedProto; -%unignore xla::swig::Computation::GetHloText; -%unignore xla::swig::Computation::GetHloDotGraph; -%unignore xla::swig::LocalOp; -%unignore xla::swig::ComputationBuilder; -%unignore xla::swig::ComputationBuilder::ComputationBuilder; -%unignore xla::swig::ComputationBuilder::Build; -%unignore xla::swig::ComputationBuilder::BuildWithRoot; -%unignore xla::swig::ComputationBuilder::SetOpMetadata; -%unignore xla::swig::ComputationBuilder::ClearOpMetadata; -%unignore xla::swig::ComputationBuilder::Parameter; -%unignore xla::swig::ComputationBuilder::GetShape; -%unignore xla::swig::ComputationBuilder::GetReturnValueShape; -%unignore xla::swig::ComputationBuilder::ReplicaId; -%unignore xla::swig::ComputationBuilder::Infeed; -%unignore xla::swig::ComputationBuilder::Outfeed; -%unignore xla::swig::ComputationBuilder::ConstantLiteral; -%unignore xla::swig::ComputationBuilder::ConstantR0; -%unignore xla::swig::ComputationBuilder::Iota; -%unignore xla::swig::ComputationBuilder::BroadcastedIota; -%unignore xla::swig::ComputationBuilder::Broadcast; -%unignore xla::swig::ComputationBuilder::BroadcastInDim; -%unignore xla::swig::ComputationBuilder::Pad; -%unignore xla::swig::ComputationBuilder::Reshape; -%unignore xla::swig::ComputationBuilder::Collapse; -%unignore xla::swig::ComputationBuilder::AllToAll; -%unignore xla::swig::ComputationBuilder::CrossReplicaSum; -%unignore xla::swig::ComputationBuilder::Slice; -%unignore xla::swig::ComputationBuilder::SliceInDim; -%unignore xla::swig::ComputationBuilder::DynamicSlice; -%unignore xla::swig::ComputationBuilder::DynamicUpdateSlice; -%unignore xla::swig::ComputationBuilder::ConcatInDim; -%unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding; -%unignore xla::swig::ComputationBuilder::Select; -%unignore xla::swig::ComputationBuilder::Tuple; -%unignore xla::swig::ComputationBuilder::GetTupleElement; -%unignore xla::swig::ComputationBuilder::ConvertElementType; -%unignore xla::swig::ComputationBuilder::BitcastConvertType; -%unignore xla::swig::ComputationBuilder::Call; -%unignore xla::swig::ComputationBuilder::Transpose; -%unignore xla::swig::ComputationBuilder::Rev; -%unignore xla::swig::ComputationBuilder::Clamp; -%unignore xla::swig::ComputationBuilder::Map; -%unignore xla::swig::ComputationBuilder::Reduce; -%unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding; -%unignore xla::swig::ComputationBuilder::RngNormal; -%unignore xla::swig::ComputationBuilder::RngUniform; -%unignore xla::swig::ComputationBuilder::RngBernoulli; -%unignore xla::swig::ComputationBuilder::While; -%unignore xla::swig::ComputationBuilder::Conditional; -%unignore xla::swig::ComputationBuilder::IsConstant; -%unignore xla::swig::ComputationBuilder::Eq; -%unignore xla::swig::ComputationBuilder::Ne; -%unignore xla::swig::ComputationBuilder::Ge; -%unignore xla::swig::ComputationBuilder::Gt; -%unignore xla::swig::ComputationBuilder::Lt; -%unignore xla::swig::ComputationBuilder::Le; -%unignore xla::swig::ComputationBuilder::Dot; -%unignore xla::swig::ComputationBuilder::DotGeneral; -%unignore xla::swig::ComputationBuilder::ConvGeneralDilated; -%unignore xla::swig::ComputationBuilder::Add; -%unignore xla::swig::ComputationBuilder::Sub; -%unignore xla::swig::ComputationBuilder::Mul; -%unignore xla::swig::ComputationBuilder::Div; -%unignore xla::swig::ComputationBuilder::Rem; -%unignore xla::swig::ComputationBuilder::Max; -%unignore xla::swig::ComputationBuilder::Min; -%unignore xla::swig::ComputationBuilder::And; -%unignore xla::swig::ComputationBuilder::Or; -%unignore xla::swig::ComputationBuilder::Xor; -%unignore xla::swig::ComputationBuilder::ShiftLeft; -%unignore xla::swig::ComputationBuilder::ShiftRightArithmetic; -%unignore xla::swig::ComputationBuilder::ShiftRightLogical; -%unignore xla::swig::ComputationBuilder::Not; -%unignore xla::swig::ComputationBuilder::Clz; -%unignore xla::swig::ComputationBuilder::Abs; -%unignore xla::swig::ComputationBuilder::Exp; -%unignore xla::swig::ComputationBuilder::Expm1; -%unignore xla::swig::ComputationBuilder::Floor; -%unignore xla::swig::ComputationBuilder::Ceil; -%unignore xla::swig::ComputationBuilder::Round; -%unignore xla::swig::ComputationBuilder::Log; -%unignore xla::swig::ComputationBuilder::Log1p; -%unignore xla::swig::ComputationBuilder::Sign; -%unignore xla::swig::ComputationBuilder::Cos; -%unignore xla::swig::ComputationBuilder::Sin; -%unignore xla::swig::ComputationBuilder::Tanh; -%unignore xla::swig::ComputationBuilder::Atan2; -%unignore xla::swig::ComputationBuilder::IsFinite; -%unignore xla::swig::ComputationBuilder::Pow; -%unignore xla::swig::ComputationBuilder::Neg; -%unignore xla::swig::ComputationBuilder::Sort; -%unignore xla::swig::ComputationBuilder::SortKeyVal; -%unignore xla::swig::ComputationBuilder::Sqrt; -%unignore xla::swig::ComputationBuilder::Rsqrt; -%unignore xla::swig::ComputationBuilder::Square; -%unignore xla::swig::ComputationBuilder::Reciprocal; -%unignore xla::swig::ComputationBuilder::Erfc; -%unignore xla::swig::ComputationBuilder::Erf; -%unignore xla::swig::ComputationBuilder::ErfInv; -%unignore xla::swig::ComputationBuilder::Lgamma; -%unignore xla::swig::ComputationBuilder::Digamma; -%unignore xla::swig::ComputationBuilder::Acos; -%unignore xla::swig::ComputationBuilder::Asin; -%unignore xla::swig::ComputationBuilder::Atan; -%unignore xla::swig::ComputationBuilder::Tan; -%unignore xla::swig::ComputationBuilder::Acosh; -%unignore xla::swig::ComputationBuilder::Asinh; -%unignore xla::swig::ComputationBuilder::Atanh; -%unignore xla::swig::ComputationBuilder::Cosh; -%unignore xla::swig::ComputationBuilder::Sinh; -%unignore xla::swig::ComputationBuilder::Real; -%unignore xla::swig::ComputationBuilder::Imag; -%unignore xla::swig::ComputationBuilder::Conj; -%unignore xla::swig::ComputationBuilder::Complex; -%unignore xla::swig::ComputationBuilder::Cholesky; -%unignore xla::swig::ComputationBuilder::QR; -%unignore xla::swig::ComputationBuilder::TriangularSolve; -%unignore xla::swig::ComputationBuilder::CustomCall; -%unignore xla::swig::ComputationBuilder::Gather; -%unignore xla::swig::ComputationBuilder::Scatter; -%unignore xla::swig::DeleteComputation; -%unignore xla::swig::DeleteLocalShapedBuffer; -%unignore xla::swig::DeleteLocalExecutable; - -%thread; -%include "tensorflow/compiler/xla/python/local_computation_builder.h" -%nothread; - -%unignoreall diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc deleted file mode 100644 index 74f45b7cdcf..00000000000 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ /dev/null @@ -1,658 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/python/numpy_bridge.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -namespace swig { - -namespace numpy { - -Safe_PyObjectPtr make_safe(PyObject* object) { - return Safe_PyObjectPtr(object); -} - -int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { - switch (primitive_type) { - case PRED: - return NPY_BOOL; - case S8: - return NPY_INT8; - case S16: - return NPY_INT16; - case S32: - return NPY_INT32; - case S64: - return NPY_INT64; - case U8: - return NPY_UINT8; - case U16: - return NPY_UINT16; - case U32: - return NPY_UINT32; - case U64: - return NPY_UINT64; - case F16: - return NPY_FLOAT16; - case F32: - return NPY_FLOAT32; - case F64: - return NPY_FLOAT64; - case C64: - return NPY_COMPLEX64; - case C128: - return NPY_COMPLEX128; - case TUPLE: - return NPY_OBJECT; - default: - LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type; - } -} - -PrimitiveType NumpyTypeToPrimitiveType(int np_type) { - switch (np_type) { - case NPY_BOOL: - return PRED; - case NPY_INT8: - return S8; - case NPY_INT16: - return S16; - case NPY_INT32: - return S32; - case NPY_INT64: - return S64; - case NPY_UINT8: - return U8; - case NPY_UINT16: - return U16; - case NPY_UINT32: - return U32; - case NPY_UINT64: - return U64; - case NPY_FLOAT16: - return F16; - case NPY_FLOAT32: - return F32; - case NPY_FLOAT64: - return F64; - case NPY_COMPLEX64: - return C64; - case NPY_COMPLEX128: - return C128; - case NPY_OBJECT: - return TUPLE; - default: - LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type; - } -} - -bool NumpyTypeIsValid(int np_type) { - switch (np_type) { - case NPY_BOOL: - case NPY_INT8: - case NPY_INT16: - case NPY_INT32: - case NPY_INT64: - case NPY_UINT8: - case NPY_UINT16: - case NPY_UINT32: - case NPY_UINT64: - case NPY_FLOAT16: - case NPY_FLOAT32: - case NPY_FLOAT64: - case NPY_COMPLEX64: - case NPY_COMPLEX128: - case NPY_OBJECT: - return true; - default: - return false; - } -} - -Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) { - int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); - PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); - - Safe_PyObjectPtr dimensions; - if (shape.IsTuple()) { - int num_elements = ShapeUtil::TupleElementCount(shape); - dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape))); - for (int i = 0; i < num_elements; ++i) { - PyTuple_SET_ITEM( - dimensions.get(), i, - PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)) - .release()); - } - } else { - int rank = shape.rank(); - dimensions = make_safe(PyTuple_New(rank)); - for (int i = 0; i < rank; ++i) { - PyTuple_SET_ITEM(dimensions.get(), i, - LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); - } - } - return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release())); -} - -Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( - const ProgramShape& shape) { - Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size())); - for (int i = 0; i < shape.parameters_size(); ++i) { - PyTuple_SET_ITEM(arg_shapes.get(), i, - PyShapeInfoFromXlaShape(shape.parameters(i)).release()); - } - - Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result()); - return make_safe( - PyTuple_Pack(2, arg_shapes.release(), result_shape.release())); -} - -// Precondition: o->ob_type == &PyArrayDescr_Type -static int NumpyTypenum(PyObject* o) { - return reinterpret_cast(o)->type_num; -} - -// Extracts the string held inside r and returns it as a C++ string. -// -// NOTE: this is an internal helper for conversion to a C++, and so decrefs r. -static string ExtractStringAndDecref(PyObject* r) { - auto error = [r] { return absl::StrFormat("", r); }; - if (r == nullptr) { - return error(); - } -#if PY_MAJOR_VERSION < 3 - string result = PyString_AsString(r); -#else - PyObject* bytes = PyUnicode_AsEncodedString(r, 0, 0); - if (bytes == nullptr) { - return error(); - } - CHECK(PyBytes_Check(bytes)); - string result = PyBytes_AsString(bytes); - Py_DECREF(bytes); -#endif - Py_DECREF(r); - return result; -} - -// Safely returns a str of the given Python object o as a C++ string. -static string PyObjectCppStr(PyObject* o) { - PyObject* s = PyObject_Str(o); - return ExtractStringAndDecref(s); -} - -string PyObjectCppRepr(PyObject* o) { - PyObject* r = PyObject_Repr(o); - return ExtractStringAndDecref(r); -} - -StatusOr XlaShapeFromPyShape(PyObject* o) { - auto error = [o](const string& prefix) { - return InvalidArgument("%s; got %s", prefix.c_str(), - PyObjectCppRepr(o).c_str()); - }; - - auto call_method = [o, &error](const string& method) -> StatusOr { - PyObject* result = - PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); - if (result == nullptr) { - return error( - absl::StrCat("Failed to call method of shape object:", method)); - } - return result; - }; - - PyObject* np_type; - TF_ASSIGN_OR_RETURN(np_type, call_method("numpy_dtype")); - if (np_type->ob_type != &PyArrayDescr_Type) { - return error( - "Return value of shape method numpy_dtype " - "is not an integer numpy dtype"); - } - if (!NumpyTypeIsValid(NumpyTypenum(np_type))) { - return error( - "Return value of shape method numpy_dtype " - "is not a valid integer numpy dtype"); - } - const PrimitiveType element_type = - NumpyTypeToPrimitiveType(NumpyTypenum(np_type)); - Py_DECREF(np_type); - - if (element_type == TUPLE) { - PyObject* py_subshapes; - TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes")); - if (!PyTuple_Check(py_subshapes)) { - return error( - "Return value of Shape method tuple_shapes() is not a tuple"); - } - const int length = PyTuple_Size(py_subshapes); - std::vector subshapes; - subshapes.reserve(length); - for (int i = 0; i < length; i++) { - TF_ASSIGN_OR_RETURN( - const Shape& subshape, - XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i))); - subshapes.push_back(subshape); - } - Py_DECREF(py_subshapes); - return ShapeUtil::MakeTupleShape(subshapes); - } else { - PyObject* py_dimensions; - PyObject* py_minor_to_major; - TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions")); - TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major")); - if (!PyTuple_Check(py_dimensions)) { - return error("Return value of Shape method dimensions() is not a tuple"); - } - if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) { - return error( - "Return value of Shape method minor_to_major() is neither a tuple " - "nor None"); - } - const int length = PyTuple_Size(py_dimensions); - if (py_minor_to_major != Py_None && - length != PyTuple_Size(py_minor_to_major)) { - return error( - "Shape methods dimensions() and minor_to_major() return " - "different-length tuples"); - } - std::vector dimensions(length); - std::vector minor_to_major(length); - for (int i = 0; i < length; i++) { - dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i)); - if (dimensions[i] == -1 && PyErr_Occurred()) { - return error("Dimension is not an int"); - } - - if (py_minor_to_major != Py_None) { - minor_to_major[i] = - PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i)); - if (minor_to_major[i] == -1 && PyErr_Occurred()) { - return error("Minor-to-major value is not an int"); - } - } - } - bool with_layout = py_minor_to_major != Py_None; - Py_DECREF(py_dimensions); - Py_DECREF(py_minor_to_major); - if (with_layout) { - return ShapeUtil::MakeShapeWithLayout(element_type, dimensions, - minor_to_major); - } else { - return ShapeUtil::MakeShape(element_type, dimensions); - } - } -} - -// Helper that retrieves the member with attr_name, stringifies it if is not -// None, and returns it as a C++ string. -static absl::optional GetAttrAsString(PyObject* o, - const string& attr_name) { - if (!PyObject_HasAttrString(o, attr_name.c_str())) { - return absl::nullopt; - } - PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); - if (attr == Py_None) { - Py_DECREF(attr); - return absl::nullopt; - } - string result = PyObjectCppStr(attr); - Py_DECREF(attr); - return result; -} - -// Helper that retrieves the member with attr_name, checks that it is an integer -// if it is not None, and returns it as an int32 value. -static absl::optional GetAttrAsInt32(PyObject* o, - const string& attr_name) { - if (!PyObject_HasAttrString(o, attr_name.c_str())) { - return absl::nullopt; - } - PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); - if (attr == Py_None) { - Py_DECREF(attr); - return absl::nullopt; - } - if (!CheckPyIntOrLong(attr)) { - Py_DECREF(attr); - return absl::nullopt; - } - long value = PyIntOrPyLongToLong(attr); // NOLINT - Py_DECREF(attr); - if (value == -1 && PyErr_Occurred() != nullptr) { - return absl::nullopt; - } - if (static_cast(value) != value) { - return absl::nullopt; - } - return value; -} - -StatusOr OpMetadataFromPyObject(PyObject* o) { - OpMetadata result; - absl::optional op_type = GetAttrAsString(o, "op_type"); - if (op_type.has_value()) { - result.set_op_type(op_type.value()); - } - absl::optional op_name = GetAttrAsString(o, "op_name"); - if (op_name.has_value()) { - result.set_op_name(op_name.value()); - } - absl::optional source_file = GetAttrAsString(o, "source_file"); - if (source_file.has_value()) { - result.set_source_file(source_file.value()); - } - absl::optional source_line = GetAttrAsInt32(o, "source_line"); - if (source_line.has_value()) { - result.set_source_line(source_line.value()); - } - return result; -} - -StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal) { - if (literal.shape().IsTuple()) { - int num_elements = ShapeUtil::TupleElementCount(literal.shape()); - std::vector elems(num_elements); - for (int i = 0; i < num_elements; i++) { - TF_ASSIGN_OR_RETURN(elems[i], - PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); - } - Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements)); - for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM(tuple.get(), i, elems[i].release()); - } - return tuple; - } else { - int rank = literal.shape().rank(); - std::vector dimensions(rank); // NOLINT - PyArray requires a long* - for (int i = 0; i < rank; i++) { - dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); - } - int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); - Safe_PyObjectPtr array = make_safe( - PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0)); - TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray( - np_type, literal, reinterpret_cast(array.get()))); - return array; - } -} - -StatusOr XlaLiteralFromPyObject(PyObject* o) { - if (PyTuple_Check(o)) { - int num_elements = PyTuple_Size(o); - std::vector elements; - elements.reserve(num_elements); - for (int i = 0; i < num_elements; i++) { - PyObject* element = PyTuple_GetItem(o, i); - TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element)); - elements.push_back(std::move(literal)); - } - return LiteralUtil::MakeTupleOwned(std::move(elements)); - } else if (PyArray_Check(o)) { - PyArrayObject* py_array = reinterpret_cast(o); - int rank = PyArray_NDIM(py_array); - std::vector dimensions(rank); - for (int i = 0; i < rank; i++) { - dimensions[i] = PyArray_DIM(py_array, i); - } - int np_type = PyArray_TYPE(py_array); - auto literal = LiteralUtil::CreateFromDimensions( - NumpyTypeToPrimitiveType(np_type), dimensions); - TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal)); - return std::move(literal); - } else { - return InvalidArgument( - "Non-tuple or Numpy array encountered in conversion to XLA literal."); - } -} - -Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, - Literal* literal) { - switch (np_type) { - case NPY_BOOL: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_INT8: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_INT16: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_INT32: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_INT64: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_UINT8: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_UINT16: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_UINT32: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_UINT64: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_FLOAT16: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_FLOAT32: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_FLOAT64: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_COMPLEX64: - CopyNumpyArrayToLiteral(py_array, literal); - break; - case NPY_COMPLEX128: - CopyNumpyArrayToLiteral(py_array, literal); - break; - default: - return InvalidArgument( - "No XLA literal container for Numpy type number: %d", np_type); - } - return Status::OK(); -} - -Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array) { - switch (np_type) { - case NPY_BOOL: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_INT8: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_INT16: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_INT32: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_INT64: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_UINT8: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_UINT16: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_UINT32: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_UINT64: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_FLOAT16: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_FLOAT32: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_FLOAT64: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_COMPLEX64: - CopyLiteralToNumpyArray(literal, py_array); - break; - case NPY_COMPLEX128: - CopyLiteralToNumpyArray(literal, py_array); - break; - default: - return InvalidArgument( - "No XLA literal container for Numpy type number: %d", np_type); - } - return Status::OK(); -} - -PyObject* LongToPyIntOrPyLong(long x) { // NOLINT -#if PY_MAJOR_VERSION < 3 - return PyInt_FromLong(x); -#else - return PyLong_FromLong(x); -#endif -} - -long PyIntOrPyLongToLong(PyObject* o) { // NOLINT -#if PY_MAJOR_VERSION < 3 - return PyInt_AsLong(o); -#else - return PyLong_AsLong(o); -#endif -} - -bool CheckPyIntOrLong(PyObject* o) { -#if PY_MAJOR_VERSION < 3 - return PyInt_Check(o); -#else - if (!PyLong_Check(o)) { - return false; - } - int overflow = 0; - PyLong_AsLongAndOverflow(o, &overflow); - return (overflow == 0); -#endif -} - -PyObject* PyNumberToPyInt(PyObject* o) { -#if PY_MAJOR_VERSION < 3 - return PyNumber_Int(o); -#else - return PyNumber_Long(o); -#endif -} - -} // namespace numpy - -bool GetIntAttr(PyObject* o, const char* field, int64* result) { - PyObject* fo = PyObject_GetAttrString(o, field); - if (!fo) { - return false; - } - const int64 value = numpy::PyIntOrPyLongToLong(fo); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(fo); - return false; - } - Py_DECREF(fo); - *result = value; - return true; -} - -// Returns "ok"; true if there is no error, false if there was an error. -bool HandleStringAttribute(PyObject* o, const char* attr_name, - std::function f) { - if (!PyObject_HasAttrString(o, attr_name)) { - return true; // It's ok for the object to not have the attribute. - } - PyObject* attr = PyObject_GetAttrString(o, attr_name); - if (attr == nullptr) { - return false; // An error occurred getting the attribute. - } - if (attr == Py_None) { - Py_DECREF(attr); - return true; // The attribute is None, which we consider ok. - } -#if PY_MAJOR_VERSION < 3 - if (!PyString_Check(attr)) { - string message = absl::StrFormat("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr)); - PyErr_SetString(PyExc_TypeError, message.c_str()); - Py_DECREF(attr); - return false; // Type error, not ok. - } - f(PyString_AsString(attr)); -#else - if (!PyBytes_Check(attr)) { - string message = absl::StrFormat("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr)); - PyErr_SetString(PyExc_TypeError, message.c_str()); - Py_DECREF(attr); - return false; // Type error, not ok. - } - f(PyBytes_AsString(attr)); -#endif - - Py_DECREF(attr); - return true; // Handled string attribute, ok! -} - -bool HandleRepeatedInt64Attribute( - PyObject* o, const char* attr_name, - tensorflow::protobuf::RepeatedField* field) { - PyObject* seq = PyObject_GetAttrString(o, attr_name); - if (!seq) { - return false; - } - - int length = PySequence_Size(seq); - if (length == -1) { - Py_DECREF(seq); - return false; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(seq, i); - if (!item) { - Py_DECREF(seq); - return false; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(seq); - return false; - } - *field->Add() = dimension; - Py_DECREF(item); - } - Py_DECREF(seq); - return true; -} - -} // namespace swig - -} // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h deleted file mode 100644 index eff8cda334f..00000000000 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// These functions transform Python/Numpy data structures to XLA data -// structures and vice versa, performing copies where -// appropriate. Python tuples and Numpy ndarrays translate to XLA -// tuples and XLA literals, respectively, and Numpy shape/dtype -// information is translated to XLA shape information. - -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/python/lib/core/numpy.h" - -namespace xla { - -namespace swig { - -namespace numpy { - -struct PyDecrefDeleter { - void operator()(PyObject* p) const { Py_DECREF(p); } -}; - -// Safe container for an owned PyObject. On destruction, the reference count of -// the contained object will be decremented. -using Safe_PyObjectPtr = std::unique_ptr; - -Safe_PyObjectPtr make_safe(PyObject* object); - -// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy -// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and -// vice versa. -int PrimitiveTypeToNumpyType(PrimitiveType primitive_type); -PrimitiveType NumpyTypeToPrimitiveType(int np_type); - -// Determines whether an integer-encoded Numpy dtype is valid, -// i.e. has a supported conversion to an XLA PrimitiveType. -bool NumpyTypeIsValid(int np_type); - -// Converts XLA shape information into a Python pair of the form -// (numpy dtype, dimensions). If the XLA shape represents a tuple, -// then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a -// Python tuple of shape-description pairs, created -// recursively. Otherwise, `dimensions` is a Python tuple-of-integers -// providing the array dimensions. -// -// The return value is a new reference. -Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape); - -// Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple -// of argument shapes and result_shape is the result shape. Each shape is as -// described in in PyShapeInfoFromXlaShape's comment. -Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( - const ProgramShape& shape); - -// Converts a Python object with a method interface mathing that of -// xla_client.Shape into an XLA Shape object. -// -// The return value is a new reference. -StatusOr XlaShapeFromPyShape(PyObject* o); - -// Converts a PyObject that represents operation metadata into protocol buffer -// form. -StatusOr OpMetadataFromPyObject(PyObject* o); - -// Converts an XLA literal to a Python object, either a Numpy ndarray -// or a nested Python tuple thereof. -// -// To avoid transferring ownership of the data buffers that underlie -// PyArrays and XLA literals, this function makes deep copies of all -// array data. -// -// The return value is a new reference. -StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal); - -// Converts a Numpy ndarray or a nested Python tuple thereof to a -// corresponding XLA literal. -// -// To avoid transferring ownership of the data buffers that underlie -// PyArrays and XLA literals, this function makes deep copies of all -// array data. -StatusOr XlaLiteralFromPyObject(PyObject* o); - -// The following functions copy array data from the buffers underlying Numpy -// ndarrays into those underlying XLA literals, and vice versa. - -Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, - Literal* literal); - -Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array); - -template -void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { - NativeT* source = static_cast(PyArray_DATA(py_array)); - auto dest = literal->data(); - std::copy(source, source + PyArray_SIZE(py_array), dest.data()); -} - -template -void CopyLiteralToNumpyArray(const LiteralSlice& literal, - PyArrayObject* py_array) { - NativeT* dest = static_cast(PyArray_DATA(py_array)); - auto source = literal.data(); - std::copy(source.begin(), source.end(), dest); -} - -// Safely returns a repr of the given Python object o as a C++ string. -string PyObjectCppRepr(PyObject* o); - -// Workarounds for Python 2 and 3 interop - -PyObject* LongToPyIntOrPyLong(long x); // NOLINT -long PyIntOrPyLongToLong(PyObject* o); // NOLINT -bool CheckPyIntOrLong(PyObject* o); -PyObject* PyNumberToPyInt(PyObject* o); - -} // namespace numpy - -// Miscellaneous swig helpers that don't have a better home. - -bool GetIntAttr(PyObject* o, const char* field, int64* result); - -// Returns "ok"; true if there is no error, false if there was an error. -bool HandleStringAttribute(PyObject* o, const char* attr_name, - std::function f); - -bool HandleRepeatedInt64Attribute( - PyObject* o, const char* attr_name, - tensorflow::protobuf::RepeatedField* field); - -} // namespace swig - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ diff --git a/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds deleted file mode 100644 index ef77ed3d958..00000000000 --- a/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds +++ /dev/null @@ -1,2 +0,0 @@ -_PyInit__pywrap_xla -_init_pywrap_xla diff --git a/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds deleted file mode 100644 index d31cfce7be7..00000000000 --- a/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds +++ /dev/null @@ -1,6 +0,0 @@ -xla { - global: - PyInit_*; - local: - *; -}; diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.cc b/tensorflow/compiler/xla/python/shared_device_buffer.cc new file mode 100644 index 00000000000..8d7ce0088a4 --- /dev/null +++ b/tensorflow/compiler/xla/python/shared_device_buffer.cc @@ -0,0 +1,184 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/shared_device_buffer.h" + +#include "tensorflow/stream_executor/device_memory_allocator.h" + +namespace xla { + +BufferDefinitionEvent::BufferDefinitionEvent(se::StreamExecutor* executor) + : event_(executor) {} + +void BufferDefinitionEvent::RecordOnStream(se::Stream* stream) { + absl::MutexLock lock(&mu_); + CHECK(streams_defined_on_.empty()); + stream->ThenRecordEvent(&event_); + streams_defined_on_.push_back(stream); +} + +void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) { + absl::MutexLock lock(&mu_); + + // The set of defined streams is expected to be very small indeed (usually + // 1-2), so a simple linear scan should be fast enough. + if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(), + stream) != streams_defined_on_.end()) { + // stream is in streams_defined_on_; it doesn't need to be waited on. + return; + } + + stream->ThenWaitFor(&event_); + streams_defined_on_.push_back(stream); +} + +static std::shared_ptr +BufferFromScopedShapedBufferIterator( + const Shape& on_device_shape, int device_ordinal, + se::DeviceMemoryAllocator* allocator, + ShapeTree::iterator* iterator, + const ShapeTree::iterator& end, + const std::shared_ptr& definition_event) { + CHECK(*iterator != end); + + se::OwningDeviceMemory device_memory((*iterator)->second, device_ordinal, + allocator); + (*iterator)->second = se::DeviceMemoryBase(); + ++*iterator; + + std::vector> children; + if (on_device_shape.IsTuple()) { + int num_children = ShapeUtil::TupleElementCount(on_device_shape); + children.reserve(num_children); + for (int i = 0; i < num_children; ++i) { + children.push_back(BufferFromScopedShapedBufferIterator( + on_device_shape.tuple_shapes(i), device_ordinal, allocator, iterator, + end, definition_event)); + } + } + return std::make_shared( + on_device_shape, std::move(device_memory), children, definition_event); +} + +/* static */ std::shared_ptr +PySharedDeviceBuffer::FromScopedShapedBuffer( + ScopedShapedBuffer shaped_buffer, + const std::shared_ptr& definition_event) { + ShapeTree::iterator iterator = + shaped_buffer.buffers().begin(); + std::shared_ptr output = + BufferFromScopedShapedBufferIterator( + shaped_buffer.on_device_shape(), shaped_buffer.device_ordinal(), + shaped_buffer.memory_allocator(), &iterator, + shaped_buffer.buffers().end(), definition_event); + CHECK(iterator == shaped_buffer.buffers().end()); + return output; +} + +/* static */ StatusOr> +PySharedDeviceBuffer::MakeTuple( + std::vector> children, + TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, + int device_ordinal, + std::shared_ptr definition_event) { + std::vector child_shapes; + child_shapes.reserve(children.size()); + for (const auto& child : children) { + TF_RET_CHECK(child->device_memory().device_ordinal() == device_ordinal); + child_shapes.push_back(child->on_device_shape()); + } + + Shape shape = ShapeUtil::MakeTupleShape(child_shapes); + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory device_memory, + allocator->Allocate(device_ordinal, + transfer_manager->GetByteSizeRequirement(shape))); + return std::make_shared( + std::move(shape), std::move(device_memory), std::move(children), + std::move(definition_event)); +} + +/* static */ StatusOr> +PySharedDeviceBuffer::MakeArray( + Shape on_device_shape, TransferManager* transfer_manager, + se::DeviceMemoryAllocator* allocator, int device_ordinal, + std::shared_ptr definition_event) { + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory device_memory, + allocator->Allocate( + device_ordinal, + transfer_manager->GetByteSizeRequirement(on_device_shape))); + return std::make_shared( + std::move(on_device_shape), std::move(device_memory), + /*children=*/std::vector>{}, + std::move(definition_event)); +} + +// Populates a buffer tree from a ShapeTree iterator. +static void PopulateShapedBufferFromBuffer( + const PySharedDeviceBuffer& buffer, + ShapeTree::iterator* iterator, + const ShapeTree::iterator& end) { + CHECK(*iterator != end); + (*iterator)->second = *buffer.device_memory(); + ++*iterator; + for (const auto& child : buffer.children()) { + PopulateShapedBufferFromBuffer(*child, iterator, end); + } +} + +ShapedBuffer PySharedDeviceBuffer::AsShapedBuffer( + const Shape& on_host_shape) const { + ShapedBuffer shaped_buffer(on_host_shape, on_device_shape_, + device_memory_.allocator()->platform(), + device_memory_.device_ordinal()); + ShapeTree::iterator iterator = + shaped_buffer.buffers().begin(); + PopulateShapedBufferFromBuffer(*this, &iterator, + shaped_buffer.buffers().end()); + CHECK(iterator == shaped_buffer.buffers().end()); + return shaped_buffer; +} + +PySharedDeviceBuffer::PySharedDeviceBuffer( + Shape on_device_shape, se::OwningDeviceMemory device_memory, + std::vector> children, + std::shared_ptr definition_event) + : on_device_shape_(std::move(on_device_shape)), + device_memory_(std::move(device_memory)), + children_(std::move(children)), + definition_event_(std::move(definition_event)) {} + +void GetDeviceBufferDefinitionEvents( + const PySharedDeviceBuffer& buffer, + absl::flat_hash_set* events) { + if (buffer.definition_event()) { + events->insert(buffer.definition_event().get()); + } + for (const auto& child : buffer.children()) { + GetDeviceBufferDefinitionEvents(*child, events); + } +} + +void WaitForBufferDefinitionEventsOnStream(const PySharedDeviceBuffer& buffer, + se::Stream* stream) { + absl::flat_hash_set events; + GetDeviceBufferDefinitionEvents(buffer, &events); + for (BufferDefinitionEvent* event : events) { + event->WaitForEventOnStream(stream); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.h b/tensorflow/compiler/xla/python/shared_device_buffer.h new file mode 100644 index 00000000000..31cab5ade45 --- /dev/null +++ b/tensorflow/compiler/xla/python/shared_device_buffer.h @@ -0,0 +1,154 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" + +namespace xla { + +// A BufferDefinitionEvent describes whether a buffer is valid from the +// viewpoint of each of stream that may access it. +// +// Each logical buffer in an XLA computation may be defined (i.e., written to) +// at most once, although the same physical piece of memory may be reused for +// multiple logical buffers. We call the operation that writes the buffer's +// value on some stream (e.g., a transfer or compute kernel) the buffer's +// definition event. +// +// After the operation that populates the value of a buffer has been enqueued on +// 'stream', RecordOnStream(stream) should also be called to trigger the +// definition event after the operation has completed. +// +// Since different streams are not necessarily synchronized with one another, +// if we wish to consume the value of the buffer on a different stream, we +// should first call WaitForEventOnStream(stream), which add a cross-stream +// from 'stream' to the buffer's definition event, causing 'stream' to pause +// until the definition event has been triggered, if needed. Operations on +// 'stream' may then assume that the buffer is valid and its contents correspond +// to the desired buffer. +// +// The dependency logic caches the set of streams at the tail of which the +// definition event is known to have occurred; waiting for the same event on the +// same stream causes no additional waiting. +class BufferDefinitionEvent { + public: + // Creates a new definition event whose event has not yet been triggered. + explicit BufferDefinitionEvent(se::StreamExecutor* executor); + + // Records the definition event on the tail of 'stream'. + void RecordOnStream(se::Stream* stream); + + // Adds synchronization events to 'stream' that wait for this event to be + // defined on 'stream'. Does nothing if the event is already known to have + // occurred by the tail of 'stream'. + void WaitForEventOnStream(se::Stream* stream); + + private: + // An event that is triggered when the content of one or more buffers is + // ready. If this event is nullptr, it is assumed that the buffer's content is + // always defined. + se::Event event_; + + absl::Mutex mu_; + + // A list of all streams for which the buffer's content is known to be defined + // at the tail of the queue, i.e., for any newly enqueued command. + absl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); +}; + +// Class that represents a node in a reference-counted DAG of device buffers. +// Unlike a ShapedBuffer, which owns none of its buffers, and +// ScopedShapedBuffer, which owns an entire buffer tree, the reference counting +// in a PySharedDeviceBuffer DAG is done at the level of individual device +// buffers. Reference counting buffer individually is more convenient when +// manipulating on-device tuples where a tuple and its elements may have +// different lifetimes. +class PySharedDeviceBuffer { + public: + // Converts a ScopedShapedBuffer into a Buffer tree. Takes ownership of the + // contents of the shaped_buffer. + static std::shared_ptr FromScopedShapedBuffer( + ScopedShapedBuffer shaped_buffer, + const std::shared_ptr& definition_event); + + // Makes a tuple buffer. Does not initialize the tuple table. + static StatusOr> MakeTuple( + std::vector> children, + TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, + int device_ordinal, + std::shared_ptr definition_event); + + // Makes an uninitialized array buffer. + static StatusOr> MakeArray( + Shape on_device_shape, TransferManager* transfer_manager, + se::DeviceMemoryAllocator* allocator, int device_ordinal, + std::shared_ptr definition_event); + + // Builds a ShapedBuffer view onto the buffers of 'tree'. Since + // PySharedDeviceBuffer does not maintain the on-host shape, the caller must + // provide it. We require but do not verify that + // TransferManager::HostShapeToDeviceShape(on_host_shape) == on_device_shape() + ShapedBuffer AsShapedBuffer(const Shape& on_host_shape) const; + + const Shape& on_device_shape() const { return on_device_shape_; } + const std::vector>& children() const { + return children_; + } + const se::OwningDeviceMemory& device_memory() const { return device_memory_; } + int device_ordinal() const { return device_memory_.device_ordinal(); } + const std::shared_ptr definition_event() const { + return definition_event_; + } + + PySharedDeviceBuffer() = default; + PySharedDeviceBuffer( + Shape on_device_shape, se::OwningDeviceMemory device_memory, + std::vector> children, + std::shared_ptr definition_event); + + private: + // We only represent the on-device shape. The on-host shape may not be + // one-to-one with the tree of device buffers, so to avoid representational + // awkwardness we maintain on-host shapes separately. + Shape on_device_shape_; + se::OwningDeviceMemory device_memory_; + std::vector> children_; + + // An event that is triggered when the content of one or more buffers is + // ready during multistream execution. May be nullptr, which is used in the + // single-stream execution case where events are not necessary for buffer + // event sequencing. + std::shared_ptr definition_event_; +}; + +// Populates 'events' with the set of buffer definition events for all buffers +// in the buffer DAG rooted at 'buffer'. +void GetDeviceBufferDefinitionEvents( + const PySharedDeviceBuffer& buffer, + absl::flat_hash_set* events); + +// Waits for all of the buffer definition events in a buffer DAG on 'stream'. +void WaitForBufferDefinitionEventsOnStream(const PySharedDeviceBuffer& buffer, + se::Stream* stream); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ diff --git a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc b/tensorflow/compiler/xla/python/shared_device_buffer_test.cc new file mode 100644 index 00000000000..79f9ecd7ed9 --- /dev/null +++ b/tensorflow/compiler/xla/python/shared_device_buffer_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/shared_device_buffer.h" + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(PySharedDeviceBufferTest, MakeArray) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + + Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); + TF_ASSERT_OK_AND_ASSIGN( + auto buffer, PySharedDeviceBuffer::MakeArray( + shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + EXPECT_EQ( + buffer->on_device_shape(), + client->backend().transfer_manager()->HostShapeToDeviceShape(shape)); + EXPECT_EQ(buffer->children().size(), 0); + EXPECT_EQ(buffer->device_memory().device_ordinal(), 0); + EXPECT_EQ(buffer->device_memory().allocator(), + client->backend().memory_allocator()); + EXPECT_FALSE(buffer->device_memory().is_null()); +} + +TEST(PySharedDeviceBufferTest, MakeTuple) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + + Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); + Shape b_shape = ShapeUtil::MakeShape(S8, {77}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape}); + TF_ASSERT_OK_AND_ASSIGN( + auto a_buffer, PySharedDeviceBuffer::MakeArray( + a_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN( + auto b_buffer, PySharedDeviceBuffer::MakeArray( + b_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN( + auto tuple_buffer, + PySharedDeviceBuffer::MakeTuple( + {a_buffer, b_buffer}, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + EXPECT_EQ(tuple_buffer->on_device_shape(), + client->backend().transfer_manager()->HostShapeToDeviceShape( + tuple_shape)); + ASSERT_EQ(tuple_buffer->children().size(), 2); + EXPECT_EQ(tuple_buffer->children()[0], a_buffer); + EXPECT_EQ(tuple_buffer->children()[1], b_buffer); + EXPECT_EQ(tuple_buffer->device_memory().device_ordinal(), 0); + EXPECT_EQ(tuple_buffer->device_memory().allocator(), + client->backend().memory_allocator()); + EXPECT_FALSE(tuple_buffer->device_memory().is_null()); +} + +TEST(PySharedDeviceBufferTest, AsShapedBuffer) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + + Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); + Shape b_shape = ShapeUtil::MakeShape(S8, {77}); + Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape}); + Shape c_shape = ShapeUtil::MakeShape(S64, {}); + Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape}); + TF_ASSERT_OK_AND_ASSIGN( + auto a_buffer, PySharedDeviceBuffer::MakeArray( + a_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN( + auto b_buffer, PySharedDeviceBuffer::MakeArray( + b_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN( + auto ab_tuple_buffer, + PySharedDeviceBuffer::MakeTuple( + {a_buffer, b_buffer}, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN( + auto c_buffer, PySharedDeviceBuffer::MakeArray( + c_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN( + auto abc_tuple_buffer, + PySharedDeviceBuffer::MakeTuple( + {c_buffer, ab_tuple_buffer}, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, nullptr)); + EXPECT_EQ(abc_tuple_buffer->on_device_shape(), + client->backend().transfer_manager()->HostShapeToDeviceShape( + abc_tuple_shape)); + + ShapedBuffer shaped_buffer = + abc_tuple_buffer->AsShapedBuffer(abc_tuple_shape); + EXPECT_EQ(shaped_buffer.on_host_shape(), abc_tuple_shape); + EXPECT_EQ(shaped_buffer.on_device_shape(), + abc_tuple_buffer->on_device_shape()); + + std::vector expected_buffer_sequence = { + *abc_tuple_buffer->device_memory(), *c_buffer->device_memory(), + *ab_tuple_buffer->device_memory(), *a_buffer->device_memory(), + *b_buffer->device_memory(), + }; + auto it = shaped_buffer.buffers().begin(); + auto expected_it = expected_buffer_sequence.begin(); + while (it != shaped_buffer.buffers().end()) { + ASSERT_TRUE(expected_it != expected_buffer_sequence.end()); + EXPECT_TRUE(expected_it->IsSameAs(it->second)); + ++it; + ++expected_it; + } + EXPECT_TRUE(expected_it == expected_buffer_sequence.end()); +} + +TEST(PySharedDeviceBufferTest, FromScopedShapedBuffer) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + + Literal literal = LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateFullWithDescendingLayout({10, 3, 7}, 33.4f), + LiteralUtil::One(S64)); + + TF_ASSERT_OK_AND_ASSIGN( + ScopedShapedBuffer shaped_buffer, + client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); + std::shared_ptr device_buffer = + PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(shaped_buffer), + nullptr); + + EXPECT_EQ(device_buffer->on_device_shape(), + client->backend().transfer_manager()->HostShapeToDeviceShape( + literal.shape())); + ASSERT_EQ(device_buffer->children().size(), 2); + EXPECT_EQ(device_buffer->children()[0]->on_device_shape(), + client->backend().transfer_manager()->HostShapeToDeviceShape( + ShapeUtil::MakeShape(F32, {10, 3, 7}))); + EXPECT_EQ(device_buffer->children()[1]->on_device_shape(), + client->backend().transfer_manager()->HostShapeToDeviceShape( + ShapeUtil::MakeShape(S64, {}))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/python/types.cc b/tensorflow/compiler/xla/python/types.cc new file mode 100644 index 00000000000..f3c83e48e00 --- /dev/null +++ b/tensorflow/compiler/xla/python/types.cc @@ -0,0 +1,217 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/types.h" + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +namespace py = pybind11; + +xla::StatusOr DtypeToPrimitiveType(const py::dtype& np_type) { + static auto* types = + new absl::flat_hash_map, PrimitiveType>({ + {{'b', 1}, PRED}, + {{'i', 1}, S8}, + {{'i', 2}, S16}, + {{'i', 4}, S32}, + {{'i', 8}, S64}, + {{'u', 1}, U8}, + {{'u', 2}, U16}, + {{'u', 4}, U32}, + {{'u', 8}, U64}, + {{'f', 2}, F16}, + {{'f', 4}, F32}, + {{'f', 8}, F64}, + {{'c', 8}, C64}, + {{'c', 16}, C128}, + }); + auto it = types->find({np_type.kind(), np_type.itemsize()}); + if (it == types->end()) { + return InvalidArgument("Unknown NumPy type %c size %d", np_type.kind(), + np_type.itemsize()); + } + return it->second; +} + +xla::StatusOr PrimitiveTypeToDtype(PrimitiveType type) { + switch (type) { + case PRED: + return py::dtype::of(); + case S8: + return py::dtype::of(); + case S16: + return py::dtype::of(); + case S32: + return py::dtype::of(); + case S64: + return py::dtype::of(); + case U8: + return py::dtype::of(); + case U16: + return py::dtype::of(); + case U32: + return py::dtype::of(); + case U64: + return py::dtype::of(); + case F16: + return py::dtype("e"); + case F32: + return py::dtype::of(); + case F64: + return py::dtype::of(); + case C64: + return py::dtype::of>(); + case C128: + return py::dtype::of>(); + default: + return Unimplemented("Unimplemented primitive type %s", + PrimitiveType_Name(type)); + } +} + +// Returns a numpy-style format descriptor string for `type`. +StatusOr FormatDescriptorForPrimitiveType(PrimitiveType type) { + switch (type) { + case PRED: + return py::format_descriptor::format(); + case S8: + return py::format_descriptor::format(); + case S16: + return py::format_descriptor::format(); + case S32: + return py::format_descriptor::format(); + case S64: + return py::format_descriptor::format(); + case U8: + return py::format_descriptor::format(); + case U16: + return py::format_descriptor::format(); + case U32: + return py::format_descriptor::format(); + case U64: + return py::format_descriptor::format(); + case F16: + return std::string("e"); + case F32: + return py::format_descriptor::format(); + case F64: + return py::format_descriptor::format(); + case C64: + return py::format_descriptor>::format(); + case C128: + return py::format_descriptor>::format(); + default: + return Unimplemented("Unimplemented primitive type %s", + PrimitiveType_Name(type)); + } +} + +// Returns the strides for `shape`. +std::vector StridesForShape(const Shape& shape) { + std::vector strides; + CHECK(shape.IsArray()); + CHECK(shape.has_layout()); + + strides.resize(shape.dimensions_size()); + ssize_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); + for (int i : shape.layout().minor_to_major()) { + strides.at(i) = stride; + stride *= shape.dimensions(i); + } + return strides; +} + +StatusOr LiteralToPython(std::unique_ptr literal) { + xla::Literal& m = *literal; + if (m.shape().IsTuple()) { + std::vector elems = m.DecomposeTuple(); + std::vector arrays(elems.size()); + for (int i = 0; i < elems.size(); ++i) { + TF_ASSIGN_OR_RETURN( + arrays[i], + LiteralToPython(absl::make_unique(std::move(elems[i])))); + } + py::tuple result(elems.size()); + for (int i = 0; i < elems.size(); ++i) { + PyTuple_SET_ITEM(result.ptr(), i, arrays[i].release().ptr()); + } + return result; + } + TF_RET_CHECK(m.shape().IsArray()); + + auto capsule = py::capsule(literal.release(), [](void* ptr) { + delete reinterpret_cast(ptr); + }); + TF_ASSIGN_OR_RETURN(std::string format, FormatDescriptorForPrimitiveType( + m.shape().element_type())); + py::buffer_info info( + m.untyped_data(), // Pointer to buffer + xla::ShapeUtil::ByteSizeOfPrimitiveType( + m.shape().element_type()), // Size of one scalar + format, // Python struct-style format descriptor + m.shape().dimensions_size(), // Number of dimensions + m.shape().dimensions(), // Buffer dimensions + StridesForShape(m.shape()) // Strides (in bytes) for each index + ); + return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, + capsule); +} + +StatusOr GetPythonBufferTree(const py::object& argument) { + PythonBufferTree tree; + if (py::isinstance(argument)) { + py::tuple tuple = py::reinterpret_borrow(argument); + std::vector host_shapes(tuple.size()); + for (int i = 0; i < host_shapes.size(); ++i) { + TF_ASSIGN_OR_RETURN(PythonBufferTree subtree, + GetPythonBufferTree(tuple[i])); + tree.leaves.reserve(tree.leaves.size() + subtree.leaves.size()); + std::move(subtree.leaves.begin(), subtree.leaves.end(), + std::back_inserter(tree.leaves)); + host_shapes[i] = std::move(subtree.shape); + } + tree.shape = ShapeUtil::MakeTupleShape(host_shapes); + } else { + pybind11::detail::type_caster caster; + if (!caster.load(argument, /*convert=*/true)) { + return InvalidArgument("Invalid array value."); + } + tree.arrays.push_back(std::move(caster.array)); + tree.leaves.push_back(std::move(*caster)); + tree.shape = tree.leaves.front().shape(); + } + return tree; +} + +py::tuple IntSpanToTuple(absl::Span xs) { + py::tuple out(xs.size()); + for (int i = 0; i < xs.size(); ++i) { + out[i] = py::int_(xs[i]); + } + return out; +} + +std::vector IntSequenceToVector(const py::object& sequence) { + std::vector output; + for (auto item : sequence) { + output.push_back(item.cast()); + } + return output; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h new file mode 100644 index 00000000000..d3c867e8304 --- /dev/null +++ b/tensorflow/compiler/xla/python/types.h @@ -0,0 +1,429 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_ + +#include + +#include "absl/types/optional.h" +#include "include/pybind11/numpy.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/stl.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace xla { + +// Helper that converts a failing StatusOr to an exception. +// For use only inside pybind11 code. +template +T ValueOrThrow(StatusOr v) { + if (!v.ok()) { + throw std::runtime_error(v.status().ToString()); + } + return v.ConsumeValueOrDie(); +} + +// Converts a NumPy dtype to a PrimitiveType. +StatusOr DtypeToPrimitiveType(const pybind11::dtype& np_type); + +// Converts a PrimitiveType to a Numpy dtype. +StatusOr PrimitiveTypeToDtype(PrimitiveType type); + +// Converts a literal to (possibly-nested tuples of) NumPy arrays. +// The literal's leaf arrays are not copied; instead the NumPy arrays share +// buffers with the literals. Takes ownership of `literal` and keeps the +// necessary pieces alive using Python reference counting. +// Requires the GIL. +StatusOr LiteralToPython(std::unique_ptr literal); + +// Converts a Python object into an XLA shape and a vector of leaf buffers. +// The leaf buffers correspond to a depth-first, left-to-right traversal of +// the Python value. +// Requires the GIL. +struct PythonBufferTree { + // Holds a reference to the arrays pointed to by `leaves`, since we may + // need to make a copy if the array is not in a C-style layout. + absl::InlinedVector arrays; + absl::InlinedVector leaves; + Shape shape; +}; +StatusOr GetPythonBufferTree( + const pybind11::object& argument); + +// Converts a sequence of int64s to a Python tuple of ints. +// Pybind11 by default converts a std::vector to a Python list; for +// shapes we frequently want a tuple instead. +pybind11::tuple IntSpanToTuple(absl::Span xs); + +// Converts a Python sequence of integers to a std::vector +std::vector IntSequenceToVector(const pybind11::object& sequence); + +} // namespace xla + +// This namespace is a documented pybind11 extension point. +// Caution: Unusually for Google code, this code uses C++ exceptions because +// they are the only mechanism for reporting cast failures to pybind11. However, +// the exceptions are local to the binding code. +namespace pybind11 { +namespace detail { + +// When absl::optional is an alias for std::optional, the type_caster +// specializations are provided by pybind11. +#ifndef ABSL_HAVE_STD_OPTIONAL +// absl::optional +template +struct type_caster> : optional_caster> {}; + +template <> +struct type_caster : public void_caster {}; +#endif + +// absl::Span +template +struct type_caster> { + using value_conv = make_caster; + + PYBIND11_TYPE_CASTER(absl::Span, + _("Span[") + value_conv::name() + _("]")); + + // absl::Span doesn't hold ownership. We therefore need a temporary array. + // Pybind appears to keep type_casters alive until the callee has run. + std::vector storage_; + + bool load(handle src, bool convert) { + if (!isinstance(src)) { + return false; + } + auto seq = reinterpret_borrow(src); + storage_.clear(); + storage_.reserve(seq.size()); + for (auto it : seq) { + value_conv conv; + if (!conv.load(it, convert)) { + return false; + } + storage_.push_back(cast_op(std::move(conv))); + } + value = absl::Span(storage_); + return true; + } +}; + +// Status, StatusOr. Failing statuses become Python exceptions; Status::OK() +// becomes None. +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::Status, _("Status")); + + static handle cast(xla::Status src, return_value_policy /* policy */, + handle /* parent */) { + if (!src.ok()) { + throw std::runtime_error(src.ToString()); + } + return none().inc_ref(); + } +}; + +template +struct type_caster> { + public: + using value_conv = make_caster; + + PYBIND11_TYPE_CASTER(xla::StatusOr, + _("StatusOr[") + value_conv::name() + _("]")); + + static handle cast(xla::StatusOr src, return_value_policy policy, + handle parent) { + if (!src.ok()) { + throw std::runtime_error(src.status().ToString()); + } + return value_conv::cast(std::forward>(src).ValueOrDie(), + policy, parent); + } +}; + +// Literals. +// Literal data can be passed to XLA as a NumPy array; its value can be +// cast to an xla::BorrowingLiteral or xla::LiteralSlice in a zero-copy way. +// We don't have any literal -> numpy conversions here, since all the methods +// that want to return arrays build Python objects directly. + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::BorrowingLiteral, _("xla::BorrowingLiteral")); + + // Pybind appears to keep type_casters alive until the callee has run. + pybind11::array array; + + bool load(handle handle, bool) { + array = pybind11::array::ensure( + handle, pybind11::array::c_style | + pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_); + if (!array) return false; + pybind11::buffer_info buffer_info = array.request(); + + absl::InlinedVector dims(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims[i] = array.shape(i); + } + auto type = xla::DtypeToPrimitiveType(array.dtype()); + if (!type.ok()) { + throw std::runtime_error(type.status().ToString()); + } + xla::Shape shape = xla::ShapeUtil::MakeShape(type.ValueOrDie(), dims); + if (buffer_info.size * buffer_info.itemsize != + xla::ShapeUtil::ByteSizeOf(shape)) { + throw std::runtime_error(absl::StrCat( + "Size mismatch for buffer: ", buffer_info.size * buffer_info.itemsize, + " vs. ", xla::ShapeUtil::ByteSizeOf(shape))); + } + value = + xla::BorrowingLiteral(static_cast(buffer_info.ptr), shape); + return true; + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::LiteralSlice, _("xla::LiteralSlice")); + + // Pybind appears to keep type_casters alive until the callee has run. + type_caster literal_caster; + + bool load(handle handle, bool convert) { + if (!literal_caster.load(handle, convert)) { + return false; + } + value = static_cast(literal_caster); + return true; + } +}; + +// XLA protocol buffers +// We don't actually care that these are the protocol buffers, we merely want +// objects that duck type as protocol buffers. The client code currently avoids +// depending on Python protocol buffers to avoid conflicting definitions from +// different modules that both include XLA. + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::ConvolutionDimensionNumbers, + _("xla::ConvolutionDimensionNumbers")); + + // PyObject -> C++ conversion. + bool load(handle handle, bool) { + value.set_input_batch_dimension( + getattr(handle, "input_batch_dimension").cast()); + value.set_input_feature_dimension( + getattr(handle, "input_feature_dimension").cast()); + value.set_output_batch_dimension( + getattr(handle, "output_batch_dimension").cast()); + value.set_output_feature_dimension( + getattr(handle, "output_feature_dimension").cast()); + value.set_kernel_input_feature_dimension( + getattr(handle, "kernel_input_feature_dimension").cast()); + value.set_kernel_output_feature_dimension( + getattr(handle, "kernel_output_feature_dimension").cast()); + std::vector dims; + dims = getattr(handle, "input_spatial_dimensions") + .cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_input_spatial_dimensions())); + dims = getattr(handle, "kernel_spatial_dimensions") + .cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_kernel_spatial_dimensions())); + dims = getattr(handle, "output_spatial_dimensions") + .cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_output_spatial_dimensions())); + return true; + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::DotDimensionNumbers, _("xla::DotDimensionNumbers")); + + // PyObject -> C++ conversion. + bool load(handle handle, bool) { + std::vector dims; + dims = getattr(handle, "lhs_contracting_dimensions") + .cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_lhs_contracting_dimensions())); + dims = getattr(handle, "rhs_contracting_dimensions") + .cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_rhs_contracting_dimensions())); + dims = + getattr(handle, "lhs_batch_dimensions").cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_lhs_batch_dimensions())); + dims = + getattr(handle, "rhs_batch_dimensions").cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_rhs_batch_dimensions())); + return true; + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::GatherDimensionNumbers, + _("xla::GatherDimensionNumbers")); + + // PyObject -> C++ conversion. + bool load(handle handle, bool) { + std::vector dims; + dims = getattr(handle, "offset_dims").cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_offset_dims())); + dims = + getattr(handle, "collapsed_slice_dims").cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_collapsed_slice_dims())); + dims = getattr(handle, "start_index_map").cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_start_index_map())); + value.set_index_vector_dim( + getattr(handle, "index_vector_dim").cast()); + return true; + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::ScatterDimensionNumbers, + _("xla::ScatterDimensionNumbers")); + + // PyObject -> C++ conversion. + bool load(handle handle, bool) { + std::vector dims; + dims = + getattr(handle, "update_window_dims").cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_update_window_dims())); + dims = + getattr(handle, "inserted_window_dims").cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_inserted_window_dims())); + dims = getattr(handle, "scatter_dims_to_operand_dims") + .cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_scatter_dims_to_operand_dims())); + value.set_index_vector_dim( + getattr(handle, "index_vector_dim").cast()); + return true; + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::ReplicaGroup, _("xla::ReplicaGroup")); + + // PyObject -> C++ conversion. + bool load(handle handle, bool) { + std::vector dims; + dims = getattr(handle, "replica_ids").cast>(); + std::copy(dims.begin(), dims.end(), + tensorflow::protobuf::RepeatedFieldBackInserter( + value.mutable_replica_ids())); + return true; + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::PaddingConfig, _("xla::PaddingConfig")); + + // PyObject -> C++ conversion. + bool load(handle handle, bool) { + sequence dimensions = + reinterpret_borrow(getattr(handle, "dimensions")); + + for (auto dimension : dimensions) { + xla::PaddingConfig::PaddingConfigDimension* config_dim = + value.add_dimensions(); + config_dim->set_edge_padding_low( + getattr(dimension, "edge_padding_low").cast()); + config_dim->set_edge_padding_high( + getattr(dimension, "edge_padding_high").cast()); + config_dim->set_interior_padding( + getattr(dimension, "interior_padding").cast()); + } + return true; + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(xla::OpMetadata, _("xla::OpMetadata")); + + // PyObject -> C++ conversion. + bool load(handle handle, bool) { + pybind11::handle op_type = getattr(handle, "op_type"); + if (!op_type.is_none()) { + value.set_op_type(op_type.cast()); + } + pybind11::handle op_name = getattr(handle, "op_name"); + if (!op_name.is_none()) { + value.set_op_name(op_name.cast()); + } + pybind11::handle source_file = getattr(handle, "source_file"); + if (!source_file.is_none()) { + value.set_source_file(source_file.cast()); + } + pybind11::handle source_line = getattr(handle, "source_line"); + if (!source_line.is_none()) { + value.set_source_line(source_line.cast()); + } + return true; + } +}; +} // namespace detail +} // namespace pybind11 + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_ diff --git a/tensorflow/compiler/xla/python/worker_thread.cc b/tensorflow/compiler/xla/python/worker_thread.cc new file mode 100644 index 00000000000..d3fb02023a5 --- /dev/null +++ b/tensorflow/compiler/xla/python/worker_thread.cc @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/worker_thread.h" + +namespace xla { + +WorkerThread::WorkerThread(tensorflow::Env* env, const std::string& name) { + thread_.reset(env->StartThread(tensorflow::ThreadOptions(), name, + [this]() { WorkLoop(); })); +} + +WorkerThread::~WorkerThread() { + absl::MutexLock lock(&mu_); + work_queue_.push(nullptr); +} + +void WorkerThread::Schedule(std::function fn) { + CHECK(fn != nullptr); + absl::MutexLock lock(&mu_); + work_queue_.push(std::move(fn)); +} + +bool WorkerThread::WorkAvailable() { return !work_queue_.empty(); } + +void WorkerThread::WorkLoop() { + while (true) { + std::function fn; + { + absl::MutexLock lock(&mu_); + mu_.Await(absl::Condition(this, &WorkerThread::WorkAvailable)); + fn = std::move(work_queue_.front()); + work_queue_.pop(); + } + if (!fn) { + return; + } + fn(); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/worker_thread.h b/tensorflow/compiler/xla/python/worker_thread.h new file mode 100644 index 00000000000..bc7dd396f88 --- /dev/null +++ b/tensorflow/compiler/xla/python/worker_thread.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ + +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/platform/env.h" + +namespace xla { + +// A worker thread that runs a sequence of closures. Equivalent to a thread +// pool of size 1. +class WorkerThread { + public: + // 'name' is a name for the thread for debugging purposes. + WorkerThread(tensorflow::Env* env, const std::string& name); + + // Blocks until all enqueued closures have completed. + ~WorkerThread(); + + // Adds 'fn' to the queue of closures to be executed by the worker thread. + void Schedule(std::function fn); + + private: + bool WorkAvailable() EXCLUSIVE_LOCKS_REQUIRED(mu_); + void WorkLoop(); + + absl::Mutex mu_; + std::queue> work_queue_ GUARDED_BY(mu_); + + std::unique_ptr thread_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc new file mode 100644 index 00000000000..a592b0823be --- /dev/null +++ b/tensorflow/compiler/xla/python/xla.cc @@ -0,0 +1,597 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/hash/hash.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "include/pybind11/numpy.h" +#include "include/pybind11/pybind11.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/compiler/xla/client/lib/svd.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/python/xrt.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +namespace py = pybind11; + +namespace { + +struct Uniquer { + absl::Mutex mu; + NameUniquer name_uniquer GUARDED_BY(mu); +}; + +Uniquer* GetUniquer() { + static Uniquer* uniquer = new Uniquer; + return uniquer; +} + +static string UniquifyName(const string& name) { + Uniquer* uniquer = GetUniquer(); + absl::MutexLock lock(&uniquer->mu); + return uniquer->name_uniquer.GetUniqueName(name); +} + +// Converts a computation to a serialized HloModuleProto. +StatusOr GetComputationSerializedProto( + const XlaComputation& computation) { + std::string result; + if (!computation.proto().SerializeToString(&result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return py::bytes(result); +} + +// Converts a computation to textual HLO form. +StatusOr GetComputationHloText(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation.proto(), module_config)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(false); + return hlo_module->ToString(options); +} + +// Converts a computation to HLO dot graph form. +StatusOr GetComputationHloDotGraph( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation.proto(), module_config)); + return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", + hlo_module->config().debug_options(), + RenderedGraphFormat::kDot); +} + +} // namespace + +PYBIND11_MODULE(xla_extension, m) { + // Types + py::enum_(m, "PrimitiveType") + .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) + .value("PRED", PRED) + .value("S8", S8) + .value("S16", S16) + .value("S32", S32) + .value("S64", S64) + .value("U8", U8) + .value("U16", U16) + .value("U32", U32) + .value("U64", U64) + .value("F16", F16) + .value("BF16", BF16) + .value("F32", F32) + .value("F64", F64) + .value("C64", C64) + .value("C128", C128) + .value("TUPLE", TUPLE) + .value("OPAQUE_TYPE", OPAQUE_TYPE) + .value("TOKEN", TOKEN); + + // Shapes + py::class_ shape_class(m, "Shape"); + shape_class + .def_static( + "tuple_shape", + [](std::vector shapes) -> Shape { + return ShapeUtil::MakeTupleShape(shapes); + }, + "Constructs a tuple shape.") + .def_static( + "array_shape", + [](PrimitiveType type, py::object dims_seq, + absl::optional layout_seq) -> Shape { + std::vector dims = IntSequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = IntSequenceToVector(*layout_seq); + return ShapeUtil::MakeShapeWithLayout(type, dims, layout); + } else { + Shape shape = ShapeUtil::MakeShape(type, dims); + shape.clear_layout(); + return shape; + } + }, + "Constructs an array shape.", py::arg("type"), py::arg("dims"), + py::arg("layout") = absl::nullopt) + .def_static( + "array_shape", + [](py::dtype dtype, py::object dims_seq, + absl::optional layout_seq) -> Shape { + PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); + std::vector dims = IntSequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = IntSequenceToVector(*layout_seq); + return ShapeUtil::MakeShapeWithLayout(type, dims, layout); + } else { + Shape shape = ShapeUtil::MakeShape(type, dims); + shape.clear_layout(); + return shape; + } + }, + "Constructs an array shape.", py::arg("type"), py::arg("dims"), + py::arg("layout") = absl::nullopt) + .def("dimensions", + [](const Shape& shape) -> py::tuple { + return IntSpanToTuple(shape.dimensions()); + }) + .def("xla_element_type", &Shape::element_type) + .def("element_type", + [](const Shape& shape) { + return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type())); + }) + .def("numpy_dtype", + [](const Shape& shape) { + if (shape.IsTuple()) { + return py::dtype("O"); + } + return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type())); + }) + .def("is_tuple", &Shape::IsTuple) + .def("is_array", &Shape::IsArray) + .def("rank", &Shape::rank) + .def("to_serialized_proto", + [](const Shape& shape) { + ShapeProto proto = shape.ToProto(); + return py::bytes(proto.SerializeAsString()); + }) + .def("tuple_shapes", + [](const Shape& shape) { + return std::vector(shape.tuple_shapes()); + }) + .def( + "with_major_to_minor_layout_if_absent", + [](const Shape& shape) { + Shape out = shape; + ShapeUtil::ForEachMutableSubshape( + &out, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + return out; + }, + "Returns a copy of a shape with missing layouts set to " + "major-to-minor.") + .def("__eq__", [](const Shape& shape, + const Shape& other) { return shape == other; }) + .def("__ne__", [](const Shape& shape, + const Shape& other) { return shape != other; }) + .def("__hash__", + [](const Shape& shape) { return absl::Hash()(shape); }) + .def("__repr__", [](const Shape& shape) { + return shape.ToString(/*print_layouts=*/true); + }); + + py::class_(m, "ProgramShape") + .def(py::init( + [](absl::Span params, Shape result) -> ProgramShape { + ProgramShape program_shape; + for (const Shape& param : params) { + *program_shape.add_parameters() = param; + } + *program_shape.mutable_result() = result; + return program_shape; + })) + .def("parameter_shapes", + static_cast& (ProgramShape::*)() const>( + &ProgramShape::parameters)) + .def("result_shape", &ProgramShape::result) + .def("__repr__", &ProgramShape::ToString); + + // Literals + py::class_(m, "Literal").def("__repr__", &Literal::ToString); + py::class_(m, "LiteralSlice"); + py::implicitly_convertible(); + py::implicitly_convertible(); + + // Device assignments + py::class_(m, "DeviceAssignment") + .def_static("Create", + [](py::array_t array) -> StatusOr { + if (array.ndim() != 2) { + return InvalidArgument( + "Argument to DeviceAssignment constructor must be a " + "2D array, " + "received an %dD array.", + array.ndim()); + } + DeviceAssignment result(array.shape(0), array.shape(1)); + for (int i = 0; i < array.shape(0); ++i) { + for (int j = 0; j < array.shape(1); ++j) { + result(i, j) = array.at(i, j); + } + } + return result; + }) + .def("replica_count", &DeviceAssignment::replica_count) + .def("computation_count", &DeviceAssignment::computation_count) + .def("__repr__", &DeviceAssignment::ToString); + + // Local XLA client methods. + + // CPU custom-call targets. + m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget); + + py::class_ alloc_config(m, "AllocatorConfig"); + alloc_config.def(py::init<>()) + .def_readwrite("kind", &AllocatorConfig::kind) + .def_readwrite("memory_fraction", &AllocatorConfig::memory_fraction); + py::enum_(alloc_config, "Kind") + .value("DEFAULT", AllocatorConfig::Kind::kDefault) + .value("PLATFORM", AllocatorConfig::Kind::kPlatform) + .value("BFC", AllocatorConfig::Kind::kBFC); + + py::class_>(m, "LocalClient") + .def_static("Get", &PyLocalClient::Get, py::arg("platform"), + py::arg("xla_platform_id"), py::arg("asynchronous"), + py::arg("allocator_config") = AllocatorConfig()) + .def("DeviceCount", &PyLocalClient::device_count) + .def("TransferToInfeed", &PyLocalClient::TransferToInfeed) + .def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed); + + py::class_(m, "PyLocalBuffer") + .def_static("from_python", &PyLocalBuffer::FromPython) + .def_static("from_python_values", &PyLocalBuffer::FromPythonValues) + .def_static("make_tuple", &PyLocalBuffer::MakeTuple) + .def("delete", &PyLocalBuffer::Delete) + .def("destructure", &PyLocalBuffer::DestructureTuple) + .def("to_py", &PyLocalBuffer::ToPython) + .def("shape", &PyLocalBuffer::on_host_shape) + .def("device", &PyLocalBuffer::device_ordinal) + .def("is_deleted", [](const PyLocalBuffer& buffer) { + return buffer.device_buffer() == nullptr; + }); + + py::class_(m, "LocalExecutable") + .def_static("Compile", &PyLocalExecutable::Compile, + py::call_guard()) + .def("DeviceOrdinals", &PyLocalExecutable::DeviceOrdinals) + .def("Delete", &PyLocalExecutable::Delete) + .def("Execute", &PyLocalExecutable::Execute, + py::call_guard(), py::arg("arguments")) + .def("ExecutePerReplica", &PyLocalExecutable::ExecutePerReplica, + py::call_guard(), py::arg("arguments")); + + py::class_(m, "DebugOptions") + .def_property("xla_cpu_enable_fast_math", + &DebugOptions::xla_cpu_enable_fast_math, + &DebugOptions::set_xla_cpu_enable_fast_math) + .def_property("xla_cpu_fast_math_honor_infs", + &DebugOptions::xla_cpu_fast_math_honor_infs, + &DebugOptions::set_xla_cpu_fast_math_honor_infs) + .def_property("xla_cpu_fast_math_honor_nans", + &DebugOptions::xla_cpu_fast_math_honor_nans, + &DebugOptions::set_xla_cpu_fast_math_honor_nans); + + py::class_(m, "ExecutableBuildOptions") + .def(py::init<>()) + .def_property( + "result_layout", + [](const ExecutableBuildOptions& options) -> absl::optional { + return options.result_layout() + ? absl::optional(*options.result_layout()) + : absl::nullopt; + }, + &ExecutableBuildOptions::set_result_layout) + .def_property("num_replicas", &ExecutableBuildOptions::num_replicas, + &ExecutableBuildOptions::set_num_replicas) + .def_property_readonly( + "debug_options", &ExecutableBuildOptions::mutable_debug_options, + py::return_value_policy::reference, py::keep_alive<1, 0>()); + + py::class_(m, "XlaComputation") + .def("GetProgramShape", &XlaComputation::GetProgramShape) + .def("GetSerializedProto", &GetComputationSerializedProto) + .def("GetHloText", &GetComputationHloText) + .def("GetHloDotGraph", &GetComputationHloDotGraph); + + py::class_(m, "XlaOp"); + + py::class_(m, "XlaBuilder") + .def(py::init([](const std::string& name) -> std::unique_ptr { + return absl::make_unique(UniquifyName(name)); + })) + .def( + "Build", + [](XlaBuilder& builder, absl::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }, + "Builds a computation from the contents of the builder.", + py::arg("root") = absl::nullopt) + .def("ClearOpMetadata", &XlaBuilder::ClearOpMetadata) + .def("GetShape", &XlaBuilder::GetShape) + .def( + "GetProgramShape", + [](const XlaBuilder& builder, + absl::optional root) -> StatusOr { + return root ? builder.GetProgramShape(*root) + : builder.GetProgramShape(); + }, + py::arg("root") = absl::nullopt) + .def("IsConstant", &XlaBuilder::IsConstant) + .def("SetOpMetadata", &XlaBuilder::SetOpMetadata); + + // ops submodule, containing free functions that add operators to an + // XlaBuilder. + py::module ops = m.def_submodule("ops", "XLA operations"); + + ops.def("AllReduce", + static_cast, + const absl::optional&)>(&CrossReplicaSum)); + ops.def("AllToAll", &AllToAll); + ops.def("CollectivePermute", &CollectivePermute); + ops.def("CrossReplicaSum", + static_cast)>( + &CrossReplicaSum)); + ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), + py::arg("new_element_type")); + ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); + ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), + py::arg("shape"), py::arg("broadcast_dimensions")); + ops.def("Call", &Call); + ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); + ops.def("Clamp", &Clamp); + ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); + ops.def("ConcatInDim", &ConcatInDim); + ops.def("Conditional", + static_cast, + absl::Span)>(&Conditional)); + ops.def("Conditional", + static_cast(&Conditional)); + ops.def("ConstantLiteral", &ConstantLiteral); + ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), + py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), + py::arg("lhs_dilation"), py::arg("rhs_dilation"), + py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, + py::arg("batch_group_count") = 1, + py::arg("precision_config") = nullptr); + ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), + py::arg("new_element_type")); + ops.def("CustomCall", &CustomCallWithLayout); + ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), + py::arg("precision_config") = nullptr); + ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), + py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); + ops.def("DynamicSlice", + static_cast, + absl::Span)>(&DynamicSlice)); + ops.def("DynamicUpdateSlice", + static_cast)>( + &DynamicUpdateSlice)); + + ops.def("Fft", &Fft); + py::enum_(m, "FftType") + .value("FFT", FftType::FFT) + .value("IFFT", FftType::IFFT) + .value("RFFT", FftType::RFFT) + .value("IRFFT", FftType::IRFFT); + + ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), + py::arg("dimension_numbers"), py::arg("slice_sizes")); + ops.def("GetTupleElement", &GetTupleElement); + ops.def("Infeed", &Infeed, py::arg("builder"), py::arg("shape"), + py::arg("config") = ""); + ops.def("Iota", + static_cast(&Iota)); + ops.def("Iota", + static_cast(&Iota)); + ops.def("Map", &Map); + ops.def("Outfeed", &Outfeed, py::arg("operand"), py::arg("shape_with_layout"), + py::arg("outfeed_config") = ""); + ops.def("Pad", &Pad); + ops.def( + "Parameter", + static_cast( + &Parameter)); + ops.def("QR", + [](XlaOp a, bool full_matrices) -> StatusOr> { + TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); + return std::make_pair(qr.q, qr.r); + }); + ops.def( + "Eigh", + [](XlaOp a, bool lower, int64 max_iter, + float epsilon) -> std::pair { + auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon); + return std::make_pair(eigh.v, eigh.w); + }, + py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100, + py::arg("epsilon") = 1e-6); + ops.def( + "SVD", + [](XlaOp a, int64 max_iter, + float epsilon) -> std::tuple { + auto svd = SVD(a, max_iter, epsilon); + return std::make_tuple(svd.u, svd.d, svd.v); + }, + py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); + ops.def("Reduce", + static_cast, + absl::Span, const XlaComputation&, + absl::Span)>(&Reduce)); + ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding); + ops.def("ReplicaId", &ReplicaId); + ops.def("Reshape", static_cast, + absl::Span)>(&Reshape)); + ops.def("Reshape", + static_cast)>(&Reshape)); + ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); + ops.def("RngNormal", &RngNormal); + ops.def("RngUniform", &RngUniform); + ops.def("Scatter", &Scatter); + ops.def("Select", &Select); + ops.def("SelectAndScatterWithGeneralPadding", + &SelectAndScatterWithGeneralPadding); + ops.def("Slice", &Slice); + ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), + py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); + ops.def( + "Sort", + [](XlaBuilder* builder, absl::Span operands, + int64 dimension) -> XlaOp { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + std::vector operand_types; + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); + operand_types.push_back(operand_shape.element_type()); + } + return Sort(operands, + CreateScalarLtComputation(operand_types, builder), + dimension); + }); + }, + py::arg("builder"), py::arg("operands"), py::arg("dimension") = -1); + ops.def("Transpose", &Transpose); + ops.def("TriangularSolve", &TriangularSolve); + ops.def("Tuple", &Tuple); + ops.def("While", &While); + +#define BINARY_OP(op) \ + ops.def( \ + #op, \ + [](XlaOp a, XlaOp b, absl::optional> dims) { \ + return dims ? op(a, b, *dims) : op(a, b); \ + }, \ + py::arg("lhs"), py::arg("rhs"), \ + py::arg("broadcast_dimensions") = absl::nullopt) + BINARY_OP(Eq); + BINARY_OP(Ne); + BINARY_OP(Ge); + BINARY_OP(Gt); + BINARY_OP(Lt); + BINARY_OP(Le); + BINARY_OP(Add); + BINARY_OP(Sub); + BINARY_OP(Mul); + BINARY_OP(Div); + BINARY_OP(Rem); + BINARY_OP(Max); + BINARY_OP(Min); + BINARY_OP(And); + BINARY_OP(Or); + BINARY_OP(Xor); + BINARY_OP(ShiftLeft); + BINARY_OP(ShiftRightArithmetic); + BINARY_OP(ShiftRightLogical); + BINARY_OP(Atan2); + BINARY_OP(Pow); + BINARY_OP(Complex); +#undef BINARY_OP + +#define UNARY_OP(op) ops.def(#op, &op) + UNARY_OP(Not); + UNARY_OP(Clz); + UNARY_OP(Abs); + UNARY_OP(Exp); + UNARY_OP(Expm1); + UNARY_OP(Floor); + UNARY_OP(Ceil); + UNARY_OP(Round); + UNARY_OP(Log); + UNARY_OP(Log1p); + UNARY_OP(Sign); + UNARY_OP(Cos); + UNARY_OP(Sin); + UNARY_OP(Tanh); + UNARY_OP(IsFinite); + UNARY_OP(Neg); + UNARY_OP(Sqrt); + UNARY_OP(Rsqrt); + UNARY_OP(Square); + UNARY_OP(Reciprocal); + UNARY_OP(Erfc); + UNARY_OP(Erf); + UNARY_OP(ErfInv); + UNARY_OP(Lgamma); + UNARY_OP(Digamma); + UNARY_OP(Acos); + UNARY_OP(Asin); + UNARY_OP(Atan); + UNARY_OP(Tan); + UNARY_OP(Acosh); + UNARY_OP(Asinh); + UNARY_OP(Atanh); + UNARY_OP(Cosh); + UNARY_OP(Sinh); + UNARY_OP(Real); + UNARY_OP(Imag); + UNARY_OP(Conj); +#undef UNARY_OP + + py::enum_( + m, "TriangularSolveOptions_Transpose") + .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) + .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) + .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) + .value("ADJOINT", TriangularSolveOptions::ADJOINT); + + // TODO(phawkins): improve bindings for these types. + py::class_(m, "ChannelHandle"); + py::class_(m, "PrecisionConfig"); + + tensorflow::AddXrtSubmodule(&m); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 38458d7f090..e208cacc19c 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An XLA client in Python, supporting AOT compilation.""" +"""An XLA client in Python.""" from __future__ import absolute_import from __future__ import division @@ -28,22 +28,14 @@ import os import numpy as np import six -from six.moves import xrange # Note this module does *not* depend on any Python protocol buffers. The XLA # Python bindings are currently packaged both as part of jaxlib and as part # of TensorFlow. If we use protocol buffers here, then importing both jaxlib # and TensorFlow may fail with duplicate protocol buffer message definitions. -from tensorflow.compiler.xla.python import pywrap_xla as c_api - -# Import the XRT backend, if available. -try: - # pylint: disable=g-import-not-at-top - from tensorflow.compiler.xla.python import pywrap_xrt as xrt_api -except ImportError: - xrt_api = None - +from tensorflow.compiler.xla.python import xla_extension as _xla +from tensorflow.compiler.xla.python.xla_extension import ops # Most functions are snake_case for consistency with other modules, whereas # method names of ComputationBuilder and Computation are CamelCase for @@ -51,31 +43,18 @@ except ImportError: # pylint: disable=invalid-name -# Version of the XLA Python client. -# -# JAX packages the XLA python plugin as a binary pip module (jaxlib) that is -# packaged separately from the Python code that consumes it (jax). -# -# We occasionally need to make backwards-incompatible changes to jaxlib, in -# which case we need to be able to detect when incompatible versions are -# installed. -def version(): - return (0, 1, 8) - - -_OP_METADATA_FIELDS = [ - 'op_type', - 'op_name', - 'source_file', - 'source_line', -] -OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) - - @six.add_metaclass(abc.ABCMeta) class Backend(object): """Abstract base class for XLA backends.""" + def __init__(self, platform): + """Creates a new Backend. + + Args: + platform: A string naming the platform; for example 'gpu'. + """ + self.platform = platform + @abc.abstractmethod def device_count(self): """Returns the number of devices known to the backend.""" @@ -84,151 +63,155 @@ class Backend(object): def buffer_from_pyval(self, pyval, device=0): """Allocates a fresh buffer and populates it with `pyval`.""" - @abc.abstractmethod - def delete_buffer(self, c_buffer): - """Deletes buffer `c_buffer`.""" + def buffers_from_pyvals(self, pyvals_and_devices): + """Allocates buffers and populates them with `pyvals`.""" + return [ + self.buffer_from_pyval(pyval, device) + for pyval, device in pyvals_and_devices + ] @abc.abstractmethod - def destructure_tuple(self, c_buffer): - """Destructures a tuple buffer into a sequence of buffers.""" + def make_tuple(self, c_buffers, device_ordinal): + """Makes a tuple from a sequence of backend buffer objects.""" @abc.abstractmethod - def compile(self, computation, argument_shapes, result_shape, - compile_options): + def compile(self, computation, compile_options): """Compiles a computation. Returns an executable.""" - @abc.abstractmethod - def delete_executable(self, executable): - """Deletes an executable.""" - @abc.abstractmethod - def execute(self, executable, args): - """Runs an executable without replication.""" - - @abc.abstractmethod - def execute_replicated(self, executable, per_replica_args): - """Runs an executable in a replicated manner.""" - - -def _maybe_encode_string(s): - if six.PY3: - return s.encode('utf-8') - else: - return s - - -class XlaLocalBackend(Backend): +class LocalBackend(Backend): """XLA backend implemented using the in-process xla::LocalClient API.""" - def __init__(self, platform=None): - platform = platform or _get_default_platform_name() - self.client = c_api.LocalClient.Get(_maybe_encode_string(platform)) - self._delete_buffer = c_api.DeleteLocalShapedBuffer - self._delete_executable = c_api.DeleteLocalExecutable + def __init__(self, platform, client): + """Creates a new LocalBackend. + + Args: + platform: A string; the user-visible platform name, e.g. 'gpu'. + client: An _xla.PyLocalClient object. + """ + super(LocalBackend, self).__init__(platform) + self.client = client def device_count(self): return self.client.DeviceCount() def buffer_from_pyval(self, pyval, device=0): - return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device) + return _xla.PyLocalBuffer.from_python(pyval, self.client, device) - def delete_buffer(self, c_buffer): - self._delete_buffer(c_buffer) + def buffers_from_pyvals(self, pyvals_and_devices): + return _xla.PyLocalBuffer.from_python_values(pyvals_and_devices, + self.client) - def destructure_tuple(self, c_buffer): - result = c_buffer.DestructureTuple() - return [result.Release(i) for i in xrange(result.size())] + def make_tuple(self, c_buffers, device_ordinal): + return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device_ordinal) - def compile(self, c_computation, argument_shapes, result_shape, - compile_options): - return c_computation.Compile(argument_shapes, compile_options, self.client) - - def delete_executable(self, executable): - self._delete_executable(executable) - - def execute(self, executable, args): - return executable.Execute(args) - - def execute_replicated(self, executable, per_replica_args): - output_buffer_tup = executable.ExecutePerReplica(per_replica_args) - size = output_buffer_tup.size() - return [output_buffer_tup.Release(i) for i in xrange(size)] + def compile(self, c_computation, compile_options): + options = _xla.ExecutableBuildOptions() + options.num_replicas = compile_options.num_replicas + if compile_options.argument_layouts: + argument_layouts = compile_options.argument_layouts + else: + argument_layouts = c_computation.GetProgramShape().parameter_shapes() + if compile_options.result_layout: + options.result_layout = compile_options.result_layout + options.debug_options.xla_cpu_fast_math_honor_infs = True + options.debug_options.xla_cpu_fast_math_honor_nans = True + return _xla.LocalExecutable.Compile(c_computation, argument_layouts, + options, self.client) -class XrtBackend(Backend): - """XLA backend implemented using XRT.""" - - def __init__(self, target): - self.target = target - self._delete_buffer = xrt_api.DeleteXrtAllocation - self._delete_executable = xrt_api.DeleteXrtExecutable - - def device_count(self): - return 1 # Multidevice execution not implemented. - - def buffer_from_pyval(self, pyval, device=0): - if device != 0: - raise NotImplementedError( - 'Multi-replica execution is not yet supported via the XRT backend.') - return xrt_api.XrtAllocation.FromLiteral(pyval, - _maybe_encode_string(self.target)) - - def delete_buffer(self, c_buffer): - self._delete_buffer(c_buffer) - - def destructure_tuple(self, c_buffer): - result = xrt_api.DestructureXrtAllocationTuple( - c_buffer, _maybe_encode_string(self.target)) - return [result.Release(i) for i in xrange(result.size())] - - def compile(self, c_computation, argument_shapes, result_shape, - compile_options): - return xrt_api.XrtExecutable.CompileForXrt( - c_computation.GetSerializedProto(), argument_shapes, result_shape, - _maybe_encode_string(self.target)) - - def delete_executable(self, executable): - self._delete_executable(executable) - - def execute(self, executable, args): - return executable.Execute(args) - - def execute_replicated(self, executable, per_replica_args): - if len(per_replica_args) != 1: - raise NotImplementedError( - 'Multi-replica execution is not yet supported via the XRT backend.') - return [executable.Execute(per_replica_args[0])] +def _cpu_backend_factory(): + client = _xla.LocalClient.Get( + platform='cpu', xla_platform_id='Host', asynchronous=True) + return LocalBackend(platform='cpu', client=client) -_default_platform_name = 'Host' -_default_backend = None +def _gpu_backend_factory(): + """Returns a GPU backend. BFC allocator is used by default.""" + allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() + memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION') + if allocator not in ('default', 'platform', 'bfc'): + raise ValueError( + 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", or ' + '"bfc", got "%s"' % allocator) + config = _xla.AllocatorConfig() + if allocator == 'default': + config.kind = _xla.AllocatorConfig.Kind.DEFAULT + if allocator == 'platform': + config.kind = _xla.AllocatorConfig.Kind.PLATFORM + if allocator == 'bfc': + config.kind = _xla.AllocatorConfig.Kind.BFC + if memory_fraction: + config.memory_fraction = float(memory_fraction) + + client = _xla.LocalClient.Get( + platform='gpu', xla_platform_id='CUDA', asynchronous=False, + allocator_config=config) + return LocalBackend(platform='gpu', client=client) -def _get_default_platform_name(): - return _default_platform_name +# Backend factories, keyed by user-visible name, in increasing priority order. +_local_backend_factories = collections.OrderedDict([ + ('cpu', _cpu_backend_factory), + ('gpu', _gpu_backend_factory), +]) -def _get_default_local_backend(): - global _default_backend - global _default_platform_name - if _default_backend is None: - _default_backend = XlaLocalBackend(_default_platform_name) - return _default_backend +def register_local_backend_factory(name, factory): + _local_backend_factories[name] = factory -class BackendType(enum.Enum): - XLA_LOCAL = 1 - XRT = 2 +_local_backends = None -def BackendSpec(backend, target): - """Compatibility wrapper to support older clients. Do not use in new code.""" - if backend == BackendType.XLA_LOCAL: - return _get_default_local_backend() - elif backend == BackendType.XRT: - return XrtBackend(target) - else: - raise ValueError('Unknown backend {}'.format(backend)) +def _get_local_backends(): + """Instantiates all known local backends.""" + global _local_backends + if _local_backends is not None: + return _local_backends + + _local_backends = collections.OrderedDict() + for name, factory in _local_backend_factories.items(): + try: + backend = factory() + except RuntimeError: + # If the backend isn't built into the binary, or if it has no devices, we + # expect a RuntimeError. + continue + _local_backends[name] = backend + return _local_backends + + +def get_local_backend(name=None): + """Returns a local backend. + + Args: + name: the backend name. If `None`, a default local backend is returned, + typically `gpu` if one is present, or `cpu` if not. If a string, the named + backend is returned or an exception raised. + + Returns: + A LocalBackend object. + """ + backends = _get_local_backends() + if name is not None: + try: + return backends[name] + except KeyError: + raise RuntimeError('Unknown backend {}'.format(name)) + + return list(backends.values())[-1] + + +class OpMetadata(object): + """Python representation of a xla.OpMetadata protobuf.""" + __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') + + def __init__(self, op_type='', op_name='', source_file='', source_line=0): + self.op_type = op_type + self.op_name = op_name + self.source_file = source_file + self.source_line = source_line def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): @@ -242,6 +225,350 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) +PrimitiveType = _xla.PrimitiveType + +XLA_ELEMENT_TYPE_TO_DTYPE = { + PrimitiveType.PRED: np.dtype('bool'), + PrimitiveType.S8: np.dtype('int8'), + PrimitiveType.S16: np.dtype('int16'), + PrimitiveType.S32: np.dtype('int32'), + PrimitiveType.S64: np.dtype('int64'), + PrimitiveType.U8: np.dtype('uint8'), + PrimitiveType.U16: np.dtype('uint16'), + PrimitiveType.U32: np.dtype('uint32'), + PrimitiveType.U64: np.dtype('uint64'), + PrimitiveType.F16: np.dtype('float16'), + PrimitiveType.F32: np.dtype('float32'), + PrimitiveType.F64: np.dtype('float64'), + PrimitiveType.C64: np.dtype('complex64'), + PrimitiveType.C128: np.dtype('complex128'), + PrimitiveType.TUPLE: np.dtype(np.object), +} + +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + + +Shape = _xla.Shape +Shape.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class Shape(object): + '''Represents an XLA shape. + + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + ''' + + @staticmethod + def tuple_shape(tuple_shapes) -> Shape: + "Construct a tuple shape." + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: + + @staticmethod + def from_pyval(pyval) -> Shape: + "Returns a Shape that describes a tuple-tree of Numpy arrays." + + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): + def is_tuple(self) -> bool: + def is_array(self) -> bool: + def tuple_shapes(self) -> [Shape]: + def numpy_dtype(self) -> np.dtype: + "Like element_type(), but returns dtype('O') for a tuple shape." + def xla_element_type(self) -> PrimitiveType: + def element_type(self) -> np.dtype: + def dimensions(self) -> (int, int, ...): + def rank(self) -> int: + def minor_to_major(self) -> [int]: + def with_major_to_minor_layout_if_absent(self) -> Shape: + "Returns a copy with missing layouts set to major-to-minor." + + def to_serialized_proto(self) -> bytes: + "Returns 'shape' as a serialized proto." +""" + +ProgramShape = _xla.ProgramShape +ProgramShape.__doc__ = """ +A ProgramShape is a C++ object that duck types like the following class. + +class ProgramShape(object): + def __init__(self, parameter_shapes, result_shape): + def parameter_shapes(self) -> [Shape]: + def result_shape(self) -> Shape: + def __repr__(self): +""" + + +class Buffer(object): + """Represents a handle to data owned by XLA. + + The referent is ready for use in executing a local, compiled + Computation. On XLA platforms involving a device (e.g. GPU), this + means the referent is in device memory. + """ + + @staticmethod + def from_pyval(pyval, device=0, backend=None): + """Copies the `pyval` to a freshly allocated on-device buffer.""" + backend = backend or get_local_backend() + return backend.buffer_from_pyval(pyval, device) + + @staticmethod + def from_pyvals(pyvals_and_devices, backend=None): + """Copies multiple Python values to freshly allocated on-device buffers. + + Arguments: + pyvals_and_devices: a list of `(pyval, device)` pairs, where `pyval` is a + Python value to copy (e.g., a NumPy array), and `device` is an integer + device ordinal. + backend: a Backend object, or `None` to use the default local backend. + + Returns: + A list of `Buffer` objects corresponding to `pyvals_and_devices`. + """ + backend = backend or get_local_backend() + return backend.buffers_from_pyvals(pyvals_and_devices) + + @staticmethod + def make_tuple(buffers, backend=None, device=0): + backend = backend or get_local_backend() + return backend.make_tuple(buffers, device_ordinal=device) + + # Buffer is not an instantiable type and exists only for its static methods. + # The underlying buffer objects are C++ object with the following + # API: + # def to_py(self): + # def shape(self) -> Shape: + # def device(self) -> int: + # def delete(self): + # def destructure(self) -> [Buffer] + # def is_deleted(self) -> bool: + # + # TODO(phawkins): remove Buffer and its static methods completely, have + # clients call methods on Backend to create buffers. + + +# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops +# compatibility with Jaxlib versions older than 0.1.13. +LocalBuffer = Buffer + + +def shape_from_pyval(pyval): + """Returns a Shape that describes a tuple-tree of Numpy arrays.""" + + def convert(pyval): + if isinstance(pyval, tuple): + return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) + else: + return Shape.array_shape(pyval.dtype, np.shape(pyval)) + + return convert(pyval) + + +def transfer_to_infeed(value, device_ordinal=0): + """Transfers the given value into the XLA infeed queue. + + XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with + a totally ordered stream of values. This is dequeued from XLA computations via + the Infeed() operation. + + Args: + value: the value that the caller would like to enqueue into the XLA infeed + queue + device_ordinal: the device to infeed the value to. Each device has a + distinct infeed queue. + """ + # TODO(phawkins): support non-default backends. + backend = get_local_backend() + backend.client.TransferToInfeed(value, device_ordinal) + + +def transfer_from_outfeed(shape, device_ordinal=0): + """Transfers a literal of the given shape from `device_ordinal`'s outfeed. + + Args: + shape: The shape of the value to transfer from outfeed. + device_ordinal: The device ordinal to transfer the outfeed value from. Each + device has a distinct outfeed queue.. + + Returns: + The literal value that is produced from the outfeed queue. + """ + # TODO(phawkins): support non-default backends. + backend = get_local_backend() + return backend.client.TransferFromOutfeed( + shape.with_major_to_minor_layout_if_absent(), device_ordinal) + + +class CompileOptions(object): + """Python object for XLA compile options. + + These options can be passed to the 'compile' step when using a local XLA + client. + """ + + def __init__(self): + self.xla_dump_to = None + self.dump_hlo_pass_re = None + self.dump_hlo_module_re = None + self.dump_hlo_as_text = None + self.dump_hlo_as_proto = None + self.hlo_profile = None + self.num_replicas = 1 + self.argument_layouts = None + self.result_layout = None + + +class Computation(object): + """Python wrapper for an XLA Computation. + + A Computation can be compiled to form an Executable, or used as a + subcomputation in ComputationBuilder methods. + """ + + def __init__(self, c_computation, backend=None): + self._c_computation = c_computation + # The backend argument is deprecated. Pass a backend to Compile() instead. + self._backend = backend + + @property + def computation(self): + return self._c_computation + + def GetSerializedProto(self): + """Gets the serialized HloModuleProto proto object in this computation. + + Returns: + A string containing a serialized HloModuleProto proto containing the + computation and its dependencies. + """ + return self.computation.GetSerializedProto() + + def GetHloText(self): + """Get the textual HLO representation of this computation. + + Returns: + A string containing the textual HLO. + """ + return self.computation.GetHloText() + + def GetHloDotGraph(self): + """Get a Graphviz Dot representation of this computation. + + Returns: + A string containing the graphviz dot graph. + """ + return self.computation.GetHloDotGraph() + + def Compile(self, argument_shapes=None, compile_options=None, backend=None): + """Compiles a computation. + + Computations are the result of a "ComputationBuild'ing" process. + + Arguments: + argument_shapes: Deprecated. Use compile_options.argument_layouts instead. + compile_options: options to use for compilation, includes an optional laid + out result shape for the computation. + backend: a `Backend` for which an executable should be generated. + + Returns: + A Executable instance. + """ + backend = backend or self._backend or get_local_backend() + + compile_options = compile_options or CompileOptions() + if argument_shapes: + compile_options.argument_layouts = argument_shapes + return backend.compile(self.computation, compile_options) + + def GetProgramShape(self): + return self._c_computation.GetProgramShape() + + def GetReturnValueShape(self): + return self._c_computation.GetProgramShape().result_shape() + + +# An Executable is a C++ class that duck types with the following API: +# class Executable(object): +# def DeviceOrdinals(self) -> [int]: +# def Execute(self, arguments : [Buffer]) -> Buffer: +# """Execute on one replica with Buffer arguments and return value.""" +# +# def ExecutePerReplica(self, arguments: [[Buffer]]) -> [Buffer]: +# """Execute on many replicas with Buffer arguments and return value. +# +# Args: +# arguments: A sequence of sequences of Buffers. The i'th inner sequence +# comprises the arguments for execution on the i'th replica. +# +# Returns: +# A list of the computation's outputs for each replica, as a Buffer. If +# a shallow sequence of arguments was passed in for `arguments`, then the +# sole, zero'th replica's output is returned instead, as a Buffer. +# """ +# +# There are different implementations of Executable for the Local and XRT +# backends. + + +def execute_with_python_values(executable, arguments=(), backend=None): + """Execute on one replica with Python values as arguments and output.""" + + backend = backend or get_local_backend() + + def put(arg): + return Buffer.from_pyval( + arg, device=executable.DeviceOrdinals()[0], backend=backend) + + arguments = [put(arg) for arg in arguments] + return executable.Execute(arguments).to_py() + + +def execute_with_python_values_replicated(executable, arguments, backend=None): + """Execute on many replicas with Python values as arguments and output. + + Arguments: + executable: the program to run. + arguments: a list of lists of Python values indexed by + `[replica][arg_num]` to pass as inputs. + backend: the backend we are targeting. + + Returns: + A list of python values, one per replica. + """ + backend = backend or get_local_backend() + device_ordinals = executable.DeviceOrdinals() + # pylint: disable=g-complex-comprehension + flat_args = [(arg, device_ordinals[replica]) + for replica, replica_args in enumerate(arguments) + for arg in replica_args] + flat_arg_buffers = Buffer.from_pyvals(flat_args, backend=backend) + arg_buffers = [] + for replica_args in arguments: + arg_buffers.append(flat_arg_buffers[:len(replica_args)]) + flat_arg_buffers = flat_arg_buffers[len(replica_args):] + return [out.to_py() for out in executable.ExecutePerReplica(arg_buffers)] + + class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -267,15 +594,837 @@ def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, return [(0, 0)] * len(window_strides) elif padding_type == PaddingType.SAME: out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) - pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) - for out_size, stride, filter_size, in_size - in zip(out_shape, window_strides, rhs_dims, lhs_dims)] + pad_sizes = [ + max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size in zip( + out_shape, window_strides, rhs_dims, lhs_dims) + ] return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] else: msg = 'Unexpected PaddingType value: {}' raise ValueError(msg.format(padding_type)) +class ComputationBuilder(object): + """XLA computation builder. + + Enqueues XLA ops in sequence and in order to build a + Computation, which in turn can be compiled into a + LocalExecutable, which in turn can be locally executed. + """ + + # The methods of this class map 1-to-1 onto the XLA C++ + # computation builder API. Therefore, there's no need to laboriously list + # arguments and return values for every method, especially where it's obvious. + # + # pylint: disable=g-doc-return-or-yield + # pylint: disable=g-doc-args + + def __init__(self, name): + self._builder = _xla.XlaBuilder(name) + self._parameter_numbering = itertools.count() + + def Build(self, root=None, backend=None): + """Builds a `Computation` from the contents of the builder. + + Args: + root: if not None, the operator containing the return value of the + computation. + + Returns: + A `Computation`. + """ + if root is not None: + return Computation(self._builder.Build(root), backend=backend) + else: + return Computation(self._builder.Build(), backend=backend) + + def GetShape(self, operand): + return self._builder.GetShape(operand) + + def SetOpMetadata(self, op_metadata): + """Set metadata for operations that are about to be enqueued.""" + self._builder.SetOpMetadata(op_metadata) + + def ClearOpMetadata(self): + """Clear metadata for operations that are about to be enqueued.""" + self._builder.ClearOpMetadata() + + def Infeed(self, shape): + """Enqueues an infeed op onto the computation. + + Infeed operations dequeue data of the given shape from the device's infeed + queue for subsequent use in the computation. + + Returns: + An XlaOp. + """ + return ops.Infeed(self._builder, + shape.with_major_to_minor_layout_if_absent()) + + def Outfeed(self, operand): + """Enqueues an outfeed op onto the computation. + + Outfeed operations enqueue data, using the given operand, onto the XLA + outfeed queue for subsequent dequeue via the client API. + """ + return ops.Outfeed(operand, self._builder.GetShape(operand), '') + + def Constant(self, value): + """Enqueues a constant op onto the computation. + + Args: + value: value for the constant, as a np.array with an explicit dtype set to + one of the supported types. + + Returns: + An XlaOp. + """ + return ops.ConstantLiteral(self._builder, value) + + def ConstantF32Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + An XlaOp. + """ + return self.Constant(np.array(value, dtype=np.float32)) + + def ConstantF64Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + An XlaOp. + """ + return self.Constant(np.array(value, dtype=np.float64)) + + def ConstantS32Scalar(self, value): + """Convenience method to enqueue a scalar S32 constant op. + + Args: + value: a floating-point number. + + Returns: + An XlaOp. + """ + return self.Constant(np.array(value, dtype=np.int32)) + + def ConstantS64Scalar(self, value): + """Convenience method to enqueue a scalar S64 constant op. + + Args: + value: a floating-point number. + + Returns: + An XlaOp. + """ + return self.Constant(np.array(value, dtype=np.int64)) + + def ConstantPredScalar(self, value): + """Convenience method to enqueue a scalar PRED constant op. + + Args: + value: a boolean value. + + Returns: + An XlaOp. + """ + return self.Constant(np.array(value, dtype=np.bool)) + + def ParameterWithShape(self, shape, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation, given a shape. + + Args: + shape: the parameter's shape as a Shape object. + name: optional string name for the parameter. + parameter_num: parameter number in the computation function. If None, the + next linear parameter number is used. The default value capability can + be used for auto-numbering. If you're using auto-numbering for some + parameters, use it for *all* parameters to avoid clashes. + + Returns: + An XlaOp. + """ + if name is None: + name = '' + if parameter_num is None: + parameter_num = next(self._parameter_numbering) + + return ops.Parameter(self._builder, parameter_num, + shape.with_major_to_minor_layout_if_absent(), + name.encode('utf8')) + + def ParameterFromNumpy(self, value, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation. + + Args: + value: a Numpy array, or a nested tuple thereof, from which the shape is + inferred. + name: as in ParameterWithShape. + parameter_num: as in ParameterWithShape. + + Returns: + An XlaOp. + """ + return self.ParameterWithShape( + shape_from_pyval(value), name=name, parameter_num=parameter_num) + + def Iota(self, dtype, size): + """Enqueues an iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + size: integer, the number of elements in the array. + + Returns: + An XlaOp representing the added iota constant. + """ + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + return ops.Iota(self._builder, element_type, size) + + def BroadcastedIota(self, dtype, shape, dimension): + """Enqueues a broadcasted iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + shape: tuple of integers, the expected output shape (dimensions). + dimension: positive integer, dimension along which to increment values. + + Returns: + An XlaOp representing the added broadcasted iota constant. + """ + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + xla_shape = _xla.Shape.array_shape(element_type, shape, None) + return ops.Iota(self._builder, xla_shape, dimension) + + def Concatenate(self, operands, dimension): + """Enqueues a concatenate operation onto the computation. + + Args: + operands: the operands to concatenate. + dimension: the dimension in which to perform the concatenation. + + Returns: + An XlaOp representing the added concatenate op. + """ + return ops.ConcatInDim(self._builder, list(operands), dimension) + + def ReplicaId(self): + """Enqueues a ReplicaId operation onto the computation. + + Returns: + A LocalOp representing the replica id. + """ + return _xla.ops.ReplicaId(self._builder) + + def Pad(self, operand, padding_value, padding_config): + """Enqueues a Pad operation onto the computation. + + Args: + operand: XlaOp representing the array to pad. + padding_value: XlaOp representing the scalar pad value. + padding_config: either a PaddingConfig or a list of integer triples + (edge_padding_low, edge_padding_high, interior_padding) representing the + configuration of the padding operation. + + Returns: + An XlaOp representing the added Pad op. + """ + if isinstance(padding_config, tuple) or isinstance(padding_config, list): + padding_config = GetPaddingConfigFromTriples(padding_config) + return ops.Pad(operand, padding_value, padding_config) + + def Reshape(self, operand, dimensions, new_sizes): + """Enqueues a reshape op onto the computation. + + Args: + operand: XlaOp representing the array to be reshaped. + dimensions: sequence of integers encoding the order in which dimensions + are collapsed or None, in which case dimensions are flattened in order. + new_sizes: sequence of integers encoding the new dimension sizes (shape). + + Returns: + An XlaOp representing the added Reshape op. + """ + if dimensions is None: + ndim = len(self.GetShape(operand).dimensions()) + dimensions = tuple(range(ndim)) + return ops.Reshape(operand, dimensions, new_sizes) + + def AllReduce(self, operand, computation, replica_groups=None): + """AllReduce op. + + Args: + operand: XlaOp representing the input array + computation: a Computation object - binary reduction function. + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the all-to-all is performed. If not supplied or None (the + default), all replicas belong to the same group. + + Returns: + An XlaOp that represents the all-reduced result. + """ + replica_groups_protos = _get_replica_groups_protos(replica_groups) + return ops.AllReduce(operand, computation.computation, + replica_groups_protos, None) + + def AllToAll(self, + operand, + split_dimension, + concat_dimension, + replica_groups=None): + """AllToAll op. + + Args: + operand: XlaOp representing the input array + split_dimension: the dimension along which the operand is split + concat_dimension: the dimension along which the split blocks are + concatenated + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the all-to-all is performed. If not supplied or None (the + default), all replicas belong to the same group. + + Returns: + An XlaOp that represents the all-to-all concatenation. + """ + replica_groups_protos = _get_replica_groups_protos(replica_groups) + if not replica_groups: + split_count = 1 + else: + split_count = len(replica_groups[0]) + if not all(split_count == len(g) for g in replica_groups): + raise ValueError('Replica groups must be equally sized') + return ops.AllToAll(operand, split_dimension, concat_dimension, split_count, + replica_groups_protos) + + def CrossReplicaSum(self, operand, replica_groups=None): + """CrossReplicaSum op. + + Args: + operand: the operand to sum across replica instances. + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the cross-replica sum is performed. If not supplied or None + (the default), all replicas belong to the same group. + + Returns: + An XlaOp that represents on each replica the sum of its group's values. + """ + replica_groups_protos = _get_replica_groups_protos(replica_groups) + return ops.CrossReplicaSum(operand, replica_groups_protos) + + def Trans(self, operand): + """Specialized matrix transpose op.""" + return ops.Transpose(operand, [1, 0]) + + def Transpose(self, operand, permutation): + """Transpose op.""" + return ops.Transpose(operand, permutation) + + def SelectAndScatter(self, operand, select, window_dimensions, window_strides, + padding, source, init_value, scatter): + """Select and scatter op, used by the gradient of ReduceWindow. + + Args: + operand: XlaOp for array of dimension N and type T over which the windows + slide. + select: Computation of type (T, T) -> Pred to apply to the elements of + each window to indicate which element is selected. + window_dimensions: sequence of N integers for dimensions of the window. + window_strides: sequence of N integers for the strides of the window. + padding: PaddingType representing either 'SAME' or 'VALID ' padding. + source: XlaOp for array of type T with values to scatter. + init_value: XlaOp of scalar type T for initial out value. + scatter: Computation of type (T, T) -> T to apply to each scatter source + element with its destination element. + + Returns: + An XlaOp representing the added SelectAndScatter op. + """ + pads = _convert_padding_type_to_pad_values( + padding, + self.GetShape(operand).dimensions(), window_dimensions, window_strides) + return ops.SelectAndScatterWithGeneralPadding(operand, select.computation, + window_dimensions, + window_strides, pads, source, + init_value, + scatter.computation) + + def Slice(self, operand, start_indices, limit_indices, strides=None): + """Enqueues a slice operation onto the computation. + + Args: + operand: XlaOp for the N dimensional array to be sliced. + start_indices: iterable of N integers containing the starting indices of + the slice for each dimension. + limit_indices: iterable of N integers containing the ending indices + (exclusive) of the slice for each dimension. + strides: optional iterable of N integers containing the stride sizes for + each dimension. + + Returns: + An XlaOp representing the added Slice op. + """ + if strides is None: + start_indices = list(start_indices) + strides = [1] * len(start_indices) + return ops.Slice(operand, start_indices, limit_indices, strides) + + def DynamicSlice(self, operand, start_indices, slice_sizes): + """Enqueues a slice op with dynamic start indices onto the computation. + + Args: + operand: XlaOp for the N dimensional array to be sliced. + start_indices: XlaOp for the 1D array of N integers containing the + starting indices of the slice. + slice_sizes: iterable of N integers containing the slice sizes in each + dimension. + + Returns: + An XlaOp representing the added DynamicSlice op. + """ + slice_sizes = list(slice_sizes) + if isinstance(start_indices, _xla.XlaOp): + start_indices = [ + ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), []) + for i in range(len(slice_sizes)) + ] + return ops.DynamicSlice(operand, list(start_indices), slice_sizes) + + def DynamicUpdateSlice(self, operand, update, start_indices): + """Enqueues a dynamic update slice operation onto the computation. + + Args: + operand: XlaOp for the N dimensional array to be updated. + update: N dimensional array comprising the slice update. + start_indices: Rank-1 array of N integers comprising the starting indices + of the slice along each dimension. + + Returns: + An XlaOp representing the added DynamicUpdateSlice op. + """ + if isinstance(start_indices, _xla.XlaOp): + ndims = self._builder.GetShape(start_indices).dimensions()[0] + start_indices = [ + ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), []) + for i in range(ndims) + ] + return ops.DynamicUpdateSlice(operand, update, list(start_indices)) + + def Tuple(self, *elems): + """Enqueues a tuple operation onto the computation. + + Args: + elems: a sequence of tuple operands (each a XlaOp). + + Returns: + An XlaOp representing the added Tuple op. + """ + return ops.Tuple(self._builder, list(elems)) + + def Call(self, computation_to_apply, operands): + """Enqueues a call operation onto the computation. + + Args: + computation_to_apply: a Computation object. + operands: an iterable of XlaOp. The number and types of operands must + match the arity of computation_to_apply. + + Returns: + An XlaOp representing the added call op. + """ + return ops.Call(self._builder, computation_to_apply.computation, + list(operands)) + + def CustomCall(self, + call_target_name, + operands, + shape_with_layout, + operand_shapes_with_layout, + opaque=None): + """Enqueues a custom call operation onto the computation. + + Args: + call_target_name: the name of the function to call. + operands: an iterable of XlaOp. The number and types of operands must + match the arity of `operand_shapes_with_layout`. + shape_with_layout: the shape of the operator's output, with layout. + operand_shapes_with_layout: the shapes of `operands`, including the + expected layouts. + opaque: an opaque string passed to the backend. + + Returns: + An XlaOp representing the added custom call op. + """ + opaque = opaque or b'' + return ops.CustomCall(self._builder, call_target_name, + list(operands), shape_with_layout, + list(operand_shapes_with_layout), opaque) + + def Map(self, operands, computation_to_apply, dimensions): + """Enqueues a map operation onto the computation. + + Args: + operands: an iterable of XlaOp. + computation_to_apply: a Computation object. + dimensions: dimensions over which to apply map the function. + + Returns: + An XlaOp representing the added Map op. + """ + return ops.Map(self._builder, list(operands), + computation_to_apply.computation, dimensions, []) + + def Reduce(self, operand, init_value, computation_to_apply, dimensions): + """Enqueues a reduction operation onto the computation. + + Args: + operand: reduction operand (XlaOp). + init_value: reduction initial value (XlaOp). + computation_to_apply: a Computation object - binary reduction function. + dimensions: sequence of dimensions (integers) to reduce on. + + Returns: + An XlaOp representing the added Reduce op. + """ + return ops.Reduce(self._builder, [operand], [init_value], + computation_to_apply.computation, dimensions) + + def ReduceWindow(self, operand, init_value, computation_to_apply, + window_dimensions, window_strides, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (XlaOp). + init_value: reduction initial value (XlaOp). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + padding: PaddingType representing either 'SAME' or 'VALID' padding. + + Returns: + An XlaOp representing the added ReduceWindow op. + """ + pads = _convert_padding_type_to_pad_values( + padding, + self.GetShape(operand).dimensions(), window_dimensions, window_strides) + return ops.ReduceWindowWithGeneralPadding(operand, init_value, + computation_to_apply.computation, + window_dimensions, window_strides, + (), (), pads) + + def ReduceWindowWithGeneralPadding(self, operand, init_value, + computation_to_apply, window_dimensions, + window_strides, base_dilations, + window_dilations, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (XlaOp). + init_value: reduction initial value (XlaOp). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + base_dilations: dilations for the base (sequence of integers). + window_dilations: dilations for window (sequence of integers). + padding: length-N array-like of pairs of integers of (low, high) padding. + + Returns: + An XlaOp representing the added ReduceWindow op. + """ + return ops.ReduceWindowWithGeneralPadding(operand, init_value, + computation_to_apply.computation, + window_dimensions, window_strides, + base_dilations, window_dilations, + padding) + + def RngNormal(self, mu, sigma, dims): + """Enqueues an RngNormal operation onto the computation. + + Args: + mu: An XlaOp to an F32 scalar specifying the mean. + sigma: An XlaOp to an F32 scalar specifying the standard deviation. + dims: A 1D array-like of nonnegative integers specifying the dimensions. + Returns: a XlaOp to the generated array of F32 values. + """ + shape = _xla.Shape.array_shape(self.GetShape(mu).xla_element_type(), dims) + return ops.RngNormal(mu, sigma, shape) + + def RngUniform(self, a, b, dims): + """Enqueues an RngUniform operation onto the computation. + + Args: + a: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of b) + specifying the low end of the interval [a, b) over which values are + generated. + b: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of a) + specifying the high end of the interval [a, b) over which values are + generated. + dims: A 1D array-like of nonnegative integers specifying the dimensions. + Returns: a XlaOp to the generated array of values with the same numeric type + (F32, S32, or U32) as the arguments a and b. + """ + shape = _xla.Shape.array_shape(self.GetShape(a).xla_element_type(), dims) + return ops.RngUniform(a, b, shape) + + def While(self, cond, body, init): + """Enqueues a While operation onto the computation. + + Args: + cond: a Computation for the loop condition, which has type T -> PRED + body: a Computation for the loop body, which has type T -> T + init: a XlaOp for the initial parameter, which has type T + Returns: a XlaOp representing the While operation. + """ + return ops.While(cond.computation, body.computation, init) + + def Conditional(self, pred, true_operand, true_computation, false_operand, + false_computation): + """Enqueues a Conditional operation onto the computation. + + Args: + predicate: a XlaOp to test, which has scalar type PRED + true_operand: a XlaOp of type T_0 + true_computation: a Computation to apply to true_operand, type T_0 -> S + false_operand: a ComputationDatahandle of type T_1 + false_computation: a Computation to apply to false_operand, type T_1 -> S + Returns: a XlaOp representing the Conditional operation. + """ + return ops.Conditional(pred, true_operand, true_computation.computation, + false_operand, false_computation.computation) + + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. + + Args: + operand: a ComputationDataHandle to test. + Returns: bool indicating whether `operand` is a compile-time constant, + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. + """ + return self._builder.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a XlaOp to test. + Returns: a Computation that is rooted on the given `operand` which is a + compile-time constant. + """ + return ops.BuildConstantSubGraph(operand) + + def DotGeneral(self, lhs, rhs, dimension_numbers): + """Enqueues a general dot operation onto the computation. + + Args: + lhs: XlaOp for the left-hand-side array. + rhs: XlaOp for the right-hand-side array. + dimension_numbers: either a DotDimensionNumbers or a nested tuple + ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + Returns: a XlaOp representing the DotGeneral operation. + """ + if isinstance(dimension_numbers, tuple): + dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) + return ops.DotGeneral(lhs, rhs, dimension_numbers) + + def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): + """Enqueues a Conv operation onto the computation. + + Args: + lhs: XlaOp for the rank N+2 array of inputs. + rhs: XlaOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: PaddingType representing either 'SAME' or 'VALID' padding. + feature_group_count: number of feature groups for grouped convolution. + Returns: a XlaOp representing the Conv operation. + """ + pads = _convert_padding_type_to_pad_values( + padding, + self.GetShape(lhs).dimensions()[2:], + self.GetShape(rhs).dimensions()[2:], window_strides) + return self.ConvGeneralDilated( + lhs, + rhs, + window_strides, + pads, [], [], + dimension_numbers=None, + feature_group_count=feature_group_count) + + def ConvWithGeneralPadding(self, + lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + feature_group_count=1): + """Enqueues a ConvWithGeneralPadding operation onto the computation. + + Args: + lhs: XlaOp for the rank N+2 array of inputs. + rhs: XlaOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of dilation factors. + rhs_dilation: length-N array-like of dilation factors. + feature_group_count: number of feature groups for grouped convolution. + + Returns: + A ComputationdataHandle representing the added ConvWithGeneralPadding op. + """ + return self.ConvGeneralDilated( + lhs, + rhs, + list(window_strides), + list(padding), + list(lhs_dilation), + list(rhs_dilation), + dimension_numbers=None, + feature_group_count=feature_group_count) + + def _GetConvDimensionNumbers(self, num_spatial_dims): + """Create ConvolutionDimensionNumbers proto for convolutions.""" + nd = num_spatial_dims + dimension_numbers = ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + return dimension_numbers + + def ConvGeneralDilated(self, + lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers=None, + feature_group_count=1): + """Enqueues a ConvGeneralDilated operation onto the computation. + + Args: + lhs: XlaOp for the rank N+2 array of inputs. + rhs: XlaOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of integer dilation factors. + rhs_dilation: length-N array-like of integer dilation factors. + dimension_numbers: optional, either a ConvolutionDimensionNumbers object + or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of + length N+2 identifying by position: (1) batch dimensions in lhs, rhs, + and the output with the character 'N', (2) feature dimensions in lhs + and the output with the character 'C', (3) input and output feature + dimensions in rhs with the characters 'I' and 'O' respectively, and + (4) spatial dimension correspondences between lhs, rhs, and the output + using any distinct characters. For example, to indicate dimension + numbers consistent with the Conv operation with two spatial + dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another + example, to indicate dimension numbers consistent with the TensorFlow + Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using + the latter form of convolution dimension specification, window strides + are associated with spatial dimension character labels according to + the order in which the labels appear in the rhs_spec string, so that + window_strides[0] is matched with the dimension corresponding to the + first character appearing in rhs_spec that is not 'I' or 'O'. By + default, use the same dimension numbering as Conv and + ConvWithGeneralPadding. + feature_group_count: number of feature groups for grouped convolution. + Returns: a XlaOp representing the ConvGenralDilated operation. + """ + if dimension_numbers is None: + dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) + elif isinstance(dimension_numbers, tuple): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + dimension_numbers.input_spatial_dimensions.extend( + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) + dimension_numbers.output_spatial_dimensions.extend( + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) + return ops.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, dimension_numbers, + feature_group_count) + + def Sort(self, operand, dimension=-1): + """Enqueues a sort operation onto the computation.""" + return ops.Sort(self._builder, [operand], dimension) + + def SortKeyVal(self, keys, values, dimension=-1): + """Enqueues a key-value sort operation onto the computation.""" + return ops.Sort(self._builder, [keys, values], dimension) + + def QR(self, a, full_matrices=True): + """Enqueues a QR decomposition onto the computation.""" + return self.Tuple(*ops.QR(a, full_matrices)) + + def TriangularSolve(self, + a, + b, + left_side=False, + lower=False, + transpose_a=False, + conjugate_a=False, + unit_diagonal=False): + """Enqueues a triangular-solve operation onto the computation.""" + if not transpose_a: + transpose = _xla.TriangularSolveOptions_Transpose.NO_TRANSPOSE + if conjugate_a: + a = self.Conj(a) + else: + transpose = ( + _xla.TriangularSolveOptions_Transpose.ADJOINT + if conjugate_a else _xla.TriangularSolveOptions_Transpose.TRANSPOSE) + return ops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose) + + def Eigh(self, a, full_matrices=True): + """Enqueues a symmetric/Hermitian eigendecomposition.""" + return self.Tuple(*ops.Eigh(a, full_matrices)) + + def SVD(self, a): + """Enqueues a singular value decomposition.""" + return self.Tuple(*ops.SVD(a)) + + def Scatter(self, a, scatter_indices, updates, update_computation, + dimension_numbers): + """Enqueues a Scatter operation onto the computation.""" + return ops.Scatter(a, scatter_indices, updates, + update_computation.computation, dimension_numbers) + + def Fft(self, operand, fft_type, fft_lengths): + """Enqueues a FFT operation onto the computation.""" + return ops.Fft(operand, fft_type, fft_lengths) + + +FftType = _xla.FftType + _UNARY_OPS = [ 'Not', 'Clz', @@ -341,1517 +1490,48 @@ _BINARY_OPS = [ 'Complex', ] - -class PrimitiveType(enum.IntEnum): - """Python copy of the XLA PrimitiveType enum. - - Must match the corresponding protocol buffer. - """ - PRIMITIVE_TYPE_INVALID = 0 - PRED = 1 - S8 = 2 - S16 = 3 - S32 = 4 - S64 = 5 - U8 = 6 - U16 = 7 - U32 = 8 - U64 = 9 - BF16 = 16 - F16 = 10 - F32 = 11 - F64 = 12 - C64 = 15 - C128 = 18 - TUPLE = 13 - OPAQUE = 14 - TOKEN = 17 - - -XLA_ELEMENT_TYPE_TO_DTYPE = { - PrimitiveType.PRED: np.dtype('bool'), - PrimitiveType.S8: np.dtype('int8'), - PrimitiveType.S16: np.dtype('int16'), - PrimitiveType.S32: np.dtype('int32'), - PrimitiveType.S64: np.dtype('int64'), - PrimitiveType.U8: np.dtype('uint8'), - PrimitiveType.U16: np.dtype('uint16'), - PrimitiveType.U32: np.dtype('uint32'), - PrimitiveType.U64: np.dtype('uint64'), - PrimitiveType.F16: np.dtype('float16'), - PrimitiveType.F32: np.dtype('float32'), - PrimitiveType.F64: np.dtype('float64'), - PrimitiveType.C64: np.dtype('complex64'), - PrimitiveType.C128: np.dtype('complex128'), - PrimitiveType.TUPLE: np.dtype(np.object), -} - -# Note the conversion on the key. Numpy has a known issue wherein dtype hashing -# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, -# when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} - - -def dtype_to_etype(dtype): - """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" - return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] - - -class LocalBuffer(object): - """Represents a handle to data owned by XLA. - - The referent is ready for use in executing a local, compiled - Computation. On XLA platforms involving a device (e.g. GPU), this - means the referent is in device memory. - """ - - def __init__(self, c_buffer, backend, device): - self.c_buffer = c_buffer - self._backend = backend - self._device = device - - @staticmethod - def from_pyval(pyval, device=0, backend=None): - """Allocate and copy to XLA the given python value.""" - backend = backend or _get_default_local_backend() - pyval = require_numpy_array_layout(pyval) - cbuf = backend.buffer_from_pyval(pyval, device) - return LocalBuffer(cbuf, backend, device) - - def to_py(self): - return self.c_buffer.ToLiteral() - - def shape(self): - return _wrap_shape(self.c_buffer.shape()) - - def device(self): - return self._device - - def delete(self): - if self.c_buffer is not None: - self._backend.delete_buffer(self.c_buffer) - self.c_buffer = None - - def destructure(self): - """Assuming a tuple buffer, unpack it into constituent tuple elements.""" - assert self.c_buffer is not None - result = self._backend.destructure_tuple(self.c_buffer) - self.delete() - return tuple( - LocalBuffer(sub_buffer, device=self._device, backend=self._backend) - for sub_buffer in result) - - def is_deleted(self): - return self.c_buffer is None - - def __del__(self): - self.delete() - - -class Format(enum.IntEnum): - """Python copy of the Format protocol buffer enum.""" - INVALID_FORMAT = 0 - DENSE = 1 - SPARSE = 2 - - -class Shape(object): - """Represents an XLA shape. - - A shape is either an array shape, having rank-many integer - dimensions and an element type (represented by a Numpy dtype), or it - is a tuple shape, having a shape for every tuple component: - - type shape = - TupleShape of shape list - | ArrayShape of { dimensions: int list; element_type: dtype } - - Callers are expected to instantiate this class only via the static - constructors: tuple_shape, array_shape, and from_pyval. - """ - - @staticmethod - def tuple_shape(tuple_shapes): - """Construct a tuple shape.""" - if (not isinstance(tuple_shapes, (tuple, list)) or - not all(isinstance(t, Shape) for t in tuple_shapes)): - raise TypeError('tuple_shapes must be a tuple of Shapes') - return Shape(tuple_shapes, tuple) - - @staticmethod - def array_shape(element_type, dimensions, minor_to_major=None): - """Construct an array shape.""" - if (not isinstance(dimensions, tuple) or - not all(isinstance(i, int) for i in dimensions)): - dimensions = tuple(int(i) for i in dimensions) - return Shape( - dimensions, np.dtype(element_type), minor_to_major=minor_to_major) - - @staticmethod - def from_pyval(pyval): - def convert(pyval): - if isinstance(pyval, tuple): - return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) - else: - pyval = require_numpy_array_layout(pyval) - return Shape.array_shape(pyval.dtype, np.shape(pyval)) - return convert(pyval) - - def __init__(self, dimensions, dtype, minor_to_major=None): - assert isinstance(dimensions, tuple) - self._dimensions = dimensions - self._dtype = dtype - self._is_tuple = dtype == tuple - self._minor_to_major = minor_to_major - self._check_minor_to_major() - - def __eq__(self, other): - # pylint: disable=protected-access - return (self._dtype == other._dtype and - self._dimensions == other._dimensions and - self._minor_to_major == other._minor_to_major) - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash((self._dtype, self._dimensions, self._minor_to_major)) - - def __repr__(self): - return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' - '_is_tuple={!r}, _minor_to_major={!r})').format( - self._dtype, self._dimensions, self._is_tuple, - self._minor_to_major) - - def is_tuple(self): - return self._is_tuple - - def is_array(self): - return not self._is_tuple - - def tuple_shapes(self): - if not self.is_tuple(): - raise ValueError('not a tuple shape') - return self._dimensions - - def numpy_dtype(self): - """Like element_type(), but returns dtype('O') in case of a tuple shape.""" - if self.is_tuple(): - return np.dtype(np.object) - else: - return self.element_type() - - def xla_element_type(self): - return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())] - - def element_type(self): - if not self.is_array(): - raise ValueError('not an array shape') - return self._dtype - - def dimensions(self): - if not self.is_array(): - raise ValueError('not an array shape') - return self._dimensions - - def rank(self): - return len(self.dimensions()) - - def minor_to_major(self): - return self._minor_to_major - - def map_leaves(self, f): - """Map f over each leaf-level array subshape. - - Args: - f: The function to apply. Whenever f returns None, the identity is applied - instead. - - Returns: - A new Shape with the mapped leaves. - """ - if self.is_tuple(): - children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) - return Shape.tuple_shape(children) - else: - mapped = f(self) - return self if mapped is None else mapped - - def _check_minor_to_major(self): - mtm = self._minor_to_major - if self.is_tuple(): - assert mtm is None, self - if mtm is not None: - assert self.rank() == len(mtm), self - assert sorted(mtm) == list(range(len(mtm))), self - - def update_minor_to_major(self, minor_to_major): - if not self.is_array(): - raise ValueError('not an array shape') - if not isinstance(minor_to_major, tuple): - raise TypeError('minor_to_major must be a tuple') - updated = Shape.array_shape(self.element_type(), self.dimensions(), - minor_to_major) - updated._check_minor_to_major() # pylint: disable=protected-access - return updated - - def with_major_to_minor_layout_if_absent(self): - """Returns a copy of a shape with missing layouts set to major-to-minor.""" - - def f(a): - if a.minor_to_major(): - return None - return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1))) - - return self.map_leaves(f) - - def serialize(self, proto): - """Serializes 'shape' into proto.""" - if self.is_tuple(): - proto.element_type = PrimitiveType.TUPLE - for shape in self.tuple_shapes(): - shape.serialize(proto.tuple_shapes.add()) - else: - proto.element_type = dtype_to_etype(self.element_type()) - proto.dimensions.extend(self.dimensions()) - proto.is_dynamic_dimension.extend([False for _ in self.dimensions()]) - if self.minor_to_major(): - proto.layout.format = Format.DENSE - proto.layout.minor_to_major.extend(self.minor_to_major()) - - -ProgramShape = collections.namedtuple('ProgramShape', - ('parameter_shapes', 'result_shape')) - - -def _wrap_shape(shape_info): - dtype, dims = shape_info - element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] - if element_type == PrimitiveType.TUPLE: - shapes = tuple(_wrap_shape(subshape_info) for subshape_info in dims) - return Shape.tuple_shape(shapes) - else: - return Shape.array_shape(dtype, dims) - - -def _wrap_program_shape(shape_info): - arg_shapes, result_shape = shape_info - return ProgramShape([_wrap_shape(arg) for arg in arg_shapes], - _wrap_shape(result_shape)) - - -def require_numpy_array_layout(value): - if isinstance(value, tuple): - return tuple(require_numpy_array_layout(x) for x in value) - else: - return np.require(value, requirements=['C', 'A']) - - -class CompileOptions(object): - """Python object for XLA compile options. - - These options can be passed to the 'compile' step when using a local XLA - client. - """ - - def __init__(self): - self.generate_hlo_graph = None - self.dump_optimized_hlo_proto_to = None - self.dump_unoptimized_hlo_proto_to = None - self.dump_per_pass_hlo_proto_to = None - self.hlo_profile = False - self.num_replicas = get_replica_count() - - -def transfer_to_infeed(value, device_ordinal=0): - """Transfers the given value into the XLA infeed queue. - - XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with - a totally ordered stream of values. This is dequeued from XLA computations via - the Infeed() operation. - - Args: - value: the value that the caller would like to enqueue into the XLA infeed - queue - device_ordinal: the device to infeed the value to. Each device has a - distinct infeed queue. - """ - # TODO(phawkins): support non-default backends. - backend = _get_default_local_backend() - backend.client.TransferToInfeed( - require_numpy_array_layout(value), device_ordinal) - - -def transfer_from_outfeed(shape, device_ordinal=0): - """Transfers a literal of the given shape from `device_ordinal`'s outfeed. - - Args: - shape: The shape of the value to transfer from outfeed. - device_ordinal: The device ordinal to transfer the outfeed value from. Each - device has a distinct outfeed queue.. - - Returns: - The literal value that is produced from the outfeed queue. - """ - # TODO(phawkins): support non-default backends. - backend = _get_default_local_backend() - return backend.client.TransferFromOutfeed(shape, device_ordinal) - - -class Computation(object): - """Python wrapper for an XLA Computation. - - A Computation can be compiled to form an Executable, or used as a - subcomputation in ComputationBuilder methods. - """ - - def __init__(self, c_computation, backend=None): - self._c_computation = c_computation - # The backend argument is deprecated. Pass a backend to Compile() instead. - self._backend = backend - self._delete_computation = c_api.DeleteComputation - - @property - def computation(self): - return self._c_computation - - def GetSerializedProto(self): - """Gets the serialized HloModuleProto proto object in this computation. - - Returns: - A string containing a serialized HloModuleProto proto containing the - computation and its dependencies. - """ - return self.computation.GetSerializedProto() - - def GetHloText(self): - """Get the textual HLO representation of this computation. - - Returns: - A string containing the textual HLO. - """ - return self.computation.GetHloText() - - def GetHloDotGraph(self): - """Get a Graphviz Dot representation of this computation. - - Returns: - A string containing the graphviz dot graph. - """ - return self.computation.GetHloDotGraph() - - def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None, - backend=None): - """Compiles a computation. - - Computations are the result of a "ComputationBuild'ing" process. - - Arguments: - argument_shapes: parameter shapes -- they are first laid out by layout_fn - if layout_fn is provided. Otherwise, the default layout for those shapes - will be used. - compile_options: options to use for compilation, includes an optional laid - out result shape for the computation. - layout_fn: lambda that is used to lay out the argument/result shapes. - backend: a `Backend` for which an executable should be generated. - - Returns: - A Executable instance. - """ - backend = backend or self._backend or _get_default_local_backend() - result_shape = _wrap_shape(self.computation.GetReturnValueShape()) - - if layout_fn: - argument_shapes = [ - shape.map_leaves(layout_fn) for shape in argument_shapes - ] - result_shape = result_shape.map_leaves(layout_fn) - - argument_shapes = list(argument_shapes) - - compile_options = compile_options or CompileOptions() - compile_options.result_shape = result_shape - c = backend.compile(self.computation, argument_shapes, result_shape, - compile_options) - return Executable(c, backend=backend) - - def CompileWithExampleArguments(self, - arguments=(), - compile_options=None, - layout_fn=None, - backend=None): - return self.Compile( - argument_shapes=[Shape.from_pyval(arg) for arg in arguments], - compile_options=compile_options, - layout_fn=layout_fn, - backend=backend) - - def GetProgramShape(self): - return _wrap_program_shape(self._c_computation.GetProgramShape()) - - def GetReturnValueShape(self): - return _wrap_shape(self._c_computation.GetReturnValueShape()) - - def __del__(self): - if self._c_computation: - self._delete_computation(self._c_computation) - - -class Executable(object): - """Python wrapper for an XLA Executable.""" - - def __init__(self, c_executable, backend=None): - self._c_executable = c_executable - self._device_ordinals = c_executable.DeviceOrdinals() - self._backend = backend - - def DeviceOrdinals(self): - """Returns a list containing the device ordinals for each replica.""" - return self._device_ordinals - - def Execute(self, arguments=(), check_for_deleted_args=True): - """Execute on one replica with LocalBuffer arguments and return value.""" - if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): - raise ValueError('Executing with deleted local buffer argument') - raw_args = [arg.c_buffer for arg in arguments] - output_buffer = self._backend.execute(self._c_executable, raw_args) - return LocalBuffer( - output_buffer, backend=self._backend, device=self._device_ordinals[0]) - - def ExecutePerReplica(self, arguments=None): - """Execute on many replicas with LocalBuffer arguments and return value. - - Args: - arguments: A sequence of sequences of LocalBuffers. The i'th inner - sequence comprises the arguments for execution on the i'th replica. - - Returns: - A list of the computation's outputs for each replica, as a LocalBuffer. If - a shallow sequence of arguments was passed in for `arguments`, then the - sole, zero'th replica's output is returned instead, as a LocalBuffer. - """ - if arguments is None: - arguments = ((),) * len(self._device_ordinals) - else: - arguments = [list(replica_args) for replica_args in arguments] - - # Check arguments - for replica, replica_args in enumerate(arguments): - for arg in replica_args: - if arg.is_deleted(): - raise ValueError('Executing with deleted local buffer argument') - if arg.device() != self._device_ordinals[replica]: - raise ValueError( - 'Executing on device {} with argument from device {}'.format( - self._device_ordinals[replica], arg.device())) - - # Pull out argument buffer handles - # pylint: disable=g-complex-comprehension - stripped_args = [ - [arg.c_buffer for arg in replica_args] for replica_args in arguments - ] - - # Execute - output_buffers = self._backend.execute_replicated(self._c_executable, - stripped_args) - - # Wrap output handles in LocalBuffer instances - return tuple( - LocalBuffer( - output_buffer, - backend=self._backend, - device=self._device_ordinals[replica]) - for replica, output_buffer in enumerate(output_buffers)) - - def ExecuteWithPythonValues(self, arguments=()): - """Execute on one replica with Python values as arguments and output.""" - - def put(arg): - return LocalBuffer.from_pyval( - arg, device=self._device_ordinals[0], backend=self._backend) - - arguments = [put(arg) for arg in arguments] - return self.Execute(arguments).to_py() - - def ExecuteWithPythonValuesPerReplica(self, arguments): - """Execute on many replicas with Python values as arguments and output.""" - - def put(arg, device): - return LocalBuffer.from_pyval(arg, device, backend=self._backend) - - # pylint: disable=g-complex-comprehension - arguments = [[ - put(arg, self._device_ordinals[replica]) for arg in replica_args - ] for replica, replica_args in enumerate(arguments)] - return [out.to_py() for out in self.ExecutePerReplica(arguments)] - - def __del__(self): - # Python may have freed c_api first. - if c_api and self._c_executable: - self._backend.delete_executable(self._c_executable) - - -class ComputationBuilder(object): - """XLA computation builder. - - Enqueues XLA ops in sequence and in order to build a - Computation, which in turn can be compiled into a - LocalExecutable, which in turn can be locally executed. - """ - - # The methods of this class map 1-to-1 onto the XLA C++ - # computation builder API. Therefore, there's no need to laboriously list - # arguments and return values for every method, especially where it's obvious. - # - # pylint: disable=g-doc-return-or-yield - # pylint: disable=g-doc-args - - def __init__(self, name): - self._client = c_api.ComputationBuilder(name.encode('utf8')) - self._parameter_numbering = itertools.count() - - def Build(self, root=None, backend=None): - """Builds a `Computation` from the contents of the builder. - - Args: - root: if not None, the operator containing the return value of the - computation. - backend: deprecated. Pass a `backend` to `Computation.Compile` instead. - - Returns: - A `Computation`. - """ - if root is not None: - return Computation(self._client.BuildWithRoot(root), backend=backend) - else: - return Computation(self._client.Build(), backend=backend) - - def SetOpMetadata(self, op_metadata): - """Set metadata for operations that are about to be enqueued.""" - self._client.SetOpMetadata(op_metadata) - - def ClearOpMetadata(self): - """Clear metadata for operations that are about to be enqueued.""" - self._client.ClearOpMetadata() - - def Infeed(self, shape): - """Enqueues an infeed op onto the computation. - - Infeed operations dequeue data of the given shape from the device's infeed - queue for subsequent use in the computation. - - Returns: - A LocalOp. - """ - return self._client.Infeed(shape) - - def Outfeed(self, operand): - """Enqueues an outfeed op onto the computation. - - Outfeed operations enqueue data, using the given operand, onto the XLA - outfeed queue for subsequent dequeue via the client API. - """ - self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) - - def Constant(self, value): - """Enqueues a constant op onto the computation. - - Args: - value: value for the constant, as a np.array with an explicit dtype set to - one of the supported types. - - Returns: - A LocalOp. - """ - value = require_numpy_array_layout(value) - return self._client.ConstantLiteral(value) - - def ConstantF32Scalar(self, value): - """Convenience method to enqueue a scalar F32 constant op. - - Args: - value: a floating-point number. - - Returns: - A LocalOp. - """ - return self.Constant(np.array(value, dtype=np.float32)) - - def ConstantF64Scalar(self, value): - """Convenience method to enqueue a scalar F32 constant op. - - Args: - value: a floating-point number. - - Returns: - A LocalOp. - """ - return self.Constant(np.array(value, dtype=np.float64)) - - def ConstantS32Scalar(self, value): - """Convenience method to enqueue a scalar S32 constant op. - - Args: - value: a floating-point number. - - Returns: - A LocalOp. - """ - return self.Constant(np.array(value, dtype=np.int32)) - - def ConstantS64Scalar(self, value): - """Convenience method to enqueue a scalar S64 constant op. - - Args: - value: a floating-point number. - - Returns: - A LocalOp. - """ - return self.Constant(np.array(value, dtype=np.int64)) - - def ConstantPredScalar(self, value): - """Convenience method to enqueue a scalar PRED constant op. - - Args: - value: a boolean value. - - Returns: - A LocalOp. - """ - return self.Constant(np.array(value, dtype=np.bool)) - - def ParameterWithShape(self, shape, name=None, parameter_num=None): - """Enqueues a Parameter op onto the computation, given a shape. - - Args: - shape: the parameter's shape as a Shape object. - name: optional string name for the parameter. - parameter_num: parameter number in the computation function. If None, the - next linear parameter number is used. The default value capability can - be used for auto-numbering. If you're using auto-numbering for some - parameters, use it for *all* parameters to avoid clashes. - - Returns: - A LocalOp. - """ - if name is None: - name = '' - if parameter_num is None: - parameter_num = next(self._parameter_numbering) - - return self._client.Parameter(parameter_num, shape, name.encode('utf8')) - - def ParameterFromNumpy(self, value, name=None, parameter_num=None): - """Enqueues a Parameter op onto the computation. - - Args: - value: a Numpy array, or a nested tuple thereof, from which the shape is - inferred. - name: as in ParameterWithShape. - parameter_num: as in ParameterWithShape. - - Returns: - A LocalOp. - """ - return self.ParameterWithShape( - Shape.from_pyval(value), name=name, parameter_num=parameter_num) - - def Iota(self, dtype, size): - """Enqueues an iota constant onto the computation. - - Args: - dtype: expected numpy dtype of the output. - size: integer, the number of elements in the array. - - Returns: - A LocalOp representing the added iota constant. - """ - element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] - return self._client.Iota(element_type, size) - - def BroadcastedIota(self, dtype, shape, dimension): - """Enqueues a broadcasted iota constant onto the computation. - - Args: - dtype: expected numpy dtype of the output. - shape: tuple of integers, the expected output shape (dimensions). - dimension: positive integer, dimension along which to increment values. - - Returns: - A LocalOp representing the added broadcasted iota constant. - """ - xla_shape = Shape.array_shape(dtype, shape) - return self._client.BroadcastedIota(xla_shape, dimension) - - def Broadcast(self, operand, sizes): - """Enqueues a broadcast operation onto the computation. - - Args: - operand: the operand LocalOp to broadcast. - sizes: an iterable of broadcast sizes. - - Returns: - A LocalOp representing the added broadcast op. - """ - return self._client.Broadcast(operand, sizes) - - def BroadcastInDim(self, operand, shape, broadcast_dimensions): - """Enqueues a broadcast-in-dimensions operation onto the computation. - - Args: - operand: the operand LocalOp to broadcast. - shape: tuple of integers, the expected output shape. - broadcast_dimensions: tuple of integers identifying which dimensions of - the output are to be broadcast into. - - Returns: - A LocalOp representing the added broadcast-in-dimensions op. - """ - return self._client.BroadcastInDim(operand, shape, broadcast_dimensions) - - def Concatenate(self, operands, dimension): - """Enqueues a concatenate operation onto the computation. - - Args: - operands: the operands to concatenate. - dimension: the dimension in which to perform the concatenation. - - Returns: - A LocalOp representing the added concatenate op. - """ - return self._client.ConcatInDim(operands, dimension) - - def ConvertElementType(self, operand, new_element_type): - """Enqueues an element type conversion operation onto the computation. - - Args: - operand: the operand to convert. - new_element_type: the target primitive type. - - Returns: - A LocalOp representing the added conversion op. - """ - return self._client.ConvertElementType(operand, new_element_type) - - def BitcastConvertType(self, operand, new_element_type): - """Enqueues a bitcast type conversion operation onto the computation. - - Args: - operand: the operand to convert. - new_element_type: the target primitive type. - - Returns: - A LocalOp representing the added conversion op. - """ - return self._client.BitcastConvertType(operand, new_element_type) - - def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(operand)) - - def GetReturnValueShape(self): - return _wrap_shape(self._client.GetReturnValueShape()) - - def GetComputationStats(self): - raise NotImplementedError() - - def ReplicaId(self): - """Enqueues a ReplicaId operation onto the computation. - - Returns: - A LocalOp representing the replica id. - """ - return self._client.ReplicaId() - - def Pad(self, operand, padding_value, padding_config): - """Enqueues a Pad operation onto the computation. - - Args: - operand: LocalOp representing the array to pad. - padding_value: LocalOp representing the scalar pad value. - padding_config: either a PaddingConfig or a list of integer triples - (edge_padding_low, edge_padding_high, interior_padding) representing the - configuration of the padding operation. - - Returns: - A LocalOp representing the added Pad op. - """ - if isinstance(padding_config, tuple) or isinstance(padding_config, list): - padding_config = GetPaddingConfigFromTriples(padding_config) - return self._client.Pad(operand, padding_value, padding_config) - - def Reshape(self, operand, dimensions, new_sizes): - """Enqueues a reshape op onto the computation. - - Args: - operand: LocalOp representing the array to be reshaped. - dimensions: sequence of integers encoding the order in which dimensions - are collapsed or None, in which case dimensions are flattened in order. - new_sizes: sequence of integers encoding the new dimension sizes (shape). - - Returns: - A LocalOp representing the added Reshape op. - """ - if dimensions is None: - ndim = len(self.GetShape(operand).dimensions()) - dimensions = tuple(range(ndim)) - return self._client.Reshape(operand, dimensions, new_sizes) - - def AllToAll(self, - operand, - split_dimension, - concat_dimension, - replica_groups=None): - """AllToAll op. - - Args: - operand: LocalOp representing the input array - split_dimension: the dimension along which the operand is split - concat_dimension: the dimension along which the split blocks are - concatenated - replica_groups: optional, list of lists of ints encoding a partition of - the set {0, 1, ..., num_replicas} into equally-sized replica groups - within which the all-to-all is performed. If not supplied or None (the - default), all replicas belong to the same group. - - Returns: - A LocalOp that represents the all-to-all concatenation. - """ - if replica_groups is None: - replica_groups_protos = [] # special value for XLA API - else: - replica_groups = list(replica_groups) - replica_groups_protos = [ - _make_replica_group_proto(group) for group in replica_groups - ] - if not replica_groups: - split_count = get_replica_count() - else: - split_count = len(replica_groups[0]) - if not all(split_count == len(g) for g in replica_groups): - raise ValueError('Replica groups must be equally sized') - return self._client.AllToAll(operand, split_dimension, concat_dimension, - split_count, replica_groups_protos) - - def CrossReplicaSum(self, operand, replica_groups=None): - """CrossReplicaSum op. - - Args: - operand: the operand to sum across replica instances. - replica_groups: optional, list of lists of ints encoding a partition of - the set {0, 1, ..., num_replicas} into equally-sized replica groups - within which the cross-replica sum is performed. If not supplied or None - (the default), all replicas belong to the same group. - - Returns: - A LocalOp that represents on each replica the sum of its group's values. - """ - if replica_groups is None: - replica_groups = [] # special value for XLA API - else: - replica_groups = [ - _make_replica_group_proto(group) for group in replica_groups - ] - return self._client.CrossReplicaSum(operand, replica_groups) - - def Collapse(self, operand, dimensions): - """Collapse op.""" - return self._client.Collapse(operand, dimensions) - - def Trans(self, operand): - """Specialized matrix transpose op.""" - return self._client.Transpose(operand, [1, 0]) - - def Transpose(self, operand, permutation): - """Transpose op.""" - return self._client.Transpose(operand, permutation) - - def Rev(self, operand, dimensions): - """Rev op.""" - return self._client.Rev(operand, dimensions) - - def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin - """Clamp op.""" - return self._client.Clamp(min, operand, max) - - def SelectAndScatter(self, operand, select, window_dimensions, window_strides, - padding, source, init_value, scatter): - """Select and scatter op, used by the gradient of ReduceWindow. - - Args: - operand: LocalOp for array of dimension N and type T over which the - windows slide. - select: Computation of type (T, T) -> Pred to apply to the elements of - each window to indicate which element is selected. - window_dimensions: sequence of N integers for dimensions of the window. - window_strides: sequence of N integers for the strides of the window. - padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: LocalOp for array of type T with values to scatter. - init_value: LocalOp of scalar type T for initial out value. - scatter: Computation of type (T, T) -> T to apply to each scatter source - element with its destination element. - - Returns: - A LocalOp representing the added SelectAndScatter op. - """ - pads = _convert_padding_type_to_pad_values( - padding, self.GetShape(operand).dimensions(), window_dimensions, - window_strides) - return self._client.SelectAndScatterWithGeneralPadding( - operand, select.computation, window_dimensions, window_strides, pads, - source, init_value, scatter.computation) - - def Select(self, pred, on_true, on_false): - """Element-wise selection op. - - Constructs an output array from elements of two input arrays, based on the - values of a predicate array. - """ - return self._client.Select(pred, on_true, on_false) - - def Slice(self, operand, start_indices, limit_indices, strides=None): - """Enqueues a slice operation onto the computation. - - Args: - operand: LocalOp for the N dimensional array to be sliced. - start_indices: iterable of N integers containing the starting indices of - the slice for each dimension. - limit_indices: iterable of N integers containing the ending indices - (exclusive) of the slice for each dimension. - strides: optional iterable of N integers containing the stride sizes for - each dimension. - - Returns: - A LocalOp representing the added Slice op. - """ - if strides is None: - start_indices = list(start_indices) - strides = [1] * len(start_indices) - return self._client.Slice(operand, start_indices, limit_indices, strides) - - def SliceInDim(self, operand, start_index, limit_index, stride, dimno): - """Enqueues a slice-in-dimension operation onto the computation. - - Args: - operand: LocalOp for the N dimensional array to be sliced. - start_index: an integer containing the start index of the slice. - limit_index: an integer containing the end index of the slice. - stride: an integer containing the stride size for the slice. - dimno: an integer indicating the dimension along which to slice. - - Returns: - A LocalOp representing the added Slice op. - """ - return self._client.SliceInDim(operand, start_index, limit_index, stride, - dimno) - - def DynamicSlice(self, operand, start_indices, slice_sizes): - """Enqueues a slice op with dynamic start indices onto the computation. - - Args: - operand: LocalOp for the N dimensional array to be sliced. - start_indices: LocalOp for the 1D array of N integers containing the - starting indices of the slice. - slice_sizes: iterable of N integers containing the slice sizes in each - dimension. - - Returns: - A LocalOp representing the added DynamicSlice op. - """ - return self._client.DynamicSlice(operand, start_indices, slice_sizes) - - def DynamicUpdateSlice(self, operand, update, start_indices): - """Enqueues a dynamic update slice operation onto the computation. - - Args: - operand: LocalOp for the N dimensional array to be updated. - update: N dimensional array comprising the slice update. - start_indices: Rank-1 array of N integers comprising the starting indices - of the slice along each dimension. - - Returns: - A LocalOp representing the added DynamicUpdateSlice op. - """ - return self._client.DynamicUpdateSlice(operand, update, start_indices) - - def Tuple(self, *ops): - """Enqueues a tuple operation onto the computation. - - Args: - ops: a sequence of tuple operands (each a LocalOp). - - Returns: - A LocalOp representing the added Tuple op. - """ - return self._client.Tuple(ops) - - def GetTupleElement(self, tup, index): - """Enqueues a 'get tuple element' operation onto the computation. - - Args: - tup: the tuple operand (a LocalOp). - index: numeric index to select from the tuple. - - Returns: - A LocalOp representing the added GetTupleElement op. - """ - return self._client.GetTupleElement(tup, index) - - def Call(self, computation_to_apply, operands): - """Enqueues a call operation onto the computation. - - Args: - computation_to_apply: a Computation object. - operands: an iterable of LocalOp. The number and types of operands must - match the arity of computation_to_apply. - - Returns: - A LocalOp representing the added call op. - """ - return self._client.Call(computation_to_apply.computation, operands) - - def CustomCall(self, - call_target_name, - operands, - shape_with_layout, - operand_shapes_with_layout, - opaque=None): - """Enqueues a custom call operation onto the computation. - - Args: - call_target_name: the name of the function to call. - operands: an iterable of LocalOp. The number and types of operands must - match the arity of `operand_shapes_with_layout`. - shape_with_layout: the shape of the operator's output, with layout. - operand_shapes_with_layout: the shapes of `operands`, including the - expected layouts. - opaque: an opaque string passed to the backend. - - Returns: - A LocalOp representing the added custom call op. - """ - opaque = opaque or b'' - return self._client.CustomCall(call_target_name, operands, - shape_with_layout, - operand_shapes_with_layout, opaque) - - def Map(self, operands, computation_to_apply, dimensions): - """Enqueues a map operation onto the computation. - - Args: - operands: an iterable of LocalOp. - computation_to_apply: a Computation object. - dimensions: dimensions over which to apply map the function. - - Returns: - A LocalOp representing the added Map op. - """ - return self._client.Map(operands, computation_to_apply.computation, - dimensions) - - def Reduce(self, operand, init_value, computation_to_apply, dimensions): - """Enqueues a reduction operation onto the computation. - - Args: - operand: reduction operand (LocalOp). - init_value: reduction initial value (LocalOp). - computation_to_apply: a Computation object - binary reduction function. - dimensions: sequence of dimensions (integers) to reduce on. - - Returns: - A LocalOp representing the added Reduce op. - """ - return self._client.Reduce(operand, init_value, - computation_to_apply.computation, dimensions) - - def ReduceWindow(self, operand, init_value, computation_to_apply, - window_dimensions, window_strides, padding): - """Enqueues a windowed reduction operation onto the computation. - - Args: - operand: reduction operand (LocalOp). - init_value: reduction initial value (LocalOp). - computation_to_apply: a binary reduction function (Computation). - window_dimensions: dimensions of window (sequence of integers). - window_strides: strides for window (sequence of integers). - padding: PaddingType representing either 'SAME' or 'VALID' padding. - - Returns: - A LocalOp representing the added ReduceWindow op. - """ - pads = _convert_padding_type_to_pad_values( - padding, - self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return self._client.ReduceWindowWithGeneralPadding( - operand, init_value, computation_to_apply.computation, - window_dimensions, window_strides, (), (), pads) - - def ReduceWindowWithGeneralPadding( - self, operand, init_value, computation_to_apply, window_dimensions, - window_strides, base_dilations, window_dilations, padding): - """Enqueues a windowed reduction operation onto the computation. - - Args: - operand: reduction operand (LocalOp). - init_value: reduction initial value (LocalOp). - computation_to_apply: a binary reduction function (Computation). - window_dimensions: dimensions of window (sequence of integers). - window_strides: strides for window (sequence of integers). - base_dilations: dilations for the base (sequence of integers). - window_dilations: dilations for window (sequence of integers). - padding: length-N array-like of pairs of integers of (low, high) padding. - - Returns: - A LocalOp representing the added ReduceWindow op. - """ - return self._client.ReduceWindowWithGeneralPadding( - operand, init_value, computation_to_apply.computation, - window_dimensions, window_strides, base_dilations, window_dilations, - padding) - - def RngNormal(self, mu, sigma, dims): - """Enqueues an RngNormal operation onto the computation. - - Args: - mu: A LocalOp to an F32 scalar specifying the mean. - sigma: A LocalOp to an F32 scalar specifying the standard deviation. - dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a LocalOp to the generated array of F32 values. - """ - shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) - return self._client.RngNormal(mu, sigma, shape) - - def RngUniform(self, a, b, dims): - """Enqueues an RngUniform operation onto the computation. - - Args: - a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) - specifying the low end of the interval [a, b) over which values are - generated. - b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) - specifying the high end of the interval [a, b) over which values are - generated. - dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a LocalOp to the generated array of values with the same numeric - type (F32, S32, or U32) as the arguments a and b. - """ - shape = Shape.array_shape(self.GetShape(a).element_type(), dims) - return self._client.RngUniform(a, b, shape) - - def While(self, cond, body, init): - """Enqueues a While operation onto the computation. - - Args: - cond: a Computation for the loop condition, which has type T -> PRED - body: a Computation for the loop body, which has type T -> T - init: a LocalOp for the initial parameter, which has type T - Returns: a LocalOp representing the While operation. - """ - return self._client.While(cond.computation, body.computation, init) - - def Conditional(self, pred, true_operand, true_computation, false_operand, - false_computation): - """Enqueues a Conditional operation onto the computation. - - Args: - predicate: a LocalOp to test, which has scalar type PRED - true_operand: a LocalOp of type T_0 - true_computation: a Computation to apply to true_operand, type T_0 -> S - false_operand: a ComputationDatahandle of type T_1 - false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a LocalOp representing the Conditional operation. - """ - return self._client.Conditional(pred, true_operand, - true_computation.computation, false_operand, - false_computation.computation) - - def IsConstant(self, operand): - """Checks whether the given operand is a compile-time constant. - - Args: - operand: a ComputationDataHandle to test. - Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on any parametersor, or on stateful - operators such as `RngNormal` or `Infeed`. - """ - return self._client.IsConstant(operand) - - def BuildConstantSubGraph(self, operand): - """Builds a constant sub graph. - - Args: - operand: a LocalOp to test. - Returns: a Computation that is rooted on the given `operand` which is a - compile-time constant. - """ - return self._client.BuildConstantSubGraph(operand) - - def Dot(self, lhs, rhs): - """Enqueues a dot operation onto the computation. - - Args: - lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. - rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a LocalOp representing the Dot operation. - """ - return self._client.Dot(lhs, rhs) - - def DotGeneral(self, lhs, rhs, dimension_numbers): - """Enqueues a general dot operation onto the computation. - - Args: - lhs: LocalOp for the left-hand-side array. - rhs: LocalOp for the right-hand-side array. - dimension_numbers: either a DotDimensionNumbers or a nested tuple - ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of - integers representing the dimensions to treat as contracting dimensions - and batch dimensions on each input operand. - Returns: a LocalOp representing the DotGeneral operation. - """ - if isinstance(dimension_numbers, tuple): - dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return self._client.DotGeneral(lhs, rhs, dimension_numbers) - - def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): - """Enqueues a Conv operation onto the computation. - - Args: - lhs: LocalOp for the rank N+2 array of inputs. - rhs: LocalOp for the rank N+2 array of kernel weights. - window_strides: length-N array-like of integer kernel strides. - padding: PaddingType representing either 'SAME' or 'VALID' padding. - feature_group_count: number of feature groups for grouped convolution. - Returns: a LocalOp representing the Conv operation. - """ - pads = _convert_padding_type_to_pad_values( - padding, - self.GetShape(lhs).dimensions()[2:], - self.GetShape(rhs).dimensions()[2:], window_strides) - return self.ConvGeneralDilated( - lhs, rhs, window_strides, pads, (), (), dimension_numbers=None, - feature_group_count=feature_group_count) - - def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, feature_group_count=1): - """Enqueues a ConvWithGeneralPadding operation onto the computation. - - Args: - lhs: LocalOp for the rank N+2 array of inputs. - rhs: LocalOp for the rank N+2 array of kernel weights. - window_strides: length-N array-like of kernel strides. - padding: length-N array-like of pairs of integers of (low, high) padding. - lhs_dilation: length-N array-like of dilation factors. - rhs_dilation: length-N array-like of dilation factors. - feature_group_count: number of feature groups for grouped convolution. - - Returns: - A ComputationdataHandle representing the added ConvWithGeneralPadding op. - """ - return self.ConvGeneralDilated( - lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers=None, feature_group_count=feature_group_count) - - def _GetConvDimensionNumbers(self, num_spatial_dims): - """Create ConvolutionDimensionNumbers proto for convolutions.""" - nd = num_spatial_dims - dimension_numbers = ConvolutionDimensionNumbers() - dimension_numbers.input_batch_dimension = 0 - dimension_numbers.input_feature_dimension = 1 - dimension_numbers.output_batch_dimension = 0 - dimension_numbers.output_feature_dimension = 1 - dimension_numbers.kernel_output_feature_dimension = 0 - dimension_numbers.kernel_input_feature_dimension = 1 - dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) - dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) - dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) - return dimension_numbers - - def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers=None, - feature_group_count=1): - """Enqueues a ConvGeneralDilated operation onto the computation. - - Args: - lhs: LocalOp for the rank N+2 array of inputs. - rhs: LocalOp for the rank N+2 array of kernel weights. - window_strides: length-N array-like of integer kernel strides. - padding: length-N array-like of pairs of integers of (low, high) padding. - lhs_dilation: length-N array-like of integer dilation factors. - rhs_dilation: length-N array-like of integer dilation factors. - dimension_numbers: optional, either a ConvolutionDimensionNumbers object - or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of - length N+2 identifying by position: (1) batch dimensions in lhs, rhs, - and the output with the character 'N', (2) feature dimensions in lhs - and the output with the character 'C', (3) input and output feature - dimensions in rhs with the characters 'I' and 'O' respectively, and - (4) spatial dimension correspondences between lhs, rhs, and the output - using any distinct characters. For example, to indicate dimension - numbers consistent with the Conv operation with two spatial - dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another - example, to indicate dimension numbers consistent with the TensorFlow - Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using - the latter form of convolution dimension specification, window strides - are associated with spatial dimension character labels according to - the order in which the labels appear in the rhs_spec string, so that - window_strides[0] is matched with the dimension corresponding to the - first character appearing in rhs_spec that is not 'I' or 'O'. By - default, use the same dimension numbering as Conv and - ConvWithGeneralPadding. - feature_group_count: number of feature groups for grouped convolution. - Returns: a LocalOp representing the ConvGenralDilated operation. - """ - if dimension_numbers is None: - dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - elif isinstance(dimension_numbers, tuple): - lhs_spec, rhs_spec, out_spec = dimension_numbers - dimension_numbers = ConvolutionDimensionNumbers() - - dimension_numbers.input_batch_dimension = lhs_spec.index('N') - dimension_numbers.input_feature_dimension = lhs_spec.index('C') - dimension_numbers.output_batch_dimension = out_spec.index('N') - dimension_numbers.output_feature_dimension = out_spec.index('C') - dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') - dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') - - dimension_numbers.kernel_spatial_dimensions.extend( - i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) - dimension_numbers.input_spatial_dimensions.extend( - sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(lhs_spec[i]))) - dimension_numbers.output_spatial_dimensions.extend( - sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(out_spec[i]))) - return self._client.ConvGeneralDilated( - lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count) - - def Sort(self, operand, dimension=-1): - """Enqueues a sort operation onto the computation.""" - return self._client.Sort(operand, dimension) - - def SortKeyVal(self, keys, values, dimension=-1): - """Enqueues a key-value sort operation onto the computation.""" - return self._client.SortKeyVal(keys, values, dimension) - - def Cholesky(self, a, lower=True): - """Enqueues a Cholesky decomposition onto the computation.""" - return self._client.Cholesky(a, lower) - - def QR(self, a, full_matrices=True): - """Enqueues a QR decomposition onto the computation.""" - return self._client.QR(a, full_matrices) - - def TriangularSolve(self, - a, - b, - left_side=False, - lower=False, - transpose_a=False, - conjugate_a=False, - unit_diagonal=False): - """Enqueues a triangular-solve operation onto the computation.""" - if not transpose_a: - transpose = 1 - if conjugate_a: - a = self.Conj(a) - else: - transpose = 3 if conjugate_a else 2 - return self._client.TriangularSolve(a, b, left_side, lower, unit_diagonal, - transpose) - - def Gather(self, a, start_indices, dimension_numbers, slice_sizes): - """Enqueues a Gather operation onto the computation.""" - return self._client.Gather(a, start_indices, dimension_numbers, slice_sizes) - - def Scatter(self, a, scatter_indices, updates, update_computation, - dimension_numbers): - """Enqueues a Scatter operation onto the computation.""" - return self._client.Scatter( - a, scatter_indices, updates, update_computation.computation, - dimension_numbers) +_OTHER_OPS = [ + 'BitcastConvertType', + 'Broadcast', + 'BroadcastInDim', + 'Cholesky', + 'Clamp', + 'Collapse', + 'CollectivePermute', + 'ConvertElementType', + 'Dot', + 'Gather', + 'GetTupleElement', + 'Rev', + 'Select', + 'SliceInDim', +] def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. - Set up methods, corresponding to unary and binary XLA operations, + Set up methods, corresponding to XLA operations, whose calls are forwarded in a boilerplate manner to the underlying - ComputationBuilder C-extension API. + _xla.ops API. """ - def forward_to_local_builder_with_handles(target_method, is_binop=False): - """Generate a forwarding method that wraps/unwraps data handles.""" + def forward_op(target_method): - def forward(self, *args, **kwargs): - arg_list = list(args) - - if is_binop and len(arg_list) < 3: - arg_list.append(kwargs.get('broadcast_dimensions', ())) - - return target_method( - self._client, # pylint: disable=protected-access - *arg_list) + def forward(builder, *args, **kwargs): + del builder + return target_method(*args, **kwargs) return forward - for method_name in _UNARY_OPS: - forward = forward_to_local_builder_with_handles( - getattr(c_api.ComputationBuilder, method_name)) - forward.__name__ = method_name - setattr(ComputationBuilder, method_name, forward) - - for method_name in _BINARY_OPS: - forward = forward_to_local_builder_with_handles( - getattr(c_api.ComputationBuilder, method_name), is_binop=True) + for method_name in itertools.chain(_UNARY_OPS, _BINARY_OPS, _OTHER_OPS): + forward = forward_op(getattr(ops, method_name)) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) _forward_methods_to_local_builder() -_default_replica_count = 1 - - -def initialize_replica_count(replica_count): - """Initializes the default replica count to use. - - Deprecated; pass `num_replicas` as an option to `Computation.Compile()` - instead. - - Args: - replica_count: number of replicas that are desired for set up during XLA - initialization. - - Raises: - A runtime exception if the XLA service has already been initialized. - """ - global _default_replica_count - _default_replica_count = replica_count - - -def get_replica_count(): - """Returns the default replica count. - - Deprecated; pass `num_replicas` as an option to `Computation.Compile()` - instead. - """ - return _default_replica_count - - -def initialize_platform_name(platform_name): - """Initializes the default platform name to use for XLA. - - Args: - platform_name: string name of platform. - """ - global _default_platform_name - _default_platform_name = platform_name - - # Make sure the platform is valid by trying to instantiate it. - _get_default_local_backend() - def register_cpu_custom_call_target(name, fn): """Registers a CPU custom call target. @@ -1860,7 +1540,7 @@ def register_cpu_custom_call_target(name, fn): name: bytes containing the name of the function. fn: a PyCapsule object containing the function pointer. """ - c_api.RegisterCpuCustomCallTarget(name, fn) + _xla.RegisterCpuCustomCallTarget(name, fn) class PaddingConfigDimension(object): @@ -1868,9 +1548,9 @@ class PaddingConfigDimension(object): __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') def __init__(self): - self.edge_padding_low = [] - self.edge_padding_high = [] - self.interior_padding = [] + self.edge_padding_low = 0 + self.edge_padding_high = 0 + self.interior_padding = 0 class PaddingConfig(object): @@ -1971,3 +1651,14 @@ def _make_replica_group_proto(replica_group): replica_group_proto = ReplicaGroup() replica_group_proto.replica_ids.extend(replica_group) return replica_group_proto + + +def _get_replica_groups_protos(replica_groups): + if replica_groups is None: + replica_groups_protos = [] # special value for XLA API + else: + replica_groups = list(replica_groups) + replica_groups_protos = [ + _make_replica_group_proto(group) for group in replica_groups + ] + return replica_groups_protos diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 65594f5669d..682a6c099a6 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -24,24 +24,11 @@ import threading import numpy as np -from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import xla_client import unittest -class EnumTest(unittest.TestCase): - """Verifies Python enumerations match their protocol buffer equivalents.""" - - def testPrimitiveType(self): - for name, value in xla_client.PrimitiveType.__members__.items(): - self.assertEqual(value, getattr(xla_data_pb2, name)) - - def testFormat(self): - for name, value in xla_client.Format.__members__.items(): - self.assertEqual(value, getattr(xla_data_pb2, name)) - - class ComputationTest(unittest.TestCase): """Base class for running an XLA Computation through the local client.""" @@ -51,8 +38,8 @@ class ComputationTest(unittest.TestCase): return xla_client.ComputationBuilder(name) def _Execute(self, c, arguments): - compiled_c = c.Build().CompileWithExampleArguments(arguments) - return compiled_c.ExecuteWithPythonValues(arguments) + compiled_c = c.Build().Compile() + return xla_client.execute_with_python_values(compiled_c, arguments) def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): assert expected is not None @@ -66,11 +53,15 @@ class ComputationTest(unittest.TestCase): def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) - def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7, + def _ExecuteAndCompareClose(self, + c, + arguments=(), + expected=None, + rtol=1e-7, atol=0): self._ExecuteAndAssertWith( - functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), - c, arguments, expected) + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), c, + arguments, expected) def NumpyArrayF32(*args, **kwargs): @@ -123,14 +114,12 @@ class ComputationsWithConstantsTest(ComputationTest): def testConstantScalarSumS8(self): c = self._NewComputation() - root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) - self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) + c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) self._ExecuteAndCompareExact(c, expected=np.int8(3)) def testConstantScalarSumF32(self): c = self._NewComputation() - root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) - self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) + c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) self._ExecuteAndCompareClose(c, expected=4.25) def testConstantScalarSumF64(self): @@ -148,6 +137,14 @@ class ComputationsWithConstantsTest(ComputationTest): c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) self._ExecuteAndCompareClose(c, expected=3) + def testConstantVectorMulF16(self): + c = self._NewComputation() + c.Mul( + c.Constant(np.array([2.5, 3.3, -1.2, 0.7], np.float16)), + c.Constant(np.array([-1.2, 2, -2, -3], np.float16))) + self._ExecuteAndCompareClose( + c, expected=np.array([-3, 6.6, 2.4, -2.1], np.float16), rtol=2e-3) + def testConstantVectorMulF32(self): c = self._NewComputation() c.Mul( @@ -227,20 +224,19 @@ class ComputationsWithConstantsTest(ComputationTest): def testShiftLeft(self): c = self._NewComputation() - c.ShiftLeft(c.Constant(NumpyArrayS32([3])), - c.Constant(NumpyArrayS32([2]))) + c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2]))) self._ExecuteAndCompareClose(c, expected=[12]) def testShiftRightArithmetic(self): c = self._NewComputation() - c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])), - c.Constant(NumpyArrayS32([1]))) + c.ShiftRightArithmetic( + c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1]))) self._ExecuteAndCompareClose(c, expected=[-1]) def testShiftRightLogical(self): c = self._NewComputation() - c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])), - c.Constant(NumpyArrayS32([1]))) + c.ShiftRightLogical( + c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1]))) self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) def testSum2DF64(self): @@ -319,10 +315,11 @@ class ComputationsWithConstantsTest(ComputationTest): c.CustomCall( b"test_subtract_f32", operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), - shape_with_layout=xla_client.Shape.array_shape(np.float32, (), ()), + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), ()), operand_shapes_with_layout=( - xla_client.Shape.array_shape(np.float32, (), ()), - xla_client.Shape.array_shape(np.float32, (), ()), + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), )) self._ExecuteAndCompareClose(c, expected=0.75) @@ -407,12 +404,12 @@ class ParametersTest(ComputationTest): expected=[-4.3, 1.3, -6.3, 3.3]) -class LocalBufferTest(ComputationTest): - """Tests focusing on execution with LocalBuffers.""" +class BufferTest(ComputationTest): + """Tests focusing on execution with Buffers.""" def _Execute(self, c, arguments): - compiled_c = c.Build().CompileWithExampleArguments(arguments) - arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments] + compiled_c = c.Build().Compile() + arg_buffers = [xla_client.Buffer.from_pyval(arg) for arg in arguments] result_buffer = compiled_c.Execute(arg_buffers) return result_buffer.to_py() @@ -425,41 +422,39 @@ class LocalBufferTest(ComputationTest): c = self._NewComputation() c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) self._ExecuteAndCompareClose( - c, - arguments=[NumpyArrayF32(1.11)], - expected=4.25) + c, arguments=[NumpyArrayF32(1.11)], expected=4.25) def testTwoParameterSum(self): c = self._NewComputation() - c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), - c.ParameterFromNumpy(NumpyArrayF32(0.))) + c.Add( + c.ParameterFromNumpy(NumpyArrayF32(0.)), + c.ParameterFromNumpy(NumpyArrayF32(0.))) self._ExecuteAndCompareClose( - c, - arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)], - expected=4.25) + c, arguments=[NumpyArrayF32(1.11), + NumpyArrayF32(3.14)], expected=4.25) def testCannotCallWithDeletedBuffers(self): c = self._NewComputation() c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) arg = NumpyArrayF32(1.11) - compiled_c = c.Build().CompileWithExampleArguments([arg]) - arg_buffer = xla_client.LocalBuffer.from_pyval(arg) + compiled_c = c.Build().Compile() + arg_buffer = xla_client.Buffer.from_pyval(arg) arg_buffer.delete() - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): compiled_c.Execute([arg_buffer]) def testDestructureTupleEmpty(self): t = () - local_buffer = xla_client.LocalBuffer.from_pyval(t) + local_buffer = xla_client.Buffer.from_pyval(t) pieces = local_buffer.destructure() - self.assertTrue(local_buffer.is_deleted()) + self.assertFalse(local_buffer.is_deleted()) self.assertEqual(len(pieces), 0) def testDestructureTupleOneArrayElement(self): t = (np.array([1, 2, 3, 4], dtype=np.int32),) - local_buffer = xla_client.LocalBuffer.from_pyval(t) + local_buffer = xla_client.Buffer.from_pyval(t) pieces = local_buffer.destructure() - self.assertTrue(local_buffer.is_deleted()) + self.assertFalse(local_buffer.is_deleted()) self.assertEqual(len(pieces), 1) array = pieces[0] got = array.to_py() @@ -467,25 +462,30 @@ class LocalBufferTest(ComputationTest): np.testing.assert_equal(want, got) def testDestructureTupleTwoArrayElementDifferentType(self): - t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), - np.array([2, 3, 4, 5], dtype=np.int32)) - local_buffer = xla_client.LocalBuffer.from_pyval(t) - pieces = local_buffer.destructure() - self.assertTrue(local_buffer.is_deleted()) - self.assertEqual(len(pieces), 2) - array0, array1 = pieces - got = array0.to_py() - want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) - np.testing.assert_equal(want, got) - got = array1.to_py() - want = NumpyArrayS32([2, 3, 4, 5]) - np.testing.assert_equal(want, got) + t = ( + np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + np.array([2, 3, 4, 5], dtype=np.int32), + ) + local_buffer = xla_client.Buffer.from_pyval(t) + # Run the test twice to verify that the original tuple buffer remains valid + # even after destructuring. + for _ in range(2): + pieces = local_buffer.destructure() + self.assertFalse(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + array0, array1 = pieces + got = array0.to_py() + want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) + np.testing.assert_equal(want, got) + got = array1.to_py() + want = NumpyArrayS32([2, 3, 4, 5]) + np.testing.assert_equal(want, got) def testDestructureTupleNested(self): t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5])) - local_buffer = xla_client.LocalBuffer.from_pyval(t) + local_buffer = xla_client.Buffer.from_pyval(t) pieces = local_buffer.destructure() - self.assertTrue(local_buffer.is_deleted()) + self.assertFalse(local_buffer.is_deleted()) self.assertEqual(len(pieces), 2) tuple0, array1 = pieces got = array1.to_py() @@ -497,11 +497,27 @@ class LocalBufferTest(ComputationTest): np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) + def testMakeTuple(self): + t = ( + np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + np.array([2, 3, 4, 5], dtype=np.int32), + ) + b0 = xla_client.Buffer.from_pyval(t[0]) + b1 = xla_client.Buffer.from_pyval(t[1]) + btup = xla_client.Buffer.make_tuple([b0, b1], device=0) + pieces = btup.destructure() + self.assertEqual(len(pieces), 2) + array0, array1 = pieces + np.testing.assert_equal( + np.array([1, 2, 3, 4], dtype=np.float32), array0.to_py()) + np.testing.assert_equal( + np.array([2, 3, 4, 5], dtype=np.int32), array1.to_py()) + def testShape(self): pyval = np.array([[1., 2.]], np.float32) - local_buffer = xla_client.LocalBuffer.from_pyval(pyval) + local_buffer = xla_client.Buffer.from_pyval(pyval) xla_shape = local_buffer.shape() - self.assertEqual(xla_shape.dimensions(), (1, 2,)) + self.assertEqual(xla_shape.dimensions(), (1, 2)) self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) @@ -515,18 +531,20 @@ class SingleOpTest(ComputationTest): def testConcatenateF32(self): c = self._NewComputation() - c.Concatenate( - (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), - c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))), - dimension=0) + args = ( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])), + ) + c.Concatenate(args, dimension=0) self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) def testConcatenateF64(self): c = self._NewComputation() - c.Concatenate( - (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), - c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))), - dimension=0) + args = ( + c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])), + ) + c.Concatenate(args, dimension=0) self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) def testConvertElementType(self): @@ -543,7 +561,7 @@ class SingleOpTest(ComputationTest): x = c.Constant(np.array(template, dtype=src_dtype)) c.ConvertElementType(x, xla_types[dst_dtype]) - result = c.Build().Compile().ExecuteWithPythonValues() + result = xla_client.execute_with_python_values(c.Build().Compile()) expected = np.array(template, dtype=dst_dtype) self.assertEqual(result.shape, expected.shape) @@ -570,7 +588,7 @@ class SingleOpTest(ComputationTest): x = c.Constant(np.array(template, dtype=src_dtype)) c.BitcastConvertType(x, dst_etype) - result = c.Build().Compile().ExecuteWithPythonValues() + result = xla_client.execute_with_python_values(c.Build().Compile()) expected = np.array(template, src_dtype).view(dst_dtype) self.assertEqual(result.shape, expected.shape) @@ -680,11 +698,13 @@ class SingleOpTest(ComputationTest): a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") lhs = a(1, 2, 3, 4) rhs = a(1, 2, 1, 2) * 10 - c.Conv(c.Constant(lhs), c.Constant(rhs), - [1, 1], xla_client.PaddingType.SAME) - result = np.array([[[[640., 700., 760., 300.], - [880., 940., 1000., 380.], - [1120., 1180., 1240., 460.]]]]) + c.Conv( + c.Constant(lhs), c.Constant(rhs), [1, 1], xla_client.PaddingType.SAME) + result = np.array([[[ + [640., 700., 760., 300.], + [880., 940., 1000., 380.], + [1120., 1180., 1240., 460.], + ]]]) self._ExecuteAndCompareClose(c, expected=result) def testConvF32Valid(self): @@ -692,10 +712,12 @@ class SingleOpTest(ComputationTest): a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") lhs = a(1, 2, 3, 4) rhs = a(1, 2, 1, 2) * 10 - c.Conv(c.Constant(lhs), c.Constant(rhs), - [2, 1], xla_client.PaddingType.VALID) - result = np.array([[[[640., 700., 760.], - [1120., 1180., 1240.]]]]) + c.Conv( + c.Constant(lhs), c.Constant(rhs), [2, 1], xla_client.PaddingType.VALID) + result = np.array([[[ + [640., 700., 760.], + [1120., 1180., 1240.], + ]]]) self._ExecuteAndCompareClose(c, expected=result) def testConvWithGeneralPaddingF32(self): @@ -707,12 +729,15 @@ class SingleOpTest(ComputationTest): pads = [(1, 0), (0, 1)] lhs_dilation = (2, 1) rhs_dilation = (1, 1) - c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs), - strides, pads, lhs_dilation, rhs_dilation) - result = np.array([[[[0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.]]]]) + c.ConvWithGeneralPadding( + c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, + rhs_dilation) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) self._ExecuteAndCompareClose(c, expected=result) def testConvGeneralDilatedF32(self): @@ -725,13 +750,15 @@ class SingleOpTest(ComputationTest): lhs_dilation = (2, 1) rhs_dilation = (1, 1) dimension_numbers = ("NCHW", "OIHW", "NCHW") - c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), - strides, pads, lhs_dilation, rhs_dilation, - dimension_numbers) - result = np.array([[[[0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.]]]]) + c.ConvGeneralDilated( + c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, + rhs_dilation, dimension_numbers) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) self._ExecuteAndCompareClose(c, expected=result) def testConvGeneralDilatedPermutedF32(self): @@ -745,13 +772,10 @@ class SingleOpTest(ComputationTest): rhs_dilation = (1, 1) dimension_numbers = ("NHWC", "OIHW", "CWNH") - c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))), - c.Constant(rhs), - strides, pads, lhs_dilation, rhs_dilation, - dimension_numbers) - result = np.array([[[[0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], + c.ConvGeneralDilated( + c.Constant(np.transpose(lhs, (0, 2, 3, 1))), c.Constant(rhs), strides, + pads, lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) @@ -766,17 +790,20 @@ class SingleOpTest(ComputationTest): rhs_dilation = (1, 1) dimension_numbers = ("NCHW", "OIHW", "NCHW") feature_group_count = 2 - c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), - strides, pads, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count) - result = np.array([[[[0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.]], - [[0., 0., 0.], - [330., 380., 160.], - [0., 0., 0.], - [480., 530., 220.]]]]) + c.ConvGeneralDilated( + c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, + rhs_dilation, dimension_numbers, feature_group_count) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ], [ + [0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.], + ]]]) self._ExecuteAndCompareClose(c, expected=result) def testBooleanNot(self): @@ -967,14 +994,11 @@ class SingleOpTest(ComputationTest): c = self._NewComputation() c.Pad( c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), - c.Constant(NumpyArrayF32(0.0)), - [(1, 2, 1), (0, 1, 0)]) - self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], - [1.0, 2.0, 0.0], - [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0]]) + c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)]) + self._ExecuteAndCompareClose( + c, + expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) def testPadWithPaddingConfig(self): c = self._NewComputation() @@ -987,14 +1011,11 @@ class SingleOpTest(ComputationTest): padding_config.dimensions.append(dimension) c.Pad( c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), - c.Constant(NumpyArrayF32(0.0)), - padding_config) - self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], - [1.0, 2.0, 0.0], - [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0]]) + c.Constant(NumpyArrayF32(0.0)), padding_config) + self._ExecuteAndCompareClose( + c, + expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) def testReshape(self): c = self._NewComputation() @@ -1087,7 +1108,7 @@ class SingleOpTest(ComputationTest): c.Tuple( c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), c.Constant(NumpyArrayBool([True, False, False, True]))) - result = c.Build().Compile().ExecuteWithPythonValues() + result = xla_client.execute_with_python_values(c.Build().Compile()) self.assertIsInstance(result, tuple) np.testing.assert_equal(result[0], 42) np.testing.assert_allclose(result[1], [1.0, 2.0]) @@ -1117,9 +1138,11 @@ class SingleOpTest(ComputationTest): def testRngNormal(self): shape = (2, 3) c = self._NewComputation() - c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)), - dims=shape) - result = c.Build().Compile().ExecuteWithPythonValues() + c.RngNormal( + c.Constant(NumpyArrayF32(0.)), + c.Constant(NumpyArrayF32(1.)), + dims=shape) + result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape and uniqueness self.assertEqual(result.shape, shape) self.assertEqual(len(np.unique(result)), np.prod(shape)) @@ -1128,9 +1151,11 @@ class SingleOpTest(ComputationTest): lo, hi = 2., 4. shape = (2, 3) c = self._NewComputation() - c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)), - dims=shape) - result = c.Build().Compile().ExecuteWithPythonValues() + c.RngUniform( + c.Constant(NumpyArrayF32(lo)), + c.Constant(NumpyArrayF32(hi)), + dims=shape) + result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape, uniqueness, and range self.assertEqual(result.shape, shape) self.assertEqual(len(np.unique(result)), np.prod(shape)) @@ -1141,9 +1166,11 @@ class SingleOpTest(ComputationTest): lo, hi = 2, 4 shape = (2, 3) c = self._NewComputation() - c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)), - dims=shape) - result = c.Build().Compile().ExecuteWithPythonValues() + c.RngUniform( + c.Constant(NumpyArrayS32(lo)), + c.Constant(NumpyArrayS32(hi)), + dims=shape) + result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape, integrality, and range self.assertEqual(result.shape, shape) self.assertEqual(result.dtype, np.int32) @@ -1157,6 +1184,23 @@ class SingleOpTest(ComputationTest): c.Cholesky(c.Constant(np.dot(l, l.T))) self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) + def testSort(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + c = self._NewComputation() + c.Sort(c.Constant(keys)) + self._ExecuteAndCompareClose( + c, expected=np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)) + + def testSortKeyVal(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + c.SortKeyVal(c.Constant(keys), c.Constant(values), dimension=0) + result = xla_client.execute_with_python_values(c.Build().Compile()) + self.assertIsInstance(result, tuple) + np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) + np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) + def testQR(self): a = np.array( [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], @@ -1166,6 +1210,27 @@ class SingleOpTest(ComputationTest): q, r = self._Execute(c, ()) np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + def testEigh(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + a = (a + a.T) / 2 + + c = self._NewComputation() + c.Eigh(c.Constant(a), full_matrices=True) + # TODO(b/129396575): Turn this test back on when it passes without fastmath. + # v, w = self._Execute(c, ()) + # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) + + def testSVD(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + c.SVD(c.Constant(a)) + u, d, v = self._Execute(c, ()) + self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) + def testTriangularSolve(self): a_vals = np.array( [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], @@ -1174,13 +1239,21 @@ class SingleOpTest(ComputationTest): dtype=np.float32) c = self._NewComputation() - c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False, - lower=True, transpose_a=True) - self._ExecuteAndCompareClose(c, expected=np.array([ - [0.5, 0.08333334, 0.04629629, 0.03367003], - [2.5, -0.25, -0.1388889, -0.1010101], - [4.5, -0.58333331, -0.32407406, -0.23569024], - ], dtype=np.float32), rtol=1e-4) + c.TriangularSolve( + c.Constant(a_vals), + c.Constant(b_vals), + left_side=False, + lower=True, + transpose_a=True) + self._ExecuteAndCompareClose( + c, + expected=np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], + dtype=np.float32), + rtol=1e-4) def testIsConstant(self): c = self._NewComputation() @@ -1208,6 +1281,33 @@ class SingleOpTest(ComputationTest): expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) np.testing.assert_allclose(g, expected, rtol=1e-4) + def testFft(self): + shape = [2, 3, 4, 5] + rng = np.random.RandomState(0) + a = rng.randn(*shape) + 1.0j * rng.randn(*shape) + a = a.astype(np.complex64) + # FFT + c = self._NewComputation() + c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:]) + self._ExecuteAndCompareClose(c, expected=np.fft.fftn(a, axes=(1, 2, 3)), + rtol=1e-4) + # IFFT + c = self._NewComputation() + c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:]) + self._ExecuteAndCompareClose(c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), + rtol=1e-4) + # RFFT + b = rng.randn(*shape).astype(np.float32) + c = self._NewComputation() + c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:]) + self._ExecuteAndCompareClose(c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), + rtol=1e-4) + # IRFFT + c = self._NewComputation() + c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8]) + self._ExecuteAndCompareClose(c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), + rtol=1e-4) + class EmbeddedComputationsTest(ComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" @@ -1255,8 +1355,9 @@ class EmbeddedComputationsTest(ComputationTest): def _CreateMulF32ByParamComputation(self): """Computation (f32) -> f32 that multiplies one parameter by the other.""" c = self._NewComputation("mul_f32_by_param") - c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), - c.ParameterFromNumpy(NumpyArrayF32(0))) + c.Mul( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) return c.Build() def _CreateMulF64By2Computation(self): @@ -1320,15 +1421,17 @@ class EmbeddedComputationsTest(ComputationTest): def _CreateBinaryGeF32Computation(self): """Computation (f32, f32) -> bool that tests first_param >= second_param.""" c = self._NewComputation("param0_lt_param1") - c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)), - c.ParameterFromNumpy(NumpyArrayF32(0))) + c.Ge( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) return c.Build() def _CreateBinaryGeF64Computation(self): """Computation (f64, f64) -> bool that tests first_param >= second_param.""" c = self._NewComputation("param0_lt_param1") - c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)), - c.ParameterFromNumpy(NumpyArrayF64(0))) + c.Ge( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) return c.Build() def _MakeSample3DArrayF32(self): @@ -1409,26 +1512,28 @@ class EmbeddedComputationsTest(ComputationTest): def testSelectAndScatterF32(self): c = self._NewComputation() - c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), - select=self._CreateBinaryGeF32Computation(), - window_dimensions=(2, 1), - window_strides=(1, 2), - padding=xla_client.PaddingType.VALID, - source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), - init_value=c.Constant(NumpyArrayF32(1)), - scatter=self._CreateBinaryAddF32Computation()) + c.SelectAndScatter( + c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), + select=self._CreateBinaryGeF32Computation(), + window_dimensions=(2, 1), + window_strides=(1, 2), + padding=xla_client.PaddingType.VALID, + source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), + init_value=c.Constant(NumpyArrayF32(1)), + scatter=self._CreateBinaryAddF32Computation()) self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) def testSelectAndScatterF64(self): c = self._NewComputation() - c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])), - select=self._CreateBinaryGeF64Computation(), - window_dimensions=(2, 1), - window_strides=(1, 2), - padding=xla_client.PaddingType.VALID, - source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), - init_value=c.Constant(NumpyArrayF64(1)), - scatter=self._CreateBinaryAddF64Computation()) + c.SelectAndScatter( + c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])), + select=self._CreateBinaryGeF64Computation(), + window_dimensions=(2, 1), + window_strides=(1, 2), + padding=xla_client.PaddingType.VALID, + source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), + init_value=c.Constant(NumpyArrayF64(1)), + scatter=self._CreateBinaryAddF64Computation()) self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) def testReduce1DtoScalarF32(self): @@ -1531,61 +1636,73 @@ class EmbeddedComputationsTest(ComputationTest): def testReduceWindowValidUnitStridesF32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) c = self._NewComputation() - c.ReduceWindow(operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - window_dimensions=(2, 1), window_strides=(1, 1), - padding=xla_client.PaddingType.VALID) + c.ReduceWindow( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), + window_strides=(1, 1), + padding=xla_client.PaddingType.VALID) self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) def testReduceWindowSameUnitStridesF32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) c = self._NewComputation() - c.ReduceWindow(operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - window_dimensions=(2, 1), window_strides=(1, 1), - padding=xla_client.PaddingType.SAME) + c.ReduceWindow( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), + window_strides=(1, 1), + padding=xla_client.PaddingType.SAME) self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) def testReduceWindowValidGeneralStridesF32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) c = self._NewComputation() - c.ReduceWindow(operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - window_dimensions=(2, 1), window_strides=(1, 2), - padding=xla_client.PaddingType.VALID) + c.ReduceWindow( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), + window_strides=(1, 2), + padding=xla_client.PaddingType.VALID) self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) def testReduceWindowValidUnitStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) c = self._NewComputation() - c.ReduceWindow(operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - window_dimensions=(2, 1), window_strides=(1, 1), - padding=xla_client.PaddingType.VALID) + c.ReduceWindow( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), + window_strides=(1, 1), + padding=xla_client.PaddingType.VALID) self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) def testReduceWindowSameUnitStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) c = self._NewComputation() - c.ReduceWindow(operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - window_dimensions=(2, 1), window_strides=(1, 1), - padding=xla_client.PaddingType.SAME) + c.ReduceWindow( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), + window_strides=(1, 1), + padding=xla_client.PaddingType.SAME) self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) def testReduceWindowValidGeneralStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) c = self._NewComputation() - c.ReduceWindow(operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - window_dimensions=(2, 1), window_strides=(1, 2), - padding=xla_client.PaddingType.VALID) + c.ReduceWindow( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), + window_strides=(1, 2), + padding=xla_client.PaddingType.VALID) self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) def testWhileF32(self): @@ -1629,29 +1746,29 @@ class EmbeddedComputationsTest(ComputationTest): def testInfeedS32Values(self): to_infeed = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() - c.Infeed(xla_client.Shape.from_pyval(to_infeed[0])) - compiled_c = c.Build().CompileWithExampleArguments() + c.Infeed(xla_client.shape_from_pyval(to_infeed[0])) + compiled_c = c.Build().Compile() for item in to_infeed: xla_client.transfer_to_infeed(item) for item in to_infeed: - result = compiled_c.ExecuteWithPythonValues() + result = xla_client.execute_with_python_values(compiled_c) self.assertEqual(result, item) def testInfeedThenOutfeedS32(self): to_round_trip = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() - x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0])) + x = c.Infeed(xla_client.shape_from_pyval(to_round_trip[0])) c.Outfeed(x) - compiled_c = c.Build().CompileWithExampleArguments() + compiled_c = c.Build().Compile() for want in to_round_trip: - execution = threading.Thread(target=compiled_c.Execute) + execution = threading.Thread(target=lambda: compiled_c.Execute([])) execution.start() xla_client.transfer_to_infeed(want) got = xla_client.transfer_from_outfeed( - xla_client.Shape.from_pyval(to_round_trip[0])) + xla_client.shape_from_pyval(to_round_trip[0])) execution.join() self.assertEqual(want, got) @@ -1667,8 +1784,9 @@ class EmbeddedComputationsTest(ComputationTest): dnums.index_vector_dim = 1 c = self._NewComputation() - c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), - self._CreateBinaryAddS32Computation(), dnums) + c.Scatter( + c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), + self._CreateBinaryAddS32Computation(), dnums) expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) self._ExecuteAndCompareClose(c, expected=expected) @@ -1679,15 +1797,37 @@ class ErrorTest(ComputationTest): self.f32_scalar_2 = NumpyArrayF32(2.0) self.s32_scalar_2 = NumpyArrayS32(2) + def testCompileWithWrongElementTypeInLayout(self): + c = self._NewComputation() + c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) + c.ParameterFromNumpy(self.s32_scalar_2) + c.ClearOpMetadata() + + options = xla_client.CompileOptions() + options.argument_layouts = [ + xla_client.Shape.array_shape(np.dtype(np.float32), []) + ] + + def TestFun(): + return c.Build().Compile(compile_options=options) + + self.assertRaisesRegexp( + RuntimeError, r".*Invalid argument shape.*" + r"expected s32\[\], got f32\[\].*", TestFun) + def testInvokeWithWrongElementType(self): c = self._NewComputation() c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) c.ParameterFromNumpy(self.s32_scalar_2) c.ClearOpMetadata() + + def TestFun(): + return xla_client.execute_with_python_values(c.Build().Compile(), + [self.f32_scalar_2]) + self.assertRaisesRegexp( - RuntimeError, r"Invalid argument shape.*xla_client_test.py.*" - r"expected s32\[\], got f32\[\]", - lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) + RuntimeError, r"Invalid argument: Argument does not match.*" + r"want s32\[\], got f32\[\].*", TestFun) class ComputationRootTest(ComputationTest): @@ -1700,8 +1840,8 @@ class ComputationRootTest(ComputationTest): extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable arg = NumpyArrayF32(1.0) - compiled_c = c.Build(result).CompileWithExampleArguments([arg]) - ans = compiled_c.ExecuteWithPythonValues([arg]) + compiled_c = c.Build(result).Compile() + ans = xla_client.execute_with_python_values(compiled_c, [arg]) np.testing.assert_allclose(ans, 4.14) diff --git a/tensorflow/compiler/xla/python/xla_data.i b/tensorflow/compiler/xla/python/xla_data.i deleted file mode 100644 index b18583c64d4..00000000000 --- a/tensorflow/compiler/xla/python/xla_data.i +++ /dev/null @@ -1,654 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// SWIG typemaps for building, compiling, and executing XLA computations. -// -// The typemaps below implement/assert the following correspondences -// (with elaborations below): -// -// C++ Python -// -------------------------------------+--------------------------------------- -// Span <- sequence of int -// vector -> sequence of int -// Span <- sequence of LocalOp -// Literal <-> (nested tuple of) numpy ndarray -// std::vector <- sequence of (nested tuple of) ndarray -// Shape -> pair holding (dtype, dimensions) -// <- object duck-typed as xla_client.Shape -// ProgramShape -> pair of ([arg_shapes], ret_shape) -// std::vector <- sequence of xla_client.Shape objects -// PrimitiveType <- int -// Span> <- sequence of int pairs -// PaddingConfig proto <- ducktyped Python proto -// ConvolutionDimensionNumbers proto <- ducktyped Python proto -// DotDimensionNumbers proto <- ducktyped Python proto -// GatherDimensionNumbers proto <- ducktyped Python proto -// ScatterDimensionNumbers proto <- ducktyped Python proto -// Span <- sequence of ReplicaGroup Python proto -// -// Arrows indicate whether a conversion only ever occurs in one -// direction, or whether it is maintained bidirectionally. -// -// The Python objects corresponding to C++ Literals have the type: -// -// T = ndarray | (T, ...) -// -// where a terminal numpy ndarray translates to a Literal with a -// non-tuple Shape, an XLA primitive element type corresponding to the -// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates -// to a tuple-shaped Literal whose tuple components are translated -// recursively. For example, if x is a numpy ndarray in Python, with -// shape (2, 3) and dtype of dtype('float32'), then x translates to a -// Literal with rank 2, dimension 2 and 3, and XLA primitive type -// F32. Meanwhile, -// -// (x, (x, x), (x,)), -// -// translates to a tuple-shaped XLA Literal, whose component subshapes -// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. -// -// Shapes output by C++ become Python objects with the type: -// -// T = (dtype, S) -// S = DIMENSIONS | TUPLE_SHAPES -// DIMENSIONS = (int, ...) -// TUPLE_SHAPES = (T, ...) -// -// In the pair described by the T rule, the terminal dtype determines -// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is -// dtype('O'), numpy's object dtype, the structure represents a tuple -// shape and the expansion of the non-terminal S is -// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type -// and S expands into DIMENSIONS giving dimension sizes. For example: -// -// (dtype('float32'), (3, 5, 7)) -// -// describes a 3x5x7 array of F32s, and -// -// (dtype('O'), ((dtype('float32'), (2, 3)), -// (dtype('float64'), (4, 5)))) -// -// describes a tuple shape with two subshapes: the first a 2x3 F32, -// and the other a 4x5 F64. -// -// The Python int corresponding to a PrimitiveType enum must be valid -// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). -// -// The SWIG object wrappers generated by this file are not intended -// for end use, but rather for internal use in the Python XLA client, -// xla_client.py. -// -// One central reason for the Python-side indirection is that the -// Python-side objects produced by the typemaps in this file are -// further packaged up by xla_client before being passed on. For -// instance, the Python pair produced for a C++ Shape is further -// wrapped in a Python class (xla_client.Shape) so as not to expose -// the raw pair externally. -// -// Other SWIG object wrappers (e.g. of Computation) are further -// wrapped by xla_client in order to set up a custom destructor that -// triggers memory deallocation on the C++ side. -// - - -%module(threads="1") xla_data - -// Keep the GIL except where explicitly specified. -%nothread; - -%include "tensorflow/python/platform/base.i" - -%{ -// Must be included first -#include "tensorflow/python/lib/core/numpy.h" - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/python/numpy_bridge.h" - -using namespace xla; -using namespace xla::swig; - -%} - -// Basic types - - -%typemap(out) std::vector { - PyObject* out = PyList_New($1.size()); - for (int i = 0; i < $1.size(); ++i) { - PyList_SET_ITEM(out, i, PyInt_FromLong($1[i])); - } - $result = out; -} - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = PyBool_FromLong($1.ConsumeValueOrDie()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = PyString_FromString($1.ConsumeValueOrDie().c_str()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) Status { - if (!$1.ok()) { - PyErr_SetString( - PyExc_RuntimeError, $1.ToString().c_str()); - SWIG_fail; - } - Py_INCREF(Py_None); - $result = Py_None; -} - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.resize(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - temps[i] = numpy::PyIntOrPyLongToLong(py_int); - if (temps[i] == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); - SWIG_fail; - } - Py_DECREF(py_int); - Py_DECREF(o); - } - $1 = temps; -} - -// Literal - -%typemap(in) const Literal& (StatusOr literal_status) { - literal_status = numpy::XlaLiteralFromPyObject($input); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - SWIG_fail; - } - $1 = &literal_status.ValueOrDie(); -} - -%typemap(out) Literal (StatusOr obj_status) { - obj_status = numpy::PyObjectFromXlaLiteral(*$1); - if (!obj_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); - SWIG_fail; - } - $result = obj_status.ValueOrDie().release(); -} - -%typemap(out) StatusOr (StatusOr obj_status) { - if (!$1.ok()) { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } - obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); - if (!obj_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); - SWIG_fail; - } - $result = obj_status.ValueOrDie().release(); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - Py_DECREF(o); - SWIG_fail; - } - temps.push_back(literal_status.ConsumeValueOrDie()); - Py_DECREF(o); - } - $1 = &temps; -} - -// OpMetadata - -%typemap(in) const OpMetadata& (OpMetadata temp) { - StatusOr statusor = numpy::OpMetadataFromPyObject($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -// Shape - -%typemap(out) const Shape& { - $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); -} - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release(); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = numpy::PyProgramShapeInfoFromXlaProgramShape( - $1.ConsumeValueOrDie()).release(); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - - -%typemap(in) const Shape& (Shape temp) { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -%typemap(in) const absl::optional& ( - absl::optional temp) { - if ($input == Py_None) { - temp = absl::nullopt; - $1 = &temp; - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; - } -} - -%typemap(out) std::unique_ptr { - $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - $1 = &temps; -} - -%typemap(in) const std::vector >& ( - std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (o == Py_None) { - temps.push_back(absl::nullopt); - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - } - $1 = &temps; -} - -// PrimitiveType - -%typemap(in) PrimitiveType { - PyObject* py_int = numpy::PyNumberToPyInt($input); - if (!py_int) { - PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - SWIG_fail; - } - const long value = numpy::PyIntOrPyLongToLong(py_int); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - SWIG_fail; - } - if (!PrimitiveType_IsValid(value)) { - PyErr_SetString( - PyExc_TypeError, "Argument not valid for PrimitiveType enum"); - Py_DECREF(py_int); - SWIG_fail; - } - $1 = static_cast(value); -} - -// Span> - -%typemap(in) absl::Span > - (std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (!o) { - SWIG_fail; - } - PyObject* first = PyTuple_GetItem(o, 0); - if (!first) { - Py_DECREF(o); - SWIG_fail; - } - PyObject* first_pyint = numpy::PyNumberToPyInt(first); - if (!first_pyint) { - PyErr_SetString( - PyExc_TypeError, - "First pair item cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - PyObject* second = PyTuple_GetItem(o, 1); - if (!second) { - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - PyObject* second_pyint = numpy::PyNumberToPyInt(second); - if (!second_pyint) { - PyErr_SetString( - PyExc_TypeError, - "Second pair item cannot be converted to int"); - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); - if (first_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); - if (second_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - temps.push_back(std::make_pair(first_value, second_value)); - Py_DECREF(o); - } - $1 = temps; -} - -// DotDimensionNumbers - -%typemap(in) const DotDimensionNumbers& - (DotDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "lhs_contracting_dimensions", - dimension_numbers.mutable_lhs_contracting_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "rhs_contracting_dimensions", - dimension_numbers.mutable_rhs_contracting_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "lhs_batch_dimensions", - dimension_numbers.mutable_lhs_batch_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "rhs_batch_dimensions", - dimension_numbers.mutable_rhs_batch_dimensions())) { - SWIG_fail; - } - - $1 = &dimension_numbers; -} - -// PaddingConfig - -%typemap(in) const PaddingConfig& - (PaddingConfig padding_config) { - PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); - if (!dimensions) { - SWIG_fail; - } - - int length = PySequence_Size(dimensions); - if (length == -1) { - Py_DECREF(dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(dimensions, i); - if (!item) { - Py_DECREF(dimensions); - SWIG_fail; - } - int64 edge_padding_low, edge_padding_high, interior_padding; - if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) - || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) - || !GetIntAttr(item, "interior_padding", &interior_padding)) { - Py_DECREF(item); - Py_DECREF(dimensions); - SWIG_fail; - } - Py_DECREF(item); - - PaddingConfig::PaddingConfigDimension* dimension = - padding_config.add_dimensions(); - dimension->set_edge_padding_low(edge_padding_low); - dimension->set_edge_padding_high(edge_padding_high); - dimension->set_interior_padding(interior_padding); - } - Py_DECREF(dimensions); - - $1 = &padding_config; -} - -// ConvolutionDimensionNumbers - -%typemap(in) const ConvolutionDimensionNumbers& - (ConvolutionDimensionNumbers dimension_numbers) { - int64 value; - - if (!GetIntAttr($input, "input_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_batch_dimension(value); - - if (!GetIntAttr($input, "input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_feature_dimension(value); - - if (!GetIntAttr($input, "output_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_batch_dimension(value); - - if (!GetIntAttr($input, "output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_input_feature_dimension(value); - - if (!HandleRepeatedInt64Attribute( - $input, "input_spatial_dimensions", - dimension_numbers.mutable_input_spatial_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "kernel_spatial_dimensions", - dimension_numbers.mutable_kernel_spatial_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "output_spatial_dimensions", - dimension_numbers.mutable_output_spatial_dimensions())) { - SWIG_fail; - } - - $1 = &dimension_numbers; -} - -// GatherDimensionNumbers - -%typemap(in) const GatherDimensionNumbers& - (GatherDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "offset_dims", - dimension_numbers.mutable_offset_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "collapsed_slice_dims", - dimension_numbers.mutable_collapsed_slice_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "start_index_map", - dimension_numbers.mutable_start_index_map())) { - SWIG_fail; - } - - int64 value; - if (!GetIntAttr($input, "index_vector_dim", &value)) { - SWIG_fail; - } - dimension_numbers.set_index_vector_dim(value); - - $1 = &dimension_numbers; -} - -// ScatterDimensionNumbers - -%typemap(in) const ScatterDimensionNumbers& - (ScatterDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "update_window_dims", - dimension_numbers.mutable_update_window_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "inserted_window_dims", - dimension_numbers.mutable_inserted_window_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "scatter_dims_to_operand_dims", - dimension_numbers.mutable_scatter_dims_to_operand_dims())) { - SWIG_fail; - } - - int64 value; - if (!GetIntAttr($input, "index_vector_dim", &value)) { - SWIG_fail; - } - dimension_numbers.set_index_vector_dim(value); - - $1 = &dimension_numbers; -} - -// Span - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - ReplicaGroup rgrp; - if (!HandleRepeatedInt64Attribute( - o, "replica_ids", - rgrp.mutable_replica_ids())) { - SWIG_fail; - } - temps.push_back(rgrp); - Py_DECREF(o); - } - $1 = temps; -} diff --git a/tensorflow/compiler/xla/python/xrt.cc b/tensorflow/compiler/xla/python/xrt.cc index 2c55abc17f8..2390de567f7 100644 --- a/tensorflow/compiler/xla/python/xrt.cc +++ b/tensorflow/compiler/xla/python/xrt.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,285 +13,197 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/xrt.h" - #include #include -#include #include "absl/memory/memory.h" -#include "tensorflow/cc/client/client_session.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/shape_util.h" +#include "absl/types/optional.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/stl.h" +#include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/compiler/xrt/client/xrt_client.h" +#include "tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" -namespace xla { -namespace swig { +namespace tensorflow { +namespace { -XrtAllocation::XrtAllocation(int64 handle, Shape shape, - const string& session_target) - : handle_(handle), shape_(shape), session_target_(session_target) {} +namespace py = pybind11; -XrtAllocation::~XrtAllocation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; +xla::StatusOr> GetTfClient(const string& address, + const string& worker) { + ClusterDef cluster_def; + JobDef* job = cluster_def.add_job(); + job->set_name(worker); + (*job->mutable_tasks())[0] = address; + ChannelCreationFunction channel_func = + ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + TF_ASSIGN_OR_RETURN(std::shared_ptr channel_cache, + GetGrpcChannelCache(cluster_def, channel_func)); + return std::make_shared(cluster_def, channel_cache); +} + +// TODO(phawkins): This function won't produce a particularly good device +// assignment since it knows nothing about the hardware or its topology. +// It's here mostly as a placeholder until we do something smarter. +xla::StatusOr AssignDevices(int num_replicas, + int num_computations) { + return xla::ComputationPlacer().AssignDevices(num_replicas, num_computations); +} + +xla::StatusOr>>> +ExecuteReplicated( + XrtExecutable* executable, + absl::Span>> const> + pyargs) { + const xla::DeviceAssignment& device_assignment = + executable->device_assignment(); + if (pyargs.size() != device_assignment.computation_count()) { + return xla::InvalidArgument( + "Outermost argument list must have one entry per " + "computation; " + "got %d args, device assignment has %d computations.", + pyargs.size(), device_assignment.computation_count()); } - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; - } -} - -/* static */ -StatusOr XrtAllocation::FromLiteral( - const Literal& argument, const string& session_target) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = argument.ToProto(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto literal_string = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); - TF_RETURN_IF_ERROR(root.status()); - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({literal_string, alloc.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); - - int64 handle = outputs[0].scalar()(); - return new XrtAllocation(handle, argument.shape(), session_target); -} - -const int64 XrtAllocation::handle() const { return handle_; } - -const Shape& XrtAllocation::shape() const { return shape_; } - -StatusOr XrtAllocation::ToLiteral() const { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); - TF_RETURN_IF_ERROR(root.status()); - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); - - xla::LiteralProto response; - TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); - return Literal::CreateFromProto(response); -} - -XrtAllocationTuple::XrtAllocationTuple(std::vector elements) - : elements_(std::move(elements)) { - for (auto* element : elements_) { - CHECK(element != nullptr); - } -} - -XrtAllocationTuple::~XrtAllocationTuple() { - for (XrtAllocation* element : elements_) { - if (element != nullptr) { - delete element; + std::vector>> args(pyargs.size()); + for (int i = 0; i < pyargs.size(); ++i) { + if (pyargs[i].size() != device_assignment.replica_count() || + pyargs[i].empty()) { + return xla::InvalidArgument( + "Mismatch in number of replicas; got %d arguments, but " + "device assignment has %d replicas.", + pyargs[i].size(), device_assignment.replica_count()); } - } -} -StatusOr XrtAllocationTuple::Release(int i) { - XrtAllocation* element = elements_[i]; - if (element == nullptr) { - return InvalidArgument("Attempted to release already-released element %d.", - i); - } - elements_[i] = nullptr; - return element; -} - -int64 XrtAllocationTuple::size() const { return elements_.size(); } - -StatusOr XrtExecutable::CompileForXrt( - const string& hlo_module_proto, const std::vector& argument_shapes, - const Shape& result_shape, const string& session_target) { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto compile = tensorflow::ops::XRTCompile(root, program); - TF_RETURN_IF_ERROR(root.status()); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - ProgramShape program_shape; - for (auto& shape : argument_shapes) { - *program_shape.add_parameters() = shape; - } - *program_shape.mutable_result() = result_shape; - - LayoutUtil::SetToDefaultLayout(&program_shape); - *config->mutable_program_shape() = program_shape.ToProto(); - c.mutable_hlo_snapshot() - ->mutable_hlo() - ->mutable_hlo_module() - ->ParsePartialFromString(hlo_module_proto); - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({program, c.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); - - int64 handle = outputs[0].scalar()(); - return new XrtExecutable(program_shape, handle, session_target); -} - -XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle, - const string& session_target) - : program_shape_(program_shape), - handle_(handle), - session_target_(session_target) {} - -XrtExecutable::~XrtExecutable() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({computation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; - } -} - -StatusOr XrtExecutable::Execute( - absl::Span argument_handles) { - const int num_expected_arguments = program_shape().parameters().size(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - std::vector arguments; - arguments.reserve(num_expected_arguments); - for (int i = 0; i < num_expected_arguments; ++i) { - arguments.push_back( - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); - } - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto execution_config = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto execute = tensorflow::ops::XRTExecute(root, computation_handle, - execution_config, arguments); - TF_RETURN_IF_ERROR(root.status()); - - TF_RET_CHECK(argument_handles.size() == arguments.size()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(false); - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - for (int i = 0; i < arguments.size(); ++i) { - inputs.insert({arguments[i], argument_handles[i]->handle()}); - } - inputs.insert({computation_handle, handle()}); - inputs.insert({execution_config, e.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); - - int64 output = outputs[0].scalar()(); - return new XrtAllocation(output, program_shape().result(), session_target_); -} - -const ProgramShape& XrtExecutable::program_shape() const { - return program_shape_; -} - -int64 XrtExecutable::handle() const { return handle_; } - -void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } - -void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; } - -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target) { - const Shape& tuple_shape = allocation->shape(); - - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); - auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); - TF_RETURN_IF_ERROR(root.status()); - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - std::vector results; - for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - inputs.clear(); - inputs.insert({base_handle, allocation->handle()}); - inputs.insert({shape_index, {i}}); - std::vector outputs; - auto status = session.Run(inputs, {subtuple}, &outputs); - if (!status.ok()) { - // Clean up before returning non-ok status. - for (int j = 0; j < results.size(); ++j) { - delete results[j]; + int arg_count = pyargs[i][0].size(); + args[i] = xla::Array2D>( + device_assignment.replica_count(), arg_count); + for (int j = 0; j < pyargs[i].size(); ++j) { + if (pyargs[i][j].size() != arg_count) { + return xla::InvalidArgument( + "Mismatched number of arguments to computation %d for " + "different replicas; %d vs %d arguments.", + i, arg_count, pyargs[i][j].size()); + } + for (int k = 0; k < arg_count; ++k) { + args[i](j, k) = pyargs[i][j][k]; } - return status; } - const int64 subtuple_handle = outputs[0].scalar()(); - const Shape& subtuple_shape = - ShapeUtil::GetTupleElementShape(tuple_shape, i); - results.push_back( - new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); } - return new XrtAllocationTuple(std::move(results)); + + TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteReplicated(args)); + std::vector>> pyresult(result.n1()); + for (int i = 0; i < result.n1(); ++i) { + pyresult[i].resize(result.n2()); + for (int j = 0; j < result.n2(); ++j) { + pyresult[i][j] = result(i, j); + } + } + return pyresult; } -} // namespace swig -} // namespace xla +} // namespace + +void AddXrtSubmodule(py::module* module) { + py::module m = module->def_submodule("xrt", "XRT backend"); + + m.def("AssignDevices", &AssignDevices, + "Computes a default device assignment."); + + py::class_> xrt_tf_client( + m, "XrtTfClient"); + m.def("GetTfClient", &GetTfClient, "Returns a TensorFlow client."); + + py::class_(m, "XrtTfContextOptions") + .def(py::init<>()) + .def_readwrite("async", &XrtTfContext::Options::async) + .def_readwrite("max_queue_size", &XrtTfContext::Options::max_queue_size); + + py::class_>(m, "XrtTfContext") + .def_static("Create", &XrtTfContext::Create); + + py::class_>(m, "XrtContext") + .def_static("Create", &XrtContext::Create) + .def("DeviceCount", &XrtContext::device_count) + .def_property_readonly("tf_device_ids", &XrtContext::tf_device_ids); + + py::class_>(m, "XrtBuffer") + .def_static("from_literal", &XrtBuffer::FromLiteral) + .def_static("make_tuple", &XrtBuffer::MakeTuple) + .def("to_py", + [](std::shared_ptr buffer) -> xla::StatusOr { + auto literal = absl::make_unique(); + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(*literal, buffer->ToLiteral()); + } + return xla::LiteralToPython(std::move(literal)); + }) + .def("delete", &XrtBuffer::Delete) + .def("destructure", &XrtBuffer::DestructureTuple) + .def("device", &XrtBuffer::xrt_device_ordinal) + .def("shape", &XrtBuffer::shape) + .def("is_deleted", + [](const XrtBuffer& buffer) { return !buffer.handle().valid(); }); + + py::class_>(m, "XrtExecutable") + .def_static("Compile", + [](std::shared_ptr context, + const std::string& hlo_module_proto_serialized, + const std::vector& argument_shapes, + const xla::Shape& result_shape, + const xla::DeviceAssignment& device_assignment) { + xla::HloModuleProto hlo_module_proto; + hlo_module_proto.ParsePartialFromString( + hlo_module_proto_serialized); + return XrtExecutable::Compile(context, hlo_module_proto, + argument_shapes, result_shape, + device_assignment); + }) + .def("Execute", &XrtExecutable::Execute) + .def("ExecuteReplicated", + [](XrtExecutable& executable, + std::vector>>> + pyargs) + -> xla::StatusOr< + std::vector>>> { + return ExecuteReplicated(&executable, pyargs); + }) + // Simplified API for compatibility with the local ExecutePerReplica, + // that only accepts one computation per replica. + // TODO(phawkins): support multiple computations per replica everywhere + // and remove this entry point. + .def("ExecutePerReplica", + [](XrtExecutable& executable, + std::vector>> pyargs) + -> xla::StatusOr>> { + const xla::DeviceAssignment& device_assignment = + executable.device_assignment(); + if (device_assignment.computation_count() != 1) { + return xla::InvalidArgument( + "ExecutePerReplica requires one computation per replica, " + "got %d.", + device_assignment.computation_count()); + } + TF_ASSIGN_OR_RETURN(auto result, + ExecuteReplicated(&executable, {pyargs})); + TF_RET_CHECK(result.size() == 1); + return result[0]; + }) + .def("Delete", &XrtExecutable::Delete) + .def("DeviceOrdinals", [](const XrtExecutable& executable) { + return std::vector(executable.device_assignment().begin(), + executable.device_assignment().end()); + }); + + m.doc() = "XRT backend plugin"; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xla/python/xrt.h b/tensorflow/compiler/xla/python/xrt.h index 710c3af3fa6..4263cfad1ce 100644 --- a/tensorflow/compiler/xla/python/xrt.h +++ b/tensorflow/compiler/xla/python/xrt.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,103 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ -#include -#include +#include "include/pybind11/pybind11.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/shape.h" +namespace tensorflow { -namespace xla { -namespace swig { +void AddXrtSubmodule(pybind11::module* module); -// Represents a reference to literals that live in a device-allocated buffer via -// XRT. Specifically, wraps an int64 handle produced by running the allocation -// graph, and an XLA shape to track the referent's shape. -class XrtAllocation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which allocation and deallocation - // graphs are run. - static StatusOr FromLiteral(const Literal& argument, - const string& session_target); - - XrtAllocation(int64 handle, Shape shape, const string& session_target); - ~XrtAllocation(); - StatusOr ToLiteral() const; - const Shape& shape() const; - const int64 handle() const; - - private: - const int64 handle_; - const Shape shape_; - const string session_target_; -}; - -// Result of a tuple destructuring operation on an XrtAllocation. -class XrtAllocationTuple { - public: - // Note: any XrtAllocation elements that are not Release()'d will be - // deallocated in the destructor. - explicit XrtAllocationTuple(std::vector elements); - - ~XrtAllocationTuple(); - - // Releases the ith element to the caller. Further attempts to release the ith - // element will return an invalid argument error. - StatusOr Release(int i); - - // Returns the number of elements in the destructured tuple. - int64 size() const; - - private: - std::vector elements_; -}; - -// Destructures a tuple-valued XrtAllocation into its constituent elements -// in XrtAllocationTuple form. -// -// Accepts a `session_target` argument, used in constructing the -// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, -// and passed along in constructing each constituent XrtAllocation. -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target); - -// Represents a compiled computation that can be executed given handles to -// device-allocated literals. Specifically, wraps an XRT computation handle. -class XrtExecutable { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the compilation graph is run. - static StatusOr CompileForXrt( - const string& hlo_module_proto, const std::vector& argument_shapes, - const Shape& result_shape, const string& session_target); - - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the execution graph is run. - XrtExecutable(const ProgramShape& program_shape, int64 handle, - const string& session_target); - ~XrtExecutable(); - - std::vector DeviceOrdinals() const { return {0}; } - - StatusOr Execute( - absl::Span argument_handles); - - const ProgramShape& program_shape() const; - int64 handle() const; - - private: - const ProgramShape program_shape_; - const int64 handle_; - const string session_target_; -}; - -// Functions for freeing resources from the Python side. -void DeleteXrtAllocation(XrtAllocation* allocation); -void DeleteXrtExecutable(XrtExecutable* computation); - -} // namespace swig -} // namespace xla +} // namespace tensorflow #endif // TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ diff --git a/tensorflow/compiler/xla/python/xrt.i b/tensorflow/compiler/xla/python/xrt.i deleted file mode 100644 index 456dd7be86e..00000000000 --- a/tensorflow/compiler/xla/python/xrt.i +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Wrappers for XRT ops. - -%module(threads="1") xrt - -// Keep the GIL except where explicitly specified. -%nothread; - -%include "tensorflow/python/platform/base.i" -%include "tensorflow/compiler/xla/python/xla_data.i" - -%{ -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/python/xrt.h" - -using namespace xla; -using namespace xla::swig; - -%} - -// Computation and buffer/allocation types - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtExecutable*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtAllocation*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtAllocationTuple*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - XrtAllocation* xrta; - if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), - SWIG_POINTER_EXCEPTION)) == -1) { - SWIG_fail; - } - temps.push_back(xrta); - Py_DECREF(o); - } - $1 = temps; -} - - -%ignoreall -%unignore xla; -%unignore xla::swig; -%unignore xla::swig::XrtAllocation; -%unignore xla::swig::XrtAllocation::FromLiteral; -%unignore xla::swig::XrtAllocation::ToLiteral; -%unignore xla::swig::XrtAllocation::shape; -%unignore xla::swig::XrtAllocationTuple; -%unignore xla::swig::XrtAllocationTuple::Release; -%unignore xla::swig::XrtAllocationTuple::size; -%unignore xla::swig::XrtExecutable; -%unignore xla::swig::XrtExecutable::CompileForXrt; -%unignore xla::swig::XrtExecutable::DeviceOrdinals; -%unignore xla::swig::XrtExecutable::Execute; -%unignore xla::swig::DestructureXrtAllocationTuple; -%unignore xla::swig::DeleteXrtAllocation; -%unignore xla::swig::DeleteXrtExecutable; - -%thread; -%include "tensorflow/compiler/xla/python/xrt.h" -%nothread; - -%unignoreall diff --git a/tensorflow/compiler/xla/python/xrt.py b/tensorflow/compiler/xla/python/xrt.py new file mode 100644 index 00000000000..76a99f20481 --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.py @@ -0,0 +1,84 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""XLA backend that runs XRT operators via TensorFlow remote eager. + +This module implements the Python XLA client's `Backend` abstraction using XRT, +which embeds XLA's compiler/runtime operations as TensorFlow +operations. The module uses TensorFlow's remote eager RPC API to invoke XRT +operations. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.compiler.xla.python import xla_client +from tensorflow.compiler.xla.python import xla_extension as _xla +# pylint: enable=g-direct-tensorflow-import + + +def get_tf_context(target, worker): + """Returns a TensorFlow RPC client object. + + Args: + target: string; a host:port pair (e.g., '10.0.101.1:8470') naming an XRT + server. + worker: string; the task name of the remote TensorFlow worker. + """ + client = _xla.xrt.GetTfClient(target, worker) + options = _xla.xrt.XrtTfContextOptions() + options.max_queue_size = 10000 + return _xla.xrt.XrtTfContext.Create(options, client, worker, 0) + + +class XrtBackend(xla_client.Backend): + """XLA backend using XRT. + + Args: + tf_context: an XrtTfContext object. + tf_device_type: the type of TensorFlow device to use for XRT (e.g. `"TPU"`). + """ + + def __init__(self, tf_context, tf_device_type, platform="tpu"): + super(XrtBackend, self).__init__(platform) + self.tf_device_type = tf_device_type + + self.context = _xla.xrt.XrtContext.Create(tf_context, tf_device_type) + + def device_count(self): + return self.context.DeviceCount() + + def buffer_from_pyval(self, pyval, device=0): + return _xla.xrt.XrtBuffer.from_literal(self.context, device, pyval) + + def make_tuple(self, buffers, device_ordinal): + return _xla.xrt.XrtBuffer.make_tuple(self.context, buffers) + + def compile(self, computation, compile_options): + # pylint: disable=protected-access + program_shape = computation.GetProgramShape() + # pylint: enable=protected-access + proto = computation.GetSerializedProto() + # TODO(phawkins): use the layouts in compile_options. + arg_shapes = [ + shape.with_major_to_minor_layout_if_absent() + for shape in program_shape.parameter_shapes() + ] + result_shape = ( + program_shape.result_shape().with_major_to_minor_layout_if_absent()) + device_assignment = _xla.xrt.AssignDevices(compile_options.num_replicas, 1) + return _xla.xrt.XrtExecutable.Compile(self.context, proto, arg_shapes, + result_shape, device_assignment) diff --git a/tensorflow/compiler/xla/python/xrt_test.py b/tensorflow/compiler/xla/python/xrt_test.py new file mode 100644 index 00000000000..29257c11d68 --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt_test.py @@ -0,0 +1,77 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the XRT client.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.xla.python import xla_client +from tensorflow.compiler.xla.python import xrt +from tensorflow.python.platform import test + + +def BuildAddAndScaleComputation(shape1, shape2): + """Builds the computation (a + b) * 3.""" + b = xla_client.ComputationBuilder("add-and-scale") + x = b.ParameterWithShape(shape1) + y = b.ParameterWithShape(shape2) + dtype = shape1.numpy_dtype().type + b.Mul(b.Add(x, y), b.Constant(dtype(3))) + return b.Build() + + +# TODO(phawkins): add more tests, beyond a simple "hello world" example. +class XrtBackendTest(test.TestCase): + + def testBasics(self): + (worker,), _ = test.create_local_cluster(num_workers=1, num_ps=0) + self.assertTrue(worker.target.startswith("grpc://")) + tf_context = xrt.get_tf_context(worker.target[len("grpc://"):], "worker") + backend = xrt.XrtBackend(tf_context, "XLA_CPU") + + a = np.arange(10) + b = np.arange(10) + + c = BuildAddAndScaleComputation( + xla_client.shape_from_pyval(a), xla_client.shape_from_pyval(b)) + + executable = c.Compile(backend=backend) + output = xla_client.execute_with_python_values( + executable, (a, b), backend=backend) + self.assertAllEqual(output, (a + b) * 3) + + def testTuples(self): + (worker,), _ = test.create_local_cluster(num_workers=1, num_ps=0) + self.assertTrue(worker.target.startswith("grpc://")) + tf_context = xrt.get_tf_context(worker.target[len("grpc://"):], "worker") + backend = xrt.XrtBackend(tf_context, "XLA_CPU") + + a = np.random.randn(10) + b = np.random.randn(15, 3) + pieces = [ + xla_client.Buffer.from_pyval(a, backend=backend), + xla_client.Buffer.from_pyval(b, backend=backend) + ] + t = xla_client.Buffer.make_tuple(pieces, backend=backend) + a_out, b_out = t.destructure() + self.assertAllEqual(a, a_out.to_py()) + self.assertAllEqual(b, b_out.to_py()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index a1b0f4045ff..3c28e4be554 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -94,7 +94,7 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { } TEST_F(ReferenceUtilTest, MapArray2D) { - auto identity = [](float value) { return log(exp(value)); }; + auto identity = [](float value) { return std::log(std::exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a7934463a54..1e7a924e350 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -7,7 +7,7 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py", ) -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") licenses(["notice"]) # Apache 2.0 @@ -18,6 +18,9 @@ package_group( includes = [ "//tensorflow/compiler/xla:friends", ], + packages = [ + "//learning/brain/experimental/tf_runtime/...", + ], ) xla_proto_library( @@ -167,6 +170,23 @@ tf_cc_test( ], ) +cc_library( + name = "dump", + srcs = ["dump.cc"], + hdrs = ["dump.h"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_proto_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], @@ -202,6 +222,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -334,6 +355,7 @@ cc_library( ":hlo_proto", ":name_uniquer", "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -415,10 +437,10 @@ tf_cc_test( srcs = ["pattern_matcher_test.cc"], deps = [ ":hlo", + ":hlo_parser", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -486,8 +508,8 @@ cc_library( hdrs = ["hlo_matchers.h"], deps = [ ":hlo", + ":hlo_parser", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -530,13 +552,13 @@ tf_cc_test( srcs = ["hlo_sharding_test.cc"], deps = [ ":hlo", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -564,6 +586,7 @@ tf_cc_test( srcs = ["call_graph_test.cc"], deps = [ ":call_graph", + ":hlo", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -571,7 +594,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -634,6 +656,7 @@ tf_cc_test( deps = [ ":call_graph", ":flatten_call_graph", + ":hlo", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -641,7 +664,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -672,7 +694,6 @@ cc_library( deps = [ ":compiler", ":computation_placer", - ":device_memory_allocator", ":platform_util", ":stream_pool", ":transfer_manager", @@ -682,6 +703,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", "//third_party/eigen3", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -701,7 +723,8 @@ cc_library( ":compilation_cache", ":compiler", ":computation_layout", - ":device_memory_allocator", + ":computation_placer", + ":dump", ":dynamic_dimension_inference", ":executable", ":execution_tracker", @@ -730,6 +753,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -746,7 +770,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":device_memory_allocator", ":executable", ":hlo", ":hlo_execution_profile", @@ -766,6 +789,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -781,6 +805,7 @@ cc_library( ":backend", ":compiler", ":computation_layout", + ":dump", ":platform_util", ":service", "//tensorflow/compiler/xla:debug_options_flags", @@ -810,8 +835,8 @@ cc_library( name = "gpu_plugin", deps = [ ":service", - "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", + "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", ], @@ -833,7 +858,6 @@ cc_library( srcs = ["shaped_buffer.cc"], hdrs = ["shaped_buffer.h"], deps = [ - ":device_memory_allocator", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -843,6 +867,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -856,7 +881,6 @@ tf_cc_test( srcs = ["shaped_buffer_test.cc"], deps = [ ":cpu_plugin", - ":device_memory_allocator", ":platform_util", ":shaped_buffer", "//tensorflow/compiler/xla:shape_util", @@ -866,6 +890,7 @@ tf_cc_test( "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", ], ) @@ -879,7 +904,7 @@ cc_library( ], deps = [ ":computation_layout", - ":device_memory_allocator", + ":dump", ":hlo", ":hlo_execution_profile", ":hlo_graph_dumper", @@ -899,6 +924,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -912,6 +938,7 @@ cc_library( hdrs = ["compiler.h"], deps = [ ":buffer_value", + ":computation_placer", ":executable", ":hlo", ":hlo_module_config", @@ -964,7 +991,6 @@ cc_library( hdrs = ["allocation_tracker.h"], deps = [ ":backend", - ":device_memory_allocator", ":transfer_manager", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -973,6 +999,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1075,6 +1102,7 @@ tf_cc_test( ":buffer_liveness", ":hlo", ":hlo_dataflow_analysis", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -1131,6 +1159,7 @@ tf_cc_test( ":hlo", ":hlo_memory_scheduler", ":hlo_ordering", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1138,7 +1167,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1180,10 +1208,10 @@ tf_cc_test( ":hlo_dataflow_analysis", ":hlo_memory_scheduler", ":hlo_ordering", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -1430,8 +1458,8 @@ tf_cc_test( srcs = ["instruction_fusion_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", ":instruction_fusion", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -1442,11 +1470,11 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":hlo", + ":hlo_pass", ":hlo_reachability", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1638,6 +1666,7 @@ cc_library( ":hlo_pass", ":hlo_query", ":pattern_matcher", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1686,6 +1715,44 @@ tf_cc_test( ], ) +cc_library( + name = "all_reduce_simplifier", + srcs = ["all_reduce_simplifier.cc"], + hdrs = ["all_reduce_simplifier.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_replication_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + ], +) + +tf_cc_test( + name = "all_reduce_simplifier_test", + srcs = ["all_reduce_simplifier_test.cc"], + deps = [ + ":all_reduce_simplifier", + ":hlo", + ":hlo_parser", + ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "batch_dot_simplification", srcs = ["batch_dot_simplification.cc"], @@ -1725,8 +1792,8 @@ tf_cc_test( srcs = ["gather_expander_test.cc"], deps = [ ":gather_expander", + ":hlo_parser", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -1737,10 +1804,12 @@ cc_library( srcs = ["conditional_simplifier.cc"], hdrs = ["conditional_simplifier.h"], deps = [ + ":call_graph", ":call_inliner", ":hlo", ":hlo_pass", "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1757,6 +1826,7 @@ tf_cc_test( ":conditional_simplifier", ":hlo", ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1821,9 +1891,9 @@ tf_cc_test( name = "while_loop_analysis_test", srcs = ["while_loop_analysis_test.cc"], deps = [ + ":hlo_parser", ":while_loop_analysis", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -1930,34 +2000,6 @@ tf_cc_test( ], ) -cc_library( - name = "implicit_broadcast_remover", - srcs = ["implicit_broadcast_remover.cc"], - hdrs = ["implicit_broadcast_remover.h"], - deps = [ - ":hlo", - ":hlo_dce", - ":hlo_pass", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "implicit_broadcast_remover_test", - srcs = ["implicit_broadcast_remover_test.cc"], - deps = [ - ":hlo_matchers", - ":implicit_broadcast_remover", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_test_base", - ], -) - cc_library( name = "dot_decomposer", srcs = ["dot_decomposer.cc"], @@ -1974,6 +2016,18 @@ cc_library( ], ) +tf_cc_test( + name = "dot_decomposer_test", + srcs = ["dot_decomposer_test.cc"], + deps = [ + ":dot_decomposer", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -2026,9 +2080,11 @@ cc_library( deps = [ ":hlo", ":while_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", @@ -2060,7 +2116,10 @@ tf_cc_test( srcs = ["dynamic_padder_test.cc"], deps = [ ":dynamic_padder", + ":hlo", + ":hlo_matchers", ":hlo_parser", + ":hlo_runner", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2069,9 +2128,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], @@ -2082,6 +2138,9 @@ tf_cc_test( srcs = ["dynamic_dimension_inference_test.cc"], deps = [ ":dynamic_dimension_inference", + ":hlo", + ":hlo_matchers", + ":hlo_runner", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2089,9 +2148,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], @@ -2187,6 +2243,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", ], ) @@ -2241,7 +2298,7 @@ tf_cc_test( ":cpu_plugin", ":hlo_cost_analysis", ":hlo_execution_profile", - "//tensorflow/compiler/xla/service:hlo_parser", + ":hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2254,14 +2311,14 @@ tf_cc_test( srcs = ["hlo_computation_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_map", @@ -2269,24 +2326,6 @@ tf_cc_test( ], ) -tf_cc_binary( - name = "graphviz_example", - srcs = ["graphviz_example.cc"], - deps = [ - ":hlo", - ":hlo_graph_dumper", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - tf_cc_test( name = "hlo_module_test", srcs = ["hlo_module_test.cc"], @@ -2425,6 +2464,38 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_replication_analysis", + srcs = ["hlo_replication_analysis.cc"], + hdrs = ["hlo_replication_analysis.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "hlo_replication_analysis_test", + srcs = ["hlo_replication_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_parser", + ":hlo_replication_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/memory", + ], +) + cc_library( name = "hlo_liveness_analysis", srcs = ["hlo_liveness_analysis.cc"], @@ -2452,13 +2523,13 @@ tf_cc_test( deps = [ ":hlo", ":hlo_liveness_analysis", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2562,6 +2633,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -2618,6 +2690,7 @@ cc_library( "layout_assignment.h", ], deps = [ + ":call_graph", ":computation_layout", ":hlo", ":hlo_casting_utils", @@ -2635,6 +2708,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -2650,6 +2724,7 @@ cc_library( hdrs = ["copy_insertion.h"], deps = [ ":buffer_liveness", + ":dump", ":hlo", ":hlo_alias_analysis", ":hlo_dce", @@ -2838,12 +2913,12 @@ tf_cc_test( deps = [ ":hlo", ":hlo_module_dce", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -2860,6 +2935,7 @@ tf_cc_test( ":algebraic_simplifier", ":computation_layout", ":hlo", + ":hlo_parser", ":layout_assignment", ":pattern_matcher", ":pattern_matcher_gmock", @@ -2870,7 +2946,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2905,6 +2980,8 @@ cc_library( "hlo_pass_pipeline.h", ], deps = [ + ":compilation_stats", + ":dump", ":hlo", ":hlo_graph_dumper", ":hlo_pass", @@ -2967,12 +3044,12 @@ tf_cc_test( ":hlo", ":hlo_cse", ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -3156,27 +3233,6 @@ tf_cc_test( ], ) -cc_library( - name = "device_memory_allocator", - srcs = [ - "device_memory_allocator.cc", - "owning_device_memory.cc", - ], - hdrs = [ - "device_memory_allocator.h", - "owning_device_memory.h", - ], - deps = [ - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "maybe_owning_device_memory", srcs = [ @@ -3186,7 +3242,7 @@ cc_library( "maybe_owning_device_memory.h", ], deps = [ - ":device_memory_allocator", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], @@ -3229,10 +3285,10 @@ xla_test( "gpu", ], deps = [ + ":hlo_parser", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -3245,6 +3301,7 @@ cc_library( hdrs = ["hlo_module_config.h"], deps = [ ":computation_layout", + ":computation_placer", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -3293,10 +3350,7 @@ tf_cc_test( cc_library( name = "hlo_graph_dumper", - srcs = [ - "hlo_graph_dumper.cc", - "hlo_graph_html_renderer.cc", - ], + srcs = ["hlo_graph_dumper.cc"], hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", @@ -3306,6 +3360,7 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", @@ -3328,6 +3383,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -3355,6 +3411,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_matchers", + ":hlo_parser", ":shape_inference", ":transpose_folding", "//tensorflow/compiler/xla:literal", @@ -3363,7 +3420,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -3606,10 +3662,10 @@ tf_cc_test( name = "tuple_util_test", srcs = ["tuple_util_test.cc"], deps = [ + ":hlo_matchers", + ":hlo_parser", ":tuple_util", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -3635,11 +3691,11 @@ tf_cc_test( name = "while_util_test", srcs = ["while_util_test.cc"], deps = [ + ":hlo_matchers", + ":hlo_parser", ":while_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", ], @@ -3670,9 +3726,9 @@ tf_cc_test( srcs = ["while_loop_invariant_code_motion_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], @@ -3698,9 +3754,9 @@ tf_cc_test( srcs = ["while_loop_constant_sinking_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], @@ -3717,7 +3773,6 @@ cc_library( ":hlo_memory_scheduler", ":hlo_pass", ":hlo_pass_pipeline", - ":implicit_broadcast_remover", "//tensorflow/compiler/xla:statusor", ], ) @@ -3806,6 +3861,7 @@ tf_cc_test( ":pattern_matcher_gmock", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", # fixdeps: keep @@ -3900,6 +3956,8 @@ cc_library( hdrs = ["ar_crs_combiner.h"], deps = [ ":call_graph", + ":hlo", + ":hlo_pass", ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -3907,23 +3965,34 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_pass", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) +cc_library( + name = "compilation_stats", + srcs = ["compilation_stats.cc"], + hdrs = ["compilation_stats.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "dynamic_index_splitter", srcs = ["dynamic_index_splitter.cc"], hdrs = ["dynamic_index_splitter.h"], deps = [ + ":hlo", ":hlo_casting_utils", + ":hlo_pass", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_pass", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -3989,3 +4058,81 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "conditional_to_select", + srcs = ["conditional_to_select.cc"], + hdrs = ["conditional_to_select.h"], + deps = [ + ":call_graph", + ":call_inliner", + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "conditional_to_select_test", + srcs = ["conditional_to_select_test.cc"], + deps = [ + ":conditional_to_select", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "slice_sinker", + srcs = ["slice_sinker.cc"], + hdrs = ["slice_sinker.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "custom_call_target_registry", + srcs = ["custom_call_target_registry.cc"], + hdrs = ["custom_call_target_registry.h"], + visibility = ["//visibility:public"], +) + +tf_cc_test( + name = "slice_sinker_test", + srcs = ["slice_sinker_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_parser", + ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":slice_sinker", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 0ed5963521c..53afc598813 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -183,6 +184,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleCompare(HloInstruction* compare) override; + Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConstant(HloInstruction* constant) override; @@ -234,6 +237,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; @@ -250,35 +254,17 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, - const AlgebraicSimplifierOptions& options); + const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier); private: explicit AlgebraicSimplifierVisitor(HloComputation* computation, - const AlgebraicSimplifierOptions& options) - : computation_(computation), options_(options) {} + const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier) + : computation_(computation), options_(options), simplifier_(simplifier) {} - // Transforms Dots where at least one input is a vector or has a degenerate - // dimension and converts it into a multiply and reduce. This should enable - // more fusion than leaving the nodes as Dot operations. - StatusOr HandleDotStrengthReduction(HloInstruction* dot); - - // Removes dimension dim from hlo. - HloInstruction* StripDim(HloInstruction* hlo, int64 dim) { - CHECK_EQ(hlo->shape().dimensions(dim), 1); - return computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo)); - } - - // Reshapes an instruction to rank 1 if it is not already rank 1. - HloInstruction* Flatten(HloInstruction* hlo) { - if (hlo->shape().rank() == 1) { - return hlo; - } - return computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(hlo->shape().element_type(), - {ShapeUtil::ElementsIn(hlo->shape())}), - hlo)); - } + // Removes degenerate dimension from dot. + StatusOr RemoveDegenerateDimensionFromDot(HloInstruction* dot); // Converts to primitive type if the input hlo is not that type, otherwise // returns the original hlo. @@ -287,11 +273,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { if (hlo->shape().element_type() == element_type) { return hlo; } - return computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + Shape changed_shape = + ShapeUtil::ChangeElementType(hlo->shape(), element_type); + simplifier_->UpdateLayout(&changed_shape); + return computation_->AddInstruction( + HloInstruction::CreateConvert(changed_shape, hlo)); } - // Transposes a dot operand such that the batch dimensions are the msot major, + // Transposes a dot operand such that the batch dimensions are the most major, // and the contracting dimensions are most minor. StatusOr NormalizeDotOperandToBatchMajorAndContractingMinor( HloInstruction* dot_operand, absl::Span batch_dimensions, @@ -307,26 +296,26 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { transpose_dimensions.insert(transpose_dimensions.end(), contracting_dimensions.begin(), contracting_dimensions.end()); + if (absl::c_is_sorted(transpose_dimensions)) { + return dot_operand; + } return MakeTransposeHlo(dot_operand, transpose_dimensions); } // Helper method to perform and add reduction on a list of dimensions. HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims) { - HloInstruction* zero = - computation_->AddInstruction(HloInstruction::CreateConstant( + HloInstruction* zero = computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::FilterDimensions( [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, hlo->shape()); + simplifier_->UpdateLayout(&shape); return computation_->AddInstruction(HloInstruction::CreateReduce( shape, hlo, zero, dims, AddReduce_computation)); } - HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { - return AddReduce(hlo, std::vector{dim}); - } - // Convenience method for replacing an instruction with a bitcast. If operand // is not null, then the bitcast will use the specified operand instead of the // operand of the instruction. @@ -396,6 +385,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr OptimizeDotOfGather(HloInstruction* dot); + StatusOr OptimizeDotOfReorderContractingDims( + HloInstruction* dot); + HloComputation* GetOrCreateScalarAddComputation() { if (scalar_add_computation_) { return scalar_add_computation_; @@ -403,6 +395,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloComputation::Builder b("scalar_add_computation"); Shape shape = ShapeUtil::MakeShape(F32, {}); + simplifier_->UpdateLayout(&shape); auto scalar_lhs = b.AddInstruction( HloInstruction::CreateParameter(0, shape, "scalar_lhs")); auto scalar_rhs = b.AddInstruction( @@ -440,13 +433,16 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Cached computation for adding two scalar F32. HloComputation* scalar_add_computation_ = nullptr; + + AlgebraicSimplifier* simplifier_ = nullptr; }; } // namespace -bool AlgebraicSimplifierVisitor::Run( - HloComputation* computation, const AlgebraicSimplifierOptions& options) { - AlgebraicSimplifierVisitor visitor(computation, options); +bool AlgebraicSimplifierVisitor::Run(HloComputation* computation, + const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier) { + AlgebraicSimplifierVisitor visitor(computation, options, simplifier); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -518,9 +514,17 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; HloInstruction *a, *c1, *c2; if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)), - m::Constant(&c2)))) { + m::Constant(&c2))) || + Match(add, m::Add(m::Add(m::NonConstant(&a), + m::Broadcast(m::ConstantScalar(&c1))), + m::Broadcast(m::ConstantScalar(&c2))))) { TF_ASSIGN_OR_RETURN(auto* sum_of_constants, MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); + if (ShapeUtil::IsScalar(sum_of_constants->shape()) && + !ShapeUtil::IsScalar(add->shape())) { + sum_of_constants = computation_->AddInstruction( + HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {})); + } return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a, sum_of_constants)); @@ -581,20 +585,20 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) { return Status::OK(); } + } - // A && False => False - VLOG(10) << "trying transform [A && False => False]: " - << logical_and->ToString(); - if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) { - return Status::OK(); - } + // A && False => False or A & 0 => 0 + VLOG(10) << "trying transform [A && False => False]: " + << logical_and->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } - // False && A => False - VLOG(10) << "trying transform [False && A => False]: " - << logical_and->ToString(); - if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) { - return Status::OK(); - } + // False && A => False or A & 0 => 0 + VLOG(10) << "trying transform [False && A => False]: " + << logical_and->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); } return Status::OK(); @@ -636,8 +640,18 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { if (HloInstruction* bitcast_operand = BitcastingOperandOfReshapeOrCopyChain(copy, options_)) { ReplaceWithBitcast(copy, bitcast_operand); + return Status::OK(); } + // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast. + if (copy->operand(0)->opcode() == HloOpcode::kReshape && + copy->operand(0)->user_count() == 1 && + ShapeUtil::ReshapeIsBitcast(copy->operand(0)->shape(), copy->shape())) { + return ReplaceWithNewInstruction( + copy, + copy->operand(0)->CloneWithNewOperands( + copy->shape(), {copy->mutable_operand(0)->mutable_operand(0)})); + } return Status::OK(); } @@ -713,6 +727,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( new_slice_shape.set_dimensions( concatenate_dimension, slice_end - operands[i]->slice_starts(concatenate_dimension)); + simplifier_->UpdateLayout(&new_slice_shape); auto new_limit_indices = operands[i]->slice_limits(); new_limit_indices[concatenate_dimension] = slice_end; auto new_slice_op = @@ -775,18 +790,19 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } static HloInstruction* BuildTupleConstant(HloComputation* computation, - const LiteralSlice& literal) { + const LiteralSlice& literal, + AlgebraicSimplifier* simplifier) { if (literal.shape().IsTuple()) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { - elems.push_back( - BuildTupleConstant(computation, LiteralSlice(literal, {i}))); + elems.push_back(BuildTupleConstant( + computation, LiteralSlice(literal, {i}), simplifier)); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(literal.Clone())); + simplifier->CreateConstantWithLayoutUpdated(literal.Clone())); } } @@ -795,7 +811,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // explicit Tuple instructions. if (constant->shape().IsTuple()) { return ReplaceInstruction( - constant, BuildTupleConstant(computation_, constant->literal())); + constant, + BuildTupleConstant(computation_, constant->literal(), simplifier_)); } if (constant->shape().element_type() == TOKEN) { @@ -808,7 +825,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { Literal unique_scalar( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(unique_scalar))); + simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar))); return ReplaceWithNewInstruction( constant, HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); @@ -835,9 +852,17 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { // Canonicalize subtraction of a constant to addition. VLOG(10) << "trying transform [A - Const => A + (-Const)]"; - if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) { + if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) || + Match(sub, m::Subtract(m::NonConstant(&lhs), + m::Broadcast(m::Constant(&rhs))))) { HloInstruction* negative_const = computation_->AddInstruction( HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); + if (const HloInstruction* broadcast = + DynCast(sub->operand(1))) { + negative_const = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + broadcast->shape(), negative_const, broadcast->dimensions())); + } return ReplaceWithNewInstruction( sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs, negative_const)); @@ -854,8 +879,9 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) { } template -std::unique_ptr TryDivideToShift(HloInstruction* divide, - HloComputation* computation) { +std::unique_ptr TryDivideToShift( + HloInstruction* divide, HloComputation* computation, + AlgebraicSimplifier* simplifier) { HloInstruction *a, *b, *c; CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); @@ -872,10 +898,11 @@ std::unique_ptr TryDivideToShift(HloInstruction* divide, HloInstruction* zero_like_a = BroadcastZeros( computation, a->shape().element_type(), a->shape().dimensions()); + Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED); + simplifier->UpdateLayout(&changed_shape); auto* dividend_is_negative = - computation->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, - zero_like_a)); + computation->AddInstruction(HloInstruction::CreateCompare( + changed_shape, a, zero_like_a, ComparisonDirection::kLt)); auto* negated_dividend = computation->AddInstruction( HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); @@ -887,8 +914,8 @@ std::unique_ptr TryDivideToShift(HloInstruction* divide, int log2_abs_b_value = tensorflow::Log2Floor64(b_value); - auto* shift_amount = - computation->AddInstruction(HloInstruction::CreateConstant( + auto* shift_amount = computation->AddInstruction( + simplifier->CreateConstantWithLayoutUpdated( LiteralUtil::CreateR0(log2_abs_b_value))); if (!ShapeUtil::IsScalar(b->shape())) { shift_amount = computation->AddInstruction( @@ -911,8 +938,8 @@ std::unique_ptr TryDivideToShift(HloInstruction* divide, uint64 b_value = c->literal().GetFirstElement(); if (IsPowerOfTwo(b_value)) { int log2_abs_b_value = tensorflow::Log2Floor64(b_value); - HloInstruction* shift_amount = - computation->AddInstruction(HloInstruction::CreateConstant( + HloInstruction* shift_amount = computation->AddInstruction( + simplifier->CreateConstantWithLayoutUpdated( LiteralUtil::CreateR0(log2_abs_b_value))); if (!ShapeUtil::IsScalar(b->shape())) { shift_amount = computation->AddInstruction( @@ -940,49 +967,49 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { switch (divide->shape().element_type()) { case S8: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case S16: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case S32: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case S64: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case U8: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case U16: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case U32: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; case U64: if (std::unique_ptr shift = - TryDivideToShift(divide, computation_)) { + TryDivideToShift(divide, computation_, simplifier_)) { return ReplaceWithNewInstruction(divide, std::move(shift)); } break; @@ -1058,33 +1085,38 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // // (Backends can do this transformation, but generally only if the constant is // a scalar.) - if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - Shape result_shape = b->literal().shape(); + if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) && + (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) { + Shape result_shape = c->literal().shape(); Literal new_literal(result_shape); switch (result_shape.element_type()) { case F16: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case F32: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case BF16: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case F64: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case C64: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case C128: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; default: return Status::OK(); } auto inverse = computation_->AddInstruction( - HloInstruction::CreateConstant((new_literal.Clone()))); + simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone())); + if (b != c) { + inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast( + b->shape(), inverse, b->dimensions())); + } TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); @@ -1124,240 +1156,81 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( +StatusOr AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( HloInstruction* dot) { - HloInstruction *lhs, *rhs; - CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - - const auto kept_dim = [](int64 rank, int64 contracting_dimension, - absl::Span batch_dimensions) -> int64 { - for (int64 i = 0; i < rank; ++i) { - if (i != contracting_dimension && - !absl::c_linear_search(batch_dimensions, i)) { - return i; - } + const Shape& lhs_shape = dot->operand(0)->shape(); + int64 num_degenerate_lhs_dims = 0; + std::vector lhs_dimension_map(lhs_shape.rank(), -1); + for (int64 i = 0; i < lhs_shape.rank(); ++i) { + if (lhs_shape.dimensions(i) == 1) { + ++num_degenerate_lhs_dims; + } else { + lhs_dimension_map[i] = i - num_degenerate_lhs_dims; } - return -1; - }; - - const int64 dot_rank = dot->shape().rank(); - const int64 rhs_rank = rhs->shape().rank(); - const int64 lhs_rank = lhs->shape().rank(); - const auto& dnums = dot->dot_dimension_numbers(); - if (dnums.rhs_contracting_dimensions_size() != 1) { - return false; - } - if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { - return false; - } - int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0); - int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim, - AsInt64Slice(dnums.lhs_batch_dimensions())); - // If there is no non-contracting dimension in rank 2, do not strength reduce. - if (lhs_kept_dim == -1 && lhs_rank > 1) { - return false; - } - if (lhs->IsRank2Transpose()) { - lhs = lhs->mutable_operand(0); - std::swap(lhs_collapsing_dim, lhs_kept_dim); } - int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0); - int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim, - AsInt64Slice(dnums.rhs_batch_dimensions())); - // If there is no non-contracting dimension in rank 2, do not strength reduce. - if (rhs_kept_dim == -1 && rhs_rank > 1) { - return false; - } - if (rhs->IsRank2Transpose()) { - rhs = rhs->mutable_operand(0); - std::swap(rhs_collapsing_dim, rhs_kept_dim); - } - - auto reshape_if_necessary = [&](HloInstruction* hlo) { - hlo = AsType(hlo, dot->shape().element_type()); - if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { - hlo = computation_->AddInstruction( - HloInstruction::CreateReshape(dot->shape(), hlo)); + const Shape& rhs_shape = dot->operand(1)->shape(); + int64 num_degenerate_rhs_dims = 0; + std::vector rhs_dimension_map(rhs_shape.rank(), -1); + for (int64 i = 0; i < rhs_shape.rank(); ++i) { + if (rhs_shape.dimensions(i) == 1) { + ++num_degenerate_rhs_dims; + } else { + rhs_dimension_map[i] = i - num_degenerate_rhs_dims; } - return hlo; - }; - - auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { - return AddReduce(AsType(hlo, F32), dim); - }; - - auto broadcast = [&](HloInstruction* hlo, const Shape& shape, - absl::Span dims) { - return computation_->AddInstruction( - HloInstruction::CreateBroadcast(shape, hlo, dims)); - }; - - auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, - int64 dim) { - return broadcast(hlo, shape, {dim}); - }; - - auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { - return computation_->AddInstruction(HloInstruction::CreateBinary( - local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs)); - }; - - // Strength reduce dot(a[K] , b[K]) = - // reshape(result.shape, - // reduce_sum(multiply(a, b), {0})) - if (rhs_rank == 1 && lhs_rank == 1) { - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0)))); - return true; } - - if (ShapeUtil::IsEffectiveScalar(rhs->shape()) && - ShapeUtil::IsEffectiveScalar(lhs->shape())) { - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs))))); - return true; - } - - // Simplify outer product into multiply with implicit broadcasting. - // - // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), - broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); - return true; - } - - // Strength reduce dot(a[1, K], b) = - // reshape(result.shape, - // reduce_sum( - // multiply(broadcast(reshape(a, [K]), {0}), b), - // {0}) - // ) - // ) - if (lhs_rank == 1 || - (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { - if (rhs->shape().rank() == 1) { - TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( - multiply(Flatten(lhs), rhs), 0)))); - return true; - } - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(add_reduce_in_f32( - multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), - rhs_collapsing_dim), - rhs), - rhs_collapsing_dim)))); - return true; - } - - // Strength reduce dot(a, b[K, 1]) = - // reshape(result.shape, - // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) - // ) - if (rhs_rank == 1 || - (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(add_reduce_in_f32( - multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), - lhs_collapsing_dim)), - lhs_collapsing_dim)))); - return true; - } - - // Only consider kDot with batch dimension. - if (dot_rank <= 2) { + if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) { return false; } - - CHECK_EQ(rhs_rank, lhs_rank); - CHECK_EQ(dot_rank, lhs_rank); - // If there is more than one non-contracting dimension or the batch dimensions - // are not equal, bail out since transposes may be required to do a strength - // reduction. - if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank || - !absl::c_equal(dnums.lhs_batch_dimensions(), - dnums.rhs_batch_dimensions())) { - return false; - } - - auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) { - absl::InlinedVector dims; - for (int64 i = 0; i < rank; ++i) { - if (i != non_broadcast_dim) { - dims.push_back(i); - } + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + DotDimensionNumbers new_dnums; + for (int64 dim : dnums.lhs_batch_dimensions()) { + int64 new_dim = lhs_dimension_map[dim]; + if (new_dim != -1) { + new_dnums.add_lhs_batch_dimensions(new_dim); } - return dims; - }; - - // If the contracting dimension is 1, remove the degnerate dimnensions from - // the lhs and rhs, broadcast each to the result shape and multiply. - if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && - (rhs_kept_dim == rhs_rank - 1 || - (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { - CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1); - const int64 lhs_kept_dim_in_output = - lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim; - absl::InlinedVector lhs_broadcast_dims; - for (const int64 dim : dnums.lhs_batch_dimensions()) { - lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim); + } + for (int64 dim : dnums.lhs_contracting_dimensions()) { + int64 new_dim = lhs_dimension_map[dim]; + if (new_dim != -1) { + new_dnums.add_lhs_contracting_dimensions(new_dim); } - absl::InlinedVector rhs_broadcast_dims = lhs_broadcast_dims; - lhs_broadcast_dims.push_back(lhs_kept_dim_in_output); - absl::c_sort(lhs_broadcast_dims); - rhs_broadcast_dims.push_back(dot_rank - 1); - absl::c_sort(rhs_broadcast_dims); - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary( - multiply(broadcast(StripDim(lhs, lhs_collapsing_dim), - dot->shape(), lhs_broadcast_dims), - broadcast(StripDim(rhs, rhs_collapsing_dim), - dot->shape(), rhs_broadcast_dims))))); - return true; } - // If the lhs and rhs non-contracting dimensions are both one, strip each one, - // multiply and then reduce the collapsing dimension - if (lhs->shape().dimensions(lhs_kept_dim) == 1 && - rhs->shape().dimensions(rhs_kept_dim) == 1 && - lhs_kept_dim == rhs_kept_dim) { - auto new_lhs = StripDim(lhs, lhs_kept_dim); - auto new_rhs = StripDim(rhs, rhs_kept_dim); - const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim - ? (rhs_collapsing_dim - 1) - : rhs_collapsing_dim; - TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( - multiply(new_lhs, new_rhs), reduce_dim)))); - return true; + for (int64 dim : dnums.rhs_batch_dimensions()) { + int64 new_dim = rhs_dimension_map[dim]; + if (new_dim != -1) { + new_dnums.add_rhs_batch_dimensions(new_dim); + } + } + for (int64 dim : dnums.rhs_contracting_dimensions()) { + int64 new_dim = rhs_dimension_map[dim]; + if (new_dim != -1) { + new_dnums.add_rhs_contracting_dimensions(new_dim); + } } - // If the lhs non-contracting dimensions is one, strip the one, brodcast to - // the rhs shape, multiply and then reduce the collapsing dimension - if (lhs->shape().dimensions(lhs_kept_dim) == 1) { - auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(), - broadcast_dims(rhs_rank, rhs_kept_dim)); - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs), - rhs_collapsing_dim)))); - return true; + HloInstruction* new_lhs = + num_degenerate_lhs_dims > 0 + ? dot->parent()->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::DropDegenerateDimensions(lhs_shape), + dot->mutable_operand(0))) + : dot->mutable_operand(0); + HloInstruction* new_rhs = + num_degenerate_rhs_dims > 0 + ? dot->parent()->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::DropDegenerateDimensions(rhs_shape), + dot->mutable_operand(1))) + : dot->mutable_operand(1); + TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums, + dot->precision_config())); + if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) { + TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot)); + } else { + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + dot, HloInstruction::CreateReshape(dot->shape(), new_dot))); } - - // If the rhs non-contracting dimensions is one, strip the one, brodcast to - // the lhs shape, multiply and then reduce the collapsing dimension - if (rhs->shape().dimensions(rhs_kept_dim) == 1) { - auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(), - broadcast_dims(lhs_rank, lhs_kept_dim)); - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs), - lhs_collapsing_dim)))); - return true; - } - - return false; + return true; } StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( @@ -1456,6 +1329,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); Shape rhs_slice_shape(rhs->shape()); rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); + simplifier_->UpdateLayout(&rhs_slice_shape); std::array start_indices; start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; @@ -1591,6 +1465,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( right_operand->shape().dimensions(1 - rhs_contracting_dimension); auto memoized_shape = ShapeUtil::MakeShape(dot->shape().element_type(), {m, n}); + simplifier_->UpdateLayout(&memoized_shape); auto* memoized_inst = computation_->AddInstruction( HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, dnums, dot->precision_config())); @@ -1605,7 +1480,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( // Slice out start and 0 components and reorder if necessary. auto indices_type = dynamic_slice->operand(1)->shape().element_type(); Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); + simplifier_->UpdateLayout(&s_shape); Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); + simplifier_->UpdateLayout(&d_shape); HloInstruction* non_zero_start = dynamic_slice->mutable_operand(1 + index_of_non_zero_start); HloInstruction* zero_start = @@ -1628,6 +1505,221 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( return memoized_lookup; } +// This function tries to transform +// dot(reshape(transpose(A)), Const) to +// dot(reshape(A), reshape(transpose(reshape(Const)))), +// so that the reshape and transpose on the Const side can be constant folded. +// +// The basic idea is that since the accumulation in the dot operation is +// associative, so as long as we permute the elements of the contracting +// dimensions on both sides of the dot in the same way, the result of the +// dot is not affected. +StatusOr +AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( + HloInstruction* dot) { + // This transformation assumes layout is not assigned yet. + if (options_.is_layout_sensitive()) { + return nullptr; + } + + // Canonicalize dot(, rhs) to dot(rhs, ) to make the + // remainder of this function easier. + auto dnums = dot->dot_dimension_numbers(); + auto lhs_contracting_dims = dnums.lhs_contracting_dimensions(); + auto rhs_contracting_dims = dnums.rhs_contracting_dimensions(); + auto* lhs = dot->mutable_operand(0); + auto* rhs = dot->mutable_operand(1); + if (dot->operand(0)->IsConstant()) { + std::swap(lhs, rhs); + std::swap(lhs_contracting_dims, rhs_contracting_dims); + } + + // Require single contracting dim to make the implementation easier to + // track contracting dims. + if (dnums.lhs_contracting_dimensions_size() != 1) { + return nullptr; + } + + // Pattern match Dot(reshape(transpose(input), constant)) + HloInstruction* reshape; + HloInstruction* transpose; + HloInstruction* input; + HloInstruction* constant; + if (!Match(lhs, + m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) || + !Match(rhs, m::Constant(&constant))) { + return nullptr; + } + + // Check that reshape squishes some dims into one dim and that this one + // dim is the dot's lhs contracting dim. The size of unmodified_dims should + // be N - 1, where N is the rank of the reshape output. This means that the + // reshape squishes some dims into one dim. lhs contracting dim should not + // be in unmodified_dims. This means that the squishing target dim is the + // lhs contracting dim. + auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape( + reshape->operand(0)->shape(), reshape->shape()); + CHECK_EQ(lhs_contracting_dims.size(), 1); + if ((unmodified_dims.size() != reshape->shape().rank() - 1) || + absl::c_any_of(unmodified_dims, [&](const std::pair& p) { + return p.second == lhs_contracting_dims[0]; + })) { + return nullptr; + } + + // Virtually pull the reshape into the dot so the dot operates on the + // transpose, with "unsquished" lhs contracting dims. The new contracting + // dims are all of the dims that are modified by the reshape -- that is, every + // dimension that's not in `unmodified_dims[i].first`. + // + // (We don't need to actually create a new dot instruction. We can just keep + // track of lhs and lhs_contracting_dims.) + absl::flat_hash_set unmodified_transpose_dims; + for (const auto& pair : unmodified_dims) { + unmodified_transpose_dims.insert(pair.first); + } + lhs_contracting_dims.Clear(); + for (int64 i = 0; i < transpose->shape().dimensions_size(); ++i) { + if (!unmodified_transpose_dims.contains(i)) { + lhs_contracting_dims.Add(i); + } + } + // We require the "unsquished" lhs contracting dims to be consecutive. + auto is_iota = [](absl::Span dims) { + return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) { + return (b != a + 1); + }) == dims.end(); + }; + if (!is_iota(AsInt64Slice(lhs_contracting_dims))) { + return nullptr; + } + lhs = lhs->mutable_operand(0); + + // Check that the transpose only permutes the contracting dims. + const auto& transpose_dims = transpose->dimensions(); + for (int64 i = 0; i < transpose_dims.size(); ++i) { + if (transpose_dims[i] != i && + !absl::c_linear_search(lhs_contracting_dims, i)) { + return nullptr; + } + } + // Virtually pull the transpose into the dot. Now the dot is equivalent to + // a new dot with "permuted" lhs contracting dims. + std::vector permutation; + for (auto dim : lhs_contracting_dims) { + permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]); + } + CHECK(IsPermutation(permutation, permutation.size())); + auto new_lhs_contracting_dims = + ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation); + lhs_contracting_dims.Clear(); + for (auto dim : new_lhs_contracting_dims) { + lhs_contracting_dims.Add(dim); + } + lhs = lhs->mutable_operand(0); + + // All checks are passed at this point. + // + // Transform lhs. Remove the transpose and reshape by sorting the lhs + // contracting dims and squishing them into a single one. We don't actually + // squish the lhs_contracting_dims here because we still need the unsquished + // contracting dims to invert reshape and transpose. + absl::c_sort(lhs_contracting_dims); + lhs = computation_->AddInstruction( + HloInstruction::CreateReshape(reshape->shape(), lhs)); + + // Transform rhs. Say the input HLO is: + // + // t0 = f32[2, 2, 3] parameter(0) + // t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1} + // t2 = f32[2, 6] reshape(t1) + // t3 = f32[6, 2] constant(...) + // dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1}, + // rhs_contracting_dims={0} + // + // At this point in the function, we have decided that the second and third + // dims of t0 can be switched to remove the transpose, and we have + // "virtually decomposed" the input HLO to: + // + // t0 = f32[2, 2, 3] parameter(0) + // t2' = f32[2, 6] reshape(t0) + // t3' = f32[6, 2] ops-to-be-filled ... + // dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1}, + // rhs_contracting_dims={0} + // + // The rest of this function is to fill in the ops of t3'. To do this, we + // unsquish the contracting dimensions in t3 and then apply the inverse of + // the transpose from t1. + + // Invert reshape. + CHECK_EQ(rhs_contracting_dims.size(), 1); + auto rhs_unsquished_shape_dims = constant->shape().dimensions(); + auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() + + rhs_contracting_dims[0]); + for (auto dim : lhs_contracting_dims) { + it = rhs_unsquished_shape_dims.insert(it, + transpose->shape().dimensions(dim)); + ++it; + } + HloInstruction* rhs_reshape = + computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(constant->shape().element_type(), + rhs_unsquished_shape_dims), + constant)); + rhs = rhs_reshape; + + // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims. + rhs_contracting_dims.Resize(lhs_contracting_dims.size(), + rhs_contracting_dims[0]); + absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]); + + // Invert transpose. First compute the shape. + auto rhs_transpose_shape_dims = rhs_reshape->shape().dimensions(); + it = rhs_transpose_shape_dims.erase( + rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0], + rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] + + rhs_contracting_dims.size()); + for (auto dim : lhs_contracting_dims) { + it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim)); + ++it; + } + // Then compute the transpose dims. + std::vector rhs_transpose_dims(rhs_reshape->shape().rank()); + absl::c_iota(rhs_transpose_dims, 0); + it = rhs_transpose_dims.erase( + rhs_transpose_dims.begin() + rhs_contracting_dims[0], + rhs_transpose_dims.begin() + rhs_contracting_dims[0] + + rhs_contracting_dims.size()); + auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims); + for (auto dim : lhs_contracting_dims) { + it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] - + lhs_contracting_dims[0] + + rhs_contracting_dims[0]); + ++it; + } + HloInstruction* rhs_transpose = + computation_->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(constant->shape().element_type(), + rhs_transpose_shape_dims), + rhs_reshape, rhs_transpose_dims)); + rhs = rhs_transpose; + + // Squish the multiple rhs contracting dims into a single one. + rhs = computation_->AddInstruction( + HloInstruction::CreateReshape(constant->shape(), rhs)); + + // If we virtually swapped lhs and rhs, we need to swap it back before + // creating new dot. + if (dot->operand(0)->IsConstant()) { + std::swap(lhs, rhs); + } + + HloInstruction* new_dot = + computation_->AddInstruction(HloInstruction::CreateDot( + dot->shape(), lhs, rhs, dnums, dot->precision_config())); + return new_dot; +} + Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); @@ -1638,14 +1730,14 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { - auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(dot->shape().element_type()))); + auto zero = computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( + LiteralUtil::Zero(dot->shape().element_type()))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } - // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are - // rank 2 or below. + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs. if (dot->shape().element_type() != F32 && dot->shape().element_type() != BF16) { return Status::OK(); @@ -1761,13 +1853,16 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, new_dot); } - if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || - dot->shape().rank() > 2) { - if (options_.enable_dot_strength_reduction() && - !options_.is_layout_sensitive()) { - TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); - } - return Status::OK(); + // Simplify dot(reshape(transpose(A)), Const) to: + // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape + // and transpose on the Const side can be constant folded. + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized, + OptimizeDotOfReorderContractingDims(dot)); + if (dot_of_reorder_optimized) { + VLOG(10) << " Replaced dot " << dot->ToString() + << " with new dot operation: " + << dot_of_reorder_optimized->ToString(); + return ReplaceInstruction(dot, dot_of_reorder_optimized); } TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, @@ -1789,11 +1884,10 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_gather_optimized); } - if (options_.enable_dot_strength_reduction() && - !options_.is_layout_sensitive()) { - TF_ASSIGN_OR_RETURN(bool did_strength_reduction, - HandleDotStrengthReduction(dot)); - if (did_strength_reduction) { + if (options_.enable_dot_strength_reduction()) { + TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, + RemoveDegenerateDimensionFromDot(dot)); + if (removed_degenerate_dimensions) { return Status::OK(); } } @@ -1845,6 +1939,29 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]"; + HloInstruction *a, *c1, *c2; + if (Match(multiply, + m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)), + m::Constant(&c2))) || + Match(multiply, + m::Multiply( + m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))), + m::Broadcast(m::ConstantScalar(&c2))))) { + TF_ASSIGN_OR_RETURN(auto* product_of_constants, + MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); + if (ShapeUtil::IsScalar(product_of_constants->shape()) && + !ShapeUtil::IsScalar(multiply->shape())) { + product_of_constants = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + multiply->shape(), product_of_constants, {})); + } + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a, + product_of_constants)); + } + // exp(A) * exp(B) => exp(A+B) if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -1895,20 +2012,18 @@ Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) { if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) { return Status::OK(); } + } - // A || False => A - VLOG(10) << "trying transform [A || False => A]: " - << logical_or->ToString(); - if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) { - return Status::OK(); - } + // A || False => A and A | 0 => A + VLOG(10) << "trying transform [A || False => A]: " << logical_or->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } - // False || A => A - VLOG(10) << "trying transform [False || A => A]: " - << logical_or->ToString(); - if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) { - return Status::OK(); - } + // False || A => A and 0 | A => A + VLOG(10) << "trying transform [False || A => A]: " << logical_or->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); } return Status::OK(); @@ -1954,40 +2069,11 @@ Status AlgebraicSimplifierVisitor::HandleGetTupleElement( namespace { -// Return whether the given reshape instruction leaves the dimensions at the -// given input indices unmodified, and returns their output indices. -// -// Example: -// input_dim_indices = {2, 3} -// input shape = T[a, b, x, y, cd] -// output shape = T[ab, x, 1, y, c, d] -// return value = {1, 3} -// -// Precondition: input_dim_indices is sorted. absl::optional> ReshapeLeavesDimensionsUnmodified( const HloInstruction* hlo, absl::Span input_dim_indices) { - CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); - CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); - - std::vector output_dim_indices; - std::vector> unmodified_dims = - ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(), - hlo->shape()); - size_t i = 0; // index to unmodified_dims - for (int64 input_dim_index : input_dim_indices) { - // Search unmodified_dims for input_dim_index. We can search from the last - // matching position because input_dim_indices is guaranteed to be sorted. - while (i < unmodified_dims.size() && - unmodified_dims[i].first < input_dim_index) { - ++i; - } - if (i >= unmodified_dims.size() || - unmodified_dims[i].first != input_dim_index) { - return absl::nullopt; - } - output_dim_indices.push_back(unmodified_dims[i].second); - } - return output_dim_indices; + CHECK_EQ(hlo->opcode(), HloOpcode::kReshape); + return ShapeUtil::ReshapeLeavesDimensionsUnmodified( + hlo->operand(0)->shape(), hlo->shape(), input_dim_indices); } // Returns true if the output of "instruction" is a permutation of the @@ -2138,6 +2224,49 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) { + HloInstruction* lhs; + HloInstruction* rhs; + CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs)))); + + auto replace_with_pred_broadcast = [&](bool value) { + return ReplaceWithNewInstruction( + compare, + HloInstruction::CreateBroadcast( + compare->shape(), + computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(value))), + {})); + }; + if (compare->comparison_direction() == ComparisonDirection::kLt && + lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { + return replace_with_pred_broadcast(false); + } else if (compare->comparison_direction() == ComparisonDirection::kGt && + IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) { + return replace_with_pred_broadcast(false); + } else if (compare->comparison_direction() == ComparisonDirection::kGe && + lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { + return replace_with_pred_broadcast(true); + } else if (compare->comparison_direction() == ComparisonDirection::kLe && + IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) { + return replace_with_pred_broadcast(true); + } + if (lhs == rhs && + primitive_util::IsIntegralType(lhs->shape().element_type())) { + switch (compare->comparison_direction()) { + case ComparisonDirection::kGt: + case ComparisonDirection::kLt: + case ComparisonDirection::kNe: + return replace_with_pred_broadcast(false); + case ComparisonDirection::kEq: + case ComparisonDirection::kGe: + case ComparisonDirection::kLe: + return replace_with_pred_broadcast(true); + } + } + return Status::OK(); +} + // A conversion to the same element type as the operand is a nop and can be // removed. A conversion of a constant can be simplified by making a new // constant. @@ -2183,8 +2312,9 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { // zero. auto* iota = Cast(instruction); if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { - auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(iota->shape().element_type()).Clone())); + auto zero = computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( + LiteralUtil::Zero(iota->shape().element_type()).Clone())); return ReplaceWithNewInstruction( iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); } @@ -2307,7 +2437,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { HloInstruction *lhs, *rhs; CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { - auto one = HloInstruction::CreateConstant( + auto one = simplifier_->CreateConstantWithLayoutUpdated( LiteralUtil::One(power->shape().element_type()).Clone()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { @@ -2342,8 +2472,9 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { - auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::One(rhs->shape().element_type()).Clone())); + auto* one = computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( + LiteralUtil::One(rhs->shape().element_type()).Clone())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -2422,14 +2553,16 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( std::vector new_operands; new_operands.reserve(user->operand_count()); + Shape changed_shape; for (HloInstruction* user_operand : user->operands()) { if (user_operand->opcode() == HloOpcode::kBroadcast && ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + changed_shape = ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); new_operands.push_back( computation_->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::ChangeElementType( - operand->shape(), user_operand->shape().element_type()), - user_operand->mutable_operand(0), {}))); + changed_shape, user_operand->mutable_operand(0), {}))); } else { CHECK_EQ(broadcast, user_operand); new_operands.push_back(operand); @@ -2438,11 +2571,11 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( VLOG(4) << "Sinking broadcast after user:"; VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old user: " << user->ToString(); - HloInstruction* new_user = - computation_->AddInstruction(user->CloneWithNewOperands( - ShapeUtil::ChangeElementType(operand->shape(), - user->shape().element_type()), - new_operands)); + changed_shape = ShapeUtil::ChangeElementType(operand->shape(), + user->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); + HloInstruction* new_user = computation_->AddInstruction( + user->CloneWithNewOperands(changed_shape, new_operands)); VLOG(4) << " new user: " << new_user->ToString(); HloInstruction* new_broadcast = computation_->AddInstruction(HloInstruction::CreateBroadcast( @@ -2456,8 +2589,9 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( namespace { template -std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, - HloComputation* computation) { +std::unique_ptr TryRemainderToAnd( + HloInstruction* remainder, HloComputation* computation, + AlgebraicSimplifier* simplifier) { HloInstruction *a, *b, *c; CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); @@ -2475,9 +2609,9 @@ std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, computation, a->shape().element_type(), a->shape().dimensions()); auto* dividend_is_negative = - computation->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, - zero_like_a)); + computation->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a, + ComparisonDirection::kLt)); auto* negated_dividend = computation->AddInstruction( HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); @@ -2487,8 +2621,8 @@ std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, a->shape(), HloOpcode::kSelect, dividend_is_negative, negated_dividend, a)); - auto* mask_amount = - computation->AddInstruction(HloInstruction::CreateConstant( + auto* mask_amount = computation->AddInstruction( + simplifier->CreateConstantWithLayoutUpdated( LiteralUtil::CreateR0(b_value - 1))); if (!ShapeUtil::IsScalar(b->shape())) { mask_amount = computation->AddInstruction( @@ -2509,8 +2643,8 @@ std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, } else { uint64 b_value = c->literal().GetFirstElement(); if (IsPowerOfTwo(b_value)) { - HloInstruction* mask_amount = - computation->AddInstruction(HloInstruction::CreateConstant( + HloInstruction* mask_amount = computation->AddInstruction( + simplifier->CreateConstantWithLayoutUpdated( LiteralUtil::CreateR0(b_value - 1))); if (!ShapeUtil::IsScalar(b->shape())) { mask_amount = computation->AddInstruction( @@ -2532,49 +2666,49 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { switch (remainder->shape().element_type()) { case S8: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case S16: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case S32: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case S64: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case U8: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case U16: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case U32: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; case U64: if (std::unique_ptr shift = - TryRemainderToAnd(remainder, computation_)) { + TryRemainderToAnd(remainder, computation_, simplifier_)) { return ReplaceWithNewInstruction(remainder, std::move(shift)); } break; @@ -2597,7 +2731,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { if (!LayoutUtil::HasLayout(reshaped_shape)) { LayoutUtil::SetToDefaultLayout(&reshaped_shape); } - auto empty_constant = HloInstruction::CreateConstant( + auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated( Literal::CreateFromShape(reshaped_shape)); return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); @@ -2810,6 +2944,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( new_slice_limits), new_slice_operand, new_slice_starts, new_slice_limits, new_slice_stides)); + simplifier_->UpdateLayout(new_slice->mutable_shape()); TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); return true; @@ -2914,8 +3049,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // representative. auto arg = reduce->inputs()[0]; auto init_value = reduce->init_values()[0]; - const Shape& reduce_result_shape = - multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape(); + Shape& reduce_result_shape = const_cast( + multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape()); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); @@ -2960,7 +3095,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } - // TODO(b/112040122): Most of those optimizations below can be done for + // TODO(b/131122694): Most of those optimizations below can be done for // multi-output reduces. if (multi_output_reduce) { return Status::OK(); @@ -3136,6 +3271,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return !absl::c_linear_search(effective_reduce_dims, dim); }, reduce_window->shape()); + simplifier_->UpdateLayout(&reduce_shape); HloInstruction* reduce = computation_->AddInstruction(HloInstruction::CreateReduce( /*shape=*/reduce_shape, @@ -3261,17 +3397,17 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* new_reduce_window_operand; if (convert != nullptr) { - new_reduce_window_operand = - computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(pad_operand->shape(), - convert->shape().element_type()), - pad_operand)); + Shape changed_shape = ShapeUtil::ChangeElementType( + pad_operand->shape(), convert->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); + new_reduce_window_operand = computation_->AddInstruction( + HloInstruction::CreateConvert(changed_shape, pad_operand)); } else { new_reduce_window_operand = pad_operand; } if (is_effective_broadcast()) { - VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; + VLOG(10) << "Replacing pad/reduce-window with broadcast."; auto fadd = [this](std::unique_ptr x) { return computation_->AddInstruction(std::move(x)); }; @@ -3321,6 +3457,22 @@ Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleScatter(HloInstruction* scatter) { + if (ShapeUtil::IsZeroElementArray(scatter->operand(2)->shape()) && + ReplaceInstructionIfSameShape(scatter, scatter->mutable_operand(0))) { + return Status::OK(); + } + if (ShapeUtil::IsZeroElementArray(scatter->operand(1)->shape()) && + SameShape(scatter, scatter->operand(0)) && + SameShape(scatter, scatter->operand(2))) { + return ReplaceWithNewInstruction( + scatter, HloInstruction::CreateMap( + scatter->shape(), + {scatter->mutable_operand(0), scatter->mutable_operand(2)}, + scatter->to_apply())); + } + return Status::OK(); +} Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { auto operand = sort->mutable_operand(0); int64 dimension_to_sort = sort->dimensions(0); @@ -3337,8 +3489,8 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { } namespace { -bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape, - absl::Span perm) { +bool OnlyPermutesDegenerateDims(const Shape& shape, + absl::Span perm) { std::vector new_permutation; int64 degenerate_count = 0; for (int64 i = 0; i < perm.size(); ++i) { @@ -3348,7 +3500,7 @@ bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape, ++degenerate_count; } } - return degenerate_count > 1 && absl::c_is_sorted(new_permutation); + return degenerate_count > 0 && absl::c_is_sorted(new_permutation); } } // namespace @@ -3370,8 +3522,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { // Replace transpose with a reshape if more than one degenerate method is // permuted. - if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(), - transpose->dimensions())) { + if (OnlyPermutesDegenerateDims(transpose->shape(), transpose->dimensions())) { return ReplaceWithNewInstruction( transpose, HloInstruction::CreateReshape( transpose->shape(), transpose->mutable_operand(0))); @@ -3614,15 +3765,18 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( // We already checked feature_dimension is most minor, so data in input_shape // and row-major {conv_width,input_channels} are bitwise identical. - const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( + Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( input_shape.element_type(), {conv_width, input_channels}); + simplifier_->UpdateLayout(&new_input_shape); // We already checked input_feature_dimension is more major than // output_feature_dimension, so data in filter_shape and row-major // {input_channels,output_channels} are bitwise identical. - const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( + Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( filter_shape.element_type(), {input_channels, output_channels}); - const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + simplifier_->UpdateLayout(&new_filter_shape); + Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( convolution_shape.element_type(), {conv_width, output_channels}); + simplifier_->UpdateLayout(&dot_output_shape); auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); @@ -3647,8 +3801,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( convolution, HloInstruction::CreateBroadcast( convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()))), + computation_->AddInstruction( + simplifier_->CreateConstantWithLayoutUpdated( + LiteralUtil::Zero(convolution->shape().element_type()))), {})); } @@ -3731,7 +3886,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (AlgebraicSimplifierVisitor::Run(comp, options_)) { + if (AlgebraicSimplifierVisitor::Run(comp, options_, this)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index df5a8c2ec14..1768f725b20 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -105,6 +105,15 @@ class AlgebraicSimplifier : public HloModulePass { // computation was changed. StatusOr Run(HloModule* module) override; + // Create constant from literal with tiles and element size updated in the + // constant's layout. + std::unique_ptr CreateConstantWithLayoutUpdated( + Literal literal) { + auto constant = HloInstruction::CreateConstant(std::move(literal)); + UpdateLayout(constant->mutable_shape()); + return constant; + } + private: AlgebraicSimplifierOptions options_; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 7f399ce0f11..e37b69c5cba 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -295,6 +295,70 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { EXPECT_EQ(computation->root_instruction(), zero); } +TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeConstants) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + c0 = f32[] constant(2.0) + c1 = f32[] constant(3.0) + multiply0 = f32[] multiply(p0, c0) + ROOT multiply1 = f32[] multiply(multiply0, c1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), + m::Multiply(m::ConstantScalar(2.0), + m::ConstantScalar(3.0))))); +} + +TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c0 = f32[] constant(2.0) + c1 = f32[] constant(3.0) + b0 = f32[4] broadcast(c0), dimensions={} + b1 = f32[4] broadcast(c1), dimensions={} + multiply0 = f32[4] multiply(p0, b0) + ROOT multiply1 = f32[4] multiply(multiply0, b1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply( + m::Parameter(0), m::Broadcast(m::Multiply(m::ConstantScalar(2.0), + m::ConstantScalar(3.0)))))); +} + +TEST_F(AlgebraicSimplifierTest, + MultiplyReassociateMultiplyOfConstantAndBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + c0 = f32[4] constant({2.0, 3.0, 4.0, 5.0}) + c1 = f32[] constant(3.0) + c2 = f32[] constant(4.0) + b0 = f32[4] broadcast(c1), dimensions={} + b1 = f32[4] broadcast(c2), dimensions={} + multiply0 = f32[4] multiply(c0, b0) + ROOT multiply1 = f32[4] multiply(multiply0, b1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply( + m::Constant(), m::Broadcast(m::Multiply(m::ConstantScalar(3.0), + m::ConstantScalar(4.0)))))); +} + // Test that select(true, a, b) is simplified to a TEST_F(AlgebraicSimplifierTest, SelectTrue) { Shape r0s32 = ShapeUtil::MakeShape(S32, {}); @@ -446,6 +510,27 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { m::Add(m::Op().Is(constant1), m::Op().Is(constant2))))); } +TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c0 = f32[] constant(1.0) + c1 = f32[] constant(2.0) + b0 = f32[4] broadcast(c0), dimensions={} + b1 = f32[4] broadcast(c1), dimensions={} + add0 = f32[4] add(p0, b0) + ROOT add1 = f32[4] add(add0, b1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Add(m::ConstantScalar(1.0), + m::ConstantScalar(2.0)))))); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); @@ -640,6 +725,25 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { m::Negate(m::Op().Is(constant))))); } +// Test that A - Broadcast(Const) is canonicalized to A + Broadcast(-Const). +TEST_F(AlgebraicSimplifierTest, SubBroadcastConstCanonicalization) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c = f32[] constant(0.125) + b = f32[4] broadcast(c), dimensions={} + ROOT sub = f32[4] subtract(p0, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Negate(m::ConstantScalar(0.125)))))); +} + // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { auto m = CreateNewVerifiedModule(); @@ -853,6 +957,26 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { GmockMatch(m::Multiply(m::Parameter(0), m::Constant()))); } +// A / Broadcast(Const) => A * Broadcast(InvertedConst) +TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) { + const char* kModuleStr = R"( + HloModule m + test { + p = f32[4] parameter(0) + c = f32[] constant(256.0) + b = f32[4] broadcast(c), dimensions={} + ROOT d = f32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply( + m::Parameter(0), + m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f))))); +} + // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { auto m = CreateNewVerifiedModule(); @@ -4317,7 +4441,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { int m, k, n; PrimitiveType element_type; std::tie(m, k, n, element_type) = GetParam(); - std::vector lhs_dims = {1, 3, 5}; + std::vector lhs_dims = {2, 3, 5}; std::vector rhs_dims = lhs_dims; std::vector output_dims = lhs_dims; if (m > 0) { @@ -4360,6 +4484,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1; EXPECT_EQ(changed, dot_should_be_transformed); + TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get())); bool has_no_dot = true; for (const auto& hlo : computation->instructions()) { if (hlo->opcode() == HloOpcode::kDot) { @@ -4414,11 +4539,17 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); + // First pass of algebraic simplifier will remove degenerate dimensions + // and optimize dot(transpose(x),transpose(y)) TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = dot_should_be_transformed || (transpose_lhs && transpose_rhs); EXPECT_EQ(changed, computation_should_be_modified); + // The second pass of algebriac simplifer will remove dots without + // non-contracting dimensions or contracting dimensions. + TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get())); + EXPECT_EQ(changed, dot_should_be_transformed); bool has_no_dot = true; for (const auto& hlo : computation->instructions()) { if (hlo->opcode() == HloOpcode::kDot) { @@ -4966,5 +5097,327 @@ TEST_F(AlgebraicSimplifierTest, RecipRsqrt) { m::Sqrt(m::Parameter(0))))); } +TEST_F(AlgebraicSimplifierTest, CopyReshape) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[168,168,48,48]{3,2,1,0} parameter(0) + r0 = f32[1,168,168,2304]{3,2,1,0} reshape(p0) + ROOT c0 = f32[1,168,168,2304]{3,0,2,1} copy(r0) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + Shape result_shape = m->entry_computation()->root_instruction()->shape(); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_is_layout_sensitive(true); + ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)).WithShapeEqualTo(&result_shape))); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RL) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6}); + auto shape2 = ShapeUtil::MakeShape(F32, {3, 2, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + // The transformation of moving transpose and reshape to the constant side + // is layout insensitive. We ignore layout when checking shapes. + const HloInstruction* transpose; + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); + EXPECT_THAT(transpose->dimensions(), ElementsAre(1, 0, 2)); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RR) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6}, + {1, 1, 1, 1, 1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3}); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6}, + {1, 1, 1, 1, 1, 1}}) + t0 = f32[2, 3, 2] parameter(0) + t1 = f32[3, 2, 2] transpose(t0), dimensions={1, 0, 2} + lhs = f32[6, 2] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {6, 2}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3}); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR2) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{7, 7},{8, 8}}) + t0 = f32[2, 2, 2, 2] parameter(0) + t1 = f32[2, 2, 2, 2] transpose(t0), dimensions={0, 2, 3, 1} + lhs = f32[2, 8] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 8}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); + const HloInstruction* transpose; + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose( + &transpose, + m::Reshape(m::Constant()).WithShapeCompatibleTo(&shape2)))))); + EXPECT_THAT(transpose->dimensions(), ElementsAre(2, 0, 1, 3)); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_MM) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[2, 6, 2] constant({{{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}, + {{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}}) + t0 = f32[2, 2, 3, 2] parameter(0) + t1 = f32[2, 3, 2, 2] transpose(t0), dimensions={0, 2, 1, 3} + lhs = f32[2, 6, 2] reshape(t1) + ROOT dot.5 = f32[2, 2, 2] dot(lhs, rhs), lhs_batch_dims={0}, lhs_contracting_dims={1}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {2, 6, 2}); + auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2, 2}); + auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3, 2}); + const HloInstruction* transpose; + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); + EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3)); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegTranspose) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[12, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}) + t0 = f32[3, 4, 2] parameter(0) + t1 = f32[2, 3, 4] transpose(t0), dimensions={2, 0, 1} + lhs = f32[2, 12] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Transpose affects non-contracting dimension. The transpose and reshape + // should not be moved to the constant side. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegReshape) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{1, 1},{2, 2},{3, 3},{4, 4}}) + t0 = f32[2, 4, 3] parameter(0) + t1 = f32[2, 3, 4] transpose(t0), dimensions={0, 2, 1} + lhs = f32[3, 8] reshape(t1) + ROOT dot.5 = f32[3, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Reshape affects non-contracting dimensions. The transpose and reshape + // should not be moved to the constant side. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegConstant) { + const char* kModuleStr = R"( + HloModule m + test { + t0 = f32[2, 3, 4] parameter(0) + t1 = f32[2, 4, 3] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 12] reshape(t1) + rhs = f32[12, 2] parameter(1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // Both operands are non-constant, so the optimization should not happen. + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegLayout) { + const char* kModuleStr = R"( + HloModule m + test { + rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}}) + t0 = f32[2, 2, 3] parameter(0) + t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1} + lhs = f32[2, 6] reshape(t1) + ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + // We disable converting reshape to bitcast to make sure algsimp pass does + // not catch the reshape in this test, then we can simply check if algsimp + // pass does not make any change. + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_is_layout_sensitive(true); + // The transformation of moving transpose and reshape to the constant side is + // layout insensitive. It should not happen if AlgebraicSimplifier is set up + // to be layout sensitive. + ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDimsNoChange) { + // This isn't transformed (notice that the relative order of the `2` and `3` + // dims doesn't change, so there's no opportunity here), but it's nonetheless + // an interesting testcase because of the presence of the size-1 dimensions. + const char* kModuleStr = R"( + HloModule m + test { + param = f32[1,2,5,3] parameter(0) + transpose = f32[1,5,2,3] transpose(param), dimensions={0,2,1,3} + reshape = f32[5,6] reshape(transpose) + constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + ROOT dot = f32[5,4] dot(reshape, constant), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) { + const char* kModuleStr = R"( + HloModule m + test { + param = f32[1,2,3,5] parameter(0) + transpose = f32[1,3,2,5] transpose(param), dimensions={0,2,1,3} + reshape = f32[6,5] reshape(transpose) + constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + ROOT dot = f32[5,4] dot(reshape, constant), + lhs_contracting_dims={0}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto shape1 = ShapeUtil::MakeShape(F32, {6, 5}); + auto shape2 = ShapeUtil::MakeShape(F32, {1, 3, 2, 4}); + auto shape3 = ShapeUtil::MakeShape(F32, {1, 2, 3, 4}); + const HloInstruction* transpose; + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Dot( + m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1), + m::Reshape(m::Transpose(&transpose, + m::Reshape(m::Constant()) + .WithShapeCompatibleTo(&shape2)) + .WithShapeCompatibleTo(&shape3))))); + EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3)); +} + +TEST_F(AlgebraicSimplifierTest, + DotContractingReorder_NoChangeInContractingDimsOrder) { + // No optimization opportunity here because the transpose does not reorder the + // contracting dims. + const char* kModuleStr = R"( + HloModule m + test { + param = f32[2,5,1,3] parameter(0) + transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3} + reshape = f32[5,6] reshape(transpose) + constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + ROOT dot = f32[5,4] dot(reshape, constant), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, CompareIota) { + const char* kModuleStr = R"( + HloModule m + test { + zero = s32[] constant(0) + iota = s32[128] iota(), iota_dimension=0 + broad = s32[128] broadcast(zero), dimensions={} + ROOT compare = pred[128] compare(iota, broad), direction=LT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::ConstantScalar(false)))); +} + +TEST_F(AlgebraicSimplifierTest, CompareSame) { + const char* kModuleStr = R"( + HloModule m + test { + param = s32[123] parameter(0) + ROOT compare = pred[123] compare(param, param), direction=GE + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::ConstantScalar(true)))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc new file mode 100644 index 00000000000..e541bfea11f --- /dev/null +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier.cc @@ -0,0 +1,121 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_reduce_simplifier.h" + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +StatusOr AllReduceSimplifier::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(auto replication, HloReplicationAnalysis::Run(module)); + std::vector all_reduces_to_replace; + for (auto computation : module->computations()) { + for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + if (!inst->shape().IsArray()) { + // We currently do not change tuple-shaped all-reduce. + continue; + } + if (inst->IsCrossReplicaAllReduce() && + replication->HloInstructionIsReplicatedAt(inst->operand(0), {})) { + all_reduces_to_replace.push_back(inst); + } + } + } + + bool changed = false; + if (all_reduces_to_replace.empty()) { + return changed; + } + + // Returns the size of a replica group if all groups have the same size, or -1 + // if they have different sizes. + auto get_replica_group_size = + [this](const HloInstruction* all_reduce) -> int64 { + if (all_reduce->replica_groups().empty()) { + return replica_count_; + } + int64 replica_group_size = -1; + for (const auto& group : all_reduce->replica_groups()) { + if (replica_group_size == -1) { + replica_group_size = group.replica_ids_size(); + } else if (replica_group_size != group.replica_ids_size()) { + return -1; + } + } + return replica_group_size; + }; + + for (auto all_reduce : all_reduces_to_replace) { + if (all_reduce->to_apply()->instruction_count() != 3 || + all_reduce->to_apply()->num_parameters() != 2) { + continue; + } + HloInstruction* replacement; + switch (all_reduce->to_apply()->root_instruction()->opcode()) { + case HloOpcode::kAdd: { + int64 replica_group_size = get_replica_group_size(all_reduce); + if (replica_group_size == -1) { + continue; + } + // Create the multiplier: + // broadcast(convert_to_matching_type(s32 group size)) + auto multiplier = + all_reduce->parent()->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(replica_group_size))); + if (all_reduce->shape().element_type() != S32) { + multiplier = all_reduce->parent()->AddInstruction( + HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + multiplier->shape(), all_reduce->shape().element_type()), + multiplier)); + } + if (all_reduce->shape().rank() > 0) { + multiplier = all_reduce->parent()->AddInstruction( + HloInstruction::CreateBroadcast(all_reduce->shape(), multiplier, + {})); + } + replacement = + all_reduce->parent()->AddInstruction(HloInstruction::CreateBinary( + all_reduce->shape(), HloOpcode::kMultiply, + all_reduce->mutable_operand(0), multiplier)); + break; + } + case HloOpcode::kMinimum: + case HloOpcode::kMaximum: + case HloOpcode::kOr: + case HloOpcode::kAnd: + replacement = all_reduce->mutable_operand(0); + break; + default: + continue; + } + VLOG(2) << "Replacing " << all_reduce->ToString() << " with " + << replacement->ToString(); + TF_RETURN_IF_ERROR(all_reduce->ReplaceAllUsesWith(replacement)); + changed = true; + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier.h b/tensorflow/compiler/xla/service/all_reduce_simplifier.h new file mode 100644 index 00000000000..f2d2294bd6d --- /dev/null +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// A pass that detects all-reduces whose inputs are already the same across +// replicas using the replication analysis, then replaces those all-reduces with +// local computations. E.g., a sum all-reduce on replicated input will be +// replaced by a multiply with the replica count. +class AllReduceSimplifier : public HloModulePass { + public: + explicit AllReduceSimplifier(int64 replica_count) + : replica_count_(replica_count) {} + ~AllReduceSimplifier() override = default; + absl::string_view name() const override { return "all-reduce-simp"; } + + // Run all-reduce simplification on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + int64 replica_count_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc new file mode 100644 index 00000000000..2e03e67c59c --- /dev/null +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_reduce_simplifier.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using AllReduceSimplifierTest = HloTestBase; + +TEST_F(AllReduceSimplifierTest, ReplicatedParameters) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +max { + a.1 = f32[] parameter(0) + b.1 = f32[] parameter(1) + ROOT max = f32[] maximum(a.1, b.1) +} + +min { + a.2 = f32[] parameter(0) + b.2 = f32[] parameter(1) + ROOT min = f32[] minimum(a.2, b.2) +} + +sum.1 { + a.3 = f32[] parameter(0) + b.3 = f32[] parameter(1) + ROOT add.1 = f32[] add(a.3, b.3) +} + +test { + p0 = f32[8,16] parameter(0), parameter_replication={true} + p1 = f32[8,16] parameter(1), parameter_replication={false} + p2 = f32[] parameter(2), parameter_replication={true} + all-reduce = f32[8,16] all-reduce(p0), replica_groups={}, to_apply=sum + all-reduce.1 = f32[8,16] all-reduce(p0), replica_groups={}, to_apply=max + all-reduce.2 = f32[8,16] all-reduce(p1), replica_groups={}, to_apply=min + all-reduce.3 = f32[] all-reduce(p2), replica_groups={}, to_apply=sum.1 + ROOT tuple = (f32[8,16], f32[8,16], f32[8,16], f32[]) tuple(all-reduce, all-reduce.1, all-reduce.2, all-reduce.3) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + AllReduceSimplifier simplifier(/*replica_count=*/8); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::MultiplyAnyOrder(m::Parameter(0), + m::Broadcast(m::Convert(m::ConstantScalar(8)))), + m::Parameter(0), m::AllReduce(m::Parameter(1)), + m::MultiplyAnyOrder(m::Parameter(2), + m::Convert(m::ConstantScalar(8)))))); +} + +TEST_F(AllReduceSimplifierTest, AllReduceAfterAllReduce) { + const char* kModuleStr = R"( +HloModule m + +max { + a.1 = f32[] parameter(0) + b.1 = f32[] parameter(1) + ROOT max = f32[] maximum(a.1, b.1) +} + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + all-reduce = f32[8,16] all-reduce(p0), replica_groups={}, to_apply=max + ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), replica_groups={}, to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + AllReduceSimplifier simplifier(/*replica_count=*/8); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AllReduce(m::Parameter(0)), + m::Broadcast(m::Convert(m::ConstantScalar(8)))))); +} + +TEST_F(AllReduceSimplifierTest, SubgroupAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +max { + a.1 = f32[] parameter(0) + b.1 = f32[] parameter(1) + ROOT max = f32[] maximum(a.1, b.1) +} + +min { + a.2 = f32[] parameter(0) + b.2 = f32[] parameter(1) + ROOT min = f32[] minimum(a.2, b.2) +} + +test { + p0 = f32[8,16] parameter(0), parameter_replication={true} + p1 = f32[8,16] parameter(1), parameter_replication={false} + all-reduce = f32[8,16] all-reduce(p0), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=sum + all-reduce.1 = f32[8,16] all-reduce(p0), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=max + all-reduce.2 = f32[8,16] all-reduce(p1), replica_groups={{0,1,2,3},{4,5,6,7}}, to_apply=min + ROOT tuple = (f32[8,16], f32[8,16], f32[8,16]) tuple(all-reduce, all-reduce.1, all-reduce.2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + AllReduceSimplifier simplifier(/*replica_count=*/8); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::MultiplyAnyOrder(m::Parameter(0), + m::Broadcast(m::Convert(m::ConstantScalar(4)))), + m::Parameter(0), m::AllReduce(m::Parameter(1))))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 6cb0e985e57..ea56c75b2f2 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -20,13 +20,13 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { @@ -221,8 +221,8 @@ void AllocationTracker::AddAllocationOrIncrementRefCount( auto it = allocation_map.find(device_memory.opaque()); if (it == allocation_map.end()) { allocation_map[device_memory.opaque()] = { - OwningDeviceMemory(device_memory, device_ordinal, - backend_->memory_allocator()), + se::OwningDeviceMemory(device_memory, device_ordinal, + backend_->memory_allocator()), /*ref_count=*/1}; } else { it->second.ref_count++; diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 98d1a302a9f..6e7f9fdfc13 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -77,7 +77,7 @@ class AllocationTracker { // Data structure encapsulating single memory allocation on the device. struct Allocation { // The pointer to this allocation. - OwningDeviceMemory device_memory; + se::OwningDeviceMemory device_memory; // This is the number of times this memory allocation is referred to by // registered data handles. diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 52d6982c70f..dbd89911d92 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -36,8 +36,7 @@ namespace m = match; // Checks if the argument instruction is an AllReduce, followed by a certain // sequence of instructions and then a CRS. It must be possible to move -// the AR past each instruction in the sequence. Returns the CRS, which is the -// last instruction in the sequence. +// the AR past each instruction in the sequence. absl::optional ArCrsCombiner::MatchesArCrsPattern( HloInstruction* instruction) { auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { @@ -86,7 +85,9 @@ absl::optional ArCrsCombiner::MatchesArCrsPattern( } if (!Cast(next)->IsNoop() && computation_is_addition(next->called_computations()[0])) { - return absl::optional(ArCrsPair(instruction, next, distance)); + ArCrsPair pair(instruction, next, distance); + VLOG(2) << "ArCrsPair matching pattern: " << pair.ToString(); + return pair; } else { return absl::nullopt; } @@ -106,54 +107,124 @@ absl::optional ArCrsCombiner::WhileFromBodyParameter( return absl::nullopt; } -std::vector ArCrsCombiner::GetAllTuples( +absl::optional ArCrsCombiner::ConditionalFromBodyParameter( HloInstruction* instruction) { - if (instruction->opcode() == HloOpcode::kTuple) { - return {instruction}; - } - if (instruction->opcode() == HloOpcode::kDomain) { - return GetAllTuples(instruction->operands()[0]); - } - if (instruction->opcode() == HloOpcode::kParameter) { - auto maybe_while = WhileFromBodyParameter(instruction); - if (!maybe_while) { - return {}; + CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); + HloComputation* computation = instruction->parent(); + auto caller_instructions = call_graph_->GetComputationCallers(computation); + if (caller_instructions.size() == 1) { + auto caller_instruction = caller_instructions[0]; + if (caller_instruction->opcode() == HloOpcode::kConditional) { + return caller_instruction; } - auto while_instr = *maybe_while; - auto init_tuples = GetAllTuples(while_instr->while_init()); - auto body_tuples = - GetAllTuples(while_instr->while_body()->root_instruction()); - if (init_tuples.empty() || body_tuples.empty()) { - return {}; - } - init_tuples.insert(init_tuples.end(), body_tuples.begin(), - body_tuples.end()); - return init_tuples; } - if (instruction->opcode() == HloOpcode::kGetTupleElement) { - std::vector result_tuples; - for (auto tuple : GetAllTuples(instruction->operands()[0])) { - auto tmp_tuples = - GetAllTuples(tuple->mutable_operand(instruction->tuple_index())); - if (tmp_tuples.empty()) { - return {}; + return absl::nullopt; +} + +absl::optional> ArCrsCombiner::GetAllTuples( + HloInstruction* instruction, + absl::flat_hash_set* visited) { + if (visited->find(instruction) != visited->end()) { + return std::vector(); + } + visited->insert(instruction); + + switch (instruction->opcode()) { + case HloOpcode::kTuple: { + return std::vector({instruction}); + } + case HloOpcode::kDomain: { + return GetAllTuples(instruction->operands()[0], visited); + } + case HloOpcode::kParameter: { + auto maybe_while = WhileFromBodyParameter(instruction); + if (maybe_while) { + auto while_instr = *maybe_while; + auto init_tuples = GetAllTuples(while_instr->while_init(), visited); + auto body_tuples = GetAllTuples( + while_instr->while_body()->root_instruction(), visited); + if (!init_tuples || !body_tuples) { + return absl::nullopt; + } + auto result = *init_tuples; + result.insert(result.end(), body_tuples->begin(), body_tuples->end()); + return result; } - result_tuples.insert(result_tuples.end(), tmp_tuples.begin(), - tmp_tuples.end()); + auto maybe_conditional = ConditionalFromBodyParameter(instruction); + if (maybe_conditional) { + auto cond_instr = *maybe_conditional; + std::vector tuples; + for (int64 i = 0; i < cond_instr->branch_computations().size(); ++i) { + if (cond_instr->branch_computation(i)->parameter_instruction(0) == + instruction) { + // If the same computation is used for more than one branch of the + // conditional, we collect the arguments that flow to the + // computation from all branches. + auto branch_tuples = + GetAllTuples(cond_instr->mutable_operand(i + 1), visited); + if (!branch_tuples) { + return absl::nullopt; + } + tuples.insert(tuples.end(), branch_tuples->begin(), + branch_tuples->end()); + } + } + return tuples; + } + return absl::nullopt; } - return result_tuples; + case HloOpcode::kGetTupleElement: { + std::vector result_tuples; + auto tuples = GetAllTuples(instruction->operands()[0], visited); + if (!tuples) { + return absl::nullopt; + } + for (auto tuple : *tuples) { + auto tmp_tuples = GetAllTuples( + tuple->mutable_operand(instruction->tuple_index()), visited); + if (!tmp_tuples) { + return absl::nullopt; + } + result_tuples.insert(result_tuples.end(), tmp_tuples->begin(), + tmp_tuples->end()); + } + return result_tuples; + } + case HloOpcode::kConditional: { + std::vector result_tuples; + for (HloComputation* body : instruction->branch_computations()) { + if (body->root_instruction()->opcode() != HloOpcode::kTuple) { + return absl::nullopt; + } + result_tuples.push_back(body->root_instruction()); + } + return result_tuples; + } + case HloOpcode::kWhile: { + auto init_tuples = GetAllTuples(instruction->while_init(), visited); + auto body_tuples = + GetAllTuples(instruction->while_body()->root_instruction(), visited); + if (!init_tuples || !body_tuples) { + return absl::nullopt; + } + auto result = *init_tuples; + result.insert(result.end(), body_tuples->begin(), body_tuples->end()); + return result; + } + default: + return absl::nullopt; } - return {}; } bool ArCrsCombiner::TupleElementsComputeSameValue( HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, absl::flat_hash_map* visited_pairs) { - auto tuples = GetAllTuples(tuple_shaped_instruction); - if (tuples.empty()) { + absl::flat_hash_set visited; + auto tuples = GetAllTuples(tuple_shaped_instruction, &visited); + if (!tuples) { return false; } - for (auto tuple : tuples) { + for (auto tuple : *tuples) { CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), tuple->mutable_operand(i2), @@ -263,6 +334,8 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { if (prev_distance < pair.distance) { // The current AR's distance to CRS is longer than the previously // tracked AR, so we discard the previous AR. + VLOG(2) << "Replacing ArCrsPair: " << prev_pair.ToString() + << " with ArCrsPair: " << pair.ToString(); all_reduce_map_.erase(prev_ar_id); discarded_ar_ids.insert(prev_ar_id); all_reduce_map_[ar_id].push_back(pair); @@ -291,6 +364,8 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { auto all_reduce_id = it.first; + VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking ar_id: " + << all_reduce_id << "\n"; auto pairs_vec = it.second; CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); auto instr_0 = pairs_vec[0].ar; @@ -302,6 +377,8 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { while (true) { if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { all_reduce_map_.erase(all_reduce_id); + VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased ar_id: " + << all_reduce_id << "\n"; break; } if (next_0->IsCrossReplicaAllReduce()) { diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index f503e1d5f2b..4d17d5d8a31 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -93,8 +93,21 @@ class ArCrsCombiner : public HloModulePass { : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} string ToString() { - return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(), - ", distance: ", distance, ")"); + std::vector pieces; + pieces.push_back("("); + HloInstruction* instruction = ar; + while (instruction != crs) { + pieces.push_back(instruction->name()); + pieces.push_back(","); + instruction = instruction->users()[0]; + } + pieces.push_back(instruction->name()); + pieces.push_back(")[id:"); + pieces.push_back(std::to_string(*(ar->all_reduce_id()))); + pieces.push_back(",dist:"); + pieces.push_back(std::to_string(distance)); + pieces.push_back("]"); + return absl::StrJoin(pieces, ""); } }; @@ -106,10 +119,19 @@ class ArCrsCombiner : public HloModulePass { absl::optional WhileFromBodyParameter( HloInstruction* instruction); + // If the passed instruction is a parameter in one of the branch computations, + // and the branch body is only called by a single instruction, return the + // conditional instruction. + absl::optional ConditionalFromBodyParameter( + HloInstruction* instruction); + // Returns a vector of tuple instructions. // If all instructions that flow to "instruction" are tuples, return them. - // Otherwise, return an empty vector. - std::vector GetAllTuples(HloInstruction* instruction); + // Otherwise, return absl::nullopt. Returns an empty vector if the instruction + // is already in the visited set. + absl::optional> GetAllTuples( + HloInstruction* instruction, + absl::flat_hash_set* visited); // Checks whether two different elements in the same tuple compute the same // value. diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 9c9db74fd2f..0ea26f63b95 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -221,7 +221,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { @@ -258,7 +258,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { @@ -296,7 +296,7 @@ HloModule foobar %x = (f32[2,2], f32[2,2]) parameter(0) %constant.0 = s32[] constant(0) %constant.1 = s32[] constant(1) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT } %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { @@ -326,6 +326,55 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestNestedWhile) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + ROOT %t = pred[] constant(true) +} + +%body_inner (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) + %gte.1 = f32[2,2] get-tuple-element(%x), index=0 + %gte.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%gte.1, %constant.f32) + %add.2 = f32[2,2] add(%gte.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +%body_outer (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %gte.1 = f32[2,2] get-tuple-element(%x), index=0 + %gte.2 = f32[2,2] get-tuple-element(%x), index=1 + %init = (f32[2,2], f32[2,2]) tuple(%gte.1, %gte.2) + ROOT %while.1 = (f32[2,2], f32[2,2]) while(%init), condition=%condition, + body=%body_inner +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, + body=%body_outer +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + + auto root_while = module->entry_computation()->root_instruction(); + auto inner_while = root_while->while_body()->root_instruction(); + auto i1 = inner_while->while_body()->root_instruction()->operands()[0]; + auto i2 = inner_while->while_body()->root_instruction()->operands()[1]; + // They are the same because the same constant {{3, 4}, {5, 6}} flows to both, + // and we add the same number {{1, 2}, {3, 4}} to both in each iteration + // of the inner while. + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + void CompareReplicaGroups(const std::vector& groups_before, const std::vector& groups_after) { ASSERT_EQ(groups_before.size(), groups_after.size()); @@ -1173,5 +1222,47 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { EXPECT_FALSE(changed); } +TEST_F(ArCrsCombinerTest, SameValueTestConditional) { + const char* module_str = R"( +HloModule foobar + +branch_true { + pt = (f32[2,4], f32[2,4]) parameter(0) + gte.0 = f32[2,4] get-tuple-element(pt), index=0 + gte.1 = f32[2,4] get-tuple-element(pt), index=1 + ROOT tuple.t = (f32[2,4], f32[2,4]) tuple(gte.1, gte.0) +} + +branch_false { + pf = (f32[2,4], f32[2,4]) parameter(0) + gte.0 = f32[2,4] get-tuple-element(pf), index=0 + gte.1 = f32[2,4] get-tuple-element(pf), index=1 + add = f32[2,4] add(gte.1, gte.1) + ROOT tuple.f = (f32[2,4], f32[2,4]) tuple(gte.0, add) +} + +ENTRY Parameters1.v4 { + constant = pred[] constant(true) + p = f32[2,4] parameter(0) + tuple = (f32[2,4], f32[2,4]) tuple(p, p) + ROOT conditional = (f32[2,4], f32[2,4]) conditional(constant, tuple, tuple), true_computation=branch_true, false_computation=branch_false +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto cond = module->entry_computation()->root_instruction(); + + auto branch_true = cond->branch_computation(0)->root_instruction(); + auto t0 = branch_true->mutable_operand(0); + auto t1 = branch_true->mutable_operand(1); + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(t0, t1)); + + auto branch_false = cond->branch_computation(1)->root_instruction(); + auto f0 = branch_false->mutable_operand(0); + auto f1 = branch_false->mutable_operand(1); + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(f0, f1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index d016d3e03d5..d859f647ea0 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -66,38 +66,16 @@ const absl::optional>& BackendOptions::allowed_devices() const { return allowed_devices_; } -namespace { - -class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { - public: - explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool) - : pool_(pool) {} - ~EigenThreadPoolWrapper() override {} - - void Schedule(std::function fn) override { - pool_->Schedule(std::move(fn)); - } - int NumThreads() const override { return pool_->NumThreads(); } - int CurrentThreadId() const override { return pool_->CurrentThreadId(); } - - private: - tensorflow::thread::ThreadPool* pool_ = nullptr; -}; - -} // namespace - // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct Backend::IntraOpThreadPool { explicit IntraOpThreadPool(const int num_threads) : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), "XLAEigen", num_threads)), - wrapper(new EigenThreadPoolWrapper(pool.get())), - device(new Eigen::ThreadPoolDevice(wrapper.get(), - wrapper->NumThreads())) {} + device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(), + pool->NumThreads())) {} std::unique_ptr pool; - std::unique_ptr wrapper; std::unique_ptr device; }; @@ -156,7 +134,7 @@ Backend::Backend(se::Platform* platform, Compiler* compiler, } } // Create a memory allocator for the valid stream executors. - memory_allocator_ = absl::make_unique( + memory_allocator_ = absl::make_unique( platform, stream_executors); CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index e7f29a044b9..79fdeb2b0bc 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -27,7 +27,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" @@ -35,6 +34,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace Eigen { struct ThreadPoolDevice; @@ -88,7 +88,7 @@ class Backend { // Accessors for the various objects. se::Platform* platform() const { return platform_; } Compiler* compiler() const { return compiler_; } - DeviceMemoryAllocator* memory_allocator() const { + se::DeviceMemoryAllocator* memory_allocator() const { return memory_allocator_.get(); } TransferManager* transfer_manager() const { return transfer_manager_; } @@ -179,7 +179,7 @@ class Backend { stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. - std::unique_ptr memory_allocator_; + std::unique_ptr memory_allocator_; // For the CPU backend, an Eigen threadpool device for use by Eigen code. struct IntraOpThreadPool; diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index dbabd82dd55..72112585cb3 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -23,6 +23,24 @@ namespace xla { StatusOr BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( HloInstruction* batch_dot) { + // This pass assumes the lhs and rhs batch dimensions are equal and strictly + // ascending. + const auto& is_iota = [](absl::Span dims) { + for (int64 i = 0; i < dims.size(); ++i) { + if (dims[i] != i) { + return false; + } + } + return true; + }; + if (!absl::c_equal( + batch_dot->dot_dimension_numbers().lhs_batch_dimensions(), + batch_dot->dot_dimension_numbers().rhs_batch_dimensions()) || + !is_iota(AsInt64Slice( + batch_dot->dot_dimension_numbers().lhs_batch_dimensions()))) { + return false; + } + const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers(); HloInstruction *lhs = batch_dot->mutable_operand(0), *rhs = batch_dot->mutable_operand(1); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index e62d72b323b..0a8e8dc2a8b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -29,8 +29,11 @@ namespace xla { class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { public: explicit BFloat16ConversionFoldingVisitor( - HloComputation* computation, const BFloat16Support* bfloat16_support) - : computation_(computation), bfloat16_support_(bfloat16_support) {} + HloComputation* computation, const BFloat16Support* bfloat16_support, + BFloat16ConversionFolding* bfloat16_conversion_folding) + : computation_(computation), + bfloat16_support_(bfloat16_support), + bfloat16_conversion_folding_(bfloat16_conversion_folding) {} Status DefaultAction(HloInstruction* hlo) override; @@ -38,8 +41,10 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status HandleAllReduce(HloInstruction* crs) override; static bool Run(HloComputation* computation, - const BFloat16Support* bfloat16_support) { - BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); + const BFloat16Support* bfloat16_support, + BFloat16ConversionFolding* bfloat16_conversion_folding) { + BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support, + bfloat16_conversion_folding); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -61,6 +66,7 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { HloComputation* computation_; const BFloat16Support* bfloat16_support_; + BFloat16ConversionFolding* bfloat16_conversion_folding_; bool changed_ = false; }; @@ -68,6 +74,7 @@ Status BFloat16ConversionFoldingVisitor::FoldOutputConversions( HloInstruction* hlo) { std::vector materialized_users = hlo->users(); hlo->mutable_shape()->set_element_type(BF16); + bfloat16_conversion_folding_->UpdateLayout(hlo->mutable_shape()); for (auto user : materialized_users) { CHECK_EQ(user->opcode(), HloOpcode::kConvert); TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); @@ -228,6 +235,8 @@ Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}) ->set_element_type(BF16); + bfloat16_conversion_folding_->UpdateLayout( + ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})); for (auto gte : per_tuple_element_gtes[i]) { TF_RETURN_IF_ERROR(FoldOutputConversions(gte)); } @@ -241,7 +250,7 @@ StatusOr BFloat16ConversionFolding::Run(HloModule* module) { 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) { + if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_, this)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 72459961485..dc6ed897a6b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -29,18 +30,16 @@ namespace xla { class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { public: - explicit BFloat16NormalizationVisitor(HloComputation* computation, - const BFloat16Support* bfloat16_support) - : computation_(computation), bfloat16_support_(bfloat16_support) {} + explicit BFloat16NormalizationVisitor( + const BFloat16Support* bfloat16_support, + BFloat16Normalization* bfloat16_normalization) + : computation_(nullptr), + bfloat16_support_(bfloat16_support), + bfloat16_normalization_(bfloat16_normalization) {} + bool changed() const { return changed_; } Status DefaultAction(HloInstruction* hlo) override; - - static bool Run(HloComputation* computation, - const BFloat16Support* bfloat16_support) { - BFloat16NormalizationVisitor visitor(computation, bfloat16_support); - TF_CHECK_OK(computation->Accept(&visitor)); - return visitor.changed_; - } + Status Preprocess(HloInstruction* hlo) override; private: // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts @@ -73,6 +72,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { HloComputation* computation_; const BFloat16Support* bfloat16_support_; + BFloat16Normalization* bfloat16_normalization_; bool changed_ = false; }; @@ -95,6 +95,7 @@ Status BFloat16NormalizationVisitor::InsertConvertAfterOutput( computation->set_root_instruction(convert); } convert->mutable_shape()->set_element_type(to); + bfloat16_normalization_->UpdateLayout(convert->mutable_shape()); changed_ = true; return Status::OK(); } @@ -103,6 +104,7 @@ Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { auto original_type = hlo->shape().element_type(); hlo->mutable_shape()->set_element_type(to); + bfloat16_normalization_->UpdateLayout(hlo->mutable_shape()); return InsertConvertAfterOutput(hlo, original_type, computation); } @@ -110,8 +112,10 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( HloInstruction* hlo, int64 operand_idx, PrimitiveType to, HloComputation* computation) { auto operand = hlo->mutable_operand(operand_idx); - auto convert = computation->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(operand->shape(), to), operand)); + auto shape = ShapeUtil::ChangeElementType(operand->shape(), to); + bfloat16_normalization_->UpdateLayout(&shape); + auto convert = computation->AddInstruction( + HloInstruction::CreateConvert(shape, operand)); TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert)); changed_ = true; return Status::OK(); @@ -243,11 +247,13 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( continue; } subshape->set_element_type(F32); + bfloat16_normalization_->UpdateLayout(subshape); auto gte = computation_->AddInstruction( HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); + auto shape = ShapeUtil::ChangeElementType(*subshape, BF16); + bfloat16_normalization_->UpdateLayout(&shape); output_elements[i] = - computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(*subshape, BF16), gte)); + computation_->AddInstruction(HloInstruction::CreateConvert(shape, gte)); } auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); @@ -396,18 +402,21 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { return HandleInstruction(hlo); } +Status BFloat16NormalizationVisitor::Preprocess(HloInstruction* hlo) { + computation_ = hlo->parent(); + return Status::OK(); +} + StatusOr BFloat16Normalization::Run(HloModule* module) { XLA_VLOG_LINES( 2, "BFloat16Normalization::Run(), before:\n" + module->ToString()); - bool changed = false; + BFloat16NormalizationVisitor visitor(bfloat16_support_, this); for (auto* comp : module->MakeComputationPostOrder()) { - if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) { - changed = true; - } + TF_RETURN_IF_ERROR(comp->Accept(&visitor)); } XLA_VLOG_LINES(2, "BFloat16Normalization::Run(), after:\n" + module->ToString()); - return changed; + return visitor.changed(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index bab63f66d83..d314065c752 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -674,11 +674,13 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { if (hlo->opcode() != HloOpcode::kConstant) { continue; } - if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { + if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(), + hlo->shape())) { TF_ASSIGN_OR_RETURN(auto converted_literal, hlo->literal().ConvertToShape(hlo->shape())); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); + UpdateLayout(new_constant->mutable_shape()); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); } } @@ -797,6 +799,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { auto subshape = entry.first; CHECK_EQ(subshape->element_type(), F32); subshape->set_element_type(BF16); + UpdateLayout(subshape); changed_ = true; } } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index a9b5d9916e4..357d38a5548 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -109,8 +109,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); - HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b)); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq)); HloInstruction* sel = builder.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); HloInstruction* xpose = @@ -574,8 +574,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { HloInstruction::CreateParameter(0, shape, "cond_param")); auto cond_dot = builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); - auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction( @@ -583,9 +583,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { cond_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -631,8 +632,8 @@ TEST_F(BFloat16PropagationTest, auto builder_cond = HloComputation::Builder("cond"); auto cond_param = builder_cond.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); - builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction(HloInstruction::CreateSlice( @@ -642,7 +643,8 @@ TEST_F(BFloat16PropagationTest, ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, - {1, 1})))))); + {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -705,8 +707,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); auto cond_dot = builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); - builder_cond.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond.AddInstruction( @@ -714,9 +716,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { cond_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -800,8 +803,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); auto cond0_dot = builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); - builder_cond0.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond0.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond0.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond0.AddInstruction( @@ -809,9 +812,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { cond0_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond0.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond0.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond0_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); // Condition computation for the second while. @@ -828,8 +832,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); auto cond1_dot = builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); - builder_cond1.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond1.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), builder_cond1.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), builder_cond1.AddInstruction( @@ -837,9 +841,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { cond1_dot, {0, 0}, {1, 1}, {1, 1})))), builder_cond1.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {}), - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2}, - {1, 1})))))); + builder_cond1.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond1_dot, {1, 1}, {2, 2}, {1, 1})))), + ComparisonDirection::kGt)); auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); // Body computation shared by both whiles. diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 2b9502f63a8..abb695fa486 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -105,6 +105,8 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( return operand_index == 0; case HloOpcode::kDynamicUpdateSlice: return operand_index == 0 || operand_index == 1; + case HloOpcode::kGather: + return operand_index == 0; case HloOpcode::kSelect: case HloOpcode::kTupleSelect: return operand_index == 1 || operand_index == 2; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index cb682f49a5c..aa57f28448e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -737,9 +738,11 @@ StatusOr> BufferAssigner::Run( LogicalBuffer::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing, bool allocate_buffers_for_constants, - BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) { + BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker, + ReuseColocatedAllocationForTempChecker reuse_colocated_checker) { BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer), - std::move(reuse_checker)); + std::move(reuse_checker), + std::move(reuse_colocated_checker)); return assigner.CreateAssignment(module, std::move(hlo_ordering), std::move(buffer_size), std::move(color_alignment)); @@ -977,8 +980,36 @@ Status BufferAssigner::AssignBuffersForComputation( }); // BufferAllocations are necessarily created in decreasing size order. Keep - // indices of previously created BufferAllocations in allocation_indices. - std::vector allocation_indices; + // indices of previously created BufferAllocations in new_allocation_indices. + std::vector new_allocation_indices; + + // A sorted multimap from size to indices of colocated allocations. + std::multimap + colocated_allocation_size_to_indices; + { + std::priority_queue sorted_colocated_indices; + for (auto index : colocated_allocations) { + bool consider_reusing = true; + // Output tuple table may be allocated at run-time, so make sure we don't + // overwrite them. + for (const auto& buffer_offset_size : + assignment->GetAllocation(index).assigned_buffers()) { + if (buffer_offset_size.first->shape().IsTuple()) { + consider_reusing = false; + break; + } + } + if (consider_reusing) { + sorted_colocated_indices.push(index); + } + } + while (!sorted_colocated_indices.empty()) { + auto index = sorted_colocated_indices.top(); + sorted_colocated_indices.pop(); + colocated_allocation_size_to_indices.emplace( + assignment->GetAllocation(index).size(), index); + } + } for (const LogicalBuffer* buffer : sorted_buffers) { VLOG(3) << "Assigning allocation to: " << *buffer; if (colocated_buffers.contains(buffer)) { @@ -1074,25 +1105,47 @@ Status BufferAssigner::AssignBuffersForComputation( } } + if (reuse_colocated_checker_ != nullptr && + reuse_colocated_checker_(*buffer, buffer_size) && + !assignment->HasAllocation(*buffer)) { + // Find the smallest buffer which can be reused iterating from the lower + // bound of the buffer size in colocated_allocation_size_to_indices. + auto it = colocated_allocation_size_to_indices.lower_bound(buffer_size); + while (it != colocated_allocation_size_to_indices.end()) { + CHECK_GE(it->first, buffer_size); + BufferAllocation* allocation = + assignment->GetMutableAllocation(it->second); + if (MaybeAssignBuffer(allocation, *buffer, assignment)) { + VLOG(3) << "Reusing allocation #" << allocation->index() + << " for: " << *buffer; + // We remove the assigned allocation from + // colocated_allocation_size_to_indices to prevent putting too many + // buffers into collocated allocations, and to reduce the search space + // for subsequent buffers. This is to avoid excessive pairwise checks + // for interference that may slow down compilation. The heap simulator + // is more efficient in live range checks. + // + // Another benefit of removing the allocation is that the reused + // allocation will be less likely to contain interferences that + // prevent operand-output reuse, which is important for in-place + // dynamic update slices. + colocated_allocation_size_to_indices.erase(it); + break; + } + ++it; + } + } + if (!assignment->HasAllocation(*buffer)) { // Find the smallest buffer which can be reused iterating from end of - // allocation_indices (smallest) to beginning (largest). - for (int allocation_index = allocation_indices.size() - 1; + // new_allocation_indices (smallest) to beginning (largest). + for (int allocation_index = new_allocation_indices.size() - 1; allocation_index >= 0; allocation_index--) { BufferAllocation* allocation = assignment->GetMutableAllocation( - allocation_indices[allocation_index]); + new_allocation_indices[allocation_index]); // Instructions are iterated in increasing buffer size, so any // previously create allocation must be large enough to hold this - // instruction's output (with the exception of colocated buffers). - if (!colocated_allocations.contains(allocation->index())) { - // TODO(b/32491382) Colocated buffers are currently assigned in an - // earlier pass, and so can break the "increasing allocation size" - // invariant in this function (causing this CHECK to fail). However, - // the call to MaybeAssignBuffer is safe as it returns false if - // allocation.size < buffer.size. - CHECK_GE(allocation->size(), buffer_size); - } - + // instruction's output. if (MaybeAssignBuffer(allocation, *buffer, assignment)) { VLOG(3) << "Reusing allocation #" << allocation->index() << " for: " << *buffer; @@ -1121,7 +1174,7 @@ Status BufferAssigner::AssignBuffersForComputation( if (!assignment->HasAllocation(*buffer)) { BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size); - allocation_indices.push_back(allocation->index()); + new_allocation_indices.push_back(allocation->index()); VLOG(3) << "New allocation #" << allocation->index() << " for: " << *buffer; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 448dec3b1aa..41adf1b80a5 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -540,6 +540,11 @@ class BufferAssigner { const BufferAssignment& assignment, const BufferAllocation& alloc, const LogicalBuffer& buffer)>; + // Returns whether a logical buffer can be considered reusing memory for + // colocated buffers. + using ReuseColocatedAllocationForTempChecker = + std::function; + // Build and return a BufferAssignment for the given module. The given // HloOrdering is used to determine buffer liveness. buffer_size and // color_alignment are functions which returns the size and alignment of a @@ -552,15 +557,18 @@ class BufferAssigner { bool allow_input_output_aliasing = false, bool allocate_buffers_for_constants = false, BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer(), - ReuseAllocationFunction reuse_checker = nullptr); + ReuseAllocationFunction reuse_checker = nullptr, + ReuseColocatedAllocationForTempChecker reuse_colocated_checker = nullptr); private: BufferAssigner(bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer, - ReuseAllocationFunction reuse_checker) + ReuseAllocationFunction reuse_checker, + ReuseColocatedAllocationForTempChecker reuse_colocated_checker) : allocate_buffers_for_constants_(allocate_buffers_for_constants), - colorer_(colorer), - reuse_checker_(reuse_checker) {} + colorer_(std::move(colorer)), + reuse_checker_(std::move(reuse_checker)), + reuse_colocated_checker_(std::move(reuse_colocated_checker)) {} virtual ~BufferAssigner() = default; // Create a buffer assignment. @@ -657,6 +665,9 @@ class BufferAssigner { // Functor to check if a buffer can reuse an allocation. ReuseAllocationFunction reuse_checker_; + // Functor to check if a temp buffer can reuse a colocated allocation. + ReuseColocatedAllocationForTempChecker reuse_colocated_checker_; + TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 580bc2f4338..acdf5d25e1d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -151,6 +151,24 @@ class BufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr + RunBufferAssignmentWithReusingColocatedBuffersForTemp(HloModule* module, + int64 alignment = 1) { + return BufferAssigner::Run( + module, absl::make_unique(module), + backend().compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true, + /*colorer=*/BufferLiveness::DefaultColorer(), + /*reuse_checker=*/nullptr, + /*reuse_colocated_checker=*/ + [](const LogicalBuffer& buffer, int64 byte_size) { + return true; + }) + .ConsumeValueOrDie(); + } + // Builds an x+1.0 computation to use in a Map. std::unique_ptr BuildMapComputationPlus1(const string& name) { auto builder = HloComputation::Builder(name); @@ -190,8 +208,9 @@ class BufferAssignmentTest : public HloTestBase { HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index, + const4, ComparisonDirection::kLt)); return builder.Build(); } @@ -499,6 +518,75 @@ TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) { EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index()); } +TEST_F(BufferAssignmentTest, ReuseColocatedBuffersForTemp) { + const char* const hlo_string = R"( +HloModule test + +sum (a: f32[], b: f32[]) -> f32[] { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +while_body { + state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) + get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1 + get-tuple-element.3 = s32[] get-tuple-element(state), index=0 + constant.2 = s32[] constant(128) + add.5 = s32[] add(get-tuple-element.3, constant.2) + broadcast = f32[2,1280,1,128]{3,2,1,0} broadcast(get-tuple-element.4), dimensions={1,2,3} + constant.3 = s32[] constant(0) + reduce = f32[1280,1,128]{2,1,0} reduce(broadcast, constant.3), dimensions={3}, to_apply=sum + ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, reduce) +} + +while_condition { + state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) + get-tuple-element = s32[] get-tuple-element(state), index=0 + get-tuple-element.1 = s32[] constant(3) + ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT +} + +sum.1 (a: f32[], b: f32[]) -> f32[] { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry_computation { + parameter = f32[2,1280,1,128]{3,2,1,0} parameter(0) + constant.6 = f32[] constant(0) + reduce.1 = f32[1280,1,128]{2,1,0} reduce(parameter, constant.6), dimensions={3}, to_apply=sum.1 + constant.7 = s32[] constant(0) + tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(constant.7, reduce.1) + while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body + get-tuple-element.1 = f32[1280,1,128] get-tuple-element(while.0), index=1 + ROOT broadcast.1 = f32[2,1280,1,128]{3,2,1,0} broadcast(get-tuple-element.1), dimensions={1,2,3} +} + +)"; + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module = module_or_status.ConsumeValueOrDie(); + + TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias( + {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); + + auto assignment = + RunBufferAssignmentWithReusingColocatedBuffersForTemp(module.get()); + // Get BufferAllocation for root instruction. + auto broadcast = FindInstruction(module.get(), "broadcast"); + auto broadcast_alloc_slice = + assignment->GetUniqueTopLevelSlice(broadcast).ConsumeValueOrDie(); + auto parameter = FindInstruction(module.get(), "parameter"); + auto parameter_alloc_slice = + assignment->GetUniqueTopLevelSlice(parameter).ConsumeValueOrDie(); + + EXPECT_EQ(broadcast_alloc_slice.allocation(), + parameter_alloc_slice.allocation()); + EXPECT_EQ(broadcast_alloc_slice, parameter_alloc_slice); +} + TEST_F(BufferAssignmentTest, AddCannotReuse) { // Pass in a special rule to indicate that "add" cannot reuse any buffer. // @@ -1863,8 +1951,8 @@ class WhileBufferAssignmentTest : public HloTestBase { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto ten = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt)); return builder.Build(); } @@ -2135,8 +2223,9 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param, + const4, ComparisonDirection::kLt)); return builder.Build(); }; @@ -2530,7 +2619,7 @@ while_condition { state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) get-tuple-element = s32[] get-tuple-element(state), index=0 get-tuple-element.1 = s32[] constant(3) - ROOT less-than.339.338 = pred[] less-than(get-tuple-element, get-tuple-element.1) + ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT } ENTRY entry_computation { diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 9b2783a214a..3adf129a22d 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -114,12 +114,13 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // If the root instruction aliases the buffer 'a', the live range of 'a' is // until the end of the computation and can never be strictly before another - // buffer defined in the same computation. This is needed to prevent the - // root instruction's buffers from being reused by later instructions even - // when the root is not the last instruction in the schedule. + // buffer nested in the same computation. This is needed to prevent the root + // instruction's buffers from being reused by later instructions even when + // the root is not the last instruction in the schedule. if (alias.instruction()->parent()->root_instruction() == alias.instruction() && - alias.instruction()->parent() == b.instruction()->parent()) { + hlo_ordering_->call_graph().InstructionIsNestedIn( + b.instruction(), alias.instruction()->parent())) { return false; } } @@ -147,15 +148,20 @@ bool IsEntryParameter(const HloInstruction* instruction) { bool BufferLiveness::MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const { - // Entry parameters live at the entry of the execution, thus always interfere - // with all other instructions executing before them in the ordering. + // Parameters live at the entry of the computation, thus always interfere with + // all other instructions inside the computation executing before them in the + // ordering. const HloInstruction* a_instruction = a.instruction(); const HloInstruction* b_instruction = b.instruction(); - if (IsEntryParameter(a_instruction) && + if (a_instruction->opcode() == HloOpcode::kParameter && + hlo_ordering_->call_graph().InstructionIsNestedIn( + b_instruction, a_instruction->parent()) && hlo_ordering_->ExecutesBefore(b_instruction, a_instruction)) { return true; } - if (IsEntryParameter(b_instruction) && + if (b_instruction->opcode() == HloOpcode::kParameter && + hlo_ordering_->call_graph().InstructionIsNestedIn( + a_instruction, b_instruction->parent()) && hlo_ordering_->ExecutesBefore(a_instruction, b_instruction)) { return true; } diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 23b9af0281b..79812923911 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -196,6 +197,80 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); } +TEST_F(BufferLivenessTest, EmbeddedComputationParameters) { + absl::string_view hlo_string = R"( +HloModule EmbeddedComputationParameters, is_scheduled=true + +%EmbeddedComputationParameters_embedded (embedded_param0: f32[42], embedded_param1: f32[42]) -> (f32[42], f32[42]) { + %embedded_param0 = f32[42]{0} parameter(0) + %log = f32[42]{0} log(f32[42]{0} %embedded_param0) + %add = f32[42]{0} add(f32[42]{0} %log, f32[42]{0} %log) + %embedded_param1 = f32[42]{0} parameter(1) + ROOT %tuple = (f32[42]{0}, f32[42]{0}) tuple(f32[42]{0} %add, f32[42]{0} %embedded_param1) +} + +ENTRY %EmbeddedComputationParameters (param0: f32[42], param1: f32[42]) -> (f32[42], f32[42]) { + %param0 = f32[42]{0} parameter(0) + %param1 = f32[42]{0} parameter(1) + ROOT %call = (f32[42]{0}, f32[42]{0}) call(f32[42]{0} %param0, f32[42]{0} %param1), to_apply=%EmbeddedComputationParameters_embedded +} +)"; + HloModuleConfig hlo_config; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string, hlo_config)); + auto liveness = + BufferLiveness::Run( + module.get(), + absl::make_unique(module->schedule())) + .ConsumeValueOrDie(); + + auto embedded_log = FindInstruction(module.get(), "log"); + auto embedded_param0 = FindInstruction(module.get(), "embedded_param0"); + auto embedded_param1 = FindInstruction(module.get(), "embedded_param1"); + auto param0 = FindInstruction(module.get(), "param0"); + auto param1 = FindInstruction(module.get(), "param1"); + + // Parameters should interfere with other instructions inside the computation. + EXPECT_TRUE( + InstructionsMayInterfere(*liveness, embedded_log, embedded_param1)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, param0)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, param1)); + EXPECT_TRUE( + InstructionsMayInterfere(*liveness, embedded_param0, embedded_param1)); +} + +TEST_F(BufferLivenessTest, InterferenceWithOuterRoot) { + absl::string_view hlo_string = R"( +HloModule InterferenceWithOuterRoot, is_scheduled=true + +Emmbedded (embedded_param: f32[42]) -> f32[42] { + embedded_param = f32[42]{0} parameter(0) + multiply = f32[42]{0} multiply(embedded_param, embedded_param) + ROOT log = f32[42]{0} log(multiply) +} + +ENTRY InterferenceWithOuterRoot { + param = f32[4096,4096]{1,0} parameter(0) + ROOT add = f32[4096,4096]{1,0} add(param, param) + call = f32[42]{0} call(param), to_apply=Emmbedded +} + +)"; + HloModuleConfig hlo_config; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string, hlo_config)); + auto liveness = + BufferLiveness::Run( + module.get(), + absl::make_unique(module->schedule())) + .ConsumeValueOrDie(); + + auto multiply = FindInstruction(module.get(), "multiply"); + auto add = FindInstruction(module.get(), "add"); + + EXPECT_TRUE(InstructionsMayInterfere(*liveness, multiply, add)); +} + TEST_F(BufferLivenessTest, NonElementwiseOperand) { // A chain of operations with two elementwise and one non-elementwise. The // elementwise op should not interfere with its operand, while the diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 98304757cae..e13e6040ff2 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -277,8 +277,8 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { // Constructor for CallGraph is private so absl::make_unique can't be used. auto call_graph = absl::WrapUnique(new CallGraph(module)); - VLOG(2) << "Building call graph for:"; - XLA_VLOG_LINES(2, module->ToString()); + VLOG(3) << "Building call graph for:"; + XLA_VLOG_LINES(3, module->ToString()); // Construct nodes of the call graph and populate the callsites. for (HloComputation* computation : module->computations()) { @@ -309,7 +309,7 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { call_graph->SetCallContexts(); call_graph->SetNodeDepths(); - XLA_VLOG_LINES(1, call_graph->ToString()); + XLA_VLOG_LINES(2, call_graph->ToString()); return call_graph; } diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 5de724f8924..458aef14999 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -83,8 +83,9 @@ class CallGraphTest : public HloTestBase { HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + zero, ComparisonDirection::kGt)); return builder.Build(); } diff --git a/tensorflow/compiler/xla/service/cholesky_expander.cc b/tensorflow/compiler/xla/service/cholesky_expander.cc index 1c39cf9bc0a..c4979ad5d4c 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.cc +++ b/tensorflow/compiler/xla/service/cholesky_expander.cc @@ -99,7 +99,7 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) - auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); + auto diag_dot = BatchDot(row, false, row, true, precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = Sqrt(a_ii - diag_dot); @@ -114,7 +114,7 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { // The columns in [i, n] are zeroed out in `row`, so we just have to // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], // r.T) - auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); + auto dot = BatchDot(body_l, false, row, true, precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot); @@ -178,7 +178,7 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision); + auto delta = BatchDot(lhs, false, rhs, true, precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } diff --git a/tensorflow/compiler/xla/service/compilation_stats.cc b/tensorflow/compiler/xla/service/compilation_stats.cc new file mode 100644 index 00000000000..a800e92bd50 --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_stats.cc @@ -0,0 +1,132 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compilation_stats.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/env.h" + +namespace xla { + +class NoopStats : public CompilationStats { + public: + NoopStats() = default; + + void StartPass(absl::string_view pass_name) override {} + + void EndPass(absl::string_view pass_name) override {} + + void CompilationReport() override {} +}; + +class Stats : public CompilationStats { + public: + Stats() = default; + + void StartPass(absl::string_view pass_name) override; + + void EndPass(absl::string_view pass_name) override; + + void CompilationReport() override; + + private: + struct PassInfo { + PassInfo(absl::string_view name, double duration) + : name(name), duration_ms(duration) {} + + absl::string_view name; + int num_runs = 1; + double duration_ms; + }; + + // Info about the passes that have been run so far. + std::vector passes_; + // Used to avoid nested calls to StartPass. + bool pass_running_ = false; + absl::string_view current_pass_; + // The start time of the currently running pass. + uint64 start_micros_; +}; + +/* static */ +std::unique_ptr CompilationStats::MakeNoopStats() { + return absl::make_unique(); +} + +/* static */ +std::unique_ptr CompilationStats::MakeStats() { + return absl::make_unique(); +} + +void Stats::StartPass(absl::string_view pass_name) { + CHECK(!pass_running_) << "Can't start " << pass_name << " while running " + << current_pass_; + pass_running_ = true; + current_pass_ = pass_name; + start_micros_ = tensorflow::Env::Default()->NowMicros(); +} + +void Stats::EndPass(absl::string_view pass_name) { + CHECK(pass_running_); + CHECK_EQ(current_pass_, pass_name); + pass_running_ = false; + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); + double duration_ms = (end_micros - start_micros_) / 1000.0; + passes_.push_back(PassInfo(current_pass_, duration_ms)); +} + +void Stats::CompilationReport() { + CHECK(!pass_running_) << "EndPass never called for " << current_pass_; + absl::flat_hash_map summary; + double total_duration = 0; + + for (auto& pass_run : passes_) { + auto pass_name = pass_run.name; + total_duration += pass_run.duration_ms; + auto it = summary.find(pass_name); + if (it == summary.end()) { + summary.insert(std::make_pair(pass_name, pass_run)); + } else { + ++summary.at(pass_name).num_runs; + summary.at(pass_name).duration_ms += pass_run.duration_ms; + } + } + + std::vector sorted_summary; + sorted_summary.reserve(summary.size()); + for (auto& it : summary) { + sorted_summary.push_back(it.second); + } + absl::c_sort(sorted_summary, [](const PassInfo& a, const PassInfo& b) { + // Sort passes that take the longest first, break ties using pass names. + return std::make_pair(b.duration_ms, a.name) < + std::make_pair(a.duration_ms, b.name); + }); + LOG(INFO) << "Total runtime (ms) of HLO passes: " << total_duration; + LOG(INFO) << "Pass name, num runs, time (ms)"; + for (auto& pass_info : sorted_summary) { + LOG(INFO) << pass_info.name << ", " << pass_info.num_runs << ", " + << pass_info.duration_ms; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_stats.h b/tensorflow/compiler/xla/service/compilation_stats.h new file mode 100644 index 00000000000..9b748d0c7fe --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_stats.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_ + +#include +#include + +#include "absl/strings/str_format.h" + +namespace xla { + +// This class is used to collect information about HLO passes and print some +// statistics at the end of compilation. From HloPassPipeline, we call StartPass +// before the execution of a pass, and EndPass after. Currently, we only collect +// timing information and how many times each pass was run. In the future, we +// can add more things, such as the size of the HLO graph after each pass. +class CompilationStats { + public: + virtual ~CompilationStats() = default; + + static std::unique_ptr MakeNoopStats(); + + static std::unique_ptr MakeStats(); + + virtual void StartPass(absl::string_view pass_name) = 0; + + virtual void EndPass(absl::string_view pass_name) = 0; + + virtual void CompilationReport() = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_STATS_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 1965925fa7f..a4758c2b9db 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -70,26 +71,14 @@ CompileOnlyService::CompileAheadOfTime( TF_RET_CHECK(instance.computation.has_host_program_shape()); const DebugOptions& debug_options = options.debug_options(); - - // Dump computation proto if flag is set. - const string& directory_path = debug_options.xla_dump_computations_to(); - if (!directory_path.empty()) { - HloSnapshot hlo_snapshot; - *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; - string filename = - absl::StrCat("computation_", instance.computation.id(), "__", - instance.computation.entry_computation_name()); - const string& per_host_path = tensorflow::io::JoinPath( - directory_path, tensorflow::port::Hostname()); - - TF_RETURN_IF_ERROR( - Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); - } - ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; *execution_options.mutable_shape_with_output_layout() = instance.result_layout->ToProto(); + if (options.has_static_device_assignment()) { + TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize( + execution_options.mutable_device_assignment())); + } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig( @@ -99,7 +88,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); + DumpHloModuleIfEnabled(*hlo_module, "before_optimizations"); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index d4db95da8eb..631a7dd7e6a 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -74,20 +75,34 @@ class AotCompilationOptions { // Optional allocator that may be used for allocating temp space on the device // during compilation. - DeviceMemoryAllocator* device_allocator() const { return device_allocator_; } - void set_device_allocator(DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator() const { + return device_allocator_; + } + void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) { device_allocator_ = device_allocator; } const DebugOptions& debug_options() const { return debug_options_; } DebugOptions* mutable_debug_options() { return &debug_options_; } + bool has_static_device_assignment() const { + return static_device_assignment_.has_value(); + } + const DeviceAssignment& static_device_assignment() const { + CHECK(static_device_assignment_.has_value()); + return *static_device_assignment_; + } + void set_static_device_assignment(const DeviceAssignment& device_assignment) { + static_device_assignment_ = device_assignment; + } + protected: AotCompilationOptions(); private: - DeviceMemoryAllocator* device_allocator_ = nullptr; + se::DeviceMemoryAllocator* device_allocator_ = nullptr; DebugOptions debug_options_; + absl::optional static_device_assignment_; }; // Abstract superclass describing metadata produced during ahead-of-time @@ -134,14 +149,14 @@ class Compiler { // allocated should be deallocated before this function returns. virtual StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* executor, - DeviceMemoryAllocator* device_allocator) = 0; + se::DeviceMemoryAllocator* device_allocator) = 0; // Optimizes a HLO module group, a set of module which runs concurrently on // multiple devices potentially communicating data between the modules. virtual Status RunHloPassesOnModuleGroup( HloModuleGroup* module_group, absl::Span executors, - DeviceMemoryAllocator* device_allocator) = 0; + se::DeviceMemoryAllocator* device_allocator) = 0; // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are @@ -155,7 +170,7 @@ class Compiler { // device_allocator is optional; see RunHloPasses. virtual StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, - DeviceMemoryAllocator* device_allocator) = 0; + se::DeviceMemoryAllocator* device_allocator) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules. @@ -163,7 +178,7 @@ class Compiler { RunBackendOnModuleGroup( std::unique_ptr module_group, std::vector> stream_exec, - DeviceMemoryAllocator* device_allocator) = 0; + se::DeviceMemoryAllocator* device_allocator) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding @@ -176,7 +191,7 @@ class Compiler { virtual StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, - DeviceMemoryAllocator* device_allocator) = 0; + se::DeviceMemoryAllocator* device_allocator) = 0; // Returns the backend configurations that the backend will consider for the // given HLO. Returns no configurations if the backend does not support diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index f1d0ca44f08..301ac9cc3d4 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -21,24 +21,28 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { +namespace { // Tries to replace a conditional with a call operation of the corresponding // computation. If the given conditional has a constant branch_index, tries to // replace it with a call to its corresponding branch computation and then // inline that computation. // // Returns true if it made a change to the graph. -static StatusOr TryRemoveConditional(HloInstruction* conditional) { +StatusOr TryRemoveConditional(HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); // Do not remove conditionals that contain side-effecting instructions or // have control predecessors/successors in either true/false computation. @@ -80,6 +84,99 @@ static StatusOr TryRemoveConditional(HloInstruction* conditional) { return true; } +StatusOr TryRemoveUnusedConditionalOperands( + HloInstruction* conditional, + std::map>* changed_computations) { + // Avoid dealing with sharding. + if (conditional->has_sharding()) { + return false; + } + std::vector> tuple_indices_to_keep( + conditional->branch_count()); + bool will_change = false; + for (int64 i = 0; i < conditional->branch_count(); ++i) { + HloComputation* computation = conditional->branch_computation(i); + if (changed_computations->count(computation) > 0) { + will_change = true; + break; + } + HloInstruction* param = computation->parameter_instruction(0); + // Do not remove the root instruction. + if (param == computation->root_instruction()) { + return false; + } + // There is nothing to be removed for non-tuple operands. + if (!param->shape().IsTuple()) { + return false; + } + for (HloInstruction* user : param->users()) { + // If the user is not a get tuple element, assume it is unsafe to remove + // elemnts from the tuple. + if (user->opcode() != HloOpcode::kGetTupleElement) { + return false; + } + tuple_indices_to_keep[i].insert(user->tuple_index()); + } + // If not all tuple elements are used in this conditional branch, some can + // removed from the computation. + if (tuple_indices_to_keep[i].size() != + ShapeUtil::TupleElementCount(param->shape())) { + will_change = true; + } + } + + if (!will_change) { + return false; + } + + for (int64 branch = 0; branch < conditional->branch_count(); ++branch) { + const Shape& old_shape = conditional->operand(branch + 1)->shape(); + int64 old_tuple_element_count = ShapeUtil::TupleElementCount(old_shape); + // Clone the computation in case it is called by another instruction. + HloComputation* computation = conditional->branch_computation(branch); + if (changed_computations + ->insert({computation, tuple_indices_to_keep[branch]}) + .second) { + HloInstruction* param = computation->parameter_instruction(0); + + // Create a new tuple shape based on the indices actually used by this + // branch. + std::vector new_tuple_shapes; + new_tuple_shapes.reserve(tuple_indices_to_keep[branch].size()); + std::vector map(old_tuple_element_count, -1); + for (int64 i : tuple_indices_to_keep[branch]) { + map[i] = new_tuple_shapes.size(); + new_tuple_shapes.push_back(old_shape.tuple_shapes(i)); + } + Shape tuple_shape = ShapeUtil::MakeTupleShape(new_tuple_shapes); + // Reset the parameter shape of the computation. + *param->mutable_shape() = tuple_shape; + + // Reroute the GTE instructions to new tuple indices. + for (HloInstruction* user : param->users()) { + user->set_tuple_index(map[user->tuple_index()]); + } + } + + // Reroute the operand tuple through a tuple of gte instructions of the + // original operand tuple. + const auto& to_keep = (*changed_computations)[computation]; + std::vector new_tuple_operands; + new_tuple_operands.reserve(to_keep.size()); + for (int64 i : to_keep) { + new_tuple_operands.push_back(conditional->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + old_shape.tuple_shapes(i), + conditional->mutable_operand(branch + 1), i))); + } + HloInstruction* new_tuple = conditional->parent()->AddInstruction( + HloInstruction::CreateTuple(new_tuple_operands)); + TF_RETURN_IF_ERROR( + conditional->ReplaceOperandWithDifferentShape(branch + 1, new_tuple)); + } + return true; +} +} // namespace StatusOr ConditionalSimplifier::Run(HloModule* module) { XLA_VLOG_LINES( @@ -98,8 +195,13 @@ StatusOr ConditionalSimplifier::Run(HloModule* module) { } } + std::map> changed_computations; for (HloInstruction* conditional_op : conditional_ops) { TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); + if (!result) { + TF_ASSIGN_OR_RETURN(result, TryRemoveUnusedConditionalOperands( + conditional_op, &changed_computations)); + } changed |= result; } diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 289eb6d9023..9759526c6e0 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -157,5 +158,60 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } +TEST_F(ConditionalSimplifierTest, TrivalOperandsRemoved) { + absl::string_view hlo_string = + R"( +HloModule UnusedTupleOperands +on_false { + t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) parameter(0) + lhs = f32[20,40] get-tuple-element(t), index=0 + rhs = f32[40,40] get-tuple-element(t), index=1 + dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT result = (f32[20,40]) tuple(dot) +} + +on_true { + t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) parameter(0) + lhs = f32[20,40] get-tuple-element(t), index=2 + rhs = f32[40,40] get-tuple-element(t), index=3 + dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT result = (f32[20,40]) tuple(dot) +} + +ENTRY main { + c0_0 = f32[20,40] parameter(0) + c0_1 = f32[40,40] parameter(1) + c1_0 = f32[20,40] parameter(2) + c1_1 = f32[40,40] parameter(3) + p = pred[] parameter(4) + t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) tuple(c0_0, c0_1, c1_0, c1_1) + ROOT result = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true +} +)"; + auto status = ParseHloString(hlo_string); + TF_ASSERT_OK(status.status()); + HloVerifier v(false, false); + TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); + EXPECT_TRUE( + ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); + TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); + EXPECT_EQ(status.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->operand(1) + ->shape() + .tuple_shapes() + .size(), + 2); + EXPECT_EQ(status.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->operand(2) + ->shape() + .tuple_shapes() + .size(), + 2); +} } // namespace + } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_to_select.cc b/tensorflow/compiler/xla/service/conditional_to_select.cc new file mode 100644 index 00000000000..d9b246bd628 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_to_select.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_to_select.h" + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +static StatusOr DoConditionalToSelect(HloInstruction* conditional) { + // Only allow conditional to select if the called computations + // do not have side effects. + if (conditional->true_computation()->HasSideEffect() || + conditional->false_computation()->HasSideEffect()) { + VLOG(1) << "Not transforming conditional; branches have side effects:" + << conditional->ToString(); + return false; + } + + auto computation = conditional->parent(); + + // Create new instructions + HloInstruction* if_call_op = + computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(1)}, + conditional->true_computation())); + conditional->SetupDerivedInstruction(if_call_op); + HloInstruction* else_call_op = + computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(2)}, + conditional->false_computation())); + conditional->SetupDerivedInstruction(else_call_op); + HloInstruction* condition = conditional->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * select_op, + MakeSelectHlo(condition, if_call_op, else_call_op, conditional)); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, select_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(if_call_op).status()); + TF_RETURN_IF_ERROR(CallInliner::Inline(else_call_op).status()); + return true; +} + +StatusOr ConditionalToSelect::Run(HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + bool did_mutate = false; + VLOG(1) << "Running conditional-to-select pass"; + TF_RETURN_IF_ERROR( + call_graph->VisitNodes([&](const CallGraphNode& node) -> Status { + std::vector ToInline; + if (node.context() != CallContext::kParallel) { + return Status::OK(); + } + for (const CallSite& callsite : node.callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + VLOG(1) << "Visiting conditional: " << callsite.ToString(); + HloInstruction* conditional = callsite.instruction(); + TF_ASSIGN_OR_RETURN(bool result, + DoConditionalToSelect(conditional)); + did_mutate |= result; + } + } + return Status::OK(); + })); + return did_mutate; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/conditional_to_select.h similarity index 58% rename from tensorflow/compiler/xla/service/implicit_broadcast_remover.h rename to tensorflow/compiler/xla/service/conditional_to_select.h index 9c48b7db613..3b99e8192e8 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/conditional_to_select.h @@ -13,30 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ - -#include +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_TO_SELECT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_TO_SELECT_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { -// Pass which replaces all implicit broadcasts with their equivalent sequence of -// explicit broadcast and reshape instructions. -class ImplicitBroadcastRemover : public HloModulePass { +// A pass which transforms conditionals to selects in places where conditionals +// are legal, but not currently supported by the backends (e.g. inside kMap) +class ConditionalToSelect : public HloModulePass { public: - ImplicitBroadcastRemover() {} - ~ImplicitBroadcastRemover() override {} - - absl::string_view name() const override { - return "implicit-broadcast-remover"; - } + ~ConditionalToSelect() override = default; + absl::string_view name() const override { return "conditional-to-select"; } + // Run conditional to select on the given computation. Returns whether the + // computation was changed. StatusOr Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_TO_SELECT_H_ diff --git a/tensorflow/compiler/xla/service/conditional_to_select_test.cc b/tensorflow/compiler/xla/service/conditional_to_select_test.cc new file mode 100644 index 00000000000..fe9c6addfc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_to_select_test.cc @@ -0,0 +1,188 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_to_select.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using ConditionalToSelectTest = HloTestBase; +using ::testing::_; + +// Test that a conditional of simple constants is transformed to a select +TEST_F(ConditionalToSelectTest, MapConditionalConstants) { + const string hlo_text = R"( +HloModule MapConditionalConstants + +if { + %pif = () parameter(0) + ROOT %cif = f32[] constant(0) +} + +else { + %pelse = () parameter(0) + ROOT %celse = f32[] constant(1) +} + +mapped { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + %lt = pred[] compare(%a, %b), direction=LT + %t = () tuple() + ROOT %conditional = f32[] conditional(%lt, %t, %t), true_computation=if, false_computation=else +} + +ENTRY comp { + %p1 = f32[1000]{0} parameter(0) + %p2 = f32[1000]{0} parameter(1) + ROOT %mapped = f32[1000]{0} map(%p1, %p2), dimensions={0}, to_apply=mapped +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); + ConditionalToSelect pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_EQ(root->opcode(), HloOpcode::kMap); + HloComputation* mapped = root->called_computations()[0]; + EXPECT_THAT(mapped->root_instruction(), + op::Select(op::Lt(op::Parameter(0), op::Parameter(1)), + op::Constant(), op::Constant())); +} + +// Test that the condition gets broadcasted for feeding into +// select when the output is non-scalar. +TEST_F(ConditionalToSelectTest, MapConditionalNonScalar) { + const string hlo_text = R"( +HloModule MapConditionalNonScalar + +if { + %pif = () parameter(0) + %zero = f32[] constant(0) + ROOT %zero_broadcasted = f32[2,2]{1,0} broadcast(%zero), dimensions={} +} + +else { + %pelse = () parameter(0) + %one = f32[] constant(0) + ROOT %one_broadcasted = f32[2,2]{1,0} broadcast(%one), dimensions={} +} + +add { + %add_lhs = f32[] parameter(0) + %add_rhs = f32[] parameter(1) + ROOT %add = f32[] add(%add_lhs, %add_rhs) +} + +mapped { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + %lt = pred[] compare(%a, %b), direction=LT + %t = () tuple() + %conditional = f32[2,2]{1,0} conditional(%lt, %t, %t), true_computation=if, false_computation=else + %zero = f32[] constant(0) + ROOT %reduced = f32[] reduce(%conditional, %zero), dimensions={0,1}, to_apply=add +} + +ENTRY comp { + %p1 = f32[1000]{0} parameter(0) + %p2 = f32[1000]{0} parameter(1) + ROOT %mapped = f32[1000]{0} map(%p1, %p2), dimensions={0}, to_apply=mapped +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); + ConditionalToSelect pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_EQ(root->opcode(), HloOpcode::kMap); + HloComputation* mapped = root->called_computations()[0]; + EXPECT_THAT( + mapped->root_instruction(), + op::Reduce( + op::Select(op::Broadcast(op::Lt(op::Parameter(0), op::Parameter(1))), + _, _), + _)); +} + +// Test that conditionals of tuple type get turned into kTupleSelect +TEST_F(ConditionalToSelectTest, MapConditionalTuples) { + const string hlo_text = R"( +HloModule MapConditionalTuples + +if { + %pif = () parameter(0) + %zero = f32[] constant(0) + ROOT %tup = (f32[],f32[]) tuple(%zero, %zero) +} + +else { + %pelse = () parameter(0) + %one = f32[] constant(0) + ROOT %tup = (f32[],f32[]) tuple(%one, %one) +} + +add { + %add_lhs = f32[] parameter(0) + %add_rhs = f32[] parameter(1) + ROOT %add = f32[] add(%add_lhs, %add_rhs) +} + +mapped { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + %lt = pred[] compare(%a, %b), direction=LT + %t = () tuple() + %conditional = (f32[], f32[]) conditional(%lt, %t, %t), true_computation=if, false_computation=else + %el1 = f32[] get-tuple-element(%conditional), index=0 + %el2 = f32[] get-tuple-element(%conditional), index=1 + %reduced = f32[] add(%el1, %el2) +} + +ENTRY comp { + %p1 = f32[1000]{0} parameter(0) + %p2 = f32[1000]{0} parameter(1) + ROOT %mapped = f32[1000]{0} map(%p1, %p2), dimensions={0}, to_apply=mapped +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); + ConditionalToSelect pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_EQ(root->opcode(), HloOpcode::kMap); + HloComputation* mapped = root->called_computations()[0]; + EXPECT_THAT(mapped->root_instruction(), + op::Add(op::GetTupleElement(op::TupleSelect(_, _, _)), _)); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index f11f9e5fc29..434bbe9ffd5 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -191,8 +191,9 @@ HloInstruction* GetExpandedFilterMask( // linspace to create a diagonal predicate. Shape predicate_shape = ShapeUtil::MakeShape( PRED, AsInt64Slice(expanded_filter_shape.dimensions())); - return add_instruction(HloInstruction::CreateBinary( - predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); + return add_instruction(HloInstruction::CreateCompare( + predicate_shape, broadcasted_mask1, broadcasted_mask2, + ComparisonDirection::kEq)); } // This function handles batch_group_counts which are relevant only for diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 79b010e2f1b..4b8d20f53e5 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -956,14 +957,6 @@ class CopyRemover { absl::flat_hash_map copy_map_; }; -void MaybeDumpModule(const string& message, const HloModule& module) { - if (VLOG_IS_ON(3)) { - VLOG(3) << message; - XLA_VLOG_LINES(3, module.ToString()); - hlo_graph_dumper::MaybeDumpHloModule(module, message); - } -} - } // namespace // Add kCopy instructions to the given module to guarantee there is no @@ -1095,18 +1088,8 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, return Status::OK(); } -Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering, - HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); - TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); - return Status::OK(); -} - Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module) { - MaybeDumpModule("after adding copies to resolve interference", *module); - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); @@ -1130,8 +1113,6 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, } } } - MaybeDumpModule("after removing unnecessary copies", *module); - return Status::OK(); } @@ -1160,8 +1141,6 @@ StatusOr CopyInsertion::Run(HloModule* module) { // interference. If all copies were added in step (1) then copy removal would // also have to reason about things like constants and parameters live out of // the computation. - MaybeDumpModule("before copy insertion", *module); - std::unique_ptr call_graph = CallGraph::Build(module); if (!call_graph->IsFlattened()) { return FailedPrecondition( @@ -1190,22 +1169,20 @@ StatusOr CopyInsertion::Run(HloModule* module) { HloDCE dce; TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); + DumpHloModuleDuringPassIfEnabled( + name(), "after adding copies to resolve interference", *module); - DependencyHloOrdering dep_ordering(module); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module)); - - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module)); + TF_RETURN_IF_ERROR( + RemoveUnnecessaryCopies(DependencyHloOrdering(module), module)); + DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies", + *module); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); - - MaybeDumpModule("after adding special-case copies", *module); + DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies", + *module); TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK( - VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module)); - - MaybeDumpModule("after copy insertion", *module); if (VLOG_IS_ON(1)) { int64 num_total_copies = 0; diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 8866b5050bf..f7e19970feb 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -89,11 +89,6 @@ class CopyInsertion : public HloModulePass { // Status AddSpecialCaseCopies(HloModule* module); - // Verifies that no HLO values have interfering live ranges using the given - // ordering. - Status VerifyNoLiveRangeInterference(const HloOrdering& ordering, - HloModule* module); - protected: // Override which requires the caller to pass in a call graph. virtual Status AddSpecialCaseCopies(const CallGraph& call_graph, diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 4391bdcba53..6fa3161e578 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -420,9 +420,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, - induction_variable, limit_const)); + builder.AddInstruction(HloInstruction::CreateCompare( + condition_result_shape_, induction_variable, limit_const, + ComparisonDirection::kLt)); return builder.Build(); } @@ -1842,7 +1842,7 @@ HloModule TokensShouldNotBeCopied %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %TokensShouldNotBeCopied () -> s32[] { @@ -2060,7 +2060,7 @@ if-condition.v4 { p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 constant.4 = s32[] constant(0) - ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) + ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ } _functionalize_body_1__.v28 { @@ -2070,7 +2070,7 @@ _functionalize_body_1__.v28 { add.4 = s32[] add(get-tuple-element.68, constant.7) get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 - less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70) + less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE constant.8 = s32[] constant(0) select = s32[] select(less-than-or-equal-to, constant.8, constant.7) get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 @@ -2087,7 +2087,7 @@ cond_wrapper.v3.1 { inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 constant.11 = s32[] constant(7) - ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11) + ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT } _functionalize_body_2__.v25 { @@ -2110,7 +2110,7 @@ cond_wrapper.v3.2 { inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 constant.13 = s32[] constant(5) - ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13) + ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT } ENTRY TestComputation { @@ -2142,7 +2142,7 @@ if-condition.v4 { p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 constant.4 = s32[] constant(0) - ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) + ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ } if-body.v5.1 { @@ -2159,7 +2159,7 @@ if-condition.v4.1 { p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 constant.6 = s32[] constant(1) - ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6) + ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ } _functionalize_body_1__.v28 { @@ -2169,7 +2169,7 @@ _functionalize_body_1__.v28 { add.4 = s32[] add(get-tuple-element.72, constant.7) get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 - less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74) + less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE constant.8 = s32[] constant(0) select = s32[] select(less-than-or-equal-to, constant.8, constant.7) get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 @@ -2187,7 +2187,7 @@ cond_wrapper.v3.1 { inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 constant.11 = s32[] constant(7) - ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11) + ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT } _functionalize_body_2__.v25 { @@ -2210,7 +2210,7 @@ cond_wrapper.v3.2 { inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 constant.13 = s32[] constant(5) - ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13) + ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT } ENTRY TestComputation { diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index c8fef147b85..09f5c859af4 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -62,7 +62,7 @@ cc_library( srcs = ["buffer_info_util.cc"], hdrs = ["buffer_info_util.h"], deps = [ - "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/types:span", @@ -93,11 +93,14 @@ cc_library( "@com_google_absl//absl/strings", ":target_machine_features", "@com_google_absl//absl/types:span", - "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", + "//tensorflow/compiler/xla/service:conditional_to_select", "//tensorflow/compiler/xla/service:scatter_expander", + "//tensorflow/compiler/xla/service:slice_sinker", + "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", @@ -179,7 +182,6 @@ cc_library( deps = [ ":compiler_functor", ":cpu_runtime", - ":custom_call_target_registry", ":disassembler", ":orc_jit_memory_mapper", ":runtime_fp16", @@ -200,6 +202,7 @@ cc_library( "@llvm//:orc_jit", "@llvm//:support", "@llvm//:target", # fixdeps: keep + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", @@ -207,6 +210,12 @@ cc_library( ] + ORC_JIT_MEMORY_MAPPER_TARGETS, ) +cc_library( + name = "runtime_lightweight_check", + hdrs = ["runtime_lightweight_check.h"], + copts = runtime_copts(), +) + cc_library( name = "runtime_fp16", srcs = [ @@ -236,7 +245,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:computation_layout", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", @@ -245,6 +253,8 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor/host:host_stream", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -338,15 +348,15 @@ cc_library( srcs = ["ir_function.cc"], hdrs = ["ir_function.h"], deps = [ + ":cpu_runtime", ":ir_emission_utils", ":shape_partition", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -377,6 +387,7 @@ cc_library( ":vector_support_library", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", @@ -499,6 +510,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", @@ -534,6 +546,7 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":runtime_lightweight_check", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", "//tensorflow/core/kernels:eigen_helpers", @@ -569,6 +582,7 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":runtime_lightweight_check", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:framework_lite", @@ -593,6 +607,7 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":runtime_lightweight_check", ":runtime_matvec", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", @@ -624,6 +639,7 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ + ":runtime_lightweight_check", "//tensorflow/core:framework_lite", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", @@ -748,6 +764,7 @@ cc_library( ":ir_emission_utils", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", ], ) @@ -929,17 +946,6 @@ cc_library( ], ) -cc_library( - name = "custom_call_target_registry", - srcs = [ - "custom_call_target_registry.cc", - ], - hdrs = [ - "custom_call_target_registry.h", - ], - visibility = ["//visibility:public"], -) - cc_library( name = "orc_jit_memory_mapper", srcs = ["orc_jit_memory_mapper.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc index 1942ea1a2af..efc6ac06509 100644 --- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc @@ -14,11 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" namespace xla { namespace cpu { -using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo; +using BufferInfo = cpu_function_runtime::BufferInfo; std::vector CreateBufferInfosFromBufferAssignment( const BufferAssignment& buffer_assignment) { diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h index e9ee928ab29..dd0bd9e9a0b 100644 --- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h @@ -17,14 +17,14 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ #include "absl/types/span.h" -#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" namespace xla { namespace cpu { // Creates and returns a list of BufferInfo instances containing relevant // information from `buffer_assignment`. -std::vector<::tensorflow::cpu_function_runtime::BufferInfo> +std::vector CreateBufferInfosFromBufferAssignment( const BufferAssignment& buffer_assignment); @@ -34,8 +34,7 @@ CreateBufferInfosFromBufferAssignment( // If this function returns V then entry parameter i has buffer allocation index // V[i]. std::vector CreateArgIndexTableFromBufferInfos( - absl::Span - buffer_infos); + absl::Span buffer_infos); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 414eacddfc7..7dab505c724 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -66,14 +66,9 @@ class FilteredPassManager : public llvm::legacy::PassManager { explicit FilteredPassManager(bool disable_expensive_passes) : disable_expensive_passes_(disable_expensive_passes) {} void add(llvm::Pass* p) override { - llvm::StringRef PassName = p->getPassName(); - if (PassName.contains("Warn about non-applied transformations")) { - delete p; - return; - } if (disable_expensive_passes_) { + llvm::StringRef PassName = p->getPassName(); if (PassName.contains("Unroll loops")) { - delete p; return; } } @@ -94,7 +89,7 @@ std::unique_ptr CompilerFunctor::operator()( XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); if (pre_optimization_hook_) { - TF_CHECK_OK(pre_optimization_hook_(module)); + pre_optimization_hook_(module); } // Add the appropriate TargetLibraryInfo and TargetTransformInfo. @@ -128,7 +123,10 @@ std::unique_ptr CompilerFunctor::operator()( CHECK(!llvm::verifyModule(module, &llvm::dbgs())); - runtime::RewriteIRRuntimeFunctions(&module, enable_fast_math_); + const auto& opts = target_machine_->Options; + bool fast_math_enabled = opts.UnsafeFPMath && opts.NoInfsFPMath && + opts.NoNaNsFPMath && opts.NoSignedZerosFPMath; + runtime::RewriteIRRuntimeFunctions(&module, fast_math_enabled); // Buffer for holding machine code prior to constructing the ObjectFile. llvm::SmallVector stream_buffer; @@ -138,7 +136,7 @@ std::unique_ptr CompilerFunctor::operator()( XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); if (post_optimization_hook_) { - TF_CHECK_OK(post_optimization_hook_(module)); + post_optimization_hook_(module); } // Generate code. @@ -150,17 +148,11 @@ std::unique_ptr CompilerFunctor::operator()( std::unique_ptr memory_buffer( new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); - if (VLOG_IS_ON(2)) { + if (post_codegen_hook_) { llvm::Expected> obj_file = llvm::object::ObjectFile::createObjectFile(*memory_buffer); if (obj_file) { - StatusOr disasm_result = - disassembler_->DisassembleObjectFile(*obj_file.get()); - if (disasm_result.ok()) { - XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text); - } else { - LOG(WARNING) << "Could not disassemble object file!"; - } + post_codegen_hook_(*obj_file.get()); } else { LOG(WARNING) << "Could convert memory buffer to object file!"; } @@ -220,7 +212,6 @@ void CompilerFunctor::AddOptimizationPasses( builder.Inliner = llvm::createAlwaysInlinerLegacyPass(); } - builder.DisableUnitAtATime = false; builder.DisableUnrollLoops = opt_level == 0; builder.LoopVectorize = opt_level > 0 && size_level == 0; builder.SLPVectorize = opt_level > 1 && size_level == 0; diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index c38b896c501..fdaba451c19 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Target/TargetMachine.h" -#include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/core/platform/logging.h" @@ -32,19 +31,19 @@ namespace cpu { class CompilerFunctor { public: explicit CompilerFunctor( - llvm::TargetMachine* target_machine, const Disassembler* disassembler, - int opt_level, bool optimize_for_size, bool enable_fast_math, - bool disable_expensive_passes, + llvm::TargetMachine* target_machine, int opt_level, + bool optimize_for_size, bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook = nullptr, - LLVMCompiler::ModuleHook post_optimization_hook = nullptr) + LLVMCompiler::ModuleHook post_optimization_hook = nullptr, + std::function post_codegen_hook = + nullptr) : target_machine_(target_machine), - disassembler_(CHECK_NOTNULL(disassembler)), opt_level_(opt_level), optimize_for_size_(optimize_for_size), - enable_fast_math_(enable_fast_math), disable_expensive_passes_(disable_expensive_passes), - pre_optimization_hook_(pre_optimization_hook), - post_optimization_hook_(post_optimization_hook) {} + pre_optimization_hook_(std::move(pre_optimization_hook)), + post_optimization_hook_(std::move(post_optimization_hook)), + post_codegen_hook_(std::move(post_codegen_hook)) {} // Compile a Module to an ObjectFile. std::unique_ptr operator()( @@ -61,13 +60,12 @@ class CompilerFunctor { unsigned opt_level, unsigned size_level) const; llvm::TargetMachine* target_machine_; - const Disassembler* disassembler_; const unsigned opt_level_; const bool optimize_for_size_; - const bool enable_fast_math_; const bool disable_expensive_passes_; LLVMCompiler::ModuleHook pre_optimization_hook_; LLVMCompiler::ModuleHook post_optimization_hook_; + std::function post_codegen_hook_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index eb5d843fe8b..06ea1e2f8bd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include + #include #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include @@ -41,6 +42,7 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -52,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/cholesky_expander.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/conditional_to_select.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" @@ -70,6 +73,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -94,6 +98,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" +#include "tensorflow/compiler/xla/service/slice_sinker.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" @@ -111,7 +116,7 @@ limitations under the License. namespace xla { namespace cpu { -using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo; +using BufferInfo = cpu_function_runtime::BufferInfo; CpuAotCompilationOptions::CpuAotCompilationOptions( string triple, string cpu_name, string features, string entry_point_name, @@ -249,6 +254,11 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + + // Remove zero-sized HLO from the input so that other passes don't have to + // handle it. + pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); @@ -256,6 +266,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -265,7 +276,10 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( // pass. pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(/*decompose_batch_dot=*/false); + pipeline.AddPass(); + // After canonicalization, there may be more batch dots that can be + // simplified. + pipeline.AddPass(); auto cost_model = [](HloInstruction* conv) { // We need a cost model for CPUs. Currently, do nothing. return false; @@ -302,6 +316,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pass.AddPass(); pass.AddPass(); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -318,6 +333,11 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( }, TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); + + pipeline.AddPass( + module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, target_machine_features); + pipeline.AddPass(); pipeline.AddPass(); @@ -325,10 +345,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); - - pipeline.AddPass( - module->mutable_entry_computation_layout(), - LayoutAssignment::InstructionCanChangeLayout, target_machine_features); return pipeline.Run(module).status(); } @@ -348,13 +364,10 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( { auto& pass = pipeline.AddPass>( "simplification after layout assignement"); - // TODO(b/117156505): When the bug is fixed, the CPU backend should not - // produce layout changing elementwise operations. We will then pass - // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to - // enable stricter verification. pass.AddInvariantChecker( /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); options.set_enable_dot_strength_reduction(false); @@ -410,10 +423,20 @@ auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; }; llvm::TargetOptions CompilerTargetOptions( const HloModuleConfig& module_config) { llvm::TargetOptions target_options; - llvm_ir::SetTargetOptions( - /*fast_math_enabled=*/module_config.debug_options() - .xla_cpu_enable_fast_math(), - &target_options); + // In LLVM backend flags, UnsafeFPMath does not explicitly imply NoInfs, etc. + if (module_config.debug_options().xla_cpu_enable_fast_math()) { + target_options.UnsafeFPMath = true; + target_options.NoInfsFPMath = + !module_config.debug_options().xla_cpu_fast_math_honor_infs(); + target_options.NoNaNsFPMath = + !module_config.debug_options().xla_cpu_fast_math_honor_nans(); + target_options.NoSignedZerosFPMath = true; + } else { + target_options.UnsafeFPMath = false; + target_options.NoInfsFPMath = false; + target_options.NoNaNsFPMath = false; + target_options.NoSignedZerosFPMath = false; + } return target_options; } @@ -432,53 +455,32 @@ llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) { } } -Status InitializeModuleHooks( +std::pair GetIRModuleHooks( const HloModule& hlo_module, const LLVMCompiler::ModuleHook& user_pre_optimization_hook, - const LLVMCompiler::ModuleHook& user_post_optimization_hook, - LLVMCompiler::ModuleHook* pre_optimization_ir_hook, - LLVMCompiler::ModuleHook* post_optimization_ir_hook) { - const string& ir_dump_directory = - hlo_module.config().debug_options().xla_dump_ir_to(); - if (ir_dump_directory.empty()) { - *pre_optimization_ir_hook = user_pre_optimization_hook; - *post_optimization_ir_hook = user_post_optimization_hook; - return Status::OK(); - } - - const string& hlo_module_name = hlo_module.name(); - + const LLVMCompiler::ModuleHook& user_post_optimization_hook) { // Create the IR hooks. If applicable, each IR hook does the following: // // * Calls the user supplied module hook. // * Writes out the IR to a file in the output directory designated by - // --xla_dump_ir_to - - *pre_optimization_ir_hook = - [user_pre_optimization_hook, ir_dump_directory, - hlo_module_name](const llvm::Module& llvm_module) { - if (user_pre_optimization_hook) { - TF_RETURN_IF_ERROR(user_pre_optimization_hook(llvm_module)); - } - return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, - /*hlo_module_name=*/hlo_module_name, - llvm_module, - /*optimized=*/false); - }; - - *post_optimization_ir_hook = - [user_post_optimization_hook, ir_dump_directory, - hlo_module_name](const llvm::Module& llvm_module) { - if (user_post_optimization_hook) { - TF_RETURN_IF_ERROR(user_post_optimization_hook(llvm_module)); - } - return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, - /*hlo_module_name=*/hlo_module_name, - llvm_module, - /*optimized=*/true); - }; - - return Status::OK(); + // --xla_dump_to + const HloModule* hlo_module_ptr = &hlo_module; + auto hook = [user_pre_optimization_hook, user_post_optimization_hook, + hlo_module_ptr](bool optimized, + const llvm::Module& llvm_module) { + const auto& user_hook = + !optimized ? user_pre_optimization_hook : user_post_optimization_hook; + if (user_hook) { + user_hook(llvm_module); + } + llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized); + }; + return {[hook](const llvm::Module& llvm_module) { + return hook(/*optimized=*/false, llvm_module); + }, + [hook](const llvm::Module& llvm_module) { + return hook(/*optimized=*/true, llvm_module); + }}; } Status VerifyLlvmModule(const llvm::Module& llvm_module) { @@ -492,7 +494,7 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) { << "Invalid LLVM IR before optimizations:\n" << err_stream.str() << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " - "Rerun with --xla_dump_ir_to to get the IR. "; + "Rerun with --xla_dump_to to get the IR. "; return Status::OK(); } @@ -535,10 +537,7 @@ Status CreateHloProfilingArtifacts( StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* /*stream_exec*/, - DeviceMemoryAllocator* /*device_allocator*/) { - VLOG(2) << "Before optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - + se::DeviceMemoryAllocator* /*device_allocator*/) { std::unique_ptr jit_target_machine = SimpleOrcJIT::InferTargetMachineForJIT( CompilerTargetOptions(module->config()), @@ -546,29 +545,72 @@ StatusOr> CpuCompiler::RunHloPasses( TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, jit_target_machine.get())); - - VLOG(2) << "After optimization:"; - XLA_VLOG_LINES(2, module->ToString()); return std::move(module); } +namespace { + +// Post-compilation callback functor for use by SimpleOrcJIT. +// +// Dumps disassembled machine code if dumping is enabled for the module. +struct OrcJITPostCompilationHook { + // Gets an std::function that implements this hook. + static std::function Create( + const HloModule* module) { + // This struct is not copyable, but std::functions must be. So to create an + // std::function out of this struct, we have to wrap it in a shared_ptr. + auto wrapped = std::make_shared(module); + return [wrapped](const llvm::object::ObjectFile& obj_file) { + (*wrapped)(obj_file); + }; + } + + // Constructor can't be private because we want to call it from + // std::make_shared, but users should call Create() instead. + explicit OrcJITPostCompilationHook(const HloModule* module) + : module(module), + target_machine(SimpleOrcJIT::InferTargetMachineForJIT( + CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config()))), + disassembler(*target_machine) {} + + private: + void operator()(const llvm::object::ObjectFile& obj_file) { + if (!DumpingEnabledForHloModule(*module)) { + return; + } + StatusOr disasm_or = + disassembler.DisassembleObjectFile(obj_file); + string text = disasm_or.ok() ? std::move(disasm_or).ValueOrDie().text + : absl::StrCat("Error disassembling: ", + disasm_or.status().ToString()); + DumpToFileInDirOrStdout(*module, /*file_suffix=*/"s", text); + } + + const HloModule* module; + // disassembler keeps references to data inside of target_machine. + std::unique_ptr target_machine; + Disassembler disassembler; +}; + +} // namespace + StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* /*device_allocator*/) { - const string timer_message = - "Compiling [" + module->name() + "] for CPU using JIT"; - XLA_SCOPED_LOGGING_TIMER(timer_message); - + se::DeviceMemoryAllocator* /*device_allocator*/) { VLOG(1) << "Compiling: " << module->name(); + XLA_SCOPED_LOGGING_TIMER( + absl::StrFormat("Compiling [%s] for CPU using JIT", module->name())); + TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, &llvm_ir::InitializeLLVMCommandLineOptions, module->config()); ModuleHook pre_optimization_ir_hook; ModuleHook post_optimization_ir_hook; - TF_RETURN_IF_ERROR(InitializeModuleHooks( - *module, user_pre_optimization_hook_, user_post_optimization_hook_, - &pre_optimization_ir_hook, &post_optimization_ir_hook)); + std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) = + GetIRModuleHooks(*module, user_pre_optimization_hook_, + user_post_optimization_hook_); // Compile must be thread-safe so create a new LLVM context for the module. auto llvm_context = absl::make_unique(); @@ -579,9 +621,9 @@ StatusOr> CpuCompiler::RunBackend( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), - pre_optimization_ir_hook, post_optimization_ir_hook); + pre_optimization_ir_hook, post_optimization_ir_hook, + OrcJITPostCompilationHook::Create(module.get())); llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); @@ -602,8 +644,6 @@ StatusOr> CpuCompiler::RunBackend( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string xla_dump_optimized_hlo_proto_to = - module->config().debug_options().xla_dump_optimized_hlo_proto_to(); // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis @@ -620,15 +660,7 @@ StatusOr> CpuCompiler::RunBackend( BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. - XLA_VLOG_LINES(2, assignment->ToString()); - - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); - } + DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations"); // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from @@ -683,7 +715,6 @@ StatusOr> CpuCompiler::RunBackend( ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } - XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. @@ -714,15 +745,29 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, // We can pass just one llvm::TargetOptions when we compile the LLVM module, // so we bail if the configs have conflicting flags. At the moment, the only - // flag that needs to be consistent is fast-math. - const bool fast_math_enabled = - modules[0]->config().debug_options().xla_cpu_enable_fast_math(); - for (const auto& module : modules) { - if (module->config().debug_options().xla_cpu_enable_fast_math() != - fast_math_enabled) { - return InvalidArgument( - "All HLO module configs must have the same value for " - "xla_enable_fast_math."); + // flags that need to be consistent are for fast-math. + for (const auto& fn_and_name : + {std::make_pair(&DebugOptions::xla_cpu_enable_fast_math, + "xla_cpu_enable_fast_math"), + std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_infs, + "xla_cpu_fast_math_honor_infs"), + std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_nans, + "xla_cpu_fast_math_honor_nans")}) { + // This only works because each of the method pointers above returns a bool. + // Otherwise we'd have to do some template magic. + const auto& field_method_ptr = fn_and_name.first; + const auto& field_name = fn_and_name.second; + bool first_module_val = + (modules[0]->config().debug_options().*field_method_ptr)(); + for (int64 i = 0; i < modules.size(); ++i) { + bool cur_module_val = + (modules[i]->config().debug_options().*field_method_ptr)(); + if (first_module_val != cur_module_val) { + return InvalidArgument( + "All HLO module configs must have the same value for %s, but " + "module 0 and %d have different values (%d vs %d).", + field_name, i, first_module_val, cur_module_val); + } } } @@ -731,8 +776,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, } const CpuAotCompilationOptions& options = static_cast(aot_options); - llvm::StringRef target_triple = llvm_ir::AsStringRef(options.triple()); - llvm::Triple triple(llvm::Triple::normalize(target_triple)); + llvm::Triple triple(llvm::Triple::normalize(options.triple())); std::string error; const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); @@ -770,13 +814,12 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, pie_level = llvm::PIELevel::Large; break; } - llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); - llvm::StringRef features = llvm_ir::AsStringRef(options.features()); llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); - std::unique_ptr target_machine = absl::WrapUnique( - target->createTargetMachine(triple.getTriple(), cpu_name, features, - CompilerTargetOptions(modules[0]->config()), - reloc_model, llvm::None, opt_level)); + std::unique_ptr target_machine = + absl::WrapUnique(target->createTargetMachine( + triple.getTriple(), options.cpu_name(), options.features(), + CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None, + opt_level)); // Compile must be thread-safe so create a new LLVM context for the module. llvm::LLVMContext llvm_context; @@ -795,15 +838,9 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, HloModule* module = modules[i].get(); VLOG(1) << "Compiling ahead-of-time: " << module->name(); - VLOG(2) << "Before optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR( RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get())); - VLOG(2) << "After optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN(HloSchedule schedule, ScheduleModule(module, BufferSizeBytesFunction())); @@ -818,15 +855,11 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. - XLA_VLOG_LINES(2, assignment->ToString()); - - const string xla_dump_optimized_hlo_proto_to = - module->config().debug_options().xla_dump_optimized_hlo_proto_to(); - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "buffer_assignment", + assignment->ToString()); } + DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations"); std::unordered_map instruction_to_profile_idx; std::unordered_map computation_to_profile_idx; @@ -870,33 +903,42 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, /*is_top_level_computation=*/true, schedule.sequence(computation).instructions())); - CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); + CHECK(entry_function->getName() == entry_point_name); - ModuleHook pre_optimization_ir_dump_hook; - ModuleHook post_optimization_ir_dump_hook; - TF_RETURN_IF_ERROR(InitializeModuleHooks( - *module, user_pre_optimization_hook_, user_post_optimization_hook_, - &pre_optimization_ir_dump_hook, &post_optimization_ir_dump_hook)); + ModuleHook pre_optimization_ir_hook; + ModuleHook post_optimization_ir_hook; + std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) = + GetIRModuleHooks(*module, user_pre_optimization_hook_, + user_post_optimization_hook_); // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run the // pre-optimization IR dump hook before returning. { Status verify_status = VerifyLlvmModule(llvm_module); - if (!verify_status.ok() && pre_optimization_ir_dump_hook) { - pre_optimization_ir_dump_hook(llvm_module).IgnoreError(); + if (!verify_status.ok() && pre_optimization_ir_hook) { + pre_optimization_ir_hook(llvm_module); } TF_RETURN_IF_ERROR(verify_status); } - XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(llvm_module)); + auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) { + if (!DumpingEnabledForHloModule(*module)) { + return; + } + StatusOr disasm_or = + Disassembler(*target_machine).DisassembleObjectFile(obj_file); + string text = disasm_or.ok() + ? std::move(disasm_or).ValueOrDie().text + : absl::StrCat("Error disassembling: ", + disasm_or.status().ToString()); + DumpToFileInDirOrStdout(*module, /*file_suffix=*/"s", text); + }; - Disassembler disassembler(*target_machine); CompilerFunctor compiler_functor( - target_machine.get(), &disassembler, opt_level, + target_machine.get(), opt_level, options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), - pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); + pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook); std::unique_ptr object_file = compiler_functor(llvm_module); ObjectFileData object_file_data(object_file->getBufferStart(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index c67307548dd..dd15891f175 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Target/TargetMachine.h" -#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -81,7 +81,7 @@ class CpuAotCompilationResult : public AotCompilationResult { public: CpuAotCompilationResult( ObjectFileData object_file_data, - std::vector<::tensorflow::cpu_function_runtime::BufferInfo> buffer_infos, + std::vector buffer_infos, int64 result_buffer_index, std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult(); @@ -91,8 +91,7 @@ class CpuAotCompilationResult : public AotCompilationResult { } const ObjectFileData& object_file_data() const { return object_file_data_; } - const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>& - buffer_infos() const { + const std::vector& buffer_infos() const { return buffer_infos_; } int64 result_buffer_index() const { return result_buffer_index_; } @@ -103,8 +102,7 @@ class CpuAotCompilationResult : public AotCompilationResult { // A list of BufferInfo objects describing the buffers used by the XLA // computation. - const std::vector<::tensorflow::cpu_function_runtime::BufferInfo> - buffer_infos_; + const std::vector buffer_infos_; // Contains which buffer index into |buffer_sizes| was designated to the // result of the computation. This buffer should be passed into the output @@ -135,11 +133,11 @@ class CpuCompiler : public LLVMCompiler { StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 23d0af34233..fffd1d0175f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -73,13 +73,13 @@ CpuExecutable::CpuExecutable( } StatusOr, - std::vector>> + std::vector>> CpuExecutable::CreateBufferTable( - DeviceMemoryAllocator* memory_allocator, int device_ordinal, + se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, absl::Span arguments) { std::vector unowning_buffers( assignment_->Allocations().size()); - std::vector owning_buffers( + std::vector owning_buffers( assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); @@ -113,17 +113,17 @@ CpuExecutable::CreateBufferTable( } else { TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate( device_ordinal, buffer_size)); - unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase(); + unowning_buffers[i] = *owning_buffers[i]; VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << owning_buffers[i].opaque() << "]"; + << owning_buffers[i]->opaque() << "]"; } // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i]->opaque(), buffer_size); } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, @@ -207,7 +207,7 @@ Status CpuExecutable::ExecuteComputeFunction( StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - absl::Span buffers) { + absl::Span buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( /*on_host_shape=*/result_shape(), @@ -216,7 +216,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( const HloInputOutputAliasConfig& input_output_alias = module().input_output_alias_config(); - // Move OwningDeviceMemory values which contain the array(s) of the result + // Move se::OwningDeviceMemory values which contain the array(s) of the result // into the respective location in ScopedShapedBuffer which is returned to the // caller. TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus( @@ -235,7 +235,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( const BufferAllocation::Slice slice, this->assignment_->GetUniqueSlice(src, buffer_source->index())); const BufferAllocation::Index buffer_index = slice.index(); - OwningDeviceMemory& buffer = buffers[buffer_index]; + se::OwningDeviceMemory& buffer = buffers[buffer_index]; if (!slice.allocation()->is_entry_computation_parameter()) { // If the buffer coming out of the result is from a parameter, the // owning buffer will be null, and that means the caller aliased some @@ -247,7 +247,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( // ownership, and hence a buffer coming from there cannot be part of // the new ScopedShapedBuffer we create for the result (which assumes // ownership). - *device_memory = buffer.Forget(); + *device_memory = buffer.Release(); } else { auto output_alias = input_output_alias.GetAliasedOutput( slice.allocation()->parameter_number(), @@ -297,8 +297,8 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( auto* host_stream = dynamic_cast( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector owning_buffers; + se::DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector owning_buffers; std::vector unowning_buffers; TF_ASSIGN_OR_RETURN( std::tie(unowning_buffers, owning_buffers), @@ -326,7 +326,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( CpuExecutable* executable; ServiceExecutableRunOptions run_options; std::vector unowning_buffers; - std::shared_ptr> buffers; + std::shared_ptr> buffers; HloExecutionProfile* hlo_execution_profile; void operator()() { @@ -338,7 +338,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( }; host_stream->EnqueueTask( AsyncRunTask{this, *run_options, std::move(unowning_buffers), - std::make_shared>( + std::make_shared>( std::move(owning_buffers)), hlo_execution_profile}); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 3b91b15ba9b..735a20749b9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -25,7 +25,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -37,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace cpu { @@ -111,8 +111,9 @@ class CpuExecutable : public Executable { // storage and the live-out buffer into which the computation writes it // result. StatusOr, - std::vector>> - CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + std::vector>> + CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator, + int device_ordinal, absl::Span arguments); // Calls the generated function performing the computation with the given @@ -126,7 +127,7 @@ class CpuExecutable : public Executable { // The addresses are set according to buffer assignment. StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - absl::Span buffers); + absl::Span buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 2e22cdaca52..3d7c06cfa27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" namespace xla { namespace cpu { @@ -30,6 +31,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. return hlo.IsElementwise() || // + hlo.opcode() == HloOpcode::kBitcast || hlo.opcode() == HloOpcode::kBroadcast || hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || @@ -117,6 +119,14 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Don't fuse if fusing would cause too much code duplication because of + // inefficiencies in the fusion emitter. + // TODO(b/119692968): Remove this once the fusion emitter can handle + // arbitrary fusion nodes. + if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) { + return false; + } + if (consumer->opcode() == HloOpcode::kDot) { // In the general case we call out to optimized "black box" GEMM routines // for Dot, which precludes fusion. However, in very specific cases, we try @@ -144,8 +154,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } } - if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (consumer->IsLoopFusion()) { VLOG(2) << "Fusing: consumer is a fusion node."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 952950308be..b35026a41cd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -84,7 +84,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); } -TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { +TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); @@ -101,7 +101,8 @@ TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); - EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Fusion()); } TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { @@ -627,61 +628,6 @@ TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); } -// Tests that we do not fuse instructions in cases where instructions in the -// fusion would reuse elements from its operand due to an implicit broadcast. -TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) { - Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4}); - Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4}); - - HloComputation::Builder builder(TestName()); - - HloInstruction* small_param = - builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, small_shape, "param")); - HloInstruction* small_exp = builder.AddInstruction( - HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param)); - builder.AddInstruction( - HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp)); - - std::unique_ptr module = CreateNewUnverifiedModule(); - module->AddEntryComputation(builder.Build()); - - auto did_fusion = CpuInstructionFusion().Run(module.get()); - ASSERT_TRUE(did_fusion.ok()); - EXPECT_FALSE(did_fusion.ValueOrDie()); - ASSERT_THAT(module->entry_computation()->root_instruction(), - Not(op::Fusion())); -} - -// Like ReuseViaImplicitBroadcastUnary but with a binary operation. -TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { - Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4}); - Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4}); - - HloComputation::Builder builder(TestName()); - - HloInstruction* small_param = - builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, small_shape, "param")); - HloInstruction* large_param = - builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, large_shape, "param")); - HloInstruction* small_exp = builder.AddInstruction( - HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param)); - - builder.AddInstruction(HloInstruction::CreateBinary( - large_shape, HloOpcode::kAdd, small_exp, large_param)); - - std::unique_ptr module = CreateNewUnverifiedModule(); - module->AddEntryComputation(builder.Build()); - - auto did_fusion = CpuInstructionFusion().Run(module.get()); - ASSERT_TRUE(did_fusion.ok()); - EXPECT_FALSE(did_fusion.ValueOrDie()); - ASSERT_THAT(module->entry_computation()->root_instruction(), - Not(op::Fusion())); -} - void CreateComputationForDotAddOutputFusionTest(const string& test_name, HloModule* module, int m, int k, int n, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index d8878e622c0..b09fead49b7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/stream_executor/stream_executor.h" namespace xla { @@ -86,7 +87,11 @@ extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; extern const char* const kKeyValueSortSymbolName = "__xla_cpu_runtime_KeyValueSort"; +extern const char* const kTracingStartSymbolName = + "__xla_cpu_runtime_TracingStart"; +extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; + } // namespace runtime } // namespace cpu } // namespace xla @@ -104,6 +109,24 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace +extern "C" { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY xla::int64 __xla_cpu_runtime_TracingStart( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + const char* name) { + VLOG(3) << "TracingStart " << name; + return tensorflow::profiler::TraceMe::ActivityStart(name); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + xla::int64 id) { + VLOG(3) << "TracingEnd " << id; + tensorflow::profiler::TraceMe::ActivityEnd(id); +} + +} // extern "C" + TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 3a2b44d8c1a..684cb92c217 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -66,6 +66,9 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; extern const char* const kKeyValueSortSymbolName; +extern const char* const kTracingStartSymbolName; +extern const char* const kTracingEndSymbolName; + // All symbol names for XLA CPU runtime functions need to start with this // prefix. extern const char* const kXlaCpuRuntimeSymbolNamePrefix; @@ -80,6 +83,13 @@ XfeedManager* GetXfeedManager(int device_ordinal); extern "C" { +extern xla::int64 __xla_cpu_runtime_TracingStart( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + const char* name); +extern void __xla_cpu_runtime_TracingEnd( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + xla::int64 id); + // Some things common to all of the runtime entry points below: // // * The shape pointer and shape_length reflect values that can be deserialized diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 4e8c9867830..7dabe28c2af 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -101,8 +100,7 @@ std::unique_ptr> EigenMatrixMultiply(const Array2D& a, } else { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", 2); - tensorflow::EigenThreadPoolWrapper tp(&pool); - Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads()); ExecutableRunOptions run_options; run_options.set_intra_op_thread_pool(&device); diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h deleted file mode 100644 index 664125ecc95..00000000000 --- a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ - -// This file is depended on by kernels that have to build for mobile devices. -// For this reason, we avoid relying on TensorFlow and instead only use the -// standard C++ library. - -#include // NOLINT -#include -#include - -namespace xla { -namespace cpu { - -// The CPU JIT compiler uses this registry to resolve symbolic CustomCall -// targets; so when using the CPU JIT, CustomCall targets need to be registered -// here with the symbol name used in the CustomCall. -// -// The XLA AOT compiler links using a standard offline linker; so when compiling -// in AOT mode, you *also* need to make sure the name of the callee (presumably -// implemented in C++) matches up with the symbolic name used in the CustomCall. -// -// We maintain the registry in both the JIT and the AOT cases for simplicity, -// but we only use it when running in JIT mode. -class CustomCallTargetRegistry { - public: - static CustomCallTargetRegistry* Global(); - - void Register(const std::string& symbol, void* address); - void* Lookup(const std::string& symbol) const; - - private: - std::unordered_map registered_symbols_; - mutable std::mutex mu_; -}; - -class RegisterCustomCallTarget { - public: - explicit RegisterCustomCallTarget(const std::string& name, void* address) { - CustomCallTargetRegistry::Global()->Register(name, address); - } -}; - -#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b - -#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \ - static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \ - custom_call_target_register, counter)(symbol, \ - reinterpret_cast(address)) - -#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ - REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__) - -#define REGISTER_CUSTOM_CALL_TARGET(function) \ - REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function) - -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index c3c6847b7b7..e95f29fc889 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -89,13 +89,14 @@ StatusOr Disassembler::DisassembleObjectFile( }); // Construct ArrayRef pointing to section contents. - llvm::StringRef section_content_string; - if (section.getContents(section_content_string)) { + llvm::Expected section_content_string = + section.getContents(); + if (!section_content_string) { continue; } llvm::ArrayRef section_content_bytes( - reinterpret_cast(section_content_string.data()), - section_content_string.size()); + reinterpret_cast(section_content_string->data()), + section_content_string->size()); // Use int types from LLVM (eg, uint64_t) for values passed to and returned // from the LLVM API. These values map to different types in LLVM and diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 2bf22ec6e43..4645f73f1e5 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -250,11 +250,6 @@ void DotOpEmitter::EmitTiledLlvmIrGemm() { std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - const bool enable_fast_math = - hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); - const bool optimize_for_size = - options::OptimizeForSizeRequested(hlo_module_config_); - EmitSmallGemm( /*scalar_type=*/primitive_type, /*m=*/m, /*k=*/k, /*n=*/n, @@ -262,9 +257,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemm() { /*max_vector_count=*/tile_size_n_in_vector_width, /*min_vectorization_width=*/std::min(4, max_target_vector_width), /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs, - /*rhs=*/rhs, /*result=*/target, b_, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size); + /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_); } void DotOpEmitter::EmitTiledLlvmIrGemv() { @@ -323,11 +316,6 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() { llvm::Value* rhs_op = swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); - const bool enable_fast_math = - hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); - const bool optimize_for_size = - options::OptimizeForSizeRequested(hlo_module_config_); - const int target_vector_register_element_size = target_machine_features_.vector_register_num_elements( *b_->GetInsertBlock()->getParent(), primitive_type); @@ -349,9 +337,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() { /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, - /*result=*/result_op, b_, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size); + /*result=*/result_op, b_, hlo_module_config_); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; @@ -361,9 +347,7 @@ void DotOpEmitter::EmitTiledLlvmIrGemv() { /*tile_cols=*/vector_register_element_size, /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, - /*result=*/result_op, b_, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size); + /*result=*/result_op, b_, hlo_module_config_); } } @@ -445,10 +429,12 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_); - llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( - lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); - llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( - rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + std::vector lhs_multi_index = + loop_nest.EmitOperandArrayLoopNest( + lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + std::vector rhs_multi_index = + loop_nest.EmitOperandArrayLoopNest( + rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); // Create the loop which does the sum of products reduction. // @@ -468,8 +454,12 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. - lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); - rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape, + b_->getInt64Ty()); + rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape, + b_->getInt64Ty()); // For computing the sum of products we alloca a single location to store the // dot product result as we accumulate it within the reduction loop. After the @@ -532,18 +522,20 @@ void DotOpEmitter::EmitNaiveLlvmIrGemm() { // the rhs and lhs indexes with the reduction dimensions removed. The terms // from the rhs index are the lower dimensions in the index so we add them // first. - llvm_ir::IrArray::Index target_index(lhs_index.GetType()); + std::vector target_multi_index; for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { - target_index.push_back(lhs_index[dimension]); + target_multi_index.push_back(lhs_index[dimension]); } } for (int dimension = 0; dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { - target_index.push_back(rhs_index[dimension]); + target_multi_index.push_back(rhs_index[dimension]); } } + llvm_ir::IrArray::Index target_index( + target_multi_index, target_array_.GetShape(), lhs_index.GetType()); target_array_.EmitWriteArrayElement(target_index, result, b_); // Set the IR builder insert point to the exit basic block of the outer most @@ -720,8 +712,7 @@ absl::optional ProfitableToMakeDotOperandColumnMajor( return {}; } - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) { + if (hlo.IsOutputFusion()) { auto* fusion_root = hlo.fused_instructions_computation()->root_instruction(); if (fusion_root->opcode() != HloOpcode::kAdd) { @@ -921,11 +912,11 @@ llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array, llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); Shape inner_shape = DropFirstDim(outer_array.GetShape()); - llvm_ir::IrArray::Index slice_index(b->getInt64Ty()); - slice_index.push_back(batch_index); - slice_index.InsertAt( - /*index=*/1, outer_array.GetShape().dimensions_size() - 1, - b->getInt64(0)); + std::vector multidim_index(inner_shape.rank() + 1, + b->getInt64(0)); + multidim_index[0] = batch_index; + llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(), + batch_index->getType()); llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b); llvm::Type* slice_ptr_type = llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo(); @@ -1016,11 +1007,8 @@ bool DotImplementationCanHandleTranspose( GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), DotInfo(dot_instr), target_machine_features); - // TODO(sanjoy): This is not quite right, it should be `impl_strategy == - // kEigen || impl_strategy == kTiledLlvmIrGemv || impl_strategy == - // kNaiveLlvmIr` but I'll fix this in a later CL in the interest of keeping - // the CL adding this comment NFC. - return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr || + impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv || impl_strategy == DotImplementationStrategy::kEigen; } diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index 368ae8bffc5..e21ca01c803 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -55,8 +56,8 @@ StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, // Create a function declaration. llvm::Function* function = llvm::dyn_cast( module_ - ->getOrInsertFunction(llvm_ir::AsStringRef(function_name), - lhs->getType(), lhs->getType(), rhs->getType()) + ->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(), + rhs->getType()) .getCallee()); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); @@ -90,8 +91,8 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, // Create a function declaration. llvm::Function* function = llvm::dyn_cast( module_ - ->getOrInsertFunction(llvm_ir::AsStringRef(function_name), - value->getType(), value->getType()) + ->getOrInsertFunction(function_name, value->getType(), + value->getType()) .getCallee()); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); @@ -114,8 +115,7 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( std::vector operands; for (int i = 0; i < hlo->operand_count(); i++) { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(i))( - ElementwiseSourceIndex(index, *hlo, i))); + operand_to_generator.at(hlo->operand(i))(index)); operands.push_back(operand_value); } return ir_emitter_->EmitElementalMap(*Cast(hlo), @@ -136,10 +136,19 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( }; case HloOpcode::kReduce: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + auto reduce_instr = Cast(hlo); + std::vector input_generators; + for (const HloInstruction* instr : reduce_instr->inputs()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + + std::vector initial_value_generators; + for (const HloInstruction* instr : reduce_instr->init_values()) { + initial_value_generators.push_back(operand_to_generator.at(instr)); + } return ir_emitter_->EmitElementalReduce( - Cast(hlo), - operand_to_generator.at(hlo->operand(0)), - operand_to_generator.at(hlo->operand(1)), index); + reduce_instr, std::move(input_generators), + std::move(initial_value_generators), index); }; default: return ElementalIrEmitter::MakeElementGenerator(hlo, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 43e4bb43933..06ea62d552c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -58,8 +58,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -74,7 +76,6 @@ limitations under the License. namespace xla { namespace { -using llvm_ir::AsStringRef; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; } // namespace @@ -99,9 +100,7 @@ IrEmitter::IrEmitter( is_top_level_computation_(false), target_machine_features_(*target_machine_features), emit_code_for_msan_(emit_code_for_msan) { - b_.setFastMathFlags(llvm_ir::GetFastMathFlags( - /*fast_math_enabled=*/hlo_module_config_.debug_options() - .xla_cpu_enable_fast_math())); + b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_)); Status s = GatherComputationsByAllocationType( &hlo_module, &thread_local_computations_, &global_computations_); absl::c_sort(thread_local_computations_); @@ -109,6 +108,42 @@ IrEmitter::IrEmitter( TF_CHECK_OK(s) << "Should have failed buffer assignment."; } +void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) { + llvm::Argument* out_parameter = compute_function_->result_arg(); + llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction()); + const Shape& return_shape = computation->root_instruction()->shape(); + + if (ShapeUtil::IsScalar(return_shape)) { + llvm::Value* ret_value = + Load(root_value.GetBasePointer(), "load_ret_value"); + Store(ret_value, + BitCast(out_parameter, root_value.GetBasePointer()->getType())); + } else { + CHECK(return_shape.IsTuple()); + + llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_); + llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo(); + llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue); + + for (int i = 0; i < return_shape.tuple_shapes_size(); i++) { + const Shape& element_shape = return_shape.tuple_shapes(i); + llvm::Value* destination = llvm_ir::EmitGetTupleElement( + element_shape, + /*index=*/i, + /*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue, + &b_); + + llvm::Value* source = llvm_ir::EmitGetTupleElement( + element_shape, + /*index=*/i, + /*alignment=*/MinimumAlignmentForShape(element_shape), + root_value.GetBasePointer(), &b_); + + Store(Load(source), destination); + } + } +} + StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, @@ -141,11 +176,28 @@ StatusOr IrEmitter::EmitComputation( bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(use_rdtscp); + + bool emit_tracing = + hlo_module_config_.hlo_profiling_enabled() && + hlo_module_config_.debug_options().xla_backend_extra_options().count( + "xla_hlo_trace"); + tracing_state_.set_enabled(emit_tracing); + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); llvm::Function* ir_function = compute_function_->function(); InsertOrDie(&emitted_functions_, computation, ir_function); // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. + + // Function epilogue: copying the value over to either the return register, + // or values pointing from the return register. + const BufferAllocation* root_allocation = + computation_root_allocation_.allocation(); + if (root_allocation && root_allocation->is_thread_local()) { + EmitThreadLocalFunctionEpilogue(computation); + } + + // Destructor for compute_function_ emits the "ret void" instruction. compute_function_.reset(); computation_root_allocation_ = BufferAllocation::Slice(); computation_parameter_allocations_.clear(); @@ -160,11 +212,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; // Create and initialize new IrFunction. - compute_function_.reset(new IrFunction( - function_name, linkage, - options::OptimizeForSizeRequested(hlo_module_config_), - hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_, - &b_, num_dynamic_loop_bounds_)); + compute_function_.reset(new IrFunction(function_name, linkage, + hlo_module_config_, module_, &b_, + num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -173,8 +223,7 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = BitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast)); return Status::OK(); } @@ -304,7 +353,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Shape& shape = get_tuple_element->shape(); emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), - GetEmittedValueFor(operand), &b_, module_); + GetEmittedValueFor(operand), &b_); return Status::OK(); } @@ -324,7 +373,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), - GetEmittedValueFor(on_false), &b_, module_); + GetEmittedValueFor(on_false), &b_); return Status::OK(); } @@ -347,8 +396,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { assignment_.GetUniqueSlice(infeed, {1})); llvm::Value* token_address = EmitBufferPointer( token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); - llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, - module_); + llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_); if (data_shape.IsTuple()) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); @@ -379,7 +427,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { } llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape), - tuple_element_addresses, &b_, module_); + tuple_element_addresses, &b_); } else { TF_RETURN_IF_ERROR( EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address)); @@ -500,7 +548,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { ShapeUtil::GetTupleElementShape(operand_shape, i); llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), - value, &b_, module_); + value, &b_); TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, tuple_element_shape, tuple_element)); } @@ -623,8 +671,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { GetProfileCountersArgument(), less_than_function}); if (sort->values_count() > 0) { - llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, - module_); + llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_); } return Status::OK(); } @@ -635,14 +682,15 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { for (auto operand : tuple->operands()) { base_ptrs.push_back(GetEmittedValueFor(operand)); } - llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_); return Status::OK(); } llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, absl::Span elemental_operands, absl::string_view name) { - return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); + return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(), + elemental_operands, name); } StatusOr IrEmitter::EmitElementalReduceWindow( @@ -673,21 +721,22 @@ StatusOr IrEmitter::EmitElementalReduceWindow( SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), index.size()); + std::vector input_multi_index(index.size()); llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = NSWSub( + input_multi_index[i] = NSWSub( NSWAdd(strided_index, NSWMul(window_index[i], b_.getInt64(window.dimensions(i).window_dilation()))), b_.getInt64(window.dimensions(i).padding_low())); // We need to verify that we are not in the dilated base area. - llvm::Value* dilation_condition = ICmpEQ( - SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())), - b_.getInt64(0)); + llvm::Value* dilation_condition = + ICmpEQ(SRem(input_multi_index[i], + b_.getInt64(window.dimensions(i).base_dilation())), + b_.getInt64(0)); if (in_bounds_condition == nullptr) { in_bounds_condition = dilation_condition; } else { @@ -695,15 +744,16 @@ StatusOr IrEmitter::EmitElementalReduceWindow( } // Apply base dilation to the index. - input_index[i] = - SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())); + input_multi_index[i] = + SDiv(input_multi_index[i], + b_.getInt64(window.dimensions(i).base_dilation())); - // We need to check if 0 <= input_index[i] < bound, as otherwise we are in - // the padding so that we can skip the computation. That is equivalent to - // input_index[i] < bound as an *unsigned* comparison, since a negative - // value will wrap to a large positive value. + // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we + // are in the padding so that we can skip the computation. That is + // equivalent to input_multi_index[i] < bound as an *unsigned* comparison, + // since a negative value will wrap to a large positive value. llvm::Value* index_condition = - ICmpULT(input_index[i], + ICmpULT(input_multi_index[i], b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); if (in_bounds_condition == nullptr) { in_bounds_condition = index_condition; @@ -718,9 +768,11 @@ StatusOr IrEmitter::EmitElementalReduceWindow( SetToFirstInsertPoint(if_data.true_block, &b_); // We are not in the padding, so carry out the computation. + llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(), + b_.getInt64Ty()); TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, input_generator(input_index)); - llvm::Value* result = EmitThreadLocalCall( + llvm::Value* result = EmitScalarReturningThreadLocalCall( *reduce_window->to_apply(), {Load(accumulator_address), input_value}, "reducer_function"); Store(result, accumulator_address); @@ -823,15 +875,16 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); + std::vector operand_multi_index(source_index.size()); llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride())); - operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + operand_multi_index[i] = + NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); llvm::Value* index_condition = - ICmpULT(operand_index[i], + ICmpULT(operand_multi_index[i], b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); in_bounds_condition = And(in_bounds_condition, index_condition); } @@ -857,6 +910,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); + llvm_ir::IrArray::Index operand_index( + operand_multi_index, operand_array.GetShape(), b_.getInt64Ty()); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); Store(operand_data, selected_value_address); @@ -869,7 +924,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); llvm::Value* operand_element = Load(operand_address); - llvm::Value* result = EmitThreadLocalCall( + llvm::Value* result = EmitScalarReturningThreadLocalCall( *select_and_scatter->select(), {Load(selected_value_address), operand_element}, "select_function"); @@ -890,21 +945,23 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // location is computed by calling the `scatter` function with the source // value and the current output value. SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_); - llvm_ir::IrArray::Index selected_index(source_index.GetType()); + std::vector selected_multi_index; for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = InBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(Load(selected_index_address_slot)); + selected_multi_index.push_back(Load(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value = source_array.EmitReadArrayElement(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); + llvm_ir::IrArray::Index selected_index( + selected_multi_index, output_array.GetShape(), source_index.GetType()); llvm::Value* output_value = output_array.EmitReadArrayElement(selected_index, &b_); - llvm::Value* scatter_value = - EmitThreadLocalCall(*select_and_scatter->scatter(), - {output_value, source_value}, "scatter_function"); + llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall( + *select_and_scatter->scatter(), {output_value, source_value}, + "scatter_function"); output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); @@ -1054,27 +1111,31 @@ StatusOr IrEmitter::EmitElementalConvolution( // We are not in the padding, so carry out the computation. int num_dims = num_spatial_dims + 2; - llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), num_dims); + std::vector input_multi_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; + input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; } - input_index[dnums.input_feature_dimension()] = input_feature; - input_index[dnums.input_batch_dimension()] = batch; + input_multi_index[dnums.input_feature_dimension()] = input_feature; + input_multi_index[dnums.input_batch_dimension()] = batch; - llvm_ir::IrArray::Index kernel_index(b_.getInt64Ty(), num_dims); + std::vector kernel_multi_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - kernel_index[dnums.kernel_spatial_dimensions(i)] = + kernel_multi_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), kernel_spatial[i]) : kernel_spatial[i]; } - kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; - kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; + kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature; + kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature; + llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(), + b_.getInt64Ty()); TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, input_generator(input_index)); + llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(), + b_.getInt64Ty()); TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value, kernel_generator(kernel_index)); llvm::Value* product = FMul(input_value, kernel_value); @@ -1338,7 +1399,7 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } - llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_); return Status::OK(); } @@ -1587,22 +1648,23 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), &b_); - llvm_ir::IrArray::Index reduced_dims_index = + std::vector input_multi_index = reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_); llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); - llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = output_index.begin(); - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = *it++; + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; } } CHECK(output_index.end() == it); + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + b_.getInt64Ty()); llvm::Value* input_address = BitCast( arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); @@ -1659,6 +1721,11 @@ StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, absl::Span dimensions, HloComputation* function, string* failure_reason) { + if (!reduce->shape().IsArray()) { + *failure_reason = "vectorization of variadic reduce not implemented"; + return false; + } + if (!ReductionPreservesLayout(*reduce)) { return false; } @@ -1714,8 +1781,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( // } llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_); - llvm_ir::IrArray::Index array_index(b_.getInt64Ty(), - reduce->shape().dimensions_size()); + std::vector array_multi_index( + reduce->shape().dimensions_size()); for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; --i) { int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); @@ -1723,7 +1790,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 end_index = reduce->shape().dimensions(dimension); std::unique_ptr loop = loop_nest.AddLoop( start_index, end_index, absl::StrFormat("dim.%d", dimension)); - array_index[dimension] = loop->GetIndVarValue(); + array_multi_index[dimension] = loop->GetIndVarValue(); } int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0); @@ -1744,12 +1811,14 @@ StatusOr IrEmitter::EmitVectorizedReduce( std::unique_ptr loop = loop_nest.AddLoop(start_index, end_index, vectorization_factor, absl::StrFormat("dim.%d", innermost_dimension)); - array_index[innermost_dimension] = loop->GetIndVarValue(); + array_multi_index[innermost_dimension] = loop->GetIndVarValue(); SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); ShardedVectorType vector_type = CreateShardedVectorType( reduce->shape().element_type(), vectorization_factor); + llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(), + b_.getInt64Ty()); TF_ASSIGN_OR_RETURN(std::vector accumulator, EmitInnerLoopForVectorizedReduction( reduction_generator, array_index, vector_type, @@ -1775,13 +1844,15 @@ StatusOr IrEmitter::EmitVectorizedReduce( // in the following case: if (innermost_dimension_size % vectorization_factor) { // TODO(b/63775531): Consider using a scalar loop here to save on code size. - array_index[innermost_dimension] = + array_multi_index[innermost_dimension] = b_.getInt64(innermost_dimension_size - (innermost_dimension_size % vectorization_factor)); ShardedVectorType vector_type = CreateShardedVectorType( reduce->shape().element_type(), innermost_dimension_size % vectorization_factor); + llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(), + b_.getInt64Ty()); TF_ASSIGN_OR_RETURN(std::vector accumulator, EmitInnerLoopForVectorizedReduction( reduction_generator, array_index, vector_type, @@ -1803,21 +1874,39 @@ StatusOr IrEmitter::EmitVectorizedReduce( StatusOr IrEmitter::EmitElementalReduce( const HloReduceInstruction* reduce, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::ElementGenerator& initial_value_generator, + std::vector input_generators, + std::vector initial_value_generators, const llvm_ir::IrArray::Index& index) { - const HloInstruction* arg = reduce->operand(0); - absl::Span dimensions(reduce->dimensions()); + const Shape& out_shape = reduce->shape(); + bool is_variadic = !out_shape.IsArray(); + int accumulators_count = 1; + if (is_variadic) { + CHECK(out_shape.IsTuple()); + accumulators_count = out_shape.tuple_shapes_size(); + } - // Initialize an accumulator with init_value. - PrimitiveType accumulator_type = reduce->shape().element_type(); - llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", - &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_value, - initial_value_generator(llvm_ir::IrArray::Index(index.GetType()))); - Store(init_value, accumulator_addr); + absl::Span reduced_dimensions(reduce->dimensions()); + + std::vector accumulator_addrs; + std::vector accumulator_types; + for (int i = 0; i < accumulators_count; i++) { + const Shape& element_shape = + is_variadic ? out_shape.tuple_shapes(i) : out_shape; + PrimitiveType accumulator_type = element_shape.element_type(); + llvm::Type* accumulator_llvm_type = + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); + accumulator_types.push_back(accumulator_llvm_type); + + // Initialize an accumulator with init_value. + llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( + accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_, + MinimumAlignmentForPrimitiveType(accumulator_type)); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType()))); + Store(init_value, accumulator_addr); + accumulator_addrs.push_back(accumulator_addr); + } // The enclosing loops go over all the target elements. Now we have to compute // the actual target element. For this, we build a new loop nest to iterate @@ -1825,44 +1914,67 @@ StatusOr IrEmitter::EmitElementalReduce( // AddLoopsForShapeOnDimensions will return an Index where induction Value*s // are placed for each dimension in dimensions, and all the rest are nullptrs. llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - const llvm_ir::IrArray::Index reduced_dims_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, + const HloInstruction* arg = reduce->operand(0); + std::vector input_multi_index = + loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions, "reduction_dim"); SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - // Build a full index for the input argument, using reduced_dims_index as the - // base. In reduced_dims_index only the reduction dimensions are filled in. We + // Build a full index for the input argument, using input_multi_index as the + // base. In input_multi_index only the reduction dimensions are filled in. We // fill in the rest of the dimensions with induction Value*s taken from // 'index' which iterates over the target array. See the high-level // description in the XLA documentation for details. - llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = *it++; + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; } } CHECK(index.end() == it); + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + b_.getInt64Ty()); - // Apply the reduction function to the loaded value. - TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, - input_generator(input_index)); - llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {Load(accumulator_addr), input_element}, - "reduce_function"); - Store(result, accumulator_addr); + std::vector reduction_operands; + for (llvm::Value* accum : accumulator_addrs) { + llvm::Value* accum_value = Load(accum); + reduction_operands.push_back(accum_value); + } + for (int i = 0; i < accumulators_count; i++) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, + input_generators[i](input_index)); + reduction_operands.push_back(input_element); + } + + std::vector results = EmitThreadLocalCall( + *reduce->to_apply(), reduction_operands, "reduce_function"); + + CHECK(results.size() == accumulators_count); + for (int i = 0; i < accumulators_count; i++) { + Store(results[i], accumulator_addrs[i]); + } SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return Load(accumulator_addr); + + if (is_variadic) { + // Emit a structure, as that what the LoopEmitter expects. + llvm::Value* returned_structure = llvm::UndefValue::get( + llvm::StructType::get(b_.getContext(), accumulator_types)); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* accumulator_value = Load(accumulator_addrs[i]); + returned_structure = + b_.CreateInsertValue(returned_structure, accumulator_value, i); + } + return returned_structure; + } else { + CHECK_EQ(accumulator_addrs.size(), 1); + return Load(accumulator_addrs[0]); + } } Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/118333695): Support variadic reduce. - if (!reduce->shape().IsArray()) { - return Unimplemented("Variadic reduce is not supported on CPU"); - } auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); absl::Span dimensions(reduce->dimensions()); @@ -1994,15 +2106,17 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { const int64 num_outer_loops = outer_dims.size(); llvm_ir::ForLoopNest loops(IrName(slice), &b_); - llvm_ir::IrArray::Index target_index = + std::vector target_multi_index = loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice"); // Only the indices for the outer dimensions have been initialized in // target_index. The rest of the indices should get initialized to 0, since // for the rest of the dimensions the copy writes to the full dimension. - std::replace(target_index.begin(), target_index.end(), + std::replace(target_multi_index.begin(), target_multi_index.end(), static_cast(nullptr), static_cast(b_.getInt64(0))); + llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(), + b_.getInt64Ty()); if (num_outer_loops > 0) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); @@ -2010,7 +2124,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { llvm_ir::IrArray source_array = GetIrArrayFor(operand); const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( - /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), + /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(), /*strides=*/slice->slice_strides(), /*builder=*/&b_); llvm::Value* memcpy_dest = @@ -2113,18 +2227,20 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { // Compute the output index the operand element should be assigned to. // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); - llvm_ir::IrArray::Index output_index(operand_index.GetType()); + std::vector output_multi_index; for (size_t i = 0; i < operand_index.size(); ++i) { llvm::Value* offset = Mul(operand_index[i], b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); llvm::Value* index = Add( offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); - output_index.push_back(index); + output_multi_index.push_back(index); } // Store the operand element to the computed output location. llvm_ir::IrArray output_array(GetIrArrayFor(pad)); + llvm_ir::IrArray::Index output_index( + output_multi_index, output_array.GetShape(), operand_index.GetType()); output_array.EmitWriteArrayElement(output_index, operand_data, &b_); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); @@ -2141,7 +2257,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion), &elemental_emitter, &b_); - } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { + } else if (fusion->IsLoopFusion()) { VLOG(3) << "HandleFusion kLoop"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); auto operands = GetIrArraysForOperandsOf(fusion); @@ -2150,7 +2266,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); - } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) { + } else if (fusion->IsOutputFusion()) { VLOG(3) << "HandleFusion kOutput"; int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1; const HloInstruction* dot = root->operand(dot_op_index); @@ -2213,7 +2329,6 @@ Status IrEmitter::HandleCall(HloInstruction* call) { Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { absl::Span operands(custom_call->operands()); - absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -2248,7 +2363,7 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { auto* custom_call_ir_function = llvm::dyn_cast( module_ ->getOrInsertFunction( - AsStringRef(custom_call_target), + custom_call->custom_call_target(), llvm::FunctionType::get( /*Result=*/b_.getVoidTy(), /*Params=*/{i8_ptr_type, operands_alloca->getType()}, @@ -2269,7 +2384,7 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { llvm::Value* addr = EmitBufferPointer(slice, elem_shape); base_ptrs.push_back(addr); } - llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_); } auto* output_address_arg = PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); @@ -2331,7 +2446,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( - module_->getContext(), AsStringRef(IrName(xla_while, "header")), + module_->getContext(), IrName(xla_while, "header"), compute_function_->function()); Br(header_bb); b_.SetInsertPoint(header_bb); @@ -2344,11 +2459,11 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. - llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( - module_->getContext(), AsStringRef(IrName(xla_while, "body")), - compute_function_->function()); + llvm::BasicBlock* body_bb = + llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"), + compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( - module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); + module_->getContext(), IrName(xla_while, "exit")); CondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. @@ -2403,11 +2518,13 @@ StatusOr IrEmitter::EmitFastConcatenate( llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); llvm_ir::ForLoopNest loops(IrName(concatenate), &b_); - llvm_ir::IrArray::Index outer_dims_index = + std::vector target_multi_index = loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat"); - std::replace(outer_dims_index.begin(), outer_dims_index.end(), + std::replace(target_multi_index.begin(), target_multi_index.end(), static_cast(nullptr), static_cast(b_.getInt64(0))); + llvm_ir::IrArray::Index target_index(target_multi_index, output_shape, + b_.getInt64Ty()); if (!outer_dims.empty()) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); @@ -2419,10 +2536,9 @@ StatusOr IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. - llvm::Value* target_region_begin = - BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_, - "target_region"), - i8_ptr_type); + llvm::Value* target_region_begin = BitCast( + target_array.EmitArrayElementAddress(target_index, &b_, "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2436,8 +2552,10 @@ StatusOr IrEmitter::EmitFastConcatenate( for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); + llvm_ir::IrArray::Index source_index(target_multi_index, operand->shape(), + b_.getInt64Ty()); llvm::Value* copy_source_address = BitCast( - source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), + source_array.EmitArrayElementAddress(source_index, &b_, "src_addr"), i8_ptr_type); llvm::Value* copy_target_address = @@ -2691,7 +2809,19 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx), - AsStringRef(counter_name)); + counter_name); +} + +llvm::Value* IrEmitter::GetProfileCounterFor( + const HloInstruction& instruction) { + return GetProfileCounterCommon(instruction, + instruction_to_profile_idx_); +} + +llvm::Value* IrEmitter::GetProfileCounterFor( + const HloComputation& computation) { + return GetProfileCounterCommon(computation, + computation_to_profile_idx_); } void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, @@ -2735,7 +2865,7 @@ llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) { void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo) { auto* cycle_start = ReadCycleCounter(b); - cycle_start->setName(AsStringRef(IrName(hlo, "cycle_start"))); + cycle_start->setName(IrName(hlo, "cycle_start")); cycle_starts_[hlo] = cycle_start; if (first_read_cycle_start_ == nullptr) { first_read_cycle_start_ = cycle_start; @@ -2746,7 +2876,7 @@ void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, llvm::Value* prof_counter) { auto* cycle_end = ReadCycleCounter(b); - cycle_end->setName(AsStringRef(IrName(hlo, "cycle_end"))); + cycle_end->setName(IrName(hlo, "cycle_end")); auto* cycle_start = cycle_starts_[hlo]; UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start); last_read_cycle_end_ = cycle_end; @@ -2760,9 +2890,70 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( } } +void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b, + HloInstruction* hlo, + llvm::Value* run_options) { + if (!enabled_) { + return; + } + + llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo(); + llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo(); + llvm::FunctionType* fn_type = + llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type}, + /*isVarArg=*/false); + + llvm::Function* function = b->GetInsertBlock()->getParent(); + llvm::Module* module = function->getParent(); + const char* fn_name = runtime::kTracingStartSymbolName; + llvm::FunctionCallee trace_func = + module->getOrInsertFunction(fn_name, fn_type); + if (auto* fn = llvm::dyn_cast(trace_func.getCallee())) { + fn->setCallingConv(llvm::CallingConv::C); + fn->setDoesNotThrow(); + fn->setOnlyAccessesArgMemory(); + } + auto* hlo_name = b->CreateGlobalStringPtr(hlo->name()); + auto* activity_id = + b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type), + b->CreateBitCast(hlo_name, int8_ptr_type)}); + activity_id->setName(IrName(hlo, "activity_id")); + activity_ids_[hlo] = activity_id; +} + +void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b, + HloInstruction* hlo, + llvm::Value* run_options) { + if (!enabled_) { + return; + } + + llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo(); + llvm::FunctionType* fn_type = + llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()}, + /*isVarArg=*/false); + + llvm::Function* function = b->GetInsertBlock()->getParent(); + llvm::Module* module = function->getParent(); + const char* fn_name = runtime::kTracingEndSymbolName; + llvm::FunctionCallee trace_func = + module->getOrInsertFunction(fn_name, fn_type); + if (auto* fn = llvm::dyn_cast(trace_func.getCallee())) { + fn->setCallingConv(llvm::CallingConv::C); + fn->setDoesNotThrow(); + fn->setOnlyAccessesArgMemory(); + } + auto* activity_id = activity_ids_.at(hlo); + b->CreateCall(trace_func, + {b->CreateBitCast(run_options, void_ptr_type), activity_id}); +} + Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); if (instruction_to_profile_idx_.count(hlo)) { + // Only trace the same HLOs that the profiler does. + tracing_state_.EmitTracingStart(&b_, hlo, + GetExecutableRunOptionsArgument()); profiling_state_.RecordCycleStart(&b_, hlo); } return Status::OK(); @@ -2772,6 +2963,10 @@ Status IrEmitter::Postprocess(HloInstruction* hlo) { if (auto* prof_counter = GetProfileCounterFor(*hlo)) { profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter); } + // Only trace the same HLOs that the profiler does. + if (instruction_to_profile_idx_.count(hlo)) { + tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument()); + } return Status::OK(); } @@ -2821,15 +3016,6 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address = [&]() -> llvm::Value* { - if (slice == computation_root_allocation_) { - llvm::Argument* retval = compute_function_->result_arg(); - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - return retval; - } - auto param_it = computation_parameter_allocations_.find(slice.allocation()->index()); if (param_it != computation_parameter_allocations_.end()) { @@ -2919,7 +3105,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueTopLevelSlice(op)); llvm::Value* addr = EmitBufferPointer(slice, target_shape); - addr->setName(AsStringRef(IrName(op))); + addr->setName(IrName(op)); emitted_value_[op] = addr; return Status::OK(); } @@ -2939,7 +3125,8 @@ Status IrEmitter::EmitTargetElementLoop( TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); llvm_ir::IrArray target_array = GetIrArrayFor(target_op); - if (target_op->IsMultiOutputFusion()) { + if (target_shape.IsTuple() && (target_op->IsMultiOutputFusion() || + target_op->opcode() == HloOpcode::kReduce)) { // For multiple outputs fusion, we need to emit each operand and the root. TF_RET_CHECK(num_dynamic_loop_bounds_ == 0); std::vector output_arrays; @@ -2959,7 +3146,7 @@ Status IrEmitter::EmitTargetElementLoop( for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_); } else { if (ShouldEmitParallelLoopFor(*target_op)) { @@ -3021,19 +3208,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -llvm::Value* IrEmitter::EmitThreadLocalCall( +llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name) { + std::vector return_value = + EmitThreadLocalCall(callee, parameters, name); + CHECK_EQ(return_value.size(), 1); + return return_value[0]; +} + +std::vector IrEmitter::EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view name) { CHECK(absl::c_binary_search(thread_local_computations_, &callee)); - const Shape& return_shape = callee.root_instruction()->shape(); - - // Lifting this restriction to allow "small" arrays should be easy. Allowing - // larger arrays is difficult because we allocate the buffer for this return - // value on the stack. - CHECK(ShapeUtil::IsScalar(return_shape)); - - PrimitiveType return_type = return_shape.element_type(); + bool is_scalar_return = ShapeUtil::IsScalar(return_shape); + bool is_tuple_of_scalars_return = + return_shape.IsTuple() && + absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) { + return ShapeUtil::IsScalar(shape); + }); + CHECK(is_scalar_return || is_tuple_of_scalars_return); std::vector parameter_addrs; for (llvm::Value* parameter : parameters) { @@ -3044,10 +3239,30 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( parameter_addrs.push_back(parameter_addr); } + llvm::Type* return_value_buffer_type = + llvm_ir::ShapeToIrType(return_shape, module_); + std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr"); + int retval_alignment = + is_scalar_return + ? MinimumAlignmentForPrimitiveType(return_shape.element_type()) + : 0; llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), - absl::StrCat(name, "_retval_addr"), &b_, - MinimumAlignmentForPrimitiveType(return_type)); + return_value_buffer_type, retval_alloca_name, &b_, retval_alignment); + + std::vector allocas_for_returned_scalars; + if (is_scalar_return) { + allocas_for_returned_scalars.push_back(return_value_buffer); + } else { + constexpr int max_tuple_size = 1000; + CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size) + << "Multivalue function can not return more than 1000 elements to avoid" + << " stack smashing"; + allocas_for_returned_scalars = + llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_); + llvm_ir::IrArray tuple_array(return_value_buffer, return_shape); + + EmitTuple(tuple_array, allocas_for_returned_scalars, &b_); + } Call(FindOrDie(emitted_functions_, &callee), GetArrayFunctionCallArguments( @@ -3058,7 +3273,12 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), /*profile_counters_arg=*/GetProfileCountersArgument())); - return Load(return_value_buffer); + std::vector returned_scalars; + returned_scalars.reserve(allocas_for_returned_scalars.size()); + for (llvm::Value* addr : allocas_for_returned_scalars) { + returned_scalars.push_back(Load(addr)); + } + return returned_scalars; } void IrEmitter::EmitGlobalCall(const HloComputation& callee, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index e183ae01070..44d660fbb1f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -132,8 +132,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit code to emit the element at `index` for a reduce instruction. StatusOr EmitElementalReduce( const HloReduceInstruction* reduce, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::ElementGenerator& initial_value_generator, + std::vector input_generators, + std::vector initial_value_generator, const llvm_ir::IrArray::Index& index); protected: @@ -197,23 +197,25 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); - template - llvm::Value* GetProfileCounterCommon( - const T& hlo, - const std::unordered_map& profile_index_map); + // Emits the copying epilogue for the function, + // where it copies the returned value to the reserved alloca. + // This is only necessary for thread-local functions. + // Note that since the call graph is flattened, if the same function is + // called in both thread-local and non-thread-local it would be codegen'd + // twice, and we would know whether it's thread-local at codegen time. + void EmitThreadLocalFunctionEpilogue(HloComputation* computation); // Convenience functions to generate a GEP into the profile counter parameter // which would correspond to the index for a given HLO instruction or // computation. - llvm::Value* GetProfileCounterFor(const HloInstruction& instruction) { - return GetProfileCounterCommon(instruction, - instruction_to_profile_idx_); - } + llvm::Value* GetProfileCounterFor(const HloInstruction& instruction); + llvm::Value* GetProfileCounterFor(const HloComputation& computation); - llvm::Value* GetProfileCounterFor(const HloComputation& computation) { - return GetProfileCounterCommon(computation, - computation_to_profile_idx_); - } + // Helper function template for the implementation of the above two functions. + template + llvm::Value* GetProfileCounterCommon( + const T& hlo, + const std::unordered_map& profile_index_map); // Gets the IR Value emitted previously for the given hlo. // @@ -273,12 +275,18 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emits a call to a thread local function (e.g. to the computation nested // within a reduce or a map). Thread local callees (by definition) only write // to and read from thread local allocations. + // Supports only functions returning scalars or tuples of scalars. // // `parameters` holds the *scalar values* that need to be passed to the // callee. The return value is the scalar returned by the callee. - llvm::Value* EmitThreadLocalCall(const HloComputation& callee, - absl::Span parameters, - absl::string_view name); + std::vector EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name); + + // Similar to EmitThreadLocal, yet assumes that the function returns a scalar. + llvm::Value* EmitScalarReturningThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to @@ -523,6 +531,22 @@ class IrEmitter : public DfsHloVisitorWithDefault, ProfilingState profiling_state_; + class TracingState { + public: + TracingState() : enabled_(false) {} + void set_enabled(bool value) { enabled_ = value; } + void EmitTracingStart(llvm::IRBuilder<>* b, HloInstruction* hlo, + llvm::Value* run_options); + void EmitTracingEnd(llvm::IRBuilder<>* b, HloInstruction* hlo, + llvm::Value* run_options); + + private: + bool enabled_; + // Maps from HLO to the activity id returned by xprof::TraceMe. + std::unordered_map activity_ids_; + }; + TracingState tracing_state_; + // Given a load instruction and a shape or buffer size, annotate the load's // result with the alignment required by the shape or size. void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape); diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 84a5b058cfb..42acd72f966 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -24,11 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" namespace xla { - -namespace { -using llvm_ir::AsStringRef; -} // namespace - namespace cpu { static std::vector GetComputeFunctionParams( @@ -48,15 +43,14 @@ static std::vector GetComputeFunctionParams( IrFunction::IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, - const bool optimize_for_size_requested, - const bool enable_fast_math, llvm::Module* llvm_module, - llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds) + const HloModuleConfig& module_config, + llvm::Module* llvm_module, llvm::IRBuilder<>* b, + int64 num_dynamic_loop_bounds) : b_(b), llvm_module_(llvm_module), caller_insert_point_guard_(*b), num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { - Initialize(function_name, linkage, optimize_for_size_requested, - enable_fast_math); + Initialize(function_name, linkage, module_config); } IrFunction::~IrFunction() { @@ -75,8 +69,7 @@ DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { void IrFunction::Initialize(const string& function_name, llvm::Function::LinkageTypes linkage, - const bool optimize_for_size_requested, - const bool enable_fast_math) { + const HloModuleConfig& module_config) { // The function signature is: // void function(i8* retval, i8* run_options, i8** params, i8** // buffer_table, @@ -147,11 +140,8 @@ void IrFunction::Initialize(const string& function_name, // Functions with local linkage get an inlining bonus. Because we know // a-priori that embedded functions (non-entry functions) will not have its // name resolved, give it local linkage. - function_ = - llvm_ir::CreateFunction(function_type, linkage, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size_requested, - function_name, llvm_module_); + function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config, + function_name, llvm_module_); // Set meaningful names for the function's arguments: useful for debugging. llvm::Function::arg_iterator arg_iter = function_->arg_begin(); @@ -193,7 +183,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); string name = absl::StrCat("dynamic_loop_bound_", offset); return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), - b_->getInt64(offset), AsStringRef(name))); + b_->getInt64(offset), name)); } // Emits code to allocate an array of parameter address pointers, and store @@ -216,10 +206,9 @@ std::vector GetArrayFunctionCallArguments( absl::StrCat(name, "_parameter_addresses"), b); for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = - b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(absl::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); + llvm::Value* parameter_as_i8ptr = b->CreateBitCast( + parameter_addresses[i], b->getInt8PtrTy(), + absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr")); llvm::Value* slot_in_param_addresses = b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); @@ -324,7 +313,7 @@ Status EmitCallToParallelForkJoin( /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/partitions_array, /*Name=*/ - AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions"))); + absl::StrCat(name, "_parallel_dimension_partitions")); // Add argument specifying parallel dimension partitions. fork_join_arguments.push_back( diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 623a5f185fa..02bcec9dfc7 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,8 +53,7 @@ namespace cpu { class IrFunction { public: IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, - const bool optimize_for_size_requested, - const bool enable_fast_math, llvm::Module* llvm_module, + const HloModuleConfig& module_config, llvm::Module* llvm_module, llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds); ~IrFunction(); @@ -92,7 +92,7 @@ class IrFunction { // Initialize an llvm::Function with standard signature based on arguments. void Initialize(const string& function_name, llvm::Function::LinkageTypes linkage, - bool optimize_for_size_requested, bool enable_fast_math); + const HloModuleConfig& module_config); // Emit ir to read and return the ir value for the dynamic loop bound at // 'offset' from the "dynamic_loop_bounds" argument of this function. diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 93ef51754d2..a4bb5f72297 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -119,13 +119,9 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, int32 vector_width) { VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32"); - // This implements the same polynomial approximation as implemented in Eigen3. - + // This implements the same polynomial approximation as implemented in Cephes. const llvm::APFloat half = GetIeeeF32(0.5); - const llvm::APFloat one = GetIeeeF32(1.0); - - const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950); - const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949); + const llvm::APFloat one = GetIeeeF32(1); const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341); const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375); @@ -138,39 +134,79 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); - llvm::Value* input_clamped = - vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); - llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); - llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx); - llvm::Value* z = vsl.Mul(cephes_exp_C2, fx); - llvm::Value* x = vsl.Sub(input_clamped, tmp); - x = vsl.Sub(x, z); - z = vsl.Mul(x, x); + // To compute e^input, we re-express it as + // + // e^input = e^(a + b) + // = e^(a + n log(2)) + // = e^a * 2^n. + // + // We choose n = floor(a * log(2) + 0.5), restricting the value of `a` to + // (-0.5, 0.5). We then use a polynomial to compute e^a. - llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1); - y = vsl.MulAdd(y, x, cephes_exp_p2); - y = vsl.MulAdd(y, x, cephes_exp_p3); - y = vsl.MulAdd(y, x, cephes_exp_p4); - y = vsl.MulAdd(y, x, cephes_exp_p5); - y = vsl.MulAdd(y, z, x); - y = vsl.Add(one, y); + // Restrict input to a small range, including some values that evaluate to + // +/- inf. Our computations below aren't particularly sensitive to the exact + // choices here, so we choose values a bit larger/smaller than + // + // log(F32_MAX) = 88.723... + // log(F32_EPSILON) = -103.279.... + // + input = vsl.Clamp(input, GetIeeeF32(-104), GetIeeeF32(88.8)); - // VectorSupportLibrary (intentionally) can't juggle more than one type at a - // time so drop down to IRBuilder for this bit. - llvm::Value* vector_constant_0x7f = - b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); - llvm::Value* vector_constant_23 = - b->CreateVectorSplat(vector_width, b->getInt32(23)); - llvm::Type* i32_vector_type = - llvm::VectorType::get(b->getInt32Ty(), vector_width); - // fx is clamped so we don't have to worry about it being out of range for - // i32. - llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type); - emm0 = b->CreateAdd(emm0, vector_constant_0x7f); - emm0 = b->CreateShl(emm0, vector_constant_23); - llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type()); + llvm::Value* x = input; + llvm::Value* n = vsl.Floor(vsl.MulAdd(input, cephes_LOG2EF, half)); - return vsl.Max(vsl.Mul(y, emm0_f32), input); + // When we eventually do the multiplication in e^a * 2^n, we need to handle + // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1 + // (so e^a * 2^n != inf). There's a similar problem for n < -126, the + // smallest fp32 exponent. + // + // A straightforward solution would be to detect n out of range and split it + // up, doing + // + // e^a * 2^n = e^a * 2^(n1 + n2) + // = (2^n1 * e^a) * 2^n2. + // + // But it turns out this approach is quite slow. It's not clear why; our + // hypothesis is that the integer operations on the exponent `n` have nonlocal + // effects on the pipeline. + // + // The approach we use instead is to clamp n to [-126, 127] so 2^n doesn't + // over/underflow. This causes `a` to be outside the range (-0.5, 0.5), which + // means that our polynomial for e^a will give a less-accurate result. In + // practice this seems to work well enough; it passes our exhaustive tests, + // breaking only one result, and by one ulp (we return exp(88.7228394) = + // max-float but we should return inf). + n = vsl.Clamp(n, GetIeeeF32(-126), GetIeeeF32(127)); + + // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5). + x = vsl.Sub(x, vsl.Mul(cephes_exp_C1, n)); + x = vsl.Sub(x, vsl.Mul(cephes_exp_C2, n)); + llvm::Value* z = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1); + z = vsl.MulAdd(z, x, cephes_exp_p2); + z = vsl.MulAdd(z, x, cephes_exp_p3); + z = vsl.MulAdd(z, x, cephes_exp_p4); + z = vsl.MulAdd(z, x, cephes_exp_p5); + z = vsl.MulAdd(z, vsl.Mul(x, x), x); + z = vsl.Add(one, z); + + // Convert n to an i32. This is safe because we clamped it above. + llvm::Value* n_i32 = + b->CreateFPToSI(n, llvm::VectorType::get(b->getInt32Ty(), vector_width)); + + // Create 2^n as an fp32. This works because -126 <= n <= 127 means that n is + // within the bounds for an fp32 exponent. + auto splat_i32 = [&](int32 v) { + return b->CreateVectorSplat(vector_width, b->getInt32(v)); + }; + const int32 kF32SignificandBits = 23; + llvm::Value* exp_bias = splat_i32(0x7f); + llvm::Value* pow2 = + b->CreateBitCast(b->CreateShl(b->CreateAdd(n_i32, exp_bias), + splat_i32(kF32SignificandBits)), + vsl.vector_type()); + + // Return z * 2^n. + return vsl.Mul(z, pow2); } llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index a6f4273a5a7..ffbd0d68ce9 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -39,7 +39,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm_ir::ForLoopNest loop_nest(loop_name, b_); const int64 num_dims = shape_.dimensions_size(); - llvm_ir::IrArray::Index array_index(index_type, num_dims); + std::vector array_multi_index(num_dims); // Add loops from outer-most to inner-most dimensions. for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { @@ -54,14 +54,14 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::unique_ptr loop = loop_nest.AddLoop( /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, end_index); - array_index[dimension] = loop->GetIndVarValue(); + array_multi_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), /*suffix=*/absl::StrFormat("dim.%d", dimension)); - array_index[dimension] = loop->GetIndVarValue(); + array_multi_index[dimension] = loop->GetIndVarValue(); } } // Point IR builder at inner loop BB. @@ -71,6 +71,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); CHECK(exit_bb_ != nullptr); + llvm_ir::IrArray::Index array_index(array_multi_index, shape_, index_type); return {array_index}; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 6121d1ca9a5..234fa91fe3e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -146,8 +146,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || - (opcode == HloOpcode::kFusion && - instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || + (opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) || instruction->shape().IsTuple()) { return 1; } @@ -239,10 +238,7 @@ void ParallelTaskAssigner::ComputeTargetParallelTasks( &target_machine_features_); // Compute parallel task counts for all instructions in 'module'. - for (auto* computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } + for (auto* computation : module->MakeNonfusionComputations()) { for (auto* instruction : computation->instructions()) { // Query ParallelTaskAssignment for target parallel task count. const int64 target_parallel_task_count = diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc index 3905e7ff2a1..84cb41a8f17 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" @@ -34,6 +35,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32( int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); + XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); tensorflow::xla::EigenConvImpl( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, @@ -53,6 +55,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF16( int64 rhs_col_dilation) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); + XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); tensorflow::xla::EigenConvImpl( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h index 85af63bb032..193c25f2a4b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h @@ -26,15 +26,17 @@ namespace xla { template void EigenConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs, - ScalarType* rhs, int64 input_batch, int64 input_rows, - int64 input_cols, int64 input_channels, int64 kernel_rows, - int64 kernel_cols, int64 kernel_channels, - int64 kernel_filters, int64 output_rows, int64 output_cols, - int64 row_stride, int64 col_stride, int64 padding_top, - int64 padding_bottom, int64 padding_left, - int64 padding_right, int64 lhs_row_dilation, - int64 lhs_col_dilation, int64 rhs_row_dilation, - int64 rhs_col_dilation) { + ScalarType* rhs, Eigen::Index input_batch, + Eigen::Index input_rows, Eigen::Index input_cols, + Eigen::Index input_channels, Eigen::Index kernel_rows, + Eigen::Index kernel_cols, Eigen::Index kernel_channels, + Eigen::Index kernel_filters, Eigen::Index output_rows, + Eigen::Index output_cols, Eigen::Index row_stride, + Eigen::Index col_stride, Eigen::Index padding_top, + Eigen::Index padding_bottom, Eigen::Index padding_left, + Eigen::Index padding_right, Eigen::Index lhs_row_dilation, + Eigen::Index lhs_col_dilation, Eigen::Index rhs_row_dilation, + Eigen::Index rhs_col_dilation) { const Eigen::TensorMap, Eigen::Aligned> input(lhs, input_batch, input_rows, input_cols, input_channels); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc index 848d2d22414..10026f3a371 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" @@ -31,6 +32,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft( int64 fft_length2) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); + XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); tensorflow::xla::EigenFftImpl(*run_options->intra_op_thread_pool(), out, operand, fft_type, fft_rank, input_batch, fft_length0, fft_length1, fft_length2); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h b/tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h new file mode 100644 index 00000000000..4a662864728 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_LIGHTWEIGHT_CHECK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_LIGHTWEIGHT_CHECK_H_ + +#include +#include + +// Aborts the program if the condition is false. +// +// This is like QCHECK, except it doesn't pull in the TF/XLA logging framework. +// This makes it suitable for use from within the XLA:CPU runtime files, which +// need to be lightweight. +#define XLA_LIGHTWEIGHT_CHECK(cond) \ + do { \ + if (!(cond)) { \ + std::cerr << __FILE__ << ":" << __LINE__ \ + << " Failed XLA_LIGHTWEIGHT_QCHECK " << #cond << std::endl; \ + std::abort(); \ + } \ + } while (0) + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_LIGHTWEIGHT_CHECK_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index fe7e87a197b..844e9a24de9 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -19,6 +19,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" @@ -69,6 +70,7 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension // 0 of the rhs. + XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims); } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index f7b64738b7b..bf55e9e22cf 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include + #include #include #include @@ -28,7 +29,6 @@ limitations under the License. #include "llvm/Support/CodeGen.h" #include "llvm/Support/Host.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -91,14 +92,14 @@ SimpleOrcJIT::InferTargetMachineForJIT( return target_machine; } -SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, - llvm::CodeGenOpt::Level opt_level, - bool optimize_for_size, bool enable_fast_math, - bool disable_expensive_passes, - LLVMCompiler::ModuleHook pre_optimization_hook, - LLVMCompiler::ModuleHook post_optimization_hook) +SimpleOrcJIT::SimpleOrcJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, + bool disable_expensive_passes, + LLVMCompiler::ModuleHook pre_optimization_hook, + LLVMCompiler::ModuleHook post_optimization_hook, + std::function post_codegen_hook) : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), - disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), symbol_resolver_(llvm::orc::createLegacyLookupResolver( execution_session_, @@ -128,12 +129,12 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, [this](VModuleKeyT, const llvm::object::ObjectFile& object) { this->NotifyObjectFreed(object); }), - compile_layer_(object_layer_, - CompilerFunctor(target_machine_.get(), &disassembler_, - opt_level, optimize_for_size, - enable_fast_math, disable_expensive_passes, - std::move(pre_optimization_hook), - std::move(post_optimization_hook))), + compile_layer_( + object_layer_, + CompilerFunctor( + target_machine_.get(), opt_level, optimize_for_size, + disable_expensive_passes, std::move(pre_optimization_hook), + std::move(post_optimization_hook), std::move(post_codegen_hook))), gdb_jit_event_listener_( llvm::JITEventListener::createGDBRegistrationListener()) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() @@ -146,13 +147,18 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { // On Mac OS X, 'name' may have a leading underscore prefix, even though the // registered name may not. std::string stripped_name(name.begin() + 1, name.end()); - func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name); + func_addr = + xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host"); } else { - func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host"); } if (func_addr == nullptr) { - LOG(ERROR) << "Unable to resolve runtime symbol: " << name; + LOG(ERROR) + << "Unable to resolve runtime symbol: `" << name + << "'. Hint: if the symbol a custom call target, make sure you've " + "registered it with the JIT using " + "XLA_CPU_REGISTER_CUSTOM_CALL_TARGET."; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), @@ -206,14 +212,15 @@ llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) { namespace { // Register some known symbols with the CustomCallTargetRegistry. bool RegisterKnownJITSymbols() { - CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); + xla::CustomCallTargetRegistry* registry = + xla::CustomCallTargetRegistry::Global(); #define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ do { \ auto* function_address = \ reinterpret_cast(__xla_cpu_runtime_##base_name); \ registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ - function_address); \ + function_address, "Host"); \ CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ "__xla_cpu_runtime_" #base_name); \ } while (false) @@ -241,9 +248,13 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); + REGISTER_CPU_RUNTIME_SYMBOL(TracingStart); + REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd); - registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); - registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); + registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee), + "Host"); + registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee), + "Host"); #undef REGISTER_CPU_RUNTIME_SYMBOL @@ -251,11 +262,12 @@ bool RegisterKnownJITSymbols() { // Unfortunately the double versions are overloaded on some systems, e.g. // Mac so we need an explicit cast. This requires passing the function signature // for that case. -#define REGISTER_LIBM_SYMBOL(name, double_sig) \ - do { \ - registry->Register(#name "f", reinterpret_cast(name##f)); \ - registry->Register( \ - #name, reinterpret_cast(static_cast(name))); \ +#define REGISTER_LIBM_SYMBOL(name, double_sig) \ + do { \ + registry->Register(#name "f", reinterpret_cast(name##f), "Host"); \ + registry->Register(#name, \ + reinterpret_cast(static_cast(name)), \ + "Host"); \ } while (false) REGISTER_LIBM_SYMBOL(acos, double (*)(double)); @@ -313,8 +325,9 @@ bool RegisterKnownJITSymbols() { #ifdef __APPLE__ REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); registry->Register("__sincosf_stret", - reinterpret_cast(__sincosf_stret)); - registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret)); + reinterpret_cast(__sincosf_stret), "Host"); + registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret), + "Host"); #else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); #endif @@ -327,19 +340,19 @@ bool RegisterKnownJITSymbols() { #undef REGISTER_LIBM_SYMBOL - registry->Register("memcpy", reinterpret_cast(memcpy)); - registry->Register("memmove", reinterpret_cast(memmove)); - registry->Register("memset", reinterpret_cast(memset)); + registry->Register("memcpy", reinterpret_cast(memcpy), "Host"); + registry->Register("memmove", reinterpret_cast(memmove), "Host"); + registry->Register("memset", reinterpret_cast(memset), "Host"); #ifdef __APPLE__ - registry->Register("__bzero", reinterpret_cast(bzero)); + registry->Register("__bzero", reinterpret_cast(bzero), "Host"); registry->Register("memset_pattern16", - reinterpret_cast(memset_pattern16)); + reinterpret_cast(memset_pattern16), "Host"); #endif #ifdef MEMORY_SANITIZER registry->Register("__msan_unpoison", - reinterpret_cast(__msan_unpoison)); + reinterpret_cast(__msan_unpoison), "Host"); #endif return true; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 3307c2f93d7..f9e845bd282 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" -#include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -51,29 +50,20 @@ class SimpleOrcJIT { using VModuleKeyT = llvm::orc::VModuleKey; // Create a new JIT, targeting the host architecture. - // The |target_options| parameter allows customization of certain code - // generation properties of the TargetMachine (whether or not float point math - // can be reassociated, etc.). - // The |opt_level| parameter controls the optimization level of the code - // generator. - // The |optimize_for_size| parameter specifies that the code generator should - // optimize to reduce code size, potentially at the cost of performance. - // The |disable_expensive_passes| parameter will disable certain optimization - // passes - // The |pre_optimization_hook| is invoked on the module before any IR - // level optimizations are applied. - // The |post_optimization_hook| is invoked on the module after all IR - // level optimizations are applied. - SimpleOrcJIT(const llvm::TargetOptions& target_options, - llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, - bool enable_fast_math, bool disable_expensive_passes, - LLVMCompiler::ModuleHook pre_optimization_hook, - LLVMCompiler::ModuleHook post_optimization_hook); + // + // {pre,post}_optimization_hook is invoked on the module before/after all + // LLVM IR-level optimizations. post_codegen_hook is invoked after + // compiling to machine code. + SimpleOrcJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, + bool disable_expensive_passes, + LLVMCompiler::ModuleHook pre_optimization_hook, + LLVMCompiler::ModuleHook post_optimization_hook, + std::function post_codegen_hook); - // Data layout this JIT was created with. const llvm::DataLayout& data_layout() const { return data_layout_; } - // Target triple (host) this JIT was created with. const llvm::Triple& target_triple() const { return target_machine_->getTargetTriple(); } @@ -107,7 +97,6 @@ class SimpleOrcJIT { std::vector module_keys_; std::unique_ptr target_machine_; - const Disassembler disassembler_; const llvm::DataLayout data_layout_; llvm::orc::ExecutionSession execution_session_; std::shared_ptr symbol_resolver_; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 3a26fea9116..a72ebe2beea 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -39,6 +39,13 @@ class CpuFusionTest : public HloTestBase { CpuFusionTest() {} ErrorSpec error_spec_{0.0001, 1e-5}; + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + return debug_options; + } }; TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc index 762ee67db9a..e07ac9edc89 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -29,7 +29,7 @@ HloModule KeyValueSort compare { p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY main { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 030bd41c2fc..951098eb104 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -53,6 +53,8 @@ TEST_F(CpuNoAliasTest, Concat) { HloInstruction* concat2 = builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {2, 6}), {concat1, param_x}, 1)); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {2, 6}), HloOpcode::kAdd, concat2, concat2)); std::unique_ptr computation = builder.Build(); @@ -81,7 +83,6 @@ TEST_F(CpuNoAliasTest, Concat) { llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); llvm::IRBuilder<> b(bb); auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); - llvm_ir::IrArray::Index zero2D({zero, zero}); llvm::ArrayType* array2d_type = llvm::ArrayType::get( llvm::ArrayType::get(llvm::Type::getFloatTy(context), 100), 100); @@ -91,7 +92,8 @@ TEST_F(CpuNoAliasTest, Concat) { ir_module.getOrInsertGlobal("param_x", array2d_type); llvm_ir::IrArray param_x_array(param_x_val, param_shape); aa.AddAliasingInformationToIrArray(*param_x, ¶m_x_array); - param_x_array.EmitReadArrayElement(zero2D, &b) + llvm_ir::IrArray::Index zero_2d({zero, zero}, param_shape, zero->getType()); + param_x_array.EmitReadArrayElement(zero_2d, &b) ->setName("read_param_x_array"); } @@ -101,7 +103,8 @@ TEST_F(CpuNoAliasTest, Concat) { auto shape = ShapeUtil::MakeShape(F32, {2, 4}); llvm_ir::IrArray concat1_array(concat1_val, shape); aa.AddAliasingInformationToIrArray(*concat1, &concat1_array); - concat1_array.EmitReadArrayElement(zero2D, &b) + llvm_ir::IrArray::Index zero_2d({zero, zero}, shape, zero->getType()); + concat1_array.EmitReadArrayElement(zero_2d, &b) ->setName("read_concat1_array"); } @@ -111,15 +114,26 @@ TEST_F(CpuNoAliasTest, Concat) { auto shape = ShapeUtil::MakeShape(F32, {2, 6}); llvm_ir::IrArray concat2_array(concat2_val, shape); aa.AddAliasingInformationToIrArray(*concat2, &concat2_array); - concat2_array.EmitReadArrayElement(zero2D, &b) + llvm_ir::IrArray::Index zero_2d({zero, zero}, shape, zero->getType()); + concat2_array.EmitReadArrayElement(zero_2d, &b) ->setName("read_concat2_array"); } + { + llvm::Value* concat2_val = ir_module.getOrInsertGlobal("add", array2d_type); + auto shape = ShapeUtil::MakeShape(F32, {2, 6}); + llvm_ir::IrArray add_array(concat2_val, shape); + aa.AddAliasingInformationToIrArray(*add, &add_array); + llvm_ir::IrArray::Index zero_2d({zero, zero}, shape, zero->getType()); + add_array.EmitReadArrayElement(zero_2d, &b)->setName("read_add_array"); + } + // Check the AA info in the loads. const char* filecheck_pattern = R"( CHECK: %read_param_x_array = load {{.*}} !noalias [[param_x_noalias:![0-9]+]] CHECK: %read_concat1_array = load {{.*}} !alias.scope [[concat1_scope:![0-9]+]], !noalias [[concat1_noalias:![0-9]+]] CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] + CHECK: %read_add_array = load {{.*}} !alias.scope [[concat1_noalias]]{{$}} CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]} diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc index 9fc472ff767..7668f364bad 100644 --- a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -991,7 +991,7 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* b, - bool enable_fast_math, bool optimize_for_size) { + const HloModuleConfig& module_config) { RowMajorMatrixVectorProductEmitter::Config config( /*scalar_type=*/scalar_type, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, @@ -1001,8 +1001,7 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); KernelSupportLibrary::EmitAndCallOutlinedKernel( - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), + module_config, b, config.GetCacheKey(), canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, canonical_inputs.addend_canonicalized, canonical_inputs.result_canonicalized, @@ -1019,7 +1018,7 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* b, - bool enable_fast_math, bool optimize_for_size) { + const HloModuleConfig& module_config) { ColumnMajorMatrixVectorProductEmitter::Config config( /*scalar_type=*/scalar_type, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, @@ -1029,8 +1028,7 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); KernelSupportLibrary::EmitAndCallOutlinedKernel( - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), + module_config, b, config.GetCacheKey(), canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, canonical_inputs.addend_canonicalized, canonical_inputs.result_canonicalized, @@ -1048,7 +1046,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, int64 min_vectorization_width, int64 tile_size_m, int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, llvm::IRBuilder<>* b, - bool enable_fast_math, bool optimize_for_size) { + const HloModuleConfig& module_config) { TiledSmallGemmEmitter::Config config( /*scalar_type=*/scalar_type, TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, @@ -1058,9 +1056,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); KernelSupportLibrary::EmitAndCallOutlinedKernel( - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, - rhs, result, + module_config, b, config.GetCacheKey(), lhs, rhs, result, [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) { TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, /*rhs=*/rhs, diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h index 0a82326cc37..77581a53cfb 100644 --- a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ #include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -29,15 +30,15 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, tensorflow::int64 tile_cols, tensorflow::int64 m, tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, - llvm::IRBuilder<>* b, bool enable_fast_math, - bool optimize_for_size); + llvm::IRBuilder<>* b, + const HloModuleConfig& module_config); void EmitColumnMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, tensorflow::int64 tile_cols, tensorflow::int64 m, tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* b, - bool enable_fast_math, bool optimize_for_size); + const HloModuleConfig& module_config); void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m, tensorflow::int64 k, tensorflow::int64 n, @@ -46,8 +47,7 @@ void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m, tensorflow::int64 min_vectorization_width, tensorflow::int64 tile_size_m, tensorflow::int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b, bool enable_fast_math, - bool optimize_for_size); + llvm::IRBuilder<>* b, const HloModuleConfig& module_config); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 1bd4b59dd60..b15ad1e162d 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -107,13 +107,19 @@ llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, const llvm::APFloat& low, const llvm::APFloat& high) { + CHECK(!low.isNaN()); + CHECK(!high.isNaN()); + CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); + AssertCorrectTypes({a}); llvm::Type* type = a->getType(); - CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); CHECK(scalar_type_->isFloatingPointTy()); - return llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_), - GetConstantFloat(type, high), b_); + + llvm::Value* low_value = GetConstantFloat(type, low); + llvm::Value* high_value = GetConstantFloat(type, high); + a = b_->CreateSelect(b_->CreateFCmpUGE(a, low_value), a, low_value); + a = b_->CreateSelect(b_->CreateFCmpULE(a, high_value), a, high_value); + return a; } llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs, diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index c444fd7d4aa..2f8be8c111b 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -100,8 +100,10 @@ class VectorSupportLibrary { llvm::Value* Floor(llvm::Value* a); + // Precondition: Neither `low` nor `high` is nan. llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low, const llvm::APFloat& high); + llvm::Value* SplatFloat(const llvm::APFloat& d) { return GetConstantFloat(vector_type(), d); } diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc b/tensorflow/compiler/xla/service/custom_call_target_registry.cc similarity index 73% rename from tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc rename to tensorflow/compiler/xla/service/custom_call_target_registry.cc index 5f5803874b7..e6a70211f25 100644 --- a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc +++ b/tensorflow/compiler/xla/service/custom_call_target_registry.cc @@ -13,10 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" namespace xla { -namespace cpu { CustomCallTargetRegistry* CustomCallTargetRegistry::Global() { static auto* registry = new CustomCallTargetRegistry; @@ -24,16 +23,17 @@ CustomCallTargetRegistry* CustomCallTargetRegistry::Global() { } void CustomCallTargetRegistry::Register(const std::string& symbol, - void* address) { + void* address, + const std::string& platform) { std::lock_guard lock(mu_); - registered_symbols_[symbol] = address; + registered_symbols_[std::make_pair(symbol, platform)] = address; } -void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const { +void* CustomCallTargetRegistry::Lookup(const std::string& symbol, + const std::string& platform) const { std::lock_guard lock(mu_); - auto it = registered_symbols_.find(symbol); + auto it = registered_symbols_.find(std::make_pair(symbol, platform)); return it == registered_symbols_.end() ? nullptr : it->second; } -} // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/custom_call_target_registry.h b/tensorflow/compiler/xla/service/custom_call_target_registry.h new file mode 100644 index 00000000000..06239689e15 --- /dev/null +++ b/tensorflow/compiler/xla/service/custom_call_target_registry.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CUSTOM_CALL_TARGET_REGISTRY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CUSTOM_CALL_TARGET_REGISTRY_H_ + +// This file is depended on by kernels that have to build for mobile devices. +// For this reason, we avoid relying on TensorFlow and instead only use the +// standard C++ library. + +#include +#include // NOLINT +#include + +namespace xla { + +// XLA JIT compilers use this registry to resolve symbolic CustomCall targets; +// so when using XLA as a JIT, CustomCall targets need to be registered here +// with the symbol name used in the CustomCall. +// +// The XLA:CPU ahead-of-time (AOT) compiler links using a standard offline +// linker; so when compiling in CPU AOT mode, you *also* need to make sure the +// name of the callee (presumably implemented in C++) matches up with the +// symbolic name used in the CustomCall. +// +// We maintain the registry in both the JIT and the AOT cases for simplicity, +// but we only use it when running in JIT mode. +class CustomCallTargetRegistry { + public: + static CustomCallTargetRegistry* Global(); + + void Register(const std::string& symbol, void* address, + const std::string& platform); + void* Lookup(const std::string& symbol, const std::string& platform) const; + + private: + // Maps the pair (symbol, platform) to a C function implementing a custom call + // named `symbol` for StreamExecutor platform `platform`. + // + // Different platforms have different ABIs. TODO(jlebar): Describe them! + // + // (We std::map rather than std::unordered_map because the STL doesn't provide + // a default hasher for pair, and we want to avoid pulling in + // dependencies that might define this.) + std::map, void*> registered_symbols_; + mutable std::mutex mu_; +}; + +class RegisterCustomCallTarget { + public: + explicit RegisterCustomCallTarget(const std::string& name, void* address, + const std::string& platform) { + CustomCallTargetRegistry::Global()->Register(name, address, platform); + } +}; + +#define XLA_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b + +#define XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \ + platform, counter) \ + static ::xla::RegisterCustomCallTarget XLA_REGISTER_CUSTOM_CALL_CONCAT( \ + custom_call_target_register, counter)( \ + symbol, reinterpret_cast(address), platform) + +#define XLA_REGISTER_CUSTOM_CALL_TARGET(function, platform) \ + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function, platform) + +#define XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address, platform) \ + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, platform, \ + __COUNTER__) + +// Convenience overloads for registering custom-call targets on the CPU. +#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \ + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function, "Host") + +#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address, "Host") + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CUSTOM_CALL_TARGET_REGISTRY_H_ diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index ed37099a542..490e057fcbc 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/defuser.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" -#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" namespace xla { @@ -49,7 +48,6 @@ Despecializer::Despecializer() : pipeline_("despecializer") { pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); - pipeline_.AddPass(); pipeline_.AddPass(); } diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index 46dcc3a438c..b6afaa17aa2 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -28,8 +28,8 @@ namespace xla { // optimized for one specific platform on a different platform (undoing platform // specific passes) with matching numerics for comparison. // -// Current despecialization passes are Defuser, ImplicitBroadcastRemover, -// and BFloat16MixedPrecisionRemoval. +// Current despecialization passes are HloDescheduler, ControlDepRemover, +// Defuser and BFloat16MixedPrecisionRemoval. class Despecializer : public HloModulePass { public: Despecializer(); diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc deleted file mode 100644 index e1e3b156fb3..00000000000 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" - -#include - -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/numbers.h" - -namespace xla { - -StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( - const se::Platform* platform, - absl::Span stream_executors) - : DeviceMemoryAllocator(platform), - stream_executors_(stream_executors.begin(), stream_executors.end()) {} - -StatusOr StreamExecutorMemoryAllocator::Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) { - TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, - GetStreamExecutor(device_ordinal)); - se::DeviceMemoryBase result = stream_executor->AllocateArray(size); - if (size > 0 && result == nullptr) { - return ResourceExhausted( - "Failed to allocate request for %s (%uB) on device ordinal %d", - tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal); - } - VLOG(3) << absl::StreamFormat( - "Allocated %s (%uB) on device ordinal %d: %p", - tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal, - result.opaque()); - return OwningDeviceMemory(result, device_ordinal, this); -} - -Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, - se::DeviceMemoryBase mem) { - if (!mem.is_null()) { - TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, - GetStreamExecutor(device_ordinal)); - VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d", - mem.opaque(), device_ordinal); - stream_executor->Deallocate(&mem); - } - return Status::OK(); -} - -StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( - int device_ordinal) { - if (device_ordinal < 0) { - return InvalidArgument("device ordinal value (%d) must be non-negative", - device_ordinal); - } - if (device_ordinal >= stream_executors_.size()) { - return InvalidArgument( - "device ordinal value (%d) >= number of devices (%u)", device_ordinal, - stream_executors_.size()); - } - if (stream_executors_[device_ordinal] == nullptr) { - return NotFound("Device %s:%d present but not supported", - platform()->Name(), device_ordinal); - } - return stream_executors_[device_ordinal]; -} - -bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const { - return false; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h deleted file mode 100644 index a2308ee7a41..00000000000 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_ - -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/service/owning_device_memory.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Interface for device memory allocators used within the XLA service. An -// allocator is responsible for allocating memory on all devices of a particular -// platform. -class DeviceMemoryAllocator { - public: - // Parameter platform indicates which platform the allocator allocates memory - // on. Must be non-null. - explicit DeviceMemoryAllocator(const se::Platform* platform) - : platform_(platform) {} - virtual ~DeviceMemoryAllocator() {} - - // Allocates memory on the device. - // - // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory - // must not be null. If size == 0, must return a null OwningDeviceMemory. - // - // 'retry_on_failure': If false, and the first attempt to allocate the memory - // fails, the allocation should return immediately without retrying. An - // example use case is optional scratch spaces where a failure has only - // performance impact. - virtual StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) = 0; - - // Two-arg version of Allocate(), which sets retry-on-failure to true. - // - // (We don't simply use a default argument on the virtual Allocate function - // because default args on virtual functions are disallowed by the Google - // style guide.) - StatusOr Allocate(int device_ordinal, uint64 size) { - return Allocate(device_ordinal, size, /*retry_on_failure=*/true); - } - - // Must be a nop for null pointers. - virtual Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) = 0; - - // Return the platform that the allocator allocates memory on. - const se::Platform* platform() const { return platform_; } - - // Can we call Deallocate() as soon as a computation has been scheduled on - // a stream, or do we have to wait for the computation to complete first? - virtual bool AllowsAsynchronousDeallocation() const = 0; - - protected: - friend class OwningDeviceMemory; - const se::Platform* platform_; -}; - -// Default memory allocator for a platform which uses -// StreamExecutor::Allocate/Deallocate. -class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { - public: - StreamExecutorMemoryAllocator( - const se::Platform* platform, - absl::Span stream_executors); - - StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - - // Pull in two-arg overload that sets retry_on_failure to true. - using DeviceMemoryAllocator::Allocate; - - Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; - - bool AllowsAsynchronousDeallocation() const override; - - private: - StatusOr GetStreamExecutor(int device_ordinal); - - // A vector indexed by device ordinal of StreamExecutors for each device of - // the allocator's platform type. If an element is nullptr, then the device - // with the respective device ordinal is not supported by XLA. - std::vector stream_executors_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_ diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 246f2af09b5..f45cda806c8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -117,6 +117,7 @@ class DfsHloVisitorBase { virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; + virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); @@ -199,6 +200,9 @@ class DfsHloVisitorBase { virtual Status HandleXor(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } + virtual Status HandlePopulationCount(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleShiftLeft(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -288,6 +292,7 @@ class DfsHloVisitorBase { // This call is purely a performance hint and can be omitted without // affecting correctness. void ReserveVisitStates(int num) { visit_state_.reserve(num); } + size_t VisitStateCapacity() const { return visit_state_.capacity(); } // Useful when we want to visit the same computation more than once with the // same visitor. diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 79ce3f82e8c..756ba9025f0 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -109,6 +109,9 @@ class DfsHloVisitorWithDefaultBase Status HandleReplicaId(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandlePartitionId(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc index 70173d43d79..bd638917ccf 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -75,7 +75,7 @@ ENTRY TestComputation { broadcast = f32[42] broadcast(add), dimensions={} slice = f32[1] slice(broadcast), slice={[1:2]} copy = f32[] copy(arg) - eq = pred[] equal-to(arg, gte) + eq = pred[] compare(arg, gte), direction=EQ neg = f32[] negate(arg) ROOT convert = f64[] convert(f32[] arg) })"; diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 559b9c1f2c9..353a7f5cebc 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -29,135 +29,6 @@ namespace xla { namespace { -// TODO(b/69062148) Remove this code when all backends support BatchDot -// natively. -Status DecomposeBatchDot(HloInstruction* dot) { - auto computation = dot->parent(); - const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); - const Shape& lhs_shape = lhs->shape(); - const Shape& rhs_shape = rhs->shape(); - const Shape& dot_shape = dot->shape(); - - // ShapeInference should guarantee that lhs/rhs batch dimensions match. - CHECK_EQ(dnums.lhs_batch_dimensions_size(), - dnums.rhs_batch_dimensions_size()); - const int64 num_batch_dims = dnums.lhs_batch_dimensions_size(); - // Calculate total batch size (note that ShapeInference requires that - // the batch dimensions are most-major). - int64 batch_size = 1; - for (int i = 0; i < num_batch_dims; ++i) { - CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)), - rhs_shape.dimensions(dnums.rhs_batch_dimensions(i))); - batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)); - } - - // Set lhs/rhs_transpose. - CHECK_EQ(1, dnums.lhs_contracting_dimensions_size()); - const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0); - const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0; - - CHECK_EQ(1, dnums.rhs_contracting_dimensions_size()); - const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0); - const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1; - - // Compute R3 and R3 shapes for lhs. - PrimitiveType lhs_type = lhs_shape.element_type(); - const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0); - const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1); - Shape lhs_shape_r3 = - ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols}); - Shape lhs_slice_shape_r3 = - ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols}); - Shape lhs_slice_shape_r2 = - ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols}); - - // Compute R3 and R3 shapes for rhs. - PrimitiveType rhs_type = rhs_shape.element_type(); - const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0); - const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1); - Shape rhs_shape_r3 = - ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols}); - Shape rhs_slice_shape_r3 = - ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols}); - Shape rhs_slice_shape_r2 = - ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols}); - - // Compute R3 and R3 shapes for dot output. - PrimitiveType dot_type = dot_shape.element_type(); - const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0); - const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1); - Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols}); - Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols}); - Shape concat_shape_r3 = - ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols}); - - // Reshape lhs/rhs into R3. - auto lhs_r3 = computation->AddInstruction( - HloInstruction::CreateReshape(lhs_shape_r3, lhs)); - auto rhs_r3 = computation->AddInstruction( - HloInstruction::CreateReshape(rhs_shape_r3, rhs)); - - // Loop through batch size, slicing out required lhs/rhs to compute each Dot. - std::vector output_slices(batch_size); - for (int64 i = 0; i < batch_size; ++i) { - // Slice R3 shape from 'lhs' and reshape to R2. - auto lhs_slice_r3 = computation->AddInstruction( - HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0}, - {i + 1, lhs_rows, lhs_cols}, {1, 1, 1})); - auto lhs_slice_r2 = computation->AddInstruction( - HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3)); - - // Slice R3 shape from 'rhs' and reshape to R2. - auto rhs_slice_r3 = computation->AddInstruction( - HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0}, - {i + 1, rhs_rows, rhs_cols}, {1, 1, 1})); - auto rhs_slice_r2 = computation->AddInstruction( - HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3)); - - // Transpose lhs/rhs (if needed). - if (lhs_transpose) { - Shape lhs_slice_shape_r2_transpose = - ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows}); - lhs_slice_r2 = - computation->AddInstruction(HloInstruction::CreateTranspose( - lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0})); - } - if (rhs_transpose) { - Shape rhs_slice_shape_r2_transpose = - ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows}); - rhs_slice_r2 = - computation->AddInstruction(HloInstruction::CreateTranspose( - rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0})); - } - - // Compute Dot of lhs/rhs R2 slices. - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_r2 = computation->AddInstruction( - HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2, - dot_dnums, dot->precision_config())); - - // Reshape Dot to R3 so we can concat along batch dimension. - auto dot_r3 = computation->AddInstruction( - HloInstruction::CreateReshape(dot_shape_r3, dot_r2)); - - output_slices[i] = dot_r3; - } - - // Concatenate slices from 'output_slices' along batch dimension. - auto concat = computation->AddInstruction( - HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0)); - // Reshape output 'new_dot' to original dimensions. - auto new_dot = computation->AddInstruction( - HloInstruction::CreateReshape(dot_shape, concat)); - - // Replace all uses of 'dot' in 'computation' with 'new_dot'. - return computation->ReplaceInstruction(dot, new_dot); -} - // Convert a dot into a canonical form where non-contracting and contracting // dimensions are reshaped together and batch dimensions are the most major // dimensions. The requires transposing and reshapes the lhs and rhs and @@ -301,6 +172,15 @@ StatusOr DotDecomposer::Run(HloModule* module) { non_canonical_dots.push_back(instruction); continue; } + // A dot is not canonical if it has more than one non-contracting + // dimension. + if (dnums.lhs_batch_dimensions_size() + 2 != + instruction->operand(0)->shape().rank() || + dnums.rhs_batch_dimensions_size() + 2 != + instruction->operand(1)->shape().rank()) { + non_canonical_dots.push_back(instruction); + continue; + } if (dnums.lhs_batch_dimensions().empty() && dnums.lhs_contracting_dimensions().empty()) { non_canonical_dots.push_back(instruction); @@ -323,27 +203,6 @@ StatusOr DotDecomposer::Run(HloModule* module) { TF_RETURN_IF_ERROR(CanonicalizeDot(dot)); changed = true; } - - if (decompose_batch_dot_) { - std::vector batch_dots; - for (auto* computation : module->MakeNonfusionComputations()) { - for (auto* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kDot) { - continue; - } - const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); - if (!dnums.lhs_batch_dimensions().empty()) { - batch_dots.push_back(instruction); - } - } - } - // Decompose each batch Dot in 'batch_dots'. - - for (auto* dot : batch_dots) { - TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); - changed = true; - } - } XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index 40e7a3b4c25..dcf92c8cc97 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -21,22 +21,16 @@ limitations under the License. namespace xla { -// DotDecomposer is a pass which decomposes batch Dot operations into a -// sequence of smaller (R2) Dot operations. +// DotDecomposer is a pass which converts dots into a canonical form where +// non-contracting and contracting dimensions are reshaped together and batch +// dimensions are the most major dimensions. class DotDecomposer : public HloModulePass { public: - // Decomposes batch Dot operations when 'decompose_batch_dot' is true. - DotDecomposer(bool decompose_batch_dot = true) - : decompose_batch_dot_(decompose_batch_dot) {} - ~DotDecomposer() = default; absl::string_view name() const override { return "dot_decomposer"; } // Run DotDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. StatusOr Run(HloModule* module) override; - - private: - bool decompose_batch_dot_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_decomposer_test.cc b/tensorflow/compiler/xla/service/dot_decomposer_test.cc new file mode 100644 index 00000000000..67fff50eaf6 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dot_decomposer.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using DotDecomposerTest = HloTestBase; + +TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,63,512]{2,1,0} parameter(0) + p1 = f32[512,512]{1,0} parameter(1) + ROOT dot = f32[64,63,512]{2,1,0} dot(p0, p1), lhs_contracting_dims={2}, + rhs_contracting_dims={0} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/0), + op::Shape("f32[4032,512]")))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc new file mode 100644 index 00000000000..d251c828bcd --- /dev/null +++ b/tensorflow/compiler/xla/service/dump.cc @@ -0,0 +1,412 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dump.h" + +#include "absl/strings/ascii.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { + +namespace { + +using absl::StrCat; +using absl::StrFormat; +using absl::string_view; + +struct CanonicalDebugOptions { + explicit CanonicalDebugOptions(const DebugOptions& opts) + : dump_to(opts.xla_dump_to()), + dump_as_text(opts.xla_dump_hlo_as_text()), + dump_as_proto(opts.xla_dump_hlo_as_proto()), + dump_as_dot(opts.xla_dump_hlo_as_dot()), + dump_as_html(opts.xla_dump_hlo_as_html()), + dump_as_url(opts.xla_dump_hlo_as_url()), + dump_snapshots(opts.xla_dump_hlo_snapshots()) { + // This constructor examines the values in `opts` and turns on other flags + // based on what we think is the user's intent. To reduce confusion about + // what was a user-specified value versus an extrapolated value, within this + // function we treat this struct's members as write-only, and read only from + // `opts`. + + // If dump_to is empty, default to dumping to stdout. + if (opts.xla_dump_to().empty()) { + dump_to = "-"; + } + + // Did the user specifiy an explicit format for dumping? + bool output_format_specified = + opts.xla_dump_hlo_as_text() || opts.xla_dump_hlo_as_proto() || + opts.xla_dump_hlo_as_dot() || opts.xla_dump_hlo_as_html() || + opts.xla_dump_hlo_as_url() || opts.xla_dump_hlo_snapshots(); + + // If we haven't specified an output format, default to dumping as text. + if (!output_format_specified) { + dump_as_text = true; + } + + // If we specified a regular expression restricting which modules to dump, + // respect that. + // + // If we didn't specify which modules to dump but we passed some other flag + // which implies dumping modules, dump all modules. + // + // Otherwise, don't dump any HLO modules. + if (!opts.xla_dump_hlo_module_re().empty()) { + // RE2 object is not copyable, and we can't capture "by move", so we + // resort to this hack. + string pattern = opts.xla_dump_hlo_module_re(); + should_dump_module = [pattern](string_view module_name) { + return RE2::PartialMatch(string(module_name), pattern); + }; + } else if (!opts.xla_dump_hlo_pass_re().empty() || + !opts.xla_dump_to().empty() || output_format_specified) { + should_dump_module = [](string_view) { return true; }; + } else { + should_dump_module = [](string_view) { return false; }; + } + + // Initialize should_dump_pass. This one is easy: We only dump per-pass + // data if the user asked for it explicitly. + if (!opts.xla_dump_hlo_pass_re().empty()) { + string pattern = opts.xla_dump_hlo_pass_re(); + should_dump_pass = [pattern](string_view pass_name) { + return RE2::PartialMatch(string(pass_name), pattern); + }; + } else { + should_dump_pass = [](string_view) { return false; }; + } + + // Output dirs "sponge" and "test_undeclared_outputs_dir" (case-insensitive) + // have a special meaning: Dump into the directory specified by the + // environment variable TEST_UNDECLARED_OUTPUTS_DIR. + string dump_to_lower = absl::AsciiStrToLower(opts.xla_dump_to()); + if (dump_to_lower == "sponge" || + dump_to_lower == "test_undeclared_outputs_dir") { + const char* dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); + if (dir != nullptr) { + dump_to = dir; + } else { + LOG(ERROR) << "--xla_dump_to=" << opts.xla_dump_to() + << ", but environment variable TEST_UNDECLARED_OUTPUTS_DIR " + "is not set, so cannot dump anywhere."; + should_dump_module = [](string_view) { return false; }; + should_dump_pass = [](string_view) { return false; }; + } + } + } + + bool dumping_to_stdout() const { return dump_to == "-"; } + + string dump_to; + std::function should_dump_module; + std::function should_dump_pass; + + // dump_ir isn't present here because this file is mostly concerned with + // dumping HLO. + bool dump_as_text; + bool dump_as_proto; + bool dump_as_dot; + bool dump_as_html; + bool dump_as_url; + bool dump_snapshots; +}; + +string FilenameFor(const HloModule& module, string_view suffix) { + return StrFormat("module_%04d.%s", module.unique_id(), suffix); +} + +void DumpToFileInDirImpl(string_view filename, string_view contents, + const CanonicalDebugOptions& opts) { + if (opts.dumping_to_stdout()) { + LOG(ERROR) << "Refusing to write " << filename + << " to stdout. Pass --xla_dump_to= to write to a file."; + return; + } + + const string& dir = opts.dump_to; + VLOG(1) << "Dumping " << filename << " to " << dir; + + tensorflow::Env* env = tensorflow::Env::Default(); + // Two threads can race to observe the absence of the dump directory and + // simultaneously try to create it, causing the "losing" thread to get a + // "directory already exists" error. We can work around this by checking + // again whether the dir exists. + if (!env->IsDirectory(dir).ok()) { + auto status = env->RecursivelyCreateDir(dir); + if (!status.ok() && !env->IsDirectory(dir).ok()) { + LOG(ERROR) << "Could not create directory " << dir + << " for dumping XLA debug data: " << status; + return; + } + } + + string file_path = + tensorflow::io::JoinPath(dir, SanitizeFileName(string(filename))); + auto status = tensorflow::WriteStringToFile(env, file_path, contents); + if (!status.ok()) { + LOG(ERROR) << "Could not write XLA debug data to " << file_path << ": " + << status; + } +} + +void DumpToFileInDirOrStdoutImpl(string_view filename, string_view contents, + const CanonicalDebugOptions& opts) { + // Dump to stdout if that's called for. + if (opts.dumping_to_stdout()) { + std::cout << "*** Begin " << filename << " ***\n" + << contents << "\n*** End " << filename << " ***" << std::endl; + return; + } + + // Otherwise, dump to a file. + DumpToFileInDirImpl(filename, contents, opts); +} + +void DumpHloModuleImpl(const HloModule& module, + const BufferAssignment* buffer_assn, + const HloExecutionProfile* profile, string_view suffix, + const CanonicalDebugOptions& opts) { + string filename = FilenameFor(module, suffix); + + if (opts.dump_as_text) { + DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"), module.ToString(), + opts); + if (buffer_assn) { + DumpToFileInDirOrStdoutImpl(StrCat(filename, "-buffer-assignment.txt"), + buffer_assn->ToString(), opts); + } + } + + if (opts.dump_as_proto) { + HloProto module_proto = + buffer_assn ? MakeHloProto(module, *buffer_assn) : MakeHloProto(module); + string pb; + if (!tensorflow::SerializeToStringDeterministic(module_proto, &pb)) { + pb = "Failed to serialize HLO module proto."; + } + DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts); + } + + auto render_graph = [&](RenderedGraphFormat format) { + StatusOr rendered_graph = RenderGraph( + *module.entry_computation(), + /*label=*/filename, module.config().debug_options(), format, profile); + if (rendered_graph.ok()) { + return std::move(rendered_graph).ValueOrDie(); + } + return StrFormat("Error rendering graph: %s", + rendered_graph.status().ToString()); + }; + + if (opts.dump_as_dot) { + DumpToFileInDirImpl(StrFormat("%s.dot", filename), + render_graph(RenderedGraphFormat::kDot), opts); + } + + if (opts.dump_as_html) { + DumpToFileInDirImpl(StrFormat("%s.html", filename), + render_graph(RenderedGraphFormat::kHtml), opts); + } + + // Special case for rendering graphs as URLs. We'll dump them to a file + // because why not, but we always log them to stdout as well. + if (opts.dump_as_url) { + string url = render_graph(RenderedGraphFormat::kUrl); + std::cout << filename << " --> " << url << std::endl; + if (!opts.dumping_to_stdout()) { + DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts); + } + } +} + +static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + +// Maps a module's unique ID to a counter indicating how many times we've dumped +// this module during the compilation pipeline. This lets us keep the filenames +// ordered nicely. +// +// Entries added here leak forever; we have no way to GC them when a module +// dies. But we only add an entry if dumping is enabled for this module, and +// dumping a module leaks buffer space in stdout or bytes on disk *way* faster +// than this hashtable leaks memory. +static auto& module_id_to_step_number GUARDED_BY(mu) = + *new absl::flat_hash_map(); + +} // namespace + +void DumpToFileInDir(const HloModule& module, string_view suffix, + string_view contents) { + DumpToFileInDirImpl(FilenameFor(module, suffix), contents, + CanonicalDebugOptions(module.config().debug_options())); +} + +void DumpToFileInDirOrStdout(const HloModule& module, string_view suffix, + string_view contents) { + DumpToFileInDirOrStdoutImpl( + FilenameFor(module, suffix), contents, + CanonicalDebugOptions(module.config().debug_options())); +} + +void DumpHloModuleIfEnabled(const HloModule& module, string_view name) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (opts.should_dump_module(module.name())) { + DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr, + name, opts); + } +} +void DumpHloModuleIfEnabled(const HloModule& module, + const BufferAssignment& buffer_assn, + string_view name) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (opts.should_dump_module(module.name())) { + DumpHloModuleImpl(module, &buffer_assn, /*profile=*/nullptr, name, opts); + } +} + +void DumpHloModuleIfEnabled(const HloModule& module, + const HloExecutionProfile& profile, + string_view name) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (opts.should_dump_module(module.name())) { + DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, &profile, name, opts); + } +} + +bool DumpingEnabledForHloModule(string_view hlo_module_name, + const DebugOptions& opts) { + return CanonicalDebugOptions(opts).should_dump_module(hlo_module_name); +} + +bool DumpingToStdout(const DebugOptions& opts) { + return CanonicalDebugOptions(opts).dumping_to_stdout(); +} + +void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name, + string_view before_pass_name, + string_view after_pass_name, + const HloModule& module) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (!opts.should_dump_module(module.name())) { + return; + } + + if (!opts.should_dump_pass(before_pass_name) && + !opts.should_dump_pass(after_pass_name)) { + return; + } + + int64 step_number; + { + tensorflow::mutex_lock lock(mu); + step_number = module_id_to_step_number[module.unique_id()]++; + } + + string filename_suffix = + StrFormat("%04d.%s.after_%s.before_%s", step_number, pipeline_name, + after_pass_name, before_pass_name); + DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr, + filename_suffix, opts); +} + +void DumpHloModuleDuringPassIfEnabled(string_view pass_name, + string_view step_name, + const HloModule& module) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (!opts.should_dump_module(module.name()) || + !opts.should_dump_pass(pass_name)) { + return; + } + + int64 step_number; + { + tensorflow::mutex_lock lock(mu); + step_number = module_id_to_step_number[module.unique_id()]++; + } + + string filename_suffix = + StrFormat("%04d.%s.%s", step_number, pass_name, step_name); + DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr, + filename_suffix, opts); +} + +void DumpHloSnapshotIfEnabled(const HloModule& module, + const HloSnapshot& snapshot) { + CanonicalDebugOptions opts(module.config().debug_options()); + if (!opts.should_dump_module(module.name()) || !opts.dump_snapshots) { + return; + } + int64 execution_count; + { + static auto& module_id_to_execution_count GUARDED_BY(mu) = + *new absl::flat_hash_map(); + tensorflow::mutex_lock lock(mu); + execution_count = module_id_to_execution_count[module.unique_id()]++; + } + string filename = + StrCat(FilenameFor(module, StrFormat("execution_%04d", execution_count)), + ".hlo_snapshot.pb"); + if (opts.dumping_to_stdout()) { + LOG(ERROR) << "Refusing to write HLO snapshot proto for " << filename + << " to stdout. Pass --xla_dump_to= to write to a file."; + return; + } + string pb; + if (!tensorflow::SerializeToStringDeterministic(snapshot, &pb)) { + LOG(ERROR) << "Failed to serialize HLO snapshot proto " << filename; + } + DumpToFileInDirImpl(filename, pb, opts); +} + +void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, + const DebugOptions& opts) { + CanonicalDebugOptions canonical_opts(opts); + string name = snapshot.hlo().hlo_module().name(); + if (!canonical_opts.should_dump_module(name) || + !canonical_opts.dump_snapshots) { + return; + } + + // We don't have a unique id for an HloSnapshot, so in this overload we just + // have to use its name. + int64 execution_count; + { + static auto& module_name_to_execution_count GUARDED_BY(mu) = + *new absl::flat_hash_map(); + tensorflow::mutex_lock lock(mu); + execution_count = module_name_to_execution_count[name]++; + } + string filename = StrFormat("module_%s.execution_%04d.hlo_snapshot.pb", name, + execution_count); + if (canonical_opts.dumping_to_stdout()) { + LOG(ERROR) << "Refusing to write HLO snapshot proto for " << filename + << " to stdout. Pass --xla_dump_to= to write to a file."; + return; + } + string pb; + if (!tensorflow::SerializeToStringDeterministic(snapshot, &pb)) { + LOG(ERROR) << "Failed to serialize HLO snapshot proto " << filename; + } + DumpToFileInDirImpl(filename, pb, canonical_opts); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dump.h b/tensorflow/compiler/xla/service/dump.h new file mode 100644 index 00000000000..6edc9b28dde --- /dev/null +++ b/tensorflow/compiler/xla/service/dump.h @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DUMP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DUMP_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/xla.pb.h" + +// Consolidated utilities for logging information during compilation, usually +// based on the options specified in the DebugOptions proto. +// +// Most functions here take an HloModule and read the DebugOptions from the +// module's config. + +namespace xla { + +class BufferAssignment; +class HloExecutionProfile; +class HloSnapshot; + +// Writes the given string to a file in the xla_dump_to directory specified by +// module's DebugOptions. +// +// If module doesn't have an xla_dump_to directory, does nothing. +void DumpToFileInDir(const HloModule& module, absl::string_view file_suffix, + absl::string_view contents); + +// Like DumpToFileInDir, except if module doesn't have an xla_dump_to directory +// specified, or if that directory is equal to "-", writes to stdout instead. +void DumpToFileInDirOrStdout(const HloModule& module, + absl::string_view file_suffix, + absl::string_view contents); + +// Dumps the given HLO module if dumping is enabled for the module. Exactly +// where and in what formats it's dumped is determined by the module's config. +// +// If you pass an HloExecutionProfile, note that currently only DOT-based output +// formats (i.e. --xla_dump_as_{dot,html,url}) are able to incorporate it into +// their output. Other formats will just ignore the profile. +void DumpHloModuleIfEnabled(const HloModule& module, absl::string_view name); +void DumpHloModuleIfEnabled(const HloModule& module, + const BufferAssignment& buffer_assn, + absl::string_view name); +void DumpHloModuleIfEnabled(const HloModule& module, + const HloExecutionProfile& profile, + absl::string_view name); + +// Dumps the given HLO module after running one HLO pass and before running +// another, if that's enabled. +void DumpHloModuleBetweenPassesIfEnabled(absl::string_view pipeline_name, + absl::string_view before_pass_name, + absl::string_view after_pass_name, + const HloModule& module); + +// Dumps the given HLO module during the given HLO pass, if that's enabled. +// +// "step" is a human-readable description of where we are in the middle of this +// pass. For example, "before-assigning-layouts". +void DumpHloModuleDuringPassIfEnabled(absl::string_view pass_name, + absl::string_view step, + const HloModule& module); + +// Dumps the given HloSnapshot to the module's xla_dump_dir, if this is enabled. +// +// Prefer the first overload below, as this will give filenames that are +// consistent with the other methods here. The second overload (which doesn't +// take an HloModule) is useful in the cases when you're dumping an HloSnapshot +// and simply don't have an HloModule. +void DumpHloSnapshotIfEnabled(const HloModule& module, + const HloSnapshot& snapshot); +void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, + const DebugOptions& opts); + +// Returns true if we should dump data for an HloModule. This is useful if you +// want to check if DumpToFileInDir{,OrStdout} will do anything before +// generating an expensive string. +bool DumpingEnabledForHloModule(absl::string_view hlo_module_name, + const DebugOptions& opts); +inline bool DumpingEnabledForHloModule(const HloModule& module) { + return DumpingEnabledForHloModule(module.name(), + module.config().debug_options()); +} + +// Returns true if DumpToFileInDirOrStdout and DumpHloModuleIfEnabled will write +// to stdout, rather than to a file on disk. +// +// This is useful if you want to do something different when writing to stdout. +// For example, maybe you have (almost-)duplicate data that you wouldn't mind +// writing to two files, but you don't want to print twice. +bool DumpingToStdout(const DebugOptions& opts); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DUMP_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index a4963427ec2..b2563f9949e 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -14,24 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" + +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { -namespace { -bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { - return window_dimension.size() == 1 && window_dimension.stride() == 1 && - window_dimension.padding_low() == 0 && - window_dimension.padding_high() == 0 && - window_dimension.window_dilation() == 1 && - window_dimension.base_dilation() == 1; -} -} // namespace - class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { public: explicit DynamicDimensionInferenceVisitor( @@ -84,6 +77,14 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleSlice(HloInstruction* hlo) override; + Status HandleDynamicSlice(HloInstruction* hlo) override; + + Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; + + Status HandleGather(HloInstruction* hlo) override; + + Status HandleScatter(HloInstruction* hlo) override; + private: using OperandDynamicDimensionFn = std::functionoperand_count(); + bool is_variadic_reduce = operand_count > 2; CHECK_EQ(operand_count % 2, 0); if (operand_index >= operand_count / 2) { // Init values doesn't have dynamic size. return Status::OK(); } if ((absl::c_count(reduce->dimensions(), dimension) != 0)) { - // Dimension is to be reduce, stop tracing. + // Dimension is to be reduced, stop tracing. return Status::OK(); } @@ -192,8 +194,21 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { int64 dimensions_not_reduced_count = 0; for (int i = 0; i < operand->shape().rank(); ++i) { if (dimension == i) { - parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, - dynamic_size); + ShapeIndex result_index = {}; + + if (is_variadic_reduce) { + // The dimensions of all data operands of a variadic reduce have + // to be the same. This means that if one operand of variadic + // reduce has a dynamic dimension, we set all outputs to use the + // same dynamic size in corresponding dimensions. + for (int64 i = 0; i < operand_count / 2; ++i) { + parent_->SetDynamicSize( + reduce, {i}, dimensions_not_reduced_count, dynamic_size); + } + } else { + parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, + dynamic_size); + } return Status::OK(); } @@ -351,6 +366,39 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { return Status::OK(); } } + // A dimension modifying reshape is supported as long as it is the most + // major one and it is combining with other non-dynamic dimensions. + const int64 output_most_major = reshape->shape().dimensions(0); + const int64 input_most_major = operand->shape().dimensions(0); + if (dimension == 0) { + if (output_most_major > input_most_major) { + const int64 multiplier = + reshape->shape().dimensions(0) / operand->shape().dimensions(0); + HloInstruction* multiplier_hlo = + hlo->parent()->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(multiplier))); + + HloInstruction* new_dynamic_size = + hlo->parent()->AddInstruction(HloInstruction::CreateBinary( + dynamic_size->shape(), HloOpcode::kMultiply, dynamic_size, + multiplier_hlo)); + parent_->SetDynamicSize(reshape, {}, 0, new_dynamic_size); + return Status::OK(); + } else if (output_most_major < input_most_major) { + const int64 divisor = + operand->shape().dimensions(0) / reshape->shape().dimensions(0); + HloInstruction* divisor_hlo = + hlo->parent()->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(divisor))); + + HloInstruction* new_dynamic_size = + hlo->parent()->AddInstruction(HloInstruction::CreateBinary( + dynamic_size->shape(), HloOpcode::kDivide, dynamic_size, + divisor_hlo)); + parent_->SetDynamicSize(reshape, {}, 0, new_dynamic_size); + return Status::OK(); + } + } return Unimplemented( "Dynamic Reshape on modified dimensions is yet not supported: %s", reshape->ToString()); @@ -366,7 +414,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow( const WindowDimension& window_dimension = reduce_window->window().dimensions(dimension); - if (!IsTrivialWindowDimension(window_dimension)) { + if (!window_util::IsTrivialWindowDimension(window_dimension)) { return Unimplemented( "Dynamic Spatial reduce window is not supported: %s", reduce_window->ToString()); @@ -387,7 +435,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( const WindowDimension& window_dimension = select_and_scatter->window().dimensions(dimension); - if (!IsTrivialWindowDimension(window_dimension)) { + if (!window_util::IsTrivialWindowDimension(window_dimension)) { return Unimplemented( "Dynamic Spatial select and scatter is not supported: %s", select_and_scatter->ToString()); @@ -420,6 +468,123 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) { }); } +Status DynamicDimensionInferenceVisitor::HandleDynamicSlice( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64 dimension, + int64 /*operand_index*/, HloInstruction* dynamic_size) { + if (hlo->shape().dimensions(dimension) != + hlo->operand(0)->shape().dimensions(dimension)) { + // Slicing a single element out kills the dynamic dimension. + if (hlo->shape().dimensions(dimension) == 1) { + return Status::OK(); + } + return Unimplemented( + "Dynamic dimension propagation on DynamicSlice where a partial " + "dimension is selected %s", + hlo->ToString()); + } + + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, + [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension, + int64 /*operand_index*/, HloInstruction* dynamic_size) { + if (hlo->shape().dimensions(dimension) != + hlo->operand(0)->shape().dimensions(dimension)) { + return Unimplemented( + "Dynamic dimension propagation on DynamicSlice where a partial " + "dimension is selected %s", + hlo->ToString()); + } + + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, + int64 input_dynamic_dimension, int64 operand_index, + HloInstruction* dynamic_size) { + if (operand_index != 1) { + return Unimplemented( + "Detects a dynamic dimension on the data input of gather, which " + "is not suported: %s", + hlo->ToString()); + } + // A mapping from output to input batch dim number. -1 means not a batch + // dimension. + int64 indices_rank = hlo->operand(1)->shape().rank(); + int64 output_rank = hlo->shape().rank(); + const GatherDimensionNumbers& gather_dims = + hlo->gather_dimension_numbers(); + // indices_dim is an iterator over indices dimensions. + int64 indices_dim = 0; + // Find the corresponding batch dimension in the output. + for (int64 output_dim = 0; output_dim < output_rank; ++output_dim) { + if (!absl::c_linear_search(gather_dims.offset_dims(), output_dim)) { + // Skips index vector dimension. + if (indices_dim == gather_dims.index_vector_dim()) { + indices_dim++; + } + if (indices_dim++ == input_dynamic_dimension) { + parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size); + return Status::OK(); + } + } + } + CHECK(indices_dim == indices_rank); + + return Unimplemented( + "Detects a non-batch dynamic dimension of gather, " + "which is not supported: %s", + hlo->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, + [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension, + int64 operand_index, HloInstruction* operand_dynamic_size) { + if (operand_index == 0) { + return Unimplemented( + "Detects a dynamic dimension on the data input of scatter, which " + "is not suported: %s", + hlo->ToString()); + } + + const ScatterDimensionNumbers& scatter_dims = + hlo->scatter_dimension_numbers(); + if (operand_index == 1) { + parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size); + return Status::OK(); + } + + if (operand_index == 2 && + absl::c_linear_search(scatter_dims.update_window_dims(), + dimension)) { + return Unimplemented( + "Dynamic dimension of update window dims is not supported " + "is not suported: %s", + hlo->ToString()); + } + // The dynamic dimension is collapsed and won't show up in the output. + // Do nothing here. + return Status::OK(); + }); +} + Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { // While loop is handled by passing dynamic size hlos as parameters into the // hlo while loop. This is done by replacing the original while with a new diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index d0f2998328f..a77aacaaa96 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -83,6 +83,12 @@ class DynamicDimensionInference { // by a scalar instruction `size`. void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, HloInstruction* size) { + Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index); + CHECK(!subshape.IsTuple()) + << "Can't set a tuple shape to dynamic dimension"; + CHECK(dim < subshape.rank() && dim >= 0) + << "Asked to set invalid dynamic dimension. Shape: " + << subshape.ToString() << ", Dimension: " << dim; dynamic_mapping_.try_emplace(DynamicDimension{inst, index, dim}, size); auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); iter.first->second.emplace(DynamicDimension{inst, index, dim}); diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index b3d93737427..a18c0176153 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -43,7 +43,6 @@ class DynamicDimensionInferenceTest : public HloTestBase { } Status RunInference() { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, DynamicDimensionInference::Run(module_.get())); @@ -62,20 +61,40 @@ class DynamicDimensionInferenceTest : public HloTestBase { return module_->AddEmbeddedComputation(embedded_builder.Build()); } + HloComputation* GetAddTuple() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto lhs_1 = + embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "lhs.1")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {}), "rhs")); + auto rhs_1 = + embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 3, ShapeUtil::MakeShape(F32, {}), "rhs.1")); + auto add = embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + auto add_1 = embedded_builder.AddInstruction(HloInstruction::CreateBinary( + lhs->shape(), HloOpcode::kAdd, lhs_1, rhs_1)); + embedded_builder.AddInstruction(HloInstruction::CreateTuple({add, add_1})); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + HloComputation* GetGe() { auto embedded_builder = HloComputation::Builder("ge"); auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "lhs")); auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( 1, ShapeUtil::MakeShape(F32, {}), "rhs")); - embedded_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs)); + embedded_builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe)); return module_->AddEmbeddedComputation(embedded_builder.Build()); } std::unique_ptr module_; std::unique_ptr inference_; - const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); + const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {}); }; TEST_F(DynamicDimensionInferenceTest, ParamTest) { @@ -88,6 +107,8 @@ TEST_F(DynamicDimensionInferenceTest, ParamTest) { HloInstruction::CreateParameter(1, scalar_shape_, "param")); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + // Set up dynamic parameter binding. TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( DynamicParameterBinding::DynamicParameter{1, {}}, @@ -112,6 +133,7 @@ TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) { DynamicParameterBinding::DynamicParameter{0, {1}}, DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), op::GetTupleElement(param, 1)); @@ -137,6 +159,7 @@ TEST_F(DynamicDimensionInferenceTest, GetTupleElement) { DynamicParameterBinding::DynamicParameter{0, {1}}, DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), op::GetTupleElement(param, 1)); @@ -167,6 +190,7 @@ TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param); } @@ -197,6 +221,7 @@ TEST_F(DynamicDimensionInferenceTest, ReduceTestI) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param); } @@ -228,11 +253,53 @@ TEST_F(DynamicDimensionInferenceTest, ReduceTestII) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 2})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param); EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr); } +TEST_F(DynamicDimensionInferenceTest, VariadicReduce) { + // Handle variadic reduce where output is a tuple. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto data_param_dynamic = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto data_param_static = builder.AddInstruction( + HloInstruction::CreateParameter(1, input_shape, "data_param.2")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(2, scalar_shape_, "size_param")); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + auto dynamic_negate = builder.AddInstruction(HloInstruction::CreateUnary( + input_shape, HloOpcode::kNegate, data_param_dynamic)); + + auto static_negate = builder.AddInstruction(HloInstruction::CreateUnary( + input_shape, HloOpcode::kNegate, data_param_static)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeTupleShape({reduce_shape, reduce_shape}), + {dynamic_negate, static_negate}, {init, init}, {1}, GetAddTuple())); + + module_->AddEntryComputation(builder.Build()); + + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 0), nullptr); +} + TEST_F(DynamicDimensionInferenceTest, DotTest) { auto builder = HloComputation::Builder(TestName()); constexpr int xdim = 3; @@ -271,6 +338,7 @@ TEST_F(DynamicDimensionInferenceTest, DotTest) { DynamicParameterBinding::DynamicParameter{2, {}}, DynamicParameterBinding::DynamicDimension{1, {}, 0})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param); EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); @@ -319,6 +387,7 @@ TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { DynamicParameterBinding::DynamicParameter{2, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param); EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr); @@ -356,6 +425,7 @@ TEST_F(DynamicDimensionInferenceTest, TransposeTest) { DynamicParameterBinding::DynamicParameter{3, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 2})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2); @@ -386,6 +456,7 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 3})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr); EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param); @@ -395,6 +466,63 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr); } +TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) { + // Test the ability to trace dimension combining. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {32, 10, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {320, 4}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + SCOPED_TRACE(module_->ToString()); + Status status = RunInference(); + EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr); + const Literal& multiplier = + inference_->GetDynamicSize(reshape, {}, 0)->operand(1)->literal(); + LiteralTestUtil::ExpectR0Equal(10, multiplier); +} + +TEST_F(DynamicDimensionInferenceTest, GatherTest) { + const string hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[20,10]{1,0} parameter(0) + indices = s32[32,20] parameter(1) + dynamic_size = s32[] parameter(2) + ROOT gather = f32[32,10,10]{2,1,0} gather(%operand, %indices), + offset_dims={2}, + collapsed_slice_dims={0}, + start_index_map={0}, + index_vector_dim=2, + slice_sizes={1,10} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize( + module_->entry_computation()->root_instruction(), {}, 0), + module_->entry_computation()->parameter_instruction(2)); +} + TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) { // Test the ability to trace unmodified reshape dimensions. auto builder = HloComputation::Builder(TestName()); @@ -415,6 +543,7 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 1})); + SCOPED_TRACE(module_->ToString()); Status status = RunInference(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); } @@ -439,6 +568,7 @@ TEST_F(DynamicDimensionInferenceTest, BroadcastTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 0})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr); EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param); @@ -580,6 +710,7 @@ TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {}, 0})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); } @@ -633,6 +764,7 @@ TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{2, {}, 0})); + SCOPED_TRACE(module_->ToString()); TF_ASSERT_OK(RunInference()); EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param); } @@ -659,5 +791,63 @@ TEST_F(DynamicDimensionInferenceTest, SliceTest) { EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 1), size_param); } +TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) { + auto builder = HloComputation::Builder(TestName()); + + auto data_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + std::vector params; + for (int i = 0; i < 2; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + } + + auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(F32, {5, 1}), data_param, params, + /*slice_sizes=*/{5, 1})); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) { + // Slicing out a single element from a dynamic dimension terminates the + // dynamic dimension. + auto builder = HloComputation::Builder(TestName()); + + auto data_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + std::vector params; + for (int i = 0; i < 2; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + } + + auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(F32, {1, 1}), data_param, params, + /*slice_sizes=*/{1, 1})); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), nullptr); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index a982dad95c7..95405cd6600 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" - #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -28,18 +27,19 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" - #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace { -// ChooseIdentityValue looks at the instruction and returns a identity value -// which, when padded, doesn't change the result of the instruction. +// ChooseIdentityValue looks at the instruction's operand, returns a +// identity value which, when padded, doesn't change the result of the +// instruction. // // nullopt is returned if padding doesn't need to be reset. -StatusOr ChooseIdentityValue(HloInstruction* inst) { +StatusOr ChooseIdentityValue(HloInstruction* inst, + int64 operand_number) { HloComputation* comp = inst->parent(); // Padding on elementwise operation doesn't affect the result of the effective // data. @@ -48,7 +48,14 @@ StatusOr ChooseIdentityValue(HloInstruction* inst) { } switch (inst->opcode()) { - case HloOpcode::kReduce: + case HloOpcode::kReduce: { + TF_RET_CHECK(operand_number < inst->operand_count() / 2) + << "Only data operand with dynamic dimension is valid."; + // Variadic reduce has different init value for different operand, given a + // data operand number, find the init value index. + int64 init_value_index = inst->operand_count() / 2 + operand_number; + return inst->mutable_operand(init_value_index); + } case HloOpcode::kReduceWindow: { // Because of the way we do reduce, we already require the `init` operand // of hlo reduce instruction to be identity value. Here we reuse the @@ -72,6 +79,10 @@ StatusOr ChooseIdentityValue(HloInstruction* inst) { return inst->mutable_operand(2); } case HloOpcode::kParameter: + case HloOpcode::kGather: + case HloOpcode::kScatter: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kGetDimensionSize: case HloOpcode::kReshape: case HloOpcode::kTuple: @@ -81,7 +92,7 @@ StatusOr ChooseIdentityValue(HloInstruction* inst) { case HloOpcode::kSlice: return nullptr; default: - return UnimplementedStrCat("Unimplimented padding for instruction: ", + return UnimplementedStrCat("Unimplemented padding for instruction: ", inst->ToString()); } } @@ -133,7 +144,7 @@ StatusOr DynamicPadder::Run(HloModule* module) { } TF_ASSIGN_OR_RETURN(HloInstruction * identity_value, - ChooseIdentityValue(inst)); + ChooseIdentityValue(inst, operand_num)); if (identity_value == nullptr) { continue; } @@ -160,9 +171,10 @@ StatusOr DynamicPadder::Run(HloModule* module) { HloInstruction* broadcasted_effective_size = computation->AddInstruction(HloInstruction::CreateBroadcast( mask_shape, dynamic_size, {})); - HloInstruction* pred = computation->AddInstruction( - HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota, - broadcasted_effective_size)); + HloInstruction* pred = + computation->AddInstruction(HloInstruction::CreateCompare( + pred_shape, iota, broadcasted_effective_size, + ComparisonDirection::kLt)); HloInstruction* broadcasted_identity_value = computation->AddInstruction(HloInstruction::CreateBroadcast( diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index fda806bbf81..2963deaa317 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -43,10 +43,7 @@ class DynamicPadderTest : public HloTestBase { DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } StatusOr RunPadder() { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before padder"); - DynamicPadder padder; - return padder.Run(module_.get()); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index abd8ead52ab..664fdcaebb0 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -49,7 +49,6 @@ limitations under the License. namespace xla { using absl::StrCat; -using llvm_ir::AsStringRef; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; @@ -208,10 +207,8 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) { - if (op->opcode() == HloOpcode::kCopy) { - return operand_value; - } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || - op->operand(0)->shape().element_type() == PRED) { + if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + op->operand(0)->shape().element_type() == PRED) { return EmitIntegerUnaryOp(op, operand_value); } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { return EmitComplexUnaryOp(op, operand_value); @@ -329,6 +326,11 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } return Unimplemented("unary op Not is not defined for type '%d'", type); } + case HloOpcode::kPopulationCount: { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop, + {operand_value}, + {operand_value->getType()}, b_); + } default: return Unimplemented("unary integer op '%s'", HloOpcodeString(op->opcode())); @@ -440,7 +442,9 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( {operand_value}, {operand_value->getType()}, b_); case HloOpcode::kRoundNearestAfz: - return EmitRoundNearestAfz(op->shape().element_type(), operand_value); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kSign: { auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); @@ -721,25 +725,28 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( // We use ordered comparisons for everything except kNe, where we use an // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. - case HloOpcode::kEq: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, - rhs_value, b_); - case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, - rhs_value, b_); - case HloOpcode::kLt: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, - rhs_value, b_); - case HloOpcode::kGt: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, - rhs_value, b_); - case HloOpcode::kLe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, - rhs_value, b_); - case HloOpcode::kGe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, - rhs_value, b_); - + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, + rhs_value, b_); + case ComparisonDirection::kNe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLt: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, + rhs_value, b_); + case ComparisonDirection::kGt: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kGe: + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, + rhs_value, b_); + } + } case HloOpcode::kMaximum: return EmitFloatMax(lhs_value, rhs_value); case HloOpcode::kMinimum: @@ -841,21 +848,28 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // We use ordered comparisons for everything except kNe, where we use an // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. - case HloOpcode::kEq: - return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); - case HloOpcode::kNe: - return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); - + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); + case ComparisonDirection::kNe: + return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); + default: + return Unimplemented( + "complex comparison '%s'", + ComparisonDirectionToString(op->comparison_direction())); + } + } case HloOpcode::kPower: { auto a = EmitExtractReal(lhs_value); auto b = EmitExtractImag(lhs_value); @@ -1127,12 +1141,6 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, return Select(x_is_small, for_small_x, for_large_x); } -StatusOr ElementalIrEmitter::EmitRoundNearestAfz( - PrimitiveType /*prim_type*/, llvm::Value* value) { - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round, {value}, - {value->getType()}, b_); -} - StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) { @@ -1280,28 +1288,32 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: return EmitIntegerRemainder(lhs_value, rhs_value, is_signed); - case HloOpcode::kEq: - return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, - rhs_value, b_); - case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, - rhs_value, b_); - case HloOpcode::kLt: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, - lhs_value, rhs_value, b_); - case HloOpcode::kGt: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, - lhs_value, rhs_value, b_); - case HloOpcode::kLe: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, - lhs_value, rhs_value, b_); - case HloOpcode::kGe: - return llvm_ir::EmitComparison( - is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, - lhs_value, rhs_value, b_); + case HloOpcode::kCompare: { + switch (op->comparison_direction()) { + case ComparisonDirection::kEq: + return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, + rhs_value, b_); + case ComparisonDirection::kNe: + return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, + rhs_value, b_); + case ComparisonDirection::kLt: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, + lhs_value, rhs_value, b_); + case ComparisonDirection::kGt: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, + lhs_value, rhs_value, b_); + case ComparisonDirection::kLe: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, + lhs_value, rhs_value, b_); + case ComparisonDirection::kGe: + return llvm_ir::EmitComparison( + is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, + lhs_value, rhs_value, b_); + } + } case HloOpcode::kMinimum: return EmitIntegralMin(lhs_value, rhs_value, is_signed); case HloOpcode::kMaximum: @@ -1354,46 +1366,6 @@ llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, lhs_value, rhs_value); } -llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( - const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) { - CHECK(hlo.IsElementwise()) - << "HLO " << hlo.ToString() << " is not elementwise."; - - const Shape& operand_shape = hlo.operand(operand_no)->shape(); - // If the operand is scalar, the source index is always {}. - if (ShapeUtil::IsScalar(operand_shape)) { - return llvm_ir::IrArray::Index(target_index.GetType()); - } - - // If no implicit broadcast is needed for this operand, returns the target - // index as the source index. - // - // `IrArray::Index` may contain a physical linear which we can propagate to - // our operand only if our layouts match. "only if" is a bit strong since - // e.g. we can still forward the linear index if the operand shape is - // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases - // are probably not worth handling here for now. - if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) && - LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) { - return target_index; - } - - // If implicit broadcast is needed, the source dimensions that are broadcast - // have index 0. - CHECK_EQ(operand_shape.rank(), hlo.shape().rank()); - llvm_ir::IrArray::Index source_index(target_index.GetType()); - for (int64 i = 0; i < hlo.shape().rank(); ++i) { - if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { - source_index.push_back(target_index[i]); - } else { - CHECK_EQ(1, operand_shape.dimensions(i)); - source_index.push_back(target_index.GetConstantWithIndexType(0)); - } - } - return source_index; -} - StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, @@ -1494,11 +1466,16 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( } } case RNG_NORMAL: { + // Convert uniform x in (0, 1] to normal using formula: + // Normal(x, mu, sigma) = mu + sqrt(2)*sigma*ErfcInv(2x) + // = mu + sqrt(2)*sigma*ErfInv(1-2x) TF_ASSIGN_OR_RETURN( llvm::Value * r, EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), elem_value))); - return FAdd(FMul(r, b_or_sigma), a_or_mean); + return FAdd(FMul(llvm::ConstantFP::get(r->getType(), std::sqrt(2.0)), + FMul(r, b_or_sigma)), + a_or_mean); } default: return InvalidArgument( @@ -1699,14 +1676,11 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); + operand_to_generator.at(hlo->operand(1))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); + operand_to_generator.at(hlo->operand(2))(index)); return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value, on_false_value); } @@ -1716,14 +1690,11 @@ StatusOr ElementalIrEmitter::EmitElementalClamp( const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * min_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); + operand_to_generator.at(hlo->operand(1))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * max_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); + operand_to_generator.at(hlo->operand(2))(index)); PrimitiveType prim_type = hlo->shape().element_type(); if (primitive_util::IsFloatingPointType(prim_type)) { return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); @@ -1756,8 +1727,8 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( exit_block = llvm_ir::CreateBasicBlock( /*insert_before=*/nullptr, IrName(hlo, "merge"), b_); } else { - exit_block = init_block->splitBasicBlock(b_->GetInsertPoint(), - AsStringRef(IrName(hlo, "merge"))); + exit_block = + init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge")); init_block->getTerminator()->eraseFromParent(); } @@ -1803,37 +1774,40 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_); source_index_phis[operand_id] = PHI(source_index.GetType(), operand_usage_count[operand_id]); - auto operand_index = source_index; - operand_index[concat_dim] = source_index_phis[operand_id]; + std::vector operand_multi_index = source_index.multidim(); + operand_multi_index[concat_dim] = source_index_phis[operand_id]; // Create the terminator of the block before calling operand generators, // because they require non-degenerate basic blocks. b_->SetInsertPoint(llvm::BranchInst::Create( exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id])); + llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(), + source_index.GetType()); TF_ASSIGN_OR_RETURN(llvm::Value * value, operand_to_generator.at(operand)(operand_index)); output->addIncoming(value, b_->GetInsertBlock()); b_->SetInsertPoint(init_block, saved_insert_point); } + std::vector source_multi_index = source_index.multidim(); for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); auto false_block = llvm_ir::CreateBasicBlock( exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); - auto concat_dim_size = - llvm::ConstantInt::get(source_index[concat_dim]->getType(), - operand->shape().dimensions(concat_dim)); + auto concat_dim_size = source_index.GetConstantWithIndexType( + operand->shape().dimensions(concat_dim)); int64 operand_id = to_unique_operand_id[operand]; - source_index_phis[operand_id]->addIncoming(source_index[concat_dim], + source_index_phis[operand_id]->addIncoming(source_multi_index[concat_dim], b_->GetInsertBlock()); - CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), + CondBr(ICmpULT(source_multi_index[concat_dim], concat_dim_size), emit_operand_blocks[operand_id], false_block); // Subtract the size of the concat dimension of the current operand // from the source index. b_->SetInsertPoint(false_block); - source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size); + source_multi_index[concat_dim] = + Sub(source_multi_index[concat_dim], concat_dim_size); } Unreachable(); @@ -1850,7 +1824,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const int64 rank = input_hlo->shape().rank(); // Use the same index type for all tensor accesses in the same kernel. llvm::Type* index_type = index.GetType(); - llvm_ir::IrArray::Index slice_start_index(index_type, rank); + std::vector slice_start_multi_index(rank); for (int64 i = 0; i < rank; ++i) { auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); @@ -1873,17 +1847,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( EmitIntegralMax(index_typed_const(0), start_index_value, is_signed), is_signed); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = start_index_value; + start_index_value->setName(IrName(hlo, StrCat("start_idx", i))); + slice_start_multi_index[i] = start_index_value; } - llvm_ir::IrArray::Index input_index(index_type, rank); + std::vector input_multi_index(rank); for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = Add(slice_start_index[i], index[i]); + input_multi_index[i] = Add(slice_start_multi_index[i], index[i]); } + llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(), + index_type); return operand_to_generator.at(input_hlo)(input_index); } @@ -1905,7 +1880,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( llvm::Type* index_type = index.GetType(); // This is the index into `operand` that holds the element we want to // generate. - IrArray::Index operand_index(index_type); + std::vector operand_multi_index; // First copy in the window indices to operand_index. Also collect a mapping // from operand dimension to output window dimension. Elided window dimensions @@ -1914,26 +1889,29 @@ StatusOr ElementalIrEmitter::EmitElementalGather( for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { - operand_index.push_back(index.GetConstantWithIndexType(0)); + operand_multi_index.push_back(index.GetConstantWithIndexType(0)); } else { int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); operand_to_output_dim[i] = output_window_dim; - operand_index.push_back(index[output_window_dim]); + operand_multi_index.push_back(index[output_window_dim]); } } // This is the index of the index vector in the start_indices tensor. - IrArray::Index gather_index_index(index_type); + std::vector gather_index_index_components; { - std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { - gather_index_index.push_back(index[i]); + gather_index_index_components.push_back(index[i]); } } - if (gather_index_index.size() != indices_shape.dimensions_size()) { - gather_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + if (gather_index_index_components.size() != + indices_shape.dimensions_size()) { + gather_index_index_components.insert( + gather_index_index_components.begin() + + dim_numbers.index_vector_dim(), + nullptr); } } @@ -1961,11 +1939,14 @@ StatusOr ElementalIrEmitter::EmitElementalGather( gather_dim_component_extended, is_signed), is_signed); - operand_index[operand_dim] = - Add(operand_index[operand_dim], gather_dim_component_extended_inbound); + operand_multi_index[operand_dim] = + Add(operand_multi_index[operand_dim], + gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { + IrArray::Index gather_index_index(gather_index_index_components, + indices_shape, index_type); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); add_to_operand_index(gather_dim_component, 0); @@ -1973,13 +1954,16 @@ StatusOr ElementalIrEmitter::EmitElementalGather( int64 index_vector_size = indices_shape.dimensions(dim_numbers.index_vector_dim()); for (int64 i = 0; i < index_vector_size; i++) { - gather_index_index[dim_numbers.index_vector_dim()] = + gather_index_index_components[dim_numbers.index_vector_dim()] = index.GetConstantWithIndexType(i); + IrArray::Index gather_index_index(gather_index_index_components, + indices_shape, index_type); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); add_to_operand_index(gather_dim_component, i); } } + IrArray::Index operand_index(operand_multi_index, operand_shape, index_type); return operand_generator(operand_index); } @@ -1992,8 +1976,8 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. const int64 rank = input_hlo->shape().rank(); - llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); - llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); + std::vector slice_start_multi_index(rank); + std::vector slice_limit_multi_index(rank); // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' llvm::Value* slice_intersection = b_->getTrue(); @@ -2024,16 +2008,16 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( EmitIntegralMax(index_typed_const(0), start_index_value, is_signed), is_signed); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = start_index_value; - slice_limit_index[i] = Add(slice_start_index[i], update_dim_size); + start_index_value->setName(IrName(hlo, StrCat("start_idx", i))); + slice_start_multi_index[i] = start_index_value; + slice_limit_multi_index[i] = + Add(slice_start_multi_index[i], update_dim_size); slice_intersection = - And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]), + And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]), "slice_intersection"); slice_intersection = - And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]), + And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]), "slice_intersection"); } @@ -2049,10 +2033,12 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Handle true BB (return data from 'update') SetToFirstInsertPoint(if_data.true_block, b_); // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(index.GetType(), rank); + std::vector update_multi_index(rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = Sub(index[i], slice_start_index[i]); + update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]); } + llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(), + index.GetType()); TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); Store(true_value, ret_value_addr); @@ -2071,27 +2057,28 @@ StatusOr ElementalIrEmitter::EmitElementalPad( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& padded_index) { - auto index = padded_index; + std::vector multi_index = padded_index.multidim(); llvm::Value* in_bounds = b_->getTrue(); - for (size_t i = 0; i < index.size(); ++i) { + for (size_t i = 0; i < multi_index.size(); ++i) { auto index_typed_const = [=](int64 n) { - return llvm::ConstantInt::get(index[i]->getType(), n); + return padded_index.GetConstantWithIndexType(n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = - And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds"); - in_bounds = And( - in_bounds, - ICmpEQ( - index_typed_const(0), - URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = - SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1)); + multi_index[i] = + Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)), + "in_bounds"); in_bounds = And(in_bounds, - ICmpSLT(index[i], + ICmpEQ(index_typed_const(0), + URem(multi_index[i], + index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + multi_index[i] = + SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = + And(in_bounds, + ICmpSLT(multi_index[i], index_typed_const(hlo->operand(0)->shape().dimensions(i))), "in_bounds"); } @@ -2107,6 +2094,8 @@ StatusOr ElementalIrEmitter::EmitElementalPad( llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); SetToFirstInsertPoint(if_data.true_block, b_); + llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(), + padded_index.GetType()); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); Store(operand_value, ret_value_addr); @@ -2166,21 +2155,27 @@ StatusOr ElementalIrEmitter::EmitElementalDot( // Given an output index [a,b,c,d,e] in the result, we compute: // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - IrArray::Index lhs_index(index_type), rhs_index(index_type); - + std::vector lhs_multi_index, rhs_multi_index; for (int64 i = 0; i < lhs_dims - 1; i++) { - lhs_index.push_back(dot_result_index[i]); + lhs_multi_index.push_back(dot_result_index[i]); } - lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue()); + lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim, + inner_loop->GetIndVarValue()); + IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(), + index_type); int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size(); for (int64 i = 0; i < num_batch_dims; i++) { - rhs_index.push_back(dot_result_index[dim_numbers.rhs_batch_dimensions(i)]); + rhs_multi_index.push_back( + dot_result_index[dim_numbers.rhs_batch_dimensions(i)]); } for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) { - rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); + rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]); } - rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); + rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim, + inner_loop->GetIndVarValue()); + IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(), + index_type); llvm::Value* current_accumulator = Load(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); @@ -2220,7 +2215,6 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kClz: case HloOpcode::kConvert: case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kExpm1: @@ -2231,6 +2225,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kLog1p: case HloOpcode::kNegate: case HloOpcode::kNot: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: case HloOpcode::kSign: @@ -2240,24 +2235,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(hlo->operand(0))(index)); return EmitUnaryOp(hlo, operand_value); }; case HloOpcode::kAdd: case HloOpcode::kAnd: case HloOpcode::kAtan2: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kOr: case HloOpcode::kXor: case HloOpcode::kPower: @@ -2271,11 +2260,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* lhs = hlo->operand(0); const HloInstruction* rhs = hlo->operand(1); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, - operand_to_generator.at(lhs)( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(lhs)(index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, - operand_to_generator.at(rhs)( - ElementwiseSourceIndex(index, *hlo, 1))); + operand_to_generator.at(rhs)(index)); return EmitBinaryOp(hlo, lhs_value, rhs_value); }; case HloOpcode::kSelect: @@ -2292,8 +2279,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); + operand_to_generator.at(hlo->operand(0))(index)); return EmitReducePrecision(hlo, operand_value); }; case HloOpcode::kConcatenate: @@ -2306,13 +2292,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return [this, hlo, &operand_to_generator]( const IrArray::Index& target_index) -> StatusOr { const HloInstruction* operand = hlo->operand(0); - auto source_index = target_index; + std::vector source_multi_index = target_index.multidim(); for (int64 dim : hlo->dimensions()) { - source_index[dim] = - Sub(llvm::ConstantInt::get(target_index[dim]->getType(), - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType( + hlo->shape().dimensions(dim) - 1), + target_index[dim]); } + llvm_ir::IrArray::Index source_index( + source_multi_index, operand->shape(), target_index.GetType()); return operand_to_generator.at(operand)(source_index); }; case HloOpcode::kBroadcast: @@ -2386,7 +2373,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { IrArray::Index sliced_index = index.SourceIndexOfSlice( - /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(), + /*operand_shape=*/hlo->operand(0)->shape(), + /*starts=*/hlo->slice_starts(), /*strides=*/hlo->slice_strides(), /*builder=*/b_); return operand_to_generator.at(hlo->operand(0))(sliced_index); }; @@ -2423,6 +2411,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return operand_to_generator.at(operand)( index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_)); }; + case HloOpcode::kCopy: + return [hlo, &operand_to_generator]( + const IrArray::Index& target_index) -> StatusOr { + IrArray::Index source_index(target_index.multidim(), + hlo->operand(0)->shape(), + target_index.GetType()); + TF_ASSIGN_OR_RETURN( + llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(source_index)); + return operand_value; + }; case HloOpcode::kTranspose: return [this, hlo, &operand_to_generator](const IrArray::Index& target_index) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 7afecbbd318..6b3844c3f0e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -146,9 +146,6 @@ class ElementalIrEmitter : public IrBuilderMixin { virtual StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value); - virtual StatusOr EmitRoundNearestAfz(PrimitiveType prim_type, - llvm::Value* value); - virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x); @@ -159,15 +156,6 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, llvm::Value* imag); - // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and - // the target array index, computes the source array index of its - // `operand_no`-th operand. - // - // Precondition: `hlo` is an elementwise op. - llvm_ir::IrArray::Index ElementwiseSourceIndex( - const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no); - // Identifier of the thread unique among all threads on the device virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 9e9d3daf25e..ac18346faa1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -32,7 +32,7 @@ class ElementalIrEmitterExecutionTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text, config)); + ParseAndReturnVerifiedModule(hlo_text, config)); EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); } }; @@ -83,7 +83,15 @@ ENTRY resampler_Resampler.49 { } )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{4e-3, 4e-3})); + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + // Disable the layout assignment pass because it would throw away the layouts + // in the fusion computation, but not recreate them. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{4e-3, 4e-3})); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 1518d83083b..7b60c983b30 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -137,8 +138,6 @@ StatusOr Executable::ExecuteOnStreamWrapper( XLA_LOG_LINES( tensorflow::INFO, profile_ptr->ToString(stream->parent()->GetDeviceDescription())); - hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", - profile_ptr.get()); } return return_value; @@ -146,39 +145,4 @@ StatusOr Executable::ExecuteOnStreamWrapper( int64 Executable::SizeInBytes() { return -1; } -Status Executable::DumpHloSnapshot() { - TF_RET_CHECK(dumping_snapshot()); - TF_RET_CHECK(hlo_snapshot_->has_hlo() && - hlo_snapshot_->hlo().has_hlo_module()); - const string& directory_path = - module_config().debug_options().xla_dump_executions_to(); - const auto& module = hlo_snapshot_->hlo().hlo_module(); - string filename = - absl::StrFormat("computation_%d__%s__execution_%d", module.id(), - module.entry_computation_name(), ++execution_count_); - return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); -} - -/* static */ Status Executable::DumpToDirectory( - const string& directory_path, string filename, - const HloSnapshot& hlo_session) { - tensorflow::Env* env = tensorflow::Env::Default(); - if (!env->IsDirectory(directory_path).ok()) { - // NB! CreateDir does not work reliably with multiple XLA threads -- two - // threads can race to observe the absence of the dump directory and - // simultaneously try to create it, causing the "losing" thread to get a - // "directory already exists" error. - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); - } - filename = SanitizeFileName(std::move(filename)); - string file_path = tensorflow::io::JoinPath(directory_path, filename); - const size_t size = hlo_session.ByteSizeLong(); - auto serialized = absl::make_unique(size); - TF_RET_CHECK(tensorflow::SerializeToBufferDeterministic( - hlo_session, serialized.get(), size)); - return tensorflow::WriteStringToFile( - tensorflow::Env::Default(), file_path, - absl::string_view(serialized.get(), size)); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index b34bca55a48..5caead15d50 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -24,13 +24,11 @@ limitations under the License. #include "absl/types/variant.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" -#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -40,20 +38,66 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { // ExecutionOutput encapsulates the output buffers of a execution and the // leftover buffers to be released by the caller. -struct ExecutionOutput { +class ExecutionOutput { + public: ExecutionOutput(ScopedShapedBuffer result, - std::vector to_be_released) - : result(std::move(result)), to_be_released(std::move(to_be_released)) {} - ScopedShapedBuffer result; + std::vector to_be_released, + std::vector aliased_indices) + : result_(std::move(result)), + to_be_released_(std::move(to_be_released)), + aliased_indices_(std::move(aliased_indices)) {} + ExecutionOutput(ExecutionOutput&&) = default; + ExecutionOutput& operator=(ExecutionOutput&&) = default; + + ~ExecutionOutput() { + // If the ExecutionOutput has not been committed, and if there are aliased + // indices, clear them off the ScopedShapedBuffer to prevent them to be + // released. + for (auto& index : aliased_indices_) { + result_.set_buffer(se::OwningDeviceMemory(), index); + } + } + + // Should be called once it is known that the execute operation succeeded, + // before returning the ExecutionOutput to the caller. + ExecutionOutput& Commit() { + aliased_indices_.clear(); + return *this; + } + + const ScopedShapedBuffer& Result() const { return result_; } + + ScopedShapedBuffer ConsumeResult() { + aliased_indices_.clear(); + return std::move(result_); + } + + const std::vector& ToBeReleased() const { + return to_be_released_; + } + + std::vector ConsumeToBeReleased() { + return std::move(to_be_released_); + } + + private: + ScopedShapedBuffer result_; // Leftover buffers for the caller to release. Elements in this list are // donated input memory buffers that are not reused by XLA as outputs. - std::vector to_be_released; + std::vector to_be_released_; + + // These are the indices in result_ which have been aliased from the caller. + // If the execution operation fails, the caller should maintain ownership of + // the buffer, so we track the indices here, and unless the ExecutionOutput is + // committed, we remove them from the result_ before destruction. + std::vector aliased_indices_; }; // A given platform's compiler will produce an Executable -- this is a uniform @@ -184,11 +228,6 @@ class Executable { } bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; } HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } - Status DumpHloSnapshot(); - - // Dump hlo snapshot to directory_path/filename. - static Status DumpToDirectory(const string& directory_path, string filename, - const HloSnapshot& hlo_session); protected: mutable tensorflow::mutex mutex_; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8eeb930b481..ef35311b08b 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -81,8 +81,9 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + zero, ComparisonDirection::kGt)); return builder.Build(); } @@ -158,9 +159,9 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kEq, param0, false_constant)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), param0, false_constant, + ComparisonDirection::kEq)); cond_computation = module->AddEmbeddedComputation(builder.Build()); } diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index a58ac39dffa..1838f65e6ea 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -112,6 +112,14 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( int64 operand_rank) { HloComputation* computation = index_vector->parent(); const Shape& index_shape = index_vector->shape(); + + if (operand_rank == 0) { + // This is Gather from a scalar. So, the index vector in operand space must + // be a zero-sized vector. + return computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); + } + HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index cb43c27be96..d6a7ec90b59 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -57,7 +57,8 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( void GenericTransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - MutableBorrowingLiteral literal, std::function done) { + MutableBorrowingLiteral literal, std::function done, + const TransferMetadata* /*transfer_metadata*/) { Status status = stream->BlockHostUntilDone(); if (!status.ok()) { return done(status); @@ -97,7 +98,8 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal( Status GenericTransferManager::TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) { + const ShapedBuffer& device_buffer, + const TransferMetadata* /*transfer_metadata*/) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " << ShapeUtil::HumanString(shape) diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 86c8b1c145a..acfd8dd64c1 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -40,14 +40,15 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - void TransferLiteralFromDevice(se::Stream* stream, - const ShapedBuffer& device_buffer, - MutableBorrowingLiteral literal, - std::function done) override; + void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, std::function done, + const TransferMetadata* transfer_metadata) override; Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) override; + const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 3bc0daf9e70..2ffc6c8fb63 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,13 +1,14 @@ # Description: # GPU-specific components in XLA service implementation. -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load( "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", "tf_cuda_tests_tags", ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library", "if_cuda") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") licenses(["notice"]) # Apache 2.0 @@ -84,6 +85,24 @@ cc_library( # ], #) +tf_cc_test( + name = "custom_call_test", + srcs = ["custom_call_test.cc"], + tags = ["requires-gpu-sm35"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/service:custom_call_target_registry", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + "@local_config_cuda//cuda:cuda_headers", + ], +) + cc_library( name = "stream_assignment", srcs = ["stream_assignment.cc"], @@ -142,6 +161,22 @@ cc_library( ], ) +cc_library( + name = "target_util", + srcs = ["target_util.cc"], + hdrs = ["target_util.h"], + deps = [ + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@com_google_absl//absl/base", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm//:core", + "@llvm//:support", + ], +) + cc_library( name = "ir_emitter", srcs = [ @@ -166,6 +201,7 @@ cc_library( ":nccl_all_reduce_thunk", ":parallel_loop_emitter", ":partition_assignment", + ":target_util", ":thunk", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -176,6 +212,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", @@ -264,10 +301,10 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", @@ -310,6 +347,7 @@ tf_cuda_library( ":buffer_allocations", ":hlo_execution_profiler", ":thunk", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/synchronization", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", @@ -331,6 +369,7 @@ cc_library( "convolution_thunk.cc", "copy_thunk.cc", "cudnn_batchnorm_thunk.cc", + "custom_call_thunk.cc", "fft_thunk.cc", "for_thunk.cc", "gemm_thunk.cc", @@ -351,6 +390,7 @@ cc_library( "convolution_thunk.h", "copy_thunk.h", "cudnn_batchnorm_thunk.h", + "custom_call_thunk.h", "fft_thunk.h", "for_thunk.h", "gemm_thunk.h", @@ -372,10 +412,11 @@ cc_library( ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", - ":nccl_all_reduce_thunk", + ":nccl_all_reduce_thunk", # fixdeps: keep ":outfeed_manager", ":partition_assignment", ":stream_assignment", + ":stream_executor_util", ":thunk", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -388,7 +429,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", @@ -404,9 +444,15 @@ cc_library( "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep + "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor", "//tensorflow/stream_executor:blas", "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/stream_executor:kernel", + "//tensorflow/stream_executor/cuda:cuda_stream", + "//tensorflow/stream_executor/gpu:gpu_stream", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -415,6 +461,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", ], ) @@ -424,6 +471,7 @@ cc_library( hdrs = ["ir_emission_utils.h"], deps = [ ":backend_configs", + ":target_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -440,23 +488,28 @@ cc_library( srcs = ["cudnn_conv_algorithm_picker.cc"], hdrs = ["cudnn_conv_algorithm_picker.h"], deps = [ - ":autotuning_proto", ":backend_configs", ":buffer_comparator", ":cudnn_conv_runner", + ":gpu_autotuning_proto", ":gpu_executable", ":ir_emission_utils", - ":scratch_allocator", + ":redzone_allocator", + ":stream_executor_util", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:logger", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/util/proto:proto_utils", + "//tensorflow/stream_executor:device_memory_allocator", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", @@ -471,8 +524,53 @@ cc_library( deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", + ], +) + +cc_library( + name = "redzone_allocator", + srcs = ["redzone_allocator.cc"], + hdrs = ["redzone_allocator.h"], + deps = [ + ":gpu_constants", + ":partition_assignment", + ":stream_executor_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/stream_executor:stream_executor_headers", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "redzone_allocator_test", + srcs = ["redzone_allocator_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":redzone_allocator", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:kernel", + "//tensorflow/stream_executor/cuda:cuda_activation", + "//tensorflow/stream_executor/cuda:cuda_gpu_executor", ], ) @@ -537,15 +635,17 @@ cc_library( srcs = ["cusolver_context.cc"], hdrs = ["cusolver_context.h"], deps = [ + "@local_config_cuda//cuda:cuda_headers", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", - "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cusolver", - ], + ] + if_static( + ["@local_config_cuda//cuda:cusolver"], + ["//tensorflow/stream_executor/cuda:cusolver_stub"], + ), ) cc_library( @@ -559,12 +659,12 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/types:optional", ], ) @@ -581,6 +681,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -702,6 +803,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", @@ -775,7 +877,8 @@ cc_library( srcs = ["gpu_transfer_manager.cc"], hdrs = ["gpu_transfer_manager.h"], deps = [ - ":gpu_compiler", + ":infeed_manager", + ":nvptx_compiler", ":outfeed_manager", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -788,7 +891,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", @@ -798,7 +900,7 @@ cc_library( ) cc_library( - name = "gpu_compiler", + name = "nvptx_compiler", srcs = ["nvptx_compiler.cc"], hdrs = ["nvptx_compiler.h"], deps = [ @@ -838,6 +940,7 @@ cc_library( "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -856,6 +959,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:slice_sinker", "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:stable_sort_expander", "//tensorflow/compiler/xla/service:transpose_folding", @@ -871,6 +975,8 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/cuda:cuda_diagnostics", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", @@ -1044,7 +1150,17 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:kernel_spec", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1068,28 +1184,29 @@ cc_library( hdrs = ["buffer_comparator.h"], deps = [ ":gpu_executable", + ":partition_assignment", + ":stream_executor_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/service:device_memory_allocator", - "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/strings", ], ) -xla_test( +tf_cc_test( name = "buffer_comparator_test", srcs = ["buffer_comparator_test.cc"], - backends = [ - "cpu", - "gpu", - ], + tags = tf_cuda_tests_tags(), deps = [ ":buffer_comparator", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:backend", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/stream_executor:device_memory", ], ) @@ -1099,6 +1216,7 @@ cc_library( hdrs = ["gpu_fusible.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", ], ) @@ -1137,8 +1255,8 @@ tf_cc_test( srcs = ["cudnn_fused_conv_rewriter_test.cc"], tags = tf_cuda_tests_tags(), deps = [ + ":ir_emission_utils", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", @@ -1183,10 +1301,11 @@ tf_cc_test( ) xla_proto_library( - name = "autotuning_proto", - srcs = ["autotuning.proto"], + name = "gpu_autotuning_proto", + srcs = ["gpu_autotuning.proto"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:autotuning_proto_cc", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/autotuning.proto b/tensorflow/compiler/xla/service/gpu/autotuning.proto deleted file mode 100644 index b4a08963b4f..00000000000 --- a/tensorflow/compiler/xla/service/gpu/autotuning.proto +++ /dev/null @@ -1,81 +0,0 @@ -// This file defines protos that store the results of autotuning XLA:GPU -// operations. -// -// They are in proto format because we want to log them structured. They offer -// tremendous statistical, testing, and debugging value. -syntax = "proto3"; - -package xla.gpu; - -import "google/protobuf/duration.proto"; -import "tensorflow/compiler/xla/xla_data.proto"; -import "tensorflow/compiler/xla/service/hlo.proto"; - -message CudnnVersion { - int32 major = 1; - int32 minor = 2; - int32 patch = 3; -} - -message ComputeCapability { - int32 major = 1; - int32 minor = 2; -} - -message AutotuneResult { - message SuccessResult { - int64 scratch_bytes = 1; - google.protobuf.Duration run_time = 2; - } - - message ConvKey { - int64 algorithm = 1; - bool tensor_ops_enabled = 2; - } - - // If the conv runs successfully, success will be populated with the - // autotuning result. Otherwise, the error message is propagated. - oneof result { - SuccessResult success = 3; - string error_string = 4; - } - - oneof key { - ConvKey conv = 5; - } - - // Sometimes we run a correctness checker during autotuning. It compares the - // result buffer content between two algorithms, say, "reference" and "test" - // algorithms. The "test" algorithm is the one associated with this - // AutotuneResult. - // - // This field records the reference algorithm used. Notice that naming it - // "reference" doesn't mean it's always correct. However, empirically it's - // more correct, as it's "algo 0", less fancy than the compared one. - // - // Notice that the checker_failure may exist even in the success case. - // This is because the error string in `result` comes from the underlying - // implementation like cuDNN, which isn't aware that it produced an incorrect - // result. And even if the checker detects an incorrect result, we can still - // retrieve scratch_bytes and runtime_ms. - oneof checker_failure { - ConvKey reference_conv = 6; - } -} - -message AutotuneLog { - message Instruction { - xla.HloInstructionProto instruction = 1; - repeated xla.ShapeProto operand_shapes = 2; - } - - oneof instr_oneof { - Instruction instr = 1; - } - - // Records all auto-tuning results per algorithm. - repeated AutotuneResult results = 3; - - CudnnVersion cudnn_version = 4; - ComputeCapability compute_capability = 5; -} diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index eb59ee5a1d4..ae84563881a 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -39,7 +39,7 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index, StatusOr> BufferAllocations::Builder::Build( const BufferAssignment* buffer_assignment, int device_ordinal, - DeviceMemoryAllocator* memory_allocator) { + se::DeviceMemoryAllocator* memory_allocator) { const int64 num_buffers = buffer_assignment->Allocations().size(); auto buffer_allocations = absl::WrapUnique(new BufferAllocations( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); @@ -77,20 +77,21 @@ StatusOr> BufferAllocations::Builder::Build( const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; if (buffer_size > 0) { - OwningDeviceMemory buffer; + se::OwningDeviceMemory buffer; TF_ASSIGN_OR_RETURN( buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); - if (reinterpret_cast(buffer.opaque()) % expected_alignment != + if (reinterpret_cast(buffer->opaque()) % + expected_alignment != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " "multiple of 0x%x, but was %p", - kXlaAllocatedBufferAlignBytes, buffer.opaque()); + kXlaAllocatedBufferAlignBytes, buffer->opaque()); } // We do manual memory management within BufferAllocations. Be sure not // to do a TF_RETURN_IF_ERROR between this line and the // buffer_allocations->SetBuffer(buffer_address) call below! - buffer_address = buffer.Forget(); + buffer_address = buffer.Release(); } buffer_allocations->SetBuffer(i, buffer_address); @@ -164,7 +165,7 @@ se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( CHECK_LE(buffer_slice.offset() + buffer_slice.size(), base.size()); return se::DeviceMemoryBase( static_cast(base.opaque()) + buffer_slice.offset(), - buffer_slice.size(), /*is_sub_buffer=*/true); + buffer_slice.size()); } void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index, diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 9413ac2cff7..cf78b92fe5b 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -23,9 +23,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace gpu { @@ -50,7 +50,7 @@ class BufferAllocations { // memory on. StatusOr> Build( const BufferAssignment* buffer_assignment, int device_ordinal, - DeviceMemoryAllocator* memory_allocator); + se::DeviceMemoryAllocator* memory_allocator); private: absl::flat_hash_map @@ -62,7 +62,9 @@ class BufferAllocations { BufferAllocations(const BufferAllocations&) = delete; BufferAllocations& operator=(const BufferAllocations&) = delete; - DeviceMemoryAllocator* memory_allocator() const { return memory_allocator_; } + se::DeviceMemoryAllocator* memory_allocator() const { + return memory_allocator_; + } int device_ordinal() const { return device_ordinal_; } // Returns the device address of buffer `buffer_index`. `buffer_index` must be @@ -84,7 +86,7 @@ class BufferAllocations { private: BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal, - DeviceMemoryAllocator* memory_allocator, + se::DeviceMemoryAllocator* memory_allocator, const BufferAssignment* buffer_assignment) : buffers_(buffer_count), device_ordinal_(device_ordinal), @@ -104,7 +106,7 @@ class BufferAllocations { se::DeviceMemoryBase temp_buffer_base_; int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; + se::DeviceMemoryAllocator* memory_allocator_; const BufferAssignment* buffer_assignment_; bool torn_down_ = false; }; diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 13c83c9199f..5f3b3b48ef2 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -15,190 +15,466 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" +#include #include + #include "absl/strings/str_replace.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/kernel.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" namespace xla { namespace gpu { -static constexpr float kTolerance = 0.1f; +static constexpr double kTolerance = 0.1f; -static string GetCompHloText(size_t num_elements) { - // Implements the textual format of the comparison routine, as it's more - // readable. - static constexpr char kF16CompHloText[] = R"( -HloModule CompareF16 +// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length +// buffer_length where the relative error does not exceed the passed +// rel_error_threshold. Write the number of mismatches into out parameter +// mismatch_count. +// +// NaN's are considered equal, and for half's we clamp all numbers to largest +// and smallest numbers representable to avoid miscomparisons due to overflows. +// +// The PTX below is compiled from the following CUDA code: +// +// #include +// extern "C" { // avoid name mangling +// __device__ float canonicalize(float input) { +// // All fp16 infinities are treated as 65505 or -65505, in order to avoid +// // differences due to overflows. +// return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); +// } +// +// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b, +// float rel_error_threshold, +// unsigned long long buffer_length, +// int* mismatch_count) { +// int idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// float elem_a = __half2float(buffer_a[idx]); +// float elem_b = __half2float(buffer_b[idx]); +// elem_a = canonicalize(elem_a); +// elem_b = canonicalize(elem_b); +// if (isnan(elem_a) && isnan(elem_b)) return; +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// } +// +// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b, +// float rel_error_threshold, +// unsigned long long buffer_length, +// int* mismatch_count) { +// int idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// float elem_a = buffer_a[idx]; +// float elem_b = buffer_b[idx]; +// if (isnan(elem_a) && isnan(elem_b)) return; +// if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) +// return; +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// } +// +// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b, +// float rel_error_threshold, +// unsigned long long buffer_length, +// int* mismatch_count) { +// int idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// double elem_a = buffer_a[idx]; +// double elem_b = buffer_b[idx]; +// if (isnan(elem_a) && isnan(elem_b)) return; +// if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) +// return; +// double rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// } +// } // end extern declaration. +static const char* buffer_compare_ptx = R"( +.version 4.2 +.target sm_30 +.address_size 64 + +.visible .entry __xla_fp16_comparison( + .param .u64 __xla_fp16_comparison_param_0, + .param .u64 __xla_fp16_comparison_param_1, + .param .f32 __xla_fp16_comparison_param_2, + .param .u64 __xla_fp16_comparison_param_3, + .param .u64 __xla_fp16_comparison_param_4 +) +{ + .reg .pred %p<10>; + .reg .b16 %rs<3>; + .reg .f32 %f<20>; + .reg .b32 %r<6>; + .reg .b64 %rd<12>; + ld.param.u64 %rd8, [__xla_fp16_comparison_param_3]; + mov.u32 %r1, %tid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %ntid.x; + mad.lo.s32 %r4, %r3, %r2, %r1; + cvt.s64.s32 %rd4, %r4; + setp.ge.u64 %p1, %rd4, %rd8; + @%p1 bra LBB7_4; + ld.param.u64 %rd5, [__xla_fp16_comparison_param_0]; + ld.param.u64 %rd7, [__xla_fp16_comparison_param_1]; + cvta.to.global.u64 %rd2, %rd7; + cvta.to.global.u64 %rd3, %rd5; + shl.b64 %rd9, %rd4, 1; + add.s64 %rd10, %rd3, %rd9; + ld.global.u16 %rs1, [%rd10]; + // begin inline asm + { cvt.f32.f16 %f6, %rs1;} + + // end inline asm + add.s64 %rd11, %rd2, %rd9; + ld.global.u16 %rs2, [%rd11]; + // begin inline asm + { cvt.f32.f16 %f7, %rs2;} + + // end inline asm + abs.f32 %f8, %f6; + setp.gtu.f32 %p2, %f8, 0f7F800000; + min.f32 %f9, %f6, 0f477FE100; + max.f32 %f10, %f9, 0fC77FE100; + selp.f32 %f1, %f6, %f10, %p2; + abs.f32 %f11, %f7; + setp.gtu.f32 %p3, %f11, 0f7F800000; + min.f32 %f12, %f7, 0f477FE100; + max.f32 %f13, %f12, 0fC77FE100; + selp.f32 %f2, %f7, %f13, %p3; + abs.f32 %f3, %f1; + setp.gtu.f32 %p4, %f3, 0f7F800000; + abs.f32 %f4, %f2; + setp.gtu.f32 %p5, %f4, 0f7F800000; + and.pred %p6, %p4, %p5; + @%p6 bra LBB7_4; + ld.param.f32 %f5, [__xla_fp16_comparison_param_2]; + sub.f32 %f14, %f1, %f2; + abs.f32 %f15, %f14; + max.f32 %f16, %f3, %f4; + add.f32 %f17, %f16, 0f3F800000; + div.rn.f32 %f18, %f15, %f17; + setp.leu.f32 %p7, %f18, %f5; + abs.f32 %f19, %f18; + setp.le.f32 %p8, %f19, 0f7F800000; + and.pred %p9, %p7, %p8; + @%p9 bra LBB7_4; + ld.param.u64 %rd6, [__xla_fp16_comparison_param_4]; + cvta.to.global.u64 %rd1, %rd6; + atom.global.add.u32 %r5, [%rd1], 1; +LBB7_4: + ret; -MaxF32 { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %max = f32[] maximum(%lhs, %rhs) } + // .globl __xla_fp32_comparison +.visible .entry __xla_fp32_comparison( + .param .u64 __xla_fp32_comparison_param_0, + .param .u64 __xla_fp32_comparison_param_1, + .param .f32 __xla_fp32_comparison_param_2, + .param .u64 __xla_fp32_comparison_param_3, + .param .u64 __xla_fp32_comparison_param_4 +) +{ + .reg .pred %p<12>; + .reg .f32 %f<12>; + .reg .b32 %r<9>; + .reg .b64 %rd<12>; -Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] { - %min_constant = f32[] constant(-65505) - %max_constant = f32[] constant(65505) - %large_constant = f32[] constant(1048576) - %min_values = f32[SIZE] broadcast(%min_constant), dimensions={} - %max_values = f32[SIZE] broadcast(%max_constant), dimensions={} - %large_values = f32[SIZE] broadcast(%large_constant), dimensions={} + ld.param.u64 %rd8, [__xla_fp32_comparison_param_3]; + mov.u32 %r1, %tid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %ntid.x; + mad.lo.s32 %r4, %r3, %r2, %r1; + cvt.s64.s32 %rd4, %r4; + setp.ge.u64 %p1, %rd4, %rd8; + @%p1 bra LBB8_6; + ld.param.u64 %rd5, [__xla_fp32_comparison_param_0]; + ld.param.u64 %rd7, [__xla_fp32_comparison_param_1]; + cvta.to.global.u64 %rd2, %rd7; + cvta.to.global.u64 %rd3, %rd5; + shl.b64 %rd9, %rd4, 2; + add.s64 %rd10, %rd3, %rd9; + ld.global.f32 %f1, [%rd10]; + add.s64 %rd11, %rd2, %rd9; + ld.global.f32 %f2, [%rd11]; + abs.f32 %f3, %f1; + setp.gtu.f32 %p2, %f3, 0f7F800000; + abs.f32 %f4, %f2; + setp.gtu.f32 %p3, %f4, 0f7F800000; + and.pred %p4, %p2, %p3; + @%p4 bra LBB8_6; + setp.neu.f32 %p5, %f3, 0f7F800000; + setp.neu.f32 %p6, %f4, 0f7F800000; + or.pred %p7, %p5, %p6; + @%p7 bra LBB8_4; + mov.b32 %r5, %f1; + mov.b32 %r6, %f2; + xor.b32 %r7, %r6, %r5; + setp.gt.s32 %p8, %r7, -1; + @%p8 bra LBB8_6; +LBB8_4: + ld.param.f32 %f5, [__xla_fp32_comparison_param_2]; + sub.f32 %f6, %f1, %f2; + abs.f32 %f7, %f6; + max.f32 %f8, %f3, %f4; + add.f32 %f9, %f8, 0f3F800000; + div.rn.f32 %f10, %f7, %f9; + setp.leu.f32 %p9, %f10, %f5; + abs.f32 %f11, %f10; + setp.le.f32 %p10, %f11, 0f7F800000; + and.pred %p11, %p9, %p10; + @%p11 bra LBB8_6; + ld.param.u64 %rd6, [__xla_fp32_comparison_param_4]; + cvta.to.global.u64 %rd1, %rd6; + atom.global.add.u32 %r8, [%rd1], 1; +LBB8_6: + ret; - %a = f16[SIZE] parameter(0) - %converted = f32[SIZE] convert(%a) - %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values) - - // Since the clamp() above already took care of infs, only NaNs will cause - // is-finite() to return false. - %is_finite = pred[SIZE] is-finite(%clamped) - ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values) } + // .globl __xla_fp64_comparison +.visible .entry __xla_fp64_comparison( + .param .u64 __xla_fp64_comparison_param_0, + .param .u64 __xla_fp64_comparison_param_1, + .param .f32 __xla_fp64_comparison_param_2, + .param .u64 __xla_fp64_comparison_param_3, + .param .u64 __xla_fp64_comparison_param_4 +) +{ + .reg .pred %p<16>; + .reg .f32 %f<2>; + .reg .b32 %r<13>; + .reg .f64 %fd<12>; + .reg .b64 %rd<12>; -ENTRY MaxDifference { - %one_constant = f32[] constant(1.0) - %zero_constant = f32[] constant(0.0) - - %ones = f32[SIZE] broadcast(%one_constant), dimensions={} - - %lhs = f16[SIZE] parameter(0) - %rhs = f16[SIZE] parameter(1) - %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize - %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize - %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical) - %sub_abs = f32[SIZE] abs(%sub) - %lhs_abs = f32[SIZE] abs(%lhs_canonical) - %rhs_abs = f32[SIZE] abs(%rhs_canonical) - %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs) - %denominator = f32[SIZE] add(%max, %ones) - %error = f32[SIZE] divide(%sub_abs, %denominator) - ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 -})"; - return absl::StrReplaceAll(kF16CompHloText, - {{"SIZE", absl::StrCat(num_elements)}}); + ld.param.u64 %rd8, [__xla_fp64_comparison_param_3]; + mov.u32 %r2, %tid.x; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mad.lo.s32 %r5, %r4, %r3, %r2; + cvt.s64.s32 %rd4, %r5; + setp.ge.u64 %p1, %rd4, %rd8; + @%p1 bra LBB9_6; + ld.param.u64 %rd5, [__xla_fp64_comparison_param_0]; + ld.param.u64 %rd7, [__xla_fp64_comparison_param_1]; + cvta.to.global.u64 %rd2, %rd7; + cvta.to.global.u64 %rd3, %rd5; + shl.b64 %rd9, %rd4, 3; + add.s64 %rd10, %rd3, %rd9; + ld.global.f64 %fd1, [%rd10]; + add.s64 %rd11, %rd2, %rd9; + ld.global.f64 %fd2, [%rd11]; + abs.f64 %fd3, %fd1; + setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000; + abs.f64 %fd4, %fd2; + setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000; + and.pred %p4, %p2, %p3; + @%p4 bra LBB9_6; + { + .reg .b32 %temp; + mov.b64 {%r6, %temp}, %fd1; + } + { + .reg .b32 %temp; + mov.b64 {%temp, %r1}, %fd1; + } + and.b32 %r7, %r1, 2147483647; + setp.ne.s32 %p5, %r7, 2146435072; + setp.ne.s32 %p6, %r6, 0; + or.pred %p7, %p6, %p5; + @%p7 bra LBB9_4; + { + .reg .b32 %temp; + mov.b64 {%r8, %temp}, %fd2; + } + { + .reg .b32 %temp; + mov.b64 {%temp, %r9}, %fd2; + } + and.b32 %r10, %r9, 2147483647; + setp.eq.s32 %p8, %r10, 2146435072; + setp.eq.s32 %p9, %r8, 0; + and.pred %p10, %p8, %p9; + xor.b32 %r11, %r9, %r1; + setp.gt.s32 %p11, %r11, -1; + and.pred %p12, %p11, %p10; + @%p12 bra LBB9_6; +LBB9_4: + ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; + sub.f64 %fd5, %fd1, %fd2; + abs.f64 %fd6, %fd5; + max.f64 %fd7, %fd3, %fd4; + add.f64 %fd8, %fd7, 0d3FF0000000000000; + div.rn.f64 %fd9, %fd6, %fd8; + cvt.f64.f32 %fd10, %f1; + setp.leu.f64 %p13, %fd9, %fd10; + abs.f64 %fd11, %fd9; + setp.le.f64 %p14, %fd11, 0d7FF0000000000000; + and.pred %p15, %p13, %p14; + @%p15 bra LBB9_6; + ld.param.u64 %rd6, [__xla_fp64_comparison_param_4]; + cvta.to.global.u64 %rd1, %rd6; + atom.global.add.u32 %r12, [%rd1], 1; +LBB9_6: + ret; } +)"; -StatusOr F16BufferComparator::Create( - se::DeviceMemory ref_buffer, Compiler* compiler, - DeviceMemoryAllocator* allocator, se::Stream* stream) { - auto stream_exec = stream->parent(); - int64 num_elements = ref_buffer.ElementCount(); +template +using ComparisonKernelT = + se::TypedKernel, se::DeviceMemory, + float, uint64, se::DeviceMemory>; - // One may consider using hlo_runner to do all the compilation and execution. - // However, as of the time hlo_runner doesn't support injection for Compiler*, - // Stream*, or even the allocator. We may revisit this in the future if it - // proves to be a maintenance burden. - TF_ASSIGN_OR_RETURN( - auto exec, ([&]() -> StatusOr> { - HloModuleConfig config; - DebugOptions debug_options; - debug_options.set_xla_backend_optimization_level(2); - config.set_debug_options(debug_options); - TF_ASSIGN_OR_RETURN( - auto module, ParseHloString(GetCompHloText(num_elements), config)); - TF_ASSIGN_OR_RETURN( - module, - compiler->RunHloPasses(std::move(module), stream_exec, nullptr)); - return compiler->RunBackend(std::move(module), stream_exec, nullptr); - }())); +// Compares two buffers on the GPU. +// +// Returns `true` if two buffers are equal, `false` otherwise. +template +static StatusOr DeviceCompare(se::Stream* stream, + se::DeviceMemoryBase lhs, + se::DeviceMemoryBase rhs, + const Shape& buffer_shape, + const HloModuleConfig& config, + absl::string_view kernel_name) { + se::StreamExecutor* executor = stream->parent(); - TF_ASSIGN_OR_RETURN( - auto shaped_buffer, ([&]() -> StatusOr { - auto device_ordinal = stream_exec->device_ordinal(); - TF_ASSIGN_OR_RETURN( - auto owning_buffer, - allocator->Allocate(device_ordinal, ref_buffer.size())); - se::DeviceMemory buffer( - owning_buffer.AsDeviceMemoryBase()); - stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size()); - Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); - ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal); - ret.set_buffer(std::move(owning_buffer), {}); - return std::move(ret); - }())); + se::ScopedDeviceMemory out_param = + executor->AllocateOwnedScalar(); - return F16BufferComparator(stream, allocator, std::move(exec), - std::move(shaped_buffer)); -} - -StatusOr F16BufferComparator::CompareEqualImpl( - se::DeviceMemory test_buffer) { - if (ref_buffer_.root_buffer().size() != test_buffer.size()) { - return InternalError("Mismatched buffer size: %d vs %d", - ref_buffer_.root_buffer().size(), test_buffer.size()); + stream->ThenMemZero(out_param.ptr(), sizeof(uint64)); + if (lhs.size() != rhs.size()) { + return InternalError("Mismatched buffer size: %d bytes vs. %d bytes", + lhs.size(), rhs.size()); } - int64 num_elements = test_buffer.ElementCount(); + se::DeviceMemory lhs_typed(lhs); + se::DeviceMemory rhs_typed(rhs); + uint64 buffer_size = lhs_typed.ElementCount(); + + PtxCompilationOptions opts(config); + TF_ASSIGN_OR_RETURN( + absl::Span compiled_ptx, + CompilePtxOrGetCached(executor, buffer_compare_ptx, opts)); TF_ASSIGN_OR_RETURN( - auto result_buffer, ([&]() -> StatusOr { - auto stream_exec = stream_->parent(); - Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); - auto device_ordinal = stream_exec->device_ordinal(); - ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(), - device_ordinal); - shaped_test_buffer.set_buffer(test_buffer, {}); - ExecutableRunOptions run_options; - run_options.set_device_ordinal(stream_exec->device_ordinal()); - run_options.set_stream(stream_); - run_options.set_allocator(allocator_); - ServiceExecutableRunOptions service_run_options(run_options); - return exec_->ExecuteOnStream( - &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr); - }())); + std::unique_ptr> comparison_kernel, + (CreateTypedKernel, se::DeviceMemory, + float, uint64, se::DeviceMemory>( + kernel_name, buffer_compare_ptx, compiled_ptx, executor))); - float result; - CHECK(result_buffer.root_buffer().size() == sizeof(result)); - stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result)); - TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); - return result < kTolerance; + LaunchDimensions dim = + CalculateLaunchDimensions(buffer_shape, executor->GetDeviceDescription()); + + stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()), + se::BlockDim(dim.block_count()), *comparison_kernel, + lhs_typed, rhs_typed, static_cast(kTolerance), + buffer_size, out_param.cref()); + + uint64 result = -1; + CHECK_EQ(out_param->size(), sizeof(result)); + stream->ThenMemcpy(&result, *out_param, sizeof(result)); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + return result == 0; } -StatusOr F16BufferComparator::CompareEqual( - se::DeviceMemory test_buffer) { - TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer)); - if (result) { - return true; - } - // Host side code that does the same thing, but report some of the - // differences as well. - int64 n = test_buffer.ElementCount(); - std::vector host_ref_buffer(n), host_test_buffer(n); - stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(), - ref_buffer_.root_buffer().size()); - stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size()); - TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); +// Host side comparison code that does the same thing, but reports some of the +// differences as well. It only print logs for debugging. +// +// Returns true if no differences were seen, false otherwise. +template +StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs, + se::DeviceMemoryBase rhs) { + int64 n = lhs.size() / sizeof(ElementType); + std::vector host_lhs(n), host_rhs(n); + stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size()); + stream->ThenMemcpy(host_rhs.data(), rhs, rhs.size()); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - const auto canonicalize = [](float a) -> float { - constexpr float kBigNumer = 1048576.; - constexpr float kMaxFp16Value = 65504.; - if (std::isnan(a)) { - return kBigNumer; - } - if (std::isinf(a)) { - if (a < 0) { - return -(kMaxFp16Value + 1); + const auto canonicalize = [](ComparisonType a) -> ComparisonType { + if (std::is_same::value && a) { + constexpr ComparisonType kMaxFp16Value = 65505.; + if (std::isnan(a)) { + return a; } - return kMaxFp16Value + 1; + return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value)); } return a; }; int differences_seen = 0; for (int64 i = 0; i < n && differences_seen < 10; i++) { - float original_ref = static_cast(host_ref_buffer[i]); - float original_test = static_cast(host_test_buffer[i]); - float ref = canonicalize(original_ref); - float test = canonicalize(original_test); - if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) < + auto original_lhs = static_cast(host_lhs[i]); + auto original_rhs = static_cast(host_rhs[i]); + ComparisonType lhs = canonicalize(original_lhs); + ComparisonType rhs = canonicalize(original_rhs); + if (std::isnan(lhs) && std::isnan(rhs)) { + continue; + } + if (std::isinf(lhs) && std::isinf(rhs) && lhs == rhs) { + continue; + } + if (std::isfinite(lhs) != std::isfinite(rhs) || + !(std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1) < kTolerance)) { differences_seen++; - LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs " - << original_test; + LOG(ERROR) << "Difference at " << i << ": " << original_lhs << " vs " + << original_rhs; } } + return differences_seen == 0; +} + +template +static StatusOr CompareEqualParameterized(se::Stream* stream, + se::DeviceMemoryBase lhs, + se::DeviceMemoryBase rhs, + const Shape& shape, + const HloModuleConfig& config, + absl::string_view kernel_name) { + XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual"); + TF_ASSIGN_OR_RETURN( + bool result, + DeviceCompare(stream, lhs, rhs, shape, config, kernel_name)); + + if (result) { + return true; + } + + TF_ASSIGN_OR_RETURN(bool host_return, + (HostCompare(stream, lhs, rhs))); + CHECK(host_return == result) << "Different comparison result on GPU vs host"; return false; } +StatusOr BufferComparator::CompareEqual(se::Stream* stream, + se::DeviceMemoryBase lhs, + se::DeviceMemoryBase rhs) { + switch (shape_.element_type()) { + case xla::F16: + return CompareEqualParameterized( + stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison"); + case xla::F32: + return CompareEqualParameterized( + stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison"); + case xla::F64: + return CompareEqualParameterized( + stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison"); + default: + return Unimplemented("Unimplemented element type"); + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h index bf2ba78ceac..e77dfe02a15 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h @@ -16,53 +16,39 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ -#include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -// A fp16 comparator that internally keeps a reference buffer, and compares it -// against other test buffers. -class F16BufferComparator { +// A device-side comparator that compares buffers. +class BufferComparator { public: - F16BufferComparator(const F16BufferComparator&) = delete; - F16BufferComparator(F16BufferComparator&&) = default; + BufferComparator(const BufferComparator&) = delete; + BufferComparator(BufferComparator&&) = default; - // Creates a new comparator. It internally allocates a buffer initialized by - // ref_buffer. - static StatusOr Create( - se::DeviceMemory ref_buffer, Compiler* compiler, - DeviceMemoryAllocator* allocator, se::Stream* stream); + BufferComparator(const Shape& shape, const HloModuleConfig& config) + : shape_(shape), config_(config) {} - // Returns true if the internally allocated buffer "compares equal" to - // test_buffer. The definition of "equal" is: + // Returns true if the two buffers compare equal. The definition of "equal" + // is: // * All NaNs equal. - // * All infs are treated as 65505 or -65505, so that this checker is tolerant - // to fp16 overflows. + // * All fp16 infs are treated as 65505 or -65505. Otherwise, + // infs and negative infs compare equal. // * With NaNs and infs taken care of, a and b compare equal iff: // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance // // See the implementation for the tolerance value. - StatusOr CompareEqual(se::DeviceMemory test_buffer); + StatusOr CompareEqual(se::Stream* stream, + se::DeviceMemoryBase lhs, + se::DeviceMemoryBase rhs); private: - F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator, - std::unique_ptr exec, - ScopedShapedBuffer ref_buffer) - : stream_(stream), - allocator_(allocator), - exec_(std::move(exec)), - ref_buffer_(std::move(ref_buffer)) {} - - StatusOr CompareEqualImpl(se::DeviceMemory test_buffer); - - se::Stream* stream_; - DeviceMemoryAllocator* allocator_; - std::unique_ptr exec_; - ScopedShapedBuffer ref_buffer_; + Shape shape_; + HloModuleConfig config_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc index 33761d1bd88..4bca6e7dfd2 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include -#include "tensorflow/compiler/xla/service/backend.h" + +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/device_memory.h" namespace xla { namespace gpu { @@ -27,97 +29,170 @@ namespace { class BufferComparatorTest : public testing::Test { protected: BufferComparatorTest() - : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()), - stream_exec_(backend_->default_stream_executor()), - allocator_(stream_exec_->platform(), {stream_exec_}), - compiler_(Compiler::GetForPlatform(stream_exec_->platform()) - .ConsumeValueOrDie()) {} + : platform_( + se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie()), + stream_exec_(platform_->ExecutorForDevice(0).ValueOrDie()) {} - // Take floats only for convenience. Still uses half internally. + // Take floats only for convenience. Still uses ElementType internally. + template bool CompareEqualFloatBuffers(const std::vector& lhs_float, const std::vector& rhs_float) { - std::vector lhs(lhs_float.begin(), lhs_float.end()); - std::vector rhs(rhs_float.begin(), rhs_float.end()); + std::vector lhs(lhs_float.begin(), lhs_float.end()); + std::vector rhs(rhs_float.begin(), rhs_float.end()); se::Stream stream(stream_exec_); stream.Init(); - auto owning_lhs_buffer = - allocator_ - .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half)) - .ConsumeValueOrDie(); - - auto owning_rhs_buffer = - allocator_ - .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half)) - .ConsumeValueOrDie(); - - auto lhs_buffer = - se::DeviceMemory(owning_lhs_buffer.AsDeviceMemoryBase()); - auto rhs_buffer = - se::DeviceMemory(owning_rhs_buffer.AsDeviceMemoryBase()); - - stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size()); - stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size()); + se::ScopedDeviceMemory lhs_buffer = + stream_exec_->AllocateOwnedArray(lhs.size()); + se::ScopedDeviceMemory rhs_buffer = + stream_exec_->AllocateOwnedArray(lhs.size()); + stream.ThenMemcpy(lhs_buffer.ptr(), lhs.data(), lhs_buffer->size()); + stream.ThenMemcpy(rhs_buffer.ptr(), rhs.data(), rhs_buffer->size()); TF_CHECK_OK(stream.BlockHostUntilDone()); - return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_, - &stream) - .ConsumeValueOrDie() - .CompareEqual(rhs_buffer) + BufferComparator comparator( + ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), + {static_cast(lhs_buffer->ElementCount())}), + HloModuleConfig()); + return comparator.CompareEqual(&stream, *lhs_buffer, *rhs_buffer) .ConsumeValueOrDie(); } - std::unique_ptr backend_; + se::Platform* platform_; se::StreamExecutor* stream_exec_; - StreamExecutorMemoryAllocator allocator_; - Compiler* compiler_; }; TEST_F(BufferComparatorTest, TestNaNs) { - EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + EXPECT_TRUE( + CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); // NaN values with different bit patterns should compare equal. - EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); - EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, + {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); + + EXPECT_TRUE( + CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + // NaN values with different bit patterns should compare equal. + EXPECT_TRUE( + CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); + + EXPECT_TRUE( + CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + // NaN values with different bit patterns should compare equal. + EXPECT_TRUE( + CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); } TEST_F(BufferComparatorTest, TestInfs) { const auto inf = std::numeric_limits::infinity(); - EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); - EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); - EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); - EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); - EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); - EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); - EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); - EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); - EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); - EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); + + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); } TEST_F(BufferComparatorTest, TestNumbers) { - EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); - EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); - EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); - EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); - EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); + + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); + + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); } TEST_F(BufferComparatorTest, TestMultiple) { - EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, - {20.1, 30.1, 40.1, 50.1, 60.1})); - std::vector lhs(200); - std::vector rhs(200); - for (int i = 0; i < 200; i++) { - EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) - << "should be the same at index " << i; - lhs[i] = 3; - rhs[i] = 5; - EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) - << "should be the different at index " << i; - lhs[i] = 0; - rhs[i] = 0; + { + EXPECT_TRUE(CompareEqualFloatBuffers( + {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } + } + + { + EXPECT_TRUE(CompareEqualFloatBuffers( + {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } + } + + { + EXPECT_TRUE(CompareEqualFloatBuffers( + {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } } } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 0c4980f6549..9ef5f07d857 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -14,21 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" + +#include "google/protobuf/any.pb.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/proto/proto_utils.h" namespace xla { namespace gpu { @@ -37,6 +45,7 @@ namespace { using absl::optional; using se::DeviceMemoryBase; using se::dnn::AlgorithmDesc; +using tensorflow::AutotuneResult; std::vector GetAlgorithms(CudnnConvKind kind, se::StreamExecutor* stream_exec) { @@ -72,30 +81,8 @@ string NumBytesToString(int64 bytes) { bytes, "B)"); } -// Acquires a process-global lock on the device pointed to by the given -// StreamExecutor. -// -// This is used to prevent other XLA instances from trying to autotune on this -// device while we're using it. -tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - // se::Platform*s are global singletons guaranteed to live forever. - static auto* mutexes = - new std::map, - tensorflow::mutex>(); - - tensorflow::mutex_lock global_lock(mu); - auto it = mutexes - ->emplace(std::piecewise_construct, - std::make_tuple(stream_exec->platform(), - stream_exec->device_ordinal()), - std::make_tuple()) - .first; - return tensorflow::mutex_lock{it->second}; -} - -xla::gpu::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { - xla::gpu::CudnnVersion cudnn_version; +tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { + tensorflow::CudnnVersion cudnn_version; if (auto* dnn = stream_executor->AsDnn()) { StatusOr version_or = dnn->GetVersion(); if (version_or.ok()) { @@ -108,9 +95,9 @@ xla::gpu::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { return cudnn_version; } -xla::gpu::ComputeCapability GetComputeCapability( +tensorflow::ComputeCapability GetComputeCapability( se::StreamExecutor* stream_executor) { - xla::gpu::ComputeCapability cc; + tensorflow::ComputeCapability cc; int cc_major, cc_minor; stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); @@ -119,33 +106,151 @@ xla::gpu::ComputeCapability GetComputeCapability( return cc; } +void PrintPlatformInfo(const se::Stream* stream) { + auto* se = stream->parent(); + const auto& desc = se->GetDeviceDescription(); + LOG(ERROR) << "Device: " << desc.name(); + LOG(ERROR) << "Platform: " << desc.platform_version(); + LOG(ERROR) << "Driver: " << desc.driver_version(); + LOG(ERROR) << "Runtime: " << desc.runtime_version(); + + auto* dnn = se->AsDnn(); + if (dnn) { + auto dnn_version = dnn->GetVersion(); + if (dnn_version.ok()) { + auto v = dnn_version.ValueOrDie(); + LOG(ERROR) << "cudnn version: " << v.major_version() << "." + << v.minor_version() << "." << v.patch(); + } + } +} + +// Returns true if the redzones in `allocator`'s allocations are unmodified. +// +// If the redzones are modified, logs an error, sets the appropriate failure +// bits on `result`, and returns false. +// +// Returns a status if an unexpected error has occurred, and the stream +// has been poisoned. +// +// `name` is a user-friendly name for the set of redzones being checked, e.g. +// "input/output" or "scratch". +StatusOr CheckRedzones(const RedzoneAllocator& allocator, + se::Stream* stream, absl::string_view name, + const HloInstruction* instr, + AutotuneResult* result) { + XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones", + 2); + using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; + + TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check, + allocator.CheckRedzones(stream)); + + if (redzone_check.ok()) { + return true; + } + + auto* fail = result->mutable_failure(); + fail->set_kind(AutotuneResult::REDZONE_MODIFIED); + *fail->mutable_msg() = redzone_check.redzone_failure_msg; + + LOG(ERROR) << absl::StreamFormat( + "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a " + "cudnn bug. We will skip this algorithm in the future, but your GPU " + "state may already be corrupted, leading to incorrect results. Within " + "Google, no action is needed on your part. Outside of Google, please " + "ensure you're running the latest version of cudnn. If that doesn't fix " + "the problem, please file a bug with this full error message and we'll " + "contact nvidia.", + name); + LOG(ERROR) << redzone_check.redzone_failure_msg; + LOG(ERROR) << "HloInstruction " << instr->ToString(); + PrintPlatformInfo(stream); + return false; +} + +using ConvCacheKey = + std::tuple, std::string, std::string, int64>; + +struct ConvCacheStats { + int64 cache_hits = 0; + int64 cache_misses = 0; + + void LogStats() { + VLOG(1) << "Cache hits: " << cache_hits; + VLOG(1) << "Cache misses: " << cache_misses; + } +}; + +StatusOr AutotuneCacheKeyfromInstruction( + const HloCustomCallInstruction* conv, se::StreamExecutor* se) { + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + conv->backend_config()); + std::vector operand_shapes; + absl::c_transform(conv->operands(), std::back_inserter(operand_shapes), + [&](const HloInstruction* op) { return op->shape(); }); + + return std::make_tuple( + se, backend_config.SerializeAsString(), conv->custom_call_target(), + conv->shape(), std::move(operand_shapes), + conv->window().SerializeAsString(), + conv->convolution_dimension_numbers().SerializeAsString(), + conv->feature_group_count()); +} + +tensorflow::mutex autotune_cache_lock(tensorflow::LINKER_INITIALIZED); +auto& autotune_cache GUARDED_BY(autotune_cache_lock) = + *new absl::flat_hash_map(); +auto& autotune_cache_stats GUARDED_BY(autotune_cache_lock) = + *new ConvCacheStats(); } // anonymous namespace -// We could have caching here so that we don't redo this work for two identical -// convolutions. Unfortunately our cache key would have to be a tuple -// containing the protos passed to this function, and we have no utility for -// hashing protos. We could write our own hash functions, but they'd silently -// break if we ever added a field to one of the protos. Perhaps we could hack -// using the binary-encoded proto as the hash key, on the assumption that two -// protos being binary-equal is a sufficient, if not necessary, condition for -// proper equality. But that would still leave us open to having unnecessary -// cache misses and doing extra work. Overall, caching doesn't seem worth the -// trouble, but we may want to revisit this if we ever find a model where -// caching would speed up compilation a lot. StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( const HloCustomCallInstruction* instr) { - // TODO(timshen): for now only check fp16. It can be expanded to other types, - // with some work on the HLO routines. - const bool cross_check_enabled = - instr->shape().tuple_shapes(0).element_type() == xla::F16; - // Don't run this function concurrently on the same GPU. // // This is a bit of a hack and doesn't protect us against arbitrary concurrent // use of a GPU, but it's sufficient to let us compile two HLO modules // concurrently and then run them sequentially. + // + // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us + // avoid ever doing duplicate work. If we have a cache miss, only one thread + // will run PickBestAlgorithmImpl for a particular device. tensorflow::mutex_lock lock = LockGpu(stream_exec_); + // We cache the autotuning results to avoid doing the duplicate work, + // which can greatly improve both stability (deterministic numeric results + // within a process for a given input) and performance (2x speedup on some + // models). + TF_ASSIGN_OR_RETURN(ConvCacheKey key, + AutotuneCacheKeyfromInstruction(instr, stream_exec_)); + { + tensorflow::mutex_lock lock(autotune_cache_lock); + auto it = autotune_cache.find(key); + if (it != autotune_cache.end()) { + autotune_cache_stats.cache_hits++; + return it->second; + } + autotune_cache_stats.cache_misses++; + } + + StatusOr result_or = PickBestAlgorithmNoCache(instr); + if (result_or.ok()) { + tensorflow::mutex_lock lock(autotune_cache_lock); + CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second); + } + return result_or; +} + +StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( + const HloCustomCallInstruction* instr) { + XLA_SCOPED_LOGGING_TIMER( + absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithmImpl for ", + instr->ToString())); + + const Shape& result_shape = instr->shape().tuple_shapes(0); + // Make sure any previous activity on this executor is done. We don't want to // interfere with programs that are still running on the GPU. if (!stream_exec_->SynchronizeAllActivity()) { @@ -158,9 +263,9 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( const auto device_ordinal = stream_exec_->device_ordinal(); // allocator either points to this->allocator_ or, if that's null, to a - // StreamExecutorMemoryAllocator for stream_exec_. - DeviceMemoryAllocator* allocator; - optional se_allocator; + // se::StreamExecutorMemoryAllocator for stream_exec_. + se::DeviceMemoryAllocator* allocator; + optional se_allocator; if (allocator_ != nullptr) { allocator = allocator_; } else { @@ -169,42 +274,50 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( allocator = &*se_allocator; } - const auto initialize_buffer = [&stream, cross_check_enabled]( - DeviceMemoryBase buffer) { - if (cross_check_enabled) { - // Broadcast a constant to the buffer, instead of zeroing the buffer. A - // non-zero constant is useful for the cross checking, because zero-inputs - // may not always reveal the bugs. - CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); - size_t left_over_bytes = buffer.size() % 4; - CHECK_EQ(0, left_over_bytes % 2); + const auto initialize_buffer = [&stream, + &result_shape](DeviceMemoryBase buffer) { + constexpr float kBroadcastedConstant = 0.1f; + switch (result_shape.element_type()) { + case xla::F16: { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because + // zero-inputs may not always reveal the bugs. + CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); + size_t left_over_bytes = buffer.size() % 4; + CHECK_EQ(0, left_over_bytes % 2); - constexpr float kBroadcastedConstant = 0.1f; - static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), - Eigen::half(kBroadcastedConstant)}; - uint32 bits; - static_assert(sizeof(bits) == sizeof(halfs), ""); - memcpy(&bits, halfs, sizeof(bits)); + static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), + Eigen::half(kBroadcastedConstant)}; + uint32 bits; + static_assert(sizeof(bits) == sizeof(halfs), ""); + memcpy(&bits, halfs, sizeof(bits)); - size_t aligned_size = buffer.size() / 4 * 4; - stream.ThenMemset32(&buffer, bits, aligned_size); + size_t aligned_size = buffer.size() / 4 * 4; + stream.ThenMemset32(&buffer, bits, aligned_size); - DeviceMemoryBase left_over( - static_cast(buffer.opaque()) + aligned_size, left_over_bytes); - stream.ThenMemcpy(&left_over, halfs, left_over_bytes); - } else { - // Although we don't have evidence this matters, zero out the buffers - // before autotuning. It's conceivable that using uninitialized memory as - // the inputs might affect performance if e.g. the inputs contain - // denormals, and this is easy enough. - stream.ThenMemZero(&buffer, buffer.size()); + DeviceMemoryBase left_over( + static_cast(buffer.opaque()) + aligned_size, + left_over_bytes); + stream.ThenMemcpy(&left_over, halfs, left_over_bytes); + break; + } + case xla::F32: { + uint32 bits; + memcpy(&bits, &kBroadcastedConstant, sizeof(bits)); + stream.ThenMemset32(&buffer, bits, buffer.size()); + break; + } + // TODO(timshen): populate non-zero data for f64. + default: + stream.ThenMemZero(&buffer, buffer.size()); } }; - // Allocate space for the input, filter, and output of the convolution. We - // use a ScratchAllocator for this instead of calling allocator_ directly so - // that our allocations don't leak. - ScratchAllocator input_output_allocator(device_ordinal, allocator); + const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); + + // Allocate space for the input, filter, and output of the convolution. + RedzoneAllocator input_output_allocator(device_ordinal, allocator, + hlo_module_config); std::vector operand_buffers; for (const auto* operand : instr->operands()) { TF_ASSIGN_OR_RETURN(auto buffer, @@ -213,24 +326,38 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( initialize_buffer(buffer); operand_buffers.push_back(buffer); } - TF_ASSIGN_OR_RETURN( - auto result_buffer, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); + TF_ASSIGN_OR_RETURN(auto result_buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(result_shape))); initialize_buffer(result_buffer); TF_ASSIGN_OR_RETURN(auto backend_config, instr->backend_config()); - optional comparator; + optional comparator; // Use the first algorithm that's supported as reference. There isn't a // particular reason to use it, as any algorithm sufficies. It doesn't make // this algorithm considered correct, though. - optional first_algorithm; + se::DeviceMemoryBase reference_result_buffer; + AlgorithmDesc first_algorithm; + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); std::vector profile_results; + + const DebugOptions& debug_options = + instr->GetModule()->config().debug_options(); + + const bool crash_on_checking_failure = + debug_options.xla_gpu_crash_on_verification_failures(); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { - ScratchAllocator scratch_allocator(device_ordinal, allocator); + XLA_SCOPED_LOGGING_TIMER_LEVEL( + absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ", + AlgorithmToString(alg)), + 2); + + RedzoneAllocator scratch_allocator(device_ordinal, allocator, + hlo_module_config); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); @@ -243,103 +370,142 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, &scratch_allocator, &stream, options); + if (!launch_status.ok()) { + continue; + } + + if (!profile_result.is_valid()) { + continue; + } + profile_results.emplace_back(); AutotuneResult& result = profile_results.back(); result.mutable_conv()->set_algorithm(alg.algo_id()); result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled()); - if (!launch_status.ok()) { - result.set_error_string(launch_status.error_message()); + int64 scratch_bytes_used = + scratch_allocator.TotalAllocatedBytesExcludingRedzones(); + result.set_scratch_bytes(scratch_bytes_used); + *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); + + // Check for writes to redzones. + TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear, + CheckRedzones(input_output_allocator, &stream, + "input/output", instr, &result)); + + TF_ASSIGN_OR_RETURN( + bool scratch_allocator_redzone_clear, + CheckRedzones(scratch_allocator, &stream, "scratch", instr, &result)); + + if (!input_output_allocator_redzone_clear || + !scratch_allocator_redzone_clear) { continue; } - if (!profile_result.is_valid()) { - result.set_error_string("Invalid profile result"); - continue; - } - - int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); - result.mutable_success()->set_scratch_bytes(scratch_bytes_used); - *result.mutable_success()->mutable_run_time() = - protobuf_util::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); - - const bool crash_on_checking_failure = - instr->GetModule() - ->config() - .debug_options() - .xla_gpu_crash_on_verification_failures(); - if (comparator.has_value()) { + XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2); StatusOr compare_result = comparator->CompareEqual( - se::DeviceMemory(result_buffer)); + &stream, reference_result_buffer, result_buffer); if (!compare_result.ok()) { - LOG(ERROR) << "Unable to compare " - << AlgorithmToString(*first_algorithm) << " against " - << AlgorithmToString(alg) << " for " << instr->ToString() - << ": " << compare_result.status(); + LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm) + << " against " << AlgorithmToString(alg) << " for " + << instr->ToString() << ": " << compare_result.status(); + if (compare_result.status().code() == + tensorflow::error::RESOURCE_EXHAUSTED) { + // Possibly OOM. Propatate the error. + return compare_result.status(); + } CHECK(!crash_on_checking_failure); } else if (!compare_result.ValueOrDie()) { - LOG(ERROR) << "Results mismatch between different convolution " - "algorithms. This is likely a bug in convolution, or " - "an excessive loss of precision in convolution. " - << instr->ToString() << " for " - << AlgorithmToString(*first_algorithm) << " vs " - << AlgorithmToString(alg); - CHECK(!crash_on_checking_failure); - auto* failure = result.mutable_reference_conv(); - failure->set_algorithm(first_algorithm->algo_id()); - failure->set_tensor_ops_enabled(first_algorithm->tensor_ops_enabled()); - } - } else if (cross_check_enabled) { - auto comp = F16BufferComparator::Create( - se::DeviceMemory(result_buffer), compiler_, allocator, - &stream); - if (comp.ok()) { - comparator.emplace(comp.ConsumeValueOrDie()); - first_algorithm.emplace(alg); - } else { - LOG(ERROR) << "Fail to initialize buffer comparator: " << comp.status() - << ", instruction: " << instr->ToString(); - CHECK(!crash_on_checking_failure); + LOG(ERROR) + << "Results mismatch between different convolution algorithms. " + "This is likely a bug/unexpected loss of precision in cudnn.\n" + << instr->ToString() << " for " + << AlgorithmToString(first_algorithm) << " vs " + << AlgorithmToString(alg); + PrintPlatformInfo(&stream); + auto* fail = result.mutable_failure(); + fail->set_kind(AutotuneResult::WRONG_RESULT); + auto* reference_conv = fail->mutable_reference_conv(); + reference_conv->set_algorithm(first_algorithm.algo_id()); + reference_conv->set_tensor_ops_enabled( + first_algorithm.tensor_ops_enabled()); } + } else { + XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::Create", 2); + comparator.emplace(result_shape, hlo_module_config); + reference_result_buffer = result_buffer; + TF_ASSIGN_OR_RETURN(result_buffer, + input_output_allocator.AllocateBytes( + &stream, reference_result_buffer.size())); + initialize_buffer(result_buffer); + first_algorithm = alg; } } // Log the autotuning result. { - AutotuneLog log; - *log.mutable_instr()->mutable_instruction() = instr->ToProto(); - for (const auto* op : instr->operands()) { - *log.mutable_instr()->add_operand_shapes() = op->shape().ToProto(); + tensorflow::AutotuningLog log; + { + ConvInstructionLog instr_log; + *instr_log.mutable_instruction() = instr->ToProto(); + for (const auto* op : instr->operands()) { + *instr_log.add_operand_shapes() = op->shape().ToProto(); + } + log.mutable_instr()->PackFrom(instr_log); } for (const auto& profile : profile_results) { *log.add_results() = profile; } *log.mutable_compute_capability() = GetComputeCapability(stream_exec_); *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_); - VLOG(2) << "Autotuning result:\n" << log.DebugString(); - tensorflow::Logger::Singleton()->LogProto(log); + log.set_device_pci_bus_id( + stream_exec_->GetDeviceDescription().pci_bus_id()); + // If we crash on checking failure, we are in a testing/benchmark mode, thus + // print more information instead of logging to the logger. + if (crash_on_checking_failure) { + LOG(INFO) << "Autotuning result: " << log.ShortDebugString(); + } else { + VLOG(2) << "Autotuning result:\n" << log.DebugString(); + tensorflow::Logger::Singleton()->LogProto(log); + } } - auto* profile_results_end = profile_results.data() + profile_results.size(); + // Crash on miscompares and redzone violations if desired. Do this after + // logging the autotuning results, otherwise we won't get any data! + for (const auto& result : profile_results) { + if (result.has_failure()) { + CHECK(!crash_on_checking_failure); + } + } - const AutotuneResult* best_result = std::min_element( - profile_results.data(), profile_results_end, - [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - // The successful one should have a smaller key, since we are doing - // min_element. If they are both unsuccessful, keep the earlier one in - // the vector by comparing pointers. - return std::make_tuple( - !lhs.has_success(), - protobuf_util::FromDurationProto(lhs.success().run_time()), - &lhs) < std::make_tuple(!rhs.has_success(), - protobuf_util::FromDurationProto( - rhs.success().run_time()), - &rhs); + // Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED + // error. + // + // For now, we ignore WRONG_RESULT failures because false-positives are + // possible (e.g. perhaps the reference algorithm is the one that's + // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're + // quite severe and can be detected with high accuracy. + // + // TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs + // in the output of the conv and skip those. + // + // The successful one should have a smaller key, since we are doing + // min_element. If they are both unsuccessful, keep the earlier one in + // the vector by comparing pointers. + auto result_comparison_key = [](const AutotuneResult& r) { + return std::make_tuple( + r.has_failure() && r.failure().kind() != AutotuneResult::WRONG_RESULT, + tensorflow::proto_utils::FromDurationProto(r.run_time())); + }; + const auto& best_result = absl::c_min_element( + profile_results, + [&](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return result_comparison_key(lhs) < result_comparison_key(rhs); }); - if (best_result != profile_results_end && best_result->has_success()) { + if (best_result != profile_results.end() && !best_result->has_failure()) { return *best_result; } @@ -363,7 +529,7 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( auto best_algo = std::move(best_algo_or).ValueOrDie(); VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.conv().algorithm() << " and " - << NumBytesToString(best_algo.success().scratch_bytes()) + << NumBytesToString(best_algo.scratch_bytes()) << " of scratch memory: " << instr->ToString() << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled(); @@ -372,7 +538,7 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( HloComputation* computation = instr->parent(); Shape new_call_shape = ShapeUtil::MakeTupleShape( {instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {best_algo.success().scratch_bytes()})}); + ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()})}); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, instr->backend_config()); @@ -418,11 +584,25 @@ StatusOr CudnnConvAlgorithmPicker::RunOnComputation( } StatusOr CudnnConvAlgorithmPicker::Run(HloModule* module) { + XLA_SCOPED_LOGGING_TIMER("CudnnConvAlgorithmPicker"); + + if (module->config().debug_options().xla_gpu_disable_autotune()) { + VLOG(2) << "Convolution auto-tuning disabled, CudnnConvAlgorithmPicker " + "returning early."; + return false; + } + bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); changed |= result; } + + { + tensorflow::mutex_lock lock(autotune_cache_lock); + autotune_cache_stats.LogStats(); + } + return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index 2e34ba96723..9e8a797739a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -19,13 +19,13 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/gpu/autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/protobuf/autotuning.pb.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace gpu { @@ -38,7 +38,8 @@ class CudnnConvAlgorithmPicker : public HloModulePass { // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. CudnnConvAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator, Compiler* compiler) + se::DeviceMemoryAllocator* allocator, + Compiler* compiler) : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} absl::string_view name() const override { @@ -50,11 +51,13 @@ class CudnnConvAlgorithmPicker : public HloModulePass { private: StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr PickBestAlgorithm( + StatusOr PickBestAlgorithm( + const HloCustomCallInstruction* instr); + StatusOr PickBestAlgorithmNoCache( const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null - DeviceMemoryAllocator* allocator_; // may be null + se::DeviceMemoryAllocator* allocator_; // may be null Compiler* compiler_; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index b628f27f4b2..cd0198e2cb9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -364,7 +364,6 @@ StatusOr GetCudnnConvParams( params.output_buf = operand_buffers[1]; break; case CudnnConvKind::kForwardActivation: { - params.kind = CudnnConvKind::kForwardActivation; params.input_shape = &lhs_shape; params.filter_shape = &rhs_shape; params.output_shape = &conv_result_shape; diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_context.h b/tensorflow/compiler/xla/service/gpu/cusolver_context.h index fdd89c3a8d5..68b5fb14c6b 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_context.h +++ b/tensorflow/compiler/xla/service/gpu/cusolver_context.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "cuda/include/cublas_v2.h" -#include "cuda/include/cusolverDn.h" +#include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cusolverDn.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc index 7861eb1ef04..2ba6e8fc3c5 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.cc @@ -174,9 +174,9 @@ StatusOr CusolverRewriter::RunOnComputation(HloComputation* computation) { const auto device_ordinal = stream_exec_->device_ordinal(); // allocator either points to this->allocator_ or, if that's null, to a - // StreamExecutorMemoryAllocator for stream_exec_. - DeviceMemoryAllocator* allocator; - absl::optional se_allocator; + // se::StreamExecutorMemoryAllocator for stream_exec_. + se::DeviceMemoryAllocator* allocator; + absl::optional se_allocator; if (allocator_ != nullptr) { allocator = allocator_; } else { @@ -200,7 +200,7 @@ StatusOr CusolverRewriter::RunOnComputation(HloComputation* computation) { } CusolverRewriter::CusolverRewriter(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) + se::DeviceMemoryAllocator* allocator) : stream_exec_(stream_exec), allocator_(allocator) {} StatusOr CusolverRewriter::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h index c82233188f7..d8c2cc55872 100644 --- a/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cusolver_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace gpu { @@ -30,7 +30,7 @@ namespace gpu { class CusolverRewriter : public HloModulePass { public: CusolverRewriter(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator); + se::DeviceMemoryAllocator* allocator); absl::string_view name() const override { return "cusolver-rewriter"; } StatusOr Run(HloModule* module) override; @@ -39,7 +39,7 @@ class CusolverRewriter : public HloModulePass { StatusOr RunOnComputation(HloComputation* computation); se::StreamExecutor* stream_exec_; // never null - DeviceMemoryAllocator* allocator_; // may be null + se::DeviceMemoryAllocator* allocator_; // may be null }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_test.cc b/tensorflow/compiler/xla/service/gpu/custom_call_test.cc new file mode 100644 index 00000000000..c04f6fb7bf5 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/custom_call_test.cc @@ -0,0 +1,189 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/includes/cuda_headers/third_party/gpus/cuda/include/driver_types.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class CustomCallTest : public ClientLibraryTestBase {}; + +bool is_invoked_called = false; +void Callback_IsInvoked(CUstream /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/) { + is_invoked_called = true; +} +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, "CUDA"); + +TEST_F(CustomCallTest, IsInvoked) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_IsInvoked", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), + /*opaque=*/""); + EXPECT_FALSE(is_invoked_called); + TF_ASSERT_OK(Execute(&b, {}).status()); + EXPECT_TRUE(is_invoked_called); +} + +TEST_F(CustomCallTest, UnknownTarget) { + XlaBuilder b(TestName()); + CustomCall(&b, "UknownTarget", /*operands=*/{}, ShapeUtil::MakeShape(F32, {}), + /*opaque=*/""); + ASSERT_FALSE(Execute(&b, {}).ok()); +} + +void Callback_Memcpy(CUstream stream, void** buffers, const char* /*opaque*/, + size_t /*opaque_len*/) { + void* src = buffers[0]; + void* dst = buffers[1]; + auto err = cudaMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128, + cudaMemcpyDeviceToDevice, stream); + ASSERT_EQ(err, cudaSuccess); +} +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, "CUDA"); +TEST_F(CustomCallTest, Memcpy) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_Memcpy", + /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})}, + ShapeUtil::MakeShape(F32, {128}), /*opaque=*/""); + TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {})); + EXPECT_THAT(result.data(), ::testing::Each(42)); +} + +// Check that opaque handles nulls within the string. +std::string& kExpectedOpaque = *new std::string("abc\0def", 7); +void Callback_Opaque(CUstream /*stream*/, void** /*buffers*/, + const char* opaque, size_t opaque_len) { + std::string opaque_str(opaque, opaque_len); + ASSERT_EQ(opaque_str, kExpectedOpaque); +} +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, "CUDA"); +TEST_F(CustomCallTest, Opaque) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_Opaque", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {}), kExpectedOpaque); + TF_ASSERT_OK(Execute(&b, {}).status()); +} + +void Callback_SubBuffers(CUstream stream, void** buffers, + const char* /*opaque*/, size_t /*opaque_len*/) { + // `buffers` is a flat array containing device pointers to the following. + // + // 0: root tuple of param 0 + // 1: param 0 at tuple index {0}, shape f32[128] + // 2: param 0 at tuple index {1}, shape f32[256] + // 3: root tuple of param 1 + // 4: param 1 at tuple index {0}, shape f32[1024] + // 5: param 1 at tuple index {1}, shape f32[8] + // 6: root tuple of custom-call result + // 7: result at tuple index {0}, shape f32[8] + // 8: result at tuple index {1}, shape (f32[128], f32[256]) + // 9: result at tuple index {1, 0}, shape f32[128] + // 10: result at tuple index {1, 1}, shape f32[256] + // 11: result at tuple index {2}, shape f32[1024] + // + // It's the contract of custom-call that the non-root pointers (i.e. + // everything other than indices 0, 3, and 6) may be null, if XLA is unable to + // analyze the program well enough to determine for sure what's in those + // buffers. For this simple example, all of the buffers should be non-null. + + // Check the param 0 tuple, namely that + // + // (*buffers[0])[0] == buffers[1] and + // (*buffers[0])[1] == buffers[2]. + // + // because buffers contains pointers to device memory, we have to retrieve + // these values via cudaMemcpy. + void* p0[2]; + cudaMemcpy(p0, buffers[0], 2 * sizeof(void*), cudaMemcpyDeviceToHost); + ASSERT_EQ(p0[0], buffers[1]); + ASSERT_EQ(p0[1], buffers[2]); + + // Check the param 1 tuple, namely that + // + // (*buffers[3])[0] == buffers[4] + // (*buffers[3])[1] == buffers[5]. + void* p1[2]; + cudaMemcpy(p1, buffers[3], 2 * sizeof(void*), cudaMemcpyDeviceToHost); + ASSERT_EQ(p1[0], buffers[4]); + ASSERT_EQ(p1[1], buffers[5]); + + // We don't have an equivalent check for the output tuple (i.e. we don't check + // (*buffers[6])[0] == buffers[7]) because it's up to us to set the tuple + // as part of this custom-call. + + // Write the results. First set the root tuple output buffer to {b7, b8, + // b11}. + void* root[3] = {buffers[7], buffers[8], buffers[11]}; + cudaMemcpy(buffers[6], root, 3 * sizeof(void*), cudaMemcpyHostToDevice); + + // Now set the sub-tuple output buffer at index 8 to {b9, b10}. + void* sub_tuple[2] = {buffers[9], buffers[10]}; + cudaMemcpy(buffers[8], sub_tuple, 2 * sizeof(void*), cudaMemcpyDeviceToHost); + + // Now set output leaf buffers 7, 9, 10, and 11, copying data from the + // corresponding same-sized inputs. + cudaMemcpyAsync(buffers[7], buffers[5], 8 * sizeof(float), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buffers[9], buffers[1], 128 * sizeof(float), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buffers[10], buffers[2], 256 * sizeof(float), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buffers[11], buffers[4], 1024 * sizeof(float), + cudaMemcpyDeviceToDevice, stream); +} +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, "CUDA"); +TEST_F(CustomCallTest, SubBuffers) { + XlaBuilder b(TestName()); + CustomCall(&b, "Callback_SubBuffers", /*operands=*/ + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 1), {128}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + }, + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + }), + /*opaque=*/""); + TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {})); + EXPECT_THAT(result.data({0}), ::testing::Each(4)); + EXPECT_THAT(result.data({1, 0}), ::testing::Each(1)); + EXPECT_THAT(result.data({1, 1}), ::testing::Each(2)); + EXPECT_THAT(result.data({2}), ::testing::Each(3)); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc new file mode 100644 index 00000000000..f0f3152ac98 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -0,0 +1,81 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h" + +#include "absl/strings/str_format.h" +#include "tensorflow/stream_executor/cuda/cuda_stream.h" +#include "tensorflow/stream_executor/gpu/gpu_stream.h" + +namespace xla { +namespace gpu { + +CustomCallThunk::CustomCallThunk( + void* call_target, + std::vector> operand_slices, + ShapeTree result_slices, std::string opaque, + const HloInstruction* instr) + : Thunk(Thunk::kCustomCall, instr), + call_target_(call_target), + operand_slices_(std::move(operand_slices)), + result_slices_(std::move(result_slices)), + opaque_(std::move(opaque)) { + CHECK_EQ(instr->operand_count(), operand_slices_.size()); + for (int64 i = 0; i < instr->operand_count(); ++i) { + const auto& s1 = operand_slices_[i].shape(); + const auto& s2 = instr->operand(i)->shape(); + CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat( + "Shape mismatch between instr->operand(%d) and " + "operand_slices[%d].shape(): %s vs %s", + i, i, s1.ToString(), s2.ToString()); + } + CHECK(ShapeUtil::Equal(instr->shape(), result_slices.shape())) + << absl::StreamFormat( + "Shape mismatch between instr->shape() and result_slices.shape(): " + "%s vs %s.", + instr->shape().ToString(), result_slices.shape().ToString()); +} + +Status CustomCallThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + // gpu_stream is CUstream or e.g. the equivalent type in ROCm. + auto gpu_stream = se::gpu::AsGpuStreamValue(stream); + auto typed_call_target = + reinterpret_cast( + call_target_); + + std::vector buffers; + auto append_buffers = [&](const ShapeTree& slices) { + slices.ForEachElement([&](const ShapeIndex& /*index*/, + const BufferAllocation::Slice& slice) { + if (slice.allocation() == nullptr) { + buffers.push_back(nullptr); + } + buffers.push_back(buffer_allocations.GetDeviceAddress(slice).opaque()); + }); + }; + for (const auto& slices : operand_slices_) { + append_buffers(slices); + } + append_buffers(result_slices_); + + typed_call_target(gpu_stream, buffers.data(), opaque_.data(), opaque_.size()); + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h new file mode 100644 index 00000000000..9011fa26ffa --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" + +namespace xla { +namespace gpu { + +// Thunk to run a GPU custom call. +// +// This thunk's `ExecuteOnStream` implementation executes a host function +// `call_target` which is expected to enqueue operations onto the GPU. +// +// For information about the calling convention, see xla/g3doc/custom_call.md +// +// Note that not all kCustomCall HLOs in XLA:GPU end up being run by this thunk. +// XLA itself creates kCustomCall instructions when lowering kConvolution HLOs +// into calls to cudnn. These internally-created custom-calls are run using +// ConvolutionThunk, not CustomCallThunk. There's no ambiguity because they +// have special call target names (e.g. "__cudnn$convForward") that only the +// compiler is allowed to create. +class CustomCallThunk : public Thunk { + public: + CustomCallThunk( + void* call_target, + std::vector> operand_slices, + ShapeTree result_slices, std::string opaque, + const HloInstruction* instr); + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + void* call_target_; + std::vector> operand_slices_; + ShapeTree result_slices_; + std::string opaque_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index b024bbe4b5e..ffa60da6f16 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -71,7 +71,6 @@ GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, llvm::IRBuilder<>* b, NestedComputer compute_nested) : ElementalIrEmitter(hlo_module_config, module, b), - hlo_module_config_(hlo_module_config), compute_nested_(std::move(compute_nested)) {} StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( @@ -271,16 +270,6 @@ StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return FPCast(fast_tanh, value->getType()); } -StatusOr GpuElementalIrEmitter::EmitRoundNearestAfz( - PrimitiveType prim_type, llvm::Value* value) { - // Use libdevice __nv_round instead of llvm.round. This is to workaround a - // bug in the PTX backend, which implements llvm.round with PTX cvt.rni. - // When the llvm.round is fixed, we may still want to use __nv_round here as - // expanding the non-trivial implementation early while inlining allows better - // optimizations. - return EmitLibdeviceMathCall("__nv_round", {value}, {prim_type}, prim_type); -} - llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( const string& callee_name, absl::Span operands, absl::Span input_types, PrimitiveType output_type, @@ -299,7 +288,7 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( llvm::Function* callee = llvm::dyn_cast( b_->GetInsertBlock() ->getModule() - ->getOrInsertFunction(llvm_ir::AsStringRef(callee_name), callee_type) + ->getOrInsertFunction(callee_name, callee_type) .getCallee()); for (auto attribute : attributes) { @@ -385,12 +374,12 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); - IrArray::Index input_index(index_type, index.size()); + std::vector input_multi_index(index.size()); llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = NSWSub( + input_multi_index[i] = NSWSub( NSWAdd(stridden_index, NSWMul(window_index[i], index_typed_const( @@ -399,24 +388,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( // We need to verify that we are not in the dilated base area. llvm::Value* dilation_condition = ICmpEQ( - SRem(input_index[i], + SRem(input_multi_index[i], index_typed_const(window.dimensions(i).base_dilation())), index_typed_const(0)); in_bounds = And(in_bounds, dilation_condition); // Apply base dilation to the index. - input_index[i] = - SDiv(input_index[i], + input_multi_index[i] = + SDiv(input_multi_index[i], index_typed_const(window.dimensions(i).base_dilation())); - // We must check whether 0 ≤ input_index[i] < bound, as otherwise - // we are in the pad and so can skip the computation. This + // We must check whether 0 ≤ input_multi_index[i] < bound, as + // otherwise we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison - // input_index[i] < bound, as a negative value wraps to a large + // input_multi_index[i] < bound, as a negative value wraps to a large // positive value. in_bounds = And(in_bounds, - ICmpULT(input_index[i], + ICmpULT(input_multi_index[i], index_typed_const(operand->shape().dimensions(i)))); } @@ -425,6 +414,8 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(if_data.true_block, b_); // We are not in pad, so do the computation. + IrArray::Index input_index(input_multi_index, operand->shape(), + index_type); TF_ASSIGN_OR_RETURN(llvm::Value * input_value, operand_to_generator.at(operand)(input_index)); TF_ASSIGN_OR_RETURN( @@ -451,19 +442,22 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( b()->CreateStore(init_value, accum_ptr); llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); - IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( - operand->shape(), hlo->dimensions(), "reduction_dim"); + std::vector input_multi_index = + loops.AddLoopsForShapeOnDimensions( + operand->shape(), hlo->dimensions(), "reduction_dim"); if (!ShapeUtil::IsScalar(hlo->shape())) { - // Here only input_index[hlo->dimensions()] are non-null, so we must - // set the rest. + // Here only input_multi_index[hlo->dimensions()] are non-null, so we + // must set the rest. size_t j = 0; - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = output_index[j++]; + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = output_index[j++]; } } CHECK_EQ(output_index.size(), j); } + llvm_ir::IrArray::Index input_index( + input_multi_index, hlo->operand(0)->shape(), index_type); SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index e9d08177ad9..466543a2f92 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -91,9 +91,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; - StatusOr EmitRoundNearestAfz(PrimitiveType prim_type, - llvm::Value* value) override; - llvm::Value* EmitThreadId() override; private: @@ -129,7 +126,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const string& callee_name, absl::Span operands, absl::Span input_types, PrimitiveType output_type); - const HloModuleConfig& hlo_module_config_; NestedComputer compute_nested_; }; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index ca4a605af5d..1609f0d60c4 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -29,7 +29,7 @@ namespace xla { namespace gpu { FftScratchAllocator::FftScratchAllocator( - int device_ordinal, DeviceMemoryAllocator* memory_allocator) + int device_ordinal, se::DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { @@ -48,12 +48,12 @@ StatusOr> FftScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, memory_allocator_->Allocate(device_ordinal_, byte_size, /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + se::DeviceMemoryBase buffer_addr = *allocated_buffer; allocated_buffers_.push_back(std::move(allocated_buffer)); return se::DeviceMemory(buffer_addr); } diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 2be50e08bd2..f653e4f12fe 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -38,7 +38,7 @@ namespace gpu { class FftScratchAllocator : public se::ScratchAllocator { public: FftScratchAllocator(int device_ordinal, - DeviceMemoryAllocator* memory_allocator); + se::DeviceMemoryAllocator* memory_allocator); int64 GetMemoryLimitInBytes(se::Stream* stream) override; @@ -49,8 +49,8 @@ class FftScratchAllocator : public se::ScratchAllocator { private: const int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + se::DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 853a09213b1..4103605df99 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -95,27 +96,6 @@ double CalculateBytesReadByFusionInstruction(HloInstruction* fusion) { return bytes; } -// Returns the flops to bytes transferred ratio of instruction 'fusion'. -double CalculateFlopsToBytesRatio(HloInstruction* fusion) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - // Calculate total bytes transferred in/out. - double bytes = CalculateBytesReadByFusionInstruction(fusion); - // Add bytes written to root instructions buffer. - if (fusion->IsMultiOutputFusion()) { - for (auto& operand : fusion->fused_expression_root()->operands()) { - bytes += ShapeUtil::ByteSizeOf(operand->shape()); - } - } else { - bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); - } - // Calculate flops for all fused instructions. Use a null shape size function - // because we don't care about bytes accessed by the ops. - HloCostAnalysis analysis([](const Shape& shape) { return 0; }); - TF_CHECK_OK(fusion->fused_expression_root()->Accept(&analysis)); - // Return flops / bytes. - return bytes > 0.0 ? analysis.flop_count() / bytes : analysis.flop_count(); -} - // Returns bytes transferred by instruction 'fusion', including the bytes // that would be read by all users. double GetCurrentBytesTransferred(HloInstruction* fusion) { @@ -169,8 +149,8 @@ class FusionInstructionMerger { int num_fail_not_loop_fusion_ = 0; int num_fail_merge_all_users_ = 0; int num_fail_expensive_fused_instruction_ = 0; - int num_fail_flops_to_byte_ratio_ = 0; int num_fail_net_bytes_transferred_ratio_ = 0; + int num_fail_inefficient_fusion_emitter_ = 0; TF_DISALLOW_COPY_AND_ASSIGN(FusionInstructionMerger); }; @@ -190,15 +170,13 @@ Status FusionInstructionMerger::Run() { << " not_loop_fusion: " << num_fail_not_loop_fusion_ << " merge_all_users: " << num_fail_merge_all_users_ << " expensive_instruction: " << num_fail_expensive_fused_instruction_ - << " flops_to_byte_ratio: " << num_fail_flops_to_byte_ratio_ << " net_bytes_transferred: " << num_fail_net_bytes_transferred_ratio_ - << " }"; + << " inefficient_fusion_emitter: " + << num_fail_inefficient_fusion_emitter_ << " }"; return Status::OK(); } Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { - VLOG(3) << "FusionInstructionMerger ENTRY fusion: " << fusion->name() - << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion); ++total_visited_; // Skip 'fusion' instruction if there are no users into which we can merge. if (fusion->users().empty()) { @@ -211,7 +189,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // instructions match specific patterns, so they shouldn't be further fused. // Input fusion instructions need to be rooted at a particular HLO (e.g. // kReduce), so they shouldn't be further fused either. - if (fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { + if (!fusion->IsLoopFusion()) { VLOG(3) << "Not merging " << fusion->name() << ": Is not loop fusion."; ++num_fail_not_loop_fusion_; return Status::OK(); @@ -228,7 +206,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // computation. if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && - (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + (user->IsLoopFusion() || (IsReduceInputFusion(*user) && LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { @@ -256,15 +234,6 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - // Skip 'fusion' instruction if its flops to bytes transferred ratio - // exceeds the threshold value. - if (CalculateFlopsToBytesRatio(fusion) > - FusionMerger::GetThresholdFlopsToBytesRatio()) { - VLOG(3) << "Not merging " << fusion->name() - << ": flops-to-bytes ratio is not favorable."; - ++num_fail_flops_to_byte_ratio_; - return Status::OK(); - } // Skip 'fusion' instruction if merging it into all users would result in a // net increase in bytes transferred (currently allowing the net bytes // transferred to be exceeded up to ~10% in exhange for eliminating the @@ -280,6 +249,23 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { ++num_fail_net_bytes_transferred_ratio_; return Status::OK(); } + + // Skip 'fusion' instruction if merging it into at least one of the users + // would cause too much code duplication because of inefficiencies in the + // fusion emitter. + // TODO(b/119692968): Remove this once the fusion emitter can handle arbitrary + // fusion nodes. + if (absl::c_any_of(fusion->users(), [fusion](const HloInstruction* user) { + return FusedIrEmitter::IsFusedIrEmitterInefficient(/*consumer=*/user, + /*producer=*/fusion); + })) { + VLOG(3) << "Not merging " << fusion->name() + << ": Contains one or more users where fusing would cause " + "inefficiencies in the fusion emitter."; + ++num_fail_inefficient_fusion_emitter_; + return Status::OK(); + } + // Merge fused instructions from 'fusion' into each user. std::vector users = fusion->users(); for (HloInstruction* user : users) { @@ -288,7 +274,6 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { } ++total_merged_; VLOG(2) << "Merged fusion instruction: " << fusion->name() - << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion) << " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio << " into users { " << absl::StrJoin(users, ", ", diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index f19996edfe3..a49d68002f8 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -37,8 +37,6 @@ class FusionMerger : public HloModulePass { absl::string_view name() const override { return "fusion merger"; } StatusOr Run(HloModule* module) override; - - static double GetThresholdFlopsToBytesRatio() { return 1.0; } }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 7cc869ed9e8..1d937d51316 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -99,62 +99,6 @@ ENTRY MergeSharedFusionInstruction.Computation0 { EXPECT_EQ(7, operand2->fused_instruction_count()); } -// Tests that we do not merge a fusion instruction that above flops to bytes -// threshold. -// -// Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. -TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - auto module = ParseHloString(R"( -HloModule FlopsToBytesRatioThresholdExceeded - -comp.2 { - state.param_1.1 = (f32[4]{0}, f32[4]{0}) parameter(0) - get-tuple-element.3 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 - get-tuple-element.4 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 - multiply.29 = f32[4]{0} multiply(get-tuple-element.3, get-tuple-element.4) - multiply.30 = f32[4]{0} multiply(get-tuple-element.3, multiply.29) - multiply.31 = f32[4]{0} multiply(get-tuple-element.3, multiply.30) - multiply.32 = f32[4]{0} multiply(get-tuple-element.3, multiply.31) - multiply.33 = f32[4]{0} multiply(get-tuple-element.3, multiply.32) - multiply.34 = f32[4]{0} multiply(get-tuple-element.3, multiply.33) - multiply.35 = f32[4]{0} multiply(get-tuple-element.3, multiply.34) - multiply.36 = f32[4]{0} multiply(get-tuple-element.3, multiply.35) - multiply.37 = f32[4]{0} multiply(get-tuple-element.3, multiply.36) - multiply.38 = f32[4]{0} multiply(get-tuple-element.3, multiply.37) - multiply.39 = f32[4]{0} multiply(get-tuple-element.3, multiply.38) - multiply.40 = f32[4]{0} multiply(get-tuple-element.3, multiply.39) - ROOT multiply.41 = f32[4]{0} multiply(get-tuple-element.3, multiply.40) -} - -comp.1 { - multiply.12.param_1.1 = f32[4]{0} parameter(1) - constant.param_1.3 = f32[4]{0} parameter(0) - add.3 = f32[4]{0} add(multiply.12.param_1.1, constant.param_1.3) - ROOT multiply.16 = f32[4]{0} multiply(add.3, constant.param_1.3) -} - -comp { - multiply.12.param_1 = f32[4]{0} parameter(1) - constant.param_1.1 = f32[4]{0} parameter(0) - multiply.15 = f32[4]{0} multiply(multiply.12.param_1, constant.param_1.1) - ROOT add.2 = f32[4]{0} add(multiply.15, constant.param_1.1) -} - -ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { - constant = f32[4]{0} constant({1, 1, 1, 1}) - state = (f32[4]{0}, f32[4]{0}) parameter(0) - fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 - fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 - fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp - ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) -})") - .ValueOrDie(); - // Run fusion merger pass, which should detect that the flops/bytes of the - // shared fusion instruction exceeds the threshold ratio, and therefore - // cannot be merged with other fusion instructions. - EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); -} - // Tests that threshold for bytes transferred if merged is exceeded. // // Fusion2 is not merged because it exceeds the threshold bytes transferred. @@ -257,8 +201,8 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { HloModule m f1_computation { - f1_p0 = f32[10]{0} parameter(0) - ROOT f1_root = f32[10]{0} add(f1_p0, f1_p0) + f1_p0 = f32[32]{0} parameter(0) + ROOT f1_root = f32[32]{0} add(f1_p0, f1_p0) } add_computation { @@ -268,16 +212,16 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { } f2_computation { - f2_p0 = f32[10]{0} parameter(0) - f2_mul = f32[10]{0} multiply(f2_p0, f2_p0) + f2_p0 = f32[32]{0} parameter(0) + f2_mul = f32[32]{0} multiply(f2_p0, f2_p0) f2_zero = f32[] constant(0) ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0}, to_apply=add_computation } ENTRY entry { - p0 = f32[10]{0} parameter(0) - f1 = f32[10]{0} fusion(p0), kind=kLoop, calls=f1_computation + p0 = f32[32]{0} parameter(0) + f1 = f32[32]{0} fusion(p0), kind=kLoop, calls=f1_computation ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation })") .ValueOrDie(); @@ -319,6 +263,62 @@ TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) { EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); } +// TODO(b/119692968): Remove this test once fusion emitter is fixed. +TEST_F(FusionMergerTest, WillNotMergeIfFusionEmitterIsInefficient) { + auto module = ParseHloString(R"( + HloModule m + + %fused_computation (param_0.10: f32[6]) -> f32[1] { + %param_0.10 = f32[6]{0} parameter(0) + %add.7 = f32[6]{0} add(%param_0.10, %param_0.10) + %slice.21 = f32[5]{0} slice(%add.7), slice={[0:5]} + %slice.18 = f32[5]{0} slice(%add.7), slice={[1:6]} + %add.5 = f32[5]{0} add(%slice.21, %slice.18) + %slice.15 = f32[4]{0} slice(%add.5), slice={[0:4]} + %slice.12 = f32[4]{0} slice(%add.5), slice={[1:5]} + %add.4 = f32[4]{0} add(%slice.15, %slice.12) + %slice.9 = f32[3]{0} slice(%add.4), slice={[0:3]} + %slice.6 = f32[3]{0} slice(%add.4), slice={[1:4]} + %add.2 = f32[3]{0} add(%slice.9, %slice.6) + %slice.3 = f32[2]{0} slice(%add.2), slice={[0:2]} + %slice.2 = f32[2]{0} slice(%add.2), slice={[1:3]} + %add.1 = f32[2]{0} add(%slice.3, %slice.2) + %slice.1 = f32[1]{0} slice(%add.1), slice={[0:1]} + %slice.0 = f32[1]{0} slice(%add.1), slice={[1:2]} + ROOT %add.0 = f32[1]{0} add(%slice.1, %slice.0) + } + + %fused_computation.1 (param_0.21: f32[11], param_1.21: f32[11]) -> f32[6] { + %param_0.21 = f32[11]{0} parameter(0) + %param_1.21 = f32[11]{0} parameter(1) + %add.16 = f32[11]{0} add(%param_0.21, %param_1.21) + %slice.51 = f32[10]{0} slice(%add.16), slice={[0:10]} + %slice.48 = f32[10]{0} slice(%add.16), slice={[1:11]} + %add.14 = f32[10]{0} add(%slice.51, %slice.48) + %slice.45 = f32[9]{0} slice(%add.14), slice={[0:9]} + %slice.42 = f32[9]{0} slice(%add.14), slice={[1:10]} + %add.13 = f32[9]{0} add(%slice.45, %slice.42) + %slice.39 = f32[8]{0} slice(%add.13), slice={[0:8]} + %slice.36 = f32[8]{0} slice(%add.13), slice={[1:9]} + %add.11 = f32[8]{0} add(%slice.39, %slice.36) + %slice.33 = f32[7]{0} slice(%add.11), slice={[0:7]} + %slice.30 = f32[7]{0} slice(%add.11), slice={[1:8]} + %add.10 = f32[7]{0} add(%slice.33, %slice.30) + %slice.27 = f32[6]{0} slice(%add.10), slice={[0:6]} + %slice.24 = f32[6]{0} slice(%add.10), slice={[1:7]} + ROOT %add.8 = f32[6]{0} add(%slice.27, %slice.24) + } + + ENTRY entry { + p0 = f32[11]{0} parameter(0) + p1 = f32[11]{0} parameter(1) + f1 = f32[6]{0} fusion(p0, p1), kind=kLoop, calls=%fused_computation.1 + ROOT f2 = f32[1] fusion(f1), kind=kLoop, calls=%fused_computation + })") + .ValueOrDie(); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index a7053e6a013..9bbe1ab5a38 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -17,12 +17,14 @@ limitations under the License. #include -#include "absl/strings/str_cat.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory.h" namespace xla { namespace gpu { @@ -32,71 +34,25 @@ namespace { // This struct contains the metadata of a matrix, e.g., its base address and // dimensions. struct MatrixDescriptor { - MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose, - int64 matrix_num_rows, int64 matrix_num_cols, - int64 matrix_batch_size) - : data(matrix_data), - transpose(needs_transpose), - num_rows(matrix_num_rows), - num_cols(matrix_num_cols), - batch_size(matrix_batch_size) {} - se::DeviceMemoryBase data; bool transpose; // Whether this matrix needs to be transposed. int64 num_rows; int64 num_cols; - int64 batch_size; }; -// Performs a gemm call without an explicit algorithm on lhs_matrix and -// rhs_matrix, and stores the result to output_matrix. -template -bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, double alpha, double beta, - se::Stream* stream) { - DCHECK(!output_matrix.transpose); +using GemmCacheKey = + std::tuple; - const int64 batch_size = lhs_matrix.batch_size; - CHECK_EQ(batch_size, rhs_matrix.batch_size); - CHECK_EQ(batch_size, output_matrix.batch_size); - se::DeviceMemory lhs_data(lhs_matrix.data); - se::DeviceMemory rhs_data(rhs_matrix.data); - se::DeviceMemory output_data(output_matrix.data); +tensorflow::mutex autotune_cache_mu(tensorflow::LINKER_INITIALIZED); +auto& autotune_cache GUARDED_BY(autotune_cache_mu) = *new absl::flat_hash_map< + GemmCacheKey, absl::optional>(); +int64 cache_hits GUARDED_BY(autotune_cache_mu) = 0; +int64 cache_misses GUARDED_BY(autotune_cache_mu) = 0; - auto lhs_transpose = lhs_matrix.transpose ? se::blas::Transpose::kTranspose - : se::blas::Transpose::kNoTranspose; - auto rhs_transpose = rhs_matrix.transpose ? se::blas::Transpose::kTranspose - : se::blas::Transpose::kNoTranspose; - auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; - - if (batch_size == 1) { - return stream - ->ThenBlasGemm( - lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta, - &output_data, /*leading dim of output=*/output_matrix.num_rows) - .ok(); - } - - int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols; - int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols; - int64 output_stride = output_matrix.num_rows * output_matrix.num_cols; - return stream - ->ThenBlasGemmStridedBatched( - lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, - /*alpha=*/alpha, lhs_data, - /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, - /*beta=*/beta, &output_data, - /*leading dim of output=*/output_matrix.num_rows, output_stride, - batch_size) - .ok(); -} - -// Like DoGemm, but takes an explicit computation type and algorithm. +// Performs a gemm call on lhs_matrix and rhs_matrix, and stores the result +// to output_matrix. +// // computation_type specifies the type of intermediate values generated during // the matmul (e.g. your input/output matricies could be f16s but you could do // computations with f32s). algorithm is an opaque identifier which functions @@ -111,19 +67,16 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, // the Stream was valid to begin with); check the is_valid property of the // ProfileResult to see whether the call actually succeeded. template -bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, +bool DoGemmWithAlgorithm(int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix, double alpha, double beta, se::blas::ComputationType computation_type, - se::blas::AlgorithmType algorithm, se::Stream* stream, + se::Stream* stream, + absl::optional algorithm, se::blas::ProfileResult* output_profile_result) { DCHECK(!output_matrix.transpose); - CHECK_EQ(1, lhs_matrix.batch_size); - CHECK_EQ(1, rhs_matrix.batch_size); - CHECK_EQ(1, output_matrix.batch_size); - se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -134,17 +87,45 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, : se::blas::Transpose::kNoTranspose; auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; - return stream - ->ThenBlasGemmWithAlgorithm( - lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, - /*alpha=*/static_cast(alpha), lhs_data, - /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, - /*beta=*/static_cast(beta), &output_data, - /*leading dim of output=*/output_matrix.num_rows, computation_type, - algorithm, output_profile_result) - .ok(); + if (algorithm) { + // Autotuning is disabled for batch_size != 1. + CHECK_EQ(1, batch_size); + return stream + ->ThenBlasGemmWithAlgorithm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/static_cast(alpha), lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, + /*beta=*/static_cast(beta), &output_data, + /*leading dim of output=*/output_matrix.num_rows, computation_type, + *algorithm, output_profile_result) + .ok(); + } else if (batch_size == 1) { + return stream + ->ThenBlasGemm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta, + &output_data, /*leading dim of output=*/output_matrix.num_rows) + .ok(); + } else { + int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols; + int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols; + int64 output_stride = output_matrix.num_rows * output_matrix.num_cols; + return stream + ->ThenBlasGemmStridedBatched( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/alpha, lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, + /*beta=*/beta, &output_data, + /*leading dim of output=*/output_matrix.num_rows, output_stride, + batch_size) + .ok(); + } } // Experimentally tries to pick the best algorithm for the given gemm. @@ -154,10 +135,43 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, // than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at // all. template -StatusOr DoGemmAutotune( - MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, double alpha, double beta, - se::blas::ComputationType computation_type, se::Stream* stream) { +absl::optional DoUncachedGemmAutotune( + PrimitiveType type, se::blas::ComputationType computation_type, + int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream, + const Shape output_shape, double alpha, double beta, + absl::string_view instr_descr) { + if (!stream->BlockHostUntilDone().ok()) { + VLOG(2) << "Failed to synchronize GPU for autotuning"; + return absl::nullopt; + } + + VLOG(3) << "Starting autotune of GemmThunk " << instr_descr; + + // If the output buffer already contains a bias then autotune into a + // scratch buffer. This avoids overwriting the bias buffer. The scratch + // buffer may contain arbitrary garbage values. + se::DeviceMemoryBase scratch_data = output_matrix.data; + absl::optional> allocated_memory; + if (beta != 0.0) { + se::DeviceMemory out = stream->parent()->AllocateArray( + ShapeUtil::ByteSizeOf(output_shape)); + + if (out.is_null()) { + VLOG(1) << "Allocation failed, using generic algorthm"; + return absl::nullopt; + } + + // Destructor ensures deallocation at the end of the scope. + allocated_memory.emplace(stream->parent(), out); + scratch_data = out; + } + + const MatrixDescriptor scratch_descriptor{scratch_data, + /*needs_transpose=*/false, + output_matrix.num_rows, + output_matrix.num_cols}; + std::vector algorithms; CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms)); @@ -168,9 +182,9 @@ StatusOr DoGemmAutotune( // for all algorithms if we're targeting < sm_50. But because we pass a // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. - CHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, - alpha, beta, computation_type, algorithm, - stream, &profile_result)); + CHECK(DoGemmWithAlgorithm( + /*batch_size=*/1, lhs_matrix, rhs_matrix, scratch_descriptor, alpha, + beta, computation_type, stream, algorithm, &profile_result)); if (profile_result.is_valid()) { VLOG(3) << "cublas gemm algorithm " << algorithm << " took " @@ -185,65 +199,14 @@ StatusOr DoGemmAutotune( } if (best_result.is_valid()) { + VLOG(2) << "Autotune on GemmThunk " << instr_descr + << " successful; best algorithm is " << best_result.algorithm(); return best_result.algorithm(); } - return InternalError( - "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms " - "ran successfully", - stream, algorithms.size()); -} - -// Helper functions to go from a PrimitiveType to a templated version of -// DoGemm/DoGemmWithAlgorithm/DoGemmAutotune. -auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { - switch (type) { - case F16: - return &DoGemm; - case F32: - return &DoGemm; - case F64: - return &DoGemm; - case C64: - return &DoGemm>; - case C128: - return &DoGemm>; - default: - LOG(FATAL) << "Unsupported type."; - } -} -auto GetGemmWithAlgorithmFn(PrimitiveType type) - -> decltype(&DoGemmWithAlgorithm) { - switch (type) { - case F16: - return &DoGemmWithAlgorithm; - case F32: - return &DoGemmWithAlgorithm; - case F64: - return &DoGemmWithAlgorithm; - case C64: - return &DoGemmWithAlgorithm>; - case C128: - return &DoGemmWithAlgorithm>; - default: - LOG(FATAL) << "Unsupported type."; - } -} -auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { - switch (type) { - case F16: - return &DoGemmAutotune; - case F32: - return &DoGemmAutotune; - case F64: - return &DoGemmAutotune; - case C64: - return &DoGemmAutotune>; - case C128: - return &DoGemmAutotune>; - default: - LOG(FATAL) << "Unsupported type."; - } + VLOG(1) << "Unable to autotune cuBLAS gemm on stream " << stream + << " none of the " << algorithms.size() << " ran successfully"; + return absl::nullopt; } // Converts from an XLA PrimitiveType to a blas::ComputationType, which is used @@ -268,6 +231,48 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { } } +template +absl::optional DoGemmAutotune( + int64 batch_size, MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream, + const Shape output_shape, double alpha, double beta, + absl::string_view instr_descr) { + PrimitiveType type = output_shape.element_type(); + se::blas::ComputationType computation_type = GetBlasComputationType(type); + + tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent()); + + GemmCacheKey key = std::make_tuple( + type, lhs_matrix.transpose, lhs_matrix.num_rows, lhs_matrix.num_cols, + rhs_matrix.transpose, rhs_matrix.num_rows, rhs_matrix.num_cols, alpha, + beta, computation_type, stream->parent()); + + tensorflow::mutex_lock cache_lock(autotune_cache_mu); + auto it = autotune_cache.find(key); + int64 autotuning_requests = cache_hits + cache_misses; + if (autotuning_requests && autotuning_requests % 10 == 0) { + VLOG(2) << "Autotuning cache hits/(hits + misses): " << cache_hits << "/" + << autotuning_requests; + } + + if (it != autotune_cache.end()) { + cache_hits++; + VLOG(4) + << "Autotuning cache hit, using algorithm (-1 stands for 'generic'): " + << it->second.value_or(-1); + return it->second; + } + cache_misses++; + VLOG(4) << "Autotuning cache miss"; + + auto result = DoUncachedGemmAutotune( + type, computation_type, batch_size, lhs_matrix, rhs_matrix, output_matrix, + stream, output_shape, alpha, beta, instr_descr); + + CHECK(autotune_cache.emplace(key, result).second); + return result; +} + DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { if (hlo_instruction.opcode() == HloOpcode::kDot) { return hlo_instruction.dot_dimension_numbers(); @@ -288,6 +293,138 @@ DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { return dot->dot_dimension_numbers(); } +template +Status ExecuteOnStreamParameterized( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler, const BufferAllocation::Slice lhs_buffer, + const BufferAllocation::Slice rhs_buffer, + const BufferAllocation::Slice output_buffer, const Shape lhs_shape, + const Shape rhs_shape, const Shape output_shape, + bool implements_whole_instruction, const HloInstruction* hlo_instruction, + double alpha, double beta, bool xla_gpu_disable_autotune) { + VLOG(2) << "Executing a GemmThunk"; + + se::DeviceMemoryBase lhs_data = + buffer_allocations.GetDeviceAddress(lhs_buffer); + se::DeviceMemoryBase rhs_data = + buffer_allocations.GetDeviceAddress(rhs_buffer); + se::DeviceMemoryBase output_data = + buffer_allocations.GetDeviceAddress(output_buffer); + + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape.rank()); + + int64 row_dim = dim_nums.lhs_batch_dimensions_size(); + int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; + int64 batch_size = std::accumulate(output_shape.dimensions().begin(), + output_shape.dimensions().end() - 2, 1, + std::multiplies()); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_NE(row_dim, batch_dim); + CHECK_NE(col_dim, batch_dim); + } + + // Verify that the non-batch dimensions are minor-most. This is required for + // efficient access. + for (const auto* shape : {&lhs_shape, &rhs_shape, &output_shape}) { + CHECK_LT(shape->layout().minor_to_major(row_dim), 2); + CHECK_LT(shape->layout().minor_to_major(col_dim), 2); + } + + // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between + // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of + // their layout. Therefore, we should treat dimension 0 as row and dimension 1 + // as column when mapping a matrix Dot to BLAS gemm. + int64 output_num_rows = output_shape.dimensions(row_dim); + int64 output_num_cols = output_shape.dimensions(col_dim); + + // BLAS gemm expects the inputs and the output are in column-major order. + // Therefore, we need to convert dot between row-major matrices to that + // between column-major matrices. The key insight for the conversion is that, + // in linear storage, matrix M in column-major order is identical to the + // transpose of M in row-major order. In other words, + // + // column-major(M) = row-major(M^T). + // + // Leveraging this insight, we can perform dot between row-major matrices as + // follows. + // + // row-major(C) + // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T) + // = gemm(column-major(B^T), column-major(A^T)) + // = gemm(row-major(B), row-major(A)) + // + // Although we do not modify the content of A and B in linear memory, we + // should use the dimensions of B^T and A^T when calling gemm. For example, + // the leading dimension of the LHS matrix of gemm is the number of rows in + // B^T and thus the number of columns in B. + auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape, + bool transpose) -> MatrixDescriptor { + bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) != + LayoutUtil::Minor(output_shape.layout(), row_dim); + return MatrixDescriptor{ + data, static_cast(transpose ^ layout_mismatch), + shape.dimensions(row_dim + static_cast(is_row_major)), + shape.dimensions(row_dim + static_cast(!is_row_major))}; + }; + + MatrixDescriptor lhs_matrix = make_descriptor( + lhs_data, lhs_shape, dim_nums.lhs_contracting_dimensions(0) == row_dim); + MatrixDescriptor rhs_matrix = make_descriptor( + rhs_data, rhs_shape, dim_nums.rhs_contracting_dimensions(0) == col_dim); + auto op_profiler = profiler->MakeScopedInstructionProfiler( + implements_whole_instruction ? hlo_instruction : nullptr); + + if (LayoutUtil::Minor(output_shape.layout(), row_dim) != 0) { + std::swap(lhs_matrix, rhs_matrix); + std::swap(output_num_cols, output_num_rows); + } + + const MatrixDescriptor output_matrix{output_data, /*needs_transpose=*/false, + output_num_rows, output_num_cols}; + + // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts + // to autotune this gemm to figure out the best algorithm. + PrimitiveType element_type = output_shape.element_type(); + se::blas::ComputationType computation_type = + GetBlasComputationType(element_type); + + std::string instr_descr = + hlo_instruction != nullptr ? hlo_instruction->ToString() : ""; + + // Try finding the best algorithm by autotuning, or use older Gemm API + // if autotuning is disabled or has failed. + absl::optional best_algorithm; + if (xla_gpu_disable_autotune) { + VLOG(2) << "Autotuning disabled, using generic algorithm"; + } else if (batch_size != 1) { + // TODO(b/112111608): Implement auto tune for batched gemm. + VLOG(2) << "Batch size is non-singular, using generic algorithm"; + } else { + // Autotune may fail for various reasons (e.g. when when CUDA 8 and GPU + // sm_50 or older are used). In that case the returned best_algorithm + // will be an empty optional. + best_algorithm = DoGemmAutotune( + batch_size, lhs_matrix, rhs_matrix, output_matrix, stream, output_shape, + alpha, beta, instr_descr); + } + + bool launch_ok = DoGemmWithAlgorithm( + batch_size, lhs_matrix, rhs_matrix, output_matrix, alpha, beta, + computation_type, stream, best_algorithm, + /*output_profile_result=*/nullptr); + + if (!launch_ok) { + return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); + } + return Status::OK(); +} + } // namespace GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, @@ -311,179 +448,27 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - VLOG(2) << "Executing a GemmThunk"; - - se::DeviceMemoryBase lhs_data = - buffer_allocations.GetDeviceAddress(lhs_buffer_); - se::DeviceMemoryBase rhs_data = - buffer_allocations.GetDeviceAddress(rhs_buffer_); - se::DeviceMemoryBase output_data = - buffer_allocations.GetDeviceAddress(output_buffer_); - - DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); - CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), - dim_nums.rhs_batch_dimensions_size()); - CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank()); - - int64 row_dim = dim_nums.lhs_batch_dimensions_size(); - int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; - int64 batch_size = std::accumulate(output_shape_.dimensions().begin(), - output_shape_.dimensions().end() - 2, 1, - std::multiplies()); - - // Check that the batch dims don't cover the last two dims. - for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { - CHECK_NE(row_dim, batch_dim); - CHECK_NE(col_dim, batch_dim); - } - - // Verify that the non-batch dimensions are minor-most. This is required for - // efficient access. - for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) { - CHECK_LT(shape->layout().minor_to_major(row_dim), 2); - CHECK_LT(shape->layout().minor_to_major(col_dim), 2); - } - - // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between - // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of - // their layout. Therefore, we should treat dimension 0 as row and dimension 1 - // as column when mapping a matrix Dot to BLAS gemm. - int64 output_num_rows = output_shape_.dimensions(row_dim); - int64 output_num_cols = output_shape_.dimensions(col_dim); - - // BLAS gemm expects the inputs and the output are in column-major order. - // Therefore, we need to convert dot between row-major matrices to that - // between column-major matrices. The key insight for the conversion is that, - // in linear storage, matrix M in column-major order is identical to the - // transpose of M in row-major order. In other words, - // - // column-major(M) = row-major(M^T). - // - // Leveraging this insight, we can perform dot between row-major matrices as - // follows. - // - // row-major(C) - // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T) - // = gemm(column-major(B^T), column-major(A^T)) - // = gemm(row-major(B), row-major(A)) - // - // Although we do not modify the content of A and B in linear memory, we - // should use the dimensions of B^T and A^T when calling gemm. For example, - // the leading dimension of the LHS matrix of gemm is the number of rows in - // B^T and thus the number of columns in B. - - auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape, - bool transpose) -> MatrixDescriptor { - bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0; - bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) != - LayoutUtil::Minor(output_shape_.layout(), row_dim); - return MatrixDescriptor( - data, transpose ^ layout_mismatch, - shape.dimensions(row_dim + static_cast(is_row_major)), - shape.dimensions(row_dim + static_cast(!is_row_major)), - batch_size); - }; - - const MatrixDescriptor lhs_descriptor = make_descriptor( - lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim); - const MatrixDescriptor rhs_descriptor = make_descriptor( - rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim); - - // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to - // autotune this gemm to figure out the best algorithm. - auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { - PrimitiveType element_type = output_shape_.element_type(); - se::blas::ComputationType computation_type = - GetBlasComputationType(element_type); - - // TODO(b/112111608): Implement auto tune for batched gemm. - if (batch_size != 1) { - return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - alpha_, beta_, stream); + auto fn = [&]() { + switch (output_shape_.element_type()) { + case F16: + return &ExecuteOnStreamParameterized; + case F32: + return &ExecuteOnStreamParameterized; + case F64: + return &ExecuteOnStreamParameterized; + case C64: + return &ExecuteOnStreamParameterized>; + case C128: + return &ExecuteOnStreamParameterized>; + default: + LOG(FATAL) << "Unsupported type."; } + }(); - auto thunk_name = [&] { - return hlo_instruction() != nullptr ? hlo_instruction()->ToString() - : ""; - }; - - const string& device_name = stream->parent()->GetDeviceDescription().name(); - auto autotune_it = autotune_results_.find(device_name); - if (autotune_it == autotune_results_.end()) { - VLOG(3) << "Starting autotune of GemmThunk " << thunk_name(); - - // If the output buffer already contains a bias then autotune into a - // scratch buffer. This avoids overwriting the bias buffer. The scratch - // buffer may contain arbitrary garbage values. - se::DeviceMemoryBase scratch_data = output_data; - std::unique_ptr> scratch_mem; - if (beta_ != 0.0) { - auto temp_status = stream->AllocateTemporaryArray( - ShapeUtil::ByteSizeOf(output_shape_)); - if (!temp_status.ok()) { - return false; - } - scratch_mem = std::move(temp_status).ValueOrDie(); - scratch_data = scratch_mem->device_memory(); - } - const MatrixDescriptor scratch_descriptor( - scratch_data, false, output_matrix.num_rows, output_matrix.num_cols, - batch_size); - - StatusOr best_algorithm = GetGemmAutotuneFn( - element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_, - beta_, computation_type, stream); - autotune_it = - autotune_results_.insert({device_name, best_algorithm}).first; - - if (autotune_it->second.ok()) { - VLOG(2) << "Autotune on GemmThunk " << thunk_name() - << " successful; best algorithm is " - << best_algorithm.ValueOrDie(); - } else { - VLOG(2) << "Autotune on GemmThunk " << thunk_name() - << " unsuccessful. Will use generic gemm."; - } - } - - const StatusOr& best_algorithm = - autotune_it->second; - if (best_algorithm.ok()) { - auto algorithm = best_algorithm.ValueOrDie(); - VLOG(2) << "Using algorithm " << algorithm - << " chosen by autotuning on GemmThunk " << thunk_name(); - return GetGemmWithAlgorithmFn(element_type)( - lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_, - computation_type, algorithm, stream, - /*output_profile_result=*/nullptr); - } - - // Autotune will fail when CUDA 8 and GPU sm_50 or older are used. - // Use the older Gemm API in this case. - return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - alpha_, beta_, stream); - }; - - auto op_profiler = profiler->MakeScopedInstructionProfiler( - implements_whole_instruction_ ? hlo_instruction() : nullptr); - bool launch_ok; - if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) { - launch_ok = launch(lhs_descriptor, rhs_descriptor, - MatrixDescriptor(output_data, false, output_num_rows, - output_num_cols, batch_size), - stream); - } else { - launch_ok = launch(rhs_descriptor, lhs_descriptor, - MatrixDescriptor(output_data, false, output_num_cols, - output_num_rows, batch_size), - stream); - } - - if (!launch_ok) { - return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); - } - return Status::OK(); + return fn(buffer_allocations, stream, profiler, lhs_buffer_, rhs_buffer_, + output_buffer_, lhs_shape_, rhs_shape_, output_shape_, + implements_whole_instruction_, hlo_instruction(), alpha_, beta_, + GetModuleConfig().debug_options().xla_gpu_disable_autotune()); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index cc2d12a39c0..e4f07d04820 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -53,14 +53,6 @@ class GemmThunk : public Thunk { se::Stream* stream, HloExecutionProfiler* profiler) override; - bool WillAutotuneKernel(se::Stream* stream) override { - // We will autotune this kernel if we don't already have a autotune result - // for the stream device. - return autotune_results_.find( - stream->parent()->GetDeviceDescription().name()) == - autotune_results_.end(); - } - private: const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; @@ -74,15 +66,6 @@ class GemmThunk : public Thunk { const double beta_; const bool implements_whole_instruction_; - - // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune - // results. The map's value is the best algorithm we've found for this thunk - // on this device, or an error if none of the algorithms worked and we should - // use the regular gemm without an algorithm. - // - // TODO(b/112415150): Make this thread safe. - std::unordered_map> - autotune_results_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto new file mode 100644 index 00000000000..ec4f6e9c913 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto @@ -0,0 +1,13 @@ +// This is used for convolution logging. Also see +// tensorflow/core/protobuf/autotuing.h +syntax = "proto3"; + +package xla.gpu; + +import "tensorflow/compiler/xla/service/hlo.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + +message ConvInstructionLog { + xla.HloInstructionProto instruction = 1; + repeated xla.ShapeProto operand_shapes = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 434060ad89d..dec40c5e49c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" namespace xla { namespace gpu { @@ -98,14 +99,12 @@ Status GpuExecutable::ExecuteThunks( sub_streams, hlo_module_->entry_computation()); uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // This top-level trace serves two purposes: - // 1) It marks the scope of the whole XLA module. - // 2) It tells us whether tracing is enabled. We use this to avoid the - // expensive HloInstruction::ToString() calls inside the loop below if - // tracing is disabled. - ScopedAnnotation top_level_annotation(hlo_module_->name(), "XLA GPU module"); + tensorflow::profiler::TraceMe hlo_module_activity( + [&] { return absl::StrCat(hlo_module_->name(), ":XLA GPU module"); }, + tensorflow::profiler::TraceMeLevel::kInfo); std::map> thunk_to_finish_event; + bool scoped_annotation_enabled = ScopedAnnotation::IsEnabled(); for (Thunk* thunk : thunk_schedule_->TotalOrder()) { // Annotate execution of this op if tracing was enabled when we started // running this module. If tracing is enabled *while* we're running the @@ -114,12 +113,13 @@ Status GpuExecutable::ExecuteThunks( // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), // since we expect it to be an expensive call? absl::optional op_annotation; - if (top_level_annotation.IsEnabled()) { + CHECK(thunk->hlo_instruction()); + if (scoped_annotation_enabled) { + auto hlo = thunk->hlo_instruction(); op_annotation.emplace( - thunk->hlo_instruction() != nullptr - ? thunk->hlo_instruction()->ToString(HloPrintOptions::Canonical()) - : "", - "XLA op"); + thunk->hlo_instruction()->ToString(HloPrintOptions::Canonical()), + absl::StrCat("#tf_op=", hlo->metadata().op_name(), + ",hlo_op=", hlo->name(), "#")); } TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); @@ -132,13 +132,6 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - // If this thunk is about to autotune then wait for all currently executing - // thunks to finish. This reduces noise and thus the probability of - // choosing a suboptimal algorithm. - if (thunk->WillAutotuneKernel(stream)) { - TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); - } - VLOG(2) << "Executing the thunk for " << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; @@ -233,11 +226,11 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { return &module_globals_.emplace(executor, std::move(globals)).first->second; } -StatusOr GpuExecutable::ExecuteOnStream( +StatusOr GpuExecutable::Execute( const ServiceExecutableRunOptions* run_options, absl::Span arguments, - HloExecutionProfile* hlo_execution_profile) { - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + HloExecutionProfile* hlo_execution_profile, bool block_host_until_done) { + se::DeviceMemoryAllocator* memory_allocator = run_options->allocator(); if (GetRootPointsToSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); @@ -279,8 +272,6 @@ StatusOr GpuExecutable::ExecuteOnStream( buffer_allocations_builder.Build( assignment_.get(), executor->device_ordinal(), memory_allocator)); - bool block_host_until_done = - !memory_allocator->AllowsAsynchronousDeallocation(); TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations, block_host_until_done, hlo_execution_profile)); @@ -346,12 +337,22 @@ StatusOr GpuExecutable::ExecuteOnStream( return std::move(shaped_buffer); } +StatusOr GpuExecutable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile) { + return Execute(run_options, arguments, hlo_execution_profile, + /*block_host_until_done=*/true); +} + StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments) { - // TODO(b/30671675): Implement asynchronous execution mode. - return Unimplemented( - "Asynchronous execution on stream is not yet supported on GPU."); + se::DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + // Force synchronous execution if the allocator requires it. + bool block_host_until_done = + !memory_allocator->AllowsAsynchronousDeallocation(); + return Execute(run_options, arguments, nullptr, block_host_until_done); } const PointsToSet& GpuExecutable::GetRootPointsToSet() const { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 2b3c77f5b82..b1f63bc672e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -24,7 +24,6 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" @@ -38,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace gpu { @@ -86,6 +86,11 @@ class GpuExecutable : public Executable { absl::Span arguments) override; private: + StatusOr Execute( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile, bool block_host_until_done); + // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for // clients, such as Tensorflow, that use a single stream of execution for diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 842ba2fdcd3..d5b351f69e3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -15,7 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include +#include + #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" namespace xla { namespace gpu { @@ -57,8 +64,8 @@ bool IsReduceInputFusion(const HloInstruction& instr) { if (instr.IsMultiOutputFusion()) { for (const HloInstruction* operand : instr.fused_expression_root()->operands()) { - if (IsReductionToVector(*operand)) { - CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + if (IsReductionFromOrToContiguousDimensions(*operand)) { + CHECK(instr.IsInputFusion()) << " Multi-output fusion rooted at reduction-to-vector ops must be " "of kind kInput: " << instr.ToString(); @@ -66,8 +73,9 @@ bool IsReduceInputFusion(const HloInstruction& instr) { } } } else if (instr.opcode() == HloOpcode::kFusion && - IsReductionToVector(*instr.fused_expression_root())) { - CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + IsReductionFromOrToContiguousDimensions( + *instr.fused_expression_root())) { + CHECK(instr.IsInputFusion()) << " Fusion rooted at reduction-to-vector op must be of kind kInput: " << instr.ToString(); return true; @@ -76,7 +84,13 @@ bool IsReduceInputFusion(const HloInstruction& instr) { } bool IsInputFusibleReduction(const HloInstruction& instr) { - return IsReduceInputFusion(instr) || IsReductionToVector(instr); + // TODO(b/129089333): Don't fuse variadic reduce. + if (instr.opcode() == HloOpcode::kReduce && instr.shape().IsTuple()) { + return false; + } + + return IsReduceInputFusion(instr) || + IsReductionFromOrToContiguousDimensions(instr); } bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, @@ -91,7 +105,7 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, // If possible, we want to pick a reduction-to-vector operand of the // fusion root, because it has the most constraints. for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionToVector(*inst)) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { return inst; } } @@ -107,7 +121,7 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, auto get_loop_shape = [&](const HloInstruction* element_instr) { // Special-case reduction-to-vector ops: The loop dimensions are determined // by the shape of the first operand. - if (IsReductionToVector(*element_instr)) { + if (IsReductionFromOrToContiguousDimensions(*element_instr)) { return element_instr->operand(0)->shape(); } return element_instr->shape(); @@ -120,7 +134,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, auto* instr_1 = get_real_hero(&instr1); auto* instr_2 = get_real_hero(&instr2); // TODO(tjoerg): Relax the shape constraint. The datatype does not matter. - if (IsReductionToVector(*instr_1) && IsReductionToVector(*instr_2) && + if (IsReductionFromOrToContiguousDimensions(*instr_1) && + IsReductionFromOrToContiguousDimensions(*instr_2) && (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) || instr_1->dimensions() != instr_2->dimensions())) { return false; @@ -131,5 +146,65 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, get_loop_shape(instr_2)); } +bool IsInputFusibleScatter(const HloInstruction& instr) { + if (instr.opcode() == HloOpcode::kScatter || + (instr.opcode() == HloOpcode::kFusion && + instr.fusion_kind() == HloInstruction::FusionKind::kInput && + instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) { + return true; + } + return false; +} + +bool IsInputFusible(const HloInstruction& instr) { + // Input fusion only handles non-elemental reduction and scatter operations. + return instr.IsFusible() && + (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr)); +} + +bool IsLoopFusible(const HloInstruction& instr) { + // Don't fuse get-tuple-element on GPU: We can, but it's slower than not + // fusing. We never generate kernels for unfused GTEs. Instead, if an + // unfused GTE is an input to a kernel (including a fusion kernel), we + // compute the address of the GTE at the top of the kernel. Often we know the + // address of the GTE result statically, so we can do this without chasing any + // pointers. + return instr.IsFusible() && + ((instr.IsElementwise() && instr.operand_count() > 0) || + instr.opcode() == HloOpcode::kBitcast || + instr.opcode() == HloOpcode::kBroadcast || + instr.opcode() == HloOpcode::kConcatenate || + instr.opcode() == HloOpcode::kDynamicSlice || + instr.opcode() == HloOpcode::kDynamicUpdateSlice || + (instr.opcode() == HloOpcode::kFusion && + instr.fusion_kind() == HloInstruction::FusionKind::kLoop) || + instr.opcode() == HloOpcode::kGather || + instr.opcode() == HloOpcode::kIota || + instr.opcode() == HloOpcode::kPad || + (instr.opcode() == HloOpcode::kReduce && + !IsReductionFromOrToContiguousDimensions(instr)) || + instr.opcode() == HloOpcode::kReduceWindow || + instr.opcode() == HloOpcode::kReshape || + instr.opcode() == HloOpcode::kReverse || + instr.opcode() == HloOpcode::kSlice || + instr.opcode() == HloOpcode::kTranspose); +} + +bool IsFusible(const HloInstruction& instr) { + return IsInputFusible(instr) || IsLoopFusible(instr); +} + +bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) { + // We can fuse reduces and loop fusions. Elementwise instructions can be fused + // with any other instruction. + // Note that scatter cannot be the root of a multi-output fusion because + // its emitter doesn't support it. + + return instr.IsFusible() && + (IsInputFusibleReduction(instr) || + instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here. + instr.IsElementwise()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index 9f0de3f794d..a4501fd31dc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -24,6 +24,15 @@ limitations under the License. namespace xla { namespace gpu { +// Whether 'instr' can occur inside fusions, i.e. whether it is a candidate +// for being fused. Note that further restrictions apply, e.g. Scatter must +// be the root of an input fusion. +bool IsFusible(const HloInstruction& instr); + +bool IsInputFusible(const HloInstruction& instr); + +bool IsLoopFusible(const HloInstruction& instr); + // The code emitted for reduce-rooted input fusions (EmitReductionToVector) // suffers from poor data locality if the layouts of input parameters differ. In // such situtations it is better not to fuse. Only input params with @@ -46,6 +55,10 @@ bool IsReduceInputFusion(const HloInstruction& instr); // is either an unfused reduction-to-vector op or a reduce input fusion. bool IsInputFusibleReduction(const HloInstruction& instr); +// Whether `instr` is fusible as root of a scatter input fusions, i.e. `instr` +// is either an unfused scatter op or a scatter input fusion. +bool IsInputFusibleScatter(const HloInstruction& instr); + // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. // This function works for both, sibling and producer-consumer multi-output @@ -56,6 +69,10 @@ bool IsInputFusibleReduction(const HloInstruction& instr); bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, const HloInstruction& instr2); +// Whether `instr` is a candidate for sibling fusion or as a consumer in +// a producer-consumer multi-output fusion. +bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index 15d4ee206ce..cee678e2902 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -62,7 +62,7 @@ TEST_F(GpuFusibleTest, copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) } fused_reduce { @@ -122,7 +122,7 @@ TEST_F(GpuFusibleTest, p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast) + greater-than = pred[128,1024,32,32]{3,2,1,0} compare(p1.1, broadcast), direction=GT select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast) ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select) } @@ -201,7 +201,8 @@ TEST_F(GpuFusibleTest, IsReduceInputFusion_ElementalReduction) { c0 = f32[] parameter(0) p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1) // Reduction lowered by GpuElementalIrEmitter. - ROOT reduce = f32[8,512,5,1,1]{4,3,2,1,0} reduce(p1, c0), dimensions={3}, to_apply=scalar_add + ROOT reduce = f32[512,5,1,1]{3,2,1,0} reduce(p1, c0), dimensions={3,0}, + to_apply=scalar_add })")) .ValueOrDie(); SCOPED_TRACE(module->ToString()); @@ -468,11 +469,12 @@ TEST_F(GpuFusibleTest, TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[32,32,32]{2,1,0} parameter(0) c0 = f32[] constant(0) - exp = f32[2,2,2]{2,1,0} exponential(p0) - reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add - ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + exp = f32[32,32,32]{2,1,0} exponential(p0) + reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2}, + to_apply=scalar_add + ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp) })")) .ValueOrDie(); const HloInstruction* reduce = @@ -507,7 +509,7 @@ TEST_F(GpuFusibleTest, p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT p0.1 = f32[2,2,2]{2,1,0} parameter(0) ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) } @@ -572,24 +574,28 @@ TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_DifferentReduceDimensions) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduce_1 { - p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p0.1 = f32[32,32,32]{2,1,0} parameter(0) c0 = f32[] constant(0) - ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add + ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32]{2,1,0} p0.1, f32[] c0), + dimensions={0}, to_apply=scalar_add } fused_reduce_2 { - p0.2 = f32[2,2,2]{2,1,0} parameter(0) - mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + p0.2 = f32[32,32,32]{2,1,0} parameter(0) + mul = f32[32,32,32]{2,1,0} multiply(f32[32,32,32]{2,1,0} p0.2, + f32[32,32,32]{2,1,0} p0.2) c1 = f32[] constant(0) - ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={2}, to_apply=scalar_add + ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32]{2,1,0} mul, f32[] c1), + dimensions={2}, to_apply=scalar_add } ENTRY reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) - p1 = f32[2,2,2]{2,1,0} parameter(1) - reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1 - reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2 - ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2) + p0 = f32[32,32,32]{2,1,0} parameter(0) + p1 = f32[32,32,32]{2,1,0} parameter(1) + reduce_1 = f32[32,32]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1 + reduce_2 = f32[32,32]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2 + ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) + tuple(reduce_1, reduce_2) })")) .ValueOrDie(); const HloInstruction* fusion_1 = @@ -603,25 +609,31 @@ TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_NoReductionToVector) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_element_wise { - p0.1 = f32[2,2,2]{2,1,0} parameter(0) - p1.1 = f32[2,2,2]{2,1,0} parameter(1) - ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + p0.1 = f32[32,32,32]{2,1,0} parameter(0) + p1.1 = f32[32,32,32]{2,1,0} parameter(1) + ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1) } fused_reduce { - p0.2 = f32[2,2,2]{2,1,0} parameter(0) - mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + p0.2 = f32[32,32,32]{2,1,0} parameter(0) + mul = f32[32,32,32]{2,1,0} multiply(f32[32,32,32]{2,1,0} p0.2, + f32[32,32,32]{2,1,0} p0.2) + broadcast = f32[32,32,32,32]{3,2,1,0} broadcast(mul), dimensions={3,2,1} c1 = f32[] constant(0) // Note that reduce is not a reduction-to-vector. - ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add + ROOT reduce = f32[32,32]{1,0} reduce(f32[32,32,32,32]{3,2,1,0} broadcast, + f32[] c1), dimensions={1,3}, to_apply=scalar_add } ENTRY reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) - p1 = f32[2,2,2]{2,1,0} parameter(1) - element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise - fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce - ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise) + p0 = f32[32,32,32]{2,1,0} parameter(0) + p1 = f32[32,32,32]{2,1,0} parameter(1) + element_wise = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, + calls=fused_element_wise + fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(element_wise), + kind=kLoop, calls=fused_reduce + ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) + tuple(fusion, element_wise) })")) .ValueOrDie(); const HloInstruction* fusion_1 = diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index 9c64b4d10c9..8b19769a781 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -20,7 +20,7 @@ limitations under the License. namespace xla { -// his pass should run early in the HLO pipeline and checks for HLO constructs +// This pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the GPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). class GpuHloSupportChecker : public HloModulePass { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index a6d80f0b6dd..0dac5d734bc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -57,9 +57,11 @@ HeuristicLayoutAssignment(const HloInstruction* instr, std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput, DataLayout::kBatchYXDepth); - // If we're not Volta or not fp16, the decision is easy: Use NCHW. - if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && - IsVoltaOrLater(*stream_executor))) { + // If we're not Volta or not fp16, or not conv2D, the decision is easy: Use + // NCHW. + if (instr->operand(0)->shape().element_type() != xla::PrimitiveType::F16 || + !IsVoltaOrLater(*stream_executor) || + instr->shape().tuple_shapes(0).dimensions_size() != 4) { return kAllNCHW; } @@ -214,6 +216,16 @@ Status GpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(op1_shape, instruction, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); + } else if (instruction->opcode() == HloOpcode::kFft) { + // cuFFT requires a dim0 major layout. + Shape op0_shape = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&op0_shape); + Shape output_shape = instruction->shape(); + LayoutUtil::SetToDefaultLayout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); } else if (instruction->opcode() == HloOpcode::kSort && instruction->operand(0)->shape().rank() > 1) { // Make sure that all the operands and the output(s) have the same layout. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 391029e5746..d9453aff69c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -374,7 +374,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) { p.0.rhs = f32[] parameter(1) p.1.lhs = f32[] parameter(2) p.1.rhs = f32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort { @@ -402,6 +402,35 @@ TEST_F(LayoutAssignmentTest, SortLayout) { op::ShapeWithLayout(expected_shape))); } +TEST_F(LayoutAssignmentTest, FftLayout) { + const char* hlo_text = R"( + HloModule Fft_module + + ENTRY Fft { + input = c64[8,32]{0,1} parameter(0) + fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32} + ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::Transpose(op::ShapeWithLayout(expected_shape)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Copy(op::Transpose(op::Fft(op::ShapeWithLayout(expected_shape))))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 69aaaceca11..6e414bd7a4d 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -38,17 +38,21 @@ using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( absl::Span io_hlos, absl::Span non_io_hlos) { - // I/O HLOs are bound to the arguments of the current IR function. I.e., + // I/O HLOs are bound to the arguments of the current IR function, + // *excluding* the output argument, which is added to non-I/O HLOs. + // I.e., // - // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { + // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg, temp_buffer_base) { llvm::Function* function = b_->GetInsertBlock()->getParent(); - CHECK_EQ(io_hlos.size() + 1, function->arg_size()); + CHECK_EQ(io_hlos.size() + 2, function->arg_size()); // An HLO can have duplicated operands. This data structure remembers which // operand HLOs are already bound to avoid rebinding the same HLO. absl::flat_hash_set already_bound_for_this_function; auto arg_iter = function->arg_begin(); for (const HloInstruction* io_hlo : io_hlos) { + CHECK(!absl::c_count(non_io_hlos, io_hlo)) + << "IO HLOs and non-IO HLOs should be disjoint"; if (!already_bound_for_this_function.contains(io_hlo)) { if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); @@ -60,6 +64,10 @@ void HloToIrBindings::EmitBasePointersForHlos( ++arg_iter; } + // Name and skip the output parameter. + arg_iter->setName("output_arg"); + ++arg_iter; + temp_buffer_base_ = &*arg_iter; temp_buffer_base_->setName("temp_buffer"); @@ -113,10 +121,9 @@ void HloToIrBindings::EmitBasePointersForHlos( BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), index); } else if (slice.allocation()->is_constant()) { - llvm::Value* global_for_constant = - module_->getGlobalVariable(llvm_ir::AsStringRef( - llvm_ir::ConstantBufferAllocationToGlobalName( - *slice.allocation()))); + llvm::Value* global_for_constant = module_->getGlobalVariable( + llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation())); BindHloToIrValue(*non_io_hlo, global_for_constant); } else { const int64 offset = slice.offset(); @@ -136,11 +143,11 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_); + EmitGetTupleElement(gte->operand(0), base_ptr), b_); } // Returns true if `value` has a name that should not be changed. @@ -166,11 +173,10 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); } if (!HasMeaningfulName(ir_value)) { - ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); + ir_value->setName(llvm_ir::IrName(&hlo, "raw")); } if (!HasMeaningfulName(typed_ir_value)) { - typed_ir_value->setName( - llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); + typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed")); } return typed_ir_value; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index f07141029cb..54cab21ab4c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -28,31 +29,6 @@ namespace gpu { namespace { -bool IsFusible(const HloInstruction& hlo) { - // Don't fuse get-tuple-element on GPU: We can, but it's slower than not - // fusing. We never generate kernels for unfused GTEs. Instead, if an - // unfused GTE is an input to a kernel (including a fusion kernel), we - // compute the address of the GTE at the top of the kernel. Often we know the - // address of the GTE result statically, so we can do this without chasing any - // pointers. - return (hlo.IsElementwise() && hlo.operand_count() > 0) || - hlo.opcode() == HloOpcode::kBitcast || - hlo.opcode() == HloOpcode::kBroadcast || - hlo.opcode() == HloOpcode::kConcatenate || - hlo.opcode() == HloOpcode::kDynamicSlice || - hlo.opcode() == HloOpcode::kDynamicUpdateSlice || - hlo.opcode() == HloOpcode::kFusion || - hlo.opcode() == HloOpcode::kGather || - hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || - hlo.opcode() == HloOpcode::kReduce || - hlo.opcode() == HloOpcode::kReduceWindow || - hlo.opcode() == HloOpcode::kReshape || - hlo.opcode() == HloOpcode::kReverse || - hlo.opcode() == HloOpcode::kScatter || - hlo.opcode() == HloOpcode::kSlice || - hlo.opcode() == HloOpcode::kTranspose; -} - bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { if (constant->opcode() != HloOpcode::kConstant || !ShapeUtil::IsScalar(constant->shape())) { @@ -138,20 +114,16 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion; } -bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, - int64 operand_index) { +bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, + int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (producer->opcode() == HloOpcode::kDot || - (producer->opcode() == HloOpcode::kFusion && - producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { + if (producer->opcode() == HloOpcode::kDot && ImplementedAsGemm(*producer)) { int64 other_operand_index = 1 - operand_index; HloInstruction* op1 = nullptr; HloInstruction* op2 = nullptr; - if (consumer->operand_count() == 1 && - consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && + if (consumer->operand_count() == 1 && consumer->IsLoopFusion() && Match(consumer->fused_expression_root(), match::Op() .WithOpcode(HloOpcode::kMultiply) @@ -190,9 +162,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // Only allow fusing transpose or broadcast into an output fusion that is // implemented as a Gemm call. - if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kOutput && - ImplementedAsGemm(*consumer)) { + if (consumer->IsOutputFusion() && ImplementedAsGemm(*consumer)) { auto producer_operand_index = consumer->operand_index(producer); auto fused_parameter = consumer->fused_parameter(producer_operand_index); const std::vector& fused_parameter_users = @@ -226,7 +196,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // Do not fuse to-vector reduction into other consumers. They should be // unfused or the root of a kInput fusion. - if (IsReductionToVector(*producer)) { + if (IsReductionFromOrToContiguousDimensions(*producer)) { return false; } @@ -275,9 +245,29 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, !InstructionFusion::ShouldFuse(consumer, operand_index)) { return false; } + return true; +} - // We put this check last because it's potentially expensive. - return !FusionWouldBeTooLarge(consumer, producer); +bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, + int64 operand_index) { + if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) { + return false; + } + auto producer = consumer->operand(operand_index); + + // TODO(b/129089333): Don't fuse variadic reduce. + if (consumer->opcode() == HloOpcode::kReduce && consumer->shape().IsTuple()) { + return false; + } + // The following checks are potentially expensive. + if (FusionWouldBeTooLarge(consumer, producer)) { + return false; + } + // Also check that our emitter can handle the fusion node. We currently can + // have exponential time/memory requirements for emitting certain fusion + // kernels, in which case we don't want to fuse. + // TODO(b/119692968): Remove this once we have fixed our fusion emitter. + return !FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer); } bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, @@ -287,7 +277,7 @@ bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { - if (IsReductionToVector(*consumer) || + if (IsReductionFromOrToContiguousDimensions(*consumer) || consumer->opcode() == HloOpcode::kScatter) { return HloInstruction::FusionKind::kInput; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index c91f6343a69..2f8f40b4b5e 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -49,6 +49,12 @@ class GpuInstructionFusion : public InstructionFusion { HloInstruction::FusionKind ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) override; + + private: + // This method is called by ShouldFuse() to do all the computationally + // inexpensive checks whether we should fuse the operand into 'consumer'. + bool ShouldFuseInexpensiveChecks(HloInstruction* consumer, + int64 operand_index); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index a05ab86cf77..edb6ecf6247 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -381,6 +381,25 @@ ENTRY main { Not(op::Fusion())); } +TEST_F(InstructionFusionTest, DotOutputFusion_DontUnlessImplementedAsGemm) { + auto module = ParseHloString(R"( + HloModule dot + + ENTRY entry_computation { + p0 = f32[64,63,512]{2,1,0} parameter(0) + p1 = f32[512,512]{1,0} parameter(1) + p2 = f32[64,63,512]{2,1,0} parameter(2) + dot.3525 = f32[64,63,512]{2,1,0} dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT add.3529 = f32[64,63,512]{2,1,0} add(dot.3525, p2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6b9cbdd94b3..957a2f00723 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -81,6 +82,45 @@ bool DotImplementedAsGemm(const HloInstruction& dot) { } return false; } + +// Given a shape and a group of contiguous dimensions in the shape, returns +// a tuple of three values (major, middle, minor), where major is the size of +// the dimensions more major then the given dimensions, minor is the size of +// dimensions more minor then the given dimensions, and middle is the size of +// the given dimensions. +std::tuple PartitionShapeByMiddleDimensions( + const Shape& shape, DimensionVector dims_middle) { + CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle)); + + absl::Span minor_to_major = LayoutUtil::MinorToMajor(shape); + int64 values[3] = {1, 1, 1}; + enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 }; + Segment cur_segment = kMinor; + + // Iterate through the dimensions for the three segments in the order of + // minor, middle and major to accumulate the size of each segment. + absl::c_for_each(minor_to_major, [&](int64 cur_dim) { + if (cur_segment != kMajor) { + // Handle change of segments. + bool cur_dim_in_middle = absl::c_any_of( + dims_middle, [&](int64 dim) { return dim == cur_dim; }); + if (cur_segment == kMinor) { + if (cur_dim_in_middle) { + cur_segment = kMiddle; + } + } else if (cur_segment == kMiddle) { + if (!cur_dim_in_middle) { + cur_segment = kMajor; + } + } + } + + values[cur_segment] *= shape.dimensions(cur_dim); + }); + + return std::make_tuple(values[kMajor], values[kMiddle], values[kMinor]); +} + } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { @@ -89,8 +129,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { return DotImplementedAsGemm(hlo); } - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && + if (hlo.IsOutputFusion() && (hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply || hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) { // Try to find the dot inside the output fusion node. @@ -157,10 +196,16 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) { IsCustomCallToDnnConvolution(hlo); } -bool IsReductionToVector(const HloInstruction& reduce) { +bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { if (HloOpcode::kReduce != reduce.opcode()) { return false; } + + // TODO(b/129698548): Remove this check after fixing the bug. + if (reduce.shape().element_type() == C128) { + return false; + } + const HloInstruction* input = reduce.operand(0); std::vector dims_to_keep; for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) { @@ -168,13 +213,74 @@ bool IsReductionToVector(const HloInstruction& reduce) { dims_to_keep.push_back(dim); } } - return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), - dims_to_keep) && - ShapeUtil::Equal( - reduce.shape(), - ShapeUtil::FilterDimensions( - [&](int64 dim) { return absl::c_count(dims_to_keep, dim); }, - input->shape())); + if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), + dims_to_keep) && + !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), + reduce.dimensions())) { + return false; + } + + bool is_row_reduction; + DimensionVector dims_in_elem; + std::tie(is_row_reduction, dims_in_elem) = + GetReductionKindAndContiguousComponents(input->shape(), + reduce.dimensions()); + + if (is_row_reduction) { + // For row reduction, the tile block is 1 x tile_size_x, and we are reducing + // along tile_size_x which needs to be large enough to make the tiling + // implementation efficient. + return dims_in_elem[2] >= kWarpSize; + } + + // For column reduction, the tile block is tize_size_y x tile_size_x, and we + // are reducing along tile_size_y. Both tile_size_x and tile_size_y need to be + // large enough to make the tiling implementation efficient. + return dims_in_elem[2] >= kWarpSize && dims_in_elem[1] >= kWarpSize; +} + +std::pair GetReductionKindAndContiguousComponents( + const Shape& input_shape, absl::Span dims_to_reduce) { + DimensionVector dims_to_keep; + for (int64 dim = 0; dim < input_shape.rank(); ++dim) { + if (!absl::c_linear_search(dims_to_reduce, dim)) { + dims_to_keep.push_back(dim); + } + } + + if (dims_to_keep.empty()) { + return std::make_pair( + true, DimensionVector{1, 1, ShapeUtil::ElementsIn(input_shape)}); + } + + if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(), + dims_to_keep)) { + int64 num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1; + std::tie(num_reduced_major, num_kept, num_reduced_minor) = + PartitionShapeByMiddleDimensions(input_shape, dims_to_keep); + if (num_kept == 1) { + return std::make_pair( + true, DimensionVector{1, 1, num_reduced_minor * num_reduced_major}); + } + if (num_reduced_minor == 1) { + return std::make_pair(false, + DimensionVector{1, num_reduced_major, num_kept}); + } + return std::make_pair( + true, DimensionVector{num_reduced_major, num_kept, num_reduced_minor}); + } + + int64 num_kept_major = 1, num_reduced = 1, num_kept_minor = 1; + std::tie(num_kept_major, num_reduced, num_kept_minor) = + PartitionShapeByMiddleDimensions( + input_shape, + DimensionVector(dims_to_reduce.begin(), dims_to_reduce.end())); + if (num_kept_minor == 1) { + return std::make_pair(true, + DimensionVector{1, num_kept_major, num_reduced}); + } + return std::make_pair( + false, DimensionVector{num_kept_major, num_reduced, num_kept_minor}); } // This emits a device-side call to @@ -213,8 +319,8 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, // Special case for efficiency if (value->getType()->isFloatTy() && bit_width == 32) { - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_shfl_sync_down_f32, + return EmitCallToTargetIntrinsic( + TargetIntrinsicID::kShflDownF32, {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {}, builder); } @@ -230,8 +336,8 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, for (int i = 0; i < num_segments; ++i) { x = builder->CreateInsertElement( x, - llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_shfl_sync_down_i32, + EmitCallToTargetIntrinsic( + TargetIntrinsicID::kShflDownI32, {all_warps_mask, builder->CreateExtractElement(x, i), offset, builder->getInt32(kWarpSize - 1)}, {}, builder), @@ -279,12 +385,10 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { return b->CreateAnd( b->CreateICmpEQ( b->getInt32(0), - llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)), + EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)), b->CreateICmpEQ( b->getInt32(0), - llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b))); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b))); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index f1a7aabb4db..4e0e828eee2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -148,7 +148,29 @@ extern const char* const kCusolverCholeskyCallTarget; // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); -bool IsReductionToVector(const HloInstruction& reduce); +// Returns true if either the dimensions being reduced or the dimensions being +// kept are contiguous in the input of the reduce instruction. +bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce); + +// Given the input shape and dimensions to reduce for a reduction, returns +// : +// is_row_reduction: indicates whether the reduction is a row reduction or a +// column reduction. +// DimensionVector: contains the size of the three contiguous components for the +// reduction [depth, height, width]. For row reduction, height is the size of +// the dimensions to keep, depth is the size of the dimensions to reduce that +// are more major than the dimensions to keep, and width is the size of the +// dimensions to reduce that are more minor than the dimensions to keep. For +// column reduction, height is the size of dimensions to reduce, depth is the +// the size of the dimensions to keep that are more major than the dimensions +// to reduce, and width is the size of the dimensions to keep that are more +// minor than the dimensions to reduce. +// +// Prerequisite: the reduction instruction passes the check +// IsReductionFromOrToContiguousDimensions, which guarantees either the +// dimensions to reduce or the dimensions to keep are consecutive. +std::pair GetReductionKindAndContiguousComponents( + const Shape& input_shape, absl::Span dims_to_reduce); // Emits call to "vprintf" with given format and arguments. llvm::Value* EmitPrintf(absl::string_view fmt, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 8f010ab27a6..a3fb1ce7307 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -32,7 +33,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -115,7 +118,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { get_tuple_element->shape(), get_tuple_element->tuple_index(), // TODO(b/26344050): tighten the alignment here // based on the real element type. - /*alignment=*/1, GetBasePointer(*operand), &b_, module_)); + /*alignment=*/1, GetBasePointer(*operand), &b_)); return Status::OK(); } @@ -144,7 +147,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_); return Status::OK(); } @@ -157,8 +160,7 @@ Status IrEmitter::EmitCallToNestedComputation( if (emitted_function == nullptr) { IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation, ir_emitter_context_); - TF_RETURN_IF_ERROR( - nested_computation.root_instruction()->Accept(&ir_emitter_nested)); + TF_RETURN_IF_ERROR(ir_emitter_nested.CodegenNestedComputation()); emitted_function = ir_emitter_nested.GetEmittedFunction(); } @@ -434,7 +436,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), GetIrArray(*pred, *tuple_select), GetBasePointer(*on_true), GetBasePointer(*on_false), - &b_, module_); + &b_); return Status::OK(); } @@ -528,16 +530,18 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_); - llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( - lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); - llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( - rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + std::vector lhs_multi_index = + loop_nest.EmitOperandArrayLoopNest( + lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + std::vector rhs_multi_index = + loop_nest.EmitOperandArrayLoopNest( + rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); // We don't have to iterate over the batch dimensions in both arrays, simplify // the loop nest of the rhs. for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); - rhs_index[i] = lhs_index[i]; + rhs_multi_index[i] = lhs_multi_index[i]; } // Create the reduction loop which does the sum of products reduction. @@ -548,8 +552,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // The final entry in the rhs and lhs indexes is the indvar of the reduction // loop. - lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); - rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); + rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); // For computing the sum of products we alloca a single location to store the // dot product result as we accumulate it within the reduction loop. After the @@ -574,7 +578,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty()); b_.SetInsertPoint( &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); + llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(), + b_.getInt64Ty()); llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); + llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(), + b_.getInt64Ty()); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); llvm::Value* accum = Load(accum_address); llvm::Value* updated_accum; @@ -600,20 +608,22 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // address. The index into the target address is the concatenation of the rhs // and lhs indexes with the reduction dimensions removed. The terms from the // rhs index are the lower dimensions in the index so we add them first. - llvm_ir::IrArray::Index target_index(index_type); + std::vector target_multi_index; for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { - target_index.push_back(lhs_index[dimension]); + target_multi_index.push_back(lhs_index[dimension]); } } // Skip over the batch dimensions to not have them in the index twice. for (size_t dimension = dnums.lhs_batch_dimensions_size(); dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { - target_index.push_back(rhs_index[dimension]); + target_multi_index.push_back(rhs_index[dimension]); } } SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); + llvm_ir::IrArray::Index target_index(target_multi_index, + target_array.GetShape(), index_type); target_array.EmitWriteArrayElement( target_index, Load(accum_address), // The value written to the target array. @@ -653,23 +663,38 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } -Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/118332391): Support variadic reduce. - if (!reduce->shape().IsArray()) { - return Unimplemented("Variadic reduce is not supported on GPU"); +Status IrEmitter::HandleReduce(HloInstruction* instr) { + const HloReduceInstruction* reduce = Cast(instr); + const Shape& out_shape = reduce->shape(); + bool returns_tuple = !out_shape.IsArray(); + int accumulators_count = 1; + if (returns_tuple) { + CHECK(out_shape.IsTuple()); + accumulators_count = out_shape.tuple_shapes_size(); } + auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); return EmitTargetElementLoop( *reduce, [=](const llvm_ir::IrArray::Index& index) -> StatusOr { - // Initialize an accumulator with init_value. - llvm::AllocaInst* accumulator_addr = - Alloca(llvm_ir::PrimitiveTypeToIrType( - reduce->shape().element_type(), module_)); - Store(Load(GetBasePointer(*init_value)), accumulator_addr); + std::vector accumulator_addrs; + std::vector accumulator_types; + + // Initialize accumulators with initial values. + for (int i = 0; i < accumulators_count; i++) { + auto init_value = reduce->init_values()[i]; + const Shape& element_shape = + returns_tuple ? out_shape.tuple_shapes(i) : out_shape; + PrimitiveType accumulator_type = element_shape.element_type(); + llvm::Type* accumulator_llvm_type = + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); + llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type); + Store(Load(GetBasePointer(*init_value)), accumulator_addr); + accumulator_addrs.push_back(accumulator_addr); + accumulator_types.push_back(accumulator_llvm_type); + } // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -678,7 +703,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Value*s are placed for each dimension in dimensions, and all the rest // are nullptrs. llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - const llvm_ir::IrArray::Index reduced_dims_index = + std::vector input_multi_index = loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); @@ -689,24 +714,61 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // filled in. We fill in the rest of the dimensions with induction // Value*s taken from 'index' which iterates over the target array. // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = *it++; + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; } } CHECK(index.end() == it); // Apply the reduction function to the loaded value. - llvm::Value* input_address = - GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_); + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + b_.getInt64Ty()); + std::vector reduction_operands(accumulator_addrs.begin(), + accumulator_addrs.end()); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* input_address = + GetIrArray(*reduce->operand(i), *reduce) + .EmitArrayElementAddress(input_index, &b_); + reduction_operands.push_back(input_address); + } + + llvm::Value* ret_argument; + if (!returns_tuple) { + CHECK_EQ(accumulator_addrs.size(), 1); + ret_argument = accumulator_addrs[0]; + } else { + const Shape& return_shape = function->root_instruction()->shape(); + + llvm::Type* return_value_buffer_type = + llvm_ir::ShapeToIrType(return_shape, module_); + ret_argument = Alloca(return_value_buffer_type); + llvm_ir::IrArray tuple_array(ret_argument, return_shape); + EmitTuple(tuple_array, accumulator_addrs, &b_); + } + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *function, {accumulator_addr, input_address}, accumulator_addr)); + *function, reduction_operands, ret_argument)); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return Load(accumulator_addr); + + if (!returns_tuple) { + CHECK_EQ(accumulator_addrs.size(), 1); + return Load(accumulator_addrs[0]); + } else { + // Emit a struct for the LoopEmitter dealing with multi-output + // fusion. + llvm::Value* returned_structure = llvm::UndefValue::get( + llvm::StructType::get(b_.getContext(), accumulator_types)); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* accumulator_value = Load(accumulator_addrs[i]); + returned_structure = + b_.CreateInsertValue(returned_structure, accumulator_value, i); + } + return returned_structure; + } }); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index a78b4ff8307..b9d944b5dc1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -32,10 +32,12 @@ class IrEmitterContext { public: IrEmitterContext(const HloModule* hlo_module, const BufferAssignment* buffer_assignment, + const se::Platform* platform, const se::DeviceDescription* device_desc, llvm::Module* llvm_module) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), + platform_(platform), device_desc_(device_desc), llvm_module_(llvm_module) {} // Disallow copy and assign. @@ -47,6 +49,7 @@ class IrEmitterContext { const BufferAssignment& buffer_assignment() const { return *buffer_assignment_; } + const se::Platform* platform() const { return platform_; } const se::DeviceDescription& device_description() const { return *device_desc_; } @@ -56,6 +59,7 @@ class IrEmitterContext { private: const HloModule* hlo_module_; const BufferAssignment* buffer_assignment_; + const se::Platform* platform_; const se::DeviceDescription* device_desc_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 66c65f69758..72f48c49096 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -38,20 +38,18 @@ namespace gpu { IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config, const HloComputation& nested_computation, IrEmitterContext* ir_emitter_context) - : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) { - std::vector io_hlos; - emitted_function_ = - EmitBasePointersForNestedComputation(nested_computation, &io_hlos); -} + : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true), + nested_computation_(nested_computation) {} -llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( - const HloComputation& nested_computation, - std::vector* io_hlos) { +// Nested function serves the same purpose on GPU as a thread-local function on +// a CPU. +Status IrEmitterNested::CodegenNestedComputation() { + std::vector io_hlos; std::vector argument_types; std::vector argument_dereferenceable_bytes; for (const HloInstruction* param : - nested_computation.parameter_instructions()) { - io_hlos->push_back(param); + nested_computation_.parameter_instructions()) { + io_hlos.push_back(param); const Shape& param_shape = param->shape(); argument_types.push_back( llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo()); @@ -59,9 +57,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout()); argument_dereferenceable_bytes.push_back(param_size); } + + const HloInstruction* root = nested_computation_.root_instruction(); { - const HloInstruction* root = nested_computation.root_instruction(); - io_hlos->push_back(root); const Shape& root_shape = root->shape(); argument_types.push_back( llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo()); @@ -77,9 +75,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::Function* function = llvm::Function::Create( function_type, // The function type. llvm::GlobalValue::InternalLinkage, // The linkage type. - llvm_ir::AsStringRef(ir_emitter_context_->name_uniquer()->GetUniqueName( + ir_emitter_context_->name_uniquer()->GetUniqueName( llvm_ir::SanitizeFunctionName( - nested_computation.name()))), // The name of the function. + nested_computation_.name())), // The name of the function. ir_emitter_context_->llvm_module()); // The parent LLVM module. for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size(); ++arg_no) { @@ -96,17 +94,61 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::BasicBlock::Create(function->getContext(), "entry", function); // Emit a "return void" at entry_bb's end, and sets the insert point before // that return instruction. - b_.SetInsertPoint(llvm::ReturnInst::Create(function->getContext(), entry_bb)); + llvm::ReturnInst* ret_instr = + llvm::ReturnInst::Create(function->getContext(), entry_bb); + b_.SetInsertPoint(ret_instr); std::vector non_io_hlos; - for (const auto* hlo : nested_computation.instructions()) { + non_io_hlos.push_back(root); + for (const auto* hlo : nested_computation_.instructions()) { if (hlo->opcode() != HloOpcode::kParameter && - hlo != nested_computation.root_instruction()) { + hlo != nested_computation_.root_instruction()) { non_io_hlos.push_back(hlo); } } - bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); - return function; + bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos); + + TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this)); + b_.SetInsertPoint(ret_instr); + + // Function epilogue: copy the output value back. + { + // TODO(cheshire) Duplication vs. EmitThreadLocalFunctionEpilogue + const HloInstruction* root_instruction = + nested_computation_.root_instruction(); + llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction); + const Shape& return_shape = root_instruction->shape(); + + // Second last argument is the out parameter. + llvm::Argument* out_parameter = std::prev(function->arg_end(), 2); + + if (ShapeUtil::IsScalar(return_shape)) { + llvm::Value* ret_value = Load(root_value, "load_ret_value"); + Store(ret_value, + BitCast(out_parameter, root_value->getType(), "bitcast_ret_value")); + } else { + CHECK(return_shape.IsTuple()); + llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_); + llvm::Type* tuple_type_ptr = tuple_type->getPointerTo(); + llvm::Value* tuple_ptr = BitCast(out_parameter, tuple_type_ptr); + + for (int i = 0; i < return_shape.tuple_shapes_size(); i++) { + const Shape& element_shape = return_shape.tuple_shapes(i); + llvm::Value* destination = + llvm_ir::EmitGetTupleElement(element_shape, + /*index=*/i, + /*alignment=*/1, tuple_ptr, &b_); + llvm::Value* source = + llvm_ir::EmitGetTupleElement(element_shape, + /*index=*/i, + /*alignment=*/1, root_value, &b_); + Store(Load(source), destination); + } + } + } + b_.SetInsertPoint(ret_instr); + emitted_function_ = function; + return Status::OK(); } Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { @@ -118,12 +160,12 @@ Status IrEmitterNested::EmitTargetElementLoop( const llvm_ir::ElementGenerator& element_generator) { // For MOF we give the loop emitter an array for every output it should // generate. - if (hlo.IsMultiOutputFusion()) { + if (hlo.shape().IsTuple()) { std::vector target_arrays = ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_); return Status::OK(); } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h index ca11cf2c182..ce825851bcc 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h @@ -58,11 +58,11 @@ class IrEmitterNested : public IrEmitter { const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) override; - private: - llvm::Function* EmitBasePointersForNestedComputation( - const HloComputation& nested_computation, - std::vector* io_hlos); + // Generate the code for the computation passed in the constructor. + Status CodegenNestedComputation(); + private: + const HloComputation& nested_computation_; llvm::Function* emitted_function_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 61aa981d779..774c2b8682f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" + #include #include #include @@ -20,8 +22,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" - #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h" @@ -45,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" +#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" @@ -60,6 +62,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" @@ -231,7 +234,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( if (alloc->IsPreallocatedTempBuffer()) { fn_arg->setName("temp_buf"); } else { - fn_arg->setName(llvm_ir::AsStringRef(StrCat("alloc", alloc->index()))); + fn_arg->setName(StrCat("alloc", alloc->index())); } } @@ -526,7 +529,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return Status::OK(); } - return IrEmitter::HandleCustomCall(custom_call); + if (void* call_target = CustomCallTargetRegistry::Global()->Lookup( + custom_call->custom_call_target(), + ir_emitter_context_->platform()->Name())) { + const auto& assn = ir_emitter_context_->buffer_assignment(); + auto get_slices_for_instr = [&](const HloInstruction* instr) { + ShapeTree slices(instr->shape()); + slices.ForEachMutableElement([&](const ShapeIndex& index, + BufferAllocation::Slice* slice) { + StatusOr s = assn.GetUniqueSlice(instr, index); + if (s.ok()) { + *slice = s.ValueOrDie(); + } + }); + return slices; + }; + std::vector> operand_slices; + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(get_slices_for_instr(operand)); + } + ShapeTree result_slices = + get_slices_for_instr(custom_call); + AddThunkToThunkSequence(absl::make_unique( + call_target, std::move(operand_slices), std::move(result_slices), + Cast(custom_call)->opaque(), custom_call)); + return Status::OK(); + } + + return Unimplemented("No registered implementation for custom call to \"%s\"", + custom_call->custom_call_target()); } Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { @@ -574,7 +605,7 @@ Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) { Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); - if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { + if (fusion->IsInputFusion()) { switch (root->opcode()) { case HloOpcode::kScatter: { std::vector> thunks; @@ -631,10 +662,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // a 1D array. The specialized version requires a initializer thunk that // initializes the output array to the initial value of the reduce. if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) { - // TODO(b/118332391): Support variadic reduce. - return Unimplemented("Variadic reduce is not supported on GPU"); + // TODO(b/129089333): Support tiled vectorized variadic reduce. + return Unimplemented( + "Vectorized variadic reduce is not supported on GPU"); } - return EmitReductionToVector(fusion); + return EmitReductionFromOrToContiguousDimensions(fusion); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -677,7 +709,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop); + CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop) + << ": " << fusion->ToString(); if (CheckAndEmitHloWithTile021(fusion)) { return Status::OK(); @@ -705,13 +738,14 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { Status IrEmitterUnnested::EmitExtraOutputsForReduce( const HloInstruction* unnested_hlo, const IrArray::Index& index, + bool use_linear_index, absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { llvm::Value* extra_output_address = GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second) - .EmitArrayElementAddress(index, &b_, - "extra_output_element_address"); + .EmitArrayElementAddress(index, &b_, "extra_output_element_address", + use_linear_index); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); Store(extra_output_ir_value, extra_output_address); @@ -720,12 +754,9 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( } Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { - // TODO(b/118332391): Support multi-output reduce. - if (!reduce->shape().IsArray()) { - return Unimplemented("Multi-output reduce is not supported on GPU"); - } - if (IsReductionToVector(*reduce)) { - return EmitReductionToVector(reduce); + if (IsReductionFromOrToContiguousDimensions(*reduce) && + reduce->shape().IsArray()) { + return EmitReductionFromOrToContiguousDimensions(reduce); } return IrEmitter::HandleReduce(reduce); @@ -863,16 +894,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - IrArray::Index operand_index(index_type, source_index.size()); + std::vector operand_multi_index(source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = NSWMul( source_index[i], index_typed_constant(window.dimensions(i).stride())); - operand_index[i] = + operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), index_typed_constant(window.dimensions(i).padding_low())); llvm::Value* index_condition = ICmpULT( - operand_index[i], + operand_multi_index[i], index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); in_bounds_condition = And(in_bounds_condition, index_condition); } @@ -897,6 +928,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( } }; IrArray operand_array = GetIrArray(*operand, *select_and_scatter); + IrArray::Index operand_index(operand_multi_index, operand->shape(), + index_type); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); Store(operand_data, selected_value_address); @@ -907,7 +940,6 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // potentially update the selected value and index with the currently // visiting operand. llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_); - const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( @@ -939,15 +971,18 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // value and the current output value. llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_); - IrArray::Index selected_index(operand_index.GetType()); + std::vector selected_multi_index; for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = InBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(Load(selected_index_address_slot)); + selected_multi_index.push_back(Load(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) .EmitArrayElementAddress(source_index, &b_); + IrArray::Index selected_index(selected_multi_index, + select_and_scatter->shape(), + operand_index.GetType()); llvm::Value* output_value_address = GetIrArray(*select_and_scatter, *select_and_scatter) .EmitArrayElementAddress(selected_index, &b_); @@ -1125,16 +1160,20 @@ Status IrEmitterUnnested::EmitScatter( // Now load the indices corresponding to the current window from // scatter_indices. - llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim, - index.GetType()); - raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + std::vector raw_scatter_index_multidim = + input_scatter_multidim; + raw_scatter_index_multidim.insert( + raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(), + nullptr); llvm::Value* is_in_bounds = b_.getTrue(); for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); i != e; ++i) { // Our index is stored along index_vector_dim, insert that into the lookup // index into scatter_indices. - raw_scatter_index_index[dim_numbers.index_vector_dim()] = - raw_scatter_index_index.GetConstantWithIndexType(i); + raw_scatter_index_multidim[dim_numbers.index_vector_dim()] = + index.GetConstantWithIndexType(i); + llvm_ir::IrArray::Index raw_scatter_index_index( + raw_scatter_index_multidim, scatter_indices_shape, index.GetType()); int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); TF_ASSIGN_OR_RETURN( @@ -1164,10 +1203,10 @@ Status IrEmitterUnnested::EmitScatter( llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); // All done, now just read from the calculated input from the window, and do // an atomic store to the calculated location in the output. - llvm_ir::IrArray::Index input_window_index(input_window_multidim, - index.GetType()); HloInstruction* output_hlo = scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter; + llvm_ir::IrArray::Index input_window_index( + input_window_multidim, output_hlo->shape(), index.GetType()); llvm::Value* output_address = GetIrArray(*output_hlo, *output_hlo) .EmitArrayElementAddress(input_window_index, &b_); @@ -1659,8 +1698,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::Value* loc; if (slice.allocation()->is_constant()) { loc = ir_emitter_context_->llvm_module()->getGlobalVariable( - llvm_ir::AsStringRef(llvm_ir::ConstantBufferAllocationToGlobalName( - *slice.allocation()))); + llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation())); CHECK_NE(loc, nullptr); } else { loc = InBoundsGEP(kernel_args.at(slice.allocation()), @@ -1689,7 +1727,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( } return absl::make_unique( - non_constant_buffers, llvm_ir::AsString(kernel->getName()), + non_constant_buffers, kernel->getName(), implements_whole_instruction ? inst : nullptr, unroll_factor); } @@ -2170,9 +2208,10 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( int unroll_factor = thunk->unroll_factor(); VLOG(3) << bindings_.ToString(); - const Shape& element_shape = hlo.IsMultiOutputFusion() - ? ShapeUtil::GetSubshape(hlo.shape(), {0}) - : hlo.shape(); + bool multi_output = hlo.shape().IsTuple(); + + const Shape& element_shape = + multi_output ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape(); VLOG(3) << "EmitTargetElementLoopInThunk " << ShapeUtil::HumanStringWithLayout(hlo.shape()) << " for unroll_factor " << unroll_factor; @@ -2180,7 +2219,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( element_shape, ir_emitter_context_->device_description(), unroll_factor); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); - if (!hlo.IsMultiOutputFusion()) { + if (!multi_output) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &b_, unroll_factor) .EmitLoop( @@ -2194,7 +2233,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_); }); // For multioutput fusion, we need to emit each operand and the root. @@ -2573,10 +2612,22 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { : llvm_ir::KernelMappingScheme::DimY; } - // Return the dimension that is being ketp between DimX and DimY. - int GetKeptDimensionEnum() const { - return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY - : llvm_ir::KernelMappingScheme::DimX; + // Given the IrArray index of a reduction input, return the linear address of + // the reduction output as if the reduction were going the keep the input + // shape with the dimensions being reduced moved. + llvm::Value* GetUntransposedOutputLinearAddress( + llvm::IRBuilder<>* b, const llvm_ir::IrArray::Index& index) const { + if (IsRowReduction()) { + return index[llvm_ir::KernelMappingScheme::DimY]; + } + absl::Span dims_in_elem = + GetKernelMappingScheme()->GetDimensionsInElements(); + llvm::Value* x_dim_size = index.GetConstantWithIndexType( + dims_in_elem[llvm_ir::KernelMappingScheme::DimX]); + llvm::Value* x_block_offset = + b->CreateMul(index[llvm_ir::KernelMappingScheme::DimZ], x_dim_size); + return b->CreateAdd(x_block_offset, + index[llvm_ir::KernelMappingScheme::DimX]); } int GetNumberOfPartialResults() const { @@ -2601,6 +2652,9 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { AddressVector reduction_input_addresses_; InlinedVector reducers_; InlinedVector reduction_output_shape_indices_; + // The address of the memory that stores the linear index of the current + // output, assuming that the output doesn't change the layout of the kept + // elements in the reduction input. llvm::AllocaInst* current_output_linear_index_address_; llvm::AllocaInst* current_output_inbound_address_; bool is_row_reduction_; @@ -2624,7 +2678,7 @@ const HloInstruction* GetFirstReduceInstruction( absl::Span instructions) { auto first_reduce_iter = absl::c_find_if(instructions, [](const HloInstruction* inst) { - return inst->opcode() == HloOpcode::kReduce; + return IsReductionFromOrToContiguousDimensions(*inst); }); CHECK_NE(first_reduce_iter, instructions.end()); return *first_reduce_iter; @@ -2641,7 +2695,7 @@ void IrEmitterUnnested::EmitPrologueForOneReduction( InlinedVector* reducers = reduction_info->GetMutableReducers(); - CHECK(IsReductionToVector(*reduce_inst)); + CHECK(IsReductionFromOrToContiguousDimensions(*reduce_inst)); reducers->push_back(reduce_inst->to_apply()); InlinedVector* reduction_output_shape_indices = @@ -2703,7 +2757,7 @@ void IrEmitterUnnested::EmitPrologueForReduction( &b_, GetNestedComputer()); const HloInstruction* first_reduce = nullptr; for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { continue; } HloInstruction* reduce_inst = output_instructions[i]; @@ -2799,26 +2853,56 @@ void IrEmitterUnnested::EmitEpilogueForReduction( llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); } + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + std::vector reduce_instructions; + absl::c_for_each(GetOutputInstructions(&reduce_or_tuple), + [&](const HloInstruction* instr) { + if (IsReductionFromOrToContiguousDimensions(*instr)) { + reduce_instructions.push_back(instr); + } + }); int num_partial_results = reduction_info->GetNumberOfPartialResults(); // Emit an atomic operation that accumulates the partial reduction to the // output element. For row reduction, this is only for lane 0 due to the // if-statement emitted above. for (int i = 0; i != num_reduces; ++i) { + const HloInstruction* reduce_hlo = reduce_instructions[i]; + Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(reduce_hlo->dimensions(), dim); + }, + reduce_hlo->operand(0)->shape()); for (int j = 0; j < num_partial_results; ++j) { + // A reduction is allowed to transpose its output. For example, suppose + // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are + // allowed to produce as output either f32[10,30]{1,0} (no transpose) or + // f32[10,30]{0,1} (transposing the two output dims). + // + // At this point in the function we have a "partial sum" of input elements + // (stored in partial_result_addresses), and we need to accumulate it into + // the correct output element. + // + // *reduction_info->GetCurrentOutputLinearIndexAddress() stores the linear + // index in the output into which we would need to accumulate *if the + // output layout matched the input layout*. This is why we use + // `reduction_kept_element_shape` rather than `unnested_hlo->shape()` when + // computing `element_index` below. + auto output_array = GetIrArray(*unnested_hlo, *unnested_hlo, + reduction_output_shape_indices[i]); IrArray::Index element_index( /*linear=*/Load( InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), {b_.getInt32(j)}), - "output_linear_addr"), - ShapeUtil::GetSubshape(unnested_hlo->shape(), - reduction_output_shape_indices[i]), - &b_); - llvm::Value* output_address = - GetIrArray(*unnested_hlo, *unnested_hlo, - reduction_output_shape_indices[i]) - .EmitArrayElementAddress(element_index, &b_, - "output_element_address"); + "untransposed_output_linear_addr"), + reduction_kept_element_shape, &b_); + IrArray::Index output_index(element_index.multidim(), + output_array.GetShape(), + element_index.GetType()); + llvm::Value* output_address = output_array.EmitArrayElementAddress( + output_index, &b_, "output_element_address"); // Do not emit atomic operations if each element in the reduction result // is computed by one block, that is the dimension being reduced has only // one block. @@ -2855,14 +2939,14 @@ void IrEmitterUnnested::EmitTileElementForReduction( tiled_param_info->set_y(y_loc); tiled_param_info->set_x(x_loc); - // Record the linear address for the current reduction. + // Record the untransposed output linear address for the reduction. const ReductionCodegenInfo* reduction_info = dynamic_cast(kernel_info); int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num; - - Store(index[reduction_info->GetKeptDimensionEnum()], + Store(reduction_info->GetUntransposedOutputLinearAddress(&b_, index), InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), {b_.getInt32(partial_result_index)})); + if (!reduction_info->IsRowReduction()) { llvm::Type* bool_ty = b_.getInt1Ty(); llvm::AllocaInst* output_inbound_addr = @@ -2891,7 +2975,7 @@ void IrEmitterUnnested::EmitTileElementForReduction( if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { output_shape_index = {i}; } - if (inst->opcode() == HloOpcode::kReduce) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); } else { extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), @@ -2905,17 +2989,18 @@ void IrEmitterUnnested::EmitTileElementForReduction( }); } + Shape reduction_operand_shape = + GetFirstReduceInstruction(output_instructions)->operand(0)->shape(); IrArray::Index input_index = reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( - index, - GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); + index, reduction_operand_shape); + // Clear the linear index field of the IrArray::Index to enable the use of + // GetElementPointer with array types. This enables the vectorization of + // the computation for different partial results. Use this index if + // 'num_partial_results > 1'. int num_partial_results = reduction_info->GetNumberOfPartialResults(); - if (num_partial_results > 1) { - // Clear the linear index field of the IrArray::Index to enable the use of - // GetElementPointer with array types. This enables the vectorization of - // the computation for different partial results. - input_index.ClearLinearIndex(); - } + auto index_without_linear = IrArray::Index( + input_index.multidim(), reduction_operand_shape, input_index.GetType()); absl::Span partial_reduction_result_addresses = reduction_info->GetPartialResultAddresses(); absl::Span reduction_input_addresses = @@ -2926,7 +3011,10 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Emit code to generate the input and perform the reduction computation for // each reduction instruction. for (int i = 0; i != reducers.size(); ++i) { - llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); + llvm::Value* const input_ir_value = + input_gens[i](num_partial_results > 1 ? index_without_linear + : input_index) + .ValueOrDie(); Store(input_ir_value, reduction_input_addresses[i]); llvm::Value* partial_result_address = InBoundsGEP(partial_reduction_result_addresses[i], @@ -2938,8 +3026,9 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Emit code to generate the output for the non-reduction instructions in the // fusion, if any. - TF_CHECK_OK( - EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens)); + TF_CHECK_OK(EmitExtraOutputsForReduce( + unnested_hlo, input_index, + /*use_linear_index=*/num_partial_results == 1, extra_output_gens)); } // Emits a kernel for the hlo instruction using the given tiling scheme. @@ -3096,8 +3185,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), - ConstructIrArrayForOutputs(*unnested_hlo), &b_, - module_); + ConstructIrArrayForOutputs(*unnested_hlo), &b_); }); } @@ -3139,7 +3227,9 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( bool block_contains_multi_tiles) { // Calculate the input tile origin from the output tile origin. const IrArray::Index input_tile_origin( - Permute({0, 2, 1}, output_tile_origin.multidim())); + Permute({0, 2, 1}, output_tile_origin.multidim()), + Permute({0, 2, 1}, output_tile_origin.dims()), + output_tile_origin.GetType()); // If shared memory transpose is needed, wait for all threads to reach this // point, lest we copy a value from tile to output before the other thread @@ -3167,7 +3257,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( }); // Wait for all threads to reach this point using `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); } llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); @@ -3189,7 +3279,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // buffer for the current tile before we move on to process the next tile // and overwrite the shared memory buffers. if (block_contains_multi_tiles && !tiled_param_ids.empty()) { - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); } }; @@ -3365,10 +3455,8 @@ std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { HloOpcode opcode = hlo->opcode(); - CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); - CHECK(opcode != HloOpcode::kFusion || - hlo->fusion_kind() == HloInstruction::FusionKind::kLoop) - << "Only loop fusions are supported."; + + CHECK(hlo->IsLoopFusion() || opcode == HloOpcode::kCopy); const Shape& output_shape = hlo->IsMultiOutputFusion() ? ShapeUtil::GetSubshape(hlo->shape(), {0}) @@ -3466,7 +3554,7 @@ Status AreFusedReductionOutputsConsistent( absl::Span output_instructions, const HloInstruction* first_reduce) { for (const HloInstruction* inst : output_instructions) { - if (inst->opcode() == HloOpcode::kReduce) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { // Shapes, layouts and dimensions must be the same for all reduces // inside of this fusion. TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); @@ -3488,72 +3576,6 @@ Status AreFusedReductionOutputsConsistent( return Status::OK(); } -// Finds the dimensions to keep for the reduction, sorts and returns the -// dimensions from minor to major. -DimensionVector GetDimensionsToKeepMinorToMajor( - const Shape& input_shape, absl::Span dims_to_reduce) { - DimensionVector input_dims(input_shape.rank(), 0); - absl::c_iota(input_dims, 0); - DimensionVector input_dims_to_keep; - for (int input_dim : input_dims) { - auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) { - return dim_to_reduce == input_dim; - }); - if (it == dims_to_reduce.end()) { - input_dims_to_keep.push_back(input_dim); - } - } - - // Sort the dimensions to keep from minor to major. - absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < - PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); - }); - - VLOG(10) << "dims to keep minor to major" - << absl::StrJoin(input_dims_to_keep, ","); - return input_dims_to_keep; -} - -// Given the input shape and dimensions to reduce for the reduction to vector, -// returns : -// num_kept: the number of elements in the contiguous dimensions to keep. -// num_reduced_major: the number of elements in the dimensions to reduce that -// are more major than the dimensions to keep. -// num_reduced_minor: the number of elements in the dimensions to reduce that -// are more minor than the dimensions to kept. -std::tuple GetReductionToVectorDimensions( - const Shape& input_shape, absl::Span dims_to_reduce) { - DimensionVector input_dims_to_keep_minor_to_major = - GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce); - CHECK(LayoutUtil::AreDimensionsConsecutive( - input_shape.layout(), input_dims_to_keep_minor_to_major)); - int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1; - if (input_dims_to_keep_minor_to_major.empty()) { - return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); - } - DimensionVector input_dims(input_shape.rank(), 0); - absl::c_iota(input_dims, 0); - absl::Span minor_to_major = - LayoutUtil::MinorToMajor(input_shape); - for (int input_dim : input_dims) { - int64 curr_dim_size = input_shape.dimensions(input_dim); - if (PositionInContainer(minor_to_major, input_dim) > - PositionInContainer(minor_to_major, - input_dims_to_keep_minor_to_major.back())) { - num_reduced_major *= curr_dim_size; - } else if (PositionInContainer(minor_to_major, input_dim) < - PositionInContainer(minor_to_major, - input_dims_to_keep_minor_to_major.front())) { - num_reduced_minor *= curr_dim_size; - } else { - num_kept *= curr_dim_size; - } - } - - return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); -} - // Returns true if all the transitive users of hlo before hitting users in // use_chain_endings are elementwise operations. bool AreUsersElementwise(const HloInstruction* hlo, @@ -3597,14 +3619,14 @@ int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo, // unrolling is beneficial for the given kInput fusion. bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo, const Shape& input_shape, - int64 num_kept) { + int64 num_kept_minor) { // TODO(b/122468062): Need further investigate to see whether we can // remove the constraint on IsPowerOfTwo. - if (!IsPowerOfTwo(static_cast(num_kept))) { + if (!IsPowerOfTwo(static_cast(num_kept_minor))) { return false; } - if (unnested_hlo->opcode() == HloOpcode::kReduce) { + if (IsReductionFromOrToContiguousDimensions(*unnested_hlo)) { return true; } @@ -3613,14 +3635,14 @@ bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo, int64 cannot_be_vectorized = 0; const HloInstruction* fused_root = unnested_hlo->fused_expression_root(); ConstHloInstructionSet use_chain_endings; - if (fused_root->opcode() == HloOpcode::kReduce) { + if (IsReductionFromOrToContiguousDimensions(*fused_root)) { use_chain_endings.insert(fused_root); // Atomic.add of the reduction result can't be vectorized. cannot_be_vectorized++; } else { CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple); for (const HloInstruction* instr : fused_root->operands()) { - if (instr->opcode() == HloOpcode::kReduce) { + if (IsReductionFromOrToContiguousDimensions(*instr)) { // Atomic.add of the reduction result can't be vectorized. cannot_be_vectorized++; } else { @@ -3649,35 +3671,41 @@ bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo, std::tuple IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) { - int64 depth = 1; - int64 height = 1; - int64 width = 1; - bool is_row_reduction = true; + const Shape& input_shape = first_reduce->operand(0)->shape(); + bool is_row_reduction; + DimensionVector dims_in_elem; + std::tie(is_row_reduction, dims_in_elem) = + GetReductionKindAndContiguousComponents(input_shape, + first_reduce->dimensions()); + VLOG(10) << "is_row_reduction " << is_row_reduction << " " << dims_in_elem[0] + << " " << dims_in_elem[1] << " " << dims_in_elem[2]; + int64 tile_size_x = 1; int64 tile_size_y = 1; int64 block_size_z = 1; int64 num_threads_x = 1; int64 num_threads_y = 1; - const Shape& input_shape = first_reduce->operand(0)->shape(); - int64 num_input_elems = ShapeUtil::ElementsIn(input_shape); - int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape()); - int64 num_reduced_major, num_kept, num_reduced_minor; - std::tie(num_reduced_major, num_kept, num_reduced_minor) = - GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); - CHECK_EQ(num_output_elems, num_kept); bool dilated_x = true; - - if (num_kept == 1) { - // Scalar reduction is a special row reduction with depth = height = 1. - width = num_input_elems; - tile_size_x = kWarpSize * 16; - num_threads_x = kWarpSize; - } else if (num_reduced_minor == 1) { - // Column reduction reduces inputs with dimension [height, width], where - // width is the minor dimension, to dimension [width]. - height = num_reduced_major; - width = num_kept; - is_row_reduction = false; + if (is_row_reduction) { + if (dims_in_elem[1] == 1) { + // Scalar reduction is handled differently than the other kind of row + // reduction. + CHECK_EQ(dims_in_elem[0], 1); + tile_size_x = kWarpSize * 16; + num_threads_x = kWarpSize; + } else { + num_threads_x = kWarpSize; + if (dims_in_elem[2] % (kWarpSize * 64) == 0) { + tile_size_x = kWarpSize * 64; + } else { + tile_size_x = kWarpSize * 8; + block_size_z = 8; + while (dims_in_elem[0] % block_size_z != 0) { + block_size_z -= 1; + } + } + } + } else { // Column reduction without transpose doesn't require communication among // threads processing elements in the same tile. The current implementation // only support the use of one hardware thread block to process one block of @@ -3687,38 +3715,18 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( int64 hw_threads_per_block_limit = ThreadsPerBlockLimit(ir_emitter_context_->device_description()); if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, - num_kept)) { - tile_size_x = std::min(2 * hw_threads_per_block_limit, num_kept); + dims_in_elem[2])) { + tile_size_x = std::min(2 * hw_threads_per_block_limit, dims_in_elem[2]); num_threads_x = tile_size_x / 2; dilated_x = false; } else { - tile_size_x = std::min(hw_threads_per_block_limit, num_kept); + tile_size_x = std::min(hw_threads_per_block_limit, dims_in_elem[2]); num_threads_x = tile_size_x; } int64 kNumElementsPerPartialSum = 128; tile_size_y = kNumElementsPerPartialSum; - } else { - // Row reduction reduces inputs with dimension [depth, height, width], - // where width is the most minor dimension, to dimension [height] . - depth = num_reduced_major; - height = num_kept; - width = num_reduced_minor; - num_threads_x = kWarpSize; - if (width % (kWarpSize * 64) == 0) { - tile_size_x = kWarpSize * 64; - } else { - tile_size_x = kWarpSize * 8; - block_size_z = 8; - while (depth % block_size_z != 0) { - block_size_z -= 1; - } - } } - DCHECK_EQ(depth * height * width, num_input_elems); - VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height - << " " << width; - DimensionVector dims_in_elem{depth, height, width}; DimensionVector req_block_sizes{block_size_z, 1, 1}; llvm_ir::KernelMappingScheme mapping_scheme( dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, @@ -3727,7 +3735,8 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( return std::make_tuple(mapping_scheme, is_row_reduction); } -Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { +Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( + HloInstruction* unnested_hlo) { VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion @@ -3746,7 +3755,7 @@ Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { // Build an initializer thunk to initialize each reduction output. std::vector> thunks; for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (output_instructions[i]->opcode() != HloOpcode::kReduce) { + if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { continue; } TF_ASSIGN_OR_RETURN( @@ -3836,8 +3845,7 @@ Status IrEmitterUnnested::EmitConstantGlobals() { global_type, /*isConstant=*/should_emit_initializer, llvm::GlobalValue::ExternalLinkage, /*Initializer=*/initializer, - llvm_ir::AsStringRef( - llvm_ir::ConstantBufferAllocationToGlobalName(allocation))); + llvm_ir::ConstantBufferAllocationToGlobalName(allocation)); global_for_const->setAlignment(kConstantBufferAlignBytes); ir_emitter_context_->llvm_module()->getGlobalList().push_back( global_for_const); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 9890ce122df..d627ca9ef02 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -209,13 +209,15 @@ class IrEmitterUnnested : public IrEmitter { // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, + bool use_linear_index, absl::Span> extra_output_gens); // Generates code for reduction to contiguous dimensions. // - // Prerequisite: `IsReductionToVector(*unnested_hlo)` - Status EmitReductionToVector(HloInstruction* unnested_hlo); + // Prerequisite: `IsReductionFromOrToContiguousDimensions(*unnested_hlo)` + Status EmitReductionFromOrToContiguousDimensions( + HloInstruction* unnested_hlo); // Computes the KernelMappingScheme for the reduce HLO and indicates whether // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index e09b8fbd3ba..fbe22e3a18e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -16,13 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/kernel.h" namespace xla { namespace gpu { @@ -39,16 +45,6 @@ KernelThunk::KernelThunk(absl::Span args, Status KernelThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { tensorflow::mutex_lock lock(mutex_); - if (!loader_spec_) { - loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - loader_spec_->AddCudaPtxInMemory(executable.ptx(), kernel_name_); - - if (!executable.cubin().empty()) { - loader_spec_->AddCudaCubinInMemory( - reinterpret_cast(executable.cubin().data()), - kernel_name_); - } - } // Load the kernel into the device if necessary. // @@ -57,10 +53,12 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, // profiles. auto it = kernel_cache_.find(executor); if (kernel_cache_.end() == it) { - it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; - if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_); - } + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + CreateKernel(kernel_name_, args_.size(), executable.ptx(), + executable.cubin(), executor)); + + kernel_cache_.emplace(executor, std::move(kernel)); } return Status::OK(); @@ -85,27 +83,22 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, CHECK(it != kernel_cache_.end()) << "Initialize() not called for StreamExecutor " << executor; launch_dimensions = launch_dimensions_; - kernel = &it->second; + kernel = it->second.get(); } VLOG(3) << "Launching " << kernel->name(); - // Launch the kernel with potentially multiple blocks and threads. - static constexpr int kKernelArgsLimit = 1024; - auto kernel_args = absl::make_unique>(); + absl::InlinedVector buffer_args; for (const BufferAllocation* arg : args_) { - const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); - kernel_args->add_device_memory_argument(buf); - VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " (" + se::DeviceMemoryBase buf = + buffer_allocations.GetDeviceAddress(arg->index()); + VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " (" << buf.size() << "B)"; + buffer_args.push_back(buf); } auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - if (!stream->parent()->Launch( - stream, se::ThreadDim(launch_dimensions.threads_per_block()), - se::BlockDim(launch_dimensions.block_count()), *kernel, - *kernel_args)) { - return InternalError("Unable to launch kernel %s", kernel_name_); - } - return Status::OK(); + return ExecuteKernelOnStream(*kernel, buffer_args, + launch_dimensions.threads_per_block(), + launch_dimensions.block_count(), stream); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index f63db5c3696..2cea89e4e2a 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -84,12 +84,11 @@ class KernelThunk : public Thunk { // Describes how to load this kernel. ExecuteOnStream reuses this loader // specification for all executions. mutable tensorflow::mutex mutex_; - std::unique_ptr loader_spec_ GUARDED_BY(mutex_); // Loaded kernels for each `StreamExecutor`. Requires pointer stability of // values. - std::unordered_map kernel_cache_ - GUARDED_BY(mutex_); + std::unordered_map> + kernel_cache_ GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 698d2d51cc8..ca42807edd1 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 153aab97d9e..34966b197e7 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -21,12 +21,6 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" - #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" @@ -55,11 +49,17 @@ limitations under the License. #include "llvm/Transforms/IPO/Internalize.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Scalar.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/profiler/lib/traceme.h" namespace xla { namespace gpu { @@ -140,10 +140,9 @@ static string GetSmName(std::pair compute_capability) { // Convenience function for producing a name of a temporary compilation product // from the input filename. -string MakeNameForTempProduct(const std::string& input_filename, +string MakeNameForTempProduct(absl::string_view input_filename, absl::string_view extension) { - return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename( - llvm_ir::AsString(input_filename))), + return ReplaceFilenameExtension(tensorflow::io::Basename(input_filename), extension); } @@ -220,7 +219,6 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level, builder.Inliner = llvm::createAlwaysInlinerLegacyPass(); } - builder.DisableUnitAtATime = false; builder.DisableUnrollLoops = opt_level == 0; builder.LoopVectorize = opt_level > 0; builder.SLPVectorize = opt_level > 1 && size_level < 2; @@ -254,11 +252,8 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { llvm::buffer_ostream pstream(stream); // The extension is stripped by IrDumpingPassManager, so we need to // get creative to add a suffix. - string module_id(llvm_ir::AsString(module->getModuleIdentifier())); IrDumpingPassManager codegen_passes( - ReplaceFilenameExtension( - absl::string_view(tensorflow::io::Basename(module_id)), - "-nvptx.dummy"), + MakeNameForTempProduct(module->getModuleIdentifier(), "-nvptx.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -336,7 +331,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // If the module has no functions or globals, there's nothing to compile. Just // return an empty string. if (module->empty() && module->global_empty()) { - VLOG(2) << "Module '" << llvm_ir::AsString(module->getName()) + VLOG(2) << "Module '" << module->getName().str() << "' is empty. Skipping compilation."; return string(); } @@ -492,11 +487,10 @@ StatusOr CompileToPtx(llvm::Module* module, string ptx; { - tensorflow::tracing::ScopedActivity activity( - "Compiling IR", llvm_ir::AsString(module->getName()), - /*is_expensive=*/true); - XLA_SCOPED_LOGGING_TIMER("Compile module " + - llvm_ir::AsString(module->getName())); + tensorflow::profiler::TraceMe activity( + [&] { return absl::StrCat("Compiling IR:", module->getName().str()); }, + tensorflow::profiler::TraceMeLevel::kInfo); + XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); TF_ASSIGN_OR_RETURN( ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config, libdevice_dir_path)); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 02e1207f377..a00900fabab 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -45,15 +45,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, } bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { - // We can fuse reduces and loop fusions. Elementwise instructions can be fused - // with any other instruction. - // TODO(b/112957171): This should use the same isFusible logic as - // instruction_fusion. - return instr->IsFusible() && - (IsInputFusibleReduction(*instr) || - (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop) || - instr->IsElementwise()); + return IsFusibleAsMultiOutputFusionRoot(*instr); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -92,8 +84,8 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, CHECK(instr1->opcode() == HloOpcode::kFusion); if ((instr2->opcode() == HloOpcode::kFusion && instr1->fusion_kind() != instr2->fusion_kind()) || - (IsReductionToVector(*instr2) && - instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { + (IsReductionFromOrToContiguousDimensions(*instr2) && + instr1->IsLoopFusion())) { return false; } @@ -155,10 +147,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " is a constant."; continue; } - const bool is_loop_fusion = - producer->opcode() == HloOpcode::kFusion && - producer->fusion_kind() == HloInstruction::FusionKind::kLoop; - if (!producer->IsElementwise() && !is_loop_fusion) { + if (!producer->IsElementwise() && !producer->IsLoopFusion()) { VLOG(3) << producer->name() << " is not a loop fusion."; continue; } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 40b87b16a19..2aa61b8951a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -303,21 +303,21 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { - p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) - ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0) + ROOT mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1) } fused_computation_2 { - p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0) const.2 = f32[] constant(0) - ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation + ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2), dimensions={0,3}, to_apply=scalar_add_computation } ENTRY entry { - p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) - fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1 - fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 - ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2) + p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0) + fusion.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0}) tuple(fusion.1, fusion.2) })")) .ValueOrDie(); ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); @@ -329,7 +329,8 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) - ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, + f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) } fused_computation_2 { @@ -340,11 +341,16 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { ENTRY entry { p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) - fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 - fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, + f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, + calls=fused_computation_1 + fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, + calls=fused_computation_2 gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 - ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, + f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) + tuple(gte0, gte1, fusion.2) })")) .ValueOrDie(); ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); @@ -360,25 +366,32 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { - p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) - mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) - exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) - ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,2]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, + f32[8,1,5,16,1,2]{5,4,3,2,1,0}) tuple(mul, exp) } fused_computation_2 { - p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0) const.2 = f32[] constant(0) - ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation + ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2), + dimensions={0,3}, to_apply=scalar_add_computation } ENTRY entry { - p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) - fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 - fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 - gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 - gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 - ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, + f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, + calls=fused_computation_1 + fusion.2 = f32[8,1,5,1,2]{4,3,2,1,0} fusion(p0), kind=kLoop, + calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, + f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) + tuple(gte0, gte1, fusion.2) })")) .ValueOrDie(); ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); @@ -387,11 +400,12 @@ TEST_F(MultiOutputFusionTest, TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[32,32,32]{2,1,0} parameter(0) c0 = f32[] constant(0) - exp = f32[2,2,2]{2,1,0} exponential(p0) - reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation - ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + exp = f32[32,32,32]{2,1,0} exponential(p0) + reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2}, + to_apply=scalar_add_computation + ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp) })")) .ValueOrDie(); ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); @@ -407,18 +421,19 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_add { - p0.1 = f32[2,2,2]{2,1,0} parameter(0) - p1.1 = f32[2,2,2]{2,1,0} parameter(1) - ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + p0.1 = f32[32,32,32]{2,1,0} parameter(0) + p1.1 = f32[32,32,32]{2,1,0} parameter(1) + ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1) } ENTRY reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) - p1 = f32[2,2,2]{2,1,0} parameter(1) + p0 = f32[32,32,32]{2,1,0} parameter(0) + p1 = f32[32,32,32]{2,1,0} parameter(1) c0 = f32[] constant(0) - add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add - reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation - ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add) + add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add + reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2}, + to_apply=scalar_add_computation + ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add) })")) .ValueOrDie(); ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); @@ -434,31 +449,37 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { - p1.1 = f32[2,2,2]{2,1,0} parameter(1) + p1.1 = f32[32,32,32]{2,1,0} parameter(1) c0 = f32[] constant(0) - broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) - p0.1 = f32[2,2,2]{2,1,0} parameter(0) - ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1, + f32[32,32,32]{2,1,0} broadcast), direction=GT + p0.1 = f32[32,32,32]{2,1,0} parameter(0) + ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0} + greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast) } fused_reduce { - p0.2 = f32[2,2,2]{2,1,0} parameter(0) + p0.2 = f32[32,32,32]{2,1,0} parameter(0) c1 = f32[] constant(0) - r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation - mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) - r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation - ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2}, + to_apply=scalar_add_computation + mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2}, + to_apply=scalar_add_computation + ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2) } ENTRY reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) - p1 = f32[2,2,2]{2,1,0} parameter(1) - select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select - fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce - gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 - gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 - ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + p0 = f32[32,32,32]{2,1,0} parameter(0) + p1 = f32[32,32,32]{2,1,0} parameter(1) + select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput, + calls=fused_reduce + gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) + tuple(gte1, gte1, select) })")) .ValueOrDie(); ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); @@ -482,9 +503,12 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { fused_reduce { p0.2 = f32[2,2,2]{2,1,0} parameter(0) - mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, + f32[2,2,2]{2,1,0} p0.2) + broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1} c1 = f32[] constant(0) - ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add_computation + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast, + f32[] c1), dimensions={1,3}, to_apply=scalar_add_computation } ENTRY reduce { @@ -502,30 +526,36 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionFp16LoopFusionAndReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { - p1.1 = f16[2,2,2]{2,1,0} parameter(1) + p1.1 = f16[32,32,32]{2,1,0} parameter(1) c0 = f16[] constant(0) - broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast) - p0.1 = f16[2,2,2]{2,1,0} parameter(0) - ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast) + broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={} + greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1, + f16[32,32,32]{2,1,0} broadcast), direction=GT + p0.1 = f16[32,32,32]{2,1,0} parameter(0) + ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0} + greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast) } fused_reduce { - p0.2 = f16[2,2,2]{2,1,0} parameter(0) - convert = f32[2,2,2]{2,1,0} convert(p0.2) + p0.2 = f16[32,32,32]{2,1,0} parameter(0) + convert = f32[32,32,32]{2,1,0} convert(p0.2) c1 = f32[] constant(0) - r1 = f32[2,2]{1,0} reduce(convert, c1), dimensions={2}, to_apply=scalar_add_computation - mul = f32[2,2,2]{2,1,0} multiply(convert, convert) - r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation - ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2}, + to_apply=scalar_add_computation + mul = f32[32,32,32]{2,1,0} multiply(convert, convert) + r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2}, + to_apply=scalar_add_computation + ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2) } ENTRY reduce { - p0 = f16[2,2,2]{2,1,0} parameter(0) - p1 = f16[2,2,2]{2,1,0} parameter(1) - select = f16[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select - fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce - gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 - gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 - ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + p0 = f16[32,32,32]{2,1,0} parameter(0) + p1 = f16[32,32,32]{2,1,0} parameter(1) + select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[32,32]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, + calls=fused_reduce + gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0}) + tuple(gte1, gte1, select) })")) .ValueOrDie(); ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); @@ -548,7 +578,7 @@ TEST_F(MultiOutputFusionTest, copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) c0 = f16[] constant(0) broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} - greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) } fused_reduce { diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 3051db3af4a..c00edae9540 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #if GOOGLE_CUDA +#include "absl/container/flat_hash_set.h" #include "absl/synchronization/blocking_counter.h" #include "third_party/nccl/nccl.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -76,6 +77,42 @@ struct ParticipantData { // This manager is responsible for establishing communication channels and // ultimately enqueueing the NCCL library operation onto the participating // streams. +// +// Implementation note: We make an effort to avoid initializing nccl +// communciation channels too often, as this is expensive. +// +// Ideally, we'd set up a nccl channel between each pair of devices that needs +// to communicate, and close each channel when the GPUs won't be communicating +// again "for a long time" (because channels hold memory on the GPU). As a +// simplification to this ideal, we adopt the following policy. +// +// - We maintain a set of GPUs that are "actively participating" in +// cross-device communications. That set of GPUs is always connected as a +// clique, using ncclCommInitAll. +// +// - When a NcclAllReduceThunk touches a new GPU, we tear down the old clique +// and build a new, bigger one. +// +// - All GPUs ever touched by a thunk are considered "actively in use" by that +// thunk until the thunk is destroyed. Destroying the thunk decrements the +// refcount of the GPUs it's touched, and if that refcount goes to 0 +// (meaning, some GPUs are no longer in use by any thunk), we tear down the +// clique and build a new, smaller one. +// +// This approximation is justified because: +// +// - Currently the only collective operation we support is AllReduce, which +// requires a clique. When we support point-to-point operations, we may not +// want to build a communication clique. +// +// - Tearing down and creating a new thunk is tantamount to running the whole +// XLA:GPU compiler. This is expensive, so shouldn't happen "too often" to +// cause thrashing here. +// +// - XLA executables already keep resources on the GPU tied to the lifetime of +// the executable (e.g. constants stored in GPU memory), so tying the +// lifetime of the nccl communication channels to the lifetime of the +// executable is consistent. class GlobalRendezvousManager { public: // The GpuExecutable-executing threads call this in order to a) establish the @@ -98,18 +135,38 @@ class GlobalRendezvousManager { return current_generation_; } - private: - // Called by the primary thread to set up the communication links. + // Increments the refcount of a GPU in our accounting of which devices are + // "actively participating" in cross-device operations. // - // TODO(b/125951860): This performs lots of (presumably) unnecessary host-side - // synchronization so that we can be paranoid about semantics in the earliest - // implementation. In the limit we should only need to synchronize host - // replica threads when the "number of replicas" or "participating device - // ordinals" change, to set up a new NCCL "communication" context, at which - // point we can enqueue onto device streams without host synchronization in - // our code -- this will likely be helpful for "lots of little AllReduce" - // cases. - Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // This doesn't actually do anything other than increment the refcount. If + // the GPU added here is novel, we'll rebuild the nccl communication clique + // when we actually go do the communication. + void AddrefParticipatingDevice(int device_ordinal); + + // Decrements the refcount of a set of GPUs in our accounting of which devices + // are "actively participating" in cross-device operations. + // + // If one or more GPUs' refcounts to go 0, we immediately destroy the whole + // nccl communication clique. We'll rebuild a new, smaller clique the next + // time it's used. + void DecrefParticipatingDevices(absl::Span device_ordinals); + + // Gets the set of devices that have a NCCL channel currently open. This is + // primarily for testing. + absl::flat_hash_set DevicesWithOpenNcclChannels() const { + absl::flat_hash_set devices; + tensorflow::mutex_lock lock(mutex_); + for (const auto& kv : comms_) { + devices.insert(kv.first); + } + return devices; + } + + private: + // Destroys the current nccl communication clique and builds a new one + // connecting the given devices. + Status ReinitializeNcclClique(const absl::flat_hash_set& device_ordinals) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Called when all necessary participants are present, the functionality // that's implemented by all executing threads lives in here. @@ -118,28 +175,51 @@ class GlobalRendezvousManager { // Puts all state back into a "reset" state for the next generation of // AllReduce requests. void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - for (ncclComm_t& comm : comms_) { - ncclCommDestroy(comm); - } - comms_.clear(); participants_.clear(); current_generation_++; initialized_ = false; done_ = absl::nullopt; } - tensorflow::mutex mutex_; + mutable tensorflow::mutex mutex_; tensorflow::condition_variable all_participants_present_; tensorflow::condition_variable deinitialized_; - // Communication handles that correspond to the participants below. - std::vector comms_ GUARDED_BY(mutex_); - Status initialize_status_ GUARDED_BY(mutex_); std::vector participants_ GUARDED_BY(mutex_); int64 current_generation_ GUARDED_BY(mutex_) = 0; bool initialized_ GUARDED_BY(mutex_) = false; + struct Comm { + explicit Comm(ncclComm_t nccl_comm) : nccl_comm(nccl_comm) {} + + // Movable, but not copyable. + Comm(Comm&& c) : nccl_comm(c.nccl_comm) { c.nccl_comm.reset(); } + Comm& operator=(Comm&& c) { + nccl_comm = c.nccl_comm; + c.nccl_comm.reset(); + return *this; + } + Comm(const Comm&) = delete; + Comm& operator=(const Comm&) = delete; + + absl::optional nccl_comm; + + ~Comm() { + if (nccl_comm.has_value()) { + VLOG(3) << absl::StreamFormat("Destroying comm %p", *nccl_comm); + ncclCommDestroy(*nccl_comm); + } + } + }; + // Communication handles for our NCCL clique. Key is device ordinal. + absl::flat_hash_map comms_ GUARDED_BY(mutex_); + + // Refcounts of which devices are "actively participating" in all-reduces. + // These devices don't necessarily have an open comm, but the next time we run + // an operation, we'll create a NCCL clique between all of them. + absl::flat_hash_map device_refcounts_ GUARDED_BY(mutex_); + // The participating threads wait for this to count down in order to know we // can begin the teardown process. absl::optional done_; @@ -151,11 +231,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { return participants_.size() >= participant.replica_count; }; - // We remember the participant index at which we are inserted and use that - // same index for referring to auxiliary metadata (e.g. the ncclComm_t handle - // index) below. - int64 index; - { tensorflow::mutex_lock lock(mutex_); @@ -171,7 +246,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { "participants; existing: %s; submitted: %s)", participants_.back().ToString(), participant.ToString()); } - index = participants_.size(); participants_.push_back(participant); if (all_participants_present()) { @@ -205,11 +279,35 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { VLOG(3) << "Primary initializing accounting data."; initialized_ = true; done_.emplace(participant.replica_count); - initialize_status_ = InitializeCommunicationChannels(); - VLOG(3) << "Done initializing communication channels; status: " - << initialize_status_; - if (!initialize_status_.ok()) { - DeinitializeGeneration(); + + // Check if all participants_ are in comms_. If not, we will rebuild the + // clique to include them. (This can't be spelled using absl::c_any_of + // because it needs to touch comms_ and tensorflow::mutex lacks an + // AssertHeld() function that would let us assert that the lambda is run + // while holding the lock.) + bool new_devices_found = false; + for (const auto& p : participants_) { + if (!comms_.contains(p.device_ordinal)) { + new_devices_found = true; + break; + } + } + + if (new_devices_found) { + absl::flat_hash_set new_clique_device_ordinals; + for (const auto& kv : comms_) { + new_clique_device_ordinals.insert(kv.first); + } + for (const auto& p : participants_) { + new_clique_device_ordinals.insert(p.device_ordinal); + } + + initialize_status_ = ReinitializeNcclClique(new_clique_device_ordinals); + VLOG(3) << "Done initializing communication channels; status: " + << initialize_status_; + if (!initialize_status_.ok()) { + DeinitializeGeneration(); + } } } @@ -218,7 +316,7 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { return initialize_status_; } - comm = comms_[index]; + comm = *comms_.at(participant.device_ordinal).nccl_comm; // Drop the lock at the end of scope so other participants may enter. } @@ -259,22 +357,30 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { return all_reduce_status; } -Status GlobalRendezvousManager::InitializeCommunicationChannels() { - std::vector ordinals; - for (ParticipantData& data : participants_) { - ordinals.push_back(data.device_ordinal); - } - comms_.resize(ordinals.size()); - VLOG(3) << "Participants: " << participants_.size() - << "; initializing comms."; - ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(), - /*devlist=*/ordinals.data()); +Status GlobalRendezvousManager::ReinitializeNcclClique( + const absl::flat_hash_set& device_ordinals) { + comms_.clear(); + + std::vector ordinals_vec(device_ordinals.begin(), device_ordinals.end()); + std::vector comm_vec; + comm_vec.resize(device_ordinals.size()); + + VLOG(3) << absl::StreamFormat( + "Initializing nccl comms for participant devices {%s}", + absl::StrJoin(ordinals_vec, ", ")); + ncclResult_t result = ncclCommInitAll(comm_vec.data(), comm_vec.size(), + /*devlist=*/ordinals_vec.data()); if (result != ncclSuccess) { - comms_.clear(); return InternalError( "Failed to initialize NCCL communication channels for %d participants: " "%s", - participants_.size(), ncclGetErrorString(result)); + ordinals_vec.size(), ncclGetErrorString(result)); + } + + for (int64 i = 0; i < ordinals_vec.size(); ++i) { + VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p", + ordinals_vec[i], comm_vec[i]); + CHECK(comms_.emplace(ordinals_vec[i], Comm{comm_vec[i]}).second); } return Status::OK(); } @@ -289,6 +395,11 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, << " on device: " << participant.device_ordinal; void* send_buffer = participant.source_data.opaque(); void* recv_buffer = participant.destination_data.opaque(); + VLOG(3) << absl::StreamFormat( + "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, " + "datatype=ncclFloat, op=ncclSum, comm=%p, stream=%p)", + send_buffer, recv_buffer, participant.element_count, + static_cast(comm), cu_stream); ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer, /*count=*/participant.element_count, /*datatype=*/ncclFloat, @@ -304,6 +415,36 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, return Status::OK(); } +void GlobalRendezvousManager::AddrefParticipatingDevice(int device_ordinal) { + // Addref'ing a device doesn't do anything other than increment its refcount. + // We'll update our nccl clique if necessary during the next call to + // SubmitParticipant. + tensorflow::mutex_lock lock(mutex_); + device_refcounts_[device_ordinal]++; +} + +void GlobalRendezvousManager::DecrefParticipatingDevices( + absl::Span device_ordinals) { + // Decref'ing devices causes us to destroy the nccl clique if any devices were + // removed due to having refcount 0. We'll rebuild the new, smaller clique + // during the next call to SubmitParticipant. + tensorflow::mutex_lock lock(mutex_); + bool removed_device = false; + for (int device_ordinal : device_ordinals) { + auto it = device_refcounts_.find(device_ordinal); + CHECK(it != device_refcounts_.end()); + it->second--; + if (it->second == 0) { + device_refcounts_.erase(it); + removed_device = true; + } + } + + if (removed_device) { + comms_.clear(); + } +} + static GlobalRendezvousManager* GetGlobalRendezvous() { static auto* manager = new GlobalRendezvousManager; return manager; @@ -311,6 +452,11 @@ static GlobalRendezvousManager* GetGlobalRendezvous() { } // namespace +/*static*/ absl::flat_hash_set +NcclAllReduceThunk::DevicesWithOpenNcclChannels() { + return GetGlobalRendezvous()->DevicesWithOpenNcclChannels(); +} + Status NcclAllReduceThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { @@ -327,8 +473,32 @@ Status NcclAllReduceThunk::ExecuteOnStream( participant.stream = stream; participant.originator = this; + // We currently say that that all GPUs this thunk has ever touched are + // "actively participating" in cross-device operations, until the thunk itself + // is destroyed. + // + // This policy is an attempt to avoid thrashing the GPU (ncclCommInitAll is + // very expensive) while also freeing resources on the GPUs when we can. The + // idea is, creating new thunks is tantamount to running the whole XLA:GPU + // compiler stack, so that shouldn't happen terribly often. + bool new_device; + { + tensorflow::mutex_lock lock(mu_); + new_device = devices_seen_.insert(participant.device_ordinal).second; + } + if (new_device) { + GetGlobalRendezvous()->AddrefParticipatingDevice( + participant.device_ordinal); + } + return GetGlobalRendezvous()->SubmitParticipant(std::move(participant)); } + +NcclAllReduceThunk::~NcclAllReduceThunk() { + GetGlobalRendezvous()->DecrefParticipatingDevices( + std::vector(devices_seen_.begin(), devices_seen_.end())); +} + #else Status NcclAllReduceThunk::ExecuteOnStream( @@ -339,6 +509,13 @@ Status NcclAllReduceThunk::ExecuteOnStream( "compiler, which is necessary to build the NCCL source library."); } +NcclAllReduceThunk::~NcclAllReduceThunk() = default; + +/*static*/ absl::flat_hash_set +NcclAllReduceThunk::DevicesWithOpenNcclChannels() { + return {}; +} + #endif // GOOGLE_CUDA NcclAllReduceThunk::NcclAllReduceThunk( diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 1a8d1356c00..9ff4fb187af 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -38,12 +40,21 @@ class NcclAllReduceThunk : public Thunk { // error. static bool NcclIsEnabled(); + // Gets the set of devices that have a NCCL channel open. This is primarily + // for testing. + // + // (Indeed, because the NCCL channels are a global variable, in the real + // world, the value returned here is stale as soon as you read it, so it's not + // clear how you *could* use it for anything other than tests.) + static absl::flat_hash_set DevicesWithOpenNcclChannels(); + // TODO(b/125951860): Plumb more datatypes / reduction operators. Initial // implementation is simply F32 summation. NcclAllReduceThunk(int64 replica_count, int64 element_count, const BufferAllocation::Slice& source_buffer, const BufferAllocation::Slice& destination_buffer, const HloInstruction* all_reduce); + ~NcclAllReduceThunk() override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, @@ -54,6 +65,10 @@ class NcclAllReduceThunk : public Thunk { const int64 element_count_; const BufferAllocation::Slice source_buffer_; const BufferAllocation::Slice destination_buffer_; + + tensorflow::mutex mu_; + // Set of GPUs that ExecuteOnStream has been called on. + absl::flat_hash_set devices_seen_ GUARDED_BY(mu_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index ad75e2dd434..d8249e99d42 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" #include + #include #include #include // NOLINT(build/c++11): only using std::call_once, not mutex. @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" @@ -82,6 +84,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/slice_sinker.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/stable_sort_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" @@ -103,6 +106,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" namespace xla { @@ -116,30 +120,18 @@ namespace { namespace tracing = tensorflow::tracing; -// Returns a vector of potential locations of the CUDA root directory. -std::vector GetCudaRootCandidates( - const HloModuleConfig& hlo_module_config) { - std::vector potential_cuda_roots = tensorflow::CandidateCudaRoots(); - - // "." is our last resort, even though it probably won't work. - potential_cuda_roots.push_back("."); - - // CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has - // highest priority. - string xla_gpu_cuda_data_dir = - hlo_module_config.debug_options().xla_gpu_cuda_data_dir(); - if (!xla_gpu_cuda_data_dir.empty()) { - potential_cuda_roots.insert(potential_cuda_roots.begin(), - xla_gpu_cuda_data_dir); - } - return potential_cuda_roots; +static std::vector CandidateCudaRoots( + const HloModuleConfig& config) { + return tensorflow::CandidateCudaRoots( + config.debug_options().xla_gpu_cuda_data_dir()); } void PrintCantFindCudaMessage(absl::string_view msg, const HloModuleConfig& hlo_module_config) { LOG(WARNING) << msg; - LOG(WARNING) << "Searched in the following directories:"; - for (const auto& dir : GetCudaRootCandidates(hlo_module_config)) { + LOG(WARNING) << "Searched for CUDA in the following directories:"; + + for (const auto& dir : CandidateCudaRoots(hlo_module_config)) { LOG(WARNING) << " " << dir; } LOG(WARNING) @@ -150,8 +142,7 @@ void PrintCantFindCudaMessage(absl::string_view msg, // Returns the directory containing nvvm libdevice files. string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { - const auto& candidate_dirs = GetCudaRootCandidates(hlo_module_config); - for (const string& cuda_root : candidate_dirs) { + for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) { string libdevice_dir = tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); VLOG(2) << "Looking for libdevice at " << libdevice_dir; @@ -161,9 +152,9 @@ string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { } } PrintCantFindCudaMessage( - "Can't find directory containing CUDA libevice. This may result in " - "compilation or runtime failures, if the program we try to run uses " - "routines from libdevice.", + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may " + "result in compilation or runtime failures, if the program we try to run " + "uses routines from libdevice.", hlo_module_config); // GetCudaRotCandidates always inclues ".", but but if everything fails, we @@ -176,12 +167,16 @@ string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { // It takes a compiler pointer, as passes may compile and execute HLOs on the // fly for cuDNN verification or other purposes. Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator, + se::DeviceMemoryAllocator* device_allocator, Compiler* compiler) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + // Remove zero-sized HLO from the input so that other passes don't have to + // handle it. + pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( @@ -194,7 +189,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // We need a cost model for GPUs. Currently, do nothing. return false; }; - pipeline.AddPass(false); + pipeline.AddPass(); pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/true); @@ -234,6 +229,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pass.AddPass(); pass.AddPass(); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -421,78 +417,6 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { return pipeline.Run(hlo_module).status(); } -// Prints a warning if the ptxas at ptxas_path has known bugs. -// -// Only prints a warning the first time it's called for a particular value of -// ptxas_path. -void WarnIfBadPtxasVersion(const string& ptxas_path) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static std::unordered_set* seen_ptxas_paths GUARDED_BY(mu) = - new std::unordered_set(); - - tensorflow::mutex_lock lock(mu); - if (!seen_ptxas_paths->insert(ptxas_path).second) { - // Already checked this ptx binary, nothing to do. - return; - } - - tensorflow::SubProcess ptxas; - ptxas.SetProgram(ptxas_path, {ptxas_path, "--version"}); - ptxas.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); - if (!ptxas.Start()) { - LOG(WARNING) << "Couldn't invoke " << ptxas_path << " --version"; - return; - } - - string out; - int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out, - /*stderr_output=*/nullptr); - if (exit_code != 0) { - LOG(WARNING) << "Running " << ptxas_path << " --version returned " - << exit_code; - return; - } - - int64 vmaj, vmin, vdot; - string vmaj_str, vmin_str, vdot_str; - if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, - &vmin_str, &vdot_str) || - !absl::SimpleAtoi(vmaj_str, &vmaj) || - !absl::SimpleAtoi(vmin_str, &vmin) || - !absl::SimpleAtoi(vdot_str, &vdot)) { - LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path - << " --version:\n" - << out; - return; - } - - // We need ptxas >= 9.0 as a hard requirement, because we compile targeting - // PTX 6.0. An older ptxas will just fail to compile any of our code. - // - // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some - // address calculations with large offsets (e.g. "load ptr + large_constant"), - // b/70245379. - // - // ptxas 9.1.121 miscompiles some large multioutput fusions, again in a way - // that appears related to address calculations, b/111107644. ptxas 9.2.88 - // appears to work, as far as we can tell. - if (vmaj < 9) { - LOG(ERROR) - << "You are using ptxas 8.x, but XLA requires ptxas 9.x (and strongly " - "prefers >= 9.2.88). Compilation of XLA kernels below will likely " - "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " - "binary is sufficient."; - } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) { - LOG(WARNING) - << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." - << vdot - << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to " - "miscompile XLA code, leading to incorrect results or " - "invalid-address errors.\n\nYou do not need to update to CUDA " - "9.2.88; cherry-picking the ptxas binary is sufficient."; - } -} - // Prints a warning if the ptx->sass JIT in the driver has known bugs. // // Using such a driver only a problem if we fail to use ptxas to compile our ptx @@ -533,80 +457,6 @@ void WarnIfBadDriverJITVersion() { }); } -// Compiles the given PTX string using ptxas and returns the resulting machine -// code (i.e. a cubin) as a byte array. -StatusOr> CompilePtx( - const string& ptx, int cc_major, int cc_minor, - const HloModuleConfig& hlo_module_config) { - tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); - auto env = tensorflow::Env::Default(); - string ptxas_path; - for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { - ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); - VLOG(2) << "Looking for ptxas at " << ptxas_path; - if (env->FileExists(ptxas_path).ok()) { - break; - } - } - TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); - VLOG(2) << "Using ptxas at " << ptxas_path; - - WarnIfBadPtxasVersion(ptxas_path); - - // Write ptx into a temporary file. - string ptx_path; - if (!env->LocalTempFilename(&ptx_path)) { - return InternalError("couldn't get temp PTX file name"); - } - auto ptx_cleaner = tensorflow::gtl::MakeCleanup([&ptx_path] { - TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(ptx_path)); - }); - - TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_path, ptx)); - VLOG(2) << "ptx written to: " << ptx_path; - - // Invoke ptxas and collect its output. - string cubin_path; - if (!env->LocalTempFilename(&cubin_path)) { - return InternalError("couldn't get temp CUBIN file name"); - } - auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { - // CUBIN file may never be created, so the failure to delete it should not - // produce TF error. - tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); - }); - tensorflow::SubProcess ptxas_info_dumper; - std::vector ptxas_args = { - ptxas_path, ptx_path, "-o", cubin_path, - absl::StrCat("-arch=sm_", cc_major, cc_minor)}; - if (VLOG_IS_ON(2)) { - ptxas_args.push_back("-v"); - } - if (hlo_module_config.debug_options().xla_gpu_disable_ptxas_optimizations()) { - ptxas_args.push_back("-O0"); - } - ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); - ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, - tensorflow::ACTION_PIPE); - if (!ptxas_info_dumper.Start()) { - return InternalError("Failed to launch ptxas"); - } - string stderr_output; - int exit_status = ptxas_info_dumper.Communicate( - /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); - XLA_LOG_LINES(tensorflow::INFO, stderr_output); - if (exit_status != 0) { - return InternalError("ptxas exited with non-zero error code %d", - exit_status); - } - - // Read in the result of compilation and return it as a byte vector. - string cubin; - TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), - cubin_path, &cubin)); - std::vector cubin_vector(cubin.begin(), cubin.end()); - return cubin_vector; -} } // namespace @@ -616,14 +466,12 @@ NVPTXCompiler::NVPTXCompiler() StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { // We dump the post-optimization HLO in RunBackend so no need to dump it here. - VLOG(3) << "*** HLO Before Optimization"; - XLA_VLOG_LINES(3, module->ToString()); - XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); - tracing::ScopedActivity activity("HLO Transforms", module->name(), - /*is_expensive=*/true); + tensorflow::profiler::TraceMe activity( + [&] { return absl::StrCat("HLO Transforms:", module->name()); }, + tensorflow::profiler::TraceMeLevel::kInfo); TF_RETURN_IF_ERROR( OptimizeHloModule(module.get(), stream_exec, device_allocator, this)); @@ -634,7 +482,7 @@ StatusOr> NVPTXCompiler::RunHloPasses( StatusOr> NVPTXCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend"); TF_RET_CHECK(stream_exec != nullptr); @@ -674,23 +522,11 @@ StatusOr> NVPTXCompiler::RunBackend( [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); - // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() - // include headers, so no need for us to print them ourselves. - XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); - XLA_VLOG_LINES(2, buffer_assignment->ToString()); - VLOG(3) << "*** HLO After Optimization"; - XLA_VLOG_LINES(3, module->ToString()); - const string xla_dump_optimized_hlo_proto_to = - module->config().debug_options().xla_dump_optimized_hlo_proto_to(); - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *buffer_assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); - } + DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); - IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), - &stream_exec->GetDeviceDescription(), - &llvm_module); + IrEmitterContext ir_emitter_context( + module.get(), buffer_assignment.get(), stream_exec->platform(), + &stream_exec->GetDeviceDescription(), &llvm_module); HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, @@ -704,26 +540,16 @@ StatusOr> NVPTXCompiler::RunBackend( } if (user_pre_optimization_hook_) { - TF_CHECK_OK(user_pre_optimization_hook_(llvm_module)); + user_pre_optimization_hook_(llvm_module); } string ir_module_string_before_opt; const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - if (VLOG_IS_ON(3) || embed_ir_in_executable) { + if (embed_ir_in_executable) { ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); - VLOG(3) << "LLVM module before optimizations:"; - XLA_VLOG_LINES(3, ir_module_string_before_opt); } - const string& ir_dump_directory = - module->config().debug_options().xla_dump_ir_to(); - - if (!ir_dump_directory.empty()) { - TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( - /*directory_name=*/ir_dump_directory, - /*hlo_module_name=*/module->name(), llvm_module, - /*optimized=*/false)); - } + llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/false); { XLA_SCOPED_LOGGING_TIMER( @@ -737,7 +563,7 @@ StatusOr> NVPTXCompiler::RunBackend( << "Invalid LLVM IR before optimizations:\n" << err_stream.str() << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " - "Rerun with --xla_dump_ir_to to get the IR. "; + "Rerun with --xla_dump_to to get the IR. "; } string libdevice_dir; @@ -770,45 +596,26 @@ StatusOr> NVPTXCompiler::RunBackend( module->config(), libdevice_dir)); } - if (!ir_dump_directory.empty()) { - TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( - /*directory_name=*/ir_dump_directory, - /*hlo_module_name=*/module->name(), llvm_module, - /*optimized=*/true)); - } + llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/true); if (user_post_optimization_hook_) { - TF_CHECK_OK(user_post_optimization_hook_(llvm_module)); + user_post_optimization_hook_(llvm_module); } - VLOG(3) << "LLVM module after optimizations:"; - XLA_VLOG_LINES(3, llvm_ir::DumpModuleToString(llvm_module)); - VLOG(3) << "PTX:"; - XLA_VLOG_LINES(3, ptx); - // Write PTX to IR dump directory, if IR dumping was requested. - if (!ir_dump_directory.empty()) { - const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, absl::StrCat(module->name(), ".ptx")); - auto status = [&] { - auto* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); - TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_outfile, ptx)); - return Status::OK(); - }(); - if (!status.ok()) { - LOG(WARNING) << "Couldn't dump PTX for module " << module->name() - << " to " << ptx_outfile << ": " << status; - } + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "ptx", ptx); } - const std::vector cubin = - CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor, module->config()); + const std::vector cubin = CompilePtxOrGetCachedResult( + stream_exec, ptx, cc_major, cc_minor, module->config()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); - VLOG(3) << "Printing the thunk schedule..."; - XLA_VLOG_LINES(3, thunk_schedule->ToString()); + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "thunk_schedule", + thunk_schedule->ToString()); + } std::unique_ptr profile_index_map; std::unique_ptr profile_printer; @@ -840,10 +647,11 @@ StatusOr> NVPTXCompiler::RunBackend( } std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( - const string& ptx, int cc_major, int cc_minor, - const HloModuleConfig& hlo_module_config) { + se::StreamExecutor* stream_exec, const string& ptx, int cc_major, + int cc_minor, const HloModuleConfig& hlo_module_config) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); - tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); + tensorflow::profiler::TraceMe activity( + "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo); bool inserted; decltype(compilation_cache_.begin()) iter; // Pointers into compilation_cache_ where the ptx and (optional) cubin are @@ -869,8 +677,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = - CompilePtx(*cache_ptx, cc_major, cc_minor, hlo_module_config); + StatusOr> maybe_cubin = CompilePtx( + stream_exec, *cache_ptx, PtxCompilationOptions(hlo_module_config)); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() @@ -890,9 +698,10 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( } if (log_warning) { PrintCantFindCudaMessage( - "Can't find ptxas binary. Will back to the GPU driver " - "for PTX -> sass compilation. This is OK so long as you don't " - "see a warning below about an out-of-date driver version.", + "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the " + "GPU driver for PTX -> sass compilation. This is OK so long " + "as you don't see a warning below about an out-of-date driver " + "version.", hlo_module_config); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index b2077f42fd0..25e4b9427c0 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" namespace xla { namespace gpu { @@ -52,11 +53,11 @@ class NVPTXCompiler : public LLVMCompiler { StatusOr> RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, @@ -98,8 +99,8 @@ class NVPTXCompiler : public LLVMCompiler { // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. std::vector CompilePtxOrGetCachedResult( - const string& ptx, int cc_major, int cc_minor, - const HloModuleConfig& hlo_module_config); + se::StreamExecutor* stream_exec, const string& ptx, int cc_major, + int cc_minor, const HloModuleConfig& hlo_module_config); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc b/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc new file mode 100644 index 00000000000..9427a44a90c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc @@ -0,0 +1,309 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h" + +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/kernel.h" +#include "tensorflow/stream_executor/kernel_spec.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" + +namespace xla { +namespace gpu { + +// The size of the redzone at the end of the user buffer is rounded up to a +// multiple of kRhsRedzoneAlign. This simplifies the implementation a bit. +constexpr int64 kRhsRedzoneAlign = 4; + +using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; + +StatusOr> RedzoneAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + int64 rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size; + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, + byte_size + 2 * redzone_size_ + rhs_slop, + /*retry_on_failure=*/false)); + allocated_bytes_excluding_redzones_ += byte_size; + + static_assert(sizeof(uint8) == 1, "Unexpected size"); + se::DeviceMemory allocated_buffer_memory(*allocated_buffer); + + se::DeviceMemory lhs_redzone = stream->parent()->GetSubBuffer( + &allocated_buffer_memory, 0, redzone_size_); + + se::DeviceMemory data_chunk = stream->parent()->GetSubBuffer( + &allocated_buffer_memory, redzone_size_, byte_size); + + // Split up the RHS redzone into two pieces: + // - 0 to kRhsRedzoneAlign bytes adjacent to the user buffer, followed by + // - redzone_size_ bytes. + // We do this because Stream::ThenMemset32 requires the buffer address and + // size to be aligned to 4 bytes. + se::DeviceMemory rhs_redzone_slop = stream->parent()->GetSubBuffer( + &allocated_buffer_memory, redzone_size_ + byte_size, rhs_slop); + + se::DeviceMemory rhs_redzone_nonslop = stream->parent()->GetSubBuffer( + &allocated_buffer_memory, redzone_size_ + byte_size + rhs_slop, + redzone_size_); + + uint8 pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_, + redzone_pattern_}; + uint32 pattern32; + std::memcpy(&pattern32, pattern_arr, sizeof(pattern32)); + stream->ThenMemset32(&lhs_redzone, pattern32, redzone_size_); + if (rhs_slop != 0) { + stream->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop); + } + stream->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_); + + allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size); + return data_chunk; +} + +// PTX blob for the function which checks that every byte in +// input_buffer (length is buffer_length) is equal to redzone_pattern. +// +// On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr. +// +// Generated from: +// __global__ void redzone_checker(unsigned char* input_buffer, +// unsigned char redzone_pattern, +// unsigned long long buffer_length, +// int* out_mismatched_ptr) { +// unsigned long long idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); +// } +// +// Code must compile for the oldest GPU XLA may be compiled for. +static const char* redzone_checker_ptx = R"( +.version 4.2 +.target sm_30 +.address_size 64 + +.visible .entry redzone_checker( + .param .u64 input_buffer, + .param .u8 redzone_pattern, + .param .u64 buffer_length, + .param .u64 out_mismatch_cnt_ptr +) +{ + .reg .pred %p<3>; + .reg .b16 %rs<3>; + .reg .b32 %r<6>; + .reg .b64 %rd<8>; + + ld.param.u64 %rd6, [buffer_length]; + mov.u32 %r1, %tid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %ntid.x; + mad.lo.s32 %r4, %r3, %r2, %r1; + cvt.u64.u32 %rd3, %r4; + setp.ge.u64 %p1, %rd3, %rd6; + @%p1 bra LBB6_3; + ld.param.u8 %rs1, [redzone_pattern]; + ld.param.u64 %rd4, [input_buffer]; + cvta.to.global.u64 %rd2, %rd4; + add.s64 %rd7, %rd2, %rd3; + ld.global.u8 %rs2, [%rd7]; + setp.eq.s16 %p2, %rs2, %rs1; + @%p2 bra LBB6_3; + ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; + cvta.to.global.u64 %rd1, %rd5; + atom.global.add.u32 %r5, [%rd1], 1; +LBB6_3: + ret; +} +)"; + +// The PTX in redzone_checker_ptx has to be launched with specified types +// in the specified order. +using ComparisonKernelT = se::TypedKernel, uint8, + uint64, se::DeviceMemory>; + +// Check that redzones weren't overwritten on a host. +// +// Slower, but gives a more useful error message. +static StatusOr CheckRedzoneHost( + se::DeviceMemoryBase redzone, se::DeviceMemoryBase user_allocation, + absl::string_view name, se::Stream* stream, uint8 redzone_pattern, + int64 redzone_size) { + uint64 size = redzone.size(); + auto redzone_data = absl::make_unique(size); + TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size) + .BlockHostUntilDone()); + XLA_SCOPED_LOGGING_TIMER("RedzoneAllocator::CheckBufferRedzones CPU loop."); + + std::array pattern_arr; + pattern_arr.fill(redzone_pattern); + uint64 pattern64; + std::memcpy(&pattern64, pattern_arr.data(), sizeof(uint64)); + + int64 i; + for (i = 0; i + 7 < size; i += sizeof(uint64)) { + uint64 rz_value = *reinterpret_cast(&redzone_data[i]); + if (rz_value != pattern64) { + return RedzoneCheckStatus::WithFailureMsg(absl::StrFormat( + "Redzone mismatch in %s redzone of buffer %p at offset %d; " + "expected %08x but was %08x.", + name, user_allocation.opaque(), i, pattern64, rz_value)); + } + } + for (; i < size; ++i) { + uint8 rz_value = redzone_data[i]; + if (rz_value != redzone_pattern) { + return RedzoneCheckStatus::WithFailureMsg(absl::StrFormat( + "Redzone mismatch in %s redzone of buffer %p at offset %d; " + "expected %08x but was %08x.", + name, user_allocation.opaque(), i, redzone_pattern, rz_value)); + } + } + return RedzoneCheckStatus::OK(); +} + +// Run the redzone checker on the provided buffer redzone. +// +// Increment out_param if mismatch occurs. +static void RunRedzoneChecker(se::Stream* stream, + const se::DeviceMemory& redzone, + uint8 redzone_pattern, + const se::DeviceMemory& out_param, + const ComparisonKernelT& comparison_kernel) { + se::StreamExecutor* executor = stream->parent(); + Shape redzone_shape = ShapeUtil::MakeShape( + PrimitiveType::U8, {static_cast(redzone.size())}); + LaunchDimensions dim = CalculateLaunchDimensions( + redzone_shape, executor->GetDeviceDescription()); + + stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()), + se::BlockDim(dim.block_count()), comparison_kernel, + redzone, redzone_pattern, redzone.size(), out_param); +} + +// Check redzones around the user allocation. +// +// Precondition: the memory pointed out by out_param is zeroed. +static StatusOr CheckRedzonesForBuffer( + se::Stream* stream, se::DeviceMemoryBase memory, + const se::DeviceMemory& out_param, + const ComparisonKernelT& comparison_kernel, int64 user_allocation_size, + uint64 redzone_size, uint8 redzone_pattern) { + se::StreamExecutor* executor = stream->parent(); + int64 rhs_slop = + RoundUpToNearest(user_allocation_size, kRhsRedzoneAlign) - + user_allocation_size; + CHECK_EQ(memory.size(), user_allocation_size + rhs_slop + 2 * redzone_size); + + se::DeviceMemory buffer_uint8(memory); + se::DeviceMemory lhs_redzone = + executor->GetSubBuffer(&buffer_uint8, 0, redzone_size); + se::DeviceMemory user_allocation = + executor->GetSubBuffer(&buffer_uint8, redzone_size, user_allocation_size); + se::DeviceMemory rhs_redzone = + executor->GetSubBuffer(&buffer_uint8, redzone_size + user_allocation_size, + redzone_size + rhs_slop); + + RunRedzoneChecker(stream, lhs_redzone, redzone_pattern, out_param, + comparison_kernel); + RunRedzoneChecker(stream, rhs_redzone, redzone_pattern, out_param, + comparison_kernel); + int64 result; + CHECK_EQ(out_param.size(), sizeof(result)); + stream->ThenMemcpy(&result, out_param, sizeof(result)); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + + if (result != 0) { + TF_ASSIGN_OR_RETURN( + RedzoneCheckStatus lhs_check, + CheckRedzoneHost(lhs_redzone, user_allocation, "LHS", stream, + redzone_pattern, redzone_size)); + TF_ASSIGN_OR_RETURN( + RedzoneCheckStatus rhs_check, + CheckRedzoneHost(rhs_redzone, user_allocation, "RHS", stream, + redzone_pattern, redzone_size)); + + CHECK(!lhs_check.ok() || !rhs_check.ok()) + << "Mismatched results with host and device comparison"; + return !lhs_check.ok() ? lhs_check : rhs_check; + } + + return RedzoneCheckStatus::OK(); +} + +StatusOr RedzoneAllocator::CheckRedzones( + se::Stream* stream) const { + XLA_SCOPED_LOGGING_TIMER("Redzone checking"); + + se::StreamExecutor* executor = stream->parent(); + + absl::Span compiled_ptx = {}; + StatusOr> compiled_ptx_or = CompilePtxOrGetCached( + executor, redzone_checker_ptx, PtxCompilationOptions(hlo_module_config_)); + if (compiled_ptx_or.ok()) { + compiled_ptx = compiled_ptx_or.ValueOrDie(); + } else { + LOG(WARNING) << compiled_ptx_or.status().ToString() + << "\nRelying on driver to perform ptx compilation"; + } + + se::ScopedDeviceMemory out_param = + executor->AllocateOwnedScalar(); + stream->ThenMemZero(out_param.ptr(), sizeof(uint64)); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr comparison_kernel, + (CreateTypedKernel, uint8, uint64, + se::DeviceMemory>( + "redzone_checker", redzone_checker_ptx, compiled_ptx, executor))); + + for (const auto& buf_and_size : allocated_buffers_) { + TF_ASSIGN_OR_RETURN( + RedzoneCheckStatus redzone_status, + CheckRedzonesForBuffer(stream, *buf_and_size.first, out_param.cref(), + *comparison_kernel, buf_and_size.second, + redzone_size_, redzone_pattern_)); + if (!redzone_status.ok()) { + return redzone_status; + } + } + + return RedzoneCheckStatus::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/redzone_allocator.h b/tensorflow/compiler/xla/service/gpu/redzone_allocator.h new file mode 100644 index 00000000000..e6eff32eec9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/redzone_allocator.h @@ -0,0 +1,113 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" + +namespace xla { +namespace gpu { + +// An allocator that allocates a bit of extra memory around the beginning/end of +// every allocation and can check that this memory is unmodified. +// +// This can be used to check for out-of-bounds writes, and, if the redzone is +// filled with a sufficiently "ugly" pattern, may also be able to check for +// out-of-bounds reads. The default fill pattern of -1 is an unusual NaN +// pattern when interpreted as a floating-point number, so hopefully works for +// out-of-bounds reads and writes in those cases. +// +// This class implements se::ScratchAllocator, so can be used to allocate temp +// memory for cudnn convolutions. +class RedzoneAllocator : public se::ScratchAllocator { + public: + RedzoneAllocator(int device_ordinal, + se::DeviceMemoryAllocator* memory_allocator, + const HloModuleConfig& hlo_module_config, + int64 redzone_size = 1 << 23, // 8MiB per side, 16MiB total + uint8 redzone_pattern = -1) + : device_ordinal_(device_ordinal), + redzone_size_( + RoundUpToNearest(redzone_size, kXlaAllocatedBufferAlignBytes)), + redzone_pattern_(redzone_pattern), + memory_allocator_(memory_allocator), + hlo_module_config_(hlo_module_config) {} + + // Redzones don't count towards the memory limit. + int64 GetMemoryLimitInBytes(se::Stream* stream) override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytesExcludingRedzones() const { + return allocated_bytes_excluding_redzones_; + } + + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; + + // Non-empty redzone check status implies that there was a write into a + // redzone, with a string communicating the location of the write. + struct RedzoneCheckStatus { + std::string redzone_failure_msg; + + static RedzoneCheckStatus OK() { return {}; } + + static RedzoneCheckStatus WithFailureMsg(std::string msg) { return {msg}; } + + bool ok() { return redzone_failure_msg.empty(); } + }; + + // Determines whether redzones around all allocated buffers are unmodified. + // + // Returns: + // + // - RedzoneCheckStatus::OK() if everything went well. + // - RedzoneCheckStatus with a non-empty error message iff a write into a + // redzone has been detected. + // - A stream error, if loading or launching the kernel has failed. + StatusOr CheckRedzones(se::Stream* stream) const; + + private: + const int device_ordinal_; + + // Redzone size on *one side* of allocation. + // + // Must be a multiple of kXlaAllocatedBufferAlignBytes, otherwise the buffers + // returned to users will be misaligned. + const int64 redzone_size_; + + const uint8 redzone_pattern_; + se::DeviceMemoryAllocator* memory_allocator_; + const HloModuleConfig& hlo_module_config_; + + // The second element of the pair is the size of the user allocation. This + // isn't necessarily just first.size() - 2 * redzone_size_ because when the + // user allocation size is not a multiple of 4 bytes, we round up the size of + // the RHS redzone. + std::vector> allocated_buffers_; + + int64 allocated_bytes_excluding_redzones_ = 0; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/redzone_allocator_test.cc b/tensorflow/compiler/xla/service/gpu/redzone_allocator_test.cc new file mode 100644 index 00000000000..6344836d237 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/redzone_allocator_test.cc @@ -0,0 +1,144 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h" + +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" + +namespace xla { +namespace gpu { +namespace { + +using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus; + +static void EXPECT_REDZONE_OK(StatusOr status) { + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(status.ValueOrDie().ok()); +} + +static void EXPECT_REDZONE_VIOLATION(StatusOr status) { + EXPECT_TRUE(status.ok()); + EXPECT_FALSE(status.ValueOrDie().ok()); +} + +TEST(RedzoneAllocatorTest, WriteToRedzone) { + constexpr int64 kRedzoneSize = 1 << 23; // 8MiB redzone on each side + // Redzone pattern should not be equal to zero; otherwise modify_redzone will + // break. + constexpr uint8 kRedzonePattern = 0x7e; + + // Allocate 32MiB + 1 byte (to make things misaligned) + constexpr int64 kAllocSize = (1 << 25) + 1; + + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); + se::StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie(); + HloModuleConfig config; + se::StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); + RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, config, + kRedzoneSize, kRedzonePattern); + + se::Stream stream(stream_exec); + stream.Init(); + TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemory buf, + allocator.AllocateBytes(&stream, + /*byte_size=*/kAllocSize)); + EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream)); + + char* buf_addr = reinterpret_cast(buf.opaque()); + se::DeviceMemoryBase lhs_redzone(buf_addr - kRedzoneSize, kRedzoneSize); + se::DeviceMemoryBase rhs_redzone(buf_addr + kAllocSize, kRedzoneSize); + + // Check that the redzones are in fact filled with kRedzonePattern. + auto check_redzone = [&](se::DeviceMemoryBase redzone, + absl::string_view name) { + std::vector host_buf(kRedzoneSize); + TF_ASSERT_OK(stream.ThenMemcpy(host_buf.data(), redzone, kRedzoneSize) + .BlockHostUntilDone()); + const int64 kMaxMismatches = 16; + int64 mismatches = 0; + for (int64 i = 0; i < host_buf.size(); ++i) { + if (mismatches == kMaxMismatches) { + ADD_FAILURE() << "Hit max number of mismatches; skipping others."; + break; + } + if (host_buf[i] != kRedzonePattern) { + ++mismatches; + EXPECT_EQ(host_buf[i], kRedzonePattern) + << "at index " << i << " of " << name << " redzone"; + } + } + }; + check_redzone(lhs_redzone, "lhs"); + check_redzone(rhs_redzone, "rhs"); + + // Modifies a redzone, checks that RedzonesAreUnmodified returns false, then + // reverts it back to its original value and checks that RedzonesAreUnmodified + // returns true. + auto modify_redzone = [&](se::DeviceMemoryBase redzone, int64 offset, + absl::string_view name) { + SCOPED_TRACE(absl::StrCat(name, ", offset=", offset)); + se::DeviceMemoryBase redzone_at_offset( + reinterpret_cast(redzone.opaque()) + offset, 1); + char old_redzone_value = 0; + { + XLA_SCOPED_LOGGING_TIMER("Checking redzones"); + EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream)); + } + stream.ThenMemcpy(&old_redzone_value, redzone_at_offset, 1) + .ThenMemZero(&redzone_at_offset, 1); + EXPECT_REDZONE_VIOLATION(allocator.CheckRedzones(&stream)); + stream.ThenMemcpy(&redzone_at_offset, &old_redzone_value, 1); + EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream)); + }; + + modify_redzone(lhs_redzone, /*offset=*/0, "lhs"); + modify_redzone(lhs_redzone, /*offset=*/kRedzoneSize - 1, "lhs"); + modify_redzone(rhs_redzone, /*offset=*/0, "rhs"); + modify_redzone(rhs_redzone, /*offset=*/kRedzoneSize - 1, "rhs"); +} + +// Older CUDA compute capabilities (<= 2.0) have a limitation that grid +// dimension X cannot be larger than 65535. +// +// Make sure we can launch kernels on sizes larger than that, given that the +// maximum number of threads per block is 1024. +TEST(RedzoneAllocatorTest, VeryLargeRedzone) { + // Make sure the redzone size would require grid dimension > 65535. + constexpr int64 kRedzoneSize = 65535 * 1024 + 1; + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); + se::StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie(); + HloModuleConfig config; + se::StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); + RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, config, + kRedzoneSize, /*redzone_pattern=*/-1); + se::Stream stream(stream_exec); + stream.Init(); + (void)allocator.AllocateBytes(&stream, /*byte_size=*/1); + EXPECT_REDZONE_OK(allocator.CheckRedzones(&stream)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc b/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc index 197367e8168..5793051771f 100644 --- a/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc +++ b/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc @@ -29,12 +29,12 @@ StatusOr> ScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, memory_allocator_->Allocate(device_ordinal_, byte_size, /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + se::DeviceMemoryBase buffer_addr = *allocated_buffer; allocated_buffers_.push_back(std::move(allocated_buffer)); return se::DeviceMemory(buffer_addr); } diff --git a/tensorflow/compiler/xla/service/gpu/scratch_allocator.h b/tensorflow/compiler/xla/service/gpu/scratch_allocator.h index 620c7e78912..9654237956a 100644 --- a/tensorflow/compiler/xla/service/gpu/scratch_allocator.h +++ b/tensorflow/compiler/xla/service/gpu/scratch_allocator.h @@ -18,18 +18,18 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace gpu { class ScratchAllocator : public se::ScratchAllocator { public: - ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) + ScratchAllocator(int device_ordinal, + se::DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} int64 GetMemoryLimitInBytes(se::Stream* stream) override { @@ -50,8 +50,8 @@ class ScratchAllocator : public se::ScratchAllocator { private: const int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + se::DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 08ff52211af..ca409fff67b 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -15,9 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/stream_executor/kernel_spec.h" namespace xla { namespace gpu { @@ -162,5 +171,242 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, return std::make_tuple(input_layout, filter_layout, output_layout); } + +tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + // se::Platform*s are global singletons guaranteed to live forever. + static auto* mutexes = + new std::map, + tensorflow::mutex>(); + + tensorflow::mutex_lock global_lock(mu); + auto it = mutexes + ->emplace(std::piecewise_construct, + std::make_tuple(stream_exec->platform(), + stream_exec->device_ordinal()), + std::make_tuple()) + .first; + return tensorflow::mutex_lock{it->second}; +} + +StatusOr> CreateKernel( + absl::string_view kernel_name, uint64 num_args, absl::string_view ptx, + absl::Span cubin_data, se::StreamExecutor* stream_exec) { + se::MultiKernelLoaderSpec loader_spec(num_args); + loader_spec.AddCudaPtxInMemory(ptx, kernel_name); + + if (!cubin_data.empty()) { + loader_spec.AddCudaCubinInMemory( + reinterpret_cast(cubin_data.data()), kernel_name); + } + + auto kernel_base = absl::make_unique(stream_exec); + if (!stream_exec->GetKernel(loader_spec, kernel_base.get())) { + return InternalError("Unable to load kernel '%s'", kernel_name); + } + + return std::move(kernel_base); +} + +Status ExecuteKernelOnStream(const se::KernelBase& kernel, + absl::Span args, + int64 threads_per_block, int64 block_count, + se::Stream* stream) { + static constexpr int kKernelArgsLimit = 1024; + auto kernel_args = absl::make_unique>(); + for (const se::DeviceMemoryBase& buf : args) { + kernel_args->add_device_memory_argument(buf); + } + + if (!stream->parent()->Launch(stream, se::ThreadDim(threads_per_block), + se::BlockDim(block_count), kernel, + *kernel_args)) { + return InternalError("Unable to launch kernel"); + } + return Status::OK(); +} + +// Prints a warning if the ptxas at ptxas_path has known bugs. +// +// Only prints a warning the first time it's called for a particular value of +// ptxas_path. +// +// Locks on entry. +void WarnIfBadPtxasVersion(const string& ptxas_path) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static std::unordered_set* seen_ptxas_paths GUARDED_BY(mu) = + new std::unordered_set(); + + tensorflow::mutex_lock lock(mu); + if (!seen_ptxas_paths->insert(ptxas_path).second) { + // Already checked this ptx binary, nothing to do. + return; + } + + tensorflow::SubProcess ptxas; + ptxas.SetProgram(ptxas_path, {ptxas_path, "--version"}); + ptxas.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); + if (!ptxas.Start()) { + LOG(WARNING) << "Couldn't invoke " << ptxas_path << " --version"; + return; + } + + string out; + int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out, + /*stderr_output=*/nullptr); + if (exit_code != 0) { + LOG(WARNING) << "Running " << ptxas_path << " --version returned " + << exit_code; + return; + } + + int64 vmaj, vmin, vdot; + string vmaj_str, vmin_str, vdot_str; + if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, + &vmin_str, &vdot_str) || + !absl::SimpleAtoi(vmaj_str, &vmaj) || + !absl::SimpleAtoi(vmin_str, &vmin) || + !absl::SimpleAtoi(vdot_str, &vdot)) { + LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path + << " --version:\n" + << out; + return; + } + + // We need ptxas >= 9.0 as a hard requirement, because we compile targeting + // PTX 6.0. An older ptxas will just fail to compile any of our code. + // + // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some + // address calculations with large offsets (e.g. "load ptr + large_constant"), + // b/70245379. + // + // ptxas 9.1.121 miscompiles some large multioutput fusions, again in a way + // that appears related to address calculations, b/111107644. ptxas 9.2.88 + // appears to work, as far as we can tell. + if (vmaj < 9) { + LOG(ERROR) + << "You are using ptxas 8.x, but XLA requires ptxas 9.x (and strongly " + "prefers >= 9.2.88). Compilation of XLA kernels below will likely " + "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " + "binary is sufficient."; + } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) { + LOG(WARNING) + << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." + << vdot + << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + "miscompile XLA code, leading to incorrect results or " + "invalid-address errors.\n\nYou do not need to update to CUDA " + "9.2.88; cherry-picking the ptxas binary is sufficient."; + } +} + +StatusOr> CompilePtxOrGetCached( + se::StreamExecutor* executor, absl::string_view ptx, + PtxCompilationOptions compilation_options) { + using PtxCacheKey = std::tuple; + static tensorflow::mutex ptx_cache_mutex(tensorflow::LINKER_INITIALIZED); + static auto& ptx_cache GUARDED_BY(ptx_cache_mutex) = + *new absl::flat_hash_map>(); + + tensorflow::mutex_lock lock(ptx_cache_mutex); + PtxCacheKey cache_key{executor, std::string(ptx), + compilation_options.ToTuple()}; + auto it = ptx_cache.find(cache_key); + if (it == ptx_cache.end()) { + TF_ASSIGN_OR_RETURN(std::vector compiled, + CompilePtx(executor, ptx, compilation_options)); + it = ptx_cache.emplace(cache_key, std::move(compiled)).first; + } + + CHECK(it != ptx_cache.end()); + const std::vector& compiled = it->second; + return absl::MakeSpan(compiled); +} + +StatusOr> CompilePtx( + se::StreamExecutor* stream_exec, absl::string_view ptx, + PtxCompilationOptions compile_ptx_options) { + int cc_major, cc_minor; + if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor)) { + LOG(WARNING) + << "Couldn't get compute capability for device; assuming sm_20."; + cc_major = 2; + cc_minor = 0; + } + + tensorflow::profiler::TraceMe activity( + "Compile PTX", tensorflow::profiler::TraceMeLevel::kInfo); + auto env = tensorflow::Env::Default(); + string ptxas_path; + for (const string& cuda_root : tensorflow::CandidateCudaRoots( + /*preferred_location=*/compile_ptx_options.xla_gpu_cuda_data_dir)) { + ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); + VLOG(2) << "Looking for ptxas at " << ptxas_path; + if (env->FileExists(ptxas_path).ok()) { + break; + } + } + TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); + VLOG(2) << "Using ptxas at " << ptxas_path; + + WarnIfBadPtxasVersion(ptxas_path); + + // Write ptx into a temporary file. + string ptx_path; + if (!env->LocalTempFilename(&ptx_path)) { + return InternalError("couldn't get temp PTX file name"); + } + auto ptx_cleaner = tensorflow::gtl::MakeCleanup([&ptx_path] { + TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(ptx_path)); + }); + + TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_path, ptx)); + VLOG(2) << "ptx written to: " << ptx_path; + + // Invoke ptxas and collect its output. + string cubin_path; + if (!env->LocalTempFilename(&cubin_path)) { + return InternalError("couldn't get temp CUBIN file name"); + } + auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { + // CUBIN file may never be created, so the failure to delete it should not + // produce TF error. + tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); + }); + tensorflow::SubProcess ptxas_info_dumper; + std::vector ptxas_args = { + ptxas_path, ptx_path, "-o", cubin_path, + absl::StrCat("-arch=sm_", cc_major, cc_minor)}; + if (VLOG_IS_ON(2)) { + ptxas_args.push_back("-v"); + } + if (compile_ptx_options.xla_gpu_disable_ptxas_optimizations) { + ptxas_args.push_back("-O0"); + } + ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); + ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, + tensorflow::ACTION_PIPE); + if (!ptxas_info_dumper.Start()) { + return InternalError("Failed to launch ptxas"); + } + string stderr_output; + int exit_status = ptxas_info_dumper.Communicate( + /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); + XLA_LOG_LINES(tensorflow::INFO, stderr_output); + if (exit_status != 0) { + return InternalError("ptxas exited with non-zero error code %d", + exit_status); + } + + // Read in the result of compilation and return it as a byte vector. + string cubin; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + cubin_path, &cubin)); + std::vector cubin_vector(cubin.begin(), cubin.end()); + return cubin_vector; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 92e4d6dbbc1..06ac7dca634 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,11 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/kernel_spec.h" // Helper functions for interacting with StreamExecutor. @@ -45,6 +49,102 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, const Layout& input, const Layout& filter, const Layout& output); +// Generates and returns a unique lock per each provided executor. +// Guarantees that blocks of code both holding a lock for the same provided +// executor (as given by this function) will not be running concurrently. +// +// This is used to prevent other XLA instances from trying to autotune on a +// device while another thread is using it. +tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec); + +// Creates a kernel which can be launched with stream.ThenLaunch, such that +// the types of the arguments provided for launch would have to match +// types of the arguments provided at creation time. +// +// The kernel has a name kernel_name, and is based from provided PTX in ptx, +// and (optional) compiled PTX in cubin_data. +// The canonical storage for both ptx and cubin_data should outlive the +// lifetime of the kernel. +// +// This is a preferred API since it provides type safety for kernel launches. +template +StatusOr>> CreateTypedKernel( + absl::string_view kernel_name, absl::string_view ptx, + absl::Span cubin_data, se::StreamExecutor* stream_exec) { + auto kernel_base = absl::make_unique>(stream_exec); + se::MultiKernelLoaderSpec loader_spec(kernel_base->kNumberOfParameters); + loader_spec.AddCudaPtxInMemory(ptx, kernel_name); + + if (!cubin_data.empty()) { + loader_spec.AddCudaCubinInMemory( + reinterpret_cast(cubin_data.data()), kernel_name); + } + + if (!stream_exec->GetKernel(loader_spec, kernel_base.get())) { + return InternalError("Unable to load kernel '%s'", kernel_name); + } + + return std::move(kernel_base); +} + +// Creates a kernel with a provided name, based from provided PTX in ptx. +// The kernel should be executed using the provided executor. +// The argument cubin_data represents compiled PTX and may be left empty. +// +// The canonical storage for both ptx and cubin_data should outlive +// the lifetime of the kernel. +StatusOr> CreateKernel( + absl::string_view kernel_name, uint64 num_args, absl::string_view ptx, + absl::Span cubin_data, se::StreamExecutor* stream_exec); + +// Runs loaded kernel on the stream with the provided arguments. +Status ExecuteKernelOnStream(const se::KernelBase& kernel, + absl::Span args, + int64 threads_per_block, int64 block_count, + se::Stream* stream); + +// Options for compiling with PTX. +struct PtxCompilationOptions { + bool xla_gpu_disable_ptxas_optimizations; + std::string xla_gpu_cuda_data_dir; + + using PtxOptionsTuple = std::tuple; + + explicit PtxCompilationOptions(const HloModuleConfig& hlo_module_config) + : xla_gpu_disable_ptxas_optimizations( + hlo_module_config.debug_options() + .xla_gpu_disable_ptxas_optimizations()), + xla_gpu_cuda_data_dir( + hlo_module_config.debug_options().xla_gpu_cuda_data_dir()) {} + + // For comparison and hashing. + PtxOptionsTuple ToTuple() { + return std::make_tuple(xla_gpu_disable_ptxas_optimizations, + xla_gpu_cuda_data_dir); + } +}; + +// Compiles the given PTX string using ptxas and returns the resulting machine +// code (i.e. a cubin) as a byte array. +// +// Queries stream executor stream_exec to get CUDA compute capability from the +// device. +// +// compile_ptx_options is used to query for the CUDA location in case it is +// customized in a passed flag, and for controlling ptxas optimizations. +// It can be constructed from HloModuleConfig. +StatusOr> CompilePtx( + se::StreamExecutor* stream_exec, absl::string_view ptx, + PtxCompilationOptions compile_ptx_options); + +// Same as CompilePtx, but caches the result, and returns unowned view of +// the compiled binary. +// +// A copy of the string provided in ptx will be made. +StatusOr> CompilePtxOrGetCached( + se::StreamExecutor* executor, absl::string_view ptx, + PtxCompilationOptions compilation_options); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/target_util.cc b/tensorflow/compiler/xla/service/gpu/target_util.cc new file mode 100644 index 00000000000..8225cd79a66 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/target_util.cc @@ -0,0 +1,102 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Provide helper routine for obtaining gpu target information useful +// for llvm IR contruction. + +#include "tensorflow/compiler/xla/service/gpu/target_util.h" + +#include "llvm/IR/MDBuilder.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { +namespace { +// Utility functions to obtain NVPTX/AMDGPU specific information. + +// Wrapper structure for carrying llvm intrinsic ids for NVPTX/AMDGPU platforms. +struct TargetIntrinsics { + llvm::Intrinsic::ID nvptx_intrinsic; + llvm::Intrinsic::ID amdgpu_intrinsic; +}; + +// Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU) +// corresponding to the give TargetIntrinsicID. +struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) { + switch (intrin) { + case TargetIntrinsicID::kShflDownF32: { + return {llvm::Intrinsic::nvvm_shfl_sync_down_f32, + llvm::Intrinsic::not_intrinsic}; + } + case TargetIntrinsicID::kShflDownI32: { + return {llvm::Intrinsic::nvvm_shfl_sync_down_i32, + llvm::Intrinsic::not_intrinsic}; + } + case TargetIntrinsicID::kThreadIdx: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, + llvm::Intrinsic::amdgcn_workitem_id_x}; + } + case TargetIntrinsicID::kThreadIdy: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y, + llvm::Intrinsic::amdgcn_workitem_id_y}; + } + case TargetIntrinsicID::kThreadIdz: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z, + llvm::Intrinsic::amdgcn_workitem_id_z}; + } + case TargetIntrinsicID::kBlockIdx: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, + llvm::Intrinsic::amdgcn_workgroup_id_x}; + } + case TargetIntrinsicID::kBlockIdy: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y, + llvm::Intrinsic::amdgcn_workgroup_id_y}; + } + case TargetIntrinsicID::kBlockIdz: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z, + llvm::Intrinsic::amdgcn_workgroup_id_z}; + } + case TargetIntrinsicID::kBarrierId: { + return {llvm::Intrinsic::nvvm_barrier0, + llvm::Intrinsic::amdgcn_s_barrier}; + } + } +} +} // namespace + +llvm::CallInst* EmitCallToTargetIntrinsic( + TargetIntrinsicID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getModule(); + struct TargetIntrinsics gpu_intrinsic_id = GetIntrinsic(intrinsic_id); + llvm::Triple target_triple = llvm::Triple(module->getTargetTriple()); + llvm::Intrinsic::ID llvm_intrinsic_id = llvm::Intrinsic::not_intrinsic; + + if ((target_triple.getArch() == llvm::Triple::nvptx) || + (target_triple.getArch() == llvm::Triple::nvptx64)) { + llvm_intrinsic_id = gpu_intrinsic_id.nvptx_intrinsic; + } else if (target_triple.getArch() == llvm::Triple::amdgcn) { + llvm_intrinsic_id = gpu_intrinsic_id.amdgpu_intrinsic; + } else { + LOG(FATAL) << "Invalid triple " << target_triple.str(); + } + + llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + module, llvm_intrinsic_id, llvm_ir::AsArrayRef(overloaded_types)); + return b->CreateCall(intrinsic, llvm_ir::AsArrayRef(operands)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/target_util.h b/tensorflow/compiler/xla/service/gpu/target_util.h new file mode 100644 index 00000000000..b8f796c7259 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/target_util.h @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TARGET_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TARGET_UTIL_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Triple.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" + +namespace xla { +namespace gpu { + +// Enmeration to get target specific intrinsics. +enum class TargetIntrinsicID { + kShflDownF32 = 0, + kShflDownI32, + kThreadIdx, + kThreadIdy, + kThreadIdz, + kBlockIdx, + kBlockIdy, + kBlockIdz, + kBarrierId, +}; + +// Emits a call to the specified target intrinsic with the given operands. + +// Overloaded intrinsics (for example, "minnum") must include a type +// in overloaded_types for each overloaded type. Typically, overloaded +// intrinsics have only a single overloaded type. +llvm::CallInst* EmitCallToTargetIntrinsic( + TargetIntrinsicID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TARGET_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index d798b316437..b6ce15bb384 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -200,8 +200,8 @@ tf_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:gpu_plugin", - "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc index 672c68e59b5..914b81c632f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/tests/filecheck.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index 6814be779e0..963716e7050 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -48,8 +48,9 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { HloInstruction::CreateParameter(0, param_shape, "x")); HloInstruction* param_y = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "y")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {5, 7, 2}), param_x, param_y, + ComparisonDirection::kGe)); auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); @@ -73,7 +74,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { x = f32[5,7,2]{2,1,0} parameter(0) y = f32[5,14]{1,0} parameter(1) reshape = f32[5,7,2]{2,1,0} reshape(y) - ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape) + ROOT gte = pred[5,7,2]{2,1,0} compare(x, reshape), direction=GE })", config) .ValueOrDie(); @@ -98,7 +99,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { y = f32[14]{0} parameter(1) reshape = f32[7,2]{1,0} reshape(y) broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2} - ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast) + ROOT gte = pred[5,7,2]{2,1,0} compare(x, broadcast), direction=GE })", config) .ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 869724db601..dd3b92ec3af 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -411,6 +411,132 @@ TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) { // Check that the kernel runs correctly. EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); } + +TEST_F(GpuKernelTilingTest, ColumnReductionWithLayoutChangeTiled) { + const char *const kHloString = R"( + HloModule reduce_with_layout_change + reduction0 { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) + } + + ENTRY kernel_entry { + arg0 = f32[4,32,32,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[4,32,16,12,12]{4,3,2,1,0} reduce(arg0, constant0), + dimensions={1,6,7}, to_apply=reduction0 + })"; + + // Check that the kernel is tiled by looking for llvm.nvvm.atomic. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @reduce +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +} + +TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { + const char *const kHloString = R"( + HloModule reduce_with_layout_change + reduction0 { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) + } + + ENTRY kernel_entry { + arg0 = f32[8,6,64]{2,1,0} parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[8,6]{0,1} reduce(arg0, constant0), dimensions={2}, + to_apply=reduction0 + })"; + + // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @reduce +; CHECK: call float @llvm.nvvm.shfl.sync.down.f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +} + +TEST_F(GpuKernelTilingTest, + ColumnReductionResultTwoPartsWithLayoutChangeTiled) { + const char *const kHloString = R"( + HloModule reduce_with_no_layout_change + reduction0 { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) + } + + ENTRY kernel_entry { + arg0 = f32[8,64,32]{2,1,0} parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[8,32]{0,1} reduce(arg0, constant0), dimensions={1}, + to_apply=reduction0 + })"; + + // Check that the kernel is tiled by looking for llvm.nvvm.atomic. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @reduce +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +} + +TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) { + const char *const kHloString = R"( + HloModule reduction + reduction0 { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) + } + + ENTRY kernel_entry { + arg0 = f32[8,6,16]{2,1,0} parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[8,6]{1,0} reduce(arg0, constant0), dimensions={2}, + to_apply=reduction0 + })"; + + // Check that the kernel is not tiled by looking for llvm.nvvm.shfl.sync.down. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @reduce +; CHECK-NOT: call float @llvm.nvvm.shfl.sync.down.f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 3019215c015..8b844e66b90 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -111,8 +111,8 @@ TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { hlo_module->AddEmbeddedComputation(embedded_builder.Build()); } - auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); - auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + auto param_shape = ShapeUtil::MakeShape(F32, {32, 32}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {32}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc index 6b98cbb6570..5a9b7bdf902 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk.cc @@ -18,48 +18,54 @@ limitations under the License. namespace xla { namespace gpu { -std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { +absl::string_view ThunkKindToString(Thunk::Kind kind) { switch (kind) { case Thunk::kCholesky: - return os << "kCholesky"; + return "kCholesky"; case Thunk::kConditional: - return os << "kConditional"; + return "kConditional"; case Thunk::kConvolution: - return os << "kConvolution"; + return "kConvolution"; case Thunk::kCopy: - return os << "kCopy"; + return "kCopy"; case Thunk::kCudnnBatchNormBackward: - return os << "kCudnnBatchNormBackward"; + return "kCudnnBatchNormBackward"; case Thunk::kCudnnBatchNormForwardInference: - return os << "kCudnnBatchNormForwardInference"; + return "kCudnnBatchNormForwardInference"; case Thunk::kCudnnBatchNormForwardTraining: - return os << "kCudnnBatchNormForwardTraining"; + return "kCudnnBatchNormForwardTraining"; + case Thunk::kCustomCall: + return "kCustomCall"; case Thunk::kNcclAllReduce: - return os << "kNcclAllReduce"; + return "kNcclAllReduce"; case Thunk::kFft: - return os << "kFft"; + return "kFft"; case Thunk::kGemm: - return os << "kGemm"; + return "kGemm"; case Thunk::kInfeed: - return os << "kInfeed"; + return "kInfeed"; case Thunk::kKernel: - return os << "kKernel"; + return "kKernel"; case Thunk::kMemset32BitValue: - return os << "kMemset32BitValue"; + return "kMemset32BitValue"; case Thunk::kMemzero: - return os << "kMemzero"; + return "kMemzero"; case Thunk::kOutfeed: - return os << "kOutfeed"; + return "kOutfeed"; case Thunk::kSequential: - return os << "kSequential"; + return "kSequential"; case Thunk::kTriangularSolve: - return os << "kTriangularSolve"; + return "kTriangularSolve"; case Thunk::kTuple: - return os << "kTuple"; + return "kTuple"; case Thunk::kWhile: - return os << "kWhile"; + return "kWhile"; } } +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { + return os << ThunkKindToString(kind); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 442506f002c..bdd06718717 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -49,13 +49,14 @@ class Thunk { kCudnnBatchNormBackward, kCudnnBatchNormForwardInference, kCudnnBatchNormForwardTraining, - kNcclAllReduce, + kCustomCall, kFft, kGemm, kInfeed, kKernel, kMemset32BitValue, kMemzero, + kNcclAllReduce, kOutfeed, kSequential, kTriangularSolve, @@ -85,10 +86,6 @@ class Thunk { return Status::OK(); } - // Returns true if this kernel will autotune for the stream device the next - // time it is run. - virtual bool WillAutotuneKernel(se::Stream* /*stream*/) { return false; } - // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. 'stream' and 'profiler' must be non-null. @@ -98,6 +95,11 @@ class Thunk { se::Stream* stream, HloExecutionProfiler* profiler) = 0; + protected: + const HloModuleConfig& GetModuleConfig() const { + return hlo_instruction()->GetModule()->config(); + } + private: Kind kind_; const HloInstruction* hlo_instruction_; @@ -106,6 +108,7 @@ class Thunk { // A sequence of thunks. using ThunkSequence = std::vector>; +absl::string_view ThunkKindToString(Thunk::Kind); std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 25bad67bab9..daa5f33e560 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" @@ -144,11 +147,32 @@ const std::list& ThunkSchedule::DependsOn( } string ThunkSchedule::ToString() const { + if (thunk_total_order_.empty()) { + return "No thunks."; + } + + const Thunk* thunk_with_longest_kind = *absl::c_max_element( + thunk_total_order_, [](const Thunk* a, const Thunk* b) { + return ThunkKindToString(a->kind()).length() < + ThunkKindToString(b->kind()).length(); + }); + int64 max_thunk_kind_len = + ThunkKindToString(thunk_with_longest_kind->kind()).length(); + string result = "Total order:\n"; for (Thunk* thunk : thunk_total_order_) { - absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n"); + // Write out the thunk kind, padded out to max_thunk_kind_len. + absl::string_view kind_str = ThunkKindToString(thunk->kind()); + absl::StrAppend(&result, kind_str, + string(max_thunk_kind_len - kind_str.length(), ' '), "\t"); + if (thunk->hlo_instruction() != nullptr) { + absl::StrAppend(&result, thunk->hlo_instruction()->ToString()); + } else { + absl::StrAppend(&result, "(no HloInstruction)"); + } + absl::StrAppend(&result, "\n"); } - absl::StrAppend(&result, "Dependencies:\n"); + absl::StrAppend(&result, "\nDependencies:\n"); for (const auto& entry : depends_on_) { const Thunk* dependent = entry.first; for (const Thunk* dependency : entry.second) { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 77e49f0e46b..64a5fe5fdd2 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -44,9 +44,9 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, tuple_index)); - builder.AddInstruction( - HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, - induction_variable, limit_const)); + builder.AddInstruction(HloInstruction::CreateCompare( + condition_result_shape_, induction_variable, limit_const, + ComparisonDirection::kLt)); return builder.Build(); } diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc deleted file mode 100644 index ef70b688778..00000000000 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Example HLO graph which demonstrates Graphviz dumper for HLO -// computations. When run, pushes the example DOT graph to the Graphviz service -// and prints the URL. Useful for seeing effect of changes to the graph -// generation code. - -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -// Adds a computation to the given HLO module which adds a scalar constant to -// its parameter and returns the result. -HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { - auto builder = HloComputation::Builder(absl::StrCat("add_", addend)); - auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "x_value")); - auto half = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); - builder.AddInstruction(HloInstruction::CreateBinary( - half->shape(), HloOpcode::kAdd, x_value, half)); - return module->AddEmbeddedComputation(builder.Build()); -} - -// Adds a computation to the given HLO module which sums its two parameters and -// returns the result. -HloComputation* ScalarSumComputation(HloModule* module) { - auto builder = HloComputation::Builder("add"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "lhs")); - auto rhs = builder.AddInstruction( - HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "rhs")); - builder.AddInstruction( - HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); - return module->AddEmbeddedComputation(builder.Build()); -} - -// Adds a computation to the given HLO module which forwards its argument to a -// kCall instruction which then calls the given computation. -HloComputation* CallForwardingComputation(HloComputation* computation, - HloModule* module) { - auto builder = HloComputation::Builder("call_forward"); - auto arg = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "arg")); - builder.AddInstruction( - HloInstruction::CreateCall(arg->shape(), {arg}, computation)); - return module->AddEmbeddedComputation(builder.Build()); -} - -// Create a large, arbitrary computation with many different kinds of -// instructions. Sets the computation as the entry to an HLO module and returns -// the module. -std::unique_ptr MakeBigGraph() { - HloModuleConfig config; - auto module = absl::make_unique("BigGraph", config); - - auto builder = HloComputation::Builder("TestBigGraphvizGraph"); - - // Shapes used in the computation. - auto mshape = ShapeUtil::MakeShape(F32, {3, 5}); - auto vshape = ShapeUtil::MakeShape(F32, {3}); - auto sshape = ShapeUtil::MakeShape(F32, {3}); - - // Create a set of parameter instructions. - auto param_v0 = - builder.AddInstruction(HloInstruction::CreateParameter(0, vshape, "foo")); - auto param_v1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, vshape, "bar")); - auto param_v2 = - builder.AddInstruction(HloInstruction::CreateParameter(2, vshape, "baz")); - auto param_s = - builder.AddInstruction(HloInstruction::CreateParameter(3, sshape, "qux")); - auto param_m = - builder.AddInstruction(HloInstruction::CreateParameter(4, mshape, "zzz")); - - // Add an arbitrary expression of different instructions. - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); - auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( - vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfig precision_config; - precision_config.mutable_operand_precision()->Resize( - /*new_size=*/2, PrecisionConfig::DEFAULT); - auto dot = builder.AddInstruction(HloInstruction::CreateDot( - vshape, clamp, param_v0, dot_dnums, precision_config)); - auto tuple = builder.AddInstruction( - HloInstruction::CreateTuple({dot, param_s, clamp})); - auto scalar = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(sshape, tuple, 2)); - auto add_one = AddScalarConstantComputation(1.0, module.get()); - auto rng = builder.AddInstruction( - HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto add_computation = ScalarSumComputation(module.get()); - builder.AddInstruction( - HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); - auto map1 = builder.AddInstruction( - HloInstruction::CreateMap(sshape, {scalar}, add_one)); - auto map2 = builder.AddInstruction( - HloInstruction::CreateMap(sshape, {map1}, add_one)); - auto map3 = builder.AddInstruction( - HloInstruction::CreateMap(sshape, {map2}, add_one)); - - // Create a fusion instruction containing the chain of map instructions. - auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( - sshape, HloInstruction::FusionKind::kLoop, map3)); - fusion->FuseInstruction(map2); - fusion->FuseInstruction(map1); - - // Add a random trace instruction. - builder.AddInstruction(HloInstruction::CreateTrace("trace", dot)); - - // Add a call instruction will calls the call-forwarding computation to call - // another computation. - auto call_computation = CallForwardingComputation(add_one, module.get()); - builder.AddInstruction( - HloInstruction::CreateCall(fusion->shape(), {fusion}, call_computation)); - - module->AddEntryComputation(builder.Build()); - return module; -} - -} // namespace -} // namespace xla - -int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); - - auto module = xla::MakeBigGraph(); - - printf("Graph URL: %s\n", xla::hlo_graph_dumper::DumpGraph( - *module->entry_computation(), - "Example computation", xla::DebugOptions()) - .c_str()); - return 0; -} diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 4fca981c6a5..2af8e1d6ea6 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -251,7 +251,26 @@ Status HeapSimulator::RunComputation( // We can only share with the operand buffer if it is about to be freed; // we must be the last user of the buffer. bool shared = false; - if (options_.may_reuse_operand_buffers) { + auto shared_it = shared_buffers_.find(buffer); + if (shared_it != shared_buffers_.end()) { + std::shared_ptr group = shared_it->second; + if (group->refcount != 0) { + // This buffer has a shared group with already some instructions + // scheduled (refcount > 0), find and share buffer with the + // canonical instruction. + shared = true; + VLOG(3) << " Sharing: " << buffer->ToString() + << " with must aliased buffer " + << group->canonical->ToString(); + FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, + instruction, group->canonical); + } else { + VLOG(3) << " New shared group, canonical buffer: " + << buffer->ToString(); + group->canonical = buffer; + } + group->refcount++; + } else if (options_.may_reuse_operand_buffers) { for (const BufferValue* operand_buffer : operand_buffers_to_free) { if (reused_buffers.contains(operand_buffer)) { continue; @@ -261,12 +280,17 @@ Status HeapSimulator::RunComputation( points_to_analysis.CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), buffer->instruction(), buffer->index())) { - VLOG(3) << " Sharing: " << buffer->ToString() << " with " - << operand_buffer->ToString(); - ShareBuffer(buffer, operand_buffer, instruction); - shared = true; - reused_buffers.insert(operand_buffer); - break; + // Make sure the two buffers belong to the same shared groups. + // Otherwise we'd need to merge those shared groups which is not + // suported. + if (InSameSharedGroup(buffer, operand_buffer)) { + VLOG(3) << " Sharing: " << buffer->ToString() << " with " + << operand_buffer->ToString(); + ShareBuffer(buffer, operand_buffer, instruction); + shared = true; + reused_buffers.insert(operand_buffer); + break; + } } } } @@ -358,6 +382,17 @@ HeapSimulator::HeapSimulator( options_(options), schedule_(schedule), memory_by_computation_(memory_by_computation) { + for (const BufferValueFlatSet& value_set : options.must_alias_sets) { + auto group = std::make_shared(); + group->refcount = 0; + VLOG(2) << "Shared buffers:"; + for (const BufferValue* buffer_value : value_set) { + VLOG(2) << " " << buffer_value->ToString(); + shared_buffers_.emplace(buffer_value, group); + // Refcounts are not incremented here as buffers are shared but not + // referenced yet. + } + } debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } @@ -402,9 +437,13 @@ void HeapSimulator::Free(const BufferValue* buffer, if (shared_it != shared_buffers_.end()) { std::shared_ptr group = shared_it->second; --group->refcount; + VLOG(3) << " Decrementing refcount : " << group->canonical->ToString(); if (group->refcount > 0) { + // Another buffer still holds the reference to this shared group, don't + // free the underlying canonical buffer. return; } + VLOG(3) << " Ref == 0 " << group->canonical->ToString(); CHECK_EQ(group->refcount, 0) << "Free caused negative refcount on shared buffer: " << *buffer; buffer = group->canonical; @@ -423,6 +462,21 @@ void HeapSimulator::Free(const BufferValue* buffer, FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr); } +bool HeapSimulator::InSameSharedGroup(const BufferValue* left, + const BufferValue* right) { + auto left_it = shared_buffers_.find(left); + if (left_it == shared_buffers_.end()) { + return true; + } + + auto right_it = shared_buffers_.find(right); + if (right_it == shared_buffers_.end()) { + return true; + } + + return left_it->second == right_it->second; +} + // ShareBuffer associates buffers with their SharedGroup in shared_buffers_. // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to // Alloc. The 'shared' buffer must be a previously allocated or shared buffer. @@ -445,6 +499,12 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, // The 'shared' buffer already has a group; it might be the canonical, but // also might not be. Just add 'buffer' to the existing group. std::shared_ptr group = shared_it->second; + + if (group->refcount == 0) { + // Nothing is scheduled at the shared group yet. This must be the + // canonical. + group->canonical = shared; + } canonical = group->canonical; ++group->refcount; shared_buffers_.emplace(buffer, group); @@ -475,7 +535,7 @@ HeapSimulator::Result HeapSimulator::Finish() { for (const auto& share_pair : shared_buffers_) { const BufferValue* buffer = share_pair.first; std::shared_ptr group = share_pair.second; - if (buffer != group->canonical) { + if (buffer != group->canonical && group->canonical != nullptr) { // The canonical must already exist in the chunk_map, since we called // Alloc(canonical) on the underlying algorithm. Add non-canonical // chunks with the same offset as the canonical. diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 3e0631aeb4a..ef1a62ed414 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -85,6 +85,9 @@ class HeapSimulator { // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. const BufferValueFlatSet* buffers_to_assign; + // A vector of multiple buffer value sets. Each set enforces a must-alias + // relationship for all buffers inside them. + std::vector must_alias_sets; }; // Returns the minimum memory required to compute an HLO module where all @@ -153,6 +156,11 @@ class HeapSimulator { void Free(const BufferValue* buffer, const HloInstruction* instruction); void ShareBuffer(const BufferValue* buffer, const BufferValue* shared, const HloInstruction* instruction); + + // Returns true if: + // Two buffers belong to the same shared group. + // Eight of the buffer has no shared group assigned. + bool InSameSharedGroup(const BufferValue* left, const BufferValue* right); Result Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index dc40b9446ad..8cb70a1d088 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -54,8 +54,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_data, ComparisonDirection::kLt)); HloComputation* cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -113,7 +113,8 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]} // %reshape = f32[] reshape(f32[1]{0} %slice) // %constant = f32[] constant(0) - // ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant) + // ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant), + // direction=NE // } // ENTRY %SubcomputationAccounting () -> f32[2,4] { @@ -143,9 +144,9 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice)); HloInstruction* zero = cond_builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); - HloInstruction* cond_comparison = - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero)); + HloInstruction* cond_comparison = cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape, + zero, ComparisonDirection::kNe)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); // param - 1 @@ -258,7 +259,8 @@ class HeapSimulatorTracker { // Constructor for testing a single entry computation. HeapSimulatorTracker( const string& name, std::unique_ptr computation, - const std::vector& instruction_sequence) { + const std::vector& instruction_sequence, + const std::vector& must_alias_set = {}) { HloModuleConfig config; module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); @@ -271,10 +273,19 @@ class HeapSimulatorTracker { auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); + BufferValueFlatSet must_alias_buffer_value_set; + + for (HloInstruction* hlo : must_alias_set) { + must_alias_buffer_value_set.insert( + points_to_analysis_->GetBufferDefinedAt(hlo, {}).ValueOrDie()); + } + + HeapSimulator::Options options; + options.must_alias_sets = {must_alias_buffer_value_set}; result_ = HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), HloInstructionSequence(instruction_sequence), - *points_to_analysis_, zero_size) + *points_to_analysis_, zero_size, options) .ConsumeValueOrDie(); } @@ -409,6 +420,46 @@ TEST_F(HeapSimulatorTest, Multiply) { }); } +TEST_F(HeapSimulatorTest, MustAliasBuffers) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec4_, "paramY")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto add_1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); + + auto add_2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); + + auto add_3 = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, add_1, add_2)); + + // Check that mul and add_2 are collocated as requested by the user. + HeapSimulatorTracker tracker( + TestName(), builder.Build(), + {paramA, paramX, mul, paramY, add_1, add_2, add_3}, {mul, add_2}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramX, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(add_1, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramX, {})}, + {kFree, tracker.BufferAt(mul, {})}, + {kFree, tracker.BufferAt(paramY, {})}, + {kFree, tracker.BufferAt(add_1, {})}, + {kFinish, nullptr}, + }); + tracker.ExpectSharedBuffers(add_2, {}, mul, {}); +} + TEST_F(HeapSimulatorTest, MultiplyAdd) { auto builder = HloComputation::Builder(TestName()); auto paramA = builder.AddInstruction( @@ -703,8 +754,8 @@ TEST_F(HeapSimulatorTest, WholeModule) { HloInstruction* cond_data = cond_builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_data, ComparisonDirection::kLt)); HloComputation* cond_computation = tracker.module()->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1413ce3062d..135e7c9c1d3 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 63 +// Next ID: 65 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -49,6 +49,9 @@ message HloInstructionProto { reserved "called_computation_names"; reserved 44; reserved "replica_group_ids"; + // Use backend_config instead for custom_call_opaque. + reserved 53; + reserved "custom_call_opaque"; string name = 1; string opcode = 2; @@ -131,9 +134,6 @@ message HloInstructionProto { // kCustomCall. string custom_call_target = 28; - // Opaque string, only present for kCustomCall. - string custom_call_opaque = 53; - // Shape of outfeed request. xla.ShapeProto outfeed_shape = 29; @@ -146,6 +146,9 @@ message HloInstructionProto { // FFT length. repeated int64 fft_length = 32; + // Comparison direction only used for kCompare. + string comparison_direction = 63; + // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_slice_sizes = 34; @@ -206,6 +209,15 @@ message HloInstructionProto { // Describes how parameters behave with regards to replicas. xla.ParameterReplication parameter_replication = 61; + + // If set, the given instruction is run in parallel on e.g. multiple CPU + // cores. The outermost dimension gets split up into + // outer_dimension_partitions[0] pieces, the next-outermost dim gets split + // into outer_dimension_partitions[1] pieces, etc. + // + // It's illegal to partition a dimension into more shards than there are + // elements in that dimension. + repeated int64 outer_dimension_partitions = 64; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 7d02f4b3d75..8e10d6a376a 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -59,7 +59,7 @@ class BufferValueMap { // construction process. using BufferNumber = int64; - explicit BufferValueMap(HloModule* module, + explicit BufferValueMap(const HloModule* module, const HloDataflowAnalysis& dataflow) : module_(module), dataflow_(dataflow) { buffers_.reserve(dataflow_.values().size()); @@ -325,7 +325,7 @@ class BufferValueMap { return aliased_buffers; } - HloModule* module_; + const HloModule* module_ = nullptr; // Dataflow analysis used to construct the buffer map. const HloDataflowAnalysis& dataflow_; @@ -341,7 +341,7 @@ class BufferValueMap { BufferNumber next_buffer_number_ = 0; }; -HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {} +HloAliasAnalysis::HloAliasAnalysis(const HloModule* module) : module_(module) {} const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt( const HloInstruction* instruction, const ShapeIndex& index) const { @@ -488,8 +488,9 @@ string HloAliasAnalysis::ToString() const { /* static */ StatusOr> HloAliasAnalysis::Run( - HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& - fusion_can_share_buffer) { + const HloModule* module, + const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer) { VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); @@ -523,10 +524,75 @@ StatusOr> HloAliasAnalysis::Run( TF_DCHECK_OK(alias_analysis->Verify()); + HloInstruction* root = module->entry_computation()->root_instruction(); + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { + for (const HloBuffer* buffer : + alias_analysis->ComputeBuffersAt(root, index)) { + alias_analysis->live_out_buffers_.insert(buffer); + } + }); + XLA_VLOG_LINES(2, alias_analysis->ToString()); return std::move(alias_analysis); } +void HloAliasAnalysis::MergeBuffers(const HloBuffer& to, + const HloBuffer& from) { + CHECK(to.id() != from.id()); + VLOG(2) << "Merge buffer: " << from.ToString() << " into :" << to.ToString(); + + CHECK(from.id() < buffers_.size()); + CHECK(to.id() < buffers_.size()); + + // Merge the values of `to` and `from`, creates a new buffer with the + // merged values. + std::vector merged_values(to.values().begin(), + to.values().end()); + + merged_values.insert(merged_values.end(), from.values().begin(), + from.values().end()); + absl::c_sort(merged_values, [](const HloValue* a, const HloValue* b) { + return a->id() < b->id(); + }); + + buffers_[to.id()] = HloBuffer(to.id(), merged_values); + for (const HloValue* value : merged_values) { + // Update references of values. + value_to_buffer_[value] = &buffers_[to.id()]; + } + + if (live_out_buffers_.count(&from) > 0) { + // Update live out set to erase `from` and add `to`. + live_out_buffers_.erase(&from); + live_out_buffers_.insert(&buffers_[to.id()]); + } + + int64 from_id = from.id(); + if (from_id != buffers_.size() - 1) { + // Now `from` is invalid, move the last element of buffers to replace `from` + // and update references to the last element. + const HloBuffer& last_elem = buffers_.back(); + buffers_[from.id()] = HloBuffer(from_id, last_elem.values()); + + if (live_out_buffers_.count(&last_elem) > 0) { + // Update live out set to redirect the last element to its new position. + live_out_buffers_.erase(&last_elem); + live_out_buffers_.insert(&buffers_[from_id]); + } + + // Update references of values. + for (const HloValue* value : buffers_[from_id].values()) { + value_to_buffer_[value] = &buffers_[from_id]; + } + } + + // Remove the last element. + buffers_.pop_back(); + + CHECK(Verify().ok()); +} + bool HloAliasAnalysis::HasLiveRangeInterference( const HloOrdering& ordering) const { for (const HloBuffer& buffer : buffers()) { diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 372f99ff01c..d09ec15e83a 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -41,7 +41,7 @@ class HloAliasAnalysis { // The callgraph of the given HloModule must be flattened // (xla::FlattenCallGraph) prior to running the analysis. static StatusOr> Run( - HloModule* module, + const HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& fusion_can_share_buffer); @@ -82,9 +82,7 @@ class HloAliasAnalysis { const std::vector& buffers() const { return buffers_; } // Returns the underlying dataflow analysis used by this alias analysis. - const HloDataflowAnalysis& dataflow_analysis() const { - return *dataflow_analysis_; - } + HloDataflowAnalysis& dataflow_analysis() const { return *dataflow_analysis_; } // Returns true if any index in the output of the given instruction has more // than one buffer. That is, ComputeBuffersAt returns a vector with more than @@ -95,17 +93,44 @@ class HloAliasAnalysis { // output of the given instruction. bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const; + // Merge buffer `from` into buffer `to`. Caller has to make sure no + // interference will be introduced after merging. This rebuilds internal data + // structure, and invalidates references to all existing buffers. + void MergeBuffers(const HloBuffer& to, const HloBuffer& from); + // Returns true if any HLO values in the module have interfering live ranges // assuming the given ordering. bool HasLiveRangeInterference(const HloOrdering& ordering) const; + // Returns true if a buffer lives out of the module. + bool BufferLivesOut(const HloBuffer& buffer) const { + return live_out_buffers_.count(&buffer); + } + + // Returns true if a hlo value lives out of the module. + bool ValueLivesOut(const HloValue& value) const { + return live_out_buffers_.count(&GetBufferContainingValue(value)); + } + + std::vector LiveOutBuffers() const { + std::vector results(live_out_buffers_.begin(), + live_out_buffers_.end()); + absl::c_sort(results, [](const HloBuffer* a, const HloBuffer* b) { + return a->id() < b->id(); + }); + return results; + } + protected: - explicit HloAliasAnalysis(HloModule* module); + explicit HloAliasAnalysis(const HloModule* module); // Verify various invariants of the alias analysis. Status Verify() const; - HloModule* module_; + const HloModule* module_; + + // A set of buffers that live out the module. + absl::flat_hash_set live_out_buffers_; // The underlying dataflow analysis used by this alias analysis. std::unique_ptr dataflow_analysis_; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index b6dbf07959c..89eda8552d4 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -48,7 +47,6 @@ class HloAliasAnalysisTest : public HloTestBase { // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); analysis_ = HloAliasAnalysis::Run(module_.get(), /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); @@ -126,6 +124,7 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -160,6 +159,7 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -203,6 +203,7 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({param0, param1, param0})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -237,6 +238,8 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); @@ -281,6 +284,8 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}, /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); @@ -370,6 +375,8 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); @@ -421,6 +428,7 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -462,6 +470,7 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -547,6 +556,7 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -647,6 +657,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { FlattenCallGraph flattener; TF_ASSERT_OK(flattener.Run(module_.get()).status()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -738,6 +749,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { auto entry_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -811,6 +823,7 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -872,6 +885,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -960,6 +974,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { HloInstruction::CreateWhile(tuple_shape, condition, body, select)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -997,6 +1012,7 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { scalar_shape_, HloOpcode::kBitcast, constant)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -1006,6 +1022,54 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { analysis.GetUniqueBufferAt(bitcast)); } +TEST_F(HloAliasAnalysisTest, MergeBuffers) { + // Bitcasting a value should not produce a new buffer. + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, elem_shape, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(elem_shape, HloOpcode::kNegate, param0)); + builder.AddInstruction( + HloInstruction::CreateUnary(elem_shape, HloOpcode::kNegate, negate)); + + module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.buffers().size(), 3); + analysis.MergeBuffers(analysis.buffers()[0], analysis.buffers()[1]); + EXPECT_EQ(analysis.buffers().size(), 2); + analysis.MergeBuffers(analysis.buffers()[0], analysis.buffers()[1]); + EXPECT_EQ(analysis.buffers().size(), 1); + analysis.BufferLivesOut(analysis.buffers()[0]); +} + +TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) { + // Bitcasting a value should not produce a new buffer. + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, elem_shape, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(elem_shape, HloOpcode::kNegate, param0)); + builder.AddInstruction( + HloInstruction::CreateUnary(elem_shape, HloOpcode::kNegate, negate)); + + module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.buffers().size(), 3); + analysis.MergeBuffers(analysis.buffers()[2], analysis.buffers()[1]); + EXPECT_EQ(analysis.buffers().size(), 2); + analysis.MergeBuffers(analysis.buffers()[1], analysis.buffers()[0]); + EXPECT_EQ(analysis.buffers().size(), 1); + analysis.BufferLivesOut(analysis.buffers()[0]); +} + TEST_F(HloAliasAnalysisTest, BitcastInterference) { // A bitcast value simultaneously live with its operand should not cause // interference. @@ -1017,6 +1081,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -1056,6 +1121,7 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { builder.AddInstruction(HloInstruction::CreateTuple({negate, xla_while})); HloComputation* entry = module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h index a88c87e46c8..a81078fdc96 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.h +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -109,11 +109,11 @@ class HloBuffer { private: // Unique identifier for this HloBuffer. - const Id id_; + Id id_; // The set of values contained in this buffer. Vector contains no duplicates // and is sorted stably by HloValue::Id. - const std::vector values_; + std::vector values_; }; std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 48a51d302bb..195c84b034f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -716,7 +716,8 @@ ProgramShape HloComputation::ComputeProgramShape() const { return program_shape; } -bool HloComputation::operator==(const HloComputation& other) const { +bool HloComputation::Equal(const HloComputation& other, + bool is_layout_sensitive) const { if (this == &other) { return true; } @@ -741,7 +742,8 @@ bool HloComputation::operator==(const HloComputation& other) const { [](const HloInstruction*, const HloInstruction*) { return true; }, [](const HloComputation* a, const HloComputation* b) { return *a == *b; - }); + }, + is_layout_sensitive); if (!identical_ignoring_operands) { return false; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index a48cfa1f1b2..89dbe93b36b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -198,6 +198,13 @@ class HloComputation { const HloComputationProto& proto, const absl::flat_hash_map& computation_map); + using InstructionSequence = tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>>; + + using ConstInstructionSequence = + tensorflow::gtl::iterator_range>::const_iterator>>; + // Gets the instructions in this computation. // // The returned type is a range of HloInstruction*s, so you can iterate over @@ -205,15 +212,11 @@ class HloComputation { // // for (HloInstruction* instr : computation->instructions()) { ... } // - tensorflow::gtl::iterator_range>::const_iterator>> - instructions() const { + ConstInstructionSequence instructions() const { return {MakeUnwrappingIterator(instructions_.begin()), MakeUnwrappingIterator(instructions_.end())}; } - tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> - instructions() { + InstructionSequence instructions() { return {MakeUnwrappingIterator(instructions_.begin()), MakeUnwrappingIterator(instructions_.end())}; } @@ -270,7 +273,12 @@ class HloComputation { ProgramShape ComputeProgramShape() const; // Return whether `*this` and `other` are functionally equivalent. - bool operator==(const HloComputation& other) const; + bool Equal(const HloComputation& other, bool is_layout_sensitive) const; + + // Return whether `*this` and `other` are functionally equivalent. + bool operator==(const HloComputation& other) const { + return Equal(other, true); + } // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index fe37ca6b396..3fa6f80b1b9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -509,8 +509,9 @@ TEST_F(HloComputationTest, CloneWithReplacements) { HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs")); auto param2 = builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1")); - auto lt = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1)); + auto lt = builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + param1, ComparisonDirection::kLt)); auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/lt)); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index e7ed858e8c5..e0f18c4efd1 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -130,7 +130,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { int64 elements_in_constant = ShapeUtil::ElementsIn(instruction->shape()); - static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000; + static const int64 kMaximumConstantSizeElements = 45 * 1000 * 1000; if (elements_in_constant > elements_in_removed_operands && elements_in_constant > kMaximumConstantSizeElements) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 13b1c827095..8c1b22e0a10 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -17,6 +17,10 @@ limitations under the License. #include +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -129,6 +133,42 @@ int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const { return shape_size_(shape); } +int64 HloCostAnalysis::FusionParameterReadBytes( + const HloInstruction* hlo) const { + int64 size = 0; + bool seen_trivial_user = false; + CHECK(hlo->IsFused() && hlo->opcode() == HloOpcode::kParameter); + for (const HloInstruction* user : hlo->users()) { + switch (user->opcode()) { + case HloOpcode::kFusion: { + for (int64 idx : user->OperandIndices(hlo)) { + size += FusionParameterReadBytes(user->fused_parameter(idx)); + } + break; + } + case HloOpcode::kSlice: + size += GetShapeSize(user->shape()); + break; + case HloOpcode::kDynamicSlice: + size += hlo == user->operand(0) ? GetShapeSize(user->shape()) + : GetShapeSize(hlo->shape()); + break; + case HloOpcode::kBroadcast: + case HloOpcode::kReshape: + size += GetShapeSize(hlo->shape()); + break; + default: + // Other instructions reading this parameter are assumed to be able to + // share the read from memory. + if (!seen_trivial_user) { + seen_trivial_user = true; + size += GetShapeSize(hlo->shape()); + } + } + } + return size; +} + Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } @@ -598,6 +638,10 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } +Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) { return Status::OK(); } @@ -612,6 +656,17 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { } Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { + if (fusion->IsCustomFusion()) { + for (const HloInstruction* hlo : + fusion->fused_instructions_computation()->instructions()) { + if (hlo->opcode() == HloOpcode::kGather) { + return HandleGather(hlo); + } + if (hlo->opcode() == HloOpcode::kScatter) { + return HandleScatter(hlo); + } + } + } TF_ASSIGN_OR_RETURN( current_properties_, ProcessNestedSubcomputation(fusion->fused_instructions_computation())); @@ -622,12 +677,34 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { current_properties_[kBytesAccessedKey] = 0; ShapeUtil::ForEachSubshape( fusion->shape(), - [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { + [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) { + if (!subshape.IsArray()) { + return; + } + if (shape_index.empty()) { + if (fusion->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + current_properties_[kBytesAccessedKey] += GetShapeSize( + fusion->fused_expression_root()->operand(0)->shape()); + return; + } + } else if (shape_index.size() == 1) { + if (fusion->fused_expression_root() + ->operand(shape_index[0]) + ->opcode() == HloOpcode::kDynamicUpdateSlice) { + current_properties_[kBytesAccessedKey] += + GetShapeSize(fusion->fused_expression_root() + ->operand(shape_index[0]) + ->operand(0) + ->shape()); + return; + } + } current_properties_[kBytesAccessedKey] += GetShapeSize(subshape); }); - for (const HloInstruction* operand : fusion->operands()) { - current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape()); + for (const HloInstruction* operand : fusion->fused_parameters()) { + current_properties_[kBytesAccessedKey] += FusionParameterReadBytes(operand); } return Status::OK(); @@ -718,8 +795,10 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { } Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { + // Scatter accesses the equivalent of 3 update shapes (input, output, and + // updates), and the scatter indices. current_properties_[kBytesAccessedKey] = - GetShapeSize(scatter->operand(2)->shape()) * 2 + + GetShapeSize(scatter->operand(2)->shape()) * 3 + GetShapeSize(scatter->operand(1)->shape()); const int64 element_count = ShapeUtil::ElementsIn(scatter->operand(2)->shape()); @@ -777,6 +856,7 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { StatusOr HloCostAnalysis::ProcessNestedSubcomputation(HloComputation* computation) { HloCostAnalysis visitor(shape_size_, per_second_rates_); + visitor.ReserveVisitStates(computation->instruction_count()); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.properties(); } @@ -784,6 +864,7 @@ HloCostAnalysis::ProcessNestedSubcomputation(HloComputation* computation) { StatusOr HloCostAnalysis::ProcessUnnestedSubcomputation(HloComputation* computation) { HloCostAnalysis visitor(shape_size_, per_second_rates_); + visitor.ReserveVisitStates(computation->instruction_count()); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); hlo_properties_.insert(visitor.hlo_properties_.begin(), visitor.hlo_properties_.end()); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 4480554de50..b76465531f0 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -77,6 +77,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleReplicaId(const HloInstruction* hlo) override; + Status HandlePartitionId(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleRng(const HloInstruction* random) override; @@ -196,6 +197,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // a layout. int64 GetShapeSize(const Shape& shape) const; + // Traverses a fusion operand to find the actual bytes accessed by the fusion + // node. + int64 FusionParameterReadBytes(const HloInstruction* hlo) const; + // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 4d42770ba78..e9e4bfb3bff 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -688,7 +688,7 @@ TEST_F(HloCostAnalysisTest, Scatter) { ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); - EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 2 * (2 * 3))); + EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 3 * (2 * 3))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index b5d9e8e7f1a..99e6217e7dd 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" + #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -42,6 +43,18 @@ StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs)); } +StatusOr MakeCompareHlo(ComparisonDirection direction, + HloInstruction* lhs, + HloInstruction* rhs) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN( + Shape binary_op_shape, + ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs)); + return computation->AddInstruction( + HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction)); +} + StatusOr MakePadHlo(HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { @@ -262,15 +275,37 @@ StatusOr MakeReduceHlo(HloInstruction* operand, StatusOr MakeSelectHlo(HloInstruction* pred, HloInstruction* on_true, - HloInstruction* on_false) { + HloInstruction* on_false, + HloInstruction* derived_from) { HloComputation* computation = pred->parent(); DCHECK_EQ(computation, on_true->parent()); DCHECK_EQ(computation, on_false->parent()); + Shape op_shape = on_true->shape(); + if (ShapeUtil::IsScalar(pred->shape())) { + if (!ShapeUtil::IsScalar(op_shape) && !op_shape.IsTuple()) { + // If the output is not scalar, we need to broadcast the condition + // to match the contract of kSelect. For tuples, we use kTupleSelect + // which expects the condition to be a scalar. + pred = computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(op_shape, PrimitiveType::PRED), pred, + {})); + if (derived_from) { + derived_from->SetupDerivedInstruction(pred); + } + } + } + HloOpcode select_op_code = + op_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect; TF_ASSIGN_OR_RETURN(Shape select_shape, - ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, pred, on_true, on_false)); - return computation->AddInstruction(HloInstruction::CreateTernary( - select_shape, HloOpcode::kSelect, pred, on_true, on_false)); + ShapeInference::InferTernaryOpShape(select_op_code, pred, + on_true, on_false)); + HloInstruction* select = + computation->AddInstruction(HloInstruction::CreateTernary( + select_shape, select_op_code, pred, on_true, on_false)); + if (derived_from) { + derived_from->SetupDerivedInstruction(select); + } + return select; } StatusOr MakeSortHlo( @@ -350,27 +385,11 @@ StatusOr ExpandFirstDimIntoNDims( StatusOr ElideDegenerateDims( HloInstruction* operand, absl::Span dims_to_elide) { - CHECK(absl::c_is_sorted(dims_to_elide)); - - const Shape& input_shape = operand->shape(); - // First accumulate in reverse - std::vector new_shape_dim_bounds; - new_shape_dim_bounds.reserve(input_shape.dimensions_size() - - dims_to_elide.size()); - int64 dims_to_elide_idx = dims_to_elide.size() - 1; - for (int64 i = input_shape.dimensions_size() - 1; i >= 0; i--) { - if (dims_to_elide_idx >= 0 && i == dims_to_elide[dims_to_elide_idx]) { - CHECK_EQ(input_shape.dimensions(i), 1); - dims_to_elide_idx--; - } else { - new_shape_dim_bounds.push_back(input_shape.dimensions(i)); - } - } - - absl::c_reverse(new_shape_dim_bounds); - Shape output_shape = - ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); - return MakeReshapeHlo(output_shape, operand); + return MakeReshapeHlo( + ShapeUtil::FilterDimensions( + [&](int64 dim) { return !absl::c_linear_search(dims_to_elide, dim); }, + operand->shape()), + operand); } StatusOr InsertDegenerateDims( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 17b7a2da6a9..61df5fb328f 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -32,6 +32,12 @@ namespace xla { StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs); +// Creates a compare HLO instruction and adds it to the computation containing +// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeCompareHlo(ComparisonDirection direction, + HloInstruction* lhs, + HloInstruction* rhs); + // Creates a pad HLO instruction and adds it to the computation containing // `operand` and `padding_value` (`operand` and `padding_value` must be in the // same computation). @@ -118,10 +124,12 @@ StatusOr MakeReduceHlo(HloInstruction* operand, // Creates a Select HLO instruction and adds it to the computation containing // the predicate. The on_true and on_false instructions must also be contained -// in the same computation. +// in the same computation. If on_true and on_false are tuples, create a tuple +// select instead. `pred` is broadcasted up from a scalar if necessary. StatusOr MakeSelectHlo(HloInstruction* pred, HloInstruction* on_true, - HloInstruction* on_false); + HloInstruction* on_false, + HloInstruction* derived_from = nullptr); // Creates a Sort HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. Also creates a diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 849cac278ee..1e7e125d956 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -143,7 +143,9 @@ StatusOr HloCSE::Run(HloModule* module) { for (auto instruction : computation->MakeInstructionPostOrder()) { // If the instruction has zero operands (constants, parameters, etc.) skip // over it. - if (instruction->operand_count() == 0) { + if (instruction->operand_count() == 0 && + instruction->opcode() != HloOpcode::kPartitionId && + instruction->opcode() != HloOpcode::kReplicaId) { continue; } // Skip instructions which have side effects. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 80215c92a9f..9036ae8d5fd 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -929,8 +929,7 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer( for (const HloValue* value : GetValueSet(operand, index).values()) { for (const HloUse& use : value->uses()) { if (use.instruction == user) { - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->IsLoopFusion()) { HloInstruction* fusion_param = user->fused_parameter(use.operand_number); const HloValue& value = @@ -958,7 +957,6 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer( // // Returns true if: // -// * fusion is a loop or input fusion, AND // * fusion_param is used by the root of dynamic-update-slice as the "base" of // the update, i.e. the thing being updated, AND // * all other uses of fusion_param are dynamic-slices that slice the same @@ -978,13 +976,6 @@ static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion, CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter); CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation()); - // fusion must be a loop or input fusion. - auto kind = fusion->fusion_kind(); - if (kind != HloInstruction::FusionKind::kLoop && - kind != HloInstruction::FusionKind::kInput) { - return false; - } - // fusion_param must be used by the root as the "base" of the // dynamic-update-slice. The natural way to check this would be // @@ -1033,9 +1024,6 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kFusion) { - if (fusion_can_share_buffer_ != nullptr) { - return fusion_can_share_buffer_(user, operand); - } // Get the parameter associated with 'operand'; HloInstruction* fusion_param = user->fused_parameter(user->operand_index(operand)); @@ -1043,17 +1031,27 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( const HloValue& fusion_param_value = GetValueDefinedAt(fusion_param, operand_index); + // TODO(b/80315712): This code is in a bit of a weird intermediate state + // at the moment. The in-place DUS check really needs to be common to all + // backends, so it runs first. Then we run the backend-specific check if + // provided, or go through the target-indepdendent check if not. + // Unfortunately, the notionally "target-independent" path actually contains + // some target-specific code, so we can't run all of it *in addition* to the + // target-specific function, like the interface documentation says. if (user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value); } - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - user->fusion_kind() == HloInstruction::FusionKind::kInput) { + if (fusion_can_share_buffer_ != nullptr) { + return fusion_can_share_buffer_(user, operand); + } + + if (user->IsLoopFusion() || user->IsInputFusion()) { return AreTransitiveUsesElementwiseOrTuple(fusion_param); } - if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + if (user->IsOutputFusion() && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 13027fd5463..cb2341a80be 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -50,7 +50,6 @@ class HloDataflowAnalysisTest : public HloTestBase, // reference to the generated analysis stored in analysis_. const HloDataflowAnalysis& RunAnalysis(bool ssa_form, bool bitcast_defines_value = false) { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis"); analysis_ = HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); @@ -109,6 +108,7 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -157,6 +157,7 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -212,6 +213,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { auto gte_out = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -267,6 +269,7 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -320,6 +323,7 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { auto sub = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kSubtract, call1, call2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -372,6 +376,7 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -434,6 +439,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -509,6 +515,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -614,6 +621,7 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { auto xla_while2 = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -701,6 +709,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { auto entry_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -796,6 +805,7 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -851,6 +861,7 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -893,6 +904,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -964,6 +976,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1053,6 +1066,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1095,6 +1109,7 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { scalar_shape_, HloOpcode::kBitcast, constant)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); { @@ -1131,6 +1146,7 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { auto copy = builder.AddInstruction( HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1164,6 +1180,7 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { HloInstruction::CreateSend(param, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1189,6 +1206,7 @@ TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1224,6 +1242,7 @@ TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); DependencyHloOrdering ordering(module_.get()); @@ -1261,6 +1280,7 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { vector_shape_, HloOpcode::kAdd, negate, exp)); auto entry = module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); HloSchedule schedule(module_.get()); @@ -1339,6 +1359,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { HloInstruction::CreateWhile(scalar_shape_, condition, body, param)); auto entry = module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); bool ssa_form = GetParam(); RunAnalysis(ssa_form); @@ -1409,6 +1430,7 @@ TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) { HloInstruction::CreateReverse(vector_shape_, negate, {0})); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); DependencyHloOrdering ordering(module_.get()); @@ -1440,6 +1462,7 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValues) { vector_shape_, HloOpcode::kAdd, negate, exp)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); DependencyHloOrdering ordering(module_.get()); @@ -1479,6 +1502,7 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { vector_shape_, HloOpcode::kAdd, negate, exp)); auto entry = module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); HloSchedule schedule(module_.get()); @@ -1537,6 +1561,7 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { builder.AddInstruction(HloInstruction::CreateBinary( vector_shape_, HloOpcode::kAdd, negate, call)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); RunAnalysis(GetParam()); DependencyHloOrdering ordering(module_.get()); @@ -1589,6 +1614,7 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { scalar_shape_, pred, constant1, true_computation, constant2, false_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); @@ -1682,6 +1708,7 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, false_computation)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); @@ -1816,6 +1843,7 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { scalar_shape_, pred1, tuple_operand, inner_conditional_computation, constant3, computation3)); module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); @@ -2239,8 +2267,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { HloInstruction::CreateParameter(0, in_shape, "param0")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + auto result = builder.AddInstruction(HloInstruction::CreateCompare( + out_shape, param0, param1, ComparisonDirection::kEq)); BuildModuleAndRunAnalysis(builder.Build()); @@ -2549,7 +2577,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { {add, two, mul}, HloInstruction::FusionKind::kInput); RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion, const HloInstruction*) { - return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop; + return fusion->IsLoopFusion(); }); EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, @@ -2563,8 +2591,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq)); return builder.Build(); }; diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index b5d72b386f8..d0073237ac2 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -223,8 +223,9 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { HloInstruction::CreateParameter(0, shape, "cond_param")); auto constant = cond_builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant)); + cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param, + constant, ComparisonDirection::kLt)); } auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index 19b5734825d..5b388bc0bd8 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -27,8 +27,6 @@ namespace { StatusOr RunInternal(HloModule* module, HloDomainIsolator::DomainCreator* creator) { - hlo_graph_dumper::MaybeDumpHloModule(*module, "Before Domain Isolator"); - int64 added_domains = 0; for (HloComputation* computation : module->computations()) { // Walk in post order and place all the required kDomain instructions. @@ -49,16 +47,17 @@ StatusOr RunInternal(HloModule* module, HloInstruction* domain = (*creator)(instruction, root, operand); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); + // Call ReplaceUseWithDifferentShape even though the shapes are + // expected to match to avoid an expensive shape check between the + // original and the new instruction. + TF_RETURN_IF_ERROR( + operand->ReplaceUseWithDifferentShape(instruction, domain)); ++added_domains; } } } } VLOG(3) << "Added " << added_domains << " kDomain instructions"; - if (added_domains > 0) { - hlo_graph_dumper::MaybeDumpHloModule(*module, "After Domain Isolator"); - } return added_domains > 0; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc index 67fad0769f5..4975c3fbb93 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -59,8 +59,6 @@ Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( StatusOr HloDomainRemover::RunContext::Run() { VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover"); - int64 removed_domains = 0; for (HloComputation* computation : module_->computations()) { // First create the domain instruciton sets. A domain instruction set is @@ -97,9 +95,6 @@ StatusOr HloDomainRemover::RunContext::Run() { } VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '" << remover_->kind_ << "' kind"; - if (removed_domains > 0) { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover"); - } return removed_domains > 0; } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 63bdbc52e82..0320979102f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -56,43 +57,40 @@ namespace xla { namespace { template -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el != rhs_el; }; break; - case HloOpcode::kGe: + case ComparisonDirection::kGe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el >= rhs_el; }; break; - case HloOpcode::kGt: + case ComparisonDirection::kGt: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el > rhs_el; }; break; - case HloOpcode::kLe: + case ComparisonDirection::kLe: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el <= rhs_el; }; break; - case HloOpcode::kLt: + case ComparisonDirection::kLt: compare_op = [](OperandT lhs_el, OperandT rhs_el) { return lhs_el < rhs_el; }; break; - default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); } Literal result(shape); @@ -106,24 +104,25 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, + ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](complex64 lhs_el, complex64 rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](complex64 lhs_el, complex64 rhs_el) { return lhs_el != rhs_el; }; break; default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); + LOG(FATAL) << "unhandled direction for conversion to Comparison: " + << ComparisonDirectionToString(direction); } Literal result(shape); @@ -137,24 +136,25 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr Compare(const Shape& shape, HloOpcode opcode, +StatusOr Compare(const Shape& shape, + ComparisonDirection direction, LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: + switch (direction) { + case ComparisonDirection::kEq: compare_op = [](complex128 lhs_el, complex128 rhs_el) { return lhs_el == rhs_el; }; break; - case HloOpcode::kNe: + case ComparisonDirection::kNe: compare_op = [](complex128 lhs_el, complex128 rhs_el) { return lhs_el != rhs_el; }; break; default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); + LOG(FATAL) << "unhandled direction for conversion to Comparison: " + << ComparisonDirectionToString(direction); } Literal result(shape); @@ -216,10 +216,10 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) return Unimplemented( "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); }); - typed_visitors_[OPAQUE] = + typed_visitors_[OPAQUE_TYPE] = absl::make_unique([](HloInstruction*) { return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE_TYPE."); }); typed_visitors_[TOKEN] = absl::make_unique([](HloInstruction*) { @@ -244,12 +244,13 @@ StatusOr HloEvaluator::Evaluate( const auto& computation_shape = computation.parameter_instruction(i)->shape(); const auto& arg_shape = arg_literals[i]->shape(); - if (!ShapeUtil::Equal(computation_shape, arg_shape)) { + if (!Shape::Equal().MinorToMajorOnlyInLayout()(computation_shape, + arg_shape)) { return InvalidArgument( "Shape mismatch at parameter %d. Computation expected %s, but arg " "was %s.", i, ShapeUtil::HumanStringWithLayout(computation_shape), - ShapeUtil::HumanString(arg_shape)); + ShapeUtil::HumanStringWithLayout(arg_shape)); } } @@ -423,10 +424,12 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { #ifndef NDEBUG const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); - DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) - << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) + DCHECK(Shape::Equal().MinorToMajorOnlyInLayout()(parameter->shape(), + input_literal->shape())) + << "parameter shape is: " + << ShapeUtil::HumanStringWithLayout(parameter->shape()) << ", but input literal shape is: " - << ShapeUtil::HumanString(input_literal->shape()); + << ShapeUtil::HumanStringWithLayout(input_literal->shape()); #endif return Status::OK(); @@ -495,7 +498,7 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { switch (elem_ty) { case PRED: case TUPLE: - case OPAQUE: + case OPAQUE_TYPE: case TOKEN: case S8: case S16: @@ -671,20 +674,11 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) { } Status HloEvaluator::HandleCompare(HloInstruction* compare) { - HloOpcode opcode = compare->opcode(); + ComparisonDirection direction = compare->comparison_direction(); auto lhs = compare->operand(0); auto rhs = compare->operand(1); - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s", - ShapeUtil::HumanString(compare->shape()), - ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape())); - } + DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); @@ -696,76 +690,76 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { case PRED: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case U8: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U16: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case U64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S8: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case S16: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case S64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case F16: { TF_ASSIGN_OR_RETURN( evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case BF16: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case F32: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case F64: { - TF_ASSIGN_OR_RETURN( - evaluated_[compare], - Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), direction, + lhs_literal, rhs_literal)); } break; case C64: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; case C128: { TF_ASSIGN_OR_RETURN(evaluated_[compare], - Compare(compare->shape(), opcode, + Compare(compare->shape(), direction, lhs_literal, rhs_literal)); } break; default: @@ -786,6 +780,545 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } +namespace { + +// Straightforward implementation of 1D DFT transform. Uses passed-in start +// index and stride to gather inputs from the data vector into the preallocated +// buffer, computes the result, and writes it back to the same locations in the +// data vector. Runs in O(length^2) time. +// +// Parameters contract_output and expand_input are used to avoid unnecessary +// calculations. When contract_output is set to true, then only (length / 2) + 1 +// output values are computed. When expand_input is set to true, then +// (length / 2) + 1 values from the data set are used to re-create the full set +// of size 'length', on which the transform is then performed. +// +void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse, + bool contract_output, bool expand_input, + absl::Span data, absl::Span buffer) { + CHECK_GT(data.size(), start + (length - 1) * stride); + CHECK_GT(buffer.size(), length - 1); + + // Copy input data to 1D vector. + bool input_is_zero = true; + const int64 ub = expand_input ? length / 2 + 1 : length; + for (int64 k = 0; k < ub; k++) { + complex128 value = data[start + k * stride]; + input_is_zero &= value == complex128(0.0, 0.0); + buffer[k] = value; + if (expand_input) { + // Use conjugates of the values at indices [1 ... (ub - 2)] when the + // length is even and at indices [1 ... (ub - 1)] when the length is odd + // to calculate missing values at indices [(length - 1) ... ub]. + if (k > 0 && k < (length - ub + 1)) { + buffer[length - k] = std::conj(value); + } + } + } + + // Do 1D transformation with double precision. + if (!input_is_zero) { + const int64 ub = contract_output ? length / 2 + 1 : length; + for (int64 k = 0; k < ub; k++) { + complex128 value = complex128(0.0, 0.0); + for (int n = 0; n < length; n++) { + auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * n * k / length)); + value += (inverse ? std::conj(buffer[n]) : buffer[n]) * coeff; + } + data[start + k * stride] = + inverse ? std::conj(value) / complex128(length, 0.0) : value; + } + } +} + +// Helper to reverse the order of dimension lengths in the passed-in literal. +std::vector GetDimensionLengths(const Literal& literal) { + std::vector lengths = literal.shape().dimensions(); + absl::c_reverse(lengths); + return lengths; +} + +// Helper to compute strides for creating linear indices into multidimensional +// data from the dimension lengths and the layout. Returns a new vector of size +// lengths.size() + 1. The last element of the returned vector at index +// [lengths.size()] contains the product of all dimension lengths. +std::vector ComputeStrides(const absl::Span lengths, + const Layout& layout) { + const int64 num_dimensions = lengths.size(); + + // Make sure that the layout length matches the number of dimensions. + CHECK_EQ(num_dimensions, layout.minor_to_major_size()); + + // Calculate strides using layout-specified ordering of the dimensions and + // place the stride for axis 0 at index 0, for axis 1 at index 1, etc. + std::vector strides(num_dimensions + 1); + int64 stride = 1; + for (int64 i = 0; i < num_dimensions; i++) { + // Reverse the ordering of the dimensions in the layout. + const int64 index = (num_dimensions - 1) - layout.minor_to_major(i); + strides[index] = stride; + stride *= lengths[index]; + } + strides[num_dimensions] = stride; + + return strides; +} + +// Compute strides as above using the default layout. +std::vector ComputeStrides(const absl::Span lengths) { + return ComputeStrides(lengths, + LayoutUtil::GetDefaultLayoutForRank(lengths.size())); +} + +// Compute strides as above using the layout from the literal, if available. +std::vector ComputeStrides(const absl::Span lengths, + const Literal& literal) { + return literal.shape().has_layout() + ? ComputeStrides(lengths, literal.shape().layout()) + : ComputeStrides(lengths); +} + +// Make 1D sweeps along each transform axis. +void Sweep(int64 fft_rank, FftType fft_type, + const absl::Span fft_lengths, + const absl::Span fft_strides, + absl::Span data, absl::Span buffer) { + const bool inverse = fft_type == FftType::IFFT || fft_type == FftType::IRFFT; + const bool input_is_truncated = fft_type == FftType::IRFFT; + const bool output_is_truncated = fft_type == FftType::RFFT; + + // Recursively visit each column of the data along the sweep_axis. Calculate + // linearized index of that column's first element and the stride, then invoke + // 1D transform. + // For RFFT, avoid calculating unused output values: first, compute only + // (length_x / 2) + 1 values along the X axis, then limit the X coordinate to + // [0 ... (length / 2)] during the sweeps along other axes. Similarly, for + // IRFFT sweep along higher dimensions first, while keeping the X coordinate + // in the [0 ... (length / 2)] range, then re-create negative frequencies + // omitted in the input and perform the full-length transform along the X axis + // in the last sweep. + std::function sweep = [&](int64 sweep_axis, + int64 axis, + int64 start) { + if (axis < 0) { + // Base case: invoke 1D transform. + const int64 length = fft_lengths[sweep_axis]; + const int64 stride = fft_strides[sweep_axis]; + const bool expand_input = input_is_truncated && sweep_axis == 0; + const bool contract_oputput = output_is_truncated && sweep_axis == 0; + NaiveDft1D(length, start, stride, inverse, contract_oputput, expand_input, + data, buffer); + } else if (axis == sweep_axis) { + // Visit only the elements with coordinate 0 along the sweep axis. + sweep(sweep_axis, axis - 1, start); + } else { + const int64 length = fft_lengths[axis]; + const bool is_truncated = input_is_truncated || output_is_truncated; + const int64 ub = is_truncated && axis == 0 ? (length / 2) + 1 : length; + for (int64 i = 0; i < ub; i++) { + sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]); + } + } + }; + if (input_is_truncated) { + // Sweep along the X axis last for IRFFT. + for (int64 sweep_axis = fft_rank - 1; sweep_axis >= 0; sweep_axis--) { + sweep(sweep_axis, fft_rank - 1, 0); + } + } else { + // Sweep along the X axis first for RFFT. The order does not matter for FFT + // and IFFT types; handle them here as well. + for (int64 sweep_axis = 0; sweep_axis < fft_rank; sweep_axis++) { + sweep(sweep_axis, fft_rank - 1, 0); + } + } +} + +// These templates convert the data from the input data type to the type used in +// calculations and then to the output data type. They are intended to be used +// only within the DFT implementation. One special case is IRFFT, where the +// specialization drops imaginary parts of complex values (which is expected to +// be 0) and returns real numbers. +template +ToType GetAs(FromType value) { + return static_cast(value); +} + +template <> +float GetAs(complex128 value) { + return static_cast(value.real()); +} + +// This template generates two linearized indices, which can be used to access +// multidimensional arrays. It uses a recursive function, which passes the +// indices to the user-supplied callback function. The destination index is +// always within dst_lengths[] bounds. The boolean parameter within_src_bounds +// indicates whether the source index is within src_lengths[] bounds. +// +// The value returned from the callback function controls the recursion depth. +// Returning true indicates that the base case had been hit and the recursion +// stops. Otherwise, the recursion proceeds along the next less-major axis. +// +// For example, the base case when the axis value becomes negative invokes the +// callback function for each possible index within dst_lengths[] bounds. The +// base case when the axis value is equal to zero limits the indices to point +// only to first elements along the minor-most dimension, allowing the callback +// function to handle all values along the X axis. +// +template +void GenerateIndices(const absl::Span dst_lengths, + const absl::Span dst_strides, + const absl::Span src_lengths, + const absl::Span src_strides, int64 fft_rank, + int64 dst_start, int64 src_start, BaseFn&& base) { + CHECK_EQ(dst_lengths.size() + 1, dst_strides.size()); + CHECK_GE(dst_lengths.size(), fft_rank); + CHECK_EQ(src_lengths.size() + 1, src_strides.size()); + CHECK_GE(src_lengths.size(), fft_rank); + + std::function generate = + [&](int64 axis, int64 dst_index, int64 src_index, + bool within_src_bounds) { + if (!base(axis, dst_index, src_index, within_src_bounds)) { + for (int64 i = 0; i < dst_lengths[axis]; i++) { + // Because the loop goes over dst_lengths[], the source index may be + // out of src_lengths[] bounds. In this case, within_src_bounds is + // false. + within_src_bounds &= i < src_lengths[axis]; + generate(axis - 1, dst_index, src_index, within_src_bounds); + dst_index += dst_strides[axis]; + src_index += src_strides[axis]; + } + } + }; + generate(fft_rank - 1, dst_start, src_start, true); +} + +// Copies the input data from a literal to a pre-allocated vector. The sizes of +// the input and the transform do not need to match. For each axis of the +// transform, any extra input values beyond the transform length are ignored. +// Conversely, if the input does not contain enough elements along any axis, the +// data is padded with zeroes. +// +// For IRFFT transforms, we use (length_x / 2) + 1 elements from the input, +// where length_x is the size of the full transform along the X axis. +// +// The input literal may have a rank higher than the rank of the transform. +// Passed-in input_index value points to the first element of the input literal +// to be copied. +// +// Returns true if all values in the work data set are zeroes. +// +template +bool CopyDataFromInput(const Literal& input_literal, int64 input_start, + int64 fft_rank, FftType fft_type, int64 fft_size, + const absl::Span fft_lengths, + const absl::Span fft_strides, + const absl::Span input_lengths, + const absl::Span input_strides, + absl::Span data) { + CHECK_GE(data.size(), fft_size); + + const bool input_is_truncated = fft_type == FftType::IRFFT; + + // Recursively visit each transform dimension to copy input values to the + // working data set. The base case handles inputs along the X axis. + bool input_is_zero = true; + const InputType* input_data = input_literal.data().data(); + auto base_case = [&](int64 axis, int64 dst_index, int64 src_index, + bool within_src_bounds) { + if (axis == 0) { + // For IRFFT, the negavie frequencies are only needed for the sweep along + // the X axis, which is performed last. Leave this part of the working set + // uninitialized until then. + const int64 length = fft_lengths[axis]; + const int64 ub = input_is_truncated ? (length / 2) + 1 : length; + for (int64 i = 0; i < ub; i++) { + complex128 value = InputType(0); + // Read input value only if the index is within bounds. + if (within_src_bounds && i < input_lengths[axis]) { + value = GetAs( + input_data[src_index + i * input_strides[axis]]); + input_is_zero &= value == complex128(0.0, 0.0); + } + data[dst_index + i * fft_strides[axis]] = value; + } + return true; + } + return false; + }; + GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides, + fft_rank, 0, input_start, base_case); + return input_is_zero; +} + +// Copies the result of the transform to the literal output. The sizes of the +// transform and output must match. +// +// For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is +// the size of the full transform along the X axis (the most minor dimension). +// +// The output literal may have a rank higher than the rank of the transform. +// Passed-in output_index value points to the first element of the output +// literal to be filled in. +// +template +void CopyDataToOutput(const absl::Span data, int64 output_start, + int64 fft_rank, FftType fft_type, + const absl::Span fft_lengths, + const absl::Span fft_strides, + const absl::Span output_lengths, + const absl::Span output_strides, + Literal* output_literal) { + const bool output_is_truncated = fft_type == FftType::RFFT; + + // Base case for recursive copy of the results to the output. The code avoids + // making a recursive call for each output element by handling axis 0 in the + // loop (as opposed to making "axis < 0" to be the base case). + OutputType* output_data = output_literal->data().data(); + auto base_case = [&](int64 axis, int64 dst_index, int64 src_index, + bool within_src_bounds) { + if (axis == 0) { + // Drop negative frequencies for RFFT. + const int64 length = fft_lengths[axis]; + const int64 ub = output_is_truncated ? (length / 2) + 1 : length; + for (int64 i = 0; i < output_lengths[axis]; i++) { + OutputType value = OutputType(0); + // Read data only if the index is within bounds. + if (within_src_bounds && i < ub) { + value = GetAs( + data[src_index + i * fft_strides[axis]]); + } + output_data[dst_index + i * output_strides[axis]] = value; + } + return true; + } + return false; + }; + GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides, + fft_rank, output_start, 0, base_case); +} + +// Determine the type to use with the CopyDataFromInput<> template above. +bool CopyDataFromInput(const Literal& input_literal, int64 input_start, + int64 fft_rank, FftType fft_type, int64 fft_size, + const absl::Span fft_lengths, + const absl::Span fft_strides, + const absl::Span input_lengths, + const absl::Span input_strides, + absl::Span data) { + const bool input_is_float = fft_type == FftType::RFFT; + if (input_is_float) { + return CopyDataFromInput( + input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths, + fft_strides, input_lengths, input_strides, data); + } else { + return CopyDataFromInput( + input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths, + fft_strides, input_lengths, input_strides, data); + } +} + +// Determine the type to use with the CopyDataToOutput<> template above. +void CopyDataToOutput(const absl::Span data, int64 output_start, + int64 fft_rank, FftType fft_type, + const absl::Span fft_lengths, + const absl::Span fft_strides, + const absl::Span output_lengths, + const absl::Span output_strides, + Literal* output_literal) { + const bool output_is_float = fft_type == FftType::IRFFT; + if (output_is_float) { + CopyDataToOutput(data, output_start, fft_rank, fft_type, fft_lengths, + fft_strides, output_lengths, output_strides, + output_literal); + } else { + CopyDataToOutput(data, output_start, fft_rank, fft_type, + fft_lengths, fft_strides, output_lengths, + output_strides, output_literal); + } +} + +Status CheckParameters(const Shape& input_shape, const Shape& output_shape, + int64 fft_rank, FftType fft_type, + const absl::Span fft_lengths) { + // Check FFT parameters. + if (fft_rank <= 0) { + return InvalidArgument("Zero or negative FFT rank."); + } + if (*absl::c_min_element(fft_lengths) < 0) { + return InvalidArgument("Negative FFT length."); + } + + // Check input-related values. + TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape)); + if (!input_shape.IsArray()) { + return Unimplemented("Only array input shapes are supported."); + } + auto input_elt_type = input_shape.element_type(); + if (fft_type == FftType::RFFT && input_elt_type != PrimitiveType::F32) { + return InvalidArgument("Invalid input type: %d, must be %d (float).", + input_elt_type, PrimitiveType::F32); + } + if (fft_type != FftType::RFFT && input_elt_type != PrimitiveType::C64) { + return InvalidArgument("Invalid input type: %d, must be %d (complex64).", + input_elt_type, PrimitiveType::C64); + } + const int64 input_rank = input_shape.rank(); + if (input_rank < fft_rank) { + return InvalidArgument("Input shape rank is smaller than FFT rank."); + } + + // Check output-related values. + TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape)); + if (!output_shape.IsArray()) { + return Unimplemented("Only array output shapes are supported."); + } + auto output_elt_type = output_shape.element_type(); + if (fft_type == FftType::IRFFT && output_elt_type != PrimitiveType::F32) { + return InvalidArgument("Invalid output type: %d, must be %d (float).", + output_elt_type, PrimitiveType::F32); + } + if (fft_type != FftType::IRFFT && output_elt_type != PrimitiveType::C64) { + return InvalidArgument("Invalid output type: %d, must be %d (complex64).", + output_elt_type, PrimitiveType::C64); + } + const int64 output_rank = output_shape.rank(); + if (output_rank < fft_rank) { + return InvalidArgument("Output shape rank is smaller than FFT rank."); + } + + // Consistency of input and output parameters. + if (input_rank != output_rank) { + return InvalidArgument( + "Ranks of input shape and output shape do not match."); + } + for (int64 dim = 0; dim < input_rank - fft_rank; dim++) { + if (ShapeUtil::GetDimension(input_shape, dim) != + ShapeUtil::GetDimension(output_shape, dim)) { + return InvalidArgument( + "Higher dimension lengths of input shape and output shape do not " + "match."); + } + } + + return Status::OK(); +} + +} // namespace + +// Flexible but slow implementation of the discrete Fourier transform. All +// transform types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the +// arbitrary rank and length of each dimension of the transform, and arbitrary +// layouts of the input and output literals. +// +// The input literal in operand 0 provides input data, which must be complex64 +// for FFT, IFFT, IRFFT transforms and float for RFFT. The transform is computed +// over the innermost dimensions of the input, thus the rank of the input data +// must be same as fft_rank or larger. The input is expected to provide Ni +// values along each transform axis with one exception: for IRFFT, only +// (N0 / 2) + 1 values are needed along the X axis (the innermost index). To +// increase flexibility, this implementation can handle mismatches between the +// input size and transform lengths by either dropping extra input values or +// using zeroes in place of missing input values as necessary. If the input data +// has rank higher than the transform, the transform is applied for each valid +// combination of the higher-ranking indices. +// +// The output contains complex64 values for FFT, IFFT, RFFT, and float values +// for IRFFT. The rank of the output as well as the sizes of the dimensions +// above the rank of the transform must match those of the input. Sizes of the +// output's "fft_rank" innermost dimensions are expected to match the length of +// the transform along respective axes with one exception: for RFFT, the output +// is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the +// length(s) mismatch, the FFT output is trimmed to fit into the provided output +// shape, or the output is padded with zero values appropriately. +// +// For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17] +// input array will perform two transforms over the [][15][17] data in the sub +// arrays [0][][] and [1][][], dropping the values along axis X and padding axis +// Y with zeroes to create 16x16 working sets, and generating +// complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to +// complex64[64][16][9] input array will use all input values and will produce +// float[64][16][16] output. +// +// The implementation of the 1D transform is a straightforward loop nest. The +// transforms of higher ranks apply sets of 1D transforms along each axis. For +// example, the 2D transform is computed by applying 1D transforms to each +// column followed by applying 1D transforms to each row. +// +// In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn)) +// time, where Ni is the length of the transform's i-th dimension. It is +// possible to reduce the run time to O(N0*N1*...(log(N0)+log(N1)+...)) by +// plugging in a more efficient 1D implementation. +// +Status HloEvaluator::HandleFft(HloInstruction* fft) { + const FftType fft_type = fft->fft_type(); + std::vector fft_lengths = fft->fft_length(); + const int64 fft_rank = fft_lengths.size(); + const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0)); + const Shape& input_shape = input_literal.shape(); + const Shape& output_shape = fft->shape(); + Literal output_literal = Literal::CreateFromShape(output_shape); + + // Make fft_lengths[0] the minor-most dimension. + absl::c_reverse(fft_lengths); + + TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape, fft_rank, + fft_type, fft_lengths)); + + const auto fft_strides = ComputeStrides(fft_lengths); + + // Working set size. + const int64 fft_size = fft_strides[fft_rank]; + + if (fft_size > 0) { + // Linearized working data set. + std::vector data(fft_size); + + // Temporary buffer allocated once and used in 1D sweeps. + std::vector buffer(*absl::c_max_element(fft_lengths)); + + // Sizes of each axis of input and output literals. + const auto input_lengths = GetDimensionLengths(input_literal); + const auto output_lengths = GetDimensionLengths(output_literal); + + // Strides for generating linearized indices into multidimensional arrays. + const auto input_strides = ComputeStrides(input_lengths, input_literal); + const auto output_strides = ComputeStrides(output_lengths, output_literal); + + // Visit all elements in the dimensions with ranks above the FFT rank. For + // each such element invoke the transform. Use separate indices for the + // input and the output to allow different layouts. + auto base_case = [&](int64 axis, int64 output_index, int64 input_index, + bool within_src_bounds) { + if (axis == fft_rank - 1) { + // Base case: copy the data from the input literal, apply the + // transform, and copy the result to the output literal. + CHECK(within_src_bounds); + bool input_is_zero = + CopyDataFromInput(input_literal, input_index, fft_rank, fft_type, + fft_size, fft_lengths, fft_strides, input_lengths, + input_strides, absl::MakeSpan(data)); + if (!input_is_zero) { + // Make 1D sweeps along each transform axis. + Sweep(fft_rank, fft_type, fft_lengths, fft_strides, + absl::MakeSpan(data), absl::MakeSpan(buffer)); + } + CopyDataToOutput(absl::MakeSpan(data), output_index, fft_rank, fft_type, + fft_lengths, fft_strides, output_lengths, + output_strides, &output_literal); + return true; + } + return false; + }; + GenerateIndices(output_lengths, output_strides, input_lengths, + input_strides, input_shape.rank(), 0, 0, base_case); + } + + evaluated_[fft] = std::move(output_literal); + return Status::OK(); +} + // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch // dimensions while keeping the rest of the output dimensions clamped to 0. ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( @@ -1226,8 +1759,8 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { HloEvaluator embedded_evaluator; embedded_evaluator.set_dynamic_dimension_inference( dynamic_dimension_inference_); - Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(Literal result, + embedded_evaluator.Evaluate(*computation, arg_literals)); evaluated_[call] = std::move(result); return Status::OK(); @@ -1261,9 +1794,8 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { HloEvaluator embedded_evaluator; embedded_evaluator.set_dynamic_dimension_inference( dynamic_dimension_inference_); - Literal result = - embedded_evaluator.Evaluate(*readded_computation, arg_literals) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(Literal result, embedded_evaluator.Evaluate( + *readded_computation, arg_literals)); evaluated_[fusion] = std::move(result); return Status::OK(); @@ -1287,10 +1819,10 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { HloEvaluator embedded_evaluator; embedded_evaluator.set_dynamic_dimension_inference( dynamic_dimension_inference_); - Literal result = embedded_evaluator - .Evaluate(*conditional->branch_computation(branch_index), - {&branch_computation_arg}) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(Literal result, + embedded_evaluator.Evaluate( + *conditional->branch_computation(branch_index), + {&branch_computation_arg})); evaluated_[conditional] = std::move(result); return Status::OK(); @@ -1568,20 +2100,225 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { return Status::OK(); } -Status HloEvaluator::HandleReduce(HloInstruction* reduce) { - if (!reduce->shape().IsTuple()) { - return DefaultAction(reduce); - } else { - auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); - for (const auto& tuple_shape : reduce->shape().tuple_shapes()) { - if (tuple_shape.element_type() != first_element_type) { - return Unimplemented( - "Reduce with several outputs that have mixed element types is " - "unsupported"); - } - } - return reduce->Visit(typed_visitors_[first_element_type].get()); +static bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; } + return false; +} + +// Run a single step of an inner loop while running reduction, which applies +// the user-provided computation on the accumulator and the output element +// (until the reduction is completed, the output element is also used as +// an accumulator). +static StatusOr PerformReductionStep( + absl::Span input_index, absl::Span output_index, + absl::Span input_args, absl::Span results, + HloComputation* computation, HloEvaluator* embedded_evaluator) { + int num_args = results.size(); + bool is_tuple = num_args > 1; + + absl::InlinedVector arg_values; + arg_values.reserve(num_args); + absl::InlinedVector accumulators; + accumulators.reserve(num_args); + for (int64 i = 0; i < num_args; ++i) { + arg_values.emplace_back( + ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {})); + accumulators.emplace_back( + ShapeUtil::MakeShape(input_args[i]->shape().element_type(), {})); + + TF_RETURN_IF_ERROR( + arg_values[i].CopyElementFrom(*input_args[i], input_index, {})); + TF_RETURN_IF_ERROR( + accumulators[i].CopyElementFrom(results[i], output_index, {})); + } + + // Evaluate computation with specified literal operands. + absl::InlinedVector embedded_operands; + for (Literal& accumulator : accumulators) { + embedded_operands.push_back(&accumulator); + } + for (Literal& local_input : arg_values) { + embedded_operands.push_back(&local_input); + } + + TF_ASSIGN_OR_RETURN( + Literal computed_result, + embedded_evaluator->Evaluate(*computation, embedded_operands)); + + // Clear visit states so that we can use the evaluator again on the same + // computation. + embedded_evaluator->ResetVisitStates(); + + if (is_tuple) { + std::vector computed_results = computed_result.DecomposeTuple(); + for (int64 i = 0; i < num_args; ++i) { + TF_RETURN_IF_ERROR( + results[i].CopyElementFrom(computed_results[i], {}, output_index)); + } + } else { + TF_RETURN_IF_ERROR( + results[0].CopyElementFrom(computed_result, {}, output_index)); + } + + return true; +} + +static StatusOr GenerateReduceOutputElement( + absl::Span output_index, + + absl::Span init_values, + absl::Span input_args, absl::Span results, + + HloComputation* function, HloEvaluator* embedded_evaluator, + + absl::Span arg_dim_steps, + absl::Span arg_dim_counts, + absl::Span result_to_arg_index) { + bool is_tuple = results.size() > 1; + bool use_fast_add = ShapeUtil::ElementIsFloating(init_values[0]->shape()) && + IsScalarAdd(function) && !is_tuple; + + const Shape& arg_shape = input_args[0]->shape(); + absl::Span arg_dimensions = AsInt64Slice(arg_shape.dimensions()); + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < output_index.size(); ++i) { + base[result_to_arg_index[i]] = output_index[i]; + } + + for (int64 i = 0; i < results.size(); ++i) { + TF_RETURN_IF_ERROR( + results[i].CopyElementFrom(*init_values[i], {}, output_index)); + } + + if (use_fast_add) { + TF_ASSIGN_OR_RETURN(double computed_result, + init_values[0]->GetAsDouble({})); + auto reduction_step = + [&](absl::Span input_index) -> StatusOr { + TF_ASSIGN_OR_RETURN(double argument, + input_args[0]->GetAsDouble(input_index)); + computed_result += argument; + return true; + }; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step)); + TF_RETURN_IF_ERROR(results[0].SetFromDouble(output_index, computed_result)); + return true; + } + + // Iterates only over reduced shape, as counts and steps are set to zero + // for all non-reduced dimensions. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + arg_shape, base, arg_dim_counts, arg_dim_steps, + [&](absl::Span input_index) { + return PerformReductionStep(input_index, output_index, input_args, + results, function, embedded_evaluator); + })); + return true; +} + +Status HloEvaluator::HandleReduce(HloInstruction* instr) { + HloReduceInstruction* reduce = Cast(instr); + int64 num_args = reduce->inputs().size(); + absl::Span dimensions_to_reduce(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); + + absl::InlinedVector operand_shapes; + for (const HloInstruction* operand : reduce->operands()) { + operand_shapes.push_back(&operand->shape()); + } + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReduceShape( + operand_shapes, dimensions_to_reduce, + /*to_apply=*/function->ComputeProgramShape())); + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(reduce->shape(), + inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + absl::InlinedVector input_args(num_args); + absl::InlinedVector init_values(num_args); + for (int64 i = 0; i < num_args; ++i) { + input_args[i] = &GetEvaluatedLiteralFor(reduce->inputs()[i]); + VLOG(3) << "HandleReduce arg_literal: " << input_args[i]->ToString(); + init_values[i] = &GetEvaluatedLiteralFor(reduce->init_values()[i]); + VLOG(3) << "HandleReduce init_literal: " << init_values[i]->ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_values[i]->shape())); + } + + // All args and results have the same dimensions, so pick an arbitrary one. + const Shape& arg_shape = input_args[0]->shape(); + const Shape& out_shape = inferred_return_shape; + bool is_tuple = out_shape.IsTuple(); + const Shape& output_shape = inferred_return_shape.IsTuple() + ? inferred_return_shape.tuple_shapes(0) + : inferred_return_shape; + + absl::Span arg_dimensions = AsInt64Slice(arg_shape.dimensions()); + + // All increments are set to 0. + std::vector arg_dim_steps(arg_dimensions.size()); + + // All counts are set to 0. + std::vector arg_dim_counts(arg_dimensions.size()); + + // Set steps and counts for reduced dimensions. + // This avoids iterating over non-reduced dimensions, as their step + // and count is set to zero. + for (const int64 dim : dimensions_to_reduce) { + arg_dim_steps[dim] = 1; + arg_dim_counts[dim] = arg_dimensions[dim]; + } + auto reduced_dimensions = arg_shape.dimensions(); + + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector result_to_arg_index; + for (int64 i = 0; i < arg_dimensions.size(); ++i) { + if (arg_dim_steps[i] == 0) { + result_to_arg_index.push_back(i); + } + } + + HloEvaluator embedded_evaluator(max_loop_iterations_); + absl::InlinedVector results(num_args); + for (int64 i = 0; i < num_args; ++i) { + results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape); + } + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + output_shape, [&](absl::Span output_index) { + return GenerateReduceOutputElement( + output_index, init_values, input_args, absl::Span(results), + function, &embedded_evaluator, arg_dim_steps, arg_dim_counts, + result_to_arg_index); + })); + + if (is_tuple) { + Literal tuple_result(inferred_return_shape); + for (int64 i = 0; i < num_args; ++i) { + TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); + } + evaluated_[reduce] = std::move(tuple_result); + } else { + CHECK_EQ(results.size(), 1); + evaluated_[reduce] = std::move(results[0]); + } + if (!ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) { + TF_ASSIGN_OR_RETURN(evaluated_[reduce], + evaluated_[reduce].ConvertToShape(reduce->shape())); + } + return Status::OK(); } Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { @@ -1615,8 +2352,9 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); // Out of convenience the literal may have been produced with a different // layout. Relayout as indicated by the HLO instruction. - if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), - hlo->shape())) { + if (!Layout::Equal().MinorToMajorOnly()( + GetEvaluatedLiteralFor(hlo).shape().layout(), + hlo->shape().layout())) { evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 357975a131d..45b6a2754d6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -204,6 +204,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 383921fde22..c4266f95fcc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -56,14 +56,14 @@ static std::array use_bf16_params{true, false}; // In bf16 mode, all f32 shapes are converted to bf16 before running. class HloEvaluatorTest : public HloTestBase { public: - HloEvaluatorTest() : use_bfloat16_(false) {} + HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); } - Literal Evaluate(absl::Span arg_literals = {}) { + StatusOr Evaluate( + absl::Span arg_literals = {}) { if (use_bfloat16_) { HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_.Evaluate(*m_->entry_computation(), arg_literals) - .ConsumeValueOrDie(); + return evaluator_.Evaluate(*m_->entry_computation(), arg_literals); } // Evaluate function that takes in a local module instead of using m_ @@ -86,7 +86,7 @@ class HloEvaluatorTest : public HloTestBase { b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto element_type = expected.shape().element_type(); if (element_type == F32 || element_type == F64) { @@ -106,7 +106,7 @@ class HloEvaluatorTest : public HloTestBase { HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } @@ -124,17 +124,30 @@ class HloEvaluatorTest : public HloTestBase { expected.shape(), opcode, operand0, operand1, operand2)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } protected: - explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {} + explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) { + InitializeFftData(); + } + + // Initializes data sets used in FFT tests below. + void InitializeFftData(); + HloEvaluator evaluator_; const bool use_bfloat16_; std::unique_ptr m_ = CreateNewVerifiedModule(); + + // Data sets used in FFT tests below. + ErrorSpec fft_error_ = ErrorSpec(1e-4, 1e-5); + Literal fft_c64x2x4x8_; + Literal fft_c64x2x4x8_1d_; + Literal fft_c64x2x4x8_2d_; + Literal fft_c64x2x4x8_3d_; }; // Lets you write TEST_Ps that run twice, once with and once without bf16. @@ -163,7 +176,7 @@ TEST_P(HloEvaluatorBf16Test, DoesClamp) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); @@ -189,7 +202,7 @@ TEST_P(HloEvaluatorBf16Test, DoesClampInt64) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{0, ones(55)}, {ones(54), ones(58)}}); @@ -211,7 +224,7 @@ TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); @@ -236,7 +249,7 @@ TEST_P(HloEvaluatorBf16Test, DoesSelect) { HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate({}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({})); auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); @@ -339,6 +352,13 @@ TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) { auto expected = LiteralUtil::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } + +TEST_F(HloEvaluatorTest, DoesAbsC128) { + auto x = LiteralUtil::CreateR0({1, 2}); + auto expected_real = LiteralUtil::CreateR0(2.23607); + TestUnaryOp(HloOpcode::kAbs, std::move(expected_real), std::move(x), 3e-06); +} + TEST_F(HloEvaluatorTest, DoesNegateR2) { auto operand = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); @@ -404,7 +424,7 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { lhs_instruction, param_rhs2)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(args); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args)); auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); @@ -428,7 +448,7 @@ TEST_F(HloEvaluatorTest, DoesReshape) { HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate({}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({})); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; result.EachCell([&](absl::Span indices, NativeT value) { @@ -449,7 +469,7 @@ TEST_F(HloEvaluatorTest, DoesBroadcast) { output_literal.shape(), literal_instruction, {1, 2})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate({}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({})); EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } @@ -468,7 +488,7 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { /*broadcast_dimensions=*/{})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate({}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({})); EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } @@ -488,7 +508,7 @@ TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2( {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); @@ -510,7 +530,7 @@ TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR1({100, 200}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); @@ -530,7 +550,7 @@ TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) { b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } @@ -550,7 +570,7 @@ TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) { b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } @@ -585,7 +605,7 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { shape, operand_instruction, padding_value_instruction, padding_config)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); @@ -612,7 +632,7 @@ TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -656,7 +676,7 @@ TEST_P(HloEvaluatorBf16Test, NegativePadding2D) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = absl::make_unique>(1, 5); @@ -701,7 +721,7 @@ TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); @@ -740,7 +760,7 @@ TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // clang-format off auto expected_array = Array2D({ @@ -786,7 +806,7 @@ TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR1({22.f, 28.f}); @@ -830,7 +850,7 @@ TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected_array = Array2D({ {22.f, 28.f}, @@ -872,7 +892,8 @@ TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); + float expected_1 = 0; for (float i = 1.0f; i < 7.0f; ++i) { expected_1 += i * i + i; @@ -928,7 +949,7 @@ TEST_P(HloEvaluatorBf16Test, SimpleConv1D) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); @@ -983,7 +1004,7 @@ TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -1067,7 +1088,7 @@ TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -1145,7 +1166,7 @@ TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -1205,7 +1226,7 @@ TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1269,7 +1290,7 @@ TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) { /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1341,7 +1362,7 @@ TEST_P(HloEvaluatorBf16Test, /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1413,7 +1434,7 @@ TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) { DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); Array4D expected_array(1, 1, 1, 8); expected_array.FillWithYX( @@ -1422,6 +1443,1015 @@ TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +// Initialization of data sets for FFT tests: + +void HloEvaluatorTest::InitializeFftData() { + // clang-format off + fft_c64x2x4x8_ = LiteralUtil::CreateR3({ + {{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, + {4.0, 0.0}, {5.0, 0.0}, {6.0, 0.0}, {7.0, 0.0}}, + {{0.0, 0.0}, {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0}, + {0.0, 4.0}, {0.0, 5.0}, {0.0, 6.0}, {0.0, 7.0}}, + {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0}, + {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}}, + {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0}, + {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}}, + {{{-4.0, 0.0}, {-3.0, 0.0}, {-2.0, 0.0}, {-1.0, 0.0}, + {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}}, + {{0.0, -4.0}, {0.0, -3.0}, {0.0, -2.0}, {0.0, -1.0}, + {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0}, {0.0, 4.0}}, + {{3.5, 3.5}, {-1.707107, -0.707107}, {-1.0, -0.0}, {-0.707107, 0.292893}, + {-0.5, 0.5}, {-0.292893, 0.707107}, {0.0, 1.0}, {0.707107, 1.707107}}, + {{3.5, 3.5}, {1.707107, 0.707107}, {1.0, 0.0}, {0.707107, -0.292893}, + {0.5, -0.5}, {0.292893, -0.707107}, {-0.0, -1.0}, {-0.707107, -1.707107}}} + }); + fft_c64x2x4x8_1d_ = LiteralUtil::CreateR3({ + {{{28.0, 0.0}, {-4.0, 9.656854}, {-4.0, 4.0}, {-4.0, 1.656854}, + {-4.0, 0.0}, {-4.0, -1.656854}, {-4.0, -4.0}, {-4.0, -9.656854}}, + {{0.0, 28.0}, {-9.656854, -4.0}, {-4.0, -4.0}, {-1.656854, -4.0}, + {0.0, -4.0}, {1.656854, -4.0}, {4.0, -4.0}, {9.656854, -4.0}}, + {{28.0, 28.0}, {5.656854, 13.656854}, {0.0, 8.0}, {-2.343146, 5.656854}, + {-4.0, 4.0}, {-5.656854, 2.343146}, {-8.0, -0.0}, {-13.656854, -5.656854}}, // NOLINT + {{28.0, 28.0}, {-5.656854, -13.656854}, {-0.0, -8.0}, {2.343146, -5.656854}, // NOLINT + {4.0, -4.0}, {5.656854, -2.343146}, {8.0, 0.0}, {13.656854, 5.656854}}}, + {{{0.0, 0.0}, {-5.0, 12.071068}, {-4.0, 4.0}, {-5.0, 2.071068}, + {-4.0, 0.0}, {-5.0, -2.071068}, {-4.0, -4.0}, {-5.0, -12.071068}}, + {{0.0, 0.0}, {-12.071068, -5.0}, {-4.0, -4.0}, {-2.071068, -5.0}, + {0.0, -4.0}, {2.071068, -5.0}, {4.0, -4.0}, {12.071068, -5.0}}, + {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0}, + {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}}, + {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0}, + {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}} + }); + fft_c64x2x4x8_2d_ = LiteralUtil::CreateR3({ + {{{84.0, 84.0}, {-13.656854, 5.656854}, {-8.0, 0.0}, {-5.656854, -2.343146}, + {-4.0, -4.0}, {-2.343146, -5.656854}, {0.0, -8.0}, {5.656854, -13.656854}}, // NOLINT + {{0.0, 0.0}, {0.0, -0.0}, {0.0, 0.0}, {0.0, 0.0}, + {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{28.0, -28.0}, {16.970562, 40.970562}, {0.0, 24.0}, {-7.029438, 16.970562}, // NOLINT + {-12.0, 12.0}, {-16.970562, 7.029438}, {-24.0, 0.0}, {-40.970562, -16.970562}}, // NOLINT + {{0.0, -56.0}, {-19.313708, -8.0}, {-8.0, -8.0}, {-3.313708, -8.0}, + {0.0, -8.0}, {3.313708, -8.0}, {8.0, -8.0}, {19.313708, -8.0}}}, + {{{7.0, 7.0}, {-10.071068, 14.071068}, {-1.0, 7.0}, {-0.071068, 4.071068}, + {3.0, 3.0}, {4.071068, -0.071068}, {7.0, -1.0}, {14.071068, -10.071068}}, + {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136}, + {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}}, + {{-7.0, 7.0}, {2.071068, 22.071068}, {-3.0, 11.0}, {-3.928932, 8.071068}, + {-3.0, 3.0}, {-4.071068, -0.071068}, {-3.0, -5.0}, {-10.071068, -14.071068}}, // NOLINT + {{0.0, -14.0}, {0.0, -12.0}, {0.0, -10.0}, {0.0, -8.0}, + {0.0, -6.0}, {0.0, -4.0}, {0.0, -2.0}, {0.0, 0.0}}} + }); + fft_c64x2x4x8_3d_ = LiteralUtil::CreateR3({ + {{{91.0, 91.0}, {-23.727922, 19.727922}, {-9.0, 7.0}, {-5.727922, 1.727922}, + {-1.0, -1.0}, {1.727922, -5.727922}, {7.0, -9}, {19.727922, -23.727922}}, + {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136}, + {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}}, + {{21.0, -21.0}, {19.041630, 63.041630}, {-3.0, 35.0}, {-10.958370, 25.041630}, // NOLINT + {-15.0, 15.0}, {-21.041630, 6.958370}, {-27.0, -5.0}, {-51.041630, -31.041630}}, // NOLINT + {{0.0, -70.0}, {-19.313708, -20.0}, {-8.0, -18.0}, {-3.313708, -16.0}, + {0.0, -14.0}, {3.313708, -12.0}, {8.0, -10.0}, {19.313708, -8.0}}}, + {{{77.0, 77.0}, {-3.585786, -8.414214}, {-7.0, -7.0}, {-5.585786, -6.414214}, // NOLINT + {-7.0, -7.0}, {-6.414214, -5.585786}, {-7.0, -7.0}, {-8.414214, -3.585786}}, // NOLINT + {{0.0, 0.0}, {12.0, -24.142136}, {12.0, -8.0}, {16.0, -4.142136}, + {16.0, 0.0}, {20.0, 4.142136}, {20.0, 8.0}, {24.0, 24.142136}}, + {{35.0, -35.0}, {14.899494, 18.899494}, {3.0, 13.0}, {-3.100506, 8.899494}, + {-9.0, 9.0}, {-12.899494, 7.100506}, {-21.0, 5.0}, {-30.899494, -2.899494}}, // NOLINT + {{0.0, -42.0}, {-19.313708, 4.0}, {-8.0, 2.0}, {-3.313708, 0.0}, + {0.0, -2.0}, {3.313708, -4.0}, {8.0, -6.0}, {19.313708, -8.0}}} + }); + // clang-format on +} + +// Simple FFT tests: + +TEST_F(HloEvaluatorTest, 1D_FFT_4_on_c64x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[4] parameter(0) + ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4} +} +)"; + auto input = LiteralUtil::CreateR1( + {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}}); + auto expected = LiteralUtil::CreateR1( + {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IFFT_4_on_c64x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[4] parameter(0) + ROOT ifft = c64[4] fft(operand), fft_type=IFFT, fft_length={4} +} +)"; + auto input = LiteralUtil::CreateR1( + {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}}); + auto expected = LiteralUtil::CreateR1( + {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_RFFT_4_on_f32x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[4] parameter(0) + ROOT rfft = c64[3] fft(operand), fft_type=RFFT, fft_length={4} +} +)"; + auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto expected = + LiteralUtil::CreateR1({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IRFFT_4_on_c64x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3] parameter(0) + ROOT irfft = f32[4] fft(operand), fft_type=IRFFT, fft_length={4} +} +)"; + auto input = + LiteralUtil::CreateR1({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}}); + auto expected = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// 1D FFT tests: + +TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IFFT_8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_1d_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_RFFT_8_on_f32x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[8] parameter(0) + ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={8} +} +)"; + auto input = + LiteralUtil::CreateR1({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1}); + auto expected = LiteralUtil::CreateR1({{39.6, 0.0}, + {-3.6, 8.691169}, + {-3.6, 3.6}, + {-3.6, 1.491169}, + {-3.6, 0.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IRFFT_8_on_c64x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[5] parameter(0) + ROOT irfft = f32[8] fft(operand), fft_type=IRFFT, fft_length={8} +} +)"; + auto input = LiteralUtil::CreateR1({{39.6, 0.0}, + {-3.6, 8.691169}, + {-3.6, 3.6}, + {-3.6, 1.491169}, + {-3.6, 0.0}}); + auto expected = + LiteralUtil::CreateR1({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_RFFT_9_on_f32x9) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[9] parameter(0) + ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={9} +} +)"; + auto input = LiteralUtil::CreateR1( + {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9}); + auto expected = LiteralUtil::CreateR1({{49.5, 0.0}, + {-3.360560, 11.705792}, + {-3.893717, 5.712929}, + {-4.5, 3.117691}, + {-4.895723, 1.021942}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IRFFT_9_on_c64x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[5] parameter(0) + ROOT irfft = f32[9] fft(operand), fft_type=IRFFT, fft_length={9} +} +)"; + auto input = LiteralUtil::CreateR1({{49.5, 0.0}, + {-3.360560, 11.705792}, + {-3.893717, 5.712929}, + {-4.5, 3.117691}, + {-4.895723, 1.021942}}); + auto expected = LiteralUtil::CreateR1( + {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// 2D FFT tests: + +TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={4, 8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_IFFT_4x8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={4, 8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_2d_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_RFFT_3x8_on_f32x3x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 8] parameter(0) + ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 8} +} +)"; + auto input = + LiteralUtil::CreateR2({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1}, + {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8}, + {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}}); + auto expected = LiteralUtil::CreateR2({{{118.8, 0.0}, + {-4.4, 10.622540}, + {-4.4, 4.4}, + {-4.4, 1.822540}, + {-4.4, 0.0}}, + {{0.0, 0.0}, + {-19.926162, 0.797280}, + {-10.128203, -3.728203}, + {-6.069756, -5.602720}, + {-3.2, -6.928203}}, + {{0.0, 0.0}, + {13.526162, 14.653687}, + {3.728203, 10.128203}, + {-0.330244, 8.253687}, + {-3.2, 6.928203}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_IRFFT_3x8_on_c64x3x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 5] parameter(0) + ROOT irfft = f32[3, 8] fft(operand), fft_type=IRFFT, fft_length={3, 8} +} +)"; + auto input = LiteralUtil::CreateR2({{{118.8, 0.0}, + {-4.4, 10.622540}, + {-4.4, 4.4}, + {-4.4, 1.822540}, + {-4.4, 0.0}}, + {{0.0, 0.0}, + {-19.926162, 0.797280}, + {-10.128203, -3.728203}, + {-6.069756, -5.602720}, + {-3.2, -6.928203}}, + {{0.0, 0.0}, + {13.526162, 14.653687}, + {3.728203, 10.128203}, + {-0.330244, 8.253687}, + {-3.2, 6.928203}}}); + auto expected = + LiteralUtil::CreateR2({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1}, + {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8}, + {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_RFFT_3x9_on_f32x3x9) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 9] parameter(0) + ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 9} +} +)"; + auto input = LiteralUtil::CreateR2( + {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1}, + {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9}, + {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}}); + auto expected = LiteralUtil::CreateR2({{{148.5, 0.0}, + {-4.95, 13.600013}, + {-4.95, 5.899180}, + {-4.95, 2.857884}, + {-4.95, 0.872819}}, + {{0.0, 0.0}, + {-25.014467, 2.096690}, + {-12.888800, -3.503916}, + {-8.1, -5.715768}, + {-4.974333, -7.159452}}, + {{0.0, 0.0}, + {17.814467, 17.685147}, + {5.688800, 12.084542}, + {0.9, 9.872690}, + {-2.225667, 8.429006}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_IRFFT_3x9_on_c64x3x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 5] parameter(0) + ROOT irfft = f32[3, 9] fft(operand), fft_type=IRFFT, fft_length={3, 9} +} +)"; + auto input = LiteralUtil::CreateR2({{{148.5, 0.0}, + {-4.95, 13.600013}, + {-4.95, 5.899180}, + {-4.95, 2.857884}, + {-4.95, 0.872819}}, + {{0.0, 0.0}, + {-25.014467, 2.096690}, + {-12.888800, -3.503916}, + {-8.1, -5.715768}, + {-4.974333, -7.159452}}, + {{0.0, 0.0}, + {17.814467, 17.685147}, + {5.688800, 12.084542}, + {0.9, 9.872690}, + {-2.225667, 8.429006}}}); + auto expected = LiteralUtil::CreateR2( + {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1}, + {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9}, + {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// 3D FFT tests: + +TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={2, 4, 8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_IFFT_2x4x8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={2, 4, 8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_3d_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_f32x3x3x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 3, 4] parameter(0) + ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}}, + {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}}, + {{-1.8, -2.7, -3.6, -4.5}, + {-5.4, -6.3, -7.2, -8.1}, + {1.9, 2.9, 3.9, 4.9}}}); + auto expected = LiteralUtil::CreateR3( + {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}}, + {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}}, + {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}}, + {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}}, + {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}}, + {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}}, + {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}}, + {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}}, + {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_c64x3x3x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 3] parameter(0) + ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}}, + {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}}, + {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}}, + {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}}, + {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}}, + {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}}, + {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}}, + {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}}, + {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}}); + auto expected = LiteralUtil::CreateR3( + {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}}, + {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}}, + {{-1.8, -2.7, -3.6, -4.5}, + {-5.4, -6.3, -7.2, -8.1}, + {1.9, 2.9, 3.9, 4.9}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x5_on_f32x3x3x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 3, 5] parameter(0) + ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 5} +} +)"; + auto input = LiteralUtil::CreateR3({{{1.8, 2.7, 3.6, 4.5, 5.4}, + {8.1, 7.2, 6.3, 5.4, 4.5}, + {1.1, 2.2, 3.3, 4.4, 5.5}}, + {{5.4, 6.3, 7.2, 8.1, 9.0}, + {4.5, 3.6, 2.7, 1.8, 0.9}, + {5.5, 6.6, 7.7, 8.8, 9.9}}, + {{-1.8, -2.7, -3.6, -4.5, -5.4}, + {-5.4, -6.3, -7.2, -8.1, -9.0}, + {1.9, 2.9, 3.9, 4.9, 5.9}}}); + auto expected = LiteralUtil::CreateR3( + {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}}, + {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}}, + {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}}, + {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}}, + {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}}, + {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}}, + {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}}, + {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}}, + {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x5_on_c64x3x3x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 3] parameter(0) + ROOT irfft = f32[3, 3, 5] fft(operand), fft_type=IRFFT, fft_length={3, 3, 5} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}}, + {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}}, + {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}}, + {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}}, + {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}}, + {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}}, + {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}}, + {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}}, + {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}}); + auto expected = LiteralUtil::CreateR3({{{1.8, 2.7, 3.6, 4.5, 5.4}, + {8.1, 7.2, 6.3, 5.4, 4.5}, + {1.1, 2.2, 3.3, 4.4, 5.5}}, + {{5.4, 6.3, 7.2, 8.1, 9.0}, + {4.5, 3.6, 2.7, 1.8, 0.9}, + {5.5, 6.6, 7.7, 8.8, 9.9}}, + {{-1.8, -2.7, -3.6, -4.5, -5.4}, + {-5.4, -6.3, -7.2, -8.1, -9.0}, + {1.9, 2.9, 3.9, 4.9, 5.9}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// FFT tests with non-default data layout: + +TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8_with_layout) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8]{0, 2, 1} parameter(0) + ROOT fft = c64[2, 4, 8]{1, 2, 0} fft(operand), fft_type=FFT, fft_length={8} +} +)"; + auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({0, 2, 1})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8_with_layout) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8]{2, 0, 1} parameter(0) + ROOT fft = c64[2, 4, 8]{1, 0, 2} fft(operand), fft_type=FFT, fft_length={4, 8} +} +)"; + auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({2, 0, 1})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8_with_layout) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8]{1, 2, 0} parameter(0) + ROOT fft = + c64[2, 4, 8]{0, 2, 1} fft(operand), fft_type=FFT, fft_length={2, 4, 8} +} +)"; + auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({1, 2, 0})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_)); +} + +// FFT tests with unusual parameters: + +// Zero-length transform. +TEST_F(HloEvaluatorTest, 1D_FFT_0_on_c64x1x1x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 1] parameter(0) + ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={0} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}}}); + auto expected = LiteralUtil::CreateR4({{{{{0.0, 0.0}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Zero-length axis. +TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x0) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 0] parameter(0) + ROOT fft = c64[1, 1, 1, 0] fft(operand), fft_type=FFT, fft_length={1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto input, + LiteralUtil::CreateR4({{{{}}}}).Reshape({1, 1, 1, 0})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// Some/all dimensions have length 1. +TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 1] parameter(0) + ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// Zero-length transform. +TEST_F(HloEvaluatorTest, 3D_FFT_1x0x1_on_c64x1x1x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 1] parameter(0) + ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 0, 1} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}}}); + auto expected = LiteralUtil::CreateR4({{{{{0.0, 0.0}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Zero-length axis. +TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x0x1x0x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[0, 1, 0, 1] parameter(0) + ROOT fft = c64[0, 1, 0, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto input, + LiteralUtil::CreateR4({{{{}}}}).Reshape({0, 1, 0, 1})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// Some/all dimensions have length 1. +TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x1x1x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 1] parameter(0) + ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// Some/all dimensions have length 1. +TEST_F(HloEvaluatorTest, 3D_FFT_3x1x1_on_c64x1x3x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 3, 1, 1] parameter(0) + ROOT fft = c64[1, 3, 1, 1] fft(operand), fft_type=FFT, fft_length={3, 1, 1} +} +)"; + auto input = LiteralUtil::CreateR4( + {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}}); + auto expected = + LiteralUtil::CreateR4({{{{{42.24, 24.42}}}, + {{{84.5367, 97.5818}}}, + {{{-0.0566792, -48.7418}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Some/all dimensions have length 1. +TEST_F(HloEvaluatorTest, 3D_IFFT_3x1x1_on_c64x1x3x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 3, 1, 1] parameter(0) + ROOT ifft = c64[1, 3, 1, 1] fft(operand), fft_type=IFFT, fft_length={3, 1, 1} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}, + {{{84.5367, 97.5818}}}, + {{{-0.0566792, -48.7418}}}}}); + auto expected = LiteralUtil::CreateR4( + {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Odd transform length. +TEST_F(HloEvaluatorTest, 1D_FFT_5_on_c64x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[5] parameter(0) + ROOT fft = c64[5] fft(operand), fft_type=FFT, fft_length={5} +} +)"; + auto input = LiteralUtil::CreateR1( + {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}}); + auto expected = LiteralUtil::CreateR1({{15.0, 15.0}, + {0.940955, 5.94095}, + {-1.6877, 3.3123}, + {-3.3123, 1.6877}, + {-5.94095, -0.940955}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Odd transform length. +TEST_F(HloEvaluatorTest, 1D_IFFT_5_on_c64x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[5] parameter(0) + ROOT ifft = c64[5] fft(operand), fft_type=IFFT, fft_length={5} +} +)"; + auto input = LiteralUtil::CreateR1({{15.0, 15.0}, + {0.940955, 5.94095}, + {-1.6877, 3.3123}, + {-3.3123, 1.6877}, + {-5.94095, -0.940955}}); + auto expected = LiteralUtil::CreateR1( + {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 1D_FFT_4_on_zero_c64x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[4] parameter(0) + ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4} +} +)"; + auto input = LiteralUtil::CreateR1( + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 3D_FFT_3x3x4_on_zero_c64x3x3x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 4] parameter(0) + ROOT fft = c64[3, 3, 4] fft(operand), fft_type=FFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 3D_IFFT_3x3x4_on_zero_c64x3x3x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 4] parameter(0) + ROOT ifft = c64[3, 3, 4] fft(operand), fft_type=IFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_zero_f32x3x3x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 3, 4] parameter(0) + ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}}); + auto expected = LiteralUtil::CreateR3( + {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_zero_c64x3x3x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 3] parameter(0) + ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}}); + auto expected = LiteralUtil::CreateR3( + {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Input values, for which IRFFT discards non-zero imaginary parts. +TEST_F(HloEvaluatorTest, 2D_IRFFT_3x4_on_c64x3x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3] parameter(0) + ROOT irfft = f32[3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 4} +} +)"; + auto input = + LiteralUtil::CreateR2({{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}}, + {{3.0, 0.0}, {4.0, 0.0}, {5.0, 0.0}}, + {{6.0, 0.0}, {7.0, 0.0}, {8.0, 0.0}}}); + auto expected = + LiteralUtil::CreateR2({{4.0, -0.5, 0.0, -0.5}, + {-1.5, 0.433013, 0.0, -0.433013}, + {-1.5, -0.433013, 0.0, 0.433013}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + class HloEvaluatorPreciseReduceTest : public HloTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because @@ -1531,7 +2561,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceAdd) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR1({6, 18}); @@ -1583,7 +2613,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{6, 7}}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); @@ -1635,7 +2665,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{11}}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); @@ -1692,7 +2722,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); @@ -1753,7 +2783,7 @@ TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); std::vector output_dims = {4, 3, 3, 3, 4, 4}; Literal result_literal = @@ -1785,7 +2815,7 @@ TEST_P(HloEvaluatorBf16Test, StridedSlice) { /*strides=*/{2, 3})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {3}, @@ -1821,7 +2851,7 @@ TEST_P(HloEvaluatorBf16Test, DynamicSlice) { HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, @@ -1859,7 +2889,7 @@ TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) { HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, @@ -1898,7 +2928,7 @@ TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) { shape, operand, update, {zero, one})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, @@ -1934,7 +2964,7 @@ TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, @@ -1973,7 +3003,7 @@ TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) { m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); auto result_inner_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -2011,7 +3041,7 @@ TEST_P(HloEvaluatorBf16Test, Reverse) { b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); m_->AddEntryComputation(b.Build()); - Literal result = Evaluate(); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); // clang-format off auto expected = LiteralUtil::CreateR4FromArray4D({ @@ -2051,11 +3081,12 @@ TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) { HloEvaluator evaluator; Literal param0_literal = LiteralUtil::CreateR1({1, 2, 3, 4}); Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); - auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, ¶m0_literal}, {square, &square_literal}}); - TF_ASSERT_OK(result.status()); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator.EvaluateWithSubstitutions( + add, {{param0, ¶m0_literal}, {square, &square_literal}})); EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result)); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -2076,11 +3107,11 @@ TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) { // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); - auto result = - evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}); - TF_ASSERT_OK(result.status()); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}})); EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -2102,9 +3133,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - Evaluate({&operand, &start_indices}))); + LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -2126,9 +3157,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - Evaluate({&operand, &start_indices}))); + LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -2150,10 +3181,11 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - Evaluate({&operand, &start_indices}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -2177,9 +3209,9 @@ ENTRY main { {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), result)); } TEST_F(HloEvaluatorTest, @@ -2204,9 +3236,9 @@ ENTRY main { {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -2228,8 +3260,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({1, 1}); - EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -2251,9 +3284,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{8}}, {{5}}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR3({{{8}}, {{5}}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2274,8 +3307,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); - EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), result)); } TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2298,9 +3332,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - Evaluate({&operand, &start_indices}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices})); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{0, 1}, {2, 1}}), result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { @@ -2329,9 +3363,11 @@ ENTRY main { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { @@ -2361,9 +3397,11 @@ ENTRY main { Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { @@ -2393,9 +3431,11 @@ ENTRY main { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { @@ -2425,9 +3465,11 @@ ENTRY main { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) { @@ -2458,10 +3500,12 @@ ENTRY main { Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Near( LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), - Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); + result, ErrorSpec{0.1, 0.01})); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { @@ -2491,9 +3535,11 @@ ENTRY main { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { @@ -2524,9 +3570,11 @@ ENTRY main { Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); EXPECT_TRUE(LiteralTestUtil::Equal( LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), - Evaluate({&operand, &scatter_indices, &updates}))); + result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { @@ -2561,8 +3609,9 @@ ENTRY main { LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // {{-40, 40}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, @@ -2598,8 +3647,9 @@ ENTRY main { LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { @@ -2630,8 +3680,9 @@ ENTRY main { Literal updates = LiteralUtil::CreateR2({{10}}); Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { @@ -2662,8 +3713,9 @@ ENTRY main { Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { @@ -2691,8 +3743,9 @@ ENTRY main { Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{}, {}}); - EXPECT_TRUE(LiteralTestUtil::Equal( - operand, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(operand, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { @@ -2724,8 +3777,9 @@ ENTRY main { LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); Literal expected = LiteralUtil::CreateR1({10, 61, 32}); - EXPECT_TRUE(LiteralTestUtil::Equal( - expected, Evaluate({&operand, &scatter_indices, &updates}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + Evaluate({&operand, &scatter_indices, &updates})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { @@ -2848,8 +3902,16 @@ TEST_F(HloEvaluatorTest, DoesCompareBF16) { {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); auto expected = LiteralUtil::CreateR2({{false, true, true}, {false, true, true}}); - TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), - std::move(rhs)); + + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); + b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2, + ComparisonDirection::kGe)); + m_->AddEntryComputation(b.Build()); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate()); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorBf16Test, Bf16Reduction) { @@ -2873,7 +3935,72 @@ ENTRY main { Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); Literal expected = LiteralUtil::CreateR0(bfloat16(44.0f)); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_F(HloEvaluatorTest, MixedPrecisionReduction) { + const string hlo_text = R"( +HloModule MixedPrecisionReduction + +add_f32 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY main { + arg0 = f32[4]{0} parameter(0) + init = f32[] constant(0) + ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_f32 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + + Literal arg = LiteralUtil::CreateR1({1.0f, 3.0f, -2.0f, 42.0f}); + Literal expected = LiteralUtil::CreateR0(bfloat16(44.0f)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) { + // Infeed triggers unimplemented error within HandleCall, and we verify that + // the Evaluator does fail in such case. + const string hlo_text = R"( +HloModule DontFailOnCall + +call { + token0 = token[] after-all() + ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) +} + +ENTRY main { + ROOT result = ((u32[3]{0}, pred[]), token[]) call(), to_apply=call +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto statusor = Evaluate(); + EXPECT_FALSE(statusor.status().ok()); +} + +TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) { + // Infeed triggers unimplemented error within HandleFusion, and we verify that + // the Evaluator does fail in such case. + const string hlo_text = R"( +HloModule DontFailOnFusion + +fused_computation { + token0 = token[] after-all() + ROOT infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) +} + +ENTRY main { + ROOT result = ((u32[3]{0}, pred[]), token[]) fusion(), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto statusor = Evaluate(); + EXPECT_FALSE(statusor.status().ok()); } TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) { @@ -2891,7 +4018,7 @@ ENTRY main { Literal arg = LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, LayoutUtil::MakeLayout({0, 1, 2})); - Literal actual = Evaluate({&arg}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg})); EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } @@ -2913,7 +4040,7 @@ ENTRY main { } TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); - Literal actual = Evaluate({&args[0]}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]})); if (use_bfloat16_) { EXPECT_TRUE( absl::c_equal(args[0].data(), actual.data())); @@ -2939,7 +4066,8 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); - std::vector actual = Evaluate({}).DecomposeTuple(); + TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({})); + std::vector actual = literal.DecomposeTuple(); ASSERT_EQ(actual.size(), 3); uint32 pow30 = uint32{1} << 30; @@ -2979,7 +4107,7 @@ ENTRY main { Literal size_arg = LiteralUtil::CreateR0(3); Literal data_arg = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal actual = Evaluate({&size_arg, &data_arg}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg})); EXPECT_EQ(actual.GetFirstElement(), static_cast(3)); } @@ -3002,13 +4130,13 @@ ENTRY main { .status() .error_message(), "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " - "but arg was s32[2]."); + "but arg was s32[2]{0}."); EXPECT_EQ(HloEvaluator() .Evaluate(*m_->entry_computation(), {&input_wrong_shape}) .status() .error_message(), "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " - "but arg was s32[2]."); + "but arg was s32[2]{0}."); } // Check that we get a useful error if we pass too many or too few inputs. @@ -3051,7 +4179,8 @@ TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) { TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); - Literal actual = Evaluate({&args[0]}); + + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]})); EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); } @@ -3072,7 +4201,7 @@ TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) { TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); - Literal actual = Evaluate({&args[0]}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]})); EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); } @@ -3094,7 +4223,7 @@ TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) { TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); - Literal actual_tuple = Evaluate({&args[0]}); + TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]})); std::vector actual_literals = actual_tuple.DecomposeTuple(); EXPECT_TRUE( absl::c_equal(args[0].data(), actual_literals[0].data())); @@ -3212,5 +4341,20 @@ TEST_F(HloEvaluatorTest, IsFiniteBf16) { ::testing::ElementsAre(false, true, false, true, false, false)); } +// Check that evaluating `f32[, 0] iota` doesn't oom (it's an empty +// array!). +TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) { + constexpr absl::string_view hlo_text = R"( + HloModule test + ENTRY t { + ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), ::testing::IsEmpty()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 2d8a578985e..c3b5838cf0a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include #include #include @@ -67,8 +68,8 @@ T ToArithmeticSafeType(T t) { // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated -// Handle* method of a TypedVisitor. There are however a few notable exceptions -// to this rule, notably: +// Handle* method of a TypedVisitor. There are however a few exceptions to this +// rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. // - HandleImag and HandleReal: where the resulting literal type is always float @@ -80,7 +81,7 @@ T ToArithmeticSafeType(T t) { // - ReturnT: The type of input and output of each operation. // - ElementwiseT: The type in which internal computation are done. // -// This a logically a private part of HloEvaluator. It lives in this header +// This is logically a private part of HloEvaluator. It lives in this header // file rather than in hlo_evaluator.cc because we use extern templates and a // bunch of independent cc files to speed up compiling the many instantiations // of this class. @@ -179,7 +180,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(abs->operand(0)); TF_ASSIGN_OR_RETURN( parent_->evaluated_[abs], - (HloEvaluator::ElementWiseUnaryOpImpl( + (HloEvaluator::ElementWiseUnaryOpImpl( abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, operand_literal))); @@ -937,7 +939,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - if (std::isnan(low) || std::isnan(high)) { + if (std::isnan(low) || std::isnan(high) || std::isnan(value)) { return static_cast(NAN); } return static_cast( @@ -1681,178 +1683,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return UnsupportedTypeError(sort); } - Status HandleReduce(HloInstruction* hlo) override { - HloReduceInstruction* reduce = Cast(hlo); - int64 num_args = reduce->inputs().size(); - bool has_tuple_output = reduce->shape().IsTuple(); - absl::Span dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - - absl::InlinedVector operand_shapes; - for (const HloInstruction* operand : reduce->operands()) { - operand_shapes.push_back(&operand->shape()); - } - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReduceShape( - operand_shapes, - /*dimensions_to_reduce=*/dimensions, - /*to_apply=*/function->ComputeProgramShape())); - TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - absl::InlinedVector arg_literals(num_args); - absl::InlinedVector init_literals(num_args); - for (int64 i = 0; i < num_args; ++i) { - arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]); - VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString(); - init_literals[i] = - &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]); - VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape())); - } - - // All args and results have the same dimensions, so pick an arbitrary one. - const Shape& arg_shape = arg_literals[0]->shape(); - const Shape& result_shape = reduce->shape().IsTuple() - ? reduce->shape().tuple_shapes(0) - : reduce->shape(); - const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); - std::vector arg_dim_steps(arg_dimensions.size()); - std::vector arg_dim_counts(arg_dimensions.size()); - for (const int64 dim : dimensions) { - arg_dim_steps[dim] = 1; - arg_dim_counts[dim] = arg_dimensions[dim]; - } - - // Map each dimension in the result to a dimension in arg that isn't - // being reduced. - std::vector result_to_arg_index; - for (int64 i = 0; i < arg_dimensions.size(); ++i) { - if (arg_dim_steps[i] == 0) { - result_to_arg_index.push_back(i); - } - } - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - absl::InlinedVector results(num_args); - for (int64 i = 0; i < num_args; ++i) { - results[i] = Literal(result_shape); - } - - Status eval_status; - // For each resulting dimension, calculate and assign computed values. - // This is really wasteful when num_args > 1, since we re-run the - // reduction num_args time. The alternative is to teach Populate() about - // tuples, which we should probably do. - absl::InlinedVector init_scalars(num_args); - for (int i = 0; i < num_args; ++i) { - init_scalars[i] = init_literals[i]->Get({}); - } - - for (int64 input = 0; input < num_args; ++input) { - TF_RETURN_IF_ERROR(results[input].Populate( - [&](absl::Span multi_index) { - if (!eval_status.ok()) { - return init_scalars[input]; - } - absl::InlinedVector result_values(init_scalars.begin(), - init_scalars.end()); - std::vector base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } - - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) && - IsScalarAdd(function)) { - CHECK_EQ(num_args, 1); - double computed_result = 0; - auto func = [&](absl::Span input_index) { - computed_result += - GetAsDouble(*arg_literals[0], input_index); - return true; - }; - ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base, - arg_dim_counts, arg_dim_steps, func); - return static_cast(computed_result); - } - auto func = - [&](absl::Span input_index) -> StatusOr { - absl::InlinedVector arg_values(num_args); - for (int64 i = 0; i < num_args; ++i) { - arg_values[i] = arg_literals[i]->Get(input_index); - } - - // Evaluate computation with specified literal operands. - absl::InlinedVector embedded_operands; - for (ReturnT value : result_values) { - embedded_operands.push_back( - LiteralUtil::CreateR0(value)); - } - for (ReturnT value : arg_values) { - embedded_operands.push_back( - LiteralUtil::CreateR0(value)); - } - absl::InlinedVector embedded_operands_ptrs( - embedded_operands.size()); - std::transform(embedded_operands.begin(), embedded_operands.end(), - embedded_operands_ptrs.begin(), - [](Literal& literal) { return &literal; }); - - TF_ASSIGN_OR_RETURN(Literal computed_result, - embedded_evaluator.Evaluate( - *function, embedded_operands_ptrs)); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - if (!has_tuple_output) { - result_values[0] = computed_result.Get({}); - } else { - for (int64 i = 0; i < num_args; ++i) { - result_values[i] = computed_result.Get( - /*multi_index=*/{}, /*shape_index=*/{i}); - } - } - return true; - }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - eval_status = ShapeUtil::ForEachIndexWithStatus( - arg_shape, base, arg_dim_counts, arg_dim_steps, func); - return result_values[input]; - })); - } - if (!has_tuple_output) { - parent_->evaluated_[reduce] = std::move(results[0]); - } else { - Literal tuple_result(reduce->shape()); - for (int64 i = 0; i < num_args; ++i) { - TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); - } - parent_->evaluated_[reduce] = std::move(tuple_result); - } - return eval_status; - } - - bool IsScalarAdd(HloComputation* computation) { - HloInstruction* instruction = computation->root_instruction(); - if (instruction->opcode() == HloOpcode::kAdd && - computation->num_parameters() == 2) { - const HloInstruction* lhs = instruction->operand(0); - const HloInstruction* rhs = instruction->operand(1); - return lhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(lhs->shape()) && - rhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; - } - return false; - } - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { auto operand = select_and_scatter->operand(0); auto source = select_and_scatter->operand(1); @@ -2482,6 +2312,37 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleClz(clz); } + // Enable Popcnt only for int32, uint32, int64 and uint64. + template ::value || + std::is_same::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> + Status HandlePopulationCount(HloInstruction* popcnt) { + return UnsupportedTypeError(popcnt); + } + + template ::value || + std::is_same::value || + std::is_same::value || + std::is_same::value>::type* = nullptr> + Status HandlePopulationCount(HloInstruction* popcnt) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[popcnt], + ElementWiseUnaryOp(popcnt, [](ElementwiseT elem_operand) { + return std::bitset(elem_operand) + .count(); + })); + return Status::OK(); + } + + Status HandlePopulationCount(HloInstruction* popcnt) override { + return HandlePopulationCount(popcnt); + } + template ::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { @@ -2644,32 +2505,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); - const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); - // Avoid using std::vector since std::vector does not convert to - // absl::Span. - absl::InlinedVector data(iota_size); - // We don't use std::iota for two reasons: - // - // (1) std:iota does not support bfloat16 and float16. - // - // (2) std::iota saturates for floating point types when the value is not - // representable, but the definition of HLO iota is the value as a - // 64-bit integer cast to the native type. - for (int64 i = 0; i < iota_size; ++i) { - // static_cast is required for Eigen::half (F16). - data[i] = static_cast(i); - } - auto result = LiteralUtil::CreateR1(data); - - if (iota->shape().rank() > 1) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[iota], - result.Broadcast(iota->shape(), {iota->iota_dimension()})); - } else { - TF_RET_CHECK(iota->shape().rank() == 1); - parent_->evaluated_[iota] = std::move(result); - } + Literal result(iota->shape()); + ShapeUtil::ForEachIndex(iota->shape(), [&](absl::Span idx) { + result.Set(idx, static_cast(idx[iota->iota_dimension()])); + return true; + }); + parent_->evaluated_[iota] = std::move(result); return Status::OK(); } template < diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 9623edcf5eb..3a1ba773645 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -41,17 +41,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" namespace xla { -namespace hlo_graph_dumper { namespace { using absl::nullopt; @@ -60,9 +61,6 @@ using absl::StrAppend; using absl::StrCat; using absl::StrFormat; using absl::StrJoin; -using tensorflow::Env; -using tensorflow::WriteStringToFile; -using tensorflow::io::JoinPath; // Used to indicate how we should treat a given HLOInstruction in the graph. // should we treat it like normal, hide it, and so on? @@ -119,7 +117,7 @@ class NodeFilter { // We arbitrarily set this as the boundary between "large" and "small" // instructions. bool IsSmall(const HloInstruction* instr) { - if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) || + if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) || ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) { return true; } @@ -258,14 +256,16 @@ optional MatchTrivialComputation(const HloComputation* computation) { // param0), check that the operation being performed is commutative. if (root->operand(0) == param1) { CHECK_EQ(root->operand(1), param0); - switch (root->opcode()) { - case HloOpcode::kLe: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLt: - return nullopt; - default: - break; + if (root->opcode() == HloOpcode()) { + switch (root->comparison_direction()) { + case ComparisonDirection::kLe: + case ComparisonDirection::kGe: + case ComparisonDirection::kGt: + case ComparisonDirection::kLt: + return nullopt; + default: + break; + } } } @@ -279,18 +279,22 @@ optional MatchTrivialComputation(const HloComputation* computation) { return "min"; case HloOpcode::kMaximum: return "max"; - case HloOpcode::kLe: - return "less-or-equal"; - case HloOpcode::kGe: - return "greater-or-equal"; - case HloOpcode::kGt: - return "greater-than"; - case HloOpcode::kLt: - return "less-than"; - case HloOpcode::kEq: - return "equal-to"; - case HloOpcode::kNe: - return "not-equal-to"; + case HloOpcode::kCompare: { + switch (root->comparison_direction()) { + case ComparisonDirection::kLe: + return "less-or-equal"; + case ComparisonDirection::kGe: + return "greater-or-equal"; + case ComparisonDirection::kGt: + return "greater-than"; + case ComparisonDirection::kLt: + return "less-than"; + case ComparisonDirection::kEq: + return "equal-to"; + case ComparisonDirection::kNe: + return "not-equal-to"; + } + } default: return nullopt; } @@ -922,29 +926,25 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIota: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: + case HloOpcode::kPopulationCount: case HloOpcode::kOr: case HloOpcode::kXor: case HloOpcode::kPower: @@ -1040,6 +1040,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: + case HloOpcode::kPartitionId: case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kSend: @@ -1256,36 +1257,6 @@ const HloInstruction* HloDotDumper::GetNodeForEdge( return instr; } -class GraphRendererRegistry { - public: - void SetRenderer(std::shared_ptr graph_renderer) { - tensorflow::mutex_lock lock(mu_); - graph_renderer_ = graph_renderer; - } - - std::shared_ptr GetDefaultRenderer() { - tensorflow::mutex_lock lock(mu_); - return graph_renderer_; - } - - static GraphRendererRegistry* Default() { - static GraphRendererRegistry* registry = new GraphRendererRegistry(); - return registry; - } - - private: - tensorflow::mutex mu_; - std::shared_ptr graph_renderer_ GUARDED_BY(mu_); -}; - -} // namespace - -Registrar::Registrar(std::shared_ptr dumper) { - GraphRendererRegistry::Default()->SetRenderer(dumper); -} - -namespace { - // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. NodeFilter MakeNodeRadiusAroundFilter( @@ -1448,157 +1419,7 @@ NodeFilter MakeNodeFromToFilter(const HloInstruction* from, }); } -string SaveGraph(const string& graph, - GraphRendererInterface::GraphKind graph_kind, - const string& dest_path) { - static std::atomic output_num(0); - string file_extension; - switch (graph_kind) { - case GraphRendererInterface::DOT_GRAPH: - file_extension = ".dot"; - break; - } - string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, ".")); - auto status = Status::OK(); - auto env = tensorflow::Env::Default(); - if (!env->CreateUniqueFileName(&path, file_extension)) { - status = - Status(tensorflow::error::Code::UNKNOWN, - StrCat("Failed to create temporary file to dump HLO graph: ", - strerror(errno))); - } else { - status = tensorflow::WriteStringToFile(env, path, graph); - } - if (!status.ok()) { - LOG(WARNING) << "Saving HLO graph failed: " << status; - } - return path; -} - -string ExportGraph(const string& graph, - GraphRendererInterface::GraphKind graph_kind, - const DebugOptions& debug_options) { - string path = debug_options.xla_hlo_graph_path(); - if (!path.empty() && !debug_options.xla_hlo_dump_as_html()) { - return SaveGraph(graph, graph_kind, path); - } else { - auto graph_renderer = - GraphRendererRegistry::Default()->GetDefaultRenderer(); - CHECK(graph_renderer != nullptr) - << "No registered renderer for the HLO graph. " - "Use --xla_hlo_graph_path=PATH --xla_hlo_dump_as_html=false to " - "export to local file system"; - return graph_renderer->RenderGraph(graph, graph_kind, debug_options); - } -} - -} // namespace - -string HloComputationToDotGraph(const HloComputation& computation, - const DotGraphOptions& options) { - DebugOptions default_debug_options; - return HloDotDumper(&computation, options.label, - options.debug_options ? *options.debug_options - : default_debug_options, - options.show_backend_config, options.profile, - NodeFilter()) - .Dump(); -} - -string DumpGraph(const HloComputation& computation, const string& label, - const DebugOptions& debug_options, - const HloExecutionProfile* hlo_execution_profile, - bool show_backend_config) { - GraphRendererInterface::GraphKind graph_kind; - string graph = - HloDotDumper(&computation, label, debug_options, show_backend_config, - hlo_execution_profile, NodeFilter()) - .Dump(); - graph_kind = GraphRendererInterface::DOT_GRAPH; - - string graph_url = ExportGraph(graph, graph_kind, debug_options); - LOG(INFO) << "computation " << computation.name() << " [" << label - << "]: " << graph_url; - return graph_url; -} - -string DumpNeighborhoodAround( - const HloInstruction& node, int radius, bool show_backend_config, - const absl::flat_hash_set& boundary) { - auto debug_options = node.GetModule()->config().debug_options(); - string label = - StrCat("Neighborhood of ", radius, " nodes around ", node.name()); - NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius, boundary); - string graph = - HloDotDumper(node.parent(), label, debug_options, show_backend_config, - /*profile=*/nullptr, filter) - .Dump(); - return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); -} - -string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, - int64 max_nodes, bool show_backend_config) { - CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!"; - auto debug_options = from.GetModule()->config().debug_options(); - - bool hit_limit = false; - NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit); - string label; - if (!hit_limit) { - label = StrCat("All paths from ", from.name(), " to ", to.name()); - } else { - label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(), - " to ", to.name(), - "

***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN " - "NODES***

"); - } - string graph = - HloDotDumper(from.parent(), label, debug_options, show_backend_config, - /*profile=*/nullptr, filter) - .Dump(); - return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); -} - -void DumpText(const HloModule& module, const string& label, - const string& directory_path, bool do_prefix) { - Env* env = Env::Default(); - TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); - string prefix = StrCat(env->NowMicros()); - string filename = - do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); - string path = JoinPath(directory_path, filename); - TF_CHECK_OK(WriteStringToFile( - env, path, - module.ToString(HloPrintOptions().set_print_large_constants(true)))); - LOG(INFO) << "dumping module '" << module.name() << "' to " << path; -} - -string MaybeDumpHloModule(const HloModule& module, const string& label, - const HloExecutionProfile* profile) { - const DebugOptions& debug_options = module.config().debug_options(); - VLOG(2) << "MaybeDumpHloModule called on module " << module.name() - << " with generate_hlo_graph regex \"" - << debug_options.xla_generate_hlo_graph() << "\""; - string graph_url; - if (!debug_options.xla_generate_hlo_graph().empty() && - RE2::PartialMatch(module.name(), - debug_options.xla_generate_hlo_graph())) { - graph_url = - DumpGraph(*module.entry_computation(), label, debug_options, profile); - } - if (!debug_options.xla_log_hlo_text().empty() && - RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) { - LOG(INFO) << "HLO for module " << module.name(); - LOG(INFO) << "Label: " << label; - XLA_LOG_LINES(2, module.ToString()); - } - if (!debug_options.xla_generate_hlo_text_to().empty()) { - DumpText(module, label, debug_options.xla_generate_hlo_text_to()); - } - return graph_url; -} - -string WrapDotInHTML(const string& dot) { +string WrapDotInHtml(absl::string_view dot) { static const char html_prefix[] = R"html( @@ -1639,6 +1460,9 @@ string WrapDotInHTML(const string& dot) { var css_data = '' if (results !== null) { css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field. + // CSS inside DOT is URL-escaped, so we must unescape it + // before we can insert it into SVG. + css_data = unescape(css_data); dot_data = data.replace(cssregex, ''); // Remove the stylesheet } @@ -1706,37 +1530,117 @@ string WrapDotInHTML(const string& dot) { )html"; - return html_prefix + dot + html_suffix; + return absl::StrCat(html_prefix, dot, html_suffix); } -string RenderDotAsHTMLFile(const string& dot, - const DebugOptions& debug_options) { - string html = WrapDotInHTML(dot); +tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED); +std::function(absl::string_view)>* url_renderer + GUARDED_BY(url_renderer_mu) = nullptr; - auto env = tensorflow::Env::Default(); - std::vector dirs; - string output_dir = debug_options.xla_hlo_graph_path(); - if (output_dir.empty()) { - env->GetLocalTempDirectories(&dirs); +// Precondition: url_renderer != nullptr. +// +// (We specify this as a precondition rather than checking it in here and +// returning an error because we want to fail quickly when there's no URL +// renderer available, and this function runs only after we've done all the work +// of producing dot for the graph.) +StatusOr WrapDotInFormat(absl::string_view dot, + RenderedGraphFormat format) + EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) { + switch (format) { + case RenderedGraphFormat::kUrl: + CHECK(url_renderer != nullptr) + << "Should have checked url_renderer != null before calling."; + return (*url_renderer)(dot); + case RenderedGraphFormat::kHtml: + return WrapDotInHtml(dot); + case RenderedGraphFormat::kDot: + return string(dot); + } +} + +} // namespace + +void RegisterGraphToURLRenderer( + std::function(absl::string_view)> renderer) { + tensorflow::mutex_lock lock(url_renderer_mu); + if (url_renderer != nullptr) { + LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call " + "wins, but because order of initialization in C++ is " + "nondeterministic, this may not be what you want."; + } + delete url_renderer; + url_renderer = new std::function(absl::string_view)>( + std::move(renderer)); +} + +StatusOr RenderGraph(const HloComputation& computation, + absl::string_view label, + const DebugOptions& debug_options, + RenderedGraphFormat format, + const HloExecutionProfile* hlo_execution_profile, + bool show_backend_config) { + tensorflow::mutex_lock lock(url_renderer_mu); + if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { + return Unavailable("Can't render as URL; no URL renderer was registered."); + } + + string rendered_dot = + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) + .Dump(); + return WrapDotInFormat(rendered_dot, format); +} + +StatusOr RenderNeighborhoodAround( + const HloInstruction& node, int radius, RenderedGraphFormat format, + bool show_backend_config, + const absl::flat_hash_set& boundary) { + tensorflow::mutex_lock lock(url_renderer_mu); + if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { + return FailedPrecondition( + "Can't render as URL; no URL renderer was registered."); + } + + string label = + StrCat("Neighborhood of ", radius, " nodes around ", node.name()); + string rendered_dot = + HloDotDumper(node.parent(), label, + node.GetModule()->config().debug_options(), + show_backend_config, /*profile=*/nullptr, + MakeNodeRadiusAroundFilter(&node, radius, boundary)) + .Dump(); + return WrapDotInFormat(rendered_dot, format); +} + +StatusOr RenderAllPathsFromTo(const HloInstruction& from, + const HloInstruction& to, int64 max_nodes, + RenderedGraphFormat format, + bool show_backend_config) { + tensorflow::mutex_lock lock(url_renderer_mu); + if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { + return FailedPrecondition( + "Can't render as URL; no URL renderer was registered."); + } + + CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!"; + auto debug_options = from.GetModule()->config().debug_options(); + + bool hit_limit = false; + NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit); + string label; + if (!hit_limit) { + label = StrCat("All paths from ", from.name(), " to ", to.name()); } else { - dirs.push_back(output_dir); + label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(), + " to ", to.name(), + "

***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN " + "NODES***

"); } - // Try each directory, as they might be full, have inappropriate - // permissions or have different problems at times. - string output; - for (const string& dir : dirs) { - string filename = tensorflow::io::JoinPath(dir, "graph-"); - if (env->CreateUniqueFileName(&filename, ".html")) { - output = filename; - break; - } - } - if (output.empty()) { - LOG(FATAL) << "Failed to create unique output file name."; - } - TF_CHECK_OK(tensorflow::WriteStringToFile(env, output, html)); - return "file://" + output; + string rendered_dot = + HloDotDumper(from.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); + return WrapDotInFormat(rendered_dot, format); } -} // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 563cea42371..324ac67a6dd 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -23,52 +23,47 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" +// This file contains routines for rendering HLO computations into a +// human-readable graphical format. +// +// Fundamentally all graphs are rendered using the DOT language, but they can be +// packaged three different ways: +// +// - as a raw DOT file, which can be rendered using `graphviz`. +// +// - as an HTML file with an embedded DOT file, which can be viewed in a +// browser using a version of graphviz compiled to JavaScript +// +// - as a URL hosted somewhere which somehow embeds the DOT file. +// +// This last option is not implemented by default, but you can add a plugin to +// implement it via RegisterGraphToURLRenderer. +// +// TODO(jlebar): Rename this file to hlo_graph_renderer. + namespace xla { -namespace hlo_graph_dumper { -// Converts a HLO module to a DOT (graphviz) graph. Returns the dot graph as -// a string. -struct DotGraphOptions { - absl::string_view label; - const DebugOptions* debug_options = nullptr; - const HloExecutionProfile* profile = nullptr; - bool show_backend_config = false; -}; -string HloComputationToDotGraph(const HloComputation& computation, - const DotGraphOptions& options); - -// Abstract interface for classes that render HLO graphs (e.g. DOT graph, -// tensorflow GraphDef) to files or services. -class GraphRendererInterface { - public: - enum GraphKind { - DOT_GRAPH, - }; - - virtual ~GraphRendererInterface() = default; - - // Renders a DOT graph, returning a description of the rendered output - // (e.g., a URL) - virtual string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) = 0; +// Different formats that a graph can be packaged as. +enum class RenderedGraphFormat { + kDot, + kHtml, + kUrl, }; -// Dump the given HLO module if a dump is requested in its debug options. Based -// on the debug options, either a graph dump, a text dump or both may be -// generated. If a graph dump is generated, the description (e.g. an URL) is -// returned; otherwise an empty string is returned. -string MaybeDumpHloModule(const HloModule& module, const string& label, - const HloExecutionProfile* profile = nullptr); +// Renders an HLO module as a human-readable visual graph. +// +// Note that this only works well for relatively small graphs (no more than a +// few hundred nodes). Beyond that, the dot is usually unrenderable, +// unreadable, or both. To view such graphs, use a tool such as +// interactive_graphviz, which calls RenderNeighborhoodAround to render subsets +// of a graph. +StatusOr RenderGraph( + const HloComputation& computation, absl::string_view label, + const DebugOptions& debug_options, RenderedGraphFormat format, + const HloExecutionProfile* hlo_execution_profile = nullptr, + bool show_backend_config = false); -// Dumps a graph of the computation and returns a description of the rendered -// graph (e.g., a URL) based on the renderer. The "best" renderer in the -// registry is used. -string DumpGraph(const HloComputation& computation, const string& label, - const DebugOptions& debug_options, - const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_backend_config = false); - -// Like DumpGraph, but renders only nodes "near" the given node in the graph. +// Like RenderGraph, but renders only nodes "near" the given node in the graph. // // The number of nodes dumped is controlled by the radius parameter, which // (roughly) corresponds to the max distance a node may be from the primary node @@ -76,55 +71,28 @@ string DumpGraph(const HloComputation& computation, const string& label, // // The optional boundary specifies a set of boundary nodes, beyond which nodes // will be omitted even if they are within the radius. -string DumpNeighborhoodAround( - const HloInstruction& node, int radius, bool show_backend_config = false, +StatusOr RenderNeighborhoodAround( + const HloInstruction& node, int radius, RenderedGraphFormat format, + bool show_backend_config = false, const absl::flat_hash_set& boundary = {}); -// Dumps nodes on any of the paths from `from` to `to`. If there are more than -// max_nodes on all paths, restricts to the max_nodes nodes on the shortest +// Renders nodes on any of the paths from `from` to `to`. If there are more +// than max_nodes on all paths, restricts to the max_nodes nodes on the shortest // paths. -string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, - int64 max_nodes, bool show_backend_config = false); +StatusOr RenderAllPathsFromTo(const HloInstruction& from, + const HloInstruction& to, int64 max_nodes, + RenderedGraphFormat format, + bool show_backend_config = false); -// Dumps the HloModule::ToString() as a file into the provided directory path -// suffixed with the provided label. +// Registers a function which implements RenderedGraphFormat::kUrl. // -// If do_prefix is true, a timestamp will be prepended onto the label to -// construct a filename in the directory path; otherwise, the label is used -// as the filename directly. -void DumpText(const HloModule& module, const string& label, - const string& directory_path, bool do_prefix = true); +// The input to the function is dot, and the output should be a URL or an error. +// +// There can only be one active renderer, and the last call to this function +// wins. +void RegisterGraphToURLRenderer( + std::function(absl::string_view dot)> renderer); -// Renders DOT graph as inline SVG and saves it in an HTML file in a temprary -// directory or directory specified via --xla_hlo_graph_path. Returns the file -// URI pointing to the file. -string RenderDotAsHTMLFile(const string& dot, - const DebugOptions& debug_options); - -// Graph renderers may be added using a registration mechanism, e.g.: -// XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) -// The renderer with the highest numeric priority value is used. - -#define XLA_REGISTER_GRAPH_RENDERER(factory, ...) \ - XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, __COUNTER__, ##__VA_ARGS__) - -// Internal implementation details below this point. - -// Class that registers a graph renderer. -class Registrar { - public: - Registrar(std::shared_ptr dumper); -}; - -#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ - static ::xla::hlo_graph_dumper::Registrar \ - XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)( \ - std::make_shared(), ##__VA_ARGS__) - -// __COUNTER__ must go through another macro to be properly expanded -#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr) ___##ctr##__object_ - -} // namespace hlo_graph_dumper } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GRAPH_DUMPER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 064c53252c0..fa1ff49de87 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -31,24 +32,13 @@ namespace { using absl::StrCat; using ::testing::HasSubstr; +using HloGraphDumperTest = HloTestBase; + string TestName() { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } -class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { - public: - string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) override { - return graph; - } - - private: - string last_graph_; -}; - -XLA_REGISTER_GRAPH_RENDERER(DotRenderer); - -TEST(HloGraphDumperTest, NestedFusion) { +TEST_F(HloGraphDumperTest, NestedFusion) { HloComputation::Builder b("b"); // Build param0 + param1 + param2 + param3 + param4. @@ -90,8 +80,9 @@ TEST(HloGraphDumperTest, NestedFusion) { {fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop); // Generate the graph; all nodes should be present. - string graph = hlo_graph_dumper::DumpGraph(*root_computation, /*label=*/"", - DebugOptions()); + TF_ASSERT_OK_AND_ASSIGN( + string graph, RenderGraph(*root_computation, /*label=*/"", DebugOptions(), + RenderedGraphFormat::kDot)); for (const HloComputation* computation : {root_computation, // inner_fusion->fused_instructions_computation(), @@ -113,12 +104,13 @@ TEST(HloGraphDumperTest, NestedFusion) { } } ASSERT_NE(inner_sum, nullptr); - EXPECT_THAT( - hlo_graph_dumper::DumpNeighborhoodAround(*inner_sum, /*radius=*/1), - HasSubstr(inner_sum->name())); + TF_ASSERT_OK_AND_ASSIGN(string neighborhood_graph, + RenderNeighborhoodAround(*inner_sum, /*radius=*/1, + RenderedGraphFormat::kDot)); + EXPECT_THAT(neighborhood_graph, HasSubstr(inner_sum->name())); } -TEST(HloGraphDumperTest, Constant) { +TEST_F(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(-42))); @@ -126,13 +118,14 @@ TEST(HloGraphDumperTest, Constant) { HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); - string graph = hlo_graph_dumper::DumpGraph( - *root_computation, /*label=*/"an_empty_graph", DebugOptions()); + TF_ASSERT_OK_AND_ASSIGN( + string graph, RenderGraph(*root_computation, /*label=*/"an_empty_graph", + DebugOptions(), RenderedGraphFormat::kDot)); EXPECT_THAT(graph, HasSubstr("an_empty_graph")); EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); } -TEST(HloGraphDumperTest, TupleConstant) { +TEST_F(HloGraphDumperTest, TupleConstant) { Shape tuple_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})}); HloComputation::Builder b("b"); @@ -144,11 +137,30 @@ TEST(HloGraphDumperTest, TupleConstant) { HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build(gte)); - string graph = hlo_graph_dumper::DumpGraph( - *root_computation, /*label=*/"tuple_constant", DebugOptions()); + TF_ASSERT_OK_AND_ASSIGN( + string graph, RenderGraph(*root_computation, /*label=*/"tuple_constant", + DebugOptions(), RenderedGraphFormat::kDot)); EXPECT_THAT(graph, HasSubstr("tuple_constant")); EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])")); } +TEST_F(HloGraphDumperTest, Compare) { + const char* hlo_string = R"( + HloModule comp + + ENTRY comp { + param.0 = f32[10] parameter(0) + param.1 = f32[10] parameter(1) + ROOT lt = pred[10] compare(param.0, param.1), direction=LT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + string graph, + RenderGraph(*module->entry_computation(), /*label=*/"tuple_constant", + DebugOptions(), RenderedGraphFormat::kDot)); + EXPECT_THAT(graph, HasSubstr("direction=LT")); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc deleted file mode 100644 index 84c4cf18df6..00000000000 --- a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Implementation of an DOT graph renderer that uses Javascript to render DOT to -// SVG in a browser. - -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -class GraphHtmlRenderer : public GraphRendererInterface { - public: - string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) override { - switch (graph_kind) { - case DOT_GRAPH: - return RenderDotAsHTMLFile(graph, debug_options); - default: - LOG(FATAL) << "Only DOT graphs can be rendered"; - } - } -}; - -XLA_REGISTER_GRAPH_RENDERER(GraphHtmlRenderer); - -} // namespace -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index b01c00121b3..c74ed27f484 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -30,9 +30,9 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias) << kind; TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) - << absl::StrCat("Tring to set up alias at ", output_index.ToString(), - " which is an invalid index for shape ", - ShapeUtil::HumanString(alias_.shape())); + << "Trying to set up alias at " << output_index.ToString() + << " which is an invalid index for shape " + << ShapeUtil::HumanString(alias_.shape()); TF_RET_CHECK(param_number >= 0) << param_number; TF_RET_CHECK(!OutputHasAlias(output_index)) << "Output index " << output_index << " already has an alias setup"; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index bb45eb4fa0d..c36834c9756 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -64,7 +64,35 @@ StatusOr> HloInstruction::CreateFromProto( const absl::flat_hash_map& instruction_map, const absl::flat_hash_map& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); + HloOpcode opcode; + auto opcode_or = StringToHloOpcode(proto.opcode()); + absl::optional comparison_direction; + if (opcode_or.ok()) { + opcode = opcode_or.ConsumeValueOrDie(); + } else { + // Unknown opcode. Try auto-upgrading deprecated "less-than", + // "greater-than", etc opcodes, which are now rolled into the kCompare + // opcode. + if (proto.opcode() == "equal-to") { + comparison_direction = ComparisonDirection::kEq; + } else if (proto.opcode() == "not-equal-to") { + comparison_direction = ComparisonDirection::kNe; + } else if (proto.opcode() == "greater-than-or-equal-to") { + comparison_direction = ComparisonDirection::kGe; + } else if (proto.opcode() == "greater-than") { + comparison_direction = ComparisonDirection::kGt; + } else if (proto.opcode() == "less-than-or-equal-to") { + comparison_direction = ComparisonDirection::kLe; + } else if (proto.opcode() == "less-than") { + comparison_direction = ComparisonDirection::kLt; + } + if (comparison_direction) { + opcode = HloOpcode::kCompare; + } else { + return InvalidArgument("Unknown opcode: %s", proto.opcode()); + } + } + TF_RET_CHECK(proto.has_shape()); std::unique_ptr instruction; @@ -136,6 +164,17 @@ StatusOr> HloInstruction::CreateFromProto( absl::Span(fft_length)); break; } + case HloOpcode::kCompare: { + // Auto-upgraded from deprecated opcode skips the following. + if (!comparison_direction) { + TF_ASSIGN_OR_RETURN( + comparison_direction, + StringToComparisonDirection(proto.comparison_direction())); + } + instruction = + CreateCompare(shape, operands(0), operands(1), *comparison_direction); + break; + } case HloOpcode::kTriangularSolve: { instruction = CreateTriangularSolve(shape, operands(0), operands(1), proto.triangular_solve_options()); @@ -264,6 +303,10 @@ StatusOr> HloInstruction::CreateFromProto( TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); + // Literal's shape may have no/different tiling info. + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + instruction->shape(), shape)); + *instruction->mutable_shape() = shape; } else { instruction = absl::make_unique(shape); } @@ -373,6 +416,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateReplicaId(); break; } + case HloOpcode::kPartitionId: { + instruction = CreatePartitionId(); + break; + } case HloOpcode::kConvolution: { TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); @@ -413,11 +460,11 @@ StatusOr> HloInstruction::CreateFromProto( } instruction = CreateCustomCall(shape, all_operands(), proto.custom_call_target(), - operand_shapes, proto.custom_call_opaque()); + operand_shapes, proto.backend_config()); } else { instruction = CreateCustomCall(shape, all_operands(), proto.custom_call_target(), - proto.custom_call_opaque()); + proto.backend_config()); } if (proto.has_window()) { static_cast(instruction.get()) @@ -556,6 +603,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->AppendOperand(instruction_map.at(operand_id)); } if (instruction->opcode() != HloOpcode::kFusion) { + if (instruction->opcode() == HloOpcode::kCall) { + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Call should have 1 called computation but has " + << proto.called_computation_ids_size(); + } for (const int64 computation_id : proto.called_computation_ids()) { instruction->called_computations_.push_back( computation_map.at(computation_id)); @@ -579,6 +631,9 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + instruction->outer_dimension_partitions_.assign( + proto.outer_dimension_partitions().begin(), + proto.outer_dimension_partitions().end()); TF_RET_CHECK(proto.id() >= 0) << "Instruction with negative id: " << proto.id(); @@ -664,6 +719,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: case HloOpcode::kSign: @@ -688,15 +744,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: @@ -761,6 +811,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, fft_length); } +/* static */ std::unique_ptr HloInstruction::CreateCompare( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + ComparisonDirection direction) { + return absl::make_unique(shape, lhs, rhs, direction); +} + /* static */ std::unique_ptr HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a, HloInstruction* b, @@ -820,6 +876,12 @@ HloInstruction::CreateCollectivePermute( new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {}))); } +/* static */ std::unique_ptr +HloInstruction::CreatePartitionId() { + return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, + ShapeUtil::MakeShape(U32, {}))); +} + /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -1240,18 +1302,18 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, absl::Span operands, - absl::string_view custom_call_target, absl::string_view opaque) { + absl::string_view custom_call_target, string opaque) { return absl::make_unique( - shape, operands, custom_call_target, opaque); + shape, operands, custom_call_target, std::move(opaque)); } /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, - absl::Span operand_shapes_with_layout, - absl::string_view opaque) { + absl::Span operand_shapes_with_layout, string opaque) { return absl::make_unique( - shape, operands, custom_call_target, opaque, operand_shapes_with_layout); + shape, operands, custom_call_target, std::move(opaque), + operand_shapes_with_layout); } /* static */ std::unique_ptr HloInstruction::CreateTuple( @@ -1311,6 +1373,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kFft: + case HloOpcode::kCompare: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kRecv: @@ -1368,6 +1431,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: case HloOpcode::kSign: @@ -1384,12 +1448,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDivide: case HloOpcode::kMultiply: case HloOpcode::kSubtract: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1461,10 +1519,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 0); clone = CreateReplicaId(); break; + case HloOpcode::kPartitionId: + CHECK_EQ(new_operands.size(), 0); + clone = CreatePartitionId(); + break; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_outer_dimension_partitions(outer_dimension_partitions_); clone->set_raw_backend_config_string(backend_config_); if (context != nullptr) { context->MapInstruction(this, clone.get()); @@ -1565,12 +1628,12 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { } const HloInstruction* HloInstruction::operand(int64 i) const { - return operands_[i]; + return operands_.at(i); } HloInstruction* HloInstruction::mutable_operand(int64 i) { CHECK(operands_[i] != nullptr); - return operands_[i]; + return operands_.at(i); } int64 HloInstruction::operand_index(const HloInstruction* target) const { @@ -1705,27 +1768,23 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: case HloOpcode::kXor: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kPartitionId: + case HloOpcode::kPopulationCount: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: @@ -1772,6 +1831,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kFft: + case HloOpcode::kCompare: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kRecv: @@ -1878,6 +1938,7 @@ Status HloInstruction::ReplaceUseWithDifferentShape( std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + // Custom fusions may not be able to handle deduplicated operands. if (user->opcode() == HloOpcode::kFusion) { TF_RETURN_IF_ERROR( Cast(user)->DeduplicateFusionOperands()); @@ -2106,6 +2167,7 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kRsqrt: @@ -2119,17 +2181,12 @@ bool HloInstruction::IsElementwiseImpl( // Binary elementwise operations, the same as in IsElementwiseBinary(). case HloOpcode::kAdd: case HloOpcode::kAtan2: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: @@ -2181,9 +2238,15 @@ string HloInstruction::ToStringWithCanonicalNameMap( StrAppend(&result, PrintName(name(), options), " = "); } - // Print opcode, operand(s) and shape. - StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", + // Print shape. + if (options.include_layout_in_shapes()) { + StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape())); + } else { + StrAppend(&result, ShapeUtil::HumanString(shape())); + } + + // Print opcode, operand(s). + StrAppend(&result, " ", HloOpcodeString(opcode()), "(", OperandsToStringWithCanonicalNameMap(options, canonical_name_map), ")"); @@ -2227,7 +2290,11 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( } std::vector str; if (options.print_operand_shape()) { - str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); + if (options.include_layout_in_shapes()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); + } else { + str.push_back(ShapeUtil::HumanString(operand->shape())); + } } // In a top-level HloInstruction::ToString() call, the operand name is not @@ -2351,6 +2418,11 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } + if (!outer_dimension_partitions_.empty()) { + extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}", + StrJoin(outer_dimension_partitions_, ","))); + } + if (options.print_control_dependencies() && !control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", StrJoin(control_predecessors_, ", ", @@ -2400,6 +2472,11 @@ HloInstructionProto HloInstruction::ToProto() const { if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } + if (!outer_dimension_partitions_.empty()) { + for (const auto& idx : outer_dimension_partitions_) { + proto.mutable_outer_dimension_partitions()->Add(idx); + } + } return proto; } @@ -2425,6 +2502,22 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } +bool HloInstruction::IsInputFusion() const { + return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput; +} + +bool HloInstruction::IsLoopFusion() const { + return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop; +} + +bool HloInstruction::IsOutputFusion() const { + return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput; +} + +bool HloInstruction::IsCustomFusion() const { + return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom; +} + bool HloInstruction::IsFusible() const { // Instructions which are traced should not be fused. if (tracing()) { @@ -2472,12 +2565,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleGetTupleElement(this); case HloOpcode::kParameter: return visitor->HandleParameter(this); - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: + case HloOpcode::kCompare: return visitor->HandleCompare(this); case HloOpcode::kComplex: return visitor->HandleComplex(this); @@ -2535,6 +2623,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCollectivePermute(this); case HloOpcode::kReplicaId: return visitor->HandleReplicaId(this); + case HloOpcode::kPartitionId: + return visitor->HandlePartitionId(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2581,6 +2671,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleIsFinite(this); case HloOpcode::kNot: return visitor->HandleNot(this); + case HloOpcode::kPopulationCount: + return visitor->HandlePopulationCount(this); case HloOpcode::kBitcast: return visitor->HandleBitcast(this); case HloOpcode::kBroadcast: @@ -2691,7 +2783,12 @@ template static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, bool ignore_control_predecessors) { - visitor->ReserveVisitStates(root->GetModule()->instruction_count()); + // Calculating the instruction count within a module can be expensive on large + // models so only do it if the visit state is empty. This will help when the + // same visitor is reused across many computations of a single module. + if (visitor->VisitStateCapacity() == 0) { + visitor->ReserveVisitStates(root->GetModule()->instruction_count()); + } // dfs_stack holds pairs of unique_id(), HloInstruction*>. // @@ -2835,11 +2932,6 @@ bool HloInstruction::IsElementwise() const { return IsElementwiseImpl(absl::nullopt); } -bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { - CHECK(IsElementwise()); - return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); -} - bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { return IsElementwiseImpl(operand_idx); } @@ -2949,9 +3041,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { } return UseKind::kReuse; default: - return IsElementwise() && !ImplicitlyBroadcastsOperand(i) - ? UseKind::kUse - : UseKind::kReuse; + return IsElementwise() ? UseKind::kUse : UseKind::kReuse; } } @@ -3064,6 +3154,16 @@ string ConvolutionDimensionNumbersToString( StrJoin(output_dims, "")); } +string ReplicaGroupsToString(const std::vector& replica_groups) { + std::vector replica_group_str; + replica_group_str.reserve(replica_groups.size()); + for (const ReplicaGroup& group : replica_groups) { + replica_group_str.push_back( + StrCat("{", StrJoin(group.replica_ids(), ","), "}")); + } + return StrCat("{", StrJoin(replica_group_str, ","), "}"); +} + StatusOr StringToRandomDistribution(const string& name) { static std::unordered_map* map = [] { static auto* map = new std::unordered_map; @@ -3157,7 +3257,12 @@ Status HloInstruction::set_backend_config( /* static */ StatusOr HloInstruction::BackendConfigToRawString( const tensorflow::protobuf::Message& proto) { string ret; - TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret)); + // Pass ignore_accuracy_loss = true because estimated_cycles field can be + // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles = + // INT64_MAX, JsonFormat will return an error status, although there is no + // accuracy loss for int64. + TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson( + proto, &ret, /*ignore_accuracy_loss=*/true)); return ret; } @@ -3363,6 +3468,13 @@ void HloInstruction::set_parameter_replicated_at_leaf_buffers( parameter_replicated_at_leaf_buffers); } +void HloInstruction::set_parameter_replicated_at_leaf_buffers( + const std::vector& parameter_replicated_at_leaf_buffers) { + return Cast(this) + ->set_parameter_replicated_at_leaf_buffers( + parameter_replicated_at_leaf_buffers); +} + const absl::optional>& HloInstruction::parameter_replicated_at_leaf_buffers() const { return Cast(this) @@ -3373,6 +3485,11 @@ int64 HloInstruction::tuple_index() const { return Cast(this)->tuple_index(); } +void HloInstruction::set_tuple_index(int64 new_tuple_index) { + return Cast(this)->set_tuple_index( + new_tuple_index); +} + int32 HloInstruction::exponent_bits() const { return Cast(this)->exponent_bits(); } @@ -3526,6 +3643,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const { return Cast(this)->user_side_metadata(); } +ComparisonDirection HloInstruction::comparison_direction() const { + return Cast(this)->direction(); +} + const TriangularSolveOptions& HloInstruction::triangular_solve_options() const { return Cast(this)->triangular_solve_options(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 9a46626f5f9..b9c4ba62d4d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -79,6 +80,7 @@ class HloPrintOptions { print_metadata_(true), print_backend_config_(true), compact_operands_(false), + include_layout_in_shapes_(true), print_operand_shape_(true), print_operand_names_(true), print_program_shape_(true), @@ -177,6 +179,13 @@ class HloPrintOptions { return *this; } + // If true, include the layout in any shapes that are printed (instruction + // and operands). + HloPrintOptions& set_include_layout_in_shapes(bool value) { + include_layout_in_shapes_ = value; + return *this; + } + // If true, canonicalizes instructions' name. Instead of using "%foo.1" as // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. HloPrintOptions& set_canonicalize_instruction_names(bool value) { @@ -204,6 +213,7 @@ class HloPrintOptions { bool print_metadata() const { return print_metadata_; } bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } + bool include_layout_in_shapes() const { return include_layout_in_shapes_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_operand_names() const { return print_operand_names_; } bool print_program_shape() const { return print_program_shape_; } @@ -223,6 +233,7 @@ class HloPrintOptions { bool print_metadata_; bool print_backend_config_; bool compact_operands_; + bool include_layout_in_shapes_; bool print_operand_shape_; bool print_operand_names_; bool print_program_shape_; @@ -444,6 +455,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, FftType fft_type, absl::Span fft_length); + // Creates a compare op, performing the comparison specified in direction. + static std::unique_ptr CreateCompare( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + ComparisonDirection direction); + static std::unique_ptr CreateTriangularSolve( const Shape& shape, HloInstruction* a, HloInstruction* b, const TriangularSolveOptions& options); @@ -513,6 +529,9 @@ class HloInstruction { // Creates an instruction that returns a U32 replica ID. static std::unique_ptr CreateReplicaId(); + // Creates an instruction that returns a U32 partition ID. + static std::unique_ptr CreatePartitionId(); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, @@ -758,7 +777,7 @@ class HloInstruction { // backend-specific interpretation. "shape" is the resultant shape. static std::unique_ptr CreateCustomCall( const Shape& shape, absl::Span operands, - absl::string_view custom_call_target, absl::string_view opaque = ""); + absl::string_view custom_call_target, string opaque = ""); // Overload which constrains the layouts of the operand and result. 'shape' // and 'operand_shapes_with_layout' must have layouts. @@ -767,8 +786,7 @@ class HloInstruction { static std::unique_ptr CreateCustomCall( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, - absl::Span operand_shapes_with_layout, - absl::string_view opaque = ""); + absl::Span operand_shapes_with_layout, string opaque = ""); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. @@ -1124,6 +1142,11 @@ class HloInstruction { // instruction. bool IsFused() const; + bool IsLoopFusion() const; + bool IsInputFusion() const; + bool IsOutputFusion() const; + bool IsCustomFusion() const; + // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusible() const; @@ -1220,10 +1243,8 @@ class HloInstruction { // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, - // after performing necessary implicit broadcast - // (cs/IrArray::EmitArrayElementAddress), to compute the output at index - // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is - // the element at {i_0,i_1,...,i_n}. + // to compute the output at index {i_0,i_1,...,i_n}, the only element required + // from the operand (if any) is the element at {i_0,i_1,...,i_n}. // // Note on performance: when this instruction is kFusion, this method, in the // worst case, scans all fused instructions. We could speed this up by @@ -1239,12 +1260,6 @@ class HloInstruction { // Returns true if this is a cross-replica all-reduce instruction. bool IsCrossReplicaAllReduce() const; - // Returns true if this elementwise instruction implicitly broadcasts operand - // `operand_idx`. - // - // Precondition: this instruction should be an elementwise operation. - bool ImplicitlyBroadcastsOperand(int64 operand_idx) const; - // Returns true if this instruction is binary and elementwise. bool IsElementwiseBinary() const; @@ -1299,7 +1314,8 @@ class HloInstruction { // this HLO. The meaning of the field is backend specific. Not for use before // or during general HLO optimization, since HLO optimizations do not preserve // this field and they cannot interpret it due to its meaning being backend - // specific. + // specific. Except for CustomCall, where this field is preserved and no + // general HLO optimization needs to interpret it. // // ConfigProto should be a protobuf Message type. template @@ -1494,6 +1510,8 @@ class HloInstruction { // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. void set_parameter_replicated_at_leaf_buffers( absl::Span parameter_replicated_at_leaf_buffers); + void set_parameter_replicated_at_leaf_buffers( + const std::vector& parameter_replicated_at_leaf_buffers); // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. const absl::optional>& @@ -1502,6 +1520,9 @@ class HloInstruction { // Delegates to HloGetTupleElementInstruction::tuple_index. int64 tuple_index() const; + // Delegates to HloGetTupleElementInstruction::set_tuple_index. + void set_tuple_index(int64 new_tuple_index); + // Delegates to HloReducePrecisionInstruction::exponent_bits. int32 exponent_bits() const; @@ -1608,6 +1629,9 @@ class HloInstruction { // Delegates to HloDomainInstruction::user_side_metadata(). const DomainMetadata& user_side_metadata() const; + // Delegates to HloCompareInstruction::direction(). + ComparisonDirection comparison_direction() const; + // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). const TriangularSolveOptions& triangular_solve_options() const; @@ -1813,6 +1837,7 @@ string RandomDistributionToString(const RandomDistribution& distribution); string PrecisionToString(const PrecisionConfig::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); +string ReplicaGroupsToString(const std::vector& replica_groups); StatusOr StringToRandomDistribution(const string& name); StatusOr StringToPrecision(const string& name); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 35f031f29a7..dee4af42bbc 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1331,6 +1331,16 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); + auto options2 = HloPrintOptions() + .set_print_metadata(false) + .set_print_operand_shape(false) + .set_print_percent(false) + .set_include_layout_in_shapes(false); + + EXPECT_EQ(dot->ToString(options2), + "dot = f32[5,20] dot(x, transpose), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1655,7 +1665,7 @@ body (bparam: s32[]) -> s32[] { condition (cparam: s32[]) -> pred[] { xconstant = s32[] constant(5) cparam = s32[] parameter(0) - ROOT greater-than = pred[] greater-than(xconstant, cparam) + ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT } ENTRY entry (param: s32[]) -> s32[] { @@ -1759,5 +1769,19 @@ TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) { ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT)); } +TEST_F(HloInstructionTest, PreserveOuterDimensionPartitionsOnClone) { + constexpr char kHloString[] = R"( + HloModule test_module + ENTRY test { + ROOT iota = f32[100] iota(), iota_dimension=1, outer_dimension_partitions={0, 50} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kHloString)); + auto* iota = module->entry_computation()->root_instruction(); + + auto clone = iota->Clone(); + EXPECT_THAT(clone->outer_dimension_partitions(), + ::testing::ElementsAre(0, 50)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 7d18b35c2bb..7a6d563b83f 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/window_util.h" @@ -202,6 +203,42 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } +HloCompareInstruction::HloCompareInstruction(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs, + ComparisonDirection direction) + : HloInstruction(HloOpcode::kCompare, shape), direction_(direction) { + AppendOperand(lhs); + AppendOperand(rhs); +} + +HloInstructionProto HloCompareInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_comparison_direction(ComparisonDirectionToString(direction_)); + return proto; +} + +std::vector HloCompareInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("direction=", ComparisonDirectionToString(direction()))}; +} + +bool HloCompareInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return direction() == casted_other.direction(); +} + +std::unique_ptr HloCompareInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique(shape, new_operands[0], + new_operands[1], direction()); +} + namespace { // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector @@ -456,15 +493,7 @@ HloInstructionProto HloCollectiveInstruction::ToProto() const { std::vector HloCollectiveInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& /*options*/) const { - std::vector result; - std::vector replica_group_str; - for (const ReplicaGroup& group : replica_groups()) { - replica_group_str.push_back( - StrCat("{", StrJoin(group.replica_ids(), ","), "}")); - } - result.push_back( - StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}")); - return result; + return {StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))}; } bool HloCollectiveInstruction::IdenticalSlowPath( @@ -989,6 +1018,11 @@ HloConstantInstruction::HloConstantInstruction(Literal literal) : HloInstruction(HloOpcode::kConstant, literal.shape()), literal_(std::move(literal)) {} +HloConstantInstruction::HloConstantInstruction(Literal literal, + const Shape& shape) + : HloInstruction(HloOpcode::kConstant, shape), + literal_(std::move(literal)) {} + HloConstantInstruction::HloConstantInstruction(const Shape& shape) : HloInstruction(HloOpcode::kConstant, shape) {} @@ -1034,7 +1068,12 @@ HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK(literal_.has_value()); - return absl::make_unique(literal_->Clone()); + // Literal's shape may have no/different tiling info. Use this instruction's + // shape instead. + CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(literal_->shape(), + this->shape())); + return absl::make_unique(literal_->Clone(), + this->shape()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -1541,6 +1580,9 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( } Status HloFusionInstruction::DeduplicateFusionOperands() { + if (IsCustomFusion()) { + return Status::OK(); + } absl::flat_hash_map operand_indices; std::vector operands_to_remove; for (int i = 0; i < operand_count(); ++i) { @@ -2007,13 +2049,13 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span operands, - absl::string_view custom_call_target, absl::string_view opaque) + absl::string_view custom_call_target, string opaque) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()), - opaque_(opaque.begin(), opaque.end()), feature_group_count_(1), batch_group_count_(1), layout_constrained_(false) { + set_raw_backend_config_string(std::move(opaque)); for (auto operand : operands) { AppendOperand(operand); } @@ -2021,16 +2063,16 @@ HloCustomCallInstruction::HloCustomCallInstruction( HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span operands, - absl::string_view custom_call_target, absl::string_view opaque, + absl::string_view custom_call_target, string opaque, absl::Span operand_shapes_with_layout) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()), - opaque_(opaque.begin(), opaque.end()), feature_group_count_(1), batch_group_count_(1), layout_constrained_(true), operand_shapes_with_layout_(operand_shapes_with_layout.begin(), operand_shapes_with_layout.end()) { + set_raw_backend_config_string(std::move(opaque)); for (auto operand : operands) { AppendOperand(operand); } @@ -2046,7 +2088,6 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); - proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); proto.set_batch_group_count(batch_group_count_); if (layout_constrained()) { @@ -2080,11 +2121,7 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( // an HloComputation. extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); - // If the opaque string becomes enormous we may want to reconsider printing - // this inline and consider other options. - if (!opaque_.empty()) { - extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\"")); - } + if (layout_constrained()) { std::vector shape_strings; for (const Shape& shape : operand_shapes_with_layout_) { @@ -2132,8 +2169,9 @@ bool HloCustomCallInstruction::IdenticalSlowPath( } } } - return custom_call_target_ == casted_other.custom_call_target_ && - opaque_ == casted_other.opaque_; + // Note: backend_config comparison is done in Identical, which is the + // intended/exposed way to compare computations, and so not repeated here. + return custom_call_target_ == casted_other.custom_call_target_; } std::unique_ptr @@ -2282,19 +2320,19 @@ HloGatherInstruction::HloGatherInstruction( absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } -string HloGatherInstruction::GatherDimensionNumbersToString() const { - CHECK(gather_dimension_numbers_ != nullptr); +/*static*/ string HloGatherInstruction::GatherDimensionNumbersToString( + const GatherDimensionNumbers& gather_dimension_numbers) { string offset_dims = StrCat("offset_dims={", - StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}"); + StrJoin(gather_dimension_numbers.offset_dims(), ","), "}"); string collapsed_slice_dims = StrCat( "collapsed_slice_dims={", - StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + StrJoin(gather_dimension_numbers.collapsed_slice_dims(), ","), "}"); string start_index_map = StrCat("start_index_map={", - StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}"); - string index_vector_dim = StrCat( - "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); + StrJoin(gather_dimension_numbers.start_index_map(), ","), "}"); + string index_vector_dim = + StrCat("index_vector_dim=", gather_dimension_numbers.index_vector_dim()); return StrJoin>( {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, @@ -2331,7 +2369,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {GatherDimensionNumbersToString(), + return {GatherDimensionNumbersToString(gather_dimension_numbers()), StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")}; } @@ -2369,19 +2407,20 @@ HloScatterInstruction::HloScatterInstruction( absl::make_unique(scatter_dim_numbers); } -string HloScatterInstruction::ScatterDimensionNumbersToString() const { - string update_window_dims = StrCat( - "update_window_dims={", - StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}"); +/*static*/ string HloScatterInstruction::ScatterDimensionNumbersToString( + const ScatterDimensionNumbers& scatter_dimension_numbers) { + string update_window_dims = + StrCat("update_window_dims={", + StrJoin(scatter_dimension_numbers.update_window_dims(), ","), "}"); string inserted_window_dims = StrCat( "inserted_window_dims={", - StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + StrJoin(scatter_dimension_numbers.inserted_window_dims(), ","), "}"); string scatter_dims_to_operand_dims = StrCat( "scatter_dims_to_operand_dims={", - StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + StrJoin(scatter_dimension_numbers.scatter_dims_to_operand_dims(), ","), "}"); - string index_vector_dim = StrCat( - "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); + string index_vector_dim = + StrCat("index_vector_dim=", scatter_dimension_numbers.index_vector_dim()); return StrJoin>( {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, @@ -2418,7 +2457,7 @@ HloInstructionProto HloScatterInstruction::ToProto() const { std::vector HloScatterInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {ScatterDimensionNumbersToString()}; + return {ScatterDimensionNumbersToString(scatter_dimension_numbers())}; } bool HloScatterInstruction::IdenticalSlowPath( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 43aa12c10f2..dc3634c41ce 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -131,6 +131,28 @@ class HloFftInstruction : public HloInstruction { std::vector fft_length_; }; +class HloCompareInstruction : public HloInstruction { + public: + explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + ComparisonDirection direction); + ComparisonDirection direction() const { return direction_; } + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + ComparisonDirection direction_; +}; + class HloTriangularSolveInstruction : public HloInstruction { public: explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, @@ -628,6 +650,7 @@ class HloSliceInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction { public: explicit HloConstantInstruction(Literal literal); + explicit HloConstantInstruction(Literal literal, const Shape& shape); // Used when the literal is too large and dropped. explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. @@ -853,6 +876,13 @@ class HloParameterInstruction : public HloInstruction { parameter_replicated_at_leaf_buffers.begin(), parameter_replicated_at_leaf_buffers.end()); } + void set_parameter_replicated_at_leaf_buffers( + const std::vector& parameter_replicated_at_leaf_buffers) { + CHECK_EQ(ShapeUtil::GetLeafCount(shape()), + parameter_replicated_at_leaf_buffers.size()); + parameter_replicated_at_leaf_buffers_ = + parameter_replicated_at_leaf_buffers; + } const absl::optional>& parameter_replicated_at_leaf_buffers() const { return parameter_replicated_at_leaf_buffers_; @@ -889,6 +919,10 @@ class HloGetTupleElementInstruction : public HloInstruction { HloInstruction* operand, int64 index); // Returns the tuple index associated with this instruction. int64 tuple_index() const { return tuple_index_; } + // Sets the tuple index associated with this instruction. + void set_tuple_index(int64 new_tuple_index) { + tuple_index_ = new_tuple_index; + } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1146,15 +1180,13 @@ class HloCustomCallInstruction : public HloInstruction { public: HloCustomCallInstruction(const Shape& shape, absl::Span operands, - absl::string_view custom_call_target, - absl::string_view opaque); + absl::string_view custom_call_target, string opaque); // Constructor for a custom call with constrained layout. 'shape' and // 'operands_with_layout' must all have layouts. HloCustomCallInstruction(const Shape& shape, absl::Span operands, - absl::string_view custom_call_target, - absl::string_view opaque, + absl::string_view custom_call_target, string opaque, absl::Span operand_shapes_with_layout); const Window& window() const override { @@ -1176,7 +1208,8 @@ class HloCustomCallInstruction : public HloInstruction { convolution_dimension_numbers_ = absl::make_unique(dnums); } - const string& opaque() const { return opaque_; } + // TODO(jpienaar): Remove this accessor in the follow up. + const string& opaque() const { return raw_backend_config_string(); } const string& custom_call_target() const { return custom_call_target_; } void set_feature_group_count(int64 feature_group_count) { feature_group_count_ = feature_group_count; @@ -1212,8 +1245,6 @@ class HloCustomCallInstruction : public HloInstruction { HloCloneContext* context) const override; // Name of a global symbol to call. string custom_call_target_; - // Opaque string interpreted by the backend. - string opaque_; // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; // Describes the dimension numbers used for a convolution. @@ -1348,8 +1379,6 @@ class HloGatherInstruction : public HloInstruction { absl::Span gather_slice_sizes() const { return gather_slice_sizes_; } - // Returns the dump string of the gather dimension numbers. - string GatherDimensionNumbersToString() const; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1358,6 +1387,9 @@ class HloGatherInstruction : public HloInstruction { absl::Span offset_dims, absl::Span collapsed_slice_dims, absl::Span start_index_map, int64 index_vector_dim); + // Returns the dump string of the given gather dimension numbers. + static string GatherDimensionNumbersToString( + const GatherDimensionNumbers& gather_dimension_numbers); private: std::vector ExtraAttributesToStringImpl( @@ -1385,8 +1417,6 @@ class HloScatterInstruction : public HloInstruction { CHECK(scatter_dimension_numbers_ != nullptr); return *scatter_dimension_numbers_; } - // Returns the dump string of the scatter dimension numbers. - string ScatterDimensionNumbersToString() const; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1396,6 +1426,9 @@ class HloScatterInstruction : public HloInstruction { absl::Span inserted_window_dims, absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim); + // Returns the dump string of the given scatter dimension numbers. + static string ScatterDimensionNumbersToString( + const ScatterDimensionNumbers& scatter_dimension_numbers); private: std::vector ExtraAttributesToStringImpl( diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 436cccb1fb9..45d3e9c460e 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -255,7 +255,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -308,7 +308,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(add.1, constant.2) + ROOT less-than = pred[] compare(add.1, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -360,7 +360,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { loop_var.2 = (s32[], s32[], s32[]) parameter(0) get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 constant.1 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + ROOT less-than = pred[] compare(get-tuple-element.4, constant.1), direction=LT } ENTRY SimpleLoop { constant.2 = s32[] constant(0) @@ -415,7 +415,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -448,13 +448,13 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } OuterWhileCondition { cond_param.2 = (s32[]) parameter(0) get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0 constant.5 = s32[] constant(5) - ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5) + ROOT less-than.2 = pred[] compare(get-tuple-element.5, constant.5), direction=LT } OuterWhileBody { body_param.2 = (s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index d28e79d41ad..47ed85be196 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -89,6 +89,22 @@ bool HloParameterMatcher::MatchAndExplain( return true; } +bool HloComparisonMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->comparison_direction() != direction_) { + *listener << "has wrong comparison direction (got " + << ComparisonDirectionToString( + instruction->comparison_direction()) + << ", want " << ComparisonDirectionToString(direction_) << ")"; + return false; + } + return true; +} + bool HloGetTupleElementMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 67488a6a9a0..cf0f4bc912c 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -54,6 +54,21 @@ class HloParameterMatcher : public HloMatcher { int64 parameter_number_; }; +// Custom matcher for comparisons, which accepts a comparison direction. +class HloComparisonMatcher : public HloMatcher { + public: + explicit HloComparisonMatcher( + ComparisonDirection direction, + std::vector<::testing::Matcher> operands) + : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + ComparisonDirection direction_; +}; + // Custom matcher for get-tuple-element instructions, which accepts a tuple // index to match. class HloGetTupleElementMatcher : public HloMatcher { @@ -172,6 +187,7 @@ HLO_MATCHER(BatchNormGrad); HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); +HLO_MATCHER(Compare); HLO_MATCHER(Concatenate); HLO_MATCHER(Conditional); HLO_MATCHER(Constant); @@ -184,31 +200,27 @@ HLO_MATCHER(Divide); HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); -HLO_MATCHER(Eq); HLO_MATCHER(Exp); +HLO_MATCHER(Fft); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); -HLO_MATCHER(Ge); HLO_MATCHER(AfterAll); -HLO_MATCHER(Gt); HLO_MATCHER(Iota); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); -HLO_MATCHER(Le); HLO_MATCHER(Log); HLO_MATCHER(And); HLO_MATCHER(Not); HLO_MATCHER(Or); HLO_MATCHER(Xor); -HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); HLO_MATCHER(Minimum); HLO_MATCHER(Multiply); -HLO_MATCHER(Ne); HLO_MATCHER(Negate); HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); +HLO_MATCHER(PartitionId); HLO_MATCHER(Power); HLO_MATCHER(Recv); HLO_MATCHER(RecvDone); @@ -216,6 +228,7 @@ HLO_MATCHER(Reduce); HLO_MATCHER(ReducePrecision); HLO_MATCHER(ReduceWindow); HLO_MATCHER(Remainder); +HLO_MATCHER(ReplicaId); HLO_MATCHER(Reshape); HLO_MATCHER(Reverse); HLO_MATCHER(Rng); @@ -256,6 +269,38 @@ inline ::testing::Matcher Parameter() { new ::xla::testing::HloMatcher(HloOpcode::kParameter, {})); } +// Comparison matchers below do not require any additional arguments. +template +inline ::testing::Matcher Eq(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kEq, {operands...})); +} +template +inline ::testing::Matcher Ne(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kNe, {operands...})); +} +template +inline ::testing::Matcher Ge(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kGe, {operands...})); +} +template +inline ::testing::Matcher Gt(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kGt, {operands...})); +} +template +inline ::testing::Matcher Le(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kLe, {operands...})); +} +template +inline ::testing::Matcher Lt(M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher( + ComparisonDirection::kLt, {operands...})); +} + // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th // tuple element of operand, while GetTupleElement(operand) matches any GTE // operation on operand, and GetTupleElement() matches any GTE operation at all. diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 7961aece541..549fc603c70 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -220,5 +220,33 @@ ENTRY DotOperationFusion_TransposeFusion { "rhs_contracting_dimensions (got {0} want {1})"); } +TEST(HloMatchersTest, ComparisonMatcher) { + auto shape = ShapeUtil::MakeShape(F32, {1}); + auto p0 = HloInstruction::CreateParameter(0, shape, "param.0"); + auto p1 = HloInstruction::CreateParameter(1, shape, "param.1"); + auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kEq); + auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kNe); + auto add = + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get()); + auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(), + ComparisonDirection::kLe); + + EXPECT_THAT(eq.get(), op::Compare()); + EXPECT_THAT(eq.get(), op::Eq()); + EXPECT_THAT(ne.get(), op::Compare()); + EXPECT_THAT(ne.get(), op::Ne()); + EXPECT_THAT(le.get(), + op::Compare(op::Parameter(0), + op::Add(op::Parameter(0), op::Parameter(1)))); + EXPECT_THAT(le.get(), op::Le(op::Parameter(0), + op::Add(op::Parameter(0), op::Parameter(1)))); + + EXPECT_THAT(Explain(eq.get(), op::Add()), Eq("")); + EXPECT_THAT(Explain(eq.get(), op::Ne()), + Eq("has wrong comparison direction (got EQ, want NE)")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index d2740bcce26..ba3c06981e1 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -126,6 +127,7 @@ class ListScheduler { // Create map containing the number of unscheduled uses (hlo instructions) // of each logical buffer. + unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers()); for (auto* instruction : computation->instructions()) { for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { @@ -205,6 +207,7 @@ class ListScheduler { // than not taking subcomputations into account at all. In the future, we may // improve accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { + auto instruction = entry.instruction; int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { auto buffer = kv->first; @@ -216,7 +219,7 @@ class ListScheduler { // We only count the memory usage of the largest subcomputation, instead of // adding them all, because subcomputations won't execute in parallel. int64 max_subcomputation_bytes = 0; - for (const auto* c : entry.instruction->called_computations()) { + for (const auto* c : instruction->called_computations()) { auto it = memory_by_computation_.find(c); if (it != memory_by_computation_.end()) { int64 subcomputation_bytes = it->second; @@ -226,10 +229,10 @@ class ListScheduler { } } int64 bytes_defined; + auto opcode = instruction->opcode(); if (max_subcomputation_bytes > 0 && - (entry.instruction->opcode() == HloOpcode::kWhile || - entry.instruction->opcode() == HloOpcode::kCall || - entry.instruction->opcode() == HloOpcode::kConditional)) { + (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kConditional)) { // The output buffer of while/call/conditional is always aliased with the // output buffer of the root instruction in the body. Don't double count. bytes_defined = max_subcomputation_bytes; @@ -459,6 +462,7 @@ StatusOr DFSMemoryScheduler( sequence.push_back(hlo); return Status::OK(); }); + visitor.ReserveVisitStates(computation->instruction_count()); TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder( &visitor, [&extra_users, &total_sizes](const HloInstruction* a, const HloInstruction* b) { @@ -611,11 +615,13 @@ StatusOr HloTrivialScheduler::Run(HloModule* module) { if (!computation->IsFusionComputation()) { HloInstructionSequence& computation_sequence = schedule.GetOrCreateSequence(computation); - TF_RETURN_IF_ERROR(computation->Accept( + FunctionVisitor visitor( [&computation_sequence](HloInstruction* instruction) { computation_sequence.push_back(instruction); return Status::OK(); - })); + }); + visitor.ReserveVisitStates(computation->instruction_count()); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); } } TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 76cc29cbb78..d42ec929c1f 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -98,13 +98,16 @@ class HloMemoryScheduler : public HloModulePass { // specified, then DefaultMemoryScheduler is used. HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); + ~HloMemoryScheduler() override = default; + absl::string_view name() const override { return "hlo-memory-scheduler"; } StatusOr Run(HloModule* module) override; private: LogicalBuffer::SizeFunction size_function_; + MemorySchedulerAlgorithm algorithm_; }; diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index bc0d7e2bc00..80a0c444e94 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -147,6 +147,47 @@ ENTRY root { instructions_by_name.at("e"))); } +TEST_F(HloSchedulingTest, HostSendDoneSchedule) { + const char* const module_str = R"( +HloModule module + +ENTRY entry { + %p = f32[1000, 1000] parameter(0) + %token.0 = token[] after-all() + %send = (f32[1000, 1000], token[]) send(%p, %token.0), + channel_id=0, is_host_transfer=true + %n1 = f32[1000, 1000] negate(%p) + %n2 = f32[1000, 1000] negate(%n1) + %n3 = f32[1000, 1000] negate(%n2) + %send-done = token[] send-done(%send), channel_id=0, is_host_transfer=true +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); + + std::unordered_map instructions_by_name; + for (const HloInstruction* instruction : sequence) { + instructions_by_name[instruction->name()] = instruction; + } + + SequentialHloOrdering ordering(schedule); + EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("send-done"), + instructions_by_name.at("n1"))); +} + TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto builder = HloComputation::Builder(TestName()); const auto TUPLE_SIZE = 1; @@ -254,8 +295,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { HloInstruction* zero_vector = cond_builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({0, 0, 0, 0}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_param, + zero_vector, ComparisonDirection::kNe)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); // param - 1 diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index edcda8f9a7b..135e10081ae 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -132,6 +132,7 @@ void HloModule::ReplaceComputations( for (std::unique_ptr& computation : computations_) { for (auto* instruction : computation->instructions()) { switch (instruction->opcode()) { + case HloOpcode::kAllReduce: case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduce: diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 68c18836eb0..cee46fe10a2 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -101,6 +102,20 @@ class HloModuleConfig { return intra_op_parallelism_threads_; } + // Checks if this config has a static device assignment. + bool has_static_device_assignment() const { + return static_device_assignment_.has_value(); + } + + // Getter and setter of the compile-time known device assignment. + const DeviceAssignment& static_device_assignment() const { + CHECK(static_device_assignment_.has_value()); + return *static_device_assignment_; + } + void set_static_device_assignment(const DeviceAssignment& device_assignment) { + static_device_assignment_ = device_assignment; + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -117,6 +132,9 @@ class HloModuleConfig { int64 intra_op_parallelism_threads_ = -1; DebugOptions debug_options_; + + // Compile-time known device assignment. + absl::optional static_device_assignment_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index f6e28662049..84988a9ecb3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -86,7 +86,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -125,7 +125,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { loop_var.2 = (s32[], f32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.3 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.3), direction=LT } ENTRY SimpleLoop { constant.4 = s32[] constant(0) @@ -163,7 +163,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -206,7 +206,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { loop_var.2 = (s32[], s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -248,7 +248,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } SimpleLoop.body1 { loop_var.3 = (s32[], s32[3]{0}) parameter(0) @@ -263,7 +263,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { loop_var.4 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 constant.4 = s32[] constant(5) - ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT } ENTRY SimpleLoop { constant.5 = s32[] constant(0) @@ -316,7 +316,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { loop_var.2 = (s32[3]{0}, s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 constant.2 = s32[] constant(5) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } SimpleLoop.body1 { loop_var.3 = (s32[], s32[3]{0}) parameter(0) @@ -331,7 +331,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { loop_var.4 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 constant.4 = s32[] constant(5) - ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT } ENTRY SimpleLoop { constant.5 = s32[] constant(0) @@ -383,7 +383,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { cond_param = (s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(0) @@ -418,7 +418,7 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { cond_param = (s32[], s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { p0 = (s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index bc258a77000..ab65bb4685e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -122,8 +123,9 @@ Status HloModuleGroupMetadata::Build() { // Visit the computations in postorder so that the companion information grows // from inner computations to outer ones. for (HloModule* module : modules_) { + FunctionVisitor function_visitor(visitor); for (HloComputation* computation : module->MakeComputationPostOrder()) { - TF_RETURN_IF_ERROR(computation->Accept(visitor)); + TF_RETURN_IF_ERROR(computation->Accept(&function_visitor)); } } TF_RETURN_IF_ERROR(VerifyCompanionSets()); @@ -370,8 +372,9 @@ Status HloModuleGroupMetadata::RecordInstructions() { }; for (HloModule* module : modules_) { + FunctionVisitor function_visitor(visitor); for (auto* computation : module->computations()) { - TF_RETURN_IF_ERROR(computation->Accept(visitor)); + TF_RETURN_IF_ERROR(computation->Accept(&function_visitor)); } } VLOG(2) << "Created " << channels_.size() << " channels"; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 91417bd2d9a..9b7f54c5c6f 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -207,6 +207,34 @@ std::vector HloModuleGroupUtil::RootInstructions( return roots; } +string HloModuleGroupUtil::CycleToString(HloInstruction* init_instruction) { + std::vector names; + absl::flat_hash_set seen; + + std::function helper = + [&](HloInstruction* instruction) { + if (seen.find(instruction) != seen.end()) { + if (instruction == init_instruction) { + names.push_back(instruction->name()); + return true; + } + return false; + } + seen.insert(instruction); + for (HloInstruction* predecessor : GlobalPredecessors(instruction)) { + bool init_found = helper(predecessor); + if (init_found) { + names.push_back(instruction->name()); + return true; + } + } + return false; + }; + + helper(init_instruction); + return absl::StrJoin(names, " --> "); +} + Status HloModuleGroupUtil::VisitTopologicalOrder( VisitStates* visit_state, const VisitFunction& visit_function, HloInstruction* root) { @@ -269,22 +297,9 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( // a cycle. Generate an error with the list of instructions in the // cycle. if ((*visit_state)[predecessor] == VisitState::kVisiting) { - string cyclic_instructions; - for (const auto& state : *visit_state) { - if (state.second == VisitState::kVisiting) { - absl::StrAppend(&cyclic_instructions, state.first->ToString(), - "\n"); - } - } - // TODO(b/64305524): Improve the error message to print out the - // instructions in a deterministic order that forms the cycle. return FailedPrecondition( - "Cross-computation cycle detected via communicating nodes. The " - "cycle contains the node %s. The cycle is found among the " - "following nodes. Note that the order of the nodes is arbitrary " - "and that the list may include nodes that are not part of the " - "cycle.\n%s", - predecessor->ToString(), cyclic_instructions); + "Cross-computation cycle detected via communicating nodes.\n%s", + CycleToString(predecessor)); } stack.push(predecessor); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index 862666b48c9..d388fe51d0d 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -108,6 +108,8 @@ class HloModuleGroupUtil { HloInstruction* instruction, HloReachabilityMap* reachability_map); private: + string CycleToString(HloInstruction* instruction); + const HloModuleGroupMetadata& metadata_; }; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 548fbb873aa..8f459107b32 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -44,21 +44,8 @@ StatusOr StringToHloOpcode(const string& opcode_name) { return it->second; } -#define CHECK_DEFAULT(property_name, opcode_name) false -#define CHECK_PROPERTY(property_name, opcode_name, value) \ - (value & property_name) -#define RESOLVE(_1, _2, target, ...) target -#define HAS_PROPERTY(property, ...) \ - RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__) - bool HloOpcodeIsComparison(HloOpcode opcode) { - switch (opcode) { -#define CASE_IS_COMPARISON(enum_name, opcode_name, ...) \ - case HloOpcode::enum_name: \ - return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__); - HLO_OPCODE_LIST(CASE_IS_COMPARISON) -#undef CASE_IS_COMPARISON - } + return opcode == HloOpcode::kCompare; } bool HloOpcodeIsVariadic(HloOpcode opcode) { @@ -82,9 +69,4 @@ absl::optional HloOpcodeArity(HloOpcode opcode) { } } -#undef HAS_PROPERTY -#undef RESOLVE -#undef CHECK_DEFAULT -#undef CHECK_PROPERTY - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 3e144c4472f..ecd4eb3cbc0 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -19,8 +19,10 @@ limitations under the License. #include #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,6 +67,7 @@ namespace xla { V(kClamp, "clamp", 3) \ V(kCollectivePermute, "collective-permute", 1) \ V(kClz, "count-leading-zeros", 1) \ + V(kCompare, "compare", 2) \ V(kComplex, "complex", 2) \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ V(kConditional, "conditional", kHloOpcodeIsVariadic) \ @@ -79,38 +82,34 @@ namespace xla { V(kDot, "dot", 2) \ V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \ V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \ - V(kEq, "equal-to", 2, kHloOpcodeIsComparison) \ V(kExp, "exponential", 1) \ V(kExpm1, "exponential-minus-one", 1) \ V(kFft, "fft", 1) \ V(kFloor, "floor", 1) \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather", 2) \ - V(kGe, "greater-than-or-equal-to", 2, kHloOpcodeIsComparison) \ V(kGetDimensionSize, "get-dimension-size", 1) \ V(kGetTupleElement, "get-tuple-element", 1) \ - V(kGt, "greater-than", 2, kHloOpcodeIsComparison) \ V(kImag, "imag", 1) \ V(kInfeed, "infeed", 1) \ V(kIota, "iota", 0) \ V(kIsFinite, "is-finite", 1) \ - V(kLe, "less-than-or-equal-to", 2, kHloOpcodeIsComparison) \ V(kLog, "log", 1) \ V(kLog1p, "log-plus-one", 1) \ V(kAnd, "and", 2) \ V(kNot, "not", 1) \ V(kOr, "or", 2) \ V(kXor, "xor", 2) \ - V(kLt, "less-than", 2, kHloOpcodeIsComparison) \ V(kMap, "map", kHloOpcodeIsVariadic) \ V(kMaximum, "maximum", 2) \ V(kMinimum, "minimum", 2) \ V(kMultiply, "multiply", 2) \ - V(kNe, "not-equal-to", 2, kHloOpcodeIsComparison) \ V(kNegate, "negate", 1) \ V(kOutfeed, "outfeed", 2) \ V(kPad, "pad", 2) \ V(kParameter, "parameter", 0) \ + V(kPartitionId, "partition-id", 0) \ + V(kPopulationCount, "popcnt", 1) \ V(kPower, "power", 2) \ V(kReal, "real", 1) \ V(kRecv, "recv", 1) \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 910cc25a591..136e6702b21 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -42,12 +42,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { // Test some properties. switch (opcode) { - case HloOpcode::kEq: - case HloOpcode::kNe: - case HloOpcode::kGt: - case HloOpcode::kLt: - case HloOpcode::kGe: - case HloOpcode::kLe: + case HloOpcode::kCompare: EXPECT_TRUE(HloOpcodeIsComparison(opcode)); break; default: diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 831771fe63b..a4804a8faef 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -190,17 +190,30 @@ bool HloOrdering::UseIsBeforeValueDefinition( } // The use at a while is an input to a phi, and logically occurs before values - // are defined in the body or condition computations. + // are defined in the body. Note that the use is *not* before the value if the + // value is defined in the condition and is not the condition parameter, since + // the input of a while's life range is only ended at the start the body. if (use.instruction->opcode() == HloOpcode::kWhile) { const HloInstruction* xla_while = use.instruction; if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - xla_while->while_body()) || - call_graph_->InstructionIsNestedIn(value.defining_instruction(), - xla_while->while_condition())) { + xla_while->while_body())) { VLOG(4) << " use is while " << use.instruction->name() - << " and def is in condition or body"; + << " and def is in body"; return true; } + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_condition())) { + if (value.defining_instruction() != + xla_while->while_condition()->parameter_instruction(0)) { + VLOG(4) << " use is while " << use.instruction->name() + << " and def is in condition and is not the parameter"; + return false; + } else { + VLOG(4) << " use is while " << use.instruction->name() + << " and def is in condition and is the parameter"; + return true; + } + } } // Similarly if the value is defined at a while, it logically occurs after any @@ -263,10 +276,23 @@ bool HloOrdering::LiveRangeStrictlyBefore( } if (a.live_out_of_module()) { - VLOG(4) << a << " is live out of module and defined before " << b; + VLOG(4) << a << " is live out of module and not defined before " << b; return false; } + // If the root instruction aliases the buffer 'a', the live range of 'a' is + // until the end of the computation and can never be strictly before another + // buffer nested in the same computation. This is needed to prevent the root + // instruction's buffers from being reused by later instructions even when + // the root is not the last instruction in the schedule. + for (const HloPosition& pos : a.positions()) { + if (pos.instruction->parent()->root_instruction() == pos.instruction && + call_graph().InstructionIsNestedIn(b.instruction(), + pos.instruction->parent())) { + return false; + } + } + // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 3ca77e60cd5..11408114ab1 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -247,6 +247,11 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( dataflow->GetValueDefinedAt(constant), dataflow->GetValueDefinedAt(xla_while), *dataflow)); + // Value defined as init of while interferes with instructions in the + // condition other than the parameter. + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(convert), *dataflow)); EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant), dataflow->GetValueDefinedAt(xla_while), *dataflow)); @@ -261,8 +266,10 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate), dataflow->GetValueDefinedAt(xla_while), *dataflow)); - - EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert), + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(xla_while), + *dataflow)); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant), dataflow->GetValueDefinedAt(xla_while))); EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( dataflow->GetValueDefinedAt(convert), @@ -306,7 +313,7 @@ condition.v4 { constant.2 = s32[] constant(2) prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0 - ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8) + ROOT greater-than = pred[] compare(constant.2, get-tuple-element.8), direction=GT } fused_computation { @@ -496,5 +503,36 @@ TEST_F(HloOrderingTest, *dataflow)); } +TEST_F(HloOrderingTest, InterferenceWithOuterRoot) { + absl::string_view hlo_string = R"( +HloModule InterferenceWithOuterRoot, is_scheduled=true + +Emmbedded (embedded_param: f32[42]) -> f32[42] { + embedded_param = f32[42]{0} parameter(0) + multiply = f32[42]{0} multiply(embedded_param, embedded_param) + ROOT log = f32[42]{0} log(multiply) +} + +ENTRY InterferenceWithOuterRoot { + param = f32[4096,4096]{1,0} parameter(0) + ROOT add = f32[4096,4096]{1,0} add(param, param) + call = f32[42]{0} call(param), to_apply=Emmbedded +} + +)"; + HloModuleConfig hlo_config; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string, hlo_config)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + DependencyHloOrdering ordering(module.get()); + auto multiply = FindInstruction(module.get(), "multiply"); + auto add = FindInstruction(module.get(), "add"); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(multiply), + dataflow->GetValueDefinedAt(add), + *dataflow)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index fd55d92c04a..f32346806af 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/protobuf.h" @@ -88,6 +89,7 @@ class HloParser { StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); StatusOr ParsePaddingConfigOnly(); + StatusOr> ParseReplicaGroupsOnly(); private: using InstrNameTable = @@ -183,6 +185,7 @@ class HloParser { kHloComputation, kBracedHloComputationList, kFftType, + kComparisonDirection, kWindow, kConvolutionDimensionNumbers, kSharding, @@ -267,6 +270,7 @@ class HloParser { bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); + bool ParseReplicaGroupsOnly(std::vector* replica_groups); // Parses the metadata behind a kDOmain instruction. bool ParseDomain(DomainData* domain); @@ -283,6 +287,9 @@ class HloParser { bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + bool ParseInt64ListList(const TokKind start, const TokKind end, + const TokKind delim, + std::vector>* result); // 'parse_and_add_item' is an lambda to parse an element in the list and add // the parsed element to the result. It's supposed to capture the result. bool ParseList(const TokKind start, const TokKind end, const TokKind delim, @@ -300,6 +307,7 @@ class HloParser { bool ParseTiles(std::vector* tiles); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); + bool ParseComparisonDirection(ComparisonDirection* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); bool ParsePrecision(PrecisionConfig::Precision* result); @@ -678,6 +686,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional backend_config; attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; + optional> outer_dimension_partitions; + attrs["outer_dimension_partitions"] = {/*required=*/false, + AttrTy::kBracedInt64List, + &outer_dimension_partitions}; HloInstruction* instruction; switch (opcode) { @@ -742,6 +754,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: case HloOpcode::kSign: @@ -763,12 +776,6 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kSubtract: case HloOpcode::kAtan2: case HloOpcode::kComplex: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -888,6 +895,15 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, instruction = builder->AddInstruction(HloInstruction::CreateReplicaId()); break; } + case HloOpcode::kPartitionId: { + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreatePartitionId()); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -1133,6 +1149,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, shape, operands[0], operands[1], options)); break; } + case HloOpcode::kCompare: { + optional direction; + attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection, + &direction}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCompare( + shape, operands[0], operands[1], *direction)); + break; + } case HloOpcode::kCholesky: { CholeskyOptions options; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1444,6 +1472,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (!ParseOperands(&operands)) { return false; } + if (!ShapeUtil::IsScalar(operands[0]->shape())) { + return Error(lexer_.GetLoc(), "The first operand must be a scalar"); + } const bool branch_index_is_bool = operands[0]->shape().element_type() == PRED; if (branch_index_is_bool) { @@ -1452,6 +1483,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, attrs["false_computation"] = { /*required=*/true, AttrTy::kHloComputation, &false_computation}; } else { + if (operands[0]->shape().element_type() != S32) { + return Error(lexer_.GetLoc(), + "The first operand must be a scalar of PRED or S32"); + } attrs["branch_computations"] = {/*required=*/true, AttrTy::kBracedHloComputationList, &branch_computations}; @@ -1474,7 +1509,6 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional custom_call_target; - optional opaque; optional window; optional dnums; optional feature_group_count; @@ -1482,7 +1516,6 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; - attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; @@ -1530,11 +1563,11 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( shape, operands, *custom_call_target, *operand_layout_constraints, - opaque.has_value() ? *opaque : "")); + backend_config ? *backend_config : "")); } else { instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( shape, operands, *custom_call_target, - opaque.has_value() ? *opaque : "")); + backend_config ? *backend_config : "")); } if (window.has_value()) { instruction->set_window(*window); @@ -1741,6 +1774,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } + if (outer_dimension_partitions) { + instruction->set_outer_dimension_partitions(*outer_dimension_partitions); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1911,6 +1947,25 @@ bool HloParser::ParseParameterReplication( "expected '}' to end parameter_replication attribute"); } +// replica_groups ::='{' int64list_elements '}' +// int64list_elements +// ::= /*empty*/ +// ::= int64list (',' int64list)* +// int64list ::= '{' int64_elements '}' +// int64_elements +// ::= /*empty*/ +// ::= int64_val (',' int64_val)* +bool HloParser::ParseReplicaGroupsOnly( + std::vector* replica_groups) { + std::vector> result; + if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + &result)) { + return false; + } + *replica_groups = CreateReplicaGroups(result); + return true; +} + // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' // 'exit=' exit_sharding '}' bool HloParser::ParseDomain(DomainData* domain) { @@ -2637,7 +2692,7 @@ bool HloParser::ParseAttributeHelper( if (!ParseAttributeName(&name)) { return Error(loc, "error parsing attributes"); } - VLOG(1) << "Parsing attribute " << name; + VLOG(3) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { return Error(loc, StrFormat("attribute %s already exists", name)); } @@ -2728,6 +2783,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kComparisonDirection: { + ComparisonDirection result; + if (!ParseComparisonDirection(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kWindow: { Window result; if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { @@ -2792,17 +2856,8 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kBracedInt64ListList: { std::vector> result; - auto parse_and_add_item = [&]() { - std::vector item; - if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, - TokKind::kComma, &item)) { - return false; - } - result.push_back(item); - return true; - }; - if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, - parse_and_add_item)) { + if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, + TokKind::kComma, &result)) { return false; } static_cast>>*>(attr_out_ptr) @@ -2909,7 +2964,7 @@ bool HloParser::ParseAttributeAsProtoMessageHelper( if (!ParseAttributeName(&name)) { return Error(loc, "error parsing attributes"); } - VLOG(1) << "Parsing attribute " << name; + VLOG(3) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { return Error(loc, StrFormat("attribute %s already exists", name)); } @@ -3311,6 +3366,28 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, return ParseList(start, end, delim, parse_and_add_item); } +// int64listlist ::= start int64list_elements end +// int64list_elements +// ::= /*empty*/ +// ::= int64list (delim int64list)* +// int64list ::= start int64_elements end +// int64_elements +// ::= /*empty*/ +// ::= int64_val (delim int64_val)* +bool HloParser::ParseInt64ListList(const TokKind start, const TokKind end, + const TokKind delim, + std::vector>* result) { + auto parse_and_add_item = [&]() { + std::vector item; + if (!ParseInt64List(start, end, delim, &item)) { + return false; + } + result->push_back(item); + return true; + }; + return ParseList(start, end, delim, parse_and_add_item); +} + bool HloParser::ParseList(const TokKind start, const TokKind end, const TokKind delim, const std::function& parse_and_add_item) { @@ -3594,7 +3671,7 @@ bool HloParser::CanBeShape() { } bool HloParser::ParseName(string* result) { - VLOG(1) << "ParseName"; + VLOG(3) << "ParseName"; if (lexer_.GetKind() != TokKind::kIdent && lexer_.GetKind() != TokKind::kName) { return TokenError("expects name"); @@ -3614,7 +3691,7 @@ bool HloParser::ParseAttributeName(string* result) { } bool HloParser::ParseString(string* result) { - VLOG(1) << "ParseString"; + VLOG(3) << "ParseString"; if (lexer_.GetKind() != TokKind::kString) { return TokenError("expects string"); } @@ -3728,7 +3805,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { } bool HloParser::ParseOpcode(HloOpcode* result) { - VLOG(1) << "ParseOpcode"; + VLOG(3) << "ParseOpcode"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects opcode"); } @@ -3744,7 +3821,7 @@ bool HloParser::ParseOpcode(HloOpcode* result) { } bool HloParser::ParseFftType(FftType* result) { - VLOG(1) << "ParseFftType"; + VLOG(3) << "ParseFftType"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects fft type"); } @@ -3756,8 +3833,24 @@ bool HloParser::ParseFftType(FftType* result) { return true; } +bool HloParser::ParseComparisonDirection(ComparisonDirection* result) { + VLOG(1) << "ParseComparisonDirection"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects comparison direction"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToComparisonDirection(val); + if (!status_or_result.ok()) { + return TokenError( + StrFormat("expects comparison direction but sees: %s", val)); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { - VLOG(1) << "ParseFusionKind"; + VLOG(3) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects fusion kind"); } @@ -3774,7 +3867,7 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { } bool HloParser::ParseRandomDistribution(RandomDistribution* result) { - VLOG(1) << "ParseRandomDistribution"; + VLOG(3) << "ParseRandomDistribution"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects random distribution"); } @@ -3791,7 +3884,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { } bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { - VLOG(1) << "ParsePrecision"; + VLOG(3) << "ParsePrecision"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects random distribution"); } @@ -3808,7 +3901,7 @@ bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { } bool HloParser::ParseInt64(int64* result) { - VLOG(1) << "ParseInt64"; + VLOG(3) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); } @@ -3897,7 +3990,7 @@ bool HloParser::ParseBool(bool* result) { } bool HloParser::ParseToken(TokKind kind, const string& msg) { - VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg; + VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg; if (lexer_.GetKind() != kind) { return TokenError(msg); } @@ -3974,6 +4067,18 @@ StatusOr> HloParser::ParseParameterReplicationOnly() { parameter_replication.replicated_at_leaf_buffers().end()); } +StatusOr> HloParser::ParseReplicaGroupsOnly() { + lexer_.Lex(); + std::vector replica_groups; + if (!ParseReplicaGroupsOnly(&replica_groups)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after replica groups"); + } + return replica_groups; +} + StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; @@ -4053,6 +4158,14 @@ bool HloParser::ParseSingleInstruction(HloModule* module) { } } + if (lexer_.GetKind() != TokKind::kEof) { + Error( + lexer_.GetLoc(), + "Syntax error:\nExpected eof after parsing single instruction. Did " + "you mean to write an HLO module and forget the \"HloModule\" header?"); + return false; + } + module->AddEntryComputation(builder.Build()); for (auto& comp : computations_) { module->AddEmbeddedComputation(std::move(comp)); @@ -4094,6 +4207,12 @@ StatusOr> ParseParameterReplication(absl::string_view str) { return parser.ParseParameterReplicationOnly(); } +StatusOr> ParseReplicaGroupsOnly( + absl::string_view str) { + HloParser parser(str); + return parser.ParseReplicaGroupsOnly(); +} + StatusOr ParseWindow(absl::string_view str) { HloParser parser(str); return parser.ParseWindowOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index a96260b4d75..b18b03ff083 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -68,6 +68,12 @@ StatusOr ParsePaddingConfig(absl::string_view str); // Parses and returns a Shape::ToString-format string. StatusOr ParseShape(absl::string_view str); +// Parses and returns a std::vector from str. str is supposed to +// contain a list of the replica groups, i.e. just the rhs of the +// "replica_groups={...}" attribute string, e.g., "{{0,1}, {2,3}}". +StatusOr> ParseReplicaGroupsOnly( + absl::string_view str); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 1ba2d718ecc..9fc3af7254f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -222,7 +223,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} - %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated} + %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated} ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} } @@ -292,7 +293,7 @@ R"(HloModule WhileWithScalarS32Result_module %condition.v3 (prev.2: s32[]) -> pred[] { %constant.1 = s32[] constant(5) %prev.2 = s32[] parameter(0) - ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2) + ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT } ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { @@ -361,6 +362,18 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1 } +)" +}, +// CustomCall with backend_config. +{ +"CustomCallWithOpaque", +R"(HloModule custom_call + +ENTRY %CustomCall () -> f32[1,2,3] { + %constant = f32[1]{0} constant({12345}) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque" +} + )" }, // reduce window @@ -474,7 +487,7 @@ R"(HloModule R4F32OverlapSmall_module %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) + ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE } %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { @@ -500,7 +513,7 @@ R"(HloModule select_and_scatter_scalar %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) + ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE } %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { @@ -1037,7 +1050,7 @@ R"(HloModule TupleReduce max_argmax { value = f32[] parameter(2) prev_max = f32[] parameter(0) - is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + is_next_larger = pred[] compare(value, prev_max), direction=GE max = f32[] select(is_next_larger, value, prev_max) index = s32[] parameter(3) prev_argmax = s32[] parameter(1) @@ -1106,7 +1119,7 @@ R"(HloModule sort compare { p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1126,7 +1139,7 @@ compare { p.1.rhs = s32[] parameter(3) p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1145,7 +1158,7 @@ R"(HloModule sort compare { p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1165,7 +1178,7 @@ compare { p.1.rhs = s32[] parameter(3) p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1190,7 +1203,7 @@ compare { p.3.rhs = f32[] parameter(7) p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1211,7 +1224,7 @@ R"(HloModule sort compare { p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1285,18 +1298,6 @@ ENTRY CustomCall { ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar" } -)" -}, -// CustomCall with opaque value. -{ -"CustomCallWithOpaque", -R"(HloModule custom_call - -ENTRY CustomCall { - constant = f32[1]{0} constant({12345}) - ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque" -} - )" }, // Variables with non-default names @@ -1435,6 +1436,17 @@ ENTRY Replica-id { ROOT replica-id = u32[] replica-id() } +)" +}, +// partition-id +{ +"PartitionId", +R"(HloModule partition-id + +ENTRY PartitionId { + ROOT id = u32[] partition-id() +} + )" }, // Iota @@ -1469,7 +1481,7 @@ compare { p.1.rhs = s32[] parameter(3) p.0.lhs = f32[] parameter(0) p.0.rhs = f32[] parameter(1) - ROOT lhs = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY Sort { @@ -1533,6 +1545,29 @@ ENTRY MinMaxValues { ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)}) } +)" +}, + +// Bitcast-convert usage +{ +"BitcastConvert", +R"(HloModule BitcastConvert + +ENTRY BitcastConvertUsage { + p = f32[100]{0} parameter(0) + ROOT out = u32[100]{0} bitcast-convert(p) +} + +)" +}, +{ +"OuterDimensionPartitions", +R"(HloModule OuterDimensionPartitions + +ENTRY Test { + ROOT foo = f32[100]{0} parameter(0), outer_dimension_partitions={0,10,20} +} + )" }, }); @@ -1656,7 +1691,7 @@ TEST_F(HloParserTest, WrongOperandsSize) { ENTRY %blabla (x: f32[]) -> pred[] { %x = f32[]{} parameter(0) - %eq = pred[]{} equal-to(f32[]{} %x) + %eq = pred[]{} compare(f32[]{} %x), direction=EQ } )"; @@ -1668,7 +1703,7 @@ TEST_F(HloParserTest, OperandNotFound) { const string original = R"(HloModule operand_not_found: ENTRY %blabla (x: f32[]) -> pred[] { %x = f32[]{} parameter(0) - %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) + %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ } )"; auto result = ParseHloString(original); @@ -2263,6 +2298,13 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); } +TEST_F(HloParserTest, ParseReplicaGroups) { + const string original = "{{0,1},{2,3}}"; + TF_ASSERT_OK_AND_ASSIGN(std::vector replica_groups, + ParseReplicaGroupsOnly(original)); + EXPECT_EQ(original, ReplicaGroupsToString(replica_groups)); +} + TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) { const string original = "0_1x2_3"; TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original)); @@ -2470,6 +2512,16 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { EXPECT_EQ(convolution->feature_group_count(), 1); } +TEST(HloParserSingleOpTest, MultipleOpsProducesError) { + const string text = R"( + param = f32[2,5,1,3] parameter(0) + transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3} + )"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("Expected eof")); +} + TEST_F(HloParserTest, IsScheduledIsFalse) { const string text = R"( HloModule axpy_module, is_scheduled=false @@ -2832,5 +2884,89 @@ TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) { "parameter_replication has 3 elements")); } +TEST_F(HloParserTest, CheckIndexedConditionalDimension) { + const char* const hlo_string = R"( + HloModule Module + + branch0 { + tparam = f32[4] parameter(0) + ROOT tgte1 = f32[4] ceil(tparam) + } + + branch1 { + fparam = f32[4] parameter(0) + ROOT fgte1 = f32[4] floor(fparam) + } + + ENTRY entry { + p0 = f32[4] parameter(0) + b0 = s32[2] parameter(1) + ROOT conditional = f32[4] conditional(b0, p0, p0), + branch_computations={branch0, branch1} + } + )"; + auto result = ParseHloString(hlo_string); + EXPECT_NE(Status::OK(), result.status()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("The first operand must be a scalar")); +} + +TEST_F(HloParserTest, CheckIndexedConditionalElementType) { + const char* const hlo_string = R"( + HloModule Module + + branch0 { + tparam = f32[4] parameter(0) + ROOT tgte1 = f32[4] ceil(tparam) + } + + branch1 { + fparam = f32[4] parameter(0) + ROOT fgte1 = f32[4] floor(fparam) + } + + ENTRY entry { + p0 = f32[4] parameter(0) + b0 = f32[] parameter(1) + ROOT conditional = f32[4] conditional(b0, p0, p0), + branch_computations={branch0, branch1} + } + )"; + auto result = ParseHloString(hlo_string); + EXPECT_NE(Status::OK(), result.status()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr( + "The first operand must be a scalar of PRED or S32")); +} + +TEST_F(HloParserTest, + CheckPredicatedConditionalRequiresTrueAndFalseComputation) { + const char* const hlo_string = R"( + HloModule Module + + branch0 { + tparam = f32[4] parameter(0) + ROOT tgte1 = f32[4] ceil(tparam) + } + + branch1 { + fparam = f32[4] parameter(0) + ROOT fgte1 = f32[4] floor(fparam) + } + + ENTRY entry { + p0 = f32[4] parameter(0) + b0 = pred[] parameter(1) + ROOT conditional = f32[4] conditional(b0, p0, p0), + branch_computations={branch0, branch1} + } + )"; + auto result = ParseHloString(hlo_string); + EXPECT_NE(Status::OK(), result.status()); + EXPECT_THAT( + result.status().error_message(), + ::testing::HasSubstr("unexpected attribute \"branch_computations\"")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index fdaac34386c..b793f6f7276 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -41,6 +41,8 @@ class HloPassInterface { // module group. Ideally, the module group variant would be named "Run" as // well, but C++ does not handle overloaded virtual methods well. virtual StatusOr RunOnModuleGroup(HloModuleGroup* module_group) = 0; + + virtual bool IsPassPipeline() { return false; } }; // Base class for passes which are module-scoped. @@ -56,6 +58,14 @@ class HloModulePass : public HloPassInterface { } return changed; }; + + // Update the layout of a Shape to one that is supported by a given backend. + // One can call this function after modifying the Shape in case that modifying + // the Shape requires changes to the layout for the given Backend. + // + // TODO(b/129084868): Make this Backend dependent instead of requiring + // deriving from the pass the and overriding this function. + virtual void UpdateLayout(Shape* shape) {} }; // Base class for passes which are module-group scoped. These passes cannot run diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index ae8c08cf1d1..c1c17155876 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -57,14 +58,21 @@ StatusOr HloPassPipeline::RunPassesInternal( TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); bool changed = false; for (HloPassInterface* pass : passes) { - VLOG(1) << " HLO pass " << pass->name(); + absl::string_view pass_name = pass->name(); + VLOG(1) << " HLO pass " << pass_name; MaybeDumpHlo(*hlo, /*after_pass_name=*/last_pass_name, - /*before_pass_name=*/pass->name()); + /*before_pass_name=*/pass_name); + if (!pass->IsPassPipeline()) { + compilation_stats_->StartPass(pass_name); + } TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); changed |= pass_changed; - TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name())); - last_pass_name = string(pass->name()); + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name)); + last_pass_name = string(pass_name); + if (!pass->IsPassPipeline()) { + compilation_stats_->EndPass(pass_name); + } } MaybeDumpHlo(*hlo, /*after_pass_name=*/last_pass_name, @@ -99,30 +107,8 @@ std::vector HloPassPipeline::GetEnabledPasses( void HloPassPipeline::MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name, absl::string_view before_pass_name) { - const string& proto_dump_path = - module.config().debug_options().xla_dump_per_pass_hlo_proto_to(); - if (!proto_dump_path.empty()) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static auto* const module_id_to_pass_number = - new absl::flat_hash_map(); - - tensorflow::mutex_lock lock(mu); - const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; - - const string filename = SanitizeFileName( - absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), - pass_number, name(), after_pass_name)); - - TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory( - MakeHloProto(module), proto_dump_path, filename)); - } - - const string message = - absl::StrCat("after ", after_pass_name, ", before ", before_pass_name); - hlo_graph_dumper::MaybeDumpHloModule(module, message); - VLOG(3) << "HLO " << message << ":"; - VLOG(3) << module.entry_computation_layout().ToString(); - XLA_VLOG_LINES(3, module.ToString()); + DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name, + module); } void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group, diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 60d72b9d296..ad4070e3e23 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/compilation_stats.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,7 +35,14 @@ namespace xla { // Pipeline of HLO passes. class HloPassPipeline : public HloPassInterface { public: - explicit HloPassPipeline(const string& name) : name_(name) {} + explicit HloPassPipeline(const string& name, + CompilationStats* compilation_stats = nullptr) + : name_(name), compilation_stats_(compilation_stats) { + if (compilation_stats == nullptr) { + empty_compilation_stats_ = CompilationStats::MakeNoopStats(); + compilation_stats_ = empty_compilation_stats_.get(); + } + } absl::string_view name() const override { return name_; } // Add a pass to the pipeline. It should be called with the arguments for the @@ -65,6 +73,8 @@ class HloPassPipeline : public HloPassInterface { StatusOr Run(HloModule* module) override; StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override; + bool IsPassPipeline() override { return true; } + private: // Returns the set of passes which are enabled. DebugOptions can selectively // disable passes via --xla_disable_hlo_passes flag. @@ -105,6 +115,11 @@ class HloPassPipeline : public HloPassInterface { std::vector> passes_; std::vector> invariant_checkers_; bool run_called_ = false; + + CompilationStats* compilation_stats_; + // Default stats instance for when one is not passed in the constructor. + // Use via compilation_stats_, not directly. + std::unique_ptr empty_compilation_stats_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index b7f507b1184..af07eb83a5c 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -54,7 +54,9 @@ void HloReachabilityMap::SetReachabilityToUnionHelper( } bit_vector->Set(GetIndex(instruction)); for (const HloInstruction* input : inputs) { - bit_vector->OrWith(GetBitVector(input)); + if (input != instruction) { + bit_vector->OrWith(GetBitVector(input)); + } } } diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc new file mode 100644 index 00000000000..e11d3920f95 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc @@ -0,0 +1,299 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +namespace { + +// Determines whether an HLO instruction is replicated at index based on current +// knowledge in hlo_replication. +bool DetermineHloInstructionIsReplicated( + const HloInstruction* hlo, const ShapeIndex& index, + const absl::flat_hash_map>& + hlo_replication) { + if (hlo->HasSideEffectNoRecurse()) { + return false; + } + if (hlo->opcode() == HloOpcode::kReplicaId) { + return false; + } + auto it = hlo_replication.find(hlo); + if (hlo->opcode() == HloOpcode::kParameter) { + // Parameters should have been processed. + return it != hlo_replication.end() && it->second.element(index); + } + if (it != hlo_replication.end() && !it->second.element(index)) { + // The HLO is already marked as non-replicated. + return false; + } + if (hlo->opcode() == HloOpcode::kConstant) { + return true; + } + if (hlo->opcode() == HloOpcode::kAllReduce) { + // Only all-reduce across all cores are replicated, which means there + // is only one subgroup. + return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1; + } + + if (hlo->IsElementwise() || // + hlo->opcode() == HloOpcode::kConcatenate || // + hlo->opcode() == HloOpcode::kConvolution || // + hlo->opcode() == HloOpcode::kDot || // + hlo->opcode() == HloOpcode::kReduce || // + hlo->opcode() == HloOpcode::kBroadcast || // + hlo->opcode() == HloOpcode::kTranspose || // + hlo->opcode() == HloOpcode::kReshape || // + hlo->opcode() == HloOpcode::kBitcast || // + hlo->opcode() == HloOpcode::kReverse || // + hlo->opcode() == HloOpcode::kGather || // + hlo->opcode() == HloOpcode::kScatter || // + hlo->opcode() == HloOpcode::kIota || // + hlo->opcode() == HloOpcode::kPad || // + hlo->opcode() == HloOpcode::kSlice || // + hlo->opcode() == HloOpcode::kDynamicSlice || // + hlo->opcode() == HloOpcode::kDynamicUpdateSlice || // + hlo->opcode() == HloOpcode::kReduceWindow || // + hlo->opcode() == HloOpcode::kCopy) { + for (auto operand : hlo->operands()) { + auto operand_it = hlo_replication.find(operand); + if (operand_it == hlo_replication.end() || + !operand_it->second.element({})) { + return false; + } + } + return true; + } + return false; +} + +} // namespace + +bool HloReplicationAnalysis::ComputeHloReplicationOnComputation( + const HloComputation* computation, bool mark_everything_not_replicated) { + bool changed = false; + for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + // Assigns the shape tree to dest if dest doesn't have one yet, or combines + // it with the existing one by and'ing them. Returns if anything is updated. + auto assign_or_combine_shapetree = [&](ShapeTree&& to_combine, + const HloInstruction* dest) { + auto it = hlo_replication_.find(dest); + if (it == hlo_replication_.end()) { + hlo_replication_[dest] = std::move(to_combine); + return true; + } + bool updated = false; + it->second.ForEachMutableElement( + [&](const ShapeIndex& index, bool* element) { + if (*element && !to_combine.element(index)) { + *element = false; + updated = true; + } + }); + return updated; + }; + // Assigns or combines source's shape tree to dest. Returns if anything is + // updated. + auto propagate_shapetree = [&](const HloInstruction* source, + const HloInstruction* dest) { + auto source_it = hlo_replication_.find(source); + if (source_it == hlo_replication_.end()) { + return false; + } + return assign_or_combine_shapetree(ShapeTree(source_it->second), + dest); + }; + // For the opcodes below that we do special handling, we don't need to + // explicitly check mark_everything_not_replicated because if it is set, the + // operands should already be marked as not replicated. + if (inst->opcode() == HloOpcode::kWhile) { + // Since while body's input and output alias each other, we need to run it + // multiple times until a fixed point is reached. + while (true) { + // First, propagate the input's and body root's shape trees to the + // parameters of the body and condition. + bool updated = propagate_shapetree( + inst->operand(0), + inst->while_condition()->parameter_instruction(0)); + updated |= propagate_shapetree( + inst->while_body()->root_instruction(), + inst->while_condition()->parameter_instruction(0)); + updated |= propagate_shapetree( + inst->operand(0), inst->while_body()->parameter_instruction(0)); + updated |= + propagate_shapetree(inst->while_body()->root_instruction(), + inst->while_body()->parameter_instruction(0)); + // Compute the condition. + updated |= ComputeHloReplicationOnComputation( + inst->while_condition(), mark_everything_not_replicated); + // Compute the body. If the condition is not replicated, the while body + // should be different across replicas. + if (!ContainsKey(loops_known_with_same_iterations_, inst) && + !hlo_replication_[inst->while_condition()->root_instruction()] + .element({})) { + updated |= ComputeHloReplicationOnComputation( + inst->while_body(), /*mark_everything_not_replicated=*/true); + } else { + updated |= ComputeHloReplicationOnComputation( + inst->while_body(), mark_everything_not_replicated); + } + if (!updated) { + break; + } + changed = true; + } + // Propagate the input's and body root's shape trees to the while HLO. + changed |= propagate_shapetree(inst->operand(0), inst); + changed |= + propagate_shapetree(inst->while_body()->root_instruction(), inst); + } else if (inst->opcode() == HloOpcode::kCall || + inst->opcode() == HloOpcode::kFusion) { + auto called = inst->called_computations().front(); + for (int64 i = 0; i < inst->operand_count(); ++i) { + changed |= propagate_shapetree(inst->operand(i), + called->parameter_instruction(i)); + } + changed |= ComputeHloReplicationOnComputation( + called, mark_everything_not_replicated); + changed |= propagate_shapetree(called->root_instruction(), inst); + } else if (inst->opcode() == HloOpcode::kConditional) { + // Propagate inputs' shape trees to the called computations' parameters. + for (int64 i = 0; i < inst->called_computations().size(); ++i) { + changed |= propagate_shapetree( + inst->operand(i + 1), + inst->called_computations()[i]->parameter_instruction(0)); + } + // If the condition is not replicated, the conditional result should be + // different across replicas. + if (!hlo_replication_[inst->operand(0)].element({})) { + for (auto called : inst->called_computations()) { + changed |= ComputeHloReplicationOnComputation( + called, + /*mark_everything_not_replicated=*/true); + } + changed |= assign_or_combine_shapetree( + ShapeTree(inst->shape(), false), inst); + } else { + for (auto called : inst->called_computations()) { + changed |= ComputeHloReplicationOnComputation( + called, mark_everything_not_replicated); + changed |= propagate_shapetree(called->root_instruction(), inst); + } + } + } else if (inst->opcode() == HloOpcode::kTupleSelect) { + if (!hlo_replication_[inst->operand(0)].element({})) { + // The predicate is not replicated, so the result is different across + // replicas. + changed |= assign_or_combine_shapetree( + ShapeTree(inst->shape(), false), inst); + } else { + changed |= propagate_shapetree(inst->operand(1), inst); + changed |= propagate_shapetree(inst->operand(2), inst); + } + } else if (inst->opcode() == HloOpcode::kTuple) { + ShapeTree shape_tree(inst->shape(), true); + for (int64 i = 0; i < inst->operand_count(); ++i) { + shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(i)], {}, {i}); + } + changed |= assign_or_combine_shapetree(std::move(shape_tree), inst); + } else if (inst->opcode() == HloOpcode::kGetTupleElement) { + ShapeTree shape_tree(inst->shape(), true); + shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(0)], + {inst->tuple_index()}, {}); + changed |= assign_or_combine_shapetree(std::move(shape_tree), inst); + } else { + if (mark_everything_not_replicated) { + changed |= assign_or_combine_shapetree( + ShapeTree(inst->shape(), false), inst); + } else { + ShapeTree shape_tree(inst->shape(), true); + ShapeUtil::ForEachSubshape( + inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + *shape_tree.mutable_element(index) = + DetermineHloInstructionIsReplicated(inst, index, + hlo_replication_); + return Status::OK(); + }); + changed |= assign_or_combine_shapetree(std::move(shape_tree), inst); + } + } + } + return changed; +} + +void HloReplicationAnalysis::ComputeHloReplication() { + // Add entry parameters to the above sets according to user annotation. + auto entry = module_->entry_computation(); + for (int i = 0; i < entry->num_parameters(); ++i) { + auto param = entry->parameter_instruction(i); + ShapeTree shape_tree(param->shape(), false); + const auto& replication = param->parameter_replicated_at_leaf_buffers(); + int leaf_index = 0; + ShapeUtil::ForEachSubshape( + param->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(param->shape(), index)) { + return Status::OK(); + } + if (replication && replication->at(leaf_index)) { + *shape_tree.mutable_element(index) = true; + } + ++leaf_index; + return Status::OK(); + }); + hlo_replication_[param] = std::move(shape_tree); + } + ComputeHloReplicationOnComputation(entry, + /*mark_everything_not_replicated=*/false); +} + +bool HloReplicationAnalysis::HloInstructionIsReplicatedAt( + const HloInstruction* inst, const ShapeIndex& index) const { + auto it = hlo_replication_.find(inst); + if (it == hlo_replication_.end()) { + return false; + } + return it->second.element(index); +} + +/* static */ StatusOr> +HloReplicationAnalysis::Run(const HloModule* module) { + const absl::flat_hash_set empty; + return Run(module, &empty); +} + +/* static */ StatusOr> +HloReplicationAnalysis::Run(const HloModule* module, + const absl::flat_hash_set* + loops_known_with_same_iterations) { + auto analysis = absl::WrapUnique( + new HloReplicationAnalysis(module, loops_known_with_same_iterations)); + analysis->ComputeHloReplication(); + return analysis; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.h b/tensorflow/compiler/xla/service/hlo_replication_analysis.h new file mode 100644 index 00000000000..3175fc35102 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.h @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// An HLO pass that determines whether each instruction in the module outputs +// the same value across replicas. It propagates sources of replicated values to +// the rest of the module, where sources include cross-replica-sum, annotated +// entry parameters, and constants. +class HloReplicationAnalysis { + public: + // Runs the analysis on module and returns the result or an error. + static StatusOr> Run( + const HloModule* module); + + // Same as above, but the caller can provide additional annotations: a set of + // while loops that are known to have the same iteration counts across + // replicas. + static StatusOr> Run( + const HloModule* module, const absl::flat_hash_set* + loops_known_with_same_iterations); + + // Returns if the HLO instruction outputs the same value (i.e., replicated) at + // the given index across all replicas. + bool HloInstructionIsReplicatedAt(const HloInstruction* inst, + const ShapeIndex& index) const; + + private: + HloReplicationAnalysis(const HloModule* module, + const absl::flat_hash_set* + loops_known_with_same_iterations) + : module_(module), + loops_known_with_same_iterations_(*loops_known_with_same_iterations) {} + + // Computes hlo_replication_. + void ComputeHloReplication(); + + // A helper function to recursively compute hlo_replication on a computation. + // Returns whether hlo_replication_ is changed. + bool ComputeHloReplicationOnComputation(const HloComputation* computation, + bool mark_everything_not_replicated); + + const HloModule* module_; + + // A set of while loops that are known to have the same iteration counts + // across replicas. This is provided by the caller as additional annotations. + const absl::flat_hash_set& + loops_known_with_same_iterations_; + + // A map from each analyzed HLO instruction to a shape tree that represents + // whether the instruction outputs the same value across replicas at each + // shape index. + absl::flat_hash_map> hlo_replication_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc new file mode 100644 index 00000000000..ea1f5b2b4b8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -0,0 +1,447 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloReplicationAnalysisTest : public HloTestBase {}; + +TEST_F(HloReplicationAnalysisTest, NoControlFlow) { + const string module_str = R"( +HloModule NoControlFlow + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +ENTRY entry { + param = (f32[4096,4096]{1,0}, f32[4096,4096]{1,0}) parameter(0) + get-tuple-element.2 = f32[4096,4096]{1,0} get-tuple-element(param), index=0 + get-tuple-element.3 = f32[4096,4096]{1,0} get-tuple-element(param), index=1 + after-all.1 = token[] after-all() + infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1) + get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0 + dot = f32[4096,4096]{1,0} dot(get-tuple-element.5, get-tuple-element.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + all-reduce = f32[4096,4096]{1,0} all-reduce(dot), replica_groups={}, to_apply=sum + subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.3, all-reduce) + ROOT add = f32[4096,4096]{1,0} add(get-tuple-element.2, subtract) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{false, true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.2"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.3"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.5"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "dot"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "all-reduce"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "subtract"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "add"), {})); +} + +TEST_F(HloReplicationAnalysisTest, NestedCall) { + const string module_str = R"( +HloModule NestedCall + +fusion_computation { + fusion_p0 = f32[] parameter(0) + fusion_p1 = f32[] parameter(1) + add = f32[] add(fusion_p0, fusion_p0) + multiply = f32[] multiply(add, fusion_p1) + ROOT tuple = (f32[], f32[]) tuple(add, multiply) +} + +call_body { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT fusion = (f32[], f32[]) fusion(a, b), kind=kLoop, calls=fusion_computation +} + +ENTRY entry { + param = (f32[], f32[]) parameter(0) + get-tuple-element = f32[] get-tuple-element(param), index=0 + get-tuple-element.1 = f32[] get-tuple-element(param), index=1 + ROOT call = (f32[], f32[]) call(get-tuple-element, get-tuple-element.1), to_apply=call_body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{true, false}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "add"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "multiply"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "fusion"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "fusion"), {1})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "call"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "call"), {1})); +} + +TEST_F(HloReplicationAnalysisTest, SimpleWhileLoop) { + const string module_str = R"( +HloModule SimpleWhileLoop + +cond { + cond_param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(cond_param), index=1 + constant.3 = u32[] constant(5) + ROOT greater-than = pred[] compare(get-tuple-element, constant.3), direction=LT +} + +body { + body_param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + get-tuple-element.1 = f32[4096,4096]{1,0} get-tuple-element(body_param), index=0 + multiply = f32[4096,4096]{1,0} multiply(get-tuple-element.1, get-tuple-element.1) + get-tuple-element.6 = u32[] get-tuple-element(body_param), index=1 + constant.1 = u32[] constant(1) + add = u32[] add(get-tuple-element.6, constant.1) + ROOT tuple = (f32[4096,4096]{1,0}, u32[]) tuple(multiply, add) +} + +ENTRY SimpleWhileLoop { + param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + ROOT while = (f32[4096,4096]{1,0}, u32[]) while(param), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{true, true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple"), {0})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple"), {1})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "while"), {0})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "while"), {1})); +} + +TEST_F(HloReplicationAnalysisTest, + WhileLoopParameterAliasingNonReplicatedOutput) { + const string module_str = R"( +HloModule WhileLoopParameterAliasingNonReplicatedOutput + +cond { + cond_param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(cond_param), index=1 + constant.3 = u32[] constant(5) + ROOT greater-than = pred[] compare(get-tuple-element, constant.3), direction=LT +} + +body { + body_param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + get-tuple-element.1 = f32[4096,4096]{1,0} get-tuple-element(body_param), index=0 + multiply = f32[4096,4096]{1,0} multiply(get-tuple-element.1, get-tuple-element.1) + after-all.1 = token[] after-all() + infeed = (f32[4096,4096]{1,0}, token[]) infeed(after-all.1) + get-tuple-element.5 = f32[4096,4096]{1,0} get-tuple-element(infeed), index=0 + subtract = f32[4096,4096]{1,0} subtract(get-tuple-element.5, multiply) + get-tuple-element.6 = u32[] get-tuple-element(body_param), index=1 + constant.1 = u32[] constant(1) + add = u32[] add(get-tuple-element.6, constant.1) + ROOT tuple = (f32[4096,4096]{1,0}, u32[]) tuple(subtract, add) +} + +ENTRY WhileLoopParameterAliasingNonReplicatedOutput { + param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + ROOT while = (f32[4096,4096]{1,0}, u32[]) while(param), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{true, true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "multiply"), {})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple"), {0})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple"), {1})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "while"), {0})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "while"), {1})); +} + +TEST_F(HloReplicationAnalysisTest, WhileLoopDifferentCondition) { + const string module_str = R"( +HloModule WhileLoopDifferentCondition + +cond { + cond_param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(cond_param), index=1 + constant.3 = u32[] constant(5) + ROOT greater-than = pred[] compare(get-tuple-element, constant.3), direction=LT +} + +body { + body_param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + get-tuple-element.1 = f32[4096,4096]{1,0} get-tuple-element(body_param), index=0 + multiply = f32[4096,4096]{1,0} multiply(get-tuple-element.1, get-tuple-element.1) + get-tuple-element.6 = u32[] get-tuple-element(body_param), index=1 + replica-id = u32[] replica-id() + add = u32[] add(get-tuple-element.6, replica-id) + ROOT tuple = (f32[4096,4096]{1,0}, u32[]) tuple(multiply, add) +} + +ENTRY WhileLoopDifferentCondition { + param = (f32[4096,4096]{1,0}, u32[]) parameter(0) + ROOT while = (f32[4096,4096]{1,0}, u32[]) while(param), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{true, true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "while"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "while"), {1})); +} + +TEST_F(HloReplicationAnalysisTest, SimpleConditional) { + const string module_str = R"( +HloModule SimpleConditional + +Negate { + x = (f32[], f32[]) parameter(0) + get-tuple-element = f32[] get-tuple-element(x), index=0 + negate = f32[] negate(get-tuple-element) + get-tuple-element.1 = f32[] get-tuple-element(x), index=1 + negate.1 = f32[] negate(get-tuple-element.1) + ROOT tuple = (f32[], f32[]) tuple(negate, negate.1) +} + +Identity { + ROOT y = (f32[], f32[]) parameter(0) +} + +Floor { + z = (f32[], f32[]) parameter(0) + get-tuple-element.2 = f32[] get-tuple-element(z), index=0 + floor = f32[] floor(get-tuple-element.2) + get-tuple-element.3 = f32[] get-tuple-element(z), index=1 + floor.1 = f32[] floor(get-tuple-element.3) + ROOT tuple.1 = (f32[], f32[]) tuple(floor, floor.1) +} + +ENTRY entry { + param = ((f32[], f32[]), (f32[], f32[]), (f32[], f32[]), s32[]) parameter(0) + get-tuple-element.4 = (f32[], f32[]) get-tuple-element(param), index=0 + get-tuple-element.5 = (f32[], f32[]) get-tuple-element(param), index=1 + get-tuple-element.6 = (f32[], f32[]) get-tuple-element(param), index=2 + get-tuple-element.7 = s32[] get-tuple-element(param), index=3 + ROOT conditional = (f32[], f32[]) conditional(get-tuple-element.7, get-tuple-element.4, get-tuple-element.5, get-tuple-element.6), branch_computations={Negate, Identity, Floor} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{true, true, true, true, false, true, true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple"), {0})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple"), {1})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "y"), {0})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "y"), {1})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple.1"), {1})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "conditional"), {0})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "conditional"), {1})); +} + +TEST_F(HloReplicationAnalysisTest, ConditionalWithDifferentPredicates) { + const string module_str = R"( +HloModule ConditionalWithDifferentPredicates + +Negate { + x = (f32[], f32[]) parameter(0) + get-tuple-element = f32[] get-tuple-element(x), index=0 + negate = f32[] negate(get-tuple-element) + get-tuple-element.1 = f32[] get-tuple-element(x), index=1 + negate.1 = f32[] negate(get-tuple-element.1) + ROOT tuple = (f32[], f32[]) tuple(negate, negate.1) +} + +Identity { + ROOT y = (f32[], f32[]) parameter(0) +} + +Floor { + z = (f32[], f32[]) parameter(0) + get-tuple-element.2 = f32[] get-tuple-element(z), index=0 + floor = f32[] floor(get-tuple-element.2) + get-tuple-element.3 = f32[] get-tuple-element(z), index=1 + floor.1 = f32[] floor(get-tuple-element.3) + ROOT tuple.1 = (f32[], f32[]) tuple(floor, floor.1) +} + +ENTRY entry { + param = ((f32[], f32[]), (f32[], f32[]), (f32[], f32[])) parameter(0) + get-tuple-element.4 = (f32[], f32[]) get-tuple-element(param), index=0 + get-tuple-element.5 = (f32[], f32[]) get-tuple-element(param), index=1 + get-tuple-element.6 = (f32[], f32[]) get-tuple-element(param), index=2 + replica-id = u32[] replica-id() + id = s32[] bitcast-convert(replica-id) + ROOT conditional = (f32[], f32[]) conditional(id, get-tuple-element.4, + get-tuple-element.5, get-tuple-element.6), + branch_computations={Negate, Identity, Floor} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{true, true, true, true, true, true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple"), {1})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "y"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "y"), {1})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple.1"), {1})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "conditional"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "conditional"), {1})); +} + +TEST_F(HloReplicationAnalysisTest, SimpleTupleSelect) { + const string module_str = R"( +HloModule SimpleTupleSelect + +ENTRY entry { + param = ((f32[], f32[]), (f32[], f32[]), pred[]) parameter(0) + get-tuple-element.4 = (f32[], f32[]) get-tuple-element(param), index=0 + get-tuple-element.5 = (f32[], f32[]) get-tuple-element(param), index=1 + get-tuple-element.6 = pred[] get-tuple-element(param), index=2 + ROOT tuple-select = (f32[], f32[]) tuple-select(get-tuple-element.6, get-tuple-element.4, get-tuple-element.5) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{true, false, true, true, true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple-select"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple-select"), {1})); +} + +TEST_F(HloReplicationAnalysisTest, TupleSelectWithDifferentPredicates) { + const string module_str = R"( +HloModule TupleSelectWithDifferentPredicates + +ENTRY entry { + param = ((f32[], f32[]), (f32[], f32[]), pred[]) parameter(0) + get-tuple-element.4 = (f32[], f32[]) get-tuple-element(param), index=0 + get-tuple-element.5 = (f32[], f32[]) get-tuple-element(param), index=1 + get-tuple-element.6 = pred[] get-tuple-element(param), index=2 + ROOT tuple-select = (f32[], f32[]) tuple-select(get-tuple-element.6, get-tuple-element.4, get-tuple-element.5) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers( + absl::Span{true, true, true, true, false}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run(module.get())); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple-select"), {0})); + EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "tuple-select"), {1})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 8f44e1b37ee..5ba390acfd4 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -85,9 +84,11 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename, return ParseHloString(hlo_string, config); } -HloRunner::HloRunner(se::Platform* platform) { +HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) { BackendOptions backend_options; backend_options.set_platform(platform); + backend_options.set_intra_op_parallelism_threads( + intra_op_parallelism_threads); backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie(); VLOG(1) << "Created HloRunner for platform: " << platform->Name(); } @@ -108,7 +109,7 @@ StatusOr HloRunner::TransferLiteralToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const absl::Span literals) { + absl::Span literals) { std::vector buffers; for (const Literal* literal : literals) { CHECK(literal != nullptr); @@ -120,7 +121,7 @@ StatusOr> HloRunner::TransferLiteralsToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const absl::Span literals) { + absl::Span literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { @@ -137,10 +138,10 @@ StatusOr HloRunner::TransferLiteralFromDevice( buffer); } -StatusOr HloRunner::Execute( - std::unique_ptr module, - const absl::Span arguments, bool run_hlo_passes, - ExecutionProfile* profile) { +StatusOr HloRunner::Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, @@ -153,7 +154,7 @@ StatusOr HloRunner::Execute( } StatusOr HloRunner::Execute(std::unique_ptr module, - const absl::Span arguments, + absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. @@ -169,10 +170,9 @@ StatusOr HloRunner::Execute(std::unique_ptr module, /*profile=*/profile); } -StatusOr HloRunner::Execute( - std::unique_ptr executable, - const absl::Span arguments, - ExecutionProfile* profile) { +StatusOr HloRunner::Execute(std::unique_ptr executable, + absl::Span arguments, + ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, @@ -184,7 +184,7 @@ StatusOr HloRunner::Execute( } StatusOr HloRunner::Execute(std::unique_ptr executable, - const absl::Span arguments, + absl::Span arguments, ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector argument_pointers; @@ -200,7 +200,7 @@ StatusOr HloRunner::Execute(std::unique_ptr executable, StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const absl::Span arguments, bool run_hlo_passes, + absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { // Get service run options. se::Stream stream(backend().default_stream_executor()); @@ -221,7 +221,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const absl::Span arguments, bool run_hlo_passes, + absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { std::vector argument_pointers; argument_pointers.reserve(arguments.size()); @@ -236,8 +236,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - Executable* executable, - const absl::Span arguments, + Executable* executable, absl::Span arguments, ExecutionProfile* profile) { // Get service run options. se::Stream stream(backend().default_stream_executor()); @@ -255,8 +254,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - Executable* executable, - const absl::Span arguments, + Executable* executable, absl::Span arguments, ExecutionProfile* profile) { std::vector argument_pointers; argument_pointers.reserve(arguments.size()); @@ -271,13 +269,16 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options, - bool use_threads) { + DeviceAssignment* device_assignment) { TF_ASSIGN_OR_RETURN( std::unique_ptr executable, CreateExecutable(std::move(module), options.run_hlo_passes)); - TF_ASSIGN_OR_RETURN( - DeviceAssignment device_assignment, - backend().computation_placer()->AssignDevices(options.num_replicas, 1)); + return ExecuteReplicated(executable.get(), options, device_assignment); +} + +StatusOr> HloRunner::ExecuteReplicated( + Executable* executable, const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, ExecutionProfile* profile) { std::vector> streams; std::vector service_run_options; @@ -294,13 +295,13 @@ StatusOr> HloRunner::ExecuteReplicated( std::vector> argument_buffer_slices; int64 index = 0; for (int64 i = 0; i < options.num_replicas; ++i) { - int64 device = device_assignment(i, 0); + int64 device = (*device_assignment)(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); streams.push_back(absl::make_unique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( - device, streams.back().get(), &device_assignment)); + device, streams.back().get(), device_assignment)); // Copy arguments to device. for (const Literal* argument : options.arguments) { @@ -330,7 +331,7 @@ StatusOr> HloRunner::ExecuteReplicated( } if (options.infeed != nullptr) { for (int64 i = 0; i < options.num_replicas; ++i) { - int64 device = device_assignment(i, 0); + int64 device = (*device_assignment)(i, 0); pool->Schedule([this, device, &options]() { se::StreamExecutor* executor = backend().stream_executor(device).ValueOrDie(); @@ -348,7 +349,7 @@ StatusOr> HloRunner::ExecuteReplicated( } if (ShapeUtil::IsInitialized(options.outfeed_shape)) { for (int64 i = 0; i < options.num_replicas; ++i) { - int64 device = device_assignment(i, 0); + int64 device = (*device_assignment)(i, 0); pool->Schedule([this, device, &options]() { se::StreamExecutor* executor = backend().stream_executor(device).ValueOrDie(); @@ -371,7 +372,7 @@ StatusOr> HloRunner::ExecuteReplicated( LOG(INFO) << "Replicated execution started"; std::vector results; - if (!use_threads) { + if (!options.use_threads) { TF_ASSIGN_OR_RETURN(results, executable->ExecuteOnStreams(service_run_options, argument_buffer_slices)); @@ -416,6 +417,15 @@ StatusOr> HloRunner::ExecuteReplicated( return std::move(exec_results); } +StatusOr> HloRunner::ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options) { + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + backend().computation_placer()->AssignDevices(options.num_replicas, 1)); + return ExecuteReplicated(std::move(module), options, &device_assignment); +} + StatusOr> HloRunner::CreateExecutable( std::unique_ptr module, bool run_hlo_passes) { if (run_hlo_passes) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 88a137e6452..7e666a8186e 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -78,9 +78,19 @@ class HloRunner { // saved modules are coming from after the HLO pass pipeline, so triggering // another run will likely cause errors. bool run_hlo_passes = false; + + // If true, executes on multiple threads using se::Stream::ExecuteOnStream. + // Othewise, executes using xla::Executable::ExecuteOnStreams. + bool use_threads = false; }; - explicit HloRunner(se::Platform* platform); + // intra_op_parallelism_threads: For the CPU backend only. It is the thread + // pool size for parallel execution of an individual operator. The default + // value of -1 will result in initializing the thread pool with the number of + // threads equal to the number of + // cores in the system. + explicit HloRunner(se::Platform* platform, + int intra_op_parallelism_threads = -1); ~HloRunner(); @@ -104,9 +114,9 @@ class HloRunner { // Transfers data between the host and device. StatusOr TransferLiteralToDevice(const Literal& literal); StatusOr> TransferLiteralsToDevice( - const absl::Span literals); + absl::Span literals); StatusOr> TransferLiteralsToDevice( - const absl::Span literals); + absl::Span literals); StatusOr TransferLiteralFromDevice(const ShapedBuffer& buffer); // Executes the given module with given literals as input and returns the @@ -115,46 +125,44 @@ class HloRunner { // If run_hlo_passes is false, the module will be executed without Hlo // optimization. StatusOr Execute(std::unique_ptr module, - const absl::Span arguments, + absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr Execute(std::unique_ptr module, - const absl::Span arguments, + absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr Execute(std::unique_ptr executable, - const absl::Span arguments, + absl::Span arguments, ExecutionProfile* profile = nullptr); StatusOr Execute(std::unique_ptr executable, - const absl::Span arguments, + absl::Span arguments, ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const absl::Span arguments, + absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const absl::Span arguments, + absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // In the following two calls, "executable" is not a unique_ptr to allow // reuse of the Executable. This call may update the profile information in // *executable. StatusOr ExecuteWithDeviceBuffers( - Executable* executable, - const absl::Span arguments, + Executable* executable, absl::Span arguments, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( - Executable* executable, - const absl::Span arguments, + Executable* executable, absl::Span arguments, ExecutionProfile* profile = nullptr); // Creates an executable object given an HLO module. If run_hlo_passes is @@ -165,13 +173,24 @@ class HloRunner { // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. - // - // use_threads indicates whether this replicated computation will be executed - // with a thread-per-replica, vs using an implicitly async call such as - // Executable::ExecuteOnStreams. StatusOr> ExecuteReplicated( std::unique_ptr module, - const ReplicatedExecuteOptions& options, bool use_threads = false); + const ReplicatedExecuteOptions& options); + + // Same as above, but with specified device assignment. + StatusOr> ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment); + + // Same as above, but with a reusable Executable. This may update the profile + // information in *executable. + // + // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, + // since we've already compiled the Executable. + StatusOr> ExecuteReplicated( + Executable* executable, const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index 0e56e6f760e..ecc8dbe6560 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -228,7 +228,7 @@ HloModule UpdateScheduleWithMultipleComputations %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %WhileLoop () -> s32[] { @@ -297,7 +297,7 @@ HloModule UpdateScheduleWithMultipleComputations %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %WhileLoop () -> s32[] { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index dd57ea83f1c..90a80a4421b 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -207,6 +207,7 @@ class HloSharding { // Returns the flattened list of all the leaf shardings in a tuple shape, by // pre-order walk (ShapeTree iterator order). // REQUIRES: IsTuple(). + std::vector& tuple_elements() { return tuple_elements_; } const std::vector& tuple_elements() const { return tuple_elements_; } diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index c1073911ea9..6c0a1926c41 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -87,17 +86,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", - module->config().debug_options()); - } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", - module->config().debug_options()); - } EXPECT_EQ(2, module->computation_count()); EXPECT_EQ(x->to_apply(), y->to_apply()); } @@ -126,17 +115,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", - module->config().debug_options()); - } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", - module->config().debug_options()); - } EXPECT_EQ(2, module->computation_count()); EXPECT_EQ(x->to_apply(), y->to_apply()); } @@ -166,17 +145,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", - module->config().debug_options()); - } EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); - if (VLOG_IS_ON(1)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", - module->config().debug_options()); - } EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); } diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 218b33b2ac2..ba856fc17af 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -177,12 +177,16 @@ void HloValue::SetPositionsAndComputeUses( // Build vector of HloUses for the value. for (const HloPosition& position : positions_) { for (HloInstruction* user : position.instruction->users()) { - for (int64 operand_number : user->OperandIndices(position.instruction)) { + for (int64 i = 0; i < user->operand_count(); ++i) { + if (user->operand(i) != position.instruction) { + continue; + } + // Root instructions of computations are considered to be uses whether // or not the root instruction itself actually uses the value. - if (MayUseOperandValue(operand_number, position.index, user) || + if (MayUseOperandValue(i, position.index, user) || ContainsKey(root_positions, user)) { - HloUse new_use{user, operand_number, position.index}; + HloUse new_use{user, i, position.index}; // The new use must not already exist in uses_. for (const HloUse& use : uses_) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 375ae2c477d..6cbfb784cdc 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/service/hlo_verifier.h" + #include #include "absl/container/flat_hash_map.h" @@ -21,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -201,6 +202,10 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandlePartitionId(HloInstruction* hlo) { + return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); +} + Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) { return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); } @@ -343,7 +348,7 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { // Check that the 'compare' computation returns a PRED. Shape compare_shape = compare->root_instruction()->shape(); - if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "The Sort compare computation shape does not lead to a scalar " "predicate shape: %s", @@ -393,7 +398,8 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return InternalError("Constant is required to have a valid literal: %s", constant->ToString()); } - return CheckShape(constant, constant->literal().shape()); + return CheckShape(constant, constant->literal().shape(), + /*only_compare_minor_to_major_in_layout=*/true); } Status ShapeVerifier::HandleIota(HloInstruction* instruction) { @@ -406,9 +412,10 @@ Status ShapeVerifier::HandleIota(HloInstruction* instruction) { return InternalError("Iota does not support scalars."); } int64 iota_dimension = iota->iota_dimension(); - if (iota_dimension >= rank) { + if (iota_dimension >= rank || iota_dimension < 0) { return InternalError( - "The iota dimension cannot go beyond the operation rank."); + "The iota dimension cannot go beyond the operation rank or be " + "negative."); } return Status::OK(); } @@ -654,7 +661,8 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); - if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapeUtil::Compatible(conditional_shape, + ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", @@ -667,10 +675,23 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + if (!ShapeUtil::IsScalar(conditional->operand(0)->shape())) { + return InvalidArgument( + "The first operand of conditional must be a scalar. Got %s", + conditional->operand(0)->shape().DebugString()); + } const int num_branches = conditional->branch_count(); - if (conditional->operand(0)->shape().element_type() == PRED) { + PrimitiveType operand0_type = conditional->operand(0)->shape().element_type(); + if (operand0_type == PRED) { TF_RET_CHECK(num_branches == 2); } else { + if (operand0_type != S32) { + return InvalidArgument( + "The first operand of indexed conditional must be a scalar of S32. " + "Got" + " type %s.", + PrimitiveType_Name(operand0_type)); + } TF_RET_CHECK(num_branches >= 1); } TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1)); @@ -696,7 +717,8 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), - ShapeUtil::MakeTokenShape()})); + ShapeUtil::MakeTokenShape()}), + /*only_compare_minor_to_major_in_layout=*/true); } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { @@ -705,9 +727,11 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { Status ShapeVerifier::HandleRecv(HloInstruction* recv) { return CheckShape( - recv, ShapeUtil::MakeTupleShape( - {ShapeUtil::GetTupleElementShape(recv->shape(), 0), - ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})); + recv, + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(recv->shape(), 0), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}), + /*only_compare_minor_to_major_in_layout=*/true); } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { @@ -844,7 +868,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const Shape& inferred_shape) { + const Shape& inferred_shape, + bool only_compare_minor_to_major_in_layout) { // If allow_mixed_precision_ is false, check if there are operands with // different precisions. We need this check because ShapeInference allows // mixed precision inputs. @@ -878,7 +903,8 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, case HloOpcode::kTuple: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: - return ShapesSame(instruction->shape(), inferred_shape); + return ShapesSame(instruction->shape(), inferred_shape, + only_compare_minor_to_major_in_layout); // We allow arbitrary layout and f32->bf16 transformations on all other // instructions, although this may be made more strict pending discussion @@ -961,7 +987,7 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { if (computation->num_parameters() != layout.parameter_count()) { return InternalError( "Number of parameters in entry computation layout (%d) must be same " - "as number of parameters of entry computation computation (%d)", + "as number of parameters of entry computation (%d)", layout.parameter_count(), computation->num_parameters()); } @@ -1478,11 +1504,9 @@ StatusOr HloVerifier::Run(HloModule* module) { std::unique_ptr shape_verifier = target_metadata_->GetVerifier(); + InstructionVerifier instruction_verifier(instruction_can_change_layout_func_); for (auto* computation : module->computations()) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); - - InstructionVerifier instruction_verifier( - instruction_can_change_layout_func_); TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index d427a1586c3..45e472bbdf2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -57,6 +57,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; + Status HandlePartitionId(HloInstruction* hlo) override; Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; @@ -106,7 +107,8 @@ class ShapeVerifier : public DfsHloVisitor { // Check the instruction's shape against the shape given by ShapeInference // and return an appropriate error if there is a mismatch. Status CheckShape(const HloInstruction* instruction, - const Shape& inferred_shape); + const Shape& inferred_shape, + bool only_compare_minor_to_major_in_layout = false); // Overload which takes a StatusOr to reduce boilerplate in the caller. Status CheckShape(const HloInstruction* instruction, @@ -120,14 +122,31 @@ class ShapeVerifier : public DfsHloVisitor { private: // Helpers that switch on layout_sensitive_. - bool ShapesSame(const Shape& a, const Shape& b) { - return layout_sensitive_ ? ShapeUtil::Equal(a, b) - : ShapeUtil::Compatible(a, b); + bool ShapesSame(const Shape& a, const Shape& b, + bool minor_to_major_only = false) { + if (!layout_sensitive_) { + return ShapeUtil::Compatible(a, b); + } + Shape::Equal equal; + if (minor_to_major_only) { + equal.MinorToMajorOnlyInLayout(); + } + return equal(a, b); } - bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { - return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) - : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + + bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b, + bool minor_to_major_only = false) { + if (!layout_sensitive_) { + return ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + } + Shape::Equal equal; + if (minor_to_major_only) { + equal.MinorToMajorOnlyInLayout(); + } + equal.IgnoreFpPrecision(); + return equal(a, b); } + string StringifyShape(const Shape& s) { return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) : ShapeUtil::HumanString(s); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 523890b3c72..201fc654ad0 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -155,17 +155,17 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) { const char* const hlo_string = R"( -HloModule Module + HloModule Module -callme { - ROOT param = (s32[], f32[4]) parameter(0) -} + callme { + ROOT param = (s32[], f32[4]) parameter(0) + } -ENTRY entry { - p0 = (f32[4], s32[]) parameter(0) - ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme -} -)"; + ENTRY entry { + p0 = (f32[4], s32[]) parameter(0) + ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme + } + )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); auto status = verifier().Run(module.get()).status(); @@ -176,25 +176,25 @@ ENTRY entry { TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) { const char* const hlo_string = R"( -HloModule Module + HloModule Module -true_branch { - tparam = (s32[], f32[4]) parameter(0) - ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1 -} + true_branch { + tparam = (s32[], f32[4]) parameter(0) + ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1 + } -false_branch { - fparam = (s32[], f32[4]) parameter(0) - ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1 -} + false_branch { + fparam = (s32[], f32[4]) parameter(0) + ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1 + } -ENTRY entry { - p0 = (f32[4], s32[]) parameter(0) - constant = pred[] constant(true) - ROOT conditional = f32[4] conditional(constant, p0, p0), - true_computation=true_branch, false_computation=false_branch -} -)"; + ENTRY entry { + p0 = (f32[4], s32[]) parameter(0) + constant = pred[] constant(true) + ROOT conditional = f32[4] conditional(constant, p0, p0), + true_computation=true_branch, false_computation=false_branch + } + )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); auto status = verifier().Run(module.get()).status(); @@ -203,6 +203,51 @@ ENTRY entry { HasSubstr("shape does not match parameter")); } +TEST_F(HloVerifierTest, CheckConditionalBranchIndexOperandShape) { + const char* const hlo_string = R"( + HloModule Module + + branch0 { + tparam = f32[4] parameter(0) + ROOT tgte1 = f32[4] ceil(tparam) + } + + branch1 { + fparam = f32[4] parameter(0) + ROOT fgte1 = f32[4] floor(fparam) + } + + branch2 { + sparam = f32[4] parameter(0) + ROOT sgte1 = f32[4] ceil(sparam) + } + + ENTRY entry { + p0 = f32[4] parameter(0) + b0 = s32[] parameter(1) + ROOT conditional = f32[4] conditional(b0, p0, p0, p0), + branch_computations={branch0, branch1, branch2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + auto status = verifier().Run(module.get()).status(); + + HloInstruction* condition = FindInstruction(module.get(), "b0"); + *condition->mutable_shape() = ShapeUtil::MakeShape(F32, {}); + status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.error_message(), + HasSubstr( + "first operand of indexed conditional must be a scalar of S32")); + + *condition->mutable_shape() = ShapeUtil::MakeShape(S32, {4}); + status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("first operand of conditional must be a scalar")); +} + TEST_F(HloVerifierTest, RngOpnd0NotScalar) { const char* const hlo_string = R"( HloModule Module @@ -504,7 +549,7 @@ TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) { HloModule Module ENTRY SelectMixedPrecisionNotAllowed { - p0 = pred[] parameter(0) + p0 = pred[32] parameter(0) p1 = f32[32] parameter(1) p2 = bf16[32] parameter(2) ROOT select = f32[32] select(p0, p1, p2) @@ -523,7 +568,7 @@ TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) { HloModule Module ENTRY SelectMixedPrecisionAllowed { - p0 = pred[] parameter(0) + p0 = pred[32] parameter(0) p1 = f32[32] parameter(1) p2 = bf16[32] parameter(2) ROOT select = f32[32] select(p0, p1, p2) @@ -535,6 +580,25 @@ TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTest, SelectTupleNotAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY SelectWithTuple { + p0 = (f32[], f32[]) parameter(0) + p1 = (f32[], f32[]) parameter(1) + p2 = pred[] parameter(2) + ROOT select = (f32[], f32[]) select(p2, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected array argument for select")); +} + TEST_F(HloVerifierTest, IotaNonArrayResult) { const char* const hlo_string = R"( HloModule IotaTupleResult @@ -552,6 +616,22 @@ TEST_F(HloVerifierTest, IotaNonArrayResult) { HasSubstr("does not support non-array result")); } +TEST_F(HloVerifierTest, IotaNegativeDimension) { + const char* const hlo_string = R"( + HloModule IotaTupleResult + + ENTRY kernelEntry { + ROOT iota = s32[128,1001]{1,0} iota(), iota_dimension=-1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("negative")); +} + static const char* const kMapOperandComputationMismatchHlo = R"( HloModule MapOperandComputationMismatch diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc deleted file mode 100644 index ada21345014..00000000000 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -namespace { - -// Visitor for removing implicit broadcasts. -class ImplicitBroadcastVisitor : public DfsHloVisitorWithDefault { - public: - Status DefaultAction(HloInstruction* hlo_instruction) override { - return Status::OK(); - } - - Status HandleElementwiseBinary(HloInstruction* hlo) override { - return ReplaceImplicitBroadcastOperands(hlo); - } - - Status HandleClamp(HloInstruction* hlo) override { - // Clamp is the only element-wise ternary operation. - return ReplaceImplicitBroadcastOperands(hlo); - } - - // Returns whether any modification has been made to any visited instruction. - bool changed() const { return changed_; } - - private: - // Iterates through the operands of 'hlo' and replace any operands which are - // implicitly broadcast with the equivalent sequence of broadcast and reshape - // instructions. An operand is considered to be implicitly broadcast if the - // operand shape does have the same dimensions as the shape of 'hlo'. - Status ReplaceImplicitBroadcastOperands(HloInstruction* hlo) { - auto fadd = [hlo](std::unique_ptr x) { - return hlo->parent()->AddInstruction(std::move(x)); - }; - std::vector operands; - bool operands_changed = false; - for (int i = 0; i < hlo->operand_count(); ++i) { - HloInstruction* operand = hlo->mutable_operand(i); - if (!ShapeUtil::SameDimensions(hlo->shape(), operand->shape())) { - HloInstruction* new_operand = hlo->parent()->AddInstruction( - HloInstruction::CreateBroadcastSequence(hlo->shape(), operand, - fadd)); - operands.push_back(new_operand); - operands_changed = true; - } else { - operands.push_back(operand); - } - } - if (operands_changed) { - // Create a new HLO instruction because the HloInstruction::Replace* - // methods check that the shape does not change with the replacement. - HloInstruction* new_hlo = hlo->parent()->AddInstruction( - hlo->CloneWithNewOperands(hlo->shape(), operands)); - TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); - changed_ = true; - } - return Status::OK(); - } - - bool changed_ = false; -}; - -} // namespace - -StatusOr ImplicitBroadcastRemover::Run(HloModule* module) { - VLOG(1) << "Removing implicit broadcast from module " << module->name(); - XLA_VLOG_LINES(2, - "Before removing implicit broadcasts:\n" + module->ToString()); - - ImplicitBroadcastVisitor visitor; - for (HloComputation* computation : module->computations()) { - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); - } - - if (visitor.changed()) { - // HLO instructions with implicitly broadcast operands are cloned and left - // for dead. Remove them. - HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); - } - - XLA_VLOG_LINES(2, - "After removing implicit broadcasts:\n" + module->ToString()); - - return visitor.changed(); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc deleted file mode 100644 index cf6cf897fe1..00000000000 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" - -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -class ImplicitBroadcastRemoverTest : public HloTestBase { - protected: - ImplicitBroadcastRemover remover_; -}; - -TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); - auto param0 = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); - auto param1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); - builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_FALSE(remover_.Run(m.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Parameter(), op::Parameter())); -} - -TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_param")); - auto param1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); - builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kPower, param0, param1)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - - EXPECT_FALSE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - root = computation->root_instruction(); - - EXPECT_THAT(root, op::Power(op::Broadcast(op::Parameter()), op::Parameter())); - - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); -} - -TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); - auto param0 = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 4, 1}), "p1")); - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, param0, param1)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Subtract(op::Parameter(), - op::Broadcast(op::Reshape(op::Parameter())))); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); -} - -TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {1, 4, 1}); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_param")); - auto param1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, param0, param1)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, - op::Subtract(op::Broadcast(op::Parameter()), op::Parameter())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); -} - -TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6, 8}); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 4, 1, 8}), "p0")); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 6, 8}), "p1")); - auto param2 = builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(F32, {2, 1, 6, 8}), "p2")); - builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, - param0, param1, param2)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Reshape(op::Parameter())), - op::Broadcast(op::Reshape(op::Parameter())), - op::Broadcast(op::Reshape(op::Parameter())))); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(2)->shape())); -} - -TEST_F(ImplicitBroadcastRemoverTest, - TernaryScalarAndDegenerateDimensionBroadcast) { - auto m = CreateNewVerifiedModule(); - auto builder = HloComputation::Builder(TestName()); - - const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 4, 6}), "p1")); - auto param2 = - builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2")); - builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, - param0, param1, param2)); - - HloComputation* computation = m->AddEntryComputation(builder.Build()); - - EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); - - HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Parameter()), - op::Broadcast(op::Reshape(op::Parameter())), - op::Parameter())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(2)->shape())); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index e4a78af7c72..1be029a7c03 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -65,6 +65,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConstant: @@ -72,27 +73,23 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kFloor: - case HloOpcode::kGe: case HloOpcode::kGetTupleElement: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kInfeed: case HloOpcode::kIota: case HloOpcode::kIsFinite: - case HloOpcode::kLe: - case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: case HloOpcode::kXor: case HloOpcode::kOutfeed: case HloOpcode::kPad: + case HloOpcode::kPartitionId: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kReplicaId: @@ -254,67 +251,63 @@ InstructionFusion::ComputeGloballyUnfusible( HloInstructionSet do_not_duplicate; absl::flat_hash_map, bool> can_fuse_on_all_paths_result_cache; - for (HloInstruction* consumer : post_order) { - for (HloInstruction* producer : consumer->operands()) { - if (do_not_duplicate.count(producer) > 0) { - continue; - } - - // If the producer is effectively not more than unary, duplicating it - // will not increase the number of relevant inputs read, as the fusion - // node will only need to read at most 1 relevant input (the input of - // the producer). In that case, we do not forbid fusion of the operation - // here. - if (EffectivelyAtMostUnary(producer)) { - continue; - } - - // If the total size of the inputs is less than or equal to the total size - // of the outputs for the producer then duplicating it won't increase the - // memory traffic. In that case, we do not forbid fusion of the operation - // here. - auto total_size = [](const Shape& shape) { - int64 size = 0; - ShapeUtil::ForEachSubshape( - shape, - [&size](const Shape& subshape, const ShapeIndex& shape_index) { - if (subshape.IsArray()) { - size += ShapeUtil::ElementsIn(subshape); - } - }); - return size; - }; - int64 operands_size = 0; - for (const HloInstruction* op : producer->operands()) { - operands_size += total_size(op->shape()); - } - if (operands_size <= total_size(producer->shape())) { - continue; - } - - // Otherwise we will forbid fusing the op unless we can fuse it into - // all of its consumers on all paths. - // - // That means, that for: - // A --> B (fusible) - // \-> C (non-fusible) - // A will be not allowed to be fused into B, as it cannot be fused into C. - // - // Similarly, for: - // A -------------> B - // \-> C -> D -/ - // If: - // - A is fusible into B and C, and D is fusible into B - // - C is *not* fusible into D - // A will be not allowed to be fused into B, as it cannot be fused via - // all paths. - if (producer->IsFusible() && - CanFuseOnAllPaths(producer, consumer, do_not_duplicate, - &can_fuse_on_all_paths_result_cache)) { - continue; - } - do_not_duplicate.insert(producer); + for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) { + HloInstruction* producer = *it; + // If the producer is effectively not more than unary, duplicating it + // will not increase the number of relevant inputs read, as the fusion + // node will only need to read at most 1 relevant input (the input of + // the producer). In that case, we do not forbid fusion of the operation + // here. + if (EffectivelyAtMostUnary(producer)) { + continue; } + + // If the total size of the inputs is less than or equal to the total size + // of the outputs for the producer then duplicating it won't increase the + // memory traffic. In that case, we do not forbid fusion of the operation + // here. + auto total_size = [](const Shape& shape) { + int64 size = 0; + ShapeUtil::ForEachSubshape( + shape, [&size](const Shape& subshape, const ShapeIndex& shape_index) { + if (subshape.IsArray()) { + size += ShapeUtil::ElementsIn(subshape); + } + }); + return size; + }; + int64 operands_size = 0; + for (const HloInstruction* op : producer->unique_operands()) { + operands_size += total_size(op->shape()); + } + if (operands_size <= total_size(producer->shape())) { + continue; + } + + // Otherwise we will forbid fusing the op unless we can fuse it into + // all of its consumers on all paths. + // + // That means, that for: + // A --> B (fusible) + // \-> C (non-fusible) + // A will be not allowed to be fused into B, as it cannot be fused into C. + // + // Similarly, for: + // A -------------> B + // \-> C -> D -/ + // If: + // - A is fusible into B and C, and D is fusible into B + // - C is *not* fusible into D + // A will be not allowed to be fused into B, as it cannot be fused via + // all paths. + if (producer->IsFusible() && + absl::c_all_of(producer->users(), [&](HloInstruction* consumer) { + return CanFuseOnAllPaths(producer, consumer, do_not_duplicate, + &can_fuse_on_all_paths_result_cache); + })) { + continue; + } + do_not_duplicate.insert(producer); } return do_not_duplicate; @@ -413,13 +406,11 @@ class ReversePostOrderFusionQueue : public FusionQueue { } sorted_operand_numbers.push_back(i); } - absl::c_sort( - sorted_operand_numbers, [&](int64 i, int64 j) { - // Instructions with higher priority in the queue come first. - return ( - FindOrDie(post_order_index_, instruction->mutable_operand(i)) > + absl::c_sort(sorted_operand_numbers, [&](int64 i, int64 j) { + // Instructions with higher priority in the queue come first. + return (FindOrDie(post_order_index_, instruction->mutable_operand(i)) > FindOrDie(post_order_index_, instruction->mutable_operand(j))); - }); + }); return std::make_pair(instruction, sorted_operand_numbers); } @@ -451,9 +442,6 @@ std::unique_ptr InstructionFusion::GetFusionQueue( } StatusOr InstructionFusion::Run(HloModule* module) { - VLOG(2) << "Before instruction fusion:"; - XLA_VLOG_LINES(2, module->ToString()); - bool changed = false; module_ = module; for (auto* computation : module->MakeNonfusionComputations()) { @@ -531,9 +519,6 @@ StatusOr InstructionFusion::Run(HloModule* module) { } } - VLOG(2) << "After instruction fusion:"; - XLA_VLOG_LINES(2, module->ToString()); - return changed; } @@ -623,10 +608,8 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() != HloInstruction::FusionKind::kLoop && - consumer->fusion_kind() != HloInstruction::FusionKind::kInput && - consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) { + if (consumer->opcode() == HloOpcode::kFusion && !consumer->IsLoopFusion() && + !consumer->IsInputFusion() && !consumer->IsOutputFusion()) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 611cfd404d7..864a9ac2069 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -219,13 +219,14 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // Make sure we do not duplicate the add, as we cannot fuse through the rng. // - // p0 -> add -------------------------> sub - // \-> abs1 -> rng -> abs2 -/ + // (p0, p1) -> add -------------------------> sub + // \-> abs1 -> rng -> abs2 -/ auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) - add = f32[4,3]{1,0} add(p0, p0) + p1 = f32[4,3]{1,0} parameter(1) + add = f32[4,3]{1,0} add(p0, p1) abs1 = f32[4,3]{1,0} abs(add) rng = f32[4,3]{1,0} rng(abs1), distribution=rng_uniform abs2 = f32[4,3]{1,0} abs(rng) @@ -249,14 +250,15 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // Use a log node with a second consumer to break the fusion. // - // p0 -> add -------------------------> sub - // \-> abs1 -> log -> abs2 -/ - // \-> send + // (p0, p1) -> add -------------------------> sub + // \-> abs1 -> log -> abs2 -/ + // \-> send module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) - add = f32[4,3]{1,0} add(p0, p0) + p1 = f32[4,3]{1,0} parameter(1) + add = f32[4,3]{1,0} add(p0, p1) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) token0 = token[] after-all() @@ -280,15 +282,16 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // Make sure we still fuse ops where one operand in the chain to the producer // can't be fused. // - // p0 ---> add1 -----------> sub - // \ \-> add2 -/ - // \-> log -/ - // \-> send + // (p0, p1) ---> add1 -----------> sub + // \ \-> add2 -/ + // \-> log -/ + // \-> send module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) - add1 = f32[4,3]{1,0} add(p0, p0) + p1 = f32[4,3]{1,0} parameter(1) + add1 = f32[4,3]{1,0} add(p0, p1) log = f32[4,3]{1,0} log(p0) token0 = token[] after-all() send = f32[4,3]{1,0} send(log, token0), channel_id=0 @@ -313,16 +316,17 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // into sub2. // // /---------------\ - // p0 ---> add1 ---> add2 ------> sub2 - // \------> sub1 - // log -/ - // \-> send + // (p0, p1) ---> add1 ---> add2 ------> sub2 + // \------> sub1 + // log -/ + // \-> send module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) - add1 = f32[4,3]{1,0} add(p0, p0) - add2 = f32[4,3]{1,0} add(add1, add1) + p1 = f32[4,3]{1,0} parameter(1) + add1 = f32[4,3]{1,0} add(p0, p1) + add2 = f32[4,3]{1,0} add(add1, p1) log = f32[4,3]{1,0} log(add2) token0 = token[] after-all() send = f32[4,3]{1,0} send(log, token0), channel_id=0 @@ -394,6 +398,40 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) { + // Make sure we do duplicate the add of the same values, even though we cannot + // fuse through the rng. + // + // p0 -> add -------------------------> sub + // \-> abs1 -> rng -> abs2 -/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + abs1 = f32[4,3]{1,0} abs(add) + rng = f32[4,3]{1,0} rng(abs1), distribution=rng_uniform + abs2 = f32[4,3]{1,0} abs(rng) + ROOT root = f32[4,3]{1,0} subtract(abs2, add) + })") + .ValueOrDie(); + // We expect abs2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Subtract(op::Abs(op::Parameter()), + op::Add(op::Parameter(), op::Parameter()))) + << module->ToString(); + + // Make sure the add has been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); +} + TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { auto module = ParseHloString(R"( HloModule test_module diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 599489b3785..7f0c1ccc728 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -35,6 +35,7 @@ cc_library( "//tensorflow/compiler/xla/service:cholesky_expander", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -53,7 +54,6 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:while_loop_simplifier", - "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:lib", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index a8f8ab4f725..80a3ebccff1 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/cholesky_expander.h" #include "tensorflow/compiler/xla/service/computation_placer.h" -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -52,8 +52,8 @@ namespace { StatusOr HandleEvaluatorCustomCall( HloInstruction* custom_call, absl::Span operands) { // Find the target C function in the global registry. - auto* registry = xla::cpu::CustomCallTargetRegistry::Global(); - void* target_fn = registry->Lookup(custom_call->custom_call_target()); + auto* registry = CustomCallTargetRegistry::Global(); + void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host"); if (!target_fn) { return NotFound("Custom call target '%s' was not registered", custom_call->custom_call_target()); @@ -96,7 +96,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { StatusOr> InterpreterCompiler::RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* /*stream_exec*/, - DeviceMemoryAllocator* /*device_allocator*/) { + se::DeviceMemoryAllocator* /*device_allocator*/) { VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); return std::move(hlo_module); @@ -105,13 +105,13 @@ StatusOr> InterpreterCompiler::RunHloPasses( Status InterpreterCompiler::RunHloPassesOnModuleGroup( HloModuleGroup* module_group, absl::Span executors, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { return Unimplemented("Module group compilation not supported on Interpreter"); } StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* /*device_allocator*/) { + se::DeviceMemoryAllocator* /*device_allocator*/) { TF_RET_CHECK(stream_exec != nullptr); VLOG(1) << "Run backend " << hlo_module->name(); @@ -137,7 +137,7 @@ StatusOr>> InterpreterCompiler::RunBackendOnModuleGroup( std::unique_ptr module_group, std::vector> stream_exec, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { return Unimplemented( "Module group compilation is not supported on Interpreter."); } @@ -145,7 +145,7 @@ InterpreterCompiler::RunBackendOnModuleGroup( StatusOr>> InterpreterCompiler::Compile( std::unique_ptr module_group, std::vector> stream_exec, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { if (module_group->empty()) { return std::vector>(); } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index 591272951a0..dc83295b527 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -45,24 +45,24 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; Status RunHloPassesOnModuleGroup( HloModuleGroup* module_group, absl::Span executors, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr>> RunBackendOnModuleGroup( std::unique_ptr module_group, std::vector> stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::unique_ptr module_group, diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 7a6ebdef708..167a013408b 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -72,11 +72,12 @@ StatusOr InterpreterExecutable::ExecuteOnStream( for (int64 i = 0; i < computation->num_parameters(); ++i) { const auto& expected_shape = computation->parameter_instruction(i)->shape(); const auto& actual_shape = arguments[i]->on_device_shape(); - if (!ShapeUtil::Equal(expected_shape, actual_shape)) { + if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape, + actual_shape)) { return InvalidArgument( "Shape mismatch on parameter %d. Expected %s, but was %s.", i, - ShapeUtil::HumanString(expected_shape), - ShapeUtil::HumanString(actual_shape)); + ShapeUtil::HumanStringWithLayout(expected_shape), + ShapeUtil::HumanStringWithLayout(actual_shape)); } } diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index e3e5fa71543..b1a26b3b586 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -35,16 +35,14 @@ XlaInterpreterExecutor::~XlaInterpreterExecutor() {} void *XlaInterpreterExecutor::Allocate(uint64 size) { return new char[size]; } -void *XlaInterpreterExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, - uint64 offset_bytes, - uint64 /*size_bytes*/) { +void *XlaInterpreterExecutor::GetSubBuffer(DeviceMemoryBase *parent, + uint64 offset_bytes, + uint64 /*size_bytes*/) { return parent + offset_bytes; } void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { - if (!mem->is_sub_buffer()) { - delete[] static_cast(mem->opaque()); - } + delete[] static_cast(mem->opaque()); } bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, @@ -112,7 +110,8 @@ port::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { return port::Status::OK(); } -DeviceDescription *XlaInterpreterExecutor::PopulateDeviceDescription() const { +port::StatusOr> +XlaInterpreterExecutor::CreateDeviceDescription(int device_ordinal) { internal::DeviceDescriptionBuilder builder; builder.set_device_address_bits(64); @@ -121,7 +120,7 @@ DeviceDescription *XlaInterpreterExecutor::PopulateDeviceDescription() const { builder.set_device_memory_size(static_cast(4) * 1024 * 1024 * 1024); builder.set_clock_rate_ghz(static_cast(CLOCKS_PER_SEC) / 1e9); - return builder.Build().release(); + return builder.Build(); } } // namespace interpreter diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 400c3051546..6d337688a94 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -69,8 +69,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } void *Allocate(uint64 size) override; - void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes, - uint64 size_bytes) override; + void *GetSubBuffer(DeviceMemoryBase *parent, uint64 offset_bytes, + uint64 size_bytes) override; void Deallocate(DeviceMemoryBase *mem) override; void *HostMemoryAllocate(uint64 size) override { return new char[size]; } @@ -80,9 +80,9 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { bool HostMemoryRegister(void *mem, uint64 size) override { return true; } bool HostMemoryUnregister(void *mem) override { return true; } - bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &pop_src, + bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &dev_src, uint64 size) override; - bool Memcpy(Stream *stream, DeviceMemoryBase *pop_dst, const void *host_src, + bool Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, const void *host_src, uint64 size) override; bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, const DeviceMemoryBase &host_src, @@ -114,10 +114,10 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return false; } - port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst, + port::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst, const void *host_src, uint64 size) override; port::Status SynchronousMemcpy(void *host_dst, - const DeviceMemoryBase &pop_src, + const DeviceMemoryBase &dev_src, uint64 size) override; port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, const DeviceMemoryBase &pop_src, @@ -165,7 +165,13 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return false; } - DeviceDescription *PopulateDeviceDescription() const override; + port::StatusOr> CreateDeviceDescription() + const override { + return CreateDeviceDescription(0); + } + + static port::StatusOr> + CreateDeviceDescription(int device_ordinal); port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { return port::Status::OK(); diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index b0fc1af8b89..aa17c20bc7f 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" -#include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/multi_platform_manager.h" @@ -43,6 +42,11 @@ int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } const string& XlaInterpreterPlatform::Name() const { return name_; } +port::StatusOr> +XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const { + return XlaInterpreterExecutor::CreateDeviceDescription(ordinal); +} + port::StatusOr XlaInterpreterPlatform::ExecutorForDevice( int ordinal) { StreamExecutorConfig config; diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index 0187f6d473b..ff9c5d07f8d 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -29,8 +29,9 @@ namespace interpreter { class XlaInterpreterPlatform : public Platform { public: - XlaInterpreterPlatform(const string& name = "Interpreter", - const Platform::Id& id = kXlaInterpreterPlatformId); + XlaInterpreterPlatform() + : XlaInterpreterPlatform("Interpreter", kXlaInterpreterPlatformId) {} + XlaInterpreterPlatform(const string& name, const Platform::Id& id); ~XlaInterpreterPlatform() override; Platform::Id id() const override; @@ -39,6 +40,9 @@ class XlaInterpreterPlatform : public Platform { const string& Name() const override; + port::StatusOr> DescriptionForDevice( + int ordinal) const override; + port::StatusOr ExecutorForDevice(int ordinal) override; port::StatusOr ExecutorForDeviceWithPluginConfig( diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index a0cb8fcaf3d..b1303f17580 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -33,15 +34,16 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -171,7 +173,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, auto iter = buffer_constraints_.find(&buffer); if (iter != buffer_constraints_.end()) { const BufferLayoutConstraint& curr_constraint = iter->second; - if (LayoutUtil::Equal(curr_constraint.layout(), layout)) { + if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) { // New constraint matches existing constraint. Nothing to do. return Status::OK(); } @@ -208,7 +210,7 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, GetOperandLayoutConstraint(instruction, operand_no); if (curr_shape_layout != nullptr) { if (curr_shape_layout->shape_layout().MatchesLayoutInShape( - shape_with_layout)) { + shape_with_layout, /*minor_to_major_only=*/true)) { // New constraint matches existing constraint. Nothing to do. return Status::OK(); } @@ -267,7 +269,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, const ShapeLayout* curr_shape_layout = ResultLayout(); if (curr_shape_layout != nullptr) { - if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) { + if (!curr_shape_layout->MatchesLayoutInShape( + shape_with_layout, /*minor_to_major_only=*/true)) { return FailedPrecondition( "Result of computation %s already has the layout constraint %s, " "cannot add incompatible constraint %s", @@ -277,7 +280,6 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, // New constraint matches existing constraint. Nothing to do. return Status::OK(); } - result_constraint_.reset( new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs)); added_constraints_.push_back(result_constraint_.get()); @@ -583,10 +585,10 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain the output and the operand of the while instruction to match // the computations. - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( body_layout.result_shape(), instruction, 0)); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + body_layout.result_shape(), instruction)); } else if (instruction->opcode() == HloOpcode::kConditional) { // Find the conditional branch with the most instructions and force all // other computations to match that layout. A potentially better decison @@ -610,38 +612,32 @@ Status LayoutAssignment::AddMandatoryConstraints( int j = (k + largest_branch) % instruction->branch_count(); TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1); ComputationLayout& branch_computation_layout = - FindOrDie(computation_layouts_, instruction->branch_computation(j)); - - DCHECK(ShapeUtil::Compatible( - instruction->operand(j + 1)->shape(), - branch_computation_layout.parameter_shape(0))); - if (best_branch_computation_layout.result_layout() != - branch_computation_layout.result_layout()) { - // We assign layouts in DFS fashion, so the largest_branch and current - // branch computations might have negotiated a different layout. But - // for the case instruction POV the layout must match, so we run again - // on the branch j computation, this time with proper computation - // layout. - VLOG(2) << "Reset %conditional branch " << j - << " computation result layout: branch_computation=" - << instruction->branch_computation(j)->name() - << " case=" << instruction->name() << " shape=" - << best_branch_computation_layout.result_layout().ToString(); - *branch_computation_layout.mutable_result_layout() = - best_branch_computation_layout.result_layout(); + FindOrDie(computation_layouts_, instruction->branch_computation(k)); + if (branch_computation_layout.result_layout() != + best_branch_computation_layout.result_layout()) { + computation_layouts_.erase(instruction->branch_computation(k)); + InsertOrDie(&conditional_mismatch_, + instruction->branch_computation(k), + best_branch_computation_layout); + } else { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + branch_computation_layout.parameter_shape(0), instruction, k + 1, + /*mandatory=*/true)); } - if (k == 0) { - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - best_branch_computation_layout.result_shape(), instruction)); - } - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - branch_computation_layout.parameter_shape(0), instruction, j + 1, - /*mandatory=*/true)); } + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + best_branch_computation_layout.parameter_shape(0), instruction, + largest_branch + 1, + /*mandatory=*/true)); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + best_branch_computation_layout.result_shape(), instruction)); } } // Finally set the result layout to match ComputationLayout, if there is one. - if (computation_layout != nullptr) { + if (conditional_mismatch_.count(computation) > 0) { + TF_RETURN_IF_ERROR(constraints->SetResultLayout( + FindOrDie(conditional_mismatch_, computation).result_layout().shape())); + } else if (computation_layout != nullptr) { const ShapeLayout& result_layout = computation_layout->result_layout(); if (result_layout.LayoutIsSet()) { TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape())); @@ -652,6 +648,10 @@ Status LayoutAssignment::AddMandatoryConstraints( namespace { +bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { + return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout()); +} + // The operands of a call must match the layouts of parameters in the // ComputationLayout, and the call instruction itself must match the result // layout in the ComputationLayout. @@ -661,10 +661,10 @@ Status CheckCallLayout(HloInstruction* call, TF_RET_CHECK(computation->num_parameters() == call->operand_count()); for (int64 i = 0; i < computation->num_parameters(); ++i) { TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( - call->operand(i)->shape())); + call->operand(i)->shape(), /*minor_to_major_only=*/true)); } - TF_RET_CHECK( - computation_layout.result_layout().MatchesLayoutInShape(call->shape())); + TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape( + call->shape(), /*minor_to_major_only=*/true)); return Status::OK(); } @@ -675,9 +675,9 @@ Status CheckCustomCallLayout(HloInstruction* instruction) { const HloCustomCallInstruction* custom_call = DynCast(instruction); for (int64 i = 0; i < custom_call->operand_count(); ++i) { - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - custom_call->operand(i)->shape(), - custom_call->operand_shapes_with_layout()[i])); + TF_RET_CHECK( + LayoutsInShapesEqual(custom_call->operand(i)->shape(), + custom_call->operand_shapes_with_layout()[i])); } } return Status::OK(); @@ -695,13 +695,12 @@ Status CheckWhileLayout(HloInstruction* while_inst, auto init_shape = while_inst->operand(0)->shape(); TF_RET_CHECK( condition_computation_layout.parameter_layout(0).MatchesLayoutInShape( - init_shape)); + init_shape, /*minor_to_major_only=*/true)); TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape( - init_shape)); - TF_RET_CHECK( - body_computation_layout.result_layout().MatchesLayoutInShape(init_shape)); - TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape())); + init_shape, /*minor_to_major_only=*/true)); + TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape( + init_shape, /*minor_to_major_only=*/true)); + TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape())); return Status::OK(); } @@ -714,13 +713,14 @@ Status CheckConditionalLayout( branch_computation_layouts[j].result_layout()); TF_RET_CHECK( branch_computation_layouts[j].result_layout().MatchesLayoutInShape( - instruction->shape())); + instruction->shape(), /*minor_to_major_only=*/true)); TF_RET_CHECK( branch_computation_layouts[j].result_layout().MatchesLayoutInShape( - instruction->branch_computation(j)->root_instruction()->shape())); + instruction->branch_computation(j)->root_instruction()->shape(), + /*minor_to_major_only=*/true)); TF_RET_CHECK( branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape( - branch_operand->shape())); + branch_operand->shape(), /*minor_to_major_only=*/true)); } return Status::OK(); } @@ -731,11 +731,11 @@ Status CheckConditionalLayout( Status CheckFusionLayout(HloInstruction* fusion) { TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode()); - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - fusion->shape(), fusion->fused_expression_root()->shape())); + TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(), + fusion->fused_expression_root()->shape())); for (int64 i = 0; i < fusion->operand_count(); ++i) { - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape())); + TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(), + fusion->operand(i)->shape())); } return Status::OK(); } @@ -747,7 +747,8 @@ Status CheckParameterLayout(HloInstruction* parameter, const ShapeLayout& parameter_layout = computation_layout.parameter_layout(parameter->parameter_number()); if (parameter_layout.LayoutIsSet() && - !parameter_layout.MatchesLayoutInShape(parameter->shape())) { + !parameter_layout.MatchesLayoutInShape(parameter->shape(), + /*minor_to_major_only=*/true)) { return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", @@ -758,8 +759,7 @@ Status CheckParameterLayout(HloInstruction* parameter, // The layout of a constant instruction must match the layout of its literal. Status CheckConstantLayout(HloInstruction* constant) { - if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(), - constant->shape())) { + if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) { return InternalError( "constant instruction %s does not match the layout of its literal %s", constant->ToString(), @@ -790,7 +790,8 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( HloInstruction* gte = instruction->parent()->AddInstruction( HloInstruction::CreateGetTupleElement(instr_shape, instruction, i)); - if (ShapeUtil::Equal(target_shape, instr_shape)) { + if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape, + instr_shape)) { // Shapes and layouts are equal, no need to copy. element_copies.push_back(gte); } else { @@ -836,7 +837,8 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer( TF_RET_CHECK(operand_layout.LayoutIsSet()); TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); - if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(), + operand->shape())) { VLOG(5) << "Operand " << operand->ToString() << " layout matches in " << instruction->ToString(); // Operand layout already matches our constraint. Nothing to do. @@ -897,7 +899,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { const Shape& instruction_subshape = ShapeUtil::GetSubshape(instruction->shape(), index); for (const LogicalBuffer* buffer : buffers) { - if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) { + if (!Shape::Equal().MinorToMajorOnlyInLayout()( + instruction_subshape, buffer->shape())) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", @@ -959,8 +962,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, module->entry_computation()) .result_layout(); if (result_layout.LayoutIsSet()) { - TF_RET_CHECK( - ShapeUtil::Equal(module->result_shape(), result_layout.shape())); + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + module->result_shape(), result_layout.shape())); } return Status::OK(); } @@ -1012,7 +1015,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // operations. For similar reasons, if the operand and output have the same // rank, try to match the operand's layout to the output. if (ShapeUtil::TrueRank(operand->shape()) == 1 && - instruction->shape().rank() == 1) { + ShapeUtil::TrueRank(instruction->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; } @@ -1072,7 +1075,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( // reshape is a bitcast when using the same layout. This may avoid copy // operations. For similar reasons, if the operand and output have the same // rank, try to match the outputs's layout to the operand. - if (operand->shape().rank() == 1 && + if (ShapeUtil::TrueRank(operand->shape()) == 1 && ShapeUtil::TrueRank(user->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; @@ -1223,8 +1226,14 @@ namespace { // unassigned layouts in the graph. bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { switch (hlo.opcode()) { + case HloOpcode::kFusion: + return hlo.IsCustomFusion(); + case HloOpcode::kGather: + return true; case HloOpcode::kReshape: - return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + return hlo.operand(0)->shape().rank() == 1 || + std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + case HloOpcode::kScatter: case HloOpcode::kTranspose: return true; default: @@ -1515,8 +1524,8 @@ StatusOr InferArrayLayout( if (first_buffer_layout == nullptr) { first_buffer_layout = &source_buffer->shape().layout(); - } else if (!LayoutUtil::Equal(source_buffer->shape().layout(), - *first_buffer_layout)) { + } else if (!Layout::Equal().MinorToMajorOnly()( + source_buffer->shape().layout(), *first_buffer_layout)) { // The points-to set is ambiguous for this index and the different source // buffers have different layouts. This case is possible in valid XLA // computations because we do not propagate BufferLayoutConstraints to all @@ -1569,7 +1578,7 @@ Status SetFusionLayouts(HloInstruction* fusion) { fused_instruction->mutable_shape())); } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { // Nop; leave the infeed layout alone. - } else if (fusion->fusion_kind() != HloInstruction::FusionKind::kCustom) { + } else if (!fusion->IsCustomFusion()) { // Other instructions don't have layouts inside of fusion nodes. // But do not clear layouts for other instructions in custom fusion nodes. LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); @@ -1590,18 +1599,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { LayoutUtil::ClearLayout(instruction->mutable_shape()); - // Create a copy of an operand if the operand instruction's layout does not - // match the use constraint (OperandLayoutConstraint). - for (int64 operand_no = 0; operand_no < instruction->operand_count(); - ++operand_no) { - const ShapeLayout* operand_layout = - constraints.OperandLayout(instruction, operand_no); - if (operand_layout != nullptr) { - TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout, - instruction, operand_no)); - } - } - // Set the layouts of the array shapes this instruction defines as indicated // by the respective BufferLayoutConstraints. Any array shapes in the output // of the instruction which are not defined by the instruction (eg, array @@ -1644,6 +1641,18 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, return Status::OK(); })); + // Create a copy of an operand if the operand instruction's layout does not + // match the use constraint (OperandLayoutConstraint). + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + const ShapeLayout* operand_layout = + constraints.OperandLayout(instruction, operand_no); + if (operand_layout != nullptr) { + TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout, + instruction, operand_no)); + } + } + // Fusion instructions require some layouts to be set on fused instructions // inside the fusion instruction. if (instruction->opcode() == HloOpcode::kFusion) { @@ -1698,9 +1707,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { } Status LayoutAssignment::RunOnComputation( - ComputationLayout* computation_layout, - const TuplePointsToAnalysis& points_to_analysis, - HloComputation* computation, + ComputationLayout* computation_layout, HloComputation* computation, ChannelLayoutConstraints* channel_constraints) { VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; @@ -1725,7 +1732,7 @@ Status LayoutAssignment::RunOnComputation( } // Construct LayoutConstraints with all layout constraints of the computation. - LayoutConstraints constraints(points_to_analysis, computation); + LayoutConstraints constraints(*points_to_analysis_, computation); // Add constraints required for correctness on all backends (eg, entry // parameter layout constraints). @@ -1742,7 +1749,7 @@ Status LayoutAssignment::RunOnComputation( // which lack a layout constraint. for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) { unconstrained_layout_instructions_.insert( - points_to_analysis.GetBuffer(buffer_id).instruction()); + points_to_analysis_->GetBuffer(buffer_id).instruction()); } // While any unconstrained buffers remain, pick an arbitrary buffer, give it a @@ -1753,7 +1760,7 @@ Status LayoutAssignment::RunOnComputation( // Arbitrarily pick the first unconstrained buffer and give it the default // layout (or the literal layout, in case of constants). By construction // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id. - const LogicalBuffer& buffer = points_to_analysis.GetBuffer( + const LogicalBuffer& buffer = points_to_analysis_->GetBuffer( *constraints.unconstrained_buffer_ids().begin()); const HloInstruction* instruction = buffer.instruction(); Layout new_layout = @@ -1796,7 +1803,12 @@ Status LayoutAssignment::RunOnComputation( // layout constraint. if (constraints.ResultLayout() != nullptr && !constraints.ResultLayout()->MatchesLayoutInShape( - computation->root_instruction()->shape())) { + computation->root_instruction()->shape(), + /*minor_to_major_only=*/true)) { + if (conditional_mismatch_.count(computation) > 0) { + *FindOrDie(computation_layouts_, computation).mutable_result_layout() = + FindOrDie(conditional_mismatch_, computation).result_layout(); + } TF_ASSIGN_OR_RETURN( HloInstruction * new_root, CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), @@ -1910,20 +1922,42 @@ Status LayoutAssignment::PropagateComputationLayouts( << ": " << computed_computation_layout.result_layout().ToString(); *result_layout = computed_computation_layout.result_layout(); } else { - TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + computed_computation_layout.result_layout().shape(), + result_layout->shape())); } return Status::OK(); } StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); - XLA_VLOG_LINES(3, module->ToString()); - if (VLOG_IS_ON(10)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before layout assignment", - module->config().debug_options()); - } TF_RETURN_IF_ERROR(Init()); + std::unique_ptr call_graph = CallGraph::Build(module); + auto computations = module->computations(); + // Clone Conditional computations wiht multiple callsites. + for (HloComputation* computation : computations) { + CallGraphNode& node = call_graph->GetNode(computation); + if (node.caller_callsites().size() == 1) { + continue; + } + if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) { + return caller.instruction()->opcode() == HloOpcode::kConditional; + })) { + continue; + } + for (int64 i = 0; i < node.caller_callsites().size() - 1; ++i) { + HloInstruction* caller = node.caller_callsites()[i].instruction(); + if (caller->opcode() == HloOpcode::kConditional) { + for (int64 k = 0; k < caller->branch_count(); ++k) { + if (computation == caller->branch_computation(k)) { + caller->set_branch_computation( + k, module->AddEmbeddedComputation(computation->Clone())); + break; + } + } + } + } + } // Verify computation layout is sane. const HloComputation* entry = module->entry_computation(); @@ -1956,19 +1990,21 @@ StatusOr LayoutAssignment::Run(HloModule* module) { TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); + points_to_analysis_ = std::move(points_to_analysis); for (auto* computation : module->MakeComputationPostOrder()) { if (computation->IsFusionComputation()) { continue; } if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation( - entry_computation_layout_, *points_to_analysis, - module->entry_computation(), channel_layout_constraints_)); + TF_RETURN_IF_ERROR(RunOnComputation(entry_computation_layout_, + module->entry_computation(), + channel_layout_constraints_)); } else { ComputationLayout* computation_layout = - (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation); - TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation, + (i == 0 || conditional_mismatch_.count(computation) > 0) + ? nullptr + : &FindOrDie(computation_layouts_, computation); + TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation, channel_layout_constraints_)); } } @@ -1977,13 +2013,6 @@ StatusOr LayoutAssignment::Run(HloModule* module) { entry_computation_layout_)); TF_RETURN_IF_ERROR(CheckLayouts(module)); - VLOG(3) << "After layout assignment:"; - XLA_VLOG_LINES(3, module->ToString()); - if (VLOG_IS_ON(10)) { - hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after layout assignment", - module->config().debug_options()); - } // All layouts are reset then reassigned by this pass. return true; } @@ -2001,6 +2030,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCompare: case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConditional: @@ -2012,24 +2042,18 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFft: case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kIsFinite: - case HloOpcode::kLe: case HloOpcode::kLog: case HloOpcode::kLog1p: - case HloOpcode::kLt: case HloOpcode::kMap: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: - case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: @@ -2056,6 +2080,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kSqrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: + case HloOpcode::kPopulationCount: case HloOpcode::kTriangularSolve: case HloOpcode::kCholesky: case HloOpcode::kTupleSelect: @@ -2080,6 +2105,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kIota: case HloOpcode::kOutfeed: case HloOpcode::kParameter: + case HloOpcode::kPartitionId: case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReduce: @@ -2109,6 +2135,7 @@ bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { Status LayoutAssignment::Init() { computation_layouts_.clear(); + conditional_mismatch_.clear(); *entry_computation_layout_ = saved_entry_computation_layout_; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 5701cb5b025..6b6b3665317 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -359,7 +359,7 @@ class LayoutAssignment : public HloModulePass { // the cost of `instruction`. `output_layout` is the layout of `instruction`. // Returns null if it can't decide the best layout. // Precondition: `instruction` and the operand are array-shaped. - std::unique_ptr ChooseOperandLayoutFromOutputLayout( + virtual std::unique_ptr ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no); // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of @@ -396,7 +396,6 @@ class LayoutAssignment : public HloModulePass { // Layouts constraints are added, then propagated until all LogicalBuffers in // the computation are constrained. Status RunOnComputation(ComputationLayout* computation_layout, - const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints); @@ -466,9 +465,9 @@ class LayoutAssignment : public HloModulePass { // Creates a copy of the given operand if the operand's layout does not match // the given layout. This copy replaces the use in the given instruction. // Tuple operands will be deep-copied. - Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no); + virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); // Registers a copy instruction added by the layout assignment pass. void RegisterAddedCopy(HloInstruction* copy) { @@ -504,6 +503,9 @@ class LayoutAssignment : public HloModulePass { // instructions can be set to match the computation. std::map computation_layouts_; + // Map from branch computations to the result layout they shuould apply. + std::map conditional_mismatch_; + // Every copy added to the module by the layout assignment pass is registered // here. absl::flat_hash_set added_copies_; @@ -521,6 +523,9 @@ class LayoutAssignment : public HloModulePass { // host. ChannelLayoutConstraints host_channel_constraints_; + // Module points to analysis that can be updated for cloned computations. + std::unique_ptr points_to_analysis_; + // The set of HLO instructions which lacked any layout constraint, thus // receiving propagated default layouts. absl::flat_hash_set unconstrained_layout_instructions_; diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index c8cf3c47d38..5597afc15a3 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -570,13 +570,10 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); - EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); - EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), - ElementsAre(1, 0)); - EXPECT_THAT(concatenate->operand(1)->shape().layout().minor_to_major(), - ElementsAre(1, 0)); - EXPECT_THAT(concatenate->shape().layout().minor_to_major(), - ElementsAre(1, 0)); + EXPECT_EQ(concatenate->operand(0)->shape().layout().minor_to_major(), + concatenate->operand(1)->shape().layout().minor_to_major()); + EXPECT_EQ(concatenate->shape().layout().minor_to_major(), + concatenate->operand(1)->shape().layout().minor_to_major()); } // Test layout assignment of a transpose into a bitcast based on its operand. @@ -1084,7 +1081,7 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) counter.1 = s32[] get-tuple-element(tup.1), index=0 five = s32[] constant(5) - ROOT lt = pred[] less-than(counter.1, five) + ROOT lt = pred[] compare(counter.1, five), direction=LT } body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) { diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index 382b5751202..82e955c818e 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -24,7 +24,7 @@ namespace xla { Status LLVMCompiler::RunHloPassesOnModuleGroup( HloModuleGroup* module_group, absl::Span executors, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { return Unimplemented( "Model partitioning not implemented for the CPU/GPU compilers!"); } @@ -33,7 +33,7 @@ StatusOr>> LLVMCompiler::RunBackendOnModuleGroup( std::unique_ptr module_group, std::vector> stream_exec, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { return Unimplemented( "Model partitioning not implemented for the CPU/GPU compilers!"); } @@ -41,7 +41,7 @@ LLVMCompiler::RunBackendOnModuleGroup( StatusOr>> LLVMCompiler::Compile( std::unique_ptr module_group, std::vector> stream_execs, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { // Tensorflow tries to enable the following behaviors in all its threads: // // - Denormals are zero (DAZ): roughly, operations treat denormal floats as diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index 182d8edbe30..888815bea3d 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -37,7 +37,7 @@ class LLVMCompiler : public Compiler { // A callback of this type can be run before and/or after IR-level // optimization to e.g. dump out the generated IR to disk or gather some // statistics. - using ModuleHook = std::function; + using ModuleHook = std::function; void SetPreOptimizationHook(ModuleHook hook) { CHECK(!user_pre_optimization_hook_) @@ -61,28 +61,28 @@ class LLVMCompiler : public Compiler { // StatusOr> RunBackend( // std::unique_ptr module, // se::StreamExecutor* stream_exec, - // DeviceMemoryAllocator* device_allocator) + // se::DeviceMemoryAllocator* device_allocator) // StatusOr> RunHloPasses( // std::unique_ptr module, // se::StreamExecutor* stream_exec, - // DeviceMemoryAllocator* device_allocator) + // se::DeviceMemoryAllocator* device_allocator) using Compiler::RunBackend; using Compiler::RunHloPasses; Status RunHloPassesOnModuleGroup( HloModuleGroup* module_group, absl::Span executors, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr>> RunBackendOnModuleGroup( std::unique_ptr module_group, std::vector> stream_exec, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, - DeviceMemoryAllocator* device_allocator) override; + se::DeviceMemoryAllocator* device_allocator) override; protected: ModuleHook user_pre_optimization_hook_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index c5d59fb28e0..e1303f60779 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -49,8 +49,8 @@ tf_cc_test( srcs = ["alias_analysis_test.cc"], deps = [ ":alias_analysis", + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/core:test", @@ -67,9 +67,11 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service/cpu:cpu_options", "//tensorflow/core:lib", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", @@ -111,6 +113,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -146,6 +149,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/gpu:target_util", "//tensorflow/core:lib", "@llvm//:core", ], @@ -161,6 +165,7 @@ cc_library( ":llvm_util", ":loop_emitter", ":tuple_ops", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -169,6 +174,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -206,6 +212,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", + "//tensorflow/compiler/xla/service/gpu:target_util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -236,7 +243,7 @@ cc_library( hdrs = ["kernel_support_library.h"], deps = [ ":llvm_loop", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + ":llvm_util", "@com_google_absl//absl/strings", "@llvm//:core", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index ce3d922ca7a..761c6879db8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -35,7 +35,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, const ShapeIndex& index) { BufferAllocation::Slice buffer_slice; if (hlo.opcode() == HloOpcode::kParameter && - hlo.parent() == hlo.parent()->parent()->entry_computation()) { + hlo.parent() == module_.entry_computation()) { // Entry computation parameters may alias with each other but may not alias // with our temporary buffers. buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0); @@ -63,7 +63,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, } if (module_.config().debug_options().xla_llvm_enable_noalias_metadata()) { - llvm::MDNode*& noalias_md = noalias_metadata_[buffer_slice]; + llvm::MDNode*& noalias_md = noalias_metadata_[{buffer_slice, &hlo}]; if (noalias_md == nullptr) { noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(), assignment_, hlo); @@ -78,12 +78,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, .xla_llvm_enable_invariant_load_metadata()) { // Parameters of the entry computation are never stored to, loading from a // parameter pointer should always return the same result within a loop. - if (hlo.opcode() == HloOpcode::kParameter) { - const std::vector& parameter_instructions = - module_.entry_computation()->parameter_instructions(); - if (absl::c_linear_search(parameter_instructions, &hlo)) { - array->MarkInvariantOverWholeProgram(context_); - } + if (hlo.opcode() == HloOpcode::kParameter && + hlo.parent() == module_.entry_computation()) { + array->MarkInvariantOverWholeProgram(context_); } } } @@ -115,7 +112,7 @@ llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( llvm::MDBuilder metadata_builder(domain->getContext()); llvm::MDNode* scope = metadata_builder.createAliasScope( - AsStringRef("buffer: " + buffer_slice.ToString()), domain); + "buffer: " + buffer_slice.ToString(), domain); llvm::MDNode* scope_list = llvm::MDNode::get(domain->getContext(), scope); return scope_list; } @@ -197,7 +194,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( std::vector scopes; for (const BufferAllocation::Slice noalias_slice : buffers) { llvm::MDNode* scope = metadata_builder.createAliasScope( - AsStringRef("buffer: " + noalias_slice.ToString()), domain); + "buffer: " + noalias_slice.ToString(), domain); scopes.push_back(scope); } llvm::MDNode* noalias_list = diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 12e2f449e23..7e7a6f6f820 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -79,9 +79,11 @@ class AliasAnalysis { absl::flat_hash_map alias_scope_metadata_; - // A map from a buffer slice to metadata corresponding to its noalias - // metadata. - absl::flat_hash_map noalias_metadata_; + // A map from a buffer slice and producer to metadata corresponding to its + // noalias metadata. + absl::flat_hash_map, + llvm::MDNode*> + noalias_metadata_; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index b6ae4932f57..db60e08472d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" + #include #include -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/test.h" @@ -29,7 +30,7 @@ class AliasAnalysisTest : public CpuCodegenTest {}; void FakeCustomCallTarget(float* out, float** in) {} -REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget); TEST_F(AliasAnalysisTest, EmbeddedComputationParamsMayAliasTemps) { const char* hlo_string = R"( @@ -46,7 +47,7 @@ condition { condition.state = f32[] parameter(0) addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget" add = f32[] add(addend, condition.state) - ROOT greater-than = pred[] greater-than(const.100, add) + ROOT greater-than = pred[] compare(const.100, add), direction=GT } ENTRY while3 { diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index 1ea5a42b0b3..f96c985da71 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -23,7 +23,7 @@ static const HloInstruction& InstrForConstantBufferAllocation( CHECK(allocation.is_constant()); HloInstruction* const_instr = nullptr; for (const auto& buffer_offset_pair : allocation.assigned_buffers()) { - const LogicalBuffer* buffer = buffer_offset_pair.first; + const BufferValue* buffer = buffer_offset_pair.first; // BufferAssignment may have assigned non-constant instructions to this // allocation too so we can't CHECK this condition. E.g. for // diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 3acceccfa55..4974cb57db3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -47,29 +47,30 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Read start indices from start_indices_generator. const int64 rank = output_shape.rank(); - IrArray::Index start_index(b->getInt64Ty(), rank); + std::vector start_multi_index(rank); for (int64 i = 0; i < rank; ++i) { - TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(i)); + TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i)); llvm::Value* output_dim_size = llvm::ConstantInt::get( - start_index[i]->getType(), output_shape.dimensions(i)); + start_multi_index[i]->getType(), output_shape.dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( - start_index[i]->getType(), update_shape.dimensions(i)); + start_multi_index[i]->getType(), update_shape.dimensions(i)); // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size); - llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); - start_index[i] = + llvm::Value* zero = + llvm::ConstantInt::get(start_multi_index[i]->getType(), 0); + start_multi_index[i] = b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, - zero, start_index[i]), - zero, start_index[i]); + zero, start_multi_index[i]), + zero, start_multi_index[i]); - start_index[i] = + start_multi_index[i] = b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, - max_bound, start_index[i]), - max_bound, start_index[i]); + max_bound, start_multi_index[i]), + max_bound, start_multi_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { @@ -78,14 +79,16 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // // output_index[dim] = start_index[dim] + update_index[dim] // - IrArray::Index output_index(start_index.GetType(), rank); + std::vector output_multi_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* start_index0 = - b->CreateSExtOrBitCast(start_index[i], update_index[i]->getType()); - output_index[i] = b->CreateAdd(start_index0, update_index[i]); + llvm::Value* start_index0 = b->CreateSExtOrBitCast( + start_multi_index[i], update_index[i]->getType()); + output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. + IrArray::Index output_index(output_multi_index, output_shape, + b->getInt64Ty()); TF_ASSIGN_OR_RETURN(llvm::Value * update_data, update_array_generator(update_index)); output_array.EmitWriteArrayElement(output_index, update_data, b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index 7fe803d1f8d..c4da28229d0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -45,7 +45,7 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); HloInstruction* fused_root = fusion->fused_expression_root(); if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice || - fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { + !fusion->IsLoopFusion()) { return false; } // Walk DynamicUpdateSlice operand(0) to fused parameter and get its diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 9f094330bdc..630b58f3f47 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -15,14 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -58,9 +66,9 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { } VLOG(3) << "The cached generated value can't be reused, because it is in " "a different BB (" - << llvm_ir::AsString(generated_value_bb->getName()) + << generated_value_bb->getName().str() << ") from the current insertion block (" - << llvm_ir::AsString(b_->GetInsertBlock()->getName()) << ")."; + << b_->GetInsertBlock()->getName().str() << ")."; } TF_ASSIGN_OR_RETURN(generated_value_cache_[hlo][index.multidim()], @@ -113,9 +121,9 @@ Status FusedIrEmitter::HandleGetTupleElement( } // Lookup tuple element pointer. - return llvm_ir::EmitGetTupleElement( - get_tuple_element->shape(), get_tuple_element->tuple_index(), - /*alignment=*/1, tuple_ptr, b_, module_); + return llvm_ir::EmitGetTupleElement(get_tuple_element->shape(), + get_tuple_element->tuple_index(), + /*alignment=*/1, tuple_ptr, b_); }; if (!get_tuple_element->shape().IsTuple()) { @@ -195,4 +203,101 @@ FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator( return indexed_generators_.at(instruction); } +bool FusedIrEmitter::IsFusedIrEmitterInefficient( + const HloInstruction* consumer, const HloInstruction* producer) { + if (consumer->opcode() != HloOpcode::kFusion) { + return false; + } + // Collects for each instruction in the fusion node from which (indirect) + // users newly created index values are passed. Roughly speaking, we reuse + // index values if the shapes are equal when ignoring the element type (we may + // reuse also if the shape change is a bitcast, but we don't consider that + // here). By ignoring potential reuses our estimate whether the fusion emitter + // is inefficient is a bit more conservative than necessary. + absl::flat_hash_map> + indexing_users; + // Stores the number of different index accesses for each instruction in the + // fusion node. The fusion emitter caches access with the same index, so this + // value indicates how many times a specific instruction will be emitted. + absl::flat_hash_map index_usage_count; + index_usage_count[consumer] = 1; + + auto evaluate_fusion_computation = [&indexing_users, &index_usage_count]( + const HloInstruction* fusion) { + auto postorder = + fusion->fused_instructions_computation()->MakeInstructionPostOrder(); + std::reverse(postorder.begin(), postorder.end()); + for (const auto* instruction : postorder) { + if (instruction->opcode() == HloOpcode::kParameter) { + continue; + } + int64& total = index_usage_count[instruction]; + if (indexing_users[instruction].empty()) { + total = index_usage_count[fusion]; + } else { + total = 0; + for (const auto* user : indexing_users[instruction]) { + int64 weight = 1; + // Concatenate is special: the index differs for each operand, so + // in the worst case we have to deal with as many index values as + // the number of operands of Concatenate. By considering the worst + // case, we are more conservative than necessary regarding + // refusing to fuse. + if (user->opcode() == HloOpcode::kConcatenate) { + weight = user->operand_count(); + } + total += index_usage_count[user] * weight; + } + } + for (const auto* operand : instruction->operands()) { + // For simplicity we assume that all shape and layout changing + // operations invalidate index reuse. + if (Shape::Equal().IgnoreElementType()(operand->shape(), + instruction->shape())) { + // If the index is reused, it means the operand gets index values + // from the same set of (indirect) users as 'instruction' itself. + indexing_users[operand].insert(indexing_users[instruction].begin(), + indexing_users[instruction].end()); + } else { + // If the index is not reused, it means 'instruction' computes a + // new index derived from the index it gets. + indexing_users[operand].insert(instruction); + } + } + } + }; + evaluate_fusion_computation(consumer); + + // Also account for the 'producer' if it would be fused. Find the operand it + // corresponds to. + for (int64 operand_num = 0; operand_num < consumer->operand_count(); + ++operand_num) { + if (consumer->operand(operand_num) == producer) { + auto instruction = consumer->fused_parameter(operand_num); + int64& total = index_usage_count[producer]; + total = 0; + for (const auto* user : indexing_users[instruction]) { + total += index_usage_count[user]; + } + break; + } + } + + // If 'producer' is a fusion node as well, also evaluate it. + if (producer->opcode() == HloOpcode::kFusion) { + evaluate_fusion_computation(producer); + } + + // Sum up the total number of emitted ops. + int64 total = 0; + for (const auto& entry : index_usage_count) { + total += entry.second; + } + + // Check that the code duplication has at most a factor of 15 (where 15 is an + // arbitrary constant that seems to work). + return total > 15 * index_usage_count.size(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index e6d52a580c0..b1aa6d59634 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -91,6 +91,14 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { tiled_parameter_info_ = info; } + // Evaluates whether fusing 'producer' into 'consumer' might cause exponential + // behavior in FusedIrEmitter. We currently can have exponential time/memory + // requirements for emitting certain fusion kernels, in which case we don't + // want to fuse. + // TODO(b/119692968): Remove this once we have fixed our fusion emitter. + static bool IsFusedIrEmitterInefficient(const HloInstruction* consumer, + const HloInstruction* producer); + protected: // Returns the IrArrays for the fusion instruction operands. llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 8ee07ae8331..241eea87a30 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,6 +29,14 @@ limitations under the License. namespace xla { namespace llvm_ir { +IrArray::Index::Index(absl::Span multidim, + llvm::Value* linear, const Shape& shape, + llvm::Type* index_type) + : Index(multidim, shape, index_type) { + CHECK_NE(linear, nullptr); + linear_ = linear; +} + void IrArray::Index::Delinearize(std::vector* multidim, llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b) const { @@ -74,36 +82,28 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, } IrArray::Index::Index(absl::Span multidim, - llvm::Value* linear, const Shape& shape) + absl::Span dimensions, + llvm::Type* index_type) + : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions), + index_type) {} + +IrArray::Index::Index(absl::Span multidim, + const Shape& shape, llvm::Type* index_type) : multidim_(multidim.begin(), multidim.end()), - linear_(linear), + linear_(nullptr), layout_(shape.layout()), - dims_(shape.dimensions().begin(), shape.dimensions().end()) { - if (size()) { - index_type_ = multidim_[0]->getType(); - } else { - CHECK_NE(linear_, nullptr); - index_type_ = linear_->getType(); - } + dims_(shape.dimensions().begin(), shape.dimensions().end()), + index_type_(index_type) { CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); + for (const auto* dim : multidim) { + CHECK_NE(dim, nullptr); + } CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; } -IrArray::Index::Index(absl::Span multidim, - const Shape& shape, llvm::IRBuilder<>* b) - : multidim_(multidim.begin(), multidim.end()), - layout_(shape.layout()), - dims_(shape.dimensions().begin(), shape.dimensions().end()) { - CHECK_GT(multidim_.size(), 0); - index_type_ = multidim[0]->getType(); - CHECK_NE(index_type_, nullptr); - CHECK_EQ(shape.dimensions_size(), multidim.size()); - CHECK(LayoutUtil::HasLayout(shape)); -} - IrArray::IrArray(llvm::Value* base_ptr, Shape shape) : base_ptr_(base_ptr), shape_(std::move(shape)) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); @@ -117,10 +117,10 @@ IrArray::IrArray(llvm::Value* base_ptr, Shape shape) ++depth; } - if (!shape_->IsArray() || ShapeUtil::IsScalar(*shape_)) { + if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { - DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString(); + DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString(); } } @@ -136,8 +136,7 @@ bool IrArray::Index::LinearValidOnShape(const Shape& a) const { IrArray::Index IrArray::Index::SourceIndexOfReshape( const Shape& output_shape, const Shape& input_shape, llvm::IRBuilder<>* builder) const { - const auto& target_index = *this; - CHECK_EQ(target_index.size(), output_shape.rank()); + CHECK_EQ(multidim_.size(), output_shape.rank()); std::vector> common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); @@ -146,16 +145,16 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { + absl::Span dimensions = + AsInt64Slice(output_shape.dimensions()) + .subspan(common_factors[k].second, + common_factors[k + 1].second - common_factors[k].second); llvm::Value* logical_linear_index = Index(absl::Span(multidim_).subspan( common_factors[k].second, common_factors[k + 1].second - common_factors[k].second), - index_type_) - .Linearize(AsInt64Slice(output_shape.dimensions()) - .subspan(common_factors[k].second, - common_factors[k + 1].second - - common_factors[k].second), - builder); + dimensions, index_type_) + .Linearize(dimensions, builder); // Delinearizes logical_linear_index for the source array in row-major // collapsed order. The first rank-1 indices are the remainder of the // linear index by each dimension size. @@ -178,30 +177,30 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) && LayoutUtil::HasLayout(output_shape) && ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { - return Index(source_multidim_index, linear(), input_shape); + return Index(source_multidim_index, linear(), input_shape, index_type_); } - return Index(source_multidim_index, index_type_); + return Index(source_multidim_index, input_shape, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfSlice( - const Shape& shape, absl::Span starts, + const Shape& operand_shape, absl::Span starts, absl::Span strides, llvm::IRBuilder<>* builder) const { - Index source_index(index_type_, multidim_.size()); + std::vector source_multi_index(multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; auto type = multidim_[i]->getType(); if (stride != 1) { - source_index[i] = builder->CreateAdd( + source_multi_index[i] = builder->CreateAdd( builder->CreateMul(multidim_[i], llvm::ConstantInt::get(type, stride)), llvm::ConstantInt::get(type, starts[i])); } else { - source_index[i] = builder->CreateAdd( + source_multi_index[i] = builder->CreateAdd( multidim_[i], llvm::ConstantInt::get(type, starts[i])); } } - return source_index; + return Index(source_multi_index, operand_shape, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfTranspose( @@ -214,10 +213,10 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose( if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) && LayoutUtil::HasLayout(shape) && ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { - return Index(operand_multidim_index, linear(), operand_shape); + return Index(operand_multidim_index, linear(), operand_shape, index_type_); } - return Index(operand_multidim_index); + return Index(operand_multidim_index, operand_shape, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfBitcast( @@ -246,11 +245,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( scale *= shape.dimensions(dimension); } - // Now delinearize it for the input of the bitcast. - std::vector multi_index(operand_shape.dimensions_size()); - Delinearize(&multi_index, linear_index, operand_shape, builder); - - return Index(multi_index, linear_index, operand_shape); + return Index(linear_index, operand_shape, builder); } IrArray::Index IrArray::Index::SourceIndexOfBroadcast( @@ -264,7 +259,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( } if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || !LayoutUtil::HasLayout(shape)) { - return Index(source_index, index_type_); + return Index(source_index, operand_shape, index_type_); } // High-level idea: we can reuse the linear index if the broadcasted // dimensions are contiguous, and this part of the operation is a bitcast. @@ -286,7 +281,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( bool contiguous_broadcast_dimensions = max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; if (!contiguous_broadcast_dimensions) { - return Index(source_index, index_type_); + return Index(source_index, operand_shape, index_type_); } // Check if the mapped dimensions are a bitcast. std::vector operand_logical_to_physical = @@ -294,7 +289,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( for (int64 i = 0; i < rank; ++i) { if (operand_logical_to_physical[i] != logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { - return Index(source_index, index_type_); + return Index(source_index, operand_shape, index_type_); } } llvm::Value* linear = linear_; @@ -303,9 +298,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } if (divisor > 1) { - linear = builder->CreateUDiv( - linear, - IrArray::Index(linear->getType()).GetConstantWithIndexType(divisor)); + linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor)); } if (min_broadcasted_dimension > 0) { int64 mod = 1; @@ -313,11 +306,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( ++i) { mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - linear = builder->CreateURem( - linear, - IrArray::Index(linear->getType()).GetConstantWithIndexType(mod)); + linear = builder->CreateURem(linear, GetConstantWithIndexType(mod)); } - return Index(source_index, linear, operand_shape); + return Index(source_index, linear, operand_shape, index_type_); } llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, @@ -341,20 +332,22 @@ llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, llvm::IRBuilder<>* b, - absl::string_view name) const { - if (ShapeUtil::IsScalar(*shape_)) { + absl::string_view name, + bool use_linear_index) const { + if (ShapeUtil::IsScalar(shape_)) { // Special handling of scalars: a scalar pretends to have the same value for // every index, thus effectively implementing broadcasting of its value // over higher-rank arrays. return base_ptr_; } - CHECK_EQ(index.size(), shape_->rank()); + CHECK_EQ(index.size(), shape_.rank()); + CHECK(index.ShapeIsCompatible(shape_)); - if (index.LinearValidOnShape(*shape_)) { + if (use_linear_index && index.LinearValidOnShape(shape_)) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); return b->CreateInBoundsGEP( b->CreateBitCast(base_ptr_, - PrimitiveTypeToIrType(shape_->element_type(), module) + PrimitiveTypeToIrType(shape_.element_type(), module) ->getPointerTo()), {index.linear()}, llvm_ir::AsStringRef(name)); } @@ -364,7 +357,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, // When dimension i is of size 1, LLVM optimization is able to replace // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to // produce better code in some cases. - auto dim = shape_->dimensions(i); + auto dim = shape_.dimensions(i); actual_index.push_back( dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]); } @@ -377,8 +370,8 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, CHECK_GT(index.size(), 0); std::vector gep_indices( 1, llvm::ConstantInt::get(index[0]->getType(), 0)); - for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { - int64 dimension = LayoutUtil::Major(shape_->layout(), i); + for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); gep_indices.push_back(actual_index[dimension]); } return b->CreateInBoundsGEP(base_ptr_, gep_indices, @@ -399,16 +392,20 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - absl::string_view name) const { - llvm::Value* element_address = EmitArrayElementAddress(index, b, name); + absl::string_view name, + bool use_linear_index) const { + llvm::Value* element_address = + EmitArrayElementAddress(index, b, name, use_linear_index); llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); return load; } void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, - llvm::IRBuilder<>* b) const { - llvm::Value* element_address = EmitArrayElementAddress(index, b); + llvm::IRBuilder<>* b, + bool use_linear_index) const { + llvm::Value* element_address = + EmitArrayElementAddress(index, b, "", use_linear_index); llvm::StoreInst* store = b->CreateStore(value, element_address); AnnotateLoadStoreInstructionWithMetadata(store); } @@ -423,18 +420,5 @@ IrArray IrArray::CastToShape(const Shape& new_shape, return new_irarray; } -/* static */ IrArray::Index IrArray::BumpIndex(const Index& index, - int64 which_dimension, - int64 addend, - llvm::IRBuilder<>* b) { - Index new_index = index; - new_index[which_dimension] = b->CreateAdd( - index[which_dimension], - llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "", - /*HasNUW=*/true, - /*HasNSW=*/true); - return new_index; -} - } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index b706ebd311c..b043f95b1de 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -42,12 +43,9 @@ namespace llvm_ir { // are supported. class IrArray { public: - // A multidimensional index into an IrArray. The index for dimension zero is - // first in the vector. This is the reverse order of the notation used for - // describing the dimensions of an array. That is, for a [4 x 3 x 2] array - // dimension zero has size 2, dimension one has size 3, and dimension two has - // size 4. Thus the index {1, 2, 3} indexes the last element of this [4 x 3 x - // 2] array. + // A multidimensional index into an IrArray. All the runtime indices + // (multidim) and dimensions (Shape::dimensions(), absl::Span) + // are major-first. // // This may also keep a linear index and the layout and dimensions it was // emitted for; if the shape where this `Index` is used matches, the linear @@ -55,39 +53,11 @@ class IrArray { // multidimensional index, which LLVM DCE can delete. class Index { public: - // Constructs an index of rank "size". Each dimension of the index is - // initialized to "value". - explicit Index(size_t size, llvm::Value* value) - : multidim_(size, value), index_type_(value->getType()) { - CHECK_NE(index_type_, nullptr); - } - - // Constructs an index of rank "size". Each dimension of the index is - // initialized to nullptr. - explicit Index(llvm::Type* index_ty, size_t size = 0) - : multidim_(size, nullptr), index_type_(index_ty) { + // Constructs an index for a scalar shape. + explicit Index(llvm::Type* index_ty) : index_type_(index_ty) { CHECK(index_ty->isIntegerTy()); } - // Constructs an index from multi-dimensional index "multidim". The linear - // index is set to nullptr. - explicit Index(absl::Span multidim, - llvm::Type* index_ty = nullptr) - : multidim_(multidim.begin(), multidim.end()) { - if (size() == 0) { - index_type_ = index_ty; - } else { - index_type_ = (*this)[0]->getType(); - if (index_ty != nullptr) { - CHECK_EQ(index_type_, index_ty); - } - } - CHECK_NE(index_type_, nullptr); - CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) { - return index_type_ == v->getType(); - })); - } - // Constructs an index from linear index "linear" and computes the // multi-dimensional index from "linear" and "shape". "b" is the IR // builder to emit the index of each dimension in the multi-dimensional @@ -96,70 +66,65 @@ class IrArray { // Precondition: "shape" has a layout. Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); - // Constructs an index from the given multi-dimensional index and the shape - // that it indexes into. + // Constructs an index from a multi-dimensional index. 'shape' is the shape + // for which the multi-dimensional index is used. 'index_type' is the type + // of the index. // // Precondition: "shape" has a layout. Index(absl::Span multidim, const Shape& shape, - llvm::IRBuilder<>* b); + llvm::Type* index_type); - // Constructs an index from both a multi-dimensional index and a linear - // index. "shape" has the same meaning as that in the constructor that takes - // only a linear index. - Index(absl::Span multidim, llvm::Value* linear, - const Shape& shape); + // Same as above, but only the dimensions of the shape without layout is + // passed. The layout is assumed to be the default (descending + // minor-to-major) layout. + Index(absl::Span multidim, + absl::Span dimensions, llvm::Type* index_type); // Returns an index that adds `addend` to the given `dim` of the object. Index AddOffsetToDim(llvm::Value* addend, int64 dim, llvm::IRBuilder<>* b) const { - IrArray::Index index = *this; - index[dim] = b->CreateAdd(index[dim], addend); - return index; + Index with_offset = *this; + with_offset.linear_ = nullptr; + with_offset.multidim_[dim] = + b->CreateAdd(with_offset.multidim_[dim], addend); + return with_offset; } const std::vector& multidim() const { return multidim_; } + const std::vector& dims() const { return dims_; } llvm::Value* linear() const { return linear_; } size_t size() const { return multidim().size(); } llvm::Value* operator[](size_t i) const { return multidim()[i]; } - llvm::Value*& operator[](size_t i) { return mutable_multidim()[i]; } - void push_back(llvm::Value* value) { mutable_multidim().push_back(value); } - void InsertAt(int64 index, llvm::Value* value) { - CHECK_LE(index, size()); - mutable_multidim().insert(mutable_multidim().begin() + index, value); - } - void InsertAt(int64 index, int64 count, llvm::Value* value) { - CHECK_LE(index, size()); - mutable_multidim().insert(mutable_multidim().begin() + index, count, - value); - } - - using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; - iterator begin() { return mutable_multidim().begin(); } - iterator end() { return mutable_multidim().end(); } - const_iterator begin() const { return multidim().begin(); } const_iterator end() const { return multidim().end(); } - llvm::Value* back() const { return multidim().back(); } - bool LinearValidOnShape(const Shape& a) const; - // Given that "this" is the target index of a reshape from `operand_shape` - // to `shape`, returns the source index. - Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape, + bool ShapeIsCompatible(const Shape& a) const { + Shape own_shape = ShapeUtil::MakeShape(a.element_type(), dims_); + *own_shape.mutable_layout() = layout_; + return ShapeUtil::Equal(own_shape, a); + } + + // Given that "this" is the target index of a reshape from `input_shape` + // to `output_shape`, returns the source index. + Index SourceIndexOfReshape(const Shape& output_shape, + const Shape& input_shape, llvm::IRBuilder<>* builder) const; // Returns the index into the source operand from which a slice operation // selects a value to be placed into index "this". The slice is described // by starting indices `starts` and stride values `strides`. // - // Precondition: "this" is an index into a slice whose shape is `shape`. - Index SourceIndexOfSlice(const Shape& shape, absl::Span starts, + // Precondition: "this" is an index into a slice whose operand shape is + // `operand_shape`. + Index SourceIndexOfSlice(const Shape& operand_shape, + absl::Span starts, absl::Span strides, llvm::IRBuilder<>* builder) const; @@ -194,14 +159,14 @@ class IrArray { return llvm::ConstantInt::get(index_type_, c); } - void ClearLinearIndex() { linear_ = nullptr; } - private: - // Changing the multi-dimensional index invalidates the linear index. - std::vector& mutable_multidim() { - linear_ = nullptr; - return multidim_; - } + // Constructs an index from both a multi-dimensional index and a linear + // index. 'shape' is the shape on which the index is used. 'index_type' is + // the type of the index. + // + // Precondition: "shape" has a layout. + Index(absl::Span multidim, llvm::Value* linear, + const Shape& shape, llvm::Type* index_type); void Delinearize(std::vector* multidim, llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b) const; @@ -242,9 +207,7 @@ class IrArray { llvm::Value* GetBasePointer() const { return base_ptr_; } llvm::Type* GetElementLlvmType() const { return element_type_; } - const Shape& GetShape() const { - return *shape_; - } + const Shape& GetShape() const { return shape_; } // Emit a sequence of instructions to compute the address of the element in // the given array at the given index. Returns the address of the element as @@ -253,7 +216,8 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, - absl::string_view name = "") const; + absl::string_view name = "", + bool use_linear_index = true) const; // Attach metadata this IrArray instance knows about to "instruction". void AnnotateLoadStoreInstructionWithMetadata( @@ -266,15 +230,23 @@ class IrArray { // // The optional name is useful for debugging when looking at // the emitted LLVM IR. + // 'use_linear_index' can be used to specify whether the linear index (if + // available) or the multi-dimensional index should be used. llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - absl::string_view name = "") const; + absl::string_view name = "", + bool use_linear_index = true) const; // Emit IR to write the given value to the array element at the given index. + // 'use_linear_index' can be used to specify whether the linear index (if + // available) or the multi-dimensional index should be used. void EmitWriteArrayElement(const Index& index, llvm::Value* value, - llvm::IRBuilder<>* b) const; + llvm::IRBuilder<>* b, + bool use_linear_index = true) const; // Returns a new IrArray whose shape is "new_shape" and base pointer is a // bitcast of the base pointer of "this" IrArray. + // 'use_linear_index' can be used to specify whether the linear index (if + // available) or the multi-dimensional index should be used. IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { @@ -318,11 +290,6 @@ class IrArray { const std::map& metadata() const { return metadata_; } - // Bumps the "which_dimension" value within the provided index by the provided - // addend. - static Index BumpIndex(const Index& index, int64 which_dimension, - int64 addend, llvm::IRBuilder<>* b); - private: // Add the specified LLVM IR metadata to loads/stores associated with this // IrArray. @@ -337,7 +304,7 @@ class IrArray { llvm::Type* element_type_; // Shape of the XLA array. - absl::optional shape_; + Shape shape_; // The list of key/value pairs used when attaching metadata to emitted // loads/stores for this array. They keys are the metadata kinds and the diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 5eeb29c478a..e1dc7e74765 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -70,7 +70,7 @@ Status KernelSupportLibrary::IfWithStatus( } void KernelSupportLibrary::EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + const HloModuleConfig& module_config, llvm::IRBuilder<>* b, absl::string_view kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function& @@ -101,10 +101,9 @@ void KernelSupportLibrary::EmitAndCallOutlinedKernel( auto* function_type = llvm::FunctionType::get(b->getVoidTy(), arg_types, /*isVarArg=*/false); - function = llvm_ir::CreateFunction( - function_type, llvm::GlobalValue::InternalLinkage, - /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, kernel_name, module); + function = llvm_ir::CreateCpuFunction(function_type, + llvm::GlobalValue::InternalLinkage, + module_config, kernel_name, module); llvm::IRBuilder<>::InsertPointGuard guard(*b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 612b839cfa1..b66ce6b835e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -263,33 +263,33 @@ class KernelSupportLibrary { // in a nullptr llvm::Value* in its position to `kernel_body_generator`. // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + const HloModuleConfig& module_config, llvm::IRBuilder<>* b, absl::string_view kernel_name, ArgumentVector arguments, const std::function& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + const HloModuleConfig& module_config, llvm::IRBuilder<>* b, absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, const std::function& kernel_body_generator) { - EmitAndCallOutlinedKernel( - enable_fast_math, optimize_for_size, b, kernel_name, {arg0, arg1, arg2}, - [&](ArgumentVector args) { - kernel_body_generator(args[0], args[1], args[2]); - }); + EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2}, + [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], + args[2]); + }); } static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + const HloModuleConfig& module_config, llvm::IRBuilder<>* b, absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, llvm::Value* arg3, const std::function& kernel_body_generator) { EmitAndCallOutlinedKernel( - enable_fast_math, optimize_for_size, b, kernel_name, - {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { + module_config, b, kernel_name, {arg0, arg1, arg2, arg3}, + [&](ArgumentVector args) { kernel_body_generator(args[0], args[1], args[2], args[3]); }); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index cd8dd72cd77..2ef844ffa62 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -53,28 +54,6 @@ Shape MergeDimensions(absl::Span segs, const Shape& shape) { dimensions); } -// Given an index for a shape, return the equivalent new index if the shape is -// reshaped to another shape. -IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape, - const Shape& reshaped_shape, - llvm::IRBuilder<>* b) { - auto bounds = shape.dimensions(); - auto minor_to_major = shape.layout().minor_to_major(); - llvm::Value* linear_index = index.GetConstantWithIndexType(0); - int64 multiplier = 1; - for (int i = 0; i < index.size(); ++i) { - int64 dim = minor_to_major[i]; - llvm::Value* addend = b->CreateMul( - index[dim], index.GetConstantWithIndexType(multiplier), "linearizing", - /*HasNUW=*/true, /*HasNSW=*/true); - linear_index = b->CreateAdd(linear_index, addend, "", - /*HasNUW=*/true, /*HasNSW=*/true); - multiplier *= bounds[dim]; - } - - return IrArray::Index(linear_index, reshaped_shape, b); -} - } // namespace absl::optional > FindTranspose021(const Shape& a, @@ -150,15 +129,14 @@ IrArray::Index KernelMappingScheme::GetUnnormalizedIndex( const IrArray::Index& normalized_shape_index, const Shape& unnormalized_shape) { DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size()); - Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout( - unnormalized_shape.element_type(), GetDimensionsInElements()); - return GetReshapedIndex(normalized_shape_index, output_shape, - unnormalized_shape, b_); + llvm::Value* linear = + normalized_shape_index.Linearize(GetDimensionsInElements(), b_); + return IrArray::Index(linear, unnormalized_shape, b_); } IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); + llvm::Value* block_id = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdx, {}, {}, b_); llvm_ir::AddRangeMetadata(0, GetNumberOfBlocks(), llvm::cast(block_id)); llvm::Value* linear_block_id = @@ -180,20 +158,20 @@ IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), "block_origin." + std::to_string(i))); } - return IrArray::Index(multidim, block_index[0]->getType()); + return IrArray::Index(multidim, dims_in_tiles_, block_index.GetType()); } IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( const IrArray::Index& tile_index) { - IrArray::Index elem_index = tile_index; + std::vector elem_multi_index = tile_index.multidim(); for (int i = DimY; i < DimTot; ++i) { - elem_index[i] = + elem_multi_index[i] = b_->CreateMul(tile_index[i], llvm::ConstantInt::get(tile_index[i]->getType(), GetTileSizeForDimension(i)), "tile_origin." + std::to_string(i)); } - return elem_index; + return IrArray::Index(elem_multi_index, dims_in_elems_, tile_index.GetType()); } llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType( @@ -218,8 +196,8 @@ std::tuple KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { // Calculate (y, x) coordinate of the thread in the 2D view of thread block // defined by (num_thread_y, num_thread_x) from thread_id. - llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); + llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b_); llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw); llvm::Value* thread_id_int = b_->CreateIntCast(thread_id_raw, index_ty, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 3a35405a2da..2e769b5588a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -73,8 +74,8 @@ void ForLoop::Emit(llvm::IRBuilder<>* b) { // Split the preheader to create an exit basic block. The exit basic block // will contain all instructions at or after insert_point. - exit_bb_ = preheader_bb_->splitBasicBlock( - insert_point, AsStringRef(GetQualifiedName("loop_exit"))); + exit_bb_ = preheader_bb_->splitBasicBlock(insert_point, + GetQualifiedName("loop_exit")); // splitBasicBlock adds an unconditional branch between the split basic // blocks. Remove it. An unconditional branch will be added below from the @@ -94,9 +95,8 @@ void ForLoop::Emit(llvm::IRBuilder<>* b) { llvm::Function* func = preheader_bb_->getParent(); b->SetInsertPoint(&func->getEntryBlock(), func->getEntryBlock().getFirstInsertionPt()); - llvm::Value* indvar_address = - b->CreateAlloca(start_index_->getType(), nullptr, - AsStringRef(GetQualifiedName("invar_address"))); + llvm::Value* indvar_address = b->CreateAlloca( + start_index_->getType(), nullptr, GetQualifiedName("invar_address")); // Preheader basic block. // Initialize induction variable starting index. Create branch to the header. @@ -110,8 +110,7 @@ void ForLoop::Emit(llvm::IRBuilder<>* b) { // Emit the loop conditional branch. Load and compare indvar with ending // index and jump to loop exit if equal. Jump to body otherwise. b->SetInsertPoint(header_bb_); - indvar_ = - b->CreateLoad(indvar_address, AsStringRef(GetQualifiedName("indvar"))); + indvar_ = b->CreateLoad(indvar_address, GetQualifiedName("indvar")); llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_); b->CreateCondBr(/*Cond=*/exit_cond, /*True=*/exit_bb_, /*False=*/body_bb_); @@ -236,25 +235,26 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, absl::string_view suffix) { std::vector dimensions(shape.rank()); std::iota(dimensions.begin(), dimensions.end(), 0); - return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); + return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix), + shape, index_type_); } -IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( +std::vector ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, absl::Span dimensions, absl::string_view suffix) { - llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); + std::vector multi_index(shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ llvm_ir::IrName(suffix, absl::StrCat(dimension))); - index[dimension] = loop->GetIndVarValue(); + multi_index[dimension] = loop->GetIndVarValue(); } - return index; + return multi_index; } -IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( +std::vector ForLoopNest::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, absl::string_view name_suffix) { // Prepares the dimension list we will use to emit the loop nest. Outermost @@ -262,26 +262,28 @@ IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( // 'dimension_to_skip' dimension. std::vector dimensions; const Shape& shape = operand_array.GetShape(); + // Initially get the dimensions in minor to major order, then reverse them. for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { if (dimension != dimension_to_skip) { dimensions.push_back(dimension); } } + absl::c_reverse(dimensions); // Create loop nest with one for-loop for each dimension of the // output. - llvm_ir::IrArray::Index index = + std::vector multi_index = AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); // Verify every dimension except the 'dimension_to_skip' dimension was set in // the index. - for (size_t dimension = 0; dimension < index.size(); ++dimension) { + for (size_t dimension = 0; dimension < multi_index.size(); ++dimension) { if (dimension == dimension_to_skip) { - DCHECK_EQ(nullptr, index[dimension]); + DCHECK_EQ(nullptr, multi_index[dimension]); } else { - DCHECK_NE(nullptr, index[dimension]); + DCHECK_NE(nullptr, multi_index[dimension]); } } - return index; + return multi_index; } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index ac3bba3c9fd..1dbc9745c08 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -43,6 +43,9 @@ enum class UnrollMode { // A class for constructing a for-loop in LLVM IR. class ForLoop { public: + ForLoop(const ForLoop&) = delete; + ForLoop& operator=(const ForLoop&) = delete; + // Emit a for-loop at the current insert point of the given IRBuilder. // // start_index and end_index are the loop bounds (end_index is not inclusive). @@ -169,18 +172,11 @@ class ForLoop { llvm::Value* indvar_; UnrollMode unroll_mode_; bool prevent_vectorization_; - - TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); }; // A simple class for constructing nested for-loops. class ForLoopNest { public: - explicit ForLoopNest(llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) - : ForLoopNest(/*name=*/"", b) { - SetIndexType(index_ty); - } - ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) : name_(name), @@ -190,6 +186,8 @@ class ForLoopNest { b_(b) { SetIndexType(index_ty); } + ForLoopNest(const ForLoopNest&) = delete; + ForLoopNest& operator=(const ForLoopNest&) = delete; // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have @@ -241,7 +239,7 @@ class ForLoopNest { // The return value is an index with the induction variables. The // size equals the rank of shape and there is a null for each // dimension that is not in "dimensions". - IrArray::Index AddLoopsForShapeOnDimensions( + std::vector AddLoopsForShapeOnDimensions( const Shape& shape, absl::Span dimensions, absl::string_view suffix); @@ -252,9 +250,9 @@ class ForLoopNest { // dimensions of the index are filled except for 'dimension_to_skip'. // name_suffix is the string to append to the names of LLVM constructs (eg, // basic blocks) constructed by this method. - IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, - int64 dimension_to_skip, - absl::string_view name_suffix); + std::vector EmitOperandArrayLoopNest( + const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, + absl::string_view name_suffix); // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been @@ -289,8 +287,6 @@ class ForLoopNest { llvm::IRBuilder<>* b_; llvm::Type* index_type_; - - TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest); }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 807296329c0..815598929f5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -31,6 +31,8 @@ limitations under the License. #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -58,14 +60,6 @@ llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) { } // namespace -string AsString(const std::string& str) { - return string(str.data(), str.length()); -} - -llvm::StringRef AsStringRef(absl::string_view str) { - return llvm::StringRef(str.data(), str.size()); -} - std::unique_ptr DropConstantInitializers( const llvm::Module& module) { std::unique_ptr cloned_module = CloneModule(module); @@ -81,7 +75,7 @@ string DumpModuleToString(const llvm::Module& module) { llvm::raw_string_ostream ostream(buffer_string); module.print(ostream, nullptr); ostream.flush(); - return AsString(buffer_string); + return buffer_string; } llvm::CallInst* EmitCallToIntrinsic( @@ -200,7 +194,7 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } // A Tuple contains an array of pointers. Use i8*. case TUPLE: // An Opaque is like a void*, use i8*. - case OPAQUE: + case OPAQUE_TYPE: return llvm::Type::getInt8PtrTy(module->getContext()); case TOKEN: // Tokens do not have a physical representation, but the compiler needs @@ -248,7 +242,7 @@ StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, return InternalError("Encoded shape size exceeded int32 size limit."); } *shape_size = static_cast(encoded_shape.size()); - return b->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape)); + return b->CreateGlobalStringPtr(encoded_shape); } StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, @@ -293,7 +287,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, absl::string_view name, llvm::IRBuilder<>* b, int alignment) { - llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); + llvm::IRBuilder<>::InsertPointGuard guard(*b); llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), function->getEntryBlock().getFirstInsertionPt()); @@ -302,7 +296,6 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, if (alignment != 0) { alloca->setAlignment(alignment); } - b->restoreIP(insert_point); return alloca; } @@ -334,7 +327,7 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after"))); + b->GetInsertPoint(), absl::StrCat(name, "-after")); } // Our basic block should now end with an unconditional branch. Remove it; @@ -507,24 +500,25 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { return ShapeUtil::ByteSizeOf(shape, pointer_size); } -llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) { +llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) { llvm::FastMathFlags flags; - if (fast_math_enabled) { - // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, - // AllowReciprocal, AllowContract, and ApproxFunc. - flags.setFast(); + if (!module_config.debug_options().xla_cpu_enable_fast_math()) { + return flags; } - return flags; -} -void SetTargetOptions(bool fast_math_enabled, - llvm::TargetOptions* target_options) { - // In LLVM backend flags, UnsafeFPMath does not explicitly imply - // NoInfs, etc. - target_options->UnsafeFPMath = fast_math_enabled; - target_options->NoInfsFPMath = fast_math_enabled; - target_options->NoNaNsFPMath = fast_math_enabled; - target_options->NoSignedZerosFPMath = fast_math_enabled; + // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal, + // AllowContract, and ApproxFunc. + flags.setFast(); + + if (module_config.debug_options().xla_cpu_fast_math_honor_nans()) { + flags.setNoNaNs(false); + } + + if (module_config.debug_options().xla_cpu_fast_math_honor_infs()) { + flags.setNoInfs(false); + } + + return flags; } std::map MergeMetadata( @@ -575,14 +569,6 @@ std::map MergeMetadata( return result; } -static string GetProcessUniqueIrFileName(absl::string_view prefix) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); - - tensorflow::mutex_lock lock(mu); - return uniquer->GetUniqueName(prefix); -} - static Status CreateAndWriteStringToFile(const string& directory_name, const string& file_name, const string& text) { @@ -596,35 +582,34 @@ static Status CreateAndWriteStringToFile(const string& directory_name, return Status::OK(); } -Status DumpIRToDirectory(const string& directory_name, - const string& hlo_module_name, - const llvm::Module& llvm_module, bool optimized) { +void DumpIrIfEnabled(const HloModule& hlo_module, + const llvm::Module& llvm_module, bool optimized) { + const auto& debug_opts = hlo_module.config().debug_options(); + if (!DumpingEnabledForHloModule(hlo_module)) { + return; + } // We can end up compiling different modules with the same name when using // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. - string unique_and_safe_file_name = GetProcessUniqueIrFileName( - absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", - optimized ? "with" : "no", "-opt")); - - string ir_file_name = tensorflow::io::JoinPath( - directory_name, absl::StrCat(unique_and_safe_file_name, ".ll")); + string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt"); + DumpToFileInDirOrStdout(hlo_module, absl::StrCat(suffix, ".ll"), + DumpModuleToString(llvm_module)); // For some models the embedded constants can be huge, so also dump the module - // with the constants stripped to get IR that is easier to manipulate. - string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( - directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll")); - - TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( - directory_name, ir_file_name, DumpModuleToString(llvm_module))); - return CreateAndWriteStringToFile( - directory_name, ir_no_constant_initializers_file_name, - DumpModuleToString(*DropConstantInitializers(llvm_module))); + // with the constants stripped to get IR that is easier to manipulate. Skip + // this if we're dumping to stdout; there's no point in duplicating everything + // when writing to the terminal. + if (!DumpingToStdout(debug_opts)) { + DumpToFileInDir(hlo_module, absl::StrCat(suffix, "-noconst.ll"), + DumpModuleToString(*DropConstantInitializers(llvm_module))); + } } -llvm::Function* CreateFunction(llvm::FunctionType* function_type, - llvm::GlobalValue::LinkageTypes linkage, - bool enable_fast_math, bool optimize_for_size, - absl::string_view name, llvm::Module* module) { +llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + const HloModuleConfig& module_config, + absl::string_view name, + llvm::Module* module) { llvm::Function* function = llvm::Function::Create(function_type, linkage, AsStringRef(name), module); function->setCallingConv(llvm::CallingConv::C); @@ -634,17 +619,23 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type, // created by the JIT compiled code. function->setHasUWTable(); - if (enable_fast_math) { + if (module_config.debug_options().xla_cpu_enable_fast_math()) { function->addFnAttr("unsafe-fp-math", "true"); - function->addFnAttr("no-infs-fp-math", "true"); - function->addFnAttr("no-nans-fp-math", "true"); function->addFnAttr("no-signed-zeros-fp-math", "true"); + + if (!module_config.debug_options().xla_cpu_fast_math_honor_nans()) { + function->addFnAttr("no-nans-fp-math", "true"); + } + + if (!module_config.debug_options().xla_cpu_fast_math_honor_infs()) { + function->addFnAttr("no-infs-fp-math", "true"); + } } // Add the optize attribute to the function if optimizing for size. This // controls internal behavior of some optimization passes (e.g. loop // unrolling). - if (optimize_for_size) { + if (cpu::options::OptimizeForSizeRequested(module_config)) { function->addFnAttr(llvm::Attribute::OptimizeForSize); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index c604c7c870a..7b7d86364e2 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -45,14 +45,13 @@ class TargetOptions; namespace xla { namespace llvm_ir { -// Convert a std::string (used by LLVM's interfaces) to string. -string AsString(const std::string& str); - // Convert a absl::string_view to a llvm::StringRef. Note: both // absl::string_view and llvm::StringRef are non-owning pointers into a // string in memory. This method is used to feed strings to LLVM // & Clang APIs that expect llvm::StringRef. -llvm::StringRef AsStringRef(absl::string_view str); +inline llvm::StringRef AsStringRef(absl::string_view str) { + return llvm::StringRef(str.data(), str.size()); +} template llvm::ArrayRef AsArrayRef(const std::vector& vec) { @@ -71,7 +70,7 @@ string DumpToString(const T& entity) { llvm::raw_string_ostream ostream(buffer_string); entity.print(ostream); ostream.flush(); - return AsString(buffer_string); + return buffer_string; } // Dump the given LLVM module to a string. This requires a function distinct @@ -264,12 +263,7 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout); // Gets an llvm::FastMathFlags that reflects the settings in the given // module config. -llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled); - -// Sets values in the given TargetOptions struct according to the given -// compilation options. -void SetTargetOptions(bool fast_math_enabled, - llvm::TargetOptions* target_options); +llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config); // Computes a conservative union of the metadata in "a" and "b". For // aliasing-related metadata, this means the result can be applied to @@ -279,19 +273,19 @@ std::map MergeMetadata( llvm::LLVMContext* context, const std::map& a, const std::map& b); -// Dumps out `llvm_module` to a file in the directory named `directory_name`, -// creating the directory if necessary. A sanitized version of -// `hlo_module_name` is incorporated into the file name. If `optimized` is true -// then a suffix of "-with-opt.ll" is used, else a suffix of "-no-opt.ll" is -// used. -Status DumpIRToDirectory(const string& directory_name, - const string& hlo_module_name, - const llvm::Module& llvm_module, bool optimized); +// Dumps out `llvm_module` to the path specified in DebugOptions, if dumping is +// enabled for the given HLO module. +// +// A sanitized version of `hlo_module_name` is incorporated into the file name. +// If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix +// of "-no-opt.ll" is used. +void DumpIrIfEnabled(const HloModule& hlo_module, + const llvm::Module& llvm_module, bool optimized); -llvm::Function* CreateFunction(llvm::FunctionType* function_type, - llvm::GlobalValue::LinkageTypes linkage, - bool enable_fast_math, bool optimize_for_size, - absl::string_view name, llvm::Module* module); +llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + const HloModuleConfig& module_config, + absl::string_view name, llvm::Module* module); // Extracts the xla_backend_extra_options from `config` and passes those that // don't start with xla_ to LLVM. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index a689881e65e..83be4334269 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -47,14 +47,14 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, shape_(target_array.GetShape()), b_(b) {} -static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( +static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutput( const ElementGenerator& target_element_generator, const std::vector& target_arrays, llvm::IRBuilder<>* b) { return [=](const llvm_ir::IrArray::Index array_index) { TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); CHECK(target_element->getType()->isStructTy()) - << "This BodyEmitter is for multi-output fusion, but target element " + << "This BodyEmitter is for multi-output, but target element " "generator does not produce values of struct type."; CHECK_EQ(target_element->getType()->getStructNumElements(), target_arrays.size()); @@ -70,7 +70,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, absl::Span target_arrays, llvm::IRBuilder<>* b) - : body_emitter_(MakeBodyEmitterForMultiOutputFusion( + : body_emitter_(MakeBodyEmitterForMultiOutput( target_element_generator, std::vector(target_arrays.begin(), target_arrays.end()), b)), shape_(target_arrays[0].GetShape()), @@ -98,15 +98,16 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). ForLoopNest loop_nest(loop_name, b_); - IrArray::Index array_index(index_type, shape_.dimensions_size()); + std::vector array_multi_index(shape_.dimensions_size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), /*suffix=*/absl::StrFormat("dim.%d", dimension)); - array_index[dimension] = loop->GetIndVarValue(); + array_multi_index[dimension] = loop->GetIndVarValue(); } + IrArray::Index array_index(array_multi_index, shape_, index_type); // Set IR builder insertion point to the loop body basic block of the // innermost loop. diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index d71addec9b7..f2f4f306941 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" @@ -149,8 +150,8 @@ Status EmitTiledCompareLoop( const EmitCallToNestedComputationCallback& emit_compare_callback, llvm::IRBuilder<>* b) { KernelSupportLibrary ksl(b); - llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); + llvm::Value* thread_id = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b); llvm_ir::AddRangeMetadata(0, tile_size / 2, llvm::cast(thread_id)); thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(), @@ -187,10 +188,12 @@ Status EmitTiledCompareLoop( }; // Copy operand tiles from the operand buffers to shared memory. - IrArray::Index keys_index = tiled_keys_index; + std::vector keys_multi_index = tiled_keys_index.multidim(); for (int64 i = 0; i < params.size(); ++i) { copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { - keys_index[dimension_to_sort] = index; + keys_multi_index[dimension_to_sort] = index; + IrArray::Index keys_index(keys_multi_index, params[i].GetShape(), + tiled_keys_index.GetType()); auto value = params[i].EmitReadArrayElement(keys_index, b); b->CreateStore(value, b->CreateGEP(param_shmem_buffers[i], @@ -199,7 +202,7 @@ Status EmitTiledCompareLoop( }); } // Wait until all reads have happened. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); + gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {}, b); // Now emit the bodies of the comparison loops. auto element_address = [&](int64 operand, llvm::Value* index) { @@ -260,13 +263,16 @@ Status EmitTiledCompareLoop( /*needs_bounds_checks=*/false)); } // Wait until all comparisons have happened. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); + gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {}, + b); } // Copy the operand tiles back from shared memory to the operand buffers. for (int64 i = 0; i < params.size(); ++i) { copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { - keys_index[dimension_to_sort] = index; + keys_multi_index[dimension_to_sort] = index; + IrArray::Index keys_index(keys_multi_index, params[i].GetShape(), + tiled_keys_index.GetType()); auto value = b->CreateLoad(b->CreateGEP( param_shmem_buffers[i], {tiled_keys_index.GetConstantWithIndexType(0), cache_index})); @@ -348,23 +354,31 @@ Status EmitSortInPlace( // // This follows the algorithm described on Wikipedia: // https://en.wikipedia.org/wiki/Bitonic_sorter - IrArray::Index keys_index(tiles_index.GetType(), rank); + std::vector keys_multi_index(rank); for (int64 i = 0; i < rank; ++i) { - keys_index[iteration_order_to_logical_order[i]] = tiles_index[i]; + keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i]; } if (xor_masks.size() > 1) { + IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(), + tiles_index.GetType()); TF_RETURN_IF_ERROR(EmitTiledCompareLoop( keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks, values_arrays, param_shmem_buffers, tile_size, emit_compare_callback, b)); } else { auto element_address = [&](int64 operand, llvm::Value* index) { - keys_index[dimension_to_sort] = index; + keys_multi_index[dimension_to_sort] = index; + IrArray::Index keys_index(keys_multi_index, + values_arrays[operand].GetShape(), + tiles_index.GetType()); return values_arrays[operand].EmitArrayElementAddress(keys_index, b); }; auto write_element = [&](int64 operand, llvm::Value* index, llvm::Value* value) { - keys_index[dimension_to_sort] = index; + keys_multi_index[dimension_to_sort] = index; + IrArray::Index keys_index(keys_multi_index, + values_arrays[operand].GetShape(), + tiles_index.GetType()); values_arrays[operand].EmitWriteArrayElement(keys_index, value, b); }; TF_RETURN_IF_ERROR(EmitCompareLoopBody( diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index d8d2700e193..93dc66f9ac1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -29,9 +29,14 @@ limitations under the License. namespace xla { namespace llvm_ir { +static llvm::Module* getModuleFromBuilder(llvm::IRBuilder<>* b) { + return b->GetInsertBlock()->getModule(); +} + void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::IRBuilder<>* b) { + llvm::Module* module = getModuleFromBuilder(b); CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = @@ -45,27 +50,17 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, VLOG(2) << " pred_value: " << DumpToString(*pred_value); VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); - for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - llvm::Value* const element_index[] = {b->getInt64(0), b->getInt64(i)}; - llvm::Value* on_true_element_address = - b->CreateInBoundsGEP(on_true, element_index); - llvm::Value* on_true_element = b->CreateLoad( - on_true_element_address, "on_true_element_" + llvm::Twine(i)); - llvm::Value* on_false_element_address = - b->CreateInBoundsGEP(on_false, element_index); - llvm::Value* on_false_element = b->CreateLoad( - on_false_element_address, "on_false_element_" + llvm::Twine(i)); - - llvm::Value* output_element_address = - b->CreateInBoundsGEP(select.GetBasePointer(), element_index); - b->CreateStore(b->CreateSelect(pred_cond, on_true_element, on_false_element, - "select_output_element_" + llvm::Twine(i)), - output_element_address); - } + llvm::Value* src = b->CreateSelect(pred_cond, on_true, on_false); + llvm::Value* dst = select.GetBasePointer(); + int64 table_size = ShapeUtil::ByteSizeOfTupleIndexTable( + select.GetShape(), module->getDataLayout().getPointerSize()); + b->CreateMemCpy(dst, /*DstAlign=*/1, src, /*SrcAlign=*/1, + b->getInt64(table_size)); } void EmitTuple(const IrArray& tuple, absl::Span operands, - llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::IRBuilder<>* b) { + llvm::Module* module = getModuleFromBuilder(b); for (size_t i = 0; i < operands.size(); ++i) { auto* store = b->CreateStore( b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)), @@ -76,18 +71,45 @@ void EmitTuple(const IrArray& tuple, absl::Span operands, } void EmitTuple(const IrArray& tuple, absl::Span buffers, - llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::IRBuilder<>* b) { std::vector buffer_ptrs; buffer_ptrs.reserve(buffers.size()); absl::c_transform( buffers, std::back_inserter(buffer_ptrs), [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); }); - llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module); + llvm_ir::EmitTuple(tuple, buffer_ptrs, b); +} + +std::vector EmitTupleAllocasAtFunctionEntry( + const Shape& tuple_shape, llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getModule(); + + llvm::IRBuilder<>::InsertPointGuard guard(*b); + llvm::Function* function = b->GetInsertBlock()->getParent(); + b->SetInsertPoint(&function->getEntryBlock(), + function->getEntryBlock().getFirstInsertionPt()); + CHECK(tuple_shape.IsTuple()); + int tuple_size = tuple_shape.tuple_shapes_size(); + + std::vector generated_allocas; + for (int i = 0; i < tuple_size; i++) { + const Shape& element_shape = tuple_shape.tuple_shapes(i); + CHECK(ShapeUtil::IsScalar(element_shape)); + llvm::Type* type = + llvm_ir::PrimitiveTypeToIrType(element_shape.element_type(), module); + llvm::AllocaInst* alloca = b->CreateAlloca( + type, + /*ArraySize=*/nullptr, AsStringRef(absl::StrCat("tuple_element_", i))); + generated_allocas.push_back(alloca); + } + + return generated_allocas; } llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::IRBuilder<>* b) { + llvm::Module* module = getModuleFromBuilder(b); llvm::Value* element_ptr = b->CreateInBoundsGEP(operand, {b->getInt64(0), b->getInt64(index)}); llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr); diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index 94340b91d8e..1e173801139 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -61,17 +61,23 @@ namespace llvm_ir { // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* b, llvm::Module* module); + llvm::IRBuilder<>* b); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. void EmitTuple(const IrArray& tuple, absl::Span operands, - llvm::IRBuilder<>* b, llvm::Module* module); + llvm::IRBuilder<>* b); + +// Emits one alloca for each element in the tuple of shape tuple_shape, +// returns the emitted allocas. +// Precondition: tuple_shape should be a tuple of scalars. +std::vector EmitTupleAllocasAtFunctionEntry( + const Shape& tuple_shape, llvm::IRBuilder<>* b); // Similar to EmitTuple above, except that the output buffers are provided in // the form of IrArray. void EmitTuple(const IrArray& tuple, absl::Span buffers, - llvm::IRBuilder<>* b, llvm::Module* module); + llvm::IRBuilder<>* b); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. A GetTupleElement instruction @@ -79,7 +85,7 @@ void EmitTuple(const IrArray& tuple, absl::Span buffers, // Returns an llvm value representing a pointer to the tuple element buffer. llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* b, llvm::Module* module); + llvm::IRBuilder<>* b); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index f56ba32b04b..170d226e336 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -23,13 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc index 8269842426e..5fe5fea71ac 100644 --- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc @@ -17,25 +17,24 @@ limitations under the License. #include "absl/types/variant.h" namespace xla { -se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() { +tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() { if (HasOwnership()) { - return absl::get(mem_).AsDeviceMemoryBase(); + return *absl::get(mem_); } else { - return absl::get(mem_); + return absl::get(mem_); } } bool MaybeOwningDeviceMemory::HasOwnership() const { - return absl::holds_alternative(mem_); + return absl::holds_alternative(mem_); } -absl::optional MaybeOwningDeviceMemory::Release() { +absl::optional +MaybeOwningDeviceMemory::Release() { if (!HasOwnership()) { return {}; } - OwningDeviceMemory result = std::move(absl::get(mem_)); - mem_ = result.AsDeviceMemoryBase(); - return absl::make_optional(std::move(result)); + return std::move(absl::get(mem_)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h index 82e7f1183c0..8edd64cf681 100644 --- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h @@ -18,30 +18,29 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/variant.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/owning_device_memory.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { // MaybeOwningDeviceMemory represents either an owned or unowned device memory. -// Like std::variant. When the object goes +// Like std::variant. When the object goes // output of scope, it will free the underlying memory if it owns it. class MaybeOwningDeviceMemory { public: MaybeOwningDeviceMemory() = default; - explicit MaybeOwningDeviceMemory(OwningDeviceMemory owned) + explicit MaybeOwningDeviceMemory(tensorflow::se::OwningDeviceMemory owned) : mem_(std::move(owned)) {} - explicit MaybeOwningDeviceMemory(se::DeviceMemoryBase unowned) + explicit MaybeOwningDeviceMemory(tensorflow::se::DeviceMemoryBase unowned) : mem_(unowned) {} MaybeOwningDeviceMemory(MaybeOwningDeviceMemory&&) = default; ~MaybeOwningDeviceMemory() = default; - MaybeOwningDeviceMemory& operator=(se::DeviceMemoryBase unowned) { + MaybeOwningDeviceMemory& operator=(tensorflow::se::DeviceMemoryBase unowned) { mem_ = unowned; return *this; } - MaybeOwningDeviceMemory& operator=(OwningDeviceMemory owned) { + MaybeOwningDeviceMemory& operator=(tensorflow::se::OwningDeviceMemory owned) { mem_ = std::move(owned); return *this; } @@ -50,19 +49,21 @@ class MaybeOwningDeviceMemory { // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The // caller of this function is *not* responsible for freeing the memory. - se::DeviceMemoryBase AsDeviceMemoryBase(); + tensorflow::se::DeviceMemoryBase AsDeviceMemoryBase(); - // Release the OwningDeviceMemory without freeing it, and moves the ownership - // of the memory buffer from the object to the caller. + // Release the tensorflow::se::OwningDeviceMemory without freeing it, and + // moves the ownership of the memory buffer from the object to the caller. // // A nullopt is returned if the HasOwnership() == false; - absl::optional Release(); + absl::optional Release(); // Returns true if the device_memory has ownership over underlying memory. bool HasOwnership() const; private: - absl::variant mem_; + absl::variant + mem_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index e55b83d17e9..70742b67a28 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -62,6 +63,14 @@ NameUniquer::NameUniquer(const string& separator) { if (primitive_util::IsPrimitiveTypeName(result) && result != "tuple") { result += "_"; } + + if (absl::StartsWith(result, "__") && !absl::StartsWith(result, "__xla_")) { + // Morph name prefix __ that is not __xla_, to avoid using name prefixes + // reserved by the backends, such as __llvm_retpoline_ reserved by the LLVM + // x86 backend. + result[0] = 'a'; + } + return result; } diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index d0d04147e0c..1007c2aeae8 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -111,6 +111,12 @@ TEST_F(NameUniquerTest, AvoidKeywords) { EXPECT_EQ("s64_", uniquer.GetUniqueName("s64")); EXPECT_EQ("pred_", uniquer.GetUniqueName("pred")); + // Name prefix __xla_ is preserved. + EXPECT_NE(uniquer.GetUniqueName("__xla_").find("__xla_"), std::string::npos); + // Other form of __ prefixes is not preserved to avoid using name prefixes + // reserved by backends. + EXPECT_EQ(uniquer.GetUniqueName("__abx").find("__"), std::string::npos); + // Though a primitive type, "tuple" is not a keyword. EXPECT_EQ("tuple", uniquer.GetUniqueName("tuple")); diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc index 701c629add5..c1d401613d7 100644 --- a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" @@ -45,44 +46,42 @@ StatusOr OptimizeInputOutputBufferAlias::Build( VLOG(1) << "input_shape:" << input_shape.ToString(); VLOG(1) << "output_shape:" << output_shape.ToString(); - // For all buffers defined by the parameter, build a map from the byte - // size to the list of the buffers of that size. - absl::flat_hash_map> size_to_input_index; + // Tracks all buffers defined by the parameter in a flatten list. + struct Entry { + Shape shape; + ShapeIndex index; + bool used; + }; + std::vector parameter_entries; ShapeUtil::ForEachSubshape( input_shape, [&](const Shape& subshape, const ShapeIndex& index) { if (subshape.IsTuple()) { return; } - int64 bytes = size_func_(subshape); - size_to_input_index[bytes].push(index); + parameter_entries.emplace_back(Entry{subshape, index, false}); }); // For each result buffer shape index, take the first unused parameter - // buffer that matches the size. + // buffer that matches the shape. TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( output_shape, [&](const Shape& subshape, const ShapeIndex& index) { if (subshape.IsTuple()) { return Status::OK(); } - int64 bytes = size_func_(subshape); - - auto it = size_to_input_index.find(bytes); - if (it != size_to_input_index.end() && !it->second.empty()) { - changed = true; - const ShapeIndex& input_index = it->second.front(); - const ShapeIndex& output_index = index; - if (!alias_config->ParameterHasAlias(0, input_index) && - !alias_config->OutputHasAlias(output_index)) { - TF_RETURN_IF_ERROR(alias_config->SetUpAlias( - output_index, 0, input_index, - HloInputOutputAliasConfig::AliasKind::kSystemAlias)); + for (Entry& entry : parameter_entries) { + if (Shape::Equal()(entry.shape, subshape) && !entry.used) { + changed = true; + const ShapeIndex& input_index = entry.index; + const ShapeIndex& output_index = index; + if (!alias_config->ParameterHasAlias(0, input_index) && + !alias_config->OutputHasAlias(output_index)) { + TF_RETURN_IF_ERROR(alias_config->SetUpAlias( + output_index, 0, input_index, + HloInputOutputAliasConfig::AliasKind::kSystemAlias)); + } + entry.used = true; + break; } - VLOG(3) << "Set up alias from with param index " - << it->second.front().ToString() << ", shape size " << bytes - << " and result subshape " - << ShapeUtil::HumanStringWithLayout(subshape) << " at index " - << index.ToString(); - it->second.pop(); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h index 79ce468e975..90c35251ea9 100644 --- a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h @@ -36,20 +36,17 @@ namespace xla { // aliased, and writes the alias config into the HloModule. // // The input and the output buffers can be in any shape, and each output buffer -// can alias with an input buffer with the same size. Each input buffer may only -// alias with a single output buffer. For example, for the following parameter -// and the output buffers, +// can alias with an input buffer with the same shape. Each input buffer may +// only alias with a single output buffer. For example, for the following +// parameter and the output buffers, // -// Parameters : { P1(2MiB), P2(4MiB), P3(8MiB), P4(4MiB), P5(4MiB), ... } -// Outputs : { O1(4MiB), O2(2MiB), O3(4MiB), O4(6MiB), O5(4MiB), ... } +// Parameters : { P1(f32[3]), P2(s32[3]), P3(f32[3,12]), P4(f32[16,12]), ... } +// Outputs : { O1(s32[3]), O2(f32[3]), O3(f32[16,12]), ... } // -// one potential aliasing would be (O1, P2), (O2, P1), (O3, P4), (O5, P5), .. +// one potential aliasing would be (O1, P2), (O2, P1), (O3, P4), .. class OptimizeInputOutputBufferAlias : public HloModulePass { - using ShapeSizeFunction = std::function; - public: - OptimizeInputOutputBufferAlias(ShapeSizeFunction size_func) - : size_func_(size_func) {} + OptimizeInputOutputBufferAlias() = default; ~OptimizeInputOutputBufferAlias() override = default; absl::string_view name() const override { @@ -63,7 +60,6 @@ class OptimizeInputOutputBufferAlias : public HloModulePass { StatusOr Build(const Shape& input_shape, const Shape& output_shape, HloInputOutputAliasConfig* alias_config); - ShapeSizeFunction size_func_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc index 41e90f9b693..214ee663ac6 100644 --- a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc @@ -37,12 +37,7 @@ class OptimizeInputOutputBufferAliasTest : public HloTestBase { r3f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6}); r4f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - auto size_func = [](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape); - }; - - optimize_pass_ = - absl::make_unique(size_func); + optimize_pass_ = absl::make_unique(); } // Returns the number of output indices that aliases with the input. diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h deleted file mode 100644 index 9cf071f0d9d..00000000000 --- a/tensorflow/compiler/xla/service/owning_device_memory.h +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ - -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -namespace xla { - -// Break circular dependency between this file and device_memory_allocator.h. -class DeviceMemoryAllocator; - -// Owning pointer for memory on a device. -// -// OwningDeviceMemory is an owning pointer like std::unique_ptr, but it can -// point to memory that resides on a "device" (e.g. a GPU). When an -// OwningDeviceMemory goes out of scope, it frees the memory it owns. -// -// We say that an instance of OwningDeviceMemory is "active" if it currently -// owns a (possibly empty) slice of memory on the device. Moving, Forget()'ing, -// Free()'ing, and other actions can deactive an active object. -// -// Note that we can't simply use stream_executor::ScopedDeviceMemory instead of -// OwningDeviceMemory, because ScopedDeviceMemory frees its pointer via a -// StreamExecutor. This class needs to free via a xla::DeviceMemoryAllocator. -class OwningDeviceMemory { - public: - OwningDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {} - - explicit OwningDeviceMemory(se::DeviceMemoryBase mem, int device_ordinal, - DeviceMemoryAllocator* allocator) - : mem_(mem), device_ordinal_(device_ordinal), allocator_(allocator) { - CHECK(allocator != nullptr) << "allocator cannot be null."; - } - - OwningDeviceMemory(OwningDeviceMemory&& other) - : mem_(other.mem_), - device_ordinal_(other.device_ordinal_), - allocator_(other.allocator_) { - other.mem_ = se::DeviceMemoryBase(); - other.allocator_ = nullptr; - } - - OwningDeviceMemory& operator=(OwningDeviceMemory&& other) { - if (allocator_ != nullptr) { - Free(); - } - mem_ = other.mem_; - device_ordinal_ = other.device_ordinal_; - allocator_ = other.allocator_; - - other.mem_ = se::DeviceMemoryBase(); - other.allocator_ = nullptr; - return *this; - } - - // Deactivates this instance if it's active. Nop if it's not active. - OwningDeviceMemory& operator=(std::nullptr_t) { - if (allocator_ != nullptr) { - Free(); - } - return *this; - } - - ~OwningDeviceMemory() { - if (allocator_ != nullptr) { - Free(); - } - } - - // The returned allocator is nonnull iff this object is active. - DeviceMemoryAllocator* allocator() const { return allocator_; } - - int device_ordinal() const { return device_ordinal_; } - - // Gets the device memory pointer. - const void* opaque() const { return mem_.opaque(); } - void* opaque() { return mem_.opaque(); } - - uint64 size() const { return mem_.size(); } - - // Determines whether this wraps a null pointer. - // - // !is_null() is sufficient but not necessary to imply `this` is active. - bool is_null() const { return mem_.is_null(); } - - se::DeviceMemoryBase AsDeviceMemoryBase() { - return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false); - } - - // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates - // this object. Precondition: `this` is active. - TF_MUST_USE_RESULT se::DeviceMemoryBase Forget() { - CHECK(allocator_ != nullptr) - << "Can't call Forget() on an inactive (i.e. moved from, Forget()'ten, " - "or Free()'ed) instance."; - allocator_ = nullptr; - se::DeviceMemoryBase mem(mem_); - mem_ = se::DeviceMemoryBase(); - return mem; - } - - // Frees the wrapped DeviceMemoryBase and deactivates this object. - // Precondition: `this` is active. - void Free(); - - private: - se::DeviceMemoryBase mem_; - int device_ordinal_; - DeviceMemoryAllocator* allocator_; // Null if this object is inactive. -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 7164bfc4cd4..ae1df60d350 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -67,6 +67,7 @@ namespace xla { // - WithOneUse: Instruction is used as an operand exactly once. // - WithOneUser: Instruction is used by exactly one other instruction, but // is possibly used more than once as an operand (e.g. multiply(x,x)). +// - WithComparisonDirection: instr has the given direction // // Shape(): // - EqualTo @@ -1671,6 +1672,40 @@ class HloInstructionPatternOneUserImpl } }; +class HloInstructionPatternComparisonDirectionImpl { + public: + explicit constexpr HloInstructionPatternComparisonDirectionImpl( + ComparisonDirection direction) + : direction_(direction) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has comparison direction " + << ComparisonDirectionToString(direction_); + } + + private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kCompare || + inst->comparison_direction() != direction_) { + EXPLAIN << "HloInstruction is not comparison " + << ComparisonDirectionToString(direction_); + return false; + } + return true; + } + + ComparisonDirection direction_; +}; + // Matches a constant scalar or effective scalar, optionally with a given value. template class HloConstantScalarImpl { @@ -1956,6 +1991,14 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOneUserImpl()); } + // Modifies the pattern to match only if the instruction has the given + // comparison direction. + auto WithComparisonDirection(ComparisonDirection direction) const + -> decltype(this->AppendImpl( + HloInstructionPatternComparisonDirectionImpl(direction))) { + return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction)); + } + void DescribeTo(std::ostream* os, int64 indent = 0) const { impl_.DescribeTo(os, indent); } @@ -2118,18 +2161,13 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Compare) XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) -XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) -XLA_BINOP_PATTERN(Ge) -XLA_BINOP_PATTERN(Gt) -XLA_BINOP_PATTERN(Le) -XLA_BINOP_PATTERN(Lt) XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) -XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Pad) XLA_BINOP_PATTERN(Power) @@ -2242,6 +2280,73 @@ XLA_VARIADIC_OP_PATTERN(Reduce); XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); +// Helpers for comparison instructions. +#define XLA_COMPARE_PATTERN(NAME) \ + inline auto NAME()->decltype( \ + Op().WithOpcode(HloOpcode::kCompare) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op().WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } + +#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \ + XLA_COMPARE_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs))) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ + } + +XLA_COMMUTATIVE_COMPARE_PATTERN(Eq); +XLA_COMMUTATIVE_COMPARE_PATTERN(Ne); +XLA_COMPARE_PATTERN(Ge); +XLA_COMPARE_PATTERN(Gt); +XLA_COMPARE_PATTERN(Le); +XLA_COMPARE_PATTERN(Lt); + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 5c3c009a68b..cbe8c4a2410 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -931,5 +931,48 @@ TEST(PatternMatcherTest, OneUseAndOneUser) { "in p0 = f32[] parameter(0)"); } +TEST(HloMatchersTest, Comparison) { + auto shape = ShapeUtil::MakeShape(F32, {1}); + auto p0 = HloInstruction::CreateParameter(0, shape, "param.0"); + auto p1 = HloInstruction::CreateParameter(1, shape, "param.1"); + auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kEq); + auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(), + ComparisonDirection::kNe); + auto add = + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get()); + auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(), + ComparisonDirection::kLe); + + EXPECT_TRUE(Match(eq.get(), m::Compare())); + EXPECT_TRUE(Match(eq.get(), m::Eq())); + EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1)))); + EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0)))); + EXPECT_TRUE(Match(ne.get(), m::Compare())); + EXPECT_TRUE(Match(ne.get(), m::Ne())); + EXPECT_TRUE(Match( + le.get(), + m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1))))); + EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0), + m::Add(m::Parameter(0), m::Parameter(1))))); + + EXPECT_FALSE(Match(eq.get(), m::Add())); + EXPECT_FALSE(Match(eq.get(), m::Ne())); + EXPECT_FALSE( + Match(le.get(), + m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1))))); + EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0)))); + EXPECT_DESC_AND_EXPLANATION( + eq, m::Ne().WithOneUser(), + "an HloInstruction:\n" + " * with opcode compare AND\n" + " * which has comparison direction NE AND\n" + " * which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction is not comparison NE\n" + "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), " + "direction=EQ"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index b70cb705747..98964cc2006 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -121,9 +121,7 @@ StatusOr ReducePrecisionInsertion::insert_on_inputs( continue; } - if (instruction->opcode() == HloOpcode::kFusion && - (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop || - instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) { + if (instruction->IsInputFusion() || instruction->IsLoopFusion()) { // Insert the reduce-precision operation inside the fusion computation, // after the corresponding parameter instruction. TF_ASSIGN_OR_RETURN( @@ -172,9 +170,7 @@ StatusOr ReducePrecisionInsertion::insert_on_outputs( continue; } - if (instruction->opcode() == HloOpcode::kFusion && - (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop || - instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) { + if (instruction->IsLoopFusion() || instruction->IsOutputFusion()) { // Insert the reduce-precision operation as the last operation inside // the fusion computation. HloInstruction* fusion_root = instruction->fused_expression_root(); diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index a62118df157..9e2d7406940 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -29,11 +29,6 @@ limitations under the License. // // Where the instruction must be elementwise, and both reshapes and transposes // are moved. -// -// Most elementwise instructions support implicit broadcast of scalar operands, -// but select is a special-case. The signature is Select(Pred, A, B), and the -// only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or -// transposes to a scalar should be cheap, we simply never move them. #include "tensorflow/compiler/xla/service/reshape_mover.h" @@ -64,20 +59,14 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { // // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially // reshapable if *all* instructions in the chain have user_count == 1. And - // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar; we - // rely on implicit scalar broadcast for scalars to be trivial. In addition, - // these cases make it harder to maintain correctness of the UpdateOperand - // logic below. + // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar. + // In addition, these cases make it harder to maintain correctness of the + // UpdateOperand logic below. // // So don't handle these chains, unless you update the tests and code to deal // with these properly. One idea is to add a pass immediately beforehand that // collapses trivial runs of reshapes / transposes. - // Scalars can operate with any shape. - if (ShapeUtil::IsScalar(instruction->shape())) { - return true; - } - // A constant can trivially reshape the literal it holds. if (instruction->opcode() == HloOpcode::kConstant) { return true; @@ -143,8 +132,8 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { // This function is called once we've decided to sink reshape/transpose operands // across an instruction. It returns an updated `operand` with a shape that -// plays nicely with `new_operand_shape`; either it has the same shape (of the -// correct type), or it is a scalar that may be implicitly broadcast. +// plays nicely with `new_operand_shape`; it has the same shape (of the +// correct type). HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, const Shape& new_operand_shape, HloInstruction* operand) { @@ -221,9 +210,8 @@ StatusOr PerformSinkReshapeOrTranspose( UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { - // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not have easy shape - // changes, so all the fused instructions have the same dimensions. + // Here we already know `instruction` is elementwise, and all the fused + // instructions have the same dimensions. for (const auto& fused_instruction : instruction->fused_instructions()) { Shape* shape = fused_instruction->mutable_shape(); shape->clear_dimensions(); @@ -287,21 +275,17 @@ bool IsReshapeMoveCandidate(HloInstruction* instruction) { } // Check whether all operands: - // 0. Have the same dimensions as the output -- if not, they may be - // implicitly broadcast, which can confound the movement's - // correctness. + // 0. Have the same dimensions as the output. // // And one of the following: // 1. Are reshapes or transposes that have the same input and // output shapes as all other reshaped or transposed operands. // or - // 2. Are one of kConstant, kRng, broadcast of a scalar value, and scalars - // that can change shape trivially. + // 2. Are one of kConstant, kRng, broadcast of a scalar value. const HloInstruction* first_reshape_operand = nullptr; for (const HloInstruction* operand : instruction->operands()) { if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " + VLOG(5) << "Operand shape differs from output shape; so preventing " "movement\n\toperand: " << operand->ToString(print_no_metadata) << "\n\tinstruction: " << instruction->ToString(print_no_metadata); diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index acad871c4d4..e3a3feb8640 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -134,6 +134,13 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( int64 operand_rank) { HloComputation* computation = index_vector->parent(); const Shape& index_shape = index_vector->shape(); + + // Scatter of a scalar. Return a zero-sized vector of indices. + if (operand_rank == 0) { + return computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); + } + HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); @@ -174,8 +181,9 @@ static StatusOr CheckIndexValidity( HloInstruction* zero_index = BroadcastZeros(computation, index->shape().element_type(), AsInt64Slice(index->shape().dimensions())); - TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, - MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); + TF_ASSIGN_OR_RETURN( + HloInstruction * negative_index_check, + MakeCompareHlo(ComparisonDirection::kLe, zero_index, index)); // Check if the index is OOB w.r.t. the operand dimensions and window sizes. std::vector max_valid_index(operand_dims.size()); @@ -186,9 +194,9 @@ static StatusOr CheckIndexValidity( HloInstruction * max_valid_index_constant, MakeR1ConstantHlo(computation, index->shape().element_type(), max_valid_index)); - TF_ASSIGN_OR_RETURN( - HloInstruction * oob_index_check, - MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index)); + TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check, + MakeCompareHlo(ComparisonDirection::kGe, + max_valid_index_constant, index)); // Combine the results of the two checks above. TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 9bda6fba3aa..42b9e566d71 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -28,7 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -56,6 +57,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/ptr_util.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace { @@ -63,6 +65,10 @@ namespace { using absl::StrCat; using absl::StrFormat; +// Argument used when calling DumpHloModuleIfEnabled before optimizations are +// performed on an HloModule. +constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations"; + // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments(const absl::Span arguments, se::Stream* stream, TransferManager* transfer_manager, @@ -314,6 +320,15 @@ StatusOr> Service::CreateModuleConfig( config->set_intra_op_parallelism_threads( execute_backend_->eigen_intra_op_thread_pool()->NumThreads()); } + + if (execution_options != nullptr && + execution_options->has_device_assignment()) { + TF_ASSIGN_OR_RETURN( + auto device_assignment, + DeviceAssignment::Deserialize(execution_options->device_assignment())); + config->set_static_device_assignment(*device_assignment); + } + return std::move(config); } @@ -332,27 +347,14 @@ StatusOr>> Service::BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator) { + se::DeviceMemoryAllocator* device_allocator) { VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. std::vector> hlo_snapshots; for (int64 i = 0; i < module_protos.size(); ++i) { - const string& directory_path = - module_configs[i]->debug_options().xla_dump_computations_to(); - const string& execution_directory_path = - module_configs[i]->debug_options().xla_dump_executions_to(); - if (directory_path.empty() && execution_directory_path.empty()) { - continue; - } auto hlo_snapshot = absl::make_unique(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; - if (!directory_path.empty()) { - string filename = StrFormat("computation_%d__%s", module_protos[i]->id(), - module_protos[i]->entry_computation_name()); - TF_RETURN_IF_ERROR( - Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); - } hlo_snapshots.push_back(std::move(hlo_snapshot)); } @@ -368,7 +370,7 @@ StatusOr>> Service::BuildExecutables( const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); - TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); + DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); module_group->push_back(std::move(module)); } @@ -378,7 +380,9 @@ StatusOr>> Service::BuildExecutables( std::move(executors), device_allocator)); for (size_t i = 0; i < module_protos.size(); ++i) { - if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { + const auto& debug_opts = module_configs[i]->debug_options(); + if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) && + debug_opts.xla_dump_hlo_snapshots()) { executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i])); } } @@ -476,24 +480,6 @@ Service::ExecuteParallelAndRegisterResult( } } - // For every stream that had profiling enabled, obtain and debug-dump the HLO - // profile. - for (auto& index_to_profiled_stream : index_to_profiled_streams) { - int64 device = index_to_profiled_stream.first; - se::Stream* stream = index_to_profiled_stream.second; - Executable* executable = executables[device]; - const HloModule& module = executable->module(); - HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), - &executable->hlo_profile_index_map()); - TF_RETURN_IF_ERROR( - executable->PopulateExecutionProfile(&hlo_profile, stream)); - XLA_LOG_LINES( - tensorflow::INFO, - hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); - hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute", - &hlo_profile); - } - if (profile != nullptr) { CHECK(!timers.empty()); std::vector timer_nanoseconds; @@ -752,16 +738,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, } for (int i = 0; i < executable_ptrs.size(); i++) { - if (executable_ptrs[i]->dumping_snapshot()) { + Executable* executable = executable_ptrs[i]; + if (executable->dumping_snapshot()) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, allocation_tracker_.ResolveForReplica(outputs[i], 0)); TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(all_executors[i][0])); TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), - executable_ptrs[i]->hlo_snapshot())); - // Dump out the ith snapshot. - TF_RETURN_IF_ERROR(executable_ptrs[i]->DumpHloSnapshot()); + executable->hlo_snapshot())); + DumpHloSnapshotIfEnabled(executable->module(), + *executable->hlo_snapshot()); } } @@ -796,31 +783,14 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { + se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) { VLOG(1) << StrFormat( "BuildExecutable on service %p with serialized module proto: %s", this, module_proto.name()); - // Dump computation proto state if flag is set. - auto hlo_snapshot = absl::make_unique(); - const string& directory_path = - module_config->debug_options().xla_dump_computations_to(); - const string& execution_directory_path = - module_config->debug_options().xla_dump_executions_to(); - if (!directory_path.empty() || !execution_directory_path.empty()) { - *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; - if (!directory_path.empty()) { - string filename = StrFormat("computation_%d__%s", module_proto.id(), - module_proto.entry_computation_name()); - TF_RETURN_IF_ERROR( - Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); - } - } - TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(module_proto, *module_config)); - - TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); + DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -830,7 +800,11 @@ StatusOr> Service::BuildExecutable( backend->compiler()->RunBackend( std::move(module), executor, device_allocator)); - if (!execution_directory_path.empty()) { + const auto& debug_opts = module_config->debug_options(); + if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) && + debug_opts.xla_dump_hlo_snapshots()) { + auto hlo_snapshot = absl::make_unique(); + *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; executable->set_hlo_snapshot(std::move(hlo_snapshot)); } @@ -940,7 +914,7 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), executable->hlo_snapshot())); - TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); + DumpHloSnapshotIfEnabled(executable->module(), *executable->hlo_snapshot()); } VLOG(1) << "successfully completed 'execute' request"; @@ -1162,9 +1136,7 @@ Status Service::GetComputationGraphStats( config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); - - hlo_graph_dumper::MaybeDumpHloModule(*module, - "computation statistics subject"); + DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); // Run HLO analysis to get the computation statistics. HloCostAnalysis analysis( @@ -1203,16 +1175,4 @@ StatusOr> Service::Replicas( return replicas; } -Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { - const string xla_dump_unoptimized_hlo_proto_to = - module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); - if (xla_dump_unoptimized_hlo_proto_to.empty()) { - return Status::OK(); - } - HloProto proto = MakeHloProto(module); - return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, - StrCat(module.name(), ".unoptimized")); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index fd907d07dae..ba51e457c20 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" #include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" @@ -43,6 +42,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { @@ -234,7 +234,7 @@ class Service : public ServiceInterface { const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, - DeviceMemoryAllocator* device_allocator = nullptr); + se::DeviceMemoryAllocator* device_allocator = nullptr); // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. @@ -242,7 +242,7 @@ class Service : public ServiceInterface { const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator); + se::DeviceMemoryAllocator* device_allocator); // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is @@ -275,10 +275,6 @@ class Service : public ServiceInterface { StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - // Dumps the (unoptimized) module given if the corresponding DebugOptions - // field has been set. - Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; - // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. DeviceHandle SingleComputationDeviceHandle() const; diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 6bee6710565..7fc66310ee7 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -43,7 +43,9 @@ class ServiceExecutableRunOptions { // Delegate to `ExecutableRunOptions` member. se::Stream* stream() const { return run_options_.stream(); } - DeviceMemoryAllocator* allocator() const { return run_options_.allocator(); } + se::DeviceMemoryAllocator* allocator() const { + return run_options_.allocator(); + } int device_ordinal() const { return run_options_.device_ordinal(); } // Borrows a stream and returns a smart pointer which returns the stream on diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 53b5d18a065..3510e4913f4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -156,14 +156,6 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, return Status::OK(); } -bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { - return window_dimension.size() == 1 && window_dimension.stride() == 1 && - window_dimension.padding_low() == 0 && - window_dimension.padding_high() == 0 && - window_dimension.window_dilation() == 1 && - window_dimension.base_dilation() == 1; -} - StatusOr InferWindowOutputShape(const Shape& base_shape, const Window& window, PrimitiveType element_type, @@ -205,7 +197,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, window.DebugString()); } - if (base_shape.is_dynamic_dimension(i) && !IsTrivialWindowDimension(dim)) { + if (base_shape.is_dynamic_dimension(i) && + !window_util::IsTrivialWindowDimension(dim)) { return Unimplemented( "Dynamic shape is not supported for non trivial window: %s", window_util::ToString(window)); @@ -315,6 +308,14 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; + case HloOpcode::kPopulationCount: + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Expected an integral element type in argument to PopulationCount " + "operation; got %s.", + PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kSign: if (!ShapeUtil::ElementIsSigned(shape) && !ShapeUtil::ElementIsComplex(shape)) { @@ -633,10 +634,6 @@ Status ValidateDotDimensionNumbers( return fail("Element types do not match."); } - if ((lhs.rank() < 1) || (rhs.rank() < 1)) { - return fail("Dot only supports rank 1 or above."); - } - // Validate basic properties of dot dimension numbers. TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); @@ -988,12 +985,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kNe: { + case HloOpcode::kCompare: { TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); @@ -1721,11 +1713,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (batch_group_count > 1 && input_batch % kernel_output_features != 0) { return InvalidArgument( - "Expected output feature dimension (value %d) to be divisible by " - "input_batch (value %d) for batch group count %d; " + "Expected input batch (value %d) to be divisible by output feature " + "dimension size (value %d) for batch group count %d; " "got (%s, %s)\n" "Dimension numbers: {%s}.", - kernel_output_features, input_batch, batch_group_count, + input_batch, kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } @@ -1868,12 +1860,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]); } } - if (ShapeUtil::IsZeroElementArray(in)) { - return in; - } Shape result = ShapeUtil::ChangeElementType(in, C64); - result.set_dimensions(result.dimensions_size() - 1, - fft_length[fft_rank - 1] / 2 + 1); + // Preserve the size of zero-sized dimensions. + if (fft_length[fft_rank - 1] != 0) { + result.set_dimensions(result.dimensions_size() - 1, + fft_length[fft_rank - 1] / 2 + 1); + } return result; } case IRFFT: { @@ -1894,8 +1886,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]); } } - if (in.dimensions(in.dimensions_size() - 1) != - fft_length[fft_rank - 1] / 2 + 1) { + // The size of zero-sized dimensions is preserved. + if ((in.dimensions(in.dimensions_size() - 1) != 0 || + fft_length[fft_rank - 1] != 0) && + in.dimensions(in.dimensions_size() - 1) != + fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " "dimension %d is %d and should be %d.", @@ -2548,7 +2543,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, }; // Check the shapes of computation parameters and return types. - if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapeUtil::Compatible(condition.result(), + ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Condition must return a boolean; got %s.", shape_string()); } @@ -2568,8 +2564,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& branch_index, absl::Span branch_computations, absl::Span branch_operands) { - if (!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(PRED, {})) && - !ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(S32, {}))) { + if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) && + !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) { return InvalidArgument("branch_index must be bool or int32; got %s.", ShapeUtil::HumanString(branch_index)); } @@ -2744,45 +2740,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand); } -// TODO(b/36794510): Make broadcast semantics more consistent, by supporting -// "degenerate" cases, as with binary elementwise ops. /* static */ StatusOr ShapeInference::InferClampShape( const Shape& min, const Shape& operand, const Shape& max) { TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min")); TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max")); - if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || - !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { - return InvalidArgument("Clamp with different operand types: %s, %s, %s.", - ShapeUtil::HumanString(min), - ShapeUtil::HumanString(operand), - ShapeUtil::HumanString(max)); + + if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || + !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) { + return InvalidArgument( + "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min), + ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max)); } - if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || - ShapeUtil::IsScalar(min)) && - (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) || - ShapeUtil::IsScalar(max)))) { - return operand; - } - if (ShapeUtil::IsScalar(operand)) { - if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) { - return ShapeUtil::ChangeElementType(min, operand.element_type()); - } else if (ShapeUtil::IsScalar(min)) { - return ShapeUtil::ChangeElementType(max, operand.element_type()); - } else if (ShapeUtil::IsScalar(max)) { - return ShapeUtil::ChangeElementType(min, operand.element_type()); - } - } - return Unimplemented("%s, %s %s is not implemented.", - min.ShortDebugString(), max.ShortDebugString(), - operand.ShortDebugString()); + return operand; } -// TODO(b/36794510): Make broadcast semantics more consistent, by supporting -// "degenerate" cases, as with binary elementwise ops, as well as scalar -// broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { + TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred")); + TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true")); + TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false")); + if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", @@ -2793,31 +2771,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred)); } - if (Shape::Equal() - .IgnoreElementType() - .IgnoreLayout() - .IgnoreDynamicDimension()(pred, on_true) || - ShapeUtil::IsScalar(pred)) { - // By this stage we know that pred's element type is PRED. Therefore, this - // check restricts pred to be a PRED scalar, or a PRED array with the same - // dimensions as on_true and on_false. - Shape inferred_shape = ShapeUtil::ChangeElementType( - on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); - - // Propagate dynamic dimensions if pred is not a scalar. - if (!ShapeUtil::IsScalar(pred)) { - for (int i = 0; i < inferred_shape.rank(); i++) { - if (pred.is_dynamic_dimension(i)) { - inferred_shape.set_dynamic_dimension(i, true); - } - } - } - return inferred_shape; + if (!Shape::Equal() + .IgnoreElementType() + .IgnoreLayout() + .IgnoreDynamicDimension()(pred, on_true)) { + return InvalidArgument( + "Operands to select and predicate must be the same shape; got %s and " + "%s.", + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred)); } - return InvalidArgument( - "Select operation with non-scalar predicate with dimensionality " - "different from the other operands: %s.", - ShapeUtil::HumanString(pred)); + + return ShapeUtil::ChangeElementType( + pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); } /* static */ StatusOr ShapeInference::InferTupleSelectShape( @@ -2971,7 +2936,7 @@ static Status ValidateGatherDimensionNumbers( const GatherDimensionNumbers& gather_dim_numbers, absl::Span slice_sizes) { TF_RETURN_IF_ERROR( - ExpectArray(input_shape, "input tensor operand gather op")); + ExpectArray(input_shape, "input tensor operand of gather op")); TF_RETURN_IF_ERROR( ExpectArray(start_indices_shape, "gather indices operand of gather op")); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 6f8cc6136bb..3bfa971f857 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -112,15 +113,18 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, pred_, tuple, tuple); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("Expected array argument for select")); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT( + inferred_status.status().error_message(), + HasSubstr("Operands to select and predicate must be the same shape")); } TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { @@ -148,8 +152,9 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("with non-scalar predicate with dimensionality")); + ASSERT_THAT( + inferred_status_error3.status().error_message(), + HasSubstr("Operands to select and predicate must be the same shape")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( @@ -158,7 +163,7 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), - HasSubstr("pred operand must have PRED element type")); + HasSubstr("Expected array argument for select pred")); } TEST_F(ShapeInferenceTest, ClampAllMatrix) { @@ -178,43 +183,49 @@ TEST_F(ShapeInferenceTest, ClampAllScalar) { TEST_F(ShapeInferenceTest, ClampMinScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, matrix_64_48_, f32_, f32_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, f32_, f32_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, f32_, matrix_64_48_, f32_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampBadShapes) { @@ -562,6 +573,148 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { HasSubstr("each dimension exactly once")); } +namespace fft { + +static const char* unsupported_rank = "only supports ranks 1-3"; +static const char* invalid_rank = "requires input of at least same rank"; +static const char* requires_complex_input = "requires complex input type"; +static const char* requires_f32_input = "requires F32 input type"; +static const char* requires_c64_input = "requires C64 input type"; +static const char* dimensions_match = "innermost dimensions match fft_length"; +static const char* innermost_dimension_matches = + "innermost dimension matches fft_length/2+1"; + +static void Pass(const Shape& shape, FftType type, + absl::Span length, const Shape& expected_shape) { + auto inferred_status = ShapeInference::InferFftShape(shape, type, length); + ASSERT_IS_OK(inferred_status.status()); + Shape inferred_shape = inferred_status.ValueOrDie(); + ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape)); +} + +static void Fail(const Shape& shape, FftType type, + absl::Span length, absl::string_view message) { + auto inferred_status = ShapeInference::InferFftShape(shape, type, length); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr(std::string(message))); +} + +} // namespace fft + +TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) { + FftType type = FftType::FFT; + Shape shape = ShapeUtil::MakeShape(C64, {16, 8}); + fft::Fail(shape, type, {}, fft::unsupported_rank); + fft::Pass(shape, type, {8}, shape); + fft::Pass(shape, type, {16, 8}, shape); + fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank); + fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) { + FftType type = FftType::FFT; + Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); + Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); + fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); + fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) { + FftType type = FftType::IFFT; + Shape shape = ShapeUtil::MakeShape(C64, {16, 8}); + fft::Fail(shape, type, {}, fft::unsupported_rank); + fft::Pass(shape, type, {8}, shape); + fft::Pass(shape, type, {16, 8}, shape); + fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank); + fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) { + FftType type = FftType::IFFT; + Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); + Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); + fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); + fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) { + FftType type = FftType::RFFT; + Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8}); + Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5}); + fft::Fail(shape_in, type, {}, fft::unsupported_rank); + fft::Pass(shape_in, type, {8}, shape_out); + fft::Pass(shape_in, type, {16, 8}, shape_out); + fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank); + fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) { + FftType type = FftType::RFFT; + Shape shape = ShapeUtil::MakeShape(F32, {16, 8}); + fft::Fail(shape, type, {4}, fft::dimensions_match); + fft::Fail(shape, type, {16, 4}, fft::dimensions_match); + fft::Fail(shape, type, {8, 8}, fft::dimensions_match); + fft::Fail(shape, type, {8, 16}, fft::dimensions_match); + + Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0}); + Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0}); + fft::Pass(zero_shape_in, type, {0}, zero_shape_out); + fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out); + + Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8}); + Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9}); + Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5}); + fft::Pass(even_shape_in, type, {16, 8}, shape_out); + fft::Pass(odd_shape_in, type, {16, 9}, shape_out); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) { + FftType type = FftType::RFFT; + Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8}); + Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); + fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input); + fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) { + FftType type = FftType::IRFFT; + Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5}); + Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8}); + fft::Fail(shape_in, type, {}, fft::unsupported_rank); + fft::Pass(shape_in, type, {8}, shape_out); + fft::Pass(shape_in, type, {16, 8}, shape_out); + fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank); + fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) { + FftType type = FftType::IRFFT; + Shape shape = ShapeUtil::MakeShape(C64, {16, 5}); + fft::Fail(shape, type, {5}, fft::innermost_dimension_matches); + fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches); + fft::Fail(shape, type, {8, 8}, fft::dimensions_match); + fft::Fail(shape, type, {8, 9}, fft::dimensions_match); + + Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0}); + Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0}); + fft::Pass(zero_shape_in, type, {0}, zero_shape_out); + fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out); + + Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8}); + Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9}); + fft::Pass(shape, type, {16, 8}, even_shape_out); + fft::Pass(shape, type, {16, 9}, odd_shape_out); +} + +TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) { + FftType type = FftType::IRFFT; + Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); + Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); + fft::Fail(shape_f32, type, {16, 8}, fft::requires_c64_input); + fft::Fail(shape_c128, type, {16, 8}, fft::requires_c64_input); +} + TEST_F(ShapeInferenceTest, MapThatChangesElementType) { Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); @@ -918,55 +1071,10 @@ TEST_F(ShapeInferenceTest, InferPowShape) { ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } -TEST_F(ShapeInferenceTest, InferCompareShapeEq) { +TEST_F(ShapeInferenceTest, InferCompareShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeGe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeGt) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeLe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeLt) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), - inferred_status.ValueOrDie())); -} - -TEST_F(ShapeInferenceTest, InferCompareShapeNe) { - auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kCompare, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -1005,16 +1113,13 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { } } -// scalar vector: error +// scalar vector: ok TEST_F(ShapeInferenceTest, ScalarDotVector) { DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Dot only supports rank")); + EXPECT_TRUE(inferred_status.ok()); + EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_); } // 3D 2D: error diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index d90dde3b13d..9b0ec31e9da 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -67,6 +67,20 @@ ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { ShapedBuffer::~ShapedBuffer() {} +StatusOr ShapedBuffer::SubShapedBuffer( + const ShapeIndex& index) const { + TF_ASSIGN_OR_RETURN(const Shape* host_sub_shape, + ShapeUtil::TryGetSubshape(on_host_shape(), index)); + TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape, + ShapeUtil::TryGetSubshape(on_device_shape(), index)); + ShapedBuffer sub_shaped_buffer(*host_sub_shape, *device_sub_shape, platform_, + device_ordinal_); + TF_ASSIGN_OR_RETURN(ShapeTree sub_buffers, + buffers_.SubShapeTree(index)); + sub_shaped_buffer.set_buffers(std::move(sub_buffers)); + return std::move(sub_shaped_buffer); +} + void ShapedBuffer::clear() { for (auto& pair : buffers_) { // A default constructed DeviceMemoryBase is a null pointer. @@ -105,14 +119,14 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, - DeviceMemoryAllocator* allocator, + se::DeviceMemoryAllocator* allocator, int device_ordinal) : ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(), device_ordinal), allocator_(allocator) {} ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, - DeviceMemoryAllocator* allocator) + se::DeviceMemoryAllocator* allocator) : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {} ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index f5210c9cfa6..2351e901887 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -21,12 +21,12 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { @@ -90,6 +90,7 @@ class ShapedBuffer { void set_buffers(ShapeTree buffers) { CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_)); buffers_ = std::move(buffers); + buffers_.replace_shape_ptr(&on_device_shape_); } // Returns the underlying ShapeTree containing all the device addresses in the @@ -97,6 +98,8 @@ class ShapedBuffer { const ShapeTree& buffers() const { return buffers_; } ShapeTree& buffers() { return buffers_; } + StatusOr SubShapedBuffer(const ShapeIndex& index) const; + // Set all device memory pointers in the object to null. void clear(); @@ -135,13 +138,13 @@ class ScopedShapedBuffer : public ShapedBuffer { // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. explicit ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, - DeviceMemoryAllocator* allocator, + se::DeviceMemoryAllocator* allocator, int device_ordinal); // Create a ScopedShapedBuffer by taking over the memory from the incoming // ShapedBuffer. explicit ScopedShapedBuffer(ShapedBuffer shaped_buffer, - DeviceMemoryAllocator* allocator); + se::DeviceMemoryAllocator* allocator); // Movable, but not copyable. ScopedShapedBuffer(ScopedShapedBuffer&& s); @@ -154,17 +157,17 @@ class ScopedShapedBuffer : public ShapedBuffer { // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. - DeviceMemoryAllocator* memory_allocator() const { return allocator_; } + se::DeviceMemoryAllocator* memory_allocator() const { return allocator_; } // Sets the device memory buffer at the given index. // // If the given buffer's device memory is non-null, its device_ordinal and // allocator must match those in `this`. - void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) { + void set_buffer(se::OwningDeviceMemory buffer, const ShapeIndex& index) { if (!buffer.is_null()) { CHECK_EQ(buffer.device_ordinal(), device_ordinal()); CHECK_EQ(buffer.allocator(), allocator_); - *buffers_.mutable_element(index) = buffer.Forget(); + *buffers_.mutable_element(index) = buffer.Release(); } else { *buffers_.mutable_element(index) = se::DeviceMemoryBase(); } @@ -184,7 +187,7 @@ class ScopedShapedBuffer : public ShapedBuffer { protected: void Deallocate(); - DeviceMemoryAllocator* allocator_; + se::DeviceMemoryAllocator* allocator_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index ca64bd3c8dd..887e8b7146d 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -16,13 +16,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/util/ptr_util.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace { @@ -34,7 +34,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { auto* platform = platforms[0]; TF_ASSERT_OK_AND_ASSIGN(auto executors, xla::PlatformUtil::GetStreamExecutors(platform)); - xla::StreamExecutorMemoryAllocator allocator(platform, executors); + xla::se::StreamExecutorMemoryAllocator allocator(platform, executors); const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); const int kDeviceOrdinal = 0; auto scoped_buffer = absl::make_unique( @@ -43,11 +43,11 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { buffer = nullptr; } -class TestAllocator : public DeviceMemoryAllocator { +class TestAllocator : public se::DeviceMemoryAllocator { public: TestAllocator() - : DeviceMemoryAllocator(PlatformUtil::GetDefaultPlatform().ValueOrDie()) { - } + : se::DeviceMemoryAllocator( + PlatformUtil::GetDefaultPlatform().ValueOrDie()) {} ~TestAllocator() override { if (!allocations_.empty()) { @@ -56,18 +56,18 @@ class TestAllocator : public DeviceMemoryAllocator { } // Pull in two-arg overload of Allocate. - using DeviceMemoryAllocator::Allocate; + using se::DeviceMemoryAllocator::Allocate; - StatusOr Allocate(int device_ordinal, uint64 size, - bool /*retry_on_failure*/) override { + StatusOr Allocate( + int device_ordinal, uint64 size, bool /*retry_on_failure*/) override { // By contract, we must return null if size == 0. if (size == 0) { - return OwningDeviceMemory(); + return se::OwningDeviceMemory(); } void* buf = malloc(size); allocations_.insert({device_ordinal, buf}); - return OwningDeviceMemory(se::DeviceMemoryBase(buf, size), device_ordinal, - this); + return se::OwningDeviceMemory(se::DeviceMemoryBase(buf, size), + device_ordinal, this); } Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override { @@ -120,9 +120,9 @@ TEST(ScopedShapedBufferTest, TestTakeSubTree) { sb.buffers().ForEachMutableElement( [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { TF_ASSERT_OK_AND_ASSIGN( - OwningDeviceMemory m, + se::OwningDeviceMemory m, allocator.Allocate(/*device_ordinal=*/0, /*size=*/77)); - *buffer = m.Forget(); + *buffer = m.Release(); }); ShapeTree buffers = sb.buffers(); @@ -148,6 +148,27 @@ TEST(ScopedShapedBufferTest, TestTakeSubTree) { }); } +TEST(ScopedShapedBufferTest, TestSubShapeTree) { + Shape array_shape = ShapeUtil::MakeShape(F32, {1}); + Shape tuple_shape = + xla::ShapeUtil::MakeTupleShape({array_shape, array_shape}); + TestAllocator allocator; + ScopedShapedBuffer sb(tuple_shape, tuple_shape, &allocator, + /*device_ordinal=*/0); + sb.buffers().ForEachMutableElement( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + TF_ASSERT_OK_AND_ASSIGN( + se::OwningDeviceMemory m, + allocator.Allocate(/*device_ordinal=*/0, /*size=*/32)); + *buffer = m.Release(); + }); + auto ssb_statusor = sb.SubShapedBuffer({1}); + ASSERT_TRUE(ssb_statusor.ok()); + auto ssb = ssb_statusor.ConsumeValueOrDie(); + EXPECT_EQ(ssb.on_host_shape(), array_shape); + EXPECT_EQ(ssb.on_device_shape(), array_shape); +} + // Test TakeSubTree with different depths (depth of ShapeTree) and fan-outs // (cardinality of each non-leaf node's children). void BM_TakeSubTree(int iters, int depth, int fan_out) { diff --git a/tensorflow/compiler/xla/service/slice_sinker.cc b/tensorflow/compiler/xla/service/slice_sinker.cc new file mode 100644 index 00000000000..a8e681db78d --- /dev/null +++ b/tensorflow/compiler/xla/service/slice_sinker.cc @@ -0,0 +1,278 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/slice_sinker.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +namespace { + +// Returns whether two slices are taken from the same indices, assuming the +// slices are taking from tensors with the same dimensions. +bool SameSliceConfiguration(const HloInstruction* slice_1, + const HloInstruction* slice_2) { + CHECK_EQ(slice_1->opcode(), HloOpcode::kSlice); + CHECK_EQ(slice_2->opcode(), HloOpcode::kSlice); + CHECK(slice_1->operand(0)->shape().dimensions() == + slice_2->operand(0)->shape().dimensions()); + return slice_1->slice_starts() == slice_2->slice_starts() && + slice_1->slice_limits() == slice_2->slice_limits() && + slice_1->slice_strides() == slice_2->slice_strides(); +} + +// Returns true if all the operands of the given elementwise operation are +// slices from the same indices of tensors with compatible shapes. +bool IsElementwiseOperationOnSimilarSlices(const HloInstruction* inst) { + CHECK(inst->IsElementwise()); + + // Check that all operands are slices. + if (absl::c_any_of(inst->operands(), [](const HloInstruction* operand) { + return operand->opcode() != HloOpcode::kSlice; + })) { + return false; + } + + // Check that all slices are from the same indices of slice sources with + // compatible shapes. + const HloInstruction* slice0 = inst->operand(0); + return absl::c_all_of(absl::MakeSpan(inst->operands()).subspan(1), + [slice0](const HloInstruction* slice) { + return ShapeUtil::CompatibleIgnoringElementType( + slice0->operand(0)->shape(), + slice->operand(0)->shape()) && + SameSliceConfiguration(slice0, slice); + }); +} + +// Given an elementwise operation with all slice operands, operation_on_slices, +// checks whether another operation, candidate, is an operation that hasn't been +// transformed and is similar to operation_on_slices as defined by the following +// criteria: +// (1) candidate has the same opcode as the operation_on_slices. +// (2) The ith operand of candidate is a slice from the same slice source of +// the ith operand in operation_on_slices. +// (3) All operands of candidate are slices taken from the same indices as the +// operands of operation_on_slices are. +bool IsSimilarOperationOnSlices(const HloInstruction* operation_on_slices, + const HloInstruction* candidate) { + // Instructions that have already been transformed have user_count 0. Avoid + // transforming such instructions again. + if (candidate->user_count() == 0) { + return false; + } + + if (candidate->opcode() != operation_on_slices->opcode()) { + return false; + } + + const HloInstruction* operand_slice0 = candidate->operand(0); + for (int64 i = 0; i < candidate->operand_count(); ++i) { + const HloInstruction* operand_slice = candidate->operand(i); + if (operand_slice->opcode() != HloOpcode::kSlice || + operand_slice->operand(0) != + operation_on_slices->operand(i)->operand(0) || + !SameSliceConfiguration(operand_slice0, operand_slice)) { + return false; + } + } + return true; +} + +// Given a group of elementwise operations on slices that can be transformed to +// one elementwise operation on the slice sources, compares the cost of +// implementing the new elementwise operation on the slice sources with the cost +// of implementing all the individual elementwise operations independently. +// Returns true if the former is less expensive. +// +// Currently we don't support the following transformation that produces a new +// elementwise operation on bigger slices of the slice sources. This is because +// we don't have such a use case yet: +// Transform +// p = f32[20] parameter(0) +// a = f32[8] slice(p), slice=[0:8] +// aa = add(a, a) +// b = f32[7] slice(p), slice=[2:9] +// bb = add(b, b) +// +// to +// p = f32[20] parameter(0) +// x = f32[9] slice(p), slice=[0:8] +// xx = add(x,x) +// aa = f32[8] slice(xx), slice=[0:8] +// bb = f32[7] slice(xx), slice=[2:9] +bool ShouldTransform(const std::vector& operations_on_slices) { + int64 sum = 0; + for (HloInstruction* user : operations_on_slices) { + sum += ShapeUtil::ElementsIn(user->shape()); + } + return sum >= xla::ShapeUtil::ElementsIn( + operations_on_slices[0]->operand(0)->operand(0)->shape()); +} + +// Returns a group of elementwise operations on slices that are similar to the +// given operations_on_slices. See IsSimilarOperationOnSlices for what are +// considered similar operation on slices. +absl::optional> FindElementwiseOperationGroup( + const HloInstruction* operation_on_slices) { + std::vector operations; + const HloInstruction* slice_source0 = + operation_on_slices->operand(0)->operand(0); + + // Traverse the slices taken from the first slice sources. + for (const HloInstruction* operand_slice0 : slice_source0->users()) { + if (operand_slice0->opcode() != HloOpcode::kSlice) { + continue; + } + + for (HloInstruction* user : operand_slice0->users()) { + if (IsSimilarOperationOnSlices(operation_on_slices, user)) { + operations.push_back(user); + } + } + } + + return ShouldTransform(operations) ? absl::make_optional(operations) + : absl::nullopt; +} + +// Generates a new elementwise operation using the slice_sources as operands, +// and replaces the uses of elementwise operation_on_slices with slices of the +// new elementwise operations. +Status SinkSlices(const std::vector& slice_sources, + const std::vector& operation_on_slices) { + const Shape shape = slice_sources[0]->shape(); + PrimitiveType element_type = operation_on_slices[0]->shape().element_type(); + Shape new_shape = ShapeUtil::ChangeElementType(shape, element_type); + + HloComputation* computation = operation_on_slices[0]->parent(); + auto operation_on_slice_sources = computation->AddInstruction( + operation_on_slices[0]->CloneWithNewOperands(new_shape, slice_sources)); + VLOG(10) << "Adding operation_on_slice_sources: " + << operation_on_slice_sources->ToString(); + + // Replace each operation on slices with a slice of the operation on the slice + // sources. + for (HloInstruction* user : operation_on_slices) { + const HloInstruction* operand_slice = user->operand(0); + auto user_slice = + computation->AddInstruction(operand_slice->CloneWithNewOperands( + user->shape(), {operation_on_slice_sources})); + VLOG(10) << "Adding new slice: " << user_slice->ToString() + << " to replace: " << user->ToString(); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(user_slice)); + } + return Status::OK(); +} + +} // namespace + +// There are two purposes of this pass. +// +// 1. Eliminates redundant work that occurs when two slices overlap. For +// example: +// p = f32[10] parameter(0) +// a = f32[9] slice(p), slice=[0:9] +// aa = add(a, a) +// b = f32[8] slice(p), slice=[2:10] +// bb = add(b, b) +// ... +// Here we do 17 scalar add operations, while we actually only need to do 10 if +// we can transform the code to the following: +// p = f32[10] parameter(0) +// add = add(p, p) +// aa = f32[9] slice(add), slice=[0:9] +// bb = f32[8] slice(add), slice=[2:10] +// ... +// +// 2. Merges elementwise operations when two slices are "adjacent". +// p = f32[10] parameter(0) +// a = f32[6] slice(p), slice=[0:6] +// aa = add(a, a) +// b = f32[4] slice(p), slice=[6:10] +// bb = add(b, b) +// ... +// Here we're not doing any redundant work, but transforming this graph to the +// following graph allows us to run fewer kernels: +// p = f32[10] parameter(0) +// add = add(p, p) +// aa = f32[6] slice(add), slice=[0:6] +// bb = f32[4] slice(add), slice=[6:10] +// +// As can be seen from the examples, the group of elementwise operations being +// transformed must meet the following requirements: +// (1) The operands of each operation are slices taken from the same indices of +// bigger tensors with the same dimensions. +// (2) All operations have the same opcode. +// (3) The corresponding operands of all operations are slices taken +// from the same bigger tensors. +// (4) The accumulated size of the group of operations is not less than the size +// of such a bigger tensor. This is a heuristic to ensure that the +// transformation never causes us to do more elementwise operations. +// +// This pass currently doesn't transform non-elementwise instructions. We may +// extend this pass to transform non-elementwise instructions, such as dot, +// broadcast and reduce in the future. +StatusOr SliceSinker::Run(HloModule* module) { + bool changed = false; + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + // When processing instruction A in this loop, we may transform A along + // with instruction B, which is after A in the post order. An instruction + // that has been transformed has a user_count 0. We use this fact to + // avoid transforming an instruction that has been transformed. + if (!instruction->IsElementwise() || instruction->operand_count() == 0 || + instruction->user_count() == 0) { + continue; + } + VLOG(10) << "Processing instruction : " << instruction->ToString(); + + // This checks condition (1). + if (!IsElementwiseOperationOnSimilarSlices(instruction)) { + continue; + } + + // Try to find a group of elementwise operations that are similar to + // the current instruction. This checks conditions (2)-(4). + absl::optional> similar_operations = + FindElementwiseOperationGroup(instruction); + if (!similar_operations.has_value()) { + continue; + } + + std::vector slice_sources; + absl::c_transform( + instruction->operands(), std::back_inserter(slice_sources), + [](HloInstruction* slice) { return slice->mutable_operand(0); }); + + TF_RETURN_IF_ERROR(SinkSlices(slice_sources, similar_operations.value())); + changed = true; + } + } + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/owning_device_memory.cc b/tensorflow/compiler/xla/service/slice_sinker.h similarity index 51% rename from tensorflow/compiler/xla/service/owning_device_memory.cc rename to tensorflow/compiler/xla/service/slice_sinker.h index c115bc097f3..4615b5f2e69 100644 --- a/tensorflow/compiler/xla/service/owning_device_memory.cc +++ b/tensorflow/compiler/xla/service/slice_sinker.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/owning_device_memory.h" +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SLICE_SINKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SLICE_SINKER_H_ -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { -void OwningDeviceMemory::Free() { - CHECK(allocator_ != nullptr) - << "Can't call Free() on an inactive (i.e. moved from, Forget()'ten, " - "or Free()'ed) instance."; - auto status = allocator_->Deallocate(device_ordinal_, mem_); - if (!status.ok()) { - LOG(WARNING) << "Deallocating buffer " << mem_.opaque() << " failed."; - } +// An HLO pass that sinks slice operations used by a group of elementwise +// operations and merges the group of elementwise operations. +class SliceSinker : public HloModulePass { + public: + tensorflow::StringPiece name() const override { return "slice-sinker"; } - allocator_ = nullptr; - mem_ = se::DeviceMemoryBase(); -} + StatusOr Run(HloModule* module) override; +}; } // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SLICE_SINKER_H_ diff --git a/tensorflow/compiler/xla/service/slice_sinker_test.cc b/tensorflow/compiler/xla/service/slice_sinker_test.cc new file mode 100644 index 00000000000..f09a7a8288a --- /dev/null +++ b/tensorflow/compiler/xla/service/slice_sinker_test.cc @@ -0,0 +1,498 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/slice_sinker.h" + +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace m = match; +using ::testing::ElementsAre; + +class SliceSinkerTest : public HloTestBase {}; + +TEST_F(SliceSinkerTest, TernaryOperation) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + p2 = f32[8,9] parameter(2) + s00 = pred[2,9] slice(pred[8,9] p0), slice={[0:2], [0:9]} + s01 = pred[6,9] slice(pred[8,9] p0), slice={[2:8], [0:9]} + s10 = f32[2,9] slice(f32[8,9] p1), slice={[0:2], [0:9]} + s11 = f32[6,9] slice(f32[8,9] p1), slice={[2:8], [0:9]} + s20 = f32[2,9] slice(f32[8,9] p2), slice={[0:2], [0:9]} + s21 = f32[6,9] slice(f32[8,9] p2), slice={[2:8], [0:9]} + sel0 = f32[2,9] select(pred[2,9] s00, f32[2,9] s10, f32[2,9] s20) + sel1 = f32[6,9] select(pred[6,9] s01, f32[6,9] s11, f32[6,9] s21) + ROOT tuple = (f32[2,9], f32[6,9]) tuple(sel0, sel1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_TRUE(result); + HloInstruction* inst = module->entry_computation()->root_instruction(); + const HloInstruction* slice0; + const HloInstruction* slice1; + EXPECT_THAT(inst, + GmockMatch(m::Tuple( + m::Slice(&slice0, m::Select(m::Parameter(0), m::Parameter(1), + m::Parameter(2))), + m::Slice(&slice1, m::Select(m::Parameter(0), m::Parameter(1), + m::Parameter(2)))))); + EXPECT_THAT(slice0->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice0->slice_limits(), ElementsAre(2, 9)); + EXPECT_THAT(slice0->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice1->slice_starts(), ElementsAre(2, 0)); + EXPECT_THAT(slice1->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice1->slice_strides(), ElementsAre(1, 1)); +} + +TEST_F(SliceSinkerTest, OverlappingPartialSlicesBeneficial) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[2,9] slice(f32[8,9] p0), slice={[0:2], [0:9]} + s01 = f32[5,9] slice(f32[8,9] p0), slice={[3:8], [0:9]} + s02 = f32[8,4] slice(f32[8,9] p0), slice={[0:8], [0:4]} + s10 = f32[2,9] slice(f32[8,9] p1), slice={[0:2], [0:9]} + s11 = f32[5,9] slice(f32[8,9] p1), slice={[3:8], [0:9]} + s12 = f32[8,4] slice(f32[8,9] p1), slice={[0:8], [0:4]} + add0 = f32[2,9] add(f32[2,9] s00, f32[2,9] s10) + add1 = f32[5,9] add(f32[5,9] s01, f32[5,9] s11) + add2 = f32[8,4] add(f32[8,4] s02, f32[8,4] s12) + ROOT tuple = (f32[2,9], f32[5,9], f32[8,4]) tuple(add0, add1, add2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_TRUE(result); + HloInstruction* inst = module->entry_computation()->root_instruction(); + const HloInstruction* slice0; + const HloInstruction* slice1; + const HloInstruction* slice2; + EXPECT_THAT( + inst, GmockMatch(m::Tuple( + m::Slice(&slice0, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice1, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice2, m::Add(m::Parameter(0), m::Parameter(1)))))); + EXPECT_THAT(slice0->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice0->slice_limits(), ElementsAre(2, 9)); + EXPECT_THAT(slice0->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice1->slice_starts(), ElementsAre(3, 0)); + EXPECT_THAT(slice1->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice1->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice2->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice2->slice_limits(), ElementsAre(8, 4)); + EXPECT_THAT(slice2->slice_strides(), ElementsAre(1, 1)); +} + +TEST_F(SliceSinkerTest, SameSliceSourcesTwoPeerGroups) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[2,9] slice(f32[8,9] p0), slice={[0:2], [0:9]} + s01 = f32[6,9] slice(f32[8,9] p0), slice={[2:8], [0:9]} + s02 = f32[8,2] slice(f32[8,9] p0), slice={[0:8], [0:2]} + s03 = f32[8,7] slice(f32[8,9] p0), slice={[0:8], [2:9]} + s10 = f32[2,9] slice(f32[8,9] p1), slice={[0:2], [0:9]} + s11 = f32[6,9] slice(f32[8,9] p1), slice={[2:8], [0:9]} + s12 = f32[8,2] slice(f32[8,9] p1), slice={[0:8], [0:2]} + s13 = f32[8,7] slice(f32[8,9] p1), slice={[0:8], [2:9]} + add0 = f32[2,9] add(f32[2,9] s00, f32[2,9] s10) + add1 = f32[6,9] add(f32[6,9] s01, f32[6,9] s11) + mul0 = f32[8,2] multiply(f32[8,2] s02, f32[8,2] s12) + mul1 = f32[8,7] multiply(f32[8,7] s03, f32[8,7] s13) + ROOT tuple = (f32[2,9], f32[6,9], f32[8,2], f32[8,7]) tuple(add0, add1, mul0, mul1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_TRUE(result); + HloInstruction* inst = module->entry_computation()->root_instruction(); + const HloInstruction* slice0; + const HloInstruction* slice1; + const HloInstruction* slice2; + const HloInstruction* slice3; + EXPECT_THAT( + inst, + GmockMatch(m::Tuple( + m::Slice(&slice0, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice1, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice2, m::Multiply(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice3, m::Multiply(m::Parameter(0), m::Parameter(1)))))); + EXPECT_THAT(slice0->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice0->slice_limits(), ElementsAre(2, 9)); + EXPECT_THAT(slice0->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice1->slice_starts(), ElementsAre(2, 0)); + EXPECT_THAT(slice1->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice1->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice2->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice2->slice_limits(), ElementsAre(8, 2)); + EXPECT_THAT(slice2->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice3->slice_starts(), ElementsAre(0, 2)); + EXPECT_THAT(slice3->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice3->slice_strides(), ElementsAre(1, 1)); +} + +TEST_F(SliceSinkerTest, OverlappingMultipleSlices) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[2,9] slice(f32[8,9] p0), slice={[0:2], [0:9]} + s01 = f32[5,9] slice(f32[8,9] p0), slice={[3:8], [0:9]} + s02 = f32[3,9] slice(f32[8,9] p0), slice={[2:5], [0:9]} + s10 = f32[2,9] slice(f32[8,9] p1), slice={[0:2], [0:9]} + s11 = f32[5,9] slice(f32[8,9] p1), slice={[3:8], [0:9]} + s12 = f32[3,9] slice(f32[8,9] p1), slice={[2:5], [0:9]} + add0 = f32[2,9] add(f32[2,9] s00, f32[2,9] s10) + add1 = f32[5,9] add(f32[5,9] s01, f32[5,9] s11) + add2 = f32[3,9] add(f32[3,9] s02, f32[3,9] s12) + ROOT tuple = (f32[2,9], f32[5,9], f32[3,9]) tuple(add0, add1, add2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_TRUE(result); + HloInstruction* inst = module->entry_computation()->root_instruction(); + const HloInstruction* slice0; + const HloInstruction* slice1; + const HloInstruction* slice2; + EXPECT_THAT( + inst, GmockMatch(m::Tuple( + m::Slice(&slice0, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice1, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice2, m::Add(m::Parameter(0), m::Parameter(1)))))); + EXPECT_THAT(slice0->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice0->slice_limits(), ElementsAre(2, 9)); + EXPECT_THAT(slice0->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice1->slice_starts(), ElementsAre(3, 0)); + EXPECT_THAT(slice1->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice1->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice2->slice_starts(), ElementsAre(2, 0)); + EXPECT_THAT(slice2->slice_limits(), ElementsAre(5, 9)); + EXPECT_THAT(slice2->slice_strides(), ElementsAre(1, 1)); +} + +TEST_F(SliceSinkerTest, DisjointedPartialSlices) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[2,9] slice(f32[8,9] p0), slice={[0:2], [0:9]} + s01 = f32[5,9] slice(f32[8,9] p0), slice={[2:7], [0:9]} + s10 = f32[2,9] slice(f32[8,9] p1), slice={[0:2], [0:9]} + s11 = f32[5,9] slice(f32[8,9] p1), slice={[2:7], [0:9]} + add0 = f32[2,9] add(f32[2,9] s00, f32[2,9] s10) + add1 = f32[5,9] add(f32[5,9] s01, f32[5,9] s11) + ROOT tuple = (f32[2,9], f32[5,9]) tuple(add0, add1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(SliceSinkerTest, OverlappingPartialSlicesNotBeneficial) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[2,7] slice(f32[8,9] p0), slice={[0:2], [0:7]} + s01 = f32[6,7] slice(f32[8,9] p0), slice={[2:8], [0:7]} + s10 = f32[2,7] slice(f32[8,9] p1), slice={[0:2], [0:7]} + s11 = f32[6,7] slice(f32[8,9] p1), slice={[2:8], [0:7]} + add0 = f32[2,7] add(f32[2,7] s00, f32[2,7] s10) + add1 = f32[6,7] add(f32[6,7] s01, f32[6,7] s11) + ROOT tuple = (f32[2,7], f32[6,7]) tuple(add0, add1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(SliceSinkerTest, DifferentOrderingOfSliceSources) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,7] parameter(0) + p1 = f32[8,7] parameter(1) + s00 = f32[2,7] slice(f32[8,7] p0), slice={[0:2], [0:7]} + s01 = f32[6,7] slice(f32[8,7] p0), slice={[2:8], [0:7]} + s10 = f32[2,7] slice(f32[8,7] p1), slice={[0:2], [0:7]} + s11 = f32[6,7] slice(f32[8,7] p1), slice={[2:8], [0:7]} + add0 = f32[2,7] add(f32[2,7] s00, f32[2,7] s10) + add1 = f32[6,7] add(f32[6,7] s11, f32[6,7] s01) + ROOT tuple = (f32[2,7], f32[6,7]) tuple(add0, add1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(SliceSinkerTest, SlicesFromDifferentIndices) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[4,9] slice(f32[8,9] p0), slice={[0:4], [0:9]} + s01 = f32[4,9] slice(f32[8,9] p0), slice={[4:8], [0:9]} + s10 = f32[4,9] slice(f32[8,9] p1), slice={[0:4], [0:9]} + s11 = f32[4,9] slice(f32[8,9] p1), slice={[4:8], [0:9]} + add0 = f32[4,9] add(f32[4,9] s01, f32[4,9] s10) + add1 = f32[4,9] add(f32[4,9] s00, f32[4,9] s11) + ROOT tuple = (f32[4,9], f32[4,9]) tuple(add0, add1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(SliceSinkerTest, DifferentOperator) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[2,9] slice(f32[8,9] p0), slice={[0:2], [0:9]} + s01 = f32[6,9] slice(f32[8,9] p0), slice={[2:8], [0:9]} + s10 = f32[2,9] slice(f32[8,9] p1), slice={[0:2], [0:9]} + s11 = f32[6,9] slice(f32[8,9] p1), slice={[2:8], [0:9]} + mul = f32[2,9] multiply(f32[2,9] s00, f32[2,9] s10) + add = f32[6,9] add(f32[6,9] s01, f32[6,9] s11) + ROOT tuple = (f32[2,9], f32[6,9]) tuple(mul, add) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(SliceSinkerTest, SlicesWithMultiUsers) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[2,9] slice(f32[8,9] p0), slice={[0:2], [0:9]} + s01 = f32[6,9] slice(f32[8,9] p0), slice={[2:8], [0:9]} + s10 = f32[2,9] slice(f32[8,9] p1), slice={[0:2], [0:9]} + s11 = f32[6,9] slice(f32[8,9] p1), slice={[2:8], [0:9]} + add0 = f32[2,9] add(f32[2,9] s00, f32[2,9] s10) + add1 = f32[6,9] add(f32[6,9] s01, f32[6,9] s11) + mul0 = f32[2,9] multiply(f32[2,9] s00, f32[2,9] s10) + mul1 = f32[6,9] multiply(f32[6,9] s01, f32[6,9] s11) + ROOT tuple = (f32[2,9], f32[6,9]) tuple(add0, add1, mul0, mul1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_TRUE(result); + HloInstruction* inst = module->entry_computation()->root_instruction(); + const HloInstruction* slice0; + const HloInstruction* slice1; + const HloInstruction* slice2; + const HloInstruction* slice3; + EXPECT_THAT( + inst, + GmockMatch(m::Tuple( + m::Slice(&slice0, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice1, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice2, m::Multiply(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice3, m::Multiply(m::Parameter(0), m::Parameter(1)))))); + EXPECT_THAT(slice0->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice0->slice_limits(), ElementsAre(2, 9)); + EXPECT_THAT(slice0->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice1->slice_starts(), ElementsAre(2, 0)); + EXPECT_THAT(slice1->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice1->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice2->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice2->slice_limits(), ElementsAre(2, 9)); + EXPECT_THAT(slice2->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice3->slice_starts(), ElementsAre(2, 0)); + EXPECT_THAT(slice3->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice3->slice_strides(), ElementsAre(1, 1)); +} + +TEST_F(SliceSinkerTest, NonElementWise) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8] parameter(0) + s00 = f32[2] slice(f32[8] p0), slice={[0:2]} + s01 = f32[6] slice(f32[8] p0), slice={[2:8]} + bc0 = f32[2,9] broadcast(f32[2] s00), dimensions={0} + bc1 = f32[6,9] broadcast(f32[6] s01), dimensions={0} + ROOT tuple = (f32[2,9], f32[6,9]) tuple(bc0, bc1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(SliceSinkerTest, SlicesWithNontrivialStrides) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[4,9] slice(f32[8,9] p0), slice={[0:7:2], [0:9]} + s01 = f32[4,9] slice(f32[8,9] p0), slice={[1:8:2], [0:9]} + s10 = f32[4,9] slice(f32[8,9] p1), slice={[0:7:2], [0:9]} + s11 = f32[4,9] slice(f32[8,9] p1), slice={[1:8:2], [0:9]} + add0 = f32[4,9] add(f32[4,9] s00, f32[4,9] s10) + add1 = f32[4,9] add(f32[4,9] s01, f32[4,9] s11) + ROOT tuple = (f32[4,9], f32[4,9]) tuple(add0, add1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_TRUE(result); + HloInstruction* inst = module->entry_computation()->root_instruction(); + const HloInstruction* slice0; + const HloInstruction* slice1; + EXPECT_THAT( + inst, GmockMatch(m::Tuple( + m::Slice(&slice0, m::Add(m::Parameter(0), m::Parameter(1))), + m::Slice(&slice1, m::Add(m::Parameter(0), m::Parameter(1)))))); + EXPECT_THAT(slice0->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice0->slice_limits(), ElementsAre(7, 9)); + EXPECT_THAT(slice0->slice_strides(), ElementsAre(2, 1)); + EXPECT_THAT(slice1->slice_starts(), ElementsAre(1, 0)); + EXPECT_THAT(slice1->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice1->slice_strides(), ElementsAre(2, 1)); +} + +TEST_F(SliceSinkerTest, NotAllSliceOperand) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[2,9] parameter(1) + p2 = f32[6,9] parameter(2) + s00 = f32[2,9] slice(f32[8,9] p0), slice={[0:2], [0:9]} + s01 = f32[6,9] slice(f32[8,9] p0), slice={[2:8], [0:9]} + abs0 = f32[2,9] abs(f32[2,9] p1) + abs1 = f32[6,9] abs(f32[6,9] p2) + add0 = f32[2,9] add(f32[2,9] s00, f32[2,9] abs0) + add1 = f32[6,9] add(f32[6,9] s01, f32[6,9] abs1) + ROOT tuple = (f32[2,9], f32[6,9]) tuple(add0, add1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(SliceSinkerTest, Cascade) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[8,9] parameter(0) + p1 = f32[8,9] parameter(1) + s00 = f32[2,9] slice(f32[8,9] p0), slice={[0:2], [0:9]} + s01 = f32[6,9] slice(f32[8,9] p0), slice={[2:8], [0:9]} + s10 = f32[2,9] slice(f32[8,9] p1), slice={[0:2], [0:9]} + s11 = f32[6,9] slice(f32[8,9] p1), slice={[2:8], [0:9]} + abs0 = f32[2,9] abs(f32[2,9] s10) + abs1 = f32[6,9] abs(f32[6,9] s11) + add0 = f32[2,9] add(f32[2,9] s00, f32[2,9] abs0) + add1 = f32[6,9] add(f32[6,9] s01, f32[6,9] abs1) + ROOT tuple = (f32[2,9], f32[6,9]) tuple(add0, add1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + SliceSinker slice_sinker; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&slice_sinker, module.get())); + EXPECT_TRUE(result); + HloInstruction* inst = module->entry_computation()->root_instruction(); + const HloInstruction* slice0; + const HloInstruction* slice1; + EXPECT_THAT( + inst, + GmockMatch(m::Tuple( + m::Slice(&slice0, m::Add(m::Parameter(0), m::Abs(m::Parameter(1)))), + m::Slice(&slice1, + m::Add(m::Parameter(0), m::Abs(m::Parameter(1))))))); + EXPECT_THAT(slice0->slice_starts(), ElementsAre(0, 0)); + EXPECT_THAT(slice0->slice_limits(), ElementsAre(2, 9)); + EXPECT_THAT(slice0->slice_strides(), ElementsAre(1, 1)); + EXPECT_THAT(slice1->slice_starts(), ElementsAre(2, 0)); + EXPECT_THAT(slice1->slice_limits(), ElementsAre(8, 9)); + EXPECT_THAT(slice1->slice_strides(), ElementsAre(1, 1)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/sort_simplifier_test.cc b/tensorflow/compiler/xla/service/sort_simplifier_test.cc index 696ac1b4658..284d5095277 100644 --- a/tensorflow/compiler/xla/service/sort_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/sort_simplifier_test.cc @@ -39,7 +39,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -73,7 +73,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) { p.1.rhs = s32[] parameter(3) p.2.lhs = u32[] parameter(4) p.2.rhs = u32[] parameter(5) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -109,7 +109,7 @@ TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -134,7 +134,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.1.lhs, p.1.rhs) + ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT } ENTRY sort_computation { diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.cc b/tensorflow/compiler/xla/service/stable_sort_expander.cc index 1aa7e5fe7c0..ae4ce32569a 100644 --- a/tensorflow/compiler/xla/service/stable_sort_expander.cc +++ b/tensorflow/compiler/xla/service/stable_sort_expander.cc @@ -180,13 +180,13 @@ StatusOr StableSortExpander::ExpandInstruction( CHECK_NE(cloned_root, nullptr); Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); HloInstruction* same = - comparator->AddInstruction(HloInstruction::CreateBinary( - scalar_pred, HloOpcode::kEq, old_root, cloned_root)); + comparator->AddInstruction(HloInstruction::CreateCompare( + scalar_pred, old_root, cloned_root, ComparisonDirection::kEq)); HloInstruction* tie_breaker = - comparator->AddInstruction(HloInstruction::CreateBinary( - scalar_pred, HloOpcode::kLt, - comparator->parameter_instruction(2 * iota_index), - comparator->parameter_instruction(2 * iota_index + 1))); + comparator->AddInstruction(HloInstruction::CreateCompare( + scalar_pred, comparator->parameter_instruction(2 * iota_index), + comparator->parameter_instruction(2 * iota_index + 1), + ComparisonDirection::kLt)); HloInstruction* new_root = comparator->AddInstruction(HloInstruction::CreateTernary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker, diff --git a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc index a62d953e6e8..61fb4392a32 100644 --- a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc +++ b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc @@ -65,7 +65,8 @@ void CheckComputationHasTieBreaker(const HloInstruction* root, // the copied comparison function where the parameters are reversed. Lt() is // the tie breaker comparison using the Iota operand. ASSERT_EQ(root->opcode(), HloOpcode::kSelect); - ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq); + ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kCompare); + ASSERT_EQ(root->operand(0)->comparison_direction(), ComparisonDirection::kEq); // Check that the tie breaker instruction is correct. EXPECT_THAT(root->operand(1), @@ -88,7 +89,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -126,15 +127,15 @@ TEST_F(StableSortExpanderTest, lhs.unsigned = u32[] bitcast-convert(p.0.lhs) lhs.flipped = u32[] subtract(max, lhs.unsigned) lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped) - lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero) + lhs.is_negative = pred[] compare(lhs.flipped.signed, zero), direction=LT lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed) rhs.signed = s32[] bitcast-convert(p.0.rhs) rhs.unsigned = u32[] bitcast-convert(p.0.rhs) rhs.flipped = u32[] subtract(max, rhs.unsigned) rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped) - rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero) + rhs.is_negative = pred[] compare(rhs.flipped.signed, zero), direction=LT rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed) - ROOT lt = pred[] less-than(lhs.converted, rhs.converted) + ROOT lt = pred[] compare(lhs.converted, rhs.converted), direction=LT } ENTRY sort_computation { @@ -165,7 +166,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -200,7 +201,7 @@ TEST_F(StableSortExpanderTest, HonorIsStableFlag) { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -227,7 +228,7 @@ TEST_F(StableSortExpanderTest, p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -264,7 +265,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) { p.0.rhs = f32[] parameter(1) p.1.lhs = f32[] parameter(2) p.1.rhs = f32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY sort_computation { @@ -302,7 +303,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1) { mask = s32[] constant(65535) lhs = s32[] and(p.0.lhs, mask) rhs = s32[] and(p.0.rhs, mask) - ROOT lt = pred[] less-than(lhs, rhs) + ROOT lt = pred[] compare(lhs, rhs), direction=LT } ENTRY sort_computation { @@ -332,7 +333,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) { mask = s32[] constant(65535) lhs = s32[] and(p.0.lhs, mask) rhs = s32[] and(p.0.rhs, mask) - ROOT lt = pred[] less-than(lhs, rhs) + ROOT lt = pred[] compare(lhs, rhs), direction=LT } ENTRY sort_computation { diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 15ef623cc7b..6b089f8f3f0 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -42,8 +42,11 @@ TransferManager::GetPlatformTransferManagers() { return r; } +TransferManager::TransferMetadata::~TransferMetadata() {} + StatusOr TransferManager::TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer) { + se::Stream* stream, const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) { StatusOr ret; se::Stream* substream = stream->GetOrCreateSubStream(); @@ -54,11 +57,13 @@ StatusOr TransferManager::TransferLiteralFromDevice( tensorflow::Notification n; Status s; Literal literal(device_buffer.on_host_shape()); - TransferLiteralFromDevice(substream, device_buffer, literal, - [&](Status status) { - s = status; - n.Notify(); - }); + TransferLiteralFromDevice( + substream, device_buffer, literal, + [&](Status status) { + s = status; + n.Notify(); + }, + transfer_metadata); n.WaitForNotification(); if (!s.ok()) { return s; @@ -68,25 +73,29 @@ StatusOr TransferManager::TransferLiteralFromDevice( Status TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - const MutableBorrowingLiteral& literal) { + const MutableBorrowingLiteral& literal, + const TransferMetadata* transfer_metadata) { se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); Status ret; tensorflow::Notification n; - TransferLiteralFromDevice(substream, device_buffer, literal, - [&](Status status) { - ret = status; - n.Notify(); - }); + TransferLiteralFromDevice( + substream, device_buffer, literal, + [&](Status status) { + ret = status; + n.Notify(); + }, + transfer_metadata); n.WaitForNotification(); return ret; } Status TransferManager::TransferLiteralToDevice( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) { + const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) { // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. @@ -94,14 +103,14 @@ Status TransferManager::TransferLiteralToDevice( substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); - TF_RETURN_IF_ERROR( - TransferLiteralToDeviceAsync(substream, literal, device_buffer)); + TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync( + substream, literal, device_buffer, transfer_metadata)); return substream->BlockHostUntilDone(); } StatusOr TransferManager::TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source) { + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, + const TransferMetadata* transfer_metadata) { StatusOr ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't @@ -113,11 +122,13 @@ StatusOr TransferManager::TransferArrayFromDevice( tensorflow::Notification n; Literal literal(shape); Status s; - TransferArrayFromDevice(substream, shape, source, literal, - [&](Status status) { - s = status; - n.Notify(); - }); + TransferArrayFromDevice( + substream, shape, source, literal, + [&](Status status) { + s = status; + n.Notify(); + }, + transfer_metadata); n.WaitForNotification(); if (!s.ok()) { return s; @@ -127,20 +138,23 @@ StatusOr TransferManager::TransferArrayFromDevice( Status TransferManager::TransferArrayToDevice( se::Stream* stream, const LiteralSlice& literal, - const se::DeviceMemoryBase& dest) { + const se::DeviceMemoryBase& dest, + const TransferMetadata* transfer_metadata) { // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); - TF_RETURN_IF_ERROR(TransferArrayToDeviceAsync(substream, literal, dest)); + TF_RETURN_IF_ERROR( + TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata)); return substream->BlockHostUntilDone(); } Status TransferManager::TransferArrayToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const se::DeviceMemoryBase& dest) { + const se::DeviceMemoryBase& dest, + const TransferMetadata* transfer_metadata) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(on_device_shape.IsArray()) << "On-device representation of " @@ -156,13 +170,16 @@ Status TransferManager::TransferArrayToDeviceAsync( stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(dest, /*index=*/{}); - return TransferLiteralToDevice(stream, literal, shaped_buffer); + return TransferLiteralToDevice(stream, literal, shaped_buffer, + transfer_metadata); } void TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, - const MutableBorrowingLiteral& literal, std::function done) { - if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { + const MutableBorrowingLiteral& literal, std::function done, + const TransferMetadata* transfer_metadata) { + if (!Shape::Equal().MinorToMajorOnlyInLayout()(HostShapeToDeviceShape(shape), + shape)) { auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); @@ -179,7 +196,7 @@ void TransferManager::TransferArrayFromDevice( stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); return TransferLiteralFromDevice(stream, shaped_buffer, literal, - std::move(done)); + std::move(done), transfer_metadata); } /* static */ void TransferManager::RegisterTransferManager( @@ -269,7 +286,7 @@ Status TransferManager::TransferBufferFromDevice( void* destination) { if (source.size() < size) { return FailedPrecondition( - "Source allocation on device not large enough for data tranfer: " + "Source allocation on device not large enough for data transfer: " "%d < %d", source.size(), size); } @@ -282,7 +299,7 @@ Status TransferManager::TransferBufferToDevice( se::DeviceMemoryBase* destination) { if (destination->size() < size) { return FailedPrecondition( - "Destination allocation on device not large enough for data tranfer: " + "Destination allocation on device not large enough for data transfer: " "%d < %d", destination->size(), size); } @@ -291,7 +308,7 @@ Status TransferManager::TransferBufferToDevice( } StatusOr TransferManager::AllocateScopedShapedBuffer( - const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal) { if (!LayoutUtil::HasLayout(on_host_shape)) { return InvalidArgument("Shape must have a layout: %s", @@ -314,10 +331,15 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( allocator->Allocate(shaped_buffer.device_ordinal(), GetByteSizeRequirement(subshape))); // Move the allocated buffer into the ScopedShapedBuffer, which owns it. - memory_base = memory.Forget(); + memory_base = memory.Release(); } return std::move(shaped_buffer); } +StatusOr TransferManager::ChooseCompactLayoutForShape( + const Shape& host_shape) const { + return LayoutUtil::GetWithDefaultLayout(host_shape); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 43a50487c63..f08862bff26 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -52,16 +52,38 @@ class TransferManager { return host_shape; } + // Base class for specifying platform specific transfer metadata that can be + // used to tell the underlying implementation to perform specific optimization + // to a transfer. Actual metadata passed to supported transfer methods should + // subclass this class. + class TransferMetadata { + public: + virtual ~TransferMetadata() = 0; + }; // Returns a literal containing the data held in the given ShapedBuffer // using the provided executor. This operation is performed synchronously // without waiting for any other operation on a stream to complete. // // This function should be avoided in favor of the asynchronous version below. + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. virtual StatusOr TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer); + se::Stream* stream, const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata); + StatusOr TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer) { + return TransferLiteralFromDevice(stream, device_buffer, nullptr); + } virtual Status TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - const MutableBorrowingLiteral& literal); + const MutableBorrowingLiteral& literal, + const TransferMetadata* transfer_metadata); + Status TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal) { + return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr); + } // Begins transferring a literal containing the data held in the given // ShapedBuffer using the provided executor. @@ -72,10 +94,20 @@ class TransferManager { // // device_buffer is copied by reference and must live at least until done() is // invoked. - virtual void TransferLiteralFromDevice(se::Stream* stream, - const ShapedBuffer& device_buffer, - MutableBorrowingLiteral literal, - std::function done) = 0; + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. + virtual void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, std::function done, + const TransferMetadata* transfer_metadata) = 0; + void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) { + return TransferLiteralFromDevice(stream, device_buffer, literal, done, + nullptr); + } // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -85,9 +117,18 @@ class TransferManager { // This operation is performed synchronously without waiting for any other // operation on a stream to complete. This function should be avoided in favor // of the asynchronous version below. - virtual Status TransferLiteralToDevice(se::Stream* stream, - const LiteralSlice& literal, - const ShapedBuffer& device_buffer); + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. + virtual Status TransferLiteralToDevice( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata); + Status TransferLiteralToDevice(se::Stream* stream, + const LiteralSlice& literal, + const ShapedBuffer& device_buffer) { + return TransferLiteralToDevice(stream, literal, device_buffer, nullptr); + } // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -102,26 +143,44 @@ class TransferManager { // immediately after this function returns, however their constituent buffers // on both host and device must remain valid until the enqueued transfer has // completed on 'stream'. + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. virtual Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) = 0; + const ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) = 0; + Status TransferLiteralToDeviceAsync(se::Stream* stream, + const LiteralSlice& literal, + const ShapedBuffer& device_buffer) { + return TransferLiteralToDeviceAsync(stream, literal, device_buffer, + nullptr); + } // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. - Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, - const se::DeviceMemoryBase& dest); - void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source, - const MutableBorrowingLiteral& literal, - std::function done); + // + // Optionally caller can specify platform-specific transfer metadata that + // tells the actual implementation to do something special. + Status TransferArrayToDevice( + se::Stream* stream, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest, + const TransferMetadata* transfer_metadata = nullptr); + void TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, std::function done, + const TransferMetadata* transfer_metadata = nullptr); - Status TransferArrayToDeviceAsync(se::Stream* stream, - const LiteralSlice& literal, - const se::DeviceMemoryBase& dest); - StatusOr TransferArrayFromDevice(se::Stream* stream, - const Shape& shape, - const se::DeviceMemoryBase& source); + Status TransferArrayToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest, + const TransferMetadata* transfer_metadata = nullptr); + StatusOr TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const TransferMetadata* transfer_metadata = nullptr); // Transfers the given literal into the Infeed interface of the device, // using the given executor. @@ -157,11 +216,20 @@ class TransferManager { // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; + // Chooses a compact layout for 'shape', ignoring any existing layout on + // 'shape'. What "reasonable" means is left up to the backend. The + // intended use case is to choose a layout that avoids excessive padding on + // devices that have tiled memory architectures. + // The default implementation always picks a default (major-to-minor) layout. + // Fails if 'shape' cannot be represented by the device. + virtual StatusOr ChooseCompactLayoutForShape( + const Shape& host_shape) const; + // Allocates a ScopedShapedBuffer which can hold data with the given on-host // shape. The on-device shape may be different as indicated by // HostShapeToDeviceShape. StatusOr AllocateScopedShapedBuffer( - const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal); // The given ShapedBuffer holds a handle to allocated memory, but it is not diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index b26cdc1db59..57efee700be 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -317,13 +317,9 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = - b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x, - precision); + remainder = b_row - BatchDot(a_row, transpose_a, x, false, precision); } else { - remainder = - b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a), - precision); + remainder = b_row - BatchDot(x, false, a_row, transpose_a, precision); } } @@ -332,12 +328,11 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, auto start_index = ConstantR0WithType(builder, S32, j * block_size); std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a), - remainder, precision); + x_update = + BatchDot(inv_block, transpose_a, remainder, false, precision); } else { - x_update = BatchDot(remainder, - MaybeTransposeInMinorDims(inv_block, transpose_a), - precision); + x_update = + BatchDot(remainder, false, inv_block, transpose_a, precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -403,6 +398,9 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, block_size); } + block_size = std::max( + int64{1}, std::min(block_size, ShapeUtil::GetDimension(a_shape, -1))); + if (ShapeUtil::IsZeroElementArray(b_shape)) { // The output has the same shape as 'b', and since the output has zero // elements, any such array will do. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index cc82e9bb028..638c3b4a88c 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -598,8 +599,7 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( // GetTupleElement instructions only access the top-level buffer of their // operand. return true; - } else if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + } else if (user->IsLoopFusion()) { // Find fusion parameter associated with 'operand'. auto it = absl::c_find_if( user->fused_parameters(), [&](HloInstruction* fused_param) { @@ -717,8 +717,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return false; } if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - user->fusion_kind() == HloInstruction::FusionKind::kInput) { + if (user->IsLoopFusion() || user->IsInputFusion()) { if (user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { // Loop fusion with kDynamicUpdateSlice fused root. @@ -733,7 +732,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( fusion_param); } - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + } else if (user->IsOutputFusion() && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -756,6 +755,14 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( // index 'other_add_operand_index'). return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, other_add_operand_index); + } else if (user->IsCustomFusion()) { + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0 && + absl::c_any_of( + user->fused_instructions_computation()->instructions(), + [](const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kScatter; + }); } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 6f61fc44166..61b98673cbe 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -934,8 +934,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { HloInstruction::CreateParameter(0, in_shape, "param0")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + auto result = builder.AddInstruction(HloInstruction::CreateCompare( + out_shape, param0, param1, ComparisonDirection::kEq)); BuildModuleAndRunAnalysis(builder.Build()); @@ -1185,8 +1185,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + builder.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq)); return builder.Build(); }; diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 40f268f889b..ffa89b6a797 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -80,7 +80,7 @@ static optional GetGTEOperandIndex(const HloInstruction* instr, // Tries to get the tuple index of the induction variable of a while loop. // -// Checks that the loop condition and root both plumb the induction variable +// Checks that the loop condition and body both plumb the induction variable // through the same tuple index, and that they both apply exactly one op to the // induction variable before deciding whether to do another loop iteration (in // the loop condition's case) or packing the induction variable into the result @@ -96,8 +96,7 @@ static optional GetGTEOperandIndex(const HloInstruction* instr, // root = tuple(..., inc, ...) // inc is N'th operand of tuple(). // // If so, returns N. Otherwise, returns nullopt. -static optional GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { +optional GetLoopInductionVarTupleIdx(const HloInstruction* while_op) { CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); VLOG(2) << "Finding induction variable for loop " << while_op->ToShortString(); @@ -287,7 +286,7 @@ static optional PatternMatchLoopTripCount(HloInstruction* while_op, // Handle `i = K; i < N; ++i`. if (Match(while_cond_root, m::Op() - .WithOpcode(HloOpcode::kLt) + .WithComparisonDirection(ComparisonDirection::kLt) .WithOperand(0, m::Op().Is(while_cond_indvar)))) { VLOG(2) << "Pattern-match succeeded: loop condition is i < N: " << while_cond_root->ToString(); @@ -304,7 +303,7 @@ static optional PatternMatchLoopTripCount(HloInstruction* while_op, // Handle `i = K; i <= N; ++i`. if (Match(while_cond_root, m::Op() - .WithOpcode(HloOpcode::kLe) + .WithComparisonDirection(ComparisonDirection::kLe) .WithOperand(0, m::Op().Is(while_cond_indvar)))) { VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: " << while_cond_root->ToString(); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index 9bb784a544b..10b64459974 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -35,6 +35,11 @@ absl::optional ComputeWhileLoopTripCount( // known, nullopt otherwise. absl::optional ComputeWhileLoopTripCountUpperBound( HloInstruction *while_op); + +// Returns the tuple index of the loop induction variable if there is such an +// induction variable detected. Otherwise returns nullopt. +absl::optional GetLoopInductionVarTupleIdx( + const HloInstruction *while_op); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc index 1da0fbeac89..5a5dc742c03 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -40,7 +40,7 @@ TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { @@ -71,7 +71,7 @@ TEST_F(WhileLoopAnalysisTest, NoUpperBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { @@ -104,7 +104,7 @@ TEST_F(WhileLoopAnalysisTest, ExactBound) { p_cond = (f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=1 const = s32[] constant(42) - ROOT result = pred[] less-than(gte, const) + ROOT result = pred[] compare(gte, const), direction=LT } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 3bcf5c38309..8ab5e433e0f 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -260,7 +260,7 @@ condition { p_cond = (f32[],f32[]) parameter(0) p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0 p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1 - ROOT result = pred[] less-than(p_cond.0, p_cond.1) + ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT } ENTRY entry { @@ -300,7 +300,7 @@ condition { p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0 p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1 p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1 - ROOT result = pred[] less-than(p_c.0, p_c.1.1) + ROOT result = pred[] compare(p_c.0, p_c.1.1), direction=LT } ENTRY entry { @@ -342,7 +342,7 @@ condition { p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - ROOT result = pred[] less-than(p_cond.0, p_cond.1) + ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT } ENTRY entry { @@ -389,10 +389,10 @@ condition { p_cond = (f32[],f32[],f32[]) parameter(0) p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - lt.0 = pred[] less-than(p_cond.0, p_cond.2) + lt.0 = pred[] compare(p_cond.0, p_cond.2), direction=LT p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 - lt.1 = pred[] less-than(p_cond.1, p_cond.2.c) + lt.1 = pred[] compare(p_cond.1, p_cond.2.c), direction=LT ROOT result = pred[] and(lt.0, lt.1) } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 3587c016b44..f0bb646d9c0 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -556,7 +556,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0) gte = s32[] get-tuple-element(p_cond), index=3 const = s32[] constant(42) - ROOT result = pred[] equal-to(gte, const) + ROOT result = pred[] compare(gte, const), direction=EQ } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 386ffb99547..999e8a9c0ac 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -46,7 +46,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Don't try this transformation if the while loop isn't removable, since if // it succeeds ultimately we're going to have to replace the old while loop // with a new one. - if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { + if (!while_op->parent()->IsRemovable(while_op)) { VLOG(2) << "Can't remove dead parameters from non-removable while op."; return false; } @@ -300,8 +300,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { } HloInstruction* new_tuple = computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); - TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple)); - + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); return true; } @@ -454,27 +453,29 @@ static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { // // Returns true if it made a change to the graph. static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { - // Cowardly refuse to remove loops that are not removable. In practice, - // this means that we can't remove loops that contain side-effecting - // instructions or have control predecessors/successors. + // Cowardly refuse to remove loops that are not removable. In practice, this + // means that we can't remove loops that have control predecessors/successors. + if (!while_op->parent()->IsRemovable(while_op)) { + VLOG(2) << "Not attempting to remove while loop that is not removable: " + << while_op->ToShortString(); + return false; + } + + // Refuse to remove while loops with a condition that contain side-effects, + // because removing a while loop is tantamount to removing its condition. // - // This is not a fundamental limitation. The control operands can be moved - // onto the new HLOs after simplification, and any side-effecting ops inside - // the loop aren't removed, just cloned and added back to the loop. But - // moving an op out of the loop also removes implicit control dependencies - // between the op and the ops outside the loop, so we'd have to add those back - // for things like infeed/outfeed. It gets complicated. So for now we just - // avoid it. - if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { - VLOG(2) << "Not attempting to remove while loop it is not removable: " + // TODO(jlebar): This is conservative: We could instead just run the while + // condition once (trip-count == 0) or twice (trip-count == 1). + if (while_op->while_condition()->HasSideEffect()) { + VLOG(2) << "Not attempting to remove while loop whose condition contains " + "side-effecting instructions: " << while_op->ToShortString(); return false; } // Remove while loops with static trip count of 0. optional trip_count = - ComputeWhileLoopTripCount(while_op, - /*max_value_returned=*/1); + ComputeWhileLoopTripCount(while_op, /*max_brute_force_iters=*/1); if (trip_count && *trip_count == 0) { // The loop never executes, so the value of the loop is the value of its // "init" operand. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index ecca76b1e86..8ec6e40044c 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -72,7 +72,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant({{LOOP_BOUND}}) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -107,7 +107,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2 - ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4) + ROOT less-than = pred[] compare(get-tuple-element.3, get-tuple-element.4), direction=LT } ENTRY SimpleLoopWithIndirectLoopBound { constant.3 = s32[] constant(42) @@ -209,11 +209,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } -// The limitation on not being able to simplify loops that contain infeeds (and -// other non-removable instructions) isn't fundamental -- it just stems from the -// fact that our infrastructure sees simplifying such a loop as tantamount to -// removing the non-removable instruction. -TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { +// We can simplify loops whose bodies contain infeed or other side-effecting +// instructions other than send/recv. +TEST_F(WhileLoopSimplifierTest, LoopWithInfeedSimplified) { auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); @@ -222,6 +220,22 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { auto token = while_body->AddInstruction(HloInstruction::CreateToken()); while_body->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple()); +} + +// We don't simplify trip-count-1 loops whose *conditions* contain infeed or +// other side-effecting instructions, because simplifying such a loop always +// removes its condition! +TEST_F(WhileLoopSimplifierTest, LoopWithInfeedInCondNotSimplified) { + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* while_cond = while_op->while_condition(); + auto token = while_cond->AddInstruction(HloInstruction::CreateToken()); + while_cond->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } @@ -237,7 +251,7 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { NonTupleShapedLoop.condition { loop_var = s32[] parameter(0) constant = s32[] constant(100) - ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant) + ROOT less-than = pred[] compare(s32[] loop_var, s32[] constant), direction=LT } ENTRY INonTupleShapedLoop { constant.2 = s32[] constant(42) @@ -387,7 +401,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { param0 = (s32[], s32[], s32[]) parameter(0) get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0), index=2 - ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element) + ROOT equal-to = pred[] compare(s32[] constant.2, s32[] get-tuple-element), direction=EQ } ENTRY RemoveUnusedOperands { x = s32[] parameter(0) @@ -431,6 +445,47 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } +// Check that we can remove unused loop operands even if the loop contains a +// side-effecting instruction. +TEST_F(WhileLoopSimplifierTest, + RemoveUnusedLoopOperandsDespiteSideEffectingOps) { + const string hlo_string = R"( + HloModule RemoveUnusedOperands + body { + loop_var = (s32[]) parameter(0) + gte0 = s32[] get-tuple-element(loop_var), index=0 + token0 = token[] after-all() + unused = ((s32[], pred[]), token[]) infeed(token0) + ROOT tuple = (s32[]) tuple(gte0) + } + cond { + loop_var = (s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY RemoveUnusedOperands { + x = s32[] parameter(0) + tuple.1 = (s32[]) tuple(s32[] x) + ROOT while = (s32[]) while((s32[]) tuple.1), + condition=cond, body=body + } + )"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + + // The original while instruction is still left in the module as a dead + // instruction, find a while instruction with a different name as the new + // while instruction. + const auto& instrs = m->entry_computation()->instructions(); + HloInstruction* new_while_op = + *absl::c_find_if(instrs, [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(new_while_op->shape())) + << new_while_op->shape().ToString(); +} + TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { const string hlo_string = R"( HloModule BodyHasNonTupleRoot @@ -471,7 +526,7 @@ TEST_F(WhileLoopSimplifierTest, loop_var.2 = (s32[], s32[3]{0}) parameter(0) get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(44) - ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -503,7 +558,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 constant.2 = s32[] constant(47) - ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2) + ROOT less-than = pred[] compare(get-tuple-element.4, constant.2), direction=LT } ENTRY SimpleLoop { constant.3 = s32[] constant(42) @@ -679,7 +734,7 @@ const char* const kSimpleMergeInductionVariablesModule = R"( b = TYPE[] get-tuple-element(param), index=1 sum = TYPE[] power(a, b) ten = TYPE[] constant(10) - ROOT cond = pred[] less-than(sum, ten) + ROOT cond = pred[] compare(sum, ten), direction=LT } ENTRY Loop { a = TYPE[] constant(10) diff --git a/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc index 5c19cbc015d..a1e18bbdef6 100644 --- a/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_trip_count_annotator_test.cc @@ -41,7 +41,7 @@ TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) { param = (s32[]) parameter(0) i = s32[] get-tuple-element(param), index=0 trip_count = s32[] constant(10) - ROOT done = pred[] less-than(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LT } ENTRY test { @@ -77,7 +77,7 @@ TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) { param = (s32[]) parameter(0) i = s32[] get-tuple-element(param), index=0 trip_count = s32[] constant(1000000) - ROOT done = pred[] less-than(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LT } ENTRY test { @@ -113,7 +113,7 @@ TEST_F(TripCountAnnotatorTest, NonzeroStart) { param = (s32[]) parameter(0) i = s32[] get-tuple-element(param), index=0 trip_count = s32[] constant(1000000) - ROOT done = pred[] less-than(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LT } ENTRY test { @@ -149,7 +149,7 @@ TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) { param = (s32[]) parameter(0) i = s32[] get-tuple-element(param), index=0 trip_count = s32[] constant(1000000) - ROOT done = pred[] less-than-or-equal-to(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LE } ENTRY test { @@ -188,7 +188,7 @@ TEST_F(TripCountAnnotatorTest, Int64Overflow) { param = (s64[]) parameter(0) i = s64[] get-tuple-element(param), index=0 trip_count = s64[] constant(9223372036854775807) // 2^63-1 - ROOT done = pred[] less-than-or-equal-to(i, trip_count) + ROOT done = pred[] compare(i, trip_count), direction=LE } ENTRY test { diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index d77386497a1..b6f65c763ea 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -166,7 +166,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape, TF_ASSIGN_OR_RETURN( HloInstruction * compare, - MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant)); + MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant)); cond_computation->set_root_instruction(compare); return std::move(cond_computation); } diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 94854047e53..27d24514f8f 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -141,6 +141,11 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { } } + if (!ShapeUtil::SameDimensions(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + return false; + } + if (!ignore_layout_) { if (lhs.layout().format() != rhs.layout().format()) { VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; @@ -161,11 +166,6 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { } } - if (!ShapeUtil::SameDimensions(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; - return false; - } - if (!ignore_dynamic_dimension_) { for (int i = 0; i < lhs.rank(); ++i) { if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 78cea83c6d7..b6e1cb621d7 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -56,7 +56,7 @@ class Shape { bool IsArray() const { return primitive_util::IsArrayType(element_type()); } bool IsTuple() const { return element_type() == TUPLE; } bool IsToken() const { return element_type() == TOKEN; } - bool IsOpaque() const { return element_type() == OPAQUE; } + bool IsOpaque() const { return element_type() == OPAQUE_TYPE; } // Returns true if no array dimension in the shape is dynamically sized. Tuple // shapes are traversed recursively. @@ -169,6 +169,11 @@ class Shape { ignore_element_size_in_layout_ = true; return *this; } + Equal& MinorToMajorOnlyInLayout() { + ignore_tiles_in_layout_ = true; + ignore_element_size_in_layout_ = true; + return *this; + } Equal& IgnoreElementType() { ignore_element_type_ = true; return *this; @@ -195,6 +200,12 @@ class Shape { bool operator==(const Shape& other) const { return Equal()(*this, other); } bool operator!=(const Shape& other) const { return !(*this == other); } + template + friend H AbslHashValue(H h, const Shape& s) { + return H::combine(std::move(h), s.element_type_, s.dimensions_, + s.dynamic_dimensions_, s.tuple_shapes_, s.layout_); + } + private: // The element type of this shape (tuple, array, etc). PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index a000886d60d..44ed3181162 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -46,8 +46,13 @@ void ShapeLayout::SetToDefaultLayout() { LayoutUtil::SetToDefaultLayout(&shape_); } -bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const { - return ShapeUtil::Equal(shape, shape_); +bool ShapeLayout::MatchesLayoutInShape(const Shape& shape, + bool minor_to_major_only) const { + auto equal = Shape::Equal(); + if (minor_to_major_only) { + equal.MinorToMajorOnlyInLayout(); + } + return equal(shape, shape_); } const Layout& ShapeLayout::layout() const { diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 214cf988549..b4982f1d8e4 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -45,7 +45,8 @@ class ShapeLayout { // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible // with the ShapeLayout's shape, then false is returned. - bool MatchesLayoutInShape(const Shape& shape) const; + bool MatchesLayoutInShape(const Shape& shape, + bool minor_to_major_only = false) const; // Copies the layout from the given shape into this ShapeLayout. 'other_shape' // must be compatible with the ShapeLayout's shape. diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc index dbdafcc0a1f..aa6c7d10989 100644 --- a/tensorflow/compiler/xla/shape_test.cc +++ b/tensorflow/compiler/xla/shape_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape.h" #include + +#include "absl/hash/hash_testing.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -210,5 +212,11 @@ TEST_F(ShapeTest, ProgramShapeToString) { prog.ToString()); } +TEST_F(ShapeTest, SupportsAbslHash) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly( + {opaque_, token_, scalar_, scalar_with_tile_, matrix_, matrix2_, tuple_, + nested_tuple_, dyanmic_matrix_})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 089120179e2..75eb34f294b 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -122,15 +122,16 @@ class ShapeTree { // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } - // Replaces *only* the underlying shape of this ShapeTree. The caller must own - // the Shape object and hence shape_storage_ is not updated. - // - // Only safe to use this if the ShapeTree was constructed with 'explicit - // ShapeTree(const Shape* shape)' or is moved from one such ShapeTree. The - // caller must ensure that the input shape is consistent with the underlying - // tree. + // A ShapeTree object can own the underlying Shape pointer (via the + // shape_storage_ member), or can point to a Shape object owned by the caller. + // This API replaces the underlying Shape object to the one supplied by the + // caller, whom must ensure the object remain valid for the whole lifetime of + // this ShapeTree object, and also that the Shape is consistent with it. void replace_shape_ptr(const Shape* shape) { - CHECK(shape_storage_.get() == nullptr); + if (shape_storage_ != nullptr) { + CHECK_EQ(*shape, *shape_storage_); + shape_storage_ = nullptr; + } shape_ = shape; } @@ -290,6 +291,8 @@ class ShapeTree { const ShapeIndex& source_base_index, const ShapeIndex& target_base_index); + StatusOr> SubShapeTree(const ShapeIndex& index) const; + bool operator==(const ShapeTree& other) const; bool operator!=(const ShapeTree& other) const { return !(*this == other); } @@ -664,6 +667,16 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, }); } +template +StatusOr> ShapeTree::SubShapeTree( + const ShapeIndex& index) const { + TF_ASSIGN_OR_RETURN(const Shape* sub_shape, + ShapeUtil::TryGetSubshape(shape(), index)); + ShapeTree sub_shape_tree(*sub_shape); + sub_shape_tree.CopySubtreeFrom(*this, index, {}); + return std::move(sub_shape_tree); +} + template bool ShapeTree::operator==(const ShapeTree& other) const { bool equal = true; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index acaa9cae7c2..2eb9d278bd9 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -96,12 +96,17 @@ StatusOr MakeShapeWithLayoutInternal( return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); } - if (element_type == OPAQUE || element_type == TUPLE) { + if (element_type == OPAQUE_TYPE || element_type == TUPLE) { return InvalidArgument("Unsupported element type: %s", PrimitiveType_Name(element_type)); } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); + if (element_size_in_bits == + ShapeUtil::ByteSizeOfPrimitiveType(element_type) * 8) { + // Only set element_size_in_bits if it's different from the default value. + element_size_in_bits = 0; + } *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits); if (!shape.has_layout()) { @@ -219,7 +224,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( for (int i = 0; i < shape.dimensions_size(); ++i) { dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - return MakeShapeWithDescendingLayout(shape.element_type(), dims); + Shape new_shape = MakeShapeWithDescendingLayout(shape.element_type(), dims); + // Since the physical layout is kept the same, the tiles and element size are + // the same also. + *new_shape.mutable_layout()->mutable_tiles() = shape.layout().tiles(); + new_shape.mutable_layout()->set_element_size_in_bits( + shape.layout().element_size_in_bits()); + return new_shape; } /* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type, @@ -247,7 +258,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::MakeOpaqueShape() { Shape result; - result.set_element_type(OPAQUE); + result.set_element_type(OPAQUE_TYPE); TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); return result; } @@ -308,7 +319,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case C64: case C128: case TUPLE: - case OPAQUE: + case OPAQUE_TYPE: case TOKEN: return false; @@ -559,7 +570,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( // Tokens require no space. return 0; case TUPLE: - case OPAQUE: + case OPAQUE_TYPE: LOG(FATAL) << PrimitiveType_Name(primitive_type) << " primitive type has no definitive size"; default: @@ -580,7 +591,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return byte_size; } else if (shape.element_type() == TOKEN) { return 0; - } else if (shape.element_type() == OPAQUE) { + } else if (shape.element_type() == OPAQUE_TYPE) { CHECK_GT(pointer_size, 0); return pointer_size; } @@ -642,7 +653,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } // Tokens and opaques can should not have layout or dimensions. - if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) { + if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE_TYPE) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", @@ -1067,6 +1078,32 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return common_factors; } +/* static */ absl::optional> +ShapeUtil::ReshapeLeavesDimensionsUnmodified( + const Shape& from_shape, const Shape& to_shape, + absl::Span input_dim_indices) { + CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); + + std::vector output_dim_indices; + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape); + size_t i = 0; // index to unmodified_dims + for (int64 input_dim_index : input_dim_indices) { + // Search unmodified_dims for input_dim_index. We can search from the last + // matching position because input_dim_indices is guaranteed to be sorted. + while (i < unmodified_dims.size() && + unmodified_dims[i].first < input_dim_index) { + ++i; + } + if (i >= unmodified_dims.size() || + unmodified_dims[i].first != input_dim_index) { + return absl::nullopt; + } + output_dim_indices.push_back(unmodified_dims[i].second); + } + return output_dim_indices; +} + /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, absl::Span dimension_mapping) { @@ -1294,8 +1331,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } } Shape output_shape_with_layout = output_shape; - *output_shape_with_layout.mutable_layout()->mutable_minor_to_major() = - layout; + *output_shape_with_layout.mutable_layout() = Layout{layout}; return output_shape_with_layout; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 7f610a6085d..0065a3b8784 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -578,6 +578,20 @@ class ShapeUtil { static std::vector> DimensionsUnmodifiedByReshape( const Shape& input_shape, const Shape& output_shape); + // Return whether the given reshape instruction leaves the dimensions at the + // given input indices unmodified, and returns their output indices. + // + // Example: + // input_dim_indices = {2, 3} + // input shape = T[a, b, x, y, cd] + // output shape = T[ab, x, 1, y, c, d] + // return value = {1, 3} + // + // Precondition: input_dim_indices is sorted. + static absl::optional> ReshapeLeavesDimensionsUnmodified( + const Shape& from_shape, const Shape& to_shape, + absl::Span input_dim_indices); + // Returns whether a transpose from input_shape to output_shape with dimension // mapping "dimension_mapping" produces a result which is bit-wise identical // to its input and thus may be replaced with a bitcast. @@ -624,7 +638,7 @@ class ShapeUtil { // continue, or false otherwise. // // visitor_function must be a callable of type - // StatusOr(Span) or compatible. + // StatusOr(absl::Span) or compatible. template static Status ForEachIndexWithStatus(const Shape& shape, absl::Span base, diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 3ede5e6e38a..a2b76fafb12 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -28,9 +28,6 @@ limitations under the License. // This module contains a minimal subset of gmock functionality just // sufficient to execute the currently existing tests. -namespace util { -class Status; -} // namespace util namespace xla { template diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 15d05f6a6e6..cff87c59938 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -55,6 +55,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -259,7 +260,6 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", @@ -268,6 +268,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory_allocator", "//third_party/eigen3", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", @@ -299,6 +300,7 @@ xla_test( srcs = ["conv_depthwise_test.cc"], shard_count = 50, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -324,6 +326,7 @@ xla_test( ], shard_count = 6, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -350,6 +353,7 @@ xla_test( ], shard_count = 50, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -368,6 +372,7 @@ xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -405,6 +410,7 @@ xla_test( name = "while_test", srcs = ["while_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -432,6 +438,7 @@ xla_test( "interpreter", ], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -455,6 +462,7 @@ xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", @@ -469,6 +477,7 @@ xla_test( name = "map_test", srcs = ["map_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -500,6 +509,7 @@ xla_test( "optonly", ], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -536,6 +546,7 @@ xla_test( name = "select_test", srcs = ["select_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", @@ -554,6 +565,7 @@ xla_test( srcs = ["conditional_test.cc"], shard_count = 2, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", @@ -571,6 +583,7 @@ xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", @@ -590,6 +603,7 @@ xla_test( srcs = ["scalar_computations_test.cc"], shard_count = 32, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", @@ -614,6 +628,7 @@ xla_test( name = "deallocation_test", srcs = ["deallocation_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -631,6 +646,7 @@ xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -653,6 +669,7 @@ xla_test( srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", @@ -687,6 +704,7 @@ xla_test( deps = [ ":client_library_test_base", ":literal_test_util", + ":test_macros_header", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", @@ -700,6 +718,7 @@ xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -728,6 +747,39 @@ xla_test( "optonly", ], deps = [ + ":test_macros_header", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + +# Run dot tests with auto-tuning disabled. This just does a basic sanity check +# that enabling xla_gpu_disable_autotune does not break simple graphs. +xla_test( + name = "dot_operation_test_autotune_disabled", + srcs = ["dot_operation_test.cc"], + args = ["--xla_gpu_disable_autotune"], + backends = ["gpu"], + shard_count = 20, + tags = [ + "optonly", + ], + deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", @@ -753,6 +805,7 @@ xla_test( deps = [ ":client_library_test_base", ":hlo_test_base", + ":test_macros_header", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -768,6 +821,7 @@ xla_test( deps = [ ":client_library_test_base", ":hlo_test_base", + ":test_macros_header", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", @@ -787,6 +841,7 @@ xla_test( shard_count = 20, tags = ["optonly"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", @@ -810,6 +865,7 @@ xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -829,6 +885,7 @@ xla_test( name = "constants_test", srcs = ["constants_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", @@ -849,6 +906,7 @@ xla_test( ) CONVOLUTION_TEST_DEPS = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -881,6 +939,23 @@ xla_test( ], ) +# Run convolution tests with auto-tuning disabled. This just does a basic +# sanity check that enabling xla_gpu_disable_autotune does not break simple +# graphs. +xla_test( + name = "convolution_test_autotune_disabled", + timeout = "long", + srcs = ["convolution_test.cc"], + args = ["--xla_gpu_disable_autotune"], + backends = ["gpu"], + shard_count = 40, + tags = ["optonly"], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + xla_test( name = "convolution_test_gpu_alternative_layout", timeout = "long", @@ -904,6 +979,7 @@ xla_test( }, shard_count = 30, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -927,6 +1003,7 @@ xla_test( srcs = ["convolution_dimension_numbers_test.cc"], shard_count = 20, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", @@ -956,6 +1033,7 @@ xla_test( ], shard_count = 40, deps = [ + ":test_macros_header", ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -990,6 +1068,7 @@ xla_test( srcs = ["bfloat16_test.cc"], shard_count = 40, deps = [ + ":test_macros_header", ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1025,6 +1104,7 @@ xla_test( "gpu", ], deps = [ + ":test_macros_header", ":test_utils", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", @@ -1046,6 +1126,7 @@ xla_test( srcs = ["slice_test.cc"], shard_count = 40, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", @@ -1066,6 +1147,7 @@ xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla/client:local_client", @@ -1083,6 +1165,7 @@ xla_test( timeout = "moderate", srcs = ["dynamic_ops_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:test_helpers", @@ -1090,7 +1173,6 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", @@ -1101,6 +1183,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -1108,6 +1191,7 @@ xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -1131,6 +1215,7 @@ xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", @@ -1153,6 +1238,7 @@ xla_test( "optonly", ], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", @@ -1168,6 +1254,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1219,7 +1306,7 @@ xla_test( "optonly", ], xla_test_library_deps = [":reduce_window_test_library"], - deps = [], + deps = [":test_macros_header"], ) xla_test( @@ -1230,6 +1317,7 @@ xla_test( "optonly", ], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", @@ -1254,6 +1342,7 @@ xla_test( srcs = ["copy_test.cc"], deps = [ ":client_library_test_base", + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", @@ -1273,6 +1362,7 @@ xla_test( name = "reduce_hlo_test", srcs = ["reduce_hlo_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1286,6 +1376,7 @@ xla_test( name = "token_hlo_test", srcs = ["token_hlo_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1300,6 +1391,7 @@ xla_test( name = "call_test", srcs = ["call_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1320,14 +1412,15 @@ xla_test( srcs = ["custom_call_test.cc"], backends = ["cpu"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1342,6 +1435,7 @@ xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", @@ -1359,6 +1453,7 @@ xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -1378,6 +1473,7 @@ xla_test( name = "pad_test", srcs = ["pad_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", @@ -1413,6 +1509,7 @@ xla_test( name = "log_test", srcs = ["log_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", @@ -1427,6 +1524,7 @@ xla_test( name = "matrix_ops_simple_test", srcs = ["matrix_ops_simple_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", @@ -1453,6 +1551,7 @@ xla_test( name = "prng_test", srcs = ["prng_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1473,6 +1572,7 @@ xla_test( srcs = ["reshape_test.cc"], shard_count = 30, deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", @@ -1498,6 +1598,7 @@ xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla/client:local_client", @@ -1516,6 +1617,7 @@ xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1539,6 +1641,7 @@ xla_test( name = "concat_test", srcs = ["concat_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", @@ -1559,6 +1662,7 @@ xla_test( name = "convert_test", srcs = ["convert_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", @@ -1582,6 +1686,7 @@ xla_test( "interpreter", ], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1607,12 +1712,17 @@ xla_test( srcs = ["multi_device_all_reduce_test.cc"], backends = ["gpu"], tags = [ + # This test is tagged "manual" because it requires multiple GPUs, and + # Forge only supports single-GPU tests. Guitar skips "manual" tests + # unless they're also tagged "guitar". + "noguitar", # TODO(b/131524578): Re-enable this. "manual", "multi_gpu", "no_oss", "notap", ], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1639,6 +1749,7 @@ xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", @@ -1656,6 +1767,7 @@ xla_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1678,6 +1790,7 @@ xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1694,6 +1807,7 @@ xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1717,6 +1831,7 @@ xla_test( name = "client_test", srcs = ["client_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1744,6 +1859,7 @@ xla_test( ], deps = [ ":client_library_test_base", + ":test_macros_header", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", @@ -1762,6 +1878,7 @@ xla_test( ], deps = [ ":client_library_test_base", + ":test_macros_header", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", @@ -1774,6 +1891,7 @@ xla_test( name = "replay_test", srcs = ["replay_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", @@ -1796,6 +1914,7 @@ xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -1823,7 +1942,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", - "//tensorflow/compiler/xla/service/gpu:gpu_compiler", + "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/stream_executor", @@ -1836,6 +1955,7 @@ xla_test( name = "round_trip_packed_literal_test", srcs = ["round_trip_packed_literal_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:packed_literal_reader", "//tensorflow/compiler/xla:shape_util", @@ -1857,6 +1977,7 @@ xla_test( name = "fusion_test", srcs = ["fusion_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1884,6 +2005,7 @@ xla_test( srcs = ["multioutput_fusion_test.cc"], backends = ["gpu"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1922,6 +2044,7 @@ xla_test( name = "local_client_allocation_test", srcs = ["local_client_allocation_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", @@ -1945,6 +2068,7 @@ xla_test( shard_count = 30, tags = ["optonly"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1955,7 +2079,6 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", @@ -1967,6 +2090,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -1978,6 +2102,7 @@ xla_test( "interpreter", ], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla/tests:local_client_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -2004,6 +2129,7 @@ xla_test( name = "round_trip_transfer_test", srcs = ["round_trip_transfer_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2024,6 +2150,7 @@ xla_test( name = "reshape_motion_test", srcs = ["reshape_motion_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -2073,19 +2200,20 @@ xla_test( deps = [ ":literal_test_util", ":local_client_test_base", + ":test_macros_header", ":xla_internal_test_main", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -2100,6 +2228,7 @@ xla_test( ], deps = [ ":hlo_test_base", + ":test_macros_header", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -2136,6 +2265,7 @@ xla_test( backends = ["cpu"], deps = [ ":local_client_test_base", + ":test_macros_header", ":test_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", @@ -2158,6 +2288,7 @@ xla_test( ], deps = [ ":client_library_test_base", + ":test_macros_header", ":xla_internal_test_main", "//tensorflow/core:lib", ], @@ -2190,6 +2321,7 @@ xla_test( ], deps = [ ":hlo_test_base", + ":test_macros_header", ":xla_internal_test_main", # fixdeps: keep "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:test", @@ -2204,6 +2336,7 @@ xla_test( "noasan", # sometimes times out, http://b/78650012 ], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", @@ -2224,6 +2357,7 @@ xla_test( srcs = ["cholesky_test.cc"], tags = ["optonly"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index acdd3c9da92..a5e27cd67a7 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -349,9 +349,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { error_spec_); } -// TODO(b/119692968): This test runs OOM on the GPU and CPU backend. -XLA_TEST_F(ArrayElementwiseOpTest, - DISABLED_ON_GPU(DISABLED_ON_CPU(DeeplyNestedAddWithSlices))) { +XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) { XlaBuilder builder(TestName()); std::vector values(30, 0.0); auto a_literal = LiteralUtil::CreateR1(values); @@ -1067,6 +1065,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, 1, -15, 341}); + PopulationCount(a); + ComputeAndCompareR1(&builder, {0, 1, 29, 5}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, 1}, {-15, 341}}); + PopulationCount(a); + Array2D expected_array({{0, 1}, {29, 5}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}}); + PopulationCount(a); + Array2D expected_array({{0, 64}, {63, 62}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XlaBuilder builder(TestName()); auto a = ConstantR1( diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index c14d279ac56..48719c6c47c 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -128,7 +128,7 @@ def xla_test( srcs = srcs, copts = copts, testonly = True, - deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], + deps = deps, ) for backend in filter_backends(backends): @@ -265,6 +265,8 @@ def generate_backend_test_macros(backends = []): "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, ], deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index cfee9c0f8a4..0ab765aefa0 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -1801,7 +1801,8 @@ INSTANTIATE_TEST_CASE_P( Convolve1DTestParam{24, 1, 1, 10, 5}, Convolve1DTestParam{160, 1, 1, 10, 1}, Convolve1DTestParam{255, 1, 1, 3, 1}, - Convolve1DTestParam{130, 1, 1, 1, 3}, + Convolve1DTestParam{130, 1, 1, 1, 2}, + Convolve1DTestParam{136, 1, 1, 1, 2}, Convolve1DTestParam{64, 1, 1, 1, 1}, Convolve1DTestParam{128, 1, 1, 1, 1}, Convolve1DTestParam{139, 1, 1, 128, 1}, diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 4687ed61a7d..63c3b4b5b02 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -64,10 +64,10 @@ void F32TupleSwap(float** out, float** in) { } // namespace -REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); -REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); -REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); -REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); namespace xla { namespace { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 414d0b14a6b..59c3d4f5c7e 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -395,6 +396,8 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { ParametricDotTestWithoutLayoutAssignment() { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "layout-assignment"); + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "tiling-assignment"); // Disable algebraic simplification because the pass may replace a dot // instruction with a layout-changing multiplication instruction. execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( @@ -1198,12 +1201,52 @@ std::vector GetEinsumTestCases() { p{v{16, 34}, v{16, 34}, "ab,ab->ab"}, p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"}, p{v{5, 19}, v{}, "ab,->ab"}, + p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf->bhqk"}, }; return test_cases; } -INSTANTIATE_TEST_CASE_P(Einsum, EinsumTest, - ::testing::ValuesIn(GetEinsumTestCases())); +INSTANTIATE_TEST_SUITE_P(Einsum, EinsumTest, + ::testing::ValuesIn(GetEinsumTestCases())); + +using BatchDotParamType = + std::tuple, std::vector, std::vector>; +class BatchDotTest : public DotOperationTest, + public ::testing::WithParamInterface {}; +XLA_TEST_P(BatchDotTest, BroadcastingBatchDotTest) { + XlaBuilder builder(TestName()); + auto x = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam()))) + .ValueOrDie(), + &builder); + auto y = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam()))) + .ValueOrDie(), + &builder); + auto batch_dot = BatchDot(x, y); + auto output_shape = builder.GetShape(batch_dot).ValueOrDie(); + EXPECT_EQ(output_shape.dimensions(), std::get<2>(GetParam())); + ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3}); +} + +std::vector GetBatchDotTestCases() { + using v = std::vector; + using p = BatchDotParamType; + std::vector

test_cases = { + p{v{5, 6}, v{6, 7}, v{5, 7}}, + p{v{5, 6, 11}, v{5, 11, 7}, v{5, 6, 7}}, + p{v{5, 6, 11}, v{11, 7}, v{5, 6, 7}}, + p{v{5, 6, 11}, v{1, 11, 7}, v{5, 6, 7}}, + p{v{6, 11}, v{5, 11, 7}, v{5, 6, 7}}, + p{v{1, 6, 11}, v{5, 11, 7}, v{5, 6, 7}}, + p{v{8, 1, 2, 3}, v{8, 3, 4}, v{8, 8, 2, 4}}, + p{v{8, 8, 2, 3}, v{8, 1, 3, 2}, v{8, 8, 2, 2}}, + }; + return test_cases; +} + +INSTANTIATE_TEST_SUITE_P(BatchDot, BatchDotTest, + ::testing::ValuesIn(GetBatchDotTestCases())); class DotOperationTextTest : public HloTestBase {}; @@ -1361,5 +1404,183 @@ ENTRY MatrixVectorComplex { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) { + Array3D input_arr(2, 3, 2); + Array2D const_arr(2, 6); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Transpose(t0, {1, 0, 2}); + auto rhs = Reshape(t1, {6, 2}); + auto lhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_LR) { + Array3D input_arr(2, 3, 2); + Array2D const_arr(2, 6); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Transpose(t0, {1, 0, 2}); + auto lhs = Reshape(t1, {6, 2}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(0); + dims.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dims); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_RL) { + Array4D input_arr(2, 2, 3, 4); + Array2D const_arr(24, 2); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR4FromArray4D(input_arr), &builder); + auto t1 = Transpose(t0, {0, 2, 3, 1}); + auto lhs = Reshape(t1, {2, 24}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_MM) { + Array3D input_arr(2, 6, 2); + Array3D const_arr(2, 6, 3); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR3FromArray3D(input_arr), &builder); + auto t1 = Reshape(t0, {2, 2, 3, 2}); + auto t2 = Transpose(t1, {0, 2, 1, 3}); + auto lhs = Reshape(t2, {2, 6, 2}); + auto rhs = ConstantR3FromArray3D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(1); + dims.add_rhs_contracting_dimensions(1); + dims.add_lhs_batch_dimensions(0); + dims.add_rhs_batch_dimensions(0); + DotGeneral(lhs, rhs, dims); + + ComputeAndCompare(&builder, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) { + Array4D input_arr(2, 2, 3, 5); + Array2D const_arr(2, 30); + input_arr.FillIota(0); + const_arr.FillIota(0); + + XlaBuilder builder(TestName()); + auto t0 = + AddParam(LiteralUtil::CreateR4FromArray4D(input_arr), &builder); + auto t1 = Transpose(t0, {0, 2, 1, 3}); + auto t2 = Reshape(t1, {2, 6, 5}); + auto t3 = Transpose(t2, {0, 2, 1}); + auto lhs = Reshape(t3, {2, 30}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + + DotDimensionNumbers dims; + dims.add_lhs_contracting_dimensions(1); + dims.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dims); + + // Constant folding are disabled by default in unit tests. algsimp + // optimization can be applied multiple times if we fold the transpose + // and reshape that are moved to the constant side of the dot. + mutable_debug_options()->clear_xla_disable_hlo_passes(); + ComputeAndCompare(&builder, {}, error_spec_); +} + +// This benchmark is to show the performance impact of the following +// transformation: +// dot(reshape(transpose(A)), Const) ==> +// dot(reshape(A), reshape(transpose(reshape(Const)))), +// and then fold the reshape and transpose on the Const side. +// We can compare performance with and without algsimp pass to see the impact. +void DOT_ReorderContracting(int num_iters) { + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + se::StreamExecutorMemoryAllocator allocator(platform, executors); + + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + auto client = + ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); + + int device_ordinal = client->default_device_ordinal(); + + const int64 d0 = 128; + const int64 d1 = 128; + const int64 d2 = 128; + const int64 d3 = 128; + + Array3D input_arr(d0, d1, d2); + Array2D const_arr(d1 * d2, d3); + input_arr.FillIota(0); + const_arr.FillIota(0); + XlaBuilder builder("ReorderContracting"); + auto t0 = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {d0, d1, d2}), "param0"); + auto t1 = Transpose(t0, {0, 2, 1}); + auto lhs = Reshape(t1, {d0, d2 * d1}); + auto rhs = ConstantR2FromArray2D(&builder, const_arr); + Dot(lhs, rhs); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto input_literal = LiteralUtil::CreateR3FromArray3D(input_arr); + ScopedShapedBuffer buffer0 = + client->LiteralToShapedBuffer(input_literal, device_ordinal) + .ConsumeValueOrDie(); + + std::unique_ptr executable = + client + ->Compile(computation, {&buffer0.on_host_shape()}, + ExecutableBuildOptions()) + .ConsumeValueOrDie(); + + se::Stream stream(executors[device_ordinal]); + stream.Init(); + + ExecutableRunOptions options; + options.set_allocator(&allocator); + + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + ASSERT_IS_OK(executable->Run({&buffer0}, options)); + } + + const int64 total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3; + tensorflow::testing::BytesProcessed(static_cast(num_iters) * + total_bytes * sizeof(float)); + tensorflow::testing::UseRealTime(); + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + ASSERT_IS_OK(executable->Run({&buffer0}, options)); + } +} + +BENCHMARK(DOT_ReorderContracting); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 82e2db36143..1ea72af5f5f 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -34,6 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace { @@ -736,7 +736,7 @@ void BM_DynamicSlice(int num_iters) { se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); - StreamExecutorMemoryAllocator allocator(platform, executors); + se::StreamExecutorMemoryAllocator allocator(platform, executors); LocalClient* client = ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); auto* transfer_manager = diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc index 58bb9a217b8..7df01e04c6d 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc @@ -215,7 +215,7 @@ class ExhaustiveOpTest RunImpl(enqueue_op, evaluate_op); break; case BF16: - SetDefaultErrSpec(0.001, 0.01); + SetDefaultErrSpec(0.002, 0.02); RunImpl(enqueue_op, evaluate_op); break; default: @@ -245,14 +245,6 @@ class ExhaustiveOpTest int64 begin, end; std::tie(begin, end) = test_range; - if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) { - LOG(INFO) << absl::StreamFormat( - "Skipping this shard, as the range under test, [%d, %d), falls " - "entirely within the known-incorrect range [%d, %d).", - begin, end, known_incorrect_begin_, known_incorrect_end_); - return; - } - LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; int64 input_size = end - begin; @@ -262,8 +254,7 @@ class ExhaustiveOpTest IntegralT input_val = i + begin; // If the operation is known to be buggy on a specific input clamp that // input to 0 under the assumption that the op is at least correct on 0. - if (input_val >= known_incorrect_begin_ && - input_val < known_incorrect_end_) { + if (known_incorrect_fn_ && known_incorrect_fn_(input_val)) { input_arr[i] = T{0}; } else { input_arr[i] = absl::bit_cast(input_val); @@ -347,6 +338,10 @@ class ExhaustiveOpTest // denormals. const T expected_at_pos_zero = static_cast(evaluate_op(0)); const T expected_at_neg_zero = static_cast(evaluate_op(-0.0)); + const T expected_at_pos_min_normal_float = + static_cast(evaluate_op(std::numeric_limits::min())); + const T expected_at_neg_min_normal_float = + static_cast(evaluate_op(-std::numeric_limits::min())); for (int64 i = 0; i < input_arr.size(); ++i) { T input = input_arr[i]; float input_f32 = static_cast(input); @@ -378,13 +373,23 @@ class ExhaustiveOpTest // - evaluate_op(input) // - evaluate_op(+/-0), where the sign of 0 equal to the sign of // `input`, + // - evaluate_op(+/-min_normal_float), where the sign of + // min_normal_float matches `input`. // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of // 0 is the opposite of `input`. + // + // (In particular, the XLA:CPU implementation of log flushes positive + // denormals to min-normal-float. This seems kind of reasonable if our + // goal is to avoid infinities because they cause nans?) T sign_preserving_ftz_expected = std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero; + T flush_to_normal_expected = std::signbit(input_f32) + ? expected_at_neg_min_normal_float + : expected_at_pos_min_normal_float; T sign_nonpreserving_ftz_expected = std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero; if (IsClose(sign_preserving_ftz_expected, actual) || + IsClose(flush_to_normal_expected, actual) || (relaxed_denormal_signs_ && IsClose(sign_nonpreserving_ftz_expected, actual))) { continue; @@ -395,11 +400,13 @@ class ExhaustiveOpTest return absl::StrFormat( "Mismatch on denormal value %s. Expected one of:\n" " %10s (evaluated at full-precision value)\n" + " %10s (evaluated at sign-preserving min-normal-float)\n" " %10s (evaluated after flushing to sign-preserving zero)\n" " %10s (evaluated after flushing to non-sign-preserving " "zero)\n" "but got %s.", - StringifyNum(input), StringifyNum(expected), + StringifyNum(input), // + StringifyNum(expected), StringifyNum(flush_to_normal_expected), StringifyNum(sign_preserving_ftz_expected), StringifyNum(sign_nonpreserving_ftz_expected), StringifyNum(actual)); @@ -409,10 +416,13 @@ class ExhaustiveOpTest return absl::StrFormat( "Mismatch on denormal value %s. Expected one of:\n" " %10s (evaluated at full-precision value)\n" + " %10s (evaluated at sign-preserving min-normal-float)\n" " %10s (evaluated after flushing to sign-preserving zero)\n" "but got %s.", - StringifyNum(input), StringifyNum(expected), - StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual)); + StringifyNum(input), // + StringifyNum(expected), StringifyNum(flush_to_normal_expected), + StringifyNum(sign_preserving_ftz_expected), // + StringifyNum(actual)); }); } } @@ -434,11 +444,14 @@ class ExhaustiveOpTest LOG(ERROR) << err_generator(); } else if (*mismatches == kMaxMismatchesLoggedToErr) { LOG(ERROR) << "Not printing any more mismatches; pass " - "--vmodule=exhaustive_f32__op_test=2 to see " + "--vmodule=exhaustive_op_test=2 to see " "all of them."; } } + // Sets error parameters appropriately for testing sin/cos/tan. + void SetParamsForSinCosTan(); + // The following members are set during construction so testcases can read // these values and use them e.g. to influence the values given to the mutable // members below. @@ -452,10 +465,9 @@ class ExhaustiveOpTest // Tests can set the following variables for control over execution. This is // safe because each XLA_TEST_P instantiates a new instance of this class. - // Testing will ignore the given range (encoded as bitwise representations of - // the type under test zero-extended to int64). - int64 known_incorrect_begin_ = 0; - int64 known_incorrect_end_ = 0; + // Testing will ignore inputs for which known_incorect_fn_ returns true. (Its + // argument is the type under test, e.g. f32, zero-extended to int64). + std::function known_incorrect_fn_; // If unset, reasonable defaults will be used depending on the type under // test. @@ -496,40 +508,39 @@ XLA_TEST_P(ExhaustiveOpTest, Log1p) { } XLA_TEST_P(ExhaustiveOpTest, Exp) { - if (platform_ == "Host" && ty_ == F32) { - // TODO(b/73142289): The vectorized Exp implementation gives results outside - // our error spec in this range. - known_incorrect_begin_ = 1107296256 + 11583654; - known_incorrect_end_ = 1107296256 + 11629080; - } else if (platform_ == "Host" && ty_ == BF16) { - // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU? - // - // Mismatch on 88.5 (0x42b1). - // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80). - known_incorrect_begin_ = 0x42b1; - known_incorrect_end_ = 0x42b2; + // Our CPU implementation of exp returns one incorrect value: says + // exp(88.7228394) = max-float, but the correct answer is inf. We deem this + // acceptable and check for it explicitly so that we can be aware if anything + // changes. + if (platform_ == "Host") { + auto host_exp_with_overflow = +[](float f) { + if (f == 88.7228394f) { + return 3.40282347e+38f; + } + return std::exp(f); + }; + Run(Exp, host_exp_with_overflow); + } else { + Run(Exp, std::exp); } - - Run(Exp, std::exp); } XLA_TEST_P(ExhaustiveOpTest, Expm1) { - // Expm1 has the same erroneous behavior on CPU as Exp. - if (platform_ == "Host" && ty_ == F32) { - // TODO(b/73142289): The vectorized Exp implementation gives results outside - // our error spec in this range. - known_incorrect_begin_ = 1107296256 + 11583654; - known_incorrect_end_ = 1107296256 + 11629080; - } else if (platform_ == "Host" && ty_ == BF16) { - // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU? - // - // Mismatch on 88.5 (0x42b1). - // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80). - known_incorrect_begin_ = 0x42b1; - known_incorrect_end_ = 0x42b2; + // Our CPU implementation of expm1 returns one incorrect value: says + // exp(88.7228394) = max-float, but the correct answer is inf. We deem this + // acceptable and check for it explicitly so that we can be aware if anything + // changes. + if (platform_ == "Host") { + auto host_expm1_with_overflow = +[](float f) { + if (f == 88.7228394f) { + return 3.40282347e+38f; + } + return std::expm1(f); + }; + Run(Expm1, host_expm1_with_overflow); + } else { + Run(Expm1, std::expm1); } - - Run(Expm1, std::expm1); } // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but @@ -553,10 +564,111 @@ XLA_TEST_P(ExhaustiveOpTest, Sqrt) { Run(Sqrt, std::sqrt); } -// TODO(jlebar): Add remaining trig functions. Don't forget Atan2! // TODO(jlebar): Test trig functions over complex inputs. + +XLA_TEST_P(ExhaustiveOpTest, Acosh) { + // Error inherited from Log, which our implementation of Acosh uses. + if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { + abs_err_ = 0.001; + rel_err_ = 0.001; + } + Run(Acosh, std::acosh); +} +XLA_TEST_P(ExhaustiveOpTest, Asinh) { + // Error inherited from Log, which our implementation of Asinh uses. + if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { + abs_err_ = 0.001; + rel_err_ = 0.001; + } + Run(Asinh, std::asinh); +} +XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); } +XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); } +XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); } + +XLA_TEST_P(ExhaustiveOpTest, Cosh) { + // Our cosh implementation incorrectly overflows to inf for +/-89.4159851. + // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to + // max-float, so we deem this acceptable. + // + // This does not occur on CPU because we have an offsetting error in our + // implementation of exp. + float (*host_cosh)(float); + if (platform_ == "Host") { + host_cosh = &std::cosh; + } else { + host_cosh = +[](float x) { + if (std::abs(x) == 89.4159851f) { + return std::numeric_limits::infinity(); + } + return std::cosh(x); + }; + } + Run(Cosh, host_cosh); +} +XLA_TEST_P(ExhaustiveOpTest, Sinh) { + // Our sinh implementation incorrectly overflows to +/-inf for +/-89.4159851. + // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to + // max-float, so we deem this acceptable. + // + // This does not occur on CPU because we have an offsetting error in our + // implementation of exp. + float (*host_sinh)(float); + if (platform_ == "Host") { + host_sinh = &std::sinh; + } else { + host_sinh = +[](float x) { + if (std::abs(x) == 89.4159851f) { + return std::copysign(std::numeric_limits::infinity(), x); + } + return std::sinh(x); + }; + } + Run(Sinh, host_sinh); +} XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); } +void ExhaustiveOpTest::SetParamsForSinCosTan() { + if (platform_ == "Host" || platform_ == "CUDA") { + return; + } + + // Non CPU/GPU targets may have used the Cody-Waite range reduction technique + // and will not provide meaningful results for sin/cos/tan if magnitudes + // exceed 2**p. + if (ty_ == F32) { + rel_err_ = 0.001; + abs_err_ = 0.001; + known_incorrect_fn_ = [](int64 v) { + float f = absl::bit_cast(static_cast(v)); + return std::abs(f) > (1 << 13); + }; + } else if (ty_ == BF16) { + known_incorrect_fn_ = [](int64 v) { + float f = + static_cast(absl::bit_cast(static_cast(v))); + return std::abs(f) > (1 << 13); + }; + } +} + +XLA_TEST_P(ExhaustiveOpTest, Cos) { + SetParamsForSinCosTan(); + Run(Cos, std::cos); +} +XLA_TEST_P(ExhaustiveOpTest, Sin) { + SetParamsForSinCosTan(); + Run(Sin, std::sin); +} +XLA_TEST_P(ExhaustiveOpTest, Tan) { + SetParamsForSinCosTan(); + Run(Tan, std::tan); +} + +// TODO(jlebar): Enable these. +// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); } +// XLA_TEST_P(ExhaustiveOpTest, Atan2) { Run(Atan2, std::atan2); } + XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); } XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); } XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); } @@ -595,19 +707,24 @@ XLA_TEST_P(ExhaustiveOpTest, Lgamma) { if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) { rel_err_ = 0.001; } + float (*host_lgamma)(float) = std::lgamma; if (platform_ != "Host" && platform_ != "CUDA") { // TODO(b/123956399): This is a fairly high error, significantly higher than // we see on CPU/GPU. rel_err_ = 0.01; abs_err_ = 0.01; - // Overflows for to inf for input 4.08500343e+36 (0x7c44af8e). + // Overflows to inf for input 4.08500343e+36 (0x7c44af8e). if (ty_ == F32) { - known_incorrect_begin_ = 0x7c44af8e; - known_incorrect_end_ = 0x7c44af8e + 1; + host_lgamma = +[](float v) { + if (absl::bit_cast(v) == 0x7c44af8e) { + return std::numeric_limits::infinity(); + } + return std::lgamma(v); + }; } } - Run(Lgamma, std::lgamma); + Run(Lgamma, host_lgamma); } XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); } diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index 1b0bebe2d03..5d91326aad0 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -47,8 +47,9 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { } tensorflow::SubProcess file_check_process; - file_check_process.SetProgram(file_check_path, - {file_check_path, "-v", pattern_path}); + file_check_process.SetProgram( + file_check_path, + {file_check_path, "-v", "-dump-input=always", pattern_path}); file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, tensorflow::ACTION_PIPE); file_check_process.SetChannelAction(tensorflow::CHAN_STDERR, diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 2178c9b3f3d..2d0805cdb0e 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include + #include #include #include @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -63,7 +63,11 @@ const float test_float_vals[3][test_width][test_height] = { class FusionTest : public HloTestBase { protected: template - void TestElementwise2D(HloOpcode opcode) { + void TestElementwise2D( + HloOpcode opcode, + absl::optional direction = absl::nullopt) { + // Create a variable for comparisons since they require the direction. + bool is_compare = std::is_same::value; Array2D operand_data[Arity]; for (int i = 0; i < Arity; ++i) { new (&operand_data[i]) Array2D(test_width, test_height); @@ -76,12 +80,16 @@ class FusionTest : public HloTestBase { xs[k] = test_float_vals[k][i][j]; operand_data[k](i, j) = xs[k]; } - answer_data(i, j) = ComputeElementwiseAnswer(opcode, xs); + if (is_compare) { + answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs); + } else { + answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs); + } } } auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -98,8 +106,13 @@ class FusionTest : public HloTestBase { root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]); break; case 2: - root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], - hlos[2]); + if (is_compare) { + root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1], + hlos[2], *direction); + } else { + root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], + hlos[2]); + } break; case 3: root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1], @@ -124,13 +137,19 @@ class FusionTest : public HloTestBase { } private: - template - T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span xs); + float ComputeElementwiseAnswerFloat(HloOpcode opcode, + absl::Span xs); + bool ComputeElementwiseAnswerCompare(ComparisonDirection direction, + absl::Span xs); + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + return debug_options; + } }; -template <> -float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - absl::Span xs) { +float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode, + absl::Span xs) { switch (opcode) { case HloOpcode::kAdd: return xs[0] + xs[1]; @@ -153,24 +172,21 @@ float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, } } -template <> -bool FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - absl::Span xs) { - switch (opcode) { - case HloOpcode::kEq: +bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction, + absl::Span xs) { + switch (direction) { + case ComparisonDirection::kEq: return xs[0] == xs[1]; - case HloOpcode::kNe: + case ComparisonDirection::kNe: return xs[0] != xs[1]; - case HloOpcode::kGt: + case ComparisonDirection::kGt: return xs[0] > xs[1]; - case HloOpcode::kLt: + case ComparisonDirection::kLt: return xs[0] < xs[1]; - case HloOpcode::kGe: + case ComparisonDirection::kGe: return xs[0] >= xs[1]; - case HloOpcode::kLe: + case ComparisonDirection::kLe: return xs[0] <= xs[1]; - default: - LOG(FATAL) << "No comparatory opcode: " << opcode; } } @@ -183,7 +199,7 @@ XLA_TEST_F(FusionTest, Test) { // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -231,7 +247,7 @@ XLA_TEST_F(FusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -266,7 +282,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); // Build simple fusion computation: y = x^2 (elementwise). auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto two = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); @@ -290,7 +306,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( @@ -314,7 +330,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto single_element_array = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( @@ -329,7 +345,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -344,7 +360,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( @@ -359,7 +375,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( @@ -374,7 +390,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -389,7 +405,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( @@ -404,7 +420,7 @@ XLA_TEST_F(FusionTest, Reshape__) { XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( @@ -419,7 +435,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -434,7 +450,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -449,7 +465,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -465,7 +481,7 @@ XLA_TEST_F(FusionTest, Reverse) { XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -483,7 +499,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -501,7 +517,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( @@ -519,7 +535,7 @@ XLA_TEST_F(FusionTest, SliceNegate) { XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( @@ -541,7 +557,7 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( @@ -559,7 +575,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -587,11 +603,10 @@ std::unique_ptr MakeReduceTestComputation() { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { - auto hlo_module = CreateNewUnverifiedModule(); - + auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {32}), 0)); auto const1 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( @@ -602,12 +617,12 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { HloInstruction::FusionKind::kInput); EXPECT_TRUE( - LiteralTestUtil::Equal(LiteralUtil::CreateR0(15), + LiteralTestUtil::Equal(LiteralUtil::CreateR0(496), ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { - auto hlo_module = CreateNewUnverifiedModule(); +XLA_TEST_F(FusionTest, ReduceImplicitBroadcast) { + auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -630,7 +645,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( @@ -682,7 +697,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { // into a fusion, it should remain shared, rather than being duplicated // within the fusion. XLA_TEST_F(FusionTest, SharedConstant) { - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( @@ -740,64 +755,34 @@ XLA_TEST_F(FusionTest, Maximum2D) { TestElementwise2D(HloOpcode::kMaximum); } -XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D(HloOpcode::kEq); } +XLA_TEST_F(FusionTest, Equal2D) { + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kEq); +} XLA_TEST_F(FusionTest, Inequal2D) { - TestElementwise2D(HloOpcode::kNe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kNe); } XLA_TEST_F(FusionTest, Greater2D) { - TestElementwise2D(HloOpcode::kGt); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGt); } -XLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D(HloOpcode::kLt); } +XLA_TEST_F(FusionTest, Lesser2D) { + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLt); +} XLA_TEST_F(FusionTest, GreaterOrEqual2D) { - TestElementwise2D(HloOpcode::kGe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGe); } XLA_TEST_F(FusionTest, LesserOrEqual2D) { - TestElementwise2D(HloOpcode::kLe); + TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLe); } XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } -// TODO(b/117156505): Remove this test when the bug is fixed and the CPU backend -// should not generate layout changing elementwise operations. -#ifdef XLA_TEST_BACKEND_CPU -XLA_TEST_F(FusionTest, LayoutChangingElementWiseOp) { - const string hlo_text = R"( -HloModule Cluster - -fusion_c { - fusion.arg = f32[2,2]{1,0} parameter(0) - bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg) - tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0) - ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0) -} - -ENTRY main { - arg = f32[2,2]{1,0} parameter(0) - ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c -} -)"; - - Literal operand = LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text, config)); - TF_ASSERT_OK_AND_ASSIGN(Literal result, - test_runner_.Execute(std::move(module), {&operand}, - /*run_hlo_passes=*/false)); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), - result)); -} -#endif - class FusionClientLibraryTest : public ClientLibraryTestBase {}; XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { @@ -844,7 +829,7 @@ void BM_ParallelFusion(int num_iters) { se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); - StreamExecutorMemoryAllocator allocator(platform, executors); + se::StreamExecutorMemoryAllocator allocator(platform, executors); const int64 intra_op_parallelism_threads = 24; xla::LocalClientOptions client_options; @@ -910,8 +895,7 @@ void BM_ParallelFusion(int num_iters) { // Initialize thread pool. tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", intra_op_parallelism_threads); - tensorflow::EigenThreadPoolWrapper tp(&pool); - Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + Eigen::ThreadPoolDevice device(pool.AsEigenThreadPool(), pool.NumThreads()); // Initialize ExecutableRunOptions. ExecutableRunOptions options; diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index d65b67a535d..16a1371ec8d 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -598,6 +598,26 @@ ENTRY main { RunTest(hlo_text, &operand, &start_indices); } +XLA_TEST_F(GatherOperationTest, GatherFromScalar) { + const string hlo_text = R"( +HloModule GatherFromScalar + +ENTRY main { + operand = f32[] parameter(0) + indices = s32[0]{0} parameter(1) + ROOT gather = f32[] gather(operand, indices), + offset_dims={}, + collapsed_slice_dims={}, + start_index_map={}, + index_vector_dim=0, + slice_sizes={} +} +)"; + Literal operand = LiteralUtil::CreateR0(1); + Literal start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &start_indices); +} + class GatherClientLibraryTest : public ClientLibraryTestBase {}; // Disabled on interpreter since ExectuteAsyncOnStream is not supported. diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 1115e50fe31..74333d66610 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -73,7 +73,7 @@ half sign_imp(half value) { } half round_imp(half value) { - return half(round(static_cast(std::move(value)))); + return half(std::round(static_cast(std::move(value)))); } INSTANTIATE_TEST_CASE_P( @@ -163,8 +163,8 @@ XLA_TEST_P(BinaryOpTest, Ops) { } half atan2_imp(half x, half y) { - return half(atan2(static_cast(std::move(x)), - static_cast(std::move(y)))); + return half(std::atan2(static_cast(std::move(x)), + static_cast(std::move(y)))); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 62e2b465cfe..79974723b8b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -143,6 +143,11 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( backend().compiler()->ShapeSizeBytesFunction()); } +StatusOr> +HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) { + return ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest()); +} + StatusOr> HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { @@ -210,11 +215,26 @@ StatusOr> HloTestBase::ExecuteReplicated( int64 num_replicas, bool use_threads) { HloRunner::ReplicatedExecuteOptions options; options.num_replicas = num_replicas; + options.use_threads = use_threads; + for (auto argument : arguments) { + options.arguments.push_back(argument); + } + return test_runner_.ExecuteReplicated(std::move(module), options); +} + +StatusOr> HloTestBase::ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas, DeviceAssignment* device_assignment, + bool run_hlo_passes, bool use_threads) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.run_hlo_passes = run_hlo_passes; + options.use_threads = use_threads; for (auto argument : arguments) { options.arguments.push_back(argument); } return test_runner_.ExecuteReplicated(std::move(module), options, - use_threads); + device_assignment); } StatusOr> HloTestBase::MakeReferenceModule( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index df9c29a186f..7a78307a467 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -114,8 +114,9 @@ class HloTestBase : public ::testing::Test { // Parses the given string and returns module as a VerifiedHloModule. StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); + absl::string_view hlo_text); + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config); // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -182,6 +183,12 @@ class HloTestBase : public ::testing::Test { std::unique_ptr module, absl::Span arguments, int64 num_replicas, bool use_threads); + // Same as above, but uses specified device assignment. + StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas, DeviceAssignment* device_assignment, + bool run_hlo_passes, bool use_threads); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index a2fd6070731..7f725a97f28 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -26,19 +26,26 @@ namespace { // Writes the given literal to a file in the test temporary directory. void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { - auto get_hostname = [] { - char hostname[1024]; - gethostname(hostname, sizeof hostname); - hostname[sizeof hostname - 1] = 0; - return string(hostname); - }; - int64 now_usec = tensorflow::Env::Default()->NowMicros(); + // Bazel likes for tests to write "debugging outputs" like these to + // TEST_UNDECLARED_OUTPUTS_DIR. This plays well with tools that inspect test + // results, especially when they're run on remote machines. + string outdir; + const char* undeclared_outputs_dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); + if (undeclared_outputs_dir != nullptr) { + outdir = undeclared_outputs_dir; + } else { + outdir = tensorflow::testing::TmpDir(); + } + + auto* env = tensorflow::Env::Default(); string filename = tensorflow::io::JoinPath( - tensorflow::testing::TmpDir(), - absl::StrFormat("tempfile-%s-%x-%s", get_hostname(), now_usec, name)); - TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, + outdir, absl::StrFormat("tempfile-%d-%s", env->NowMicros(), name)); + TF_CHECK_OK(tensorflow::WriteBinaryProto(env, absl::StrCat(filename, ".pb"), literal.ToProto())); - LOG(ERROR) << "wrote to " << name << " file: " << filename; + TF_CHECK_OK(tensorflow::WriteStringToFile(env, absl::StrCat(filename, ".txt"), + literal.ToString())); + LOG(ERROR) << "wrote Literal to " << name << " file: " << filename + << ".{pb,txt}"; } // Callback helper that dumps literals to temporary files in the event of a diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index ea9b3037cf4..c54b28c142a 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -38,6 +38,68 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); } +TEST(LiteralTestUtilTest, ComparesEqualComplex64TuplesEqual) { + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); +} + +TEST(LiteralTestUtilTest, ComparesEqualComplex128TuplesEqual) { + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); +} + +TEST(LiteralTestUtilTest, ComparesUnequalComplex64TuplesUnequal) { + Literal literal0 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + Literal literal1 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({64.0, 42.0}), + LiteralUtil::CreateR0({42.0, 64.0}), + }); + Literal literal2 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.42, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + Literal literal3 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.42}), + }); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3)); +} + +TEST(LiteralTestUtilTest, ComparesUnequalComplex128TuplesUnequal) { + Literal literal0 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + Literal literal1 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({64.0, 42.0}), + LiteralUtil::CreateR0({42.0, 64.0}), + }); + Literal literal2 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.42, 64.0}), + LiteralUtil::CreateR0({64.0, 42.0}), + }); + Literal literal3 = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0({42.0, 64.0}), + LiteralUtil::CreateR0({64.0, 42.42}), + }); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3)); + EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3)); +} + TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // Implementation note: we have to use a death test here, because you can't // un-fail an assertion failure. The CHECK-failure is death, so we can make a @@ -65,8 +127,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { }; tensorflow::Env* env = tensorflow::Env::Default(); - string pattern = - tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/tempfile-*"); + + string outdir; + const char* undeclared_outputs_dir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); + if (undeclared_outputs_dir != nullptr) { + outdir = undeclared_outputs_dir; + } else { + outdir = tensorflow::testing::TmpDir(); + } + string pattern = tensorflow::io::JoinPath(outdir, "/tempfile-*.pb"); std::vector files; TF_CHECK_OK(env->GetMatchingPaths(pattern, &files)); for (const auto& f : files) { @@ -118,6 +187,92 @@ TEST(LiteralTestUtilTest, NearComparatorR1) { EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } +TEST(LiteralTestUtilTest, NearComparatorR1Complex64) { + auto a = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.8}}); + auto b = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.8}}); + auto c = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.9, 1.8}}); + auto d = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.9}}); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001})); +} + +TEST(LiteralTestUtilTest, NearComparatorR1Complex128) { + auto a = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.8}}); + auto b = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.8}}); + auto c = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.9, 1.8}}); + auto d = LiteralUtil::CreateR1({{0.0, 1.0}, + {0.1, 1.1}, + {0.2, 1.2}, + {0.3, 1.3}, + {0.4, 1.4}, + {0.5, 1.5}, + {0.6, 1.6}, + {0.7, 1.7}, + {0.8, 1.9}}); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001})); +} + TEST(LiteralTestUtilTest, NearComparatorR1Nan) { auto a = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); @@ -135,5 +290,13 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) { EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001})); } +TEST(LiteralTestUtilTest, ExpectNearDoubleOutsideFloatValueRange) { + auto two_times_float_max = + LiteralUtil::CreateR0(2.0 * std::numeric_limits::max()); + ErrorSpec error(0.001); + EXPECT_TRUE( + LiteralTestUtil::Near(two_times_float_max, two_times_float_max, error)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 60eb21aafd2..f1779c856bb 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -69,12 +69,12 @@ int main(int argc, char** argv) { } else if (target_cpu == "arm") { triple_string = "aarch64-none-linux-gnu"; } else if (target_cpu == "local") { - triple_string = xla::llvm_ir::AsString(llvm::sys::getDefaultTargetTriple()); + triple_string = llvm::sys::getDefaultTargetTriple(); } else { LOG(FATAL) << "unsupported TARGET_CPU: " << target_cpu; } - llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); + llvm::Triple triple(triple_string); xla::XlaComputation computation = builder.Build().ConsumeValueOrDie(); xla::CompileOnlyClient::AotXlaComputationInstance instance{ diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 96527886b71..67a1abacd18 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -41,6 +40,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace { @@ -130,14 +130,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { // Create x as a col-major array. auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); - EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), - LayoutUtil::MakeLayout({0, 1}))); + EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()( + x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); - EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), - LayoutUtil::MakeLayout({1, 0}))); + EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()( + y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -171,8 +171,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), - LayoutUtil::MakeLayout({0, 1}))); + EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()( + result_colmaj.on_device_shape().layout(), + LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, ShapedBufferToLiteral(result_colmaj), error_spec_); @@ -183,8 +184,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), - LayoutUtil::MakeLayout({1, 0}))); + EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()( + result_rowmaj.on_device_shape().layout(), + LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, ShapedBufferToLiteral(result_rowmaj), error_spec_); @@ -900,7 +902,7 @@ void BM_LocalClientOverhead(int num_iters) { se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); - StreamExecutorMemoryAllocator allocator(platform, executors); + se::StreamExecutorMemoryAllocator allocator(platform, executors); LocalClient* client = ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); auto* transfer_manager = diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index f90ef22d2d5..7eaa2791d47 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" @@ -36,17 +35,16 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -StatusOr TestAllocator::Allocate(int device_ordinal, - uint64 size, - bool retry_on_failure) { +StatusOr TestAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; { tensorflow::mutex_lock lock(count_mutex_); allocation_count_++; device_allocation_count_[device_ordinal]++; } - return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size, - retry_on_failure); + return se::StreamExecutorMemoryAllocator::Allocate(device_ordinal, size, + retry_on_failure); } Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { @@ -56,7 +54,7 @@ Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { deallocation_count_++; device_deallocation_count_[device_ordinal]++; } - return StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem); + return se::StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem); } int64 TestAllocator::allocation_count() const { @@ -108,12 +106,10 @@ struct LocalClientTestBase::EigenThreadPoolWrapper { explicit EigenThreadPoolWrapper() : pool(new tensorflow::thread::ThreadPool( tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)), - wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), - device(new Eigen::ThreadPoolDevice(wrapper.get(), - wrapper->NumThreads())) {} + device(new Eigen::ThreadPoolDevice(pool->AsEigenThreadPool(), + pool->NumThreads())) {} std::unique_ptr pool; - std::unique_ptr wrapper; std::unique_ptr device; }; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 4027c7b124f..292baacf969 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -36,18 +35,19 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { -class TestAllocator : public StreamExecutorMemoryAllocator { +class TestAllocator : public se::StreamExecutorMemoryAllocator { public: explicit TestAllocator(se::Platform* platform) - : StreamExecutorMemoryAllocator( + : se::StreamExecutorMemoryAllocator( platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) { } - StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. diff --git a/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc index 1513d89ba9c..7895895e3e7 100644 --- a/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc @@ -14,35 +14,86 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +// Tests cross-GPU all-reduce operatons. +// +// This test requires multiple GPUs. For instructions on running this within +// Google, see go/multi-gpu-unit-test. namespace xla { namespace { -class MultiDeviceAllReduceTest : public HloTestBase {}; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +class MultiDeviceAllReduceTest : public HloTestBase { + protected: + std::unique_ptr MakeCrsModule(int64 num_elems, + const HloModuleConfig& config) { + const char* kTemplate = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p = f32[NUM_ELEMS] parameter(0) + ROOT crs = f32[NUM_ELEMS] all-reduce(p), to_apply=add + } + )"; + return ParseHloString( + absl::StrReplaceAll(kTemplate, + {{"NUM_ELEMS", absl::StrCat(num_elems)}}), + config) + .ValueOrDie(); + } +}; + +// Returns the non-empty subsets of {0, 1, ..., n}. For example, +// PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}. +std::vector> PowerSetOfIota(int64 n) { + std::vector> power_set; + for (int64 i = 1; i < (1 << n); ++i) { + power_set.emplace_back(); + for (int64 j = 0; j < n; ++j) { + if (i & (1 << j)) { + power_set.back().push_back(j); + } + } + } + return power_set; +} + +// Makes a DeviceAssignment assigning replica-id i to devices[i]. +DeviceAssignment MakeDeviceAssn(std::vector devices) { + DeviceAssignment assn(/*replica_count=*/devices.size(), + /*computation_count=*/1); + for (int64 i = 0; i < devices.size(); ++i) { + assn(i, 0) = devices[i]; + } + return assn; +} + +// Shorter alias for this function. +absl::flat_hash_set OpenNcclChannels() { + return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels(); +} XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) { - const char* module_str = R"( - HloModule test - - add { - x = f32[] parameter(0) - y = f32[] parameter(1) - add = f32[] add(x, y) - } - - ENTRY test_computation { - p = f32[3] parameter(0) - ROOT crs = f32[3] all-reduce(p), to_apply=add - })"; auto config = GetModuleConfigForTest(); config.set_replica_count(2); - auto module = ParseHloString(module_str, config).ValueOrDie(); + auto module = MakeCrsModule(/*num_elems=*/3, config); auto literal = LiteralUtil::CreateR1({1, 2, 3}); auto expected = LiteralUtil::CreateR1({2, 4, 6}); TF_ASSERT_OK_AND_ASSIGN(std::vector results, @@ -52,5 +103,112 @@ XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) { EXPECT_EQ(expected, results[1]); } +// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of +// devices in sequence. +XLA_TEST_F(MultiDeviceAllReduceTest, AllCombinations) { + const int64 kNumDevices = 4; + const int64 kNumElems = 1024; + + for (std::vector devices : PowerSetOfIota(kNumDevices)) { + SCOPED_TRACE(absl::StrFormat("Running on devices {%s}", + absl::StrJoin(devices, ", "))); + + DeviceAssignment device_assn = MakeDeviceAssn(devices); + + auto config = GetModuleConfigForTest(); + config.set_replica_count(devices.size()); + config.set_static_device_assignment(device_assn); + + auto module = MakeCrsModule(kNumElems, config); + + std::vector input_vec(kNumElems); + absl::c_iota(input_vec, 0); + auto input_literal = LiteralUtil::CreateR1(input_vec); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {&input_literal}, + /*num_replicas=*/devices.size(), &device_assn, + /*run_hlo_passes=*/true, /*use_threads=*/true)); + } +} + +// Check that the NCCL data structures in our all-reduce implementation are +// cached as we expect. +XLA_TEST_F(MultiDeviceAllReduceTest, NcclChannelCaching) { + const int64 kNumElems = 1024; + + std::vector input_vec(kNumElems); + absl::c_iota(input_vec, 0); + auto input_literal = LiteralUtil::CreateR1(input_vec); + + // Initially no NCCL channels should be open. + EXPECT_THAT(OpenNcclChannels(), IsEmpty()); + + // Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}. + struct ExecutableInfo { + std::unique_ptr executable; + DeviceAssignment device_assn; + HloRunner::ReplicatedExecuteOptions opts; + }; + std::vector executables; + for (const auto& devices : + std::vector>{{0, 1}, {1, 2}, {0, 1, 2}}) { + executables.emplace_back(); + auto& e = executables.back(); + + e.device_assn = MakeDeviceAssn(devices); + + auto config = GetModuleConfigForTest(); + config.set_replica_count(devices.size()); + config.set_static_device_assignment(e.device_assn); + auto module = MakeCrsModule(kNumElems, config); + e.executable = + test_runner_ + .CreateExecutable(std::move(module), /*run_hlo_passes=*/true) + .ValueOrDie(); + + e.opts.num_replicas = devices.size(); + e.opts.use_threads = true; + e.opts.arguments.push_back(&input_literal); + } + + auto run_executable = [&](int64 i) { + auto& e = executables[i]; + TF_ASSERT_OK( + test_runner_ + .ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn) + .status()); + }; + + // Compiling executables above shouldn't cause us to open any channels. + EXPECT_THAT(OpenNcclChannels(), IsEmpty()); + + // Run the executables and check that channels are opened as we expect. + run_executable(0); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1)); + + run_executable(2); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2)); + + run_executable(1); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2)); + + // Tear down the executables and check that channels are closed as we expect. + // Note that after we tear down an executable *all* the nccl channels may go + // away, so we rerun all of the executables that haven't been torn down. + executables[2].executable.reset(); + run_executable(0); + run_executable(1); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2)); + + executables[0].executable.reset(); + run_executable(1); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(1, 2)); + + executables[1].executable.reset(); + EXPECT_THAT(OpenNcclChannels(), IsEmpty()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 1fd9cb055c0..7578094e07f 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -227,7 +227,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { fused_computation { p = f32[4] parameter(0) multiply = f32[4] multiply(p, p) - less-than = pred[4] less-than(p, multiply) + less-than = pred[4] compare(p, multiply), direction=LT ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) } @@ -252,7 +252,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { fused_computation { p = f32[] parameter(0) multiply = f32[] multiply(p, p) - less-than = pred[] less-than(p, multiply) + less-than = pred[] compare(p, multiply), direction=LT ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) } @@ -295,255 +295,191 @@ XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[32,32,32]{2,1,0} parameter(0) c0 = f32[] constant(0) - r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add - mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + r1 = f32[32,32]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[32,32,32]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) - r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max - ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2) } ENTRY reduce { - p = f32[2,2,2]{2,1,0} parameter(0) - ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, - calls=fused_reduce + p = f32[32,32,32]{2,1,0} parameter(0) + ROOT fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce })"); auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); - auto param = - LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR2({{3, 7}, {11, 15}}), - LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - result)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[32,32,32]{2,1,0} parameter(0) c0 = f32[] constant(0) - r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add - mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + r1 = f32[32,32]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[32,32,32]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) - r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max - ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2) } ENTRY reduce { - p = f32[2,2,2]{2,1,0} parameter(0) - ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, - calls=fused_reduce + p = f32[32,32,32]{2,1,0} parameter(0) + ROOT fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce })"); auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); - auto param = - LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR2({{6, 8}, {10, 12}}), - LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - result)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[2,32,32]{2,1,0} parameter(0) c0 = f32[] constant(0) - r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add - mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + r1 = f32[32]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,32,32]{2,1,0} multiply(p0, p0) c1 = f32[] constant(1.17549e-38) - r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max - r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add - ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3) + r2 = f32[32]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max + r3 = f32[32]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add + ROOT tuple = (f32[32]{0}, f32[32]{0}, f32[32]{0}) tuple(r1, r2, r3) } ENTRY reduce { - p = f32[2,2,2]{2,1,0} parameter(0) - ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, - calls=fused_reduce + p = f32[2,32,32]{2,1,0} parameter(0) + ROOT fusion = (f32[32]{0}, f32[32]{0}, f32[32]{0}) fusion(p), kind=kInput, + calls=fused_reduce })"); auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); - auto param = - LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), - LiteralUtil::CreateR1({36, 64}), - LiteralUtil::CreateR1({66, 138})), - result)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[2,32,32]{2,1,0} parameter(0) c0 = f32[] constant(0) - r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add - mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + r1 = f32[2,32]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,32,32]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) - r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max - ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) + r2 = f32[2,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,32,32]{2,1,0}, f32[2,32]{1,0}, f32[2,2]{1,0}) tuple(p0, r1, r2) } ENTRY reduce { - p = f32[2,2,2]{2,1,0} parameter(0) - ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), - kind=kInput, calls=fused_reduce + p = f32[2,32,32]{2,1,0} parameter(0) + ROOT fusion = (f32[2,32,32]{2,1,0}, f32[2,32]{1,0}, f32[2,32]{1,0}) + fusion(p), kind=kInput, calls=fused_reduce })"); auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); - auto param = - LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), - LiteralUtil::CreateR2({{3, 7}, {11, 15}}), - LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - result)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[32,32,2]{2,1,0} parameter(0) c0 = f32[] constant(0) - r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add - mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + r1 = f32[32,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[32,32,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) - r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max - ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) + r2 = f32[32,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[32,2]{1,0}, f32[32,32,2]{2,1,0}, f32[32,2]{1,0}) tuple(r1, mul, r2) } ENTRY reduce { - p = f32[2,2,2]{2,1,0} parameter(0) - ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), - kind=kInput, calls=fused_reduce + p = f32[32,32,2]{2,1,0} parameter(0) + ROOT fusion = (f32[32,2]{1,0}, f32[32,32,2]{2,1,0}, f32[32,2]{1,0}) + fusion(p), kind=kInput, calls=fused_reduce })"); auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); - auto param = - LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR2({{6, 8}, {10, 12}}), - LiteralUtil::CreateR3( - {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - result)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { - const string testcase = absl::StrCat(kScalarOps, R"( + const string testcase = R"( + HloModule m, is_scheduled=true + + Add { + lhsadd = f32[] parameter(0) + rhsadd = f32[] parameter(1) + ROOT add = f32[] add(lhsadd, rhsadd) + } fused_reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[2,32,32]{2,1,0} parameter(0) c0 = f32[] constant(0) - r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add - mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + r1 = f32[32]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,32,32]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) - b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={} - mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1) - ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) - tuple(r1, mul, mul2) + b1 = f32[2,32,32]{2,1,0} broadcast(c1), dimensions={} + mul2 = f32[2,32,32]{2,1,0} multiply(p0, b1) + ROOT tuple = (f32[32]{0}, f32[2,32,32]{2,1,0}, f32[2,32,32]{2,1,0}) + tuple(r1, mul, mul2) } ENTRY reduce { - p = f32[2,2,2]{2,1,0} parameter(0) - ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), - kind=kInput, calls=fused_reduce - })"); + p = f32[2,32,32]{2,1,0} parameter(0) + ROOT fusion = (f32[32]{0}, f32[2,32,32]{2,1,0}, f32[2,32,32]{2,1,0}) + fusion(p), kind=kInput, calls=fused_reduce + })"; auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); - auto param = - LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR1({14, 22}), - LiteralUtil::CreateR3( - {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - LiteralUtil::CreateR3( - {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), - result)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { - p0 = f32[2,2,2]{2,1,0} parameter(0) + p0 = f32[2,32,32]{2,1,0} parameter(0) init1 = f32[] parameter(1) init2 = f32[] parameter(2) - r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add - r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max - ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + r1 = f32[2,32]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add + r2 = f32[2,32]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,32]{1,0}, f32[2,32]{1,0}) tuple(r1, r2) } ENTRY reduce { - p = f32[2,2,2]{2,1,0} parameter(0) + p = f32[2,32,32]{2,1,0} parameter(0) i = f32[] parameter(1) j = f32[] parameter(2) - ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, - calls=fused_reduce + ROOT fusion = (f32[2,32]{1,0}, f32[2,32]{1,0}) fusion(p, i, j), + kind=kInput, calls=fused_reduce })"); auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); - auto param = - LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - auto init1 = LiteralUtil::CreateR0(5); - auto init2 = LiteralUtil::CreateR0(6); - Literal result = - ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR2({{167, 172}, {176, 180}}), - LiteralUtil::CreateR2({{6, 6}, {6, 8}})), - result)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { const string testcase = absl::StrCat(kScalarOps, R"( - fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { - p0 = f16[2,2,2]{2,1,0} parameter(0) - convert = f32[2,2,2]{2,1,0} convert(p0) + fused_reduce (p0: f16[2,32,32]) -> (f32[2,32], f32[2,32], f16[2,32,32]) { + p0 = f16[2,32,32]{2,1,0} parameter(0) + convert = f32[2,32,32]{2,1,0} convert(p0) c0 = f32[] constant(0) - r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add - mul = f32[2,2,2]{2,1,0} multiply(convert, convert) + r1 = f32[2,32]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add + mul = f32[2,32,32]{2,1,0} multiply(convert, convert) c1 = f32[] constant(5) - r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max - ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) + r2 = f32[2,32]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,32]{1,0}, f32[2,32]{1,0}, f16[2,32,32]{2,1,0}) tuple(r1, r2, p0) } ENTRY reduce { - p = f16[2,2,2]{2,1,0} parameter(0) - ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), + p = f16[2,32,32]{2,1,0} parameter(0) + ROOT fusion = (f32[2,32]{1,0}, f32[2,32]{1,0}, f16[2,32,32]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); - auto param = LiteralUtil::CreateR3( - {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, - {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); - Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); - EXPECT_TRUE(LiteralTestUtil::Equal( - LiteralUtil::MakeTupleOwned( - LiteralUtil::CreateR2({{3, 7}, {11, 15}}), - LiteralUtil::CreateR2({{5, 16}, {36, 64}}), - LiteralUtil::CreateR3( - {{{Eigen::half(1), Eigen::half(2)}, - {Eigen::half(3), Eigen::half(4)}}, - {{Eigen::half(5), Eigen::half(6)}, - {Eigen::half(7), Eigen::half(8)}}})), - result)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 18c99490a38..5b3f30a4da4 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -50,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -455,7 +457,7 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { for (int64 colno = 0; colno < cols; ++colno) { float column_sum = 0; for (int64 rowno = 0; rowno < rows; ++rowno) { - column_sum += log(input_data(rowno, colno)); + column_sum += std::log(input_data(rowno, colno)); } expected.push_back(column_sum); } @@ -486,7 +488,7 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { for (int64 colno = 0; colno < cols; ++colno) { float column_sum = 0; for (int64 rowno = 0; rowno < rows; ++rowno) { - column_sum += log(input_data(rowno, colno)); + column_sum += std::log(input_data(rowno, colno)); } expected.push_back(column_sum); } @@ -533,7 +535,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { for (int64 colno = 0; colno < cols / 2; ++colno) { float column_sum = 0; for (int64 rowno = 0; rowno < rows; ++rowno) { - column_sum += tanh(input_data(rowno, major, colno)); + column_sum += std::tanh(input_data(rowno, major, colno)); } expected.push_back(column_sum); } @@ -1001,5 +1003,193 @@ XLA_TEST_F(ReduceTest, R0ReduceInDisguise) { ErrorSpec(0.001)); } +class ReduceHloTest : public HloTestBase {}; + +XLA_TEST_F(ReduceHloTest, HandleReductionToVectorAndOtherReduction) { + absl::string_view hlo_string = R"( + HloModule HandleReductionToVectorAndOtherReduction + + add { + acc = f32[] parameter(1) + op = f32[] parameter(0) + ROOT out = f32[] add(acc, op) + } + + ENTRY main { + iota.3 = s32[2,2]{1,0} iota(), iota_dimension=0 + iota.2 = s32[2,2]{1,0} iota(), iota_dimension=1 + compare.0 = pred[2,2]{1,0} compare(iota.3, iota.2), direction=EQ + broadcast = pred[2,2,2,2]{3,2,1,0} broadcast(compare.0), dimensions={2,3} + param_0.16 = f32[2,2,2,2]{3,2,1,0} parameter(0) + constant_4 = f32[] constant(0) + broadcast.9 = f32[2,2,2,2]{3,2,1,0} broadcast(constant_4), dimensions={} + select.0 = f32[2,2,2,2]{3,2,1,0} select(broadcast, param_0.16, broadcast.9) + reduce.1 = f32[2,2,2]{2,1,0} reduce(select.0, constant_4), dimensions={2}, + to_apply=add + abs.0 = f32[2,2,2]{2,1,0} abs(reduce.1) + log.0 = f32[2,2,2]{2,1,0} log(abs.0) + reduce.0 = f32[2,2]{1,0} reduce(log.0, constant_4), dimensions={2}, + to_apply=add + ROOT tuple = (f32[2,2]{1,0}, f64[2,2,2]{2,1,0}) tuple(reduce.0, reduce.1) + } + )"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + +class VariadicReduceTest : public HloTestBase {}; + +XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R2x2_simple) { + absl::string_view hlo_string = R"( + HloModule Reduce_R3x2_to_R1x2_simple + + add { + op1 = f32[] parameter(0) + op2 = f32[] parameter(1) + acc1 = f32[] parameter(2) + acc2 = f32[] parameter(3) + out1 = f32[] add(acc1, op1) + out2 = f32[] add(acc2, op2) + ROOT result = (f32[], f32[]) tuple(out1, out2) + } + + ENTRY main { + inp1 = f32[3,4,5] parameter(0) + inp2 = f32[3,4,5] parameter(1) + zero = f32[] constant(0) + + ROOT out = (f32[3,5], f32[3,5]) reduce(inp1, inp2, zero, zero), + dimensions={1}, + to_apply=add + } +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + +XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R1x2_simple) { + absl::string_view hlo_string = R"( + HloModule Reduce_R3x2_to_R1x2_simple + + add { + op1 = f32[] parameter(0) + op2 = f32[] parameter(1) + acc1 = f32[] parameter(2) + acc2 = f32[] parameter(3) + out1 = f32[] add(acc1, op1) + out2 = f32[] add(acc2, op2) + ROOT result = (f32[], f32[]) tuple(out1, out2) + } + + ENTRY main { + inp1 = f32[10,20,3] parameter(0) + inp2 = f32[10,20,3] parameter(1) + zero = f32[] constant(0) + + ROOT out = (f32[10], f32[10]) reduce(inp1, inp2, zero, zero), + dimensions={1,2}, + to_apply=add + } +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + +XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_simple) { + absl::string_view hlo_string = R"( + HloModule Reduce_R1x2_to_R0x2_simple + + add { + op1 = f32[] parameter(0) + op2 = f32[] parameter(1) + acc1 = f32[] parameter(2) + acc2 = f32[] parameter(3) + out1 = f32[] add(acc1, op1) + out2 = f32[] add(acc2, op2) + ROOT result = (f32[], f32[]) tuple(out1, out2) + } + + ENTRY main { + inp1 = f32[100] parameter(0) + inp2 = f32[100] parameter(1) + zero = f32[] constant(0) + + ROOT out = (f32[], f32[]) reduce(inp1, inp2, zero, zero), + dimensions={0}, + to_apply=add + } +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + +XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_argmax) { + absl::string_view hlo_string = R"( + HloModule Reduce_R1x2_to_R0x2_argmax + + argmax { + running_max = f32[] parameter(0) + running_max_idx = u32[] parameter(1) + current_value = f32[] parameter(2) + current_value_idx = u32[] parameter(3) + + current = (f32[], u32[]) tuple(running_max, running_max_idx) + potential = (f32[], u32[]) tuple(current_value, current_value_idx) + + cmp_code = pred[] compare(current_value, running_max), direction=GT + + new_max = f32[] select(cmp_code, current_value, running_max) + new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx) + + ROOT out = (f32[], u32[]) tuple(new_max, new_idx) + } + + ENTRY main { + input = f32[100] parameter(0) + idxs = u32[100]{0} iota(), iota_dimension=0 + zero = f32[] constant(0) + zero_idx = u32[] constant(0) + + ROOT out = (f32[], u32[]) reduce( + input, idxs, zero, zero_idx), + dimensions={0}, + to_apply=%argmax + } +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + +XLA_TEST_F(VariadicReduceTest, ReduceMultiOutputVariadicAnd) { + absl::string_view hlo_string = R"( + HloModule VariadicReduceMultiOutput + + VariadicAnd { + value = pred[] parameter(0) + value_idx = u32[] parameter(1) + current_value = pred[] parameter(2) + current_value_idx = u32[] parameter(3) + ROOT out = (pred[], u32[]) tuple(value, value_idx) + } + + ENTRY CheckBuffer { + test_value = f32[] parameter(0) + buffer = f32[100] parameter(1) + value_broadcast = f32[100] broadcast(test_value), dimensions={} + comparison_result = pred[100] compare(buffer, value_broadcast), direction=EQ + true_constant = pred[] constant(true) + + zero_idx = u32[] constant(0) + idxs = u32[100]{0} iota(), iota_dimension=0 + out = (pred[], u32[]) reduce( + comparison_result, idxs, true_constant, zero_idx + ), dimensions={0}, to_apply=VariadicAnd + + ROOT returned = u32[] get-tuple-element(out), index=1 + } +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 30e2d24184a..352b59f248b 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -611,6 +611,12 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, // values. (Technically, the requirement is that the iota length is // relatively prime to all of the dimensions involved in the reduce-window.) input.FillRepeatedIota(0, 137); + // Floating point sum reduction requires higher localized precision. We need + // the following normalization in order to enable testing of kAdd on large + // windows. + input.Each([&](absl::Span /*indices*/, float* value) { + *value = *value / 10000000000.f; + }); Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -626,12 +632,6 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto reducer = param.reducer; - if (use_bfloat16()) { - // To avoid numerical issues, force the reducer to be kMax for bf16 - // inputs. - reducer = kMax; - } - auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); @@ -697,15 +697,6 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, - // With non-1x1 window. - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, - /*window_bounds=*/{2, 3, 1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{3, 2, 1, 0}, - /*reducer=*/kAdd}, - // With max instead of add. R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, /*window_bounds=*/{2, 3, 1, 1}, @@ -778,15 +769,6 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, - // With second minor dimension == 9. - R4ReduceWindowTestData{/*base_bounds=*/{2, 3, 9, 127}, - /*window_bounds=*/{1, 1, 1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{3, 2, 1, 0}, - /*reducer=*/kAdd}, - // With minor dimension == 129. R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129}, /*window_bounds=*/{1, 1, 1, 1}, @@ -814,7 +796,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, - R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3}, + R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3}, /*window_bounds=*/{1, 64, 64, 1}, /*strides=*/{1, 64, 64, 1}, /*pad_low=*/{0, 0, 0, 0}, @@ -828,6 +810,32 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kMax}, + + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 4, 5}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + + // With 0321 layout. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 4, 5}, + /*strides=*/{1, 2, 3, 4}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{0, 3, 2, 1}, + /*reducer=*/kAdd}, + + // With 0123 layout. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17}, + /*window_bounds=*/{2, 3, 7, 9}, + /*strides=*/{1, 2, 5, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{0, 1, 2, 3}, /*reducer=*/kAdd}, }; @@ -866,6 +874,55 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*pad_high=*/{0, 0, 2, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kMax}, + + // Patterns generated by cumsum/cumprod. + R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, + /*window_bounds=*/{1021, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{1020, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, + /*window_bounds=*/{1, 1, 1021, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 1020, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021}, + /*window_bounds=*/{1, 1, 1, 1021}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 1020}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16}, + /*window_bounds=*/{1021, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{1021, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16}, + /*window_bounds=*/{1, 1, 1021, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 1021, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021}, + /*window_bounds=*/{1, 1, 1, 1021}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 1021}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -874,53 +931,6 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(use_bfloat16_params)), R4ReduceWindowTestDataToString); -class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {}; - -// TODO(b/72234705): Fix the test cases failed on CPU and GPU. -XLA_TEST_P(R4ReduceWindowAnyDimsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) { - DoIt(); -} - -const R4ReduceWindowTestData kR4ReduceWindowAnyDimsTestValues[] = { - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, - /*window_bounds=*/{2, 3, 4, 5}, - /*strides=*/{1, 1, 1, 1}, - /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{3, 2, 1, 0}, - /*reducer=*/kAdd}, - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, - /*window_bounds=*/{2, 3, 1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{3, 2, 1, 0}, - /*reducer=*/kMax}, - // With 0321 layout. - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, - /*window_bounds=*/{2, 3, 4, 5}, - /*strides=*/{1, 2, 3, 4}, - /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{0, 3, 2, 1}, - /*reducer=*/kAdd}, - - // With 0123 layout. - R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23}, - /*window_bounds=*/{2, 3, 7, 9}, - /*strides=*/{1, 2, 5, 8}, - /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, - /*layout=*/{0, 1, 2, 3}, - /*reducer=*/kAdd}, -}; - -INSTANTIATE_TEST_CASE_P( - R4ReduceWindowAnyDimsTestInstantiation, R4ReduceWindowAnyDimsTest, - ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowAnyDimsTestValues), - ::testing::ValuesIn(use_bfloat16_params)), - R4ReduceWindowTestDataToString); - struct R3ReduceWindowTestData { int64 base_bounds[3]; int64 window_bounds[3]; @@ -1113,6 +1123,11 @@ struct R2ReduceWindowTestData { {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4}, /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, + // Regression test for b/72234705: bf16 lacks precision to store incremental + // results on very large windows. Using smaller window with minor dim 128. + {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, }; string R2ReduceWindowTestDataToString( @@ -1191,27 +1206,6 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(use_bfloat16_params)), R2ReduceWindowTestDataToString); -class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {}; - -// TODO(b/72234705): Fix the test cases failed on CPU and GPU. -XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, - DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) { - DoIt(); -} - -const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { - {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, - /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, - /*layout=*/{1, 0}, - /*reducer=*/Reducer::kAdd}, -}; - -INSTANTIATE_TEST_CASE_P( - R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test, - ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test), - ::testing::ValuesIn(use_bfloat16_params)), - R2ReduceWindowTestDataToString); - struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; @@ -1321,9 +1315,9 @@ struct R1ReduceWindowTestData { /*reducer=*/Reducer::kMax}, // The pattern generated by exclusive scan (cumsum/cumprod). - {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + {/*base_bounds=*/{4095}, /*window_bounds=*/{4095}, /*strides=*/{1}, - /*pad_low=*/{4096}, + /*pad_low=*/{4095}, /*pad_high=*/{0}, /*reducer=*/Reducer::kMax}, }; diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 32de0fdf78f..86d9999b4a4 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -718,5 +718,32 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, ScatterIntoScalar) { + const char* hlo_text = R"( +HloModule ScatterIntoScalar + +update_s32 { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + parameter.1 = s32[] parameter(0) + parameter.2 = s32[0]{0} parameter(1) + parameter.3 = s32[] parameter(2) + ROOT scatter = s32[] scatter(parameter.1, parameter.2, parameter.3), + update_window_dims={}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={}, + index_vector_dim=0, + to_apply=update_s32 +} +)"; + Literal operand = LiteralUtil::CreateR0(1); + Literal scatter_indices = LiteralUtil::CreateR1({}); + Literal updates = LiteralUtil::CreateR0(2); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 0dcb1c42db1..4b3283b5cd7 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -84,7 +84,7 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) { GetParam().window_strides, GetParam().padding_type, source, ConstantR0(&builder_, 0.0f), add_f32_); - ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5)); + ComputeAndCompare(&builder_, {}, ErrorSpec(1e-4)); } INSTANTIATE_TEST_CASE_P( @@ -199,7 +199,10 @@ INSTANTIATE_TEST_CASE_P( SelectAndScatterTestParam{ {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}}, SelectAndScatterTestParam{ - {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}})); + {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}}, + SelectAndScatterTestParam{{1104}, {551}, Padding::kValid, {3}, {2}}, + SelectAndScatterTestParam{ + {1300}, {1171}, Padding::kValid, {130}, {1}})); // Test for F32 1D array, with a zero-element input. XLA_TEST_F(SelectAndScatterTest, R1S0F32) { diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index a9874a91865..4241d813356 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -18,9 +18,8 @@ limitations under the License. #include #include #include -#include -#include "absl/strings/ascii.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "tensorflow/core/platform/logging.h" @@ -31,7 +30,7 @@ namespace { // Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is // disabled - a sequence of regexps. -using ManifestT = std::unordered_map>; +using ManifestT = absl::flat_hash_map>; ManifestT ReadManifest() { ManifestT manifest; @@ -68,10 +67,21 @@ ManifestT ReadManifest() { } // namespace -string PrependDisabledIfIndicated(const string& test_case_name, - const string& test_name) { +std::string PrependDisabledIfIndicated(absl::string_view test_case_name, + absl::string_view test_name) { ManifestT manifest = ReadManifest(); + // If the test name ends with a slash followed by one or more digits, strip + // that off; this is just a shard number, and matching on this would be + // unstable even if someone wanted to do it. + static auto* shard_num_pattern = new RE2(R"(/\d+$)"); + tensorflow::RegexpStringPiece suffix; + if (RE2::PartialMatch( + tensorflow::RegexpStringPiece(test_name.data(), test_name.size()), + *shard_num_pattern, &suffix)) { + test_name.remove_suffix(suffix.size()); + } + // First try full match: test_case_name.test_name // If that fails, try to find just the test_case_name; this would disable all // tests in the test case. @@ -79,7 +89,7 @@ string PrependDisabledIfIndicated(const string& test_case_name, if (it == manifest.end()) { it = manifest.find(test_case_name); if (it == manifest.end()) { - return test_name; + return std::string(test_name); } } @@ -88,12 +98,12 @@ string PrependDisabledIfIndicated(const string& test_case_name, string platform_string = XLA_PLATFORM; for (const auto& s : disabled_platforms) { if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { - return "DISABLED_" + test_name; + return absl::StrCat("DISABLED_", test_name); } } // We didn't hit in the disabled manifest entries, so don't disable it. - return test_name; + return std::string(test_name); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 80a6868485c..9636df2ff5f 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -30,6 +30,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/test.h" @@ -68,8 +69,8 @@ namespace xla { // disabled on a particular platform. For a test that should be disabled, // returns DISABLED_ prepended to its name; otherwise returns the test name // unmodified. -string PrependDisabledIfIndicated(const string& test_case_name, - const string& test_name); +std::string PrependDisabledIfIndicated(absl::string_view test_case_name, + absl::string_view test_name); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 4ac3dbd80cf..07dabc2cfaf 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -207,7 +207,12 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, if (engine == nullptr) { return Literal::CreateFromShape(shape); } - Literal literal(shape); + // Clear tiles/element size in shape's layout before using it for creating + // literal. + Shape new_shape = shape; + new_shape.mutable_layout()->clear_tiles(); + new_shape.mutable_layout()->set_element_size_in_bits(0); + Literal literal(new_shape); switch (shape.element_type()) { case BF16: PopulateWithFloatingPointData(&literal, engine, no_duplicates); @@ -300,7 +305,12 @@ StatusOr MakeFakeLiteralInternalWithBounds(const Shape& shape, if (engine == nullptr) { return Literal::CreateFromShape(shape); } - Literal literal(shape); + // Clear tiles/element size in shape's layout before using it for creating + // literal. + Shape new_shape = shape; + new_shape.mutable_layout()->clear_tiles(); + new_shape.mutable_layout()->set_element_size_in_bits(0); + Literal literal(new_shape); switch (shape.element_type()) { case S8: PopulateWithRandomIntegralDataWithBounds( diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index f68ee04565f..4337aa4bf9a 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -143,7 +143,7 @@ compare { p.0.rhs = f32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { @@ -174,7 +174,7 @@ compare { p.0.rhs = s32[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { @@ -205,7 +205,7 @@ compare { p.0.rhs = bf16[] parameter(1) p.1.lhs = s32[] parameter(2) p.1.rhs = s32[] parameter(3) - ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) { diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index b77cf38ed8e..38a2a9b8fba 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -129,7 +129,7 @@ HloModule TokenInWhileLoop %param = (s32[], token[]) parameter(0) %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) + ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT } ENTRY %TokenInWhileLoop () -> s32[] { diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index d6641d257a7..00b72cedbf5 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/stream_pool.h" @@ -34,6 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" namespace xla { namespace { @@ -117,6 +117,26 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { LiteralTestUtil::ExpectR1Equal(test_vector, result); } +XLA_TEST_F(TransferManagerTest, TransferR1LargeUnalignedF32) { + std::vector test_vector(1025); + std::iota(test_vector.begin(), test_vector.end(), 0); + Shape shape = ShapeUtil::MakeShape(F32, {1024}); + BorrowingLiteral literal(reinterpret_cast(&test_vector[1]), + shape); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + + std::vector expected_output(1024); + std::iota(expected_output.begin(), expected_output.end(), 1); + LiteralTestUtil::ExpectR1Equal(expected_output, result); +} + XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; Literal literal = LiteralUtil::CreateR1U8(test_string); diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 3848ec1684c..3407a68f709 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -308,6 +308,19 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { + XlaBuilder builder(TestName()); + mutable_debug_options()->set_xla_cpu_enable_fast_math(false); + mutable_debug_options()->set_xla_gpu_enable_fast_min_max(false); + auto low = ConstantR1(&builder, {NAN, 1, 1}); + auto high = ConstantR1(&builder, {3, NAN, 3}); + auto x = ConstantR1(&builder, {2, 2, NAN}); + Clamp(low, x, high); + + std::vector expected = {NAN, NAN, NAN}; + ComputeAndCompareR1(&builder, expected, {}); +} + XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { XlaBuilder builder(TestName()); auto zero = ConstantR0(&builder, 0); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 85212fa56d7..4d80a57ad40 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -1265,7 +1265,7 @@ void BM_WhileLoop(int num_iters) { se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); - StreamExecutorMemoryAllocator allocator(platform, executors); + se::StreamExecutorMemoryAllocator allocator(platform, executors); LocalClient* client = ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 7b7b8f5d02d..b36fc4174ae 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -135,7 +135,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, LocalService* service = ClientLibrary::GetXlaService(client->platform()); Backend* backend = service->mutable_backend(); se::StreamExecutor* executor = backend->default_stream_executor(); - DeviceMemoryAllocator* allocator = backend->memory_allocator(); + se::DeviceMemoryAllocator* allocator = backend->memory_allocator(); auto* transfer_manager = backend->transfer_manager(); TF_ASSERT_OK_AND_ASSIGN( StreamPool::Ptr stream_ptr, diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index ebd4bb1e42c..4edd13c79c7 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -72,6 +72,7 @@ cc_library( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//third_party/eigen3", "@com_google_absl//absl/types:span", ], alwayslink = True, @@ -231,6 +232,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc index 8460ae3e499..88f3a8bdde2 100644 --- a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -19,7 +19,9 @@ limitations under the License. // // Reads one serilized Hlo module, convert it into JSON format and dump into // some output directory. some_binaray_proto is obtained by serializing Hlo -// module to disk using --xla_dump_optimized_hlo_proto_to debug option. +// module to disk using the debug options +// +// --xla_dump_to=DIR --xla_dump_hlo_as_proto #include #include diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc index 0c7c078b9b9..5652d303f02 100644 --- a/tensorflow/compiler/xla/tools/interactive_graphviz.cc +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -38,6 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/tools/hlo_extractor.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/subprocess.h" @@ -388,22 +390,18 @@ bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) { return false; } -void DisplayGraphHandle(const Options &opts, const string& handle) { - std::cout << handle << std::endl; +void OpenUrl(const Options& opts, absl::string_view url) { + std::cout << url << std::endl; // If it is a url, try to open it up in the user's browser too. - if (absl::StartsWithIgnoreCase(handle, "http://") || - absl::StartsWithIgnoreCase(handle, "https://") || - absl::StartsWithIgnoreCase(handle, "file://")) { + if (absl::StartsWithIgnoreCase(url, "http://") || + absl::StartsWithIgnoreCase(url, "https://") || + absl::StartsWithIgnoreCase(url, "file://")) { const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser" : opts.browser.c_str(); tensorflow::SubProcess p; - p.SetProgram(browser_bin, {browser_bin, handle}); + p.SetProgram(browser_bin, {browser_bin, string(url)}); p.Start(); - } else if (handle.empty()) { - std::cerr << "Unable to render graph, perhaps due to graphviz server " - "timeout. Run with --logtostderr to see." - << std::endl; } else { std::cerr << "\nExpected a URL, but got strange graph result (dumped " "above). If this isn't what you expected, maybe file a bug?" @@ -411,6 +409,65 @@ void DisplayGraphHandle(const Options &opts, const string& handle) { } } +// Renders a graph by calling `renderer`, and then tries to open it. +// +// `renderer` is a callback so we can try various formats. In particular, the +// URL format doesn't work out of the box; it requires you to register a plugin. +void RenderAndDisplayGraph( + const Options& opts, + const std::function(RenderedGraphFormat)>& renderer) { + StatusOr url_result = renderer(RenderedGraphFormat::kUrl); + if (url_result.ok()) { + string url = url_result.ValueOrDie(); + OpenUrl(opts, url); + return; + } + + // Ignore UNAVAILABLE errors; these are expected when there's no URL renderer + // plugin registered. + if (url_result.status().code() != tensorflow::error::UNAVAILABLE) { + std::cerr << "Unable to render graph as URL: " << url_result.status() + << std::endl; + std::cerr << "Trying as HTML..." << std::endl; + } + + auto* env = tensorflow::Env::Default(); + StatusOr html_result = renderer(RenderedGraphFormat::kHtml); + if (!html_result.ok()) { + std::cerr << "Failed to render graph as HTML: " << html_result.status() + << std::endl; + return; + } + + std::vector temp_dirs; + env->GetLocalTempDirectories(&temp_dirs); + if (temp_dirs.empty()) { + std::cerr << "Can't render graph as HTML because we can't find a suitable " + "temp directory. Try setting $TMPDIR?" + << std::endl; + return; + } + + // Try to create a unique file inside of temp_dirs.front(). Notably, this + // file's name must end with ".html", otherwise web browsers will treat it as + // plain text, so we can't use Env::CreateUniqueFileName(). + string temp_file_path = tensorflow::io::JoinPath( + temp_dirs.front(), + absl::StrFormat("interactive_graphviz.%d.html", env->NowMicros())); + auto status = tensorflow::WriteStringToFile( + env, temp_file_path, std::move(html_result).ValueOrDie()); + if (status.ok()) { + OpenUrl(opts, absl::StrCat("file://", temp_file_path)); + return; + } + + std::cerr << "Failed to write rendered HTML graph to " << temp_file_path + << ": " << status; + + // We don't bother trying kDot, because kHTML should always work (or if it + // doesn't, we don't have any reason to believe kDot will work better). +} + void DoAllPathsCommand(const Options& opts, const HloModule& module, const std::vector& tokens) { if (tokens.size() > 4) { @@ -451,8 +508,10 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module, std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2]; return; } - DisplayGraphHandle(opts, hlo_graph_dumper::DumpAllPathsFromTo( - *from, *to, max_nodes, /*show_backend_config=*/show_backend_config)); + RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { + return RenderAllPathsFromTo(*from, *to, max_nodes, format, + /*show_backend_config=*/show_backend_config); + }); } // Plot a given instruction neighborhood or computation with graphviz. @@ -513,14 +572,19 @@ void DoPlotCommand(const Options& opts, const HloModule& module, // Generate the graph and print the resulting string, which should be a // graphviz url. if (comp) { - DisplayGraphHandle(opts, hlo_graph_dumper::DumpGraph( - *comp, "", comp->parent()->config().debug_options(), nullptr, - /*show_backend_config=*/show_backend_config)); + RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { + return RenderGraph(*comp, /*label=*/"", + comp->parent()->config().debug_options(), format, + /*hlo_execution_profile=*/nullptr, + /*show_backend_config=*/show_backend_config); + }); } else { - DisplayGraphHandle(opts, hlo_graph_dumper::DumpNeighborhoodAround( - *instr, graph_width, - /*show_backend_config=*/show_backend_config, - /*boundary=*/boundary)); + RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { + return RenderNeighborhoodAround( + *instr, graph_width, format, + /*show_backend_config=*/show_backend_config, + /*boundary=*/boundary); + }); } } diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index d66561315b4..257b1ef5c3d 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -15,6 +15,13 @@ limitations under the License. // Usage: replay_computation some_binary_snapshot_proto* // +// Where some_binary_snapshot_proto is [type_prefix:]file_path. Supported +// type_prefixes: +// * recordio_hlo_proto - for a Tensorflow recordio file containing serialized +// xla.HloProtos. +// +// If type_prefix is omitted, the program will make several guesses. +// // Replays computations and shows the results on the command line. // // some_binary_snapshot_proto is obtained by serializing the HloSnapshot from @@ -34,13 +41,17 @@ limitations under the License. // Note: If you pass multiple modules, they will be compiled in parallel but run // in series. +#define EIGEN_USE_THREADS + #include + #include #include #include #include #include "absl/types/span.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -61,6 +72,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -73,6 +86,9 @@ namespace { // Command-line opts to this tool. See main() for descriptions of these // fields. struct Options { + Options() + : intra_op_thread_pool_size(tensorflow::port::NumSchedulableCPUs()) {} + string fake_infeed_shape; string fake_outfeed_shape; @@ -88,6 +104,8 @@ struct Options { bool use_fake_data = false; bool print_result = true; int num_runs = 1; + + int intra_op_thread_pool_size; }; StatusOr> CompileExecutable( @@ -271,7 +289,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, // Run the computation num_runs times, and return the result from the last // execution. const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile(); - StreamExecutorMemoryAllocator allocator( + se::StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); absl::optional final_result; @@ -282,10 +300,16 @@ StatusOr ReplayComputation(const HloSnapshot& module, if (xla_hlo_profile && is_final_result) { LOG(INFO) << "\n\n***** Final run below ******"; } + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", + opts.intra_op_thread_pool_size); + Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(), + pool.NumThreads()); + ExecutionProfile profile; ExecutableRunOptions run_options; run_options.set_execution_profile(&profile); run_options.set_allocator(&allocator); + run_options.set_intra_op_thread_pool(&thread_pool); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, executable->Run(argument_ptrs, run_options)); @@ -306,9 +330,41 @@ StatusOr ReplayComputation(const HloSnapshot& module, return result_literal; } -StatusOr ParseInputFile(const string& filename, - const Options& opts) { +StatusOr> ParseRecordIoFile(absl::string_view filename, + const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); + + std::unique_ptr file; + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + string(filename.begin(), filename.end()), &file)); + tensorflow::io::RecordReader reader( + file.get(), + tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions("ZLIB")); + + std::vector snapshots; + uint64 offset = 0; + string record; + while (reader.ReadRecord(&offset, &record).ok()) { + HloSnapshot snapshot; + if (snapshot.mutable_hlo()->ParseFromString(record)) { + snapshots.push_back(std::move(snapshot)); + } else { + LOG(ERROR) << "Encountered bad proto"; + } + } + CHECK(!snapshots.empty()) + << "No proto is successfully parsed from the file - the file possibly " + "has a mismatched compression option, format, etc."; + CHECK(opts.use_fake_data) + << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " + "and textual HLO don't carry real data."; + return snapshots; +} + +StatusOr ParseSingleHloFile(const string& filename, + const Options& opts) { + tensorflow::Env* env = tensorflow::Env::Default(); + HloSnapshot snapshot; auto s = tensorflow::ReadBinaryProto(env, filename, &snapshot); if (s.ok()) { @@ -337,24 +393,40 @@ StatusOr ParseInputFile(const string& filename, *snapshot.mutable_hlo()->mutable_hlo_module() = module.ValueOrDie()->ToProto(); return snapshot; + } else { + LOG(ERROR) << module.status(); } fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", filename.c_str()); return InvalidArgument("Could not parse %s.", filename); } +StatusOr> ParseInputFile(const string& filename, + const Options& opts) { + std::vector snapshots; + absl::string_view filename_view = filename; + if (absl::ConsumePrefix(&filename_view, "recordio_hlo_proto:")) { + return ParseRecordIoFile(filename_view, opts); + } + TF_ASSIGN_OR_RETURN(auto snapshot, ParseSingleHloFile(filename, opts)); + return std::vector{std::move(snapshot)}; +} + int RealMain(absl::Span args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; std::vector snapshots; for (char* arg : args) { - StatusOr maybe_snapshot = ParseInputFile(arg, opts); + StatusOr> maybe_snapshot = + ParseInputFile(arg, opts); if (maybe_snapshot.ok()) { - snapshots.push_back(std::move(maybe_snapshot).ValueOrDie()); + auto new_snapshots = std::move(maybe_snapshot).ValueOrDie(); + snapshots.insert(snapshots.end(), + std::make_move_iterator(new_snapshots.begin()), + std::make_move_iterator(new_snapshots.end())); } else { - LOG(ERROR) << "Can't handle file " << arg << ": " - << maybe_snapshot.status(); + LOG(ERROR) << maybe_snapshot.status(); } } @@ -362,10 +434,12 @@ int RealMain(absl::Span args, const Options& opts) { LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; std::vector>> executables; { + constexpr size_t kThreadLimits = 100; // ThreadPool CHECK-fails if we give it 0 threads. tensorflow::thread::ThreadPool thread_pool( tensorflow::Env::Default(), tensorflow::ThreadOptions(), - "compile_modules", std::max(size_t{1}, snapshots.size()), + "compile_modules", + std::min(std::max(kThreadLimits, snapshots.size()), 1), /*low_latency_hint=*/false); executables.resize(snapshots.size()); for (int64 i = 0; i < snapshots.size(); ++i) { @@ -378,7 +452,8 @@ int RealMain(absl::Span args, const Options& opts) { for (int64 i = 0; i < executables.size(); ++i) { if (!executables[i].ok()) { - LOG(ERROR) << "Compilation failed: " << executables[i].status(); + LOG(ERROR) << "Compilation failed: " << executables[i].status() << ": " + << snapshots[i].ShortDebugString(); exit_status = EXIT_FAILURE; continue; } @@ -439,6 +514,10 @@ int main(int argc, char** argv) { tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed, "Whether a fake outfeed shape should be derived " "from the computation"), + tensorflow::Flag("intra_op_thread_pool_size", + &opts.intra_op_thread_pool_size, + "How many threads to use in the intra-op thread pool. " + "Defaults to the number of CPUs."), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index bb8bbf57c42..732b7f2efd7 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include + #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -39,23 +41,41 @@ Status WithLogBacktrace(const Status& status) { return status; } -ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled) - : enabled(enabled), label(label) { +ScopedLoggingTimer::ScopedLoggingTimer(const std::string& label, bool enabled, + TimerStats* timer_stats) + : enabled(enabled), label(label), timer_stats(timer_stats) { if (enabled) { start_micros = tensorflow::Env::Default()->NowMicros(); } } -ScopedLoggingTimer::~ScopedLoggingTimer() { +void ScopedLoggingTimer::StopAndLog() { if (enabled) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); double secs = (end_micros - start_micros) / 1000000.0; + TimerStats& stats = *timer_stats; + tensorflow::mutex_lock lock(stats.stats_mutex); + stats.cumulative_secs += secs; + if (secs > stats.max_secs) { + stats.max_secs = secs; + } + stats.times_called++; + LOG(INFO) << label << " time: " - << tensorflow::strings::HumanReadableElapsedTime(secs); + << tensorflow::strings::HumanReadableElapsedTime(secs) + << " (cumulative: " + << tensorflow::strings::HumanReadableElapsedTime( + stats.cumulative_secs) + << ", max: " + << tensorflow::strings::HumanReadableElapsedTime(stats.max_secs) + << ", #called: " << stats.times_called << ")"; + enabled = false; } } +ScopedLoggingTimer::~ScopedLoggingTimer() { StopAndLog(); } + Status AddStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); return Status{prior.code(), @@ -91,7 +111,7 @@ std::vector InversePermutation( DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { - output_permutation[input_permutation[i]] = i; + output_permutation.at(input_permutation.at(i)) = i; } return output_permutation; } @@ -101,7 +121,7 @@ std::vector ComposePermutations(absl::Span p1, CHECK_EQ(p1.size(), p2.size()); std::vector output; for (size_t i = 0; i < p1.size(); ++i) { - output.push_back(p1[p2[i]]); + output.push_back(p1.at(p2.at(i))); } return output; } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 1754ae0e44f..55b092cfbaa 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -63,6 +64,8 @@ using DimensionVector = absl::InlinedVector; // readable form. This differs from base's ElapsedTimer primarily in that it // spits out the human-readable duration form. // +// Keeps track of global maximum and cumulative times across all invocations. +// // By default, the timing traces are only printed at VLOG(1) and above: // // XLA_SCOPED_LOGGING_TIMER("fooing bar"); // nop if !VLOG_IS_ON(1). @@ -83,9 +86,17 @@ using DimensionVector = absl::InlinedVector; XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) // Helper for macros above. Don't use directly. -#define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \ - ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \ - label, VLOG_IS_ON(level)) +#define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \ + static ::xla::TimerStats XLA_TimerStats##counter; \ + ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \ + label, /*enabled=*/VLOG_IS_ON(level), &XLA_TimerStats##counter); + +struct TimerStats { + tensorflow::mutex stats_mutex; + double cumulative_secs GUARDED_BY(stats_mutex) = 0; + double max_secs GUARDED_BY(stats_mutex) = 0; + uint64 times_called GUARDED_BY(stats_mutex) = 0; +}; // RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL // macros above. Recommended usage is via the macros so you don't have to give @@ -93,12 +104,22 @@ using DimensionVector = absl::InlinedVector; struct ScopedLoggingTimer { // The timer does nothing if enabled is false. This lets you pass in your // file's VLOG_IS_ON value. - ScopedLoggingTimer(const string& label, bool enabled); + // + // timer_stats is unowned non-null pointer which is used to populate the + // global timer statistics. + ScopedLoggingTimer(const std::string& label, bool enabled, + TimerStats* timer_stats); + + // Stop the timer and log the tracked time. Timer is disabled after this + // function is called. + void StopAndLog(); + ~ScopedLoggingTimer(); bool enabled; string label; uint64 start_micros; + TimerStats* timer_stats; }; // Given a vector, returns a Span that points at its diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index e001cc35f9f..f2e18311039 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -204,6 +204,14 @@ bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) { window_dim.padding_low() == 0 && window_dim.padding_high() == 0; } +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} + int64 DilatedBound(int64 bound, int64 dilation) { CHECK_GE(bound, 0); CHECK_GE(dilation, 1); diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 099d7ecdd5c..e7099285c34 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -62,6 +62,10 @@ bool AllOrNoneReversed(const Window& window); // has window bound 1, no striding and no padding. bool IsInactiveWindowDimension(const Window& window, int64 logical_dim); +// Returns true if the provided window dimension is trivial (inactive and has no +// dilation) +bool IsTrivialWindowDimension(const WindowDimension& window_dimension); + // Returns the new bound after dilation. // // If a window with the given bound in some dimension is dilated with the given diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index cda2d7c7c6b..16ef0caf29b 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -8,7 +8,10 @@ load( "//tensorflow/core:platform/default/build_config_root.bzl", "if_static", ) -load("//tensorflow:tensorflow.bzl", "if_cuda_is_configured") +load( + "//tensorflow/core:platform/default/cuda_build_defs.bzl", + "if_cuda_is_configured", +) # xla_proto_library() is a convenience wrapper around cc_proto_library. def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = 0, **kwargs): diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 925fcbf88c1..43666758e64 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -61,40 +61,12 @@ message HloReducePrecisionOptions { // Debugging options for XLA. These options may change at any time - there are // no guarantees about backward or forward compatibility for these fields. message DebugOptions { - // HLO modules matching this regex will be dumped to a .dot file throughout - // various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to - // dump *all* HLO modules. - string xla_generate_hlo_graph = 1; - // Show addresses of HLO ops in graph dump. bool xla_hlo_graph_addresses = 2; - // Path to dump HLO graphs to. - string xla_hlo_graph_path = 4; - - reserved 5; // Was xla_hlo_dump_as_graphdef - - // HLO modules matching this regex will be dumped to LOG(INFO). Set to ".*" to - // dump *all* HLO modules. - string xla_log_hlo_text = 6; - - // Dump all HLO modules as text into the provided directory path. - string xla_generate_hlo_text_to = 7; - - // Dump Hlo after all hlo passes are executed as proto binary into this - // directory. - string xla_dump_optimized_hlo_proto_to = 8; - // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; - // Dumps computations that XLA executes into the provided directory path. - string xla_dump_computations_to = 10; - - // Dumps parameters and results of computations that XLA executes into the - // provided directory path. - string xla_dump_executions_to = 11; - // List of HLO passes to disable. These names must exactly match the pass // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 30; @@ -114,9 +86,6 @@ message DebugOptions { // Embed the compiler IR as a string in the executable. bool xla_embed_ir_in_executable = 33; - // Dump the compiler IR into this directory as individual files. - string xla_dump_ir_to = 34; - // Eliminate implicit broadcasts when lowering user computations to HLO // instructions; use explicit broadcast instead. bool xla_eliminate_hlo_implicit_broadcast = 35; @@ -176,14 +145,6 @@ message DebugOptions { // ops. bool xla_gpu_use_cudnn_batchnorm = 94; - // Dump HLO before any hlo passes are executed as proto binary into this - // directory. - string xla_dump_unoptimized_hlo_proto_to = 95; - - // Dump HLO after each pass as an HloProto in binary file format into this - // directory. - string xla_dump_per_pass_hlo_proto_to = 96; - // Generate calls to MKL-DNN in the CPU backend. bool xla_cpu_use_mkl_dnn = 97; @@ -195,20 +156,37 @@ message DebugOptions { // // - Reducing the precision of operations (e.g. using an approximate sin // function, or transforming x/y into x * (1/y)). - // - Assuming that operations never produce or consume NaN or +/- Inf. + // - Assuming that operations never produce or consume NaN or +/- Inf (this + // behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). // - Assuming that +0 and -0 are indistinguishable. bool xla_cpu_enable_fast_math = 99; + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_nans = 120; + + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_infs = 121; + // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. bool xla_gpu_enable_fast_min_max = 100; + // Allows xla to increase the output precision of floating point operations. + bool xla_allow_excess_precision = 122; + // Crashes the program when any kind of verification fails, instead of just // logging the failures. One example is cross checking of convolution results // among different algorithms. bool xla_gpu_crash_on_verification_failures = 101; + // Disable GEMM and Convolution auto-tuning. + bool xla_gpu_disable_autotune = 123; + // Force the host platform to pretend that there are these many host // "devices". All these devices are backed by the same threadpool. Defaults // to 1. @@ -221,9 +199,6 @@ message DebugOptions { // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). bool xla_gpu_disable_ptxas_optimizations = 103; - // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) - bool xla_hlo_dump_as_html = 105; - // Enable fast math with eigen in the HLO evaluator. bool xla_hlo_evaluator_use_fast_path = 106; @@ -232,14 +207,18 @@ message DebugOptions { bool xla_allow_scalar_index_dynamic_ops = 107; enum StepMarkerLocation { - // Generate step mark at each iteration of top level while loop, which - // is assumed to be a training loop. This is the default. + // Generate a step marker at the program entry. This handles the case where + // each step is done by one or multiple program execution(s). Only the first + // program will be tagged for generating a step marker at the program entry. + // This is the default. STEP_MARK_AT_ENTRY = 0; - // Generate step mark at program entry. This handles the case where each - // step are done by one or multiple programs execution. Only the first - // program will be tagged for generating step mark at program entry. + // Generate a step marker at each iteration of the top level while loop, + // which is assumed to be a training loop. STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1; - // No step mark. + // Generate a step marker at each iteration of the second level while loops, + // which is assumed to be a training or eval loop. + STEP_MARK_AT_SECOND_LEVEL_WHILE_LOOP = 3; + // No step marker generated. STEP_MARK_NONE = 2; } // Option to emit a target-specific marker to indicate the start of a training @@ -247,11 +226,59 @@ message DebugOptions { // value. StepMarkerLocation xla_step_marker_location = 108; - // Next id: 109 + // + // BEGIN flags controlling dumping HLO modules for debugging. + // + // When dumping is enabled, HLO modules dumped at the very beginning and end + // of compilation, and optionally also during the pass pipeline. + // + // In general, if you set one of these flags, we will try to infer reasonable + // defaults for the others. For example: + // + // * Setting --xla_dump_to=/tmp/foo without specifying a format + // with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text. + // + // * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will + // dump to stdout. + // + + // Directory to dump into. + string xla_dump_to = 109; + + // If specified, will only dump modules which match this regexp. + string xla_dump_hlo_module_re = 110; + + // If this flag is specified, will also HLO before and after passes that match + // this regular expression. Set to .* to dump before/after all passes. + string xla_dump_hlo_pass_re = 111; + + // Specifies the format that HLO is dumped in. Multiple of these may be + // specified. + bool xla_dump_hlo_as_text = 112; + bool xla_dump_hlo_as_proto = 113; + bool xla_dump_hlo_as_dot = 114; + bool xla_dump_hlo_as_url = 115; + + // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + bool xla_dump_hlo_as_html = 116; + + // If true, every time an HLO module is run, we will dump an HloSnapshot + // (essentially, a serialized module plus its inputs) to the --xla_dump_to + // directory. + bool xla_dump_hlo_snapshots = 118; + + // + // END flags controlling dumping HLO modules. + // + + // Next id: 124 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; + + reserved 117; // was xla_dump_to + reserved 5; // Was xla_hlo_dump_as_graphdef } // These settings control how XLA compiles and/or runs code. Not all settings @@ -282,6 +309,10 @@ message ExecutionOptions { // Number of replicas of the computation to run. If zero, uses the default // number of replicas for the XLA service. int32 num_replicas = 6; + + // This optional field specifies the device assignment if known at compile + // time. + DeviceAssignmentProto device_assignment = 7; } message GetDeviceHandlesRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 6e5772a7396..67f76d00703 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -71,7 +71,10 @@ enum PrimitiveType { // An opaque type used for passing context-specific data to a custom // operation. Shapes of this primitive type will have empty dimensions and // tuple_shapes fields. - OPAQUE = 14; + // + // (OPAQUE would be a better name for this identifier, but that conflicts with + // a macro defined in windows.h.) + OPAQUE_TYPE = 14; // A token type threaded between side-effecting operations. Shapes of this // primitive type will have empty dimensions and tuple_shapes fields. diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index b2718c5c283..acd984f9e99 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -54,6 +54,7 @@ cc_library( "xrt_util.h", ], deps = [ + ":xrt_proto", "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:debug_options_flags", @@ -66,13 +67,13 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:backend", - "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/stream_executor", + "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", diff --git a/tensorflow/compiler/xrt/client/BUILD b/tensorflow/compiler/xrt/client/BUILD new file mode 100644 index 00000000000..3908f026bcf --- /dev/null +++ b/tensorflow/compiler/xrt/client/BUILD @@ -0,0 +1,115 @@ +# Description: Operations defined for XRT + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//tensorflow:internal", + "//tensorflow/compiler:__subpackages__", + ], +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + +cc_library( + name = "xrt_grpc_eager_client", + srcs = ["xrt_grpc_eager_client.cc"], + hdrs = ["xrt_grpc_eager_client.h"], + deps = [ + "//tensorflow:grpc++", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:eager_service_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/distributed_runtime:call_options", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_client_cq_tag", + "//tensorflow/core/distributed_runtime/rpc:grpc_state", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "xrt_tf_client", + srcs = ["xrt_tf_client.cc"], + hdrs = ["xrt_tf_client.h"], + deps = [ + ":xrt_grpc_eager_client", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:eager_service_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime:call_options", + "//tensorflow/core/distributed_runtime:request_id", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "xrt_client", + srcs = ["xrt_client.cc"], + hdrs = ["xrt_client.h"], + deps = [ + ":xrt_tf_client", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xrt:xrt_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf/tpu:topology_proto_cc", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +tf_cc_test( + name = "xrt_client_test", + srcs = ["xrt_client_test.cc"], + data = [":xrt_testlib_server"], + deps = [ + ":xrt_client", + ":xrt_grpc_eager_client", + ":xrt_tf_client", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:eager_service_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_session", + "//tensorflow/core/distributed_runtime/rpc:grpc_testlib", + ], +) + +tf_cc_binary( + name = "xrt_testlib_server", + testonly = 1, + deps = [ + "//tensorflow/compiler/jit:xla_cpu_device", + "//tensorflow/compiler/xrt:xrt_server", + "//tensorflow/core/distributed_runtime/rpc:grpc_testlib_server_main", + ], +) diff --git a/tensorflow/compiler/xrt/client/xrt_client.cc b/tensorflow/compiler/xrt/client/xrt_client.cc new file mode 100644 index 00000000000..0c19b0dcee3 --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_client.cc @@ -0,0 +1,654 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/client/xrt_client.h" + +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xrt/client/xrt_tf_client.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/tensor_coding.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" + +namespace tensorflow { + +namespace { + +// Deserializes a TensorProto containing a scalar string value. +xla::StatusOr DeserializeTensorProtoAsString( + const TensorProto& proto) { + if (proto.dtype() != DT_STRING) { + return errors::InvalidArgument("Tensors must be of type DT_STRING, got ", + DataType_Name(proto.dtype())); + } + if (proto.tensor_shape().dim_size() != 0 || + proto.tensor_shape().unknown_rank()) { + return errors::InvalidArgument("String tensor must be a scalar, got ", + proto.tensor_shape().DebugString()); + } + if (proto.string_val_size() > 0) { + if (proto.string_val_size() != 1) { + return errors::InvalidArgument( + "Expected at most one string_val in TensorProto, got ", + proto.string_val_size()); + } + return proto.string_val(0); + } else { + std::string data; + port::DecodeStringList(proto.tensor_content(), &data, 1); + return data; + } +} + +// Deserializes a xla::Literal from a TensorProto. +xla::StatusOr DeserializeTensorProtoAsLiteral( + const TensorProto& proto) { + TF_ASSIGN_OR_RETURN(std::string data, DeserializeTensorProtoAsString(proto)); + xla::LiteralProto literal_proto; + literal_proto.ParsePartialFromString(data); + return xla::Literal::CreateFromProto(literal_proto); +} + +} // namespace + +XrtBuffer::XrtBuffer(XrtTensorHandle handle, int xrt_device_ordinal, + xla::Shape shape) + : handle_(std::move(handle)), + xrt_device_ordinal_(xrt_device_ordinal), + shape_(std::move(shape)) {} + +XrtBuffer::~XrtBuffer() { Delete(); } + +/*static*/ xla::StatusOr> XrtBuffer::FromLiteral( + const std::shared_ptr& context, int xrt_device_ordinal, + const xla::LiteralSlice& literal) { + xrt::XLAAllocation allocation; + *allocation.mutable_value() = literal.ToProto(); + + auto proto = absl::make_unique(); + proto->set_dtype(DT_STRING); + allocation.SerializeToString(proto->add_string_val()); + + if (xrt_device_ordinal < 0 || + xrt_device_ordinal >= context->tf_device_ids().size()) { + return errors::InvalidArgument("Invalid XRT device ordinal ", + xrt_device_ordinal); + } + int tf_device_id = context->tf_device_ids().at(xrt_device_ordinal); + XrtTensorHandle literal_handle = + context->tf_context()->SendTensor(std::move(proto), tf_device_id, + /*host_memory=*/true); + + XrtTensorHandle buffer_handle = std::move(context->tf_context()->EnqueueOp( + "XRTAllocate", {&literal_handle}, /*output_arity=*/1, /*attrs=*/{}, + tf_device_id)[0]); + + return std::make_shared(std::move(buffer_handle), + xrt_device_ordinal, literal.shape()); +} + +/*static*/ xla::StatusOr> XrtBuffer::MakeTuple( + const std::shared_ptr& context, + const std::vector>& elements) { + if (elements.empty()) { + return errors::Unimplemented( + "The arity zero case of MakeTuple is not implemented."); + } + int xrt_device_ordinal = elements[0]->xrt_device_ordinal(); + int tf_device_id = elements[0]->handle().device_id(); + xrt::XLATupleNode tuple_description; + std::vector element_shapes; + element_shapes.reserve(elements.size()); + for (int index = 0; index < elements.size(); ++index) { + xrt::XLATupleNode* node = tuple_description.add_tuples(); + node->set_input_index(index); + element_shapes.push_back(elements[index]->shape()); + if (elements[index]->handle().device_id() != tf_device_id) { + return errors::InvalidArgument( + "All elements of tuple must be on the same device ( ", + elements[index]->handle().device_id(), " vs. ", tf_device_id, ")"); + } + } + auto proto = absl::make_unique(); + proto->set_dtype(DT_STRING); + tuple_description.SerializeToString(proto->add_string_val()); + + XrtTensorHandle description_handle = + context->tf_context()->SendTensor(std::move(proto), tf_device_id, + /*host_memory=*/true); + + protobuf::Map attrs; + attrs["Ninputs"] = MakeAttrValue(elements.size()); + + std::vector args; + args.reserve(elements.size() + 1); + args.push_back(&description_handle); + for (const auto& element : elements) { + args.push_back(&element->handle()); + } + XrtTensorHandle buffer_handle = std::move(context->tf_context()->EnqueueOp( + "XRTMakeTuple", args, /*output_arity=*/1, attrs, tf_device_id)[0]); + return std::make_shared( + std::move(buffer_handle), xrt_device_ordinal, + xla::ShapeUtil::MakeTupleShape(element_shapes)); +} + +xla::StatusOr XrtBuffer::ToLiteral() const { + TF_RET_CHECK(handle_.valid()); + XrtTensorHandle literal_handle = std::move(handle_.context()->EnqueueOp( + "XRTReadLiteral", {&handle_}, /*output_arity=*/1, /*attrs=*/{}, + handle_.device_id())[0]); + + std::shared_ptr future = + handle_.context()->RecvTensor(literal_handle, DT_STRING, + /*host_memory=*/true); + + // Flush the queue to make sure the producers are dispatched before blocking + // on the future. + handle_.context()->FlushQueue(); + + TF_ASSIGN_OR_RETURN(RecvTensorResponse * response, future->Get()); + VLOG(10) << "ToLiteral received tensor " << response->DebugString(); + TF_RET_CHECK(!response->is_dead()); + return DeserializeTensorProtoAsLiteral(response->tensor()); +} + +void XrtBuffer::Delete() { + if (handle_.valid()) { + handle_.context()->EnqueueOp("XRTReleaseAllocationHandle", {&handle_}, + /*output_arity=*/0, + /*attrs=*/{}, handle_.device_id()); + handle_ = XrtTensorHandle(); + } +} + +xla::StatusOr>> +XrtBuffer::DestructureTuple() { + TF_RET_CHECK(shape_.IsTuple()); + std::vector> output; + output.reserve(shape_.tuple_shapes().size()); + for (int i = 0; i < shape_.tuple_shapes().size(); ++i) { + TensorProto index_proto; + index_proto.set_dtype(DT_INT32); + index_proto.mutable_tensor_shape()->add_dim()->set_size(1); + index_proto.add_int_val(i); + XrtTensorHandle index = + EnqueueConst(handle_.context().get(), handle_.device_id(), index_proto, + /*host_memory=*/true); + XrtTensorHandle sub = std::move( + handle_.context()->EnqueueOp("XRTSubTuple", {&handle_, &index}, + /*output_arity=*/1, + /*attrs=*/{}, handle_.device_id())[0]); + output.push_back(std::make_shared( + std::move(sub), xrt_device_ordinal_, shape_.tuple_shapes(i))); + } + return output; +} + +/*static*/ xla::StatusOr> XrtExecutable::Compile( + std::shared_ptr context, + const xla::HloModuleProto& hlo_module_proto, + const std::vector& argument_shapes, + const xla::Shape& result_shape, xla::DeviceAssignment device_assignment) { + if (device_assignment.replica_count() <= 0 || + device_assignment.computation_count() <= 0) { + return errors::InvalidArgument( + "Device assignment must be non-empty; got ", + device_assignment.replica_count(), " replicas and ", + device_assignment.computation_count(), " computations per replica."); + } + + // TODO(phawkins): add support for per-core argument and return shapes. + TF_RET_CHECK(device_assignment.computation_count() == 1) + << "Computation count != 1 not implemented"; + + xrt::XLAComputation computation; + computation.mutable_config()->set_num_replicas( + device_assignment.replica_count()); + computation.mutable_config()->set_num_cores_per_replica( + device_assignment.computation_count()); + + xrt::DeviceAssignment* xrt_assignment = + computation.mutable_config()->mutable_device_assignment(); + for (int computation = 0; computation < device_assignment.computation_count(); + ++computation) { + xrt::DeviceAssignment::ComputationDevice* xrt_devices = + xrt_assignment->add_computation_devices(); + for (int replica = 0; replica < device_assignment.replica_count(); + ++replica) { + int xrt_device_ordinal = device_assignment(replica, computation); + if (xrt_device_ordinal < 0 || + xrt_device_ordinal >= context->tf_device_ids().size()) { + return errors::InvalidArgument("Invalid device ordinal in device ", + "assignment: ", xrt_device_ordinal); + } + *xrt_devices->add_replica_devices() = + context->device_mesh_coordinates().at(xrt_device_ordinal); + } + } + + xla::ProgramShape program_shape; + for (const xla::Shape& shape : argument_shapes) { + xla::Shape* param_shape = program_shape.add_parameters(); + *param_shape = shape; + if (!xla::LayoutUtil::HasLayout(shape)) { + xla::LayoutUtil::SetToDefaultLayout(param_shape); + } + } + *program_shape.mutable_result() = result_shape; + if (!xla::LayoutUtil::HasLayout(result_shape)) { + xla::LayoutUtil::SetToDefaultLayout(program_shape.mutable_result()); + } + *computation.mutable_config()->mutable_program_shape() = + program_shape.ToProto(); + *computation.mutable_hlo_snapshot()->mutable_hlo()->mutable_hlo_module() = + hlo_module_proto; + + auto proto = absl::make_unique(); + proto->set_dtype(DT_STRING); + computation.SerializeToString(proto->add_string_val()); + + int xrt_device_ordinal_for_compilation = device_assignment(0, 0); + int tf_device_id = + context->tf_device_ids().at(xrt_device_ordinal_for_compilation); + XrtTensorHandle computation_handle = + context->tf_context()->SendTensor(std::move(proto), tf_device_id, + /*host_memory=*/true); + + XrtTensorHandle executable_handle = + std::move(context->tf_context()->EnqueueOp( + "XRTCompile", {&computation_handle}, /*output_arity=*/2, /*attrs=*/{}, + tf_device_id)[0]); + + if (device_assignment.num_elements() > 1) { + string wire_id = XrtGetUniqueWireID(); + int recv_tf_device_id = context->tf_context()->cpu_device_id(); + EnqueueSend(context->tf_context().get(), executable_handle, DT_INT64, + recv_tf_device_id, wire_id, /*host_memory=*/true); + executable_handle = + EnqueueRecv(context->tf_context().get(), DT_INT64, tf_device_id, + recv_tf_device_id, wire_id, /*host_memory=*/true); + } + + return std::make_shared( + std::move(context), std::move(executable_handle), program_shape, + std::move(device_assignment)); +} + +XrtExecutable::XrtExecutable(std::shared_ptr context, + XrtTensorHandle handle, xla::ProgramShape shape, + xla::DeviceAssignment device_assignment) + : context_(std::move(context)), + handle_(std::move(handle)), + shape_(std::move(shape)), + device_assignment_(std::move(device_assignment)) {} + +XrtExecutable::~XrtExecutable() { Delete(); } + +void XrtExecutable::Delete() { + if (handle_.valid()) { + handle_.context()->EnqueueOp("XRTReleaseCompilationHandle", {&handle_}, + /*output_arity=*/0, + /*attrs=*/{}, handle_.device_id()); + handle_ = XrtTensorHandle(); + } +} + +xla::StatusOr> XrtExecutable::Execute( + const std::vector>& args) { + TF_RET_CHECK(device_assignment_.replica_count() == 1 && + device_assignment_.computation_count() == 1) + << device_assignment_.ToString(); + int xrt_device_ordinal = device_assignment_(0, 0); + int tf_device_id = context_->tf_device_ids().at(xrt_device_ordinal); + + TensorProto config_proto; + config_proto.set_dtype(DT_STRING); + config_proto.add_string_val(); + XrtTensorHandle execution_config_handle = + EnqueueConst(handle_.context().get(), tf_device_id, config_proto, + /*host_memory=*/true); + + protobuf::Map attrs; + attrs["Ninputs"] = MakeAttrValue(args.size()); + + std::vector inputs; + inputs.reserve(args.size() + 2); + inputs.push_back(&handle_); + inputs.push_back(&execution_config_handle); + for (const std::shared_ptr& arg : args) { + if (arg->handle().device_id() != tf_device_id) { + return errors::InvalidArgument( + "Input buffer to Execute() is not on the device for which the " + "computation was compiled. Target device is ", + tf_device_id, ", buffer is on device ", arg->handle().device_id()); + } + inputs.push_back(&arg->handle()); + } + + XrtTensorHandle result_handle = std::move(handle_.context()->EnqueueOp( + "XRTExecute", inputs, /*output_arity=*/1, attrs, tf_device_id)[0]); + + return std::make_shared(std::move(result_handle), + xrt_device_ordinal, shape_.result()); +} + +xla::StatusOr>> +XrtExecutable::ExecuteReplicated( + absl::Span>> args) { + if (args.size() != device_assignment_.computation_count()) { + return errors::InvalidArgument( + "Mismatched number of computation per replica between executable and " + "arguments. Expected computations_per_replica=", + device_assignment_.computation_count(), + "; got computations_per_replica=", args.size()); + } + + for (int computation = 0; + computation < device_assignment_.computation_count(); ++computation) { + if (args[computation].n1() != device_assignment_.replica_count()) { + return errors::InvalidArgument( + "Mismatched number of replicas between executable and arguments for " + " computation ", + computation, + ". Expected replicas=", device_assignment_.replica_count(), + "; got replicas=", args[computation].n1()); + } + for (int replica = 0; replica < device_assignment_.replica_count(); + ++replica) { + int xrt_device_ordinal = device_assignment_(replica, computation); + int tf_device_id = context_->tf_device_ids().at(xrt_device_ordinal); + for (int arg = 0; arg < args[computation].n2(); ++arg) { + const std::shared_ptr& buffer = + args[computation](replica, arg); + if (buffer->handle().device_id() != tf_device_id) { + return errors::InvalidArgument( + "Input buffer to ExecuteReplicated() is not on the device for " + "which the computation was compiled. Target device is ", + tf_device_id, ", buffer is on device ", + buffer->handle().device_id()); + } + } + } + } + + std::vector input_arity; + input_arity.reserve(args.size()); + for (const auto& arg : args) { + input_arity.push_back(arg.n2()); + } + TF_ASSIGN_OR_RETURN(string exec_fn, context_->GetExecuteReplicatedFunction( + input_arity, device_assignment_)); + + std::vector input_types; + std::vector inputs; + inputs.push_back(&handle_); + input_types.push_back(DT_INT64); + + std::vector execution_config_handles( + device_assignment_.computation_count()); + int tf_cpu_device_id = context_->tf_context()->cpu_device_id(); + for (int j = 0; j < device_assignment_.computation_count(); ++j) { + TensorProto config_proto; + config_proto.set_dtype(DT_STRING); + xrt::XRTExecutionConfig config; + config.set_core_index_in_replica(j); + config_proto.add_string_val(config.SerializeAsString()); + execution_config_handles[j] = EnqueueConst(context_->tf_context().get(), + tf_cpu_device_id, config_proto, + /*host_memory=*/true); + inputs.push_back(&execution_config_handles[j]); + input_types.push_back(DT_STRING); + } + + for (int i = 0; i < device_assignment_.replica_count(); ++i) { + for (int j = 0; j < device_assignment_.computation_count(); ++j) { + for (int k = 0; k < args[j].n2(); ++k) { + inputs.push_back(&args[j](i, k)->handle()); + input_types.push_back(DT_INT64); + } + } + } + + // Run all the XRTExecute ops in parallel using a multi-device function. + // We do this for two reasons: + // a) we need the operators to run in parallel, but without async mode enabled + // they might not. + // b) we need the operators to all be issued as part of the same + // EnqueueRequest batch, otherwise we will deadlock. + // TODO(phawkins): It would be even better to enable async mode, when its + // error semantics have been improved. + std::vector output_types(device_assignment_.num_elements(), + DT_INT64); + std::vector outputs = context_->tf_context()->EnqueueOp( + exec_fn, inputs, /*output_arity=*/output_types.size(), /*attrs=*/{}, + tf_cpu_device_id); + + xla::Array2D> results( + device_assignment_.computation_count(), + device_assignment_.replica_count()); + int output_num = 0; + for (int i = 0; i < device_assignment_.computation_count(); ++i) { + for (int j = 0; j < device_assignment_.replica_count(); ++j) { + int xrt_device_ordinal = device_assignment_(j, i); // NB. different order + int tf_device_id = context_->tf_device_ids().at(xrt_device_ordinal); + + // EnqueueOp doesn't know about multidevice functions, so it will assume + // that the outputs are on the CPU. Override the device IDs it assigned; + // we know better. + outputs[output_num].set_device_id(tf_device_id); + + // TODO(phawkins): use a per-core result shape here. + results(i, j) = std::make_shared( + std::move(outputs[output_num]), xrt_device_ordinal, shape_.result()); + ++output_num; + } + } + return results; +} + +/*static*/ xla::StatusOr> XrtContext::Create( + std::shared_ptr tf_context, string device_type) { + auto context = std::make_shared(tf_context, device_type); + if (context->tf_device_ids().empty()) { + return errors::NotFound("No accelerator devices of type ", device_type, + " are present."); + } + if (device_type == "TPU") { + TF_RETURN_IF_ERROR(context->InitializeTPU()); + } else { + // Fill in a dummy topology mapping for CPU/GPU. + for (int i = 0; i < context->tf_device_ids().size(); ++i) { + context->device_mesh_coordinates_.push_back({}); + context->device_mesh_coordinates_.back().add_value(i); + } + } + return context; +} + +XrtContext::XrtContext(std::shared_ptr tf_context, + string device_type) + : tf_context_(std::move(tf_context)), device_type_(std::move(device_type)) { + for (int i = 0; i < tf_context_->devices().size(); ++i) { + const DeviceAttributes& device = tf_context_->devices()[i]; + VLOG(2) << "Device: " << i << ": " << device.DebugString(); + if (device.device_type() == device_type_) { + tf_device_ids_.push_back(i); + VLOG(1) << "Accelerator device " << i << ": " << device.name(); + } + } +} + +int XrtContext::device_count() const { return tf_device_ids_.size(); } + +static Status RegisterTPUInitializeFunction(XrtTfContext* context) { + FunctionDef fdef; + OpDef* opdef = fdef.mutable_signature(); + opdef->set_name("TPUInitFunc"); + OpDef::ArgDef* outdef = opdef->add_output_arg(); + outdef->set_name("topology"); + outdef->set_type(DT_STRING); + + NodeDef* ndef = fdef.add_node_def(); + ndef->set_name("n"); + ndef->set_op("ConfigureDistributedTPU"); + + (*fdef.mutable_ret())["topology"] = "n:topology"; + + Status status = context->RegisterFunction(fdef); + VLOG(10) << "RegisterTPUInitializeFunction returned " << status; + return status; +} + +Status XrtContext::InitializeTPU() { + LOG(INFO) << "Initializing TPU devices."; + TF_RETURN_IF_ERROR(RegisterTPUInitializeFunction(tf_context_.get())); + + TensorProto index_proto; + index_proto.set_dtype(DT_INT32); + index_proto.add_int_val(0); + XrtTensorHandle device_ordinal = EnqueueConst( + tf_context_.get(), /*device_id=*/tf_context_->cpu_device_id(), + index_proto, /*host_memory=*/false); + + protobuf::Map attrs; + attrs["f"].mutable_func()->set_name("TPUInitFunc"); + attrs["Tin"].mutable_list(); + attrs["Tout"].mutable_list()->add_type(DT_STRING); + XrtTensorHandle t = std::move( + tf_context_->EnqueueOp("TPUPartitionedCall", {&device_ordinal}, + /*output_arity=*/1, + /*attrs=*/attrs, tf_context_->cpu_device_id())[0]); + + auto result = tf_context_->RecvTensor(t, DT_STRING, /*host_memory=*/false); + TF_ASSIGN_OR_RETURN(RecvTensorResponse * response, result->Get()); + VLOG(10) << "TPU topology " << response->DebugString(); + + TF_ASSIGN_OR_RETURN(std::string data, + DeserializeTensorProtoAsString(response->tensor())); + + tpu::TopologyProto tpu_topology; + tpu_topology.ParsePartialFromString(data); + VLOG(4) << "TPU topology:\n" << tpu_topology.DebugString(); + + TF_RET_CHECK(tpu_topology.num_tasks() == 1) << tpu_topology.DebugString(); + TF_RET_CHECK(tpu_topology.num_tpu_devices_per_task() == tf_device_ids_.size()) + << tpu_topology.DebugString() << " " << tf_device_ids_.size(); + + const int mesh_rank = tpu_topology.mesh_shape_size(); + TF_RET_CHECK(tpu_topology.device_coordinates_size() == + tf_device_ids_.size() * mesh_rank); + + for (int i = 0; i < tf_device_ids_.size(); ++i) { + device_mesh_coordinates_.push_back({}); + auto& coords = device_mesh_coordinates_.back(); + for (int j = 0; j < mesh_rank; ++j) { + coords.add_value(tpu_topology.device_coordinates(i * mesh_rank + j)); + } + } + + LOG(INFO) << "TPU initialization succeeded."; + return Status::OK(); +} + +XrtContext::ExecuteReplicatedKey::ExecuteReplicatedKey( + absl::Span input_arity, xla::DeviceAssignment device_assignment) + : input_arity(input_arity.begin(), input_arity.end()), + device_assignment(std::move(device_assignment)) {} + +bool XrtContext::ExecuteReplicatedKey::operator==( + const ExecuteReplicatedKey& other) const { + return input_arity == other.input_arity && + device_assignment == other.device_assignment; +} + +xla::StatusOr XrtContext::GetExecuteReplicatedFunction( + absl::Span input_arity, + const xla::DeviceAssignment& device_assignment) { + ExecuteReplicatedKey key(input_arity, device_assignment); + + absl::MutexLock lock(&mu_); + auto it = replicated_fns_.find(key); + if (it != replicated_fns_.end()) { + return it->second; + } + + string name = absl::StrCat("ExecuteReplicated_", replicated_fns_.size()); + + FunctionDef fdef; + OpDef* opdef = fdef.mutable_signature(); + opdef->set_name(name); + OpDef::ArgDef* execution_handle = opdef->add_input_arg(); + execution_handle->set_name("execution_handle"); + execution_handle->set_type(DT_INT64); + + TF_RET_CHECK(device_assignment.computation_count() == input_arity.size()); + + std::vector execution_configs; + execution_configs.reserve(device_assignment.computation_count()); + for (int j = 0; j < device_assignment.computation_count(); ++j) { + OpDef::ArgDef* execution_config = opdef->add_input_arg(); + execution_config->set_name(absl::StrCat("execution_config_computation", j)); + execution_config->set_type(DT_STRING); + execution_configs.push_back(execution_config); + } + + for (int i = 0; i < device_assignment.replica_count(); ++i) { + for (int j = 0; j < device_assignment.computation_count(); ++j) { + NodeDef* ndef = fdef.add_node_def(); + ndef->set_name(absl::StrFormat("execute_replica%d_computation%d", i, j)); + ndef->set_op("XRTExecute"); + (*ndef->mutable_attr())["Ninputs"] = MakeAttrValue(input_arity[j]); + ndef->add_input(execution_handle->name()); + ndef->add_input(execution_configs[j]->name()); + int tf_device_id = tf_device_ids_.at(device_assignment(i, j)); + ndef->set_device(tf_context_->devices().at(tf_device_id).name()); + + for (int k = 0; k < input_arity[j]; ++k) { + OpDef::ArgDef* arg = opdef->add_input_arg(); + arg->set_name( + absl::StrFormat("in_replica%d_computation%d_arg%d", i, j, k)); + arg->set_type(DT_INT64); + + ndef->add_input(arg->name()); + } + OpDef::ArgDef* ret = opdef->add_output_arg(); + ret->set_name(absl::StrFormat("out_replica%d_computation%d", i, j)); + ret->set_type(DT_INT64); + + (*fdef.mutable_ret())[ret->name()] = + absl::StrCat(ndef->name(), ":output_handle"); + } + } + + VLOG(10) << fdef.DebugString(); + + Status status = tf_context_->RegisterFunction(fdef); + VLOG(4) << "GetExecuteReplicatedFunction returned " << status; + if (!status.ok()) return status; + + replicated_fns_[key] = name; + return name; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/client/xrt_client.h b/tensorflow/compiler/xrt/client/xrt_client.h new file mode 100644 index 00000000000..29469b0e888 --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_client.h @@ -0,0 +1,250 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains a C++ client for XRT, that communicates with a remote +// TensorFlow Eager server over gRPC. +// +// This client is a prototype and its API is not stable yet. +// +// TODO(phawkins): add support for multi-host configurations. +// * currently the API names accelerator devices using a flat space of device +// ordinals, with no particular meaning to the device ordinals. The plan is to +// instead to use the linearized device topology coordinates as device +// ordinals. + +#ifndef TENSORFLOW_COMPILER_XRT_CLIENT_XRT_CLIENT_H_ +#define TENSORFLOW_COMPILER_XRT_CLIENT_XRT_CLIENT_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xrt/client/xrt_tf_client.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" + +namespace tensorflow { + +class XrtContext; + +// RAII class that holds ownership of an XRT buffer. +class XrtBuffer { + public: + // Builds a new XrtBuffer from an XLA literal, copying the buffer to the + // remote host. + static xla::StatusOr> FromLiteral( + const std::shared_ptr& context, int xrt_device_ordinal, + const xla::LiteralSlice& literal); + + // Builds a new XrtBuffer tuple from its constituent parts. + static xla::StatusOr> MakeTuple( + const std::shared_ptr& context, + const std::vector>& elements); + + // Converts an XrtBuffer to an XLA literal, copying the buffer from the remote + // host. Blocks until the buffer is available. + xla::StatusOr ToLiteral() const; + + // Deletes the remote buffer. + void Delete(); + + // Destructures a tuple-shaped buffer into its constituent pieces. + xla::StatusOr>> DestructureTuple(); + + // TODO(phawkins): add a mechanism for converting XrtBuffers into remote + // tensors and vice-versa for TF interoperability. + + XrtBuffer() = default; + XrtBuffer(XrtTensorHandle handle, int xrt_device_ordinal, xla::Shape shape); + ~XrtBuffer(); // Calls Delete(). + + // A buffer reference is moveable but not copyable. + XrtBuffer(const XrtBuffer&) = delete; + XrtBuffer(XrtBuffer&&) = default; + XrtBuffer& operator=(const XrtBuffer&) = delete; + XrtBuffer& operator=(XrtBuffer&&) = default; + + const XrtTensorHandle& handle() const { return handle_; } + int xrt_device_ordinal() const { return xrt_device_ordinal_; } + const xla::Shape& shape() const { return shape_; } + + private: + // Tensor that contains the XRT allocation ID. + XrtTensorHandle handle_; + int xrt_device_ordinal_; + xla::Shape shape_; +}; + +// RAII class that holds ownership of an XRT executable. +class XrtExecutable { + public: + // Constructs an XrtExecutable by compiling a program. + // `xrt_device_ordinal` must be the ordinal of a device known to XrtContext + // on which the compile operator should be placed. + // `hlo_module_proto` is the serialized HLO program to compile. + // `argument_shapes` and `result_shape` describe the shapes of the + // arguments/result and their layout. + // `device_assignment` is the set of devices to which compilation should be + // targeted. The device numbers in the device assignment are the XRT device + // ordinals. + // TODO(phawkins): device assignments with more than one computation per + // replica do not work yet, even though the API appears to support them. + static xla::StatusOr> Compile( + std::shared_ptr context, + const xla::HloModuleProto& hlo_module_proto, + const std::vector& argument_shapes, + const xla::Shape& result_shape, xla::DeviceAssignment device_assignment); + + explicit XrtExecutable(std::shared_ptr context, + XrtTensorHandle handles, xla::ProgramShape shape, + xla::DeviceAssignment device_assignment); + ~XrtExecutable(); // Calls Delete(). + + // Deletes the XrtExecutable. + void Delete(); + + // Runs the executable. Simplified API without replication or model + // parallelism. + xla::StatusOr> Execute( + const std::vector>& args); + + // General API that runs replicated, model-parallel computations. + // + // Arguments are indexed by [computation][replica][arg]. Since each + // computation may have a different arity, we use a Span to represent + // a possibly ragged array. + // + // Return values are indexed by [computation][replica]. XLA computations + // always have exactly one return value, so there is no possibility of + // raggedness. + xla::StatusOr>> ExecuteReplicated( + absl::Span>> args); + + // Moveable but not copyable. + XrtExecutable(const XrtExecutable&) = delete; + XrtExecutable(XrtExecutable&&) = default; + XrtExecutable& operator=(const XrtExecutable&) = delete; + XrtExecutable& operator=(XrtExecutable&&) = default; + + const xla::DeviceAssignment& device_assignment() const { + return device_assignment_; + } + + private: + std::shared_ptr context_; + + // A copy of the executable's handle in host memory. If the computation is + // unreplicated, this lives on the target device. If the computation is + // replicated, this lives on the CPU device. + XrtTensorHandle handle_; + xla::ProgramShape shape_; + + // The TF device ordinal on which this handle was compiled and on which it + // should be deleted. + xla::DeviceAssignment device_assignment_; +}; + +// Manages an XRT session. +// +// The XrtTfClient/XrtTfContext classes wrap the TensorFlow API more directly, +// without any XRT-specific knowledge. The higher level XrtClient +// adds XRT-specific functionality on top. +// +// It is intended that all clients talking to the same XRT session use the same +// XrtContext and that objects such as buffers and executables must not be +// shared between XrtContexts. However, clients may run non-XRT TensorFlow ops +// using the XrtTfContext that underlies an XrtContext. +// +// TODO(phawkins): Currently this code only supports a single remote host; each +// XrtContext communicates via a single XrtTfContext. The plan is to support +// multihost configurations (e.g., TPU pods) in the future, in which case +// XrtContext will be extended to have one XrtTfContext per remote host. +// +// TODO(phawkins): This API is intended to be thread-safe, but this is untested. +class XrtContext { + public: + // Creates an XrtContext. Fails if no accelerators of 'device_type' are found. + static xla::StatusOr> Create( + std::shared_ptr tf_context, string device_type); + + // Use Create() instead. + XrtContext(std::shared_ptr tf_context, string device_type); + + // Returns the number of accelerator devices of 'device_type'. + int device_count() const; + + const std::shared_ptr& tf_context() const { + return tf_context_; + } + const std::vector& tf_device_ids() const { return tf_device_ids_; } + + const std::vector< + xrt::DeviceAssignment::ComputationDevice::DeviceMeshCoordinates>& + device_mesh_coordinates() const { + return device_mesh_coordinates_; + } + + private: + friend class XrtExecutable; + + const std::shared_ptr tf_context_; + const string device_type_; // Type of accelerator device to use (e.g., TPU) + + // Initializes TPU devices. Synchronous; called by Create(). + Status InitializeTPU(); + + // IDs of devices of type `device_type_` in `tf_context_`. + std::vector tf_device_ids_; + + // Device coordinates of each device, indexed by XRT device ordinal. + std::vector + device_mesh_coordinates_; + + // Returns the name of a function that launches a replicated computation + // with input arity `input_arity` and device assignment `device_assignment`. + xla::StatusOr GetExecuteReplicatedFunction( + absl::Span input_arity, + const xla::DeviceAssignment& device_assignment); + + struct ExecuteReplicatedKey { + ExecuteReplicatedKey(absl::Span input_arity, + xla::DeviceAssignment device_assignment); + std::vector input_arity; + xla::DeviceAssignment device_assignment; + bool operator==(const ExecuteReplicatedKey& other) const; + }; + template + friend H AbslHashValue(H h, const ExecuteReplicatedKey& key); + + absl::Mutex mu_; + absl::flat_hash_map replicated_fns_ + GUARDED_BY(mu_); +}; + +template +H AbslHashValue(H h, const XrtContext::ExecuteReplicatedKey& key) { + h = H::combine_contiguous(std::move(h), key.input_arity.data(), + key.input_arity.size()); + return H::combine_contiguous(std::move(h), key.device_assignment.data(), + key.device_assignment.num_elements()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_CLIENT_XRT_CLIENT_H_ diff --git a/tensorflow/compiler/xrt/client/xrt_client_test.cc b/tensorflow/compiler/xrt/client/xrt_client_test.cc new file mode 100644 index 00000000000..e64c986f44e --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_client_test.cc @@ -0,0 +1,348 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/client/xrt_client.h" + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h" +#include "tensorflow/compiler/xrt/client/xrt_tf_client.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class XrtClientTest : public ::testing::Test { + protected: + XrtClientTest() { + string binary_path = + absl::StrCat(testing::TensorFlowSrcRoot(), + "/compiler/xrt/client/xrt_testlib_server"); + + TF_CHECK_OK(test::TestCluster::MakeTestCluster( + binary_path, SessionOptions(), /*n=*/1, &cluster_)); + + CHECK_EQ(cluster_->targets().size(), 1); + JobDef* job = cluster_def_.add_job(); + job->set_name("localhost"); + (*job->mutable_tasks())[0] = cluster_->targets()[0]; + } + + xla::StatusOr> MakeContext(); + + std::unique_ptr cluster_; + ClusterDef cluster_def_; +}; + +// Test some connection basics using XrtGrpcEagerClient directly. +TEST_F(XrtClientTest, XrtGrpcEagerClientWorks) { + ChannelCreationFunction channel_func = + ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr channel_cache, + GetGrpcChannelCache(cluster_def_, channel_func)); + XrtGrpcEagerClientCache client_cache(channel_cache); + + TF_ASSERT_OK_AND_ASSIGN( + XrtGrpcEagerClient * client, + client_cache.GetClient("/job:localhost/task:0/replica:0")); + + // Create and destroy a context to verify we can make RPCs. + eager::CreateContextRequest request; + ServerDef* server_def = request.mutable_server_def(); + *server_def->mutable_cluster() = cluster_def_; + server_def->set_job_name("localhost"); + server_def->set_protocol("grpc"); + request.set_keep_alive_secs(60); + request.set_rendezvous_id(random::New64()); + + eager::CreateContextResponse create_response; + TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CreateContextAsync, + &request, &create_response)); + + eager::CloseContextRequest close_request; + close_request.set_context_id(create_response.context_id()); + + eager::CloseContextResponse close_response; + TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CloseContextAsync, + &close_request, &close_response)); +} + +// Tests that we can connect to a server using the higher-level XrtTfClient API, +// transfer tensors to the device, run an Add operator, and retrieve the result. +TEST_F(XrtClientTest, XrtTfClientWorks) { + ChannelCreationFunction channel_func = + ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr channel_cache, + GetGrpcChannelCache(cluster_def_, channel_func)); + + auto client = std::make_shared(cluster_def_, channel_cache); + TF_ASSERT_OK_AND_ASSIGN( + std::shared_ptr context, + XrtTfContext::Create(XrtTfContext::Options(), client, /*job=*/"localhost", + /*task=*/0)); + + auto a_proto = absl::make_unique(); + a_proto->set_dtype(DT_INT32); + a_proto->add_int_val(47); + XrtTensorHandle a = + context->SendTensor(std::move(a_proto), context->cpu_device_id()); + auto b_proto = absl::make_unique(); + b_proto->set_dtype(DT_INT32); + b_proto->mutable_tensor_shape()->add_dim()->set_size(2); + b_proto->add_int_val(-101); + b_proto->add_int_val(3); + XrtTensorHandle b = + context->SendTensor(std::move(b_proto), context->cpu_device_id()); + + protobuf::Map attrs; + attrs["T"] = MakeAttrValue(DT_INT32); + std::vector add_outputs = context->EnqueueOp( + "Add", {&a, &b}, /*output_arity=*/1, attrs, context->cpu_device_id()); + ASSERT_EQ(add_outputs.size(), 1); + + std::shared_ptr future = + context->RecvTensor(add_outputs[0], DT_INT32, /*host_memory=*/false); + + TF_ASSERT_OK_AND_ASSIGN(RecvTensorResponse * response, future->Get()); + const TensorProto& out_proto = response->tensor(); + EXPECT_EQ(out_proto.dtype(), DT_INT32); + + ASSERT_EQ(out_proto.tensor_content().size(), sizeof(int32) * 2); + std::vector out(2); + out_proto.tensor_content().CopyToArray(reinterpret_cast(out.data())); + // TODO(phawkins): handle endian conversion. + EXPECT_EQ(out[0], -54); + EXPECT_EQ(out[1], 50); +} + +xla::StatusOr> XrtClientTest::MakeContext() { + ChannelCreationFunction channel_func = + ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + TF_ASSIGN_OR_RETURN(std::shared_ptr channel_cache, + GetGrpcChannelCache(cluster_def_, channel_func)); + + auto client = std::make_shared(cluster_def_, channel_cache); + TF_ASSIGN_OR_RETURN( + std::shared_ptr tf_context, + XrtTfContext::Create(XrtTfContext::Options(), client, /*job=*/"localhost", + /*task=*/0)); + + TF_ASSIGN_OR_RETURN(auto context, XrtContext::Create(tf_context, "XLA_CPU")); + + // There should be exactly one XLA_CPU device. + TF_RET_CHECK(context->device_count() == 1); + return context; +} + +// Tests that we can use the XRT client to perform some simple operations. +TEST_F(XrtClientTest, XrtClientWorks) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); + + ASSERT_TRUE(context->tf_context() != nullptr); + + EXPECT_EQ(context->tf_device_ids().size(), 1); + + ASSERT_EQ(context->device_mesh_coordinates().size(), 1); + ASSERT_EQ(context->device_mesh_coordinates()[0].value_size(), 1); + EXPECT_EQ(context->device_mesh_coordinates()[0].value(0), 0); + + // Tests sending a literal to and from the device. + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a, + xla::LiteralUtil::CreateRandomLiteral( + shape, + /*mean=*/7.0, /*stddev=*/13.5)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer, XrtBuffer::FromLiteral(context, 0, a)); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b, buffer->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(a, b)); + + // Run a simple computation, fetch its output, and check it is what we expect. + auto build_computation = [&]() { + xla::XlaBuilder builder("test_computation"); + xla::XlaOp p = xla::Parameter(&builder, 0, shape, "param"); + xla::Add(p, p); + return builder.Build(); + }; + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, build_computation()); + + TF_ASSERT_OK_AND_ASSIGN(xla::DeviceAssignment assignment, + xla::ComputationPlacer().AssignDevices(1, 1)); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + XrtExecutable::Compile(context, computation.proto(), + {shape}, shape, assignment)); + EXPECT_EQ(executable->device_assignment(), assignment); + TF_ASSERT_OK_AND_ASSIGN(auto c_buffer, executable->Execute({buffer})); + + xla::Literal expected = a.Clone(); + for (float& elem : expected.data()) { + elem *= 2; + } + TF_ASSERT_OK_AND_ASSIGN(xla::Literal out, c_buffer->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, out)); + + // Explicitly delete the executable, and then compile and run a different + // computation. + executable->Delete(); + + auto build_sub_computation = [&]() { + xla::XlaBuilder builder("test_computation"); + xla::XlaOp p = xla::Parameter(&builder, 0, shape, "p"); + xla::XlaOp q = xla::Parameter(&builder, 1, shape, "q"); + xla::Sub(p, q); + return builder.Build(); + }; + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation sub_computation, + build_sub_computation()); + + TF_ASSERT_OK_AND_ASSIGN( + auto sub_executable, + XrtExecutable::Compile(context, sub_computation.proto(), {shape, shape}, + shape, assignment)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_out, + sub_executable->Execute({c_buffer, buffer})); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal sub_out, buffer_out->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(a, sub_out)); +} + +TEST_F(XrtClientTest, ErrorsPropagateCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a, + xla::LiteralUtil::CreateRandomLiteral( + shape, + /*mean=*/7.0, /*stddev=*/13.5)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer, XrtBuffer::FromLiteral(context, 0, a)); + + auto build_computation = [&]() { + xla::XlaBuilder builder("test_computation"); + xla::XlaOp p = xla::Parameter(&builder, 0, shape, "p"); + xla::XlaOp q = xla::Parameter(&builder, 1, shape, "q"); + xla::Add(p, q); + return builder.Build(); + }; + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, build_computation()); + + TF_ASSERT_OK_AND_ASSIGN(xla::DeviceAssignment assignment, + xla::ComputationPlacer().AssignDevices(1, 1)); + // Call Compile() with an arity mismatch. + TF_ASSERT_OK_AND_ASSIGN(auto sub_executable, + XrtExecutable::Compile(context, computation.proto(), + {shape}, shape, assignment)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer_out, sub_executable->Execute({buffer})); + + // The compilation error should be reported when we consumer the computation's + // output. + EXPECT_FALSE(buffer_out->ToLiteral().ok()); + + // Further, we expect a clean shutdown at this point. + context = nullptr; +} + +TEST_F(XrtClientTest, TupleDestructuringAndDelete) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); + + // Tests sending a literal to and from the device. + xla::Shape a_shape = xla::ShapeUtil::MakeShape(xla::F32, {3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a, + xla::LiteralUtil::CreateRandomLiteral( + a_shape, + /*mean=*/7.0, /*stddev=*/13.5)); + + xla::Shape b_shape = xla::ShapeUtil::MakeShape(xla::F64, {2, 7}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b, + xla::LiteralUtil::CreateRandomLiteral( + b_shape, + /*mean=*/3.15, /*stddev=*/-2.1)); + xla::Literal tuple = xla::LiteralUtil::MakeTuple({&a, &b}); + TF_ASSERT_OK_AND_ASSIGN(auto buffer, + XrtBuffer::FromLiteral(context, 0, tuple)); + + TF_ASSERT_OK_AND_ASSIGN(std::vector> pieces, + buffer->DestructureTuple()); + + // Explicitly delete the tuple, which should have no effect on its + // constituents. + buffer->Delete(); + + ASSERT_EQ(pieces.size(), 2); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a_out, pieces[0]->ToLiteral()); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b_out, pieces[1]->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(a, a_out)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(b, b_out)); + + // Explicitly delete one of the pieces, use RAII to delete the other. + pieces[1]->Delete(); +} + +TEST_F(XrtClientTest, TupleConstructionAndDestructuring) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr context, MakeContext()); + + // Tests sending a literal to and from the device. + xla::Shape a_shape = xla::ShapeUtil::MakeShape(xla::F32, {3, 4, 5}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a, + xla::LiteralUtil::CreateRandomLiteral( + a_shape, + /*mean=*/7.0, /*stddev=*/13.5)); + TF_ASSERT_OK_AND_ASSIGN(auto a_buffer, XrtBuffer::FromLiteral(context, 0, a)); + + xla::Shape b_shape = xla::ShapeUtil::MakeShape(xla::F64, {2, 7}); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b, + xla::LiteralUtil::CreateRandomLiteral( + b_shape, + /*mean=*/3.15, /*stddev=*/-2.1)); + TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, XrtBuffer::FromLiteral(context, 0, b)); + + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a_in, a_buffer->ToLiteral()); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b_in, b_buffer->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(a, a_in)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(b, b_in)); + + std::vector> elems = {a_buffer, b_buffer}; + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr buffer, + XrtBuffer::MakeTuple(context, elems)); + TF_ASSERT_OK_AND_ASSIGN(std::vector> pieces, + buffer->DestructureTuple()); + + // Explicitly delete the tuple, which should have no effect on its + // constituents. + buffer->Delete(); + + ASSERT_EQ(pieces.size(), 2); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal a_out, pieces[0]->ToLiteral()); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal b_out, pieces[1]->ToLiteral()); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(a, a_out)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(b, b_out)); + + // Explicitly delete one of the pieces, use RAII to delete the other. + pieces[1]->Delete(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc new file mode 100644 index 00000000000..39c83c14f0a --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" + +namespace tensorflow { + +XrtGrpcEagerClient::XrtGrpcEagerClient(const SharedGrpcChannelPtr& channel, + ::grpc::CompletionQueue* cq) + : stub_(channel), cq_(cq) {} + +#define EAGER_CLIENT_METHOD(method) \ + void XrtGrpcEagerClient::method##Async( \ + const eager::method##Request* request, \ + eager::method##Response* response, StatusCallback done, \ + CallOptions* call_opts) { \ + new RPCState( \ + &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \ + response, std::move(done), call_opts, nullptr); \ + } + +EAGER_CLIENT_METHOD(CreateContext); +EAGER_CLIENT_METHOD(Enqueue); +EAGER_CLIENT_METHOD(WaitQueueDone); +EAGER_CLIENT_METHOD(KeepAlive); +EAGER_CLIENT_METHOD(CloseContext); +EAGER_CLIENT_METHOD(RegisterFunction); +EAGER_CLIENT_METHOD(SendTensor); +#undef EAGER_CLIENT_METHOD + +#define WORKER_CLIENT_METHOD(method) \ + void XrtGrpcEagerClient::method##Async( \ + const method##Request* request, method##Response* response, \ + StatusCallback done, CallOptions* call_opts) { \ + new RPCState( \ + &stub_, cq_, "/tensorflow.WorkerService/" #method, *request, response, \ + std::move(done), call_opts, nullptr); \ + } + +WORKER_CLIENT_METHOD(GetStatus); +WORKER_CLIENT_METHOD(RecvTensor); +#undef WORKER_CLIENT_METHOD + +class XrtGrpcEagerClientThread { + public: + XrtGrpcEagerClientThread() { + thread_.reset(Env::Default()->StartThread( + ThreadOptions(), "xrt_eager_client_thread", [this]() { + void* tag; + bool ok; + while (completion_queue_.Next(&tag, &ok)) { + GrpcClientCQTag* callback_tag = static_cast(tag); + callback_tag->OnCompleted(ok); + } + })); + } + + ~XrtGrpcEagerClientThread() { + completion_queue_.Shutdown(); + thread_.reset(); + } + + ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; } + + private: + ::grpc::CompletionQueue completion_queue_; + std::unique_ptr thread_; +}; // XrtGrpcEagerClientThread + +XrtGrpcEagerClientCache::XrtGrpcEagerClientCache( + std::shared_ptr channel_cache) + : next_round_robin_assignment_(0), cache_(channel_cache), threads_(4) {} + +XrtGrpcEagerClientCache::~XrtGrpcEagerClientCache() { threads_.clear(); } + +xla::StatusOr XrtGrpcEagerClientCache::GetClient( + const string& target) { + auto it = clients_.find(target); + if (it == clients_.end()) { + tensorflow::SharedGrpcChannelPtr shared = cache_->FindWorkerChannel(target); + if (!shared) { + return errors::NotFound("Unknown target ", target); + } + auto worker = absl::make_unique( + shared, threads_[AssignClientToThread(target)].completion_queue()); + + it = clients_.emplace(target, std::move(worker)).first; + } + + return it->second.get(); +} + +size_t XrtGrpcEagerClientCache::AssignClientToThread(const string& target) { + // Round-robin target assignment, but keeps the same target on the same + // polling thread always, as this is important for gRPC performance. + mutex_lock lock(assignment_mu_); + auto it = target_assignments_.find(target); + if (it == target_assignments_.end()) { + it = target_assignments_ + .insert(std::make_pair( + target, (next_round_robin_assignment_++) % threads_.size())) + .first; + } + return it->second; +} + +xla::StatusOr> GetGrpcChannelCache( + const ClusterDef& cluster_def, ChannelCreationFunction channel_func) { + GrpcChannelSpec channel_spec; + for (const JobDef& job : cluster_def.job()) { + std::map host_ports(job.tasks().begin(), job.tasks().end()); + TF_RETURN_IF_ERROR(channel_spec.AddHostPortsJob(job.name(), host_ports)); + } + return std::shared_ptr( + NewGrpcChannelCache(channel_spec, channel_func)); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h new file mode 100644 index 00000000000..18275c7d002 --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Self-contained client for communicating with a TensorFlow Eager remote +// service over gRPC. +// +// Unlike, say, the TensorFlow C API, this class is intended to be a +// self-contained, minimal-dependency way to interact with a remote TF eager +// server, containing just enough functionality for the XRT use case. + +#ifndef TENSORFLOW_COMPILER_XRT_CLIENT_XRT_GRPC_EAGER_CLIENT_H_ +#define TENSORFLOW_COMPILER_XRT_CLIENT_XRT_GRPC_EAGER_CLIENT_H_ + +#include "grpcpp/generic/generic_stub.h" +#include "absl/synchronization/notification.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +// This class is a self-contained cousin of the standard EagerClient class. +// Unlike EagerClient, this class includes all of the methods needed by XRT, +// including methods from both EagerService and WorkerService. +// This reduces the dependency footprint, since in particular RecvTensor's +// implementation depends on a bunch of TF framework infrastructure (e.g., +// Device, Tensor) that we don't need for the XRT client use case. +class XrtGrpcEagerClient { + public: + XrtGrpcEagerClient(const SharedGrpcChannelPtr& channel, + ::grpc::CompletionQueue* cq); + ~XrtGrpcEagerClient() = default; + + XrtGrpcEagerClient(const XrtGrpcEagerClient&) = delete; + XrtGrpcEagerClient(XrtGrpcEagerClient&&) = delete; + XrtGrpcEagerClient& operator=(const XrtGrpcEagerClient&) = delete; + XrtGrpcEagerClient& operator=(XrtGrpcEagerClient&&) = delete; + + void CreateContextAsync(const eager::CreateContextRequest* request, + eager::CreateContextResponse* response, + StatusCallback done, + CallOptions* call_opts = nullptr); + void EnqueueAsync(const eager::EnqueueRequest* request, + eager::EnqueueResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); + void WaitQueueDoneAsync(const eager::WaitQueueDoneRequest* request, + eager::WaitQueueDoneResponse* response, + StatusCallback done, + CallOptions* call_opts = nullptr); + void KeepAliveAsync(const eager::KeepAliveRequest* request, + eager::KeepAliveResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); + void CloseContextAsync(const eager::CloseContextRequest* request, + eager::CloseContextResponse* response, + StatusCallback done, CallOptions* call_opts = nullptr); + void RegisterFunctionAsync(const eager::RegisterFunctionRequest* request, + eager::RegisterFunctionResponse* response, + StatusCallback done, + CallOptions* call_opts = nullptr); + void SendTensorAsync(const eager::SendTensorRequest* request, + eager::SendTensorResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); + + // The following two methods are actually from the WorkerService API, not + // EagerService, but are necessary for using remote Eager, and we include them + // here for self-containedness. + + // We use RecvTensor to copy tensors back from a remote worker to the client. + void RecvTensorAsync(const RecvTensorRequest* request, + RecvTensorResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); + + // We use GetStatus to discover device incarnation values for use in + // RecvTensor. + // TODO(phawkins): We need to call GetStatus to work around a bug in the + // TFE server implementation. Remove this API call and use the device + // information from CreateContext once the bug fix is deployed everywhere. + void GetStatusAsync(const GetStatusRequest* request, + GetStatusResponse* response, StatusCallback done, + CallOptions* call_opts = nullptr); + + // Helper method for calling any of the ...Async methods synchronously. + template + Status SyncCall(Method m, const Request* request, Response* response, + CallOptions* call_opts = nullptr) { + absl::Notification done; + Status status; + (this->*(m))( + request, response, + [&](Status s) { + status = s; + done.Notify(); + }, + call_opts); + done.WaitForNotification(); + return status; + } + + private: + ::grpc::GenericStub stub_; + ::grpc::CompletionQueue* cq_; +}; + +class XrtGrpcEagerClientThread; + +// Simple wrapper class that can be used to retrieve XrtGrpcEagerClients. +class XrtGrpcEagerClientCache { + public: + explicit XrtGrpcEagerClientCache( + std::shared_ptr channel_cache); + ~XrtGrpcEagerClientCache(); + + XrtGrpcEagerClientCache(const XrtGrpcEagerClientCache&) = delete; + XrtGrpcEagerClientCache(XrtGrpcEagerClientCache&&) = delete; + XrtGrpcEagerClientCache& operator=(const XrtGrpcEagerClientCache&) = delete; + XrtGrpcEagerClientCache& operator=(XrtGrpcEagerClientCache&&) = delete; + + // Returns a cached client for 'target'. 'target' should be a task name known + // te the channel cache, e.g., "/job:worker/task:0/replica:0". + xla::StatusOr GetClient(const string& target); + + private: + size_t AssignClientToThread(const string& target); + + mutex assignment_mu_; + std::unordered_map target_assignments_ + GUARDED_BY(assignment_mu_); + size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_); + + std::shared_ptr cache_; + std::unordered_map> clients_; + std::vector threads_; +}; + +// Builds a GrpcChannelCache for a TF cluster `cluster_def`. `channel_func` +// is a function to use to create channels; it is client-provided so clients can +// set up custom authentication, etc. +xla::StatusOr> GetGrpcChannelCache( + const ClusterDef& cluster_def, ChannelCreationFunction channel_func); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_CLIENT_XRT_GRPC_EAGER_CLIENT_H_ diff --git a/tensorflow/compiler/xrt/client/xrt_tf_client.cc b/tensorflow/compiler/xrt/client/xrt_tf_client.cc new file mode 100644 index 00000000000..5388338fd36 --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_tf_client.cc @@ -0,0 +1,533 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/client/xrt_tf_client.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/distributed_runtime/request_id.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace tensorflow { + +XrtTfClient::XrtTfClient(ClusterDef cluster_def, + std::shared_ptr channel_cache) + : cluster_def_(cluster_def), channel_cache_(std::move(channel_cache)) { + eager_client_cache_ = + absl::make_unique(channel_cache_); +} + +xla::StatusOr> XrtTfContext::Create( + const XrtTfContext::Options& options, + std::shared_ptr tf_client, const std::string& job, int task) { + int64 rendezvous_id = random::New64(); + + eager::CreateContextRequest request; + ServerDef* server_def = request.mutable_server_def(); + *server_def->mutable_cluster() = tf_client->cluster_def(); + server_def->set_job_name(job); + server_def->set_protocol("grpc"); + request.set_keep_alive_secs(60); + request.set_rendezvous_id(rendezvous_id); + request.set_async(options.async); + + eager::CreateContextResponse response; + + std::string target = absl::StrFormat("/job:%s/task:%d/replica:0", job, task); + TF_ASSIGN_OR_RETURN(XrtGrpcEagerClient * eager_client, + tf_client->eager_client_cache()->GetClient(target)); + + TF_RETURN_IF_ERROR(eager_client->SyncCall( + &XrtGrpcEagerClient::CreateContextAsync, &request, &response)); + + // Due to a TFE server-side bug, devices returned by the eager CreateContext + // method have the wrong device incarnation numbers, which we need to call + // RecvTensor. Use the device attributes from WorkerService.GetStatus instead. + // TODO(phawkins): revert to using device information from CreateContext once + // the bug is fixed. + GetStatusRequest status_request; + GetStatusResponse status_response; + TF_RETURN_IF_ERROR(eager_client->SyncCall(&XrtGrpcEagerClient::GetStatusAsync, + &status_request, &status_response)); + + std::vector devices( + status_response.device_attributes().begin(), + status_response.device_attributes().end()); + VLOG(1) << "Remote devices: " << devices.size(); + int cpu_device_id = -1; + for (int i = 0; i < devices.size(); ++i) { + const auto& device = devices[i]; + VLOG(2) << "Remote device: " << device.DebugString(); + if (cpu_device_id < 0 && device.device_type() == "CPU") { + cpu_device_id = i; + VLOG(1) << "Remote CPU device: " << i << " name: " << device.name(); + } + } + if (cpu_device_id < 0) { + return errors::FailedPrecondition( + "Remote TensorFlow worker does not have a CPU device."); + } + + std::sort(devices.begin(), devices.end(), + [](const DeviceAttributes& a, const DeviceAttributes& b) { + return a.name() < b.name(); + }); + return std::make_shared(options, tf_client, eager_client, + rendezvous_id, response.context_id(), + std::move(devices), cpu_device_id); +} + +XrtTfContext::XrtTfContext(const XrtTfContext::Options& options, + std::shared_ptr tf_client, + XrtGrpcEagerClient* eager_client, + int64 rendezvous_id, int64 context_id, + std::vector devices, + int cpu_device_id) + : options_(options), + tf_client_(tf_client), + eager_client_(eager_client), + rendezvous_id_(rendezvous_id), + context_id_(context_id), + devices_(std::move(devices)), + cpu_device_id_(cpu_device_id) { + CHECK_GE(cpu_device_id_, 0); + enqueue_request_ = absl::make_unique(); + queue_thread_.reset(Env::Default()->StartThread(ThreadOptions(), + "xrt_tf_client_queue_thread", + [this]() { QueueThread(); })); +} + +XrtTfContext::~XrtTfContext() { + Status status = Close(); + if (!status.ok()) { + LOG(ERROR) << "XrtTfContext::Close failed with error: " << status; + } +} + +Status XrtTfContext::Close() { + { + absl::MutexLock lock(&mu_); + shutting_down_ = true; + } + + eager::CloseContextRequest request; + request.set_context_id(context_id_); + + Status status; + absl::Notification done; + eager::CloseContextResponse response; + eager_client_->CloseContextAsync(&request, &response, [&](Status s) { + status = s; + done.Notify(); + }); + done.WaitForNotification(); + return status; +} + +void XrtTfContext::QueueThread() { + auto should_flush_queue = [this]() { + mu_.AssertHeld(); // For annotalysis. + return enqueue_request_->queue_size() > options_.max_queue_size || + flush_requested_ || shutting_down_; + }; + while (true) { + auto request = absl::make_unique(); + { + absl::MutexLock lock(&mu_); + // To keep the connection alive, make sure we send an EnqueueRequest + // regularly, currently every 5 seconds. + mu_.AwaitWithTimeout(absl::Condition(&should_flush_queue), + absl::Seconds(5)); + if (shutting_down_) break; + std::swap(request, enqueue_request_); + flush_requested_ = false; + } + + std::vector op_ids; + for (const auto& item : request->queue()) { + if (item.has_operation()) { + op_ids.push_back(item.operation().id()); + } + } + request->set_context_id(context_id_); + + VLOG(10) << "Enqueue:\n" << request->DebugString(); + eager::EnqueueResponse response; + Status status; + absl::Notification done; + eager_client_->EnqueueAsync(request.get(), &response, [&](Status s) { + status = s; + done.Notify(); + }); + + done.WaitForNotification(); + + VLOG(10) << "EnqueueResponse: " << status << "\n" << response.DebugString(); + { + absl::MutexLock lock(&mu_); + if (status.ok()) { + for (OperationId op_id : op_ids) { + DeleteOperation(op_id); + } + } else { + ReportError(op_ids, status); + } + } + } +} + +void XrtTfContext::ReportError(absl::Span op_ids, + Status status) { + auto shared_error = std::make_shared(status); + absl::flat_hash_set visited(op_ids.begin(), op_ids.end()); + std::stack stack; + for (OperationId op_id : op_ids) { + stack.push(LookupOperation(op_id)); + } + while (!stack.empty()) { + Operation* op = stack.top(); + stack.pop(); + VLOG(10) << "Reporting error for " << op->id; + for (const std::shared_ptr& future : + op->tensor_futures) { + VLOG(10) << "Reporting error for " << op->id << " future"; + future->call_options_.StartCancel(); + future->Notify(status); + } + for (OperationId consumer_id : op->consumers) { + Operation* consumer = LookupOperation(consumer_id); + stack.push(consumer); + } + DeleteOperation(op->id); + } +} + +XrtTfContext::Operation* XrtTfContext::AddOperation() { + OperationId id = ++next_op_id_; + auto result = operations_.emplace(id, Operation(id)); + return &result.first->second; +} + +void XrtTfContext::DeleteOperation(OperationId id) { + CHECK_GT(operations_.erase(id), 0); +} + +XrtTfContext::Operation* XrtTfContext::LookupOperation(OperationId id) { + auto it = operations_.find(id); + CHECK(it != operations_.end()) << id; + return &it->second; +} + +std::vector XrtTfContext::EnqueueOp( + absl::string_view name, absl::Span inputs, + int output_arity, protobuf::Map attrs, + int device_id, std::shared_ptr future) { + std::vector outputs; + absl::MutexLock lock(&mu_); + Operation* op = AddOperation(); + + eager::Operation* proto = enqueue_request_->add_queue()->mutable_operation(); + proto->set_id(op->id); + proto->set_name(static_cast(name)); + for (const XrtTensorHandle* input : inputs) { + input->Serialize(proto->add_inputs()); + } + proto->mutable_attrs()->swap(attrs); + proto->set_device(devices_.at(device_id).name()); + + outputs.reserve(output_arity); + for (int i = 0; i < output_arity; ++i) { + outputs.push_back( + XrtTensorHandle(shared_from_this(), device_id, TensorId{op->id, i})); + } + if (future) { + op->tensor_futures.push_back(future); + } + + return outputs; +} + +XrtTensorHandle XrtTfContext::SendTensor( + std::unique_ptr tensor_proto, int device_id, + bool host_memory) { + DataType dtype = tensor_proto->dtype(); + bool transfer_via_cpu_device = host_memory && device_id != cpu_device_id_; + int rpc_device_id = transfer_via_cpu_device ? cpu_device_id_ : device_id; + OperationId op_id; + { + absl::MutexLock lock(&mu_); + Operation* op = AddOperation(); + op_id = op->id; + } + + eager::SendTensorRequest request; + request.set_context_id(context_id_); + request.set_op_id(op_id); + request.mutable_tensors()->AddAllocated(tensor_proto.release()); + request.set_device_name(devices_.at(rpc_device_id).name()); + auto response = std::make_shared(); + auto context_ptr = shared_from_this(); + absl::Notification done; + eager_client_->SendTensorAsync( + &request, response.get(), + [context_ptr, op_id, response, &done](Status status) { + absl::MutexLock lock(&context_ptr->mu_); + if (!status.ok()) { + context_ptr->ReportError({op_id}, status); + } else { + context_ptr->DeleteOperation(op_id); + } + done.Notify(); + }); + XrtTensorHandle handle(context_ptr, rpc_device_id, TensorId{op_id, 0}); + + // TODO(phawkins): we block here to avoid a race. We must not + // enqueue any dependent operations until the SendTensor has been + // acknowledged. + done.WaitForNotification(); + + // TODO(phawkins): EagerService.SendTensor could use a host_memory option. + if (!transfer_via_cpu_device) { + return handle; + } + std::string wire_id = XrtGetUniqueWireID(); + EnqueueSend(this, handle, dtype, device_id, wire_id, /*host_memory=*/false); + return EnqueueRecv(this, dtype, rpc_device_id, device_id, wire_id, + /*host_memory=*/true); +} + +// This gets a unique wire ID. We add a random identifier so that if the +// worker has other clients that it is servicing, we don't have any collision. +std::string XrtGetUniqueWireID() { + static uint64 random_seed = random::New64(); + static std::atomic wireid(0); + return absl::StrCat(random_seed, "_", ++wireid); +} + +static std::string GetReceiverDevice(XrtTfContext* context, + int recv_device_id) { + if (recv_device_id < 0) { + return "/job:xrt_client/task:0/replica:0/device:CPU:0"; + } else { + return context->devices().at(recv_device_id).name(); + } +} + +static std::string GetRendezvousKey(absl::string_view send_device, + absl::string_view recv_device, + const uint64 send_device_incarnation, + absl::string_view tensor_name) { + return absl::StrCat(send_device, ";", + strings::FpToString(send_device_incarnation), ";", + recv_device, ";", tensor_name, ";0:0"); +} + +std::shared_ptr XrtTfContext::RecvTensor( + const XrtTensorHandle& tensor, DataType dtype, bool host_memory) { + auto response = std::make_shared(); + + int device_id = tensor.device_id(); + + std::string wire_id = XrtGetUniqueWireID(); + EnqueueSend(this, tensor, dtype, /*recv_device_id=*/-1, wire_id, + /*host_memory=*/host_memory, /*future=*/response); + + const DeviceAttributes& device = devices().at(device_id); + RecvTensorRequest request; + request.set_step_id(rendezvous_id_); + request.set_rendezvous_key(GetRendezvousKey(device.name(), + GetReceiverDevice(this, -1), + device.incarnation(), wire_id)); + request.set_request_id(GetUniqueRequestId()); + // TODO(phawkins): verify uniqueness of request ID. Random IDs won't collide + // with high probability, but we should probably add code to guard against + // collisions nonetheless. + + eager_client_->RecvTensorAsync( + &request, &response->value_, + [response, wire_id](Status status) { + VLOG(10) << "RecvTensor complete for " << wire_id; + response->Notify(status); + }, + &response->call_options_); + return response; +} + +Status XrtTfContext::RegisterFunction(const FunctionDef& def) { + eager::RegisterFunctionRequest request; + request.set_context_id(context_id_); + *request.mutable_function_def() = def; + + eager::RegisterFunctionResponse response; + Status status; + absl::Notification done; + eager_client_->RegisterFunctionAsync(&request, &response, [&](Status s) { + status = s; + done.Notify(); + }); + done.WaitForNotification(); + return status; +} +void XrtTfContext::EnqueueDecrefTensorHandle(eager::RemoteTensorHandle handle) { + absl::MutexLock lock(&mu_); + eager::QueueItem* item = enqueue_request_->add_queue(); + *item->mutable_handle_to_decref() = handle; +} + +void XrtTfContext::FlushQueue() { + absl::MutexLock lock(&mu_); + FlushQueueLocked(); +} + +void XrtTfContext::FlushQueueLocked() { flush_requested_ = true; } + +XrtTensorHandle::XrtTensorHandle() = default; +XrtTensorHandle::~XrtTensorHandle() { + if (context_) { + eager::RemoteTensorHandle proto; + Serialize(&proto); + context_->EnqueueDecrefTensorHandle(proto); + } +} + +XrtTensorHandle::XrtTensorHandle(XrtTensorHandle&& other) { + context_ = other.context_; + device_id_ = other.device_id_; + tensor_id_ = other.tensor_id_; + + other.context_ = nullptr; + other.device_id_ = -1; + other.tensor_id_ = XrtTfContext::TensorId{-1, -1}; +} + +XrtTensorHandle& XrtTensorHandle::operator=(XrtTensorHandle&& other) { + context_ = other.context_; + device_id_ = other.device_id_; + tensor_id_ = other.tensor_id_; + + other.context_ = nullptr; + other.device_id_ = -1; + other.tensor_id_ = XrtTfContext::TensorId{-1, -1}; + return *this; +} + +void XrtTensorHandle::Serialize(eager::RemoteTensorHandle* proto) const { + proto->set_op_id(tensor_id_.first); + proto->set_output_num(tensor_id_.second); +} + +AttrValue MakeAttrValue(std::string s) { + AttrValue a; + a.set_s(std::move(s)); + return a; +} + +AttrValue MakeAttrValue(int64 i) { + AttrValue a; + a.set_i(i); + return a; +} + +AttrValue MakeBoolAttrValue(bool b) { + AttrValue a; + a.set_b(b); + return a; +} + +AttrValue MakeAttrValue(DataType dtype) { + AttrValue a; + a.set_type(dtype); + return a; +} + +AttrValue MakeAttrValue(TensorProto tensor) { + AttrValue a; + *a.mutable_tensor() = tensor; + return a; +} + +AttrValue MakeAttrValue(absl::Span dtypes) { + AttrValue a; + auto* list = a.mutable_list(); + for (DataType dtype : dtypes) { + list->add_type(dtype); + } + return a; +} + +void EnqueueSend(XrtTfContext* context, const XrtTensorHandle& tensor, + DataType dtype, int recv_device_id, std::string wire_id, + bool host_memory, + std::shared_ptr future) { + protobuf::Map attrs; + const DeviceAttributes& device = context->devices().at(tensor.device_id()); + attrs["tensor_name"] = MakeAttrValue(wire_id); + attrs["send_device"] = MakeAttrValue(device.name()); + attrs["send_device_incarnation"] = MakeAttrValue(device.incarnation()); + attrs["recv_device"] = + MakeAttrValue(GetReceiverDevice(context, recv_device_id)); + attrs["client_terminated"] = MakeBoolAttrValue(false); + attrs["T"] = MakeAttrValue(dtype); + + context->EnqueueOp(host_memory ? "_HostSend" : "_Send", {&tensor}, + /*output_arity=*/0, std::move(attrs), tensor.device_id(), + future); +} + +XrtTensorHandle EnqueueRecv(XrtTfContext* context, DataType dtype, + int send_device_id, int recv_device_id, + std::string wire_id, bool host_memory) { + protobuf::Map attrs; + const DeviceAttributes& send_device = context->devices().at(send_device_id); + const DeviceAttributes& recv_device = context->devices().at(recv_device_id); + attrs["tensor_name"] = MakeAttrValue(wire_id); + attrs["send_device"] = MakeAttrValue(send_device.name()); + attrs["send_device_incarnation"] = MakeAttrValue(send_device.incarnation()); + attrs["recv_device"] = MakeAttrValue(recv_device.name()); + attrs["client_terminated"] = MakeBoolAttrValue(false); + attrs["tensor_type"] = MakeAttrValue(dtype); + + return std::move(context->EnqueueOp(host_memory ? "_HostRecv" : "_Recv", + /*inputs=*/{}, + /*output_arity=*/1, std::move(attrs), + recv_device_id)[0]); +} + +XrtTensorHandle EnqueueConst(XrtTfContext* context, int device_id, + TensorProto value, bool host_memory) { + protobuf::Map attrs; + attrs["value"] = MakeAttrValue(value); + attrs["dtype"] = MakeAttrValue(value.dtype()); + + return std::move(context->EnqueueOp(host_memory ? "HostConst" : "Const", + /*inputs=*/{}, + /*output_arity=*/1, std::move(attrs), + device_id)[0]); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/client/xrt_tf_client.h b/tensorflow/compiler/xrt/client/xrt_tf_client.h new file mode 100644 index 00000000000..220c57305b0 --- /dev/null +++ b/tensorflow/compiler/xrt/client/xrt_tf_client.h @@ -0,0 +1,338 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains a simplified TF client that talks to a remote TF eager +// service over gRPC. Unlike the standard TF client libraries, this is a small +// self-contained client with few TF core dependencies that allows clients to: +// * send tensors to and from remote devices +// * run ops +// which is all the functionality we need for running XLA programs using XRT. +// The API is intended to be minimal and does not take dependencies on classes +// such as Tensor or Device. +// +// The main feature this client adds over the remote eager TF client is +// batching. Rather than synchronously executing each operator, the client +// accumulates batches of operators and enqueues them as a unit. This is +// important to hide latency; clients of XRT make large numbers of cheap +// operator calls to perform operations like allocation and deallocation. +// +// The auto-batching client is also more ergonomic that using graph mode or +// functions to batch computations. The graphs an XRT client runs may often be +// ephemeral and may rarely be the same. By allowing the XRT to enqueue +// operators eagerly and performing batching in the RPC client we can hide +// latency without requiring users to manage functions/graphs and their +// lifetimes. +// +// An important future direction for the client and something that cannot be +// supported by the TF graph mode API is asynchronous execution. However, we +// do not yet use asynchronous execution, mostly because of some problematic +// error handling semantics in the remote eager service API that make it +// difficult to attribute errors to asynchronously-launched operations. + +// TODO(phawkins): handle client shutdown more gracefully; abandon all pending +// operations on shutdown. + +#ifndef TENSORFLOW_COMPILER_XRT_CLIENT_XRT_TF_CLIENT_H_ +#define TENSORFLOW_COMPILER_XRT_CLIENT_XRT_TF_CLIENT_H_ + +#include + +#include "absl/container/inlined_vector.h" +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/notification.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h" +#include "tensorflow/compiler/xrt/client/xrt_tf_client.h" +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/cluster.pb.h" + +namespace tensorflow { + +// Class that manages a connection to a TensorFlow cluster. +class XrtTfClient { + public: + XrtTfClient(ClusterDef cluster_def, + std::shared_ptr channel_cache); + + const ClusterDef& cluster_def() const { return cluster_def_; } + XrtGrpcEagerClientCache* eager_client_cache() const { + return eager_client_cache_.get(); + } + + private: + const ClusterDef cluster_def_; + const std::shared_ptr channel_cache_; + + std::unique_ptr eager_client_cache_; +}; + +class XrtTensorHandle; +class XrtRecvTensorFuture; + +// Class that manages a TensorFlow Eager context. +// TODO(phawkins): Intended to be thread-safe +class XrtTfContext : public std::enable_shared_from_this { + public: + struct Options { + // Enable async mode. + // TODO(phawkins): this is not tested. + bool async = false; + + // Maximum number of ops to keep queued. + int max_queue_size = 100; + }; + static xla::StatusOr> Create( + const Options& options, std::shared_ptr client, + const std::string& job, int task); + + XrtTfContext(const Options& options, std::shared_ptr client, + XrtGrpcEagerClient* eager_client, int64 rendezvous_id, + int64 context_id, std::vector devices, + int cpu_device_id); + + ~XrtTfContext(); + + const Options& options() const { return options_; } + + // The set of devices that were known to the remote worker when the context + // was created. + const std::vector& devices() const { return devices_; } + + // The CPU device on the remote worker. + int cpu_device_id() const { return cpu_device_id_; } + + // Sends `tensor_proto` to `devices_[device_id]`. If `host_memory` is true, + // sends to the tensor to host memory on `device_id`. + XrtTensorHandle SendTensor(std::unique_ptr tensor_proto, + int device_id, bool host_memory = false); + + // Receives `tensor` from the remote host. Does not flush the queue. + std::shared_ptr RecvTensor(const XrtTensorHandle& tensor, + DataType dtype, + bool host_memory); + + // Enqueues an operator onto the remote host. + // 'future' is an optional future that depends on the op. + std::vector EnqueueOp( + absl::string_view name, absl::Span inputs, + int output_arity, protobuf::Map attrs, int device_id, + std::shared_ptr future = {}); + + // Registers a function `def` on the remote host. + Status RegisterFunction(const FunctionDef& def); + + // Flushes any enqueued work to the remote host. + void FlushQueue(); + + private: + friend class XrtTensorHandle; + + // An operation ID on the remote worker. + typedef int64 OperationId; + + // Names a tensor on the remote worker. + typedef std::pair TensorId; + + // An Operation describes an operation to be enqueued to a remote worker, + // together with its consumers. We need to know the set of consumers so we + // can propagate errors to dependent operations in the event of failure. + struct Operation { + explicit Operation(OperationId id) : id(id) {} + + OperationId id; + + // Operations that depend on the output of this operation. + absl::InlinedVector consumers; + + // Tensor futures that consume the output of this operator. + std::vector> tensor_futures; + }; + + // Allocates and returns new operation. Does not return ownership. + Operation* AddOperation() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Flushes the queue of pending work. + void FlushQueueLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Shuts down the context, abandoning any pending operations. + Status Close(); + + // Enqueues an operation that releases the client's handle to a remote tensor. + void EnqueueDecrefTensorHandle(eager::RemoteTensorHandle handle); + + // Reports the failure of a set of operations. Propagates the failure to + // any dependent operations. + void ReportError(absl::Span op_ids, Status status) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Looks up operation 'id'. Dies if 'id' does not exist. + Operation* LookupOperation(OperationId id) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Deletes an operation 'id'. Dies if 'id' does not exist. + void DeleteOperation(OperationId id) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + const Options options_; + + std::shared_ptr tf_client_; + XrtGrpcEagerClient* eager_client_; + + // The rendezvous ID to use when performing Send/Recv operations. + const int64 rendezvous_id_; + + // A unique ID for this context on the remote worker. + const int64 context_id_; + + // The set of devices present on the remote worker. + std::vector devices_; + + // The CPU device on the remote worker. A CPU device must exist or Create() + // fails with an error. + int cpu_device_id_; + + absl::Mutex mu_; + + // The next available operation ID. + int64 next_op_id_ GUARDED_BY(mu_) = 0; + // The set of pending operations. + absl::node_hash_map operations_ GUARDED_BY(mu_); + + // The queue of operations to run next. + std::unique_ptr enqueue_request_ GUARDED_BY(mu_); + + // Requests the queue thread to flush the queue. + bool flush_requested_ GUARDED_BY(mu_) = false; + + // Notifies the queue thread that we are shutting down. + bool shutting_down_ GUARDED_BY(mu_) = false; + + // Thread responsible for enqueueing queued ops to the remote worker. + // Also responsible for sending regular RPCs to keep the connection alive. + void QueueThread(); + std::unique_ptr queue_thread_; +}; + +// RAII class that owns a reference to a remote TF tensor. +class XrtTensorHandle { + public: + XrtTensorHandle(); + XrtTensorHandle(std::shared_ptr context, int device_id, + XrtTfContext::TensorId tensor_id) + : context_(context), device_id_(device_id), tensor_id_(tensor_id) {} + ~XrtTensorHandle(); + + // Moveable but not copyable; the handle cannot be duplicated. + XrtTensorHandle(const XrtTensorHandle&) = delete; + XrtTensorHandle& operator=(const XrtTensorHandle&) = delete; + XrtTensorHandle(XrtTensorHandle&& other); + XrtTensorHandle& operator=(XrtTensorHandle&& other); + + // Serializes the handle's ID to a protocol buffer. + void Serialize(eager::RemoteTensorHandle* proto) const; + + // The context to which the handle belongs. + const std::shared_ptr& context() const { return context_; } + + int device_id() const { return device_id_; } + void set_device_id(int device_id) { device_id_ = device_id; } + + // Returns true if the handle refers to valid context. + bool valid() const { return context_ != nullptr; } + + private: + friend class XrtTfContext; + std::shared_ptr context_; + int device_id_ = -1; + XrtTfContext::TensorId tensor_id_ = {-1, -1}; +}; + +// Future that holds the result of a RecvTensor call. +class XrtRecvTensorFuture { + public: + XrtRecvTensorFuture() = default; + + // Returns either an error or a pointer to the RecvTensorResponse. + // Blocks waiting for the future if it is not yet available. + xla::StatusOr Get() { + done_.WaitForNotification(); + absl::MutexLock lock(&mu_); + if (!status_.ok()) return status_; + return &value_; + } + + private: + friend class XrtTfContext; + + // Marks the future as completed, with `status`. + void Notify(Status status) { + absl::MutexLock lock(&mu_); + if (done_.HasBeenNotified()) { + LOG(ERROR) << "Duplicate notification for XrtRecvTensorFuture. " + "Previous status: " + << status_ << " new status: " << status; + return; + } + status_ = status; + done_.Notify(); + } + + absl::Mutex mu_; + absl::Notification done_; + Status status_ GUARDED_BY(mu_); + RecvTensorResponse value_ GUARDED_BY(mu_); + + CallOptions call_options_; +}; + +// This gets a unique wire ID. We add a random identifier so that if the +// worker has other clients that it is servicing, we don't have collisions. +std::string XrtGetUniqueWireID(); + +// Helpers for enqueuing common TF ops. + +// Enqueues a _Send operator that sends a tensor located on a device on the +// remote worker. If recv_device_id < 0 the target of the send is the client, +// and a fake device name is used (since the client has no real name in the +// TF cluster). +// 'future' may be null. If non-null it gives a future that depends on the +// output of the send and that must be aborted if the send fails. +void EnqueueSend(XrtTfContext* context, const XrtTensorHandle& tensor, + DataType dtype, int recv_device_id, std::string wire_id, + bool host_memory, + std::shared_ptr future = {}); + +// Enqueues a _Recv operator that receives a tensor onto a remote device. +XrtTensorHandle EnqueueRecv(XrtTfContext* context, DataType dtype, + int send_device_id, int recv_device_id, + std::string wire_id, bool host_memory); + +// Enqueues a Const operator operator on a remote device. +XrtTensorHandle EnqueueConst(XrtTfContext* context, int device_id, + TensorProto value, bool host_memory); + +// Helpers for building AttrValue protos. We have our own versions of these +// to avoid depending on TF framework code. +AttrValue MakeAttrValue(std::string s); +AttrValue MakeAttrValue(int64 i); +AttrValue MakeBoolAttrValue(bool b); +AttrValue MakeAttrValue(DataType dtype); +AttrValue MakeAttrValue(TensorProto tensor); +AttrValue MakeAttrValue(absl::Span dtypes); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_CLIENT_XRT_TF_CLIENT_H_ diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 42ef88168af..d89dc4642be 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/compiler/xrt/xrt_device.h" #include "tensorflow/compiler/xrt/xrt_state.h" +#include "tensorflow/compiler/xrt/xrt_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -41,6 +42,12 @@ namespace tensorflow { namespace { +struct InputBuffers { + std::vector> input_tuples; + std::vector input_allocations; + std::vector input_pointers; +}; + uint32 InitialRandomSeed() { // Support plumbing the TF seed through to XLA is being worked on. // If a user wants deterministic behavior, their best option @@ -64,52 +71,134 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } -// Populates `inputs` with the input tensors to the computation. -Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm, - bool release_inputs, - std::vector* input_tuples, - std::vector* input_allocations, - std::vector* input_pointers) { - std::vector input_uids; - OpInputList arg_list; - TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list)); - - // Concatenate all input uids from list of scalars-or-vectors carrying them. - for (int i = 0; i < arg_list.size(); ++i) { - const Tensor& arg = arg_list[i]; - if (TensorShapeUtils::IsScalar(arg.shape())) { - input_uids.push_back(arg.scalar()()); - } else { - TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape())); - auto arg_vec = arg.vec(); - const int64 num_elts = arg.shape().dim_size(0); - for (int i = 0; i < num_elts; ++i) { - input_uids.push_back(arg_vec(i)); - } - } - } - - // Retrieve allocations for the uids. - input_tuples->resize(input_uids.size()); - input_pointers->resize(input_uids.size()); - for (int i = 0; i < input_uids.size(); ++i) { - const int64 input_uid = input_uids[i]; +xla::StatusOr GetInputBuffers( + ResourceMgr* rm, const std::vector& input_coords, + bool release_inputs) { + InputBuffers input_buffers; + input_buffers.input_tuples.reserve(input_coords.size()); + input_buffers.input_allocations.reserve(input_coords.size()); + input_buffers.input_pointers.reserve(input_coords.size()); + for (size_t i = 0; i < input_coords.size(); ++i) { + XRTTupleAllocation* tuple; TF_RETURN_IF_ERROR( - XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i])); + XRTTupleAllocation::Lookup(rm, input_coords[i].handle, &tuple)); + input_buffers.input_tuples.emplace_back(tuple); if (release_inputs) { // We are holding a reference to the tuple, so we can safely delete it // from the resource manager here. - TF_RETURN_IF_ERROR( - XRTTupleAllocation::DeleteFromResourceManager(rm, input_uid)); - VLOG(2) << "Released allocation handle " << input_uid; + TF_RETURN_IF_ERROR(XRTTupleAllocation::DeleteFromResourceManager( + rm, input_coords[i].handle)); + VLOG(2) << "Released allocation handle " << input_coords[i].handle; + } + if (input_coords[i].index.empty()) { + input_buffers.input_allocations.emplace_back(tuple->ToShapedBuffer()); + } else { + xla::ShapedBuffer shaped_buffer = tuple->ToShapedBuffer(); + TF_ASSIGN_OR_RETURN(xla::ShapedBuffer sub_shaped_buffer, + shaped_buffer.SubShapedBuffer(input_coords[i].index)); + input_buffers.input_allocations.emplace_back( + std::move(sub_shaped_buffer)); } - XRTTupleAllocation* tuple = (*input_tuples)[i]; - input_allocations->emplace_back(tuple->ToShapedBuffer()); } - for (int i = 0; i < input_uids.size(); ++i) { - (*input_pointers)[i] = &(*input_allocations)[i]; + for (size_t i = 0; i < input_buffers.input_allocations.size(); ++i) { + input_buffers.input_pointers.push_back(&input_buffers.input_allocations[i]); } - return Status::OK(); + return std::move(input_buffers); +} + +xla::StatusOr GetChainedOpInputs( + const xrt::XRTChainedExecuteOp& op, int current_index, + absl::Span> ops_outputs) { + InputBuffers input_buffers; + input_buffers.input_tuples.reserve(op.inputs_size()); + input_buffers.input_allocations.reserve(op.inputs_size()); + input_buffers.input_pointers.reserve(op.inputs_size()); + for (auto& input : op.inputs()) { + if (input.op_index() >= current_index) { + return errors::InvalidArgument( + "Input index ", input.op_index(), + " is above the current position: ", current_index); + } + input_buffers.input_tuples.emplace_back(ops_outputs[input.op_index()]); + // Thanks to the greatness of proto3, there is no way to query for + // explicitly set fields, so the default for output_index (zero) means no + // sub-index. As consequence, the real index is output_index - 1. + if (input.output_index() == 0) { + input_buffers.input_allocations.emplace_back( + input_buffers.input_tuples.back()->ToShapedBuffer()); + } else { + xla::ShapedBuffer shaped_buffer = + input_buffers.input_tuples.back()->ToShapedBuffer(); + TF_ASSIGN_OR_RETURN( + xla::ShapedBuffer sub_shaped_buffer, + shaped_buffer.SubShapedBuffer({input.output_index() - 1})); + input_buffers.input_allocations.emplace_back( + std::move(sub_shaped_buffer)); + } + } + for (size_t i = 0; i < input_buffers.input_allocations.size(); ++i) { + input_buffers.input_pointers.push_back(&input_buffers.input_allocations[i]); + } + return std::move(input_buffers); +} + +xla::StatusOr> ExecuteComputation( + OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, + xla::LocalExecutable* executable, const InputBuffers& input_buffers, + se::Stream* stream, int rng_seed) { + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(device_ref->backend()->memory_allocator()); + run_options.set_intra_op_thread_pool(&context->eigen_cpu_device()); + run_options.set_rng_seed(rng_seed); + + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + TF_ASSIGN_OR_RETURN( + xla::ScopedShapedBuffer run_result, + executable->Run(input_buffers.input_pointers, run_options)); + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time: " << elapsed << "us"; + + auto shaped_buffer = run_result.release(); + XRTTupleAllocation* output_tuple; + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, device_ref->backend(), device_ref->device_ordinal(), + &output_tuple)); + RefPtr output_tuple_ptr(output_tuple); + + // The ScopedShapedBuffer returned by the executable Run() API, in case of + // input/output buffer aliasing, might have holes in it, which need to be + // filled using the proper input tuples buffers which are the source of + // aliasing. + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + auto alias_function = + [&](const xla::ShapeIndex& output_index, + const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { + TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size()); + return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias + ? output_tuple->AliasBufferFrom( + *input_buffers.input_tuples[alias.parameter_number], + alias.parameter_index, output_index) + : Status::OK(); + }; + TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function)); + + return std::move(output_tuple_ptr); +} + +xla::StatusOr> ExecuteComputation( + OpKernelContext* context, ResourceMgr* rm, + XRTGenericDeviceAccessor::ScopedRef* device_ref, + xla::LocalExecutable* executable, + const std::vector& input_coords, bool release_inputs, + se::Stream* stream, int rng_seed) { + TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, + GetInputBuffers(rm, input_coords, release_inputs)); + return ExecuteComputation(context, device_ref, executable, input_buffers, + stream, rng_seed); } // XRTExecuteOp @@ -162,31 +251,6 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { rm->default_container(), kXRTCompilationCacheResourceName, &cache)); core::ScopedUnref cache_unref(cache); - std::unique_ptr entry; - TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry)); - - if (release_compilation) { - // Process-wide cache of XLA executables. - TF_RETURN_IF_ERROR(cache->Release(compilation_handle)); - VLOG(2) << "Released compilation handle " << compilation_handle; - } - - std::vector input_tuples; - // Make a cleanup method so that we can safely return in error conditions - // without leaking references to allocations. - auto buffer_releaser = gtl::MakeCleanup([&input_tuples]() { - for (auto tuple : input_tuples) { - if (tuple != nullptr) { - tuple->Unref(); - } - } - }); - std::vector input_allocations; - std::vector input_pointers; - TF_RETURN_IF_ERROR(GetComputationInputs(context, rm, release_inputs, - &input_tuples, &input_allocations, - &input_pointers)); - // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -201,86 +265,107 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { se::Stream* stream = context->op_device_context() ? context->op_device_context()->stream() : nullptr; + TF_ASSIGN_OR_RETURN(std::vector input_coords, + GetComputationInputs(context, rm, "input_handles")); - // Execute the computation. - VLOG(2) << "Executing computation."; - xla::ExecutableRunOptions run_options; - run_options.set_stream(stream); - run_options.set_allocator(device_ref.backend()->memory_allocator()); - run_options.set_intra_op_thread_pool(&context->eigen_cpu_device()); - run_options.set_rng_seed(rng_seed); - - Env* env = Env::Default(); - auto start_time = env->NowMicros(); - + std::unique_ptr entry; + TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry)); xla::LocalExecutable* executable = entry->get().get_executable(); - auto run_result = executable->Run(input_pointers, run_options); - if (!run_result.ok()) { - return run_result.status(); + if (release_compilation) { + // Process-wide cache of XLA executables. + TF_RETURN_IF_ERROR(cache->Release(compilation_handle)); + VLOG(2) << "Released compilation handle " << compilation_handle; } - auto elapsed = env->NowMicros() - start_time; - VLOG(2) << "Elapsed time: " << elapsed << "us"; + TF_ASSIGN_OR_RETURN( + RefPtr output_tuple, + ExecuteComputation(context, rm, &device_ref, executable, input_coords, + release_inputs, stream, rng_seed)); - auto scoped_buffer = run_result.ConsumeValueOrDie(); - auto shaped_buffer = scoped_buffer.release(); - XRTTupleAllocation* output_tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), - &output_tuple)); - - // The ScopedShapedBuffer returned by the executable Run() API, in case of - // input/output buffer aliasing, might have holes in it, which need to be - // filled using the proper input tuples buffers which are the source of - // aliasing. - const xla::HloInputOutputAliasConfig& input_output_alias = - executable->executable()->module().input_output_alias_config(); - auto alias_function = - [&](const xla::ShapeIndex& output_index, - const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { - TF_RET_CHECK(alias.parameter_number < input_tuples.size()); - return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias - ? output_tuple->AliasBufferFrom( - *input_tuples[alias.parameter_number], - alias.parameter_index, output_index) - : Status::OK(); - }; - TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function)); - - if (config_proto.return_exploded_tuple() && - output_tuple->on_device_shape().IsTuple()) { - int64 tuple_element_count = - xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); - Tensor* output_tensor; - TF_RETURN_IF_ERROR(context->allocate_output( - 0, TensorShape({tuple_element_count}), &output_tensor)); - - for (int64 i = 0; i < tuple_element_count; ++i) { - xla::ShapeIndex shape_index; - shape_index.push_back(i); - - XRTTupleAllocation* suballocation; - TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( - output_tuple, shape_index, &suballocation, - /*alias_parent_allocation=*/false)); - int64 key; - TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); - output_tensor->vec()(i) = key; - } - output_tuple->Unref(); - } else { - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - context->allocate_output(0, TensorShape({}), &output_tensor)); - int64 key; - TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); - output_tensor->scalar()() = key; - } - return Status::OK(); + return CreateExecuteOutput(context, rm, std::move(output_tuple), + config_proto.return_exploded_tuple()); } XRTExecuteOp::~XRTExecuteOp() = default; +class XRTExecuteChainedOp : public AsyncOpKernel { + public: + explicit XRTExecuteChainedOp(OpKernelConstruction* context); + ~XRTExecuteChainedOp() override; + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + Status DoWork(OpKernelContext* context); +}; + +XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + +void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context, + DoneCallback done) { + // Schedule onto the default queue, for unbounded concurrency. See b/73520706 + Env::Default()->SchedClosure([this, context, done]() { + OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); + done(); + }); +} + +Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { + VLOG(1) << "XRTExecuteChainedOp::Compute"; + ResourceMgr* rm; + TF_RETURN_IF_ERROR( + XRTGenericDeviceAccessor::GetResourceManager(context, &rm)); + + const Tensor& execution_plan = context->input(0); + TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape())); + xrt::XRTChainedExecutePlan plan; + TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar()())); + + const Tensor& execution_config = context->input(1); + TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); + xrt::XRTChainedExecuteConfig config; + TF_RET_CHECK(config.ParseFromString(execution_config.scalar()())); + + XRTCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->Lookup( + rm->default_container(), kXRTCompilationCacheResourceName, &cache)); + core::ScopedUnref cache_unref(cache); + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class XRTGenericDeviceAccessor::ScopedRef device_ref; + TF_RETURN_IF_ERROR( + XRTGenericDeviceAccessor::InitScopedRef(context, 0, &device_ref)); + + int rng_seed = config.rng_seed(); + if (rng_seed == 0) { + rng_seed = GetXLARandomSeed(); + } + + se::Stream* stream = context->op_device_context() + ? context->op_device_context()->stream() + : nullptr; + auto execute_op = + [&](const xrt::XRTChainedExecuteOp& op, int current_index, + absl::Span> ops_outputs) + -> xla::StatusOr> { + TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, + GetChainedOpInputs(op, current_index, ops_outputs)); + + std::unique_ptr entry; + TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry)); + xla::LocalExecutable* executable = entry->get().get_executable(); + + return ExecuteComputation(context, &device_ref, executable, input_buffers, + stream, rng_seed); + }; + + return ExecuteChained(context, rm, plan, config, execute_op); +} + +XRTExecuteChainedOp::~XRTExecuteChainedOp() = default; + } // namespace REGISTER_KERNEL_BUILDER(Name("XRTExecute") @@ -299,4 +384,18 @@ REGISTER_KERNEL_BUILDER(Name("XRTExecute") .HostMemory("output_handle"), XRTExecuteOp); +REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained") + .Device(DEVICE_XLA_CPU) + .HostMemory("execution_plan") + .HostMemory("execution_config") + .HostMemory("output_handle"), + XRTExecuteChainedOp); + +REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained") + .Device(DEVICE_XLA_GPU) + .HostMemory("execution_plan") + .HostMemory("execution_config") + .HostMemory("output_handle"), + XRTExecuteChainedOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 343f43b7159..9020fe8ea78 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -147,4 +147,9 @@ REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), XRTReleaseAllAllocationsOp); +REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_GPU), + XRTCompactAllocationsOp); +REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_CPU), + XRTCompactAllocationsOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 6af73ecc853..8a54e0987e5 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -688,6 +688,27 @@ class XRTReleaseAllAllocationsOp : public OpKernel { } }; +template +class XRTCompactAllocationsOp : public OpKernel { + public: + explicit XRTCompactAllocationsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~XRTCompactAllocationsOp() override = default; + XRTCompactAllocationsOp(const XRTCompactAllocationsOp&) = delete; + XRTCompactAllocationsOp& operator=(const XRTCompactAllocationsOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTCompactAllocationsOp::Compute"; + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); + OP_REQUIRES_OK(ctx, + XRTTupleAllocation::CompactAllocations( + rm, device_ref.backend(), device_ref.device_ordinal())); + } +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc index 4f59fccaf12..a52b2a78455 100644 --- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc @@ -50,4 +50,22 @@ computation. 'Ninputs' is the number of input handles. )"); +REGISTER_OP("XRTExecuteChained") + .Input("execution_plan: string") + .Input("execution_config: string") + .Output("output_handle: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + return tensorflow::shape_inference::ScalarShape(c); + }) + .Doc( + R"( +Runs a sequence of previously-compiled computations on a core. +The 'execution_plan' input is a serialized xrt::XRTChainedExecutePlan proto +describing the post-order of the chained execution. +The 'execution_config' input is a serialized xrt::XRTChainedExecuteConfig +proto describing the configuration for the chained execution operation. +Returns one of more int64 handles to the XRT device data generated by the +chained execution. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 87546fce4e4..6d4e70fad53 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -191,4 +191,14 @@ REGISTER_OP("XRTReleaseAllAllocations") Discards all the XRT allocations. All the client held handles will be invalid. )"); +REGISTER_OP("XRTCompactAllocations") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc( + R"( +Runs a device memory compaction cycle. This copies the device data behind the +currently alive allocation handles into host memory, releases the device memory +backing the handles, and re-allocate and send back the data to the device. +This operation helps with device memory fragmentation. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 1111f824051..305b3a67fae 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -70,6 +70,14 @@ xla::LiteralProto TwoElementTuple() { return tuple.ToProto(); } +xla::LiteralProto BasedTwoElementTuple(float base) { + auto array = xla::LiteralUtil::CreateR1({base, base + 1}); + auto matrix = xla::LiteralUtil::CreateR2( + {{base + 2, base + 3}, {base + 4, base + 5}}); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); + return tuple.ToProto(); +} + xla::LiteralProto ScalarLiteral() { auto scalar = xla::LiteralUtil::CreateR0(12.0f); return scalar.ToProto(); @@ -167,6 +175,18 @@ xla::XlaComputation AddAndScale() { return builder.Build().ValueOrDie(); } +xla::XlaComputation SubAndScale() { + xla::XlaBuilder builder("SubAndScale"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1"); + auto sum = xla::Sub(p0, p1); + auto c = xla::ConstantR0(&builder, 11.0f); + xla::Mul(sum, c); + return builder.Build().ValueOrDie(); +} + xla::XlaComputation Dot() { xla::XlaBuilder builder("Dot"); auto p0 = xla::Parameter( @@ -369,8 +389,8 @@ TEST(RawApiTest, AllocAndRewrite) { auto read_back = ops::XRTReadLiteral(root, handle); TF_ASSERT_OK(root.status()); - tensorflow::ClientSession session(root); - std::vector outputs; + ClientSession session(root); + std::vector outputs; TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); EXPECT_EQ(outputs.size(), 2); @@ -378,7 +398,6 @@ TEST(RawApiTest, AllocAndRewrite) { xla::LiteralProto response; EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); - outputs.clear(); xla::LiteralProto new_literal = xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); @@ -390,7 +409,6 @@ TEST(RawApiTest, AllocAndRewrite) { TF_EXPECT_OK(session.Run({write_op}, &outputs)); EXPECT_EQ(outputs.size(), 1); EXPECT_EQ(allocation_handle, outputs[0].scalar()()); - outputs.clear(); auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); @@ -404,8 +422,7 @@ TEST(RawApiTest, AllocAndRewrite) { release_tensor.flat()(0) = allocation_handle; auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); - TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, - &outputs)); + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs)); } TEST(RawApiTest, AllocReleaseMany) { @@ -425,8 +442,8 @@ TEST(RawApiTest, AllocReleaseMany) { auto handle2 = ops::XRTAllocate(root, value2); TF_ASSERT_OK(root.status()); - tensorflow::ClientSession session(root); - std::vector outputs; + ClientSession session(root); + std::vector outputs; TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs)); EXPECT_EQ(outputs.size(), 2); @@ -438,9 +455,7 @@ TEST(RawApiTest, AllocReleaseMany) { release_tensor.flat()(1) = allocation_handle2; auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); - outputs.clear(); - TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, - &outputs)); + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs)); } TEST(RawApiTest, CompileAndReleaseMany) { @@ -467,13 +482,7 @@ TEST(RawApiTest, CompileAndReleaseMany) { .ToProto(); StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot()); - xrt::XRTExecutionConfig e; - e.set_release_input_handles(true); - e.set_release_compilation_handle(false); - Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); - auto e_config = - ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); auto computation1 = ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString()); auto c_handle1 = ops::XRTCompile(root, computation1); @@ -495,9 +504,7 @@ TEST(RawApiTest, CompileAndReleaseMany) { release_tensor.flat()(1) = compilation_handle2; auto release = ops::XRTReleaseCompilationHandle(root, release_tensor); - outputs.clear(); - TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, - &outputs)); + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs)); } TEST(RawApiTest, AllocAndClearAll) { @@ -511,8 +518,8 @@ TEST(RawApiTest, AllocAndClearAll) { auto handle = ops::XRTAllocate(root, value); TF_ASSERT_OK(root.status()); - tensorflow::ClientSession session(root); - std::vector outputs; + ClientSession session(root); + std::vector outputs; TF_EXPECT_OK(session.Run({handle}, &outputs)); EXPECT_EQ(outputs.size(), 1); @@ -520,14 +527,13 @@ TEST(RawApiTest, AllocAndClearAll) { auto clear_all = ops::XRTReleaseAllAllocations(root); - outputs.clear(); - TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, - {clear_all}, &outputs)); + TF_EXPECT_OK( + session.Run(ClientSession::FeedType(), {}, {clear_all}, &outputs)); EXPECT_EQ(outputs.size(), 0); auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle)); EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(), - tensorflow::error::Code::NOT_FOUND); + error::Code::NOT_FOUND); } TEST(RawApiTest, ReadAndWriteState) { @@ -543,10 +549,10 @@ TEST(RawApiTest, ReadAndWriteState) { root.WithControlDependencies(read_back), handle); TF_ASSERT_OK(root.status()); - tensorflow::ClientSession session(root); - std::vector outputs; - TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back}, - {release}, &outputs)); + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK( + session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); xla::LiteralProto response; EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); @@ -686,6 +692,196 @@ TEST(RawApiTest, MakeTuple) { EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1)); } +TEST(RawApiTest, ExecuteChainedOpByOp) { + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + + auto make_computation = [](const std::function& fn) { + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot()); + return c.SerializeAsString(); + }; + + auto c_add_scale = make_computation(AddAndScale); + auto c_sub_scale = make_computation(SubAndScale); + + auto c_add_scale_op = ops::XRTCompile( + root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale)); + auto c_sub_scale_op = ops::XRTCompile( + root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale)); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK( + session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 c_add_scale_handle = outputs[0].scalar()(); + int64 c_sub_scale_handle = outputs[1].scalar()(); + + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({8.0f, 5.0f}); + + auto p0_handle = ops::XRTAllocate( + root, + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString())); + auto p1_handle = ops::XRTAllocate( + root, + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString())); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(false); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto result0 = ops::XRTExecute(root, Input(c_add_scale_handle), e_config, + {Output(p0_handle), Output(p1_handle)}); + auto result1 = ops::XRTExecute(root, Input(c_sub_scale_handle), e_config, + {Output(p0_handle), Output(p1_handle)}); + auto result = ops::XRTExecute(root, Input(c_add_scale_handle), e_config, + {result0.output_handle, result1.output_handle}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({-150.0f, -36.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, ExecuteChained) { + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + + auto make_computation = [](const std::function& fn) { + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot()); + return c.SerializeAsString(); + }; + + auto c_add_scale = make_computation(AddAndScale); + auto c_sub_scale = make_computation(SubAndScale); + + auto c_add_scale_op = ops::XRTCompile( + root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale)); + auto c_sub_scale_op = ops::XRTCompile( + root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale)); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK( + session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 c_add_scale_handle = outputs[0].scalar()(); + int64 c_sub_scale_handle = outputs[1].scalar()(); + + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({8.0f, 5.0f}); + + auto p0_handle_op = ops::XRTAllocate( + root, + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString())); + auto p1_handle_op = ops::XRTAllocate( + root, + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString())); + + TF_EXPECT_OK(session.Run({p0_handle_op, p1_handle_op}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 p0_handle = outputs[0].scalar()(); + int64 p1_handle = outputs[1].scalar()(); + + xrt::XRTChainedExecuteConfig config; + auto config_const = + ops::Const(root.WithDevice("/device:CPU:0"), config.SerializeAsString()); + + xrt::XRTChainedExecutePlan plan; + xrt::XRTChainedExecuteOp* op; + xrt::XRTChainedExecuteOp::Input* input; + xrt::XRTChainedExecuteOp::Output* output; + + // Index 0 + op = plan.add_ops(); + op->set_data_handle(p0_handle); + + // Index 1 + op = plan.add_ops(); + op->set_data_handle(p1_handle); + + // Index 2 + op = plan.add_ops(); + op->set_computation_handle(c_add_scale_handle); + input = op->add_inputs(); + input->set_op_index(0); + input = op->add_inputs(); + input->set_op_index(1); + + // Index 3 + op = plan.add_ops(); + op->set_computation_handle(c_sub_scale_handle); + input = op->add_inputs(); + input->set_op_index(0); + input = op->add_inputs(); + input->set_op_index(1); + + // Index 4 + op = plan.add_ops(); + op->set_computation_handle(c_add_scale_handle); + input = op->add_inputs(); + input->set_op_index(2); + input = op->add_inputs(); + input->set_op_index(3); + output = op->add_outputs(); + output->set_result_index(0); + + auto plan_const = + ops::Const(root.WithDevice("/device:CPU:0"), plan.SerializeAsString()); + auto result = ops::XRTExecuteChained(root, plan_const, config_const); + TF_ASSERT_OK(root.status()); + + TF_EXPECT_OK(session.Run({result}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + auto handles_vec = outputs[0].vec(); + EXPECT_EQ(handles_vec.size(), 1); + + auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(0))); + TF_ASSERT_OK(root.status()); + + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({-150.0f, -36.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecute) { xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f}); @@ -831,8 +1027,8 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { ClientSession session(root); std::vector outputs; - TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), - {c_handle.program_shape}, {release}, &outputs)); + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {c_handle.program_shape}, + {release}, &outputs)); xla::ProgramShapeProto program_shape_proto; EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec()(0))); @@ -846,16 +1042,16 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { xla::ProgramShape xla_program_shape = XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes)); - EXPECT_TRUE(xla::LayoutUtil::Equal( + EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) .layout())); - EXPECT_TRUE(xla::LayoutUtil::Equal( + EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(), xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1}) .layout())); - EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(), - xla_program_shape.result().layout())); + EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()( + program_shape.result().layout(), xla_program_shape.result().layout())); } TEST(RawApiTest, DotGeneralWithLayoutTest) { @@ -1146,9 +1342,8 @@ TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) { root.WithControlDependencies(read_back), result); TF_ASSERT_OK(root.status()); - outputs.clear(); - TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back}, - {release}, &outputs)); + TF_EXPECT_OK( + session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); xla::Literal exec_literal = ReadOutputLiteral(outputs, 0); auto exec_literal_parts = exec_literal.DecomposeTuple(); @@ -1165,8 +1360,7 @@ TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) { root.WithControlDependencies(read_handle), Input(alloc_handle)); TF_ASSERT_OK(root.status()); - outputs.clear(); - TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_handle}, + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle}, {release_handle}, &outputs)); xla::Literal return_literal = ReadOutputLiteral(outputs, 0); @@ -1235,6 +1429,65 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xla::Shape(program_shape.result()), xla::S64)); } +// Tests the XRT device memory compation API (XRTCompactAllocations). +TEST(RawApiTest, TestDeviceMemoryCompaction) { + static const int kNumAllocs = 32; + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + + std::vector allocs(kNumAllocs); + std::vector handle_outputs; + for (int i = 0; i < kNumAllocs; ++i) { + *allocs[i].mutable_value() = BasedTwoElementTuple(i * 4.0f); + auto value = ops::Const(root.WithDevice("/device:CPU:0"), + allocs[i].SerializeAsString()); + handle_outputs.push_back(ops::XRTAllocate(root, value)); + } + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run(handle_outputs, &outputs)); + EXPECT_EQ(outputs.size(), handle_outputs.size()); + + std::vector handles; + for (auto& output : outputs) { + handles.push_back(output.scalar()()); + } + // Create holes by releasing even allocations. + std::vector handle_releases; + for (size_t i = 0; i < handles.size(); i += 2) { + handle_releases.push_back( + ops::XRTReleaseAllocationHandle(root, Input(handles[i]))); + } + TF_ASSERT_OK(root.status()); + + TF_EXPECT_OK( + session.Run(ClientSession::FeedType(), {}, handle_releases, &outputs)); + + // Run the compaction API. + auto compact_op = ops::XRTCompactAllocations(root); + TF_EXPECT_OK( + session.Run(ClientSession::FeedType(), {}, {compact_op}, &outputs)); + + // Read back the allocation left at odd indices. + std::vector read_outputs; + for (size_t i = 1; i < handles.size(); i += 2) { + read_outputs.push_back(ops::XRTReadLiteral(root, Input(handles[i]))); + } + TF_ASSERT_OK(root.status()); + + TF_EXPECT_OK(session.Run(read_outputs, &outputs)); + EXPECT_EQ(outputs.size(), read_outputs.size()); + + // Verify that everything got moved correctly and the device data matches what + // we have on record. + for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) { + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[j].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response)); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 84adee73928..a598b800329 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -3,9 +3,9 @@ syntax = "proto3"; package xrt; import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; import "tensorflow/compiler/xla/xla.proto"; import "tensorflow/compiler/xla/xla_data.proto"; -import "tensorflow/compiler/xla/service/hlo.proto"; message DeviceAssignment { message ComputationDevice { @@ -106,3 +106,61 @@ message XRTExecutionConfig { // allocations, one for each of the first-level elements of the result tuple. bool return_exploded_tuple = 7; } + +message XRTChainedExecuteConfig { + // If non-zero, rng_seed to reset the core with. + uint32 rng_seed = 1; + // Which model-parallel computation to run from the compiled bundle. + int32 core_index_in_replica = 2; + // Optional key to disambiguate between executions. This is only needed if + // multiple host send/recvs may be outstanding concurrently with executions. + string execution_instance_key = 3; +} + +// A single chained execute operation. An operation can either be a device data +// load, or an existing (as in, previously compiled and accessible via its int64 +// handle) XLA computation execution. +message XRTChainedExecuteOp { + // Represents an input for this operation. + message Input { + // The index within the XRTChainedExecutePlan.ops post-order of the source + // operation for this input. + int64 op_index = 1; + // The output index of the value generated by the operation at op_index. + // Zero (default value) means no index ({}) while if an indexing is + // required, output_index needs to be set to index+1. + // Thanks proto3! + int64 output_index = 2; + } + // Represents an output of the XRTChainedExecute operation, which should + // originate by the output of this operation. + message Output { + // The index in the value generated by this operation, which should be + // forwarded as XRTChainedExecute output. If output_index is zero (default + // value) the whole output will be used as result. This means that if the + // output shape is a tuple, the result will be the full tuple. Otherwise the + // real sub-tuple index will be output_index - 1. + int64 output_index = 1; + // The index in the vector of the results returned by the XRTChainedExecute + // operation, where this output should be forwarded. + int64 result_index = 2; + } + + oneof op_oneof { + // The handle to an existing XRT device data. + int64 data_handle = 1; + // The handle to an existing XRT compiled computation. + int64 computation_handle = 2; + } + // The outputs of this XRTChainedExecuteOp operation. + repeated Output outputs = 3; + // The inputs of this XRTChainedExecuteOp operation. If data_handle is set, + // there are no inputs. + repeated Input inputs = 4; +} + +// Execution plan for the XRTChainedExecute operation. +message XRTChainedExecutePlan { + // The post order with the XRT computations to be executed. + repeated XRTChainedExecuteOp ops = 1; +} diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 1b3bcbea4c1..fa25b727a3d 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_state.h" #include + #include #include #include @@ -73,8 +74,11 @@ class BufferAllocStats { const char* kTupleContainer = "tuples"; int64 get_uid() { - uint64 unsigned_rand = random::New64() & INT64_MAX; - return static_cast(unsigned_rand); + int64 uid; + do { + uid = random::New64() & INT64_MAX; + } while (uid == XRTTupleAllocation::InvalidKey()); + return uid; } BufferAllocStats* GetAllocStats() { @@ -113,10 +117,10 @@ Status AllocateScopedShapedBuffer( xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = transfer_manager->GetByteSizeRequirement(subshape); TF_ASSIGN_OR_RETURN( - xla::OwningDeviceMemory buffer, + se::OwningDeviceMemory buffer, allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false)); // Move our buffer into shaped_buffer, which takes ownership of it. - index_to_buffer.second = buffer.Forget(); + index_to_buffer.second = buffer.Release(); VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque() << " index " << index_to_buffer.first.ToString(); } @@ -131,7 +135,7 @@ Status AllocateScopedShapedBuffer( XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, int device_ordinal, - xla::DeviceMemoryAllocator* allocator) + se::DeviceMemoryAllocator* allocator) : size_(allocation.size()), allocation_(allocation), device_ordinal_(device_ordinal), @@ -165,7 +169,7 @@ void XRTBufferAllocation::DiscardAllocation() { } XRTTupleAllocation::XRTTupleAllocation(int device_ordinal, - xla::DeviceMemoryAllocator* allocator, + se::DeviceMemoryAllocator* allocator, const xla::Shape& on_host_shape, const xla::Shape& on_device_shape) : device_ordinal_(device_ordinal), @@ -336,9 +340,41 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; return Status::OK(); } +/* static */ Status XRTTupleAllocation::CompactAllocations( + ResourceMgr* rm, xla::Backend* backend, int device_ordinal) { + std::vector tuples; + rm->GetContainerResources(kTupleContainer, &tuples); + + std::vector> host_tuples; + for (auto& rm_tuple : tuples) { + XRTTupleAllocation* tuple = + dynamic_cast(rm_tuple.resource.get()); + if (tuple->device_ordinal() == device_ordinal) { + xla::Literal literal(tuple->on_host_shape()); + TF_RETURN_IF_ERROR(tuple->ToLiteral(backend, device_ordinal, &literal)); + host_tuples.emplace_back(rm_tuple.name, std::move(literal)); + // At this point there are two references held onto the XRTTupleAllocation + // object. One in the ResourceMgr, which we release here, and one held + // within the tuples vector, which we release in the tuples.clear() call + // below. + TF_RETURN_IF_ERROR( + rm->Delete(kTupleContainer, rm_tuple.name)); + } + } + tuples.clear(); + + for (auto& name_literal : host_tuples) { + XRTTupleAllocation* tuple; + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateAndTransfer( + name_literal.second, backend, device_ordinal, &tuple)); + TF_RETURN_IF_ERROR(rm->Create(kTupleContainer, name_literal.first, tuple)); + } + return Status::OK(); +} + /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples( const xla::ShapeTree& elements, int device_ordinal, - xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, + se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, xla::Shape* device_shape) { // Initialize both host and device shape to be the 'spine' of the new tuple // shape, given by the shape of the tree of tuples. @@ -411,10 +447,10 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; xla::Shape subshape = xla::ShapeUtil::GetSubshape(device_shape, index); uint64 size = transfer_manager->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer, + TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer, allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false)); - VLOG(2) << "Allocated buffer at " << buffer.opaque() << " index " + VLOG(2) << "Allocated buffer at " << buffer->opaque() << " index " << index.ToString(); // Move the new buffer into new_tuple_buffers, which takes ownership // of it. @@ -498,7 +534,7 @@ bool XRTTupleAllocation::IsExclusiveOwner() { void XRTTupleAllocation::InitializeFromShapedBuffer( const xla::ShapedBuffer& shaped_buffer, - xla::DeviceMemoryAllocator* allocator, int device_ordinal) { + se::DeviceMemoryAllocator* allocator, int device_ordinal) { for (auto& buffer : buffers_) { // Make a reference-counted version of the allocated buffer. buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first), @@ -545,7 +581,7 @@ XRTTupleAllocation::ToDeviceMemoryTree( if (!release_checker(buffer.first)) { *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation(); } else { - *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory( + *shaped_tree.mutable_element(buffer.first) = se::OwningDeviceMemory( buffer.second->allocation(), device_ordinal_, allocator_); DiscardAllocation(buffer.first); } diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 6519da30d02..4d284382532 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,6 +33,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/stream_executor.h" namespace tensorflow { @@ -45,8 +45,7 @@ namespace tensorflow { class XRTBufferAllocation : public core::RefCounted { public: XRTBufferAllocation(const se::DeviceMemoryBase& allocation, - int device_ordinal, - xla::DeviceMemoryAllocator* allocator); + int device_ordinal, se::DeviceMemoryAllocator* allocator); ~XRTBufferAllocation() override; // The region of device memory being wrapped. @@ -69,7 +68,7 @@ class XRTBufferAllocation : public core::RefCounted { uint64 size_ = 0; se::DeviceMemoryBase allocation_; int device_ordinal_; - xla::DeviceMemoryAllocator* allocator_; + se::DeviceMemoryAllocator* allocator_; }; // Entry in the resource manager corresponding to an allocation handle returned @@ -107,6 +106,11 @@ class XRTTupleAllocation : public ResourceBase { XRTTupleAllocation** allocation, bool alias_parent_allocation); + // Runs a compaction cycle which copies the device data to host, frees the + // device data, and then reallocate and send back the data. + static Status CompactAllocations(ResourceMgr* rm, xla::Backend* backend, + int device_ordinal); + // A structure describing a leaf of a tree of tuples to expand. Each leaf // contains an allocation and indicates whether or not the allocation's handle // should be freed after incorporating its buffers into the expanded tree. @@ -141,6 +145,10 @@ class XRTTupleAllocation : public ResourceBase { // manager. static Status ReleaseAllAllocations(ResourceMgr* rm); + // Returns the invalid key value, which will be never generated by the + // Intern() API. + static int64 InvalidKey() { return 0; } + // Adds the allocation to a ResourceMgr and returns the key that will be used // to retrieve it. Transfers a reference on *this to rm. Status Intern(ResourceMgr* rm, int64* key); @@ -193,14 +201,14 @@ class XRTTupleAllocation : public ResourceBase { private: // Creates a new handle with (tuple) shape. - XRTTupleAllocation(int device_ordinal, xla::DeviceMemoryAllocator* allocator, + XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator, const xla::Shape& on_host_shape, const xla::Shape& on_device_shape); // Inherits the allocations represented in buffer, which must have the same // shape as buffers_. void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer, - xla::DeviceMemoryAllocator* allocator, + se::DeviceMemoryAllocator* allocator, int device_ordinal); // Takes a tree 'elements' where each leaf is an allocation, validates that @@ -210,12 +218,12 @@ class XRTTupleAllocation : public ResourceBase { // grafted on. static Status ExpandTreeOfTuples( const xla::ShapeTree& elements, int device_ordinal, - xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, + se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, xla::Shape* device_shape); // Location of the memory that is being managed. int device_ordinal_; - xla::DeviceMemoryAllocator* allocator_; + se::DeviceMemoryAllocator* allocator_; // The shape that the caller thinks the tuple has. const xla::Shape on_host_shape_; diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc index 3ef8bedc732..518c993f390 100644 --- a/tensorflow/compiler/xrt/xrt_util.cc +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -47,6 +47,20 @@ string SafeDebugPath(const string& path) { return string(); } +Status MakeOutput(const RefPtr& output, int64 index, + RefPtr* result) { + if (index == 0) { + *result = output; + } else { + XRTTupleAllocation* tuple; + TF_RETURN_IF_ERROR( + XRTTupleAllocation::MakeSubBuffer(output.get(), {index - 1}, &tuple, + /*alias_parent_allocation=*/true)); + result->reset(tuple); + } + return Status::OK(); +} + } // namespace xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { @@ -55,22 +69,133 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { return ref_options; } xla::DebugOptions options = xla::GetDebugOptionsFromFlags(); - options.set_xla_generate_hlo_text_to( - SafeDebugPath(ref_options.xla_generate_hlo_text_to())); - options.set_xla_dump_optimized_hlo_proto_to( - SafeDebugPath(ref_options.xla_dump_optimized_hlo_proto_to())); - options.set_xla_dump_computations_to( - SafeDebugPath(ref_options.xla_dump_computations_to())); - options.set_xla_dump_executions_to( - SafeDebugPath(ref_options.xla_dump_executions_to())); + options.set_xla_dump_to(SafeDebugPath(ref_options.xla_dump_to())); + options.set_xla_dump_hlo_as_proto(ref_options.xla_dump_hlo_as_proto()); + options.set_xla_dump_hlo_as_text(ref_options.xla_dump_hlo_as_text()); + options.set_xla_dump_hlo_snapshots(ref_options.xla_dump_hlo_snapshots()); + options.set_xla_dump_hlo_pass_re(ref_options.xla_dump_hlo_pass_re()); for (auto& pass : ref_options.xla_disable_hlo_passes()) { options.add_xla_disable_hlo_passes(pass); } - options.set_xla_dump_unoptimized_hlo_proto_to( - SafeDebugPath(ref_options.xla_dump_unoptimized_hlo_proto_to())); - options.set_xla_dump_per_pass_hlo_proto_to( - SafeDebugPath(ref_options.xla_dump_per_pass_hlo_proto_to())); return options; } +xla::StatusOr> GetComputationInputs( + OpKernelContext* context, ResourceMgr* rm, const char* input_name) { + OpInputList arg_list; + TF_RETURN_IF_ERROR(context->input_list(input_name, &arg_list)); + // Concatenate all input uids from list of scalars-or-vectors carrying them. + std::vector input_coords; + for (int i = 0; i < arg_list.size(); ++i) { + const Tensor& arg = arg_list[i]; + if (TensorShapeUtils::IsScalar(arg.shape())) { + input_coords.emplace_back(arg.scalar()()); + } else { + TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape())); + auto arg_vec = arg.vec(); + const int64 num_elts = arg.shape().dim_size(0); + for (int i = 0; i < num_elts; ++i) { + input_coords.emplace_back(arg_vec(i)); + } + } + } + return std::move(input_coords); +} + +Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm, + RefPtr output_tuple, + bool return_exploded_tuple) { + if (return_exploded_tuple && output_tuple->on_host_shape().IsTuple()) { + int64 tuple_element_count = + xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); + Tensor* output_tensor; + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({tuple_element_count}), &output_tensor)); + + for (int64 i = 0; i < tuple_element_count; ++i) { + XRTTupleAllocation* suballocation; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + output_tuple.get(), {i}, &suballocation, + /*alias_parent_allocation=*/false)); + int64 key; + TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); + output_tensor->vec()(i) = key; + } + } else { + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({}), &output_tensor)); + int64 key; + TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); + output_tuple.release(); + output_tensor->scalar()() = key; + } + return Status::OK(); +} + +Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm, + const xrt::XRTChainedExecutePlan& plan, + const xrt::XRTChainedExecuteConfig& config, + const ChainedExecuteFn& execute_op) { + // Create the vector which tracks the uses of the intermediate chained + // operations outputs. + std::vector uses(plan.ops_size(), 0); + for (auto& op : plan.ops()) { + for (auto& input : op.inputs()) { + uses[input.op_index()] += 1; + } + } + std::vector> ops_outputs(plan.ops_size()); + std::vector> results; + for (int i = 0; i < plan.ops_size(); ++i) { + auto& op = plan.ops(i); + if (op.op_oneof_case() == xrt::XRTChainedExecuteOp::kDataHandle) { + // This operation is a device data load. Fetch the proper + // XRTTupleAllocation behind the user handle and fill up the op output at + // the current position. + XRTTupleAllocation* tuple; + TF_RETURN_IF_ERROR( + XRTTupleAllocation::Lookup(rm, op.data_handle(), &tuple)); + ops_outputs[i].reset(tuple); + } else if (op.op_oneof_case() == + xrt::XRTChainedExecuteOp::kComputationHandle) { + // This is an XRT execute operation, forward to the device specific + // handler. + TF_ASSIGN_OR_RETURN(ops_outputs[i], execute_op(op, i, ops_outputs)); + } else { + return errors::InvalidArgument( + "Undefined operation kind at post-order position ", i); + } + // If the result of this chained operation is an output result, feed the + // results vector at the desired position. + for (auto& output : op.outputs()) { + if (output.result_index() >= results.size()) { + results.resize(output.result_index() + 1); + } + TF_RETURN_IF_ERROR(MakeOutput(ops_outputs[i], output.output_index(), + &results[output.result_index()])); + } + // Drop intermediate results which have no more users. + for (auto& input : op.inputs()) { + uses[input.op_index()] -= 1; + if (uses[input.op_index()] == 0) { + ops_outputs[input.op_index()].reset(); + } + } + } + + Tensor* output_tensor; + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({static_cast(results.size())}), &output_tensor)); + for (size_t i = 0; i < results.size(); ++i) { + int64 key = XRTTupleAllocation::InvalidKey(); + if (results[i] != nullptr) { + TF_RETURN_IF_ERROR(results[i]->Intern(rm, &key)); + results[i].release(); + } + output_tensor->vec()(i) = key; + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h index d9c05a7f340..07159dd5677 100644 --- a/tensorflow/compiler/xrt/xrt_util.h +++ b/tensorflow/compiler/xrt/xrt_util.h @@ -18,10 +18,106 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ #define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/compiler/xrt/xrt_state.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { +// Reference counted smart pointer for XRT objects providing the standard +// Ref()/Unref() APIs. +template +class RefPtr { + public: + RefPtr() = default; + // Creates a RefPtr from a pointer. This is an ownership transfer operation, + // and the caller has to own a valid reference to ptr (unless ptr is nullptr). + RefPtr(T* ptr) : ptr_(ptr) {} + RefPtr(const RefPtr& other) : ptr_(other.ptr_) { Acquire(ptr_); } + RefPtr(RefPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; } + + ~RefPtr() { Release(ptr_); } + + RefPtr& operator=(const RefPtr& other) { + if (this != &other) { + Acquire(other.ptr_); + Release(ptr_); + ptr_ = other.ptr_; + } + return *this; + } + + RefPtr& operator=(RefPtr&& other) { + if (this != &other) { + Release(ptr_); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + } + return *this; + } + + operator bool() const { return ptr_ != nullptr; } + bool operator==(const RefPtr& rhs) const { return ptr_ == rhs.ptr_; } + bool operator!=(const RefPtr& rhs) const { return ptr_ != rhs.ptr_; } + bool operator==(const T* ptr) const { return ptr_ == ptr; } + bool operator!=(const T* ptr) const { return ptr_ != ptr; } + bool operator==(std::nullptr_t ptr) const { return ptr_ == ptr; } + bool operator!=(std::nullptr_t ptr) const { return ptr_ != ptr; } + + T* get() const { return ptr_; } + + T* operator->() const { + CHECK(ptr_ != nullptr); // Crash OK + return ptr_; + } + + T& operator*() const { + CHECK(ptr_ != nullptr); // Crash OK + return *ptr_; + } + + T* release() { + T* ptr = ptr_; + ptr_ = nullptr; + return ptr; + } + + // Resets the RefPtr from a pointer. This is an ownership transfer operation, + // and the caller has to own a valid reference to ptr (unless ptr is nullptr). + void reset(T* ptr = nullptr) { + Release(ptr_); + ptr_ = ptr; + } + + private: + static void Release(T* ptr) { + if (ptr != nullptr) { + ptr->Unref(); + } + } + + static void Acquire(T* ptr) { + if (ptr != nullptr) { + ptr->Ref(); + } + } + + T* ptr_ = nullptr; +}; + +struct InputCoords { + explicit InputCoords(int64 handle) : handle(handle) {} + InputCoords(int64 handle, xla::ShapeIndex index) + : handle(handle), index(std::move(index)) {} + + int64 handle = 0; + xla::ShapeIndex index; +}; + // Filters the debug options provided as argument according to the value of the // TF_XLA_DEBUG_OPTIONS_PASSTHROUGH environment variable. If such variable is // set to "1" or "true", the debug options will be returned as is. Otherwise @@ -29,6 +125,29 @@ namespace tensorflow { // contained in it, will be limited to gs:// and bigstore:// ones. xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options); +// Populates the input_coords with a list of input coordinates from a input_name +// op argument. +xla::StatusOr> GetComputationInputs( + OpKernelContext* context, ResourceMgr* rm, const char* input_name); + +// Create the XRT execute output tensor given the computation result +// (output_tuple). The return_exploded_tuple tells whether a tuple result should +// be returned as vector of handles representing each tuple child. +Status CreateExecuteOutput(OpKernelContext* context, ResourceMgr* rm, + RefPtr output_tuple, + bool return_exploded_tuple); + +// Drives the XRT chained computation execution given the supplied core execute +// function. +using ChainedExecuteFn = + std::function>( + const xrt::XRTChainedExecuteOp&, int, + absl::Span>)>; +Status ExecuteChained(OpKernelContext* context, ResourceMgr* rm, + const xrt::XRTChainedExecutePlan& plan, + const xrt::XRTChainedExecuteConfig& config, + const ChainedExecuteFn& execute_op); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 0173b8bb064..6760ef265d3 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -8,7 +8,6 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "if_not_windows") -load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") py_library( name = "contrib_py", @@ -27,7 +26,6 @@ py_library( "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/checkpoint/python:checkpoint", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", - "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/contrib/compiler:xla", "//tensorflow/contrib/autograph", @@ -59,6 +57,7 @@ py_library( "//tensorflow/contrib/labeled_tensor", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", + "//tensorflow/contrib/learn:head_test_lib", "//tensorflow/contrib/legacy_seq2seq:seq2seq_py", "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", @@ -165,9 +164,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", - "//tensorflow/contrib/coder:all_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/hadoop:dataset_kernels", + "//tensorflow/contrib/image:image_ops_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels", @@ -205,7 +204,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", - "//tensorflow/contrib/coder:all_ops", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/hadoop:dataset_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 48d5296c71c..7253ec4c9d5 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -30,7 +30,6 @@ from tensorflow.contrib import checkpoint if os.name != "nt" and platform.machine() != "s390x": from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver -from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import constrained_optimization from tensorflow.contrib import copy_graph diff --git a/tensorflow/contrib/autograph/examples/benchmarks/BUILD b/tensorflow/contrib/autograph/examples/benchmarks/BUILD index 6d2d70c99b4..651b108e239 100644 --- a/tensorflow/contrib/autograph/examples/benchmarks/BUILD +++ b/tensorflow/contrib/autograph/examples/benchmarks/BUILD @@ -17,6 +17,7 @@ py_test( name = "cartpole_benchmark", size = "enormous", srcs = ["cartpole_benchmark.py"], + python_version = "PY2", tags = [ "local", "manual", diff --git a/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py index 93c694849c4..25414fbda62 100644 --- a/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py +++ b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py @@ -47,7 +47,7 @@ class ReportingBenchmark(tf.test.Benchmark): avg_time = np.average(all_times) - extras = dict() + extras = {} extras['all_times'] = all_times if isinstance(name, tuple): diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 5174afe0a63..bc9b2b05172 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -6,31 +6,18 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") load( "//tensorflow:tensorflow.bzl", "py_test", - "tf_custom_op_library", - "tf_gen_op_libs", - "tf_gen_op_wrapper_py", - "tf_kernel_library", ) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") py_library( name = "batch_py", srcs = glob(["python/ops/*.py"]) + ["__init__.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:array_ops", + "//tensorflow/python:batch_ops", "//tensorflow/python:batch_ops_gen", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:platform", - "//tensorflow/python:script_ops", - "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 3e4d0dc1cec..2a4f3c36fbf 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,14 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.eager import function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import gen_batch_ops -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.python.ops.gen_batch_ops import * -# pylint: enable=wildcard-import +# pylint: disable=unused-import +from tensorflow.python.ops.batch_ops import batch +from tensorflow.python.ops.batch_ops import batch_function +from tensorflow.python.ops.batch_ops import unbatch +# pylint: enable=unused-import @ops.RegisterGradient("Batch") @@ -55,85 +54,6 @@ def _UnbatchGrad(op, grad): # pylint: disable=invalid-name ] -def batch_function(num_batch_threads, - max_batch_size, - batch_timeout_micros, - allowed_batch_sizes=None, - max_enqueued_batches=10): - """Batches the computation done by the decorated function. - - So, for example, in the following code - - ```python - @batch_function(1, 2, 3) - def layer(a): - return tf.matmul(a, a) - - b = layer(w) - ``` - - if more than one session.run call is simultaneously trying to compute `b` - the values of `w` will be gathered, non-deterministically concatenated - along the first axis, and only one thread will run the computation. See the - documentation of the `Batch` op for more details. - - Assumes that all arguments of the decorated function are Tensors which will - be batched along their first dimension. - - SparseTensor is not supported. The return value of the decorated function - must be a Tensor or a list/tuple of Tensors. - - Args: - num_batch_threads: Number of scheduling threads for processing batches - of work. Determines the number of batches processed in parallel. - max_batch_size: Batch sizes will never be bigger than this. - batch_timeout_micros: Maximum number of microseconds to wait before - outputting an incomplete batch. - allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, - does nothing. Otherwise, supplies a list of batch sizes, causing the op - to pad batches up to one of those sizes. The entries must increase - monotonically, and the final entry must equal max_batch_size. - max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. - - Returns: - The decorated function will return the unbatched computation output Tensors. - """ - - def decorator(fn): # pylint: disable=missing-docstring - - def decorated(*args): # pylint: disable=missing-docstring - - @function.defun() - def computation(*computation_args): - return fn(*computation_args) - - computation = computation.get_concrete_function( - *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i)) - for i, x in enumerate(args)]) - - with ops.name_scope("batch") as name: - for a in args: - if not isinstance(a, ops.Tensor): - raise ValueError("All arguments to functions decorated with " - "`batch_function` are supposed to be Tensors; " - "found %s" % repr(a)) - return gen_batch_ops.batch_function( - num_batch_threads=num_batch_threads, - max_batch_size=max_batch_size, - batch_timeout_micros=batch_timeout_micros, - allowed_batch_sizes=allowed_batch_sizes, - max_enqueued_batches=max_enqueued_batches, - shared_name=name, - f=computation, - in_tensors=list(args), - captured_tensors=computation.captured_inputs, - Tout=[o.dtype for o in computation.outputs]) - - return decorated - - return decorator - - def batch_function_v1(num_batch_threads, max_batch_size, batch_timeout_micros, diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 9109b9c1c91..e224588fa30 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -23,12 +23,8 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -41,153 +37,6 @@ def delayed_plus1(x): class BatchOpsTest(test.TestCase): """Tests for batch_ops.{un,}batch.""" - def testBasicBatch(self): - """Tests that a single batched tensor executes together and only once.""" - with self.cached_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - batched, index, _ = batch_ops.batch( - [inp], num_batch_threads=1, max_batch_size=2, - batch_timeout_micros=36000000, grad_timeout_micros=0, - batching_queue="") - thread_results = [] - - def worker(): - thread_results.extend( - sess.run([batched, index], feed_dict={inp: [1]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([batched, index], feed_dict={inp: [2]}) - worker_thread.join() - - # At this point either the thread or the main did the batch and the other - # should have empty results. - if list(thread_results[0][0]): - batch_t = thread_results[0][0] - index_t = thread_results[1] - empty_b = main_results[0][0] - empty_m = main_results[1] - else: - batch_t = main_results[0][0] - index_t = main_results[1] - empty_b = thread_results[0][0] - empty_m = thread_results[1] - - # Check that both the inputs made it out exactly once. - self.assertAllEqual(sorted(batch_t), (1, 2)) - # Check that we get 2 rows in the index tensor. - self.assertEqual(len(index_t), 2) - # Check that the other ones are empty. - self.assertEqual(len(empty_b), 0) - self.assertEqual(len(empty_m), 0) - - def testBatchWithPadding(self): - """Test that batching with padding up to an allowed batch size works.""" - with self.cached_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) - batched, index, _ = batch_ops.batch( - [inp], num_batch_threads=1, max_batch_size=10, - batch_timeout_micros=100000, # 100ms - allowed_batch_sizes=[5, 10], - grad_timeout_micros=0, batching_queue="") - thread_results = [] - - def worker(): - thread_results.extend( - sess.run([batched, index], feed_dict={inp: [1, 3]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([batched, index], feed_dict={inp: [2, 4]}) - worker_thread.join() - - # At this point either the thread or the main did the batch and the other - # should have empty results. - if list(thread_results[0][0]): - batch_t = thread_results[0][0] - else: - batch_t = main_results[0][0] - - # Check that the batch tensor incorporates the padding. - self.assertEqual(len(batch_t), 5) - - def testMultipleBatch(self): - """Tests that multiple batched tensors execute together.""" - with self.cached_session() as sess: - inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - batched, _, _ = batch_ops.batch( - [inp0, inp1], - num_batch_threads=1, - max_batch_size=2, - batch_timeout_micros=36000000, - grad_timeout_micros=0, - batching_queue="") - thread_results = [] - - def worker(): - thread_results.extend( - sess.run([batched], feed_dict={inp0: [1], - inp1: [2]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]}) - worker_thread.join() - - # At this point either the thread or the main did the batch and the other - # should have empty results. - if list(thread_results[0][0]): - batch_t = thread_results[0] - empty_t = main_results[0] - else: - batch_t = main_results[0] - empty_t = thread_results[0] - - # Assert that the tensors were batched together. - self.assertAllEqual(sorted(batch_t[0]), [1, 2]) - self.assertAllEqual(sorted(batch_t[1]), [2, 3]) - self.assertAllEqual(empty_t[0], []) - self.assertAllEqual(empty_t[1], []) - - def testIllegalBatchDifferentDim0Sizes(self): - """Tests illegally feeding tensors with different dim0 sizes.""" - with self.cached_session() as sess: - inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) - batched, index, _ = batch_ops.batch( - [inp0, inp1], num_batch_threads=1, max_batch_size=2, - batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="") - with self.assertRaises(Exception) as raised: - _ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]}) - self.assertGreater( - raised.exception.message.find("must have equal 0th-dimension size"), - 0) - - def testBasicUnbatch(self): - """Tests that batch and unbatch work together.""" - with self.cached_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - batched, index, id_t = batch_ops.batch( - [inp], num_batch_threads=1, max_batch_size=10, - batch_timeout_micros=100000, # 100ms - allowed_batch_sizes=[3, 10], - grad_timeout_micros=0, batching_queue="") - computation = batched[0] + 1 - result = batch_ops.unbatch(computation, index, id_t, - timeout_micros=1000000, shared_name="unbatch") - thread_results = [] - - def worker(): - thread_results.extend(sess.run([result], feed_dict={inp: [1]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([result], feed_dict={inp: [2]}) - worker_thread.join() - self.assertEqual(thread_results[0], [2]) - self.assertEqual(main_results[0], [3]) - def testBasicUnbatchV1Decorated(self): """Tests that the batch_function_v1 decorator works.""" with self.cached_session() as sess: @@ -210,206 +59,6 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) - def testBasicUnbatchDecorated(self): - """Tests that the batch_function decorator works.""" - with self.cached_session() as sess: - # TODO(apassos): Removing this line causes test flakiness! Ideally should - # be investigated. - default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable - - @batch_ops.batch_function(1, 10, 100000) - def computation(in_t): - self.assertTrue(in_t.shape is not None) - return in_t + 1 - - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - result = computation(inp) - thread_results = [] - - def worker(): - thread_results.extend(sess.run([result], feed_dict={inp: [1]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([result], feed_dict={inp: [2]}) - worker_thread.join() - self.assertEqual(thread_results[0], [2]) - self.assertEqual(main_results[0], [3]) - - def testBatchDecoratedWithCapturedInput(self): - """Tests that the batch_function decorator works.""" - with self.cached_session() as sess: - captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) - captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) - - @batch_ops.batch_function(1, 10, 100000) - def computation(in_t): - return in_t + captured_inp0 - captured_inp1 - - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - result = computation(inp) - thread_results = [] - - def worker(): - thread_results.extend(sess.run([result], feed_dict={inp: [1]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([result], feed_dict={inp: [2]}) - worker_thread.join() - self.assertEqual(thread_results[0], [2]) - self.assertEqual(main_results[0], [3]) - - def testBatchFunctionOp(self): - """Tests that the batch_function op works.""" - with self.cached_session() as sess: - - @function.Defun(dtypes.int32) - def computation(in_t): - return in_t + 1 - - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - result = gen_batch_ops.batch_function( - [inp], - num_batch_threads=1, - max_batch_size=10, - batch_timeout_micros=100000, - Tout=[dtypes.int32], - f=computation, - captured_tensors=computation.captured_inputs) - thread_results = [] - - def worker(): - thread_results.extend(sess.run([result], feed_dict={inp: [1]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([result], feed_dict={inp: [2]}) - worker_thread.join() - self.assertEqual(thread_results[0], [2]) - self.assertEqual(main_results[0], [3]) - - def testBatchFunctionOpWithCapturedInput(self): - """Tests that batch_function op works with captured input.""" - with self.cached_session() as sess: - captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) - captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - - @function.Defun(dtypes.int32) - def computation(inp): - return inp + captured_inp0 - captured_inp1 - - result = gen_batch_ops.batch_function( - num_batch_threads=1, - max_batch_size=10, - batch_timeout_micros=100000, # 100ms - allowed_batch_sizes=[3, 10], - batching_queue="", - f=computation, - in_tensors=[inp], - captured_tensors=computation.captured_inputs, - Tout=[o.type for o in computation.definition.signature.output_arg]) - - thread_results = [] - - def worker(): - thread_results.extend(sess.run([result], feed_dict={inp: [1]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([result], feed_dict={inp: [2]}) - worker_thread.join() - self.assertEqual(thread_results[0], [2]) - self.assertEqual(main_results[0], [3]) - - def testBatchFunctionOpWithInputError(self): - """Tests that batch_function op works with error in the inputs.""" - with self.cached_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - - @function.Defun(dtypes.int32, dtypes.int32) - def computation(in0, in1): - return in0 + in1 - - result = gen_batch_ops.batch_function( - [inp], # computation actually expects 2 inputs. - num_batch_threads=1, - max_batch_size=10, - batch_timeout_micros=100000, # 100ms - batching_queue="", - f=computation, - captured_tensors=computation.captured_inputs, - Tout=[o.type for o in computation.definition.signature.output_arg]) - - with self.assertRaisesRegexp(InvalidArgumentError, - ".*2 arguments.*but 1.*"): - sess.run([result], feed_dict={inp: [2]}) - - def testBasicUnbatchDecoratedWithReshape(self): - """Tests that the batch_function decorator works.""" - with self.cached_session() as sess: - - @batch_ops.batch_function(1, 10, 100000) - def computation(in_t): - return array_ops.reshape(in_t, [-1]) + 1 - - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) - result = computation(inp) - thread_results = [] - - def worker(): - thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) - - worker_thread = threading.Thread(target=worker) - worker_thread.start() - main_results = sess.run([result], feed_dict={inp: [[2]]}) - worker_thread.join() - self.assertEqual(thread_results[0], [2]) - self.assertEqual(main_results[0], [3]) - - def testUnbatchTimeout(self): - """Tests that the unbatch timeout works.""" - with self.cached_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) - batched, index, id_t = batch_ops.batch( - [inp], num_batch_threads=1, max_batch_size=2, - batch_timeout_micros=36000000, grad_timeout_micros=0, - batching_queue="") - computation = batched[0] + 1 - timeout_micros = 10 - result = batch_ops.unbatch(computation, index, id_t, timeout_micros, - shared_name="shared_unbatch") - # Set up a parallel pipeline that delays the computation, but uses the - # same unbatch resource object as the non-delayed pipeline. - computation_delayed = script_ops.py_func(delayed_plus1, - [batched[0]], - dtypes.int32) - result_delayed = batch_ops.unbatch(computation_delayed, - index, - id_t, - timeout_micros, - shared_name="shared_unbatch") - - thread_results = [] - def worker(): - # A first call using the non-delayed pipeline. The batcher will send an - # empty tensor along the non-delayed pipeline. - thread_results.extend(sess.run([result], feed_dict={inp: [1]})) - worker_thread = threading.Thread(target=worker) - worker_thread.start() - time.sleep(0.1) # Ensure the thread's call starts first. - # A second call using the delayed pipeline. The batcher will send the - # batched tensor along the delayed pipeline, thus delaying the arrival of - # the batched tensor at the unbatch op, relative to the empty tensor. - # - # TODO(olston, apassos): Avoid relying on the order in which the batch op - # emits the empty tensor versus the batched one. - _ = sess.run([result_delayed], feed_dict={inp: [2]}) - worker_thread.join() - # The thread's call should hit the timeout, and thus get 0 results. - self.assertEqual(len(thread_results), 0) - def testUnbatchGrad(self): """Tests that batch and unbatch are differentiable.""" with self.cached_session() as sess: @@ -434,6 +83,5 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [4]) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index f0637595db0..002d68111cd 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" - #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -262,32 +261,31 @@ class ToBigtableOp : public AsyncOpKernel { } components.clear(); } - grpc::Status mutation_status; + ::google::cloud::Status mutation_status; std::vector<::google::cloud::bigtable::FailedMutation> failures = - resource->table().BulkApply(std::move(mutation), mutation_status); - if (!mutation_status.ok()) { - LOG(ERROR) << "Failure applying mutation: " - << mutation_status.error_code() << " - " - << mutation_status.error_message() << " (" - << mutation_status.error_details() << ")."; - } + resource->table().BulkApply(mutation); if (!failures.empty()) { + mutation_status = failures.front().status(); + if (!mutation_status.ok()) { + LOG(ERROR) << "Failure applying mutation: " + << mutation_status.code() << " - " + << mutation_status.message() << "."; + } + ::google::bigtable::v2::MutateRowsRequest request; + mutation.MoveTo(&request); for (const auto& failure : failures) { LOG(ERROR) << "Failure applying mutation on row (" - << failure.original_index() - << "): " << failure.mutation().row_key() - << " - error: " << failure.status().error_message() - << " (Details: " << failure.status().error_details() - << ")."; + << failure.original_index() << "): " + << request.entries(failure.original_index()).row_key() + << " - error: " << failure.status().message() << "."; } } OP_REQUIRES_ASYNC( - ctx, failures.empty() && mutation_status.ok(), + ctx, failures.empty(), errors::Unknown("Failure while writing to Cloud Bigtable: ", - mutation_status.error_code(), " - ", - mutation_status.error_message(), " (", - mutation_status.error_details(), - "), # of mutation failures: ", failures.size(), + mutation_status.code(), " - ", + mutation_status.message(), + "; # of mutation failures: ", failures.size(), ". See the log for the specific error details."), done); } while (!end_of_sequence); diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc index a81c3ec5c23..0bdaf3ae0bd 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -16,20 +16,55 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" namespace tensorflow { +namespace { +::tensorflow::error::Code GcpErrorCodeToTfErrorCode( + ::google::cloud::StatusCode code) { + switch (code) { + case ::google::cloud::StatusCode::kOk: + return ::tensorflow::error::OK; + case ::google::cloud::StatusCode::kCancelled: + return ::tensorflow::error::CANCELLED; + case ::google::cloud::StatusCode::kUnknown: + return ::tensorflow::error::UNKNOWN; + case ::google::cloud::StatusCode::kInvalidArgument: + return ::tensorflow::error::INVALID_ARGUMENT; + case ::google::cloud::StatusCode::kDeadlineExceeded: + return ::tensorflow::error::DEADLINE_EXCEEDED; + case ::google::cloud::StatusCode::kNotFound: + return ::tensorflow::error::NOT_FOUND; + case ::google::cloud::StatusCode::kAlreadyExists: + return ::tensorflow::error::ALREADY_EXISTS; + case ::google::cloud::StatusCode::kPermissionDenied: + return ::tensorflow::error::PERMISSION_DENIED; + case ::google::cloud::StatusCode::kUnauthenticated: + return ::tensorflow::error::UNAUTHENTICATED; + case ::google::cloud::StatusCode::kResourceExhausted: + return ::tensorflow::error::RESOURCE_EXHAUSTED; + case ::google::cloud::StatusCode::kFailedPrecondition: + return ::tensorflow::error::FAILED_PRECONDITION; + case ::google::cloud::StatusCode::kAborted: + return ::tensorflow::error::ABORTED; + case ::google::cloud::StatusCode::kOutOfRange: + return ::tensorflow::error::OUT_OF_RANGE; + case ::google::cloud::StatusCode::kUnimplemented: + return ::tensorflow::error::UNIMPLEMENTED; + case ::google::cloud::StatusCode::kInternal: + return ::tensorflow::error::INTERNAL; + case ::google::cloud::StatusCode::kUnavailable: + return ::tensorflow::error::UNAVAILABLE; + case ::google::cloud::StatusCode::kDataLoss: + return ::tensorflow::error::DATA_LOSS; + } +} +} // namespace -Status GrpcStatusToTfStatus(const ::grpc::Status& status) { +Status GcpStatusToTfStatus(const ::google::cloud::Status& status) { if (status.ok()) { return Status::OK(); } - auto grpc_code = status.error_code(); - if (status.error_code() == ::grpc::StatusCode::ABORTED || - status.error_code() == ::grpc::StatusCode::UNAVAILABLE || - status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) { - grpc_code = ::grpc::StatusCode::INTERNAL; - } - return Status(static_cast<::tensorflow::error::Code>(grpc_code), - strings::StrCat("Error reading from Cloud Bigtable: ", - status.error_message())); + return Status( + GcpErrorCodeToTfErrorCode(status.code()), + strings::StrCat("Error reading from Cloud Bigtable: ", status.message())); } string RegexFromStringSet(const std::vector& strs) { diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index e3b4535bac4..1325560e772 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -16,16 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ #define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ -// Note: we use bigtable/client/internal/table.h as this is the no-exception API - #include "google/cloud/bigtable/data_client.h" -#include "google/cloud/bigtable/internal/table.h" +#include "google/cloud/bigtable/table.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/resource_mgr.h" namespace tensorflow { -Status GrpcStatusToTfStatus(const ::grpc::Status& status); +Status GcpStatusToTfStatus(const ::google::cloud::Status& status); string RegexFromStringSet(const std::vector& strs); @@ -65,7 +63,7 @@ class BigtableTableResource : public ResourceBase { ~BigtableTableResource() override { client_->Unref(); } - ::google::cloud::bigtable::noex::Table& table() { return table_; } + ::google::cloud::bigtable::Table& table() { return table_; } string DebugString() const override { return strings::StrCat( @@ -76,7 +74,7 @@ class BigtableTableResource : public ResourceBase { private: BigtableClientResource* client_; // Ownes one ref. const string table_name_; - ::google::cloud::bigtable::noex::Table table_; + ::google::cloud::bigtable::Table table_; }; namespace data { @@ -89,22 +87,21 @@ class BigtableReaderDatasetIterator : public DatasetIterator { public: explicit BigtableReaderDatasetIterator( const typename DatasetIterator::Params& params) - : DatasetIterator(params), iterator_(nullptr, false) {} + : DatasetIterator(params) {} Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(EnsureIteratorInitialized()); if (iterator_ == reader_->end()) { - grpc::Status status = reader_->Finish(); - if (status.ok()) { - *end_of_sequence = true; - return Status::OK(); - } - return GrpcStatusToTfStatus(status); + *end_of_sequence = true; + return Status::OK(); + } + if (!*iterator_) { + return GcpStatusToTfStatus(iterator_->status()); } *end_of_sequence = false; - google::cloud::bigtable::Row& row = *iterator_; + google::cloud::bigtable::Row& row = **iterator_; Status s = ParseRow(ctx, row, out_tensors); // Ensure we always advance. ++iterator_; diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index 2c6317157d2..98ec991a934 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -152,18 +152,19 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { } if (input_tensors[0].NumElements() == 1) { // Single key lookup. - ::grpc::Status status; - auto pair = dataset()->table_->table().ReadRow( - input_tensors[0].scalar()(), dataset()->filter_, status); - if (!status.ok()) { - return GrpcStatusToTfStatus(status); + ::google::cloud::StatusOr< + std::pair> + row = dataset()->table_->table().ReadRow( + input_tensors[0].scalar()(), dataset()->filter_); + if (!row.ok()) { + return GcpStatusToTfStatus(row.status()); } - if (!pair.first) { + if (!row->first) { return errors::DataLoss("Row key '", input_tensors[0].scalar()(), "' not found."); } - TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors)); + TF_RETURN_IF_ERROR(ParseRow(ctx, row->second, out_tensors)); } else { // Batched get. return errors::Unimplemented( diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h index 44c628e366c..d6d00476612 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_ #include +#include #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index f0c3ef4e2ec..88284c5a4e9 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -125,15 +125,15 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { // ensure we don't accidentally miss any subsets of the requested range by // including `begin_key()` and `end_key()` as appropriate. Status Initialize(IteratorContext* ctx) override { - grpc::Status status; - std::vector row_keys = - dataset()->table().table().SampleRows(status); - if (!status.ok()) { - return GrpcStatusToTfStatus(status); + ::google::cloud::StatusOr< + std::vector<::google::cloud::bigtable::RowKeySample>> + row_key_samples = dataset()->table().table().SampleRows(); + if (!row_key_samples.ok()) { + return GcpStatusToTfStatus(row_key_samples.status()); } - for (size_t i = 0; i < row_keys.size(); ++i) { - string row_key(row_keys[i].row_key); + for (const auto& row_key_sample : *row_key_samples) { + string row_key(row_key_sample.row_key); if (dataset()->key_range_.contains_key(row_key)) { // First key: check to see if we need to add the begin_key. if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) { diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index 9b60e0a6672..119da35973a 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -80,12 +80,14 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { - ::grpc::Status status; - row_keys_ = dataset()->table()->table().SampleRows(status); - if (!status.ok()) { + ::google::cloud::StatusOr< + std::vector<::google::cloud::bigtable::RowKeySample>> + sampled_rows = dataset()->table()->table().SampleRows(); + if (!sampled_rows.ok()) { row_keys_.clear(); - return GrpcStatusToTfStatus(status); + return GcpStatusToTfStatus(sampled_rows.status()); } + row_keys_ = std::move(*sampled_rows); return Status::OK(); } diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index d9fce6e09f4..4b688f2d22f 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -432,6 +432,17 @@ BigtableTestClient::AsyncReadRows( return nullptr; } +std::unique_ptr< + grpc::ClientAsyncReaderInterface> +BigtableTestClient::PrepareAsyncMutateRows( + grpc::ClientContext* context, + const google::bigtable::v2::MutateRowsRequest& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index 63d59b32dd1..299494b7180 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -100,6 +100,12 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { const google::bigtable::v2::ReadRowsRequest& request, grpc::CompletionQueue* cq, void* tag) override; + std::unique_ptr> + PrepareAsyncMutateRows(grpc::ClientContext* context, + const google::bigtable::v2::MutateRowsRequest& request, + grpc::CompletionQueue* cq) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc index 32611e2590d..cf6e619bfaf 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc @@ -39,7 +39,6 @@ TEST(BigtableTestClientTest, EmptyRowRead) { ::google::cloud::bigtable::Filter::Latest(1)); auto rows = table.ReadRows(std::move(rowset), filter); EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!"; - EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows."; } TEST(BigtableTestClientTest, SingleRowWriteAndRead) { @@ -55,15 +54,15 @@ TEST(BigtableTestClientTest, SingleRowWriteAndRead) { auto rows = table.ReadRows(std::move(rowset), filter); auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "No rows were returned in response!"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_EQ(itr, rows.end()); - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { @@ -82,15 +81,15 @@ TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, MultiRowWriteAndRead) { @@ -109,33 +108,35 @@ TEST(BigtableTestClientTest, MultiRowWriteAndRead) { auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r2"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v2"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r2"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v2"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r3"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v3"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r3"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v3"); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { @@ -154,33 +155,35 @@ TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r2"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v2"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r2"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v2"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r3"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v3"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r3"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v3"); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, ColumnFiltering) { @@ -206,33 +209,35 @@ TEST(BigtableTestClientTest, ColumnFiltering) { auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v1"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v1"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r2"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v2"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r2"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v2"); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r3"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), "v3"); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r3"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), "v3"); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, RowKeys) { @@ -257,33 +262,35 @@ TEST(BigtableTestClientTest, RowKeys) { table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); auto itr = rows.begin(); EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r1"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), ""); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r1"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), ""); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r2"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), ""); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r2"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), ""); ++itr; EXPECT_NE(itr, rows.end()) << "Missing rows"; - EXPECT_EQ(itr->row_key(), "r3"); - EXPECT_EQ(itr->cells().size(), 1); - EXPECT_EQ(itr->cells()[0].family_name(), "f1"); - EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); - EXPECT_EQ(itr->cells()[0].value(), ""); + EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message(); + EXPECT_EQ((*itr)->row_key(), "r3"); + EXPECT_EQ((*itr)->cells().size(), 1); + EXPECT_EQ((*itr)->cells()[0].family_name(), "f1"); + EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ((*itr)->cells()[0].value(), ""); ++itr; EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; - EXPECT_TRUE(rows.Finish().ok()); } TEST(BigtableTestClientTest, SampleKeys) { diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 736cf3da49e..aa476281c90 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -221,7 +221,7 @@ class BigtableTable(object): A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all of the row keys matching that prefix. """ - return dataset_ops.DatasetV1Adapter(_BigtablePrefixKeyDataset(self, prefix)) + return _BigtablePrefixKeyDataset(self, prefix) def sample_keys(self): """Retrieves a sampling of row keys from the Bigtable table. @@ -233,7 +233,7 @@ class BigtableTable(object): Returns: A `tf.data.Dataset` returning string row keys. """ - return dataset_ops.DatasetV1Adapter(_BigtableSampleKeysDataset(self)) + return _BigtableSampleKeysDataset(self) def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): """Retrieves row (including values) from the Bigtable service. @@ -278,8 +278,7 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return dataset_ops.DatasetV1Adapter( - _BigtableScanDataset(self, prefix, "", "", normalized, probability)) + return _BigtableScanDataset(self, prefix, "", "", normalized, probability) def scan_range(self, start, end, probability=None, columns=None, **kwargs): """Retrieves rows (including values) from the Bigtable service. @@ -324,8 +323,7 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return dataset_ops.DatasetV1Adapter( - _BigtableScanDataset(self, "", start, end, normalized, probability)) + return _BigtableScanDataset(self, "", start, end, normalized, probability) def parallel_scan_prefix(self, prefix, @@ -381,8 +379,7 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = dataset_ops.DatasetV1Adapter( - _BigtableSampleKeyPairsDataset(self, prefix, "", "")) + ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "") return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -444,8 +441,7 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = dataset_ops.DatasetV1Adapter( - _BigtableSampleKeyPairsDataset(self, "", start, end)) + ds = _BigtableSampleKeyPairsDataset(self, "", start, end) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index f7f15a302a0..6791e379107 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -1,4 +1,5 @@ # TensorFlow code for training gradient boosted trees. + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) @@ -71,6 +72,7 @@ py_test( name = "losses_test", size = "small", srcs = ["python/utils/losses_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":losses", @@ -121,6 +123,7 @@ py_test( name = "gbdt_batch_test", size = "medium", srcs = ["python/training/functions/gbdt_batch_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "notsan", # b/62863147 @@ -150,6 +153,7 @@ py_test( name = "model_ops_test", size = "small", srcs = ["python/kernel_tests/model_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":model_ops_py", @@ -170,6 +174,7 @@ py_test( name = "prediction_ops_test", size = "small", srcs = ["python/kernel_tests/prediction_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":model_ops_py", @@ -187,6 +192,7 @@ py_test( name = "quantile_ops_test", size = "small", srcs = ["python/kernel_tests/quantile_ops_test.py"], + python_version = "PY2", shard_count = 3, srcs_version = "PY2AND3", deps = [ @@ -209,6 +215,7 @@ py_test( name = "split_handler_ops_test", size = "small", srcs = ["python/kernel_tests/split_handler_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":split_handler_ops_py", @@ -225,6 +232,7 @@ py_test( name = "stats_accumulator_ops_test", size = "small", srcs = ["python/kernel_tests/stats_accumulator_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":stats_accumulator_ops_py", @@ -239,6 +247,7 @@ py_test( name = "training_ops_test", size = "small", srcs = ["python/kernel_tests/training_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":model_ops_py", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 64e4c4560ba..968aff18053 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -1,5 +1,6 @@ # This directory contains estimators to train and run inference on # gradient boosted trees on top of TensorFlow. + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) @@ -69,6 +70,7 @@ py_test( name = "trainer_hooks_test", size = "small", srcs = ["trainer_hooks_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":trainer_hooks", @@ -118,6 +120,7 @@ py_test( name = "custom_export_strategy_test", size = "small", srcs = ["custom_export_strategy_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":custom_export_strategy", @@ -176,6 +179,7 @@ py_test( size = "medium", timeout = "long", srcs = ["dnn_tree_combined_estimator_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_gpu", @@ -195,6 +199,7 @@ py_test( name = "estimator_test", size = "medium", srcs = ["estimator_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py index 48a7f85eada..c4f94a6554a 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py @@ -44,7 +44,7 @@ def _export_outputs_to_output_alternatives(export_outputs): Returns: converted output_alternatives. """ - output = dict() + output = {} if export_outputs is not None: for key, value in export_outputs.items(): if isinstance(value, export_output.ClassificationOutput): diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index eecf3c5aeb6..07fa4ca684b 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -186,6 +186,10 @@ def model_builder(features, train_op_fn=_train_op_fn, logits=logits) + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + estimator_spec.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + estimator_spec = estimator_spec._replace( training_hooks=training_hooks + list(estimator_spec.training_hooks)) return estimator_spec diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index fd832de982a..634dfab1090 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -1,5 +1,6 @@ # Description: # This directory contains common utilities used in boosted_trees. + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) @@ -284,6 +285,7 @@ py_library( py_test( name = "categorical_split_handler_test", srcs = ["learner/batch/categorical_split_handler_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":categorical_split_handler", @@ -324,6 +326,7 @@ py_library( py_test( name = "ordinal_split_handler_test", srcs = ["learner/batch/ordinal_split_handler_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":ordinal_split_handler", diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index d26af584197..22ad181fc3f 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -193,7 +193,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): num_minibatches = control_flow_ops.cond( ops.convert_to_tensor(self._loss_uses_sum_reduction), - lambda: math_ops.to_int64(1), lambda: num_minibatches) + lambda: math_ops.cast(1, dtypes.int64), + lambda: num_minibatches) partition_ids, gains, split_infos = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=num_minibatches, diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 0476bed2cd3..0e6a9f8f3a0 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -312,9 +312,10 @@ def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, gen_stats_accumulator_ops.stats_accumulator_scalar_flush( stats_accumulator_handle, stamp_token, next_stamp_token)) # For sum_reduction, we don't need to divide by number of minibatches. - num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, - lambda: math_ops.to_int64(1), - lambda: num_minibatches) + num_minibatches = control_flow_ops.cond( + loss_uses_sum_reduction, + lambda: math_ops.cast(1, dtypes.int64), + lambda: num_minibatches) # Put quantile and stats accumulator flushing in the dependency path. with ops.control_dependencies([flush_quantiles, partition_ids]): are_splits_ready = array_ops.identity(are_splits_ready) @@ -488,9 +489,10 @@ def _make_sparse_split( num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( gen_stats_accumulator_ops.stats_accumulator_scalar_flush( stats_accumulator_handle, stamp_token, next_stamp_token)) - num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, - lambda: math_ops.to_int64(1), - lambda: num_minibatches) + num_minibatches = control_flow_ops.cond( + loss_uses_sum_reduction, + lambda: math_ops.cast(1, dtypes.int64), + lambda: num_minibatches) # Put quantile and stats accumulator flushing in the dependency path. with ops.control_dependencies([flush_quantiles, partition_ids]): are_splits_ready = array_ops.identity(are_splits_ready) diff --git a/tensorflow/contrib/boosted_trees/lib/utils/random.h b/tensorflow/contrib/boosted_trees/lib/utils/random.h index 249651e99ed..f0eaef24cbb 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/random.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/random.h @@ -15,6 +15,8 @@ #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_RANDOM_H_ #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_RANDOM_H_ +#include + #include "tensorflow/core/lib/random/simple_philox.h" namespace tensorflow { @@ -24,7 +26,7 @@ namespace utils { // Generates a poisson distributed number with mean 1 for use in bootstrapping. inline int32 PoissonBootstrap(random::SimplePhilox* rng) { // Knuth, special cased for lambda = 1.0 for efficiency. - static const float lbound = exp(-1.0f); + static const float lbound = std::exp(-1.0f); int32 n = 0; for (float r = 1; r > lbound; r *= rng->RandFloat()) { ++n; diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index e78ec476ab3..4a13da4b5be 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -228,7 +228,7 @@ def extract_features(features, feature_columns, use_core_columns): indices = array_ops.concat([ array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]), array_ops.expand_dims( - math_ops.to_int64(categorical_tensor.values), -1) + math_ops.cast(categorical_tensor.values, dtypes.int64), -1) ], 1) tensor = sparse_tensor.SparseTensor( indices=indices, values=weight_tensor.values, dense_shape=shape) @@ -590,10 +590,14 @@ class GradientBoostedDecisionTreeModel(object): stamp_token=ensemble_stamp, tree_ensemble_config=serialized_model), ensemble_stamp - refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond( - math_ops.not_equal(ensemble_stamp, - local_stamp), _refresh_local_ensemble_fn, - lambda: (control_flow_ops.no_op(), ensemble_stamp)) + with ops.device(local_ensemble_handle.device): + # Need to colocate stamps for cond. + colocated_ensemble_stamp = array_ops.identity(ensemble_stamp) + + refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond( + math_ops.not_equal(colocated_ensemble_stamp, + local_stamp), _refresh_local_ensemble_fn, + lambda: (control_flow_ops.no_op(), colocated_ensemble_stamp)) # Once updated, use the local model for prediction. with ops.control_dependencies([refresh_local_ensemble]): @@ -611,8 +615,9 @@ class GradientBoostedDecisionTreeModel(object): learner_pb2.LearnerConfig.TREE_PER_CLASS and self._logits_dimension != 1): # Choose the class for which the tree is built (one vs rest). - return math_ops.to_int32( - predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) + return math_ops.cast( + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension, + dtypes.int32) return constant_op.constant(-1, dtype=dtypes.int32) def update_stats(self, loss, predictions_dict, gradients=None, hessians=None): diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 7e45d0b2cec..728b764898a 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -1149,9 +1149,9 @@ class GbdtTest(test_util.TensorFlowTestCase): expected_leaf_1 = [-3.4480, -3.4429, 13.8490, -3.45, -3.4508] expected_leaf_2 = [-1.2547, -1.3145, 1.52, 2.3875, -1.3264] self.assertArrayNear(expected_leaf_1, - output.trees[0].nodes[1].leaf.vector.value, 1e-3) + output.trees[0].nodes[1].leaf.vector.value, 7e-3) self.assertArrayNear(expected_leaf_2, - output.trees[0].nodes[2].leaf.vector.value, 1e-3) + output.trees[0].nodes[2].leaf.vector.value, 7e-3) def testTrainFnMulticlassDiagonalHessian(self): """Tests the GBDT train for multiclass diagonal hessian.""" diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 1ad40aca288..40fdfcf45ac 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -43,7 +44,7 @@ def per_example_logistic_loss(labels, weights, predictions): loss: A Rank 2 (N, 1) tensor of per-example logistic loss. update_op: An update operation to update the loss's internal state. """ - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() @@ -74,7 +75,7 @@ def per_example_quantile_regression_loss(labels, weights, predictions, loss: A Rank 2 (N, 1) tensor of per-example quantile loss. update_op: An update operation to update the loss's internal state. """ - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) error = labels - predictions square_loss_right = array_ops.where(error * quantile < 1.0, math_ops.square(quantile * error), @@ -112,7 +113,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): loss: A Rank 2 (N, 1) tensor of per-example maxent loss update_op: An update operation to update the loss's internal state. """ - labels = math_ops.to_int64(labels) + labels = math_ops.cast(labels, dtypes.int64) # If labels are of rank 1, make them rank 2. labels_shape = labels.get_shape() if len(labels_shape) != 2: @@ -120,7 +121,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): # Labels are indices of classes, convert them to one hot encodings. target_one_hot = array_ops.one_hot(indices=labels, depth=num_classes) labels = math_ops.reduce_sum(input_tensor=target_one_hot, axis=[1]) - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) # Calculate softmax probabilities for each class. unnormalized_probs = math_ops.exp(logits) @@ -253,7 +254,7 @@ def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1): preds_converted = min_res return math_ops.exp(-preds_converted * labels_converted) - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) unweighted_loss = exp_with_logits( name=name, eps=eps, labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() @@ -312,7 +313,7 @@ def per_example_full_exp_loss(labels, weights, predictions, name=None): return math_ops.exp(-1.0 * logits * labels_converted) - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) unweighted_loss = full_exp_with_logits( name=name, labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index cd9c94c9bd7..caedf5b2d1d 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -120,5 +120,4 @@ tf_py_test( "//tensorflow/python/keras:layers", "//tensorflow/python/training/tracking:util", ], - tags = ["no_oss"], # b/124472244 ) diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 1ada05227ba..e9618b972d9 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -108,8 +108,8 @@ class NumpyState(base.Trackable): except AttributeError: value = _NumpyWrapper(value) self._track_trackable(value, name=name, overwrite=True) - elif (name not in ("_setattr_tracking", "_update_uid") - and getattr(self, "_setattr_tracking", True)): + elif (name not in ("_self_setattr_tracking", "_self_update_uid") + and getattr(self, "_self_setattr_tracking", True)): # Mixing restore()-created attributes with user-added trackable # objects is tricky, since we can't use the `_lookup_dependency` trick to # re-create attributes (we might accidentally steal the restoration for @@ -154,4 +154,3 @@ class _NumpyWrapper(core_python_state.PythonState): self.array = numpy.load(string_file, allow_pickle=False) finally: string_file.close() - diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index d7b02b53890..aaabe4e3f57 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -27,13 +27,14 @@ from tensorflow.python.training.tracking import base as trackable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): """Wraps save and restore callbacks as a `SaveableObject`.""" - def __init__(self, name, dtype, save_callback, restore_callback): + def __init__(self, name, dtype, device, save_callback, restore_callback): self._restore_callback = restore_callback spec = saver_lib.BaseSaverBuilder.SaveSpec( tensor=save_callback, slice_spec="", name=name, - dtype=dtype) + dtype=dtype, + device=device) super(_CallbackSaveable, self).__init__( save_callback, [spec], name) @@ -46,12 +47,13 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): class _SplitDependency(trackable.Trackable): """Looks like a regular variable while synchronizing save/restores.""" - def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, - fill_save_buffer_fn, consume_restore_buffer_fn): + def __init__(self, save_buffer, restore_buffer, name, dtype, device, + num_components, fill_save_buffer_fn, consume_restore_buffer_fn): self._save_buffer = save_buffer self._restore_buffer = restore_buffer self._name = name self._dtype = dtype + self._device = device self._num_components = num_components self._fill_save_buffer_fn = fill_save_buffer_fn self._consume_restore_buffer_fn = consume_restore_buffer_fn @@ -86,13 +88,15 @@ class _SplitDependency(trackable.Trackable): trackable.VARIABLE_VALUE_KEY: functools.partial(_CallbackSaveable, dtype=self._dtype, + device=self._device, save_callback=self._save, restore_callback=self._restore) } def split_dependency(component_names, component_dtypes, - fill_save_buffer_fn, consume_restore_buffer_fn): + fill_save_buffer_fn, consume_restore_buffer_fn, + device): """Creates multiple dependencies with a synchronized save/restore. Useful when a single op produces `Tensor`s which should each be saved under @@ -115,6 +119,7 @@ def split_dependency(component_names, component_dtypes, `component_names` as keys mapping to restored individual `Tensor`s and returns a restore op (or if executing eagerly, runs the restoration and may return `None`). + device: The device on which to run save and restore operations. Returns: A dictionary mapping from names to Trackable objects. If one is @@ -130,6 +135,7 @@ def split_dependency(component_names, component_dtypes, restore_buffer=restore_buffer, name=name, dtype=dtype, + device=device, num_components=len(component_names), fill_save_buffer_fn=fill_save_buffer_fn, consume_restore_buffer_fn=consume_restore_buffer_fn) diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 9bc01059481..8660cc12f28 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -54,7 +54,8 @@ class SaveTensorSlicesAsDeps(base.Trackable): fill_save_buffer_fn=_split_variable_closure( self.combined), consume_restore_buffer_fn=_combine_variable_closure( - self.combined)) + self.combined), + device=self.combined.device) for name, dep in split_dependencies.items(): self._track_trackable(dep, name=name) diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py index 76c6bc05ff7..6606213cbfc 100644 --- a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py +++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops.py @@ -35,9 +35,9 @@ class BigQueryReader(io_ops.ReaderBase): # Create the parse_examples list of features. features = dict( - name=tf.FixedLenFeature([1], tf.string), - age=tf.FixedLenFeature([1], tf.int32), - state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK")) + name=tf.io.FixedLenFeature([1], tf.string), + age=tf.io.FixedLenFeature([1], tf.int32), + state=tf.io.FixedLenFeature([1], dtype=tf.string, default_value="UNK")) # Create a Reader. reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT, @@ -48,11 +48,11 @@ class BigQueryReader(io_ops.ReaderBase): features=features) # Populate a queue with the BigQuery Table partitions. - queue = tf.train.string_input_producer(reader.partitions()) + queue = tf.compat.v1.train.string_input_producer(reader.partitions()) # Read and parse examples. row_id, examples_serialized = reader.read(queue) - examples = tf.parse_example(examples_serialized, features=features) + examples = tf.io.parse_example(examples_serialized, features=features) # Process the Tensors examples["name"], examples["age"], etc... ``` diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py index cb45e427342..806d5630464 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -63,7 +63,7 @@ class ConfigureGcsHook(training.SessionRunHook): Example: ``` - sess = tf.Session() + sess = tf.compat.v1.Session() refresh_token = raw_input("Refresh token: ") client_secret = raw_input("Client secret: ") client_id = "" @@ -154,8 +154,8 @@ def configure_gcs(session, credentials=None, block_cache=None, device=None): at https://cloud.google.com/security/encryption-in-transit/. Args: - session: A `tf.Session` session that should be used to configure the GCS - file system. + session: A `tf.compat.v1.Session` session that should be used to configure + the GCS file system. credentials: [Optional.] A JSON string block_cache: [Optional.] A BlockCacheParams to configure the block cache . device: [Optional.] The device to place the configure ops. @@ -186,7 +186,7 @@ def configure_colab_session(session): """ConfigureColabSession configures the GCS file system in Colab. Args: - session: A `tf.Session` session. + session: A `tf.compat.v1.Session` session. """ # Read from the application default credentials (adc). adc_filename = os.environ.get( diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 2ad9ae42a16..44431d5010d 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -522,6 +522,7 @@ if (tensorflow_ENABLE_GPU) "#define CUDA_CUDA_CONFIG_H_\n" "#define TF_CUDA_CAPABILITIES ${TF_CUDA_CAP}\n" "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" + "#define TF_CUDA_LIB_VERSION \"64_${short_CUDA_VER}\"\n" "#define TF_CUDNN_VERSION \"64_${CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" @@ -615,4 +616,4 @@ if(tensorflow_BUILD_SHARED_LIB) endif() if(tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS) include(tf_tests.cmake) -endif() \ No newline at end of file +endif() diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 60ee1b4b3fd..9e9d85def83 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -294,11 +294,12 @@ Here we assume that you have basic knowledge on gathering dependency with `CMakeLists.txt` and the c++ file `main.cxx` 2. Fill in the `main.cxx` with the code provided in [official c++ api basic](https://www.tensorflow.org/api_guides/cc/guide). -3. Fill in the `CMakeLists.txt` with following code: ``` cmake +3. Fill in the `CMakeLists.txt` with following code: + + ```cmake cmake_minimum_required (VERSION 2.6) project (tf_hello) # Tensorflow - find_package(Tensorflow REQUIRED) include_directories(${TENSORFLOW_INCLUDE_DIRS}) @@ -314,7 +315,8 @@ Here we assume that you have basic knowledge on gathering dependency with this CMakeList.txt, under development") endif() add_executable(tf_hello main.cxx) target_link_libraries(tf_hello - ${TENSORFLOW_LIBRARIES}) ``` + ${TENSORFLOW_LIBRARIES}) + ``` 4. Configure the folder with cmake-gui, an error should be prompted out, requesting you to locate the folder containing `TensorflowConfig.cmake`. diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index 6c6a5df7f76..53ad3648d61 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -48,7 +48,7 @@ else (systemlib_ABSEIL_CPP) set(abseil_cpp_STATIC_LIBRARIES ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_malloc_internal.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_throw_delegate.lib ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib @@ -104,4 +104,4 @@ else (systemlib_ABSEIL_CPP) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) -endif (systemlib_ABSEIL_CPP) \ No newline at end of file +endif (systemlib_ABSEIL_CPP) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 30b4e2dbdee..22f4ea89e3c 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 62688b6a05cc85b47fb77dd408611734253e47e2) +set(GRPC_TAG 4566c2a29ebec0835643b972eb99f4306c4234a3) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 32e6d78e508..174f7d1d47f 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -16,8 +16,8 @@ include (ExternalProject) include (GNUInstallDirs) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) -set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) -set(png_HASH SHA256=e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef) +set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.35.tar.gz) +set(png_HASH SHA256=6d59d6a154ccbb772ec11772cb8f8beb0d382b61e7ccc62435bf7311c9f4b210) set(png_BUILD ${CMAKE_BINARY_DIR}/png/src/png) set(png_INSTALL ${CMAKE_BINARY_DIR}/png/install) diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 7f835d2d519..ef9226a9388 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -15,8 +15,8 @@ include (ExternalProject) set(sqlite_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/sqlite) -set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip) -set(sqlite_HASH SHA256=4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc) +set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2019/sqlite-amalgamation-3280000.zip) +set(sqlite_HASH SHA256=d02fc4e95cfef672b45052e221617a050b7f2e20103661cda88387349a9b1327) set(sqlite_BUILD ${CMAKE_CURRENT_BINARY_DIR}/sqlite/src/sqlite) set(sqlite_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/sqlite/install) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index fd205a4b9b0..95e50728773 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -119,11 +119,6 @@ tensorflow/contrib/cloud/python/ops tensorflow/contrib/cluster_resolver tensorflow/contrib/cluster_resolver/python tensorflow/contrib/cluster_resolver/python/training -tensorflow/contrib/coder -tensorflow/contrib/coder/kernels -tensorflow/contrib/coder/ops -tensorflow/contrib/coder/python -tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization tensorflow/contrib/constrained_optimization/python diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index cc263d7995c..24e45236a63 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -274,10 +274,9 @@ if (NOT WIN32) COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} DEPENDS __force_rebuild) + set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) endif() -set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) - ######################################################## # tf_core_framework library ######################################################## diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index d66e39ac07c..e8972098c7e 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -63,11 +63,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder.cc" - "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" - "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 310eed4ecbf..f73f89ce379 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -92,7 +92,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_training "${tensorflow_source_dir}/ten GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 1fe8795ddf0..c0ab327ac94 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -378,8 +378,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_quantiles_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_quantile_ops.py) GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_coder_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index ed31351d9ea..af4483f3cb0 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -163,7 +163,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/profiler/internal/*_test.py" "${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py" "${tensorflow_source_dir}/tensorflow/python/training/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/coder/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/feature_column/python/feature_column/*_test.py" diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD deleted file mode 100644 index 7f96a103d4c..00000000000 --- a/tensorflow/contrib/coder/BUILD +++ /dev/null @@ -1,207 +0,0 @@ -# Description: -# Contains ops related to data compression. - -package(default_visibility = [ - "//learning/brain:__subpackages__", - "//research/vision/piedpiper:__subpackages__", - "//tensorflow:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 - -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", - "tf_custom_op_library", - "tf_gen_op_libs", - "tf_gen_op_wrapper_py", - "tf_kernel_library", - "tf_py_test", -) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") - -cc_library( - name = "range_coder", - srcs = [ - "kernels/range_coder.cc", - ], - hdrs = [ - "kernels/range_coder.h", - ], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "range_coder_test", - size = "small", - srcs = ["kernels/range_coder_test.cc"], - deps = [ - ":range_coder", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -tf_gen_op_libs( - op_lib_names = ["coder_ops"], - deps = [ - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "range_coder_ops_util", - srcs = ["kernels/range_coder_ops_util.cc"], - hdrs = ["kernels/range_coder_ops_util.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], -) - -tf_kernel_library( - name = "range_coder_ops", - srcs = [ - "kernels/range_coder_ops.cc", - ], - visibility = ["//visibility:public"], - deps = [ - ":coder_ops_op_lib", - ":range_coder", - ":range_coder_ops_util", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], - alwayslink = 1, -) - -tf_cc_test( - name = "range_coder_ops_test", - size = "small", - srcs = ["kernels/range_coder_ops_test.cc"], - deps = [ - ":range_coder", - ":range_coder_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels:ops_testutil", - ], -) - -tf_kernel_library( - name = "pmf_to_cdf_op", - srcs = ["kernels/pmf_to_cdf_op.cc"], - visibility = ["//visibility:public"], - deps = [ - ":coder_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "pmf_to_cdf_op_test", - size = "small", - srcs = ["kernels/pmf_to_cdf_op_test.cc"], - deps = [ - ":pmf_to_cdf_op", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels:ops_testutil", - ], -) - -cc_library( - name = "all_ops", - deps = [":coder_ops_op_lib"], -) - -cc_library( - name = "all_kernels", - deps = [ - ":pmf_to_cdf_op", - ":range_coder_ops", - ], -) - -tf_custom_op_library( - name = "python/ops/_coder_ops.so", - srcs = [ - "kernels/pmf_to_cdf_op.cc", - "kernels/range_coder.cc", - "kernels/range_coder.h", - "kernels/range_coder_ops.cc", - "kernels/range_coder_ops_util.cc", - "kernels/range_coder_ops_util.h", - "ops/coder_ops.cc", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_coder_ops", - out = "python/ops/gen_coder_ops.py", - deps = [":coder_ops_op_lib"], -) - -py_library( - name = "coder_py", - srcs = [ - "__init__.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":coder_ops_py", - ], -) - -tf_custom_op_py_library( - name = "coder_ops_py", - srcs = [ - "python/ops/coder_ops.py", - ], - dso = [ - ":python/ops/_coder_ops.so", - ], - kernels = [ - ":all_kernels", - ], - srcs_version = "PY2AND3", - deps = [ - ":gen_coder_ops", - "//tensorflow/contrib/util:util_py", - ], -) - -tf_py_test( - name = "coder_ops_py_test", - srcs = [ - "python/ops/coder_ops_test.py", - ], - additional_deps = [ - ":coder_ops_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - ], - main = "python/ops/coder_ops_test.py", -) diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc deleted file mode 100644 index bd5272ee6f2..00000000000 --- a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include -#include -#include -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace { -using errors::InvalidArgument; - -class PmfToCdfOp : public OpKernel { - public: - explicit PmfToCdfOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); - OP_REQUIRES( - context, 0 < precision_ && precision_ <= 16, - InvalidArgument("`precision` must be in [1, 16]: ", precision_)); - } - - void Compute(OpKernelContext* context) override { - const Tensor& pmf_tensor = context->input(0); - - TensorShape shape = pmf_tensor.shape(); - OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(shape), - InvalidArgument("`pmf` should be at least 1-D.")); - OP_REQUIRES( - context, shape.dim_size(shape.dims() - 1) > 1, - InvalidArgument("`pmf` size should be at least 2 in the last axis.")); - shape.set_dim(shape.dims() - 1, shape.dim_size(shape.dims() - 1) + 1); - - Tensor* cdf_tensor; - OP_REQUIRES_OK(context, context->allocate_output(0, shape, &cdf_tensor)); - - auto pmf = pmf_tensor.flat_inner_dims(); - auto cdf = cdf_tensor->flat_inner_dims(); - CHECK_EQ(pmf.dimension(0), cdf.dimension(0)); - CHECK_EQ(pmf.dimension(1) + 1, cdf.dimension(1)); - - const double n = pmf.dimension(1); - const int64 cost_per_unit = static_cast(50.0 * n * std::log2(n)); - thread::ThreadPool* thread_pool = - context->device()->tensorflow_cpu_worker_threads()->workers; - thread_pool->ParallelFor( - pmf.dimension(0), cost_per_unit, - [this, pmf, &cdf](int64 start, int64 limit) { - const gtl::ArraySlice::size_type pmf_size = pmf.dimension(1); - for (int64 i = start; i < limit; ++i) { - cdf(i, 0) = 0; - PerShard({&pmf(i, 0), pmf_size}, {&cdf(i, 1), pmf_size}); - } - }); - } - - private: - struct PenaltyItem { - PenaltyItem(int32* p, double mass) : pointer(p), mass(mass) { - penalty = ComputeNextPenalty(); - } - - void Decrease() { - CHECK_GT(*pointer, 1); - --*pointer; - penalty = ComputeNextPenalty(); - } - - friend bool operator<(const PenaltyItem& lhs, const PenaltyItem& rhs) { - return lhs.penalty < rhs.penalty; - } - - double ComputeNextPenalty() { - if (*pointer <= 1) { - return std::numeric_limits::infinity(); - } - return mass * (std::log2(*pointer) - std::log2(*pointer - 1)); - } - - int32* pointer; - double mass; - double penalty; - }; - - struct GainItem { - GainItem(int32* p, double mass) : pointer(p), mass(mass) { - gain = ComputeNextGain(); - } - - void Increase() { - CHECK_GT(*pointer, 0); - ++*pointer; - gain = ComputeNextGain(); - } - - friend bool operator>(const GainItem& lhs, const GainItem& rhs) { - return lhs.gain > rhs.gain; - } - - double ComputeNextGain() { - // Never increment zero value to non-zero value. - if (*pointer < 1) { - return -std::numeric_limits::infinity(); - } - return mass * (std::log2(*pointer + 1) - std::log2(*pointer)); - } - - int32* pointer; - double mass; - double gain; - }; - - void PerShard(gtl::ArraySlice pmf, - gtl::MutableArraySlice cdf) const { - CHECK_EQ(pmf.size(), cdf.size()); - - const int32 normalizer = 1 << precision_; - std::transform(pmf.begin(), pmf.end(), cdf.begin(), - [normalizer](float mass) { - int32 value = std::rint(mass * normalizer); - // NOTE: Consider checking if mass > 0. - value = std::max(value, 1); - return value; - }); - - int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0); - if (sum > normalizer) { - std::vector queue; - queue.reserve(cdf.size()); - for (int i = 0; i < cdf.size(); ++i) { - queue.emplace_back(&cdf[i], pmf[i]); - } - - std::sort(queue.begin(), queue.end()); - while (sum-- > normalizer) { - queue[0].Decrease(); - // Performs a linear search because this find_if is likely to return - // iterator very close to the begin. - auto iter = std::find_if( - std::next(queue.begin()), queue.end(), - [&queue](const PenaltyItem& rhs) { return queue[0] < rhs; }); - std::rotate(queue.begin(), std::next(queue.begin()), iter); - } - } else if (sum < normalizer) { - std::vector queue; - queue.reserve(cdf.size()); - for (int i = 0; i < cdf.size(); ++i) { - queue.emplace_back(&cdf[i], pmf[i]); - } - - std::sort(queue.begin(), queue.end(), std::greater()); - while (sum++ < normalizer) { - queue[0].Increase(); - // Performs a linear search because this find_if is likely to return - // iterator very close to the begin. - auto iter = std::find_if( - std::next(queue.begin()), queue.end(), - [&queue](const GainItem& rhs) { return queue[0] > rhs; }); - std::rotate(queue.begin(), std::next(queue.begin()), iter); - } - } - std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); - } - - int precision_; -}; - -REGISTER_KERNEL_BUILDER(Name("PmfToQuantizedCdf").Device(DEVICE_CPU), - PmfToCdfOp); -} // namespace -} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc deleted file mode 100644 index 3408f6b519a..00000000000 --- a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/shape_inference_testutil.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/random/philox_random.h" -#include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/random/simple_philox.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { -class PmfToQuantizedCdfOpTest : public OpsTestBase { - protected: - void SetupOp(int precision, Tensor* input) { - TF_ASSERT_OK(NodeDefBuilder("pmf_to_cdf", "PmfToQuantizedCdf") - .Input(FakeInput(DT_FLOAT)) - .Attr("precision", precision) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - - inputs_.clear(); - inputs_.emplace_back(input); - } - - void GenerateData(random::SimplePhilox* rand, - gtl::MutableArraySlice slice) { - constexpr float minimum = std::numeric_limits::epsilon(); - float sum = 0; - for (float& value : slice) { - value = std::max(rand->RandFloat(), minimum); - sum += value; - } - for (float& value : slice) { - value /= sum; - } - } - - void Verify(int precision, const Tensor& pmf_tensor, - const Tensor& cdf_tensor) { - ASSERT_EQ(pmf_tensor.dims(), cdf_tensor.dims()); - const int n = pmf_tensor.dims(); - - for (int i = 0; i < n - 1; ++i) { - EXPECT_EQ(pmf_tensor.dim_size(i), cdf_tensor.dim_size(i)); - } - - auto pmf = pmf_tensor.flat_inner_dims(); - auto cdf = cdf_tensor.flat_inner_dims(); - EXPECT_EQ(pmf.dimension(1) + 1, cdf.dimension(1)); - - const int normalizer = 1 << precision; - for (int i = 0; i < pmf.dimension(0); ++i) { - EXPECT_EQ(0, cdf(i, 0)); - - TTypes::UnalignedConstVec cdf_slice(&cdf(i, 0), cdf.dimension(1)); - - for (int j = 1; j < cdf_slice.size(); ++j) { - const int32 diff = cdf_slice(j) - cdf_slice(j - 1); - EXPECT_GT(diff, 0); - } - - EXPECT_EQ(cdf_slice(cdf_slice.size() - 1), normalizer); - } - } -}; - -TEST_F(PmfToQuantizedCdfOpTest, UnderSum) { - Tensor pmf(DT_FLOAT, {1, 10, 1, 32}); - auto matrix = pmf.flat_inner_dims(); - const std::size_t n = matrix.dimension(1); - - random::PhiloxRandom gen(random::New64(), random::New64()); - random::SimplePhilox rand(&gen); - for (int64 i = 0; i < matrix.dimension(0); ++i) { - GenerateData(&rand, {&matrix(i, 0), n}); - } - - pmf.flat() = pmf.flat() * 0.85f; - - constexpr int kPrecision = 10; - SetupOp(kPrecision, &pmf); - TF_ASSERT_OK(RunOpKernel()); - - Verify(kPrecision, pmf, *GetOutput(0)); -} - -TEST_F(PmfToQuantizedCdfOpTest, OverSum) { - Tensor pmf(DT_FLOAT, {10, 1, 1, 100}); - auto matrix = pmf.flat_inner_dims(); - - // Half of each PMF is filled with zeros. The op will round up zeros to ones, - // post quantization. These round ups are likely to make the sum over - // normalizer value. - matrix.setZero(); - const std::size_t n = matrix.dimension(1) / 2; - - random::PhiloxRandom gen(random::New64(), random::New64()); - random::SimplePhilox rand(&gen); - for (int64 i = 0; i < matrix.dimension(0); ++i) { - GenerateData(&rand, {&matrix(i, 0), n}); - } - - constexpr int kPrecision = 7; - SetupOp(kPrecision, &pmf); - TF_ASSERT_OK(RunOpKernel()); - - Verify(kPrecision, pmf, *GetOutput(0)); -} - -TEST_F(PmfToQuantizedCdfOpTest, ShapeFn) { - ShapeInferenceTestOp op("PmfToQuantizedCdf"); - - INFER_OK(op, "?", "?"); - INFER_OK(op, "[3]", "[4]"); - INFER_OK(op, "[3,4]", "[d0_0,5]"); - INFER_OK(op, "[3,4,5]", "[d0_0,d0_1,6]"); -} -} // namespace -} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder.cc b/tensorflow/contrib/coder/kernels/range_coder.cc deleted file mode 100644 index 21b35155ff3..00000000000 --- a/tensorflow/contrib/coder/kernels/range_coder.cc +++ /dev/null @@ -1,374 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Range coder implementation, based on [1]. -// -// [1] G. N. N. Martin, "Range coding: an algorithm for removing redundancy from -// a digitised message", presented to the Video & Data Recording Conference, -// held in Southampton, July 24-27, 1979. -// -#include "tensorflow/contrib/coder/kernels/range_coder.h" - -#include -#include - -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -RangeEncoder::RangeEncoder(int precision) : precision_(precision) { - CHECK_GT(precision, 0); - CHECK_LE(precision, 16); -} - -void RangeEncoder::Encode(int32 lower, int32 upper, string* sink) { - // Input requirement: 0 <= lower < upper <= 2^precision. - DCHECK_LE(0, lower); - DCHECK_LT(lower, upper); - DCHECK_LE(upper, 1 << precision_); - - // `base` and `size` represent a half-open interval [base, base + size). - // Loop invariant: 2^16 <= size <= 2^32. - // - // Note that keeping size above 2^16 is important. Since the interval sizes - // are quantized to up to 16 bits, the smallest interval size the encode may - // handle is 2^-16. If size is smaller than 2^16, a small interval input may - // collapse the encoder range into an empty interval. - const uint64 size = static_cast(size_minus1_) + 1; - DCHECK_NE(size >> 16, 0); - - // For short notation, let u := lower and v := upper. - // - // The input u, v represents a half-open interval [u, v) / 2^precision. - // This narrows the current interval roughly to - // [base + (size * u) / 2^precision, base + (size * v) / 2^precision). - // - // TODO(sjhwang): Try rounding if it helps improve compression ratio, at the - // expense of more operations. In the test using Zipf distribution, the - // overhead over the theoretical compression ratio was ~0.01%. - // NOTE: The max value of `size` is 2^32 and size > 0. Therefore `size * u` - // can be rewritten as `(size - 1) * u + u` and all the computation can be - // done in 32-bit mode. If 32-bit multiply is faster, then rewrite. - const uint32 a = (size * static_cast(lower)) >> precision_; - const uint32 b = ((size * static_cast(upper)) >> precision_) - 1; - DCHECK_LE(a, b); - - // Let's confirm the RHS of a, b fit in uint32 type. - // Recall that 0 <= u < 2^precision, and size <= 2^32. Therefore - // (size * u) / 2^precision < size <= 2^32, - // and the value of a fits in uint32 type. Similarly, since v <= 2^precision, - // (size * v) / 2^precision - 1 <= size - 1 < 2^32. - // For lower bound of b, note that 1 <= v, 2^16 <= size, and 16 <= precision. - // Therefore (size * v) / 2^precision - 1 >= 2^16 / 2^precision - 1 >= 0. - - // The new interval is [base + a, base + b] = [base + a, base + b + 1). - base_ += a; // May overflow. - size_minus1_ = b - a; - const bool base_overflow = (base_ < a); - - // The encoder has two states. Let's call them state 0 and state 1. - // State 0 is when base < base + size <= 2^32. - // State 1 is when base < 2^32 < base + size. - // - // The encoder initially starts in state 0, with base = 0, size = 2^32. - // - // TODO(sjhwang): Requires some profiling, but the encoder stays in state 0 - // most of the time. Should optimize code for state 0. - // - // Each Encode() has up to two places where the interval changes: - // #1. Refine the interval. [base, base + size) -> [base + a, base + b + 1). - // #2. Expand interval if the new size is too small, - // and each change may cause a state transition. - // - // First, consider when the current state is 0. - // - // In this case, the next state after #1 is always state 0, since refining - // interval only shrinks the interval, therefore new_base + new_size <= 2^32. - // - // Let us explain #2. - // - // Recall that at the beginning of each Encode(), the encoder requires - // 2^16 < size <= 2^32. As precision <= 16, the new interval size can be as - // small as 1, but never zero. - // - // To keep size above 2^16, if new size is smaller than or equal to 2^16, the - // encoder would left-shift base and size by 16 bits: size' <- size * 2^16. - // Note that new size' is now in the range [2^16, 2^32]. - // - // Since size is left-shifted, the same should be applied to base as well. - // However, after the left-shift, base will then contain 48 bits instead of 32 - // bits. Therefore prior to the shift, The upper 16 bits in base should be - // stored somewhere else. - // - // If the upper 16 bits of all values in the interval were the same, i.e., if - // base[32:16] == (base + size - 1)[32:16], then base[32:16] can be written - // out to `output` string, since any further Encode() only narrows down the - // interval and that 16 bits would never change. - // - // If the upper 16 bits were not all the same, since this happens only when - // size <= 2^16, the upper 16 bits may differ only by one, i.e., - // base[32:16] + 1 == (base + size - 1)[32:16]. At this stage, it is not - // determined yet whether base[32:16] should be written to the output or - // (base[32:16] + 1) should be written to the output. In this case, - // (base[32:16] + 1) is temporarily stored in `delay`, and base is - // left-shifted by 16 bits. - // - // In the latter case, the condition implies that (base // 2^16) and - // ((base + size - 1) // 2^16) were different. Therefore after left-shift by - // 16 bits, the new (base + size) is greater than 2^32, i.e., the encoder - // transition to state 1. - // - // ==== Summary ==== - // To detect the current encoder state, - // state 0: delay == 0 iff (base mod 2^32) < (base + size) mod 2^32, - // state 1: delay != 0 iff (base + size) mod 2^32 <= base mod 2^32, - // because size <= 2^32. - // - // ==== Summary for state 0 ==== - // 1. Interval refinement does not cause state transition. - // 2. Interval expansion may cause state transition, depending on the upper 16 - // bits of base and base + size - 1. - // - // Now suppose the previous state was 1. This means that - // base <= 2^32 < base + size. - // - // When in state 1, an interval refinement may trigger state transition. - // After Encode() refines the interval, there are three possibilities: - // #1. base <= 2^32 < base + size (unchanged), - // #2. 2^32 <= base < base + size (base overflowed), - // #3. base < base + size <= 2^32 (base + size - 1 underflowed). - // - // In case #1, the encoder remains in state 1. - // In case #2 or #3, the encoder state changes to state 0. - // - // ==== State transition for interval refinement ==== - // 1. state 0 -> state 0, - // 2. state 1 -> state 0 or state 1. - // - // Therefore if the new state is 1, then the previous state must have been - // state 1. - if (base_ + size_minus1_ < base_) { - // If statement checked if 2^32 < base + size. The new state is 1, hence the - // previous state was also state 1. - DCHECK_NE(((base_ - a) + size) >> 32, 0); - DCHECK_NE(delay_ & 0xFFFF, 0); - - // Like in state 0, if the new size is <= 2^16, then base and size should - // be left-shifted by 16 bits. Combine the conditions - // base <= 2^32 < base + size and size <= 2^16 to conclude that - // base[32:16] >= 0xFFFF and (base + size - 1)[32:16] = 0x0000. - // - // Note that 2^32 - base < size, and since base is at least 0xFFFF0000, - // 2^16 - base[16:0] < size. Let base' and size' be the new base and size - // after the bit-shift. Then 2^32 - base' < size' => 2^32 < base' + size'. - // Therefore the encoder remains in state 1. - // - // Lastly, `delay` is modified. Conceptually, delay has to be changed to - // delay' <- delay * 2^16 + (base + size - 1)[32:16]. - // Since we know above that (base + size - 1)[32:16] = 0x0000, there is no - // need to explicitly do the computation above, but rather store how many - // trailing zeros there were. For this reason, the lower 16 bits of - // `delay` stores the delayed value when state changed from 0 to 1, and - // delay[32:16] stores the # of trailing zeros (in bytes). - // - // ==== State transition for interval expansion ==== - // 1. state 0 -> state 0 or state 1, - // 2. state 1 -> state 1. - if (size_minus1_ >> 16 == 0) { - DCHECK_EQ(base_ >> 16, 0xFFFF); - base_ <<= 16; - size_minus1_ <<= 16; - size_minus1_ |= 0xFFFF; - // TODO(sjhwang): It is possible that for very long input, delay - // overflow during below. If overflow is detected, this delay is too - // long the encoder should forcefully move to state 0. In such case, - // base can be raised to 2^32 (force case #2), or (base + size) can be - // lowered to 2^32 (force case #3), depending on which transition - // keeps size larger. - CHECK_LT(delay_, static_cast(1) << 62); - delay_ += 0x20000; // Two more bytes of zeros. Check overflow? - } - return; - } - - // If reached here, the current state is 0. - // First handle the case when the previous state was state 1. - if (delay_ != 0) { - // In case #2 or #3, the encoder state changes to state 0. Recall that when - // the encoder state changed from state 0 to state 1, the top 16 bits of - // (base + size - 1) was temporarily stored in `delay`, because the output - // could be either (delay - 1) or (delay). - // - // And from above, the delayed value encoded in `delay` is - // delay' <- delay[16:0] * 2^(8 * delay[MAX:16]) - // - // In case #2, the interval moved below 2^32. So (delay' - 1) is the - // converged value after interval refinements. Write out - // (delay[16:0] - 1) and write (8 * delay[MAX:16]) bytes of 0xFF. - // - // In case #3, the interval moved above 2^32. So delay' is the converged - // value after interval refinement. Write out delay[16:0] and write - // (8 * delay[MAX:16]) bytes of 0x00. - if (base_overflow) { - // Case #2. - DCHECK_NE((static_cast(base_ - a) + a) >> 32, 0); - sink->push_back(static_cast(delay_ >> 8)); - sink->push_back(static_cast(delay_ >> 0)); - sink->append(delay_ >> 16, static_cast(0)); - } else { - // Case #3. - DCHECK_EQ(static_cast(base_ + size_minus1_) >> 32, 0); - --delay_; - sink->push_back(static_cast(delay_ >> 8)); - sink->push_back(static_cast(delay_ >> 0)); - sink->append(delay_ >> 16, static_cast(0xFF)); - } - // Reset to state 0. - delay_ = 0; - } - - if (size_minus1_ >> 16 == 0) { - const uint32 top = base_ >> 16; - - base_ <<= 16; - size_minus1_ <<= 16; - size_minus1_ |= 0xFFFF; - - if (base_ <= base_ + size_minus1_) { - // Still in state 0. Write the top 16 bits. - sink->push_back(static_cast(top >> 8)); - sink->push_back(static_cast(top)); - } else { - // New state is 1. - DCHECK_LT(top, 0xFFFF); - delay_ = top + 1; - } - } -} - -void RangeEncoder::Finalize(string* sink) { - // Finalize the encode by writing out any number in the interval - // [base, base + size). - // - // Trailing zeros are not explicitly written out as decoder can fill in zeros - // by default. - if (delay_ != 0) { - // The last state was state 1. Since base < 2^32 < base + size, pick 2^32 - // (state 1, case #3). - // NOTE: It is a bit difficult to trigger this code path on purpose. - // TODO(sjhwang): Find a way to trigger this code path for test coverage. - sink->push_back(static_cast(delay_ >> 8)); - if ((delay_ & 0xFF) != 0) { - sink->push_back(static_cast(delay_)); - } - } else if (base_ != 0) { - // If base == 0, then pick 0 from [base, base + size) and no zeros are - // explicitly written. - // - // Otherwise, pick (base + (2^16 - base[16:0])), i.e., round up base to the - // next multiple of 2^16. As 2^16 < size, this value should be in the - // interval [base, base + size). - const uint32 mid = ((base_ - 1) >> 16) + 1; - DCHECK_EQ(mid & 0xFFFF, mid); - sink->push_back(static_cast(mid >> 8)); - if ((mid & 0xFF) != 0) { - sink->push_back(static_cast(mid >> 0)); - } - } - - base_ = 0; - size_minus1_ = std::numeric_limits::max(); - delay_ = 0; -} - -RangeDecoder::RangeDecoder(const string& source, int precision) - : current_(source.begin()), - begin_(source.begin()), - end_(source.end()), - precision_(precision) { - CHECK_LE(precision, 16); - - Read16BitValue(); - Read16BitValue(); -} - -int32 RangeDecoder::Decode(tensorflow::gtl::ArraySlice cdf) { - const uint64 size = static_cast(size_minus1_) + 1; - const uint64 offset = - ((static_cast(value_ - base_) + 1) << precision_) - 1; - - // This is similar to std::lower_range() with std::less_equal as comparison. - // After the binary search, `pv` points to the smallest number v that - // satisfies offset < (size * v) / 2^precision. - - // Assumes that cdf[0] == 0. Therefore (size * cdf[0]) / 2^precision is always - // less than or equal to offset. - const int32* pv = cdf.data() + 1; - // `len` can be cdf.size() - 2 if there is guarantee that the last element of - // cdf is 2^precision. - auto len = cdf.size() - 1; - DCHECK_GT(len, 0); - - do { - const auto half = len / 2; - const int32* mid = pv + half; - DCHECK_GE(*mid, 0); - DCHECK_LE(*mid, 1 << precision_); - if (size * static_cast(*mid) <= offset) { - pv = mid + 1; - len -= half + 1; - } else { - len = half; - } - } while (len > 0); - - // If (size * v) / 2^precision <= offset for all v in cdf, then pv points to - // one after the last element of cdf. That is a decoding error. - // - // TODO(sjhwang): Consider returning -1 to indicate error. Or start len = - // cdf.size() - 2 instead and give up detecting this error. - CHECK_LT(pv, cdf.data() + cdf.size()); - - const uint32 a = (size * static_cast(*(pv - 1))) >> precision_; - const uint32 b = ((size * static_cast(*pv)) >> precision_) - 1; - DCHECK_LE(a, offset >> precision_); - DCHECK_LE(offset >> precision_, b); - - base_ += a; - size_minus1_ = b - a; - - if (size_minus1_ >> 16 == 0) { - base_ <<= 16; - size_minus1_ <<= 16; - size_minus1_ |= 0xFFFF; - - Read16BitValue(); - } - - return pv - cdf.data() - 1; -} - -void RangeDecoder::Read16BitValue() { - value_ <<= 8; - if (current_ != end_) { - value_ |= static_cast(*current_++); - } - value_ <<= 8; - if (current_ != end_) { - value_ |= static_cast(*current_++); - } -} -} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder.h b/tensorflow/contrib/coder/kernels/range_coder.h deleted file mode 100644 index f46413072e3..00000000000 --- a/tensorflow/contrib/coder/kernels/range_coder.h +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ -#define TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ - -#include -#include - -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -class RangeEncoder { - public: - // `precision` determines the granularity of probability masses passed to - // Encode() function below. - // - // REQUIRES: 0 < precision <= 16. - explicit RangeEncoder(int precision); - - // Encodes a half-open interval [lower / 2^precision, upper / 2^precision). - // Suppose each character to be encoded is from an integer-valued - // distribution. When encoding a random character x0, the arguments lower and - // upper represent - // Pr(X < x0) = lower / 2^precision, - // Pr(X < x0 + 1) = upper / 2^precision, - // where X is a random variable following the distribution. - // - // For example, assume that the distribution has possible outputs 0, 1, 2, ... - // To encode value 0, lower = 0 and upper = Pr(X = 0). - // To encode value 1, lower = Pr(X = 0) and upper = Pr(X = 0 or 1). - // To encode value 2, lower = Pr(X = 0 or 1) and upper = Pr(X = 0, 1, or 2). - // ... - // - // REQUIRES: 0 <= lower < upper <= 2^precision. - void Encode(int32 lower, int32 upper, string* sink); - - // The encode may contain some under-determined values from previous encoding. - // After Encode() calls, Finalize() must be called. Otherwise the encoded - // string may not be decoded. - void Finalize(string* sink); - - private: - uint32 base_ = 0; - uint32 size_minus1_ = std::numeric_limits::max(); - uint64 delay_ = 0; - - const int precision_; -}; - -class RangeDecoder { - public: - // Holds a reference to `source`. The caller has to make sure that `source` - // outlives the decoder object. - // - // REQUIRES: `precision` must be the same as the encoder's precision. - // REQUIRES: 0 < precision <= 16. - RangeDecoder(const string& source, int precision); - - // Decodes a character from `source` using CDF. The size of `cdf` should be - // one more than the number of the character in the alphabet. - // - // If x0, x1, x2, ... are the possible characters (in increasing order) from - // the distribution, then - // cdf[0] = 0 - // cdf[1] = Pr(X <= x0), - // cdf[2] = Pr(X <= x1), - // cdf[3] = Pr(X <= x2), - // ... - // - // The returned value is an index to `cdf` where the decoded character - // corresponds to. - // - // REQUIRES: cdf.size() > 1. - // REQUIRES: cdf[i] <= cdf[i + 1] for i = 0, 1, ..., cdf.size() - 2. - // REQUIRES: cdf[cdf.size() - 1] <= 2^precision. - // - // In practice the last element of `cdf` should equal to 2^precision. - int32 Decode(gtl::ArraySlice cdf); - - private: - void Read16BitValue(); - - uint32 base_ = 0; - uint32 size_minus1_ = std::numeric_limits::max(); - uint32 value_ = 0; - - string::const_iterator current_; - const string::const_iterator begin_; - const string::const_iterator end_; - - const int precision_; -}; -} // namespace tensorflow - -#endif // TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops.cc b/tensorflow/contrib/coder/kernels/range_coder_ops.cc deleted file mode 100644 index cde7982530f..00000000000 --- a/tensorflow/contrib/coder/kernels/range_coder_ops.cc +++ /dev/null @@ -1,307 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/coder/kernels/range_coder.h" -#include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace { -// A helper class to iterate over data and cdf simultaneously, while cdf is -// broadcasted to data. -// NOTE: Moving this class out of anonymous namespace impacts compiler -// optimization and affects performance. When moving this code around (e.g., -// into a library header), be sure to check the benchmark tests. -template -class BroadcastRange { - public: - BroadcastRange(T* data_pointer, gtl::ArraySlice data_shape, - const U* cdf_pointer, gtl::ArraySlice cdf_shape) - : data_pointer_(data_pointer), cdf_pointer_(cdf_pointer) { - CHECK(!data_shape.empty()); - CHECK_EQ(data_shape.size(), N); - CHECK_EQ(cdf_shape.size(), N + 1); - - std::copy(data_shape.begin(), data_shape.end(), &data_shape_[0]); - data_index_.fill(0); - - const int64 innermost_stride = cdf_shape[N]; - cdf_displace_.fill(innermost_stride); - - // Pre-compute the pointer displacement for cdf. - int64 stride = innermost_stride; - for (int i = N - 1; i >= 0; --i) { - const bool broadcasting = (cdf_shape[i] <= 1); - - // When the data linear index advances by one, the cdf linear index - // advances by `innermost_stride`. - // - // Suppose that the i-th axis coordinate of data increased by one, and - // that i-th axis is broadcasting. The cdf linear index should be wound - // back by i-th axis stride, so that i-th axis coordinate of cdf is - // effectively kept at 0. - if (broadcasting) { - cdf_displace_[i] -= stride; - } - stride *= cdf_shape[i]; - } - } - - // Returns the pointers to the current iterating locations to data and cdf - // tensors. - // - // Note that this function does not track whether data pointer is running past - // the end of data buffer. The caller has to make sure Next() is called no - // more than that. - std::pair Next() { - std::pair return_value = {data_pointer_, cdf_pointer_}; - - int i = N - 1; - for (; i > 0; --i) { - ++data_index_[i]; - if (data_index_[i] < data_shape_[i]) { - break; - } - data_index_[i] = 0; - } - - // Advance data pointer by one. - data_pointer_ += 1; - - // For cdf pointer, it's more complicated because of broadcasting. When i-th - // coordinate increase by one, and if i-th axis is broadcasting, then we - // need to rewind back the pointer so that the effective i-th axis - // coordinate for cdf is always 0. This value is precomputed as - // cdf_displace_. - cdf_pointer_ += cdf_displace_[i]; - return return_value; - } - - private: - std::array data_shape_; - std::array cdf_displace_; - std::array data_index_; - - T* data_pointer_; - const U* cdf_pointer_; -}; - -Status CheckCdfShape(const TensorShape& data_shape, - const TensorShape& cdf_shape) { - if (TF_PREDICT_FALSE(cdf_shape.dims() != data_shape.dims() + 1)) { - return errors::InvalidArgument( - "`cdf` should have one more axis than `data`: data shape=", - data_shape.DebugString(), ", cdf shape=", cdf_shape.DebugString()); - } - - if (TF_PREDICT_FALSE(cdf_shape.dim_size(cdf_shape.dims() - 1) <= 1)) { - return errors::InvalidArgument( - "The last dimension of `cdf` should be > 1: ", cdf_shape.DebugString()); - } - - return Status::OK(); -} - -// Non-incremental encoder op ------------------------------------------------- -class RangeEncodeOp : public OpKernel { - public: - explicit RangeEncodeOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); - OP_REQUIRES(context, 0 < precision_ && precision_ <= 16, - errors::InvalidArgument("`precision` must be in [1, 16]: ", - precision_)); - } - - void Compute(OpKernelContext* context) override { - const Tensor& data = context->input(0); - const Tensor& cdf = context->input(1); - - OP_REQUIRES_OK(context, CheckCdfShape(data.shape(), cdf.shape())); - - std::vector data_shape, cdf_shape; - OP_REQUIRES_OK( - context, MergeAxes(data.shape(), cdf.shape(), &data_shape, &cdf_shape)); - - Tensor* output_tensor; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape{}, &output_tensor)); - string* output = &output_tensor->scalar()(); - - switch (data_shape.size()) { -#define RANGE_ENCODE_CASE(dims) \ - case dims: { \ - RangeEncodeImpl(data.flat(), data_shape, \ - cdf.flat_inner_dims(), cdf_shape, output); \ - } break - RANGE_ENCODE_CASE(1); - RANGE_ENCODE_CASE(2); - RANGE_ENCODE_CASE(3); - RANGE_ENCODE_CASE(4); - RANGE_ENCODE_CASE(5); - RANGE_ENCODE_CASE(6); -#undef RANGE_ENCODE_CASE - default: - context->CtxFailure(errors::InvalidArgument( - "Irregular broadcast pattern: ", data.shape().DebugString(), ", ", - cdf.shape().DebugString())); - return; - } - } - - private: - template - void RangeEncodeImpl(TTypes::ConstFlat data, - gtl::ArraySlice data_shape, - TTypes::ConstMatrix cdf, - gtl::ArraySlice cdf_shape, string* output) const { - const int64 data_size = data.size(); - const int64 cdf_size = cdf.size(); - const int64 chip_size = cdf.dimension(1); - - BroadcastRange view{data.data(), data_shape, - cdf.data(), cdf_shape}; - RangeEncoder encoder{precision_}; - for (int64 linear = 0; linear < data_size; ++linear) { - const auto pair = view.Next(); - - const int64 index = *pair.first; - DCHECK_GE(index, 0); - DCHECK_LT(index + 1, chip_size); - - const int32* cdf_slice = pair.second; - DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size); - - const int32 lower = cdf_slice[index]; - const int32 upper = cdf_slice[index + 1]; - encoder.Encode(lower, upper, output); - } - - encoder.Finalize(output); - } - - int precision_; -}; - -REGISTER_KERNEL_BUILDER(Name("RangeEncode").Device(DEVICE_CPU), RangeEncodeOp); - -// Non-incremental decoder op ------------------------------------------------- -class RangeDecodeOp : public OpKernel { - public: - explicit RangeDecodeOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); - OP_REQUIRES(context, 0 < precision_ && precision_ <= 16, - errors::InvalidArgument("`precision` must be in [1, 16]: ", - precision_)); - } - - void Compute(OpKernelContext* context) override { - const Tensor& encoded_tensor = context->input(0); - const Tensor& shape = context->input(1); - const Tensor& cdf = context->input(2); - - OP_REQUIRES(context, TensorShapeUtils::IsScalar(encoded_tensor.shape()), - errors::InvalidArgument("Invalid `encoded` shape: ", - encoded_tensor.shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()), - errors::InvalidArgument("Invalid `shape` shape: ", - shape.shape().DebugString())); - TensorShape output_shape; - OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(shape.vec(), - &output_shape)); - OP_REQUIRES_OK(context, CheckCdfShape(output_shape, cdf.shape())); - - std::vector data_shape, cdf_shape; - OP_REQUIRES_OK( - context, MergeAxes(output_shape, cdf.shape(), &data_shape, &cdf_shape)); - - const string& encoded = encoded_tensor.scalar()(); - - Tensor* output; - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - - switch (data_shape.size()) { -#define RANGE_DECODE_CASE(dim) \ - case dim: { \ - RangeDecodeImpl(output->flat(), data_shape, \ - cdf.flat_inner_dims(), cdf_shape, encoded); \ - } break - RANGE_DECODE_CASE(1); - RANGE_DECODE_CASE(2); - RANGE_DECODE_CASE(3); - RANGE_DECODE_CASE(4); - RANGE_DECODE_CASE(5); - RANGE_DECODE_CASE(6); -#undef RANGE_DECODE_CASE - default: - context->CtxFailure(errors::InvalidArgument( - "Irregular broadcast pattern: ", output_shape.DebugString(), ", ", - cdf.shape().DebugString())); - return; - } - } - - private: - template - void RangeDecodeImpl(TTypes::Flat output, - gtl::ArraySlice output_shape, - TTypes::ConstMatrix cdf, - gtl::ArraySlice cdf_shape, - const string& encoded) const { - BroadcastRange view{output.data(), output_shape, - cdf.data(), cdf_shape}; - - RangeDecoder decoder{encoded, precision_}; - - const int64 output_size = output.size(); - const int64 cdf_size = cdf.size(); - const auto chip_size = - static_cast::size_type>(cdf.dimension(1)); - - for (int64 i = 0; i < output_size; ++i) { - const auto pair = view.Next(); - - int16* data = pair.first; - DCHECK_LT(data, output.data() + output_size); - - const int32* cdf_slice = pair.second; - DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size); - - *data = decoder.Decode(gtl::ArraySlice{cdf_slice, chip_size}); - } - } - - int precision_; -}; - -REGISTER_KERNEL_BUILDER(Name("RangeDecode").Device(DEVICE_CPU), RangeDecodeOp); -} // namespace -} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc deleted file mode 100644 index 81b36ca902b..00000000000 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc +++ /dev/null @@ -1,520 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/contrib/coder/kernels/range_coder.h" -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/common_runtime/shape_refiner.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/graph/testlib.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/random/simple_philox.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/public/session_options.h" - -namespace tensorflow { -namespace { -int LogUniform(random::SimplePhilox* gen, uint32 n) { - CHECK_GT(n, 0); - - // Split [0, n) into {0}, [1, 2), [2, 4), [4, 8), ..., [2^(m-1), n). - const int m = Log2Ceiling(n); - - int outcome; - do { - // Uniform() consumes at least 32 bits per call, therefore this is somewhat - // wasteful implementation. Since this is used only for test, we do not - // refine this implementation further. - const int k = gen->Uniform(m + 1) - 1; - // If k == -1, then sample from {0}. - // If k == 0, then sample from [1, 2). - // If k == 1, then sample from [2, 4), ... and so on. - if (k < 1) { - outcome = k + 1; - } else { - outcome = (1 << k) + gen->Uniform(1 << k); - } - } while (n <= outcome); - return outcome; -} - -std::vector ComputeStrides(const TensorShape& shape) { - std::vector stride(shape.dims()); - int64 current = 1; - for (int i = shape.dims() - 1; i >= 0; --i) { - stride[i] = current; - current *= shape.dim_size(i); - } - return stride; -} - -class RangeCoderOpsTest : public OpsTestBase { - protected: - Status RunEncodeOp(int precision, gtl::ArraySlice input, - Tensor* output) { - TF_RETURN_IF_ERROR(NodeDefBuilder("encode", "RangeEncode") - .Input(tensorflow::FakeInput(DT_INT16)) - .Input(tensorflow::FakeInput(DT_INT32)) - .Attr("precision", precision) - .Finalize(node_def())); - TF_RETURN_IF_ERROR(InitOp()); - - inputs_.clear(); - std::vector copies(input.size()); - for (int i = 0; i < input.size(); ++i) { - copies[i] = input[i]; - inputs_.emplace_back(&copies[i]); - } - - TF_RETURN_IF_ERROR(RunOpKernel()); - - *output = *GetOutput(0); - inputs_.clear(); - - return Status::OK(); - } - - Status RunDecodeOp(int precision, gtl::ArraySlice input, - Tensor* output) { - TF_RETURN_IF_ERROR(NodeDefBuilder("decode", "RangeDecode") - .Input(tensorflow::FakeInput(DT_STRING)) - .Input(tensorflow::FakeInput(DT_INT32)) - .Input(tensorflow::FakeInput(DT_INT32)) - .Attr("precision", precision) - .Finalize(node_def())); - TF_RETURN_IF_ERROR(InitOp()); - - inputs_.clear(); - std::vector copies(input.size()); - for (int i = 0; i < input.size(); ++i) { - copies[i] = input[i]; - inputs_.emplace_back(&copies[i]); - } - - TF_RETURN_IF_ERROR(RunOpKernel()); - - *output = *GetOutput(0); - inputs_.clear(); - - return Status::OK(); - } - - void TestEncodeAndDecode(int precision, const Tensor& data, - const Tensor& cdf) { - Tensor encoded; - TF_ASSERT_OK(RunEncodeOp(precision, {data, cdf}, &encoded)); - - const TensorShape& data_shape = data.shape(); - Tensor shape{DT_INT32, {data_shape.dims()}}; - for (int i = 0; i < data_shape.dims(); ++i) { - shape.flat()(i) = data_shape.dim_size(i); - } - - Tensor decoded; - TF_ASSERT_OK(RunDecodeOp(precision, {encoded, shape, cdf}, &decoded)); - - EXPECT_EQ(decoded.dtype(), data.dtype()); - EXPECT_EQ(decoded.shape(), data.shape()); - EXPECT_EQ(decoded.tensor_data(), data.tensor_data()); - } - - void PopulateMaxValues(random::SimplePhilox* gen, Tensor* maxvalue_tensor, - int min_maxvalue, int max_maxvalue) { - const int range = max_maxvalue - min_maxvalue; - TTypes::Flat flat = maxvalue_tensor->flat(); - - for (int64 i = 0; i < flat.size(); ++i) { - flat(i) = min_maxvalue + gen->Uniform(range); - } - } - - void BuildCdf(random::SimplePhilox* gen, Tensor* data_tensor, - Tensor* cdf_tensor, const Tensor& maxvalue_tensor) { - CHECK(TensorShapeUtils::StartsWith(cdf_tensor->shape(), - maxvalue_tensor.shape())); - CHECK_EQ(cdf_tensor->dims(), maxvalue_tensor.dims() + 1); - const int64 chip_size = cdf_tensor->dim_size(cdf_tensor->dims() - 1); - - std::vector data_stride = ComputeStrides(data_tensor->shape()); - std::vector cdf_stride = ComputeStrides(cdf_tensor->shape()); - - for (int i = 0; i < cdf_tensor->dims(); ++i) { - if (cdf_tensor->dim_size(i) == 1) { - cdf_stride[i] = 0; - } - } - - Tensor histogram_tensor{DT_INT32, cdf_tensor->shape()}; - TTypes::Flat data = data_tensor->flat(); - TTypes::Flat histogram = histogram_tensor.flat(); - TTypes::ConstFlat maxvalue = maxvalue_tensor.flat(); - histogram.setZero(); - - for (int64 index = 0; index < data.size(); ++index) { - int64 temp = index; - int64 offset = 0; - for (int dim = 0; dim < data_stride.size(); ++dim) { - const int64 coord = temp / data_stride[dim]; - offset += coord * cdf_stride[dim]; - temp -= coord * data_stride[dim]; - } - ASSERT_EQ(temp, 0); - - const int64 maxvalue_offset = offset / chip_size; - CHECK_EQ(maxvalue_offset * chip_size, offset); - CHECK_LT(maxvalue(maxvalue_offset) + 1, chip_size); - const int value = LogUniform(gen, maxvalue(maxvalue_offset)); - data(index) = value; - histogram(offset + value + 1) += 1; - } - - cdf_tensor->flat_inner_dims() = - histogram_tensor.flat_inner_dims().cumsum(1); - } -}; - -TEST_F(RangeCoderOpsTest, NoBroadcast) { - constexpr int kPrecision = 14; - constexpr int kMaxValue = 10; - - Tensor data{DT_INT16, {1, 32, 32, 16}}; - Tensor temp{DT_INT32, {1, 1, 1, 1, kMaxValue + 2}}; - Tensor maxvalue{DT_INT16, {1, 1, 1, 1}}; - maxvalue.flat()(0) = kMaxValue; - - ASSERT_LE(data.shape().num_elements(), 1 << kPrecision); - - random::PhiloxRandom philox(random::New64(), random::New64()); - random::SimplePhilox gen(&philox); - BuildCdf(&gen, &data, &temp, maxvalue); - - const Eigen::array broadcast = {1, 32, 32, 16, 1}; - - Tensor cdf{DT_INT32, {1, 32, 32, 16, kMaxValue + 2}}; - cdf.tensor() = temp.tensor().broadcast(broadcast); - - TestEncodeAndDecode(kPrecision, data, cdf); -} - -TEST_F(RangeCoderOpsTest, Broadcast1Axis) { - constexpr int kPrecision = 9; - constexpr int kDimensionSize = 1 << kPrecision; - constexpr int kMinMaxValue = 10; - constexpr int kMaxMaxValue = 64; - - random::PhiloxRandom philox(random::New64(), random::New64()); - random::SimplePhilox gen(&philox); - Tensor data{DT_INT16, {1, kDimensionSize, kDimensionSize}}; - - Tensor maxvalue{DT_INT16, {kDimensionSize}}; - PopulateMaxValues(&gen, &maxvalue, kMinMaxValue, kMaxMaxValue); - - { - // Axis 1. - Tensor maxvalue1; - ASSERT_TRUE(maxvalue1.CopyFrom(maxvalue, {1, 1, kDimensionSize})); - - Tensor cdf{DT_INT32, {1, 1, kDimensionSize, kMaxMaxValue + 2}}; - BuildCdf(&gen, &data, &cdf, maxvalue1); - TestEncodeAndDecode(kPrecision, data, cdf); - } - - { - // Axis 2. - Tensor maxvalue2; - ASSERT_TRUE(maxvalue2.CopyFrom(maxvalue, {1, kDimensionSize, 1})); - - Tensor cdf{DT_INT32, {1, kDimensionSize, 1, kMaxMaxValue + 2}}; - BuildCdf(&gen, &data, &cdf, maxvalue2); - TestEncodeAndDecode(kPrecision, data, cdf); - } -} - -TEST_F(RangeCoderOpsTest, Broadcast2Axes) { - constexpr int kPrecision = 13; - constexpr int kDimensionSize1 = 1 << (kPrecision / 2); - constexpr int kDimensionSize2 = 1 << (kPrecision - kPrecision / 2); - constexpr int kMinMaxValue = 10; - constexpr int kMaxMaxValue = 64; - - random::PhiloxRandom philox(random::New64(), random::New64()); - random::SimplePhilox gen(&philox); - Tensor maxvalue{DT_INT16, {2, 1, 1, 7}}; - PopulateMaxValues(&gen, &maxvalue, kMinMaxValue, kMaxMaxValue); - - Tensor data{DT_INT16, {2, kDimensionSize1, kDimensionSize2, 7}}; - Tensor cdf{DT_INT32, {2, 1, 1, 7, kMaxMaxValue + 2}}; - BuildCdf(&gen, &data, &cdf, maxvalue); - TestEncodeAndDecode(kPrecision, data, cdf); -} - -TEST_F(RangeCoderOpsTest, InvalidCdfShape) { - Tensor data{DT_INT16, {3, 3}}; - Tensor cdf{DT_INT32, {3, 3}}; - - Tensor unused; - { - const Status status = RunEncodeOp(10, {data, cdf}, &unused); - EXPECT_FALSE(status.ok()); - EXPECT_NE(status.error_message().find("`cdf` should have one more axis"), - string::npos); - } - - Tensor empty{DT_STRING, {}}; - Tensor shape{DT_INT32, {2}}; - shape.vec().setValues({3, 3}); - { - const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); - EXPECT_FALSE(status.ok()); - EXPECT_NE(status.error_message().find("`cdf` should have one more axis"), - string::npos); - } - - cdf = Tensor{DT_INT32, {3, 3, 1}}; - { - const Status status = RunEncodeOp(10, {data, cdf}, &unused); - EXPECT_FALSE(status.ok()); - EXPECT_NE( - status.error_message().find("last dimension of `cdf` should be > 1"), - string::npos); - } - { - const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); - EXPECT_FALSE(status.ok()); - EXPECT_NE( - status.error_message().find("last dimension of `cdf` should be > 1"), - string::npos); - } -} - -TEST_F(RangeCoderOpsTest, DecoderShapeFn) { - Tensor encoded_tensor{DT_STRING, {}}; - Tensor shape_tensor{DT_INT32, {3}}; - Tensor cdf_tensor{DT_INT32, {4, 6, 8, 2}}; - - shape_tensor.flat().setValues({4, 6, 8}); - - Graph g{OpRegistry::Global()}; - Node* encoded = test::graph::Constant(&g, encoded_tensor); - Node* shape = test::graph::Constant(&g, shape_tensor); - Node* cdf = test::graph::Constant(&g, cdf_tensor); - Node* decode; - TF_ASSERT_OK(NodeBuilder("range_decode", "RangeDecode", g.op_registry()) - .Input(encoded) - .Input(shape) - .Input(cdf) - .Attr("precision", 10) - .Finalize(&g, &decode)); - - ShapeRefiner refiner{g.versions().producer(), g.op_registry()}; - TF_ASSERT_OK(refiner.AddNode(encoded)); - TF_ASSERT_OK(refiner.AddNode(shape)); - TF_ASSERT_OK(refiner.AddNode(cdf)); - TF_ASSERT_OK(refiner.AddNode(decode)); - - auto* context = refiner.GetContext(decode); - ASSERT_NE(context, nullptr); - - ASSERT_EQ(context->num_outputs(), 1); - auto shape_handle = context->output(0); - - ASSERT_EQ(context->Rank(shape_handle), 3); - EXPECT_EQ(context->Value(context->Dim(shape_handle, 0)), 4); - EXPECT_EQ(context->Value(context->Dim(shape_handle, 1)), 6); - EXPECT_EQ(context->Value(context->Dim(shape_handle, 2)), 8); -} - -TEST_F(RangeCoderOpsTest, InvalidBroadcast) { - Tensor data{DT_INT16, {3, 3}}; - Tensor cdf{DT_INT32, {3, 2, 2}}; - - Tensor unused; - { - const Status status = RunEncodeOp(10, {data, cdf}, &unused); - EXPECT_FALSE(status.ok()); - EXPECT_NE(status.error_message().find("Cannot broadcast shape"), - string::npos); - } - - data = Tensor{DT_INT16, {3, 1}}; - cdf = Tensor{DT_INT32, {3, 3, 2}}; - Tensor empty{DT_STRING, {}}; - Tensor shape{DT_INT32, {2}}; - shape.vec().setValues({3, 1}); - { - const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); - EXPECT_FALSE(status.ok()); - EXPECT_NE(status.error_message().find("Cannot broadcast shape"), - string::npos); - } - - std::vector shape_vector = {2, 2, 2, 2, 2, 2, 2, 2, 2}; - data = Tensor{DT_INT16, TensorShape{shape_vector}}; - cdf = Tensor{DT_INT32, {2, 1, 2, 1, 2, 1, 2, 1, 2, 2}}; - { - const Status status = RunEncodeOp(10, {data, cdf}, &unused); - EXPECT_FALSE(status.ok()); - EXPECT_NE(status.error_message().find("Irregular broadcast"), string::npos); - } - - shape = Tensor{DT_INT32, {static_cast(shape_vector.size())}}; - for (int i = 0; i < shape_vector.size(); ++i) { - shape.flat()(i) = shape_vector[i]; - } - { - const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); - EXPECT_FALSE(status.ok()); - EXPECT_NE(status.error_message().find("Irregular broadcast"), string::npos); - } -} - -// Benchmark ------------------------------------------------------------- - -// This function creates RangeEncode graph with CDF built from a separate data -// sample. -Graph* CreateRangeEncodeFullBroadcastGraph(const TensorShape& shape, - int precision) { - CHECK_EQ(shape.dims(), 4); - - constexpr int kAlphabetSize = 70; - - Tensor histogram{DT_INT32, {kAlphabetSize + 1}}; - TTypes::Vec h = histogram.vec(); - h.setConstant(1); - h(0) = 0; - - random::PhiloxRandom philox(random::New64(), random::New64()); - random::SimplePhilox gen(&philox); - for (int i = 0; i < (1 << precision) - kAlphabetSize; ++i) { - const int value = LogUniform(&gen, kAlphabetSize - 1); - h(value + 1) += 1; - } - - Tensor cdf{DT_INT32, {1, 1, 1, 1, kAlphabetSize + 1}}; - cdf.flat() = h.cumsum(0); - - Tensor data{DT_INT16, shape}; - TTypes::Flat d = data.flat(); - for (int64 i = 0; i < d.size(); ++i) { - d(i) = LogUniform(&gen, kAlphabetSize - 1); - } - - Graph* g = new Graph(OpRegistry::Global()); - TF_CHECK_OK(NodeBuilder("range_encode", "RangeEncode", g->op_registry()) - .Input(test::graph::Constant(g, data)) - .Input(test::graph::Constant(g, cdf)) - .Attr("precision", precision) - .Finalize(g, nullptr)); - return g; -} - -// This function creates RangeDecode graph with CDF built from a separate data -// sample. -Graph* CreateRangeDecodeFullBroadcastGraph(const TensorShape& shape, - int precision) { - CHECK_EQ(shape.dims(), 4); - - constexpr int kAlphabetSize = 200; - const int64 num_elements = shape.num_elements(); - - Tensor histogram{DT_INT32, {kAlphabetSize + 1}}; - TTypes::Vec h = histogram.vec(); - h.setConstant(1); - h(0) = 0; - - random::PhiloxRandom philox(random::New64(), random::New64()); - random::SimplePhilox gen(&philox); - for (int i = 0; i < (1 << precision) - kAlphabetSize; ++i) { - const int value = LogUniform(&gen, kAlphabetSize - 1); - h(value + 1) += 1; - } - - Tensor cdf_tensor{DT_INT32, {1, 1, 1, 1, kAlphabetSize + 1}}; - TTypes::Flat cdf = cdf_tensor.flat(); - cdf = h.cumsum(0); - - Tensor string_tensor{DT_STRING, TensorShape{}}; - string& sink = string_tensor.scalar()(); - - RangeEncoder encoder{precision}; - for (int64 i = 0; i < num_elements; ++i) { - const int value = LogUniform(&gen, kAlphabetSize - 1); - encoder.Encode(cdf(value), cdf(value + 1), &sink); - } - encoder.Finalize(&sink); - - Tensor shape_tensor{DT_INT32, {shape.dims()}}; - for (int i = 0; i < shape.dims(); ++i) { - shape_tensor.flat()(i) = shape.dim_size(i); - } - - Graph* g = new Graph(OpRegistry::Global()); - TF_CHECK_OK(NodeBuilder("range_decode", "RangeDecode", g->op_registry()) - .Input(test::graph::Constant(g, string_tensor)) - .Input(test::graph::Constant(g, shape_tensor)) - .Input(test::graph::Constant(g, cdf_tensor)) - .Attr("precision", precision) - .Finalize(g, nullptr)); - return g; -} - -void RunTensorFlowBenchmark(int iters, Graph* g, int64 num_elements) { - SessionOptions opts; - opts.config.set_intra_op_parallelism_threads(1); - opts.config.set_inter_op_parallelism_threads(1); - - testing::UseRealTime(); - test::Benchmark("cpu", g, &opts).Run(iters); - - const int64 num_items = static_cast(iters) * num_elements; - testing::ItemsProcessed(num_items); -} - -void BM_RangeEncodeFullBroadcast(int iters, int code_size) { - constexpr int kPrecision = 14; - const TensorShape shape = {1, code_size, code_size, 256}; - Graph* g = CreateRangeEncodeFullBroadcastGraph(shape, kPrecision); - RunTensorFlowBenchmark(iters, g, shape.num_elements()); -} - -BENCHMARK(BM_RangeEncodeFullBroadcast)->Arg(32)->Arg(64); - -void BM_RangeDecodeFullBroadcast(int iters, int code_size) { - constexpr int kPrecision = 14; - const TensorShape shape = {1, code_size, code_size, 256}; - Graph* g = CreateRangeDecodeFullBroadcastGraph(shape, kPrecision); - RunTensorFlowBenchmark(iters, g, shape.num_elements()); -} - -BENCHMARK(BM_RangeDecodeFullBroadcast)->Arg(32)->Arg(64); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc deleted file mode 100644 index d66730cb488..00000000000 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h" - -#include - -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -using tensorflow::errors::InvalidArgument; - -namespace tensorflow { -Status MergeAxes(const TensorShape& broadcast_shape, - const TensorShape& storage_shape, - std::vector* merged_broadcast_shape_pointer, - std::vector* merged_storage_shape_pointer) { - CHECK_EQ(storage_shape.dims(), broadcast_shape.dims() + 1); - - std::vector& merged_broadcast_shape = *merged_broadcast_shape_pointer; - std::vector& merged_storage_shape = *merged_storage_shape_pointer; - - // The shapes are simplified so that the conversions between linear index - // and coordinates takes less CPU cycles. Two adjacent dimensions are - // merged if they both are broadcasting dimensions or if they both are - // non-broadcasting dimensions. - merged_broadcast_shape.resize(1); - merged_broadcast_shape[0] = 1; - merged_storage_shape.resize(1); - merged_storage_shape[0] = 1; - - for (int i = 0, j = 0; j < broadcast_shape.dims(); ++j) { - if (TF_PREDICT_FALSE( - (broadcast_shape.dim_size(j) != storage_shape.dim_size(j)) && - (storage_shape.dim_size(j) != 1))) { - return InvalidArgument("Cannot broadcast shape ", - storage_shape.DebugString(), " to ", - broadcast_shape.DebugString()); - } - - const bool was_broadcasting = (merged_storage_shape[i] == 1); - const bool is_broadcasting = (storage_shape.dim_size(j) == 1); - - // Merge two adjacent axes if they both are broadcasting or both are - // non-broadcasting axes. The second and the third conditions in the if - // clause below are when the previously merged axis or the next j-th axis - // may be interpreted as either a broadcasting or a non-broadcasting axis. - const bool merge = (was_broadcasting == is_broadcasting) || - (broadcast_shape.dim_size(j) <= 1) || - (merged_broadcast_shape[i] <= 1); - - if (merge) { - merged_broadcast_shape[i] *= broadcast_shape.dim_size(j); - merged_storage_shape[i] *= storage_shape.dim_size(j); - } else { - // Move to the next axis. - merged_broadcast_shape.push_back(broadcast_shape.dim_size(j)); - merged_storage_shape.push_back(storage_shape.dim_size(j)); - ++i; - } - } - - int64 storage_stride = 1; - for (int i = broadcast_shape.dims(); i < storage_shape.dims(); ++i) { - storage_stride *= storage_shape.dim_size(i); - } - merged_storage_shape.push_back(storage_stride); - - return Status::OK(); -} -} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_util.h b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h deleted file mode 100644 index b8aabcef62e..00000000000 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_util.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ -#define TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ - -#include - -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -// The shapes are simplified to reduce indexing cost. -Status MergeAxes(const TensorShape& broadcast_shape, - const TensorShape& storage_shape, - std::vector* merged_broadcast_shape_pointer, - std::vector* merged_storage_shape_pointer); -} // namespace tensorflow - -#endif // TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ diff --git a/tensorflow/contrib/coder/kernels/range_coder_test.cc b/tensorflow/contrib/coder/kernels/range_coder_test.cc deleted file mode 100644 index 442994bf7c7..00000000000 --- a/tensorflow/contrib/coder/kernels/range_coder_test.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/coder/kernels/range_coder.h" - -#include - -#include "tensorflow/core/lib/random/distribution_sampler.h" -#include "tensorflow/core/lib/random/philox_random.h" -#include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/random/simple_philox.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { -void RangeEncodeDecodeTest(int precision, random::SimplePhilox* gen) { - constexpr int kAlphabetSize = 256; - - std::vector distribution_weight; - distribution_weight.reserve(kAlphabetSize); - for (int i = 1; i <= kAlphabetSize; ++i) { - distribution_weight.push_back(std::pow(static_cast(i), -2.0f)); - } - - random::DistributionSampler sampler(distribution_weight); - - const int multiplier = (precision > 7) ? 32 : 1; - std::vector histogram(kAlphabetSize, multiplier - 1); - - const int data_size = - (multiplier << precision) - histogram.size() * (multiplier - 1); - CHECK_GE(data_size, 0); - std::vector data(data_size); - for (uint8& x : data) { - x = sampler.Sample(gen); - ++histogram[x]; - } - - std::vector cdf(histogram.size() + 1, 0); - int partial_sum = 0; - for (int i = 0; i < histogram.size(); ++i) { - partial_sum += histogram[i]; - cdf[i + 1] = partial_sum / multiplier; - } - - ASSERT_EQ(cdf.front(), 0); - ASSERT_EQ(cdf.back(), 1 << precision); - - std::vector ideal_code_length(histogram.size()); - const double normalizer = static_cast(1 << precision); - for (int i = 0; i < ideal_code_length.size(); ++i) { - ideal_code_length[i] = -std::log2((cdf[i + 1] - cdf[i]) / normalizer); - } - - RangeEncoder encoder(precision); - string encoded; - double ideal_length = 0.0; - for (uint8 x : data) { - encoder.Encode(cdf[x], cdf[x + 1], &encoded); - ideal_length += ideal_code_length[x]; - } - encoder.Finalize(&encoded); - - LOG(INFO) << "Encoded string length (bits): " << 8 * encoded.size() - << ", whereas ideal " << ideal_length << " (" - << (8 * encoded.size()) / ideal_length << " of ideal) " - << " (ideal compression rate " << ideal_length / (8 * data.size()) - << ")"; - - RangeDecoder decoder(encoded, precision); - for (int i = 0; i < data.size(); ++i) { - const int32 decoded = decoder.Decode(cdf); - ASSERT_EQ(decoded, static_cast(data[i])) << i; - } -} - -TEST(RangeCoderTest, Precision1To11) { - random::PhiloxRandom gen(random::New64(), random::New64()); - random::SimplePhilox rand(&gen); - const int precision = 1 + rand.Uniform(11); - RangeEncodeDecodeTest(precision, &rand); -} - -TEST(RangeCoderTest, Precision12To16) { - random::PhiloxRandom gen(random::New64(), random::New64()); - random::SimplePhilox rand(&gen); - for (int precision = 12; precision < 17; ++precision) { - RangeEncodeDecodeTest(precision, &rand); - } -} - -TEST(RangeCoderTest, FinalizeState0) { - constexpr int kPrecision = 2; - - string output; - RangeEncoder encoder(kPrecision); - encoder.Encode(0, 2, &output); - encoder.Finalize(&output); - - RangeDecoder decoder(output, kPrecision); - EXPECT_EQ(decoder.Decode({0, 2, 4}), 0); -} -} // namespace -} // namespace tensorflow diff --git a/tensorflow/contrib/coder/ops/coder_ops.cc b/tensorflow/contrib/coder/ops/coder_ops.cc deleted file mode 100644 index a185e07913f..00000000000 --- a/tensorflow/contrib/coder/ops/coder_ops.cc +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { -using shape_inference::DimensionHandle; -using shape_inference::InferenceContext; -using shape_inference::ShapeHandle; - -// clang-format off -REGISTER_OP("RangeEncode") - .Input("data: int16") - .Input("cdf: int32") - .Output("encoded: string") - .Attr("precision: int >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Using the provided cumulative distribution functions (CDF) inside `cdf`, returns -a range-code of `data`. - -The shape of `cdf` should have one more axis than the shape of `data`, and the -prefix `cdf.shape[:-1]` should be broadcastable to `data.shape`. That is, for -every `i = 0,...,rank(data) - 1`, the op requires that either -`cdf.shape[i] == 1` or `cdf.shape[i] == data.shape[i]`. Note that this -broadcasting is limited in the sense that the number of axes must match, and -broadcasts only `cdf` but not `data`. - -`data` should have an upper bound `m > 0` such that each element is an integer -in range `[0, m)`. Then the last dimension size of `cdf` must be `m + 1`. For -each element of `data`, the innermost strip of `cdf` is a vector representing a -CDF. For each k = 0,...,m, `cdf[..., k] / 2^precision` is the probability that -an outcome is less than `k` (not less than or equal to). - -``` - cdf[..., 0] / 2^precision = Pr(data[...] < 0) - cdf[..., 1] / 2^precision = Pr(data[...] < 1) = Pr(data[...] <= 0) - cdf[..., 2] / 2^precision = Pr(data[...] < 2) = Pr(data[...] <= 1) - ... - cdf[..., m] / 2^precision = Pr(data[...] < m) = 1 -``` - -Therefore each element of `cdf` must be in `[0, 2^precision]`. - -Ideally `cdf[..., m]` should equal to `2^precision` but this is not a hard -requirement as long as `cdf[..., m] <= 2^precision`. - -The encoded string neither contains the shape information of the encoded data -nor a termination symbol. Therefore the shape of the encoded data must be -explicitly provided to the decoder. - -Implementation notes: - -- Because of potential performance issues, the op does not check whether -elements of `data` is in the correct range `[0, m)`, or if `cdf` satisfies -monotonic increase property. - -- For the range coder to decode the encoded string correctly, the decoder should -be able to reproduce the internal states of the encoder precisely. Otherwise, -the decoding would fail and once an error occur, all subsequent decoded values -are incorrect. For this reason, the range coder uses integer arithmetics and -avoids using any floating point operations internally, and `cdf` should contain -integers representing quantized probability mass rather than floating points. - -data: An int16 tensor. -cdf: An int32 tensor representing the CDF's of `data`. Each integer is divided - by `2^precision` to represent a fraction. -encoded: A range-coded scalar string. -precision: The number of bits for probability quantization. Must be <= 16. -)doc"); - - -REGISTER_OP("RangeDecode") - .Input("encoded: string") - .Input("shape: int32") - .Input("cdf: int32") - .Output("decoded: int16") - .Attr("precision: int >= 1") - .SetShapeFn([] (InferenceContext* c) { - ShapeHandle out; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); - c->set_output(0, out); - return Status::OK(); - }) - .Doc(R"doc( -Decodes a range-coded `code` into an int32 tensor of shape `shape`. - -This is the reverse op of RangeEncode. The shape of the tensor that was encoded -should be known by the caller. - -Implementation notes: - -- If wrong input was given (e.g., corrupt `encoded` string, or `cdf` or -`precision` do not match encoder), the decode is unsuccessful. Because of -potential performance issues, the decoder does not return error status. - -encoded: A scalar string tensor from RangeEncode. -shape: An int32 1-D tensor representing the shape of the data encoded by - RangeEncode. -decoded: An int16 tensor with shape equal to `shape`. -precision: The number of bits for probability quantization. Must be <= 16, and - must match the precision used by RangeEncode that produced `encoded`. -)doc"); - -REGISTER_OP("PmfToQuantizedCdf") - .Input("pmf: float") - .Output("cdf: int32") - .Attr("precision: int >= 1") - .SetShapeFn([] (InferenceContext* c) { - ShapeHandle in; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in)); - DimensionHandle last; - TF_RETURN_IF_ERROR(c->Add(c->Dim(in, -1), 1, &last)); - ShapeHandle out; - TF_RETURN_IF_ERROR(c->ReplaceDim(in, -1, last, &out)); - c->set_output(0, out); - return Status::OK(); - }) - .Doc(R"doc( -Converts PMF to quantized CDF. This op uses floating-point operations -internally. Therefore the quantized output may not be consistent across multiple -platforms. For entropy encoders and decoders to have the same quantized CDF on -different platforms, the quantized CDF should be produced once and saved, then -the saved quantized CDF should be used everywhere. - -After quantization, if PMF does not sum to 2^precision, then some values of PMF -are increased or decreased to adjust the sum to equal to 2^precision. - -Note that the input PMF is pre-quantization. The input PMF is not normalized -by this op prior to quantization. Therefore the user is responsible for -normalizing PMF if necessary. -)doc"); -// clang-format on -} // namespace tensorflow diff --git a/tensorflow/contrib/coder/python/ops/coder_ops_test.py b/tensorflow/contrib/coder/python/ops/coder_ops_test.py deleted file mode 100644 index f5431ca1ffd..00000000000 --- a/tensorflow/contrib/coder/python/ops/coder_ops_test.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Coder operations tests.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.coder.python.ops import coder_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.platform import test - - -class CoderOpsTest(test.TestCase): - """Coder ops test. - - Coder ops have C++ tests. Python test just ensures that Python binding is not - broken. - """ - - def testReadmeExample(self): - data = random_ops.random_uniform((128, 128), 0, 10, dtype=dtypes.int32) - histogram = math_ops.bincount(data, minlength=10, maxlength=10) - cdf = math_ops.cumsum(histogram, exclusive=False) - cdf = array_ops.pad(cdf, [[1, 0]]) - cdf = array_ops.reshape(cdf, [1, 1, -1]) - - data = math_ops.cast(data, dtypes.int16) - encoded = coder_ops.range_encode(data, cdf, precision=14) - decoded = coder_ops.range_decode( - encoded, array_ops.shape(data), cdf, precision=14) - - with self.cached_session() as sess: - self.assertAllEqual(*sess.run((data, decoded))) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 79c61589112..773560fcd0b 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -24,49 +24,20 @@ py_library( srcs_version = "PY2AND3", deps = [ ":xla", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/compiler/xla:compiler_py", ], ) -cuda_py_test( - name = "jit_test", - size = "small", - srcs = ["jit_test.py"], - additional_deps = [ - ":compiler_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], - xla_enabled = True, -) - py_library( name = "xla", srcs = ["xla.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/compiler/jit:xla_ops_py", - "//tensorflow/compiler/jit/ops:xla_ops_grad", "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", - "//tensorflow/python:summary_op_util", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", + "//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/python/estimator:estimator_py", ], ) @@ -79,17 +50,12 @@ cuda_py_test( "@absl_py//absl/testing:parameterized", "//tensorflow/compiler/tests:xla_test", "//tensorflow/contrib/tpu:tpu_estimator", - "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:control_flow_util", - "//tensorflow/python:math_ops", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", - "//tensorflow/python:variable_scope", "//tensorflow/python/data/ops:dataset_ops", ], tags = [ diff --git a/tensorflow/contrib/compiler/jit.py b/tensorflow/contrib/compiler/jit.py index c516ab658d7..70898aeb974 100644 --- a/tensorflow/contrib/compiler/jit.py +++ b/tensorflow/contrib/compiler/jit.py @@ -18,101 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib +from tensorflow.python.compiler.xla import jit -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.framework import ops - - -_XLA_SCOPE_KEY = ("__xla_scope",) - - -class _XlaScope(object): - """Keeps track of previous XLA scope calls, and depth of current call.""" - - def __init__(self, count, depth): - self.count = count - self.depth = depth - - -@contextlib.contextmanager -def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False): - """Enable or disable JIT compilation of operators within the scope. - - NOTE: This is an experimental feature. - - The compilation is a hint and only supported on a best-effort basis. - - Example usage: - with tf.contrib.compiler.experimental_jit_scope(): - c = tf.matmul(a, b) # compiled - with tf.contrib.compiler.experimental_jit_scope(compile_ops=False): - d = tf.matmul(a, c) # not compiled - with tf.contrib.compiler.experimental_jit_scope( - compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): - e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. - - Example of separate_compiled_gradients: - # In the example below, the computations for f, g and h will all be compiled - # in separate scopes. - with tf.contrib.compiler.experimental_jit_scope( - separate_compiled_gradients=True): - f = tf.matmul(a, b) - g = tf.gradients([f], [a, b], name='mygrads1') - h = tf.gradients([f], [a, b], name='mygrads2') - - Args: - compile_ops: Whether to enable or disable compilation in the scope. - Either a Python bool, or a callable that accepts the parameter - `node_def` and returns a python bool. - separate_compiled_gradients: If true put each gradient subgraph into a - separate compilation scope. This gives fine-grained control over which - portions of the graph will be compiled as a single unit. Compiling - gradients separately may yield better performance for some graphs. - The scope is named based on the scope of the forward computation as well - as the name of the gradients. As a result, the gradients will be compiled - in a scope that is separate from both the forward computation, and from - other gradients. - Yields: - The current scope, enabling or disabling compilation. - - """ - if callable(compile_ops): - def xla_compile(node_def): - return attr_value_pb2.AttrValue(b=compile_ops(node_def)) - else: - xla_compile = attr_value_pb2.AttrValue(b=compile_ops) - - attrs = { - "_XlaCompile": - xla_compile, - "_XlaSeparateCompiledGradients": - attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients)) - } - - # Find the singleton counter for the current scoped graph. If it - # doesn't exist, create one. - xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY) - if not xla_scope_counter: - xla_scope_counter = _XlaScope(0, 0) - ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter) - else: - xla_scope_counter = xla_scope_counter[0] - - if xla_scope_counter.depth == 0: - # If we're at the root xla scope, we can increase the counter so - # future calls to jit_scope use a different scope value. - # If we're already within a scope, we'll be fusing using the scope - # controlled by the parent. - attrs["_XlaScope"] = attr_value_pb2.AttrValue( - s=("jit_scope_%d" % xla_scope_counter.count).encode()) - xla_scope_counter.count += 1 - - xla_scope_counter.depth += 1 - - # pylint: disable=protected-access - with ops.get_default_graph()._attr_scope(attrs): - yield - # pylint: enable=protected-access - - xla_scope_counter.depth -= 1 +experimental_jit_scope = jit.experimental_jit_scope diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 238c6ab1366..eec0b0ccb09 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -18,511 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import contextlib -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.compiler.jit.ops import xla_ops -from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import -from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.compiler.xla import xla from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import summary_op_util -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import compat from tensorflow.python.util import function_utils -from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator -from tensorflow.python.util import tf_inspect -_XLA_COMPILE_ATTR = '_xla_compile_id' -_MAX_WARNING_LINES = 5 - -# Operations that indicate some error in the users graph. For example, XLA -# computation should not have any Placeholder op. -_BLACKLISTED_OPS = set([ - 'Placeholder', -]) - -# XLA doesn't currently support reading of intermediate tensors, thus some ops -# are not supported. -_UNSUPPORTED_OPS = set([ - 'AudioSummary', - 'AudioSummaryV2', - 'HistogramSummary', - 'ImageSummary', - 'MergeSummary', - 'Print', - 'ScalarSummary', - 'TensorSummary', - 'TensorSummaryV2', -]) - - -def compile(computation, inputs=None): # pylint: disable=redefined-builtin - """Builds an operator that compiles and runs `computation` with XLA. - - Args: - computation: A Python function that builds a computation to apply to the - input. If the function takes n inputs, 'inputs' should be a list of n - tensors. - - `computation` may return a list of operations and tensors. Tensors must - come before operations in the returned list. The return value of - `compile` is a list of tensors corresponding to the tensors from the - output of `computation`. - - All `Operation`s returned from `computation` will be executed when - evaluating any of the returned output tensors. - inputs: A list of inputs or `None` (equivalent to an empty list). Each input - can be a nested structure containing values that are convertible to - tensors. Note that passing an N-dimension list of compatible values will - result in a N-dimention list of scalar tensors rather than a single Rank-N - tensors. If you need different behavior, convert part of inputs to tensors - with `tf.convert_to_tensor`. - - Returns: - Same data structure as if computation(*inputs) is called directly with some - exceptions for correctness. Exceptions include: - 1) None output: a NoOp would be returned which control-depends on - computation. - 2) Single value output: A tuple containing the value would be returned. - 3) Operation-only outputs: a NoOp would be returned which - control-depends on computation. - TODO(b/121383831): Investigate into removing these special cases. - """ - # pylint: disable=protected-access - return _compile_internal(computation, inputs) - - -class XLACompileContext(control_flow_ops.XLAControlFlowContext): - """A `ControlFlowContext` for nodes inside an XLA computation cluster. - - THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. - - The primary role of `XLACompileContext` is to mark operators inside a - xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is - a unique name. - - `ControlFlowContext` is used to perform the annotation since it integrates - with Tensorflow constructs like ResourceVariables. For example, if a - `ResourceVariable` is constructed inside a xla.compile() block, the - `ResourceVariable` implementation can use - `with ops.control_dependencies(None)` to build the variable's definition - outside the compiled computation. - """ - - def __init__(self, name, pivot): - """Builds a new XLACompileContext. - - Args: - name: a unique name for the context, used to populate the - `_xla_compile_id` attribute. - pivot: a pivot node. Nodes in the XLACompileContext that do not have any - inputs will have a control dependency on the pivot node. This ensures - that nodes are correctly included in any enclosing control flow - contexts. - """ - super(XLACompileContext, self).__init__() - self._name = name - self._name_as_bytes = compat.as_bytes(name) - self._unsupported_ops = [] - self._pivot = pivot - - def report_unsupported_operations(self): - if self._unsupported_ops: - op_str = '\n'.join([ - ' %s (%s)' % (op.type, op.name) - for op in self._unsupported_ops[:_MAX_WARNING_LINES] - ]) - logging.warning('%d unsupported operations found: \n%s', - len(self._unsupported_ops), op_str) - if len(self._unsupported_ops) > _MAX_WARNING_LINES: - logging.warning('... and %d more', - len(self._unsupported_ops) - _MAX_WARNING_LINES) - - def _RemoveExternalControlEdges(self, op): - """Remove any external control dependency on this op.""" - internal_control_inputs = [] - external_control_inputs = [] - for x in op.control_inputs: - # pylint: disable=protected-access - is_internal_op = False - ctxt = x._get_control_flow_context() - while ctxt is not None: - if ctxt == self: - is_internal_op = True - break - ctxt = ctxt._outer_context - if is_internal_op: - internal_control_inputs.append(x) - else: - external_control_inputs.append(x) - # pylint: enable=protected-access - # pylint: disable=protected-access - op._remove_all_control_inputs() - op._add_control_inputs(internal_control_inputs) - # pylint: enable=protected-access - return internal_control_inputs, external_control_inputs - - def AddOp(self, op): - """Create op in XLACompileContext and notifies outer context recursively.""" - # pylint: disable=protected-access - if op.type in _BLACKLISTED_OPS: - logging.error( - 'Operation of type %s (%s) is not supported in XLA. Execution will ' - 'fail if this op is used in the graph. ', op.type, op.name) - - # TODO(ycao): Automatically disable summaries instead of reporting them. - if op.type in _UNSUPPORTED_OPS: - self._unsupported_ops.append(op) - - if any(x.dtype._is_ref_dtype for x in op.inputs): - raise NotImplementedError( - 'Non-resource Variables are not supported inside XLA computations ' - '(operator name: %s)' % op.name) - - if _XLA_COMPILE_ATTR in op.node_def.attr: - raise ValueError('XLA compiled computations cannot be nested, (operator ' - 'name: %s)' % op.name) - - op._set_attr( - _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) - - op.graph.prevent_feeding(op) - op.graph.prevent_fetching(op) - - # Remove any control edges from outer control flow contexts. These may cause - # mismatched frame errors. An example is when one of op's inputs is - # generated in a different While control flow context. - (internal_control_inputs, - external_control_inputs) = self._RemoveExternalControlEdges(op) - - if not op.inputs: - # Add a control edge from the control pivot to this op. - if not internal_control_inputs: - # pylint: disable=protected-access - op._add_control_input(self._pivot) - # pylint: enable=protected-access - else: - for index in xrange(len(op.inputs)): - x = op.inputs[index] - real_x = self.AddValue(x) - if real_x != x: - op._update_input(index, real_x) # pylint: disable=protected-access - - if external_control_inputs: - # Use an identity to pull control inputs as data inputs. Note that we - # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - self.Exit() - # pylint: disable=protected-access - op._add_control_inputs(external_control_inputs) - # pylint: enable=protected-access - - # Mark op's outputs as seen by this context and any outer contexts. - output_names = [x.name for x in op.outputs] - context = self - while context is not None: - # pylint: disable=protected-access - context._values.update(output_names) - context = context._outer_context - # pylint: enable=protected-access - - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - """Add `val` to the current context and its outer context recursively.""" - if val.name in self._values: - # Use the real value if it comes from outer context. - result = self._external_values.get(val.name) - return val if result is None else result - - result = val - self._values.add(val.name) - if self._outer_context: - result = self._outer_context.AddValue(val) - self._values.add(result.name) - - self._external_values[val.name] = result - - return result - - def AddInnerOp(self, op): - self.AddOp(op) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - @property - def grad_state(self): - # Define the gradient loop state associated with the XLACompileContext to - # be None as the XLACompileContext does not get nested nor does the - # grad_state outside the XLACompileContext affect the graph inside so the - # grad_state should be as if this is the top-level gradient state. - return None - - @property - def back_prop(self): - """Forwards to the enclosing while context, if any.""" - if self.GetWhileContext(): - return self.GetWhileContext().back_prop - return False - - -def _compile_internal(computation, inputs=None): - """Builds graph operators that compiles and symbolically executes computation. - - Args: - computation: A Python function that builds the computation to compile and - execute. - inputs: A list of inputs or `None` (equivalent to an empty list). Each input - can be a nested structure containing values that are convertible to - tensors. Note that passing an N-dimension list of compatible values will - result in a N-dimension list of scalar tensors rather than a single Rank-N - tensors. If you need different behavior, convert part of inputs to tensors - with `tf.convert_to_tensor`. - - Returns: - Same data structure as if computation(*inputs) is called directly with some - exceptions for correctness. Exceptions include: 1) None output 2) Single - value output 3) Operation-only outputs - Raises: - ValueError: If any element in computation outputs is neither an operations - or a value that can be converted to tensor. - ValueError: If computation outputs is non-flat and contains any Operations. - TypeError: If `inputs` is not a list or tuple. - """ - if inputs is None: - inputs = [] - - if not isinstance(inputs, collections.Sequence): - raise TypeError('inputs must be a list') - - # Flatten inputs. - flat_inputs = nest.flatten(inputs) - # Converts inputs to Tensors. - flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] - - cluster_name = ops.get_default_graph().unique_name('cluster') - pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') - context = XLACompileContext(name=cluster_name, pivot=pivot) - try: - context.Enter() - - # Add identity ops so even unused inputs are 'consumed' by the - # computation. - flat_inputs = [ - array_ops.identity(x, name='input_{}'.format(i)) - for i, x in enumerate(flat_inputs) - ] - - # Re-pack flat_inputs in same structure as 'inputs'. - computation_inputs = nest.pack_sequence_as( - structure=inputs, flat_sequence=flat_inputs) - - # Only resource variables work inside an XLA computation, so turn on - # resource variables for the computation. - vscope = variable_scope.get_variable_scope() - saved_use_resource = vscope.use_resource - vscope.set_use_resource(True) - - with _disable_summary_context(): - outputs = computation(*computation_inputs) - - # Restore variable scope after computation. - vscope.set_use_resource(saved_use_resource) - - outputs_is_flat = is_flat(outputs) - if outputs_is_flat: - output_tensors, control_deps = _postprocess_flat_outputs(outputs) - else: - output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - - context.ExitResult(output_tensors) - finally: - context.report_unsupported_operations() - context.Exit() - - # When XLA computation returns only operations and no tensors, a NoOp - # dependent on the operations in outputs is returned. Otherwise final - # outputs would be empty and there is no way to trigger returned - # operations. - if not output_tensors: - return control_flow_ops.group(control_deps, name='output_0') - - output_tensors = [ - xla_ops.xla_cluster_output(o, name='output{}'.format(i)) - for i, o in enumerate(output_tensors) - ] - - with ops.control_dependencies(control_deps): - # Wraps the outputs in identity operators that carries control - # dependencies. - output_tensors = [ - array_ops.identity(o, name='output_%d' % i) - for i, o in enumerate(output_tensors) - ] - - # If `computation` returned non-flat output structure, pack output tensors - # back into same structure. - if not outputs_is_flat: - output_tensors = nest.pack_sequence_as( - structure=outputs, flat_sequence=output_tensors) - - return output_tensors - - -def is_flat(outputs): - """Checks if outputs is a flat structure. - - Following structures and values are considered flat: - 1) None - 2) A single object - 3) A list or tuple of Tensors/Operations - - The only structures that this function understands are sequences and - dictionaries. E.g. this means that if outputs contains a single - user-defined Object, it is considered to be flat. Errors are raised later on - if that Object cannot be converted to a Tensor. - - Args: - outputs: Output from `computation` inside `xla.compile`. - - Returns: - A boolean indicates whether outputs is flat. - """ - # If outputs is a list or tuple, check if it has any nested structure. If - # there is, then outputs is non-flat. - if isinstance(outputs, collections.Sequence): - for o in outputs: - if isinstance(o, collections.Sequence) or isinstance(o, dict): - return False - - # If outputs is a dict, it is non-flat. - if isinstance(outputs, dict): - return False - - # Getting here means either outputs itself is a single non-structured value - # or it is a flat list of single non-structured values. - return True - - -def _postprocess_flat_outputs(outputs): - """Validates flat outputs and adds back device assignments. - - Args: - outputs: Output from `computation` inside `xla.compile`. - - Returns: - Tensors and Operations extracted from outputs. - """ - # Following code segment is to preserve legacy behavior. Previously we only - # supported flat outputs and thus for consistency it was nice to convert even - # single element into a tuple. But now that we support arbitrary output - # structure, this is no longer necessary. - # TODO(b/121383831): Migrate all legacy use cases and delete this special - # case. - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, make it a tuple. - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) - - # Append `no_op` here so that return value of this function always contains - # at least one op that can trigger XlaLaunch node. - outputs += (control_flow_ops.no_op(),) - try: - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - 'XLA computation function return values must all either be Operations' - ' or convertible to Tensors. Got error: "%s"' % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - 'XLA computation function must return zero or more Tensor values ' - 'followed by zero or more Operations.') - - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else ''): - new_output_tensors.append(array_ops.identity(t)) - - return new_output_tensors, output_operations - - -def _postprocess_non_flat_outputs(outputs): - """Validates non-flat outputs and adds back device assignments. - - Args: - outputs: Output from `computation` inside `xla.compile`. - - Returns: - Tensors extracted from outputs and an empty list because Operations are not - allowed in non-flat outputs.. - """ - # Convert all non-Operation outputs to Tensors. - new_output_tensors = [] - for o in nest.flatten(outputs): - if isinstance(o, ops.Operation): - raise ValueError( - 'xla.compile does not support Operation as return value in non-flat ' - 'output structure. You can set returned Operations as control ' - 'dependencies of returned Tensors so Operations are triggered when ' - 'Tensors are evaluated. Operation found: "%s"' % o.name) - - try: - o = ops.convert_to_tensor(o) - except Exception as e: - raise ValueError( - 'XLA computation function return values must all either be ' - 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) - - # Makes sure even pass-through inputs/outputs are touched in compile - # context by creating an Identity node inside compile context. - with ops.device(o.device if o.device else ''): - new_output_tensors.append(array_ops.identity(o)) - - return new_output_tensors, [] - - -@contextlib.contextmanager -def _disable_summary_context(): - """Enters a context where all summary ops are skipped. - - Summaries are not yet supported in xla.compile(). So we provide this context - manager that can skip creating summary ops. This is a temporary workaround due - to XLA not supporting summary ops. - - Yields: - None. - """ - original_skip_summary_func = summary_op_util.skip_summary - summary_op_util.skip_summary = lambda: True - - try: - yield - finally: - summary_op_util.skip_summary = original_skip_summary_func +compile = xla.compile # pylint: disable=redefined-builtin +check_function_argument_count = xla.check_function_argument_count class _CapturedObject(object): """A placeholder to capture an object.""" @@ -750,8 +256,8 @@ def estimator_model_fn(target_model_fn=None): """estimator_model_fn decorates a model_fn to be compiled for execution. Currently it only works with `TPUEstimator`. If you need to use it with base - `Estimator`, please add `tf.enable_resource_variables()` at the beginning of - your program. + `Estimator`, please add `tf.compat.v1.enable_resource_variables()` at the + beginning of your program. Example 1, decorating model_fn: ``` @@ -788,51 +294,3 @@ def estimator_model_fn(target_model_fn=None): return tf_decorator.make_decorator(function, _ModelFnWrapper(function)) return decorated(target_model_fn) if target_model_fn else decorated - - -def check_function_argument_count(func, input_arity, infeed_queue): - """Validate the number of input arguments to an XLA function. - - Args: - func: the Python function that will be called to generate the body of an XLA - computation graph. - input_arity: the number of explicit arguments supplied by the caller. - infeed_queue: if not None, the infeed queue that will supply - additional arguments to the function. - - Returns: - None if function can be called with the supplied number of - arguments, or an error string if it cannot. - """ - def format_error(complaint, quantity): - return '%s %d argument%s' % (complaint, quantity, '' - if quantity == 1 else 's') - - num_args_supplied = input_arity - if infeed_queue is not None: - num_args_supplied += infeed_queue.number_of_tuple_elements - arg_spec = tf_inspect.getargspec(func) - num_func_args = len(arg_spec.args) - if arg_spec.defaults is None: - num_func_defaults = 0 - else: - num_func_defaults = len(arg_spec.defaults) - min_func_args = num_func_args - num_func_defaults - if num_args_supplied < min_func_args: - # The required number of arguments is not enough to call the function. - if num_func_defaults == 0 and arg_spec.varargs is None: - return format_error('exactly', num_func_args) - else: - return format_error('at least', min_func_args) - if arg_spec.varargs is None and num_args_supplied > num_func_args: - # The required number of arguments is too many to call the function. - if num_func_defaults == 0: - return format_error('exactly', num_func_args) - else: - return format_error('at most', num_func_args) - # Reaching here means either - # 1) There are varargs, func can accept any number of arguments greater than - # the minimum. - # 2) Number of supplied arguments falls in range of acceptable argument count - # of func. - return None diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index c4384dcde75..0df7c3706aa 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -23,20 +23,13 @@ from absl.testing import parameterized from tensorflow.contrib.compiler import xla from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.training.python.training import hparam from tensorflow.python import summary from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_util -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import training @@ -48,226 +41,6 @@ _EXPECTED_FEATURE = 2 _EXPECTED_LABEL = 3 -class XLACompileContextTest(test.TestCase): - - def create_test_xla_compile_context(self): - computation_name = ops.get_default_graph().unique_name('computation') - pivot = control_flow_ops.no_op(name=computation_name + '/pivot') - return xla.XLACompileContext(name=computation_name, pivot=pivot) - - def test_report_unsupported_operations(self): - """Tests that unsupported operations are detected.""" - context = self.create_test_xla_compile_context() - context.Enter() - dummy_tensor = constant_op.constant(1.1) - audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5) - histogram_summary = summary.histogram('histogram_summary', dummy_tensor) - image_summary = summary.image('image_summary', dummy_tensor) - scalar_summary = summary.scalar('scalar_summary', dummy_tensor) - tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor) - summary.merge( - [ - audio_summary, histogram_summary, image_summary, scalar_summary, - tensor_summary - ], - name='merge_summary') - logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op') - context.Exit() - - unsupported_ops_names = [op.name for op in context._unsupported_ops] - self.assertEqual(unsupported_ops_names, [ - u'audio_summary', u'histogram_summary', u'image_summary', - u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary', - u'print_op' - ]) - - def test_resource_variable(self): - """Tests that resource variable usage is allowed.""" - a = variable_scope.get_variable( - name='variable_a', shape=(1), use_resource=True) - - context = self.create_test_xla_compile_context() - context.Enter() - state_ops.assign(a, a + 1) - context.Exit() - - def test_non_resource_variable_error(self): - """Tests that non-resource variable usage is disallowed.""" - a = variable_scope.get_variable( - name='variable_a', shape=(1), use_resource=False) - - context = self.create_test_xla_compile_context() - context.Enter() - with self.assertRaisesRegexp( - NotImplementedError, 'Non-resource Variables are not supported inside ' - r'XLA computations \(operator name: Assign\)'): - state_ops.assign(a, a + 1) - context.Exit() - - def test_nested_xla_compile_error(self): - """Tests that nested XLA computation leads to fatal error.""" - context1 = self.create_test_xla_compile_context() - context1.Enter() - - context2 = self.create_test_xla_compile_context() - context2.Enter() - with self.assertRaisesRegexp(ValueError, - 'XLA compiled computations cannot be nested'): - constant_op.constant(1) - context2.Exit() - context1.Exit() - - def test_xla_compile_attr(self): - """Tests that ops are tagged with XLA compile ID attribute.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - self.assertIn('_xla_compile_id', op.op.node_def.attr) - - def test_op_without_input(self): - """Tests that ops without inputs depend on pivot correctly.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - - self.assertIn(context._pivot, op.op.control_inputs) - - def test_external_control_edges(self): - """Tests that external control edges are handled correctly.""" - i = constant_op.constant(1) - op1 = constant_op.constant(1) - - with ops.control_dependencies([op1]): - op2 = constant_op.constant(1) - self.assertIn(op1.op, op2.op.control_inputs) - - def while_body(i): - del i # unused - context = self.create_test_xla_compile_context() - context.Enter() - with ops.control_dependencies([op1]): - op3 = constant_op.constant(1) - context.Exit() - self.assertNotIn(op1.op, op3.op.control_inputs) - return op3 - - control_flow_ops.while_loop( - cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i]) - - def test_op_output_marked_as_seen(self): - """Tests that any op output is marked as seen in context.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - - self.assertIn(op.name, context._values) - - def testOpIsInContext(self): - """Tests that XLACompileContext is recognized as an XLA context.""" - op1 = constant_op.constant(1) - context = self.create_test_xla_compile_context() - context.Enter() - op2 = constant_op.constant(2) - context.Exit() - self.assertFalse(control_flow_util.IsInXLAContext(op1.op)) - self.assertTrue(control_flow_util.IsInXLAContext(op2.op)) - - def testOpPreventFeeding(self): - """Tests that ops created inside XLACompileContext can not be fed.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - self.assertFalse(op.graph.is_feedable(op.op)) - - def testOpPreventFetching(self): - """Tests that ops created inside XLACompileContext can not be fetched.""" - context = self.create_test_xla_compile_context() - context.Enter() - op = constant_op.constant(1) - context.Exit() - self.assertFalse(op.graph.is_fetchable(op.op)) - - -class CheckFunctionArgumentCountTest(test.TestCase): - - def testSimple(self): - """Tests that arg checker works for functions with no varargs or defaults. - """ - - def func(x, y, z): - return x + y + z - - self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) - self.assertEqual('exactly 3 arguments', - xla.check_function_argument_count(func, 2, None)) - queue = tpu_feed.InfeedQueue(2) - self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) - self.assertEqual('exactly 3 arguments', - xla.check_function_argument_count(func, 2, queue)) - - def testDefaultArgs(self): - """Tests that arg checker works for a function with no varargs.""" - - def func(x, y, z=17): - return x + y + z - - self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 1, None)) - self.assertEqual('at most 3 arguments', - xla.check_function_argument_count(func, 4, None)) - queue = tpu_feed.InfeedQueue(1) - self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 0, queue)) - self.assertEqual('at most 3 arguments', - xla.check_function_argument_count(func, 4, queue)) - - def testVarArgs(self): - """Tests that arg checker works for a function with varargs.""" - - def func(x, y, *z): - return x + y + len(z) - - self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 1, None)) - queue = tpu_feed.InfeedQueue(1) - self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 0, queue)) - - def testVarArgsAndDefaults(self): - """Tests that arg checker works for a function with varargs and defaults.""" - - def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg - return x + y + z + len(q) - - self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) - self.assertEqual(None, xla.check_function_argument_count(func, 5, None)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 1, None)) - queue = tpu_feed.InfeedQueue(1) - self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) - self.assertEqual(None, xla.check_function_argument_count(func, 4, queue)) - self.assertEqual('at least 2 arguments', - xla.check_function_argument_count(func, 0, queue)) - - def _test_train_model_fn(features, labels, mode, params): """A dummy model_fn for testing purpose.""" del features, labels, params diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD index 619153df67c..bd81e36c423 100644 --- a/tensorflow/contrib/constrained_optimization/BUILD +++ b/tensorflow/contrib/constrained_optimization/BUILD @@ -41,7 +41,12 @@ py_library( py_test( name = "candidates_test", srcs = ["python/candidates_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", + tags = [ + # TODO(b/129496144): Re-enable MSAN test. + "nomsan", + ], deps = [ ":constrained_optimization", "//tensorflow/python:client_testlib", @@ -65,6 +70,7 @@ py_library( py_test( name = "external_regret_optimizer_test", srcs = ["python/external_regret_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":constrained_optimization", @@ -79,6 +85,7 @@ py_test( py_test( name = "swap_regret_optimizer_test", srcs = ["python/swap_regret_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":constrained_optimization", diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py index 0b79bdf7c05..09249884423 100644 --- a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py @@ -32,17 +32,18 @@ from tensorflow.python.training import optimizer as train_optimizer class ConstrainedOptimizer(object): """Base class representing a constrained optimizer. - A ConstrainedOptimizer wraps a tf.train.Optimizer (or more than one), and - applies it to a ConstrainedMinimizationProblem. Unlike a tf.train.Optimizer, - which takes a tensor to minimize as a parameter to its minimize() method, a - constrained optimizer instead takes a ConstrainedMinimizationProblem. + A ConstrainedOptimizer wraps a tf.compat.v1.train.Optimizer (or more than + one), and applies it to a ConstrainedMinimizationProblem. Unlike a + tf.compat.v1.train.Optimizer, which takes a tensor to minimize as a parameter + to its minimize() method, a constrained optimizer instead takes a + ConstrainedMinimizationProblem. """ def __init__(self, optimizer): """Constructs a new `ConstrainedOptimizer`. Args: - optimizer: tf.train.Optimizer, used to optimize the + optimizer: tf.compat.v1.train.Optimizer, used to optimize the ConstraintedMinimizationProblem. Returns: @@ -52,7 +53,7 @@ class ConstrainedOptimizer(object): @property def optimizer(self): - """Returns the `tf.train.Optimizer` used for optimization.""" + """Returns the `tf.compat.v1.train.Optimizer` used for optimization.""" return self._optimizer @abc.abstractmethod @@ -74,14 +75,15 @@ class ConstrainedOptimizer(object): Args: minimization_problem: ConstrainedMinimizationProblem, the problem to optimize. - global_step: as in `tf.train.Optimizer`'s `minimize` method. - var_list: as in `tf.train.Optimizer`'s `minimize` method. - gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. - aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. - colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + global_step: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + var_list: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. - name: as in `tf.train.Optimizer`'s `minimize` method. - grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.compat.v1.train.Optimizer`'s + `minimize` method. + name: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. Returns: `Operation`, the train_op. @@ -106,14 +108,15 @@ class ConstrainedOptimizer(object): Args: minimization_problem: ConstrainedMinimizationProblem, the problem to optimize. - global_step: as in `tf.train.Optimizer`'s `minimize` method. - var_list: as in `tf.train.Optimizer`'s `minimize` method. - gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. - aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. - colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + global_step: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + var_list: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. - name: as in `tf.train.Optimizer`'s `minimize` method. - grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.compat.v1.train.Optimizer`'s + `minimize` method. + name: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. Returns: `Operation`, the train_op. @@ -159,14 +162,15 @@ class ConstrainedOptimizer(object): Args: minimization_problem: ConstrainedMinimizationProblem, the problem to optimize. - global_step: as in `tf.train.Optimizer`'s `minimize` method. - var_list: as in `tf.train.Optimizer`'s `minimize` method. - gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. - aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. - colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + global_step: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + var_list: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. - name: as in `tf.train.Optimizer`'s `minimize` method. - grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.compat.v1.train.Optimizer`'s + `minimize` method. + name: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. Returns: `Operation`, the train_op. @@ -220,14 +224,15 @@ class ConstrainedOptimizer(object): optimize. unconstrained_steps: int, number of steps for which we should perform unconstrained updates, before transitioning to constrained updates. - global_step: as in `tf.train.Optimizer`'s `minimize` method. - var_list: as in `tf.train.Optimizer`'s `minimize` method. - gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. - aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. - colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + global_step: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + var_list: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. - name: as in `tf.train.Optimizer`'s `minimize` method. - grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.compat.v1.train.Optimizer`'s + `minimize` method. + name: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. Returns: `Operation`, the train_op. diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py index fb0f849b33b..0cfe354bce0 100644 --- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py @@ -15,7 +15,8 @@ """Defines `AdditiveExternalRegretOptimizer`. This optimizer minimizes a `ConstrainedMinimizationProblem` by introducing -Lagrange multipliers, and using `tf.train.Optimizer`s to jointly optimize over +Lagrange multipliers, and using `tf.compat.v1.train.Optimizer`s to jointly +optimize over the model parameters and Lagrange multipliers. For the purposes of constrained optimization, at least in theory, @@ -33,8 +34,8 @@ The formulation used by the AdditiveExternalRegretOptimizer--which is simply the usual Lagrangian formulation--can be found in Definition 1, and is discussed in Section 3. This optimizer is most similar to Algorithm 3 in Appendix C.3, with the two differences being that it uses proxy constraints (if they're provided) -in the update of the model parameters, and uses `tf.train.Optimizer`s, instead -of SGD, for the "inner" updates. +in the update of the model parameters, and uses `tf.compat.v1.train.Optimizer`s, +instead of SGD, for the "inner" updates. """ from __future__ import absolute_import @@ -99,9 +100,8 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius): del old_inactive # Needed by the condition, but not the body. iteration += 1 scale = standard_ops.minimum( - 0.0, - (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum( - 1.0, standard_ops.reduce_sum(inactive))) + 0.0, (radius - standard_ops.reduce_sum(multipliers)) / + standard_ops.maximum(1.0, standard_ops.reduce_sum(inactive))) multipliers = multipliers + (scale * inactive) new_inactive = standard_ops.cast(multipliers > 0, multipliers.dtype) multipliers = multipliers * new_inactive @@ -157,12 +157,12 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): `constraint_optimizer` is provided, then `optimizer` is used for both. Args: - optimizer: tf.train.Optimizer, used to optimize the objective and - proxy_constraints portion of the ConstrainedMinimizationProblem. If + optimizer: tf.compat.v1.train.Optimizer, used to optimize the objective + and proxy_constraints portion of the ConstrainedMinimizationProblem. If constraint_optimizer is not provided, this will also be used to optimize the Lagrange multipliers. - constraint_optimizer: optional tf.train.Optimizer, used to optimize the - Lagrange multipliers. + constraint_optimizer: optional tf.compat.v1.train.Optimizer, used to + optimize the Lagrange multipliers. Returns: A new `_ExternalRegretOptimizer`. @@ -172,7 +172,7 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): @property def constraint_optimizer(self): - """Returns the `tf.train.Optimizer` used for the Lagrange multipliers.""" + """Returns the `tf.compat.v1.train.Optimizer` used for the Lagrange multipliers.""" return self._constraint_optimizer @abc.abstractmethod @@ -209,14 +209,15 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): Args: minimization_problem: ConstrainedMinimizationProblem, the problem to optimize. - global_step: as in `tf.train.Optimizer`'s `minimize` method. - var_list: as in `tf.train.Optimizer`'s `minimize` method. - gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. - aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. - colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + global_step: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + var_list: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. - name: as in `tf.train.Optimizer`'s `minimize` method. - grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.compat.v1.train.Optimizer`'s + `minimize` method. + name: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. Raises: ValueError: If the minimization_problem tensors have different dtypes. @@ -318,10 +319,10 @@ class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): """A `ConstrainedOptimizer` based on external-regret minimization. - This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly - minimize over the model parameters, and maximize over Lagrange multipliers, - with the latter maximization using additive updates and an algorithm that - minimizes external regret. + This `ConstrainedOptimizer` uses the given `tf.compat.v1.train.Optimizer`s to + jointly minimize over the model parameters, and maximize over Lagrange + multipliers, with the latter maximization using additive updates and an + algorithm that minimizes external regret. For more specifics, please refer to: @@ -333,8 +334,8 @@ class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): formulation--can be found in Definition 1, and is discussed in Section 3. It is most similar to Algorithm 3 in Appendix C.3, with the two differences being that it uses proxy constraints (if they're provided) in the update of the - model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for the - "inner" updates. + model parameters, and uses `tf.compat.v1.train.Optimizer`s, instead of SGD, + for the "inner" updates. """ def __init__(self, @@ -344,12 +345,12 @@ class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): """Constructs a new `AdditiveExternalRegretOptimizer`. Args: - optimizer: tf.train.Optimizer, used to optimize the objective and - proxy_constraints portion of ConstrainedMinimizationProblem. If + optimizer: tf.compat.v1.train.Optimizer, used to optimize the objective + and proxy_constraints portion of ConstrainedMinimizationProblem. If constraint_optimizer is not provided, this will also be used to optimize the Lagrange multipliers. - constraint_optimizer: optional tf.train.Optimizer, used to optimize the - Lagrange multipliers. + constraint_optimizer: optional tf.compat.v1.train.Optimizer, used to + optimize the Lagrange multipliers. maximum_multiplier_radius: float, an optional upper bound to impose on the sum of the Lagrange multipliers. @@ -379,7 +380,7 @@ class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): return state def _constraint_grad_and_var(self, state, gradient): - # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + # TODO(acotter): v1.colocate_with(), if colocate_gradients_with_ops is True? return (-gradient, state) def _projection_op(self, state, name=None): diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index 14e6d870112..fb02106f025 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -37,7 +37,8 @@ For more specifics, please refer to: The formulation used by both of the SwapRegretOptimizers can be found in Definition 2, and is discussed in Section 4. The `MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 in Section 4, -with the difference being that it uses `tf.train.Optimizer`s, instead of SGD, +with the difference being that it uses `tf.compat.v1.train.Optimizer`s, instead +of SGD, for the "inner" updates. The `AdditiveSwapRegretOptimizer` differs further in that it performs additive (instead of multiplicative) updates of the stochastic matrix. @@ -74,9 +75,7 @@ def _maximal_eigenvector_power_method(matrix, L2 norm) by no more than epsilon, we will terminate. maximum_iterations: nonnegative int, if we perform this many iterations, we will terminate. - - Result: - The maximal right-eigenvector of `matrix`. + Result: The maximal right-eigenvector of `matrix`. Raises: ValueError: If the `matrix` tensor is not floating-point, or if the @@ -255,12 +254,12 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): is provided, then `optimizer` is used for both. Args: - optimizer: tf.train.Optimizer, used to optimize the objective and - proxy_constraints portion of ConstrainedMinimizationProblem. If + optimizer: tf.compat.v1.train.Optimizer, used to optimize the objective + and proxy_constraints portion of ConstrainedMinimizationProblem. If constraint_optimizer is not provided, this will also be used to optimize the Lagrange multiplier analogues. - constraint_optimizer: optional tf.train.Optimizer, used to optimize the - Lagrange multiplier analogues. + constraint_optimizer: optional tf.compat.v1.train.Optimizer, used to + optimize the Lagrange multiplier analogues. Returns: A new `_SwapRegretOptimizer`. @@ -270,7 +269,7 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): @property def constraint_optimizer(self): - """Returns the `tf.train.Optimizer` used for the matrix.""" + """Returns the `tf.compat.v1.train.Optimizer` used for the matrix.""" return self._constraint_optimizer @abc.abstractmethod @@ -316,14 +315,15 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): Args: minimization_problem: ConstrainedMinimizationProblem, the problem to optimize. - global_step: as in `tf.train.Optimizer`'s `minimize` method. - var_list: as in `tf.train.Optimizer`'s `minimize` method. - gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. - aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. - colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + global_step: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + var_list: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. - name: as in `tf.train.Optimizer`'s `minimize` method. - grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.compat.v1.train.Optimizer`'s + `minimize` method. + name: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.compat.v1.train.Optimizer`'s `minimize` method. Raises: ValueError: If the minimization_problem tensors have different dtypes. @@ -359,9 +359,9 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): trainable=False, name="swap_regret_optimizer_state") - zero_and_constraints = standard_ops.concat( - (standard_ops.zeros((1,), dtype=constraints.dtype), constraints), - axis=0) + zero_and_constraints = standard_ops.concat((standard_ops.zeros( + (1,), dtype=constraints.dtype), constraints), + axis=0) objective_and_proxy_constraints = standard_ops.concat( (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0) @@ -433,7 +433,8 @@ class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer): """A `ConstrainedOptimizer` based on swap-regret minimization. - This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + This `ConstrainedOptimizer` uses the given `tf.compat.v1.train.Optimizer`s to + jointly minimize over the model parameters, and maximize over constraint/objective weight matrix (the analogue of Lagrange multipliers), with the latter maximization using additive updates and an algorithm that minimizes swap @@ -447,7 +448,8 @@ class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer): The formulation used by this optimizer can be found in Definition 2, and is discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with - the differences being that it uses `tf.train.Optimizer`s, instead of SGD, for + the differences being that it uses `tf.compat.v1.train.Optimizer`s, instead of + SGD, for the "inner" updates, and performs additive (instead of multiplicative) updates of the stochastic matrix. """ @@ -456,12 +458,12 @@ class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer): """Constructs a new `AdditiveSwapRegretOptimizer`. Args: - optimizer: tf.train.Optimizer, used to optimize the objective and - proxy_constraints portion of ConstrainedMinimizationProblem. If + optimizer: tf.compat.v1.train.Optimizer, used to optimize the objective + and proxy_constraints portion of ConstrainedMinimizationProblem. If constraint_optimizer is not provided, this will also be used to optimize the Lagrange multiplier analogues. - constraint_optimizer: optional tf.train.Optimizer, used to optimize the - Lagrange multiplier analogues. + constraint_optimizer: optional tf.compat.v1.train.Optimizer, used to + optimize the Lagrange multiplier analogues. Returns: A new `AdditiveSwapRegretOptimizer`. @@ -479,16 +481,16 @@ class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer): dimension = num_constraints + 1 # Initialize by putting all weight on the objective, and none on the # constraints. - return standard_ops.concat( - (standard_ops.ones( - (1, dimension)), standard_ops.zeros((dimension - 1, dimension))), - axis=0) + return standard_ops.concat((standard_ops.ones( + (1, dimension)), standard_ops.zeros((dimension - 1, dimension))), + axis=0) def _stochastic_matrix(self, state): return state def _constraint_grad_and_var(self, state, gradient): - # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + # TODO(acotter): tf.compat.v1.colocate_with(), + # if colocate_gradients_with_ops is True? return (-gradient, state) def _projection_op(self, state, name=None): @@ -502,7 +504,8 @@ class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer): class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): """A `ConstrainedOptimizer` based on swap-regret minimization. - This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + This `ConstrainedOptimizer` uses the given `tf.compat.v1.train.Optimizer`s to + jointly minimize over the model parameters, and maximize over constraint/objective weight matrix (the analogue of Lagrange multipliers), with the latter maximization using multiplicative updates and an algorithm that minimizes swap @@ -516,7 +519,8 @@ class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): The formulation used by this optimizer can be found in Definition 2, and is discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with - the difference being that it uses `tf.train.Optimizer`s, instead of SGD, for + the difference being that it uses `tf.compat.v1.train.Optimizer`s, instead of + SGD, for the "inner" updates. """ @@ -528,12 +532,12 @@ class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): """Constructs a new `MultiplicativeSwapRegretOptimizer`. Args: - optimizer: tf.train.Optimizer, used to optimize the objective and - proxy_constraints portion of ConstrainedMinimizationProblem. If + optimizer: tf.compat.v1.train.Optimizer, used to optimize the objective + and proxy_constraints portion of ConstrainedMinimizationProblem. If constraint_optimizer is not provided, this will also be used to optimize the Lagrange multiplier analogues. - constraint_optimizer: optional tf.train.Optimizer, used to optimize the - Lagrange multiplier analogues. + constraint_optimizer: optional tf.compat.v1.train.Optimizer, used to + optimize the Lagrange multiplier analogues. minimum_multiplier_radius: float, each element of the matrix will be lower bounded by `minimum_multiplier_radius` divided by one plus the number of constraints. @@ -575,20 +579,20 @@ class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): (dimension - 1) / (dimension))) log_initial_zero = math.log(self._initial_multiplier_radius / dimension) # FUTURE WORK: make the dtype a parameter. - return standard_ops.concat( - (standard_ops.constant( - log_initial_one, dtype=dtypes.float32, shape=(1, dimension)), - standard_ops.constant( - log_initial_zero, - dtype=dtypes.float32, - shape=(dimension - 1, dimension))), - axis=0) + return standard_ops.concat((standard_ops.constant( + log_initial_one, dtype=dtypes.float32, shape=(1, dimension)), + standard_ops.constant( + log_initial_zero, + dtype=dtypes.float32, + shape=(dimension - 1, dimension))), + axis=0) def _stochastic_matrix(self, state): return standard_ops.exp(state) def _constraint_grad_and_var(self, state, gradient): - # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + # TODO(acotter): tf.compat.v1.colocate_with(), + # if colocate_gradients_with_ops is True? return (-gradient, state) def _projection_op(self, state, name=None): diff --git a/tensorflow/contrib/copy_graph/BUILD b/tensorflow/contrib/copy_graph/BUILD index fa44c4d54e1..6273bcf7a5c 100644 --- a/tensorflow/contrib/copy_graph/BUILD +++ b/tensorflow/contrib/copy_graph/BUILD @@ -28,6 +28,7 @@ py_library( py_test( name = "copy_test", srcs = glob(["python/util/copy_test.py"]), + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":copy_graph_py", diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 40e159b8fcb..974423fec99 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -24,7 +24,7 @@ log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood( unary_scores, gold_tags, sequence_lengths) loss = tf.reduce_mean(-log_likelihood) -train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) +train_op = tf.compat.v1.train.GradientDescentOptimizer(0.01).minimize(loss) # Decoding in Tensorflow. viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode( @@ -283,7 +283,7 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0) # Use int32 or int64 based on tag_indices' dtype. if tag_indices.dtype == dtypes.int64: - offsets = math_ops.to_int64(offsets) + offsets = math_ops.cast(offsets, dtypes.int64) flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1]) unary_scores = array_ops.reshape( diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index 8d35622e393..174d82c1b9a 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -1,6 +1,7 @@ # Description: # A Cudnn RNN wrapper. # APIs are meant to change over time. + package( default_visibility = ["//visibility:private"], ) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index f5219eb134d..5c63ee7a97b 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -21,19 +21,17 @@ from __future__ import print_function import collections import itertools import os -import unittest from absl.testing import parameterized import numpy as np from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.core.protobuf import saver_pb2 -from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -44,7 +42,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as saver_lib @@ -69,6 +66,8 @@ def RunLSTM(sess, time, num_layers=1, variable_seq_lengths=False, + time_major=True, + dynamic_shape_input=False, is_training=True, dropout=0., num_dirs=True, @@ -84,11 +83,14 @@ def RunLSTM(sess, random_seed.set_random_seed(0) np.random.seed(0) - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(time, batch_size, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) + shape = ([time, batch_size, input_size] + if time_major else [batch_size, time, input_size]) + inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype) + inputs_static = variable_scope.get_variable( + "inputs", initializer=inputs_np, dtype=dtype) + inputs_dynamic = array_ops.placeholder( + dtype, shape=[None, None, None], name="inputs") + inputs = inputs_dynamic if dynamic_shape_input else inputs_static initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -122,12 +124,12 @@ def RunLSTM(sess, cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, - inputs, + inputs_static, sequence_length=lengths, initial_state=rnn_cell_impl.LSTMStateTuple( h=initial_h_op, c=initial_c_op), dtype=dtype, - time_major=True, + time_major=time_major, scope=None) # Convert to cudnn opaque param. @@ -135,35 +137,38 @@ def RunLSTM(sess, num_layers, num_units, input_size) opaque_params = format_converter.tf_canonical_to_opaque([w, b]) - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) - cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) + cu_initial_h_op = array_ops.expand_dims( + initial_h_op, axis=(0 if time_major else 1)) + cu_initial_c_op = array_ops.expand_dims( + initial_c_op, axis=(0 if time_major else 1)) cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, cu_initial_c_op, opaque_params, sequence_lengths=lengths, + time_major=time_major, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) # Remove the trivial 1st dimension. cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( - c=array_ops.squeeze(cu_c_op, axis=0), - h=array_ops.squeeze(cu_h_op, axis=0)) + c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1), + h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1)) if is_training: (inp_grad_op, hgrad_op, cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( - outputs_op, [inputs, initial_h_op, initial_c_op, w, b]) + outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b]) (cu_inp_grad_op, cu_hgrad_op, cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( cu_outputs_op, [inputs, cu_initial_h_op, cu_initial_c_op, opaque_params]) # Remove the trivial 1st dimension - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1) # Remove the trivial 1st dimension - cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1) cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( opaque_grad_op) @@ -183,10 +188,12 @@ def RunLSTM(sess, (hgrad_op, cgrad_op), wgrad_op, bgrad_op ]) (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, - cu_bgrad) = sess.run([ - cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, - (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op - ]) + cu_bgrad) = sess.run( + [ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -205,7 +212,10 @@ def RunLSTM(sess, cu_bgrad) else: outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) - cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op]) + cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op], + feed_dict=({ + inputs: inputs_np + } if dynamic_shape_input else None)) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -326,7 +336,7 @@ def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs): return [dict(t) for t in {tuple(d.items()) for d in res}] -class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): +class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase): def _test_training_helper(self, num_units, @@ -336,6 +346,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtype, variable_seq_lengths, + time_major, + dynamic_shape_input=False, rtol=3e-6, atol=3e-6): with self.session(use_gpu=True) as sess: @@ -347,7 +359,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -359,15 +373,15 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") + variable_seq_lengths, time_major, dynamic_shape_input): self._test_training_helper( num_units, input_size, @@ -375,18 +389,21 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtypes.float32, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): self._test_training_helper( num_units, input_size, @@ -396,18 +413,20 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): dtypes.float16, rtol=5e-3, atol=5e-4, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") + variable_seq_lengths, time_major, dynamic_shape_input): with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( sess, @@ -417,7 +436,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs) # h @@ -426,15 +447,16 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(state_tuple.c, cu_state_tuple.c) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( sess, @@ -445,7 +467,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dtype=dtypes.float16, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -457,16 +481,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): """Validates that dropout does not affect Cudnn Rnn inference.""" - if not context.context().num_gpus(): - self.skipTest("No GPUs found") # Hand-picked dropouts are used below (0. and 1.) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -480,7 +505,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=0., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -493,7 +520,9 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=1., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(cu_outputs, cu_outputs2) # h @@ -510,6 +539,8 @@ def RunGRU(sess, num_layers=1, is_training=True, variable_seq_lengths=False, + time_major=True, + dynamic_shape_input=False, dropout=0., num_dirs=True, dtype=dtypes.float32): @@ -524,11 +555,14 @@ def RunGRU(sess, random_seed.set_random_seed(0) np.random.seed(0) - inputs = variable_scope.get_variable( - "inputs", - initializer=np.random.rand(time, batch_size, - input_size).astype(dtype.as_numpy_dtype), - dtype=dtype) + shape = ([time, batch_size, input_size] + if time_major else [batch_size, time, input_size]) + inputs_np = np.random.rand(*shape).astype(dtype.as_numpy_dtype) + inputs_static = variable_scope.get_variable( + "inputs", initializer=inputs_np, dtype=dtype) + inputs_dynamic = array_ops.placeholder( + dtype, shape=[None, None, None], name="inputs") + inputs = inputs_dynamic if dynamic_shape_input else inputs_static initial_h_op = variable_scope.get_variable( "initial_h_op", initializer=np.random.rand(batch_size, @@ -573,11 +607,11 @@ def RunGRU(sess, cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True) outputs_op, h_op = rnn.dynamic_rnn( cell, - inputs, + inputs_static, sequence_length=lengths, initial_state=initial_h_op, dtype=dtype, - time_major=True, + time_major=time_major, scope=None) ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel] @@ -588,13 +622,15 @@ def RunGRU(sess, opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) - cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_initial_h_op = array_ops.expand_dims( + initial_h_op, axis=(0 if time_major else 1)) cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, array_ops.zeros_like(cu_initial_h_op), # not used opaque_params, sequence_lengths=lengths, + time_major=time_major, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_GRU) @@ -602,12 +638,12 @@ def RunGRU(sess, if is_training: (inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op, cib_grad_op, chb_grad_op) = gradients_impl.gradients( - outputs_op, [inputs, initial_h_op] + ws + bs) + outputs_op, [inputs_static, initial_h_op] + ws + bs) (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) # Remove the trivial 1st dimension - cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0 if time_major else 1) cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( opaque_grad_op) @@ -627,13 +663,15 @@ def RunGRU(sess, (gk_grad_op, cik_grad_op, chk_grad_op), (gb_grad_op, cib_grad_op, chb_grad_op) ]) - (cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run([ - cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, - (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), - (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) - ]) + (cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run( + [ + cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) + ], + feed_dict={inputs: inputs_np} if dynamic_shape_input else None) # Remove the trivial 1st dimension - cu_h = np.squeeze(cu_h, axis=0) + cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -651,9 +689,12 @@ def RunGRU(sess, cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) else: outputs, h = sess.run([outputs_op, h_op]) - cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) + cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op], + feed_dict=({ + inputs: inputs_np + } if dynamic_shape_input else None)) # Remove the trivial 1st dimension. - cu_h = np.squeeze(cu_h, axis=0) + cu_h = np.squeeze(cu_h, axis=0 if time_major else 1) logging.vlog(1, "outputs: %s" % outputs) logging.vlog(1, "cu_outputs: %s" % cu_outputs) @@ -662,7 +703,7 @@ def RunGRU(sess, return outputs, cu_outputs, h, cu_h -class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): +class CudnnGRUTest(test_util.TensorFlowTestCase, parameterized.TestCase): def _test_training_helper(self, num_units, @@ -672,6 +713,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtype, variable_seq_lengths, + time_major, + dynamic_shape_input=False, rtol=3e-6, atol=3e-6): with self.session(use_gpu=True) as sess: @@ -683,7 +726,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @@ -695,15 +740,15 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_training(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") + variable_seq_lengths, time_major, dynamic_shape_input): self._test_training_helper( num_units, input_size, @@ -711,18 +756,21 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtypes.float32, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): self._test_training_helper( num_units, input_size, @@ -732,18 +780,20 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): dtypes.float16, rtol=5e-3, atol=5e-4, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_inference(self, num_units, input_size, batch_size, time, num_layers, - variable_seq_lengths): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") + variable_seq_lengths, time_major, dynamic_shape_input): with self.session(use_gpu=True) as sess: (outputs, cu_outputs, h, cu_h) = RunGRU( sess, @@ -753,20 +803,23 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(outputs, cu_outputs) self.assertAllClose(h, cu_h) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): with self.session(use_gpu=True) as sess: (outputs, cu_outputs, h, cu_h) = RunGRU( sess, @@ -777,24 +830,27 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dtype=dtypes.float16, - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @parameterized.named_parameters( - ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ - "variable_seq_lengths": [True, False], - })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + "time_major": [True, False], + "dynamic_shape_input": [True, False], + })) + @test_util.run_gpu_only def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers, variable_seq_lengths): + num_layers, variable_seq_lengths, time_major, + dynamic_shape_input): """Validates that dropout does not affect Cudnn Rnn inference.""" # Hand-picked dropouts are used below (0. and 1.) - if not context.context().num_gpus(): - self.skipTest("No GPUs found") with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: # 1st time w/o dropout. @@ -807,7 +863,9 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=0., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -820,13 +878,15 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, is_training=False, dropout=1., - variable_seq_lengths=variable_seq_lengths) + variable_seq_lengths=variable_seq_lengths, + time_major=time_major, + dynamic_shape_input=dynamic_shape_input) self.assertAllClose(cu_outputs, cu_outputs2) self.assertAllClose(cu_h[0], cu_h2[0]) -class CudnnParamsFormatConverterTest(TensorFlowTestCase, +class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase, parameterized.TestCase): """Class for testing various format converters.""" @@ -877,22 +937,16 @@ class CudnnParamsFormatConverterTest(TensorFlowTestCase, @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) for c in NAMED_RNN_TESTCASES) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + @test_util.run_gpu_only def test_lstm(self, num_units, input_size, num_layers): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") self._test_lstm_helper(num_units, input_size, num_layers, cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) for c in NAMED_RNN_TESTCASES) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + @test_util.run_gpu_only def test_lstm_bidi(self, num_units, input_size, num_layers): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") self._test_lstm_helper(num_units, input_size, num_layers, cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) @@ -951,27 +1005,22 @@ class CudnnParamsFormatConverterTest(TensorFlowTestCase, @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) for c in NAMED_RNN_TESTCASES) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + @test_util.run_gpu_only def test_gru(self, num_units, input_size, num_layers): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") self._test_gru_helper(num_units, input_size, num_layers, cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) @parameterized.named_parameters((c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"]) for c in NAMED_RNN_TESTCASES) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + @test_util.run_gpu_only def test_gru_bidi(self, num_units, input_size, num_layers): - if not context.context().num_gpus(): - self.skipTest("No GPUs found") self._test_gru_helper(num_units, input_size, num_layers, cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) -class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase): +class CudnnRnnSaveRestoreTest(test_util.TensorFlowTestCase, + parameterized.TestCase): """Class for testing various Cudnn Rnn SaveableObjects.""" def _create_opaque_param(self, @@ -1019,14 +1068,11 @@ class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase): ], "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + @test_util.run_gpu_only def test_save_restore_variable(self, rnn_mode, num_units, input_size, num_layers, direction): # Verify the restored opaque param, once converted to tf_canonical format, # is the same as the tf canonicals of the pre-restored param. - if not context.context().num_gpus(): - self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: opaque_param = self._create_opaque_param(rnn_mode, num_units, input_size, num_layers, direction) @@ -1071,14 +1117,11 @@ class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase): ], "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] })) - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") + @test_util.run_gpu_only def test_save_restore_multi_variables(self, rnn_mode, num_units, input_size, num_layers, direction): # Verify the restored opaque param, once converted to tf_canonical format, # is the same as the tf canonicals of the pre-restored param. - if not context.context().num_gpus(): - self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: opaque_params = [] saveables = [] diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 403f3090952..be66fac66b8 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -1155,7 +1155,7 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): The gradient check verifies the expected delta_y calculated by the above equation is close to the actual delta_y. Args: - sess: tf.Session object. + sess: tf.compat.v1.Session object. y: output tensor. xs: a tensor or a list of input tensors. num_samples: number of test samples to run. diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index c4e37b41c85..2401870a455 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging - CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -46,7 +45,6 @@ CUDNN_INPUT_LINEAR_MODE = cudnn_rnn_ops.CUDNN_INPUT_LINEAR_MODE CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE - __all__ = ["CudnnLSTM", "CudnnGRU", "CudnnRNNTanh", "CudnnRNNRelu"] @@ -57,7 +55,7 @@ class _CudnnRNN(base_layer.Layer): Cudnn RNNs have two major differences from other platform-independent RNNs tf provides: * Cudnn LSTM and GRU are mathematically different from their tf counterparts. - (e.g. `tf.contrib.rnn.LSTMBlockCell` and `tf.nn.rnn_cell.GRUCell`. + (e.g. `tf.contrib.rnn.LSTMBlockCell` and `tf.compat.v1.nn.rnn_cell.GRUCell`. * Cudnn-trained checkpoints are not directly compatible with tf RNNs: * They use a single opaque parameter buffer for the entire (possibly) multi-layer multi-directional RNN; Whereas tf RNN weights are per-cell and @@ -67,7 +65,8 @@ class _CudnnRNN(base_layer.Layer): does not have a static shape and is not partitionable. Instead of using partitioning to alleviate the PS's traffic load, try building a multi-tower model and do gradient aggregation locally within the host - before updating the PS. See https://www.tensorflow.org/performance/performance_models#parameter_server_variables + before updating the PS. See + https://www.tensorflow.org/performance/performance_models#parameter_server_variables for a detailed performance guide. Consequently, if one plans to use Cudnn trained models on both GPU and CPU @@ -104,15 +103,17 @@ class _CudnnRNN(base_layer.Layer): # Inference subgraph for unidirectional RNN on, e.g., CPU or mobile. with tf.Graph().as_default(): - single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units) + single_cell = lambda: + tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units) # NOTE: Even if there's only one layer, the cell needs to be wrapped in # MultiRNNCell. - cell = tf.nn.rnn_cell.MultiRNNCell( + cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell( [single_cell() for _ in range(num_layers)]) # Leave the scope arg unset. - outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state, ...) + outputs, final_state = tf.compat.v1.nn.dynamic_rnn(cell, inputs, + initial_state, ...) saver = Saver() @@ -124,7 +125,8 @@ class _CudnnRNN(base_layer.Layer): # Inference subgraph for bidirectional RNN with tf.Graph().as_default(): - single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units) + single_cell = lambda: + tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units) cells_fw = [single_cell() for _ in range(num_layers)] cells_bw = [single_cell() for _ in range(num_layers)] @@ -171,26 +173,25 @@ class _CudnnRNN(base_layer.Layer): num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It can be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and the actual computation before the first layer. It can be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Can be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: dropout rate, a number between [0, 1]. Dropout is applied between - each layer (no dropout is applied for a model with a single layer). - When set to 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + each layer (no dropout is applied for a model with a single layer). When + set to 0, dropout is disabled. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. dtype: tf.float16, tf.float32 or tf.float64 kernel_initializer: starting value to initialize the weight. - bias_initializer: starting value to initialize the bias - (default is all zeros). - name: VariableScope for the created subgraph; defaults to class name. - This only serves the default scope if later no scope is specified when + bias_initializer: starting value to initialize the bias (default is all + zeros). + name: VariableScope for the created subgraph; defaults to class name. This + only serves the default scope if later no scope is specified when invoking __call__(). Raises: @@ -201,8 +202,8 @@ class _CudnnRNN(base_layer.Layer): cudnn_rnn_ops.check_input_mode(input_mode) if dtype not in [dtypes.float16, dtypes.float32, dtypes.float64]: - raise ValueError( - "Only support float16, float32, float64, provided %s" % dtype) + raise ValueError("Only support float16, float32, float64, provided %s" % + dtype) # Layer self.dtype is type name, the original DType object is kept here. self._plain_dtype = dtype self._num_layers = num_layers @@ -309,6 +310,7 @@ class _CudnnRNN(base_layer.Layer): Args: input_shape: network input tensor shape, a python list or a TensorShape object with 3 dimensions. + Raises: ValueError: if input_shape has wrong dimension or unknown 3rd dimension. """ @@ -328,9 +330,9 @@ class _CudnnRNN(base_layer.Layer): self._set_scope(None) # Not using base class `add_variable()` since the it calls - # `tf.get_variable()` with a callable initializer whereas here with a - # tensor. The difference is mandated to support forward-compatibility with - # Cudnn. + # `tf.compat.v1.get_variable()` with a callable initializer whereas here + # with a tensor. The difference is mandated to support forward-compatibility + # with Cudnn. with vs.variable_scope( self._scope, reuse=self.built, @@ -360,8 +362,10 @@ class _CudnnRNN(base_layer.Layer): # Initialize opaque params with a tensor with unknown shape, thus couldn't # use self.add_variable(name, shape, initializer, ...) self.kernel = vs.get_variable( - "opaque_kernel", dtype=self._plain_dtype, - initializer=opaque_params_t, validate_shape=False) + "opaque_kernel", + dtype=self._plain_dtype, + initializer=opaque_params_t, + validate_shape=False) # Create saveable in the outer scope of the cudnn subgraph, such that # alternative subgraph with platform-independent rnn cells can load the # checkpoints directly. @@ -378,20 +382,34 @@ class _CudnnRNN(base_layer.Layer): inputs, initial_state=None, sequence_lengths=None, + time_major=True, training=True): """Runs the forward step for the RNN model. Args: - inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. - initial_state: a tuple of tensor(s) of shape - `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use - zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. + inputs: `3-D` tensor. If `time_major` is True (default), the Tensor shape + is [time_len, batch_size, input_size]. If `time_major` is False, the + shape is [batch_size, time_len, input_size]. + initial_state: a tuple of tensor(s) of shape `[num_layers * num_dirs, + batch_size, num_units]` if `time_major` is True (default) or + `[batch_size, num_layers * num_dirs, num_units]` if `time_major` is + False. If not provided, use zero initial states. The tuple size is 2 for + LSTM and 1 for other RNNs. sequence_lengths: an int32 array representing the variable sequence - lengths in a batch. The size of the array has to equal the - batch_size. If not provided, the same sequence length will be assumed. + lengths in a batch. The size of the array has to equal the batch_size. + If not provided, the same sequence length will be assumed. + time_major: The shape format of the `inputs` and `outputs` Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' is + used. training: whether this operation will be used in training or inference. + Returns: - output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. + output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]` + if `time_major` is True (default) or `[batch_size, time_len, + num_dirs * num_units]` if `time_major` is False. It is a `concat([fwd_output, bak_output], axis=2)`. output_states: a tuple of tensor(s) of the same shape and structure as `initial_state`. @@ -418,7 +436,8 @@ class _CudnnRNN(base_layer.Layer): # For model that doesn't take input_c, replace with a dummy tensor. c = array_ops.constant([], dtype=dtype) outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, - sequence_lengths, training) + sequence_lengths, time_major, + training) if self._rnn_mode == CUDNN_LSTM: return outputs, (output_h, output_c) else: @@ -437,7 +456,7 @@ class _CudnnRNN(base_layer.Layer): """Shapes of Cudnn canonical weight tensors for given layer.""" if layer < 0 or layer >= self._num_layers: raise ValueError("\'layer\' is not valid, got %s, expecting [%d, %d]" % - (layer, 0, self._num_layers-1)) + (layer, 0, self._num_layers - 1)) if not self._input_size: raise RuntimeError( "%s._canonical_weight_shape invoked before input shape is known" % @@ -482,7 +501,8 @@ class _CudnnRNN(base_layer.Layer): dropout=self._dropout, direction=self._direction) - def _forward(self, inputs, h, c, opaque_params, sequence_lengths, training): + def _forward(self, inputs, h, c, opaque_params, sequence_lengths, time_major, + training): output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access inputs, h, @@ -491,6 +511,7 @@ class _CudnnRNN(base_layer.Layer): training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -519,7 +540,8 @@ class _CudnnRNN(base_layer.Layer): scope=vs.get_variable_scope(), name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) self._saveable._add_trackable_dependencies( # pylint: disable=protected-access - trackable=self, dtype=self._plain_dtype) + trackable=self, + dtype=self._plain_dtype) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) @@ -536,6 +558,7 @@ class CudnnLSTM(_CudnnRNN): [num_layers * num_dirs, batch_size, num_units] Args: batch_size: an int + Returns: a tuple of python arrays. """ @@ -557,12 +580,15 @@ class _CudnnRNNNoInputC(_CudnnRNN): """Abstract simple CudnnRNN layer without input_c.""" def state_shape(self, batch_size): - """Shape of the state of Cudnn RNN cells w/o. input_c. + """Shape of the state of Cudnn RNN cells w/o. + + input_c. Shape is a 1-element tuple, [num_layers * num_dirs, batch_size, num_units] Args: batch_size: an int + Returns: a tuple of python arrays. """ diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 7d848e2ec2d..3694d112ce4 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -68,15 +68,19 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( - num_units, forget_bias=0, cell_clip=None, use_peephole=False, - reuse=reuse, name="cudnn_compatible_lstm_cell") + num_units, + forget_bias=0, + cell_clip=None, + use_peephole=False, + reuse=reuse, + name="cudnn_compatible_lstm_cell") self._names.update({"scope": "cudnn_compatible_lstm_cell"}) class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): r"""Cudnn Compatible GRUCell. - A GRU impl akin to `tf.nn.rnn_cell.GRUCell` to use along with + A GRU impl akin to `tf.compat.v1.nn.rnn_cell.GRUCell` to use along with `tf.contrib.cudnn_rnn.CudnnGRU`. The latter's params can be used by it seamlessly. @@ -97,7 +101,8 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): $$h_t = (1 - u_t) .* h'_t + u_t .* h_t-1$$ ``` - Other GRU (see `tf.nn.rnn_cell.GRUCell` and `tf.contrib.rnn.GRUBlockCell`): + Other GRU (see `tf.compat.v1.nn.rnn_cell.GRUCell` and + `tf.contrib.rnn.GRUBlockCell`): ```python # new memory gate \\(h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_{Wh})\\) @@ -117,8 +122,8 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): def build(self, inputs_shape): if inputs_shape[1].value is None: - raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % + inputs_shape) input_depth = inputs_shape[1].value self._gate_kernel = self.add_variable( @@ -128,10 +133,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): self._gate_bias = self.add_variable( "gates/%s" % _BIAS_VARIABLE_NAME, shape=[2 * self._num_units], - initializer=( - self._bias_initializer - if self._bias_initializer is not None - else init_ops.constant_initializer(1.0, dtype=self.dtype))) + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.constant_initializer(1.0, dtype=self.dtype))) self._candidate_input_kernel = self.add_variable( "candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME, @@ -145,17 +149,15 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): self._candidate_input_bias = self.add_variable( "candidate/input_projection/%s" % _BIAS_VARIABLE_NAME, shape=[self._num_units], - initializer=( - self._bias_initializer - if self._bias_initializer is not None - else init_ops.zeros_initializer(dtype=self.dtype))) + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) self._candidate_hidden_bias = self.add_variable( "candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME, shape=[self._num_units], - initializer=( - self._bias_initializer - if self._bias_initializer is not None - else init_ops.zeros_initializer(dtype=self.dtype))) + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" @@ -173,7 +175,7 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): math_ops.matmul(state, self._candidate_hidden_kernel), self._candidate_hidden_bias) candidate = self._activation(candidate) - new_h = (1-u) * candidate + u * state + new_h = (1 - u) * candidate + u * state return new_h, new_h @@ -231,6 +233,7 @@ class CudnnParamsFormatConverter(object): Args: opaque_param: An opaque tensor storing cudnn rnn params (weights and biases). + Returns: 2 list for weights and biases respectively. """ @@ -252,6 +255,7 @@ class CudnnParamsFormatConverter(object): Args: cu_weights: a list of tensors, Cudnn canonical weights. cu_biases: a list of tensors, Cudnn canonical biases. + Returns: a single opaque tensor. """ @@ -285,6 +289,7 @@ class CudnnParamsFormatConverter(object): Args: cu_weights: a list of tensors of Cudnn canonical weights. cu_biases: a list of tensors of Cudnn canonical biases. + Returns: 1 tuple, tf canonical weights and biases. """ @@ -298,8 +303,9 @@ class CudnnParamsFormatConverter(object): layer_weights_num] layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: - self._cu_canonical_to_tf_canonical_single_layer( - layer_weights, layer_biases, tf_weights, tf_biases) + self._cu_canonical_to_tf_canonical_single_layer(layer_weights, + layer_biases, + tf_weights, tf_biases) else: fw_weights = layer_weights[:len(layer_weights) // 2] bw_weights = layer_weights[len(layer_weights) // 2:] @@ -372,10 +378,12 @@ class CudnnParamsFormatConverter(object): cu_weights.extend(self._tf_to_cudnn_weights(i, *layer_weights)) cu_biases.extend(self._tf_to_cudnn_biases(*layer_biases)) else: - fw_weights, bw_weights = layer_weights[:len( - layer_weights) // 2], layer_weights[len(layer_weights) // 2:] - fw_biases, bw_biases = layer_biases[:len( - layer_biases) // 2], layer_biases[len(layer_biases) // 2:] + fw_weights, bw_weights = layer_weights[:len(layer_weights) // + 2], layer_weights[ + len(layer_weights) // 2:] + fw_biases, bw_biases = layer_biases[:len(layer_biases) // + 2], layer_biases[len(layer_biases + ) // 2:] cu_weights.extend(self._tf_to_cudnn_weights(i, *fw_weights)) cu_biases.extend(self._tf_to_cudnn_biases(*fw_biases)) @@ -424,7 +432,7 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): W_o = array_ops.concat([w_o, r_o], axis=1) # pylint: enable=invalid-name # Cudnn LSTM weights are in ifco order, other tf LSTMs are in icfo order. - reordered = self._cudnn_to_tf_gate_params(* [W_i, W_f, W_c, W_o]) + reordered = self._cudnn_to_tf_gate_params(*[W_i, W_f, W_c, W_o]) return (array_ops.transpose(array_ops.concat(reordered, axis=0)),) def _tf_to_cudnn_weights(self, layer, *tf_weights): @@ -441,8 +449,8 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): (tf_weight,) = tf_weights w = array_ops.transpose(tf_weight) # pylint: disable=invalid-name - W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params(*array_ops.split( - w, 4, axis=0)) + W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params( + *array_ops.split(w, 4, axis=0)) w_i, r_i = array_ops.split(W_i, [input_weight_width, num_units], axis=1) w_c, r_c = array_ops.split(W_c, [input_weight_width, num_units], axis=1) @@ -463,15 +471,15 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): B_c = b_wc + b_rc B_o = b_wo + b_ro # pylint: enable=invalid-name - reordered = self._cudnn_to_tf_gate_params(* [B_i, B_f, B_c, B_o]) + reordered = self._cudnn_to_tf_gate_params(*[B_i, B_f, B_c, B_o]) return (array_ops.concat(reordered, axis=0),) def _tf_to_cudnn_biases(self, *tf_biases): r"""Reverse the operations in StitchBiases().""" (tf_bias,) = tf_biases # pylint: disable=invalid-name - B_i, B_f, B_c, B_o = self._tf_to_cudnn_gate_params(*array_ops.split( - tf_bias, 4, axis=0)) + B_i, B_f, B_c, B_o = self._tf_to_cudnn_gate_params( + *array_ops.split(tf_bias, 4, axis=0)) # pylint: enable=invalid-name # pylint: disable=unbalanced-tuple-unpacking b_wi, b_ri = (B_i * 0.5,) * 2 @@ -539,8 +547,8 @@ class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter): # return two biases each with half the value. Since RNN does not # regularize by weight decay, it has no side effect in training or # inference. - array_ops.concat([b_wi, b_wr], axis=0) + array_ops.concat( - [b_ri, b_rr], axis=0), + array_ops.concat([b_wi, b_wr], axis=0) + + array_ops.concat([b_ri, b_rr], axis=0), b_wh, b_rh) @@ -720,8 +728,8 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) for param, param_name in zip(params, prefixed_param_names) ] - super(CudnnOpaqueParamsSaveable, self).__init__( - array_ops.identity(self._variables), specs, name) + super(CudnnOpaqueParamsSaveable, + self).__init__(array_ops.identity(self._variables), specs, name) @property def format_converter(self): @@ -760,15 +768,16 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): cuDNN-compatible cells. Args: - trackable: An object inheriting from `Trackable` to add - dependencies too (typically the cuDNN `Layer`). + trackable: An object inheriting from `Trackable` to add dependencies too + (typically the cuDNN `Layer`). dtype: The dtype for the canonical parameter Tensors. """ split_dependencies = split_dependency.split_dependency( component_names=self._param_names, component_dtypes=(dtype,) * len(self._param_names), fill_save_buffer_fn=self._trackable_save, - consume_restore_buffer_fn=self._trackable_restore) + consume_restore_buffer_fn=self._trackable_restore, + device=self._variables[0].device) self._trackable_track_params(trackable, split_dependencies) def _trackable_track_params(self, trackable, params): @@ -904,9 +913,9 @@ _cudnn_rnn_common_doc_string = """ def _check_rnn_mode(rnn_mode): if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU): - raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % - (rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, - CUDNN_RNN_RELU)) + raise ValueError( + "Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % + (rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU)) def _get_seed(seed): @@ -956,6 +965,7 @@ def _cudnn_rnn(inputs, is_training, rnn_mode, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -964,12 +974,14 @@ def _cudnn_rnn(inputs, """Cudnn RNN. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. - input_c: the initial hidden state for c. This is only relevant for LSTM. - A Tensor of the same shape as input_h. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. + input_c: the initial hidden state for c. This is only relevant for LSTM. A + Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). @@ -977,20 +989,25 @@ def _cudnn_rnn(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: outputs, output_h, output_c """ @@ -1017,6 +1034,14 @@ def _cudnn_rnn(inputs, } if sequence_lengths is not None: args["sequence_lengths"] = sequence_lengths + args["time_major"] = time_major + outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) + elif time_major is False: + batch_size = array_ops.shape(inputs)[0] + max_time = array_ops.shape(inputs)[1] + sequence_lengths = array_ops.fill([batch_size], max_time) + args["sequence_lengths"] = sequence_lengths + args["time_major"] = time_major outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) @@ -1031,6 +1056,7 @@ def cudnn_lstm(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1039,38 +1065,45 @@ def cudnn_lstm(inputs, """Cudnn LSTM. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. - input_c: the initial hidden state for c. This is only relevant for LSTM. - A Tensor of the same shape as input_h. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. + input_c: the initial hidden state for c. This is only relevant for LSTM. A + Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. - direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' - dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: outputs, output_h, output_c """ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, - sequence_lengths, input_mode, direction, dropout, seed, - name) + sequence_lengths, time_major, input_mode, direction, + dropout, seed, name) def _cudnn_rnn_no_input_c(inputs, @@ -1079,6 +1112,7 @@ def _cudnn_rnn_no_input_c(inputs, is_training, rnn_mode, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1087,10 +1121,12 @@ def _cudnn_rnn_no_input_c(inputs, """Cudnn RNN w/o input_c. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). @@ -1098,27 +1134,33 @@ def _cudnn_rnn_no_input_c(inputs, in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: outputs, output_h """ input_c = array_ops.constant([], dtype=input_h.dtype) outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths, - input_mode, direction, dropout, seed, name) + time_major, input_mode, direction, dropout, + seed, name) return outputs, output_h @@ -1127,6 +1169,7 @@ def cudnn_gru(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1135,36 +1178,43 @@ def cudnn_gru(inputs, """Cudnn GRU. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU, - sequence_lengths, input_mode, direction, dropout, - seed, name) + sequence_lengths, time_major, input_mode, + direction, dropout, seed, name) def cudnn_rnn_relu(inputs, @@ -1176,14 +1226,17 @@ def cudnn_rnn_relu(inputs, dropout=0., seed=0, sequence_lengths=None, + time_major=True, name=None): """Cudnn RNN Relu. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the @@ -1196,19 +1249,24 @@ def cudnn_rnn_relu(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. If not provided, the same sequence length will be assumed. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. name: name of the operation. Returns: outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_RELU, sequence_lengths, input_mode, - direction, dropout, seed, name) + CUDNN_RNN_RELU, sequence_lengths, time_major, + input_mode, direction, dropout, seed, name) def cudnn_rnn_tanh(inputs, @@ -1216,6 +1274,7 @@ def cudnn_rnn_tanh(inputs, params, is_training, sequence_lengths=None, + time_major=True, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1224,36 +1283,43 @@ def cudnn_rnn_tanh(inputs, """Cudnn RNN Tanh. Args: - inputs: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + inputs: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True (default), + the Tensor shape is [num_layers, batch_size, num_units]. If `time_major` + is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If true, + these Tensors must be shaped ['max_time', 'batch_size', 'depth']. If + false, these Tensors must be shaped ['batch_size', 'max_time', 'depth']. + By default this function accepts input and emits output in time-major + form. This param is only effective when 'sequence_lengths' is used. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_TANH, sequence_lengths, input_mode, - direction, dropout, seed, name) + CUDNN_RNN_TANH, sequence_lengths, time_major, + input_mode, direction, dropout, seed, name) def cudnn_rnn_opaque_params_to_canonical(rnn_mode, @@ -1270,26 +1336,25 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, Args: rnn_mode: a string specifies the mode, under which this RNN model runs. - Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. - input_size: the size of the input, it could be different from the - num_units. + input_size: the size of the input, it could be different from the num_units. params: opaque cudnn params var. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: weights list and bias list Raises: @@ -1332,27 +1397,26 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, Args: rnn_mode: a string specifies the mode, under which this RNN model runs. - Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. - input_size: the size of the input, it could be different from the - num_units. + input_size: the size of the input, it could be different from the num_units. weights: a Tensor for weight parameters. biases: a Tensor for bias parameters. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: an opaque Cudnn param. Raises: @@ -1391,26 +1455,25 @@ def cudnn_rnn_opaque_params_size(rnn_mode, Args: rnn_mode: a string specifies the mode, under which this RNN model runs. - Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. - input_size: the size of the input, it could be different from the - num_units. - input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input_size: the size of the input, it could be different from the num_units. + input_mode: indicate whether there is a linear projection between the input + and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dtype: one of tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. name: name of the operation. + Returns: a int, size of Cudnn opaque params. Raises: @@ -1458,25 +1521,25 @@ class _CudnnRNN(object): Args: rnn_mode: a string specifies the mode, under which this RNN model runs. - Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + seed: the op seed used for initializing dropout. See + `tf.compat.v1.set_random_seed` for behavior. + Raises: ValueError: if direction is invalid. """ @@ -1537,22 +1600,32 @@ class _CudnnRNN(object): input_c, params, is_training=True, - sequence_lengths=None): + sequence_lengths=None, + time_major=True): """Runs the forward step for the RNN model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. - input_c: the initial hidden state for c. This is only relevant for LSTM. - A Tensor of the same shape as input_h. + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. + input_c: the initial hidden state for c. This is only relevant for LSTM. A + Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' is + used. + Returns: output: the output sequence. output_h: the final state for h. @@ -1566,6 +1639,7 @@ class _CudnnRNN(object): is_training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -1636,15 +1710,14 @@ class CudnnLSTM(_CudnnRNN): num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and The actual computation before the first layer. It could be - 'skip_input', 'linear_input' or 'auto_select'. - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and The actual computation before the first layer. It could be + 'skip_input', 'linear_input' or 'auto_select'. 'skip_input' is only + allowed when input_size == num_units; 'auto_select' implies 'skip_input' + when input_size == num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the seed used for initializing dropout. @@ -1666,14 +1739,17 @@ class CudnnLSTM(_CudnnRNN): input_c, params, sequence_lengths=None, + time_major=True, is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the LSTM model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. input_c: the initial hidden state for c. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. @@ -1681,7 +1757,14 @@ class CudnnLSTM(_CudnnRNN): lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' is + used. is_training: whether this operation will be used in training or inference. + Returns: output: the output sequence. output_h: the final state for h. @@ -1693,6 +1776,7 @@ class CudnnLSTM(_CudnnRNN): input_c, params, sequence_lengths=sequence_lengths, + time_major=time_major, is_training=is_training) return (output, output_h, output_c) @@ -1716,15 +1800,14 @@ class _CudnnRNNNoInputC(_CudnnRNN): num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and The actual computation before the first layer. It could be - 'skip_input', 'linear_input' or 'auto_select'. - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and The actual computation before the first layer. It could be + 'skip_input', 'linear_input' or 'auto_select'. 'skip_input' is only + allowed when input_size == num_units; 'auto_select' implies 'skip_input' + when input_size == num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the seed used for initializing dropout. @@ -1752,20 +1835,30 @@ class _CudnnRNNNoInputC(_CudnnRNN): input_h, params, sequence_lengths=None, + time_major=True, is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the RNN model. A Tensor of shape [?, - batch_size, input_size]. - input_h: the initial hidden state for h. A Tensor of shape [num_layers, - batch_size, num_units]. + input_data: the input sequence to the RNN model. If `time_major` is True + (default), the Tensor shape is [max_time, batch_size, input_size]. If + `time_major` is False, the shape is [batch_size, max_time, input_size]. + input_h: the initial hidden state for h. If `time_major` is True + (default), the Tensor shape is [num_layers, batch_size, num_units]. If + `time_major` is False, the shape is [batch_size, num_layers, num_units]. params: the parameter buffer created for this model. sequence_lengths: an int32 array representing the variable sequence lengths in a batch. The size of the array has to equal the batch_size. Default to None, in which case sequences in the batch are assumed to have the same length, which is inferred from inputs. + time_major: The shape format of the `inputs` and `outputs` Tensors. If + true, these Tensors must be shaped ['max_time', 'batch_size', 'depth']. + If false, these Tensors must be shaped ['batch_size', 'max_time', + 'depth']. By default this function accepts input and emits output in + time-major form. This param is only effective when 'sequence_lengths' is + used. is_training: whether this operation will be used in training or inference. + Returns: output: the output sequence. output_h: the final state for h. @@ -1777,6 +1870,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): is_training, self._rnn_mode, sequence_lengths=sequence_lengths, + time_major=time_major, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 42f538b4ba1..10475cf2866 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -10,6 +10,7 @@ load("//tensorflow:tensorflow.bzl", "py_test") py_test( name = "assert_element_shape_test", srcs = ["assert_element_shape_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:batching", @@ -32,6 +33,7 @@ py_test( size = "medium", srcs = ["lmdb_dataset_op_test.py"], data = ["//tensorflow/core:lmdb_testdata"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_pip", @@ -57,6 +59,7 @@ py_test( name = "reduce_dataset_test", size = "small", srcs = ["reduce_dataset_test.py"], + python_version = "PY2", deps = [ "//tensorflow/contrib/data/python/ops:get_single_element", "//tensorflow/contrib/data/python/ops:grouping", @@ -73,6 +76,7 @@ py_test( name = "slide_dataset_op_test", size = "small", srcs = ["slide_dataset_op_test.py"], + python_version = "PY2", deps = [ "//tensorflow/contrib/data/python/ops:sliding", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 9275a36582a..95cf659a84b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -232,7 +232,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=array_ops.expand_dims( math_ops.range(i, dtype=dtypes.int64), 1), - values=array_ops.fill([math_ops.to_int32(i)], i), + values=array_ops.fill([math_ops.cast(i, dtypes.int32)], i), dense_shape=[i]) iterator = dataset_ops.make_initializable_iterator( diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index f8bb942c0a5..6a88cc68162 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -43,7 +43,8 @@ def dense_to_sparse_batch(batch_size, row_shape): # contents of a dataset. a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } - a.apply(tf.contrib.data.dense_to_sparse_batch(batch_size=2, row_shape=[6])) == + a.apply(tf.data.experimental.dense_to_sparse_batch(batch_size=2, + row_shape=[6])) == { ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices ['a', 'b', 'c', 'a', 'b'], # values @@ -55,14 +56,13 @@ def dense_to_sparse_batch(batch_size, row_shape): ``` Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the - number of consecutive elements of this dataset to combine in a - single batch. - row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like - object representing the equivalent dense shape of a row in the - resulting `tf.SparseTensor`. Each element of this dataset must - have the same rank as `row_shape`, and must have size less - than or equal to `row_shape` in each dimension. + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object + representing the equivalent dense shape of a row in the resulting + `tf.SparseTensor`. Each element of this dataset must have the same rank as + `row_shape`, and must have size less than or equal to `row_shape` in each + dimension. Returns: A `Dataset` transformation function, which can be passed to @@ -85,7 +85,7 @@ def unbatch(): # of a dataset. a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } - a.apply(tf.contrib.data.unbatch()) == { + a.apply(tf.data.experimental.unbatch()) == { 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} ``` @@ -111,7 +111,8 @@ def batch_and_drop_remainder(batch_size): ```python dataset = tf.data.Dataset.range(200) - batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128)) + batched = + dataset.apply(tf.contrib.data.batch_and_drop_remainder(128)) print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known) ``` @@ -121,7 +122,7 @@ def batch_and_drop_remainder(batch_size): Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. + consecutive elements of this dataset to combine in a single batch. Returns: A `Dataset` transformation function, which can be passed to @@ -152,11 +153,10 @@ def padded_batch_and_drop_remainder(batch_size, Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - padded_shapes: A nested structure of `tf.TensorShape` or - `tf.int64` vector tensor-like objects. See - `tf.data.Dataset.padded_batch` for details. - padding_values: (Optional.) A nested structure of scalar-shaped - `tf.Tensor`. See `tf.data.Dataset.padded_batch` for details. + padded_shapes: A nested structure of `tf.TensorShape` or `tf.int64` vector + tensor-like objects. See `tf.data.Dataset.padded_batch` for details. + padding_values: (Optional.) A nested structure of scalar-shaped `tf.Tensor`. + See `tf.data.Dataset.padded_batch` for details. Returns: A `Dataset` transformation function, which can be passed to @@ -179,7 +179,7 @@ def assert_element_shape(expected_shapes): ```python shapes = [tf.TensorShape([16, 256]), tf.TensorShape([None, 2])] - result = dataset.apply(tf.contrib.data.assert_element_shape(shapes)) + result = dataset.apply(tf.data.experimental.assert_element_shape(shapes)) print(result.output_shapes) # ==> "((16, 256), (, 2))" ``` @@ -245,8 +245,8 @@ def map_and_batch(map_func, deprecated. Args: - map_func: A function mapping a nested structure of tensors to another - nested structure of tensors. + map_func: A function mapping a nested structure of tensors to another nested + structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, @@ -257,9 +257,9 @@ def map_and_batch(map_func, whether the last batch should be dropped in case its size is smaller than desired; the default behavior is not to drop the smaller batch. num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, - representing the number of elements to process in parallel. If not - specified, `batch_size * num_parallel_batches` elements will be - processed in parallel. + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be processed + in parallel. Returns: A `Dataset` transformation function, which can be passed to diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 0559a2e09cc..b22e11a7044 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -32,12 +32,14 @@ def ignore_errors(): ```python dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) - # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. - dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error")) + # Computing `tf.debugging.check_numerics(1. / 0.)` will raise an + InvalidArgumentError. + dataset = dataset.map(lambda x: tf.debugging.check_numerics(1. / x, "error")) # Using `ignore_errors()` will drop the element that causes an error. dataset = - dataset.apply(tf.contrib.data.ignore_errors()) # ==> { 1., 0.5, 0.2 } + dataset.apply(tf.data.experimental.ignore_errors()) # ==> { 1., 0.5, 0.2 + } ``` Returns: diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index 58ad9eea903..9df55faf291 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -30,13 +30,14 @@ def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. This function enables you to use a `tf.data.Dataset` in a stateless - "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`. + "tensor-in tensor-out" expression, without creating a + `tf.compat.v1.data.Iterator`. This can be useful when your preprocessing transformations are expressed as a `Dataset`, and you want to use the transformation at serving time. For example: ```python - input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) + input_batch = tf.compat.v1.placeholder(tf.string, shape=[BATCH_SIZE]) def preprocessing_fn(input_str): # ... @@ -46,7 +47,7 @@ def get_single_element(dataset): .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) .batch(BATCH_SIZE)) - image_batch, label_batch = tf.contrib.data.get_single_element(dataset) + image_batch, label_batch = tf.data.experimental.get_single_element(dataset) ``` Args: @@ -70,7 +71,8 @@ def reduce_dataset(dataset, reducer): Args: dataset: A `tf.data.Dataset` object. - reducer: A `tf.contrib.data.Reducer` object representing the reduce logic. + reducer: A `tf.data.experimental.Reducer` object representing the reduce + logic. Returns: A nested structure of `tf.Tensor` objects, corresponding to the result diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index f50da4d429f..4543bd2ecc3 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -46,7 +46,7 @@ def parallel_interleave(map_func, # Preprocess 4 files concurrently. filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") dataset = filenames.apply( - tf.contrib.data.parallel_interleave( + tf.data.experimental.parallel_interleave( lambda filename: tf.data.TFRecordDataset(filename), cycle_length=4)) ``` @@ -146,7 +146,7 @@ def sample_from_datasets(datasets, weights=None, seed=None): `datasets`. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - `tf.set_random_seed` for behavior. + `tf.compat.v1.set_random_seed` for behavior. Returns: A dataset that interleaves elements from `datasets` at random, according to @@ -175,7 +175,7 @@ def choose_from_datasets(datasets, choice_dataset): # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. choice_dataset = tf.data.Dataset.range(3).repeat(3) - result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset) + result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) ``` The elements of `result` will be: diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 48c325c86f7..013fedb6186 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -36,11 +36,11 @@ def make_saveable_from_iterator(iterator): ds = tf.data.Dataset.range(10) iterator = ds.make_initializable_iterator() # Build the iterator SaveableObject. - saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator) + saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator) # Add the SaveableObject to the SAVEABLE_OBJECTS collection so # it can be automatically saved using Saver. - tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) - saver = tf.train.Saver() + tf.compat.v1.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = tf.compat.v1.train.Saver() while continue_training: ... Perform training ... @@ -82,7 +82,7 @@ class CheckpointInputPipelineHook(iterator_ops.CheckpointInputPipelineHook): while True: est.train( train_input_fn, - hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)], + hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)], steps=train_steps_per_eval) # Note: We do not pass the hook here. metrics = est.evaluate(eval_input_fn) @@ -99,7 +99,7 @@ class CheckpointInputPipelineHook(iterator_ops.CheckpointInputPipelineHook): pipeline. For saving the input pipeline checkpoint alongside the model weights use - `tf.contrib.data.make_saveable_from_iterator` directly to create a + `tf.data.experimental.make_saveable_from_iterator` directly to create a `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, that you will need to be careful not to restore the training iterator during eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py index 3aeee9d8e42..7bc4f0a0193 100644 --- a/tensorflow/contrib/data/python/ops/parsing_ops.py +++ b/tensorflow/contrib/data/python/ops/parsing_ops.py @@ -34,7 +34,7 @@ def parse_example_dataset(features, num_parallel_calls=1): and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`, `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature` and `SparseFeature` is mapped to a `SparseTensor`, and each - `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more + `FixedLenFeature` is mapped to a `Tensor`. See `tf.io.parse_example` for more details about feature dictionaries. Args: diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index c6bf5215c94..70fbff9e6aa 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -60,7 +60,7 @@ def make_csv_dataset( Args: file_pattern: List of files or patterns of file paths containing CSV - records. See `tf.gfile.Glob` for pattern rules. + records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int representing the number of records to combine in a single batch. column_names: An optional list of strings that corresponds to the CSV @@ -225,11 +225,11 @@ def make_batched_features_dataset(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int representing the number of records to combine in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or - `VarLenFeature` values. See `tf.parse_example`. + `VarLenFeature` values. See `tf.io.parse_example`. reader: A function or class that can be called with a `filenames` tensor and (optional) `reader_args` and returns a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. @@ -328,11 +328,11 @@ def read_batch_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int representing the number of records to combine in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or - `VarLenFeature` values. See `tf.parse_example`. + `VarLenFeature` values. See `tf.io.parse_example`. reader: A function or class that can be called with a `filenames` tensor and (optional) `reader_args` and returns a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. @@ -378,7 +378,7 @@ class LMDBDataset(dataset_ops.DatasetSource): (key value) pairs sequentially. For example: ```python - tf.enable_eager_execution() + tf.compat.v1.enable_eager_execution() dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb") diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index 329b34fdfec..ef9944e6143 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -26,7 +26,7 @@ from tensorflow.python.util import deprecation def shuffle_and_repeat(buffer_size, count=None, seed=None): """Shuffles and repeats a Dataset returning a new permutation for each epoch. - `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))` + `dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size, count))` is equivalent to @@ -45,7 +45,7 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None): indefinitely. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - `tf.set_random_seed` for behavior. + `tf.compat.v1.set_random_seed` for behavior. Returns: A `Dataset` transformation function, which can be passed to diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 2897c72b445..b3c2c984a9d 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -56,7 +56,7 @@ class _SlideDataset(dataset_ops.UnaryDataset): None, "stride is deprecated, use window_shift instead", "stride") @deprecation.deprecated( None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, " - "stride=window_stride).flat_map(lambda x: x.batch(window.size))` " + "stride=window_stride).flat_map(lambda x: x.batch(window_size))` " "instead.") def sliding_window_batch(window_size, stride=None, diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 909d06c677e..9129599eb95 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -32,7 +32,7 @@ def unique(): dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) # Using `unique()` will drop the duplicate elements. - dataset = dataset.apply(tf.contrib.data.unique()) # ==> { 1, 37, 2 } + dataset = dataset.apply(tf.data.experimental.unique()) # ==> { 1, 37, 2 } ``` Returns: diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto b/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto index 4c0cceaddca..2a41b321b78 100644 --- a/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto @@ -10,9 +10,21 @@ import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto"; // Used in generic_tree_model.BinaryNode.left_child_test. // Tests whether the feature's value belongs to the specified list, // (or does not belong if inverse=True). +// For empty list use ConstResultTest instead. message MatchingValuesTest { // When the feature is missing, the test's outcome is undefined. FeatureId feature_id = 1; repeated Value value = 2; bool inverse = 3; } + +// Used in generic_tree_model.BinaryNode.left_child_test. +// Returns test_result if feature value is not missed. Otherwise +// BinaryNode.default_direction is used. +message ConstResultTest { + FeatureId feature_id = 1; + // value_for_dtype is used to store the type of the feature. The value itself + // should be ignored, only its type is used. + Value value_for_dtype = 2; + bool test_result = 3; +} diff --git a/tensorflow/contrib/deprecated/BUILD b/tensorflow/contrib/deprecated/BUILD index 401527f1e74..035d8cfc37e 100644 --- a/tensorflow/contrib/deprecated/BUILD +++ b/tensorflow/contrib/deprecated/BUILD @@ -22,6 +22,7 @@ py_library( py_test( name = "summaries_test", srcs = ["summaries_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/deprecated/__init__.py b/tensorflow/contrib/deprecated/__init__.py index 7aff045de30..fc4742bc2c8 100644 --- a/tensorflow/contrib/deprecated/__init__.py +++ b/tensorflow/contrib/deprecated/__init__.py @@ -19,12 +19,12 @@ submodule, and made some semantic tweaks. The first thing to note is that we moved the APIs around as follows: ```python -tf.scalar_summary -> tf.summary.scalar -tf.histogram_summary -> tf.summary.histogram -tf.audio_summary -> tf.summary.audio -tf.image_summary -> tf.summary.image -tf.merge_summary -> tf.summary.merge -tf.merge_all_summaries -> tf.summary.merge_all +tf.scalar_summary -> tf.compat.v1.summary.scalar +tf.histogram_summary -> tf.compat.v1.summary.histogram +tf.audio_summary -> tf.compat.v1.summary.audio +tf.image_summary -> tf.compat.v1.summary.image +tf.merge_summary -> tf.compat.v1.summary.merge +tf.merge_all_summaries -> tf.compat.v1.summary.merge_all ``` We think this API is cleaner and will improve long-term discoverability and @@ -59,8 +59,8 @@ def add_activation_summaries(v, scope): # After def add_activation_summaries(v): - tf.summary.scalar("fraction_of_zero", tf.nn.fraction_of_zero(v)) - tf.summary.histogram("activations", v) + tf.compat.v1.summary.scalar("fraction_of_zero", tf.nn.fraction_of_zero(v)) + tf.compat.v1.summary.histogram("activations", v) ``` Now, so long as the add_activation_summaries function is called from within the @@ -74,10 +74,12 @@ In addition to the name change described above, there are two further changes to the new summary ops: - the "max_images" argument for `tf.image_summary` was renamed to "max_outputs - for `tf.summary.image` + for `tf.compat.v1.summary.image` - `tf.scalar_summary` accepted arbitrary tensors of tags and values. But - `tf.summary.scalar` requires a single scalar name and scalar value. In most - cases, you can create `tf.summary.scalar` in a loop to get the same behavior + `tf.compat.v1.summary.scalar` requires a single scalar name and scalar value. + In most + cases, you can create `tf.compat.v1.summary.scalar` in a loop to get the same + behavior As before, TensorBoard groups charts by the top-level `tf.name_scope` which may be inconvenient, for in the new summary ops, the summary will inherit that @@ -90,7 +92,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - # pylint: disable=unused-import from tensorflow.python.ops.logging_ops import audio_summary from tensorflow.python.ops.logging_ops import histogram_summary @@ -102,8 +103,9 @@ from tensorflow.python.ops.logging_ops import scalar_summary from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long -_allowed_symbols = ['audio_summary', 'histogram_summary', - 'image_summary', 'merge_all_summaries', - 'merge_summary', 'scalar_summary'] +_allowed_symbols = [ + 'audio_summary', 'histogram_summary', 'image_summary', + 'merge_all_summaries', 'merge_summary', 'scalar_summary' +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index 3ecd755d86f..1fa4a9bcee1 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -30,12 +30,12 @@ py_library( "//tensorflow/contrib/distribute/python:monitor", "//tensorflow/contrib/distribute/python:one_device_strategy", "//tensorflow/contrib/distribute/python:parameter_server_strategy", - "//tensorflow/contrib/distribute/python:step_fn", "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:distribute_config", "//tensorflow/python/distribute:distribute_coordinator", + "//tensorflow/python/distribute:step_fn", ], ) diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index dbcaf8185fb..ea48cb390b9 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -187,7 +187,7 @@ in the input function gives a solid boost in performance. When using For multi-worker training, no code change is required to the `Estimator` code. You can run the same model code for all tasks in your cluster including parameter servers and the evaluator. But you need to use -`tf.estimator.train_and_evaluate`, explicitly specify `num_gpus_per_workers` +`tf.estimator.train_and_evaluate`, explicitly specify `num_gpus_per_worker` for your strategy object, and set "TF\_CONFIG" environment variables for each binary running in your cluster. We'll provide a Kubernetes template in the [tensorflow/ecosystem](https://github.com/tensorflow/ecosystem) repo which sets diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 59d76f5d1c8..40a0e773978 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -29,7 +29,6 @@ from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrat from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy -from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import initialize_tpu_system from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.distribute.cross_device_ops import * @@ -38,9 +37,13 @@ from tensorflow.python.distribute.distribute_coordinator import run_standard_ten from tensorflow.python.distribute.distribute_lib import * from tensorflow.python.distribute.distribution_strategy_context import * +from tensorflow.python.distribute.step_fn import * from tensorflow.python.util.all_util import remove_undocumented +DistributionStrategy = StrategyV1 + + _allowed_symbols = [ 'AllReduceCrossDeviceOps', 'CollectiveAllReduceStrategy', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 88d6c7a6d27..c5ddf6b5533 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -2,7 +2,6 @@ load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test") -load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( @@ -19,51 +18,16 @@ py_library( name = "distribute_test_lib_pip", visibility = ["//tensorflow:internal"], deps = [ - ":combinations", - ":keras_correctness_test_lib", - ":keras_test_lib", - ":multi_worker_test_base", - ":single_loss_example", - ":strategy_test_lib", - ], -) - -cuda_py_test( - name = "values_test", - srcs = ["values_test.py"], - additional_deps = [ - ":combinations", - ":mirrored_strategy", - "@absl_py//absl/testing:parameterized", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/distribute:device_util", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:estimator_py", - ], -) - -cuda_py_test( - name = "input_lib_test", - srcs = ["input_lib_test.py"], - additional_deps = [ - ":combinations", - ":mirrored_strategy", - ":multi_worker_test_base", - "@absl_py//absl/testing:parameterized", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:input_lib", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", + ":keras_multi_worker_test_base", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:model_combinations", + "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute:saved_model_test_base", + "//tensorflow/python/distribute:single_loss_example", + "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:strategy_test_lib", + "//tensorflow/python/keras/distribute:keras_correctness_test_lib", + "//tensorflow/python/keras/distribute:keras_test_lib", ], ) @@ -94,10 +58,13 @@ cuda_py_test( name = "parameter_server_strategy_test", srcs = ["parameter_server_strategy_test.py"], additional_deps = [ - ":combinations", - ":multi_worker_test_base", ":parameter_server_strategy", - ":strategy_test_lib", + "//tensorflow/python/distribute:central_storage_strategy", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:parameter_server_strategy", + "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute:strategy_test_lib", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -131,16 +98,6 @@ py_library( ], ) -cuda_py_test( - name = "one_device_strategy_test", - srcs = ["one_device_strategy_test.py"], - additional_deps = [ - ":strategy_test_lib", - ":combinations", - "//tensorflow/python/eager:test", - ], -) - py_library( name = "collective_all_reduce_strategy", srcs = ["collective_all_reduce_strategy.py"], @@ -152,106 +109,59 @@ py_library( ], ) -py_library( - name = "strategy_test_lib", - srcs = ["strategy_test_lib.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:layers", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/eager:backprop", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - "//third_party/py/numpy", - ], -) - -py_library( - name = "combinations", - srcs = ["combinations.py"], - srcs_version = "PY2AND3", - deps = [ - ":mirrored_strategy", - ":one_device_strategy", - ":parameter_server_strategy", - ":tpu_strategy", - "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", - "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/eager:context", - "//tensorflow/python/keras/optimizer_v2", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "combinations_test", - srcs = ["combinations_test.py"], - deps = [ - ":combinations", - "//tensorflow/python/eager:test", - ], -) - -# TODO(priyag): Rename this test to mirrored_strategy_test cuda_py_test( - name = "mirrored_strategy_multigpu_test", - srcs = ["mirrored_strategy_multigpu_test.py"], + name = "contrib_mirrored_strategy_test", + srcs = ["contrib_mirrored_strategy_test.py"], additional_deps = [ - ":combinations", ":mirrored_strategy", - ":multi_worker_test_base", - ":strategy_test_lib", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:layers", - "//tensorflow/python:state_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:values", ], - shard_count = 5, + shard_count = 1, tags = [ "guitar", "multi_and_single_gpu", ], ) -py_library( - name = "multi_worker_test_base", - srcs = ["multi_worker_test_base.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/core:protos_all_py", +cuda_py_test( + name = "keras_multi_worker_correctness_test", + srcs = ["keras_multi_worker_correctness_test.py"], + additional_deps = [ + ":collective_all_reduce_strategy", + ":mirrored_strategy", + ":parameter_server_strategy", + ":keras_multi_worker_test_base", "//tensorflow/python:client_testlib", - "//tensorflow/python:distributed_framework_test_lib", - "//tensorflow/python:session", - "//tensorflow/python:util", - "//tensorflow/python/estimator:estimator_py", - "//third_party/py/numpy", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:distribute_config", + "//tensorflow/python/distribute:distribute_coordinator", + "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/eager:context", + "//tensorflow/python/keras", + ], + tags = [ + "multi_and_single_gpu", + "nomsan", # TODO(b/130299192) ], ) py_library( - name = "step_fn", - srcs = ["step_fn.py"], - visibility = ["//tensorflow:internal"], + name = "keras_multi_worker_test_base", + srcs = ["keras_multi_worker_test_base.py"], deps = [ - "//tensorflow/python:training", - "//tensorflow/python/eager:backprop", + ":collective_all_reduce_strategy", + ":mirrored_strategy", + ":parameter_server_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:multi_worker_test_base", ], ) @@ -270,9 +180,10 @@ cuda_py_test( srcs = ["collective_all_reduce_strategy_test.py"], additional_deps = [ ":collective_all_reduce_strategy", - ":combinations", - ":multi_worker_test_base", - ":strategy_test_lib", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute:strategy_test_lib", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/core:protos_all_py", @@ -295,59 +206,21 @@ cuda_py_test( ], ) -distribute_py_test( - name = "minimize_loss_test", - srcs = ["minimize_loss_test.py"], - main = "minimize_loss_test.py", - tags = [ - "multi_and_single_gpu", - ], - deps = [ - ":combinations", - ":mirrored_strategy", - ":single_loss_example", - "//tensorflow/contrib/tpu:tpu_lib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - "//tensorflow/python/ops/losses", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -cuda_py_test( - name = "moving_averages_test", - srcs = ["moving_averages_test.py"], - additional_deps = [ - ":combinations", - "@absl_py//absl/testing:parameterized", - "//tensorflow/python/eager:test", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:training", - "//tensorflow/python:variables", - ], -) - cuda_py_test( name = "optimizer_v2_test", srcs = ["optimizer_v2_test.py"], additional_deps = [ - ":combinations", - ":single_loss_example", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:single_loss_example", + ":mirrored_strategy", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:control_flow_ops", "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/contrib/optimizer_v2:training", ], tags = [ "multi_and_single_gpu", @@ -358,7 +231,8 @@ cuda_py_test( name = "estimator_integration_test", srcs = ["estimator_integration_test.py"], additional_deps = [ - ":combinations", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:strategy_combinations", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/contrib/optimizer_v2:training", @@ -377,28 +251,15 @@ cuda_py_test( ], ) -cuda_py_test( - name = "keras_optimizer_v2_test", - srcs = ["keras_optimizer_v2_test.py"], - additional_deps = [ - ":keras_test_lib", - ], - shard_count = 4, - tags = [ - "multi_and_single_gpu", - "no_oss", # http://b/119349471 - "tf_integration_test", - ], -) - cuda_py_test( name = "estimator_training_test", srcs = ["estimator_training_test.py"], additional_deps = [ ":collective_all_reduce_strategy", - ":combinations", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:strategy_combinations", ":mirrored_strategy", - ":multi_worker_test_base", + "//tensorflow/python/distribute:multi_worker_test_base", ":parameter_server_strategy", "//third_party/py/numpy", "//tensorflow/contrib/optimizer_v2:training", @@ -424,38 +285,6 @@ cuda_py_test( ], ) -py_library( - name = "single_loss_example", - srcs = ["single_loss_example.py"], - deps = [ - ":step_fn", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:layers", - "//tensorflow/python:math_ops", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -distribute_py_test( - name = "step_fn_test", - srcs = ["step_fn_test.py"], - main = "step_fn_test.py", - tags = [ - "multi_and_single_gpu", - ], - deps = [ - ":combinations", - ":single_loss_example", - "//tensorflow/contrib/tpu:tpu_lib", - "//tensorflow/python:variables", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - py_library( name = "monitor", srcs = ["monitor.py"], @@ -470,9 +299,10 @@ cuda_py_test( name = "monitor_test", srcs = ["monitor_test.py"], additional_deps = [ - ":combinations", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:strategy_combinations", ":monitor", - ":single_loss_example", + "//tensorflow/python/distribute:single_loss_example", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python/distribute:one_device_strategy", @@ -486,100 +316,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "cross_device_utils_test", - srcs = ["cross_device_utils_test.py"], - additional_deps = [ - ":combinations", - "@absl_py//absl/testing:parameterized", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/distribute:cross_device_utils", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - -cuda_py_test( - name = "cross_device_ops_test", - srcs = ["cross_device_ops_test.py"], - additional_deps = [ - ":collective_all_reduce_strategy", - ":combinations", - ":multi_worker_test_base", - ":mirrored_strategy", - "@absl_py//absl/testing:parameterized", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], - tags = [ - "multi_and_single_gpu", - ], -) - -py_library( - name = "keras_test_lib", - srcs = [ - "keras_backward_compat_test.py", - "keras_test.py", - "keras_utils_test.py", - ], - deps = [ - ":combinations", - "//tensorflow/contrib/distribute/python:mirrored_strategy", - "//tensorflow/contrib/distribute/python:tpu_strategy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/keras", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -distribute_py_test( - name = "keras_test", - srcs = ["keras_test.py"], - full_precision = True, - main = "keras_test.py", - shard_count = 32, - tags = [ - "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_windows_gpu", - "notsan", - ], - deps = [ - ":keras_test_lib", - ], -) - -distribute_py_test( - name = "keras_utils_test", - srcs = ["keras_utils_test.py"], - full_precision = True, - main = "keras_utils_test.py", - shard_count = 32, - tags = [ - "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_windows_gpu", - "notsan", - ], - deps = [ - ":keras_test", - ":keras_test_lib", - ], -) - # TODO(b/121200287): Remove this in 2.0 distribute_py_test( name = "keras_backward_compat_test", @@ -589,208 +325,12 @@ distribute_py_test( shard_count = 31, tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. "no_windows_gpu", "notsan", ], deps = [ - ":keras_test_lib", - ], -) - -py_library( - name = "keras_correctness_test_lib", - srcs = [ - "keras_correctness_test_base.py", - "keras_dnn_correctness_test.py", - "keras_embedding_model_correctness_test.py", - "keras_image_model_correctness_test.py", - "keras_lstm_model_correctness_test.py", - "keras_stateful_lstm_model_correctness_test.py", - ], - deps = [ - ":combinations", - "//tensorflow/contrib/distribute/python:mirrored_strategy", - "//tensorflow/contrib/distribute/python:tpu_strategy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/keras", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -distribute_py_test( - name = "keras_dnn_correctness_test", - size = "medium", - srcs = ["keras_dnn_correctness_test.py"], - full_precision = True, - main = "keras_dnn_correctness_test.py", - # Shard count is set to an odd number to distribute tasks across - # shards more evenly. - shard_count = 19, - tags = [ - "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_windows_gpu", - "notsan", - ], - deps = [ - ":keras_correctness_test_lib", - ], -) - -distribute_py_test( - name = "keras_image_model_correctness_test", - size = "medium", - srcs = ["keras_image_model_correctness_test.py"], - full_precision = True, - main = "keras_image_model_correctness_test.py", - # Shard count is set to an odd number to distribute tasks across - # shards more evenly. - shard_count = 31, - tags = [ - "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_windows_gpu", - "notsan", - ], - deps = [ - ":keras_correctness_test_lib", - ], -) - -distribute_py_test( - name = "keras_embedding_model_correctness_test", - size = "medium", - srcs = ["keras_embedding_model_correctness_test.py"], - full_precision = True, - main = "keras_embedding_model_correctness_test.py", - # Shard count is set to an odd number to distribute tasks across - # shards more evenly. - shard_count = 31, - tags = [ - "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_windows_gpu", - "notsan", - ], - deps = [ - ":keras_correctness_test_lib", - ], -) - -distribute_py_test( - name = "keras_lstm_model_correctness_test", - size = "medium", - srcs = ["keras_lstm_model_correctness_test.py"], - full_precision = True, - main = "keras_lstm_model_correctness_test.py", - # Shard count is set to an odd number to distribute tasks across - # shards more evenly. - shard_count = 31, - tags = [ - "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_windows_gpu", - "notsan", - ], - deps = [ - ":keras_correctness_test_lib", - ], -) - -distribute_py_test( - name = "keras_stateful_lstm_model_correctness_test", - size = "medium", - srcs = ["keras_stateful_lstm_model_correctness_test.py"], - full_precision = True, - main = "keras_stateful_lstm_model_correctness_test.py", - # Shard count is set to an odd number to distribute tasks across - # shards more evenly. - shard_count = 31, - tags = [ - "multi_and_single_gpu", - "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", - "no_windows_gpu", - "notsan", - ], - deps = [ - ":keras_correctness_test_lib", - ], -) - -distribute_py_test( - name = "metrics_v1_test", - srcs = ["metrics_v1_test.py"], - main = "metrics_v1_test.py", - tags = [ - "multi_and_single_gpu", - ], - deps = [ - ":combinations", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:test", - "@absl_py//absl/testing:parameterized", - ], -) - -cuda_py_test( - name = "warm_starting_util_test", - size = "medium", - srcs = ["warm_starting_util_test.py"], - additional_deps = [ - ":combinations", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], - tags = [ - "multi_and_single_gpu", - ], -) - -cuda_py_test( - name = "checkpoint_utils_test", - size = "medium", - srcs = ["checkpoint_utils_test.py"], - additional_deps = [ - ":combinations", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], - tags = [ - "multi_and_single_gpu", - ], -) - -tf_xla_py_test( - name = "checkpointing_test", - srcs = ["checkpointing_test.py"], - disabled_backends = [ - # Only makes sense on TPUs - "cpu", - "gpu", - "cpu_ondemand", - ], - tags = [ - "no_oss", - ], - deps = [ - ":tpu_strategy", - "//tensorflow/compiler/tests:xla_test", - "//tensorflow/python/eager:test", - "//tensorflow/python/training/tracking:util", + ":mirrored_strategy", + "//tensorflow/python/distribute:tpu_strategy", + "//tensorflow/python/keras/distribute:keras_test_lib", ], ) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index d4f76e3e7b9..5f944e493dc 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -26,7 +26,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve # TODO(yuefengz): support in-graph replication. -class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): +class CollectiveAllReduceStrategy(distribute_lib.StrategyV1): """Distribution strategy that uses collective ops for all-reduce. *** contrib version *** diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index cef3cc60737..d6eff47fdc5 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -22,19 +22,19 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.distribute.python import collective_all_reduce_strategy -from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import collective_all_reduce_strategy as core_collective_all_reduce_strategy +from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import strategy_test_lib from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context @@ -57,7 +57,7 @@ from tensorflow.python.training import training_util from tensorflow.python.training.server_lib import ClusterSpec -class MockCollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): +class MockCollectiveAllReduceStrategy(distribute_lib.StrategyV1): """Mock the strategy to allow cluster resolver as an argument.""" def __init__(self, cluster_resolver): @@ -176,7 +176,7 @@ class CollectiveAllReduceStrategyTestBase( def update(v, g): return v.assign_sub(0.05 * g, use_locking=True) - one = d.broadcast(constant_op.constant([[1.]])) + one = constant_op.constant([[1.]]) def step(): """Perform one optimization step.""" @@ -266,7 +266,7 @@ class CollectiveAllReduceStrategyTestBase( target=master_target) as sess: with d.scope(): train_op = d.extended.call_for_each_replica(model_fn) - train_op = d.group(d.unwrap(train_op)) + train_op = d.group(d.experimental_local_results(train_op)) sess.run(variables.global_variables_initializer()) sess.run(train_op) @@ -293,8 +293,8 @@ class CollectiveAllReduceStrategyTestBase( return array_ops.identity(x) x = distribution.extended.call_for_each_replica(model_fn) - reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x) - x = distribution.unwrap(x)[0] + reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x, axis=None) + x = distribution.experimental_local_results(x)[0] sess.run(variables.global_variables_initializer()) @@ -312,6 +312,7 @@ class CollectiveAllReduceStrategyTestBase( input_fn, expected_values, test_reinitialize=True, + ignore_order=False, use_core_strategy=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) @@ -327,7 +328,10 @@ class CollectiveAllReduceStrategyTestBase( next_element = iterator.get_next() computed_value = sess.run([values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(expected_value, computed_value) + if ignore_order: + self.assertCountEqual(expected_value, computed_value) + else: + self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() @@ -342,7 +346,10 @@ class CollectiveAllReduceStrategyTestBase( next_element = iterator.get_next() computed_value = sess.run([values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(expected_value, computed_value) + if ignore_order: + self.assertCountEqual(expected_value, computed_value) + else: + self.assertEqual(expected_value, computed_value) class DistributedCollectiveAllReduceStrategyTest( @@ -413,7 +420,6 @@ class DistributedCollectiveAllReduceStrategyTest( num_gpus=num_gpus, use_core_strategy=use_core_strategy) - # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(yuefengz): Update how we use num_gpus and required_gpus @combinations.generate( combinations.combine( @@ -422,8 +428,7 @@ class DistributedCollectiveAllReduceStrategyTest( required_gpus=1, use_dataset=[True, False], use_core_strategy=[True, False])) - def DISABLED_testMakeInputFnIterator(self, num_gpus, use_dataset, - use_core_strategy): + def testMakeInputFnIterator(self, num_gpus, use_dataset, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -450,6 +455,7 @@ class DistributedCollectiveAllReduceStrategyTest( input_fn, expected_values, test_reinitialize=use_dataset, + ignore_order=not use_dataset, use_core_strategy=use_core_strategy) @combinations.generate( @@ -576,7 +582,7 @@ class LocalCollectiveAllReduceStrategy( required_gpus=2, use_dataset=[True, False], use_core_strategy=[True, False])) - def DISABLED_testMakeInputFnIterator(self, use_dataset, use_core_strategy): + def testMakeInputFnIterator(self, use_dataset, use_core_strategy): num_gpus = 2 if use_dataset: fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) @@ -599,6 +605,7 @@ class LocalCollectiveAllReduceStrategy( input_fn, expected_values, test_reinitialize=use_dataset, + ignore_order=not use_dataset, use_core_strategy=use_core_strategy) @combinations.generate( diff --git a/tensorflow/contrib/distribute/python/contrib_mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/contrib_mirrored_strategy_test.py new file mode 100644 index 00000000000..8642e0cf4e9 --- /dev/null +++ b/tensorflow/contrib/distribute/python/contrib_mirrored_strategy_test.py @@ -0,0 +1,91 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests the contrib MirroredStrategy specific features.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import test + +contrib_mirrored_strategies = [ + combinations.NamedDistribution( + "ContribMirrored1CPU", + lambda: mirrored_strategy.MirroredStrategy(["/cpu:0"])), + combinations.NamedDistribution( + "ContribMirrored1GPU", + lambda: mirrored_strategy.MirroredStrategy(["/gpu:0"]), + required_gpus=1), + combinations.NamedDistribution( + "ContribMirroredCPUAndGPU", + lambda: mirrored_strategy.MirroredStrategy(["/cpu:0", "/gpu:0"]), + required_gpus=1), + combinations.NamedDistribution( + "ContribMirrored2GPU", + lambda: mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), + required_gpus=2), +] + + +def all_strategy_and_eager_plus_graph(): + return combinations.times( + combinations.combine(distribution=contrib_mirrored_strategies), + combinations.combine(mode=["eager", "graph"])) + + +class ContribMirroredStrategyTest(test.TestCase, parameterized.TestCase): + + def _initialize_and_evaluate_iterator(self, iterator): + if context.executing_eagerly(): + iterator.initialize() + res = iterator.get_next() + if isinstance(res, values.PerReplica): + res = res.values + else: + with self.cached_session() as sess: + sess.run(iterator.initialize()) + res = iterator.get_next() + if isinstance(res, values.PerReplica): + res = sess.run(res.values) + else: + res = sess.run(res) + + return res + + @combinations.generate(all_strategy_and_eager_plus_graph()) + def test_dataset_iterator(self, distribution): + data = np.array([[1, 1], [2, 1], [3, 1], [4, 1]]) + dataset = dataset_ops.Dataset.from_tensors(data).repeat() + iterator = distribution.make_dataset_iterator(dataset) + res = self._initialize_and_evaluate_iterator(iterator) + + if isinstance(res, tuple): + self.assertLen(res, 2) + self.assertAllEqual(data, res[0]) + self.assertAllEqual(data, res[1]) + else: + self.assertAllEqual(data, res) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index 1ff1e7c1d25..c46616ce60f 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -22,10 +22,10 @@ import shutil import tempfile from absl.testing import parameterized import numpy as np - -from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.optimizer_v2 import adagrad from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import test from tensorflow.python.estimator import run_config from tensorflow.python.estimator import training @@ -60,11 +60,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, combinations.combine( mode=['graph'], distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_two_gpus, ], use_train_and_evaluate=[True, False])) def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 1b422ef2d19..9eebdfd68d8 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -28,15 +28,15 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.distribute.python import collective_all_reduce_strategy -from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.optimizer_v2 import adagrad from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import estimator_training as dc_training +from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.eager import context from tensorflow.python.estimator import exporter as exporter_lib @@ -249,9 +249,25 @@ class DistributeCoordinatorIntegrationTest( ]) self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) - def _get_strategy_object(self, strategy_cls): + def _make_cross_device_ops(self, num_gpus_per_worker): + return cross_device_ops_lib.MultiWorkerAllReduce( + ["/job:worker/task:0", "/job:worker/task:1", "/job:worker/task:2"], + num_gpus_per_worker) + + def _get_strategy_object(self, strategy_cls, eval_strategy=False): if strategy_cls == mirrored_strategy.CoreMirroredStrategy: - return strategy_cls() + if eval_strategy: + return strategy_cls() + else: + return strategy_cls( + cross_device_ops=self._make_cross_device_ops( + num_gpus_per_worker=context.num_gpus())) + elif (strategy_cls == mirrored_strategy.MirroredStrategy and + not eval_strategy): + return strategy_cls( + num_gpus_per_worker=context.num_gpus(), + cross_device_ops=self._make_cross_device_ops( + num_gpus_per_worker=context.num_gpus())) else: return strategy_cls(num_gpus_per_worker=context.num_gpus()) @@ -277,7 +293,8 @@ class DistributeCoordinatorIntegrationTest( train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = self._get_strategy_object(eval_distribute_cls) + eval_distribute = self._get_strategy_object( + eval_distribute_cls, eval_strategy=True) else: eval_distribute = None @@ -307,7 +324,8 @@ class DistributeCoordinatorIntegrationTest( communication=cross_device_ops_lib.CollectiveCommunication.NCCL)) if eval_distribute_class: - eval_distribute = self._get_strategy_object(eval_distribute_class) + eval_distribute = self._get_strategy_object( + eval_distribute_class, eval_strategy=True) else: eval_distribute = None @@ -388,7 +406,8 @@ class DistributeCoordinatorIntegrationTest( train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = self._get_strategy_object(eval_distribute_cls) + eval_distribute = self._get_strategy_object( + eval_distribute_cls, eval_strategy=True) else: eval_distribute = None @@ -436,7 +455,8 @@ class DistributeCoordinatorIntegrationTest( train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = self._get_strategy_object(eval_distribute_cls) + eval_distribute = self._get_strategy_object( + eval_distribute_cls, eval_strategy=True) else: eval_distribute = None diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD index 58bede801ff..75fbc3bf53f 100644 --- a/tensorflow/contrib/distribute/python/examples/BUILD +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -13,6 +13,7 @@ exports_files(["LICENSE"]) py_binary( name = "simple_estimator_example", srcs = ["simple_estimator_example.py"], + python_version = "PY2", deps = [ "//tensorflow:tensorflow_py", ], @@ -23,6 +24,7 @@ py_binary( srcs = [ "keras_model_with_estimator.py", ], + python_version = "PY2", deps = [ "//tensorflow:tensorflow_py", "//third_party/py/numpy", @@ -32,6 +34,7 @@ py_binary( py_binary( name = "keras_mnist", srcs = ["keras_mnist.py"], + python_version = "PY2", deps = [":keras_mnist_lib"], ) @@ -51,6 +54,7 @@ py_binary( srcs = [ "mnist_eager_multigpu.py", ], + python_version = "PY2", deps = [ "//tensorflow:tensorflow_py", "//third_party/py/numpy", @@ -62,6 +66,7 @@ py_binary( srcs = [ "mnist_tf1_tpu.py", ], + python_version = "PY2", deps = [ "//tensorflow:tensorflow_py", "//third_party/py/numpy", diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py deleted file mode 100644 index 204f52b034f..00000000000 --- a/tensorflow/contrib/distribute/python/input_lib_test.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the input_lib library.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized - -from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import input_lib -from tensorflow.python.distribute import values -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import errors -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.util import nest - - -class InputIteratorTestBase(test.TestCase): - - def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) - device_map = values.ReplicaDeviceMap(devices) - input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) - - if input_type == "input_fn": - input_contexts = [ - distribute_lib.InputContext() for _ in worker_device_pairs] - input_fn = lambda _: dataset_fn() - iterator = input_lib.InputFunctionIterator( - input_fn, input_workers, input_contexts) - else: - iterator = input_lib.DatasetIterator( - dataset_fn(), input_workers, split_batch_by) - - evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - evaluate([values.select_replica(r, next_element) - for r in range(len(devices))]) - - # After re-initializing the iterator, should be able to iterate again. - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertAllEqual(expected_value, computed_value) - - -class InputIteratorSingleWorkerTest(InputIteratorTestBase, - parameterized.TestCase): - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"])) - def testOneDeviceCPU(self, input_type): - worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesOneGPUOneCPU(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTupleDataset(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["dataset"], - split_batch_by=[None, 2], - required_gpus=1)) - def testBatchSplitting(self, input_type, split_batch_by): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - batch_size = 10 - dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) - - updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) - expected_values = [[range(i, i+updated_batch_size), - range(i+updated_batch_size, i+2*updated_batch_size)] - for i in range(0, 100, updated_batch_size*2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, - split_batch_by=split_batch_by) - - -class InputIteratorMultiWorkerTest( - multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, - parameterized.TestCase): - - def _cpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] - - def _cpu_and_one_gpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testOneDevicePerWorker(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesPerWorker(self, input_type): - worker_devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 1, 0, 1], [2, 3, 2, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testTupleDataset(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(4) - dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py index c49b5522f91..a134b124744 100644 --- a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py +++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py @@ -19,17 +19,18 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np - -from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import test from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.distribute import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops.parsing_ops import gen_parsing_ops @@ -163,7 +164,8 @@ def batch_wrapper(dataset, batch_size, distribution, repeat=None): dataset = dataset.repeat(repeat) # TPUs currently require fully defined input shapes, drop_remainder ensures # the input will have fully defined shapes. - if isinstance(distribution, tpu_strategy.TPUStrategy): + if isinstance(distribution, (tpu_strategy.TPUStrategy, + tpu_strategy.TPUStrategyV1)): return dataset.batch(batch_size, drop_remainder=True) else: return dataset.batch(batch_size) @@ -289,16 +291,16 @@ def get_correctness_test_inputs(use_numpy, use_validation_data, strategies_minus_tpu = [ - combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus] + strategy_combinations.default_strategy, + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_two_gpus, +] tpu_strategies = [ - combinations.tpu_strategy, # steps_per_run=2 - combinations.tpu_strategy_one_step] + strategy_combinations.tpu_strategy, # steps_per_run=2 + strategy_combinations.tpu_strategy_one_step +] def strategy_minus_tpu_combinations(): @@ -321,14 +323,14 @@ def strategy_and_optimizer_combinations(): return combinations.times( all_strategy_combinations(), combinations.combine(optimizer=[ - combinations.adagrad_optimizer_v1_fn, - combinations.adagrad_optimizer_keras_v2_fn, - combinations.adam_optimizer_v1_fn, - combinations.adam_optimizer_keras_v2_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_keras_v2_fn, - combinations.rmsprop_optimizer_v1_fn, - combinations.rmsprop_optimizer_keras_v2_fn + strategy_combinations.adagrad_optimizer_v1_fn, + strategy_combinations.adagrad_optimizer_keras_v2_fn, + strategy_combinations.adam_optimizer_v1_fn, + strategy_combinations.adam_optimizer_keras_v2_fn, + strategy_combinations.gradient_descent_optimizer_v1_fn, + strategy_combinations.gradient_descent_optimizer_keras_v2_fn, + strategy_combinations.rmsprop_optimizer_v1_fn, + strategy_combinations.rmsprop_optimizer_keras_v2_fn ])) @@ -355,6 +357,7 @@ def strategy_for_numpy_input_combinations(): mode=['graph']) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): @@ -463,6 +466,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, self.assertAllEqual([6, 7], outs[1].shape) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): @@ -529,11 +533,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, # as clone_model's input_tensors argument only seems to accept list and not # tuples or dict. - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + mode=['graph', 'eager'])) def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): with self.cached_session(): model = multi_input_output_model() @@ -614,11 +619,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.evaluate(dataset, steps=2, verbose=1) model.predict(dataset, steps=2) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + mode=['graph', 'eager'])) # TODO(b/120943676, b/120957836): Re-enable once the validation code is # restored. def DISABLED_test_dataset_wrong_input_shape(self, distribution): @@ -640,9 +646,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - @combinations.generate(combinations.combine( - distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) # TODO(b/120943676, b/120957836): Re-enable once the validation code is # restored. def DISABLED_test_dataset_no_batch_input_validation(self, distribution): @@ -662,32 +671,13 @@ class TestDistributionStrategyWithDatasets(test.TestCase, with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - @combinations.generate(combinations.combine( - distribution=[combinations.tpu_strategy_one_step], - mode=['graph'])) - def test_dataset_input_shape_fully_defined(self, distribution): - with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) - - dataset = get_dataset(distribution) - # Input shapes are not fully known. Batch dimension is unknown as we are - # not using the drop_remainder argument. - dataset = dataset.repeat(100).batch(10) - - with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph', 'eager'])) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_two_gpus, + ], + mode=['graph', 'eager'])) def test_learning_phase_value(self, distribution): # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare # meaningful values. Currently we don't pass the learning phase if the @@ -746,7 +736,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - grouped_models = distribution.unwrap( + grouped_models = distribution.experimental_local_results( distributed_training_utils.get_distributed_model( model, ModeKeys.TRAIN)) with distribution.scope(): @@ -755,13 +745,15 @@ class TestDistributionStrategyWithDatasets(test.TestCase, m.optimizer.lr), atol=1e-05, rtol=1e-05) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + mode=['graph', 'eager'])) def test_unsupported_features(self, distribution): with self.cached_session(): model = get_model() @@ -811,11 +803,12 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): '`steps` argument'): model.predict(dataset, verbose=0) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + mode=['graph', 'eager'])) def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): model = get_model() @@ -842,16 +835,18 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): callbacks=[keras.callbacks.ReduceLROnPlateau()]) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyWithLossMasking(test.TestCase, parameterized.TestCase): # TODO(priyag): Enable all strategies for this test. Currently it does not # work for TPU due to some invalid datatype. - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + mode=['graph', 'eager'])) def test_masking(self, distribution): with self.cached_session(): np.random.seed(1337) @@ -872,6 +867,7 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, self.assertEqual(hist.history['loss'][0], 0) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): @@ -904,6 +900,7 @@ class TestDistributionStrategyWithNormalizationLayer( np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) +@test_util.run_v1_only('model.compile(..distribute=..) only works in TF v1') class TestDistributionStrategyCorrectness(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py deleted file mode 100644 index 61202e30c4f..00000000000 --- a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Correctness tests for tf.keras DNN model using DistributionStrategy.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import keras_correctness_test_base -from tensorflow.python import keras -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import test -from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras -from tensorflow.python.training import gradient_descent - - -def all_strategy_combinations_with_eager_and_graph_modes(): - return combinations.combine(distribution=keras_correctness_test_base. - all_strategies, - mode=['graph', 'eager']) - - -def all_strategy_combinations_with_graph_mode(): - return combinations.combine(distribution=keras_correctness_test_base. - all_strategies, mode=['graph']) - - -class TestDistributionStrategyDnnCorrectness( - keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): - - def get_model(self, initial_weights=None, distribution=None): - with keras_correctness_test_base.MaybeDistributionScope(distribution): - # We add few non-linear layers to make it non-trivial. - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense( - 10, activation='relu', - kernel_regularizer=keras.regularizers.l2(1e-4))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) - - if initial_weights: - model.set_weights(initial_weights) - - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent_keras.SGD(0.5), - metrics=['mse']) - return model - - def get_data(self): - # TODO(xiejw): Change this back to 10000, once we support final partial - # batch. - num_samples = 9984 - x_train = np.random.rand(num_samples, 1) - y_train = 3 * x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) - return x_train, y_train, x_predict - - @combinations.generate(keras_correctness_test_base. - all_strategy_and_input_config_combinations()) - def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): - self.run_correctness_test(distribution, use_numpy, use_validation_data) - - @combinations.generate(all_strategy_combinations_with_graph_mode()) - def test_dnn_with_dynamic_learning_rate(self, distribution): - self.run_dynamic_lr_test(distribution) - - -class TestDistributionStrategyDnnMetricCorrectness( - keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): - - def get_model(self, distribution=None): - with distribution.scope(): - model = keras.Sequential() - model.add(keras.layers.Dense(1, - input_shape=(1,), - kernel_initializer='ones')) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), - metrics=[keras.metrics.BinaryAccuracy()]) - return model - - def run_metric_correctness_test(self, distribution): - with self.cached_session(): - self.set_up_test_config() - self.skip_unsupported_test_configuration(distribution) - - x_train, y_train, _ = self.get_data() - model = self.get_model(distribution=distribution) - - batch_size = 64 - batch_size = (keras_correctness_test_base. - get_batch_size(batch_size, distribution)) - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = (keras_correctness_test_base. - batch_wrapper(train_dataset, batch_size, distribution)) - - history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) - - @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) - def test_simple_dnn_metric_correctness(self, distribution): - self.run_metric_correctness_test(distribution) - - -class TestDistributionStrategyDnnMetricEvalCorrectness( - keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): - - def get_model(self, distribution=None): - with distribution.scope(): - model = keras.Sequential() - model.add( - keras.layers.Dense( - 3, activation='relu', input_dim=4, kernel_initializer='ones')) - model.add( - keras.layers.Dense( - 1, activation='sigmoid', kernel_initializer='ones')) - model.compile( - loss='mae', - metrics=['accuracy', keras.metrics.BinaryAccuracy()], - optimizer=gradient_descent.GradientDescentOptimizer(0.001)) - return model - - def run_eval_metrics_correctness_test(self, distribution): - with self.cached_session(): - self.set_up_test_config() - self.skip_unsupported_test_configuration(distribution) - - model = self.get_model(distribution=distribution) - - # verify correctness of stateful and stateless metrics. - x = np.ones((100, 4)).astype('float32') - y = np.ones((100, 1)).astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() - dataset = (keras_correctness_test_base. - batch_wrapper(dataset, 4, distribution)) - outs = model.evaluate(dataset, steps=10) - self.assertEqual(outs[1], 1.) - self.assertEqual(outs[2], 1.) - - y = np.zeros((100, 1)).astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() - dataset = (keras_correctness_test_base. - batch_wrapper(dataset, 4, distribution)) - outs = model.evaluate(dataset, steps=10) - self.assertEqual(outs[1], 0.) - self.assertEqual(outs[2], 0.) - - @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) - def test_identity_model_metric_eval_correctness(self, distribution): - self.run_eval_metrics_correctness_test(distribution) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/distribute/python/keras_multi_worker_correctness_test.py b/tensorflow/contrib/distribute/python/keras_multi_worker_correctness_test.py new file mode 100644 index 00000000000..1223206a497 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_multi_worker_correctness_test.py @@ -0,0 +1,228 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for accuracy and mathematical correctness of tf.keras multi-worker.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl.testing import parameterized +import numpy as np + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy as collective_strategy +from tensorflow.contrib.distribute.python import keras_multi_worker_test_base +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver +from tensorflow.python.framework import ops +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.platform import test + + +np.random.seed(99) +EMBED_INPUTS = np.random.randint(0, 10, (6400, 1)).astype(np.int32) +EMBED_TARGETS = np.random.normal(0, 0.1, (6400, 1)).astype(np.float32) +IMAGE_INPUTS = np.random.normal(0, 0.1, (6400, 28, 28, 3)).astype(np.float32) +IMAGE_TARGETS = np.random.randint(0, 10, (6400, 1)) +LSTM_INPUTS = np.random.normal(0, 0.1, (6400, 10, 20)).astype(np.float32) +LSTM_TARGETS = np.random.normal(0, 0.1, (6400, 1)).astype(np.float32) + + +def get_num_workers(): + cluster_resolver = TFConfigClusterResolver() + cluster_spec = cluster_resolver.cluster_spec().as_dict() + if cluster_spec: + task_type = cluster_resolver.task_type + return int(multi_worker_util.worker_count(cluster_spec, task_type)) + return 1 + + +class Bias(keras.layers.Layer): + + def build(self, input_shape): + self.bias = self.add_weight(shape=(), initializer='zeros', name='bias') + + def call(self, inputs): + return inputs + self.bias + + +class SimpleBiasTest( + keras_multi_worker_test_base.KerasIndependentWorkerTestBase, + parameterized.TestCase): + + @keras_multi_worker_test_base.run_sync_strategies + def test_multi_worker_simple_bias_fit(self, strategy_cls): + + def _worker_fn(results_without_ds=None): + # Make sure Session is cleared at the start of each run. + keras.backend._SESSION.session = None + + x = ops.convert_to_tensor([[0.], [1.], [2.], [0.], [1.], [2.], [0.], + [1.]]) + y = ops.convert_to_tensor([[0.5], [2.], [3.5], [0.5], [2.], [3.5], [0.5], + [2.]]) + ds = dataset_ops.Dataset.from_tensor_slices((x, y)) + ds = ds.batch(8) + model = keras.Sequential([Bias(input_shape=(1,))]) + model.compile( + keras.optimizer_v2.gradient_descent.SGD(0.1), 'mae', metrics=['mae']) + history = model.fit(ds, epochs=5) + self.assertAllClose(history.history['loss'], + [0.9375, 0.8375, 0.7375, 0.6375, 0.5375]) + self.assertAllClose(history.history['mean_absolute_error'], + [0.9375, 0.8375, 0.7375, 0.6375, 0.5375]) + + results = {'training': history.history} + if results_without_ds: + for key in results: + self.assertAllClose( + results[key], + results_without_ds[key], + msg='Fail to assert {}'.format(key)) + + return results + + results_without_ds = _worker_fn() + self.run_independent_workers( + _worker_fn, + strategy_cls, + num_workers=2, + results_without_ds=results_without_ds) + + +def make_image_model(initial_weights=None): + image = keras.layers.Input(shape=(28, 28, 3), name='image') + c1 = keras.layers.Conv2D( + name='conv1', + filters=16, + kernel_size=(3, 3), + strides=(4, 4), + kernel_regularizer=keras.regularizers.l2(1e-4))( + image) + c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) + c1 = keras.layers.Flatten()(c1) + logits = keras.layers.Dense(10, activation='softmax', name='pred')(c1) + model = keras.Model(inputs=[image], outputs=[logits]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + 'sgd', + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + + return model, IMAGE_INPUTS, IMAGE_TARGETS + + +def make_lstm_model(initial_weights=None): + inputs = keras.layers.Input(shape=(10, 20)) + rnn_out = keras.layers.LSTM(4)(inputs) + outputs = keras.layers.Dense(1)(rnn_out) + model = keras.Model(inputs, outputs) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + gradient_descent.SGD(0.1), + 'sparse_categorical_crossentropy', + metrics=['sparse_categorical_crossentropy']) + + return model, LSTM_INPUTS, LSTM_TARGETS + + +def make_embedding_model(initial_weights=None): + inputs = keras.layers.Input(shape=(1,), dtype='int32') + embeddings = keras.layers.Embedding(100, 5)(inputs) + outputs = keras.layers.Dense(1, activation='softmax')(embeddings) + model = keras.Model(inputs, outputs) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile('rmsprop', 'mae', metrics=['binary_crossentropy']) + + return model, EMBED_INPUTS, EMBED_TARGETS + + +class ModelCorrectnessTest( + keras_multi_worker_test_base.KerasIndependentWorkerTestBase, + parameterized.TestCase): + + def make_dataset(self, inputs, targets, batch_size=64): + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.batch(batch_size) + return dataset + + @combinations.generate( + combinations.combine( + mode=['graph'], + strategy_cls=[ + collective_strategy.CollectiveAllReduceStrategy, + ], + make_model=[make_image_model, make_lstm_model, make_embedding_model], + required_gpus=[0, 1])) + def test_correctness(self, strategy_cls, make_model): + + def _worker_fn(initial_weights=None, results_without_ds=None): + # Make sure Session is cleared at each run + # so that it can be configured properly for the DistributionStrategy. + keras.backend._SESSION.session = None + + results = {} + model, inputs, targets = make_model(initial_weights) + + data = self.make_dataset(inputs, targets) + + # TODO(b/129363441): Remove `steps_per_epoch`. + results['training'] = model.fit( + data, steps_per_epoch=50, epochs=2).history + results['trained_weights'] = model.get_weights() + + eval_data = self.make_dataset(inputs, targets) + results['evaluation'] = model.evaluate(eval_data, steps=50) + + if results_without_ds: + for key in results: + self.assertAllClose( + results[key], + results_without_ds[key], + rtol=1e-5, + atol=1e-5, + msg='Fail to assert {}'.format(key)) + + return results + + model, _, _ = make_model() + initial_weights = model.get_weights() + results_without_ds = _worker_fn(initial_weights=initial_weights) + self.run_independent_workers( + _worker_fn, + strategy_cls, + num_workers=2, + initial_weights=initial_weights, + results_without_ds=results_without_ds) + + +if __name__ == '__main__': + with test.mock.patch.object(sys, 'exit', os._exit): + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_multi_worker_test_base.py b/tensorflow/contrib/distribute/python/keras_multi_worker_test_base.py new file mode 100644 index 00000000000..324b10fdae1 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_multi_worker_test_base.py @@ -0,0 +1,103 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test base for tf.keras Models in multi-worker mode.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy as collective_strategy +from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import distribute_coordinator as dc +from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python.eager import context +from tensorflow.python.platform import test + +_original_run_std_server = dc._run_std_server # pylint: disable=protected-access + +# Used as a decorator on test methods. +run_sync_strategies = combinations.generate( + combinations.combine( + mode=['graph'], + strategy_cls=[ + collective_strategy.CollectiveAllReduceStrategy, + ], + required_gpus=[0, 1])) + +# Used as a decorator on test methods. +run_async_strategies = combinations.generate( + combinations.combine( + mode=['graph'], + strategy_cls=[parameter_server_strategy.ParameterServerStrategy], + required_gpus=[0, 1])) + + +def get_strategy_object(strategy_cls): + return strategy_cls(num_gpus_per_worker=context.num_gpus()) + + +# TODO(omalleyt): Merge with keras_multiworker_callback_test +class KerasIndependentWorkerTestBase( + multi_worker_test_base.IndependentWorkerTestBase): + """Test base for simulating Keras Multi-Worker in threads.""" + + def _make_mock_run_std_server(self): + thread_local = threading.local() + + def _mock_run_std_server(*args, **kwargs): + ret = _original_run_std_server(*args, **kwargs) + # Wait for all std servers to be brought up in order to reduce the chance + # of remote sessions taking local ports that have been assigned to std + # servers. Only call this barrier the first time this function is run for + # each thread. + if not getattr(thread_local, 'server_started', False): + self._barrier.wait() + thread_local.server_started = True + return ret + + return _mock_run_std_server + + def run_independent_workers(self, + worker_fn, + strategy_cls, + num_workers, + num_ps=None, + **kwargs): + cluster_spec = multi_worker_test_base.create_cluster_spec( + num_workers=num_workers, num_ps=num_ps) + self._barrier = dc._Barrier(num_workers + (num_ps or 0)) # pylint: disable=protected-access + + def _worker_fn(**kwargs): + """Runs the worker function in a thread.""" + with test.mock.patch.object(dc, '_run_std_server', + self._make_mock_run_std_server()): + strategy = get_strategy_object(strategy_cls) + with strategy.scope(): + return worker_fn(**kwargs) + + threads = self.run_multiple_tasks_in_threads(_worker_fn, cluster_spec, + **kwargs) + strategy = get_strategy_object(strategy_cls) + if strategy.extended.experimental_between_graph: + threads_to_join = threads.get('chief', []) + threads.get('worker', []) + else: + threads_to_join = [ + threads['chief'][0] if 'chief' in threads else threads['worker'][0] + ] + self.join_independent_workers(threads_to_join) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 5391e083fc9..fe88d431dca 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -32,7 +32,7 @@ CoreMirroredExtended = mirrored_strategy.MirroredExtended # pylint: enable=protected-access,invalid-name -class MirroredStrategy(distribute_lib.DistributionStrategy): +class MirroredStrategy(distribute_lib.StrategyV1): """Mirrors vars to distribute across multiple devices and machines. *** contrib version *** @@ -127,36 +127,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): """ return super(MirroredStrategy, self).make_dataset_iterator(dataset) - # Override to change the documentation to reflect the different handling of - # global vs. local batch size between core and contrib. - def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation - self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): - """Makes an iterator for input provided via a nest of numpy arrays. - - NOTE: The `batch_size` argument here has different behavior for this - contrib version of `MirroredStrategy`. - - Args: - numpy_input: A nest of NumPy input arrays that will be distributed evenly - across all replicas. - batch_size: The number of entries from the array we should consume in one - step of the computation, across all replicas. This is the per-replica - batch size. The global batch size will be this times - `num_replicas_in_sync`. - num_epochs: The number of times to iterate through the examples. A value - of `None` means repeat forever. - shuffle: Size of buffer to use for shuffling the input examples. - Use `None` to disable shuffling. - session: (TensorFlow v1.x graph execution only) A session used for - initialization. - - Returns: - An `tf.distribute.InputIterator` which returns inputs for each step of the - computation. User should call `initialize` on the returned iterator. - """ - return super(MirroredStrategy, self).experimental_make_numpy_iterator( - numpy_input, batch_size, num_epochs, shuffle, session) - class MirroredExtended(CoreMirroredExtended): """Implementation of (contrib) MirroredStrategy.""" @@ -188,7 +158,8 @@ class MirroredExtended(CoreMirroredExtended): Returns: An `InputIterator` which returns inputs for each step of the computation. """ - return input_lib.DatasetIterator(dataset, self._input_workers) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._container_strategy()) # TODO(priyag): Delete this once all strategies use global batch size. @property diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index c0651610caf..397ce8743d8 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -20,12 +20,12 @@ from __future__ import print_function from absl.testing import parameterized import numpy - -from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import monitor as monitor_lib -from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example from tensorflow.python.client import session +from tensorflow.python.distribute import combinations from tensorflow.python.distribute import one_device_strategy +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute.single_loss_example import single_loss_example from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -36,8 +36,9 @@ class MonitorTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.times( - combinations.distributions_and_v1_optimizers(), - combinations.combine(mode=combinations.graph_and_eager_modes))) + strategy_combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=strategy_combinations.graph_and_eager_modes))) def testTrainNetwork(self, distribution, optimizer_fn): with distribution.scope(): single_loss_step, layer = single_loss_example(optimizer_fn, distribution) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 13a501394ee..6ae847c2938 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -20,4 +20,4 @@ from __future__ import print_function from tensorflow.python.distribute import one_device_strategy -OneDeviceStrategy = one_device_strategy.OneDeviceStrategy +OneDeviceStrategy = one_device_strategy.OneDeviceStrategyV1 diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index e388061b17a..df5e5595ccb 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -20,20 +20,53 @@ from __future__ import print_function from absl.testing import parameterized import numpy - -from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 +from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute.single_loss_example import minimize_loss_example from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables +mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution( + "MirroredCPUAndGPU", + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), + required_gpus=1) +mirrored_strategy_with_two_gpus = combinations.NamedDistribution( + "Mirrored2GPUs", + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), + required_gpus=2) + +# pylint: disable=g-long-lambda +gradient_descent_optimizer_v2_fn = combinations.NamedObject( + "GradientDescentV2", lambda: gradient_descent_v2.GradientDescentOptimizer( + 0.2)) +adagrad_optimizer_v2_fn = combinations.NamedObject( + "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) + +optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn] + + +def distributions_and_v2_optimizers(): + """DistributionStrategies and V2 Optimizers.""" + return combinations.combine( + distribution=[ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_two_gpus, + ], + optimizer_fn=optimizers_v2) + + class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.times( - combinations.distributions_and_v2_optimizers(), + distributions_and_v2_optimizers(), combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True]))) def testTrainNetwork(self, distribution, optimizer_fn, @@ -45,7 +78,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): def run_step(): return control_flow_ops.group( - distribution.unwrap( + distribution.experimental_local_results( distribution.extended.call_for_each_replica( model_fn, args=(iterator.get_next(),)))) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index be863322256..a5fead9596d 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -31,7 +31,7 @@ CoreParameterServerExtended = parameter_server_strategy.ParameterServerStrategyE # pylint: enable=protected-access,invalid-name,line-too-long -class ParameterServerStrategy(distribute_lib.DistributionStrategy): +class ParameterServerStrategy(distribute_lib.StrategyV1): """A parameter server DistributionStrategy. *** contrib version *** @@ -61,8 +61,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra caution needs to be taken: - 1) Always use `tf.get_variable` instead of `tf.Variable` which is not able - to refer to the same variable on different replicas. + 1) Always use `tf.compat.v1.get_variable` instead of `tf.Variable` which + is not able to refer to the same variable on different replicas. 2) It is generally not recommended to open a device scope under the strategy's scope. A device scope (i.e. calling `tf.device`) will be merged with or @@ -70,9 +70,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): variables. 3) It is also not recommended to open a colocation scope (i.e. calling - `tf.colocate_with`) under the strategy's scope. For colocating variables, use - `strategy.extended.colocate_vars_with` instead. Colocation of ops will - possibly create conflicts of device assignment. + `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating + variables, use `strategy.extended.colocate_vars_with` instead. Colocation of + ops will possibly create conflicts of device assignment. """ def __init__(self, num_gpus_per_worker=0): @@ -114,37 +114,6 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): """ return super(ParameterServerStrategy, self).make_dataset_iterator(dataset) - # Override to change the documentation to reflect the different handling of - # global vs. local batch size between core and contrib. - def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation - self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): - """Makes an iterator for input provided via a nest of numpy arrays. - - NOTE: The `batch_size` argument here has different behavior for this - contrib version of `ParameterServerStrategy`. - - Args: - numpy_input: A nest of NumPy input arrays that will be distributed evenly - across all replicas. - batch_size: The number of entries from the array we should consume in one - step of the computation, across all replicas. This is the per-replica - batch size. The global batch size will be this times - `num_replicas_in_sync`. - num_epochs: The number of times to iterate through the examples. A value - of `None` means repeat forever. - shuffle: Size of buffer to use for shuffling the input examples. - Use `None` to disable shuffling. - session: (TensorFlow v1.x graph execution only) A session used for - initialization. - - Returns: - An `tf.distribute.InputIterator` which returns inputs for each step of the - computation. User should call `initialize` on the returned iterator. - """ - return super(ParameterServerStrategy, - self).experimental_make_numpy_iterator( - numpy_input, batch_size, num_epochs, shuffle, session) - class ParameterServerExtended(CoreParameterServerExtended): """Implementation of ParameterServerStrategy.""" @@ -163,7 +132,8 @@ class ParameterServerExtended(CoreParameterServerExtended): container_strategy, cluster_resolver=cluster_resolver) def _make_dataset_iterator(self, dataset): - return input_lib.DatasetIterator(dataset, self._input_workers) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._container_strategy()) # TODO(priyag): Delete this once all strategies use global batch size. @property diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 18f6904959d..da3cd4843be 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -21,19 +21,18 @@ from __future__ import print_function import copy import threading from absl.testing import parameterized - -from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy -from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import central_storage_strategy +from tensorflow.python.distribute import combinations from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import parameter_server_strategy as core_parameter_server_strategy from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import strategy_test_lib from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop @@ -53,7 +52,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import training_util -from tensorflow.python.training.server_lib import ClusterSpec CHIEF = run_config.TaskType.CHIEF WORKER = run_config.TaskType.WORKER @@ -67,15 +65,6 @@ def _get_replica_id_integer(): return replica_id -class MockCoreParameterServerStrategy(distribute_lib.DistributionStrategy): - """Mock the strategy to allow cluster resolver as an argument.""" - - def __init__(self, cluster_resolver): - super(MockCoreParameterServerStrategy, self).__init__( - core_parameter_server_strategy.ParameterServerStrategyExtended( - self, cluster_resolver=cluster_resolver)) - - def create_test_objects(cluster_spec=None, task_type=None, task_id=None, @@ -92,13 +81,15 @@ def create_test_objects(cluster_spec=None, task_type=task_type, task_id=task_id, num_accelerators={'GPU': num_gpus}) + distribution = core_parameter_server_strategy.ParameterServerStrategy( + cluster_resolver) target = 'grpc://' + cluster_spec[WORKER][task_id] else: - cluster_resolver = SimpleClusterResolver( - ClusterSpec({}), num_accelerators={'GPU': num_gpus}) + distribution = ( + central_storage_strategy.CentralStorageStrategy._from_num_gpus( + num_gpus)) target = '' - distribution = MockCoreParameterServerStrategy(cluster_resolver) sess_config = copy.deepcopy(sess_config) sess_config = distribution.update_config_proto(sess_config) else: @@ -441,7 +432,8 @@ class ParameterServerStrategyTestBase( x, y, z, train_op = d.extended.call_for_each_replica(model_fn) train_op = d.group(train_op) - if context.num_gpus() < d.extended._num_gpus_per_worker: + if context.num_gpus() < sum( + 1 for d in d.extended.worker_devices if 'GPU' in d.upper()): return True if task_id == 0: @@ -514,7 +506,7 @@ class ParameterServerStrategyTestBase( def update(v, g): return v.assign_sub(0.05 * g, use_locking=True) - one = d.broadcast(constant_op.constant([[1.]])) + one = constant_op.constant([[1.]]) def step(): """Perform one optimization step.""" @@ -537,7 +529,8 @@ class ParameterServerStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d.extended._num_gpus_per_worker: + if context.num_gpus() < sum( + 1 for d in d.extended.worker_devices if 'GPU' in d.upper()): return True if (not task_type or @@ -572,6 +565,7 @@ class ParameterServerStrategyTestBase( input_fn, expected_values, test_reinitialize=True, + ignore_order=False, use_core_strategy=False): distribution, master_target, config = self._get_test_objects( task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) @@ -587,7 +581,10 @@ class ParameterServerStrategyTestBase( next_element = iterator.get_next() computed_value = sess.run([values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(expected_value, computed_value) + if ignore_order: + self.assertCountEqual(expected_value, computed_value) + else: + self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() @@ -602,7 +599,10 @@ class ParameterServerStrategyTestBase( next_element = iterator.get_next() computed_value = sess.run([values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(expected_value, computed_value) + if ignore_order: + self.assertCountEqual(expected_value, computed_value) + else: + self.assertEqual(expected_value, computed_value) class ParameterServerStrategyTest( @@ -696,7 +696,6 @@ class ParameterServerStrategyTest( def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) - # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( combinations.combine( @@ -705,7 +704,7 @@ class ParameterServerStrategyTest( required_gpus=1, use_core_strategy=[True, False], use_dataset=[True, False])) - def DISABLED_testMakeInputFnIteratorDistributed( + def testMakeInputFnIteratorDistributed( self, num_gpus, use_core_strategy, use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') @@ -731,9 +730,9 @@ class ParameterServerStrategyTest( input_fn, expected_values, test_reinitialize=use_dataset, + ignore_order=not use_dataset, use_core_strategy=use_core_strategy) - # TODO(b/124344198): Re-enable after fixing this flaky test. @combinations.generate( combinations.combine( mode=['graph'], @@ -741,8 +740,8 @@ class ParameterServerStrategyTest( required_gpus=1, use_core_strategy=[True, False], use_dataset=[True, False])) - def DISABLED_testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, - use_dataset): + def testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, + use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -767,6 +766,7 @@ class ParameterServerStrategyTest( input_fn, expected_values, test_reinitialize=use_dataset, + ignore_order=not use_dataset, use_core_strategy=use_core_strategy) @combinations.generate( @@ -779,9 +779,11 @@ class ParameterServerStrategyTest( combinations.combine(mode=['graph'], use_core_strategy=[True, False])) def testUpdateConfigProtoMultiWorker(self, use_core_strategy): strategy, _, _ = create_test_objects( - num_gpus=2, use_core_strategy=use_core_strategy) - strategy.configure( - cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + cluster_spec=self._cluster_spec, + task_type='worker', + task_id=1, + num_gpus=2, + use_core_strategy=use_core_strategy) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) @@ -806,31 +808,37 @@ class ParameterServerStrategyTest( # Verify isolate_session_state self.assertTrue(new_config.isolate_session_state) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceSum(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_sum(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceSumGradients(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_sum_gradients(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceSumGradientTape(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_sum_gradient_tape(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceMean(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_mean(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceMeanGradients(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) self._test_all_reduce_mean_gradients(distribution) + @combinations.generate(combinations.combine(required_gpus=[2])) def testAllReduceMeanGradientTape(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=2) @@ -918,16 +926,16 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, strategy.extended.call_for_each_replica(f) -class LocalParameterServerStrategyTest(strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class CentralStorageStrategyTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): @combinations.generate(combinations.combine(mode=['graph', 'eager'], use_core_strategy=[True, False], required_gpus=2)) - def testNumpyIterator(self, use_core_strategy): + def testNumpyDataset(self, use_core_strategy): strategy, _, _ = create_test_objects( num_gpus=2, use_core_strategy=use_core_strategy) - self._test_numpy_iterator(strategy) + self._test_numpy_dataset(strategy) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 04e0af767bf..88f97dd9226 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -22,5 +22,5 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import -from tensorflow.python.distribute.tpu_strategy import TPUStrategy +from tensorflow.python.distribute.tpu_strategy import TPUStrategyV1 as TPUStrategy from tensorflow.python.tpu.tpu_strategy_util import initialize_tpu_system diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index c2300286d3b..e4b7b81d083 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -294,7 +294,10 @@ cuda_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", ], - tags = ["nomsan"], # disable to avoid false positives from scipy. + tags = [ + "nomsan", # disable to avoid false positives from scipy. + "notap", # TODO(b/130421237) + ], ) cuda_py_test( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py index d5b3367f9a3..1b88c1d130a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -41,10 +41,10 @@ class ScaleTriLBijectorTest(test.TestCase): diag_shift=shift) y_ = self.evaluate(b.forward(x)) - self.assertAllClose(y, y_) + self.assertAllClose(y, y_, rtol=1e-4) x_ = self.evaluate(b.inverse(y)) - self.assertAllClose(x, x_) + self.assertAllClose(x, x_, rtol=1e-4) @test_util.run_in_graph_and_eager_modes def testInvertible(self): @@ -52,18 +52,18 @@ class ScaleTriLBijectorTest(test.TestCase): # Generate random inputs from an unconstrained space, with # event size 6 to specify 3x3 triangular matrices. batch_shape = [2, 1] - x = np.float32(np.random.randn(*(batch_shape + [6]))) + x = np.float32(self._rng.randn(*(batch_shape + [6]))) b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(), diag_shift=3.14159) y = self.evaluate(b.forward(x)) self.assertAllEqual(y.shape, batch_shape + [3, 3]) x_ = self.evaluate(b.inverse(y)) - self.assertAllClose(x, x_) + self.assertAllClose(x, x_, rtol=1e-4) fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) - self.assertAllClose(fldj, -ildj) + self.assertAllClose(fldj, -ildj, rtol=1e-4) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 6a3d171f6c2..21fb54d1dc0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -32,7 +32,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test - ds = distributions @@ -66,9 +65,7 @@ class MultivariateNormalDiagTest(test.TestCase): with self.cached_session(): base_dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True) dist = ds.TransformedDistribution( - base_dist, - validate_args=True, - bijector=bijectors.Softplus()) + base_dist, validate_args=True, bijector=bijectors.Softplus()) samps = dist.sample(5) # Shape [5, 1, 3]. self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape()) @@ -103,10 +100,8 @@ class MultivariateNormalDiagTest(test.TestCase): samps = dist.sample(int(1e3), seed=0).eval() cov_mat = array_ops.matrix_diag(diag).eval()**2 - self.assertAllClose(mu, samps.mean(axis=0), - atol=0., rtol=0.05) - self.assertAllClose(cov_mat, np.cov(samps.T), - atol=0.05, rtol=0.05) + self.assertAllClose(mu, samps.mean(axis=0), atol=0., rtol=0.05) + self.assertAllClose(cov_mat, np.cov(samps.T), atol=0.05, rtol=0.05) def testSingularScaleRaises(self): mu = [-1., 1] @@ -133,13 +128,11 @@ class MultivariateNormalDiagTest(test.TestCase): n = int(1e3) samps = dist.sample(n, seed=0).eval() cov_mat = array_ops.matrix_diag(diag).eval()**2 - sample_cov = np.matmul(samps.transpose([1, 2, 0]), - samps.transpose([1, 0, 2])) / n + sample_cov = np.matmul( + samps.transpose([1, 2, 0]), samps.transpose([1, 0, 2])) / n - self.assertAllClose(mu, samps.mean(axis=0), - atol=0.10, rtol=0.05) - self.assertAllClose([cov_mat, cov_mat], sample_cov, - atol=0.10, rtol=0.05) + self.assertAllClose(mu, samps.mean(axis=0), atol=0.10, rtol=0.05) + self.assertAllClose([cov_mat, cov_mat], sample_cov, atol=0.10, rtol=0.05) def testCovariance(self): with self.cached_session(): @@ -155,12 +148,8 @@ class MultivariateNormalDiagTest(test.TestCase): self.assertAllEqual([2], mvn.batch_shape) self.assertAllEqual([3], mvn.event_shape) self.assertAllClose( - np.array([[[3., 0, 0], - [0, 3, 0], - [0, 0, 3]], - [[2, 0, 0], - [0, 2, 0], - [0, 0, 2]]])**2., + np.array([[[3., 0, 0], [0, 3, 0], [0, 0, 3]], + [[2, 0, 0], [0, 2, 0], [0, 0, 2]]])**2., mvn.covariance().eval()) mvn = ds.MultivariateNormalDiag( @@ -169,61 +158,48 @@ class MultivariateNormalDiagTest(test.TestCase): self.assertAllEqual([2], mvn.batch_shape) self.assertAllEqual([3], mvn.event_shape) self.assertAllClose( - np.array([[[3., 0, 0], - [0, 2, 0], - [0, 0, 1]], - [[4, 0, 0], - [0, 5, 0], - [0, 0, 6]]])**2., + np.array([[[3., 0, 0], [0, 2, 0], [0, 0, 1]], + [[4, 0, 0], [0, 5, 0], [0, 0, 6]]])**2., mvn.covariance().eval()) def testVariance(self): with self.cached_session(): mvn = ds.MultivariateNormalDiag( loc=array_ops.zeros([2, 3], dtype=dtypes.float32)) - self.assertAllClose( - np.ones([3], dtype=np.float32), - mvn.variance().eval()) + self.assertAllClose(np.ones([3], dtype=np.float32), mvn.variance().eval()) mvn = ds.MultivariateNormalDiag( loc=array_ops.zeros([3], dtype=dtypes.float32), scale_identity_multiplier=[3., 2.]) self.assertAllClose( - np.array([[3., 3, 3], - [2, 2, 2]])**2., + np.array([[3., 3, 3], [2, 2, 2]])**2., mvn.variance().eval()) mvn = ds.MultivariateNormalDiag( loc=array_ops.zeros([3], dtype=dtypes.float32), - scale_diag=[[3., 2, 1], - [4, 5, 6]]) + scale_diag=[[3., 2, 1], [4, 5, 6]]) self.assertAllClose( - np.array([[3., 2, 1], - [4, 5, 6]])**2., + np.array([[3., 2, 1], [4, 5, 6]])**2., mvn.variance().eval()) def testStddev(self): with self.cached_session(): mvn = ds.MultivariateNormalDiag( loc=array_ops.zeros([2, 3], dtype=dtypes.float32)) - self.assertAllClose( - np.ones([3], dtype=np.float32), - mvn.stddev().eval()) + self.assertAllClose(np.ones([3], dtype=np.float32), mvn.stddev().eval()) mvn = ds.MultivariateNormalDiag( loc=array_ops.zeros([3], dtype=dtypes.float32), scale_identity_multiplier=[3., 2.]) self.assertAllClose( - np.array([[3., 3, 3], - [2, 2, 2]]), + np.array([[3., 3, 3], [2, 2, 2]]), mvn.stddev().eval()) mvn = ds.MultivariateNormalDiag( loc=array_ops.zeros([3], dtype=dtypes.float32), scale_diag=[[3., 2, 1], [4, 5, 6]]) self.assertAllClose( - np.array([[3., 2, 1], - [4, 5, 6]]), + np.array([[3., 2, 1], [4, 5, 6]]), mvn.stddev().eval()) def testMultivariateNormalDiagWithSoftplusScale(self): @@ -242,9 +218,8 @@ class MultivariateNormalDiagTest(test.TestCase): num_draws = 50 dims = 3 with self.cached_session() as sess: - x_pl = array_ops.placeholder(dtype=dtypes.float32, - shape=[None, dims], - name="x") + x_pl = array_ops.placeholder( + dtype=dtypes.float32, shape=[None, dims], name="x") mu_var = variable_scope.get_variable( name="mu", shape=[dims], @@ -257,8 +232,8 @@ class MultivariateNormalDiagTest(test.TestCase): scale_diag=array_ops.ones(shape=[dims], dtype=dtypes.float32)) # Typically you'd use `mvn.log_prob(x_pl)` which is always at least as - # numerically stable as `tf.log(mvn.prob(x_pl))`. However in this test - # we're testing a bug specific to `prob` and not `log_prob`; + # numerically stable as `tf.math.log(mvn.prob(x_pl))`. However in this + # test we're testing a bug specific to `prob` and not `log_prob`; # http://stackoverflow.com/q/45109305. (The underlying issue was not # related to `Distributions` but that `reduce_prod` didn't correctly # handle negative indexes.) @@ -268,12 +243,13 @@ class MultivariateNormalDiagTest(test.TestCase): x = np.zeros([num_draws, dims], dtype=np.float32) grad_neg_log_likelihood_ = sess.run( - grad_neg_log_likelihood, - feed_dict={x_pl: x}) + grad_neg_log_likelihood, feed_dict={x_pl: x}) self.assertEqual(1, len(grad_neg_log_likelihood_)) - self.assertAllClose(grad_neg_log_likelihood_[0], - np.tile(num_draws, dims), - rtol=1e-6, atol=0.) + self.assertAllClose( + grad_neg_log_likelihood_[0], + np.tile(num_draws, dims), + rtol=1e-6, + atol=0.) def testDynamicBatchShape(self): mvn = ds.MultivariateNormalDiag( @@ -294,12 +270,10 @@ class MultivariateNormalDiagTest(test.TestCase): with self.cached_session() as sess: loc = array_ops.zeros([dims], dtype=dtypes.float32) mvn = ds.MultivariateNormalDiag( - loc=loc, - scale_diag=np.ones([dims], dtype=np.float32)) + loc=loc, scale_diag=np.ones([dims], dtype=np.float32)) g = gradients_impl.gradients(ds.kl_divergence(mvn, mvn), loc) g_ = sess.run(g) - self.assertAllEqual(np.ones_like(g_, dtype=np.bool), - np.isfinite(g_)) + self.assertAllEqual(np.ones_like(g_, dtype=np.bool), np.isfinite(g_)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index ba31697c589..fcc8898f6eb 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -59,8 +59,8 @@ class Affine(bijector.Bijector): ```python scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + + scale_identity_multiplier * tf.linalg.tensor_diag(tf.ones(d)) + + tf.linalg.tensor_diag(scale_diag) + scale_tril + scale_perturb_factor @ diag(scale_perturb_diag) @ tf.transpose([scale_perturb_factor]) @@ -84,7 +84,7 @@ class Affine(bijector.Bijector): b = Affine(shift=[1., 2, 3], scale_identity_multiplier=2.) - # Y = tf.diag(d1) @ X.T + shift + # Y = tf.linalg.tensor_diag(d1) @ X.T + shift b = Affine(shift=[1., 2, 3], scale_diag=[-1., 2, 1]) # Implicitly 3x3. @@ -136,8 +136,8 @@ class Affine(bijector.Bijector): ```python scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + + scale_identity_multiplier * tf.linalg.tensor_diag(tf.ones(d)) + + tf.linalg.tensor_diag(scale_diag) + scale_tril + scale_perturb_factor @ diag(scale_perturb_diag) @ tf.transpose([scale_perturb_factor]) @@ -147,7 +147,7 @@ class Affine(bijector.Bijector): If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are specified then `scale += IdentityMatrix`. Otherwise specifying a `scale` argument has the semantics of `scale += Expand(arg)`, i.e., - `scale_diag != None` means `scale += tf.diag(scale_diag)`. + `scale_diag != None` means `scale += tf.linalg.tensor_diag(scale_diag)`. Args: shift: Floating-point `Tensor`. If this is set to `None`, no shift is diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py index f19f147dd64..f891e418427 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - import numpy as np from tensorflow.python.framework import ops @@ -29,15 +28,13 @@ from tensorflow.python.ops import nn from tensorflow.python.ops.distributions import bijector from tensorflow.python.util import deprecation - __all__ = [ "BatchNormalization", ] @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -56,10 +53,10 @@ def _undo_batch_normalization(x, x: Input `Tensor` of arbitrary dimensionality. mean: A mean `Tensor`. variance: A variance `Tensor`. - offset: An offset `Tensor`, often denoted `beta` in equations, or - None. If present, will be added to the normalized tensor. - scale: A scale `Tensor`, often denoted `gamma` in equations, or - `None`. If present, the scale is applied to the normalized tensor. + offset: An offset `Tensor`, often denoted `beta` in equations, or None. If + present, will be added to the normalized tensor. + scale: A scale `Tensor`, often denoted `gamma` in equations, or `None`. If + present, the scale is applied to the normalized tensor. variance_epsilon: A small `float` added to the minibatch `variance` to prevent dividing by zero. name: A name for this operation (optional). @@ -67,8 +64,8 @@ def _undo_batch_normalization(x, Returns: batch_unnormalized: The de-normalized, de-scaled, de-offset `Tensor`. """ - with ops.name_scope( - name, "undo_batchnorm", [x, mean, variance, scale, offset]): + with ops.name_scope(name, "undo_batchnorm", + [x, mean, variance, scale, offset]): # inv = math_ops.rsqrt(variance + variance_epsilon) # if scale is not None: # inv *= scale @@ -83,7 +80,9 @@ def _undo_batch_normalization(x, class BatchNormalization(bijector.Bijector): - """Compute `Y = g(X) s.t. X = g^-1(Y) = (Y - mean(Y)) / std(Y)`. + """Compute `Y = g(X) s.t. + + X = g^-1(Y) = (Y - mean(Y)) / std(Y)`. Applies Batch Normalization [(Ioffe and Szegedy, 2015)][1] to samples from a data distribution. This can be used to stabilize training of normalizing @@ -138,8 +137,7 @@ class BatchNormalization(bijector.Bijector): """ @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -153,19 +151,20 @@ class BatchNormalization(bijector.Bijector): """Instantiates the `BatchNorm` bijector. Args: - batchnorm_layer: `tf.layers.BatchNormalization` layer object. If `None`, - defaults to - `tf.layers.BatchNormalization(gamma_constraint=nn_ops.relu(x) + 1e-6)`. - This ensures positivity of the scale variable. - + batchnorm_layer: `tf.compat.v1.layers.BatchNormalization` layer object. If + `None`, defaults to + `tf.compat.v1.layers.BatchNormalization(gamma_constraint=nn_ops.relu(x) + + 1e-6)`. This ensures positivity of the scale variable. training: If True, updates running-average statistics during call to `inverse()`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. + Raises: ValueError: If bn_layer is not an instance of - `tf.layers.BatchNormalization`, or if it is specified with `renorm=True` + `tf.compat.v1.layers.BatchNormalization`, or if it is specified with + `renorm=True` or a virtual batch size. """ # Scale must be positive. @@ -180,16 +179,19 @@ class BatchNormalization(bijector.Bijector): forward_min_event_ndims = len(self.batchnorm.axis) super(BatchNormalization, self).__init__( forward_min_event_ndims=forward_min_event_ndims, - validate_args=validate_args, name=name) + validate_args=validate_args, + name=name) def _validate_bn_layer(self, layer): """Check for valid BatchNormalization layer. Args: - layer: Instance of `tf.layers.BatchNormalization`. + layer: Instance of `tf.compat.v1.layers.BatchNormalization`. + Raises: ValueError: If batchnorm_layer argument is not an instance of - `tf.layers.BatchNormalization`, or if `batchnorm_layer.renorm=True` or + `tf.compat.v1.layers.BatchNormalization`, or if + `batchnorm_layer.renorm=True` or if `batchnorm_layer.virtual_batch_size` is specified. """ if not isinstance(layer, normalization.BatchNormalization): @@ -214,12 +216,13 @@ class BatchNormalization(bijector.Bijector): broadcast_shape = [1] * ndims broadcast_shape[self.batchnorm.axis[0]] = ( input_shape[self.batchnorm.axis[0]]) + def _broadcast(v): - if (v is not None and - len(v.get_shape()) != ndims and + if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v + return _broadcast def _normalize(self, y): @@ -235,8 +238,8 @@ class BatchNormalization(bijector.Bijector): variance = broadcast_fn(self.batchnorm.moving_variance) beta = broadcast_fn(self.batchnorm.beta) if self.batchnorm.center else None gamma = broadcast_fn(self.batchnorm.gamma) if self.batchnorm.scale else None - return _undo_batch_normalization( - x, mean, variance, beta, gamma, self.batchnorm.epsilon) + return _undo_batch_normalization(x, mean, variance, beta, gamma, + self.batchnorm.epsilon) def _forward(self, x): return self._de_normalize(x) @@ -261,12 +264,12 @@ class BatchNormalization(bijector.Bijector): reduction_axes = [i for i in range(len(input_shape)) if i not in event_dims] if use_saved_statistics or not self._training: - log_variance = math_ops.log( - self.batchnorm.moving_variance + self.batchnorm.epsilon) + log_variance = math_ops.log(self.batchnorm.moving_variance + + self.batchnorm.epsilon) else: # At training-time, ildj is computed from the mean and log-variance across # the current minibatch. - _, v = nn.moments(y, axes=reduction_axes, keep_dims=True) + _, v = nn.moments(y, axes=reduction_axes, keepdims=True) log_variance = math_ops.log(v + self.batchnorm.epsilon) # `gamma` and `log Var(y)` reductions over event_dims. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 910774ea5bb..3c61b7eb232 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -142,7 +142,7 @@ class Chain(bijector.Bijector): softplus = Softplus() Chain([exp, softplus]).forward(x) = exp.forward(softplus.forward(x)) - = tf.exp(tf.log(1. + tf.exp(x))) + = tf.exp(tf.math.log(1. + tf.exp(x))) = 1. + tf.exp(x) ``` @@ -153,8 +153,8 @@ class Chain(bijector.Bijector): softplus = Softplus() Chain([exp, softplus]).inverse(y) = softplus.inverse(exp.inverse(y)) - = tf.log(tf.exp(tf.log(y)) - 1.) - = tf.log(y - 1.) + = tf.math.log(tf.exp(tf.math.log(y)) - 1.) + = tf.math.log(y - 1.) ``` """ diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py index 7ae98878986..daab24e4333 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -158,10 +159,13 @@ def vector_size_to_square_matrix_size(d, validate_args, name=None): return int(n) else: with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: - n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.cast(d, dtypes.float32))) / 2. if validate_args: - with ops.control_dependencies([check_ops.assert_equal( - math_ops.to_float(math_ops.to_int32(n)), n, - message="Vector length is not a triangular number")]): + with ops.control_dependencies([ + check_ops.assert_equal( + math_ops.cast(math_ops.cast(n, dtypes.int32), dtypes.float32), + n, + message="Vector length is not a triangular number") + ]): n = array_ops.identity(n) return math_ops.cast(n, d.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index 1504bd27204..1f8ffc554a5 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -35,9 +35,9 @@ class Inline(bijector.Bijector): ```python exp = Inline( forward_fn=tf.exp, - inverse_fn=tf.log, + inverse_fn=tf.math.log, inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), + lambda y: -tf.reduce_sum(tf.math.log(y), axis=-1)), name="exp") ``` diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index c30de1f989a..88855b27fd3 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -36,7 +36,6 @@ from tensorflow.python.ops import variable_scope as variable_scope_lib from tensorflow.python.ops.distributions import bijector from tensorflow.python.util import deprecation - __all__ = [ "MaskedAutoregressiveFlow", "masked_autoregressive_default_template", @@ -190,8 +189,7 @@ class MaskedAutoregressiveFlow(bijector.Bijector): """ @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -210,10 +208,11 @@ class MaskedAutoregressiveFlow(bijector.Bijector): `log_scale` from both the forward domain (`x`) and the inverse domain (`y`). Calculation must respect the "autoregressive property" (see class docstring). Suggested default - `masked_autoregressive_default_template(hidden_layers=...)`. - Typically the function contains `tf.Variables` and is wrapped using - `tf.make_template`. Returning `None` for either (both) `shift`, - `log_scale` is equivalent to (but more efficient than) returning zero. + `masked_autoregressive_default_template(hidden_layers=...)`. Typically + the function contains `tf.Variables` and is wrapped using + `tf.compat.v1.make_template`. Returning `None` for either (both) + `shift`, `log_scale` is equivalent to (but more efficient than) + returning zero. is_constant_jacobian: Python `bool`. Default: `False`. When `True` the implementation assumes `log_scale` does not depend on the forward domain (`x`) or inverse domain (`y`) values. (No validation is made; @@ -222,9 +221,9 @@ class MaskedAutoregressiveFlow(bijector.Bijector): validate_args: Python `bool` indicating whether arguments should be checked for correctness. unroll_loop: Python `bool` indicating whether the `tf.while_loop` in - `_forward` should be replaced with a static for loop. Requires that - the final dimension of `x` be known at graph construction time. Defaults - to `False`. + `_forward` should be replaced with a static for loop. Requires that the + final dimension of `x` be known at graph construction time. Defaults to + `False`. name: Python `str`, name given to ops managed by this object. """ name = name or "masked_autoregressive_flow" @@ -267,6 +266,7 @@ class MaskedAutoregressiveFlow(bijector.Bijector): y0 = array_ops.zeros_like(x, name="y0") # call the template once to ensure creation _ = self._shift_and_log_scale_fn(y0) + def _loop_body(index, y0): """While-loop body for autoregression calculation.""" # Set caching device to avoid re-getting the tf.Variable for every while @@ -282,6 +282,7 @@ class MaskedAutoregressiveFlow(bijector.Bijector): if shift is not None: y += shift return index + 1, y + _, y = control_flow_ops.while_loop( cond=lambda index, _: index < event_size, body=_loop_body, @@ -310,8 +311,7 @@ MASK_EXCLUSIVE = "exclusive" @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -335,8 +335,7 @@ def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -357,8 +356,7 @@ def _gen_mask(num_blocks, @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -373,7 +371,9 @@ def masked_dense(inputs, name=None, *args, **kwargs): - """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. + """A autoregressively masked dense layer. + + Analogous to `tf.compat.v1.layers.dense`. See [Germain et al. (2015)][1] for detailed explanation. @@ -385,14 +385,14 @@ def masked_dense(inputs, MADE masks. exclusive: Python `bool` scalar representing whether to zero the diagonal of the mask, used for the first layer of a MADE. - kernel_initializer: Initializer function for the weight matrix. - If `None` (default), weights are initialized using the + kernel_initializer: Initializer function for the weight matrix. If `None` + (default), weights are initialized using the `tf.glorot_random_initializer`. reuse: Python `bool` scalar representing whether to reuse the weights of a previous layer by the same name. name: Python `str` used to describe ops managed by this function. - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. + *args: `tf.compat.v1.layers.dense` arguments. + **kwargs: `tf.compat.v1.layers.dense` keyword arguments. Returns: Output tensor. @@ -438,8 +438,7 @@ def masked_dense(inputs, @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -464,7 +463,7 @@ def masked_autoregressive_default_template( Warning: This function uses `masked_dense` to create randomly initialized `tf.Variables`. It is presumed that these will be fit, just as you would any - other neural architecture which uses `tf.layers.dense`. + other neural architecture which uses `tf.compat.v1.layers.dense`. #### About Hidden Layers @@ -500,8 +499,8 @@ def masked_autoregressive_default_template( `tf.clip_by_value` should be preserved. Default: `False`. name: A name for ops managed by this function. Default: "masked_autoregressive_default_template". - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. + *args: `tf.compat.v1.layers.dense` arguments. + **kwargs: `tf.compat.v1.layers.dense` keyword arguments. Returns: shift: `Float`-like `Tensor` of shift terms (the "mu" in @@ -521,6 +520,7 @@ def masked_autoregressive_default_template( """ name = name or "masked_autoregressive_default_template" with ops.name_scope(name, values=[log_scale_min_clip, log_scale_max_clip]): + def _fn(x): """MADE parameterized via `masked_autoregressive_default_template`.""" # TODO(b/67594795): Better support of dynamic shape. @@ -529,8 +529,9 @@ def masked_autoregressive_default_template( if input_depth is None: raise NotImplementedError( "Rightmost dimension must be known prior to graph execution.") - input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() - else array_ops.shape(x)) + input_shape = ( + np.int32(x.shape.as_list()) + if x.shape.is_fully_defined() else array_ops.shape(x)) for i, units in enumerate(hidden_layers): x = masked_dense( inputs=x, @@ -553,16 +554,17 @@ def masked_autoregressive_default_template( x = array_ops.reshape( x, shape=array_ops.concat([input_shape, [2]], axis=0)) shift, log_scale = array_ops.unstack(x, num=2, axis=-1) - which_clip = (math_ops.clip_by_value if log_scale_clip_gradient - else _clip_by_value_preserve_grad) + which_clip = ( + math_ops.clip_by_value + if log_scale_clip_gradient else _clip_by_value_preserve_grad) log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) return shift, log_scale + return template_ops.make_template(name, _fn) @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index 178c3c94bfd..4d136c59899 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -31,7 +31,6 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.util import deprecation - __all__ = [ "Permute", ] @@ -62,12 +61,13 @@ class Permute(bijector.Bijector): Warning: `tf.estimator` may repeatedly build the graph thus `Permute(np.random.permutation(event_size)).astype("int32"))` is not a reliable parameterization (nor would it be even if using `tf.constant`). A - safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + safe alternative is to use `tf.compat.v1.get_variable` to achieve "init once" + behavior, i.e., ```python def init_once(x, name): - return tf.get_variable(name, initializer=x, trainable=False) + return tf.compat.v1.get_variable(name, initializer=x, trainable=False) Permute(permutation=init_once( np.random.permutation(event_size).astype("int32"), @@ -77,8 +77,7 @@ class Permute(bijector.Bijector): """ @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -101,9 +100,7 @@ class Permute(bijector.Bijector): `{0, 1, ..., d}`. """ with ops.name_scope(name, "permute", values=[permutation]): - permutation = ops.convert_to_tensor( - permutation, - name="permutation") + permutation = ops.convert_to_tensor(permutation, name="permutation") if not permutation.dtype.is_integer: raise TypeError("permutation.dtype ({}) should be `int`-like.".format( permutation.dtype.name)) @@ -113,12 +110,12 @@ class Permute(bijector.Bijector): raise ValueError("Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.") elif validate_args: - p, _ = nn_ops.top_k(-permutation, - k=array_ops.shape(permutation)[-1], - sorted=True) + p, _ = nn_ops.top_k( + -permutation, k=array_ops.shape(permutation)[-1], sorted=True) permutation = control_flow_ops.with_dependencies([ check_ops.assert_equal( - -p, math_ops.range(array_ops.size(p)), + -p, + math_ops.range(array_ops.size(p)), message=("Permutation over `d` must contain exactly one of " "each of `{0, 1, ..., d}`.")), ], permutation) @@ -138,9 +135,7 @@ class Permute(bijector.Bijector): def _inverse(self, y): return array_ops.gather( - y, - array_ops.invert_permutation(self.permutation), - axis=-1) + y, array_ops.invert_permutation(self.permutation), axis=-1) def _inverse_log_det_jacobian(self, y): # is_constant_jacobian = True for this bijector, hence the diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 17e9b8dec9f..c0b20fd8637 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -29,11 +29,7 @@ from tensorflow.python.ops import template as template_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.util import deprecation - -__all__ = [ - "RealNVP", - "real_nvp_default_template" -] +__all__ = ["RealNVP", "real_nvp_default_template"] class RealNVP(bijector.Bijector): @@ -132,8 +128,7 @@ class RealNVP(bijector.Bijector): """ @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -155,10 +150,11 @@ class RealNVP(bijector.Bijector): `log_scale` from both the forward domain (`x`) and the inverse domain (`y`). Calculation must respect the "autoregressive property" (see class docstring). Suggested default - `masked_autoregressive_default_template(hidden_layers=...)`. - Typically the function contains `tf.Variables` and is wrapped using - `tf.make_template`. Returning `None` for either (both) `shift`, - `log_scale` is equivalent to (but more efficient than) returning zero. + `masked_autoregressive_default_template(hidden_layers=...)`. Typically + the function contains `tf.Variables` and is wrapped using + `tf.compat.v1.make_template`. Returning `None` for either (both) + `shift`, `log_scale` is equivalent to (but more efficient than) + returning zero. is_constant_jacobian: Python `bool`. Default: `False`. When `True` the implementation assumes `log_scale` does not depend on the forward domain (`x`) or inverse domain (`y`) values. (No validation is made; @@ -243,8 +239,7 @@ class RealNVP(bijector.Bijector): @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -272,8 +267,8 @@ def real_nvp_default_template( implies a linear activation. name: A name for ops managed by this function. Default: "real_nvp_default_template". - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. + *args: `tf.compat.v1.layers.dense` arguments. + **kwargs: `tf.compat.v1.layers.dense` keyword arguments. Returns: shift: `Float`-like `Tensor` of shift terms ("mu" in @@ -293,15 +288,12 @@ def real_nvp_default_template( """ with ops.name_scope(name, "real_nvp_default_template"): + def _fn(x, output_units): """Fully connected MLP parameterized via `real_nvp_template`.""" for units in hidden_layers: x = layers.dense( - inputs=x, - units=units, - activation=activation, - *args, - **kwargs) + inputs=x, units=units, activation=activation, *args, **kwargs) x = layers.dense( inputs=x, units=(1 if shift_only else 2) * output_units, @@ -312,5 +304,5 @@ def real_nvp_default_template( return x, None shift, log_scale = array_ops.split(x, 2, axis=-1) return shift, log_scale - return template_ops.make_template( - "real_nvp_default_template", _fn) + + return template_ops.make_template("real_nvp_default_template", _fn) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index 74765f19e58..26d5407125c 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -46,12 +46,12 @@ class SoftmaxCentered(bijector.Bijector): Example Use: ```python - bijector.SoftmaxCentered().forward(tf.log([2, 3, 4])) + bijector.SoftmaxCentered().forward(tf.math.log([2, 3, 4])) # Result: [0.2, 0.3, 0.4, 0.1] # Extra result: 0.1 bijector.SoftmaxCentered().inverse([0.2, 0.3, 0.4, 0.1]) - # Result: tf.log([2, 3, 4]) + # Result: tf.math.log([2, 3, 4]) # Extra coordinate removed. ``` diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 1415f85e5cb..85692d271b6 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -515,7 +515,7 @@ def move_dimension(x, source_idx, dest_idx): Example: ```python - x = tf.placeholder(shape=[200, 30, 4, 1, 6]) + x = tf.compat.v1.placeholder(shape=[200, 30, 4, 1, 6]) x_perm = _move_dimension(x, 1, 1) # no-op x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] x_perm = _move_dimension(x, 0, -2) # equivalent to previous diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index cf15deebb78..7e1411ea89e 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -115,7 +115,7 @@ class Independent(distribution_lib.Distribution): reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims which will be regarded as event dims. When `None` all but the first batch axis (batch axis 0) will be transferred to event dimensions - (analogous to `tf.layers.flatten`). + (analogous to `tf.compat.v1.layers.flatten`). validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 1006dfac49f..9f1e9d5cd1b 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -84,7 +84,7 @@ class InverseGamma(distribution.Distribution): examples for details. WARNING: This distribution may draw 0-valued samples for small concentration - values. See note in `tf.random_gamma` docstring. + values. See note in `tf.random.gamma` docstring. #### Examples @@ -190,7 +190,7 @@ class InverseGamma(distribution.Distribution): return tensor_shape.scalar() @distribution_util.AppendDocstring( - """Note: See `tf.random_gamma` docstring for sampling details and + """Note: See `tf.random.gamma` docstring for sampling details and caveats.""") def _sample_n(self, n, seed=None): return 1. / random_ops.random_gamma( diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index c21f70fc3b3..2f09f49f8c6 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -24,14 +24,12 @@ from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation - __all__ = [ "MultivariateNormalTriL", ] -class MultivariateNormalTriL( - mvn_linop.MultivariateNormalLinearOperator): +class MultivariateNormalTriL(mvn_linop.MultivariateNormalLinearOperator): """The multivariate normal distribution on `R^k`. The Multivariate Normal distribution is defined over `R^k` and parameterized @@ -63,7 +61,7 @@ class MultivariateNormalTriL( ``` where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal, - i.e., `tf.diag_part(scale_tril) != 0`. + i.e., `tf.linalg.tensor_diag_part(scale_tril) != 0`. Additional leading dimensions (if any) will index batches. @@ -91,7 +89,7 @@ class MultivariateNormalTriL( cov = [[ 0.36, 0.12, 0.06], [ 0.12, 0.29, -0.13], [ 0.06, -0.13, 0.26]] - scale = tf.cholesky(cov) + scale = tf.linalg.cholesky(cov) # ==> [[ 0.6, 0. , 0. ], # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) @@ -126,19 +124,19 @@ class MultivariateNormalTriL( # Instantiate a "learnable" MVN. dims = 4 - with tf.variable_scope("model"): + with tf.compat.v1.variable_scope("model"): mvn = tfd.MultivariateNormalTriL( - loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), + loc=tf.compat.v1.get_variable(shape=[dims], dtype=tf.float32, + name="mu"), scale_tril=tfd.fill_triangular( - tf.get_variable(shape=[dims * (dims + 1) / 2], + tf.compat.v1.get_variable(shape=[dims * (dims + 1) / 2], dtype=tf.float32, name="chol_Sigma"))) ``` """ @deprecation.deprecated( - "2018-10-01", - "The TensorFlow Distributions library has moved to " + "2018-10-01", "The TensorFlow Distributions library has moved to " "TensorFlow Probability " "(https://github.com/tensorflow/probability). You " "should update all references to use `tfp.distributions` " @@ -165,7 +163,7 @@ class MultivariateNormalTriL( ``` where `scale_tril` is lower-triangular `k x k` matrix with non-zero - diagonal, i.e., `tf.diag_part(scale_tril) != 0`. + diagonal, i.e., `tf.linalg.tensor_diag_part(scale_tril) != 0`. Additional leading dimensions (if any) will index batches. @@ -174,24 +172,26 @@ class MultivariateNormalTriL( implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_tril: Floating-point, lower-triangular `Tensor` with non-zero - diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where - `b >= 0` and `k` is the event size. + diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where `b + >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. - allow_nan_stats: Python `bool`, default `True`. When `True`, - statistics (e.g., mean, mode, variance) use the value "`NaN`" to - indicate the result is undefined. When `False`, an exception is raised - if one or more of the statistic's batch members are undefined. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or more + of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ parameters = dict(locals()) + def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) + if loc is None and scale_tril is None: raise ValueError("Must specify one or both of `loc`, `scale_tril`.") with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index 19e99e03803..ad0f2317c99 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -304,14 +304,14 @@ def percentile(x, x = ops.convert_to_tensor(x, name="x") # Double is needed here and below, else we get the wrong index if the array # is huge along axis. - q = math_ops.to_double(q, name="q") + q = math_ops.cast(q, dtypes.float64, name="q") _get_static_ndims(q, expect_ndims=0) if validate_args: q = control_flow_ops.with_dependencies([ check_ops.assert_rank(q, 0), - check_ops.assert_greater_equal(q, math_ops.to_double(0.)), - check_ops.assert_less_equal(q, math_ops.to_double(100.)) + check_ops.assert_greater_equal(q, math_ops.cast(0., dtypes.float64)), + check_ops.assert_less_equal(q, math_ops.cast(100., dtypes.float64)) ], q) if axis is None: @@ -336,7 +336,7 @@ def percentile(x, y = _move_dims_to_flat_end(x, axis, x_ndims) frac_at_q_or_above = 1. - q / 100. - d = math_ops.to_double(array_ops.shape(y)[-1]) + d = math_ops.cast(array_ops.shape(y)[-1], dtypes.float64) if interpolation == "lower": index = math_ops.ceil((d - 1) * frac_at_q_or_above) @@ -349,7 +349,7 @@ def percentile(x, # let's use max/min to avoid out of bounds errors. d = array_ops.shape(y)[-1] # d - 1 will be distinct from d in int32. - index = clip_ops.clip_by_value(math_ops.to_int32(index), 0, d - 1) + index = clip_ops.clip_by_value(math_ops.cast(index, dtypes.int32), 0, d - 1) # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py index cf505ac627b..3d39e9ce507 100644 --- a/tensorflow/contrib/distributions/python/ops/seed_stream.py +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -38,8 +38,8 @@ class SeedStream(object): ```python def broken_beta(shape, alpha, beta, seed): - x = tf.random_gamma(shape, alpha, seed=seed) - y = tf.random_gamma(shape, beta, seed=seed) + x = tf.random.gamma(shape, alpha, seed=seed) + y = tf.random.gamma(shape, beta, seed=seed) return x / (x + y) ``` @@ -83,8 +83,8 @@ class SeedStream(object): ```python def random_beta(shape, alpha, beta, seed): # (a) seed = SeedStream(seed, salt="random_beta") # (b) - x = tf.random_gamma(shape, alpha, seed=seed()) # (c) - y = tf.random_gamma(shape, beta, seed=seed()) # (c) + x = tf.random.gamma(shape, alpha, seed=seed()) # (c) + y = tf.random.gamma(shape, beta, seed=seed()) # (c) return x / (x + y) ``` @@ -123,12 +123,12 @@ class SeedStream(object): ```python def tfp_foo(seed): seed = SeedStream(seed, salt="") - foo_stuff = tf.random_normal(seed=seed()) + foo_stuff = tf.random.normal(seed=seed()) ... def tfp_bar(seed): seed = SeedStream(seed, salt="") - bar_stuff = tf.random_normal(seed=seed()) + bar_stuff = tf.random.normal(seed=seed()) ... def client_baz(seed): diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 4f348be2806..19d88d5ab5d 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -114,7 +114,7 @@ class _DistributionShape(object): E.g., Jacobian of the transform `Y = g(X) = exp(X)`: ```python - tf.div(1., tf.reduce_prod(x, event_dims)) + tf.compat.v1.div(1., tf.reduce_prod(x, event_dims)) ``` We show examples using this class. diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py index af22f4843a0..ed64c71a218 100644 --- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -43,7 +43,7 @@ is some expected constant. Suppose the support of P is the interval # Check that the difference in means detectable with 5000 samples is # small enough - check2 = tf.assert_less( + check2 = tf.compat.v1.assert_less( statistical_testing.min_discrepancy_of_true_means_detectable_by_dkwm( num_samples, low=0., high=1.0, false_fail_rate=1e-6, false_pass_rate=1e-6), diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py index 15b0820cbdf..73ab3d818be 100644 --- a/tensorflow/contrib/distributions/python/ops/test_util.py +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import histogram_ops @@ -27,7 +28,6 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables as variables_ops - __all__ = [ "DiscreteScalarDistributionTestHelpers", "VectorDistributionTestHelpers", @@ -37,11 +37,15 @@ __all__ = [ class DiscreteScalarDistributionTestHelpers(object): """DiscreteScalarDistributionTestHelpers.""" - def run_test_sample_consistent_log_prob( - self, sess_run_fn, dist, - num_samples=int(1e5), num_threshold=int(1e3), seed=42, - batch_size=None, - rtol=1e-2, atol=0.): + def run_test_sample_consistent_log_prob(self, + sess_run_fn, + dist, + num_samples=int(1e5), + num_threshold=int(1e3), + seed=42, + batch_size=None, + rtol=1e-2, + atol=0.): """Tests that sample/log_prob are consistent with each other. "Consistency" means that `sample` and `log_prob` correspond to the same @@ -61,9 +65,9 @@ class DiscreteScalarDistributionTestHelpers(object): samples to draw from `dist`. num_threshold: Python `int` scalar indicating the number of samples a bucket must contain before being compared to the probability. - Default value: 1e3; must be at least 1. - Warning, set too high will cause test to falsely pass but setting too - low will cause the test to falsely fail. + Default value: 1e3; must be at least 1. Warning, set too high will cause + test to falsely pass but setting too low will cause the test to + falsely fail. seed: Python `int` indicating the seed to use when sampling from `dist`. In general it is not recommended to use `None` during a test as this increases the likelihood of spurious test failure. @@ -78,8 +82,8 @@ class DiscreteScalarDistributionTestHelpers(object): ValueError: if `num_threshold < 1`. """ if num_threshold < 1: - raise ValueError("num_threshold({}) must be at least 1.".format( - num_threshold)) + raise ValueError( + "num_threshold({}) must be at least 1.".format(num_threshold)) # Histogram only supports vectors so we call it once per batch coordinate. y = dist.sample(num_samples, seed=seed) y = array_ops.reshape(y, shape=[num_samples, -1]) @@ -97,13 +101,15 @@ class DiscreteScalarDistributionTestHelpers(object): valid = counts_ > num_threshold probs_ = probs_[valid] counts_ = counts_[valid] - self.assertAllClose(probs_, counts_ / num_samples, - rtol=rtol, atol=atol) + self.assertAllClose(probs_, counts_ / num_samples, rtol=rtol, atol=atol) - def run_test_sample_consistent_mean_variance( - self, sess_run_fn, dist, - num_samples=int(1e5), seed=24, - rtol=1e-2, atol=0.): + def run_test_sample_consistent_mean_variance(self, + sess_run_fn, + dist, + num_samples=int(1e5), + seed=24, + rtol=1e-2, + atol=0.): """Tests that sample/mean/variance are consistent with each other. "Consistency" means that `sample`, `mean`, `variance`, etc all correspond @@ -125,27 +131,21 @@ class DiscreteScalarDistributionTestHelpers(object): atol: Python `float`-type indicating the admissible absolute error between analytical and sample statistics. """ - x = math_ops.to_float(dist.sample(num_samples, seed=seed)) + x = math_ops.cast(dist.sample(num_samples, seed=seed), dtypes.float32) sample_mean = math_ops.reduce_mean(x, axis=0) sample_variance = math_ops.reduce_mean( math_ops.square(x - sample_mean), axis=0) sample_stddev = math_ops.sqrt(sample_variance) - [ - sample_mean_, - sample_variance_, - sample_stddev_, - mean_, - variance_, - stddev_ - ] = sess_run_fn([ - sample_mean, - sample_variance, - sample_stddev, - dist.mean(), - dist.variance(), - dist.stddev(), - ]) + [sample_mean_, sample_variance_, sample_stddev_, mean_, variance_, + stddev_] = sess_run_fn([ + sample_mean, + sample_variance, + sample_stddev, + dist.mean(), + dist.variance(), + dist.stddev(), + ]) self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol) self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol) @@ -180,7 +180,7 @@ class DiscreteScalarDistributionTestHelpers(object): lo = value_range[0] hi = value_range[1] if nbins is None: - nbins = math_ops.to_int32(hi - lo) + nbins = math_ops.cast(hi - lo, dtypes.int32) delta = (hi - lo) / math_ops.cast( nbins, dtype=value_range.dtype.base_dtype) edges = math_ops.range( @@ -193,16 +193,15 @@ class DiscreteScalarDistributionTestHelpers(object): class VectorDistributionTestHelpers(object): """VectorDistributionTestHelpers helps test vector-event distributions.""" - def run_test_sample_consistent_log_prob( - self, - sess_run_fn, - dist, - num_samples=int(1e5), - radius=1., - center=0., - seed=42, - rtol=1e-2, - atol=0.): + def run_test_sample_consistent_log_prob(self, + sess_run_fn, + dist, + num_samples=int(1e5), + radius=1., + center=0., + seed=42, + rtol=1e-2, + atol=0.): """Tests that sample/log_prob are mutually consistent. "Consistency" means that `sample` and `log_prob` correspond to the same @@ -268,24 +267,23 @@ class VectorDistributionTestHelpers(object): rtol: Python `float`-type indicating the admissible relative error between actual- and approximate-volumes. atol: Python `float`-type indicating the admissible absolute error between - actual- and approximate-volumes. In general this should be zero since - a typical radius implies a non-zero volume. + actual- and approximate-volumes. In general this should be zero since a + typical radius implies a non-zero volume. """ def actual_hypersphere_volume(dims, radius): # https://en.wikipedia.org/wiki/Volume_of_an_n-ball - # Using tf.lgamma because we'd have to otherwise use SciPy which is not - # a required dependency of core. + # Using tf.math.lgamma because we'd have to otherwise use SciPy which is + # not a required dependency of core. radius = np.asarray(radius) dims = math_ops.cast(dims, dtype=radius.dtype) - return math_ops.exp( - (dims / 2.) * np.log(np.pi) - - math_ops.lgamma(1. + dims / 2.) - + dims * math_ops.log(radius)) + return math_ops.exp((dims / 2.) * np.log(np.pi) - + math_ops.lgamma(1. + dims / 2.) + + dims * math_ops.log(radius)) def is_in_ball(x, radius, center): - return math_ops.cast(linalg_ops.norm(x - center, axis=-1) <= radius, - dtype=x.dtype) + return math_ops.cast( + linalg_ops.norm(x - center, axis=-1) <= radius, dtype=x.dtype) def monte_carlo_hypersphere_volume(dist, num_samples, radius, center): # https://en.wikipedia.org/wiki/Importance_sampling @@ -301,35 +299,32 @@ class VectorDistributionTestHelpers(object): values=[num_samples, radius, center] + dist._graph_parents): # pylint: disable=protected-access batch_shape = dist.batch_shape_tensor() actual_volume = actual_hypersphere_volume( - dims=dist.event_shape_tensor()[0], - radius=radius) + dims=dist.event_shape_tensor()[0], radius=radius) sample_volume = monte_carlo_hypersphere_volume( - dist, - num_samples=num_samples, - radius=radius, - center=center) + dist, num_samples=num_samples, radius=radius, center=center) init_op = variables_ops.global_variables_initializer() # Execute graph. sess_run_fn(init_op) - [batch_shape_, actual_volume_, sample_volume_] = sess_run_fn([ - batch_shape, actual_volume, sample_volume]) + [batch_shape_, actual_volume_, + sample_volume_] = sess_run_fn([batch_shape, actual_volume, sample_volume]) # Check results. - self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_), - sample_volume_, - rtol=rtol, atol=atol) + self.assertAllClose( + np.tile(actual_volume_, reps=batch_shape_), + sample_volume_, + rtol=rtol, + atol=atol) - def run_test_sample_consistent_mean_covariance( - self, - sess_run_fn, - dist, - num_samples=int(1e5), - seed=24, - rtol=1e-2, - atol=0.1, - cov_rtol=None, - cov_atol=None): + def run_test_sample_consistent_mean_covariance(self, + sess_run_fn, + dist, + num_samples=int(1e5), + seed=24, + rtol=1e-2, + atol=0.1, + cov_rtol=None, + cov_atol=None): """Tests that sample/mean/covariance are consistent with each other. "Consistency" means that `sample`, `mean`, `covariance`, etc all correspond @@ -364,14 +359,8 @@ class VectorDistributionTestHelpers(object): sample_stddev = math_ops.sqrt(sample_variance) [ - sample_mean_, - sample_covariance_, - sample_variance_, - sample_stddev_, - mean_, - covariance_, - variance_, - stddev_ + sample_mean_, sample_covariance_, sample_variance_, sample_stddev_, + mean_, covariance_, variance_, stddev_ ] = sess_run_fn([ sample_mean, sample_covariance, @@ -384,9 +373,11 @@ class VectorDistributionTestHelpers(object): ]) self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol) - self.assertAllClose(covariance_, sample_covariance_, - rtol=cov_rtol or rtol, - atol=cov_atol or atol) + self.assertAllClose( + covariance_, + sample_covariance_, + rtol=cov_rtol or rtol, + atol=cov_atol or atol) self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol) self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol) diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index a648d61ac8d..f9748466c2e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -682,7 +682,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._covariance_of_mean_given_quadrature_component(diag_only=False)) def _variance(self): - # Equivalent to: tf.diag_part(self._covariance()), + # Equivalent to: tf.linalg.tensor_diag_part(self._covariance()), return add( self._mean_of_covariance_given_quadrature_component(diag_only=True), self._covariance_of_mean_given_quadrature_component(diag_only=True)) diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 2d83f0c13f1..a5bb880bed9 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -486,7 +486,7 @@ class WishartCholesky(_WishartLinearOperator): # Initialize a single 3x3 Wishart with Cholesky factored scale matrix and 5 # degrees-of-freedom.(*) df = 5 - chol_scale = tf.cholesky(...) # Shape is [3, 3]. + chol_scale = tf.linalg.cholesky(...) # Shape is [3, 3]. dist = tfd.WishartCholesky(df=df, scale=chol_scale) # Evaluate this on an observation in R^3, returning a scalar. @@ -500,7 +500,7 @@ class WishartCholesky(_WishartLinearOperator): # Initialize two 3x3 Wisharts with Cholesky factored scale matrices. df = [5, 4] - chol_scale = tf.cholesky(...) # Shape is [2, 3, 3]. + chol_scale = tf.linalg.cholesky(...) # Shape is [2, 3, 3]. dist = tfd.WishartCholesky(df=df, scale=chol_scale) # Evaluate this on four observations. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index d441e4735b6..a500f9fd34c 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -151,6 +151,7 @@ py_library( py_test( name = "metrics_test", srcs = ["metrics_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":metrics", @@ -188,6 +189,7 @@ py_library( py_test( name = "evaluator_test", srcs = ["evaluator_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":evaluator", @@ -220,6 +222,7 @@ py_library( py_test( name = "network_test", srcs = ["network_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":network", diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 34614b86a75..97bf02f6539 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -24,7 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops -class Iterator(iterator_ops.EagerIterator): +class Iterator(iterator_ops.IteratorV2): """An iterator producing tf.Tensor objects from a tf.data.Dataset. NOTE: Unlike the iterator created by the diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 51443d24829..fa46d73241d 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -165,8 +165,15 @@ class Evaluator(object): self.__call__(example, *args, **kwargs) return self.all_metric_results(summary_logdir) # Graph construction - call_op = self.__call__( - dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs) + next_value = dataset_ops.make_one_shot_iterator(dataset).get_next() + # Function inlining destroys strict inputs semantics (function body might + # start execution before all inputs are ready). When iterator is exhausted + # and throws out of range error, function body might be partially executed. + # To prevent this we add an explicit control dependency from the 'get_next'. + with ops.control_dependencies([next_value]): + has_next_value = control_flow_ops.no_op(name="iterator_has_next") + with ops.control_dependencies([has_next_value]): + call_op = self.__call__(next_value, *args, **kwargs) init_op = self.init_variables() results_op = self.all_metric_results(summary_logdir) return (init_op, call_op, results_op) diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 3e0881754c7..fd5a44a7975 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -1,4 +1,5 @@ # TensorFlow code for training gradient boosted trees. + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index fbb5daf230b..a001d426fe2 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -8,6 +8,7 @@ load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "densenet", srcs = ["densenet.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":densenet_lib"], ) diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD index d99a5191127..be561a1da66 100644 --- a/tensorflow/contrib/eager/python/examples/gan/BUILD +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -8,6 +8,7 @@ load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "mnist", srcs = ["mnist.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":mnist_lib"], ) diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py index 81ac05e26d2..cb99e2b416b 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -22,7 +22,6 @@ import time import tensorflow as tf -import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.gan import mnist NOISE_DIM = 100 @@ -37,14 +36,14 @@ def data_format(): def device(): - return '/gpu:0' if tfe.num_gpus() else '/cpu:0' + return '/gpu:0' if tf.test.is_gpu_available() else '/cpu:0' class MnistEagerGanBenchmark(tf.test.Benchmark): def _report(self, test_name, start, num_iters, batch_size): avg_time = (time.time() - start) / num_iters - dev = 'gpu' if tfe.num_gpus() else 'cpu' + dev = 'gpu' if tf.test.is_gpu_available() else 'cpu' name = 'eager_%s_%s_batch_%d_%s' % (test_name, dev, batch_size, data_format()) extras = {'examples_per_sec': batch_size / avg_time} diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index 30afef83bc5..8536fdbf705 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -8,6 +8,7 @@ load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "linear_regression", srcs = ["linear_regression.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":linear_regression_lib"], ) diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb index d60ee185861..57bd18d7529 100644 --- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb +++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb @@ -11,777 +11,17 @@ "\n", "Licensed under the Apache License, Version 2.0 (the \"License\").\n", "\n", - "# Pix2Pix: An example with tf.keras and eager\n", - "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "# Pix2Pix" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "ITZuApL56Mny" + "id": "c7W3j96p219v" }, "source": [ - "This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings. We use [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) to achieve this.\n", - "\n", - "In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n", - "\n", - "Each epoch takes around 58 seconds on a single P100 GPU.\n", - "\n", - "Below is the output generated after training the model for 200 epochs.\n", - "\n", - "\n", - "![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)\n", - "![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "source": [ - "## Import TensorFlow and enable eager execution" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "YfIk2es3hJEd" - }, - "outputs": [], - "source": [ - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "import os\n", - "import time\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import PIL\n", - "from IPython.display import clear_output" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "source": [ - "## Load the dataset\n", - "\n", - "You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n", - "* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n", - "* In random mirroring, the image is randomly flipped horizontally i.e left to right." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Kn-k8kTXuAlv" - }, - "outputs": [], - "source": [ - "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n", - " cache_subdir=os.path.abspath('.'),\n", - " origin='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz', \n", - " extract=True)\n", - "\n", - "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "2CbTEt448b4R" - }, - "outputs": [], - "source": [ - "BUFFER_SIZE = 400\n", - "BATCH_SIZE = 1\n", - "IMG_WIDTH = 256\n", - "IMG_HEIGHT = 256" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "tyaP4hLJ8b4W" - }, - "outputs": [], - "source": [ - "def load_image(image_file, is_train):\n", - " image = tf.read_file(image_file)\n", - " image = tf.image.decode_jpeg(image)\n", - "\n", - " w = tf.shape(image)[1]\n", - "\n", - " w = w // 2\n", - " real_image = image[:, :w, :]\n", - " input_image = image[:, w:, :]\n", - "\n", - " input_image = tf.cast(input_image, tf.float32)\n", - " real_image = tf.cast(real_image, tf.float32)\n", - "\n", - " if is_train:\n", - " # random jittering\n", - " \n", - " # resizing to 286 x 286 x 3\n", - " input_image = tf.image.resize_images(input_image, [286, 286], \n", - " align_corners=True, \n", - " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", - " real_image = tf.image.resize_images(real_image, [286, 286], \n", - " align_corners=True, \n", - " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", - " \n", - " # randomly cropping to 256 x 256 x 3\n", - " stacked_image = tf.stack([input_image, real_image], axis=0)\n", - " cropped_image = tf.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n", - " input_image, real_image = cropped_image[0], cropped_image[1]\n", - "\n", - " if np.random.random() \u003e 0.5:\n", - " # random mirroring\n", - " input_image = tf.image.flip_left_right(input_image)\n", - " real_image = tf.image.flip_left_right(real_image)\n", - " else:\n", - " input_image = tf.image.resize_images(input_image, size=[IMG_HEIGHT, IMG_WIDTH], \n", - " align_corners=True, method=2)\n", - " real_image = tf.image.resize_images(real_image, size=[IMG_HEIGHT, IMG_WIDTH], \n", - " align_corners=True, method=2)\n", - " \n", - " # normalizing the images to [-1, 1]\n", - " input_image = (input_image / 127.5) - 1\n", - " real_image = (real_image / 127.5) - 1\n", - "\n", - " return input_image, real_image" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "source": [ - "## Use tf.data to create batches, map(do preprocessing) and shuffle the dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "SQHmYSmk8b4b" - }, - "outputs": [], - "source": [ - "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n", - "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", - "train_dataset = train_dataset.map(lambda x: load_image(x, True))\n", - "train_dataset = train_dataset.batch(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "MS9J0yA58b4g" - }, - "outputs": [], - "source": [ - "test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n", - "test_dataset = test_dataset.map(lambda x: load_image(x, False))\n", - "test_dataset = test_dataset.batch(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "source": [ - "## Write the generator and discriminator models\n", - "\n", - "* **Generator** \n", - " * The architecture of generator is a modified U-Net.\n", - " * Each block in the encoder is (Conv -\u003e Batchnorm -\u003e Leaky ReLU)\n", - " * Each block in the decoder is (Transposed Conv -\u003e Batchnorm -\u003e Dropout(applied to the first 3 blocks) -\u003e ReLU)\n", - " * There are skip connections between the encoder and decoder (as in U-Net).\n", - " \n", - "* **Discriminator**\n", - " * The Discriminator is a PatchGAN.\n", - " * Each block in the discriminator is (Conv -\u003e BatchNorm -\u003e Leaky ReLU)\n", - " * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n", - " * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n", - " * Discriminator receives 2 inputs.\n", - " * Input image and the target image, which it should classify as real.\n", - " * Input image and the generated image (output of generator), which it should classify as fake. \n", - " * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)\n", - "\n", - "* Shape of the input travelling through the generator and the discriminator is in the comments in the code.\n", - "\n", - "To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "tqqvWxlw8b4l" - }, - "outputs": [], - "source": [ - "OUTPUT_CHANNELS = 3" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "lFPI4Nu-8b4q" - }, - "outputs": [], - "source": [ - "class Downsample(tf.keras.Model):\n", - " \n", - " def __init__(self, filters, size, apply_batchnorm=True):\n", - " super(Downsample, self).__init__()\n", - " self.apply_batchnorm = apply_batchnorm\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - "\n", - " self.conv1 = tf.keras.layers.Conv2D(filters, \n", - " (size, size), \n", - " strides=2, \n", - " padding='same',\n", - " kernel_initializer=initializer,\n", - " use_bias=False)\n", - " if self.apply_batchnorm:\n", - " self.batchnorm = tf.keras.layers.BatchNormalization()\n", - " \n", - " def call(self, x, training):\n", - " x = self.conv1(x)\n", - " if self.apply_batchnorm:\n", - " x = self.batchnorm(x, training=training)\n", - " x = tf.nn.leaky_relu(x)\n", - " return x \n", - "\n", - "\n", - "class Upsample(tf.keras.Model):\n", - " \n", - " def __init__(self, filters, size, apply_dropout=False):\n", - " super(Upsample, self).__init__()\n", - " self.apply_dropout = apply_dropout\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - "\n", - " self.up_conv = tf.keras.layers.Conv2DTranspose(filters, \n", - " (size, size), \n", - " strides=2, \n", - " padding='same',\n", - " kernel_initializer=initializer,\n", - " use_bias=False)\n", - " self.batchnorm = tf.keras.layers.BatchNormalization()\n", - " if self.apply_dropout:\n", - " self.dropout = tf.keras.layers.Dropout(0.5)\n", - "\n", - " def call(self, x1, x2, training):\n", - " x = self.up_conv(x1)\n", - " x = self.batchnorm(x, training=training)\n", - " if self.apply_dropout:\n", - " x = self.dropout(x, training=training)\n", - " x = tf.nn.relu(x)\n", - " x = tf.concat([x, x2], axis=-1)\n", - " return x\n", - "\n", - "\n", - "class Generator(tf.keras.Model):\n", - " \n", - " def __init__(self):\n", - " super(Generator, self).__init__()\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - " \n", - " self.down1 = Downsample(64, 4, apply_batchnorm=False)\n", - " self.down2 = Downsample(128, 4)\n", - " self.down3 = Downsample(256, 4)\n", - " self.down4 = Downsample(512, 4)\n", - " self.down5 = Downsample(512, 4)\n", - " self.down6 = Downsample(512, 4)\n", - " self.down7 = Downsample(512, 4)\n", - " self.down8 = Downsample(512, 4)\n", - "\n", - " self.up1 = Upsample(512, 4, apply_dropout=True)\n", - " self.up2 = Upsample(512, 4, apply_dropout=True)\n", - " self.up3 = Upsample(512, 4, apply_dropout=True)\n", - " self.up4 = Upsample(512, 4)\n", - " self.up5 = Upsample(256, 4)\n", - " self.up6 = Upsample(128, 4)\n", - " self.up7 = Upsample(64, 4)\n", - "\n", - " self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, \n", - " (4, 4), \n", - " strides=2, \n", - " padding='same',\n", - " kernel_initializer=initializer)\n", - " \n", - " @tf.contrib.eager.defun\n", - " def call(self, x, training):\n", - " # x shape == (bs, 256, 256, 3) \n", - " x1 = self.down1(x, training=training) # (bs, 128, 128, 64)\n", - " x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)\n", - " x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)\n", - " x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)\n", - " x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)\n", - " x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)\n", - " x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)\n", - " x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)\n", - "\n", - " x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)\n", - " x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)\n", - " x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)\n", - " x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)\n", - " x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)\n", - " x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)\n", - " x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)\n", - "\n", - " x16 = self.last(x15) # (bs, 256, 256, 3)\n", - " x16 = tf.nn.tanh(x16)\n", - "\n", - " return x16" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "ll6aNeQx8b4v" - }, - "outputs": [], - "source": [ - "class DiscDownsample(tf.keras.Model):\n", - " \n", - " def __init__(self, filters, size, apply_batchnorm=True):\n", - " super(DiscDownsample, self).__init__()\n", - " self.apply_batchnorm = apply_batchnorm\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - "\n", - " self.conv1 = tf.keras.layers.Conv2D(filters, \n", - " (size, size), \n", - " strides=2, \n", - " padding='same',\n", - " kernel_initializer=initializer,\n", - " use_bias=False)\n", - " if self.apply_batchnorm:\n", - " self.batchnorm = tf.keras.layers.BatchNormalization()\n", - " \n", - " def call(self, x, training):\n", - " x = self.conv1(x)\n", - " if self.apply_batchnorm:\n", - " x = self.batchnorm(x, training=training)\n", - " x = tf.nn.leaky_relu(x)\n", - " return x \n", - "\n", - "class Discriminator(tf.keras.Model):\n", - " \n", - " def __init__(self):\n", - " super(Discriminator, self).__init__()\n", - " initializer = tf.random_normal_initializer(0., 0.02)\n", - " \n", - " self.down1 = DiscDownsample(64, 4, False)\n", - " self.down2 = DiscDownsample(128, 4)\n", - " self.down3 = DiscDownsample(256, 4)\n", - " \n", - " # we are zero padding here with 1 because we need our shape to \n", - " # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)\n", - " self.zero_pad1 = tf.keras.layers.ZeroPadding2D()\n", - " self.conv = tf.keras.layers.Conv2D(512, \n", - " (4, 4), \n", - " strides=1, \n", - " kernel_initializer=initializer, \n", - " use_bias=False)\n", - " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", - " \n", - " # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)\n", - " self.zero_pad2 = tf.keras.layers.ZeroPadding2D()\n", - " self.last = tf.keras.layers.Conv2D(1, \n", - " (4, 4), \n", - " strides=1,\n", - " kernel_initializer=initializer)\n", - " \n", - " @tf.contrib.eager.defun\n", - " def call(self, inp, tar, training):\n", - " # concatenating the input and the target\n", - " x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2)\n", - " x = self.down1(x, training=training) # (bs, 128, 128, 64)\n", - " x = self.down2(x, training=training) # (bs, 64, 64, 128)\n", - " x = self.down3(x, training=training) # (bs, 32, 32, 256)\n", - "\n", - " x = self.zero_pad1(x) # (bs, 34, 34, 256)\n", - " x = self.conv(x) # (bs, 31, 31, 512)\n", - " x = self.batchnorm1(x, training=training)\n", - " x = tf.nn.leaky_relu(x)\n", - " \n", - " x = self.zero_pad2(x) # (bs, 33, 33, 512)\n", - " # don't add a sigmoid activation here since\n", - " # the loss function expects raw logits.\n", - " x = self.last(x) # (bs, 30, 30, 1)\n", - "\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "gDkA05NE6QMs" - }, - "outputs": [], - "source": [ - "# The call function of Generator and Discriminator have been decorated\n", - "# with tf.contrib.eager.defun()\n", - "# We get a performance speedup if defun is used (~25 seconds per epoch)\n", - "generator = Generator()\n", - "discriminator = Discriminator()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "source": [ - "## Define the loss functions and the optimizer\n", - "\n", - "* **Discriminator loss**\n", - " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", - " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n", - " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n", - " * Then the total_loss is the sum of real_loss and the generated_loss\n", - " \n", - "* **Generator loss**\n", - " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n", - " * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n", - " * This allows the generated image to become structurally similar to the target image.\n", - " * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "cyhxTuvJyIHV" - }, - "outputs": [], - "source": [ - "LAMBDA = 100" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "wkMNfBWlT-PV" - }, - "outputs": [], - "source": [ - "def discriminator_loss(disc_real_output, disc_generated_output):\n", - " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output), \n", - " logits = disc_real_output)\n", - " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(disc_generated_output), \n", - " logits = disc_generated_output)\n", - "\n", - " total_disc_loss = real_loss + generated_loss\n", - "\n", - " return total_disc_loss" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "90BIcCKcDMxz" - }, - "outputs": [], - "source": [ - "def generator_loss(disc_generated_output, gen_output, target):\n", - " gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_generated_output),\n", - " logits = disc_generated_output) \n", - " # mean absolute error\n", - " l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", - "\n", - " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n", - "\n", - " return total_gen_loss" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "iWCn_PVdEJZ7" - }, - "outputs": [], - "source": [ - "generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)\n", - "discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "aKUZnDiqQrAh" - }, - "source": [ - "## Checkpoints (Object-based saving)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "WJnftd5sQsv6" - }, - "outputs": [], - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", - " discriminator_optimizer=discriminator_optimizer,\n", - " generator=generator,\n", - " discriminator=discriminator)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "source": [ - "## Training\n", - "\n", - "* We start by iterating over the dataset\n", - "* The generator gets the input image and we get a generated output.\n", - "* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.\n", - "* Next, we calculate the generator and the discriminator loss.\n", - "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n", - "\n", - "## Generate Images\n", - "\n", - "* After training, its time to generate some images!\n", - "* We pass images from the test dataset to the generator.\n", - "* The generator will then translate the input image into the output we expect.\n", - "* Last step is to plot the predictions and **voila!**" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "NS2GWywBbAWo" - }, - "outputs": [], - "source": [ - "EPOCHS = 200" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "RmdVsmvhPxyy" - }, - "outputs": [], - "source": [ - "def generate_images(model, test_input, tar):\n", - " # the training=True is intentional here since\n", - " # we want the batch statistics while running the model\n", - " # on the test dataset. If we use training=False, we will get \n", - " # the accumulated statistics learned from the training dataset\n", - " # (which we don't want)\n", - " prediction = model(test_input, training=True)\n", - " plt.figure(figsize=(15,15))\n", - "\n", - " display_list = [test_input[0], tar[0], prediction[0]]\n", - " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", - "\n", - " for i in range(3):\n", - " plt.subplot(1, 3, i+1)\n", - " plt.title(title[i])\n", - " # getting the pixel values between [0, 1] to plot it.\n", - " plt.imshow(display_list[i] * 0.5 + 0.5)\n", - " plt.axis('off')\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "2M7LmLtGEMQJ" - }, - "outputs": [], - "source": [ - "def train(dataset, epochs): \n", - " for epoch in range(epochs):\n", - " start = time.time()\n", - "\n", - " for input_image, target in dataset:\n", - "\n", - " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", - " gen_output = generator(input_image, training=True)\n", - "\n", - " disc_real_output = discriminator(input_image, target, training=True)\n", - " disc_generated_output = discriminator(input_image, gen_output, training=True)\n", - "\n", - " gen_loss = generator_loss(disc_generated_output, gen_output, target)\n", - " disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n", - "\n", - " generator_gradients = gen_tape.gradient(gen_loss, \n", - " generator.variables)\n", - " discriminator_gradients = disc_tape.gradient(disc_loss, \n", - " discriminator.variables)\n", - "\n", - " generator_optimizer.apply_gradients(zip(generator_gradients, \n", - " generator.variables))\n", - " discriminator_optimizer.apply_gradients(zip(discriminator_gradients, \n", - " discriminator.variables))\n", - "\n", - " if epoch % 1 == 0:\n", - " clear_output(wait=True)\n", - " for inp, tar in test_dataset.take(1):\n", - " generate_images(generator, inp, tar)\n", - " \n", - " # saving (checkpoint) the model every 20 epochs\n", - " if (epoch + 1) % 20 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - "\n", - " print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n", - " time.time()-start))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "a1zZmKmvOH85" - }, - "outputs": [], - "source": [ - "train(train_dataset, EPOCHS)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "kz80bY3aQ1VZ" - }, - "source": [ - "## Restore the latest checkpoint and test" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "4t4x69adQ5xb" - }, - "outputs": [], - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1RGysMU_BZhx" - }, - "source": [ - "## Testing on the entire test dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "KUgSnmy2nqSP" - }, - "outputs": [], - "source": [ - "# Run the trained model on the entire test dataset\n", - "for inp, tar in test_dataset:\n", - " generate_images(generator, inp, tar)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "3AJXOByaZVOf" - }, - "outputs": [], - "source": [ - "" + "This notebook has been moved to [https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/pix2pix.ipynb](https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/pix2pix.ipynb)" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index f2851d97223..a80f3d210a4 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -37,9 +37,6 @@ cuda_py_test( ], shard_count = 4, tags = [ - "noasan", # Fix b/118130911 - "nomsan", # Fix b/118130911 - "notsan", # Fix b/118130911 "optonly", "oss_serial", ], diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index cb207b8ddf3..a48d08b8a3a 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -83,17 +83,27 @@ cuda_py_test( name = "blocks_test", size = "medium", srcs = ["blocks_test.py"], - additional_deps = [ - ":blocks", - "//tensorflow:tensorflow_py", - ], + additional_deps = [":blocks_test_main_lib"], shard_count = 4, tags = [ - "no_oss", # b/123045964 + "no_oss", # TODO(b/132387200): Segfaulting "optonly", ], ) +py_library( + name = "blocks_test_main_lib", + testonly = True, + srcs = ["blocks_test.py"], + tags = [ + "optonly", + ], + deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], +) + cuda_py_test( name = "revnet_test", size = "medium", @@ -125,6 +135,7 @@ py_library( py_binary( name = "cifar_tfrecords", srcs = ["cifar_tfrecords.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", @@ -134,6 +145,7 @@ py_binary( py_binary( name = "main", srcs = ["main.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":main_lib"], ) @@ -153,6 +165,7 @@ py_library( py_binary( name = "main_estimator", srcs = ["main_estimator.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":cifar_input", diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops.py b/tensorflow/contrib/eager/python/examples/revnet/ops.py index 9ed5d363e6c..af17a22b4d1 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/ops.py +++ b/tensorflow/contrib/eager/python/examples/revnet/ops.py @@ -32,12 +32,14 @@ def downsample(x, filters, strides, axis=1): def pad_strides(strides, axis=1): """Convert length 2 to length 4 strides. - Needed since `tf.layers.Conv2D` uses length 2 strides, whereas operations - such as `tf.nn.avg_pool` use length 4 strides. + Needed since `tf.compat.v1.layers.Conv2D` uses length 2 strides, whereas + operations + such as `tf.nn.avg_pool2d` use length 4 strides. Args: strides: length 2 list/tuple strides for height and width axis: integer specifying feature dimension according to data format + Returns: length 4 strides padded with 1 on batch and channel dimension """ diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops_test.py b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py index 5bc2641faf5..e92c0fd9258 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/ops_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py @@ -47,7 +47,7 @@ class OpsTest(tf.test.TestCase): tape.watch(x) y = ops.downsample(x, filters=3, strides=(1, 1)) self.assertEqual(y.shape, x.shape) - dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + dy = tf.random_normal(shape=[batch_size, 32, 32, 3]) grad, = tape.gradient(y, [x], output_gradients=[dy]) self.assertEqual(grad.shape, x.shape) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index f4dbe7ac16f..aca0b2f05f6 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -8,6 +8,7 @@ load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "rnn_colorbot", srcs = ["rnn_colorbot.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":rnn_colorbot_lib"], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py index b7d8395e277..2955b94037f 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py @@ -19,17 +19,13 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.contrib.eager.python import tfe from tensorflow.contrib.eager.python.examples.rnn_colorbot import rnn_colorbot +from tensorflow.python.framework import test_util LABEL_DIMENSION = 5 -def device(): - return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" - - def random_dataset(): batch_size = 64 time_steps = 10 @@ -53,7 +49,7 @@ class RNNColorbotTest(tf.test.TestCase): keep_prob=1.0) optimizer = tf.train.AdamOptimizer(learning_rate=.01) dataset = random_dataset() - with tf.device(device()): + with test_util.use_gpu(): rnn_colorbot.train_one_epoch(model, optimizer, dataset) def testTest(self): @@ -62,7 +58,7 @@ class RNNColorbotTest(tf.test.TestCase): label_dimension=LABEL_DIMENSION, keep_prob=1.0) dataset = random_dataset() - with tf.device(device()): + with test_util.use_gpu(): rnn_colorbot.test(model, dataset) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index 43a6ca526d3..ef683ce232b 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -8,6 +8,7 @@ load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "rnn_ptb", srcs = ["rnn_ptb.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":rnn_ptb_lib"], ) @@ -32,7 +33,6 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], - tags = ["no_oss"], # b/123045964 ) cuda_py_test( diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 9b5a2c947b1..56aeb534230 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -44,7 +44,7 @@ layers = tf.keras.layers class RNN(tf.keras.Model): """A static RNN. - Similar to tf.nn.static_rnn, implemented as a class. + Similar to tf.compat.v1.nn.static_rnn, implemented as a class. """ def __init__(self, hidden_dim, num_layers, keep_ratio): diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 9b0fbaa6793..72f1829ffc4 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -17,6 +17,7 @@ py_test( name = "data_test", size = "small", srcs = ["data_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":data", diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py index 3bc3bb49bcb..72d23630cd9 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/data.py +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -179,7 +179,7 @@ def load_word_vectors(data_root, vocab): print("Loading word vectors...") - word2index = dict() + word2index = {} embed = [] embed.append([0] * WORD_VECTOR_LEN) # diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 3143270ccfe..6e47cc57051 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -30,7 +30,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf # pylint: disable=g-bad-import-order -import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util @@ -59,7 +58,7 @@ def _generate_synthetic_snli_data_batch(sequence_length, [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, 3, 2, 2]] * batch_size, dtype=np.int64).T) - if tfe.num_gpus(): + if test_util.is_gpu_available(): labels = labels.gpu() prem = prem.gpu() prem_trans = prem_trans.gpu() @@ -121,7 +120,7 @@ class SpinnTest(test_util.TensorFlowTestCase): def setUp(self): super(SpinnTest, self).setUp() - self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + self._test_device = "gpu:0" if test_util.is_gpu_available() else "cpu:0" self._temp_data_dir = tempfile.mkdtemp() def tearDown(self): @@ -436,7 +435,7 @@ class SpinnTest(test_util.TensorFlowTestCase): class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): def benchmarkEagerSpinnSNLIClassifier(self): - test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + test_device = "gpu:0" if test_util.is_gpu_available() else "cpu:0" with tf.device(test_device): burn_in_iterations = 2 benchmark_iterations = 10 diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index b32501c2e80..7885eb84f04 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -56,9 +56,9 @@ class Metric(trackable.Trackable): inputs = ... # Some tensors to compute the metric on. m_update = m(inputs) # Variables defined in first call, so get the initialization op afterwards. - m_init = m.init_variables() # or tf.global_variables_initializer() + m_init = m.init_variables() # or tf.compat.v1.global_variables_initializer() m_result = m.result() - with tf.Session() as sess: + with tf.compat.v1.Session() as sess: sess.run(m_init) for input in ...: sess.run(m_update) @@ -67,12 +67,12 @@ class Metric(trackable.Trackable): Example use with graph execution with placeholders and feed_dict: ```python m = SomeMetric(...) - m_placeholder = tf.placeholder(...) + m_placeholder = tf.compat.v1.placeholder(...) m_update = m(m_placeholder) # Variables defined in first call, so get the initialization op afterwards. - m_init = m.init_variables() # or tf.global_variables_initializer() + m_init = m.init_variables() # or tf.compat.v1.global_variables_initializer() m_result = m.result() - with tf.Session() as sess: + with tf.compat.v1.Session() as sess: sess.run(m_init) for input in ...: sess.run(m_update, feed_dict={m_placeholder: input}) @@ -406,8 +406,8 @@ class CategoricalAccuracy(Mean): """Calculates how often `predictions` matches `labels`. This class is compatible with `tf.keras.losses.categorical_crossentropy`, - `tf.nn.softmax_cross_entropy_with_logits_v2`, - `tf.losses.softmax_cross_entropy`. + `tf.nn.softmax_cross_entropy_with_logits`, + `tf.compat.v1.losses.softmax_cross_entropy`. Attributes: name: name of the accuracy object. @@ -450,7 +450,7 @@ class BinaryAccuracy(Mean): """Calculates how often `predictions` matches `labels`. This class is compatible with `tf.keras.losses.binary_crossentropy`, - `tf.losses.sigmoid_cross_entropy`, + `tf.compat.v1.losses.sigmoid_cross_entropy`, `tf.nn.sigmoid_cross_entropy_with_logits`. If there is more than one label, this will become multi-label classification. @@ -505,7 +505,7 @@ class SparseAccuracy(Mean): This class is compatible with `tf.keras.losses.sparse_categorical_crossentropy`, `tf.nn.sparse_softmax_cross_entropy_with_logits`, - `tf.losses.sparse_softmax_cross_entropy`. + `tf.compat.v1.losses.sparse_softmax_cross_entropy`. Attributes: name: name of the accuracy object diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 5cc0c4f23d9..363e2191c3d 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -24,7 +24,7 @@ import weakref from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras import backend from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -48,6 +48,7 @@ def _network_name_scope_naming(current_variable_scope): Args: current_variable_scope: A VariableScope object. + Returns: A name scope name. """ @@ -66,8 +67,7 @@ _NETWORK_DEPRECATION_MESSAGE = ( "`Layer` instances, including those from `tf.layers`, but switching to " "the `tf.keras.layers` versions along with the migration to " "`tf.keras.Model` is recommended, since it will preserve variable names. " - "Feel free to import it with an alias to avoid excess typing :)." -) + "Feel free to import it with an alias to avoid excess typing :).") class Network(base.Layer): @@ -99,8 +99,10 @@ class Network(base.Layer): def __init__(self, name): super(TwoLayerNetwork, self).__init__(name=name) - self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,))) - self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,))) + self.layer_one = self.track_layer(tf.compat.v1.layers.Dense(16, + input_shape=(8,))) + self.layer_two = self.track_layer(tf.compat.v1.layers.Dense(1, + input_shape=(16,))) def call(self, inputs): return self.layer_two(self.layer_one(inputs)) @@ -116,7 +118,7 @@ class Network(base.Layer): ``` This example prints variable names, one kernel and one bias per - `tf.layers.Dense` layer: + `tf.compat.v1.layers.Dense` layer: ``` ['net/dense/kernel:0', @@ -125,9 +127,10 @@ class Network(base.Layer): 'net/dense_1/bias:0'] ``` - These variables can be passed to a `Saver` (`tf.train.Saver`, or + These variables can be passed to a `Saver` (`tf.compat.v1.train.Saver`, or `tf.contrib.eager.Saver` when executing eagerly) to save or restore the - `Network`, typically alongside a global step and `tf.train.Optimizer` + `Network`, typically alongside a global step and + `tf.compat.v1.train.Optimizer` variables when checkpointing during training. Note that the semantics of calling a `Network` with graph execution (i.e. not @@ -151,14 +154,12 @@ class Network(base.Layer): Args: name: The name to use for this `Network`. If specified, it must be unique - in the context where this `Network` is first - (1) added to another `Network` (in which case it must not share a name - with other `Layers` added to that `Network`), or - (2) built/called (in which case no other 'top-level' `Network`s may - share this name). - If unspecified or None, the `Network` will be named using its class - name, with a number appended if necessary for uniqueness (e.g. MyNetwork - -> 'my_network_1'). + in the context where this `Network` is first (1) added to another + `Network` (in which case it must not share a name with other `Layers` + added to that `Network`), or (2) built/called (in which case no other + 'top-level' `Network`s may share this name). If unspecified or None, the + `Network` will be named using its class name, with a number appended if + necessary for uniqueness (e.g. MyNetwork -> 'my_network_1'). Raises: ValueError: If `name` is not valid. Note that some naming errors will @@ -167,7 +168,7 @@ class Network(base.Layer): if context.executing_eagerly(): logging.warning( ("** tfe.Network is deprecated and will be removed in a future " - "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE) + "version.\n\n%s"), _NETWORK_DEPRECATION_MESSAGE) if isinstance(name, variable_scope.VariableScope): raise ValueError("VariableScopes are not valid Network names.") if name is not None and "/" in name: @@ -192,8 +193,8 @@ class Network(base.Layer): def _gather_saveables_for_checkpoint(self): raise NotImplementedError( - "tfe.Network does not support object-based checkpointing.\n\n%s" - % _NETWORK_DEPRECATION_MESSAGE) + "tfe.Network does not support object-based checkpointing.\n\n%s" % + _NETWORK_DEPRECATION_MESSAGE) def _name_scope_name(self, current_variable_scope): """Overrides Layer op naming to match variable naming.""" @@ -220,22 +221,26 @@ class Network(base.Layer): avoid_names = parent_network._owned_layers name_uid_map = parent_network._sub_layer_name_uids else: - name_uid_map = base_layer_utils.get_default_graph_uid_map() + name_uid_map = backend.get_default_graph_uid_map() # Figure out which names we have to avoid based on which variable scope # we're nested in. strip_name = self._default_parent_variable_scope.name if strip_name: strip_name += "/" + def _strip_on_init_scope(name): if name.startswith(strip_name): return name[len(strip_name):] else: return None + avoid_names = set( _strip_on_init_scope(name) - for name in self._variable_scope_counts_on_init.keys() if name) + for name in self._variable_scope_counts_on_init.keys() + if name) self._name, self._base_name = self._make_unique_name( - name_uid_map=name_uid_map, avoid_names=avoid_names, + name_uid_map=name_uid_map, + avoid_names=avoid_names, namespace=self._default_parent_variable_scope.name, zero_based=True) if self._first_parent is None or (self._first_parent # False = no parent @@ -272,8 +277,8 @@ class Network(base.Layer): if expected_scope_name in self._variable_scope_counts_on_init: raise ValueError( ("A Network named '%s' already exists (or a variable_scope was " - "created with this name). Names must be unique.") % ( - self._name,)) + "created with this name). Names must be unique.") % + (self._name,)) # Make sure variables with this prefix will be unique. with variable_scope.variable_scope( None, use_resource=True, default_name=self._name) as scope: @@ -287,19 +292,16 @@ class Network(base.Layer): if scope_suffix != self._name: raise ValueError( ("A Network named '%s' already exists (or a variable_scope was " - "created with this name). Names must be unique.") % ( - self._name,)) - if (first_parent - and scope_prefix[:-1] != first_parent.scope_name): + "created with this name). Names must be unique.") % + (self._name,)) + if (first_parent and scope_prefix[:-1] != first_parent.scope_name): raise ValueError( ("Network variable names must match a nesting of sub-Network " "names. Expected prefix '%s' from parent network, but got " "'%s' when attempting to create a variable_scope for Network " "'%s'. Likely an explicit variable_scope was inserted into " - "the nesting.") % ( - first_parent.scope_name, - scope_prefix[:-1], - self._name)) + "the nesting.") % + (first_parent.scope_name, scope_prefix[:-1], self._name)) elif not first_parent and scope_prefix: # For the case when this Network is not nested inside any other # Network, but is in a variable_scope. This Network's name takes on @@ -323,15 +325,13 @@ class Network(base.Layer): raise ValueError( ("The parent of a Layer added to Network %s was garbage collected " "before the Layer was built. If this limitation bothers you " - "please file a feature request.") % - (self.name,)) + "please file a feature request.") % (self.name,)) with variable_scope.variable_scope(parent_scope): # Horrid hack to make Layer variable names which are direct # sub-layers of Networks conform to the Network variable naming # conventions. with variable_scope.variable_scope( - None, use_resource=True, - default_name=sublayer.name) as sub_scope: + None, use_resource=True, default_name=sublayer.name) as sub_scope: sublayer._scope = sub_scope # Also switch op naming for this Layer to match Network conventions, # i.e. op naming matching variable naming. @@ -354,7 +354,7 @@ class Network(base.Layer): `Network` can export a complete list of variables. Args: - layer: A `tf.layers.Layer` object. + layer: A `tf.compat.v1.layers.Layer` object. Returns: The passed in `layer`. @@ -398,9 +398,8 @@ class Network(base.Layer): ) layer._first_parent = weakref.ref(self) self._non_network_sublayers.append(layer) - if (not layer.built - and layer._first_parent - and self is layer._first_parent()): + if (not layer.built and layer._first_parent and + self is layer._first_parent()): if layer.name in self._owned_layers: if self._owned_layers[layer.name] is layer: return layer @@ -412,7 +411,7 @@ class Network(base.Layer): return layer def get_layer(self, name=None, index=None): - """Get a contained `tf.layers.Layer` either by name or index. + """Get a contained `tf.compat.v1.layers.Layer` either by name or index. Args: name: String matching one of the names of a contained `Layer`. Note that @@ -420,11 +419,11 @@ class Network(base.Layer): layer sharing (i.e. adding a `Layer` to this `Network` which was already added to another `Network`). The lowest index `Layer` with a matching name will be returned. - index: Integer in [0, number of layers). Layers are assigned an index - by the order they are added. + index: Integer in [0, number of layers). Layers are assigned an index by + the order they are added. Returns: - A `tf.layers.Layer` object. + A `tf.compat.v1.layers.Layer` object. Raises: ValueError: If neither or both of 'index' or 'name' is specified, or the @@ -490,8 +489,14 @@ class Network(base.Layer): def layers(self): return self._layers - def add_variable(self, name, shape, dtype=None, initializer=None, - regularizer=None, trainable=True, constraint=None): + def add_variable(self, + name, + shape, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + constraint=None): raise RuntimeError( "add_variable not supported in Network class yet. Please file an issue " "at https://github.com/tensorflow/tensorflow/issues/new if this is " @@ -532,7 +537,7 @@ class Sequential(Network): Args: layers_funcs: An optional sequence where each element is either a - tf.layers.Layer object or a callable. + tf.compat.v1.layers.Layer object or a callable. name: An optional string name to use for this Network. """ @@ -571,7 +576,6 @@ class Sequential(Network): _DeferredRestoration = collections.namedtuple( - "_DeferredRestoration", [ # The map_func to use (either user-specified or the default). @@ -595,9 +599,9 @@ _DeferredRestoration = collections.namedtuple( ]) -def _default_naming_conflict_error_message( - mapped_name, first_variable, second_variable, - network_name, network_scope_name): +def _default_naming_conflict_error_message(mapped_name, first_variable, + second_variable, network_name, + network_scope_name): return ( ("The default checkpoint variable name mapping strategy for Network " "'%s' resulted in a naming conflict. We attempted to strip off the " @@ -609,17 +613,15 @@ def _default_naming_conflict_error_message( "`map_func=lambda n: n` to save and restore to use fully qualified " "variable names in the checkpoint, although this will require that the " "variable prefix of the Network being restored into is also '%s'. You " - "may alternatively write an arbitrary mapping.") - % ( - network_name, network_scope_name, mapped_name, - first_variable._shared_name, - second_variable._shared_name, network_scope_name - )) + "may alternatively write an arbitrary mapping.") % + (network_name, network_scope_name, mapped_name, + first_variable._shared_name, second_variable._shared_name, + network_scope_name)) -def _restore_custom_map_func_error_message( - mapped_name, first_variable, second_variable, - network_name, network_scope_name): +def _restore_custom_map_func_error_message(mapped_name, first_variable, + second_variable, network_name, + network_scope_name): return ( ("The map_func passed to restore_network_checkpoint for the Network '%s' " "resulted in two variables named '%s' (originally '%s' and '%s'). Since " @@ -631,11 +633,9 @@ def _restore_custom_map_func_error_message( "of the Network. For reference, variables created by sub-Layers " "of this Network are prefixed with '%s', but if they are " "re-used after being added to another Network they will have " - "that Network's full variable prefix instead.") % ( - network_name, mapped_name, - first_variable._shared_name, - second_variable._shared_name, - network_scope_name)) + "that Network's full variable prefix instead.") % + (network_name, mapped_name, first_variable._shared_name, + second_variable._shared_name, network_scope_name)) def _make_custom_getter_for_deferred_restorations(): @@ -651,9 +651,13 @@ def _make_custom_getter_for_deferred_restorations(): """ deferred_restorations = [] - def _custom_getter(getter, name, shape=None, dtype=None, + def _custom_getter(getter, + name, + shape=None, + dtype=None, initializer=None, - *args, **kwargs): + *args, + **kwargs): """A custom getter which processes deferred restorations.""" # Iterate over restorations, newest first (newer restorations will take # precedence over older restorations, just like with immediate restorations @@ -661,15 +665,14 @@ def _make_custom_getter_for_deferred_restorations(): delayed_restoration = None found_value = False value_to_restore = None - for delayed_restoration in reversed( - deferred_restorations): + for delayed_restoration in reversed(deferred_restorations): checkpoint_name = delayed_restoration.map_func(name) - if (checkpoint_name - in delayed_restoration.checkpointed_variables_to_restore): + if (checkpoint_name in + delayed_restoration.checkpointed_variables_to_restore): found_value = True value_to_restore = ( - delayed_restoration.checkpointed_variables_to_restore[ - checkpoint_name]) + delayed_restoration + .checkpointed_variables_to_restore[checkpoint_name]) if found_value: break # value_to_restore may be False because this variable is not in any @@ -679,8 +682,13 @@ def _make_custom_getter_for_deferred_restorations(): if found_value and value_to_restore is not None: initializer = value_to_restore shape = None - variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, - *args, **kwargs) + variable = getter( + name, + shape=shape, + dtype=dtype, + initializer=initializer, + *args, + **kwargs) if found_value and value_to_restore is not None: # Mark as already restored from this checkpoint. delayed_restoration.checkpointed_variables_to_restore[ @@ -697,8 +705,8 @@ def _make_custom_getter_for_deferred_restorations(): raise ValueError( _restore_custom_map_func_error_message( mapped_name=checkpoint_name, - first_variable=delayed_restoration.restored_variables[ - checkpoint_name], + first_variable=delayed_restoration + .restored_variables[checkpoint_name], second_variable=variable, network_name=delayed_restoration.network_name, network_scope_name=delayed_restoration.network_scope_name)) @@ -706,12 +714,13 @@ def _make_custom_getter_for_deferred_restorations(): raise ValueError( _default_naming_conflict_error_message( mapped_name=checkpoint_name, - first_variable=delayed_restoration.restored_variables[ - checkpoint_name], + first_variable=delayed_restoration + .restored_variables[checkpoint_name], second_variable=variable, network_name=delayed_restoration.network_name, network_scope_name=delayed_restoration.network_scope_name)) return variable + return _custom_getter, deferred_restorations @@ -724,6 +733,7 @@ def _make_prefix_stripping_map_fn(scope_name): Args: scope_name: The Network.scope_name to strip from variables. + Returns: A scope_name-stripping default `map_fn` for the Network. """ @@ -735,8 +745,9 @@ def _make_prefix_stripping_map_fn(scope_name): and leaves other variable names fully qualified in the checkpoint. Args: - original_variable_name: The _shared_name of the variable (no :0 - suffix) to map. + original_variable_name: The _shared_name of the variable (no :0 suffix) to + map. + Returns: The checkpoint name of the variable. """ @@ -749,28 +760,30 @@ def _make_prefix_stripping_map_fn(scope_name): return _strip_variable_prefix -@deprecation.deprecated(date=None, instructions=( - "Please inherit from tf.keras.Model instead of tfe.Network, and use " - "tf.keras.Model.save_weights.")) -def save_network_checkpoint( - network, save_path, global_step=None, map_func=None): +@deprecation.deprecated( + date=None, + instructions=( + "Please inherit from tf.keras.Model instead of tfe.Network, and use " + "tf.keras.Model.save_weights.")) +def save_network_checkpoint(network, save_path, global_step=None, + map_func=None): """Save variables from the Network to a checkpoint. Args: network: A Network object to save. - save_path: Either a checkpoint prefix or the name of a directory to save - the checkpoint in (in which case the checkpoint will be named based on - the Network name). + save_path: Either a checkpoint prefix or the name of a directory to save the + checkpoint in (in which case the checkpoint will be named based on the + Network name). global_step: The global step to use when naming the checkpoint. If None - (default), we will first try to get the default global step. If that - fails because no default global step exists, then the checkpoint is - created without a global step suffix. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. + (default), we will first try to get the default global step. If that fails + because no default global step exists, then the checkpoint is created + without a global step suffix. + map_func: A function mapping fully qualified variable names (e.g. + 'my_network_1/dense_1/kernel') to names in the checkpoint. By default (if + `map_func=None`), the variable prefix for the network being restored + (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped and all + other variable names (shared with other Networks) are left unchanged. + Returns: The checkpoint prefix for the saved checkpoint, which may be passed to `Network.restore`. @@ -801,12 +814,13 @@ def save_network_checkpoint( # full variable names in the checkpoint. This could be odd for deeply # nested sub-Networks (since the full prefix from the nesting would # get added), so for now we'll let the user deal with this case. - raise ValueError(_default_naming_conflict_error_message( - mapped_name=mapped_name, - first_variable=variable_map[mapped_name], - second_variable=variable, - network_name=network.name, - network_scope_name=network.scope_name)) + raise ValueError( + _default_naming_conflict_error_message( + mapped_name=mapped_name, + first_variable=variable_map[mapped_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) else: # The user passed their own problematic map_func. raise ValueError( @@ -816,17 +830,17 @@ def save_network_checkpoint( "the Network. For reference, variables created by sub-Layers of " "this Network are prefixed with '%s', but if they are re-used " "after being added to another Network, they will have that " - "Network's full variable prefix instead.") % ( - network.name, mapped_name, - variable_map[mapped_name]._shared_name, - variable._shared_name, - network.scope_name)) + "Network's full variable prefix instead.") % + (network.name, mapped_name, variable_map[mapped_name]._shared_name, + variable._shared_name, network.scope_name)) if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() return saver_lib.Saver(variable_map).save( - sess=sess, save_path=save_path, write_meta_graph=False, + sess=sess, + save_path=save_path, + write_meta_graph=False, global_step=global_step) @@ -869,10 +883,10 @@ def _restore_existing_variables(network, save_path, map_func, user_map_func): Args: network: A Network object to restore. save_path: The checkpoint prefix or directory to read from. - map_func: The function to use when mapping from variable names to - checkpoint names. - user_map_func: The original map_func passed by the user, for error - checking. + map_func: The function to use when mapping from variable names to checkpoint + names. + user_map_func: The original map_func passed by the user, for error checking. + Returns: A dictionary mapping from checkpoint names to variable objects which have been restored (for bookkeeping to avoid deferred restorations on these @@ -886,21 +900,23 @@ def _restore_existing_variables(network, save_path, map_func, user_map_func): if existing_variables_by_checkpoint_name.setdefault( checkpoint_name, variable) is not variable: if user_map_func is None: - raise ValueError(_default_naming_conflict_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=network.name, - network_scope_name=network.scope_name)) + raise ValueError( + _default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) else: - raise ValueError(_restore_custom_map_func_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=network.name, - network_scope_name=network.scope_name)) + raise ValueError( + _restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) if existing_variables_by_checkpoint_name: if context.executing_eagerly(): sess = None @@ -951,9 +967,11 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func, _add_deferred_restoration(network, deferred_restoration) -@deprecation.deprecated(date=None, instructions=( - "Please inherit from tf.keras.Model instead of tfe.Network, and use " - "tf.keras.Model.load_weights.")) +@deprecation.deprecated( + date=None, + instructions=( + "Please inherit from tf.keras.Model instead of tfe.Network, and use " + "tf.keras.Model.load_weights.")) def restore_network_checkpoint(network, save_path, map_func=None): """Restore the Network from a checkpoint. @@ -976,13 +994,13 @@ def restore_network_checkpoint(network, save_path, map_func=None): network: A Network object to restore. save_path: The return value of `tfe.save_network_checkpoint`, or a directory to search for a checkpoint. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. Note that this is the _same_ map_func as - `tfe.save_network_checkpoint`, not an inverse mapping. + map_func: A function mapping fully qualified variable names (e.g. + 'my_network_1/dense_1/kernel') to names in the checkpoint. By default (if + `map_func=None`), the variable prefix for the network being restored + (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped and all + other variable names (shared with other Networks) are left unchanged. Note + that this is the _same_ map_func as `tfe.save_network_checkpoint`, not an + inverse mapping. """ network._finalize_name(parent_network=False) network._set_scope() # scope_name should be available to map_funcs diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py index 258f0a19309..d221d9790a6 100644 --- a/tensorflow/contrib/eager/python/parameter_server.py +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -144,7 +144,7 @@ class SharedVariable(resource_variable_ops.ResourceVariable): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access - handle_name = ops._name_from_scope_name(name) + handle_name = ops.name_from_scope_name(name) shared_name = handle_name if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index f540d9b37b6..fc78e46a5b1 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -23,6 +23,7 @@ import os import numpy as np +from tensorflow.python import pywrap_tensorflow from tensorflow.contrib.eager.python import parameter_server from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 @@ -31,6 +32,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import remote from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -91,10 +93,11 @@ class RemoteExecutionTest(test.TestCase): def setUp(self): # Start the local server. + local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie() context.set_server_def( server_def=get_server_def( JOB_NAME, - local_server_port=0, + local_server_port=local_port, remote_server_addresses=[ self._cached_server1_target, self._cached_server2_target ], @@ -220,12 +223,10 @@ class RemoteExecutionTest(test.TestCase): self.assertEqual(y.device, "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME) + @test_util.run_gpu_only @run_sync_and_async def testGPUToRemoteCopy(self): """Tests that the remote copy happens satisfactorily.""" - if not context.context().num_gpus(): - self.skipTest("No GPUs.") - x1 = array_ops.ones([2, 2]).gpu() with ops.device("/job:remote_device/replica:0/task:1/device:CPU:0"): diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index 1d0d6c6c14c..8649e56556a 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -85,7 +85,7 @@ def restore_variables_on_create(save_path, map_func=None): raise ValueError("map_func must be callable.") map_func_wrapper = lambda self, x: map_func(x) - ckpt_var_cache = dict() + ckpt_var_cache = {} reader = checkpoint_utils.load_checkpoint(save_path) for k, _ in checkpoint_utils.list_variables(save_path): ckpt_var_cache[k] = reader.get_tensor(k) @@ -114,7 +114,7 @@ def restore_variables_on_create(save_path, map_func=None): class Saver(object): - """A tf.train.Saver adapter for use when eager execution is enabled. + """A tf.compat.v1.train.Saver adapter for use when eager execution is enabled. `Saver`'s name-based checkpointing strategy is fragile. Please switch to `tf.train.Checkpoint` or `tf.keras.Model.save_weights`, which perform a more @@ -123,9 +123,9 @@ class Saver(object): """ def __init__(self, var_list): - """A tf.train.Saver adapter for use when eager execution is enabled. + """A tf.compat.v1.train.Saver adapter for use when eager execution is enabled. - The API, and on-disk format, mimic tf.train.Saver except that no + The API, and on-disk format, mimic tf.compat.v1.train.Saver except that no Session is needed. Args: @@ -173,13 +173,14 @@ class Saver(object): def get_optimizer_variables(optimizer): - """Returns a list of variables for the given `tf.train.Optimizer`. + """Returns a list of variables for the given `tf.compat.v1.train.Optimizer`. Equivalent to `optimizer.variables()`. Args: - optimizer: An instance of `tf.train.Optimizer` which has created variables - (typically after a call to `Optimizer.minimize`). + optimizer: An instance of `tf.compat.v1.train.Optimizer` which has created + variables (typically after a call to `Optimizer.minimize`). + Returns: A list of variables which have been created by the `Optimizer`. """ diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index df5b059448f..8080d954eb7 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -16,7 +16,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. -To use, at program startup, call `tf.enable_eager_execution()`. +To use, at program startup, call `tf.compat.v1.enable_eager_execution()`. @@metrics @@ -138,7 +138,7 @@ from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable -from tensorflow.python.training.tracking.util import Checkpoint +from tensorflow.python.training.tracking.util import CheckpointV1 as Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 6881fabdc09..2e44ff4096a 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import tempfile from tensorflow.contrib.eager.python import tfe -from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -41,9 +40,10 @@ class TFETest(test_util.TensorFlowTestCase): self.assertAllEqual([[4.]], y.numpy()) def testInstantError(self): - if context.num_gpus(): + if test_util.is_gpu_available(): # TODO(nareshmodi): make this test better self.skipTest("Gather doesn't do index checking on GPUs") + with self.assertRaisesRegexp(errors.InvalidArgumentError, r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) @@ -79,10 +79,8 @@ class TFETest(test_util.TensorFlowTestCase): grad = tfe.gradients_function(f) self.assertEquals([12], [x.numpy() for x in grad(3.)]) + @test_util.run_gpu_only def testGPU(self): - if tfe.num_gpus() <= 0: - self.skipTest('No GPUs available') - # tf.Tensor.as_gpu_device() moves a tensor to GPU. x = constant_op.constant([[1., 2.], [3., 4.]]).gpu() # Alternatively, tf.device() as a context manager places tensors and diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index da2479a0b7b..ab510b86d15 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -201,6 +201,7 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/ops/kmeans_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", tags = ["notsan"], diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index d48b89cbacc..505d8d731fa 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -613,7 +613,8 @@ class _InitializeClustersOpFactory(object): inp = nn_impl.l2_normalize(inp, dim=1) return gen_clustering_ops.kmeans_plus_plus_initialization( inp, - math_ops.to_int64(self._num_remaining), self._random_seed, + math_ops.cast(self._num_remaining, dtypes.int64), + self._random_seed, self._kmeans_plus_plus_num_retries) def _kmc2_multiple_centers(self): diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 7ab70fbcfd7..5c55f7f597b 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -134,7 +134,7 @@ class WALSModel(object): # model_init_op is passed to Supervisor. Chief trainer runs it. Other # trainers wait. - sv = tf.train.Supervisor(is_chief=is_chief, + sv = tf.compat.v1.train.Supervisor(is_chief=is_chief, ..., init_op=tf.group(..., model_init_op, ...), ...) ... @@ -912,7 +912,7 @@ class WALSModel(object): total_rhs = ( self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) - # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of + # TODO(rmlarsen): handle transposing in tf.linalg.solve instead of # transposing explicitly. # TODO(rmlarsen): multi-thread tf.matrix_solve. new_left_values = array_ops.transpose( diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index 9f0664dfe5b..000b9832aa4 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -53,7 +53,7 @@ def _covariance(x, diag): A Tensor representing the covariance of x. In the case of diagonal matrix just the diagonal is returned. """ - num_points = math_ops.to_float(array_ops.shape(x)[0]) + num_points = math_ops.cast(array_ops.shape(x)[0], dtypes.float32) x -= math_ops.reduce_mean(x, 0, keepdims=True) if diag: cov = math_ops.reduce_sum( @@ -297,8 +297,9 @@ class GmmAlgorithm(object): cholesky, array_ops.transpose( diff, perm=[0, 2, 1]), lower=True)) diag_m = array_ops.transpose(math_ops.reduce_sum(x_mu_cov, 1)) - self._probs[shard_id] = -0.5 * (diag_m + math_ops.to_float(self._dimensions) - * math_ops.log(2 * np.pi) + log_det_covs) + self._probs[shard_id] = ( + -0.5 * (diag_m + math_ops.cast(self._dimensions, dtypes.float32) * + math_ops.log(2 * np.pi) + log_det_covs)) def _define_diag_covariance_probs(self, shard_id, shard): """Defines the diagonal covariance probabilities per example in a class. @@ -320,7 +321,8 @@ class GmmAlgorithm(object): x2_cov = math_ops.matmul(x2, cov_expanded) x2_cov = array_ops.transpose(array_ops.squeeze(x2_cov, [2])) self._probs[shard_id] = -0.5 * ( - math_ops.to_float(self._dimensions) * math_ops.log(2.0 * np.pi) + + math_ops.cast(self._dimensions, dtypes.float32) * + math_ops.log(2.0 * np.pi) + array_ops.transpose(det_expanded) + x2_cov) def _define_log_prob_operation(self, shard_id, shard): @@ -400,7 +402,8 @@ class GmmAlgorithm(object): # Update alpha. if 'w' in self._params: final_points_in_k = points_in_k / num_batches - num_examples = math_ops.to_float(math_ops.reduce_sum(final_points_in_k)) + num_examples = math_ops.cast(math_ops.reduce_sum(final_points_in_k), + dtypes.float32) self._alpha_op = self._alpha.assign(final_points_in_k / (num_examples + MEPS)) else: diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 3eb396a29cc..260ea5bf127 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -81,9 +81,9 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): Args: init_op: An op that, when run, will choose some initial cluster centers. - This op may need to be run multiple times to choose all the centers. + This op may need to be run multiple times to choose all the centers. is_initialized_var: A boolean variable reporting whether all initial - centers have been chosen. + centers have been chosen. is_chief: A boolean specifying whether this task is the chief. """ self._init_op = init_op @@ -113,8 +113,8 @@ def _parse_features_if_necessary(features, feature_columns): features: The input features. feature_columns: An optionable iterable containing all the feature columns used by the model. All items in the set should be feature column instances - that can be passed to `tf.feature_column.input_layer`. If this is None, - all features will be used. + that can be passed to `tf.compat.v1.feature_column.input_layer`. If this + is None, all features will be used. Returns: If `features` is a dict of `k` features (optionally filtered by @@ -255,7 +255,7 @@ class KMeansClustering(estimator.Estimator): points = np.random.uniform(0, 1000, [num_points, dimensions]) def input_fn(): - return tf.train.limit_epochs( + return tf.compat.v1.train.limit_epochs( tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1) num_clusters = 5 @@ -358,26 +358,25 @@ class KMeansClustering(estimator.Estimator): argument is ignored if `initial_clusters` is a tensor or numpy array. model_dir: The directory to save the model results and log files. initial_clusters: Specifies how the initial cluster centers are chosen. - One of the following: - * a tensor or numpy array with the initial cluster centers. - * a callable `f(inputs, k)` that selects and returns up to `k` centers - from an input batch. `f` is free to return any number of centers - from `0` to `k`. It will be invoked on successive input batches - as necessary until all `num_clusters` centers are chosen. + One of the following: * a tensor or numpy array with the initial cluster + centers. * a callable `f(inputs, k)` that selects and returns up to + `k` centers from an input batch. `f` is free to return any number of + centers from `0` to `k`. It will be invoked on successive input + batches as necessary until all `num_clusters` centers are chosen. * `KMeansClustering.RANDOM_INIT`: Choose centers randomly from an input - batch. If the batch size is less than `num_clusters` then the - entire batch is chosen to be initial cluster centers and the - remaining centers are chosen from successive input batches. + batch. If the batch size is less than `num_clusters` then the entire + batch is chosen to be initial cluster centers and the remaining + centers are chosen from successive input batches. * `KMeansClustering.KMEANS_PLUS_PLUS_INIT`: Use kmeans++ to choose - centers from the first input batch. If the batch size is less - than `num_clusters`, a TensorFlow runtime error occurs. + centers from the first input batch. If the batch size is less than + `num_clusters`, a TensorFlow runtime error occurs. distance_metric: The distance metric used for clustering. One of: * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance - between vectors `u` and `v` is defined as \\(||u - v||_2\\) - which is the square root of the sum of the absolute squares of - the elements' difference. + between vectors `u` and `v` is defined as \\(||u - v||_2\\) which is + the square root of the sum of the absolute squares of the elements' + difference. * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors - `u` and `v` is defined as \\(1 - (u . v) / (||u||_2 ||v||_2)\\). + `u` and `v` is defined as \\(1 - (u . v) / (||u||_2 ||v||_2)\\). random_seed: Python integer. Seed for PRNG used to initialize centers. use_mini_batch: A boolean specifying whether to use the mini-batch k-means algorithm. See explanation above. @@ -396,8 +395,9 @@ class KMeansClustering(estimator.Estimator): config: See `tf.estimator.Estimator`. feature_columns: An optionable iterable containing all the feature columns used by the model. All items in the set should be feature column - instances that can be passed to `tf.feature_column.input_layer`. If this - is None, all features will be used. + instances that can be passed to + `tf.compat.v1.feature_column.input_layer`. If this is None, all features + will be used. Raises: ValueError: An invalid argument was passed to `initial_clusters` or @@ -406,19 +406,19 @@ class KMeansClustering(estimator.Estimator): if isinstance(initial_clusters, str) and initial_clusters not in [ KMeansClustering.RANDOM_INIT, KMeansClustering.KMEANS_PLUS_PLUS_INIT ]: - raise ValueError( - "Unsupported initialization algorithm '%s'" % initial_clusters) + raise ValueError("Unsupported initialization algorithm '%s'" % + initial_clusters) if distance_metric not in [ KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, KMeansClustering.COSINE_DISTANCE ]: raise ValueError("Unsupported distance metric '%s'" % distance_metric) super(KMeansClustering, self).__init__( - model_fn=_ModelFn( - num_clusters, initial_clusters, distance_metric, random_seed, - use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance, - feature_columns).model_fn, + model_fn=_ModelFn(num_clusters, initial_clusters, distance_metric, + random_seed, use_mini_batch, + mini_batch_steps_per_iteration, + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns).model_fn, model_dir=model_dir, config=config) @@ -447,7 +447,7 @@ class KMeansClustering(estimator.Estimator): Args: input_fn: Input points. See `tf.estimator.Estimator.evaluate`. Only one - batch is retrieved. + batch is retrieved. Returns: The sum of the squared distance from each point in the first batch of diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 64df44fe436..456304ac61c 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -64,11 +64,11 @@ def sequence_input_layer( watches_embedding = embedding_column(watches, dimension=10) columns = [rating, watches] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) + features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) input_layer, sequence_length = sequence_input_layer(features, columns) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( + rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.compat.v1.nn.dynamic_rnn( rnn_cell, inputs=input_layer, sequence_length=sequence_length) ``` @@ -199,11 +199,11 @@ def sequence_categorical_column_with_identity( watches_embedding = embedding_column(watches, dimension=10) columns = [watches_embedding] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) + features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) input_layer, sequence_length = sequence_input_layer(features, columns) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( + rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.compat.v1.nn.dynamic_rnn( rnn_cell, inputs=input_layer, sequence_length=sequence_length) ``` @@ -243,11 +243,11 @@ def sequence_categorical_column_with_hash_bucket( tokens_embedding = embedding_column(tokens, dimension=10) columns = [tokens_embedding] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) + features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) input_layer, sequence_length = sequence_input_layer(features, columns) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( + rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.compat.v1.nn.dynamic_rnn( rnn_cell, inputs=input_layer, sequence_length=sequence_length) ``` @@ -286,11 +286,11 @@ def sequence_categorical_column_with_vocabulary_file( states_embedding = embedding_column(states, dimension=10) columns = [states_embedding] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) + features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) input_layer, sequence_length = sequence_input_layer(features, columns) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( + rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.compat.v1.nn.dynamic_rnn( rnn_cell, inputs=input_layer, sequence_length=sequence_length) ``` @@ -347,11 +347,11 @@ def sequence_categorical_column_with_vocabulary_list( colors_embedding = embedding_column(colors, dimension=3) columns = [colors_embedding] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) + features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) input_layer, sequence_length = sequence_input_layer(features, columns) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( + rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.compat.v1.nn.dynamic_rnn( rnn_cell, inputs=input_layer, sequence_length=sequence_length) ``` @@ -403,11 +403,11 @@ def sequence_numeric_column( temperature = sequence_numeric_column('temperature') columns = [temperature] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) + features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) input_layer, sequence_length = sequence_input_layer(features, columns) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( + rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.compat.v1.nn.dynamic_rnn( rnn_cell, inputs=input_layer, sequence_length=sequence_length) ``` diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 45a67acb5b2..42e212b0c11 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -30,7 +30,10 @@ _ffmpeg_so = loader.load_op_library( resource_loader.get_path_to_datafile('ffmpeg.so')) -@deprecated('2018-09-04', 'This will be deleted and should not be used.') +@deprecated('2018-09-04', + 'tf.contrib.ffmpeg will be removed in 2.0, the support for video ' + 'and audio will continue to be provided in tensorflow-io: ' + 'https://github.com/tensorflow/io') def decode_audio(contents, file_format=None, samples_per_second=None, channel_count=None, stream=None): """Create an op that decodes the contents of an audio file. @@ -71,7 +74,10 @@ def decode_audio(contents, file_format=None, samples_per_second=None, ops.NotDifferentiable('DecodeAudio') -@deprecated('2018-09-04', 'This will be deleted and should not be used.') +@deprecated('2018-09-04', + 'tf.contrib.ffmpeg will be removed in 2.0, the support for video ' + 'and audio will continue to be provided in tensorflow-io: ' + 'https://github.com/tensorflow/io') def encode_audio(audio, file_format=None, samples_per_second=None): """Creates an op that encodes an audio file using sampled audio from a tensor. @@ -98,13 +104,15 @@ def encode_audio(audio, file_format=None, samples_per_second=None): ops.NotDifferentiable('EncodeAudio') -@deprecated('2018-09-04', 'This will be deleted and should not be used.') +@deprecated('2018-09-04', + 'tf.contrib.ffmpeg will be removed in 2.0, the support for video ' + 'and audio will continue to be provided in tensorflow-io: ' + 'https://github.com/tensorflow/io') def decode_video(contents): """Create an op that decodes the contents of a video file. Args: - contents: The binary contents of the video file to decode. This is a - scalar. + contents: The binary contents of the video file to decode. This is a scalar. Returns: A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output. diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 8fd2b5f39bc..91e2954079e 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -145,6 +145,7 @@ py_test( name = "arg_scope_test", size = "small", srcs = ["python/ops/arg_scope_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":framework_py", @@ -156,6 +157,7 @@ py_test( name = "checkpoint_utils_test", size = "small", srcs = ["python/framework/checkpoint_utils_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["manual"], # http://b/30468735 deps = [ @@ -175,6 +177,7 @@ py_test( name = "ops_test", size = "small", srcs = ["python/ops/ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":framework_py", @@ -187,6 +190,7 @@ py_test( name = "prettyprint_ops_test", size = "small", srcs = ["python/ops/prettyprint_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":framework_py", @@ -203,6 +207,7 @@ py_test( py_test( name = "experimental_test", srcs = ["python/framework/experimental_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":framework_py", @@ -214,6 +219,7 @@ py_test( py_test( name = "graph_util_test", srcs = ["python/framework/graph_util_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":framework_py", @@ -225,6 +231,7 @@ py_test( py_test( name = "tensor_util_test", srcs = ["python/framework/tensor_util_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":framework_py", @@ -242,6 +249,7 @@ py_test( name = "variables_test", size = "small", srcs = ["python/ops/variables_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["manual"], deps = [ @@ -280,6 +288,7 @@ py_test( size = "medium", srcs = ["python/ops/checkpoint_ops_test.py"], data = [":checkpoint_ops_testdata"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py index e7184a01fbf..6dd887edf59 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py @@ -177,11 +177,11 @@ def init_from_checkpoint(checkpoint_dir, assignment_map): ```python # Create variables. - with tf.variable_scope('test'): - m = tf.get_variable('my_var') - with tf.variable_scope('test2'): - var2 = tf.get_variable('my_var') - var3 = tf.get_variable(name="my1", shape=[100, 100], + with tf.compat.v1.variable_scope('test'): + m = tf.compat.v1.get_variable('my_var') + with tf.compat.v1.variable_scope('test2'): + var2 = tf.compat.v1.get_variable('my_var') + var3 = tf.compat.v1.get_variable(name="my1", shape=[100, 100], partitioner=lambda shape, dtype: [5, 1]) ... # Specify which variables to initialize from checkpoint. diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 2703224b1bf..6dd82a47572 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -141,10 +141,10 @@ def get_placeholders(graph): For example: ```python - a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a') - a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b') + a = tf.compat.v1.placeholder(dtype=tf.float32, shape=[2, 2], name='a') + a = tf.compat.v1.placeholder(dtype=tf.int32, shape=[3, 2], name='b') - tf.contrib.framework.get_placeholders(tf.get_default_graph()) + tf.contrib.framework.get_placeholders(tf.compat.v1.get_default_graph()) # Returns: # [, # ] diff --git a/tensorflow/contrib/framework/python/ops/script_ops.py b/tensorflow/contrib/framework/python/ops/script_ops.py index d5cb679e2c0..21b00fcdaa8 100644 --- a/tensorflow/contrib/framework/python/ops/script_ops.py +++ b/tensorflow/contrib/framework/python/ops/script_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Script Language Operators. @@py_func @@ -39,7 +38,8 @@ def py_func(func, name=None): """Wraps a python function and uses it as a TensorFlow op. - This function is a wrapper around `tf.py_func` and improve it with kwargs + This function is a wrapper around `tf.compat.v1.py_func` and improve it with + kwargs and output_shapes. Further it changed some argument names. Given a python function `func`, which takes numpy arrays as its @@ -52,27 +52,31 @@ def py_func(func, def my_func(x): # x will be a numpy array with the contents of the placeholder below return np.sinh(x) - inp = tf.placeholder(tf.float32) - y = tf.py_func(my_func, [inp], tf.float32) + inp = tf.compat.v1.placeholder(tf.float32) + y = tf.compat.v1.py_func(my_func, [inp], tf.float32) ``` - **N.B.** The `tf.py_func()` operation has the following known limitations: + **N.B.** The `tf.compat.v1.py_func()` operation has the following known + limitations: * The body of the function (i.e. `func`) will not be serialized in a `GraphDef`. Therefore, you should not use this function if you need to serialize your model and restore it in a different environment. * The operation must run in the same address space as the Python program - that calls `tf.py_func()`. If you are using distributed TensorFlow, you - must run a `tf.train.Server` in the same process as the program that calls - `tf.py_func()` and you must pin the created operation to a device in that + that calls `tf.compat.v1.py_func()`. If you are using distributed + TensorFlow, you + must run a `tf.distribute.Server` in the same process as the program that + calls + `tf.compat.v1.py_func()` and you must pin the created operation to a device + in that server (e.g. using `with tf.device():`). Args: func: A Python function, which accepts a list of NumPy `ndarray` objects - having element types that match the corresponding `tf.Tensor` objects - in `inp`, and returns a list of `ndarray` objects (or a single `ndarray`) + having element types that match the corresponding `tf.Tensor` objects in + `inp`, and returns a list of `ndarray` objects (or a single `ndarray`) having element types that match the corresponding values in `Tout`. args: A list of `Tensor` objects. kwargs: A dict with `Tensor` objects as values. @@ -80,11 +84,10 @@ def py_func(func, tensorflow data type if there is only one, indicating what `func` returns. output_shapes: Same as output_types, except the types are replaces with shapes (optional). - stateful: (Boolean.) If True, the function should be considered stateful. - If a function is stateless, when given the same input it will return the - same output and have no observable side effects. Optimizations such as - common subexpression elimination are only performed on stateless - operations. + stateful: (Boolean.) If True, the function should be considered stateful. If + a function is stateless, when given the same input it will return the same + output and have no observable side effects. Optimizations such as common + subexpression elimination are only performed on stateless operations. name: A name for the operation (optional). Returns: @@ -133,8 +136,9 @@ def py_func(func, if output_shapes is not None: # I am not sure if this is nessesary - output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) + output_shapes = nest.map_structure_up_to(output_types, + tensor_shape.as_shape, + output_shapes) flattened_shapes = nest.flatten(output_shapes) for ret_t, shape in zip(flat_values, flattened_shapes): diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index a7acae804a0..5b957817227 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - -"""Variable functions. -""" +"""Variable functions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -72,18 +70,20 @@ __all__ = ['add_model_variable', def zero_initializer(ref, use_locking=True, name="zero_initializer"): """Initialize 'ref' with all zeros, ref tensor should be uninitialized. + If already initialized, you will get ValueError. This op is intended to save memory during initialization. Args: ref: ref of the tensor need to be zero initialized. name: optional name for this operation. + Returns: ref that initialized. Raises: ValueError: If ref tensor is initialized. """ loader.load_op_library( - resource_loader.get_path_to_datafile("_variable_ops.so")) + resource_loader.get_path_to_datafile('_variable_ops.so')) if resource_variable_ops.is_resource_variable(ref): return gen_variable_ops.zero_var_initializer( ref.handle, shape=ref.shape, dtype=ref.dtype, name=name) @@ -91,7 +91,7 @@ def zero_initializer(ref, use_locking=True, name="zero_initializer"): return gen_variable_ops.zero_initializer(ref, name=name) -@deprecated(None, "Please switch to tf.train.assert_global_step") +@deprecated(None, 'Please switch to tf.train.assert_global_step') def assert_global_step(global_step_tensor): training_util.assert_global_step(global_step_tensor) @@ -105,8 +105,8 @@ def assert_or_get_global_step(graph=None, global_step_tensor=None): Args: graph: The graph to find the global step tensor for. - global_step_tensor: The tensor to check for suitability as a global step. - If None is given (the default), find a global step tensor. + global_step_tensor: The tensor to check for suitability as a global step. If + None is given (the default), find a global step tensor. Returns: A tensor suitable as a global step, or `None` if none was provided and none @@ -119,19 +119,21 @@ def assert_or_get_global_step(graph=None, global_step_tensor=None): assert_global_step(global_step_tensor) return global_step_tensor -@deprecated(None, "Please switch to tf.train.get_global_step") + +@deprecated(None, 'Please switch to tf.train.get_global_step') def get_global_step(graph=None): return training_util.get_global_step(graph) -@deprecated(None, "Please switch to tf.train.create_global_step") + +@deprecated(None, 'Please switch to tf.train.create_global_step') def create_global_step(graph=None): """Create global step tensor in graph. This API is deprecated. Use core framework training version instead. Args: - graph: The graph in which to create the global step tensor. If missing, - use default graph. + graph: The graph in which to create the global step tensor. If missing, use + default graph. Returns: Global step tensor. @@ -141,7 +143,8 @@ def create_global_step(graph=None): """ return training_util.create_global_step(graph) -@deprecated(None, "Please switch to tf.train.get_or_create_global_step") + +@deprecated(None, 'Please switch to tf.train.get_or_create_global_step') def get_or_create_global_step(graph=None): """Returns and create (if necessary) the global step tensor. @@ -166,11 +169,13 @@ def local_variable(initial_value, validate_shape: See variables.Variable.__init__. name: See variables.Variable.__init__. use_resource: If `True` use a ResourceVariable instead of a Variable. + Returns: New variable. """ return variable_scope.variable( - initial_value, trainable=False, + initial_value, + trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], validate_shape=validate_shape, use_resource=use_resource, @@ -188,11 +193,13 @@ def global_variable(initial_value, validate_shape: See variables.Variable.__init__. name: See variables.Variable.__init__. use_resource: If `True` use a ResourceVariable instead of a Variable. + Returns: New variable. """ return variable_scope.variable( - initial_value, trainable=False, + initial_value, + trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES], validate_shape=validate_shape, use_resource=use_resource, @@ -221,30 +228,29 @@ def variable(name, shape: shape of the new or existing variable. dtype: type of the new or existing variable (defaults to `DT_FLOAT`). initializer: initializer for the variable if one is created. - regularizer: a (Tensor -> Tensor or None) function; the result of - applying it on a newly created variable will be added to the collection - GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. + regularizer: a (Tensor -> Tensor or None) function; the result of applying + it on a newly created variable will be added to the collection + GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. trainable: If `True` also add the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). collections: A list of collection names to which the Variable will be added. If None it would default to `tf.GraphKeys.GLOBAL_VARIABLES`. caching_device: Optional device string or function describing where the - Variable should be cached for reading. Defaults to the Variable's - device. + Variable should be cached for reading. Defaults to the Variable's device. device: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable. partitioner: Optional callable that accepts a fully defined `TensorShape` and dtype of the `Variable` to be created, and returns a list of partitions for each axis (currently only one axis can be partitioned). - custom_getter: Callable that allows overwriting the internal - get_variable method and has to have the same signature. + custom_getter: Callable that allows overwriting the internal get_variable + method and has to have the same signature. use_resource: If `True` use a ResourceVariable instead of a Variable. - synchronization: Indicates when a distributed a variable will be - aggregated. Accepted values are constants defined in the class + synchronization: Indicates when a distributed a variable will be aggregated. + Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to - `AUTO` and the current `DistributionStrategy` chooses - when to synchronize. If `synchronization` is set to `ON_READ`, - `trainable` must not be set to `True`. + `AUTO` and the current `DistributionStrategy` chooses when to synchronize. + If `synchronization` is set to `ON_READ`, `trainable` must not be set to + `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. @@ -252,15 +258,15 @@ def variable(name, Returns: The created or existing variable. """ - collections = list(collections if collections is not None - else [ops.GraphKeys.GLOBAL_VARIABLES]) + collections = list(collections if collections is not None else + [ops.GraphKeys.GLOBAL_VARIABLES]) # Remove duplicates collections = list(set(collections)) getter = variable_scope.get_variable if custom_getter is not None: - getter = functools.partial(custom_getter, - reuse=variable_scope.get_variable_scope().reuse) + getter = functools.partial( + custom_getter, reuse=variable_scope.get_variable_scope().reuse) with ops.device(device or ''): return getter( name, @@ -299,31 +305,30 @@ def model_variable(name, shape: shape of the new or existing variable. dtype: type of the new or existing variable (defaults to `DT_FLOAT`). initializer: initializer for the variable if one is created. - regularizer: a (Tensor -> Tensor or None) function; the result of - applying it on a newly created variable will be added to the collection - GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. + regularizer: a (Tensor -> Tensor or None) function; the result of applying + it on a newly created variable will be added to the collection + GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. trainable: If `True` also add the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). collections: A list of collection names to which the Variable will be added. Note that the variable is always also added to the `GraphKeys.GLOBAL_VARIABLES` and `GraphKeys.MODEL_VARIABLES` collections. caching_device: Optional device string or function describing where the - Variable should be cached for reading. Defaults to the Variable's - device. + Variable should be cached for reading. Defaults to the Variable's device. device: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable. partitioner: Optional callable that accepts a fully defined `TensorShape` and dtype of the `Variable` to be created, and returns a list of partitions for each axis (currently only one axis can be partitioned). - custom_getter: Callable that allows overwriting the internal - get_variable method and has to have the same signature. + custom_getter: Callable that allows overwriting the internal get_variable + method and has to have the same signature. use_resource: If `True` use a ResourceVariable instead of a Variable. - synchronization: Indicates when a distributed a variable will be - aggregated. Accepted values are constants defined in the class + synchronization: Indicates when a distributed a variable will be aggregated. + Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to - `AUTO` and the current `DistributionStrategy` chooses - when to synchronize. If `synchronization` is set to `ON_READ`, - `trainable` must not be set to `True`. + `AUTO` and the current `DistributionStrategy` chooses when to synchronize. + If `synchronization` is set to `ON_READ`, `trainable` must not be set to + `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. @@ -361,7 +366,8 @@ def add_model_variable(var): ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var) -def get_variables(scope=None, suffix=None, +def get_variables(scope=None, + suffix=None, collection=ops.GraphKeys.GLOBAL_VARIABLES): """Gets the list of variables, filtered by scope and/or suffix. @@ -535,7 +541,7 @@ def assign_from_values(var_names_to_values): if not var: raise ValueError('Variable %s wasn\'t found' % var_name) elif len(var) > 1: - # tf.get_collection is just a filter on the prefix: find the exact match: + # tf.compat.v1.get_collection is just a filter on the prefix: find the exact match: found = False for v in var: if v.op.name == var_name: @@ -574,15 +580,18 @@ def assign_from_values_fn(var_names_to_values): var_names_to_values: A map from variable names to values. Returns: - A function that takes a single argument, a `tf.Session`, that applies the + A function that takes a single argument, a `tf.compat.v1.Session`, that + applies the assignment operation. Raises: ValueError: if any of the given variable names were not found. """ assign_op, feed_dict = assign_from_values(var_names_to_values) + def callback(session): return session.run(assign_op, feed_dict) + return callback @@ -619,16 +628,15 @@ def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): Args: model_path: The full path to the model checkpoint. To get latest checkpoint - use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` - var_list: A list of (possibly partitioned) `Variable` objects - or a dictionary mapping names in the checkpoint to the - corresponding variables or list of variables to initialize - from that checkpoint value. For partitioned Variables, the - name in the checkpoint must be the full variable, not the - name of the partitioned variable, eg. "my_var" rather than - "my_var/part_4". If empty, returns no_op(), {}. + use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` + var_list: A list of (possibly partitioned) `Variable` objects or a + dictionary mapping names in the checkpoint to the corresponding variables + or list of variables to initialize from that checkpoint value. For + partitioned Variables, the name in the checkpoint must be the full + variable, not the name of the partitioned variable, eg. "my_var" rather + than "my_var/part_4". If empty, returns no_op(), {}. ignore_missing_vars: Boolean, if True ignore variables missing in the - checkpoint with a warning instead of failing. + checkpoint with a warning instead of failing. Returns: the restore_op and the feed_dict that need to be run to restore var_list. @@ -682,8 +690,8 @@ def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): if var.get_shape() != ckpt_value.shape: raise ValueError( 'Total size of new array must be unchanged for %s ' - 'lh_shape: [%s], rh_shape: [%s]' - % (ckpt_name, str(ckpt_value.shape), str(var.get_shape()))) + 'lh_shape: [%s], rh_shape: [%s]' % + (ckpt_name, str(ckpt_value.shape), str(var.get_shape()))) feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape) else: @@ -697,10 +705,14 @@ def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): assign_op = control_flow_ops.group(*assign_ops) return assign_op, feed_dict + + # pylint: enable=protected-access -def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, +def assign_from_checkpoint_fn(model_path, + var_list, + ignore_missing_vars=False, reshape_variables=False): """Returns a function that assigns specific variables from a checkpoint. @@ -709,18 +721,19 @@ def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, Args: model_path: The full path to the model checkpoint. To get latest checkpoint - use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` + use `model_path = tf.train.latest_checkpoint(checkpoint_dir)` var_list: A list of `Variable` objects or a dictionary mapping names in the - checkpoint to the corresponding variables to initialize. If empty or - `None`, it would return `no_op(), None`. + checkpoint to the corresponding variables to initialize. If empty or + `None`, it would return `no_op(), None`. ignore_missing_vars: Boolean, if True it would ignore variables missing in - the checkpoint with a warning instead of failing. + the checkpoint with a warning instead of failing. reshape_variables: Boolean, if True it would automatically reshape variables - which are of different shape then the ones stored in the checkpoint but - which have the same number of elements. + which are of different shape then the ones stored in the checkpoint but + which have the same number of elements. Returns: - A function that takes a single argument, a `tf.Session`, that applies the + A function that takes a single argument, a `tf.compat.v1.Session`, that + applies the assignment operation. If no matching variables were found in the checkpoint then `None` is returned. @@ -740,14 +753,17 @@ def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, if reader.has_tensor(var): available_vars[var] = var_dict[var] else: - logging.warning( - 'Variable %s missing in checkpoint %s', var, model_path) + logging.warning('Variable %s missing in checkpoint %s', var, model_path) var_list = available_vars if var_list: - saver = tf_saver.Saver(var_list, reshape=reshape_variables, - write_version=saver_pb2.SaverDef.V1) + saver = tf_saver.Saver( + var_list, + reshape=reshape_variables, + write_version=saver_pb2.SaverDef.V1) + def callback(session): saver.restore(session, model_path) + return callback else: logging.warning('No Variables to restore') @@ -781,8 +797,8 @@ class VariableDeviceChooser(object): num_tasks: number of tasks. job_name: String, a name for the parameter server job. device_type: Optional device type string (e.g. "CPU" or "GPU") - device_index: int. Optional device index. If left - unspecified, device represents 'any' device_index. + device_index: int. Optional device index. If left unspecified, device + represents 'any' device_index. """ self._job_name = job_name self._device_type = device_type @@ -804,7 +820,9 @@ class VariableDeviceChooser(object): return device_spec.to_string() -def filter_variables(var_list, include_patterns=None, exclude_patterns=None, +def filter_variables(var_list, + include_patterns=None, + exclude_patterns=None, reg_search=True): """Filter a list of variables using regular expressions. @@ -825,15 +843,15 @@ def filter_variables(var_list, include_patterns=None, exclude_patterns=None, Args: var_list: list of variables. include_patterns: list of regular expressions to include. Defaults to None, - which means all variables are selected according to the include rules. - A variable is included if it matches any of the include_patterns. + which means all variables are selected according to the include rules. A + variable is included if it matches any of the include_patterns. exclude_patterns: list of regular expressions to exclude. Defaults to None, - which means all variables are selected according to the exclude rules. - A variable is excluded if it matches any of the exclude_patterns. + which means all variables are selected according to the exclude rules. A + variable is excluded if it matches any of the exclude_patterns. reg_search: boolean. If True (default), performs re.search to find matches - (i.e. pattern can match any substring of the variable name). If False, - performs re.match (i.e. regexp should match from the beginning of the - variable name). + (i.e. pattern can match any substring of the variable name). If False, + performs re.match (i.e. regexp should match from the beginning of the + variable name). Returns: filtered list of variables. diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index f65f450eba4..2dfbd646a65 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -19,10 +19,10 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "tf_kernel_library", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", + "tf_kernel_library", ) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -68,6 +68,8 @@ tf_kernel_library( prefix = "fused_conv2d_bias_activation_op", visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core:conv_autotuning_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", @@ -75,6 +77,7 @@ tf_kernel_library( "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:conv_2d_hdrs", "//tensorflow/core/kernels:conv_ops_gpu_hdrs", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "//third_party/eigen3", @@ -92,10 +95,13 @@ tf_custom_op_library( "ops/fused_conv2d_bias_activation_op.cc", ], deps = [ + "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core:conv_autotuning_proto_cc", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/kernels:bounds_check_lib", "//tensorflow/core/kernels:conv_2d_hdrs", "//tensorflow/core/kernels:conv_ops_gpu_hdrs", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "@local_config_cuda//cuda:cudnn_header", diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index f13a66717f6..9dda04f3929 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -17,6 +17,8 @@ limitations under the License. #define EIGEN_USE_GPU #endif // GOOGLE_CUDA +#define EIGEN_USE_THREADS + #include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h" #include "tensorflow/core/framework/bounds_check.h" @@ -33,16 +35,27 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/use_cudnn.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) + #if GOOGLE_CUDA -#include "cuda/include/cudnn.h" +#include "google/protobuf/duration.pb.h" +#include "absl/time/time.h" +#include "third_party/gpus/cudnn/cudnn.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/protobuf/autotuning.pb.h" +#include "tensorflow/core/protobuf/conv_autotuning.pb.h" #include "tensorflow/core/util/activation_mode.h" +#include "tensorflow/stream_executor/dnn.h" #endif // GOOGLE_CUDA namespace tensorflow { namespace { +typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template @@ -70,6 +83,174 @@ struct Int8x4ToInt32 { }; } // namespace +// WARNING: Packing specializations defined in eigen_spatial_convolutions.h do +// not support packing expressions of QInt8 type. However, default Eigen +// gebp_kernel for QInt8 is too slow to be considered useful for anything. +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) + +template +class LaunchFusedConv2DBiasActivationOp { + using T = qint8; // conv_input and filter type + using TempT = qint32; // temporary accumulator type for tensor contraction + + public: + void launch(OpKernelContext* ctx, bool cudnn_use_autotune, + const Tensor& conv_input, ScaleType conv_input_scale, + const Tensor& filter, int32 row_stride, int32 col_stride, + const Eigen::PaddingType& padding, const Tensor& side_input, + ScaleType side_input_scale, const Tensor& bias, + ActivationMode activation_mode, TensorFormat data_format, + FilterTensorFormat filter_format, Tensor* output) { + static_assert(std::is_same::value, + "Scale and Bias must be of the same type."); + + // Output tensor has type T (QInt8), but we can only evaluate Int8 Tensor + // contraction using 32-bit accumulation (QInt32). + Tensor temp_output(DT_QINT32, output->shape()); + + constexpr int32 row_dilation = 1; + constexpr int32 col_dilation = 1; + + auto& device = ctx->eigen_device(); + + // CPU convolution works with input in NHWC and filter in HWIO data formats. + // NOTE: This code is mostly shared with 'Conv2D' and 'FusedConv2D'. + + BiasActivationOutputKernel output_kernel(conv_input_scale, side_input, + side_input_scale, bias, + activation_mode, output); + + if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && + col_stride == 1) { + int conv_width = // Width for the convolution step. + output->dim_size(0) * output->dim_size(1) * output->dim_size(2); + + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + + auto out = temp_output.shaped({conv_width, filter.dim_size(3)}); + auto in0 = conv_input.shaped({conv_width, filter.dim_size(2)}); + auto in1 = filter.shaped({filter.dim_size(2), filter.dim_size(3)}); + + out.device(device) = in0.contract(in1, dim_pair, output_kernel); + + } else if (filter.dim_size(0) == conv_input.dim_size(1) && + filter.dim_size(1) == conv_input.dim_size(2) && + row_dilation == 1 && col_dilation == 1 && + padding == Eigen::PaddingType::PADDING_VALID) { + // If the input data and filter have the same height/width, + // reduce the 2D convolution to matrix multiplication. + const auto k = // Length of reduction dimension. + filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2); + + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + + auto out = temp_output.shaped( + {conv_input.dim_size(0), filter.dim_size(3)}); + auto in0 = conv_input.shaped({conv_input.dim_size(0), k}); + auto in1 = filter.shaped({k, filter.dim_size(3)}); + + out.device(device) = in0.contract(in1, dim_pair, output_kernel); + + } else { + auto out = temp_output.tensor(); + auto in0 = conv_input.tensor(); + auto in1 = filter.tensor(); + + // Need to swap row/col when calling Eigen. + out.device(device) = + Eigen::SpatialConvolution(in0, in1, col_stride, row_stride, padding, + col_dilation, row_dilation, output_kernel); + } + } + + private: + // Contraction output mapper for temporary QInt32 tensor. + using ContractionOutputMapper = + Eigen::internal::blas_data_mapper; + + // This output kernel computes an expressions corresponding to cuDNN + // implementation of INT8 cudnnConvolutionBiasActivationForward: + // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#scaling-parameters__fig-conv-bias-activation-forward + struct BiasActivationOutputKernel { + static constexpr int8 kMaxRange = 127; + static constexpr int8 kMinRange = -128; + + explicit BiasActivationOutputKernel(ScaleType conv_input_scale, + const Tensor& side_input, + ScaleType side_input_scale, + const Tensor& bias, + ActivationMode activation_mode, + Tensor* output) + : activation_mode(activation_mode), + conv_input_scale(conv_input_scale), + bias_data(bias.flat().data()), + side_input_data(side_input.flat().data()), + side_input_scale(side_input_scale), + output_data(const_cast(output->flat().data())) {} + + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& conv_output_mapper, + const Eigen::TensorContractionParams& params, Eigen::Index i, + Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const { + DCHECK(params.swapped_arguments); + + const auto stride = conv_output_mapper.stride(); + + const BiasType* bias_base = bias_data + i; + typename TTypes::UnalignedConstTensor bias(bias_base, num_rows); + + const T* side_input_base = side_input_data + i + j * stride; + T* output_base = output_data + i + j * stride; + + for (int col = 0; col < num_cols; ++col) { + // A column of an output tensor after QInt8xQInt8 -> QInt32 contraction. + // This is a temporary tensor, that we will scale, add bias with + // side_input, and quantize before writing to final output tensor. + typename TTypes::UnalignedConstTensor conv_output( + &conv_output_mapper(0, col), num_rows); + + // A column of side input tensor corresponding to conv output row. + typename TTypes::UnalignedConstTensor side_input( + side_input_base + col * stride, num_rows); + + // A column of output quantized tensor corresponding to conv output row. + typename TTypes::UnalignedTensor output(output_base + col * stride, + num_rows); + + auto conv_output_scaled = + conv_output.cast() * conv_input_scale; + ScaleType lower_bound = (activation_mode == ActivationMode::NONE + ? static_cast(kMinRange) + : 0); + if (side_input_scale == 0.0f) { + output = (conv_output_scaled + bias) + .round() + .clip(lower_bound, static_cast(kMaxRange)) + .template cast(); + } else { + auto side_input_scaled = + side_input.cast() * side_input_scale; + output = (conv_output_scaled + bias + side_input_scaled) + .round() + .clip(lower_bound, static_cast(kMaxRange)) + .template cast(); + } + } + } + + private: + ActivationMode activation_mode; + ScaleType conv_input_scale; + const BiasType* bias_data; + const T* side_input_data; + ScaleType side_input_scale; + T* output_data; + }; +}; +#endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) + // T is the element type of the conv_input, filter and side_input tensors. // BiasType is the element type of the bias tensor, which can be different. // ScaleType is the type used for conv_input_scale, side_input_scale. @@ -114,19 +295,45 @@ class FusedConv2DBiasActivationOp : public OpKernel { errors::Unimplemented("Convolutional strides are not supported in " "the batch and depth dimensions.")); - // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. - constexpr bool is_int8x4 = std::is_same::value; + std::vector dilations; + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations)); + OP_REQUIRES(context, dilations == std::vector({1, 1, 1, 1}), + errors::InvalidArgument("Dilations must be all equal to 1.")); - // Note: Only NCHW_VECT_C format is supported for int8. - // This is because it is expected to be the fastest, and our previous tests - // found cudnn 6 does not fully support the other formats for int8 mode. - OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)), - errors::InvalidArgument( - "qint8 should be used with data_format NCHW_VECT_C.")); + constexpr bool is_cpu = std::is_same::value; + constexpr bool is_gpu = std::is_same::value; + OP_REQUIRES(context, is_cpu || is_gpu, + errors::InvalidArgument("Unknown Device type.")); - OP_REQUIRES(context, (is_int8x4 == (filter_format_ == FORMAT_OIHW_VECT_I)), - errors::InvalidArgument( - "qint8 should be used with filter_format OIHW_VECT_I.")); + constexpr bool is_qint8 = std::is_same::value; + + if (is_qint8 && is_gpu) { + // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. + + // Note: Only NCHW_VECT_C format is supported for int8 on GPU. + // This is because it is expected to be the fastest, and our previous + // tests found cudnn 6 does not fully support the other formats for int8 + // mode. + OP_REQUIRES( + context, data_format_ == FORMAT_NCHW_VECT_C, + errors::InvalidArgument( + "qint8 should be used with data_format NCHW_VECT_C on GPU.")); + OP_REQUIRES( + context, filter_format_ == FORMAT_OIHW_VECT_I, + errors::InvalidArgument( + "qint8 should be used with filter_format OIHW_VECT_I on GPU.")); + + } else if (is_qint8 && is_cpu) { + // On CPU we implement convolution with Eigen Tensor contraction, it + // requries NHWC and HWIO formats for input and kernel. + + OP_REQUIRES(context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "qint8 should be used with data_format NHWC on CPU.")); + OP_REQUIRES(context, filter_format_ == FORMAT_HWIO, + errors::InvalidArgument( + "qint8 should be used with filter_format HWIO on CPU.")); + } OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_)); eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_); @@ -249,9 +456,132 @@ class FusedConv2DBiasActivationOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp); }; +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +REGISTER_KERNEL_BUILDER( + Name("FusedConv2DBiasActivation") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .TypeConstraint("Tbias"), + FusedConv2DBiasActivationOp); +#endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) + #if GOOGLE_CUDA namespace dnn = se::dnn; +// Several functions are copyed over from tensorflow/core/kernels/gpu_utils, +// since this file may be compiled down to a tf_custom_op_library .so file, +// which can't depend on basic dependencies like tensorflow/core:lib. Instead, +// the code has to depend on whatever is the same in libtensorflow_framework.so. +// +// In theory, we can lift the dependencies of gpu_utils by turning it into a +// template library that provides duck typing, but I think duplication is the +// lesser of two evils. +namespace internal { +namespace { + +tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { + tensorflow::CudnnVersion cudnn_version; + if (auto* dnn = stream_executor->AsDnn()) { + se::port::StatusOr version_or = dnn->GetVersion(); + if (version_or.ok()) { + const auto& version = version_or.ValueOrDie(); + cudnn_version.set_major(version.major_version()); + cudnn_version.set_minor(version.minor_version()); + cudnn_version.set_patch(version.patch()); + } + } + return cudnn_version; +} + +// Converts an absl::Duration to a google::protobuf::Duration. +inline google::protobuf::Duration ToDurationProto(absl::Duration duration) { + google::protobuf::Duration proto; + proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); + proto.set_nanos( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + return proto; +} + +// Converts a google::protobuf::Duration to an absl::Duration. +inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + +tensorflow::ComputeCapability GetComputeCapability( + se::StreamExecutor* stream_executor) { + tensorflow::ComputeCapability cc; + int cc_major, cc_minor; + stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + cc.set_major(cc_major); + cc.set_minor(cc_minor); + return cc; +} + +void LogFusedConvForwardAutotuneResults( + se::dnn::DataType element_type, const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, double conv_scale, + double side_value_scale, se::dnn::ActivationMode activation_mode, + se::StreamExecutor* stream_exec, absl::Span results) { + AutotuningLog log; + { + ConvolutionProto instr; + instr.set_kind(se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION); + *instr.mutable_input() = input_desc.ToProto(element_type); + *instr.mutable_filter() = filter_desc.ToProto(element_type); + *instr.mutable_output() = output_desc.ToProto(element_type); + *instr.mutable_conv_desc() = conv_desc.ToProto(); + instr.set_conv_scale(conv_scale); + instr.set_side_value_scale(side_value_scale); + instr.set_activation(activation_mode); + log.mutable_instr()->PackFrom(std::move(instr)); + } + *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec); + *log.mutable_compute_capability() = GetComputeCapability(stream_exec); + log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id()); + for (const auto& result : results) { + *log.add_results() = result; + } + Logger::Singleton()->LogProto(log); +} + +Status BestCudnnConvAlgorithm(absl::Span results, + se::dnn::AlgorithmConfig* algo) { + const AutotuneResult* best_result = std::min_element( + results.begin(), results.end(), + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return internal::FromDurationProto(lhs.run_time()) < + internal::FromDurationProto(rhs.run_time()); + }); + + const AutotuneResult* best_result_no_scratch = std::min_element( + results.begin(), results.end(), + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return std::make_tuple(lhs.scratch_bytes(), + internal::FromDurationProto(lhs.run_time())) < + std::make_tuple(rhs.scratch_bytes(), + internal::FromDurationProto(rhs.run_time())); + }); + + if (best_result == results.end()) { + return errors::NotFound("No algorithm worked!"); + } + algo->set_algorithm({best_result->conv().algorithm(), + best_result->conv().tensor_ops_enabled()}); + if (best_result_no_scratch != results.end() && + best_result_no_scratch->scratch_bytes() == 0) { + algo->set_algorithm_no_scratch( + {best_result_no_scratch->conv().algorithm(), + best_result_no_scratch->conv().tensor_ops_enabled()}); + } + return Status::OK(); +} + +} // namespace +} // namespace internal + // A dummy type to group forward convolution autotune results together. struct ConvBiasActivationAutoTuneGroup { static string name() { return "ConvBiasActivation"; } @@ -579,8 +909,7 @@ void LaunchFusedConv2DBiasActivationOp:: }), algorithms.end()); } - dnn::ProfileResult best_result; - dnn::ProfileResult best_result_no_scratch; + std::vector results; for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. @@ -595,30 +924,23 @@ void LaunchFusedConv2DBiasActivationOp:: output_desc, &output_ptr, &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), &profile_result) .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } - } + if (cudnn_launch_status && profile_result.is_valid()) { + results.emplace_back(); + auto& result = results.back(); + result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled( + profile_algorithm.tensor_ops_enabled()); + result.set_scratch_bytes(scratch_allocator.TotalByteSize()); + *result.mutable_run_time() = internal::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } - OP_REQUIRES(ctx, - best_result.is_valid() || best_result_no_scratch.is_valid(), - errors::NotFound("No algorithm worked!")); - if (best_result.is_valid()) { - algorithm_config.set_algorithm(best_result.algorithm()); - } - if (best_result_no_scratch.is_valid()) { - algorithm_config.set_algorithm_no_scratch( - best_result_no_scratch.algorithm()); - } + internal::LogFusedConvForwardAutotuneResults( + se::dnn::ToDataType::type>::value, conv_input_desc, + filter_desc, output_desc, conv_desc, conv_input_scale, side_input_scale, + dnn_activation_mode, stream->parent(), results); + OP_REQUIRES_OK( + ctx, internal::BestCudnnConvAlgorithm(results, &algorithm_config)); AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters, algorithm_config); } diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h index 869e899ac87..19236ff4247 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h @@ -22,8 +22,13 @@ limitations under the License. #include "tensorflow/core/util/activation_mode.h" #include "tensorflow/core/util/tensor_format.h" -#if GOOGLE_CUDA +// FixedPoint header must be included after Tensor. +// clang-format off #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" +// clang-format on + +#if GOOGLE_CUDA #include "tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index e5c8a34fc14..6cc5d697efe 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -23,14 +23,19 @@ from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activatio from tensorflow.python.platform import test -# Instantiate the two test suites from test_base, mixing in test.TestCase as +# Instantiate three test suites from test_base, mixing in test.TestCase as # the test framework. class FusedConv2DBiasActivationTest(test_base.FusedConv2DBiasActivationTest, test.TestCase): pass -class FusedConvInt8Tests(test_base.FusedConvInt8Tests, test.TestCase): +class FusedConvInt8CPUTests(test_base.FusedConvInt8CPUTests, test.TestCase): + pass + + +class FusedConvInt8CorrespondenceTests( + test_base.FusedConvInt8CorrespondenceTests, test.TestCase): pass diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py index 35fc65e4ba8..04edc7593a2 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Provides test suites that can be run to test fused convolutions. Each of the two test suites in this module, FusedConv2DBiasActivationTest and @@ -34,9 +33,11 @@ from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activatio from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -236,36 +237,35 @@ class FusedConv2DBiasActivationTest(object): # This is to guarantee that there are always negative values after # bias add so that we can test whether relu works correctly. x3 = bias - with self.cached_session(use_gpu=True), self.test_scope(): - t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) - t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) - fused_t2 = t2 - if filter_format == "OIHW": - fused_t2 = _HwioToOihw(t2) - t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype) - strides = [1] + strides + [1] - if data_format == "NCHW": - t1 = test_util.NHWCToNCHW(t1) - strides = test_util.NHWCToNCHW(strides) - output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - t1, - fused_t2, - t3, - strides=strides, - padding=padding, - data_format=data_format, - filter_format=filter_format, - activation_mode=activation_mode) - ref_conv_output = nn_ops.conv2d( - t1, t2, strides=strides, padding=padding, data_format=data_format) - ref_bias_output = nn_ops.bias_add( - ref_conv_output, t3, data_format=data_format) - ref_output = nn_ops.relu(ref_bias_output) - if data_format == "NCHW": - output = test_util.NCHWToNHWC(output) - ref_output = test_util.NCHWToNHWC(ref_output) + t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) + t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) + fused_t2 = t2 + if filter_format == "OIHW": + fused_t2 = _HwioToOihw(t2) + t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype) + strides = [1] + strides + [1] + if data_format == "NCHW": + t1 = test_util.NHWCToNCHW(t1) + strides = test_util.NHWCToNCHW(strides) + output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + t1, + fused_t2, + t3, + strides=strides, + padding=padding, + data_format=data_format, + filter_format=filter_format, + activation_mode=activation_mode) + ref_conv_output = nn_ops.conv2d( + t1, t2, strides=strides, padding=padding, data_format=data_format) + ref_bias_output = nn_ops.bias_add( + ref_conv_output, t3, data_format=data_format) + ref_output = nn_ops.relu(ref_bias_output) + if data_format == "NCHW": + output = test_util.NCHWToNHWC(output) + ref_output = test_util.NCHWToNHWC(ref_output) - return output, ref_output + return output, ref_output def CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides, padding): @@ -284,62 +284,62 @@ class FusedConv2DBiasActivationTest(object): x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) def _SetupVal(data_format, use_gpu): - with self.cached_session(use_gpu=use_gpu), self.test_scope(): - t1 = constant_op.constant(x1, shape=tensor_in_sizes) - t2 = constant_op.constant(x2, shape=filter_in_sizes) - t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) - strides = [1] + conv_strides + [1] - if data_format == "NCHW": - t1 = test_util.NHWCToNCHW(t1) - strides = test_util.NHWCToNCHW(strides) - output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - t1, - t2, - t3, - strides=strides, - padding=padding, - data_format=data_format, - activation_mode="Relu") + t1 = constant_op.constant(x1, shape=tensor_in_sizes) + t2 = constant_op.constant(x2, shape=filter_in_sizes) + t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) + strides = [1] + conv_strides + [1] + if data_format == "NCHW": + t1 = test_util.NHWCToNCHW(t1) + strides = test_util.NHWCToNCHW(strides) + output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + t1, + t2, + t3, + strides=strides, + padding=padding, + data_format=data_format, + activation_mode="Relu") - if data_format == "NCHW": - output = test_util.NCHWToNHWC(output) - return output + if data_format == "NCHW": + output = test_util.NCHWToNHWC(output) + return output - tensors = [] - for (data_format, use_gpu) in _GetTestConfigs(): - tensors.append(_SetupVal(data_format, use_gpu)) - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): + tensors = [] + for (data_format, use_gpu) in _GetTestConfigs(): + tensors.append(_SetupVal(data_format, use_gpu)) values = sess.run(tensors) - for i in range(1, len(values)): - self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3) + for i in range(1, len(values)): + self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3) def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides, padding): - tensors = [] - ref_tensors = [] - for (data_format, use_gpu) in _GetTestConfigs(): - for dtype in self._DtypesToTest(use_gpu): - for filter_format in self._FilterFormatsToTest(use_gpu): - result, expected = self._SetupValuesForDevice( - tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu", - data_format, filter_format, dtype) - tensors.append(result) - ref_tensors.append(expected) - with self.cached_session() as sess, self.test_scope(): - values = sess.run(tensors) - ref_values = sess.run(ref_tensors) - for i in range(len(tensors)): - conv = tensors[i] - value = values[i] - ref_value = ref_values[i] - tf_logging.info("expected = %s", ref_value) - tf_logging.info("actual = %s", value) - tol = 1e-5 - if value.dtype == np.float16: - tol = 1e-3 - self.assertAllClose( - np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol) - self.assertShapeEqual(value, conv) + with self.session() as sess, self.test_scope(): + tensors = [] + ref_tensors = [] + for (data_format, use_gpu) in _GetTestConfigs(): + for dtype in self._DtypesToTest(use_gpu): + for filter_format in self._FilterFormatsToTest(use_gpu): + result, expected = self._SetupValuesForDevice( + tensor_in_sizes, filter_in_sizes, bias, strides, padding, + "Relu", data_format, filter_format, dtype) + tensors.append(result) + ref_tensors.append(expected) + + values = sess.run(tensors) + ref_values = sess.run(ref_tensors) + for i in range(len(tensors)): + conv = tensors[i] + value = values[i] + ref_value = ref_values[i] + tf_logging.info("expected = %s", ref_value) + tf_logging.info("actual = %s", value) + tol = 1e-5 + if value.dtype == np.float16: + tol = 1e-3 + self.assertAllClose( + np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol) + self.assertShapeEqual(value, conv) def testConv2D1x1Filter(self, gpu_only=True): if gpu_only and not test.is_gpu_available(): @@ -536,7 +536,7 @@ class FusedConv2DBiasActivationTest(object): if gpu_only and not test.is_gpu_available(): tf_logging.info("Skipping OpEdgeCases tests.") return - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): # Illegal strides. with self.assertRaisesRegexp( errors_impl.UnimplementedError, @@ -636,81 +636,8 @@ def _CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type): return (input_dim + stride - 1) // stride -def _NchwVectCToNchw(in_tensor): - # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W] - t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3]) - n = in_tensor.shape.dims[0].value - c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value - h = in_tensor.shape.dims[2].value - w = in_tensor.shape.dims[3].value - return array_ops.reshape(t, [n, c, h, w]) - - -def _OihwVectIToHwio(in_tensor): - # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W] - t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0]) - o = in_tensor.shape.dims[0].value - i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value - h = in_tensor.shape.dims[2].value - w = in_tensor.shape.dims[3].value - return array_ops.reshape(t, [h, w, i, o]) - - -def _NchwToNchwVectC(in_tensor): - n, c, h, w = in_tensor.shape.as_list() - assert c % 4 == 0 - t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w]) - return array_ops.transpose(t, [0, 1, 3, 4, 2]) - - -def _HwioToOihw(in_tensor): - return array_ops.transpose(in_tensor, [3, 2, 0, 1]) - - -def _SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, - padding, strides, side_input_scale, - side_input, biases, apply_relu): - """Simulates the int8 fused 2-D convolution op using separate float ops. - - The arguments and return values have the same format, meanings and - restrictions as the actual op. - Args: - conv_input_scale: A scalar 'float'. - conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. - kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout. - padding: A `string` from: `"SAME", "VALID"`. - strides: A list of `ints`. - side_input_scale: A scalar 'float'. - side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. - biases: A `Tensor` of type `float32` in NCHW layout. - apply_relu: A boolean to specify whether to apply "Relu" activation function - that clips outputs to the range [0, 127], or "None" activation that clips - to the range [-128, 127]. - Returns: - A `Tensor` of type `qint8` in NCHW_VECT_C layout. - """ - conv_result = nn_ops.conv2d( - _NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)), - _OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)), - strides=strides, - padding=padding, - data_format="NCHW") * conv_input_scale - - conv_and_side_inputs = conv_result + side_input_scale * _NchwVectCToNchw( - gen_array_ops.dequantize(side_input, -128, 127)) - - output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") - if apply_relu: - output = nn_ops.relu(output) - - result, _, _ = gen_array_ops.quantize_v2( - _NchwToNchwVectC(output), -128, 127, dtypes.qint8) - return result - - -# TODO(b/114580749): XLA:CPU/GPU don't support int8 at the moment, so this test -# doesn't currently use XLA. -class FusedConvInt8Tests(object): +def _GetFusedConvInt8TestParams(): + """Returns test parameters shared by all Int8 FusedConv tests.""" _test_params = [ { "batch_size": 1, @@ -848,6 +775,111 @@ class FusedConvInt8Tests(object): "padding_type": "SAME" }, ] + return _test_params + + +def _Int8Roundtrip(fn, tensor): + return array_ops.bitcast( + fn(array_ops.bitcast(tensor, dtypes.int8)), dtypes.qint8) + + +def _NchwVectCToNchw(in_tensor): + # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W] + t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3]) + n = in_tensor.shape.dims[0].value + c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + return array_ops.reshape(t, [n, c, h, w]) + + +def _NchwVectCToNhwc(in_tensor): + # [N, C / 4, H, W, 4] => [N, H, W, C / 4, 4] == [N, H, W, C] + t = array_ops.transpose(in_tensor, [0, 2, 3, 1, 4]) + n = in_tensor.shape.dims[0].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + return array_ops.reshape(t, [n, h, w, c]) + + +def _OihwVectIToHwio(in_tensor): + # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W] + t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0]) + o = in_tensor.shape.dims[0].value + i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + return array_ops.reshape(t, [h, w, i, o]) + + +def _NchwToNchwVectC(in_tensor): + n, c, h, w = in_tensor.shape.as_list() + assert c % 4 == 0 + t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w]) + return array_ops.transpose(t, [0, 1, 3, 4, 2]) + + +def _NhwcToNchwVectC(in_tensor): + # [H, H, W, C] => [N, H, W, C //4, 4] => [N, C / 4, H, W, 4] + n, h, w, c = in_tensor.shape.as_list() + assert c % 4 == 0 + t = array_ops.reshape(in_tensor, [n, h, w, c // 4, 4]) + return array_ops.transpose(t, [0, 3, 1, 2, 4]) + + +def _HwioToOihw(in_tensor): + return array_ops.transpose(in_tensor, [3, 2, 0, 1]) + + +def _SimulateFusedConv2dBiasActivationInt8OnCpu(conv_input_scale, conv_input, + kernel, padding, strides, + side_input_scale, side_input, + biases, apply_relu): + """Simulates the int8 fused 2-D convolution op using separate float ops. + + The arguments and return values have the same format, meanings and + restrictions as the actual op. + + Args: + conv_input_scale: A scalar 'float'. + conv_input: A `Tensor` of type `qint8` in NHWC layout. + kernel: A `Tensor` of type `qint8` in HWIO layout. + padding: A `string` from: `"SAME", "VALID"`. + strides: A list of `ints`. + side_input_scale: A scalar 'float'. + side_input: A `Tensor` of type `qint8` in NHWC layout. + biases: A `Tensor` of type `float32` in NHWC layout. + apply_relu: A boolean to specify whether to apply "Relu" activation function + that clips outputs to the range [0, 127], or "None" activation that clips + to the range [-128, 127]. + + Returns: + A `Tensor` of type `qint8` in NHWC layout. + """ + conv_result = nn_ops.conv2d( + math_ops.cast(conv_input, dtypes.float32), + math_ops.cast(kernel, dtypes.float32), + strides=strides, + padding=padding, + data_format="NHWC") * conv_input_scale + + conv_and_side_inputs = conv_result + side_input_scale * math_ops.cast( + side_input, dtypes.float32) + + output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NHWC") + if apply_relu: + output = nn_ops.relu(output) + + # In this case quantization is identical to clipping and casting. + result, _, _ = gen_array_ops.quantize_v2(output, -128, 127, dtypes.qint8) + return result + + +# FusedConv2DBiasActivation on CPU supports only NHWC/HWIO data format. +class FusedConvInt8CPUTests(object): + """Verify quantization with CPU kernel.""" + _test_params = _GetFusedConvInt8TestParams() @contextlib.contextmanager def test_scope(self): # pylint: disable=invalid-name @@ -855,6 +887,8 @@ class FusedConvInt8Tests(object): yield def runTest(self, test_param, apply_relu): + """Runs tests for dimensions configured in test_param.""" + batch_size = test_param["batch_size"] input_channels = test_param["input_channels"] output_channels = test_param["output_channels"] @@ -869,7 +903,98 @@ class FusedConvInt8Tests(object): bias_scale = test_param["bias_scale"] padding_type = test_param["padding_type"] - with self.cached_session(use_gpu=True) as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): + conv_input, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [batch_size, input_height, input_width, input_channels], + minval=-0.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + kernel, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [filter_height, filter_width, input_channels, output_channels], + minval=-1.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + output_height = _CalculateConvolvedOutputDim(input_height, filter_height, + vertical_stride, + padding_type) + output_width = _CalculateConvolvedOutputDim(input_width, filter_width, + horizontal_stride, + padding_type) + tf_logging.info("output_height=%s, output_width=%s", output_height, + output_width) + + side_input, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [batch_size, output_height, output_width, output_channels], + minval=0.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + biases = random_ops.random_uniform([output_channels], + minval=-10 * bias_scale, + maxval=20 * bias_scale, + dtype=dtypes.float32) + + strides = [1, vertical_stride, horizontal_stride, 1] + + actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + conv_input, + kernel, + biases, + strides=strides, + padding=padding_type, + conv_input_scale=conv_input_scale, + side_input_scale=side_input_scale, + side_input=(None if side_input_scale == 0.0 else side_input), + activation_mode="Relu" if apply_relu else "None", + data_format="NHWC", + filter_format="HWIO") + + expected = _SimulateFusedConv2dBiasActivationInt8OnCpu( + conv_input_scale, conv_input, kernel, padding_type, strides, + side_input_scale, side_input, biases, apply_relu) + + actual_y, expected_y = sess.run([actual, expected]) + self.assertAllClose(actual_y, expected_y, rtol=0, atol=1) + + def testFusedConvInt8(self): + for apply_relu in [True, False]: + for test_param in self._test_params: + self.runTest(test_param, apply_relu) + + +# Test that GPU and CPU kernels produce identical results for QInt8 data type. +class FusedConvInt8CorrespondenceTests(object): + """Verify quantization with CPU kernel.""" + _test_params = _GetFusedConvInt8TestParams() + + @contextlib.contextmanager + def test_scope(self): # pylint: disable=invalid-name + """Can be overridden in base classes to provide a test scope.""" + yield + + def runTest(self, test_param, apply_relu): + """Runs tests for dimensions configured in test_param.""" + + batch_size = test_param["batch_size"] + input_channels = test_param["input_channels"] + output_channels = test_param["output_channels"] + input_height = test_param["input_height"] + input_width = test_param["input_width"] + filter_height = test_param["filter_height"] + filter_width = test_param["filter_width"] + vertical_stride = test_param["vertical_stride"] + horizontal_stride = test_param["horizontal_stride"] + conv_input_scale = test_param["conv_input_scale"] + side_input_scale = test_param["side_input_scale"] + bias_scale = test_param["bias_scale"] + padding_type = test_param["padding_type"] + + with self.session() as sess, self.test_scope(): conv_input, _, _ = gen_array_ops.quantize_v2( random_ops.random_uniform( [batch_size, input_channels // 4, input_height, input_width, 4], @@ -887,10 +1012,12 @@ class FusedConvInt8Tests(object): dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) - output_height = _CalculateConvolvedOutputDim( - input_height, filter_height, vertical_stride, padding_type) - output_width = _CalculateConvolvedOutputDim( - input_width, filter_width, horizontal_stride, padding_type) + output_height = _CalculateConvolvedOutputDim(input_height, filter_height, + vertical_stride, + padding_type) + output_width = _CalculateConvolvedOutputDim(input_width, filter_width, + horizontal_stride, + padding_type) tf_logging.info("output_height=%s, output_width=%s", output_height, output_width) @@ -908,27 +1035,39 @@ class FusedConvInt8Tests(object): maxval=20 * bias_scale, dtype=dtypes.float32) - strides = [1, 1, vertical_stride, horizontal_stride] + with ops.device("/cpu:0"): + t = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + _Int8Roundtrip(_NchwVectCToNhwc, conv_input), + _Int8Roundtrip(_OihwVectIToHwio, kernel), + biases, + strides=[1, vertical_stride, horizontal_stride, 1], + padding=padding_type, + conv_input_scale=conv_input_scale, + side_input_scale=side_input_scale, + side_input=(None if side_input_scale == 0.0 else _Int8Roundtrip( + _NchwVectCToNhwc, side_input)), + activation_mode="Relu" if apply_relu else "None", + data_format="NHWC", + filter_format="HWIO") + cpu_result = _Int8Roundtrip(_NhwcToNchwVectC, t) - actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - conv_input, - kernel, - biases, - strides=strides, - padding=padding_type, - conv_input_scale=conv_input_scale, - side_input_scale=side_input_scale, - side_input=side_input, - activation_mode="Relu" if apply_relu else "None", - data_format="NCHW_VECT_C", - filter_format="OIHW_VECT_I") + with ops.device("/gpu:0"): + t = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + conv_input, + kernel, + biases, + strides=[1, 1, vertical_stride, horizontal_stride], + padding=padding_type, + conv_input_scale=conv_input_scale, + side_input_scale=side_input_scale, + side_input=(None if side_input_scale == 0.0 else side_input), + activation_mode="Relu" if apply_relu else "None", + data_format="NCHW_VECT_C", + filter_format="OIHW_VECT_I") + gpu_result = t - expected = _SimulateFusedConv2dBiasActivationInt8( - conv_input_scale, conv_input, kernel, padding_type, strides, - side_input_scale, side_input, biases, apply_relu) - - actual_y, expected_y = sess.run([actual, expected]) - self.assertAllClose(actual_y, expected_y, rtol=0, atol=1) + cpu_y, gpu_y = sess.run([cpu_result, gpu_result]) + self.assertAllClose(cpu_y, gpu_y, rtol=0, atol=0) def testFusedConvInt8(self): if not test.is_gpu_available( diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 386e4cf69b7..3165e007996 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -1,4 +1,5 @@ # Files for using TF-GAN framework. + load("//tensorflow:tensorflow.bzl", "py_test") package(default_visibility = [ @@ -58,6 +59,7 @@ py_library( py_test( name = "train_test", srcs = ["python/train_test.py"], + python_version = "PY2", shard_count = 50, srcs_version = "PY2AND3", tags = ["notsan"], @@ -161,6 +163,7 @@ py_library( py_test( name = "losses_impl_test", srcs = ["python/losses/python/losses_impl_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":losses_impl", @@ -198,6 +201,7 @@ py_library( py_test( name = "tuple_losses_test", srcs = ["python/losses/python/tuple_losses_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":losses_impl", @@ -236,6 +240,7 @@ py_library( py_test( name = "conditioning_utils_test", srcs = ["python/features/python/conditioning_utils_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":conditioning_utils", @@ -266,6 +271,7 @@ py_library( py_test( name = "random_tensor_pool_test", srcs = ["python/features/python/random_tensor_pool_test.py"], + python_version = "PY2", shard_count = 6, srcs_version = "PY2AND3", deps = [ @@ -303,6 +309,7 @@ py_library( py_test( name = "virtual_batchnorm_test", srcs = ["python/features/python/virtual_batchnorm_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":virtual_batchnorm", @@ -338,6 +345,7 @@ py_library( py_test( name = "clip_weights_test", srcs = ["python/features/python/clip_weights_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":clip_weights", @@ -376,6 +384,7 @@ py_library( py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_pip", @@ -411,6 +420,7 @@ py_library( py_test( name = "eval_utils_test", srcs = ["python/eval/python/eval_utils_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":eval_utils", @@ -443,6 +453,7 @@ py_library( py_test( name = "summaries_test", srcs = ["python/eval/python/summaries_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":namedtuples", @@ -475,6 +486,7 @@ py_library( py_test( name = "head_test", srcs = ["python/estimator/python/head_test.py"], + python_version = "PY2", shard_count = 1, srcs_version = "PY2AND3", deps = [ @@ -512,6 +524,7 @@ py_library( py_test( name = "gan_estimator_test", srcs = ["python/estimator/python/gan_estimator_test.py"], + python_version = "PY2", shard_count = 1, srcs_version = "PY2AND3", tags = ["notsan"], @@ -567,6 +580,7 @@ py_library( py_test( name = "stargan_estimator_test", srcs = ["python/estimator/python/stargan_estimator_test.py"], + python_version = "PY2", shard_count = 1, srcs_version = "PY2AND3", tags = ["notsan"], @@ -617,6 +631,7 @@ py_library( py_test( name = "tpu_gan_estimator_test", srcs = ["python/estimator/python/tpu_gan_estimator_test.py"], + python_version = "PY2", shard_count = 11, srcs_version = "PY2AND3", tags = ["notsan"], @@ -670,6 +685,7 @@ py_test( srcs = [ "python/estimator/python/latent_gan_estimator_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":latent_gan_estimator", @@ -705,6 +721,7 @@ py_library( py_test( name = "sliced_wasserstein_test", srcs = ["python/eval/python/sliced_wasserstein_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":sliced_wasserstein", @@ -738,6 +755,7 @@ py_library( py_test( name = "spectral_normalization_test", srcs = ["python/features/python/spectral_normalization_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":spectral_normalization", diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index dd904611d1a..d234558d4da 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -85,8 +85,8 @@ class GANEstimator(estimator.Estimator): discriminator_fn=discriminator_fn, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), - discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5)) + generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5)) # Train estimator. gan_estimator.train(train_input_fn, steps) diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py index 2a485e7d47f..06a1480c072 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py @@ -79,8 +79,8 @@ class StarGANEstimator(estimator.Estimator): generator_fn=generator_fn, discriminator_fn=discriminator_fn, loss_fn=loss_fn, - generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), - discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5)) + generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5)) # Train estimator. stargan_estimator.train(train_input_fn, steps) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py index 8f2a22c78a3..8ed64e869a0 100644 --- a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py @@ -64,8 +64,8 @@ class TPUGANEstimator(tpu_estimator.TPUEstimator): discriminator_fn=discriminator_fn, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), - discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5), train_batch_size=4, config=config) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index ff19ce2f78e..2c301267900 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -140,7 +140,7 @@ def preprocess_image(images, is_single = images.shape.ndims == 3 with ops.name_scope(scope, 'preprocess', [images, height, width]): if not images.dtype.is_floating: - images = math_ops.to_float(images) + images = math_ops.cast(images, dtypes.float32) if is_single: images = array_ops.expand_dims(images, axis=0) resized = image_ops.resize_bilinear(images, [height, width]) @@ -189,7 +189,7 @@ def _kl_divergence(p, p_logits, q): def get_graph_def_from_disk(filename): """Get a GraphDef proto from a disk location.""" - with gfile.FastGFile(filename, 'rb') as f: + with gfile.GFile(filename, 'rb') as f: return graph_pb2.GraphDef.FromString(f.read()) @@ -1057,7 +1057,8 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, n_g = array_ops.shape(generated_activations)[0] n_bigger = math_ops.maximum(n_r, n_g) - n_blocks = math_ops.to_int32(math_ops.ceil(n_bigger / max_block_size)) + n_blocks = math_ops.cast(math_ops.ceil(n_bigger / max_block_size), + dtypes.int32) v_r = n_r // n_blocks v_g = n_g // n_blocks diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py index 4b1105f6bd4..9657d4e3d0c 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -28,6 +28,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -74,7 +75,7 @@ def _laplacian_pyramid(batch, num_levels): res = spatial_conv(res, 4) return res - pyramid = [math_ops.to_float(batch)] + pyramid = [math_ops.cast(batch, dtypes.float32)] for _ in range(1, num_levels): pyramid.append(pyr_down(pyramid[-1])) pyramid[-2] -= pyr_up(pyramid[-1]) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index c7bbd65bbff..3eb4f5db0c8 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import eval_utils +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import map_fn @@ -171,8 +172,10 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, gan_model.generated_data[:num_comparisons]) real_list = array_ops.unstack(gan_model.real_data[:num_comparisons]) diffs = [ - math_ops.abs(math_ops.to_float(generated) - math_ops.to_float(real)) for - generated, real in zip(generated_list, real_list)] + math_ops.abs(math_ops.cast(generated, dtypes.float32) - + math_ops.cast(real, dtypes.float32)) + for generated, real in zip(generated_list, real_list) + ] image_list.extend(diffs) # Reshape image and display. diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py index 3764c43cdfc..9004be6229f 100644 --- a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py @@ -191,9 +191,9 @@ def spectral_normalization_custom_getter(name_filter=_default_name_filter, of output channels. Apply this to layers by supplying this as the `custom_getter` of a - `tf.variable_scope`. For example: + `tf.compat.v1.variable_scope`. For example: - with tf.variable_scope('discriminator', + with tf.compat.v1.variable_scope('discriminator', custom_getter=spectral_norm_getter()): net = discriminator_fn(net) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py index f5c448db41c..030ce942607 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py @@ -52,8 +52,7 @@ def _statistics(x, axes): Args: x: A `Tensor`. - axes: Array of ints. Axes along which to compute mean and - variance. + axes: Array of ints. Axes along which to compute mean and variance. Returns: Two `Tensor` objects: `mean` and `square mean`. @@ -97,10 +96,12 @@ def _validate_init_input_and_get_axis(reference_batch, axis): def _validate_call_input(tensor_list, batch_dim): """Verifies that tensor shapes are compatible, except for `batch_dim`.""" + def _get_shape(tensor): shape = tensor.shape.as_list() del shape[batch_dim] return shape + base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0])) for tensor in tensor_list: base_shape.assert_is_compatible_with(_get_shape(tensor)) @@ -121,7 +122,8 @@ class VBN(object): Note that if `center` or `scale` variables are created, they are shared between all calls to this object. - The `__init__` API is intended to mimic `tf.layers.batch_normalization` as + The `__init__` API is intended to mimic + `tf.compat.v1.layers.batch_normalization` as closely as possible. """ @@ -157,9 +159,9 @@ class VBN(object): epsilon: Small float added to variance to avoid dividing by zero. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is - not used. When the next layer is linear (also e.g. `nn.relu`), this can - be disabled since the scaling can be done by the next layer. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. When + the next layer is linear (also e.g. `nn.relu`), this can be disabled + since the scaling can be done by the next layer. beta_initializer: Initializer for the beta weight. gamma_initializer: Initializer for the gamma weight. beta_regularizer: Optional regularizer for the beta weight. @@ -185,8 +187,8 @@ class VBN(object): if axis == self._batch_axis: raise ValueError('`axis` and `batch_axis` cannot be the same.') - with variable_scope.variable_scope(name, 'VBN', - values=[reference_batch]) as self._vs: + with variable_scope.variable_scope( + name, 'VBN', values=[reference_batch]) as self._vs: self._reference_batch = reference_batch # Calculate important shapes: @@ -217,14 +219,15 @@ class VBN(object): # that can be easily modified by additional examples. self._ref_mean, self._ref_mean_squares = _statistics( self._reference_batch, reduction_axes) - self._ref_variance = (self._ref_mean_squares - - math_ops.square(self._ref_mean)) + self._ref_variance = ( + self._ref_mean_squares - math_ops.square(self._ref_mean)) # Virtual batch normalization uses a weighted average between example # statistics and the reference batch statistics. - ref_batch_size = _static_or_dynamic_batch_size( - self._reference_batch, self._batch_axis) - self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.) + ref_batch_size = _static_or_dynamic_batch_size(self._reference_batch, + self._batch_axis) + self._example_weight = 1. / ( + math_ops.cast(ref_batch_size, dtypes.float32) + 1.) self._ref_weight = 1. - self._example_weight # Make the variables, if necessary. @@ -246,10 +249,11 @@ class VBN(object): def _virtual_statistics(self, inputs, reduction_axes): """Compute the statistics needed for virtual batch normalization.""" cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes) - vb_mean = (self._example_weight * cur_mean + - self._ref_weight * self._ref_mean) - vb_mean_sq = (self._example_weight * cur_mean_sq + - self._ref_weight * self._ref_mean_squares) + vb_mean = ( + self._example_weight * cur_mean + self._ref_weight * self._ref_mean) + vb_mean_sq = ( + self._example_weight * cur_mean_sq + + self._ref_weight * self._ref_mean_squares) return (vb_mean, vb_mean_sq) def _broadcast(self, v, broadcast_shape=None): @@ -268,8 +272,7 @@ class VBN(object): self._broadcast(self._ref_mean), self._broadcast(self._ref_variance), self._broadcast(self._beta), - self._broadcast(self._gamma), - self._epsilon) + self._broadcast(self._gamma), self._epsilon) def __call__(self, inputs): """Run virtual batch normalization on inputs. @@ -298,9 +301,7 @@ class VBN(object): b_shape[self._batch_axis] = _static_or_dynamic_batch_size( inputs, self._batch_axis) return nn.batch_normalization( - inputs, - self._broadcast(vb_mean, b_shape), + inputs, self._broadcast(vb_mean, b_shape), self._broadcast(vb_variance, b_shape), self._broadcast(self._beta, self._broadcast_shape), - self._broadcast(self._gamma, self._broadcast_shape), - self._epsilon) + self._broadcast(self._gamma, self._broadcast_shape), self._epsilon) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py index ecfbb8a432e..9848f654bad 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py @@ -112,7 +112,7 @@ class VirtualBatchnormTest(test.TestCase): batch, axis, training=True) # Get VBN's batch normalization on reference batch. - batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same + batch_axis = 0 if axis != 0 else 1 # axis and batch_axis can't same vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis) vbn_normalized = vbn.reference_batch_normalization() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 1f1ae2df4d6..99bdf5b20d3 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -36,7 +36,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -51,7 +50,6 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.ops.losses import util from tensorflow.python.summary import summary - __all__ = [ 'acgan_discriminator_loss', 'acgan_generator_loss', @@ -95,19 +93,19 @@ def wasserstein_generator_loss( the same as the corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add detailed summaries for the loss. Returns: A loss Tensor. The shape depends on `reduction`. """ - with ops.name_scope(scope, 'generator_wasserstein_loss', ( - discriminator_gen_outputs, weights)) as scope: + with ops.name_scope(scope, 'generator_wasserstein_loss', + (discriminator_gen_outputs, weights)) as scope: discriminator_gen_outputs = _to_float(discriminator_gen_outputs) - loss = - discriminator_gen_outputs - loss = losses.compute_weighted_loss( - loss, weights, scope, loss_collection, reduction) + loss = -discriminator_gen_outputs + loss = losses.compute_weighted_loss(loss, weights, scope, loss_collection, + reduction) if add_summaries: summary.scalar('generator_wass_loss', loss) @@ -140,25 +138,31 @@ def wasserstein_discriminator_loss( `discriminator_gen_outputs`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: A loss Tensor. The shape depends on `reduction`. """ - with ops.name_scope(scope, 'discriminator_wasserstein_loss', ( - discriminator_real_outputs, discriminator_gen_outputs, real_weights, - generated_weights)) as scope: + with ops.name_scope(scope, 'discriminator_wasserstein_loss', + (discriminator_real_outputs, discriminator_gen_outputs, + real_weights, generated_weights)) as scope: discriminator_real_outputs = _to_float(discriminator_real_outputs) discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) loss_on_generated = losses.compute_weighted_loss( - discriminator_gen_outputs, generated_weights, scope, - loss_collection=None, reduction=reduction) + discriminator_gen_outputs, + generated_weights, + scope, + loss_collection=None, + reduction=reduction) loss_on_real = losses.compute_weighted_loss( - discriminator_real_outputs, real_weights, scope, loss_collection=None, + discriminator_real_outputs, + real_weights, + scope, + loss_collection=None, reduction=reduction) loss = loss_on_generated - loss_on_real util.add_loss(loss, loss_collection) @@ -173,17 +177,16 @@ def wasserstein_discriminator_loss( # ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs` # (https://arxiv.org/abs/1610.09585). -def acgan_discriminator_loss( - discriminator_real_classification_logits, - discriminator_gen_classification_logits, - one_hot_labels, - label_smoothing=0.0, - real_weights=1.0, - generated_weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): +def acgan_discriminator_loss(discriminator_real_classification_logits, + discriminator_gen_classification_logits, + one_hot_labels, + label_smoothing=0.0, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): """ACGAN loss for the discriminator. The ACGAN loss adds a classification loss to the conditional discriminator. @@ -212,7 +215,7 @@ def acgan_discriminator_loss( `discriminator_gen_classification_logits`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: @@ -226,13 +229,20 @@ def acgan_discriminator_loss( (discriminator_real_classification_logits, discriminator_gen_classification_logits, one_hot_labels)) as scope: loss_on_generated = losses.softmax_cross_entropy( - one_hot_labels, discriminator_gen_classification_logits, - weights=generated_weights, scope=scope, loss_collection=None, + one_hot_labels, + discriminator_gen_classification_logits, + weights=generated_weights, + scope=scope, + loss_collection=None, reduction=reduction) loss_on_real = losses.softmax_cross_entropy( - one_hot_labels, discriminator_real_classification_logits, - weights=real_weights, label_smoothing=label_smoothing, scope=scope, - loss_collection=None, reduction=reduction) + one_hot_labels, + discriminator_real_classification_logits, + weights=real_weights, + label_smoothing=label_smoothing, + scope=scope, + loss_collection=None, + reduction=reduction) loss = loss_on_generated + loss_on_real util.add_loss(loss, loss_collection) @@ -244,14 +254,13 @@ def acgan_discriminator_loss( return loss -def acgan_generator_loss( - discriminator_gen_classification_logits, - one_hot_labels, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): +def acgan_generator_loss(discriminator_gen_classification_logits, + one_hot_labels, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): """ACGAN loss for the generator. The ACGAN loss adds a classification loss to the conditional discriminator. @@ -273,7 +282,7 @@ def acgan_generator_loss( either `1`, or the same as the corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: @@ -287,8 +296,11 @@ def acgan_generator_loss( scope, 'acgan_generator_loss', (discriminator_gen_classification_logits, one_hot_labels)) as scope: loss = losses.softmax_cross_entropy( - one_hot_labels, discriminator_gen_classification_logits, - weights=weights, scope=scope, loss_collection=loss_collection, + one_hot_labels, + discriminator_gen_classification_logits, + weights=weights, + scope=scope, + loss_collection=loss_collection, reduction=reduction) if add_summaries: @@ -323,8 +335,8 @@ def wasserstein_gradient_penalty( Args: real_data: Real data. generated_data: Output of the generator. - generator_inputs: Exact argument to pass to the generator, which is used - as optional conditioning to the discriminator. + generator_inputs: Exact argument to pass to the generator, which is used as + optional conditioning to the discriminator. discriminator_fn: A discriminator function that conforms to TF-GAN API. discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when @@ -334,12 +346,12 @@ def wasserstein_gradient_penalty( one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `real_data` and `generated_data`, and must be broadcastable to - them (i.e., all dimensions must be either `1`, or the same as the - corresponding dimension). + `real_data` and `generated_data`, and must be broadcastable to them (i.e., + all dimensions must be either `1`, or the same as the corresponding + dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: @@ -366,8 +378,10 @@ def wasserstein_gradient_penalty( with ops.name_scope(None): # Clear scope so update ops are added properly. # Reuse variables if variables already exists. - with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', - reuse=variable_scope.AUTO_REUSE): + with variable_scope.variable_scope( + discriminator_scope, + 'gpenalty_dscope', + reuse=variable_scope.AUTO_REUSE): disc_interpolates = discriminator_fn(interpolates, generator_inputs) if isinstance(disc_interpolates, tuple): @@ -379,8 +393,8 @@ def wasserstein_gradient_penalty( math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) # Propagate shape information, if possible. if isinstance(batch_size, int): - gradient_squares.set_shape([ - batch_size] + gradient_squares.shape.as_list()[1:]) + gradient_squares.set_shape([batch_size] + + gradient_squares.shape.as_list()[1:]) # For numerical stability, add epsilon to the sum before taking the square # root. Note tf.norm does not add epsilon. slopes = math_ops.sqrt(gradient_squares + epsilon) @@ -389,8 +403,11 @@ def wasserstein_gradient_penalty( penalties = math_ops.maximum(0., penalties) penalties_squared = math_ops.square(penalties) penalty = losses.compute_weighted_loss( - penalties_squared, weights, scope=scope, - loss_collection=loss_collection, reduction=reduction) + penalties_squared, + weights, + scope=scope, + loss_collection=loss_collection, + reduction=reduction) if add_summaries: summary.scalar('gradient_penalty_loss', penalty) @@ -437,26 +454,34 @@ def minimax_discriminator_loss( generated_weights: Same as `real_weights`, but for `generated_data`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: A loss Tensor. The shape depends on `reduction`. """ - with ops.name_scope(scope, 'discriminator_minimax_loss', ( - discriminator_real_outputs, discriminator_gen_outputs, real_weights, - generated_weights, label_smoothing)) as scope: + with ops.name_scope( + scope, 'discriminator_minimax_loss', + (discriminator_real_outputs, discriminator_gen_outputs, real_weights, + generated_weights, label_smoothing)) as scope: # -log((1 - label_smoothing) - sigmoid(D(x))) loss_on_real = losses.sigmoid_cross_entropy( array_ops.ones_like(discriminator_real_outputs), - discriminator_real_outputs, real_weights, label_smoothing, scope, - loss_collection=None, reduction=reduction) + discriminator_real_outputs, + real_weights, + label_smoothing, + scope, + loss_collection=None, + reduction=reduction) # -log(- sigmoid(D(G(x)))) loss_on_generated = losses.sigmoid_cross_entropy( array_ops.zeros_like(discriminator_gen_outputs), - discriminator_gen_outputs, generated_weights, scope=scope, - loss_collection=None, reduction=reduction) + discriminator_gen_outputs, + generated_weights, + scope=scope, + loss_collection=None, + reduction=reduction) loss = loss_on_real + loss_on_generated util.add_loss(loss, loss_collection) @@ -469,14 +494,13 @@ def minimax_discriminator_loss( return loss -def minimax_generator_loss( - discriminator_gen_outputs, - label_smoothing=0.0, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): +def minimax_generator_loss(discriminator_gen_outputs, + label_smoothing=0.0, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): """Original minimax generator loss for GANs. Note that the authors don't recommend using this loss. A more practically @@ -499,17 +523,23 @@ def minimax_generator_loss( the same as the corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: A loss Tensor. The shape depends on `reduction`. """ with ops.name_scope(scope, 'generator_minimax_loss') as scope: - loss = - minimax_discriminator_loss( + loss = -minimax_discriminator_loss( array_ops.ones_like(discriminator_gen_outputs), - discriminator_gen_outputs, label_smoothing, weights, weights, scope, - loss_collection, reduction, add_summaries=False) + discriminator_gen_outputs, + label_smoothing, + weights, + weights, + scope, + loss_collection, + reduction, + add_summaries=False) if add_summaries: summary.scalar('generator_minimax_loss', loss) @@ -547,32 +577,26 @@ def modified_discriminator_loss( `discriminator_gen_outputs`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: A loss Tensor. The shape depends on `reduction`. """ - return minimax_discriminator_loss( - discriminator_real_outputs, - discriminator_gen_outputs, - label_smoothing, - real_weights, - generated_weights, - scope or 'discriminator_modified_loss', - loss_collection, - reduction, - add_summaries) + return minimax_discriminator_loss(discriminator_real_outputs, + discriminator_gen_outputs, label_smoothing, + real_weights, generated_weights, scope or + 'discriminator_modified_loss', + loss_collection, reduction, add_summaries) -def modified_generator_loss( - discriminator_gen_outputs, - label_smoothing=0.0, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): +def modified_generator_loss(discriminator_gen_outputs, + label_smoothing=0.0, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): """Modified generator loss for GANs. L = -log(sigmoid(D(G(z)))) @@ -593,7 +617,7 @@ def modified_generator_loss( dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: @@ -644,7 +668,7 @@ def least_squares_generator_loss( the same as the corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: @@ -653,10 +677,10 @@ def least_squares_generator_loss( with ops.name_scope(scope, 'lsq_generator_loss', (discriminator_gen_outputs, real_label)) as scope: discriminator_gen_outputs = _to_float(discriminator_gen_outputs) - loss = math_ops.squared_difference( - discriminator_gen_outputs, real_label) / 2.0 - loss = losses.compute_weighted_loss( - loss, weights, scope, loss_collection, reduction) + loss = math_ops.squared_difference(discriminator_gen_outputs, + real_label) / 2.0 + loss = losses.compute_weighted_loss(loss, weights, scope, loss_collection, + reduction) if add_summaries: summary.scalar('generator_lsq_loss', loss) @@ -699,7 +723,7 @@ def least_squares_discriminator_loss( `discriminator_gen_outputs`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: @@ -712,16 +736,22 @@ def least_squares_discriminator_loss( discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) - real_losses = math_ops.squared_difference( - discriminator_real_outputs, real_label) / 2.0 - fake_losses = math_ops.squared_difference( - discriminator_gen_outputs, fake_label) / 2.0 + real_losses = math_ops.squared_difference(discriminator_real_outputs, + real_label) / 2.0 + fake_losses = math_ops.squared_difference(discriminator_gen_outputs, + fake_label) / 2.0 loss_on_real = losses.compute_weighted_loss( - real_losses, real_weights, scope, loss_collection=None, + real_losses, + real_weights, + scope, + loss_collection=None, reduction=reduction) loss_on_generated = losses.compute_weighted_loss( - fake_losses, generated_weights, scope, loss_collection=None, + fake_losses, + generated_weights, + scope, + loss_collection=None, reduction=reduction) loss = loss_on_real + loss_on_generated @@ -745,7 +775,7 @@ def _validate_distributions(distributions): raise ValueError('`distributions` must be a list or tuple. Instead, ' 'found %s.' % type(distributions)) for x in distributions: - # We used to check with `isinstance(x, tf.distributions.Distribution)`. + # We used to check with `isinstance(x, tf.compat.v1.distributions.Distribution)`. # However, distributions have migrated to `tfp.distributions.Distribution`, # which is a new code repo, so we can't check this way anymore until # TF-GAN is migrated to a new repo as well. @@ -755,15 +785,15 @@ def _validate_distributions(distributions): 'Instead, found %s.' % type(x)) -def _validate_information_penalty_inputs( - structured_generator_inputs, predicted_distributions): +def _validate_information_penalty_inputs(structured_generator_inputs, + predicted_distributions): """Validate input to `mutual_information_penalty`.""" _validate_distributions(predicted_distributions) if len(structured_generator_inputs) != len(predicted_distributions): - raise ValueError('`structured_generator_inputs` length %i must be the same ' - 'as `predicted_distributions` length %i.' % ( - len(structured_generator_inputs), - len(predicted_distributions))) + raise ValueError( + '`structured_generator_inputs` length %i must be the same ' + 'as `predicted_distributions` length %i.' % + (len(structured_generator_inputs), len(predicted_distributions))) def mutual_information_penalty( @@ -789,21 +819,26 @@ def mutual_information_penalty( `structured_generator_inputs`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. - reduction: A `tf.losses.Reduction` to apply to loss. + reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. add_summaries: Whether or not to add summaries for the loss. Returns: A scalar Tensor representing the mutual information loss. """ - _validate_information_penalty_inputs( - structured_generator_inputs, predicted_distributions) + _validate_information_penalty_inputs(structured_generator_inputs, + predicted_distributions) with ops.name_scope(scope, 'mutual_information_loss') as scope: # Calculate the negative log-likelihood of the reconstructed noise. - log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in - zip(predicted_distributions, structured_generator_inputs)] + log_probs = [ + math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in zip( + predicted_distributions, structured_generator_inputs) + ] loss = -1 * losses.compute_weighted_loss( - log_probs, weights, scope, loss_collection=loss_collection, + log_probs, + weights, + scope, + loss_collection=loss_collection, reduction=reduction) if add_summaries: @@ -828,10 +863,11 @@ def _numerically_stable_global_norm(tensor_list): if all(x is None for x in tensor_list): return 0.0 - list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in - tensor_list if x is not None]) - return list_max * clip_ops.global_norm([x / list_max for x in tensor_list - if x is not None]) + list_max = math_ops.reduce_max([ + math_ops.reduce_max(math_ops.abs(x)) for x in tensor_list if x is not None + ]) + return list_max * clip_ops.global_norm( + [x / list_max for x in tensor_list if x is not None]) def _used_weight(weights_list): @@ -879,9 +915,9 @@ def combine_adversarial_loss(main_loss, adversarial loss. Exactly one of this and `gradient_ratio` must be non-None. gradient_ratio: If not `None`, the ratio of the magnitude of the gradients. - Specifically, - gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss) - Exactly one of this and `weight_factor` must be non-None. + Specifically, gradient_ratio = grad_mag(main_loss) / + grad_mag(adversarial_loss) Exactly one of this and `weight_factor` must be + non-None. gradient_ratio_epsilon: An epsilon to add to the adversarial loss coefficient denominator, to avoid division-by-zero. variables: List of variables to calculate gradients with respect to. If not @@ -900,8 +936,8 @@ def combine_adversarial_loss(main_loss, if variables is None: variables = contrib_variables_lib.get_trainable_variables() - with ops.name_scope(scope, 'adversarial_loss', - values=[main_loss, adversarial_loss]): + with ops.name_scope( + scope, 'adversarial_loss', values=[main_loss, adversarial_loss]): # Compute gradients if we will need them. if gradient_summaries or gradient_ratio is not None: main_loss_grad_mag = _numerically_stable_global_norm( @@ -923,15 +959,15 @@ def combine_adversarial_loss(main_loss, if _used_weight((weight_factor, gradient_ratio)) == 0: final_loss = main_loss elif weight_factor is not None: - final_loss = (main_loss + - array_ops.stop_gradient(weight_factor) * adversarial_loss) + final_loss = ( + main_loss + array_ops.stop_gradient(weight_factor) * adversarial_loss) elif gradient_ratio is not None: grad_mag_ratio = main_loss_grad_mag / ( adv_loss_grad_mag + gradient_ratio_epsilon) adv_coeff = grad_mag_ratio / gradient_ratio summary.scalar('adversarial_coefficient', adv_coeff) - final_loss = (main_loss + - array_ops.stop_gradient(adv_coeff) * adversarial_loss) + final_loss = ( + main_loss + array_ops.stop_gradient(adv_coeff) * adversarial_loss) return final_loss diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 9bff8090d93..422e16f0bfe 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -127,11 +127,12 @@ def gan_model( generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) - return namedtuples.GANModel( - generator_inputs, generated_data, generator_variables, gen_scope, - generator_fn, real_data, discriminator_real_outputs, - discriminator_gen_outputs, discriminator_variables, dis_scope, - discriminator_fn) + return namedtuples.GANModel(generator_inputs, generated_data, + generator_variables, gen_scope, generator_fn, + real_data, discriminator_real_outputs, + discriminator_gen_outputs, + discriminator_variables, dis_scope, + discriminator_fn) def infogan_model( @@ -158,10 +159,10 @@ def infogan_model( of Tensorflow distributions representing the predicted noise distribution of the ith structure noise. real_data: A Tensor representing the real data. - unstructured_generator_inputs: A list of Tensors to the generator. - These tensors represent the unstructured noise or conditioning. - structured_generator_inputs: A list of Tensors to the generator. - These tensors must have high mutual information with the recognizer. + unstructured_generator_inputs: A list of Tensors to the generator. These + tensors represent the unstructured noise or conditioning. + structured_generator_inputs: A list of Tensors to the generator. These + tensors must have high mutual information with the recognizer. generator_scope: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created. discriminator_scope: Optional discriminator variable scope. Useful if you @@ -246,9 +247,9 @@ def acgan_model( generator_fn: A python lambda that takes `generator_inputs` as inputs and returns the outputs of the GAN generator. discriminator_fn: A python lambda that takes `real_data`/`generated data` - and `generator_inputs`. Outputs a tuple consisting of two Tensors: - (1) real/fake logits in the range [-inf, inf] - (2) classification logits in the range [-inf, inf] + and `generator_inputs`. Outputs a tuple consisting of two Tensors: (1) + real/fake logits in the range [-inf, inf] (2) classification logits in + the range [-inf, inf] real_data: A Tensor representing the real data. generator_inputs: A Tensor or list of Tensors to the generator. In the vanilla GAN case, this might be a single noise Tensor. In the conditional @@ -296,13 +297,14 @@ def acgan_model( generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) - return namedtuples.ACGANModel( - generator_inputs, generated_data, generator_variables, gen_scope, - generator_fn, real_data, discriminator_real_outputs, - discriminator_gen_outputs, discriminator_variables, dis_scope, - discriminator_fn, one_hot_labels, - discriminator_real_classification_logits, - discriminator_gen_classification_logits) + return namedtuples.ACGANModel(generator_inputs, generated_data, + generator_variables, gen_scope, generator_fn, + real_data, discriminator_real_outputs, + discriminator_gen_outputs, + discriminator_variables, dis_scope, + discriminator_fn, one_hot_labels, + discriminator_real_classification_logits, + discriminator_gen_classification_logits) def cyclegan_model( @@ -538,8 +540,8 @@ def _tensor_pool_adjusted_model(model, tensor_pool_fn): generator_inputs=pooled_generator_inputs, generated_data=pooled_generated_data, discriminator_gen_outputs=pooled_discriminator_gen_outputs, - discriminator_gen_classification_logits= - pooled_discriminator_gen_classification_logits) + discriminator_gen_classification_logits=pooled_discriminator_gen_classification_logits # pylint: disable=line-too-long + ) elif isinstance(model, namedtuples.InfoGANModel): pooled_generator_inputs, pooled_generated_data, pooled_structured_input = ( tensor_pool_fn((model.generator_inputs, model.generated_data, @@ -598,7 +600,7 @@ def gan_loss( mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more - details. + details. aux_cond_generator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 aux_cond_discriminator_weight: If not None: add a classification loss as in @@ -730,8 +732,8 @@ def cyclegan_loss( """ # Sanity checks. if not isinstance(model, namedtuples.CycleGANModel): - raise ValueError( - '`model` must be a `CycleGANModel`. Instead, was %s.' % type(model)) + raise ValueError('`model` must be a `CycleGANModel`. Instead, was %s.' % + type(model)) # Defines cycle consistency loss. cycle_consistency_loss = cycle_consistency_loss_fn( @@ -757,6 +759,7 @@ def cyclegan_loss( return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) + # Begin google-internal # The four major parts can be found here: http://screen/tMRMBAohDYG. # End google-internal @@ -786,7 +789,7 @@ def stargan_loss( `StarGANModel` namedtuple. gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to - turn off gradient penalty. + turn off gradient penalty. gradient_penalty_epsilon: (float) A small positive number added for numerical stability when computing the gradient norm. gradient_penalty_target: (float, or tf.float `Tensor`) The target value of @@ -944,9 +947,8 @@ def gan_train_ops( update ops outside of the generator or discriminator scopes. is_chief: Specifies whether or not the training is being run by the primary replica during replica training. - **kwargs: Keyword args to pass directly to - `training.create_train_op` for both the generator and - discriminator train op. + **kwargs: Keyword args to pass directly to `training.create_train_op` for + both the generator and discriminator train op. Returns: A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can @@ -1065,8 +1067,8 @@ def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): """Returns a hooks function for sequential GAN training. Args: - train_steps: A `GANTrainSteps` tuple that determines how many generator - and discriminator training steps to take. + train_steps: A `GANTrainSteps` tuple that determines how many generator and + discriminator training steps to take. Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. @@ -1106,7 +1108,8 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates for the generator and discriminator simultaneously whenever possible. This - reduces the number of `tf.Session` calls, and can also change the training + reduces the number of `tf.compat.v1.Session` calls, and can also change the + training semantics. To illustrate the difference look at the following example: @@ -1121,8 +1124,8 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): 2) 2 discriminator steps Args: - train_steps: A `GANTrainSteps` tuple that determines how many generator - and discriminator training steps to take. + train_steps: A `GANTrainSteps` tuple that determines how many generator and + discriminator training steps to take. Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. @@ -1165,11 +1168,11 @@ def gan_train(train_ops, master: The URL of the master. is_chief: Specifies whether or not the training is being run by the primary replica during replica training. - scaffold: An tf.train.Scaffold instance. - hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the - training loop. - chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run - inside the training loop for the chief trainer only. + scaffold: An tf.compat.v1.train.Scaffold instance. + hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside + the training loop. + chief_only_hooks: List of `tf.estimator.SessionRunHook` instances which are + run inside the training loop for the chief trainer only. save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If `save_checkpoint_secs` is set to `None`, then the default checkpoint saver isn't used. @@ -1177,7 +1180,7 @@ def gan_train(train_ops, summaries are written to disk using a default summary saver. If `save_summaries_steps` is set to `None`, then the default summary saver isn't used. - config: An instance of `tf.ConfigProto`. + config: An instance of `tf.compat.v1.ConfigProto`. Returns: Output of the call to `training.train`. @@ -1207,8 +1210,8 @@ def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)): use `MonitoredSession` and `get_sequential_train_hooks`. Args: - train_steps: A `GANTrainSteps` tuple that determines how many generator - and discriminator training steps to take. + train_steps: A `GANTrainSteps` tuple that determines how many generator and + discriminator training steps to take. Returns: A function that can be used for `train_step_fn` for GANs. @@ -1238,8 +1241,9 @@ def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)): # Run generator training steps. gen_loss = 0 for _ in range(train_steps.generator_train_steps): - cur_gen_loss, _ = slim_learning.train_step( - sess, train_ops.generator_train_op, global_step, train_kwargs) + cur_gen_loss, _ = slim_learning.train_step(sess, + train_ops.generator_train_op, + global_step, train_kwargs) gen_loss += cur_gen_loss # Run discriminator training steps. @@ -1306,7 +1310,9 @@ def _generate_stargan_random_domain_target(batch_size, num_domains): Returns: Tensor of shape (batch_size, num_domains) representing random label. """ - domain_idx = random_ops.random_uniform( - [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32) + domain_idx = random_ops.random_uniform([batch_size], + minval=0, + maxval=num_domains, + dtype=dtypes.int32) return array_ops.one_hot(domain_idx, num_domains) diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 7321e973191..9b8e832fd96 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -250,10 +250,9 @@ Status GdrMemoryManager::Init() { LOG(INFO) << "Instrumenting CPU allocator(s)"; for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { - GPUProcessState::singleton()->AddCUDAHostAllocVisitor(numa_idx, - alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostFreeVisitor(numa_idx, - free_visitor); + GPUProcessState::singleton()->AddGpuHostAllocVisitor(numa_idx, + alloc_visitor); + GPUProcessState::singleton()->AddGpuHostFreeVisitor(numa_idx, free_visitor); } if (IsGDRAvailable()) { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 1204b8ca501..9dfca258a6f 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -128,7 +128,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, StatusCallback copy_ready = [response, done, copy, is_dead](const Status& s) { // The value is now ready to be returned on the wire. - grpc::EncodeTensorToByteBuffer(is_dead, *copy, response); + grpc::EncodeTensorToByteBuffer(is_dead, *copy, false, response); done(s); delete copy; }; @@ -136,7 +136,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, send_dev_context->CopyDeviceTensorToCPU( &val, request->rendezvous_key(), src_dev, copy, copy_ready); } else { - grpc::EncodeTensorToByteBuffer(is_dead, val, response); + grpc::EncodeTensorToByteBuffer(is_dead, val, false, response); done(Status::OK()); } } diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index 1711100e3a8..35b6e638763 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -53,6 +53,7 @@ py_library( py_test( name = "util_test", srcs = ["tests/util_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_editor_py", @@ -65,6 +66,7 @@ py_test( py_test( name = "select_test", srcs = ["tests/select_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_editor_py", @@ -77,6 +79,7 @@ py_test( py_test( name = "match_test", srcs = ["tests/match_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":match", @@ -89,6 +92,7 @@ py_test( py_test( name = "subgraph_test", srcs = ["tests/subgraph_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_editor_py", @@ -101,6 +105,7 @@ py_test( py_test( name = "reroute_test", srcs = ["tests/reroute_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_editor_py", @@ -114,6 +119,7 @@ py_test( py_test( name = "edit_test", srcs = ["tests/edit_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_editor_py", @@ -127,6 +133,7 @@ py_test( py_test( name = "transform_test", srcs = ["tests/transform_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_editor_py", diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 5b37239665d..0a0c476dd1e 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -86,7 +86,7 @@ def assign_renamed_collections_handler(info, elem, elem_): """Add the transformed elem to the (renamed) collections of elem. A collection is renamed only if is not a known key, as described in - `tf.GraphKeys`. + `tf.compat.v1.GraphKeys`. Args: info: Transform._TmpInfo instance. diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 584f4509ccc..4b53d182f34 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -450,18 +450,19 @@ def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): def make_placeholder_from_tensor(t, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): - """Create a `tf.placeholder` for the Graph Editor. + """Create a `tf.compat.v1.placeholder` for the Graph Editor. Note that the correct graph scope must be set by the calling function. Args: - t: a `tf.Tensor` whose name will be used to create the placeholder - (see function placeholder_name). - scope: absolute scope within which to create the placeholder. None - means that the scope of `t` is preserved. `""` means the root scope. + t: a `tf.Tensor` whose name will be used to create the placeholder (see + function placeholder_name). + scope: absolute scope within which to create the placeholder. None means + that the scope of `t` is preserved. `""` means the root scope. prefix: placeholder name prefix. + Returns: - A newly created `tf.placeholder`. + A newly created `tf.compat.v1.placeholder`. Raises: TypeError: if `t` is not `None` or a `tf.Tensor`. """ @@ -472,7 +473,7 @@ def make_placeholder_from_tensor(t, scope=None, def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): - """Create a tf.placeholder for the Graph Editor. + """Create a tf.compat.v1.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. The placeholder is named using the function placeholder_name (with no @@ -481,9 +482,10 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, Args: dtype: the tensor type. shape: the tensor shape (optional). - scope: absolute scope within which to create the placeholder. None - means that the scope of t is preserved. "" means the root scope. + scope: absolute scope within which to create the placeholder. None means + that the scope of t is preserved. "" means the root scope. prefix: placeholder name prefix. + Returns: A newly created tf.placeholder. """ diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index 71eac729a8a..a5b1c2cbea6 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -44,7 +44,7 @@ class SequenceFileDataset(dataset_ops.DatasetSource): For example: ```python - tf.enable_eager_execution() + tf.compat.v1.enable_eager_execution() dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq") # Prints the (key, value) pairs inside a hadoop sequence file. diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc index f97e790b56c..a0542b399fd 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" #include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index b25a6f7b574..05ba9155c40 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -518,7 +518,7 @@ def connected_components(images): def has_zero(): # Insert a zero in the consecutive ids where zero appears in unique_ids. # id_is_zero has length 1. - zero_id_ind = math_ops.to_int32(id_is_zero[0]) + zero_id_ind = math_ops.cast(id_is_zero[0], dtypes.int32) ids_before = nonzero_consecutive_ids[:zero_id_ind] ids_after = nonzero_consecutive_ids[zero_id_ind:] return array_ops.concat([ids_before, [0], ids_after], axis=0) diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index 0ceb683ff4c..2b0bcf64019 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -58,7 +58,7 @@ def single_image_random_dot_stereograms(depth_values, [1,2,3,4,5,3], [1,2,3,4,5,4], [6,5,4,4,5,5]] - session = tf.InteractiveSession() + session = tf.compat.v1.InteractiveSession() sirds = single_image_random_dot_stereograms( img, convergence_dots_size=8, diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index 0e34315db45..cf786c062ea 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -2,17 +2,11 @@ # Contains ops to build an input pipeline for tensorflow. # APIs here are meant to evolve over time. -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -package(default_visibility = ["//visibility:public"]) - +load("//tensorflow:tensorflow.bzl", "py_test") load( "//tensorflow:tensorflow.bzl", - "py_test", - "tf_custom_op_library", "tf_cc_tests", + "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", @@ -23,6 +17,12 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + tf_custom_op_library( # TODO(sibyl-Mooth6ku,ptucker): Understand why 'python/ops/_' is needed and fix it. name = "python/ops/_input_pipeline_ops.so", @@ -79,6 +79,7 @@ py_test( name = "input_pipeline_ops_test", size = "small", srcs = ["python/ops/input_pipeline_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":input_pipeline_py", diff --git a/tensorflow/contrib/integrate/BUILD b/tensorflow/contrib/integrate/BUILD index 0b7d64f4edd..9a2c94446fd 100644 --- a/tensorflow/contrib/integrate/BUILD +++ b/tensorflow/contrib/integrate/BUILD @@ -31,6 +31,7 @@ py_library( py_test( name = "odes_test", srcs = ["python/ops/odes_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":integrate_py", diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index 87c2dcd89b6..833771eda0f 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -47,6 +47,7 @@ py_library( py_test( name = "random_fourier_features_test", srcs = ["python/mappers/random_fourier_features_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":dense_kernel_mapper_py", @@ -63,6 +64,7 @@ py_test( py_test( name = "kernel_estimators_test", srcs = ["python/kernel_estimators_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -80,6 +82,7 @@ py_test( py_test( name = "losses_test", srcs = ["python/losses_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":kernel_methods", diff --git a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py index 1626e55b9b3..0f863c5a906 100644 --- a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py +++ b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py @@ -77,7 +77,7 @@ def _update_features_and_columns(features, feature_columns, return features, feature_columns # First construct new columns and features affected by kernel_mappers_dict. - mapped_features = dict() + mapped_features = {} mapped_columns = set() for feature_column in kernel_mappers_dict: column_name = feature_column.name diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 294a7d69a70..0d43bc2101b 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -80,7 +80,7 @@ def sparse_multiclass_hinge_loss( ' {}'.format(logits_rank)) logits_shape = array_ops.shape(logits) batch_size, num_classes = logits_shape[0], logits_shape[1] - logits = math_ops.to_float(logits) + logits = math_ops.cast(logits, dtypes.float32) # Check labels have valid type. if labels.dtype != dtypes.int32 and labels.dtype != dtypes.int64: diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py index af7018f8368..e20ed4e1cac 100644 --- a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py +++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py @@ -110,7 +110,7 @@ class KinesisDatasetTest(test.TestCase): init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() - data = list() + data = [] with self.cached_session() as sess: # Basic test: read from shard 0 of stream 2. sess.run( diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 9479afb180d..e3918b91d1e 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -35,7 +35,7 @@ class KinesisDataset(dataset_ops.DatasetSource): For example, we can construct and use the KinesisDataset as follows: ```python - tf.enable_eager_execution() + tf.compat.v1.enable_eager_execution() dataset = tf.contrib.kinesis.KinesisDataset( "kinesis_stream_name", read_indefinitely=False) diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 7e19ae7c13d..fb28d6689a6 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -69,6 +69,7 @@ py_test( srcs = [ "python/ops/core_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_windows", # TODO: needs investigation on Windows @@ -106,6 +107,7 @@ py_test( srcs = [ "python/ops/io_ops_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":core", @@ -136,6 +138,7 @@ py_test( srcs = [ "python/ops/nn_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":core", @@ -171,6 +174,7 @@ py_test( srcs = [ "python/ops/ops_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":core", @@ -205,6 +209,7 @@ py_test( srcs = [ "python/ops/sugar_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":core", diff --git a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py index 80fa17ec1f7..1783a07fac9 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Minimal runtime type checking library. This module should not be considered public API. @@ -54,7 +53,7 @@ class Type(object): def __repr__(self): args_repr = ", ".join(repr(t) for t in self._types) - return "typecheck.%s(%s)" % (type(self).__name__, args_repr) + return "typecheck.%s(%s)" % (type(self).__name__, args_repr) class _SingleArgumentType(Type): @@ -104,8 +103,8 @@ class List(_SingleArgumentType): """ def __instancecheck__(self, instance): - return (isinstance(instance, list) - and all(isinstance(x, self._type) for x in instance)) + return (isinstance(instance, list) and + all(isinstance(x, self._type) for x in instance)) class Sequence(_SingleArgumentType): @@ -115,8 +114,8 @@ class Sequence(_SingleArgumentType): """ def __instancecheck__(self, instance): - return (isinstance(instance, collections.Sequence) - and all(isinstance(x, self._type) for x in instance)) + return (isinstance(instance, collections.Sequence) and + all(isinstance(x, self._type) for x in instance)) class Collection(_SingleArgumentType): @@ -131,10 +130,10 @@ class Collection(_SingleArgumentType): """ def __instancecheck__(self, instance): - return (isinstance(instance, collections.Iterable) - and isinstance(instance, collections.Sized) - and isinstance(instance, collections.Container) - and all(isinstance(x, self._type) for x in instance)) + return (isinstance(instance, collections.Iterable) and + isinstance(instance, collections.Sized) and + isinstance(instance, collections.Container) and + all(isinstance(x, self._type) for x in instance)) class Tuple(Type): @@ -145,9 +144,9 @@ class Tuple(Type): """ def __instancecheck__(self, instance): - return (isinstance(instance, tuple) - and len(instance) == len(self._types) - and all(isinstance(x, t) for x, t in zip(instance, self._types))) + return (isinstance(instance, tuple) and + len(instance) == len(self._types) and + all(isinstance(x, t) for x, t in zip(instance, self._types))) class Mapping(_TwoArgumentType): @@ -158,9 +157,9 @@ class Mapping(_TwoArgumentType): def __instancecheck__(self, instance): key_type, value_type = self._types # pylint: disable=unbalanced-tuple-unpacking - return (isinstance(instance, collections.Mapping) - and all(isinstance(k, key_type) for k in instance.keys()) - and all(isinstance(k, value_type) for k in instance.values())) + return (isinstance(instance, collections.Mapping) and + all(isinstance(k, key_type) for k in instance.keys()) and + all(isinstance(k, value_type) for k in instance.values())) class Dict(Mapping): @@ -170,8 +169,8 @@ class Dict(Mapping): """ def __instancecheck__(self, instance): - return (isinstance(instance, dict) - and super(Dict, self).__instancecheck__(instance)) + return (isinstance(instance, dict) and + super(Dict, self).__instancecheck__(instance)) def _replace_forward_references(t, context): @@ -190,7 +189,8 @@ def register_type_abbreviation(name, alias): This makes otherwise very long typecheck errors much more readable. Example: - typecheck.register_type_abbreviation(tf.Dimension, 'tf.Dimension') + typecheck.register_type_abbreviation(tf.compat.v1.Dimension, + 'tf.compat.v1.Dimension') Args: name: type or class to abbreviate. @@ -240,14 +240,13 @@ def accepts(*types): if spec.defaults: num_defaults = len(spec.defaults) - for (name, a, t) in zip(spec.args[-num_defaults:], - spec.defaults, + for (name, a, t) in zip(spec.args[-num_defaults:], spec.defaults, types[-num_defaults:]): allowed_type = _replace_forward_references(t, f.__globals__) if not isinstance(a, allowed_type): raise Error("default argument value %r of type %r is not an instance " - "of the allowed type %s for the %s argument to %r" - % (a, type(a), _type_repr(allowed_type), name, f)) + "of the allowed type %s for the %s argument to %r" % + (a, type(a), _type_repr(allowed_type), name, f)) @functools.wraps(f) def new_f(*args, **kwds): @@ -273,11 +272,10 @@ def returns(*types): https://www.python.org/dev/peps/pep-0318/ Args: - *types: A list of Python types. - A list of one element corresponds to a single return value. - A list of several elements corresponds to several return values. - Note that a function with no explicit return value has an implicit - NoneType return and should be annotated correspondingly. + *types: A list of Python types. A list of one element corresponds to a + single return value. A list of several elements corresponds to several + return values. Note that a function with no explicit return value has an + implicit NoneType return and should be annotated correspondingly. Returns: A function to use as a decorator. @@ -297,17 +295,16 @@ def returns(*types): # The function has a single return value. allowed_type = _replace_forward_references(types[0], f.__globals__) if not isinstance(return_value, allowed_type): - raise Error("%r of type %r is not an instance of the allowed type %s " - "for %r" - % (return_value, type(return_value), - _type_repr(allowed_type), f)) + raise Error( + "%r of type %r is not an instance of the allowed type %s " + "for %r" % + (return_value, type(return_value), _type_repr(allowed_type), f)) else: if len(return_value) != len(types): - raise Error( - "Function %r has %d return values but only %d types were " - "provided in the annotation." % - (f, len(return_value), len(types))) + raise Error("Function %r has %d return values but only %d types were " + "provided in the annotation." % + (f, len(return_value), len(types))) for (r, t) in zip(return_value, types): allowed_type = _replace_forward_references(t, f.__globals__) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py index 8ee554ffa7a..b0961e5b3a2 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py @@ -48,7 +48,7 @@ from tensorflow.python.ops import math_ops # We use this instead of collections.Sequence to exclude strings. LabelsLike = tc.Union(np.ndarray, range, list, tuple) -# Types coercible to a tf.Dimension +# Types coercible to a tf.compat.v1.Dimension DimensionLike = tc.Optional(tc.Union(tensor_shape.Dimension, int)) # Types usable for axis values @@ -63,7 +63,7 @@ Scalar = tc.Union(numbers.Number, bool, binary_type, text_type) class Axis(object): """Size and label information for an axis. - Axis contains either a tf.Dimension indicating the size of an axis, + Axis contains either a tf.compat.v1.Dimension indicating the size of an axis, or a tuple of tick labels for the axis. If tick labels are provided, they must be unique. @@ -75,9 +75,9 @@ class Axis(object): Args: name: Name of the axis. - value: Either None, an int or tf.Dimension giving the size of the axis, - or a sequence that is not a string additionally providing coordinate - (tick) labels. + value: Either None, an int or tf.compat.v1.Dimension giving the size of + the axis, or a sequence that is not a string additionally providing + coordinate (tick) labels. Raises: ValueError: If the user provides labels with duplicate values. @@ -99,8 +99,8 @@ class Axis(object): if labels is not None: index = dict(zip(labels, range(len(labels)))) if len(index) != len(labels): - raise ValueError('Tick labels must be unique, but got {}' - .format(labels)) + raise ValueError( + 'Tick labels must be unique, but got {}'.format(labels)) else: index = None @@ -152,7 +152,7 @@ class Axis(object): @property @tc.returns(tc.Union(tuple, tensor_shape.Dimension)) def value(self): - """Returns the tf.Dimension or tuple specifying axis ticks.""" + """Returns the tf.compat.v1.Dimension or tuple specifying axis ticks.""" if self.labels is None: return self.dimension else: @@ -313,8 +313,9 @@ class LabeledTensor(object): # First, the rank of the tensor must be equal to the number of axes. if len(shape) != len(unvalidated_axes): - raise ValueError('Tensor rank was not equal to the number of axes: %r, %r' - % (shape, unvalidated_axes)) + raise ValueError( + 'Tensor rank was not equal to the number of axes: %r, %r' % + (shape, unvalidated_axes)) # Second, the size of each tensor dimension must match the size of the # corresponding indices. @@ -608,16 +609,14 @@ def identity(labeled_tensor, name=None): with ops.name_scope(name, 'lt_identity', [labeled_tensor]) as scope: labeled_tensor = convert_to_labeled_tensor(labeled_tensor) return LabeledTensor( - array_ops.identity( - labeled_tensor.tensor, name=scope), + array_ops.identity(labeled_tensor.tensor, name=scope), labeled_tensor.axes) # We don't call this slice because that shadows a built-in. Instead, we alias # this to lt.slice in __init__.py. @tc.returns(LabeledTensor) -@tc.accepts(LabeledTensorLike, - tc.Mapping(string_types, tc.Union(int, slice)), +@tc.accepts(LabeledTensorLike, tc.Mapping(string_types, tc.Union(int, slice)), tc.Optional(string_types)) def slice_function(labeled_tensor, selection, name=None): """Slice out a subset of the tensor. @@ -632,8 +631,8 @@ def slice_function(labeled_tensor, selection, name=None): Args: labeled_tensor: The input tensor. - selection: A dictionary of type str -> Union(int, slice of int) mapping - axis names to sub-selections. + selection: A dictionary of type str -> Union(int, slice of int) mapping axis + names to sub-selections. name: Optional op name. Returns: @@ -669,13 +668,12 @@ def slice_function(labeled_tensor, selection, name=None): assert isinstance(s, int) return LabeledTensor( - array_ops.identity( - sliced_tensor, name=scope), sliced_axes) + array_ops.identity(sliced_tensor, name=scope), sliced_axes) @tc.returns(LabeledTensor) -@tc.accepts(LabeledTensorLike, - tc.Optional(tc.Collection(string_types)), tc.Optional(string_types)) +@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)), + tc.Optional(string_types)) def transpose(labeled_tensor, axis_order=None, name=None): """Permute a tensor's axes. @@ -718,11 +716,11 @@ def transpose(labeled_tensor, axis_order=None, name=None): @tc.returns(LabeledTensor) -@tc.accepts( - LabeledTensorLike, - tc.Collection( - tc.Union(string_types, tc.Tuple(string_types, collections.Hashable))), - tc.Optional(string_types)) +@tc.accepts(LabeledTensorLike, + tc.Collection( + tc.Union(string_types, + tc.Tuple(string_types, collections.Hashable))), + tc.Optional(string_types)) def expand_dims(labeled_tensor, axes, name=None): """Insert dimensions of size 1. @@ -730,10 +728,10 @@ def expand_dims(labeled_tensor, axes, name=None): Args: labeled_tensor: The input tensor. - axes: The desired axis names as strings or tuples of (name, label), - where `label` is the coordinate name for the new dimension `name`. - These must include the existing axis names, and the existing names must - appear in the same order in this list as they do in the input tensor. + axes: The desired axis names as strings or tuples of (name, label), where + `label` is the coordinate name for the new dimension `name`. These must + include the existing axis names, and the existing names must appear in the + same order in this list as they do in the input tensor. name: Optional op name. Returns: @@ -886,8 +884,8 @@ def check_axis_order(labeled_tensor, axis_order=None): @tc.returns(LabeledTensor) -@tc.accepts(LabeledTensorLike, - tc.Optional(tc.Collection(string_types)), tc.Optional(string_types)) +@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)), + tc.Optional(string_types)) def impose_axis_order(labeled_tensor, axis_order=None, name=None): """Impose desired axis order on a labeled tensor. @@ -1065,7 +1063,7 @@ def define_unary_op(op_name, elementwise_function): op_name: string name of the TensorFlow op. elementwise_function: function to call to evaluate the op on a single tf.Tensor object. This function must accept two arguments: a tf.Tensor - object, and an optional `name`. + object, and an optional `name`. Returns: Function defining the given op that acts on LabeledTensors. @@ -1134,7 +1132,7 @@ def define_binary_op(op_name, elementwise_function): op_name: string name of the TensorFlow op. elementwise_function: function to call to evaluate the op on tf.Tensor objects. This function must accept three arguments: two tf.Tensor objects, - and an optional `name`. + and an optional `name`. Returns: Function defining the given op that acts on LabeledTensors. diff --git a/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py b/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py index 3bb9c21c2e3..3cda9c82788 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/io_ops.py @@ -159,7 +159,7 @@ def placeholder(dtype, axes, name=None): lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])]) - See tf.placeholder for more details. + See tf.compat.v1.placeholder for more details. Args: dtype: The type of elements in the tensor to be fed. diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index a65f045cc88..a04e3772799 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -1244,7 +1244,7 @@ def boolean_mask(labeled_tensor, mask, name=None): 'are not equal:\n%r\n%r' % (lt_axis, mask_axis)) op = array_ops.boolean_mask(labeled_tensor.tensor, mask.tensor, name=scope) # TODO(shoyer): attempt to infer labels for the masked values, by calling - # tf.contrib.util.constant_value on the mask? + # tf.get_static_value on the mask? axes = [lt_axis.name] + list(labeled_tensor.axes.values())[1:] return core.LabeledTensor(op, axes) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 69d5496f8ae..c6f6e722a4f 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -162,6 +162,7 @@ py_test( name = "regularizers_test", size = "small", srcs = ["python/layers/regularizers_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -178,6 +179,7 @@ py_test( name = "initializers_test", size = "small", srcs = ["python/layers/initializers_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -194,6 +196,7 @@ py_test( name = "normalization_test", size = "medium", srcs = ["python/layers/normalization_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ @@ -211,6 +214,7 @@ py_test( py_test( name = "optimizers_test", srcs = ["python/layers/optimizers_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -232,6 +236,7 @@ py_test( name = "summaries_test", size = "small", srcs = ["python/layers/summaries_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -247,6 +252,7 @@ py_test( name = "feature_column_test", size = "small", srcs = ["python/layers/feature_column_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -268,6 +274,7 @@ py_test( name = "feature_column_ops_test", size = "medium", srcs = ["python/layers/feature_column_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -292,6 +299,7 @@ py_test( name = "target_column_test", size = "small", srcs = ["python/layers/target_column_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -306,6 +314,7 @@ py_test( name = "sparse_feature_cross_op_test", size = "medium", srcs = ["python/kernel_tests/sparse_feature_cross_op_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -323,6 +332,7 @@ py_test( size = "small", timeout = "moderate", srcs = ["python/layers/embedding_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -345,6 +355,7 @@ py_test( name = "utils_test", size = "small", srcs = ["python/layers/utils_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -360,6 +371,7 @@ py_test( name = "sparse_ops_test", size = "small", srcs = ["python/ops/sparse_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ @@ -376,6 +388,7 @@ py_test( name = "encoders_test", size = "small", srcs = ["python/layers/encoders_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", @@ -390,6 +403,7 @@ py_test( name = "rev_block_lib_test", size = "medium", srcs = ["python/layers/rev_block_lib_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers_py", diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 429d696daf0..14bbe5f9b30 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -58,7 +58,8 @@ def safe_embedding_lookup_sparse(embedding_weights, The partitioned embedding in `embedding_weights` must all be the same shape except for the first dimension. The first dimension is allowed to vary as the vocabulary size is not necessarily a multiple of `P`. `embedding_weights` - may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a + may be a `PartitionedVariable` as returned by using + `tf.compat.v1.get_variable()` with a partitioner. Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs @@ -70,25 +71,24 @@ def safe_embedding_lookup_sparse(embedding_weights, Args: embedding_weights: A list of `P` float tensors or values representing - partitioned embedding tensors. Alternatively, a `PartitionedVariable`, - created by partitioning along dimension 0. The total unpartitioned - shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the - vocab size and `e_1, ..., e_m` are the embedding dimensions. + partitioned embedding tensors. Alternatively, a `PartitionedVariable`, + created by partitioning along dimension 0. The total unpartitioned shape + should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the vocab size + and `e_1, ..., e_m` are the embedding dimensions. sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the - ids. `d_0` is typically batch size. + ids. `d_0` is typically batch size. sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing - float weights corresponding to `sparse_ids`, or `None` if all weights - are be assumed to be 1.0. + float weights corresponding to `sparse_ids`, or `None` if all weights are + be assumed to be 1.0. combiner: A string specifying how to combine embedding results for each - entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" - the default. + entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the + default. default_id: The id to use for an entry with no features. name: A name for this operation (optional). - partition_strategy: A string specifying the partitioning strategy. - Currently `"div"` and `"mod"` are supported. Default is `"div"`. + partition_strategy: A string specifying the partitioning strategy. Currently + `"div"` and `"mod"` are supported. Default is `"div"`. max_norm: If not None, all embeddings are l2-normalized to max_norm before - combining. - + combining. Returns: Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. @@ -119,25 +119,24 @@ def safe_embedding_lookup_sparse(embedding_weights, contrib_tensor_util.assert_same_float_dtype(embedding_weights + [sparse_weights]) - with ops.name_scope(name, "embedding_lookup", - embedding_weights + [sparse_ids, - sparse_weights]) as scope: + with ops.name_scope(name, "embedding_lookup", embedding_weights + + [sparse_ids, sparse_weights]) as scope: # Reshape higher-rank sparse ids and weights to linear segment ids. original_shape = sparse_ids.dense_shape - original_rank_dim = tensor_shape.Dimension(tensor_shape.dimension_value( - sparse_ids.dense_shape.get_shape()[0])) + original_rank_dim = tensor_shape.Dimension( + tensor_shape.dimension_value(sparse_ids.dense_shape.get_shape()[0])) original_rank = ( array_ops.size(original_shape) - if original_rank_dim.value is None - else original_rank_dim.value) + if original_rank_dim.value is None else original_rank_dim.value) sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ math_ops.reduce_prod( array_ops.slice(original_shape, [0], [original_rank - 1])), - array_ops.gather(original_shape, original_rank - 1)]) + array_ops.gather(original_shape, original_rank - 1) + ]) if sparse_weights is not None: - sparse_weights = sparse_tensor.SparseTensor( - sparse_ids.indices, - sparse_weights.values, sparse_ids.dense_shape) + sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices, + sparse_weights.values, + sparse_ids.dense_shape) # Prune invalid ids and weights. sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) @@ -146,9 +145,8 @@ def safe_embedding_lookup_sparse(embedding_weights, sparse_ids, sparse_weights) # Fill in dummy values for empty features, if necessary. - sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, - default_id or - 0) + sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows( + sparse_ids, default_id or 0) if sparse_weights is not None: sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) @@ -168,10 +166,8 @@ def safe_embedding_lookup_sparse(embedding_weights, array_ops.reshape(is_row_empty, [-1, 1]), array_ops.stack([1, array_ops.shape(result)[1]])) - result = array_ops.where(is_row_empty, - array_ops.zeros_like(result), - result, - name=scope) + result = array_ops.where( + is_row_empty, array_ops.zeros_like(result), result, name=scope) # Reshape back from linear ids back into higher-dimensional dense result. final_result = array_ops.reshape( @@ -182,8 +178,9 @@ def safe_embedding_lookup_sparse(embedding_weights, [original_rank - 1]), array_ops.slice(array_ops.shape(result), [1], [-1]) ], 0)) - final_result.set_shape(tensor_shape.unknown_shape( - (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) + final_result.set_shape( + tensor_shape.unknown_shape( + (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) return final_result @@ -238,8 +235,8 @@ def scattered_embedding_lookup(params, partitioned in 4 tensors with length `[3, 3, 2, 2]`. Args: - params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. - Each tensor must be of rank 1 with fully-defined shape. + params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. Each + tensor must be of rank 1 with fully-defined shape. values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`. dimension: Embedding dimension. name: An optional name for this op. @@ -256,13 +253,20 @@ def scattered_embedding_lookup(params, if dimension is None: raise ValueError("You must specify dimension.") return _sampled_scattered_embedding_lookup( - params, values, dimension=dimension, sampled_candidates=None, - hash_key=hash_key, name=name) + params, + values, + dimension=dimension, + sampled_candidates=None, + hash_key=hash_key, + name=name) -def _sampled_scattered_embedding_lookup( - params, values, dimension=None, sampled_candidates=None, hash_key=None, - name=None): +def _sampled_scattered_embedding_lookup(params, + values, + dimension=None, + sampled_candidates=None, + hash_key=None, + name=None): """Looks up embeddings using parameter hashing for each value in `values`. This method looks up selected embedding dimensions if `sampled_candidates` is @@ -290,8 +294,8 @@ def _sampled_scattered_embedding_lookup( partitioned in 4 tensors with length `[3, 3, 2, 2]`. Args: - params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. - Each tensor must be of rank 1 with fully-defined shape. + params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. Each + tensor must be of rank 1 with fully-defined shape. values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`. dimension: Embedding dimension. The user must specify either `dimension` or `sampled_candidates`. @@ -327,19 +331,27 @@ def _sampled_scattered_embedding_lookup( "You must specify either dimension or sampled_candidates.") if dimension <= 0: raise ValueError("Dimension must be >0. Given is %d" % dimension) - sampled_candidates = array_ops.tile(array_ops.expand_dims( - math_ops.range(0, dimension), 0), array_ops.shape(values)) + sampled_candidates = array_ops.tile( + array_ops.expand_dims(math_ops.range(0, dimension), 0), + array_ops.shape(values)) else: - dimension = array_ops.shape(sampled_candidates)[ - math_ops.subtract(array_ops.rank(sampled_candidates), 1)] + dimension = array_ops.shape(sampled_candidates)[math_ops.subtract( + array_ops.rank(sampled_candidates), 1)] sampled_candidates_shape = array_ops.shape(sampled_candidates) - dimension_tensor = array_ops.reshape(dimension, shape=[1,]) + dimension_tensor = array_ops.reshape( + dimension, shape=[ + 1, + ]) expected_shape = array_ops.concat([values_shape, dimension_tensor], 0) - with ops.control_dependencies([control_flow_ops.Assert( - math_ops.reduce_all(math_ops.equal(sampled_candidates_shape, - expected_shape)), - ["The shape of sampled_candidates: ", sampled_candidates_shape, - " does not match the shape of values: ", values_shape])]): + with ops.control_dependencies([ + control_flow_ops.Assert( + math_ops.reduce_all( + math_ops.equal(sampled_candidates_shape, expected_shape)), + [ + "The shape of sampled_candidates: ", sampled_candidates_shape, + " does not match the shape of values: ", values_shape + ]) + ]): # Flatten sampled_candidates, same way as values are flattened. sampled_candidates = array_ops.reshape(sampled_candidates, [-1, dimension]) @@ -364,7 +376,9 @@ def _sampled_scattered_embedding_lookup( # [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]]. tensors_to_cross = [sampled_candidates, values] ids = sparse_feature_cross_op.sparse_feature_cross( - tensors_to_cross, hashed_output=True, num_buckets=num_params, + tensors_to_cross, + hashed_output=True, + num_buckets=num_params, hash_key=hash_key) ids = sparse_ops.sparse_tensor_to_dense(ids) @@ -389,14 +403,14 @@ def scattered_embedding_lookup_sparse(params, See `tf.contrib.layers.scattered_embedding_lookup` for embedding with hashing. Args: - params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. - Each tensor must be of rank 1 with fully-defined shape. + params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`. Each + tensor must be of rank 1 with fully-defined shape. sparse_values: A 2-D `SparseTensor` containing the values to be embedded. Some rows may be empty. dimension: Embedding dimension combiner: A string specifying how to combine embedding results for each - entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" - the default. + entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the + default. default_value: The value to use for an entry with no features. name: An optional name for this op. hash_key: Specify the hash_key that will be used by the `FingerprintCat64` @@ -445,14 +459,14 @@ def scattered_embedding_lookup_sparse(params, params, values, dimension, hash_key=hash_key) if combiner == "sum": - embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids, - name=scope) + embeddings = math_ops.sparse_segment_sum( + embeddings, idx, segment_ids, name=scope) elif combiner == "mean": - embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids, - name=scope) + embeddings = math_ops.sparse_segment_mean( + embeddings, idx, segment_ids, name=scope) elif combiner == "sqrtn": - embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids, - name=scope) + embeddings = math_ops.sparse_segment_sqrt_n( + embeddings, idx, segment_ids, name=scope) else: raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.") @@ -469,8 +483,8 @@ def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): Args: params: A list of tensors with the same shape and type, or a `PartitionedVariable`. Shape `[index, d1, d2, ...]`. - ids: A one-dimensional `Tensor` with type `int32` or `int64` containing - the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. + ids: A one-dimensional `Tensor` with type `int32` or `int64` containing the + ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. partition_strategy: A string specifying the partitioning strategy, relevant if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default is `"mod"`. @@ -486,8 +500,8 @@ def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): with ops.name_scope(name, "EmbeddingLookupUnique", [params, ids]): ids = ops.convert_to_tensor(ids) shape = array_ops.shape(ids) - ids_flat = array_ops.reshape( - ids, math_ops.reduce_prod(shape, keepdims=True)) + ids_flat = array_ops.reshape(ids, + math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids, partition_strategy) @@ -528,15 +542,15 @@ def _sampled_scattered_embedding_lookup_sparse(params, sp_values: A 2D `SparseTensor` to be embedded with shape `[d0, d1]`. dimension: An int `Tensor` of the final dimension. The user needs to provide either `dimension` or `sampled_candidates`. - sampled_candidates: An optional `Tensor` of column indices to keep along - the final dimension with shape `[d0, N]`. If given, `dimension` is - ignored. If `None`, looks up all candidates. + sampled_candidates: An optional `Tensor` of column indices to keep along the + final dimension with shape `[d0, N]`. If given, `dimension` is ignored. If + `None`, looks up all candidates. hash_key: Specify the hash_key that will be used by the `FingerprintCat64` function to combine the crosses fingerprints on SparseFeatureCrossOp (optional). with_sign_hash: A `bool` indicating whether `h(i, j)` should be multiplied - by `+1` or `-1`, where the value selected is determined by hashing - `(i, j)`. This is often necessary to remove bias resulting from hash + by `+1` or `-1`, where the value selected is determined by hashing `(i, + j)`. This is often necessary to remove bias resulting from hash collisions. name: An optional name for this op. @@ -562,13 +576,19 @@ def _sampled_scattered_embedding_lookup_sparse(params, sampled_candidates = array_ops.gather(sampled_candidates, segment_ids) embeddings = _sampled_scattered_embedding_lookup( - params, sp_values.values, dimension=dimension, + params, + sp_values.values, + dimension=dimension, sampled_candidates=sampled_candidates, - hash_key=hash_key, name="values_lookup") + hash_key=hash_key, + name="values_lookup") if with_sign_hash: signs = _sampled_scattered_embedding_lookup( - array_ops.constant([-1., 1.]), sp_values.values, dimension=dimension, - sampled_candidates=sampled_candidates, hash_key=hash_key, + array_ops.constant([-1., 1.]), + sp_values.values, + dimension=dimension, + sampled_candidates=sampled_candidates, + hash_key=hash_key, name="signs_lookup") embeddings = math_ops.multiply(signs, embeddings, name="signs_hash") @@ -576,9 +596,8 @@ def _sampled_scattered_embedding_lookup_sparse(params, segment_ids = math_ops.cast(segment_ids, dtypes.int32) num_segments = array_ops.shape(sp_values)[0] - return math_ops.unsorted_segment_sum(embeddings, segment_ids, - num_segments=num_segments, - name=name_scope) + return math_ops.unsorted_segment_sum( + embeddings, segment_ids, num_segments=num_segments, name=name_scope) def embedding_lookup_sparse_with_distributed_aggregation( @@ -596,8 +615,8 @@ def embedding_lookup_sparse_with_distributed_aggregation( `tf.nn.embedding_lookup_sparse` for the functionality and example of this op. Args: - params: A single tensor representing the complete embedding tensor, - or a list of P tensors all of same shape except for the first dimension, + params: A single tensor representing the complete embedding tensor, or a + list of P tensors all of same shape except for the first dimension, representing sharded embedding tensors. Alternatively, a `PartitionedVariable`, created by partitioning along dimension 0. Each element must be appropriately sized for the given `partition_strategy`. @@ -611,13 +630,12 @@ def embedding_lookup_sparse_with_distributed_aggregation( is `"mod"`. See `tf.nn.embedding_lookup` for more details. name: Optional name for the op. combiner: A string specifying the reduction op. Currently "mean", "sqrtn" - and "sum" are supported. - "sum" computes the weighted sum of the embedding results for each row. - "mean" is the weighted sum divided by the total weight. - "sqrtn" is the weighted sum divided by the square root of the sum of the - squares of the weights. - max_norm: If not None, each embedding is normalized to have l2 norm equal - to max_norm before combining. + and "sum" are supported. "sum" computes the weighted sum of the embedding + results for each row. "mean" is the weighted sum divided by the total + weight. "sqrtn" is the weighted sum divided by the square root of the sum + of the squares of the weights. + max_norm: If not None, each embedding is normalized to have l2 norm equal to + max_norm before combining. Returns: A dense tensor representing the combined embeddings for the @@ -798,16 +816,18 @@ def _embedding_lookup_with_distributed_aggregation(params, ids_per_partition = num_total_ids // np extras = num_total_ids % np - p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), ( - flat_ids - extras) // ids_per_partition) + p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // + ids_per_partition) # Emulate a conditional using a boolean indicator tensor is_in_first_extras_partitions = math_ops.cast(p_assignments < extras, flat_ids.dtype) - new_ids = (is_in_first_extras_partitions * (flat_ids % - (ids_per_partition + 1)) + - (1 - is_in_first_extras_partitions) * ( - (flat_ids - extras) % ids_per_partition)) + new_ids = ( + is_in_first_extras_partitions * (flat_ids % + (ids_per_partition + 1)) + + (1 - is_in_first_extras_partitions) * + ((flat_ids - extras) % ids_per_partition)) else: raise ValueError("Unrecognized partition strategy: " + partition_strategy) @@ -851,8 +871,8 @@ def _embedding_lookup_with_distributed_aggregation(params, partitioned_result[p] = array_ops.reshape( partitioned_result[p], array_ops.concat([ - array_ops.shape(pindices[p]), array_ops.slice( - params_shape, [1], [-1]) + array_ops.shape(pindices[p]), + array_ops.slice(params_shape, [1], [-1]) ], 0)) # Normalize each partition result. for p in xrange(np): @@ -877,9 +897,8 @@ def _embedding_lookup_with_distributed_aggregation(params, if partitioned_result[p].get_shape().ndims is not None: partitioned_weight[p].set_shape( orig_weights_shape.concatenate([ - 1 - for _ in range(partitioned_result[p].get_shape().ndims - - 1) + 1 for _ in range(partitioned_result[p].get_shape().ndims - + 1) ])) partitioned_result[p] *= partitioned_weight[p] partitioned_segment_ids = [] diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 00d819ed0e9..c9f4d1eb148 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -156,18 +156,14 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation from tensorflow.python.util import nest - # Imports the core `InputLayer` symbol in contrib during development. InputLayer = fc_core.InputLayer # pylint: disable=invalid-name class _LinearEmbeddingLookupArguments( - collections.namedtuple("_LinearEmbeddingLookupArguments", - ["input_tensor", - "weight_tensor", - "vocab_size", - "initializer", - "combiner"])): + collections.namedtuple("_LinearEmbeddingLookupArguments", [ + "input_tensor", "weight_tensor", "vocab_size", "initializer", "combiner" + ])): """Represents the information needed from a column for embedding lookup. Used to compute DNN inputs and weighted sum. @@ -176,17 +172,11 @@ class _LinearEmbeddingLookupArguments( class _DeepEmbeddingLookupArguments( - collections.namedtuple("_DeepEmbeddingLookupArguments", - ["input_tensor", - "weight_tensor", - "vocab_size", - "initializer", - "combiner", - "dimension", - "shared_embedding_name", - "hash_key", - "max_norm", - "trainable"])): + collections.namedtuple("_DeepEmbeddingLookupArguments", [ + "input_tensor", "weight_tensor", "vocab_size", "initializer", + "combiner", "dimension", "shared_embedding_name", "hash_key", + "max_norm", "trainable" + ])): """Represents the information needed from a column for embedding lookup. Used to compute DNN inputs and weighted sum. @@ -208,33 +198,25 @@ class _FeatureColumn(object): """ @abc.abstractproperty - @deprecation.deprecated( - "2016-09-25", - "Should be private.") + @deprecation.deprecated("2016-09-25", "Should be private.") def name(self): """Returns the name of column or transformed column.""" pass @abc.abstractproperty - @deprecation.deprecated( - "2016-09-25", - "Should be private.") + @deprecation.deprecated("2016-09-25", "Should be private.") def config(self): - """Returns configuration of the base feature for `tf.parse_example`.""" + """Returns configuration of the base feature for `tf.io.parse_example`.""" pass @abc.abstractproperty - @deprecation.deprecated( - "2016-09-25", - "Should be private.") + @deprecation.deprecated("2016-09-25", "Should be private.") def key(self): """Returns a string which will be used as a key when we do sorting.""" pass @abc.abstractmethod - @deprecation.deprecated( - "2016-09-25", - "Should be private.") + @deprecation.deprecated("2016-09-25", "Should be private.") def insert_transformed_feature(self, columns_to_tensors): """Apply transformation and inserts it into columns_to_tensors. @@ -243,8 +225,8 @@ class _FeatureColumn(object): key means a base feature (not-transformed). It can have _FeatureColumn as a key too. That means that _FeatureColumn is already transformed. """ - raise NotImplementedError("Transform is not implemented for {}.".format( - self)) + raise NotImplementedError( + "Transform is not implemented for {}.".format(self)) # pylint: disable=unused-argument def _to_dnn_input_layer(self, @@ -367,8 +349,8 @@ class _SparseColumn( if bucket_size is not None and bucket_size < 1: raise ValueError("bucket_size must be at least 1. " - "bucket_size: {}, column_name: {}".format(bucket_size, - column_name)) + "bucket_size: {}, column_name: {}".format( + bucket_size, column_name)) if ((lookup_config) and (not isinstance(lookup_config, _SparseIdLookupConfig))): @@ -459,12 +441,13 @@ class _SparseColumn( """Check compatibility of two sparse columns.""" if self.lookup_config and other_column.lookup_config: return self.lookup_config == other_column.lookup_config - compatible = (self.length == other_column.length and - (self.dtype == other_column.dtype or - (self.dtype.is_integer and other_column.dtype.is_integer))) + compatible = ( + self.length == other_column.length and + (self.dtype == other_column.dtype or + (self.dtype.is_integer and other_column.dtype.is_integer))) if compatible: - logging.warn("Column {} and {} may not have the same vocabulary.". - format(self.name, other_column.name)) + logging.warn("Column {} and {} may not have the same vocabulary.".format( + self.name, other_column.name)) return compatible @abc.abstractmethod @@ -501,8 +484,8 @@ class _SparseColumnIntegerized(_SparseColumn): """See `sparse_column_with_integerized_feature`.""" def _do_transform(self, input_tensor): - sparse_id_values = math_ops.mod(input_tensor.values, self.bucket_size, - name="mod") + sparse_id_values = math_ops.mod( + input_tensor.values, self.bucket_size, name="mod") return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values, input_tensor.dense_shape) @@ -548,8 +531,11 @@ def sparse_column_with_integerized_feature(column_name, ValueError: dtype is not integer. """ return _SparseColumnIntegerized( - column_name, is_integerized=True, bucket_size=bucket_size, - combiner=combiner, dtype=dtype) + column_name, + is_integerized=True, + bucket_size=bucket_size, + combiner=combiner, + dtype=dtype) class _SparseColumnHashed(_SparseColumn): @@ -601,8 +587,9 @@ class _SparseColumnHashed(_SparseColumn): else: sparse_id_values = string_ops.string_to_hash_bucket_fast( sparse_values, self.bucket_size, name="lookup") - return sparse_tensor_py.SparseTensor( - input_tensor.indices, sparse_id_values, input_tensor.dense_shape) + return sparse_tensor_py.SparseTensor(input_tensor.indices, + sparse_id_values, + input_tensor.dense_shape) def sparse_column_with_hash_bucket(column_name, @@ -662,8 +649,11 @@ class _SparseColumnKeys(_SparseColumn): return table.lookup(input_tensor) -def sparse_column_with_keys( - column_name, keys, default_value=-1, combiner="sum", dtype=dtypes.string): +def sparse_column_with_keys(column_name, + keys, + default_value=-1, + combiner="sum", + dtype=dtypes.string): """Creates a _SparseColumn with keys. Look up logic is as follows: @@ -702,9 +692,8 @@ class _SparseColumnVocabulary(_SparseColumn): def _do_transform(self, st): if self.dtype.is_integer: sparse_string_values = string_ops.as_string(st.values) - sparse_string_tensor = sparse_tensor_py.SparseTensor(st.indices, - sparse_string_values, - st.dense_shape) + sparse_string_tensor = sparse_tensor_py.SparseTensor( + st.indices, sparse_string_values, st.dense_shape) else: sparse_string_tensor = st @@ -774,8 +763,8 @@ class _WeightedSparseColumn( _FeatureColumn, fc_core._CategoricalColumn, # pylint: disable=protected-access collections.namedtuple("_WeightedSparseColumn", - ["sparse_id_column", "weight_column_name", - "dtype"])): + ["sparse_id_column", "weight_column_name", "dtype"]) +): """See `weighted_sparse_column`.""" def __new__(cls, sparse_id_column, weight_column_name, dtype): @@ -840,7 +829,7 @@ class _WeightedSparseColumn( # The weight tensor can be a regular Tensor. In such case, sparsify it. weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor) if not self.dtype.is_floating: - weight_tensor = math_ops.to_float(weight_tensor) + weight_tensor = math_ops.cast(weight_tensor, dtypes.float32) return tuple([id_tensor, weight_tensor]) def insert_transformed_feature(self, columns_to_tensors): @@ -919,8 +908,8 @@ def weighted_sparse_column(sparse_id_column, ValueError: if dtype is not convertible to float. """ if not (dtype.is_integer or dtype.is_floating): - raise ValueError("dtype is not convertible to float. Given {}".format( - dtype)) + raise ValueError( + "dtype is not convertible to float. Given {}".format(dtype)) return _WeightedSparseColumn(sparse_id_column, weight_column_name, dtype) @@ -970,7 +959,7 @@ class _OneHotColumn( Args: transformed_input_tensor: A tensor that has undergone the transformations - in `insert_transformed_feature`. Rank should be >= `output_rank`. + in `insert_transformed_feature`. Rank should be >= `output_rank`. unused_weight_collections: Unused. One hot encodings are not variable. unused_trainable: Unused. One hot encodings are not trainable. output_rank: the desired rank of the output `Tensor`. @@ -991,23 +980,23 @@ class _OneHotColumn( weight_tensor = self.sparse_id_column.weight_tensor( transformed_input_tensor) if weight_tensor is not None: - weighted_column = sparse_ops.sparse_merge(sp_ids=sparse_id_column, - sp_values=weight_tensor, - vocab_size=self.length) + weighted_column = sparse_ops.sparse_merge( + sp_ids=sparse_id_column, + sp_values=weight_tensor, + vocab_size=self.length) # Remove (?, -1) index weighted_column = sparse_ops.sparse_slice( - weighted_column, - array_ops.zeros_like(weighted_column.dense_shape), + weighted_column, array_ops.zeros_like(weighted_column.dense_shape), weighted_column.dense_shape) dense_tensor = sparse_ops.sparse_tensor_to_dense(weighted_column) batch_shape = array_ops.shape(dense_tensor)[:-1] - dense_tensor_shape = array_ops.concat( - [batch_shape, [self.length]], axis=0) + dense_tensor_shape = array_ops.concat([batch_shape, [self.length]], + axis=0) dense_tensor = array_ops.reshape(dense_tensor, dense_tensor_shape) return dense_tensor - dense_id_tensor = sparse_ops.sparse_tensor_to_dense(sparse_id_column, - default_value=-1) + dense_id_tensor = sparse_ops.sparse_tensor_to_dense( + sparse_id_column, default_value=-1) # One hot must be float for tf.concat reasons since all other inputs to # input_layer are float32. @@ -1048,8 +1037,8 @@ class _EmbeddingColumn( sparse_id_column: A `_SparseColumn` which is created by `sparse_column_with_*` or `weighted_sparse_column` functions. dimension: An integer specifying dimension of the embedding. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently "mean", "sqrtn" and "sum" are supported, with + combiner: A string specifying how to reduce if there are multiple entries in + a single row. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the default. "sqrtn" often achieves good accuracy, in particular with bag-of-words columns. Each of this can be thought as example level normalizations on the column: @@ -1059,8 +1048,8 @@ class _EmbeddingColumn( For more information: `tf.embedding_lookup_sparse`. initializer: A variable initializer function to be used in embedding variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean 0.0 and standard deviation - 1/sqrt(sparse_id_column.length). + `tf.compat.v1.truncated_normal_initializer` with mean 0.0 and standard + deviation 1/sqrt(sparse_id_column.length). ckpt_to_load_from: (Optional). String representing checkpoint name/pattern to restore the column weights. Required if `tensor_name_in_ckpt` is not None. @@ -1070,8 +1059,8 @@ class _EmbeddingColumn( shared_embedding_name: (Optional). The common name for shared embedding. shared_vocab_size: (Optional). The common vocab_size used for shared embedding space. - max_norm: (Optional). If not None, embedding values are l2-normalized to - the value of max_norm. + max_norm: (Optional). If not None, embedding values are l2-normalized to the + value of max_norm. trainable: (Optional). Should the embedding be trainable. Default is True. Raises: @@ -1105,14 +1094,11 @@ class _EmbeddingColumn( stddev = 1 / math.sqrt(sparse_id_column.length) initializer = init_ops.truncated_normal_initializer( mean=0.0, stddev=stddev) - return super(_EmbeddingColumn, cls).__new__(cls, sparse_id_column, - dimension, combiner, - initializer, ckpt_to_load_from, - tensor_name_in_ckpt, - shared_embedding_name, - shared_vocab_size, - max_norm, - trainable) + return super(_EmbeddingColumn, + cls).__new__(cls, sparse_id_column, dimension, combiner, + initializer, ckpt_to_load_from, + tensor_name_in_ckpt, shared_embedding_name, + shared_vocab_size, max_norm, trainable) @property def name(self): @@ -1172,8 +1158,7 @@ class _EmbeddingColumn( def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): return _embeddings_from_arguments( - self, - self._deep_embedding_lookup_arguments(inputs.get(self)), + self, self._deep_embedding_lookup_arguments(inputs.get(self)), weight_collections, trainable) def _transform_feature(self, inputs): @@ -1186,8 +1171,8 @@ class _EmbeddingColumn( def _is_variable(v): """Returns true if `v` is a variable.""" - return isinstance(v, (variables.Variable, - resource_variable_ops.ResourceVariable)) + return isinstance( + v, (variables.Variable, resource_variable_ops.ResourceVariable)) def _embeddings_from_arguments(column, @@ -1237,26 +1222,26 @@ def _embeddings_from_arguments(column, name="lookup") if args.shared_embedding_name is not None: - shared_embedding_collection_name = ( - "SHARED_EMBEDDING_COLLECTION_" + args.shared_embedding_name.upper()) + shared_embedding_collection_name = ("SHARED_EMBEDDING_COLLECTION_" + + args.shared_embedding_name.upper()) graph = ops.get_default_graph() shared_embedding_collection = ( graph.get_collection_ref(shared_embedding_collection_name)) shape = [args.vocab_size, args.dimension] if shared_embedding_collection: if len(shared_embedding_collection) > 1: - raise ValueError( - "Collection %s can only contain one " - "(partitioned) variable." % shared_embedding_collection_name) + raise ValueError("Collection %s can only contain one " + "(partitioned) variable." % + shared_embedding_collection_name) else: embeddings = shared_embedding_collection[0] if embeddings.get_shape() != shape: - raise ValueError( - "The embedding variable with name {} already " - "exists, but its shape does not match required " - "embedding shape here. Please make sure to use " - "different shared_embedding_name for different " - "shared embeddings.".format(args.shared_embedding_name)) + raise ValueError("The embedding variable with name {} already " + "exists, but its shape does not match required " + "embedding shape here. Please make sure to use " + "different shared_embedding_name for different " + "shared embeddings.".format( + args.shared_embedding_name)) else: embeddings = contrib_variables.model_variable( name=args.shared_embedding_name, @@ -1305,9 +1290,8 @@ def one_hot_column(sparse_id_column): Args: sparse_id_column: A _SparseColumn which is created by - `sparse_column_with_*` - or crossed_column functions. Note that `combiner` defined in - `sparse_id_column` is ignored. + `sparse_column_with_*` or crossed_column functions. Note that `combiner` + defined in `sparse_id_column` is ignored. Returns: An _OneHotColumn. @@ -1330,8 +1314,8 @@ def embedding_column(sparse_id_column, `sparse_column_with_*` or crossed_column functions. Note that `combiner` defined in `sparse_id_column` is ignored. dimension: An integer specifying dimension of the embedding. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently "mean", "sqrtn" and "sum" are supported, with + combiner: A string specifying how to reduce if there are multiple entries in + a single row. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the default. "sqrtn" often achieves good accuracy, in particular with bag-of-words columns. Each of this can be thought as example level normalizations on the column: @@ -1341,24 +1325,30 @@ def embedding_column(sparse_id_column, For more information: `tf.embedding_lookup_sparse`. initializer: A variable initializer function to be used in embedding variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean 0.0 and standard deviation - 1/sqrt(sparse_id_column.length). + `tf.compat.v1.truncated_normal_initializer` with mean 0.0 and standard + deviation 1/sqrt(sparse_id_column.length). ckpt_to_load_from: (Optional). String representing checkpoint name/pattern to restore the column weights. Required if `tensor_name_in_ckpt` is not None. tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided checkpoint from which to restore the column weights. Required if `ckpt_to_load_from` is not None. - max_norm: (Optional). If not None, embedding values are l2-normalized to - the value of max_norm. + max_norm: (Optional). If not None, embedding values are l2-normalized to the + value of max_norm. trainable: (Optional). Should the embedding be trainable. Default is True Returns: An `_EmbeddingColumn`. """ - return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer, - ckpt_to_load_from, tensor_name_in_ckpt, - max_norm=max_norm, trainable=trainable) + return _EmbeddingColumn( + sparse_id_column, + dimension, + combiner, + initializer, + ckpt_to_load_from, + tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable) def shared_embedding_columns(sparse_id_columns, @@ -1377,8 +1367,8 @@ def shared_embedding_columns(sparse_id_columns, `sparse_column_with_*` or crossed_column functions. Note that `combiner` defined in each sparse_id_column is ignored. dimension: An integer specifying dimension of the embedding. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently "mean", "sqrtn" and "sum" are supported, with + combiner: A string specifying how to reduce if there are multiple entries in + a single row. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the default. "sqrtn" often achieves good accuracy, in particular with bag-of-words columns. Each of this can be thought as example level normalizations on the column: @@ -1391,16 +1381,16 @@ def shared_embedding_columns(sparse_id_columns, embedding separately from the generated `_EmbeddingColumn`. initializer: A variable initializer function to be used in embedding variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean 0.0 and standard deviation - 1/sqrt(sparse_id_columns[0].length). + `tf.compat.v1.truncated_normal_initializer` with mean 0.0 and standard + deviation 1/sqrt(sparse_id_columns[0].length). ckpt_to_load_from: (Optional). String representing checkpoint name/pattern to restore the column weights. Required if `tensor_name_in_ckpt` is not None. tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided checkpoint from which to restore the column weights. Required if `ckpt_to_load_from` is not None. - max_norm: (Optional). If not None, embedding values are l2-normalized to - the value of max_norm. + max_norm: (Optional). If not None, embedding values are l2-normalized to the + value of max_norm. trainable: (Optional). Should the embedding be trainable. Default is True Returns: @@ -1424,16 +1414,23 @@ def shared_embedding_columns(sparse_id_columns, for sparse_id_column in sparse_id_columns: if not (isinstance(sparse_id_column, _SparseColumn) or isinstance(sparse_id_column, _WeightedSparseColumn)): - raise TypeError("Elements of sparse_id_columns must be _SparseColumn or " - "_WeightedSparseColumn, but {} is not." - .format(sparse_id_column)) + raise TypeError( + "Elements of sparse_id_columns must be _SparseColumn or " + "_WeightedSparseColumn, but {} is not.".format(sparse_id_column)) if len(sparse_id_columns) == 1: return [ - _EmbeddingColumn(sparse_id_columns[0], dimension, combiner, initializer, - ckpt_to_load_from, tensor_name_in_ckpt, - shared_embedding_name, max_norm=max_norm, - trainable=trainable)] + _EmbeddingColumn( + sparse_id_columns[0], + dimension, + combiner, + initializer, + ckpt_to_load_from, + tensor_name_in_ckpt, + shared_embedding_name, + max_norm=max_norm, + trainable=trainable) + ] else: # Check compatibility of sparse_id_columns compatible = True @@ -1460,11 +1457,11 @@ def shared_embedding_columns(sparse_id_columns, sorted_columns = sorted(sparse_columns) + sorted( weighted_sparse_columns, key=lambda x: x.name) if len(sorted_columns) <= 3: - shared_embedding_name = "_".join([column.name - for column in sorted_columns]) + shared_embedding_name = "_".join( + [column.name for column in sorted_columns]) else: - shared_embedding_name = "_".join([column.name - for column in sorted_columns[0:3]]) + shared_embedding_name = "_".join( + [column.name for column in sorted_columns[0:3]]) shared_embedding_name += ( "_plus_{}_others".format(len(sorted_columns) - 3)) shared_embedding_name += "_shared_embedding" @@ -1473,10 +1470,17 @@ def shared_embedding_columns(sparse_id_columns, embedded_columns = [] for column in sparse_id_columns: embedded_columns.append( - _EmbeddingColumn(column, dimension, combiner, initializer, - ckpt_to_load_from, tensor_name_in_ckpt, - shared_embedding_name, shared_vocab_size, - max_norm=max_norm, trainable=trainable)) + _EmbeddingColumn( + column, + dimension, + combiner, + initializer, + ckpt_to_load_from, + tensor_name_in_ckpt, + shared_embedding_name, + shared_vocab_size, + max_norm=max_norm, + trainable=trainable)) return tuple(embedded_columns) @@ -1503,10 +1507,9 @@ class _ScatteredEmbeddingColumn( stddev = 0.1 initializer = init_ops.truncated_normal_initializer( mean=0.0, stddev=stddev) - return super(_ScatteredEmbeddingColumn, cls).__new__(cls, column_name, size, - dimension, hash_key, - combiner, - initializer) + return super(_ScatteredEmbeddingColumn, + cls).__new__(cls, column_name, size, dimension, hash_key, + combiner, initializer) @property def name(self): @@ -1543,8 +1546,7 @@ class _ScatteredEmbeddingColumn( def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): return _embeddings_from_arguments( - self, - self._deep_embedding_lookup_arguments(inputs.get(self)), + self, self._deep_embedding_lookup_arguments(inputs.get(self)), weight_collections, trainable) def _transform_feature(self, inputs): @@ -1593,8 +1595,8 @@ def scattered_embedding_column(column_name, dimension: An integer specifying dimension of the embedding. hash_key: Specify the hash_key that will be used by the `FingerprintCat64` function to combine the crosses fingerprints on SparseFeatureCrossOp. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently "mean", "sqrtn" and "sum" are supported, with + combiner: A string specifying how to reduce if there are multiple entries in + a single row. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the default. "sqrtn" often achieves good accuracy, in particular with bag-of-words columns. Each of this can be thought as example level normalizations on the column: @@ -1604,7 +1606,8 @@ def scattered_embedding_column(column_name, For more information: `tf.embedding_lookup_sparse`. initializer: A variable initializer function to be used in embedding variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean 0 and standard deviation 0.1. + `tf.compat.v1.truncated_normal_initializer` with mean 0 and standard + deviation 0.1. Returns: A _ScatteredEmbeddingColumn. @@ -1621,8 +1624,8 @@ def scattered_embedding_column(column_name, if combiner not in ("mean", "sqrtn", "sum"): raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'. " - "combiner: {}, column_name: {}".format(combiner, - column_name)) + "combiner: {}, column_name: {}".format( + combiner, column_name)) return _ScatteredEmbeddingColumn(column_name, size, dimension, hash_key, combiner, initializer) @@ -1644,6 +1647,7 @@ def _reshape_real_valued_tensor(input_tensor, output_rank, column_name=None): output_rank: the desired rank of the reshaped `Tensor`. column_name: (optional) the name of the associated column. Used for error messages. + Returns: A `Tensor` with the same entries as `input_tensor` and rank `output_rank`. Raises: @@ -1658,8 +1662,9 @@ def _reshape_real_valued_tensor(input_tensor, output_rank, column_name=None): "data is typically 2 dimensional (rank 2).".format( input_rank, output_rank)) if column_name is not None: - error_string = ("Error while processing column {}.".format(column_name) - + error_string) + error_string = ( + "Error while processing column {}.".format(column_name) + + error_string) raise ValueError(error_string) if output_rank == input_rank + 1: logging.warning( @@ -1674,9 +1679,11 @@ def _reshape_real_valued_tensor(input_tensor, output_rank, column_name=None): return layers._inner_flatten(input_tensor, output_rank) # pylint: disable=protected-access -class _RealValuedVarLenColumn(_FeatureColumn, collections.namedtuple( - "_RealValuedVarLenColumn", - ["column_name", "default_value", "dtype", "normalizer", "is_sparse"])): +class _RealValuedVarLenColumn( + _FeatureColumn, + collections.namedtuple( + "_RealValuedVarLenColumn", + ["column_name", "default_value", "dtype", "normalizer", "is_sparse"])): """Represents a real valued feature column for variable length Features. Instances of this class are immutable. @@ -1695,9 +1702,14 @@ class _RealValuedVarLenColumn(_FeatureColumn, collections.namedtuple( if self.is_sparse: return {self.column_name: parsing_ops.VarLenFeature(self.dtype)} else: - return {self.column_name: parsing_ops.FixedLenSequenceFeature( - [], self.dtype, allow_missing=True, - default_value=self.default_value)} + return { + self.column_name: + parsing_ops.FixedLenSequenceFeature( + [], + self.dtype, + allow_missing=True, + default_value=self.default_value) + } @property def key(self): @@ -1714,10 +1726,9 @@ class _RealValuedVarLenColumn(_FeatureColumn, collections.namedtuple( if self.normalizer is None: return input_tensor if self.is_sparse: - return sparse_tensor_py.SparseTensor( - input_tensor.indices, - self.normalizer(input_tensor.values), - input_tensor.dense_shape) + return sparse_tensor_py.SparseTensor(input_tensor.indices, + self.normalizer(input_tensor.values), + input_tensor.dense_shape) else: return self.normalizer(input_tensor) @@ -1731,7 +1742,7 @@ class _RealValuedVarLenColumn(_FeatureColumn, collections.namedtuple( """ # Transform the input tensor according to the normalizer function. input_tensor = self._normalized_input_tensor(columns_to_tensors[self.name]) - columns_to_tensors[self] = math_ops.to_float(input_tensor) + columns_to_tensors[self] = math_ops.cast(input_tensor, dtypes.float32) # pylint: disable=unused-argument def _to_dnn_input_layer(self, @@ -1780,9 +1791,10 @@ def _real_valued_var_len_column(column_name, of the real valued column after default_value is applied for parsing. Normalizer function takes the input tensor as its argument, and returns the output tensor. (e.g. lambda x: (x - 3.0) / 4.2). Note that for - is_sparse=False, the normalizer will be run on the values of the - `SparseTensor`. + is_sparse=False, the normalizer will be run on the values of the + `SparseTensor`. is_sparse: A boolean defining whether to create a SparseTensor or a Tensor. + Returns: A _RealValuedSparseColumn. Raises: @@ -1825,13 +1837,12 @@ class _RealValuedColumn( (batch_size, dimension). """ - def __new__(cls, column_name, dimension, default_value, - dtype, normalizer): + def __new__(cls, column_name, dimension, default_value, dtype, normalizer): if default_value is not None: default_value = tuple(default_value) - return super(_RealValuedColumn, cls).__new__(cls, column_name, dimension, - default_value, dtype, - normalizer) + return super(_RealValuedColumn, + cls).__new__(cls, column_name, dimension, default_value, dtype, + normalizer) @property def name(self): @@ -1842,9 +1853,11 @@ class _RealValuedColumn( default_value = self.default_value if default_value is not None: default_value = list(default_value) - return {self.column_name: parsing_ops.FixedLenFeature([self.dimension], - self.dtype, - default_value)} + return { + self.column_name: + parsing_ops.FixedLenFeature([self.dimension], self.dtype, + default_value) + } @property def key(self): @@ -1858,8 +1871,8 @@ class _RealValuedColumn( def _normalized_input_tensor(self, input_tensor): """Returns the input tensor after custom normalization is applied.""" - return (self.normalizer(input_tensor) if self.normalizer is not None else - input_tensor) + return (self.normalizer(input_tensor) + if self.normalizer is not None else input_tensor) def insert_transformed_feature(self, columns_to_tensors): """Apply transformation and inserts it into columns_to_tensors. @@ -1871,7 +1884,7 @@ class _RealValuedColumn( """ # Transform the input tensor according to the normalizer function. input_tensor = self._normalized_input_tensor(columns_to_tensors[self.name]) - columns_to_tensors[self] = math_ops.to_float(input_tensor) + columns_to_tensors[self] = math_ops.cast(input_tensor, dtypes.float32) # pylint: disable=unused-argument def _to_dnn_input_layer(self, @@ -1881,7 +1894,7 @@ class _RealValuedColumn( output_rank=2): input_tensor = self._to_dense_tensor(input_tensor) if input_tensor.dtype != dtypes.float32: - input_tensor = math_ops.to_float(input_tensor) + input_tensor = math_ops.cast(input_tensor, dtypes.float32) return _reshape_real_valued_tensor(input_tensor, output_rank, self.name) def _to_dense_tensor(self, input_tensor): @@ -1897,8 +1910,8 @@ class _RealValuedColumn( return inputs.get(self) def _transform_feature(self, inputs): - return math_ops.to_float( - self._normalized_input_tensor(inputs.get(self.name))) + return math_ops.cast( + self._normalized_input_tensor(inputs.get(self.name)), dtypes.float32) @property def _parse_example_spec(self): @@ -1914,24 +1927,25 @@ def real_valued_column(column_name, Args: column_name: A string defining real valued column name. - dimension: An integer specifying dimension of the real valued column. - The default is 1. + dimension: An integer specifying dimension of the real valued column. The + default is 1. default_value: A single value compatible with dtype or a list of values compatible with dtype which the column takes on during tf.Example parsing if data is missing. When dimension is not None, a default value of None - will cause tf.parse_example to fail if an example does not contain this + will cause tf.io.parse_example to fail if an example does not contain this column. If a single value is provided, the same value will be applied as the default value for every dimension. If a list of values is provided, - the length of the list should be equal to the value of `dimension`. - Only scalar default value is supported in case dimension is not specified. + the length of the list should be equal to the value of `dimension`. Only + scalar default value is supported in case dimension is not specified. dtype: defines the type of values. Default value is tf.float32. Must be a non-quantized, real integer or floating point type. normalizer: If not None, a function that can be used to normalize the value of the real valued column after default_value is applied for parsing. Normalizer function takes the input tensor as its argument, and returns the output tensor. (e.g. lambda x: (x - 3.0) / 4.2). Note that for - variable length columns, the normalizer should expect an input_tensor of - type `SparseTensor`. + variable length columns, the normalizer should expect an input_tensor of + type `SparseTensor`. + Returns: A _RealValuedColumn. Raises: @@ -1946,16 +1960,16 @@ def real_valued_column(column_name, if dimension is None: raise TypeError("dimension must be an integer. Use the " "_real_valued_var_len_column for variable length features." - "dimension: {}, column_name: {}".format(dimension, - column_name)) + "dimension: {}, column_name: {}".format( + dimension, column_name)) if not isinstance(dimension, int): raise TypeError("dimension must be an integer. " - "dimension: {}, column_name: {}".format(dimension, - column_name)) + "dimension: {}, column_name: {}".format( + dimension, column_name)) if dimension < 1: raise ValueError("dimension must be greater than 0. " - "dimension: {}, column_name: {}".format(dimension, - column_name)) + "dimension: {}, column_name: {}".format( + dimension, column_name)) if not (dtype.is_integer or dtype.is_floating): raise ValueError("dtype must be convertible to float. " @@ -1967,21 +1981,21 @@ def real_valued_column(column_name, if isinstance(default_value, int): if dtype.is_integer: - default_value = ([default_value for _ in range(dimension)] if dimension - else [default_value]) + default_value = ([default_value for _ in range(dimension)] + if dimension else [default_value]) return _RealValuedColumn(column_name, dimension, default_value, dtype, normalizer) if dtype.is_floating: default_value = float(default_value) - default_value = ([default_value for _ in range(dimension)] if dimension - else [default_value]) + default_value = ([default_value for _ in range(dimension)] + if dimension else [default_value]) return _RealValuedColumn(column_name, dimension, default_value, dtype, normalizer) if isinstance(default_value, float): if dtype.is_floating and (not dtype.is_integer): - default_value = ([default_value for _ in range(dimension)] if dimension - else [default_value]) + default_value = ([default_value for _ in range(dimension)] + if dimension else [default_value]) return _RealValuedColumn(column_name, dimension, default_value, dtype, normalizer) @@ -2023,8 +2037,8 @@ class _BucketizedColumn( _FeatureColumn, fc_core._CategoricalColumn, # pylint: disable=protected-access fc_core._DenseColumn, # pylint: disable=protected-access - collections.namedtuple("_BucketizedColumn", ["source_column", - "boundaries"])): + collections.namedtuple("_BucketizedColumn", + ["source_column", "boundaries"])): """Represents a bucketization transformation also known as binning. Instances of this class are immutable. Values in `source_column` will be @@ -2039,8 +2053,9 @@ class _BucketizedColumn( Attributes: source_column: A _RealValuedColumn defining dense column. boundaries: A list or tuple of floats specifying the boundaries. It has to - be sorted. [a, b, c] defines following buckets: (-inf., a), [a, b), - [b, c), [c, inf.) + be sorted. [a, b, c] defines following buckets: (-inf., a), [a, b), [b, + c), [c, inf.) + Raises: ValueError: if 'boundaries' is empty or not sorted. """ @@ -2104,7 +2119,7 @@ class _BucketizedColumn( raise ValueError("BucketizedColumn currently only supports output_rank=2") return array_ops.reshape( array_ops.one_hot( - math_ops.to_int64(input_tensor), + math_ops.cast(input_tensor, dtypes.int64), self.length, 1., 0., @@ -2136,10 +2151,12 @@ class _BucketizedColumn( i2 = array_ops.zeros([batch_size], dtype=dtypes.int32, name="zeros") bucket_indices = array_ops.reshape(input_tensor, [-1], name="reshape") - indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2)))) - shape = math_ops.to_int64(array_ops.stack([batch_size, dimension])) - sparse_id_values = sparse_tensor_py.SparseTensor( - indices, bucket_indices, shape) + indices = math_ops.cast( + array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64) + shape = math_ops.cast( + array_ops.stack([batch_size, dimension]), dtypes.int64) + sparse_id_values = sparse_tensor_py.SparseTensor(indices, bucket_indices, + shape) return sparse_id_values @@ -2242,8 +2259,8 @@ class _CrossedColumn( columns: An iterable of _FeatureColumn. Items can be an instance of _SparseColumn, _CrossedColumn, or _BucketizedColumn. hash_bucket_size: An int that is > 1. The number of buckets. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently "mean", "sqrtn" and "sum" are supported, with + combiner: A string specifying how to reduce if there are multiple entries in + a single row. Currently "mean", "sqrtn" and "sum" are supported, with "sum" the default. "sqrtn" often achieves good accuracy, in particular with bag-of-words columns. Each of this can be thought as example level normalizations on the column:: @@ -2296,13 +2313,12 @@ class _CrossedColumn( raise ValueError("Must specify both `ckpt_to_load_from` and " "`tensor_name_in_ckpt` or none of them.") - sorted_columns = sorted( - [column for column in columns], key=lambda column: column.name) - return super(_CrossedColumn, cls).__new__(cls, tuple(sorted_columns), - hash_bucket_size, hash_key, - combiner, - ckpt_to_load_from, - tensor_name_in_ckpt) + sorted_columns = sorted([column for column in columns], + key=lambda column: column.name) + return super(_CrossedColumn, + cls).__new__(cls, tuple(sorted_columns), hash_bucket_size, + hash_key, combiner, ckpt_to_load_from, + tensor_name_in_ckpt) @property def name(self): @@ -2429,7 +2445,9 @@ class _LazyBuilderByColumnsToTensor(object): return self._columns_to_tensors[key] -def crossed_column(columns, hash_bucket_size, combiner="sum", +def crossed_column(columns, + hash_bucket_size, + combiner="sum", ckpt_to_load_from=None, tensor_name_in_ckpt=None, hash_key=None): @@ -2439,8 +2457,8 @@ def crossed_column(columns, hash_bucket_size, combiner="sum", columns: An iterable of _FeatureColumn. Items can be an instance of _SparseColumn, _CrossedColumn, or _BucketizedColumn. hash_bucket_size: An int that is > 1. The number of buckets. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently "mean", "sqrtn" and "sum" are supported, with + combiner: A string specifying how to reduce if there are multiple entries in + a single row. Currently "mean", "sqrtn" and "sum" are supported, with "sum" the default. "sqrtn" often achieves good accuracy, in particular with bag-of-words columns. Each of this can be thought as example level normalizations on the column:: @@ -2527,7 +2545,7 @@ class DataFrameColumn(_FeatureColumn, trainable=True, output_rank=2): if input_tensor.dtype != dtypes.float32: - input_tensor = math_ops.to_float(input_tensor) + input_tensor = math_ops.cast(input_tensor, dtypes.float32) return _reshape_real_valued_tensor(input_tensor, output_rank, self.name) def _to_dense_tensor(self, input_tensor): @@ -2549,11 +2567,10 @@ def _get_feature_config(feature_column): raise TypeError( "feature_columns should only contain instances of _FeatureColumn. " "Given column is {}".format(feature_column)) - if isinstance(feature_column, (_SparseColumn, _WeightedSparseColumn, - _EmbeddingColumn, _RealValuedColumn, - _RealValuedVarLenColumn, - _BucketizedColumn, _CrossedColumn, - _OneHotColumn, _ScatteredEmbeddingColumn)): + if isinstance(feature_column, + (_SparseColumn, _WeightedSparseColumn, _EmbeddingColumn, + _RealValuedColumn, _RealValuedVarLenColumn, _BucketizedColumn, + _CrossedColumn, _OneHotColumn, _ScatteredEmbeddingColumn)): return feature_column.config raise TypeError("Not supported _FeatureColumn type. " @@ -2577,7 +2594,7 @@ def create_feature_spec_for_parsing(feature_columns): feature_columns = set( [feature_b, feature_c_bucketized, feature_a_x_feature_c]) - batch_examples = tf.parse_example( + batch_examples = tf.io.parse_example( serialized=serialized_examples, features=create_feature_spec_for_parsing(feature_columns)) ``` @@ -2594,6 +2611,7 @@ def create_feature_spec_for_parsing(feature_columns): should be instances of classes derived from _FeatureColumn, unless feature_columns is a dict -- in which case, this should be true of all values in the dict. + Returns: A dict mapping feature keys to FixedLenFeature or VarLenFeature values. """ @@ -2615,6 +2633,7 @@ def _create_sequence_feature_spec_for_parsing(sequence_feature_columns, All items should be instances of classes derived from `_FeatureColumn`. allow_missing_by_default: whether to set `allow_missing=True` by default for `FixedLenSequenceFeature`s. + Returns: A dict mapping feature keys to `FixedLenSequenceFeature` or `VarLenFeature`. """ @@ -2629,15 +2648,15 @@ def _create_sequence_feature_spec_for_parsing(sequence_feature_columns, if default_is_set: logging.warning( 'Found default value {} for feature "{}". Ignoring this value and ' - 'setting `allow_missing=True` instead.'. - format(feature.default_value, key)) + "setting `allow_missing=True` instead.".format( + feature.default_value, key)) sequence_feature = parsing_ops.FixedLenSequenceFeature( shape=feature.shape, dtype=feature.dtype, allow_missing=(allow_missing_by_default or default_is_set)) else: - raise TypeError( - "Unsupported feature type: {}".format(type(feature).__name__)) + raise TypeError("Unsupported feature type: {}".format( + type(feature).__name__)) sequence_feature_spec[key] = sequence_feature return sequence_feature_spec @@ -2648,6 +2667,7 @@ def make_place_holder_tensors_for_base_features(feature_columns): Args: feature_columns: An iterable containing all the feature columns. All items should be instances of classes derived from _FeatureColumn. + Returns: A dict mapping feature keys to SparseTensors (sparse columns) or placeholder Tensors (dense columns). @@ -2670,9 +2690,10 @@ def make_place_holder_tensors_for_base_features(feature_columns): class _SparseIdLookupConfig( - collections.namedtuple("_SparseIdLookupConfig", - ["vocabulary_file", "keys", "num_oov_buckets", - "vocab_size", "default_value"])): + collections.namedtuple("_SparseIdLookupConfig", [ + "vocabulary_file", "keys", "num_oov_buckets", "vocab_size", + "default_value" + ])): """Defines lookup configuration for a sparse feature. An immutable object defines lookup table configuration used by @@ -2697,6 +2718,6 @@ class _SparseIdLookupConfig( vocab_size=None, default_value=-1): - return super(_SparseIdLookupConfig, cls).__new__(cls, vocabulary_file, keys, - num_oov_buckets, - vocab_size, default_value) + return super(_SparseIdLookupConfig, + cls).__new__(cls, vocabulary_file, keys, num_oov_buckets, + vocab_size, default_value) diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index a85cff4f709..37594fb81dd 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -170,7 +170,7 @@ def input_from_feature_columns(columns_to_tensors, ```python # Building model for training - columns_to_tensor = tf.parse_example(...) + columns_to_tensor = tf.io.parse_example(...) first_layer = input_from_feature_columns( columns_to_tensors=columns_to_tensor, feature_columns=feature_columns) @@ -449,7 +449,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors, real_valued_column("my_feature1"), ... ) - columns_to_tensor = tf.parse_example(...) + columns_to_tensor = tf.io.parse_example(...) logits = weighted_sum_from_feature_columns( columns_to_tensors=columns_to_tensor, feature_columns=feature_columns, @@ -488,7 +488,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors, default_name='weighted_sum_from_feature_columns', values=columns_to_tensors.values()): output_tensors = [] - column_to_variable = dict() + column_to_variable = {} transformer = _Transformer(columns_to_tensors) # pylint: disable=protected-access for column in sorted(set(feature_columns), key=lambda x: x.key): @@ -548,7 +548,7 @@ def parse_feature_columns_from_examples(serialized, example_names=None): """Parses tf.Examples to extract tensors for given feature_columns. - This is a wrapper of 'tf.parse_example'. + This is a wrapper of 'tf.io.parse_example'. Example: @@ -806,7 +806,7 @@ class _Transformer(object): sparse_x_real = crossed_column( columns=[sparse_feature, real_valued_buckets], hash_bucket_size=10000) - columns_to_tensor = tf.parse_example(...) + columns_to_tensor = tf.io.parse_example(...) transformer = Transformer(columns_to_tensor) sparse_x_real_tensor = transformer.transform(sparse_x_real) diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index 655f038b184..51e5f4d68b9 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -47,7 +47,7 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32): Args: uniform: Whether to use uniform or normal distributed random initialization. seed: A Python integer. Used to create random seeds. See - `tf.set_random_seed` for behavior. + `tf.compat.v1.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. Returns: @@ -98,7 +98,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'. uniform: Whether to use uniform or normal distributed random initialization. seed: A Python integer. Used to create random seeds. See - `tf.set_random_seed` for behavior. + `tf.compat.v1.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. Returns: diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 1d959b3c784..7507e1fffa6 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -91,11 +91,11 @@ def avg_pool2d(inputs, `data_format` is `NHWC`, and `[batch_size, channels, height, width]` if `data_format` is `NCHW`. kernel_size: A list of length 2: [kernel_height, kernel_width] of the - pooling kernel over which the op is computed. Can be an int if both - values are the same. - stride: A list of length 2: [stride_height, stride_width]. - Can be an int if both strides are the same. Note that presently - both strides must have the same value. + pooling kernel over which the op is computed. Can be an int if both values + are the same. + stride: A list of length 2: [stride_height, stride_width]. Can be an int if + both strides are the same. Note that presently both strides must have the + same value. padding: The padding method, either 'VALID' or 'SAME'. data_format: A string. `NHWC` (default) and `NCHW` are supported. outputs_collections: The collections to which the outputs are added. @@ -142,9 +142,9 @@ def avg_pool3d(inputs, kernel_size: A list of length 3: [kernel_depth, kernel_height, kernel_width] of the pooling kernel over which the op is computed. Can be an int if both values are the same. - stride: A list of length 3: [stride_depth, stride_height, stride_width]. - Can be an int if both strides are the same. Note that presently - both strides must have the same value. + stride: A list of length 3: [stride_depth, stride_height, stride_width]. Can + be an int if both strides are the same. Note that presently both strides + must have the same value. padding: The padding method, either 'VALID' or 'SAME'. data_format: A string. `NDHWC` (default) and `NCDHW` are supported. outputs_collections: The collections to which the outputs are added. @@ -203,7 +203,7 @@ def _fused_batch_norm(inputs, need to be added as a dependency to the `train_op`. For example: ```python - update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) ``` @@ -218,14 +218,14 @@ def _fused_batch_norm(inputs, `NCHW`. decay: Decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. - Lower `decay` value (recommend trying `decay`=0.9) if model experiences - reasonably good training performance but poor validation and/or test - performance. + Lower `decay` value (recommend trying `decay`=0.9) if model experiences + reasonably good training performance but poor validation and/or test + performance. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is - not used. When the next layer is linear (also e.g. `nn.relu`), this can be - disabled since the scaling can be done by the next layer. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the + next layer is linear (also e.g. `nn.relu`), this can be disabled since the + scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. @@ -233,9 +233,8 @@ def _fused_batch_norm(inputs, moving variance. param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. - The updates_ops need to be executed with the train_op. - If None, a control dependency would be added to make sure the updates are - computed in place. + The updates_ops need to be executed with the train_op. If None, a control + dependency would be added to make sure the updates are computed in place. is_training: Whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given @@ -272,8 +271,8 @@ def _fused_batch_norm(inputs, raise ValueError('Inputs %s has undefined rank' % inputs.name) elif original_rank not in [2, 4]: raise ValueError('Inputs %s has unsupported rank.' - ' Expected 2 or 4 but got %d' % (inputs.name, - original_rank)) + ' Expected 2 or 4 but got %d' % + (inputs.name, original_rank)) if original_rank == 2: channels = inputs.get_shape().dims[-1].value if channels is None: @@ -379,8 +378,9 @@ def _fused_batch_norm(inputs, is_training=False, data_format=data_format) - outputs, mean, variance = utils.smart_cond( - is_training, _fused_batch_norm_training, _fused_batch_norm_inference) + outputs, mean, variance = utils.smart_cond(is_training, + _fused_batch_norm_training, + _fused_batch_norm_inference) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and @@ -413,8 +413,9 @@ def _fused_batch_norm(inputs, moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance - update_mean, update_variance = utils.smart_cond( - is_training, _delay_updates, moving_vars_fn) + update_mean, update_variance = utils.smart_cond(is_training, + _delay_updates, + moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) @@ -469,7 +470,7 @@ def batch_norm(inputs, need to be added as a dependency to the `train_op`. For example: ```python - update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) ``` @@ -484,14 +485,14 @@ def batch_norm(inputs, `NCHW`. decay: Decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. - Lower `decay` value (recommend trying `decay`=0.9) if model experiences - reasonably good training performance but poor validation and/or test - performance. Try zero_debias_moving_mean=True for improved stability. + Lower `decay` value (recommend trying `decay`=0.9) if model experiences + reasonably good training performance but poor validation and/or test + performance. Try zero_debias_moving_mean=True for improved stability. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is - not used. When the next layer is linear (also e.g. `nn.relu`), this can be - disabled since the scaling can be done by the next layer. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the + next layer is linear (also e.g. `nn.relu`), this can be disabled since the + scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. @@ -499,9 +500,8 @@ def batch_norm(inputs, moving variance. param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. - The updates_ops need to be executed with the train_op. - If None, a control dependency would be added to make sure the updates are - computed in place. + The updates_ops need to be executed with the train_op. If None, a control + dependency would be added to make sure the updates are computed in place. is_training: Whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given @@ -513,11 +513,10 @@ def batch_norm(inputs, outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - batch_weights: An optional tensor of shape `[batch_size]`, - containing a frequency weight for each batch item. If present, - then the batch normalization uses weighted mean and - variance. (This can be used to correct for bias in training - example selection.) + batch_weights: An optional tensor of shape `[batch_size]`, containing a + frequency weight for each batch item. If present, then the batch + normalization uses weighted mean and variance. (This can be used to + correct for bias in training example selection.) fused: if `None` or `True`, use a faster, fused implementation if possible. If `False`, use the system recommended implementation. data_format: A string. `NHWC` (default) and `NCHW` are supported. @@ -526,28 +525,28 @@ def batch_norm(inputs, scope: Optional scope for `variable_scope`. renorm: Whether to use Batch Renormalization (https://arxiv.org/abs/1702.03275). This adds extra variables during - training. The inference is the same for either value of this parameter. + training. The inference is the same for either value of this parameter. renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to - scalar `Tensors` used to clip the renorm correction. The correction - `(r, d)` is used as `corrected_value = normalized_value * r + d`, with - `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, + scalar `Tensors` used to clip the renorm correction. The correction `(r, + d)` is used as `corrected_value = normalized_value * r + d`, with `r` + clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, dmax are set to inf, 0, inf, respectively. renorm_decay: Momentum used to update the moving means and standard - deviations with renorm. Unlike `momentum`, this affects training - and should be neither too small (which would add noise) nor too large - (which would give stale estimates). Note that `decay` is still applied - to get the means and variances for inference. + deviations with renorm. Unlike `momentum`, this affects training and + should be neither too small (which would add noise) nor too large (which + would give stale estimates). Note that `decay` is still applied to get the + means and variances for inference. adjustment: A function taking the `Tensor` containing the (dynamic) shape of the input tensor and returning a pair (scale, bias) to apply to the normalized values (before gamma and beta), only during training. For example, `adjustment = lambda shape: ( - tf.random_uniform(shape[-1:], 0.93, 1.07), - tf.random_uniform(shape[-1:], -0.1, 0.1))` - will scale the normalized value by up to 7% up or down, then shift the - result by up to 0.1 (with independent scaling and bias for each feature - but shared across all examples), and finally apply gamma and/or beta. If - `None`, no adjustment is applied. + tf.random.uniform(shape[-1:], 0.93, 1.07), + tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized + value by up to 7% up or down, then shift the result by up to 0.1 + (with independent scaling and bias for each feature but shared + across all examples), and finally apply gamma and/or beta. If + `None`, no adjustment is applied. Returns: A `Tensor` representing the output of the operation. @@ -692,8 +691,8 @@ def batch_norm(inputs, # For NCHW format, rather than relying on implicit broadcasting, we # explicitly reshape the params to params_shape_broadcast when computing # the moments and the batch normalization. - params_shape_broadcast = list( - [1, inputs_shape.dims[1].value] + [1 for _ in range(2, inputs_rank)]) + params_shape_broadcast = list([1, inputs_shape.dims[1].value] + + [1 for _ in range(2, inputs_rank)]) else: moments_axes = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] @@ -811,8 +810,9 @@ def batch_norm(inputs, moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance - update_mean, update_variance = utils.smart_cond( - is_training, _delay_updates, moving_vars_fn) + update_mean, update_variance = utils.smart_cond(is_training, + _delay_updates, + moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. @@ -858,8 +858,8 @@ def bias_add(inputs, activation_fn: Activation function, default set to None to skip it and maintain a linear activation. initializer: An initializer for the bias, defaults to 0. - regularizer: A regularizer like the result of - `l1_regularizer` or `l2_regularizer`. + regularizer: A regularizer like the result of `l1_regularizer` or + `l2_regularizer`. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. @@ -952,9 +952,8 @@ def convolution(inputs, `stride` values != 1 are not supported. Args: - inputs: A Tensor of rank N+2 of shape - `[batch_size] + input_spatial_shape + [in_channels]` if data_format does - not start with "NC" (default), or + inputs: A Tensor of rank N+2 of shape `[batch_size] + input_spatial_shape + + [in_channels]` if data_format does not start with "NC" (default), or `[batch_size, in_channels] + input_spatial_shape` if data_format starts with "NC". num_outputs: Integer, the number of output filters. @@ -970,11 +969,11 @@ def convolution(inputs, the `input` and output is the last dimension (default, or if `data_format` does not start with "NC"), or the second dimension (if `data_format` starts with "NC"). For N=1, the valid values are "NWC" (default) and - "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". - For N=3, the valid values are "NDHWC" (default) and "NCDHW". + "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For + N=3, the valid values are "NDHWC" (default) and "NCDHW". rate: A sequence of N positive integers specifying the dilation rate to use - for atrous convolution. Can be a single integer to specify the same - value for all spatial dimensions. Specifying any `rate` value != 1 is + for atrous convolution. Can be a single integer to specify the same value + for all spatial dimensions. Specifying any `rate` value != 1 is incompatible with specifying any `stride` value != 1. activation_fn: Activation function. The default value is a ReLU function. Explicitly set it to None to skip it and maintain a linear activation. @@ -997,8 +996,8 @@ def convolution(inputs, scope: Optional scope for `variable_scope`. conv_dims: Optional convolution dimensionality, when set it would use the corresponding convolution (e.g. 2 for Conv 2D, 3 for Conv 3D, ..). When - leaved to None it would select the convolution dimensionality based on - the input rank (i.e. Conv ND, with N = input_rank - 2). + leaved to None it would select the convolution dimensionality based on the + input rank (i.e. Conv ND, with N = input_rank - 2). Returns: A tensor representing the output of the operation. @@ -1070,6 +1069,7 @@ def convolution(inputs, outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + @add_arg_scope def convolution1d(inputs, num_outputs, @@ -1090,29 +1090,32 @@ def convolution1d(inputs, outputs_collections=None, trainable=True, scope=None): - return convolution(inputs, - num_outputs, - kernel_size, - stride, - padding, - data_format, - rate, - activation_fn, - normalizer_fn, - normalizer_params, - weights_initializer, - weights_regularizer, - biases_initializer, - biases_regularizer, - reuse, - variables_collections, - outputs_collections, - trainable, - scope, - conv_dims=1) + return convolution( + inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=1) + convolution1d.__doc__ = convolution.__doc__ + @add_arg_scope def convolution2d(inputs, num_outputs, @@ -1133,29 +1136,32 @@ def convolution2d(inputs, outputs_collections=None, trainable=True, scope=None): - return convolution(inputs, - num_outputs, - kernel_size, - stride, - padding, - data_format, - rate, - activation_fn, - normalizer_fn, - normalizer_params, - weights_initializer, - weights_regularizer, - biases_initializer, - biases_regularizer, - reuse, - variables_collections, - outputs_collections, - trainable, - scope, - conv_dims=2) + return convolution( + inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=2) + convolution2d.__doc__ = convolution.__doc__ + @add_arg_scope def convolution3d(inputs, num_outputs, @@ -1176,29 +1182,32 @@ def convolution3d(inputs, outputs_collections=None, trainable=True, scope=None): - return convolution(inputs, - num_outputs, - kernel_size, - stride, - padding, - data_format, - rate, - activation_fn, - normalizer_fn, - normalizer_params, - weights_initializer, - weights_regularizer, - biases_initializer, - biases_regularizer, - reuse, - variables_collections, - outputs_collections, - trainable, - scope, - conv_dims=3) + return convolution( + inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=3) + convolution3d.__doc__ = convolution.__doc__ + @add_arg_scope def convolution2d_in_plane( inputs, @@ -1234,9 +1243,9 @@ def convolution2d_in_plane( inputs: A 4-D tensor with dimensions [batch_size, height, width, channels]. kernel_size: A list of length 2 holding the [kernel_height, kernel_width] of of the pooling. Can be an int if both values are the same. - stride: A list of length 2 `[stride_height, stride_width]`. - Can be an int if both strides are the same. Note that presently - both strides must have the same value. + stride: A list of length 2 `[stride_height, stride_width]`. Can be an int if + both strides are the same. Note that presently both strides must have the + same value. padding: The padding type to use, either 'SAME' or 'VALID'. activation_fn: Activation function. The default value is a ReLU function. Explicitly set it to None to skip it and maintain a linear activation. @@ -1332,15 +1341,15 @@ def convolution2d_transpose( second variable called 'biases' is added to the result of the operation. Args: - inputs: A 4-D `Tensor` of type `float` and shape - `[batch, height, width, in_channels]` for `NHWC` data format or - `[batch, in_channels, height, width]` for `NCHW` data format. + inputs: A 4-D `Tensor` of type `float` and shape `[batch, height, width, + in_channels]` for `NHWC` data format or `[batch, in_channels, height, + width]` for `NCHW` data format. num_outputs: Integer, the number of output filters. kernel_size: A list of length 2 holding the [kernel_height, kernel_width] of of the filters. Can be an int if both values are the same. - stride: A list of length 2: [stride_height, stride_width]. - Can be an int if both strides are the same. Note that presently - both strides must have the same value. + stride: A list of length 2: [stride_height, stride_width]. Can be an int if + both strides are the same. Note that presently both strides must have the + same value. padding: One of 'VALID' or 'SAME'. data_format: A string. `NHWC` (default) and `NCHW` are supported. activation_fn: Activation function. The default value is a ReLU function. @@ -1447,15 +1456,15 @@ def convolution3d_transpose( kernel, that is convolved with the input. If `batch_norm_params` is `None`, a second variable called 'biases' is added to the result of the operation. Args: - inputs: A 5-D `Tensor` of type `float` and shape - `[batch, depth, height, width, in_channels]` for `NDHWC` data format or - `[batch, in_channels, depth, height, width]` for `NCDHW` data format. + inputs: A 5-D `Tensor` of type `float` and shape `[batch, depth, height, + width, in_channels]` for `NDHWC` data format or `[batch, in_channels, + depth, height, width]` for `NCDHW` data format. num_outputs: Integer, the number of output filters. kernel_size: A list of length 3 holding the [kernel_depth, kernel_height, kernel_width] of the filters. Can be an int if both values are the same. - stride: A list of length 3: [stride_depth, stride_height, stride_width]. - Can be an int if both strides are the same. Note that presently - both strides must have the same value. + stride: A list of length 3: [stride_depth, stride_height, stride_width]. Can + be an int if both strides are the same. Note that presently both strides + must have the same value. padding: One of 'VALID' or 'SAME'. data_format: A string. `NDHWC` (default) and `NCDHW` are supported. activation_fn: Activation function. The default value is a ReLU function. @@ -1476,6 +1485,7 @@ def convolution3d_transpose( outputs_collections: Collection to add the outputs. trainable: Whether or not the variables should be trainable or not. scope: Optional scope for variable_scope. + Returns: A tensor representing the output of the operation. Raises: @@ -1543,8 +1553,8 @@ def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): Args: tensor: An `int` `Tensor` to be converted to a `Sparse`. - eos_token: An integer. - It is part of the target label that signifies the end of a sentence. + eos_token: An integer. It is part of the target label that signifies the + end of a sentence. outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. """ @@ -1575,17 +1585,17 @@ def dropout(inputs, Args: inputs: The tensor to pass to the nn.dropout op. - keep_prob: A scalar `Tensor` with the same type as x. The probability - that each element is kept. - noise_shape: A 1-D `Tensor` of type `int32`, representing the - shape for randomly generated keep/drop flags. - is_training: A bool `Tensor` indicating whether or not the model - is in training mode. If so, dropout is applied and values scaled. - Otherwise, inputs is returned. + keep_prob: A scalar `Tensor` with the same type as x. The probability that + each element is kept. + noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for + randomly generated keep/drop flags. + is_training: A bool `Tensor` indicating whether or not the model is in + training mode. If so, dropout is applied and values scaled. Otherwise, + inputs is returned. outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. seed: A Python integer. Used to create random seeds. See - `tf.set_random_seed` for behavior. + `tf.compat.v1.set_random_seed` for behavior. Returns: A tensor representing the output of the operation. @@ -1675,7 +1685,7 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): For example: ''' - x = tf.random_uniform(shape=[1, 2, 3, 4, 5, 6]) + x = tf.random.uniform(shape=[1, 2, 3, 4, 5, 6]) y = _inner_flatten(x, 4) assert y.get_shape().as_list() == [1, 2, 3, (4 * 5 * 6)] ''' @@ -1687,6 +1697,7 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): new_rank: The desired rank of the returned `Tensor` or `SparseTensor`. output_collections: Collection to which the outputs will be added. scope: Optional scope for `name_scope`. + Returns: A `Tensor` or `SparseTensor` containing the same values as `inputs`, but with innermost dimensions flattened to obtain rank `new_rank`. @@ -1824,8 +1835,8 @@ def fully_connected(inputs, ValueError: If x has rank less than 2 or if its last dimension is not set. """ if not isinstance(num_outputs, six.integer_types): - raise ValueError('num_outputs type should be one of %s, got %s.' % ( - list(six.integer_types), type(num_outputs))) + raise ValueError('num_outputs type should be one of %s, got %s.' % + (list(six.integer_types), type(num_outputs))) layer_variable_getter = _build_variable_getter({ 'bias': 'biases', @@ -1902,14 +1913,14 @@ class GDN(base.Layer): Arguments: inverse: If `False` (default), compute GDN response. If `True`, compute IGDN - response (one step of fixed point iteration to invert GDN; the division - is replaced by multiplication). + response (one step of fixed point iteration to invert GDN; the division is + replaced by multiplication). beta_min: Lower bound for beta, to prevent numerical error from causing square root of zero or negative values. gamma_init: The gamma matrix will be initialized as the identity matrix multiplied with this value. If set to zero, the layer is effectively - initialized to the identity operation, since beta is initialized as one. - A good default setting is somewhere between 0 and 0.5. + initialized to the identity operation, since beta is initialized as one. A + good default setting is somewhere between 0 and 0.5. reparam_offset: Offset added to the reparameterization of beta and gamma. The reparameterization of beta and gamma as their square roots lets the training slow down when their values are close to zero, which is desirable @@ -1926,10 +1937,8 @@ class GDN(base.Layer): activity_regularizer: Regularizer function for the output. trainable: Boolean, if `True`, also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: String, the name of the layer. Layers with the same name will - share weights, but to avoid mistakes we require `reuse=True` in such - cases. - + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require `reuse=True` in such cases. Properties: inverse: Boolean, whether GDN is computed (`True`) or IGDN (`False`). data_format: Format of input tensor. Currently supports `'channels_first'` @@ -1987,9 +1996,8 @@ class GDN(base.Layer): with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope: inputs = ops.convert_to_tensor(inputs, name='inputs') bound = ops.convert_to_tensor(bound, name='bound') - with ops.get_default_graph().gradient_override_map({ - 'Maximum': 'GDNLowerBound' - }): + with ops.get_default_graph().gradient_override_map( + {'Maximum': 'GDNLowerBound'}): return math_ops.maximum(inputs, bound, name=scope) @staticmethod @@ -2017,9 +2025,7 @@ class GDN(base.Layer): 'must be defined.') self._input_rank = input_shape.ndims self.input_spec = input_spec.InputSpec( - ndim=input_shape.ndims, axes={ - channel_axis: num_channels - }) + ndim=input_shape.ndims, axes={channel_axis: num_channels}) pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) beta_bound = array_ops.constant( @@ -2147,14 +2153,14 @@ def gdn(inputs, Args: inputs: Tensor input. inverse: If `False` (default), compute GDN response. If `True`, compute IGDN - response (one step of fixed point iteration to invert GDN; the division - is replaced by multiplication). + response (one step of fixed point iteration to invert GDN; the division is + replaced by multiplication). beta_min: Lower bound for beta, to prevent numerical error from causing square root of zero or negative values. gamma_init: The gamma matrix will be initialized as the identity matrix multiplied with this value. If set to zero, the layer is effectively - initialized to the identity operation, since beta is initialized as one. - A good default setting is somewhere between 0 and 0.5. + initialized to the identity operation, since beta is initialized as one. A + good default setting is somewhere between 0 and 0.5. reparam_offset: Offset added to the reparameterization of beta and gamma. The reparameterization of beta and gamma as their square roots lets the training slow down when their values are close to zero, which is desirable @@ -2171,9 +2177,8 @@ def gdn(inputs, activity_regularizer: Regularizer function for the output. trainable: Boolean, if `True`, also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: String, the name of the layer. Layers with the same name will - share weights, but to avoid mistakes we require `reuse=True` in such - cases. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require `reuse=True` in such cases. reuse: Boolean, whether to reuse the weights of a previous layer by the same name. @@ -2234,14 +2239,14 @@ def layer_norm(inputs, and this part of the inputs' shape must be fully defined. Args: - inputs: A tensor having rank `R`. The normalization is performed over - axes `begin_norm_axis ... R - 1` and centering and scaling parameters - are calculated over `begin_params_axis ... R - 1`. + inputs: A tensor having rank `R`. The normalization is performed over axes + `begin_norm_axis ... R - 1` and centering and scaling parameters are + calculated over `begin_params_axis ... R - 1`. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is - not used. When the next layer is linear (also e.g. `nn.relu`), this can be - disabled since the scaling can be done by the next layer. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the + next layer is linear (also e.g. `nn.relu`), this can be disabled since the + scaling can be done by the next layer. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. reuse: Whether or not the layer and its variables should be reused. To be @@ -2252,10 +2257,10 @@ def layer_norm(inputs, `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). begin_norm_axis: The first normalization dimension: normalization will be performed along dimensions `begin_norm_axis : rank(inputs)` - begin_params_axis: The first parameter (beta, gamma) dimension: scale - and centering parameters will have dimensions + begin_params_axis: The first parameter (beta, gamma) dimension: scale and + centering parameters will have dimensions `begin_params_axis : rank(inputs)` and will be broadcast with the - normalized inputs accordingly. + normalized inputs accordingly. scope: Optional scope for `variable_scope`. Returns: @@ -2380,11 +2385,11 @@ def max_pool2d(inputs, `data_format` is `NHWC`, and `[batch_size, channels, height, width]` if `data_format` is `NCHW`. kernel_size: A list of length 2: [kernel_height, kernel_width] of the - pooling kernel over which the op is computed. Can be an int if both - values are the same. - stride: A list of length 2: [stride_height, stride_width]. - Can be an int if both strides are the same. Note that presently - both strides must have the same value. + pooling kernel over which the op is computed. Can be an int if both values + are the same. + stride: A list of length 2: [stride_height, stride_width]. Can be an int if + both strides are the same. Note that presently both strides must have the + same value. padding: The padding method, either 'VALID' or 'SAME'. data_format: A string. `NHWC` (default) and `NCHW` are supported. outputs_collections: The collections to which the outputs are added. @@ -2432,9 +2437,9 @@ def max_pool3d(inputs, kernel_size: A list of length 3: [kernel_depth, kernel_height, kernel_width] of the pooling kernel over which the op is computed. Can be an int if both values are the same. - stride: A list of length 3: [stride_depth, stride_height, stride_width]. - Can be an int if both strides are the same. Note that presently - both strides must have the same value. + stride: A list of length 3: [stride_depth, stride_height, stride_width]. Can + be an int if both strides are the same. Note that presently both strides + must have the same value. padding: The padding method, either 'VALID' or 'SAME'. data_format: A string. `NDHWC` (default) and `NCDHW` are supported. outputs_collections: The collections to which the outputs are added. @@ -2478,9 +2483,8 @@ def pool(inputs, Args: - inputs: Tensor of rank N+2, of shape - `[batch_size] + input_spatial_shape + [num_channels]` if data_format does - not start with "NC" (default), or + inputs: Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape + + [num_channels]` if data_format does not start with "NC" (default), or `[batch_size, num_channels] + input_spatial_shape` if data_format starts with "NC". Pooling happens over the spatial dimensions only. kernel_size: Sequence of N ints >= 1. Can also be a single integer to @@ -2491,8 +2495,8 @@ def pool(inputs, the `input` and output is the last dimension (default, or if `data_format` does not start with "NC"), or the second dimension (if `data_format` starts with "NC"). For N=1, the valid values are "NWC" (default) and - "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". - For N=3, the valid values are "NDHWC" (default) and "NCDHW". + "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For + N=3, the valid values are "NDHWC" (default) and "NCDHW". dilation_rate: Optional. Dilation rate. Sequence of N ints >= 1. Defaults to [1]*N. Can also be a single integer to specify the same value for all spatial dimensions. If any value of dilation_rate is > 1, then all values @@ -2693,10 +2697,10 @@ def separable_convolution2d( Args: inputs: A tensor of size [batch_size, height, width, channels]. - num_outputs: The number of pointwise convolution output filters. If is - None, then we skip the pointwise convolution stage. - kernel_size: A list of length 2: [kernel_height, kernel_width] of - of the filters. Can be an int if both values are the same. + num_outputs: The number of pointwise convolution output filters. If is None, + then we skip the pointwise convolution stage. + kernel_size: A list of length 2: [kernel_height, kernel_width] of of the + filters. Can be an int if both values are the same. depth_multiplier: The number of depthwise convolution output channels for each input channel. The total number of depthwise convolution output channels will be equal to `num_filters_in * depth_multiplier`. @@ -2705,8 +2709,8 @@ def separable_convolution2d( padding: One of 'VALID' or 'SAME'. data_format: A string. `NHWC` (default) and `NCHW` are supported. rate: A list of length 2: [rate_height, rate_width], specifying the dilation - rates for atrous convolution. Can be an int if both rates are the same. - If any value is larger than one, then both stride values need to be one. + rates for atrous convolution. Can be an int if both rates are the same. If + any value is larger than one, then both stride values need to be one. activation_fn: Activation function. The default value is a ReLU function. Explicitly set it to None to skip it and maintain a linear activation. normalizer_fn: Normalization function to use instead of `biases`. If @@ -2715,8 +2719,8 @@ def separable_convolution2d( default set to None for no normalizer function normalizer_params: Normalization function parameters. weights_initializer: An initializer for the depthwise weights. - pointwise_initializer: An initializer for the pointwise weights. - default set to None, means use weights_initializer. + pointwise_initializer: An initializer for the pointwise weights. default set + to None, means use weights_initializer. weights_regularizer: Optional regularizer for the weights. biases_initializer: An initializer for the biases. If None skip biases. biases_regularizer: Optional regularizer for the biases. @@ -2810,10 +2814,9 @@ def separable_convolution2d( regularizer=weights_regularizer, trainable=trainable, collections=weights_collections) - strides = [1, 1, stride_h, - stride_w] if data_format.startswith('NC') else [ - 1, stride_h, stride_w, 1 - ] + strides = [ + 1, 1, stride_h, stride_w + ] if data_format.startswith('NC') else [1, stride_h, stride_w, 1] outputs = nn.depthwise_conv2d( inputs, @@ -2859,8 +2862,8 @@ def sequence_to_images(inputs, Args: inputs: (num_steps, num_batches, depth) sequence tensor height: the height of the images - output_data_format: Format of output tensor. - Currently supports `'channels_first'` and `'channels_last'`. + output_data_format: Format of output tensor. Currently supports + `'channels_first'` and `'channels_last'`. outputs_collections: The collections to which the outputs are added. scope: Optional scope for name_scope. @@ -2874,8 +2877,7 @@ def sequence_to_images(inputs, num_batches = -1 else: num_batches //= height - reshaped = array_ops.reshape(inputs, - [width, num_batches, height, depth]) + reshaped = array_ops.reshape(inputs, [width, num_batches, height, depth]) if output_data_format == 'channels_first': outputs = array_ops.transpose(reshaped, [1, 3, 2, 0]) else: @@ -2936,6 +2938,7 @@ def spatial_softmax(features, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). data_format: A string. `NHWC` (default) and `NCHW` are supported. + Returns: feature_keypoints: A `Tensor` with size [batch_size, num_channels * 2]; the expected 2D locations of each channel's feature keypoint (normalized @@ -2998,8 +3001,7 @@ def spatial_softmax(features, pos_y * softmax_attention, [1], keepdims=True) expected_xy = array_ops.concat([expected_x, expected_y], 1) feature_keypoints = array_ops.reshape( - expected_xy, - [-1, tensor_shape.dimension_value(num_channels) * 2]) + expected_xy, [-1, tensor_shape.dimension_value(num_channels) * 2]) feature_keypoints.set_shape( [None, tensor_shape.dimension_value(num_channels) * 2]) return feature_keypoints @@ -3112,11 +3114,11 @@ def maxout(inputs, num_units, axis=-1, scope=None): Arguments: inputs: Tensor input - num_units: Specifies how many features will remain after maxout - in the `axis` dimension (usually channel). - This must be a factor of number of features. - axis: The dimension where max pooling will be performed. Default is the - last dimension. + num_units: Specifies how many features will remain after maxout in the + `axis` dimension (usually channel). This must be a factor of number of + features. + axis: The dimension where max pooling will be performed. Default is the last + dimension. scope: Optional scope for variable_scope. Returns: @@ -3166,8 +3168,7 @@ def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): Args: x: A `Tensor`. - axis: Axis along which to normalize. A scalar or a vector of - integers. + axis: Axis along which to normalize. A scalar or a vector of integers. epsilon: A small deviation from the edge of the unit sphere for numerical stability. name: A name for this operation (optional). @@ -3217,8 +3218,9 @@ def legacy_fully_connected(x, This op creates `w` and optionally `b`. Bias (`b`) can be disabled by setting `bias_init` to `None`. - The variable creation is compatible with `tf.variable_scope` and so can be - reused with `tf.variable_scope` or `tf.make_template`. + The variable creation is compatible with `tf.compat.v1.variable_scope` and so + can be + reused with `tf.compat.v1.variable_scope` or `tf.compat.v1.make_template`. Most of the details of variable creation can be controlled by specifying the initializers (`weight_init` and `bias_init`) and in which collections to place @@ -3245,16 +3247,16 @@ def legacy_fully_connected(x, name: The name for this operation is used to name operations and to find variables. If specified it must be unique for this scope, otherwise a unique name starting with "fully_connected" will be created. See - `tf.variable_scope` for details. + `tf.compat.v1.variable_scope` for details. weight_collections: List of graph collections to which weights are added. bias_collections: List of graph collections to which biases are added. output_collections: List of graph collections to which outputs are added. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - weight_regularizer: A regularizer like the result of - `l1_regularizer` or `l2_regularizer`. Used for weights. - bias_regularizer: A regularizer like the result of - `l1_regularizer` or `l2_regularizer`. Used for biases. + weight_regularizer: A regularizer like the result of `l1_regularizer` or + `l2_regularizer`. Used for weights. + bias_regularizer: A regularizer like the result of `l1_regularizer` or + `l2_regularizer`. Used for biases. Returns: The output of the fully connected layer. diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 2cd72410d3d..90fd55cf389 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1399,9 +1399,10 @@ class DropoutTest(test.TestCase): with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') - num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) + num_elem_initial = math_ops.reduce_mean( + math_ops.cast(images > 0, dtypes.float32)) output = _layers.dropout(images) - num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem = math_ops.reduce_mean(math_ops.cast(output > 0, dtypes.float32)) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertLess(num_elem, num_elem_initial / 2 + 0.1) self.assertGreater(num_elem, num_elem_initial / 2 - 0.1) @@ -1421,9 +1422,10 @@ class DropoutTest(test.TestCase): with self.cached_session() as sess: images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') - num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) + num_elem_initial = math_ops.reduce_mean( + math_ops.cast(images > 0, dtypes.float32)) output = _layers.dropout(images, is_training=False) - num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem = math_ops.reduce_mean(math_ops.cast(output > 0, dtypes.float32)) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertEqual(num_elem, num_elem_initial) outputs, inputs = sess.run([output, images]) @@ -1435,9 +1437,10 @@ class DropoutTest(test.TestCase): images = random_ops.random_uniform( (5, height, width, 3), seed=1, name='images') output = _layers.fully_connected(images, 50) - num_elem_initial = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem_initial = math_ops.reduce_mean( + math_ops.cast(output > 0, dtypes.float32)) output = _layers.dropout(output) - num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem = math_ops.reduce_mean(math_ops.cast(output > 0, dtypes.float32)) sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertLess(num_elem, num_elem_initial / 2 + 0.1) @@ -1450,7 +1453,7 @@ class DropoutTest(test.TestCase): (5, height, width, 3), seed=1, name='images') output = _layers.fully_connected( images, 50, normalizer_fn=_layers.dropout) - num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) + num_elem = math_ops.reduce_mean(math_ops.cast(output > 0, dtypes.float32)) sess.run(variables_lib.global_variables_initializer()) num_elem = sess.run(num_elem) self.assertLess(num_elem, 0.5) diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 3b075875035..2c18bfa7b91 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -21,6 +21,7 @@ from __future__ import print_function import six from tensorflow.contrib import framework as contrib_framework +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -75,64 +76,62 @@ def optimize_loss(loss, for full list. E.g. `optimize_loss(..., optimizer='Adam')`. - by function taking learning rate `Tensor` as argument and returning an `Optimizer` instance. E.g. `optimize_loss(..., - optimizer=lambda lr: tf.train.MomentumOptimizer(lr, momentum=0.5))`. + optimizer=lambda lr: tf.compat.v1.train.MomentumOptimizer(lr, + momentum=0.5))`. Alternatively, if `learning_rate` is `None`, the function takes no arguments. E.g. `optimize_loss(..., learning_rate=None, - optimizer=lambda: tf.train.MomentumOptimizer(0.5, momentum=0.5))`. + optimizer=lambda: tf.compat.v1.train.MomentumOptimizer(0.5, + momentum=0.5))`. - by a subclass of `Optimizer` having a single-argument constructor (the argument is the learning rate), such as AdamOptimizer or AdagradOptimizer. E.g. `optimize_loss(..., - optimizer=tf.train.AdagradOptimizer)`. + optimizer=tf.compat.v1.train.AdagradOptimizer)`. - by an instance of a subclass of `Optimizer`. - E.g., `optimize_loss(..., optimizer=tf.train.AdagradOptimizer(0.5))`. + E.g., `optimize_loss(..., + optimizer=tf.compat.v1.train.AdagradOptimizer(0.5))`. Args: loss: Scalar `Tensor`. - global_step: Scalar int `Tensor`, step counter to update on each step - unless `increment_global_step` is `False`. If not supplied, - it will be fetched from the default graph (see - `tf.train.get_global_step` for details). If it has - not been created, no step will be incremented with each weight - update. `learning_rate_decay_fn` requires `global_step`. + global_step: Scalar int `Tensor`, step counter to update on each step unless + `increment_global_step` is `False`. If not supplied, it will be fetched + from the default graph (see `tf.compat.v1.train.get_global_step` for + details). If it has not been created, no step will be incremented with + each weight update. `learning_rate_decay_fn` requires `global_step`. learning_rate: float or `Tensor`, magnitude of update per each training - step. Can be `None`. - optimizer: string, class or optimizer instance, used as trainer. - string should be name of optimizer, like 'SGD', - 'Adam', 'Adagrad'. Full list in OPTIMIZER_CLS_NAMES constant. - class should be sub-class of `tf.Optimizer` that implements - `compute_gradients` and `apply_gradients` functions. - optimizer instance should be instantiation of `tf.Optimizer` - sub-class and have `compute_gradients` and `apply_gradients` - functions. + step. Can be `None`. + optimizer: string, class or optimizer instance, used as trainer. string + should be name of optimizer, like 'SGD', 'Adam', 'Adagrad'. Full list in + OPTIMIZER_CLS_NAMES constant. class should be sub-class of `tf.Optimizer` + that implements `compute_gradients` and `apply_gradients` functions. + optimizer instance should be instantiation of `tf.Optimizer` sub-class and + have `compute_gradients` and `apply_gradients` functions. gradient_noise_scale: float or None, adds 0-mean normal noise scaled by this - value. - gradient_multipliers: dict of variables or variable names to floats. - If present, gradients for specified - variables will be multiplied by given constant. + value. + gradient_multipliers: dict of variables or variable names to floats. If + present, gradients for specified variables will be multiplied by given + constant. clip_gradients: float, callable or `None`. If a float is provided, a global clipping is applied to prevent the norm of the gradient from exceeding this value. Alternatively, a callable can be provided, e.g., - `adaptive_clipping_fn()`. This callable takes a list of - `(gradients, variables)` tuples and returns the same thing with the - gradients modified. + `adaptive_clipping_fn()`. This callable takes a list of `(gradients, + variables)` tuples and returns the same thing with the gradients modified. learning_rate_decay_fn: function, takes `learning_rate` and `global_step` - `Tensor`s, returns `Tensor`. - Can be used to implement any learning rate decay - functions. - For example: `tf.train.exponential_decay`. - Ignored if `learning_rate` is not supplied. + `Tensor`s, returns `Tensor`. Can be used to implement any learning rate + decay functions. + For example: `tf.compat.v1.train.exponential_decay`. + Ignored if `learning_rate` is not supplied. update_ops: list of update `Operation`s to execute at each step. If `None`, - uses elements of UPDATE_OPS collection. The order of execution - between `update_ops` and `loss` is non-deterministic. - variables: list of variables to optimize or - `None` to use all trainable variables. + uses elements of UPDATE_OPS collection. The order of execution between + `update_ops` and `loss` is non-deterministic. + variables: list of variables to optimize or `None` to use all trainable + variables. name: The name for this operation is used to scope operations and summaries. summaries: List of internal quantities to visualize on tensorboard. If not - set, the loss, the learning rate, and the global norm of the - gradients will be reported. The complete list of possible values - is in OPTIMIZER_SUMMARIES. + set, the loss, the learning rate, and the global norm of the gradients + will be reported. The complete list of possible values is in + OPTIMIZER_SUMMARIES. colocate_gradients_with_ops: If True, try colocating gradients with the - corresponding op. + corresponding op. increment_global_step: Whether to increment `global_step`. If your model calls `optimize_loss` multiple times per training step (e.g. to optimize different parts of the model), use this arg to avoid incrementing @@ -181,8 +180,8 @@ def optimize_loss(loss, initializer=init_ops.constant_initializer(learning_rate)) else: raise ValueError("Learning rate should be 0d Tensor or float. " - "Got %s of type %s" % (str(learning_rate), - str(type(learning_rate)))) + "Got %s of type %s" % + (str(learning_rate), str(type(learning_rate)))) if summaries is None: summaries = ["loss", "learning_rate", "global_gradient_norm"] else: @@ -262,8 +261,8 @@ def optimize_loss(loss, elif callable(clip_gradients): gradients = clip_gradients(gradients) elif clip_gradients is not None: - raise ValueError( - "Unknown type %s for clip_gradients" % type(clip_gradients)) + raise ValueError("Unknown type %s for clip_gradients" % + type(clip_gradients)) # Add scalar summary for loss. if "loss" in summaries: @@ -325,7 +324,7 @@ def _adaptive_max_norm(norm, std_factor, decay, global_step, epsilon, name): # quicker adaptation at the beginning if global_step is not None: - n = math_ops.to_float(global_step) + n = math_ops.cast(global_step, dtypes.float32) decay = math_ops.minimum(decay, n / (n + 1.)) # update averages @@ -355,8 +354,8 @@ def adaptive_clipping_fn(std_factor=2., rescaled such that the global norm becomes `exp(mean)`. Args: - std_factor: Python scaler (or tensor). - `max_norm = exp(mean + std_factor*std)` + std_factor: Python scaler (or tensor). `max_norm = exp(mean + + std_factor*std)` decay: The smoothing factor of the moving averages. static_max_norm: If provided, will threshold the norm to this value as an extra safety. @@ -384,8 +383,7 @@ def adaptive_clipping_fn(std_factor=2., summary.scalar("global_norm/adaptive_max_gradient_norm", max_norm) # factor will be 1. if norm is smaller than max_norm - factor = array_ops.where(norm < max_norm, - array_ops.ones_like(norm), + factor = array_ops.where(norm < max_norm, array_ops.ones_like(norm), math_ops.exp(log_mean) / norm) if static_max_norm is not None: diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 5234869718b..131b1e0dba2 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -23,6 +23,7 @@ import six from tensorflow.contrib.framework import deprecated from tensorflow.contrib.losses.python.losses import loss_ops from tensorflow.contrib.metrics.python.ops import metric_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -185,7 +186,8 @@ class _TargetColumn(object): return None else: return array_ops.reshape( - math_ops.to_float(features[self._weight_column_name]), shape=(-1,)) + math_ops.cast(features[self._weight_column_name], dtypes.float32), + shape=(-1,)) @property def problem_type(self): @@ -252,9 +254,10 @@ class _TargetColumn(object): if weight_tensor is None: return math_ops.reduce_mean(loss_unweighted, name="loss") loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor) - return math_ops.div(math_ops.reduce_sum(loss_weighted), - math_ops.to_float(math_ops.reduce_sum(weight_tensor)), - name="loss") + return math_ops.div( + math_ops.reduce_sum(loss_weighted), + math_ops.cast(math_ops.reduce_sum(weight_tensor), dtypes.float32), + name="loss") class _RegressionTargetColumn(_TargetColumn): @@ -323,7 +326,7 @@ class _MultiClassTargetColumn(_TargetColumn): metrics = {("accuracy", "classes"): metric_ops.streaming_accuracy} predictions = math_ops.sigmoid(logits) - labels_float = math_ops.to_float(labels) + labels_float = math_ops.cast(labels, dtypes.float32) default_metrics = self._default_eval_metrics() for metric_name, metric_op in default_metrics.items(): @@ -399,7 +402,8 @@ def _mean_squared_loss(logits, target): target = array_ops.expand_dims(target, axis=1) logits.get_shape().assert_is_compatible_with(target.get_shape()) - return math_ops.squared_difference(logits, math_ops.to_float(target)) + return math_ops.squared_difference(logits, + math_ops.cast(target, dtypes.float32)) def _log_loss_with_two_classes(logits, target): @@ -407,7 +411,7 @@ def _log_loss_with_two_classes(logits, target): if len(target.get_shape()) == 1: target = array_ops.expand_dims(target, axis=1) loss_vec = nn.sigmoid_cross_entropy_with_logits( - labels=math_ops.to_float(target), logits=logits) + labels=math_ops.cast(target, dtypes.float32), logits=logits) return loss_vec @@ -475,7 +479,7 @@ def get_default_binary_metrics_for_eval(thresholds): def _float_weights_or_none(weights): if weights is None: return None - return math_ops.to_float(weights) + return math_ops.cast(weights, dtypes.float32) def _labels_streaming_mean(unused_predictions, labels, weights=None): @@ -494,8 +498,8 @@ def _streaming_auc(predictions, labels, weights=None): def _accuracy_at_threshold(threshold): def _accuracy_metric(predictions, labels, weights=None): - threshold_predictions = math_ops.to_float( - math_ops.greater_equal(predictions, threshold)) + threshold_predictions = math_ops.cast( + math_ops.greater_equal(predictions, threshold), dtypes.float32) return metric_ops.streaming_accuracy( predictions=threshold_predictions, labels=labels, weights=weights) diff --git a/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py b/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py index 91684dc61e4..934a7f06069 100644 --- a/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py +++ b/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py @@ -86,11 +86,11 @@ def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0, internal_type = dtypes.string for i in range(len(values)): if values[i].dtype != dtypes.string: - values[i] = math_ops.to_int64(values[i]) + values[i] = math_ops.cast(values[i], dtypes.int64) internal_type = dtypes.int64 for i in range(len(dense_inputs)): if dense_inputs[i].dtype != dtypes.string: - dense_inputs[i] = math_ops.to_int64(dense_inputs[i]) + dense_inputs[i] = math_ops.cast(dense_inputs[i], dtypes.int64) internal_type = dtypes.int64 if hash_key: diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 4749371248e..1d0cac308f3 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -261,6 +261,9 @@ py_test( name = "tensor_signature_test", srcs = ["python/learn/estimators/tensor_signature_test.py"], srcs_version = "PY2AND3", + tags = [ + "manual", # b/130760310 + ], deps = [ ":learn", "//tensorflow/python:array_ops", @@ -387,6 +390,13 @@ py_test( shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 + deps = [":head_test_lib"], +) + +py_library( + name = "head_test_lib", + srcs = ["python/learn/estimators/head_test.py"], + srcs_version = "PY2AND3", deps = [ ":learn", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/learn/python/learn/datasets/BUILD b/tensorflow/contrib/learn/python/learn/datasets/BUILD index 2c7215bba38..d6a43ee3a69 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/BUILD +++ b/tensorflow/contrib/learn/python/learn/datasets/BUILD @@ -37,6 +37,7 @@ py_library( py_binary( name = "produce_small_datasets", srcs = ["produce_small_datasets.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":datasets", @@ -48,6 +49,7 @@ py_test( name = "base_test", size = "small", srcs = ["base_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":datasets", @@ -59,6 +61,7 @@ py_test( name = "load_csv_test", size = "small", srcs = ["load_csv_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":datasets", @@ -70,6 +73,7 @@ py_test( name = "synthetic_test", size = "small", srcs = ["synthetic_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":datasets", diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 3e64595f312..ce644dde04f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -101,7 +101,7 @@ my_features = [embedding_feature_a, embedding_feature_b] estimator = DNNClassifier( feature_columns=my_features, hidden_units=[1024, 512, 256], - optimizer=tf.train.ProximalAdagradOptimizer( + optimizer=tf.compat.v1.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001 )) @@ -123,7 +123,7 @@ hidden_units=[1024, 512, 256]) estimator = DNNRegressor( feature_columns=my_features, hidden_units=[1024, 512, 256], - optimizer=tf.train.ProximalAdagradOptimizer( + optimizer=tf.compat.v1.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001 )) @@ -145,11 +145,11 @@ estimator = DNNLinearCombinedClassifier( weight_column_name=weight_column_name, # Wide settings linear_feature_columns=my_linear_features, - linear_optimizer=tf.train.FtrlOptimizer(...), + linear_optimizer=tf.compat.v1.train.FtrlOptimizer(...), # Deep settings dnn_feature_columns=my_deep_features, dnn_hidden_units=[1000, 500, 100], - dnn_optimizer=tf.train.AdagradOptimizer(...)) + dnn_optimizer=tf.compat.v1.train.AdagradOptimizer(...)) ``` #### LinearClassifier @@ -161,7 +161,7 @@ classes. When number of possible classes is 2, this is binary classification. my_features = [sparse_feature_b, crossed_feature_a_x_b] estimator = LinearClassifier( feature_columns=my_features, - optimizer=tf.train.FtrlOptimizer( + optimizer=tf.compat.v1.train.FtrlOptimizer( learning_rate=0.1, l1_regularization_strength=0.001 )) diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py index 4c206839300..99f22d182cd 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py @@ -50,7 +50,7 @@ class _BaseEstimator(object): params : mapping of string to any Parameter names mapped to their values. """ - out = dict() + out = {} param_names = [name for name in self.__dict__ if not name.startswith('_')] for key in param_names: value = getattr(self, key, None) diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py index b968aeed1b7..ab0ce6d581a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py @@ -474,7 +474,7 @@ class DebugClassifierTest(test.TestCase): def _my_metric_op(predictions, labels): # For the case of binary classification, the 2nd column of "predictions" # denotes the model predictions. - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) predictions = array_ops.strided_slice( predictions, [0, 1], [-1, 2], end_mask=1) labels = math_ops.cast(labels, predictions.dtype) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index 10fbd60ba2d..7e4e3a8d287 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -241,7 +241,7 @@ class DNNClassifier(estimator.Estimator): estimator = DNNClassifier( feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], hidden_units=[1024, 512, 256], - optimizer=tf.train.ProximalAdagradOptimizer( + optimizer=tf.compat.v1.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001 )) @@ -554,7 +554,7 @@ class DNNRegressor(estimator.Estimator): estimator = DNNRegressor( feature_columns=[sparse_feature_a, sparse_feature_b], hidden_units=[1024, 512, 256], - optimizer=tf.train.ProximalAdagradOptimizer( + optimizer=tf.compat.v1.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001 )) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 2ade6b7b6ce..5d09ac4069b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -525,11 +525,11 @@ class DNNLinearCombinedClassifier(estimator.Estimator): weight_column_name=weight_column_name, # wide settings linear_feature_columns=[sparse_feature_a_x_sparse_feature_b], - linear_optimizer=tf.train.FtrlOptimizer(...), + linear_optimizer=tf.compat.v1.train.FtrlOptimizer(...), # deep settings dnn_feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], dnn_hidden_units=[1000, 500, 100], - dnn_optimizer=tf.train.AdagradOptimizer(...)) + dnn_optimizer=tf.compat.v1.train.AdagradOptimizer(...)) # Input builders def input_fn_train: # returns x, y (where y represents label's class index). @@ -870,14 +870,14 @@ class DNNLinearCombinedRegressor(estimator.Estimator): weight_column_name=weight_column_name, # wide settings linear_feature_columns=[sparse_feature_a_x_sparse_feature_b], - linear_optimizer=tf.train.FtrlOptimizer(...), + linear_optimizer=tf.compat.v1.train.FtrlOptimizer(...), # deep settings dnn_feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], dnn_hidden_units=[1000, 500, 100], - dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) + dnn_optimizer=tf.compat.v1.train.ProximalAdagradOptimizer(...)) # To apply L1 and L2 regularization, you can set optimizers as follows: - tf.train.ProximalAdagradOptimizer( + tf.compat.v1.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index d46a873bfaa..4f636ce69dd 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -807,7 +807,7 @@ class DNNLinearCombinedClassifierTest(test.TestCase): def _my_metric_op(predictions, labels): # For the case of binary classification, the 2nd column of "predictions" # denotes the model predictions. - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) predictions = array_ops.strided_slice( predictions, [0, 1], [-1, 2], end_mask=1) return math_ops.reduce_sum(math_ops.multiply(predictions, labels)) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index ee25cebd484..d779495720b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -815,7 +815,7 @@ class DNNClassifierTest(test.TestCase): def _my_metric_op(predictions, labels): # For the case of binary classification, the 2nd column of "predictions" # denotes the model predictions. - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) predictions = array_ops.strided_slice( predictions, [0, 1], [-1, 2], end_mask=1) labels = math_ops.cast(labels, predictions.dtype) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index c3e9e3af942..7a96f6d3ea4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -372,9 +372,10 @@ class DynamicRnnEstimatorTest(test.TestCase): labels = array_ops.slice(random_sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( - math_ops.to_float( + math_ops.cast( array_ops.slice(random_sequence, [0, 1], - [batch_size, sequence_length])), 2) + [batch_size, sequence_length]), + dtypes.float32), 2) input_dict = { dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform( [batch_size, cell_size], seed=((i + 1) * seed)) @@ -430,9 +431,10 @@ class DynamicRnnEstimatorTest(test.TestCase): labels = array_ops.slice(sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( - math_ops.to_float( + math_ops.cast( array_ops.slice(sequence, [0, 1], [batch_size, sequence_length - ])), 2) + ]), + dtypes.float32), 2) input_dict = state_dict input_dict['inputs'] = inputs return input_dict, labels @@ -587,9 +589,11 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): labels = array_ops.slice(random_sequence, [0, 0], [batch_size, sequence_length]) inputs = array_ops.expand_dims( - math_ops.to_float( + math_ops.cast( array_ops.slice(random_sequence, [0, 1], - [batch_size, sequence_length])), 2) + [batch_size, sequence_length]), + dtypes.float32), + 2) return {'inputs': inputs}, labels return input_fn @@ -719,11 +723,13 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): def input_fn(): random_sequence = random_ops.random_uniform( [batch_size, sequence_length], 0, 2, dtype=dtypes.int32, seed=seed) - inputs = array_ops.expand_dims(math_ops.to_float(random_sequence), 2) - labels = math_ops.to_int32( + inputs = array_ops.expand_dims( + math_ops.cast(random_sequence, dtypes.float32), 2) + labels = math_ops.cast( array_ops.squeeze( math_ops.reduce_sum(inputs, axis=[1]) > ( - sequence_length / 2.0))) + sequence_length / 2.0)), + dtypes.int32) return {'inputs': inputs}, labels return input_fn diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index cbcae338a0a..153d4867961 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -220,7 +220,7 @@ def _build_estimator_for_export_tests(tmpdir): hashtable = lookup.HashTable( lookup.TextFileStringTableInitializer(vocab_file_name), 'x') features['bogus_lookup'] = hashtable.lookup( - math_ops.to_int64(features['feature'])) + math_ops.cast(features['feature'], dtypes.int64)) return input_fn_utils.InputFnOps(features, labels, inputs) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 4bb14a6e63b..0dd835f8fb5 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -126,7 +126,7 @@ class Head(object): scope=...) if mode == tf.contrib.learn.ModeKeys.TRAIN: optimizer = ... - sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) + sync = tf.compat.v1.train.SyncReplicasOptimizer(opt=optimizer, ...) update_op = tf.contrib.layers.optimize_loss(optimizer=sync, loss=model_fn_ops.loss, ...) hooks = [sync.make_session_run_hook(is_chief)] @@ -568,7 +568,7 @@ def _mean_squared_loss(labels, logits, weights=None): logits = array_ops.expand_dims(logits, axis=1) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = math_ops.squared_difference( - logits, math_ops.to_float(labels), name=name) + logits, math_ops.cast(labels, dtypes.float32), name=name) return _compute_weighted_loss(loss, weights) @@ -793,7 +793,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None): with ops.name_scope(None, "log_loss_with_two_classes", (logits, labels)) as name: logits = ops.convert_to_tensor(logits) - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) # TODO(ptucker): This will break for dynamic shapes. # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. if len(labels.get_shape()) == 1: @@ -1214,8 +1214,8 @@ def _sparse_labels_to_indicator(labels, num_classes): if num_classes < 2: raise ValueError("Must set num_classes >= 2 when passing labels as a " "SparseTensor.") - return math_ops.to_int64( - sparse_ops.sparse_to_indicator(labels, num_classes)) + return math_ops.cast( + sparse_ops.sparse_to_indicator(labels, num_classes), dtypes.int64) return labels @@ -1400,8 +1400,9 @@ class _MultiLabelHead(_SingleHead): math_ops.sigmoid( logits, name=prediction_key.PredictionKey.PROBABILITIES), prediction_key.PredictionKey.CLASSES: - math_ops.to_int64( + math_ops.cast( math_ops.greater(logits, 0), + dtypes.int64, name=prediction_key.PredictionKey.CLASSES) } @@ -1783,7 +1784,7 @@ def _weight_tensor(features, weight_column_name): raise ValueError("Weights {} missing from features.".format( weight_column_name)) with ops.name_scope(None, "weight_tensor", tuple(six.itervalues(features))): - weight_tensor = math_ops.to_float(features[weight_column_name]) + weight_tensor = math_ops.cast(features[weight_column_name], dtypes.float32) shape = weight_tensor.get_shape() rank = shape.ndims # We don't bother with expanding dims of non-staticly shaped tensors or @@ -1833,7 +1834,7 @@ def _compute_weighted_loss(loss_unweighted, weight, name="loss"): weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope) weighted_loss_normalized = math_ops.div( math_ops.reduce_sum(weighted_loss), - math_ops.to_float(math_ops.reduce_sum(weight)), + math_ops.cast(math_ops.reduce_sum(weight), dtypes.float32), name="weighted_average_loss") return weighted_loss_mean, weighted_loss_normalized @@ -1952,7 +1953,7 @@ def _sigmoid_cross_entropy_loss(labels, logits, weights=None): (logits, labels)) as name: # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels. loss = nn.sigmoid_cross_entropy_with_logits( - labels=math_ops.to_float(labels), logits=logits, name=name) + labels=math_ops.cast(labels, dtypes.float32), logits=logits, name=name) return _compute_weighted_loss(loss, weights) @@ -1960,11 +1961,11 @@ def _float_weights_or_none(weights): if weights is None: return None with ops.name_scope(None, "float_weights", (weights,)) as name: - return math_ops.to_float(weights, name=name) + return math_ops.cast(weights, dtypes.float32, name=name) def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) weights = _float_weights_or_none(weights) if weights is not None: weights = weights_broadcast_ops.broadcast_weights(weights, labels) @@ -1978,7 +1979,7 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None): def _predictions_streaming_mean(predictions, weights=None, class_id=None): - predictions = math_ops.to_float(predictions) + predictions = math_ops.cast(predictions, dtypes.float32) weights = _float_weights_or_none(weights) if weights is not None: weights = weights_broadcast_ops.broadcast_weights(weights, predictions) @@ -2002,9 +2003,9 @@ def _class_predictions_streaming_mean(predictions, weights, class_id): return metrics_lib.mean( array_ops.where( math_ops.equal( - math_ops.to_int32(class_id), math_ops.to_int32(predictions)), - array_ops.ones_like(predictions), - array_ops.zeros_like(predictions)), + math_ops.cast(class_id, dtypes.int32), + math_ops.cast(predictions, dtypes.int32)), + array_ops.ones_like(predictions), array_ops.zeros_like(predictions)), weights=weights) @@ -2012,15 +2013,16 @@ def _class_labels_streaming_mean(labels, weights, class_id): return metrics_lib.mean( array_ops.where( math_ops.equal( - math_ops.to_int32(class_id), math_ops.to_int32(labels)), - array_ops.ones_like(labels), array_ops.zeros_like(labels)), + math_ops.cast(class_id, dtypes.int32), + math_ops.cast(labels, dtypes.int32)), array_ops.ones_like(labels), + array_ops.zeros_like(labels)), weights=weights) def _streaming_auc(predictions, labels, weights=None, class_id=None, curve="ROC"): # pylint: disable=missing-docstring - predictions = math_ops.to_float(predictions) + predictions = math_ops.cast(predictions, dtypes.float32) if labels.dtype.base_dtype != dtypes.bool: logging.warning("Casting %s labels to bool.", labels.dtype) labels = math_ops.cast(labels, dtypes.bool) @@ -2047,8 +2049,8 @@ def _assert_class_id(class_id, num_classes=None): def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold): - threshold_predictions = math_ops.to_float( - math_ops.greater_equal(predictions, threshold)) + threshold_predictions = math_ops.cast( + math_ops.greater_equal(predictions, threshold), dtypes.float32) return metrics_lib.accuracy(labels, threshold_predictions, weights) diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 9ee8d8004bf..d2ee0ebfa48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -334,7 +334,7 @@ class LinearClassifier(estimator.Estimator): # Or estimator using the FTRL optimizer with regularization. estimator = LinearClassifier( feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b], - optimizer=tf.train.FtrlOptimizer( + optimizer=tf.compat.v1.train.FtrlOptimizer( learning_rate=0.1, l1_regularization_strength=0.001 )) diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py index 3cbcc6e98de..8981432f7f2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py @@ -31,6 +31,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import metric_key from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib +from tensorflow.python.framework import dtypes from tensorflow.python.ops import math_ops @@ -160,8 +161,9 @@ def _make_logistic_eval_metric_ops(labels, predictions, thresholds): labels=labels_tensor, predictions=predictions) for threshold in thresholds: - predictions_at_threshold = math_ops.to_float( + predictions_at_threshold = math_ops.cast( math_ops.greater_equal(predictions, threshold), + dtypes.float32, name='predictions_at_threshold_%f' % threshold) metrics[metric_key.MetricKey.ACCURACY_MEAN % threshold] = ( metrics_lib.streaming_accuracy(labels=labels_tensor, diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 96adc8b83b5..5ce5c02cc63 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -131,7 +131,7 @@ class ModelFnOps( run on the chief worker during training. training_hooks: A list of `SessionRunHook` objects that will be run on all workers during training. - scaffold: A `tf.train.Scaffold` object that can be used to set + scaffold: A `tf.compat.v1.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training. Returns: @@ -220,7 +220,7 @@ class ModelFnOps( on. Pass the key of the output alternative here that you want to designate as default. A separate ExportOutpout for this default head will be added to the export_outputs dict with the special key - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is + saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, unless there is already an enry in output_alternatives with this special key. Returns: diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 08f23aa2231..b51ea30959e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -303,6 +303,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): # just manually initialize this field: self._train_distribute = None self._eval_distribute = None + self._experimental_max_worker_delay_secs = None self._device_fn = None gpu_options = config_pb2.GPUOptions( diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index 06c61554fa2..0689be88c5e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -396,8 +396,9 @@ class StateSavingRnnEstimatorTest(test.TestCase): random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) - inputs = math_ops.to_float( - array_ops.slice(random_sequence, [1], [sequence_length])) + inputs = math_ops.cast( + array_ops.slice(random_sequence, [1], [sequence_length]), + dtypes.float32) features = {'inputs': inputs} if mode == model_fn_lib.ModeKeys.INFER: @@ -450,8 +451,9 @@ class LegacyConstructorTest(test.TestCase): random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) - inputs = math_ops.to_float( - array_ops.slice(random_sequence, [1], [sequence_length])) + inputs = math_ops.cast( + array_ops.slice(random_sequence, [1], [sequence_length]), + dtypes.float32) return {'inputs': inputs}, labels return input_fn @@ -537,8 +539,9 @@ class StateSavingRNNEstimatorLearningTest(test.TestCase): random_sequence = random_ops.random_uniform( [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) labels = array_ops.slice(random_sequence, [0], [sequence_length]) - inputs = math_ops.to_float( - array_ops.slice(random_sequence, [1], [sequence_length])) + inputs = math_ops.cast( + array_ops.slice(random_sequence, [1], [sequence_length]), + dtypes.float32) return {'inputs': inputs}, labels return input_fn diff --git a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py index 71b5658dd17..9eccb2ed185 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py +++ b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """TensorSignature class and utilities (deprecated). This module and all its submodules are deprecated. See @@ -34,8 +33,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops -class TensorSignature(collections.namedtuple( - "TensorSignature", ["dtype", "shape", "is_sparse"])): +class TensorSignature( + collections.namedtuple("TensorSignature", ["dtype", "shape", "is_sparse"])): """Signature of the `Tensor` object. THIS CLASS IS DEPRECATED. See @@ -47,7 +46,7 @@ class TensorSignature(collections.namedtuple( Example: ```python - examples = tf.placeholder(...) + examples = tf.compat.v1.placeholder(...) inputs = {'a': var_a, 'b': var_b} signatures = tensor_signature.create_signatures(inputs) result = tensor_signature.create_example_parser_from_signatures( @@ -94,8 +93,8 @@ class TensorSignature(collections.namedtuple( def get_placeholder(self): if self.is_sparse: return array_ops.sparse_placeholder(dtype=self.dtype) - return array_ops.placeholder(dtype=self.dtype, - shape=[None] + list(self.shape[1:])) + return array_ops.placeholder( + dtype=self.dtype, shape=[None] + list(self.shape[1:])) def get_feature_spec(self): dtype = self.dtype @@ -114,8 +113,8 @@ def tensors_compatible(tensors, signatures): Args: tensors: Dict of `Tensor` objects or single `Tensor` object. - signatures: Dict of `TensorSignature` objects or - single `TensorSignature` object. + signatures: Dict of `TensorSignature` objects or single `TensorSignature` + object. Returns: True if all tensors are compatible, False otherwise. @@ -150,8 +149,7 @@ def create_signatures(tensors): Dict of `TensorSignature` objects or single `TensorSignature`. """ if isinstance(tensors, dict): - return { - key: TensorSignature(tensors[key]) for key in tensors} + return {key: TensorSignature(tensors[key]) for key in tensors} if tensors is None: return None return TensorSignature(tensors) @@ -165,18 +163,18 @@ def create_placeholders_from_signatures(signatures): or `None`. Returns: - Dict of `tf.placeholder` objects or single `tf.placeholder`, or `None`. + Dict of `tf.compat.v1.placeholder` objects or single + `tf.compat.v1.placeholder`, or `None`. """ if signatures is None: return None if not isinstance(signatures, dict): return signatures.get_placeholder() - return { - key: signatures[key].get_placeholder() - for key in signatures} + return {key: signatures[key].get_placeholder() for key in signatures} -def create_example_parser_from_signatures(signatures, examples_batch, +def create_example_parser_from_signatures(signatures, + examples_batch, single_feature_name="feature"): """Creates example parser from given signatures. @@ -192,8 +190,9 @@ def create_example_parser_from_signatures(signatures, examples_batch, if not isinstance(signatures, dict): feature_spec[single_feature_name] = signatures.get_feature_spec() else: - feature_spec = {key: signatures[key].get_feature_spec() - for key in signatures} + feature_spec = { + key: signatures[key].get_feature_spec() for key in signatures + } features = parsing_ops.parse_example(examples_batch, feature_spec) if not isinstance(signatures, dict): # Returns single feature, casts if needed. diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index a997fab723a..244cb3fd438 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -183,7 +183,7 @@ def train(graph, keep_checkpoint_max: The maximum number of recent checkpoint files to keep. As new files are created, older files are deleted. If None or 0, all checkpoint files are kept. This is simply passed as the max_to_keep - arg to tf.train.Saver constructor. + arg to tf.compat.v1.train.Saver constructor. supervisor_save_summaries_steps: Save summaries every `supervisor_save_summaries_steps` seconds when training. feed_fn: A function that is called every iteration to produce a `feed_dict` diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index c45b1d18647..b7dba0f2775 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -887,8 +887,8 @@ class DaskDataFeeder(object): """Returns a function, that will sample data and provide it to placeholders. Args: - input_placeholder: tf.placeholder for input features mini batch. - output_placeholder: tf.placeholder for output labels. + input_placeholder: tf.compat.v1.placeholder for input features mini batch. + output_placeholder: tf.compat.v1.placeholder for output labels. Returns: A function that when called samples a random subset of batch size diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py index f8aaa0c9e3e..20fb6a5fd9d 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py @@ -59,7 +59,7 @@ def generator_input_fn(x, 'age': np.random.randint(18, 80), 'label': np.ones(1)} - with tf.Session() as session: + with tf.compat.v1.Session() as session: input_fn = generator_io.generator_input_fn( generator, target_key="label", batch_size=2, shuffle=False, num_epochs=1) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 9e816f54b6c..4017929bc5f 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -72,7 +72,7 @@ def read_batch_examples(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). @@ -80,7 +80,7 @@ def read_batch_examples(file_pattern, num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call - `tf.local_variables_initializer()` and run the op in a session. + `tf.compat.v1.local_variables_initializer()` and run the op in a session. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. In order to have predictable and repeatable order of reading and enqueueing, such as in @@ -140,7 +140,7 @@ def read_keyed_batch_examples(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). @@ -148,7 +148,7 @@ def read_keyed_batch_examples(file_pattern, num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call - `tf.local_variables_initializer()` and run the op in a session. + `tf.compat.v1.local_variables_initializer()` and run the op in a session. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. In order to have predictable and repeatable order of reading and enqueueing, such as in @@ -215,7 +215,7 @@ def read_keyed_batch_examples_shared_queue(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). @@ -223,7 +223,7 @@ def read_keyed_batch_examples_shared_queue(file_pattern, num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call - `tf.local_variables_initializer()` and run the op in a session. + `tf.compat.v1.local_variables_initializer()` and run the op in a session. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of @@ -352,7 +352,7 @@ def _read_keyed_batch_examples_helper(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). @@ -360,7 +360,7 @@ def _read_keyed_batch_examples_helper(file_pattern, num_epochs: Integer specifying the number of times to read through the dataset. If `None`, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call - `tf.local_variables_initializer()` and run the op in a session. + `tf.compat.v1.local_variables_initializer()` and run the op in a session. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of @@ -489,7 +489,7 @@ def read_keyed_batch_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. @@ -499,7 +499,7 @@ def read_keyed_batch_features(file_pattern, num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call - tf.local_variables_initializer() and run the op in a session. + tf.compat.v1.local_variables_initializer() and run the op in a session. queue_capacity: Capacity for input queue. reader_num_threads: The number of threads to read examples. In order to have predictable and repeatable order of reading and enqueueing, such as in @@ -578,7 +578,7 @@ def read_keyed_batch_features_shared_queue(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. @@ -588,7 +588,7 @@ def read_keyed_batch_features_shared_queue(file_pattern, num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call - tf.local_variables_initializer() and run the op in a session. + tf.compat.v1.local_variables_initializer() and run the op in a session. queue_capacity: Capacity for input queue. reader_num_threads: The number of threads to read examples. feature_queue_capacity: Capacity of the parsed features queue. @@ -782,7 +782,7 @@ def read_batch_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. @@ -792,7 +792,7 @@ def read_batch_features(file_pattern, num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call - tf.local_variables_initializer() and run the op in a session. + tf.compat.v1.local_variables_initializer() and run the op in a session. queue_capacity: Capacity for input queue. feature_queue_capacity: Capacity of the parsed features queue. Set this value to a small number, for example 5 if the parsed features are large. @@ -849,7 +849,7 @@ def read_batch_record_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. + `Example` records. See `tf.io.gfile.glob` for pattern rules. batch_size: An int or scalar `Tensor` specifying the batch size to use. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. @@ -857,7 +857,7 @@ def read_batch_record_features(file_pattern, num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. NOTE - If specified, creates a variable that must be initialized, so call - tf.local_variables_initializer() and run the op in a session. + tf.compat.v1.local_variables_initializer() and run the op in a session. queue_capacity: Capacity for input queue. reader_num_threads: The number of threads to read examples. In order to have predictable and repeatable order of reading and enqueueing, such as in diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 3d691d43404..22b85cb034b 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -122,7 +122,7 @@ class BaseMonitor(object): """Callback at the end of training/evaluation. Args: - session: A `tf.Session` object that can be used to run ops. + session: A `tf.compat.v1.Session` object that can be used to run ops. Raises: ValueError: if we've not begun a run. diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py index 916aecbea88..f69a4dd6ad5 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc.py @@ -62,7 +62,7 @@ For example, # Delete everything not in 'both'. to_delete = gc.negation(both) for p in to_delete(all_paths): - gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2", + gfile.rmtree(p.path) # deletes: "/tmp/1", "/tmp/2", # "/tmp/3", "/tmp/4", "/tmp/6", """ diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index e7424472089..9555bb2fe8b 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -124,10 +124,10 @@ def rnn_decoder(decoder_inputs, in order to generate the i+1-st input, and decoder_inputs will be ignored, except for the first element ("GO" symbol). This can be used for decoding, but also for training to emulate http://arxiv.org/abs/1506.03099. - Signature -- loop_function(prev, i) = next - * prev is a 2D Tensor of shape [batch_size x output_size], - * i is an integer, the step number (when advanced control is needed), - * next is a 2D Tensor of shape [batch_size x input_size]. + Signature -- loop_function(prev, i) = next * prev is a 2D Tensor of + shape [batch_size x output_size], * i is an integer, the step number + (when advanced control is needed), * next is a 2D Tensor of shape + [batch_size x input_size]. scope: VariableScope for the created subgraph; defaults to "rnn_decoder". Returns: @@ -170,7 +170,7 @@ def basic_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 2D Tensors [batch_size x input_size]. decoder_inputs: A list of 2D Tensors [batch_size x input_size]. - cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. + cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the cell function and size. dtype: The dtype of the initial state of the RNN cell (default: tf.float32). scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". @@ -202,10 +202,10 @@ def tied_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 2D Tensors [batch_size x input_size]. decoder_inputs: A list of 2D Tensors [batch_size x input_size]. - cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. - loop_function: If not None, this function will be applied to i-th output - in order to generate i+1-th input, and decoder_inputs will be ignored, - except for the first element ("GO" symbol), see rnn_decoder for details. + cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the cell function and size. + loop_function: If not None, this function will be applied to i-th output in + order to generate i+1-th input, and decoder_inputs will be ignored, except + for the first element ("GO" symbol), see rnn_decoder for details. dtype: The dtype of the initial state of the rnn cell (default: tf.float32). scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". @@ -244,24 +244,24 @@ def embedding_rnn_decoder(decoder_inputs, Args: decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). initial_state: 2D Tensor [batch_size x cell.state_size]. - cell: tf.nn.rnn_cell.RNNCell defining the cell function. + cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the cell function. num_symbols: Integer, how many symbols come into the embedding. embedding_size: Integer, the length of the embedding vector for each symbol. output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [output_size x num_symbols] and B has - shape [num_symbols]; if provided and feed_previous=True, each fed - previous output will first be multiplied by W and added B. + biases; W has shape [output_size x num_symbols] and B has shape + [num_symbols]; if provided and feed_previous=True, each fed previous + output will first be multiplied by W and added B. feed_previous: Boolean; if True, only the first of decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be generated by: - next = embedding_lookup(embedding, argmax(previous_output)), - In effect, this implements a greedy decoder. It can also be used - during training to emulate http://arxiv.org/abs/1506.03099. - If False, decoder_inputs are used as given (the standard decoder case). + next = embedding_lookup(embedding, argmax(previous_output)), In effect, + this implements a greedy decoder. It can also be used + during training to emulate http://arxiv.org/abs/1506.03099. If False, + decoder_inputs are used as given (the standard decoder case). update_embedding_for_previous: Boolean; if False and feed_previous=True, only the embedding for the first symbol of decoder_inputs (the "GO" symbol) will be updated by back propagation. Embeddings for the symbols - generated from the decoder itself remain unchanged. This parameter has - no effect if feed_previous=False. + generated from the decoder itself remain unchanged. This parameter has no + effect if feed_previous=False. scope: VariableScope for the created subgraph; defaults to "embedding_rnn_decoder". @@ -292,8 +292,8 @@ def embedding_rnn_decoder(decoder_inputs, loop_function = _extract_argmax_and_embed( embedding, output_projection, update_embedding_for_previous) if feed_previous else None - emb_inp = (embedding_ops.embedding_lookup(embedding, i) - for i in decoder_inputs) + emb_inp = ( + embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs) return rnn_decoder( emb_inp, initial_state, cell, loop_function=loop_function) @@ -320,16 +320,16 @@ def embedding_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. + cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols: Integer; number of symbols on the decoder side. embedding_size: Integer, the length of the embedding vector for each symbol. output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [output_size x num_decoder_symbols] and B has - shape [num_decoder_symbols]; if provided and feed_previous=True, each - fed previous output will first be multiplied by W and added B. - feed_previous: Boolean or scalar Boolean Tensor; if True, only the first - of decoder_inputs will be used (the "GO" symbol), and all other decoder + biases; W has shape [output_size x num_decoder_symbols] and B has shape + [num_decoder_symbols]; if provided and feed_previous=True, each fed + previous output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of + decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be taken from previous outputs (as in embedding_rnn_decoder). If False, decoder_inputs are used as given (the standard decoder case). dtype: The dtype of the initial state for both the encoder and encoder @@ -395,9 +395,8 @@ def embedding_rnn_seq2seq(encoder_inputs, state_list = nest.flatten(state) return outputs + state_list - outputs_and_state = control_flow_ops.cond(feed_previous, - lambda: decoder(True), - lambda: decoder(False)) + outputs_and_state = control_flow_ops.cond( + feed_previous, lambda: decoder(True), lambda: decoder(False)) outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. state_list = outputs_and_state[outputs_len:] state = state_list[0] @@ -430,7 +429,7 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. + cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the cell function and size. num_symbols: Integer; number of symbols for both encoder and decoder. embedding_size: Integer, the length of the embedding vector for each symbol. num_decoder_symbols: Integer; number of output symbols for decoder. If @@ -439,11 +438,11 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, this assumes that the vocabulary is set up such that the first num_decoder_symbols of num_symbols are part of decoding. output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [output_size x num_symbols] and B has - shape [num_symbols]; if provided and feed_previous=True, each - fed previous output will first be multiplied by W and added B. - feed_previous: Boolean or scalar Boolean Tensor; if True, only the first - of decoder_inputs will be used (the "GO" symbol), and all other decoder + biases; W has shape [output_size x num_symbols] and B has shape + [num_symbols]; if provided and feed_previous=True, each fed previous + output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of + decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be taken from previous outputs (as in embedding_rnn_decoder). If False, decoder_inputs are used as given (the standard decoder case). dtype: The dtype to use for the initial RNN states (default: tf.float32). @@ -516,9 +515,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, state_list = nest.flatten(state) return outputs + state_list - outputs_and_state = control_flow_ops.cond(feed_previous, - lambda: decoder(True), - lambda: decoder(False)) + outputs_and_state = control_flow_ops.cond( + feed_previous, lambda: decoder(True), lambda: decoder(False)) outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. state_list = outputs_and_state[outputs_len:] state = state_list[0] @@ -559,23 +557,23 @@ def attention_decoder(decoder_inputs, decoder_inputs: A list of 2D Tensors [batch_size x input_size]. initial_state: 2D Tensor [batch_size x cell.state_size]. attention_states: 3D Tensor [batch_size x attn_length x attn_size]. - cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. + cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the cell function and size. output_size: Size of the output vectors; if None, we use cell.output_size. num_heads: Number of attention heads that read from attention_states. - loop_function: If not None, this function will be applied to i-th output - in order to generate i+1-th input, and decoder_inputs will be ignored, - except for the first element ("GO" symbol). This can be used for decoding, + loop_function: If not None, this function will be applied to i-th output in + order to generate i+1-th input, and decoder_inputs will be ignored, except + for the first element ("GO" symbol). This can be used for decoding, but also for training to emulate http://arxiv.org/abs/1506.03099. - Signature -- loop_function(prev, i) = next - * prev is a 2D Tensor of shape [batch_size x output_size], - * i is an integer, the step number (when advanced control is needed), - * next is a 2D Tensor of shape [batch_size x input_size]. + Signature -- loop_function(prev, i) = next * prev is a 2D Tensor of + shape [batch_size x output_size], * i is an integer, the step number + (when advanced control is needed), * next is a 2D Tensor of shape + [batch_size x input_size]. dtype: The dtype to use for the RNN initial state (default: tf.float32). scope: VariableScope for the created subgraph; default: "attention_decoder". - initial_state_attention: If False (default), initial attentions are zero. - If True, initialize the attentions from the initial state and attention - states -- useful when we wish to resume decoding from a previously - stored decoder state and attention states. + initial_state_attention: If False (default), initial attentions are zero. If + True, initialize the attentions from the initial state and attention + states -- useful when we wish to resume decoding from a previously stored + decoder state and attention states. Returns: A tuple of the form (outputs, state), where: @@ -626,8 +624,7 @@ def attention_decoder(decoder_inputs, attention_vec_size = attn_size # Size of query vectors for attention. for a in xrange(num_heads): k = variable_scope.get_variable( - "AttnW_%d" % a, [1, 1, attn_size, attention_vec_size], - dtype=dtype) + "AttnW_%d" % a, [1, 1, attn_size, attention_vec_size], dtype=dtype) hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) v.append( variable_scope.get_variable( @@ -665,8 +662,7 @@ def attention_decoder(decoder_inputs, prev = None batch_attn_size = array_ops.stack([batch_size, attn_size]) attns = [ - array_ops.zeros( - batch_attn_size, dtype=dtype) for _ in xrange(num_heads) + array_ops.zeros(batch_attn_size, dtype=dtype) for _ in xrange(num_heads) ] for a in attns: # Ensure the second shape of attention vectors is set. a.set_shape([None, attn_size]) @@ -728,7 +724,7 @@ def embedding_attention_decoder(decoder_inputs, decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). initial_state: 2D Tensor [batch_size x cell.state_size]. attention_states: 3D Tensor [batch_size x attn_length x attn_size]. - cell: tf.nn.rnn_cell.RNNCell defining the cell function. + cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the cell function. num_symbols: Integer, how many symbols come into the embedding. embedding_size: Integer, the length of the embedding vector for each symbol. num_heads: Number of attention heads that read from attention_states. @@ -739,22 +735,22 @@ def embedding_attention_decoder(decoder_inputs, output will first be multiplied by W and added B. feed_previous: Boolean; if True, only the first of decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be generated by: - next = embedding_lookup(embedding, argmax(previous_output)), - In effect, this implements a greedy decoder. It can also be used - during training to emulate http://arxiv.org/abs/1506.03099. - If False, decoder_inputs are used as given (the standard decoder case). + next = embedding_lookup(embedding, argmax(previous_output)), In effect, + this implements a greedy decoder. It can also be used + during training to emulate http://arxiv.org/abs/1506.03099. If False, + decoder_inputs are used as given (the standard decoder case). update_embedding_for_previous: Boolean; if False and feed_previous=True, only the embedding for the first symbol of decoder_inputs (the "GO" symbol) will be updated by back propagation. Embeddings for the symbols - generated from the decoder itself remain unchanged. This parameter has - no effect if feed_previous=False. + generated from the decoder itself remain unchanged. This parameter has no + effect if feed_previous=False. dtype: The dtype to use for the RNN initial states (default: tf.float32). scope: VariableScope for the created subgraph; defaults to "embedding_attention_decoder". - initial_state_attention: If False (default), initial attentions are zero. - If True, initialize the attentions from the initial state and attention - states -- useful when we wish to resume decoding from a previously - stored decoder state and attention states. + initial_state_attention: If False (default), initial attentions are zero. If + True, initialize the attentions from the initial state and attention + states -- useful when we wish to resume decoding from a previously stored + decoder state and attention states. Returns: A tuple of the form (outputs, state), where: @@ -822,24 +818,24 @@ def embedding_attention_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. + cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols: Integer; number of symbols on the decoder side. embedding_size: Integer, the length of the embedding vector for each symbol. num_heads: Number of attention heads that read from attention_states. output_projection: None or a pair (W, B) of output projection weights and - biases; W has shape [output_size x num_decoder_symbols] and B has - shape [num_decoder_symbols]; if provided and feed_previous=True, each - fed previous output will first be multiplied by W and added B. - feed_previous: Boolean or scalar Boolean Tensor; if True, only the first - of decoder_inputs will be used (the "GO" symbol), and all other decoder + biases; W has shape [output_size x num_decoder_symbols] and B has shape + [num_decoder_symbols]; if provided and feed_previous=True, each fed + previous output will first be multiplied by W and added B. + feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of + decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be taken from previous outputs (as in embedding_rnn_decoder). If False, decoder_inputs are used as given (the standard decoder case). dtype: The dtype of the initial RNN state (default: tf.float32). scope: VariableScope for the created subgraph; defaults to "embedding_attention_seq2seq". - initial_state_attention: If False (default), initial attentions are zero. - If True, initialize the attentions from the initial state and attention + initial_state_attention: If False (default), initial attentions are zero. If + True, initialize the attentions from the initial state and attention states. Returns: @@ -911,9 +907,8 @@ def embedding_attention_seq2seq(encoder_inputs, state_list = nest.flatten(state) return outputs + state_list - outputs_and_state = control_flow_ops.cond(feed_previous, - lambda: decoder(True), - lambda: decoder(False)) + outputs_and_state = control_flow_ops.cond( + feed_previous, lambda: decoder(True), lambda: decoder(False)) outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. state_list = outputs_and_state[outputs_len:] state = state_list[0] @@ -941,14 +936,14 @@ def one2many_rnn_seq2seq(encoder_inputs, Args: encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. - decoder_inputs_dict: A dictionary mapping decoder name (string) to - the corresponding decoder_inputs; each decoder_inputs is a list of 1D - Tensors of shape [batch_size]; num_decoders is defined as + decoder_inputs_dict: A dictionary mapping decoder name (string) to the + corresponding decoder_inputs; each decoder_inputs is a list of 1D Tensors + of shape [batch_size]; num_decoders is defined as len(decoder_inputs_dict). - enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and - size. - dec_cells_dict: A dictionary mapping encoder name (string) to an - instance of tf.nn.rnn_cell.RNNCell. + enc_cell: tf.compat.v1.nn.rnn_cell.RNNCell defining the encoder cell + function and size. + dec_cells_dict: A dictionary mapping encoder name (string) to an instance of + tf.nn.rnn_cell.RNNCell. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an integer specifying number of symbols for the corresponding decoder; @@ -1004,8 +999,8 @@ def one2many_rnn_seq2seq(encoder_inputs, num_decoder_symbols = num_decoder_symbols_dict[name] dec_cell = dec_cells_dict[name] - with variable_scope.variable_scope("one2many_decoder_" + str( - name)) as scope: + with variable_scope.variable_scope("one2many_decoder_" + + str(name)) as scope: dec_cell = core_rnn_cell.OutputProjectionWrapper( dec_cell, num_decoder_symbols) if isinstance(feed_previous, bool): @@ -1038,8 +1033,8 @@ def one2many_rnn_seq2seq(encoder_inputs, return outputs + state_list outputs_and_state = control_flow_ops.cond( - feed_previous, lambda: filled_embedding_rnn_decoder(True), - lambda: filled_embedding_rnn_decoder(False)) + feed_previous, lambda: filled_embedding_rnn_decoder(True), lambda: + filled_embedding_rnn_decoder(False)) # Outputs length is the same as for decoder inputs. outputs_len = len(decoder_inputs) outputs = outputs_and_state[:outputs_len] @@ -1068,10 +1063,10 @@ def sequence_loss_by_example(logits, weights: List of 1D batch-sized float-Tensors of the same length as logits. average_across_timesteps: If set, divide the returned cost by the total label weight. - softmax_loss_function: Function (labels, logits) -> loss-batch - to be used instead of the standard softmax (the default if this is None). - **Note that to avoid confusion, it is required for the function to accept - named arguments.** + softmax_loss_function: Function (labels, logits) -> loss-batch to be used + instead of the standard softmax (the default if this is None). **Note that + to avoid confusion, it is required for the function to accept named + arguments.** name: Optional name for this operation, default: "sequence_loss_by_example". Returns: @@ -1121,10 +1116,10 @@ def sequence_loss(logits, average_across_timesteps: If set, divide the returned cost by the total label weight. average_across_batch: If set, divide the returned cost by the batch size. - softmax_loss_function: Function (labels, logits) -> loss-batch - to be used instead of the standard softmax (the default if this is None). - **Note that to avoid confusion, it is required for the function to accept - named arguments.** + softmax_loss_function: Function (labels, logits) -> loss-batch to be used + instead of the standard softmax (the default if this is None). **Note that + to avoid confusion, it is required for the function to accept named + arguments.** name: Optional name for this operation, defaults to "sequence_loss". Returns: @@ -1169,16 +1164,16 @@ def model_with_buckets(encoder_inputs, targets: A list of 1D batch-sized int32 Tensors (desired output sequence). weights: List of 1D batch-sized float-Tensors to weight the targets. buckets: A list of pairs of (input size, output size) for each bucket. - seq2seq: A sequence-to-sequence model function; it takes 2 input that - agree with encoder_inputs and decoder_inputs, and returns a pair - consisting of outputs and states (as, e.g., basic_rnn_seq2seq). - softmax_loss_function: Function (labels, logits) -> loss-batch - to be used instead of the standard softmax (the default if this is None). - **Note that to avoid confusion, it is required for the function to accept - named arguments.** + seq2seq: A sequence-to-sequence model function; it takes 2 input that agree + with encoder_inputs and decoder_inputs, and returns a pair consisting of + outputs and states (as, e.g., basic_rnn_seq2seq). + softmax_loss_function: Function (labels, logits) -> loss-batch to be used + instead of the standard softmax (the default if this is None). **Note that + to avoid confusion, it is required for the function to accept named + arguments.** per_example_loss: Boolean. If set, the returned loss will be a batch-sized - tensor of losses for each sequence in the batch. If unset, it will be - a scalar with the averaged loss from all examples. + tensor of losses for each sequence in the batch. If unset, it will be a + scalar with the averaged loss from all examples. name: Optional name for this operation, defaults to "model_with_buckets". Returns: diff --git a/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py index b3022505635..937a49395e2 100644 --- a/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py +++ b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py @@ -22,12 +22,17 @@ from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import resource_loader +from tensorflow.python.util.deprecation import deprecated _libsvm_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_libsvm_ops.so")) +@deprecated(None, + 'tf.contrib.libsvm will be removed in 2.0, the support for libsvm ' + 'format will continue to be provided in tensorflow-io: ' + 'https://github.com/tensorflow/io') def decode_libsvm(content, num_features, dtype=None, label_dtype=None): """Convert Libsvm records to a tensor of label and a tensor of feature. diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index 7534b50a4ae..ec0cbf92dd2 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -40,6 +40,7 @@ py_test( name = "sdca_ops_test", size = "medium", srcs = ["python/kernel_tests/sdca_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_gpu", @@ -80,6 +81,7 @@ py_test( name = "sharded_mutable_dense_hashtable_test", size = "small", srcs = ["python/ops/sharded_mutable_dense_hashtable_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":sharded_mutable_dense_hashtable_py", @@ -100,6 +102,7 @@ py_test( name = "sparse_feature_column_test", size = "small", srcs = ["python/ops/sparse_feature_column_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":sparse_feature_column_py", @@ -130,6 +133,7 @@ py_library( py_test( name = "sdca_estimator_test", srcs = ["python/sdca_estimator_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index d49834dc860..9dea5eff337 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -465,6 +465,9 @@ class SdcaWithLogisticLossTest(SdcaModelTest): dtypes.string, shape=(len(example_weights),)) examples['example_ids'] = example_ids variables = make_variable_dict(1, 1) + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() for num_shards in _SHARD_NUMBERS: for num_loss_partitions in _NUM_LOSS_PARTITIONS: with self._single_threaded_test_session(): diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index c056a12fa53..950840c6b77 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -624,7 +624,7 @@ class SdcaModel(object): # Note that we need double precision to get accurate results. with ops.control_dependencies(shard_sums): shard_sums.append( - math_ops.reduce_sum(math_ops.to_double(values), 0)) + math_ops.reduce_sum(math_ops.cast(values, dtypes.float64), 0)) summed_values = math_ops.add_n(shard_sums) primal_loss = summed_values[1] diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 8ebe45d8510..58ab3aec664 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -135,7 +135,7 @@ class SDCAOptimizer(object): array_ops.reshape( array_ops.split( value=sparse_indices, num_or_size_splits=2, axis=1)[1], [-1]), - array_ops.reshape(math_ops.to_float(sparse_values), [-1])) + array_ops.reshape(math_ops.cast(sparse_values, dtypes.float32), [-1])) def _training_examples_and_variables(): """Returns dictionaries for training examples and variables.""" @@ -254,8 +254,8 @@ class SDCAOptimizer(object): examples = dict( sparse_features=sparse_feature_with_values, dense_features=dense_features, - example_labels=math_ops.to_float( - array_ops.reshape(targets, shape=[-1])), + example_labels=math_ops.cast( + array_ops.reshape(targets, shape=[-1]), dtypes.float32), example_weights=example_weights, example_ids=example_ids) sdca_variables = dict( diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 20e86e56bbe..60aaae7e4c9 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -51,8 +51,13 @@ def string_to_index_table_from_file(vocabulary_file=None, hasher_spec=FastHashSpec, name=None): return index_table_from_file( - vocabulary_file, num_oov_buckets, vocab_size, default_value, hasher_spec, - key_dtype=dtypes.string, name=name) + vocabulary_file, + num_oov_buckets, + vocab_size, + default_value, + hasher_spec, + key_dtype=dtypes.string, + name=name) @deprecated("2017-04-10", "Use `index_table_from_tensor`.") @@ -88,7 +93,8 @@ def index_table_from_tensor(mapping, The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling - `session.run(tf.tables_initializer)` or `session.run(table.init)` once. + `session.run(tf.compat.v1.tables_initializer)` or `session.run(table.init)` + once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -102,7 +108,7 @@ def index_table_from_tensor(mapping, features = tf.constant(["emerson", "lake", "and", "palmer"]) ids = table.lookup(features) ... - tf.tables_initializer().run() + tf.compat.v1.tables_initializer().run() ids.eval() ==> [0, 1, 3, 2] ``` @@ -137,10 +143,9 @@ def index_table_from_tensor(mapping, name=name) -@deprecated( - "2017-01-07", "This op will be removed after the deprecation date. " - "Please switch to index_table_from_tensor and call the lookup " - "method of the returned table.") +@deprecated("2017-01-07", "This op will be removed after the deprecation date. " + "Please switch to index_table_from_tensor and call the lookup " + "method of the returned table.") def string_to_index(tensor, mapping, default_value=-1, name=None): """Maps `tensor` of strings into `int64` indices based on `mapping`. @@ -155,7 +160,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): will throw a FailedPreconditionError. The underlying table must be initialized by calling - `session.run(tf.tables_initializer)` once. + `session.run(tf.compat.v1.tables_initializer)` once. For example: @@ -165,7 +170,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): ids = tf.contrib.lookup.string_to_index( feats, mapping=mapping_strings, default_value=-1) ... - tf.tables_initializer().run() + tf.compat.v1.tables_initializer().run() ids.eval() ==> [0, 1, -1, 2] ``` @@ -199,7 +204,8 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `session.run(tf.tables_initializer)` or `session.run(table.init)` once. + `session.run(tf.compat.v1.tables_initializer)` or `session.run(table.init)` + once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -213,7 +219,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): mapping_string, default_value="UNKNOWN") values = table.lookup(indices) ... - tf.tables_initializer().run() + tf.compat.v1.tables_initializer().run() values.eval() ==> ["lake", "UNKNOWN"] ``` @@ -254,7 +260,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `session.run(tf.tables_initializer)` once. + `session.run(tf.compat.v1.tables_initializer)` once. For example: @@ -264,7 +270,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): values = tf.contrib.lookup.index_to_string( indices, mapping=mapping_string, default_value="UNKNOWN") ... - tf.tables_initializer().run() + tf.compat.v1.tables_initializer().run() values.eval() ==> ["lake", "UNKNOWN"] ``` diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD index f4ebbdeee88..c51b651d1a4 100644 --- a/tensorflow/contrib/losses/BUILD +++ b/tensorflow/contrib/losses/BUILD @@ -39,6 +39,7 @@ py_library( py_test( name = "loss_ops_test", srcs = glob(["python/losses/loss_ops_test.py"]), + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":losses_py", @@ -86,6 +87,7 @@ py_test( srcs = [ "python/metric_learning/metric_loss_ops_test.py", ], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 5ebdd0b8b50..dea111f9a0f 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import add_arg_scope +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -100,8 +101,8 @@ def compute_weighted_loss(losses, weights=1.0, scope=None): with ops.name_scope(scope, "weighted_loss", [losses, weights]): losses = ops.convert_to_tensor(losses) input_dtype = losses.dtype - losses = math_ops.to_float(losses) - weights = math_ops.to_float(ops.convert_to_tensor(weights)) + losses = math_ops.cast(losses, dtypes.float32) + weights = math_ops.cast(ops.convert_to_tensor(weights), dtypes.float32) if losses.get_shape().ndims is None: raise ValueError("losses.get_shape().ndims cannot be None") @@ -147,8 +148,8 @@ def _num_present(losses, weights, per_batch=False): batch_size = array_ops.reshape( array_ops.slice(array_ops.shape(losses), [0], [1]), []) num_per_batch = math_ops.div( - math_ops.to_float(array_ops.size(losses)), - math_ops.to_float(batch_size)) + math_ops.cast(array_ops.size(losses), dtypes.float32), + math_ops.cast(batch_size, dtypes.float32)) num_per_batch = array_ops.where( math_ops.equal(weights, 0), 0.0, num_per_batch) num_per_batch = math_ops.multiply( @@ -159,12 +160,14 @@ def _num_present(losses, weights, per_batch=False): if weights.get_shape().ndims >= 1: axis = list(range(1, weights.get_shape().ndims)) num_nonzero_per_batch = math_ops.reduce_sum( - math_ops.to_float(math_ops.not_equal(weights, 0)), axis=axis) + math_ops.cast(math_ops.not_equal(weights, 0), dtypes.float32), + axis=axis) # Next, determine the number of elements that weights would broadcast to: broadcast_dims = array_ops.slice( array_ops.shape(losses), [weights.get_shape().ndims], [-1]) - num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims)) + num_to_broadcast = math_ops.cast(math_ops.reduce_prod(broadcast_dims), + dtypes.float32) num_per_batch = math_ops.multiply(num_nonzero_per_batch, num_to_broadcast) return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) @@ -262,8 +265,8 @@ def absolute_difference(predictions, labels=None, weights=1.0, scope=None): with ops.name_scope(scope, "absolute_difference", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) losses = math_ops.abs(math_ops.subtract(predictions, labels)) return compute_weighted_loss(losses, weights, scope=scope) @@ -438,8 +441,8 @@ def log_loss(predictions, labels=None, weights=1.0, epsilon=1e-7, scope=None): with ops.name_scope(scope, "log_loss", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) losses = -math_ops.multiply( labels, math_ops.log(predictions + epsilon)) - math_ops.multiply( (1 - labels), math_ops.log(1 - predictions + epsilon)) @@ -473,7 +476,7 @@ def hinge_loss(logits, labels=None, scope=None): with ops.name_scope(scope, "hinge_loss", [logits, labels]) as scope: logits.get_shape().assert_is_compatible_with(labels.get_shape()) # We first need to convert binary labels to -1/1 labels (as floats). - labels = math_ops.to_float(labels) + labels = math_ops.cast(labels, dtypes.float32) all_ones = array_ops.ones_like(labels) labels = math_ops.subtract(2 * labels, all_ones) return nn_ops.relu( @@ -509,8 +512,8 @@ def mean_squared_error(predictions, labels=None, weights=1.0, scope=None): with ops.name_scope(scope, "mean_squared_error", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) losses = math_ops.squared_difference(predictions, labels) return compute_weighted_loss(losses, weights, scope=scope) @@ -563,9 +566,9 @@ def mean_pairwise_squared_error(predictions, with ops.name_scope(scope, "mean_pairwise_squared_error", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) - weights = math_ops.to_float(ops.convert_to_tensor(weights)) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) + weights = math_ops.cast(ops.convert_to_tensor(weights), dtypes.float32) diffs = math_ops.subtract(predictions, labels) @@ -638,8 +641,8 @@ def cosine_distance(predictions, [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - predictions = math_ops.to_float(predictions) - labels = math_ops.to_float(labels) + predictions = math_ops.cast(predictions, dtypes.float32) + labels = math_ops.cast(labels, dtypes.float32) radial_diffs = math_ops.multiply(predictions, labels) losses = 1 - math_ops.reduce_sum( diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py index f3b0e77740f..226527a49c7 100644 --- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -67,11 +67,13 @@ def pairwise_distance(feature, squared=False): pairwise_distances = pairwise_distances_squared else: pairwise_distances = math_ops.sqrt( - pairwise_distances_squared + math_ops.to_float(error_mask) * 1e-16) + pairwise_distances_squared + + math_ops.cast(error_mask, dtypes.float32) * 1e-16) # Undo conditionally adding 1e-16. pairwise_distances = math_ops.multiply( - pairwise_distances, math_ops.to_float(math_ops.logical_not(error_mask))) + pairwise_distances, + math_ops.cast(math_ops.logical_not(error_mask), dtypes.float32)) num_data = array_ops.shape(feature)[0] # Explicitly set diagonals to zero. @@ -111,8 +113,8 @@ def contrastive_loss(labels, embeddings_anchor, embeddings_positive, # Add contrastive loss for the siamese network. # label here is {0,1} for neg, pos. return math_ops.reduce_mean( - math_ops.to_float(labels) * math_ops.square(distances) + - (1. - math_ops.to_float(labels)) * + math_ops.cast(labels, dtypes.float32) * math_ops.square(distances) + + (1. - math_ops.cast(labels, dtypes.float32)) * math_ops.square(math_ops.maximum(margin - distances, 0.)), name='contrastive_loss') @@ -284,8 +286,8 @@ def npairs_loss(labels, embeddings_anchor, embeddings_positive, assert lshape.shape == 1 labels = array_ops.reshape(labels, [lshape[0], 1]) - labels_remapped = math_ops.to_float( - math_ops.equal(labels, array_ops.transpose(labels))) + labels_remapped = math_ops.cast( + math_ops.equal(labels, array_ops.transpose(labels)), dtypes.float32) labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True) # Add the softmax loss. @@ -318,9 +320,10 @@ def _build_multilabel_adjacency(sparse_labels): adjacency_matrix = array_ops.zeros([num_pairs, num_pairs]) for i in range(num_pairs): for j in range(num_pairs): - sparse_dot_product = math_ops.to_float( + sparse_dot_product = math_ops.cast( sparse_ops.sparse_reduce_sum(sparse_ops.sparse_minimum( - sparse_labels[i], sparse_labels[j]))) + sparse_labels[i], sparse_labels[j])), + dtypes.float32) sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 0) sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 1) one_hot_matrix = array_ops.pad(sparse_dot_product, @@ -390,7 +393,7 @@ def npairs_loss_multilabel(sparse_labels, embeddings_anchor, # TODO(coreylynch): are composed only of 0's and 1's. multilabel_adjacency_matrix = _build_multilabel_adjacency(sparse_labels) - labels_remapped = math_ops.to_float(multilabel_adjacency_matrix) + labels_remapped = math_ops.cast(multilabel_adjacency_matrix, dtypes.float32) labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True) # Add the softmax loss. @@ -542,7 +545,8 @@ def get_cluster_assignment(pairwise_distances, centroid_ids): array_ops.constant(0, dtype=dtypes.int64), axis=0, dtype=dtypes.int64), - math_ops.to_int64(math_ops.range(array_ops.shape(centroid_ids)[0]))) + math_ops.cast(math_ops.range(array_ops.shape(centroid_ids)[0]), + dtypes.int64)) constraint_vect = math_ops.reduce_sum( array_ops.transpose(constraint_one_hot), axis=0) @@ -606,46 +610,51 @@ def compute_clustering_score(labels, predictions, margin_type): def _compute_nmi_score(labels, predictions): - return math_ops.to_float( + return math_ops.cast( script_ops.py_func( metrics.normalized_mutual_info_score, [labels, predictions], [dtypes.float64], - name='nmi')) + name='nmi'), + dtypes.float32) def _compute_ami_score(labels, predictions): - ami_score = math_ops.to_float( + ami_score = math_ops.cast( script_ops.py_func( metrics.adjusted_mutual_info_score, [labels, predictions], [dtypes.float64], - name='ami')) + name='ami'), + dtypes.float32) return math_ops.maximum(0.0, ami_score) def _compute_ari_score(labels, predictions): - ari_score = math_ops.to_float( + ari_score = math_ops.cast( script_ops.py_func( metrics.adjusted_rand_score, [labels, predictions], [dtypes.float64], - name='ari')) + name='ari'), + dtypes.float32) # ari score can go below 0 # http://scikit-learn.org/stable/modules/clustering.html#adjusted-rand-score return math_ops.maximum(0.0, ari_score) def _compute_vmeasure_score(labels, predictions): - vmeasure_score = math_ops.to_float( + vmeasure_score = math_ops.cast( script_ops.py_func( metrics.v_measure_score, [labels, predictions], [dtypes.float64], - name='vmeasure')) + name='vmeasure'), + dtypes.float32) return math_ops.maximum(0.0, vmeasure_score) def _compute_zeroone_score(labels, predictions): - zeroone_score = math_ops.to_float( + zeroone_score = math_ops.cast( math_ops.equal( math_ops.reduce_sum( - math_ops.to_int32(math_ops.equal(labels, predictions))), - array_ops.shape(labels)[0])) + math_ops.cast(math_ops.equal(labels, predictions), dtypes.int32)), + array_ops.shape(labels)[0]), + dtypes.float32) return zeroone_score @@ -711,8 +720,8 @@ def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids, candidate_scores = math_ops.add( candidate_scores, margin_multiplier * nmi_scores) - argmax_index = math_ops.to_int32( - math_ops.argmax(candidate_scores, axis=0)) + argmax_index = math_ops.cast( + math_ops.argmax(candidate_scores, axis=0), dtypes.int32) return candidate_ids[argmax_index] @@ -787,7 +796,7 @@ def update_medoid_per_cluster(pairwise_distances, pairwise_distances_subset, def func_body(iteration, scores_margin): # swap the current medoid with the candidate cluster member - candidate_medoid = math_ops.to_int32(cluster_member_ids[iteration]) + candidate_medoid = math_ops.cast(cluster_member_ids[iteration], dtypes.int32) tmp_chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, candidate_medoid) predictions = get_cluster_assignment(pairwise_distances, tmp_chosen_ids) metric_score = compute_clustering_score(labels, predictions, margin_type) @@ -811,10 +820,10 @@ def update_medoid_per_cluster(pairwise_distances, pairwise_distances_subset, [iteration, scores_margin]) candidate_scores = math_ops.add(scores_fac, margin_multiplier * scores_margin) - argmax_index = math_ops.to_int32( - math_ops.argmax(candidate_scores, axis=0)) + argmax_index = math_ops.cast( + math_ops.argmax(candidate_scores, axis=0), dtypes.int32) - best_medoid = math_ops.to_int32(cluster_member_ids[argmax_index]) + best_medoid = math_ops.cast(cluster_member_ids[argmax_index], dtypes.int32) chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, best_medoid) return chosen_ids @@ -842,7 +851,8 @@ def update_all_medoids(pairwise_distances, predictions, labels, chosen_ids, def func_body_augmented_pam(iteration, chosen_ids): """Call the update_medoid_per_cluster subroutine.""" mask = math_ops.equal( - math_ops.to_int64(predictions), math_ops.to_int64(iteration)) + math_ops.cast(predictions, dtypes.int64), + math_ops.cast(iteration, dtypes.int64)) this_cluster_ids = array_ops.where(mask) pairwise_distances_subset = array_ops.transpose( diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index d22548d5007..13f84313314 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -509,6 +509,7 @@ ifeq ($(TARGET),IOS) -fembed-bitcode \ -miphoneos-version-min=${MIN_SDK_VERSION} \ -framework Accelerate \ + -framework CoreFoundation \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -533,6 +534,7 @@ ifeq ($(TARGET),IOS) -fembed-bitcode \ -miphoneos-version-min=${MIN_SDK_VERSION} \ -framework Accelerate \ + -framework CoreFoundation \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -555,7 +557,8 @@ ifeq ($(TARGET),IOS) LDFLAGS := -arch arm64 \ -fembed-bitcode \ -miphoneos-version-min=${MIN_SDK_VERSION} \ - -framework Accelerate \ + -framework Accelerate \ + -framework CoreFoundation \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -579,7 +582,8 @@ ifeq ($(TARGET),IOS) LDFLAGS := -arch i386 \ -fembed-bitcode \ -mios-simulator-version-min=${MIN_SDK_VERSION} \ - -framework Accelerate \ + -framework Accelerate \ + -framework CoreFoundation \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -603,6 +607,7 @@ ifeq ($(TARGET),IOS) -fembed-bitcode \ -mios-simulator-version-min=${MIN_SDK_VERSION} \ -framework Accelerate \ + -framework CoreFoundation \ -Xlinker -S \ -Xlinker -x \ -Xlinker -dead_strip \ @@ -629,6 +634,9 @@ BENCHMARK_NAME := $(BINDIR)benchmark CORE_CC_ALL_SRCS := \ $(ABSL_CC_SRCS) \ +tensorflow/c/c_api.cc \ +tensorflow/c/kernels.cc \ +tensorflow/c/tf_status_helper.cc \ $(wildcard tensorflow/core/*.cc) \ $(wildcard tensorflow/core/common_runtime/*.cc) \ $(wildcard tensorflow/core/framework/*.cc) \ @@ -642,6 +650,10 @@ $(wildcard tensorflow/core/platform/*/*/*.cc) \ $(wildcard tensorflow/core/util/*.cc) \ $(wildcard tensorflow/core/util/*/*.cc) \ $(wildcard tensorflow/contrib/makefile/downloads/double_conversion/double-conversion/*.cc) \ +tensorflow/core/profiler/internal/profiler_interface.cc \ +tensorflow/core/profiler/internal/traceme_recorder.cc \ +tensorflow/core/profiler/lib/profiler_session.cc \ +tensorflow/core/profiler/lib/traceme.cc \ tensorflow/core/util/version_info.cc # Remove duplicates (for version_info.cc) CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) @@ -677,7 +689,7 @@ $(wildcard tensorflow/core/platform/windows/*) \ $(wildcard tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.*) \ $(wildcard tensorflow/core/grappler/inputs/file_input_yielder.*) \ $(wildcard tensorflow/core/grappler/clusters/single_machine.*) \ -tensorflow/core/util/cuda_kernel_helper_test.cu.cc +tensorflow/core/util/gpu_kernel_helper_test.cu.cc CORE_CC_EXCLUDE_SRCS := \ $(CORE_CC_EXCLUDE_SRCS_NON_GPU) \ @@ -893,6 +905,14 @@ $(HOST_OBJDIR)%.pb.o: $(HOST_GENDIR)%.pb.cc # we compile the C++. $(PROTO_TEXT_OBJS) : $(PROTO_TEXT_PB_H_FILES) +# Ensures we link CoreFoundation as it is used for time library when building +# for Mac. +ifeq ($(TARGET),IOS) + ifeq ($(IOS_ARCH),X86_64) + HOST_LDOPTS += -framework CoreFoundation + endif +endif + # Runs proto_text to generate C++ source files from protos. $(PROTO_TEXT): $(PROTO_TEXT_OBJS) $(PROTO_TEXT_PB_H_FILES) @mkdir -p $(dir $@) diff --git a/tensorflow/contrib/makefile/compile_ios_protobuf.sh b/tensorflow/contrib/makefile/compile_ios_protobuf.sh index 8fa20213633..d2fbf696f8f 100755 --- a/tensorflow/contrib/makefile/compile_ios_protobuf.sh +++ b/tensorflow/contrib/makefile/compile_ios_protobuf.sh @@ -24,11 +24,11 @@ fi usage() { echo "Usage: $(basename "$0") [-a]" echo "-a [build_arch] build for specified arch comma separate for multiple archs (eg: x86_64,arm64)" - echo "default arch i386, x86_64, armv7, armv7s, arm64" + echo "default arch x86_64, armv7, armv7s, arm64" exit 1 } -BUILD_TARGET="i386 x86_64 armv7 armv7s arm64" +BUILD_TARGET="x86_64 armv7 armv7s arm64" while getopts "a:" opt_name; do case "$opt_name" in a) BUILD_TARGET="${OPTARG}";; @@ -115,39 +115,6 @@ package_pb_library() { build_target() { case "$1" in - i386) make distclean - ./configure \ - --host=i386-apple-${OSX_VERSION} \ - --disable-shared \ - --enable-cross-compile \ - --with-protoc="${PROTOC_PATH}" \ - --prefix=${LIBDIR}/iossim_386 \ - --exec-prefix=${LIBDIR}/iossim_386 \ - "CFLAGS=${CFLAGS} \ - -mios-simulator-version-min=${MIN_SDK_VERSION} \ - -arch i386 \ - -fembed-bitcode \ - -isysroot ${IPHONESIMULATOR_SYSROOT}" \ - "CXX=${CXX}" \ - "CXXFLAGS=${CXXFLAGS} \ - -mios-simulator-version-min=${MIN_SDK_VERSION} \ - -arch i386 \ - -fembed-bitcode \ - -isysroot \ - ${IPHONESIMULATOR_SYSROOT}" \ - LDFLAGS="-arch i386 \ - -fembed-bitcode \ - -mios-simulator-version-min=${MIN_SDK_VERSION} \ - ${LDFLAGS} \ - -L${IPHONESIMULATOR_SYSROOT}/usr/lib/ \ - -L${IPHONESIMULATOR_SYSROOT}/usr/lib/system" \ - "LIBS=${LIBS}" - make -j"${JOB_COUNT}" - make install - - package_pb_library "iossim_386" - ;; - x86_64) make distclean ./configure \ --host=x86_64-apple-${OSX_VERSION} \ diff --git a/tensorflow/contrib/makefile/compile_ios_tensorflow.sh b/tensorflow/contrib/makefile/compile_ios_tensorflow.sh index ae82163e117..3822f0d7da7 100755 --- a/tensorflow/contrib/makefile/compile_ios_tensorflow.sh +++ b/tensorflow/contrib/makefile/compile_ios_tensorflow.sh @@ -46,11 +46,11 @@ fi usage() { echo "Usage: $(basename "$0") [-a]" echo "-a [build_arch] build for specified arch comma separate for multiple archs (eg: x86_64,arm64)" - echo "default is [i386, x86_64, armv7, armv7s, arm64]" + echo "default is [x86_64, armv7, armv7s, arm64]" exit 1 } -BUILD_TARGET="i386 x86_64 armv7 armv7s arm64" +BUILD_TARGET="x86_64 armv7 armv7s arm64" while getopts "a:f:h:n:" opt_name; do case "$opt_name" in a) BUILD_TARGET="${OPTARG}";; @@ -126,18 +126,6 @@ case "$1" in fi package_tf_library "arm64" ;; - i386) - make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \ - TARGET=IOS IOS_ARCH=I386 LIB_NAME=${LIB_PREFIX}-i386.a \ - OPTFLAGS="${BUILD_OPT}" HOST_NSYNC_LIB="${NSYNC_HOST}" \ - TARGET_NSYNC_LIB="${NSYNC_TARGET}" - if [ $? -ne 0 ] - then - echo "i386 compilation failed." - exit 1 - fi - package_tf_library "i386" - ;; x86_64) make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \ TARGET=IOS IOS_ARCH=X86_64 LIB_NAME=${LIB_PREFIX}-x86_64.a \ diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh index cb4c94d92fc..e154b8223c6 100755 --- a/tensorflow/contrib/makefile/compile_nsync.sh +++ b/tensorflow/contrib/makefile/compile_nsync.sh @@ -22,7 +22,7 @@ set -e prog=compile_nsync.sh android_api_version=21 default_android_arch=armeabi-v7a -default_ios_arch="i386 x86_64 armv7 armv7s arm64" +default_ios_arch="x86_64 armv7 armv7s arm64" usage="usage: $prog [-t linux|ios|android|macos|native] [-a architecture] [-v android_api_version] @@ -130,7 +130,7 @@ for arch in $archs; do ios) arch_flags= case "$arch" in - i386|x86_64) + x86_64) arch_flags="$arch_flags -mios-simulator-version-min=8.0" arch_flags="$arch_flags -isysroot $(xcrun --sdk iphonesimulator --show-sdk-path)" ;; diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index af3c541dc21..7566733680c 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -27,9 +27,9 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'http://mirror.tensorflow.org/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" -NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +NSYNC_URL="$(grep -o 'http://mirror.tensorflow.org/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" # Note: The protobuf repo needs to be cloned due to its submodules. # These variables contain the GitHub repo and the sha, from `tensorflow/workspace.bzl`, @@ -37,8 +37,8 @@ NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\. readonly PROTOBUF_REPO="https://github.com/protocolbuffers/protobuf.git" readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/archive/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1 | awk '{print substr($0, index($0, "archive") + 8, index($0, "tar") - index($0, "archive") - 9) }')" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once -# the archive has been propagated in mirror.bazel.build. +# TODO (yongtang): Replace the following with 'http://mirror.tensorflow.org/github.com/google/re2/.*tar\.gz' once +# the archive has been propagated in mirror.tensorflow.org. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)" @@ -46,8 +46,8 @@ ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_ CUB_URL="$(grep -o 'https.*cub/archive.*zip' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" # Required for TensorFlow Lite Flex runtime. -FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" -FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz" +FARMHASH_URL="http://mirror.tensorflow.org/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" +FLATBUFFERS_URL="http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.11.0.tar.gz" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 1c1460ce77c..93ce366b7eb 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -40,6 +40,7 @@ tensorflow/core/protobuf/saver.pb.cc tensorflow/core/protobuf/struct.pb.cc tensorflow/core/protobuf/tensorflow_server.pb.cc tensorflow/core/protobuf/verifier_config.pb.cc +tensorflow/core/protobuf/trace_events.pb.cc tensorflow/core/util/event.pb.cc tensorflow/core/util/memmapped_file_system.pb.cc tensorflow/core/util/saved_tensor_slice.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 5def632e8a7..bb14a539958 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -41,6 +41,7 @@ tensorflow/core/protobuf/struct.pb.h tensorflow/core/protobuf/tensor_bundle.pb.h tensorflow/core/protobuf/tensorflow_server.pb.h tensorflow/core/protobuf/verifier_config.pb.h +tensorflow/core/protobuf/trace_events.pb.h tensorflow/core/util/event.pb.h tensorflow/core/util/memmapped_file_system.pb.h tensorflow/core/util/saved_tensor_slice.pb.h diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 0ed87544ce3..ac54c0c3a80 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -1,3 +1,4 @@ +tensorflow/c/kernels/bitcast_op.cc tensorflow/contrib/boosted_trees/ops/model_ops.cc tensorflow/contrib/boosted_trees/ops/prediction_ops.cc tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -124,6 +125,7 @@ tensorflow/core/kernels/fill_functor.cc tensorflow/core/kernels/fft_ops.cc tensorflow/core/kernels/function_ops.cc tensorflow/core/kernels/fused_batch_norm_op.cc +tensorflow/core/kernels/fused_eigen_output_kernels.cc tensorflow/core/kernels/gather_functor.cc tensorflow/core/kernels/gather_nd_op.cc tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc @@ -148,6 +150,8 @@ tensorflow/core/kernels/lookup_table_op.cc tensorflow/core/kernels/lookup_util.cc tensorflow/core/kernels/lrn_op.cc tensorflow/core/kernels/matmul_op.cc +tensorflow/core/kernels/matrix_diag_op.cc +tensorflow/core/kernels/matrix_set_diag_op.cc tensorflow/core/kernels/maxpooling_op.cc tensorflow/core/kernels/meta_support.cc tensorflow/core/kernels/mfcc.cc @@ -256,6 +260,7 @@ tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack_ops.cc +tensorflow/core/kernels/stateful_random_ops.cc tensorflow/core/kernels/stateless_random_ops.cc tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index 13e3b6422d1..57c977bed4e 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -29,6 +29,7 @@ tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc tensorflow/core/protobuf/tensor_bundle.pb_text.cc +tensorflow/core/protobuf/trace_events.pb_text.cc tensorflow/core/protobuf/verifier_config.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/util/saved_tensor_slice.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index deb6a5b9402..8e68ac46d9f 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -46,6 +46,7 @@ tensorflow/core/protobuf/saver.proto tensorflow/core/protobuf/struct.proto tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/protobuf/tensorflow_server.proto +tensorflow/core/protobuf/trace_events.proto tensorflow/core/protobuf/verifier_config.proto tensorflow/core/util/event.proto tensorflow/core/util/memmapped_file_system.proto diff --git a/tensorflow/contrib/meta_graph_transform/BUILD b/tensorflow/contrib/meta_graph_transform/BUILD index 24400789f8a..d667b8e1449 100644 --- a/tensorflow/contrib/meta_graph_transform/BUILD +++ b/tensorflow/contrib/meta_graph_transform/BUILD @@ -36,6 +36,7 @@ py_test( name = "meta_graph_transform_test", size = "small", srcs = ["meta_graph_transform_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", visibility = ["//visibility:private"], deps = [ diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 21cd34f73ff..858fd1ede45 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -31,7 +31,6 @@ py_library( "//tensorflow/python:check_ops", "//tensorflow/python:confusion_matrix", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:histogram_ops", "//tensorflow/python:init_ops", @@ -44,6 +43,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:weights_broadcast_ops", + "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/ops/distributions", ], ) @@ -51,6 +51,7 @@ py_library( py_test( name = "classification_test", srcs = ["python/metrics/classification_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":metrics_py", @@ -64,6 +65,7 @@ py_test( name = "histogram_ops_test", size = "medium", srcs = ["python/kernel_tests/histogram_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":metrics_py", @@ -78,6 +80,7 @@ py_test( py_test( name = "metric_ops_test", srcs = ["python/ops/metric_ops_test.py"], + python_version = "PY2", shard_count = 30, srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 @@ -103,6 +106,7 @@ py_test( name = "metric_ops_large_test", size = "large", srcs = ["python/ops/metric_ops_large_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 deps = [ diff --git a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py index 81bbe935e74..1fb15bfcd6d 100644 --- a/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py +++ b/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py @@ -24,7 +24,7 @@ from tensorflow.python.ops import confusion_matrix as cm def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, name=None, weights=None): - """Deprecated. Use tf.confusion_matrix instead.""" + """Deprecated. Use tf.math.confusion_matrix instead.""" return cm.confusion_matrix(labels=labels, predictions=predictions, num_classes=num_classes, dtype=dtype, name=name, weights=weights) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index ece246b7c28..eae04c7ba3e 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -63,11 +63,10 @@ def streaming_true_positives(predictions, labels: The ground truth values, a `Tensor` whose dimensions must match `predictions`. Will be cast to `bool`. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions - must be either `1`, or the same as the corresponding `labels` - dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -109,11 +108,10 @@ def streaming_true_negatives(predictions, labels: The ground truth values, a `Tensor` whose dimensions must match `predictions`. Will be cast to `bool`. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions - must be either `1`, or the same as the corresponding `labels` - dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -155,11 +153,10 @@ def streaming_false_positives(predictions, labels: The ground truth values, a `Tensor` whose dimensions must match `predictions`. Will be cast to `bool`. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions - must be either `1`, or the same as the corresponding `labels` - dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -201,11 +198,10 @@ def streaming_false_negatives(predictions, labels: The ground truth values, a `Tensor` whose dimensions must match `predictions`. Will be cast to `bool`. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions - must be either `1`, or the same as the corresponding `labels` - dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -253,10 +249,10 @@ def streaming_mean(values, weights: `Tensor` whose rank is either 0, or the same rank as `values`, and must be broadcastable to `values` (i.e., all dimensions must be either `1`, or the same as the corresponding `values` dimension). - metrics_collections: An optional list of collections that `mean` - should be added to. - updates_collections: An optional list of collections that `update_op` - should be added to. + metrics_collections: An optional list of collections that `mean` should be + added to. + updates_collections: An optional list of collections that `update_op` should + be added to. name: An optional variable_scope name. Returns: @@ -307,10 +303,10 @@ def streaming_mean_tensor(values, weights: `Tensor` whose rank is either 0, or the same rank as `values`, and must be broadcastable to `values` (i.e., all dimensions must be either `1`, or the same as the corresponding `values` dimension). - metrics_collections: An optional list of collections that `mean` - should be added to. - updates_collections: An optional list of collections that `update_op` - should be added to. + metrics_collections: An optional list of collections that `mean` should be + added to. + updates_collections: An optional list of collections that `update_op` should + be added to. name: An optional variable_scope name. Returns: @@ -477,8 +473,8 @@ def streaming_recall(predictions, weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that `recall` should - be added to. + metrics_collections: An optional list of collections that `recall` should be + added to. updates_collections: An optional list of collections that `update_op` should be added to. name: An optional variable_scope name. @@ -535,7 +531,7 @@ def streaming_false_positive_rate(predictions, `labels`, and must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). metrics_collections: An optional list of collections that - `false_positive_rate` should be added to. + `false_positive_rate` should be added to. updates_collections: An optional list of collections that `update_op` should be added to. name: An optional variable_scope name. @@ -712,9 +708,8 @@ def _streaming_confusion_matrix_at_thresholds(predictions, to `bool`. thresholds: A python list or tuple of float thresholds in `[0, 1]`. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions - must be either `1`, or the same as the corresponding `labels` - dimension). + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, default to all four. @@ -772,7 +767,7 @@ def _streaming_confusion_matrix_at_thresholds(predictions, if weights is not None: broadcast_weights = weights_broadcast_ops.broadcast_weights( - math_ops.to_float(weights), predictions) + math_ops.cast(weights, dtypes.float32), predictions) weights_tiled = array_ops.tile( array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) thresh_tiled.get_shape().assert_is_compatible_with( @@ -784,51 +779,51 @@ def _streaming_confusion_matrix_at_thresholds(predictions, update_ops = {} if 'tp' in includes: - true_positives = metrics_impl.metric_variable( - [num_thresholds], dtypes.float32, name='true_positives') - is_true_positive = math_ops.to_float( - math_ops.logical_and(label_is_pos, pred_is_pos)) + true_positives = metrics_impl.metric_variable([num_thresholds], + dtypes.float32, + name='true_positives') + is_true_positive = math_ops.cast( + math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32) if weights_tiled is not None: is_true_positive *= weights_tiled - update_ops['tp'] = state_ops.assign_add(true_positives, - math_ops.reduce_sum( - is_true_positive, 1)) + update_ops['tp'] = state_ops.assign_add( + true_positives, math_ops.reduce_sum(is_true_positive, 1)) values['tp'] = true_positives if 'fn' in includes: - false_negatives = metrics_impl.metric_variable( - [num_thresholds], dtypes.float32, name='false_negatives') - is_false_negative = math_ops.to_float( - math_ops.logical_and(label_is_pos, pred_is_neg)) + false_negatives = metrics_impl.metric_variable([num_thresholds], + dtypes.float32, + name='false_negatives') + is_false_negative = math_ops.cast( + math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32) if weights_tiled is not None: is_false_negative *= weights_tiled - update_ops['fn'] = state_ops.assign_add(false_negatives, - math_ops.reduce_sum( - is_false_negative, 1)) + update_ops['fn'] = state_ops.assign_add( + false_negatives, math_ops.reduce_sum(is_false_negative, 1)) values['fn'] = false_negatives if 'tn' in includes: - true_negatives = metrics_impl.metric_variable( - [num_thresholds], dtypes.float32, name='true_negatives') - is_true_negative = math_ops.to_float( - math_ops.logical_and(label_is_neg, pred_is_neg)) + true_negatives = metrics_impl.metric_variable([num_thresholds], + dtypes.float32, + name='true_negatives') + is_true_negative = math_ops.cast( + math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32) if weights_tiled is not None: is_true_negative *= weights_tiled - update_ops['tn'] = state_ops.assign_add(true_negatives, - math_ops.reduce_sum( - is_true_negative, 1)) + update_ops['tn'] = state_ops.assign_add( + true_negatives, math_ops.reduce_sum(is_true_negative, 1)) values['tn'] = true_negatives if 'fp' in includes: - false_positives = metrics_impl.metric_variable( - [num_thresholds], dtypes.float32, name='false_positives') - is_false_positive = math_ops.to_float( - math_ops.logical_and(label_is_neg, pred_is_pos)) + false_positives = metrics_impl.metric_variable([num_thresholds], + dtypes.float32, + name='false_positives') + is_false_positive = math_ops.cast( + math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32) if weights_tiled is not None: is_false_positive *= weights_tiled - update_ops['fp'] = state_ops.assign_add(false_positives, - math_ops.reduce_sum( - is_false_positive, 1)) + update_ops['fp'] = state_ops.assign_add( + false_positives, math_ops.reduce_sum(is_false_positive, 1)) values['fp'] = false_positives return values, update_ops @@ -1020,7 +1015,7 @@ def streaming_auc(predictions, updates_collections: An optional list of collections that `update_op` should be added to. curve: Specifies the name of the curve to be computed, 'ROC' [default] or - 'PR' for the Precision-Recall-curve. + 'PR' for the Precision-Recall-curve. name: An optional variable_scope name. Returns: @@ -1141,8 +1136,8 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None): # exception seems excessive) so we return 0, otherwise we finish computing. return control_flow_ops.cond( math_ops.logical_or( - math_ops.equal(total_positive, 0), math_ops.equal( - total_positive, total_weight)), + math_ops.equal(total_positive, 0), + math_ops.equal(total_positive, total_weight)), true_fn=lambda: array_ops.constant(0, dtypes.float64), false_fn=continue_computing_dynamic_auc) @@ -1271,8 +1266,7 @@ def _compute_placement_auc(labels, predictions, weights, alpha, # Count the total number of positive and negative labels in the input. total_0 = math_ops.reduce_sum( math_ops.cast(1 - labels, weights.dtype) * weights) - total_1 = math_ops.reduce_sum( - math_ops.cast(labels, weights.dtype) * weights) + total_1 = math_ops.reduce_sum(math_ops.cast(labels, weights.dtype) * weights) # Sort the predictions ascending, as well as # (i) the corresponding labels and @@ -1308,10 +1302,10 @@ def _compute_placement_auc(labels, predictions, weights, alpha, # These cumulative sums of weights importantly exclude the current weight # sums. - cum_weight_totals_for_true = math_ops.cumsum(weight_totals_for_true, - exclusive=True) - cum_weight_totals_for_false = math_ops.cumsum(weight_totals_for_false, - exclusive=True) + cum_weight_totals_for_true = math_ops.cumsum( + weight_totals_for_true, exclusive=True) + cum_weight_totals_for_false = math_ops.cumsum( + weight_totals_for_false, exclusive=True) # Compute placement values using the formula. Values with the same segmented # indices and labels share the same placement values. @@ -1334,20 +1328,20 @@ def _compute_placement_auc(labels, predictions, weights, alpha, placements_for_false * float_labels_for_false) # Split placement values by labeled groups. - placement_values_0 = placement_values * math_ops.cast( - 1 - ordered_labels, weights.dtype) - weights_0 = ordered_weights * math_ops.cast( - 1 - ordered_labels, weights.dtype) - placement_values_1 = placement_values * math_ops.cast( - ordered_labels, weights.dtype) - weights_1 = ordered_weights * math_ops.cast( - ordered_labels, weights.dtype) + placement_values_0 = placement_values * math_ops.cast(1 - ordered_labels, + weights.dtype) + weights_0 = ordered_weights * math_ops.cast(1 - ordered_labels, weights.dtype) + placement_values_1 = placement_values * math_ops.cast(ordered_labels, + weights.dtype) + weights_1 = ordered_weights * math_ops.cast(ordered_labels, weights.dtype) # Calculate AUC using placement values - auc_0 = (math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) / - (total_0 + _EPSILON)) - auc_1 = (math_ops.reduce_sum(weights_1 * (placement_values_1)) / - (total_1 + _EPSILON)) + auc_0 = ( + math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) / + (total_0 + _EPSILON)) + auc_1 = ( + math_ops.reduce_sum(weights_1 * (placement_values_1)) / + (total_1 + _EPSILON)) auc = array_ops.where(math_ops.less(total_0, total_1), auc_1, auc_0) # Calculate variance and standard error using the placement values. @@ -1356,10 +1350,11 @@ def _compute_placement_auc(labels, predictions, weights, alpha, weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / (total_0 - 1. + _EPSILON)) var_1 = ( - math_ops.reduce_sum(weights_1 * math_ops.squared_difference( - placement_values_1, auc_1)) / (total_1 - 1. + _EPSILON)) - auc_std_err = math_ops.sqrt( - (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) + math_ops.reduce_sum( + weights_1 * math_ops.squared_difference(placement_values_1, auc_1)) / + (total_1 - 1. + _EPSILON)) + auc_std_err = math_ops.sqrt((var_0 / (total_0 + _EPSILON)) + + (var_1 / (total_1 + _EPSILON))) # Calculate asymptotic normal confidence intervals std_norm_dist = Normal(loc=0., scale=1.) @@ -1369,6 +1364,7 @@ def _compute_placement_auc(labels, predictions, weights, alpha, std_err = auc_std_err / (auc * (1. - auc + _EPSILON)) transformed_auc_lower = estimate + (z_value * std_err) transformed_auc_upper = estimate - (z_value * std_err) + def inverse_logit_transformation(x): exp_negative = math_ops.exp(math_ops.negative(x)) return 1. / (1. + exp_negative + _EPSILON) @@ -1386,20 +1382,18 @@ def _compute_placement_auc(labels, predictions, weights, alpha, lower = array_ops.where( math_ops.logical_or( math_ops.equal(auc, array_ops.ones_like(auc)), - math_ops.equal(auc, array_ops.zeros_like(auc))), - auc, auc_lower) + math_ops.equal(auc, array_ops.zeros_like(auc))), auc, auc_lower) upper = array_ops.where( math_ops.logical_or( math_ops.equal(auc, array_ops.ones_like(auc)), - math_ops.equal(auc, array_ops.zeros_like(auc))), - auc, auc_upper) + math_ops.equal(auc, array_ops.zeros_like(auc))), auc, auc_upper) # If all the labels are the same, AUC isn't well-defined (but raising an # exception seems excessive) so we return 0, otherwise we finish computing. trivial_value = array_ops.constant(0.0) return AucData(*control_flow_ops.cond( - is_valid, lambda: [auc, lower, upper], lambda: [trivial_value]*3)) + is_valid, lambda: [auc, lower, upper], lambda: [trivial_value] * 3)) def auc_with_confidence_intervals(labels, @@ -1486,12 +1480,13 @@ def auc_with_confidence_intervals(labels, ]): preds_accum, update_preds = streaming_concat( predictions, name='concat_preds') - labels_accum, update_labels = streaming_concat(labels, - name='concat_labels') + labels_accum, update_labels = streaming_concat( + labels, name='concat_labels') weights_accum, update_weights = streaming_concat( weights, name='concat_weights') - update_op_for_valid_case = control_flow_ops.group( - update_labels, update_preds, update_weights) + update_op_for_valid_case = control_flow_ops.group(update_labels, + update_preds, + update_weights) # Only perform updates if this case is valid. all_labels_positive_or_0 = math_ops.logical_and( @@ -1502,8 +1497,8 @@ def auc_with_confidence_intervals(labels, sums_of_weights_at_least_1) update_op = control_flow_ops.cond( - sums_of_weights_at_least_1, - lambda: update_op_for_valid_case, control_flow_ops.no_op) + sums_of_weights_at_least_1, lambda: update_op_for_valid_case, + control_flow_ops.no_op) auc = _compute_placement_auc( labels_accum, @@ -1549,12 +1544,12 @@ def precision_recall_at_equal_thresholds(labels, labels: A bool `Tensor` whose shape matches `predictions`. predictions: A floating point `Tensor` of arbitrary shape and whose values are in the range `[0, 1]`. - weights: Optional; If provided, a `Tensor` that has the same dtype as, - and broadcastable to, `predictions`. This tensor is multiplied by counts. - num_thresholds: Optional; Number of thresholds, evenly distributed in - `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins - is 1 less than `num_thresholds`. Using an even `num_thresholds` value - instead of an odd one may yield unfriendly edges for bins. + weights: Optional; If provided, a `Tensor` that has the same dtype as, and + broadcastable to, `predictions`. This tensor is multiplied by counts. + num_thresholds: Optional; Number of thresholds, evenly distributed in `[0, + 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins is 1 + less than `num_thresholds`. Using an even `num_thresholds` value instead + of an odd one may yield unfriendly edges for bins. use_locking: Optional; If True, the op will be protected by a lock. Otherwise, the behavior is undefined, but may exhibit less contention. Defaults to True. @@ -1621,7 +1616,7 @@ def precision_recall_at_equal_thresholds(labels, f_labels = math_ops.cast(labels, agg_dtype) weights = math_ops.cast(weights, agg_dtype) - true_labels = f_labels * weights + true_labels = f_labels * weights false_labels = (1.0 - f_labels) * weights # Flatten predictions and labels. @@ -1650,7 +1645,7 @@ def precision_recall_at_equal_thresholds(labels, # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] # Given a prediction value p, we can map it to its bucket by # bucket_index(p) = floor( p * (num_thresholds - 1) ) - # so we can use tf.scatter_add() to update the buckets in one pass. + # so we can use tf.compat.v1.scatter_add() to update the buckets in one pass. # # This implementation exhibits a run time and space complexity of O(T + N), # where T is the number of thresholds and N is the size of predictions. @@ -1662,10 +1657,12 @@ def precision_recall_at_equal_thresholds(labels, math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32) with ops.name_scope('variables'): - tp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], agg_dtype, name='tp_buckets') - fp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], agg_dtype, name='fp_buckets') + tp_buckets_v = metrics_impl.metric_variable([num_thresholds], + agg_dtype, + name='tp_buckets') + fp_buckets_v = metrics_impl.metric_variable([num_thresholds], + agg_dtype, + name='fp_buckets') with ops.name_scope('update_op'): update_tp = state_ops.scatter_add( @@ -2164,7 +2161,7 @@ def streaming_recall_at_k(predictions, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) + in_top_k = math_ops.cast(nn.in_top_k(predictions, labels, k), dtypes.float32) return streaming_mean(in_top_k, weights, metrics_collections, updates_collections, name or _at_k_name('recall', k)) @@ -2207,17 +2204,17 @@ def streaming_sparse_recall_at_k(predictions, If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. Args: - predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where - N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. - The final dimension contains the logit values for each class. [D1, ... DN] + predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where N >= + 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. The + final dimension contains the logit values for each class. [D1, ... DN] must match `labels`. - labels: `int64` `Tensor` or `SparseTensor` with shape - [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of - target classes for the associated prediction. Commonly, N=1 and `labels` - has shape [batch_size, num_labels]. [D1, ... DN] must match `predictions`. - Values should be in range [0, num_classes), where num_classes is the last - dimension of `predictions`. Values outside this range always count - towards `false_negative_at_`. + labels: `int64` `Tensor` or `SparseTensor` with shape [D1, ... DN, + num_labels], where N >= 1 and num_labels is the number of target classes + for the associated prediction. Commonly, N=1 and `labels` has shape + [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values + should be in range [0, num_classes), where num_classes is the last + dimension of `predictions`. Values outside this range always count towards + `false_negative_at_`. k: Integer, k for @k metric. class_id: Integer class ID for which we want binary metrics. This should be in range [0, num_classes), where num_classes is the last dimension of @@ -2226,10 +2223,10 @@ def streaming_sparse_recall_at_k(predictions, `labels`. If the latter, it must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that values should - be added to. - updates_collections: An optional list of collections that updates should - be added to. + metrics_collections: An optional list of collections that values should be + added to. + updates_collections: An optional list of collections that updates should be + added to. name: Name of new update operation, and namespace for other dependent ops. Returns: @@ -2295,17 +2292,16 @@ def streaming_sparse_precision_at_k(predictions, If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. Args: - predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where - N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. - The final dimension contains the logit values for each class. [D1, ... DN] + predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where N >= + 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. The + final dimension contains the logit values for each class. [D1, ... DN] must match `labels`. - labels: `int64` `Tensor` or `SparseTensor` with shape - [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of - target classes for the associated prediction. Commonly, N=1 and `labels` - has shape [batch_size, num_labels]. [D1, ... DN] must match - `predictions`. Values should be in range [0, num_classes), where - num_classes is the last dimension of `predictions`. Values outside this - range are ignored. + labels: `int64` `Tensor` or `SparseTensor` with shape [D1, ... DN, + num_labels], where N >= 1 and num_labels is the number of target classes + for the associated prediction. Commonly, N=1 and `labels` has shape + [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values + should be in range [0, num_classes), where num_classes is the last + dimension of `predictions`. Values outside this range are ignored. k: Integer, k for @k metric. class_id: Integer class ID for which we want binary metrics. This should be in range [0, num_classes], where num_classes is the last dimension of @@ -2315,10 +2311,10 @@ def streaming_sparse_precision_at_k(predictions, `labels`. If the latter, it must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that values should - be added to. - updates_collections: An optional list of collections that updates should - be added to. + metrics_collections: An optional list of collections that values should be + added to. + updates_collections: An optional list of collections that updates should be + added to. name: Name of new update operation, and namespace for other dependent ops. Returns: @@ -2381,17 +2377,16 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. Args: - top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where - N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. - The final dimension contains the indices of top-k labels. [D1, ... DN] - must match `labels`. - labels: `int64` `Tensor` or `SparseTensor` with shape - [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of - target classes for the associated prediction. Commonly, N=1 and `labels` - has shape [batch_size, num_labels]. [D1, ... DN] must match - `top_k_predictions`. Values should be in range [0, num_classes), where - num_classes is the last dimension of `predictions`. Values outside this - range are ignored. + top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. + Commonly, N=1 and top_k_predictions has shape [batch size, k]. The final + dimension contains the indices of top-k labels. [D1, ... DN] must match + `labels`. + labels: `int64` `Tensor` or `SparseTensor` with shape [D1, ... DN, + num_labels], where N >= 1 and num_labels is the number of target classes + for the associated prediction. Commonly, N=1 and `labels` has shape + [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`. + Values should be in range [0, num_classes), where num_classes is the last + dimension of `predictions`. Values outside this range are ignored. class_id: Integer class ID for which we want binary metrics. This should be in range [0, num_classes), where num_classes is the last dimension of `predictions`. If `class_id` is outside this range, the method returns @@ -2400,10 +2395,10 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, `labels`. If the latter, it must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that values should - be added to. - updates_collections: An optional list of collections that updates should - be added to. + metrics_collections: An optional list of collections that values should be + added to. + updates_collections: An optional list of collections that updates should be + added to. name: Name of new update operation, and namespace for other dependent ops. Returns: @@ -2464,17 +2459,17 @@ def sparse_recall_at_top_k(labels, If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. Args: - labels: `int64` `Tensor` or `SparseTensor` with shape - [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of - target classes for the associated prediction. Commonly, N=1 and `labels` - has shape [batch_size, num_labels]. [D1, ... DN] must match - `top_k_predictions`. Values should be in range [0, num_classes), where - num_classes is the last dimension of `predictions`. Values outside this - range always count towards `false_negative_at_`. - top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where - N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. - The final dimension contains the indices of top-k labels. [D1, ... DN] - must match `labels`. + labels: `int64` `Tensor` or `SparseTensor` with shape [D1, ... DN, + num_labels], where N >= 1 and num_labels is the number of target classes + for the associated prediction. Commonly, N=1 and `labels` has shape + [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`. + Values should be in range [0, num_classes), where num_classes is the last + dimension of `predictions`. Values outside this range always count towards + `false_negative_at_`. + top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. + Commonly, N=1 and top_k_predictions has shape [batch size, k]. The final + dimension contains the indices of top-k labels. [D1, ... DN] must match + `labels`. class_id: Integer class ID for which we want binary metrics. This should be in range [0, num_classes), where num_classes is the last dimension of `predictions`. If class_id is outside this range, the method returns NAN. @@ -2482,10 +2477,10 @@ def sparse_recall_at_top_k(labels, `labels`. If the latter, it must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that values should - be added to. - updates_collections: An optional list of collections that updates should - be added to. + metrics_collections: An optional list of collections that values should be + added to. + updates_collections: An optional list of collections that updates should be + added to. name: Name of new update operation, and namespace for other dependent ops. Returns: @@ -2523,9 +2518,9 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name, fn: The number of false negatives. precision: The precision for which the recall will be calculated. name: An optional variable_scope name. - strict_mode: If true and there exists a threshold where the precision is - no smaller than the target precision, return the corresponding recall at - the threshold. Otherwise, return 0. If false, find the threshold where the + strict_mode: If true and there exists a threshold where the precision is no + smaller than the target precision, return the corresponding recall at the + threshold. Otherwise, return 0. If false, find the threshold where the precision is closest to the target precision and return the recall at the threshold. @@ -2596,8 +2591,8 @@ def recall_at_precision(labels, be either `1`, or the same as the corresponding `labels` dimension). num_thresholds: The number of thresholds to use for matching the given `precision`. - metrics_collections: An optional list of collections that `recall` - should be added to. + metrics_collections: An optional list of collections that `recall` should be + added to. updates_collections: An optional list of collections that `update_op` should be added to. name: An optional variable_scope name. @@ -2747,13 +2742,12 @@ def precision_at_recall(labels, # Now we have the threshold at which to compute precision: return math_ops.div(tp[tf_index] + kepsilon, - tp[tf_index] + fp[tf_index] + kepsilon, - name) + tp[tf_index] + fp[tf_index] + kepsilon, name) - precision_value = compute_precision_at_recall( - values['tp'], values['fp'], values['fn'], 'value') - update_op = compute_precision_at_recall( - update_ops['tp'], update_ops['fp'], update_ops['fn'], 'update_op') + precision_value = compute_precision_at_recall(values['tp'], values['fp'], + values['fn'], 'value') + update_op = compute_precision_at_recall(update_ops['tp'], update_ops['fp'], + update_ops['fn'], 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, precision_value) @@ -2793,27 +2787,26 @@ def streaming_sparse_average_precision_at_k(predictions, If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. Args: - predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where - N >= 1. Commonly, N=1 and `predictions` has shape - [batch size, num_classes]. The final dimension contains the logit values - for each class. [D1, ... DN] must match `labels`. - labels: `int64` `Tensor` or `SparseTensor` with shape - [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of - target classes for the associated prediction. Commonly, N=1 and `labels` - has shape [batch_size, num_labels]. [D1, ... DN] must match - `predictions_`. Values should be in range [0, num_classes), where - num_classes is the last dimension of `predictions`. Values outside this - range are ignored. + predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where N >= + 1. Commonly, N=1 and `predictions` has shape [batch size, num_classes]. + The final dimension contains the logit values for each class. [D1, ... DN] + must match `labels`. + labels: `int64` `Tensor` or `SparseTensor` with shape [D1, ... DN, + num_labels], where N >= 1 and num_labels is the number of target classes + for the associated prediction. Commonly, N=1 and `labels` has shape + [batch_size, num_labels]. [D1, ... DN] must match `predictions_`. Values + should be in range [0, num_classes), where num_classes is the last + dimension of `predictions`. Values outside this range are ignored. k: Integer, k for @k metric. This will calculate an average precision for range `[1,k]`, as documented above. weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of `labels`. If the latter, it must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that values should - be added to. - updates_collections: An optional list of collections that updates should - be added to. + metrics_collections: An optional list of collections that values should be + added to. + updates_collections: An optional list of collections that updates should be + added to. name: Name of new update operation, and namespace for other dependent ops. Returns: @@ -2859,22 +2852,22 @@ def streaming_sparse_average_precision_at_top_k(top_k_predictions, top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final dimension must be set and contains the top `k` predicted class indices. - [D1, ... DN] must match `labels`. Values should be in range - [0, num_classes). - labels: `int64` `Tensor` or `SparseTensor` with shape - [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies - num_labels=1. N >= 1 and num_labels is the number of target classes for - the associated prediction. Commonly, N=1 and `labels` has shape - [batch_size, num_labels]. [D1, ... DN] must match `top_k_predictions`. - Values should be in range [0, num_classes). + [D1, ... DN] must match `labels`. Values should be in range [0, + num_classes). + labels: `int64` `Tensor` or `SparseTensor` with shape [D1, ... DN, + num_labels] or [D1, ... DN], where the latter implies num_labels=1. N >= 1 + and num_labels is the number of target classes for the associated + prediction. Commonly, N=1 and `labels` has shape [batch_size, num_labels]. + [D1, ... DN] must match `top_k_predictions`. Values should be in range [0, + num_classes). weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of `labels`. If the latter, it must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that values should - be added to. - updates_collections: An optional list of collections that updates should - be added to. + metrics_collections: An optional list of collections that values should be + added to. + updates_collections: An optional list of collections that updates should be + added to. name: Name of new update operation, and namespace for other dependent ops. Returns: @@ -2927,8 +2920,8 @@ def streaming_mean_absolute_error(predictions, labels: A `Tensor` of the same shape as `predictions`. weights: Optional `Tensor` indicating the frequency with which an example is sampled. Rank must be 0, or the same rank as `labels`, and must be - broadcastable to `labels` (i.e., all dimensions must be either `1`, or - the same as the corresponding `labels` dimension). + broadcastable to `labels` (i.e., all dimensions must be either `1`, or the + same as the corresponding `labels` dimension). metrics_collections: An optional list of collections that `mean_absolute_error` should be added to. updates_collections: An optional list of collections that `update_op` should @@ -2987,8 +2980,8 @@ def streaming_mean_relative_error(predictions, normalizer: A `Tensor` of the same shape as `predictions`. weights: Optional `Tensor` indicating the frequency with which an example is sampled. Rank must be 0, or the same rank as `labels`, and must be - broadcastable to `labels` (i.e., all dimensions must be either `1`, or - the same as the corresponding `labels` dimension). + broadcastable to `labels` (i.e., all dimensions must be either `1`, or the + same as the corresponding `labels` dimension). metrics_collections: An optional list of collections that `mean_relative_error` should be added to. updates_collections: An optional list of collections that `update_op` should @@ -3016,6 +3009,7 @@ def streaming_mean_relative_error(predictions, updates_collections=updates_collections, name=name) + @deprecated(None, 'Please switch to tf.metrics.mean_squared_error. Note that the ' 'order of the labels and predictions arguments has been switched.') @@ -3048,8 +3042,8 @@ def streaming_mean_squared_error(predictions, labels: A `Tensor` of the same shape as `predictions`. weights: Optional `Tensor` indicating the frequency with which an example is sampled. Rank must be 0, or the same rank as `labels`, and must be - broadcastable to `labels` (i.e., all dimensions must be either `1`, or - the same as the corresponding `labels` dimension). + broadcastable to `labels` (i.e., all dimensions must be either `1`, or the + same as the corresponding `labels` dimension). metrics_collections: An optional list of collections that `mean_squared_error` should be added to. updates_collections: An optional list of collections that `update_op` should @@ -3076,9 +3070,9 @@ def streaming_mean_squared_error(predictions, updates_collections=updates_collections, name=name) + @deprecated( - None, - 'Please switch to tf.metrics.root_mean_squared_error. Note that the ' + None, 'Please switch to tf.metrics.root_mean_squared_error. Note that the ' 'order of the labels and predictions arguments has been switched.') def streaming_root_mean_squared_error(predictions, labels, @@ -3109,8 +3103,8 @@ def streaming_root_mean_squared_error(predictions, labels: A `Tensor` of the same shape as `predictions`. weights: Optional `Tensor` indicating the frequency with which an example is sampled. Rank must be 0, or the same rank as `labels`, and must be - broadcastable to `labels` (i.e., all dimensions must be either `1`, or - the same as the corresponding `labels` dimension). + broadcastable to `labels` (i.e., all dimensions must be either `1`, or the + same as the corresponding `labels` dimension). metrics_collections: An optional list of collections that `root_mean_squared_error` should be added to. updates_collections: An optional list of collections that `update_op` should @@ -3174,10 +3168,10 @@ def streaming_covariance(predictions, labels: A `Tensor` of the same size as `predictions`. weights: Optional `Tensor` indicating the frequency with which an example is sampled. Rank must be 0, or the same rank as `labels`, and must be - broadcastable to `labels` (i.e., all dimensions must be either `1`, or - the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. + broadcastable to `labels` (i.e., all dimensions must be either `1`, or the + same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -3197,15 +3191,18 @@ def streaming_covariance(predictions, predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') - mean_prediction = metrics_impl.metric_variable( - [], dtypes.float32, name='mean_prediction') - mean_label = metrics_impl.metric_variable( - [], dtypes.float32, name='mean_label') + mean_prediction = metrics_impl.metric_variable([], + dtypes.float32, + name='mean_prediction') + mean_label = metrics_impl.metric_variable([], + dtypes.float32, + name='mean_label') comoment = metrics_impl.metric_variable( # C_A in update equation [], dtypes.float32, name='comoment') if weights is None: - batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn + batch_count = math_ops.cast(array_ops.size(labels), + dtypes.float32) # n_B in eqn weighted_predictions = predictions weighted_labels = labels else: @@ -3243,8 +3240,8 @@ def streaming_covariance(predictions, if weights is None: batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) else: - batch_comoment = math_ops.reduce_sum( - unweighted_batch_coresiduals * weights) + batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * + weights) # View delta_comoment as = C_AB - C_A in the update equation above. # Since C_A is stored in a var, by how much do we need to increment that var @@ -3307,10 +3304,10 @@ def streaming_pearson_correlation(predictions, labels: A `Tensor` of the same size as predictions. weights: Optional `Tensor` indicating the frequency with which an example is sampled. Rank must be 0, or the same rank as `labels`, and must be - broadcastable to `labels` (i.e., all dimensions must be either `1`, or - the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. + broadcastable to `labels` (i.e., all dimensions must be either `1`, or the + same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -3392,8 +3389,8 @@ def streaming_mean_cosine_distance(predictions, dim: The dimension along which the cosine distance is computed. weights: An optional `Tensor` whose shape is broadcastable to `predictions`, and whose dimension `dim` is 1. - metrics_collections: An optional list of collections that the metric - value variable should be added to. + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -3456,8 +3453,8 @@ def streaming_percentage_less(values, values: A numeric `Tensor` of arbitrary size. threshold: A scalar threshold. weights: An optional `Tensor` whose shape is broadcastable to `values`. - metrics_collections: An optional list of collections that the metric - value variable should be added to. + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -3510,12 +3507,12 @@ def streaming_mean_iou(predictions, flattened, if its rank > 1. labels: A `Tensor` of ground truth labels with shape [batch size] and of type `int32` or `int64`. The tensor will be flattened, if its rank > 1. - num_classes: The possible number of labels the prediction task can - have. This value must be provided, since a confusion matrix of - dimension = [num_classes, num_classes] will be allocated. + num_classes: The possible number of labels the prediction task can have. + This value must be provided, since a confusion matrix of dimension = + [num_classes, num_classes] will be allocated. weights: An optional `Tensor` whose shape is broadcastable to `predictions`. - metrics_collections: An optional list of collections that `mean_iou` - should be added to. + metrics_collections: An optional list of collections that `mean_iou` should + be added to. updates_collections: An optional list of collections `update_op` should be added to. name: An optional variable_scope name. @@ -3552,8 +3549,8 @@ def _next_array_size(required_size, growth_factor=1.5): tf.Tensor with dtype=int32 giving the next array size. """ exponent = math_ops.ceil( - math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log( - math_ops.cast(growth_factor, dtypes.float32))) + math_ops.log(math_ops.cast(required_size, dtypes.float32)) / + math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32) @@ -3584,8 +3581,8 @@ def streaming_concat(values, max_size: optional integer maximum size of `value` along the given axis. Once the maximum size is reached, further updates are no-ops. By default, there is no maximum size: the array is resized as necessary. - metrics_collections: An optional list of collections that `value` - should be added to. + metrics_collections: An optional list of collections that `value` should be + added to. updates_collections: An optional list of collections `update_op` should be added to. name: An optional variable_scope name. @@ -3652,8 +3649,9 @@ def streaming_concat(values, new_size = size + batch_size array_size = array_ops.shape_internal(array, optimize=False)[0] - maybe_reallocate_op = control_flow_ops.cond( - new_size > array_size, reallocate, control_flow_ops.no_op) + maybe_reallocate_op = control_flow_ops.cond(new_size > array_size, + reallocate, + control_flow_ops.no_op) with ops.control_dependencies([maybe_reallocate_op]): append_values_op = array[size:new_size].assign(batch_values) with ops.control_dependencies([append_values_op]): @@ -3727,7 +3725,7 @@ def count(values, name=None): """Computes the number of examples, or sum of `weights`. - This metric keeps track of the denominator in `tf.metrics.mean`. + This metric keeps track of the denominator in `tf.compat.v1.metrics.mean`. When evaluating some metric (e.g. mean) on one or more subsets of the data, this auxiliary metric is useful for keeping track of how many examples there are in each subset. @@ -3737,11 +3735,10 @@ def count(values, Args: values: A `Tensor` of arbitrary dimensions. Only it's shape is used. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions - must be either `1`, or the same as the corresponding `labels` - dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric value + variable should be added to. updates_collections: An optional list of collections that the metric update ops should be added to. name: An optional variable_scope name. @@ -3765,15 +3762,15 @@ def count(values, count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') if weights is None: - num_values = math_ops.to_float(array_ops.size(values)) + num_values = math_ops.cast(array_ops.size(values), dtypes.float32) else: - values = math_ops.to_float(values) + values = math_ops.cast(values, dtypes.float32) values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=values, labels=None, weights=weights) weights = weights_broadcast_ops.broadcast_weights( - math_ops.to_float(weights), values) + math_ops.cast(weights, dtypes.float32), values) num_values = math_ops.reduce_sum(weights) with ops.control_dependencies([values]): @@ -3828,8 +3825,8 @@ def cohen_kappa(labels, classification. Must have the same type as `labels`. num_classes: The possible number of labels. weights: Optional `Tensor` whose shape matches `predictions`. - metrics_collections: An optional list of collections that `kappa` should - be added to. + metrics_collections: An optional list of collections that `kappa` should be + added to. updates_collections: An optional list of collections that `update_op` should be added to. name: An optional variable_scope name. @@ -3869,10 +3866,12 @@ def cohen_kappa(labels, dtypes.int64 if weights is None or weights.dtype.is_integer else dtypes.float32) po = metrics_impl.metric_variable((num_classes,), stat_dtype, name='po') - pe_row = metrics_impl.metric_variable( - (num_classes,), stat_dtype, name='pe_row') - pe_col = metrics_impl.metric_variable( - (num_classes,), stat_dtype, name='pe_col') + pe_row = metrics_impl.metric_variable((num_classes,), + stat_dtype, + name='pe_row') + pe_col = metrics_impl.metric_variable((num_classes,), + stat_dtype, + name='pe_col') # Table of the counts of agreement: counts_in_table = confusion_matrix.confusion_matrix( @@ -3895,10 +3894,11 @@ def cohen_kappa(labels, total = math_ops.reduce_sum(pe_row) pe_sum = math_ops.reduce_sum( math_ops.div_no_nan( - math_ops.to_double(pe_row * pe_col), math_ops.to_double(total))) - po_sum, pe_sum, total = (math_ops.to_double(po_sum), - math_ops.to_double(pe_sum), - math_ops.to_double(total)) + math_ops.cast(pe_row * pe_col, dtypes.float64), + math_ops.cast(total, dtypes.float64))) + po_sum, pe_sum, total = (math_ops.cast(po_sum, dtypes.float64), + math_ops.cast(pe_sum, dtypes.float64), + math_ops.cast(total, dtypes.float64)) # kappa = (po - pe) / (N - pe) k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access po_sum - pe_sum, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index fc64f343ab4..aec07241e7a 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -5810,9 +5810,10 @@ class StreamingCovarianceTest(test.TestCase): def testVars(self): metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10])) + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10]))) _assert_metric_variables(self, ( 'covariance/comoment:0', 'covariance/count:0', @@ -5823,18 +5824,20 @@ class StreamingCovarianceTest(test.TestCase): def testMetricsCollection(self): my_collection_name = '__metrics__' cov, _ = metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10])), metrics_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [cov]) def testUpdatesCollection(self): my_collection_name = '__updates__' _, update_op = metrics.streaming_covariance( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10])), updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) @@ -5857,8 +5860,8 @@ class StreamingCovarianceTest(test.TestCase): def testSingleUpdateIdentical(self): with self.cached_session() as sess: - predictions = math_ops.to_float(math_ops.range(10)) - labels = math_ops.to_float(math_ops.range(10)) + predictions = math_ops.cast(math_ops.range(10), dtypes_lib.float32) + labels = math_ops.cast(math_ops.range(10), dtypes_lib.float32) cov, update_op = metrics.streaming_covariance(predictions, labels) @@ -5982,9 +5985,10 @@ class StreamingPearsonRTest(test.TestCase): def testVars(self): metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10])) + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10]))) _assert_metric_variables(self, ( 'pearson_r/covariance/comoment:0', 'pearson_r/covariance/count:0', @@ -6003,18 +6007,20 @@ class StreamingPearsonRTest(test.TestCase): def testMetricsCollection(self): my_collection_name = '__metrics__' pearson_r, _ = metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10])), metrics_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [pearson_r]) def testUpdatesCollection(self): my_collection_name = '__updates__' _, update_op = metrics.streaming_pearson_correlation( - predictions=math_ops.to_float(math_ops.range(10)) + + predictions=math_ops.cast(math_ops.range(10), dtypes_lib.float32) + array_ops.ones([10, 10]), - labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]), + labels=(math_ops.cast(math_ops.range(10), dtypes_lib.float32) + + array_ops.ones([10, 10])), updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) @@ -6038,8 +6044,8 @@ class StreamingPearsonRTest(test.TestCase): def testSingleUpdateIdentical(self): with self.cached_session() as sess: - predictions = math_ops.to_float(math_ops.range(10)) - labels = math_ops.to_float(math_ops.range(10)) + predictions = math_ops.cast(math_ops.range(10), dtypes_lib.float32) + labels = math_ops.cast(math_ops.range(10), dtypes_lib.float32) pearson_r, update_op = metrics.streaming_pearson_correlation( predictions, labels) diff --git a/tensorflow/contrib/mixed_precision/python/BUILD b/tensorflow/contrib/mixed_precision/python/BUILD index 1d769e16141..39821399fc9 100644 --- a/tensorflow/contrib/mixed_precision/python/BUILD +++ b/tensorflow/contrib/mixed_precision/python/BUILD @@ -28,6 +28,7 @@ py_test( name = "loss_scale_manager_test", size = "small", srcs = ["loss_scale_manager_test.py"], + python_version = "PY2", deps = [ ":loss_scale_manager", "//tensorflow/python:client_testlib", @@ -62,6 +63,7 @@ py_test( name = "loss_scale_optimizer_test", size = "small", srcs = ["loss_scale_optimizer_test.py"], + python_version = "PY2", deps = [ ":loss_scale_optimizer", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py index a5621b44cd3..86306050560 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -104,8 +104,8 @@ class LossScaleOptimizer(optimizer.Optimizer): Args: opt: The actual optimizer that will be used to compute and apply the - gradients. Must be an implementation of the `tf.train.Optimizer` - interface. + gradients. Must be an implementation of the + `tf.compat.v1.train.Optimizer` interface. loss_scale_manager: A LossScaleManager object. """ self._opt = opt @@ -118,7 +118,7 @@ class LossScaleOptimizer(optimizer.Optimizer): aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): - """Compute gradients. See base class `tf.train.Optimizer`.""" + """Compute gradients. See base class `tf.compat.v1.train.Optimizer`.""" loss_scale = self._loss_scale_manager.get_loss_scale() if context.executing_eagerly(): @@ -142,7 +142,7 @@ class LossScaleOptimizer(optimizer.Optimizer): return self._down_scale(grads_and_vars, loss_scale) def apply_gradients(self, grads_and_vars, global_step=None, name=None): - """Apply gradients. See base class `tf.train.Optimizer`.""" + """Apply gradients. See base class `tf.compat.v1.train.Optimizer`.""" grads = [g for (g, _) in grads_and_vars] is_finite_grad = [] @@ -154,8 +154,9 @@ class LossScaleOptimizer(optimizer.Optimizer): def true_apply_gradients_fn(): return self._opt.apply_gradients(grads_and_vars, global_step, name) - update_vars = control_flow_ops.cond( - is_overall_finite, true_apply_gradients_fn, gen_control_flow_ops.no_op) + update_vars = control_flow_ops.cond(is_overall_finite, + true_apply_gradients_fn, + gen_control_flow_ops.no_op) # Potentially adjust gradient scale in case of finite gradients. return control_flow_ops.group( update_vars, diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD index 3cffd76a255..ce77143e0c3 100644 --- a/tensorflow/contrib/model_pruning/BUILD +++ b/tensorflow/contrib/model_pruning/BUILD @@ -46,6 +46,7 @@ py_test( name = "layers_test", size = "small", srcs = ["python/layers/layers_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":layers", @@ -115,6 +116,7 @@ py_test( name = "pruning_utils_test", size = "medium", srcs = ["python/pruning_utils_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":pruning_utils", @@ -127,6 +129,7 @@ py_test( name = "pruning_test", size = "small", srcs = ["python/pruning_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":pruning", @@ -138,6 +141,7 @@ py_test( name = "rnn_cells_test", size = "small", srcs = ["python/layers/rnn_cells_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":pruning", @@ -150,7 +154,11 @@ py_test( name = "strip_pruning_vars_test", size = "small", srcs = ["python/strip_pruning_vars_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", + tags = [ + "no_oss", # b/132443370 + ], deps = [ ":layers", ":pruning", @@ -163,6 +171,7 @@ py_test( py_binary( name = "strip_pruning_vars", srcs = ["python/strip_pruning_vars.py"], + python_version = "PY2", srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 710a262f338..98760ea7050 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -65,7 +65,7 @@ The pruning library allows for specification of the following hyper parameters: The sparsity $$s_t$$ at global step $$t$$ is given by: -$$ s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3} $$ +$$s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3}$$ The interval between sparsity_function_begin_step and sparsity_function_end_step is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta @@ -133,9 +133,10 @@ For now, it is assumed that the underlying hardware platform will provide mechan ## Example: Pruning and training deep CNNs on the cifar10 dataset -Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural -network architecture, setting up inputs etc. The additional changes needed to -incorporate pruning are captured in the following: +Please see +[Advanced Convolutional Neural Networks](https://www.tensorflow.org/tutorials/images/deep_cnn) +for details on neural network architecture, setting up inputs etc. The +additional changes needed to incorporate pruning are captured in the following: * [cifar10_pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py) creates a deep CNN with the same architecture, but adds mask and threshold diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD index 30ea9122229..805a6eab236 100644 --- a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD +++ b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD @@ -48,6 +48,7 @@ py_binary( srcs = [ "cifar10_eval.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":cifar10_pruning", @@ -61,6 +62,7 @@ py_binary( srcs = [ "cifar10_train.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":cifar10_pruning", diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py index 660f0168b10..96303f3984c 100644 --- a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py @@ -75,6 +75,7 @@ def _activation_summary(x): Args: x: Tensor + Returns: nothing """ @@ -112,16 +113,15 @@ def _variable_with_weight_decay(name, shape, stddev, wd): name: name of the variable shape: list of ints stddev: standard deviation of a truncated Gaussian - wd: add L2Loss weight decay multiplied by this float. If None, weight - decay is not added for this Variable. + wd: add L2Loss weight decay multiplied by this float. If None, weight decay + is not added for this Variable. Returns: Variable Tensor """ dtype = tf.float32 - var = _variable_on_cpu(name, shape, - tf.truncated_normal_initializer( - stddev=stddev, dtype=dtype)) + var = _variable_on_cpu( + name, shape, tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) if wd is not None: weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) @@ -176,10 +176,10 @@ def inference(images): Returns: Logits. """ - # We instantiate all variables using tf.get_variable() instead of + # We instantiate all variables using tf.compat.v1.get_variable() instead of # tf.Variable() in order to share variables across multiple GPU training runs. # If we only ran this model on a single GPU, we could simplify this function - # by replacing all instances of tf.get_variable() with tf.Variable(). + # by replacing all instances of tf.compat.v1.get_variable() with tf.Variable(). # # While instantiating conv and local layers, we add mask and threshold # variables to the layer by calling the pruning.apply_mask() function. @@ -276,8 +276,8 @@ def loss(logits, labels): Add summary for "Loss" and "Loss/avg". Args: logits: Logits from inference(). - labels: Labels from distorted_inputs or inputs(). 1-D tensor - of shape [batch_size] + labels: Labels from distorted_inputs or inputs(). 1-D tensor of shape + [batch_size] Returns: Loss tensor of type float. @@ -302,6 +302,7 @@ def _add_loss_summaries(total_loss): Args: total_loss: Total loss from loss(). + Returns: loss_averages_op: op for generating moving averages of losses. """ @@ -331,6 +332,7 @@ def train(total_loss, global_step): total_loss: Total loss from loss(). global_step: Integer Variable counting the number of training steps processed. + Returns: train_op: op for training. """ @@ -388,9 +390,9 @@ def maybe_download_and_extract(): if not os.path.exists(filepath): def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % - (filename, - float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.write( + '\r>> Downloading %s %.1f%%' % + (filename, float(count * block_size) / float(total_size) * 100.0)) sys.stdout.flush() filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py index 2959019d6d8..abbc7d01e2d 100644 --- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py +++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py @@ -48,7 +48,7 @@ class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell): It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. - For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` + For advanced models, please use the full `tf.compat.v1.nn.rnn_cell.LSTMCell` that follows. """ diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py index 26695237c27..ca34de61a7c 100644 --- a/tensorflow/contrib/model_pruning/python/learning.py +++ b/tensorflow/contrib/model_pruning/python/learning.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Wrapper around tf-slim's training code contrib/slim/python/slim/learning.py + to support training of pruned models ******************************************************************* @@ -28,7 +29,8 @@ to support training of pruned models total_loss = slim.losses.get_total_loss() # Define the optimizer: - optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) + optimizer = tf.compat.v1.train.MomentumOptimizer(FLAGS.learning_rate, + FLAGS.momentum) # Create the train_op train_op = slim.learning.create_train_op(total_loss, optimizer) @@ -98,7 +100,8 @@ def train(train_op, thresholds. train_step_fn: The function to call in order to execute a single gradient step. The function must have take exactly four arguments: the current - session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary. + session, the `train_op` `Tensor`, a global step `Tensor` and a + dictionary. train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By default, two `Boolean`, scalar ops called "should_stop" and "should_log" are provided. @@ -112,35 +115,36 @@ def train(train_op, global_step: The `Tensor` representing the global step. If left as `None`, then slim.variables.get_or_create_global_step() is used. number_of_steps: The max number of gradient steps to take during training, - as measured by 'global_step': training will stop if global_step is - greater than 'number_of_steps'. If the value is left as None, training - proceeds indefinitely. + as measured by 'global_step': training will stop if global_step is greater + than 'number_of_steps'. If the value is left as None, training proceeds + indefinitely. init_op: The initialization operation. If left to its default value, then - the session is initialized by calling `tf.global_variables_initializer()`. + the session is initialized by calling + `tf.compat.v1.global_variables_initializer()`. init_feed_dict: A feed dictionary to use when executing the `init_op`. local_init_op: The local initialization operation. If left to its default value, then the session is initialized by calling - `tf.local_variables_initializer()` and `tf.tables_initializer()`. + `tf.compat.v1.local_variables_initializer()` and + `tf.compat.v1.tables_initializer()`. init_fn: An optional callable to be executed after `init_op` is called. The callable must accept one argument, the session being initialized. ready_op: Operation to check if the model is ready to use. If left to its default value, then the session checks for readiness by calling - `tf.report_uninitialized_variables()`. + `tf.compat.v1.report_uninitialized_variables()`. summary_op: The summary operation. save_summaries_secs: How often, in seconds, to save summaries. - summary_writer: `SummaryWriter` to use. Can be `None` - to indicate that no summaries should be written. If unset, we - create a SummaryWriter. + summary_writer: `SummaryWriter` to use. Can be `None` to indicate that no + summaries should be written. If unset, we create a SummaryWriter. startup_delay_steps: The number of steps to wait for before beginning. Note that this must be 0 if a sync_optimizer is supplied. - saver: Saver to save checkpoints. If None, a default one will be created - and used. + saver: Saver to save checkpoints. If None, a default one will be created and + used. save_interval_secs: How often, in seconds, to save the model to `logdir`. - sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of - them. If the argument is supplied, gradient updates will be synchronous. - If left as `None`, gradient updates will be asynchronous. - session_config: An instance of `tf.ConfigProto` that will be used to - configure the `Session`. If left as `None`, the default will be used. + sync_optimizer: an instance of tf.compat.v1.train.SyncReplicasOptimizer, or + a list of them. If the argument is supplied, gradient updates will be + synchronous. If left as `None`, gradient updates will be asynchronous. + session_config: An instance of `tf.compat.v1.ConfigProto` that will be used + to configure the `Session`. If left as `None`, the default will be used. trace_every_n_steps: produce and save a `Timeline` in Chrome trace format and add it to the summaries every `trace_every_n_steps`. If None, no trace information will be produced or saved. diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc index b9967fe76dc..c2e1edb1366 100644 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc +++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/tensor_coding.h" +#include "tensorflow/core/framework/allocator.h" namespace tensorflow { @@ -122,7 +123,7 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync( } else { TensorResponse tr; tr.InitAlloc(dst_device, recv_args.alloc_attrs); - tr.InitPartial(mpi_response.response()); + tr.InitPartial(mpi_response.response(), AllocationAttributes()); const size_t nBytes = tr.tensor().TotalBytes(); void* data = const_cast(DMAHelper::base(&tr.tensor())); MPI_Status status; diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc index ca3ddfa721d..c8e3e81c8ba 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc @@ -20,6 +20,7 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow/contrib/mpi_collectives/kernels/ring.h" +#include "tensorflow/core/util/gpu_launch_config.h" namespace tensorflow { namespace contrib { diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD index ef7ab226465..e3e36c4fdf5 100644 --- a/tensorflow/contrib/nn/BUILD +++ b/tensorflow/contrib/nn/BUILD @@ -44,6 +44,7 @@ py_test( name = "alpha_dropout_test", size = "small", srcs = ["python/ops/alpha_dropout_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":nn_py", @@ -61,6 +62,7 @@ py_test( name = "fwd_gradients_test", size = "small", srcs = ["python/ops/fwd_gradients_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":nn_py", @@ -74,6 +76,7 @@ py_test( name = "sampling_ops_test", size = "small", srcs = ["python/ops/sampling_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":nn_py", @@ -89,6 +92,7 @@ py_test( name = "scaled_softplus_test", size = "small", srcs = ["python/ops/scaled_softplus_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":nn_py", diff --git a/tensorflow/contrib/nn/python/ops/alpha_dropout.py b/tensorflow/contrib/nn/python/ops/alpha_dropout.py index 98f4264fe08..2b64a78c223 100644 --- a/tensorflow/contrib/nn/python/ops/alpha_dropout.py +++ b/tensorflow/contrib/nn/python/ops/alpha_dropout.py @@ -43,7 +43,7 @@ def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylin noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for randomly generated keep/drop flags. seed: A Python integer. Used to create random seeds. See - `tf.set_random_seed` for behavior. + `tf.compat.v1.set_random_seed` for behavior. name: A name for this operation (optional). Returns: diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index f30643cf305..6c85533d774 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -65,6 +65,7 @@ py_library( py_test( name = "adam_gs_optimizer_test", srcs = ["python/training/adam_gs_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -80,6 +81,7 @@ py_test( py_test( name = "adamax_test", srcs = ["python/training/adamax_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -95,6 +97,7 @@ py_test( py_test( name = "external_optimizer_test", srcs = ["python/training/external_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no-internal-py3", @@ -115,6 +118,7 @@ py_test( py_test( name = "moving_average_optimizer_test", srcs = ["python/training/moving_average_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "notsan", # b/31055119 @@ -151,6 +155,7 @@ tf_py_test( py_test( name = "multitask_optimizer_wrapper_test", srcs = ["python/training/multitask_optimizer_wrapper_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -168,6 +173,7 @@ py_test( py_test( name = "lazy_adam_gs_optimizer_test", srcs = ["python/training/lazy_adam_gs_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -187,6 +193,7 @@ py_test( py_test( name = "lazy_adam_optimizer_test", srcs = ["python/training/lazy_adam_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -206,6 +213,7 @@ py_test( py_test( name = "reg_adagrad_optimizer_test", srcs = ["python/training/reg_adagrad_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -225,6 +233,7 @@ py_test( py_test( name = "nadam_optimizer_test", srcs = ["python/training/nadam_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -244,6 +253,7 @@ py_test( py_test( name = "weight_decay_optimizers_test", srcs = ["python/training/weight_decay_optimizers_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -348,6 +358,7 @@ tf_py_test( py_test( name = "sign_decay_test", srcs = ["python/training/sign_decay_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -358,6 +369,7 @@ py_test( py_test( name = "addsign_test", srcs = ["python/training/addsign_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -377,6 +389,7 @@ py_test( py_test( name = "powersign_test", srcs = ["python/training/powersign_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -396,6 +409,7 @@ py_test( py_test( name = "ggt_test", srcs = ["python/training/ggt_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -415,6 +429,7 @@ py_test( name = "shampoo_test", size = "medium", srcs = ["python/training/shampoo_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", deps = [ @@ -435,6 +450,7 @@ py_test( py_test( name = "lars_optimizer_test", srcs = ["python/training/lars_optimizer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", @@ -450,6 +466,7 @@ py_test( py_test( name = "matrix_functions_test", srcs = ["python/training/matrix_functions_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":opt_py", diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index fa1a7aaff0a..d24de9efeec 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -39,6 +39,7 @@ GLOBAL_STEP = 'global_step' class ElasticAverageCustomGetter(object): """Custom_getter class is used to do: + 1. Change trainable variables to local collection and place them at worker device 2. Generate global variables(global center variables) @@ -46,22 +47,23 @@ class ElasticAverageCustomGetter(object): variables and place them at worker device Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed - at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to + at ps device. Besides, use 'tf.compat.v1.get_variable' instead of + 'tf.Variable' to use this custom getter. For example, ea_custom_getter = ElasticAverageCustomGetter(worker_device) with tf.device( - tf.train.replica_device_setter( + tf.compat.v1.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps", cluster=cluster)), - tf.variable_scope('',custom_getter=ea_custom_getter): + tf.compat.v1.variable_scope('',custom_getter=ea_custom_getter): ... create your model here ... with tf.device(worker_device): - opt = tf.train.MomentumOptimizer(...) + opt = tf.compat.v1.train.MomentumOptimizer(...) optimizer = ElasticAverageOptimizer( opt, num_worker=2, @@ -75,7 +77,7 @@ class ElasticAverageCustomGetter(object): ... hooks = [optimizer.make_session_run_hook(is_chief, task_index)] ... - with tf.train.MonitoredTrainingSession(master=server.target, + with tf.compat.v1.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, checkpoint_dir=("...), save_checkpoint_secs=600, @@ -138,9 +140,9 @@ class ElasticAverageCustomGetter(object): return getter(name, *args, **kwargs) - class ElasticAverageOptimizer(optimizer.Optimizer): """Wrapper optimizer that implements the Elastic Average SGD algorithm. + This is an async optimizer. During the training, Each worker will update the local variables and maintains its own local_step, which starts from 0 and is incremented by 1 after each update of local variables. Whenever @@ -170,19 +172,18 @@ class ElasticAverageOptimizer(optimizer.Optimizer): Must be one of the Optimizer classes. num_worker: The number of workers ea_custom_getter: The ElasticAverageCustomGetter - communication_period: An int point value to controls the frequency - of the communication between every worker and the ps. + communication_period: An int point value to controls the frequency of the + communication between every worker and the ps. moving_rate: A floating point value to control the elastic difference. - rho: the amount of exploration we allow in the model. The default - value is moving_rate/learning_rate - rho=0.0 is suggested in async mode. + rho: the amount of exploration we allow in the model. The default value is + moving_rate/learning_rate rho=0.0 is suggested in async mode. use_locking: If True use locks for update operations. synchronous: Add_sync_queues_and_barrier or not. True: all workers will wait for each other before start training False: worker can start training when its initilization is done, - no need to wait for everyone is ready. - in case one worker is restarted, it can join and continue - training without being blocked. + no need to wait for everyone is ready. in case one worker is + restarted, it can join and continue training without being + blocked. name: Optional name prefix for the operations created when applying gradients. Defaults to "ElasticAverageOptimizer". """ @@ -229,14 +230,14 @@ class ElasticAverageOptimizer(optimizer.Optimizer): Args: loss: A Tensor containing the value to minimize. var_list: Optional list or tuple of `tf.Variable` to update to minimize - `loss`. Defaults to the list of variables collected in the graph - under the key `GraphKey.TRAINABLE_VARIABLES`. + `loss`. Defaults to the list of variables collected in the graph under + the key `GraphKey.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. - colocate_gradients_with_ops: If True, try colocating gradients with - the corresponding op. + colocate_gradients_with_ops: If True, try colocating gradients with the + corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. Returns: @@ -273,10 +274,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer): Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. - global_step: Optional `Variable` to increment by one after the - variables have been updated. - name: Optional name for the returned operation. Default to the - name passed to the `Optimizer` constructor. + global_step: Optional `Variable` to increment by one after the variables + have been updated. + name: Optional name for the returned operation. Default to the name + passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` @@ -344,13 +345,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer): with ops.control_dependencies([local_update]): condition = math_ops.equal( math_ops.mod(self._local_step, self._period), 0) - conditional_update = control_flow_ops.cond( - condition, _Update_global_variables, control_flow_ops.no_op) + conditional_update = control_flow_ops.cond(condition, + _Update_global_variables, + control_flow_ops.no_op) return conditional_update def get_init_op(self, task_index): """Returns the op to let all the local variables and local center - variables equal to the global center variables before the training begins""" + + variables equal to the global center variables before the training begins + """ def _Add_sync_queues_and_barrier(enqueue_after_list): """Adds ops to enqueue on all worker queues""" @@ -399,18 +403,20 @@ class ElasticAverageOptimizer(optimizer.Optimizer): def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver copy global_center_variable to trainable variables + Please call this function after all your variables created with ElasticAverageCustomGetter. For evaluations or inference, use this saver during training. It will save the global_center_variable of the trained parameters under the original parameter names. Args: - var_list: List of variables to save, as per `Saver()`. - If set to None, save all the trainable_variables that have - been created before this call. + var_list: List of variables to save, as per `Saver()`. If set to None, + save all the trainable_variables that have been created before this + call. name: The name of the saver. **kwargs: Keyword arguments of `Saver()`. + Returns: - A `tf.train.Saver` object. + A `tf.compat.v1.train.Saver` object. Raises: RuntimeError: global_center_variable is empty, please make sure this is called after model created and @@ -436,13 +442,14 @@ class ElasticAverageOptimizer(optimizer.Optimizer): if tvar.op.name == var.op.name: tensor = self._global_map.get(tvar, var) break - else: #partitioned variable + else: #partitioned variable tensor = [self._global_map.get(lvar, lvar) for lvar in var] swapped_var_list[key] = tensor return saver.Saver(swapped_var_list, name=name, **kwargs) + class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): def __init__(self, ea_optimizer, is_chief, task_index): diff --git a/tensorflow/contrib/opt/python/training/external_optimizer.py b/tensorflow/contrib/opt/python/training/external_optimizer.py index e5e52f7dc3a..814f980c791 100644 --- a/tensorflow/contrib/opt/python/training/external_optimizer.py +++ b/tensorflow/contrib/opt/python/training/external_optimizer.py @@ -297,7 +297,7 @@ class ExternalOptimizerInterface(object): class ScipyOptimizerInterface(ExternalOptimizerInterface): - """Wrapper allowing `scipy.optimize.minimize` to operate a `tf.Session`. + """Wrapper allowing `scipy.optimize.minimize` to operate a `tf.compat.v1.Session`. Example: @@ -309,7 +309,7 @@ class ScipyOptimizerInterface(ExternalOptimizerInterface): optimizer = ScipyOptimizerInterface(loss, options={'maxiter': 100}) - with tf.Session() as session: + with tf.compat.v1.Session() as session: optimizer.minimize(session) # The value of vector should now be [0., 0.]. @@ -326,7 +326,7 @@ class ScipyOptimizerInterface(ExternalOptimizerInterface): optimizer = ScipyOptimizerInterface( loss, var_to_bounds={vector: ([1, 2], np.infty)}) - with tf.Session() as session: + with tf.compat.v1.Session() as session: optimizer.minimize(session) # The value of vector should now be [1., 2.]. @@ -349,7 +349,7 @@ class ScipyOptimizerInterface(ExternalOptimizerInterface): optimizer = ScipyOptimizerInterface( loss, equalities=equalities, inequalities=inequalities, method='SLSQP') - with tf.Session() as session: + with tf.compat.v1.Session() as session: optimizer.minimize(session) # The value of vector should now be [1., 1.]. diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py index 6dc17fe5a52..df0cb2b0071 100644 --- a/tensorflow/contrib/opt/python/training/ggt.py +++ b/tensorflow/contrib/opt/python/training/ggt.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import numpy as np from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -224,7 +225,7 @@ class GGTOptimizer(optimizer_v2.OptimizerV2): window = state.get_hyper("window") grad_buffer = self._get_grad_buffer(state) next_grad_index = math_ops.floormod( - math_ops.to_int32(update_global_step - 1.), window) + math_ops.cast(update_global_step - 1., dtypes.int32), window) # grad_buffer[(t-1) % window] := moment1_t update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, update_moment1) diff --git a/tensorflow/contrib/opt/python/training/matrix_functions.py b/tensorflow/contrib/opt/python/training/matrix_functions.py index baab5776386..1c5d2fe1787 100644 --- a/tensorflow/contrib/opt/python/training/matrix_functions.py +++ b/tensorflow/contrib/opt/python/training/matrix_functions.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -57,7 +58,7 @@ def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4): current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err - identity = linalg_ops.eye(math_ops.to_int32(mat_a_size)) + identity = linalg_ops.eye(math_ops.cast(mat_a_size, dtypes.int32)) mat_a = mat_a + ridge_epsilon * identity norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a)) mat_init_y = mat_a / norm @@ -100,7 +101,7 @@ def matrix_inverse_pth_root(mat_g, mat_g^alpha """ - identity = linalg_ops.eye(math_ops.to_int32(mat_g_size)) + identity = linalg_ops.eye(math_ops.cast(mat_g_size, dtypes.int32)) def mat_power(mat_m, p): """Computes mat_m^p, for p a positive integer. diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py index 746df77ba2c..669d83a3a49 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -41,23 +41,25 @@ class ModelAverageCustomGetter(object): 2. Generate global variables Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed - at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to + at ps device. Besides, use 'tf.compat.v1.get_variable' instead of + 'tf.Variable' to use this custom getter. For example, ma_custom_getter = ModelAverageCustomGetter(worker_device) with tf.device( - tf.train.replica_device_setter( + tf.compat.v1.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)), - tf.variable_scope('',custom_getter=ma_custom_getter): - hid_w = tf.get_variable( - initializer=tf.truncated_normal( + tf.compat.v1.variable_scope('',custom_getter=ma_custom_getter): + hid_w = tf.compat.v1.get_variable( + initializer=tf.random.truncated_normal( [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") - hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), + hid_b = + tf.compat.v1.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), name="hid_b") """ @@ -89,8 +91,8 @@ class ModelAverageCustomGetter(object): self._local_2_global[local_var] = global_variable return local_var else: - kwargs['trainable'] = trainable - kwargs['collections'] = collections + kwargs["trainable"] = trainable + kwargs["collections"] = collections if ops.GraphKeys.LOCAL_VARIABLES in collections: with ops.device(self._worker_device): return getter(name, *args, **kwargs) @@ -191,10 +193,10 @@ class ModelAverageOptimizer(optimizer.Optimizer): Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). - global_step: Optional Variable to increment by one after the - variables have been updated. - name: Optional name for the returned operation. Default to the - name passed to the Optimizer constructor. + global_step: Optional Variable to increment by one after the variables + have been updated. + name: Optional name for the returned operation. Default to the name + passed to the Optimizer constructor. Returns: A conditional 'Operation' that update both local and global variables or @@ -268,8 +270,9 @@ class ModelAverageOptimizer(optimizer.Optimizer): with ops.control_dependencies([local_update]): condition = math_ops.equal( math_ops.mod(self._local_step, self._interval_steps), 0) - conditional_update = control_flow_ops.cond( - condition, _update_global_variables, control_flow_ops.no_op) + conditional_update = control_flow_ops.cond(condition, + _update_global_variables, + control_flow_ops.no_op) chief_init_ops = [] for accum, dev in self._accumulator_list: diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index bf3e5c51f78..8b5a49f8ec7 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -44,7 +44,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): // Encapsulate your favorite optimizer (here the momentum one) // inside the MovingAverageOptimizer. - opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) + opt = tf.compat.v1.train.MomentumOptimizer(learning_rate, FLAGS.momentum) opt = tf.contrib.opt.MovingAverageOptimizer(opt) // Then create your model and all its variables. model = build_model() @@ -152,7 +152,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): **kwargs: Keyword arguments of `Saver()`. Returns: - A `tf.train.Saver` object. + A `tf.compat.v1.train.Saver` object. Raises: RuntimeError: If apply_gradients or minimize has not been called before. diff --git a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py index 9076cc9d128..30d96e16be0 100644 --- a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py +++ b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py @@ -78,7 +78,7 @@ class MultitaskOptimizerWrapper(object): Example: ```python - momentum_optimizer = tf.train.MomentumOptimizer( + momentum_optimizer = tf.compat.v1.train.MomentumOptimizer( learning_rate, momentum=0.9) multitask_momentum_optimizer = tf.contrib.opt.MultitaskOptimizerWrapper( momentum_optimizer) diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py index e542f46892a..efbafac662b 100644 --- a/tensorflow/contrib/opt/python/training/shampoo.py +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -24,6 +24,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.opt.python.training import matrix_functions +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -120,7 +121,7 @@ class ShampooOptimizer(optimizer.Optimizer): super(ShampooOptimizer, self).__init__(use_locking, name) - self._global_step = math_ops.to_float(global_step) + self._global_step = math_ops.cast(global_step, dtypes.float32) self._max_matrix_size = max_matrix_size self._gbar_decay = gbar_decay self._gbar_weight = gbar_weight @@ -246,7 +247,8 @@ class ShampooOptimizer(optimizer.Optimizer): if mat_g_size == 1: mat_h = math_ops.pow(mat_g + self._epsilon, alpha) else: - damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size)) + damping = self._epsilon * linalg_ops.eye( + math_ops.cast(mat_g_size, dtypes.int32)) diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) mat_h = math_ops.matmul( mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), diff --git a/tensorflow/contrib/opt/python/training/sign_decay.py b/tensorflow/contrib/opt/python/training/sign_decay.py index e8870c07211..99cd0f6e60e 100644 --- a/tensorflow/contrib/opt/python/training/sign_decay.py +++ b/tensorflow/contrib/opt/python/training/sign_decay.py @@ -23,7 +23,9 @@ from __future__ import division from __future__ import print_function import math + from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -51,10 +53,10 @@ def get_linear_decay_fn(decay_steps): if global_step is None: raise ValueError("global_step is required for linear_decay.") global_step = math_ops.minimum(global_step, decay_steps) - remaining_steps = math_ops.to_int32(decay_steps) - math_ops.to_int32( - global_step) - decayed = math_ops.to_float(remaining_steps) / math_ops.to_float( - decay_steps) + remaining_steps = math_ops.cast( + decay_steps, dtypes.int32) - math_ops.cast(global_step, dtypes.int32) + decayed = (math_ops.cast(remaining_steps, dtypes.float32) / + math_ops.cast(decay_steps, dtypes.float32)) return math_ops.maximum(0.0, decayed) # pylint:enable=missing-docstring return linear_decay_fn @@ -92,8 +94,8 @@ def get_cosine_decay_fn(decay_steps, num_periods=0.5, zero_after=None): if global_step is None: raise ValueError("global_step is required for cosine_decay.") global_step = math_ops.minimum(global_step, decay_steps) - completed_fraction = math_ops.to_float(global_step) / math_ops.to_float( - decay_steps) + completed_fraction = (math_ops.cast(global_step, dtypes.float32) / + math_ops.cast(decay_steps, dtypes.float32)) fraction = 2.0 * num_periods * completed_fraction decayed = 0.5 * ( 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) @@ -143,14 +145,14 @@ def get_restart_decay_fn(decay_steps, num_periods=1, zero_after=None): if global_step is None: raise ValueError("global_step is required for cosine_decay.") global_step = math_ops.minimum(global_step, decay_steps) - num = math_ops.mod(num_periods * math_ops.to_float(global_step), + num = math_ops.mod(num_periods * math_ops.cast(global_step, dtypes.float32), decay_steps) - fraction = num / math_ops.to_float(decay_steps) + fraction = num / math_ops.cast(decay_steps, dtypes.float32) decayed = 0.5 * ( 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) if zero_after is not None: - tmp = math_ops.to_float( - num_periods * global_step) / math_ops.to_float(decay_steps) + tmp = (math_ops.cast(num_periods * global_step, dtypes.float32) / + math_ops.cast(decay_steps, dtypes.float32)) decayed = array_ops.where( math_ops.greater_equal(tmp, zero_after), 0.0, decayed) return decayed diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 8b8065c678e..e2bcee51130 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Base class to make optimizers weight decay ready.""" from __future__ import absolute_import from __future__ import division @@ -59,12 +58,13 @@ class DecoupledWeightDecayExtension(object): Note that this extension decays weights BEFORE applying the update based on the gradient, i.e. this extension only has the desired behaviour for optimizers which do not depend on the value of'var' in the update step! - + Note: when applying a decay to the learning rate, be sure to manually apply the decay to the `weight_decay` as well. For example: ```python - schedule = tf.train.piecewise_constant(tf.train.get_global_step(), + schedule = + tf.compat.v1.train.piecewise_constant(tf.compat.v1.train.get_global_step(), [10000, 15000], [1e-0, 1e-1, 1e-2]) lr = 1e-1 * schedule() wd = lambda: 1e-4 * schedule() @@ -82,10 +82,9 @@ class DecoupledWeightDecayExtension(object): """Construct the extension class that adds weight decay to an optimizer. Args: - weight_decay: A `Tensor` or a floating point value, the factor by which - a variable is decayed in the update step. - **kwargs: Optional list or tuple or set of `Variable` objects to - decay. + weight_decay: A `Tensor` or a floating point value, the factor by which a + variable is decayed in the update step. + **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ self._decay_var_list = None # is set in minimize or apply_gradients self._weight_decay = weight_decay @@ -93,10 +92,16 @@ class DecoupledWeightDecayExtension(object): self._weight_decay_tensor = None super(DecoupledWeightDecayExtension, self).__init__(**kwargs) - def minimize(self, loss, global_step=None, var_list=None, + def minimize(self, + loss, + global_step=None, + var_list=None, gate_gradients=optimizer.Optimizer.GATE_OP, - aggregation_method=None, colocate_gradients_with_ops=False, - name=None, grad_loss=None, decay_var_list=None): + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None, + decay_var_list=None): """Add operations to minimize `loss` by updating `var_list` with decay. This function is the same as Optimizer.minimize except that it allows to @@ -107,17 +112,17 @@ class DecoupledWeightDecayExtension(object): Args: loss: A `Tensor` containing the value to minimize. - global_step: Optional `Variable` to increment by one after the - variables have been updated. + global_step: Optional `Variable` to increment by one after the variables + have been updated. var_list: Optional list or tuple of `Variable` objects to update to - minimize `loss`. Defaults to the list of variables collected in - the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + minimize `loss`. Defaults to the list of variables collected in the + graph under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. - colocate_gradients_with_ops: If True, try colocating gradients with - the corresponding op. + colocate_gradients_with_ops: If True, try colocating gradients with the + corresponding op. name: Optional name for the returned operation. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. decay_var_list: Optional list of decay variables. @@ -129,12 +134,19 @@ class DecoupledWeightDecayExtension(object): """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).minimize( - loss, global_step=global_step, var_list=var_list, - gate_gradients=gate_gradients, aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, + loss, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, grad_loss=grad_loss) - def apply_gradients(self, grads_and_vars, global_step=None, name=None, + def apply_gradients(self, + grads_and_vars, + global_step=None, + name=None, decay_var_list=None): """Apply gradients to variables and decay the variables. @@ -148,10 +160,10 @@ class DecoupledWeightDecayExtension(object): Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. - global_step: Optional `Variable` to increment by one after the - variables have been updated. - name: Optional name for the returned operation. Default to the - name passed to the `Optimizer` constructor. + global_step: Optional `Variable` to increment by one after the variables + have been updated. + name: Optional name for the returned operation. Default to the name + passed to the `Optimizer` constructor. decay_var_list: Optional list of decay variables. Returns: @@ -190,15 +202,14 @@ class DecoupledWeightDecayExtension(object): def _resource_apply_dense(self, grad, var): with ops.control_dependencies([self._decay_weights_op(var)]): - return super(DecoupledWeightDecayExtension, self)._resource_apply_dense( - grad, var) + return super(DecoupledWeightDecayExtension, + self)._resource_apply_dense(grad, var) def _apply_sparse(self, grad, var): scatter_add = state_ops.scatter_add decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add) with ops.control_dependencies([decay_op]): - return super(DecoupledWeightDecayExtension, self)._apply_sparse( - grad, var) + return super(DecoupledWeightDecayExtension, self)._apply_sparse(grad, var) def _resource_scatter_add(self, x, i, v, _=None): # last argument allows for one overflow argument, to have the same function @@ -211,8 +222,8 @@ class DecoupledWeightDecayExtension(object): scatter_add = self._resource_scatter_add decay_op = self._decay_weights_sparse_op(var, indices, scatter_add) with ops.control_dependencies([decay_op]): - return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse( - grad, var, indices) + return super(DecoupledWeightDecayExtension, + self)._resource_apply_sparse(grad, var, indices) def extend_with_decoupled_weight_decay(base_optimizer): @@ -221,7 +232,8 @@ def extend_with_decoupled_weight_decay(base_optimizer): Returns an optimizer class. An instance of the returned class computes the update step of `base_optimizer` and additionally decays the weights. E.g., the class returned by - `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to + `extend_with_decoupled_weight_decay(tf.compat.v1.train.AdamOptimizer)` is + equivalent to `tf.contrib.opt.AdamWOptimizer`. The API of the new optimizer class slightly differs from the API of the @@ -234,7 +246,7 @@ def extend_with_decoupled_weight_decay(base_optimizer): Usage example: ```python # MyAdamW is a new class - MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + MyAdamW = extend_with_decoupled_weight_decay(tf.compat.v1.train.AdamOptimizer) # Create a MyAdamW object optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) @@ -270,8 +282,8 @@ def extend_with_decoupled_weight_decay(base_optimizer): def __init__(self, weight_decay, *args, **kwargs): # super delegation is necessary here # pylint: disable=useless-super-delegation - super(OptimizerWithDecoupledWeightDecay, self).__init__( - weight_decay, *args, **kwargs) + super(OptimizerWithDecoupledWeightDecay, + self).__init__(weight_decay, *args, **kwargs) # pylint: enable=useless-super-delegation return OptimizerWithDecoupledWeightDecay @@ -296,13 +308,18 @@ class MomentumWOptimizer(DecoupledWeightDecayExtension, Note that this optimizer can also be instantiated as ```python - extend_with_weight_decay(tf.train.MomentumOptimizer, + extend_with_weight_decay(tf.compat.v1.train.MomentumOptimizer, weight_decay=weight_decay) ``` """ - def __init__(self, weight_decay, learning_rate, momentum, - use_locking=False, name="MomentumW", use_nesterov=False): + def __init__(self, + weight_decay, + learning_rate, + momentum, + use_locking=False, + name="MomentumW", + use_nesterov=False): """Construct a new MomentumW optimizer. For further information see the documentation of the Momentum Optimizer. @@ -314,23 +331,25 @@ class MomentumWOptimizer(DecoupledWeightDecayExtension, use_locking: If `True` use locks for update operations. name: Optional name prefix for the operations created when applying gradients. Defaults to "Momentum". - use_nesterov: If `True` use Nesterov Momentum. - See [Sutskever et al., 2013]( - http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). - This implementation always computes gradients at the value of the - variable(s) passed to the optimizer. Using Nesterov Momentum makes the - variable(s) track the values called `theta_t + mu*v_t` in the paper. - - @compatibility(eager) - When eager execution is enabled, learning_rate, weight_decay and momentum - can each be a callable that takes no arguments and returns the actual value - to use. This can be useful for changing these values across different - invocations of optimizer functions. - @end_compatibility + use_nesterov: If `True` use Nesterov Momentum. See [Sutskever et al., + 2013]( + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). This + implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + @compatibility(eager) When eager execution is enabled, learning_rate, + weight_decay and momentum can each be a callable that takes no + arguments and returns the actual value to use. This can be useful for + changing these values across different invocations of optimizer + functions. @end_compatibility """ super(MomentumWOptimizer, self).__init__( - weight_decay, learning_rate=learning_rate, momentum=momentum, - use_locking=use_locking, name=name, use_nesterov=use_nesterov) + weight_decay, + learning_rate=learning_rate, + momentum=momentum, + use_locking=use_locking, + name=name, + use_nesterov=use_nesterov) @tf_export("contrib.opt.AdamWOptimizer") @@ -352,12 +371,19 @@ class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): Note that this optimizer can also be instantiated as ```python - extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay) + extend_with_weight_decay(tf.compat.v1.train.AdamOptimizer, + weight_decay=weight_decay) ``` """ - def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999, - epsilon=1e-8, use_locking=False, name="AdamW"): + def __init__(self, + weight_decay, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + use_locking=False, + name="AdamW"): """Construct a new AdamW optimizer. For further information see the documentation of the Adam Optimizer. @@ -365,10 +391,10 @@ class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): Args: weight_decay: A `Tensor` or a floating point value. The weight decay. learning_rate: A Tensor or a floating point value. The learning rate. - beta1: A float value or a constant float tensor. - The exponential decay rate for the 1st moment estimates. - beta2: A float value or a constant float tensor. - The exponential decay rate for the 2nd moment estimates. + beta1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. The exponential decay + rate for the 2nd moment estimates. epsilon: A small constant for numerical stability. This epsilon is "epsilon hat" in the Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper. @@ -377,8 +403,13 @@ class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): Defaults to "Adam". """ super(AdamWOptimizer, self).__init__( - weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2, - epsilon=epsilon, use_locking=use_locking, name=name) + weight_decay, + learning_rate=learning_rate, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + use_locking=use_locking, + name=name) @tf_export("contrib.opt.ShampooWOptimizer") diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 1e210ec666c..1c8cdc5a42f 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -143,11 +143,6 @@ class CheckpointingTests(test.TestCase): suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ name + suffix for name in expected_checkpoint_names] - # The optimizer and Dense layers also save get_config() JSON - expected_checkpoint_names.extend([ - "model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON", - "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON" - ]) named_variables = {v.name: v for v in named_variables} six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) @@ -302,7 +297,7 @@ class CheckpointingTests(test.TestCase): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = util.Checkpoint( + root = util.CheckpointV1( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) @@ -726,10 +721,9 @@ class CheckpointCompatibilityTests(test.TestCase): with context.graph_mode(): save_graph = ops.Graph() with save_graph.as_default(), self.test_session( - graph=save_graph) as session: + graph=save_graph): root = self._initialized_model() - save_path = root.save( - session=session, file_prefix=checkpoint_prefix) + save_path = root.save(file_prefix=checkpoint_prefix) with context.eager_mode(): root = self._initialized_model() self._set_sentinels(root) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 436ece79a79..7bcf07fbddc 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -188,7 +188,7 @@ def _is_dynamic(value): return True # Don't need to do anything special in graph mode, since dynamic values # will propagate correctly automatically. - # TODO(josh11b): Add per-device caching across steps using variables for + # TODO(josh11b): Add per-replica caching across steps using variables for # truly static values once we add distributed support. if context.executing_eagerly() and isinstance( value, resource_variable_ops.ResourceVariable): @@ -916,7 +916,8 @@ class OptimizerV2(optimizer_v1.Optimizer): var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) - unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)] + unwrapped_var_list = [ + x for v in var_list for x in distribution.experimental_local_results(v)] eager_execution = context.executing_eagerly() if eager_execution: # Give a clear error in this case instead of "name not supported diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py index 202c1e9afc0..a161538d151 100644 --- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py +++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -448,5 +449,54 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): ]), var1.eval()) +class SlotColocationTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([True, False]) + @test_util.run_gpu_only + @test_util.run_in_graph_and_eager_modes + def testRunMinimizeOnGPUForCPUVariables(self, use_resource): + with ops.device("/device:CPU:0"): + if use_resource: + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtypes.float32) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], + dtype=dtypes.float32) + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step") + else: + var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32) + var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32) + global_step = variables.Variable( + array_ops.zeros([], dtypes.int64), name="global_step") + + def loss(): + return 5 * var0 + 3 * var1 + + opt = rmsprop.RMSPropOptimizer( + learning_rate=1.0, decay=0.9, momentum=0.5, epsilon=1.0) + + # Fetch params to validate initial values + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 1 step through optimizer on GPU. + # Slot variables are created the first time optimizer is used on some + # variable. This tests that slot variables will be colocated with the base + # variable. + with ops.device("/device:GPU:0"): + # Note that for eager execution, minimize expects a function instead of a + # Tensor. + opt_op = opt.minimize(loss, global_step, [var0, var1]) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt_op) + + # Validate updated params, All variables should have decreased. + self.assertTrue(all(v < 0.0 for v in self.evaluate(var0)), + msg="updated variables: %s" % self.evaluate(var0)) + self.assertTrue(all(v < 2.0 for v in self.evaluate(var1)), + msg="updated variables: %s" % self.evaluate(var1)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index c980a9342e4..37674071e41 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -1,19 +1,19 @@ +load("//tensorflow:tensorflow.bzl", "py_test") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", - "tf_gen_op_libs", - "tf_custom_op_library", - "tf_gen_op_wrapper_py", -) -load("//tensorflow:tensorflow.bzl", "py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") - cc_library( name = "all_ops", srcs = [":custom_op_sources"], @@ -77,6 +77,7 @@ py_library( py_test( name = "periodic_resample_op_test", srcs = ["python/kernel_tests/periodic_resample_op_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "notap", diff --git a/tensorflow/contrib/pi_examples/label_image/Makefile b/tensorflow/contrib/pi_examples/label_image/Makefile index 9d054a3133a..58fbd18dc3a 100644 --- a/tensorflow/contrib/pi_examples/label_image/Makefile +++ b/tensorflow/contrib/pi_examples/label_image/Makefile @@ -34,12 +34,14 @@ CXXFLAGS := --std=c++11 $(OPTFLAGS) LDFLAGS := \ -L/usr/local/lib \ -L$(TFLIBDIR) \ +-L$(DOWNLOADSDIR)/nsync/builds/default.linux.c++11/ \ -Wl,--no-whole-archive INCLUDES := \ -I/usr/local/include \ -I. \ -I$(DOWNLOADSDIR) \ -I$(DOWNLOADSDIR)/eigen/ \ +-I$(DOWNLOADSDIR)/absl/ \ -I$(PROTOGENDIR) \ -I$(PBTGENDIR) LIBS := \ @@ -49,6 +51,7 @@ LIBS := \ -Wl,--no-whole-archive \ -lstdc++ \ -lprotobuf \ +-lnsync \ -ldl \ -lpthread \ -lm \ diff --git a/tensorflow/contrib/pi_examples/label_image/label_image.cc b/tensorflow/contrib/pi_examples/label_image/label_image.cc index c6935a093f7..97a6e69ac03 100644 --- a/tensorflow/contrib/pi_examples/label_image/label_image.cc +++ b/tensorflow/contrib/pi_examples/label_image/label_image.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include #include + #include #include diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index 53a3bc63e1d..3189bb97ca3 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -110,6 +110,7 @@ py_test( name = "saved_model_predictor_test", srcs = ["saved_model_predictor_test.py"], data = [":test_export_dir"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -126,6 +127,7 @@ py_test( name = "predictor_factories_test", srcs = ["predictor_factories_test.py"], data = [":test_export_dir"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -137,6 +139,7 @@ py_test( py_test( name = "core_estimator_predictor_test", srcs = ["core_estimator_predictor_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -150,6 +153,7 @@ py_test( py_test( name = "contrib_estimator_predictor_test", srcs = ["contrib_estimator_predictor_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD index b27142cf4a6..c167fd70189 100644 --- a/tensorflow/contrib/proto/BUILD +++ b/tensorflow/contrib/proto/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "proto", @@ -14,5 +14,15 @@ py_library( deps = [ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", + "//tensorflow/python:proto_ops", + ], +) + +tf_py_test( + name = "import_test", + srcs = ["import_test.py"], + additional_deps = [ + ":proto", + "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/proto/__init__.py b/tensorflow/contrib/proto/__init__.py index bc5a49de78e..1fe17324cd9 100644 --- a/tensorflow/contrib/proto/__init__.py +++ b/tensorflow/contrib/proto/__init__.py @@ -21,8 +21,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.proto.python.ops.decode_proto_op import decode_proto -from tensorflow.contrib.proto.python.ops.encode_proto_op import encode_proto +from tensorflow.python.ops.proto_ops import decode_proto +from tensorflow.python.ops.proto_ops import encode_proto from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) diff --git a/tensorflow/contrib/proto/import_test.py b/tensorflow/contrib/proto/import_test.py new file mode 100644 index 00000000000..5da74a8c578 --- /dev/null +++ b/tensorflow/contrib/proto/import_test.py @@ -0,0 +1,33 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Backwards compatibility tests for imports of tf.contrib.proto.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import proto +from tensorflow.python.platform import test + + +class ProtoImportTest(test.TestCase): + + def testImport(self): + self.assertTrue(proto.decode_proto) # Should be accessible + self.assertTrue(proto.encode_proto) # Should be accessible + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/proto/python/ops/BUILD b/tensorflow/contrib/proto/python/ops/BUILD index f17065477e1..ac09934b77d 100644 --- a/tensorflow/contrib/proto/python/ops/BUILD +++ b/tensorflow/contrib/proto/python/ops/BUILD @@ -1,44 +1,20 @@ -package(default_visibility = ["//visibility:public"]) - licenses(["notice"]) # Apache 2.0 -exports_files(["LICENSE"]) +package(default_visibility = ["//tensorflow:__subpackages__"]) -load( - "//tensorflow:tensorflow.bzl", - "tf_gen_op_wrapper_py", +# Placeholders for folks with old dependencies. +py_library( + name = "encode_proto_op_py", + srcs = ["encode_proto_op.py"], + deps = [ + "//tensorflow/python:proto_ops", + ], ) py_library( name = "decode_proto_op_py", srcs = ["decode_proto_op.py"], deps = [ - ":gen_decode_proto_op_py", - "//tensorflow/python:framework_ops", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_decode_proto_op_py", - out = "gen_decode_proto_op.py", - deps = [ - "//tensorflow/core:decode_proto_ops_op_lib", - ], -) - -py_library( - name = "encode_proto_op_py", - srcs = ["encode_proto_op.py"], - deps = [ - ":gen_encode_proto_op_py", - "//tensorflow/python:framework_ops", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_encode_proto_op_py", - out = "gen_encode_proto_op.py", - deps = [ - "//tensorflow/core:encode_proto_ops_op_lib", + "//tensorflow/python:proto_ops", ], ) diff --git a/tensorflow/contrib/proto/python/ops/decode_proto_op.py b/tensorflow/contrib/proto/python/ops/decode_proto_op.py index 7dc000ebe49..1347ebe3346 100644 --- a/tensorflow/contrib/proto/python/ops/decode_proto_op.py +++ b/tensorflow/contrib/proto/python/ops/decode_proto_op.py @@ -1,4 +1,3 @@ -# ============================================================================= # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,14 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= +# ============================================================================== -# pylint: disable=wildcard-import,unused-import -"""Protocol Buffer decoding from tensors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.proto.python.ops.gen_decode_proto_op import decode_proto_v2 as decode_proto -from tensorflow.python.framework import ops -ops.NotDifferentiable("DecodeProtoV2") +# pylint: disable=unused-import +from tensorflow.python.ops.proto_ops import decode_proto diff --git a/tensorflow/contrib/proto/python/ops/encode_proto_op.py b/tensorflow/contrib/proto/python/ops/encode_proto_op.py index ac12198b2e4..6c1fcf68566 100644 --- a/tensorflow/contrib/proto/python/ops/encode_proto_op.py +++ b/tensorflow/contrib/proto/python/ops/encode_proto_op.py @@ -11,15 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= +# ============================================================================== -# pylint: disable=wildcard-import,unused-import -"""Protocol Buffer encoding from tensors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.proto.python.ops.gen_encode_proto_op import encode_proto -from tensorflow.python.framework import ops - -ops.NotDifferentiable("EncodeProto") +# pylint: disable=unused-import +from tensorflow.python.ops.proto_ops import encode_proto diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b67e68ea96a..598f6d15676 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -1,4 +1,4 @@ -package(default_visibility = ["//tensorflow:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -17,6 +17,7 @@ py_test( name = "common_test", size = "small", srcs = ["python/common_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":common", @@ -46,6 +47,7 @@ py_test( name = "graph_matcher_test", size = "small", srcs = ["python/graph_matcher_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_matcher", @@ -75,6 +77,7 @@ py_test( name = "input_to_ops_test", size = "small", srcs = ["python/input_to_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":input_to_ops", @@ -112,6 +115,7 @@ py_library( py_test( name = "fold_batch_norms_test", srcs = ["python/fold_batch_norms_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":fold_batch_norms", @@ -152,6 +156,7 @@ py_test( name = "quant_ops_test", size = "small", srcs = ["python/quant_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":quant_ops", @@ -185,6 +190,7 @@ py_test( name = "quantize_test", size = "small", srcs = ["python/quantize_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":quantize", @@ -204,6 +210,7 @@ py_test( name = "quantize_parameterized_test", size = "medium", srcs = ["python/quantize_parameterized_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", # TODO(b/118839526): Re-enable msan test. @@ -243,6 +250,7 @@ py_test( name = "quantize_graph_test", size = "small", srcs = ["python/quantize_graph_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":quantize_graph", diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index e6c04bcf554..3c553d07102 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -115,7 +115,8 @@ def CreateOrGetQuantizationStep(): dtype=dtypes.int64, initializer=init_ops.zeros_initializer(), trainable=False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) with g.name_scope(quantization_step_tensor.op.name + '/'): # We return the incremented variable tensor. Since this is used in conds # for quant_delay and freeze_bn_delay, it will run once per graph diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 39082cacf97..ecee2549683 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -55,7 +55,8 @@ def _ModelVariable(name, shape=shape, initializer=initializer, collections=collections, - trainable=trainable) + trainable=trainable, + aggregation=variable_scope.VariableAggregation.MEAN) def LastValueQuantize(inputs, @@ -161,12 +162,12 @@ def LastValueQuantize(inputs, # than the positive range. min_max_ratio = -((1 << num_bits) - 2) / (1 << num_bits) - # TFLite requires that 0.0 if always in the [min; max] range. Because + # TFLite requires that 0.0 is always in the [min; max] range. Because # batch_min <= batch_max, it follows that range_min <= 0 <= range_max. range_min = math_ops.minimum(batch_min, batch_max / min_max_ratio) range_max = math_ops.maximum(batch_max, batch_min * min_max_ratio) else: - # TFLite requires that 0.0 if always in the [min; max] range. + # TFLite requires that 0.0 is always in the [min; max] range. range_min = math_ops.minimum(batch_min, 0.0) range_max = math_ops.maximum(batch_max, 0.0) @@ -286,12 +287,12 @@ def MovingAvgQuantize(inputs, # than the positive range. min_max_ratio = -((1 << num_bits) - 2) / (1 << num_bits) - # TFLite requires that 0.0 if always in the [min; max] range. Because + # TFLite requires that 0.0 is always in the [min; max] range. Because # batch_min <= batch_max, it follows that range_min <= 0 <= range_max. range_min = math_ops.minimum(batch_min, batch_max / min_max_ratio) range_max = math_ops.maximum(batch_max, batch_min * min_max_ratio) else: - # TFLite requires that 0.0 if always in the [min; max] range. + # TFLite requires that 0.0 is always in the [min; max] range. range_min = math_ops.minimum(batch_min, 0.0) range_max = math_ops.maximum(batch_max, 0.0) diff --git a/tensorflow/contrib/rate/BUILD b/tensorflow/contrib/rate/BUILD index 76db9aecf61..4a60b4703ec 100644 --- a/tensorflow/contrib/rate/BUILD +++ b/tensorflow/contrib/rate/BUILD @@ -34,6 +34,7 @@ py_test( name = "rate_test", size = "small", srcs = ["rate_test.py"], + python_version = "PY2", tags = [ "manual", # TODO(b/120555555) "no_oss", # TODO(b/120555555) diff --git a/tensorflow/contrib/receptive_field/BUILD b/tensorflow/contrib/receptive_field/BUILD index 9325a14745c..18ef0205941 100644 --- a/tensorflow/contrib/receptive_field/BUILD +++ b/tensorflow/contrib/receptive_field/BUILD @@ -62,6 +62,7 @@ py_library( py_test( name = "graph_compute_order_test", srcs = ["python/util/graph_compute_order_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_compute_order_py", @@ -78,6 +79,7 @@ py_test( py_test( name = "parse_layer_parameters_test", srcs = ["python/util/parse_layer_parameters_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":graph_compute_order_py", @@ -94,6 +96,7 @@ py_test( py_test( name = "receptive_field_test", srcs = ["python/util/receptive_field_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":receptive_field_py", diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index 3abf7bd6dad..6b4a5fbe8bc 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A tf.nn.dynamic_rnn variant, built on the Recurrent class. -""" +"""A tf.compat.v1.nn.dynamic_rnn variant, built on the Recurrent class.""" from __future__ import absolute_import from __future__ import division @@ -143,10 +142,12 @@ class _FunctionalRnnCell(object): @property def extended_initial_state(self): if self._prepend_output: - return [array_ops.zeros( - self._output_shape, - dtype=_GetDTypesFromStructure(self._state_template)[0]), - self._state_template] + return [ + array_ops.zeros( + self._output_shape, + dtype=_GetDTypesFromStructure(self._state_template)[0]), + self._state_template + ] else: # The base case, where the output is just the hidden state. return self._state_template @@ -189,8 +190,7 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output): is_less = math_ops.cast( math_ops.less(output_time, lengths), dtype=tf_output.dtype) keep_mask = array_ops.tile( - array_ops.expand_dims(is_less, -1), - [1, 1, vector_size]) + array_ops.expand_dims(is_less, -1), [1, 1, vector_size]) final_output = keep_mask * tf_output return final_output @@ -206,10 +206,10 @@ def _PickFinalStateFromHistory(acc_state, sequence_length): max_time, batch_size = shape[0], shape[1] output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size]) output_time = array_ops.reshape(output_time, [batch_size, max_time]) - lengths = array_ops.tile(array_ops.reshape(sequence_length, - [-1, 1]), [1, max_time]) - last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1), - dtype=state_var.dtype) + lengths = array_ops.tile( + array_ops.reshape(sequence_length, [-1, 1]), [1, max_time]) + last_idx = math_ops.cast( + math_ops.equal(output_time, lengths - 1), dtype=state_var.dtype) last_idx = array_ops.transpose(last_idx) last_idx_for_bcast = array_ops.expand_dims(last_idx, -1) sliced = math_ops.multiply(last_idx_for_bcast, state_var) @@ -235,8 +235,8 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell, Args: extended_acc_state: A structure containing the accumulated state at each time. It may contain the output at each time as well. - extended_final_state: A structure containing the final state. It may - contain the output at the final time. + extended_final_state: A structure containing the final state. It may contain + the output at the final time. func_cell: The functional wrapper around the cell. total_time: A scalar integer tensor. inputs_lengths: An integer tensor with one entry per input. @@ -254,8 +254,7 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell, # out from the acc_state sequence. flat_acc_state = func_cell.MaybeRemoveOutputFromState( nest.flatten(extended_acc_state)) - acc_state = nest.pack_sequence_as( - func_cell.state_template, flat_acc_state) + acc_state = nest.pack_sequence_as(func_cell.state_template, flat_acc_state) tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths) output_from_state = func_cell.GetOutputFromState(extended_acc_state) @@ -281,11 +280,11 @@ def functional_rnn(cell, scope=None, use_tpu=False, reverse=False): - """Same interface as `tf.nn.dynamic_rnn`.""" + """Same interface as `tf.compat.v1.nn.dynamic_rnn`.""" with variable_scope.variable_scope(scope or 'rnn'): if not time_major: - inputs = nest.map_structure( - lambda t: array_ops.transpose(t, [1, 0, 2]), inputs) + inputs = nest.map_structure(lambda t: array_ops.transpose(t, [1, 0, 2]), + inputs) inputs_flat = nest.flatten(inputs) batch_size = array_ops.shape(inputs_flat[0])[1] if initial_state is None: @@ -333,19 +332,19 @@ def bidirectional_functional_rnn(cell_fw, """Creates a bidirectional recurrent neural network. Performs fully dynamic unrolling of inputs in both directions. Built to be API - compatible with `tf.nn.bidirectional_dynamic_rnn`, but implemented with + compatible with `tf.compat.v1.nn.bidirectional_dynamic_rnn`, but implemented + with functional control flow for TPU compatibility. Args: - cell_fw: An instance of `tf.contrib.rnn.RNNCell`. - cell_bw: An instance of `tf.contrib.rnn.RNNCell`. + cell_fw: An instance of `tf.compat.v1.nn.rnn_cell.RNNCell`. + cell_bw: An instance of `tf.compat.v1.nn.rnn_cell.RNNCell`. inputs: The RNN inputs. If time_major == False (default), this must be a - Tensor (or hierarchical structure of Tensors) of shape - [batch_size, max_time, ...]. If time_major == True, this must be a Tensor - (or hierarchical structure of Tensors) of shape: - [max_time, batch_size, ...]. The first two dimensions must match across - all the inputs, but otherwise the ranks and other shape components may - differ. + Tensor (or hierarchical structure of Tensors) of shape [batch_size, + max_time, ...]. If time_major == True, this must be a Tensor + (or hierarchical structure of Tensors) of shape: [max_time, batch_size, + ...]. The first two dimensions must match across all the inputs, but + otherwise the ranks and other shape components may differ. initial_state_fw: An optional initial state for `cell_fw`. Should match `cell_fw.zero_state` in structure and type. initial_state_bw: An optional initial state for `cell_bw`. Should match @@ -384,14 +383,19 @@ def bidirectional_functional_rnn(cell_fw, ValueError: If `initial_state_fw` is None or `initial_state_bw` is None and `dtype` is not provided. """ - # Keep this code in sync with tf.nn.dynamic_rnn for compatibility. + # Keep this code in sync with tf.compat.v1.nn.dynamic_rnn for compatibility. with variable_scope.variable_scope(scope or 'bidirectional_rnn'): # Forward direction with variable_scope.variable_scope('fw') as fw_scope: output_fw, output_state_fw = functional_rnn( - cell=cell_fw, inputs=inputs, sequence_length=sequence_length, - initial_state=initial_state_fw, dtype=dtype, - time_major=time_major, scope=fw_scope, use_tpu=use_tpu) + cell=cell_fw, + inputs=inputs, + sequence_length=sequence_length, + initial_state=initial_state_fw, + dtype=dtype, + time_major=time_major, + scope=fw_scope, + use_tpu=use_tpu) # Backward direction if not time_major: time_dim = 1 @@ -403,8 +407,10 @@ def bidirectional_functional_rnn(cell_fw, def _reverse(input_, seq_lengths, seq_dim, batch_dim): if seq_lengths is not None: return array_ops.reverse_sequence( - input=input_, seq_lengths=seq_lengths, - seq_dim=seq_dim, batch_dim=batch_dim) + input=input_, + seq_lengths=seq_lengths, + seq_dim=seq_dim, + batch_dim=batch_dim) else: # See b/69305369. assert not use_tpu, ( @@ -440,4 +446,6 @@ def bidirectional_functional_rnn(cell_fw, output_states = (output_state_fw, output_state_bw) return (outputs, output_states) + + # pylint: enable=invalid-name diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py index f51de755d81..fbe11a11589 100644 --- a/tensorflow/contrib/recurrent/python/ops/recurrent.py +++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py @@ -100,7 +100,8 @@ def _Update(struct_acc, struct_x, t): to_skip_update = set() acc_lst = nest.flatten(struct_acc) x_lst = nest.flatten(struct_x) - t = math_ops.to_int32([t]) # tf.to_int32 casts on-device tensors. + t = math_ops.cast( + [t], dtypes.int32) # tf.compat.v1.to_int32 casts on-device tensors. lst = [] for acc, x in zip(acc_lst, x_lst): if acc in to_skip_update: @@ -429,7 +430,8 @@ class _Recurrent(object): acc_extras = _EmptyAcc(slen_dim, extras) t = slen_dim - max_input_length if self._aligned_end else 0 - dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t) + dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast( + t, dtypes.int64) run = functional_ops.For( start=t, limit=slen_dim if self._aligned_end else max_input_length, @@ -568,7 +570,8 @@ class _Recurrent(object): # Loop backwards. Note the loop's limit is open-ended, so goes through # t=0. t = slen_dim - 1 if self._aligned_end else max_input_length - 1 - dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t) + dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast( + t, dtypes.int64) limit = slen_dim - max_input_length - 1 if self._aligned_end else -1 run = functional_ops.For( start=t, diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc index 13fbd974e9c..be09076e862 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -37,7 +37,7 @@ namespace functor { #define GPUReduceSliceFunctorReduceop(reduceop, beginning) \ template \ __global__ void ReduceSliceDeviceKernel##reduceop( \ - Cuda3DLaunchConfig config, Index indices_width, Index bound, \ + Gpu3DLaunchConfig config, Index indices_width, Index bound, \ const T begin, const Index *indices, const T *input, T *out) { \ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \ @@ -73,7 +73,7 @@ namespace functor { if (sizex * sizey * sizez == 0) { \ return; \ } \ - Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \ + Gpu3DLaunchConfig config = GetGpu3DLaunchConfig( \ sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop, \ 0, 0); \ \ diff --git a/tensorflow/contrib/remote_fused_graph/pylib/BUILD b/tensorflow/contrib/remote_fused_graph/pylib/BUILD index 3aa8a14f44f..274bdbeacf7 100644 --- a/tensorflow/contrib/remote_fused_graph/pylib/BUILD +++ b/tensorflow/contrib/remote_fused_graph/pylib/BUILD @@ -37,6 +37,7 @@ py_test( name = "remote_fused_graph_ops_test", size = "small", srcs = ["python/ops/remote_fused_graph_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":remote_fused_graph_ops_py", diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc index 3b2ee098b3e..bdadc36bbc7 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc +++ b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc @@ -17,13 +17,13 @@ #define EIGEN_USE_GPU -#include "tensorflow/contrib/resampler/kernels/resampler_ops.h" - #include + #include +#include "tensorflow/contrib/resampler/kernels/resampler_ops.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -117,8 +117,8 @@ struct Resampler2DFunctor { const int data_channels, const int num_sampling_points) { const int output_data_size = batch_size * num_sampling_points * data_channels; - ::tensorflow::CudaLaunchConfig config = - ::tensorflow::GetCudaLaunchConfig(output_data_size, d); + ::tensorflow::GpuLaunchConfig config = + ::tensorflow::GetGpuLaunchConfig(output_data_size, d); TF_CHECK_OK(CudaLaunchKernel( Resampler2DKernel, config.block_count, config.thread_per_block, 0, d.stream(), data, warp, output, batch_size, data_height, data_width, @@ -252,20 +252,20 @@ struct ResamplerGrad2DFunctor { const int grad_data_size = batch_size * data_height * data_width * data_channels; - ::tensorflow::CudaLaunchConfig config = - ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); + ::tensorflow::GpuLaunchConfig config = + ::tensorflow::GetGpuLaunchConfig(grad_warp_size, d); TF_CHECK_OK(::tensorflow::CudaLaunchKernel( SetZero, config.block_count, config.thread_per_block, 0, d.stream(), grad_warp_size, grad_warp)); - config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); + config = ::tensorflow::GetGpuLaunchConfig(grad_data_size, d); TF_CHECK_OK(::tensorflow::CudaLaunchKernel( SetZero, config.block_count, config.thread_per_block, 0, d.stream(), grad_data_size, grad_data)); const int resampler_output_size = batch_size * num_sampling_points * data_channels; - config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); + config = ::tensorflow::GetGpuLaunchConfig(resampler_output_size, d); TF_CHECK_OK(CudaLaunchKernel(ResamplerGrad2DKernel, config.block_count, config.thread_per_block, 0, d.stream(), data, warp, grad_output, grad_data, grad_warp, diff --git a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py index cec4c3c2330..558cb9015ad 100644 --- a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py +++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class ResamplerOpsTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, warp_np, expected): - with self.test_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): input_image = array_ops.placeholder(image_np.dtype) warp = array_ops.placeholder(warp_np.dtype) resampled = resampler.resampler(input_image, warp, name='resampler') @@ -41,7 +41,7 @@ class ResamplerOpsTest(xla_test.XLATestCase): def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np, expected_grad_data, expected_grad_warp): - with self.cached_session() as sess, self.test_scope(): + with self.session() as sess, self.test_scope(): input_image = array_ops.placeholder(input_np.dtype) warp = array_ops.placeholder(warp_np.dtype) grad_output = array_ops.placeholder(grad_output_np.dtype) diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 24fa740d245..66fadcc16b5 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -99,6 +99,7 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = ["optonly"], xla_enabled = True, ) @@ -182,8 +183,10 @@ tf_custom_op_library( "kernels/lstm_ops.h", ], deps = [ - "//tensorflow/core/kernels:eigen_contraction_kernel", - "//tensorflow/core/kernels:eigen_helpers", + # _lstm_ops.so and _gru_ops.so cannot both be linked to MKL-DNN, + # or there will be duplicate thread_local symbols, which can cause + # problems in TensorFlow Transform. + "//tensorflow/core/kernels:eigen_helpers_no_mkl", ], ) @@ -207,8 +210,10 @@ tf_custom_op_library( "kernels/gru_ops.h", ], deps = [ - "//tensorflow/core/kernels:eigen_contraction_kernel", - "//tensorflow/core/kernels:eigen_helpers", + # _lstm_ops.so and _gru_ops.so cannot both be linked to MKL-DNN, + # or there will be duplicate thread_local symbols, which can cause + # problems in TensorFlow Transform. + "//tensorflow/core/kernels:eigen_helpers_no_mkl", ], ) @@ -341,6 +346,7 @@ tf_kernel_library( py_binary( name = "checkpoint_convert", srcs = ["python/tools/checkpoint_convert.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":checkpoint_convert_lib"], ) @@ -364,6 +370,7 @@ py_test( name = "checkpoint_convert_test", size = "small", srcs = ["python/tools/checkpoint_convert_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index 81beb2942c1..472058b9a9e 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -17,12 +17,11 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/contrib/rnn/kernels/lstm_ops.h" - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/contrib/rnn/kernels/lstm_ops.h" #include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py index d5700d2a200..1a9e7053c55 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -63,7 +63,7 @@ class _MaskedRandomUniformInitializer(init_ops.RandomUniform): maxval: A python scalar or a scalar tensor. Upper bound of the range of random values to generate. Defaults to 1 for float types. seed: A Python integer. Used to create random seeds. See - `tf.set_random_seed` for behavior. + `tf.compat.v1.set_random_seed` for behavior. dtype: The data type. Only supports tf.float16 for now. num_valid_mantissa_bits: number of non-zero mantissa bits, default to 4. diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index fbf2d4fcb8a..921b4baae43 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -1575,6 +1575,61 @@ class RNNCellTest(test.TestCase): self.assertEqual(len(outputs), batch) self.assertEqual(len(state), batch) + def testNTMCell(self): + expected_output = np.array( + [[-0.04973561, -0.00020032, -0.09586009, -0.05049511], + [-0.02199885, 0.02302885, -0.05558189, -0.02051288], + [-0.01399924, 0.02543444, -0.06975862, -0.03782758], + [-0.02238393, 0.0135776, -0.09102941, -0.05594013]], + dtype=np.float32) + expected_read_vector_list = np.array( + [[1e-6, 1e-6, 1e-6, 1e-6], [1e-6, 1e-6, 1e-6, 1e-6], + [1e-6, 1e-6, 1e-6, 1e-6], [1e-6, 1e-6, 1e-6, 1e-6]], + dtype=np.float32) + expected_w_list = np.array( + [[[0.15837428, 0.21354634, 0.22115856, 0.21117255, 0.19574821], + [0.15826838, 0.2150458, 0.2228198, 0.20747298, 0.19639312], + [0.15750293, 0.21550071, 0.22280747, 0.20737495, 0.19681393], + [0.15763053, 0.21473582, 0.22187267, 0.20920397, 0.19655706]], + [[0.21703579, 0.19425659, 0.22143759, 0.18024713, 0.18702294], + [0.2164267, 0.19451937, 0.22112325, 0.18051708, 0.18741359], + [0.21567065, 0.1947548, 0.22107735, 0.18058982, 0.18790732], + [0.2163743, 0.194361, 0.22131558, 0.18042919, 0.1875199]]], + dtype=np.float32) + expected_M_0 = np.array( + [[-0.00553495, -0.01089884, 0.00683121, -0.00273276], + [-0.00495392, -0.00975483, 0.00611433, -0.00244583], + [-0.00564722, -0.0111199, 0.00696973, -0.0027882], + [-0.00459658, -0.00905126, 0.00567345, -0.00226937], + [-0.00476941, -0.00939155, 0.00588669, -0.00235472]], + dtype=np.float32) + + with session.Session() as sess: + with variable_scope.variable_scope("root"): + seed = 1234 + random_seed.set_random_seed(seed) + batch_size = 4 + inputs = random_ops.random_uniform((batch_size, 4), + 0.0, + 1.0, + seed=seed + 1) + cell = contrib_rnn_cell.NTMCell( + controller=rnn_cell_impl.LSTMCell(num_units=4), + memory_size=5, + memory_vector_dim=4, + read_head_num=1, + write_head_num=1) + output, state = cell(inputs, cell.zero_state(batch_size, + dtypes.float32)) + sess.run([variables.global_variables_initializer()]) + res, read_vector_list, w_list, M = sess.run( + [output, state.read_vector_list, state.w_list, state.M]) + # Smoke test + self.assertAllClose(res, expected_output) + self.assertAllClose(read_vector_list[0], expected_read_vector_list) + self.assertAllClose(w_list, expected_w_list) + self.assertAllClose(M[0], expected_M_0) + class LayerNormBasicLSTMCellTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py index f90fd40990a..74b422a3845 100644 --- a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py @@ -136,7 +136,7 @@ class TimeReversedFusedRNN(FusedRNNCell): For example, ```python - cell = tf.contrib.rnn.BasicRNNCell(10) + cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(10) fw_lstm = tf.contrib.rnn.FusedRNNCellAdaptor(cell, use_dynamic_rnn=True) bw_lstm = tf.contrib.rnn.TimeReversedFusedRNN(fw_lstm) fw_out, fw_state = fw_lstm(inputs) diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py index 251a933eaec..09907af7dbe 100644 --- a/tensorflow/contrib/rnn/python/ops/gru_ops.py +++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py @@ -151,7 +151,7 @@ class GRUBlockCell(LayerRNNCell): name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. By default this is "lstm_cell", for variable-name compatibility - with `tf.nn.rnn_cell.GRUCell`. + with `tf.compat.v1.nn.rnn_cell.GRUCell`. Raises: ValueError: if both cell_size and num_units are not None; diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index b043026bc55..af7f16b8e7f 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -367,7 +367,7 @@ class LSTMBlockCell(LayerRNNCell): name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. By default this is "lstm_cell", for variable-name compatibility - with `tf.nn.rnn_cell.LSTMCell`. + with `tf.compat.v1.nn.rnn_cell.LSTMCell`. When restoring from CudnnLSTM-trained checkpoints, must use CudnnCompatibleLSTMBlockCell instead. @@ -619,7 +619,7 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. By default this is "lstm_cell", for variable-name compatibility - with `tf.nn.rnn_cell.LSTMCell`. + with `tf.compat.v1.nn.rnn_cell.LSTMCell`. """ super(LSTMBlockFusedCell, self).__init__( _reuse=reuse, name=name, dtype=dtype) @@ -691,9 +691,10 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype) if sequence_length is None: - max_seq_len = math_ops.to_int64(time_len) + max_seq_len = math_ops.cast(time_len, dtypes.int64) else: - max_seq_len = math_ops.to_int64(math_ops.reduce_max(sequence_length)) + max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length), + dtypes.int64) _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm( seq_len_max=max_seq_len, diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 9111044e5b3..75710ea4190 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -33,6 +33,7 @@ from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import input_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -3295,6 +3296,8 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. + + For a detailed analysis of IndyLSTMs, see https://arxiv.org/abs/1903.08023. """ def __init__(self, @@ -3409,6 +3412,354 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): return new_h, new_state +NTMControllerState = collections.namedtuple( + "NTMControllerState", + ("controller_state", "read_vector_list", "w_list", "M", "time")) + + +class NTMCell(rnn_cell_impl.LayerRNNCell): + """Neural Turing Machine Cell with RNN controller. + + Implementation based on: + https://arxiv.org/abs/1807.08518 + Mark Collier, Joeran Beel + + which is in turn based on the source code of: + https://github.com/snowkylin/ntm + + and of course the original NTM paper: + Neural Turing Machines + https://arxiv.org/abs/1410.5401 + A Graves, G Wayne, I Danihelka + """ + + def __init__(self, + controller, + memory_size, + memory_vector_dim, + read_head_num, + write_head_num, + shift_range=1, + output_dim=None, + clip_value=20, + dtype=dtypes.float32, + name=None): + """Initialize the NTM Cell. + + Args: + controller: an RNNCell, the RNN controller. + memory_size: int, The number of memory locations in the NTM memory + matrix + memory_vector_dim: int, The dimensionality of each location in the NTM + memory matrix + read_head_num: int, The number of read heads from the controller into + memory + write_head_num: int, The number of write heads from the controller into + memory + shift_range: int, The number of places to the left/right it is possible + to iterate the previous address to in a single step + output_dim: int, The number of dimensions to make a linear projection of + the NTM controller outputs to. If None, no linear projection is + applied + clip_value: float, The maximum absolute value the controller parameters + are clipped to + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + """ + super(NTMCell, self).__init__(dtype=dtype, name=name) + + rnn_cell_impl.assert_like_rnncell("NTM RNN controller cell", controller) + + self.controller = controller + self.memory_size = memory_size + self.memory_vector_dim = memory_vector_dim + self.read_head_num = read_head_num + self.write_head_num = write_head_num + self.clip_value = clip_value + + self.output_dim = output_dim + self.shift_range = shift_range + + self.num_parameters_per_head = ( + self.memory_vector_dim + 2 * self.shift_range + 4) + self.num_heads = self.read_head_num + self.write_head_num + self.total_parameter_num = ( + self.num_parameters_per_head * self.num_heads + + self.memory_vector_dim * 2 * self.write_head_num) + + @property + def state_size(self): + return NTMControllerState( + controller_state=self.controller.state_size, + read_vector_list=[ + self.memory_vector_dim for _ in range(self.read_head_num) + ], + w_list=[ + self.memory_size + for _ in range(self.read_head_num + self.write_head_num) + ], + M=tensor_shape.TensorShape([self.memory_size * self.memory_vector_dim]), + time=tensor_shape.TensorShape([])) + + @property + def output_size(self): + return self.output_dim + + def build(self, inputs_shape): + if self.output_dim is None: + if inputs_shape[1].value is None: + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % + inputs_shape) + else: + self.output_dim = inputs_shape[1].value + + def _create_linear_initializer(input_size, dtype=dtypes.float32): + stddev = 1.0 / math.sqrt(input_size) + return init_ops.truncated_normal_initializer(stddev=stddev, dtype=dtype) + + self._params_kernel = self.add_variable( + "parameters_kernel", + shape=[self.controller.output_size, self.total_parameter_num], + initializer=_create_linear_initializer(self.controller.output_size)) + + self._params_bias = self.add_variable( + "parameters_bias", + shape=[self.total_parameter_num], + initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) + + self._output_kernel = self.add_variable( + "output_kernel", + shape=[ + self.controller.output_size + + self.memory_vector_dim * self.read_head_num, self.output_dim + ], + initializer=_create_linear_initializer(self.controller.output_size + + self.memory_vector_dim * + self.read_head_num)) + + self._output_bias = self.add_variable( + "output_bias", + shape=[self.output_dim], + initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) + + self._init_read_vectors = [ + self.add_variable( + "initial_read_vector_%d" % i, + shape=[1, self.memory_vector_dim], + initializer=initializers.glorot_uniform()) + for i in range(self.read_head_num) + ] + + self._init_address_weights = [ + self.add_variable( + "initial_address_weights_%d" % i, + shape=[1, self.memory_size], + initializer=initializers.glorot_uniform()) + for i in range(self.read_head_num + self.write_head_num) + ] + + self._M = self.add_variable( + "memory", + shape=[self.memory_size, self.memory_vector_dim], + initializer=init_ops.constant_initializer(1e-6, dtype=self.dtype)) + + self.built = True + + def call(self, x, prev_state): + # Addressing Mechanisms (Sec 3.3) + + def _prev_read_vector_list_initial_value(): + return [ + self._expand( + math_ops.tanh( + array_ops.squeeze( + math_ops.matmul( + array_ops.ones([1, 1]), self._init_read_vectors[i]))), + dim=0, + N=x.shape[0].value or array_ops.shape(x)[0]) + for i in range(self.read_head_num) + ] + + prev_read_vector_list = control_flow_ops.cond( + math_ops.equal(prev_state.time, + 0), _prev_read_vector_list_initial_value, lambda: + prev_state.read_vector_list) + if self.read_head_num == 1: + prev_read_vector_list = [prev_read_vector_list] + + controller_input = array_ops.concat([x] + prev_read_vector_list, axis=1) + controller_output, controller_state = self.controller( + controller_input, prev_state.controller_state) + + parameters = math_ops.matmul(controller_output, self._params_kernel) + parameters = nn_ops.bias_add(parameters, self._params_bias) + parameters = clip_ops.clip_by_value(parameters, -self.clip_value, + self.clip_value) + head_parameter_list = array_ops.split( + parameters[:, :self.num_parameters_per_head * self.num_heads], + self.num_heads, + axis=1) + erase_add_list = array_ops.split( + parameters[:, self.num_parameters_per_head * self.num_heads:], + 2 * self.write_head_num, + axis=1) + + def _prev_w_list_initial_value(): + return [ + self._expand( + nn_ops.softmax( + array_ops.squeeze( + math_ops.matmul( + array_ops.ones([1, 1]), + self._init_address_weights[i]))), + dim=0, + N=x.shape[0].value or array_ops.shape(x)[0]) + for i in range(self.read_head_num + self.write_head_num) + ] + + prev_w_list = control_flow_ops.cond( + math_ops.equal(prev_state.time, 0), + _prev_w_list_initial_value, lambda: prev_state.w_list) + if (self.read_head_num + self.write_head_num) == 1: + prev_w_list = [prev_w_list] + + prev_M = control_flow_ops.cond( + math_ops.equal(prev_state.time, 0), lambda: self._expand( + self._M, dim=0, N=x.shape[0].value or array_ops.shape(x)[0]), + lambda: prev_state.M) + + w_list = [] + for i, head_parameter in enumerate(head_parameter_list): + k = math_ops.tanh(head_parameter[:, 0:self.memory_vector_dim]) + beta = nn_ops.softplus(head_parameter[:, self.memory_vector_dim]) + g = math_ops.sigmoid(head_parameter[:, self.memory_vector_dim + 1]) + s = nn_ops.softmax(head_parameter[:, self.memory_vector_dim + + 2:(self.memory_vector_dim + 2 + + (self.shift_range * 2 + 1))]) + gamma = nn_ops.softplus(head_parameter[:, -1]) + 1 + w = self._addressing(k, beta, g, s, gamma, prev_M, prev_w_list[i]) + w_list.append(w) + + # Reading (Sec 3.1) + + read_w_list = w_list[:self.read_head_num] + read_vector_list = [] + for i in range(self.read_head_num): + read_vector = math_ops.reduce_sum( + array_ops.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1) + read_vector_list.append(read_vector) + + # Writing (Sec 3.2) + + write_w_list = w_list[self.read_head_num:] + M = prev_M + for i in range(self.write_head_num): + w = array_ops.expand_dims(write_w_list[i], axis=2) + erase_vector = array_ops.expand_dims( + math_ops.sigmoid(erase_add_list[i * 2]), axis=1) + add_vector = array_ops.expand_dims( + math_ops.tanh(erase_add_list[i * 2 + 1]), axis=1) + erase_M = array_ops.ones_like(M) - math_ops.matmul(w, erase_vector) + M = M * erase_M + math_ops.matmul(w, add_vector) + + output = math_ops.matmul( + array_ops.concat([controller_output] + read_vector_list, axis=1), + self._output_kernel) + output = nn_ops.bias_add(output, self._output_bias) + output = clip_ops.clip_by_value(output, -self.clip_value, self.clip_value) + + return output, NTMControllerState( + controller_state=controller_state, + read_vector_list=read_vector_list, + w_list=w_list, + M=M, + time=prev_state.time + 1) + + def _expand(self, x, dim, N): + return array_ops.concat([array_ops.expand_dims(x, dim) for _ in range(N)], + axis=dim) + + def _addressing(self, k, beta, g, s, gamma, prev_M, prev_w): + # Sec 3.3.1 Focusing by Content + + k = array_ops.expand_dims(k, axis=2) + inner_product = math_ops.matmul(prev_M, k) + k_norm = math_ops.sqrt( + math_ops.reduce_sum(math_ops.square(k), axis=1, keepdims=True)) + M_norm = math_ops.sqrt( + math_ops.reduce_sum(math_ops.square(prev_M), axis=2, keepdims=True)) + norm_product = M_norm * k_norm + + # eq (6) + K = array_ops.squeeze(inner_product / (norm_product + 1e-8)) + + K_amplified = math_ops.exp(array_ops.expand_dims(beta, axis=1) * K) + + # eq (5) + w_c = K_amplified / math_ops.reduce_sum(K_amplified, axis=1, keepdims=True) + + # Sec 3.3.2 Focusing by Location + + g = array_ops.expand_dims(g, axis=1) + + # eq (7) + w_g = g * w_c + (1 - g) * prev_w + + s = array_ops.concat([ + s[:, :self.shift_range + 1], + array_ops.zeros([ + s.shape[0].value or array_ops.shape(s)[0], self.memory_size - + (self.shift_range * 2 + 1) + ]), s[:, -self.shift_range:] + ], + axis=1) + t = array_ops.concat( + [array_ops.reverse(s, axis=[1]), + array_ops.reverse(s, axis=[1])], + axis=1) + s_matrix = array_ops.stack([ + t[:, self.memory_size - i - 1:self.memory_size * 2 - i - 1] + for i in range(self.memory_size) + ], + axis=1) + + # eq (8) + w_ = math_ops.reduce_sum( + array_ops.expand_dims(w_g, axis=1) * s_matrix, axis=2) + w_sharpen = math_ops.pow(w_, array_ops.expand_dims(gamma, axis=1)) + + # eq (9) + w = w_sharpen / math_ops.reduce_sum(w_sharpen, axis=1, keepdims=True) + + return w + + def zero_state(self, batch_size, dtype): + read_vector_list = [ + array_ops.zeros([batch_size, self.memory_vector_dim]) + for _ in range(self.read_head_num) + ] + + w_list = [ + array_ops.zeros([batch_size, self.memory_size]) + for _ in range(self.read_head_num + self.write_head_num) + ] + + controller_init_state = self.controller.zero_state(batch_size, dtype) + + M = array_ops.zeros([batch_size, self.memory_size, self.memory_vector_dim]) + + return NTMControllerState( + controller_state=controller_init_state, + read_vector_list=read_vector_list, + w_list=w_list, + M=M, + time=0) + + class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): """MinimalRNN cell. diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py index d8ab9eba704..110025030b7 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py @@ -35,7 +35,6 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase): _protocol = 'grpc' invalid_method_string = 'Method not found' - connect_failed_string = 'Connect Failed' def __init__(self, methodName='runTest'): # pylint: disable=invalid-name super(RpcOpTest, self).__init__(methodName) diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index d6148715be9..ec073011f75 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -22,13 +22,12 @@ import itertools import numpy as np -from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.contrib.proto.python.ops import encode_proto_op from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2 from tensorflow.contrib.rpc.python.ops import rpc_op from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import proto_ops __all__ = ['I_WARNED_YOU', 'RpcOpTestBase'] @@ -124,8 +123,6 @@ class RpcOpTestBase(object): address=address, request='')) self.assertEqual(errors.UNAVAILABLE, status_code_value) - self.assertTrue( - self.connect_failed_string in status_message_value.decode('ascii')) def testAlwaysFailingMethod(self): with self.cached_session() as sess: @@ -222,7 +219,7 @@ class RpcOpTestBase(object): def testVecHostPortRpcUsingEncodeAndDecodeProto(self): with self.cached_session() as sess: - request_tensors = encode_proto_op.encode_proto( + request_tensors = proto_ops.encode_proto( message_type='tensorflow.contrib.rpc.TestCase', field_names=['values'], sizes=[[3]] * 20, @@ -233,7 +230,7 @@ class RpcOpTestBase(object): method=self.get_method_name('Increment'), address=self._address, request=request_tensors) - _, (response_shape,) = decode_proto_op.decode_proto( + _, (response_shape,) = proto_ops.decode_proto( bytes=response_tensor_strings, message_type='tensorflow.contrib.rpc.TestCase', field_names=['values'], diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index f0242a3b40f..969ff19eca6 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -69,6 +69,7 @@ py_test( name = "reader_test", size = "small", srcs = ["python/saved_model/reader_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows visibility = ["//visibility:private"], diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index be2aa4782c3..4af15095eec 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -18,7 +18,7 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -90,7 +90,7 @@ struct GatherTree { // First kernel launch to "zero" things out beams.device(d) = beams.constant(end_token); - CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); + GpuLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); TF_CHECK_OK(CudaLaunchKernel( GatherTreeOpKernel, config.block_count, config.thread_per_block, 0, d.stream(), batch_size, max_time, beam_width, step_ids.data(), diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 98e54db4584..634bc4ea21e 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -65,7 +65,7 @@ def get_result_summary(x): return x -@test_util.run_v1_only +@test_util.run_v1_only('contrib code not supported in TF2.0') class AttentionWrapperTest(test.TestCase): def assertAllCloseOrEqual(self, x, y, **kwargs): @@ -514,7 +514,7 @@ class AttentionWrapperTest(test.TestCase): for axis in [0, 1]: for exclusive in [True, False]: with self.cached_session(): - # Compute cumprod with regular tf.cumprod + # Compute cumprod with regular tf.math.cumprod cumprod_output = math_ops.cumprod( test_input, axis=axis, exclusive=exclusive).eval() # Compute cumprod with safe_cumprod diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py index 6d3192c9dae..66a464dc218 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py @@ -135,7 +135,7 @@ class AttentionMechanismTest(test.TestCase, parameterized.TestCase): encoder_input = keras.layers.Embedding( vocab, embedding_dim, mask_zero=True)( inputs) - encoder_output = keras.layers.UnifiedLSTM( + encoder_output = keras.layers.LSTM( self.memory_size, return_sequences=True)( encoder_input) @@ -314,12 +314,10 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): attention_layer_size=attention_layer_size, alignment_history=alignment_history, attention_layer=attention_layer) - # Set the attention_layer within AttentionWrapper to have deterministic - # kernel initializer, for testing purpose. if cell._attention_layers is not None: for layer in cell._attention_layers: if getattr(layer, "kernel_initializer") is None: - layer.kernel_initializer = initializers.ones() + layer.kernel_initializer = initializers.glorot_uniform(seed=1337) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) @@ -453,16 +451,17 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324), - sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0)) + shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=0.051747426), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype(np.int32), mean=3.33333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( - shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824), + shape=(5, 9), dtype=np.dtype(np.float32), mean=0.44189346), ResultSummary( - shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636)], + shape=(5, 9), dtype=np.dtype(np.float32), mean=0.65429491)], attention=ResultSummary( - shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569), + shape=(5, 6), dtype=np.dtype(np.float32), mean=0.073610783), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), @@ -481,23 +480,23 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): expected_final_alignment_history=expected_final_alignment_history, create_attention_kwargs=create_attention_kwargs) - def DISABLED_testBahdanauNormalized(self): + def testBahdanauNormalized(self): create_attention_mechanism = wrapper.BahdanauAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.9548259), + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.047594748), sample_id=ResultSummary( - shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + shape=(5, 3), dtype=np.dtype("int32"), mean=3.6)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983), + shape=(5, 9), dtype=np.dtype("float32"), mean=0.41311637), ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209)], + shape=(5, 9), dtype=np.dtype("float32"), mean=0.61683208)], attention=ResultSummary( - shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728), + shape=(5, 6), dtype=np.dtype("float32"), mean=0.090581432), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), @@ -517,17 +516,17 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489), + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.05481226), sample_id=ResultSummary( - shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088), + shape=(5, 9), dtype=np.dtype("float32"), mean=0.38453412), ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)], + shape=(5, 9), dtype=np.dtype("float32"), mean=0.5785929)], attention=ResultSummary( - shape=(5, 6), dtype=np.dtype("float32"), mean=4.084631), + shape=(5, 6), dtype=np.dtype("float32"), mean=0.16311775), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), @@ -541,23 +540,23 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): expected_final_state, attention_mechanism_depth=9) - def DISABLED_testLuongScaled(self): + def testLuongScaled(self): create_attention_mechanism = wrapper.LuongAttentionV2 create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489), + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.05481226), sample_id=ResultSummary( - shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088), + shape=(5, 9), dtype=np.dtype("float32"), mean=0.38453412), ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)], + shape=(5, 9), dtype=np.dtype("float32"), mean=0.5785929)], attention=ResultSummary( - shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314), + shape=(5, 6), dtype=np.dtype("float32"), mean=0.16311775), time=3, alignments=ResultSummary( shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), @@ -604,31 +603,31 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): create_query_layer=True, create_attention_kwargs=create_attention_kwargs) - def DISABLED_testBahdanauMonotonicNotNormalized(self): + def testBahdanauMonotonicNotNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones"} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 6), dtype=np.dtype("float32"), mean=5.9850435), + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.041342419), sample_id=ResultSummary( - shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248), + shape=(5, 9), dtype=np.dtype("float32"), mean=0.33866978), ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492)], + shape=(5, 9), dtype=np.dtype("float32"), mean=0.46913195)], attention=ResultSummary( - shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186), + shape=(5, 6), dtype=np.dtype("float32"), mean=0.092498459), time=3, alignments=ResultSummary( - shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678), + shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), attention_state=ResultSummary( - shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678), + shape=(5, 8), dtype=np.dtype("float32"), mean=0.12079944), alignment_history=()) expected_final_alignment_history = ResultSummary( - shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.117412611) + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.121448785067) self._testWithAttention( create_attention_mechanism, @@ -639,28 +638,28 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): create_query_layer=True, create_attention_kwargs=create_attention_kwargs) - def DISABLED_testBahdanauMonotonicNormalized(self): + def testBahdanauMonotonicNormalized(self): create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 6), dtype=np.dtype("float32"), mean=4.5706983), + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.043294173), sample_id=ResultSummary( - shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + shape=(5, 3), dtype=np.dtype("int32"), mean=3.53333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038), + shape=(5, 9), dtype=np.dtype("float32"), mean=0.40034312), ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473)], + shape=(5, 9), dtype=np.dtype("float32"), mean=0.5925445)], attention=ResultSummary( - shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721), + shape=(5, 6), dtype=np.dtype("float32"), mean=0.096119694), time=3, alignments=ResultSummary( - shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384), + shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), attention_state=ResultSummary( - shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384), + shape=(5, 8), dtype=np.dtype("float32"), mean=0.1211452), alignment_history=()) expected_final_alignment_history = ResultSummary( shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384) @@ -679,25 +678,25 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497), + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079), sample_id=ResultSummary( - shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + shape=(5, 3), dtype=np.dtype("int32"), mean=3.133333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038), + shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431), ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)], + shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348)], attention=ResultSummary( - shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), + shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723), time=3, alignments=ResultSummary( - shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), attention_state=ResultSummary( - shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), alignment_history=()) expected_final_alignment_history = ResultSummary( - shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644) + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004) self._testWithAttention( create_attention_mechanism, @@ -707,31 +706,31 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): alignment_history=True, expected_final_alignment_history=expected_final_alignment_history) - def DISABLED_testLuongMonotonicScaled(self): + def testLuongMonotonicScaled(self): create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 create_attention_kwargs = {"scale": True} expected_final_output = basic_decoder.BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497), + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=0.027387079), sample_id=ResultSummary( - shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + shape=(5, 3), dtype=np.dtype("int32"), mean=3.13333333)) expected_final_state = wrapper.AttentionWrapperState( cell_state=[ ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038), + shape=(5, 9), dtype=np.dtype("float32"), mean=0.32660431), ResultSummary( - shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)], + shape=(5, 9), dtype=np.dtype("float32"), mean=0.52464348)], attention=ResultSummary( - shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), + shape=(5, 6), dtype=np.dtype("float32"), mean=0.089345723), time=3, alignments=ResultSummary( - shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), attention_state=ResultSummary( - shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11831035), alignment_history=()) expected_final_alignment_history = ResultSummary( - shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644) + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12194442004) self._testWithAttention( create_attention_mechanism, diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index 599abf5a361..2183761bf11 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -36,7 +36,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -@test_util.run_v1_only +@test_util.run_v1_only("contrib code not supported in TF2.0") class BasicDecoderTest(test.TestCase): def _testStepWithTrainingHelper(self, use_output_layer): diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 8c84cd13588..6360d1cfdc1 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -472,7 +472,7 @@ class TestLargeBeamStep(test.TestCase): self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]]) -@test_util.run_v1_only +@test_util.run_v1_only('contrib code not supported in TF2.0') class BeamSearchDecoderTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, has_attention, diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index 4a420221e27..3661daf3a26 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import test -@test_util.run_v1_only +@test_util.run_v1_only("contrib code not supported in TF2.0") class DynamicDecodeRNNTest(test.TestCase): def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None): diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 577a3efbd7d..a9215e88000 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -47,7 +47,6 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest - __all__ = [ "AttentionMechanism", "AttentionWrapper", @@ -61,7 +60,6 @@ __all__ = [ "LuongMonotonicAttention", ] - _zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=protected-access @@ -92,41 +90,43 @@ class _BaseAttentionMechanism(AttentionMechanism): memory_layer=None, check_inner_dims_defined=True, score_mask_value=None, + custom_key_value_fn=None, name=None): """Construct base AttentionMechanism class. Args: - query_layer: Callable. Instance of `tf.layers.Layer`. The layer's depth - must match the depth of `memory_layer`. If `query_layer` is not - provided, the shape of `query` must match that of `memory_layer`. + query_layer: Callable. Instance of `tf.compat.v1.layers.Layer`. The + layer's depth must match the depth of `memory_layer`. If `query_layer` + is not provided, the shape of `query` must match that of `memory_layer`. memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. probability_fn: A `callable`. Converts the score and previous alignments - to probabilities. Its signature should be: - `probabilities = probability_fn(score, state)`. + to probabilities. Its signature should be: `probabilities = + probability_fn(score, state)`. memory_sequence_length (optional): Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. - memory_layer: Instance of `tf.layers.Layer` (may be None). The layer's - depth must match the depth of `query_layer`. - If `memory_layer` is not provided, the shape of `memory` must match - that of `query_layer`. + memory_layer: Instance of `tf.compat.v1.layers.Layer` (may be None). The + layer's depth must match the depth of `query_layer`. If `memory_layer` + is not provided, the shape of `memory` must match that of `query_layer`. check_inner_dims_defined: Python boolean. If `True`, the `memory` argument's shape is checked to ensure all but the two outermost dimensions are fully defined. score_mask_value: (optional): The mask value for score before passing into `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. + custom_key_value_fn: (optional): The custom function for + computing keys and values. name: Name to use when creating ops. """ - if (query_layer is not None - and not isinstance(query_layer, layers_base.Layer)): - raise TypeError( - "query_layer is not a Layer: %s" % type(query_layer).__name__) - if (memory_layer is not None - and not isinstance(memory_layer, layers_base.Layer)): - raise TypeError( - "memory_layer is not a Layer: %s" % type(memory_layer).__name__) + if (query_layer is not None and + not isinstance(query_layer, layers_base.Layer)): + raise TypeError("query_layer is not a Layer: %s" % + type(query_layer).__name__) + if (memory_layer is not None and + not isinstance(memory_layer, layers_base.Layer)): + raise TypeError("memory_layer is not a Layer: %s" % + type(memory_layer).__name__) self._query_layer = query_layer self._memory_layer = memory_layer self.dtype = memory_layer.dtype @@ -138,23 +138,27 @@ class _BaseAttentionMechanism(AttentionMechanism): self._memory_layer.dtype).as_numpy_dtype(-np.inf) self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda probability_fn( - _maybe_mask_score(score, - memory_sequence_length=memory_sequence_length, - score_mask_value=score_mask_value), - prev)) - with ops.name_scope( - name, "BaseAttentionMechanismInit", nest.flatten(memory)): + _maybe_mask_score( + score, + memory_sequence_length=memory_sequence_length, + score_mask_value=score_mask_value), prev)) + with ops.name_scope(name, "BaseAttentionMechanismInit", + nest.flatten(memory)): self._values = _prepare_memory( - memory, memory_sequence_length=memory_sequence_length, + memory, + memory_sequence_length=memory_sequence_length, check_inner_dims_defined=check_inner_dims_defined) self._keys = ( self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable else self._values) + if custom_key_value_fn is not None: + self._keys, self._values = custom_key_value_fn(self._keys, self._values) self._batch_size = ( tensor_shape.dimension_value(self._keys.shape[0]) or array_ops.shape(self._keys)[0]) - self._alignments_size = (tensor_shape.dimension_value(self._keys.shape[1]) - or array_ops.shape(self._keys)[1]) + self._alignments_size = ( + tensor_shape.dimension_value(self._keys.shape[1]) or + array_ops.shape(self._keys)[1]) @property def memory_layer(self): @@ -262,29 +266,27 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. probability_fn: A `callable`. Converts the score and previous alignments - to probabilities. Its signature should be: - `probabilities = probability_fn(score, state)`. + to probabilities. Its signature should be: `probabilities = + probability_fn(score, state)`. query_layer: (optional): Instance of `tf.keras.Layer`. The layer's depth must match the depth of `memory_layer`. If `query_layer` is not provided, the shape of `query` must match that of `memory_layer`. - memory_layer: (optional): Instance of `tf.keras.Layer`. The layer's - depth must match the depth of `query_layer`. - If `memory_layer` is not provided, the shape of `memory` must match - that of `query_layer`. + memory_layer: (optional): Instance of `tf.keras.Layer`. The layer's depth + must match the depth of `query_layer`. If `memory_layer` is not + provided, the shape of `memory` must match that of `query_layer`. memory_sequence_length (optional): Sequence lengths for the batch entries - in memory. If provided, the memory tensor rows are masked with zeros - for values past the respective sequence lengths. + in memory. If provided, the memory tensor rows are masked with zeros for + values past the respective sequence lengths. **kwargs: Dictionary that contains other common arguments for layer creation. """ - if (query_layer is not None - and not isinstance(query_layer, layers.Layer)): - raise TypeError( - "query_layer is not a Layer: %s" % type(query_layer).__name__) - if (memory_layer is not None - and not isinstance(memory_layer, layers.Layer)): - raise TypeError( - "memory_layer is not a Layer: %s" % type(memory_layer).__name__) + if (query_layer is not None and not isinstance(query_layer, layers.Layer)): + raise TypeError("query_layer is not a Layer: %s" % + type(query_layer).__name__) + if (memory_layer is not None and + not isinstance(memory_layer, layers.Layer)): + raise TypeError("memory_layer is not a Layer: %s" % + type(memory_layer).__name__) self.query_layer = query_layer self.memory_layer = memory_layer if self.memory_layer is not None and "dtype" not in kwargs: @@ -369,23 +371,22 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): Args: inputs: a list of tensor that could either be `query` and `state`, or - `memory` and `memory_sequence_length`. - `query` is the tensor of dtype matching `memory` and shape - `[batch_size, query_depth]`. - `state` is the tensor of dtype matching `memory` and shape - `[batch_size, alignments_size]`. (`alignments_size` is memory's - `max_time`). - `memory` is the memory to query; usually the output of an RNN encoder. - The tensor should be shaped `[batch_size, max_time, ...]`. - `memory_sequence_length` (optional) is the sequence lengths for the - batch entries in memory. If provided, the memory tensor rows are masked - with zeros for values past the respective sequence lengths. + `memory` and `memory_sequence_length`. `query` is the tensor of dtype + matching `memory` and shape `[batch_size, query_depth]`. `state` is the + tensor of dtype matching `memory` and shape `[batch_size, + alignments_size]`. (`alignments_size` is memory's `max_time`). `memory` + is the memory to query; usually the output of an RNN encoder. The tensor + should be shaped `[batch_size, max_time, ...]`. `memory_sequence_length` + (optional) is the sequence lengths for the batch entries in memory. If + provided, the memory tensor rows are masked with zeros for values past + the respective sequence lengths. mask: optional bool tensor with shape `[batch, max_time]` for the mask of memory. If it is not None, the corresponding item of the memory should be filtered out during calculation. setup_memory: boolean, whether the input is for setting up memory, or query attention. **kwargs: Dict, other keyword arguments for the call method. + Returns: Either processed memory or attention score, based on `setup_memory`. """ @@ -440,8 +441,8 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): if memory_sequence_length is not None and memory_mask is not None: raise ValueError("memory_sequence_length and memory_mask cannot be " "used at same time for attention.") - with ops.name_scope( - self.name, "BaseAttentionMechanismInit", nest.flatten(memory)): + with ops.name_scope(self.name, "BaseAttentionMechanismInit", + nest.flatten(memory)): self.values = _prepare_memory( memory, memory_sequence_length=memory_sequence_length, @@ -459,10 +460,12 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): self.batch_size = ( tensor_shape.dimension_value(self.keys.shape[0]) or array_ops.shape(self.keys)[0]) - self._alignments_size = (tensor_shape.dimension_value(self.keys.shape[1]) - or array_ops.shape(self.keys)[1]) + self._alignments_size = ( + tensor_shape.dimension_value(self.keys.shape[1]) or + array_ops.shape(self.keys)[1]) if memory_mask is not None: unwrapped_probability_fn = self.probability_fn + def _mask_probability_fn(score, prev): return unwrapped_probability_fn( _maybe_mask_score( @@ -470,6 +473,7 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): memory_mask=memory_mask, memory_sequence_length=memory_sequence_length, score_mask_value=self.score_mask_value), prev) + self.probability_fn = _mask_probability_fn self._memory_initialized = True @@ -526,6 +530,7 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): config: dict, the configs that will be used to reconstruct the object. custom_objects: dict mapping class names (or function names) of custom (non-Keras) objects to class/functions. + Returns: config: dict, the config with layer instance created, which is ready to be used as init parameters. @@ -536,13 +541,13 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): config = config.copy() query_layer_config = config.pop("query_layer", None) if query_layer_config: - query_layer = deserialize_layer(query_layer_config, - custom_objects=custom_objects) + query_layer = deserialize_layer( + query_layer_config, custom_objects=custom_objects) config["query_layer"] = query_layer memory_layer_config = config.pop("memory_layer", None) if memory_layer_config: - memory_layer = deserialize_layer(memory_layer_config, - custom_objects=custom_objects) + memory_layer = deserialize_layer( + memory_layer_config, custom_objects=custom_objects) config["memory_layer"] = memory_layer return config @@ -623,8 +628,8 @@ def _luong_score(query, keys, scale): raise ValueError( "Incompatible or unknown inner dimensions between query and keys. " "Query (%s) has units: %s. Keys (%s) have units: %s. " - "Perhaps you need to set num_units to the keys' dimension (%s)?" - % (query, depth, keys, key_units, key_units)) + "Perhaps you need to set num_units to the keys' dimension (%s)?" % + (query, depth, keys, key_units, key_units)) # Reshape from [batch_size, depth] to [batch_size, 1, depth] # for matmul. @@ -672,6 +677,7 @@ class LuongAttention(_BaseAttentionMechanism): probability_fn=None, score_mask_value=None, dtype=None, + custom_key_value_fn=None, name="LuongAttention"): """Construct the AttentionMechanism mechanism. @@ -691,6 +697,8 @@ class LuongAttention(_BaseAttentionMechanism): `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. dtype: The data type for the memory layer of the attention mechanism. + custom_key_value_fn: (optional): The custom function for + computing keys and values. name: Name to use when creating ops. """ # For LuongAttention, we only transform the memory layer; thus @@ -708,6 +716,7 @@ class LuongAttention(_BaseAttentionMechanism): probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, score_mask_value=score_mask_value, + custom_key_value_fn=custom_key_value_fn, name=name) self._num_units = num_units self._scale = scale @@ -717,11 +726,10 @@ class LuongAttention(_BaseAttentionMechanism): """Score the query based on the keys and values. Args: - query: Tensor of dtype matching `self.values` and shape - `[batch_size, query_depth]`. - state: Tensor of dtype matching `self.values` and shape - `[batch_size, alignments_size]` - (`alignments_size` is memory's `max_time`). + query: Tensor of dtype matching `self.values` and shape `[batch_size, + query_depth]`. + state: Tensor of dtype matching `self.values` and shape `[batch_size, + alignments_size]` (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape @@ -732,8 +740,10 @@ class LuongAttention(_BaseAttentionMechanism): attention_g = None if self._scale: attention_g = variable_scope.get_variable( - "attention_g", dtype=query.dtype, - initializer=init_ops.ones_initializer, shape=()) + "attention_g", + dtype=query.dtype, + initializer=init_ops.ones_initializer, + shape=()) score = _luong_score(query, self._keys, attention_g) alignments = self._probability_fn(score, state) next_state = alignments @@ -821,11 +831,10 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): """Score the query based on the keys and values. Args: - query: Tensor of dtype matching `self.values` and shape - `[batch_size, query_depth]`. - state: Tensor of dtype matching `self.values` and shape - `[batch_size, alignments_size]` - (`alignments_size` is memory's `max_time`). + query: Tensor of dtype matching `self.values` and shape `[batch_size, + query_depth]`. + state: Tensor of dtype matching `self.values` and shape `[batch_size, + alignments_size]` (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape @@ -854,8 +863,11 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): return cls(**config) -def _bahdanau_score(processed_query, keys, attention_v, - attention_g=None, attention_b=None): +def _bahdanau_score(processed_query, + keys, + attention_v, + attention_g=None, + attention_b=None): """Implements Bahdanau-style (additive) scoring function. This attention has two forms. The first is Bhandanau attention, @@ -927,6 +939,7 @@ class BahdanauAttention(_BaseAttentionMechanism): probability_fn=None, score_mask_value=None, dtype=None, + custom_key_value_fn=None, name="BahdanauAttention"): """Construct the Attention mechanism. @@ -934,7 +947,7 @@ class BahdanauAttention(_BaseAttentionMechanism): num_units: The depth of the query mechanism. memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. - memory_sequence_length (optional): Sequence lengths for the batch entries + memory_sequence_length: (optional) Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. normalize: Python boolean. Whether to normalize the energy term. @@ -947,6 +960,8 @@ class BahdanauAttention(_BaseAttentionMechanism): `memory_sequence_length` is not None. dtype: The data type for the query and memory layers of the attention mechanism. + custom_key_value_fn: (optional): The custom function for + computing keys and values. name: Name to use when creating ops. """ if probability_fn is None: @@ -961,6 +976,7 @@ class BahdanauAttention(_BaseAttentionMechanism): num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, + custom_key_value_fn=custom_key_value_fn, memory_sequence_length=memory_sequence_length, score_mask_value=score_mask_value, name=name) @@ -972,11 +988,10 @@ class BahdanauAttention(_BaseAttentionMechanism): """Score the query based on the keys and values. Args: - query: Tensor of dtype matching `self.values` and shape - `[batch_size, query_depth]`. - state: Tensor of dtype matching `self.values` and shape - `[batch_size, alignments_size]` - (`alignments_size` is memory's `max_time`). + query: Tensor of dtype matching `self.values` and shape `[batch_size, + query_depth]`. + state: Tensor of dtype matching `self.values` and shape `[batch_size, + alignments_size]` (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape @@ -992,16 +1007,22 @@ class BahdanauAttention(_BaseAttentionMechanism): attention_b = None else: attention_g = variable_scope.get_variable( - "attention_g", dtype=query.dtype, + "attention_g", + dtype=query.dtype, initializer=init_ops.constant_initializer( math.sqrt((1. / self._num_units))), shape=()) attention_b = variable_scope.get_variable( - "attention_b", [self._num_units], dtype=query.dtype, + "attention_b", [self._num_units], + dtype=query.dtype, initializer=init_ops.zeros_initializer()) - score = _bahdanau_score(processed_query, self._keys, attention_v, - attention_g=attention_g, attention_b=attention_b) + score = _bahdanau_score( + processed_query, + self._keys, + attention_v, + attention_g=attention_g, + attention_b=attention_b) alignments = self._probability_fn(score, state) next_state = alignments return alignments, next_state @@ -1100,10 +1121,13 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): initializer=self.kernel_initializer) if self.normalize and self.attention_g is None and self.attention_b is None: self.attention_g = self.add_weight( - "attention_g", initializer=init_ops.constant_initializer( - math.sqrt((1. / self.units))), shape=()) + "attention_g", + initializer=init_ops.constant_initializer( + math.sqrt((1. / self.units))), + shape=()) self.attention_b = self.add_weight( - "attention_b", shape=[self.units], + "attention_b", + shape=[self.units], initializer=init_ops.zeros_initializer()) self.built = True @@ -1111,11 +1135,10 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): """Score the query based on the keys and values. Args: - query: Tensor of dtype matching `self.values` and shape - `[batch_size, query_depth]`. - state: Tensor of dtype matching `self.values` and shape - `[batch_size, alignments_size]` - (`alignments_size` is memory's `max_time`). + query: Tensor of dtype matching `self.values` and shape `[batch_size, + query_depth]`. + state: Tensor of dtype matching `self.values` and shape `[batch_size, + alignments_size]` (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape @@ -1124,9 +1147,12 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): next_state: same as alignments. """ processed_query = self.query_layer(query) if self.query_layer else query - score = _bahdanau_score(processed_query, self.keys, self.attention_v, - attention_g=self.attention_g, - attention_b=self.attention_b) + score = _bahdanau_score( + processed_query, + self.keys, + self.attention_v, + attention_g=self.attention_g, + attention_b=self.attention_b) alignments = self.probability_fn(score, state) next_state = alignments return alignments, next_state @@ -1160,14 +1186,16 @@ def safe_cumprod(x, *args, **kwargs): x: Tensor to take the cumulative product of. *args: Passed on to cumsum; these are identical to those in cumprod. **kwargs: Passed on to cumsum; these are identical to those in cumprod. + Returns: Cumulative product of x. """ with ops.name_scope(None, "SafeCumprod", [x]): x = ops.convert_to_tensor(x, name="x") tiny = np.finfo(x.dtype.as_numpy_dtype).tiny - return math_ops.exp(math_ops.cumsum( - math_ops.log(clip_ops.clip_by_value(x, tiny, 1)), *args, **kwargs)) + return math_ops.exp( + math_ops.cumsum( + math_ops.log(clip_ops.clip_by_value(x, tiny, 1)), *args, **kwargs)) def monotonic_attention(p_choose_i, previous_attention, mode): @@ -1190,19 +1218,17 @@ def monotonic_attention(p_choose_i, previous_attention, mode): the first output timestep, preevious_attention[n] should be [1, 0, 0, ..., 0] for all n in [0, ... batch_size - 1]. mode: How to compute the attention distribution. Must be one of - 'recursive', 'parallel', or 'hard'. - * 'recursive' uses tf.scan to recursively compute the distribution. - This is slowest but is exact, general, and does not suffer from - numerical instabilities. - * 'parallel' uses parallelized cumulative-sum and cumulative-product - operations to compute a closed-form solution to the recurrence - relation defining the attention distribution. This makes it more - efficient than 'recursive', but it requires numerical checks which - make the distribution non-exact. This can be a problem in particular - when input_sequence_length is long and/or p_choose_i has entries very - close to 0 or 1. - * 'hard' requires that the probabilities in p_choose_i are all either 0 - or 1, and subsequently uses a more efficient and exact solution. + 'recursive', 'parallel', or 'hard'. * 'recursive' uses tf.scan to + recursively compute the distribution. This is slowest but is exact, + general, and does not suffer from numerical instabilities. * 'parallel' + uses parallelized cumulative-sum and cumulative-product operations to + compute a closed-form solution to the recurrence relation defining the + attention distribution. This makes it more efficient than 'recursive', + but it requires numerical checks which make the distribution non-exact. + This can be a problem in particular when input_sequence_length is long + and/or p_choose_i has entries very close to 0 or 1. * 'hard' requires that + the probabilities in p_choose_i are all either 0 or 1, and subsequently + uses a more efficient and exact solution. Returns: A tensor of shape (batch_size, input_sequence_length) representing the @@ -1225,22 +1251,26 @@ def monotonic_attention(p_choose_i, previous_attention, mode): # Compute attention distribution recursively as # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i] # attention[i] = p_choose_i[i]*q[i] - attention = p_choose_i*array_ops.transpose(functional_ops.scan( - # Need to use reshape to remind TF of the shape between loop iterations - lambda x, yz: array_ops.reshape(yz[0]*x + yz[1], (batch_size,)), - # Loop variables yz[0] and yz[1] - [array_ops.transpose(shifted_1mp_choose_i), - array_ops.transpose(previous_attention)], - # Initial value of x is just zeros - array_ops.zeros((batch_size,)))) + attention = p_choose_i * array_ops.transpose( + functional_ops.scan( + # Need to use reshape to remind TF of the shape between loop iterations + lambda x, yz: array_ops.reshape(yz[0] * x + yz[1], (batch_size,)), + # Loop variables yz[0] and yz[1] + [ + array_ops.transpose(shifted_1mp_choose_i), + array_ops.transpose(previous_attention) + ], + # Initial value of x is just zeros + array_ops.zeros((batch_size,)))) elif mode == "parallel": # safe_cumprod computes cumprod in logspace with numeric checks cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True) # Compute recurrence relation solution - attention = p_choose_i*cumprod_1mp_choose_i*math_ops.cumsum( + attention = p_choose_i * cumprod_1mp_choose_i * math_ops.cumsum( previous_attention / # Clip cumprod_1mp to avoid divide-by-zero - clip_ops.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.), axis=1) + clip_ops.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.), + axis=1) elif mode == "hard": # Remove any probabilities before the index chosen last time step p_choose_i *= math_ops.cumsum(previous_attention, axis=1) @@ -1249,14 +1279,17 @@ def monotonic_attention(p_choose_i, previous_attention, mode): # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1] # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0] # Product of above: [0, 0, 0, 1, 0, 0, 0, 0] - attention = p_choose_i*math_ops.cumprod( + attention = p_choose_i * math_ops.cumprod( 1 - p_choose_i, axis=1, exclusive=True) else: raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.") return attention -def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode, +def _monotonic_probability_fn(score, + previous_alignments, + sigmoid_noise, + mode, seed=None): """Attention probability function for monotonic attention. @@ -1271,8 +1304,8 @@ def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode, Args: score: Unnormalized attention scores, shape `[batch_size, alignments_size]` - previous_alignments: Previous attention distribution, shape - `[batch_size, alignments_size]` + previous_alignments: Previous attention distribution, shape `[batch_size, + alignments_size]` sigmoid_noise: Standard deviation of pre-sigmoid noise. Setting this larger than 0 will encourage the model to produce large attention scores, effectively making the choosing probabilities discrete and the resulting @@ -1289,9 +1322,9 @@ def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode, """ # Optionally add pre-sigmoid noise to the scores if sigmoid_noise > 0: - noise = random_ops.random_normal(array_ops.shape(score), dtype=score.dtype, - seed=seed) - score += sigmoid_noise*noise + noise = random_ops.random_normal( + array_ops.shape(score), dtype=score.dtype, seed=seed) + score += sigmoid_noise * noise # Compute "choosing" probabilities from the attention scores if mode == "hard": # When mode is hard, use a hard sigmoid @@ -1326,7 +1359,8 @@ class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism): """ max_time = self._alignments_size return array_ops.one_hot( - array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, + array_ops.zeros((batch_size,), dtype=dtypes.int32), + max_time, dtype=dtype) @@ -1354,7 +1388,8 @@ class _BaseMonotonicAttentionMechanismV2(_BaseAttentionMechanismV2): """ max_time = self._alignments_size return array_ops.one_hot( - array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, + array_ops.zeros((batch_size,), dtype=dtypes.int32), + max_time, dtype=dtype) @@ -1417,7 +1452,9 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): if dtype is None: dtype = dtypes.float32 wrapped_probability_fn = functools.partial( - _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + _monotonic_probability_fn, + sigmoid_noise=sigmoid_noise, + mode=mode, seed=sigmoid_noise_seed) super(BahdanauMonotonicAttention, self).__init__( query_layer=layers_core.Dense( @@ -1438,19 +1475,18 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): """Score the query based on the keys and values. Args: - query: Tensor of dtype matching `self.values` and shape - `[batch_size, query_depth]`. - state: Tensor of dtype matching `self.values` and shape - `[batch_size, alignments_size]` - (`alignments_size` is memory's `max_time`). + query: Tensor of dtype matching `self.values` and shape `[batch_size, + query_depth]`. + state: Tensor of dtype matching `self.values` and shape `[batch_size, + alignments_size]` (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]` (`alignments_size` is memory's `max_time`). """ - with variable_scope.variable_scope( - None, "bahdanau_monotonic_attention", [query]): + with variable_scope.variable_scope(None, "bahdanau_monotonic_attention", + [query]): processed_query = self.query_layer(query) if self.query_layer else query attention_v = variable_scope.get_variable( "attention_v", [self._num_units], dtype=query.dtype) @@ -1459,17 +1495,24 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): attention_b = None else: attention_g = variable_scope.get_variable( - "attention_g", dtype=query.dtype, + "attention_g", + dtype=query.dtype, initializer=init_ops.constant_initializer( math.sqrt((1. / self._num_units))), shape=()) attention_b = variable_scope.get_variable( - "attention_b", [self._num_units], dtype=query.dtype, + "attention_b", [self._num_units], + dtype=query.dtype, initializer=init_ops.zeros_initializer()) - score = _bahdanau_score(processed_query, self._keys, attention_v, - attention_g=attention_g, attention_b=attention_b) + score = _bahdanau_score( + processed_query, + self._keys, + attention_v, + attention_g=attention_g, + attention_b=attention_b) score_bias = variable_scope.get_variable( - "attention_score_bias", dtype=processed_query.dtype, + "attention_score_bias", + dtype=processed_query.dtype, initializer=self._score_bias_init) score += score_bias alignments = self._probability_fn(score, state) @@ -1538,7 +1581,9 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): if dtype is None: dtype = dtypes.float32 wrapped_probability_fn = functools.partial( - _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + _monotonic_probability_fn, + sigmoid_noise=sigmoid_noise, + mode=mode, seed=sigmoid_noise_seed) query_layer = kwargs.pop("query_layer", None) if not query_layer: @@ -1573,21 +1618,26 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): super(BahdanauMonotonicAttentionV2, self).build(input_shape) if self.attention_v is None: self.attention_v = self.add_weight( - "attention_v", [self.units], dtype=self.dtype, + "attention_v", [self.units], + dtype=self.dtype, initializer=self.kernel_initializer) if self.attention_score_bias is None: self.attention_score_bias = self.add_weight( - "attention_score_bias", shape=(), dtype=self.dtype, + "attention_score_bias", + shape=(), + dtype=self.dtype, initializer=init_ops.constant_initializer( self.score_bias_init, dtype=self.dtype)) if self.normalize and self.attention_g is None and self.attention_b is None: self.attention_g = self.add_weight( - "attention_g", dtype=self.dtype, + "attention_g", + dtype=self.dtype, initializer=init_ops.constant_initializer( math.sqrt((1. / self.units))), shape=()) self.attention_b = self.add_weight( - "attention_b", [self.units], dtype=self.dtype, + "attention_b", [self.units], + dtype=self.dtype, initializer=init_ops.zeros_initializer()) self.built = True @@ -1595,11 +1645,10 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): """Score the query based on the keys and values. Args: - query: Tensor of dtype matching `self.values` and shape - `[batch_size, query_depth]`. - state: Tensor of dtype matching `self.values` and shape - `[batch_size, alignments_size]` - (`alignments_size` is memory's `max_time`). + query: Tensor of dtype matching `self.values` and shape `[batch_size, + query_depth]`. + state: Tensor of dtype matching `self.values` and shape `[batch_size, + alignments_size]` (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape @@ -1607,9 +1656,12 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): `max_time`). """ processed_query = self.query_layer(query) if self.query_layer else query - score = _bahdanau_score(processed_query, self.keys, self.attention_v, - attention_g=self.attention_g, - attention_b=self.attention_b) + score = _bahdanau_score( + processed_query, + self.keys, + self.attention_v, + attention_g=self.attention_g, + attention_b=self.attention_b) score += self.attention_score_bias alignments = self.probability_fn(score, state) next_state = alignments @@ -1692,7 +1744,9 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): if dtype is None: dtype = dtypes.float32 wrapped_probability_fn = functools.partial( - _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + _monotonic_probability_fn, + sigmoid_noise=sigmoid_noise, + mode=mode, seed=sigmoid_noise_seed) super(LuongMonotonicAttention, self).__init__( query_layer=None, @@ -1712,11 +1766,10 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): """Score the query based on the keys and values. Args: - query: Tensor of dtype matching `self.values` and shape - `[batch_size, query_depth]`. - state: Tensor of dtype matching `self.values` and shape - `[batch_size, alignments_size]` - (`alignments_size` is memory's `max_time`). + query: Tensor of dtype matching `self.values` and shape `[batch_size, + query_depth]`. + state: Tensor of dtype matching `self.values` and shape `[batch_size, + alignments_size]` (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape @@ -1728,11 +1781,14 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): attention_g = None if self._scale: attention_g = variable_scope.get_variable( - "attention_g", dtype=query.dtype, - initializer=init_ops.ones_initializer, shape=()) + "attention_g", + dtype=query.dtype, + initializer=init_ops.ones_initializer, + shape=()) score = _luong_score(query, self._keys, attention_g) score_bias = variable_scope.get_variable( - "attention_score_bias", dtype=query.dtype, + "attention_score_bias", + dtype=query.dtype, initializer=self._score_bias_init) score += score_bias alignments = self._probability_fn(score, state) @@ -1796,7 +1852,9 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): if dtype is None: dtype = dtypes.float32 wrapped_probability_fn = functools.partial( - _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + _monotonic_probability_fn, + sigmoid_noise=sigmoid_noise, + mode=mode, seed=sigmoid_noise_seed) memory_layer = kwargs.pop("memory_layer", None) if not memory_layer: @@ -1827,7 +1885,8 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): "attention_g", initializer=init_ops.ones_initializer, shape=()) if self.attention_score_bias is None: self.attention_score_bias = self.add_weight( - "attention_score_bias", shape=(), + "attention_score_bias", + shape=(), initializer=init_ops.constant_initializer( self.score_bias_init, dtype=self.dtype)) self.built = True @@ -1836,11 +1895,10 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): """Score the query based on the keys and values. Args: - query: Tensor of dtype matching `self.values` and shape - `[batch_size, query_depth]`. - state: Tensor of dtype matching `self.values` and shape - `[batch_size, alignments_size]` - (`alignments_size` is memory's `max_time`). + query: Tensor of dtype matching `self.values` and shape `[batch_size, + query_depth]`. + state: Tensor of dtype matching `self.values` and shape `[batch_size, + alignments_size]` (`alignments_size` is memory's `max_time`). Returns: alignments: Tensor of dtype matching `self.values` and shape @@ -1917,6 +1975,7 @@ class AttentionWrapperState( A new `AttentionWrapperState` whose properties are the same as this one, except any overridden properties as provided in `kwargs`. """ + def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): @@ -1932,12 +1991,13 @@ class AttentionWrapperState( return new return nest.map_structure( - with_same_shape, - self, + with_same_shape, self, super(AttentionWrapperState, self)._replace(**kwargs)) -def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, +def _prepare_memory(memory, + memory_sequence_length=None, + memory_mask=None, check_inner_dims_defined=True): """Convert to tensor and possibly mask `memory`. @@ -1947,8 +2007,8 @@ def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, memory_mask: `boolean` tensor with shape [batch_size, max_time]. The memory should be skipped when the corresponding mask is False. check_inner_dims_defined: Python boolean. If `True`, the `memory` - argument's shape is checked to ensure all but the two outermost - dimensions are fully defined. + argument's shape is checked to ensure all but the two outermost dimensions + are fully defined. Returns: A (possibly masked), checked, new `memory`. @@ -1957,8 +2017,8 @@ def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, ValueError: If `check_inner_dims_defined` is `True` and not `memory.shape[2:].is_fully_defined()`. """ - memory = nest.map_structure( - lambda m: ops.convert_to_tensor(m, name="memory"), memory) + memory = nest.map_structure(lambda m: ops.convert_to_tensor(m, name="memory"), + memory) if memory_sequence_length is not None and memory_mask is not None: raise ValueError("memory_sequence_length and memory_mask can't be provided " "at same time.") @@ -1966,10 +2026,12 @@ def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, memory_sequence_length = ops.convert_to_tensor( memory_sequence_length, name="memory_sequence_length") if check_inner_dims_defined: + def _check_dims(m): if not m.get_shape()[2:].is_fully_defined(): raise ValueError("Expected memory %s to have fully defined inner dims, " "but saw shape: %s" % (m.name, m.get_shape())) + nest.map_structure(_check_dims, memory) if memory_sequence_length is None and memory_mask is None: return memory @@ -1982,6 +2044,7 @@ def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, # For memory_mask is not None seq_len_mask = math_ops.cast( memory_mask, dtype=nest.flatten(memory)[0].dtype) + def _maybe_mask(m, seq_len_mask): """Mask the memory based on the memory mask.""" rank = m.get_shape().ndims @@ -1995,7 +2058,9 @@ def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) -def _maybe_mask_score(score, memory_sequence_length=None, memory_mask=None, +def _maybe_mask_score(score, + memory_sequence_length=None, + memory_mask=None, score_mask_value=None): """Mask the attention score based on the masks.""" if memory_sequence_length is None and memory_mask is None: @@ -2004,7 +2069,7 @@ def _maybe_mask_score(score, memory_sequence_length=None, memory_mask=None, raise ValueError("memory_sequence_length and memory_mask can't be provided " "at same time.") if memory_sequence_length is not None: - message = "All values in memory_sequence_length must greater than zero." + message = "All values in memory_sequence_length must be greater than zero." with ops.control_dependencies( [check_ops.assert_positive(memory_sequence_length, message=message)]): memory_mask = array_ops.sequence_mask( @@ -2021,6 +2086,7 @@ def hardmax(logits, name=None): Args: logits: A batch tensor of logit values. name: Name to use when creating ops. + Returns: A batched one-hot tensor. """ @@ -2069,8 +2135,7 @@ def _compute_attention(attention_mechanism, cell_output, attention_state, class AttentionWrapper(rnn_cell_impl.RNNCell): - """Wraps another `RNNCell` with attention. - """ + """Wraps another `RNNCell` with attention.""" def __init__(self, cell, @@ -2126,40 +2191,39 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If - attention_layer is set, this must be None. If attention_fn is set, - it must guaranteed that the outputs of attention_fn also meet the - above requirements. - alignment_history: Python boolean, whether to store alignment history - from all time steps in the final output state (currently stored as a - time major `TensorArray` on which you must call `stack()`). + attention_layer is set, this must be None. If attention_fn is set, it + must guaranteed that the outputs of attention_fn also meet the above + requirements. + alignment_history: Python boolean, whether to store alignment history from + all time steps in the final output state (currently stored as a time + major `TensorArray` on which you must call `stack()`). cell_input_fn: (optional) A `callable`. The default is: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. output_attention: Python bool. If `True` (default), the output at each time step is the attention value. This is the behavior of Luong-style - attention mechanisms. If `False`, the output at each time step is - the output of `cell`. This is the behavior of Bhadanau-style - attention mechanisms. In both cases, the `attention` tensor is - propagated to the next time step via the state and is used there. - This flag only controls whether the attention mechanism is propagated - up to the next cell in an RNN stack or to the top RNN output. - initial_cell_state: The initial state value to use for the cell when - the user calls `zero_state()`. Note that if this value is provided - now, and the user uses a `batch_size` argument of `zero_state` which - does not match the batch size of `initial_cell_state`, proper - behavior is not guaranteed. + attention mechanisms. If `False`, the output at each time step is the + output of `cell`. This is the behavior of Bhadanau-style attention + mechanisms. In both cases, the `attention` tensor is propagated to the + next time step via the state and is used there. This flag only controls + whether the attention mechanism is propagated up to the next cell in an + RNN stack or to the top RNN output. + initial_cell_state: The initial state value to use for the cell when the + user calls `zero_state()`. Note that if this value is provided now, and + the user uses a `batch_size` argument of `zero_state` which does not + match the batch size of `initial_cell_state`, proper behavior is not + guaranteed. name: Name to use when creating ops. - attention_layer: A list of `tf.layers.Layer` instances or a - single `tf.layers.Layer` instance taking the context and cell output as - inputs to generate attention at each time step. If None (default), use - the context as attention at each time step. If attention_mechanism is a - list, attention_layer must be a list of the same length. If - attention_layers_size is set, this must be None. + attention_layer: A list of `tf.compat.v1.layers.Layer` instances or a + single `tf.compat.v1.layers.Layer` instance taking the context and cell + output as inputs to generate attention at each time step. If None + (default), use the context as attention at each time step. If + attention_mechanism is a list, attention_layer must be a list of the + same length. If attention_layers_size is set, this must be None. attention_fn: An optional callable function that allows users to provide their own customized attention function, which takes input (attention_mechanism, cell_output, attention_state, attention_layer) and - outputs (attention, alignments, next_attention_state). If provided, - the attention_layer_size should be the size of the outputs of - attention_fn. + outputs (attention, alignments, next_attention_state). If provided, the + attention_layer_size should be the size of the outputs of attention_fn. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` @@ -2175,17 +2239,16 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): attention_mechanisms = attention_mechanism for attention_mechanism in attention_mechanisms: if not isinstance(attention_mechanism, AttentionMechanism): - raise TypeError( - "attention_mechanism must contain only instances of " - "AttentionMechanism, saw type: %s" - % type(attention_mechanism).__name__) + raise TypeError("attention_mechanism must contain only instances of " + "AttentionMechanism, saw type: %s" % + type(attention_mechanism).__name__) else: self._is_multi = False if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must be an AttentionMechanism or list of " - "multiple AttentionMechanism instances, saw type: %s" - % type(attention_mechanism).__name__) + "multiple AttentionMechanism instances, saw type: %s" % + type(attention_mechanism).__name__) attention_mechanisms = (attention_mechanism,) if cell_input_fn is None: @@ -2193,9 +2256,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): lambda inputs, attention: array_ops.concat([inputs, attention], -1)) else: if not callable(cell_input_fn): - raise TypeError( - "cell_input_fn must be callable, saw type: %s" - % type(cell_input_fn).__name__) + raise TypeError("cell_input_fn must be callable, saw type: %s" % + type(cell_input_fn).__name__) if attention_layer_size is not None and attention_layer is not None: raise ValueError("Only one of attention_layer_size and attention_layer " @@ -2203,14 +2265,13 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): if attention_layer_size is not None: attention_layer_sizes = tuple( - attention_layer_size - if isinstance(attention_layer_size, (list, tuple)) - else (attention_layer_size,)) + attention_layer_size if isinstance(attention_layer_size, ( + list, tuple)) else (attention_layer_size,)) if len(attention_layer_sizes) != len(attention_mechanisms): raise ValueError( "If provided, attention_layer_size must contain exactly one " - "integer per attention_mechanism, saw: %d vs %d" - % (len(attention_layer_sizes), len(attention_mechanisms))) + "integer per attention_mechanism, saw: %d vs %d" % + (len(attention_layer_sizes), len(attention_mechanisms))) self._attention_layers = tuple( layers_core.Dense( attention_layer_size, @@ -2221,21 +2282,20 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): self._attention_layer_size = sum(attention_layer_sizes) elif attention_layer is not None: self._attention_layers = tuple( - attention_layer - if isinstance(attention_layer, (list, tuple)) - else (attention_layer,)) + attention_layer if isinstance(attention_layer, (list, tuple)) else ( + attention_layer,)) if len(self._attention_layers) != len(attention_mechanisms): raise ValueError( "If provided, attention_layer must contain exactly one " - "layer per attention_mechanism, saw: %d vs %d" - % (len(self._attention_layers), len(attention_mechanisms))) + "layer per attention_mechanism, saw: %d vs %d" % + (len(self._attention_layers), len(attention_mechanisms))) self._attention_layer_size = sum( - tensor_shape.dimension_value(layer.compute_output_shape( - [None, - cell.output_size + tensor_shape.dimension_value( - mechanism.values.shape[-1])])[-1]) - for layer, mechanism in zip( - self._attention_layers, attention_mechanisms)) + tensor_shape.dimension_value( + layer.compute_output_shape([ + None, cell.output_size + + tensor_shape.dimension_value(mechanism.values.shape[-1]) + ])[-1]) for layer, mechanism in zip(self._attention_layers, + attention_mechanisms)) else: self._attention_layers = None self._attention_layer_size = sum( @@ -2257,8 +2317,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): else: final_state_tensor = nest.flatten(initial_cell_state)[-1] state_batch_size = ( - tensor_shape.dimension_value(final_state_tensor.shape[0]) - or array_ops.shape(final_state_tensor)[0]) + tensor_shape.dimension_value(final_state_tensor.shape[0]) or + array_ops.shape(final_state_tensor)[0]) error_message = ( "When constructing AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " @@ -2273,10 +2333,11 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): initial_cell_state) def _batch_size_checks(self, batch_size, error_message): - return [check_ops.assert_equal(batch_size, - attention_mechanism.batch_size, - message=error_message) - for attention_mechanism in self._attention_mechanisms] + return [ + check_ops.assert_equal( + batch_size, attention_mechanism.batch_size, message=error_message) + for attention_mechanism in self._attention_mechanisms + ] def _item_or_tuple(self, seq): """Returns `seq` as tuple or the singular element. @@ -2347,8 +2408,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: - cell_state = self._cell.get_initial_state(batch_size=batch_size, - dtype=dtype) + cell_state = self._cell.get_initial_state( + batch_size=batch_size, dtype=dtype) error_message = ( "When calling zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " @@ -2364,7 +2425,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): cell_state) initial_alignments = [ attention_mechanism.initial_alignments(batch_size, dtype) - for attention_mechanism in self._attention_mechanisms] + for attention_mechanism in self._attention_mechanisms + ] return AttentionWrapperState( cell_state=cell_state, time=array_ops.zeros([], dtype=dtypes.int32), @@ -2379,9 +2441,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): dtype, size=0, dynamic_size=True, - element_shape=alignment.shape) - if self._alignment_history else () - for alignment in initial_alignments)) + element_shape=alignment.shape) if self._alignment_history else + () for alignment in initial_alignments)) def call(self, inputs, state): """Perform a step of attention-wrapped RNN. @@ -2400,8 +2461,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. - state: An instance of `AttentionWrapperState` containing - tensors from the previous time step. + state: An instance of `AttentionWrapperState` containing tensors from the + previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: @@ -2415,7 +2476,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): """ if not isinstance(state, AttentionWrapperState): raise TypeError("Expected state to be instance of AttentionWrapperState. " - "Received type %s instead." % type(state)) + "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. @@ -2435,8 +2496,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): - cell_output = array_ops.identity( - cell_output, name="checked_cell_output") + cell_output = array_ops.identity(cell_output, name="checked_cell_output") if self._is_multi: previous_attention_state = state.attention_state diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index 16dfa7ed826..d4a9c211214 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A class of Decoders that may sample to generate the next input. -""" +"""A class of Decoders that may sample to generate the next input.""" from __future__ import absolute_import from __future__ import division @@ -31,7 +30,6 @@ from tensorflow.python.layers import base as layers_base from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.util import nest - __all__ = [ "BasicDecoderOutput", "BasicDecoder", @@ -54,9 +52,9 @@ class BasicDecoder(decoder.Decoder): helper: A `Helper` instance. initial_state: A (possibly nested tuple of...) tensors and TensorArrays. The initial state of the RNNCell. - output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior - to storing the result or sampling. + output_layer: (Optional) An instance of `tf.compat.v1.layers.Layer`, i.e., + `tf.compat.v1.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. Raises: TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. @@ -64,10 +62,10 @@ class BasicDecoder(decoder.Decoder): rnn_cell_impl.assert_like_rnncell("cell", cell) if not isinstance(helper, helper_py.Helper): raise TypeError("helper must be a Helper, received: %s" % type(helper)) - if (output_layer is not None - and not isinstance(output_layer, layers_base.Layer)): - raise TypeError( - "output_layer must be a Layer, received: %s" % type(output_layer)) + if (output_layer is not None and + not isinstance(output_layer, layers_base.Layer)): + raise TypeError("output_layer must be a Layer, received: %s" % + type(output_layer)) self._cell = cell self._helper = helper self._initial_state = initial_state @@ -89,8 +87,7 @@ class BasicDecoder(decoder.Decoder): # dimensions to get the output size of the rnn with the layer # applied to the top. output_shape_with_unknown_batch = nest.map_structure( - lambda s: tensor_shape.TensorShape([None]).concatenate(s), - size) + lambda s: tensor_shape.TensorShape([None]).concatenate(s), size) layer_output_shape = self._output_layer.compute_output_shape( output_shape_with_unknown_batch) return nest.map_structure(lambda s: s[1:], layer_output_shape) @@ -159,9 +156,9 @@ class BasicDecoderV2(decoder.BaseDecoder): Args: cell: An `RNNCell` instance. sampler: A `Sampler` instance. - output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior to - storing the result or sampling. + output_layer: (Optional) An instance of `tf.compat.v1.layers.Layer`, i.e., + `tf.compat.v1.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. **kwargs: Other keyward arguments for layer creation. Raises: @@ -172,8 +169,8 @@ class BasicDecoderV2(decoder.BaseDecoder): raise TypeError("sampler must be a Sampler, received: %s" % (sampler,)) if (output_layer is not None and not isinstance(output_layer, layers.Layer)): - raise TypeError( - "output_layer must be a Layer, received: %s" % (output_layer,)) + raise TypeError("output_layer must be a Layer, received: %s" % + (output_layer,)) self.cell = cell self.sampler = sampler self.output_layer = output_layer diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index e67e5c0d9c5..139f37c0472 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -149,8 +149,8 @@ def gather_tree_from_array(t, parent_ids, sequence_length): array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) - max_sequence_lengths = math_ops.to_int32( - math_ops.reduce_max(sequence_length, axis=1)) + max_sequence_lengths = math_ops.cast( + math_ops.reduce_max(sequence_length, axis=1), dtypes.int32) sorted_beam_ids = beam_search_ops.gather_tree( step_ids=beam_ids, parent_ids=parent_ids, @@ -351,8 +351,8 @@ class BeamSearchDecoderMixin(object): """ del sequence_lengths # Get max_sequence_length across all beams for each batch. - max_sequence_lengths = math_ops.to_int32( - math_ops.reduce_max(final_state.lengths, axis=1)) + max_sequence_lengths = math_ops.cast( + math_ops.reduce_max(final_state.lengths, axis=1), dtypes.int32) predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, @@ -982,10 +982,10 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, lengths_to_add = array_ops.one_hot( indices=array_ops.fill([batch_size, beam_width], end_token), depth=vocab_size, - on_value=np.int64(0), - off_value=np.int64(1), + on_value=math_ops.to_int64(0), + off_value=math_ops.to_int64(1), dtype=dtypes.int64) - add_mask = math_ops.to_int64(not_finished) + add_mask = math_ops.cast(not_finished, dtypes.int64) lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) @@ -996,7 +996,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, attention_probs = get_attention_probs( next_cell_state, coverage_penalty_weight) if attention_probs is not None: - attention_probs *= array_ops.expand_dims(math_ops.to_float(not_finished), 2) + attention_probs *= array_ops.expand_dims( + math_ops.cast(not_finished, dtypes.float32), 2) accumulated_attention_probs = ( beam_state.accumulated_attention_probs + attention_probs) @@ -1030,15 +1031,17 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, gather_shape=[-1], name="next_beam_probs") # Note: just doing the following - # math_ops.to_int32(word_indices % vocab_size, + # math_ops.cast( + # word_indices % vocab_size, + # dtypes.int32, # name="next_beam_word_ids") # would be a lot cleaner but for reasons unclear, that hides the results of # the op which prevents capturing it with tfdbg debug ops. raw_next_word_ids = math_ops.mod( word_indices, vocab_size, name="next_beam_word_ids") - next_word_ids = math_ops.to_int32(raw_next_word_ids) - next_beam_ids = math_ops.to_int32( - word_indices / vocab_size, name="next_beam_parent_ids") + next_word_ids = math_ops.cast(raw_next_word_ids, dtypes.int32) + next_beam_ids = math_ops.cast( + word_indices / vocab_size, dtypes.int32, name="next_beam_parent_ids") # Append new ids to current predictions previously_finished = _tensor_gather_helper( @@ -1057,7 +1060,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # 2. Beams that are now finished (EOS predicted) have their length # increased by 1. # 3. Beams that are not yet finished have their length increased by 1. - lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished)) + lengths_to_add = math_ops.cast( + math_ops.logical_not(previously_finished), dtypes.int64) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, @@ -1204,7 +1208,7 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight, coverage_penalty = math_ops.reduce_sum( math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2) # Apply coverage penalty to finished predictions. - coverage_penalty *= math_ops.to_float(finished) + coverage_penalty *= math_ops.cast(finished, dtypes.float32) weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1] weighted_coverage_penalty = array_ops.expand_dims( @@ -1257,8 +1261,9 @@ def _length_penalty(sequence_lengths, penalty_factor): static_penalty = tensor_util.constant_value(penalty_factor) if static_penalty is not None and static_penalty == 0: return 1.0 - return math_ops.div((5. + math_ops.to_float(sequence_lengths)) - **penalty_factor, (5. + 1.)**penalty_factor) + return math_ops.div( + (5. + math_ops.cast(sequence_lengths, dtypes.float32))**penalty_factor, + (5. + 1.)**penalty_factor) def _mask_probs(probs, eos_token, finished): diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 9c088591807..40774c2238a 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -65,6 +65,7 @@ py_test( "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], main = "bundle_shim_test.py", + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -114,6 +115,7 @@ py_test( name = "exporter_test", size = "small", srcs = ["exporter_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", visibility = ["//visibility:private"], deps = [ @@ -150,6 +152,7 @@ py_library( py_test( name = "gc_test", srcs = ["gc_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows visibility = ["//visibility:private"], @@ -266,6 +269,7 @@ py_test( srcs = ["session_bundle_test.py"], data = [":session_bundle_half_plus_two"], main = "session_bundle_test.py", + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -413,5 +417,4 @@ tf_proto_library( visibility = ["//visibility:public"], ) -# ----------------------------------------------------------------------------- -# Google-internal targets go here (must be at the end). +# Placeholder for Google-internal load statements. diff --git a/tensorflow/contrib/session_bundle/bundle_shim.py b/tensorflow/contrib/session_bundle/bundle_shim.py index 1db97020a2a..f05cd5dab68 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.py +++ b/tensorflow/contrib/session_bundle/bundle_shim.py @@ -40,10 +40,8 @@ def _add_input_to_signature_def(tensor_name, map_key, signature_def): tensor_name: string name of tensor to add to signature_def inputs map_key: string key to key into signature_def inputs map signature_def: object of type meta_graph_pb2.SignatureDef() - - Sideffect: - adds a TensorInfo with tensor_name to signature_def inputs map keyed with - map_key + Sideffect: adds a TensorInfo with tensor_name to signature_def inputs map + keyed with map_key """ tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name) signature_def.inputs[map_key].CopyFrom(tensor_info) @@ -56,10 +54,8 @@ def _add_output_to_signature_def(tensor_name, map_key, signature_def): tensor_name: string name of tensor to add to signature_def outputs map_key: string key to key into signature_def outputs map signature_def: object of type meta_graph_pb2.SignatureDef() - - Sideffect: - adds a TensorInfo with tensor_name to signature_def outputs map keyed with - map_key + Sideffect: adds a TensorInfo with tensor_name to signature_def outputs map + keyed with map_key """ tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name) @@ -106,9 +102,10 @@ def _convert_default_signature_to_signature_def(signatures): signature_constants.CLASSIFY_OUTPUT_SCORES, signature_def) else: - logging.error("Only classification and regression default signatures " - "are supported for up-conversion. %s is not " - "supported" % default_signature.WhichOneof("type")) + logging.error( + "Only classification and regression default signatures " + "are supported for up-conversion. %s is not " + "supported", default_signature.WhichOneof("type")) return None return signature_def @@ -156,7 +153,7 @@ def _convert_signatures_to_signature_defs(metagraph_def): Args: metagraph_def: object of type meta_graph_pb2.MetaGraphDef containing legacy - format Session Bundle signatures + format Session Bundle signatures Returns: default_signature_def: object of type SignatureDef which contains an @@ -186,9 +183,10 @@ def _load_saved_model_from_session_bundle_path(export_dir, target, config): Args: export_dir: the directory that contains files exported by exporter. - target: The execution engine to connect to. See target in tf.Session() + target: The execution engine to connect to. See target in + tf.compat.v1.Session() config: A ConfigProto proto with configuration options. See config in - tf.Session() + tf.compat.v1.Session() Returns: session: a tensorflow session created from the variable files. @@ -247,11 +245,12 @@ def load_session_bundle_or_saved_model_bundle_from_path(export_dir, Args: export_dir: the directory that contains files exported by exporter. tags: Set of string tags to identify the required MetaGraphDef when model is - saved as SavedModel. These should correspond to the tags used when - saving the variables using the SavedModel `save()` API. - target: The execution engine to connect to. See target in tf.Session() + saved as SavedModel. These should correspond to the tags used when saving + the variables using the SavedModel `save()` API. + target: The execution engine to connect to. See target in + tf.compat.v1.Session() config: A ConfigProto proto with configuration options. See config in - tf.Session() + tf.compat.v1.Session() Returns: session: a tensorflow session created from the variable files. @@ -267,9 +266,8 @@ def load_session_bundle_or_saved_model_bundle_from_path(export_dir, sess = session.Session(target, graph=None, config=config) metagraph_def = loader.load(sess, tags, export_dir) elif session_bundle.maybe_session_bundle_dir(export_dir): - sess, metagraph_def = _load_saved_model_from_session_bundle_path(export_dir, - target, - config) + sess, metagraph_def = _load_saved_model_from_session_bundle_path( + export_dir, target, config) else: raise RuntimeError("SessionBundle or SavedModelBundle not found at " "specified export location: %s" % export_dir) diff --git a/tensorflow/contrib/session_bundle/example/BUILD b/tensorflow/contrib/session_bundle/example/BUILD index 9a56eab431d..18a075943c2 100644 --- a/tensorflow/contrib/session_bundle/example/BUILD +++ b/tensorflow/contrib/session_bundle/example/BUILD @@ -15,6 +15,7 @@ py_binary( srcs = [ "export_half_plus_two.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py index f3efd292cf5..a78985b2e8f 100644 --- a/tensorflow/contrib/session_bundle/exporter.py +++ b/tensorflow/contrib/session_bundle/exporter.py @@ -44,7 +44,7 @@ from tensorflow.python.util.deprecation import deprecated @deprecated("2017-06-30", "No longer supported. Switch to SavedModel immediately.") def gfile_copy_callback(files_to_copy, export_dir_path): - """Callback to copy files using `gfile.Copy` to an export directory. + """Callback to copy files using `gfile.copy` to an export directory. This method is used as the default `assets_callback` in `Exporter.init` to copy assets from the `assets_collection`. It can also be invoked directly to diff --git a/tensorflow/contrib/session_bundle/exporter_test.py b/tensorflow/contrib/session_bundle/exporter_test.py index 68419ffea04..33f10a47c59 100644 --- a/tensorflow/contrib/session_bundle/exporter_test.py +++ b/tensorflow/contrib/session_bundle/exporter_test.py @@ -88,12 +88,12 @@ class SaveRestoreShardedTest(test.TestCase): asset_file = constant_op.constant(asset_filepath_orig, name="filename42") ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file) - with gfile.FastGFile(asset_filepath_orig, "w") as f: + with gfile.GFile(asset_filepath_orig, "w") as f: f.write("your data here") assets_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) ignored_asset = os.path.join(test.get_temp_dir(), "ignored.txt") - with gfile.FastGFile(ignored_asset, "w") as f: + with gfile.GFile(ignored_asset, "w") as f: f.write("additional data here") variables.global_variables_initializer().run() diff --git a/tensorflow/contrib/session_bundle/gc.py b/tensorflow/contrib/session_bundle/gc.py index 514cc0f652c..4a9adcf4838 100644 --- a/tensorflow/contrib/session_bundle/gc.py +++ b/tensorflow/contrib/session_bundle/gc.py @@ -57,7 +57,7 @@ For example, # delete everything not in 'both' to_delete = gc.negation(both) for p in to_delete(all_paths): - gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2", + gfile.rmtree(p.path) # deletes: "/tmp/1", "/tmp/2", # "/tmp/3", "/tmp/4", "/tmp/6", """ diff --git a/tensorflow/contrib/session_bundle/session_bundle.py b/tensorflow/contrib/session_bundle/session_bundle.py index 66f2e32f58e..0911432823b 100644 --- a/tensorflow/contrib/session_bundle/session_bundle.py +++ b/tensorflow/contrib/session_bundle/session_bundle.py @@ -64,11 +64,12 @@ def load_session_bundle_from_path(export_dir, Args: export_dir: the directory that contains files exported by exporter. - target: The execution engine to connect to. See target in tf.Session() + target: The execution engine to connect to. See target in + tf.compat.v1.Session() config: A ConfigProto proto with configuration options. See config in - tf.Session() + tf.compat.v1.Session() meta_graph_def: optional object of type MetaGraphDef. If this object is - present, then it is used instead of parsing MetaGraphDef from export_dir. + present, then it is used instead of parsing MetaGraphDef from export_dir. Returns: session: a tensorflow session created from the variable files. diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD index 516e3ea0732..96e2dcecbdf 100644 --- a/tensorflow/contrib/slim/BUILD +++ b/tensorflow/contrib/slim/BUILD @@ -23,6 +23,7 @@ py_library( py_test( name = "evaluation_test", srcs = ["python/slim/evaluation_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":evaluation", @@ -70,6 +71,7 @@ py_library( py_test( name = "learning_test", srcs = ["python/slim/learning_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["manual"], deps = [ @@ -168,6 +170,7 @@ py_library( py_test( name = "summaries_test", srcs = ["python/slim/summaries_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":summaries", diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md index 7b54aafeb2c..2f6b006da77 100644 --- a/tensorflow/contrib/slim/README.md +++ b/tensorflow/contrib/slim/README.md @@ -346,7 +346,7 @@ we can both ensure that each layer uses the same values and simplify the code: with slim.arg_scope([slim.conv2d], padding='SAME', weights_initializer=tf.truncated_normal_initializer(stddev=0.01) weights_regularizer=slim.l2_regularizer(0.0005)): - net = slim.conv2d(inputs, 64, [11, 11], scope='conv1') + net = slim.conv2d(inputs, 64, [11, 11], 4, scope='conv1') net = slim.conv2d(net, 128, [11, 11], padding='VALID', scope='conv2') net = slim.conv2d(net, 256, [11, 11], scope='conv3') ``` @@ -681,11 +681,11 @@ name to each graph variable. Consider the following example where the checkpoint variables names are obtained via a simple function: ```python -# Assuming than 'conv1/weights' should be restored from 'vgg16/conv1/weights' +# Assuming that 'conv1/weights' should be restored from 'vgg16/conv1/weights' def name_in_checkpoint(var): return 'vgg16/' + var.op.name -# Assuming than 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2' +# Assuming that 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2' def name_in_checkpoint(var): if "weights" in var.op.name: return var.op.name.replace("weights", "params1") diff --git a/tensorflow/contrib/slim/python/slim/data/BUILD b/tensorflow/contrib/slim/python/slim/data/BUILD index eef043e8327..f1b57361ac6 100644 --- a/tensorflow/contrib/slim/python/slim/data/BUILD +++ b/tensorflow/contrib/slim/python/slim/data/BUILD @@ -60,6 +60,7 @@ py_library( py_test( name = "dataset_data_provider_test", srcs = ["dataset_data_provider_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ @@ -98,6 +99,7 @@ py_test( name = "parallel_reader_test", size = "small", srcs = ["parallel_reader_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":parallel_reader", @@ -130,6 +132,7 @@ py_test( name = "prefetch_queue_test", size = "small", srcs = ["prefetch_queue_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":prefetch_queue", @@ -179,6 +182,7 @@ py_library( py_test( name = "tfexample_decoder_test", srcs = ["tfexample_decoder_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":tfexample_decoder", diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py index 99ad4876303..9988c6cf635 100644 --- a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py +++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py @@ -61,13 +61,13 @@ class ParallelReader(io_ops.ReaderBase): If the `common_queue` is a shuffling queue, then the examples are shuffled. Usage: - common_queue = tf.RandomShuffleQueue( + common_queue = tf.queue.RandomShuffleQueue( capacity=256, min_after_dequeue=128, dtypes=[tf.string, tf.string]) - p_reader = ParallelReader(tf.TFRecordReader, common_queue) + p_reader = ParallelReader(tf.compat.v1.TFRecordReader, common_queue) - common_queue = tf.FIFOQueue( + common_queue = tf.queue.FIFOQueue( capacity=256, dtypes=[tf.string, tf.string]) p_reader = ParallelReader(readers, common_queue, num_readers=2) @@ -77,7 +77,8 @@ class ParallelReader(io_ops.ReaderBase): reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader common_queue: a Queue to hold (key, value pairs) with `dtypes` equal to [tf.string, tf.string]. Must be one of the data_flow_ops.Queues - instances, ex. `tf.FIFOQueue()`, `tf.RandomShuffleQueue()`, ... + instances, ex. `tf.queue.FIFOQueue()`, `tf.queue.RandomShuffleQueue()`, + ... num_readers: a integer, number of instances of reader_class to create. reader_kwargs: an optional dict of kwargs to create the readers. @@ -119,8 +120,8 @@ class ParallelReader(io_ops.ReaderBase): to the TF QueueRunners collection. Args: - queue: A Queue or a mutable string Tensor representing a handle - to a Queue, with string work items. + queue: A Queue or a mutable string Tensor representing a handle to a + Queue, with string work items. name: A name for the operation (optional). Returns: @@ -143,8 +144,8 @@ class ParallelReader(io_ops.ReaderBase): `tf.errors.UnimplementedError` is raised. Args: - queue: A Queue or a mutable string Tensor representing a handle - to a Queue, with string work items. + queue: A Queue or a mutable string Tensor representing a handle to a + Queue, with string work items. num_records: Number of records to read. name: A name for the operation (optional). @@ -218,14 +219,14 @@ def parallel_read(data_sources, /path/to/train@128, /path/to/train* or /tmp/.../train* reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader num_epochs: The number of times each data source is read. If left as None, - the data will be cycled through indefinitely. + the data will be cycled through indefinitely. num_readers: a integer, number of Readers to create. reader_kwargs: an optional dict, of kwargs for the reader. shuffle: boolean, whether should shuffle the files and the records by using RandomShuffleQueue as common_queue. - dtypes: A list of types. The length of dtypes must equal the number - of elements in each record. If it is None it will default to - [tf.string, tf.string] for (key, value). + dtypes: A list of types. The length of dtypes must equal the number of + elements in each record. If it is None it will default to [tf.string, + tf.string] for (key, value). capacity: integer, capacity of the common_queue. min_after_dequeue: integer, minimum number of records in the common_queue after dequeue. Needed for a good shuffle. @@ -238,7 +239,10 @@ def parallel_read(data_sources, data_files = get_data_files(data_sources) with ops.name_scope(scope, 'parallel_read'): filename_queue = tf_input.string_input_producer( - data_files, num_epochs=num_epochs, shuffle=shuffle, seed=seed, + data_files, + num_epochs=num_epochs, + shuffle=shuffle, + seed=seed, name='filenames') dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string] if shuffle: @@ -252,8 +256,9 @@ def parallel_read(data_sources, common_queue = data_flow_ops.FIFOQueue( capacity=capacity, dtypes=dtypes, name='common_queue') - summary.scalar('fraction_of_%d_full' % capacity, - math_ops.to_float(common_queue.size()) * (1. / capacity)) + summary.scalar( + 'fraction_of_%d_full' % capacity, + math_ops.cast(common_queue.size(), tf_dtypes.float32) * (1. / capacity)) return ParallelReader( reader_class, diff --git a/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py b/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py index 62bd2003612..7895e809f82 100644 --- a/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py +++ b/tensorflow/contrib/slim/python/slim/data/prefetch_queue.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops @@ -26,8 +27,8 @@ from tensorflow.python.training import queue_runner def _which_queue(dynamic_pad): - return (data_flow_ops.PaddingFIFOQueue if dynamic_pad - else data_flow_ops.FIFOQueue) + return (data_flow_ops.PaddingFIFOQueue + if dynamic_pad else data_flow_ops.FIFOQueue) def prefetch_queue(tensors, @@ -43,10 +44,12 @@ def prefetch_queue(tensors, Example: This is for example useful to pre-assemble input batches read with - `tf.train.batch()` and enqueue the pre-assembled batches. Ops that dequeue + `tf.compat.v1.train.batch()` and enqueue the pre-assembled batches. Ops that + dequeue from the pre-assembled queue will not pay the cost of assembling the batch. - images, labels = tf.train.batch([image, label], batch_size=32, num_threads=4) + images, labels = tf.compat.v1.train.batch([image, label], batch_size=32, + num_threads=4) batch_queue = prefetch_queue([images, labels]) images, labels = batch_queue.dequeue() logits = Net(images) @@ -86,6 +89,7 @@ def prefetch_queue(tensors, enqueue_op = queue.enqueue(tensors) queue_runner.add_queue_runner( queue_runner.QueueRunner(queue, [enqueue_op] * num_threads)) - summary.scalar("fraction_of_%d_full" % capacity, - math_ops.to_float(queue.size()) * (1. / capacity)) + summary.scalar( + "fraction_of_%d_full" % capacity, + math_ops.cast(queue.size(), _dtypes.float32) * (1. / capacity)) return queue diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index c63a3ca19b6..8fca63292e6 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -329,7 +329,7 @@ class SparseTensor(ItemHandler): shape = indices.dense_shape indices_shape = array_ops.shape(indices.indices) rank = indices_shape[1] - ids = math_ops.to_int64(indices.values) + ids = math_ops.cast(indices.values, dtypes.int64) indices_columns_to_preserve = array_ops.slice( indices.indices, [0, 0], array_ops.stack([-1, rank - 1])) new_indices = array_ops.concat( @@ -367,7 +367,7 @@ class Image(ItemHandler): dtype: images will be decoded at this bit depth. Different formats support different bit depths. See tf.image.decode_image, - tf.decode_raw, + tf.io.decode_raw, repeated: if False, decodes a single image. If True, decodes a variable number of image strings from a 1D tensor of strings. dct_method: An optional string. Defaults to empty string. It only takes @@ -464,7 +464,7 @@ class TFExampleDecoder(data_decoder.DataDecoder): Decoding Example proto buffers is comprised of two stages: (1) Example parsing and (2) tensor manipulation. - In the first stage, the tf.parse_example function is called with a list of + In the first stage, the tf.io.parse_example function is called with a list of FixedLenFeatures and SparseLenFeatures. These instances tell TF how to parse the example. The output of this stage is a set of tensors. @@ -481,7 +481,7 @@ class TFExampleDecoder(data_decoder.DataDecoder): Args: keys_to_features: a dictionary from TF-Example keys to either - tf.VarLenFeature or tf.FixedLenFeature instances. See tensorflow's + tf.io.VarLenFeature or tf.io.FixedLenFeature instances. See tensorflow's parsing_ops.py. items_to_handlers: a dictionary from items (strings) to ItemHandler instances. Note that the ItemHandler's are provided the keys that they diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py index 0feb3925eb8..b272e2543e2 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation.py +++ b/tensorflow/contrib/slim/python/slim/evaluation.py @@ -39,8 +39,8 @@ method: log_dir = '/tmp/my_model_eval/' initial_op = tf.group( - tf.global_variables_initializer(), - tf.local_variables_initializer()) + tf.compat.v1.global_variables_initializer(), + tf.compat.v1.local_variables_initializer()) metric_values = slim.evaluate_once( master='', @@ -76,7 +76,7 @@ more summaries and call the evaluation_loop method: # Define the summaries to write: for metric_name, metric_value in metrics_to_values.iteritems(): - tf.summary.scalar(metric_name, metric_value) + tf.compat.v1.summary.scalar(metric_name, metric_value) checkpoint_dir = '/tmp/my_model_dir/' log_dir = '/tmp/my_model_eval/' @@ -106,8 +106,8 @@ with only summaries. The user need only leave out the 'eval_op' argument: predictions = MyModel(images) # Define the summaries to write: - tf.summary.scalar(...) - tf.summary.histogram(...) + tf.compat.v1.summary.scalar(...) + tf.compat.v1.summary.histogram(...) checkpoint_dir = '/tmp/my_model_dir/' log_dir = '/tmp/my_model_eval/' @@ -175,14 +175,14 @@ def evaluate_once(master, value of `final_op` is returned. final_op_feed_dict: A feed dictionary to use when executing `final_op`. summary_op: The summary_op to evaluate after running TF-Slims metric ops. By - default the summary_op is set to tf.summary.merge_all(). + default the summary_op is set to tf.compat.v1.summary.merge_all(). summary_op_feed_dict: An optional feed dictionary to use when running the `summary_op`. variables_to_restore: A list of TensorFlow variables to restore during evaluation. If the argument is left as `None` then slim.variables.GetVariablesToRestore() is used. - session_config: An instance of `tf.ConfigProto` that will be used to - configure the `Session`. If left as `None`, the default will be used. + session_config: An instance of `tf.compat.v1.ConfigProto` that will be used + to configure the `Session`. If left as `None`, the default will be used. hooks: A list of additional `SessionRunHook` objects to pass during the evaluation. @@ -192,11 +192,16 @@ def evaluate_once(master, if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() - all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),] + all_hooks = [ + evaluation.StopAfterNEvalsHook(num_evals), + ] if summary_op is not None: - all_hooks.append(evaluation.SummaryAtEndHook( - log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict)) + all_hooks.append( + evaluation.SummaryAtEndHook( + log_dir=logdir, + summary_op=summary_op, + feed_dict=summary_op_feed_dict)) if hooks is not None: all_hooks.extend(hooks) @@ -254,7 +259,7 @@ def evaluation_loop(master, value of `final_op` is returned. final_op_feed_dict: A feed dictionary to use when executing `final_op`. summary_op: The summary_op to evaluate after running TF-Slims metric ops. By - default the summary_op is set to tf.summary.merge_all(). + default the summary_op is set to tf.compat.v1.summary.merge_all(). summary_op_feed_dict: An optional feed dictionary to use when running the `summary_op`. variables_to_restore: A list of TensorFlow variables to restore during @@ -263,15 +268,15 @@ def evaluation_loop(master, eval_interval_secs: The minimum number of seconds between evaluations. max_number_of_evaluations: the max number of iterations of the evaluation. If the value is left as 'None', the evaluation continues indefinitely. - session_config: An instance of `tf.ConfigProto` that will be used to - configure the `Session`. If left as `None`, the default will be used. + session_config: An instance of `tf.compat.v1.ConfigProto` that will be used + to configure the `Session`. If left as `None`, the default will be used. timeout: The maximum amount of time to wait between checkpoints. If left as `None`, then the process will wait indefinitely. timeout_fn: Optional function to call after a timeout. If the function returns True, then it means that no new checkpoints will be generated and the iterator will exit. The function is called with no arguments. - hooks: A list of additional `SessionRunHook` objects to pass during - repeated evaluations. + hooks: A list of additional `SessionRunHook` objects to pass during repeated + evaluations. Returns: The value of `final_op` or `None` if `final_op` is `None`. @@ -279,11 +284,16 @@ def evaluation_loop(master, if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() - all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),] + all_hooks = [ + evaluation.StopAfterNEvalsHook(num_evals), + ] if summary_op is not None: - all_hooks.append(evaluation.SummaryAtEndHook( - log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict)) + all_hooks.append( + evaluation.SummaryAtEndHook( + log_dir=logdir, + summary_op=summary_op, + feed_dict=summary_op_feed_dict)) if hooks is not None: # Add custom hooks if provided. @@ -297,8 +307,10 @@ def evaluation_loop(master, checkpoint_dir, master=master, scaffold=monitored_session.Scaffold( - init_op=initial_op, init_feed_dict=initial_op_feed_dict, - init_fn=init_fn, saver=saver), + init_op=initial_op, + init_feed_dict=initial_op_feed_dict, + init_fn=init_fn, + saver=saver), eval_ops=eval_op, feed_dict=eval_op_feed_dict, final_ops=final_op, diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 6e55b9407bc..605191b654f 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -19,7 +19,8 @@ manipulating gradients, creating a `train_op` (an operation that computes the loss and applies the gradients) and a training loop function. The training loop allows the user to pass in the `train_op` and runs the optimization according to user-specified arguments. Note that the training loop uses the -tf.train.Supervisor and its managed_session in its implementation to ensure the +tf.compat.v1.train.Supervisor and its managed_session in its implementation to +ensure the ability of worker processes to recover from failures. ************************************ @@ -35,7 +36,8 @@ ability of worker processes to recover from failures. total_loss = slim.losses.get_total_loss() # Define the optimizer: - optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) + optimizer = tf.compat.v1.train.MomentumOptimizer(FLAGS.learning_rate, + FLAGS.momentum) # Create the train_op train_op = slim.learning.create_train_op(total_loss, optimizer) @@ -104,8 +106,8 @@ default update ops or simply add additional update ops to the update_ops=my_other_update_ops) # Use an alternative set of update ops in addition to the default updates: - tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update0) - tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1) + tf.compat.v1.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update0) + tf.compat.v1.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1) train_op = slim.learning.create_train_op( total_loss, @@ -115,7 +117,7 @@ default update ops or simply add additional update ops to the train_op = slim.learning.create_train_op( total_loss, optimizer, - update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)) + update_ops=tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)) ****************************************** * Initializing a model from a checkpoint * @@ -393,12 +395,12 @@ def create_train_op(total_loss, update_ops: An optional list of updates to execute. If `update_ops` is `None`, then the update ops are set to the contents of the `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but - it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, - a warning will be displayed. + it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, a + warning will be displayed. variables_to_train: an optional list of variables to train. If None, it will - default to all tf.trainable_variables(). - clip_gradient_norm: If greater than 0 then the gradients would be clipped - by it. + default to all tf.compat.v1.trainable_variables(). + clip_gradient_norm: If greater than 0 then the gradients would be clipped by + it. summarize_gradients: Whether or not add summaries for each gradient. gate_gradients: How to gate the computation of gradients. See tf.Optimizer. aggregation_method: Specifies the method used to combine gradient terms. @@ -414,6 +416,7 @@ def create_train_op(total_loss, A `Tensor` that when evaluated, computes the gradients and returns the total loss value. """ + def transform_grads_fn(grads): if gradient_multipliers: with ops.name_scope('multiply_grads'): @@ -458,8 +461,8 @@ def train_step(sess, train_op, global_step, train_step_kwargs): Args: sess: The current session. - train_op: An `Operation` that evaluates the gradients and returns the - total loss. + train_op: An `Operation` that evaluates the gradients and returns the total + loss. global_step: A `Tensor` representing the global training step. train_step_kwargs: A dictionary of keyword arguments. @@ -495,9 +498,8 @@ def train_step(sess, train_op, global_step, train_step_kwargs): logging.info('Writing trace to %s', trace_filename) file_io.write_string_to_file(trace_filename, trace) if 'summary_writer' in train_step_kwargs: - train_step_kwargs['summary_writer'].add_run_metadata(run_metadata, - 'run_metadata-%d' % - np_global_step) + train_step_kwargs['summary_writer'].add_run_metadata( + run_metadata, 'run_metadata-%d' % np_global_step) if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): @@ -566,7 +568,8 @@ def train(train_op, checkpoints and summaries will not be written. train_step_fn: The function to call in order to execute a single gradient step. The function must have take exactly four arguments: the current - session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary. + session, the `train_op` `Tensor`, a global step `Tensor` and a + dictionary. train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By default, two `Boolean`, scalar ops called "should_stop" and "should_log" are provided. @@ -581,44 +584,45 @@ def train(train_op, then training_util.get_or_create_global_step(), that is, tf.contrib.framework.global_step() is used. number_of_steps: The max number of gradient steps to take during training, - as measured by 'global_step': training will stop if global_step is - greater than 'number_of_steps'. If the value is left as None, training - proceeds indefinitely. + as measured by 'global_step': training will stop if global_step is greater + than 'number_of_steps'. If the value is left as None, training proceeds + indefinitely. init_op: The initialization operation. If left to its default value, then - the session is initialized by calling `tf.global_variables_initializer()`. + the session is initialized by calling + `tf.compat.v1.global_variables_initializer()`. init_feed_dict: A feed dictionary to use when executing the `init_op`. local_init_op: The local initialization operation. If left to its default value, then the session is initialized by calling - `tf.local_variables_initializer()` and `tf.tables_initializer()`. + `tf.compat.v1.local_variables_initializer()` and + `tf.compat.v1.tables_initializer()`. init_fn: An optional callable to be executed after `init_op` is called. The callable must accept one argument, the session being initialized. ready_op: Operation to check if the model is ready to use. If left to its default value, then the session checks for readiness by calling - `tf.report_uninitialized_variables()`. + `tf.compat.v1.report_uninitialized_variables()`. summary_op: The summary operation. save_summaries_secs: How often, in seconds, to save summaries. - summary_writer: `SummaryWriter` to use. Can be `None` - to indicate that no summaries should be written. If unset, we - create a SummaryWriter. + summary_writer: `SummaryWriter` to use. Can be `None` to indicate that no + summaries should be written. If unset, we create a SummaryWriter. startup_delay_steps: The number of steps to wait for before beginning. Note that this must be 0 if a sync_optimizer is supplied. - saver: Saver to save checkpoints. If None, a default one will be created - and used. + saver: Saver to save checkpoints. If None, a default one will be created and + used. save_interval_secs: How often, in seconds, to save the model to `logdir`. - sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of - them. If the argument is supplied, gradient updates will be synchronous. - If left as `None`, gradient updates will be asynchronous. - session_config: An instance of `tf.ConfigProto` that will be used to - configure the `Session`. If left as `None`, the default will be used. - session_wrapper: A function that takes a `tf.Session` object as the only - argument and returns a wrapped session object that has the same methods - that the original object has, or `None`. Iff not `None`, the wrapped - object will be used for training. + sync_optimizer: an instance of tf.compat.v1.train.SyncReplicasOptimizer, or + a list of them. If the argument is supplied, gradient updates will be + synchronous. If left as `None`, gradient updates will be asynchronous. + session_config: An instance of `tf.compat.v1.ConfigProto` that will be used + to configure the `Session`. If left as `None`, the default will be used. + session_wrapper: A function that takes a `tf.compat.v1.Session` object as + the only argument and returns a wrapped session object that has the same + methods that the original object has, or `None`. Iff not `None`, the + wrapped object will be used for training. trace_every_n_steps: produce and save a `Timeline` in Chrome trace format and add it to the summaries every `trace_every_n_steps`. If None, no trace information will be produced or saved. - ignore_live_threads: If `True` ignores threads that remain running after - a grace period when stopping the supervisor, instead of raising a + ignore_live_threads: If `True` ignores threads that remain running after a + grace period when stopping the supervisor, instead of raising a RuntimeError. Returns: @@ -677,8 +681,8 @@ def train(train_op, lookup_ops.tables_initializer()) if sync_optimizer is not None and isinstance(sync_optimizer, list): - with ops.control_dependencies([local_init_op] if local_init_op is - not None else []): + with ops.control_dependencies( + [local_init_op] if local_init_op is not None else []): if is_chief: local_init_op = control_flow_ops.group( *[opt.chief_init_op for opt in sync_optimizer]) @@ -700,7 +704,8 @@ def train(train_op, # Need to create these BEFORE the supervisor finalizes the graph: init_tokens_op = [opt.get_init_tokens_op() for opt in sync_optimizer] chief_queue_runner = [ - opt.get_chief_queue_runner() for opt in sync_optimizer] + opt.get_chief_queue_runner() for opt in sync_optimizer + ] if train_step_kwargs == _USE_DEFAULT: with ops.name_scope('train_step'): @@ -748,17 +753,17 @@ def train(train_op, master, start_standard_services=False, config=session_config) as sess: logging.info('Starting Session.') if session_wrapper is not None: - logging.info( - 'Wrapping session with wrapper function: %s', session_wrapper) + logging.info('Wrapping session with wrapper function: %s', + session_wrapper) sess = session_wrapper(sess) if is_chief: if logdir: sv.start_standard_services(sess) elif startup_delay_steps > 0: - # (use sys.maxsize because sys.maxint doesn't exist in Python 3) - _wait_for_step(sess, global_step, - min(startup_delay_steps, number_of_steps or - sys.maxsize)) + # (use sys.maxsize because sys.maxint doesn't exist in Python 3) + _wait_for_step( + sess, global_step, + min(startup_delay_steps, number_of_steps or sys.maxsize)) threads = sv.start_queue_runners(sess) logging.info('Starting Queues.') if is_chief and sync_optimizer is not None: @@ -766,15 +771,15 @@ def train(train_op, sess.run(init_tokens_op) try: while not sv.should_stop(): - total_loss, should_stop = train_step_fn( - sess, train_op, global_step, train_step_kwargs) + total_loss, should_stop = train_step_fn(sess, train_op, global_step, + train_step_kwargs) if should_stop: logging.info('Stopping Training.') sv.request_stop() break except errors.OutOfRangeError as e: # OutOfRangeError is thrown when epoch limit per - # tf.train.limit_epochs is reached. + # tf.compat.v1.train.limit_epochs is reached. logging.info('Caught OutOfRangeError. Stopping Training. %s', e) if logdir and sv.is_chief: logging.info('Finished training! Saving model to disk.') diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py index d92a7fbb472..5db4fe02b8e 100644 --- a/tensorflow/contrib/slim/python/slim/learning_test.py +++ b/tensorflow/contrib/slim/python/slim/learning_test.py @@ -67,8 +67,8 @@ class ClipGradientNormsTest(test.TestCase): gradient = constant_op.constant(self._grad_vec, dtype=dtypes.float32) variable = variables_lib.Variable(self._zero_vec, dtype=dtypes.float32) gradients_to_variables = (gradient, variable) - [gradients_to_variables] = learning.clip_gradient_norms( - [gradients_to_variables], self._max_norm) + [gradients_to_variables + ] = learning.clip_gradient_norms([gradients_to_variables], self._max_norm) # Ensure the variable passed through. self.assertEqual(gradients_to_variables[1], variable) @@ -82,8 +82,8 @@ class ClipGradientNormsTest(test.TestCase): variable = variables_lib.Variable(self._zero_vec, dtype=dtypes.float32) gradients_to_variables = (gradient, variable) - [gradients_to_variables] = learning.clip_gradient_norms( - [gradients_to_variables], self._max_norm) + [gradients_to_variables + ] = learning.clip_gradient_norms([gradients_to_variables], self._max_norm) self.assertEqual(gradients_to_variables[0], None) self.assertEqual(gradients_to_variables[1], variable) @@ -172,8 +172,8 @@ class MultiplyGradientsTest(test.TestCase): def testIndexedSlicesGradIsMultiplied(self): values = constant_op.constant(self._grad_vec, dtype=dtypes.float32) indices = constant_op.constant([0, 1, 2], dtype=dtypes.int32) - dense_shape = constant_op.constant( - [self._grad_vec.size], dtype=dtypes.int32) + dense_shape = constant_op.constant([self._grad_vec.size], + dtype=dtypes.int32) gradient = ops.IndexedSlices(values, indices, dense_shape) variable = variables_lib.Variable(array_ops.zeros((1, 3))) @@ -289,8 +289,8 @@ class CreateTrainOpTest(test.TestCase): train_op = learning.create_train_op(total_loss, optimizer) moving_mean = variables_lib2.get_variables_by_name('moving_mean')[0] - moving_variance = variables_lib2.get_variables_by_name('moving_variance')[ - 0] + moving_variance = variables_lib2.get_variables_by_name( + 'moving_variance')[0] with session.Session() as sess: # Initialize all variables @@ -323,8 +323,8 @@ class CreateTrainOpTest(test.TestCase): train_op = learning.create_train_op(total_loss, optimizer, update_ops=[]) moving_mean = variables_lib2.get_variables_by_name('moving_mean')[0] - moving_variance = variables_lib2.get_variables_by_name('moving_variance')[ - 0] + moving_variance = variables_lib2.get_variables_by_name( + 'moving_variance')[0] with session.Session() as sess: # Initialize all variables @@ -492,7 +492,8 @@ class TrainTest(test.TestCase): """Test that slim.learning.train can take `session_wrapper` args. One of the applications of `session_wrapper` is the wrappers of TensorFlow - Debugger (tfdbg), which intercept methods calls to `tf.Session` (e.g., run) + Debugger (tfdbg), which intercept methods calls to `tf.compat.v1.Session` + (e.g., run) to achieve debugging. `DumpingDebugWrapperSession` is used here for testing purpose. """ diff --git a/tensorflow/contrib/slim/python/slim/model_analyzer.py b/tensorflow/contrib/slim/python/slim/model_analyzer.py index 74617928a71..aad968997ea 100644 --- a/tensorflow/contrib/slim/python/slim/model_analyzer.py +++ b/tensorflow/contrib/slim/python/slim/model_analyzer.py @@ -19,11 +19,12 @@ To analyze the operations in a graph: images, labels = LoadData(...) predictions = MyModel(images) - slim.model_analyzer.analyze_ops(tf.get_default_graph(), print_info=True) + slim.model_analyzer.analyze_ops(tf.compat.v1.get_default_graph(), + print_info=True) To analyze the model variables in a graph: - variables = tf.model_variables() + variables = tf.compat.v1.model_variables() slim.model_analyzer.analyze_vars(variables, print_info=False) """ from __future__ import absolute_import @@ -84,7 +85,7 @@ def analyze_vars(variables, print_info=False): """Prints the names and shapes of the variables. Args: - variables: list of variables, for example tf.global_variables(). + variables: list of variables, for example tf.compat.v1.global_variables(). print_info: Optional, if true print variables and their shape. Returns: @@ -103,8 +104,8 @@ def analyze_vars(variables, print_info=False): total_size += var_size total_bytes += var_bytes if print_info: - print(var.name, tensor_description(var), '[%d, bytes: %d]' % - (var_size, var_bytes)) + print(var.name, tensor_description(var), + '[%d, bytes: %d]' % (var_size, var_bytes)) if print_info: print('Total size of variables: %d' % total_size) print('Total bytes of variables: %d' % total_bytes) diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD index e9595d1b324..f19177b1881 100644 --- a/tensorflow/contrib/slim/python/slim/nets/BUILD +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -45,6 +45,7 @@ py_test( name = "alexnet_test", size = "medium", srcs = ["alexnet_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":alexnet", @@ -117,6 +118,7 @@ py_test( name = "inception_v1_test", size = "medium", srcs = ["inception_v1_test.py"], + python_version = "PY2", shard_count = 8, srcs_version = "PY2AND3", deps = [ @@ -137,6 +139,7 @@ py_test( name = "inception_v2_test", size = "medium", srcs = ["inception_v2_test.py"], + python_version = "PY2", shard_count = 8, srcs_version = "PY2AND3", deps = [ @@ -157,6 +160,7 @@ py_test( name = "inception_v3_test", size = "medium", srcs = ["inception_v3_test.py"], + python_version = "PY2", shard_count = 8, srcs_version = "PY2AND3", deps = [ @@ -191,6 +195,7 @@ py_test( name = "overfeat_test", size = "medium", srcs = ["overfeat_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":overfeat", @@ -235,6 +240,7 @@ py_test( name = "resnet_v1_test", size = "medium", srcs = ["resnet_v1_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", deps = [ @@ -271,6 +277,7 @@ py_test( name = "resnet_v2_test", size = "medium", srcs = ["resnet_v2_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", deps = [ @@ -307,6 +314,7 @@ py_test( name = "vgg_test", size = "medium", srcs = ["vgg_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":vgg", diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py index 8ff44fe4b5f..1cc54b15514 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py @@ -54,18 +54,19 @@ def create_test_input(batch_size, height, width, channels): return array_ops.placeholder(dtypes.float32, (batch_size, height, width, channels)) else: - return math_ops.to_float( + return math_ops.cast( np.tile( np.reshape( np.reshape(np.arange(height), [height, 1]) + np.reshape( np.arange(width), [1, width]), [1, height, width, 1]), - [batch_size, 1, 1, channels])) + [batch_size, 1, 1, channels]), dtypes.float32) class ResnetUtilsTest(test.TestCase): def testSubsampleThreeByThree(self): - x = array_ops.reshape(math_ops.to_float(math_ops.range(9)), [1, 3, 3, 1]) + x = array_ops.reshape(math_ops.cast(math_ops.range(9), dtypes.float32), + [1, 3, 3, 1]) x = resnet_utils.subsample(x, 2) expected = array_ops.reshape( constant_op.constant([0, 2, 6, 8]), [1, 2, 2, 1]) @@ -73,7 +74,8 @@ class ResnetUtilsTest(test.TestCase): self.assertAllClose(x.eval(), expected.eval()) def testSubsampleFourByFour(self): - x = array_ops.reshape(math_ops.to_float(math_ops.range(16)), [1, 4, 4, 1]) + x = array_ops.reshape(math_ops.cast(math_ops.range(16), dtypes.float32), + [1, 4, 4, 1]) x = resnet_utils.subsample(x, 2) expected = array_ops.reshape( constant_op.constant([0, 2, 8, 10]), [1, 2, 2, 1]) @@ -95,19 +97,20 @@ class ResnetUtilsTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() y1 = layers.conv2d(x, 1, [3, 3], stride=1, scope='Conv') - y1_expected = math_ops.to_float([[14, 28, 43, 26], [28, 48, 66, 37], - [43, 66, 84, 46], [26, 37, 46, 22]]) + y1_expected = math_ops.cast([[14, 28, 43, 26], [28, 48, 66, 37], + [43, 66, 84, 46], [26, 37, 46, 22]], + dtypes.float32) y1_expected = array_ops.reshape(y1_expected, [1, n, n, 1]) y2 = resnet_utils.subsample(y1, 2) - y2_expected = math_ops.to_float([[14, 43], [43, 84]]) + y2_expected = math_ops.cast([[14, 43], [43, 84]], dtypes.float32) y2_expected = array_ops.reshape(y2_expected, [1, n2, n2, 1]) y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv') y3_expected = y2_expected y4 = layers.conv2d(x, 1, [3, 3], stride=2, scope='Conv') - y4_expected = math_ops.to_float([[48, 37], [37, 22]]) + y4_expected = math_ops.cast([[48, 37], [37, 22]], dtypes.float32) y4_expected = array_ops.reshape(y4_expected, [1, n2, n2, 1]) with self.cached_session() as sess: @@ -132,14 +135,19 @@ class ResnetUtilsTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() y1 = layers.conv2d(x, 1, [3, 3], stride=1, scope='Conv') - y1_expected = math_ops.to_float([[14, 28, 43, 58, 34], [28, 48, 66, 84, 46], - [43, 66, 84, 102, 55], - [58, 84, 102, 120, 64], - [34, 46, 55, 64, 30]]) + y1_expected = math_ops.cast([[14, 28, 43, 58, 34], + [28, 48, 66, 84, 46], + [43, 66, 84, 102, 55], + [58, 84, 102, 120, 64], + [34, 46, 55, 64, 30]], + dtypes.float32) y1_expected = array_ops.reshape(y1_expected, [1, n, n, 1]) y2 = resnet_utils.subsample(y1, 2) - y2_expected = math_ops.to_float([[14, 43, 34], [43, 84, 55], [34, 55, 30]]) + y2_expected = math_ops.cast([[14, 43, 34], + [43, 84, 55], + [34, 55, 30]], + dtypes.float32) y2_expected = array_ops.reshape(y2_expected, [1, n2, n2, 1]) y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv') diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py index 055ecff1c32..31bdea9fbcd 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py @@ -54,18 +54,20 @@ def create_test_input(batch_size, height, width, channels): return array_ops.placeholder(dtypes.float32, (batch_size, height, width, channels)) else: - return math_ops.to_float( + return math_ops.cast( np.tile( np.reshape( np.reshape(np.arange(height), [height, 1]) + np.reshape( np.arange(width), [1, width]), [1, height, width, 1]), - [batch_size, 1, 1, channels])) + [batch_size, 1, 1, channels]), + dtypes.float32) class ResnetUtilsTest(test.TestCase): def testSubsampleThreeByThree(self): - x = array_ops.reshape(math_ops.to_float(math_ops.range(9)), [1, 3, 3, 1]) + x = array_ops.reshape(math_ops.cast(math_ops.range(9), dtypes.float32), + [1, 3, 3, 1]) x = resnet_utils.subsample(x, 2) expected = array_ops.reshape( constant_op.constant([0, 2, 6, 8]), [1, 2, 2, 1]) @@ -73,7 +75,8 @@ class ResnetUtilsTest(test.TestCase): self.assertAllClose(x.eval(), expected.eval()) def testSubsampleFourByFour(self): - x = array_ops.reshape(math_ops.to_float(math_ops.range(16)), [1, 4, 4, 1]) + x = array_ops.reshape(math_ops.cast(math_ops.range(16), dtypes.float32), + [1, 4, 4, 1]) x = resnet_utils.subsample(x, 2) expected = array_ops.reshape( constant_op.constant([0, 2, 8, 10]), [1, 2, 2, 1]) @@ -95,19 +98,22 @@ class ResnetUtilsTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() y1 = layers.conv2d(x, 1, [3, 3], stride=1, scope='Conv') - y1_expected = math_ops.to_float([[14, 28, 43, 26], [28, 48, 66, 37], - [43, 66, 84, 46], [26, 37, 46, 22]]) + y1_expected = math_ops.cast([[14, 28, 43, 26], + [28, 48, 66, 37], + [43, 66, 84, 46], + [26, 37, 46, 22]], + dtypes.float32) y1_expected = array_ops.reshape(y1_expected, [1, n, n, 1]) y2 = resnet_utils.subsample(y1, 2) - y2_expected = math_ops.to_float([[14, 43], [43, 84]]) + y2_expected = math_ops.cast([[14, 43], [43, 84]], dtypes.float32) y2_expected = array_ops.reshape(y2_expected, [1, n2, n2, 1]) y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv') y3_expected = y2_expected y4 = layers.conv2d(x, 1, [3, 3], stride=2, scope='Conv') - y4_expected = math_ops.to_float([[48, 37], [37, 22]]) + y4_expected = math_ops.cast([[48, 37], [37, 22]], dtypes.float32) y4_expected = array_ops.reshape(y4_expected, [1, n2, n2, 1]) with self.cached_session() as sess: @@ -132,17 +138,19 @@ class ResnetUtilsTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() y1 = layers.conv2d(x, 1, [3, 3], stride=1, scope='Conv') - y1_expected = math_ops.to_float([[14, 28, 43, 58, 34], - [28, 48, 66, 84, 46], - [43, 66, 84, 102, 55], - [58, 84, 102, 120, 64], - [34, 46, 55, 64, 30]]) + y1_expected = math_ops.cast([[14, 28, 43, 58, 34], + [28, 48, 66, 84, 46], + [43, 66, 84, 102, 55], + [58, 84, 102, 120, 64], + [34, 46, 55, 64, 30]], + dtypes.float32) y1_expected = array_ops.reshape(y1_expected, [1, n, n, 1]) y2 = resnet_utils.subsample(y1, 2) - y2_expected = math_ops.to_float([[14, 43, 34], - [43, 84, 55], - [34, 55, 30]]) + y2_expected = math_ops.cast([[14, 43, 34], + [43, 84, 55], + [34, 55, 30]], + dtypes.float32) y2_expected = array_ops.reshape(y2_expected, [1, n2, n2, 1]) y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv') diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index 7dd52df6b68..4085801342b 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -9,6 +9,7 @@ load("//tensorflow:tensorflow.bzl", "py_test") py_test( name = "summary_ops_test", srcs = ["summary_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":summary", @@ -29,6 +30,7 @@ py_test( py_test( name = "summary_ops_graph_test", srcs = ["summary_ops_graph_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":summary", diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 42898e797cc..e0159a8e8e4 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -64,7 +64,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import -from tensorflow.python.ops.summary_ops_v2 import all_summary_ops +from tensorflow.python.ops.summary_ops_v2 import all_v2_summary_ops as all_summary_ops from tensorflow.python.ops.summary_ops_v2 import always_record_summaries from tensorflow.python.ops.summary_ops_v2 import audio from tensorflow.python.ops.summary_ops_v2 import create_db_writer diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py index 8e13f7f56b2..6606a4c33c5 100644 --- a/tensorflow/contrib/summary/summary_ops_graph_test.py +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -108,7 +108,7 @@ class GraphFileTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: sess.run(summary_ops.summary_writer_initializer_op()) get_total = lambda: len(summary_test_util.events_from_logdir(logdir)) - # Note: First tf.Event is always file_version. + # Note: First tf.compat.v1.Event is always file_version. self.assertEqual(1, get_total()) sess.run(summary_ops.all_summary_ops()) self.assertEqual(1, get_total()) @@ -126,7 +126,7 @@ class GraphFileTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: sess.run(summary_ops.summary_writer_initializer_op()) get_total = lambda: len(summary_test_util.events_from_logdir(logdir)) - # Note: First tf.Event is always file_version. + # Note: First tf.compat.v1.Event is always file_version. self.assertEqual(1, get_total()) sess.run(summary_ops.all_summary_ops()) self.assertEqual(1, get_total()) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 27bfdeb3601..6cdcea1801a 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -179,7 +179,7 @@ class EagerFileTest(test_util.TensorFlowTestCase): logs, max_queue=1, flush_millis=999999, name='lol').as_default(), summary_ops.always_record_summaries(): get_total = lambda: len(summary_test_util.events_from_logdir(logs)) - # Note: First tf.Event is always file_version. + # Note: First tf.compat.v1.Event is always file_version. self.assertEqual(1, get_total()) summary_ops.scalar('scalar', 2.0, step=1) self.assertEqual(1, get_total()) @@ -193,7 +193,7 @@ class EagerFileTest(test_util.TensorFlowTestCase): logs, max_queue=999999, flush_millis=999999, name='lol') with writer.as_default(), summary_ops.always_record_summaries(): get_total = lambda: len(summary_test_util.events_from_logdir(logs)) - # Note: First tf.Event is always file_version. + # Note: First tf.compat.v1.Event is always file_version. self.assertEqual(1, get_total()) summary_ops.scalar('scalar', 2.0, step=1) summary_ops.scalar('scalar', 2.0, step=2) diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py index b4ae43302cb..f15b7aebcc9 100644 --- a/tensorflow/contrib/summary/summary_test_util.py +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -59,7 +59,7 @@ def events_from_file(filepath): filepath: Path to the event file. Returns: - A list of all tf.Event protos in the event file. + A list of all tf.compat.v1.Event protos in the event file. """ records = list(tf_record.tf_record_iterator(filepath)) result = [] @@ -77,7 +77,7 @@ def events_from_logdir(logdir): logdir: The directory in which the single event file is sought. Returns: - A list of all tf.Event protos from the single event file. + A list of all tf.compat.v1.Event protos from the single event file. Raises: AssertionError: If logdir does not contain exactly one file. diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 583bbf97c57..a7f8819915b 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -436,6 +436,7 @@ py_test( name = "eval_metrics_test", size = "small", srcs = ["client/eval_metrics_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":eval_metrics", @@ -461,6 +462,7 @@ py_test( name = "scatter_add_ndim_op_test", size = "small", srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_gpu", @@ -502,6 +504,7 @@ py_test( name = "tensor_forest_test", size = "small", srcs = ["python/tensor_forest_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":tensor_forest_py", @@ -539,6 +542,7 @@ py_test( name = "random_forest_test", size = "medium", srcs = ["client/random_forest_test.py"], + python_version = "PY2", shard_count = 6, srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index 0d87cea9fba..0b4125f00f9 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib import losses from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -35,7 +36,7 @@ FEATURE_IMPORTANCE_NAME = 'global_feature_importance' def _top_k_generator(k): def _top_k(probabilities, targets): - targets = math_ops.to_int32(targets) + targets = math_ops.cast(targets, dtypes.int32) if targets.get_shape().ndims > 1: targets = array_ops.squeeze(targets, axis=[1]) return metrics.mean(nn.in_top_k(probabilities, targets, k)) @@ -48,7 +49,7 @@ def _accuracy(predictions, targets, weights=None): def _r2(probabilities, targets, weights=None): - targets = math_ops.to_float(targets) + targets = math_ops.cast(targets, dtypes.float32) y_mean = math_ops.reduce_mean(targets, 0) squares_total = math_ops.reduce_sum( math_ops.squared_difference(targets, y_mean), 0) @@ -60,7 +61,7 @@ def _r2(probabilities, targets, weights=None): def _squeeze_and_onehot(targets, depth): targets = array_ops.squeeze(targets, axis=[1]) - return array_ops.one_hot(math_ops.to_int32(targets), depth) + return array_ops.one_hot(math_ops.cast(targets, dtypes.int32), depth) def _sigmoid_entropy(probabilities, targets, weights=None): @@ -75,7 +76,7 @@ def _sigmoid_entropy(probabilities, targets, weights=None): def _softmax_entropy(probabilities, targets, weights=None): return metrics.mean( losses.sparse_softmax_cross_entropy(probabilities, - math_ops.to_int32(targets)), + math_ops.cast(targets, dtypes.int32)), weights=weights) diff --git a/tensorflow/contrib/tensor_forest/hybrid/BUILD b/tensorflow/contrib/tensor_forest/hybrid/BUILD index b7185e09c70..64176a0dd07 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/BUILD +++ b/tensorflow/contrib/tensor_forest/hybrid/BUILD @@ -122,6 +122,7 @@ py_test( name = "hybrid_layer_test", size = "small", srcs = ["python/hybrid_layer_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":fully_connected_layer", @@ -163,6 +164,7 @@ py_test( name = "routing_function_op_test", size = "small", srcs = ["python/kernel_tests/routing_function_op_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["manual"], deps = [ @@ -177,6 +179,7 @@ py_test( name = "k_feature_routing_function_op_test", size = "small", srcs = ["python/kernel_tests/k_feature_routing_function_op_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["manual"], deps = [ @@ -206,6 +209,7 @@ py_library( py_test( name = "decisions_to_data_test", srcs = ["python/layers/decisions_to_data_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":decisions_to_data_layer", @@ -247,6 +251,7 @@ py_test( name = "decisions_to_data_then_nn_test", size = "small", srcs = ["python/models/decisions_to_data_then_nn_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":decisions_to_data_then_nn", @@ -274,6 +279,7 @@ py_test( name = "k_feature_decisions_to_data_then_nn_test", size = "small", srcs = ["python/models/k_feature_decisions_to_data_then_nn_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":k_feature_decisions_to_data_then_nn", @@ -301,6 +307,7 @@ py_test( name = "forest_to_data_then_nn_test", size = "small", srcs = ["python/models/forest_to_data_then_nn_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":forest_to_data_then_nn", diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc index 34388fe1aab..c7c6a85bf6a 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc @@ -15,6 +15,7 @@ #include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h" #include +#include #include #include "tensorflow/core/lib/random/philox_random.h" @@ -36,7 +37,7 @@ float LeftProbability(const Tensor& point, const Tensor& weight, float bias, // TODO(thomaswc): At some point we should consider // //learning/logistic/logodds-to-prob.h - return 1.0 / (1.0 + exp(-dot_product + bias)); + return 1.0 / (1.0 + std::exp(-dot_product + bias)); } float LeftProbabilityK(const Tensor& point, std::vector feature_set, @@ -54,7 +55,7 @@ float LeftProbabilityK(const Tensor& point, std::vector feature_set, // TODO(thomaswc): At some point we should consider // //learning/logistic/logodds-to-prob.h - return 1.0 / (1.0 + exp(-dot_product + bias)); + return 1.0 / (1.0 + std::exp(-dot_product + bias)); } void GetFeatureSet(int32 tree_num, int32 node_num, int32 random_seed, diff --git a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py index a427a02b7cd..926e4dda916 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py +++ b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py @@ -22,6 +22,7 @@ import collections from tensorflow.contrib import layers from tensorflow.contrib.framework.python.ops import variables as framework_variables +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -110,14 +111,15 @@ class HybridModel(object): """The loss to minimize while training.""" if self.is_regression: - diff = self.training_inference_graph(data) - math_ops.to_float(labels) + diff = self.training_inference_graph(data) - math_ops.cast( + labels, dtypes.float32) mean_squared_error = math_ops.reduce_mean(diff * diff) root_mean_squared_error = math_ops.sqrt(mean_squared_error, name="loss") loss = root_mean_squared_error else: loss = math_ops.reduce_mean( nn_ops.sparse_softmax_cross_entropy_with_logits( - labels=array_ops.squeeze(math_ops.to_int32(labels)), + labels=array_ops.squeeze(math_ops.cast(labels, dtypes.int32)), logits=self.training_inference_graph(data)), name="loss") if self.regularizer: diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc index 63d4d9ba506..e8b494fc166 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc @@ -15,6 +15,7 @@ #include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" #include +#include #include #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" @@ -272,7 +273,8 @@ void ClassificationStats::CheckPruneHoeffding() { // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted. const float num_classes = params_.num_outputs(); const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes); - float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_); + float epsilon = + gini_diff_range * std::sqrt(half_ln_dominate_frac_ / weight_sum_); for (int i = num_splits() - 1; i >= 0; i--) { if (split_scores[i] - best_split_score > epsilon) { RemoveSplit(i); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params.cc b/tensorflow/contrib/tensor_forest/kernels/v4/params.cc index a3b09c17d51..6387e7c2bc7 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/params.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params.cc @@ -15,6 +15,7 @@ #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" #include #include +#include #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -35,8 +36,8 @@ float ResolveParam(const DepthDependentParam& param, int32 depth) { return param.exponential().bias() + param.exponential().multiplier() * static_cast( - pow(param.exponential().base(), - param.exponential().depth_multiplier() * depth)); + std::pow(param.exponential().base(), + param.exponential().depth_multiplier() * depth)); case DepthDependentParam::kThreshold: if (depth >= param.threshold().threshold()) { diff --git a/tensorflow/contrib/tensor_forest/python/ops/data_ops.py b/tensorflow/contrib/tensor_forest/python/ops/data_ops.py index f878e5989cf..5c1fe23981d 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/data_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/data_ops.py @@ -44,7 +44,7 @@ def CastToFloat(tensor): if tensor.dtype == dtypes.string: return tensor_forest_ops.reinterpret_string_to_float(tensor) elif tensor.dtype.is_integer: - return math_ops.to_float(tensor) + return math_ops.cast(tensor, dtypes.float32) else: return tensor @@ -195,7 +195,7 @@ def ParseLabelTensorOrDict(labels): A 2-D tensor for labels/outputs. """ if isinstance(labels, dict): - return math_ops.to_float( + return math_ops.cast( array_ops.concat( [ sparse_ops.sparse_tensor_to_dense( @@ -203,10 +203,12 @@ def ParseLabelTensorOrDict(labels): labels, sparse_tensor.SparseTensor) else labels[k] for k in sorted(labels.keys()) ], - 1)) + 1), + dtypes.float32) else: if isinstance(labels, sparse_tensor.SparseTensor): - return math_ops.to_float(sparse_ops.sparse_tensor_to_dense( - labels, default_value=-1)) + return math_ops.cast( + sparse_ops.sparse_tensor_to_dense(labels, default_value=-1), + dtypes.float32) else: - return math_ops.to_float(labels) + return math_ops.cast(labels, dtypes.float32) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 6f62cd11a97..df10997d633 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -30,6 +30,7 @@ from tensorflow.contrib.tensor_forest.python.ops import data_ops from tensorflow.contrib.tensor_forest.python.ops import model_ops from tensorflow.contrib.tensor_forest.python.ops import stats_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -39,21 +40,18 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging - # Stores tuples of (leaf model type, stats model type) CLASSIFICATION_LEAF_MODEL_TYPES = { 'all_dense': (_params_proto.MODEL_DENSE_CLASSIFICATION, _params_proto.STATS_DENSE_GINI), 'all_sparse': (_params_proto.MODEL_SPARSE_CLASSIFICATION, _params_proto.STATS_SPARSE_GINI), - 'sparse_then_dense': - (_params_proto.MODEL_SPARSE_OR_DENSE_CLASSIFICATION, - _params_proto.STATS_SPARSE_THEN_DENSE_GINI), + 'sparse_then_dense': (_params_proto.MODEL_SPARSE_OR_DENSE_CLASSIFICATION, + _params_proto.STATS_SPARSE_THEN_DENSE_GINI), } -REGRESSION_MODEL_TYPE = ( - _params_proto.MODEL_REGRESSION, - _params_proto.STATS_LEAST_SQUARES_REGRESSION, - _params_proto.COLLECTION_BASIC) +REGRESSION_MODEL_TYPE = (_params_proto.MODEL_REGRESSION, + _params_proto.STATS_LEAST_SQUARES_REGRESSION, + _params_proto.COLLECTION_BASIC) FINISH_TYPES = { 'basic': _params_proto.SPLIT_FINISH_BASIC, @@ -205,9 +203,10 @@ class ForestHParams(object): self.bagged_features = None if self.feature_bagging_fraction < 1.0: - self.bagged_features = [random.sample( - range(self.num_features), - self.bagged_num_features) for _ in range(self.num_trees)] + self.bagged_features = [ + random.sample(range(self.num_features), self.bagged_num_features) + for _ in range(self.num_trees) + ] self.regression = getattr(self, 'regression', False) @@ -240,8 +239,8 @@ class ForestHParams(object): CLASSIFICATION_LEAF_MODEL_TYPES[self.model_name][1]) self.finish_type = ( - _params_proto.SPLIT_FINISH_BASIC if self.regression else - FINISH_TYPES[self.split_finish_name]) + _params_proto.SPLIT_FINISH_BASIC + if self.regression else FINISH_TYPES[self.split_finish_name]) self.pruning_type = PRUNING_TYPES[self.split_pruning_name] @@ -258,8 +257,8 @@ class ForestHParams(object): # default, making it easy to select the number being pruned with # pruning_type while not paying the cost of pruning too often. Note that # this only holds if not using a depth-dependent split_after_samples. - self.prune_every_samples = (self.prune_every_samples or - int(self.split_after_samples) / 2) + self.prune_every_samples = ( + self.prune_every_samples or int(self.split_after_samples) / 2) if self.finish_type == _params_proto.SPLIT_FINISH_BASIC: self.early_finish_check_every_samples = 0 @@ -298,15 +297,15 @@ def get_epoch_variable(): class TreeVariables(object): """Stores tf.Variables for training a single random tree. - Uses tf.get_variable to get tree-specific names so that this can be used + Uses tf.compat.v1.get_variable to get tree-specific names so that this can be + used with a tf.learn-style implementation (one that trains a model, saves it, then relies on restoring that model to evaluate). """ def __init__(self, params, tree_num, training, tree_config='', tree_stat=''): if (not hasattr(params, 'params_proto') or - not isinstance(params.params_proto, - _params_proto.TensorForestParams)): + not isinstance(params.params_proto, _params_proto.TensorForestParams)): params.params_proto = build_params_proto(params) params.serialized_params_proto = params.params_proto.SerializeToString() @@ -316,8 +315,8 @@ class TreeVariables(object): # multiple machines. self.stats = stats_ops.fertile_stats_variable( params, tree_stat, self.get_tree_name('stats', tree_num)) - self.tree = model_ops.tree_variable( - params, tree_config, self.stats, self.get_tree_name('tree', tree_num)) + self.tree = model_ops.tree_variable(params, tree_config, self.stats, + self.get_tree_name('tree', tree_num)) def get_tree_name(self, name, num): return '{0}-{1}'.format(name, num) @@ -334,17 +333,21 @@ class ForestVariables(object): ... forest_variables.tree ... """ - def __init__(self, params, device_assigner, training=True, + def __init__(self, + params, + device_assigner, + training=True, tree_variables_class=TreeVariables, - tree_configs=None, tree_stats=None): + tree_configs=None, + tree_stats=None): self.variables = [] # Set up some scalar variables to run through the device assigner, then # we can use those to colocate everything related to a tree. self.device_dummies = [] with ops.device(device_assigner): for i in range(params.num_trees): - self.device_dummies.append(variable_scope.get_variable( - name='device_dummy_%d' % i, shape=0)) + self.device_dummies.append( + variable_scope.get_variable(name='device_dummy_%d' % i, shape=0)) for i in range(params.num_trees): with ops.device(self.device_dummies[i].device): @@ -353,8 +356,8 @@ class ForestVariables(object): kwargs.update(dict(tree_config=tree_configs[i])) if tree_stats is not None: kwargs.update(dict(tree_stat=tree_stats[i])) - self.variables.append(tree_variables_class( - params, i, training, **kwargs)) + self.variables.append( + tree_variables_class(params, i, training, **kwargs)) def __setitem__(self, t, val): self.variables[t] = val @@ -381,9 +384,12 @@ class RandomForestGraphs(object): logging.info('Constructing forest with params = ') logging.info(self.params.__dict__) self.variables = variables or ForestVariables( - self.params, device_assigner=self.device_assigner, training=training, + self.params, + device_assigner=self.device_assigner, + training=training, tree_variables_class=tree_variables_class, - tree_configs=tree_configs, tree_stats=tree_stats) + tree_configs=tree_configs, + tree_stats=tree_stats) tree_graph_class = tree_graphs or RandomTreeGraphs self.trees = [ tree_graph_class(self.variables[i], self.params, i) @@ -453,9 +459,9 @@ class RandomForestGraphs(object): array_ops.shape(processed_dense_features), [0], [1]) r = random_ops.random_uniform(batch_size, seed=seed) mask = math_ops.less( - r, array_ops.ones_like(r) * self.params.bagging_fraction) - gather_indices = array_ops.squeeze( - array_ops.where(mask), axis=[1]) + r, + array_ops.ones_like(r) * self.params.bagging_fraction) + gather_indices = array_ops.squeeze(array_ops.where(mask), axis=[1]) # TODO(thomaswc): Calculate out-of-bag data and labels, and store # them for use in calculating statistics later. tree_data = array_ops.gather(processed_dense_features, gather_indices) @@ -480,11 +486,11 @@ class RandomForestGraphs(object): """Constructs a TF graph for evaluating a random forest. Args: - input_data: A tensor or dict of string->Tensor for the input data. - This input_data must generate the same spec as the - input_data used in training_graph: the dict must have - the same keys, for example, and all tensors must have - the same size in their first dimension. + input_data: A tensor or dict of string->Tensor for the input data. This + input_data must generate the same spec as the + input_data used in training_graph: the dict must have the + same keys, for example, and all tensors must have the same + size in their first dimension. **inference_args: Keyword arguments to pass through to each tree. Returns: @@ -540,7 +546,8 @@ class RandomForestGraphs(object): for i in range(self.params.num_trees): with ops.device(self.variables.device_dummies[i].device): sizes.append(self.trees[i].size()) - return math_ops.reduce_mean(math_ops.to_float(array_ops.stack(sizes))) + return math_ops.reduce_mean( + math_ops.cast(array_ops.stack(sizes), dtypes.float32)) # pylint: disable=unused-argument def training_loss(self, features, labels, name='training_loss'): @@ -563,8 +570,10 @@ class RandomForestGraphs(object): return math_ops.reduce_mean(array_ops.stack(impurities)) def feature_importances(self): - tree_counts = [self.trees[i].feature_usage_counts() - for i in range(self.params.num_trees)] + tree_counts = [ + self.trees[i].feature_usage_counts() + for i in range(self.params.num_trees) + ] total_counts = math_ops.reduce_sum(array_ops.stack(tree_counts, 0), 0) return total_counts / math_ops.reduce_sum(total_counts) @@ -584,7 +593,6 @@ class RandomTreeGraphs(object): data_spec, sparse_features=None, input_weights=None): - """Constructs a TF graph for training a random tree. Args: @@ -593,17 +601,17 @@ class RandomTreeGraphs(object): input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. - data_spec: A data_ops.TensorForestDataSpec object specifying the - original feature/columns of the data. + data_spec: A data_ops.TensorForestDataSpec object specifying the original + feature/columns of the data. sparse_features: A tf.SparseTensor for sparse input data. - input_weights: A float tensor or placeholder holding per-input weights, - or None if all inputs are to be weighted equally. + input_weights: A float tensor or placeholder holding per-input weights, or + None if all inputs are to be weighted equally. Returns: The last op in the random tree training graph. """ # TODO(gilberth): Use this. - unused_epoch = math_ops.to_int32(get_epoch_variable()) + unused_epoch = math_ops.cast(get_epoch_variable(), dtypes.int32) if input_weights is None: input_weights = [] @@ -661,8 +669,8 @@ class RandomTreeGraphs(object): Args: input_data: A tensor or placeholder for input data. - data_spec: A TensorForestDataSpec proto specifying the original - input columns. + data_spec: A TensorForestDataSpec proto specifying the original input + columns. sparse_features: A tf.SparseTensor for sparse input data. Returns: diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index 7f0b3255ed6..85070cfad01 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -50,6 +50,7 @@ py_test( name = "projector_api_test", size = "small", srcs = ["plugins/projector/projector_api_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":projector", diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index b683c14c0d7..1c6e5ee7d05 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -21,10 +21,11 @@ limitations under the License. #include #define EIGEN_USE_GPU -#include "cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/core/util/cuda_launch_config.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h index c35955e1057..0d4893cd5d6 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -18,7 +18,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index fb048d7b19d..80a5252e3db 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -23,7 +23,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 0cae401023e..c29665b9a82 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -19,6 +19,5 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 4a959378138..8f4f1edae0b 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -31,7 +31,6 @@ def create_inference_graph( is_dynamic_op=False, maximum_cached_engines=1, cached_engine_batches=None, - use_calibration=True, input_saved_model_dir=None, input_saved_model_tags=None, output_saved_model_dir=None, @@ -46,13 +45,7 @@ def create_inference_graph( is_dynamic_op=is_dynamic_op, maximum_cached_engines=maximum_cached_engines, cached_engine_batches=cached_engine_batches, - use_calibration=use_calibration, input_saved_model_dir=input_saved_model_dir, input_saved_model_tags=input_saved_model_tags, output_saved_model_dir=output_saved_model_dir, session_config=session_config) - - -def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): - return trt_convert.calib_graph_to_infer_graph( - calibration_graph_def=calibration_graph_def, is_dynamic_op=is_dynamic_op) diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 5c60d6b589e..6dd83452e3a 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorrt/include/NvInfer.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace shape_inference { diff --git a/tensorflow/contrib/text/BUILD b/tensorflow/contrib/text/BUILD index a434c120393..9f9e19a7cd6 100644 --- a/tensorflow/contrib/text/BUILD +++ b/tensorflow/contrib/text/BUILD @@ -97,6 +97,7 @@ py_test( name = "skip_gram_ops_test", size = "medium", srcs = ["python/ops/skip_gram_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":text_py", diff --git a/tensorflow/contrib/timeseries/BUILD b/tensorflow/contrib/timeseries/BUILD index f2b8786a527..18933227b34 100644 --- a/tensorflow/contrib/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/BUILD @@ -23,10 +23,10 @@ py_library( name = "timeseries_pip", deps = [ ":timeseries", - "//tensorflow/contrib/timeseries/examples:known_anomaly", - "//tensorflow/contrib/timeseries/examples:lstm", - "//tensorflow/contrib/timeseries/examples:multivariate", - "//tensorflow/contrib/timeseries/examples:predict", + "//tensorflow/contrib/timeseries/examples:known_anomaly_main_lib", + "//tensorflow/contrib/timeseries/examples:lstm_main_lib", + "//tensorflow/contrib/timeseries/examples:multivariate_main_lib", + "//tensorflow/contrib/timeseries/examples:predict_main_lib", "//tensorflow/contrib/timeseries/python/timeseries:test_utils", "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils", ], diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 70c3a0720ee..235f3adb92f 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -17,6 +17,7 @@ config_setting( py_binary( name = "predict", srcs = ["predict.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [":predict_main_lib"], @@ -42,6 +43,7 @@ py_test( timeout = "long", # Moderate but for asan srcs = ["predict_test.py"], data = ["data/period_trend.csv"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_windows", # TODO: needs investigation on Windows @@ -56,6 +58,7 @@ py_test( py_binary( name = "known_anomaly", srcs = ["known_anomaly.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [":known_anomaly_main_lib"], @@ -80,6 +83,7 @@ py_test( name = "known_anomaly_test", timeout = "long", # Moderate but for asan srcs = ["known_anomaly_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":known_anomaly_main_lib", @@ -90,6 +94,7 @@ py_test( py_binary( name = "multivariate", srcs = ["multivariate.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], deps = [":multivariate_main_lib"], @@ -116,6 +121,7 @@ py_test( srcs = [ "multivariate_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":multivariate_main_lib", @@ -126,6 +132,7 @@ py_test( py_binary( name = "lstm", srcs = ["lstm.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["no_pip"], visibility = ["//visibility:public"], @@ -155,6 +162,7 @@ py_test( name = "lstm_test", timeout = "long", # Moderate but for asan srcs = ["lstm_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["notsan"], deps = [ diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 4ba814b9e3d..ae2c4a5cb72 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -104,6 +104,7 @@ py_test( srcs = [ "estimators_test.py", ], + python_version = "PY2", shard_count = 3, srcs_version = "PY2AND3", tags = [ @@ -159,6 +160,7 @@ py_test( srcs = [ "head_test.py", ], + python_version = "PY2", shard_count = 10, srcs_version = "PY2AND3", tags = [ @@ -172,7 +174,7 @@ py_test( ":input_pipeline", ":model", ":state_management", - "//tensorflow/contrib/timeseries/examples:lstm", + "//tensorflow/contrib/timeseries/examples:lstm_main_lib", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -214,6 +216,7 @@ py_test( srcs = [ "model_utils_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_pip_gpu", # b/63391119 @@ -249,6 +252,7 @@ py_test( srcs = [ "state_management_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_oss", @@ -314,6 +318,7 @@ py_test( srcs = [ "input_pipeline_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_oss", # b/63709811 @@ -390,6 +395,7 @@ py_test( srcs = [ "ar_model_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -443,6 +449,7 @@ py_test( srcs = [ "math_utils_test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_pip_gpu", # b/63391119 diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 146ed9f2713..ce50d3c9849 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -47,20 +47,25 @@ from tensorflow.python.util import nest class TimeSeriesRegressor(estimator_lib.Estimator): """An Estimator to fit and evaluate a time series model.""" - def __init__(self, model, state_manager=None, optimizer=None, model_dir=None, - config=None, head_type=ts_head_lib.TimeSeriesRegressionHead): + def __init__(self, + model, + state_manager=None, + optimizer=None, + model_dir=None, + config=None, + head_type=ts_head_lib.TimeSeriesRegressionHead): """Initialize the Estimator. Args: model: The time series model to wrap (inheriting from TimeSeriesModel). state_manager: The state manager to use, or (by default) - PassthroughStateManager if none is needed. + PassthroughStateManager if none is needed. optimizer: The optimization algorithm to use when training, inheriting - from tf.train.Optimizer. Defaults to Adam with step size 0.02. + from tf.train.Optimizer. Defaults to Adam with step size 0.02. model_dir: See `Estimator`. config: See `Estimator`. head_type: The kind of head to use for the model (inheriting from - `TimeSeriesRegressionHead`). + `TimeSeriesRegressionHead`). """ input_statistics_generator = math_utils.InputStatisticsFromMiniBatch( dtype=model.dtype, num_features=model.num_features) @@ -73,16 +78,17 @@ class TimeSeriesRegressor(estimator_lib.Estimator): optimizer = train.AdamOptimizer(0.02) self._model = model ts_regression_head = head_type( - model=model, state_manager=state_manager, optimizer=optimizer, + model=model, + state_manager=state_manager, + optimizer=optimizer, input_statistics_generator=input_statistics_generator) model_fn = ts_regression_head.create_estimator_spec super(TimeSeriesRegressor, self).__init__( - model_fn=model_fn, - model_dir=model_dir, - config=config) + model_fn=model_fn, model_dir=model_dir, config=config) - def _model_start_state_placeholders( - self, batch_size_tensor, static_batch_size=None): + def _model_start_state_placeholders(self, + batch_size_tensor, + static_batch_size=None): """Creates placeholders with zeroed start state for the current model.""" gathered_state = {} # Models may not know the shape of their state without creating some @@ -90,33 +96,39 @@ class TimeSeriesRegressor(estimator_lib.Estimator): # use only static metadata from the returned Tensors. with ops.Graph().as_default(): self._model.initialize_graph() + # Evaluate the initial state as same-dtype "zero" values. These zero # constants aren't used, but are necessary for feeding to # placeholder_with_default for the "cold start" case where state is not # fed to the model. def _zeros_like_constant(tensor): return tensor_util.constant_value(array_ops.zeros_like(tensor)) - start_state = nest.map_structure( - _zeros_like_constant, self._model.get_start_state()) + + start_state = nest.map_structure(_zeros_like_constant, + self._model.get_start_state()) for prefixed_state_name, state in ts_head_lib.state_to_dictionary( start_state).items(): state_shape_with_batch = tensor_shape.TensorShape( (static_batch_size,)).concatenate(state.shape) default_state_broadcast = array_ops.tile( state[None, ...], - multiples=array_ops.concat( - [batch_size_tensor[None], - array_ops.ones(len(state.shape), dtype=dtypes.int32)], - axis=0)) + multiples=array_ops.concat([ + batch_size_tensor[None], + array_ops.ones(len(state.shape), dtype=dtypes.int32) + ], + axis=0)) gathered_state[prefixed_state_name] = array_ops.placeholder_with_default( input=default_state_broadcast, name=prefixed_state_name, shape=state_shape_with_batch) return gathered_state - def build_one_shot_parsing_serving_input_receiver_fn( - self, filtering_length, prediction_length, default_batch_size=None, - values_input_dtype=None, truncate_values=False): + def build_one_shot_parsing_serving_input_receiver_fn(self, + filtering_length, + prediction_length, + default_batch_size=None, + values_input_dtype=None, + truncate_values=False): """Build an input_receiver_fn for export_savedmodel accepting tf.Examples. Only compatible with `OneShotPredictionHead` (see `head`). @@ -167,35 +179,34 @@ class TimeSeriesRegressor(estimator_lib.Estimator): times_column = feature_column.numeric_column( key=feature_keys.TrainEvalFeatures.TIMES, dtype=dtypes.int64) values_column = feature_column.numeric_column( - key=feature_keys.TrainEvalFeatures.VALUES, dtype=values_input_dtype, + key=feature_keys.TrainEvalFeatures.VALUES, + dtype=values_input_dtype, shape=(self._model.num_features,)) parsed_features_no_sequence = ( feature_column.make_parse_example_spec( - list(self._model.exogenous_feature_columns) - + [times_column, values_column])) + list(self._model.exogenous_feature_columns) + + [times_column, values_column])) parsed_features = {} for key, feature_spec in parsed_features_no_sequence.items(): if isinstance(feature_spec, parsing_ops.FixedLenFeature): if key == feature_keys.TrainEvalFeatures.VALUES: parsed_features[key] = feature_spec._replace( - shape=((values_proto_length,) - + feature_spec.shape)) + shape=((values_proto_length,) + feature_spec.shape)) else: parsed_features[key] = feature_spec._replace( - shape=((filtering_length + prediction_length,) - + feature_spec.shape)) + shape=((filtering_length + prediction_length,) + + feature_spec.shape)) elif feature_spec.dtype == dtypes.string: parsed_features[key] = parsing_ops.FixedLenFeature( shape=(filtering_length + prediction_length,), dtype=dtypes.string) else: # VarLenFeature - raise ValueError("VarLenFeatures not supported, got %s for key %s" - % (feature_spec, key)) + raise ValueError("VarLenFeatures not supported, got %s for key %s" % + (feature_spec, key)) tfexamples = array_ops.placeholder( shape=[default_batch_size], dtype=dtypes.string, name="input") features = parsing_ops.parse_example( - serialized=tfexamples, - features=parsed_features) + serialized=tfexamples, features=parsed_features) features[feature_keys.TrainEvalFeatures.TIMES] = array_ops.squeeze( features[feature_keys.TrainEvalFeatures.TIMES], axis=-1) features[feature_keys.TrainEvalFeatures.VALUES] = math_ops.cast( @@ -206,12 +217,13 @@ class TimeSeriesRegressor(estimator_lib.Estimator): batch_size_tensor=array_ops.shape( features[feature_keys.TrainEvalFeatures.TIMES])[0], static_batch_size=default_batch_size)) - return export_lib.ServingInputReceiver( - features, {"examples": tfexamples}) + return export_lib.ServingInputReceiver(features, {"examples": tfexamples}) + return _serving_input_receiver_fn - def build_raw_serving_input_receiver_fn( - self, default_batch_size=None, default_series_length=None): + def build_raw_serving_input_receiver_fn(self, + default_batch_size=None, + default_series_length=None): """Build an input_receiver_fn for export_savedmodel which accepts arrays. Automatically creates placeholders for exogenous `FeatureColumn`s passed to @@ -227,10 +239,12 @@ class TimeSeriesRegressor(estimator_lib.Estimator): which means only this series length will be accepted by the exported model. If None (default), static shape information for series length is omitted. + Returns: An input_receiver_fn which may be passed to the Estimator's export_savedmodel. """ + def _serving_input_receiver_fn(): """A receiver function to be passed to export_savedmodel.""" placeholders = {} @@ -246,9 +260,9 @@ class TimeSeriesRegressor(estimator_lib.Estimator): name=feature_keys.TrainEvalFeatures.VALUES, input=array_ops.zeros( shape=[ - default_batch_size - if default_batch_size else 0, default_series_length - if default_series_length else 0, self._model.num_features + default_batch_size if default_batch_size else 0, + default_series_length if default_series_length else 0, + self._model.num_features ], dtype=self._model.dtype), shape=(default_batch_size, default_series_length, @@ -268,12 +282,12 @@ class TimeSeriesRegressor(estimator_lib.Estimator): exogenous_feature_shapes = { key: (value.get_shape(), value.dtype) for key, value in placeholder_features.items()} - for feature_key, (batch_only_feature_shape, value_dtype) in ( - exogenous_feature_shapes.items()): + for feature_key, (batch_only_feature_shape, + value_dtype) in (exogenous_feature_shapes.items()): batch_only_feature_shape = ( batch_only_feature_shape.with_rank_at_least(1).as_list()) - feature_shape = ([default_batch_size, default_series_length] - + batch_only_feature_shape[1:]) + feature_shape = ([default_batch_size, default_series_length] + + batch_only_feature_shape[1:]) placeholders[feature_key] = array_ops.placeholder( dtype=value_dtype, name=feature_key, shape=feature_shape) batch_size_tensor = array_ops.shape(time_placeholder)[0] @@ -296,12 +310,20 @@ class ARRegressor(TimeSeriesRegressor): evaluation, although it may be seeded for deterministic evaluation. """ - def __init__( - self, periodicities, input_window_size, output_window_size, - num_features, exogenous_feature_columns=None, num_time_buckets=10, - loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, hidden_layer_sizes=None, - anomaly_prior_probability=None, anomaly_distribution=None, - optimizer=None, model_dir=None, config=None): + def __init__(self, + periodicities, + input_window_size, + output_window_size, + num_features, + exogenous_feature_columns=None, + num_time_buckets=10, + loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, + hidden_layer_sizes=None, + anomaly_prior_probability=None, + anomaly_distribution=None, + optimizer=None, + model_dir=None, + config=None): """Initialize the Estimator. Args: @@ -318,7 +340,7 @@ class ARRegressor(TimeSeriesRegressor): `tf.feature_column.embedding_column`) corresponding to exogenous features which provide extra information to the model but are not part of the series to be predicted. Passed to - `tf.feature_column.input_layer`. + `tf.compat.v1.feature_column.input_layer`. num_time_buckets: Number of buckets into which to divide (time % periodicity) for generating time based features. loss: Loss function to use for training. Currently supported values are @@ -338,9 +360,10 @@ class ARRegressor(TimeSeriesRegressor): `ar_model.AnomalyMixtureARModel.CAUCHY_ANOMALY`. See `AnomalyMixtureARModel`. Defaults to `GAUSSIAN_ANOMALY`. optimizer: The optimization algorithm to use when training, inheriting - from tf.train.Optimizer. Defaults to Adagrad with step size 0.1. + from tf.train.Optimizer. Defaults to Adagrad with step size 0.1. model_dir: See `Estimator`. config: See `Estimator`. + Raises: ValueError: For invalid combinations of arguments. """ @@ -353,14 +376,16 @@ class ARRegressor(TimeSeriesRegressor): if anomaly_distribution is None: anomaly_distribution = ar_model.AnomalyMixtureARModel.GAUSSIAN_ANOMALY model = ar_model.ARModel( - periodicities=periodicities, num_features=num_features, + periodicities=periodicities, + num_features=num_features, prediction_model_factory=functools.partial( ar_model.FlatPredictionModel, hidden_layer_sizes=hidden_layer_sizes), exogenous_feature_columns=exogenous_feature_columns, num_time_buckets=num_time_buckets, input_window_size=input_window_size, - output_window_size=output_window_size, loss=loss) + output_window_size=output_window_size, + loss=loss) else: if loss != ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS: raise ValueError( @@ -451,13 +476,13 @@ class LSTMAutoRegressor(TimeSeriesRegressor): extra_feature_columns=extra_feature_columns, num_timesteps=50, num_units=10, - optimizer=tf.train.ProximalAdagradOptimizer(...)) + optimizer=tf.compat.v1.train.ProximalAdagradOptimizer(...)) # Input builders def input_fn_train(): return { "times": tf.range(15)[None, :], - "values": tf.random_normal(shape=[1, 15, 1]) + "values": tf.random.normal(shape=[1, 15, 1]) } estimator.train(input_fn=input_fn_train, steps=100) @@ -496,10 +521,10 @@ class LSTMAutoRegressor(TimeSeriesRegressor): output_window_size: Number of future time steps to predict. Note that setting this value to > 1 empirically seems to give a better fit. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator - to continue training a previously saved model. - num_features: The dimensionality of the time series (default value is - one for univariate, more than one for multivariate). + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + num_features: The dimensionality of the time series (default value is one + for univariate, more than one for multivariate). extra_feature_columns: A list of `tf.feature_column`s (for example `tf.feature_column.embedding_column`) corresponding to features which provide extra information to the model but are not part of the series to @@ -515,14 +540,13 @@ class LSTMAutoRegressor(TimeSeriesRegressor): normalized data. num_units: The size of the hidden state in the encoder and decoder LSTM cells. - optimizer: string, `tf.train.Optimizer` object, or callable that defines - the optimizer algorithm to use for training. Defaults to the Adam - optimizer with a learning rate of 0.01. + optimizer: string, `tf.compat.v1.train.Optimizer` object, or callable that + defines the optimizer algorithm to use for training. Defaults to the + Adam optimizer with a learning rate of 0.01. config: Optional `estimator.RunConfig` object to configure the runtime settings. """ - optimizer = optimizers.get_optimizer_instance( - optimizer, learning_rate=0.01) + optimizer = optimizers.get_optimizer_instance(optimizer, learning_rate=0.01) model = ar_model.ARModel( periodicities=periodicities, input_window_size=input_window_size, @@ -546,8 +570,13 @@ class LSTMAutoRegressor(TimeSeriesRegressor): class StateSpaceRegressor(TimeSeriesRegressor): """An Estimator for general state space models.""" - def __init__(self, model, state_manager=None, optimizer=None, model_dir=None, - config=None, head_type=ts_head_lib.TimeSeriesRegressionHead): + def __init__(self, + model, + state_manager=None, + optimizer=None, + model_dir=None, + config=None, + head_type=ts_head_lib.TimeSeriesRegressionHead): """See TimeSeriesRegressor. Uses the ChainingStateManager by default.""" if not isinstance(model, state_space_model.StateSpaceModel): raise ValueError( @@ -610,60 +639,59 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): Args: periodicities: The expected periodicity of the data (for example 24 if - feeding hourly data with a daily periodicity, or 60 * 24 if feeding - minute-level data with daily periodicity). Either a scalar or a - list. This parameter can be any real value, and does not control the - size of the model. However, increasing this without increasing - `num_values_per_cycle` will lead to smoother periodic behavior, as the - same number of distinct values will be cycled through over a longer - period of time. + feeding hourly data with a daily periodicity, or 60 * 24 if feeding + minute-level data with daily periodicity). Either a scalar or a list. + This parameter can be any real value, and does not control the size of + the model. However, increasing this without increasing + `num_values_per_cycle` will lead to smoother periodic behavior, as the + same number of distinct values will be cycled through over a longer + period of time. num_features: The dimensionality of the time series (one for univariate, - more than one for multivariate). + more than one for multivariate). cycle_num_latent_values: Along with `moving_average_order` and - `num_features`, controls the latent state size of the model. Square - matrices of size `num_features * (moving_average_order + - cycle_num_latent_values + 3)` are created and multiplied, so larger + `num_features`, controls the latent state size of the model. Square + matrices of size `num_features * (moving_average_order + + cycle_num_latent_values + 3)` are created and multiplied, so larger values may be slow. The trade-off is with resolution: cycling between - a smaller number of latent values means that only smoother functions - can be modeled. + a smaller number of latent values means that only smoother functions + can be modeled. moving_average_order: Controls model size (along with - `cycle_num_latent_values` and `autoregressive_order`) and the number - of steps before transient deviations revert to the mean defined by the - period and level/trend components. + `cycle_num_latent_values` and `autoregressive_order`) and the number of + steps before transient deviations revert to the mean defined by the + period and level/trend components. autoregressive_order: Each contribution from this component is a linear - combination of this many previous contributions. Also helps to - determine the model size. Learning autoregressive coefficients - typically requires more steps and a smaller step size than other - components. + combination of this many previous contributions. Also helps to determine + the model size. Learning autoregressive coefficients typically requires + more steps and a smaller step size than other components. exogenous_feature_columns: A list of `tf.feature_column`s (for example - `tf.feature_column.embedding_column`) corresponding to exogenous - features which provide extra information to the model but are not part - of the series to be predicted. Passed to - `tf.feature_column.input_layer`. + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not part + of the series to be predicted. Passed to + `tf.compat.v1.feature_column.input_layer`. exogenous_update_condition: A function taking two Tensor arguments, - `times` (shape [batch size]) and `features` (a dictionary mapping - exogenous feature keys to Tensors with shapes [batch size, ...]), and - returning a boolean Tensor with shape [batch size] indicating whether - state should be updated using exogenous features for each part of the - batch. Where it is False, no exogenous update is performed. If None - (default), exogenous updates are always performed. Useful for avoiding - "leaky" frequent exogenous updates when sparse updates are - desired. Called only during graph construction. See the "known - anomaly" example for example usage. + `times` (shape [batch size]) and `features` (a dictionary mapping + exogenous feature keys to Tensors with shapes [batch size, ...]), and + returning a boolean Tensor with shape [batch size] indicating whether + state should be updated using exogenous features for each part of the + batch. Where it is False, no exogenous update is performed. If None + (default), exogenous updates are always performed. Useful for avoiding + "leaky" frequent exogenous updates when sparse updates are desired. + Called only during graph construction. See the "known anomaly" example + for example usage. dtype: The floating point data type to compute with. float32 may be faster, but can be problematic for larger models and longer time series. anomaly_prior_probability: If not None, the model attempts to - automatically detect and ignore anomalies during training. This - parameter then controls the prior probability of an anomaly. Values - closer to 0 mean that points will be discarded less frequently. The - default value (None) means that anomalies are not discarded, which may - be slightly faster. + automatically detect and ignore anomalies during training. This + parameter then controls the prior probability of an anomaly. Values + closer to 0 mean that points will be discarded less frequently. The + default value (None) means that anomalies are not discarded, which may + be slightly faster. optimizer: The optimization algorithm to use when training, inheriting - from tf.train.Optimizer. Defaults to Adam with step size 0.02. + from tf.train.Optimizer. Defaults to Adam with step size 0.02. model_dir: See `Estimator`. config: See `Estimator`. head_type: The kind of head to use for the model (inheriting from - `TimeSeriesRegressionHead`). + `TimeSeriesRegressionHead`). """ if anomaly_prior_probability is not None: filtering_postprocessor = StateInterpolatingAnomalyDetector( diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py index 403c6e2cb4a..f9259a78393 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py +++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py @@ -98,8 +98,10 @@ from tensorflow.python.training import training from tensorflow.python.util import nest -def predict_continuation_input_fn( - evaluation, steps=None, times=None, exogenous_features=None): +def predict_continuation_input_fn(evaluation, + steps=None, + times=None, + exogenous_features=None): """An Estimator input_fn for running predict() after evaluate(). If the call to evaluate() we are making predictions based on had a batch_size @@ -121,6 +123,7 @@ def predict_continuation_input_fn( the batch dimension used when creating `evaluation`, and `window_size` is either the `steps` argument or the `window_size` of the `times` argument (depending on which was specified). + Returns: An `input_fn` suitable for passing to the `predict` function of a time series `Estimator`. @@ -138,6 +141,7 @@ def predict_continuation_input_fn( predict_times } features.update(exogenous_features) + def _predict_input_fn(): """An input_fn for predict().""" # Prevents infinite iteration with a constant output in an Estimator's @@ -148,6 +152,7 @@ def predict_continuation_input_fn( lambda value: training.limit_epochs(value, num_epochs=1), values) limited_features[key] = limited_values return (limited_features, None) + return _predict_input_fn @@ -227,18 +232,16 @@ class NumpyReader(TimeSeriesReader): Args: data: A dictionary mapping feature names to Numpy arrays, with two possible shapes (requires keys `TrainEvalFeatures.TIMES` and - `TrainEvalFeatures.VALUES`): - Univariate; `TIMES` and `VALUES` are both vectors of shape [series - length] - Multivariate; `TIMES` is a vector of shape [series length], `VALUES` - has shape [series length x number of features]. - In any case, `VALUES` and any exogenous features must have their shapes - prefixed by the shape of the value corresponding to the `TIMES` key. + `TrainEvalFeatures.VALUES`): Univariate; `TIMES` and `VALUES` are both + vectors of shape [series length] Multivariate; `TIMES` is a vector of + shape [series length], `VALUES` has shape [series length x number of + features]. In any case, `VALUES` and any exogenous features must have + their shapes prefixed by the shape of the value corresponding to the + `TIMES` key. read_num_records_hint: The maximum number of samples to read at one time, for efficiency. """ - self._features = _canonicalize_numpy_data( - data, require_single_batch=True) + self._features = _canonicalize_numpy_data(data, require_single_batch=True) self._read_num_records_hint = read_num_records_hint def check_dataset_size(self, minimum_dataset_size): @@ -254,8 +257,10 @@ class NumpyReader(TimeSeriesReader): def read(self): """Returns a large chunk of the Numpy arrays for later re-chunking.""" # Remove the batch dimension from all features - features = {key: numpy.squeeze(value, axis=0) - for key, value in self._features.items()} + features = { + key: numpy.squeeze(value, axis=0) + for key, value in self._features.items() + } return estimator_lib.inputs.numpy_input_fn( x=features, # The first dimensions of features are the series length, since we have @@ -275,12 +280,14 @@ class NumpyReader(TimeSeriesReader): queue_capacity=2, # Each queue element is a full copy of the dataset shuffle=False)() # TimeSeriesInputFn expect just a batch dimension - return {feature_name: array_ops.squeeze(feature_value, axis=0) - for feature_name, feature_value in features.items()} + return { + feature_name: array_ops.squeeze(feature_value, axis=0) + for feature_name, feature_value in features.items() + } class ReaderBaseTimeSeriesParser(TimeSeriesReader): - """Base for time series readers which wrap a `tf.ReaderBase`.""" + """Base for time series readers which wrap a `tf.compat.v1.ReaderBase`.""" def __init__(self, filenames, read_num_records_hint=4096): """Configure the time series reader. @@ -297,7 +304,7 @@ class ReaderBaseTimeSeriesParser(TimeSeriesReader): @abc.abstractmethod def _get_reader(self): - """Get an instance of the tf.ReaderBase associated with this class.""" + """Get an instance of the tf.compat.v1.ReaderBase associated with this class.""" pass @abc.abstractmethod @@ -326,6 +333,7 @@ class ReaderBaseTimeSeriesParser(TimeSeriesReader): Args: epoch_limit: The maximum number of times to read through the complete list of files before throwing an OutOfRangeError. + Returns: A tuple of (filename_queue, epoch_limiter): filename_queue: A FIFOQueue with filename work items. @@ -348,8 +356,8 @@ class ReaderBaseTimeSeriesParser(TimeSeriesReader): # will start incrementing and checking epoch_limiter, which will interrupt # any in-progress loops. conditional_count_up_to = control_flow_ops.cond( - state_ops.is_variable_initialized(epoch_limiter), - lambda: epoch_limiter.count_up_to(epoch_limit), + state_ops.is_variable_initialized( + epoch_limiter), lambda: epoch_limiter.count_up_to(epoch_limit), lambda: constant_op.constant(0, dtype=dtypes.int64)) with ops.control_dependencies([conditional_count_up_to]): filenames_tensor = array_ops.identity(filenames_tensor) @@ -358,7 +366,7 @@ class ReaderBaseTimeSeriesParser(TimeSeriesReader): return filename_queue, epoch_limiter def read(self): - """Reads a chunk of data from the `tf.ReaderBase` for later re-chunking.""" + """Reads a chunk of data from the `tf.compat.v1.ReaderBase` for later re-chunking.""" # Assuming there is at least one item to be read among all of the files in # self._filenames, we will not need to go through more than # self._read_num_records_hint epochs to get a batch of @@ -370,8 +378,8 @@ class ReaderBaseTimeSeriesParser(TimeSeriesReader): reader = self._get_reader() epoch_reset_op = state_ops.assign(epoch_limiter, 0) with ops.control_dependencies([epoch_reset_op]): - _, records = reader.read_up_to( - filename_queue, self._read_num_records_hint) + _, records = reader.read_up_to(filename_queue, + self._read_num_records_hint) return self._process_records(records) def read_full(self): @@ -387,29 +395,34 @@ class ReaderBaseTimeSeriesParser(TimeSeriesReader): with ops.control_dependencies([epoch_reset_op]): first_key, first_value = reader.read_up_to(filename_queue, 1) # Read until we get a duplicate key (one epoch) - def _while_condition( - current_key, current_value, current_index, collected_records): + def _while_condition(current_key, current_value, current_index, + collected_records): del current_value, current_index, collected_records # unused - return math_ops.not_equal(array_ops.squeeze(current_key, axis=0), - array_ops.squeeze(first_key, axis=0)) + return math_ops.not_equal( + array_ops.squeeze(current_key, axis=0), + array_ops.squeeze(first_key, axis=0)) - def _while_body( - current_key, current_value, current_index, collected_records): + def _while_body(current_key, current_value, current_index, + collected_records): del current_key # unused new_key, new_value = reader.read_up_to(filename_queue, 1) new_key.set_shape([1]) new_value.set_shape([1]) - return (new_key, - new_value, - current_index + 1, + return (new_key, new_value, current_index + 1, collected_records.write(current_index, current_value)) + _, _, _, records_ta = control_flow_ops.while_loop( _while_condition, _while_body, - [constant_op.constant([""]), first_value, - 0, # current_index starting value - tensor_array_ops.TensorArray( # collected_records - dtype=dtypes.string, size=0, dynamic_size=True)]) + [ + constant_op.constant([""]), + first_value, + 0, # current_index starting value + tensor_array_ops.TensorArray( # collected_records + dtype=dtypes.string, + size=0, + dynamic_size=True) + ]) records = records_ta.concat() # Reset the reader when we're done so that subsequent requests for data get # the dataset in the proper order. @@ -433,21 +446,21 @@ class CSVReader(ReaderBaseTimeSeriesParser): """CSV-parsing reader for a `TimeSeriesInputFn`. Args: - filenames: A filename or list of filenames to read the time series - from. Each line must have columns corresponding to `column_names`. - column_names: A list indicating names for each - feature. `TrainEvalFeatures.TIMES` and `TrainEvalFeatures.VALUES` are - required; `VALUES` may be repeated to indicate a multivariate series. + filenames: A filename or list of filenames to read the time series from. + Each line must have columns corresponding to `column_names`. + column_names: A list indicating names for each feature. + `TrainEvalFeatures.TIMES` and `TrainEvalFeatures.VALUES` are required; + `VALUES` may be repeated to indicate a multivariate series. column_dtypes: If provided, must be a list with the same length as - `column_names`, indicating dtypes for each column. Defaults to - `tf.int64` for `TrainEvalFeatures.TIMES` and `tf.float32` for - everything else. - skip_header_lines: Passed on to `tf.TextLineReader`; skips this number of - lines at the beginning of each file. + `column_names`, indicating dtypes for each column. Defaults to + `tf.int64` for `TrainEvalFeatures.TIMES` and `tf.float32` for everything + else. + skip_header_lines: Passed on to `tf.compat.v1.TextLineReader`; skips this + number of lines at the beginning of each file. read_num_records_hint: When not reading a full dataset, indicates the - number of records to parse/transfer in a single chunk (for - efficiency). The actual number transferred at one time may be more or - less. + number of records to parse/transfer in a single chunk (for efficiency). + The actual number transferred at one time may be more or less. + Raises: ValueError: If required column names are not specified, or if lengths do not match. @@ -465,9 +478,9 @@ class CSVReader(ReaderBaseTimeSeriesParser): column_dtypes, column_names)) if sum(1 for column_name in column_names if column_name == feature_keys.TrainEvalFeatures.TIMES) != 1: - raise ValueError( - "Got more than one times column ('{}'), but exactly " - "one is required.".format(feature_keys.TrainEvalFeatures.TIMES)) + raise ValueError("Got more than one times column ('{}'), but exactly " + "one is required.".format( + feature_keys.TrainEvalFeatures.TIMES)) self._column_names = column_names self._column_dtypes = column_dtypes self._skip_header_lines = skip_header_lines @@ -480,12 +493,13 @@ class CSVReader(ReaderBaseTimeSeriesParser): def _process_records(self, lines): """Parse `lines` as CSV records.""" if self._column_dtypes is None: - default_values = [(array_ops.zeros([], dtypes.int64),) - if column_name == feature_keys.TrainEvalFeatures.TIMES - else () for column_name in self._column_names] + default_values = [(array_ops.zeros([], dtypes.int64),) if + column_name == feature_keys.TrainEvalFeatures.TIMES else + () for column_name in self._column_names] else: - default_values = [(array_ops.zeros([], dtype),) - for dtype in self._column_dtypes] + default_values = [ + (array_ops.zeros([], dtype),) for dtype in self._column_dtypes + ] columns = parsing_ops.decode_csv(lines, default_values) features_lists = {} for column_name, value in zip(self._column_names, columns): @@ -502,17 +516,17 @@ class CSVReader(ReaderBaseTimeSeriesParser): class TFExampleReader(ReaderBaseTimeSeriesParser): """Reads and parses `tf.Example`s from a TFRecords file.""" - def __init__(self, - filenames, - features): + def __init__(self, filenames, features): """Configure `tf.Example` parsing. Args: - filenames: A filename or list of filenames to read the time series - from. Each line must have columns corresponding to `column_names`. - features: A dictionary mapping from feature keys to `tf.FixedLenFeature` - objects. Must include `TrainEvalFeatures.TIMES` (scalar integer) and - `TrainEvalFeatures.VALUES` (floating point vector) features. + filenames: A filename or list of filenames to read the time series from. + Each line must have columns corresponding to `column_names`. + features: A dictionary mapping from feature keys to + `tf.io.FixedLenFeature` objects. Must include `TrainEvalFeatures.TIMES` + (scalar integer) and `TrainEvalFeatures.VALUES` (floating point vector) + features. + Raises: ValueError: If required times/values features are not present. """ @@ -572,6 +586,7 @@ class WholeDatasetInputFn(TimeSeriesInputFn): `RandomWindowInputFn` is better suited to training and quantitative evaluation. """ + # TODO(allenl): A SequentialWindowInputFn for getting model end state without # loading the whole dataset into memory (or for quantitative evaluation of # sequential models). Note that an Estimator using such a TimeSeriesInputFn @@ -598,16 +613,17 @@ class WholeDatasetInputFn(TimeSeriesInputFn): """ features = self._reader.read_full() # Add a batch dimension of one to each feature. - return ({feature_name: feature_value[None, ...] - for feature_name, feature_value in features.items()}, - None) + return ({ + feature_name: feature_value[None, ...] + for feature_name, feature_value in features.items() + }, None) class RandomWindowInputFn(TimeSeriesInputFn): """Wraps a `TimeSeriesReader` to create random batches of windows. Tensors are first collected into sequential windows (in a windowing queue - created by `tf.train.batch`, based on the order returned from + created by `tf.compat.v1.train.batch`, based on the order returned from `time_series_reader`), then these windows are randomly batched (in a `RandomShuffleQueue`), the Tensors returned by `create_batch` having shapes prefixed by [`batch_size`, `window_size`]. @@ -619,27 +635,33 @@ class RandomWindowInputFn(TimeSeriesInputFn): `WholeDatasetInputFn`. """ - def __init__( - self, time_series_reader, window_size, batch_size, - queue_capacity_multiplier=1000, shuffle_min_after_dequeue_multiplier=2, - discard_out_of_order=True, discard_consecutive_batches_limit=1000, - jitter=True, num_threads=2, shuffle_seed=None): + def __init__(self, + time_series_reader, + window_size, + batch_size, + queue_capacity_multiplier=1000, + shuffle_min_after_dequeue_multiplier=2, + discard_out_of_order=True, + discard_consecutive_batches_limit=1000, + jitter=True, + num_threads=2, + shuffle_seed=None): """Configure the RandomWindowInputFn. Args: time_series_reader: A TimeSeriesReader object. window_size: The number of examples to keep together sequentially. This controls the length of truncated backpropagation: smaller values mean - less sequential computation, which can lead to faster training, but - create a coarser approximation to the gradient (which would ideally be - computed by a forward pass over the entire sequence in order). + less sequential computation, which can lead to faster training, but + create a coarser approximation to the gradient (which would ideally be + computed by a forward pass over the entire sequence in order). batch_size: The number of windows to place together in a batch. Larger values will lead to more stable gradients during training. queue_capacity_multiplier: The capacity for the queues used to create batches, specified as a multiple of `batch_size` (for - RandomShuffleQueue) and `batch_size * window_size` (for the - FIFOQueue). Controls the maximum number of windows stored. Should be - greater than `shuffle_min_after_dequeue_multiplier`. + RandomShuffleQueue) and `batch_size * window_size` (for the FIFOQueue). + Controls the maximum number of windows stored. Should be greater than + `shuffle_min_after_dequeue_multiplier`. shuffle_min_after_dequeue_multiplier: The minimum number of windows in the RandomShuffleQueue after a dequeue, which controls the amount of entropy introduced during batching. Specified as a multiple of `batch_size`. @@ -660,8 +682,8 @@ class RandomWindowInputFn(TimeSeriesInputFn): removes one source of non-determinism (and in combination with shuffle_seed should provide deterministic windowing). shuffle_seed: A seed for window shuffling. The default value of None - provides random behavior. With `shuffle_seed` set and - `num_threads=1`, provides deterministic behavior. + provides random behavior. With `shuffle_seed` set and `num_threads=1`, + provides deterministic behavior. """ self._reader = time_series_reader self._window_size = window_size @@ -703,22 +725,22 @@ class RandomWindowInputFn(TimeSeriesInputFn): features, batch_size=self._window_size * internal_passing_size + jitter, enqueue_many=True, - capacity=(self._queue_capacity_multiplier - * internal_passing_size * self._window_size), + capacity=(self._queue_capacity_multiplier * internal_passing_size * + self._window_size), num_threads=self._num_threads) raw_features_windowed = features_windowed if self._jitter: features_windowed = { - key: value[jitter:] - for key, value in features_windowed.items()} + key: value[jitter:] for key, value in features_windowed.items() + } features_windowed = { key: array_ops.reshape( value, - array_ops.concat( - [[internal_passing_size, self._window_size], - array_ops.shape(value)[1:]], - axis=0)) - for key, value in features_windowed.items()} + array_ops.concat([[internal_passing_size, self._window_size], + array_ops.shape(value)[1:]], + axis=0)) + for key, value in features_windowed.items() + } batch_and_window_shape = tensor_shape.TensorShape( [internal_passing_size, self._window_size]) for key in features_windowed.keys(): @@ -746,6 +768,7 @@ class RandomWindowInputFn(TimeSeriesInputFn): name="discarded_windows_limiter", trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES]) + def _initialized_limit_check(): return control_flow_ops.cond( math_ops.reduce_any(non_decreasing), @@ -785,19 +808,18 @@ def _canonicalize_numpy_data(data, require_single_batch): Args: data: A dictionary mapping keys to Numpy arrays, with several possible shapes (requires keys `TrainEvalFeatures.TIMES` and - `TrainEvalFeatures.VALUES`): - Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a - vector of length [number of features]. + `TrainEvalFeatures.VALUES`): Single example; `TIMES` is a scalar and + `VALUES` is either a scalar or a vector of length [number of features]. Sequence; `TIMES` is a vector of shape [series length], `VALUES` either - has shape [series length] (univariate) or [series length x number of - features] (multivariate). - Batch of sequences; `TIMES` is a vector of shape [batch size x series - length], `VALUES` has shape [batch size x series length] or [batch - size x series length x number of features]. - In any case, `VALUES` and any exogenous features must have their shapes - prefixed by the shape of the value corresponding to the `TIMES` key. + has shape [series length] (univariate) or [series length x number of + features] (multivariate). Batch of sequences; `TIMES` is a vector of + shape [batch size x series length], `VALUES` has shape [batch size x + series length] or [batch size x series length x number of features]. In + any case, `VALUES` and any exogenous features must have their shapes + prefixed by the shape of the value corresponding to the `TIMES` key. require_single_batch: If True, raises an error if the provided data has a batch dimension > 1. + Returns: A dictionary with features normalized to have shapes prefixed with [batch size x series length]. The sizes of dimensions which were omitted in the @@ -837,8 +859,8 @@ def _canonicalize_numpy_data(data, require_single_batch): # Add trivial batch and time dimensions for every feature features = {key: value[None, None, ...] for key, value in features.items()} if len(times.shape) == 1: # shape [series length] - if len(features[feature_keys.TrainEvalFeatures.VALUES] - .shape) == 1: # shape [series length] + if len(features[feature_keys.TrainEvalFeatures.VALUES].shape + ) == 1: # shape [series length] # Add a feature dimension (with one feature) features[feature_keys.TrainEvalFeatures.VALUES] = features[ feature_keys.TrainEvalFeatures.VALUES][..., None] @@ -853,8 +875,8 @@ def _canonicalize_numpy_data(data, require_single_batch): features[feature_keys.TrainEvalFeatures.VALUES].shape)) # Add trivial batch dimensions for every feature features = {key: value[None, ...] for key, value in features.items()} - elif len(features[feature_keys.TrainEvalFeatures.TIMES] - .shape) != 2: # shape [batch size, series length] + elif len(features[feature_keys.TrainEvalFeatures.TIMES].shape + ) != 2: # shape [batch size, series length] raise ValueError( ("Got an unexpected number of dimensions for times. Was expecting at " "most two ([batch size, series length]), but got shape {}.").format( diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index a8cd4287e00..53995c15926 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -71,7 +71,7 @@ class TimeSeriesModel(object): `tf.feature_column.embedding_column`) corresponding to exogenous features which provide extra information to the model but are not part of the series to be predicted. Passed to - `tf.feature_column.input_layer`. + `tf.compat.v1.feature_column.input_layer`. dtype: The floating point datatype to use. """ if exogenous_feature_columns: diff --git a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py index 0461abdc19c..27c8bfe653d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py @@ -31,7 +31,9 @@ from tensorflow.contrib.timeseries.python.timeseries import model_utils as _mode from tensorflow.python.util.all_util import remove_undocumented -def _colate_features_to_feeds_and_fetches(signature, features, graph, +def _colate_features_to_feeds_and_fetches(signature, + features, + graph, continue_from=None): """Uses a saved model signature to construct feed and fetch dictionaries.""" if continue_from is None: @@ -71,13 +73,13 @@ def predict_continuation(continue_from, Args: continue_from: A dictionary containing the results of either an Estimator's - evaluate method or filter_continuation. Used to determine the model - state to make predictions starting from. + evaluate method or filter_continuation. Used to determine the model state + to make predictions starting from. signatures: The `MetaGraphDef` protocol buffer returned from - `tf.saved_model.loader.load`. Used to determine the names of Tensors to - feed and fetch. Must be from the same model as `continue_from`. + `tf.compat.v1.saved_model.loader.load`. Used to determine the names of + Tensors to feed and fetch. Must be from the same model as `continue_from`. session: The session to use. The session's graph must be the one into which - `tf.saved_model.loader.load` loaded the model. + `tf.compat.v1.saved_model.loader.load` loaded the model. steps: The number of steps to predict (scalar), starting after the evaluation or filtering. If `times` is specified, `steps` must not be; one is required. @@ -92,6 +94,7 @@ def predict_continuation(continue_from, the batch dimension used when creating `continue_from`, and `window_size` is either the `steps` argument or the `window_size` of the `times` argument (depending on which was specified). + Returns: A dictionary with model-specific predictions (typically having keys "mean" and "covariance") and a feature_keys.PredictionResults.TIMES key indicating @@ -129,23 +132,22 @@ def cold_start_filter(signatures, session, features): Args: signatures: The `MetaGraphDef` protocol buffer returned from - `tf.saved_model.loader.load`. Used to determine the names of Tensors to - feed and fetch. Must be from the same model as `continue_from`. + `tf.compat.v1.saved_model.loader.load`. Used to determine the names of + Tensors to feed and fetch. Must be from the same model as `continue_from`. session: The session to use. The session's graph must be the one into which - `tf.saved_model.loader.load` loaded the model. + `tf.compat.v1.saved_model.loader.load` loaded the model. features: A dictionary mapping keys to Numpy arrays, with several possible shapes (requires keys `FilteringFeatures.TIMES` and - `FilteringFeatures.VALUES`): - Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a - vector of length [number of features]. + `FilteringFeatures.VALUES`): Single example; `TIMES` is a scalar and + `VALUES` is either a scalar or a vector of length [number of features]. Sequence; `TIMES` is a vector of shape [series length], `VALUES` either - has shape [series length] (univariate) or [series length x number of - features] (multivariate). - Batch of sequences; `TIMES` is a vector of shape [batch size x series - length], `VALUES` has shape [batch size x series length] or [batch - size x series length x number of features]. - In any case, `VALUES` and any exogenous features must have their shapes - prefixed by the shape of the value corresponding to the `TIMES` key. + has shape [series length] (univariate) or [series length x number of + features] (multivariate). Batch of sequences; `TIMES` is a vector of + shape [batch size x series length], `VALUES` has shape [batch size x + series length] or [batch size x series length x number of features]. In + any case, `VALUES` and any exogenous features must have their shapes + prefixed by the shape of the value corresponding to the `TIMES` key. + Returns: A dictionary containing model state updated to account for the observations in `features`. @@ -156,9 +158,7 @@ def cold_start_filter(signatures, session, features): data=features, require_single_batch=False) output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches( - signature=filter_signature, - features=features, - graph=session.graph) + signature=filter_signature, features=features, graph=session.graph) output = session.run(output_tensors_by_name, feed_dict=feed_dict) # Make it easier to chain filter -> predict by keeping track of the current # time. @@ -176,26 +176,25 @@ def filter_continuation(continue_from, signatures, session, features): Args: continue_from: A dictionary containing the results of either an Estimator's - evaluate method or a previous filter step (cold start or - continuation). Used to determine the model state to start filtering from. + evaluate method or a previous filter step (cold start or continuation). + Used to determine the model state to start filtering from. signatures: The `MetaGraphDef` protocol buffer returned from - `tf.saved_model.loader.load`. Used to determine the names of Tensors to - feed and fetch. Must be from the same model as `continue_from`. + `tf.compat.v1.saved_model.loader.load`. Used to determine the names of + Tensors to feed and fetch. Must be from the same model as `continue_from`. session: The session to use. The session's graph must be the one into which - `tf.saved_model.loader.load` loaded the model. + `tf.compat.v1.saved_model.loader.load` loaded the model. features: A dictionary mapping keys to Numpy arrays, with several possible shapes (requires keys `FilteringFeatures.TIMES` and - `FilteringFeatures.VALUES`): - Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a - vector of length [number of features]. + `FilteringFeatures.VALUES`): Single example; `TIMES` is a scalar and + `VALUES` is either a scalar or a vector of length [number of features]. Sequence; `TIMES` is a vector of shape [series length], `VALUES` either - has shape [series length] (univariate) or [series length x number of - features] (multivariate). - Batch of sequences; `TIMES` is a vector of shape [batch size x series - length], `VALUES` has shape [batch size x series length] or [batch - size x series length x number of features]. - In any case, `VALUES` and any exogenous features must have their shapes - prefixed by the shape of the value corresponding to the `TIMES` key. + has shape [series length] (univariate) or [series length x number of + features] (multivariate). Batch of sequences; `TIMES` is a vector of + shape [batch size x series length], `VALUES` has shape [batch size x + series length] or [batch size x series length x number of features]. In + any case, `VALUES` and any exogenous features must have their shapes + prefixed by the shape of the value corresponding to the `TIMES` key. + Returns: A dictionary containing model state updated to account for the observations in `features`. @@ -217,4 +216,5 @@ def filter_continuation(continue_from, signatures, session, features): _feature_keys.FilteringFeatures.TIMES] return output + remove_undocumented(module_name=__name__) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index cf5e749042a..08eafece5d3 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -39,6 +39,7 @@ py_test( name = "state_space_model_test", timeout = "long", # Moderate but for asan srcs = ["state_space_model_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no_mac", @@ -180,6 +181,7 @@ py_test( name = "structural_ensemble_test", timeout = "long", # Moderate but for asan/tsan/msan timeouts srcs = ["structural_ensemble_test.py"], + python_version = "PY2", shard_count = 4, srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py index 2ecc7eafdaf..d0f9a7df2f9 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py @@ -117,7 +117,7 @@ class StateSpaceModelConfiguration( `tf.feature_column.embedding_column`) corresponding to exogenous features which provide extra information to the model but are not part of the series to be predicted. Passed to - `tf.feature_column.input_layer`. + `tf.compat.v1.feature_column.input_layer`. exogenous_update_condition: A function taking two Tensor arguments `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]) and returning a diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index ee1cd3213ef..a53cf2b86c0 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -20,6 +20,7 @@ package( "//medical/pathology:__subpackages__", "//smartass/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_models:__subpackages__", "//vr/perception:__subpackages__", ], ) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 32858850cdb..a2848f58ebd 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -31,6 +31,7 @@ int main(int argc, char** argv) { int FLAGS_num_tracing_attempts = 3; bool FLAGS_include_dataset_ops = true; int FLAGS_monitoring_level = 0; + bool FLAGS_timestamp = false; int FLAGS_num_queries = 100; std::vector flag_list = { tensorflow::Flag("service_addr", &FLAGS_service_addr, @@ -54,6 +55,9 @@ int main(int argc, char** argv) { "Choose a monitoring level between 1 and 2 to monitor " "your TPU job continuously. Level 2 is more verbose " "than level 1 and shows more metrics."), + tensorflow::Flag("timestamp", &FLAGS_timestamp, + "Set to true to display timestamp in monitoring " + "results."), tensorflow::Flag("num_queries", &FLAGS_num_queries, "This script will run monitoring for num_queries before " "it stops.")}; @@ -102,7 +106,8 @@ int main(int argc, char** argv) { << "ms and show metrics for " << num_queries << " time(s)." << std::endl; tensorflow::profiler::client::StartMonitoring( - FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, num_queries); + FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, + FLAGS_timestamp, num_queries); } else { status = tensorflow::profiler::client::StartTracing( FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list, diff --git a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py index 41aa4d26781..d85aae64871 100644 --- a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py @@ -19,5 +19,5 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.python.tpu._tpu_estimator_embedding import * +from tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/error_handling.py b/tensorflow/contrib/tpu/python/tpu/error_handling.py index 1b1328b4075..9cbb5084a54 100644 --- a/tensorflow/contrib/tpu/python/tpu/error_handling.py +++ b/tensorflow/contrib/tpu/python/tpu/error_handling.py @@ -19,5 +19,5 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.python.tpu.error_handling import * +from tensorflow_estimator.python.estimator.tpu.error_handling import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 0f95afd6db4..01b1b4af339 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -30,7 +30,7 @@ strategy = keras_support.TPUDistributionStrategy(resolver) model = keras_support.tpu_model(model, strategy=strategy) # Only TF optimizers are currently supported. -model.compile(optimizer=tf.train.AdamOptimizer(), ...) +model.compile(optimizer=tf.compat.v1.train.AdamOptimizer(), ...) # `images` and `labels` should be Numpy arrays. Support for tensor input # (e.g. datasets) is planned. @@ -308,10 +308,11 @@ def _cross_replica_concat(tensor, core_id, num_cores, name): '{}.'.format(input_dtype, name)) batch_size = tensor.shape[0] - mask = math_ops.to_float( - math_ops.equal(np.arange(num_cores, dtype=np.int32), core_id)) + mask = math_ops.cast( + math_ops.equal(np.arange(num_cores, dtype=np.int32), core_id), + dtypes.float32) mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims) - result = mask * math_ops.to_float(tensor) + result = mask * math_ops.cast(tensor, dtypes.float32) local_tensor_with_holes = array_ops.reshape(result, [-1] + result.shape.as_list()[2:]) concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes) @@ -1044,29 +1045,29 @@ class TPUFunction(object): # the Momentum optimizer) when _make_train_function is invoked. with keras_tpu_variables.replicated_variable_for_optimizer( self._tpu_assignment.num_towers): - self._cloned_model._make_fit_function() + self._cloned_model._make_train_function() else: - self._cloned_model._make_fit_function() + self._cloned_model._make_train_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in self._cloned_model._fit_function.outputs + for tensor in self._cloned_model.train_function.outputs ] return [ - self._cloned_model._fit_function.updates_op, + self._cloned_model.train_function.updates_op, tpu_ops.outfeed_enqueue_tuple( - self._cloned_model._fit_function.outputs, + self._cloned_model.train_function.outputs, name='outfeed-enqueue-train') ] elif is_test: - self._cloned_model._make_eval_function() + self._cloned_model._make_test_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in self._cloned_model._eval_function.outputs + for tensor in self._cloned_model.test_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - self._cloned_model._eval_function.outputs, + self._cloned_model.test_function.outputs, name='outfeed-enqueue-test') ] elif is_predict: @@ -1405,8 +1406,6 @@ class KerasTPUModel(models.Model): self.predict_function = None self.test_function = None self.train_function = None - self._fit_function = None - self._eval_function = None self._stateful_metric_functions = [] cluster_resolver = strategy._tpu_cluster_resolver @@ -1642,7 +1641,7 @@ class KerasTPUModel(models.Model): validation_split=validation_split) # Prepare validation data - val_x, val_y, val_sample_weights = self._prepare_validation_data( + x, y, val_x, val_y, val_sample_weights = self._prepare_validation_data( validation_data, validation_split, validation_steps, x, y, sample_weights, batch_size) return self._pipeline_fit_loop( @@ -1935,7 +1934,7 @@ class KerasTPUModel(models.Model): batch_size: The training batch size (if provided) Returns: - A 3-tuple of (val_x, val_y, val_sample_weights). + A 5-tuple of (x, y, val_x, val_y, val_sample_weights). Raises: ValueError: If the provided arguments are not compatible with @@ -1947,7 +1946,7 @@ class KerasTPUModel(models.Model): # in TPUs. if validation_data: if (isinstance(validation_data, iterator_ops.Iterator) or - isinstance(validation_data, iterator_ops.EagerIterator) or + isinstance(validation_data, iterator_ops.IteratorV2) or isinstance(validation_data, dataset_ops.DatasetV2)): raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator ' 'for validation_data. Please instead pass a function ' @@ -1992,7 +1991,7 @@ class KerasTPUModel(models.Model): val_y = None val_sample_weights = None - return val_x, val_y, val_sample_weights + return x, y, val_x, val_y, val_sample_weights def predict(self, x, @@ -2047,21 +2046,6 @@ class KerasTPUModel(models.Model): self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) return self.test_function - def _make_fit_function(self): - if not self._fit_function: - self._fit_function = TPUFunction( - self, - model_fn_lib.ModeKeys.TRAIN, - tpu_assignment=self._tpu_assignment) - - return self._fit_function - - def _make_eval_function(self): - if not self._eval_function: - self._eval_function = TPUFunction( - self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) - return self._eval_function - def _make_predict_function(self): if not self.predict_function: self.predict_function = TPUFunction( @@ -2217,7 +2201,7 @@ def tpu_model(model, strategy=None): strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) model = keras_support.tpu_model(model, strategy) model.compile( - optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), + optimizer=tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1.0), ...) ``` diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index c36aaa38c0e..2c9bce0bca2 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -19,5 +19,5 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.python.tpu.tpu_config import * +from tensorflow_estimator.python.estimator.tpu.tpu_config import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index b77b010cba6..573f49b2b9b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -19,5 +19,5 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.python.tpu.tpu_context import * +from tensorflow_estimator.python.estimator.tpu.tpu_context import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 893118412e1..0ee490681e4 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -19,15 +19,15 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import,redefined-builtin -from tensorflow.python.tpu.tpu_estimator import * +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import * # used by tests -from tensorflow.python.tpu.tpu_estimator import _clone_export_output_with_tensors -from tensorflow.python.tpu.tpu_estimator import _create_global_step -from tensorflow.python.tpu.tpu_estimator import _export_output_to_tensors -from tensorflow.python.tpu.tpu_estimator import _get_scaffold -from tensorflow.python.tpu.tpu_estimator import _Inputs -from tensorflow.python.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR -from tensorflow.python.tpu.tpu_estimator import _TPU_ENQUEUE_OPS -from tensorflow.python.tpu.tpu_estimator import _TPU_ESTIMATOR -from tensorflow.python.tpu.tpu_estimator import _TPU_TRAIN_OP +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _clone_export_output_with_tensors +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _create_global_step +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _export_output_to_tensors +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _get_scaffold +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _Inputs +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ENQUEUE_OPS +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ESTIMATOR +from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_TRAIN_OP # pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/util.py b/tensorflow/contrib/tpu/python/tpu/util.py index 8d9b70d46eb..6e0da240466 100644 --- a/tensorflow/contrib/tpu/python/tpu/util.py +++ b/tensorflow/contrib/tpu/python/tpu/util.py @@ -19,5 +19,5 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.python.tpu.util import * +from tensorflow_estimator.python.estimator.tpu.util import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 5bc4c3b88ef..8f1d5ce2fdf 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -70,6 +70,7 @@ py_test( name = "device_setter_test", size = "small", srcs = ["python/training/device_setter_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":training_py", @@ -85,6 +86,7 @@ py_test( name = "sequence_queueing_state_saver_test", size = "medium", srcs = ["python/training/sequence_queueing_state_saver_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":training_py", @@ -103,6 +105,7 @@ py_test( name = "batch_sequences_with_states_test", size = "medium", srcs = ["python/training/batch_sequences_with_states_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["manual"], deps = [ @@ -126,6 +129,7 @@ py_test( name = "feeding_queue_runner_test", size = "medium", srcs = ["python/training/feeding_queue_runner_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow/python:client_testlib", @@ -141,6 +145,7 @@ py_test( name = "hparam_test", size = "small", srcs = ["python/training/hparam_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":training_py", @@ -152,6 +157,7 @@ py_test( name = "resample_test", size = "small", srcs = ["python/training/resample_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":training_py", @@ -169,6 +175,7 @@ py_test( name = "sampling_ops_test", size = "small", srcs = ["python/training/sampling_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":training_py", @@ -192,6 +199,7 @@ py_test( name = "sampling_ops_threading_test", size = "small", srcs = ["python/training/sampling_ops_threading_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "manual", @@ -214,6 +222,7 @@ py_test( name = "bucket_ops_test", size = "medium", srcs = ["python/training/bucket_ops_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["manual"], deps = [ @@ -234,6 +243,7 @@ py_test( name = "evaluation_test", size = "small", srcs = ["python/training/evaluation_test.py"], + python_version = "PY2", shard_count = 3, srcs_version = "PY2AND3", tags = [ @@ -266,6 +276,7 @@ py_test( name = "training_test", size = "medium", srcs = ["python/training/training_test.py"], + python_version = "PY2", shard_count = 8, srcs_version = "PY2AND3", tags = ["notsan"], diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index e7f23edc901..10f3f88f3eb 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -400,7 +400,7 @@ def bucket_by_sequence_length(input_length, math_ops.less_equal(buckets_min, input_length), math_ops.less(input_length, buckets_max)) which_bucket = math_ops.reduce_min(array_ops.where(conditions_c)) - which_bucket = math_ops.to_int32(which_bucket) + which_bucket = math_ops.cast(which_bucket, dtypes.int32) if shapes is not None: shapes = [tensor_shape.scalar()] + shapes diff --git a/tensorflow/contrib/training/python/training/device_setter.py b/tensorflow/contrib/training/python/training/device_setter.py index 231fc5788f6..8513aef02fd 100644 --- a/tensorflow/contrib/training/python/training/device_setter.py +++ b/tensorflow/contrib/training/python/training/device_setter.py @@ -71,7 +71,7 @@ class GreedyLoadBalancingStrategy(object): off CPU-intensive ops with RAM-intensive ops with network bandwidth. This class is intended to be used as a `ps_strategy` in - `tf.train.replica_device_setter`. + `tf.compat.v1.train.replica_device_setter`. """ def __init__(self, num_tasks, load_fn): diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index 16a647bf668..e51854ce159 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -36,13 +36,13 @@ out the metrics values to stdout: # Choose the metrics to compute: names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({ - "accuracy": tf.metrics.accuracy(labels, predictions), - "mse": tf.metrics.mean_squared_error(labels, predictions), + "accuracy": tf.compat.v1.metrics.accuracy(labels, predictions), + "mse": tf.compat.v1.metrics.mean_squared_error(labels, predictions), }) # Define the summaries to write: for metric_name, metric_value in metrics_to_values.iteritems(): - tf.summary.scalar(metric_name, metric_value) + tf.compat.v1.summary.scalar(metric_name, metric_value) checkpoint_dir = '/tmp/my_model_dir/' log_dir = '/tmp/my_model_eval/' @@ -80,13 +80,13 @@ more summaries and call the evaluate_repeatedly method: # Choose the metrics to compute: names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({ - "accuracy": tf.metrics.accuracy(labels, predictions), - "mse": tf.metrics.mean_squared_error(labels, predictions), + "accuracy": tf.compat.v1.metrics.accuracy(labels, predictions), + "mse": tf.compat.v1.metrics.mean_squared_error(labels, predictions), }) # Define the summaries to write: for metric_name, metric_value in metrics_to_values.iteritems(): - tf.summary.scalar(metric_name, metric_value) + tf.compat.v1.summary.scalar(metric_name, metric_value) checkpoint_dir = '/tmp/my_model_dir/' log_dir = '/tmp/my_model_eval/' @@ -116,8 +116,8 @@ with only summaries. The user need only leave out the 'eval_ops' argument: predictions = MyModel(images) # Define the summaries to write: - tf.summary.scalar(...) - tf.summary.histogram(...) + tf.compat.v1.summary.scalar(...) + tf.compat.v1.summary.histogram(...) checkpoint_dir = '/tmp/my_model_dir/' log_dir = '/tmp/my_model_eval/' @@ -180,7 +180,7 @@ def wait_for_new_checkpoint(checkpoint_dir, a checkpoint for the first time. seconds_to_sleep: The number of seconds to sleep for before looking for a new checkpoint. - timeout: The maximum amount of time to wait. If left as `None`, then the + timeout: The maximum number of seconds to wait. If left as `None`, then the process will wait indefinitely. Returns: @@ -232,8 +232,8 @@ def checkpoints_iterator(checkpoint_dir, checkpoint_dir: The directory in which checkpoints are saved. min_interval_secs: The minimum number of seconds between yielding checkpoints. - timeout: The maximum amount of time to wait between checkpoints. If left as - `None`, then the process will wait indefinitely. + timeout: The maximum number of seconds to wait between checkpoints. If left + as `None`, then the process will wait indefinitely. timeout_fn: Optional function to call after a timeout. If the function returns True, then it means that no new checkpoints will be generated and the iterator will exit. The function is called with no arguments. @@ -277,7 +277,8 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook): Args: log_dir: The directory where the summary events are saved to. Used only when `summary_writer` is not specified. - summary_writer: A `tf.summary.FileWriter` to write summary events with. + summary_writer: A `tf.compat.v1.summary.FileWriter` to write summary + events with. summary_op: The summary op to run. If left as `None`, then all summaries in the tf.GraphKeys.SUMMARIES collection are used. feed_dict: An optional feed dictionary to use when evaluating the @@ -380,26 +381,26 @@ def evaluate_repeatedly(checkpoint_dir, Args: checkpoint_dir: The directory where checkpoints are stored. master: The address of the TensorFlow master. - scaffold: An tf.train.Scaffold instance for initializing variables and - restoring variables. Note that `scaffold.init_fn` is used by the function - to restore the checkpoint. If you supply a custom init_fn, then it must - also take care of restoring the model from its checkpoint. - eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names - to `Tensors`, which is run until the session is requested to stop, - commonly done by a `tf.contrib.training.StopAfterNEvalsHook`. + scaffold: An tf.compat.v1.train.Scaffold instance for initializing variables + and restoring variables. Note that `scaffold.init_fn` is used by the + function to restore the checkpoint. If you supply a custom init_fn, then + it must also take care of restoring the model from its checkpoint. + eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to + `Tensors`, which is run until the session is requested to stop, commonly + done by a `tf.contrib.training.StopAfterNEvalsHook`. feed_dict: The feed dictionary to use when executing the `eval_ops`. final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to `Tensors`. final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. eval_interval_secs: The minimum number of seconds between evaluations. - hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the - evaluation loop. - config: An instance of `tf.ConfigProto` that will be used to + hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside + the evaluation loop. + config: An instance of `tf.compat.v1.ConfigProto` that will be used to configure the `Session`. If left as `None`, the default will be used. max_number_of_evaluations: The maximum times to run the evaluation. If left as `None`, then evaluation runs indefinitely. - timeout: The maximum amount of time to wait between checkpoints. If left as - `None`, then the process will wait indefinitely. + timeout: The maximum number of seconds to wait between checkpoints. If left + as `None`, then the process will wait indefinitely. timeout_fn: Optional function to call after a timeout. If the function returns True, then it means that no new checkpoints will be generated and the iterator will exit. The function is called with no arguments. @@ -445,14 +446,14 @@ def evaluate_repeatedly(checkpoint_dir, with monitored_session.MonitoredSession( session_creator=session_creator, hooks=hooks) as session: - logging.info('Starting evaluation at ' + time.strftime( - '%Y-%m-%d-%H:%M:%S', time.gmtime())) + logging.info('Starting evaluation at ' + + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) if eval_ops is not None: while not session.should_stop(): session.run(eval_ops, feed_dict) - logging.info('Finished evaluation at ' + time.strftime( - '%Y-%m-%d-%H:%M:%S', time.gmtime())) + logging.info('Finished evaluation at ' + + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) num_evaluations += 1 if (max_number_of_evaluations is not None and diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index cb0a25f333b..7c9e10105af 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -161,6 +161,11 @@ def _cast_to_type_if_compatible(name, param_type, value): "Could not cast hparam '%s' of type '%s' from value %r" % (name, param_type, value)) + # If `value` is already of type `param_type`, return it directly. + # `isinstance` is too weak (e.g. isinstance(True, int) == True). + if type(value) == param_type: # pylint: disable=unidiomatic-typecheck + return value + # Some callers use None, for which we can't do any casting/checking. :( if issubclass(param_type, type(None)): return value @@ -545,7 +550,7 @@ class HParams(object): ValueError: If `values` cannot be parsed or a hyperparameter in `values` doesn't exist. """ - type_map = dict() + type_map = {} for name, t in self._hparam_types.items(): param_type, _ = t type_map[name] = param_type @@ -658,10 +663,13 @@ class HParams(object): return key in self._hparam_types def __str__(self): - return str(sorted(self.values().items())) + hpdict = self.values() + output_list = ['{}={}'.format(key, hpdict[key]) for key in hpdict] + return ','.join(output_list) def __repr__(self): - return '%s(%s)' % (type(self).__name__, self.__str__()) + strval = str(sorted(self.values().items())) + return '%s(%s)' % (type(self).__name__, strval) @staticmethod def _get_kind_name(param_type, is_list): diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index a990e04711c..1877d2b38ac 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -39,13 +39,16 @@ class HParamsTest(test.TestCase): def testSomeValues(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d') - self.assertDictEqual( - {'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d'}, - hparams.values()) - expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), ' - '(\'d\', \'/a/b=c/d\')]') - self.assertEqual(expected_str, str(hparams.__str__())) - self.assertEqual(expected_str, str(hparams)) + self.assertDictEqual({ + 'aaa': 1, + 'b': 2.0, + 'c_c': 'relu6', + 'd': '/a/b=c/d' + }, hparams.values()) + expected_str = ('HParams([(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'),' + ' (\'d\', \'/a/b=c/d\')])') + self.assertEqual(expected_str, repr(hparams)) + self.assertEqual(expected_str, repr(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) @@ -73,8 +76,12 @@ class HParamsTest(test.TestCase): self.assertEqual('relu4', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=,b=0,') - self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d'}, - hparams.values()) + self.assertDictEqual({ + 'aaa': 12, + 'b': 0, + 'c_c': '', + 'd': '/a/b=c/d' + }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual('', hparams.c_c) @@ -140,8 +147,11 @@ class HParamsTest(test.TestCase): hparams = hparam.HParams(x=1, b=2.0, d=[0.5]) hparams.override_from_dict({'d': [0.1, 0.2, 0.3]}) - self.assertDictEqual({'d': [0.1, 0.2, 0.3], 'x': 1, 'b': 2.0}, - hparams.values()) + self.assertDictEqual({ + 'd': [0.1, 0.2, 0.3], + 'x': 1, + 'b': 2.0 + }, hparams.values()) def testBoolParsing(self): for value in 'true', 'false', 'True', 'False', '1', '0': @@ -209,6 +219,21 @@ class HParamsTest(test.TestCase): self.assertEqual([1.0], hparams2.b) self.assertEqual(['_12', '3\'4"'], hparams2.c_c) + def testStr(self): + hparam1 = hparam.HParams(a=1, b=[2.0, 3.0], c='relu6') + hparam1_str = str(hparam1) + # Create the signature + hparam2 = hparam.HParams() + hparam2.add_hparam('a', 4) + hparam2.add_hparam('b', [5.0, 6.0]) + hparam2.add_hparam('c', 'relu10') + # Load from string + hparam2.parse(hparam1_str) + # Verifies all hparams are restored + self.assertEqual(hparam2.a, hparam1.a) + self.assertEqual(hparam2.b, hparam1.b) + self.assertEqual(hparam2.c, hparam1.c) + def testParseValuesWithIndexAssigment1(self): """Assignment to an index position.""" parse_dict = hparam.parse_values('arr[1]=10', {'arr': int}) @@ -241,9 +266,10 @@ class HParamsTest(test.TestCase): def testParseValuesWithIndexAssigment3(self): """Assignment to index positions in multiple names.""" - parse_dict = hparam.parse_values('arr[0]=10,arr[1]=20,L[5]=100,L[10]=200', - {'arr': int, - 'L': int}) + parse_dict = hparam.parse_values('arr[0]=10,arr[1]=20,L[5]=100,L[10]=200', { + 'arr': int, + 'L': int + }) self.assertEqual(len(parse_dict), 2) self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {0: 10, 1: 20}) @@ -253,8 +279,11 @@ class HParamsTest(test.TestCase): def testParseValuesWithIndexAssigment3_IgnoreUnknown(self): """Assignment to index positions in multiple names.""" parse_dict = hparam.parse_values( - 'arr[0]=10,C=5,arr[1]=20,B[0]=kkk,L[5]=100,L[10]=200', - {'arr': int, 'L': int}, ignore_unknown=True) + 'arr[0]=10,C=5,arr[1]=20,B[0]=kkk,L[5]=100,L[10]=200', { + 'arr': int, + 'L': int + }, + ignore_unknown=True) self.assertEqual(len(parse_dict), 2) self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {0: 10, 1: 20}) @@ -263,10 +292,11 @@ class HParamsTest(test.TestCase): def testParseValuesWithIndexAssigment4(self): """Assignment of index positions and scalars.""" - parse_dict = hparam.parse_values('x=10,arr[1]=20,y=30', - {'x': int, - 'y': int, - 'arr': int}) + parse_dict = hparam.parse_values('x=10,arr[1]=20,y=30', { + 'x': int, + 'y': int, + 'arr': int + }) self.assertEqual(len(parse_dict), 3) self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {1: 20}) @@ -276,8 +306,12 @@ class HParamsTest(test.TestCase): def testParseValuesWithIndexAssigment4_IgnoreUnknown(self): """Assignment of index positions and scalars.""" parse_dict = hparam.parse_values( - 'x=10,foo[0]=bar,arr[1]=20,zzz=78,y=30', - {'x': int, 'y': int, 'arr': int}, ignore_unknown=True) + 'x=10,foo[0]=bar,arr[1]=20,zzz=78,y=30', { + 'x': int, + 'y': int, + 'arr': int + }, + ignore_unknown=True) self.assertEqual(len(parse_dict), 3) self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {1: 20}) @@ -305,8 +339,12 @@ class HParamsTest(test.TestCase): def testParseValuesWithIndexAssigment5_IgnoreUnknown(self): """Different variable types.""" parse_dict = hparam.parse_values( - 'a[0]=5,cc=4,b[1]=true,c[2]=abc,mm=2,d[3]=3.14', - {'a': int, 'b': bool, 'c': str, 'd': float}, + 'a[0]=5,cc=4,b[1]=true,c[2]=abc,mm=2,d[3]=3.14', { + 'a': int, + 'b': bool, + 'c': str, + 'd': float + }, ignore_unknown=True) self.assertEqual(set(parse_dict.keys()), {'a', 'b', 'c', 'd'}) self.assertTrue(isinstance(parse_dict['a'], dict)) @@ -404,9 +442,8 @@ class HParamsTest(test.TestCase): self.assertEqual('{"aaa"=123}', hparams3.to_json(separators=(';', '='))) hparams4 = hparam.HParams(aaa=123, b='hello', c_c=False) - self.assertEqual( - '{"aaa": 123, "b": "hello", "c_c": false}', - hparams4.to_json(sort_keys=True)) + self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}', + hparams4.to_json(sort_keys=True)) def testSetHParam(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True) @@ -454,6 +491,26 @@ class HParamsTest(test.TestCase): with self.assertRaises(ValueError): hparams.set_hparam('bool_', 1) + # Unfortunately there is no automagic conversion of bool-like strings to + # bool. + with self.assertRaises(ValueError): + hparams.set_hparam('bool_', 'true') + + with self.assertRaises(ValueError): + hparams.set_hparam('bool_', 'True') + + with self.assertRaises(ValueError): + hparams.set_hparam('bool_', 'false') + + with self.assertRaises(ValueError): + hparams.set_hparam('bool_', 'False') + + with self.assertRaises(ValueError): + hparams.set_hparam('bool_', '0') + + with self.assertRaises(ValueError): + hparams.set_hparam('bool_', '1') + with self.assertRaises(ValueError): hparams.set_hparam('int_', 2.2) @@ -470,6 +527,20 @@ class HParamsTest(test.TestCase): hparams.set_hparam('none', '1') self.assertEqual('1', hparams.none) + def testSetHParamExactTypeMatch(self): + + class DummyContext(object): + + def __init__(self, a, b=0): + self.a = a + self.b = b + + hparams = hparam.HParams(x=DummyContext(a=100, b=100)) + # Verify x is assigned directly, without casting. + hparams.set_hparam('x', DummyContext(a=100, b=100)) + self.assertEqual(hparams.x.a, 100) + self.assertEqual(hparams.x.b, 100) + def testNonProtoFails(self): with self.assertRaisesRegexp(AssertionError, ''): hparam.HParams(hparam_def=1) diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index 53e4f23a7cd..e44c4f8c0ef 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -673,7 +673,7 @@ class SequenceQueueingStateSaver(object): batch_size = 32 num_unroll = 20 lstm_size = 8 - cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size) + cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(num_units=lstm_size) initial_state_values = tf.zeros(cell.state_size, dtype=tf.float32) raw_data = get_single_input_from_input_reader() @@ -702,12 +702,12 @@ class SequenceQueueingStateSaver(object): state_name="lstm_state") # Start a prefetcher in the background - sess = tf.Session() + sess = tf.compat.v1.Session() num_threads = 3 - queue_runner = tf.train.QueueRunner( + queue_runner = tf.compat.v1.train.QueueRunner( stateful_reader, [stateful_reader.prefetch_op] * num_threads) - tf.train.add_queue_runner(queue_runner) - tf.train.start_queue_runners(sess=session) + tf.compat.v1.train.add_queue_runner(queue_runner) + tf.compat.v1.train.start_queue_runners(sess=session) while True: # Step through batches, perform training or inference... @@ -1320,7 +1320,7 @@ def batch_sequences_with_states(input_key, num_unroll = 20 num_enqueue_threads = 3 lstm_size = 8 - cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size) + cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(num_units=lstm_size) key, sequences, context = my_parser(raw_data) initial_state_values = tf.zeros((state_size,), dtype=tf.float32) @@ -1349,9 +1349,9 @@ def batch_sequences_with_states(input_key, state_name="lstm_state") # Start a prefetcher in the background - sess = tf.Session() + sess = tf.compat.v1.Session() - tf.train.start_queue_runners(sess=session) + tf.compat.v1.train.start_queue_runners(sess=session) while True: # Step through batches, perform training or inference... @@ -1597,7 +1597,7 @@ def _padding(sequences, num_unroll): else: # Only have SparseTensors sparse_lengths = [value.dense_shape[0] for value in sequences_dict.values() if isinstance(value, sparse_tensor.SparseTensor)] - length = math_ops.reduce_max(math_ops.to_int32(sparse_lengths)) + length = math_ops.reduce_max(math_ops.cast(sparse_lengths, dtypes.int32)) unroll = array_ops.constant(num_unroll) padded_length = length + ((unroll - (length % unroll)) % unroll) @@ -1620,8 +1620,9 @@ def _padding(sequences, num_unroll): # 3. concat values with paddings padded_sequences[key] = array_ops.concat([value, paddings], 0) else: - padded_shape = array_ops.concat([[math_ops.to_int64(padded_length)], - value.dense_shape[1:]], 0) + padded_shape = array_ops.concat( + [[math_ops.cast(padded_length, dtypes.int64)], value.dense_shape[1:]], + 0) padded_sequences[key] = sparse_tensor.SparseTensor( indices=value.indices, values=value.values, @@ -1834,8 +1835,8 @@ def _reconstruct_sparse_tensor_seq(sequence, Returns: A SparseTensor with a +1 higher rank than the input. """ - idx_batch = math_ops.to_int64( - math_ops.floor(sp_tensor.indices[:, 0] / num_unroll)) + idx_batch = math_ops.cast( + math_ops.floor(sp_tensor.indices[:, 0] / num_unroll), dtypes.int64) idx_time = math_ops.mod(sp_tensor.indices[:, 0], num_unroll) indices = array_ops.concat( [ diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py index 8932b905c91..15dc1622054 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py @@ -366,7 +366,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): update_2 = next_batch.save_state("state2", -1 + next_batch.state("state2")) - original_values = dict() + original_values = {} def insert(which): for i in range(20): diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py index ed0f398e30a..7f1cf3c8d33 100644 --- a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py +++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py @@ -90,7 +90,7 @@ def sgdr_decay(learning_rate, global_step, initial_period_steps, initial_period_steps=10000, t_mul=2, m_mul=0.5) # Passing global_step to minimize() will increment it at each step. learning_step = ( - tf.train.GradientDescentOptimizer(learning_rate) + tf.compat.v1.train.GradientDescentOptimizer(learning_rate) .minimize(...my loss..., global_step=global_step) ) diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index 4ceb6e9350f..36d6e828476 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -33,7 +33,8 @@ to user-specified arguments. total_loss = tf.contrib.losses.get_total_loss() # Define the optimizer: - optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) + optimizer = tf.compat.v1.train.MomentumOptimizer(FLAGS.learning_rate, + FLAGS.momentum) # Create the train_op train_op = tf.contrib.training.create_train_op(total_loss, optimizer) @@ -108,8 +109,8 @@ default update ops or simply add additional update ops to the update_ops=my_other_update_ops) # Use a set of update ops in addition to the default updates: - tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update0) - tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1) + tf.compat.v1.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update0) + tf.compat.v1.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1) train_op = tf.contrib.training.create_train_op( total_loss, @@ -119,7 +120,7 @@ default update ops or simply add additional update ops to the train_op = tf.contrib.training.create_train_op( total_loss, optimizer, - update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)) + update_ops=tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)) ****************************************** * Initializing a model from a checkpoint * @@ -318,8 +319,10 @@ def clip_gradient_norms(gradients_to_variables, max_norm): def clip_gradient_norms_fn(max_norm): """Returns a `transform_grads_fn` function for gradient clipping.""" + def clip_norms(gradients_to_variables): return clip_gradient_norms(gradients_to_variables, max_norm) + return clip_norms @@ -387,10 +390,10 @@ def create_train_op(total_loss, update_ops: An optional list of updates to execute. If `update_ops` is `None`, then the update ops are set to the contents of the `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but - it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, - a warning will be displayed. + it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, a + warning will be displayed. variables_to_train: an optional list of variables to train. If None, it will - default to all tf.trainable_variables(). + default to all tf.compat.v1.trainable_variables(). transform_grads_fn: A function which takes a single argument, a list of gradient to variable pairs (tuples), performs any requested gradient updates, such as gradient clipping or multipliers, and returns the updated @@ -427,10 +430,11 @@ def create_train_op(total_loss, total_loss = control_flow_ops.with_dependencies([barrier], total_loss) if variables_to_train is None: - # Default to tf.trainable_variables() + # Default to tf.compat.v1.trainable_variables() variables_to_train = tf_variables.trainable_variables() else: - # Make sure that variables_to_train are in tf.trainable_variables() + # Make sure that variables_to_train are in + # tf.compat.v1.trainable_variables() for v in variables_to_train: assert v.trainable or v in tf_variables.trainable_variables() @@ -494,11 +498,11 @@ def train(train_op, master: The URL of the master. is_chief: Specifies whether or not the training is being run by the primary replica during replica training. - scaffold: An tf.train.Scaffold instance. - hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the - training loop. - chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run - inside the training loop for the chief trainer only. + scaffold: An tf.compat.v1.train.Scaffold instance. + hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside + the training loop. + chief_only_hooks: List of `tf.estimator.SessionRunHook` instances which are + run inside the training loop for the chief trainer only. save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If `save_checkpoint_secs` is set to `None`, then the default checkpoint saver isn't used. @@ -506,11 +510,11 @@ def train(train_op, summaries are written to disk using a default summary saver. If `save_summaries_steps` is set to `None`, then the default summary saver isn't used. - config: An instance of `tf.ConfigProto`. - max_wait_secs: Maximum time workers should wait for the session to - become available. This should be kept relatively short to help detect - incorrect code, but sometimes may need to be increased if the chief takes - a while to start up. + config: An instance of `tf.compat.v1.ConfigProto`. + max_wait_secs: Maximum time workers should wait for the session to become + available. This should be kept relatively short to help detect incorrect + code, but sometimes may need to be increased if the chief takes a while to + start up. run_metadata: A [`RunMetadata`] protocol buffer. Returns: diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc index af29abd91fe..0f92ed3fe78 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc @@ -15,11 +15,8 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "grpcpp/alarm.h" -#include "grpcpp/grpcpp.h" -#include "grpcpp/server_builder.h" - #include "tensorflow/contrib/verbs/grpc_verbs_service.h" + #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.h b/tensorflow/contrib/verbs/grpc_verbs_service.h index e616778665a..97da84e3128 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service.h @@ -18,6 +18,9 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS +#include "grpcpp/alarm.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/server_builder.h" #include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" #include "tensorflow/contrib/verbs/rdma_mgr.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 9db80f6b573..b77f61e2e7b 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -1086,7 +1086,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed, // The tensor must be copied from GPU to CPU, because either: // 1. The tensor is located on a non GDR compatible GPU. // 2. The tensor's meta-data has changed. - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetGpuHostAllocator(0); copy = Tensor(alloc, in.dtype(), in.shape()); CountCopies(rm_.name_, (void*)DMAHelper::base(&in), (void*)DMAHelper::base(©), in.TotalBytes(), true); @@ -1543,7 +1543,7 @@ bool RdmaTensorRequest::AllocateTensors() { if (mr_ == nullptr) { // Can't RDMA directly to result. Use a proxy. proxy_tensor_ = - new Tensor(GPUProcessState::singleton()->GetCUDAHostAllocator(0), + new Tensor(GPUProcessState::singleton()->GetGpuHostAllocator(0), result_tensor_->dtype(), result_tensor_->shape()); rdma_addr_ = DMAHelper::base(proxy_tensor_); mr_ = @@ -1629,12 +1629,13 @@ void RdmaTensorRequest::RecvTensorContent() { CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_), (void*)DMAHelper::base(result_tensor_), result_tensor_->TotalBytes(), false); - GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context, - dst_dev_, result_tensor_, - [this](const Status& s) { - CHECK(s.ok()) << "copy tensor to gpu sync"; - Done(s); - }); + GPUUtil::CopyCPUTensorToGPU( + proxy_tensor_, recv_args_.device_context, dst_dev_, result_tensor_, + [this](const Status& s) { + CHECK(s.ok()) << "copy tensor to gpu sync"; + Done(s); + }, + true /*sync_dst_compute*/); return; } #endif diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 2f237542786..5ac9f46447c 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -277,8 +277,8 @@ void RdmaMgr::InitAllocators() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); #if GOOGLE_CUDA - GPUProcessState::singleton()->AddCUDAHostAllocVisitor(0, alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostFreeVisitor(0, free_visitor); + GPUProcessState::singleton()->AddGpuHostAllocVisitor(0, alloc_visitor); + GPUProcessState::singleton()->AddGpuHostFreeVisitor(0, free_visitor); if (IsGDRAvailable()) { // Note we don't free allocated GPU memory so there is no free visitor diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 25fc51b55a2..bcd02aa8410 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -63,7 +63,16 @@ # //tensorflow/tensorflow.bzl) will include the necessary symbols in binary # build targets. +package_group( + name = "dependency_whitelist", + packages = [ + "//learning/freud/topic_models/tensorflow/...", + "//quality/webanswers/brain/tokenization/custom_tf_ops/kernels/...", + ], +) + package(default_visibility = [ + ":dependency_whitelist", "//tensorflow:internal", "//tensorflow_models:__subpackages__", ]) @@ -107,7 +116,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") # For platform specific build config load( - "//tensorflow/core:platform/default/build_config.bzl", + ":platform/default/build_config.bzl", "tf_additional_all_protos", "tf_additional_cloud_kernel_deps", "tf_additional_cloud_op_deps", @@ -127,15 +136,18 @@ load( "tf_additional_libdevice_deps", "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", + "tf_additional_monitoring_hdrs", + "tf_additional_monitoring_srcs", "tf_additional_mpi_lib_defines", + "tf_additional_numa_copts", "tf_additional_numa_deps", "tf_additional_numa_lib_defines", - "tf_additional_numa_copts", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", "tf_additional_verbs_lib_defines", + "tf_grpc_service_all", "tf_jspb_proto_library", "tf_kernel_tests_linkstatic", "tf_lib_proto_compiler_deps", @@ -149,16 +161,14 @@ load( "tf_protos_grappler", "tf_protos_grappler_impl", "tf_pyclif_proto_library", - "tf_grpc_service_all", ) load( - "//tensorflow/core:platform/default/build_config_root.bzl", + ":platform/default/build_config_root.bzl", "if_dynamic_kernels", "if_static", "tf_cuda_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") load( "//third_party/mkl:build_defs.bzl", @@ -213,6 +223,7 @@ COMMON_PROTO_SRCS = [ "protobuf/tensor_bundle.proto", "protobuf/saver.proto", "protobuf/verifier_config.proto", + "protobuf/trace_events.proto", "util/event.proto", "util/memmapped_file_system.proto", "util/saved_tensor_slice.proto", @@ -234,6 +245,7 @@ ADDITIONAL_CORE_PROTO_SRCS = [ "example/example_parser_configuration.proto", "protobuf/trackable_object_graph.proto", "protobuf/control_flow.proto", + "protobuf/data/experimental/snapshot.proto", # TODO(ebrevdo): Re-enable once CriticalSection is in core. # "protobuf/critical_section.proto", "protobuf/meta_graph.proto", @@ -333,7 +345,7 @@ cc_library( hdrs = [":platform_base_hdrs"], copts = tf_copts(), tags = ["avoid_dep"], - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [":__subpackages__"], deps = [ ":lib_platform", "//tensorflow/core/platform/default/build_config:base", @@ -345,7 +357,7 @@ cc_library( hdrs = ["framework/bounds_check.h"], visibility = ["//tensorflow/core/kernels:friends"], deps = [ - "//tensorflow/core:platform_base", + ":platform_base", "//third_party/eigen3", ], ) @@ -392,7 +404,7 @@ cc_library( ":platform_port_internal_hdrs", ], copts = tf_copts() + tf_additional_numa_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [":__subpackages__"], deps = [ ":lib_platform", ":platform_base", @@ -433,7 +445,7 @@ cc_library( ":platform_protobuf_internal_hdrs", ], copts = tf_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [":__subpackages__"], deps = [ ":lib_platform", ":platform_base", @@ -511,8 +523,8 @@ cc_library( ], copts = tf_copts(), visibility = [ + ":__subpackages__", "//tensorflow/c:__subpackages__", - "//tensorflow/core:__subpackages__", ], deps = [ ":error_codes_proto_cc", @@ -548,7 +560,7 @@ cc_library( ":platform_file_system_hdrs", ], copts = tf_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [":__subpackages__"], deps = [ ":lib", ":lib_platform", @@ -565,7 +577,7 @@ cc_library( hdrs = [ "platform/platform_strings.h", ], - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [":__subpackages__"], deps = [":lib"], ) @@ -577,6 +589,7 @@ filegroup( "platform/cpu_feature_guard.h", "platform/error.h", "platform/fingerprint.h", + "platform/monitoring.h", "platform/net.h", "platform/notification.h", "platform/prefetch.h", @@ -588,7 +601,7 @@ filegroup( "platform/stacktrace_handler.h", "platform/strong_hash.h", "platform/subprocess.h", - ], + ] + tf_additional_monitoring_hdrs(), visibility = ["//visibility:private"], ) @@ -633,7 +646,7 @@ cc_library( ":platform_other_internal_hdrs", ], copts = tf_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [":__subpackages__"], deps = [ ":lib", ":lib_platform", @@ -897,6 +910,7 @@ tf_cuda_library( "framework/kernel_def_builder.h", "framework/kernel_def_util.h", "framework/log_memory.h", + "framework/logging.h", "framework/lookup_interface.h", "framework/memory_types.h", "framework/node_def_builder.h", @@ -929,17 +943,19 @@ tf_cuda_library( "framework/tracking_allocator.h", "framework/type_index.h", "framework/type_traits.h", + "framework/typed_allocator.h", "framework/types.h", "public/version.h", "util/activation_mode.h", "util/batch_util.h", "util/bcast.h", - "util/cuda_kernel_helper.h", + "util/matmul_bcast.h", "util/device_name_utils.h", "util/dump_graph.h", "util/events_writer.h", "util/example_proto_fast_parsing.h", "util/example_proto_helper.h", + "util/gpu_kernel_helper.h", "util/guarded_philox_random.h", "util/mirror_pad_mode.h", "util/padding.h", @@ -978,6 +994,32 @@ tf_cuda_library( ], ) +# This is redundant with the "framework" target above. It's useful for +# applications that want to depend on a minimal subset of TensorFlow (e.g. XLA). +cc_library( + name = "allocator", + srcs = [ + "framework/allocator.cc", + "framework/allocator_registry.cc", + "framework/allocator_registry.h", + "framework/numeric_types.h", + "framework/tracking_allocator.cc", + "framework/tracking_allocator.h", + "framework/type_traits.h", + ], + hdrs = [ + "framework/allocator.h", + ], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + ":lib", + "//third_party/eigen3", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "stats_calculator_portable", srcs = [ @@ -1140,6 +1182,10 @@ tf_gen_op_libs( "summary_ops", "training_ops", ], + deps = [ + ":lib", + ":protos_all_cc", + ], ) tf_gen_op_libs( @@ -1157,7 +1203,10 @@ tf_gen_op_libs( op_lib_names = [ "array_ops", ], - deps = [":protos_all_cc"], + deps = [ + ":lib", + ":protos_all_cc", + ], ) tf_gen_op_libs( @@ -1226,7 +1275,7 @@ cc_library( srcs = ["ops/word2vec_ops.cc"], linkstatic = 1, visibility = ["//tensorflow:internal"], - deps = ["//tensorflow/core:framework"], + deps = [":framework"], alwayslink = 1, ) @@ -1238,10 +1287,10 @@ cc_library( linkstatic = 1, visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", + ":framework", + ":lib", + ":lib_internal", + ":stream_executor", "//tensorflow/core/kernels:bounds_check_lib", ], alwayslink = 1, @@ -1385,6 +1434,7 @@ cc_library( ":framework", ":lib", ":math_ops_op_lib", + ":protos_all_cc", ], alwayslink = 1, ) @@ -1444,8 +1494,9 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( name = "all_kernels_impl", - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [":__subpackages__"], deps = [ + "//tensorflow/c/kernels:bitcast_op", "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", "//tensorflow/core/kernels:batch_kernels", @@ -1484,6 +1535,7 @@ cc_library( "//tensorflow/core/kernels:ragged_ops", "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:stateful_random_ops", + "//tensorflow/core/kernels:random_binomial_op", "//tensorflow/core/kernels:random_poisson_op", "//tensorflow/core/kernels:remote_fused_graph_ops", "//tensorflow/core/kernels:required", @@ -1508,6 +1560,7 @@ cc_library( "//tensorflow/core/kernels/neon:neon_depthwise_conv_op", ]) + if_mkl([ "//tensorflow/core/kernels:mkl_concat_op", + "//tensorflow/core/kernels:mkl_dequantize_op", "//tensorflow/core/kernels:mkl_conv_op", "//tensorflow/core/kernels:mkl_cwise_ops_common", "//tensorflow/core/kernels:mkl_fused_batch_norm_op", @@ -1516,6 +1569,7 @@ cc_library( "//tensorflow/core/kernels:mkl_lrn_op", "//tensorflow/core/kernels:mkl_requantize_ops", "//tensorflow/core/kernels:mkl_pooling_ops", + "//tensorflow/core/kernels:mkl_quantize_op", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_slice_op", @@ -1588,6 +1642,7 @@ cc_library( "framework/function_testlib.h", "framework/shape_inference_testutil.h", "framework/tensor_testutil.h", + "graph/benchmark_testlib.h", "graph/testlib.h", # TODO(josh11b): Drop this once users are depending on # kernels:ops_testutil instead. @@ -1615,9 +1670,14 @@ cc_library( ] + if_dynamic_kernels( [], otherwise = [ + "//tensorflow/core/kernels:aggregate_ops", + "//tensorflow/core/kernels:bcast_ops", "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:reduction_ops", + "//tensorflow/core/kernels:reshape_op", ], ), ) @@ -1628,8 +1688,8 @@ cc_library( srcs = ["common_runtime/testlib_ops.cc"], linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", + ":framework", + ":lib", ], alwayslink = 1, ) @@ -1647,6 +1707,13 @@ tf_cuda_library( alwayslink = 1, ) +# ----------------------------------------------------------------------------- +# MKL targets +cc_library( + name = "mkl_graph_util", + hdrs = ["graph/mkl_graph_util.h"], +) + # ----------------------------------------------------------------------------- # Public Android targets @@ -1671,6 +1738,8 @@ filegroup( ":protos_all_proto_text_srcs", ":error_codes_proto_text_srcs", "//tensorflow/core/platform/default/build_config:android_srcs", + "//tensorflow/core/util/ctc:android_srcs", + "//tensorflow/core/profiler:mobile_srcs", ] + glob( [ "client/**/*.cc", @@ -1702,8 +1771,13 @@ filegroup( "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/**/logger.cc", + # Exclude env_time and logging to avoid collisions with + # :platform_base, a common dependency for downstream targets. + "platform/**/env_time.cc", + "platform/**/logging.cc", "platform/default/test_benchmark.*", "platform/cuda.h", + "platform/rocm.h", "platform/google/**/*", "platform/hadoop/**/*", "platform/gif.h", @@ -1727,9 +1801,12 @@ filegroup( filegroup( name = "mobile_srcs_only_runtime", srcs = [ + "//tensorflow/core/common_runtime/eager:srcs", "//tensorflow/core/kernels:android_srcs", "//tensorflow/core/util/ctc:android_srcs", "//tensorflow/core/util/tensor_bundle:android_srcs", + "//tensorflow/c:srcs", + "//tensorflow/c/eager:srcs", ] + glob( [ "common_runtime/**/*.h", @@ -1746,7 +1823,6 @@ filegroup( "**/*testlib*", "**/*main.cc", "common_runtime/gpu/**/*", - "common_runtime/eager/*", "common_runtime/gpu_device_factory.*", "graph/dot.*", ], @@ -1779,7 +1855,7 @@ filegroup( # --host_crosstool_top=@bazel_tools//tools/cpp:toolchain cc_library( name = "android_tensorflow_lib_lite", - srcs = if_android(["//tensorflow/core:android_srcs"]), + srcs = if_android([":android_srcs"]), copts = tf_copts(android_optimization_level_override = None) + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], @@ -1795,6 +1871,7 @@ cc_library( ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", + "@farmhash_archive//:farmhash", "@nsync//:nsync_cpp", "@protobuf_archive//:protobuf", ], @@ -1803,7 +1880,7 @@ cc_library( cc_library( name = "android_tensorflow_lib_lite_nortti", - srcs = if_android(["//tensorflow/core:android_srcs"]), + srcs = if_android([":android_srcs"]), copts = tf_copts(android_optimization_level_override = None) + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ] + tf_opts_nortti_if_android(), @@ -1819,6 +1896,7 @@ cc_library( ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", + "@farmhash_archive//:farmhash", "@nsync//:nsync_cpp", "@protobuf_archive//:protobuf", ], @@ -1828,6 +1906,7 @@ cc_library( cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ + ":platform_base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -1836,10 +1915,9 @@ cc_library( cc_library( name = "emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", - srcs = if_emscripten(["//tensorflow/core:mobile_srcs_no_runtime"]), + srcs = if_emscripten([":mobile_srcs_no_runtime"]), copts = ["-DSUPPORT_SELECTIVE_REGISTRATION"] + tf_opts_nortti_if_emscripten(), defines = ["TENSORFLOW_LITE_PROTOS"], - linkopts = ["-lz"], tags = [ "manual", "notap", @@ -1851,6 +1929,7 @@ cc_library( ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", + "@farmhash_archive//:farmhash", "@nsync//:nsync_cpp", "@zlib_archive//:zlib", ], @@ -1860,7 +1939,7 @@ cc_library( # Native library support for iOS applications. # # bazel build --config=ios_x86_64 \ -# //third_party/tensorflow/core:ios_tensorflow_lib +# :ios_tensorflow_lib cc_library( name = "ios_tensorflow_lib", srcs = if_ios([ @@ -1884,7 +1963,7 @@ cc_library( cc_library( name = "ios_tensorflow_lib_lite", - srcs = if_ios(["//tensorflow/core:android_srcs"]), + srcs = if_ios([":android_srcs"]), copts = tf_copts() + ["-Os"] + ["-std=c++11"], visibility = ["//visibility:public"], deps = [ @@ -1893,6 +1972,7 @@ cc_library( ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", + "@farmhash_archive//:farmhash", "@nsync//:nsync_cpp", "@protobuf_archive//:protobuf", ], @@ -2162,9 +2242,41 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/variable_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/variable.proto", + visibility = ["//visibility:public"], +) + # ----------------------------------------------------------------------------- # Internal targets +tf_proto_library( + name = "autotuning_proto", + srcs = ["protobuf/autotuning.proto"], + cc_api_version = 2, + default_header = True, + provide_cc_alias = True, + visibility = [ + "//tensorflow:internal", + ], +) + +tf_proto_library( + name = "conv_autotuning_proto", + srcs = ["protobuf/conv_autotuning.proto"], + cc_api_version = 2, + default_header = True, + protodeps = [ + "//tensorflow/stream_executor:dnn_proto", + ], + provide_cc_alias = True, + visibility = [ + "//tensorflow:internal", + ], +) + tf_proto_library_cc( name = "worker_proto", srcs = ["protobuf/worker.proto"], @@ -2234,6 +2346,7 @@ LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob( "platform/jpeg.h", "platform/png.h", "platform/**/cuda.h", + "platform/**/rocm.h", "platform/**/stream_executor.h", ], ) @@ -2273,6 +2386,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "platform/denormal.h", "platform/host_info.h", "platform/platform.h", + "platform/monitoring.h", "platform/protobuf_internal.h", "platform/setround.h", "platform/snappy.h", @@ -2333,6 +2447,7 @@ cc_library( "lib/jpeg/**/*", "lib/png/**/*", "platform/**/env_time.cc", + "platform/**/monitoring.cc", "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", "platform/**/logger.cc", @@ -2346,6 +2461,8 @@ cc_library( "**/*test*", "platform/**/cuda.h", "platform/**/cuda_libdevice_path.cc", + "platform/**/rocm.h", + "platform/**/monitoring.cc", "platform/**/stream_executor.h", "platform/**/env_time.cc", "platform/**/device_tracer.cc", @@ -2357,7 +2474,7 @@ cc_library( # Protobuf deps already included through the ":lib_proto_parsing" # dependency. tf_additional_proto_srcs(), - ), + ) + tf_additional_monitoring_srcs(), hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), defines = LIB_INTERNAL_DEFINES, @@ -2465,12 +2582,6 @@ cc_library( cc_library( name = "tflite_portable_logging", - srcs = [ - ] + if_ios([ - "platform/default/logging.cc", - "platform/env_time.cc", - "platform/posix/env_time.cc", - ]), hdrs = [ "lib/bfloat16/bfloat16.h", "platform/default/integral_types.h", @@ -2479,10 +2590,11 @@ cc_library( "platform/macros.h", "platform/platform.h", "platform/types.h", - ] + if_windows(["platform/windows/integral_types.h"]) + if_ios(["platform/env_time.h"]), + ], copts = tf_copts(), linkopts = ["-ldl"], deps = [ + ":platform_base", "//tensorflow/core/platform/default/build_config:logging", ], ) @@ -2759,6 +2871,9 @@ tf_cuda_library( exclude = [ "**/*test*", "**/*main.cc", + "framework/allocator.cc", + "framework/allocator_registry.cc", + "framework/tracking_allocator.cc", "example/example_parser_configuration.*", "example/feature_util.cc", "util/reporter.cc", @@ -2790,6 +2905,7 @@ tf_cuda_library( ], }), deps = [ + ":allocator", ":feature_util", ":lib", ":lib_internal", @@ -2802,8 +2918,10 @@ tf_cuda_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/profiler/lib:traceme", "//third_party/eigen3", ] + if_static( extra_deps = ["@protobuf_archive//:protobuf"], @@ -2846,6 +2964,7 @@ tf_cuda_library( srcs = ["platform/stream_executor.h"], hdrs = [ "platform/cuda.h", + "platform/rocm.h", "platform/stream_executor.h", ], deps = [ @@ -2869,7 +2988,9 @@ cc_library( tf_cuda_library( name = "cuda_device_functions", - hdrs = ["util/cuda_device_functions.h"], + hdrs = [ + "util/gpu_device_functions.h", + ], visibility = ["//visibility:public"], deps = [":framework_lite"], ) @@ -2962,14 +3083,13 @@ tf_cuda_library( "common_runtime/scoped_allocator.cc", "common_runtime/scoped_allocator_mgr.cc", "common_runtime/shape_refiner.cc", - "common_runtime/shape_refiner.h", - "framework/versions.h", + "common_runtime/graph_optimizer.h", "graph/graph_constructor.cc", # Depends on common_runtime. "graph/graph_def_builder_util.cc", # Depends on common_runtime. "public/session.h", "public/session_options.h", "public/version.h", - ], + ] + CORE_CPU_BASE_HDRS, hdrs = CORE_CPU_BASE_HDRS, copts = tf_copts(), deps = [ @@ -2980,11 +3100,13 @@ tf_cuda_library( ":lib_internal", ":proto_text", ":protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", "//third_party/eigen3", ] + if_static([ ":function_ops_op_lib", ":functional_grad", ":functional_ops_op_lib", + "@com_google_absl//absl/algorithm:container", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:required", ]), @@ -3006,16 +3128,18 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/constant_folding.h", "common_runtime/copy_tensor.h", "common_runtime/costmodel_manager.h", + "common_runtime/placer_inspection_required_ops_utils.h", "common_runtime/debugger_state_interface.h", "common_runtime/device_resolver_local.h", "common_runtime/dma_helper.h", - "common_runtime/eigen_thread_pool.h", "common_runtime/executor.h", "common_runtime/executor_factory.h", "common_runtime/graph_optimizer.h", + "common_runtime/isolate_placer_inspection_required_ops_pass.h", "common_runtime/local_device.h", + "common_runtime/lower_function_call_op.h", "common_runtime/lower_if_op.h", - "common_runtime/lower_if_while.h", + "common_runtime/lower_functional_ops.h", "common_runtime/lower_while_op.h", "common_runtime/memory_types.h", "common_runtime/metrics.h", @@ -3047,9 +3171,7 @@ tf_cuda_library( name = "core_cpu_impl", srcs = [ "common_runtime/accumulate_n_optimizer.cc", - "common_runtime/allocator_retry.cc", "common_runtime/base_collective_executor.cc", - "common_runtime/bfc_allocator.cc", "common_runtime/buf_rendezvous.cc", "common_runtime/build_graph_options.cc", "common_runtime/collective_executor_mgr.cc", @@ -3073,9 +3195,13 @@ tf_cuda_library( "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", "common_runtime/hierarchical_tree_broadcaster.cc", + "common_runtime/inspecting_placer.cc", + "common_runtime/inspecting_placer.h", + "common_runtime/isolate_placer_inspection_required_ops_pass.cc", "common_runtime/local_device.cc", + "common_runtime/lower_function_call_op.cc", + "common_runtime/lower_functional_ops.cc", "common_runtime/lower_if_op.cc", - "common_runtime/lower_if_while.cc", "common_runtime/lower_while_op.cc", "common_runtime/memory_types.cc", "common_runtime/metrics.cc", @@ -3084,6 +3210,8 @@ tf_cuda_library( "common_runtime/parallel_concat_optimizer.cc", "common_runtime/partitioning_utils.cc", "common_runtime/placer.cc", + "common_runtime/placer_inspection_required_ops_utils.cc", + "common_runtime/placer_inspection_required_ops_utils.h", "common_runtime/pool_allocator.cc", "common_runtime/process_function_library_runtime.cc", "common_runtime/process_state.cc", @@ -3098,6 +3226,7 @@ tf_cuda_library( "common_runtime/session_factory.cc", "common_runtime/session_options.cc", "common_runtime/session_state.cc", + "common_runtime/single_threaded_cpu_device.cc", "common_runtime/stats_publisher_interface.cc", "common_runtime/step_stats_collector.cc", "common_runtime/threadpool_device.cc", @@ -3113,6 +3242,7 @@ tf_cuda_library( hdrs = CORE_CPU_LIB_HEADERS, copts = tf_copts(), deps = [ + ":bfc_allocator", ":graph", ":framework", ":framework_internal", @@ -3120,10 +3250,14 @@ tf_cuda_library( ":lib_internal", ":proto_text", ":protos_all_cc", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "//third_party/eigen3", "//tensorflow/core/grappler/utils:functions", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/profiler/internal:traceme_recorder", ] + mkl_deps(), alwayslink = 1, ) @@ -3153,6 +3287,7 @@ tf_cuda_library( ":lib", ":proto_text", ":protos_all_cc", + "@com_google_absl//absl/strings", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/clusters:utils", "//tensorflow/core/grappler/clusters:virtual_cluster", @@ -3168,6 +3303,36 @@ tf_cuda_library( alwayslink = 1, ) +# This is redundant with the "core_cpu_*" targets above. It's useful for +# applications that want to depend on a minimal subset of TensorFlow (e.g. XLA). +cc_library( + name = "bfc_allocator", + srcs = [ + "common_runtime/allocator_retry.cc", + "common_runtime/allocator_retry.h", + "common_runtime/bfc_allocator.cc", + ], + hdrs = ["common_runtime/bfc_allocator.h"], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + ":allocator", + ":lib", + ":lib_internal", + ":shared_counter", + ], +) + +cc_library( + name = "shared_counter", + hdrs = ["common_runtime/shared_counter.h"], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + ":lib", + ], +) + cc_library( name = "regexp_internal", hdrs = [ @@ -3191,7 +3356,6 @@ tf_cuda_library( copts = tf_copts(), deps = [ ":core_cpu_internal", - ":device_tracer", ":framework", ":framework_internal", ":graph", @@ -3201,6 +3365,9 @@ tf_cuda_library( ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/profiler/lib:profiler_graph_lib", + "//tensorflow/core/profiler/lib:profiler_session", + "//tensorflow/core/profiler/lib:traceme", ], alwayslink = 1, ) @@ -3227,11 +3394,8 @@ cc_library( tf_cuda_library( name = "device_tracer", srcs = tf_additional_device_tracer_srcs(), - hdrs = [ - "platform/device_tracer.h", - ], copts = tf_copts(), - cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()), + cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(), visibility = [ "//tensorflow:internal", ], @@ -3239,7 +3403,9 @@ tf_cuda_library( ":core_cpu_internal", ":lib", ":protos_all_cc", + "//tensorflow/core/profiler/internal:profiler_interface", ] + tf_additional_device_tracer_deps(), + alwayslink = True, ) tf_proto_library_cc( @@ -3280,16 +3446,17 @@ cc_library( ) GPU_RUNTIME_HEADERS = [ - "common_runtime/gpu/cuda_host_allocator.h", "common_runtime/gpu/gpu_bfc_allocator.h", "common_runtime/gpu/gpu_cudamalloc_allocator.h", "common_runtime/gpu/gpu_debug_allocator.h", "common_runtime/gpu/gpu_device.h", + "common_runtime/gpu/gpu_host_allocator.h", "common_runtime/gpu/gpu_id.h", "common_runtime/gpu/gpu_id_manager.h", "common_runtime/gpu/gpu_id_utils.h", "common_runtime/gpu/gpu_init.h", "common_runtime/gpu/gpu_managed_allocator.h", + "common_runtime/gpu/gpu_mem_allocator.h", "common_runtime/gpu/gpu_process_state.h", "common_runtime/gpu/gpu_stream_util.h", "common_runtime/gpu/gpu_util.h", @@ -3325,6 +3492,7 @@ tf_cuda_library( ":lib_internal", ":protos_all_cc", ":stream_executor", + "//tensorflow/core/profiler/lib:traceme", "//third_party/eigen3", ], alwayslink = 1, @@ -3346,6 +3514,40 @@ tf_cuda_library( ] + if_static([":gpu_runtime_impl"]), ) +# This is redundant with the "gpu_runtime_*" targets above. It's useful for +# applications that want to depend on a minimal subset of TensorFlow (e.g. XLA). +tf_cuda_library( + name = "gpu_bfc_allocator", + srcs = [ + "common_runtime/gpu/gpu_bfc_allocator.cc", + ], + hdrs = ["common_runtime/gpu/gpu_bfc_allocator.h"], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + ":bfc_allocator", + ":gpu_mem_allocator", + ":lib", + ":lib_internal", + ":protos_all_cc", + ], +) + +tf_cuda_library( + name = "gpu_mem_allocator", + srcs = [ + "common_runtime/gpu/gpu_id.h", + ], + hdrs = ["common_runtime/gpu/gpu_mem_allocator.h"], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + ":allocator", + ":lib_internal", + ":stream_executor", + ], +) + tf_cuda_library( name = "gpu_init", hdrs = [ @@ -3843,10 +4045,13 @@ tf_cc_tests( "common_runtime/collective_rma_local_test.cc", "common_runtime/device_resolver_local_test.cc", "common_runtime/device_set_test.cc", + "common_runtime/isolate_placer_inspection_required_ops_pass_test.cc", "common_runtime/optimization_registry_test.cc", "common_runtime/pending_counts_test.cc", + "common_runtime/placer_inspection_required_ops_utils_test.cc", "common_runtime/placer_test.cc", "common_runtime/session_test.cc", + "common_runtime/threadpool_device_test.cc", "example/feature_util_test.cc", "framework/allocator_test.cc", "framework/attr_value_util_test.cc", @@ -3903,6 +4108,7 @@ tf_cc_tests( "util/events_writer_test.cc", "util/example_proto_fast_parsing_test.cc", "util/example_proto_helper_test.cc", + "util/matmul_bcast_test.cc", "util/memmapped_file_system_test.cc", "util/presized_cuckoo_map_test.cc", "util/reffed_status_callback_test.cc", @@ -3948,6 +4154,8 @@ tf_cc_tests( "//tensorflow/core/kernels:ops_util", "//third_party/eigen3", "@com_google_absl//absl/base", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -3995,12 +4203,12 @@ tf_cc_test( "ops/cudnn_rnn_ops_test.cc", ], deps = [ - "//tensorflow/core", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", + ":core", + ":framework", + ":lib", + ":test", + ":test_main", + ":testlib", ], ) @@ -4166,11 +4374,13 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", "//tensorflow/core/kernels:mkl_cwise_ops_common", + "//tensorflow/core/kernels:mkl_dequantize_op", "//tensorflow/core/kernels:mkl_fused_batch_norm_op", "//tensorflow/core/kernels:mkl_identity_op", "//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_lrn_op", "//tensorflow/core/kernels:mkl_pooling_ops", + "//tensorflow/core/kernels:mkl_quantize_op", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_slice_op", @@ -4236,6 +4446,7 @@ tf_cc_test_gpu( ":test", ":test_main", ":testlib", + "//tensorflow/core/kernels:cwise_op", ], ) @@ -4271,9 +4482,9 @@ tf_cuda_cc_test( ) tf_cuda_only_cc_test( - name = "util_cuda_kernel_helper_test", + name = "util_gpu_kernel_helper_test", srcs = [ - "util/cuda_kernel_helper_test.cu.cc", + "util/gpu_kernel_helper_test.cu.cc", ], deps = [ ":test", @@ -4765,7 +4976,7 @@ tf_cc_test_gpu( name = "gpu_debug_allocator_test", size = "medium", srcs = ["common_runtime/gpu/gpu_debug_allocator_test.cc"], - args = ["\"--gtest_death_test_style=threadsafe\""], + args = ["--gtest_death_test_style=threadsafe"], linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags(), deps = [ @@ -4862,6 +5073,7 @@ tf_cc_test( "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:math", "//third_party/eigen3", ], ) @@ -5011,6 +5223,31 @@ tf_cc_test_gpu( ":testlib", "//tensorflow/cc:cc_ops", "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/profiler/internal:profiler_interface", + ], +) + +tf_cc_tests( + name = "common_runtime_lower_function_call_test", + size = "small", + srcs = ["common_runtime/lower_function_call_op_test.cc"], + deps = [ + ":all_kernels", + ":core_cpu", + ":core_cpu_internal", + ":direct_session", + ":framework", + ":framework_internal", + ":lib", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:client_session", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", ], ) @@ -5058,13 +5295,14 @@ tf_cc_tests( "//tensorflow/cc:client_session", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "@com_google_absl//absl/algorithm:container", ], ) tf_cc_tests( - name = "common_runtime_lower_if_while_test", + name = "common_runtime_lower_functional_ops_test", size = "small", - srcs = ["common_runtime/lower_if_while_test.cc"], + srcs = ["common_runtime/lower_functional_ops_test.cc"], deps = [ ":all_kernels", ":core_cpu", @@ -5163,18 +5401,16 @@ transitive_hdrs( name = "headers", visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:platform_strings", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:stream_executor", + ":core_cpu", + ":framework", + ":lib", + ":platform_strings", + ":protos_all_cc", + ":stream_executor", ], ) -# ----------------------------------------------------------------------------- -# Google-internal targets go here (must be at the end). - +# Placeholder for Google-internal load statements. load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library") genrule( @@ -5204,7 +5440,7 @@ tf_portable_proto_library( # There is currently no need for a full proto version of emscripten tf lib lite. alias( name = "emscripten_lib_lite_no_runtime", - actual = "//tensorflow/core:emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", + actual = ":emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousIteratorV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousIteratorV2.pbtxt new file mode 100644 index 00000000000..7416bf5ae55 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AnonymousIteratorV2.pbtxt @@ -0,0 +1,20 @@ +op { + graph_op_name: "AnonymousIteratorV2" + visibility: HIDDEN + out_arg { + name: "handle" + description: <>> a = [1., 2., 3.] +>>> equality_bitcast = tf.bitcast(a,tf.complex128) +tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast] +>>> equality_cast = tf.cast(a,tf.complex128) +>>> print(equality_cast) +tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128) +``` +Example 2: +```python +>>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8) + +``` +Example 3: +```python +>>> x = [1., 2., 3.] +>>> y = [0., 2., 3.] +>>> equality= tf.equal(x,y) +>>> equality_cast = tf.cast(equality,tf.float32) +>>> equality_bitcast = tf.bitcast(equality_cast,tf.uint8) +>>> print(equality) +tf.Tensor([False True True], shape=(3,), dtype=bool) +>>> print(equality_cast) +tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32) +>>> print(equality_bitcast) +tf.Tensor( +[[ 0 0 0 0] + [ 0 0 128 63] + [ 0 0 128 63]], shape=(3, 4), dtype=uint8) +``` + *NOTE*: Bitcast is implemented as a low-level cast, so machines with different endian orderings will give different results. END diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesAggregateStats.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesAggregateStats.pbtxt new file mode 100644 index 00000000000..d5a5502565d --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesAggregateStats.pbtxt @@ -0,0 +1,51 @@ +op { + graph_op_name: "BoostedTreesAggregateStats" + visibility: HIDDEN + in_arg { + name: "node_ids" + description: <

::value: \ + TF_RETURN_IF_ERROR(IsEqual
(a, b)); \ break; TF_CALL_NUMBER_TYPES(CASE); TF_CALL_string(CASE); // TODO(feihugis): figure out how to support variant tensors. #undef CASE default: - return errors::Internal("Unsupported dtype", a.dtype()); + return errors::Internal("Unsupported dtype: ", a.dtype()); + } + return Status::OK(); +} + +template +bool compare(const Tensor& t1, const Tensor& t2) { + auto flat_t1 = t1.flat(); + auto flat_t2 = t2.flat(); + auto length = std::min(flat_t1.size(), flat_t2.size()); + for (int i = 0; i < length; ++i) { + if (flat_t1(i) < flat_t2(i)) return true; + if (flat_t1(i) > flat_t2(i)) return false; + } + return flat_t1.size() < length; +} + +Status DatasetOpsTestBase::ExpectEqual(std::vector produced_tensors, + std::vector expected_tensors, + bool compare_order) { + if (produced_tensors.size() != expected_tensors.size()) { + return Status(tensorflow::errors::Internal( + "The two tensor vectors have different size (", produced_tensors.size(), + " v.s. ", expected_tensors.size(), ")")); + } + + if (produced_tensors.empty()) return Status::OK(); + if (produced_tensors[0].dtype() != expected_tensors[0].dtype()) { + return Status(tensorflow::errors::Internal( + "The two tensor vectors have different dtypes (", + produced_tensors[0].dtype(), " v.s. ", expected_tensors[0].dtype(), + ")")); + } + + if (!compare_order) { + const DataType& dtype = produced_tensors[0].dtype(); + switch (dtype) { +#define CASE(DT) \ + case DT: \ + std::sort(produced_tensors.begin(), produced_tensors.end(), \ + compare::Type>); \ + std::sort(expected_tensors.begin(), expected_tensors.end(), \ + compare::Type>); \ + break; + CASE(DT_FLOAT); + CASE(DT_DOUBLE); + CASE(DT_INT32); + CASE(DT_UINT8); + CASE(DT_INT16); + CASE(DT_INT8); + CASE(DT_STRING); + CASE(DT_INT64); + CASE(DT_BOOL); + CASE(DT_QINT8); + CASE(DT_QUINT8); + CASE(DT_QINT32); + CASE(DT_QINT16); + CASE(DT_QUINT16); + CASE(DT_UINT16); + CASE(DT_HALF); + CASE(DT_UINT32); + CASE(DT_UINT64); + // TODO(feihugis): support other dtypes. +#undef CASE + default: + return errors::Internal("Unsupported dtype: ", dtype); + } + } + + for (int i = 0; i < produced_tensors.size(); ++i) { + TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(produced_tensors[i], + expected_tensors[i])); } return Status::OK(); } @@ -88,11 +186,12 @@ Status DatasetOpsTestBase::CreateTensorSliceDataset( Status DatasetOpsTestBase::CreateOpKernel( const NodeDef& node_def, std::unique_ptr* op_kernel) { - Status status; - *op_kernel = - tensorflow::CreateOpKernel(device_type_, device_.get(), allocator_, - node_def, TF_GRAPH_DEF_VERSION, &status); - return status; + OpKernel* kernel; + TF_RETURN_IF_ERROR(tensorflow::CreateOpKernel(device_type_, device_.get(), + allocator_, flr_, node_def, + TF_GRAPH_DEF_VERSION, &kernel)); + op_kernel->reset(kernel); + return Status::OK(); } Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel, @@ -105,10 +204,20 @@ Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel, return Status::OK(); } +Status DatasetOpsTestBase::RestoreIterator( + IteratorContext* ctx, IteratorStateReader* reader, + const string& output_prefix, const DatasetBase& dataset, + std::unique_ptr* iterator) { + TF_RETURN_IF_ERROR(dataset.MakeIterator(ctx, output_prefix, iterator)); + TF_RETURN_IF_ERROR((*iterator)->Restore(ctx, reader)); + return Status::OK(); +} + Status DatasetOpsTestBase::CreateIteratorContext( OpKernelContext* const op_context, std::unique_ptr* iterator_context) { IteratorContext::Params params(op_context); + params.resource_mgr = op_context->resource_manager(); function_handle_cache_ = absl::make_unique(flr_); params.function_handle_cache = function_handle_cache_.get(); *iterator_context = absl::make_unique(params); @@ -130,7 +239,7 @@ Status DatasetOpsTestBase::InitThreadPool(int thread_num) { "The `thread_num` argument should be positive but got: ", thread_num); } thread_pool_ = absl::make_unique( - Env::Default(), ThreadOptions(), "inter_op", thread_num); + Env::Default(), ThreadOptions(), "test_thread_pool", thread_num); return Status::OK(); } @@ -147,6 +256,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime( TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( options, "/job:localhost/replica:0/task:0", &devices)); device_mgr_ = absl::make_unique(std::move(devices)); + resource_mgr_ = absl::make_unique("default_container"); FunctionDefLibrary proto; for (const auto& fdef : flib) *(proto.add_function()) = fdef; @@ -188,6 +298,7 @@ Status DatasetOpsTestBase::CreateOpKernelContext( step_container_ = absl::make_unique(0, [](const string&) {}); params_->step_container = step_container_.get(); + params_->resource_manager = resource_mgr_.get(); checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; slice_reader_cache_ = absl::make_unique(); @@ -210,9 +321,8 @@ Status DatasetOpsTestBase::CreateOpKernelContext( Status DatasetOpsTestBase::CreateSerializationContext( std::unique_ptr* context) { - SerializationContext::Params params; - params.flib_def = lib_def_.get(); - *context = absl::make_unique(params); + *context = + absl::make_unique(SerializationContext::Params{}); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h index 1d14608e346..d82a0c38583 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.h +++ b/tensorflow/core/kernels/data/dataset_test_base.h @@ -51,6 +51,13 @@ class DatasetOpsTestBase : public ::testing::Test { // and value. static Status ExpectEqual(const Tensor& a, const Tensor& b); + // The method validates whether the two tensor vectors have the same tensors. + // If `compare_order` is false, the method will only evaluate whether the two + // vectors have the same elements regardless of order. + static Status ExpectEqual(std::vector produced_tensors, + std::vector expected_tensors, + bool compare_order); + // Creates a tensor with the specified dtype, shape, and value. template static Tensor CreateTensor(TensorShape input_shape, @@ -68,6 +75,15 @@ class DatasetOpsTestBase : public ::testing::Test { Status CreateDataset(OpKernel* kernel, OpKernelContext* context, DatasetBase** const dataset); + // Restores the state of the input iterator. It resets the iterator before + // restoring it to make sure the input iterator does not hold any + // resources or tasks. Otherwise, restoring an existing iterator may cause + // the timeout issue or duplicated elements. + Status RestoreIterator(IteratorContext* ctx, IteratorStateReader* reader, + const string& output_prefix, + const DatasetBase& dataset, + std::unique_ptr* iterator); + // Creates a new RangeDataset op kernel. `T` specifies the output dtype of the // op kernel. template @@ -190,6 +206,7 @@ class DatasetOpsTestBase : public ::testing::Test { std::function)> runner_; std::unique_ptr device_mgr_; std::unique_ptr lib_def_; + std::unique_ptr resource_mgr_; std::unique_ptr params_; std::unique_ptr slice_reader_cache_; diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index 9def8c91618..d6101c1740c 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -14,94 +14,193 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/dataset_utils.h" + #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/grappler_item_builder.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace data { +namespace { -Status ComputeShortCircuitIndices(OpKernelContext* ctx, - const NameAttrList& func, - std::vector* indices) { - FunctionLibraryRuntime::Handle fn_handle; - TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( - func.name(), AttrSlice(&func.attr()), &fn_handle)); - auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { - Status s = ctx->function_library()->ReleaseHandle(fn_handle); - if (!s.ok()) { - LOG(WARNING) << "Failed to release handle: " << s.error_message(); - } - }); +void AddFakeSinks(FunctionDef* function_def) { + int counter = 0; + for (const auto& output : function_def->signature().output_arg()) { + NodeDef* node = function_def->add_node_def(); + tensorflow::grappler::function_utils::SetUniqueFunctionNodeName( + strings::StrCat("FakeSink", counter++), function_def, node); + node->set_op("Identity"); + node->add_input(function_def->ret().at(output.name())); + (*node->mutable_attr())["T"].set_type(output.type()); - // If the function contains any stateful operations, we conservatively execute - // the entire function. - if (ctx->function_library()->IsStateful(func.name())) { - indices->clear(); - return Status::OK(); + (*function_def->mutable_ret())[output.name()] = + strings::StrCat(node->name(), ":output:0"); } +} - const FunctionBody* fn_body = - ctx->function_library()->GetFunctionBody(fn_handle); - indices->resize(fn_body->ret_nodes.size()); - - for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) { - Node* ret_node = fn_body->ret_nodes[i]; - Node* ret_input_node; - TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node)); - - while (ret_input_node->def().op() == "Identity") { - TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node)); - } - - if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) { - TF_RETURN_IF_ERROR( - GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i]))); - } else { - indices->clear(); - break; +void RemoveFakeSinks(FunctionDef* function_def) { + // Map from identity node names to their input tensor strings + std::map identity_map; + for (const auto& node : function_def->node_def()) { + if (node.op() == "Identity" && node.input_size() == 1) { + identity_map[node.name()] = node.input(0); } } + for (const auto& output_arg : function_def->signature().output_arg()) { + const string& tensor = function_def->ret().at(output_arg.name()); + const string& output_node = tensor.substr(0, tensor.find(':')); + if (identity_map.find(output_node) != identity_map.end()) { + (*function_def->mutable_ret())[output_arg.name()] = + identity_map.at(output_node); + } + } +} + +Status ApplyRewrites(OpKernelContext* ctx, + const std::function config_factory, + bool optimize_function_library, GraphDef* graph_def, + string* output_node) { + // Add an identity node as the fetch node, otherwise we might get 'placeholder + // is both fed and fetched' errors in some cases when using input list with + // placeholder dataset nodes. + NodeDef* node = graph_def->mutable_node()->Add(); + tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def, + node); + node->set_op("Identity"); + node->add_input(*output_node); + (*node->mutable_attr())["T"].set_type(DT_VARIANT); + *output_node = node->name(); + + // Add fake sink node to graph and functions to allow rewriting the actual + // sink nodes. + // + // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals + // to be optimizable, we will no longer need this. + for (auto& function_def : *graph_def->mutable_library()->mutable_function()) { + AddFakeSinks(&function_def); + } + + // Create metagraph. + MetaGraphDef meta_graph_def; + (*meta_graph_def.mutable_graph_def()) = *graph_def; + + // Grappler determines fetch ops from collection 'train_op'. + CollectionDef collection_def; + auto node_list = collection_def.mutable_node_list(); + node_list->add_value(*output_node); + (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def; + + // Create Grappler item. + tensorflow::grappler::ItemConfig item_config; + item_config.apply_optimizations = true; + std::unique_ptr grappler_item = + tensorflow::grappler::GrapplerItemFromMetaGraphDef( + "graph", meta_graph_def, item_config); + grappler_item->optimization_options().optimize_function_library = + optimize_function_library; + std::unordered_map device_map; + tensorflow::grappler::VirtualCluster cluster(device_map); + + // Run data optimizer using grappler's meta optimizer. + tensorflow::ConfigProto config; + *config.mutable_graph_options()->mutable_rewrite_options() = config_factory(); + TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( + *grappler_item, config, ctx->device(), &cluster, graph_def)); + + // Remove fake sinks after optimizations are done. + // + // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals + // to be optimizable, we will no longer need this. + for (auto& function_def : *graph_def->mutable_library()->mutable_function()) { + RemoveFakeSinks(&function_def); + } + return Status::OK(); } -std::vector ComputeMoveVector(const std::vector& indices) { - std::map last_use; - for (size_t i = 0; i < indices.size(); ++i) { - last_use[indices[i]] = i; - } - std::vector can_move; - can_move.resize(indices.size()); - for (size_t i = 0; i < indices.size(); ++i) { - can_move[i] = last_use[indices[i]] == i; - } - return can_move; +} // anonymous namespace + +Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset, + SerializationContext&& serialization_ctx, + GraphDef* graph_def) { + GraphDefBuilder b; + DatasetBase::DatasetGraphDefBuilder db(&b); + Node* output_node = nullptr; + TF_RETURN_IF_ERROR( + db.AddInputDataset(&serialization_ctx, dataset, &output_node)); + // Insert a purely symbolic _Retval node to indicate to consumers which Tensor + // represents this Dataset. + ops::UnaryOp("_Retval", output_node, + b.opts() + .WithName("dataset") + .WithAttr("T", DT_VARIANT) + .WithAttr("index", 0)); + TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def)); + return Status::OK(); } -Status MakeIteratorFromInputElement( - IteratorContext* ctx, const std::vector& input_element, - int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func, - StringPiece prefix, std::unique_ptr* out_iterator) { - std::vector return_values; +Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, + std::function config_factory, + bool optimize_function_library, + DatasetBase** rewritten_input) { + SerializationContext::Params params; + std::vector> input_list; + params.input_list = &input_list; + params.optimization_only = true; + SerializationContext serialization_ctx(params); + GraphDef graph_def; + TF_RETURN_IF_ERROR( + AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def)); - TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element, - &return_values)); - - if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT && - TensorShapeUtils::IsScalar(return_values[0].shape()))) { - return errors::InvalidArgument( - "Function must return a single scalar of dtype DT_VARIANT."); + string output_node; + for (const auto& node : graph_def.node()) { + if (node.op() == "_Retval") { + output_node = node.input(0); + } } - // Retrieve the dataset that was created in `f`. - DatasetBase* returned_dataset; - TF_RETURN_IF_ERROR( - GetDatasetFromVariantTensor(return_values[0], &returned_dataset)); + VLOG(3) << "Before graph rewrites: " << graph_def.DebugString(); + TF_RETURN_IF_ERROR(ApplyRewrites(ctx, config_factory, + optimize_function_library, &graph_def, + &output_node)); + VLOG(3) << "After graph rewrites: " << graph_def.DebugString(); - // Create an iterator for the dataset that was returned by `f`. - return returned_dataset->MakeIterator( - ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator); + // Instantiate the optimized input pipeline by running the optimized graph + // using the optimized function library. + FunctionLibraryRuntime* flr = nullptr; + std::unique_ptr pflr = nullptr; + std::unique_ptr lib_def = nullptr; + TF_RETURN_IF_ERROR( + ctx->function_library()->Clone(&lib_def, &pflr, &flr, true)); + + // Some functions may have been modified without having their names + // changed (for example, nested dataset graphs from FlatMap or + // Interleave). + TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library())); + + Graph graph(OpRegistry::Global()); + TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); + std::vector outputs; + GraphRunner graph_runner(flr->device()); + + TF_RETURN_IF_ERROR( + graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs)); + TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], rewritten_input)); + (*rewritten_input)->Ref(); + return Status::OK(); } Status VerifyTypesMatch(const DataTypeVector& expected, @@ -261,5 +360,24 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base, } return base->AddLibrary(to_add); } + +std::function)> RunnerWithMaxParallelism( + std::function)> runner, int max_parallelism) { + return std::bind( + [max_parallelism]( + // Note: `runner` is a const reference to avoid copying it. + const std::function)>& runner, + std::function fn) { + std::function scoped_fn = std::bind( + [max_parallelism](const std::function& fn) { + ScopedPerThreadMaxParallelism scope(max_parallelism); + fn(); + }, + std::move(fn)); + runner(std::move(scoped_fn)); + }, + std::move(runner), std::placeholders::_1); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index d85e87ca098..9d8f4a3824d 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -17,35 +17,20 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/captured_function.h" namespace tensorflow { namespace data { -// This method is used to determine whether we can short-circuit the evaluation -// of the user-defined function `func`. Short-circuting is possible if every -// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) = -// (y,x)`, or `f(x) = (x,x)`). -// -// If short-circuiting is possible, the method stores the mapping from output -// indices to input indices in `indices`. Otherwise, `indices` will be empty. -// -// Returns non-ok status if analysis of the function fails. -// -// TODO(jsimsa): Extend this to support constants as well. -Status ComputeShortCircuitIndices(OpKernelContext* ctx, - const NameAttrList& func, - std::vector* indices); +// Returns a GraphDef representation of the given dataset. +Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset, + SerializationContext&& serialization_ctx, + GraphDef* graph_def); -// Given a vector that maps output indices to input indices, return a vector -// that identifies for which output indices can we move the input (assuming -// output indices are processed left to right). -std::vector ComputeMoveVector(const std::vector& indices); - -Status MakeIteratorFromInputElement( - IteratorContext* ctx, const std::vector& input_element, - int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func, - StringPiece prefix, std::unique_ptr* out_iterator); +// Rewrites the input dataset using the given config. +Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, + std::function config_factory, + bool optimize_function_library, + DatasetBase** rewritten_input); // Returns Status::OK() if `expected` and `received` types match, // errors::InvalidArgument otherwise. @@ -105,6 +90,11 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base, const FunctionLibraryDefinition& to_add); Status AddToFunctionLibrary(FunctionLibraryDefinition* base, const FunctionDefLibrary& to_add); + +// Creates a runner that runs functions with limited parallelism. +std::function)> RunnerWithMaxParallelism( + std::function)> runner, int max_parallelism); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc index bddd2d455e5..a553b8ab67d 100644 --- a/tensorflow/core/kernels/data/dataset_utils_test.cc +++ b/tensorflow/core/kernels/data/dataset_utils_test.cc @@ -19,31 +19,12 @@ limitations under the License. #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace data { namespace { -TEST(DatasetUtilsTest, ComputeMoveVector) { - struct TestCase { - std::vector indices; - std::vector expected; - }; - - TestCase test_cases[] = { - TestCase{{}, {}}, - TestCase{{1}, {true}}, - TestCase{{1, 1}, {false, true}}, - TestCase{{1, 2}, {true, true}}, - TestCase{{1, 1, 2}, {false, true, true}}, - TestCase{{1, 2, 2}, {true, false, true}}, - }; - - for (auto& test_case : test_cases) { - EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices)); - } -} - TEST(DatasetUtilsTest, VariantTensorDataRoundtrip) { VariantTensorData data; VariantTensorDataWriter writer(&data); @@ -163,6 +144,13 @@ TEST(DatasetUtilsTest, AddToFunctionLibraryWithConflictingSignatures) { "signature already exists.", s.error_message()); } + +TEST(DatasetUtilsTest, RunnerWithMaxParallelism) { + auto runner = + RunnerWithMaxParallelism([](const std::function fn) { fn(); }, 2); + auto fn = []() { ASSERT_EQ(GetPerThreadMaxParallelism(), 2); }; + runner(fn); +} } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 569501dd03c..ce31fc3403a 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -1,5 +1,6 @@ # Description: # Contains experimental kernels for datasets and iterators. + package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 @@ -21,6 +22,36 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "auto_shard_dataset_op", + srcs = ["auto_shard_dataset_op.cc"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/optimizers/data:auto_shard", + "//tensorflow/core/kernels/data:dataset_utils", + ], +) + +tf_kernel_library( + name = "choose_fastest_branch_dataset_op", + srcs = ["choose_fastest_branch_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels/data:captured_function", + "//tensorflow/core/kernels/data:dataset_utils", + "//tensorflow/core/kernels/data:take_dataset_op", + ], +) + tf_kernel_library( name = "csv_dataset_op", srcs = ["csv_dataset_op.cc"], @@ -54,21 +85,6 @@ tf_kernel_library( ], ) -tf_kernel_library( - name = "auto_shard_dataset_op", - srcs = ["auto_shard_dataset_op.cc"], - deps = [ - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:dataset_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler/optimizers/data:auto_shard", - "//tensorflow/core/kernels/data:graph_rewrite_dataset", - ], -) - tf_kernel_library( name = "group_by_reducer_dataset_op", srcs = ["group_by_reducer_dataset_op.cc"], @@ -79,6 +95,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/data:captured_function", + "//tensorflow/core/kernels/data:dataset_utils", ], ) @@ -92,6 +109,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/data:captured_function", + "//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:window_dataset", ], ) @@ -106,18 +124,6 @@ tf_kernel_library( ], ) -tf_kernel_library( - name = "indexed_dataset_op", - srcs = ["indexed_dataset_op.cc"], - deps = [ - "//tensorflow/core:experimental_dataset_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core/kernels/data:dataset_utils", - "//third_party/eigen3", - ], -) - tf_kernel_library( name = "lmdb_dataset_op", srcs = ["lmdb_dataset_op.cc"], @@ -182,23 +188,6 @@ tf_kernel_library( ], ) -tf_kernel_library( - name = "numa_map_and_batch_dataset_op", - srcs = ["numa_map_and_batch_dataset_op.cc"], - deps = [ - "//tensorflow/core:array_ops_op_lib", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:experimental_dataset_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:nn_ops_op_lib", - "//tensorflow/core/kernels:inplace_ops", - "//tensorflow/core/kernels/data:captured_function", - "@com_google_absl//absl/memory", - ], -) - tf_kernel_library( name = "parallel_interleave_dataset_op", srcs = ["parallel_interleave_dataset_op.cc"], @@ -259,7 +248,18 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/optimizers/data:rebatch", - "//tensorflow/core/kernels/data:graph_rewrite_dataset", + "//tensorflow/core/kernels/data:dataset_utils", + ], +) + +tf_kernel_library( + name = "sampling_dataset_op", + srcs = ["sampling_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) @@ -273,6 +273,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/kernels/data:captured_function", + "//tensorflow/core/kernels/data:dataset_utils", ], ) @@ -283,6 +284,7 @@ tf_kernel_library( "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels/data:stats_utils", ], ) @@ -306,6 +308,18 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "snapshot_dataset_op", + srcs = ["snapshot_dataset_op.cc"], + deps = [ + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + tf_kernel_library( name = "sql_dataset_op", srcs = [ @@ -327,6 +341,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:summary_interface", ], ) @@ -356,6 +371,18 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "threadpool_dataset_op", + srcs = ["threadpool_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/kernels/data:dataset_utils", + "//third_party/eigen3", + ], +) + tf_kernel_library( name = "to_tf_record_op", srcs = ["to_tf_record_op.cc"], @@ -368,17 +395,6 @@ tf_kernel_library( ], ) -tf_kernel_library( - name = "threadpool_dataset_op", - srcs = ["threadpool_dataset_op.cc"], - deps = [ - "//tensorflow/core:experimental_dataset_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//third_party/eigen3", - ], -) - tf_kernel_library( name = "unbatch_dataset_op", srcs = ["unbatch_dataset_op.cc"], @@ -406,6 +422,7 @@ tf_kernel_library( deps = [ ":assert_next_dataset_op", ":auto_shard_dataset_op", + ":choose_fastest_branch_dataset_op", ":choose_fastest_dataset_op", ":csv_dataset_op", ":dense_to_sparse_batch_dataset_op", @@ -413,21 +430,21 @@ tf_kernel_library( ":group_by_reducer_dataset_op", ":group_by_window_dataset_op", ":ignore_errors_dataset_op", - ":indexed_dataset_op", ":lmdb_dataset_op", ":map_and_batch_dataset_op", ":matching_files_dataset_op", ":non_serializable_dataset_op", - ":numa_map_and_batch_dataset_op", ":parallel_interleave_dataset_op", ":parse_example_dataset_op", ":prefetching_kernels", ":random_dataset_op", ":rebatch_dataset_op", + ":sampling_dataset_op", ":scan_dataset_op", ":set_stats_aggregator_dataset_op", ":sleep_dataset_op", ":sliding_window_dataset_op", + ":snapshot_dataset_op", ":sql_dataset_op", ":stats_aggregator_ops", ":stats_dataset_ops", diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc index eb547133609..cda0885e7f0 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -99,7 +99,7 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel { Status Initialize(IteratorContext* ctx) override { std::vector tokens = - str_util::Split(prefix(), ':', str_util::SkipEmpty()); + absl::StrSplit(prefix(), ':', absl::SkipEmpty()); if (dataset()->transformations_.size() > tokens.size() - 2) { return errors::InvalidArgument( "Asserted next ", dataset()->transformations_.size(), diff --git a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc index 3728c64ab5d..7531225b817 100644 --- a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace data { @@ -24,17 +25,12 @@ constexpr char kOptimizerName[] = "tf_auto_shard"; class AutoShardDatasetOp : public UnaryDatasetOpKernel { public: explicit AutoShardDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - } + : UnaryDatasetOpKernel(ctx) {} protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - int64 index; - int64 num_workers; + int64 index, num_workers; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers)); OP_REQUIRES( ctx, num_workers > 0, @@ -45,69 +41,39 @@ class AutoShardDatasetOp : public UnaryDatasetOpKernel { errors::InvalidArgument("index must be between 0 and ", num_workers - 1)); - Dataset* dataset = new Dataset(ctx, input, num_workers, index, - output_types_, output_shapes_); - const Status s = dataset->Optimize(ctx); + auto config_factory = [num_workers, index]() { + return CreateConfig(num_workers, index); + }; - if (s.ok()) { - *output = dataset; - } else { - dataset->Unref(); - OP_REQUIRES_OK(ctx, s); - } + // We only want to optimize functions for some particular datasets like + // FlatMapDataset, InterleaveDataset etc. So we disable generalized + // function optimization and explicitly handle function modifications + // for those datasets in the rewrite. + OP_REQUIRES_OK(ctx, + RewriteDataset(ctx, input, std::move(config_factory), + /*optimize_function_library=*/false, output)); } private: - class Dataset : public GraphRewriteDataset { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const int64 num_workers, const int64 index, - const DataTypeVector& output_types, - const std::vector& output_shapes) - : GraphRewriteDataset(ctx, input, output_types, output_shapes), - num_workers_(num_workers), - index_(index) {} + static RewriterConfig CreateConfig(int64 num_workers, int64 index) { + RewriterConfig rewriter_config; + rewriter_config.set_fail_on_optimizer_errors(true); + rewriter_config.add_optimizers(kOptimizerName); + rewriter_config.set_meta_optimizer_iterations( + RewriterConfig_NumIterationsType_ONE); + auto custom_optimizer = rewriter_config.add_custom_optimizers(); + custom_optimizer->set_name(kOptimizerName); + AttrValue num_workers_attr; + num_workers_attr.set_i(num_workers); + (*custom_optimizer->mutable_parameter_map())["num_workers"] = + num_workers_attr; - string DebugString() const override { - return "AutoShardDatasetOp::Dataset"; - } + AttrValue index_attr; + index_attr.set_i(index); + (*custom_optimizer->mutable_parameter_map())["index"] = index_attr; - private: - bool ShouldOptimizeFunctions() override { - // We only want to optimize functions for some particular datasets like - // FlatMapDataset, InterleaveDataset etc. So we disable generalized - // function optimization and explicitly handle function modifications - // for those datasets in the rewrite. - return false; - } - - RewriterConfig CreateGrapplerRewriteConfig() override { - RewriterConfig rewriter_config; - rewriter_config.set_fail_on_optimizer_errors(true); - rewriter_config.add_optimizers(kOptimizerName); - rewriter_config.set_meta_optimizer_iterations( - RewriterConfig_NumIterationsType_ONE); - auto custom_optimizer = rewriter_config.add_custom_optimizers(); - custom_optimizer->set_name(kOptimizerName); - AttrValue num_workers_attr; - num_workers_attr.set_i(num_workers_); - (*custom_optimizer->mutable_parameter_map())["num_workers"] = - num_workers_attr; - - AttrValue index_attr; - index_attr.set_i(index_); - (*custom_optimizer->mutable_parameter_map())["index"] = index_attr; - - return rewriter_config; - } - - const int64 num_workers_; - const int64 index_; - }; - - const int graph_def_version_; - DataTypeVector output_types_; - std::vector output_shapes_; + return rewriter_config; + } }; REGISTER_KERNEL_BUILDER(Name("ExperimentalAutoShardDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc new file mode 100644 index 00000000000..8b4bafe1f5b --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc @@ -0,0 +1,557 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/take_dataset_op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/histogram/histogram.h" + +namespace tensorflow { +namespace data { +namespace { + +static const double kPercentile = 90.0; + +// Each instance of this class wraps an iterator. Whenever an iterator created +// for this dataset invokes the `GetNext` method, the call is delegated to the +// wrapped iterator's `GetNext` method. +class WrapperDataset : public DatasetBase { + public: + WrapperDataset(DatasetContext::Params params, + const DataTypeVector* output_dtypes, + const std::vector* output_shapes, + IteratorBase* iterator) + : DatasetBase(DatasetContext(std::move(params))), + output_dtypes_(output_dtypes), + output_shapes_(output_shapes), + real_iterator_(iterator) {} + + const DataTypeVector& output_dtypes() const override { + return *output_dtypes_; + } + + const std::vector& output_shapes() const override { + return *output_shapes_; + } + + string DebugString() const override { return "WrapperDataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { + return errors::Unimplemented(DebugString(), "::AsGraphDefInternal"); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + // MakeIterator should only be called once per WrapperDataset. However, + // since this function expects an iterator return value, we raise the + // error only at iterator initialization time. + bool error = iterator_created_; + iterator_created_ = true; + return absl::make_unique( + WrapperIterator::Params{this, strings::StrCat(prefix, "::Wrapper")}, + error); + } + + private: + class WrapperIterator : public DatasetIterator { + public: + explicit WrapperIterator(const Params& params, bool error) + : DatasetIterator(params), error_(error) {} + + Status Initialize(IteratorContext* ctx) override { + if (error_) { + return errors::InvalidArgument( + "Cannot create more than one WrapperIterator per WrapperDataset. " + "Make sure the branches to ChooseFastestDataset do not expect the " + "input to repeat."); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + return dataset()->real_iterator_->GetNext(ctx, out_tensors, + end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1.0); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return Status::OK(); + } + + private: + const bool error_; + }; + + mutable bool iterator_created_ = false; + const DataTypeVector* const output_dtypes_; + const std::vector* const output_shapes_; + IteratorBase* const real_iterator_; // not owned. +}; + +// This Dataset picks between some dataset function branches. Each function is +// expected to input a dataset and output a dataset. The datasets in the +// branches are expected to be stateless. For each iterator that can be produced +// by a functions output, it is expected to call the input dataset's +// MakeIterator method at most once; otherwise, undefined behavior may occur. +class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ChooseFastestBranchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + std::vector funcs; + OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &funcs)); + func_metadatas_.resize(funcs.size()); + for (int i = 0; i < funcs.size(); ++i) { + OP_REQUIRES_OK( + ctx, FunctionMetadata::Create(ctx, std::move(funcs[i]), /*params=*/{}, + &func_metadatas_[i])); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_elements_per_branch", + &num_elements_per_branch_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("other_arguments_lengths", + &other_arguments_lengths_)); + + OP_REQUIRES( + ctx, func_metadatas_.size() == other_arguments_lengths_.size(), + errors::InvalidArgument( + "branches and other_arguments_lengths must have the same length.")); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "ratio_numerator", + &ratio_numerator_)); + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "ratio_denominator", + &ratio_denominator_)); + OP_REQUIRES(ctx, ratio_numerator_ > 0, + errors::InvalidArgument( + "`ratio_numerator` must be greater than zero.")); + OP_REQUIRES(ctx, ratio_denominator_ > 0, + errors::InvalidArgument( + "`ratio_denominator` must be greater than zero.")); + OP_REQUIRES(ctx, num_elements_per_branch_ % ratio_denominator_ == 0, + errors::InvalidArgument("`num_elements_per_branch` must be " + "divisible by `ratio_denominator`.")); + + std::vector> captured_funcs( + func_metadatas_.size()); + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); + + // Keeps track of starting index into other_arguments for a given function. + int index = 0; + for (int i = 0; i < func_metadatas_.size(); ++i) { + std::vector captured_args; + captured_args.reserve(other_arguments_lengths_[i]); + int end_index = index + other_arguments_lengths_[i]; + for (; index < end_index; ++index) { + captured_args.push_back(inputs[index]); + } + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_metadatas_[i], + std::move(captured_args), + &captured_funcs[i])); + } + *output = new Dataset(ctx, input, std::move(captured_funcs), output_types_, + output_shapes_, num_elements_per_branch_, + ratio_numerator_, ratio_denominator_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, DatasetBase* input, + std::vector> captured_funcs, + const DataTypeVector& output_types, + const std::vector& output_shapes, + int64 num_elements_per_branch, int64 ratio_numerator, + int64 ratio_denominator) + : DatasetBase(DatasetContext(ctx)), + input_(input), + captured_funcs_(std::move(captured_funcs)), + output_types_(output_types), + output_shapes_(output_shapes), + num_elements_per_branch_(num_elements_per_branch), + ratio_numerator_(ratio_numerator), + ratio_denominator_(ratio_denominator) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique( + ChooseFastestIterator::Params{ + this, strings::StrCat(prefix, "::ChooseFastestBranch")}); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "ChooseFastestBranchDatasetOp::Dataset"; + } + + int64 Cardinality() const override { + int64 n = input_->Cardinality(); + if (n == kInfiniteCardinality || n == kUnknownCardinality) { + return n; + } + // TODO(rachelim): this might be wrong if the ratio is not fixed, for + // example, from a BatchDataset with drop_remainder = False + return static_cast(n) * ratio_numerator_ / ratio_denominator_; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + + Node* ratio_numerator_node; + TF_RETURN_IF_ERROR(b->AddScalar(ratio_numerator_, &ratio_numerator_node)); + Node* ratio_denominator_node; + TF_RETURN_IF_ERROR( + b->AddScalar(ratio_denominator_, &ratio_denominator_node)); + + std::vector other_arguments_lengths; + other_arguments_lengths.reserve(captured_funcs_.size()); + int num_captured_inputs = 0; + for (const auto& func : captured_funcs_) { + num_captured_inputs += func->captured_inputs().size(); + other_arguments_lengths.push_back(func->captured_inputs().size()); + } + std::vector other_arguments; + DataTypeVector other_arguments_types; + other_arguments_types.reserve(num_captured_inputs); + other_arguments.reserve(num_captured_inputs); + for (const auto& captured_func : captured_funcs_) { + TF_RETURN_IF_ERROR(captured_func->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); + } + + // Targuments + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + // num_elements_per_branch + AttrValue num_elements_per_branch_attr; + b->BuildAttrValue(num_elements_per_branch_, + &num_elements_per_branch_attr); + + // branches + AttrValue branches_attr; + std::vector funcs; + funcs.resize(captured_funcs_.size()); + for (int i = 0; i < captured_funcs_.size(); ++i) { + funcs[i] = captured_funcs_[i]->func(); + } + b->BuildAttrValue(funcs, &branches_attr); + + // other_arguments_lengths + AttrValue other_arguments_lengths_attr; + b->BuildAttrValue(other_arguments_lengths, &other_arguments_lengths_attr); + + return b->AddDataset( + this, + /*inputs=*/ + {std::make_pair(0, input_graph_node), + std::make_pair(1, ratio_numerator_node), + std::make_pair(2, ratio_denominator_node)}, + /*list_inputs=*/{std::make_pair(3, other_arguments)}, + /*attrs=*/ + {std::make_pair("Targuments", other_arguments_types_attr), + std::make_pair("num_elements_per_branch", + num_elements_per_branch_attr), + std::make_pair("branches", branches_attr), + std::make_pair("other_arguments_lengths", + other_arguments_lengths_attr)}, + output); + } + + private: + // This iterator picks the fastest of dataset branches by running + // experiments for the first dataset()->num_elements_per_branch_ * + // num_branches iterations. + class ChooseFastestIterator : public DatasetIterator { + public: + explicit ChooseFastestIterator(const Params& params) + : DatasetIterator(params), + instantiated_captured_funcs_(dataset()->captured_funcs_.size()), + histograms_(dataset()->captured_funcs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + + for (int i = 0; i < dataset()->captured_funcs_.size(); ++i) { + TF_RETURN_IF_ERROR(dataset()->captured_funcs_[i]->Instantiate( + ctx, &instantiated_captured_funcs_[i])); + } + + return Status::OK(); + } + + // The first num_elements_per_branch * num_branches iterations, we run + // experiments on the branches, using (branch_index_, experiment_counter_) + // to keep track of which experiment we're on. + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + { // Locking scope + mutex_lock l(mu_); + if (branch_index_ < dataset()->captured_funcs_.size()) { + // Still running experiments + if (!current_iterator_) { + TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_, + /*is_experiment=*/true)); + } + + Status s = GetNextFromExperiment(ctx, out_tensors, end_of_sequence); + experiment_counter_++; + + if (experiment_counter_ >= dataset()->num_elements_per_branch_) { + // Done experimenting with this branch. Increment the branch index + // so that on the next iteration, we will draw from the next + // branch. + experiment_counter_ = 0; + branch_index_++; + current_iterator_.reset(); + } + return s; + } + if (!current_iterator_) { + SelectFastestInputIndex(); + TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_, + /*is_experiment=*/false)); + } + } + + return current_iterator_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode( + std::move(args), + /*ratio=*/static_cast(dataset()->ratio_numerator_) / + dataset()->ratio_denominator_); + } + + // TODO(rachelim): Save and restore histogram state as well. Currently, + // if an iterator is saved and restored, the histograms start recording + // from scratch. + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"), + experiment_counter_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("branch_index"), branch_index_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("fastest_index"), fastest_index_)); + if (current_iterator_) { + TF_RETURN_IF_ERROR(SaveInput(writer, current_iterator_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"), + &experiment_counter_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("branch_index"), &branch_index_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("fastest_index"), &fastest_index_)); + + // Restore state of `current_iterator_` if it exists. + if (!reader->Contains(full_name("input_impl_empty"))) { + if (branch_index_ < dataset()->captured_funcs_.size()) { + TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_, + /*is_experiment=*/true)); + } else { + TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_, + /*is_experiment=*/false)); + } + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_iterator_)); + } + return Status::OK(); + } + + private: + Status GetNextFromExperiment(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + DCHECK_GE(branch_index_, 0); + DCHECK_LT(branch_index_, histograms_.size()); + + int64 start = ctx->env()->NowNanos(); + Status s = + current_iterator_->GetNext(ctx, out_tensors, end_of_sequence); + + if (experiment_counter_ > 0) { + // Ignore the first experiment when benchmarking. It may be an outlier + // due to session set up time and other overheads. + histograms_[branch_index_].Add( + static_cast(ctx->env()->NowNanos() - start)); + } + return s; + } + + // Select the fastest input to use based on the histograms of timings + // of the completed iterations. The input with the best 90th percentile + // iteration time is selected. + void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + fastest_index_ = 0; + + VLOG(2) << "90.0 percentile iteration time:"; + double best_percentile = histograms_[0].Percentile(kPercentile); + VLOG(2) << "Branch 0: " << best_percentile; + for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs; + ++i) { + double percentile = histograms_[i].Percentile(kPercentile); + VLOG(2) << "Branch " << i << ": " << percentile; + if (percentile <= best_percentile) { + best_percentile = percentile; + fastest_index_ = i; + } + } + VLOG(1) << "Selecting index " << fastest_index_ + << " as the fastest index."; + } + + Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index, + bool is_experiment) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + DCHECK_GE(branch_index, 0); + DCHECK_LT(branch_index, histograms_.size()); + + // `StoreDatasetInVariantTensor` transfers ownership of the dataset + // to the tensor, so the tensor must persist between iterations. + wrapper_dataset_tensor_ = + absl::make_unique(DT_VARIANT, TensorShape({})); + + DatasetContext::Params params; + params.type_string = "ChooseFastestBranch_Wrapper"; + params.node_name = strings::StrCat(params.type_string, branch_index); + DatasetBase* temp_dataset = + new WrapperDataset(std::move(params), &dataset()->output_types_, + &dataset()->output_shapes_, input_impl_.get()); + + if (is_experiment) { + // When running experiment iterations, we add a TakeDataset in between + // the input and the function datasets. This is so that function + // datasets with prefetching behavior won't consume more input + // elements than they actually use to produce output. + DatasetContext::Params take_dataset_params; + take_dataset_params.type_string = "ChooseFastestBranch_Take"; + take_dataset_params.node_name = + strings::StrCat(take_dataset_params.type_string, branch_index); + int64 count = dataset()->num_elements_per_branch_ * + dataset()->ratio_numerator_ / + dataset()->ratio_denominator_; + temp_dataset = new TakeDataset(std::move(take_dataset_params), count, + temp_dataset); + } + + TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor( + temp_dataset, wrapper_dataset_tensor_.get())); + + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( + ctx, {*wrapper_dataset_tensor_}, branch_index, + *instantiated_captured_funcs_[branch_index], prefix(), + ¤t_iterator_)); + + return Status::OK(); + } + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + std::vector> + instantiated_captured_funcs_ GUARDED_BY(mu_); + + // For tracking the time taken for each input's iterations. + std::vector histograms_ GUARDED_BY(mu_); + int64 fastest_index_ = -1; + std::unique_ptr wrapper_dataset_tensor_; + std::unique_ptr current_iterator_; + + // Keeps track of which (branch, experiment) the next iteration is on. + int64 branch_index_ GUARDED_BY(mu_) = 0; + int64 experiment_counter_ GUARDED_BY(mu_) = 0; + }; // class Iterator + + const DatasetBase* const input_; + const std::vector> captured_funcs_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + const int64 num_elements_per_branch_; + const int64 ratio_numerator_; + const int64 ratio_denominator_; + }; // class Dataset + + int64 ratio_numerator_; + int64 ratio_denominator_; + int64 num_elements_per_branch_; + std::vector> func_metadatas_; + DataTypeVector output_types_; + std::vector output_shapes_; + std::vector other_arguments_lengths_; +}; // class ChooseFastestBranchDatasetOp + +// Register the kernel implementation for ChooseFastestBranchDataset. +REGISTER_KERNEL_BUILDER(Name("ChooseFastestBranchDataset").Device(DEVICE_CPU), + ChooseFastestBranchDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc index bfa2bf6bc46..1ae86c1dbfa 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc @@ -217,8 +217,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { } return threads[0].result->status; } - return input_impls_[fastest_index_]->GetNext(ctx, out_tensors, - end_of_sequence); + return fastest_input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } protected: @@ -232,7 +231,14 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { // from scratch. Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - if (input_impls_.empty()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"), + experiment_counter_)); + + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("fastest_index"), fastest_index_)); + if (fastest_index_ != -1) { + TF_RETURN_IF_ERROR(SaveInput(writer, fastest_input_impl_)); + } else if (input_impls_.empty()) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impls_empty"), "")); } else { @@ -240,17 +246,22 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(SaveInput(writer, input_impl)); } } - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"), - experiment_counter_)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("fastest_index"), fastest_index_)); return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - if (reader->Contains(full_name("input_impls_empty"))) { + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"), + &experiment_counter_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("fastest_index"), &fastest_index_)); + if (fastest_index_ != -1) { + TF_RETURN_IF_ERROR(dataset()->inputs_[fastest_index_]->MakeIterator( + ctx, strings::StrCat(prefix(), "_", fastest_index_), + &fastest_input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, fastest_input_impl_)); + } else if (reader->Contains(full_name("input_impls_empty"))) { input_impls_.clear(); } else { DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size()); @@ -258,10 +269,6 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl)); } } - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"), - &experiment_counter_)); - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("fastest_index"), &fastest_index_)); return Status::OK(); } @@ -279,6 +286,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { }; std::vector> input_impls_; + std::unique_ptr fastest_input_impl_; // For tracking the time taken for each input's iterations. std::vector histograms_; @@ -317,15 +325,23 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { void SelectFastestInputIndex() EXCLUSIVE_LOCKS_REQUIRED(mu_) { fastest_index_ = 0; + VLOG(2) << "90.0 percentile iteration time:"; double best_percentile = histograms_[0].Percentile(kPercentile); + VLOG(2) << "Branch 0: " << best_percentile; for (size_t i = 1, num_inputs = histograms_.size(); i < num_inputs; ++i) { double percentile = histograms_[i].Percentile(kPercentile); + VLOG(2) << "Branch " << i << ": " << percentile; if (percentile <= best_percentile) { best_percentile = percentile; fastest_index_ = i; } } + VLOG(1) << "Selecting index " << fastest_index_ + << " as the fastest index."; + + fastest_input_impl_ = std::move(input_impls_[fastest_index_]); + input_impls_.clear(); // Delete the unused iterators. } }; // class Iterator diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc index 4435c2a1313..fecafaacf2d 100644 --- a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc @@ -61,7 +61,7 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, errors::InvalidArgument("`select_cols` must be a vector.")); - int64 buffer_size; + int64 buffer_size = 0; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); OP_REQUIRES(ctx, buffer_size > 0, diff --git a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc index 56159593a9c..88e53dfe6c1 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { @@ -31,10 +32,17 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { public: explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "key_func", /*params=*/{}, + &key_func_metadata_)); + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "init_func", /*params=*/{}, + &init_func_metadata_)); + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "reduce_func", /*params=*/{}, + &reduce_func_metadata_)); + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "finalize_func", /*params=*/{}, + &finalize_func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -42,20 +50,20 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { std::unique_ptr captured_key_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx, + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, key_func_metadata_, "key_func_other_arguments", &captured_key_func)); std::unique_ptr captured_init_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(init_func_, ctx, + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, init_func_metadata_, "init_func_other_arguments", &captured_init_func)); std::unique_ptr captured_reduce_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx, + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, reduce_func_metadata_, "reduce_func_other_arguments", &captured_reduce_func)); std::unique_ptr captured_finalize_func; OP_REQUIRES_OK(ctx, - CapturedFunction::Create(finalize_func_, ctx, + CapturedFunction::Create(ctx, finalize_func_metadata_, "finalize_func_other_arguments", &captured_finalize_func)); @@ -109,45 +117,41 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); std::vector key_func_other_arguments_node; DataTypeVector key_func_other_arguments_types; - TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - ctx, b, captured_key_func_, &key_func_other_arguments_node, - &key_func_other_arguments_types)); + TF_RETURN_IF_ERROR( + captured_key_func_->AddToGraph(ctx, b, &key_func_other_arguments_node, + &key_func_other_arguments_types)); std::vector init_func_other_arguments_node; DataTypeVector init_func_other_arguments_types; - TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - ctx, b, captured_init_func_, &init_func_other_arguments_node, + TF_RETURN_IF_ERROR(captured_init_func_->AddToGraph( + ctx, b, &init_func_other_arguments_node, &init_func_other_arguments_types)); std::vector reduce_func_other_arguments_node; DataTypeVector reduce_func_other_arguments_types; - TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node, + TF_RETURN_IF_ERROR(captured_reduce_func_->AddToGraph( + ctx, b, &reduce_func_other_arguments_node, &reduce_func_other_arguments_types)); std::vector finalize_func_other_arguments_node; DataTypeVector finalize_func_other_arguments_types; - TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - ctx, b, captured_finalize_func_, &finalize_func_other_arguments_node, + TF_RETURN_IF_ERROR(captured_finalize_func_->AddToGraph( + ctx, b, &finalize_func_other_arguments_node, &finalize_func_other_arguments_types)); AttrValue key_func; - b->BuildAttrValue(this->key_func(), &key_func); + b->BuildAttrValue(captured_key_func_->func(), &key_func); AttrValue init_func; - b->BuildAttrValue(this->init_func(), &init_func); + b->BuildAttrValue(captured_init_func_->func(), &init_func); AttrValue reduce_func; - b->BuildAttrValue(this->reduce_func(), &reduce_func); + b->BuildAttrValue(captured_reduce_func_->func(), &reduce_func); AttrValue finalize_func; - b->BuildAttrValue(this->finalize_func(), &finalize_func); + b->BuildAttrValue(captured_finalize_func_->func(), &finalize_func); AttrValue key_func_other_arguments_types_attr; b->BuildAttrValue(key_func_other_arguments_types, @@ -391,42 +395,6 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr instantiated_finalize_func_; }; - const NameAttrList& key_func() const { return captured_key_func_->func(); } - - const NameAttrList& init_func() const { - return captured_init_func_->func(); - } - - const NameAttrList& reduce_func() const { - return captured_reduce_func_->func(); - } - - const NameAttrList& finalize_func() const { - return captured_finalize_func_->func(); - } - - Status OtherArgumentsNodeAndType( - SerializationContext* ctx, DatasetGraphDefBuilder* b, - const std::unique_ptr& captured_func, - std::vector* other_arguments_node, - DataTypeVector* other_arguments_types) const { - other_arguments_node->reserve(captured_func->captured_inputs().size()); - other_arguments_types->reserve(captured_func->captured_inputs().size()); - for (const Tensor& t : captured_func->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments_node->emplace_back(node); - other_arguments_types->emplace_back(t.dtype()); - } - return Status::OK(); - } - const DatasetBase* const input_; const std::unique_ptr captured_key_func_; const std::unique_ptr captured_init_func_; @@ -436,12 +404,12 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { const std::vector output_shapes_; }; + std::shared_ptr key_func_metadata_ = nullptr; + std::shared_ptr init_func_metadata_ = nullptr; + std::shared_ptr reduce_func_metadata_ = nullptr; + std::shared_ptr finalize_func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList key_func_; - NameAttrList init_func_; - NameAttrList reduce_func_; - NameAttrList finalize_func_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index 49122807b28..1bdb25be51c 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/window_dataset.h" #include "tensorflow/core/lib/random/random.h" @@ -32,9 +33,14 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { public: explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "key_func", /*params=*/{}, + &key_func_metadata_)); + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "reduce_func", /*params=*/{}, + &reduce_func_metadata_)); + OP_REQUIRES_OK( + ctx, FunctionMetadata::Create(ctx, "window_size_func", /*params=*/{}, + &window_size_func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -42,31 +48,31 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { std::unique_ptr captured_key_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx, + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, key_func_metadata_, "key_func_other_arguments", &captured_key_func)); + std::unique_ptr captured_reduce_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx, + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, reduce_func_metadata_, "reduce_func_other_arguments", &captured_reduce_func)); + std::unique_ptr captured_window_size_func; OP_REQUIRES_OK(ctx, - CapturedFunction::Create(window_size_func_, ctx, + CapturedFunction::Create(ctx, window_size_func_metadata_, "window_size_func_other_arguments", &captured_window_size_func)); - *output = new Dataset( - ctx, input, key_func_, reduce_func_, window_size_func_, - std::move(captured_key_func), std::move(captured_reduce_func), - std::move(captured_window_size_func), output_types_, output_shapes_); + *output = new Dataset(ctx, input, std::move(captured_key_func), + std::move(captured_reduce_func), + std::move(captured_window_size_func), output_types_, + output_shapes_); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& key_func, const NameAttrList& reduce_func, - const NameAttrList& window_size_func, std::unique_ptr captured_key_func, std::unique_ptr captured_reduce_func, std::unique_ptr captured_window_size_func, @@ -74,9 +80,6 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), input_(input), - key_func_(key_func), - reduce_func_(reduce_func), - window_size_func_(window_size_func), captured_key_func_(std::move(captured_key_func)), captured_reduce_func_(std::move(captured_reduce_func)), captured_window_size_func_(std::move(captured_window_size_func)), @@ -108,37 +111,33 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); std::vector key_func_other_arguments_node; DataTypeVector key_func_other_arguments_types; - TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - ctx, b, captured_key_func_, &key_func_other_arguments_node, - &key_func_other_arguments_types)); + TF_RETURN_IF_ERROR( + captured_key_func_->AddToGraph(ctx, b, &key_func_other_arguments_node, + &key_func_other_arguments_types)); std::vector reduce_func_other_arguments_node; DataTypeVector reduce_func_other_arguments_types; - TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( - ctx, b, captured_reduce_func_, &reduce_func_other_arguments_node, + TF_RETURN_IF_ERROR(captured_reduce_func_->AddToGraph( + ctx, b, &reduce_func_other_arguments_node, &reduce_func_other_arguments_types)); std::vector window_size_func_other_arguments_node; DataTypeVector window_size_func_other_arguments_types; - TF_RETURN_IF_ERROR( - OtherArgumentsNodeAndType(ctx, b, captured_window_size_func_, - &window_size_func_other_arguments_node, - &window_size_func_other_arguments_types)); + TF_RETURN_IF_ERROR(captured_window_size_func_->AddToGraph( + ctx, b, &window_size_func_other_arguments_node, + &window_size_func_other_arguments_types)); AttrValue key_func; - b->BuildAttrValue(key_func_, &key_func); + b->BuildAttrValue(captured_key_func_->func(), &key_func); AttrValue reduce_func; - b->BuildAttrValue(reduce_func_, &reduce_func); + b->BuildAttrValue(captured_reduce_func_->func(), &reduce_func); AttrValue window_size_func; - b->BuildAttrValue(window_size_func_, &window_size_func); + b->BuildAttrValue(captured_window_size_func_->func(), &window_size_func); AttrValue key_func_other_arguments_types_attr; b->BuildAttrValue(key_func_other_arguments_types, @@ -472,8 +471,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { GetDatasetFromVariantTensor(return_values[0], &returned_dataset)); // Create an iterator for the dataset that was returned by `f`. - return returned_dataset->MakeIterator(ctx, prefix(), - ¤t_group_iterator_); + return returned_dataset->MakeIterator( + ctx, strings::StrCat(prefix(), "::Reduce"), + ¤t_group_iterator_); } mutex mu_; @@ -490,32 +490,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { instantiated_window_size_func_; }; - Status OtherArgumentsNodeAndType( - SerializationContext* ctx, DatasetGraphDefBuilder* b, - const std::unique_ptr& captured_func, - std::vector* other_arguments_node, - DataTypeVector* other_arguments_types) const { - other_arguments_node->reserve(captured_func->captured_inputs().size()); - other_arguments_types->reserve(captured_func->captured_inputs().size()); - for (const Tensor& t : captured_func->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments_node->emplace_back(node); - other_arguments_types->emplace_back(t.dtype()); - } - return Status::OK(); - } - const DatasetBase* const input_; - const NameAttrList key_func_; - const NameAttrList reduce_func_; - const NameAttrList window_size_func_; const std::unique_ptr captured_key_func_; const std::unique_ptr captured_reduce_func_; const std::unique_ptr captured_window_size_func_; @@ -523,11 +498,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { const std::vector output_shapes_; }; + std::shared_ptr key_func_metadata_ = nullptr; + std::shared_ptr reduce_func_metadata_ = nullptr; + std::shared_ptr window_size_func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList key_func_; - NameAttrList reduce_func_; - NameAttrList window_size_func_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset_op.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset_op.cc deleted file mode 100644 index e75e6e4b80b..00000000000 --- a/tensorflow/core/kernels/data/experimental/indexed_dataset_op.cc +++ /dev/null @@ -1,547 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/cleanup.h" - -namespace tensorflow { -namespace data { -namespace { - -// TODO(saeta): Urgh, this is ugly. -class MaterializedIndexedDataset { - public: - virtual ~MaterializedIndexedDataset() = default; - - // Retrieve the element at a given index. The output tensors are stored in - // out_tensors. - // - // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is - // returned. - // - // Get is thread-safe. - virtual Status Get(IteratorContext&& ctx, uint64 index, - std::vector* out_tensors) const = 0; - - // Size determines the number of elements in this IndexedDataset. - // - // Size is thread-safe. - virtual Status Size(uint64* size) const = 0; - - // Returns a vector of DataType values, representing the respective - // element types of each tuple component in the outputs of this dataset. - virtual const DataTypeVector& output_dtypes() const = 0; - - // Returns a vector of tensor shapes, representing the respective - // (and possibly partially defined) shapes of each tuple component - // in the outputs of this dataset. - virtual const std::vector& output_shapes() const = 0; -}; - -// IndexedDataset represents a dataset that supports random access in addition -// to iterator-based sequential access. -// -// Note: IndexedDatasets are HIGHLY experimental at this time. Expect -// significant (backwards incompatible) changes! -class IndexedDataset : public DatasetBase { - public: - explicit IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {} - - // Materialize (if necessary) the dataset, and return a pointer. - // TODO(saeta): Add in `IteratorContext* ctx` when materializing. - virtual Status MaterializeDataset( - std::shared_ptr* materialized) = 0; -}; - -// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the -// rest of the TensorFlow runtime. -// -// Most IndexedDataset's will be private members of classes inheriting from this -// class. -class IndexedDatasetOpKernel : public OpKernel { - public: - explicit IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* ctx) final; - - protected: - // Subclasses should implement this method. It will be called during Compute - // execution. - virtual void MakeIndexedDataset(OpKernelContext* ctx, - IndexedDataset** output) = 0; - - template - Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece argument_name, T* output) { - const Tensor* argument_t; - TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); - if (!TensorShapeUtils::IsScalar(argument_t->shape())) { - return errors::InvalidArgument(argument_name, " must be a scalar"); - } - *output = argument_t->scalar()(); - return Status::OK(); - } -}; - -class MaterializedDatasetResource : public ResourceBase { - public: - MaterializedDatasetResource( - const DataTypeVector& output_dtypes, - const std::vector& output_shapes) - : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {} - - string DebugString() const override { - return "Materialized IndexedDataset resource"; - } - - Status Get(IteratorContext&& ctx, uint64 index, - std::vector* out_tensors) { - std::shared_ptr captured(materialized_); - if (captured) { - return captured->Get(std::move(ctx), index, out_tensors); - } else { - return errors::FailedPrecondition( - "Get() failed because the MaterializedIndexedDataset has not been " - "initialized. Ensure that you have run the materialization operation " - "for this MaterializedIndexedDataset before retrieving elements."); - } - } - - // TODO(saeta): Implement Save and Restore - - const DataTypeVector& output_dtypes() const { return output_dtypes_; } - const std::vector& output_shapes() const { - return output_shapes_; - } - - Status set_materialized_dataset( - const std::shared_ptr& dataset) { - if (dataset) { - TF_RETURN_IF_ERROR( - VerifyTypesMatch(output_dtypes_, dataset->output_dtypes())); - TF_RETURN_IF_ERROR( - VerifyShapesCompatible(output_shapes_, dataset->output_shapes())); - } - materialized_ = dataset; - return Status::OK(); - } - - private: - std::shared_ptr materialized_; - const DataTypeVector output_dtypes_; - const std::vector output_shapes_; -}; - -// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT -// tensor. Objects of the wrapper class own a reference on an instance of an -// `IndexedTensor` and the wrapper's copy constructor and desctructor take care -// of managing the reference count. -// -// NOTE: This is not a feature-complete implementation of the DT_VARIANT -// specification. In particular, we cannot currently serialize an arbitrary -// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not -// implemented. -// -// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just -// use `tensorflow::DatasetVariantWrapper`. -class IndexedDatasetVariantWrapper { - public: - IndexedDatasetVariantWrapper() : dataset_(nullptr) {} - - // Transfers ownership of `dataset` to `*this`. - explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset) - : dataset_(dataset) {} - - IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other) - : dataset_(other.dataset_) { - if (dataset_) dataset_->Ref(); - } - - ~IndexedDatasetVariantWrapper() { - if (dataset_) dataset_->Unref(); - } - - IndexedDataset* get() const { return dataset_; } - - string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; } - string DebugString() const { - if (dataset_) { - return dataset_->DebugString(); - } else { - return ""; - } - } - - void Encode(VariantTensorData* data) const { - LOG(ERROR) << "The Encode() method is not implemented for " - "IndexedDatasetVariantWrapper objects."; - } - - bool Decode(const VariantTensorData& data) { - LOG(ERROR) << "The Decode() method is not implemented for " - "IndexedDatasetVariantWrapper objects."; - return false; - } - - private: - IndexedDataset* const dataset_; // Owns one reference. -}; - -Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, - IndexedDataset** out_dataset) { - if (!(tensor.dtype() == DT_VARIANT || - TensorShapeUtils::IsScalar(tensor.shape()))) { - return errors::InvalidArgument( - "IndexedDataset tensor must be a scalar of dtype DT_VARIANT."); - } - const Variant& variant = tensor.scalar()(); - const IndexedDatasetVariantWrapper* wrapper = - variant.get(); - if (wrapper == nullptr) { - return errors::InvalidArgument("Tensor must be an IndexedDataset object."); - } - *out_dataset = wrapper->get(); - if (*out_dataset == nullptr) { - return errors::Internal("Read uninitialized IndexedDataset variant."); - } - return Status::OK(); -} - -Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, - Tensor* tensor) { - if (!(tensor->dtype() == DT_VARIANT || - TensorShapeUtils::IsScalar(tensor->shape()))) { - return errors::InvalidArgument( - "Dataset tensor must be a scalar of dtype DT_VARIANT."); - } - tensor->scalar()() = IndexedDatasetVariantWrapper(dataset); - return Status::OK(); -} - -void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) { - IndexedDataset* dataset = nullptr; - MakeIndexedDataset(ctx, &dataset); - - if (ctx->status().ok()) { - OP_REQUIRES(ctx, dataset != nullptr, - errors::Internal("MakeIndexedDataset did not correctly " - "construct the IndexedDataset")); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output)); - } -} - -class MaterializedHandleOp : public OpKernel { - public: - explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - } - - ~MaterializedHandleOp() override { - if (resource_ != nullptr) { - resource_->Unref(); - if (cinfo_.resource_is_private_to_kernel()) { - if (!cinfo_.resource_manager() - ->template Delete( - cinfo_.container(), cinfo_.name()) - .ok()) { - // Do nothing; the resource can have been deleted by session resets. - // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h - } - } - } - } - - void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { - { - mutex_lock l(mu_); - if (resource_ == nullptr) { - ResourceMgr* mgr = context->resource_manager(); - OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); - - MaterializedDatasetResource* resource; - OP_REQUIRES_OK(context, - mgr->LookupOrCreate( - cinfo_.container(), cinfo_.name(), &resource, - [this](MaterializedDatasetResource** ret) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new MaterializedDatasetResource( - output_dtypes_, output_shapes_); - return Status::OK(); - })); - Status s = VerifyResource(resource); - if (TF_PREDICT_FALSE(!s.ok())) { - resource->Unref(); - context->SetStatus(s); - return; - } - - resource_ = resource; - } - } - OP_REQUIRES_OK(context, MakeResourceHandleToOutput( - context, 0, cinfo_.container(), cinfo_.name(), - MakeTypeIndex())); - } - - private: - // During the first Compute(), resource is either created or looked up using - // shared_name. In the latter case, the resource found should be verified if - // it is compatible with this op's configuration. The verification may fail in - // cases such as two graphs asking queues of the same shared name to have - // inconsistent capacities. - Status VerifyResource(MaterializedDatasetResource* resource) { - TF_RETURN_IF_ERROR( - VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); - TF_RETURN_IF_ERROR( - VerifyShapesCompatible(output_shapes_, resource->output_shapes())); - return Status::OK(); - } - - mutex mu_; - ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. - MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr; - DataTypeVector output_dtypes_; - std::vector output_shapes_; -}; - -// TODO(saeta): Make async. -class MaterializeDatasetOp : public OpKernel { - public: - explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - IndexedDataset* dataset; - OP_REQUIRES_OK(ctx, - GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset)); - - MaterializedDatasetResource* materialized_resource; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), - &materialized_resource)); - core::ScopedUnref unref(materialized_resource); - std::shared_ptr materialized; - OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized)); - OP_REQUIRES_OK( - ctx, materialized_resource->set_materialized_dataset(materialized)); - } -}; - -// TODO(saeta): Make async -class IndexedDatasetGet : public OpKernel { - public: - explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - MaterializedDatasetResource* materialized_resource; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), - &materialized_resource)); - auto cleanup = gtl::MakeCleanup([materialized_resource] { - materialized_resource->Unref(); // Note: can't use core::ScopedUnref. - }); - - const Tensor* index_t; - OP_REQUIRES_OK(ctx, ctx->input("index", &index_t)); - // TODO(saeta): Support batch reads (indexes should be non-scalar!) - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()), - errors::InvalidArgument("index must be a scalar")); - const uint64 index = index_t->scalar()(); - - std::vector out_tensors; - Status s = - materialized_resource->Get(IteratorContext(ctx), index, &out_tensors); - - // Note: Unref materialized_resource to avoid destruction races. (Important - // in a [future] async op implementation.) - cleanup.release()(); - - if (!s.ok()) { - ctx->SetStatus(s); - } else { - auto expected_shapes = materialized_resource->output_shapes(); - auto expected_types = materialized_resource->output_dtypes(); - for (size_t i = 0; i < out_tensors.size(); ++i) { - OP_REQUIRES( - ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()), - errors::Internal( - "Materialized dataset output at index ", i, - " is incompatible with the expected shape. (Expected: ", - expected_shapes[i], ", got: ", out_tensors[i].shape(), ")")); - OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i], - errors::Internal("Materialized dataset output at index ", i, - " was not the expected dtype. (Expected: ", - expected_types[i], - ", got: ", out_tensors[i].dtype(), ")")); - ctx->set_output(i, out_tensors[i]); - } - } - } -}; - -REGISTER_KERNEL_BUILDER( - Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU), - MaterializedHandleOp); -REGISTER_KERNEL_BUILDER( - Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU), - MaterializeDatasetOp); -REGISTER_KERNEL_BUILDER( - Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU), - IndexedDatasetGet); - -class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel { - public: - using IndexedDatasetOpKernel::IndexedDatasetOpKernel; - - void MakeIndexedDataset(OpKernelContext* ctx, - IndexedDataset** output) override { - uint64 size = -1; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "size", &size)); - OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0")); - *output = new Dataset(ctx, size); - } - - class Dataset : public IndexedDataset { - public: - Dataset(OpKernelContext* ctx, uint64 size) - : IndexedDataset(DatasetContext(ctx)), size_(size) {} - - Status MaterializeDataset( - std::shared_ptr* materialized) override { - (*materialized) = std::make_shared(this); - return Status::OK(); - } - - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64}); - return *dtypes; - } - - const std::vector& output_shapes() const override { - static std::vector* shapes = - new std::vector({{}}); - return *shapes; - } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique(Iterator::Params{ - this, strings::StrCat(prefix, "::IdentityIndexedDataset")}); - } - - string DebugString() const override { - return "IdentityIndexedDataset::Dataset"; - } - - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const override { - return errors::Unimplemented( - "identity_indexed_dataset.AsGraphDefInternal"); - } - - private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - if (cur_ < dataset()->size_) { - out_tensors->emplace_back(ctx->allocator({}), DT_UINT64, - TensorShape({})); - out_tensors->back().scalar()() = cur_++; - *end_of_sequence = false; - return Status::OK(); - } - *end_of_sequence = true; - return Status::OK(); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - private: - mutex mu_; - uint64 cur_ GUARDED_BY(mu_); - }; - - class Materialized : public MaterializedIndexedDataset { - public: - explicit Materialized(Dataset* dataset) : dataset_(dataset) { - dataset->Ref(); - } - - ~Materialized() override { - // TODO(saeta): Pull this into MaterializedIndexedDataset - dataset_->Unref(); - } - - const DataTypeVector& output_dtypes() const override { - return dataset_->output_dtypes(); - } - - const std::vector& output_shapes() const override { - return dataset_->output_shapes(); - } - - Status Get(IteratorContext&& ctx, uint64 index, - std::vector* out_tensors) const override { - LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index - << ")"; - if (index >= dataset_->size_) { - // Note: use InvalidArgument instead of OutOfRange error because many - // things consider OutOfRange to be a "clean termination" error. - return errors::InvalidArgument( - "Index ", index, - " is out of range for this dataset. (Size is: ", dataset_->size_, - ".)"); - } - out_tensors->emplace_back(ctx.allocator({}), DT_UINT64, - TensorShape({})); - out_tensors->back().scalar()() = index; - return Status::OK(); - } - - Status Size(uint64* size) const override { - *size = dataset_->size_; - return Status::OK(); - } - - private: - const Dataset* const dataset_; // Not owned. - }; - - const uint64 size_; - std::shared_ptr materialized_; - }; -}; - -REGISTER_KERNEL_BUILDER( - Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU), - IdentityIndexedDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index fb7a6204a04..a7472a49e4a 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -48,14 +48,10 @@ constexpr int64 kMaxBatchResults = 16; // description of the following op. class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: - using MapAndBatchIteratorFunction = - std::function, - std::shared_ptr>, StatusCallback)>; - explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "f", /*params=*/{}, + &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); OP_REQUIRES_OK( @@ -65,13 +61,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - int64 batch_size; + int64 batch_size = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size)); OP_REQUIRES( ctx, batch_size > 0, errors::InvalidArgument("batch_size must be greater than zero.")); - int64 num_parallel_calls; + int64 num_parallel_calls = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); OP_REQUIRES( @@ -84,84 +80,36 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ParseScalarArgument(ctx, "drop_remainder", &drop_remainder)); std::unique_ptr captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - &captured_func)); - - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - - MapAndBatchIteratorFunction map_func; - CapturedFunction* raw_captured_func = captured_func.get(); - if (indices.empty()) { - map_func = [](IteratorContext* ctx, - InstantiatedCapturedFunction* instantiated_captured_func, - const string& prefix, std::vector args, - std::shared_ptr> out_tensors, - StatusCallback done) { - instantiated_captured_func->RunAsync( - ctx, std::move(args), out_tensors.get(), std::move(done), prefix); - }; - } else { - std::vector can_move = ComputeMoveVector(indices); - map_func = [raw_captured_func, indices, can_move]( - IteratorContext* ctx, - InstantiatedCapturedFunction* instantiated_captured_func, - const string& prefix, std::vector args, - std::shared_ptr> out_tensors, - StatusCallback done) { - const std::vector& captured_inputs = - raw_captured_func->captured_inputs(); - size_t num_args = args.size(); - for (size_t i = 0; i < indices.size(); ++i) { - if (indices[i] < num_args) { - if (can_move[i]) { - out_tensors->push_back(std::move(args[indices[i]])); - } else { - out_tensors->push_back(args[indices[i]]); - } - } else { - out_tensors->push_back(captured_inputs[indices[i] - num_args]); - } - } - // Run the `done` callback on a threadpool thread, because it will - // potentially do a lot of copying work, and we want to run that - // concurrently with the next invocation. - (*ctx->runner())(std::bind(std::move(done), Status::OK())); - }; - } + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func)); if (num_parallel_calls == model::kAutoTune) { metrics::RecordTFDataAutotune(kDatasetName); } - *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls, + *output = new Dataset(ctx, input, batch_size, num_parallel_calls, drop_remainder, output_types_, output_shapes_, - std::move(captured_func), &ctx->eigen_cpu_device(), - std::move(map_func), preserve_cardinality_); + std::move(captured_func), preserve_cardinality_); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, int64 batch_size, + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, int64 num_parallel_calls, bool drop_remainder, const DataTypeVector& output_types, const std::vector& output_shapes, std::unique_ptr captured_func, - const Eigen::ThreadPoolDevice* device, - MapAndBatchIteratorFunction map_func, bool preserve_cardinality) + bool preserve_cardinality) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), batch_size_(batch_size), num_parallel_calls_(num_parallel_calls), drop_remainder_(drop_remainder), output_types_(output_types), output_shapes_(output_shapes), captured_func_(std::move(captured_func)), - device_(device), - map_func_(std::move(map_func)), preserve_cardinality_(preserve_cardinality) { input_->Ref(); } @@ -171,8 +119,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetName)}, - map_func_); + Iterator::Params{this, strings::StrCat(prefix, "::", kDatasetName)}); } const DataTypeVector& output_dtypes() const override { @@ -200,7 +147,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -210,25 +156,12 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); Node* drop_remainder_node; TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); - - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + DataTypeVector other_arguments_types; + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); AttrValue f; - b->BuildAttrValue(func_, &f); + b->BuildAttrValue(captured_func_->func(), &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); AttrValue preserve_cardinality_attr; @@ -252,19 +185,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params, - MapAndBatchIteratorFunction map_func) + explicit Iterator(const Params& params) : DatasetIterator(params), mu_(std::make_shared()), cond_var_(std::make_shared()), num_parallel_calls_(std::make_shared( params.dataset->num_parallel_calls_, mu_, cond_var_)), - map_func_(std::move(map_func)), max_batch_results_(std::min(kMaxBatchResults, (params.dataset->num_parallel_calls_ + params.dataset->batch_size_ - 1) / - params.dataset->batch_size_)) { - } + params.dataset->batch_size_)) {} ~Iterator() override { mutex_lock l(*mu_); @@ -277,6 +207,15 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } } + string BuildTraceMeName() override { + int64 parallelism; + { + tf_shared_lock l(*mu_); + parallelism = num_parallel_calls_->value; + } + return strings::StrCat(prefix(), "#parallelism=", parallelism, "#"); + } + Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (num_parallel_calls_->value == model::kAutoTune) { @@ -403,7 +342,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { stats_aggregator->AddScalar( stats_utils::ThreadUtilizationScalarName(dataset()->node_name()), static_cast(num_calls_) / - static_cast(num_parallel_calls_->value)); + static_cast(num_parallel_calls_->value), + num_elements()); } cond_var_->notify_all(); } @@ -484,9 +424,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // Apply the map function on `input_element`, storing the result in // `return_values`, and invoking `done` when finished. - map_func_(ctx.get(), instantiated_captured_func_.get(), prefix(), - std::move(input_element), std::move(return_values), - std::move(done)); + instantiated_captured_func_->RunAsync( + ctx.get(), std::move(input_element), return_values.get(), + std::move(done), prefix()); } Status CopyPartialBatch(Tensor* output, const Tensor& value, @@ -649,7 +589,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { stats_utils::ThreadUtilizationScalarName( dataset()->node_name()), static_cast(num_calls_) / - static_cast(num_parallel_calls_->value)); + static_cast(num_parallel_calls_->value), + num_elements()); } for (const auto& call : new_calls) { CallFunction(ctx, call.first, call.second); @@ -787,7 +728,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const std::shared_ptr cond_var_; // Identifies the maximum number of parallel calls. const std::shared_ptr num_parallel_calls_; - const MapAndBatchIteratorFunction map_func_; // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(*mu_) = 0; @@ -808,21 +748,18 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; - const NameAttrList func_; const int64 batch_size_; const int64 num_parallel_calls_; const bool drop_remainder_; const DataTypeVector output_types_; const std::vector output_shapes_; const std::unique_ptr captured_func_; - const Eigen::ThreadPoolDevice* device_; // not owned - const MapAndBatchIteratorFunction map_func_; const bool preserve_cardinality_; }; + std::shared_ptr func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList func_; bool preserve_cardinality_; }; diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index 381b9691d14..6a8c9939a33 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -308,7 +308,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { const string child_path = io::JoinPath(current_dir, children[i]); // In case the child_path doesn't start with the fixed_prefix, then // we don't need to explore this path. - if (!str_util::StartsWith(child_path, fixed_prefix)) { + if (!absl::StartsWith(child_path, fixed_prefix)) { children_dir_status[i] = errors::Cancelled("Operation not needed"); } else { diff --git a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc deleted file mode 100644 index ce8a20a783f..00000000000 --- a/tensorflow/core/kernels/data/experimental/numa_map_and_batch_dataset_op.cc +++ /dev/null @@ -1,1160 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#define EIGEN_USE_THREADS - -#include -#include - -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/captured_function.h" -#include "tensorflow/core/kernels/inplace_ops_functor.h" -#include "tensorflow/core/lib/core/blocking_counter.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/platform/numa.h" -#include "tensorflow/core/platform/tracing.h" - -namespace tensorflow { -namespace data { -namespace { - -// kWindowSize is the fixed constant controlling the number of batch outputs -// each NumaWorkerBlock may be processing at a time. This is currently a -// constant and not user configurable to enable future performance optimizations -// in the implementation. -const int64 kWindowSize = 10; - -// Define a helper for more consistent logging. -#define WORKER_VLOG(verbose_level) \ - VLOG(verbose_level) << "WorkerThread (" << numa_node << ", " << thread_num \ - << "): " - -// See documentation in ../ops/dataset_ops.cc for a high-level -// description of the following op. - -class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { - public: - explicit NumaMapAndBatchDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - // TODO(saeta): Implement support for preserve_cardinality logic. - OP_REQUIRES_OK( - ctx, ctx->GetAttr("preserve_cardinality", &preserve_cardinality_)); - } - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - int64 batch_size; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size)); - OP_REQUIRES( - ctx, batch_size > 0, - errors::InvalidArgument("batch_size must be greater than zero.")); - - int64 num_parallel_calls; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", - &num_parallel_calls)); - OP_REQUIRES( - ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutoTune, - errors::InvalidArgument( - "num_parallel_calls must be greater than zero.")); - - bool drop_remainder; - OP_REQUIRES_OK(ctx, - ParseScalarArgument(ctx, "drop_remainder", &drop_remainder)); - - std::unique_ptr captured_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - /* use_inter_op_parallelism = */ false, - &captured_func)); - - *output = new Dataset(ctx, input, batch_size, num_parallel_calls, - drop_remainder, output_types_, output_shapes_, func_, - std::move(captured_func)); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, - int64 num_parallel_calls, bool drop_remainder, - const DataTypeVector& output_types, - const std::vector& output_shapes, - const NameAttrList& func, - std::unique_ptr captured_func) - : DatasetBase(DatasetContext(ctx)), - input_(input), - batch_size_(batch_size), - num_parallel_calls_(num_parallel_calls), - drop_remainder_(drop_remainder), - output_types_(output_types), - output_shapes_(output_shapes), - func_(func), - captured_func_(std::move(captured_func)) { - input_->Ref(); - } - - ~Dataset() override { input_->Unref(); } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::NumaMapAndBatch")}); - } - - const DataTypeVector& output_dtypes() const override { - return output_types_; - } - - const std::vector& output_shapes() const override { - return output_shapes_; - } - - string DebugString() const override { - return "NumaMapAndBatchDatasetOp::Dataset"; - } - - // TODO(b/120482302): Note that this is inaccurate until - // NumaMapAndBatchMapDataset modified to preserve cardinality. - int64 Cardinality() const override { - int64 n = input_->Cardinality(); - if (n == kInfiniteCardinality || n == kUnknownCardinality) { - return n; - } - return n / batch_size_ + - (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - Node* batch_size_node; - TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node)); - Node* num_parallel_calls_node; - TF_RETURN_IF_ERROR( - b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); - Node* drop_remainder_node; - TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); - - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); - std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } - AttrValue f; - b->BuildAttrValue(func_, &f); - AttrValue other_arguments_types_attr; - b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); - - TF_RETURN_IF_ERROR(b->AddDataset( - this, - {std::make_pair(0, input_graph_node), - std::make_pair(2, batch_size_node), - std::make_pair(3, num_parallel_calls_node), - std::make_pair(4, drop_remainder_node)}, // Single tensor inputs. - {std::make_pair(1, other_arguments)}, // Tensor list inputs. - {std::make_pair("f", f), - std::make_pair("Targuments", other_arguments_types_attr)}, // Attrs - output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params), - mu_(std::make_shared()), - autotune_cond_var_(std::make_shared()), - num_parallel_calls_(std::make_shared( - params.dataset->num_parallel_calls_, mu_, autotune_cond_var_)) { - } - - ~Iterator() override { - mutex_lock l(*mu_); - cancelled_ = true; - VLOG(3) << "NumaMapAndBatchIterator::~Iterator: cancelling operations."; - for (size_t i = 0; i < workers_.size(); ++i) { - workers_[i]->manager.Cancel(); - } - VLOG(3) << "NumaMapAndBatchIterator::~Iterator: waiting for threads to " - "shut down."; - } - - Status Initialize(IteratorContext* ctx) override { - mutex_lock l(*mu_); - if (num_parallel_calls_->value == model::kAutoTune) { - num_parallel_calls_->value = ctx->runner_threadpool_size(); - } - TF_RETURN_IF_ERROR( - dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); - TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate( - ctx, &instantiated_captured_func_)); - return Status::OK(); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - auto cleanup = gtl::MakeCleanup( - [] { VLOG(3) << "GetNextInternal call returning."; }); - NumaWorkerBlock* worker = nullptr; - { - mutex_lock l(*mu_); - VLOG(3) << "GetNextInternal call; current block: " << cur_block_; - if (global_end_of_input_) { - *end_of_sequence = true; - return Status::OK(); - } - TF_RETURN_IF_ERROR(EnsureBackgroundThreadsStarted(ctx)); - worker = workers_[cur_block_].get(); - cur_block_ = (cur_block_ + 1) % workers_.size(); - } - bool global_end_of_input_local = false; - Status s = worker->manager.GetBatch(ctx, dataset()->drop_remainder_, - &global_end_of_input_local, - out_tensors, end_of_sequence); - if (global_end_of_input_local) { - mutex_lock l(*mu_); - global_end_of_input_ = global_end_of_input_local; - } - return s; - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeAsyncKnownRatioNode( - std::move(args), dataset()->batch_size_, - {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, - /*max=*/ctx->runner_threadpool_size())}); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(*mu_); - for (size_t i = 0; i < workers_.size(); ++i) { - if (!workers_[i]->manager.Quiesce()) { - return errors::Cancelled( - "The iterator was deleted before it could reach a " - "checkpointable state."); - } - } - - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("num_workers"), workers_.size())); - - for (size_t i = 0; i < workers_.size(); ++i) { - size_t index = (cur_block_ + i) % workers_.size(); - TF_RETURN_IF_ERROR(workers_[index]->manager.Save(writer, this, i)); - } - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(*mu_); - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - int64 num_workers = -1; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("num_workers"), &num_workers)); - // Note: num_workers can be 0 if the iterator wasn't started when - // first checkpointed. - if (num_workers < 0) { - return errors::DataLoss( - "When restoring from checkpoint, we encountered a data " - "consistency error: num_workers has an invalid value: ", - num_workers); - } - if (port::NUMAEnabled()) { - int actual_numa_domains = port::NUMANumNodes(); - if (actual_numa_domains != num_workers && num_workers > 0) { - LOG(WARNING) << "# NUMA domains mismatch when restoring from " - "checkpoint: checkpoint has " - << num_workers - << " NUMA domains, while this host has: " - << actual_numa_domains << " NUMA domains."; - } - } - if (num_workers > 1 && !port::NUMAEnabled()) { - LOG(WARNING) << "NUMA is not enabled for this process, but restoring " - "a checkpoint that assumes " - << num_workers << " NUMA domains."; - } - workers_.resize(num_workers); - for (size_t i = 0; i < num_workers; ++i) { - workers_[i] = absl::make_unique(this); - TF_RETURN_IF_ERROR( - workers_[i]->manager.Restore(ctx, reader, this, i)); - } - cur_block_ = 0; - return Status::OK(); - } - - private: - // NumaBlockManager manages all the state for a set of threads pinned to a - // single NUMA domain. - // - // The methods can be divided into 3 categories based on who should call - // them: - // - // (1) RunnerThread: WaitForInputSpace, PushInputs, SetEndOfInput. - // (2) WorkerThread: RetrieveInput, GetBatchTensors. - // RecordBatchEntryComplete - // (3) Client threads: GetBatch, Cancel, Save, Restore. - // - // Internally, we manage state in a circular buffer of size `kWindowSize`. - // There are 3 pointers into the circular buffer, and must maintain the - // following order: (1) next_input_batch_ (corresponding to the next input - // batch to be pulled from the input iterator), (2) next_input_ - // (corresponding to the batch the WorkerThreads should pull from for - // their next inputs), and (3) next_output_ corresponding to the next - // value to be consumed by the output iterator. - // - // Methods return errors::Cancelled if the iteration is cancelled before - // completing. - // - // NumaBlockManager is thread safe. - class NumaBlockManager { - public: - explicit NumaBlockManager(Iterator* itr) : itr_(itr) {} - - // WaitForInputSpace blocks until there is space in the circular buffer - // to begin processing a new batch of elements. - // - // Returns true when there is space, false if the Iterator is cancelled. - bool WaitForInputSpace(IteratorContext* ctx) { - mutex_lock l(mu_); - - size_t next = (next_input_batch_ + 1) % kWindowSize; - DCHECK(next < kWindowSize) << next; - - // Wait for space in the circular buffer. - while (!cancelled_ && batches_[next].state != BatchState::kEmpty) { - VLOG(3) << "Waiting for input space; next: " << next - << ", next_output_: " << next_output_ - << ", next_input_batch_: " << next_input_batch_; - itr_->RecordStop(ctx); - runner_cond_var_.wait(l); - itr_->RecordStart(ctx); - } - if (cancelled_) { - VLOG(3) << "WaitForInputSpace cancelled."; - return false; - } - - DCHECK(batches_[next].state == BatchState::kEmpty); - - next_input_batch_ = next; - return true; - } - - // PushInputs sets the inputs for the next batch as retrieved from the - // input iterator. - void PushInputs(const Status& status, - std::vector> inputs) { - mutex_lock l(mu_); - - DCHECK(next_input_ < kWindowSize) << next_input_; - DCHECK(batches_[next_input_batch_].state == BatchState::kEmpty); - DCHECK(batches_[next_input_batch_].next_input_to_process == 0) - << batches_[next_input_batch_].next_input_to_process; - DCHECK(batches_[next_input_batch_].status.ok()) - << batches_[next_input_batch_].status; - - batches_[next_input_batch_].inputs.swap(inputs); - batches_[next_input_batch_].state = BatchState::kInputsFilled; - batches_[next_input_batch_].status.Update(status); - if (batches_[next_input_batch_].status.ok()) { - worker_cond_var_.notify_all(); - } else { - client_cond_var_.notify_all(); - batches_[next_input_batch_].error_index = 0; - } - } - - // SetEndOfInput records the fact that we have reached the end of the - // input iterator, and that we should return end_of_sequence = true when - // we have exhaused all buffered batches. - void SetEndOfInput() { - mutex_lock l(mu_); - reached_eof_ = true; - worker_cond_var_.notify_all(); - client_cond_var_.notify_all(); - } - - // RetrieveInput gets the next input tuple to be mapped by a worker - // thread. - // - // Returns true if an input was retrieved, false if the iterator has - // been cancelled. - bool RetrieveInput(IteratorContext* ctx, std::vector* input, - uint64* index, size_t* sequence_number) { - mutex_lock l(mu_); - - // Wait for inputs to be ready. - while (!cancelled_ && - batches_[next_input_].state != BatchState::kInputsFilled) { - itr_->RecordStop(ctx); - worker_cond_var_.wait(l); - itr_->RecordStart(ctx); - } - - if (cancelled_) { - return false; - } - - DCHECK(batches_[next_input_].next_input_to_process < - batches_[next_input_].inputs.size()) - << "next_input_: " << next_input_ << ", next_input_to_process: " - << batches_[next_input_].next_input_to_process - << ", inputs.size(): " << batches_[next_input_].inputs.size() - << ", state: " << static_cast(batches_[next_input_].state) - << ", this: " << this; - *index = batches_[next_input_].next_input_to_process; - *sequence_number = next_input_; - input->swap(batches_[next_input_] - .inputs[batches_[next_input_].next_input_to_process]); - // Increment pointers. - batches_[next_input_].next_input_to_process++; - - if (batches_[next_input_].next_input_to_process == - batches_[next_input_].inputs.size()) { - batches_[next_input_].state = BatchState::kAllMapsStarted; - next_input_ = (next_input_ + 1) % kWindowSize; - } - return true; - } - - // GetBatchTensors returns a pointer to the output batch tensors for the - // worker thread to copy into. - // - // allocate_output is a function taking a batch size, and a pointer to - // the output tuple of Tensors to allocate them. The allocate_output - // function is called at most once per output batch. - std::vector* GetBatchTensors( - size_t sequence_number, - std::function*)> allocate_output) { - mutex_lock l(mu_); - DCHECK(sequence_number < kWindowSize) << sequence_number; - DCHECK(batches_[sequence_number].state == BatchState::kInputsFilled || - batches_[sequence_number].state == BatchState::kAllMapsStarted) - << sequence_number; - - if (batches_[sequence_number].outputs.empty()) { - allocate_output(batches_[sequence_number].inputs.size(), - &batches_[sequence_number].outputs); - } - return &batches_[sequence_number].outputs; - } - - // RecordBatchEntryComplete records an element of the batch has finished - // copying into the output tensors. - void RecordBatchEntryComplete(size_t sequence_number, uint64 index, - Status s) { - mutex_lock l(mu_); - DCHECK(sequence_number < kWindowSize) << sequence_number; - DCHECK(batches_[sequence_number].state == BatchState::kInputsFilled || - batches_[sequence_number].state == BatchState::kAllMapsStarted) - << sequence_number; - - batches_[sequence_number].num_outputs_complete++; - if (!s.ok() && batches_[sequence_number].error_index > index) { - batches_[sequence_number].status = s; - batches_[sequence_number].error_index = index; - } - - if (batches_[sequence_number].num_outputs_complete == - batches_[sequence_number].inputs.size()) { - DCHECK(batches_[sequence_number].state == - BatchState::kAllMapsStarted); - batches_[sequence_number].state = BatchState::kOutputsComplete; - batches_[sequence_number].inputs.clear(); // Eagerly save memory. - batches_[sequence_number].inputs.shrink_to_fit(); - client_cond_var_.notify_all(); - } - } - - // GetBatch retrieves the next output batch tensors. - Status GetBatch(IteratorContext* ctx, bool drop_remainder, - bool* global_eof, std::vector* out_tensor, - bool* end_of_sequence) { - mutex_lock l(mu_); - // Wait until one of 3 conditions occurs: - // (1) we're cancelled. - // (2) the state becomes kOutputsComplete - // (3) state is empty && reached_eof. - while (!cancelled_ && - batches_[next_output_].state != BatchState::kOutputsComplete && - !(reached_eof_ && - batches_[next_output_].state == BatchState::kEmpty)) { - VLOG(3) << "Waiting in GetBatch."; - itr_->RecordStop(ctx); - client_cond_var_.wait(l); - itr_->RecordStart(ctx); - } - - if (cancelled_) { - return errors::Cancelled( - "Cancelled in NumaMapAndBatch::GetNext call."); - } - - if (reached_eof_ && - batches_[next_output_].state == BatchState::kEmpty) { - VLOG(4) << "GetBatch returning end of sequence."; - *end_of_sequence = true; - *global_eof = true; - return Status::OK(); - } - - VLOG(3) << "Returning output index: " << next_output_ - << ", this: " << this; - - *end_of_sequence = false; - Status s = batches_[next_output_].status; - if (s.ok()) { - out_tensor->swap(batches_[next_output_].outputs); - } - // Handle early termination. - if (errors::IsOutOfRange(s)) { - *global_eof = true; - s = Status::OK(); - if (drop_remainder || batches_[next_output_].error_index == 0) { - *end_of_sequence = true; - } else { - std::vector true_outputs; - for (size_t i = 0; i < batches_[next_output_].outputs.size(); - ++i) { - TensorShape component_shape( - batches_[next_output_].outputs[i].shape()); - component_shape.set_dim(0, batches_[next_output_].error_index); - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - true_outputs.emplace_back( - ctx->allocator(attr), - batches_[next_output_].outputs[i].dtype(), component_shape); - TF_RETURN_IF_ERROR(CopyPartialBatch( - &true_outputs.back(), batches_[next_output_].outputs[i], - batches_[next_output_].error_index)); - } - out_tensor->swap(true_outputs); - } - } - - batches_[next_output_].Reset(); - next_output_ = (next_output_ + 1) % kWindowSize; - runner_cond_var_.notify_all(); - - return s; - } - - void Cancel() { - mutex_lock l(mu_); - VLOG(3) << "Cancelling NUMA block."; - cancelled_ = true; - runner_cond_var_.notify_all(); - worker_cond_var_.notify_all(); - client_cond_var_.notify_all(); - } - - // Waits until all the worker threads have completed their work and all - // internal state has reached a "safe-point" where we can safely - // checkpoint. - // - // Returns true if completed successfully, false if cancelled while - // waiting. - bool Quiesce() { - mutex_lock l(mu_); - VLOG(3) << "Waiting until the operations have quiesced."; - while (!cancelled_ && !AllMapOperationsFinished()) { - client_cond_var_.wait(l); - } - if (cancelled_) { - return false; - } - return true; - } - - Status Save(IteratorStateWriter* writer, Iterator* itr, size_t index) { - mutex_lock l(mu_); - string prefix = itr->full_name(strings::StrCat("numa_block_", index)); - if (reached_eof_) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - strings::StrCat(prefix, "_end_of_input"), "")); - } - for (size_t i = 0; i < kWindowSize; ++i) { - size_t index = (next_output_ + i) % kWindowSize; - if (batches_[index].state == BatchState::kEmpty) { - break; - } - string batch_prefix = strings::StrCat(prefix, "_batch_", i); - TF_RETURN_IF_ERROR(writer->WriteScalar( - strings::StrCat(batch_prefix, "_code"), - static_cast(batches_[index].status.code()))); - if (!batches_[index].status.ok()) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(strings::StrCat(batch_prefix, "_msg"), - batches_[index].status.error_message())); - TF_RETURN_IF_ERROR(writer->WriteScalar( - strings::StrCat(batch_prefix, "_error_index"), - batches_[index].error_index)); - } - - TF_RETURN_IF_ERROR(writer->WriteScalar( - strings::StrCat(batch_prefix, "_output_size"), - batches_[index].outputs.size())); - for (size_t j = 0; j < batches_[index].outputs.size(); ++j) { - string tensor_prefix = - strings::StrCat(batch_prefix, "_output_", j); - if (!batches_[index].status.ok()) { - DCHECK(batches_[index].error_index >= 0 && - batches_[index].error_index < - itr_->dataset()->batch_size_); - // If the batch is not full, we only store the first - // `error_index` values. The rest of the batch tensor might not - // be initialized, and accessing that will raise msan errors. - TF_RETURN_IF_ERROR(writer->WriteTensor( - tensor_prefix, batches_[index].outputs[j].Slice( - 0, batches_[index].error_index))); - } else { - TF_RETURN_IF_ERROR(writer->WriteTensor( - tensor_prefix, batches_[index].outputs[j])); - } - } - } - return Status::OK(); - } - - Status Restore(IteratorContext* ctx, IteratorStateReader* reader, - Iterator* itr, size_t index) { - mutex_lock l(mu_); - if (reached_eof_) { - return errors::FailedPrecondition( - "Already reached the end of the sequence."); - } - string prefix = itr->full_name(strings::StrCat("numa_block_", index)); - reached_eof_ = - reader->Contains(strings::StrCat(prefix, "_end_of_input")); - for (size_t i = 0; i < kWindowSize; ++i) { - string batch_prefix = strings::StrCat(prefix, "_batch_", i); - if (!reader->Contains(strings::StrCat(batch_prefix, "_code"))) { - break; - } - Batch batch; - batch.state = BatchState::kOutputsComplete; - int64 code_int; - TF_RETURN_IF_ERROR(reader->ReadScalar( - strings::StrCat(batch_prefix, "_code"), &code_int)); - error::Code code = static_cast(code_int); - if (code != error::Code::OK) { - string error_message; - TF_RETURN_IF_ERROR(reader->ReadScalar( - strings::StrCat(batch_prefix, "_msg"), &error_message)); - batch.status = Status(code, error_message); - int64 error_index_int = -1; - TF_RETURN_IF_ERROR(reader->ReadScalar( - strings::StrCat(batch_prefix, "_error_index"), - &error_index_int)); - if (error_index_int < 0 || - error_index_int > itr->dataset()->batch_size_) { - return errors::FailedPrecondition( - "Error index out of bounds when restoring from checkpoint; " - "error index: ", - error_index_int); - } - batch.error_index = static_cast(error_index_int); - } - int64 output_size = -1; - TF_RETURN_IF_ERROR(reader->ReadScalar( - strings::StrCat(batch_prefix, "_output_size"), &output_size)); - batch.outputs.reserve(output_size); - for (size_t j = 0; j < output_size; ++j) { - string tensor_name = strings::StrCat(batch_prefix, "_output_", j); - Tensor t; - TF_RETURN_IF_ERROR(reader->ReadTensor(tensor_name, &t)); - batch.outputs.emplace_back(std::move(t)); - } - batches_[i] = std::move(batch); - } - return Status::OK(); - } - - private: - bool AllMapOperationsFinished() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - for (size_t i = 0; i < kWindowSize; ++i) { - if (batches_[i].state == BatchState::kInputsFilled || - batches_[i].state == BatchState::kAllMapsStarted) { - return false; - } - if (batches_[i].state != BatchState::kOutputsComplete && - !reached_eof_) { - return false; - } - } - return true; - } - - // Batches begin in the `kEmpty` state. Once the RunnerThread has - // filled the `inputs` to a `Batch`, it transitions to the - // `kInputsFilled` state. At this point, the Worker threads run the map - // function and copy the outputs appropriately. Once all worker threads - // have started, it transitions to `kAllMapsStarted`. After the outputs - // are complete, the GetNext call can consume the outputs, and return - // the batch to the kEmpty state. - enum class BatchState { - kEmpty, - kInputsFilled, - kAllMapsStarted, - kOutputsComplete, - }; - - // Batch captures all the state of an output batch as it progresses - // through the machinery. Once the RunnerThread fills inputs, it - // transitions to `kInputsFilled`. At this point, the worker threads can - // work on it, incrementing outputs_complete for every element of the - // input set that is copied into the output Tensors. Once all the input - // tuples have been processed (i.e. num_outputs_complete == - // inputs.size()), it transitions to the `kOutputsComplete` stage, where - // it is ready to be returned by a `GetBatch` call (called from - // `GetNextInternal`). - struct Batch { - BatchState state; - // Aggregates the Status of the input iterator's GetNext - // calls, in addition to the Status of the map function invocations. - // - // In the case where multiple non-OK statuses are encountered, we - // return the first one encountered. - Status status; - // In order to return the correct error status, we keep track of the - // error_index. - size_t error_index; - // The batch_size input tuples (or fewer in the case of the last - // batch). - // TODO(saeta): Avoid re-allocating vectors all the time! - std::vector> inputs; - std::vector outputs; - size_t next_input_to_process; - size_t num_outputs_complete; - - Batch() { Reset(); } - - // Resets the Batch state (e.g. after consuming the outputs). - void Reset() { - state = BatchState::kEmpty; - status = Status::OK(); - inputs.clear(); - inputs.shrink_to_fit(); - outputs.clear(); - outputs.shrink_to_fit(); - next_input_to_process = 0; - num_outputs_complete = 0; - error_index = -1; - } - }; - - Iterator* itr_; // Not owned. - mutex mu_; - Batch batches_[kWindowSize] GUARDED_BY(mu_); - size_t next_input_batch_ GUARDED_BY(mu_) = -1; - size_t next_input_ GUARDED_BY(mu_) = 0; - size_t next_output_ GUARDED_BY(mu_) = 0; - bool cancelled_ GUARDED_BY(mu_) = false; - bool reached_eof_ GUARDED_BY(mu_) = false; - - // The runner thread waits on this condition variable for space to be - // available. When the client thread takes a value out of the circular - // buffer, it notifies this condition variable that space is now - // available. - condition_variable runner_cond_var_ GUARDED_BY(mu_); - // The worker threads wait on this condition variable for available - // inputs. When the runner thread makes new inputs available, it - // notifies this condition variable. - condition_variable worker_cond_var_ GUARDED_BY(mu_); - // The client threads wait on this condition variable for available - // batched outputs. When worker threads complete a batch, they notify - // this condition variable. - condition_variable client_cond_var_ GUARDED_BY(mu_); - }; - // Mark NumaBlockManager as a friend of Iterator in order to call - // protected Iterator methods during checkpointing. - friend NumaBlockManager; - - struct NumaWorkerBlock { - NumaBlockManager manager; - // TODO(saeta): Migrate to BackgroundWorker. - std::vector> threads; - - explicit NumaWorkerBlock(Iterator* itr) : manager(itr) {} - }; - - static void CustomNumaWorkerBlockDeleter(NumaWorkerBlock* ptr) { - ptr->~NumaWorkerBlock(); - port::NUMAFree(ptr, sizeof(NumaWorkerBlock)); - } - static void DefaultNumaWorkerBlockDeleter(NumaWorkerBlock* ptr) { - delete ptr; - } - - static Status CopyPartialBatch(Tensor* output, const Tensor& value, - int64 num_elements) { - switch (value.dtype()) { -#define HANDLE_TYPE(type) \ - case DataTypeToEnum::value: { \ - auto output_t = output->flat_outer_dims(); \ - auto value_t = value.flat_outer_dims(); \ - for (size_t i = 0; i < num_elements; i++) { \ - output_t.template chip<0>(i) = value_t.template chip<0>(i); \ - } \ - return Status::OK(); \ - } - TF_CALL_DATASET_TYPES(HANDLE_TYPE); -#undef HANDLE_TYPE - default: - return errors::InvalidArgument("Unsupported data type: ", - DataTypeString(value.dtype())); - } - return Status::OK(); - } - - Status EnsureBackgroundThreadsStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(*mu_) { - if (curr_num_parallel_calls_ >= num_parallel_calls_->value) { - // All necessary threads have been started. - curr_num_parallel_calls_ = num_parallel_calls_->value; - return Status::OK(); - } - - VLOG(4) << "Starting workers"; - bool numa_enabled = port::NUMAEnabled(); - - if (!numa_enabled) { - LOG(INFO) << "NUMA not enabled on this host."; - } - - int num_numa_nodes = port::NUMANumNodes(); - if (num_numa_nodes < 1) { - return errors::Internal("The number of NUMA nodes is invalid: ", - num_numa_nodes); - } - - // Only resize when empty to support restoring from checkpoints. - if (workers_.empty()) { - VLOG(3) << "# NUMA Nodes: " << num_numa_nodes - << ", # Parallel Calls: " << num_parallel_calls_->value; - workers_.resize(num_numa_nodes); - } else { - num_numa_nodes = workers_.size(); - } - - // Round up num_parallel_calls, with a minimum of 1. - const size_t num_threads_per_block = - std::max(1LL, (num_parallel_calls_->value + num_numa_nodes - 1) / - num_numa_nodes); - - VLOG(3) << "Starting " << num_threads_per_block * num_numa_nodes - << " worker threads, with " << num_threads_per_block - << " threads per block."; - - // Only allocate new_ctx if required. - std::shared_ptr new_ctx; - - for (int i = 0; i < num_numa_nodes; ++i) { - if (!workers_[i]) { - if (numa_enabled) { - // Allocate in appropriate NUMA domain. - // 4k page align. - void* ptr = port::NUMAMalloc(i, sizeof(NumaWorkerBlock), 0); - if (ptr != nullptr) { - NumaWorkerBlock* block = new (ptr) NumaWorkerBlock(this); - workers_[i] = - std::unique_ptr>( - block, CustomNumaWorkerBlockDeleter); - } else { - LOG(ERROR) << "Could not NUMA-allocate worker block: " << i; - } - } - // If the NUMA allocation fails, or NUMA is not enabled. - if (!workers_[i]) { - workers_[i] = - std::unique_ptr>( - new NumaWorkerBlock(this), DefaultNumaWorkerBlockDeleter); - } - } - // Be sure to start threads if num_parallel_calls_ has changed. - for (size_t j = workers_[i]->threads.size(); - j < num_threads_per_block; ++j) { - VLOG(3) << "Starting worker " << i << ", " << j; - if (!new_ctx) { - new_ctx = std::make_shared(*ctx); - } - workers_[i]->threads.emplace_back(ctx->StartThread( - strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j), - [this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); })); - VLOG(3) << "Worker " << i << ", " << j << " successfully started."; - } - } - if (!runner_thread_) { - if (!new_ctx) { - new_ctx = std::make_shared(*ctx); - } - runner_thread_ = - ctx->StartThread("tf_data_numa_map_and_batch", - [this, new_ctx] { RunnerThread(new_ctx); }); - } - VLOG(3) << "All workers & runner thread started."; - return Status::OK(); - } - - void AllocateOutput(IteratorContext* ctx, size_t batch_size, - const std::vector& map_fn_outputs, - std::vector* batch_outputs) { - DCHECK(dataset()->output_dtypes().size() == - dataset()->output_shapes().size()); - DCHECK(map_fn_outputs.size() == dataset()->output_dtypes().size()); - for (size_t i = 0; i < dataset()->output_dtypes().size(); ++i) { - TensorShape component_shape({static_cast(batch_size)}); - component_shape.AppendShape(map_fn_outputs.at(i).shape()); - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - batch_outputs->emplace_back(ctx->allocator(attr), - map_fn_outputs.at(i).dtype(), - component_shape); - } - } - - void RunnerThread(std::shared_ptr ctx) - LOCKS_EXCLUDED(mu_) { - RecordStart(ctx.get()); - auto cleanup = gtl::MakeCleanup([this, &ctx] { - // Set end of input on all the managers in order to clean up in an - // orderly fashion. - VLOG(3) << "Setting End of Input on workers_[*]->manager"; - for (size_t i = 0; i < workers_.size(); ++i) { - workers_[i]->manager.SetEndOfInput(); - } - RecordStop(ctx.get()); - }); - - const size_t num_blocks = workers_.size(); - - while (true) { - for (size_t block = 0; block < num_blocks; ++block) { - VLOG(4) << "RunnerThread waiting for input space in block: " - << block; - if (TF_PREDICT_FALSE( - !workers_[block]->manager.WaitForInputSpace(ctx.get()))) { - VLOG(3) << "RunnerThread exiting due to cancellation."; - return; - } - VLOG(4) << "RunnerThread has space; pulling on upstream for block " - << block; - - Status s; - std::vector> inputs; - bool end_of_sequence = false; - for (size_t i = 0; i < dataset()->batch_size_; ++i) { - std::vector tuple; - s.Update( - input_impl_->GetNext(ctx.get(), &tuple, &end_of_sequence)); - if (!s.ok()) { - break; - } - if (end_of_sequence) { - VLOG(4) << "Runner thread encountered end of sequence."; - if (dataset()->drop_remainder_) { - return; - } - break; - } - inputs.push_back(std::move(tuple)); - } - - VLOG(4) << "Moving inputs to block " << block - << ", which has size: " << inputs.size(); - if (!s.ok() || !inputs.empty()) { - workers_[block]->manager.PushInputs(s, std::move(inputs)); - VLOG(4) << "Inputs moved into block " << block; - } - if (end_of_sequence) { - return; - } - } - } - } - - void WorkerThread(std::shared_ptr ctx, - const int numa_node, const int thread_num) { - RecordStart(ctx.get()); - WORKER_VLOG(3) << "started."; - auto stop_cleanup = - gtl::MakeCleanup([this, numa_node, thread_num, &ctx]() { - RecordStop(ctx.get()); - WORKER_VLOG(3) << "exiting."; - }); - - NumaWorkerBlock* block = workers_[numa_node].get(); - port::NUMASetThreadNodeAffinity(numa_node); - const int num_numa_nodes = port::NUMANumNodes(); - const int minimum_num_parallel_calls = thread_num * num_numa_nodes; - - while (true) { - // Put threads to sleep based on autotuner. - { - mutex_lock l(*mu_); - while (minimum_num_parallel_calls >= num_parallel_calls_->value && - !cancelled_) { - RecordStop(ctx.get()); - autotune_cond_var_->wait(l); - RecordStart(ctx.get()); - } - if (cancelled_) { - return; - } - } - - std::vector input; - uint64 index = 0; - size_t sequence_number = 0; - WORKER_VLOG(4) << "retrieving input."; - { - tracing::ScopedActivity trace( - "NumaMapAndBatch::Iterator::Worker::RetrieveInput"); - if (!block->manager.RetrieveInput(ctx.get(), &input, &index, - &sequence_number)) { - return; - } - } - - WORKER_VLOG(4) << "retrieved input; index: " << index - << ", sequence_number: " << sequence_number; - - std::vector return_values; - Status s; - { - tracing::ScopedActivity trace( - "NumaMapAndBatch::Iterator::Worker::FunctionExecution"); - s = instantiated_captured_func_->Run(ctx.get(), std::move(input), - &return_values); - } - WORKER_VLOG(4) << "ran function for index: " << index - << ", sequence_number: " << sequence_number; - - if (s.ok()) { - std::vector* output = block->manager.GetBatchTensors( - sequence_number, - [this, ctx, &return_values](size_t batch_size, - std::vector* output) { - AllocateOutput(ctx.get(), batch_size, return_values, output); - }); - WORKER_VLOG(4) << "copying tensors to batch output."; - { - tracing::ScopedActivity trace( - "NumaMapAndBatch::Iterator::Worker::BatchCopy"); - for (size_t i = 0; i < return_values.size() && s.ok(); ++i) { - Tensor& tensor = return_values.at(i); - Tensor* batch = &output->at(i); - if (tensor.NumElements() != - (batch->NumElements() / batch->dim_size(0))) { - s.Update(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does " - "not match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch->shape().DebugString())); - break; - } - s.Update(batch_util::CopyElementToSlice(std::move(tensor), - batch, index)); - } - } - } - - block->manager.RecordBatchEntryComplete(sequence_number, index, s); - WORKER_VLOG(4) << "finished index: " << index - << ", sequence_number: " << sequence_number; - } - } - - // mu_ protects shared internal state and is used to coordinate between - // the auto-tuner, client threads, worker threads, and the runner thread. - const std::shared_ptr mu_; - const std::shared_ptr autotune_cond_var_; - // The maximum number of parallel calls (can be auto-tuned). - const std::shared_ptr num_parallel_calls_; - std::unique_ptr instantiated_captured_func_; - - // Caches the last-seen value of num_parallel_calls_->value to - // short-circuit starting workers. - int64 curr_num_parallel_calls_ GUARDED_BY(*mu_) = 0; - - std::unique_ptr input_impl_; - int64 cur_block_ GUARDED_BY(*mu_) = 0; - bool global_end_of_input_ GUARDED_BY(*mu_) = false; - bool cancelled_ GUARDED_BY(*mu_) = false; - std::vector>> - workers_; // Const after initialization. - std::unique_ptr runner_thread_ GUARDED_BY(*mu_); - }; - - const DatasetBase* const input_; - const int64 batch_size_; - const int64 num_parallel_calls_; - const bool drop_remainder_; - const DataTypeVector output_types_; - const std::vector output_shapes_; - const NameAttrList func_; - const std::unique_ptr captured_func_; - }; - - DataTypeVector output_types_; - std::vector output_shapes_; - NameAttrList func_; - bool preserve_cardinality_; -}; - -REGISTER_KERNEL_BUILDER( - Name("ExperimentalNumaMapAndBatchDataset").Device(DEVICE_CPU), - NumaMapAndBatchDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 54c1d839e60..d89518eefc4 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -38,7 +38,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "f", /*params=*/{}, + &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -76,12 +77,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr captured_func; OP_REQUIRES_OK( - ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", &captured_func)); *output = - new Dataset(ctx, input, interleave_func_, std::move(captured_func), - cycle_length, block_length, sloppy, buffer_output_elements, + new Dataset(ctx, input, std::move(captured_func), cycle_length, + block_length, sloppy, buffer_output_elements, prefetch_input_elements, output_types_, output_shapes_); } @@ -89,14 +90,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, std::unique_ptr captured_func, int64 cycle_length, int64 block_length, bool sloppy, int64 buffer_output_elements, int64 prefetch_input_elements, const DataTypeVector& output_types, const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), input_(input), - interleave_func_(func), captured_func_(std::move(captured_func)), cycle_length_(cycle_length), block_length_(block_length), @@ -132,7 +131,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; @@ -147,24 +145,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { Node* prefetch_input_elements_node; TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node)); - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + DataTypeVector other_arguments_types; + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); AttrValue f; - b->BuildAttrValue(interleave_func_, &f); + b->BuildAttrValue(captured_func_->func(), &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -1065,7 +1051,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; - const NameAttrList interleave_func_; const std::unique_ptr captured_func_; const int64 cycle_length_; const int64 block_length_; @@ -1076,9 +1061,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const std::vector output_shapes_; }; + std::shared_ptr func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList interleave_func_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc index c207cf7ae4f..3dbb4df8ada 100644 --- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc @@ -72,7 +72,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - int64 num_parallel_calls; + int64 num_parallel_calls = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); OP_REQUIRES(ctx, num_parallel_calls > 0, @@ -265,9 +265,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { void MapFunc(IteratorContext* ctx, const string& prefix, std::vector input, std::vector* output, StatusCallback callback) override { - (*ctx->runner())([this, ctx, input, output, callback]() { + (*ctx->runner())([this, ctx, prefix, input, output, callback]() { thread::ThreadPool* device_threadpool = - ctx->lib()->device()->tensorflow_cpu_worker_threads()->workers; + ctx->flr()->device()->tensorflow_cpu_worker_threads()->workers; std::vector slice_vec; for (const Tensor& t : input) { auto serialized_t = t.flat(); @@ -341,19 +341,22 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { example_result.feature_stats.size()); for (example::PerExampleFeatureStats feature_stats : example_result.feature_stats) { - stats_aggregator->AddToHistogram( - stats_utils::FeatureHistogramName(dataset_->node_name()), - {static_cast(feature_stats.features_count)}); stats_aggregator->IncrementCounter( stats_utils::kFeaturesCount, "trainer", feature_stats.features_count); stats_aggregator->IncrementCounter( stats_utils::kFeatureValuesCount, "trainer", feature_stats.feature_values_count); + int64 steps = ctx->model()->NumElements(prefix); + stats_aggregator->AddToHistogram( + stats_utils::FeatureHistogramName(dataset_->node_name()), + {static_cast(feature_stats.features_count)}, steps); + stats_aggregator->AddToHistogram( stats_utils::FeatureValueHistogramName( dataset_->node_name()), - {static_cast(feature_stats.feature_values_count)}); + {static_cast(feature_stats.feature_values_count)}, + steps); } } } diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc index 0397ca01c4e..3078005a00d 100644 --- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace data { @@ -24,11 +25,7 @@ constexpr char kOptimizerName[] = "tf_data_rebatcher"; class RebatchDatasetOp : public UnaryDatasetOpKernel { public: explicit RebatchDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - } + : UnaryDatasetOpKernel(ctx) {} protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, @@ -39,58 +36,32 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { ctx, num_workers > 0, errors::InvalidArgument("num_workers must be greater than zero.")); - Dataset* dataset = - new Dataset(ctx, input, num_workers, output_types_, output_shapes_); - Status s = dataset->Optimize(ctx); - if (s.ok()) { - *output = dataset; - } else { - dataset->Unref(); - OP_REQUIRES_OK(ctx, s); - } + auto config_factory = [num_workers]() { return CreateConfig(num_workers); }; + + // We only want to optimize functions for some particular datasets like + // FlatMapDataset, InterleaveDataset etc. So we disable generalized + // function optimization and explicitly handle function modifications + // for those datasets in the rewrite. + OP_REQUIRES_OK(ctx, + RewriteDataset(ctx, input, std::move(config_factory), + /*optimize_function_library=*/false, output)); } private: - class Dataset : public GraphRewriteDataset { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const int64 num_workers, const DataTypeVector& output_types, - const std::vector& output_shapes) - : GraphRewriteDataset(ctx, input, output_types, output_shapes), - num_workers_(num_workers) {} - - string DebugString() const override { return "RebatchDatasetOp::Dataset"; } - - private: - bool ShouldOptimizeFunctions() override { - // We only want to optimize functions for some particular datasets like - // FlatMapDataset, InterleaveDataset etc. So we disable generalized - // function optimization and explicitly handle function modifications - // for those datasets in the rewrite. - return false; - } - - RewriterConfig CreateGrapplerRewriteConfig() override { - RewriterConfig rewriter_config; - rewriter_config.set_fail_on_optimizer_errors(true); - rewriter_config.add_optimizers(kOptimizerName); - rewriter_config.set_meta_optimizer_iterations( - RewriterConfig_NumIterationsType_ONE); - auto custom_optimizer = rewriter_config.add_custom_optimizers(); - custom_optimizer->set_name(kOptimizerName); - AttrValue num_workers_attr; - num_workers_attr.set_i(num_workers_); - (*custom_optimizer->mutable_parameter_map())["num_workers"] = - num_workers_attr; - return rewriter_config; - } - - const int64 num_workers_; - }; - - const int graph_def_version_; - DataTypeVector output_types_; - std::vector output_shapes_; + static RewriterConfig CreateConfig(int64 num_workers) { + RewriterConfig rewriter_config; + rewriter_config.set_fail_on_optimizer_errors(true); + rewriter_config.add_optimizers(kOptimizerName); + rewriter_config.set_meta_optimizer_iterations( + RewriterConfig_NumIterationsType_ONE); + auto custom_optimizer = rewriter_config.add_custom_optimizers(); + custom_optimizer->set_name(kOptimizerName); + AttrValue num_workers_attr; + num_workers_attr.set_i(num_workers); + (*custom_optimizer->mutable_parameter_map())["num_workers"] = + num_workers_attr; + return rewriter_config; + } }; REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc new file mode 100644 index 00000000000..a118fd81763 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc @@ -0,0 +1,224 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../../ops/dataset_ops.cc for a high-level +// description of the following op. + +class SamplingDatasetOp : public UnaryDatasetOpKernel { + public: + explicit SamplingDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + protected: + // Create a new SamplingDatasetOp::Dataset, and return it as the output. + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + float rate; + int64 seed; + int64 seed2; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "rate", &rate)); + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "seed", &seed)); + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "seed2", &seed2)); + + if (seed == 0 && seed2 == 0) { + seed = random::New64(); + seed2 = random::New64(); + } + *output = new Dataset(ctx, rate, seed, seed2, input); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, float rate, int64 seed, int64 seed2, + const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), + rate_(rate), + seed_(seed), + seed2_(seed2), + input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::Sampling")}, seed_, seed2_)); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { return "SamplingDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* rate = nullptr; + Node* seed = nullptr; + Node* seed2 = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(rate_, &rate)); + TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); + TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, rate, seed, seed2}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params, int64 seed, int64 seed2) + : DatasetIterator(params), + seed_(seed), + seed2_(seed2), + parent_generator_(seed, seed2), + generator_(&parent_generator_) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + bool rand_val_hit; + do { + { + tf_shared_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + } + if (*end_of_sequence) { + mutex_lock l(mu_); + input_impl_.reset(); + return Status::OK(); + } + + // generate a number from random uniform [0, 1) + float rand_val = Random(); + rand_val_hit = rand_val < dataset()->rate_; + if (!rand_val_hit) { + // Clear the output tensor list since it doesn't match. + out_tensors->clear(); + } + } while (!rand_val_hit); + *end_of_sequence = false; + return Status::OK(); + } + + protected: + void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // Reset the generators based on the current iterator seeds. + parent_generator_ = random::PhiloxRandom(seed_, seed2_); + generator_ = random::SimplePhilox(&parent_generator_); + + parent_generator_.Skip(num_random_samples_); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // Save state needed to restore the random number generators. + TF_RETURN_IF_ERROR(writer->WriteScalar( + this->full_name("num_random_samples"), num_random_samples_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("seed"), seed_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(this->full_name("seed2"), seed2_)); + + if (input_impl_) { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + // Restore the random number generators. + TF_RETURN_IF_ERROR(reader->ReadScalar( + this->full_name("num_random_samples"), &num_random_samples_)); + TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(this->full_name("seed2"), &seed2_)); + ResetRngs(); + + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } + return Status::OK(); + } + + mutex mu_; + int64 seed_ GUARDED_BY(mu_); + int64 seed2_ GUARDED_BY(mu_); + + private: + std::unique_ptr input_impl_ GUARDED_BY(mu_); + + float Random() { + mutex_lock l(mu_); + num_random_samples_++; + auto out = generator_.RandFloat(); + return out; + } + + // random util + random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); + random::SimplePhilox generator_ GUARDED_BY(mu_); + int64 num_random_samples_ GUARDED_BY(mu_) = 0; + }; + + const float rate_; + const int64 seed_, seed2_; + const DatasetBase* const input_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("SamplingDataset").Device(DEVICE_CPU), + SamplingDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc index 55e22c1cac6..0fcfadbe59b 100644 --- a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { @@ -33,7 +34,10 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { public: explicit ScanDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + FunctionMetadata::Params params; + params.is_multi_device_function = true; + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "f", params, &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tstate", &state_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -50,10 +54,11 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { initial_state_inputs.end()); std::unique_ptr captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - &captured_func)); + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func)); - *output = new Dataset(ctx, input, func_, std::move(initial_state), + *output = new Dataset(ctx, input, std::move(initial_state), std::move(captured_func), state_types_, output_types_, output_shapes_, preserve_cardinality_); } @@ -62,7 +67,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, std::vector initial_state, + std::vector initial_state, std::unique_ptr captured_func, const DataTypeVector& state_types, const DataTypeVector& output_types, @@ -70,7 +75,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { bool preserve_cardinality) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), initial_state_(std::move(initial_state)), captured_func_(std::move(captured_func)), state_types_(state_types), @@ -103,7 +107,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); std::vector initial_state_nodes; @@ -114,23 +117,11 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { initial_state_nodes.emplace_back(node); } std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); AttrValue f; - b->BuildAttrValue(func_, &f); + b->BuildAttrValue(captured_func_->func(), &f); AttrValue state_types; b->BuildAttrValue(state_types_, &state_types); AttrValue other_arguments_types_attr; @@ -283,7 +274,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; - const NameAttrList func_; const std::vector initial_state_; const std::unique_ptr captured_func_; const DataTypeVector state_types_; @@ -292,10 +282,10 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { const bool preserve_cardinality_; }; + std::shared_ptr func_metadata_ = nullptr; DataTypeVector state_types_; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList func_; bool preserve_cardinality_; }; diff --git a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc index 67bb1e160b9..dcd4e68e65e 100644 --- a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/data/stats_utils.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { @@ -32,21 +33,13 @@ class StatsAggregatorWithTagAndPrefix : public StatsAggregator { const string& prefix) : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {} - void AddToHistogram(const string& name, - gtl::ArraySlice values) override { - if (!tag_.empty()) { - wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values); - } else { - wrapped_->AddToHistogram(name, values); - } + void AddToHistogram(const string& name, gtl::ArraySlice values, + int64 steps) override { + wrapped_->AddToHistogram(TaggedName(name), values, steps); } - void AddScalar(const string& name, float value) override { - if (!tag_.empty()) { - wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value); - } else { - wrapped_->AddScalar(name, value); - } + void AddScalar(const string& name, float value, int64 steps) override { + wrapped_->AddScalar(TaggedName(name), value, steps); } void EncodeToProto(Summary* out_summary) override { @@ -56,15 +49,27 @@ class StatsAggregatorWithTagAndPrefix : public StatsAggregator { void IncrementCounter(const string& name, const string& label, int64 val) override { if (!prefix_.empty()) { - wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label, - val); + wrapped_->IncrementCounter( + strings::StrCat(prefix_, "/", TaggedName(name)), label, val); } else { - wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label, - val); + wrapped_->IncrementCounter( + strings::StrCat("/tensorflow/", TaggedName(name)), label, val); } } + Status SetSummaryWriter(SummaryWriterInterface* summary_writer) override { + return wrapped_->SetSummaryWriter(summary_writer); + } + private: + string TaggedName(const string& name) const { + if (!tag_.empty()) { + string tagged_name = strings::StrCat(tag_, stats_utils::kDelimiter, name); + return tagged_name; + } + return name; + } + std::shared_ptr wrapped_; string tag_; string prefix_; diff --git a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc index 9d63690622d..a3bccedd014 100644 --- a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc @@ -135,6 +135,12 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ExperimentalSleepDataset").Device(DEVICE_CPU), SleepDatasetOp); +REGISTER_KERNEL_BUILDER(Name("ExperimentalSleepDataset") + .Device(DEVICE_GPU) + .HostMemory("sleep_microseconds") + .HostMemory("input_dataset") + .HostMemory("handle"), + SleepDatasetOp); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc index c5851eaf86b..dec136dd35e 100644 --- a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc @@ -267,7 +267,7 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { input_impl_.reset(); } // Restore buffer. - int64 buffer_size; + int64 buffer_size = 0; TF_RETURN_IF_ERROR( reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size)); buffer_.resize(buffer_size); diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc new file mode 100644 index 00000000000..1ff5878bb65 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -0,0 +1,488 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/compression.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/base64.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/protobuf/data/experimental/snapshot.pb.h" +#include "tensorflow/core/util/batch_util.h" + +namespace tensorflow { +namespace data { +namespace { + +enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; + +const uint64 kReaderBufferSize = 8 * 1024 * 1024; // 8 MB + +const char* kCompressionType = io::compression::kGzip; + +const uint64 kOneDayInMicroseconds = 24L * 60L * 60L * 1e6L; + +const uint64 kNumElementsPerShard = 10000; + +const char kSnapshotFilename[] = "snapshot.metadata"; + +string GetCurrentSnapshotDataFilename(uint64 next_index, + const string& run_dir) { + uint64_t shard_id = next_index / kNumElementsPerShard; + return absl::StrCat(run_dir, "/", strings::Printf("%08lu", shard_id), + ".snapshot"); +} + +Status WriteMetadataFile(const string& fingerprint_dir, + const experimental::SnapshotMetadataRecord& metadata) { + string metadata_filename = + absl::StrCat(fingerprint_dir, "/", kSnapshotFilename); + + TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(fingerprint_dir)); + + std::unique_ptr file; + TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(metadata_filename, &file)); + + auto writer = absl::make_unique(file.get()); + TF_RETURN_IF_ERROR(writer->WriteRecord(metadata.SerializeAsString())); + TF_RETURN_IF_ERROR(writer->Close()); + + return Status::OK(); +} + +Status ReadMetadataFile(const string& fingerprint_dir, + experimental::SnapshotMetadataRecord* metadata) { + string metadata_filename = + absl::StrCat(fingerprint_dir, "/", kSnapshotFilename); + TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename)); + + std::unique_ptr file; + TF_CHECK_OK(Env::Default()->NewRandomAccessFile(metadata_filename, &file)); + + string record_bytes; + auto reader = absl::make_unique(file.get()); + TF_CHECK_OK(reader->ReadRecord(&record_bytes)); + + metadata->ParseFromString(record_bytes); + return Status::OK(); +} + +SnapshotMode DetermineOpState( + const Status& file_status, + const experimental::SnapshotMetadataRecord& metadata) { + if (errors::IsNotFound(file_status)) { + return WRITER; + } + + if (metadata.finalized()) { + // File found, snapshot has been finalized. + return READER; + } + + if (metadata.creation_timestamp() >= + Env::Default()->NowMicros() - kOneDayInMicroseconds) { + // TODO(frankchn): Make this timestamp configurable. + // Someone else is already writing and time has not expired. + return PASSTHROUGH; + } else { + // Time has expired, we write regardless. + return WRITER; + } +} + +class SnapshotDatasetOp : public UnaryDatasetOpKernel { + public: + explicit SnapshotDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + string path; + + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path)); + + GraphDef graph_def; + OP_REQUIRES_OK( + ctx, AsGraphDef(ctx, input, SerializationContext({}), &graph_def)); + + // TODO(frankchn): Find a better way than SerializeToStringDeterministic() + // This is not deterministic across different builds of binaries right now. + string graph_def_serialized; + SerializeToStringDeterministic(graph_def, &graph_def_serialized); + + string graph_fingerprint = strings::StrCat( + strings::Hex(Fingerprint64(graph_def_serialized), strings::kZeroPad16)); + + *output = new Dataset(ctx, input, path, graph_fingerprint); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, const string& path, + const string& graph_fingerprint) + : DatasetBase(DatasetContext(ctx)), + input_(input), + dir_(path), + graph_fingerprint_(graph_fingerprint) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique( + Iterator::Params{this, strings::StrCat(prefix, "::Snapshot")}); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { return "SnapshotDatasetOp::Dataset"; } + + int64 Cardinality() const override { return input_->Cardinality(); } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* path = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(dir_, &path)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, path}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + fingerprint_dir_ = + absl::StrCat(dataset()->dir_, "/", dataset()->graph_fingerprint_); + + experimental::SnapshotMetadataRecord metadata; + Status s = ReadMetadataFile(fingerprint_dir_, &metadata); + state_ = DetermineOpState(s, metadata); + + switch (state_) { + case WRITER: + iterator_ = absl::make_unique( + SnapshotWriterIterator::Params{ + dataset(), strings::StrCat(prefix(), "Impl")}, + fingerprint_dir_); + break; + case READER: + iterator_ = absl::make_unique( + SnapshotReaderIterator::Params{ + dataset(), strings::StrCat(prefix(), "Impl")}, + fingerprint_dir_, metadata); + break; + case PASSTHROUGH: + iterator_ = absl::make_unique( + SnapshotPassthroughIterator::Params{ + dataset(), strings::StrCat(prefix(), "Impl")}); + break; + } + + return iterator_->Initialize(ctx); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + return iterator_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + // TODO(frankchn): Make save iterators work + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + // TODO(frankchn): Make iterator restores work + return Status::OK(); + } + + private: + class SnapshotReaderIterator : public DatasetIterator { + public: + explicit SnapshotReaderIterator( + const Params& params, const string& fingerprint_dir, + const experimental::SnapshotMetadataRecord& metadata) + : DatasetIterator(params), + fingerprint_dir_(fingerprint_dir), + metadata_(metadata) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + + run_id_ = metadata_.run_id(); + run_dir_ = absl::StrCat(fingerprint_dir_, "/", run_id_); + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + + string snapshot_data_filename = + GetCurrentSnapshotDataFilename(next_index_, run_dir_); + + if (current_read_filename_ != snapshot_data_filename) { + current_reader_.reset(); + current_read_file_.reset(); + + // The current implementation here assumes that tensors are stored + // in files which are named sequentially. If a file doesn't exist + // when we try reading that item, we assume that we have reached the + // end of the snapshot. + Status s = Env::Default()->FileExists(snapshot_data_filename); + if (!s.ok()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_CHECK_OK(Env::Default()->NewRandomAccessFile( + snapshot_data_filename, ¤t_read_file_)); + auto reader_options = + io::RecordReaderOptions::CreateRecordReaderOptions( + kCompressionType); + reader_options.buffer_size = kReaderBufferSize; + + current_reader_ = absl::make_unique( + current_read_file_.get(), reader_options); + current_read_filename_ = snapshot_data_filename; + } + + string record_bytes; + Status s = current_reader_->ReadRecord(&record_bytes); + + if (errors::IsOutOfRange(s)) { + *end_of_sequence = true; + return Status::OK(); + } else if (!s.ok()) { + return s; + } + + *end_of_sequence = false; + experimental::SnapshotRecord record; + record.ParseFromString(record_bytes); + + for (int i = 0; i < record.tensor_size(); ++i) { + Tensor t; + if (!t.FromProto(record.tensor(i))) { + return errors::DataLoss("Unable to parse Tensor from proto."); + } + out_tensors->push_back(t); + } + + next_index_++; + return Status::OK(); + } + + private: + const string fingerprint_dir_; + const experimental::SnapshotMetadataRecord metadata_; + string run_id_ GUARDED_BY(mu_); + string run_dir_ GUARDED_BY(mu_); + + std::unique_ptr input_impl_ GUARDED_BY(mu_); + + string current_read_filename_ GUARDED_BY(mu_); + std::unique_ptr current_read_file_ GUARDED_BY(mu_); + std::unique_ptr current_reader_ + GUARDED_BY(mu_); + + int64 next_index_ GUARDED_BY(mu_) = 0; + + mutex mu_; + }; + + class SnapshotWriterIterator : public DatasetIterator { + public: + explicit SnapshotWriterIterator(const Params& params, + const string& fingerprint_dir) + : DatasetIterator(params), + fingerprint_dir_(fingerprint_dir) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + + run_id_ = strings::StrCat( + strings::Hex(random::New64(), strings::kZeroPad4)); + run_dir_ = absl::StrCat(fingerprint_dir_, "/", run_id_); + + TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir_)); + + experimental::SnapshotMetadataRecord metadata; + metadata.set_creation_timestamp(Env::Default()->NowMicros()); + metadata.set_graph_fingerprint(dataset()->graph_fingerprint_); + metadata.set_run_id(run_id_); + metadata.set_finalized(false); + + TF_RETURN_IF_ERROR(WriteMetadataFile(fingerprint_dir_, metadata)); + + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + + if (*end_of_sequence) { + experimental::SnapshotMetadataRecord metadata; + TF_RETURN_IF_ERROR(ReadMetadataFile(fingerprint_dir_, &metadata)); + + if (metadata.run_id() == run_id_) { + if (current_writer_) TF_RETURN_IF_ERROR(current_writer_->Close()); + if (current_write_file_) + TF_RETURN_IF_ERROR(current_write_file_->Close()); + current_writer_.reset(); + current_write_file_.reset(); + + current_write_filename_ = ""; + + metadata.set_finalized(true); + TF_RETURN_IF_ERROR(WriteMetadataFile(fingerprint_dir_, metadata)); + } else { + // TODO(frankchn): We lost the race, remove all snapshots. + } + + return Status::OK(); + } + + string snapshot_data_filename = + GetCurrentSnapshotDataFilename(next_index_, run_dir_); + + if (current_write_filename_ != snapshot_data_filename) { + if (current_writer_) TF_RETURN_IF_ERROR(current_writer_->Close()); + if (current_write_file_) + TF_RETURN_IF_ERROR(current_write_file_->Close()); + + current_writer_.reset(); + current_write_file_.reset(); + + auto writer_options = + io::RecordWriterOptions::CreateRecordWriterOptions( + kCompressionType); + + TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile( + snapshot_data_filename, ¤t_write_file_)); + current_writer_ = absl::make_unique( + current_write_file_.get(), writer_options); + current_write_filename_ = snapshot_data_filename; + } + + experimental::SnapshotRecord record; + + for (auto out_tensor : *out_tensors) { + TensorProto* t = record.add_tensor(); + out_tensor.AsProtoTensorContent(t); + } + + TF_RETURN_IF_ERROR( + current_writer_->WriteRecord(record.SerializeAsString())); + + next_index_++; + return Status::OK(); + } + + private: + std::unique_ptr input_impl_; + + const string fingerprint_dir_; + string run_id_ GUARDED_BY(mu_); + string run_dir_ GUARDED_BY(mu_); + + string current_write_filename_ GUARDED_BY(mu_); + std::unique_ptr current_write_file_ GUARDED_BY(mu_); + std::unique_ptr current_writer_ GUARDED_BY(mu_); + + uint64 next_index_ GUARDED_BY(mu_) = 0; + + mutex mu_; + }; + + class SnapshotPassthroughIterator : public DatasetIterator { + public: + explicit SnapshotPassthroughIterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + + private: + std::unique_ptr input_impl_; + }; + + string fingerprint_dir_; + SnapshotMode state_; + + std::unique_ptr iterator_; + }; + + const DatasetBase* const input_; + const string dir_; + const string graph_fingerprint_; + }; + + const int graph_def_version_; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("SnapshotDataset").Device(DEVICE_CPU), + SnapshotDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc index 84f6fba36d1..bd941550488 100644 --- a/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sql_dataset_op.cc @@ -199,7 +199,7 @@ class SqlDatasetOp : public DatasetOpKernel { } mutex mu_; - // TODO(shivaniagrawal): explore ways to seek into a SQLite databases. + // TODO(b/129062371): explore ways to seek into a SQLite databases. int64 next_calls_ GUARDED_BY(mu_) = 0; std::unique_ptr query_connection_ GUARDED_BY(mu_); bool query_connection_initialized_ GUARDED_BY(mu_) = false; diff --git a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc index 1d1b788b6c1..0d6ec07364d 100644 --- a/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc @@ -19,11 +19,14 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/kernels/summary_interface.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/histogram/histogram.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/events_writer.h" namespace tensorflow { namespace data { @@ -44,8 +47,8 @@ class StatsAggregatorImpl : public StatsAggregator { public: StatsAggregatorImpl() {} - void AddToHistogram(const string& name, - gtl::ArraySlice values) override { + void AddToHistogram(const string& name, gtl::ArraySlice values, + const int64 steps) override { mutex_lock l(mu_); histogram::Histogram& histogram = histograms_[name]; for (double value : values) { @@ -53,7 +56,7 @@ class StatsAggregatorImpl : public StatsAggregator { } } - void AddScalar(const string& name, float value) override { + void AddScalar(const string& name, float value, const int64 steps) override { mutex_lock l(mu_); scalars_[name] = value; } @@ -76,6 +79,13 @@ class StatsAggregatorImpl : public StatsAggregator { } } + // StatsAggregator implementation for V2 is based on push-based summary, no-op + // in V1. + Status SetSummaryWriter( + SummaryWriterInterface* summary_writer_interface) override { + return Status::OK(); + } + void IncrementCounter(const string& name, const string& label, int64 val) override { mutex_lock l(*get_counters_map_lock()); @@ -112,8 +122,128 @@ class StatsAggregatorHandleOp new StatsAggregatorResource(absl::make_unique()); return Status::OK(); } +}; - Status VerifyResource(StatsAggregatorResource* resource) override { +class StatsAggregatorImplV2 : public StatsAggregator { + public: + StatsAggregatorImplV2() {} + + ~StatsAggregatorImplV2() override { + if (summary_writer_interface_) { + summary_writer_interface_->Unref(); + } + } + + void AddToHistogram(const string& name, gtl::ArraySlice values, + const int64 steps) override { + mutex_lock l(mu_); + histogram::Histogram& histogram = histograms_[name]; + for (double value : values) { + histogram.Add(value); + } + AddToEvents(name, steps, histogram); + } + + void AddScalar(const string& name, float value, const int64 steps) override { + mutex_lock l(mu_); + AddToEvents(name, steps, value); + } + + // TODO(b/116314787): expose this is public API to manually flush summary. + Status Flush() { + mutex_lock l(mu_); + if (summary_writer_interface_) + TF_RETURN_IF_ERROR(summary_writer_interface_->Flush()); + return Status::OK(); + } + + void IncrementCounter(const string& name, const string& label, + int64 val) override { + mutex_lock l(*get_counters_map_lock()); + auto counters_map = get_counters_map(); + if (counters_map->find(name) == counters_map->end()) { + counters_map->emplace( + name, monitoring::Counter<1>::New( + /*streamz name*/ "/tensorflow/" + name, + /*streamz description*/ + name + " generated or consumed by the component.", + /*streamz label name*/ "component_descriptor")); + } + counters_map->at(name)->GetCell(label)->IncrementBy(val); + } + + // StatsAggregator implementation for V1 is based on pull-based summary, no-op + // in V2. + void EncodeToProto(Summary* out_summary) override {} + + Status SetSummaryWriter( + SummaryWriterInterface* summary_writer_interface) override { + mutex_lock l(mu_); + if (summary_writer_interface_) { + summary_writer_interface_->Unref(); + // If we create stats_aggregator twice in a program, we would end up with + // already existing resource. In this case emitting an error if a + // `summary_writer_resource` is present is not the intended behavior, we + // could either Unref the existing sumary_writer_resource or not set the + // new resource at all. + } + summary_writer_interface_ = summary_writer_interface; + summary_writer_interface_->Ref(); + return Status::OK(); + } + + private: + void AddToEvents(const string& name, const int64 steps, + const float scalar_value) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (summary_writer_interface_ == nullptr) { + return; + } + std::unique_ptr e{new Event}; + e->set_step(steps); + tensorflow::Env* env = tensorflow::Env::Default(); + e->set_wall_time(env->NowMicros() / 1.0e6); + // maybe expose GetWallTime in SummaryWriterInterface + Summary::Value* v = e->mutable_summary()->add_value(); + v->set_tag(name); + v->set_simple_value(scalar_value); + TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e))); + } + + void AddToEvents(const string& name, const int64 steps, + const histogram::Histogram& histogram) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (summary_writer_interface_ == nullptr) { + return; + } + std::unique_ptr e{new Event}; + e->set_step(steps); + tensorflow::Env* env = tensorflow::Env::Default(); + e->set_wall_time(env->NowMicros() / 1.0e6); + Summary::Value* v = e->mutable_summary()->add_value(); + v->set_tag(name); + histogram.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); + TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e))); + } + + mutex mu_; + SummaryWriterInterface* summary_writer_interface_ GUARDED_BY(mu_) = nullptr; + // not owned, we might be associating the default summary_writer from the + // context + std::unordered_map histograms_ GUARDED_BY(mu_); + TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImplV2); +}; + +class StatsAggregatorHandleOpV2 + : public ResourceOpKernel { + public: + explicit StatsAggregatorHandleOpV2(OpKernelConstruction* ctx) + : ResourceOpKernel(ctx) {} + + private: + Status CreateResource(StatsAggregatorResource** ret) override + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = + new StatsAggregatorResource(absl::make_unique()); return Status::OK(); } }; @@ -141,12 +271,45 @@ class StatsAggregatorSummaryOp : public OpKernel { } }; +class StatsAggregatorSetSummaryWriterOp : public OpKernel { + public: + explicit StatsAggregatorSetSummaryWriterOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& resource_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), + errors::InvalidArgument("resource_handle must be a scalar")); + + StatsAggregatorResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + core::ScopedUnref unref_iterator(resource); + + const Tensor& summary_resource_handle_t = ctx->input(1); + OP_REQUIRES(ctx, + TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()), + errors::InvalidArgument("resource_handle must be a scalar")); + SummaryWriterInterface* sumamry_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &sumamry_resource)); + core::ScopedUnref unref_sumamry_resource(sumamry_resource); + TF_CHECK_OK( + resource->stats_aggregator()->SetSummaryWriter(sumamry_resource)); + } +}; + REGISTER_KERNEL_BUILDER( Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU), StatsAggregatorHandleOp); +REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU), + StatsAggregatorHandleOpV2); REGISTER_KERNEL_BUILDER( Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU), StatsAggregatorSummaryOp); +REGISTER_KERNEL_BUILDER( + Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU), + StatsAggregatorSetSummaryWriterOp); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc index be5fa4c789b..08a144049db 100644 --- a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc @@ -108,8 +108,9 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { uint64 end = ctx->env()->NowMicros(); auto stats_aggregator = ctx->stats_aggregator(); if (stats_aggregator && !*end_of_sequence) { - ctx->stats_aggregator()->AddToHistogram( - dataset()->tag_, {static_cast(end - start)}); + int64 steps = num_elements(); + stats_aggregator->AddToHistogram( + dataset()->tag_, {static_cast(end - start)}, steps); } return s; } @@ -220,8 +221,9 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { for (const Tensor& t : *out_tensors) { total_bytes += t.TotalBytes(); } - ctx->stats_aggregator()->AddToHistogram( - dataset()->tag_, {static_cast(total_bytes)}); + int64 steps = num_elements(); + stats_aggregator->AddToHistogram( + dataset()->tag_, {static_cast(total_bytes)}, steps); } return s; } diff --git a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc index 3a6f70e504e..79c5ec0aebd 100644 --- a/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/take_while_dataset_op.cc @@ -32,74 +32,32 @@ namespace { class TakeWhileDatasetOp : public UnaryDatasetOpKernel { public: - using LoopIteratorPredicate = - std::function&, bool*)>; - explicit TakeWhileDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create( + ctx, "predicate", /*params=*/{}, &func_metadata_)); + OP_REQUIRES(ctx, func_metadata_->short_circuit_info().indices.size() <= 1, + errors::InvalidArgument( + "predicate function has more than one return value.")); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { std::unique_ptr captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - &captured_func)); - - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - OP_REQUIRES( - ctx, indices.size() <= 1, - errors::InvalidArgument("`predicate` has more than one return value.")); - - LoopIteratorPredicate loop_pred; - if (indices.empty()) { - loop_pred = [](IteratorContext* ctx, - InstantiatedCapturedFunction* inst_captured_func, - const std::vector& args, bool* end_of_sequence) { - std::vector result; - TF_RETURN_IF_ERROR( - inst_captured_func->RunWithBorrowedArgs(ctx, args, &result)); - - if (result.size() != 1 || result[0].dtype() != DT_BOOL || - result[0].NumElements() != 1) { - return errors::InvalidArgument( - "`predicate` must returns a scalar bool tensor."); - } - *end_of_sequence = !result[0].scalar()(); - return Status::OK(); - }; - } else { - loop_pred = [indices](IteratorContext* ctx, - InstantiatedCapturedFunction* inst_captured_func, - const std::vector& args, - bool* end_of_sequence) { - const Tensor& predicate = args[indices[0]]; - if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { - return errors::InvalidArgument( - "`predicate` must returns a scalar bool tensor."); - } - *end_of_sequence = !predicate.scalar()(); - return Status::OK(); - }; - } - *output = new Dataset(ctx, input, func_, std::move(captured_func), - std::move(loop_pred)); + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func)); + *output = new Dataset(ctx, input, std::move(captured_func)); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr captured_func, - LoopIteratorPredicate loop_pred) + std::unique_ptr captured_func) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), - captured_func_(std::move(captured_func)), - loop_pred_(std::move(loop_pred)) { + captured_func_(std::move(captured_func)) { input_->Ref(); } @@ -108,8 +66,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return MakeUnique( - Iterator::Params{this, strings::StrCat(prefix, "::TakeWhile")}, - loop_pred_); + Iterator::Params{this, strings::StrCat(prefix, "::TakeWhile")}); } const DataTypeVector& output_dtypes() const override { @@ -130,28 +87,15 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); AttrValue f_attr; - b->BuildAttrValue(func_, &f_attr); + b->BuildAttrValue(captured_func_->func(), &f_attr); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -168,8 +112,8 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params, LoopIteratorPredicate loop_pred) - : DatasetIterator(params), loop_pred_(loop_pred) {} + explicit Iterator(const Params& params) + : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -195,8 +139,20 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { input_impl_.reset(); return Status::OK(); } - return loop_pred_(ctx, instantiated_captured_func_.get(), *out_tensors, - end_of_sequence); + std::vector result; + TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs( + ctx, *out_tensors, &result)); + + if (result.size() != 1 || result[0].dtype() != DT_BOOL || + result[0].NumElements() != 1) { + return errors::InvalidArgument( + "`predicate` must returns a scalar bool tensor."); + } + *end_of_sequence = !result[0].scalar()(); + if (*end_of_sequence) { + out_tensors->clear(); + } + return Status::OK(); } protected: @@ -230,16 +186,13 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel { mutex mu_; std::unique_ptr input_impl_ GUARDED_BY(mu_); std::unique_ptr instantiated_captured_func_; - const LoopIteratorPredicate loop_pred_; }; const DatasetBase* const input_; - const NameAttrList func_; const std::unique_ptr captured_func_; - const LoopIteratorPredicate loop_pred_; }; - NameAttrList func_; + std::shared_ptr func_metadata_ = nullptr; }; REGISTER_KERNEL_BUILDER(Name("ExperimentalTakeWhileDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index d017d5ed405..9d1649bf021 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/util/work_sharder.h" @@ -307,19 +308,8 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { bool* end_of_sequence) override { IteratorContext::Params params(ctx); auto max_parallelism = dataset()->max_intra_op_parallelism_; - params.runner = std::bind( - [max_parallelism]( - const std::function)>& runner, - std::function fn) { - std::function scoped_fn = std::bind( - [max_parallelism](const std::function& fn) { - ScopedPerThreadMaxParallelism scope(max_parallelism); - fn(); - }, - std::move(fn)); - (runner)(std::move(scoped_fn)); - }, - std::move(*ctx->runner()), std::placeholders::_1); + params.runner = + RunnerWithMaxParallelism(*ctx->runner(), max_parallelism); return input_impl_->GetNext(IteratorContext{std::move(params)}, out_tensors, end_of_sequence); } diff --git a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc index 6cf6198432b..b450970a717 100644 --- a/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc +++ b/tensorflow/core/kernels/data/experimental/to_tf_record_op.cc @@ -69,7 +69,7 @@ class ToTFRecordOp : public AsyncOpKernel { std::unique_ptr iterator; IteratorContext::Params params(ctx); std::unique_ptr function_handle_cache = - absl::make_unique(params.lib); + absl::make_unique(params.flr); params.function_handle_cache = function_handle_cache.get(); IteratorContext iter_ctx(std::move(params)); diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc index 3b9b319ea94..bbcc84db31b 100644 --- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc @@ -130,7 +130,13 @@ class FilterByLastComponentDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - matched = out_tensors->back().scalar()(); + const Tensor& last_component = out_tensors->back(); + if (last_component.NumElements() != 1 || + last_component.dtype() != DT_BOOL) { + return errors::InvalidArgument( + "Last component must be a bool scalar."); + } + matched = last_component.scalar()(); out_tensors->pop_back(); if (!matched) { // Clear the output tensor list since it didn't match. diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op_test.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op_test.cc new file mode 100644 index 00000000000..04627dfae93 --- /dev/null +++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op_test.cc @@ -0,0 +1,589 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "filter_by_last_component_dataset"; +constexpr char kOpName[] = "FilterByLastComponentDataset"; + +class FilterByLastComponentDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Creates a new `FilterByLastComponentDataset` op kernel. + Status CreateFilterByLastComponentDatasetKernel( + const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"input_dataset"}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Creates a new `FilterByLastComponentDataset` op kernel context. + Status CreateFilterByLastComponentDatasetContext( + OpKernel *const op_kernel, + gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +// Test case 1: simple case. +TestCase TestCase1() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5}), + DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, + {true, false, true})}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2}, {0, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2}, {4, 5})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({2})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {0, 1, 5}}; +} + +// Test case 2: the output of input dataset is empty. +TestCase TestCase2() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{0}, {})}, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {0}}; +} + +// Test case 3: the output of input dataset has only one component. +TestCase TestCase3() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, + {true, false, true})}, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_BOOL}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {0, 1, 5}}; +} + +// Test case 4: the last component has more than one element. +TestCase InvalidLastComponentShape() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5}), + DatasetOpsTestBase::CreateTensor( + TensorShape{3, 2}, {true, false, true, true, false, true})}, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({2})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {}}; +} + +// Test case 5: the data type of last component is not DT_BOOL. +TestCase InvalidLastComponentDType() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5}), + DatasetOpsTestBase::CreateTensor(TensorShape{3}, {1, 1, 0})}, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({2})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {}}; +} + +class ParameterizedFilterByLastComponentDatasetOpTest + : public FilterByLastComponentDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedFilterByLastComponentDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext( + filter_by_last_component_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(filter_by_last_component_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +TEST_F(FilterByLastComponentDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = TestCase1(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + EXPECT_EQ(filter_by_last_component_dataset->node_name(), kNodeName); +} + +TEST_F(FilterByLastComponentDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = TestCase1(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + EXPECT_EQ(filter_by_last_component_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedFilterByLastComponentDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + TF_EXPECT_OK( + VerifyTypesMatch(filter_by_last_component_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedFilterByLastComponentDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + TF_EXPECT_OK( + VerifyShapesCompatible(filter_by_last_component_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedFilterByLastComponentDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + EXPECT_EQ(filter_by_last_component_dataset->Cardinality(), + test_case.expected_cardinality); +} + +TEST_P(ParameterizedFilterByLastComponentDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK( + filter_by_last_component_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedFilterByLastComponentDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext( + filter_by_last_component_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(filter_by_last_component_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedFilterByLastComponentDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext( + filter_by_last_component_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(filter_by_last_component_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(FilterByLastComponentDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = TestCase1(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext( + filter_by_last_component_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(filter_by_last_component_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::FilterByLastComponent"); +} + +TEST_P(ParameterizedFilterByLastComponentDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext( + filter_by_last_component_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(filter_by_last_component_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *filter_by_last_component_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + cur_iteration++; + } + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +INSTANTIATE_TEST_SUITE_P(FilterByLastComponentDatasetOpTest, + ParameterizedFilterByLastComponentDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3()}))); + +TEST_F(FilterByLastComponentDatasetOpTest, InvalidLastComponent) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + std::vector test_cases = {InvalidLastComponentShape(), + InvalidLastComponentDType()}; + for (const TestCase &test_case : test_cases) { + std::unique_ptr filter_by_last_component_dataset_kernel; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetKernel( + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &filter_by_last_component_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = + test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor( + &inputs_for_tensor_slice_dataset, &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_by_last_component_dataset_context; + TF_ASSERT_OK(CreateFilterByLastComponentDatasetContext( + filter_by_last_component_dataset_kernel.get(), &inputs, + &filter_by_last_component_dataset_context)); + DatasetBase *filter_by_last_component_dataset; + TF_ASSERT_OK(CreateDataset(filter_by_last_component_dataset_kernel.get(), + filter_by_last_component_dataset_context.get(), + &filter_by_last_component_dataset)); + core::ScopedUnref scoped_unref(filter_by_last_component_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext( + filter_by_last_component_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(filter_by_last_component_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + + std::vector next; + bool end_of_sequence = false; + EXPECT_EQ( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence).code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 8615e1b45b4..688d120ba8e 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" @@ -33,75 +34,33 @@ namespace { class FilterDatasetOp : public UnaryDatasetOpKernel { public: - using FilterIteratorPredicate = - std::function, bool*)>; - explicit FilterDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create( + ctx, "predicate", /*params=*/{}, &func_metadata_)); + OP_REQUIRES(ctx, func_metadata_->short_circuit_info().indices.size() <= 1, + errors::InvalidArgument( + "predicate function has more than one return value.")); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { std::unique_ptr captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - &captured_func)); + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func)); - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - OP_REQUIRES(ctx, indices.size() <= 1, - errors::InvalidArgument( - "predicate function has more than one return value.")); - - FilterIteratorPredicate filter_pred; - if (indices.empty()) { - filter_pred = [](IteratorContext* ctx, - InstantiatedCapturedFunction* inst_captured_func, - const std::vector& args, bool* out_matched) { - std::vector result; - TF_RETURN_IF_ERROR( - inst_captured_func->RunWithBorrowedArgs(ctx, args, &result)); - - if (result.size() != 1 || result[0].dtype() != DT_BOOL || - result[0].NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = result[0].scalar()(); - return Status::OK(); - }; - } else { - filter_pred = [indices](IteratorContext* ctx, - InstantiatedCapturedFunction* inst_captured_func, - const std::vector& args, - bool* out_matched) { - const Tensor& predicate = args[indices[0]]; - if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = predicate.scalar()(); - return Status::OK(); - }; - } - - *output = new Dataset(ctx, input, func_, std::move(captured_func), - std::move(filter_pred)); + *output = new Dataset(ctx, input, std::move(captured_func)); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr captured_func, - FilterIteratorPredicate filter_pred) + std::unique_ptr captured_func) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), - captured_func_(std::move(captured_func)), - filter_pred_(std::move(filter_pred)) { + captured_func_(std::move(captured_func)) { input_->Ref(); } @@ -110,8 +69,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::Filter")}, - filter_pred_); + Iterator::Params{this, strings::StrCat(prefix, "::Filter")}); } const DataTypeVector& output_dtypes() const override { @@ -127,28 +85,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + DataTypeVector other_arguments_types; + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); AttrValue f; - b->BuildAttrValue(func_, &f); + b->BuildAttrValue(captured_func_->func(), &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -162,13 +106,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params, - FilterIteratorPredicate filter_pred) + explicit Iterator(const Params& params) : DatasetIterator(params), filtered_elements_(0), - dropped_elements_(0), - filter_pred_(std::move(filter_pred)) { - } + dropped_elements_(0) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -202,8 +143,20 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - TF_RETURN_IF_ERROR(filter_pred_( - ctx, instantiated_captured_func_.get(), *out_tensors, &matched)); + std::vector result; + TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs( + ctx, *out_tensors, &result)); + + if (result.size() != 1 || result[0].dtype() != DT_BOOL || + result[0].NumElements() != 1) { + // Clear the output tensor list since there were errors with Filter + // prediction result. + out_tensors->clear(); + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + matched = result[0].scalar()(); + if (!matched) { // Clear the output tensor list since it didn't match. out_tensors->clear(); @@ -213,10 +166,8 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { stats_aggregator->AddScalar( stats_utils::DroppedElementsScalarName( dataset()->node_name()), - static_cast((dropped_elements_))); - // TODO(shivaniagrawal): multiple pipelines would collect - // aggregated number of dropped elements for all the pipelines, - // exploit tagged_context here. + static_cast(dropped_elements_), num_elements()); + stats_aggregator->IncrementCounter(dataset()->node_name(), stats_utils::kDroppedElements, static_cast(1)); @@ -230,10 +181,8 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { filtered_elements_++; stats_aggregator->AddScalar( stats_utils::FilterdElementsScalarName(dataset()->node_name()), - static_cast((filtered_elements_))); - // TODO(shivaniagrawal): multiple pipelines would collect aggregated - // number of filtered elements for all the pipelines, exploit - // tagged_context here. + static_cast(filtered_elements_), num_elements()); + stats_aggregator->IncrementCounter(dataset()->node_name(), stats_utils::kFilteredElements, static_cast(1)); @@ -281,18 +230,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr input_impl_ GUARDED_BY(mu_); int64 filtered_elements_ GUARDED_BY(mu_); int64 dropped_elements_ GUARDED_BY(mu_); - const FilterIteratorPredicate filter_pred_; std::unique_ptr instantiated_captured_func_; }; const DatasetBase* const input_; - const NameAttrList func_; const std::unique_ptr captured_func_; - const FilterIteratorPredicate filter_pred_; }; private: - NameAttrList func_; + std::shared_ptr func_metadata_ = nullptr; }; REGISTER_KERNEL_BUILDER(Name("FilterDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/filter_dataset_op_test.cc b/tensorflow/core/kernels/data/filter_dataset_op_test.cc new file mode 100644 index 00000000000..b145600b833 --- /dev/null +++ b/tensorflow/core/kernels/data/filter_dataset_op_test.cc @@ -0,0 +1,593 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "filter_dataset"; +constexpr char kOpName[] = "FilterDataset"; + +class FilterDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Creates a new `FilterDataset` op kernel + Status CreateFilterDatasetKernel( + const FunctionDefHelper::AttrValueWrapper &func, + const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + NodeDef node_def = + test::function::NDef(kNodeName, kOpName, {"input_dataset"}, + {{"predicate", func}, + {"Targuments", {}}, + {"output_types", output_types}, + {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Creates a new `ParallelInterleaveDataset` op kernel context. + Status CreateFilterDatasetContext( + OpKernel *const op_kernel, + gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + FunctionDefHelper::AttrValueWrapper func; + std::vector func_lib; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +template +std::vector ConvertToTensorVec(std::vector values) { + std::vector tensors; + tensors.reserve(values.size()); + for (auto &value : values) { + tensors.emplace_back( + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {value})); + } + return tensors; +} + +// Test case 1: norm case. +TestCase TestCase1() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{9, 1}, {0, 0, 0, 3, 4, 5, 6, 7, 8})}, + /*func*/ FunctionDefHelper::FunctionRef("IsZero", {{"T", DT_INT64}}), + /*func_lib*/ {test::function::IsZero()}, + /*expected_outputs*/ + ConvertToTensorVec({0, 0, 0}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {0, 2, 6}}; +} + +// Test case 2: the input dataset has no outputs. +TestCase TestCase2() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{0}, {})}, + /*func*/ FunctionDefHelper::FunctionRef("IsZero", {{"T", DT_INT64}}), + /*func_lib*/ {test::function::IsZero()}, + /*expected_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {0, 2, 6}}; +} + +// Test case 3: the filter function returns two outputs. +TestCase InvalidFuncTestCase1() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3}, {0, 0, 0, 3, 4, 5, 6, 7, 8})}, + /*func*/ + FunctionDefHelper::FunctionRef( + "GetUnique", {{"T", DT_INT64}, {"out_idx", DT_INT32}}), + /*func_lib*/ {test::function::Unique()}, + /*expected_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({3, 1})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {}}; +} + +// Test case 4: the filter function returns a 1-D bool tensor. +TestCase InvalidFuncTestCase2() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 0, 0, 3, 4, 5, 6, 7, 8})}, + /*func*/ FunctionDefHelper::FunctionRef("IsZero", {{"T", DT_INT64}}), + /*func_lib*/ {test::function::IsZero()}, + /*expected_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({3, 1})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {}}; +} + +// Test case 5: the filter function returns a scalar int64 tensor. +TestCase InvalidFuncTestCase3() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{9}, {0, 0, 0, 3, 4, 5, 6, 7, 8})}, + /*func*/ FunctionDefHelper::FunctionRef("NonZero", {{"T", DT_INT64}}), + /*func_lib*/ {test::function::NonZero()}, + /*expected_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ kUnknownCardinality, + /*breakpoints*/ {}}; +} + +class ParameterizedFilterDatasetOpTest + : public FilterDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedFilterDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(filter_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + filter_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +TEST_F(FilterDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + EXPECT_EQ(filter_dataset->node_name(), kNodeName); +} + +TEST_F(FilterDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + EXPECT_EQ(filter_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedFilterDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(filter_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedFilterDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(filter_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedFilterDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + EXPECT_EQ(filter_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_P(ParameterizedFilterDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(filter_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedFilterDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(filter_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + filter_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedFilterDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(filter_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + filter_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(ParameterizedFilterDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(filter_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + filter_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::Filter"); +} + +TEST_P(ParameterizedFilterDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs, + &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(filter_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + filter_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *filter_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + cur_iteration++; + } + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +INSTANTIATE_TEST_SUITE_P( + FilterDatasetOpTest, ParameterizedFilterDatasetOpTest, + ::testing::ValuesIn(std::vector({TestCase1(), TestCase2()}))); + +TEST_F(ParameterizedFilterDatasetOpTest, InvalidFuncs) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime( + {test::function::IsZero(), test::function::Unique(), + test::function::NonZero()}, + cpu_num)); + + std::vector test_cases( + {InvalidFuncTestCase1(), InvalidFuncTestCase2(), InvalidFuncTestCase3()}); + for (const auto &test_case : test_cases) { + std::unique_ptr filter_dataset_kernel; + TF_ASSERT_OK(CreateFilterDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &filter_dataset_kernel)); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = + test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor( + &inputs_for_tensor_slice_dataset, &tensor_slice_dataset_tensor)); + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr filter_dataset_context; + TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), + &inputs, &filter_dataset_context)); + DatasetBase *filter_dataset; + TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(), + filter_dataset_context.get(), &filter_dataset)); + core::ScopedUnref scoped_unref(filter_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(filter_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(filter_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + bool end_of_sequence = false; + std::vector out_tensors; + EXPECT_EQ( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence) + .code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_TRUE(out_tensors.empty()); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index 3f01ac55699..d3e571a8a12 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -32,7 +32,8 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { explicit FlatMapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "f", /*params=*/{}, + &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -40,23 +41,22 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { std::unique_ptr captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - &captured_func)); - *output = new Dataset(ctx, input, func_, std::move(captured_func), - output_types_, output_shapes_); + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func)); + *output = new Dataset(ctx, input, std::move(captured_func), output_types_, + output_shapes_); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, std::unique_ptr captured_func, const DataTypeVector& output_types, const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), output_shapes_(output_shapes) { @@ -85,28 +85,14 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + DataTypeVector other_arguments_types; + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); AttrValue f; - b->BuildAttrValue(func_, &f); + b->BuildAttrValue(captured_func_->func(), &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -262,7 +248,6 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; - const NameAttrList func_; const std::unique_ptr captured_func_; const DataTypeVector output_types_; const std::vector output_shapes_; @@ -271,7 +256,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { const int graph_def_version_; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList func_; + std::shared_ptr func_metadata_ = nullptr; }; REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc b/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc new file mode 100644 index 00000000000..4cb4f0471e8 --- /dev/null +++ b/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc @@ -0,0 +1,581 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "flat_map_dataset"; +constexpr char kOpName[] = "FlatMapDataset"; + +class FlatMapDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Creates a new `FlatMapDataset` op kernel + Status CreateFlatMapDatasetKernel( + const FunctionDefHelper::AttrValueWrapper &func, + const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + NodeDef node_def = + test::function::NDef(kNodeName, kOpName, {"input_dataset"}, + {{"f", func}, + {"Targuments", {}}, + {"output_types", output_types}, + {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Creates a new `FlatMapDataset` op kernel context. + Status CreateFlatMapDatasetContext( + OpKernel *const op_kernel, + gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + FunctionDefHelper::AttrValueWrapper func; + std::vector func_lib; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +TestCase MakeTensorSliceDatasetFuncTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + FunctionDefHelper::FunctionRef( + /*name*/ "MakeTensorSliceDataset", + /*attrs*/ {{"Toutput_types", DataTypeVector({DT_INT64})}, + {"output_shapes", std::vector( + {PartialTensorShape({1})})}}), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {0}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {5}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {6}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {7}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// Test case 2: test the case if the function does not return a single scalar +// of dtype DT_VARIANT. +TestCase InvalidFuncTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + FunctionDefHelper::FunctionRef(/*name*/ "NonZero", + /*attrs*/ {{"T", DT_INT64}}), + /*func_lib*/ {test::function::NonZero()}, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} + +class ParameterizedFlatMapDatasetOpTest + : public FlatMapDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedFlatMapDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(flat_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(flat_map_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + auto expected_outputs_it = test_case.expected_outputs.begin(); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + for (const auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); +} + +TEST_F(FlatMapDatasetOpTest, InvalidFunc) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = InvalidFuncTestCase(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(flat_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(flat_map_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + EXPECT_EQ( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(FlatMapDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = MakeTensorSliceDatasetFuncTestCase(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + EXPECT_EQ(flat_map_dataset->node_name(), kNodeName); +} + +TEST_F(FlatMapDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = MakeTensorSliceDatasetFuncTestCase(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + EXPECT_EQ(flat_map_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedFlatMapDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(flat_map_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedFlatMapDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(flat_map_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedFlatMapDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + EXPECT_EQ(flat_map_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_F(FlatMapDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = MakeTensorSliceDatasetFuncTestCase(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(flat_map_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedFlatMapDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(flat_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(flat_map_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedFlatMapDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(flat_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(flat_map_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(FlatMapDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = MakeTensorSliceDatasetFuncTestCase(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(flat_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(flat_map_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::FlatMap"); +} + +TEST_P(ParameterizedFlatMapDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr flat_map_dataset_kernel; + TF_ASSERT_OK(CreateFlatMapDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &flat_map_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor}); + std::unique_ptr flat_map_dataset_context; + TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(), + &inputs, &flat_map_dataset_context)); + DatasetBase *flat_map_dataset; + TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(), + flat_map_dataset_context.get(), + &flat_map_dataset)); + core::ScopedUnref scoped_unref(flat_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(flat_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(flat_map_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *flat_map_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + for (auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_outputs.size()) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } +} + +INSTANTIATE_TEST_SUITE_P(FlatMapDatasetOpTest, + ParameterizedFlatMapDatasetOpTest, + ::testing::ValuesIn(std::vector( + {MakeTensorSliceDatasetFuncTestCase()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 3469743af63..a3c39a48945 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { @@ -154,9 +155,13 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "init_func", /*params=*/{}, + &init_func_metadata_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "next_func", /*params=*/{}, + &next_func_metadata_)); + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "finalize_func", /*params=*/{}, + &finalize_func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -164,15 +169,17 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { std::unique_ptr init_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - init_func_, ctx, "init_func_other_args", &init_func)); + OP_REQUIRES_OK(ctx, + CapturedFunction::Create(ctx, init_func_metadata_, + "init_func_other_args", &init_func)); std::unique_ptr next_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - next_func_, ctx, "next_func_other_args", &next_func)); + OP_REQUIRES_OK(ctx, + CapturedFunction::Create(ctx, next_func_metadata_, + "next_func_other_args", &next_func)); std::unique_ptr finalize_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx, + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, finalize_func_metadata_, "finalize_func_other_args", &finalize_func)); diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h index d23ed97ec3a..951440eeaa7 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.h +++ b/tensorflow/core/kernels/data/generator_dataset_op.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/captured_function.h" namespace tensorflow { namespace data { @@ -32,9 +33,9 @@ class GeneratorDatasetOp : public DatasetOpKernel { DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList init_func_; - NameAttrList next_func_; - NameAttrList finalize_func_; + std::shared_ptr init_func_metadata_ = nullptr; + std::shared_ptr next_func_metadata_ = nullptr; + std::shared_ptr finalize_func_metadata_ = nullptr; }; } // namespace data diff --git a/tensorflow/core/kernels/data/graph_rewrite_dataset.cc b/tensorflow/core/kernels/data/graph_rewrite_dataset.cc deleted file mode 100644 index cd8026607e7..00000000000 --- a/tensorflow/core/kernels/data/graph_rewrite_dataset.cc +++ /dev/null @@ -1,250 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h" - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" -#include "tensorflow/core/protobuf/meta_graph.pb.h" -#include "tensorflow/core/protobuf/rewriter_config.pb.h" - -namespace tensorflow { -namespace data { - -GraphRewriteDataset::~GraphRewriteDataset() { - input_->Unref(); - if (optimized_input_) { - optimized_input_->Unref(); - } -} - -Status GraphRewriteDataset::Optimize(OpKernelContext* ctx) { - GraphDefBuilder b; - DatasetGraphDefBuilder db(&b); - Node* input_node = nullptr; - SerializationContext::Params params; - std::vector> input_list; - params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - params.input_list = &input_list; - params.optimization_only = true; - SerializationContext serialization_ctx(params); - TF_RETURN_IF_ERROR( - db.AddInputDataset(&serialization_ctx, input_, &input_node)); - string output_node = input_node->name(); - - GraphDef graph_def; - TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); - VLOG(3) << "Before optimization: " << graph_def.DebugString(); - - TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node)); - VLOG(3) << "After optimization: " << graph_def.DebugString(); - - // Instantiate the optimized input pipeline by running the optimized graph - // using the optimized function library. - TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_)); - - // Create a FunctionHandleCache. - function_handle_cache_ = absl::make_unique(lib_); - - // Some functions may have been modified without having their names - // changed (for example, nested dataset graphs from FlatMap or - // Interleave). - TF_RETURN_IF_ERROR( - AddToFunctionLibrary(flib_def_.get(), graph_def.library())); - - Graph graph(OpRegistry::Global()); - TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); - std::vector outputs; - GraphRunner graph_runner(ctx->function_library()->device()); - - TF_RETURN_IF_ERROR( - graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs)); - TF_RETURN_IF_ERROR( - GetDatasetFromVariantTensor(outputs[0], &optimized_input_)); - optimized_input_->Ref(); - return Status::OK(); -} - -Status GraphRewriteDataset::AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const { - SerializationContext::Params params; - // The optimized input needs access to the newly optimized functions when - // it is serialized. Here, we use the optimized function library for - // serialization, which is the union of the function library from the - // OpKernelContext at dataset creation time and newly optimized functions. - // This includes all functions that optimized_input_ may use. - params.flib_def = flib_def_.get(); - params.input_list = ctx->input_list(); - params.optimization_only = ctx->optimization_only(); - SerializationContext optimized_ctx(params); - - // We only serialize the optimized dataset to avoid re-running - // optimizations when the input pipeline is restored from a checkpoint. - TF_RETURN_IF_ERROR( - b->AddInputDataset(&optimized_ctx, optimized_input_, output)); - return Status::OK(); -} - -namespace { -void AddFakeSinks(FunctionDef* function_def) { - int counter = 0; - for (const auto& output : function_def->signature().output_arg()) { - NodeDef* node = function_def->add_node_def(); - tensorflow::grappler::function_utils::SetUniqueFunctionNodeName( - strings::StrCat("FakeSink", counter++), function_def, node); - node->set_op("Identity"); - node->add_input(function_def->ret().at(output.name())); - (*node->mutable_attr())["T"].set_type(output.type()); - - (*function_def->mutable_ret())[output.name()] = - strings::StrCat(node->name(), ":output:0"); - } -} - -void RemoveFakeSinks(FunctionDef* function_def) { - // Map from identity node names to their input tensor strings - std::map identity_map; - for (const auto& node : function_def->node_def()) { - if (node.op() == "Identity" && node.input_size() == 1) { - identity_map[node.name()] = node.input(0); - } - } - for (const auto& output_arg : function_def->signature().output_arg()) { - const string& tensor = function_def->ret().at(output_arg.name()); - const string& output_node = tensor.substr(0, tensor.find(':')); - if (identity_map.find(output_node) != identity_map.end()) { - (*function_def->mutable_ret())[output_arg.name()] = - identity_map.at(output_node); - } - } -} -} // anonymous namespace - -Status GraphRewriteDataset::ApplyOptimizations(OpKernelContext* ctx, - GraphDef* graph_def, - string* output_node) { - // Add an identity node as the fetch node, otherwise we might get - // 'placeholder is both fed and fetched' errors in some cases when using - // input list with placeholder dataset nodes. - NodeDef* node = graph_def->mutable_node()->Add(); - tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def, - node); - node->set_op("Identity"); - node->add_input(*output_node); - (*node->mutable_attr())["T"].set_type(DT_VARIANT); - *output_node = node->name(); - - // Add fake sink node to graph and functions to allow rewriting the actual - // sink nodes. - // TODO(b/118820916): When MetaOptimizer adds provisions for function - // retvals to be optimizable, we will no longer need this. - for (auto& function_def : *graph_def->mutable_library()->mutable_function()) { - AddFakeSinks(&function_def); - } - - // Create metagraph. - MetaGraphDef meta_graph_def; - (*meta_graph_def.mutable_graph_def()) = *graph_def; - - // Grappler determines fetch ops from collection 'train_op'. - CollectionDef collection_def; - auto node_list = collection_def.mutable_node_list(); - node_list->add_value(*output_node); - (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def; - - // Create Grappler item. - tensorflow::grappler::ItemConfig item_config; - item_config.apply_optimizations = true; - std::unique_ptr grappler_item = - tensorflow::grappler::GrapplerItemFromMetaGraphDef( - "graph", meta_graph_def, item_config); - grappler_item->optimization_options().optimize_function_library = - ShouldOptimizeFunctions(); - std::unordered_map device_map; - tensorflow::grappler::VirtualCluster cluster(device_map); - - // Run data optimizer using grappler's meta optimizer. - tensorflow::ConfigProto config; - *config.mutable_graph_options()->mutable_rewrite_options() = - CreateGrapplerRewriteConfig(); - TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( - *grappler_item, config, ctx->device(), &cluster, graph_def)); - - // Remove fake sinks after optimizations are done. - // TODO(b/118820916): When MetaOptimizer adds provisions for function - // retvals to be optimizable, we will no longer need this. - for (auto& function_def : *graph_def->mutable_library()->mutable_function()) { - RemoveFakeSinks(&function_def); - } - - return Status::OK(); -} - -class GraphRewriteDataset::Iterator - : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - - Status Initialize(IteratorContext* ctx) override { - IteratorContext::Params params(ctx); - params.lib = dataset()->lib_; - params.function_handle_cache = dataset()->function_handle_cache_.get(); - return dataset()->optimized_input_->MakeIterator( - IteratorContext(std::move(params)), prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) override { - IteratorContext::Params params(ctx); - params.lib = dataset()->lib_; - params.function_handle_cache = dataset()->function_handle_cache_.get(); - return input_impl_->GetNext(IteratorContext(std::move(params)), out_tensors, - end_of_sequence); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return Status::OK(); - } - - private: - std::unique_ptr input_impl_; -}; - -std::unique_ptr GraphRewriteDataset::MakeIteratorInternal( - const string& prefix) const { - // We do not add a token for this dataset to the prefix. The - // prefix is used to identify checkpoint elements and since this - // dataset is excluded from the checkpoint, adding a token - // here would result in invalid checkpoint identifiers. - return absl::make_unique(Iterator::Params{this, prefix}); -} - -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/graph_rewrite_dataset.h b/tensorflow/core/kernels/data/graph_rewrite_dataset.h deleted file mode 100644 index 856fcd3ea72..00000000000 --- a/tensorflow/core/kernels/data/graph_rewrite_dataset.h +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_ -#define TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_ - -#include "tensorflow/core/common_runtime/graph_runner.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/function_handle_cache.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/grappler/clusters/virtual_cluster.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/grappler_item_builder.h" -#include "tensorflow/core/grappler/optimizers/data/function_utils.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" -#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" - -namespace tensorflow { -namespace data { - -class GraphRewriteDataset : public DatasetBase { - public: - GraphRewriteDataset(OpKernelContext* ctx, const DatasetBase* input, - const DataTypeVector& output_types, - const std::vector& output_shapes) - : DatasetBase(DatasetContext(ctx)), - optimized_input_(nullptr), - input_(input), - output_types_(output_types), - output_shapes_(output_shapes) { - input_->Ref(); - } - - ~GraphRewriteDataset() override; - - // Runs Grappler to transform the input dataset into optimized_input_ - // dataset. - Status Optimize(OpKernelContext* ctx); - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override; - - const DataTypeVector& output_dtypes() const override { return output_types_; } - - const std::vector& output_shapes() const override { - return output_shapes_; - } - - int64 Cardinality() const override { return input_->Cardinality(); } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override; - - private: - class Iterator; - - // Create a Grappler RewriteConfig proto that defines the list of - // optimizations to be run by the Grappler Meta Optimizer. - virtual RewriterConfig CreateGrapplerRewriteConfig() = 0; - - // Option specifying whether we want to optimize the function library as well. - virtual bool ShouldOptimizeFunctions() { return true; } - - Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def, - string* output_node); - - DatasetBase* optimized_input_; - FunctionLibraryRuntime* lib_ = nullptr; - std::unique_ptr pflr_ = nullptr; - std::unique_ptr flib_def_ = nullptr; - std::unique_ptr function_handle_cache_ = nullptr; - const DatasetBase* input_; - const DataTypeVector output_types_; - const std::vector output_shapes_; -}; - -} // namespace data -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_ diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 69310bcff23..0dcdf196033 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -32,7 +32,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { explicit InterleaveDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "f", /*params=*/{}, + &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -58,25 +59,23 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { errors::InvalidArgument("block_length must be greater than zero.")); std::unique_ptr captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - &captured_func)); + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func)); - *output = - new Dataset(ctx, input, func_, std::move(captured_func), cycle_length, - block_length, output_types_, output_shapes_); + *output = new Dataset(ctx, input, std::move(captured_func), cycle_length, + block_length, output_types_, output_shapes_); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, std::unique_ptr captured_func, int64 cycle_length, int64 block_length, const DataTypeVector& output_types, const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), captured_func_(std::move(captured_func)), cycle_length_(cycle_length), block_length_(block_length), @@ -108,31 +107,18 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node)); Node* block_length_node; TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node)); - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + DataTypeVector other_arguments_types; + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); AttrValue f; - b->BuildAttrValue(func_, &f); + b->BuildAttrValue(captured_func_->func(), &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -317,7 +303,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; - const NameAttrList func_; const std::unique_ptr captured_func_; const int64 cycle_length_; const int64 block_length_; @@ -328,7 +313,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { const int graph_def_version_; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList func_; + std::shared_ptr func_metadata_ = nullptr; }; REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/interleave_dataset_op_test.cc new file mode 100644 index 00000000000..9fb35fc4faa --- /dev/null +++ b/tensorflow/core/kernels/data/interleave_dataset_op_test.cc @@ -0,0 +1,800 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "interleave_dataset"; +constexpr char kOpName[] = "InterleaveDataset"; + +class InterleaveDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Creates a new `InterleaveDataset` op kernel + Status CreateInterleaveDatasetKernel( + const FunctionDefHelper::AttrValueWrapper &func, + const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"input_dataset", "cycle_length", "block_length"}, + {{"f", func}, + {"Targuments", {}}, + {"output_types", output_types}, + {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Creates a new `InterleaveDataset` op kernel context. + Status CreateInterleaveDatasetContext( + OpKernel *const op_kernel, + gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + FunctionDefHelper::AttrValueWrapper func; + std::vector func_lib; + Tensor cycle_length; + Tensor block_length; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +template +std::vector ConvertToTensorVec(std::vector values) { + std::vector tensors; + tensors.reserve(values.size()); + for (auto &value : values) { + tensors.emplace_back( + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {value})); + } + return tensors; +} + +FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc( + const DataTypeVector &output_types, + const std::vector &output_shapes) { + return FunctionDefHelper::FunctionRef( + /*name*/ "MakeTensorSliceDataset", + /*attrs*/ {{"Toutput_types", output_types}, + {"output_shapes", output_shapes}}); +} + +// test case 1: cycle_length = 1, block_length = 1. +TestCase TestCase1() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*expected_outputs*/ + ConvertToTensorVec({0, 1, 2, 3, 4, 5, 6, 7, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 2: cycle_length = 2, block_length = 1. +TestCase TestCase2() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*expected_outputs*/ + ConvertToTensorVec({0, 3, 1, 4, 2, 5, 6, 7, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 3: cycle_length = 3, block_length = 1. +TestCase TestCase3() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*expected_outputs*/ + ConvertToTensorVec({0, 3, 6, 1, 4, 7, 2, 5, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 4: cycle_length = 5, block_length = 1. +TestCase TestCase4() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*expected_outputs*/ + ConvertToTensorVec({0, 3, 6, 1, 4, 7, 2, 5, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 5: cycle_length = 2, block_length = 2. +TestCase TestCase5() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "d", "e", "c", "f", "g", "h", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 6: cycle_length = 2, block_length = 3. +TestCase TestCase6() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "c", "d", "e", "f", "g", "h", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 7: cycle_length = 2, block_length = 5. +TestCase TestCase7() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "c", "d", "e", "f", "g", "h", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 8: cycle_length = 0, block_length = 5. +TestCase InvalidCycleLengthTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*expected_outputs*/ ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {}}; +} + +// test case 9: cycle_length = 1, block_length = -1. +TestCase InvalidBlockLengthTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {-1}), + /*expected_outputs*/ ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {}}; +} + +class ParameterizedInterleaveDatasetOpTest + : public InterleaveDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedInterleaveDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(interleave_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(interleave_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + auto expected_outputs_it = test_case.expected_outputs.begin(); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + for (const auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); +} + +TEST_F(InterleaveDatasetOpTest, InvalidCycleLength) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = InvalidCycleLengthTestCase(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + EXPECT_EQ(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), &interleave_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(InterleaveDatasetOpTest, InvalidBlockLength) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = InvalidBlockLengthTestCase(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + EXPECT_EQ(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), &interleave_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(InterleaveDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + EXPECT_EQ(interleave_dataset->node_name(), kNodeName); +} + +TEST_F(InterleaveDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + EXPECT_EQ(interleave_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedInterleaveDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(interleave_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedInterleaveDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(interleave_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedInterleaveDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + EXPECT_EQ(interleave_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_P(ParameterizedInterleaveDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(interleave_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedInterleaveDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(interleave_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(interleave_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedInterleaveDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(interleave_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(interleave_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(InterleaveDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(interleave_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(interleave_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::Interleave"); +} + +TEST_P(ParameterizedInterleaveDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr interleave_dataset_kernel; + TF_ASSERT_OK(CreateInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + gtl::InlinedVector inputs( + {&tensor_slice_dataset_tensor, &cycle_length, &block_length}); + std::unique_ptr interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context)); + DatasetBase *interleave_dataset; + TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(), + interleave_dataset_context.get(), + &interleave_dataset)); + core::ScopedUnref scoped_unref(interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(interleave_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(interleave_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *interleave_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + for (auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_outputs.size()) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } +} + +INSTANTIATE_TEST_SUITE_P(InterleaveDatasetOpTest, + ParameterizedInterleaveDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3(), + TestCase4(), TestCase5(), TestCase6(), + TestCase7()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 14fb6624ad7..8a37dacd6ca 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/iterator_ops.h" + #include #include "absl/memory/memory.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/data/unbounded_thread_pool.h" @@ -59,11 +61,11 @@ class IteratorResource : public ResourceBase { std::unique_ptr device_mgr, std::unique_ptr flib_def, std::unique_ptr pflr, - FunctionLibraryRuntime* lib) + FunctionLibraryRuntime* flr) : unbounded_thread_pool_(env, "tf_data_iterator_resource"), device_mgr_(std::move(device_mgr)), iterator_state_(std::make_shared( - std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */)), + std::move(flib_def), std::move(pflr), flr, nullptr /* iterator */)), output_dtypes_(output_dtypes), output_shapes_(output_shapes) {} @@ -76,7 +78,7 @@ class IteratorResource : public ResourceBase { } if (captured_state->iterator) { IteratorContext::Params params(ctx); - params.lib = captured_state->lib; + params.flr = captured_state->flr; params.function_handle_cache = captured_state->function_handle_cache.get(); params.resource_mgr = &captured_state->resource_mgr; @@ -103,17 +105,7 @@ class IteratorResource : public ResourceBase { captured_state = iterator_state_; } if (captured_state) { - SerializationContext::Params params; - // The iterator state may contain functions that are not present - // in ctx's function library. Namely, an iterator may be restored from - // a serialized iterator with a modified function library (for example, as - // a result of OptimizeDataset). These modified functions are needed - // to serialize the iterator again. - params.flib_def = captured_state->flib_def.get(); - params.input_list = ctx->input_list(); - params.optimization_only = ctx->optimization_only(); - SerializationContext ctx_with_functions(params); - return captured_state->iterator->Save(&ctx_with_functions, writer); + return captured_state->iterator->Save(ctx, writer); } else { return errors::FailedPrecondition( "Save() failed because the iterator has not been initialized. " @@ -144,10 +136,11 @@ class IteratorResource : public ResourceBase { // NOTE(mrry): We clone the existing FLR and use it in the GraphRunner // because some of the OpKernels in the graph might call functions that are // only defined in the loaded GraphDef. - FunctionLibraryRuntime* lib; + FunctionLibraryRuntime* flr; std::unique_ptr flib_def(nullptr); std::unique_ptr pflr(nullptr); - TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib)); + TF_RETURN_IF_ERROR( + ctx->function_library()->Clone(&flib_def, &pflr, &flr, true)); // Some function names may be duplicated (for example, if the serialized // graph has an optimized function that retains its original name). We @@ -157,14 +150,14 @@ class IteratorResource : public ResourceBase { TF_RETURN_IF_ERROR( AddToFunctionLibrary(flib_def.get(), graph_def.library())); std::unique_ptr new_state = absl::make_unique( - std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */); + std::move(flib_def), std::move(pflr), flr, nullptr /* iterator */); TF_RETURN_IF_ERROR( - graph_runner.Run(&graph, new_state->lib, {}, {output_node}, &outputs)); + graph_runner.Run(&graph, new_state->flr, {}, {output_node}, &outputs)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); IteratorContext::Params params(ctx); - params.lib = new_state->lib; + params.flr = new_state->flr; params.function_handle_cache = new_state->function_handle_cache.get(); params.resource_mgr = &new_state->resource_mgr; params.thread_factory = unbounded_thread_pool_.get_thread_factory(); @@ -178,10 +171,10 @@ class IteratorResource : public ResourceBase { { IteratorContext::Params params(ctx); - params.lib = new_state->lib; + params.flr = new_state->flr; params.function_handle_cache = new_state->function_handle_cache.get(); params.resource_mgr = &new_state->resource_mgr; - DeviceBase* device = new_state->lib->device(); + DeviceBase* device = new_state->flr->device(); params.allocator_getter = [device](AllocatorAttributes attrs) { return device->GetAllocator(attrs); }; @@ -195,49 +188,21 @@ class IteratorResource : public ResourceBase { return Status::OK(); } - Status AddLibrary(const FunctionLibraryDefinition& flib_def) { - mutex_lock l(mu_); - return iterator_state_->flib_def->AddLibrary(flib_def); - } - Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset) { std::shared_ptr new_state; { tf_shared_lock l(mu_); new_state = std::make_shared( iterator_state_->flib_def, iterator_state_->pflr, - iterator_state_->lib, nullptr /* function_handle_cache */, + iterator_state_->flr, nullptr /* function_handle_cache */, nullptr /* iterator */); } - - // Ensure that the iterator has access to all functions in the current - // subgraph, because some functions may have been defined after the resource - // was initially created. - Status s = new_state->flib_def->AddLibrary( - *ctx->function_library()->GetFunctionLibraryDefinition()); - - if (!s.ok()) { - // Adding functions to `flib_def_` may fail, if there are clashes between - // the function names in (e.g.) a restored graph and the currently - // executing graph. In that case, we create a new function runtime for - // this iterator, based on the current `OpKernelContext`, which will have - // the functions we need. - FunctionLibraryRuntime* lib; - std::unique_ptr flib_def(nullptr); - std::unique_ptr pflr(nullptr); - TF_RETURN_IF_ERROR( - ctx->function_library()->Clone(&flib_def, &pflr, &lib)); - new_state->flib_def = std::move(flib_def); - new_state->pflr = std::move(pflr); - new_state->lib = lib; - } - new_state->function_handle_cache = - absl::make_unique(new_state->lib); + absl::make_unique(new_state->flr); // Create new iterator. std::unique_ptr iterator; IteratorContext::Params params(ctx); - params.lib = new_state->lib; + params.flr = new_state->flr; params.function_handle_cache = new_state->function_handle_cache.get(); params.resource_mgr = &new_state->resource_mgr; params.thread_factory = unbounded_thread_pool_.get_thread_factory(); @@ -262,31 +227,97 @@ class IteratorResource : public ResourceBase { return output_shapes_; } + // This class is used to guarantee that an anonymous iterator is deleted + // (irrespective of whether the DeleteIteratorOp op is called explicitly or + // the execution encounters an error before the op runs). + // + // This is achieved by wrapping an instance of this class into a variant + // tensor which is passed as an input to the DeleteIteratorOp. If the + // execution encounters an error before the op runs, the tensor will be + // destroyed, essentially triggering the iterator deletion. + class Deleter { + public: + Deleter() : deleter_() {} + + Deleter(ResourceHandle handle, ResourceMgr* resource_manager) + : deleter_(std::make_shared(handle, resource_manager)) {} + + Deleter(Deleter&& rhs) : deleter_(std::move(rhs.deleter_)) { + VLOG(3) << "IteratorResource::Deleter move constructor called."; + } + + Deleter(const Deleter& rhs) : deleter_(rhs.deleter_) { + VLOG(3) << "IteratorResource::Deleter copy constructor called."; + } + + Deleter& operator=(const Deleter& rhs) = delete; + + Deleter& operator=(Deleter&& rhs) = default; + + virtual ~Deleter() { + VLOG(3) << "IteratorResource::Deleter destructor called."; + } + + void Encode(VariantTensorData*) const { + // Not supported. + } + + bool Decode(const VariantTensorData&) { + return false; // Not supported. + } + + private: + // Helper that performs reference counting for the parent class and deletes + // the iterator resource when the refcount goes to zero. + // + // NOTE: The object is borrowing a pointer to the resource manager. + // Consequently, the tensor containing this object should not escape the + // function in which was created (so that it is guaranteed that the resource + // manager will outlive it). + struct Helper { + Helper(ResourceHandle handle, ResourceMgr* resource_manager) + : handle(handle), resource_manager(resource_manager) {} + + Helper(const Helper& rhs) = delete; + Helper(Helper&& rhs) = delete; + + ~Helper() { + VLOG(3) << "Deleting IteratorResource: " << handle.DebugString(); + resource_manager->Delete(handle).IgnoreError(); + } + + ResourceHandle handle; + ResourceMgr* resource_manager; // not owned + }; + + std::shared_ptr deleter_; + }; + private: struct State { State(std::shared_ptr flib_def, std::shared_ptr pflr, - FunctionLibraryRuntime* lib, std::unique_ptr iterator) + FunctionLibraryRuntime* flr, std::unique_ptr iterator) : flib_def(flib_def), + flr(flr), pflr(pflr), - lib(lib), - function_handle_cache(absl::make_unique(lib)), + function_handle_cache(absl::make_unique(flr)), iterator(std::move(iterator)) {} State(std::shared_ptr flib_def, std::shared_ptr pflr, - FunctionLibraryRuntime* lib, + FunctionLibraryRuntime* flr, std::unique_ptr function_handle_cache, std::unique_ptr iterator) : flib_def(flib_def), + flr(flr), pflr(pflr), - lib(lib), function_handle_cache(std::move(function_handle_cache)), iterator(std::move(iterator)) {} std::shared_ptr flib_def; + FunctionLibraryRuntime* flr = nullptr; // not owned. std::shared_ptr pflr; - FunctionLibraryRuntime* lib = nullptr; // not owned. std::unique_ptr function_handle_cache; ResourceMgr resource_mgr; std::unique_ptr iterator; @@ -333,13 +364,14 @@ class IteratorStateVariant { Decode(*other.data_); } } + IteratorStateVariant& operator=(IteratorStateVariant&& other) = default; + IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete; + // Initializes this object with the current state of the iterator so // that it can be written on the next call to Encode(). Status InitializeFromIterator(OpKernelContext* ctx, IteratorResource* iterator_resource) { - SerializationContext::Params params; - params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - SerializationContext serialization_ctx(params); + SerializationContext serialization_ctx({}); data_ = absl::make_unique(); data_->set_type_name(TypeName()); VariantTensorDataWriter writer(data_.get()); @@ -418,7 +450,7 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { { mutex_lock l(mu_); if (resource_ == nullptr) { - FunctionLibraryRuntime* lib; + FunctionLibraryRuntime* flr; std::unique_ptr device_mgr(nullptr); std::unique_ptr flib_def(nullptr); std::unique_ptr pflr(nullptr); @@ -427,10 +459,10 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { // functions from the iterator. We may add this functionality if there // is sufficient demand, but it will require a significant refactoring. if (!name_.empty()) { - lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); + flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); } else { OP_REQUIRES_OK(context, context->function_library()->Clone( - &flib_def, &pflr, &lib)); + &flib_def, &pflr, &flr, true)); } ResourceMgr* mgr = context->resource_manager(); @@ -441,12 +473,12 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { context, mgr->LookupOrCreate( cinfo_.container(), cinfo_.name(), &resource, - [context, lib, &device_mgr, &flib_def, &pflr, + [context, flr, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new IteratorResource( context->env(), output_dtypes_, output_shapes_, graph_def_version_, std::move(device_mgr), - std::move(flib_def), std::move(pflr), lib); + std::move(flib_def), std::move(pflr), flr); return Status::OK(); })); @@ -500,21 +532,22 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR( // running them. AnonymousIteratorHandleOp::AnonymousIteratorHandleOp( OpKernelConstruction* context) - : OpKernel(context), graph_def_version_(context->graph_def_version()) { + : OpKernel(context), + graph_def_version_(context->graph_def_version()), + op_version_(context->def().op() == "AnonymousIterator" ? 1 : 2) { OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_)); OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_)); } -void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) { +void AnonymousIteratorHandleOp::Compute(OpKernelContext* ctx) { FunctionLibraryRuntime* lib; std::unique_ptr device_mgr(nullptr); std::unique_ptr flib_def(nullptr); std::unique_ptr pflr(nullptr); - OP_REQUIRES_OK(context, - context->function_library()->Clone(&flib_def, &pflr, &lib)); - - ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(ctx, + ctx->function_library()->Clone(&flib_def, &pflr, &lib, true)); + ResourceMgr* mgr = ctx->resource_manager(); const string container_name = "AnonymousIterator"; string unique_name; { @@ -527,21 +560,30 @@ void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) { if (status.code() == error::NOT_FOUND) { break; } - OP_REQUIRES_OK(context, status); + OP_REQUIRES_OK(ctx, status); existing_resource->Unref(); } IteratorResource* new_resource = new IteratorResource( - context->env(), output_dtypes_, output_shapes_, graph_def_version_, + ctx->env(), output_dtypes_, output_shapes_, graph_def_version_, std::move(device_mgr), std::move(flib_def), std::move(pflr), lib); // Create the resource with our chosen name under the resource lookup // mutex to avoid another kernel racily creating a resource with this // name. - OP_REQUIRES_OK(context, mgr->Create( - container_name, unique_name, new_resource)); + OP_REQUIRES_OK(ctx, mgr->Create( + container_name, unique_name, new_resource)); + } + Tensor* handle_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t)); + ResourceHandle handle = MakeResourceHandle(ctx, container_name, unique_name, + MakeTypeIndex()); + handle_t->scalar()() = handle; + + if (op_version_ == 2) { + Tensor* deleter_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t)); + deleter_t->scalar()() = + IteratorResource::Deleter(handle, ctx->resource_manager()); } - OP_REQUIRES_OK(context, MakeResourceHandleToOutput( - context, 0, container_name, unique_name, - MakeTypeIndex())); } // Static initializers for AnonymousIteratorHandleOp id counting. @@ -559,6 +601,14 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, iterator_resource->SetIteratorFromDataset(ctx, dataset)); } +void DeleteIteratorOp::Compute(OpKernelContext* ctx) { + ResourceHandle handle = ctx->input(0).flat()(0); + // The iterator resource is guaranteed to exist because the variant tensor + // wrapping the deleter is provided as an unused input to this op, which + // guarantees that it has not run yet. + OP_REQUIRES_OK(ctx, ctx->resource_manager()->Delete(handle)); +} + namespace { class ToSingleElementOp : public AsyncOpKernel { @@ -578,7 +628,7 @@ class ToSingleElementOp : public AsyncOpKernel { std::unique_ptr iterator; IteratorContext::Params params(ctx); std::unique_ptr function_handle_cache = - absl::make_unique(params.lib); + absl::make_unique(params.flr); params.function_handle_cache = function_handle_cache.get(); std::unique_ptr resource_mgr = absl::make_unique(); @@ -640,11 +690,16 @@ class ReduceDatasetOp : public AsyncOpKernel { explicit ReduceDatasetOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), background_worker_(ctx->env(), "tf_data_reduce_dataset") { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_)); + bool use_inter_op_parallelism; + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism)); + FunctionMetadata::Params params; + params.is_multi_device_function = true; + params.use_inter_op_parallelism = use_inter_op_parallelism; + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "f", params, &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", - &use_inter_op_parallelism_)); } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { @@ -663,13 +718,13 @@ class ReduceDatasetOp : public AsyncOpKernel { std::unique_ptr captured_func; OP_REQUIRES_OK_ASYNC( ctx, - CapturedFunction::Create(reduce_func_, ctx, "other_arguments", - use_inter_op_parallelism_, &captured_func), + CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func), done); IteratorContext::Params params(ctx); std::unique_ptr function_handle_cache = - absl::make_unique(params.lib); + absl::make_unique(params.flr); params.function_handle_cache = function_handle_cache.get(); std::unique_ptr resource_mgr = absl::make_unique(); @@ -746,10 +801,9 @@ class ReduceDatasetOp : public AsyncOpKernel { } private: - NameAttrList reduce_func_; + std::shared_ptr func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_; - bool use_inter_op_parallelism_; BackgroundWorker background_worker_; }; @@ -837,21 +891,22 @@ class OneShotIteratorOp : public AsyncOpKernel { ContainerInfo* cinfo) { TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def())); - FunctionLibraryRuntime* lib; + FunctionLibraryRuntime* flr; std::unique_ptr flib_def(nullptr); std::unique_ptr pflr(nullptr); - TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def, &pflr, &lib)); + TF_RETURN_IF_ERROR( + ctx->function_library()->Clone(&flib_def, &pflr, &flr, true)); // Create an IteratorResource that will hold the iterator for this op. TF_RETURN_IF_ERROR( ctx->resource_manager()->LookupOrCreate( cinfo->container(), cinfo->name(), iterator, - [ctx, lib, this, &flib_def, &pflr](IteratorResource** ret) + [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new IteratorResource( ctx->env(), output_dtypes_, output_shapes_, graph_def_version_, nullptr, std::move(flib_def), - std::move(pflr), lib); + std::move(pflr), flr); return Status::OK(); })); @@ -870,12 +925,6 @@ class OneShotIteratorOp : public AsyncOpKernel { &f_handle)); FunctionLibraryRuntime::Options opts; opts.cancellation_manager = ctx->cancellation_manager(); - // Choose a step ID that is guaranteed not to clash with any - // Session-generated step ID. DirectSession only generates - // non-negative step IDs (contiguous, starting from 0), and - // MasterSession generates 56-bit random step IDs whose MSB is - // always 0, so a negative random step ID should suffice. - opts.step_id = -std::abs(static_cast(random::New64())); ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) { ctx->resource_manager()->Cleanup(name).IgnoreError(); }); @@ -1173,12 +1222,25 @@ REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2), REGISTER_KERNEL_BUILDER( Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"), MakeIteratorOp); +REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2), + DeleteIteratorOp); +REGISTER_KERNEL_BUILDER( + Name("DeleteIterator").Device(DEVICE_GPU).HostMemory("deleter").Priority(1), + DeleteIteratorOp); REGISTER_KERNEL_BUILDER( Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2), AnonymousIteratorHandleOp); REGISTER_KERNEL_BUILDER( Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1), AnonymousIteratorHandleOp); +REGISTER_KERNEL_BUILDER( + Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2), + AnonymousIteratorHandleOp); +REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2") + .Device(DEVICE_GPU) + .HostMemory("deleter") + .Priority(1), + AnonymousIteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU), ToSingleElementOp); REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index 7d769d365e9..2e887d379e3 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -96,6 +96,7 @@ class AnonymousIteratorHandleOp : public OpKernel { DataTypeVector output_dtypes_; std::vector output_shapes_; const int graph_def_version_; + const int op_version_; }; class MakeIteratorOp : public OpKernel { @@ -117,6 +118,13 @@ class IteratorGetNextOp : public AsyncOpKernel { BackgroundWorker background_worker_; }; +class DeleteIteratorOp : public OpKernel { + public: + explicit DeleteIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + class IteratorGetNextAsOptionalOp : public AsyncOpKernel { public: explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index e516d7791bf..68b15651755 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -34,11 +34,13 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::vector, std::vector*)>; explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + FunctionMetadata::Params params; + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + ¶ms.use_inter_op_parallelism)); + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "f", params, &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", - &use_inter_op_parallelism_)); OP_REQUIRES_OK( ctx, ctx->GetAttr("preserve_cardinality", &preserve_cardinality_)); } @@ -46,72 +48,28 @@ class MapDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { std::unique_ptr captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - use_inter_op_parallelism_, - &captured_func)); + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func)); - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); - - MapIteratorFunction map_func; - CapturedFunction* raw_captured_func = captured_func.get(); - if (indices.empty()) { - map_func = [](IteratorContext* ctx, - InstantiatedCapturedFunction* inst_captured_func, - std::vector args, - std::vector* out_tensors) { - return inst_captured_func->Run(ctx, std::move(args), out_tensors); - }; - } else { - std::vector can_move = ComputeMoveVector(indices); - map_func = [raw_captured_func, indices, can_move]( - IteratorContext* ctx, - InstantiatedCapturedFunction* inst_captured_func, - std::vector args, - std::vector* out_tensors) { - const std::vector& captured_inputs = - raw_captured_func->captured_inputs(); - size_t num_args = args.size(); - for (size_t i = 0; i < indices.size(); ++i) { - if (indices[i] < num_args) { - if (can_move[i]) { - out_tensors->push_back(std::move(args[indices[i]])); - } else { - out_tensors->push_back(args[indices[i]]); - } - } else { - out_tensors->push_back(captured_inputs[indices[i] - num_args]); - } - } - return Status::OK(); - }; - } - - *output = - new Dataset(ctx, input, func_, std::move(captured_func), output_types_, - output_shapes_, use_inter_op_parallelism_, - std::move(map_func), preserve_cardinality_); + *output = new Dataset(ctx, input, std::move(captured_func), output_types_, + output_shapes_, preserve_cardinality_); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, std::unique_ptr captured_func, const DataTypeVector& output_types, const std::vector& output_shapes, - bool use_inter_op_parallelism, MapIteratorFunction map_func, bool preserve_cardinality) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), - use_inter_op_parallelism_(use_inter_op_parallelism), preserve_cardinality_(preserve_cardinality), captured_func_(std::move(captured_func)), output_types_(output_types), - output_shapes_(output_shapes), - map_func_(std::move(map_func)) { + output_shapes_(output_shapes) { input_->Ref(); } @@ -120,7 +78,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_); + Iterator::Params{this, strings::StrCat(prefix, "::Map")}); } const DataTypeVector& output_dtypes() const override { @@ -141,27 +99,14 @@ class MapDatasetOp : public UnaryDatasetOpKernel { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + DataTypeVector other_arguments_types; + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); // Attr: f - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); AttrValue f_attr; - b->BuildAttrValue(func_, &f_attr); + b->BuildAttrValue(captured_func_->func(), &f_attr); // Attr: Targuments AttrValue other_arguments_types_attr; @@ -169,7 +114,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { // Attr: use_inter_op_parallelism AttrValue use_inter_op_parallelism_attr; - b->BuildAttrValue(use_inter_op_parallelism_, + b->BuildAttrValue(captured_func_->use_inter_op_parallelism(), &use_inter_op_parallelism_attr); // Attr: preserve_cardinality @@ -192,8 +137,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params, MapIteratorFunction map_func) - : DatasetIterator(params), map_func_(std::move(map_func)) {} + explicit Iterator(const Params& params) + : DatasetIterator(params) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -216,8 +161,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - Status s = map_func_(ctx, instantiated_captured_func_.get(), args, - out_tensors); + Status s = + instantiated_captured_func_->Run(ctx, std::move(args), out_tensors); if (errors::IsOutOfRange(s)) { if (dataset()->preserve_cardinality_) { // To guarantee that the transformation preserves the cardinality of @@ -257,24 +202,19 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: std::unique_ptr input_impl_; - const MapIteratorFunction map_func_; std::unique_ptr instantiated_captured_func_; }; const DatasetBase* const input_; - const NameAttrList func_; - const bool use_inter_op_parallelism_; const bool preserve_cardinality_; const std::unique_ptr captured_func_; const DataTypeVector output_types_; const std::vector output_shapes_; - const MapIteratorFunction map_func_; }; + std::shared_ptr func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList func_; - bool use_inter_op_parallelism_; bool preserve_cardinality_; }; diff --git a/tensorflow/core/kernels/data/map_dataset_op_test.cc b/tensorflow/core/kernels/data/map_dataset_op_test.cc index b0d17ab2865..ac70d39cda3 100644 --- a/tensorflow/core/kernels/data/map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/map_dataset_op_test.cc @@ -49,7 +49,7 @@ class MapDatasetOpTest : public DatasetOpsTestBase { FunctionDefHelper::AttrValueWrapper func = FunctionDefHelper::FunctionRef(func_name, {{"T", DT_INT64}}); - map_node_def_ = test::function::NDef( + NodeDef map_dataset_node_def = test::function::NDef( kNodeName, kOpName, {input_dataset}, {{"f", func}, {"Targuments", {}}, @@ -58,136 +58,198 @@ class MapDatasetOpTest : public DatasetOpsTestBase { gtl::ArraySlice{tensorflow::DataTypeToEnum::value}}, {"use_inter_op_parallelism", true}, {"preserve_cardinality", false}}); - TF_CHECK_OK(CreateOpKernel(map_node_def_, map_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(map_dataset_node_def, map_kernel)); return Status::OK(); } // Creates a new MapDataset op kernel context. Status CreateMapDatasetContext( - DatasetBase* const input_dataset, OpKernel* const map_kernel, + OpKernel* const map_kernel, gtl::InlinedVector* inputs, std::unique_ptr* map_context) { - map_inputs_.clear(); - // Save the input dataset into a variant tensor as the input of MapDataset. - Tensor dataset_tensor(DT_VARIANT, TensorShape({})); - TF_RETURN_IF_ERROR( - StoreDatasetInVariantTensor(input_dataset, &dataset_tensor)); - Variant variant = dataset_tensor.scalar()(); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &map_inputs_, map_kernel->input_types(), TensorShape({}), {variant})); - input_dataset->Ref(); - TF_RETURN_IF_ERROR( - CreateOpKernelContext(map_kernel, &map_inputs_, map_context)); - TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, map_inputs_)); + TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(map_kernel, inputs, map_context)); return Status::OK(); } - - private: - NodeDef map_node_def_; - gtl::InlinedVector map_inputs_; }; -struct GetNextTestParams { - explicit GetNextTestParams(int64 input_start, int64 input_end, - int64 input_step, string input_func_name, - std::vector input_expected_values, - std::vector input_func_lib) - : start(input_start), - end(input_end), - step(input_step), - func_name(std::move(input_func_name)), - expected_values(std::move(input_expected_values)), - func_lib(std::move(input_func_lib)) {} - +struct TestCase { int64 start; int64 end; int64 step; string func_name; - std::vector expected_values; std::vector func_lib; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; }; -struct DatasetGetNextTest : MapDatasetOpTest, - ::testing::WithParamInterface {}; +TestCase TestCase1() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*func_name*/ "XTimesTwo", + /*func_lib*/ {test::function::XTimesTwo()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {18})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} -TEST_P(DatasetGetNextTest, GetNext) { +TestCase TestCase2() { + return {/*start*/ 10, + /*end*/ 0, + /*step*/ -3, + /*func_name*/ "XAddX", + /*func_lib*/ {test::function::XAddX()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {20}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {14}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// In this test case, the function `XTimesFour()` will call `XTimesTwo()`, so +// both of them are added to the function library. +TestCase TestCase3() { + return { + /*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*func_name*/ "XTimesFour", + /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {24}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {36})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +class ParameterizedMapDatasetOpTest + : public MapDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedMapDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; - GetNextTestParams test_params = GetParam(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scored_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), test_params.func_name, &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); + bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector out_tensors; while (!end_of_sequence) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } } - - EXPECT_EQ(out_tensors.size(), test_params.expected_values.size()); - for (size_t i = 0; i < out_tensors.size(); ++i) { - int64 actual_value = out_tensors[i].flat()(0); - int64 expect_value = test_params.expected_values[i]; - EXPECT_EQ(actual_value, expect_value); - } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -INSTANTIATE_TEST_CASE_P( - MapDatasetOpTest, DatasetGetNextTest, - ::testing::Values( - GetNextTestParams( - 0, 10, 3, "XTimesTwo", std::vector{0, 6, 12, 18}, - std::vector{test::function::XTimesTwo()}), - GetNextTestParams(0, 10, 3, "XAddX", std::vector{0, 6, 12, 18}, - std::vector{test::function::XAddX()}), - GetNextTestParams( - 10, 0, -3, "XTimesFour", std::vector{40, 28, 16, 4}, - std::vector{test::function::XTimesTwo(), - test::function::XTimesFour()}))); - -TEST_F(MapDatasetOpTest, DatasetName) { +TEST_F(MapDatasetOpTest, DatasetNodeName) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(map_dataset); + + EXPECT_EQ(map_dataset->node_name(), kNodeName); +} + +TEST_F(MapDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); + + std::unique_ptr map_dataset_kernel; + TF_ASSERT_OK(CreateMapDatasetOpKernel( + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); + DatasetBase* map_dataset; + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); EXPECT_EQ(map_dataset->type_string(), kOpName); @@ -195,138 +257,125 @@ TEST_F(MapDatasetOpTest, DatasetName) { TEST_F(MapDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(map_dataset->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(map_dataset->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(MapDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < map_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE( - map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(map_dataset->output_shapes(), + test_case.expected_output_shapes)); } -struct CardinalityTestParams { - explicit CardinalityTestParams(int64 input_start, int64 input_end, - int64 input_step, - int input_expected_cardinality) - : start(input_start), - end(input_end), - step(input_step), - expected_cardinality(input_expected_cardinality) {} - - int64 start; - int64 end; - int64 step; - int expected_cardinality; -}; - -struct DatasetCardinalityTest - : MapDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetCardinalityTest, Cardinality) { +TEST_P(ParameterizedMapDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; - CardinalityTestParams test_params = GetParam(); - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); - EXPECT_EQ(map_dataset->Cardinality(), test_params.expected_cardinality); + EXPECT_EQ(map_dataset->Cardinality(), test_case.expected_cardinality); } -INSTANTIATE_TEST_CASE_P(MapDatasetOpTest, DatasetCardinalityTest, - ::testing::Values(CardinalityTestParams(0, 10, 1, 10), - CardinalityTestParams(0, 10, 3, 4), - CardinalityTestParams(10, 0, -3, 4))); - -TEST_F(MapDatasetOpTest, DatasetSave) { +TEST_P(ParameterizedMapDatasetOpTest, DatasetSave) { int thread_num = 2, cpu_num = 2; - int64 start = 0, end = 10, step = 1; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr serialization_context; @@ -338,101 +387,114 @@ TEST_F(MapDatasetOpTest, DatasetSave) { } TEST_F(MapDatasetOpTest, IteratorOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(iterator->output_dtypes(), expected_dtypes); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(MapDatasetOpTest, IteratorOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < map_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); } TEST_F(MapDatasetOpTest, IteratorOutputPrefix) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - FunctionDef func_def = test::function::XTimesTwo(); - + TestCase test_case = TestCase1(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. TF_ASSERT_OK( - CreateRangeDataset(start, end, step, "range", &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), func_def.signature().name(), &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); @@ -440,95 +502,80 @@ TEST_F(MapDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), "Iterator::Map"); } -struct RoundtripTestParams { - explicit RoundtripTestParams(int64 input_start, int64 input_end, - int64 input_step, int input_breakpoint, - int64 input_expected_value, - string input_func_name, - std::vector input_func_lib) - : start(input_start), - end(input_end), - step(input_step), - breakpoint(input_breakpoint), - expected_value(input_expected_value), - func_name(std::move(input_func_name)), - func_lib(std::move(input_func_lib)) {} - - int64 start; - int64 end; - int64 step; - int breakpoint; - int64 expected_value; - string func_name; - std::vector func_lib; -}; - -struct IteratorRoundtripTest - : MapDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorRoundtripTest, Roundtrip) { +TEST_P(ParameterizedMapDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; - RoundtripTestParams test_params = GetParam(); - + TestCase test_case = GetParam(); TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); DatasetBase* range_dataset; - TF_ASSERT_OK(CreateRangeDataset(test_params.start, test_params.end, - test_params.step, "range", - &range_dataset)); - core::ScopedUnref scoped_unref_range_dataset(range_dataset); + TF_ASSERT_OK(CreateRangeDataset( + test_case.start, test_case.end, test_case.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + // The ownership of range_dataset is transferred to DatasetVariantWrapper, + // which will handle the release of memory. + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + gtl::InlinedVector map_dataset_inputs; + map_dataset_inputs.emplace_back(&range_dataset_tensor); - std::unique_ptr map_kernel; + std::unique_ptr map_dataset_kernel; TF_ASSERT_OK(CreateMapDatasetOpKernel( - range_dataset->node_name(), test_params.func_name, &map_kernel)); - std::unique_ptr map_context; - TF_ASSERT_OK( - CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context)); + range_dataset->node_name(), test_case.func_name, &map_dataset_kernel)); + std::unique_ptr map_dataset_context; + TF_ASSERT_OK(CreateMapDatasetContext( + map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context)); DatasetBase* map_dataset; - TF_ASSERT_OK( - CreateDataset(map_kernel.get(), map_context.get(), &map_dataset)); + TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(), + map_dataset_context.get(), &map_dataset)); core::ScopedUnref scoped_unref_map_dataset(map_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(map_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK( map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector out_tensors; + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; - for (int i = 0; i < test_params.breakpoint; i++) { - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - } + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader, "Iterator", + *map_dataset, &iterator)); - std::unique_ptr serialization_context; - TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); - VariantTensorData data; - VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); - TF_ASSERT_OK(writer.Flush()); - VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - EXPECT_EQ(out_tensors.back().flat()(0), test_params.expected_value); + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_cardinality) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } } -INSTANTIATE_TEST_CASE_P( - MapDatasetOpTest, IteratorRoundtripTest, - ::testing::Values(RoundtripTestParams(0, 10, 2, 0, 0, "XTimesTwo", - std::vector{ - test::function::XTimesTwo()}), - RoundtripTestParams(0, 10, 2, 4, 16, "XAddX", - std::vector{ - test::function::XAddX()}), - RoundtripTestParams(0, 10, 2, 6, 32, "XTimesFour", - std::vector{ - test::function::XTimesTwo(), - test::function::XTimesFour()}))); +INSTANTIATE_TEST_SUITE_P(MapDatasetOpTest, ParameterizedMapDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3()}))); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index f1be942a633..cae0facfba3 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/mutex.h" @@ -28,16 +29,18 @@ namespace tensorflow { namespace data { namespace { -void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, - bool always_collect_stats) { - opts->step_id = ctx->step_id(); - opts->rendezvous = ctx->rendezvous(); - if (always_collect_stats) { - opts->stats_collector = ctx->stats_collector(); - } - opts->runner = ctx->runner(); -} - +// This op runs a given defun on slices of the input arguments. The function +// given by "f" is assumed to be stateless, and is executed concurrently +// on all the slices; up to batch_size (i.e. the 0th dimension of each argument) +// functions will be scheduled at once. +// +// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to +// limit the intra op parallelism. To limit inter-op parallelism, a user +// can set a private threadpool on the dataset using `tf.data.Options`'s +// `ThreadingOptions`. +// +// Note that this op is not exposed to users directly, but is invoked in +// tf.data rewrites. class MapDefunOp : public AsyncOpKernel { public: explicit MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { @@ -50,6 +53,8 @@ class MapDefunOp : public AsyncOpKernel { func_lib->Instantiate(func->name(), AttrSlice(&func->attr()), &func_handle_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", + &max_intra_op_parallelism_)); OP_REQUIRES(ctx, ctx->num_inputs() >= 0, errors::InvalidArgument("Must have at least one input.")); @@ -72,7 +77,7 @@ class MapDefunOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC(ctx, s, done); FunctionLibraryRuntime::Options opts; - SetRunOptions(ctx, &opts, false); + SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false); // Run loop StatusCallback callback = std::bind( @@ -124,9 +129,6 @@ class MapDefunOp : public AsyncOpKernel { } private: - FunctionLibraryRuntime::Handle func_handle_; - std::vector output_shapes_; - struct ComputeOptions { // These vary per MapDefunOp::ComputeAsync call, but must persist until // all calls to the function are complete. This struct also encapsulates @@ -136,6 +138,7 @@ class MapDefunOp : public AsyncOpKernel { const std::vector arg_shapes; OpInputList captured_inputs; const int64 batch_size; + std::function)> runner; // Output of a compute call std::vector output_shapes GUARDED_BY(mu); @@ -144,68 +147,22 @@ class MapDefunOp : public AsyncOpKernel { // Create a copy of output_shapes because every `Compute` may expect a // different output shape. - ComputeOptions(OpInputList args, OpInputList captured_inputs, + ComputeOptions(OpKernelContext* ctx, OpInputList args, + OpInputList captured_inputs, std::vector arg_shapes, int64 batch_size, - const std::vector& output_shapes_attr) + const std::vector& output_shapes_attr, + int max_parallelism) : args(args), arg_shapes(std::move(arg_shapes)), captured_inputs(captured_inputs), batch_size(batch_size), - output_shapes(output_shapes_attr) {} + output_shapes(output_shapes_attr) { + if (max_parallelism >= 1) { + runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism); + } + } }; - // Get inputs to Compute and check that they are valid. - Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { - OpInputList arguments; - TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments)); - OpInputList captured_inputs; - TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs)); - - int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1; - - for (size_t i = 0; i < arguments.size(); ++i) { - if (arguments[i].dims() == 0) { - return errors::InvalidArgument( - "All inputs must have rank at least 1. Input ", i, - " has a rank of 0."); - } else if (arguments[i].dim_size(0) != batch_size) { - return errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size); - } - } - - std::vector arg_shapes; - arg_shapes.reserve(arguments.size()); - - for (size_t i = 0; i < arguments.size(); ++i) { - arg_shapes.push_back(arguments[i].shape()); - arg_shapes.at(i).RemoveDim(0); - } - - *compute_opts = - new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes), - batch_size, output_shapes_); - return Status::OK(); - } - - Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) { - mutex_lock l(opts->mu); - TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output)); - - for (size_t i = 0; i < output_types().size(); ++i) { - if (output_shapes_.at(i).IsFullyDefined()) { - Tensor* out = nullptr; - TensorShape output_shape; - output_shapes_.at(i).AsTensorShape(&output_shape); - output_shape.InsertDim(0, opts->batch_size); - TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out)); - } - } - return Status::OK(); - } - class MapFunctionCallFrame : public CallFrameInterface { public: MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel, @@ -288,8 +245,80 @@ class MapDefunOp : public AsyncOpKernel { ComputeOptions* const compute_opts_; // Not owned const OpKernel* kernel_; const size_t iter_; - }; -}; + }; // MapFunctionCallFrame + + void SetRunOptions(OpKernelContext* ctx, + FunctionLibraryRuntime::Options* opts, + ComputeOptions* compute_opts, bool always_collect_stats) { + opts->rendezvous = ctx->rendezvous(); + if (always_collect_stats) { + opts->stats_collector = ctx->stats_collector(); + } + if (max_intra_op_parallelism_ >= 1) { + opts->runner = &compute_opts->runner; + } else { + opts->runner = ctx->runner(); + } + } + + // Get inputs to Compute and check that they are valid. + Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { + OpInputList arguments; + TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments)); + OpInputList captured_inputs; + TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs)); + + int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1; + + for (size_t i = 0; i < arguments.size(); ++i) { + if (arguments[i].dims() == 0) { + return errors::InvalidArgument( + "All inputs must have rank at least 1. Input ", i, + " has a rank of 0."); + } else if (arguments[i].dim_size(0) != batch_size) { + return errors::InvalidArgument( + "All inputs must have the same dimension 0. Input ", i, + " has leading dimension ", ctx->input(i).dim_size(0), + ", while all previous inputs have leading dimension ", batch_size); + } + } + + std::vector arg_shapes; + arg_shapes.reserve(arguments.size()); + + for (size_t i = 0; i < arguments.size(); ++i) { + arg_shapes.push_back(arguments[i].shape()); + arg_shapes.at(i).RemoveDim(0); + } + + *compute_opts = new ComputeOptions( + ctx, arguments, captured_inputs, std::move(arg_shapes), batch_size, + output_shapes_, max_intra_op_parallelism_); + return Status::OK(); + } + + Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) { + mutex_lock l(opts->mu); + TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output)); + + for (size_t i = 0; i < output_types().size(); ++i) { + if (output_shapes_.at(i).IsFullyDefined()) { + Tensor* out = nullptr; + TensorShape output_shape; + output_shapes_.at(i).AsTensorShape(&output_shape); + output_shape.InsertDim(0, opts->batch_size); + TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out)); + } + } + return Status::OK(); + } + + FunctionLibraryRuntime::Handle func_handle_; + std::vector output_shapes_; + // If this value is positive, limit the max intra op parallelism when the + // function is run on slices of the input. + int max_intra_op_parallelism_; +}; // MapDefunOp REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp); } // namespace diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 4b8b68f2a38..6e54ceedab1 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -26,23 +26,33 @@ namespace tensorflow { namespace data { namespace { -constexpr int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros; +constexpr int64 kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMillis; class ModelDatasetOp : public UnaryDatasetOpKernel { public: explicit ModelDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("cpu_budget", &cpu_budget_)); + if (cpu_budget_ == 0) { + cpu_budget_ = port::NumSchedulableCPUs(); + } + OP_REQUIRES(ctx, cpu_budget_ > 0, + errors::InvalidArgument("CPU budget must be positive but is ", + cpu_budget_, ".")); + } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - *output = new Dataset(ctx, input); + *output = new Dataset(ctx, input, cpu_budget_); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input) - : DatasetBase(DatasetContext(ctx)), input_(input) { + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 cpu_budget) + : DatasetBase(DatasetContext(ctx)), + input_(input), + cpu_budget_(cpu_budget) { input_->Ref(); } @@ -149,31 +159,32 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { void OptimizeThread(const std::shared_ptr& ctx) { int64 last_optimization_ms = 0; int64 optimization_period_ms = 10; + int64 current_time_ms = + ctx->env()->NowMicros() / EnvTime::kMillisToMicros; while (true) { { mutex_lock l(mu_); while (!cancelled_ && - last_optimization_ms + optimization_period_ms >= - ctx->env()->NowMicros() / EnvTime::kMillisToMicros) { - cond_var_.wait_for( - l, std::chrono::milliseconds( - last_optimization_ms + optimization_period_ms - - ctx->env()->NowMicros() / EnvTime::kMillisToMicros)); + last_optimization_ms + optimization_period_ms > + current_time_ms) { + auto wait_ms = last_optimization_ms + optimization_period_ms - + current_time_ms; + VLOG(2) << "Waiting for " << wait_ms << " ms."; + cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms)); + current_time_ms = + ctx->env()->NowMicros() / EnvTime::kMillisToMicros; } if (cancelled_) return; } - model_->Optimize(port::NumSchedulableCPUs()); + model_->Optimize(dataset()->cpu_budget_); // Exponentially increase the period of running the optimization // until a threshold is reached. - if (optimization_period_ms < kOptimizationPeriodThresholdMs) { - if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) { - optimization_period_ms <<= 1; - } else { - optimization_period_ms = kOptimizationPeriodThresholdMs; - } + if (optimization_period_ms != kOptimizationPeriodThresholdMs) { + optimization_period_ms = std::min(optimization_period_ms << 1, + kOptimizationPeriodThresholdMs); } - last_optimization_ms = - ctx->env()->NowMicros() / EnvTime::kMillisToMicros; + current_time_ms = ctx->env()->NowMicros() / EnvTime::kMillisToMicros; + last_optimization_ms = current_time_ms; } } @@ -186,7 +197,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* input_; + const int64 cpu_budget_; }; + + int64 cpu_budget_; }; REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index 6a600a72dfa..34d9ece8b06 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -48,17 +48,17 @@ class MultiDeviceIterator : public ResourceBase { const std::vector& devices, std::unique_ptr flib_def, std::unique_ptr pflr, - FunctionLibraryRuntime* lib, + FunctionLibraryRuntime* flr, std::unique_ptr function_handle_cache) : unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"), output_types_(output_types), output_shapes_(output_shapes), devices_(devices), flib_def_(std::move(flib_def)), + flr_(flr), pflr_(std::move(pflr)), - lib_(lib), function_handle_cache_(std::move(function_handle_cache)) { - DCHECK(lib_ != nullptr); + DCHECK(flr_ != nullptr); } string DebugString() const override { @@ -94,8 +94,7 @@ class MultiDeviceIterator : public ResourceBase { MultiDeviceIteratorCallback callback) { tf_shared_lock l(mu_); IteratorContext::Params params(ctx); - params.function_library = lib_def_; - params.lib = lib_; + params.flr = flr_; params.function_handle_cache = function_handle_cache_.get(); params.resource_mgr = &resource_mgr_; params.thread_factory = unbounded_thread_pool_.get_thread_factory(); @@ -111,14 +110,9 @@ class MultiDeviceIterator : public ResourceBase { return output_shapes_; } - std::shared_ptr function_library() { + FunctionLibraryRuntime* const flr() { tf_shared_lock l(mu_); - return lib_def_; - } - - FunctionLibraryRuntime* const lib() { - tf_shared_lock l(mu_); - return lib_; + return flr_; } FunctionHandleCache* function_handle_cache() { @@ -355,8 +349,8 @@ class MultiDeviceIterator : public ResourceBase { const std::vector output_shapes_; const std::vector devices_; const std::unique_ptr flib_def_; + FunctionLibraryRuntime* const flr_ = nullptr; // not owned. const std::unique_ptr pflr_; - FunctionLibraryRuntime* const lib_ = nullptr; // not owned. const std::unique_ptr function_handle_cache_; ResourceMgr resource_mgr_; std::shared_ptr lib_def_ GUARDED_BY(mu_); @@ -402,13 +396,13 @@ class MultiDeviceIteratorHandleOp : public OpKernel { { mutex_lock l(mu_); if (resource_ == nullptr) { - FunctionLibraryRuntime* lib; + FunctionLibraryRuntime* flr; std::unique_ptr flib_def(nullptr); std::unique_ptr pflr(nullptr); OP_REQUIRES_OK(context, context->function_library()->Clone( - &flib_def, &pflr, &lib)); + &flib_def, &pflr, &flr)); std::unique_ptr function_handle_cache = - absl::make_unique(lib); + absl::make_unique(flr); ResourceMgr* mgr = context->resource_manager(); OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); @@ -420,7 +414,7 @@ class MultiDeviceIteratorHandleOp : public OpKernel { container_name = "AnonymousMultiDeviceIterator"; resource = new MultiDeviceIterator( context->env(), output_types_, output_shapes_, devices_, - std::move(flib_def), std::move(pflr), lib, + std::move(flib_def), std::move(pflr), flr, std::move(function_handle_cache)); // NOTE: `mgr->Create()` transfers the one reference on `resource` to // `mgr`. @@ -432,14 +426,14 @@ class MultiDeviceIteratorHandleOp : public OpKernel { OP_REQUIRES_OK(context, mgr->LookupOrCreate( container_name, unique_name, &resource, - [this, context, lib, &flib_def, &pflr, + [this, context, flr, &flib_def, &pflr, &function_handle_cache](MultiDeviceIterator** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new MultiDeviceIterator( context->env(), output_types_, output_shapes_, devices_, std::move(flib_def), std::move(pflr), - lib, std::move(function_handle_cache)); + flr, std::move(function_handle_cache)); return Status::OK(); })); Status s = VerifyResource(resource); @@ -505,7 +499,7 @@ class MultiDeviceIteratorInitOp : public OpKernel { std::unique_ptr iterator; IteratorContext::Params params(ctx); - params.lib = resource->lib(); + params.flr = resource->flr(); params.function_handle_cache = resource->function_handle_cache(); params.resource_mgr = resource->resource_mgr(); IteratorContext iter_ctx(std::move(params)); diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 17094e30017..896e080ae62 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" @@ -32,10 +32,9 @@ constexpr char kOptimizerName[] = "tf_data_meta_optimizer"; class OptimizeDatasetOp : public UnaryDatasetOpKernel { public: explicit OptimizeDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK( + ctx, ctx->GetAttr("optimization_configs", &optimization_configs_)); } protected: @@ -44,52 +43,41 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { std::vector optimizations; OP_REQUIRES_OK( ctx, ParseVectorArgument(ctx, "optimizations", &optimizations)); - Dataset* dataset = - new Dataset(ctx, input, optimizations, output_types_, output_shapes_); - Status s = dataset->Optimize(ctx); - if (s.ok()) { - *output = dataset; - } else { - dataset->Unref(); - OP_REQUIRES_OK(ctx, s); - } + + auto config_factory = [this, &optimizations]() { + return CreateConfig(optimizations, optimization_configs_); + }; + OP_REQUIRES_OK(ctx, + RewriteDataset(ctx, input, std::move(config_factory), + /*optimize_function_library=*/true, output)); } private: - class Dataset : public GraphRewriteDataset { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const std::vector& optimizations, - const DataTypeVector& output_types, - const std::vector& output_shapes) - : GraphRewriteDataset(ctx, input, output_types, output_shapes), - optimizations_(optimizations) {} - - string DebugString() const override { return "OptimizeDatasetOp::Dataset"; } - - private: - RewriterConfig CreateGrapplerRewriteConfig() override { - RewriterConfig rewriter_config; - rewriter_config.add_optimizers(kOptimizerName); - rewriter_config.set_meta_optimizer_iterations( - RewriterConfig_NumIterationsType_ONE); - auto custom_optimizer = rewriter_config.add_custom_optimizers(); - custom_optimizer->set_name(kOptimizerName); - auto* custom_optimizations_list = - (*custom_optimizer->mutable_parameter_map())["optimizers"] - .mutable_list(); - for (const auto& opt : optimizations_) { - custom_optimizations_list->add_s(opt); - } - return rewriter_config; + static RewriterConfig CreateConfig( + std::vector optimizations, + std::vector optimizations_configs) { + RewriterConfig rewriter_config; + rewriter_config.add_optimizers(kOptimizerName); + rewriter_config.set_meta_optimizer_iterations( + RewriterConfig_NumIterationsType_ONE); + auto custom_optimizer = rewriter_config.add_custom_optimizers(); + custom_optimizer->set_name(kOptimizerName); + auto* custom_optimizations_list = + (*custom_optimizer->mutable_parameter_map())["optimizers"] + .mutable_list(); + for (const auto& opt : optimizations) { + custom_optimizations_list->add_s(opt); } + auto* config_list = + (*custom_optimizer->mutable_parameter_map())["optimizer_configs"] + .mutable_list(); + for (const auto& config : optimizations_configs) { + config_list->add_s(config); + } + return rewriter_config; + } - const std::vector optimizations_; - }; - - const int graph_def_version_; - DataTypeVector output_types_; - std::vector output_shapes_; + std::vector optimization_configs_; }; REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h index 24eb1b81d90..91fa253b70d 100644 --- a/tensorflow/core/kernels/data/optional_ops.h +++ b/tensorflow/core/kernels/data/optional_ops.h @@ -90,10 +90,10 @@ class OptionalVariant { string DebugString() const { if (values_) { return strings::StrCat("OptionalVariant<", "values: (", - str_util::Join(*values_, ", ", - [](string* s, const Tensor& elem) { - *s = elem.DebugString(); - }), + absl::StrJoin(*values_, ", ", + [](string* s, const Tensor& elem) { + *s = elem.DebugString(); + }), ")>"); } else { return strings::StrCat("OptionalVariant"); diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index 41ea36263c7..8086253cf73 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -13,9 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/batch_util.h" namespace tensorflow { @@ -29,7 +34,11 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), - op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {} + op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) { + if (ctx->HasAttr("parallel_copy")) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("parallel_copy", ¶llel_copy_)); + } + } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { @@ -93,31 +102,32 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { } *output = - new Dataset(ctx, batch_size, drop_remainder, std::move(padded_shapes), - std::move(padding_values), input); + new Dataset(ctx, batch_size, drop_remainder, parallel_copy_, + std::move(padded_shapes), std::move(padding_values), input); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder, - std::vector padded_shapes, + bool parallel_copy, std::vector padded_shapes, std::vector padding_values, const DatasetBase* input) : DatasetBase(DatasetContext(ctx)), batch_size_(batch_size), drop_remainder_(drop_remainder), + parallel_copy_(parallel_copy), padded_shapes_(std::move(padded_shapes)), padding_values_(std::move(padding_values)), input_(input) { input_->Ref(); - // NOTE(mrry): Currently we implement "batch up to" - // semantics. If we could tell statically that the input dataset - // is infinite, then we could always report `batch_size` as the - // 0th dimension. - // TODO(mrry): Need to validate that the input shape and the - // padded shape are "compatible" (i.e. that padded shape is >= - // input shape, with both static and dynamic checks as appropriate). + // NOTE(mrry): Currently we implement "batch up to" semantics. If we could + // tell statically that the input dataset is infinite, then we could + // always report `batch_size` as the 0th dimension. + // + // TODO(mrry): Need to validate that the input shape and the padded shape + // are "compatible" (i.e. that padded shape is >= input shape, with both + // static and dynamic checks as appropriate). const auto& input_shapes = input_->output_shapes(); output_shapes_.reserve(input_shapes.size()); for (size_t i = 0; i < input_shapes.size(); ++i) { @@ -193,6 +203,9 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { Node* drop_remainder = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder)); + AttrValue parallel_copy; + b->BuildAttrValue(parallel_copy_, ¶llel_copy); + AttrValue output_types; b->BuildAttrValue(output_dtypes(), &output_types); @@ -202,14 +215,14 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddDataset( this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}}, {{2, padded_shapes}, {3, padding_values}}, - {{"Toutput_types", output_types}, {"N", N}}, output)); + {{"parallel_copy", parallel_copy}, + {"Toutput_types", output_types}, + {"N", N}}, + output)); return Status::OK(); } private: - // Copies element into the index^th slice of parent (in the 0th dimension). - // - class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) @@ -259,13 +272,14 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - // Copy the retrieved batch elements into one output tensor - // per tuple component. - // NOTE(mrry): If the input or output sizes are statically - // known, we could potentially read the input values in-place - // into their respective slice locations. This would require a - // different GetNext() overload that supports zero-copy, and might - // make sense in an optimization pass. + // Copy the retrieved batch elements into one output tensor per tuple + // component. + // + // NOTE(mrry): If the input or output sizes are statically known, we + // could potentially read the input values in-place into their + // respective slice locations. This would require a different GetNext() + // overload that supports zero-copy, and might make sense in an + // optimization pass. const size_t num_tuple_components = batch_elements[0].size(); const int64 num_batch_elements = batch_elements.size(); for (size_t component_index = 0; component_index < num_tuple_components; @@ -330,16 +344,43 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { for (int i = 1; i < batch_component_shape.dims(); ++i) { component_shape.AddDim(batch_component_shape.dim_size(i)); } - for (int64 i = 0; i < num_batch_elements; ++i) { + auto copy_element_fn = [component_index, &batch_elements, + &batch_component, + &component_shape](int index) { // Take the fast path if possible. - if (batch_elements[i][component_index].shape() == component_shape) { + if (batch_elements[index][component_index].shape() == + component_shape) { TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice( - batch_elements[i][component_index], &batch_component, i)); + batch_elements[index][component_index], &batch_component, + index)); } else { TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice( - batch_elements[i][component_index], &batch_component, i)); + batch_elements[index][component_index], &batch_component, + index)); + } + return Status::OK(); + }; + BlockingCounter counter(num_batch_elements); + Status status; + mutex status_mu; + for (size_t i = 0; i < num_batch_elements; ++i) { + if (TF_PREDICT_FALSE(dataset()->parallel_copy_)) { + (*ctx->runner())( + [i, &status, &status_mu, &counter, ©_element_fn]() { + Status s = copy_element_fn(i); + { + mutex_lock l(status_mu); + status.Update(s); + } + counter.DecrementCount(); + }); + } else { + status.Update(copy_element_fn(i)); + counter.DecrementCount(); } } + counter.Wait(); + TF_RETURN_IF_ERROR(status); } *end_of_sequence = false; return Status::OK(); @@ -381,6 +422,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { const int64 batch_size_; const bool drop_remainder_; + const bool parallel_copy_; const std::vector padded_shapes_; const std::vector padding_values_; const DatasetBase* const input_; @@ -388,6 +430,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { }; const int op_version_; + bool parallel_copy_ = false; }; REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc new file mode 100644 index 00000000000..89d42d67855 --- /dev/null +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc @@ -0,0 +1,1246 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "padded_batch_datasetv2"; +constexpr char kOpName[] = "PaddedBatchDatasetV2"; + +class PaddedBatchDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `ConcatenateDataset` variant tensor from the input vector of + // tensor vectors. + Status CreateConcatenateDatasetTensor( + const std::vector> &tensor_vectors, + const DataTypeVector &output_types, + const std::vector &output_shapes, + Tensor *concatenate_dataset_tensor) { + // Create two `TensorSliceDataset` tensors as the inputs for + // `ConcatenateDataset`. + std::vector tensor_slice_dataset_tensors; + for (int i = 0; i < tensor_vectors.size(); ++i) { + std::vector tensors = tensor_vectors[i]; + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR( + CreateTensorSliceDataset(strings::StrCat("tensor_slice_node_", i), + &tensors, &tensor_slice_dataset)); + Tensor dataset_tensor(DT_VARIANT, TensorShape({})); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, &dataset_tensor)); + tensor_slice_dataset_tensors.emplace_back(std::move(dataset_tensor)); + } + + // Create a `ConcatenateDataset` dataset. + std::unique_ptr concatenate_dataset_op_kernel; + NodeDef concatenate_node_def = test::function::NDef( + "concatenate_dataset", "ConcatenateDataset", + {"input_dataset", "another_dataset"}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR( + CreateOpKernel(concatenate_node_def, &concatenate_dataset_op_kernel)); + + gtl::InlinedVector concatenate_dataset_inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + concatenate_dataset_inputs.emplace_back(&tensor); + } + + std::unique_ptr concatenate_dataset_op_context; + TF_RETURN_IF_ERROR(CheckOpKernelInput(*concatenate_dataset_op_kernel, + concatenate_dataset_inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext( + concatenate_dataset_op_kernel.get(), &concatenate_dataset_inputs, + &concatenate_dataset_op_context)); + DatasetBase *concatenate_dataset; + TF_RETURN_IF_ERROR(CreateDataset(concatenate_dataset_op_kernel.get(), + concatenate_dataset_op_context.get(), + &concatenate_dataset)); + + // Store the `ConcatenateDataset` dataset in a tensor. + TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(concatenate_dataset, + concatenate_dataset_tensor)); + return Status::OK(); + } + + // Creates a new `PaddedBatchDataset` op kernel + Status CreatePaddedBatchDatasetKernel( + bool parallel_copy, int n, const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + std::vector inputs({"input_dataset", "batch_size"}); + // Create the placeholder names for the input padded_shapes. + for (int i = 0; i < n; ++i) { + inputs.emplace_back(strings::StrCat("padded_shapes_", i)); + } + // Create the placeholder names for the input padding_values. + for (int j = 0; j < output_types.size(); ++j) { + inputs.emplace_back(strings::StrCat("padding_values_", j)); + } + inputs.emplace_back("drop_remainder"); + + NodeDef node_def = test::function::NDef(kNodeName, kOpName, inputs, + {{"parallel_copy", parallel_copy}, + {"Toutput_types", output_types}, + {"output_shapes", output_shapes}, + {"N", n}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Creates a new `PaddedBatchDataset` op kernel context. + Status CreatePaddedBatchDatasetContext( + OpKernel *const op_kernel, + gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + // Used for creating two `TensorSliceDataset` datasets, which will be the + // input datasets for `ConcatenateDataset`. Then the `ConcatenateDataset` + // dataset will be the input for `PaddedBatchDataset`. + std::vector> input_tensors; + DataTypeVector concatenate_output_dtypes; + std::vector concatenate_output_shapes; + Tensor batch_size; + std::vector padded_shapes; + std::vector padding_values; + Tensor drop_remainder; + bool parallel_copy; + int64 n; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +template +std::vector ConvertToTensorVec(std::vector values) { + std::vector tensors; + tensors.reserve(values.size()); + for (auto &value : values) { + tensors.emplace_back( + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {value})); + } + return tensors; +} + +// Test case 1: input elements with same shapes. +TestCase TestCase1() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor( + TensorShape{4, 2}, {6, 7, 8, 9, 10, 11, 12, 13})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({2})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {true}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {0, 1, 1, 2, 3, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {4, 5, 1, 6, 7, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {8, 9, 1, 10, 11, 1})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({2, 3})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 2: input elements with different shapes. +TestCase TestCase2() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{4, 1}, + {6, 7, 8, 9})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({-1})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {true}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {0, 1, 1, 2, 3, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {4, 5, 1, 6, 1, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {7, 1, 1, 8, 1, 1})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({2, 3})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 3: similar with the test case 2 but drop_remainder = false. +TestCase TestCase3() { + return { + /*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{4, 1}, + {6, 7, 8, 9})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({-1})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ false, + /*n*/ 1, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {0, 1, 1, 2, 3, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {4, 5, 1, 6, 1, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {7, 1, 1, 8, 1, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1, 3}, {9, 1, 1})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, 3})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 4: similar with the test case 3 but the input elements can be +// divided by the batch size evenly. As drop_remainder = false, the output +// shape is still {-1, 3} instead of {2, 3}. +TestCase TestCase4() { + return { + /*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 1}, {6, 7, 8})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({-1})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ false, + /*n*/ 1, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {0, 1, 1, 2, 3, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {4, 5, 1, 6, 1, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, + {7, 1, 1, 8, 1, 1})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, 3})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 5: similar with the test case 3 but padded_shapes = {-1}. +TestCase TestCase5() { + return { + /*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{4, 1}, + {6, 7, 8, 9})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({-1})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {-1})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ false, + /*n*/ 1, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 2}, {0, 1, 2, 3}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 2}, {4, 5, 6, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 1}, {7, 8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1, 1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 6: similar with the test case 5 but parallel_copy = true. +TestCase TestCase6() { + return { + /*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{4, 1}, + {6, 7, 8, 9})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({-1})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {-1})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 2}, {0, 1, 2, 3}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 2}, {4, 5, 6, 1}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 1}, {7, 8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1, 1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 7: empty input elements. +TestCase TestCase7() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{0}, {})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{0}, {})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({-1})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {-1})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase ShortPaddingTestCase() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {6, 7, 8, 9, 10, 11})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({2})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase InvalidPaddingShapesTestCase() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {6, 7, 8, 9, 10, 11})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({2})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2}, {1, 2})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase InvalidBatchSizeTestCase() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {6, 7, 8, 9, 10, 11})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({2})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {-1}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase InvalidPaddedShapesSizeTestCase() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {6, 7, 8, 9, 10, 11})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({2})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 2, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase InvalidPaddedValuesSizeTestCase() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {6, 7, 8, 9, 10, 11})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({2})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64, DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase InvalidPaddedValuesDTypeTestCase() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {6, 7, 8, 9, 10, 11})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({2})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{}, {"a"})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +TestCase InvalidPaddedValuesShapeTestCase() { + return {/*input_tensors*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {0, 1, 2, 3, 4, 5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape{3, 2}, + {6, 7, 8, 9, 10, 11})}}, + /*concatenate_output_dtypes*/ {DT_INT64}, + /*concatenate_output_shapes*/ {PartialTensorShape({2})}, + /*batch_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {2}), + /*padded_shapes*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*padding_values*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1})}, + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape{}, {false}), + /*parallel_copy*/ true, + /*n*/ 1, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({-1, -1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +class ParameterizedPaddedBatchDatasetOpTest + : public PaddedBatchDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedPaddedBatchDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(padded_batch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(padded_batch_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +TEST_F(PaddedBatchDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + EXPECT_EQ(padded_batch_dataset->node_name(), kNodeName); +} + +TEST_F(PaddedBatchDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + EXPECT_EQ(padded_batch_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedPaddedBatchDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(padded_batch_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedPaddedBatchDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(padded_batch_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedPaddedBatchDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + EXPECT_EQ(padded_batch_dataset->Cardinality(), + test_case.expected_cardinality); +} + +TEST_P(ParameterizedPaddedBatchDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(padded_batch_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedPaddedBatchDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(padded_batch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(padded_batch_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedPaddedBatchDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(padded_batch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(padded_batch_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(PaddedBatchDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(padded_batch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(padded_batch_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::PaddedBatch"); +} + +TEST_P(ParameterizedPaddedBatchDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(padded_batch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(padded_batch_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *padded_batch_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + cur_iteration++; + } + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +INSTANTIATE_TEST_SUITE_P(PaddedBatchDatasetOpTest, + ParameterizedPaddedBatchDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3(), + TestCase4(), TestCase5(), TestCase6(), + TestCase7()}))); + +TEST_F(PaddedBatchDatasetOpTest, ShortPadding) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + TestCase test_case = ShortPaddingTestCase(); + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(padded_batch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(padded_batch_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + EXPECT_EQ( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence) + .code(), + tensorflow::error::DATA_LOSS); +} + +TEST_F(PaddedBatchDatasetOpTest, InvalidPaddedShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + TestCase test_case = InvalidPaddingShapesTestCase(); + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK( + CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(), + &inputs, &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), + &padded_batch_dataset)); + core::ScopedUnref scoped_unref(padded_batch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(padded_batch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(padded_batch_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + EXPECT_EQ( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(PaddedBatchDatasetOpTest, InvalidArguments) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::vector test_cases = { + InvalidBatchSizeTestCase(), InvalidPaddedShapesSizeTestCase(), + InvalidPaddedValuesSizeTestCase(), InvalidPaddedValuesDTypeTestCase(), + InvalidPaddedValuesShapeTestCase()}; + for (const TestCase &test_case : test_cases) { + std::unique_ptr padded_batch_dataset_kernel; + TF_ASSERT_OK(CreatePaddedBatchDatasetKernel( + test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &padded_batch_dataset_kernel)); + + Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK(CreateConcatenateDatasetTensor( + test_case.input_tensors, test_case.concatenate_output_dtypes, + test_case.concatenate_output_shapes, &concatenate_dataset_tensor)); + Tensor batch_size = test_case.batch_size; + std::vector padded_shapes = test_case.padded_shapes; + std::vector padding_values = test_case.padding_values; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&concatenate_dataset_tensor, &batch_size}); + for (auto &padded_shape : padded_shapes) { + inputs.emplace_back(&padded_shape); + } + for (auto &padding_value : padding_values) { + inputs.emplace_back(&padding_value); + } + inputs.emplace_back(&drop_remainder); + + std::unique_ptr padded_batch_dataset_context; + TF_ASSERT_OK(CreatePaddedBatchDatasetContext( + padded_batch_dataset_kernel.get(), &inputs, + &padded_batch_dataset_context)); + DatasetBase *padded_batch_dataset; + EXPECT_EQ( + CreateDataset(padded_batch_dataset_kernel.get(), + padded_batch_dataset_context.get(), &padded_batch_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 4dd5c379c03..835b2387c1e 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -30,6 +30,8 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/cpu_info.h" namespace tensorflow { @@ -54,7 +56,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "f", /*params=*/{}, + &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("sloppy", &sloppy_)); @@ -74,7 +77,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES(ctx, block_length > 0, errors::InvalidArgument("`block_length` must be > 0")); - int64 num_parallel_calls; + int64 num_parallel_calls = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); OP_REQUIRES( @@ -88,31 +91,28 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr captured_func; OP_REQUIRES_OK( - ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", &captured_func)); if (num_parallel_calls == model::kAutoTune) { metrics::RecordTFDataAutotune(kDatasetName); } - *output = - new Dataset(ctx, input, interleave_func_, std::move(captured_func), - cycle_length, block_length, num_parallel_calls, sloppy_, - output_types_, output_shapes_); + *output = new Dataset(ctx, input, std::move(captured_func), cycle_length, + block_length, num_parallel_calls, sloppy_, + output_types_, output_shapes_); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, std::unique_ptr captured_func, int64 cycle_length, int64 block_length, int64 num_parallel_calls, bool sloppy, const DataTypeVector& output_types, const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), input_(input), - interleave_func_(func), captured_func_(std::move(captured_func)), cycle_length_(cycle_length), block_length_(block_length), @@ -149,7 +149,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; @@ -159,24 +158,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { Node* num_parallel_calls_node; TF_RETURN_IF_ERROR( b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + DataTypeVector other_arguments_types; + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); AttrValue f; - b->BuildAttrValue(interleave_func_, &f); + b->BuildAttrValue(captured_func_->func(), &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); AttrValue sloppy_attr; @@ -220,11 +207,20 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { cancelled_ = true; cond_var_->notify_all(); // Wait for all in-flight calls to complete. - while (num_calls_ > 0) { + while (current_num_calls_ > 0 || future_num_calls_ > 0) { cond_var_->wait(l); } } + string BuildTraceMeName() override { + int64 parallelism; + { + tf_shared_lock l(*mu_); + parallelism = num_parallel_calls_->value; + } + return strings::StrCat(prefix(), "#parallelism=", parallelism, "#"); + } + Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (num_parallel_calls_->value == model::kAutoTune) { @@ -267,16 +263,17 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { return model::MakeAsyncInterleaveManyNode( std::move(args), {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, - /*max=*/port::NumSchedulableCPUs())}); + /*max=*/dataset()->cycle_length_)}); } Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(*mu_); // Wait for all in-flight calls to complete. - while (num_calls_ > 0) { + while (current_num_calls_ > 0 || future_num_calls_ > 0) { cond_var_->wait(l); } - DCHECK_EQ(num_calls_, 0); + DCHECK_EQ(current_num_calls_, 0); + DCHECK_EQ(future_num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("block_index"), block_index_)); @@ -453,7 +450,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } } } - return all_elements_busy || num_calls_ >= num_parallel_calls_->value; + return all_elements_busy || + current_num_calls_ >= num_parallel_calls_->value; }; while (true) { mutex_lock l(*mu_); @@ -477,6 +475,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (!future_elements_.empty()) { current_elements_[idx] = std::move(future_elements_.back()); future_elements_.pop_back(); + if (current_elements_[idx]->iterator) { + EnableAutotune(ctx.get(), + current_elements_[idx]->iterator.get()); + } } else { current_elements_[idx] = MakeElement(ctx); if (!current_elements_[idx]) { @@ -493,11 +495,23 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { dataset()->block_length_ - element->results.size(); } if (num_results > 0) { - num_calls_++; + current_num_calls_++; element->in_use = true; - thread_pool_->Schedule( - std::bind(&ParallelInterleaveIterator::FetchResults, this, - ctx, std::move(element), num_results)); + thread_pool_->Schedule(std::bind( + &ParallelInterleaveIterator::FetchResults, this, ctx, + std::move(element), num_results, + [this, ctx]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + --current_num_calls_; + const auto& stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + stats_aggregator->AddScalar( + stats_utils::ThreadUtilizationScalarName( + dataset()->node_name()), + static_cast(current_num_calls_) / + static_cast(num_parallel_calls_->value), + num_elements()); + } + })); } } } @@ -506,8 +520,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { stats_aggregator->AddScalar( stats_utils::ThreadUtilizationScalarName( dataset()->node_name()), - static_cast(num_calls_) / - static_cast(num_parallel_calls_->value)); + static_cast(current_num_calls_) / + static_cast(num_parallel_calls_->value), + num_elements()); } cond_var_->notify_all(); } @@ -532,7 +547,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Fetches up to `dataset()->block_length_` results from `element`. void FetchResults(const std::shared_ptr& ctx, const std::shared_ptr& element, - int64 num_results) LOCKS_EXCLUDED(*mu_) { + int64 num_results, std::function done) + LOCKS_EXCLUDED(*mu_) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); bool end_of_input = false; @@ -560,14 +576,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { element->inputs.clear(); --num_open_; } - --num_calls_; - const auto& stats_aggregator = ctx->stats_aggregator(); - if (stats_aggregator) { - stats_aggregator->AddScalar( - stats_utils::ThreadUtilizationScalarName(dataset()->node_name()), - static_cast(num_calls_) / - static_cast(num_parallel_calls_->value)); - } + done(); cond_var_->notify_all(); } @@ -579,9 +588,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { - // TODO(jsimsa): Autotune the buffer size. - return num_calls_ >= num_parallel_calls_->value || - future_elements_.size() >= 2 * dataset()->cycle_length_; + // TODO(jsimsa): Autotune the number of iterators to prefetch. + return future_elements_.size() >= 2 * dataset()->cycle_length_; }; while (true) { mutex_lock l(*mu_); @@ -608,19 +616,14 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (!element->iterator) { continue; } - ++num_calls_; + DisableAutotune(ctx.get(), element->iterator.get()); + ++future_num_calls_; element->in_use = true; - thread_pool_->Schedule( - std::bind(&ParallelInterleaveIterator::FetchResults, this, ctx, - std::move(element), dataset()->block_length_)); - } - const auto& stats_aggregator = ctx->stats_aggregator(); - if (stats_aggregator) { - stats_aggregator->AddScalar( - stats_utils::ThreadUtilizationScalarName( - dataset()->node_name()), - static_cast(num_calls_) / - static_cast(num_parallel_calls_->value)); + thread_pool_->Schedule(std::bind( + &ParallelInterleaveIterator::FetchResults, this, ctx, + std::move(element), dataset()->block_length_, + [this]() + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { --future_num_calls_; })); } cond_var_->notify_all(); } @@ -905,8 +908,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Identifies the number of open iterators. int64 num_open_ GUARDED_BY(*mu_) = 0; - // Identifies the number of outstanding calls. - int64 num_calls_ GUARDED_BY(*mu_) = 0; + // Identifies the number of outstanding calls for CurrentElementsManager. + int64 current_num_calls_ GUARDED_BY(*mu_) = 0; + // Identifies the number of outstanding calls for FutureElementsManager. + int64 future_num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr thread_pool_; std::unique_ptr current_elements_manager_ GUARDED_BY(*mu_); @@ -919,7 +924,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; - const NameAttrList interleave_func_; const std::unique_ptr captured_func_; const int64 cycle_length_; const int64 block_length_; @@ -929,10 +933,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const std::vector output_shapes_; }; - bool sloppy_; + std::shared_ptr func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_; - NameAttrList interleave_func_; + bool sloppy_; }; REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc new file mode 100644 index 00000000000..6f30cce3fe1 --- /dev/null +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc @@ -0,0 +1,964 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "parallel_interleave_dataset"; +constexpr char kOpName[] = "ParallelInterleaveDatasetV2"; + +class ParallelInterleaveDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Creates a new `ParallelInterleaveDataset` op kernel + Status CreateParallelInterleaveDatasetKernel( + const FunctionDefHelper::AttrValueWrapper &func, + const DataTypeVector &output_types, + const std::vector &output_shapes, bool sloppy, + std::unique_ptr *op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, + {"input_dataset", "cycle_length", "block_length", "num_parallel_calls"}, + {{"f", func}, + {"Targuments", {}}, + {"output_types", output_types}, + {"output_shapes", output_shapes}, + {"sloppy", sloppy}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Creates a new `ParallelInterleaveDataset` op kernel context. + Status CreateInterleaveDatasetContext( + OpKernel *const op_kernel, + gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + FunctionDefHelper::AttrValueWrapper func; + std::vector func_lib; + Tensor cycle_length; + Tensor block_length; + Tensor num_parallel_calls; + bool sloppy; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +template +std::vector ConvertToTensorVec(std::vector values) { + std::vector tensors; + tensors.reserve(values.size()); + for (auto &value : values) { + tensors.emplace_back( + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {value})); + } + return tensors; +} + +FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc( + const DataTypeVector &output_types, + const std::vector &output_shapes) { + return FunctionDefHelper::FunctionRef( + /*name*/ "MakeTensorSliceDataset", + /*attrs*/ {{"Toutput_types", output_types}, + {"output_shapes", output_shapes}}); +} + +// test case 1: cycle_length = 1, block_length = 1, num_parallel_calls = 1, +// sloppy = false +TestCase TestCase1() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*sloppy*/ false, + /*expected_outputs*/ + ConvertToTensorVec({0, 1, 2, 3, 4, 5, 6, 7, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 2: cycle_length = 2, block_length = 1, num_parallel_calls = 2, +// sloppy = false +TestCase TestCase2() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*sloppy*/ false, + /*expected_outputs*/ + ConvertToTensorVec({0, 3, 1, 4, 2, 5, 6, 7, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 3: cycle_length = 3, block_length = 1, num_parallel_calls = 2, +// sloppy = true +TestCase TestCase3() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({0, 3, 6, 1, 4, 7, 2, 5, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 4: cycle_length = 5, block_length = 1, num_parallel_calls = 4, +// sloppy = true +TestCase TestCase4() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({0, 3, 6, 1, 4, 7, 2, 5, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 5: cycle_length = 2, block_length = 2, num_parallel_calls = 1, +// sloppy = false +TestCase TestCase5() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*sloppy*/ false, + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "d", "e", "c", "f", "g", "h", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 6: cycle_length = 2, block_length = 3, num_parallel_calls = 2, +// sloppy = true +TestCase TestCase6() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "c", "d", "e", "f", "g", "h", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 7: cycle_length = 3, block_length = 2, num_parallel_calls = 2, +// sloppy = false +TestCase TestCase7() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*sloppy*/ false, + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "d", "e", "g", "h", "c", "f", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 8: cycle_length = 3, block_length = 3, num_parallel_calls = 3, +// sloppy = true +TestCase TestCase8() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "c", "d", "e", "f", "g", "h", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 9: cycle_length = 4, block_length = 4, num_parallel_calls = 4, +// sloppy = true +TestCase TestCase9() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "c", "d", "e", "f", "g", "h", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 10: cycle_length = 3, block_length = 3, +// num_parallel_calls = kAutoTune, sloppy = true +TestCase TestCase10() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {"a", "b", "c", "d", "e", "f", "g", "h", "i"})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_STRING}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), + {model::kAutoTune}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({"a", "b", "c", "d", "e", "f", "g", "h", "i"}), + /*expected_output_dtypes*/ {DT_STRING}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {0, 4, 11}}; +} + +// test case 11: cycle_length = 0, block_length = 1, num_parallel_calls = 2, +// sloppy = true +TestCase InvalidCycleLengthTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {}}; +} + +// test case 12: cycle_length = 1, block_length = -1, num_parallel_calls = 2, +// sloppy = true +TestCase InvalidBlockLengthTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {-1}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {}}; +} + +// test case 13: cycle_length = 1, block_length = 1, num_parallel_calls = -5, +// sloppy = true +TestCase InvalidNumParallelCallsTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{3, 3, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8})}, + /*func*/ + MakeTensorSliceDatasetFunc( + DataTypeVector({DT_INT64}), + std::vector({PartialTensorShape({1})})), + /*func_lib*/ {test::function::MakeTensorSliceDataset()}, + /*cycle_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*block_length*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {-5}), + /*sloppy*/ true, + /*expected_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ tensorflow::data::kUnknownCardinality, + /*breakpoints*/ {}}; +} + +class ParameterizedParallelInterleaveDatasetOpTest + : public ParallelInterleaveDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedParallelInterleaveDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(), + &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ !test_case.sloppy)); +} + +TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + + std::vector test_cases({InvalidCycleLengthTestCase(), + InvalidBlockLengthTestCase(), + InvalidNumParallelCallsTestCase()}); + for (const auto &test_case : test_cases) { + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = + test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor( + &inputs_for_tensor_slice_dataset, &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + EXPECT_EQ(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +TEST_F(ParallelInterleaveDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + EXPECT_EQ(parallel_interleave_dataset->node_name(), kNodeName); +} + +TEST_F(ParallelInterleaveDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + EXPECT_EQ(parallel_interleave_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(parallel_interleave_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + TF_EXPECT_OK( + VerifyShapesCompatible(parallel_interleave_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + EXPECT_EQ(parallel_interleave_dataset->Cardinality(), + test_case.expected_cardinality); +} + +TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK( + parallel_interleave_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(), + &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(), + &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(ParallelInterleaveDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(), + &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::ParallelInterleaveV2"); +} + +TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + const TestCase &test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_interleave_dataset_kernel; + TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.sloppy, + ¶llel_interleave_dataset_kernel)); + + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor cycle_length = test_case.cycle_length; + Tensor block_length = test_case.block_length; + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector inputs({&tensor_slice_dataset_tensor, + &cycle_length, &block_length, + &num_parallel_calls}); + std::unique_ptr parallel_interleave_dataset_context; + TF_ASSERT_OK(CreateInterleaveDatasetContext( + parallel_interleave_dataset_kernel.get(), &inputs, + ¶llel_interleave_dataset_context)); + DatasetBase *parallel_interleave_dataset; + TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(), + parallel_interleave_dataset_context.get(), + ¶llel_interleave_dataset)); + core::ScopedUnref scoped_unref(parallel_interleave_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(parallel_interleave_dataset_context.get(), + &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_interleave_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *parallel_interleave_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + cur_iteration++; + } + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ !test_case.sloppy)); +} + +INSTANTIATE_TEST_SUITE_P( + ParallelInterleaveDatasetOpTest, + ParameterizedParallelInterleaveDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3(), TestCase4(), TestCase5(), + TestCase6(), TestCase7(), TestCase8(), TestCase9(), TestCase10()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 34f341d1d12..e9a648b8e9e 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -38,11 +38,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelMapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + FunctionMetadata::Params params; + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + ¶ms.use_inter_op_parallelism)); + OP_REQUIRES_OK(ctx, + FunctionMetadata::Create(ctx, "f", params, &func_metadata_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", - &use_inter_op_parallelism_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("sloppy", &sloppy_)); OP_REQUIRES_OK( ctx, ctx->GetAttr("preserve_cardinality", &preserve_cardinality_)); @@ -60,46 +62,35 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { "num_parallel_calls must be greater than zero.")); std::unique_ptr captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", - use_inter_op_parallelism_, - &captured_func)); - - std::vector indices; - OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(ctx, func_metadata_, "other_arguments", + &captured_func)); if (num_parallel_calls == model::kAutoTune) { metrics::RecordTFDataAutotune(kDatasetName); } - *output = - new Dataset(ctx, input, func_, num_parallel_calls, output_types_, - output_shapes_, use_inter_op_parallelism_, sloppy_, - std::move(captured_func), indices, preserve_cardinality_); + *output = new Dataset(ctx, input, num_parallel_calls, output_types_, + output_shapes_, sloppy_, std::move(captured_func), + preserve_cardinality_); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, int32 num_parallel_calls, - const DataTypeVector& output_types, - const std::vector& output_shapes, - bool use_inter_op_parallelism, bool sloppy, + int32 num_parallel_calls, const DataTypeVector& output_types, + const std::vector& output_shapes, bool sloppy, std::unique_ptr captured_func, - const std::vector indices, bool preserve_cardinality) + bool preserve_cardinality) : DatasetBase(DatasetContext(ctx)), input_(input), - func_(func), num_parallel_calls_(num_parallel_calls), output_types_(output_types), output_shapes_(output_shapes), - use_inter_op_parallelism_(use_inter_op_parallelism), sloppy_(sloppy), preserve_cardinality_(preserve_cardinality), - captured_func_(std::move(captured_func)), - indices_(indices), - can_move_(indices.empty() ? std::vector() - : ComputeMoveVector(indices)) { + captured_func_(std::move(captured_func)) { input_->Ref(); } @@ -107,13 +98,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - std::unique_ptr parallel_map_functor(nullptr); - if (indices_.empty()) { - parallel_map_functor = - absl::make_unique(this); - } else { - parallel_map_functor = absl::make_unique(this); - } + std::unique_ptr parallel_map_functor = + absl::make_unique(this); return NewParallelMapIterator( {this, strings::StrCat(prefix, "::", kDatasetName)}, input_, std::move(parallel_map_functor), num_parallel_calls_, sloppy_, @@ -143,22 +129,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); // Input: other_arguments - DataTypeVector other_arguments_types; - other_arguments_types.reserve(captured_func_->captured_inputs().size()); std::vector other_arguments; - other_arguments.reserve(captured_func_->captured_inputs().size()); - for (const Tensor& t : captured_func_->captured_inputs()) { - Node* node; - DatasetBase* input; - Status s = GetDatasetFromVariantTensor(t, &input); - if (s.ok()) { - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); - } else { - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - } - other_arguments.emplace_back(node); - other_arguments_types.emplace_back(t.dtype()); - } + DataTypeVector other_arguments_types; + TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, + &other_arguments_types)); // Input: num_parallel_calls Node* num_parallel_calls = nullptr; @@ -166,9 +140,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { b->AddScalar(num_parallel_calls_, &num_parallel_calls)); // Attr: f - TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); AttrValue f_attr; - b->BuildAttrValue(func_, &f_attr); + b->BuildAttrValue(captured_func_->func(), &f_attr); // Attr: Targuments AttrValue other_arguments_types_attr; @@ -176,7 +149,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { // Attr: use_inter_op_parallelism AttrValue use_inter_op_parallelism_attr; - b->BuildAttrValue(use_inter_op_parallelism_, + b->BuildAttrValue(captured_func_->use_inter_op_parallelism(), &use_inter_op_parallelism_attr); // Attr: sloppy @@ -204,36 +177,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { } private: - class ShortCircuitFunctor : public ParallelMapFunctor { - public: - explicit ShortCircuitFunctor(const Dataset* dataset) - : dataset_(dataset) {} - - void MapFunc(IteratorContext* ctx, const string& prefix, - std::vector input_element, - std::vector* result, StatusCallback done) override { - const std::vector& captured_inputs = - dataset_->captured_func_->captured_inputs(); - size_t num_args = input_element.size(); - for (size_t i = 0; i < dataset_->indices_.size(); ++i) { - if (dataset_->indices_[i] < num_args) { - if (dataset_->can_move_[i]) { - result->push_back( - std::move(input_element[dataset_->indices_[i]])); - } else { - result->push_back(input_element[dataset_->indices_[i]]); - } - } else { - result->push_back( - captured_inputs[dataset_->indices_[i] - num_args]); - } - } - done(Status::OK()); - } - - const Dataset* const dataset_; - }; - class ParallelMapDatasetFunctor : public ParallelMapFunctor { public: explicit ParallelMapDatasetFunctor(const Dataset* dataset) @@ -254,7 +197,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { instantiated_captured_func_->RunAsync( ctx, std::move(input_element), result, std::move(done), prefix); }; - if (!dataset_->use_inter_op_parallelism_) { + if (!dataset_->captured_func_->use_inter_op_parallelism()) { (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(input_element), result, std::move(done))); @@ -270,24 +213,19 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; - const NameAttrList func_; const int32 num_parallel_calls_; const DataTypeVector output_types_; const std::vector output_shapes_; - const bool use_inter_op_parallelism_; const bool sloppy_; const bool preserve_cardinality_; const std::unique_ptr captured_func_; - const std::vector indices_; - const std::vector can_move_; }; + std::shared_ptr func_metadata_ = nullptr; DataTypeVector output_types_; std::vector output_shapes_; - bool use_inter_op_parallelism_; bool sloppy_; bool preserve_cardinality_; - NameAttrList func_; }; REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc new file mode 100644 index 00000000000..abb6e81aff6 --- /dev/null +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc @@ -0,0 +1,819 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "parallel_map_dataset"; +constexpr char kOpName[] = "ParallelMapDataset"; + +class ParallelMapDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates a new `ParallelMapDataset` op kernel + Status CreateParallelMapDatasetOpKernel( + const FunctionDefHelper::AttrValueWrapper& func, + const DataTypeVector& output_types, + const std::vector& output_shapes, + bool use_inter_op_parallelism, bool sloppy, bool preserve_cardinality, + std::unique_ptr* parallel_map_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"input_dataset", "num_parallel_calls"}, + {{"f", func}, + {"Targuments", {}}, + {"output_types", output_types}, + {"output_shapes", output_shapes}, + {"use_inter_op_parallelism", use_inter_op_parallelism}, + {"sloppy", sloppy}, + {"preserve_cardinality", preserve_cardinality}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, parallel_map_kernel)); + return Status::OK(); + } + + // Creates a new `ParallelMapDataset` op kernel context. + Status CreateParallelMapDatasetContext( + OpKernel* const op_kernel, + gtl::InlinedVector* const inputs, + std::unique_ptr* context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct RangeDatasetParam { + int64 start; + int64 end; + int64 step; +}; + +struct TestCase { + RangeDatasetParam range_data_param; + Tensor num_parallel_calls; + FunctionDefHelper::AttrValueWrapper func; + std::vector func_lib; + bool use_inter_op_parallelism; + bool sloppy; + bool preserve_cardinality; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +FunctionDefHelper::AttrValueWrapper MapFunc(const string& func_name, + const DataType& dtype) { + return FunctionDefHelper::FunctionRef(func_name, {{"T", dtype}}); +} + +// test case 1: num_parallel_calls = 1, use_inter_op_parallelism = false, +// sloppy = false, preserve_cardinality = false, MapFunc = XTimesTwo +TestCase TestCase1() { + return {/*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*func*/ MapFunc("XTimesTwo", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo()}, + /*use_inter_op_parallelism*/ false, + /*sloppy*/ false, + /*preserve_cardinality*/ false, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {18})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 9}}; +} + +// test case 2: num_parallel_calls = 2, use_inter_op_parallelism = true, +// sloppy = true, preserve_cardinality = true, MapFunc = XTimesTwo +TestCase TestCase2() { + return {/*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*func*/ MapFunc("XTimesTwo", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo()}, + /*use_inter_op_parallelism*/ true, + /*sloppy*/ true, + /*preserve_cardinality*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {18})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// test case 3: num_parallel_calls = 3, use_inter_op_parallelism = true, +// sloppy = false, preserve_cardinality = false, MapFunc = XTimesFour +TestCase TestCase3() { + return { + /*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*func*/ MapFunc("XTimesFour", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()}, + /*use_inter_op_parallelism*/ true, + /*sloppy*/ false, + /*preserve_cardinality*/ false, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {24}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {36})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// test case 4: num_parallel_calls = 4, use_inter_op_parallelism = false, +// sloppy = false, preserve_cardinality = false, MapFunc = XTimesTwo +TestCase TestCase4() { + return {/*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*func*/ MapFunc("XTimesTwo", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo()}, + /*use_inter_op_parallelism*/ false, + /*sloppy*/ false, + /*preserve_cardinality*/ false, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {18})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// test case 5: num_parallel_calls = kAutoTune, use_inter_op_parallelism = true, +// sloppy = true, preserve_cardinality = true, MapFunc = XTimesFour +TestCase TestCase5() { + return { + /*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), + {model::kAutoTune}), + /*func*/ MapFunc("XTimesFour", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()}, + /*use_inter_op_parallelism*/ true, + /*sloppy*/ true, + /*preserve_cardinality*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {24}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {36})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// test case 6: num_parallel_calls = 4, use_inter_op_parallelism = true, +// sloppy = false, preserve_cardinality = false, MapFunc = XTimesFour +TestCase TestCase6() { + return { + /*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*func*/ MapFunc("XTimesFour", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()}, + /*use_inter_op_parallelism*/ true, + /*sloppy*/ false, + /*preserve_cardinality*/ false, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {24}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {36})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// TODO(feihugis): make this test case work. +// test case 7: num_parallel_calls = 2, use_inter_op_parallelism = false, +// sloppy = false, preserve_cardinality = false, MapFunc = XTimesFour +TestCase TestCase7() { + return { + /*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*func*/ MapFunc("XTimesFour", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()}, + /*use_inter_op_parallelism*/ false, + /*sloppy*/ false, + /*preserve_cardinality*/ false, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {24}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {36})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +// TODO(feihugis): make this test case work. +// test case 8: num_parallel_calls = kAutoTune, use_inter_op_parallelism = +// false, sloppy = true, preserve_cardinality = true, MapFunc = XTimesFour +TestCase TestCase8() { + return { + /*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), + {model::kAutoTune}), + /*func*/ MapFunc("XTimesFour", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo(), test::function::XTimesFour()}, + /*use_inter_op_parallelism*/ false, + /*sloppy*/ true, + /*preserve_cardinality*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {12}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {24}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {36})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 5}}; +} + +TestCase InvalidNumParallelCallsTestCase() { + return {/*range_data_param*/ {0, 10, 3}, + /*num_parallel_calls*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {-4}), + /*func*/ MapFunc("XTimesTwo", DT_INT64), + /*func_lib*/ {test::function::XTimesTwo()}, + /*use_inter_op_parallelism*/ true, + /*sloppy*/ true, + /*preserve_cardinality*/ true, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ -1, + /*breakpoints*/ {0, 1, 5}}; +} + +class ParameterizedParallelMapDatasetOpTest + : public ParallelMapDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedParallelMapDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(parallel_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_map_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ !test_case.sloppy)); +} + +TEST_F(ParallelMapDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + EXPECT_EQ(parallel_map_dataset->node_name(), kNodeName); +} + +TEST_F(ParallelMapDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + EXPECT_EQ(parallel_map_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedParallelMapDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(parallel_map_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedParallelMapDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(parallel_map_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedParallelMapDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + EXPECT_EQ(parallel_map_dataset->Cardinality(), + test_case.expected_cardinality); +} + +TEST_P(ParameterizedParallelMapDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + std::unique_ptr serialization_context; + TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK( + parallel_map_dataset->Save(serialization_context.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedParallelMapDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(parallel_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_map_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedParallelMapDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(parallel_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_map_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(ParallelMapDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(parallel_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_map_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::ParallelMap"); +} + +TEST_P(ParameterizedParallelMapDatasetOpTest, Roundtrip) { + int thread_num = 3, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), + ¶llel_map_dataset)); + core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(parallel_map_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(parallel_map_dataset->MakeIterator(iterator_ctx.get(), + "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *parallel_map_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + cur_iteration++; + } + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ !test_case.sloppy)); +} + +TEST_F(ParallelMapDatasetOpTest, InvalidNumParallelCalls) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = InvalidNumParallelCallsTestCase(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num)); + + std::unique_ptr parallel_map_dataset_kernel; + TF_ASSERT_OK(CreateParallelMapDatasetOpKernel( + test_case.func, test_case.expected_output_dtypes, + test_case.expected_output_shapes, test_case.use_inter_op_parallelism, + test_case.sloppy, test_case.preserve_cardinality, + ¶llel_map_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor num_parallel_calls = test_case.num_parallel_calls; + gtl::InlinedVector parallel_map_dataset_inputs( + {&range_dataset_tensor, &num_parallel_calls}); + + std::unique_ptr parallel_map_dataset_context; + TF_ASSERT_OK(CreateParallelMapDatasetContext( + parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs, + ¶llel_map_dataset_context)); + DatasetBase* parallel_map_dataset; + EXPECT_EQ( + CreateDataset(parallel_map_dataset_kernel.get(), + parallel_map_dataset_context.get(), ¶llel_map_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +INSTANTIATE_TEST_SUITE_P(ParallelMapDatasetOpTest, + ParameterizedParallelMapDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3(), + TestCase4(), TestCase5(), TestCase6()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 3b0d6d7a449..52befecb12e 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/kernels/data/stats_utils.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/cpu_info.h" namespace tensorflow { @@ -72,6 +73,15 @@ class ParallelMapIterator : public DatasetBaseIterator { } } + string BuildTraceMeName() override { + int64 parallelism; + { + tf_shared_lock l(*mu_); + parallelism = num_parallel_calls_->value; + } + return strings::StrCat(prefix(), "#parallelism=", parallelism, "#"); + } + Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (num_parallel_calls_->value == model::kAutoTune) { @@ -148,6 +158,7 @@ class ParallelMapIterator : public DatasetBaseIterator { int64 invocation_results_size; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("invocation_results.size"), &invocation_results_size)); + if (!invocation_results_.empty()) invocation_results_.clear(); for (size_t i = 0; i < invocation_results_size; i++) { invocation_results_.push_back(std::make_shared()); auto& result = *invocation_results_.back(); @@ -207,7 +218,8 @@ class ParallelMapIterator : public DatasetBaseIterator { stats_aggregator->AddScalar( stats_utils::ThreadUtilizationScalarName(key_prefix_), static_cast(num_calls_) / - static_cast(num_parallel_calls_->value)); + static_cast(num_parallel_calls_->value), + num_elements()); } RecordBufferEnqueue(ctx.get(), result->return_values); result->notification.Notify(); @@ -302,7 +314,8 @@ class ParallelMapIterator : public DatasetBaseIterator { stats_aggregator->AddScalar( stats_utils::ThreadUtilizationScalarName(key_prefix_), static_cast(num_calls_) / - static_cast(num_parallel_calls_->value)); + static_cast(num_parallel_calls_->value), + num_elements()); } cond_var_->notify_all(); } diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 9773b492905..e356044492c 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { namespace data { @@ -31,14 +32,18 @@ namespace data { // See documentation in ../../ops/dataset_ops.cc for a high-level // description of the following op. +// Determines the fraction of slack time by which to delay prefetching of data. +constexpr double kSleepFactor = 0.2; constexpr char kDatasetName[] = "Prefetch"; class PrefetchDatasetOp::Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size) + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, + int64 slack_period) : DatasetBase(DatasetContext(ctx)), input_(input), - buffer_size_(buffer_size) { + buffer_size_(buffer_size), + slack_period_(slack_period) { input_->Ref(); } @@ -70,8 +75,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* buffer_size = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, buffer_size}, output)); + AttrValue slack_period_attr; + b->BuildAttrValue(slack_period_, &slack_period_attr); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, buffer_size}, + {std::make_pair("slack_period", slack_period_attr)}, output)); return Status::OK(); } @@ -81,6 +89,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { explicit Iterator(const Params& params) : DatasetIterator(params), auto_tuner_(params.dataset->buffer_size_) { + slack_us_ = 0; } ~Iterator() override { @@ -99,6 +108,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } } + string BuildTraceMeName() override { + int64 buffer_limit; + { + tf_shared_lock l(mu_); + buffer_limit = auto_tuner_.buffer_limit(); + } + return strings::StrCat(prefix(), "#buffer_limit=", buffer_limit, "#"); + } + Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } @@ -142,10 +160,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { if (stats_aggregator) { stats_aggregator->AddScalar( stats_utils::BufferSizeScalarName(dataset()->node_name()), - static_cast(buffer_.size())); + static_cast(buffer_.size()), num_elements()); stats_aggregator->AddScalar( stats_utils::BufferCapacityScalarName(dataset()->node_name()), - static_cast(auto_tuner_.buffer_limit())); + static_cast(auto_tuner_.buffer_limit()), num_elements()); } return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -227,6 +245,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status status; // The buffered data element. std::vector value; + int64 created_us; }; Status Consume(IteratorContext* ctx, std::vector* out_tensors, @@ -236,18 +255,33 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { stats_aggregator->AddToHistogram( stats_utils::BufferUtilizationHistogramName(dataset()->node_name()), {static_cast(buffer_.size()) / - static_cast(auto_tuner_.buffer_limit())}); + static_cast(auto_tuner_.buffer_limit())}, + num_elements()); stats_aggregator->AddScalar( stats_utils::BufferSizeScalarName(dataset()->node_name()), - static_cast(buffer_.size())); + static_cast(buffer_.size()), num_elements()); stats_aggregator->AddScalar( stats_utils::BufferCapacityScalarName(dataset()->node_name()), - static_cast(auto_tuner_.buffer_limit())); + static_cast(auto_tuner_.buffer_limit()), num_elements()); } // A new element is available. Forward the status from computing it, and // (if we successfully got an element) the output values. Status s = buffer_.front().status; if (s.ok()) { + if (dataset()->slack_period_ > 0 && + (num_elements() + 1) % dataset()->slack_period_ == 0) { + // TODO(rachelim): Consider doing something more sophisticated + // to decide how long to sleep for; e.g. using a kalman filter. + int64 slack_us = + Env::Default()->NowMicros() - buffer_.front().created_us; + // Every slack_period_-th element, update the most recent slack time, + // measured by the duration between when the element is prefetched + // and when it is consumed. We add kSleepFactor * slack_us_ to the + // measurement because we slept for that duration before prefetching + // the element. + slack_us_ = kSleepFactor * slack_us_ + slack_us; + VLOG(2) << "Setting slack_us_: " << slack_us_; + } *out_tensors = std::move(buffer_.front().value); RecordBufferDequeue(ctx, *out_tensors); } @@ -282,9 +316,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { void PrefetchThread(const std::shared_ptr& ctx) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); + // Keep track of where we are in an iteration "burst" + int num_produced = 0; while (true) { - std::vector value; - // 1. Wait for a slot in the buffer. { mutex_lock l(mu_); @@ -299,6 +333,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } } + if (dataset()->slack_period_ > 0 && + num_produced % dataset()->slack_period_ == 0) { + // For the first element in the "burst", sleep for a bit if there is + // slack. + VLOG(2) << "Sleeping for: " << slack_us_ * kSleepFactor; + ctx->env()->SleepForMicroseconds(slack_us_ * kSleepFactor); + } + // 2. Read the next element. // Acquire the parent lock since we will be reading an element // from the input iterator. Note that we do not wish to release @@ -321,9 +363,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { { mutex_lock l(mu_); RecordBufferEnqueue(ctx.get(), buffer_element.value); + buffer_element.created_us = ctx->env()->NowMicros(); buffer_.push_back(std::move(buffer_element)); cond_var_.notify_all(); } + ++num_produced; } } @@ -377,25 +421,34 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { std::unique_ptr prefetch_thread_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_) = false; bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; + + std::atomic slack_us_; }; const DatasetBase* const input_; const int64 buffer_size_; + + // If non-zero, determines the period between injecting "slack" into the + // execution. + const int64 slack_period_; }; void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) { - int64 buffer_size; + int64 buffer_size = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); OP_REQUIRES(ctx, buffer_size >= 0 || buffer_size == PrefetchAutotuner::kAutoTune, - errors::InvalidArgument("buffer_size must be >= 0")); + errors::InvalidArgument("buffer_size must be >= 0 or set " + "buffer_size to be ", + PrefetchAutotuner::kAutoTune, + " for auto-tuning")); if (buffer_size == PrefetchAutotuner::kAutoTune) { metrics::RecordTFDataAutotune(kDatasetName); } - *output = new Dataset(ctx, input, buffer_size); + *output = new Dataset(ctx, input, buffer_size, slack_period_); } namespace { diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h index 83206374946..d42e14373bd 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.h +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h @@ -25,7 +25,11 @@ namespace data { class PrefetchDatasetOp : public UnaryDatasetOpKernel { public: explicit PrefetchDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} + : UnaryDatasetOpKernel(ctx) { + if (ctx->HasAttr("slack_period")) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("slack_period", &slack_period_)); + } + } protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, @@ -33,6 +37,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { private: class Dataset; + int64 slack_period_ = 0; }; } // namespace data diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc new file mode 100644 index 00000000000..56dfbc510e8 --- /dev/null +++ b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc @@ -0,0 +1,635 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "prefetch_dataset"; +constexpr char kOpName[] = "PrefetchDataset"; + +class PrefetchDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Create a new `PrefetchDataset` op kernel. + Status CreatePrefetchDatasetKernel( + const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + NodeDef node_def = test::function::NDef(kNodeName, kOpName, + {"input_dataset", "buffer_size"}, + {{"output_types", output_types}, + {"output_shapes", output_shapes}, + {"slack_period", 0}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Create a new `PrefetchDataset` op kernel context. + Status CreatePrefetchDatasetContext( + OpKernel *op_kernel, gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + int64 buffer_size; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +TestCase PositiveBufferSizeTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*buffer_size*/ 5, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {0}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {5}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {6}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {7}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 4, 11}}; +} + +TestCase ZeroBufferSizeTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*buffer_size*/ 0, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {0}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {5}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {6}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {7}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 4, 11}}; +} + +TestCase AutoTuneTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*buffer_size*/ -1, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {0}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {5}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {6}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {7}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 4, 11}}; +} + +TestCase InvalidBufferSizeTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*buffer_size*/ -2, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 4, 11}}; +} + +class ParameterizedPrefetchDatasetOpTest + : public PrefetchDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedPrefetchDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(prefetch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(prefetch_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + auto expected_outputs_it = test_case.expected_outputs.begin(); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + for (const auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); +} + +TEST_F(PrefetchDatasetOpTest, InvalidBufferSize) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = InvalidBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + EXPECT_EQ(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), &prefetch_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(PrefetchDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + EXPECT_EQ(prefetch_dataset->node_name(), kNodeName); +} + +TEST_F(PrefetchDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + EXPECT_EQ(prefetch_dataset->type_string(), kOpName); +} + +TEST_F(PrefetchDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(prefetch_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_F(PrefetchDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(prefetch_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedPrefetchDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + EXPECT_EQ(prefetch_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_F(PrefetchDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(prefetch_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_F(PrefetchDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(prefetch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(prefetch_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_F(PrefetchDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(prefetch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(prefetch_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(PrefetchDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(prefetch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(prefetch_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::Prefetch"); +} + +TEST_P(ParameterizedPrefetchDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PositiveBufferSizeTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor buffer_size = + CreateTensor(TensorShape{}, {test_case.buffer_size}); + gtl::InlinedVector inputs_for_prefetch_dataset( + {&tensor_slice_dataset_tensor, &buffer_size}); + + std::unique_ptr prefetch_dataset_kernel; + TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &prefetch_dataset_kernel)); + std::unique_ptr prefetch_dataset_context; + TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(), + &inputs_for_prefetch_dataset, + &prefetch_dataset_context)); + DatasetBase *prefetch_dataset; + TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(), + prefetch_dataset_context.get(), + &prefetch_dataset)); + core::ScopedUnref scoped_unref(prefetch_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(prefetch_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(prefetch_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *prefetch_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + for (auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_outputs.size()) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } +} + +INSTANTIATE_TEST_SUITE_P(PreFetchDatasetOpTest, + ParameterizedPrefetchDatasetOpTest, + ::testing::ValuesIn(std::vector( + {PositiveBufferSizeTestCase(), + ZeroBufferSizeTestCase(), AutoTuneTestCase()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/range_dataset_op_test.cc b/tensorflow/core/kernels/data/range_dataset_op_test.cc index bfe091fd524..608b8e81b51 100644 --- a/tensorflow/core/kernels/data/range_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/range_dataset_op_test.cc @@ -13,237 +13,324 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/kernels/data/dataset_test_base.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" -#include "tensorflow/core/kernels/data/iterator_ops.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { namespace { +constexpr char kNodeName[] = "range_dataset"; constexpr char kOpName[] = "RangeDataset"; class RangeDatasetOpTest : public DatasetOpsTestBase { protected: // Creates a new RangeDataset op kernel context. Status CreateRangeDatasetContext( - int64 start, int64 end, int64 step, OpKernel* const range_kernel, + OpKernel* const range_kernel, + gtl::InlinedVector* const inputs, std::unique_ptr* range_context) { - inputs_.clear(); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {start})); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {end})); - TF_RETURN_IF_ERROR(AddDatasetInputFromArray( - &inputs_, range_kernel->input_types(), TensorShape({}), {step})); - + TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, *inputs)); TF_RETURN_IF_ERROR( - CreateOpKernelContext(range_kernel, &inputs_, range_context)); - TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, inputs_)); + CreateOpKernelContext(range_kernel, inputs, range_context)); return Status::OK(); } - - private: - gtl::InlinedVector inputs_; }; -struct GetNextTestParams { - explicit GetNextTestParams(int64 input_start, int64 input_end, - int64 input_step) - : start(input_start), end(input_end), step(input_step) {} - +struct TestCase { int64 start; int64 end; int64 step; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; }; -struct DatasetGetNextTest : RangeDatasetOpTest, - ::testing::WithParamInterface {}; +TestCase PositiveStepTestCase() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 3, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 4}}; +} -TEST_P(DatasetGetNextTest, GetNext) { +TestCase NegativeStepTestCase() { + return {/*start*/ 10, + /*end*/ 0, + /*step*/ -3, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {7}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 4}}; +} + +TestCase ZeroStepTestCase() { + return {/*start*/ 0, + /*end*/ 10, + /*step*/ 0, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {}, + /*expected_output_shapes*/ {}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} + +class ParameterizedRangeDatasetOpTest + : public RangeDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedRangeDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; - GetNextTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector out_tensors; while (!end_of_sequence) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } } - std::vector expected_values; - for (int i = params.start; (params.end - i) * params.step > 0; - i = i + params.step) { - expected_values.reserve(1); - expected_values.emplace_back(i); - } - EXPECT_EQ(out_tensors.size(), expected_values.size()); - for (size_t i = 0; i < out_tensors.size(); ++i) { - int64 actual_value = out_tensors[i].flat()(0); - int64 expect_value = expected_values[i]; - EXPECT_EQ(actual_value, expect_value); - } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetGetNextTest, - ::testing::Values(GetNextTestParams(0, 10, 1), - GetNextTestParams(0, 10, 3), - GetNextTestParams(10, 0, -1), - GetNextTestParams(10, 0, -3))); - -TEST_F(RangeDatasetOpTest, DatasetName) { - int64 start = 0, end = 10, step = 1; +TEST_F(RangeDatasetOpTest, ZeroStep) { int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = ZeroStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + EXPECT_EQ(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +TEST_F(RangeDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; + TF_ASSERT_OK( + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); + core::ScopedUnref scoped_unref(range_dataset); + + EXPECT_EQ(range_dataset->node_name(), kNodeName); +} + +TEST_F(RangeDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; + TF_ASSERT_OK( + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); EXPECT_EQ(range_dataset->type_string(), kOpName); } TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(range_dataset->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(range_dataset->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(RangeDatasetOpTest, DatasetOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < range_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE( - range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(range_dataset->output_shapes(), + test_case.expected_output_shapes)); } -struct CardinalityTestParams { - explicit CardinalityTestParams(int64 input_start, int64 input_end, - int64 input_step, - int input_expected_cardinality) - : start(input_start), - end(input_end), - step(input_step), - expected_cardinality(input_expected_cardinality) {} - - int64 start; - int64 end; - int64 step; - int expected_cardinality; -}; - -struct DatasetCardinalityTest - : RangeDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetCardinalityTest, Cardinality) { +TEST_P(ParameterizedRangeDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; - CardinalityTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); - EXPECT_EQ(range_dataset->Cardinality(), params.expected_cardinality); + EXPECT_EQ(range_dataset->Cardinality(), test_case.expected_cardinality); } -INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetCardinalityTest, - ::testing::Values(CardinalityTestParams(0, 10, 1, 10), - CardinalityTestParams(0, 10, 3, 4), - CardinalityTestParams(10, 0, -3, 4))); - TEST_F(RangeDatasetOpTest, DatasetSave) { int64 thread_num = 2, cpu_num = 2; - int start = 0, end = 10, step = 1; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr serialization_context; @@ -256,81 +343,105 @@ TEST_F(RangeDatasetOpTest, DatasetSave) { } TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - DataTypeVector expected_dtypes({DT_INT64}); - EXPECT_EQ(iterator->output_dtypes(), expected_dtypes); + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); } TEST_F(RangeDatasetOpTest, IteratorOutputShapes) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector expected_shapes({PartialTensorShape({})}); - EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size()); - for (int i = 0; i < range_dataset->output_shapes().size(); ++i) { - EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i])); - } + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); } TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) { - int64 start = 0, end = 10, step = 1; int thread_num = 2, cpu_num = 2; - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(), - &range_context)); - DatasetBase* range_dataset; + TestCase test_case = PositiveStepTestCase(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); @@ -338,83 +449,78 @@ TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), "Iterator::Range"); } -struct RoundtripTestParams { - explicit RoundtripTestParams(int64 input_start, int64 input_end, - int64 input_step, int input_breakpoint) - : start(input_start), - end(input_end), - step(input_step), - breakpoint(input_breakpoint) {} - - int64 start; - int64 end; - int64 step; - int breakpoint; -}; - -struct IteratorRoundtripTest - : RangeDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorRoundtripTest, Roundtrip) { +TEST_P(ParameterizedRangeDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; - RoundtripTestParams params = GetParam(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr range_kernel; - TF_ASSERT_OK(CreateRangeDatasetOpKernel("range", &range_kernel)); - std::unique_ptr range_context; - TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step, - range_kernel.get(), &range_context)); - DatasetBase* range_dataset; + TestCase test_case = GetParam(); + gtl::InlinedVector inputs; + Tensor start = CreateTensor(TensorShape({}), {test_case.start}); + Tensor end = CreateTensor(TensorShape({}), {test_case.end}); + Tensor step = CreateTensor(TensorShape({}), {test_case.step}); + inputs.emplace_back(&start); + inputs.emplace_back(&end); + inputs.emplace_back(&step); + + std::unique_ptr range_dataset_kernel; TF_ASSERT_OK( - CreateDataset(range_kernel.get(), range_context.get(), &range_dataset)); + CreateRangeDatasetOpKernel(kNodeName, &range_dataset_kernel)); + std::unique_ptr range_dataset_context; + TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs, + &range_dataset_context)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(), + range_dataset_context.get(), &range_dataset)); core::ScopedUnref scoped_unref(range_dataset); std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context)); + TF_ASSERT_OK( + CreateIteratorContext(range_dataset_context.get(), &iterator_context)); std::unique_ptr iterator; TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator)); - std::vector out_tensors; + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; - int64 cur_val = params.start - params.step; - for (int i = 0; i < params.breakpoint; i++) { - if (!end_of_sequence) { + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader, "Iterator", + *range_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, &end_of_sequence)); - cur_val = ((params.end - cur_val - params.step) * params.step > 0) - ? cur_val + params.step - : cur_val; + if (!end_of_sequence) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_cardinality) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); } } - - std::unique_ptr serialization_context; - TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); - VariantTensorData data; - VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); - TF_ASSERT_OK(writer.Flush()); - VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - int64 expect_next = ((params.end - cur_val - params.step) * params.step > 0) - ? cur_val + params.step - : cur_val; - EXPECT_EQ(out_tensors.back().flat()(0), expect_next); } -INSTANTIATE_TEST_CASE_P( - RangeDatasetOpTest, IteratorRoundtripTest, - ::testing::Values( - RoundtripTestParams(0, 10, 2, 0), // unused_iterator - RoundtripTestParams(0, 10, 2, 4), // fully_used_iterator_increase - RoundtripTestParams(10, 0, -2, 4), // fully_used_iterator_decrease - RoundtripTestParams(0, 10, 2, 6))); // exhausted_iterator +INSTANTIATE_TEST_SUITE_P( + RangeDatasetOpTest, ParameterizedRangeDatasetOpTest, + ::testing::ValuesIn(std::vector({PositiveStepTestCase(), + NegativeStepTestCase()}))); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index c8e0e9ea944..9ab687c0d7d 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -751,6 +751,7 @@ class TFRecordDatasetOp : public DatasetOpKernel { std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + VLOG(2) << "Reading file: " << filenames_tensor->flat()(i); filenames.push_back(filenames_tensor->flat()(i)); } diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index ef507ffdd1d..77c54c3dd7d 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -220,6 +220,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); } Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + DCHECK(!*end_of_sequence || out_tensors->empty()); if (first_call_ && *end_of_sequence) { // If the first call to GetNext() fails because the end // of sequence has been reached, we terminate the diff --git a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc new file mode 100644 index 00000000000..3aa58a2767b --- /dev/null +++ b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc @@ -0,0 +1,591 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "repeat_dataset"; +constexpr char kOpName[] = "RepeatDataset"; + +class RepeatDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Creates a new `RepeatDataset` op kernel. + Status CreateRepeatDatasetKernel( + const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"input_dataset", "count"}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Create a new `RepeatDataset` op kernel context. + Status CreateRepeatDatasetContext( + OpKernel *op_kernel, gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + int64 count; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +TestCase FiniteRepeatTestCase() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 2}, {1, 2, 3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 1}, {"a", "b"})}, + /*count*/ 2, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2}, {1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {"a"}), + DatasetOpsTestBase::CreateTensor(TensorShape{2}, {3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {"b"}), + DatasetOpsTestBase::CreateTensor(TensorShape{2}, {1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {"a"}), + DatasetOpsTestBase::CreateTensor(TensorShape{2}, {3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {"b"})}, + /*expected_output_dtypes*/ {DT_INT64, DT_STRING}, + /*expected_output_shapes*/ + {PartialTensorShape({2}), PartialTensorShape({1})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 3}}; +} + +TestCase EmptyRepeatTestCase() { + return { + /*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 2}, {1, 2, 3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape{2, 1}, {"a", "b"})}, + /*count*/ 0, + /*expected_outputs*/ + {}, + /*expected_output_dtypes*/ {DT_INT64, DT_STRING}, + /*expected_output_shapes*/ + {PartialTensorShape({2}), PartialTensorShape({1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 3}}; +} + +TestCase ForeverRepeatTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{2, 1}, {1, 2})}, + /*count*/ -1, + /*expected_outputs*/ + // Use the first group of the repeated tensors to represent the + // infinite outputs. + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ -1, + /*breakpoints*/ {0, 1, 3}}; +} + +class ParameterizedDatasetOpTest + : public RepeatDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(repeat_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + repeat_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + auto expected_outputs_it = test_case.expected_outputs.begin(); + bool end_of_sequence = false; + std::vector out_tensors; + + if (test_case.count < 0) { + // We test only a finite number of steps of the infinite sequence. + for (int i = 0; i < 100; ++i) { + out_tensors.clear(); + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + for (const auto &tensor : out_tensors) { + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + // In the forever-repeat test case, the first group of the repeated + // tensors is used to represent the expected outputs, so the iterator + // of the expected outputs needs to be reset once it reaches the end. + if (expected_outputs_it == test_case.expected_outputs.end()) { + expected_outputs_it = test_case.expected_outputs.begin(); + } + } + } + EXPECT_FALSE(end_of_sequence); + } else { + while (!end_of_sequence) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + for (const auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } +} + +TEST_F(RepeatDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = FiniteRepeatTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + EXPECT_EQ(repeat_dataset->node_name(), kNodeName); +} + +TEST_F(RepeatDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = FiniteRepeatTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + EXPECT_EQ(repeat_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + TF_EXPECT_OK(VerifyTypesMatch(repeat_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + TF_EXPECT_OK(VerifyShapesCompatible(repeat_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + EXPECT_EQ(repeat_dataset->Cardinality(), GetParam().expected_cardinality); +} + +TEST_F(RepeatDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = FiniteRepeatTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(repeat_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(repeat_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + repeat_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(repeat_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + repeat_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(repeat_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + repeat_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + if (test_case.count < 0) { + EXPECT_EQ(iterator->prefix(), "Iterator::ForeverRepeat"); + } else if (test_case.count == 0) { + EXPECT_EQ(iterator->prefix(), "Iterator::EmptyRepeat"); + } else { + EXPECT_EQ(iterator->prefix(), "Iterator::FiniteRepeat"); + } +} + +TEST_P(ParameterizedDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + auto expected_outputs_it = test_case.expected_outputs.begin(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(repeat_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + repeat_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = repeat_dataset->Cardinality() == 0; + std::vector out_tensors; + int cur_iteration = 0; + std::vector breakpoints = GetParam().breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *repeat_dataset, &iterator)); + + while (cur_iteration < breakpoint) { + out_tensors.clear(); + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + for (auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + cur_iteration++; + if (test_case.count < 0 && + expected_outputs_it == test_case.expected_outputs.end()) { + expected_outputs_it = test_case.expected_outputs.begin(); + } + } + + if (breakpoint >= repeat_dataset->Cardinality()) { + if (test_case.count < 0) { + EXPECT_FALSE(end_of_sequence); + } else { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } + } else { + EXPECT_FALSE(end_of_sequence); + } + } +} + +INSTANTIATE_TEST_SUITE_P(RepeatDatasetOpTest, ParameterizedDatasetOpTest, + ::testing::ValuesIn(std::vector( + {FiniteRepeatTestCase(), EmptyRepeatTestCase(), + ForeverRepeatTestCase()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc index 9bb64911aa8..59825b463cb 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/batch_util.h" namespace tensorflow { @@ -27,7 +28,9 @@ namespace { class ShardDatasetOp : public UnaryDatasetOpKernel { public: explicit ShardDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("require_non_empty", &require_non_empty_)); + } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { @@ -48,18 +51,19 @@ class ShardDatasetOp : public UnaryDatasetOpKernel { errors::InvalidArgument("Index must be between 0 and ", num_shards - 1, " (currently index = ", index, ").")); - *output = new Dataset(ctx, num_shards, index, input); + *output = new Dataset(ctx, num_shards, index, require_non_empty_, input); } private: class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, int64 num_shards, int64 index, - const DatasetBase* input) + bool require_non_empty, const DatasetBase* input) : DatasetBase(DatasetContext(ctx)), num_shards_(num_shards), index_(index), - input_(input) { + input_(input), + require_non_empty_(require_non_empty) { input_->Ref(); } @@ -102,8 +106,13 @@ class ShardDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddScalar(num_shards_, &num_shards)); Node* index = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(index_, &index)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, num_shards, index}, output)); + + AttrValue require_non_empty_attr; + b->BuildAttrValue(require_non_empty_, &require_non_empty_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, num_shards, index}, + {{"require_non_empty", require_non_empty_attr}}, output)); return Status::OK(); } @@ -138,6 +147,26 @@ class ShardDatasetOp : public UnaryDatasetOpKernel { } } while ((next_index_++ % dataset()->num_shards_) != dataset()->index_); + while (dataset()->require_non_empty_ && + next_index_ < dataset()->num_shards_) { + std::vector unused_result; + + Status s = input_impl_->GetNext(ctx, &unused_result, end_of_sequence); + if (*end_of_sequence || errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "There aren't enough elements in this dataset for each shard " + "to have at least one element (# elems = ", + next_index_, ", ", "# shards = ", dataset()->num_shards_, + "). If you are using ", + "datasets with distribution strategy, consider turning ", + "dataset autosharding off with `tf.data.Options`."); + } else if (!s.ok()) { + return s; + } + + next_index_++; + } + *out_tensors = std::move(result); return Status::OK(); } @@ -184,7 +213,10 @@ class ShardDatasetOp : public UnaryDatasetOpKernel { const int64 num_shards_; const int64 index_; const DatasetBase* const input_; + const bool require_non_empty_; }; + + bool require_non_empty_; }; REGISTER_KERNEL_BUILDER(Name("ShardDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/shard_dataset_op_test.cc b/tensorflow/core/kernels/data/shard_dataset_op_test.cc new file mode 100644 index 00000000000..6da1ff3b570 --- /dev/null +++ b/tensorflow/core/kernels/data/shard_dataset_op_test.cc @@ -0,0 +1,821 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "shard_dataset"; +constexpr char kOpName[] = "ShardDataset"; + +class ShardDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates a new `ShardDataset` op kernel. + Status CreateShardDatasetOpKernel( + bool require_non_empty, const DataTypeVector& output_types, + const std::vector& output_shapes, + std::unique_ptr* op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"input_dataset", "num_shards", "index"}, + {{"require_non_empty", require_non_empty}, + {"output_types", output_types}, + {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Create a new `ShardDataset` op kernel context + Status CreateShardDatasetContext( + OpKernel* const op_kernel, + gtl::InlinedVector* const inputs, + std::unique_ptr* context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct RangeDatasetParam { + int64 start; + int64 end; + int64 step; +}; + +struct TestCase { + RangeDatasetParam range_dataset_param; + Tensor num_shards; + Tensor index; + bool require_non_empty; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +// Test Case 1: simple case. +TestCase TestCase1() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*require_non_empty*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {7})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 2, + /*breakpoints*/ {0, 1, 5}}; +} + +// Test Case 2: zero offset. +TestCase TestCase2() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + /*require_non_empty*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 2, + /*breakpoints*/ {0, 1, 5}}; +} + +// Test Case 3: iterator ends before first element. +TestCase TestCase3() { + return {/*range_data_param*/ {0, 1, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*require_non_empty*/ true, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1}}; +} + +// Test Case 4: larger num_shards. +TestCase TestCase4() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {7}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*require_non_empty*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {5})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 1, + /*breakpoints*/ {0, 5}}; +} + +// Test Case 5: index == num_shards. +TestCase TestCase5() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*require_non_empty*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 2, + /*breakpoints*/ {0, 1, 5}}; +} + +// Test Case 6: similar with test_case_5 but the number of outputs could not be +// divided evenly by num_shards. +TestCase TestCase6() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*require_non_empty*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {7})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 2, + /*breakpoints*/ {0, 1, 5}}; +} + +// Test Case 7: num_shard is larger than the cardinality of input dataset; +// require_non_empty = false. +TestCase TestCase7() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {20}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*require_non_empty*/ false, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {5})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 1, + /*breakpoints*/ {0, 5}}; +} + +// Test Case 8: similar with test_case_7 but require_non_empty = true. +TestCase NoElemForEachShardTestCase() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {20}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*require_non_empty*/ true, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {5})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 1, + /*breakpoints*/ {0, 5}}; +} + +TestCase IndexGreaterNumShardsCase() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {7}), + /*require_non_empty*/ true, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} + +TestCase NegativeIndexTestCase() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {-3}), + /*require_non_empty*/ true, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} + +TestCase NegativeNumShardsTestCase() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {-3}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*require_non_empty*/ true, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} + +TestCase ZeroNumShardsTestCase() { + return {/*range_data_param*/ {0, 10, 1}, + /*num_shards*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + /*index*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*require_non_empty*/ true, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} + +class ParameterizedShardDatasetOpTest + : public ShardDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedShardDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(shard_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + shard_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); + std::vector out_tensors; + while (!end_of_sequence) { + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_LT(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(out_tensors.back(), *expected_outputs_it)); + expected_outputs_it++; + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); +} + +TEST_F(ShardDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + EXPECT_EQ(shard_dataset->node_name(), kNodeName); +} + +TEST_F(ShardDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + EXPECT_EQ(shard_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedShardDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(shard_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedShardDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(shard_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedShardDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + EXPECT_EQ(shard_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_P(ParameterizedShardDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + std::unique_ptr serialization_context; + TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(shard_dataset->Save(serialization_context.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedShardDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(shard_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + shard_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedShardDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(shard_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + shard_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(ShardDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(shard_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + shard_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::Shard"); +} + +TEST_P(ParameterizedShardDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(shard_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + shard_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *shard_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + cur_iteration++; + } + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, + /*compare_order*/ true)); +} + +INSTANTIATE_TEST_SUITE_P(ShardDatasetOpTest, ParameterizedShardDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3(), + TestCase4(), TestCase5(), TestCase6(), + TestCase7()}))); + +TEST_F(ShardDatasetOpTest, InvalidArguments) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::vector test_cases = { + IndexGreaterNumShardsCase(), NegativeIndexTestCase(), + NegativeNumShardsTestCase(), ZeroNumShardsTestCase()}; + for (const auto& test_case : test_cases) { + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + EXPECT_EQ(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +TEST_F(ShardDatasetOpTest, NoElemForEachShard) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + TestCase test_case = NoElemForEachShardTestCase(); + + std::unique_ptr shard_dataset_kernel; + TF_ASSERT_OK(CreateShardDatasetOpKernel( + test_case.require_non_empty, test_case.expected_output_dtypes, + test_case.expected_output_shapes, &shard_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_dataset_param.start, test_case.range_dataset_param.end, + test_case.range_dataset_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + + Tensor num_shards = test_case.num_shards; + Tensor index = test_case.index; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &num_shards, &index}); + std::unique_ptr shard_dataset_context; + TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs, + &shard_dataset_context)); + + DatasetBase* shard_dataset; + TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(), + shard_dataset_context.get(), &shard_dataset)); + core::ScopedUnref scoped_unref_batch_dataset(shard_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(shard_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + shard_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + bool end_of_sequence = false; + std::vector out_tensors; + + EXPECT_EQ( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence) + .code(), + tensorflow::error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 1a193b1d235..add526704f8 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -63,7 +63,15 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { return input_->output_shapes(); } - int64 Cardinality() const override { return input_->Cardinality(); } + int64 Cardinality() const override { + if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) { + return kInfiniteCardinality; + } else if (input_->Cardinality() == kUnknownCardinality) { + return kUnknownCardinality; + } else { + return input_->Cardinality() * count_; + } + } protected: template @@ -129,6 +137,10 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { ctx, this->prefix(), &input_impl_)); } if (!end_of_input_sequence) { + if (num_elements_ == 0) { + VLOG(1) << "Starting to fill up shuffle buffer of size: " + << this->dataset()->buffer_size_; + } this->RecordBufferEnqueue(ctx, input_element); buffer_[slices_.back()->end % this->dataset()->buffer_size_] = std::move(input_element); @@ -352,7 +364,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - int64 buffer_size; + int64 buffer_size = 0; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); OP_REQUIRES( @@ -625,7 +637,7 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - int64 buffer_size; + int64 buffer_size = 0; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); OP_REQUIRES( @@ -641,6 +653,10 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "count", &count)); + OP_REQUIRES(ctx, count > 0 || count == -1, + errors::InvalidArgument( + "count must be greater than zero or equal to -1.")); + // By TensorFlow convention, if both seeds are 0, then shuffling should be // seeded non-deterministically. if (seed == 0 && seed2 == 0) { diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc new file mode 100644 index 00000000000..38b93f13808 --- /dev/null +++ b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc @@ -0,0 +1,915 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kShuffleNodeName[] = "shuffle_dataset"; +constexpr char kShuffleOpName[] = "ShuffleDataset"; +constexpr char kShuffleAndRepeatNodeName[] = "shuffle_and_repeat_dataset"; +constexpr char kShuffleAndRepeatOpName[] = "ShuffleAndRepeatDataset"; + +class ShuffleDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates a new `ShuffleDataset`/`ShuffleAndRepeatDataset` op kernel + Status CreateDatasetOpKernel( + int64 count, bool reshuffle_each_iteration, + const DataTypeVector& output_types, + const std::vector& output_shapes, + std::unique_ptr* shuffle_dataset_kernel) { + NodeDef node_def; + if (count == 1) { + node_def = test::function::NDef( + kShuffleNodeName, kShuffleOpName, + {"input_dataset", "buffer_size", "seed", "seed2"}, + {{"reshuffle_each_iteration", reshuffle_each_iteration}, + {"output_types", output_types}, + {"output_shapes", output_shapes}}); + } else { + node_def = test::function::NDef( + kShuffleAndRepeatNodeName, kShuffleAndRepeatOpName, + {"input_dataset", "buffer_size", "seed", "seed2", "count"}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); + } + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, shuffle_dataset_kernel)); + return Status::OK(); + } + + // Creates a new `ShuffleDataset`/`ShuffleAndRepeatDataset` op kernel context. + Status CreateDatasetContext(OpKernel* const op_kernel, + gtl::InlinedVector* const inputs, + std::unique_ptr* context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct RangeDatasetParam { + int64 start; + int64 end; + int64 step; +}; + +struct TestCase { + RangeDatasetParam range_data_param; + Tensor buffer_size; + Tensor seed; + Tensor seed2; + Tensor count; + bool reshuffle_each_iteration; + std::vector expected_shuffle_outputs; + std::vector expected_reshuffle_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +template +std::vector ConvertToTensorVec(std::vector values) { + std::vector tensors; + tensors.reserve(values.size()); + for (auto& value : values) { + tensors.emplace_back( + DatasetOpsTestBase::CreateTensor(TensorShape({}), {value})); + } + return tensors; +} + +// Test case 1: test shuffle_dataset with reshuffle_each_iteration = false. +TestCase TestCase1() { + return { + /*range_data_param*/ {0, 10, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*reshuffle_each_iteration*/ false, + /*expected_shuffle_outputs*/ + ConvertToTensorVec({2, 3, 0, 5, 6, 4, 7, 8, 9, 1}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec({2, 3, 0, 5, 6, 4, 7, 8, 9, 1}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 2: test shuffle_dataset with reshuffle_each_iteration = true. +TestCase TestCase2() { + return { + /*range_data_param*/ {0, 10, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*reshuffle_each_iteration*/ true, + /*expected_shuffle_outputs*/ + ConvertToTensorVec({2, 6, 1, 3, 9, 5, 0, 8, 7, 4}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec({1, 6, 0, 5, 2, 7, 4, 3, 9, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 3: similar with the test case 2 but a smaller buffer size than +// the input dataset. +TestCase TestCase3() { + return { + /*range_data_param*/ {0, 10, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*reshuffle_each_iteration*/ true, + /*expected_shuffle_outputs*/ + ConvertToTensorVec({0, 2, 1, 3, 5, 6, 4, 7, 8, 9}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec({1, 0, 2, 3, 4, 5, 6, 7, 9, 8}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 4: similar with the test case 2 but has different seeds. +TestCase TestCase4() { + return { + /*range_data_param*/ {0, 10, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*reshuffle_each_iteration*/ true, + /*expected_shuffle_outputs*/ + ConvertToTensorVec({3, 0, 8, 1, 5, 4, 7, 2, 6, 9}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec({4, 6, 9, 0, 1, 8, 2, 7, 3, 5}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 5: test shuffle_dataset with buffer_size = 1 & +// reshuffle_each_iteration = true. +TestCase TestCase5() { + return { + /*range_data_param*/ {0, 10, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*reshuffle_each_iteration*/ true, + /*expected_shuffle_outputs*/ + ConvertToTensorVec({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 6: test shuffle_dataset with an empty input dataset. +TestCase TestCase6() { + return { + /*range_data_param*/ {0, 0, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*reshuffle_each_iteration*/ true, + /*expected_shuffle_outputs*/ + ConvertToTensorVec({}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 7: test shuffle_and_repeat_dataset with buffer_size = 10 & +// count = 2. +TestCase TestCase7() { + return { + /*range_data_param*/ {0, 10, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*reshuffle_each_iteration*/ false, + /*expected_shuffle_outputs*/ + ConvertToTensorVec( + {9, 0, 8, 6, 1, 3, 7, 2, 4, 5, 4, 3, 0, 5, 8, 2, 6, 9, 7, 1}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec( + {9, 0, 8, 6, 1, 3, 7, 2, 4, 5, 4, 3, 0, 5, 8, 2, 6, 9, 7, 1}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 20, + /*breakpoints*/ {0, 5, 22}}; +} + +// Test case 8: test shuffle_and_repeat_dataset with buffer_size = 10 & +// count = -1 +TestCase TestCase8() { + return { + /*range_data_param*/ {0, 3, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {-1}), + /*reshuffle_each_iteration*/ false, + /*expected_shuffle_outputs*/ + ConvertToTensorVec( + {2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2, 1, 0}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec( + {2, 0, 1, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 0, 2, 2, 1, 0}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ kInfiniteCardinality, + /*breakpoints*/ {0, 5, 20}}; +} + +TestCase InvalidBufferSizeTestCaseForShuffleDataset() { + return { + /*range_data_param*/ {0, 10, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {-1}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*reshuffle_each_iteration*/ true, + /*expected_shuffle_outputs*/ ConvertToTensorVec({}), + /*expected_reshuffle_outputs*/ ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +TestCase InvalidBufferSizeTestCaseForShuffleAndRepeatDataset() { + return { + /*range_data_param*/ {0, 10, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {-1}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*reshuffle_each_iteration*/ true, + /*expected_shuffle_outputs*/ ConvertToTensorVec({}), + /*expected_reshuffle_outputs*/ ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +TestCase InvalidCountTestCaseForShuffleAndRepeatDataset() { + return { + /*range_data_param*/ {0, 3, 1}, + /*buffer_size*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {10}), + /*seed*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*seed2*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*count*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + /*reshuffle_each_iteration*/ false, + /*expected_shuffle_outputs*/ + ConvertToTensorVec({}), + /*expected_reshuffle_outputs*/ + ConvertToTensorVec({}), + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 5, 20}}; +} + +class ParameterizedShuffleDatasetOpTest + : public ShuffleDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedShuffleDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + bool end_of_sequence = false; + std::vector shuffled_out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + shuffled_out_tensors.insert(shuffled_out_tensors.end(), next.begin(), + next.end()); + // For the forever-repeat case, we test only a finite number of steps of + // the infinite sequence. + if (count_value == -1 && shuffled_out_tensors.size() == + test_case.expected_shuffle_outputs.size()) { + break; + } + } + + // Reshuffle the dataset. + end_of_sequence = false; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + std::vector reshuffled_out_tensors; + while (!end_of_sequence) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + reshuffled_out_tensors.insert(reshuffled_out_tensors.end(), next.begin(), + next.end()); + // For the forever-repeat case, we test only a finite number of steps of + // the infinite sequence. + if (count_value == -1 && reshuffled_out_tensors.size() == + test_case.expected_shuffle_outputs.size()) { + break; + } + } + + TF_EXPECT_OK(ExpectEqual(shuffled_out_tensors, + test_case.expected_shuffle_outputs, + /*compare_order*/ true)); + TF_EXPECT_OK(ExpectEqual(reshuffled_out_tensors, + test_case.expected_reshuffle_outputs, + /*compare_order*/ true)); +} + +TEST_P(ParameterizedShuffleDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + if (count_value == 1) { + EXPECT_EQ(dataset->node_name(), kShuffleNodeName); + } else { + EXPECT_EQ(dataset->node_name(), kShuffleAndRepeatNodeName); + } +} + +TEST_P(ParameterizedShuffleDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + if (count_value == 1) { + EXPECT_EQ(dataset->type_string(), kShuffleOpName); + } else { + EXPECT_EQ(dataset->type_string(), kShuffleAndRepeatOpName); + } +} + +TEST_P(ParameterizedShuffleDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + TF_EXPECT_OK(VerifyTypesMatch(dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedShuffleDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedShuffleDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_P(ParameterizedShuffleDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr serialization_context; + TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + if (count_value == 1) { + EXPECT_EQ(iterator->prefix(), "Iterator::Shuffle"); + } else { + EXPECT_EQ(iterator->prefix(), "Iterator::ShuffleAndRepeat"); + } +} + +TEST_P(ParameterizedShuffleDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK( + CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, + test_case.expected_output_shapes, &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector& breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + std::vector next; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence)); + out_tensors.insert(out_tensors.end(), next.begin(), next.end()); + cur_iteration++; + } + } + + TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_shuffle_outputs, + /*compare_order*/ true)); +} + +INSTANTIATE_TEST_SUITE_P(ShuffleDatasetOpTest, + ParameterizedShuffleDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3(), + TestCase4(), TestCase5(), TestCase6(), + TestCase7(), TestCase8()}))); + +TEST_F(ShuffleDatasetOpTest, InvalidArguments) { + int thread_num = 2, cpu_num = 2; + std::vector test_cases = { + InvalidBufferSizeTestCaseForShuffleDataset(), + InvalidBufferSizeTestCaseForShuffleAndRepeatDataset(), + InvalidCountTestCaseForShuffleAndRepeatDataset()}; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + for (const auto& test_case : test_cases) { + Tensor count = test_case.count; + int64 count_value = count.flat()(0); + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateDatasetOpKernel( + count_value, test_case.reshuffle_each_iteration, + test_case.expected_output_dtypes, test_case.expected_output_shapes, + &dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor buffer_size = test_case.buffer_size; + Tensor seed = test_case.seed; + Tensor seed2 = test_case.seed2; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &buffer_size, &seed, &seed2}); + if (count_value != 1) inputs.push_back(&count); + + std::unique_ptr dataset_context; + TF_ASSERT_OK( + CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context)); + DatasetBase* shuffle_dataset; + EXPECT_EQ(CreateDataset(dataset_kernel.get(), dataset_context.get(), + &shuffle_dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 5b85a10edf1..4fd28956bf2 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -30,7 +30,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - // Create a new RepeatDatasetOp::Dataset, and return it as the output. + // Create a new SkipDatasetOp::Dataset, and return it as the output. int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "count", &count)); @@ -72,7 +72,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { if (n == kInfiniteCardinality || n == kUnknownCardinality) { return n; } - return std::max(0LL, n - count_); + return count_ < 0 ? 0 : std::max(0LL, n - count_); } protected: diff --git a/tensorflow/core/kernels/data/skip_dataset_op_test.cc b/tensorflow/core/kernels/data/skip_dataset_op_test.cc new file mode 100644 index 00000000000..44e41502867 --- /dev/null +++ b/tensorflow/core/kernels/data/skip_dataset_op_test.cc @@ -0,0 +1,589 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "skip_dataset"; +constexpr char kOpName[] = "SkipDataset"; + +class SkipDatasetOpTest : public DatasetOpsTestBase { + protected: + // Create `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Creates a new `SkipDataset` op kernel. + Status CreateSkipDatasetKernel( + const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"input_dataset", "count"}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Create a new `SkipDataset` op kernel context. + Status CreateSkipDatasetContext( + OpKernel *op_kernel, gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + int64 count; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +// Test case 1: skip fewer than input size. +TestCase SkipLessTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ 4, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {5}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {6}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {7}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 6, + /*breakpoints*/ {0, 2, 7}}; +} + +// Test case 2: skip more than input size. +TestCase SkipMoreTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ 25, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 3: skip exactly the input size. +TestCase SkipAllTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ 10, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 4: skip nothing. +TestCase SkipNothingTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ 0, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {0}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {5}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {6}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {7}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 2, 5, 11}}; +} + +// Test case 5: set -1 for `count` to skip the entire dataset. +TestCase SkipEntireDatasetTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ -1, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5}}; +} + +class ParameterizedSkipDatasetOpTest + : public SkipDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedSkipDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(skip_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + skip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + auto expected_outputs_it = test_case.expected_outputs.begin(); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + for (const auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); +} + +TEST_F(SkipDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = SkipLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + EXPECT_EQ(skip_dataset->node_name(), kNodeName); +} + +TEST_F(SkipDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = SkipLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + EXPECT_EQ(skip_dataset->type_string(), kOpName); +} + +TEST_F(SkipDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = SkipLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(skip_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_F(SkipDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = SkipLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(skip_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedSkipDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + EXPECT_EQ(skip_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_F(SkipDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = SkipLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(skip_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedSkipDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(skip_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + skip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedSkipDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(skip_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + skip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedSkipDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(skip_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + skip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + if (test_case.count < 0) { + EXPECT_EQ(iterator->prefix(), "Iterator::EmptySkip"); + } else { + EXPECT_EQ(iterator->prefix(), "Iterator::FiniteSkip"); + } +} + +TEST_P(ParameterizedSkipDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_skip_dataset( + {&tensor_slice_dataset_tensor, &count}); + + std::unique_ptr skip_dataset_kernel; + TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &skip_dataset_kernel)); + std::unique_ptr skip_dataset_context; + TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(), + &inputs_for_skip_dataset, + &skip_dataset_context)); + DatasetBase *skip_dataset; + TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(), + skip_dataset_context.get(), &skip_dataset)); + core::ScopedUnref scoped_unref(skip_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(skip_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + skip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *skip_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + for (auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_outputs.size()) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } +} + +INSTANTIATE_TEST_SUITE_P(SkipDatasetOpTest, ParameterizedSkipDatasetOpTest, + ::testing::ValuesIn(std::vector( + {SkipLessTestCase(), SkipMoreTestCase(), + SkipAllTestCase(), SkipNothingTestCase(), + SkipEntireDatasetTestCase()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc index cbae2372457..e636463d423 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc @@ -13,19 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/kernels/data/dataset_test_base.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" -#include "tensorflow/core/kernels/data/iterator_ops.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -39,10 +27,10 @@ class SparseTensorSliceDatasetOpTest : public DatasetOpsTestBase { // Creates a new SparseTensorSliceDataset op kernel. Status CreateSparseTensorSliceDatasetKernel( DataType tvalues, std::unique_ptr *op_kernel) { - node_def_ = test::function::NDef(kNodeName, kOpName, - {"indices", "values", "dense_shape"}, - {{"Tvalues", tvalues}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"indices", "values", "dense_shape"}, + {{"Tvalues", tvalues}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); return Status::OK(); } @@ -54,9 +42,6 @@ class SparseTensorSliceDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; struct SparseTensorParam { @@ -65,72 +50,108 @@ struct SparseTensorParam { Tensor dense_shape; }; -struct TestParam { +struct TestCase { SparseTensorParam input_sparse_tensor; std::vector expected_outputs; std::vector breakpoints; -} TestCases[] = { - {{{DatasetOpsTestBase::CreateTensor({2, 2}, {0, 0, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({2}, {888, 999})}, - {DatasetOpsTestBase::CreateTensor({2}, {2, 2})}}, - {{{DatasetOpsTestBase::CreateTensor({1, 1}, {0})}, - {DatasetOpsTestBase::CreateTensor({1}, {888})}, - {DatasetOpsTestBase::CreateTensor({1}, {2})}}, - {{DatasetOpsTestBase::CreateTensor({1, 1}, {1})}, - {DatasetOpsTestBase::CreateTensor({1}, {999})}, - {DatasetOpsTestBase::CreateTensor({1}, {2})}}}, - {0, 1, 2}}, // 2-D sparse tensor - {{{DatasetOpsTestBase::CreateTensor({2, 3}, {0, 0, 0, 1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({2}, {888.0, 999.0})}, - {DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}, - {{{DatasetOpsTestBase::CreateTensor({1, 2}, {0, 0})}, - {DatasetOpsTestBase::CreateTensor({1}, {888.0})}, - {DatasetOpsTestBase::CreateTensor({2}, {2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({1, 2}, {1, 1})}, - {DatasetOpsTestBase::CreateTensor({1}, {999.0})}, - {DatasetOpsTestBase::CreateTensor({2}, {2, 2})}}}, - {0, 1, 2}}, // 3-D sparse tensor - {{{DatasetOpsTestBase::CreateTensor({2, 4}, - {0, 0, 0, 0, 1, 1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({2}, {"a", "b"})}, - {DatasetOpsTestBase::CreateTensor({4}, {3, 2, 2, 2})}}, - {{{DatasetOpsTestBase::CreateTensor({1, 3}, {0, 0, 0})}, - {DatasetOpsTestBase::CreateTensor({1}, {"a"})}, - {DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({1, 3}, {1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({1}, {"b"})}, - {DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({0, 3}, {})}, - {DatasetOpsTestBase::CreateTensor({0}, {})}, - {DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}}, - {0, 1, 3}}, // 4-D sparse tensor - {{{DatasetOpsTestBase::CreateTensor({2, 5}, - {0, 0, 0, 0, 0, 1, 1, 1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({2}, {888, 999})}, - {DatasetOpsTestBase::CreateTensor({5}, {3, 2, 2, 2, 2})}}, - {{{DatasetOpsTestBase::CreateTensor({1, 4}, {0, 0, 0, 0})}, - {DatasetOpsTestBase::CreateTensor({1}, {888})}, - {DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({1, 4}, {1, 1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({1}, {999})}, - {DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({0, 4}, {})}, - {DatasetOpsTestBase::CreateTensor({0}, {})}, - {DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}}}, - {0, 1, 3}} // 5-D sparse tensor - }; -struct DatasetGetNextTest : SparseTensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; +TestCase TwoDimsTestCase() { + return { + /*input_sparse_tensor*/ + {/*indices*/ DatasetOpsTestBase::CreateTensor({2, 2}, + {0, 0, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({2}, {888, 999}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({2}, {2, 2})}, + /*expected_outputs*/ + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 1}, {0}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {888}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({1}, {2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({1, 1}, {1}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {999}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({1}, {2})}}, + /*breakpoints*/ {0, 1, 2}}; +} -TEST_P(DatasetGetNextTest, GetNext) { +TestCase ThreeDimsTestCase() { + return { + /*input_sparse_tensor*/ + {/*indices*/ DatasetOpsTestBase::CreateTensor({2, 3}, + {0, 0, 0, 1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({2}, {888.0, 999.0}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}, + /*expected_outputs*/ + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 2}, {0, 0}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {888.0}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({2}, {2, 2})}, + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 2}, {1, 1})}, + {/*values*/ DatasetOpsTestBase::CreateTensor({1}, {999.0})}, + {/*dense_shape*/ DatasetOpsTestBase::CreateTensor({2}, + {2, 2})}}}, + /*breakpoints*/ {0, 1, 2}}; +} + +TestCase FourDimsTestCase() { + return { + /*input_sparse_tensor*/ + {/*indices*/ DatasetOpsTestBase::CreateTensor( + {2, 4}, {0, 0, 0, 0, 1, 1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({2}, {"a", "b"}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({4}, {3, 2, 2, 2})}, + /*expected_outputs*/ + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 3}, {0, 0, 0}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {"a"}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({1, 3}, {1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {"b"}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({0, 3}, {}), + /*values*/ DatasetOpsTestBase::CreateTensor({0}, {}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}, + /*breakpoints*/ {0, 1, 3}}; +} + +TestCase FiveDimsTestCase() { + return {/*input_sparse_tensor*/ + {/*indices*/ DatasetOpsTestBase::CreateTensor( + {2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({2}, {888, 999}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({5}, {3, 2, 2, 2, 2})}, + /*expected_outputs*/ + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 4}, + {0, 0, 0, 0}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {888}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({1, 4}, + {1, 1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {999}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({0, 4}, {}), + /*values*/ DatasetOpsTestBase::CreateTensor({0}, {}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}}, + /*breakpoints*/ {0, 1, 3}}; +} + +class ParameterizedSparseTensorSliceDatasetOpTest + : public SparseTensorSliceDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - SparseTensorParam input_sparse_tensor = GetParam().input_sparse_tensor; - std::vector expected_outputs = GetParam().expected_outputs; + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values, @@ -153,39 +174,62 @@ TEST_P(DatasetGetNextTest, GetNext) { dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); bool end_of_sequence = false; std::vector out_tensors; - int cur_slice = 0; + auto expected_outputs_it = expected_outputs.begin(); while (!end_of_sequence) { TF_EXPECT_OK( iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); if (!end_of_sequence) { + TF_EXPECT_OK(ExpectEqual(out_tensors[0], expected_outputs_it->indices)); + TF_EXPECT_OK(ExpectEqual(out_tensors[1], expected_outputs_it->values)); TF_EXPECT_OK( - ExpectEqual(out_tensors[0], expected_outputs[cur_slice].indices)); - TF_EXPECT_OK( - ExpectEqual(out_tensors[1], expected_outputs[cur_slice].values)); - TF_EXPECT_OK( - ExpectEqual(out_tensors[2], expected_outputs[cur_slice].dense_shape)); - cur_slice++; + ExpectEqual(out_tensors[2], expected_outputs_it->dense_shape)); + expected_outputs_it++; } } + EXPECT_EQ(expected_outputs_it, expected_outputs.end()); } -INSTANTIATE_TEST_CASE_P(SparseTensorSliceDatasetOpTest, DatasetGetNextTest, - ::testing::ValuesIn(TestCases)); - -TEST_F(SparseTensorSliceDatasetOpTest, DatasetName) { +TEST_F(SparseTensorSliceDatasetOpTest, DatasetNodeName) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - int N = 2; - const int NDIM = 2; - Tensor indices = CreateTensor(TensorShape({N, NDIM}), {0, 0, 1, 1}); - Tensor values = CreateTensor(TensorShape({N}), {888, 999}); - Tensor dense_shape = CreateTensor(TensorShape({NDIM}), {5, 5}); - gtl::InlinedVector inputs = {&indices, &values, &dense_shape}; + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; std::unique_ptr dataset_kernel; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); + + EXPECT_EQ(dataset->node_name(), kNodeName); +} + +TEST_F(SparseTensorSliceDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; + + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); @@ -197,16 +241,14 @@ TEST_F(SparseTensorSliceDatasetOpTest, DatasetName) { EXPECT_EQ(dataset->type_string(), kOpName); } -struct DatasetOutputDtypesTest : SparseTensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetOutputDtypesTest, DatasetOutputDtypes) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - SparseTensorParam input_sparse_tensor = GetParam().input_sparse_tensor; - std::vector expected_outputs = GetParam().expected_outputs; + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values, @@ -229,19 +271,14 @@ TEST_P(DatasetOutputDtypesTest, DatasetOutputDtypes) { VerifyTypesMatch(dataset->output_dtypes(), expected_output_dtypes)); } -INSTANTIATE_TEST_CASE_P(SparseTensorDatasetSliceOpTest, DatasetOutputDtypesTest, - ::testing::ValuesIn(TestCases)); - -struct DatasetOutputShapesTest : SparseTensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetOutputShapesTest, DatasetOutputShapes) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - SparseTensorParam input_sparse_tensor = GetParam().input_sparse_tensor; - std::vector expected_outputs = GetParam().expected_outputs; + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values, @@ -264,19 +301,14 @@ TEST_P(DatasetOutputShapesTest, DatasetOutputShapes) { VerifyShapesCompatible(dataset->output_shapes(), expected_output_shapes)); } -INSTANTIATE_TEST_CASE_P(SparseTensorDatasetSliceOpTest, DatasetOutputShapesTest, - ::testing::ValuesIn(TestCases)); - -struct DatasetCardinalityTest : SparseTensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetCardinalityTest, Cardinality) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - SparseTensorParam input_sparse_tensor = GetParam().input_sparse_tensor; - std::vector expected_outputs = GetParam().expected_outputs; + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values, @@ -295,23 +327,21 @@ TEST_P(DatasetCardinalityTest, Cardinality) { EXPECT_EQ(dataset->Cardinality(), expected_outputs.size()); } -INSTANTIATE_TEST_CASE_P(SparseTensorDatasetSliceOpTest, DatasetCardinalityTest, - ::testing::ValuesIn(TestCases)); - TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - int N = 2; - const int NDIM = 2; - Tensor indices = CreateTensor(TensorShape({N, NDIM}), {0, 0, 1, 1}); - Tensor values = CreateTensor(TensorShape({N}), {888, 999}); - Tensor dense_shape = CreateTensor(TensorShape({NDIM}), {5, 5}); - gtl::InlinedVector inputs = {&indices, &values, &dense_shape}; + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; std::unique_ptr dataset_kernel; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); @@ -328,16 +358,14 @@ TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) { TF_ASSERT_OK(writer.Flush()); } -struct IteratorOutputDtypesTest : SparseTensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorOutputDtypesTest, IteratorOutputDtypes) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - SparseTensorParam input_sparse_tensor = GetParam().input_sparse_tensor; - std::vector expected_outputs = GetParam().expected_outputs; + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values, @@ -365,20 +393,14 @@ TEST_P(IteratorOutputDtypesTest, IteratorOutputDtypes) { VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes)); } -INSTANTIATE_TEST_CASE_P(SparseTensorSliceDatasetOpTest, - IteratorOutputDtypesTest, - ::testing::ValuesIn(TestCases)); - -struct IteratorOutputShapesTest : SparseTensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorOutputShapesTest, IteratorOutputShapes) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - SparseTensorParam input_sparse_tensor = GetParam().input_sparse_tensor; - std::vector expected_outputs = GetParam().expected_outputs; + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values, @@ -406,24 +428,21 @@ TEST_P(IteratorOutputShapesTest, IteratorOutputShapes) { expected_output_shapes)); } -INSTANTIATE_TEST_CASE_P(SparseTensorSliceDatasetOpTest, - IteratorOutputShapesTest, - ::testing::ValuesIn(TestCases)); - TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - int N = 2; - const int NDIM = 2; - Tensor indices = CreateTensor(TensorShape({N, NDIM}), {0, 0, 1, 1}); - Tensor values = CreateTensor(TensorShape({N}), {888, 999}); - Tensor dense_shape = CreateTensor(TensorShape({NDIM}), {5, 5}); - gtl::InlinedVector inputs = {&indices, &values, &dense_shape}; + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; std::unique_ptr dataset_kernel; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); @@ -440,17 +459,15 @@ TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), strings::StrCat("Iterator::SparseTensorSlice")); } -struct IteratorRoundtripTest : SparseTensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorRoundtripTest, Roundtrip) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - SparseTensorParam input_sparse_tensor = GetParam().input_sparse_tensor; - std::vector expected_outputs = GetParam().expected_outputs; - std::vector breakpoints = GetParam().breakpoints; + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + std::vector breakpoints = test_case.breakpoints; DataType tvalues = input_sparse_tensor.values.dtype(); gtl::InlinedVector inputs = { &input_sparse_tensor.indices, &input_sparse_tensor.values, @@ -507,12 +524,16 @@ TEST_P(IteratorRoundtripTest, Roundtrip) { TF_ASSERT_OK(iterator->Save(serialization_ctx.get(), &writer)); TF_ASSERT_OK(writer.Flush()); VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_ctx.get(), &reader)); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *dataset, &iterator)); } } -INSTANTIATE_TEST_CASE_P(SparseTensorSliceDatasetOpTest, IteratorRoundtripTest, - ::testing::ValuesIn(TestCases)); +INSTANTIATE_TEST_SUITE_P(SparseTensorSliceDatasetOpTest, + ParameterizedSparseTensorSliceDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TwoDimsTestCase(), ThreeDimsTestCase(), + FourDimsTestCase(), FiveDimsTestCase()}))); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/stats_utils.cc b/tensorflow/core/kernels/data/stats_utils.cc index eefd92bc665..6dc82cc22d1 100644 --- a/tensorflow/core/kernels/data/stats_utils.cc +++ b/tensorflow/core/kernels/data/stats_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/stats_utils.h" +#include "absl/base/attributes.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 0dd0c0c80de..2983ab51762 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/data/take_dataset_op.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" @@ -20,9 +21,6 @@ namespace tensorflow { namespace data { namespace { -// See documentation in ../../ops/dataset_ops.cc for a high-level -// description of the following op. - class TakeDatasetOp : public UnaryDatasetOpKernel { public: explicit TakeDatasetOp(OpKernelConstruction* ctx) @@ -34,168 +32,130 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { // Create a new TakeDatasetOp::Dataset, and return it as the output. int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "count", &count)); - *output = new Dataset(ctx, count, input); + *output = new TakeDataset(ctx, count, input); } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input) - : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) { - input_->Ref(); - } - - ~Dataset() override { input_->Unref(); } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - if (count_ == 0) { - return absl::make_unique(EmptyIterator::Params{ - this, strings::StrCat(prefix, "::EmptyTake")}); - } else { - return absl::make_unique(FiniteIterator::Params{ - this, strings::StrCat(prefix, "::FiniteTake")}); - } - } - - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); - } - - const std::vector& output_shapes() const override { - return input_->output_shapes(); - } - - string DebugString() const override { return "TakeDatasetOp::Dataset"; } - - int64 Cardinality() const override { - int64 n = input_->Cardinality(); - if (n == kUnknownCardinality) { - return kUnknownCardinality; - } - if (n == kInfiniteCardinality) { - return count_; - } - return std::min(n, count_); - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - Node* count = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, count}, output)); - return Status::OK(); - } - - private: - class EmptyIterator : public DatasetIterator { - public: - explicit EmptyIterator(const Params& params) - : DatasetIterator(params) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - *end_of_sequence = true; - return Status::OK(); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - return Status::OK(); - } - }; - - class FiniteIterator : public DatasetIterator { - public: - explicit FiniteIterator(const Params& params) - : DatasetIterator(params), i_(0) {} - - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. - if (!input_impl_) { - *end_of_sequence = true; - return Status::OK(); - } - while (dataset()->count_ < 0 || i_ < dataset()->count_) { - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (!*end_of_sequence) { - ++i_; - return Status::OK(); - } - break; - } - *end_of_sequence = true; - input_impl_.reset(); - return Status::OK(); - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); - if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - } else { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("input_impl_empty"), "")); - } - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); - if (!reader->Contains(full_name("input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - } else { - input_impl_.reset(); - } - return Status::OK(); - } - - private: - mutex mu_; - int64 i_ GUARDED_BY(mu_); - std::unique_ptr input_impl_ GUARDED_BY(mu_); - }; - - const int64 count_; - const DatasetBase* const input_; - }; }; REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp); - } // namespace + +class TakeDataset::EmptyIterator : public DatasetIterator { + public: + explicit EmptyIterator(const Params& params) + : DatasetIterator(params) {} + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + *end_of_sequence = true; + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return Status::OK(); + } +}; + +class TakeDataset::FiniteIterator : public DatasetIterator { + public: + explicit FiniteIterator(const Params& params) + : DatasetIterator(params), i_(0) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + while (dataset()->count_ < 0 || i_ < dataset()->count_) { + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (!*end_of_sequence) { + ++i_; + return Status::OK(); + } + break; + } + *end_of_sequence = true; + input_impl_.reset(); + return Status::OK(); + } + + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } + + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); + if (input_impl_) { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } + return Status::OK(); + } + + private: + mutex mu_; + int64 i_ GUARDED_BY(mu_); + std::unique_ptr input_impl_ GUARDED_BY(mu_); +}; + +// See documentation in ../../ops/dataset_ops.cc for a high-level +// description of the following op. +std::unique_ptr TakeDataset::MakeIteratorInternal( + const string& prefix) const { + if (count_ == 0) { + return absl::make_unique( + EmptyIterator::Params{this, strings::StrCat(prefix, "::EmptyTake")}); + } else { + return absl::make_unique( + FiniteIterator::Params{this, strings::StrCat(prefix, "::FiniteTake")}); + } +} + +Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* count = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output)); + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/take_dataset_op.h b/tensorflow/core/kernels/data/take_dataset_op.h new file mode 100644 index 00000000000..e35a26bfff4 --- /dev/null +++ b/tensorflow/core/kernels/data/take_dataset_op.h @@ -0,0 +1,81 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { + +class TakeDataset : public DatasetBase { + public: + TakeDataset(OpKernelContext* ctx, int64 count, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) { + input_->Ref(); + } + + TakeDataset(DatasetContext::Params params, int64 count, + const DatasetBase* input) + : DatasetBase(DatasetContext(std::move(params))), + count_(count), + input_(input) { + input_->Ref(); + } + + ~TakeDataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override; + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { return "TakeDatasetOp::Dataset"; } + + int64 Cardinality() const override { + int64 n = input_->Cardinality(); + if (n == kUnknownCardinality) { + return kUnknownCardinality; + } + if (n == kInfiniteCardinality) { + return count_; + } + return std::min(n, count_); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override; + + private: + class EmptyIterator; + class FiniteIterator; + const int64 count_; + const DatasetBase* const input_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/take_dataset_op_test.cc b/tensorflow/core/kernels/data/take_dataset_op_test.cc new file mode 100644 index 00000000000..1b8051df1e9 --- /dev/null +++ b/tensorflow/core/kernels/data/take_dataset_op_test.cc @@ -0,0 +1,583 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "take_dataset"; +constexpr char kOpName[] = "TakeDataset"; + +class TakeDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates `TensorSliceDataset` variant tensor from the input vector of + // tensors. + Status CreateTensorSliceDatasetTensor( + std::vector *const tensor_vector, Tensor *dataset_tensor) { + DatasetBase *tensor_slice_dataset; + TF_RETURN_IF_ERROR(CreateTensorSliceDataset( + "tensor_slice_node", tensor_vector, &tensor_slice_dataset)); + TF_RETURN_IF_ERROR( + StoreDatasetInVariantTensor(tensor_slice_dataset, dataset_tensor)); + return Status::OK(); + } + + // Create a new `TakeDataset` op kernel. + Status CreateTakeDatasetKernel( + const DataTypeVector &output_types, + const std::vector &output_shapes, + std::unique_ptr *op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"input_dataset", "count"}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Create a new `TakeDataset` op kernel context. + Status CreateTakeDatasetContext( + OpKernel *op_kernel, gtl::InlinedVector *const inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct TestCase { + std::vector input_tensors; + int64 count; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +// Test case 1: take fewer than input size. +TestCase TakeLessTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ 4, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {0}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 2, 5}}; +} + +// Test case 2: take more than input size. +TestCase TakeMoreTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ 25, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {0}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {5}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {6}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {7}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 10, + /*breakpoints*/ {0, 2, 5, 11}}; +} + +// Test case 3: take all of input. +TestCase TakeAllTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ -1, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape{1}, {0}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {1}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {2}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {3}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {4}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {5}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {6}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {7}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {8}), + DatasetOpsTestBase::CreateTensor(TensorShape{1}, {9})}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ -1, + /*breakpoints*/ {0, 2, 5, 11}}; +} + +// Test case 4: take nothing. +TestCase TakeNothingTestCase() { + return {/*input_tensors*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape{10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}, + /*count*/ 0, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({1})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 2, 5, 11}}; +} + +class ParameterizedTakeDatasetOpTest + : public TakeDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedTakeDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(take_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + take_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + auto expected_outputs_it = test_case.expected_outputs.begin(); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + for (const auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); +} + +TEST_F(TakeDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = TakeLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + EXPECT_EQ(take_dataset->node_name(), kNodeName); +} + +TEST_F(TakeDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = TakeLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + EXPECT_EQ(take_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + TF_EXPECT_OK(VerifyTypesMatch(take_dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedTakeDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(take_dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + EXPECT_EQ(take_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_F(TakeDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = TakeLessTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(take_dataset->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(take_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + take_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(take_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + take_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(take_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + take_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + if (test_case.count == 0) { + EXPECT_EQ(iterator->prefix(), "Iterator::EmptyTake"); + } else { + EXPECT_EQ(iterator->prefix(), "Iterator::FiniteTake"); + } +} + +TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + auto expected_outputs_it = test_case.expected_outputs.begin(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_take_dataset; + inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_take_dataset.emplace_back(&count); + + std::unique_ptr take_dataset_kernel; + TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &take_dataset_kernel)); + std::unique_ptr take_dataset_context; + TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(), + &inputs_for_take_dataset, + &take_dataset_context)); + DatasetBase *take_dataset; + TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(), + take_dataset_context.get(), &take_dataset)); + core::ScopedUnref scoped_unref(take_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(take_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + take_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *take_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + for (auto &tensor : out_tensors) { + EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(tensor, *expected_outputs_it)); + expected_outputs_it++; + } + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_outputs.size()) { + EXPECT_TRUE(end_of_sequence); + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); + } else { + EXPECT_FALSE(end_of_sequence); + } + } +} + +INSTANTIATE_TEST_SUITE_P(TakeDatasetOpTest, ParameterizedTakeDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TakeLessTestCase(), TakeMoreTestCase(), + TakeAllTestCase(), TakeNothingTestCase()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index a44dbd0d4d4..04698751f80 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" namespace tensorflow { namespace data { @@ -26,15 +27,20 @@ namespace { class TensorDatasetOp : public DatasetOpKernel { public: - explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} + explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutput_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs)); - // TODO(mrry): Validate that the shapes of the "components" tensors match - // the "shapes" attr.; std::vector components(inputs.begin(), inputs.end()); *output = new Dataset(ctx, std::move(components)); + OP_REQUIRES_OK(ctx, + VerifyTypesMatch((*output)->output_dtypes(), output_types_)); + OP_REQUIRES_OK(ctx, VerifyShapesCompatible((*output)->output_shapes(), + output_shapes_)); } private: @@ -137,6 +143,9 @@ class TensorDatasetOp : public DatasetOpKernel { DataTypeVector dtypes_; std::vector shapes_; }; + + DataTypeVector output_types_; + std::vector output_shapes_; }; REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/tensor_dataset_op_test.cc b/tensorflow/core/kernels/data/tensor_dataset_op_test.cc new file mode 100644 index 00000000000..6232a7fa64a --- /dev/null +++ b/tensorflow/core/kernels/data/tensor_dataset_op_test.cc @@ -0,0 +1,536 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/kernels/data/dataset_test_base.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "tensor_dataset"; +constexpr char kOpName[] = "TensorDataset"; + +class TensorDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates a new TensorDataset op kernel. + Status CreateTensorDatasetKernel( + DataTypeVector dtypes, std::vector shapes, + std::unique_ptr *tensor_dataset_kernel) { + std::vector components; + components.reserve(dtypes.size()); + for (int i = 0; i < dtypes.size(); i++) { + components.emplace_back(strings::StrCat("component_", i)); + } + node_def_ = test::function::NDef( + kNodeName, kOpName, components, + {{"Toutput_types", dtypes}, {"output_shapes", shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel)); + return Status::OK(); + } + + // Creates a new TensorDataset op kernel context. + Status CreateTensorDatasetContext(OpKernel *const tensor_dataset_kernel, + gtl::InlinedVector *inputs, + std::unique_ptr *context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*tensor_dataset_kernel, *inputs)); + TF_RETURN_IF_ERROR( + CreateOpKernelContext(tensor_dataset_kernel, inputs, context)); + return Status::OK(); + } + + private: + NodeDef node_def_; +}; + +struct TestCase { + std::vector components; + std::vector expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +// Test case 1: test a dataset that represents a single tuple of plain tensors. +TestCase PlainTensorsTestCase() { + return { + /*components*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({1, 3}), {1, 2, 3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {37.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({1, 2}), + {"a", "b"})}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({1, 3}), {1, 2, 3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {37.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({1, 2}), + {"a", "b"})}, + /*expected_output_dtypes*/ + {DT_INT64, DT_INT64, DT_DOUBLE, DT_STRING}, + /*expected_output_shapes*/ + {PartialTensorShape({}), PartialTensorShape({1, 3}), + PartialTensorShape({}), PartialTensorShape({1, 2})}, + /*expected_cardinality*/ 1, + /*breakpoints*/ {0, 1, 2}}; +} + +// Test case 2: test a dataset that represents a tuple of nested tensors. +TestCase NestedTensorsTestCase() { + return { + /*components*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape({}), {DatasetOpsTestBase::CreateTensor( + TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}), + DatasetOpsTestBase::CreateTensor( + TensorShape({}), {DatasetOpsTestBase::CreateTensor( + TensorShape({1, 2}), {"a", "b"})}), + DatasetOpsTestBase::CreateTensor(TensorShape({1, 3}), {1, 2, 3})}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape({}), {DatasetOpsTestBase::CreateTensor( + TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}), + DatasetOpsTestBase::CreateTensor( + TensorShape({}), {DatasetOpsTestBase::CreateTensor( + TensorShape({1, 2}), {"a", "b"})}), + DatasetOpsTestBase::CreateTensor(TensorShape({1, 3}), {1, 2, 3})}, + /*expected_output_dtypes*/ + {DT_VARIANT, DT_VARIANT, DT_INT64}, + /*expected_output_shapes*/ + {PartialTensorShape({}), PartialTensorShape({}), + PartialTensorShape({1, 3})}, + /*expected_cardinality*/ 1, + /*breakpoints*/ {0, 1, 2}}; +} + +class ParametrizedTensorDatasetOpTest + : public TensorDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParametrizedTensorDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK( + CreateIteratorContext(tensor_dataset_context.get(), &iterator_context)); + std::unique_ptr iterator; + TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator", + &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + while (!end_of_sequence) { + TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, + &end_of_sequence)); + } + EXPECT_EQ(out_tensors.size(), test_case.expected_outputs.size()); + for (int i = 0; i < out_tensors.size(); ++i) { + if (out_tensors[i].dtype() == DT_VARIANT) { + // Currently `ExpectEqual()` does not support the variant tensor + // yet, so we manually cast the variant to numeric/string tensor. + const Tensor *output = out_tensors[i].scalar()().get(); + const Tensor *expected_output = + test_case.expected_outputs[i].scalar()().get(); + TF_EXPECT_OK(ExpectEqual(*output, *expected_output)); + } else { + TF_EXPECT_OK(ExpectEqual(out_tensors[i], test_case.expected_outputs[i])); + } + } +} + +TEST_F(TensorDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PlainTensorsTestCase(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + EXPECT_EQ(tensor_dataset->type_string(), kOpName); +} + +TEST_F(TensorDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PlainTensorsTestCase(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + EXPECT_EQ(tensor_dataset->node_name(), kNodeName); +} + +TEST_F(TensorDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PlainTensorsTestCase(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + EXPECT_EQ(tensor_dataset->output_dtypes(), test_case.expected_output_dtypes); +} + +TEST_F(TensorDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PlainTensorsTestCase(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + EXPECT_EQ(tensor_dataset->output_shapes().size(), + test_case.expected_output_shapes.size()); + for (int i = 0; i < test_case.expected_output_shapes.size(); i++) { + EXPECT_TRUE(test_case.expected_output_shapes[i].IsIdenticalTo( + tensor_dataset->output_shapes()[i])); + } +} + +TEST_F(TensorDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PlainTensorsTestCase(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + EXPECT_EQ(tensor_dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_P(ParametrizedTensorDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + std::unique_ptr serialization_context; + TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(tensor_dataset->Save(serialization_context.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK( + CreateIteratorContext(tensor_dataset_context.get(), &iterator_context)); + std::unique_ptr iterator; + TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator", + &iterator)); + EXPECT_EQ(iterator->output_dtypes(), test_case.expected_output_dtypes); +} + +TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK( + CreateIteratorContext(tensor_dataset_context.get(), &iterator_context)); + std::unique_ptr iterator; + TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator", + &iterator)); + + EXPECT_EQ(iterator->output_shapes().size(), + test_case.expected_output_shapes.size()); + for (int i = 0; i < test_case.expected_output_shapes.size(); ++i) { + EXPECT_TRUE(test_case.expected_output_shapes[i].IsIdenticalTo( + iterator->output_shapes()[i])); + } +} + +TEST_F(TensorDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PlainTensorsTestCase(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK( + CreateIteratorContext(tensor_dataset_context.get(), &iterator_context)); + std::unique_ptr iterator; + TF_ASSERT_OK(tensor_dataset->MakeIterator(iterator_context.get(), "Iterator", + &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::FromTensor"); +} + +TEST_P(ParametrizedTensorDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + std::vector components = test_case.components; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.push_back(&component); + } + std::unique_ptr tensor_dataset_kernel; + TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &tensor_dataset_kernel)); + std::unique_ptr tensor_dataset_context; + TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs, + &tensor_dataset_context)); + DatasetBase *tensor_dataset; + TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(), + tensor_dataset_context.get(), &tensor_dataset)); + core::ScopedUnref scoped_unref(tensor_dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(tensor_dataset_context.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + tensor_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + bool end_of_sequence = false; + std::vector out_tensors; + int cur_iteration = 0; + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *tensor_dataset, &iterator)); + + while (cur_iteration <= breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + EXPECT_EQ(out_tensors.size(), test_case.expected_outputs.size()); + for (int i = 0; i < out_tensors.size(); ++i) { + if (out_tensors[i].dtype() == DT_VARIANT) { + // Currently `ExpectEqual()` does not support the variant tensor + // yet, so we manually cast the variant to numeric/string tensor. + const Tensor *output = + out_tensors[i].scalar()().get(); + const Tensor *expected_output = + test_case.expected_outputs[i].scalar()().get(); + TF_EXPECT_OK(ExpectEqual(*output, *expected_output)); + } else { + TF_EXPECT_OK( + ExpectEqual(out_tensors[i], test_case.expected_outputs[i])); + } + } + } + cur_iteration++; + } + + if (breakpoint >= test_case.expected_cardinality) { + EXPECT_TRUE(end_of_sequence); + } else { + EXPECT_FALSE(end_of_sequence); + } + } +} + +INSTANTIATE_TEST_CASE_P( + TensorDatasetOpTest, ParametrizedTensorDatasetOpTest, + ::testing::ValuesIn(std::vector({PlainTensorsTestCase(), + NestedTensorsTestCase()}))); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 97a1ec402f2..bae1530b6cb 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -129,23 +129,28 @@ class TensorSliceDatasetOp : public DatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { - mutex_lock l(mu_); - if (i_ < n_) { - out_tensors->clear(); - out_tensors->reserve(dataset()->tensors_.size()); - for (int i = 0; i < dataset()->tensors_.size(); ++i) { - const Tensor& t = dataset()->tensors_[i]; - out_tensors->emplace_back( - ctx->allocator({}), t.dtype(), - TensorShape(dataset()->shapes_[i].dim_sizes())); - TF_RETURN_IF_ERROR( - batch_util::CopySliceToElement(t, &out_tensors->back(), i_)); + int64 index = 0; + { + mutex_lock l(mu_); + if (i_ < n_) { + index = i_; + ++i_; + } else { + *end_of_sequence = true; + return Status::OK(); } - ++i_; - *end_of_sequence = false; - } else { - *end_of_sequence = true; } + out_tensors->clear(); + out_tensors->reserve(dataset()->tensors_.size()); + for (int i = 0; i < dataset()->tensors_.size(); ++i) { + const Tensor& t = dataset()->tensors_[i]; + out_tensors->emplace_back( + ctx->allocator({}), t.dtype(), + TensorShape(dataset()->shapes_[i].dim_sizes())); + TF_RETURN_IF_ERROR( + batch_util::CopySliceToElement(t, &out_tensors->back(), index)); + } + *end_of_sequence = false; return Status::OK(); } diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc index 591ef2f011b..ee2619663f0 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc @@ -46,10 +46,10 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase { components.emplace_back(strings::StrCat("component_", i)); } - node_def_ = test::function::NDef( + NodeDef node_def = test::function::NDef( kNodeName, kOpName, components, {{"Toutput_types", dtypes}, {"output_shapes", shapes}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_dataset_kernel)); return Status::OK(); } @@ -63,36 +63,39 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase { CreateOpKernelContext(tensor_dataset_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; -struct TestParam { +struct TestCase { std::vector components; std::vector expected_outputs; std::vector breakpoints; -} TestCases[] = { - // A single tuple of tensors. - {{{DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), - DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), - {1, 2, 3, 4}), - DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), - {37.0, 38.0}), - DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), - {"a", "b"})}}, // components - {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), - DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), - DatasetOpsTestBase::CreateTensor(TensorShape({1}), {37.0}), - DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"a"}), - DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), - DatasetOpsTestBase::CreateTensor(TensorShape({2}), {3, 4}), - DatasetOpsTestBase::CreateTensor(TensorShape({1}), {38.0}), - DatasetOpsTestBase::CreateTensor(TensorShape({1}), - {"b"})}}, // expected_outputs - {{0, 1, 3}}}, // breakpoints - // Nested tensors - {{{DatasetOpsTestBase::CreateTensor( +}; + +TestCase PlainTensorTestCase() { + return {/*components*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), + {1, 2, 3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), + {37.0, 38.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), + {"a", "b"})}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {37.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"a"}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {38.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"b"})}, + /*breakpoints*/ {0, 1, 3}}; +} + +TestCase NestedTensorTestCase() { + return { + /*components*/ + {DatasetOpsTestBase::CreateTensor( TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0}), @@ -103,9 +106,10 @@ struct TestParam { TensorShape({1, 2}), {"a", "b"}), DatasetOpsTestBase::CreateTensor( TensorShape({1, 2}), {"c", "d"})}), - DatasetOpsTestBase::CreateTensor( - TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})}}, // components - {{DatasetOpsTestBase::CreateTensor( + DatasetOpsTestBase::CreateTensor(TensorShape({2, 3}), + {1, 2, 3, 4, 5, 6})}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor( TensorShape({1}), {DatasetOpsTestBase::CreateTensor( TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}), DatasetOpsTestBase::CreateTensor( @@ -118,34 +122,34 @@ struct TestParam { DatasetOpsTestBase::CreateTensor( TensorShape({1}), {DatasetOpsTestBase::CreateTensor( TensorShape({1, 2}), {"c", "d"})}), - DatasetOpsTestBase::CreateTensor( - TensorShape({3}), {4, 5, 6})}}, // expected_outputs - {{0, 1, 2}}} // breakpoints -}; + DatasetOpsTestBase::CreateTensor(TensorShape({3}), {4, 5, 6})}, + /*breakpoints*/ {0, 1, 2}}; +} -struct DatasetGetNextTest : TensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; +class ParameterizedTensorSliceDatasetOpTest + : public TensorSliceDatasetOpTest, + public ::testing::WithParamInterface {}; -TEST_P(DatasetGetNextTest, GetNext) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; - std::vector components = GetParam().components; - std::vector expected_outputs = GetParam().expected_outputs; - size_t num_tensors_per_slice = components.size(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; DataTypeVector dtypes; - std::vector shapes; gtl::InlinedVector inputs; for (auto &component : components) { - inputs.push_back(&component); - dtypes.push_back(component.dtype()); + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); for (int i = 0; i < num_tensors_per_slice; ++i) { shapes.emplace_back(expected_outputs[i].shape()); } - std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -157,7 +161,7 @@ TEST_P(DatasetGetNextTest, GetNext) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); std::unique_ptr iterator_context; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), @@ -194,59 +198,26 @@ TEST_P(DatasetGetNextTest, GetNext) { } } -INSTANTIATE_TEST_CASE_P(TensorDatasetSliceOpTest, DatasetGetNextTest, - ::testing::ValuesIn(TestCases)); - -TEST_F(TensorSliceDatasetOpTest, DatasetName) { +TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - Tensor t1 = CreateTensor(TensorShape({2, 2}), {1, 2, 3, 4}); - Tensor t2 = CreateTensor(TensorShape({2, 2}), {5, 6, 7, 8}); - gtl::InlinedVector inputs = {&t1, &t2}; - DataTypeVector dtypes({DT_INT64, DT_INT64}); - std::vector shapes = {PartialTensorShape({2}), - PartialTensorShape({2})}; - std::unique_ptr tensor_slice_dataset_kernel; - TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, - &tensor_slice_dataset_kernel)); - std::unique_ptr tensor_slice_dataset_context; - TF_ASSERT_OK( - CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), - &inputs, &tensor_slice_dataset_context)); - DatasetBase *tensor_slice_dataset; - TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), - tensor_slice_dataset_context.get(), - &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); - - EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName); -} - -struct DatasetOutputDtypesTest : TensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetOutputDtypesTest, DatasetOutputDtypes) { - int thread_num = 2, cpu_num = 2; - std::vector components = GetParam().components; - std::vector expected_outputs = GetParam().expected_outputs; - size_t num_tensors_per_slice = components.size(); - - TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - + const TestCase &test_case = PlainTensorTestCase(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; DataTypeVector dtypes; - std::vector shapes; gtl::InlinedVector inputs; for (auto &component : components) { inputs.emplace_back(&component); dtypes.emplace_back(component.dtype()); } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); for (int i = 0; i < num_tensors_per_slice; ++i) { shapes.emplace_back(expected_outputs[i].shape()); } - std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -258,7 +229,79 @@ TEST_P(DatasetOutputDtypesTest, DatasetOutputDtypes) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + EXPECT_EQ(tensor_slice_dataset->node_name(), kNodeName); +} + +TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PlainTensorTestCase(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); const DataTypeVector produced_output_dtypes = tensor_slice_dataset->output_dtypes(); @@ -268,28 +311,23 @@ TEST_P(DatasetOutputDtypesTest, DatasetOutputDtypes) { } } -INSTANTIATE_TEST_CASE_P(TensorDatasetSliceOpTest, DatasetOutputDtypesTest, - ::testing::ValuesIn(TestCases)); - -struct DatasetOutputShapesTest : TensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetOutputShapesTest, DatasetOutputShapes) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; - std::vector components = GetParam().components; - std::vector expected_outputs = GetParam().expected_outputs; - size_t num_tensors_per_slice = components.size(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; DataTypeVector dtypes; - std::vector shapes; gtl::InlinedVector inputs; for (auto &component : components) { inputs.emplace_back(&component); dtypes.emplace_back(component.dtype()); } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); for (int i = 0; i < num_tensors_per_slice; ++i) { shapes.emplace_back(expected_outputs[i].shape()); } @@ -304,7 +342,7 @@ TEST_P(DatasetOutputShapesTest, DatasetOutputShapes) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); const std::vector produced_output_shapes = tensor_slice_dataset->output_shapes(); @@ -316,28 +354,23 @@ TEST_P(DatasetOutputShapesTest, DatasetOutputShapes) { } } -INSTANTIATE_TEST_CASE_P(TensorDatasetSliceOpTest, DatasetOutputShapesTest, - ::testing::ValuesIn(TestCases)); - -struct DatasetCardinalityTest : TensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(DatasetCardinalityTest, Cardinality) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; - std::vector components = GetParam().components; - std::vector expected_outputs = GetParam().expected_outputs; - size_t num_tensors_per_slice = components.size(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; DataTypeVector dtypes; - std::vector shapes; gtl::InlinedVector inputs; for (auto &component : components) { inputs.emplace_back(&component); dtypes.emplace_back(component.dtype()); } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); for (int i = 0; i < num_tensors_per_slice; ++i) { shapes.emplace_back(expected_outputs[i].shape()); } @@ -352,25 +385,31 @@ TEST_P(DatasetCardinalityTest, Cardinality) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0)); } -INSTANTIATE_TEST_CASE_P(TensorDatasetSliceOpTest, DatasetCardinalityTest, - ::testing::ValuesIn(TestCases)); - TEST_F(TensorSliceDatasetOpTest, DatasetSave) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - Tensor t1 = CreateTensor(TensorShape({2, 2}), {1, 2, 3, 4}); - Tensor t2 = CreateTensor(TensorShape({2, 2}), {5, 6, 7, 8}); - gtl::InlinedVector inputs = {&t1, &t2}; - DataTypeVector dtypes({DT_INT64, DT_INT64}); - std::vector shapes = {PartialTensorShape({2}), - PartialTensorShape({2})}; + const TestCase &test_case = PlainTensorTestCase(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -382,7 +421,7 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); std::unique_ptr serialization_context; TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); @@ -393,29 +432,26 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) { TF_ASSERT_OK(writer.Flush()); } -struct IteratorOutputDtypesTest : TensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorOutputDtypesTest, IteratorOutputDtypes) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; - std::vector components = GetParam().components; - std::vector expected_outputs = GetParam().expected_outputs; - size_t num_tensors_per_slice = components.size(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; DataTypeVector dtypes; - std::vector shapes; gtl::InlinedVector inputs; for (auto &component : components) { inputs.emplace_back(&component); dtypes.emplace_back(component.dtype()); } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); for (int i = 0; i < num_tensors_per_slice; ++i) { shapes.emplace_back(expected_outputs[i].shape()); } - std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -427,7 +463,7 @@ TEST_P(IteratorOutputDtypesTest, IteratorOutputDtypes) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); std::unique_ptr iterator_context; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), @@ -443,32 +479,26 @@ TEST_P(IteratorOutputDtypesTest, IteratorOutputDtypes) { } } -INSTANTIATE_TEST_CASE_P(TensorDatasetSliceOpTest, IteratorOutputDtypesTest, - ::testing::ValuesIn(TestCases)); - -struct IteratorOutputShapesTest : TensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorOutputShapesTest, IteratorOutputShapes) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; - std::vector components = GetParam().components; - std::vector expected_outputs = GetParam().expected_outputs; - size_t num_tensors_per_slice = components.size(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; DataTypeVector dtypes; - std::vector shapes; gtl::InlinedVector inputs; for (auto &component : components) { inputs.emplace_back(&component); dtypes.emplace_back(component.dtype()); } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); for (int i = 0; i < num_tensors_per_slice; ++i) { shapes.emplace_back(expected_outputs[i].shape()); } - std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -480,7 +510,7 @@ TEST_P(IteratorOutputShapesTest, IteratorOutputShapes) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); std::unique_ptr iterator_context; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), @@ -497,20 +527,26 @@ TEST_P(IteratorOutputShapesTest, IteratorOutputShapes) { } } -INSTANTIATE_TEST_CASE_P(TensorDatasetSliceOpTest, IteratorOutputShapesTest, - ::testing::ValuesIn(TestCases)); - TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - Tensor t1 = CreateTensor(TensorShape({2, 2}), {1, 2, 3, 4}); - Tensor t2 = CreateTensor(TensorShape({2, 2}), {5, 6, 7, 8}); - gtl::InlinedVector inputs = {&t1, &t2}; - DataTypeVector dtypes({DT_INT64, DT_INT64}); - std::vector shapes = {PartialTensorShape({2}), - PartialTensorShape({2})}; + const TestCase &test_case = PlainTensorTestCase(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -522,7 +558,7 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); std::unique_ptr iterator_context; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), @@ -533,30 +569,26 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice"); } -struct IteratorRoundtripTest : TensorSliceDatasetOpTest, - ::testing::WithParamInterface {}; - -TEST_P(IteratorRoundtripTest, Roundtrip) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; - std::vector components = GetParam().components; - std::vector expected_outputs = GetParam().expected_outputs; - std::vector breakpoints = GetParam().breakpoints; - size_t num_tensors_per_slice = components.size(); - TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; DataTypeVector dtypes; - std::vector shapes; gtl::InlinedVector inputs; for (auto &component : components) { inputs.emplace_back(&component); dtypes.emplace_back(component.dtype()); } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); for (int i = 0; i < num_tensors_per_slice; ++i) { shapes.emplace_back(expected_outputs[i].shape()); } - std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -568,7 +600,7 @@ TEST_P(IteratorRoundtripTest, Roundtrip) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); std::unique_ptr iterator_context; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), @@ -583,7 +615,7 @@ TEST_P(IteratorRoundtripTest, Roundtrip) { bool end_of_sequence = false; int64 num_slices = inputs[0].tensor->dim_size(0); std::vector out_tensors; - + const std::vector &breakpoints = test_case.breakpoints; for (int breakpoint : breakpoints) { while (cur_iteration < breakpoint) { TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, @@ -618,12 +650,15 @@ TEST_P(IteratorRoundtripTest, Roundtrip) { TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); TF_ASSERT_OK(writer.Flush()); VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); + TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader, "Iterator", + *tensor_slice_dataset, &iterator)); } } -INSTANTIATE_TEST_CASE_P(TensorDatasetSliceOpTest, IteratorRoundtripTest, - ::testing::ValuesIn(TestCases)); +INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest, + ParameterizedTensorSliceDatasetOpTest, + ::testing::ValuesIn(std::vector( + {PlainTensorTestCase(), NestedTensorTestCase()}))); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 0b24c118914..bfe2ef35280 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -103,8 +103,18 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { if (n == kInfiniteCardinality || n == kUnknownCardinality) { return n; } - return n / window_shift_ + - (n % window_shift_ == 0 || drop_remainder_ ? 0 : 1); + int64 cardinality = 0; + if (drop_remainder_) { + // Compute rest_elements, the number of elements after the last element + // of the initial window. If it is negative, we know that the + // cardinality is 0. Otherwise, it will be the number of valid shifts + // over the rest_elements. + int64 rest_elements = n - ((window_size_ - 1) * window_stride_ + 1); + cardinality = rest_elements < 0 ? 0 : rest_elements / window_shift_ + 1; + } else { + cardinality = n / window_shift_ + (n % window_shift_ == 0 ? 0 : 1); + } + return cardinality; } protected: @@ -288,7 +298,7 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { input_impl_.reset(); } // Restore buffer. - int64 buffer_size; + int64 buffer_size = 0; TF_RETURN_IF_ERROR( reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size)); buffer_.resize(buffer_size); diff --git a/tensorflow/core/kernels/data/window_dataset_op_test.cc b/tensorflow/core/kernels/data/window_dataset_op_test.cc new file mode 100644 index 00000000000..97debfd7321 --- /dev/null +++ b/tensorflow/core/kernels/data/window_dataset_op_test.cc @@ -0,0 +1,883 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_test_base.h" + +namespace tensorflow { +namespace data { +namespace { + +constexpr char kNodeName[] = "window_dataset"; +constexpr char kOpName[] = "WindowDataset"; + +class WindowDatasetOpTest : public DatasetOpsTestBase { + protected: + // Creates a new `WindowDataset` op kernel + Status CreateWindowDatasetKernel( + const DataTypeVector& output_types, + const std::vector& output_shapes, + std::unique_ptr* op_kernel) { + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, + {"input_dataset", "size", "shift", "stride", "drop_remainder"}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); + return Status::OK(); + } + + // Creates a new `WindowDataset` op kernel context. + Status CreateWindowDatasetContext( + OpKernel* const op_kernel, + gtl::InlinedVector* const inputs, + std::unique_ptr* context) { + TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); + TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); + return Status::OK(); + } +}; + +struct RangeDatasetParam { + int64 start; + int64 end; + int64 step; +}; + +struct TestCase { + RangeDatasetParam range_data_param; + Tensor size; + Tensor shift; + Tensor stride; + Tensor drop_remainder; + std::vector> expected_outputs; + DataTypeVector expected_output_dtypes; + std::vector expected_output_shapes; + int64 expected_cardinality; + std::vector breakpoints; +}; + +// Test case 1: size=2, shift=2, stride=1, drop_remainder=false. +TestCase TestCase1() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {false}), + /*expected_outputs*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {6})}}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 2: size=2, shift=2, stride=2, drop_remainder=true. +TestCase TestCase2() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6})}}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 3: size=8, shift=3, stride=1, drop_remainder=false. +TestCase TestCase3() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {false}), + /*expected_outputs*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {6})}}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 3, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 4: size=8, shift=3, stride=1, drop_remainder=true. +TestCase TestCase4() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {3}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 5: size=2, shift=8, stride=1, drop_remainder=false. +TestCase TestCase5() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {false}), + /*expected_outputs*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1})}}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 1, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 6: size=2, shift=8, stride=1, drop_remainder=true. +TestCase TestCase6() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {1})}}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 1, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 7: size=2, shift=2, stride=8, drop_remainder=false. +TestCase TestCase7() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {false}), + /*expected_outputs*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {0})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {2})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {4})}, + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {6})}}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 4, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 8: size=2, shift=2, stride=8, drop_remainder=true. +TestCase TestCase8() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {8}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 9: size=4, shift=2, stride=2, drop_remainder=true. +TestCase TestCase9() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ + {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {4}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {6})}}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 1, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 10: size=5, shift=2, stride=2, drop_remainder=true. +TestCase TestCase10() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {5}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 11: size=0, shift=2, stride=2, drop_remainder=true. +TestCase InvalidWindowSizeTestCase() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 12: size=2, shift=0, stride=2, drop_remainder=true. +TestCase InvalidWindowShiftTestCase() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +// Test case 13: size=2, shift=2, stride=0, drop_remainder=true. +TestCase InvalidWindowStrideTestCase() { + return { + /*range_data_param*/ {0, 7, 1}, + /*size*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*shift*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + /*stride*/ DatasetOpsTestBase::CreateTensor(TensorShape({}), {0}), + /*drop_remainder*/ + DatasetOpsTestBase::CreateTensor(TensorShape({}), {true}), + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_VARIANT}, + /*expected_output_shapes*/ {PartialTensorShape({})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {0, 1, 9}}; +} + +class ParameterizedWindowDatasetOpTest + : public WindowDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedWindowDatasetOpTest, GetNext) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(window_dataset_op_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); + while (!end_of_sequence) { + // Owns the window_datasets, which are stored as the variant tensors in the + // vector. + std::vector out_tensors; + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + for (const auto& window_dataset_tensor : out_tensors) { + // Not owned. + DatasetBase* window_dataset; + TF_ASSERT_OK(GetDatasetFromVariantTensor(window_dataset_tensor, + &window_dataset)); + std::unique_ptr window_dataset_iterator; + TF_ASSERT_OK(window_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &window_dataset_iterator)); + bool end_of_window_dataset = false; + std::vector window_elements; + // Fetches all the elements in window_dataset. + while (!end_of_window_dataset) { + std::vector next_element; + TF_EXPECT_OK(window_dataset_iterator->GetNext( + iterator_ctx.get(), &next_element, &end_of_window_dataset)); + window_elements.insert(window_elements.end(), next_element.begin(), + next_element.end()); + } + EXPECT_LT(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK(ExpectEqual(window_elements, *expected_outputs_it, false)); + expected_outputs_it++; + } + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); +} + +TEST_F(WindowDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + EXPECT_EQ(dataset->node_name(), kNodeName); +} + +TEST_F(WindowDatasetOpTest, DatasetTypeString) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + EXPECT_EQ(dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedWindowDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + TF_EXPECT_OK(VerifyTypesMatch(dataset->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedWindowDatasetOpTest, DatasetOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(dataset->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_P(ParameterizedWindowDatasetOpTest, Cardinality) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality); +} + +TEST_P(ParameterizedWindowDatasetOpTest, DatasetSave) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr serialization_context; + TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); +} + +TEST_P(ParameterizedWindowDatasetOpTest, IteratorOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(window_dataset_op_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), + test_case.expected_output_dtypes)); +} + +TEST_P(ParameterizedWindowDatasetOpTest, IteratorOutputShapes) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(window_dataset_op_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + test_case.expected_output_shapes)); +} + +TEST_F(WindowDatasetOpTest, IteratorOutputPrefix) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = TestCase1(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(window_dataset_op_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::Window"); +} + +TEST_P(ParameterizedWindowDatasetOpTest, Roundtrip) { + int thread_num = 2, cpu_num = 2; + TestCase test_case = GetParam(); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs, + &window_dataset_op_ctx)); + DatasetBase* dataset; + TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref_dataset(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK( + CreateIteratorContext(window_dataset_op_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + bool end_of_sequence = false; + auto expected_outputs_it = test_case.expected_outputs.begin(); + int cur_iteration = 0; + for (int breakpoint : test_case.breakpoints) { + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *dataset, &iterator)); + while (cur_iteration <= breakpoint) { + while (!end_of_sequence) { + // Owns the datasets, which are stored as the variant tensors in the + // vector. + std::vector out_tensors; + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + if (!end_of_sequence) { + for (const auto& window_dataset_tensor : out_tensors) { + // Not owned. + DatasetBase* window_dataset; + TF_ASSERT_OK(GetDatasetFromVariantTensor(window_dataset_tensor, + &window_dataset)); + std::unique_ptr window_dataset_iterator; + TF_ASSERT_OK(window_dataset->MakeIterator( + iterator_ctx.get(), "Iterator", &window_dataset_iterator)); + bool end_of_window_dataset = false; + std::vector window_elements; + while (!end_of_window_dataset) { + std::vector next_element; + TF_EXPECT_OK(window_dataset_iterator->GetNext( + iterator_ctx.get(), &next_element, &end_of_window_dataset)); + window_elements.insert(window_elements.end(), + next_element.begin(), next_element.end()); + } + EXPECT_LT(expected_outputs_it, test_case.expected_outputs.end()); + TF_EXPECT_OK( + ExpectEqual(window_elements, *expected_outputs_it, false)); + expected_outputs_it++; + } + } + } + cur_iteration++; + } + } + EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); +} + +INSTANTIATE_TEST_SUITE_P( + WindowDatasetOpTest, ParameterizedWindowDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TestCase1(), TestCase2(), TestCase3(), TestCase4(), TestCase5(), + TestCase6(), TestCase7(), TestCase8(), TestCase9(), TestCase10()}))); + +TEST_F(WindowDatasetOpTest, InvalidArguments) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + std::vector test_cases({InvalidWindowSizeTestCase(), + InvalidWindowShiftTestCase(), + InvalidWindowStrideTestCase()}); + for (const auto& test_case : test_cases) { + std::unique_ptr window_dataset_kernel; + TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &window_dataset_kernel)); + DatasetBase* range_dataset; + TF_ASSERT_OK(CreateRangeDataset( + test_case.range_data_param.start, test_case.range_data_param.end, + test_case.range_data_param.step, "range", &range_dataset)); + Tensor range_dataset_tensor(DT_VARIANT, TensorShape({})); + TF_ASSERT_OK( + StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor)); + Tensor size = test_case.size; + Tensor shift = test_case.shift; + Tensor stride = test_case.stride; + Tensor drop_remainder = test_case.drop_remainder; + gtl::InlinedVector inputs( + {&range_dataset_tensor, &size, &shift, &stride, &drop_remainder}); + + std::unique_ptr window_dataset_op_ctx; + TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), + &inputs, &window_dataset_op_ctx)); + DatasetBase* dataset; + EXPECT_EQ(CreateDataset(window_dataset_kernel.get(), + window_dataset_op_ctx.get(), &dataset) + .code(), + tensorflow::error::INVALID_ARGUMENT); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/zip_dataset_op_test.cc b/tensorflow/core/kernels/data/zip_dataset_op_test.cc index 9f9e86a3d08..41ea16d0009 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op_test.cc @@ -58,10 +58,10 @@ class ZipDatasetOpTest : public DatasetOpsTestBase { // Create the placeholder names for the input components of `ZipDataset`. input_datasets.emplace_back(strings::StrCat("input_dataset_", i)); } - node_def_ = test::function::NDef( + NodeDef node_def = test::function::NDef( kNodeName, kOpName, input_datasets, {{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); return Status::OK(); } @@ -74,9 +74,6 @@ class ZipDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; struct TestParam { @@ -85,8 +82,8 @@ struct TestParam { std::vector breakpoints; }; +// Test case 1: the input datasets with same number of outputs. TestParam TestCase1() { - // Test case 1: the input datasets with same number of outputs. return {/*input_range_dataset_params*/ {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}}, /*expected_outputs*/ @@ -99,8 +96,8 @@ TestParam TestCase1() { /*breakpoints*/ {0, 1, 4}}; } +// Test case 2: the input datasets with different number of outputs. TestParam TestCase2() { - // Test case 2: the input datasets with different number of outputs. return {/*input_range_dataset_params*/ {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}}, /*expected_outputs*/ @@ -113,67 +110,48 @@ TestParam TestCase2() { /*breakpoints*/ {0, 1, 4}}; } -class ZipDatasetOpTestHelper : public ZipDatasetOpTest { - public: - ~ZipDatasetOpTestHelper() override { - if (dataset_) dataset_->Unref(); - } - - protected: - Status CreateDatasetFromTestCase(const TestParam &test_case) { - std::vector range_dataset_tensors; - range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); - TF_RETURN_IF_ERROR(CreateRangeDatasetTensors( - test_case.input_range_dataset_params, &range_dataset_tensors)); - gtl::InlinedVector inputs; - inputs.reserve(range_dataset_tensors.size()); - for (auto &tensor : range_dataset_tensors) { - inputs.emplace_back(&tensor); - } - int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_RETURN_IF_ERROR(CreateZipDatasetKernel({DT_INT64}, - {{num_tensors_per_slice}}, - inputs.size(), &dataset_kernel_)); - TF_RETURN_IF_ERROR(CreateZipDatasetContext(dataset_kernel_.get(), &inputs, - &dataset_kernel_ctx_)); - TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(), - dataset_kernel_ctx_.get(), &dataset_)); - return Status::OK(); - } - - Status CreateIteratorFromTestCase(const TestParam &test_case) { - TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case)); - TF_RETURN_IF_ERROR( - CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_)); - TF_RETURN_IF_ERROR( - dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_)); - return Status::OK(); - } - - std::unique_ptr dataset_kernel_; - std::unique_ptr dataset_kernel_ctx_; - DatasetBase *dataset_ = nullptr; // owned by this class. - std::unique_ptr iterator_ctx_; - std::unique_ptr iterator_; -}; - -class ParameterizedDatasetTest - : public ZipDatasetOpTestHelper, +class ParameterizedZipDatasetOpTest + : public ZipDatasetOpTest, public ::testing::WithParamInterface {}; -TEST_P(ParameterizedDatasetTest, GetNext) { +TEST_P(ParameterizedZipDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); auto expected_outputs_it = test_case.expected_outputs.begin(); bool end_of_sequence = false; std::vector out_tensors; while (!end_of_sequence) { - TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, - &end_of_sequence)); + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); if (!end_of_sequence) { for (const auto &tensor : out_tensors) { EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); @@ -185,22 +163,92 @@ TEST_P(ParameterizedDatasetTest, GetNext) { EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -TEST_F(ZipDatasetOpTestHelper, DatasetName) { +TEST_F(ZipDatasetOpTest, DatasetNodeName) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1())); - EXPECT_EQ(dataset_->type_string(), kOpName); + const TestParam &test_case = TestCase1(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + + EXPECT_EQ(zip_dataset->node_name(), kNodeName); } -TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { +TEST_F(ZipDatasetOpTest, DatasetTypeString) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); + + const TestParam &test_case = TestCase1(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + + EXPECT_EQ(zip_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); DataTypeVector expected_output_dtypes; expected_output_dtypes.reserve(num_tensors_per_slice); @@ -209,16 +257,35 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { } TF_EXPECT_OK( - VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes)); + VerifyTypesMatch(zip_dataset->output_dtypes(), expected_output_dtypes)); } -TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { +TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); std::vector expected_output_shapes; expected_output_shapes.reserve(num_tensors_per_slice); @@ -226,43 +293,107 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape()); } - TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(), + TF_EXPECT_OK(VerifyShapesCompatible(zip_dataset->output_shapes(), expected_output_shapes)); } -TEST_P(ParameterizedDatasetTest, Cardinality) { +TEST_P(ParameterizedZipDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); - int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); - EXPECT_EQ(dataset_->Cardinality(), + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + + EXPECT_EQ(zip_dataset->Cardinality(), test_case.expected_outputs.size() / num_tensors_per_slice); } -TEST_F(ZipDatasetOpTestHelper, DatasetSave) { +TEST_F(ZipDatasetOpTest, DatasetSave) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1())); + + const TestParam &test_case = TestCase1(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); std::unique_ptr serialization_ctx; TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); VariantTensorData data; VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(zip_dataset->Save(serialization_ctx.get(), &writer)); TF_ASSERT_OK(writer.Flush()); } -TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { +TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); DataTypeVector expected_output_dtypes; expected_output_dtypes.reserve(num_tensors_per_slice); @@ -271,16 +402,40 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { } TF_EXPECT_OK( - VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes)); + VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes)); } -TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { +TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); std::vector expected_output_shapes; expected_output_shapes.reserve(num_tensors_per_slice); @@ -288,43 +443,96 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape()); } - TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(), + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), expected_output_shapes)); } -TEST_F(ZipDatasetOpTestHelper, IteratorOutputPrefix) { +TEST_F(ZipDatasetOpTest, IteratorOutputPrefix) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1())); - EXPECT_EQ(iterator_->prefix(), "Iterator::Zip"); + + const TestParam &test_case = TestCase1(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::Zip"); } -TEST_P(ParameterizedDatasetTest, Roundtrip) { +TEST_P(ParameterizedZipDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); - auto expected_outputs_it = test_case.expected_outputs.begin(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); std::unique_ptr serialization_ctx; TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; std::vector out_tensors; + auto expected_outputs_it = test_case.expected_outputs.begin(); int cur_iteration = 0; for (int breakpoint : test_case.breakpoints) { VariantTensorData data; VariantTensorDataWriter writer(&data); - TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); TF_EXPECT_OK(writer.Flush()); VariantTensorDataReader reader(&data); - TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader)); + TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator", + *zip_dataset, &iterator)); while (cur_iteration < breakpoint) { - TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, - &end_of_sequence)); + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); if (!end_of_sequence) { for (auto &tensor : out_tensors) { EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); @@ -335,7 +543,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) { cur_iteration++; } - if (breakpoint >= dataset_->Cardinality()) { + if (breakpoint >= zip_dataset->Cardinality()) { EXPECT_TRUE(end_of_sequence); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } else { @@ -345,7 +553,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) { } INSTANTIATE_TEST_SUITE_P( - ZipDatasetOpTest, ParameterizedDatasetTest, + ZipDatasetOpTest, ParameterizedZipDatasetOpTest, ::testing::ValuesIn(std::vector({TestCase1(), TestCase2()}))); } // namespace diff --git a/tensorflow/core/kernels/debug_ops_test.cc b/tensorflow/core/kernels/debug_ops_test.cc index 273962be997..12ea7db1ea1 100644 --- a/tensorflow/core/kernels/debug_ops_test.cc +++ b/tensorflow/core/kernels/debug_ops_test.cc @@ -364,7 +364,7 @@ TEST_F(DebugNumericSummaryOpTest, Float_only_valid_values) { 7.33333333333, // variance of non-inf and non-nan elements. static_cast(DT_FLOAT), // dtype 2.0, // Number of dimensions. - 2.0, 3.0}); // Dimensoin sizes. + 2.0, 3.0}); // Dimension sizes. test::ExpectTensorNear(expected, *GetOutput(0), 1e-8); } diff --git a/tensorflow/core/kernels/decode_padded_raw_op.cc b/tensorflow/core/kernels/decode_padded_raw_op.cc new file mode 100644 index 00000000000..1e6a0cb7606 --- /dev/null +++ b/tensorflow/core/kernels/decode_padded_raw_op.cc @@ -0,0 +1,139 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +template +class DecodePaddedRawOp : public OpKernel { + public: + explicit DecodePaddedRawOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_type_)); + + const bool host_is_little_endian = port::kLittleEndian; + bool data_is_little_endian; + OP_REQUIRES_OK(context, + context->GetAttr("little_endian", &data_is_little_endian)); + convert_data_endianness_ = host_is_little_endian != data_is_little_endian; + } + + void Compute(OpKernelContext* context) override { + const auto& input = context->input(0); + auto flat_in = input.flat(); + + int fixed_length; + const auto& length_input = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(length_input.shape()), + errors::InvalidArgument("k must be scalar, got shape ", + length_input.shape().DebugString())); + fixed_length = length_input.scalar()(); + + OP_REQUIRES( + context, fixed_length % sizeof(T) == 0, + errors::InvalidArgument( + "fixed_length (", fixed_length, + ") must be a multiple of the size of out_type (", sizeof(T), ")")); + + OP_REQUIRES(context, fixed_length > 0, + errors::InvalidArgument("fixed_length (", fixed_length, + ") must be greater than zero.")); + + int width = fixed_length / sizeof(T); + + TensorShape out_shape = input.shape(); + out_shape.AddDim(width); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output("output", out_shape, &output_tensor)); + + if (flat_in.size() == 0) { // Empty input + return; + } + + auto out = output_tensor->flat_inner_dims(); + T* out_data = out.data(); + + // Forcibly clear memory - we're going to copy variable length strings in, + // and need to ensure that if we don't write to byte N when we copy, that + // we're not getting random data. + memset(out_data, 0, fixed_length * flat_in.size()); + + // If the data is already in the host's byte order, or if the width of the + // output type is a single byte (meaning the ordering doesn't matter), we + // can copy the memory directly. + if (!convert_data_endianness_ || sizeof(T) == 1) { + for (int64 i = 0; i < flat_in.size(); ++i) { + const T* in_data = reinterpret_cast(flat_in(i).data()); + + if (flat_in(i).size() > fixed_length) { + memcpy(out_data, in_data, fixed_length); + } else { + memcpy(out_data, in_data, flat_in(i).size()); + } + out_data += fixed_length; + } + } else { + // Otherwise, the data is not in the host's byte order, and rather than a + // direct copy, we need to reverse the byte ordering of each element. + for (int64 i = 0; i < flat_in.size(); ++i) { + const char* in_data_bytes = + reinterpret_cast(flat_in(i).data()); + char* out_data_bytes = reinterpret_cast(out_data); + const char* p_in = in_data_bytes; + char* p_out = out_data_bytes; + for (; p_in < in_data_bytes + fixed_length; + p_in += sizeof(T), p_out += sizeof(T)) { + std::reverse_copy(p_in, p_in + sizeof(T), p_out); + } + out_data += fixed_length; + } + } + } + + private: + // True if the endianness of the data and the endianness of the host are + // different, and the data needs conversion. + bool convert_data_endianness_; + + // Data type of the output tensor. + DataType out_type_; +}; + +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER(Name("DecodePaddedRaw") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("out_type"), \ + DecodePaddedRawOp) + +REGISTER(float); +REGISTER(double); +REGISTER(int32); +REGISTER(uint16); +REGISTER(uint8); +REGISTER(int16); +REGISTER(int8); +REGISTER(int64); + +#undef REGISTER + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc index 3dd019c3d20..e68fa407534 100644 --- a/tensorflow/core/kernels/decode_raw_op.cc +++ b/tensorflow/core/kernels/decode_raw_op.cc @@ -29,8 +29,13 @@ template class DecodeRawOp : public OpKernel { public: explicit DecodeRawOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("little_endian", &little_endian_)); OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_type_)); + + const bool host_is_little_endian = port::kLittleEndian; + bool data_is_little_endian; + OP_REQUIRES_OK(context, + context->GetAttr("little_endian", &data_is_little_endian)); + convert_data_endianness_ = host_is_little_endian != data_is_little_endian; } void Compute(OpKernelContext* context) override { @@ -70,13 +75,18 @@ class DecodeRawOp : public OpKernel { auto out = output_tensor->flat_inner_dims(); DCHECK_EQ(flat_in.size(), out.dimensions()[0]); T* out_data = out.data(); - if (port::kLittleEndian == little_endian_ || sizeof(T) == 1) { + + // If the data is already in the host's byte order, or if the width of the + // output type is a single byte, we can copy the memory directly. + if (!convert_data_endianness_ || sizeof(T) == 1) { for (int64 i = 0; i < flat_in.size(); ++i) { const T* in_data = reinterpret_cast(flat_in(i).data()); memcpy(out_data, in_data, str_size); out_data += added_dim; } } else { + // Otherwise, the data is not in the host's byte order, and rather than a + // direct copy, we need to reverse the byte ordering of each element. for (int64 i = 0; i < flat_in.size(); ++i) { const char* in_data_bytes = reinterpret_cast(flat_in(i).data()); @@ -92,7 +102,12 @@ class DecodeRawOp : public OpKernel { } private: - bool little_endian_; + // True if the endianness of the data and the endianness of the host are + // different, and the data needs conversion. + bool convert_data_endianness_; + + // True if the input data is in little endian format. + bool data_is_little_endian_; DataType out_type_; }; @@ -110,6 +125,7 @@ REGISTER(uint8); REGISTER(int16); REGISTER(int8); REGISTER(int64); +REGISTER(bool); REGISTER(complex64); REGISTER(complex128); diff --git a/tensorflow/core/kernels/dense_update_functor.cc b/tensorflow/core/kernels/dense_update_functor.cc index 3ed3794e01d..4d7eafd4f72 100644 --- a/tensorflow/core/kernels/dense_update_functor.cc +++ b/tensorflow/core/kernels/dense_update_functor.cc @@ -105,7 +105,7 @@ struct DenseUpdate { INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define GPU_DENSE_COPY(T) \ case DataTypeToEnum::value: { \ functor::DenseUpdate copy_functor_; \ @@ -121,7 +121,7 @@ INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES, GPU_DENSE_COPY); #undef TF_CALL_GPU_AND_ADDITIONAL_TYPES #undef GPU_DENSE_COPY -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef CPU_DENSE_COPY #undef INSTANTIATE_GET_VARIANT_COPY_FN diff --git a/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc b/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc index 25c57384ca9..daf8a7380e0 100644 --- a/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -72,4 +72,4 @@ TF_CALL_int8(DEFINE_GPU_KERNELS); } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index f942b1a8a92..c68f1891c39 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif @@ -102,7 +102,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); TF_CALL_quint16(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Only register 'Assign' on GPU for the subset of types also supported by // 'Variable' (see variable_ops.cc.) #define REGISTER_GPU_KERNELS(type) \ @@ -113,7 +113,7 @@ TF_CALL_quint16(REGISTER_KERNELS); TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNELS(type) \ @@ -136,7 +136,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint("T"), \ @@ -147,7 +147,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // end GOOGLE_CUDA +#endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNELS(type) \ diff --git a/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc b/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc index 768dd38a600..2abda846fd6 100644 --- a/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc @@ -17,11 +17,10 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/depthtospace_op.h" - #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/depthtospace_op.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace { @@ -161,7 +160,7 @@ struct DepthToSpaceOpFunctor { if (total_count == 0) { return; } - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); TF_CHECK_OK(CudaLaunchKernel( D2S_NHWC, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, input.data(), block_size, batch_size, @@ -195,7 +194,7 @@ struct DepthToSpaceOpFunctor { if (total_count == 0) { return; } - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); switch (block_size) { case 2: TF_CHECK_OK(CudaLaunchKernel( @@ -226,7 +225,7 @@ struct DepthToSpaceOpFunctor { if (total_count == 0) { return; } - auto config = GetCudaLaunchConfig(total_count, d); + auto config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(CudaLaunchKernel( D2S_NCHW, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, input.data(), block_size, input_width, diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index ab98cacd1a1..b29e8323332 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA -#include "cuda/include/cudnn.h" +#include "third_party/gpus/cudnn/cudnn.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 11c2b31633d..ceaeaac21de 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA -#include "cuda/include/cudnn.h" +#include "third_party/gpus/cudnn/cudnn.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.h b/tensorflow/core/kernels/depthwise_conv_op_gpu.h index fcbd8ffd868..33ff78b4c56 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.h +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/depthwise_conv_op.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/tensor_format.h" #if defined(_MSC_VER) && !defined(__clang__) @@ -37,6 +37,17 @@ limitations under the License. namespace tensorflow { +namespace detail { +template +struct PseudoHalfType { + using Type = T; +}; +template <> +struct PseudoHalfType { + using Type = float; +}; +} // namespace detail + using Eigen::GpuDevice; // Returns whether depthwise convolution forward or backward input pass can be @@ -75,6 +86,7 @@ template ::Type S; const int in_height = args.in_rows; const int in_width = args.in_cols; const int in_depth = args.in_depth; @@ -108,7 +120,7 @@ __global__ void __launch_bounds__(1024, 2) const int input_row_end = input_row_start + filter_height; const int input_col_end = input_col_start + filter_width; - T sum = static_cast(0); + S sum = static_cast(0); const int input_offset_temp = in_height * batch; if (input_row_start >= 0 && input_col_start >= 0 && @@ -128,7 +140,8 @@ __global__ void __launch_bounds__(1024, 2) multiplier + depth_multiplier * (in_channel + in_depth * (filter_col + filter_offset_temp)); - sum += ldg(input + input_offset) * ldg(filter + filter_offset); + sum += static_cast(ldg(input + input_offset)) * + static_cast(ldg(filter + filter_offset)); } } } else { @@ -150,12 +163,13 @@ __global__ void __launch_bounds__(1024, 2) multiplier + depth_multiplier * (in_channel + in_depth * (filter_col + filter_offset_temp)); - sum += ldg(input + input_offset) * ldg(filter + filter_offset); + sum += static_cast(ldg(input + input_offset)) * + static_cast(ldg(filter + filter_offset)); } } } } - output[thread_id] = sum; + output[thread_id] = static_cast(sum); } } @@ -172,9 +186,10 @@ __global__ void __launch_bounds__(1024, 2) // same as T for all cases but pseudo half (which has T=Eigen::half, S=float). template + bool kKnownEvenHeight> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { + typedef typename detail::PseudoHalfType::Type S; assert(CanLaunchDepthwiseConv2dGPUSmall(args)); // Holds block plus halo and filter data for blockDim.x depths. extern __shared__ __align__(8) unsigned char shared_memory[]; @@ -315,6 +330,7 @@ template ::Type S; const int in_height = args.in_rows; const int in_width = args.in_cols; const int in_depth = args.in_depth; @@ -388,7 +404,7 @@ __global__ void __launch_bounds__(1024, 2) const int input_row_end = input_row_start + filter_height; const int input_col_end = input_col_start + filter_width; - T sum = static_cast(0); + S sum = static_cast(0); if (input_row_start >= 0 && input_col_start >= 0 && input_row_end < in_height && input_col_end < in_width) { // Loop that doesn't need to check for boundary conditions. @@ -406,7 +422,8 @@ __global__ void __launch_bounds__(1024, 2) multiplier + depth_multiplier * (in_channel + in_depth * (filter_col + filter_offset_temp)); - sum += ldg(input + input_offset) * ldg(filter + filter_offset); + sum += static_cast(ldg(input + input_offset)) * + static_cast(ldg(filter + filter_offset)); } } } else { @@ -433,13 +450,14 @@ __global__ void __launch_bounds__(1024, 2) multiplier + depth_multiplier * (in_channel + in_depth * (filter_col + filter_offset_temp)); - sum += ldg(input + input_offset) * ldg(filter + filter_offset); + sum += static_cast(ldg(input + input_offset)) * + static_cast(ldg(filter + filter_offset)); } } } } - output[thread_id] = sum; + output[thread_id] = static_cast(sum); } } @@ -456,9 +474,10 @@ __global__ void __launch_bounds__(1024, 2) // same as T for all cases but pseudo half (which has T=Eigen::half, S=float). template + bool kKnownEvenHeight> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { + typedef typename detail::PseudoHalfType::Type S; assert(CanLaunchDepthwiseConv2dGPUSmall(args)); // Holds block plus halo and filter data for blockDim.z depths. extern __shared__ __align__(8) unsigned char shared_memory[]; @@ -596,11 +615,12 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( template + bool kKnownEvenHeight> Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format) { + typedef typename detail::PseudoHalfType::Type S; const int block_height = (args.in_rows + 1) / 2; dim3 block_dim; int block_count; @@ -613,7 +633,7 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, kernel = DepthwiseConv2dGPUKernelNHWCSmall; + kKnownEvenHeight>; break; case FORMAT_NCHW: block_dim = dim3(args.in_cols, block_height, kBlockDepth); @@ -622,7 +642,7 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, kernel = DepthwiseConv2dGPUKernelNCHWSmall; + kKnownEvenHeight>; break; default: return errors::InvalidArgument("FORMAT_", ToString(data_format), @@ -636,29 +656,15 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, kBlockDepth * (tile_pixels + filter_pixels) * sizeof(S); const int num_outputs = args.out_rows * args.out_cols * block_count; auto device = ctx->eigen_gpu_device(); - CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( num_outputs, device, kernel, shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - kernel<<>>(args, input, filter, output); + TF_CHECK_OK(CudaLaunchKernel(kernel, config.block_count, block_dim, + shared_memory_size, device.stream(), args, input, + filter, output)); return Status::OK(); } -namespace detail { -template -struct PseudoHalfType { - using Type = T; -}; -template <> -struct PseudoHalfType { - using Type = float; -}; -} // namespace detail - -// Maps to float if T is __half, and to T otherwise. -template -using PseudoHalfType = typename detail::PseudoHalfType::Type; - // Returns whether the context's GPU supports efficient fp16 math. inline bool HasFastHalfMath(OpKernelContext* ctx) { int major, minor; @@ -672,27 +678,6 @@ inline bool HasFastHalfMath(OpKernelContext* ctx) { return cuda_arch >= 530 && cuda_arch != 610; } -template -Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, - const DepthwiseArgs& args, const T* input, - const T* filter, T* output, - TensorFormat data_format) { -#if !defined __CUDA_ARCH__ || __CUDA_ARCH__ >= 530 - if (HasFastHalfMath(ctx)) { - return LaunchDepthwiseConv2dGPUSmall( - ctx, args, input, filter, output, data_format); - } -#endif - return LaunchDepthwiseConv2dGPUSmall>( - ctx, args, input, filter, output, data_format); -} - template Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, @@ -759,16 +744,17 @@ Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args, const int num_outputs = args.batch * args.out_rows * args.out_cols * args.out_depth; auto device = ctx->eigen_gpu_device(); - CudaLaunchConfig config = - GetCudaLaunchConfig(num_outputs, device, kernel, 0, 0); + GpuLaunchConfig config = + GetGpuLaunchConfig(num_outputs, device, kernel, 0, 0); // The compile-time constant version runs faster with a single block. const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 || kKnownDepthMultiplier < 0 ? std::numeric_limits::max() : device.getNumGpuMultiProcessors(); - kernel<<>>(args, input, filter, - output, num_outputs); + TF_CHECK_OK(CudaLaunchKernel(kernel, + std::min(max_block_count, config.block_count), + config.thread_per_block, 0, device.stream(), + args, input, filter, output, num_outputs)); return Status::OK(); } @@ -981,8 +967,8 @@ Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx, const int num_in_backprop = args.batch * args.in_rows * args.in_cols * args.in_depth; auto device = ctx->eigen_gpu_device(); - CudaLaunchConfig config = - GetCudaLaunchConfig(num_in_backprop, device, kernel, 0, 0); + GpuLaunchConfig config = + GetGpuLaunchConfig(num_in_backprop, device, kernel, 0, 0); TF_CHECK_OK(CudaLaunchKernel( kernel, config.block_count, config.thread_per_block, 0, device.stream(), args, out_backprop, filter, in_backprop, num_in_backprop)); @@ -1029,6 +1015,8 @@ void LaunchDepthwiseConvBackpropInputOp::operator()( } // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. +// TODO: Add fp32 accumulation to half calls of this function. This addition +// is non-trivial as the partial sums are added directly to the output template __global__ void __launch_bounds__(640, 2) @@ -1163,10 +1151,11 @@ __device__ __forceinline__ T WarpSumReduce(T val) { // T is the tensors' data type. S is the math type the kernel uses. This is the // same as T for all cases but pseudo half (which has T=Eigen::half, S=float). template + int kBlockDepth, int kAccumPixels> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { + typedef typename detail::PseudoHalfType::Type S; assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z)); // Holds block plus halo and filter data for blockDim.x depths. extern __shared__ __align__(8) unsigned char shared_memory[]; @@ -1435,10 +1424,11 @@ __global__ void __launch_bounds__(640, 2) // Requirements: threads per block must be multiple of 32 and <= launch_bounds, // kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. template + int kBlockDepth, int kAccumPixels> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { + typedef typename detail::PseudoHalfType::Type S; assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x)); // Holds block plus halo and filter data for blockDim.z depths. extern __shared__ __align__(8) unsigned char shared_memory[]; @@ -1581,11 +1571,12 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( } template + int kBlockDepth, int kAccumPixels> Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { + typedef typename detail::PseudoHalfType::Type S; auto device = ctx->eigen_gpu_device(); const int tile_width = args.in_cols + args.filter_cols - 1; const int tile_height = block_height * 2 + args.filter_rows - 1; @@ -1606,50 +1597,29 @@ Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( block_count = args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth; kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, - S>; + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; break; case FORMAT_NCHW: block_dim = dim3(args.in_cols, block_height, kBlockDepth); block_count = DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth; kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, - S>; + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; break; default: return errors::InvalidArgument("FORMAT_", ToString(data_format), " is not supported"); } const int num_out_backprop = args.out_rows * args.out_cols * block_count; - CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( num_out_backprop, device, kernel, shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - kernel<<>>(args, out_backprop, input, filter_backprop); + TF_CHECK_OK(CudaLaunchKernel(kernel, config.block_count, block_dim, + shared_memory_size, device.stream(), args, + out_backprop, input, filter_backprop)); return Status::OK(); } -template -Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format) { -#if !defined __CUDA_ARCH__ || __CUDA_ARCH__ >= 530 - if (HasFastHalfMath(ctx)) { - return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, T>( - ctx, args, block_height, out_backprop, input, filter_backprop, - data_format); - } -#endif - return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, - PseudoHalfType>(ctx, args, block_height, out_backprop, input, - filter_backprop, data_format); -} - template Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( @@ -1745,8 +1715,8 @@ Status LaunchDepthwiseConv2dBackpropFilterGPU( const int num_out_backprop = args.batch * args.out_rows * args.out_cols * args.out_depth; auto device = ctx->eigen_gpu_device(); - CudaLaunchConfig config = - GetCudaLaunchConfig(num_out_backprop, device, kernel, 0, 0); + GpuLaunchConfig config = + GetGpuLaunchConfig(num_out_backprop, device, kernel, 0, 0); TF_CHECK_OK(CudaLaunchKernel( kernel, config.block_count, config.thread_per_block, 0, device.stream(), args, out_backprop, input, filter_backprop, num_out_backprop)); diff --git a/tensorflow/core/kernels/determinant_op_gpu.cu.cc b/tensorflow/core/kernels/determinant_op_gpu.cu.cc index 681567ef2d8..387ea3b6607 100644 --- a/tensorflow/core/kernels/determinant_op_gpu.cu.cc +++ b/tensorflow/core/kernels/determinant_op_gpu.cu.cc @@ -17,13 +17,13 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/determinant_op.h" - #include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/cuda_solvers.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/kernels/determinant_op.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -128,7 +128,7 @@ struct DeterminantFromPivotedLUFunctor { int* info) { const int64 num_matrices = output.size(); const int64 n = lu_factor.dimension(2); - CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); + GpuLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); TF_CHECK_OK(CudaLaunchKernel( DeterminantFromPivotedLUKernel, @@ -151,7 +151,7 @@ struct LogDeterminantFromPivotedLUFunctor { typename TTypes::Tensor log_abs_det) { const int64 num_matrices = sign.size(); const int64 n = lu_factor.dimension(2); - CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); + GpuLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); TF_CHECK_OK(CudaLaunchKernel( DeterminantFromPivotedLUKernel, config.block_count, config.thread_per_block, 0, device.stream(), diff --git a/tensorflow/core/kernels/diag_op_gpu.cu.cc b/tensorflow/core/kernels/diag_op_gpu.cu.cc index 910f3093b23..7ad967fd92f 100644 --- a/tensorflow/core/kernels/diag_op_gpu.cu.cc +++ b/tensorflow/core/kernels/diag_op_gpu.cu.cc @@ -18,9 +18,10 @@ limitations under the License. #define EIGEN_USE_GPU #include + #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/diag_op.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -49,7 +50,7 @@ struct DiagFunctor { return Status::OK(); } - // CudaLaunchConfig uses an int for virtual_thread_count, + // GpuLaunchConfig uses an int for virtual_thread_count, // so this may overflow for `size*size` in extreme cases, // here is checking the multiplication overflow for integer. if (size && (int(size * size) / size) != size) { @@ -59,11 +60,12 @@ struct DiagFunctor { // Launch the GPU kernel. const GPUDevice& device = context->eigen_device(); - CudaLaunchConfig diag_config = - GetCudaLaunchConfig(virtual_thread_count, device); - DiagCudaKernel<<>>(diag_config.virtual_thread_count, size, - in, out); + GpuLaunchConfig diag_config = + GetGpuLaunchConfig(virtual_thread_count, device); + TF_CHECK_OK( + CudaLaunchKernel(DiagCudaKernel, diag_config.block_count, + diag_config.thread_per_block, 0, device.stream(), + diag_config.virtual_thread_count, size, in, out)); auto err = cudaGetLastError(); if (err != cudaSuccess) { @@ -100,10 +102,11 @@ struct DiagPartFunctor { const GPUDevice& device = context->eigen_device(); // Extract the diagonal elements. - CudaLaunchConfig diag_config = GetCudaLaunchConfig(size, device); - DiagPartCudaKernel<<>>(diag_config.virtual_thread_count, - size, in, out); + GpuLaunchConfig diag_config = GetCudaLaunchConfig(size, device); + TF_CHECK_OK( + CudaLaunchKernel(DiagPartCudaKernel, diag_config.block_count, + diag_config.thread_per_block, 0, device.stream(), + diag_config.virtual_thread_count, size, in, out)); auto err = cudaGetLastError(); if (err != cudaSuccess) { diff --git a/tensorflow/core/kernels/dilation_ops_gpu.cu.cc b/tensorflow/core/kernels/dilation_ops_gpu.cu.cc index 12408f2c416..588f5677f40 100644 --- a/tensorflow/core/kernels/dilation_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/dilation_ops_gpu.cu.cc @@ -22,12 +22,11 @@ limitations under the License. #include #include -#include "tensorflow/core/kernels/dilation_ops.h" - #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/dilation_ops.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -194,14 +193,14 @@ struct Dilation { const int output_cols = output.dimension(2); const int total_count = batch * output_rows * output_cols * depth; - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); - DilationKernel<<>>( - config.virtual_thread_count, input.data(), filter.data(), batch, - input_rows, input_cols, depth, filter_rows, filter_cols, output_rows, - output_cols, stride_rows, stride_cols, rate_rows, rate_cols, pad_top, - pad_left, output.data()); + TF_CHECK_OK(CudaLaunchKernel( + DilationKernel, config.block_count, config.thread_per_block, 0, + d.stream(), config.virtual_thread_count, input.data(), filter.data(), + batch, input_rows, input_cols, depth, filter_rows, filter_cols, + output_rows, output_cols, stride_rows, stride_cols, rate_rows, + rate_cols, pad_top, pad_left, output.data())); } }; @@ -225,18 +224,18 @@ struct DilationBackpropInput { const int output_cols = out_backprop.dimension(2); int total_count; - CudaLaunchConfig config; + GpuLaunchConfig config; // Initialize in_backprop with all zeros. total_count = batch * input_rows * input_cols * depth; - config = GetCudaLaunchConfig(total_count, d); + config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(CudaLaunchKernel(SetZero, config.block_count, config.thread_per_block, 0, d.stream(), total_count, in_backprop.data())); // Accumulate. total_count = batch * output_rows * output_cols * depth; - config = GetCudaLaunchConfig(total_count, d); + config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(CudaLaunchKernel( DilationBackpropInputKernel, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, @@ -267,24 +266,25 @@ struct DilationBackpropFilter { const int output_cols = out_backprop.dimension(2); int total_count; - CudaLaunchConfig config; + GpuLaunchConfig config; // Initialize filter_backprop with all zeros. total_count = filter_rows * filter_cols * depth; - config = GetCudaLaunchConfig(total_count, d); + config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(CudaLaunchKernel(SetZero, config.block_count, config.thread_per_block, 0, d.stream(), total_count, filter_backprop.data())); // Accumulate. total_count = batch * output_rows * output_cols * depth; - config = GetCudaLaunchConfig(total_count, d); - DilationBackpropFilterKernel<<>>( - config.virtual_thread_count, input.data(), filter.data(), - out_backprop.data(), batch, input_rows, input_cols, depth, filter_rows, - filter_cols, output_rows, output_cols, stride_rows, stride_cols, - rate_rows, rate_cols, pad_top, pad_left, filter_backprop.data()); + config = GetGpuLaunchConfig(total_count, d); + TF_CHECK_OK(CudaLaunchKernel( + DilationBackpropFilterKernel, config.block_count, + config.thread_per_block, 0, d.stream(), config.virtual_thread_count, + input.data(), filter.data(), out_backprop.data(), batch, input_rows, + input_cols, depth, filter_rows, filter_cols, output_rows, output_cols, + stride_rows, stride_cols, rate_rows, rate_cols, pad_top, pad_left, + filter_backprop.data())); } }; diff --git a/tensorflow/core/kernels/draw_bounding_box_op.cc b/tensorflow/core/kernels/draw_bounding_box_op.cc index 618c47e6848..30de99b7d56 100644 --- a/tensorflow/core/kernels/draw_bounding_box_op.cc +++ b/tensorflow/core/kernels/draw_bounding_box_op.cc @@ -25,6 +25,30 @@ limitations under the License. namespace tensorflow { +namespace { + +std::vector> DefaultColorTable(int depth) { + std::vector> color_table; + color_table.emplace_back(std::vector({1, 1, 0, 1})); // 0: yellow + color_table.emplace_back(std::vector({0, 0, 1, 1})); // 1: blue + color_table.emplace_back(std::vector({1, 0, 0, 1})); // 2: red + color_table.emplace_back(std::vector({0, 1, 0, 1})); // 3: lime + color_table.emplace_back(std::vector({0.5, 0, 0.5, 1})); // 4: purple + color_table.emplace_back(std::vector({0.5, 0.5, 0, 1})); // 5: olive + color_table.emplace_back(std::vector({0.5, 0, 0, 1})); // 6: maroon + color_table.emplace_back(std::vector({0, 0, 0.5, 1})); // 7: navy blue + color_table.emplace_back(std::vector({0, 1, 1, 1})); // 8: aqua + color_table.emplace_back(std::vector({1, 0, 1, 1})); // 9: fuchsia + + if (depth == 1) { + for (int64 i = 0; i < color_table.size(); i++) { + color_table[i][0] = 1; + } + } + return color_table; +} +} // namespace + template class DrawBoundingBoxesOp : public OpKernel { public: @@ -52,31 +76,32 @@ class DrawBoundingBoxesOp : public OpKernel { const int64 batch_size = images.dim_size(0); const int64 height = images.dim_size(1); const int64 width = images.dim_size(2); - const int64 color_table_length = 10; + std::vector> color_table; + if (context->num_inputs() == 3) { + const Tensor& colors_tensor = context->input(2); + OP_REQUIRES(context, colors_tensor.shape().dims() == 2, + errors::InvalidArgument("colors must be a 2-D matrix", + colors_tensor.shape().DebugString())); + OP_REQUIRES(context, colors_tensor.shape().dim_size(1) >= depth, + errors::InvalidArgument("colors must have equal or more ", + "channels than the image provided: ", + colors_tensor.shape().DebugString())); + if (colors_tensor.NumElements() != 0) { + color_table.clear(); - // 0: yellow - // 1: blue - // 2: red - // 3: lime - // 4: purple - // 5: olive - // 6: maroon - // 7: navy blue - // 8: aqua - // 9: fuchsia - float color_table[color_table_length][4] = { - {1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1}, - {0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1}, - {0, 1, 1, 1}, {1, 0, 1, 1}, - }; - - // Reset first color channel to 1 if image is GRY. - // For GRY images, this means all bounding boxes will be white. - if (depth == 1) { - for (int64 i = 0; i < color_table_length; i++) { - color_table[i][0] = 1; + auto colors = colors_tensor.matrix(); + for (int64 i = 0; i < colors.dimension(0); i++) { + std::vector color_value(4); + for (int64 j = 0; j < 4; j++) { + color_value[j] = colors(i, j); + } + color_table.emplace_back(color_value); + } } } + if (color_table.empty()) { + color_table = DefaultColorTable(depth); + } Tensor* output; OP_REQUIRES_OK( context, @@ -90,7 +115,7 @@ class DrawBoundingBoxesOp : public OpKernel { const int64 num_boxes = boxes.dim_size(1); const auto tboxes = boxes.tensor(); for (int64 bb = 0; bb < num_boxes; ++bb) { - int64 color_index = bb % color_table_length; + int64 color_index = bb % color_table.size(); const int64 min_box_row = static_cast(tboxes(b, bb, 0)) * (height - 1); const int64 min_box_row_clamp = std::max(min_box_row, int64{0}); @@ -176,9 +201,12 @@ class DrawBoundingBoxesOp : public OpKernel { } }; -#define REGISTER_CPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint("T"), \ +#define REGISTER_CPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint("T"), \ + DrawBoundingBoxesOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("DrawBoundingBoxesV2").Device(DEVICE_CPU).TypeConstraint("T"), \ DrawBoundingBoxesOp); TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index f00baa932f8..24cd1b62ce0 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -47,7 +47,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/gather_functor_gpu.cu.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/transform_output_iterator.h" namespace tensorflow { @@ -78,7 +78,7 @@ __global__ void MoveValuesKernel(const int32* keys, const int32* values, template void RangeInit(const GPUDevice& d, const T start, const T delta, const int32 size, typename TTypes::Flat out) { - CudaLaunchConfig config = GetCudaLaunchConfig(size, d); + GpuLaunchConfig config = GetCudaLaunchConfig(size, d); TF_CHECK_OK(CudaLaunchKernel(RangeInitKernel, config.block_count, config.thread_per_block, 0, d.stream(), start, delta, size, out.data())); @@ -93,7 +93,7 @@ void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs, // This is valid for correct inputs, because then out_size >= *num_runs. // For wrong inputs, we may have out_size < *num_runs. In this case we will // only handle the first out_size values. - CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); + GpuLaunchConfig config = GetCudaLaunchConfig(out_size, d); TF_CHECK_OK(CudaLaunchKernel(MoveValuesKernel, config.block_count, config.thread_per_block, 0, d.stream(), keys, values, num_runs, out_size, out)); @@ -103,7 +103,7 @@ template void CallGatherKernel(const GPUDevice& d, const T* params, const int32* indices, T* out, int64 gather_dim_size, int64 indices_size, int64 slice_size, int64 out_size) { - CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); + GpuLaunchConfig config = GetCudaLaunchConfig(out_size, d); TF_CHECK_OK(CudaLaunchKernel( GatherOpKernel, config.block_count, config.thread_per_block, 0, d.stream(), params, indices, out, diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index 5b8845b675d..471bd7fbb1c 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #ifdef GOOGLE_CUDA -#include "tensorflow/core/kernels/cuda_device_array.h" +#include "tensorflow/core/kernels/gpu_device_array.h" #endif // GOOGLE_CUDA namespace tensorflow { @@ -138,9 +138,21 @@ class DynamicStitchOpImplBase : public OpKernel { template void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, const int32 slice_size, const int32 first_dim_size, - const CudaDeviceArrayStruct& input_indices, - const CudaDeviceArrayStruct& input_ptrs, + const GpuDeviceArrayStruct& input_indices, + const GpuDeviceArrayStruct& input_ptrs, T* output); +#define REGISTER_GPU(T) \ + extern template void DynamicStitchGPUImpl( \ + const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ + const int32 first_dim_size, \ + const GpuDeviceArrayStruct& input_indices, \ + const GpuDeviceArrayStruct& input_ptrs, T* output); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +TF_CALL_complex64(REGISTER_GPU); +TF_CALL_complex128(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); +TF_CALL_int32(REGISTER_GPU); +#undef REGISTER_GPU template class DynamicStitchOpGPU : public DynamicStitchOpImplBase { @@ -167,14 +179,14 @@ class DynamicStitchOpGPU : public DynamicStitchOpImplBase { // merged that aren't covered by an index in indices. What should we do? if (first_dim_size > 0) { // because the collision requirements, we have to deal with - // collion first before send data to gpu kernel. + // collision first before send data to gpu kernel. // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the // last of duplicated indices, it could instead be done of the GPU // implicitly using atomics to make sure the last index is the final // write. const int slice_size = merged->flat_outer_dims().dimension(1); - CudaDeviceArrayOnHost indices_flat(c, first_dim_size); - CudaDeviceArrayOnHost data_flat(c, data_elements_size); + GpuDeviceArrayOnHost indices_flat(c, first_dim_size); + GpuDeviceArrayOnHost data_flat(c, data_elements_size); OP_REQUIRES_OK(c, indices_flat.Init()); OP_REQUIRES_OK(c, data_flat.Init()); // initialize the indices_flat (-1 represents missing indices) diff --git a/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc index 9ed2c540091..111b6a0a90c 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/cuda_device_array_gpu.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/kernels/gpu_device_array_gpu.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -31,11 +31,11 @@ namespace { template __global__ void DynamicStitchKernel(const int32 slice_size, const int32 output_size, - CudaDeviceArrayStruct input_indices, - CudaDeviceArrayStruct input_ptrs, + GpuDeviceArrayStruct input_indices, + GpuDeviceArrayStruct input_ptrs, T* output) { - int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices); - const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs); + int32* data_indices = GetGpuDeviceArrayOnDevice(&input_indices); + const T** data_ptrs = GetGpuDeviceArrayOnDevice(&input_ptrs); CUDA_1D_KERNEL_LOOP(output_index, output_size) { const int32 slice_id = output_index / slice_size; const int32 slice_offset = output_index % slice_size; @@ -51,11 +51,11 @@ __global__ void DynamicStitchKernel(const int32 slice_size, template void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, const int32 slice_size, const int32 first_dim_size, - const CudaDeviceArrayStruct& input_indices, - const CudaDeviceArrayStruct& input_ptrs, + const GpuDeviceArrayStruct& input_indices, + const GpuDeviceArrayStruct& input_ptrs, T* output) { const int32 output_size = first_dim_size * slice_size; - auto config = GetCudaLaunchConfig(output_size, gpu_device); + auto config = GetGpuLaunchConfig(output_size, gpu_device); TF_CHECK_OK(CudaLaunchKernel(DynamicStitchKernel, config.block_count, config.thread_per_block, 0, gpu_device.stream(), @@ -67,8 +67,8 @@ void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, template void DynamicStitchGPUImpl( \ const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ const int32 first_dim_size, \ - const CudaDeviceArrayStruct& input_indices, \ - const CudaDeviceArrayStruct& input_ptrs, T* output); + const GpuDeviceArrayStruct& input_indices, \ + const GpuDeviceArrayStruct& input_ptrs, T* output); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h index 960920c55bd..aaea7b1268e 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -366,8 +366,8 @@ SpatialConvolutionBackwardKernel( YOU_MADE_A_PROGRAMMING_MISTAKE); // stride and in_stride cannot both be larger than 1 - eigen_assert(!(row_stride > 1 && row_in_stride > 1) && - !(col_stride > 1 && col_in_stride > 1)); + eigen_assert(!(row_stride > 1 && row_in_stride > 1)); + eigen_assert(!(col_stride > 1 && col_in_stride > 1)); static const bool isColMajor = (internal::traits::Layout == ColMajor); diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc index ec949ddc845..12fa7f3409d 100644 --- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc +++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc @@ -8,7 +8,7 @@ You may obtain a copy of the License at Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONT OF ANY KIND, either express or implied. +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ diff --git a/tensorflow/core/kernels/eigen_contraction_kernel.h b/tensorflow/core/kernels/eigen_contraction_kernel.h index 1af263cdd5d..493ab2776c5 100644 --- a/tensorflow/core/kernels/eigen_contraction_kernel.h +++ b/tensorflow/core/kernels/eigen_contraction_kernel.h @@ -34,6 +34,11 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +// FixedPoint header must be included after Tensor. +// clang-format off +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" +// clang-format on + #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) #include "mkldnn.h" #endif @@ -128,10 +133,15 @@ struct mkldnn_gemm_kernel::max)(); eigen_assert(max_index >= rows); @@ -143,11 +153,8 @@ struct mkldnn_gemm_kernel(cols); const int k = static_cast(depth); - const char transposeA = 'N'; - const char transposeB = 'N'; - - const int ldA = m; - const int ldB = k; + ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA; + ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB; const int ldC = static_cast(output.stride()); const float beta = 1.0; @@ -164,6 +171,72 @@ struct mkldnn_gemm_kernel +struct mkldnn_gemm_s8s8s32_kernel { + static_assert(!ConjugateLhs, "MKL-DNN kernel doesn't support ConjugateLhs"); + static_assert(!ConjugateRhs, "MKL-DNN kernel doesn't support ConjugateRhs"); + + static constexpr int kComputeStrideFromBlockDimensions = -1; + + using LhsScalar = Eigen::QInt8; + using RhsScalar = Eigen::QInt8; + using ResScalar = Eigen::QInt32; + + EIGEN_DONT_INLINE + void operator()(const OutputMapper& output, const LhsScalar* blockA, + const RhsScalar* blockB, const IndexType rows, + const IndexType depth, const IndexType cols, float alpha, + int ldA = kComputeStrideFromBlockDimensions, + int ldB = kComputeStrideFromBlockDimensions, + char transposeA = 'N', char transposeB = 'N') { + static const int max_index = (std::numeric_limits::max)(); + + eigen_assert(max_index >= rows); + eigen_assert(max_index >= cols); + eigen_assert(max_index >= depth); + eigen_assert(max_index >= output.stride()); + + const int m = static_cast(rows); + const int n = static_cast(cols); + const int k = static_cast(depth); + + ldA = ldA == kComputeStrideFromBlockDimensions ? m : ldA; + ldB = ldB == kComputeStrideFromBlockDimensions ? k : ldB; + const int ldC = static_cast(output.stride()); + + const float beta = 1.0; + + // Currently we support only symmetric quantization with zero point at 0. + const int8_t ao = 0; + const int8_t bo = 0; + + // Don't add any offset to the result C. + const char offsetc = 'F'; + const int32_t co = 0; + + const int8_t* A = reinterpret_cast(blockA); + const int8_t* B = reinterpret_cast(blockB); + int32_t* C = + reinterpret_cast(const_cast(output.data())); + + mkldnn_status_t st = + mkldnn_gemm_s8s8s32(&transposeA, &transposeB, &offsetc, // + &m, &n, &k, // + &alpha, // + A, &ldA, &ao, // + B, &ldB, &bo, // + &beta, // + C, &ldC, &co); + eigen_assert(st == 0); + + // eigen_assert is a no-op in optimized mode so we add these to avoid + // compiler's unused-variable errors. + EIGEN_UNUSED_VARIABLE(max_index); + EIGEN_UNUSED_VARIABLE(st); + } +}; + // For mkldnn_sgemm having the right dimensions (especially for small matrices) // is more important than fitting all the working set in L1/L2 caches. // TODO(ezhulenev): Do better heuristics. @@ -235,71 +308,524 @@ class TensorContractionBlocking -struct TensorContractionKernel { - // For now mkldnn has only mkldnn_sgemm (gemm for floats). - using Scalar = float; - using Traits = typename internal::gebp_traits; +// If the Lhs or Rhs Tensor expressions are already evaluated and have access to +// raw data, we can skip packing step and setup pointers and a stride to the +// underlying memory buffer and pass them directly to Gemm. +template +struct ColMajorBlock { + bool is_direct_access; - using LhsPacker = - gemm_pack_colmajor_block; - using RhsPacker = - gemm_pack_colmajor_block; - using GemmKernel = mkldnn_gemm_kernel; + // Valid iff `is_direct_access == false` + Scalar* packed_data; - // Fallback on default Eigen pack and GEBP kernel if custom contraction - // kernels disabled at runtime. - using EigenLhsPacker = - gemm_pack_lhs; - using EigenRhsPacker = - gemm_pack_rhs; - using GebpKernel = - gebp_kernel; + // Valid iff `is_direct_access == true` + Scalar* raw_data; + StorageIndex stride; + char transpose; +}; - EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void packLhs( - Scalar* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, - const StorageIndex depth, const StorageIndex rows) { - if (UseCustomContractionKernels()) { - LhsPacker()(lhsBlock, data_mapper, rows, depth); - } else { - EigenLhsPacker()(lhsBlock, data_mapper, depth, rows, /*stride*/ 0, - /*offset*/ 0); - } - } +template +struct DirectColMajorAccess { + enum { value = false }; - EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void packRhs( - Scalar* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, - const StorageIndex depth, const StorageIndex cols) { - if (UseCustomContractionKernels()) { - RhsPacker()(rhsBlock, data_mapper, depth, cols); - } else { - EigenRhsPacker()(rhsBlock, data_mapper, depth, cols); - } - } - - EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void invoke( - const OutputMapper& output_mapper, const Scalar* lhsBlock, - const Scalar* rhsBlock, const StorageIndex rows, const StorageIndex depth, - const StorageIndex cols, const Scalar alpha) { - if (UseCustomContractionKernels()) { - GemmKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha); - } else { - GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha, - /*strideA*/ -1, /*strideB*/ -1, - /*offsetA*/ 0, /*offsetB*/ 0); - } + template + static bool block(const typename DataMapper::SubMapper& data_mapper, + const StorageIndex rows, const StorageIndex cols, + const StorageIndex num_kernels, + ColMajorBlock* block) { + eigen_assert(false && "Not implemented"); + return false; } }; +// If we have an access to raw memory of the contraction input, we can safely +// skip packing if: +// (1) Packing is a no-op. +// (2) Packed block will be used just once. +// +// If a packed block is used many times, it's more efficient to pack it into +// contiguous block of memory to reduce pressure on TLB. +// +// TODO(ezhulenev): Add support for more tensor expressions that matters. +#define REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_EXPR) \ + template \ + struct DirectColMajorAccess, \ + nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true, \ + /*inner_dim_reordered=*/false, Alignment>> { \ + enum { value = true }; \ + \ + using DataMapper = TensorContractionInputMapper< \ + Scalar, StorageIndex, Side, TensorEvaluator, \ + nocontract_t, contract_t, packet_size, /*inner_dim_contiguous=*/true, \ + /*inner_dim_reordered=*/false, Alignment>; \ + \ + static bool block(const typename DataMapper::SubMapper& data_mapper, \ + const StorageIndex rows, const StorageIndex cols, \ + const StorageIndex num_kernels, \ + ColMajorBlock* block) { \ + static_assert(DataMapper::DirectOffsets == true, \ + "DataMapper must support direct offsets"); \ + \ + const StorageIndex vert_offset = data_mapper.vert_offset(); \ + const StorageIndex horiz_offset = data_mapper.horiz_offset(); \ + const StorageIndex stride = \ + Side == Lhs ? data_mapper.base_mapper().stride() \ + : data_mapper.base_mapper().nocontract_strides()[0]; \ + const Scalar* data = data_mapper.base_mapper().tensor().data(); \ + data = Side == Lhs ? data : data + vert_offset + horiz_offset * stride; \ + \ + const bool is_no_op_packing = stride == rows; \ + const StorageIndex adressable_mem = (stride * cols * sizeof(Scalar)); \ + const bool use_direct_access = \ + is_no_op_packing || num_kernels == 1 /* used once */ || \ + ((num_kernels == 2) && (adressable_mem < (256 << 10) /* 256 kb */)); \ + \ + if (use_direct_access) { \ + block->is_direct_access = true; \ + block->raw_data = const_cast(data); \ + block->stride = stride; \ + block->transpose = 'N'; \ + return true; \ + } \ + return false; \ + } \ + } + +#define SIMPLE_TENSOR const Tensor + +#define TENSOR_MAP_ROWMAJOR \ + const TensorMap, \ + Eigen::Aligned> + +#define TENSOR_MAP_COLMAJOR \ + const TensorMap, \ + Eigen::Aligned> + +#define TENSOR_MAP_CONST_ROWMAJOR \ + const TensorMap, \ + Eigen::Aligned> + +#define TENSOR_MAP_CONST_COLMAJOR \ + const TensorMap, \ + Eigen::Aligned> + +// This is reshaped convolution filter from `eigen_spatial_convolutions.h`. +#define TENSOR_RESHAPE \ + const TensorReshapingOp< \ + const Eigen::DSizes, \ + const TensorMap, \ + Eigen::Aligned>> + +REGISTER_DIRECT_COL_MAJOR_ACCESS(SIMPLE_TENSOR); +REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_ROWMAJOR); +REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_COLMAJOR); +REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_ROWMAJOR); +REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_MAP_CONST_COLMAJOR); +REGISTER_DIRECT_COL_MAJOR_ACCESS(TENSOR_RESHAPE); + +#undef SIMPLE_TENSOR +#undef TENSOR_MAP_ROWMAJOR +#undef TENSOR_MAP_COLMAJOR +#undef TENSOR_MAP_CONST_ROWMAJOR +#undef TENSOR_MAP_CONST_COLMAJOR +#undef TENSOR_RESHAPE +#undef REGISTER_DIRECT_COL_MAJOR_ACCESS + +template +struct GemmKernelProvider { + enum { Defined = 0 }; + using GemmKernel = void; +}; + +template +struct GemmKernelProvider { + enum { Defined = 1 }; + using GemmKernel = mkldnn_gemm_kernel; +}; + +template +struct GemmKernelProvider { + enum { Defined = 1 }; + using GemmKernel = mkldnn_gemm_s8s8s32_kernel; +}; + +// NOTE: 'std::enable_if' doesn't work for template specializations. See +// "default template argument in a class template partial specialization". + +// Tensor contraction kernel that can fallback on Eigen gebp_kernel at runtime. +#define REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK( \ + RES_SCALAR, LHS_SCALAR, RHS_SCALAR) \ + \ + template \ + struct TensorContractionKernel { \ + TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \ + StorageIndex bm, StorageIndex bk, StorageIndex bn) \ + : m(m), \ + k(k), \ + n(n), \ + bm(bm), \ + bk(bk), \ + bn(bn), \ + nm0(bm > 0 ? divup(m, bm) : 0), \ + nn0(bn > 0 ? divup(n, bn) : 0) {} \ + \ + using ResScalar = RES_SCALAR; \ + using LhsScalar = LHS_SCALAR; \ + using RhsScalar = RHS_SCALAR; \ + \ + using Traits = typename internal::gebp_traits; \ + \ + using LhsBlock = ColMajorBlock; \ + using RhsBlock = ColMajorBlock; \ + \ + using DirectLhsAccess = DirectColMajorAccess; \ + using DirectRhsAccess = DirectColMajorAccess; \ + \ + /* Packed Lhs/Rhs block memory allocator.*/ \ + typedef TensorContractionBlockMemAllocator \ + BlockMemAllocator; \ + typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; \ + \ + using LhsPacker = \ + gemm_pack_colmajor_block; \ + using RhsPacker = \ + gemm_pack_colmajor_block; \ + \ + using GemmKernelProviderType = \ + GemmKernelProvider; \ + static_assert( \ + GemmKernelProviderType::Defined, \ + "Custom GEMM kernel is not registered for given scalar types"); \ + using GemmKernel = typename GemmKernelProviderType::GemmKernel; \ + \ + /* Fallback on default Eigen pack and GEBP kernel if custom contraction */ \ + /* kernels disabled at runtime. */ \ + using EigenLhsPacker = \ + gemm_pack_lhs; \ + using EigenRhsPacker = \ + gemm_pack_rhs; \ + using GebpKernel = \ + gebp_kernel; \ + \ + template \ + EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, \ + RhsBlock* rhs_block) { \ + return BlockMemAllocator::allocate( \ + d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data); \ + } \ + \ + template \ + EIGEN_DEVICE_FUNC BlockMemHandle \ + allocateSlices(Device& d, const int num_lhs, const int num_rhs, \ + const int num_slices, std::vector* lhs_blocks, \ + std::vector* rhs_blocks) { \ + eigen_assert(num_slices > 0); \ + std::vector> lhs_mem(num_slices); \ + std::vector> rhs_mem(num_slices); \ + \ + BlockMemHandle block_mem = BlockMemAllocator::allocateSlices( \ + d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(), \ + rhs_mem.data()); \ + \ + for (Index x = 0; x < num_slices; x++) { \ + if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); \ + for (Index m = 0; m < num_lhs; m++) { \ + lhs_blocks[x][m].packed_data = lhs_mem[x][m]; \ + } \ + if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); \ + for (Index n = 0; n < num_rhs; n++) { \ + rhs_blocks[x][n].packed_data = rhs_mem[x][n]; \ + } \ + } \ + \ + return block_mem; \ + } \ + \ + template \ + EIGEN_DEVICE_FUNC void deallocate(Device& d, BlockMemHandle handle) { \ + BlockMemAllocator::deallocate(d, handle); \ + } \ + \ + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( \ + LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, \ + const StorageIndex depth, const StorageIndex rows) { \ + if (UseCustomContractionKernels()) { \ + const bool is_direct_access = \ + DirectLhsAccess::value && \ + DirectLhsAccess::block(data_mapper, rows, depth, nn0, lhsBlock); \ + \ + if (!is_direct_access) { \ + lhsBlock->is_direct_access = false; \ + LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth); \ + } \ + } else { \ + lhsBlock->is_direct_access = false; \ + EigenLhsPacker()(lhsBlock->packed_data, data_mapper, depth, rows, \ + /*stride*/ 0, /*offset*/ 0); \ + } \ + } \ + \ + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( \ + RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, \ + const StorageIndex depth, const StorageIndex cols) { \ + if (UseCustomContractionKernels()) { \ + const bool is_direct_access = \ + DirectRhsAccess::value && \ + DirectRhsAccess::block(data_mapper, depth, cols, nm0, rhsBlock); \ + \ + if (!is_direct_access) { \ + rhsBlock->is_direct_access = false; \ + RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ + } \ + } else { \ + rhsBlock->is_direct_access = false; \ + EigenRhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ + } \ + } \ + \ + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( \ + const OutputMapper& output_mapper, const LhsBlock& lhsBlock, \ + const RhsBlock& rhsBlock, const StorageIndex rows, \ + const StorageIndex depth, const StorageIndex cols, \ + const float alpha) { \ + if (UseCustomContractionKernels()) { \ + if ((DirectLhsAccess::value && lhsBlock.is_direct_access) && \ + (DirectRhsAccess::value && rhsBlock.is_direct_access)) { \ + GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data, \ + rows, depth, cols, alpha, /*ldA=*/lhsBlock.stride, \ + /*ldB=*/rhsBlock.stride, \ + /*transposeA=*/lhsBlock.transpose, \ + /*transposeB=*/rhsBlock.transpose); \ + \ + } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) { \ + GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \ + rows, depth, cols, alpha, /*ldA=*/lhsBlock.stride, \ + /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions, \ + /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \ + \ + } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) { \ + GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \ + rows, depth, cols, alpha, \ + /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions, \ + /*ldB=*/rhsBlock.stride, /*transposeA=*/'N', \ + /*transposeB=*/rhsBlock.transpose); \ + \ + } else { \ + GemmKernel()(output_mapper, lhsBlock.packed_data, \ + rhsBlock.packed_data, rows, depth, cols, alpha); \ + } \ + } else { \ + GebpKernel()( \ + output_mapper, lhsBlock.packed_data, rhsBlock.packed_data, rows, \ + depth, cols, alpha, \ + /*strideA*/ GemmKernel::kComputeStrideFromBlockDimensions, \ + /*strideB*/ GemmKernel::kComputeStrideFromBlockDimensions, \ + /*offsetA*/ 0, /*offsetB*/ 0); \ + } \ + } \ + \ + private: \ + /* These are dimensions of the original Tensors, and selected block */ \ + /* sizes. The actual block sizes passed to all function above might be */ \ + /* smaller because of the partial blocks at the end. */ \ + const StorageIndex m; \ + const StorageIndex k; \ + const StorageIndex n; \ + const StorageIndex bm; \ + const StorageIndex bk; \ + const StorageIndex bn; \ + /* Number of kernels for each dimension. */ \ + const StorageIndex nm0; \ + const StorageIndex nn0; \ + } + +// Tensor contraction kernel that do not fallback on Eigen. Currently not all +// data types are supported by Eigen data packing and default gebp_kernel. +#define REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(RES_SCALAR, LHS_SCALAR, \ + RHS_SCALAR) \ + \ + template \ + struct TensorContractionKernel { \ + TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, \ + StorageIndex bm, StorageIndex bk, StorageIndex bn) \ + : m(m), \ + k(k), \ + n(n), \ + bm(bm), \ + bk(bk), \ + bn(bn), \ + nm0(bm > 0 ? divup(m, bm) : 0), \ + nn0(bn > 0 ? divup(n, bn) : 0) {} \ + \ + using ResScalar = RES_SCALAR; \ + using LhsScalar = LHS_SCALAR; \ + using RhsScalar = RHS_SCALAR; \ + \ + using Traits = typename internal::gebp_traits; \ + \ + using LhsBlock = ColMajorBlock; \ + using RhsBlock = ColMajorBlock; \ + \ + using DirectLhsAccess = DirectColMajorAccess; \ + using DirectRhsAccess = DirectColMajorAccess; \ + \ + /* Packed Lhs/Rhs block memory allocator.*/ \ + typedef TensorContractionBlockMemAllocator \ + BlockMemAllocator; \ + typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; \ + \ + using LhsPacker = \ + gemm_pack_colmajor_block; \ + using RhsPacker = \ + gemm_pack_colmajor_block; \ + \ + using GemmKernelProviderType = \ + GemmKernelProvider; \ + static_assert( \ + GemmKernelProviderType::Defined, \ + "Custom GEMM kernel is not registered for given scalar types"); \ + using GemmKernel = typename GemmKernelProviderType::GemmKernel; \ + \ + template \ + EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, \ + RhsBlock* rhs_block) { \ + return BlockMemAllocator::allocate( \ + d, bm, bk, bn, &lhs_block->packed_data, &rhs_block->packed_data); \ + } \ + \ + template \ + EIGEN_DEVICE_FUNC BlockMemHandle \ + allocateSlices(Device& d, const int num_lhs, const int num_rhs, \ + const int num_slices, std::vector* lhs_blocks, \ + std::vector* rhs_blocks) { \ + eigen_assert(num_slices > 0); \ + std::vector> lhs_mem(num_slices); \ + std::vector> rhs_mem(num_slices); \ + \ + BlockMemHandle block_mem = BlockMemAllocator::allocateSlices( \ + d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_mem.data(), \ + rhs_mem.data()); \ + \ + for (Index x = 0; x < num_slices; x++) { \ + if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); \ + for (Index m = 0; m < num_lhs; m++) { \ + lhs_blocks[x][m].packed_data = lhs_mem[x][m]; \ + } \ + if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); \ + for (Index n = 0; n < num_rhs; n++) { \ + rhs_blocks[x][n].packed_data = rhs_mem[x][n]; \ + } \ + } \ + \ + return block_mem; \ + } \ + \ + template \ + EIGEN_DEVICE_FUNC void deallocate(Device& d, BlockMemHandle handle) { \ + BlockMemAllocator::deallocate(d, handle); \ + } \ + \ + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( \ + LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, \ + const StorageIndex depth, const StorageIndex rows) { \ + const bool is_direct_access = \ + DirectLhsAccess::value && \ + DirectLhsAccess::block(data_mapper, rows, depth, nn0, lhsBlock); \ + \ + if (!is_direct_access) { \ + lhsBlock->is_direct_access = false; \ + LhsPacker()(lhsBlock->packed_data, data_mapper, rows, depth); \ + } \ + } \ + \ + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( \ + RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, \ + const StorageIndex depth, const StorageIndex cols) { \ + const bool is_direct_access = \ + DirectRhsAccess::value && \ + DirectRhsAccess::block(data_mapper, depth, cols, nm0, rhsBlock); \ + \ + if (!is_direct_access) { \ + rhsBlock->is_direct_access = false; \ + RhsPacker()(rhsBlock->packed_data, data_mapper, depth, cols); \ + } \ + } \ + \ + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( \ + const OutputMapper& output_mapper, const LhsBlock& lhsBlock, \ + const RhsBlock& rhsBlock, const StorageIndex rows, \ + const StorageIndex depth, const StorageIndex cols, \ + const float alpha) { \ + if ((DirectLhsAccess::value && lhsBlock.is_direct_access) && \ + (DirectRhsAccess::value && rhsBlock.is_direct_access)) { \ + GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.raw_data, \ + rows, depth, cols, alpha, /*ldA=*/lhsBlock.stride, \ + /*ldB=*/rhsBlock.stride, \ + /*transposeA=*/lhsBlock.transpose, \ + /*transposeB=*/rhsBlock.transpose); \ + \ + } else if (DirectLhsAccess::value && lhsBlock.is_direct_access) { \ + GemmKernel()(output_mapper, lhsBlock.raw_data, rhsBlock.packed_data, \ + rows, depth, cols, alpha, /*ldA=*/lhsBlock.stride, \ + /*ldB=*/GemmKernel::kComputeStrideFromBlockDimensions, \ + /*transposeA=*/lhsBlock.transpose, /*transposeB=*/'N'); \ + \ + } else if (DirectRhsAccess::value && rhsBlock.is_direct_access) { \ + GemmKernel()(output_mapper, lhsBlock.packed_data, rhsBlock.raw_data, \ + rows, depth, cols, alpha, \ + /*ldA=*/GemmKernel::kComputeStrideFromBlockDimensions, \ + /*ldB=*/rhsBlock.stride, /*transposeA=*/'N', \ + /*transposeB=*/rhsBlock.transpose); \ + \ + } else { \ + GemmKernel()(output_mapper, lhsBlock.packed_data, \ + rhsBlock.packed_data, rows, depth, cols, alpha); \ + } \ + } \ + \ + private: \ + /* These are dimensions of the original Tensors, and selected block */ \ + /* sizes. The actual block sizes passed to all function above might be */ \ + /* smaller because of the partial blocks at the end. */ \ + const StorageIndex m; \ + const StorageIndex k; \ + const StorageIndex n; \ + const StorageIndex bm; \ + const StorageIndex bk; \ + const StorageIndex bn; \ + /* Number of kernels for each dimension. */ \ + const StorageIndex nm0; \ + const StorageIndex nn0; \ + } + +REGISTER_TENSOR_CONTRACTION_KERNEL_WITH_FALLBACK(float, float, float); +REGISTER_TENSOR_CONTRACTION_KERNEL_NO_FALLBACK(Eigen::QInt32, Eigen::QInt8, + Eigen::QInt8); + +#undef REGISTER_TENSOR_CONTRACTION_KERNEL + #endif // defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) } // namespace internal diff --git a/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc b/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc index 0234c7006ea..86938938f83 100644 --- a/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc +++ b/tensorflow/core/kernels/eigen_mkldnn_contraction_kernel_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" #include "tensorflow/core/kernels/eigen_contraction_kernel.h" #include "tensorflow/core/platform/test.h" @@ -145,5 +146,62 @@ TEST(EigenMkldnnTest, MkldnnGemm) { } } +TEST(EigenMkldnnTest, MkldnnGemmQInt8) { + // Mkldnn pack and gemm are used only in Tensor contractions, and it's + // guaranteed that Tensors will have ColMajor layout. + static const int Options = ColMajor; + + using Tensor2dQInt8 = Eigen::Tensor; + using Tensor2dQInt32 = Eigen::Tensor; + + int m = internal::random(1, 1000); + int n = internal::random(1, 1000); + int k = internal::random(1, 1000); + + Tensor2dQInt8 lhs(m, k); + lhs.setRandom(); + + Tensor2dQInt8 rhs(k, n); + rhs.setRandom(); + + Eigen::array, 1> contract_dims; + contract_dims[0].first = 1; + contract_dims[0].second = 0; + + Tensor2dQInt32 res = lhs.contract(rhs, contract_dims); + + // Compute matmul with Eigen::Matrix. We explicitly cast inputs to int32_t not + // to test QInt8->QInt32 type promotion during accumulation. + using Matrix = Eigen::Matrix; + + Matrix lhs_mat(m, k); + Matrix rhs_mat(k, n); + + for (int i = 0; i < m; ++i) { + for (int j = 0; j < k; ++j) { + lhs_mat(i, j) = static_cast(lhs(i, j)); + } + } + + for (int i = 0; i < k; ++i) { + for (int j = 0; j < n; ++j) { + rhs_mat(i, j) = static_cast(rhs(i, j)); + } + } + + Matrix matmul_result(m, n); + matmul_result.setZero(); + matmul_result = lhs_mat * rhs_mat; + + // Verify that results are equal. + for (Index i = 0; i < m; ++i) { + for (Index j = 0; j < n; ++j) { + Scalar gemm = res(i, j); + Scalar matmul = matmul_result(i, j); + EXPECT_EQ(gemm, matmul); + } + } +} + } // namespace internal } // namespace Eigen diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h b/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h new file mode 100644 index 00000000000..324e7ac58bd --- /dev/null +++ b/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h @@ -0,0 +1,1765 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ + +// Note this header is used in both TF and TFLite. +namespace Eigen { + +namespace internal { + +// TensorEvaluatorHasPartialPacket +// provides `value` that is true if TensorEvaluatorType has `PacketType +// partialPacket(IndexType, unpacket_traits::mask_t) +// const` and if the PacketType supports masked load. +// +// Partial packets are used to: +// +// 1) Split the packet over two columns and use partial loads for each +// individual part before combining them to get the required packet. This +// class is used to pick the correct implementation of loadPacketStandard +// function below. +// +// 2) Finalize packing of columns in gemm_pack_colmajor after processing +// vectorized part with full packets (see eigen_spatiual_convolutions.h). +template +class TensorEvaluatorHasPartialPacket { + public: + template + static auto functionExistsSfinae( + typename std::enable_if< + unpacket_traits::masked_load_available && + std::is_same() + .template partialPacket( + std::declval(), + std::declval::mask_t>()))>::value>:: + type*) -> std::true_type; + + template + static auto functionExistsSfinae(...) -> std::false_type; + + typedef decltype( + functionExistsSfinae( + nullptr)) status; + + static const bool value = status::value; +}; + +// Compute a mask for loading/storing coefficients in/from a packet in a +// [from, to) range. If the mask bit is 1, element will be loaded/stored. +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + typename std::enable_if::masked_load_available, + typename unpacket_traits::mask_t>::type + mask(int from, int to) { + const Index packet_size = internal::unpacket_traits::size; + eigen_assert(0 <= from && to <= (packet_size + 1) && from < to); + + using Mask = typename internal::unpacket_traits::mask_t; + const Mask mask_max = std::numeric_limits::max(); + + return (mask_max >> (packet_size - to)) ^ (mask_max >> (packet_size - from)); +} + +// WARNING: Most of the code here implicitly assumes that the matrix is in +// ColMajor layout. This is guaranteed by the tensor contraction (see +// TensorContraction.h). +// +// Inside Eigen a tensor contraction is represented by a matrix multiplication. +// We don't want to actually extract image patches and reshape the result into +// a matrix (this involves allocating huge extra memory), so the patch +// extraction and reshape operations are implicit. +// +// TensorContractionInputMapper takes a matrix index and returns the coefficient +// (or the packet) of the "virtual tensor", that would be at that index if we +// were to actually reshape the result of patch extraction. +// +// TensorContractionSubMapper provides a similar view into the "virtual matrix" +// at the given vertical and horizontal offsets. +// +// "Virtual matrix" dimensions: +// *0: kernelChannels * kernelRows * kernelCols; +// 1: out_height * out_width; * OTHERS (e.g batches, etc...) +// +// *) extracted patches are continuous in memory (innermost dimension assuming +// col major layout) +// +// With this dimensions: +// row - offset within a single patch (in code: patchId) +// col - index of the extracted patch (in code: patchIndex) +// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) +// +// TODO(ezhulenev): Consolidate this part of the code with the image patch +// extraction code since they are both very similar. + +template +class TensorContractionInputMapper< + Scalar_, Index, Side, + TensorEvaluator< + const TensorReshapingOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef Scalar_ Scalar; + + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + + typedef SubMapper VectorMapper; + typedef SubMapper LinearMapper; + typedef typename packet_traits::type Packet; + + typedef TensorEvaluator TensorEvaluatorT; + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper( + const TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>& tensor, + const nocontract_t&, const nocontract_t&, const contract_t&, + const contract_t&) + : m_impl(tensor.impl().impl()) { + Index patch_rows; + Index patch_depth; + if (internal::traits::Layout == ColMajor) { + patch_depth = tensor.impl().dimensions()[0]; + patch_rows = tensor.impl().dimensions()[1]; + m_patch_cols = tensor.impl().dimensions()[2]; + m_num_patches = tensor.impl().dimensions()[3]; + } else { + const size_t NumDims = tensor.impl().dimensions().size(); + patch_depth = tensor.impl().dimensions()[NumDims - 1]; + patch_rows = tensor.impl().dimensions()[NumDims - 2]; + m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; + m_num_patches = tensor.impl().dimensions()[NumDims - 4]; + } + + // Strides for navigating through the single patch. + m_patch_row_stride = patch_depth; + m_patch_col_stride = patch_rows * m_patch_row_stride; + + m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); + m_patch_col_inflate_strides = tensor.impl().colInflateStride(); + + m_colStride = patch_rows; + + m_outputRows = tensor.impl().outputRows(); + m_row_strides = tensor.impl().userRowStride(); + m_col_strides = tensor.impl().userColStride(); + + m_in_row_strides = tensor.impl().userInRowStride(); + m_in_col_strides = tensor.impl().userInColStride(); + + if (internal::traits::Layout == ColMajor) { + m_inputRows = tensor.impl().impl().dimensions()[1]; + m_inputCols = tensor.impl().impl().dimensions()[2]; + } else { + const int NumDims = tensor.impl().impl().dimensions().size(); + m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; + m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; + } + + m_rowInputStride = patch_depth; + m_colInputStride = patch_depth * m_inputRows; + m_patchInputStride = patch_depth * m_inputRows * m_inputCols; + + m_rowPaddingTop = tensor.impl().rowPaddingTop(); + m_colPaddingLeft = tensor.impl().colPaddingLeft(); + + m_fastPatchRowStride = + internal::TensorIntDivisor(m_patch_row_stride); + m_fastPatchColStride = + internal::TensorIntDivisor(m_patch_col_stride); + m_fastInputRowStride = + internal::TensorIntDivisor(m_patch_row_inflate_strides); + m_fastInputColStride = + internal::TensorIntDivisor(m_patch_col_inflate_strides); + m_fastNumPatches = internal::TensorIntDivisor(m_num_patches); + m_fastColStride = internal::TensorIntDivisor(m_colStride); + m_fastOutputRows = internal::TensorIntDivisor(m_outputRows); + m_fastDimZero = internal::TensorIntDivisor(patch_depth); + } + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) + : m_impl(base_mapper.m_impl) { + m_patch_cols = base_mapper.m_patch_cols; + m_num_patches = base_mapper.m_num_patches; + + m_patch_row_stride = base_mapper.m_patch_row_stride; + m_patch_col_stride = base_mapper.m_patch_col_stride; + + m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; + m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; + + m_colStride = base_mapper.m_colStride; + + m_rowInputStride = base_mapper.m_rowInputStride; + m_colInputStride = base_mapper.m_colInputStride; + m_patchInputStride = base_mapper.m_patchInputStride; + + m_inputRows = base_mapper.m_inputRows; + m_inputCols = base_mapper.m_inputCols; + + m_outputRows = base_mapper.m_outputRows; + m_row_strides = base_mapper.m_row_strides; + m_col_strides = base_mapper.m_col_strides; + + m_in_row_strides = base_mapper.m_in_row_strides; + m_in_col_strides = base_mapper.m_in_col_strides; + + m_rowPaddingTop = base_mapper.m_rowPaddingTop; + m_colPaddingLeft = base_mapper.m_colPaddingLeft; + + m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; + m_fastPatchColStride = base_mapper.m_fastPatchColStride; + m_fastInputRowStride = base_mapper.m_fastInputRowStride; + m_fastInputColStride = base_mapper.m_fastInputColStride; + m_fastNumPatches = base_mapper.m_fastNumPatches; + m_fastColStride = base_mapper.m_fastColStride; + m_fastOutputRows = base_mapper.m_fastOutputRows; + m_fastDimZero = base_mapper.m_fastDimZero; + } + + // If true, turns off some optimizations for loading packets since the image + // patches are "non-standard" such as there are non-trivial strides or + // inflations in the input. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_in_row_strides != 1 || m_in_col_strides != 1 || + m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { + return SubMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { + Index rowIndex, colIndex, otherIndex; + computeBaseIndices(0, rowIndex, colIndex, otherIndex); + return loadCoeff(row, rowIndex, colIndex, otherIndex); + } + + // Load the coefficient at the patchIndex location instead of the usual + // m_rowIndex, + // m_colIndex, m_otherIndex. This is currently only used by the gpu code. + // EIGEN_DEVICE_FUNC + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { + Index rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { + Index rowIndex, colIndex, otherIndex; + computeBaseIndices(0, rowIndex, colIndex, otherIndex); + return loadPacket(row, rowIndex, colIndex, otherIndex); + } + + // Load the packet at the patchIndex location instead of the usual m_rowIndex, + // m_colIndex, m_otherIndex. This is currently only used by the gpu code. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { + Index rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE const TensorEvaluator& impl() const { + return m_impl; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } + + private: + friend class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>; + + // Load coefficient from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, + Index colIndex, Index otherIndex) const { + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset * m_in_col_strides; + const Index origInputCol = + (m_patch_col_inflate_strides == 1) + ? inputCol + : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); + + const Index rowOffset = patchOffset - colOffset * m_colStride; + const Index inputRow = rowIndex + rowOffset * m_in_row_strides; + const Index origInputRow = + (m_patch_row_inflate_strides == 1) + ? inputRow + : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); + if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols || + origInputRow >= m_inputRows || + (inputCol != origInputCol * m_patch_col_inflate_strides) || + (inputRow != origInputRow * m_patch_row_inflate_strides)) { + return Scalar(0); + } + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + origInputRow * m_rowInputStride + + origInputCol * m_colInputStride + otherIndex; + return m_impl.coeff(inputIndex); + } + + // This is the same as loadCoeff(...), but optimized for all `inflate_strides` + // and `in_strides` equal to 1 (template specialization without templates). + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, + Index colIndex, + Index otherIndex) const { + eigen_assert(!nonStandardPatches()); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + const Index colOffset = patchOffset / m_fastColStride; + const Index rowOffset = patchOffset - colOffset * m_colStride; + const Index inputCol = colIndex + colOffset; + const Index inputRow = rowIndex + rowOffset; + if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || + inputRow >= m_inputRows) { + return Scalar(0); + } + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + return m_impl.coeff(inputIndex); + } + + // Load packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, + Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); + + if (nonStandardPatches()) { + return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); + } + typedef decltype(m_impl) TensorEvaluatorT; + return loadPacketStandard(patchId, rowIndex, + colIndex, otherIndex); + } + + // Helper function to load a 'partial' packet - this is the single column + // part of a packet that is split across two columns. In the 'partial' packet, + // the elements corresponding to the column (specified through colOffset) are + // loaded and the rest of the elements are zero-filled into the 'partial' + // packet. This function is called from loadPacketStandardFromTwoColumns(). + // This code path is exercied only when the packet type supports masked load + // and when the partial packet load is available in the TensorEvaluator. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard( + Index rowIndex, Index colIndex, Index otherIndex, Index patchId, + const Index span[], const Index patchOffsets[], Index colOffset) const { + const Index inputCol = colIndex + colOffset; + const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride, + patchOffsets[1] - colOffset * m_colStride}; + const Index inputRows[2] = {rowIndex + rowOffsets[0], + rowIndex + rowOffsets[1]}; + + if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || + inputCol >= m_inputCols || inputCol < 0) { + // Partial packet is all zeros + return internal::pset1(Scalar(0)); + } else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { + // From inputIndex-span[0], we need to load elements starting from index + // span[0] all the way upto (and including) span[1]. + const Index depth = patchId - patchOffsets[0] * patchDepth(); + const Index inputIndex = depth + inputRows[0] * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + return m_impl.template partialPacket( + inputIndex - span[0], mask(span[0], span[1] + 1)); + } else { + // Using slow path for this partial packet. + // We need to load elements starting from index span[0] all the way upto + // (and including) span[1]. We split this load into 3 parts: + // 0 : span[0]-1 - Zeros will be loaded for these indices + // span[0] : span[1] - Elements will be loaded here for these indices + // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices + const Index packetSize = internal::unpacket_traits::size; + EIGEN_ALIGN_MAX + typename internal::remove_const::type values[packetSize]; + for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0); + for (int i = span[0]; i < span[1] + 1; ++i) + values[i] = + loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex); + for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0); + return internal::pload(values); + } + } + + // Helper function to load a packet that is split across two columns. + // If required, this function is called from loadPacketStandard() when the + // packet type supports masked load and when the partial packet load is + // available in the TensorEvaluator. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns( + Index patchId, Index rowIndex, Index colIndex, Index otherIndex, + const Index patchOffsets[], const Index colOffsets[]) const { + eigen_assert(colOffsets[1] == colOffsets[0] + 1); + const Index packetSize = internal::unpacket_traits::size; + + // Packet to load will be split into 2 parts where each part spans a single + // column. First determine where to split. + const Index patchIdSplit = + ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1; + const Index patchOffsetSplit = patchIdSplit / m_fastDimZero; + + // patchIds[i]: patchId corresponding to partial packet i + // spans[i]: Start and end indices corresponding to the elements + // to be loaded for partial packet i + // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i + const Index patchIds[2] = {patchId, patchIdSplit + 1}; + const Index spans[2][2] = {{0, patchIdSplit - patchId}, + {patchIdSplit - patchId + 1, packetSize - 1}}; + const Index patchOffsets2Cols[2][2] = { + {patchOffsets[0], patchOffsetSplit}, + {patchOffsetSplit + 1, patchOffsets[1]}}; + + // Load partial packets and do bit-wise OR to generate required packet + return internal::por( + loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], + spans[0], patchOffsets2Cols[0], + colOffsets[0]), + loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], + spans[1], patchOffsets2Cols[1], + colOffsets[1])); + } + + // Helper function to load a packet that is present in a single columns. + // If required, this function is called from loadPacketStandard(). + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn( + Index patchId, Index rowIndex, Index colIndex, Index otherIndex, + const Index patchOffsets[], const Index colOffsets[], + const Index inputCols[]) const { + eigen_assert(colOffsets[0] == colOffsets[1]); + const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride, + patchOffsets[1] - colOffsets[1] * m_colStride}; + eigen_assert(rowOffsets[0] <= rowOffsets[1]); + const Index inputRows[2] = {rowIndex + rowOffsets[0], + rowIndex + rowOffsets[1]}; + + if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { + // all zeros + return internal::pset1(Scalar(0)); // all zeros + } + + if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { + // no padding + const Index depth = patchId - patchOffsets[0] * patchDepth(); + const Index inputIndex = depth + inputRows[0] * m_rowInputStride + + inputCols[0] * m_colInputStride + otherIndex; + return m_impl.template packet(inputIndex); + } + return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); + } + + // Load standard packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + // This function will be called if partial packet loading is not available + // for the TesnorEvaluator or if the packet type does not support masked + // load. + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< + !TensorEvaluatorHasPartialPacket::value, + PacketT>::type + loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); + + eigen_assert(!nonStandardPatches()); + + if ((patchDepth() % packetSize) == 0) { + return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); + } + + // Offsets and input calculation here are identical to + // loadCoeffStandard(...), but repeated twice. + const Index patchOffsets[2] = {patchId / m_fastDimZero, + (patchId + packetSize - 1) / m_fastDimZero}; + const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, + patchOffsets[1] / m_fastColStride}; + const Index inputCols[2] = {colIndex + colOffsets[0], + colIndex + colOffsets[1]}; + + if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { + // all zeros + return internal::pset1(Scalar(0)); + } + if (inputCols[0] == inputCols[1]) { + return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, + otherIndex, patchOffsets, + colOffsets, inputCols); + } + return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); + } + + // Load standard packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + // This function will be called if partial packet loading is available for + // the TesnorEvaluator and if the packet type supports masked load. + // The only difference between this and the other case is that if the packet + // to load is split across two columns, then in this case instead of going to + // the slow (element-by-element) load, we load two packets - each containing + // elements from one of the columns (rest of the elements of the packets are + // zeroes), and then combine these two packets to generate the required + // packet. The idea is to enable fast load (if possible) of these 'partial' + // packets. + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< + TensorEvaluatorHasPartialPacket::value, + PacketT>::type + loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); + + eigen_assert(!nonStandardPatches()); + + if ((patchDepth() % packetSize) == 0) { + return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); + } + + // Offsets and input calculation here are identical to + // loadCoeffStandard(...), but repeated twice. + const Index patchOffsets[2] = {patchId / m_fastDimZero, + (patchId + packetSize - 1) / m_fastDimZero}; + const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, + patchOffsets[1] / m_fastColStride}; + const Index inputCols[2] = {colIndex + colOffsets[0], + colIndex + colOffsets[1]}; + + if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { + // all zeros + return internal::pset1(Scalar(0)); + } + if (inputCols[0] == inputCols[1]) { + return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, + otherIndex, patchOffsets, + colOffsets, inputCols); + } + if (inputCols[1] == inputCols[0] + 1) { + return loadPacketStandardFromTwoColumns( + patchId, rowIndex, colIndex, otherIndex, patchOffsets, colOffsets); + } + return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, + Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); + + eigen_assert(!nonStandardPatches()); + eigen_assert((patchDepth() % packetSize) == 0); + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); + + const Index colOffset = patchOffset / m_fastColStride; + const Index rowOffset = patchOffset - colOffset * m_colStride; + const Index inputCol = colIndex + colOffset; + const Index inputRow = rowIndex + rowOffset; + if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || + inputRow >= m_inputRows) { + // all zeros + return internal::pset1(Scalar(0)); + } + // no padding + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + return m_impl.template packet(inputIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero( + Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { + const int packetSize = internal::unpacket_traits::size; + EIGEN_ALIGN_MAX + typename internal::remove_const::type values[packetSize]; + for (int i = 0; i < packetSize; ++i) { + values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex); + } + Packet rslt = internal::pload(values); + return rslt; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( + Index patchIndex, Index& rowIndex, Index& colIndex, + Index& otherIndex) const { + const size_t NumInputDims = array_size< + typename TensorEvaluator::Dimensions>::value; + otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; + const Index patch2DIndex = (NumInputDims == 3) + ? patchIndex + : (patchIndex - otherIndex * m_num_patches); + otherIndex *= m_patchInputStride; + colIndex = patch2DIndex / m_fastOutputRows; + rowIndex = patch2DIndex - colIndex * m_outputRows; + colIndex = colIndex * m_col_strides - m_colPaddingLeft; + rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; + } + + Index m_patch_cols; // number of columns in the patch + Index m_num_patches; // number of patches to extract. + + // Strides for navigating through the single patch. + Index m_patch_row_stride; + Index m_patch_col_stride; + internal::TensorIntDivisor m_fastPatchRowStride; + internal::TensorIntDivisor m_fastPatchColStride; + + Index m_patch_row_inflate_strides; // the strides for row inflation in the + // image patch + Index m_patch_col_inflate_strides; // the strides for col inflation in the + // image patch + // Fast representation of inflation strides. + internal::TensorIntDivisor m_fastInputRowStride; + internal::TensorIntDivisor m_fastInputColStride; + + Index m_otherStride; + Index m_colStride; + internal::TensorIntDivisor m_fastNumPatches; + internal::TensorIntDivisor m_fastColStride; + + Index m_rowInputStride; // row stride in the input tensor + Index m_colInputStride; // col stride in the input tensor + Index m_patchInputStride; // patch stride in the input tensor + + Index m_inputRows; // Number of rows in the input tensor + Index m_inputCols; // Number of cols in the input tensor + + Index m_outputRows; // Number of patch rows + + Index m_row_strides; // User specified row stride + Index m_col_strides; // User specified col stride + + Index m_in_row_strides; // User specified input row stride + Index m_in_col_strides; // User specified input col stride + + Index m_rowPaddingTop; // Row padding + Index m_colPaddingLeft; // Column padding + + internal::TensorIntDivisor m_fastOutputRows; + internal::TensorIntDivisor m_fastDimZero; + + const TensorEvaluator m_impl; +}; + +template +class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator< + const TensorReshapingOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef typename packet_traits::type Packet; + typedef typename packet_traits::half HalfPacket; + + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + ParentMapper; + + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + + typedef Self LinearMapper; + + typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) + : m_depth_offset(vert_offset), + m_col_offset(horiz_offset), + m_base_mapper(base_mapper) { + m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, + m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const Self& base_mapper, Index vert_offset, Index horiz_offset) + : m_depth_offset(vert_offset + base_mapper.m_depth_offset), + m_col_offset(horiz_offset + base_mapper.m_col_offset), + m_base_mapper(base_mapper.m_base_mapper) { + m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, + m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { + return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, + m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, + Index j) const { + return m_base_mapper(i + m_depth_offset, j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { + return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, + m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, + Index j) const { + return m_base_mapper.template loadPacket(i + m_depth_offset, + j + m_col_offset); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar + loadCoeffStandard(Index i) const { + return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, + m_colIndex, m_otherIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { + return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + loadPacketStandard(Index i) const { + typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT; + return m_base_mapper.template loadPacketStandard( + i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex); + } + template + EIGEN_DEVICE_FUNC bool aligned(Index) const { + return false; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_base_mapper.nonStandardPatches(); + } + + // Max(Col|Row|Depth): compute the upper limit for the column, row and depth + // index respectively that fits into the peeled_k elements starting at + // m_depth_offset. + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { + const Index max_col = + (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / + fastPatchColStride(); + return std::min(1 + max_col, patchCols()); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, + const Index col) const { + const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - + col * patchColStride()) / + fastPatchRowStride(); + return std::min(1 + max_row, patchRows()); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, + Index row) const { + const Index max_depth = m_depth_offset + peeled_k - // + col * patchColStride() - // + row * patchRowStride(); + return std::min(max_depth, patchDepth()); + } + + // MaxDepth uses only the remaining number of elements in the peeled_k. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, + const Index start_depth) const { + return std::min(start_depth + num_elements, patchDepth()); + } + + // Every register matters in this code, so sometimes to prevent register + // spilling, instead of the variable that you would expect to see, we use + // another one, that is guaranteed to have the same value. E.g. patch depth is + // always the same as input depth, and it's also the same as input row stride. + // Bunch of other parameters have similar relations. + + typedef internal::TensorIntDivisor IndexDivisor; + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { + return m_base_mapper.m_rowInputStride; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { + return m_base_mapper.m_colStride; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { + return m_base_mapper.m_patch_cols; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRowStride() const { + eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && + "Patch depth must be equal to patch row stride."); + return patchDepth(); + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchColStride() const { + return m_base_mapper.m_patch_col_stride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { + eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && + "Patch depth must be equal to patch row stride."); + return m_base_mapper.m_fastDimZero; // patch_depth + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { + return m_base_mapper.m_fastPatchColStride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_base_mapper.m_impl.template packet(inputIndex); + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_base_mapper.m_impl.coeff(inputIndex); + } + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< + TensorEvaluatorHasPartialPacket::value, + PacketT>::type + partialPacketNoPadding(const Index depth, const Index baseIndex, + Index num_coeffs) const { + const Index inputIndex = depth + baseIndex; + return m_base_mapper.m_impl.template partialPacket( + inputIndex, mask(0, num_coeffs)); + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { + const Index r = m_rowIndex + row; + return r < 0 || r >= m_base_mapper.m_inputRows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, + const Index last_row) const { + return m_rowIndex + first_row < 0 || + m_rowIndex + last_row >= m_base_mapper.m_inputRows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { + const Index c = m_colIndex + col; + return c < 0 || c >= m_base_mapper.m_inputCols; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const { + const Index r = m_rowIndex + row; + const Index c = m_colIndex + col; + return r * m_base_mapper.m_rowInputStride + + c * m_base_mapper.m_colInputStride + m_otherIndex; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index rowStride() const { + return m_base_mapper.m_row_strides; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index colStride() const { + return m_base_mapper.m_col_strides; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index rowOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + return patchOffset - colOffset * m_base_mapper.m_colStride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index colOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + return colOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index depthOffset() const { + return m_depth_offset % patchDepth(); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper + getLinearMapper(Index i, Index j) const { + return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); + } + + private: + Index m_depth_offset; // First row in the input matrix + Index m_col_offset; // First col in the input matrix + + // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base + // indices for the first element in a patch specified by col_offset + // (see computeBaseIndices(...) for details). + Index m_rowIndex; + Index m_colIndex; + Index m_otherIndex; + + const ParentMapper m_base_mapper; // Keeping a copy instead of a reference + // performs better in benchmarks. +}; + +// Arrange a block of the right input matrix (in our case it's always a "virtual +// matrix" constructed from extracted image patches) in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 +// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 +// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 +// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 +// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 +// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 +// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 +// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 +// A8 ... +// ... +// +// *) A, B, C, ... - patches extracted from the original input. +// *) A0, A1, A2 ... - values from the same patch at different offsets. +// +// The traversal (packed rhs memory) order (B0 besides A0 in memory): +// A0 B0 C0 D0 A1 B1 C1 D1 ... +// E0 F0 G0 H0 E1 F1 G1 H1 ... +// ... +// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) +// +// This traversal order must be the same as in default gemm_pack_rhs defined in +// GeneralBlockPanelKernel.h. +// +// *) nr - number of registers along the 'n' dimension. +// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix +// Multiplication" paper. +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper DataMapper; + typedef typename packet_traits::type Packet; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if ((packet_size % 4) == 0 && !non_standard_patches) { + // FAST PATH: + // Iterate over patch columns and rows, if we know that a single + // packet do not span across multiple rows or columns. + if ((rhs.patchDepth() % packet_size) == 0) { + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + // Check if we can squeeze reads along the `row` and `depth` + // dimensions (two innermost dimensions). + if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // + !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // + !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // + !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // + !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { + // Compute how many elements we can squeeze read. + const Index start_depth = + (c == start_col) ? rhs.depthOffset() : 0; + + // Upper bound for the number of elements in the depth dimension + // that we can squeeze read. + const Index squeeze_length = + (max_row - start_row) * rhs.patchDepth() - start_depth; + + // Do not overshoot beyond the block size. + const Index max_depth = + start_depth + std::min(peeled_k - k, squeeze_length); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + const Index idx0 = dm0.baseIndex(start_row, c); + const Index idx1 = dm1.baseIndex(start_row, c); + const Index idx2 = dm2.baseIndex(start_row, c); + const Index idx3 = dm3.baseIndex(start_row, c); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock kernel; + kernel.packet[0] = rhs.packetNoPadding(d, idx0); + kernel.packet[1] = rhs.packetNoPadding(d, idx1); + kernel.packet[2] = rhs.packetNoPadding(d, idx2); + kernel.packet[3] = rhs.packetNoPadding(d, idx3); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + k += packet_size; + } + + // Go to the next column. + continue; + } + + // If we can't squeeze reads, process rows one by one. + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const bool pad0 = pad_col0 || dm0.padRow(r); + const bool pad1 = pad_col1 || dm1.padRow(r); + const bool pad2 = pad_col2 || dm2.padRow(r); + const bool pad3 = pad_col3 || dm3.padRow(r); + + const Index idx0 = dm0.baseIndex(r, c); + const Index idx1 = dm1.baseIndex(r, c); + const Index idx2 = dm2.baseIndex(r, c); + const Index idx3 = dm3.baseIndex(r, c); + + const Index start_depth = ((c == start_col) && (r == start_row)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock kernel; + kernel.packet[0] = pad0 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel.packet[1] = pad1 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel.packet[2] = pad2 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel.packet[3] = pad3 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + k += packet_size; + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + } else { + for (; k < peeled_k; k += packet_size) { + PacketBlock kernel; + kernel.packet[0] = dm0.loadPacketStandard(k); + kernel.packet[1] = dm1.loadPacketStandard(k); + kernel.packet[2] = dm2.loadPacketStandard(k); + kernel.packet[3] = dm3.loadPacketStandard(k); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + } + } + } + + // Copy the remaining coefficients of the column block after the peeled_k. + if (!rhs.nonStandardPatches()) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // copy the remaining columns one at a time (nr==1) + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Template specialization for packet_size = 2. We must special-case packet +// blocks with nr > packet_size, e.g. PacketBlock. +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, + Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, + Alignment> + SubMapper; + typedef SubMapper DataMapper; + typedef typename packet_traits::type Packet; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const int packet_size = 2; + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if (!non_standard_patches) { + // FAST PATH: + // Iterate over patch columns and rows if we know that a single + // packet do not span across multiple rows or columns. + if ((rhs.patchDepth() % packet_size) == 0) { + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + // We can squeeze reads along the `row` and `depth` dimensions if + // the row stride is `1`, which means that `row` and `depth` + // dimensions are contiguous (two innermost dimensions). + if (rhs.rowStride() == 1 && // + !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // + !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // + !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // + !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // + !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { + // Compute how many elements we can squeeze read. + const Index start_depth = + (c == start_col) ? rhs.depthOffset() : 0; + + // Upper bound for the number of elements in the depth dimension + // that we can squeeze read. + const Index squeeze_length = + (max_row - start_row) * rhs.patchDepth() - start_depth; + + // Do not overshoot beyond the block size. + const Index max_depth = + start_depth + std::min(peeled_k - k, squeeze_length); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + const Index idx0 = dm0.baseIndex(start_row, c); + const Index idx1 = dm1.baseIndex(start_row, c); + const Index idx2 = dm2.baseIndex(start_row, c); + const Index idx3 = dm3.baseIndex(start_row, c); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + PacketBlock kernel0; + PacketBlock kernel1; + kernel0.packet[0] = rhs.packetNoPadding(d, idx0); + kernel0.packet[1] = rhs.packetNoPadding(d, idx1); + kernel1.packet[0] = rhs.packetNoPadding(d, idx2); + kernel1.packet[1] = rhs.packetNoPadding(d, idx3); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + k += packet_size; + } + + // Go to the next column. + continue; + } + + // If we can't squeeze reads, process rows one by one. + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const bool pad0 = pad_col0 || dm0.padRow(r); + const bool pad1 = pad_col1 || dm1.padRow(r); + const bool pad2 = pad_col2 || dm2.padRow(r); + const bool pad3 = pad_col3 || dm3.padRow(r); + + const Index idx0 = dm0.baseIndex(r, c); + const Index idx1 = dm1.baseIndex(r, c); + const Index idx2 = dm2.baseIndex(r, c); + const Index idx3 = dm3.baseIndex(r, c); + + const Index start_depth = ((c == start_col) && (r == start_row)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock kernel0; + PacketBlock kernel1; + kernel0.packet[0] = pad0 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel0.packet[1] = pad1 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel1.packet[0] = pad2 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel1.packet[1] = pad3 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + k += packet_size; + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + } else { + // Packet can span multiple rows or columns, so we have to go + // though the slower "standard" path. + for (; k < peeled_k; k += packet_size) { + PacketBlock kernel0; + PacketBlock kernel1; + kernel0.packet[0] = dm0.loadPacketStandard(k); + kernel0.packet[1] = dm1.loadPacketStandard(k); + kernel1.packet[0] = dm2.loadPacketStandard(k); + kernel1.packet[1] = dm3.loadPacketStandard(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } + } + + // Copy the remaining coefficients of the column block after the peeled_k. + if (!non_standard_patches) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // Copy the remaining columns one at a time (nr==1). + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Special case for non-vectorized types such as float16. +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, + Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp >, + Device>, + nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, + Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const Index packet_cols4 = (cols / 4) * 4; + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + if (!rhs.nonStandardPatches()) { + for (Index k = 0; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (Index k = 0; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // Copy the remaining columns one at a time (nr==1). + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; +} // end namespace internal + +/** SpatialConvolution + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a 2D convolution over a multichannel input image. + * + * The input parameter is expected to be a tensor with a rank of 3 or more + * (channels, height, width, and optionally others) + * The kernel parameter is expected to be a 4D tensor (filters, channels, + * kernel_height, kernel_width) + * The input and the kernel must both be in col-major layout. The result will + * also be in col-major layout. + * + * If col_in_stride, row_in_stride > 1, then applies convolution with holes + * (aka atrous convolution), sampling every col_in_stride, row_in_stride input + * pixels. + * + * If padding_top, padding_bottom, padding_left, or padding_right is specified, + * then those paddings will be used to pad the input, and padding_type must be + * PADDING_VALID. + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be filters, height, width (and + * others if applicable). + * + * It is possible to swap the order of the width and height dimensions provided + * that the same order is used in the input, the kernel, and the output. + * + * It is also possible to add an output kernel to the contraction, output + * kernel is called by Eigen when it "finalizes" the block of an output tensor. + * + */ +template +EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE static const typename internal::conditional< + internal::traits::Layout == ColMajor, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp >, + const OutputKernel> >, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel>, + const OutputKernel> > >::type + SpatialConvolution(const Input& input, const Kernel& kernel, + const Index row_stride = 1, const Index col_stride = 1, + const PaddingType padding_type = PADDING_SAME, + const Index row_in_stride = 1, + const Index col_in_stride = 1, + const OutputKernel& output_kernel = OutputKernel(), + Index padding_top = 0, Index padding_bottom = 0, + Index padding_left = 0, Index padding_right = 0) { + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + kern(kernel); + + EIGEN_STATIC_ASSERT( + internal::traits::Layout == internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE) + const bool isColMajor = (internal::traits::Layout == ColMajor); + + const int NumDims = internal::traits::NumDimensions; + + // Number of filters to apply. This is the same as the output depth of the + // result + const TensorIndex kernelFilters = + isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; + const TensorIndex kernelRows = + isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; + const TensorIndex kernelCols = + isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; + + const Index kernelRowsEff = + kernelRows + (kernelRows - 1) * (row_in_stride - 1); + const Index kernelColsEff = + kernelCols + (kernelCols - 1) * (col_in_stride - 1); + + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); + + const TensorIndex InputRows = + isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); + const TensorIndex InputCols = + isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + const bool padding_explicit = + (padding_top || padding_bottom || padding_left || padding_right); + + TensorIndex out_height; + TensorIndex out_width; + switch (padding_type) { + case PADDING_VALID: { + const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom; + const TensorIndex InputColsEff = InputCols + padding_left + padding_right; + out_height = numext::ceil((InputRowsEff - kernelRowsEff + 1.f) / + static_cast(row_stride)); + out_width = numext::ceil((InputColsEff - kernelColsEff + 1.f) / + static_cast(col_stride)); + break; + } + case PADDING_SAME: { + eigen_assert(!padding_explicit); + out_height = numext::ceil(InputRows / static_cast(row_stride)); + out_width = numext::ceil(InputCols / static_cast(col_stride)); + break; + } + default: { + // Initialize unused variables to avoid a compiler warning + out_height = 0; + out_width = 0; + eigen_assert(false && "unexpected padding"); + } + } + + // Molds the output of the patch extraction code into a 2d tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[1] = out_height * out_width; + for (int i = 3; i < NumDims; ++i) { + pre_contract_dims[1] *= in.dimension(i); + } + } else { + pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[0] = out_height * out_width; + for (int i = 0; i < NumDims - 3; ++i) { + pre_contract_dims[0] *= in.dimension(i); + } + } + + // Molds the output of the contraction into the shape expected by the used + // (assuming this is ColMajor): + // - 1st dim: kernel filters + // - 2nd dim: output height + // - 3rd dim: output width + // - 4th dim and beyond: everything else including batch size + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = out_height; + post_contract_dims[2] = out_width; + for (int i = 3; i < NumDims; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } else { + post_contract_dims[NumDims - 1] = kernelFilters; + post_contract_dims[NumDims - 2] = out_height; + post_contract_dims[NumDims - 3] = out_width; + for (int i = 0; i < NumDims - 3; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } + + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters; + kernel_dims[1] = kernelChannels * kernelRows * kernelCols; + } else { + kernel_dims[0] = kernelChannels * kernelRows * kernelCols; + kernel_dims[1] = kernelFilters; + } + if (padding_explicit) { + return choose( + Cond::Layout == ColMajor>(), + kernel.reshape(kernel_dims) + .contract(input + .extract_image_patches( + kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, + /*row_inflate_stride=*/1, + /*col_inflate_stride=*/1, padding_top, + padding_bottom, padding_left, padding_right, + /*padding_value=*/0) + .reshape(pre_contract_dims), + contract_dims, output_kernel) + .reshape(post_contract_dims), + input + .extract_image_patches(kernelRows, kernelCols, row_stride, + col_stride, row_in_stride, col_in_stride, + /*row_inflate_stride=*/1, + /*col_inflate_stride=*/1, padding_top, + padding_bottom, padding_left, padding_right, + /*padding_value=*/0) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) + .reshape(post_contract_dims)); + } else { + return choose( + Cond::Layout == ColMajor>(), + kernel.reshape(kernel_dims) + .contract(input + .extract_image_patches( + kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims), + contract_dims, output_kernel) + .reshape(post_contract_dims), + input + .extract_image_patches(kernelRows, kernelCols, row_stride, + col_stride, row_in_stride, col_in_stride, + padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) + .reshape(post_contract_dims)); + } +} + +} // end namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h index ca5f4b2d5e7..0127b65a7ef 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h @@ -18,1290 +18,57 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +// Note the following header is used in both TF and TFLite. Particularly, it's +// used for float TFLite Conv2D. +#include "tensorflow/core/kernels/eigen_spatial_convolutions-inl.h" + #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) #include "tensorflow/core/kernels/eigen_contraction_kernel.h" -#endif namespace Eigen { - namespace internal { -// WARNING: Most of the code here implicitly assumes that the matrix is in -// ColMajor layout. This is guaranteed by the tensor contraction (see -// TensorContraction.h). -// -// Inside Eigen a tensor contraction is represented by a matrix multiplication. -// We don't want to actually extract image patches and reshape the result into -// a matrix (this involves allocating huge extra memory), so the patch -// extraction and reshape operations are implicit. -// -// TensorContractionInputMapper takes a matrix index and returns the coefficient -// (or the packet) of the "virtual tensor", that would be at that index if we -// were to actually reshape the result of patch extraction. -// -// TensorContractionSubMapper provides a similar view into the "virtual matrix" -// at the given vertical and horizontal offsets. -// -// "Virtual matrix" dimensions: -// *0: kernelChannels * kernelRows * kernelCols; -// 1: out_height * out_width; * OTHERS (e.g batches, etc...) -// -// *) extracted patches are continuous in memory (innermost dimension assuming -// col major layout) -// -// With this dimensions: -// row - offset within a single patch (in code: patchId) -// col - index of the extracted patch (in code: patchIndex) -// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) -// -// TODO(ezhulenev): Consolidate this part of the code with the image patch -// extraction code since they are both very similar. +// After we vectorized all loads from the underlying tensor using Packet ops, we +// have to finalize coefficients that do not fit into a packet. +template +struct FinalizeDataMapperCoeffs { + EIGEN_ALWAYS_INLINE static Index finalize(Scalar* block, + const DataMapper& rhs, + Index base_idx, Index depth, + Index max_depth, bool pad = false) { + const Index num_coeffs = max_depth - depth; + eigen_assert(num_coeffs <= packet_size); -template -class TensorContractionInputMapper< - Scalar_, Index, Side, - TensorEvaluator< - const TensorReshapingOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment> { - public: - typedef Scalar_ Scalar; - - typedef TensorContractionInputMapper< - Scalar, Index, Side, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment> - Self; - - typedef TensorContractionSubMapper< - Scalar, Index, Side, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment> - SubMapper; - - typedef SubMapper VectorMapper; - typedef SubMapper LinearMapper; - typedef typename packet_traits::type Packet; - - EIGEN_DEVICE_FUNC - TensorContractionInputMapper( - const TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>& tensor, - const nocontract_t&, const nocontract_t&, const contract_t&, - const contract_t&) - : m_impl(tensor.impl().impl()) { - Index patch_rows; - Index patch_depth; - if (internal::traits::Layout == ColMajor) { - patch_depth = tensor.impl().dimensions()[0]; - patch_rows = tensor.impl().dimensions()[1]; - m_patch_cols = tensor.impl().dimensions()[2]; - m_num_patches = tensor.impl().dimensions()[3]; - } else { - const size_t NumDims = tensor.impl().dimensions().size(); - patch_depth = tensor.impl().dimensions()[NumDims - 1]; - patch_rows = tensor.impl().dimensions()[NumDims - 2]; - m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; - m_num_patches = tensor.impl().dimensions()[NumDims - 4]; + for (; depth < max_depth; ++depth) { + *block = pad ? Scalar(0) : rhs.coeffNoPadding(depth, base_idx); + ++block; } - // Strides for navigating through the single patch. - m_patch_row_stride = patch_depth; - m_patch_col_stride = patch_rows * m_patch_row_stride; - - m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); - m_patch_col_inflate_strides = tensor.impl().colInflateStride(); - - m_colStride = patch_rows; - - m_outputRows = tensor.impl().outputRows(); - m_row_strides = tensor.impl().userRowStride(); - m_col_strides = tensor.impl().userColStride(); - - m_in_row_strides = tensor.impl().userInRowStride(); - m_in_col_strides = tensor.impl().userInColStride(); - - if (internal::traits::Layout == ColMajor) { - m_inputRows = tensor.impl().impl().dimensions()[1]; - m_inputCols = tensor.impl().impl().dimensions()[2]; - } else { - const int NumDims = tensor.impl().impl().dimensions().size(); - m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; - m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; - } - - m_rowInputStride = patch_depth; - m_colInputStride = patch_depth * m_inputRows; - m_patchInputStride = patch_depth * m_inputRows * m_inputCols; - - m_rowPaddingTop = tensor.impl().rowPaddingTop(); - m_colPaddingLeft = tensor.impl().colPaddingLeft(); - - m_fastPatchRowStride = - internal::TensorIntDivisor(m_patch_row_stride); - m_fastPatchColStride = - internal::TensorIntDivisor(m_patch_col_stride); - m_fastInputRowStride = - internal::TensorIntDivisor(m_patch_row_inflate_strides); - m_fastInputColStride = - internal::TensorIntDivisor(m_patch_col_inflate_strides); - m_fastNumPatches = internal::TensorIntDivisor(m_num_patches); - m_fastColStride = internal::TensorIntDivisor(m_colStride); - m_fastOutputRows = internal::TensorIntDivisor(m_outputRows); - m_fastDimZero = internal::TensorIntDivisor(patch_depth); - } - - EIGEN_DEVICE_FUNC - TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) - : m_impl(base_mapper.m_impl) { - m_patch_cols = base_mapper.m_patch_cols; - m_num_patches = base_mapper.m_num_patches; - - m_patch_row_stride = base_mapper.m_patch_row_stride; - m_patch_col_stride = base_mapper.m_patch_col_stride; - - m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; - m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; - - m_colStride = base_mapper.m_colStride; - - m_rowInputStride = base_mapper.m_rowInputStride; - m_colInputStride = base_mapper.m_colInputStride; - m_patchInputStride = base_mapper.m_patchInputStride; - - m_inputRows = base_mapper.m_inputRows; - m_inputCols = base_mapper.m_inputCols; - - m_outputRows = base_mapper.m_outputRows; - m_row_strides = base_mapper.m_row_strides; - m_col_strides = base_mapper.m_col_strides; - - m_in_row_strides = base_mapper.m_in_row_strides; - m_in_col_strides = base_mapper.m_in_col_strides; - - m_rowPaddingTop = base_mapper.m_rowPaddingTop; - m_colPaddingLeft = base_mapper.m_colPaddingLeft; - - m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; - m_fastPatchColStride = base_mapper.m_fastPatchColStride; - m_fastInputRowStride = base_mapper.m_fastInputRowStride; - m_fastInputColStride = base_mapper.m_fastInputColStride; - m_fastNumPatches = base_mapper.m_fastNumPatches; - m_fastColStride = base_mapper.m_fastColStride; - m_fastOutputRows = base_mapper.m_fastOutputRows; - m_fastDimZero = base_mapper.m_fastDimZero; - } - - // If true, turns off some optimizations for loading packets since the image - // patches are "non-standard" such as there are non-trivial strides or - // inflations in the input. - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { - return m_in_row_strides != 1 || m_in_col_strides != 1 || - m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; - } - - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { - return SubMapper(*this, i, j); - } - - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { - return LinearMapper(*this, i, j); - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { - Index rowIndex, colIndex, otherIndex; - computeBaseIndices(0, rowIndex, colIndex, otherIndex); - return loadCoeff(row, rowIndex, colIndex, otherIndex); - } - - // Load the coefficient at the patchIndex location instead of the usual - // m_rowIndex, - // m_colIndex, m_otherIndex. This is currently only used by the gpu code. - // EIGEN_DEVICE_FUNC - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { - Index rowIndex, colIndex, otherIndex; - computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); - return loadCoeff(row, rowIndex, colIndex, otherIndex); - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { - Index rowIndex, colIndex, otherIndex; - computeBaseIndices(0, rowIndex, colIndex, otherIndex); - return loadPacket(row, rowIndex, colIndex, otherIndex); - } - - // Load the packet at the patchIndex location instead of the usual m_rowIndex, - // m_colIndex, m_otherIndex. This is currently only used by the gpu code. - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { - Index rowIndex, colIndex, otherIndex; - computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); - return loadPacket(row, rowIndex, colIndex, otherIndex); - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE const TensorEvaluator& impl() const { - return m_impl; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } - - private: - friend class TensorContractionSubMapper< - Scalar, Index, Side, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment>; - - // Load coefficient from a patch specified by the "within patch offset" - // (patchId) and the precomputed indices of the first element of the patch. - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, - Index colIndex, Index otherIndex) const { - // Find the offset of the element wrt the location of the first element. - const Index patchOffset = patchId / m_fastDimZero; - - const Index colOffset = patchOffset / m_fastColStride; - const Index inputCol = colIndex + colOffset * m_in_col_strides; - const Index origInputCol = - (m_patch_col_inflate_strides == 1) - ? inputCol - : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); - - const Index rowOffset = patchOffset - colOffset * m_colStride; - const Index inputRow = rowIndex + rowOffset * m_in_row_strides; - const Index origInputRow = - (m_patch_row_inflate_strides == 1) - ? inputRow - : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); - if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols || - origInputRow >= m_inputRows || - (inputCol != origInputCol * m_patch_col_inflate_strides) || - (inputRow != origInputRow * m_patch_row_inflate_strides)) { - return Scalar(0); - } - const Index depth = patchId - patchOffset * patchDepth(); - const Index inputIndex = depth + origInputRow * m_rowInputStride + - origInputCol * m_colInputStride + otherIndex; - return m_impl.coeff(inputIndex); - } - - // This is the same as loadCoeff(...), but optimized for all `inflate_strides` - // and `in_strides` equal to 1 (template specialization without templates). - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, - Index colIndex, - Index otherIndex) const { - eigen_assert(!nonStandardPatches()); - - // Find the offset of the element wrt the location of the first element. - const Index patchOffset = patchId / m_fastDimZero; - const Index colOffset = patchOffset / m_fastColStride; - const Index rowOffset = patchOffset - colOffset * m_colStride; - const Index inputCol = colIndex + colOffset; - const Index inputRow = rowIndex + rowOffset; - if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || - inputRow >= m_inputRows) { - return Scalar(0); - } - const Index depth = patchId - patchOffset * patchDepth(); - const Index inputIndex = depth + inputRow * m_rowInputStride + - inputCol * m_colInputStride + otherIndex; - return m_impl.coeff(inputIndex); - } - - // Load packet from a patch specified by the "within patch offset" - // (patchId) and the precomputed indices of the first element of the patch. - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, - Index colIndex, - Index otherIndex) const { - const Index packetSize = internal::unpacket_traits::size; - EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) - eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); - - if (nonStandardPatches()) { - return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); - } - return loadPacketStandard(patchId, rowIndex, colIndex, otherIndex); - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index rowIndex, - Index colIndex, - Index otherIndex) const { - const Index packetSize = internal::unpacket_traits::size; - EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) - eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); - - eigen_assert(!nonStandardPatches()); - - if ((patchDepth() % packetSize) == 0) { - return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); - } else { - // Offsets and input calculation here are identical to - // loadCoeffStandard(...), but repeated twice. - - const Index patchOffsets[2] = { - patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; - - const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, - patchOffsets[1] / m_fastColStride}; - const Index inputCols[2] = {colIndex + colOffsets[0], - colIndex + colOffsets[1]}; - if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { - // all zeros - return internal::pset1(Scalar(0)); - } - - if (inputCols[0] == inputCols[1]) { - const Index rowOffsets[2] = { - patchOffsets[0] - colOffsets[0] * m_colStride, - patchOffsets[1] - colOffsets[1] * m_colStride}; - eigen_assert(rowOffsets[0] <= rowOffsets[1]); - const Index inputRows[2] = {rowIndex + rowOffsets[0], - rowIndex + rowOffsets[1]}; - - if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { - // all zeros - return internal::pset1(Scalar(0)); - } - - if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { - // no padding - const Index depth = patchId - patchOffsets[0] * patchDepth(); - const Index inputIndex = depth + inputRows[0] * m_rowInputStride + - inputCols[0] * m_colInputStride + otherIndex; - return m_impl.template packet(inputIndex); - } - } - } - return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, - Index colIndex, - Index otherIndex) const { - const Index packetSize = internal::unpacket_traits::size; - EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) - eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); - - eigen_assert(!nonStandardPatches()); - eigen_assert((patchDepth() % packetSize) == 0); - // Find the offset of the element wrt the location of the first element. - const Index patchOffset = patchId / m_fastDimZero; - eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); - - const Index colOffset = patchOffset / m_fastColStride; - const Index rowOffset = patchOffset - colOffset * m_colStride; - const Index inputCol = colIndex + colOffset; - const Index inputRow = rowIndex + rowOffset; - if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || - inputRow >= m_inputRows) { - // all zeros - return internal::pset1(Scalar(0)); - } - // no padding - const Index depth = patchId - patchOffset * patchDepth(); - const Index inputIndex = depth + inputRow * m_rowInputStride + - inputCol * m_colInputStride + otherIndex; - return m_impl.template packet(inputIndex); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero( - Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { - const int packetSize = internal::unpacket_traits::size; - EIGEN_ALIGN_MAX - typename internal::remove_const::type values[packetSize]; - for (int i = 0; i < packetSize; ++i) { - values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex); - } - Packet rslt = internal::pload(values); - return rslt; - } - - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( - Index patchIndex, Index& rowIndex, Index& colIndex, - Index& otherIndex) const { - const size_t NumInputDims = array_size< - typename TensorEvaluator::Dimensions>::value; - otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; - const Index patch2DIndex = (NumInputDims == 3) - ? patchIndex - : (patchIndex - otherIndex * m_num_patches); - otherIndex *= m_patchInputStride; - colIndex = patch2DIndex / m_fastOutputRows; - rowIndex = patch2DIndex - colIndex * m_outputRows; - colIndex = colIndex * m_col_strides - m_colPaddingLeft; - rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; - } - - Index m_patch_cols; // number of columns in the patch - Index m_num_patches; // number of patches to extract. - - // Strides for navigating through the single patch. - Index m_patch_row_stride; - Index m_patch_col_stride; - internal::TensorIntDivisor m_fastPatchRowStride; - internal::TensorIntDivisor m_fastPatchColStride; - - Index m_patch_row_inflate_strides; // the strides for row inflation in the - // image patch - Index m_patch_col_inflate_strides; // the strides for col inflation in the - // image patch - // Fast representation of inflation strides. - internal::TensorIntDivisor m_fastInputRowStride; - internal::TensorIntDivisor m_fastInputColStride; - - Index m_otherStride; - Index m_colStride; - internal::TensorIntDivisor m_fastNumPatches; - internal::TensorIntDivisor m_fastColStride; - - Index m_rowInputStride; // row stride in the input tensor - Index m_colInputStride; // col stride in the input tensor - Index m_patchInputStride; // patch stride in the input tensor - - Index m_inputRows; // Number of rows in the input tensor - Index m_inputCols; // Number of cols in the input tensor - - Index m_outputRows; // Number of patch rows - - Index m_row_strides; // User specified row stride - Index m_col_strides; // User specified col stride - - Index m_in_row_strides; // User specified input row stride - Index m_in_col_strides; // User specified input col stride - - Index m_rowPaddingTop; // Row padding - Index m_colPaddingLeft; // Column padding - - internal::TensorIntDivisor m_fastOutputRows; - internal::TensorIntDivisor m_fastDimZero; - - const TensorEvaluator m_impl; -}; - -template -class TensorContractionSubMapper< - Scalar, Index, Side, - TensorEvaluator< - const TensorReshapingOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment> { - public: - typedef typename packet_traits::type Packet; - typedef typename packet_traits::half HalfPacket; - - typedef TensorContractionInputMapper< - Scalar, Index, Side, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment> - ParentMapper; - - typedef TensorContractionSubMapper< - Scalar, Index, Side, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment> - Self; - - typedef Self LinearMapper; - - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( - const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) - : m_depth_offset(vert_offset), - m_col_offset(horiz_offset), - m_base_mapper(base_mapper) { - m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, - m_otherIndex); - } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( - const Self& base_mapper, Index vert_offset, Index horiz_offset) - : m_depth_offset(vert_offset + base_mapper.m_depth_offset), - m_col_offset(horiz_offset + base_mapper.m_col_offset), - m_base_mapper(base_mapper.m_base_mapper) { - m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, - m_otherIndex); - } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { - return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, - m_otherIndex); - } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, - Index j) const { - return m_base_mapper(i + m_depth_offset, j + m_col_offset); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { - return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, - m_otherIndex); - } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, - Index j) const { - return m_base_mapper.template loadPacket(i + m_depth_offset, - j + m_col_offset); - } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar - loadCoeffStandard(Index i) const { - return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, - m_colIndex, m_otherIndex); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { - return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, - m_colIndex, m_otherIndex); - } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet - loadPacketStandard(Index i) const { - return m_base_mapper.loadPacketStandard(i + m_depth_offset, m_rowIndex, - m_colIndex, m_otherIndex); - } - template - EIGEN_DEVICE_FUNC bool aligned(Index) const { - return false; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { - return m_base_mapper.nonStandardPatches(); - } - - // Max(Col|Row|Depth): compute the upper limit for the column, row and depth - // index respectively that fits into the peeled_k elements starting at - // m_depth_offset. - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { - const Index max_col = - (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / - fastPatchColStride(); - return std::min(1 + max_col, patchCols()); - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, - const Index col) const { - const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - - col * patchColStride()) / - fastPatchRowStride(); - return std::min(1 + max_row, patchRows()); - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, - Index row) const { - const Index max_depth = m_depth_offset + peeled_k - // - col * patchColStride() - // - row * patchRowStride(); - return std::min(max_depth, patchDepth()); - } - - // MaxDepth uses only the remaining number of elements in the peeled_k. - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, - const Index start_depth) const { - return std::min(start_depth + num_elements, patchDepth()); - } - - // Every register matters in this code, so sometimes to prevent register - // spilling, instead of the variable that you would expect to see, we use - // another one, that is guaranteed to have the same value. E.g. patch depth is - // always the same as input depth, and it's also the same as input row stride. - // Bunch of other parameters have similar relations. - - typedef internal::TensorIntDivisor IndexDivisor; - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index patchDepth() const { - return m_base_mapper.m_rowInputStride; - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index patchRows() const { - return m_base_mapper.m_colStride; - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index patchCols() const { - return m_base_mapper.m_patch_cols; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index patchRowStride() const { - eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && - "Patch depth must be equal to patch row stride."); - return patchDepth(); - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index patchColStride() const { - return m_base_mapper.m_patch_col_stride; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { - eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && - "Patch depth must be equal to patch row stride."); - return m_base_mapper.m_fastDimZero; // patch_depth - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { - return m_base_mapper.m_fastPatchColStride; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, - const Index baseIndex) const { - const Index inputIndex = depth + baseIndex; - return m_base_mapper.m_impl.template packet(inputIndex); - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, - const Index baseIndex) const { - const Index inputIndex = depth + baseIndex; - return m_base_mapper.m_impl.coeff(inputIndex); - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { - const Index r = m_rowIndex + row; - return r < 0 || r >= m_base_mapper.m_inputRows; - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, - const Index last_row) const { - return m_rowIndex + first_row < 0 || - m_rowIndex + last_row >= m_base_mapper.m_inputRows; - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { - const Index c = m_colIndex + col; - return c < 0 || c >= m_base_mapper.m_inputCols; - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const { - const Index r = m_rowIndex + row; - const Index c = m_colIndex + col; - return r * m_base_mapper.m_rowInputStride + - c * m_base_mapper.m_colInputStride + m_otherIndex; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index rowStride() const { - return m_base_mapper.m_row_strides; - } - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index colStride() const { - return m_base_mapper.m_col_strides; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index rowOffset() const { - const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; - const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; - return patchOffset - colOffset * m_base_mapper.m_colStride; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index colOffset() const { - const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; - const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; - return colOffset; - } - - EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE Index depthOffset() const { - return m_depth_offset % patchDepth(); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper - getLinearMapper(Index i, Index j) const { - return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); - } - - private: - Index m_depth_offset; // First row in the input matrix - Index m_col_offset; // First col in the input matrix - - // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base - // indices for the first element in a patch specified by col_offset - // (see computeBaseIndices(...) for details). - Index m_rowIndex; - Index m_colIndex; - Index m_otherIndex; - - const ParentMapper m_base_mapper; // Keeping a copy instead of a reference - // performs better in benchmarks. -}; - -// Arrange a block of the right input matrix (in our case it's always a "virtual -// matrix" constructed from extracted image patches) in contiguous memory. -// -// Given column major input (A0 beside A1 in memory): -// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 -// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 -// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 -// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 -// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 -// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 -// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 -// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 -// A8 ... -// ... -// -// *) A, B, C, ... - patches extracted from the original input. -// *) A0, A1, A2 ... - values from the same patch at different offsets. -// -// The traversal (packed rhs memory) order (B0 besides A0 in memory): -// A0 B0 C0 D0 A1 B1 C1 D1 ... -// E0 F0 G0 H0 E1 F1 G1 H1 ... -// ... -// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) -// -// This traversal order must be the same as in default gemm_pack_rhs defined in -// GeneralBlockPanelKernel.h. -// -// *) nr - number of registers along the 'n' dimension. -// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix -// Multiplication" paper. -template -struct gemm_pack_rhs< - Scalar, Index, - TensorContractionSubMapper< - Scalar, Index, Rhs, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment>, - nr, ColMajor, false, false> { - typedef TensorContractionSubMapper< - Scalar, Index, Rhs, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, packet_size, inner_dim_contiguous, - inner_dim_reordered, Alignment> - SubMapper; - typedef SubMapper DataMapper; - typedef typename packet_traits::type Packet; - - EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) - - EIGEN_DEVICE_FUNC - EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, - Index depth, Index cols, Index stride = 0, - Index offset = 0) const { - eigen_assert(stride == 0); - eigen_assert(offset == 0); - - const Index packet_cols4 = (cols / 4) * 4; - const Index peeled_k = (depth / packet_size) * packet_size; - const bool non_standard_patches = rhs.nonStandardPatches(); - - for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { - const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); - const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); - const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); - const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); - - Index k = 0; - if ((packet_size % 4) == 0 && !non_standard_patches) { - // FAST PATH: - // Iterate over patch columns and rows, if we know that a single - // packet do not span across multiple rows or columns. - if ((rhs.patchDepth() % packet_size) == 0) { - const Index start_col = rhs.colOffset(); - const Index max_col = rhs.maxCol(peeled_k); - - for (Index c = start_col; c < max_col; ++c) { - eigen_assert(k <= peeled_k); - - const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; - const Index max_row = rhs.maxRow(peeled_k, c); - - const bool pad_col0 = dm0.padCol(c); - const bool pad_col1 = dm1.padCol(c); - const bool pad_col2 = dm2.padCol(c); - const bool pad_col3 = dm3.padCol(c); - - // Check if we can squeeze reads along the `row` and `depth` - // dimensions (two innermost dimensions). - if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // - !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // - !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // - !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // - !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { - // Compute how many elements we can squeeze read. - const Index start_depth = - (c == start_col) ? rhs.depthOffset() : 0; - - // Upper bound for the number of elements in the depth dimension - // that we can squeeze read. - const Index squeeze_length = - (max_row - start_row) * rhs.patchDepth() - start_depth; - - // Do not overshoot beyond the block size. - const Index max_depth = - start_depth + std::min(peeled_k - k, squeeze_length); - eigen_assert((max_depth - start_depth) % packet_size == 0); - - const Index idx0 = dm0.baseIndex(start_row, c); - const Index idx1 = dm1.baseIndex(start_row, c); - const Index idx2 = dm2.baseIndex(start_row, c); - const Index idx3 = dm3.baseIndex(start_row, c); - - for (Index d = start_depth; d < max_depth; d += packet_size) { - eigen_assert(k < peeled_k); - PacketBlock kernel; - kernel.packet[0] = rhs.packetNoPadding(d, idx0); - kernel.packet[1] = rhs.packetNoPadding(d, idx1); - kernel.packet[2] = rhs.packetNoPadding(d, idx2); - kernel.packet[3] = rhs.packetNoPadding(d, idx3); - ptranspose(kernel); - pstoreu(block + 0 * packet_size, kernel.packet[0]); - pstoreu(block + 1 * packet_size, kernel.packet[1]); - pstoreu(block + 2 * packet_size, kernel.packet[2]); - pstoreu(block + 3 * packet_size, kernel.packet[3]); - block += 4 * packet_size; - k += packet_size; - } - - // Go to the next column. - continue; - } - - // If we can't squeeze reads, process rows one by one. - for (Index r = start_row; r < max_row; ++r) { - eigen_assert(k <= peeled_k); - - const bool pad0 = pad_col0 || dm0.padRow(r); - const bool pad1 = pad_col1 || dm1.padRow(r); - const bool pad2 = pad_col2 || dm2.padRow(r); - const bool pad3 = pad_col3 || dm3.padRow(r); - - const Index idx0 = dm0.baseIndex(r, c); - const Index idx1 = dm1.baseIndex(r, c); - const Index idx2 = dm2.baseIndex(r, c); - const Index idx3 = dm3.baseIndex(r, c); - - const Index start_depth = ((c == start_col) && (r == start_row)) - ? rhs.depthOffset() - : 0; - const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); - eigen_assert((max_depth - start_depth) % packet_size == 0); - - for (Index d = start_depth; d < max_depth; d += packet_size) { - eigen_assert(k < peeled_k); - PacketBlock kernel; - kernel.packet[0] = pad0 ? pset1(Scalar(0)) - : rhs.packetNoPadding(d, idx0); - kernel.packet[1] = pad1 ? pset1(Scalar(0)) - : rhs.packetNoPadding(d, idx1); - kernel.packet[2] = pad2 ? pset1(Scalar(0)) - : rhs.packetNoPadding(d, idx2); - kernel.packet[3] = pad3 ? pset1(Scalar(0)) - : rhs.packetNoPadding(d, idx3); - ptranspose(kernel); - pstoreu(block + 0 * packet_size, kernel.packet[0]); - pstoreu(block + 1 * packet_size, kernel.packet[1]); - pstoreu(block + 2 * packet_size, kernel.packet[2]); - pstoreu(block + 3 * packet_size, kernel.packet[3]); - block += 4 * packet_size; - k += packet_size; - } - } - } - - // The loop above should fill peeled_k elements. - eigen_assert(peeled_k == k); - - } else { - for (; k < peeled_k; k += packet_size) { - PacketBlock kernel; - kernel.packet[0] = dm0.loadPacketStandard(k); - kernel.packet[1] = dm1.loadPacketStandard(k); - kernel.packet[2] = dm2.loadPacketStandard(k); - kernel.packet[3] = dm3.loadPacketStandard(k); - ptranspose(kernel); - pstoreu(block + 0 * packet_size, kernel.packet[0]); - pstoreu(block + 1 * packet_size, kernel.packet[1]); - pstoreu(block + 2 * packet_size, kernel.packet[2]); - pstoreu(block + 3 * packet_size, kernel.packet[3]); - block += 4 * packet_size; - } - } - } - - // Copy the remaining coefficients of the column block after the peeled_k. - if (!rhs.nonStandardPatches()) { - for (; k < depth; k++) { - block[0] = dm0.loadCoeffStandard(k); - block[1] = dm1.loadCoeffStandard(k); - block[2] = dm2.loadCoeffStandard(k); - block[3] = dm3.loadCoeffStandard(k); - block += 4; - } - } else { - for (; k < depth; k++) { - block[0] = dm0(k); - block[1] = dm1(k); - block[2] = dm2(k); - block[3] = dm3(k); - block += 4; - } - } - } - - // copy the remaining columns one at a time (nr==1) - for (Index j2 = packet_cols4; j2 < cols; ++j2) { - const SubMapper dm0 = rhs.getLinearMapper(0, j2); - for (Index k = 0; k < depth; k++) { - *block = dm0(k); - block += 1; - } - } + return num_coeffs; } }; -// Template specialization for packet_size = 2. We must special-case packet -// blocks with nr > packet_size, e.g. PacketBlock. -template -struct gemm_pack_rhs< - Scalar, Index, - TensorContractionSubMapper< - Scalar, Index, Rhs, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, - Alignment>, - nr, ColMajor, false, false> { - typedef TensorContractionSubMapper< - Scalar, Index, Rhs, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, - Alignment> - SubMapper; - typedef SubMapper DataMapper; - typedef typename packet_traits::type Packet; +template +struct FinalizeDataMapperCoeffs { + EIGEN_ALWAYS_INLINE static Index finalize(Scalar* block, + const DataMapper& rhs, + Index base_idx, Index depth, + Index max_depth, bool pad = false) { + Index num_coeffs = max_depth - depth; + eigen_assert(num_coeffs <= packet_size); + if (num_coeffs == 0) return 0; - EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) + using Packet = typename packet_traits::type; + Packet p = pad ? pset1(Scalar(0)) + : rhs.partialPacketNoPadding(depth, base_idx, num_coeffs); + internal::pstoreu(block, p, mask(0, num_coeffs)); - EIGEN_DEVICE_FUNC - EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, - Index depth, Index cols, Index stride = 0, - Index offset = 0) const { - eigen_assert(stride == 0); - eigen_assert(offset == 0); - - const int packet_size = 2; - const Index packet_cols4 = (cols / 4) * 4; - const Index peeled_k = (depth / packet_size) * packet_size; - const bool non_standard_patches = rhs.nonStandardPatches(); - - for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { - const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); - const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); - const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); - const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); - - Index k = 0; - if (!non_standard_patches) { - // FAST PATH: - // Iterate over patch columns and rows if we know that a single - // packet do not span across multiple rows or columns. - if ((rhs.patchDepth() % packet_size) == 0) { - const Index start_col = rhs.colOffset(); - const Index max_col = rhs.maxCol(peeled_k); - - for (Index c = start_col; c < max_col; ++c) { - eigen_assert(k <= peeled_k); - - const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; - const Index max_row = rhs.maxRow(peeled_k, c); - - const bool pad_col0 = dm0.padCol(c); - const bool pad_col1 = dm1.padCol(c); - const bool pad_col2 = dm2.padCol(c); - const bool pad_col3 = dm3.padCol(c); - - // We can squeeze reads along the `row` and `depth` dimensions if - // the row stride is `1`, which means that `row` and `depth` - // dimensions are contiguous (two innermost dimensions). - if (rhs.rowStride() == 1 && // - !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // - !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // - !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // - !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // - !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { - // Compute how many elements we can squeeze read. - const Index start_depth = - (c == start_col) ? rhs.depthOffset() : 0; - - // Upper bound for the number of elements in the depth dimension - // that we can squeeze read. - const Index squeeze_length = - (max_row - start_row) * rhs.patchDepth() - start_depth; - - // Do not overshoot beyond the block size. - const Index max_depth = - start_depth + std::min(peeled_k - k, squeeze_length); - eigen_assert((max_depth - start_depth) % packet_size == 0); - - const Index idx0 = dm0.baseIndex(start_row, c); - const Index idx1 = dm1.baseIndex(start_row, c); - const Index idx2 = dm2.baseIndex(start_row, c); - const Index idx3 = dm3.baseIndex(start_row, c); - - for (Index d = start_depth; d < max_depth; d += packet_size) { - PacketBlock kernel0; - PacketBlock kernel1; - kernel0.packet[0] = rhs.packetNoPadding(d, idx0); - kernel0.packet[1] = rhs.packetNoPadding(d, idx1); - kernel1.packet[0] = rhs.packetNoPadding(d, idx2); - kernel1.packet[1] = rhs.packetNoPadding(d, idx3); - ptranspose(kernel0); - ptranspose(kernel1); - pstoreu(block + 0 * packet_size, kernel0.packet[0]); - pstoreu(block + 1 * packet_size, kernel1.packet[0]); - pstoreu(block + 2 * packet_size, kernel0.packet[1]); - pstoreu(block + 3 * packet_size, kernel1.packet[1]); - block += 4 * packet_size; - k += packet_size; - } - - // Go to the next column. - continue; - } - - // If we can't squeeze reads, process rows one by one. - for (Index r = start_row; r < max_row; ++r) { - eigen_assert(k <= peeled_k); - - const bool pad0 = pad_col0 || dm0.padRow(r); - const bool pad1 = pad_col1 || dm1.padRow(r); - const bool pad2 = pad_col2 || dm2.padRow(r); - const bool pad3 = pad_col3 || dm3.padRow(r); - - const Index idx0 = dm0.baseIndex(r, c); - const Index idx1 = dm1.baseIndex(r, c); - const Index idx2 = dm2.baseIndex(r, c); - const Index idx3 = dm3.baseIndex(r, c); - - const Index start_depth = ((c == start_col) && (r == start_row)) - ? rhs.depthOffset() - : 0; - const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); - eigen_assert((max_depth - start_depth) % packet_size == 0); - - for (Index d = start_depth; d < max_depth; d += packet_size) { - eigen_assert(k < peeled_k); - PacketBlock kernel0; - PacketBlock kernel1; - kernel0.packet[0] = pad0 ? pset1(Scalar(0)) - : rhs.packetNoPadding(d, idx0); - kernel0.packet[1] = pad1 ? pset1(Scalar(0)) - : rhs.packetNoPadding(d, idx1); - kernel1.packet[0] = pad2 ? pset1(Scalar(0)) - : rhs.packetNoPadding(d, idx2); - kernel1.packet[1] = pad3 ? pset1(Scalar(0)) - : rhs.packetNoPadding(d, idx3); - ptranspose(kernel0); - ptranspose(kernel1); - pstoreu(block + 0 * packet_size, kernel0.packet[0]); - pstoreu(block + 1 * packet_size, kernel1.packet[0]); - pstoreu(block + 2 * packet_size, kernel0.packet[1]); - pstoreu(block + 3 * packet_size, kernel1.packet[1]); - block += 4 * packet_size; - k += packet_size; - } - } - } - - // The loop above should fill peeled_k elements. - eigen_assert(peeled_k == k); - - } else { - // Packet can span multiple rows or columns, so we have to go - // though the slower "standard" path. - for (; k < peeled_k; k += packet_size) { - PacketBlock kernel0; - PacketBlock kernel1; - kernel0.packet[0] = dm0.loadPacketStandard(k); - kernel0.packet[1] = dm1.loadPacketStandard(k); - kernel1.packet[0] = dm2.loadPacketStandard(k); - kernel1.packet[1] = dm3.loadPacketStandard(k); - ptranspose(kernel0); - ptranspose(kernel1); - pstoreu(block + 0 * packet_size, kernel0.packet[0]); - pstoreu(block + 1 * packet_size, kernel1.packet[0]); - pstoreu(block + 2 * packet_size, kernel0.packet[1]); - pstoreu(block + 3 * packet_size, kernel1.packet[1]); - block += 4 * packet_size; - } - } - } - - // Copy the remaining coefficients of the column block after the peeled_k. - if (!non_standard_patches) { - for (; k < depth; k++) { - block[0] = dm0.loadCoeffStandard(k); - block[1] = dm1.loadCoeffStandard(k); - block[2] = dm2.loadCoeffStandard(k); - block[3] = dm3.loadCoeffStandard(k); - block += 4; - } - } else { - for (; k < depth; k++) { - block[0] = dm0(k); - block[1] = dm1(k); - block[2] = dm2(k); - block[3] = dm3(k); - block += 4; - } - } - } - - // Copy the remaining columns one at a time (nr==1). - for (Index j2 = packet_cols4; j2 < cols; ++j2) { - const SubMapper dm0 = rhs.getLinearMapper(0, j2); - for (Index k = 0; k < depth; k++) { - *block = dm0(k); - block += 1; - } - } + return num_coeffs; } }; -// Special case for non-vectorized types such as float16. -template -struct gemm_pack_rhs< - Scalar, Index, - TensorContractionSubMapper< - Scalar, Index, Rhs, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, - Alignment>, - nr, ColMajor, false, false> { - typedef TensorContractionSubMapper< - Scalar, Index, Rhs, - TensorEvaluator< - const TensorReshapingOp< - NewDimension, const TensorImagePatchOp >, - Device>, - nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, - Alignment> - SubMapper; - typedef SubMapper DataMapper; - - EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) - - EIGEN_DEVICE_FUNC - EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, - Index depth, Index cols, Index stride = 0, - Index offset = 0) const { - eigen_assert(stride == 0); - eigen_assert(offset == 0); - - const Index packet_cols4 = (cols / 4) * 4; - - for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { - const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); - const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); - const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); - const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); - - if (!rhs.nonStandardPatches()) { - for (Index k = 0; k < depth; k++) { - block[0] = dm0.loadCoeffStandard(k); - block[1] = dm1.loadCoeffStandard(k); - block[2] = dm2.loadCoeffStandard(k); - block[3] = dm3.loadCoeffStandard(k); - block += 4; - } - } else { - for (Index k = 0; k < depth; k++) { - block[0] = dm0(k); - block[1] = dm1(k); - block[2] = dm2(k); - block[3] = dm3(k); - block += 4; - } - } - } - - // Copy the remaining columns one at a time (nr==1). - for (Index j2 = packet_cols4; j2 < cols; ++j2) { - const SubMapper dm0 = rhs.getLinearMapper(0, j2); - for (Index k = 0; k < depth; k++) { - *block = dm0(k); - block += 1; - } - } - } -}; - -#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) // Pack a block of the right input matrix (in our case it's always a // "virtual matrix" constructed from extracted image patches) in contiguous // block in column-major storage order. Knowing the properties of the @@ -1335,6 +102,12 @@ struct gemm_pack_colmajor_block< typedef SubMapper DataMapper; typedef typename packet_traits::type Packet; + using CoeffFinalizer = FinalizeDataMapperCoeffs< + Scalar, DataMapper, packet_size, + TensorEvaluatorHasPartialPacket::value && + unpacket_traits::masked_store_available>; + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper rhs, StorageIndex rows, StorageIndex cols) { @@ -1428,12 +201,14 @@ struct gemm_pack_colmajor_block< block += packet_size; k += packet_size; } - for (; d < max_depth; d++) { - eigen_assert(k < peeled_k); - *block = rhs.coeffNoPadding(d, base_idx); - ++block; - ++k; - } + + eigen_assert(k <= peeled_k); + const Index num_coeffs = + CoeffFinalizer::finalize(block, rhs, base_idx, d, max_depth); + + k += num_coeffs; + block += num_coeffs; + eigen_assert(k <= peeled_k); } // Go to the next column. @@ -1469,9 +244,9 @@ struct gemm_pack_colmajor_block< } } else { - const StorageIndex max_vectorized_depth = max_depth - packet_size; + const StorageIndex vectorized_depth = max_depth - packet_size; StorageIndex d = start_depth; - for (; d < max_vectorized_depth; d += packet_size) { + for (; d <= vectorized_depth; d += packet_size) { eigen_assert(k < peeled_k); const Packet p = pad ? pset1(Scalar(0)) : rhs.packetNoPadding(d, base_idx); @@ -1479,12 +254,14 @@ struct gemm_pack_colmajor_block< block += packet_size; k += packet_size; } - for (; d < max_depth; d++) { - eigen_assert(k < peeled_k); - *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx); - ++block; - ++k; - } + + eigen_assert(k <= peeled_k); + const Index num_coeffs = CoeffFinalizer::finalize( + block, rhs, base_idx, d, max_depth, pad); + + k += num_coeffs; + block += num_coeffs; + eigen_assert(k <= peeled_k); } } } @@ -1500,204 +277,7 @@ struct gemm_pack_colmajor_block< } } }; +} // namespace internal +} // namespace Eigen #endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) - -} // end namespace internal - -/** SpatialConvolution - * \ingroup CXX11_NeuralNetworks_Module - * - * \brief Applies a 2D convolution over a multichannel input image. - * - * The input parameter is expected to be a tensor with a rank of 3 or more - * (channels, height, width, and optionally others) - * The kernel parameter is expected to be a 4D tensor (filters, channels, - * kernel_height, kernel_width) - * The input and the kernel must both be in col-major layout. The result will - * also be in col-major layout. - * - * If col_in_stride, row_in_stride > 1, then applies convolution with holes - * (aka atrous convolution), sampling every col_in_stride, row_in_stride input - * pixels. - * - * The result can be assigned to a tensor of rank equal to the rank of the - * input. The dimensions of the result will be filters, height, width (and - * others if applicable). - * - * It is possible to swap the order of the width and height dimensions provided - * that the same order is used in the input, the kernel, and the output. - * - * It is also possible to add an output kernel to the contraction, output - * kernel is called by Eigen when it "finalizes" the block of an output tensor. - * - */ -template -EIGEN_DEVICE_FUNC - EIGEN_ALWAYS_INLINE static const typename internal::conditional< - internal::traits::Layout == ColMajor, - TensorReshapingOp< - const DSizes::Index, - internal::traits::NumDimensions>, - const TensorContractionOp< - const array::Index>, - 1>, - const TensorReshapingOp< - const DSizes::Index, 2>, - const Kernel>, - const TensorReshapingOp< - const DSizes::Index, 2>, - const TensorImagePatchOp >, - const OutputKernel> >, - TensorReshapingOp< - const DSizes::Index, - internal::traits::NumDimensions>, - const TensorContractionOp< - const array::Index>, - 1>, - const TensorReshapingOp< - const DSizes::Index, 2>, - const TensorImagePatchOp >, - const TensorReshapingOp< - const DSizes::Index, 2>, - const Kernel>, - const OutputKernel> > >::type - SpatialConvolution(const Input& input, const Kernel& kernel, - const Index row_stride = 1, const Index col_stride = 1, - const PaddingType padding_type = PADDING_SAME, - const Index row_in_stride = 1, - const Index col_in_stride = 1, - const OutputKernel& output_kernel = OutputKernel()) { - typedef typename internal::traits::Index TensorIndex; - TensorRef::Scalar, - internal::traits::NumDimensions, - internal::traits::Layout, TensorIndex> > - in(input); - TensorRef::Scalar, - internal::traits::NumDimensions, - internal::traits::Layout, TensorIndex> > - kern(kernel); - - EIGEN_STATIC_ASSERT( - internal::traits::Layout == internal::traits::Layout, - YOU_MADE_A_PROGRAMMING_MISTAKE) - const bool isColMajor = (internal::traits::Layout == ColMajor); - - const int NumDims = internal::traits::NumDimensions; - - // Number of filters to apply. This is the same as the output depth of the - // result - const TensorIndex kernelFilters = - isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; - // Number of channels. This is the same as the input depth. - const TensorIndex kernelChannels = - isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; - const TensorIndex kernelRows = - isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; - const TensorIndex kernelCols = - isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; - - const Index kernelRowsEff = - kernelRows + (kernelRows - 1) * (row_in_stride - 1); - const Index kernelColsEff = - kernelCols + (kernelCols - 1) * (col_in_stride - 1); - - array, 1> contract_dims; - contract_dims[0] = IndexPair(1, 0); - - const TensorIndex InputRows = - isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); - const TensorIndex InputCols = - isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); - - TensorIndex out_height; - TensorIndex out_width; - switch (padding_type) { - case PADDING_VALID: - out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / - static_cast(row_stride)); - out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / - static_cast(col_stride)); - break; - case PADDING_SAME: - out_height = numext::ceil(InputRows / static_cast(row_stride)); - out_width = numext::ceil(InputCols / static_cast(col_stride)); - break; - default: - // Initialize unused variables to avoid a compiler warning - out_height = 0; - out_width = 0; - eigen_assert(false && "unexpected padding"); - } - - // Molds the output of the patch extraction code into a 2d tensor: - // - the first dimension (dims[0]): the patch values to be multiplied with the - // kernels - // - the second dimension (dims[1]): everything else - DSizes pre_contract_dims; - if (isColMajor) { - pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; - pre_contract_dims[1] = out_height * out_width; - for (int i = 3; i < NumDims; ++i) { - pre_contract_dims[1] *= in.dimension(i); - } - } else { - pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; - pre_contract_dims[0] = out_height * out_width; - for (int i = 0; i < NumDims - 3; ++i) { - pre_contract_dims[0] *= in.dimension(i); - } - } - - // Molds the output of the contraction into the shape expected by the used - // (assuming this is ColMajor): - // - 1st dim: kernel filters - // - 2nd dim: output height - // - 3rd dim: output width - // - 4th dim and beyond: everything else including batch size - DSizes post_contract_dims; - if (isColMajor) { - post_contract_dims[0] = kernelFilters; - post_contract_dims[1] = out_height; - post_contract_dims[2] = out_width; - for (int i = 3; i < NumDims; ++i) { - post_contract_dims[i] = in.dimension(i); - } - } else { - post_contract_dims[NumDims - 1] = kernelFilters; - post_contract_dims[NumDims - 2] = out_height; - post_contract_dims[NumDims - 3] = out_width; - for (int i = 0; i < NumDims - 3; ++i) { - post_contract_dims[i] = in.dimension(i); - } - } - - DSizes kernel_dims; - if (isColMajor) { - kernel_dims[0] = kernelFilters; - kernel_dims[1] = kernelChannels * kernelRows * kernelCols; - } else { - kernel_dims[0] = kernelChannels * kernelRows * kernelCols; - kernel_dims[1] = kernelFilters; - } - return choose( - Cond::Layout == ColMajor>(), - kernel.reshape(kernel_dims) - .contract(input - .extract_image_patches( - kernelRows, kernelCols, row_stride, col_stride, - row_in_stride, col_in_stride, padding_type) - .reshape(pre_contract_dims), - contract_dims, output_kernel) - .reshape(post_contract_dims), - input - .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, - row_in_stride, col_in_stride, padding_type) - .reshape(pre_contract_dims) - .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) - .reshape(post_contract_dims)); -} - -} // end namespace Eigen - #endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc index 9aba7b63278..9b215c55af1 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc +++ b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc @@ -1848,4 +1848,14 @@ BM_PackLhs(/*input channels*/ 128, // /*filter channels*/ 1024, // /*filter dims*/ 3, 3, // /*block*/ 56, 256); + +BM_PackLhs(/*input channels*/ 30, // + /*filter channels*/ 64, // + /*filter dims*/ 3, 3, // + /*block*/ 256, 56); + +BM_PackLhs(/*input channels*/ 50, // + /*filter channels*/ 64, // + /*filter dims*/ 3, 3, // + /*block*/ 56, 256); } // namespace Eigen diff --git a/tensorflow/core/kernels/encode_jpeg_op.cc b/tensorflow/core/kernels/encode_jpeg_op.cc index e80404a4375..547b9d8da4d 100644 --- a/tensorflow/core/kernels/encode_jpeg_op.cc +++ b/tensorflow/core/kernels/encode_jpeg_op.cc @@ -135,4 +135,66 @@ class EncodeJpegOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("EncodeJpeg").Device(DEVICE_CPU), EncodeJpegOp); +class EncodeJpegVariableQualityOp : public OpKernel { + public: + explicit EncodeJpegVariableQualityOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& image = context->input(0); + OP_REQUIRES(context, image.dims() == 3, + errors::InvalidArgument("image must be 3-dimensional", + image.shape().DebugString())); + + OP_REQUIRES( + context, + FastBoundsCheck(image.NumElements(), std::numeric_limits::max()), + errors::InvalidArgument( + "Cannot encode images with >= max int32 elements")); + + const int32 dim_size0 = static_cast(image.dim_size(0)); + const int32 dim_size1 = static_cast(image.dim_size(1)); + const int32 dim_size2 = static_cast(image.dim_size(2)); + + // Use default jpeg compression flags except for format and quality. + jpeg::CompressFlags adjusted_flags; + + // Get jpeg encoding quality. + const Tensor& quality = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(quality.shape()), + errors::InvalidArgument("quality must be scalar: ", + quality.shape().DebugString())); + OP_REQUIRES(context, + 0 <= adjusted_flags.quality && adjusted_flags.quality <= 100, + errors::InvalidArgument("quality must be in [0,100], got ", + adjusted_flags.quality)); + adjusted_flags.quality = quality.scalar()(); + + // Autodetect format. + int channels; + channels = dim_size2; + if (channels == 1) { + adjusted_flags.format = jpeg::FORMAT_GRAYSCALE; + } else if (channels == 3) { + adjusted_flags.format = jpeg::FORMAT_RGB; + } else { + OP_REQUIRES( + context, false, + errors::InvalidArgument("image must have 1 or 3 channels, got ", + image.shape().DebugString())); + } + + // Encode image to jpeg string + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + OP_REQUIRES(context, + jpeg::Compress(image.flat().data(), dim_size1, dim_size0, + adjusted_flags, &output->scalar()()), + errors::Internal("JPEG encoding failed")); + } +}; +REGISTER_KERNEL_BUILDER(Name("EncodeJpegVariableQuality").Device(DEVICE_CPU), + EncodeJpegVariableQualityOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc index 4a0c1943e54..b023f1cdeb8 100644 --- a/tensorflow/core/kernels/encode_proto_op.cc +++ b/tensorflow/core/kernels/encode_proto_op.cc @@ -516,7 +516,7 @@ class EncodeProtoOp : public OpKernel { // Check the arguments for consistency. TensorShape common_prefix; - int message_count; + int message_count = 0; for (int i = 0; i < field_descs_.size(); i++) { const Tensor& v = values[i]; @@ -525,11 +525,16 @@ class EncodeProtoOp : public OpKernel { ctx, proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()), errors::InvalidArgument( - "Incompatible type for field " + field_names_[i] + - ". Saw dtype: ", - DataTypeString(v.dtype()), + "Incompatible type for field ", field_names_[i], + ". Saw dtype: ", DataTypeString(v.dtype()), " but field type is: ", field_descs_[i]->type_name())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()), + errors::InvalidArgument("Invalid shape for field ", field_names_[i], + ". Saw shape ", v.shape().DebugString(), + " but it should be at least a matrix.")); + // All value tensors must have the same shape prefix (i.e. batch size). TensorShape shape_prefix = v.shape(); shape_prefix.RemoveDim(shape_prefix.dims() - 1); diff --git a/tensorflow/core/kernels/example_parsing_ops_test.cc b/tensorflow/core/kernels/example_parsing_ops_test.cc index 5d06eda79e7..3bb9542b4dd 100644 --- a/tensorflow/core/kernels/example_parsing_ops_test.cc +++ b/tensorflow/core/kernels/example_parsing_ops_test.cc @@ -97,7 +97,15 @@ struct ExampleStore { AddExample(&serialized_example, 10, 512, 1); AddExample(&serialized_example, 100, 512, 1); AddExample(&serialized_example, 1000, 512, 1); + AddExample(&serialized_example, 1, 1, 10); + AddExample(&serialized_example, 1, 1, 100); + AddExample(&serialized_example, 1, 1, 1000); + AddExample(&serialized_example, 1, 1, 10000); + AddExample(&serialized_example, 1, 1, 100000); AddExample(&serialized_example, 1, 1, 1000000); + AddExample(&serialized_example, 10, 1, 100000); + AddExample(&serialized_example, 100, 1, 10000); + AddExample(&serialized_example, 1000, 1, 1000); }); return serialized_example; } @@ -299,11 +307,19 @@ BM_AllParseExample(VarLenDenseFloat); } \ BENCHMARK(BM_ParseSingleExample##_##TYPE##_1_##K##_##F); -#define BM_AllParseSingleExample(Type) \ - BM_ParseSingleExample(Type, 10, 1); \ - BM_ParseSingleExample(Type, 100, 1); \ - BM_ParseSingleExample(Type, 1000, 1); \ - BM_ParseSingleExample(Type, 1, 1000000); +#define BM_AllParseSingleExample(Type) \ + BM_ParseSingleExample(Type, 10, 1); \ + BM_ParseSingleExample(Type, 100, 1); \ + BM_ParseSingleExample(Type, 1000, 1); \ + BM_ParseSingleExample(Type, 1, 10); \ + BM_ParseSingleExample(Type, 1, 100); \ + BM_ParseSingleExample(Type, 1, 1000); \ + BM_ParseSingleExample(Type, 1, 10000); \ + BM_ParseSingleExample(Type, 1, 100000); \ + BM_ParseSingleExample(Type, 1, 1000000); \ + BM_ParseSingleExample(Type, 10, 100000); \ + BM_ParseSingleExample(Type, 100, 10000); \ + BM_ParseSingleExample(Type, 1000, 1000); BM_AllParseSingleExample(SparseString); BM_AllParseSingleExample(DenseString); diff --git a/tensorflow/core/kernels/extract_image_patches_op.cc b/tensorflow/core/kernels/extract_image_patches_op.cc index 9306eccf9f0..0fc1f567a92 100644 --- a/tensorflow/core/kernels/extract_image_patches_op.cc +++ b/tensorflow/core/kernels/extract_image_patches_op.cc @@ -130,7 +130,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER); #undef REGISTER -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) // Forward declarations of the functor specializations for GPU. namespace functor { @@ -160,6 +161,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER); #undef REGISTER -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc index 50159282ff1..650c51fc765 100644 --- a/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc +++ b/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU @@ -35,4 +36,4 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER); } // end namespace functor } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/eye_functor_gpu.cu.cc b/tensorflow/core/kernels/eye_functor_gpu.cu.cc index a620316e275..358584df51f 100644 --- a/tensorflow/core/kernels/eye_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/eye_functor_gpu.cu.cc @@ -17,12 +17,11 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/eye_functor.h" - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/kernels/eye_functor.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -52,10 +51,11 @@ struct EyeFunctor { const int batch_size = matrix_batch.dimension(0); const int m = matrix_batch.dimension(1); const int n = matrix_batch.dimension(2); - CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device); - EyeKernel<<>>(config.virtual_thread_count, batch_size, m, - n, matrix_batch.data()); + GpuLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device); + TF_CHECK_OK(CudaLaunchKernel(EyeKernel, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, batch_size, m, n, + matrix_batch.data())); } }; diff --git a/tensorflow/core/kernels/fake_quant_ops.cc b/tensorflow/core/kernels/fake_quant_ops.cc index f5e279eca4c..01e3468c93d 100644 --- a/tensorflow/core/kernels/fake_quant_ops.cc +++ b/tensorflow/core/kernels/fake_quant_ops.cc @@ -15,9 +15,10 @@ limitations under the License. #define EIGEN_USE_THREADS -#ifdef GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/fake_quant_ops_functor.h" @@ -28,9 +29,10 @@ limitations under the License. using tensorflow::BinaryElementWiseOp; using tensorflow::DEVICE_CPU; -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) using tensorflow::DEVICE_GPU; -#endif +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM using tensorflow::OpKernel; using tensorflow::OpKernelConstruction; using tensorflow::OpKernelContext; @@ -143,7 +145,8 @@ REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU), FakeQuantWithMinMaxArgsGradientOp); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) typedef Eigen::GpuDevice GPUDevice; // Forward declarations for functor specializations for GPU. @@ -165,7 +168,7 @@ void FakeQuantWithMinMaxArgsGradientFunctor::operator()( REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU), FakeQuantWithMinMaxArgsGradientOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // ----------------------------------------------------------------------------- // Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in @@ -265,7 +268,8 @@ REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU), FakeQuantWithMinMaxVarsGradientOp); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) template <> void FakeQuantWithMinMaxVarsFunctor::operator()( const GPUDevice& d, typename TTypes::ConstFlat inputs, @@ -294,7 +298,7 @@ REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient") .HostMemory("min") .HostMemory("max"), FakeQuantWithMinMaxVarsGradientOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // ----------------------------------------------------------------------------- // Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation @@ -411,7 +415,8 @@ REGISTER_KERNEL_BUILDER( Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU), FakeQuantWithMinMaxVarsPerChannelGradientOp); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) template <> void FakeQuantWithMinMaxVarsPerChannelFunctor::operator()( const GPUDevice& d, typename TTypes::ConstMatrix inputs, @@ -443,6 +448,6 @@ REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient") .HostMemory("min") .HostMemory("max"), FakeQuantWithMinMaxVarsPerChannelGradientOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc b/tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc index f6bfb884d94..b3bd44000ea 100644 --- a/tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define FAKE_QUANT_NO_DEBUG @@ -34,4 +35,4 @@ template struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index d7105a71bb8..e0f326dcea3 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -28,9 +28,10 @@ limitations under the License. #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/work_sharder.h" -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #include "tensorflow/core/platform/stream_executor.h" -#endif +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { @@ -286,7 +287,8 @@ REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL), #undef FFT_LABEL -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) namespace { template @@ -547,6 +549,6 @@ REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU), FFTGPU); REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU), FFTGPU); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // end namespace tensorflow diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index 9c4c0487f09..2435c3eed52 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -137,6 +137,7 @@ struct FillFunctor { TF_CALL_ALL_TYPES(DEFINE_FILL_CPU); DEFINE_FILL_CPU(quint8); DEFINE_FILL_CPU(quint16); +DEFINE_FILL_CPU(uint32); #undef DEFINE_FILL_CPU #ifdef TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/fill_functor.cu.cc b/tensorflow/core/kernels/fill_functor.cu.cc index d4c92586897..4e47de45c3e 100644 --- a/tensorflow/core/kernels/fill_functor.cu.cc +++ b/tensorflow/core/kernels/fill_functor.cu.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU @@ -116,4 +117,4 @@ TF_CALL_bool(DEFINE_SETONE_GPU); } // end namespace functor } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/fingerprint_op.cc b/tensorflow/core/kernels/fingerprint_op.cc new file mode 100644 index 00000000000..20529326b3d --- /dev/null +++ b/tensorflow/core/kernels/fingerprint_op.cc @@ -0,0 +1,136 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/byte_order.h" +#include "tensorflow/core/platform/fingerprint.h" + +namespace tensorflow { +namespace { +template +inline void CopyToBuffer(const T& value, uint8* output) { + // Memcpy to string is endian-dependent. We choose little-endian as + // standard. On big-endian machines, bytes should be reversed. +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + static_assert(port::kLittleEndian, ""); + std::memcpy(output, &value, sizeof(value)); +#else + static_assert(!port::kLittleEndian, ""); + std::reverse_copy(reinterpret_cast(&value), + reinterpret_cast(&value + 1), output); +#endif +} + +void FarmhashFingerprint64(TTypes::ConstTensor input, + TTypes::Matrix output) { + DCHECK_EQ(output.dimension(0), input.dimension(0)); + DCHECK_EQ(output.dimension(1), sizeof(uint64)); + for (int64 i = 0; i < output.dimension(0); ++i) { + const uint64 fingerprint = + Fingerprint64({reinterpret_cast(&input(i, 0)), + static_cast(input.dimension(1))}); + CopyToBuffer(fingerprint, &output(i, 0)); + } +} + +void FarmhashFingerprint64(TTypes::ConstFlat input, + TTypes::Matrix output) { + DCHECK_EQ(output.dimension(0), input.dimension(0)); + DCHECK_EQ(output.dimension(1), sizeof(uint64)); + for (int64 i = 0; i < input.dimension(0); ++i) { + const uint64 fingerprint = + Fingerprint64({input(i).data(), input(i).size()}); + CopyToBuffer(fingerprint, &output(i, 0)); + } +} + +class FingerprintOp : public OpKernel { + public: + explicit FingerprintOp(OpKernelConstruction* context) : OpKernel(context) { + DataType dtype; + OP_REQUIRES_OK(context, context->GetAttr("T", &dtype)); + OP_REQUIRES(context, DataTypeCanUseMemcpy(dtype) || dtype == DT_STRING, + errors::InvalidArgument("Data type not supported: ", + DataTypeString(dtype))); + } + + void Compute(tensorflow::OpKernelContext* context) override { + const Tensor& method_tensor = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(method_tensor.shape()), + errors::InvalidArgument("`method` should be a scalar string: ", + method_tensor.shape())); + // For now, farmhash64 is the only function supported. + const string& method = method_tensor.scalar()(); + OP_REQUIRES( + context, method == "farmhash64", + errors::InvalidArgument("Unsupported fingerprint method: ", method)); + + const Tensor& input = context->input(0); + OP_REQUIRES( + context, TensorShapeUtils::IsVectorOrHigher(input.shape()), + errors::InvalidArgument("`data` should have at least one dimension: ", + input.shape())); + + const int64 dim0 = input.shape().dim_size(0); + const int64 dim1 = input.shape().num_elements() / dim0; + + Tensor* output; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape{dim0, kFingerprintSize}, &output)); + + if (input.dtype() == DT_STRING) { + if (dim1 > 1) { + Tensor temp; + OP_REQUIRES_OK(context, context->allocate_temp( + DT_UINT8, + TensorShape{input.shape().num_elements(), + kFingerprintSize}, + &temp)); + // `temp` is a matrix of shape {input.num_elements, fingerprint_size}, + // and each row contains the fingerprint value of corresponding string. + // To compute fingerprints of multiple strings, this op fingerprints the + // buffer containing the string fingerprints. + FarmhashFingerprint64(input.flat(), temp.tensor()); + FarmhashFingerprint64(static_cast(temp).shaped( + {dim0, dim1 * kFingerprintSize}), + output->matrix()); + } else { + // In case dim1 == 1, each string computes into its own fingerprint + // value. There is no need to fingerprint twice. + FarmhashFingerprint64(input.flat(), output->matrix()); + } + } else { + auto data = input.bit_casted_shaped( + {dim0, dim1 * DataTypeSize(input.dtype())}); + FarmhashFingerprint64(data, output->matrix()); + } + } + + private: + static constexpr int kFingerprintSize = sizeof(uint64); +}; + +REGISTER_KERNEL_BUILDER(Name("Fingerprint").Device(tensorflow::DEVICE_CPU), + FingerprintOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fingerprint_op_test.cc b/tensorflow/core/kernels/fingerprint_op_test.cc new file mode 100644 index 00000000000..febfafb4db3 --- /dev/null +++ b/tensorflow/core/kernels/fingerprint_op_test.cc @@ -0,0 +1,242 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +Status MakeNodeDef(DataType dtype, NodeDef* node_def) { + return NodeDefBuilder("fingerprint", "Fingerprint") + .Input(FakeInput(dtype)) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def); +} + +class FingerprintOpTest : public OpsTestBase { + protected: + Status MakeFingerprintOp(Tensor* tensor) { + return MakeFingerprintOp(tensor, "farmhash64"); + } + + Status MakeFingerprintOp(Tensor* data, const string& method) { + TF_RETURN_IF_ERROR(MakeNodeDef(data->dtype(), node_def())); + TF_RETURN_IF_ERROR(InitOp()); + + inputs_.clear(); + inputs_.push_back(data); + + method_ = Tensor(DT_STRING, TensorShape{}); + method_.scalar()() = method; + inputs_.push_back(&method_); + return Status::OK(); + } + + Tensor batch_dims_; + Tensor method_; +}; + +// This test detects changes in fingerprint method. +TEST_F(FingerprintOpTest, GoldenValue) { + Tensor tensor(DT_UINT8, {1, 3, 4, 5, 6, 7}); + auto buffer = tensor.flat(); + std::iota(buffer.data(), buffer.data() + buffer.size(), + static_cast(47)); + + TF_ASSERT_OK(MakeFingerprintOp(&tensor)); + TF_ASSERT_OK(RunOpKernel()); + EXPECT_EQ(GetOutput(0)->shape(), (TensorShape{1, 8})); + EXPECT_EQ(GetOutput(0)->tensor_data(), "\x2d\x90\xdf\x03\x79\x36\x3c\x43"); +} + +// String types have a different compute path. This test detects changes in this +// special-case handling. +TEST_F(FingerprintOpTest, StringGoldenValue) { + Tensor data(DT_STRING, {1, 2, 2}); + auto buffer = data.flat(); + buffer(0).resize(10); + buffer(1).resize(7); + buffer(2).resize(0); + buffer(3).resize(19); + std::iota(buffer(0).begin(), buffer(0).end(), 0); + std::iota(buffer(1).begin(), buffer(1).end(), 7); + std::iota(buffer(2).begin(), buffer(2).end(), 71); + std::iota(buffer(3).begin(), buffer(3).end(), 41); + + TF_ASSERT_OK(MakeFingerprintOp(&data)); + TF_ASSERT_OK(RunOpKernel()); + ASSERT_EQ(GetOutput(0)->shape(), (TensorShape{1, 8})); + EXPECT_EQ(GetOutput(0)->tensor_data(), "\x92\x43\x28\x52\xa3\x7c\x48\x18"); + + // When each batch item has exactly one string, Fingerprint op avoids + // double-fingerprint. Adding a test to detect any change in this logic. + ASSERT_TRUE(data.CopyFrom(data, TensorShape{4})); + TF_ASSERT_OK(MakeFingerprintOp(&data)); + TF_ASSERT_OK(RunOpKernel()); + ASSERT_EQ(GetOutput(0)->shape(), (TensorShape{4, 8})); + EXPECT_EQ(GetOutput(0)->tensor_data(), + "\xea\xff\xd6\xb2\xb2\x4d\x70\x9b" + "\x6e\x9d\xed\x21\xc6\x4a\x61\x52" + "\x4f\x40\x90\x2f\x3b\x6a\xe1\x9a" + "\x0d\x9b\x7f\x63\x23\x14\x1c\xb8"); +} + +TEST_F(FingerprintOpTest, Collision) { + const TensorShape shape = {1, 2, 4, 6}; + for (DataType dtype : kRealNumberTypes) { + const int64 size = shape.num_elements() * DataTypeSize(dtype); + + Tensor tensor(dtype, shape); + auto buffer = tensor.bit_casted_shaped({size}); + buffer.setRandom(); + + TF_ASSERT_OK(MakeFingerprintOp(&tensor)); + TF_ASSERT_OK(RunOpKernel()); + const Tensor fingerprint0 = *GetOutput(0); + + // Alter a byte value in the buffer. + const int offset = buffer(0) % buffer.size(); + buffer(offset) = ~buffer(offset); + + TF_ASSERT_OK(MakeFingerprintOp(&tensor)); + TF_ASSERT_OK(RunOpKernel()); + const Tensor fingerprint1 = *GetOutput(0); + + EXPECT_NE(fingerprint0.tensor_data(), fingerprint1.tensor_data()); + } +} + +TEST_F(FingerprintOpTest, CollisionString) { + constexpr int64 size = 256; + + Tensor tensor(DT_STRING, {1}); + auto& input = tensor.vec()(0); + input.resize(size); + + TTypes::UnalignedFlat buffer(reinterpret_cast(&*input.begin()), + input.size()); + buffer.setRandom(); + + TF_ASSERT_OK(MakeFingerprintOp(&tensor)); + TF_ASSERT_OK(RunOpKernel()); + const Tensor fingerprint0 = *GetOutput(0); + + // Alter a byte value in the buffer. + const int offset = buffer(0) % buffer.size(); + buffer(offset) = ~buffer(offset); + + TF_ASSERT_OK(MakeFingerprintOp(&tensor)); + TF_ASSERT_OK(RunOpKernel()); + const Tensor fingerprint1 = *GetOutput(0); + + EXPECT_NE(fingerprint0.tensor_data(), fingerprint1.tensor_data()); +} + +TEST_F(FingerprintOpTest, CompareBytesAndString) { + Tensor pods_tensor(DT_FLOAT, {4, 64}); + Tensor strings_tensor(DT_STRING, {4}); + + auto pods = pods_tensor.matrix(); + pods.setRandom(); + + auto strings = strings_tensor.vec(); + for (int64 i = 0; i < strings.size(); ++i) { + strings(i).assign(reinterpret_cast(&pods(i, 0)), + pods.dimension(1) * sizeof(pods(i, 0))); + } + + TF_ASSERT_OK(MakeFingerprintOp(&pods_tensor)); + TF_ASSERT_OK(RunOpKernel()); + Tensor pods_fingerprints = *GetOutput(0); + + TF_ASSERT_OK(MakeFingerprintOp(&strings_tensor)); + TF_ASSERT_OK(RunOpKernel()); + Tensor strings_fingerprints = *GetOutput(0); + + EXPECT_EQ(pods_fingerprints.tensor_data(), + strings_fingerprints.tensor_data()); +} + +TEST_F(FingerprintOpTest, SupportedMethods) { + Tensor tensor(DT_STRING, TensorShape{1}); + TF_ASSERT_OK(MakeFingerprintOp(&tensor, "unsupported_method")); + + const Status status = RunOpKernel(); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("unsupported_method"), string::npos); +} + +TEST_F(FingerprintOpTest, SupportedTypes) { + Tensor input(DT_RESOURCE, TensorShape{1}); + EXPECT_FALSE(MakeFingerprintOp(&input).ok()); +} + +TEST(FingerprintOpShapeFnTest, MethodKnownStatically) { + ShapeInferenceTestOp op("Fingerprint"); + + Tensor method(DT_STRING, TensorShape{}); + method.scalar()() = "farmhash64"; + op.input_tensors.assign({nullptr, &method}); + + TF_ASSERT_OK(MakeNodeDef(DT_UINT8, &op.node_def)); + INFER_OK(op, "?;?", "[?,8]"); + INFER_ERROR("must be at least rank 1", op, "[];?"); + INFER_OK(op, "[?];?", "[d0_0,8]"); + INFER_OK(op, "[1,?];?", "[d0_0,8]"); + INFER_OK(op, "[?,2,3];?", "[d0_0,8]"); +} + +TEST(FingerprintOpShapeFnTest, MethodUnknownStatically) { + ShapeInferenceTestOp op("Fingerprint"); + + TF_ASSERT_OK(MakeNodeDef(DT_FLOAT, &op.node_def)); + INFER_OK(op, "?;?", "[?,?]"); + INFER_ERROR("must be at least rank 1", op, "[];?"); + INFER_OK(op, "[?];?", "[d0_0,?]"); + INFER_OK(op, "[1,?];?", "[d0_0,?]"); + INFER_OK(op, "[?,2,3];?", "[d0_0,?]"); +} + +TEST(FingerprintOpShapeFnTest, InvalidMethod) { + ShapeInferenceTestOp op("Fingerprint"); + + // When `method` shape is known statically. + INFER_ERROR("must be rank 0", op, "[1];[1]"); + + // When `method` shape is unknown statically. + Tensor method(DT_STRING, TensorShape{1}); + method.vec()(0) = "farmhash64"; + op.input_tensors.assign({nullptr, &method}); + INFER_ERROR("must be rank 0", op, "?;?"); + + method = Tensor(DT_STRING, TensorShape{}); + method.scalar()() = "unsupported_method"; + op.input_tensors.assign({nullptr, &method}); + INFER_ERROR("unsupported_method", op, "?;?"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc index 61234479eac..dfc2382624e 100644 --- a/tensorflow/core/kernels/fractional_avg_pool_op.cc +++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc @@ -84,7 +84,7 @@ class FractionalAvgPoolOp : public OpKernel { // Output size. for (int i = 0; i < tensor_in_and_out_dims; ++i) { output_size[i] = - static_cast(floor(input_size[i] / pooling_ratio_[i])); + static_cast(std::floor(input_size[i] / pooling_ratio_[i])); DCHECK_GT(output_size[i], 0); } diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc index cf580adab25..619a3507ce4 100644 --- a/tensorflow/core/kernels/fractional_max_pool_op.cc +++ b/tensorflow/core/kernels/fractional_max_pool_op.cc @@ -89,7 +89,7 @@ class FractionalMaxPoolOp : public OpKernel { // This must match the same logic in the shape function in // core/ops/nn_ops.cc. output_size[i] = - static_cast(floor(input_size[i] / pooling_ratio_[i])); + static_cast(std::floor(input_size[i] / pooling_ratio_[i])); DCHECK_GT(output_size[i], 0); } diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 88a8a523e47..33bed217003 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/function_ops.h" + #include #include -#include "tensorflow/core/kernels/function_ops.h" - #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -249,7 +250,6 @@ class SymbolicGradientOp : public AsyncOpKernel { ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done); FunctionLibraryRuntime::Options opts; - opts.step_id = ctx->step_id(); opts.rendezvous = ctx->rendezvous(); opts.cancellation_manager = ctx->cancellation_manager(); opts.runner = ctx->runner(); @@ -262,6 +262,13 @@ class SymbolicGradientOp : public AsyncOpKernel { args.push_back(ctx->input(i)); } std::vector* rets = new std::vector; + profiler::TraceMe trace_me( + [&] { + return absl::StrCat( + "SymbolicGradientOp #parent_step_id=", ctx->step_id(), + ",function_step_id=", opts.step_id, "#"); + }, + /*level=*/2); lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) { if (!status.ok()) { ctx->SetStatus(status); @@ -329,8 +336,12 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { handle = cached_entry->second; } else { VLOG(1) << "Instantiating " << func_.name() << " on " << target_device; - tracing::ScopedActivity activity(strings::StrCat( - "RemoteCall: Instantiate: ", func_.name(), " on ", target_device)); + profiler::TraceMe activity( + [&] { + return strings::StrCat("RemoteCall: Instantiate: ", func_.name(), + " on ", target_device); + }, + profiler::TraceMeLevel::kInfo); OP_REQUIRES_OK_ASYNC( ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), @@ -347,7 +358,6 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); FunctionLibraryRuntime::Options opts; - opts.step_id = ctx->step_id(); opts.runner = ctx->runner(); opts.source_device = source_device; if (opts.source_device != target_device) { @@ -374,10 +384,20 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { opts.rets_alloc_attrs.push_back(ret_alloc_attrs); } auto* rets = new std::vector; - auto* activity = new tracing::ScopedActivity(strings::StrCat( - "RemoteCall: Run: ", func_.name(), " on ", target_device)); + auto* activity = new profiler::TraceMe( + [&] { + return strings::StrCat("RemoteCall: Run: ", func_.name(), " on ", + target_device); + }, + profiler::TraceMeLevel::kInfo); VLOG(1) << "Running " << func_.name() << " on " << target_device << " with handle: " << handle; + profiler::TraceMe trace_me( + [&] { + return absl::StrCat("RemoteCallOp #parent_step_id=", ctx->step_id(), + ",function_step_id=", opts.step_id, "#"); + }, + /*level=*/2); lib->Run(opts, handle, args, rets, [rets, activity, done, ctx](const Status& status) { if (!status.ok()) { diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index 246a6ce04d9..52ec30080cf 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -114,7 +115,6 @@ Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx, void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, bool always_collect_stats) { - opts->step_id = ctx->step_id(); opts->rendezvous = ctx->rendezvous(); opts->cancellation_manager = ctx->cancellation_manager(); if (always_collect_stats) { @@ -183,6 +183,12 @@ class IfOp : public AsyncOpKernel { void Start() { FHandle handle = cond_ ? then_handle_ : else_handle_; rets_.clear(); + profiler::TraceMe trace_me( + [&] { + return absl::StrCat("IfOp #parent_step_id=", ctx_->step_id(), + ",function_step_id=", opts_.step_id, "#"); + }, + /*level=*/2); lib_->Run( // Evaluate one of the branch. opts_, handle, args_, &rets_, @@ -276,6 +282,12 @@ class CaseOp : public AsyncOpKernel { branch = branch_handles_.size() - 1; } rets_.clear(); + profiler::TraceMe trace_me( + [&] { + return absl::StrCat("CaseOp #parent_step_id=", ctx_->step_id(), + ",function_step_id=", opts_.step_id, "#"); + }, + /*level=*/2); lib_->Run( // Evaluate one of the branch. opts_, branch_handles_[branch], args_, &rets_, @@ -384,6 +396,13 @@ class WhileOp : public AsyncOpKernel { TensorVec rets_; void EvalCond() { + profiler::TraceMe trace_me( + [&] { + return absl::StrCat( + "WhileOp-EvalCond #parent_step_id=", ctx_->step_id(), + ",function_step_id=", opts_.step_id, "#"); + }, + /*level=*/2); lib_->Run( // Evaluate the condition. opts_, cond_handle_, args_, &rets_, @@ -444,6 +463,13 @@ class WhileOp : public AsyncOpKernel { return Finish(Status::OK()); } rets_.clear(); + profiler::TraceMe trace_me( + [&] { + return absl::StrCat( + "WhileOp-StartBody #parent_step_id=", ctx_->step_id(), + ",function_step_id=", opts_.step_id, "#"); + }, + /*level=*/2); lib_->Run( // Evaluate the body. opts_, body_handle_, args_, &rets_, @@ -594,6 +620,12 @@ class ForOp : public AsyncOpKernel { args_[1 + i] = std::move(rets_[i]); } rets_.clear(); + profiler::TraceMe trace_me( + [&] { + return absl::StrCat("ForOp #parent_step_id=", ctx_->step_id(), + ",function_step_id=", opts_.step_id, "#"); + }, + /*level=*/2); lib_->Run(opts_, kernel_->body_handle_, args_, &rets_, [this](const Status& s) { if (s.ok()) { diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 48b339508b5..40a58defe72 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -253,13 +253,22 @@ struct FusedBatchNorm { const int64 channels = GetTensorDim(x, tensor_format, 'C'); const int64 height = GetTensorDim(x, tensor_format, 'H'); const int64 width = GetTensorDim(x, tensor_format, 'W'); + + // If input tensor is in NHWC format, and we are running in inference mode, + // there is no need to convert to NCHW format, performance is the same. + // However in training mode, performance in NCHW format is much better. + TensorFormat compute_format = !is_training && tensor_format == FORMAT_NHWC + ? FORMAT_NHWC + : FORMAT_NCHW; + VLOG(2) << "FusedBatchNorm:" << " batch_size: " << batch_size << " channels: " << channels << " height: " << height << " width:" << width << " x shape: " << x.shape().DebugString() << " scale shape: " << scale.shape().DebugString() << " offset shape: " << offset.shape().DebugString() - << " tensor format: " << tensor_format; + << " tensor format: " << ToString(tensor_format) + << " compute format: " << ToString(compute_format); // If input is empty, return NaN mean/variance if (x.shape().num_elements() == 0) { @@ -274,12 +283,12 @@ struct FusedBatchNorm { Tensor y_transformed; se::DeviceMemory y_ptr; - if (tensor_format == FORMAT_NCHW) { + if (tensor_format == compute_format) { y_ptr = StreamExecutorUtil::AsDeviceMemory(*y); - } else if (tensor_format == FORMAT_NHWC) { + } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) { OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, batch_size, + ShapeFromFormat(compute_format, batch_size, height, width, channels), &x_transformed)); functor::NHWCToNCHW()( @@ -290,22 +299,27 @@ struct FusedBatchNorm { OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, batch_size, + ShapeFromFormat(compute_format, batch_size, height, width, channels), &y_transformed)); y_ptr = StreamExecutorUtil::AsDeviceMemory(y_transformed); } else { - context->SetStatus( - errors::Internal("Unsupported tensor format: ", tensor_format)); + context->SetStatus(errors::Internal( + "Unsupported tensor format: ", ToString(tensor_format), + " and compute format: ", ToString(compute_format))); return; } + const se::dnn::DataLayout data_layout = + compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth + : se::dnn::DataLayout::kBatchDepthYX; + se::dnn::BatchDescriptor x_desc; x_desc.set_count(batch_size) .set_feature_map_count(channels) .set_height(height) .set_width(width) - .set_layout(se::dnn::DataLayout::kBatchDepthYX); + .set_layout(data_layout); se::dnn::BatchDescriptor scale_offset_desc; scale_offset_desc.set_count(1) @@ -371,7 +385,8 @@ struct FusedBatchNorm { errors::Internal("cuDNN launch failure : input shape (", x.shape().DebugString(), ")")); } - if (tensor_format == FORMAT_NHWC) { + + if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) { functor::NCHWToNHWC()( context->eigen_device(), const_cast(y_transformed).tensor(), diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cu.cc b/tensorflow/core/kernels/fused_batch_norm_op.cu.cc index 4a67b2b3a30..261cb9d1b31 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cu.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cu.cc @@ -15,9 +15,9 @@ limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU -#include "cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "tensorflow/core/kernels/fused_batch_norm_op.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -37,10 +37,11 @@ template void VarianceToInvVariance::operator()(const Eigen::GpuDevice& d, const T* variance, double epsilon, int channels, T* inv_variance) { - CudaLaunchConfig config = GetCudaLaunchConfig(channels, d); - VarianceToInvVarianceKernel<<>>(config.virtual_thread_count, - variance, epsilon, inv_variance); + GpuLaunchConfig config = GetCudaLaunchConfig(channels, d); + TF_CHECK_OK(CudaLaunchKernel(VarianceToInvVarianceKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), config.virtual_thread_count, + variance, epsilon, inv_variance)); } template @@ -59,10 +60,11 @@ template void InvVarianceToVariance::operator()(const Eigen::GpuDevice& d, double epsilon, int sample_size, int channels, T* variance) { - CudaLaunchConfig config = GetCudaLaunchConfig(channels, d); - InvVarianceToVarianceKernel<<>>(config.virtual_thread_count, - epsilon, sample_size, variance); + GpuLaunchConfig config = GetCudaLaunchConfig(channels, d); + TF_CHECK_OK(CudaLaunchKernel(InvVarianceToVarianceKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), config.virtual_thread_count, epsilon, + sample_size, variance)); } template diff --git a/tensorflow/core/kernels/fused_batch_norm_op_test.cc b/tensorflow/core/kernels/fused_batch_norm_op_test.cc index a3f760b746a..1b348a600b6 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op_test.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -21,10 +23,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { class FusedBatchNormOpTest : public OpsTestBase {}; @@ -124,4 +128,82 @@ TEST_F(FusedBatchNormGradOpTest, Simple) { test::FillValues(&expected_offset, {27, 27}); test::ExpectTensorNear(expected_offset, *GetOutput(2), 0.01); } + +//----------------------------------------------------------------------------// +// Performance benchmarks are below. // +//----------------------------------------------------------------------------// + +using fp32 = float; +using fp16 = Eigen::half; + +template +static Graph* FusedBatchNormInference(int n, int h, int w, int c, + bool is_training, + TensorFormat data_format) { + Graph* g = new Graph(OpRegistry::Global()); + + DataType dtype = DataTypeToEnum::value; + Tensor x_t(dtype, data_format == FORMAT_NHWC ? TensorShape({n, h, w, c}) + : TensorShape({n, c, h, w})); + x_t.flat().setRandom(); + + Tensor other_t(DT_FLOAT, TensorShape({c})); + other_t.flat().setRandom(); + + Tensor empty_t(DT_FLOAT, TensorShape({0})); + + Node* x = test::graph::Constant(g, x_t, "x"); + Node* other = test::graph::Constant(g, other_t, "other"); + Node* empty = test::graph::Constant(g, empty_t, "empty"); + + Node* fused_batch_norm; + TF_CHECK_OK(NodeBuilder(g->NewName("fused_batch_norm"), "FusedBatchNormV2") + .Input(x) + .Input(other) // scale + .Input(other) // offset + .Input(is_training ? empty : other) // mean + .Input(is_training ? empty : other) // variance + .Attr("T", dtype) + .Attr("U", DT_FLOAT) + .Attr("epsilon", 0.001) + .Attr("is_training", is_training) + .Attr("data_format", ToString(data_format)) + .Finalize(g, &fused_batch_norm)); + + return g; +} + +#define BM_NAME(N, H, W, C, T, IT, FORMAT, DEVICE) \ + BM_FusedBatchNorm##_##N##_##H##_##W##_##C##_##IT##_##FORMAT##_##T##_##DEVICE + +#define BM_FusedBatchNorm(N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE) \ + static void BM_NAME(N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE)(int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast(iters) * N * H * W * C); \ + test::Benchmark(#DEVICE, FusedBatchNormInference( \ + N, H, W, C, IS_TRAINING, FORMAT_##FORMAT)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_NAME(N, H, W, C, T, IS_TRAINING, FORMAT, DEVICE)); + +BM_FusedBatchNorm(64, 14, 14, 256, fp32, false, NHWC, cpu); +BM_FusedBatchNorm(64, 14, 14, 256, fp16, false, NHWC, cpu); + +BM_FusedBatchNorm(64, 14, 14, 256, fp32, true, NHWC, cpu); +BM_FusedBatchNorm(64, 14, 14, 256, fp16, true, NHWC, cpu); + +#ifdef GOOGLE_CUDA +BM_FusedBatchNorm(64, 14, 14, 256, fp32, false, NHWC, gpu); +BM_FusedBatchNorm(64, 14, 14, 256, fp16, false, NHWC, gpu); + +BM_FusedBatchNorm(64, 14, 14, 256, fp32, false, NCHW, gpu); +BM_FusedBatchNorm(64, 14, 14, 256, fp16, false, NCHW, gpu); + +BM_FusedBatchNorm(64, 14, 14, 256, fp32, true, NHWC, gpu); +BM_FusedBatchNorm(64, 14, 14, 256, fp16, true, NHWC, gpu); + +BM_FusedBatchNorm(64, 14, 14, 256, fp32, true, NCHW, gpu); +BM_FusedBatchNorm(64, 14, 14, 256, fp16, true, NCHW, gpu); +#endif // GOOGLE_CUDA + } // namespace tensorflow diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.cc b/tensorflow/core/kernels/fused_eigen_output_kernels.cc new file mode 100644 index 00000000000..94e621ae05b --- /dev/null +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/fused_eigen_output_kernels.h" + +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" + +namespace tensorflow { + +Status InitializeFusedComputation( + OpKernelConstruction* context, const string& kernel_name, + const std::vector& patterns, + FusedComputationType* fused_computation, + FusedComputationArgs* fused_computation_args) { + // 'fused_ops' and 'num_args' attributes are specified by the Grappler + // Remapper optimizer (see grappler/optimizers/remapper.cc). + + std::vector fused_ops; + TF_RETURN_IF_ERROR(context->GetAttr("fused_ops", &fused_ops)); + if (fused_ops.empty()) { + return errors::InvalidArgument("Fused ", kernel_name, + " must have at least one fused op."); + } + + int num_args; + TF_RETURN_IF_ERROR(context->GetAttr("num_args", &num_args)); + + // TODO(ezhulenev): Add support for fusion element-wise op chains defined + // at runtime, e.g. Relu+Sqrt+Tanh+etc. + + // Reset fused computation type. + *fused_computation = FusedComputationType::kUndefined; + + // Match op fusion to one of the supported patterns. + for (const auto& pattern : patterns) { + if (fused_ops == pattern.fused_ops) { + *fused_computation = pattern.fused_computation; + break; + } + } + if (*fused_computation == FusedComputationType::kUndefined) { + return errors::Unimplemented("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"); + } + + // Depending on a picked fusion type validate fusion-specific arguments. + if (*fused_computation == FusedComputationType::kBiasAdd || + *fused_computation == FusedComputationType::kBiasAddWithRelu || + *fused_computation == FusedComputationType::kBiasAddWithRelu6 || + *fused_computation == FusedComputationType::kBiasAddWithElu) { + if (num_args != 1) { + return errors::InvalidArgument( + "Fused ", kernel_name, + " with BiasAdd must have one extra argument: bias."); + } + } + + if (*fused_computation == FusedComputationType::kFusedBatchNorm || + *fused_computation == FusedComputationType::kFusedBatchNormWithRelu || + *fused_computation == FusedComputationType::kFusedBatchNormWithRelu6 || + *fused_computation == FusedComputationType::kFusedBatchNormWithElu) { + if (num_args != 4) { + return errors::InvalidArgument( + "Fused ", kernel_name, + " with FusedBatchNorm must have four extra arguments: scale, offset, " + "mean, variance."); + } + TF_RETURN_IF_ERROR( + context->GetAttr("epsilon", &fused_computation_args->epsilon)); + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.h b/tensorflow/core/kernels/fused_eigen_output_kernels.h new file mode 100644 index 00000000000..2588da10f58 --- /dev/null +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.h @@ -0,0 +1,327 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Output kernels for fusing computation into Eigen Tensor contractions: +// (1) FusedConv2DOp +// (2) FusedMatMulOp +// +// Supported fused computations: +// (1) {Conv2D/MatMul} + BiasAdd + +// (2) {Conv2D/MatMul} + FusedBatchNorm + +// +// Activation: Relu, Relu6, Elu, etc... + +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +enum class FusedComputationType { + kUndefined, + kBiasAdd, + kBiasAddWithRelu, + kBiasAddWithRelu6, + kBiasAddWithElu, + kFusedBatchNorm, + kFusedBatchNormWithRelu, + kFusedBatchNormWithRelu6, + kFusedBatchNormWithElu +}; + +// We have to pass around additional arguments for all possible fusion types. +struct FusedComputationArgs { + float epsilon = 0.0; // Used by `FusedBatchNorm` fusion only +}; + +struct FusedComputationPattern { + FusedComputationType fused_computation; + std::vector fused_ops; +}; + +// Parse attributes from the kernel construction context, and verifies that they +// specify valid fused computation pattern. +Status InitializeFusedComputation( + OpKernelConstruction* context, const string& kernel_name, + const std::vector& patterns, + FusedComputationType* fused_computation, + FusedComputationArgs* fused_computation_args); + +// Type alias for the tensor contraction output mapper. +template +using ContractionOutputMapper = + Eigen::internal::blas_data_mapper; + +// Returns input expression without any transformations. +struct Identity { + template + static auto apply(XprType expr) -> XprType { + return expr; + }; +}; + +// Applies `Relu` to the passed input expression. +struct Relu { + template + static auto apply(XprType expr) + -> decltype(expr.cwiseMax(std::declval())) { + return expr.cwiseMax(static_cast(0)); + }; +}; + +// Applies `Relu6` to the passed input expression. +struct Relu6 { + template + static auto apply(XprType expr) + -> decltype(expr.cwiseMax(std::declval()) + .cwiseMin(std::declval())) { + return expr.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(6)); + }; +}; + +// Applies `Elu` to the passed input expression. +struct Elu { + template + static auto apply(XprType expr) -> decltype( + (expr < std::declval()) + .select(expr.exp() - + expr.constant(std::declval()), + expr)) { + return (expr < static_cast(0)) + .select(expr.exp() - + expr.constant(static_cast(1)), + expr); + }; +}; + +template +struct BiasAddArgs { + const T* bias_add_data = nullptr; + + static bool IsSupported(FusedComputationType fusion) { + return fusion == FusedComputationType::kBiasAdd || + fusion == FusedComputationType::kBiasAddWithRelu || + fusion == FusedComputationType::kBiasAddWithRelu6 || + fusion == FusedComputationType::kBiasAddWithElu; + } +}; + +template +struct FusedBatchNormArgs { + const T* scale_data = nullptr; + const T* offset_data = nullptr; + const T* estimated_mean_data = nullptr; + const T* estimated_variance_data = nullptr; + + // Precomputed expression: + // scaling_factor = (estimated_variance + epsilon).rsqrt() * scale + Eigen::Tensor scaling_factor; + + static bool IsSupported(FusedComputationType fusion) { + return fusion == FusedComputationType::kFusedBatchNorm || + fusion == FusedComputationType::kFusedBatchNormWithRelu || + fusion == FusedComputationType::kFusedBatchNormWithRelu6 || + fusion == FusedComputationType::kFusedBatchNormWithElu; + } +}; + +// TensorContraction swaps lhs with rhs, and changes layout from RowMajor +// (default in Tensorflow) to ColMajor (preferred in Eigen), and computes matmul +// using these tensors. +// +// (1) Spatial Convolution (see eigen_spatial_convolutions.h): +// +// TensorContraction output matrix (before reshape) has a ColMajor layout, and +// has dimensions: +// - rows: output_channels +// - cols: all other dimensions +// +// First element in every column is: +// [batch ??, height ??, width ??, out_channel = i] +// +// We do not know what are the values of the 'batch', 'height', and 'width' +// here (if we know original dimensions, they can be computed from 'j'). +// +// Each column of an output block is a continuous slice along the output +// channel dimension, so we can use it to efficiently compute any +// transformation that depends only on a channel value (e.g. add channel +// bias). +// +// (2) Matrix Multiplication (see matmul_op.cc): +// +// For the `MxK * KxN` matrix multiplication, output matrix has a `MxN` +// dimensions. Each column in output block is a slice of the innermost +// dimension of the output matrix starting at offset 'i'. +// +// Example: In Tensorflow MatMul [8x32] * [32x64], each output block column +// will correspond to MatMul output row of size 64 (because Tensorflow uses +// row major storage order). + +// Output kernel that fuses BiasAdd operation into the output of tensor +// contraction + activation function defined by Activation. +template +struct BiasAddOutputKernel { + explicit BiasAddOutputKernel(const BiasAddArgs& args) + : bias_data(args.bias_add_data) {} + + template + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, StorageIndex i, + StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const { + DCHECK(params.swapped_arguments); + + const T* bias_base = bias_data + i; + typename TTypes::UnalignedConstTensor bias(bias_base, num_rows); + + for (int col = 0; col < num_cols; ++col) { + T* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + const auto expr = output + bias; + output = Activation::template apply(expr); + } + } + + private: + const T* bias_data; +}; + +// Output kernel that fuses FusedBatchNorm operation into the output of tensor +// contraction + activation function defined by Activation. +template +struct FusedBatchNormOutputKernel { + FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs& args) + : epsilon(epsilon), + scaling_factor_data(args.scaling_factor.data()), + offset_data(args.offset_data), + estimated_mean_data(args.estimated_mean_data) {} + + template + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, StorageIndex i, + StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const { + DCHECK(params.swapped_arguments); + + const T* scaling_factor_base = scaling_factor_data + i; + const T* offset_base = offset_data + i; + const T* mean_base = estimated_mean_data + i; + + typename TTypes::UnalignedConstTensor scaling_factor(scaling_factor_base, + num_rows); + typename TTypes::UnalignedConstTensor offset(offset_base, num_rows); + typename TTypes::UnalignedConstTensor mean(mean_base, num_rows); + + for (int col = 0; col < num_cols; ++col) { + T* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + + auto scaled = (output - mean) * scaling_factor; + auto shifted = scaled + offset; + + output = Activation::template apply(shifted); + } + } + + private: + T epsilon; + const T* scaling_factor_data; + const T* offset_data; + const T* estimated_mean_data; +}; + +// Type aliases for the output kernels, purely for the sake of better launch +// dispatching code readability. +template +using WithBiasAdd = BiasAddOutputKernel; +template +using WithBiasAddAndRelu = BiasAddOutputKernel; +template +using WithBiasAddAndRelu6 = BiasAddOutputKernel; +template +using WithBiasAddAndElu = BiasAddOutputKernel; +template +using WithFusedBatchNorm = FusedBatchNormOutputKernel; +template +using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel; +template +using WithFusedBatchNormAndRelu6 = FusedBatchNormOutputKernel; +template +using WithFusedBatchNormAndElu = FusedBatchNormOutputKernel; + +template +Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args) { + // Bias of the following dimensions: [ output_depth ] + const Tensor& bias = context->input(2); + + if (bias.dims() != 1) + return errors::InvalidArgument("bias must be 1-dimensional", + bias.shape().DebugString()); + + const auto data_ptr = [](const Tensor& tensor) -> const T* { + return reinterpret_cast(tensor.tensor_data().data()); + }; + + args->bias_add_data = data_ptr(bias); + + return Status::OK(); +} + +template +Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon, + FusedBatchNormArgs* args) { + const Tensor& scale = context->input(2); + const Tensor& offset = context->input(3); + const Tensor& estimated_mean = context->input(4); + const Tensor& estimated_variance = context->input(5); + + if (scale.dims() != 1) + return errors::InvalidArgument("scale must be 1-dimensional", + scale.shape().DebugString()); + if (offset.dims() != 1) + return errors::InvalidArgument("offset must be 1-dimensional", + offset.shape().DebugString()); + if (estimated_mean.dims() != 1) + return errors::InvalidArgument("estimated_mean must be 1-dimensional", + estimated_mean.shape().DebugString()); + if (estimated_variance.dims() != 1) + return errors::InvalidArgument("estimated_variance must be 1-dimensional", + estimated_variance.shape().DebugString()); + + const auto data_ptr = [](const Tensor& tensor) -> const T* { + return reinterpret_cast(tensor.tensor_data().data()); + }; + + args->scale_data = data_ptr(scale); + args->offset_data = data_ptr(offset); + args->estimated_mean_data = data_ptr(estimated_mean); + args->estimated_variance_data = data_ptr(estimated_variance); + + // Precompute scaling factor once for all output blocks (kernels). + args->scaling_factor = + (estimated_variance.flat() + static_cast(epsilon)).rsqrt() * + scale.flat(); + + return Status::OK(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_ diff --git a/tensorflow/core/kernels/fuzzing/corpus/decode_png/1b0384bc2d5ac42b7425ac51c374f60c b/tensorflow/core/kernels/fuzzing/corpus/decode_png/1b0384bc2d5ac42b7425ac51c374f60c new file mode 100644 index 00000000000..00930411573 Binary files /dev/null and b/tensorflow/core/kernels/fuzzing/corpus/decode_png/1b0384bc2d5ac42b7425ac51c374f60c differ diff --git a/tensorflow/core/kernels/fuzzing/corpus/decode_png/38bd2bd767d0c4ddd531b3893080b952 b/tensorflow/core/kernels/fuzzing/corpus/decode_png/38bd2bd767d0c4ddd531b3893080b952 new file mode 100644 index 00000000000..07e0bbeab40 Binary files /dev/null and b/tensorflow/core/kernels/fuzzing/corpus/decode_png/38bd2bd767d0c4ddd531b3893080b952 differ diff --git a/tensorflow/core/kernels/fuzzing/corpus/decode_png/41438a3c1c77c64a2f0840a2427f8834 b/tensorflow/core/kernels/fuzzing/corpus/decode_png/41438a3c1c77c64a2f0840a2427f8834 new file mode 100644 index 00000000000..f2da9c416b2 --- /dev/null +++ b/tensorflow/core/kernels/fuzzing/corpus/decode_png/41438a3c1c77c64a2f0840a2427f8834 @@ -0,0 +1 @@ ++ \ No newline at end of file diff --git a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc index 4dbb6a71160..b3b637bac72 100644 --- a/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/string_split_fuzz.cc @@ -34,7 +34,7 @@ class FuzzStringSplit : public FuzzSession { Tensor delimiter_tensor(tensorflow::DT_STRING, TensorShape({})); if (size > 0) { - // The spec for split is that the delimeter should be 0 or 1 characters. + // The spec for split is that the delimiter should be 0 or 1 characters. // Naturally, fuzz it with something larger. (This omits the possibility // of handing it a > int32_max size string, which should be tested for in // an explicit test). diff --git a/tensorflow/core/kernels/gather_functor_gpu.cu.h b/tensorflow/core/kernels/gather_functor_gpu.cu.h index fe7850f9253..2db44621c91 100644 --- a/tensorflow/core/kernels/gather_functor_gpu.cu.h +++ b/tensorflow/core/kernels/gather_functor_gpu.cu.h @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -90,7 +90,7 @@ struct GatherFunctor { const int64 indices_size = indices.size(); const int64 slice_size = params.dimension(2); - CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); + GpuLaunchConfig config = GetCudaLaunchConfig(out_size, d); if (is_axis_zero) { TF_CHECK_OK(CudaLaunchKernel( GatherOpKernel, config.block_count, diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc index 58867a34bc2..0b82b72ccc3 100644 --- a/tensorflow/core/kernels/gather_nd_op.cc +++ b/tensorflow/core/kernels/gather_nd_op.cc @@ -18,14 +18,11 @@ limitations under the License. #include "tensorflow/core/kernels/gather_nd_op.h" #include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/util.h" namespace tensorflow { @@ -74,134 +71,10 @@ class GatherNdOp : public OpKernel { // // Same for the GPU kernel. TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); +TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU); #undef REGISTER_GATHER_ND_CPU -namespace functor { -template -Status DoGatherNd(OpKernelContext* c, const Tensor& params, - const Tensor& indices, Tensor* out) { - if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) { - return errors::InvalidArgument("params must be at least a vector"); - } - if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) { - return errors::InvalidArgument("indices must be at least a vector"); - } - if (indices.dim_size(indices.dims() - 1) > params.dims()) { - return errors::InvalidArgument( - "index innermost dimension length must be <= params rank; saw: ", - indices.dim_size(indices.dims() - 1), " vs. ", params.dims()); - } - - const TensorShape& indices_shape(indices.shape()); - const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1); - - // Check that we have enough index space - int64 N_big = 1; - for (int i = 0; i < indices_shape.dims() - 1; ++i) { - N_big *= indices_shape.dim_size(i); - } - if (N_big > std::numeric_limits::max()) { - return errors::InvalidArgument( - "indices has too many elements for int indexing: ", N_big, " > ", - std::numeric_limits::max()); - } - if (params.NumElements() > std::numeric_limits::max()) { - return errors::InvalidArgument("params.NumElements() too large for ", - DataTypeString(DataTypeToEnum::v()), - " indexing: ", params.NumElements(), " > ", - std::numeric_limits::max()); - } - - // The result shape is - // indices.shape[:-1] + params.shape[indices.shape[-1]:] - Index N_result = 1; - for (int i = 0; i < indices_shape.dims() - 1; ++i) { - N_result *= indices_shape.dim_size(i); - } - - const TensorShape& params_shape(params.shape()); - Index total_nd = params_shape.dims(); - - TensorShape result_shape(indices_shape); - result_shape.RemoveLastDims(1); - - int64 slice_size_big = 1; - for (Index i = indices_nd; i < total_nd; ++i) { - slice_size_big *= params_shape.dim_size(i); - result_shape.AddDim(params_shape.dim_size(i)); - } - - if (slice_size_big > std::numeric_limits::max()) { - return errors::InvalidArgument( - "slice size is too large for indexing: ", slice_size_big, " > ", - std::numeric_limits::max()); - } - - const Index slice_size = static_cast(slice_size_big); - - TF_RETURN_IF_ERROR( - c->allocate_temp(DataTypeToEnum::value, result_shape, out)); - - if (N_result > 0) { - if (params_shape.num_elements() == 0) { - return errors::InvalidArgument( - "Requested more than 0 entries, but " - "params is empty. Params shape: ", - params_shape.DebugString()); - } - - auto indices_mat = indices.flat_inner_dims(); - - Index bad_i = -1; - - // Request to copy slices / subtensors - // Make out a matrix with the slices the col size. - auto out_mat = out->shaped({N_result, slice_size}); - Tensor scratch; - TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch)); - auto scratch_scalar = scratch.scalar(); - - switch (indices_nd) { -#define PARAMS_CASE(IXDIM) \ - case IXDIM: { \ - functor::GatherNdSlice func; \ - auto params_flat = params.flat_outer_dims(); \ - bad_i = func(c->eigen_device(), slice_size, scratch_scalar, \ - params_flat, indices_mat, out_mat); \ - } break - PARAMS_CASE(0); - PARAMS_CASE(1); - PARAMS_CASE(2); - PARAMS_CASE(3); - PARAMS_CASE(4); - PARAMS_CASE(5); - PARAMS_CASE(6); - PARAMS_CASE(7); -#undef PARAMS_CASE - default: - return errors::InvalidArgument( - "Only indices.shape[-1] values between 1 and 7 " - "are currently supported. Requested rank: ", - indices_nd); - } - - // bad_i will only return >= 0 on CPUs right now. - if (bad_i >= 0) { - auto shape = indices.shape(); - shape.RemoveLastDims(1); - return errors::InvalidArgument( - "indices", SliceDebugString(shape, bad_i), " = [", - str_util::Join( - gtl::ArraySlice(&indices_mat(bad_i, 0), indices_nd), ", "), - "] does not index into param shape ", params.shape().DebugString()); - } - } - return Status::OK(); -} - -} // namespace functor - #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h index 77c0d7717ee..46414a38fb0 100644 --- a/tensorflow/core/kernels/gather_nd_op.h +++ b/tensorflow/core/kernels/gather_nd_op.h @@ -19,8 +19,11 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" namespace tensorflow { @@ -29,6 +32,7 @@ class Status; class Tensor; namespace functor { + template struct GatherNdSlice { // Performs a slice gather op on (Tparams, Tindices), writing to Tout. @@ -43,7 +47,126 @@ struct GatherNdSlice { template Status DoGatherNd(OpKernelContext* c, const Tensor& params, - const Tensor& indices, Tensor* out); + const Tensor& indices, Tensor* out) { + if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) { + return errors::InvalidArgument("params must be at least a vector"); + } + if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) { + return errors::InvalidArgument("indices must be at least a vector"); + } + if (indices.dim_size(indices.dims() - 1) > params.dims()) { + return errors::InvalidArgument( + "index innermost dimension length must be <= params rank; saw: ", + indices.dim_size(indices.dims() - 1), " vs. ", params.dims()); + } + + const TensorShape& indices_shape(indices.shape()); + const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1); + + // Check that we have enough index space + int64 N_big = 1; + for (int i = 0; i < indices_shape.dims() - 1; ++i) { + N_big *= indices_shape.dim_size(i); + } + if (N_big > std::numeric_limits::max()) { + return errors::InvalidArgument( + "indices has too many elements for int indexing: ", N_big, " > ", + std::numeric_limits::max()); + } + if (params.NumElements() > std::numeric_limits::max()) { + return errors::InvalidArgument("params.NumElements() too large for ", + DataTypeString(DataTypeToEnum::v()), + " indexing: ", params.NumElements(), " > ", + std::numeric_limits::max()); + } + + // The result shape is + // indices.shape[:-1] + params.shape[indices.shape[-1]:] + Index N_result = 1; + for (int i = 0; i < indices_shape.dims() - 1; ++i) { + N_result *= indices_shape.dim_size(i); + } + + const TensorShape& params_shape(params.shape()); + Index total_nd = params_shape.dims(); + + TensorShape result_shape(indices_shape); + result_shape.RemoveLastDims(1); + + int64 slice_size_big = 1; + for (Index i = indices_nd; i < total_nd; ++i) { + slice_size_big *= params_shape.dim_size(i); + result_shape.AddDim(params_shape.dim_size(i)); + } + + if (slice_size_big > std::numeric_limits::max()) { + return errors::InvalidArgument( + "slice size is too large for indexing: ", slice_size_big, " > ", + std::numeric_limits::max()); + } + + const Index slice_size = static_cast(slice_size_big); + + TF_RETURN_IF_ERROR( + c->allocate_temp(DataTypeToEnum::value, result_shape, out)); + + if (N_result > 0) { + if (params_shape.num_elements() == 0) { + return errors::InvalidArgument( + "Requested more than 0 entries, but " + "params is empty. Params shape: ", + params_shape.DebugString()); + } + + auto indices_mat = indices.flat_inner_dims(); + + Index bad_i = -1; + + // Request to copy slices / subtensors + // Make out a matrix with the slices the col size. + auto out_mat = out->shaped({N_result, slice_size}); + Tensor scratch; + TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch)); + auto scratch_scalar = scratch.scalar(); + + switch (indices_nd) { +#define PARAMS_CASE(IXDIM) \ + case IXDIM: { \ + functor::GatherNdSlice func; \ + auto params_flat = params.flat_outer_dims(); \ + bad_i = func(c->eigen_device(), slice_size, scratch_scalar, \ + params_flat, indices_mat, out_mat); \ + } break + PARAMS_CASE(0); + PARAMS_CASE(1); + PARAMS_CASE(2); + PARAMS_CASE(3); + PARAMS_CASE(4); + PARAMS_CASE(5); + PARAMS_CASE(6); + PARAMS_CASE(7); +#undef PARAMS_CASE + default: + return errors::InvalidArgument( + "Only indices.shape[-1] values between 1 and 7 " + "are currently supported. Requested rank: ", + indices_nd); + } + + // bad_i will only return >= 0 on CPUs right now. + if (bad_i >= 0) { + auto shape = indices.shape(); + shape.RemoveLastDims(1); + return errors::InvalidArgument( + "indices", SliceDebugString(shape, bad_i), " = [", + str_util::Join( + gtl::ArraySlice(&indices_mat(bad_i, 0), indices_nd), ", "), + "] does not index into param shape ", params.shape().DebugString()); + } + } + return Status::OK(); +} + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h index cf9817dc306..c3d2f701398 100644 --- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -152,6 +152,7 @@ struct GatherNdSlice { REGISTER_GATHER_ND_FULL(type, int64) TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); +TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU); } // namespace functor diff --git a/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc b/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc index 22fb6674413..1274e3f75c9 100644 --- a/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/gather_nd_op.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -84,7 +84,7 @@ struct GatherNdSlice { batch_indices[i - 1] = Tparams.dimension(i - 1); batch_strides[i - 1] = batch_strides[i] * Tparams.dimension(i); } - CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); + GpuLaunchConfig config = GetCudaLaunchConfig(out_size, d); TF_CHECK_OK(CudaLaunchKernel(GatherSliceOpKernel, config.block_count, config.thread_per_block, 0, diff --git a/tensorflow/core/kernels/gather_nd_op_test.cc b/tensorflow/core/kernels/gather_nd_op_test.cc index 9f8658ef0e8..b0b5c958b5a 100644 --- a/tensorflow/core/kernels/gather_nd_op_test.cc +++ b/tensorflow/core/kernels/gather_nd_op_test.cc @@ -57,9 +57,9 @@ namespace { class GatherNdOpTest : public OpsTestBase { protected: - void MakeOp(DataType index_type) { + void MakeOp(DataType param_type, DataType index_type) { TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(param_type)) .Input(FakeInput(index_type)) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); @@ -67,7 +67,7 @@ class GatherNdOpTest : public OpsTestBase { }; TEST_F(GatherNdOpTest, Simple) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5}), {0, 1, 2, 8, 4}); @@ -80,6 +80,32 @@ TEST_F(GatherNdOpTest, Simple) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +TEST_F(GatherNdOpTest, Quantized_UINT8) { + MakeOp(DT_QUINT8, DT_INT32); + + // Feed and run + AddInputFromArray(TensorShape({5}), {0, 1, 2, 8, 4}); + AddInputFromArray(TensorShape({2, 1}), {3, 4}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_QUINT8, TensorShape({2})); + test::FillValues(&expected, {8, 4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(GatherNdOpTest, Quantized_INT8) { + MakeOp(DT_QINT8, DT_INT32); + + AddInputFromArray(TensorShape({5}), {0, 1, 2, 8, 4}); + AddInputFromArray(TensorShape({2, 1}), {3, 4}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_QINT8, TensorShape({2})); + test::FillValues(&expected, {8, 4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + constexpr int kLookups = 2000; template diff --git a/tensorflow/core/kernels/cuda_device_array.h b/tensorflow/core/kernels/gpu_device_array.h similarity index 86% rename from tensorflow/core/kernels/cuda_device_array.h rename to tensorflow/core/kernels/gpu_device_array.h index 74dc298c7a5..51eb8bba60c 100644 --- a/tensorflow/core/kernels/cuda_device_array.h +++ b/tensorflow/core/kernels/gpu_device_array.h @@ -15,20 +15,21 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ #define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/kernels/cuda_device_array_gpu.h" +#include "tensorflow/core/kernels/gpu_device_array_gpu.h" namespace tensorflow { // Create an array of value on the host, to be sent to kernel using -// CudaDeviceArrayStruct. +// GpuDeviceArrayStruct. // // Usage: // int size = ...; -// CudaDeviceArrayOnHost ptrs(context, size); +// GpuDeviceArrayOnHost ptrs(context, size); // OP_REQUIRES_OK(ptrs.Init()); // for (int i = 0; i < size; ++i) { // ptrs.Set(i, ...); @@ -38,9 +39,9 @@ namespace tensorflow { // // ValueType must be memcopyable. template -class CudaDeviceArrayOnHost { +class GpuDeviceArrayOnHost { public: - CudaDeviceArrayOnHost(OpKernelContext* context, int32 size) + GpuDeviceArrayOnHost(OpKernelContext* context, int32 size) : context_(context), total_bytes_(static_cast(size) * sizeof(ValueType)) { data_.size = size; @@ -93,7 +94,7 @@ class CudaDeviceArrayOnHost { return Status::OK(); } - const CudaDeviceArrayStruct& data() const { + const GpuDeviceArrayStruct& data() const { // Ensure Finalize is called. DCHECK(inlined() || out_of_line_values_on_gpu_.IsInitialized()); return data_; @@ -105,16 +106,16 @@ class CudaDeviceArrayOnHost { OpKernelContext* const context_; const int64 total_bytes_; // total size of all pointers. ValueType* values_ = nullptr; - CudaDeviceArrayStruct data_; + GpuDeviceArrayStruct data_; Tensor out_of_line_values_on_host_; Tensor out_of_line_values_on_gpu_; - TF_DISALLOW_COPY_AND_ASSIGN(CudaDeviceArrayOnHost); + TF_DISALLOW_COPY_AND_ASSIGN(GpuDeviceArrayOnHost); }; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_ diff --git a/tensorflow/core/kernels/cuda_device_array_gpu.h b/tensorflow/core/kernels/gpu_device_array_gpu.h similarity index 73% rename from tensorflow/core/kernels/cuda_device_array_gpu.h rename to tensorflow/core/kernels/gpu_device_array_gpu.h index 64fa3cb806b..3d81712dd76 100644 --- a/tensorflow/core/kernels/cuda_device_array_gpu.h +++ b/tensorflow/core/kernels/gpu_device_array_gpu.h @@ -18,15 +18,16 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ #define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) namespace tensorflow { -static constexpr int kMaxInlineCudaPointers = 8; -// To decode on the device side, use GetCudaDeviceArrayOnDevice. -// To encode on the host side, use CudaDeviceArrayOnHost. +static constexpr int kMaxInlineGpuPointers = 8; +// To decode on the device side, use GetGpuDeviceArrayOnDevice. +// To encode on the host side, use GpuDeviceArrayOnHost. template -struct CudaDeviceArrayStruct { +struct GpuDeviceArrayStruct { int32 size; // used if size <= MaxInlineValues; ValueType inline_values[MaxInlineValues]; @@ -34,8 +35,8 @@ struct CudaDeviceArrayStruct { }; template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice( - CudaDeviceArrayStruct* data) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetGpuDeviceArrayOnDevice( + GpuDeviceArrayStruct* data) { if (data->size <= MaxInlineValues) { return data->inline_values; } else { @@ -45,6 +46,6 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice( } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_ diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc new file mode 100644 index 00000000000..a6a13345c71 --- /dev/null +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -0,0 +1,151 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/gpu_utils.h" + +#if GOOGLE_CUDA + +#include "google/protobuf/any.pb.h" +#include "tensorflow/core/platform/logger.h" +#include "tensorflow/core/protobuf/autotuning.pb.h" +#include "tensorflow/core/protobuf/conv_autotuning.pb.h" +#include "tensorflow/core/util/proto/proto_utils.h" + +namespace tensorflow { +namespace { + +tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { + tensorflow::CudnnVersion cudnn_version; + if (auto* dnn = stream_executor->AsDnn()) { + se::port::StatusOr version_or = dnn->GetVersion(); + if (version_or.ok()) { + const auto& version = version_or.ValueOrDie(); + cudnn_version.set_major(version.major_version()); + cudnn_version.set_minor(version.minor_version()); + cudnn_version.set_patch(version.patch()); + } + } + return cudnn_version; +} + +tensorflow::ComputeCapability GetComputeCapability( + se::StreamExecutor* stream_executor) { + tensorflow::ComputeCapability cc; + int cc_major, cc_minor; + stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + cc.set_major(cc_major); + cc.set_minor(cc_minor); + return cc; +} + +} // namespace + +void LogConvAutotuneResults(se::dnn::ConvolutionKind kind, + se::dnn::DataType element_type, + const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, + se::StreamExecutor* stream_exec, + absl::Span results) { + AutotuningLog log; + { + ConvolutionProto instr; + instr.set_kind(kind); + *instr.mutable_input() = input_desc.ToProto(element_type); + *instr.mutable_filter() = filter_desc.ToProto(element_type); + *instr.mutable_output() = output_desc.ToProto(element_type); + *instr.mutable_conv_desc() = conv_desc.ToProto(); + log.mutable_instr()->PackFrom(std::move(instr)); + instr.set_conv_scale(1); + instr.set_side_value_scale(0); + } + *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec); + *log.mutable_compute_capability() = GetComputeCapability(stream_exec); + log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id()); + for (const auto& result : results) { + *log.add_results() = result; + } + Logger::Singleton()->LogProto(log); +} + +void LogFusedConvForwardAutotuneResults( + se::dnn::DataType element_type, const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, double conv_scale, + double side_value_scale, se::dnn::ActivationMode activation_mode, + se::StreamExecutor* stream_exec, absl::Span results) { + AutotuningLog log; + { + ConvolutionProto instr; + instr.set_kind(se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION); + *instr.mutable_input() = input_desc.ToProto(element_type); + *instr.mutable_filter() = filter_desc.ToProto(element_type); + *instr.mutable_output() = output_desc.ToProto(element_type); + *instr.mutable_conv_desc() = conv_desc.ToProto(); + instr.set_conv_scale(conv_scale); + instr.set_side_value_scale(side_value_scale); + instr.set_activation(activation_mode); + log.mutable_instr()->PackFrom(std::move(instr)); + } + *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec); + *log.mutable_compute_capability() = GetComputeCapability(stream_exec); + log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id()); + for (const auto& result : results) { + *log.add_results() = result; + } + Logger::Singleton()->LogProto(log); +} + +Status BestCudnnConvAlgorithm(absl::Span results, + se::dnn::AlgorithmConfig* algo) { + // TODO(jlebar): Exclude conv ops with failures, once we have failure checking + // and have confidence that it's correct. + + const AutotuneResult* best_result = std::min_element( + results.begin(), results.end(), + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return proto_utils::FromDurationProto(lhs.run_time()) < + proto_utils::FromDurationProto(rhs.run_time()); + }); + + const AutotuneResult* best_result_no_scratch = std::min_element( + results.begin(), results.end(), + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return std::make_tuple(lhs.scratch_bytes(), + proto_utils::FromDurationProto(lhs.run_time())) < + std::make_tuple(rhs.scratch_bytes(), + proto_utils::FromDurationProto(rhs.run_time())); + }); + + if (best_result == results.end()) { + return errors::NotFound("No algorithm worked!"); + } + algo->set_algorithm({best_result->conv().algorithm(), + best_result->conv().tensor_ops_enabled()}); + if (best_result_no_scratch != results.end() && + best_result_no_scratch->scratch_bytes() == 0) { + algo->set_algorithm_no_scratch( + {best_result_no_scratch->conv().algorithm(), + best_result_no_scratch->conv().tensor_ops_enabled()}); + } + return Status::OK(); +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index 86146f75f4d..48d6813e9ca 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -16,10 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ #define TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include +#include "absl/types/span.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -28,6 +31,9 @@ limitations under the License. namespace tensorflow { +class NodeDef; +class AutotuneResult; + template inline se::DeviceMemory AsDeviceMemory(const T* cuda_memory, uint64 size) { se::DeviceMemoryBase wrapped(const_cast(cuda_memory), size * sizeof(T)); @@ -90,7 +96,28 @@ class AutoTuneMap { } if (new_score >= min_score_threshold_) { VLOG(1) << GetActionSummary("accepts", params, config); + } else if (autotune_global_count_ >= max_autotune_global_count_) { + // The autotuning exceeds the max iteration threshold and we accept the + // the winner if it exists in the map, otherwise we accept the current + // winner. + auto winner = params_config_map_.find(params); + if (winner == params_config_map_.end()) { + VLOG(1) << GetActionSummary("creates", params, config); + for (int i = 0; i < min_score_threshold_; ++i) { + VLOG(1) << GetActionSummary("promotes", params, config); + } + params_config_map_.insert( + std::make_pair(params, ValueType{config, min_score_threshold_, 1})); + } else { + int promotes_times = min_score_threshold_ - winner->second.score; + for (int i = 0; i < promotes_times; ++i) { + VLOG(1) << GetActionSummary("promotes", params, config); + } + winner->second.score = min_score_threshold_; + } + VLOG(1) << GetActionSummary("accepts", params, config); } + autotune_global_count_++; } private: @@ -109,6 +136,8 @@ class AutoTuneMap { min_score_threshold_ = std::max(min_score_threshold_, 1); max_autotune_count_ = std::max( 5 * min_score_threshold_ * min_score_threshold_, min_warmup_iterations); + max_autotune_global_count_ = 2 * max_autotune_count_; + autotune_global_count_ = 0; } template @@ -138,6 +167,8 @@ class AutoTuneMap { string name_; int32 min_score_threshold_; int32 max_autotune_count_; + int32 max_autotune_global_count_; + int32 autotune_global_count_; TF_DISALLOW_COPY_AND_ASSIGN(AutoTuneMap); }; @@ -156,8 +187,33 @@ class AutoTuneSingleton { } }; +// Logs convolution results to customized back-storage. +void LogConvAutotuneResults(se::dnn::ConvolutionKind kind, + se::dnn::DataType element_type, + const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, + se::StreamExecutor* stream_exec, + absl::Span results); + +// Logs fused convolution results to customized back-storage. +void LogFusedConvForwardAutotuneResults( + se::dnn::DataType element_type, const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, double conv_scale, + double side_value_scale, se::dnn::ActivationMode activation_mode, + se::StreamExecutor* stream_exec, absl::Span results); + +// Returns the best algorithms for the config, one is the fastest, the other is +// other is fastest with 0 scracth space. Unsuccessful autotuning results are +// allowed and ignored. +Status BestCudnnConvAlgorithm(absl::Span results, + se::dnn::AlgorithmConfig* algo); + } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h index 1b382996f88..9c57c1d4298 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h @@ -76,7 +76,7 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor { // TODO(satok): Use actual data passed by FillInputNode and remove // std::vector dummy_input_float_{}; std::unordered_map> input_tensor_data_{}; - // Dummy byte array for cosnt node. + // Dummy byte array for const node. // TODO(satok): Remove std::unordered_map> dummy_const_data_{}; diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc index 374a05850eb..203a4175a6d 100644 --- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc +++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/core/kernels/histogram_op.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/host_constant_op.cc b/tensorflow/core/kernels/host_constant_op.cc index d08a7c9bd27..17dad526ce3 100644 --- a/tensorflow/core/kernels/host_constant_op.cc +++ b/tensorflow/core/kernels/host_constant_op.cc @@ -63,8 +63,6 @@ REGISTER_KERNEL_BUILDER(Name("Const") #endif // TENSORFLOW_USE_SYCL // HostConst: forced to generate output on the host. -// Only used in tests; no op is registered for this kernel -// externally (i.e., in array_ops.cc) REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), _HostConstantOp); REGISTER_KERNEL_BUILDER( Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), _HostConstantOp); diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index 6f797298837..cf63a975cc8 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -112,7 +112,8 @@ REGISTER_GPU_KERNEL(Variant); #undef REGISTER_GPU_KERNEL -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) // A special GPU kernel for int32 and bool. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. @@ -149,6 +150,6 @@ REGISTER_GPU_HOST_KERNEL(ResourceHandle); #undef REGISTER_GPU_HOST_KERNEL -#endif +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc index 506091f76ec..a8ee00e080e 100644 --- a/tensorflow/core/kernels/in_topk_op.cc +++ b/tensorflow/core/kernels/in_topk_op.cc @@ -17,15 +17,18 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/kernels/in_topk_op.h" + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { -template +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template class InTopK : public OpKernel { public: explicit InTopK(OpKernelConstruction* context) : OpKernel(context) { @@ -37,7 +40,10 @@ class InTopK : public OpKernel { void Compute(OpKernelContext* context) override { const auto& predictions_in = context->input(0); const auto& targets_in = context->input(1); - int64 k_val = k_; + + int64 k_value = k_; + const Tensor* k_tensor = nullptr; + if (context->num_inputs() == 3) { const auto& k_in = context->input(2); @@ -45,11 +51,7 @@ class InTopK : public OpKernel { errors::InvalidArgument("k must be 0-D, got shape ", k_in.shape().DebugString())); - if (k_in.dtype() == DT_INT32) { - k_val = k_in.scalar()(); - } else { - k_val = k_in.scalar()(); - } + k_tensor = &k_in; } OP_REQUIRES(context, predictions_in.dims() == 2, @@ -61,8 +63,9 @@ class InTopK : public OpKernel { predictions_in.dim_size(0), " must match length of targets ", targets_in.dim_size(0))); - const auto& predictions = predictions_in.matrix(); - const auto& targets = targets_in.vec(); + + const auto predictions = predictions_in.matrix(); + const auto targets = targets_in.vec(); Tensor* t_out = nullptr; OP_REQUIRES_OK(context, @@ -70,28 +73,11 @@ class InTopK : public OpKernel { 0, TensorShape({targets_in.dim_size(0)}), &t_out)); auto out = t_out->vec(); - const auto size = targets.size(); - const auto num_classes = predictions.dimension(1); - for (int b = 0; b < size; b++) { - auto target = internal::SubtleMustCopy(targets(b)); - OP_REQUIRES(context, FastBoundsCheck(target, num_classes), - errors::InvalidArgument("targets[", b, "] is out of range")); - T target_prediction = predictions(b, target); - bool cannot_say = !std::isfinite(target_prediction); - int more_probable_classes = 0; - if (!cannot_say) { - for (int i = 0; i < num_classes; ++i) { - T pred = predictions(b, i); - if (!std::isfinite(pred)) { - cannot_say = true; - break; - } else if (pred > target_prediction) { - ++more_probable_classes; - } - } - } - out(b) = cannot_say ? false : (more_probable_classes < k_val); - } + functor::InTopKFunctor f; + functor::TopKArg arg; + arg.k_value = k_value; + arg.k_tensor = k_tensor; + f(context, predictions, targets, arg, out); } private: @@ -104,14 +90,14 @@ REGISTER_KERNEL_BUILDER(Name("InTopK") .HostMemory("targets") .HostMemory("precision") .TypeConstraint("T"), - InTopK); + InTopK); REGISTER_KERNEL_BUILDER(Name("InTopK") .Device(DEVICE_CPU) .HostMemory("predictions") .HostMemory("targets") .HostMemory("precision") .TypeConstraint("T"), - InTopK); + InTopK); REGISTER_KERNEL_BUILDER(Name("InTopKV2") .Device(DEVICE_CPU) @@ -120,7 +106,7 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2") .HostMemory("k") .HostMemory("precision") .TypeConstraint("T"), - InTopK); + InTopK); REGISTER_KERNEL_BUILDER(Name("InTopKV2") .Device(DEVICE_CPU) .HostMemory("predictions") @@ -128,6 +114,34 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2") .HostMemory("k") .HostMemory("precision") .TypeConstraint("T"), - InTopK); + InTopK); + +#if GOOGLE_CUDA + +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, TARGET_T) \ + template <> \ + void InTopKFunctor::operator()( \ + OpKernelContext* context, \ + typename TTypes::ConstTensor predictions, \ + typename TTypes::ConstVec targets, const TopKArg k, \ + typename TTypes::Vec output); \ + extern template struct InTopKFunctor; + +DECLARE_GPU_SPEC(float, int32); +DECLARE_GPU_SPEC(float, int64); + +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNEL_BUILDER( + Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint("T"), + InTopK); +REGISTER_KERNEL_BUILDER( + Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint("T"), + InTopK); + +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/in_topk_op.h b/tensorflow/core/kernels/in_topk_op.h new file mode 100644 index 00000000000..52716f2d272 --- /dev/null +++ b/tensorflow/core/kernels/in_topk_op.h @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// InTopK argument can be passed either via mode attribute (InTopK op), or as an +// input tensor (InTopKV2 op). +struct TopKArg { + int64 k_value = -1; + const Tensor* k_tensor = nullptr; +}; + +template +struct InTopKFunctor { + template + using Dims = Eigen::DSizes; + + void operator()(OpKernelContext* context, + typename TTypes::ConstTensor predictions, + typename TTypes::ConstVec targets, const TopKArg k, + typename TTypes::Vec output) {} +}; + +template +struct InTopKFunctor { + void operator()(OpKernelContext* context, + typename TTypes::ConstTensor predictions, + typename TTypes::ConstVec targets, const TopKArg k, + typename TTypes::Vec output) { + const Eigen::Index num_targets = predictions.dimension(0); + const Eigen::Index num_classes = predictions.dimension(1); + + int64 k_val = k.k_value; + if (k.k_tensor != nullptr) { + if (k.k_tensor->dtype() == DT_INT32) { + k_val = k.k_tensor->scalar()(); + } else { + k_val = k.k_tensor->scalar()(); + } + } + + for (int batch_idx = 0; batch_idx < num_targets; batch_idx++) { + auto target = internal::SubtleMustCopy(targets(batch_idx)); + + bool cannot_say = !FastBoundsCheck(target, num_classes) || + !std::isfinite(predictions(batch_idx, target)); + + int more_probable_classes = 0; + if (!cannot_say) { + const T target_prediction = predictions(batch_idx, target); + + for (int class_idx = 0; class_idx < num_classes; ++class_idx) { + T pred = predictions(batch_idx, class_idx); + if (!std::isfinite(pred)) { + cannot_say = true; + break; + } else if (pred > target_prediction) { + ++more_probable_classes; + if (more_probable_classes > k_val) break; + } + } + } + output(batch_idx) = cannot_say ? false : (more_probable_classes < k_val); + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_ diff --git a/tensorflow/core/kernels/in_topk_op_gpu.cu.cc b/tensorflow/core/kernels/in_topk_op_gpu.cu.cc new file mode 100644 index 00000000000..3c14838a53b --- /dev/null +++ b/tensorflow/core/kernels/in_topk_op_gpu.cu.cc @@ -0,0 +1,176 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/in_topk_op.h" +#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" +#include "tensorflow/core/kernels/reduction_ops.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +// Compare each prediction in 'predictions' with a target prediction for the +// batch, and write result to the 'mask': +// -1: If the target class is out of range, or if the prediction value is not +// finite and can't be compared to target prediction (and vice versa). +// 0: If prediction is smaller than the target prediction for the batch. +// 1: If prediction is larger than the target prediction for the batch. +template +__global__ void ComputePredictionMaskKernel( + const T* predictions, // dims: [ num_targets x num_classes ] + const TargetT* targets, // dims: [ num_targets ] + int64* mask, // dims: [ num_targets x num_classes ] + int num_targets, int num_classes) { + CUDA_1D_KERNEL_LOOP(i, num_targets * num_classes) { + const int batch_index = i / num_classes; + TargetT target_idx = ldg(targets + batch_index); + + if (!FastBoundsCheck(target_idx, num_classes)) { + mask[i] = -1; + return; + } + + T prediction = ldg(predictions + i); + T target_prediction = + ldg(predictions + batch_index * num_classes + target_idx); + + if (!Eigen::numext::isfinite(prediction) || + !Eigen::numext::isfinite(target_prediction)) { + mask[i] = -1; + } else { + mask[i] = prediction > target_prediction ? 1 : 0; + } + } +} + +// Reduce all prediction masks either to the sum of '1' for each prediction +// larger than the target, or to '-1' if target class in invalid of predictions +// in a batch have non-finite values. +struct MaskSum { + __host__ __device__ int64 operator()(const int64& a, const int64& b) const { + if (a < 0 || b < 0) + return -1; + else + return a + b; + } +}; + +namespace reduction_op_helper { +template <> +struct IdentityValue { + int64 operator()() { return 0; } +}; + +} // namespace reduction_op_helper + +template +struct InTopKFunctor { + template + using Dims = Eigen::DSizes; + + void operator()(OpKernelContext* context, + typename TTypes::ConstTensor predictions, + typename TTypes::ConstVec targets, const TopKArg k, + typename TTypes::Vec output) { + const Eigen::Index num_targets = predictions.dimension(0); + const Eigen::Index num_classes = predictions.dimension(1); + + OP_REQUIRES( + context, num_targets * num_classes < std::numeric_limits::max(), + errors::InvalidArgument( + "Number of targets * number of classes must be less than INT_MAX")); + + // Temporary storage for a mask computed by `ComputePredictionMaskKernel`. + Tensor predictions_mask; + OP_REQUIRES_OK( + context, context->allocate_temp(DT_INT64, + TensorShape({num_targets, num_classes}), + &predictions_mask)); + + // Number of predictions for each target that are larger than the target + // prediction (or -1 if we can't compute this number, because not all + // predictions are finite or target class is out of range). + Tensor num_larger_prediction; + OP_REQUIRES_OK(context, + context->allocate_temp(DT_INT64, TensorShape({num_targets}), + &num_larger_prediction)); + + const auto& d = context->eigen_device(); + + // Compute a mask for all predictions. + CudaLaunchConfig config = GetCudaLaunchConfig(num_targets * num_classes, d); + OP_REQUIRES_OK(context, CudaLaunchKernel( + ComputePredictionMaskKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), predictions.data(), targets.data(), + predictions_mask.flat().data(), + num_targets, num_classes)); + + // Reduce prediction masks to number of predictions larger than the target + // prediction, or to the negative value if we can't compute an answer. + { + auto in = predictions_mask.matrix(); + auto out = num_larger_prediction.flat(); + + ReduceImpl>( + context, (int64*)out.data(), (int64*)in.data(), in.rank(), + in.dimension(0), in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), Dims<1>(1), + MaskSum()); + } + + // Compute if target prediction is in top K predictions. + auto cnt = num_larger_prediction.flat(); + + if (k.k_tensor != nullptr) { + if (k.k_tensor->dtype() == DT_INT32) { + output.device(d) = + (cnt >= cnt.constant(0)) && + (cnt < k.k_tensor->flat().template cast().broadcast( + Dims<1>(num_targets))); + } else { + output.device(d) = + (cnt >= cnt.constant(0)) && + (cnt < k.k_tensor->flat().broadcast(Dims<1>(num_targets))); + } + } else { + output.device(d) = + (cnt >= cnt.constant(0)) && (cnt < targets.constant(k.k_value)); + } + } +}; + +} // namespace functor + +// Definition of the GPU implementations declared in in_topk_op.cc. +#define DEFINE_GPU_KERNELS(T, TARGET_T) \ + template struct functor::InTopKFunctor; + +DEFINE_GPU_KERNELS(float, int32); +DEFINE_GPU_KERNELS(float, int64); + +#undef DEFINE_GPU_KERNELS + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/in_topk_op_test.cc b/tensorflow/core/kernels/in_topk_op_test.cc new file mode 100644 index 00000000000..aacecb08bbe --- /dev/null +++ b/tensorflow/core/kernels/in_topk_op_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +template +static Graph* InTopK(int num_targets, int num_classes, T top_k) { + Graph* g = new Graph(OpRegistry::Global()); + + DataType dtype = DataTypeToEnum::value; + + Tensor predictions_t(DT_FLOAT, TensorShape({num_targets, num_classes})); + predictions_t.flat().setRandom(); + + Tensor targets_t(dtype, TensorShape({num_targets})); + targets_t.flat().setRandom(); + + Tensor k_t(dtype, TensorShape({})); + k_t.scalar() = k_t.scalar().constant(top_k); + + Node* predictions = test::graph::Constant(g, predictions_t, "predictions"); + Node* targets = test::graph::Constant(g, targets_t, "targets"); + Node* k = test::graph::Constant(g, k_t, "k"); + + Node* in_topk; + TF_CHECK_OK(NodeBuilder(g->NewName("in_topk"), "InTopKV2") + .Input(predictions) + .Input(targets) + .Input(k) + .Attr("T", dtype) + .Finalize(g, &in_topk)); + + return g; +} + +#define BM_NAME(T, TARGETS, CLASSES, K, DEVICE) \ + BM_InTopK##_##T##_##TARGETS##_##CLASSES##_##K##_##DEVICE + +#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE) \ + static void BM_NAME(T, TARGETS, CLASSES, K, DEVICE)(int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast(iters) * TARGETS * CLASSES); \ + test::Benchmark(#DEVICE, InTopK(TARGETS, CLASSES, K)).Run(iters); \ + } \ + BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE)); + +BM_InTopK(int64, 64, 1000, 10, cpu); +BM_InTopK(int64, 64, 10000, 10, cpu); + +#ifdef GOOGLE_CUDA +BM_InTopK(int64, 64, 1000, 10, gpu); +BM_InTopK(int64, 64, 10000, 10, gpu); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 7f06764d526..51862854a75 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -543,6 +543,7 @@ REGISTER_EMPTY(float, GPU); REGISTER_EMPTY(double, GPU); REGISTER_EMPTY(Eigen::half, GPU); REGISTER_EMPTY(int64, GPU); +REGISTER_EMPTY(int32, GPU); #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc index 35cfe03e8e2..cdb42645ee2 100644 --- a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc @@ -19,7 +19,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/inplace_ops_functor.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -43,7 +43,7 @@ template Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32 loc, Tensor* output) { const int64 nelem = value.NumElements(); - CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); + GpuLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); auto Toutput = output->flat_outer_dims(); const int64 nrows = Toutput.dimension(0); const int64 ncols = Toutput.dimension(1); @@ -106,7 +106,7 @@ template void DoInplaceOp(const Device& d, InplaceOpType op, const Tensor& i, const Tensor& v, Tensor* y) { const int64 nelem = v.NumElements(); - CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); + GpuLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); auto Ty = y->flat_outer_dims(); const int64 nrows = Ty.dimension(0); const int64 ncols = Ty.dimension(1); @@ -141,7 +141,7 @@ template void DoInplaceOp(const Device& d, InplaceOpType op, const Tensor& i, const Tensor& v, Tensor* y) { const int64 nelem = v.NumElements(); - CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); + GpuLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); auto Ty = y->flat_outer_dims(); const int64 nrows = Ty.dimension(0); const int64 ncols = Ty.dimension(1); diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h index 692f916439c..11ecf7d676e 100644 --- a/tensorflow/core/kernels/linalg_ops_common.h +++ b/tensorflow/core/kernels/linalg_ops_common.h @@ -113,6 +113,8 @@ class LinearAlgebraOp : public OpKernel { Eigen::Matrix; using ConstMatrixMap = Eigen::Map; using MatrixMap = Eigen::Map; + using ConstVectorMap = + Eigen::Map>; using ConstMatrixMaps = gtl::InlinedVector; using MatrixMaps = gtl::InlinedVector; using RealScalar = typename Eigen::NumTraits::Real; @@ -180,6 +182,7 @@ extern template class LinearAlgebraOp; using MatrixMaps = typename Base::MatrixMaps; \ using ConstMatrixMap = typename Base::ConstMatrixMap; \ using ConstMatrixMaps = typename Base::ConstMatrixMaps; \ + using ConstVectorMap = typename Base::ConstVectorMap; \ using TensorShapes = typename Base::TensorShapes; #define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \ diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index b5b7b75143b..1fbc967a039 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include "tensorflow/core/framework/allocator.h" #define EIGEN_USE_THREADS #if GOOGLE_CUDA @@ -196,18 +197,19 @@ Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index, const TensorList& input_list, TensorList** output_list) { // Attempt to forward the input tensor to the output if possible. - AllocatorAttributes attr; - attr.set_on_host(true); - std::unique_ptr maybe_output = - c->forward_input(input_index, output_index, DT_VARIANT, TensorShape{}, - c->input_memory_type(input_index), attr); + std::unique_ptr maybe_output = c->forward_input( + input_index, output_index, DT_VARIANT, TensorShape{}, + c->input_memory_type(input_index), AllocatorAttributes()); Tensor* output_tensor; if (maybe_output != nullptr) { // Woohoo, forwarding succeeded! output_tensor = maybe_output.get(); + c->set_output(output_index, *output_tensor); } else { // If forwarding is not possible allocate a new output tensor and copy // the `input_list` to it. + AllocatorAttributes attr; + attr.set_on_host(true); TF_RETURN_IF_ERROR( c->allocate_output(output_index, {}, &output_tensor, attr)); output_tensor->scalar()() = input_list; @@ -425,15 +427,17 @@ class TensorListResize : public OpKernel { errors::InvalidArgument( "TensorListSlice expects size to be non-negative. Got: ", size)); - AllocatorAttributes attr; - attr.set_on_host(true); - std::unique_ptr maybe_result = c->forward_input( - 0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr); + std::unique_ptr maybe_result = + c->forward_input(0, 0, DT_VARIANT, TensorShape{}, + c->input_memory_type(0), AllocatorAttributes()); if (maybe_result != nullptr) { maybe_result->scalar()().get()->tensors.resize( size, Tensor(DT_INVALID)); + c->set_output(0, *maybe_result); } else { Tensor* result; + AllocatorAttributes attr; + attr.set_on_host(true); OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); TensorList output_list; output_list.element_shape = input_list->element_shape; diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc index e611ae28b9a..c0ec46aacb4 100644 --- a/tensorflow/core/kernels/logging_ops.cc +++ b/tensorflow/core/kernels/logging_ops.cc @@ -13,16 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/logging_ops.h" - #include #include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" +#include "tensorflow/core/framework/logging.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -44,29 +41,13 @@ Status AppendStringToFile(const std::string& fname, StringPiece data, mutex_lock l(*file_mutex); std::unique_ptr file; TF_RETURN_IF_ERROR(env->NewAppendableFile(fname, &file)); - Status a = file->Append(absl::StrCat(data, "\n")); + Status a = file->Append(data); Status c = file->Close(); return a.ok() ? c : a; } } // namespace -namespace logging { - -typedef std::vector Listeners; - -Listeners* GetListeners() { - static Listeners* listeners = new Listeners; - return listeners; -} - -bool RegisterListener(void (*listener)(const char*)) { - GetListeners()->push_back(listener); - return true; -} - -} // end namespace logging - class AssertOp : public OpKernel { public: explicit AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -148,6 +129,7 @@ class PrintV2Op : public OpKernel { public: explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("end", &end_)); SetFilePathIfAny(); if (!file_path_.empty()) return; @@ -171,26 +153,29 @@ class PrintV2Op : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("input", &input_)); const string& msg = input_->scalar()(); + string ended_msg = strings::StrCat(msg, end_); + if (!file_path_.empty()) { // Outputs to a file at the specified path. - OP_REQUIRES_OK(ctx, AppendStringToFile(file_path_, msg, ctx->env())); + OP_REQUIRES_OK(ctx, + AppendStringToFile(file_path_, ended_msg, ctx->env())); return; } - auto listeners = logging::GetListeners(); - if (!listeners->empty()) { - for (auto& listener : *listeners) { - listener(msg.c_str()); - } - } else if (output_stream_ == "stdout") { - std::cout << msg << std::endl; + + if (logging::LogToListeners(ended_msg, "")) { + return; + } + + if (output_stream_ == "stdout") { + std::cout << ended_msg << std::flush; } else if (output_stream_ == "stderr") { - std::cerr << msg << std::endl; + std::cerr << ended_msg << std::flush; } else if (output_stream_ == "log(info)") { - LOG(INFO) << msg << std::endl; + LOG(INFO) << ended_msg << std::flush; } else if (output_stream_ == "log(warning)") { - LOG(WARNING) << msg << std::endl; + LOG(WARNING) << ended_msg << std::flush; } else if (output_stream_ == "log(error)") { - LOG(ERROR) << msg << std::endl; + LOG(ERROR) << ended_msg << std::flush; } else { string error_msg = strings::StrCat( "Unknown output stream: ", output_stream_, ", Valid streams are:"); @@ -206,6 +191,7 @@ class PrintV2Op : public OpKernel { "log(warning)", "log(error)"}; private: + string end_; // Either output_stream_ or file_path_ (but not both) will be non-empty. string output_stream_; string file_path_; diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index b046401c0ae..28a3d94e579 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -57,19 +57,21 @@ class LookupTableOp : public OpKernel { use_node_name_sharing_)); } - auto creator = [ctx, this](lookup::LookupInterface** ret) { - lookup::LookupInterface* container = new Container(ctx, this); - if (!ctx->status().ok()) { - container->Unref(); - return ctx->status(); - } - if (ctx->track_allocations()) { - ctx->record_persistent_memory_allocation( - container->MemoryUsed() + table_handle_.AllocatedBytes()); - } - *ret = container; - return Status::OK(); - }; + auto creator = + [ctx, this](lookup::LookupInterface** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + lookup::LookupInterface* container = new Container(ctx, this); + if (!ctx->status().ok()) { + container->Unref(); + return ctx->status(); + } + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation( + container->MemoryUsed() + table_handle_.AllocatedBytes()); + } + *ret = container; + return Status::OK(); + }; lookup::LookupInterface* table = nullptr; OP_REQUIRES_OK(ctx, diff --git a/tensorflow/core/kernels/lookup_tables/BUILD b/tensorflow/core/kernels/lookup_tables/BUILD deleted file mode 100644 index a25660e987a..00000000000 --- a/tensorflow/core/kernels/lookup_tables/BUILD +++ /dev/null @@ -1,89 +0,0 @@ -# Description: -# OpKernels and resource templates for lookup tables. - -package( - default_visibility = [ - "//tensorflow:__subpackages__", - "//tensorflow:internal", - ], -) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") - -cc_library( - name = "resource_interface_templates", - hdrs = ["resource_interface_templates.h"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -cc_library( - name = "op_kernel_templates", - hdrs = ["op_kernel_templates.h"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:tensor_flag_utils", - "//third_party/eigen3", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/types:span", - ], -) - -tf_kernel_library( - name = "fingerprint64_map_op_kernels", - srcs = [ - "fingerprint64_map_op_kernels.cc", - ], - deps = [ - ":op_kernel_templates", - ":resource_interface_templates", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_kernel_library( - name = "flat_hash_map_op_kernels", - srcs = [ - "flat_hash_map_op_kernels.cc", - ], - deps = [ - ":op_kernel_templates", - ":resource_interface_templates", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:tensor_flag_utils", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -tf_kernel_library( - name = "generic_table_op_kernels", - srcs = [ - "generic_table_op_kernels.cc", - ], - deps = [ - ":op_kernel_templates", - ":resource_interface_templates", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core/kernels:string_view_variant_wrapper", - "@com_google_absl//absl/strings", - ], -) diff --git a/tensorflow/core/kernels/lookup_tables/fingerprint64_map_op_kernels.cc b/tensorflow/core/kernels/lookup_tables/fingerprint64_map_op_kernels.cc deleted file mode 100644 index 36274bc6b63..00000000000 --- a/tensorflow/core/kernels/lookup_tables/fingerprint64_map_op_kernels.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "absl/strings/string_view.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/kernels/lookup_tables/op_kernel_templates.h" -#include "tensorflow/core/kernels/lookup_tables/resource_interface_templates.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace tables { - -// Map x -> (Fingerprint64(x) % num_oov_buckets) + offset. -// num_oov_buckets and offset are node attributes provided at construction -// time. -template -class Fingerprint64Map final - : public virtual LookupInterface, - public virtual LookupWithPrefetchInterface, - absl::Span> { - public: - using key_type = KeyType; - - Fingerprint64Map(int64 num_oov_buckets, int64 offset) - : num_oov_buckets_(num_oov_buckets), offset_(offset) {} - - Status Lookup(const KeyType& key_to_find, ValueType* value) const override { - *value = LookupHelper(key_to_find); - return Status::OK(); - } - - Status Lookup(absl::Span keys, absl::Span values, - int64 prefetch_lookahead) const override { - if (ABSL_PREDICT_FALSE(keys.size() != values.size())) { - return errors::InvalidArgument( - "keys and values do not have the same number of elements (found ", - keys.size(), " vs ", values.size(), ")."); - } - for (size_t i = 0; i < keys.size(); ++i) { - values[i] = LookupHelper(keys[i]); - } - return Status::OK(); - } - - mutex* GetMutex() const override { return nullptr; } - - string DebugString() const override { return __PRETTY_FUNCTION__; } - - private: - ABSL_ATTRIBUTE_ALWAYS_INLINE ValueType - LookupHelper(const KeyType& key_to_find) const { - // This can cause a downcast. - return static_cast(Fingerprint64(key_to_find) % - num_oov_buckets_) + - offset_; - } - - const int64 num_oov_buckets_; - const int64 offset_; - TF_DISALLOW_COPY_AND_ASSIGN(Fingerprint64Map); -}; - -template -struct Fingerprint64MapFactory { - struct Functor { - using resource_type = Fingerprint64Map; - - static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel, - Fingerprint64Map** container) { - int64 num_oov_buckets; - int64 offset; - TF_RETURN_IF_ERROR( - GetNodeAttr(kernel->def(), "num_oov_buckets", &num_oov_buckets)); - TF_RETURN_IF_ERROR(GetNodeAttr(kernel->def(), "offset", &offset)); - *container = new Fingerprint64Map(num_oov_buckets, offset); - return Status::OK(); - } - }; -}; - -template -using ResourceOp = ResourceConstructionOp< - typename Fingerprint64MapFactory< - Fingerprint64Map>::Functor, - // These are the aliases. - LookupInterface, - LookupWithPrefetchInterface, - absl::Span>>; - -#define REGISTER_STRING_KERNEL(ValueType) \ - REGISTER_KERNEL_BUILDER( \ - Name("Fingerprint64Map") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("heterogeneous_key_dtype") \ - .TypeConstraint("table_value_dtype"), \ - ResourceOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Fingerprint64Map") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("heterogeneous_key_dtype") \ - .TypeConstraint("table_value_dtype"), \ - ResourceOp); - -REGISTER_STRING_KERNEL(int32); -REGISTER_STRING_KERNEL(int64); - -#undef REGISTER_STRING_KERNEL - -} // namespace tables -} // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_tables/flat_hash_map_op_kernels.cc b/tensorflow/core/kernels/lookup_tables/flat_hash_map_op_kernels.cc deleted file mode 100644 index 9c37ca87cea..00000000000 --- a/tensorflow/core/kernels/lookup_tables/flat_hash_map_op_kernels.cc +++ /dev/null @@ -1,275 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include "absl/base/attributes.h" -#include "absl/container/flat_hash_map.h" -#include "absl/memory/memory.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/kernels/lookup_tables/op_kernel_templates.h" -#include "tensorflow/core/kernels/lookup_tables/resource_interface_templates.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/fingerprint.h" - -namespace tensorflow { -namespace tables { - -using errors::InvalidArgument; - -// absl::flat_hash_map backed table with inline -// fallback to x -> (Fingerprint64(x) % num_oov_buckets) + offset when looked -// up keys are not in the flat_hash_map. Inlining the fallback table turns out -// to be quite efficient in comparison to virtual dispatch for the fallback -// lookup. -template -class StaticStringFlatHashMap final - : public virtual LookupInterface, - public virtual LookupInterface, - public virtual LookupWithPrefetchInterface< - absl::Span, absl::Span>, - public virtual LookupWithPrefetchInterface, - absl::Span>, - public virtual KeyValueTableInitializerInterface< - absl::Span, absl::Span>, - public virtual KeyValueTableInitializerInterface< - absl::Span, absl::Span>, - public virtual SizeInterface { - public: - using value_type = ValueType; - - StaticStringFlatHashMap(bool enable_synchronization, int64 num_oov_buckets) - : num_oov_buckets_(num_oov_buckets) { - if (enable_synchronization) { - mutex_ = absl::make_unique(); - } - } - - Status Initialize(absl::Span keys, - absl::Span values) override { - if (ABSL_PREDICT_FALSE(keys.size() != values.size())) { - return errors::InvalidArgument( - "keys and values do not have the same number of elements (found ", - keys.size(), " vs ", values.size(), ")."); - } - - table_.reserve(table_.size() + keys.size()); - for (size_t i = 0; i < keys.size(); ++i) { - table_.insert_or_assign(string(keys[i]), values[i]); - } - return Status::OK(); - } - - Status Initialize(absl::Span keys, - absl::Span values) override { - if (ABSL_PREDICT_FALSE(keys.size() != values.size())) { - return errors::InvalidArgument( - "keys and values do not have the same number of elements (found ", - keys.size(), " vs ", values.size(), ")."); - } - - table_.reserve(table_.size() + keys.size()); - for (size_t i = 0; i < keys.size(); ++i) { - table_.insert_or_assign(keys[i], values[i]); - } - return Status::OK(); - } - - Status Lookup(const absl::string_view& key, ValueType* value) const override { - *value = LookupHelper(key); - return Status::OK(); - } - - Status Lookup(const string& key, ValueType* value) const override { - *value = LookupHelper(key); - return Status::OK(); - } - - // keys and values are guaranteed to have the same size by convention. - Status Lookup(absl::Span keys, - absl::Span values, - int64 prefetch_lookahead) const override { - const auto keys_size = keys.size(); - if (prefetch_lookahead <= 0 || prefetch_lookahead >= keys_size) { - for (size_t i = 0; i < keys_size; ++i) { - values[i] = LookupHelper(keys[i]); - } - } else { - for (size_t i = 0; i < keys_size; ++i) { - if (i + prefetch_lookahead < keys.size()) { - table_.prefetch(keys[i + prefetch_lookahead]); - } - values[i] = LookupHelper(keys[i]); - } - } - return Status::OK(); - } - - // keys and values are guaranteed to have the same size by convention. - Status Lookup(absl::Span keys, absl::Span values, - int64 prefetch_lookahead) const override { - const auto keys_size = keys.size(); - if (prefetch_lookahead <= 0 || prefetch_lookahead >= keys_size) { - for (size_t i = 0; i < keys_size; ++i) { - values[i] = LookupHelper(keys[i]); - } - } else { - for (size_t i = 0; i < keys_size; ++i) { - if (i + prefetch_lookahead < keys.size()) { - table_.prefetch(keys[i + prefetch_lookahead]); - } - values[i] = LookupHelper(keys[i]); - } - } - return Status::OK(); - } - - uint64 Size() const override { return table_.size(); } - - mutex* GetMutex() const override { return mutex_.get(); } - - string DebugString() const override { return __PRETTY_FUNCTION__; } - - private: - template - ABSL_ATTRIBUTE_ALWAYS_INLINE ValueType - LookupHelper(const T& key_to_find) const { - auto it = table_.find(key_to_find); - if (it != table_.end()) { - return it->second; - } else { - return static_cast(Fingerprint64(key_to_find) % - num_oov_buckets_) + - StaticStringFlatHashMap::Size(); - } - } - - const int64 num_oov_buckets_; - std::unique_ptr mutex_; - // The underlying table. - absl::flat_hash_map table_; - TF_DISALLOW_COPY_AND_ASSIGN(StaticStringFlatHashMap); -}; - -// Used to allocate StaticStringFlatHashMap objects via the AllocateContainer -// method. -template -struct StaticStringFlatHashMapFactory { - struct Functor { - using resource_type = StaticStringFlatHashMap; - - template - static Status AllocateContainer(OpKernelContext* ctx, OpKernel* kernel, - StaticStringFlatHashMapBase** container) { - OpInputList table_int64_args; - TF_RETURN_IF_ERROR( - ctx->input_list("table_int64_args", &table_int64_args)); - const size_t variadic_arg_size = table_int64_args.size(); - if (ABSL_PREDICT_FALSE(variadic_arg_size != 2)) { - return errors::InvalidArgument( - "table_int64_args should have 2 elements (found ", - variadic_arg_size, - "). Set the first element to 1 to enable synchronized table use " - "and to 0 otherwise. The second element should be " - "num_oov_buckets."); - } - - const bool enable_synchronization = ctx->input(0).scalar()() != 0; - const int64 num_oov_buckets = ctx->input(1).scalar()(); - if (ABSL_PREDICT_FALSE(num_oov_buckets <= 0)) { - return errors::InvalidArgument( - "num_oov_buckets must be positive. Found: ", num_oov_buckets); - } - auto* non_virtual_container = - new StaticStringFlatHashMap(enable_synchronization, num_oov_buckets); - *container = non_virtual_container; - const Tensor& keys = ctx->input(table_int64_args.size()); - const Tensor& values = ctx->input(table_int64_args.size() + 1); - if (keys.NumElements() == 0) { - return Status::OK(); - } else if (keys.dtype() == DT_STRING) { - return Functor::Initialize( - keys.flat(), - values.flat(), - non_virtual_container); - } else if (keys.dtype() == DT_VARIANT) { - auto keys_flat = keys.flat(); - if (keys_flat(0).get() == nullptr) { - return errors::InvalidArgument( - "Variant keys tensor must have subtype absl::string_view."); - } - return Functor::Initialize( - keys.flat(), - values.flat(), - non_virtual_container); - } - return errors::InvalidArgument( - "keys tensor must have type DT_STRING or type DT_VARIANT with " - "subtype absl::string_view."); - } - - static Status Initialize( - const absl::Span keys, - const absl::Span - values, - StaticStringFlatHashMap* container) { - return container->Initialize(keys, values); - } - - static Status Initialize( - const absl::Span keys, - const absl::Span - values, - StaticStringFlatHashMap* container) { - std::vector keys_vec; - keys_vec.reserve(keys.size()); - for (size_t i = 0; i < keys.size(); ++i) { - keys_vec.push_back(*keys[i].get()); - } - return container->Initialize(keys_vec, values); - } - }; -}; - -template -using ResourceOp = ResourceConstructionOp< - typename StaticStringFlatHashMapFactory< - StaticStringFlatHashMap>::Functor, - // These are the aliases. - LookupInterface, - LookupWithPrefetchInterface, - absl::Span>, - LookupInterface, - LookupWithPrefetchInterface, - absl::Span>, - SizeInterface>; - -#define REGISTER_STRING_KERNEL(table_value_dtype) \ - REGISTER_KERNEL_BUILDER( \ - Name("StaticStringFlatHashMap") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("heterogeneous_key_dtype") \ - .TypeConstraint("table_value_dtype"), \ - ResourceOp); - -REGISTER_STRING_KERNEL(int32); -REGISTER_STRING_KERNEL(int64); - -#undef REGISTER_STRING_KERNEL - -} // namespace tables -} // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_tables/generic_table_op_kernels.cc b/tensorflow/core/kernels/lookup_tables/generic_table_op_kernels.cc deleted file mode 100644 index 9bb29afd19a..00000000000 --- a/tensorflow/core/kernels/lookup_tables/generic_table_op_kernels.cc +++ /dev/null @@ -1,227 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include "absl/strings/string_view.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/kernels/lookup_tables/op_kernel_templates.h" -#include "tensorflow/core/kernels/lookup_tables/resource_interface_templates.h" -#include "tensorflow/core/kernels/string_view_variant_wrapper.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { -namespace tables { - -template -struct TensorInsertFactory { - class Functor { - public: - // If KeyType is not 'valid' then use the value it wraps as the table key - // type. - using resource_type = InsertOrAssignInterface< - absl::Span, - typename absl::conditional_t< - IsValidDataType::value, absl::Span, - absl::Span>>; - - static Status TensorInsert(const Tensor& keys, const Tensor& values, - resource_type* table) { - if (keys.NumElements() != values.NumElements()) { - return errors::InvalidArgument( - "OpKernel tried to map keys vector of size ", keys.NumElements(), - " to values vector of size ", values.NumElements()); - } - return TensorInsertHelper(keys, values, table); - } - - private: - // keys and *values arguments to TensorInsert must have the same number of - // elements. This is guaranteed above. - - // 'Simple' types below are types which are natively supported in TF. - // Non-variant KeyType which is the same as Container::key_type. - // No need to static_cast. - template - static absl::enable_if_t::value, Status> - TensorInsertHelper(const Tensor& keys, const Tensor& values, - resource_type* table) { - return table->InsertOrAssign(keys.flat(), - values.flat()); - } - - // Variant KeyType; the wrapped type is convertible to - // Container::key_type. - template - static absl::enable_if_t::value, Status> - TensorInsertHelper(const Tensor& keys, const Tensor& values, - resource_type* table) { - const auto keys_flat = keys.flat(); - std::vector keys_vec; - keys_vec.reserve(keys_flat.size()); - for (size_t i = 0; i < keys_flat.size(); ++i) { - keys_vec.emplace_back( - *keys_flat(i).get()); - } - return table->InsertOrAssign(keys_vec, values.flat()); - } - }; -}; - -template -using InsertOp = LookupTableInsertOp< - typename TensorInsertFactory::Functor>; - -template -struct TensorLookupFactory { - class Functor { - public: - // If KeyType is not 'valid' then use the value it wraps as the table key - // type. - using resource_type = LookupWithPrefetchInterface< - absl::Span, - typename absl::conditional_t< - IsValidDataType::value, absl::Span, - absl::Span>>; - - static Status TensorLookup(const resource_type& table, const Tensor& keys, - const int64 prefetch_lookahead, - const int64 num_keys_per_thread, - thread::ThreadPool* threadpool, Tensor* values) { - if (keys.NumElements() != values->NumElements()) { - return errors::InvalidArgument( - "OpKernel tried to map keys vector of size ", keys.NumElements(), - " to values vector of size ", values->NumElements()); - } - return TensorLookupHelper(table, keys, prefetch_lookahead, - num_keys_per_thread, threadpool, values); - } - - private: - // keys and *values arguments to TensorLookup must have the same number of - // elements. This is guaranteed above. - - // 'Simple' types below are types which are natively supported in TF. - template - static absl::enable_if_t::value, Status> - TensorLookupHelper(const resource_type& table, const Tensor& keys, - const int64 prefetch_lookahead, - const int64 num_keys_per_thread, - thread::ThreadPool* threadpool, Tensor* values) { - const auto keys_flat = keys.flat(); - auto key_span = absl::MakeSpan(keys_flat); - auto value_span = absl::MakeSpan(values->flat().data(), - values->NumElements()); - return MultithreadedTensorLookup(table, prefetch_lookahead, - num_keys_per_thread, key_span, - value_span, threadpool); - } - - // Non-simple KeyType. We'll try an implicit conversion to - // Container::key_type. - template - static absl::enable_if_t::value, Status> - TensorLookupHelper(const resource_type& table, const Tensor& keys, - const int64 prefetch_lookahead, - const int64 num_keys_per_thread, - thread::ThreadPool* threadpool, Tensor* values) { - const auto keys_flat = keys.flat(); - std::vector keys_vec; - const auto keys_size = keys_flat.size(); - keys_vec.reserve(keys_size); - for (size_t i = 0; i < keys_size; ++i) { - keys_vec.emplace_back(*keys_flat(i).get()->get()); - } - absl::Span key_span(keys_vec); - auto value_span = absl::MakeSpan(values->flat().data(), - values->NumElements()); - return MultithreadedTensorLookup(table, prefetch_lookahead, - num_keys_per_thread, key_span, - value_span, threadpool); - } - - // Wrapper around table.BatchLookup which permits sharding across cores. - template - static Status MultithreadedTensorLookup(const resource_type& table, - int64 prefetch_lookahead, - int64 num_keys_per_thread, K keys, - V values, - thread::ThreadPool* threadpool) { - mutex temp_mutex; // Protect status. - Status status; - auto lookup_keys = [&](int64 begin, int64 end) { - auto temp_status = table.Lookup(keys.subspan(begin, end - begin), - values.subspan(begin, end - begin), - prefetch_lookahead); - if (ABSL_PREDICT_FALSE(!temp_status.ok())) { - mutex_lock lock(temp_mutex); - status.Update(temp_status); - } - }; - threadpool->TransformRangeConcurrently( - num_keys_per_thread /* block_size */, keys.size(), lookup_keys); - return status; - } - }; -}; - -template -using LookupOp = LookupTableFindOp< - typename TensorLookupFactory::Functor>; - -struct TableSizeFunctor { - using resource_type = SizeInterface; - - static Status Size(const SizeInterface& table, uint64* size) { - *size = table.Size(); - return Status::OK(); - } -}; - -#define REGISTER_STRING_KERNEL(table_value_dtype) \ - REGISTER_KERNEL_BUILDER( \ - Name("LookupTableInsertOrAssignOp") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("insert_key_tensor_dtype") \ - .TypeConstraint("table_value_dtype"), \ - InsertOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("LookupTableInsertOrAssignOp") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("insert_key_tensor_dtype") \ - .TypeConstraint("table_value_dtype"), \ - InsertOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("LookupTableFindOp") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("lookup_key_tensor_dtype") \ - .TypeConstraint("table_value_dtype"), \ - LookupOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("LookupTableFindOp") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("lookup_key_tensor_dtype") \ - .TypeConstraint("table_value_dtype"), \ - LookupOp); \ - REGISTER_KERNEL_BUILDER(Name("ContainerSizeOp").Device(DEVICE_CPU), \ - ContainerSizeOp); - -REGISTER_STRING_KERNEL(int32); -REGISTER_STRING_KERNEL(int64); - -#undef REGISTER_STRING_KERNEL - -} // namespace tables -} // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_tables/op_kernel_templates.h b/tensorflow/core/kernels/lookup_tables/op_kernel_templates.h deleted file mode 100644 index d767ca0661e..00000000000 --- a/tensorflow/core/kernels/lookup_tables/op_kernel_templates.h +++ /dev/null @@ -1,448 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_OP_KERNEL_TEMPLATES_H_ -#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_OP_KERNEL_TEMPLATES_H_ - -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/meta/type_traits.h" -#include "absl/types/span.h" -#include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/kernels/tensor_flag_utils.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace tables { - -// Create resources of type ResourceType and AliasesToRegister using -// Functor::AllocateContainer(OpKernelConstruction*, OpKernel*, -// ResourceType**). ResourceType = Functor::resource_type. -// No-op for resources which have already been created. -template -class ResourceConstructionOp : public OpKernel { - public: - explicit ResourceConstructionOp(OpKernelConstruction* ctx) - : OpKernel(ctx), table_handle_set_(false) { - OP_REQUIRES_OK( - ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_)); - } - - void Compute(OpKernelContext* ctx) override { - mutex_lock l(mu_); - - if (!table_handle_set_) { - OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), - use_node_name_sharing_)); - } - - auto creator = [ctx, - this](ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - ResourceType* resource = nullptr; - auto status = Functor::AllocateContainer(ctx, this, &resource); - if (ABSL_PREDICT_FALSE(!status.ok())) { - // Ideally resource is non-null only if status is OK but we try - // to compensate here. - if (resource != nullptr) { - resource->Unref(); - } - return status; - } - if (ctx->track_allocations()) { - ctx->record_persistent_memory_allocation(resource->MemoryUsed()); - } - *ret = resource; - return Status::OK(); - }; - - // Register the ResourceType alias. - ResourceType* resource = nullptr; - core::ScopedUnref unref_me(resource); - OP_REQUIRES_OK( - ctx, - cinfo_.resource_manager()->template LookupOrCreate( - cinfo_.container(), cinfo_.name(), &resource, creator)); - - // Put a handle to resource in the output tensor (the other aliases will - // have the same handle). - Tensor* handle; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); - handle->scalar()() = MakeResourceHandle( - ctx, cinfo_.container(), cinfo_.name()); - table_handle_set_ = true; - - // Create other alias resources. - Status status; - int dummy[sizeof...(AliasesToRegister)] = { - (status.Update(RegisterAlias(resource)), 0)...}; - (void)dummy; - OP_REQUIRES_OK(ctx, status); - } - - ~ResourceConstructionOp() override { - // If the table object was not shared, delete it. - if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { - if (!cinfo_.resource_manager() - ->template Delete(cinfo_.container(), - cinfo_.name()) - .ok()) { - // Do nothing; the resource may have been deleted by session resets. - } - // Attempt to delete other resource aliases. - Status dummy_status; - int dummy[sizeof...(AliasesToRegister)] = { - (dummy_status.Update(DeleteAlias()), 0)...}; - (void)dummy; - } - } - - private: - using ResourceType = typename Functor::resource_type; - template - Status RegisterAlias(ResourceType* resource) { - auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = resource; - return Status::OK(); - }; - - T* alias_resource = nullptr; - core::ScopedUnref unref_me(alias_resource); - return cinfo_.resource_manager()->template LookupOrCreate( - cinfo_.container(), cinfo_.name(), &alias_resource, creator); - } - - template - Status DeleteAlias() { - return cinfo_.resource_manager()->template Delete(cinfo_.container(), - cinfo_.name()); - } - - mutex mu_; - bool table_handle_set_ GUARDED_BY(mu_); - ContainerInfo cinfo_; - bool use_node_name_sharing_; - - TF_DISALLOW_COPY_AND_ASSIGN(ResourceConstructionOp); -}; - -// Create resources of type ContainerBase using the static method -// Functor::AllocateContainer(OpKernelConstruction*, OpKernel*, -// FallbackTableBaseType*, ContainerBase**) -// If the resource has already been created it will be looked up. -// Container must decrease the reference count of the FallbackTableBaseType* -// constructor argument before its destructor completes. -template -class TableWithFallbackConstructionOp : public OpKernel { - public: - explicit TableWithFallbackConstructionOp(OpKernelConstruction* ctx) - : OpKernel(ctx), table_handle_set_(false) { - OP_REQUIRES_OK( - ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_)); - } - - void Compute(OpKernelContext* ctx) override { - OpInputList table_int64_args; - OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args)); - if (ctx->num_inputs() == table_int64_args.size()) { - ctx->SetStatus(errors::InvalidArgument( - "Expected op to have a resource input after the table_int64_args " - "input but no such input found.")); - return; - } - - // Look up the fallback table. - FallbackTableBaseType* fallback_table = nullptr; - { - const Tensor& table_handle = ctx->input(table_int64_args.size()); - ResourceHandle handle(table_handle.scalar()()); - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->Lookup( - handle.container(), handle.name(), &fallback_table)); - } - mutex_lock l(mu_); - - if (!table_handle_set_) { - OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), - use_node_name_sharing_)); - } - - auto creator = [ctx, this, fallback_table]( - ResourceType** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // container construction logic can't be merged with - // ResourceConstructionOp because Container constructor requires an - // input which can only be constructed if the resource manager - // internal lock is not already held. - ResourceType* resource = nullptr; - auto status = - Functor::AllocateContainer(ctx, this, fallback_table, &resource); - if (ABSL_PREDICT_FALSE(!status.ok())) { - // Ideally resource is non-null only if status is OK but we try - // to compensate here. - if (resource != nullptr) { - resource->Unref(); - } - return status; - } - if (ctx->track_allocations()) { - ctx->record_persistent_memory_allocation(resource->MemoryUsed()); - } - *ret = resource; - return Status::OK(); - }; - - // Register the ResourceType alias. - ResourceType* table = nullptr; - core::ScopedUnref unref_me(table); - OP_REQUIRES_OK( - ctx, - cinfo_.resource_manager()->template LookupOrCreate( - cinfo_.container(), cinfo_.name(), &table, creator)); - - // Put a handle to resource in the output tensor (the other aliases will - // have the same handle). - Tensor* handle; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); - handle->scalar()() = MakeResourceHandle( - ctx, cinfo_.container(), cinfo_.name()); - table_handle_set_ = true; - - // Create other alias resources. - Status status; - int dummy[sizeof...(AliasesToRegister)] = { - (status.Update(RegisterAlias(table)), 0)...}; - (void)dummy; - OP_REQUIRES_OK(ctx, status); - } - - ~TableWithFallbackConstructionOp() override { - // If the table object was not shared, delete it. - if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { - if (!cinfo_.resource_manager() - ->template Delete(cinfo_.container(), - cinfo_.name()) - .ok()) { - // Do nothing; the resource may have been deleted by session resets. - } - // Attempt to delete other resource aliases. - Status dummy_status; - int dummy[sizeof...(AliasesToRegister)] = { - (dummy_status.Update(DeleteAlias()), 0)...}; - (void)dummy; - } - } - - private: - using ResourceType = typename Functor::resource_type; - using FallbackTableBaseType = typename Functor::fallback_table_type; - - template - Status RegisterAlias(ResourceType* resource) { - auto creator = [resource](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = resource; - return Status::OK(); - }; - - T* alias_resource = nullptr; - core::ScopedUnref unref_me(alias_resource); - return cinfo_.resource_manager()->template LookupOrCreate( - cinfo_.container(), cinfo_.name(), &alias_resource, creator); - } - - template - Status DeleteAlias() { - return cinfo_.resource_manager()->template Delete(cinfo_.container(), - cinfo_.name()); - } - - mutex mu_; - bool table_handle_set_ GUARDED_BY(mu_); - ContainerInfo cinfo_; - bool use_node_name_sharing_; - - TF_DISALLOW_COPY_AND_ASSIGN(TableWithFallbackConstructionOp); -}; - -// Lookup a table of type ResourceAlias and insert the passed in keys and -// values tensors using Functor::TensorInsert(keys, values, table). -template -class LookupTableInsertOp : public OpKernel { - public: - explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - OpInputList table_int64_args; - OP_REQUIRES_OK(ctx, ctx->input_list("table_int64_args", &table_int64_args)); - const size_t tensor_index_offset = table_int64_args.size(); - // Business logic for checking tensor shapes, etc, is delegated to the - // Functor. - const Tensor& keys = ctx->input(tensor_index_offset + 1); - const Tensor& values = ctx->input(tensor_index_offset + 2); - - const Tensor& table_handle = ctx->input(tensor_index_offset); - ResourceHandle handle(table_handle.scalar()()); - ResourceAlias* table; - core::ScopedUnref unref_me(table); - OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( - handle.container(), handle.name(), &table)); - - int memory_used_before = 0; - if (ctx->track_allocations()) { - memory_used_before = table->MemoryUsed(); - } - auto* mutex = table->GetMutex(); - if (mutex != nullptr) { - mutex_lock lock(*mutex); - OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table)); - } else { - OP_REQUIRES_OK(ctx, Functor::TensorInsert(keys, values, table)); - } - if (ctx->track_allocations()) { - ctx->record_persistent_memory_allocation(table->MemoryUsed() - - memory_used_before); - } - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(LookupTableInsertOp); -}; - -// Lookup a table of type ResourceAlias and look up the passed in keys using -// Functor::TensorLookup( -// table, keys, prefetch_lookahead, num_keys_per_thread, threadpool, out). -template -class LookupTableFindOp : public OpKernel { - public: - explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - OpInputList table_int64_args; - { - auto status = ctx->input_list("table_int64_args", &table_int64_args); - if (ABSL_PREDICT_FALSE(!status.ok())) { - ctx->SetStatus(status); - return; - } - } - // We lookup tensors using positional indices because that's more - // efficient than looking up their string names. - const Tensor& prefetch_lookahead_t = ctx->input(0); - const size_t tensor_index_offset = table_int64_args.size(); - const Tensor& keys = ctx->input(tensor_index_offset + 1); - const Tensor& num_threads = ctx->input(tensor_index_offset + 2); - - TensorShape output_shape = keys.shape(); - Tensor* out; - { - auto status = ctx->allocate_output(0, output_shape, &out); - if (ABSL_PREDICT_FALSE(!status.ok())) { - ctx->SetStatus(status); - return; - } - } - - int64 num_threads_scalar; - if (TensorShapeUtils::IsScalar(num_threads.shape())) { - num_threads_scalar = num_threads.template scalar()(); - } else { - // Scans through rows of num_threads and returns second entry of first - // row whose first entry is <= the number of keys to process. - // This allows the user to control parallelism as a function of - // the number of keys to lookup. - num_threads_scalar = tensor_flag_utils::FindConfigValueForKey( - num_threads.template matrix(), keys.dim_size(0)); - } - const int64 num_keys_per_thread = - num_threads_scalar > 0 - ? std::max(1ll, keys.dim_size(0) / num_threads_scalar) - : keys.dim_size(0); - - const int64 prefetch_lookahead = prefetch_lookahead_t.scalar()(); - - const Tensor& table_handle = ctx->input(tensor_index_offset); - ResourceHandle handle(table_handle.scalar()()); - ResourceAlias* table; - core::ScopedUnref unref_me(table); - OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( - handle.container(), handle.name(), &table)); - - auto* mutex = table->GetMutex(); - auto* threadpool = ctx->device()->tensorflow_cpu_worker_threads()->workers; - if (mutex != nullptr) { - // There are many subtle problems with using reader locks so we opt for a - // writer lock here. - mutex_lock lock(*mutex); - OP_REQUIRES_OK( - ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead, - num_keys_per_thread, threadpool, out)); - } else { - OP_REQUIRES_OK( - ctx, Functor::TensorLookup(*table, keys, prefetch_lookahead, - num_keys_per_thread, threadpool, out)); - } - } -}; - -// Lookup a container of type ResourceAlias and return its size using -// Functor::Size(container, &size). -template -class ContainerSizeOp : public OpKernel { - public: - explicit ContainerSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - const Tensor& container_handle = ctx->input(0); - ResourceHandle handle(container_handle.scalar()()); - ResourceAlias* container; - core::ScopedUnref unref_me(container); - OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( - handle.container(), handle.name(), &container)); - - Tensor* out; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); - - auto* mutex = container->GetMutex(); - if (mutex != nullptr) { - tf_shared_lock lock(*mutex); - OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar()())); - } else { - OP_REQUIRES_OK(ctx, Functor::Size(*container, &out->scalar()())); - } - } -}; - -} // namespace tables -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_OP_KERNEL_TEMPLATES_H_ diff --git a/tensorflow/core/kernels/lookup_tables/resource_interface_templates.h b/tensorflow/core/kernels/lookup_tables/resource_interface_templates.h deleted file mode 100644 index 7331fb400a4..00000000000 --- a/tensorflow/core/kernels/lookup_tables/resource_interface_templates.h +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_RESOURCE_INTERFACE_TEMPLATES_H_ -#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_RESOURCE_INTERFACE_TEMPLATES_H_ - -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { -namespace tables { - -// Interface for resources with mutable state. -class SynchronizedInterface : public virtual ResourceBase { - public: - // Return value should be used to synchronize read/write access to - // all public methods. If null, no synchronization is needed. - virtual mutex* GetMutex() const = 0; -}; - -// Interface for containers which support batch lookups. -template -class InsertOrAssignInterface : public virtual SynchronizedInterface { - public: - using value_type = ValueType; - - // Stores each KV pair {keys[i], values[i]} in the underlying map, overriding - // pre-existing pairs which have equivalent keys. - // keys and values should have the same size. - virtual Status InsertOrAssign(KeyContext... key_context, - ValueType values) = 0; -}; - -// Interface for containers which support lookups. -template -class LookupInterface : public virtual SynchronizedInterface { - public: - using value_type = ValueType; - - // Lookup the values for keys and store them in values. - // prefetch_lookahead is used to prefetch the key at index - // i + prefetch_lookahead at the ith iteration of the implemented loop. - // keys and values must have the same size. - virtual Status Lookup(KeyContext... key_context, ValueType values) const = 0; -}; - -// Interface for containers which support lookups with prefetching. -template -class LookupWithPrefetchInterface : public virtual SynchronizedInterface { - public: - using value_type = ValueType; - - // Lookup the values for keys and store them in values. - // prefetch_lookahead is used to prefetch the key at index - // i + prefetch_lookahead at the ith iteration of the implemented loop. - // keys and values must have the same size. - virtual Status Lookup(KeyContext... key_context, ValueType values, - int64 prefetch_lookahead) const = 0; -}; - -// Interface for containers with size concepts. -// Implementations must guarantee thread-safety when GetMutex is used to -// synchronize method access. -class SizeInterface : public virtual SynchronizedInterface { - public: - // Returns the number of elements in the container. - virtual uint64 Size() const = 0; -}; - -// Interface for tables which can be initialized from key and value arguments. -template -class KeyValueTableInitializerInterface : public virtual SynchronizedInterface { - public: - using value_type = ValueType; - - // Lookup the values for keys and store them in values. - // prefetch_lookahead is used to prefetch the key at index - // i + prefetch_lookahead at the ith iteration of the implemented loop. - // keys and values must have the same size. - virtual Status Initialize(KeyContext... key_context, ValueType values) = 0; -}; - -} // namespace tables -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLES_RESOURCE_INTERFACE_TEMPLATES_H_ diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index ba30432e21a..a58a4acdfba 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -35,7 +35,7 @@ limitations under the License. #endif #if GOOGLE_CUDA -#include "cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/stream_executor_util.h" #endif // GOOGLE_CUDA @@ -312,7 +312,11 @@ struct LaunchLRNGrad; template struct LaunchLRNGrad { LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta) - : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {} + : depth_radius_(depth_radius), + bias_(bias), + alpha_(alpha), + beta_(beta), + alpha_beta_2_(T(-2) * alpha * beta) {} void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in_grads, const Tensor& in_image, @@ -358,13 +362,15 @@ struct LaunchLRNGrad { } norm = alpha_ * norm + bias_; DCHECK_GT(norm, T(1e-6)); + T pre_computed_pow = Eigen::numext::pow(norm, -beta_); + T activations_ab2 = alpha_beta_2_ * activations(i, j); + T gs = grads_shaped(i, j); for (int64 k = depth_begin; k < depth_end; ++k) { - T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) * - activations(i, j) / norm; + T dyi = in_shaped(i, k) * activations_ab2 / norm; if (k == j) { - dyi += Eigen::numext::pow(norm, -beta_); + dyi += pre_computed_pow; } - dyi *= grads_shaped(i, j); + dyi *= gs; const_cast::Tensor&>(out_shaped)(i, k) += dyi; } } @@ -379,6 +385,7 @@ struct LaunchLRNGrad { T bias_; T alpha_; T beta_; + T alpha_beta_2_; }; #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/lrn_op_test.cc b/tensorflow/core/kernels/lrn_op_test.cc index 496c697ac3f..7604a40d029 100644 --- a/tensorflow/core/kernels/lrn_op_test.cc +++ b/tensorflow/core/kernels/lrn_op_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -23,11 +24,13 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { @@ -195,4 +198,41 @@ TCASE(T3, 128, 4, 3, 2.0f, 1.0f, 1.0f) // clang-format on #undef TCASE + +static Graph* BM_LRNGrad(int batches, int rows, int cols, int depth, + int depth_radius) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor grads(DT_FLOAT, TensorShape({batches, rows, cols, depth})); + grads.flat().setRandom(); + + Tensor in(DT_FLOAT, TensorShape({batches, rows, cols, depth})); + in.flat().setRandom(); + + Tensor out(DT_FLOAT, TensorShape({batches, rows, cols, depth})); + + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("lrn_grad_op"), "LRNGrad") + .Input(test::graph::Constant(g, grads)) + .Input(test::graph::Constant(g, in)) + .Input(test::graph::Constant(g, out)) + .Attr("depth_radius", depth_radius) + .Attr("bias", 1.0f) + .Attr("alpha", 1.0f / 10) + .Attr("beta", 2.0f) + .Finalize(g, &ret)); + return g; +} + +#define BM_LRNGradDev(DEVICE, B, R, C, D, DR) \ + static void BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR(int iters) { \ + testing::ItemsProcessed(static_cast(iters) * B * R * C * D * DR * \ + 4); \ + test::Benchmark(#DEVICE, BM_LRNGrad(B, R, C, D, DR)).Run(iters); \ + } \ + BENCHMARK(BM_LRNGrad_##DEVICE##_##B##_##R##_##C##_##D##_##DR) + +BM_LRNGradDev(cpu, 128, 12, 12, 64, 4); +BM_LRNGradDev(cpu, 128, 56, 56, 64, 2); +BM_LRNGradDev(cpu, 128, 27, 27, 192, 2); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/lu_op_gpu.cu.cc b/tensorflow/core/kernels/lu_op_gpu.cu.cc index f83744b50de..6d94d7f3f64 100644 --- a/tensorflow/core/kernels/lu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/lu_op_gpu.cu.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -61,7 +61,7 @@ __device__ void ComputePermutationFromTranspositions( // transpositions. template __global__ void ComputePermutationFromTranspositionsKernel( - CudaLaunchConfig config, const int64 num_rows, const int* all_pivots, + GpuLaunchConfig config, const int64 num_rows, const int* all_pivots, Scalar* all_permutation_indices) { // We only parallelize over batches here. Performance is not critical, // since this cheap O(num_rows) kernel always follows an O(num_rows^3) @@ -222,11 +222,11 @@ class LuOpGpu : public AsyncOpKernel { int* pivots_ptr = pivots.flat().data(); Tidx* permutation_indices_ptr = permutation_indices->template flat().data(); - CudaLaunchConfig cfgPivots = GetCudaLaunchConfig(batch_size, device); - ComputePermutationFromTranspositionsKernel<<>>( - cfgPivots, num_rows, pivots_ptr, permutation_indices_ptr); + GpuLaunchConfig cfgPivots = GetCudaLaunchConfig(batch_size, device); + TF_CHECK_OK(CudaLaunchKernel( + ComputePermutationFromTranspositionsKernel, cfgPivots.block_count, + cfgPivots.thread_per_block, 0, device.stream(), cfgPivots, num_rows, + pivots_ptr, permutation_indices_ptr)); // Callback for checking info after kernels finish. Also capture the // temporary Tensors/ScratchSpace so they don't get deallocated before the diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index 4ebe1659370..107f5a11954 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -25,10 +25,12 @@ limitations under the License. #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/util/matmul_autotune.h" #if GOOGLE_CUDA -#include "cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda.h" +#endif +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/platform/stream_executor.h" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { @@ -111,11 +113,11 @@ bool ExplicitVectorMatrixOptimization( template struct LaunchMatMulBase { -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM typedef se::blas::AlgorithmType AlgorithmType; #else typedef int64 AlgorithmType; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM static void launch( OpKernelContext* ctx, const Tensor& a, const Tensor& b, @@ -154,7 +156,7 @@ template struct LaunchMatMul : public LaunchMatMulSYCL {}; #endif // TENSORFLOW_USE_SYCL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace { @@ -433,7 +435,7 @@ struct LaunchMatMul { } }; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM template class MatMulOp : public OpKernel { @@ -483,7 +485,7 @@ class MatMulOp : public OpKernel { return; } - if (a.NumElements() == 0 || b.NumElements() == 0) { + if (a.NumElements() == 0 && b.NumElements() == 0) { // If a has shape [x, 0] and b has shape [0, y], the // output shape is [x, y] where x and y are non-zero, so we fill // the output with zeros. @@ -622,13 +624,13 @@ TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); #endif // INTEL_MKL && ENABLE_MKL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); TF_CALL_half(REGISTER_GPU); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL(T) \ diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h index 48769f3fe5d..51a4d0e8aa6 100644 --- a/tensorflow/core/kernels/matmul_op.h +++ b/tensorflow/core/kernels/matmul_op.h @@ -58,7 +58,7 @@ struct MatMulFunctor { } // end namespace functor -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Encapsulate all the shape information that is used in matmul operations. class MatmulParameters { public: @@ -117,7 +117,7 @@ class MatmulParameters { typedef Eigen::GpuDevice GPUDevice; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // end namespace tensorflow diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc new file mode 100644 index 00000000000..3bdc303dff5 --- /dev/null +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -0,0 +1,199 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implements matmul operations with other kernels baked into the +// processing, to optimize latency and memory usage: +// - MatMul + BiasAdd + +// - MatMul + FusedBatchNorm + +// +// Activation: Relu, Relu6, Elu, etc... +// +// Currently supported only on CPU device. + +#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_ +#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_ + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include +#include + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/fused_eigen_output_kernels.h" +#include "tensorflow/core/util/tensor_format.h" + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +struct LaunchFusedMatMulOp { + void operator()( + OpKernelContext* context, const Tensor& a, const Tensor& b, + const Eigen::array, 1>& dim_pair, + FusedComputationType fusion, const FusedComputationArgs& fusion_args, + Tensor* output); +}; + +template +struct LaunchFusedMatMulOp { + void operator()( + OpKernelContext* context, const Tensor& a, const Tensor& b, + const Eigen::array, 1>& dim_pair, + FusedComputationType fusion, const FusedComputationArgs& fusion_args, + Tensor* output) { + auto lhs = a.matrix(); + auto rhs = b.matrix(); + auto out = output->matrix(); + + auto& d = context->eigen_device(); + + BiasAddArgs bias_add_args; + if (BiasAddArgs::IsSupported(fusion)) { + OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args)); + } + + switch (fusion) { + case FusedComputationType::kBiasAdd: + out.device(d) = + lhs.contract(rhs, dim_pair, WithBiasAdd(bias_add_args)); + break; + case FusedComputationType::kBiasAddWithRelu: + out.device(d) = + lhs.contract(rhs, dim_pair, WithBiasAddAndRelu(bias_add_args)); + break; + case FusedComputationType::kBiasAddWithRelu6: + out.device(d) = + lhs.contract(rhs, dim_pair, WithBiasAddAndRelu6(bias_add_args)); + break; + case FusedComputationType::kBiasAddWithElu: + out.device(d) = + lhs.contract(rhs, dim_pair, WithBiasAddAndElu(bias_add_args)); + break; + case FusedComputationType::kUndefined: + OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined")); + break; + default: + OP_REQUIRES_OK(context, + errors::Internal("Fusion type is not supported")); + } + } +}; + +template +class FusedMatMulOp : public OpKernel { + public: + explicit FusedMatMulOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_)); + + std::vector patterns; + + using FCT = FusedComputationType; + if (std::is_same::value) { + patterns = {{FCT::kBiasAdd, {"BiasAdd"}}, + {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}, + {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}}, + {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}}; + } + + OP_REQUIRES_OK(context, InitializeFusedComputation( + context, "MatMul", patterns, + &fused_computation_, &fused_computation_args_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + + // Check that the dimensions of the two matrices are valid. + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(a.shape()), + errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ", + a.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(b.shape()), + errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ", + b.shape().DebugString())); + Eigen::array, 1> dim_pair; + dim_pair[0].first = transpose_a_ ? 0 : 1; + dim_pair[0].second = transpose_b_ ? 1 : 0; + + OP_REQUIRES( + ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), + errors::InvalidArgument( + "Matrix size-incompatible: In[0]: ", a.shape().DebugString(), + ", In[1]: ", b.shape().DebugString())); + int a_dim_remaining = 1 - dim_pair[0].first; + int b_dim_remaining = 1 - dim_pair[0].second; + TensorShape out_shape( + {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + + if (out->NumElements() == 0) { + // If a has shape [0, x] or b has shape [x, 0], the output shape + // is a 0-element matrix, so there is nothing to do. + return; + } + + if (a.NumElements() == 0 && b.NumElements() == 0) { + // If a has shape [x, 0] and b has shape [0, y], the + // output shape is [x, y] where x and y are non-zero, so we fill + // the output with zeros. + functor::SetZeroFunctor f; + f(ctx->eigen_device(), out->flat()); + return; + } + + auto launch = LaunchFusedMatMulOp(); + launch(ctx, a, b, dim_pair, fused_computation_, fused_computation_args_, + out); + } + + private: + bool transpose_a_; + bool transpose_b_; + + FusedComputationType fused_computation_ = FusedComputationType::kUndefined; + FusedComputationArgs fused_computation_args_; + + TF_DISALLOW_COPY_AND_ASSIGN(FusedMatMulOp); +}; + +// Registration of the CPU implementations. +#define REGISTER_FUSED_CPU_MATMUL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ + FusedMatMulOp); + +#ifndef EIGEN_USE_LIBXSMM +TF_CALL_float(REGISTER_FUSED_CPU_MATMUL); +#endif // !EIGEN_USE_LIBXSMM + +#undef REGISTER_FUSED_CPU_MATMUL + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_ diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc index db1dc77bc5f..b442bf84cd0 100644 --- a/tensorflow/core/kernels/matmul_op_test.cc +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -13,13 +13,334 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/algorithm/container.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" +#include "tensorflow/core/public/session.h" namespace tensorflow { +template +class FusedMatMulOpTest : public OpsTestBase { + protected: + using BiasAddGraphRunner = + std::function; + + // Runs a Tensorflow graph defined by the root scope, and fetches the result + // of 'fetch' node into the output Tensor. Optional `fetch_node` parameter + // allows to define a fetch node directly using a NodeDef for the ops that are + // not supported by the C++ Api. + void RunAndFetch(const tensorflow::Scope& root, const string& fetch, + Tensor* output, bool allow_gpu_device, + const NodeDef* fetch_node = nullptr) { + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + if (fetch_node) { + *graph.add_node() = *fetch_node; + } + + // We really want to make sure that graph executed exactly as we passed it + // to the session, so we disable various optimizations. + tensorflow::SessionOptions session_options; + + // Disable common runtime constant folding. + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_opt_level(OptimizerOptions::L0); + + // Disable Grappler optimizations for tests. + tensorflow::RewriterConfig* cfg = + session_options.config.mutable_graph_options() + ->mutable_rewrite_options(); + cfg->set_constant_folding(tensorflow::RewriterConfig::OFF); + cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF); + cfg->set_remapping(tensorflow::RewriterConfig::OFF); + + std::unique_ptr session( + tensorflow::NewSession(session_options)); + + std::vector available_devices; + TF_ASSERT_OK(session->ListDevices(&available_devices)) + << "Failed to get available session devices"; + + // Check if session has an available GPU device. + const bool has_gpu_device = + absl::c_any_of(available_devices, [](const DeviceAttributes& device) { + return device.device_type() == DEVICE_GPU; + }); + + // If fused computation implemented only for CPU, in this test we don't want + // to compare GPU vs CPU numbers, so place all nodes on CPU in this case. + const bool place_all_on_gpu = allow_gpu_device && has_gpu_device; + + const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0"; + for (NodeDef& mutable_node : *graph.mutable_node()) { + mutable_node.set_device(device); + } + + TF_ASSERT_OK(session->Create(graph)); + + std::vector unfused_tensors; + TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors)); + + *output = unfused_tensors[0]; + } + + void RunMatMulWithBias(const Tensor& lhs_data, const Tensor& rhs_data, + const Tensor& bias_data, bool transpose_a, + bool transpose_b, Tensor* output, + bool allow_gpu_device = false) { + Scope root = tensorflow::Scope::NewRootScope(); + + ops::MatMul matmul = ops::MatMul( + root.WithOpName("matmul"), + ops::Const(root.WithOpName("lhs"), Input::Initializer(lhs_data)), + ops::Const(root.WithOpName("rhs"), Input::Initializer(rhs_data)), + ops::MatMul::Attrs().TransposeA(transpose_a).TransposeB(transpose_b)); + + ops::BiasAdd with_bias = ops::BiasAdd( + root.WithOpName("with_bias"), matmul, + ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data))); + + RunAndFetch(root, "with_bias", output, allow_gpu_device); + } + + void RunMatMulWithBiasAndActivation( + const Tensor& lhs_data, const Tensor& rhs_data, const Tensor& bias_data, + bool transpose_a, bool transpose_b, const string& activation_type, + Tensor* output, bool allow_gpu_device = false) { + Scope root = tensorflow::Scope::NewRootScope(); + + ops::MatMul matmul = ops::MatMul( + root.WithOpName("matmul"), + ops::Const(root.WithOpName("lhs"), Input::Initializer(lhs_data)), + ops::Const(root.WithOpName("rhs"), Input::Initializer(rhs_data)), + ops::MatMul::Attrs().TransposeA(transpose_a).TransposeB(transpose_b)); + + ops::BiasAdd with_bias = ops::BiasAdd( + root.WithOpName("with_bias"), matmul, + ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data))); + + if (activation_type == "Relu") { + ops::Relu(root.WithOpName("with_activation"), with_bias); + } else if (activation_type == "Relu6") { + ops::Relu6(root.WithOpName("with_activation"), with_bias); + } else if (activation_type == "Elu") { + ops::Elu(root.WithOpName("with_activation"), with_bias); + } else { + ops::Identity(root.WithOpName("with_activation"), with_bias); + } + + RunAndFetch(root, "with_activation", output, allow_gpu_device); + } + + void RunFusedMatMulOp(const Tensor& lhs_data, const Tensor& rhs_data, + const std::vector& args_data, + const std::vector& fused_ops, bool transpose_a, + bool transpose_b, Tensor* output, + bool allow_gpu_device = false) { + Scope root = tensorflow::Scope::NewRootScope(); + + DataType dtype = DataTypeToEnum::v(); + int num_args = static_cast(args_data.size()); + + Output lhs = + ops::Const(root.WithOpName("lhs"), Input::Initializer(lhs_data)); + Output rhs = + ops::Const(root.WithOpName("rhs"), Input::Initializer(rhs_data)); + + std::vector args; + for (int i = 0; i < num_args; ++i) { + Output arg = ops::Const(root.WithOpName(absl::StrCat("arg", i)), + Input::Initializer(args_data[i])); + args.emplace_back(arg.name(), 0, dtype); + } + + NodeDef fused_matmul; + TF_EXPECT_OK(NodeDefBuilder("fused_matmul", "_FusedMatMul") + .Input({lhs.name(), 0, dtype}) + .Input({rhs.name(), 0, dtype}) + .Input(args) + .Attr("num_args", num_args) + .Attr("T", dtype) + .Attr("fused_ops", fused_ops) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b) + .Finalize(&fused_matmul)); + + RunAndFetch(root, fused_matmul.name(), output, allow_gpu_device, + &fused_matmul); + } + + void VerifyBiasAddTensorsNear(int m, int k, int n, + const BiasAddGraphRunner& run_default, + const BiasAddGraphRunner& run_fused) { + DataType dtype = DataTypeToEnum::v(); + + Tensor lhs(dtype, {m, k}); + lhs.flat() = lhs.flat().setRandom(); + + // Add some negative values to filter to properly test Relu. + Tensor rhs(dtype, {k, n}); + rhs.flat() = rhs.flat().setRandom(); + rhs.flat() -= rhs.flat().constant(static_cast(0.5f)); + + // Bias added to the inner dimension. + const int bias_size = n; + Tensor bias(dtype, {bias_size}); + bias.flat() = bias.flat().setRandom(); + bias.flat() += bias.flat().constant(static_cast(0.5f)); + + Tensor matmul; + Tensor fused_matmul; + + run_default(lhs, rhs, bias, &matmul); + run_fused(lhs, rhs, bias, &fused_matmul); + + ASSERT_EQ(matmul.dtype(), fused_matmul.dtype()); + ASSERT_EQ(matmul.shape(), fused_matmul.shape()); + + test::ExpectClose(matmul, fused_matmul, /*atol=*/1e-5); + } + + // Verifies that computing MatMul+BiasAdd in a graph is identical to + // FusedMatMul. + void VerifyMatMulWithBias(int m, int k, int n, bool transpose_a, + bool transpose_b) { + const BiasAddGraphRunner run_default = + [&](const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, Tensor* out) { + RunMatMulWithBias(input_data, filter_data, bias_data, transpose_a, + transpose_b, out); + }; + + const BiasAddGraphRunner run_fused = + [&](const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, Tensor* out) { + RunFusedMatMulOp(input_data, filter_data, {bias_data}, {"BiasAdd"}, + transpose_a, transpose_b, out); + }; + + VerifyBiasAddTensorsNear(m, k, n, run_default, run_fused); + } + + // Verifies that computing MatMul+BiasAdd+{Activation} in a graph is identical + // to FusedMatMul. + void VerifyConv2DWithBiasAndActivation(int m, int k, int n, bool transpose_a, + bool transpose_b, + const string& activation) { + const BiasAddGraphRunner run_default = [&](const Tensor& input_data, + const Tensor& filter_data, + const Tensor& bias_data, + Tensor* out) { + RunMatMulWithBiasAndActivation(input_data, filter_data, bias_data, + transpose_a, transpose_b, activation, out); + }; + + const BiasAddGraphRunner run_fused = [&](const Tensor& input_data, + const Tensor& filter_data, + const Tensor& bias_data, + Tensor* out) { + RunFusedMatMulOp(input_data, filter_data, {bias_data}, + {"BiasAdd", activation}, transpose_a, transpose_b, out); + }; + + VerifyBiasAddTensorsNear(m, k, n, run_default, run_fused); + } +}; + +// MatMul with BatchNorm can be tested only with `T=float`, because default +// `FusedBatchNorm` kernel supports only floats for scale, mean and variance. + +template +class FusedMatMulWithBiasOpTest : public FusedMatMulOpTest {}; + +TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest); + +// -------------------------------------------------------------------------- // +// MatMul + BiasAdd + {Activation} // +// -------------------------------------------------------------------------- // + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256) { + this->VerifyMatMulWithBias(256, 256, 256, false, false); + this->VerifyMatMulWithBias(256, 256, 256, true, false); + this->VerifyMatMulWithBias(256, 256, 256, false, true); + this->VerifyMatMulWithBias(256, 256, 256, true, true); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256) { + this->VerifyMatMulWithBias(1, 256, 256, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1) { + this->VerifyMatMulWithBias(256, 256, 1, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1) { + this->VerifyMatMulWithBias(1, 256, 1, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256WithActivation) { + for (const string& activation : {"Relu", "Relu6", "Elu"}) { + this->VerifyConv2DWithBiasAndActivation(256, 256, 256, false, false, + activation); + this->VerifyConv2DWithBiasAndActivation(256, 256, 256, true, false, + activation); + this->VerifyConv2DWithBiasAndActivation(256, 256, 256, false, true, + activation); + this->VerifyConv2DWithBiasAndActivation(256, 256, 256, true, true, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256WithActivation) { + for (const string& activation : {"Relu", "Relu6", "Elu"}) { + this->VerifyConv2DWithBiasAndActivation(1, 256, 256, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1WithActivation) { + for (const string& activation : {"Relu", "Relu6", "Elu"}) { + this->VerifyConv2DWithBiasAndActivation(256, 256, 1, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1WithActivation) { + for (const string& activation : {"Relu", "Relu6", "Elu"}) { + this->VerifyConv2DWithBiasAndActivation(1, 256, 1, false, false, + activation); + } +} + +REGISTER_TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest, // + MatMul256x256x256, // + MatMul1x256x256, // + MatMul256x256x1, // + MatMul1x256x1, // + MatMul256x256x256WithActivation, // + MatMul1x256x256WithActivation, // + MatMul256x256x1WithActivation, // + MatMul1x256x1WithActivation); + +// TODO(ezhulenev): Add support for more data types. +using FusedBiasAddDataTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedMatMulWithBiasOpTest, + FusedBiasAddDataTypes); + +//----------------------------------------------------------------------------// +// Performance benchmarks are below. // +//----------------------------------------------------------------------------// + template static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b, DataType type) { @@ -42,17 +363,27 @@ static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b, } \ BENCHMARK(BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##TFTYPE##_##DEVICE); +#ifdef GOOGLE_CUDA + #define BM_Matmul(M, K, N, TA, TB) \ BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, cpu); \ BM_MatmulDev(M, K, N, TA, TB, std::complex, DT_COMPLEX64, cpu); \ BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, gpu); \ BM_MatmulDev(M, K, N, TA, TB, std::complex, DT_COMPLEX64, gpu); \ -/* Uncomment to enable benchmarks for double/complex128: */ \ -// BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \ + /* Uncomment to enable benchmarks for double/complex128: */ \ + // BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, cpu); \ // BM_MatmulDev(M, K, N, TA, TB, std::complex, DT_COMPLEX128, cpu); \ // BM_MatmulDev(M, K, N, TA, TB, double, DT_DOUBLE, gpu); \ // BM_MatmulDev(M, K, N, TA, TB, std::complex, DT_COMPLEX128, gpu); +#else + +#define BM_Matmul(M, K, N, TA, TB) \ + BM_MatmulDev(M, K, N, TA, TB, float, DT_FLOAT, cpu); \ + BM_MatmulDev(M, K, N, TA, TB, std::complex, DT_COMPLEX64, cpu); + +#endif // GOOGLE_CUDA + // Batch size of 1 included for inference. // Typical fully connected layers BM_Matmul(1, 512, 512, false, false); diff --git a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc index 628d22b4584..2195c583130 100644 --- a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc +++ b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc @@ -18,11 +18,12 @@ limitations under the License. #define EIGEN_USE_GPU #include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/matrix_band_part_op.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -57,11 +58,12 @@ struct MatrixBandPartFunctor { const int batch_size = input.dimension(0); const int m = input.dimension(1); const int n = input.dimension(2); - CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device); - MatrixBandPartKernel<<>>( - config.virtual_thread_count, batch_size, m, n, num_lower_diags, - num_upper_diags, input.data(), output.data()); + GpuLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device); + TF_CHECK_OK(CudaLaunchKernel(MatrixBandPartKernel, + config.block_count, config.thread_per_block, 0, + device.stream(), config.virtual_thread_count, + batch_size, m, n, num_lower_diags, + num_upper_diags, input.data(), output.data())); } }; diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc index 75c49baaa84..7779525d231 100644 --- a/tensorflow/core/kernels/matrix_diag_op.cc +++ b/tensorflow/core/kernels/matrix_diag_op.cc @@ -62,8 +62,8 @@ class MatrixDiagPartOp : public OpKernel { for (int i = 0; i < rank - 2; ++i) { output_shape.AddDim(input_shape.dim_size(i)); } - const int64 min_dim = std::min(input_shape.dim_size(rank - 2), - input_shape.dim_size(rank - 1)); + const Eigen::Index min_dim = std::min(input_shape.dim_size(rank - 2), + input_shape.dim_size(rank - 1)); output_shape.AddDim(min_dim); Tensor* output = nullptr; @@ -97,7 +97,7 @@ class MatrixDiagOp : public OpKernel { "input must be at least 1-dim, received shape: ", input.shape().DebugString())); - const int64 k = input_shape.dim_size(rank - 1); + const Eigen::Index k = input_shape.dim_size(rank - 1); auto input_reshaped = input.flat_inner_dims(); TensorShape output_shape = input_shape; @@ -147,8 +147,8 @@ struct MatrixDiag { typename TTypes::ConstTensor input, typename TTypes::Tensor output) { output.device(d) = output.constant(T()); - for (int64 r = 0; r < output.dimension(0); ++r) { - for (int64 d = 0; d < output.dimension(1); ++d) { + for (Eigen::Index r = 0; r < output.dimension(0); ++r) { + for (Eigen::Index d = 0; d < output.dimension(1); ++d) { output(r, d, d) = input(r, d); } } @@ -160,8 +160,8 @@ struct MatrixDiagPart { static void Compute(const CPUDevice& d, typename TTypes::ConstTensor input, typename TTypes::Tensor output) { - for (int64 r = 0; r < output.dimension(0); ++r) { - for (int64 d = 0; d < output.dimension(1); ++d) { + for (Eigen::Index r = 0; r < output.dimension(0); ++r) { + for (Eigen::Index d = 0; d < output.dimension(1); ++d) { output(r, d) = input(r, d, d); } } diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc index 502d593474e..78b1df25399 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op.cc +++ b/tensorflow/core/kernels/matrix_set_diag_op.cc @@ -121,16 +121,17 @@ struct MatrixSetDiag { if (input.data() != output.data()) { output.device(device) = input; } - auto compute_shard = [&output, &diag](int64 begin, int64 end) { - for (int64 batch = begin; batch < end; ++batch) { - for (int64 col = 0; col < diag.dimension(1); ++col) { + auto compute_shard = [&output, &diag](Eigen::Index begin, + Eigen::Index end) { + for (Eigen::Index batch = begin; batch < end; ++batch) { + for (Eigen::Index col = 0; col < diag.dimension(1); ++col) { output(batch, col, col) = diag(batch, col); } } }; auto thread_pool = context->device()->tensorflow_cpu_worker_threads()->workers; - int64 cost_per_batch = 10 * output.dimension(1); // Heuristic. + Eigen::Index cost_per_batch = 10 * output.dimension(1); // Heuristic. thread_pool->ParallelFor(output.dimension(0), cost_per_batch, std::move(compute_shard)); } diff --git a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc index 4abf666fad9..4ee52f57939 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc +++ b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/matrix_set_diag_op.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -71,15 +71,14 @@ struct MatrixSetDiag { CHECK_EQ(diag.dimension(1), minsize); if (batch_size == 0 || minsize == 0) return; if (input.data() == output.data()) { - CudaLaunchConfig config = - GetCudaLaunchConfig(batch_size * minsize, device); + GpuLaunchConfig config = GetGpuLaunchConfig(batch_size * minsize, device); TF_CHECK_OK(CudaLaunchKernel(MatrixSetDiagKernel, config.block_count, config.thread_per_block, 0, device.stream(), config.virtual_thread_count, m, n, minsize, diag.data(), output.data())); } else { - CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device); + GpuLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device); TF_CHECK_OK(CudaLaunchKernel(MatrixCopyInputAndSetDiagKernel, config.block_count, config.thread_per_block, 0, device.stream(), diff --git a/tensorflow/core/kernels/matrix_solve_op.cc b/tensorflow/core/kernels/matrix_solve_op.cc index f3919a16aa5..3a75054f4ea 100644 --- a/tensorflow/core/kernels/matrix_solve_op.cc +++ b/tensorflow/core/kernels/matrix_solve_op.cc @@ -76,7 +76,7 @@ class MatrixSolveOp : public LinearAlgebraOp { MatrixMaps* outputs) final { const ConstMatrixMap& matrix = inputs[0]; const ConstMatrixMap& rhs = inputs[1]; - if (matrix.rows() == 0 || rhs.cols() == 0) { + if (matrix.rows() == 0 || matrix.cols() == 0 || rhs.cols() == 0) { // To be consistent with the MatrixInverse op, we define the solution for // an empty set of equation as the empty matrix. return; @@ -162,7 +162,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel { // To be consistent with the MatrixInverse op, we define the solution for // an empty set of equations as the empty matrix. - if (rhs.NumElements() == 0) { + if (input.NumElements() == 0 || rhs.NumElements() == 0) { done(); return; } diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 38885f1ccf5..a3592d8ec3c 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -41,7 +41,7 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #if GOOGLE_CUDA -#include "cuda/include/cudnn.h" +#include "third_party/gpus/cudnn/cudnn.h" #include "tensorflow/core/kernels/maxpooling_op_gpu.h" #include "tensorflow/core/kernels/pooling_ops_common_gpu.h" #include "tensorflow/core/platform/stream_executor.h" @@ -914,13 +914,6 @@ class MaxPoolingWithArgmaxOp : public OpKernel { "Pooling is not yet supported on the batch dimension.")); OP_REQUIRES_OK(context, context->GetAttr("include_batch_in_index", &include_batch_in_index_)); - if (context->device_type() == DeviceType(DEVICE_GPU)) { - OP_REQUIRES(context, include_batch_in_index_ == false, - errors::Unimplemented( - "include_batch_in_index=true is not yet supported " - "on the GPU kernel.")); - } - TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false, &propagate_nans_)); } @@ -1052,7 +1045,7 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel { params.tensor_in_cols, params.depth}); Tensor* grad_out = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {1}, 0, out_shape, &grad_out)); + {0}, 0, out_shape, &grad_out)); LaunchMaxPoolingGradWithArgmax::launch( context, params, grad_in, argmax, grad_out, include_batch_in_index_); @@ -1106,7 +1099,7 @@ class MaxPoolingGradGradWithArgmaxOp : public OpKernel { Tensor* grad_out = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {1}, 0, out_shape, &grad_out)); + {0}, 0, out_shape, &grad_out)); LaunchMaxPoolingGradGradWithArgmax::launch( context, params, grad_in, argmax, grad_out, include_batch_in_index_); @@ -1313,7 +1306,7 @@ struct LaunchMaxPoolingNoMask { params.out_width, params.window_rows, params.window_cols, params.row_stride, params.col_stride, params.pad_rows, params.pad_cols, output->flat().data(), nullptr, context->eigen_gpu_device(), - propagate_nans); + propagate_nans, false); if (!status) { context->SetStatus( errors::Internal("Failed launching MaxPoolForwardNoMask")); @@ -1326,10 +1319,6 @@ struct LaunchMaxPoolingWithArgmax { static void launch(OpKernelContext* context, const PoolParameters& params, const Tensor& input, Tensor* output, Tensor* argmax, bool propagate_nans, bool include_batch_in_index) { - OP_REQUIRES(context, include_batch_in_index == false, - errors::Unimplemented( - "include_batch_in_index=true is not yet supported " - "on the GPU kernel.")); bool status = functor::MaxPoolForwardWithOptionalArgmax()( input.flat().data(), params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols, params.depth, params.out_height, @@ -1337,7 +1326,7 @@ struct LaunchMaxPoolingWithArgmax { params.row_stride, params.col_stride, params.pad_rows, params.pad_cols, output->flat().data(), reinterpret_cast(argmax->flat().data()), - context->eigen_gpu_device(), propagate_nans); + context->eigen_gpu_device(), propagate_nans, include_batch_in_index); if (!status) { context->SetStatus( errors::Internal("Failed launching MaxPoolForwardWithArgmax")); @@ -1350,10 +1339,6 @@ struct LaunchMaxPoolingGradWithArgmax { static void launch(OpKernelContext* context, const PoolParameters& params, const Tensor& grad_in, const Tensor& argmax, Tensor* grad_out, const bool include_batch_in_index) { - OP_REQUIRES(context, include_batch_in_index == false, - errors::Unimplemented( - "include_batch_in_index=true is not yet supported " - "on the GPU kernel.")); const int input_size = params.tensor_in_batch * params.tensor_in_rows * params.tensor_in_cols * params.depth; const int output_size = params.tensor_in_batch * params.out_height * @@ -1364,7 +1349,8 @@ struct LaunchMaxPoolingGradWithArgmax { bool status = functor::MaxPoolBackwardWithArgmax()( output_size, input_size, grad_in.flat().data(), reinterpret_cast(argmax.flat().data()), top_offset, - bottom_offset, grad_out->flat().data(), context->eigen_gpu_device()); + bottom_offset, grad_out->flat().data(), context->eigen_gpu_device(), + include_batch_in_index); if (!status) { context->SetStatus( errors::Internal("Failed launching MaxPoolBackwardWithArgmax")); @@ -1377,10 +1363,6 @@ struct LaunchMaxPoolingGradGradWithArgmax { static void launch(OpKernelContext* context, const PoolParameters& params, const Tensor& grad_in, const Tensor& argmax, Tensor* grad_out, const bool include_batch_in_index) { - OP_REQUIRES(context, include_batch_in_index == false, - errors::Unimplemented( - "include_batch_in_index=true is not yet supported " - "on the GPU kernel.")); const int input_size = params.tensor_in_batch * params.tensor_in_rows * params.tensor_in_cols * params.depth; const int output_size = params.tensor_in_batch * params.out_height * @@ -1392,7 +1374,8 @@ struct LaunchMaxPoolingGradGradWithArgmax { bool status = functor::MaxPoolGradBackwardWithArgmax()( output_size, input_size, grad_in.flat().data(), reinterpret_cast(argmax.flat().data()), top_offset, - bottom_offset, grad_out->flat().data(), context->eigen_gpu_device()); + bottom_offset, grad_out->flat().data(), context->eigen_gpu_device(), + include_batch_in_index); if (!status) { context->SetStatus( errors::Internal("Failed launching MaxPoolGradBackwardWithArgmax")); @@ -1473,32 +1456,32 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS); // default Eigen implementation so we are using the custom kernel as the // default. However, you can explicitly invoke the eigen version using // kernel_label_map. -#define REGISTER_GPU_ONLY_POOL_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("MaxPool") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .Label("eigen_tensor"), \ - MaxPoolingOp); \ - REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \ - .Device(DEVICE_GPU) \ - .HostMemory("ksize") \ - .HostMemory("strides") \ - .TypeConstraint("T") \ - .Label("eigen_tensor"), \ - MaxPoolingV2Op); \ - REGISTER_KERNEL_BUILDER( \ - Name("MaxPool").Device(DEVICE_GPU).TypeConstraint("T"), \ - MaxPoolingNoMaskOp); \ - REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \ - .Device(DEVICE_GPU) \ - .HostMemory("ksize") \ - .HostMemory("strides") \ - .TypeConstraint("T"), \ - MaxPoolingNoMaskV2Op); \ - REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Targmax"), \ +#define REGISTER_GPU_ONLY_POOL_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("MaxPool") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .Label("eigen_tensor"), \ + MaxPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("ksize") \ + .HostMemory("strides") \ + .TypeConstraint("T") \ + .Label("eigen_tensor"), \ + MaxPoolingV2Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("MaxPool").Device(DEVICE_GPU).TypeConstraint("T"), \ + MaxPoolingNoMaskOp); \ + REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("ksize") \ + .HostMemory("strides") \ + .TypeConstraint("T"), \ + MaxPoolingNoMaskV2Op); \ + REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Targmax"), \ MaxPoolingGradGradWithArgmaxOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS); diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index f28811ffa4d..fec6f2ebd85 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -18,6 +18,7 @@ limitations under the License. #define EIGEN_USE_GPU #include + #include #include "tensorflow/core/framework/register_types.h" @@ -25,7 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/kernels/maxpooling_op.h" #include "tensorflow/core/kernels/maxpooling_op_gpu.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace { @@ -54,6 +55,8 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) { // int form, keeping track of the flattened index of the input item that // produces the max output. If a nullptr is passed in for mask, no mask // will be produced. +// include_batch_in_index: whether to include batch dimension in flattened +// index of `argmax`. // // To call the forward and backward functions, use e.g.: // const int kThreadsPerBlock = 1024 @@ -61,14 +64,12 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) { // MaxPoolForwardNCHW<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, // kThreadsPerBlock, 0, cuda_stream>>>(...); template -__global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, - const int channels, const int height, - const int width, const int pooled_height, - const int pooled_width, const int kernel_h, - const int kernel_w, const int stride_h, - const int stride_w, const int pad_t, - const int pad_l, dtype* top_data, - int64* mask) { +__global__ void MaxPoolForwardNCHW( + const int nthreads, const dtype* bottom_data, const int channels, + const int height, const int width, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + dtype* top_data, int64* mask, const bool include_batch_in_index) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; @@ -82,12 +83,13 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, wstart = max(wstart, 0); dtype maxval = Eigen::NumTraits::lowest(); int maxidx = -1; - const dtype* bottom_data_n = bottom_data + n * channels * height * width; + const int offset = n * channels * height * width; + const dtype* bottom_data_n = bottom_data + offset; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int idx = c * height * width + h * width + w; if (IsGreaterThan(bottom_data_n[idx], maxval)) { - maxidx = idx; + maxidx = include_batch_in_index ? idx + offset : idx; maxval = bottom_data_n[idx]; } } @@ -136,14 +138,12 @@ __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C( } template -__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, - const int height, const int width, - const int channels, const int pooled_height, - const int pooled_width, const int kernel_h, - const int kernel_w, const int stride_h, - const int stride_w, const int pad_t, - const int pad_l, dtype* top_data, - int64* mask) { +__global__ void MaxPoolForwardNHWC( + const int nthreads, const dtype* bottom_data, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + dtype* top_data, int64* mask, const bool include_batch_in_index) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int n = index; int c = n % channels; @@ -158,12 +158,13 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, wstart = max(wstart, 0); dtype maxval = Eigen::NumTraits::lowest(); int maxidx = -1; - const dtype* bottom_data_n = bottom_data + n * height * width * channels; + const int offset = n * height * width * channels; + const dtype* bottom_data_n = bottom_data + offset; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int idx = (h * width + w) * channels + c; if (IsGreaterThan(bottom_data_n[idx], maxval)) { - maxidx = idx; + maxidx = include_batch_in_index ? idx + offset : idx; maxval = bottom_data_n[idx]; } } @@ -231,17 +232,20 @@ __global__ void MaxPoolBackwardNoMaskNHWC( // bottom_offset: the pre-computed per-image offset of the maxpool input. // This is equal to H*W*C. // bottom_diff: the gradient with respect to the input. +// include_batch_in_index: whether to include batch dimension in flattened +// index of `argmax`. // This function relies on CudaAtomicAdd to avoid race conditions. Also, before // the kernel is run, you will need to make sure that bottom_diff is filled with // zero first. template __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff, const int64* mask, const int top_offset, - const int bottom_offset, dtype* bottom_diff) { + const int bottom_offset, dtype* bottom_diff, + const bool include_batch_in_index) { CUDA_1D_KERNEL_LOOP(index, nthreads) { - int image_id = (index / top_offset); - CudaAtomicAdd(bottom_diff + image_id * bottom_offset + mask[index], - top_diff[index]); + const int offset = + include_batch_in_index ? 0 : (index / top_offset) * bottom_offset; + CudaAtomicAdd(bottom_diff + offset + mask[index], top_diff[index]); } } @@ -358,14 +362,17 @@ __global__ void MaxPoolGradBackwardNoMaskNHWC( // bottom_offset: the pre-computed per-image offset of the maxpool output. // This is equal to Hout*Wout*C. // bottom_diff: the gradient of the gradient w.r.t. output. +// include_batch_in_index: whether to include batch dimension in flattened +// index of `argmax`. template __global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff, const int64* mask, const int top_offset, - const int bottom_offset, - dtype* bottom_diff) { + const int bottom_offset, dtype* bottom_diff, + const bool include_batch_in_index) { CUDA_1D_KERNEL_LOOP(index, nthreads) { - int image_id = (index / bottom_offset); - bottom_diff[index] = top_diff[image_id * top_offset + mask[index]]; + const int offset = + include_batch_in_index ? 0 : (index / bottom_offset) * top_offset; + bottom_diff[index] = top_diff[offset + mask[index]]; } } @@ -385,11 +392,12 @@ bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()( const int kThreadsPerBlock = 1024; const int output_size = batch * channels * pooled_height * pooled_width; if (output_size == 0) return true; - MaxPoolForwardNoMaskKernel_NCHW_VECT_C<<< + TF_CHECK_OK(CudaLaunchKernel( + MaxPoolForwardNoMaskKernel_NCHW_VECT_C, (output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, - 0, d.stream()>>>(output_size, bottom_data, height, width, channels, - pooled_height, pooled_width, kernel_h, kernel_w, - stride_h, stride_w, pad_t, pad_l, top_data); + 0, d.stream(), output_size, bottom_data, height, width, channels, + pooled_height, pooled_width, kernel_h, kernel_w, stride_h, stride_w, + pad_t, pad_l, top_data)); return d.ok(); } @@ -399,24 +407,27 @@ bool MaxPoolForwardWithOptionalArgmax::operator()( const int channels, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_t, const int pad_l, T* top_data, - int64* mask, const Eigen::GpuDevice& d, bool propagate_nans) { + int64* mask, const Eigen::GpuDevice& d, bool propagate_nans, + const bool include_batch_in_index) { const int kThreadsPerBlock = 1024; const int output_size = batch * channels * pooled_height * pooled_width; if (output_size == 0) return true; if (propagate_nans) { - MaxPoolForwardNHWC - <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>( - output_size, bottom_data, height, width, channels, pooled_height, - pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, - top_data, mask); + TF_CHECK_OK(CudaLaunchKernel( + MaxPoolForwardNHWC, + (output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream(), output_size, bottom_data, height, + width, channels, pooled_height, pooled_width, kernel_h, kernel_w, + stride_h, stride_w, pad_t, pad_l, top_data, mask, + include_batch_in_index)); } else { - MaxPoolForwardNHWC - <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>( - output_size, bottom_data, height, width, channels, pooled_height, - pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, - top_data, mask); + TF_CHECK_OK(CudaLaunchKernel( + MaxPoolForwardNHWC, + (output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream(), output_size, bottom_data, height, + width, channels, pooled_height, pooled_width, kernel_h, kernel_w, + stride_h, stride_w, pad_t, pad_l, top_data, mask, + include_batch_in_index)); } return d.ok(); } @@ -432,16 +443,17 @@ bool MaxPoolBackwardNoMask::operator()( const int bottom_size = batch * channels * height * width; if (bottom_size == 0) return true; - SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff); + TF_CHECK_OK(CudaLaunchKernel( + SetZero, (bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream(), bottom_size, bottom_diff)); const int top_size = batch * channels * pooled_height * pooled_width; - MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) / - kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>( - top_size, bottom_data, height, width, channels, pooled_height, + TF_CHECK_OK(CudaLaunchKernel( + MaxPoolBackwardNoMaskNHWC, + (top_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, + d.stream(), top_size, bottom_data, height, width, channels, pooled_height, pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, - top_diff, bottom_diff); + top_diff, bottom_diff)); return d.ok(); } @@ -449,14 +461,18 @@ template bool MaxPoolBackwardWithArgmax::operator()( const int output_size, const int input_size, const T* top_diff, const int64* mask, const int top_offset, const int bottom_offset, - T* bottom_diff, const Eigen::GpuDevice& d) { + T* bottom_diff, const Eigen::GpuDevice& d, + const bool include_batch_in_index) { const int kThreadsPerBlock = 1024; if (input_size == 0) return true; - SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff); - MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>( - output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff); + TF_CHECK_OK(CudaLaunchKernel( + SetZero, (input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream(), input_size, bottom_diff)); + TF_CHECK_OK(CudaLaunchKernel( + MaxPoolBackward, + (output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, + 0, d.stream(), output_size, top_diff, mask, top_offset, bottom_offset, + bottom_diff, include_batch_in_index)); return d.ok(); } @@ -470,20 +486,22 @@ bool MaxPoolGradBackwardNoMask::operator()( const Eigen::GpuDevice& d) { const int num_kernels = batch * channels * pooled_height * pooled_width; if (num_kernels == 0) return true; - CudaLaunchConfig config = GetCudaLaunchConfig(num_kernels, d); + GpuLaunchConfig config = GetCudaLaunchConfig(num_kernels, d); if (data_format == FORMAT_NHWC) { - MaxPoolGradBackwardNoMaskNHWC<<>>( - num_kernels, bottom_data, output_data, pooled_height, pooled_width, - channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_t, - pad_l, top_diff, bottom_diff); + TF_CHECK_OK( + CudaLaunchKernel(MaxPoolGradBackwardNoMaskNHWC, config.block_count, + config.thread_per_block, 0, d.stream(), num_kernels, + bottom_data, output_data, pooled_height, pooled_width, + channels, height, width, kernel_h, kernel_w, stride_h, + stride_w, pad_t, pad_l, top_diff, bottom_diff)); } else { - MaxPoolGradBackwardNoMaskNCHW<<>>( - num_kernels, bottom_data, output_data, pooled_height, pooled_width, - channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_t, - pad_l, top_diff, bottom_diff); + TF_CHECK_OK( + CudaLaunchKernel(MaxPoolGradBackwardNoMaskNCHW, config.block_count, + config.thread_per_block, 0, d.stream(), num_kernels, + bottom_data, output_data, pooled_height, pooled_width, + channels, height, width, kernel_h, kernel_w, stride_h, + stride_w, pad_t, pad_l, top_diff, bottom_diff)); } return d.ok(); } @@ -492,12 +510,14 @@ template bool MaxPoolGradBackwardWithArgmax::operator()( const int output_size, const int input_size, const T* top_diff, const int64* mask, const int top_offset, const int bottom_offset, - T* bottom_diff, const Eigen::GpuDevice& d) { + T* bottom_diff, const Eigen::GpuDevice& d, + const bool include_batch_in_index) { if (input_size == 0) return true; - CudaLaunchConfig config = GetCudaLaunchConfig(output_size, d); - MaxPoolGradBackward<<>>(output_size, top_diff, mask, top_offset, - bottom_offset, bottom_diff); + GpuLaunchConfig config = GetCudaLaunchConfig(output_size, d); + TF_CHECK_OK(CudaLaunchKernel( + MaxPoolGradBackward, config.block_count, config.thread_per_block, 0, + d.stream(), output_size, top_diff, mask, top_offset, bottom_offset, + bottom_diff, include_batch_in_index)); return d.ok(); } diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h index 38ebb342480..5383833b318 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.h +++ b/tensorflow/core/kernels/maxpooling_op_gpu.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if !GOOGLE_CUDA -#error This file must only be included when building with Cuda support +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda or ROCm support #endif #ifndef TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_ @@ -39,7 +39,8 @@ struct MaxPoolForwardWithOptionalArgmax { const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_t, const int pad_l, T* top_data, int64* mask, - const Eigen::GpuDevice& d, bool propagate_nans); + const Eigen::GpuDevice& d, bool propagate_nans, + const bool include_batch_in_index); }; struct MaxPoolForwardNoMask_NCHW_VECT_C { @@ -56,7 +57,7 @@ struct MaxPoolBackwardWithArgmax { bool operator()(const int output_size, const int input_size, const T* top_diff, const int64* mask, const int top_offset, const int bottom_offset, T* bottom_diff, - const Eigen::GpuDevice& d); + const Eigen::GpuDevice& d, const bool include_batch_in_index); }; template @@ -74,7 +75,7 @@ struct MaxPoolGradBackwardWithArgmax { bool operator()(const int output_size, const int input_size, const T* top_diff, const int64* mask, const int top_offset, const int bottom_offset, T* bottom_diff, - const Eigen::GpuDevice& d); + const Eigen::GpuDevice& d, const bool include_batch_in_index); }; template diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc index 8eb334f2b49..566ab79fb0d 100644 --- a/tensorflow/core/kernels/mkl_aggregate_ops.cc +++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc @@ -38,202 +38,192 @@ class MklAddNOp : public OpKernel { ~MklAddNOp() {} explicit MklAddNOp(OpKernelConstruction* context) : OpKernel(context) {} - void Compute(OpKernelContext* ctx) override { - const int num = ctx->num_inputs(); - // Only additions of 2 input tensors is supported now - OP_REQUIRES(ctx, num / 2 == 2, - errors::InvalidArgument("Only additions of two tensors " - "supported by MKL. Num inputs: ", - num)); + TensorShape GetTensorShape(OpKernelContext* ctx, size_t src_index) { + const Tensor& src_tensor = MklGetInput(ctx, src_index); + MklDnnShape src_mkl_shape; + GetMklShape(ctx, src_index, &src_mkl_shape); + return src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape() + : src_tensor.shape(); + } - try { - auto cpu_engine = engine(engine::cpu, 0); - size_t src1_idx = 0, src2_idx = 1, output_idx = 0; - const Tensor& src1_tensor = MklGetInput(ctx, src1_idx); - const Tensor& src2_tensor = MklGetInput(ctx, src2_idx); + bool CheckInputShape(OpKernelContext* ctx) { + const int num_inputs = ctx->num_inputs() / 2; + const TensorShape src0_shape = GetTensorShape(ctx, 0); - MklDnnShape src1_mkl_shape, src2_mkl_shape; - GetMklShape(ctx, src1_idx, &src1_mkl_shape); - GetMklShape(ctx, src2_idx, &src2_mkl_shape); - bool input1_in_mkl_format = src1_mkl_shape.IsMklTensor(); - bool input2_in_mkl_format = src2_mkl_shape.IsMklTensor(); - int src1_dims_size = input1_in_mkl_format ? src1_mkl_shape.GetDimension() - : src1_tensor.dims(); - int src2_dims_size = input2_in_mkl_format ? src2_mkl_shape.GetDimension() - : src2_tensor.dims(); - // if the shapes of two tensors are not same raise op error - TensorShape src1_shape, src2_shape; - src1_shape = input1_in_mkl_format ? src1_mkl_shape.GetTfShape() - : src1_tensor.shape(); - src2_shape = input2_in_mkl_format ? src2_mkl_shape.GetTfShape() - : src2_tensor.shape(); - - if (!src1_shape.IsSameSize(src2_shape)) { + for (size_t i = 1; i < num_inputs; ++i) { + if (!src0_shape.IsSameSize(GetTensorShape(ctx, i))) { ctx->SetStatus(errors::InvalidArgument( "Inputs to operation ", this->name(), " of type ", this->type_string(), " must have the same size and shape. Input 0: ", - src1_shape.DebugString(), - " != input 1: ", src2_shape.DebugString())); - } + src0_shape.DebugString(), " != input : ", i, + GetTensorShape(ctx, i).DebugString())); - if (!input1_in_mkl_format && src1_dims_size == 0) { - Tensor* dst_tensor = nullptr; - MklDnnShape mkl_shape_dst; - mkl_shape_dst.SetMklTensor(false); - AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, - src1_tensor.shape(), mkl_shape_dst); - float user_i1 = (src1_tensor.scalar()()); - float user_i2 = (src2_tensor.scalar()()); - dst_tensor->scalar()() = std::plus{}(user_i1, user_i2); + return false; + } + } + + return true; + } + + // Return first tensor index which is in MKL layout, or -1 with no MKL input. + int FindMKLInputIndex(OpKernelContext* ctx) { + int mkl_index = -1; + const int num_inputs = ctx->num_inputs() / 2; + + MklDnnShape src_mkl_shape; + for (size_t i = 0; i < num_inputs; ++i) { + GetMklShape(ctx, i, &src_mkl_shape); + if (src_mkl_shape.IsMklTensor()) { + mkl_index = i; + break; + } + } + + return mkl_index; + } + + void ComputeScalar(OpKernelContext* ctx) { + const int num_inputs = ctx->num_inputs() / 2; + const size_t kOutputIdx = 0; + TensorShape output_tf_shape; + MklDnnShape output_mkl_shape; + Tensor* dst_tensor = nullptr; + + T sum = static_cast(0); + for (int src_idx = 0; src_idx < num_inputs; ++src_idx) { + const Tensor& src_tensor = MklGetInput(ctx, src_idx); + T* src_i = const_cast(src_tensor.flat().data()); + sum += src_i[0]; + } + + output_mkl_shape.SetMklTensor(false); + output_tf_shape = MklGetInput(ctx, kOutputIdx).shape(); + AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, + output_mkl_shape); + + T* out_o = dst_tensor->flat().data(); + out_o[0] = sum; + } + + void Compute(OpKernelContext* ctx) override { + // Each input tensor in MKL layout has additional meta-tensor carrying + // layout information. So the number of actual tensors is half the total + // number of inputs. + const int num_inputs = ctx->num_inputs() / 2; + + MklDnnShape mkl_shape; + const size_t kSrc0Idx = 0; + const size_t kOutputIdx = 0; + + if (num_inputs == 1) { + GetMklShape(ctx, kSrc0Idx, &mkl_shape); + bool input_in_mkl_format = mkl_shape.IsMklTensor(); + + if (input_in_mkl_format) { + ForwardMklTensorInToOut(ctx, kSrc0Idx, kOutputIdx); + } else { + ForwardTfTensorInToOut(ctx, kSrc0Idx, kOutputIdx); + } + return; + } + + // Check if the input shape is same + if (!CheckInputShape(ctx)) return; + + try { + TensorShape output_tf_shape; + MklDnnShape output_mkl_shape; + const Tensor& src_tensor = MklGetInput(ctx, kSrc0Idx); + + Tensor* dst_tensor = nullptr; + + // Nothing to compute, return. + if (src_tensor.shape().num_elements() == 0) { + output_mkl_shape.SetMklTensor(false); + output_tf_shape = src_tensor.shape(); + AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, + output_mkl_shape); return; } - // If there is nothing to compute, return. - if (!input1_in_mkl_format && !input2_in_mkl_format) { - if (src1_tensor.shape().num_elements() == 0) { - Tensor* dst_tensor = nullptr; - MklDnnShape mkl_shape_dst; - mkl_shape_dst.SetMklTensor(false); - AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, - src1_tensor.shape(), mkl_shape_dst); - return; - } - } - - std::vector coeff(2, 1.0); - MklDnnData src1(&cpu_engine); - MklDnnData src2(&cpu_engine); - MklDnnData dst(&cpu_engine); - - int tmp_size = input1_in_mkl_format ? src2_dims_size : src1_dims_size; - memory::dims dims(tmp_size); - memory::dims strides(tmp_size); - memory::desc md1({}, memory::data_undef, memory::format_undef); - memory::desc md2({}, memory::data_undef, memory::format_undef); - - // For creating Sum primitive, we need to ensure that all inputs are in - // same format. What that means is if we have a mixed input case - where - // one input is in Tensorflow format and one input is in MKL format -, - // then we need to ensure that all inputs are in same format for - // primitive construction. For performance reason, we say that all inputs - // are in MKL format in such case, and insert reorder for input that is - // in Tensorflow format into MKL format. On the other hand, if both the - // inputs are in MKL format or both are in Tensorflow format, then we - // dont need reorder. - if (!input1_in_mkl_format && !input2_in_mkl_format) { - // If both the inputs are in Tensorflow format, we create blocked memory - // descriptor. - dims = TFShapeToMklDnnDims(src1_tensor.shape()); - strides = CalculateTFStrides(dims); - md1 = MklDnnData::CreateBlockedMemDesc(dims, strides); - md2 = md1; - } else if (input1_in_mkl_format && !input2_in_mkl_format) { - // If one input is in MKL format and other is in Tensorflow, then - // create respective descriptors describing the actual case. For input - // in Mkl format, we just get Mkl layout from MklDnnShape. For input in - // Tensorflow format, we create memory descriptor using data format. - md1 = src1_mkl_shape.GetMklLayout(); - - memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat(); - auto src1_tf_data_format = - MklDnnDataFormatToTFDataFormat(src1_mkl_data_format); - memory::dims src2_dims; - if (src2_tensor.dims() == 4) { - src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), - src1_tf_data_format); - } else { - src2_dims = TFShapeToMklDnnDimsInNCDHW(src2_tensor.shape(), - src1_tf_data_format); - } - md2 = memory::desc(src2_dims, MklDnnType(), src1_mkl_data_format); - } else if (input2_in_mkl_format && !input1_in_mkl_format) { - // Same comment as above. - memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat(); - auto src2_tf_data_format = - MklDnnDataFormatToTFDataFormat(src2_mkl_data_format); - memory::dims src1_dims; - if (src1_tensor.dims() == 4) { - src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), - src2_tf_data_format); - } else { - src1_dims = TFShapeToMklDnnDimsInNCDHW(src1_tensor.shape(), - src2_tf_data_format); - } - md1 = memory::desc(src1_dims, MklDnnType(), src2_mkl_data_format); - - md2 = src2_mkl_shape.GetMklLayout(); - } else { - // If both the inputs are in MKL format, we use Mkl layout of the input - // tensors. - md1 = src1_mkl_shape.GetMklLayout(); - md2 = src2_mkl_shape.GetMklLayout(); - } - src1.SetUsrMem(md1, &src1_tensor); - src2.SetUsrMem(md2, &src2_tensor); - - // As per comment above, we tell MKLDNN that both the inputs are in same - // format. So we set common memory descriptor in MKL format, if any of the - // inputs are in MKL format. Let's get memory descriptor that we will use - // for both the inputs. - // We set output memory descriptor in MKL format, if any of the - // inputs are in MKL format. - memory::desc common_md({}, memory::data_undef, memory::format_undef); - if (input1_in_mkl_format || input2_in_mkl_format) { - common_md = input1_in_mkl_format ? md1 : md2; - dst.SetUsrMem(common_md); - } else { - // Since both the inputs are in Tensorflow format, and have - // same shape, we can get memory descriptor from any input. - common_md = md1; - dst.SetUsrMem(common_md); + if (src_tensor.dims() == 0) { + ComputeScalar(ctx); + return; } + auto cpu_engine = engine(engine::cpu, 0); + std::vector coeff(num_inputs, 1.0); std::vector srcs_pd; - // Memory descriptor for 1st input - srcs_pd.push_back(memory::primitive_desc(common_md, cpu_engine)); - // Memory descriptor for 2nd input - srcs_pd.push_back(memory::primitive_desc(common_md, cpu_engine)); - auto sum_pd = sum::primitive_desc(dst.GetUsrMemDesc(), coeff, srcs_pd); - - // Now we setup resources for primitive execution. - // First, we need to check if any of the inputs need to be reordered as - // per the logic described above. Since output will be in MKL format if - // atleast one input is in MKL format, we choose output descriptor for - // reorder. std::vector inputs; - // Check if actual input format of the tensor is different than common_pd - // we told MKLDNN. In that case, we will need reorder. - src1.CheckReorderToOpMem(srcs_pd[0]); - src2.CheckReorderToOpMem(srcs_pd[1]); - inputs.push_back(src1.GetOpMem()); - inputs.push_back(src2.GetOpMem()); - // Allocate output tensor now. - Tensor* dst_tensor = nullptr; - MklDnnShape output_mkl_shape; - TensorShape output_tf_shape; + MklDnnData dst(&cpu_engine); + MklDnnData src(&cpu_engine); + bool has_mkl_input = false; + int mkl_input_index = FindMKLInputIndex(ctx); + memory::format mkl_data_format; + TensorFormat tf_data_format; + if (mkl_input_index >= 0) { + has_mkl_input = true; + GetMklShape(ctx, mkl_input_index, &mkl_shape); + // MKL input has the data format information. + mkl_data_format = mkl_shape.GetTfDataFormat(); + tf_data_format = MklDnnDataFormatToTFDataFormat(mkl_data_format); + } - if (input2_in_mkl_format || input1_in_mkl_format) { - output_mkl_shape.SetMklTensor(true); - auto output_pd = dst.GetUsrMemPrimDesc(); + // Create memory descriptor for MKL-DNN. + // If all input in Tensorflow format, create block memory descriptor, + // else convet TF format to MKL memory descriptor + for (int src_idx = 0; src_idx < num_inputs; ++src_idx) { + MklDnnShape src_mkl_shape; + GetMklShape(ctx, src_idx, &src_mkl_shape); + memory::desc md({}, memory::data_undef, memory::format_undef); + src = MklDnnData(&cpu_engine); + const Tensor& src_tensor = MklGetInput(ctx, src_idx); + + if (src_mkl_shape.IsMklTensor()) { + md = src_mkl_shape.GetMklLayout(); + } else { + if (has_mkl_input) { + memory::dims src_dims; + if (src_tensor.dims() == 4) { + src_dims = + TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tf_data_format); + } else { + DCHECK(src_tensor.dims() == 5); + src_dims = TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), + tf_data_format); + } + md = memory::desc(src_dims, MklDnnType(), mkl_data_format); + } else { + // Create block memory descriptor for TensorFlow format input. + auto dims = TFShapeToMklDnnDims(src_tensor.shape()); + auto strides = CalculateTFStrides(dims); + md = MklDnnData::CreateBlockedMemDesc(dims, strides); + } + } + srcs_pd.push_back(memory::primitive_desc(md, cpu_engine)); + src.SetUsrMem(md, &src_tensor); + inputs.push_back(src.GetOpMem()); + } + + auto sum_pd = sum::primitive_desc(coeff, srcs_pd); + output_mkl_shape.SetMklTensor(has_mkl_input); + auto output_pd = sum_pd.dst_primitive_desc(); + dst.SetUsrMem(output_pd); + + if (has_mkl_input) { output_mkl_shape.SetMklLayout(&output_pd); output_mkl_shape.SetElemType(MklDnnType()); - if (input1_in_mkl_format) { - output_mkl_shape.SetTfLayout(src1_dims_size, - src1_mkl_shape.GetSizesAsMklDnnDims(), - src1_mkl_shape.GetTfDataFormat()); - } else { - output_mkl_shape.SetTfLayout(src2_dims_size, - src2_mkl_shape.GetSizesAsMklDnnDims(), - src2_mkl_shape.GetTfDataFormat()); - } + output_mkl_shape.SetTfLayout(mkl_shape.GetDimension(), + mkl_shape.GetSizesAsMklDnnDims(), + mkl_shape.GetTfDataFormat()); output_tf_shape.AddDim((output_pd.get_size() / sizeof(T))); } else { - output_mkl_shape.SetMklTensor(false); - output_tf_shape = src1_tensor.shape(); + // All inputs have TF shapes, get the shape from first one. + output_tf_shape = MklGetInput(ctx, kSrc0Idx).shape(); } - AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, output_tf_shape, + AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, output_mkl_shape); dst.SetUsrMemDataHandle(dst_tensor); @@ -259,6 +249,7 @@ class MklAddNOp : public OpKernel { MklAddNOp); TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_bfloat16(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index 28825e1a9c6..f13cfc1782f 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -16,16 +16,14 @@ #ifdef INTEL_MKL #define EIGEN_USE_THREADS +#include "mkldnn.hpp" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/mkl_pooling_ops_common.h" #include "tensorflow/core/util/mkl_util.h" -#include "tensorflow/core/kernels/mkl_pooling_ops_common.h" - -#ifndef INTEL_MKL_ML_ONLY -#include "mkldnn.hpp" using mkldnn::algorithm; using mkldnn::engine; using mkldnn::error; @@ -34,402 +32,11 @@ using mkldnn::padding_kind; using mkldnn::pooling_backward; using mkldnn::pooling_forward; using mkldnn::prop_kind; -#endif namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifdef INTEL_MKL_ML_ONLY - -template -class MklAvgPoolingOp : public OpKernel { - public: - explicit MklAvgPoolingOp(OpKernelConstruction* context) : OpKernel(context) { - string data_format; - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - - OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); - OP_REQUIRES(context, ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 4, - errors::InvalidArgument("Sliding window stride field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, - errors::Unimplemented("Pooling is not yet supported on the " - "batch dimension.")); - } - - void Compute(OpKernelContext* context) override { - MklAvgPoolingOpContext mkl_context; - const Tensor& tensor_in = MklGetInput(context, 0); - GetMklShape(context, 0, &mkl_context.input_shape); - bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); - - if (!input_in_mkl_format) - mkl_context.params.in_dim = tensor_in.dims(); - else - mkl_context.params.in_dim = mkl_context.input_shape.GetDimension(); - - MklPoolParameters pool_params; - if (!input_in_mkl_format) { - pool_params.Init(context, ksize_, stride_, padding_, data_format_, - tensor_in.shape()); - } else { - pool_params.Init(context, ksize_, stride_, padding_, data_format_, - &mkl_context.input_shape); - } - - // Extract the parameters for the op from the pooling specs - ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); - - Tensor mkl_tmp_input_buf_tensor_; - mkl_context.MklCreateLayoutsAndPrimitives(context, - &mkl_tmp_input_buf_tensor_); - OP_REQUIRES_OK(context, context->status()); - - Tensor workspace_tensor; - void* workspace_buf; - AllocTmpBuffer(context, &workspace_tensor, mkl_context.lt_workspace, - &workspace_buf); - - if (mkl_context.convert_input != nullptr) { - if (input_in_mkl_format == false) { - CHECK_EQ( - dnnConversionExecute_F32( - mkl_context.convert_input, - static_cast(const_cast(tensor_in.flat().data())), - mkl_context.input_buf), - E_SUCCESS); - CHECK_EQ(dnnDelete_F32(mkl_context.convert_input), E_SUCCESS); - } else { - mkl_context.input_shape.GetConvertedFlatData( - mkl_context.lt_prim_input, - static_cast(const_cast(tensor_in.flat().data())), - mkl_context.input_buf); - } - mkl_context.pooling_res[dnnResourceSrc] = mkl_context.input_buf; - } else { - mkl_context.pooling_res[dnnResourceSrc] = - static_cast(const_cast(tensor_in.flat().data())); - } - - // Declare output tensor and allocate memory - Tensor* output = nullptr; - TensorShape tensor_out_shape; - MklShape mkl_out_shape; - mkl_out_shape.SetMklTensor(true); - mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst); - mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, - mkl_context.params.out_sizes, - mkl_context.params.out_strides); - mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); - - tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( - mkl_out_shape.GetMklLayout())) / - sizeof(T)); - - AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape, - mkl_out_shape); - mkl_context.pooling_res[dnnResourceDst] = - static_cast(output->flat().data()); - - mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf; - - CHECK_EQ( - dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res), - E_SUCCESS); - - mkl_context.MklCleanup(); - } // Compute - - private: - typedef struct { - MklPoolingOpParams params; - MklShape input_shape; - dnnPrimitive_t prim_pooling_fwd = nullptr, convert_input = nullptr; - dnnLayout_t lt_user_input = nullptr, lt_prim_input = nullptr, - lt_workspace = nullptr; - void* input_buf = nullptr; - void* pooling_res[dnnResourceNumber]; - - void MklCreateLayoutsAndPrimitives(OpKernelContext* context, - Tensor* mkl_tmp_input_buf_tensor) { - bool input_in_mkl_format = input_shape.IsMklTensor(); - - if (!input_in_mkl_format) { - CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, - params.in_sizes, params.in_strides), - E_SUCCESS); - } else { - lt_user_input = (dnnLayout_t)input_shape.GetCurLayout(); - } - - dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg; - dnnPrimitiveAttributes_t primAttr = nullptr; - - // Create DNN primitives - CHECK_EQ(dnnPoolingCreateForward_F32( - &prim_pooling_fwd, primAttr, algorithm, lt_user_input, - params.kernel_size, params.kernel_stride, params.in_offset, - dnnBorderZerosAsymm), - E_SUCCESS); - - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( - <_prim_input, prim_pooling_fwd, dnnResourceSrc), - E_SUCCESS); - if (!dnnLayoutCompare_F32(lt_user_input, lt_prim_input)) { - CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_user_input, - lt_prim_input), - E_SUCCESS); - - AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_prim_input, - &input_buf); - } - - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_fwd, - dnnResourceWorkspace), - E_SUCCESS); - } - - void MklCleanup() { - bool input_in_mkl_format = input_shape.IsMklTensor(); - if (!input_in_mkl_format) { - CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); - } - - CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); - CHECK_EQ(dnnLayoutDelete_F32(lt_prim_input), E_SUCCESS); - } - } MklAvgPoolingOpContext; - - std::vector ksize_; - std::vector stride_; - Padding padding_; - TensorFormat data_format_; -}; - -//----------------------------------------------------------------------------- - -template -class MklAvgPoolingGradOp : public OpKernel { - public: - explicit MklAvgPoolingGradOp(OpKernelConstruction* context) - : OpKernel(context) { - string data_format; - - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); - OP_REQUIRES(context, ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, - errors::Unimplemented("Pooling is not yet supported on the " - "batch dimension.")); - } - - void Compute(OpKernelContext* context) override { - MklAvgPoolingGradOpContext mkl_context; - const Tensor& tensor_in_shape = MklGetInput(context, 0); - const Tensor& out_backprop = MklGetInput(context, 1); - GetMklShape(context, 1, &mkl_context.out_backprop_shape); - bool outbackprop_in_mkl_format = - mkl_context.out_backprop_shape.IsMklTensor(); - - TensorShape output_shape; - auto shape_vec = tensor_in_shape.vec(); - for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) { - output_shape.AddDim(shape_vec(i)); - } - - MklPoolParameters pool_params; - pool_params.Init(context, ksize_, stride_, padding_, data_format_, - output_shape); - - if (outbackprop_in_mkl_format == false) - mkl_context.params.in_dim = out_backprop.dims(); - else - mkl_context.params.in_dim = mkl_context.out_backprop_shape.GetDimension(); - - // Extract the parameters for the op from the pooling specs - ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); - - // Tensors needed to create temporary buffers - Tensor outbackprop_buf_tensor; - void* outbackprop_buf; - mkl_context.MklCreateLayoutsAndPrimitives(context); - OP_REQUIRES_OK(context, context->status()); - - // Check if outbackprop layout requires conversion. - if (!dnnLayoutCompare_F32(mkl_context.lt_user_outbackprop, - mkl_context.lt_prim_outbackprop)) { - CHECK_EQ(dnnConversionCreate_F32(&mkl_context.convert_outbackprop, - mkl_context.lt_user_outbackprop, - mkl_context.lt_prim_outbackprop), - E_SUCCESS); - - AllocTmpBuffer(context, &outbackprop_buf_tensor, - mkl_context.lt_prim_outbackprop, &outbackprop_buf); - - if (!outbackprop_in_mkl_format) { - CHECK_EQ(dnnConversionExecute_F32(mkl_context.convert_outbackprop, - static_cast(const_cast( - out_backprop.flat().data())), - outbackprop_buf), - E_SUCCESS); - CHECK_EQ(dnnDelete_F32(mkl_context.convert_outbackprop), E_SUCCESS); - } else { - mkl_context.out_backprop_shape.GetConvertedFlatData( - mkl_context.lt_prim_outbackprop, - static_cast(const_cast(out_backprop.flat().data())), - outbackprop_buf); - } - mkl_context.pooling_res[dnnResourceDiffDst] = outbackprop_buf; - } else { - mkl_context.pooling_res[dnnResourceDiffDst] = - static_cast(const_cast(out_backprop.flat().data())); - } - - // Handle workspace requirements. - Tensor workspace_buf_tensor; - void* workspace_buf; - AllocTmpBuffer(context, &workspace_buf_tensor, mkl_context.lt_workspace, - &workspace_buf); - mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf; - - // Handle MKL output tensor setup. - Tensor* output = nullptr; - TensorShape tensor_out_shape; - MklShape mkl_out_shape; - mkl_out_shape.SetMklTensor(true); - mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_bwd, - dnnResourceDiffSrc); - mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, - mkl_context.params.in_sizes, - mkl_context.params.in_strides); - mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); - - tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( - mkl_out_shape.GetMklLayout())) / - sizeof(T)); - - AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape, - mkl_out_shape); - - // Set output tensor. - mkl_context.pooling_res[dnnResourceDiffSrc] = - static_cast(output->flat().data()); - - // Execute primitive. - CHECK_EQ( - dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), - E_SUCCESS); - - mkl_context.MklCleanup(); - } - - private: - typedef struct { - MklPoolingOpParams params; - MklShape out_backprop_shape; - dnnPrimitive_t prim_pooling_bwd = nullptr, convert_outbackprop = nullptr; - void* pooling_res[dnnResourceNumber]; - dnnLayout_t lt_user_input = nullptr, lt_user_outbackprop = nullptr, - lt_prim_outbackprop = nullptr, lt_workspace = nullptr; - - void MklCreateLayoutsAndPrimitives(OpKernelContext* context) { - const Tensor& tensor_in_shape = MklGetInput(context, 0); - const Tensor& out_backprop = MklGetInput(context, 1); - bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor(); - - if (!outbackprop_in_mkl_format) { - // For avgpooling, tensor_in_shape should have 1 dimension, and 4 - // elements. - OP_REQUIRES( - context, - tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, - errors::InvalidArgument("original input shape must be " - "1-dimensional and 4 elements")); - - // For avgpooling, out_backprop should have 4 dimensions. - OP_REQUIRES( - context, out_backprop.dims() == 4, - errors::InvalidArgument("out_backprop must be 4-dimensional")); - } else { - // Input in MKL format. - // For avgpooling, out_backprop should have 4 dimensions. - OP_REQUIRES( - context, out_backprop_shape.GetDimension() == 4, - errors::InvalidArgument("out_backprop must be 4-dimensional")); - } - - // TODO(inteltf): Get outbackprop layout. - // Do we need to create layout in every invocation? - if (!outbackprop_in_mkl_format) { - CHECK_EQ(dnnLayoutCreate_F32(<_user_outbackprop, params.in_dim, - params.out_sizes, params.out_strides), - E_SUCCESS); - } else { - lt_user_outbackprop = (dnnLayout_t)out_backprop_shape.GetCurLayout(); - } - - // Create the backward primitive - // Create DNN user layout - CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, - params.in_sizes, params.in_strides), - E_SUCCESS); - - // Create PoolingBackward primitive - dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg; - dnnPrimitiveAttributes_t primAttr = nullptr; - CHECK_EQ(dnnPoolingCreateBackward_F32( - &prim_pooling_bwd, primAttr, algorithm, lt_user_input, - params.kernel_size, params.kernel_stride, params.in_offset, - dnnBorderZerosAsymm), - E_SUCCESS); - - // Create expected outbackprop layout from the primitive. - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( - <_prim_outbackprop, prim_pooling_bwd, dnnResourceDiffDst), - E_SUCCESS); - - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_bwd, - dnnResourceWorkspace), - E_SUCCESS); - } - - void MklCleanup() { - bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor(); - CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS); - CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); - if (!outbackprop_in_mkl_format) { - CHECK_EQ(dnnLayoutDelete_F32(lt_user_outbackprop), E_SUCCESS); - } - CHECK_EQ(dnnLayoutDelete_F32(lt_prim_outbackprop), E_SUCCESS); - CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS); - } - } MklAvgPoolingGradOpContext; - - std::vector ksize_; - std::vector stride_; - Padding padding_; - TensorFormat data_format_; -}; // MklAvgPoolingGradOp - -#else - template class MklAvgPoolingOp : public MklPoolingForwardOpBase { public: @@ -701,25 +308,35 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { } }; // MklAvgPoolingGradOp -REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3D") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklAvgPoolingOp); +#define REGISTER_MKL_AVGPOOL3D_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklAvgPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklAvgPoolingGradOp); -REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3DGrad") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklAvgPoolingGradOp); +TF_CALL_float(REGISTER_MKL_AVGPOOL3D_KERNELS); +TF_CALL_bfloat16(REGISTER_MKL_AVGPOOL3D_KERNELS); -#endif // INTEL_MKL_ML_ONLY +#define REGISTER_MKL_AVGPOOL_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklAvgPool") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklAvgPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklAvgPoolingGradOp); -REGISTER_KERNEL_BUILDER(Name("_MklAvgPool") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklAvgPoolingOp); +TF_CALL_float(REGISTER_MKL_AVGPOOL_KERNELS); +TF_CALL_bfloat16(REGISTER_MKL_AVGPOOL_KERNELS); REGISTER_KERNEL_BUILDER(Name("_MklQuantizedAvgPool") .Device(DEVICE_CPU) @@ -733,11 +350,5 @@ REGISTER_KERNEL_BUILDER(Name("_MklQuantizedAvgPool") .Label(mkl_op_registry::kMklQuantizedOpLabel), MklAvgPoolingOp); -REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklAvgPoolingGradOp); - } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index bc135de11e0..00ba430560b 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/batch_matmul_op_impl.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -218,10 +219,13 @@ class BatchMatMulMkl : public OpKernel { } }; -#define REGISTER_BATCH_MATMUL_MKL(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ - BatchMatMulMkl) +#define REGISTER_BATCH_MATMUL_MKL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ + BatchMatMulMkl) \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint("T"), \ + BatchMatMulV2Op) #ifdef ENABLE_MKL TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index d8fbb83940a..adabbb534ed 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -47,6 +47,45 @@ enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; // -------------------------------------------------------------------------- // Eigen Concat Op // -------------------------------------------------------------------------- +namespace { +template +struct RequantizeCopier { + RequantizeCopier( + const std::vector>* input_min_and_max, + float output_min, float output_max) + : output_min(output_min), output_max(output_max) { + DCHECK(input_min_and_max); + this->input_min_and_max = input_min_and_max; + } + + inline void Copy(T* dst, const T* src, int input_index, size_t n) { + const float input_min = (*input_min_and_max)[input_index].first; + const float input_max = (*input_min_and_max)[input_index].second; + if (input_min == output_min && input_max == output_max) { + DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum::v())); + memcpy(dst, src, n * sizeof(T)); + } else { + Eigen::array dims; + dims[0] = n; + typename TTypes::UnalignedConstTensor input_array(src, dims); + typename TTypes::UnalignedTensor output_array(dst, dims); + + QuantizedToFloatStruct q2f(input_min, input_max); + auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f); + FloatToQuantizedStruct f2q(output_min, output_max); + // RequantizeCopier::Copy is called from within a shard of computation, so + // don't use the threadpool device here, simply assign with default CPU + // device. + output_array = QUANTIZE_WITH_EIGEN(input_float, f2q, T); + } + } + + float output_min; + float output_max; + const std::vector>* input_min_and_max; +}; +} // namespace + template class EigenConcatBaseOp : public OpKernel { public: @@ -55,12 +94,44 @@ class EigenConcatBaseOp : public OpKernel { explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} + void CalculateInputAndOutputRange( + const OpInputList& input_mins, const OpInputList& input_maxes, + const size_t N, + std::vector>* input_mins_and_maxes, + float* output_min, float* output_max) { + input_mins_and_maxes->reserve(N); + float overall_min = std::numeric_limits::max(); + float overall_max = std::numeric_limits::lowest(); + for (int i = 0; i < N; ++i) { + const float input_min = input_mins[i].flat()(0); + const float input_max = input_maxes[i].flat()(0); + input_mins_and_maxes->emplace_back(input_min, input_max); + overall_min = std::min(overall_min, input_min); + overall_max = std::max(overall_max, input_max); + } + if (std::is_signed::value) { + // For signed, we want a symmetrical distribution including zero for the + // output, so pick a range that meets that need. + const float largest_value = + std::max(std::abs(overall_min), std::abs(overall_max)); + *output_min = -largest_value; + *output_max = largest_value; + } else { + // For MKL quantization, we only support scaled mode, so the range is + // [0, m] for unsigned data where m is the range maximum + *output_min = 0.0f; + *output_max = overall_max; + } + } + // Although, we modify Compute for this call to accept one extra param, // we need to have empty Compute because Compute is pure virtual function. void Compute(OpKernelContext* c) {} void Compute(OpKernelContext* c, const std::vector& values, - const TensorShapeList& input_shapes) { + const TensorShapeList& input_shapes, + const OpInputList& input_mins, const OpInputList& input_maxes, + bool quantized_input) { const Tensor* concat_dim_tensor; const char* axis_attribute_name = AxisArgName == NAME_IS_AXIS @@ -79,19 +150,28 @@ class EigenConcatBaseOp : public OpKernel { const int input_dims = input_shapes[0].dims(); const TensorShape& input_shape = input_shapes[0]; - int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(c, - (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), - errors::InvalidArgument( - "ConcatOp : Expected concatenating dimensions in the range " - "[", - -input_dims, ", ", input_dims, "), but got ", concat_dim)); + int32 axis = (concat_dim < 0) ? (concat_dim + input_dims) : concat_dim; + OP_REQUIRES( + c, + (0 <= axis && axis < input_dims) || + (allow_legacy_scalars() && concat_dim == 0), + errors::InvalidArgument( + "ConcatOp : Expected concatenating dimensions in the range [", + -input_dims, ", ", input_dims, "), but got ", concat_dim)); + + float output_min = std::numeric_limits::max(); + float output_max = std::numeric_limits::lowest(); + std::vector> input_mins_and_maxes; + if (quantized_input) { + CalculateInputAndOutputRange(input_mins, input_maxes, N, + &input_mins_and_maxes, &output_min, + &output_max); + } // Note that we reduce the concat of n-dimensional tensors into a two // dimensional concat. Assuming the dimensions of any input/output - // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along - // the dimension indicated with size y0, we flatten it to {x, y}, where y = - // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). + // tensor are {x_0, x_1,...,x_n-1, y_0, y_1,...,y_m-1}, where the + // concat is along the dimension indicated with size y_0, we flatten it + // to {x, y}, where y = Prod_i(y_i) and x = ((n > 0) ? Prod_i(x_i) : 1). ConstMatrixVector inputs_flat; inputs_flat.reserve(N); int64 inputs_flat_dim0 = 1; @@ -131,7 +211,24 @@ class EigenConcatBaseOp : public OpKernel { if (output->NumElements() > 0) { int64 output_dim1 = output->NumElements() / inputs_flat_dim0; auto output_flat = output->shaped({inputs_flat_dim0, output_dim1}); - ConcatCPU(c->device(), inputs_flat, &output_flat); + if (!quantized_input) { + ConcatCPU(c->device(), inputs_flat, &output_flat); + } else { + ConcatCPUImpl( + c->device(), inputs_flat, sizeof(T) /* cost_per_unit */, + RequantizeCopier(&input_mins_and_maxes, output_min, output_max), + &output_flat); + } + } + + if (quantized_input) { + Tensor* output_min_tensor = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(1, {}, &output_min_tensor)); + output_min_tensor->flat()(0) = output_min; + + Tensor* output_max_tensor = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(2, {}, &output_max_tensor)); + output_max_tensor->flat()(0) = output_max; } } }; @@ -158,7 +255,6 @@ class MklConcatOp : public OpKernel { OpInputList input_tensors; GetMklInputList(context, "values", &input_tensors); const int N = input_tensors.size(); - // Get Tensor shapes. std::vector mkl_input_shapes(N); GetMklShapeList(context, "values", &mkl_input_shapes); @@ -178,6 +274,7 @@ class MklConcatOp : public OpKernel { // check that ranks of all tensors match // and that their shapes match except for concat_dim. int i = 0; + int num_of_empty_inputs = 0; bool invoke_eigen = false; bool are_all_mkl_inputs = true, are_all_tf_inputs = true; const TensorShape expected_shape = mkl_input_shapes[0].IsMklTensor() @@ -217,10 +314,15 @@ class MklConcatOp : public OpKernel { else are_all_mkl_inputs = false; - if (s_dims != 4) invoke_eigen = true; + if (s_dims != 4 && s_dims != 2) invoke_eigen = true; + + if (input_tensors[i].NumElements() == 0) num_of_empty_inputs++; + ++i; } + if (num_of_empty_inputs == i) invoke_eigen = true; + // All inputs are not in one format (TF or MKL). This is mixed input case. // We can potentially optimize this case by converting all TF inputs // to Mkl format. But currently, we fall to Eigen for this case. @@ -229,7 +331,9 @@ class MklConcatOp : public OpKernel { if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true; OpInputList input_mins, input_maxes; - if (std::is_same::value || std::is_same::value) { + bool quantized_input = + std::is_same::value || std::is_same::value; + if (quantized_input) { // MKL-DNN concat does not support input tensors that have different // ranges. Check if the ranges of the all input tensors are the same. // If not, forward it to Eigen implementation. @@ -262,17 +366,8 @@ class MklConcatOp : public OpKernel { // Call Eigen library if (invoke_eigen) { - // MKL-DNN quantized concat does not support input tensors with - // different ranges. - // TODO (mabuzain): Add quantized version of CallEigen() to support - // this case. - OP_REQUIRES( - context, - (!std::is_same::value && !std::is_same::value), - errors::Unimplemented("MKL DNN quantized concat does not " - "support input tensors that have " - "different ranges")); - CallEigenVersion(context, input_tensors, mkl_input_shapes); + CallEigenVersion(context, input_tensors, input_mins, input_maxes, + mkl_input_shapes, quantized_input); return; } @@ -292,6 +387,7 @@ class MklConcatOp : public OpKernel { bool isMklReorderNeeded = false; memory::format mkl_common_format = memory::format::any; + std::vector inputs; if (are_all_mkl_inputs) { mkl_common_format = FindMklCommonFormat(mkl_input_shapes, concat_dim, @@ -301,18 +397,17 @@ class MklConcatOp : public OpKernel { // All MKL tensors have a same format. Reorder is not needed. for (int k = 0; k < N; k++) { if (input_tensors[k].NumElements() == 0) continue; - auto src_md = mkl_input_shapes[k].GetMklLayout(); srcs[k].SetUsrMem(src_md, &input_tensors[k]); auto src_mpd = srcs[k].GetUsrMemPrimDesc(); srcs_pd.push_back(src_mpd); + inputs.push_back(srcs[k].GetOpMem()); } } else { // MKL tensors have different formats. // Reorder them to most common format. for (int k = 0; k < N; k++) { if (input_tensors[k].NumElements() == 0) continue; - auto src_md = mkl_input_shapes[k].GetMklLayout(); srcs[k].SetUsrMem(src_md, &input_tensors[k]); @@ -329,18 +424,25 @@ class MklConcatOp : public OpKernel { } else { // All TF inputs for (int k = 0; k < N; k++) { if (input_tensors[k].NumElements() == 0) continue; - - memory::dims src_dims = TFShapeToMklDnnDims(input_tensors[k].shape()); + TensorShape s_shape = input_tensors[k].shape(); + memory::dims src_dims = TFShapeToMklDnnDims(s_shape); dst_concat_dim_size += src_dims[concat_dim]; + size_t s_dims = s_shape.dims(); // It does not matter what data format to be used (NHWC versus NCHW). // We just need to ensure that output uses same data format as inputs. + if (s_dims == 4) + mkl_common_format = memory::format::nchw; + else if (s_dims == 2) + mkl_common_format = memory::format::nc; + auto src_md = - memory::desc(src_dims, MklDnnType(), memory::format::nchw); + memory::desc(src_dims, MklDnnType(), mkl_common_format); srcs[k].SetUsrMem(src_md, &input_tensors[k]); auto src_mpd = srcs[k].GetUsrMemPrimDesc(); srcs_pd.push_back(src_mpd); + inputs.push_back(srcs[k].GetOpMem()); } } dst_dims[concat_dim] = dst_concat_dim_size; @@ -352,31 +454,36 @@ class MklConcatOp : public OpKernel { // Since we are passing a specific format for destination, // we need to have dst_dims in MklDnn order (NCHW). auto orig_tf_format = mkl_input_shapes[0].GetTfDataFormat(); - dst_dims_in_nchw = MklDnnDimsInNCHW( - dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format)); - // Set the output format same as the most common format of inputs - // to avoid layout conversions. - dst_md = - memory::desc(dst_dims_in_nchw, MklDnnType(), mkl_common_format); + if (dst_dims.size() == 4) { + dst_dims_in_nchw = MklDnnDimsInNCHW( + dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format)); + // Set the output format same as the most common format of inputs + // to avoid layout conversions. + dst_md = memory::desc(dst_dims_in_nchw, MklDnnType(), + mkl_common_format); + } else if (dst_dims.size() == 2 && + mkl_common_format == memory::format::nc) { + // When memory::format::nc, dst_dims are already in MKL-DNN order + dst_md = memory::desc(dst_dims, MklDnnType(), mkl_common_format); + } else { + TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION, + "Unsupported tensor dimension or" + "MKL-DNN memory format")); + } } else { // All inputs are TF tensors. - // Set the output format same as input format (nchw). - dst_md = memory::desc(dst_dims, MklDnnType(), memory::format::nchw); + // Set the output format same as input format (nchw/nc). + dst_md = memory::desc(dst_dims, MklDnnType(), mkl_common_format); } - std::vector inputs; if (isMklReorderNeeded) { for (int k = 0; k < input_tensors.size(); k++) { if (input_tensors[k].NumElements() > 0) { srcs[k].CheckReorderToOpMem(srcs_pd[k]); + inputs.push_back(srcs[k].GetOpMem()); } } } - for (int k = 0; k < input_tensors.size(); k++) { - if (input_tensors[k].NumElements() > 0) { - inputs.push_back(srcs[k].GetOpMem()); - } - } // If all inputs are in MKL format, then meaning of concat_dim needs to // change. Value of concat_dim is tied to input Tensorflow data format @@ -388,53 +495,65 @@ class MklConcatOp : public OpKernel { if (are_all_mkl_inputs) concat_dim = mkl_input_shapes[0].TfDimIdx(concat_dim); - auto concat_pd = concat::primitive_desc(concat_dim, srcs_pd); - auto dst_pd = concat_pd.dst_primitive_desc(); - - MklDnnShape dnn_shape_dst; - TensorShape tf_shape_dst; - Tensor* dst_tensor = nullptr; - if (are_all_mkl_inputs) { - dnn_shape_dst.SetMklTensor(true); + if (!inputs.empty()) { + auto concat_pd = concat::primitive_desc(concat_dim, srcs_pd); auto dst_pd = concat_pd.dst_primitive_desc(); - dnn_shape_dst.SetMklLayout(&dst_pd); - dnn_shape_dst.SetElemType(MklDnnType()); - dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw, - mkl_input_shapes[0].GetTfDataFormat()); - tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T))); + + MklDnnShape dnn_shape_dst; + TensorShape tf_shape_dst; + Tensor* dst_tensor = nullptr; + if (are_all_mkl_inputs) { + dnn_shape_dst.SetMklTensor(true); + auto dst_pd = concat_pd.dst_primitive_desc(); + dnn_shape_dst.SetMklLayout(&dst_pd); + dnn_shape_dst.SetElemType(MklDnnType()); + dnn_shape_dst.SetTfLayout(dst_dims.size(), dst_dims_in_nchw, + mkl_input_shapes[0].GetTfDataFormat()); + tf_shape_dst.AddDim((dst_pd.get_size() / sizeof(T))); + } else { + dnn_shape_dst.SetMklTensor(false); + tf_shape_dst = MklDnnDimsToTFShape(dst_dims); + } + AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, + dnn_shape_dst); + DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL"; + + if (dnn_shape_dst.IsMklTensor()) dst_md = dnn_shape_dst.GetMklLayout(); + dst.SetUsrMem(dst_md, dst_tensor); + + auto concat_op = concat(concat_pd, inputs, dst.GetOpMem()); + std::vector net; + net.push_back(concat_op); + stream(stream::kind::eager).submit(net).wait(); + + // For quantized concat, min and max outputs are also computed. + if (quantized_input) { + Tensor* output_min = nullptr; + Tensor* output_max = nullptr; + MklDnnShape output_min_mkl_shape, output_max_mkl_shape; + output_min_mkl_shape.SetMklTensor(false); + output_max_mkl_shape.SetMklTensor(false); + AllocateOutputSetMklShape(context, 1, &output_min, {}, + output_min_mkl_shape); + AllocateOutputSetMklShape(context, 2, &output_max, {}, + output_max_mkl_shape); + // All input tensors should have the same range, just use the + // first one + output_min->flat()(0) = input_mins[0].flat()(0); + output_max->flat()(0) = input_maxes[0].flat()(0); + } } else { + MklDnnShape dnn_shape_dst; + TensorShape tf_shape_dst; + Tensor* dst_tensor = nullptr; dnn_shape_dst.SetMklTensor(false); tf_shape_dst = MklDnnDimsToTFShape(dst_dims); + + AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, + dnn_shape_dst); + DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL"; } - AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst, - dnn_shape_dst); - CHECK_NOTNULL(dst_tensor); - dst_md = - dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() : dst_md; - dst.SetUsrMem(dst_md, dst_tensor); - - auto concat_op = concat(concat_pd, inputs, dst.GetOpMem()); - std::vector net; - net.push_back(concat_op); - stream(stream::kind::eager).submit(net).wait(); - - // For quantized concat, min and max outputs are also computed. - if (std::is_same::value || std::is_same::value) { - Tensor* output_min = nullptr; - Tensor* output_max = nullptr; - MklDnnShape output_min_mkl_shape, output_max_mkl_shape; - output_min_mkl_shape.SetMklTensor(false); - output_max_mkl_shape.SetMklTensor(false); - AllocateOutputSetMklShape(context, 1, &output_min, {}, - output_min_mkl_shape); - AllocateOutputSetMklShape(context, 2, &output_max, {}, - output_max_mkl_shape); - // All input tensors should have the same range, just use the - // first one - output_min->flat()(0) = input_mins[0].flat()(0); - output_max->flat()(0) = input_maxes[0].flat()(0); - } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -446,45 +565,48 @@ class MklConcatOp : public OpKernel { } void CallEigenVersion(OpKernelContext* context, const OpInputList& values, - const MklDnnShapeList& mkl_input_shapes) { - CHECK_EQ(values.size(), mkl_input_shapes.size()); - - std::vector converted_values; + const OpInputList& input_mins, + const OpInputList& input_maxes, + const MklDnnShapeList& mkl_input_shapes, + bool quantized_input) { + size_t num_mkl_input_shapes = mkl_input_shapes.size(); + DCHECK_EQ(values.size(), num_mkl_input_shapes); + std::vector converted_values(num_mkl_input_shapes); TensorShapeList tf_input_shapes; - for (int i = 0; i < mkl_input_shapes.size(); i++) { + for (size_t i = 0; i < num_mkl_input_shapes; ++i) { if (mkl_input_shapes[i].IsMklTensor()) { // do conversion from MKL to TF Tensor tmp_tensor = ConvertMklToTF(context, values[i], mkl_input_shapes[i]); - converted_values.push_back(tmp_tensor); + converted_values[i] = tmp_tensor; tf_input_shapes.push_back(mkl_input_shapes[i].GetTfShape()); } else { // no conversion since it is TF tensor already - converted_values.push_back(values[i]); + converted_values[i] = values[i]; tf_input_shapes.push_back(values[i].shape()); } } // Call Eigen concat. - eigen_concat_op_.Compute(context, converted_values, tf_input_shapes); + eigen_concat_op_.Compute(context, converted_values, tf_input_shapes, + input_mins, input_maxes, quantized_input); - // Set output Mkl tensor for this op. - MklDnnShape dnn_shape_output; - dnn_shape_output.SetMklTensor(false); - dnn_shape_output.SetDimensions(4); - Tensor* output_tensor = nullptr; - TensorShape tf_shape_output; - tf_shape_output.AddDim(dnn_shape_output.GetSerializeBufferSize()); - OP_REQUIRES_OK(context, - context->allocate_output( - GetTensorMetaDataIndex(0, context->num_outputs()), - tf_shape_output, &output_tensor)); - dnn_shape_output.SerializeMklDnnShape( - output_tensor->flat().data(), - output_tensor->flat().size() * sizeof(uint8)); + // Get the number of dims from first input since all input tensors + // should have same rank. + size_t dims = values[0].shape().dims(); + MklDnnShape output_data_mkl_shape; + output_data_mkl_shape.SetMklTensor(false); + output_data_mkl_shape.SetDimensions(dims); + AllocateOutputSetMklShape(context, 0, output_data_mkl_shape); + if (quantized_input) { + MklDnnShape output_min_max_mkl_shape; + output_min_max_mkl_shape.SetMklTensor(false); + AllocateOutputSetMklShape(context, 1, output_min_max_mkl_shape); + AllocateOutputSetMklShape(context, 2, output_min_max_mkl_shape); + } } - // This method finds the most commom format across all MKL inputs + // This method finds the most common format across all MKL inputs // Inputs: // 1. input_shapes: shapes of input (MKL) tensors. // 2. concat_dim: concat dimension. @@ -550,6 +672,7 @@ class MklConcatOp : public OpKernel { MklConcatOp) TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_bfloat16(REGISTER_MKL_CPU); REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") .Device(DEVICE_CPU) diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 47b2a43ed92..13d07f5dd22 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -414,7 +414,7 @@ class MklConvCustomBackpropFilterOp // if output tensor has more than 0 elements, we need to 0 them out. auto diff_filter_data = diff_filter_tensor->flat().data(); for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) { - diff_filter_data[i] = 0; + diff_filter_data[i] = static_cast(0); } return; } @@ -731,6 +731,7 @@ class MklConvCustomBackpropFilterOp MklConvCustomBackpropFilterOp); TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); +TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS); #undef REGISTER_MKL_FILTER_KERNELS } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 4e955df5fe9..cd03b8ced09 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -351,7 +351,7 @@ class MklConvCustomBackpropInputOp // if output tensor has more than 0 elements, we need to 0 them out. auto diff_src_data = diff_src_tensor->flat().data(); for (size_t i = 0; i < diff_src_tf_shape.num_elements(); ++i) { - diff_src_data[i] = 0; + diff_src_data[i] = static_cast(0); } return; } @@ -574,6 +574,7 @@ class MklConvCustomBackpropInputOp .Label(mkl_op_registry::kMklOpLabel), \ MklConvCustomBackpropInputOp); TF_CALL_float(REGISTER_MKL_CPU_KERNELS); +TF_CALL_bfloat16(REGISTER_MKL_CPU_KERNELS); #undef REGISTER_MKL_CPU_KERNELS } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index da999d28b1f..e406081d481 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -16,11 +16,15 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. #ifdef INTEL_MKL +#include "tensorflow/core/kernels/mkl_conv_ops.h" + #include + #include #include #include +#include "mkldnn.hpp" #include "absl/strings/str_join.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" @@ -29,7 +33,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/mkl_conv_ops.h" #include "tensorflow/core/kernels/mkl_quantized_conv_ops.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/kernels/ops_util.h" @@ -40,28 +43,17 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" -#include "tensorflow/core/util/mkl_util.h" - -#ifndef INTEL_MKL_ML_ONLY -#include "mkldnn.hpp" - using mkldnn::prop_kind; using mkldnn::stream; using mkldnn::convolution_forward; using mkldnn::convolution_direct; -#else -#include "mkl_dnn.h" -#include "mkl_dnn_types.h" -#endif - namespace tensorflow { -#ifndef INTEL_MKL_ML_ONLY - // This structure aggregates multiple inputs to Conv2DFwd* methods. struct MklConvFwdParams { memory::dims src_dims; @@ -96,9 +88,8 @@ struct MklConvFwdParams { typedef mkldnn::convolution_forward::primitive_desc ConvFwdPd; // With quantization, input, filter, and output can have different types -// so we use differnt template parameter for each type -template +// so we use different template parameter for each type +template class MklConvFwdPrimitive : public MklPrimitive { public: explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) @@ -261,10 +252,11 @@ class MklConvFwdPrimitive : public MklPrimitive { float op_scale = post_op_param.param[0]; post_ops.append_sum(op_scale); } else if (post_op_param.name == "output_scale") { - DCHECK_EQ(post_op_param.param.size(), 1); - std::vector scales; - scales.push_back(post_op_param.param[0]); - post_ops_attr.set_output_scales(0, scales); + if (post_op_param.param.size() == 1) { + post_ops_attr.set_output_scales(0, post_op_param.param); + } else { + post_ops_attr.set_output_scales(2, post_op_param.param); + } } else { DCHECK((post_op_param.name == "relu") || (post_op_param.name == "sum") || @@ -296,7 +288,7 @@ class MklConvFwdPrimitive : public MklPrimitive { // Create convolution primitive and add it to net if (!convFwdDims.bias_dims.empty()) { context_.bias_mem.reset(new memory( - {{{convFwdDims.bias_dims}, MklDnnType(), memory::format::x}, + {{{convFwdDims.bias_dims}, MklDnnType(), memory::format::x}, cpu_engine_}, DummyData)); context_.conv_fwd.reset(new convolution_forward( @@ -316,29 +308,31 @@ class MklConvFwdPrimitive : public MklPrimitive { engine cpu_engine_; }; -template -class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { +// TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory. +// But removing the need for type in MklPrimitiveFactory is going to require +// change to every MKL op. So not doing it now. Instead passing float. +template +class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { public: - static MklConvFwdPrimitive* Get( + static MklConvFwdPrimitive* Get( const MklConvFwdParams& convFwdDims, bool do_not_cache) { - MklConvFwdPrimitive* conv_fwd = nullptr; + MklConvFwdPrimitive* conv_fwd = nullptr; if (do_not_cache) { // Always create a new primitive - conv_fwd = new MklConvFwdPrimitive( - convFwdDims); + conv_fwd = + new MklConvFwdPrimitive(convFwdDims); } else { // Try to find a suitable one in pool - conv_fwd = dynamic_cast< - MklConvFwdPrimitive*>( - MklConvFwdPrimitiveFactory::GetInstance() - .GetConvFwd(convFwdDims)); + conv_fwd = + dynamic_cast*>( + MklConvFwdPrimitiveFactory::GetInstance() + .GetConvFwd(convFwdDims)); if (conv_fwd == nullptr) { - conv_fwd = new MklConvFwdPrimitive( + conv_fwd = new MklConvFwdPrimitive( convFwdDims); - MklConvFwdPrimitiveFactory::GetInstance() .SetConvFwd(convFwdDims, conv_fwd); } @@ -376,21 +370,15 @@ class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { for (auto const& post_op_param : convFwdDims.post_op_params) { if (post_op_param.name == "relu") { DCHECK_EQ(post_op_param.param.size(), 3); - key_creator.AddAsKey(post_op_param.name); - key_creator.AddAsKey(post_op_param.param[0]); - key_creator.AddAsKey(post_op_param.param[1]); - key_creator.AddAsKey(post_op_param.param[2]); } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); - key_creator.AddAsKey(post_op_param.name); - key_creator.AddAsKey(post_op_param.param[0]); - } else if (post_op_param.name == "output_scale") { - DCHECK_EQ(post_op_param.param.size(), 1); - key_creator.AddAsKey(post_op_param.name); - key_creator.AddAsKey(post_op_param.param[0]); - } else { + } else if (post_op_param.name != "output_scale") { return string("not_a_key"); } + key_creator.AddAsKey(post_op_param.name); + for (auto& param : post_op_param.param) { + key_creator.AddAsKey(param); + } } return key_creator.GetKey(); @@ -407,449 +395,8 @@ class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { } }; -#endif - typedef Eigen::ThreadPoolDevice CPUDevice; -// For now, MKL-ML is default. So making MKL-DNN not a default choice. -#ifdef INTEL_MKL_ML_ONLY -template -class MklConvOp : public OpKernel { - public: - ~MklConvOp() {} - - explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); - string data_format; - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(context, strides_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - - const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); - const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); - OP_REQUIRES( - context, stride_n == 1 && stride_c == 1, - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - } - - void Compute(OpKernelContext* context) override { - MklConv2DOpContext mkl_context; - const Tensor& input = MklGetInput(context, 0); - GetMklShape(context, 0, &(mkl_context.input_shape)); - bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); - - const Tensor& filter = MklGetInput(context, 1); - MklShape mkl_filter_shape; - GetMklShape(context, 1, &mkl_filter_shape); - CHECK(!mkl_filter_shape.IsMklTensor()) - << "Conv filter should not be in MKL Layout"; - - if (bias_enabled) { - const Tensor& bias = MklGetInput(context, 2); - OP_REQUIRES(context, bias.dims() == 1, - errors::InvalidArgument("bias must be 1-dimensional: ", - bias.shape().DebugString())); - } - - if (!input_in_mkl_format) { - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - } - - OP_REQUIRES(context, filter.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter.shape().DebugString())); - - for (int i = 0; i < 3; ++i) { - OP_REQUIRES( - context, - FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), - errors::InvalidArgument("filter too large")); - } - - const int64 input_depth = - input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C') - : GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES(context, input_depth == filter.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", input_depth, - " vs ", filter.dim_size(2))); - // The last dimension for filter is out_depth. - const int out_depth = static_cast(filter.dim_size(3)); - - // The second dimension for input is rows/height. - // The first dimension for filter is rows/height. - const int64 input_rows_raw = - input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H') - : GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES( - context, - FastBoundsCheck(input_rows_raw, std::numeric_limits::max()), - errors::InvalidArgument("Input rows too large")); - const int input_rows = static_cast(input_rows_raw); - const int filter_rows = static_cast(filter.dim_size(0)); - - // The third dimension for input is columns/width. - // The second dimension for filter is columns/width. - const int64 input_cols_raw = - input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W') - : GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES( - context, - FastBoundsCheck(input_cols_raw, std::numeric_limits::max()), - errors::InvalidArgument("Input cols too large")); - const int input_cols = static_cast(input_cols_raw); - const int filter_cols = static_cast(filter.dim_size(1)); - - // The first dimension for input is batch. - const int64 input_batch_raw = - input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N') - : GetTensorDim(input, data_format_, 'N'); - OP_REQUIRES( - context, - FastBoundsCheck(input_batch_raw, std::numeric_limits::max()), - errors::InvalidArgument("batch is too large")); - const int batch = static_cast(input_batch_raw); - - // For now we take the stride from the second and third dimensions only (we - // do not support striding on the batch or depth dimension). - const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); - const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_rows, filter_rows, stride_rows, - padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_cols, filter_cols, stride_cols, - padding_, &out_cols, &pad_cols)); - TensorShape out_shape = - ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); - - // Output tensor is of the following dimensions: - // [ in_batch, out_rows, out_cols, out_depth ] - Tensor* output = nullptr; - - // If there is nothing to compute, return. - if (out_shape.num_elements() == 0) { - // Nothing to do, allocate output tensor and return - MklShape mkl_output_mkl_shape; - mkl_output_mkl_shape.SetMklTensor(false); - AllocateOutputSetMklShape(context, 0, &output, input.shape(), - mkl_output_mkl_shape); - return; - } - - if (batch == 0) { - // Nothing to do, allocate output tensor and return - MklShape mkl_output_mkl_shape; - mkl_output_mkl_shape.SetMklTensor(false); - AllocateOutputSetMklShape(context, 0, &output, input.shape(), - mkl_output_mkl_shape); - return; - } - - // Create MKL convolution primitives - mkl_context.in_dims = input_in_mkl_format - ? mkl_context.input_shape.GetDimension() - : input.dims(); - mkl_context.filter_dims = filter.dims(); - - mkl_context.in_sizes[MklDims::W] = static_cast(input_cols); - mkl_context.in_sizes[MklDims::H] = static_cast(input_rows); - mkl_context.in_sizes[MklDims::C] = static_cast(input_depth); - mkl_context.in_sizes[MklDims::N] = static_cast(batch); - - mkl_context.out_sizes[MklDims::W] = static_cast(out_cols); - mkl_context.out_sizes[MklDims::H] = static_cast(out_rows); - mkl_context.out_sizes[MklDims::C] = static_cast(out_depth); - mkl_context.out_sizes[MklDims::N] = static_cast(batch); - - mkl_context.input_offset[0] = static_cast(-pad_cols); - mkl_context.input_offset[1] = static_cast(-pad_rows); - - mkl_context.conv_stride[0] = static_cast(stride_cols); - mkl_context.conv_stride[1] = static_cast(stride_rows); - - GetStridesFromSizes(data_format_, mkl_context.out_strides, - mkl_context.out_sizes); - GetStridesFromSizes(data_format_, mkl_context.in_strides, - mkl_context.in_sizes); - - // TF filter dimension order (out_depth, in_depth, cols, rows) -> - // MKL filter dimension order (out_depth, in_depth, rows, cols) - mkl_context.filter_sizes[0] = filter.dim_size(1); // cols - mkl_context.filter_sizes[1] = filter.dim_size(0); // rows - mkl_context.filter_sizes[2] = filter.dim_size(2); // in_depth - mkl_context.filter_sizes[3] = filter.dim_size(3); // out_depth - - // TF filter layout - (rows, cols, in_depth, out_depth) - mkl_context.filter_strides[0] = - filter.dim_size(2) * filter.dim_size(3); // cols - mkl_context.filter_strides[1] = - filter.dim_size(1) * filter.dim_size(2) * filter.dim_size(3); // rows - mkl_context.filter_strides[2] = filter.dim_size(3); // in_depth - mkl_context.filter_strides[3] = 1; // out_depth - - if (bias_enabled) { - const Tensor& bias = MklGetInput(context, 2); - mkl_context.bias_sizes[0] = {static_cast(bias.dim_size(0))}; - mkl_context.bias_strides[0] = {1}; - } - - // Create Convolution Primitive - if (bias_enabled) { - CHECK_EQ( - dnnConvolutionCreateForwardBias_F32( - &mkl_context.prim_fwd, nullptr, dnnAlgorithmConvolutionDirect, - mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes, - mkl_context.filter_sizes, mkl_context.conv_stride, - mkl_context.input_offset, dnnBorderZeros), - E_SUCCESS); - } else { - CHECK_EQ( - dnnConvolutionCreateForward_F32( - &mkl_context.prim_fwd, nullptr, dnnAlgorithmConvolutionDirect, - mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes, - mkl_context.filter_sizes, mkl_context.conv_stride, - mkl_context.input_offset, dnnBorderZeros), - E_SUCCESS); - } - - TensorShape mkl_output_tf_shape; - MklShape mkl_output_mkl_shape; - mkl_output_mkl_shape.SetMklTensor(true); - mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd, dnnResourceDst); - mkl_output_mkl_shape.SetTfLayout(mkl_context.in_dims, mkl_context.out_sizes, - mkl_context.out_strides); - // MKL might change the dimension ordering - // Create mapping to recover the original TF dimension order - mkl_output_mkl_shape.SetTfDimOrder(mkl_context.in_dims, data_format_); - - mkl_output_tf_shape.AddDim( - dnnLayoutGetMemorySize_F32( - static_cast(mkl_output_mkl_shape.GetMklLayout())) / - sizeof(T)); - AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape, - mkl_output_mkl_shape); - // Filter output to be used in the backprop_input - TensorShape mkl_filter_output_tf_shape; - MklShape mkl_filter_output_mkl_shape; - mkl_filter_output_mkl_shape.SetMklTensor(true); - mkl_filter_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd, - dnnResourceFilter); - - size_t filter_sizes[4] = {static_cast(filter.dim_size(0)), - static_cast(filter.dim_size(1)), - static_cast(filter.dim_size(2)), - static_cast(filter.dim_size(3))}; - mkl_filter_output_mkl_shape.SetTfLayout(filter.dims(), filter_sizes, - mkl_context.filter_strides); - - mkl_filter_output_mkl_shape.SetTfDimOrder(mkl_context.filter_dims, - data_format_); - mkl_filter_output_tf_shape.AddDim( - dnnLayoutGetMemorySize_F32(static_cast( - mkl_filter_output_mkl_shape.GetMklLayout())) / - sizeof(T)); - AllocateOutputSetMklShape(context, 1, &mkl_context.output_filter, - mkl_filter_output_tf_shape, - mkl_filter_output_mkl_shape); - - mkl_context.conv_res[dnnResourceDst] = - static_cast(output->flat().data()); - - mkl_context.MklCreateInputLayouts(context); - - // Temp tensor used to allocate tmp buffers - Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor, - mkl_tmp_bias_buf_tensor; - mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf_tensor, - &mkl_tmp_filter_buf_tensor, - &mkl_tmp_bias_buf_tensor); - - // Execute convolution - CHECK_EQ(dnnExecute_F32(mkl_context.prim_fwd, mkl_context.conv_res), - E_SUCCESS); - - mkl_context.MklCleanup(); - } - - private: - typedef struct { - int in_dims; - size_t in_sizes[4]; - size_t in_strides[4]; - size_t out_sizes[4]; - size_t out_strides[4]; - int filter_dims; - size_t filter_sizes[4]; - size_t filter_strides[4]; - size_t bias_sizes[1]; - size_t bias_strides[1]; - int input_offset[2]; - size_t conv_stride[2]; - MklShape input_shape; - dnnPrimitive_t prim_fwd; - void* conv_res[dnnResourceNumber]; - dnnLayout_t lt_filter, lt_bias, lt_input; - Tensor* output_filter = nullptr; - - // Create MKL dnnLayout_t objects for tensors coming into the layer - void MklCreateInputLayouts(OpKernelContext* context) { - bool input_in_mkl_format = input_shape.IsMklTensor(); - if (input_in_mkl_format) { - lt_input = static_cast(input_shape.GetCurLayout()); - } else { - CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides), - E_SUCCESS); - } - - CHECK_EQ(dnnLayoutCreate_F32(<_filter, filter_dims, filter_sizes, - filter_strides), - E_SUCCESS); - - if (bias_enabled) { - CHECK_EQ(dnnLayoutCreate_F32(<_bias, 1, bias_sizes, bias_strides), - E_SUCCESS); - } - } - - // Compare incoming tensor layouts with MKL preferred layouts and convert - // data to the preferred layout if necessary - void MklPrepareConvolutionInputs(OpKernelContext* context, - Tensor* mkl_tmp_input_buf_tensor, - Tensor* mkl_tmp_filter_buf_tensor, - Tensor* mkl_tmp_bias_buf_tensor) { - bool mkl_convert_input, mkl_convert_filter, mkl_convert_bias; - dnnPrimitive_t mkl_prim_convert_filter, mkl_prim_convert_bias, - mkl_prim_convert_input; - dnnLayout_t mkl_lt_internal_filter, mkl_lt_internal_bias, - mkl_lt_internal_input; - void *mkl_buf_convert_input, *mkl_buf_convert_filter, - *mkl_buf_convert_bias; - mkl_prim_convert_filter = nullptr; - mkl_prim_convert_bias = nullptr; - mkl_prim_convert_input = nullptr; - mkl_lt_internal_filter = nullptr; - mkl_lt_internal_bias = nullptr; - mkl_lt_internal_input = nullptr; - mkl_buf_convert_input = nullptr; - mkl_buf_convert_filter = nullptr; - mkl_buf_convert_bias = nullptr; - - // Compare with internal layouts and convert if needed - const Tensor& input = MklGetInput(context, 0); - void* mkl_buf_input = - const_cast(static_cast(input.flat().data())); - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input, - prim_fwd, dnnResourceSrc), - E_SUCCESS); - mkl_convert_input = - !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input); - if (mkl_convert_input) { - CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input, - mkl_lt_internal_input), - E_SUCCESS); - AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, - &mkl_buf_convert_input); - CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, - mkl_buf_convert_input), - E_SUCCESS); - dnnDelete_F32(mkl_prim_convert_input); - } - dnnLayoutDelete_F32(mkl_lt_internal_input); - - conv_res[dnnResourceSrc] = - (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input; - - const Tensor& filter = MklGetInput(context, 1); - void* mkl_buf_filter = - const_cast(static_cast(filter.flat().data())); - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_filter, - prim_fwd, dnnResourceFilter), - E_SUCCESS); - mkl_convert_filter = - !dnnLayoutCompare_F32(mkl_lt_internal_filter, lt_filter); - if (mkl_convert_filter) { - CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_filter, lt_filter, - mkl_lt_internal_filter), - E_SUCCESS); - - mkl_buf_convert_filter = const_cast( - static_cast(output_filter->flat().data())); - - CHECK_EQ( - dnnConversionExecute_F32(mkl_prim_convert_filter, mkl_buf_filter, - mkl_buf_convert_filter), - E_SUCCESS); - dnnDelete_F32(mkl_prim_convert_filter); - } - dnnLayoutDelete_F32(mkl_lt_internal_filter); - - conv_res[dnnResourceFilter] = - (mkl_convert_filter) ? mkl_buf_convert_filter : mkl_buf_filter; - - if (bias_enabled) { - const Tensor& bias = MklGetInput(context, 2); - void* mkl_buf_bias = - const_cast(static_cast(bias.flat().data())); - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_bias, - prim_fwd, dnnResourceBias), - E_SUCCESS); - mkl_convert_bias = !dnnLayoutCompare_F32(mkl_lt_internal_bias, lt_bias); - if (mkl_convert_bias) { - CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_bias, lt_bias, - mkl_lt_internal_bias), - E_SUCCESS); - AllocTmpBuffer(context, mkl_tmp_bias_buf_tensor, mkl_lt_internal_bias, - &mkl_buf_convert_bias); - CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_bias, mkl_buf_bias, - mkl_buf_convert_bias), - E_SUCCESS); - dnnDelete_F32(mkl_prim_convert_bias); - } - dnnLayoutDelete_F32(mkl_lt_internal_bias); - - conv_res[dnnResourceBias] = - (mkl_convert_bias) ? mkl_buf_convert_bias : mkl_buf_bias; - } - } - - void MklCleanup() { - bool input_in_mkl_format = input_shape.IsMklTensor(); - dnnDelete_F32(prim_fwd); - if (!input_in_mkl_format) dnnLayoutDelete_F32(lt_input); - dnnLayoutDelete_F32(lt_filter); - if (bias_enabled) dnnLayoutDelete_F32(lt_bias); - } - } MklConv2DOpContext; - - std::vector strides_; - Padding padding_; - TensorFormat data_format_; -}; - -// FP32 kernel registration for INTEL_MKL_ML -REGISTER_KERNEL_BUILDER(Name("_MklConv2D") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklConv2DOp); -REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklConv2DOp); - -#else - // Base class for convolution forward operations template * - conv_fwd = nullptr; + MklConvFwdPrimitive* conv_fwd = + nullptr; memory::dims bias_dims = {}; if (fuse_biasadd_) { conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); @@ -1076,17 +621,16 @@ class MklConvOp : public OpKernel { // TODO(mdfaijul): Extend the basic parameters for data types and fusions this->ExtendConvFwdParams(context, convFwdDims); - conv_fwd = MklConvFwdPrimitiveFactory::Get(convFwdDims, - do_not_cache); + conv_fwd = + MklConvFwdPrimitiveFactory::Get( + convFwdDims, do_not_cache); // Allocate output tensors `output_tensor` and `filter_out_tensor` std::shared_ptr conv_fwd_pd = conv_fwd->GetPrimitiveDesc(); AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt, &dst_tensor); Tensor* filter_out_tensor = nullptr; - if (typeid(Tinput) == typeid(float) && typeid(Tfilter) == typeid(float) && - typeid(Toutput) == typeid(float)) { + if (emit_filter_output) { AllocateFilterOutputTensor(context, *conv_fwd_pd, TFShapeToMklDnnDims(filter_tf_shape), &filter_out_tensor); @@ -1190,8 +734,8 @@ class MklConvOp : public OpKernel { // Similarly, if the data format is NCHW, indices 0, 1, 2 and 3 of // paddings(_tf) will be zero. // i.e. for the above example, paddings = {0, 0, 0, 0, 1, 2, 3, 4}. - int64 pad_top, pad_left; - int64 pad_bottom, pad_right; + int64 pad_top = 0, pad_left = 0; + int64 pad_bottom = 0, pad_right = 0; string data_format = ToString(data_format_); if (data_format == "NHWC") { pad_top = paddings[2]; @@ -1501,10 +1045,10 @@ class MklFusedConvOp // We create new class for each version of Quantized Convolution and inherit // from the FP32 version of the base class template + typename Ttemp_output, bool bias_enabled, bool is_depthwise> class MklQuantizedConv2DOp : public MklConvOp { + int32, bias_enabled, false, is_depthwise> { public: virtual ~MklQuantizedConv2DOp() { if (this->input_bias_ != nullptr) { @@ -1520,7 +1064,7 @@ class MklQuantizedConv2DOp explicit MklQuantizedConv2DOp(OpKernelConstruction* context) : MklConvOp(context) { + bias_enabled, false, is_depthwise>(context) { bool is_filter_const; OP_REQUIRES_OK(context, context->GetAttr("is_filter_const", &is_filter_const)); @@ -1531,7 +1075,7 @@ class MklQuantizedConv2DOp void Compute(OpKernelContext* context) override { // Compute int32 output tensor MklConvOp::Compute(context); + bias_enabled, false, is_depthwise>::Compute(context); // Compute additional outputs: min/max scalars. int bias_index_offset; @@ -1541,44 +1085,60 @@ class MklQuantizedConv2DOp context->input(2 + bias_index_offset).flat()(0); const float max_input = context->input(3 + bias_index_offset).flat()(0); - const float min_filter = - context->input(4 + bias_index_offset).flat()(0); - const float max_filter = - context->input(5 + bias_index_offset).flat()(0); - float min_output_value; - float max_output_value; - if (std::is_same::value || - std::is_same::value) { - // This is the case when convolution and requantization are fused. - // min_freezed_output and max_freezed_output are the actual range - // of the output. - min_output_value = context->input(6 + bias_index_offset).flat()(0); - max_output_value = context->input(7 + bias_index_offset).flat()(0); - } else { - MklQuantizationRangeForMultiplication( - min_input, max_input, min_filter, max_filter, &min_output_value, - &max_output_value); - } - - Tensor* output_min = nullptr; - Tensor* output_max = nullptr; MklDnnShape output_min_mkl_shape, output_max_mkl_shape; output_min_mkl_shape.SetMklTensor(false); output_max_mkl_shape.SetMklTensor(false); - AllocateOutputSetMklShape(context, 1, &output_min, {}, - output_min_mkl_shape); - AllocateOutputSetMklShape(context, 2, &output_max, {}, - output_max_mkl_shape); - output_min->flat()(0) = min_output_value; - output_max->flat()(0) = max_output_value; + + Tensor* output_min = nullptr; + Tensor* output_max = nullptr; + if (std::is_same::value || + std::is_same::value) { + AllocateOutputSetMklShape(context, 1, &output_min, {}, + output_min_mkl_shape); + AllocateOutputSetMklShape(context, 2, &output_max, {}, + output_max_mkl_shape); + // This is the case the convolution and requantization are fused. + output_min->flat()(0) = + context->input(6 + bias_index_offset).flat()(0); + output_max->flat()(0) = + context->input(7 + bias_index_offset).flat()(0); + } else { + const Tensor& min_filter = context->input(4 + bias_index_offset); + const Tensor& max_filter = context->input(5 + bias_index_offset); + if (min_filter.dims() == 0) { + float min_output_value; + float max_output_value; + MklQuantizationRangeForMultiplication( + min_input, max_input, min_filter.flat()(0), + max_filter.flat()(0), &min_output_value, &max_output_value); + AllocateOutputSetMklShape(context, 1, &output_min, {}, + output_min_mkl_shape); + AllocateOutputSetMklShape(context, 2, &output_max, {}, + output_max_mkl_shape); + output_min->flat()(0) = min_output_value; + output_max->flat()(0) = max_output_value; + } else { + size_t depth = min_filter.NumElements(); + AllocateOutputSetMklShape(context, 1, &output_min, + {static_cast(depth)}, + output_min_mkl_shape); + AllocateOutputSetMklShape(context, 2, &output_max, + {static_cast(depth)}, + output_max_mkl_shape); + MklQuantizationRangeForMultiplication( + min_input, max_input, min_filter, max_filter, &output_min, + &output_max); + } + } } protected: void ExtendConvFwdParams(OpKernelContext* context, MklConvFwdParams& params) override { MklConvOp::ExtendConvFwdParams(context, params); + bias_enabled, false, is_depthwise>::ExtendConvFwdParams(context, + params); // When the output type is quint8, the output data id requantized // into quint8. A post_op "output_scale" is added to do the conversion. @@ -1591,33 +1151,34 @@ class MklQuantizedConv2DOp context->input(2 + bias_index_offset).flat()(0); const float max_input = context->input(3 + bias_index_offset).flat()(0); - const float min_filter = - context->input(4 + bias_index_offset).flat()(0); - const float max_filter = - context->input(5 + bias_index_offset).flat()(0); + const Tensor& min_filter_vector = context->input(4 + bias_index_offset); + const Tensor& max_filter_vector = context->input(5 + bias_index_offset); + + // min_freezed_output and max_freezed_output are the actual range + // for the output. const float min_freezed_output = context->input(6 + bias_index_offset).flat()(0); const float max_freezed_output = context->input(7 + bias_index_offset).flat()(0); - float min_output_value; - float max_output_value; - MklQuantizationRangeForMultiplication( - min_input, max_input, min_filter, max_filter, &min_output_value, - &max_output_value); - float scale_int32 = - std::max(std::abs(min_output_value), std::abs(max_output_value)); - float scale_eightbit = + float factor = std::is_same::value ? 255.0f : 127.0f; + size_t depth = min_filter_vector.NumElements(); + const float* min_filter = min_filter_vector.flat().data(); + const float* max_filter = max_filter_vector.flat().data(); + std::vector scales(depth); + float input_range = std::max(std::abs(min_input), std::abs(max_input)); + float output_range = std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); - float scale = 1.0; - if (std::is_same::value) - scale = scale_int32 / scale_eightbit / static_cast(1 << 23); - else - scale = scale_int32 / scale_eightbit / static_cast(1 << 24); - - std::vector output_scale; - output_scale.push_back(scale); - params.post_op_params.push_back({"output_scale", output_scale}); + for (size_t i = 0; i < depth; ++i) { + // For simplicity and symmetry, we set filter range to be outer + // bounds of min_filter and max_filter. + float filter_range = + std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); + // To understand the scaling, please see mkl_requantize_ops_test. + scales[i] = factor * input_range * filter_range / + (255.0f * 127.0f * output_range); + } + params.post_op_params.push_back({"output_scale", scales}); } } @@ -1631,10 +1192,10 @@ class MklQuantizedConv2DOp context->input(2 + bias_index_offset).flat()(0); const float max_input = context->input(3 + bias_index_offset).flat()(0); - const float min_filter = - context->input(4 + bias_index_offset).flat()(0); - const float max_filter = - context->input(5 + bias_index_offset).flat()(0); + const Tensor& min_filter_vector = context->input(4 + bias_index_offset); + const Tensor& max_filter_vector = context->input(5 + bias_index_offset); + const float* min_filter = min_filter_vector.flat().data(); + const float* max_filter = max_filter_vector.flat().data(); std::vector net; if (bias_enabled) { @@ -1644,17 +1205,29 @@ class MklQuantizedConv2DOp } // If bias is enabled and requantization is not fused, scale the // bias to be consistent with quantized-input and quantized-filter. - float bias_scale = 255.0 * 127.0 / - (std::max(std::abs(max_input), std::abs(min_input)) * - std::max(std::abs(max_filter), std::abs(min_filter))); - std::vector scales; - scales.push_back(bias_scale); + size_t depth = min_filter_vector.NumElements(); + std::vector scales(depth); + for (size_t i = 0; i < depth; ++i) { + scales[i] = + 255.0 * 127.0 / + (std::max(std::abs(max_input), std::abs(min_input)) * + std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); + } mkldnn::primitive_attr bias_attr; - bias_attr.set_output_scales(0, scales); + if (depth == 1) { + bias_attr.set_output_scales(0, scales); + } else { + bias_attr.set_output_scales(1, scales); + } + auto bias_pd = + memory::primitive_desc({{static_cast(bias_tensor.NumElements())}, + MklDnnType(), + memory::format::x}, + this->cpu_engine_); void* bias_buf = static_cast( const_cast(bias_tensor.flat().data())); - input_bias_ = new memory(conv_fwd_pd->bias_primitive_desc(), bias_buf); + input_bias_ = new memory(bias_pd, bias_buf); scaled_bias_ = new memory(conv_fwd_pd->bias_primitive_desc()); auto reorder_desc = mkldnn::reorder::primitive_desc( input_bias_->get_primitive_desc(), scaled_bias_->get_primitive_desc(), @@ -1672,31 +1245,31 @@ class MklQuantizedConv2DOp }; template + typename Ttemp_output, bool bias_enabled, bool is_depthwise> class MklQuantizedConv2DReluOp : public MklQuantizedConv2DOp { + bias_enabled, is_depthwise> { public: virtual ~MklQuantizedConv2DReluOp() {} explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context) - : MklQuantizedConv2DOp(context) {} + : MklQuantizedConv2DOp(context) {} protected: void ExtendConvFwdParams(OpKernelContext* context, MklConvFwdParams& params) override { - MklQuantizedConv2DOp::ExtendConvFwdParams(context, params); + MklQuantizedConv2DOp::ExtendConvFwdParams(context, params); params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}}); } }; template + typename Ttemp_output, bool bias_enabled, bool is_depthwise> class MklQuantizedConv2DSumReluOp : public MklQuantizedConv2DOp { + bias_enabled, is_depthwise> { public: virtual ~MklQuantizedConv2DSumReluOp() { if (this->summand_ != nullptr) { @@ -1711,14 +1284,14 @@ class MklQuantizedConv2DSumReluOp } explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context) - : MklQuantizedConv2DOp(context) {} + : MklQuantizedConv2DOp(context) {} protected: void ExtendConvFwdParams(OpKernelContext* context, MklConvFwdParams& params) override { - MklQuantizedConv2DOp::ExtendConvFwdParams(context, params); + MklQuantizedConv2DOp::ExtendConvFwdParams(context, params); // Calculate the scale (beta in mkldnn api term) for sum if (std::is_same::value) { int summand_idx = context->num_inputs() / 2 - 1 - 2; @@ -1740,12 +1313,15 @@ class MklQuantizedConv2DSumReluOp std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); float scale_summand = std::max(std::abs(min_freezed_summand), std::abs(max_freezed_summand)); + // if summand_type is also DT_QUINT8 as the scale_output, + // the scaling factor of 255.0f cancels each other and thus is avoided. + // If it is not then it is DT_INT8 and is scaled appropriately. if (summand_type == DT_QUINT8) params.post_op_params.push_back( {"sum", {scale_summand / scale_output}}); else params.post_op_params.push_back( - {"sum", {2.0f * scale_summand / scale_output}}); + {"sum", {255.0f * scale_summand / (scale_output * 127.0f)}}); } else { params.post_op_params.push_back({"sum", {1.0}}); } @@ -1758,7 +1334,6 @@ class MklQuantizedConv2DSumReluOp memory::format output_tf_format, Tensor** output_tensor) override { int summand_idx = context->num_inputs() / 2 - 1; - float reorder_sum_scale = 1.0; if (std::is_same::value) { summand_idx -= 2; DataType summand_type = this->input_type(summand_idx); @@ -1769,25 +1344,22 @@ class MklQuantizedConv2DSumReluOp MklDnnShape summand_mkl_shape; GetMklShape(context, summand_idx, &summand_mkl_shape); auto dst_md = summand_mkl_shape.GetMklLayout(); - if (summand_mkl_shape.IsMklTensor()) { - if (summand_type == DT_QINT8) { - OP_REQUIRES_OK(context, summand.BitcastFrom(summand, DT_QUINT8, - summand.shape())); - dst_md.data.data_type = - static_cast(MklDnnType()); - summand_mkl_shape.SetMklLayout(&dst_md); - summand_mkl_shape.SetElemType(MklDnnType()); - } - ForwardMklTensorInToOutWithMklShape(context, summand_idx, 0, - summand_mkl_shape); - *output_tensor = const_cast(&summand); - return; - } else { - TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION, - "Current fusion is not successful.")); + + // TODO(mdfaijul): handle both non-MKL and MKL tensors + if (summand_type == DT_QINT8) { + OP_REQUIRES_OK( + context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape())); + dst_md.data.data_type = + static_cast(MklDnnType()); + summand_mkl_shape.SetMklLayout(&dst_md); + summand_mkl_shape.SetElemType(MklDnnType()); } + ForwardMklTensorInToOutWithMklShape(context, summand_idx, 0, + summand_mkl_shape); + *output_tensor = const_cast(&summand); + return; } - // TODO(mdfaijul): Add cleaner code for non-mkl tensor + MklConvOp::AllocateOutputTensor(context, conv_prim_desc, @@ -1805,19 +1377,28 @@ class MklQuantizedConv2DSumReluOp context->input(2 + bias_index_offset).flat()(0); const float max_input = context->input(3 + bias_index_offset).flat()(0); - const float min_filter = - context->input(4 + bias_index_offset).flat()(0); - const float max_filter = - context->input(5 + bias_index_offset).flat()(0); + const Tensor& min_filter_vector = context->input(4 + bias_index_offset); + const Tensor& max_filter_vector = context->input(5 + bias_index_offset); + const float* min_filter = min_filter_vector.flat().data(); + const float* max_filter = max_filter_vector.flat().data(); - reorder_sum_scale = 255.0 * 127.0 / - (std::max(std::abs(max_input), std::abs(min_input)) * - std::max(std::abs(max_filter), std::abs(min_filter))); - std::vector scales; - scales.push_back(reorder_sum_scale); + size_t depth = min_filter_vector.NumElements(); + std::vector scales(depth); + for (size_t i = 0; i < depth; ++i) { + // TODO(nammbash): scale factors for UINT8(inputs) & INT8(weights) are + // done regularly. A Cleaner design to address all mapping in one + // function needs to be implemented in future which also supports other + // quantized type mapping in future. + scales[i] = 255.0 * 127.0 / + (std::max(std::abs(max_input), std::abs(min_input)) * + std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); + } mkldnn::primitive_attr reorder_attr; - reorder_attr.set_output_scales(0, scales); - + if (depth == 1) { + reorder_attr.set_output_scales(0, scales); + } else { + reorder_attr.set_output_scales(2, scales); + } auto summand_md = summand_mkl_shape.IsMklTensor() ? summand_mkl_shape.GetMklLayout() @@ -1858,6 +1439,23 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRequantize") .TypeConstraint("out_type"), NoOp); +// Register NoOp kernel for QuantizedConv2DPerChannel. +REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DPerChannel") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type"), + NoOp); +// Register a templatized implementation of MklQuantizedConv2DPerChannel. +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedConv2DPerChannel") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DOp); + // Register a templatized implementation of MklQuantizedConv2D. REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2D") @@ -1866,7 +1464,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tfilter") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DOp); + MklQuantizedConv2DOp); REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DAndRequantize") @@ -1875,7 +1473,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tfilter") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DOp); + MklQuantizedConv2DOp); // Register NoOp kernel for QuantizedConv2DWithBias to get a python interface. // This kernel will be replaced by an MKL kernel during graph @@ -1902,7 +1500,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tfilter") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DOp); + MklQuantizedConv2DOp); REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasAndRequantize") @@ -1912,7 +1510,8 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DOp); + MklQuantizedConv2DOp); + REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasAndRequantize") .Device(DEVICE_CPU) @@ -1921,7 +1520,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DOp); + MklQuantizedConv2DOp); // Register NoOp kernel for QuantizedConv2DAndRelu to get a python interface. // This kernel will be replaced by an MKL kernel during graph-optimization pass. @@ -1947,7 +1546,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tfilter") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DReluOp); + MklQuantizedConv2DReluOp); REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DAndReluAndRequantize") @@ -1956,7 +1555,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tfilter") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DReluOp); + MklQuantizedConv2DReluOp); // Register NoOp kernel for QuantizedConv2DWithBiasAndRelu to get a python // interface. @@ -1986,7 +1585,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tfilter") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DReluOp); + MklQuantizedConv2DReluOp); // Register a templatized implementation of // MklQuantizedConv2DWithBiasAndReluAndRequantize. @@ -1998,7 +1597,8 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DReluOp); + MklQuantizedConv2DReluOp); + REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") .Device(DEVICE_CPU) @@ -2007,7 +1607,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DReluOp); + MklQuantizedConv2DReluOp); // Register NoOp kernel for QuantizedConv2DWithBiasSumAndRelu to get a python // interface. @@ -2025,6 +1625,7 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasSumAndReluAndRequantize") .TypeConstraint("Tfilter") .TypeConstraint("out_type"), NoOp); + REGISTER_KERNEL_BUILDER( Name("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") .Device(DEVICE_CPU) @@ -2032,7 +1633,9 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tfilter") .TypeConstraint("out_type"), NoOp); -// Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu. + +// Register a templatized implementation of +// MklQuantizedConv2DWithBiasSumAndRelu. REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasSumAndRelu") .Device(DEVICE_CPU) @@ -2040,7 +1643,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tfilter") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DSumReluOp); + MklQuantizedConv2DSumReluOp); REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") @@ -2050,7 +1653,8 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DSumReluOp); + MklQuantizedConv2DSumReluOp); REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") @@ -2060,7 +1664,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DSumReluOp); + MklQuantizedConv2DSumReluOp); REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") @@ -2070,7 +1674,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DSumReluOp); + MklQuantizedConv2DSumReluOp); REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") @@ -2080,60 +1684,144 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklQuantizedConv2DSumReluOp); -#endif // INTEL_MKL_ML + MklQuantizedConv2DSumReluOp); + +// Register NoOp kernels for non-fused and fused versions of +// QuantizedDepthwiseConv2D to get a Python interface. These kernels will be +// replaced by MKL kernels during the graph-optimization pass. +REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2D") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type"), + NoOp); + +REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2DWithBias") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type"), + NoOp); + +REGISTER_KERNEL_BUILDER(Name("QuantizedDepthwiseConv2DWithBiasAndRelu") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type"), + NoOp); + +REGISTER_KERNEL_BUILDER( + Name("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type"), + NoOp); + +// Register templatized MKL kernels for non-fused and fused-versions of +// QuantizedDepthwiseConv2D. +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedDepthwiseConv2D") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DOp); + +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedDepthwiseConv2DWithBias") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DOp); + +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedDepthwiseConv2DWithBiasAndRelu") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DReluOp); + +// Tbias -> float +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DReluOp); + +// Tbias -> qint32 +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DReluOp); // Register 2D operations -#define REGISTER_MKL_CPU_2D(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConvOp); \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConvOp); \ - REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklDummyOp); \ - REGISTER_KERNEL_BUILDER(Name("_MklPadWithConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tpaddings") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConvOp); \ - REGISTER_KERNEL_BUILDER(Name("_MklPadWithConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tpaddings") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConvOp); \ - REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tpaddings") \ - .Label(mkl_op_registry::kMklOpLabel), \ +#define REGISTER_MKL_CPU_2D(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklConv2DWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvOp); \ + REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklDummyOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklPadWithConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklPadWithConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvOp); \ + REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklDummyOp); TF_CALL_float(REGISTER_MKL_CPU_2D); +TF_CALL_bfloat16(REGISTER_MKL_CPU_2D); -#define REGISTER_MKL_CPU_2D_DEPTHWISE(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklDepthwiseConv2dNative") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConvOp); +#define REGISTER_MKL_CPU_2D_DEPTHWISE(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklDepthwiseConv2dNative") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvOp); TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE); +TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE); // Note we are registering _MklFusedConv2D. // We check the fused_ops attributes to decide if bias is enabled or not. @@ -2166,6 +1854,7 @@ TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE); MklDummyOp); TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED); +TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED); // Register 3D operations #define REGISTER_MKL_CPU_3D(T) \ @@ -2176,6 +1865,7 @@ TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED); .Label(mkl_op_registry::kMklOpLabel), \ MklConvOp); TF_CALL_float(REGISTER_MKL_CPU_3D); +TF_CALL_bfloat16(REGISTER_MKL_CPU_3D); } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc index 58f0c30f32b..080569bf76a 100644 --- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc +++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc @@ -67,17 +67,17 @@ class MklBinaryOp : public BinaryOp { .Label(mkl_op_registry::kMklOpLabel), \ OP>); -REGISTER5(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double, - int32, int64); -REGISTER7(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double, - int32, int64, complex64, complex128); -REGISTER5(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double, - uint8, int32); -REGISTER5(MklBinaryOp, CPU, "_MklMaximum", functor::maximum, float, Eigen::half, - double, int32, int64); -REGISTER5(MklBinaryOp, CPU, "_MklSquaredDifference", - functor::squared_difference, float, Eigen::half, double, int32, - int64); +REGISTER6(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double, + int32, int64, bfloat16); +REGISTER8(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double, + int32, int64, complex64, complex128, bfloat16); +REGISTER6(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double, + uint8, int32, bfloat16); +REGISTER6(MklBinaryOp, CPU, "_MklMaximum", functor::maximum, float, Eigen::half, + double, int32, int64, bfloat16); +REGISTER6(MklBinaryOp, CPU, "_MklSquaredDifference", + functor::squared_difference, float, Eigen::half, double, int32, int64, + bfloat16); #undef REGISTER #pragma pop_macro("REGISTER") diff --git a/tensorflow/core/kernels/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl_dequantize_op.cc new file mode 100644 index 00000000000..4c9dbf4274a --- /dev/null +++ b/tensorflow/core/kernels/mkl_dequantize_op.cc @@ -0,0 +1,180 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/meta_support.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/lib/core/errors.h" + +#include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/util/mkl_util.h" + +#include "mkldnn.hpp" +using mkldnn::primitive_attr; +using mkldnn::stream; + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class MklDequantizeOp : public OpKernel { + public: + explicit MklDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string mode_string; + OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); + OP_REQUIRES(ctx, mode_string == "SCALED", + errors::InvalidArgument( + "MklDequantizeOp only supports 'SCALED' mode, but got '" + + mode_string + "'")); + } + + void Compute(OpKernelContext* ctx) override { + try { + // Using CPU device + auto cpu_engine = engine(engine::cpu, 0); + + // Get the inputs + const Tensor& src_tensor = MklGetInput(ctx, kSrcIndex); + const float min_range = + MklGetInput(ctx, kMinIndex).template flat()(0); + const float max_range = + MklGetInput(ctx, kMaxIndex).template flat()(0); + + // Get MklShape + MklDnnShape src_mkl_shape; + GetMklShape(ctx, kSrcIndex, &src_mkl_shape); + + // src_dims is the dimension of src_tensor + // output_dims are same as src_dims + auto src_dims = src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDims(src_tensor.shape()); + auto output_dims = src_dims; + + // Create reorder memory for src and dst + MklDnnData src(&cpu_engine); + MklDnnData dst(&cpu_engine); + + // If input is in MKL layout, then simply grab input layout; otherwise, + // construct input TF layout. For TF layout, although input shape + // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's + // layout + auto src_md = + src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), memory::format::nhwc); + + src.SetUsrMem(src_md, &src_tensor); + + Tensor* output_tensor = nullptr; + MklDnnShape output_mkl_shape; + TensorShape output_tf_shape; + + memory::primitive_desc src_pd = + memory::primitive_desc(src_md, cpu_engine); + memory::desc dst_md = src_mkl_shape.IsMklTensor() + ? src_md + : memory::desc(src_dims, MklDnnType(), + memory::format::nhwc); + memory::primitive_desc dst_pd = + memory::primitive_desc(dst_md, cpu_engine); + + // If input is MKL shape, output is also MKL shape. + // If input is TF shape, output is also TF shape. + if (src_mkl_shape.IsMklTensor()) { + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(), + src_mkl_shape.GetSizesAsMklDnnDims(), + src_mkl_shape.GetTfDataFormat()); + output_tf_shape.AddDim((dst_pd.get_size() / sizeof(float))); + } else { + output_mkl_shape.SetMklTensor(false); + output_tf_shape = MklDnnDimsToTFShape(output_dims); + } + + // Allocate MKL or TF output shape based on the above + AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape, + output_mkl_shape); + dst.SetUsrMem(dst_md, output_tensor); + + // The quantization logic here for mode SCALED is similar to the logic + // in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3. + static constexpr int num_bits = sizeof(T) * 8; + const float max_abs = std::max(std::abs(min_range), std::abs(max_range)); + bool is_signed = std::is_signed::value; + // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For + // example, if it is 8 bits, we have the range [-127, 127]. So for input + // range of [-x, x], the scale should be (2*x)/254. + // + // If it is unsigned and num_bits == 8, the range with 8 bits is [0, 255]. + // If the input range is [0, x], then the scale is x/255 instead of 254 as + // in the case above. + const int target_bits = is_signed ? (num_bits - 1) : num_bits; + const float target_range = + static_cast((uint64_t{1} << target_bits) - 1); + const float scale_factor = max_abs / target_range; + + std::vector scales; + scales.push_back(scale_factor); + primitive_attr attr; + attr.set_output_scales(0, scales); + attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); + mkldnn::reorder::primitive_desc reorder_pd = + mkldnn::reorder::primitive_desc(src_pd, dst_pd, attr); + + // Execute MKL-DNN primitive + std::vector net; + net.push_back( + mkldnn::reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem())); + stream(stream::kind::eager).submit(net).wait(); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + ctx, errors::Aborted("Operation received an exception:", error_msg)); + } + } + + private: + const size_t kSrcIndex = 0; + const size_t kMinIndex = 1; + const size_t kMaxIndex = 2; +}; + +REGISTER_KERNEL_BUILDER(Name("_MklDequantize") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklDequantizeOp); +REGISTER_KERNEL_BUILDER(Name("_MklDequantize") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklDequantizeOp); + +} // namespace tensorflow + +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_dequantize_op_test.cc b/tensorflow/core/kernels/mkl_dequantize_op_test.cc new file mode 100644 index 00000000000..23d59ef7ab6 --- /dev/null +++ b/tensorflow/core/kernels/mkl_dequantize_op_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +class MklDequantizeOpTest : public OpsTestBase {}; + +static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0}; +static const TensorShape dummy_shape({8}); + +TEST_F(MklDequantizeOpTest, small) { + TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "_MklDequantize") + .Input(FakeInput(DT_QUINT8)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("T", DataTypeToEnum::v()) + .Attr("mode", "SCALED") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({1, 2, 2, 2}), + {0, 10, 50, 40, 25, 115, 190, 255}); + // min_range = 0 + AddInputFromArray(TensorShape({1}), {0}); + // max_range = 200 + AddInputFromArray(TensorShape({1}), {200.0f}); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&expected, + {0.0, 7.84, 39.21, 31.37, 19.6, 90.2, 149.0, 200}); + const Tensor& output = *GetOutput(0); + test::ExpectTensorNear(expected, output, 0.1); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 1ae42a0d0d7..6b6eaace8b0 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -43,7 +43,7 @@ struct MklBatchNormFwdParams { : src_dims(src_dims), depth(depth), eps(eps), training(training) {} }; -template +template class MklFusedBatchNormFwdPrimitive : public MklPrimitive { public: explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams) @@ -60,15 +60,15 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { // dst_data: output data buffer of dst // mean_data: output data buffer of means // variance_data: output data buffer of variances - void Execute(const T* src_data, const T* weights_data, T* dst_data, - T* mean_data, T* variance_data) { + void Execute(const T* src_data, const U* weights_data, T* dst_data, + U* mean_data, U* variance_data) { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); if (context_.flags & use_scale_shift) context_.weights_mem->set_data_handle( - static_cast(const_cast(weights_data))); + static_cast(const_cast(weights_data))); if ((context_.pkind == prop_kind::forward_training) || (context_.flags & use_global_stats)) { @@ -158,19 +158,19 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData)); if (context_.flags & use_scale_shift) { - auto weights_desc = memory::desc({2, fwdParams.depth}, MklDnnType(), + auto weights_desc = memory::desc({2, fwdParams.depth}, MklDnnType(), memory::format::nc); context_.weights_mem.reset( new memory({weights_desc, cpu_engine_}, DummyData)); } if (fwdParams.training || (context_.flags & use_global_stats)) { - auto mean_desc = memory::desc({1, fwdParams.depth}, MklDnnType(), + auto mean_desc = memory::desc({1, fwdParams.depth}, MklDnnType(), memory::format::nc); context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); auto variance_desc = - memory::desc({1, fwdParams.depth}, MklDnnType(), memory::nc); + memory::desc({1, fwdParams.depth}, MklDnnType(), memory::nc); context_.variance_mem.reset( new memory({variance_desc, cpu_engine_}, DummyData)); } @@ -219,18 +219,18 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { engine cpu_engine_; }; -template +template class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory { public: - static MklFusedBatchNormFwdPrimitive* Get( + static MklFusedBatchNormFwdPrimitive* Get( const MklBatchNormFwdParams& fwdParams) { - auto bn_fwd = static_cast*>( - MklFusedBatchNormFwdPrimitiveFactory::GetInstance().GetBatchNormFwd( - fwdParams)); + auto bn_fwd = static_cast*>( + MklFusedBatchNormFwdPrimitiveFactory::GetInstance() + .GetBatchNormFwd(fwdParams)); if (bn_fwd == nullptr) { - bn_fwd = new MklFusedBatchNormFwdPrimitive(fwdParams); - MklFusedBatchNormFwdPrimitiveFactory::GetInstance().SetBatchNormFwd( + bn_fwd = new MklFusedBatchNormFwdPrimitive(fwdParams); + MklFusedBatchNormFwdPrimitiveFactory::GetInstance().SetBatchNormFwd( fwdParams, bn_fwd); } return bn_fwd; @@ -253,6 +253,8 @@ class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(fwdParams.depth); key_creator.AddAsKey(fwdParams.eps); key_creator.AddAsKey(fwdParams.training); + key_creator.AddAsKey(typeid(T).name()); + key_creator.AddAsKey(typeid(U).name()); return key_creator.GetKey(); } @@ -284,7 +286,7 @@ struct MklBatchNormBwdParams { training(training) {} }; -template +template class MklFusedBatchNormBwdPrimitive : public MklPrimitive { public: explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) @@ -303,21 +305,22 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { // weights_data: input data buffer of weights // diff_src_data: output data buffer of diff_src // diff_weights_data: output data buffer of diff_weights - void Execute(const T* src_data, const T* mean_data, const T* variance_data, - const T* diff_dst_data, const T* weights_data, T* diff_src_data, - T* diff_weights_data) { + void Execute(const T* src_data, const U* mean_data, const U* variance_data, + const T* diff_dst_data, const U* weights_data, T* diff_src_data, + U* diff_weights_data) { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.mean_mem->set_data_handle( - static_cast(const_cast(mean_data))); + static_cast(const_cast(mean_data))); context_.variance_mem->set_data_handle( - static_cast(const_cast(variance_data))); + static_cast(const_cast(variance_data))); context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); + // TODO: type for weights? if (context_.flags & use_scale_shift) { context_.weights_mem->set_data_handle( - static_cast(const_cast(weights_data))); + static_cast(const_cast(weights_data))); context_.diff_weights_mem->set_data_handle( static_cast(diff_weights_data)); } @@ -391,11 +394,11 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { memory::desc({bwdParams.diff_dst_dims}, MklDnnType(), get_desired_format(bwdParams.diff_dst_dims[1])); auto variance_desc = - memory::desc({1, bwdParams.depth}, MklDnnType(), memory::nc); + memory::desc({1, bwdParams.depth}, MklDnnType(), memory::nc); auto mean_desc = - memory::desc({1, bwdParams.depth}, MklDnnType(), memory::format::nc); + memory::desc({1, bwdParams.depth}, MklDnnType(), memory::format::nc); auto weights_desc = - memory::desc({2, bwdParams.depth}, MklDnnType(), memory::format::nc); + memory::desc({2, bwdParams.depth}, MklDnnType(), memory::format::nc); auto diff_weights_desc = weights_desc; // fwd desc & primitive desc @@ -443,17 +446,17 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { engine cpu_engine_; }; -template +template class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory { public: - static MklFusedBatchNormBwdPrimitive* Get( + static MklFusedBatchNormBwdPrimitive* Get( const MklBatchNormBwdParams& bwdParams) { - auto bn_bwd = static_cast*>( - MklFusedBatchNormBwdPrimitiveFactory::GetInstance().GetBatchNormBwd( - bwdParams)); + auto bn_bwd = static_cast*>( + MklFusedBatchNormBwdPrimitiveFactory::GetInstance() + .GetBatchNormBwd(bwdParams)); if (bn_bwd == nullptr) { - bn_bwd = new MklFusedBatchNormBwdPrimitive(bwdParams); - MklFusedBatchNormBwdPrimitiveFactory::GetInstance().SetBatchNormBwd( + bn_bwd = new MklFusedBatchNormBwdPrimitive(bwdParams); + MklFusedBatchNormBwdPrimitiveFactory::GetInstance().SetBatchNormBwd( bwdParams, bn_bwd); } return bn_bwd; @@ -477,6 +480,8 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(bwdParams.depth); key_creator.AddAsKey(bwdParams.eps); key_creator.AddAsKey(bwdParams.training); + key_creator.AddAsKey(typeid(T).name()); + key_creator.AddAsKey(typeid(U).name()); return key_creator.GetKey(); } @@ -492,14 +497,14 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory { } }; -template +template class MklFusedBatchNormOp : public OpKernel { public: explicit MklFusedBatchNormOp(OpKernelConstruction* context) : OpKernel(context) { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); - epsilon_ = T(epsilon); + epsilon_ = epsilon; string tensor_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), @@ -593,7 +598,7 @@ class MklFusedBatchNormOp : public OpKernel { SetMeanVariance(est_mean_tensor, est_variance_tensor); MklDnnData src(&cpu_engine); - MklDnnData weights(&cpu_engine); + MklDnnData weights(&cpu_engine); memory::format format_m; if (dnn_shape_src.IsMklTensor()) { @@ -618,28 +623,28 @@ class MklFusedBatchNormOp : public OpKernel { // MKL-DNN packs scale & shift as "weights": // ...... - weights.AllocateBuffer(2 * depth_ * sizeof(T)); - T* weights_data = reinterpret_cast(weights.GetAllocatedBuffer()); - const T* scale_tf = scale_tensor.flat().data(); - const T* shift_tf = shift_tensor.flat().data(); + weights.AllocateBuffer(2 * depth_ * sizeof(U)); + U* weights_data = reinterpret_cast(weights.GetAllocatedBuffer()); + const U* scale_tf = scale_tensor.flat().data(); + const U* shift_tf = shift_tensor.flat().data(); - std::memcpy(weights_data, scale_tf, depth_ * sizeof(T)); - std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(T)); + std::memcpy(weights_data, scale_tf, depth_ * sizeof(U)); + std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(U)); char* saved_mean_data_tf = - reinterpret_cast(saved_mean_tensor->flat().data()); + reinterpret_cast(saved_mean_tensor->flat().data()); std::memcpy(saved_mean_data_tf, reinterpret_cast(mean_values_), - depth_ * sizeof(T)); + depth_ * sizeof(U)); char* saved_variance_data_tf = - reinterpret_cast(saved_variance_tensor->flat().data()); + reinterpret_cast(saved_variance_tensor->flat().data()); std::memcpy(saved_variance_data_tf, reinterpret_cast(variance_values_), - depth_ * sizeof(T)); + depth_ * sizeof(U)); // get batchnorm op from the pool MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_); - MklFusedBatchNormFwdPrimitive* bn_fwd = - MklFusedBatchNormFwdPrimitiveFactory::Get(fwdParams); + MklFusedBatchNormFwdPrimitive* bn_fwd = + MklFusedBatchNormFwdPrimitiveFactory::Get(fwdParams); // check if reorder is needed for src, weights, mean, variance const T* src_data = src_tensor.flat().data(); @@ -669,9 +674,9 @@ class MklFusedBatchNormOp : public OpKernel { AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, dnn_shape_dst); - T* weights_op_data = weights_data; - T* mean_op_data = saved_mean_tensor->flat().data(); - T* variance_op_data = saved_variance_tensor->flat().data(); + U* weights_op_data = weights_data; + U* mean_op_data = saved_mean_tensor->flat().data(); + U* variance_op_data = saved_variance_tensor->flat().data(); T* dst_data = dst_tensor->flat().data(); // execution @@ -679,10 +684,10 @@ class MklFusedBatchNormOp : public OpKernel { variance_op_data); // copy batch_mean data - T* batch_mean_data_tf = batch_mean_tensor->flat().data(); + U* batch_mean_data_tf = batch_mean_tensor->flat().data(); std::memcpy(reinterpret_cast(batch_mean_data_tf), reinterpret_cast(saved_mean_data_tf), - depth_ * sizeof(T)); + depth_ * sizeof(U)); // TODO(yli135): OpMem is same as usr mem since // since its format is hard-coded as nc when primitive is created. @@ -694,14 +699,15 @@ class MklFusedBatchNormOp : public OpKernel { adjust_factor = (static_cast(orig_size)) / adjust_size; } - auto variance_data = reinterpret_cast(saved_variance_data_tf); - auto batch_variance_data = batch_variance_tensor->flat().data(); + auto variance_data = reinterpret_cast(saved_variance_data_tf); + auto batch_variance_data = batch_variance_tensor->flat().data(); if (is_training_) { for (int k = 0; k < depth_; k++) { - batch_variance_data[k] = variance_data[k] * adjust_factor; + batch_variance_data[k] = + variance_data[k] * static_cast(adjust_factor); } } else { - std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(T)); + std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(U)); } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + @@ -714,11 +720,11 @@ class MklFusedBatchNormOp : public OpKernel { } private: - T epsilon_; + float epsilon_; TensorFormat tensor_format_; bool is_training_; - T* mean_values_; - T* variance_values_; + U* mean_values_; + U* variance_values_; size_t depth_; // batch normalization is done for per channel. engine cpu_engine = engine(engine::cpu, 0); @@ -728,9 +734,9 @@ class MklFusedBatchNormOp : public OpKernel { } void SetMeanVariance(const Tensor& mean, const Tensor& variance) { - mean_values_ = reinterpret_cast(const_cast(mean.flat().data())); + mean_values_ = reinterpret_cast(const_cast(mean.flat().data())); variance_values_ = - reinterpret_cast(const_cast(variance.flat().data())); + reinterpret_cast(const_cast(variance.flat().data())); } void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, @@ -778,8 +784,8 @@ class MklFusedBatchNormOp : public OpKernel { CHECK_NOTNULL(*batch_mean_tensor); // set NAN mean value in case of empty input tensor int num_elements = tf_shape_scale.num_elements(); - auto batch_mean_data = (*batch_mean_tensor)->flat().data(); - std::fill_n(batch_mean_data, num_elements, NAN); + auto batch_mean_data = (*batch_mean_tensor)->flat().data(); + std::fill_n(batch_mean_data, num_elements, static_cast(NAN)); // allocate batch variance output tensor MklDnnShape mkl_shape_batch_variance; @@ -789,8 +795,8 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_batch_variance); CHECK_NOTNULL(*batch_variance_tensor); // set NAN variance value in case of empty input tensor - auto batch_variance_data = (*batch_variance_tensor)->flat().data(); - std::fill_n(batch_variance_data, num_elements, NAN); + auto batch_variance_data = (*batch_variance_tensor)->flat().data(); + std::fill_n(batch_variance_data, num_elements, static_cast(NAN)); // Mean and variance (without Bessel's correction) saved for backward // computation to serve as pre-computed mean and variance. @@ -800,8 +806,8 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_saved_mean); CHECK_NOTNULL(*saved_mean_tensor); // set NAN mean value in case of empty input tensor - auto saved_mean_data = (*saved_mean_tensor)->flat().data(); - std::fill_n(saved_mean_data, num_elements, NAN); + auto saved_mean_data = (*saved_mean_tensor)->flat().data(); + std::fill_n(saved_mean_data, num_elements, static_cast(NAN)); MklDnnShape mkl_shape_saved_variance; mkl_shape_saved_variance.SetMklTensor(false); @@ -810,19 +816,19 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_saved_variance); CHECK_NOTNULL(*saved_variance_tensor); // set NAN variance value in case of empty input tensor - auto saved_variance_data = (*saved_variance_tensor)->flat().data(); - std::fill_n(saved_variance_data, num_elements, NAN); + auto saved_variance_data = (*saved_variance_tensor)->flat().data(); + std::fill_n(saved_variance_data, num_elements, static_cast(NAN)); } }; -template +template class MklFusedBatchNormGradOp : public OpKernel { public: explicit MklFusedBatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); - epsilon_ = T(epsilon); + epsilon_ = epsilon; string tensor_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), @@ -918,8 +924,8 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnData src(&cpu_engine); MklDnnData diff_dst(&cpu_engine); - MklDnnData weights(&cpu_engine); - MklDnnData diff_weights(&cpu_engine); + MklDnnData weights(&cpu_engine); + MklDnnData diff_weights(&cpu_engine); memory::dims src_dims = dnn_shape_src.IsMklTensor() @@ -943,20 +949,20 @@ class MklFusedBatchNormGradOp : public OpKernel { // weights -- MKL DNN packs scales/ shifts as weights in order // of scale, ..., scale, shift, ...., shift - weights.AllocateBuffer(2 * depth_ * sizeof(T)); - T* weights_data_tf = reinterpret_cast(weights.GetAllocatedBuffer()); - const T* scale_tf = scale_tensor.flat().data(); + weights.AllocateBuffer(2 * depth_ * sizeof(U)); + U* weights_data_tf = reinterpret_cast(weights.GetAllocatedBuffer()); + const U* scale_tf = scale_tensor.flat().data(); for (int k = 0; k < depth_; k++) { weights_data_tf[k] = scale_tf[k]; - weights_data_tf[k + depth_] = 0; + weights_data_tf[k + depth_] = static_cast(0); } - diff_weights.AllocateBuffer(2 * depth_ * sizeof(T)); + diff_weights.AllocateBuffer(2 * depth_ * sizeof(U)); MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, is_training_); - MklFusedBatchNormBwdPrimitive* bn_bwd = - MklFusedBatchNormBwdPrimitiveFactory::Get(bwdParams); + MklFusedBatchNormBwdPrimitive* bn_bwd = + MklFusedBatchNormBwdPrimitiveFactory::Get(bwdParams); // check if src/diff_dst need to be reordered const T* src_data = src_tensor.flat().data(); @@ -1001,13 +1007,13 @@ class MklFusedBatchNormGradOp : public OpKernel { AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, tf_shape_diff_src, dnn_shape_diff_src); - T* mean_data = - static_cast(const_cast(saved_mean_tensor.flat().data())); - T* variance_data = static_cast( - const_cast(saved_variance_tensor.flat().data())); - T* weights_data = weights_data_tf; + U* mean_data = + static_cast(const_cast(saved_mean_tensor.flat().data())); + U* variance_data = static_cast( + const_cast(saved_variance_tensor.flat().data())); + U* weights_data = weights_data_tf; T* diff_src_data = static_cast(diff_src_tensor->flat().data()); - T* diff_weights_data = static_cast(diff_weights.GetAllocatedBuffer()); + U* diff_weights_data = static_cast(diff_weights.GetAllocatedBuffer()); // Execute bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, weights_data, diff_src_data, diff_weights_data); @@ -1019,14 +1025,14 @@ class MklFusedBatchNormGradOp : public OpKernel { &diff_shift_tensor); // copy data: diff_scale and diff_shift - auto diff_scale_data = diff_scale_tensor->flat().data(); - auto diff_shift_data = diff_shift_tensor->flat().data(); + auto diff_scale_data = diff_scale_tensor->flat().data(); + auto diff_shift_data = diff_shift_tensor->flat().data(); std::memcpy(reinterpret_cast(diff_scale_data), reinterpret_cast(diff_weights_data), - depth_ * sizeof(T)); + depth_ * sizeof(U)); std::memcpy(reinterpret_cast(diff_shift_data), reinterpret_cast(diff_weights_data + depth_), - depth_ * sizeof(T)); + depth_ * sizeof(U)); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -1038,7 +1044,7 @@ class MklFusedBatchNormGradOp : public OpKernel { } private: - T epsilon_; + float epsilon_; TensorFormat tensor_format_; int depth_; // batch normalization is done for per channel. bool is_training_; @@ -1059,7 +1065,8 @@ class MklFusedBatchNormGradOp : public OpKernel { AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, tf_shape_src, dnn_shape_diff_src); auto diff_src_data = (*diff_src_tensor)->flat().data(); - std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), 0); + std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), + static_cast(0)); Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; @@ -1085,18 +1092,18 @@ class MklFusedBatchNormGradOp : public OpKernel { AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, tf_shape_scale_shift, mkl_shape_diff_scale); CHECK_NOTNULL(*diff_scale_tensor); - auto diff_scale_data = (*diff_scale_tensor)->flat().data(); + auto diff_scale_data = (*diff_scale_tensor)->flat().data(); std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(), - 0); + static_cast(0)); MklDnnShape mkl_shape_diff_shift; mkl_shape_diff_shift.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, tf_shape_scale_shift, mkl_shape_diff_shift); CHECK_NOTNULL(*diff_shift_tensor); - auto diff_shift_data = (*diff_shift_tensor)->flat().data(); + auto diff_shift_data = (*diff_shift_tensor)->flat().data(); std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(), - 0); + static_cast(0)); // Placeholders for estimated_mean and estimated_variance, which are // used for inference and thus not needed here for gradient computation. @@ -1112,23 +1119,52 @@ class MklFusedBatchNormGradOp : public OpKernel { memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } }; -#define REGISTER_MKL_CPU(T) \ +#define REGISTER_MKL_FUSED_BATCHNORM_CPU(T) \ REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNorm") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ - MklFusedBatchNormOp); -TF_CALL_float(REGISTER_MKL_CPU); -#undef REGISTER_MKL_CPU + MklFusedBatchNormOp); -#define REGISTER_MKL_CPU(T) \ +TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU); +TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); +#undef REGISTER_MKL_FUSED_BATCHNORM_CPU + +#define REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(T, U) \ + REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklFusedBatchNormOp); + +REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float); +REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float); +#undef REGISTER_MKL_FUSED_BATCHNORM_V2_CPU + +#define REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU(T) \ REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ - MklFusedBatchNormGradOp); -TF_CALL_float(REGISTER_MKL_CPU); -#undef REGISTER_MKL_CPU + MklFusedBatchNormGradOp); + +TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); +TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); +#undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU + +#define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(T, U) \ + REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGradV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklFusedBatchNormGradOp); + +REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float); +REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float); +#undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc index c1f6fa3fd0a..f9f58416a54 100644 --- a/tensorflow/core/kernels/mkl_identity_op.cc +++ b/tensorflow/core/kernels/mkl_identity_op.cc @@ -60,6 +60,7 @@ class MklIdentityOp : public OpKernel { MklIdentityOp); TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_bfloat16(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index ab89fe7d841..f14b811b341 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -306,6 +306,7 @@ class MklInputConversionOp : public OpKernel { // not support types. // TF_CALL_NUMBER_TYPES(REGISTER_CPU); TF_CALL_float(REGISTER_CPU); +TF_CALL_bfloat16(REGISTER_CPU); #undef REGISTER_CPU } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index f4788f48519..766a3ea907c 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -80,7 +80,7 @@ class MklMatMulOp : public OpKernel { return; } - if (a.NumElements() == 0 || b.NumElements() == 0) { + if (a.NumElements() == 0 && b.NumElements() == 0) { // If a has shape [x, 0] and b has shape [0, y], the // output shape is [x, y] where x and y are non-zero, so we fill // the output with zeros. diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index cb494f6c3ec..0e30eb53550 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -16,15 +16,17 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. #ifdef INTEL_MKL #define EIGEN_USE_THREADS + +#include + +#include "mkldnn.hpp" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/mkl_pooling_ops_common.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" -#ifndef INTEL_MKL_ML_ONLY -#include -#include "mkldnn.hpp" using mkldnn::algorithm; using mkldnn::engine; using mkldnn::error; @@ -33,471 +35,11 @@ using mkldnn::padding_kind; using mkldnn::pooling_backward; using mkldnn::pooling_forward; using mkldnn::prop_kind; -#endif namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -// MKL-DNN is now default. MKL-ML must be specified explicitly. -#ifdef INTEL_MKL_ML_ONLY - -// An implementation of MaxPooling (forward). -template -class MklMaxPoolingOp : public OpKernel { - public: - explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) { - string data_format; - - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); - OP_REQUIRES(context, ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 4, - errors::InvalidArgument("Sliding window stride field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, - errors::Unimplemented("Pooling is not yet supported on the " - "batch dimension.")); - - workspace_enabled_ = false; - // We may not get this attribute for this node if it does not go through - // graph rewrite pass. So we do not check for error while retrieving this - // attribute value. - OP_REQUIRES_OK(context, - context->GetAttr("workspace_enabled", &workspace_enabled_)); - } - - void Compute(OpKernelContext* context) override { - MklMaxPoolingOpContext mkl_context; - // Get the input tensor - const Tensor& tensor_in = MklGetInput(context, 0); - GetMklShape(context, 0, &mkl_context.input_shape); - bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); - - mkl_context.params.in_dim = 4; - MklPoolParameters pool_params; - if (input_in_mkl_format == false) { - pool_params.Init(context, ksize_, stride_, padding_, data_format_, - tensor_in.shape()); - OP_REQUIRES( - context, (pool_params.depth_window == 1), - errors::Unimplemented("Depthwise max pooling not supported by MKL")); - - } else { - pool_params.Init(context, ksize_, stride_, padding_, data_format_, - &mkl_context.input_shape); - } - - // Extract the parameters for the op from the pooling specs - - ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); - - mkl_context.MklCreateLayoutsAndPrimitives(context); - OP_REQUIRES_OK(context, context->status()); - - // Declare output tensor - TensorShape tensor_out_shape; - MklShape mkl_out_shape, mkl_workspace_shape; - mkl_out_shape.SetMklTensor(true); - mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst); - mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, - mkl_context.params.out_sizes, - mkl_context.params.out_strides); - mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); - - Tensor* output_tensor = nullptr; - tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( - mkl_out_shape.GetMklLayout())) / - sizeof(T)); - AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape, - mkl_out_shape); - - Tensor* workspace_tensor; - void* workspace_buf = nullptr; - - TensorShape workspace_shape; - mkl_workspace_shape.SetMklTensor(false); - workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( - mkl_context.lt_workspace)) / - sizeof(T)); - AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape, - mkl_workspace_shape); - - mkl_context.pooling_res[dnnResourceWorkspace] = const_cast( - static_cast(workspace_tensor->flat().data())); - mkl_context.pooling_res[dnnResourceSrc] = - const_cast(static_cast(tensor_in.flat().data())); - mkl_context.pooling_res[dnnResourceDst] = const_cast( - static_cast(output_tensor->flat().data())); - - CHECK_EQ( - dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res), - E_SUCCESS); - - mkl_context.MklCleanup(); - } - - private: - typedef struct { - MklPoolingOpParams params; - MklShape input_shape; - void* pooling_res[dnnResourceNumber]; - dnnPrimitive_t prim_pooling_fwd = nullptr; - dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr; - - void MklCreateLayoutsAndPrimitives(OpKernelContext* context) { - bool input_in_mkl_format = input_shape.IsMklTensor(); - // Create or use existing DNN user layout - if (input_in_mkl_format == false) { - CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, - params.in_sizes, params.in_strides), - E_SUCCESS); - } else { - lt_user_input = (dnnLayout_t)input_shape.GetCurLayout(); - } - - dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax; - dnnPrimitiveAttributes_t primAttr = nullptr; - - // Create DNN primitives - CHECK_EQ(dnnPoolingCreateForward_F32( - &prim_pooling_fwd, primAttr, algorithm, lt_user_input, - params.kernel_size, params.kernel_stride, params.in_offset, - dnnBorderZerosAsymm), - E_SUCCESS); - - // Creates layout for the workspace - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_fwd, - dnnResourceWorkspace), - E_SUCCESS); - } - - void MklCleanup() { - bool input_in_mkl_format = input_shape.IsMklTensor(); - CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); - if (!input_in_mkl_format) { - CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); - } - CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS); - } - } MklMaxPoolingOpContext; - - std::vector ksize_; - std::vector stride_; - Padding padding_; - TensorFormat data_format_; - bool workspace_enabled_; -}; - -// The operation to compute MaxPool gradients. -// It takes three inputs: -// - The original input tensor -// - The original output tensor -// - Backprop tensor for output -// It produces one output: backprop tensor for input. -template -class MklMaxPoolingGradOp : public OpKernel { - public: - explicit MklMaxPoolingGradOp(OpKernelConstruction* context) - : OpKernel(context) { - string data_format; - - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); - OP_REQUIRES(context, ksize_.size() == 4, - errors::InvalidArgument("Sliding window ksize field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, - errors::Unimplemented( - "Pooling is not yet supported on the batch dimension.")); - workspace_enabled_ = false; - // We may not get this attribute for this node if it does not go through - // graph rewrite pass. So we do not check for error while retrieving this - // attribute value. - OP_REQUIRES_OK(context, - context->GetAttr("workspace_enabled", &workspace_enabled_)); - } - - void Compute(OpKernelContext* context) override { - MklMaxPoolingGradOpContext mkl_context; - // Input - The original input tensor - const Tensor& tensor_in = MklGetInput(context, 0); - - // Output - Backprop tensor for input. - Tensor* output_tensor = nullptr; - - GetMklShape(context, 0, &mkl_context.input_shape); - GetMklShape(context, 2, &mkl_context.output_backprop_shape); - bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); - - if (input_in_mkl_format == false) - mkl_context.params.in_dim = tensor_in.dims(); - else - mkl_context.params.in_dim = mkl_context.input_shape.GetDimension(); - - MklPoolParameters pool_params; - if (input_in_mkl_format == false) { - pool_params.Init(context, ksize_, stride_, padding_, data_format_, - tensor_in.shape()); - OP_REQUIRES( - context, (pool_params.depth_window == 1), - errors::Unimplemented("Depthwise max pooling not supported by MKL")); - - } else { - pool_params.Init(context, ksize_, stride_, padding_, data_format_, - &mkl_context.input_shape); - } - - // Extract the parameters for the op from the pooling specs - ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); - - mkl_context.MklCreateLayouts(context); - OP_REQUIRES_OK(context, context->status()); - - mkl_context.MklCreatePrimitives(context, workspace_enabled_); - OP_REQUIRES_OK(context, context->status()); - - mkl_context.MklPrepareInputs(context, workspace_enabled_); - OP_REQUIRES_OK(context, context->status()); - - // Create shape for the input back prop output - TensorShape mkl_input_backprop; - MklShape mkl_output_shape; - mkl_output_shape.SetMklTensor(true); - mkl_output_shape.SetMklLayout(mkl_context.prim_pooling_bwd, - dnnResourceDiffSrc); - mkl_output_shape.SetTfLayout(mkl_context.params.in_dim, - mkl_context.params.in_sizes, - mkl_context.params.in_strides); - mkl_output_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); - - mkl_input_backprop.AddDim( - dnnLayoutGetMemorySize_F32( - static_cast(mkl_output_shape.GetMklLayout())) / - sizeof(T)); - AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop, - mkl_output_shape); - mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast( - static_cast(output_tensor->flat().data())); - - CHECK_EQ( - dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), - E_SUCCESS); - - mkl_context.MklCleanup(workspace_enabled_); - } - - private: - typedef struct { - MklPoolingOpParams params; - MklShape input_shape, output_backprop_shape; - void* pooling_resfwd[dnnResourceNumber]; - void* pooling_res[dnnResourceNumber]; - dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr, - convert_input = nullptr, convert_outbackprop = nullptr; - dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr, - lt_input_user = nullptr, lt_input_prim = nullptr; - void* input_buf; - void* outbackprop_buf; - Tensor tmp_output_buf_tensor; - Tensor workspace_buf_tensor; - Tensor input_buf_tensor, outbackprop_buf_tensor; - - void MklCreateLayouts(OpKernelContext* context) { - bool input_in_mkl_format = input_shape.IsMklTensor(); - bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); - // Create DNN user layout for input and outbackprop or get existing layout - if (input_in_mkl_format == false) { - CHECK_EQ(dnnLayoutCreate_F32(<_input_user, params.in_dim, - params.in_sizes, params.in_strides), - E_SUCCESS); - } else { - lt_input_user = (dnnLayout_t)input_shape.GetCurLayout(); - } - - // We don't care about the output layout for now as we can create it from - // primitives for the max pooling fwd prop - if (outbackprop_in_mkl_format == false) { - CHECK_EQ(dnnLayoutCreate_F32(<_outbackprop_user, params.in_dim, - params.out_sizes, params.out_strides), - E_SUCCESS); - } else { - lt_outbackprop_user = (dnnLayout_t)output_backprop_shape.GetCurLayout(); - } - } - - // Create DNN primitives - void MklCreatePrimitives(OpKernelContext* context, bool workspace_enabled) { - dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax; - dnnPrimitiveAttributes_t primAttr = nullptr; - - if (workspace_enabled == false) { - CHECK_EQ(dnnPoolingCreateForward_F32( - &prim_pooling_fwd, primAttr, algorithm, lt_input_user, - params.kernel_size, params.kernel_stride, params.in_offset, - dnnBorderZerosAsymm), - E_SUCCESS); - } - - CHECK_EQ(dnnPoolingCreateBackward_F32( - &prim_pooling_bwd, primAttr, algorithm, lt_input_user, - params.kernel_size, params.kernel_stride, params.in_offset, - dnnBorderZerosAsymm), - E_SUCCESS); - - // Creates conversions - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( - <_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst), - E_SUCCESS); - - if (workspace_enabled == false) { - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( - <_input_prim, prim_pooling_fwd, dnnResourceSrc), - E_SUCCESS); - if (!dnnLayoutCompare_F32(lt_input_user, lt_input_prim)) { - CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input_user, - lt_input_prim), - E_SUCCESS); - AllocTmpBuffer(context, &input_buf_tensor, lt_input_prim, &input_buf); - } - } - - if (!dnnLayoutCompare_F32(lt_outbackprop_user, lt_outbackprop_prim)) { - CHECK_EQ( - dnnConversionCreate_F32(&convert_outbackprop, lt_outbackprop_user, - lt_outbackprop_prim), - E_SUCCESS); - AllocTmpBuffer(context, &outbackprop_buf_tensor, lt_outbackprop_prim, - &outbackprop_buf); - } - } - - // Compare incoming tensor layouts with MKL preferred layouts and convert - // data to the preferred layout if necessary - void MklPrepareInputs(OpKernelContext* context, bool workspace_enabled) { - const Tensor& tensor_in = MklGetInput(context, 0); - const Tensor& out_backprop = MklGetInput(context, 2); - bool input_in_mkl_format = input_shape.IsMklTensor(); - bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); - - void* tmp_output_buf = nullptr; - void* workspace_buf = nullptr; - - if (workspace_enabled == false) { - if (convert_input != nullptr) { - if (input_in_mkl_format == false) { - CHECK_EQ(dnnConversionExecute_F32( - convert_input, - const_cast(static_cast( - tensor_in.flat().data())), - input_buf), - E_SUCCESS); - CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS); - convert_input = nullptr; - } else { - input_shape.GetConvertedFlatData( - lt_input_prim, - const_cast( - static_cast(tensor_in.flat().data())), - input_buf); - } - pooling_resfwd[dnnResourceSrc] = input_buf; - } else { - pooling_resfwd[dnnResourceSrc] = const_cast( - static_cast(tensor_in.flat().data())); - } - - dnnLayout_t lt_workspace; - CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( - <_workspace, prim_pooling_fwd, dnnResourceWorkspace), - E_SUCCESS); - AllocTmpBuffer(context, &workspace_buf_tensor, lt_workspace, - &workspace_buf); - pooling_resfwd[dnnResourceWorkspace] = workspace_buf; - - dnnLayoutDelete_F32(lt_workspace); - - // We create the layout for max pooling fwd prop tmp output here - AllocTmpBuffer(context, &tmp_output_buf_tensor, lt_outbackprop_prim, - &tmp_output_buf); - pooling_resfwd[dnnResourceDst] = tmp_output_buf; - - CHECK_EQ(dnnExecute_F32(prim_pooling_fwd, pooling_resfwd), E_SUCCESS); - pooling_res[dnnResourceWorkspace] = - pooling_resfwd[dnnResourceWorkspace]; - } else { - const Tensor& workspace = MklGetInput(context, 3); - pooling_res[dnnResourceWorkspace] = const_cast( - static_cast(workspace.flat().data())); - } - - // Out backprop conversions if needed - if (convert_outbackprop != nullptr) { - if (outbackprop_in_mkl_format == false) { - CHECK_EQ(dnnConversionExecute_F32( - convert_outbackprop, - const_cast(static_cast( - out_backprop.flat().data())), - outbackprop_buf), - E_SUCCESS); - CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS); - } else { - output_backprop_shape.GetConvertedFlatData( - lt_outbackprop_prim, - const_cast( - static_cast(out_backprop.flat().data())), - outbackprop_buf); - } - pooling_res[dnnResourceDiffDst] = outbackprop_buf; - } else { - pooling_res[dnnResourceDiffDst] = const_cast( - static_cast(out_backprop.flat().data())); - } - } - - void MklCleanup(bool workspace_enabled) { - bool input_in_mkl_format = input_shape.IsMklTensor(); - bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); - if (workspace_enabled == false) { - CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); - } - CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS); - if (outbackprop_in_mkl_format == false) { - CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user), E_SUCCESS); - } - CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim), E_SUCCESS); - if (input_in_mkl_format == false) { - CHECK_EQ(dnnLayoutDelete_F32(lt_input_user), E_SUCCESS); - } - if (workspace_enabled == false) { - CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim), E_SUCCESS); - } - } - } MklMaxPoolingGradOpContext; - - std::vector ksize_; - std::vector stride_; - Padding padding_; - TensorFormat data_format_; - - bool workspace_enabled_; -}; // MklMaxPoolingGradOp - -#else - // An implementation of MaxPooling (forward). template class MklMaxPoolingOp : public MklPoolingForwardOpBase { @@ -863,25 +405,35 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { } }; // MklMaxPoolingGradOp -REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3D") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklMaxPoolingOp); +#define REGISTER_MKL_MAXPOOL3D_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklMaxPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklMaxPoolingGradOp); -REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3DGrad") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklMaxPoolingGradOp); +TF_CALL_float(REGISTER_MKL_MAXPOOL3D_KERNELS); +TF_CALL_bfloat16(REGISTER_MKL_MAXPOOL3D_KERNELS); -#endif // INTEL_MKL_ML_ONLY +#define REGISTER_MKL_MAXPOOL_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklMaxPool") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklMaxPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklMaxPoolingGradOp); -REGISTER_KERNEL_BUILDER(Name("_MklMaxPool") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklMaxPoolingOp); +TF_CALL_float(REGISTER_MKL_MAXPOOL_KERNELS); +TF_CALL_bfloat16(REGISTER_MKL_MAXPOOL_KERNELS); REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool") .Device(DEVICE_CPU) @@ -895,11 +447,5 @@ REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool") .Label(mkl_op_registry::kMklQuantizedOpLabel), MklMaxPoolingOp); -REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .Label(mkl_op_registry::kMklOpLabel), - MklMaxPoolingGradOp); - } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index a8d1dffd4e5..30f7b3f38f7 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -16,16 +16,15 @@ limitations under the License. #ifdef INTEL_MKL #include "tensorflow/core/kernels/mkl_pooling_ops_common.h" + #include #include + #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/common_shape_fns.h" namespace tensorflow { - -#ifndef INTEL_MKL_ML_ONLY - using mkldnn::pooling_avg; using mkldnn::pooling_avg_exclude_padding; using mkldnn::pooling_avg_include_padding; @@ -129,6 +128,7 @@ void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, template class MklPoolingFwdPrimitive; template class MklPoolingFwdPrimitive; template class MklPoolingFwdPrimitive; +template class MklPoolingFwdPrimitive; template void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { @@ -217,8 +217,7 @@ void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, } template class MklPoolingBwdPrimitive; - -#endif +template class MklPoolingBwdPrimitive; // Initialization for TensorFlow format void MklPoolParameters::Init(OpKernelContext* context, @@ -247,22 +246,6 @@ void MklPoolParameters::Init(OpKernelContext* context, Init(context, ksize, stride, padding, data_format); } -#ifdef INTEL_MKL_ML_ONLY -// Initialization for MKL format -void MklPoolParameters::Init(OpKernelContext* context, - const std::vector& ksize, - const std::vector& stride, Padding padding, - TensorFormat data_format, - const MklShape* mklInputShape) { - // Get the input sizes - depth = mklInputShape->GetSizes()[2]; - tensor_in_cols = mklInputShape->GetSizes()[0]; - tensor_in_rows = mklInputShape->GetSizes()[1]; - tensor_in_batch = mklInputShape->GetSizes()[3]; - - Init(context, ksize, stride, padding, data_format); -} -#else // Initialization for MKL format void MklPoolParameters::Init(OpKernelContext* context, const std::vector& ksize, @@ -287,7 +270,7 @@ void MklPoolParameters::Init(OpKernelContext* context, Init(context, ksize, stride, padding, data_format); } -#endif // INTEL_MKL_ML_ONLY + // Common Initialization for TensorFlow and MKL formats void MklPoolParameters::Init(OpKernelContext* context, const std::vector& ksize, @@ -355,7 +338,7 @@ void MklPoolParameters::Init(OpKernelContext* context, OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( tensor_in_cols, window_cols, col_stride, padding, &out_width, &pad_left, &pad_right)); -#ifndef INTEL_MKL_ML_ONLY + // TF can work with int64, but mkldnn only supports int32 // Fail if the depth, height or width are greater than MAX_INT // We check depth only for 3D pooling case @@ -373,7 +356,7 @@ void MklPoolParameters::Init(OpKernelContext* context, OP_REQUIRES(context, FastBoundsCheck(out_width, std::numeric_limits::max()), errors::InvalidArgument("output width is too large")); -#endif + out_depth = depth; // output will have the same depth as the input } else { // we are pooling in the depth dimension // Our current version of depthwise max pooling does not support diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index 6e42b70d149..ec440a0aedf 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -20,21 +20,17 @@ limitations under the License. #include #include #include + +#include "mkldnn.hpp" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" -#ifndef INTEL_MKL_ML_ONLY -#include "mkldnn.hpp" using mkldnn::memory; using mkldnn::pooling_backward; using mkldnn::pooling_forward; using mkldnn::stream; -#endif namespace tensorflow { - -#ifndef INTEL_MKL_ML_ONLY - using mkldnn::memory; using mkldnn::pooling_avg; using mkldnn::pooling_avg_exclude_padding; @@ -357,7 +353,6 @@ class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory { this->SetOp(key, op); } }; -#endif typedef Eigen::ThreadPoolDevice CPUDevice; @@ -424,15 +419,9 @@ struct MklPoolParameters { void Init(OpKernelContext* context, const std::vector& ksize, const std::vector& stride, Padding padding, TensorFormat data_format, const TensorShape& tensor_in_shape); -#ifdef INTEL_MKL_ML_ONLY - void Init(OpKernelContext* context, const std::vector& ksize, - const std::vector& stride, Padding padding, - TensorFormat data_format, const MklShape* mkl_in_shape); -#else void Init(OpKernelContext* context, const std::vector& ksize, const std::vector& stride, Padding padding, TensorFormat data_format, const MklDnnShape* mkl_in_shape); -#endif private: // Common initialization for TensorFlow and MKL formats @@ -441,8 +430,6 @@ struct MklPoolParameters { TensorFormat data_format); }; -#ifndef INTEL_MKL_ML_ONLY - template class MklPoolingOpBase : public OpKernel { public: @@ -750,7 +737,6 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase { return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md; } }; -#endif // INTEL_MKL_ML_ONLY //------------------------------------------------------------------- // Utility functions diff --git a/tensorflow/core/kernels/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl_quantize_op.cc new file mode 100644 index 00000000000..1c7e6ff6854 --- /dev/null +++ b/tensorflow/core/kernels/mkl_quantize_op.cc @@ -0,0 +1,228 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL + +#define EIGEN_USE_THREADS + +#include "mkldnn.h" +#include "mkldnn.hpp" +#include "mkldnn_types.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/mkl_util.h" + +using mkldnn::primitive_attr; +using mkldnn::prop_kind; +using mkldnn::reorder; +using mkldnn::stream; + +namespace { +enum { QUANTIZE_MODE_SCALED }; +enum { + // Round half to even: if the fraction of y is exactly 0.5, then round(y) is + // the nearest even integer to y. + // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes + // -24, and -24.5 gets rounded to 24. + ROUND_HALF_TO_EVEN, +}; +} // namespace + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +// Quantizes a tensor from float to T, with user-specified min_range and +// max_range. +template +class MklQuantizeV2Op : public OpKernel { + public: + explicit MklQuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { + string mode_string; + OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); + OP_REQUIRES(ctx, (mode_string == "SCALED"), + errors::InvalidArgument("mode must be scaled")); + mode_ = QUANTIZE_MODE_SCALED; + string round_mode_string; + OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string)); + OP_REQUIRES(ctx, (round_mode_string == "HALF_TO_EVEN"), + errors::InvalidArgument("Round mode must be half to even")); + round_mode_ = ROUND_HALF_TO_EVEN; + } + + ~MklQuantizeV2Op() {} + + void Compute(OpKernelContext* ctx) override { + const float input_min_range = ctx->input(1).flat()(0); + const float input_max_range = ctx->input(2).flat()(0); + float min_range = std::min(0.0f, input_min_range); + float max_range; + OP_REQUIRES(ctx, (input_max_range > input_min_range), + errors::InvalidArgument( + "input_max_range must be larger than input_min_range.")); + + // When the minimum and maximum ranges are too close together, nudge them + // apart by a small value so that they are slightly different. This helps + // us avoid creating ill-formed buffers where all quantized values map to + // the same float number. These kinds of buffers cause problems for + // downstream ops when they need to do calculations on them. + // We pick the value by making sure that zero is not more than 100x the + // overall range from the maximum, so that the value can be easily + // represented when we promote the quantized value to a higher + // intermediate bit depth, since that's a common requirement. + const float epsilon = std::max(1.0f, std::max(fabsf(input_min_range), + fabsf(input_max_range))) / + 100.0f; + max_range = std::max(input_max_range, min_range + epsilon); + // Clamping the max_range to zero since max_range can also be negative. + max_range = std::max(0.0f, max_range); + auto cpu_engine = engine(engine::cpu, 0); + const unsigned int src_idx = 0; + const Tensor& src_tensor = MklGetInput(ctx, src_idx); + MklDnnShape src_mkl_shape; + GetMklShape(ctx, src_idx, &src_mkl_shape); + auto src_tf_shape = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape() + : src_tensor.shape(); + auto src_dims = src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDims(src_tensor.shape()); + auto output_dims = src_dims; + // Set the dst layout to be the best mkl layout based on dims and type. + memory::format dst_layout_type; + switch (src_tf_shape.dims()) { + case 1: + dst_layout_type = memory::format::x; + break; + case 2: + dst_layout_type = memory::format::nc; + break; + case 3: + dst_layout_type = memory::format::tnc; + break; + case 4: + dst_layout_type = memory::format::nhwc; + break; + case 5: + dst_layout_type = memory::format::ndhwc; + break; + default: + OP_REQUIRES_OK(ctx, + errors::Aborted("Input dims must be <= 5 and >= 1")); + return; + } + // Create reorder memory for src, dst: both are defined in mkl_util.h, + // they are wrapper + MklDnnData src(&cpu_engine); + MklDnnData dst(&cpu_engine); + auto src_md = + src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), dst_layout_type); + src.SetUsrMem(src_md, &src_tensor); + + memory::desc dst_md = + memory::desc(src_dims, MklDnnType(), dst_layout_type); + auto dst_pd = src.GetUsrMemPrimDesc(); + // Standard shape assignments for layout pass + MklDnnShape output_mkl_shape; + TensorShape output_tf_shape; + if (src_mkl_shape.IsMklTensor()) { + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_md); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(), + src_mkl_shape.GetSizesAsMklDnnDims(), + src_mkl_shape.GetTfDataFormat()); + output_tf_shape.AddDim(dst_pd.get_size() / sizeof(T)); + } else { + output_mkl_shape.SetMklTensor(false); + output_tf_shape = MklDnnDimsToTFShape(output_dims); + } + + Tensor* output_tensor = nullptr; + AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape, + output_mkl_shape); + TensorShape min_tf_shape = {}; + MklDnnShape min_mkl_shape; + min_mkl_shape.SetMklTensor(false); + Tensor* output_min_tensor = nullptr; + AllocateOutputSetMklShape(ctx, 1, &output_min_tensor, min_tf_shape, + min_mkl_shape); + TensorShape max_tf_shape = {}; + MklDnnShape max_mkl_shape; + max_mkl_shape.SetMklTensor(false); + Tensor* output_max_tensor = nullptr; + AllocateOutputSetMklShape(ctx, 2, &output_max_tensor, max_tf_shape, + max_mkl_shape); + + dst.SetUsrMem(dst_md, output_tensor); + // Estimating scales for quantization. + const int num_bits = sizeof(T) * 8; + const float max_abs = std::max(std::abs(min_range), std::abs(max_range)); + const bool is_signed = std::is_signed::value; + float target_range; + if (is_signed) { + max_range = max_abs; + min_range = -max_abs; + // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For + // example, if it is 8 bits, we have the range [-127, 127]. So for input + // range of [-x, x], the scale should be 254/(2*x). + target_range = static_cast((uint64_t{1} << (num_bits - 1)) - 1); + } else { + max_range = max_abs; + min_range = 0.0; + // If it is unsigned and num_bits == 8, the range with 8 bits is [0, + // 255]. If the input range is [0, x], then the scale is 255/x instead + // of 254 as in the case above. + target_range = static_cast((uint64_t{1} << num_bits) - 1); + } + output_min_tensor->flat()(0) = min_range; + output_max_tensor->flat()(0) = max_range; + const float scale_factor = target_range / max_abs; + // Primitive creation and stream submit + std::vector scales{scale_factor}; + mkldnn::primitive_attr attr; + attr.set_output_scales(0, scales); + auto reorder_desc = reorder::primitive_desc(src.GetUsrMemPrimDesc(), + dst.GetUsrMemPrimDesc(), attr); + reorder my_reorder = reorder(reorder_desc, primitive::at(*src.GetUsrMem()), + *dst.GetUsrMem()); + std::vector net{my_reorder}; + stream(stream::kind::eager).submit(net).wait(); + } + + private: + int mode_; + int round_mode_; +}; + +REGISTER_KERNEL_BUILDER(Name("_MklQuantizeV2") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizeV2Op); +REGISTER_KERNEL_BUILDER(Name("_MklQuantizeV2") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizeV2Op); +} // namespace tensorflow + +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_quantize_op_test.cc b/tensorflow/core/kernels/mkl_quantize_op_test.cc new file mode 100644 index 00000000000..cb53411ee6c --- /dev/null +++ b/tensorflow/core/kernels/mkl_quantize_op_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +class MklQuantizeV2OpTest : public OpsTestBase {}; + +static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0}; +static const TensorShape dummy_shape({8}); + +TEST_F(MklQuantizeV2OpTest, small_uint8) { + TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("T", DataTypeToEnum::v()) + .Attr("mode", "SCALED") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({8}), + {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0}); + // min_range = 0 + AddInputFromArray(TensorShape({1}), {0}); + // max_range = 255 + AddInputFromArray(TensorShape({1}), {255.0f}); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_QUINT8, TensorShape({8})); + Tensor expected_min(allocator(), DT_FLOAT, TensorShape({})); + Tensor expected_max(allocator(), DT_FLOAT, TensorShape({})); + // Input element 0.0 should map to 0. + // Input element 500.0 is quantized to 255 because max_range = 255. + test::FillValues(&expected, {0, 1, 1, 2, 127, 255, 255, 2}); + test::FillValues(&expected_min, {0.0}); + test::FillValues(&expected_max, {255.0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + test::ExpectTensorEqual(expected_min, *GetOutput(1)); + test::ExpectTensorEqual(expected_max, *GetOutput(2)); +} +TEST_F(MklQuantizeV2OpTest, small_int8) { + TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("T", DataTypeToEnum::v()) + .Attr("mode", "SCALED") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({8}), {0.0, -1.0, 1.25, -1.75, -24.5, + -255.0, -80.315, 256.0}); + AddInputFromArray(TensorShape({1}), {-50.0}); + AddInputFromArray(TensorShape({1}), {127.0}); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_QINT8, TensorShape({8})); + Tensor expected_min(allocator(), DT_FLOAT, TensorShape({})); + Tensor expected_max(allocator(), DT_FLOAT, TensorShape({})); + test::FillValues(&expected, {0, -1, 1, -2, -24, -128, -80, 127}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + test::FillValues(&expected_min, {-127.0}); + test::FillValues(&expected_max, {127.0}); + test::ExpectTensorEqual(expected_min, *GetOutput(1)); + test::ExpectTensorEqual(expected_max, *GetOutput(2)); +} +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc b/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc index fc68480bbe8..d75077aa105 100644 --- a/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc +++ b/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc @@ -86,6 +86,10 @@ TEST_F(QuantizedConcatTest, Small8BitSameRange) { TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f); } +TEST_F(QuantizedConcatTest, Small8BitDifferentRange) { + TestSmall8Bit(0.0f, 255.0f, 0.0f, 25.0f); +} + void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max, float second_min, float second_max) { TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2") diff --git a/tensorflow/core/kernels/mkl_quantized_conv_ops_perchannel_test.cc b/tensorflow/core/kernels/mkl_quantized_conv_ops_perchannel_test.cc new file mode 100644 index 00000000000..dcef8360f04 --- /dev/null +++ b/tensorflow/core/kernels/mkl_quantized_conv_ops_perchannel_test.cc @@ -0,0 +1,221 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// Some helper constants +static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0}; +static const TensorShape dummy_shape({8}); + +// TODO(nammbash): Move this helper class to mkl_utils or mkl_test_utils +// so that all tests can use. (set a separate PR that changes all MKL tests). +// Helper class for converting MKL tensors to TF tensors +class ConvMklToTF : public OpsTestBase { + public: + template + void ConvertMKL2TF(DataType dtype, const Tensor& first, const Tensor& second, + Tensor& output) { + // Create an MKL to TF conversion node and execute it + TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf") + .Input(FakeInput(dtype)) // Input + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("T", dtype) + .Attr("_kernel", "MklOp") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(first.shape(), first.flat()); + AddInputFromArray(second.shape(), second.flat()); + TF_ASSERT_OK(RunOpKernel()); + + output = *GetOutput(0); + } + void TestBody() {} +}; + +class QuantizedConv2DPerchannelTest : public OpsTestBase {}; + +TEST_F(QuantizedConv2DPerchannelTest, Small) { + const int stride = 1; + TF_ASSERT_OK(NodeDefBuilder("quantized_conv_perchannel_op", + "_MklQuantizedConv2DPerChannel") + .Input(FakeInput(DT_QUINT8)) + .Input(FakeInput(DT_QINT8)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + // MKL metadata tensors + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + // Attributes + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("Tfilter", DataTypeToEnum::v()) + .Attr("T", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Attr("strides", {1, stride, stride, 1}) + .Attr("is_filter_const", true) + .Attr("padding", "SAME") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + + TF_ASSERT_OK(InitOp()); + + // Image shape + const int image_batch_count = 1; + const int image_height = 3; + const int image_width = 4; + const int image_channel = 1; + + // Image is of datatype uint8 + const float image_min = 0.0f; + const float image_max = 255.0f; + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + Tensor image_float( + DT_FLOAT, {image_batch_count, image_height, image_width, image_channel}); + test::FillValues(&image_float, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // Create image tensor + Tensor image_quantized = + FloatTensorToQuantized(image_float, image_min, image_max); + + // Filter shape + const int filter_height = 3; + const int filter_width = 3; + const int filter_channel = 1; + const int filter_count = 2; + + // Filter is of datatype int8 + const float filter_min = -127.0f; // (-128.0f changed for symmetry) + const float filter_max = 127.0f; + + // The filter matrix (for each output channel count) is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + Tensor filter_float( + DT_FLOAT, {filter_height, filter_width, filter_channel, filter_count}); + test::FillValues( + &filter_float, {1, 1, 4, 4, 7, 7, 2, 2, 5, 5, 8, 8, 3, 3, 6, 6, 9, 9}); + + // Create filter tensor + Tensor filter_quantized = + FloatTensorToQuantized(filter_float, filter_min, filter_max); + + // Add the tensors as input to the current op. + AddInputFromArray(image_quantized.shape(), + image_quantized.flat()); + AddInputFromArray(filter_quantized.shape(), + filter_quantized.flat()); + AddInputFromArray(TensorShape({1}), {image_min}); + AddInputFromArray(TensorShape({1}), {image_max}); + AddInputFromArray(TensorShape({2}), {filter_min, filter_min}); + AddInputFromArray(TensorShape({2}), {filter_max, filter_max}); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + // Run the op Kernel. + TF_ASSERT_OK(RunOpKernel()); + + // Get the output + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + // Convert the output tensor in MKL to TF format. + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMKL2TF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + const float output_min = GetOutput(1)->flat()(0); + const float output_max = GetOutput(2)->flat()(0); + Tensor output_float = + QuantizedTensorToFloat(output_quantized, output_min, output_max); + + // Get the Expected Output tensor. + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input dimensions set to zero because we're using the 'SAME' padding + // mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*8)+(7*0)+(2*11)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + + // This means we should end up with this matrix for each channel: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + + // Shape of expected (output) tensor: N x IH x IW x filter_count + // Create the expected output tensor + const int expected_width = image_width; + const int expected_height = image_height; + + Tensor expected_float( + DT_FLOAT, TensorShape({image_batch_count, expected_height, expected_width, + filter_count})); + + test::FillValues( + &expected_float, + {105, 105, 150, 150, 183, 183, 95, 95, 235, 235, 312, 312, + 357, 357, 178, 178, 187, 187, 234, 234, 261, 261, 121, 121}); + + // Test whether the values are as expected. + test::ExpectTensorNear(expected_float, output_float, 0.5); +} + +} // namespace tensorflow +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc b/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc index 2e599d3d9f8..91eb260889d 100644 --- a/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc +++ b/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc @@ -96,6 +96,90 @@ class QuantizedConv2DTest : public OpsTestBase { .Finalize(node_def())); TF_ASSERT_OK(InitOp()); } + + void RunQuantizedDepthwiseConv2DOp(const bool& bias_enabled) { + const int depth = 2; + const int image_width = 2; + const int image_height = 3; + const int image_batch_count = 1; + // The image matrix is ('first/second' channel): + // | 1/2 | 3/4 | + // | 5/6 | 7/8 | + // | 9/10 | 11/12 | + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + // The filter matrix is: + // | 1/2 | 7/8 | 13/14 | + // | 3/4 | 9/10 | 15/16 | + // | 5/6 | 11/12 | 17/18 | + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {1, 2, 7, 8, 13, 14, 3, 4, 9, 10, 15, 16, 5, 6, 11, 12, 17, 18}); + + if (bias_enabled) { + // Bias -> float + AddInputFromArray(TensorShape({depth}), {1.0f, 1.0f}); + } + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + if (bias_enabled) { + AddInputFromArray(dummy_shape, dummy_tensor); + } + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // We're sliding two 3x3 filters across the 3x2 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // This means we should end up with this matrix: + // | 228/300 | 132/180 | + // | 482/596 | 266/344 | + // | 372/452 | 180/236 | + // + // Similarly, after adding a bias of 1.0f across each channel, we should end + // up with this matrix: + // | 229/301 | 133/181 | + // | 483/597 | 267/345 | + // | 373/453 | 181/237 | + + // Output -> qint32 + Tensor expected(DT_QINT32, TensorShape({image_batch_count, image_height, + image_width, depth})); + if (bias_enabled) { + test::FillValues(&expected, {229, 301, 133, 181, 483, 597, 267, + 345, 373, 453, 181, 237}); + } else { + test::FillValues(&expected, {228, 300, 132, 180, 482, 596, 266, + 344, 372, 452, 180, 236}); + } + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); + } }; // Output -> float @@ -454,5 +538,98 @@ TEST_F(QuantizedConv2DTest, OddPaddingBatch) { test::ExpectTensorEqual(expected, output_quantized); } +TEST_F(QuantizedConv2DTest, DepthwiseConv2D) { + const int stride = 1; + TF_ASSERT_OK(NodeDefBuilder("quantized_depthwise_conv_op", + "_MklQuantizedDepthwiseConv2D") + .Input(FakeInput(DT_QUINT8)) // Input + .Input(FakeInput(DT_QINT8)) // Filter + .Input(FakeInput(DT_FLOAT)) // Min input + .Input(FakeInput(DT_FLOAT)) // Max input + .Input(FakeInput(DT_FLOAT)) // Min filter + .Input(FakeInput(DT_FLOAT)) // Max filter + // MKL metadata tensors // + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + /////////////////////////// + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("Tfilter", DataTypeToEnum::v()) + .Attr("T", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + RunQuantizedDepthwiseConv2DOp(false); +} + +TEST_F(QuantizedConv2DTest, DepthwiseConv2DWithBias) { + const int stride = 1; + TF_ASSERT_OK(NodeDefBuilder("quantized_depthwise_conv_op", + "_MklQuantizedDepthwiseConv2DWithBias") + .Input(FakeInput(DT_QUINT8)) // Input + .Input(FakeInput(DT_QINT8)) // Filter + .Input(FakeInput(DT_FLOAT)) // Bias + .Input(FakeInput(DT_FLOAT)) // Min input + .Input(FakeInput(DT_FLOAT)) // Max input + .Input(FakeInput(DT_FLOAT)) // Min filter + .Input(FakeInput(DT_FLOAT)) // Max filter + // MKL metadata tensors // + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + /////////////////////////// + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("Tfilter", DataTypeToEnum::v()) + .Attr("T", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + RunQuantizedDepthwiseConv2DOp(true); +} + +TEST_F(QuantizedConv2DTest, DepthwiseConv2DWithBiasAndRelu) { + const int stride = 1; + TF_ASSERT_OK(NodeDefBuilder("quantized_depthwise_conv_op", + "_MklQuantizedDepthwiseConv2DWithBiasAndRelu") + .Input(FakeInput(DT_QUINT8)) // Input + .Input(FakeInput(DT_QINT8)) // Filter + .Input(FakeInput(DT_FLOAT)) // Bias + .Input(FakeInput(DT_FLOAT)) // Min input + .Input(FakeInput(DT_FLOAT)) // Max input + .Input(FakeInput(DT_FLOAT)) // Min filter + .Input(FakeInput(DT_FLOAT)) // Max filter + // MKL metadata tensors // + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + /////////////////////////// + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("Tfilter", DataTypeToEnum::v()) + .Attr("T", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + RunQuantizedDepthwiseConv2DOp(true); +} } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 19585969993..c9d740c9e2c 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -28,12 +28,11 @@ limitations under the License. using mkldnn::algorithm; using mkldnn::eltwise_bounded_relu; using mkldnn::eltwise_elu; +using mkldnn::eltwise_forward; using mkldnn::eltwise_relu; using mkldnn::eltwise_tanh; using mkldnn::memory; using mkldnn::prop_kind; -using mkldnn::relu_backward; -using mkldnn::relu_forward; using mkldnn::stream; namespace tensorflow { @@ -44,11 +43,11 @@ class MklEltwiseFwdParams { memory::dims src_dims; // check if this is needed memory::desc src_md; algorithm alg_kind; - T alpha; - T beta; + float alpha; + float beta; MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md, - algorithm alg_kind, T alpha, T beta) + algorithm alg_kind, float alpha, float beta) : src_dims(src_dims), src_md(src_md), alg_kind(alg_kind), @@ -227,12 +226,12 @@ class MklEltwiseBwdParams { memory::dims src_dims; memory::desc common_md; algorithm alg_kind; - T alpha; - T beta; + float alpha; + float beta; MklEltwiseBwdParams(const memory::dims& src_dims, const memory::desc& common_md, algorithm alg_kind, - T alpha, T beta) + float alpha, float beta) : src_dims(src_dims), common_md(common_md), alg_kind(alg_kind), @@ -542,7 +541,7 @@ class MklReluOpBase : public OpKernel { private: engine cpu_engine = engine(engine::cpu, 0); - std::shared_ptr relu_fwd_pd; + std::shared_ptr relu_fwd_pd; protected: float alpha_; @@ -710,7 +709,7 @@ class MklReluGradOpBase : public OpKernel { private: engine cpu_engine = engine(engine::cpu, 0); - std::shared_ptr relu_fwd_pd; + std::shared_ptr relu_fwd_pd; protected: float alpha_; @@ -775,7 +774,8 @@ class MklReluGradOp : public MklReluGradOpBase { void* user_g = static_cast(const_cast(diff_dst_tensor.flat().data())); (static_cast(out_o))[0] = - (static_cast(user_g))[0] * ((static_cast(user_i))[0] > 0); + (static_cast(user_g))[0] * + (static_cast((static_cast(user_i))[0] > static_cast(0))); return; } }; @@ -805,7 +805,7 @@ class MklEluOp : public MklReluOpBase { void* out_o = static_cast(dst_tensor->flat().data()); // return exp(feature) - 1 if feature > 0; feature otherwise T feature = (static_cast(user_i))[0]; - if (feature < 0) + if (feature < static_cast(0)) (static_cast(out_o))[0] = std::exp(feature); else (static_cast(out_o))[0] = feature; @@ -843,11 +843,12 @@ class MklEluGradOp : public MklReluGradOpBase { static_cast(const_cast(diff_dst_tensor.flat().data())); // gradient of elu(x) = 1 if x > 0; elu(x) + 1 otherwise T feature = (static_cast(user_i))[0]; - if (feature > 0) { + if (feature > static_cast(0)) { (static_cast(out_o))[0] = (static_cast(user_g))[0]; } else { - T elu = std::exp(feature) - 1; - (static_cast(out_o))[0] = (static_cast(user_g))[0] * (elu + 1); + T elu = std::exp(feature) - static_cast(1); + (static_cast(out_o))[0] = + (static_cast(user_g))[0] * (elu + static_cast(1)); } } }; @@ -918,7 +919,7 @@ class MklTanhGradOp : public MklReluGradOpBase { void* user_g = static_cast(const_cast(diff_dst_tensor.flat().data())); (static_cast(out_o))[0] = - (static_cast(user_g))[0] * (1 - tanh * tanh); + (static_cast(user_g))[0] * (static_cast(1) - tanh * tanh); } }; @@ -980,8 +981,9 @@ class MklRelu6GradOp T* out_o = diff_src_tensor->flat().data(); T* user_i = const_cast(src_tensor.flat().data()); T* user_g = const_cast(diff_dst_tensor.flat().data()); - out_o[0] = user_g[0] * (user_i[0] > 0 && - (user_i[0] < static_cast(RELU6_UPPER_BOUND))); + out_o[0] = user_g[0] * + static_cast(user_i[0] > static_cast(0) && + (user_i[0] < static_cast(RELU6_UPPER_BOUND))); return; } }; @@ -1018,7 +1020,7 @@ class MklLeakyReluOp : public MklReluOpBase { AllocateOutputSetMklShape(context, dst_index, &dst_tensor, src_tensor.shape(), dnn_shape_dst); T* out_o = dst_tensor->flat().data(); - out_o[0] = user_i[0] >= 0 ? user_i[0] : user_i[0] * this->alpha_; + out_o[0] = user_i[0] >= T(0) ? user_i[0] : user_i[0] * T(this->alpha_); return; } }; @@ -1059,7 +1061,9 @@ class MklLeakyReluGradOp : public MklReluGradOpBase { T* out_o = diff_src_tensor->flat().data(); T* user_i = const_cast(src_tensor.flat().data()); T* user_g = const_cast(diff_dst_tensor.flat().data()); - out_o[0] = user_i[0] >= 0 ? user_g[0] : user_g[0] * this->alpha_; + out_o[0] = user_i[0] >= static_cast(0) + ? user_g[0] + : user_g[0] * static_cast(this->alpha_); return; } }; @@ -1077,6 +1081,7 @@ class MklLeakyReluGradOp : public MklReluGradOpBase { .Label(mkl_op_registry::kMklOpLabel), \ MklReluGradOp); TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); +TF_CALL_bfloat16(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); // register dnn kernels for supported operations and supported types #define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ @@ -1091,6 +1096,7 @@ TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); .Label(mkl_op_registry::kMklOpLabel), \ MklEluGradOp); TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); +TF_CALL_bfloat16(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); #define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER(Name("_MklTanh") \ @@ -1104,6 +1110,7 @@ TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); .Label(mkl_op_registry::kMklOpLabel), \ MklTanhGradOp); TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); +TF_CALL_bfloat16(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); #define REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER(Name("_MklRelu6") \ @@ -1117,6 +1124,7 @@ TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); .Label(mkl_op_registry::kMklOpLabel), \ MklRelu6GradOp); TF_CALL_float(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES); +TF_CALL_bfloat16(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES); #define REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER(Name("_MklLeakyRelu") \ @@ -1130,6 +1138,7 @@ TF_CALL_float(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES); .Label(mkl_op_registry::kMklOpLabel), \ MklLeakyReluGradOp); TF_CALL_float(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES); +TF_CALL_bfloat16(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES); } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc index b5c1a01f831..8fbb16c11fb 100644 --- a/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc @@ -42,7 +42,7 @@ class MklRequantizePerChannelOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_type_)); OP_REQUIRES(ctx, out_type_ == DT_QINT8 || out_type_ == DT_QUINT8, errors::InvalidArgument( - "out_type must be qint8 or quint8, but got: " + out_type_)); + "out_type must be qint8 or quint8, but got: ", out_type_)); } virtual ~MklRequantizePerChannelOp() {} void Compute(OpKernelContext* ctx) override { @@ -162,11 +162,18 @@ class MklRequantizePerChannelOp : public OpKernel { engine cpu_engine_ = engine(engine::cpu, 0); }; +// Registration for out_type: qint8 REGISTER_KERNEL_BUILDER(Name("RequantizePerChannel") .Device(DEVICE_CPU) .TypeConstraint("T") .TypeConstraint("out_type"), MklRequantizePerChannelOp); +// Registration for out_type: quint8 +REGISTER_KERNEL_BUILDER(Name("RequantizePerChannel") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .TypeConstraint("out_type"), + MklRequantizePerChannelOp); } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index 342e2265ee8..58e76805af6 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -68,9 +68,10 @@ class MklReshapeOp : public OpKernel { MklDnnShape mkl_shape_input; GetMklShape(context, kInputSlotIdx, &mkl_shape_input); bool input_in_mkl_format = mkl_shape_input.IsMklTensor(); - const int64 nelems = input_in_mkl_format - ? mkl_shape_input.GetTfShape().num_elements() - : input_tensor.NumElements(); + TensorShape input_shape = input_in_mkl_format ? mkl_shape_input.GetTfShape() + : input_tensor.shape(); + const int64 nelems = input_in_mkl_format ? input_shape.num_elements() + : input_tensor.NumElements(); // Preliminary validation of sizes. OP_REQUIRES(context, IsLegacyVector(sizes.shape()), @@ -82,14 +83,17 @@ class MklReshapeOp : public OpKernel { TensorShape shape; int64 product = 1; int unknown_index = -1; + bool sizes_has_zero_dim = false; switch (sizes.dtype()) { case DT_INT32: - OP_REQUIRES_OK(context, ValidateSizes(sizes, &product, - &unknown_index, &shape)); + OP_REQUIRES_OK(context, + ValidateSizes(sizes, &product, &unknown_index, + &shape, &sizes_has_zero_dim)); break; case DT_INT64: - OP_REQUIRES_OK(context, ValidateSizes(sizes, &product, - &unknown_index, &shape)); + OP_REQUIRES_OK(context, + ValidateSizes(sizes, &product, &unknown_index, + &shape, &sizes_has_zero_dim)); break; default: context->CtxFailure(errors::InvalidArgument( @@ -98,18 +102,28 @@ class MklReshapeOp : public OpKernel { return; } if (unknown_index != -1) { - OP_REQUIRES( - context, product > 0, - errors::InvalidArgument("Reshape cannot infer the missing input size " - "for an empty tensor unless all specified " - "input sizes are non-zero")); - const int64 missing = nelems / product; - OP_REQUIRES( - context, product * missing == nelems, - errors::InvalidArgument( - "Input to reshape is a tensor with ", nelems, - " values, but the requested shape requires a multiple of ", - product)); + int64 input_num_elements = 1; + bool input_has_zero_dim = false; + for (int dim = 0; dim < input_shape.dims(); ++dim) { + // For zero dimension, we don't count it into `input_num_elements` + // unless `sizes` has no zero dimension, so we are still able to + // infer shapes for other dimensions. + if (input_shape.dim_size(dim) > 0 || !sizes_has_zero_dim) { + input_num_elements *= input_shape.dim_size(dim); + } else { + input_has_zero_dim = true; + } + } + + const int64 missing = input_num_elements / product; + if (!input_has_zero_dim) { + OP_REQUIRES( + context, product * missing == input_num_elements, + errors::InvalidArgument( + "Input to reshape is a tensor with ", input_num_elements, + " values, but the requested shape requires a multiple of ", + product)); + } shape.set_dim(unknown_index, missing); } OP_REQUIRES( @@ -163,7 +177,8 @@ class MklReshapeOp : public OpKernel { // shape_from != shape_to), then we just copy input tensor to // output tensor with target shape (we cannot forward Mkl layout // in such case because shape has changed.) - if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, output_tensor)) { + if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, + output_tensor)) { } else { OP_REQUIRES( context, output_tensor->CopyFrom(input_tensor, shape_to), @@ -213,16 +228,16 @@ class MklReshapeOp : public OpKernel { } } - private: const int kInputSlotIdx = 0; const int kOutputSlotIdx = 0; template Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index, - TensorShape* shape) { + TensorShape* shape, bool* has_zero_dim) { *product = 1; *unknown_index = -1; + *has_zero_dim = false; const int64 num_dims = sizes.NumElements(); auto Svec = sizes.flat(); for (int d = 0; d < num_dims; ++d) { @@ -238,6 +253,12 @@ class MklReshapeOp : public OpKernel { } else if (size < 0) { return errors::InvalidArgument("Size ", d, " must be non-negative, not ", size); + } else if (size == 0) { + // We don't include zero-sized dimension in product, so that we can + // still calculate number of elements for non-zero-sized dimensions and + // therefore infer their shapes. + shape->AddDim(size); + *has_zero_dim = true; } else { shape->AddDim(size); (*product) *= size; @@ -263,6 +284,7 @@ class MklReshapeOp : public OpKernel { .Label(mkl_op_registry::kMklOpLabel), \ MklReshapeOp); TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_bfloat16(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc index e2cbeec2d28..5d238a24bc6 100644 --- a/tensorflow/core/kernels/mkl_slice_op.cc +++ b/tensorflow/core/kernels/mkl_slice_op.cc @@ -16,8 +16,9 @@ limitations under the License. // See docs in ../ops/array_ops.cc. #ifdef INTEL_MKL -#ifndef INTEL_MKL_ML_ONLY +#include "mkldnn.hpp" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -25,9 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/prefetch.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -#include "mkldnn.hpp" #include "tensorflow/core/util/mkl_util.h" using mkldnn::stream; @@ -485,9 +483,9 @@ class MklSliceOp : public OpKernel { MklSliceOp); TF_CALL_float(REGISTER_MKL_SLICE); +TF_CALL_bfloat16(REGISTER_MKL_SLICE); #undef REGISTER_MKL_SLICE } // namespace tensorflow -#endif // INTEL_MKL_DNN #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc index edc71569a60..d3025d34d87 100644 --- a/tensorflow/core/kernels/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl_transpose_op.cc @@ -184,6 +184,10 @@ Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, case DT_FLOAT: return MKLTransposeND(ctx, in, out, perm); break; + // TODO(nhasabni): Enable this case when we turn on bfloat16 compilation. + // case DT_BFLOAT16: + // return MKLTransposeND(ctx, in, out, perm); + // break; // TODO(nhasabni): support other types such as INT8. default: break; @@ -228,6 +232,10 @@ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, case DT_FLOAT: return MKLTransposeND(ctx, in, out, perm); break; + // TODO(nhasabni): Enable this case when we turn on bfloat16 compilation. + // case DT_BFLOAT16: + // return MKLTransposeND(ctx, in, out, perm); + // break; // TODO(nhasabni): support other types such as INT8. default: break; diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc index 82dfece4a2a..46852167ae0 100644 --- a/tensorflow/core/kernels/multinomial_op.cc +++ b/tensorflow/core/kernels/multinomial_op.cc @@ -53,6 +53,20 @@ struct MultinomialFunctor { typename TTypes::Matrix output); }; +#if GOOGLE_CUDA +extern template struct MultinomialFunctor; +extern template struct MultinomialFunctor; +extern template struct MultinomialFunctor; +extern template struct MultinomialFunctor; +extern template struct MultinomialFunctor; + +extern template struct MultinomialFunctor; +extern template struct MultinomialFunctor; +extern template struct MultinomialFunctor; +extern template struct MultinomialFunctor; +extern template struct MultinomialFunctor; +#endif // GOOGLE_CUDA + template struct MultinomialFunctor { void operator()(OpKernelContext* ctx, const CPUDevice& d, diff --git a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc index 62e38694c8f..8143a033960 100644 --- a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc +++ b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc @@ -17,18 +17,17 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/multinomial_op.h" - #include #include #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/multinomial_op.h" #include "tensorflow/core/kernels/random_op.h" #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" #include "tensorflow/core/kernels/reduction_ops_common.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -110,11 +109,11 @@ struct MultinomialFunctor { output.device(d) = output.constant(0LL); const int32 work_items = batch_size * num_samples * num_classes; - CudaLaunchConfig config = GetCudaLaunchConfig(work_items, d); - MultinomialKernel<<>>(config.virtual_thread_count, num_classes, - num_samples, scores.data(), maxima.data(), - output.data()); + GpuLaunchConfig config = GetCudaLaunchConfig(work_items, d); + TF_CHECK_OK(CudaLaunchKernel( + MultinomialKernel, config.block_count, + config.thread_per_block, 0, d.stream(), config.virtual_thread_count, + num_classes, num_samples, scores.data(), maxima.data(), output.data())); } }; diff --git a/tensorflow/core/kernels/mutex_ops.cc b/tensorflow/core/kernels/mutex_ops.cc index 2f4a5e9aa03..0cc29b42d93 100644 --- a/tensorflow/core/kernels/mutex_ops.cc +++ b/tensorflow/core/kernels/mutex_ops.cc @@ -74,6 +74,8 @@ class Mutex : public ResourceBase { struct SharedLockReleaser { std::shared_ptr shared_lock; + SharedLockReleaser() : shared_lock() {} + explicit SharedLockReleaser(std::shared_ptr&& lock) : shared_lock(std::forward(lock)) { VLOG(3) << "Creating shared_ptr of " << shared_lock.get() @@ -86,6 +88,16 @@ class Mutex : public ResourceBase { << " count is: " << shared_lock.use_count(); } + SharedLockReleaser& operator=(const SharedLockReleaser& rhs) = delete; + + SharedLockReleaser& operator=(SharedLockReleaser&& rhs) { + if (&rhs == this) return *this; + std::swap(shared_lock, rhs.shared_lock); + VLOG(3) << "Move-assign of SharedLockReleaser of " << shared_lock.get() + << " count is: " << shared_lock.use_count(); + return *this; + } + SharedLockReleaser(const SharedLockReleaser& rhs) : shared_lock(rhs.shared_lock) { VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get() diff --git a/tensorflow/core/kernels/nn_ops_test.cc b/tensorflow/core/kernels/nn_ops_test.cc index a841291ddd7..e977aa51afb 100644 --- a/tensorflow/core/kernels/nn_ops_test.cc +++ b/tensorflow/core/kernels/nn_ops_test.cc @@ -15,9 +15,12 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "tensorflow/cc/ops/nn_ops.h" #include #include @@ -26,10 +29,8 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/fake_input.h" @@ -735,8 +736,8 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows, DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); thread::ThreadPool threadpool(Env::Default(), "test", num_threads); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(), + num_threads); device->set_eigen_cpu_device(&eigen_cpu_device); gtl::InlinedVector inputs; @@ -817,8 +818,8 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth, DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); thread::ThreadPool threadpool(Env::Default(), "test", num_threads); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(), + num_threads); device->set_eigen_cpu_device(&eigen_cpu_device); gtl::InlinedVector inputs; @@ -909,8 +910,8 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols, DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); thread::ThreadPool threadpool(Env::Default(), "test", num_threads); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(), + num_threads); device->set_eigen_cpu_device(&eigen_cpu_device); gtl::InlinedVector inputs; @@ -1013,8 +1014,8 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth, DeviceFactory::NewDevice("CPU", options, "/job:a/replica:0/task:0")); thread::ThreadPool threadpool(Env::Default(), "test", num_threads); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(), + num_threads); device->set_eigen_cpu_device(&eigen_cpu_device); gtl::InlinedVector inputs; @@ -1193,8 +1194,8 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols, DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); thread::ThreadPool threadpool(Env::Default(), "test", num_threads); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads); + Eigen::ThreadPoolDevice eigen_cpu_device(threadpool.AsEigenThreadPool(), + num_threads); device->set_eigen_cpu_device(&eigen_cpu_device); gtl::InlinedVector inputs; diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 482b227ccdc..5d46e5bb209 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -227,14 +227,11 @@ void BatchedNonMaxSuppressionOp( OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores, int num_boxes, const int max_size_per_class, const int total_size_per_batch, const float score_threshold, const float iou_threshold, - bool pad_per_class = false) { + bool pad_per_class = false, bool clip_boxes = true) { int q = inp_boxes.dim_size(2); int num_classes = inp_scores.dim_size(2); const int num_batches = inp_boxes.dim_size(0); - // Default clip window of [0, 0, 1, 1] if none specified - std::vector clip_window{0, 0, 1, 1}; - // [num_batches, per_batch_size * 4] std::vector> nmsed_boxes(num_batches); // [num_batches, per_batch_size] @@ -375,18 +372,23 @@ void BatchedNonMaxSuppressionOp( while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) { ResultCandidate next_candidate = result_candidate_vec[result_idx++]; // Add to final output vectors - nmsed_boxes[batch].push_back( - std::max(std::min(next_candidate.box_coord[0], clip_window[2]), - clip_window[0])); - nmsed_boxes[batch].push_back( - std::max(std::min(next_candidate.box_coord[1], clip_window[3]), - clip_window[1])); - nmsed_boxes[batch].push_back( - std::max(std::min(next_candidate.box_coord[2], clip_window[2]), - clip_window[0])); - nmsed_boxes[batch].push_back( - std::max(std::min(next_candidate.box_coord[3], clip_window[3]), - clip_window[1])); + if (clip_boxes) { + const float box_min = 0.0; + const float box_max = 1.0; + nmsed_boxes[batch].push_back( + std::max(std::min(next_candidate.box_coord[0], box_max), box_min)); + nmsed_boxes[batch].push_back( + std::max(std::min(next_candidate.box_coord[1], box_max), box_min)); + nmsed_boxes[batch].push_back( + std::max(std::min(next_candidate.box_coord[2], box_max), box_min)); + nmsed_boxes[batch].push_back( + std::max(std::min(next_candidate.box_coord[3], box_max), box_min)); + } else { + nmsed_boxes[batch].push_back(next_candidate.box_coord[0]); + nmsed_boxes[batch].push_back(next_candidate.box_coord[1]); + nmsed_boxes[batch].push_back(next_candidate.box_coord[2]); + nmsed_boxes[batch].push_back(next_candidate.box_coord[3]); + } nmsed_scores[batch].push_back(next_candidate.score); nmsed_classes[batch].push_back(next_candidate.class_idx); curr_total_size--; @@ -679,6 +681,7 @@ class CombinedNonMaxSuppressionOp : public OpKernel { explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_)); + OP_REQUIRES_OK(context, context->GetAttr("clip_boxes", &clip_boxes_)); } void Compute(OpKernelContext* context) override { @@ -734,11 +737,12 @@ class CombinedNonMaxSuppressionOp : public OpKernel { BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes, max_size_per_class, max_total_size_per_batch, score_threshold_val, iou_threshold_val, - pad_per_class_); + pad_per_class_, clip_boxes_); } private: bool pad_per_class_; + bool clip_boxes_; }; REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc index 242e41b2652..0458a400b26 100644 --- a/tensorflow/core/kernels/non_max_suppression_op_test.cc +++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc @@ -863,7 +863,7 @@ TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestEmptyInput) { class CombinedNonMaxSuppressionOpTest : public OpsTestBase { protected: - void MakeOp(bool pad_per_class = false) { + void MakeOp(bool pad_per_class = false, bool clip_boxes = true) { TF_EXPECT_OK(NodeDefBuilder("combined_non_max_suppression_op", "CombinedNonMaxSuppression") .Input(FakeInput(DT_FLOAT)) @@ -873,6 +873,7 @@ class CombinedNonMaxSuppressionOpTest : public OpsTestBase { .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Attr("pad_per_class", pad_per_class) + .Attr("clip_boxes", clip_boxes) .Finalize(node_def())); TF_EXPECT_OK(InitOp()); } @@ -942,6 +943,39 @@ TEST_F(CombinedNonMaxSuppressionOpTest, TestSelectFromThreeClusters) { test::ExpectTensorEqual(expected_valid_d, *GetOutput(3)); } +TEST_F(CombinedNonMaxSuppressionOpTest, + TestSelectFromThreeClustersNoBoxClipping) { + MakeOp(false, false); + AddInputFromArray(TensorShape({1, 6, 1, 4}), + {0, 0, 10, 10, 0, 1, 10, 11, 0, 1, 10, 9, + 0, 11, 10, 20, 0, 12, 10, 21, 0, 30, 100, 40}); + AddInputFromArray(TensorShape({1, 6, 1}), + {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + // boxes + Tensor expected_boxes(allocator(), DT_FLOAT, TensorShape({1, 3, 4})); + test::FillValues(&expected_boxes, + {0, 11, 10, 20, 0, 0, 10, 10, 0, 30, 100, 40}); + test::ExpectTensorEqual(expected_boxes, *GetOutput(0)); + // scores + Tensor expected_scores(allocator(), DT_FLOAT, TensorShape({1, 3})); + test::FillValues(&expected_scores, {0.95, 0.9, 0.3}); + test::ExpectTensorEqual(expected_scores, *GetOutput(1)); + // classes + Tensor expected_classes(allocator(), DT_FLOAT, TensorShape({1, 3})); + test::FillValues(&expected_classes, {0, 0, 0}); + test::ExpectTensorEqual(expected_classes, *GetOutput(2)); + // valid + Tensor expected_valid_d(allocator(), DT_INT32, TensorShape({1})); + test::FillValues(&expected_valid_d, {3}); + test::ExpectTensorEqual(expected_valid_d, *GetOutput(3)); +} + TEST_F(CombinedNonMaxSuppressionOpTest, TestSelectFromThreeClustersWithScoreThreshold) { MakeOp(); diff --git a/tensorflow/core/kernels/one_hot_op.cc b/tensorflow/core/kernels/one_hot_op.cc index c3385091a0d..0548e389b7a 100644 --- a/tensorflow/core/kernels/one_hot_op.cc +++ b/tensorflow/core/kernels/one_hot_op.cc @@ -17,9 +17,10 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/one_hot_op.h" @@ -103,7 +104,7 @@ class OneHotOp : public OpKernel { for (int i = 0; i < axis; ++i) { prefix_dim_size *= indices_shape.dim_size(i); } - TI suffix_dim_size = indices_shape.num_elements() / prefix_dim_size; + int64 suffix_dim_size = indices_shape.num_elements() / prefix_dim_size; // Split indices into matrix of size prefix_dim_size x suffix_dim_size auto indices_t = @@ -140,7 +141,8 @@ class OneHotOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_ONE_HOT); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) // Forward declarations of the functor specializations for GPU. namespace functor { @@ -190,6 +192,6 @@ TF_CALL_int64(REGISTER_ONE_HOT_GPU); #undef REGISTER_ONE_HOT_GPU_INDEX #undef REGISTER_ONE_HOT_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc index b7a6da61de1..83ba272433f 100644 --- a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc +++ b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc @@ -15,7 +15,8 @@ limitations under the License. // See docs in ../ops/array_ops.cc -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU @@ -46,4 +47,4 @@ TF_CALL_int64(DEFINE_GPU_SPEC); } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h index 5d607b90446..3c3836352e8 100644 --- a/tensorflow/core/kernels/ops_testutil.h +++ b/tensorflow/core/kernels/ops_testutil.h @@ -140,14 +140,13 @@ class OpsTestBase : public ::testing::Test { CHECK_GT(input_types_.size(), inputs_.size()) << "Adding more inputs than types; perhaps you need to call MakeOp"; ResourceMgr* rm = device_->resource_manager(); - EXPECT_TRUE( - rm->Create(container == "" ? rm->default_container() : container, name, - resource) - .ok()); + std::string container_name = + container == "" ? rm->default_container() : container; + EXPECT_TRUE(rm->Create(container_name, name, resource).ok()); TypeIndex type_index = MakeTypeIndex(); ResourceHandle handle; handle.set_device(device_->name()); - handle.set_container(container); + handle.set_container(container_name); handle.set_name(name); handle.set_hash_code(type_index.hash_code()); handle.set_maybe_type_name(type_index.name()); diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc index 691430ebaff..a55b4afb9c8 100644 --- a/tensorflow/core/kernels/pad_op.cc +++ b/tensorflow/core/kernels/pad_op.cc @@ -294,7 +294,8 @@ TF_CALL_POD_TYPES(REGISTER_KERNEL); TF_CALL_string(REGISTER_KERNEL); #undef REGISTER_KERNEL -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T, Dims) \ @@ -395,7 +396,7 @@ REGISTER_KERNEL_BUILDER(Name("PadV2") .HostMemory("constant_values") .HostMemory("output"), PadOp); -#endif +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL // Registration of the GPU implementations. diff --git a/tensorflow/core/kernels/pad_op_gpu.cu.cc b/tensorflow/core/kernels/pad_op_gpu.cu.cc index 0cd8ef17ba2..ddc12417a91 100644 --- a/tensorflow/core/kernels/pad_op_gpu.cu.cc +++ b/tensorflow/core/kernels/pad_op_gpu.cu.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU @@ -45,4 +46,4 @@ TF_CALL_uint8(DEFINE_GPU_SPECS); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc index 4abbe2fe3b7..d6d2be12391 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc @@ -17,17 +17,17 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h" - #include #include + #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" #if defined(_MSC_VER) && !defined(__clang__) // msvc does not support unroll. One could try the loop pragma but we need to @@ -161,7 +161,7 @@ __global__ void __launch_bounds__(1024) Eigen::array z; Eigen::array g; - const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin; + const T plusFactor = (normMin < T(0)) ? T(0) : T(normMin * normMin); int numIterations = 0; while (numIterations < kMaxIterations) { @@ -240,7 +240,7 @@ struct TruncatedNormalFunctor { typename TTypes::ConstFlat maxvals, const random::PhiloxRandom& gen, typename TTypes::Flat output) { - const auto config = GetCudaLaunchConfig(num_elements, d); + const auto config = GetGpuLaunchConfig(num_elements, d); TF_CHECK_OK(CudaLaunchKernel( TruncatedNormalKernel, config.block_count, config.thread_per_block, diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 4866efef9db..a398e6cbaee 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -12,253 +12,266 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/partitioned_function_ops.h" + #include "absl/strings/match.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/util/ptr_util.h" +#ifndef __ANDROID__ +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#endif #if GOOGLE_CUDA #include "tensorflow/stream_executor/stream.h" #endif // GOOGLE_CUDA namespace tensorflow { -namespace { -// A `PartitionedCallOp` asynchronously executes a function, potentially across -// multiple devices but within a single process. The kernel places and -// partitions a given function's underlying graph, and executes each of the -// partitioned subgraphs as a function. -// -// TODO(akshayka): Support distributed execution. -class PartitionedCallOp : public AsyncOpKernel { - public: - explicit PartitionedCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); - string deprecated_config_serialized; - OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &deprecated_config_serialized)); - string config_proto_serialized; - OP_REQUIRES_OK(ctx, ctx->GetAttr("config_proto", &config_proto_serialized)); + +PartitionedCallOp::PartitionedCallOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + func_(new NameAttrList), + config_proto_(new ConfigProto) { + OP_REQUIRES_OK( + ctx, ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, func_.get())); + string deprecated_config_serialized; + OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &deprecated_config_serialized)); + string config_proto_serialized; + OP_REQUIRES_OK(ctx, ctx->GetAttr("config_proto", &config_proto_serialized)); + OP_REQUIRES( + ctx, + deprecated_config_serialized.empty() || config_proto_serialized.empty(), + errors::InvalidArgument("Provided both 'config' and 'config_proto' but " + "only one should be provided. Note the " + "'config' option is deprecated.")); + if (!deprecated_config_serialized.empty()) { + OP_REQUIRES(ctx, + config_proto_->mutable_graph_options() + ->mutable_rewrite_options() + ->ParseFromString(deprecated_config_serialized), + errors::InvalidArgument("Unable to parse config string as " + "tensorflow::RewriteOptions proto.")); + } else { OP_REQUIRES( - ctx, - deprecated_config_serialized.empty() || config_proto_serialized.empty(), - errors::InvalidArgument("Provided both 'config' and 'config_proto' but " - "only one should be provided. Note the " - "'config' option is deprecated.")); - if (!deprecated_config_serialized.empty()) { - OP_REQUIRES(ctx, - config_proto_.mutable_graph_options() - ->mutable_rewrite_options() - ->ParseFromString(deprecated_config_serialized), - errors::InvalidArgument("Unable to parse config string as " - "tensorflow::RewriteOptions proto.")); + ctx, config_proto_->ParseFromString(config_proto_serialized), + errors::InvalidArgument("Unable to parse config_proto string as " + "tensorflow::ConfigProto proto.")); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("executor_type", &executor_type_)); +} + +PartitionedCallOp::~PartitionedCallOp() { + for (const auto& it : handles_) { + Status status = it.first->ReleaseHandle(it.second); + if (!status.ok()) { + LOG(INFO) << "Ignoring error while destructing PartitionedCallOp: " + << status.ToString(); + } + } +} + +void PartitionedCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { + FunctionLibraryRuntime* lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library is provided."), done); + + // The function body's graph is placed and partitioned the first time + // `ComputeAsync` is invoked; every subsequent invocation calls each + // of the function shards yielded by partitioning. + // + // The partitioning step yields a set of devices on which to run the + // function, and exactly one function shard is created for each device + // Inputs and outputs are pinned to the local device, for simplicity. + // + // TODO(akshayka): Support re-sharding the function on subsequent calls, + // via, e.g., virtual device annotations and a list of device names + // supplied through an attribute. + // + // TODO(akshayka): Add a fastpath for functions that execute on a single + // device. + FunctionLibraryRuntime::Handle handle; + // If we are instantiating the function, we can efficiently extract the + // inputs while instantiating. Else, we extract them separately below. + std::vector inputs; + bool inputs_extracted = false; + { + mutex_lock l(mu_); + auto it = handles_.find(lib); + if (it == handles_.end()) { + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, ctx, &inputs, &handle), done); + inputs_extracted = true; + handles_[lib] = handle; } else { - OP_REQUIRES( - ctx, config_proto_.ParseFromString(config_proto_serialized), - errors::InvalidArgument("Unable to parse config_proto string as " - "tensorflow::ConfigProto proto.")); - } - OP_REQUIRES_OK(ctx, ctx->GetAttr("executor_type", &executor_type_)); - } - - ~PartitionedCallOp() override { - for (const auto& it : handles_) { - Status status = it.first->ReleaseHandle(it.second); - if (!status.ok()) { - LOG(INFO) << "Ignoring error while destructing PartitionedCallOp: " - << status.ToString(); - } + handle = it->second; } } - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - FunctionLibraryRuntime* lib = ctx->function_library(); - OP_REQUIRES_ASYNC(ctx, lib != nullptr, - errors::Internal("No function library is provided."), - done); - - // The function body's graph is placed and partitioned the first time - // `ComputeAsync` is invoked; every subsequent invocation calls each - // of the function shards yielded by partitioning. - // - // The partitioning step yields a set of devices on which to run the - // function, and exactly one function shard is created for each device - // Inputs and outputs are pinned to the local device, for simplicity. - // - // TODO(akshayka): Support re-sharding the function on subsequent calls, - // via, e.g., virtual device annotations and a list of device names - // supplied through an attribute. - // - // TODO(akshayka): Add a fastpath for functions that execute on a single - // device. - FunctionLibraryRuntime::Handle handle; - // If we are instantiating the function, we can efficiently extract the - // inputs while instantiating. Else, we extract them separately below. - std::vector inputs; - bool inputs_extracted = false; - { - mutex_lock l(mu_); - auto it = handles_.find(lib); - if (it == handles_.end()) { - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, ctx, &inputs, &handle), - done); - inputs_extracted = true; - handles_[lib] = handle; - } else { - handle = it->second; - } - } - - if (!inputs_extracted) { - OpInputList args; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done); - inputs.reserve(args.size()); - for (const Tensor& tensor : args) { - inputs.push_back(tensor); - } - } - - RunFunction(handle, inputs, lib, ctx, done); - } - - private: - Status FillOutputDevices(const FunctionLibraryRuntime& lib, - const Device& cpu_device, AttrSlice attrs, - FunctionLibraryRuntime::InstantiateOptions* opts) { - const FunctionLibraryDefinition* flib = lib.GetFunctionLibraryDefinition(); - const FunctionDef* fdef = flib->Find(func_.name()); - if (fdef == nullptr) { - return errors::NotFound("Failed for find definiton for function \"", - func_.name(), "\""); - } - - bool is_type_list; - for (const OpDef::ArgDef& ret_def : fdef->signature().output_arg()) { - DataTypeVector dtypes; - TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes)); - for (DataType dtype : dtypes) { - if (MTypeFromDType(dtype) == HOST_MEMORY) { - opts->output_devices.push_back(cpu_device.name()); - } else { - opts->output_devices.push_back(opts->target); - } - } - } - return Status::OK(); - } - - Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx, - std::vector* inputs, - FunctionLibraryRuntime::Handle* handle) { - grappler::GrapplerItem::OptimizationOptions optimization_options; - - // Tensorflow 2.0 in eager mode with automatic control dependencies will - // prune all nodes that are not in the transitive fanin of the fetch nodes. - // However because the function will be executed via FunctionLibraryRuntime, - // and current function implementation does not prune stateful and dataset - // ops, we rely on Grappler to do the correct graph pruning. - optimization_options.allow_pruning_stateful_and_dataset_ops = true; - - // All the nested function calls will be executed and optimized via - // PartitionedCallOp, there is no need to optimize functions now. - optimization_options.optimize_function_library = false; - - FunctionLibraryRuntime::InstantiateOptions opts; - opts.target = lib->device()->name(); - opts.is_multi_device_function = true; - opts.optimize_graph_fn = - std::bind(grappler::OptimizeGraph, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3, - std::placeholders::_4, std::placeholders::_5, config_proto_, - func_.name(), optimization_options, std::placeholders::_6); - opts.graph_collector = ctx->graph_collector(); - opts.executor_type = executor_type_; - + if (!inputs_extracted) { OpInputList args; - TF_RETURN_IF_ERROR(ctx->input_list("args", &args)); - Device* cpu_device; - TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device)); - - inputs->reserve(args.size()); + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done); + inputs.reserve(args.size()); for (const Tensor& tensor : args) { - inputs->push_back(tensor); - DataType dtype = tensor.dtype(); - if (dtype == DT_RESOURCE) { - const ResourceHandle& handle = tensor.flat()(0); - opts.input_devices.push_back(handle.device()); - } else if (MTypeFromDType(dtype) == HOST_MEMORY) { - opts.input_devices.push_back(cpu_device->name()); + inputs.push_back(tensor); + } + } + + RunFunction(handle, inputs, lib, ctx, done); +} + +Status PartitionedCallOp::FillOutputDevices( + const FunctionLibraryRuntime& lib, const Device& cpu_device, + AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions* opts) { + const FunctionLibraryDefinition* flib = lib.GetFunctionLibraryDefinition(); + const FunctionDef* fdef = flib->Find(func_->name()); + if (fdef == nullptr) { + return errors::NotFound("Failed for find definition for function \"", + func_->name(), "\""); + } + + bool is_type_list; + for (const OpDef::ArgDef& ret_def : fdef->signature().output_arg()) { + DataTypeVector dtypes; + TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes)); + for (DataType dtype : dtypes) { + if (MTypeFromDType(dtype) == HOST_MEMORY) { + opts->output_devices.push_back(cpu_device.name()); } else { - opts.input_devices.push_back(opts.target); + opts->output_devices.push_back(opts->target); } } + } + return Status::OK(); +} - TF_RETURN_IF_ERROR( - FillOutputDevices(*lib, *cpu_device, AttrSlice(&func_.attr()), &opts)); +Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib, + OpKernelContext* ctx, + std::vector* inputs, + FunctionLibraryRuntime::Handle* handle) { + FunctionLibraryRuntime::InstantiateOptions opts; - TF_RETURN_IF_ERROR( - lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts, handle)); - return Status::OK(); +#ifndef __ANDROID__ + // Android tf library does not include grappler. + grappler::GrapplerItem::OptimizationOptions optimization_options; + // Tensorflow 2.0 in eager mode with automatic control dependencies will + // prune all nodes that are not in the transitive fanin of the fetch nodes. + // However because the function will be executed via FunctionLibraryRuntime, + // and current function implementation does not prune stateful and dataset + // ops, we rely on Grappler to do the correct graph pruning. + optimization_options.allow_pruning_stateful_and_dataset_ops = true; + + // All the nested function calls will be executed and optimized via + // PartitionedCallOp, there is no need to optimize functions now. + optimization_options.optimize_function_library = false; + + opts.optimize_graph_fn = + std::bind(grappler::OptimizeGraph, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, + std::placeholders::_4, std::placeholders::_5, *config_proto_, + func_->name(), optimization_options, std::placeholders::_6); +#endif + + // In some contexts like running the graph to evaluate constants, + // the FLR won't have any device. + opts.target = lib->device() == nullptr ? "" : lib->device()->name(); + opts.is_multi_device_function = true; + opts.graph_collector = ctx->graph_collector(); + opts.executor_type = executor_type_; + + OpInputList args; + TF_RETURN_IF_ERROR(ctx->input_list("args", &args)); + Device* cpu_device; + TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device)); + + inputs->reserve(args.size()); + for (const Tensor& tensor : args) { + inputs->push_back(tensor); + DataType dtype = tensor.dtype(); + if (dtype == DT_RESOURCE) { + const ResourceHandle& handle = tensor.flat()(0); + opts.input_devices.push_back(handle.device()); + } else if (MTypeFromDType(dtype) == HOST_MEMORY) { + opts.input_devices.push_back(cpu_device->name()); + } else { + opts.input_devices.push_back(opts.target); + } } - void RunFunction(FunctionLibraryRuntime::Handle handle, - const std::vector& inputs, - FunctionLibraryRuntime* lib, OpKernelContext* ctx, - DoneCallback done) { - FunctionLibraryRuntime::Options run_opts; - run_opts.step_id = ctx->step_id(); - run_opts.step_container = ctx->step_container(); - run_opts.cancellation_manager = ctx->cancellation_manager(); - run_opts.stats_collector = ctx->stats_collector(); - run_opts.collective_executor = ctx->collective_executor(); - // TODO(akshayka): Consider selecting a runner on a per-device basis, - // i.e., using device-specific threadpools when available. - run_opts.runner = ctx->runner(); - run_opts.source_device = lib->device()->name(); - run_opts.allow_dead_tensors = true; - // TODO(akshayka): Accommodate the multiple-worker scenario by adding the - // constructed rendezvous to a rendezvous manager. - Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr()); - run_opts.rendezvous = rendez; + TF_RETURN_IF_ERROR( + FillOutputDevices(*lib, *cpu_device, AttrSlice(&func_->attr()), &opts)); - std::vector* rets = new std::vector; - const string& func_name = func_.name(); - lib->Run(run_opts, handle, inputs, rets, - [rets, rendez, done, ctx, func_name](const Status& status) { - if (!status.ok()) { - const string function_and_msg = - strings::StrCat(errors::FormatFunctionForError(func_name), - " ", status.error_message()); - ctx->SetStatus(Status(status.code(), function_and_msg)); - } else { - for (int i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } + TF_RETURN_IF_ERROR( + lib->Instantiate(func_->name(), AttrSlice(&func_->attr()), opts, handle)); + return Status::OK(); +} + +void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle, + const std::vector& inputs, + FunctionLibraryRuntime* lib, + OpKernelContext* ctx, DoneCallback done) { + FunctionLibraryRuntime::Options run_opts; + ResourceMgr* resource_mgr = lib->device()->resource_manager(); + ScopedStepContainer* step_container = new ScopedStepContainer( + run_opts.step_id, [resource_mgr](const string& name) { + resource_mgr->Cleanup(name).IgnoreError(); + }); + run_opts.step_container = step_container; + run_opts.cancellation_manager = ctx->cancellation_manager(); + run_opts.stats_collector = ctx->stats_collector(); + run_opts.collective_executor = ctx->collective_executor(); + // TODO(akshayka): Consider selecting a runner on a per-device basis, + // i.e., using device-specific threadpools when available. + run_opts.runner = ctx->runner(); + run_opts.source_device = + lib->device() == nullptr ? "" : lib->device()->name(); + run_opts.allow_dead_tensors = true; + + Rendezvous* rendez; + OP_REQUIRES_OK_ASYNC( + ctx, + ctx->create_rendezvous(run_opts.step_id, + ctx->function_library()->device_mgr(), &rendez), + done); + run_opts.rendezvous = rendez; + + std::vector* rets = new std::vector; + const string& func_name = func_->name(); + profiler::TraceMe trace_me( + [&] { + return absl::StrCat( + "PartitionedCallOp #parent_step_id=", ctx->step_id(), + ",function_step_id=", run_opts.step_id, "#"); + }, + /*level=*/2); + lib->Run(run_opts, handle, inputs, rets, + [rets, rendez, done, ctx, func_name, + step_container](const Status& status) { + if (!status.ok()) { + const string function_and_msg = + strings::StrCat(errors::FormatFunctionForError(func_name), + " ", status.error_message()); + ctx->SetStatus(Status(status.code(), function_and_msg)); + } else { + for (int i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); } - delete rets; - rendez->Unref(); - done(); - }); - } - - NameAttrList func_; - ConfigProto config_proto_; - string executor_type_; - mutex mu_; - // Cache the handle per FLR because this kernel may be instantiated for - // a stateful op, different invocations of it may use different FLRs. - // Different device placements of PartitionedCallOp also use - // different FLRs. - gtl::FlatMap handles_ - GUARDED_BY(mu_); -}; + } + delete rets; + delete step_container; + rendez->Unref(); + done(); + }); +} REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU), PartitionedCallOp); @@ -275,5 +288,4 @@ REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_SYCL), PartitionedCallOp); #endif // TENSORFLOW_USE_SYCL -} // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/partitioned_function_ops.h b/tensorflow/core/kernels/partitioned_function_ops.h new file mode 100644 index 00000000000..776ebab9695 --- /dev/null +++ b/tensorflow/core/kernels/partitioned_function_ops.h @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_ + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +class NameAttrList; +class ConfigProto; + +// A `PartitionedCallOp` asynchronously executes a function, potentially across +// multiple devices but within a single process. The kernel places and +// partitions a given function's underlying graph, and executes each of the +// partitioned subgraphs as a function. +// +// TODO(akshayka): Support distributed execution. +class PartitionedCallOp : public AsyncOpKernel { + public: + explicit PartitionedCallOp(OpKernelConstruction* ctx); + + ~PartitionedCallOp() override; + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + Status FillOutputDevices(const FunctionLibraryRuntime& lib, + const Device& cpu_device, AttrSlice attrs, + FunctionLibraryRuntime::InstantiateOptions* opts); + + Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx, + std::vector* inputs, + FunctionLibraryRuntime::Handle* handle); + + void RunFunction(FunctionLibraryRuntime::Handle handle, + const std::vector& inputs, + FunctionLibraryRuntime* lib, OpKernelContext* ctx, + DoneCallback done); + + // Using unique pointers to avoid including proto headers in kernel headers + std::unique_ptr func_; + std::unique_ptr config_proto_; + string executor_type_; + mutex mu_; + // Cache the handle per FLR because this kernel may be instantiated for + // a stateful op, different invocations of it may use different FLRs. + // Different device placements of PartitionedCallOp also use + // different FLRs. + gtl::FlatMap handles_ + GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_ diff --git a/tensorflow/core/kernels/pooling_ops_3d_gpu.cu.cc b/tensorflow/core/kernels/pooling_ops_3d_gpu.cu.cc index 341a43c368e..1b28d8b5923 100644 --- a/tensorflow/core/kernels/pooling_ops_3d_gpu.cu.cc +++ b/tensorflow/core/kernels/pooling_ops_3d_gpu.cu.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/pooling_ops_3d_gpu.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -142,21 +142,21 @@ bool MaxPool3dGradBackward::operator()( const T* top_diff, T* bottom_diff, const Eigen::GpuDevice& d) { int num_kernels = batch * channels * pooled_plane * pooled_height * pooled_width; - CudaLaunchConfig config = GetCudaLaunchConfig(num_kernels, d); + GpuLaunchConfig config = GetCudaLaunchConfig(num_kernels, d); if (data_format == FORMAT_NHWC) { - MaxPoolGradBackwardNoMaskNDHWC<<>>( - num_kernels, bottom_data, output_data, pooled_plane, pooled_height, - pooled_width, channels, plane, height, width, kernel_p, kernel_h, - kernel_w, stride_p, stride_h, stride_w, pad_p, pad_t, pad_l, top_diff, - bottom_diff); + TF_CHECK_OK(CudaLaunchKernel( + MaxPoolGradBackwardNoMaskNDHWC, config.block_count, + config.thread_per_block, 0, d.stream(), num_kernels, bottom_data, + output_data, pooled_plane, pooled_height, pooled_width, channels, plane, + height, width, kernel_p, kernel_h, kernel_w, stride_p, stride_h, + stride_w, pad_p, pad_t, pad_l, top_diff, bottom_diff)); } else { - MaxPoolGradBackwardNoMaskNCDHW<<>>( - num_kernels, bottom_data, output_data, pooled_plane, pooled_height, - pooled_width, channels, plane, height, width, kernel_p, kernel_h, - kernel_w, stride_p, stride_h, stride_w, pad_p, pad_t, pad_l, top_diff, - bottom_diff); + TF_CHECK_OK(CudaLaunchKernel( + MaxPoolGradBackwardNoMaskNCDHW, config.block_count, + config.thread_per_block, 0, d.stream(), num_kernels, bottom_data, + output_data, pooled_plane, pooled_height, pooled_width, channels, plane, + height, width, kernel_p, kernel_h, kernel_w, stride_p, stride_h, + stride_w, pad_p, pad_t, pad_l, top_diff, bottom_diff)); } return d.ok(); } diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 903cf9313a2..01a353cb175 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #if GOOGLE_CUDA -#include "cuda/include/cudnn.h" +#include "third_party/gpus/cudnn/cudnn.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/pooling_ops_common_gpu.h" #include "tensorflow/core/platform/stream_executor.h" diff --git a/tensorflow/core/kernels/pooling_ops_common_gpu.h b/tensorflow/core/kernels/pooling_ops_common_gpu.h index 7362c5275f7..9685bd9fdd0 100644 --- a/tensorflow/core/kernels/pooling_ops_common_gpu.h +++ b/tensorflow/core/kernels/pooling_ops_common_gpu.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if !GOOGLE_CUDA -#error This file must only be included when building with Cuda support +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda or ROCm support #endif #ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_ diff --git a/tensorflow/core/kernels/population_count_op_gpu.cu.cc b/tensorflow/core/kernels/population_count_op_gpu.cu.cc index b9a7da56872..22beadfe61a 100644 --- a/tensorflow/core/kernels/population_count_op_gpu.cu.cc +++ b/tensorflow/core/kernels/population_count_op_gpu.cu.cc @@ -18,14 +18,13 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/population_count_op.h" - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/population_count_op.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -70,7 +69,7 @@ __global__ void PopulationCountKernel(const int size, const int64* input, TTypes::Flat output) { \ const GPUDevice& d = c->eigen_device(); \ int64 total_count = input.size(); \ - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); \ + GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); \ TF_CHECK_OK(CudaLaunchKernel(PopulationCountKernel, config.block_count, \ config.thread_per_block, 0, d.stream(), \ total_count, input.data(), output.data())); \ diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h index e67a94e5f83..99efa28e2ec 100644 --- a/tensorflow/core/kernels/quantization_utils.h +++ b/tensorflow/core/kernels/quantization_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ #define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ +#include #define EIGEN_USE_THREADS // This is a set of functions that standardizes how quantized values are @@ -102,7 +103,7 @@ float QuantizedToFloat(T input, float range_min, float range_max) { // range_scale to a float, otherwise range_min_rounded might be slightly // different. const double range_min_rounded = - round(range_min / static_cast(range_scale)) * + std::round(range_min / static_cast(range_scale)) * static_cast(range_scale); const double result = range_min_rounded + (offset_input * range_scale); return static_cast(result); @@ -170,7 +171,8 @@ struct QuantizedToFloatStruct { range_scale((range_max - range_min) / (number_of_steps - 1.0)), range_min_rounded(range_max == range_min ? range_min - : round(range_min / range_scale) * range_scale) {} + : std::round(range_min / range_scale) * + range_scale) {} const float range_min; const float range_scale; @@ -207,7 +209,7 @@ struct FloatToQuantizedStruct { range_scale(range_max == range_min ? 0.0 : (number_of_steps - 1.0) / (range_max - range_min)), - range_min_scaled(round(range_min * range_scale)) {} + range_min_scaled(std::round(range_min * range_scale)) {} const float range_min; const float range_scale; diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc index 176720c22cc..98ee5499b7d 100644 --- a/tensorflow/core/kernels/quantization_utils_test.cc +++ b/tensorflow/core/kernels/quantization_utils_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #define EIGEN_USE_THREADS #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" @@ -212,8 +212,8 @@ void TestRequantizeManyInNewRange8To32Bit() { template void TestRequantizeManyInNewRangeEigenVsNonEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); + Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(), + 2 /* num_threads */); const size_t ranges_count = 6; const float ranges[ranges_count][4] = { @@ -294,8 +294,8 @@ void TimeRequantizeManyInNewRange(int64 num_elements, int64 iterations, } thread::ThreadPool threadpool(Env::Default(), "test", 4 /* num_threads */); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_device(&wrapper, 4 /* num_threads */); + Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(), + 4 /* num_threads */); Tensor i_tensor = tensorflow::test::AsTensor(gtl::ArraySlice(values_quantized)); @@ -499,7 +499,7 @@ void TestAvoidBias() { const float step_size = (max - min) / 255.0f; const float tolerance = step_size / 1000.0f; // This is the smallest perfectly representable float in the range. - float first_float = ceil(min / step_size) * step_size; + float first_float = std::ceil(min / step_size) * step_size; for (float f = first_float; f <= max; f += step_size) { const int as_int = FloatToQuantized(f, min, max); const float back_to_float = QuantizedToFloat(as_int, min, max); @@ -606,8 +606,8 @@ void TestRequantizeManyInNewRange32To8Bit() { void TestRequantizeManyInNewRange32To8BitUsingEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); + Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(), + 2 /* num_threads */); TestRequantizeManyInNewRange32To8Bit(&eigen_device); } @@ -637,8 +637,8 @@ void TestFloatTensorToQuantized() { // FloatToQuantized. void TestFloatToQuantizedInPlaceUsingEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); + Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(), + 2 /* num_threads */); TestFloatToQuantizedInPlaceUsingEigen(&eigen_device); TestFloatToQuantizedInPlaceUsingEigen(&eigen_device); @@ -648,8 +648,8 @@ void TestFloatToQuantizedInPlaceUsingEigen() { void TestOverflowWithEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); + Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(), + 2 /* num_threads */); const int num_vals = 4; const float input_min = 0.0f; @@ -716,8 +716,8 @@ void TestQuantizedTensorToFloat() { // QuantizedToFloat. void TestQuantizedToFloatInPlaceUsingEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); + Eigen::ThreadPoolDevice eigen_device(threadpool.AsEigenThreadPool(), + 2 /* num_threads */); TestQuantizedToFloatInPlaceUsingEigen(&eigen_device); TestQuantizedToFloatInPlaceUsingEigen(&eigen_device); diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc index f13341e0afe..43f1c6ea2af 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc @@ -15,9 +15,10 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/quantize_and_dequantize_op.h" @@ -241,7 +242,8 @@ TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define REGISTER_GPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2") \ .Device(DEVICE_GPU) \ @@ -262,5 +264,5 @@ TF_CALL_double(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_GPU_KERNEL); TF_CALL_double(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL -#endif +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc index 5745e418f36..00d2a3b1b30 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU @@ -47,4 +48,4 @@ template struct functor::QuantizeAndDequantizeOneScaleFunctor -void WriteValueSlices(const Tensor& params_dense_values_in, - const std::vector>& value_slices, - int64 value_size, Tensor* values_out) { +template +void WriteValueSlices( + const Tensor& params_dense_values_in, + const std::vector>& value_slices, + SPLITS_TYPE value_size, Tensor* values_out) { const auto& params_dense_values = params_dense_values_in.flat_outer_dims(); auto values = values_out->flat_outer_dims(); @@ -50,7 +51,7 @@ void WriteValueSlices(const Tensor& params_dense_values_in, } // namespace -template +template class RaggedGatherOpBase : public OpKernel { public: using OpKernel::OpKernel; @@ -66,18 +67,18 @@ class RaggedGatherOpBase : public OpKernel { context->input(params_nested_splits_in.size() + 1); DCHECK_GT(params_nested_splits_in.size(), 0); // Enforced by REGISTER_OP. - int64 num_params = params_nested_splits_in[0].dim_size(0) - 1; + SPLITS_TYPE num_params = params_nested_splits_in[0].dim_size(0) - 1; OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params)); OP_REQUIRES(context, params_dense_values_in.dims() > 0, errors::InvalidArgument("params.rank must be nonzero")); - int64 num_params_dense_values = params_dense_values_in.dim_size(0); + SPLITS_TYPE num_params_dense_values = params_dense_values_in.dim_size(0); // Calculate the `splits`, and store the value slices that we need to // copy in `value_slices`. - std::vector> value_slices; - int64 num_values = 0; - std::vector> out_splits; + std::vector> value_slices; + SPLITS_TYPE num_values = 0; + std::vector> out_splits; OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in, num_params_dense_values, &out_splits, &value_slices, &num_values)); @@ -90,12 +91,14 @@ class RaggedGatherOpBase : public OpKernel { } private: + using ConstFlatType = typename TTypes::ConstFlat; + // Check if any indices are out-of-bounds. ::tensorflow::Status ValidateIndices(const Tensor& indices_in, - int64 num_params) { + SPLITS_TYPE num_params) { const auto& indices = indices_in.flat(); - for (int64 i = 0; i < indices.size(); ++i) { - int64 index = indices(i); + for (SPLITS_TYPE i = 0; i < indices.size(); ++i) { + SPLITS_TYPE index = indices(i); if (index < 0 || index >= num_params) { return errors::InvalidArgument( "indices", SliceDebugString(indices_in.shape(), i), " = ", index, @@ -111,9 +114,10 @@ class RaggedGatherOpBase : public OpKernel { // we need for allocating the output values tensor) is stored in `num_values`. ::tensorflow::Status MakeSplits( const Tensor& indices_in, const OpInputList& params_nested_splits_in, - int64 num_params_dense_values, - std::vector>* out_splits, - std::vector>* value_slices, int64* num_values) { + SPLITS_TYPE num_params_dense_values, + std::vector>* out_splits, + std::vector>* value_slices, + SPLITS_TYPE* num_values) { *num_values = 0; value_slices->clear(); @@ -122,10 +126,10 @@ class RaggedGatherOpBase : public OpKernel { // Get Eigen tensors. const auto& indices = indices_in.flat(); - std::vector::ConstFlat> params_nested_splits; + std::vector params_nested_splits; params_nested_splits.reserve(params_nested_splits_in.size()); for (const auto& splits_in : params_nested_splits_in) { - params_nested_splits.push_back(splits_in.flat()); + params_nested_splits.push_back(splits_in.flat()); } TF_RETURN_IF_ERROR( @@ -165,7 +169,7 @@ class RaggedGatherOpBase : public OpKernel { const auto& splits = params_nested_splits[dim]; int out_dim = dim + indices_in.dims() - 1; if (out_dim >= 0) { - int64 delta = out_splits->at(out_dim).back() - splits(start); + SPLITS_TYPE delta = out_splits->at(out_dim).back() - splits(start); for (int j = start; j < limit; ++j) { out_splits->at(out_dim).push_back(splits(j + 1) + delta); } @@ -182,14 +186,14 @@ class RaggedGatherOpBase : public OpKernel { } ::tensorflow::Status ValidateSplits( - const std::vector::ConstFlat>& params_nested_splits, - int64 num_params_dense_values) { + const std::vector& params_nested_splits, + SPLITS_TYPE num_params_dense_values) { // Validate for (int dim = 0; dim < params_nested_splits.size(); ++dim) { const auto& splits = params_nested_splits[dim]; - int64 last_split = (dim == params_nested_splits.size() - 1) - ? num_params_dense_values - : params_nested_splits[dim + 1].size(); + SPLITS_TYPE last_split = (dim == params_nested_splits.size() - 1) + ? num_params_dense_values + : params_nested_splits[dim + 1].size(); if (splits.size() == 0) { return errors::InvalidArgument("Ragged splits may not be empty"); } @@ -210,17 +214,17 @@ class RaggedGatherOpBase : public OpKernel { } ::tensorflow::Status WriteSplits( - const std::vector>& out_splits, + const std::vector>& out_splits, OpKernelContext* context) { OpOutputList splits_out; TF_RETURN_IF_ERROR( context->output_list("output_nested_splits", &splits_out)); for (int i = 0; i < out_splits.size(); ++i) { Tensor* splits; - int64 num_splits = out_splits[i].size(); + SPLITS_TYPE num_splits = out_splits[i].size(); TF_RETURN_IF_ERROR( splits_out.allocate(i, TensorShape({num_splits}), &splits)); - auto splits_flat = splits->flat(); + auto splits_flat = splits->flat(); std::copy_n(out_splits[i].data(), out_splits[i].size(), splits_flat.data()); } @@ -229,15 +233,16 @@ class RaggedGatherOpBase : public OpKernel { ::tensorflow::Status WriteValues( const Tensor& params_dense_values_in, - const std::vector>& value_slices, - int values_index, int64 num_values, OpKernelContext* context) const { + const std::vector>& value_slices, + int values_index, SPLITS_TYPE num_values, + OpKernelContext* context) const { Tensor* values_out = nullptr; TensorShape values_shape = params_dense_values_in.shape(); values_shape.set_dim(0, num_values); TF_RETURN_IF_ERROR( context->allocate_output(values_index, values_shape, &values_out)); - const int64 num_elements = params_dense_values_in.NumElements(); - const int64 value_size = + const SPLITS_TYPE num_elements = params_dense_values_in.NumElements(); + const SPLITS_TYPE value_size = num_elements == 0 ? 0 : (num_elements / params_dense_values_in.dim_size(0)); CallWriteValueSlices(params_dense_values_in, value_slices, value_size, @@ -253,34 +258,39 @@ class RaggedGatherOpBase : public OpKernel { // which cuts the binary size of this op from ~300k to <90k. virtual void CallWriteValueSlices( const Tensor& params_dense_values_in, - const std::vector>& value_slices, - int64 value_size, Tensor* values_out) const = 0; + const std::vector>& value_slices, + SPLITS_TYPE value_size, Tensor* values_out) const = 0; }; -template -class RaggedGatherOp : public RaggedGatherOpBase { +template +class RaggedGatherOp : public RaggedGatherOpBase { public: - using RaggedGatherOpBase::RaggedGatherOpBase; + using RaggedGatherOpBase::RaggedGatherOpBase; private: void CallWriteValueSlices( const Tensor& params_dense_values_in, - const std::vector>& value_slices, - int64 value_size, Tensor* values_out) const override { + const std::vector>& value_slices, + SPLITS_TYPE value_size, Tensor* values_out) const override { WriteValueSlices(params_dense_values_in, value_slices, value_size, values_out); } }; -#define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type) \ - REGISTER_KERNEL_BUILDER(Name("RaggedGather") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("Tindices") \ - .TypeConstraint("Tvalues"), \ - RaggedGatherOp); -#define REGISTER_CPU_KERNEL(value_type) \ - REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type) \ - REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type) +#define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type, \ + splits_type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RaggedGather") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tindices") \ + .TypeConstraint("Tvalues") \ + .TypeConstraint("Tsplits"), \ + RaggedGatherOp); +#define REGISTER_CPU_KERNEL(value_type) \ + REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int32) \ + REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type, int32) \ + REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int64) \ + REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type, int64) TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL); TF_CALL_string(REGISTER_CPU_KERNEL); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); diff --git a/tensorflow/core/kernels/ragged_range_op.cc b/tensorflow/core/kernels/ragged_range_op.cc index cb7546c3974..024b16ff935 100644 --- a/tensorflow/core/kernels/ragged_range_op.cc +++ b/tensorflow/core/kernels/ragged_range_op.cc @@ -26,7 +26,7 @@ namespace tensorflow { using errors::InvalidArgument; -template +template class RaggedRangeOp : public OpKernel { public: using OpKernel::OpKernel; @@ -60,7 +60,7 @@ class RaggedRangeOp : public OpKernel { InvalidArgument("starts, limits, and deltas must have the " "same shape")); } - int64 nrows = in_sizes.empty() ? 1 : in_sizes[0]; + SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes[0]; const auto& starts = starts_in.flat(); const auto& limits = limits_in.flat(); @@ -71,7 +71,7 @@ class RaggedRangeOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({nrows + 1}), &rt_nested_splits_out)); - auto rt_nested_splits = rt_nested_splits_out->flat(); + auto rt_nested_splits = rt_nested_splits_out->flat(); rt_nested_splits(0) = 0; for (int row = 0; row < nrows; ++row) { T start = broadcast_starts ? starts(0) : starts(row); @@ -81,7 +81,7 @@ class RaggedRangeOp : public OpKernel { rt_nested_splits(row + 1) = rt_nested_splits(row) + RangeSize(start, limit, delta); } - int64 nvals = rt_nested_splits(nrows); + SPLITS_TYPE nvals = rt_nested_splits(nrows); // Construct the rt_dense_values tensor. Tensor* rt_dense_values_out = nullptr; @@ -90,10 +90,10 @@ class RaggedRangeOp : public OpKernel { auto rt_dense_values = rt_dense_values_out->flat(); int value_index = 0; for (int row = 0; row < nrows; ++row) { - int64 row_size = rt_nested_splits(row + 1) - rt_nested_splits(row); + SPLITS_TYPE row_size = rt_nested_splits(row + 1) - rt_nested_splits(row); T value = broadcast_starts ? starts(0) : starts(row); T delta = broadcast_deltas ? deltas(0) : deltas(row); - for (int64 i = 0; i < row_size; ++i) { + for (SPLITS_TYPE i = 0; i < row_size; ++i) { rt_dense_values(value_index++) = T(value); value += delta; } @@ -102,7 +102,7 @@ class RaggedRangeOp : public OpKernel { private: // Returns the number of elements in the specified range. - int64 RangeSize(T start, T limit, T delta) { + SPLITS_TYPE RangeSize(T start, T limit, T delta) { if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) { return 0; } @@ -114,10 +114,17 @@ class RaggedRangeOp : public OpKernel { } }; -#define REGISTER_CPU_KERNEL(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("RaggedRange").Device(DEVICE_CPU).TypeConstraint("T"), \ - RaggedRangeOp); +#define REGISTER_CPU_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("RaggedRange") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tsplits"), \ + RaggedRangeOp); \ + REGISTER_KERNEL_BUILDER(Name("RaggedRange") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tsplits"), \ + RaggedRangeOp); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); TF_CALL_int32(REGISTER_CPU_KERNEL); diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc new file mode 100644 index 00000000000..3ba266bfd0f --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc @@ -0,0 +1,314 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +struct RaggedTensor { + Tensor values; + std::vector nested_splits; +}; + +Status RaggedComponentsFromVariant(const Tensor& encoded_variant, + int ragged_rank, DataType value_dtype, + DataType split_dtype, + std::vector* decoded_ragged) { + const auto& flat_variants = encoded_variant.flat(); + decoded_ragged->resize(flat_variants.size()); + // Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the + // input. + for (int i = 0; i < flat_variants.size(); i++) { + const auto& flat_variant = flat_variants(i); + const Tensor* encoded_list = flat_variant.get(); + if (encoded_list == nullptr) { + return errors::InvalidArgument( + "Input Variant element at index ", i, + " doesn't hold a Tensor: ", flat_variant.DebugString()); + } + if (encoded_list->dims() != 1) { + return errors::InvalidArgument( + "Encoded input Variant must have rank 1, but found rank: ", + encoded_list->dims(), + ". encoded input Variant: ", encoded_list->DebugString()); + } + if (encoded_list->NumElements() != (ragged_rank + 1) && + encoded_list->NumElements() != 1) { + return errors::InvalidArgument( + "Encoded input Variant must hold either input_ragged_rank + 1 " + "Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), " + "input_ragged_rank: ", + ragged_rank, + ", encoded input Variant: ", encoded_list->DebugString()); + } + const auto& input_vec = encoded_list->vec(); + + // Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor + // to create the component RaggedTensors. + (*decoded_ragged)[i].nested_splits.reserve(ragged_rank); + for (int j = 0; j < ragged_rank; j++) { + const Tensor* split_tensor = input_vec(j).get(); + if (split_tensor == nullptr) { + return errors::InvalidArgument( + "Encoded scalar element at index ", i, + " doesn't have a splits Tensor at split_index ", j, ": ", + input_vec(j).DebugString()); + } + Tensor splits_tensor = *split_tensor; + if (splits_tensor.dtype() != split_dtype) { + return errors::InvalidArgument( + "Expected splits Tensor dtype: ", split_dtype, + ", found: ", splits_tensor.dtype()); + } + if (splits_tensor.dims() != 1) { + return errors::InvalidArgument( + "Ragged splits must have rank 1; encoded scalar element at index ", + i, " has splits Tensor at split_index ", j, ": ", + splits_tensor.DebugString()); + } + (*decoded_ragged)[i].nested_splits.push_back(splits_tensor); + } + const Tensor* values_tensor = input_vec(ragged_rank).get(); + if (values_tensor == nullptr) { + return errors::InvalidArgument("Encoded scalar element at index ", i, + " doesn't have a values Tensor: ", + input_vec(ragged_rank).DebugString()); + } + if (values_tensor->dtype() != value_dtype) { + return errors::InvalidArgument( + "Expected values Tensor dtype: ", value_dtype, + ", found: ", values_tensor->dtype()); + } + if (values_tensor->dims() < 1) { + return errors::InvalidArgument( + "Ragged values must have rank >= 1; encoded scalar element at index ", + i, " has values Tensor: ", values_tensor->DebugString()); + } + (*decoded_ragged)[i].values = *values_tensor; + } + return Status::OK(); +} + +template +Status NestedStackRaggedTensors( + const std::vector& ragged_components, + const std::vector& nested_dim_sizes, const int input_ragged_rank, + const int output_ragged_rank, RaggedTensor* output_ragged) { + output_ragged->nested_splits.reserve(output_ragged_rank); + const int dims = nested_dim_sizes.size(); + + // Populate first `dims - 1` splits. + for (int i = 0; i < dims - 1; i++) { + int dims_splits_size = nested_dim_sizes[i] + 1; + output_ragged->nested_splits.push_back(Tensor( + DataTypeToEnum::value, TensorShape({dims_splits_size}))); + auto splits_vec = output_ragged->nested_splits[i].vec(); + int split_diff = nested_dim_sizes[i + 1]; + for (int j = 0; j < dims_splits_size; j++) { + splits_vec(j) = j * split_diff; + } + } + + // Populate `dims`-th split. + int splits_size = ragged_components.size() + 1; + output_ragged->nested_splits.push_back( + Tensor(DataTypeToEnum::value, TensorShape({splits_size}))); + auto dims_splits_vec = + output_ragged->nested_splits[dims - 1].vec(); + dims_splits_vec(0) = 0; + for (int i = 0; i < ragged_components.size(); i++) { + int split_val = ragged_components[i].values.NumElements(); + if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) { + split_val = ragged_components[i].nested_splits[0].NumElements() - 1; + } + dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val; + } + + // Populate last `input_ragged_rank` splits. + for (int i = 0; i < input_ragged_rank; i++) { + int split_index = dims + i; + int split_size = 1; + for (int j = 0; j < ragged_components.size(); j++) { + if (!ragged_components[j].nested_splits.empty()) { + split_size += ragged_components[j].nested_splits[i].NumElements() - 1; + } + } + output_ragged->nested_splits.push_back( + Tensor(DataTypeToEnum::value, TensorShape({split_size}))); + auto splits_vec = + output_ragged->nested_splits[split_index].vec(); + splits_vec(0) = 0; + SPLIT_TYPE last_split_value = 0; + int index = 1; + for (int j = 0; j < ragged_components.size(); j++) { + if (ragged_components[j].nested_splits.empty()) { + // Corner case: empty row. e.g [ [[x], [x]], [] ] + continue; + } + auto component_splits_vec = + ragged_components[j].nested_splits[i].vec(); + for (int k = 1; k < component_splits_vec.size(); k++, index++) { + splits_vec(index) = component_splits_vec(k) + last_split_value; + } + last_split_value = splits_vec(index - 1); + } + } + + // Populate values. + TensorShape component_values_shape = ragged_components[0].values.shape(); + int values_size = component_values_shape.dim_size(0); + for (int i = 1; i < ragged_components.size(); i++) { + if (ragged_components[i].values.dims() != component_values_shape.dims()) { + return errors::InvalidArgument( + "Rank of values must match for all " + "components; values shape at index 0: ", + component_values_shape.DebugString(), ", values shape at index ", i, + ": ", ragged_components[i].values.shape().DebugString()); + } + values_size += ragged_components[i].values.shape().dim_size(0); + } + component_values_shape.set_dim(0, values_size); + output_ragged->values = + Tensor(DataTypeToEnum::value, component_values_shape); + auto output_values_flat = + output_ragged->values.flat_outer_dims(); + int values_index = 0; + for (int i = 0; i < ragged_components.size(); i++) { + auto component_values_flat = + ragged_components[i].values.flat_outer_dims(); + int num_inner_elements = ragged_components[i].values.NumElements(); + if (ragged_components[i].values.dim_size(0) > 0) { + num_inner_elements /= ragged_components[i].values.dim_size(0); + } + for (int j = 0; j < ragged_components[i].values.dim_size(0); + j++, values_index++) { + for (int k = 0; k < num_inner_elements; k++) { + output_values_flat(values_index, k) = component_values_flat(j, k); + } + } + } + return Status::OK(); +} +} // namespace + +template +class RaggedTensorFromVariantOp : public OpKernel { + public: + explicit RaggedTensorFromVariantOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("input_ragged_rank", &input_ragged_rank_)); + OP_REQUIRES_OK( + context, context->GetAttr("output_ragged_rank", &output_ragged_rank_)); + } + + void Compute(OpKernelContext* context) override { + // Read input Tensor. + const Tensor& encoded_variant = context->input(0); + + if (input_ragged_rank_ == -1) { // Infer input_ragged_rank_. + input_ragged_rank_ = output_ragged_rank_ - encoded_variant.dims(); + OP_REQUIRES(context, input_ragged_rank_ >= 0, + errors::InvalidArgument( + "Inferred input_ragged_rank (output_ragged_rank - " + "encoded_variant.dims()) must be >= 0, found " + "output_ragged_rank: ", + output_ragged_rank_, + ", encoded_variant.dims(): ", encoded_variant.dims(), + ", inferred input_ragged_rank: ", input_ragged_rank_)); + } + OP_REQUIRES( + context, + output_ragged_rank_ == encoded_variant.dims() + input_ragged_rank_, + errors::InvalidArgument( + "output_ragged_rank must be equal to input_ragged_rank + " + "encoded_ragged.dims(); output_ragged_rank: ", + output_ragged_rank_, ", input_ragged_rank: ", input_ragged_rank_, + ", encoded_variant.dims(): ", encoded_variant.dims(), ".")); + + // Decode all variants. + const auto value_dtype = DataTypeToEnum::v(); + const auto split_dtype = DataTypeToEnum::v(); + std::vector decoded_components; + OP_REQUIRES_OK(context, RaggedComponentsFromVariant( + encoded_variant, input_ragged_rank_, + value_dtype, split_dtype, &decoded_components)); + + // Corner case: input is a scalar. + if (encoded_variant.dims() == 0) { + ReturnRaggedTensor(context, decoded_components[0]); + return; + } + + // Nested-Stack Ragged components into a batched RaggedTensor. + std::vector encoded_dim_sizes(encoded_variant.dims(), 0); + for (int i = 0; i < encoded_variant.dims(); i++) { + encoded_dim_sizes[i] = encoded_variant.dim_size(i); + } + RaggedTensor output_ragged; + OP_REQUIRES_OK( + context, NestedStackRaggedTensors( + decoded_components, encoded_dim_sizes, input_ragged_rank_, + output_ragged_rank_, &output_ragged)); + + // Set output. + ReturnRaggedTensor(context, output_ragged); + } + + private: + int input_ragged_rank_; + int output_ragged_rank_; + + void ReturnRaggedTensor(OpKernelContext* context, + RaggedTensor ragged_tensor) { + int ragged_rank = ragged_tensor.nested_splits.size(); + OpOutputList splits_out; + OP_REQUIRES_OK(context, + context->output_list("output_nested_splits", &splits_out)); + for (int i = 0; i < ragged_rank; i++) { + splits_out.set(i, ragged_tensor.nested_splits[i]); + } + context->set_output(ragged_rank, ragged_tensor.values); + } +}; + +#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ + REGISTER_KERNEL_BUILDER(Name("RaggedTensorFromVariant") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tvalues") \ + .TypeConstraint("Tsplits"), \ + RaggedTensorFromVariantOp); +#define REGISTER_KERNELS(value_type) \ + REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \ + REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64) +TF_CALL_POD_TYPES(REGISTER_KERNELS); +TF_CALL_string(REGISTER_KERNELS); +TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); +TF_CALL_quint16(REGISTER_KERNELS); +TF_CALL_qint16(REGISTER_KERNELS); +TF_CALL_uint32(REGISTER_KERNELS); +TF_CALL_uint64(REGISTER_KERNELS); +#undef REGISTER_KERNELS +#undef REGISTER_KERNELS_WITH_SPLIT_TYPE +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc new file mode 100644 index 00000000000..cb51be2e03c --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc @@ -0,0 +1,695 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/strings/match.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class RaggedTensorFromVariantKernelTest : public ::tensorflow::OpsTestBase { + protected: + // Builds the tensorflow test graph for the RaggedTensorFromVariant op, and + // populates the variant input with the given values. + template + void BuildDecodeRaggedTensorGraph( + const int input_ragged_rank, const int output_ragged_rank, + const TensorShape& variant_shape, + const std::vector& variant_values) { + const auto value_dtype = DataTypeToEnum::v(); + const auto split_dtype = DataTypeToEnum::v(); + TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorFromVariant") + .Input(FakeInput(DT_VARIANT)) + .Attr("input_ragged_rank", input_ragged_rank) + .Attr("output_ragged_rank", output_ragged_rank) + .Attr("Tvalues", value_dtype) + .Attr("Tsplits", split_dtype) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(variant_shape, variant_values); + } + + template + Tensor CreateVariantFromRagged( + const std::vector>& ragged_splits, + const TensorShape& ragged_values_shape, + const std::vector& ragged_values) { + // Step 1: Create Tensors out of ragged splits and values. + std::vector ragged_components; + for (auto ragged_split : ragged_splits) { + int splits_size = ragged_split.size(); + Tensor splits(DataTypeToEnum::v(), + TensorShape({splits_size})); + test::FillValues(&splits, ragged_split); + ragged_components.push_back(splits); + } + Tensor values(DataTypeToEnum::v(), ragged_values_shape); + test::FillValues(&values, ragged_values); + ragged_components.push_back(values); + + // Step 2: Encode into a 1-D Variant Tensor. + int num_splits = ragged_splits.size(); + Tensor encoded_list(DT_VARIANT, TensorShape({num_splits + 1})); + test::FillValues(&encoded_list, ragged_components); + return encoded_list; + } +}; + +TEST_F(RaggedTensorFromVariantKernelTest, ScalarInput) { + const std::vector split_1 = {0, 1, 2, 3, 4, 5}; + const std::vector split_2 = {0, 1, 2, 5, 6, 7}; + const std::vector values = {0, 1, 1, 2, 2, 3, 4}; + + Tensor encoded_variant = CreateVariantFromRagged( + {split_1, split_2}, TensorShape({7}), values); + Tensor expected_splits_1(DT_INT64, TensorShape({6})); + Tensor expected_splits_2(DT_INT64, TensorShape({6})); + Tensor expected_values(DT_INT32, TensorShape({7})); + + test::FillValues(&expected_splits_1, split_1); + test::FillValues(&expected_splits_2, split_2); + test::FillValues(&expected_values, values); + + int input_ragged_rank = 2; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph(input_ragged_rank, + output_ragged_rank, TensorShape({}), + {encoded_variant}); + TF_ASSERT_OK(RunOpKernel()); + + test::ExpectTensorEqual(*GetOutput(0), expected_splits_1); + test::ExpectTensorEqual(*GetOutput(1), expected_splits_2); + test::ExpectTensorEqual(*GetOutput(2), expected_values); +} + +TEST_F(RaggedTensorFromVariantKernelTest, OneInputElement) { + const std::vector split_1 = {0, 1, 2, 3, 4, 5}; + const std::vector split_2 = {0, 1, 2, 5, 6, 7}; + const std::vector values = {0, 1, 1, 2, 2, 3, 4}; + const std::vector batched_splits_1 = {0, 5}; + + Tensor encoded_variant = CreateVariantFromRagged( + {split_1, split_2}, TensorShape({7}), values); + Tensor expected_splits_1(DT_INT64, TensorShape({2})); + Tensor expected_splits_2(DT_INT64, TensorShape({6})); + Tensor expected_splits_3(DT_INT64, TensorShape({6})); + Tensor expected_values(DT_INT32, TensorShape({7})); + + test::FillValues(&expected_splits_1, batched_splits_1); + test::FillValues(&expected_splits_2, split_1); + test::FillValues(&expected_splits_3, split_2); + test::FillValues(&expected_values, values); + + int input_ragged_rank = 2; + int output_ragged_rank = 3; + BuildDecodeRaggedTensorGraph(input_ragged_rank, + output_ragged_rank, TensorShape({1}), + {encoded_variant}); + TF_ASSERT_OK(RunOpKernel()); + + test::ExpectTensorEqual(*GetOutput(0), expected_splits_1); + test::ExpectTensorEqual(*GetOutput(1), expected_splits_2); + test::ExpectTensorEqual(*GetOutput(2), expected_splits_3); + test::ExpectTensorEqual(*GetOutput(3), expected_values); +} + +TEST_F(RaggedTensorFromVariantKernelTest, TensorIn2DOut) { + // component_1 = [x, x, x] + // component_2 = [] + // component_3 = [x, x] + // component_4 = [] + // batched_ragged = + // [[component_1, component_2], [component_3, component_4]] + // [ + // [ [x, x, x], [] ], + // [ [x, x], [x] ] + // ] + const std::vector values_1 = {1, 2, 3}; + const std::vector values_2 = {}; + const std::vector values_3 = {4, 5}; + const std::vector values_4 = {6}; + const std::vector batched_splits_1 = {0, 2, 4}; + const std::vector batched_splits_2 = {0, 3, 3, 5, 6}; + const std::vector batched_values = {1, 2, 3, 4, 5, 6}; + + Tensor component_variant_1 = + CreateVariantFromRagged({}, TensorShape({3}), values_1); + Tensor component_variant_2 = + CreateVariantFromRagged({}, TensorShape({0}), values_2); + Tensor component_variant_3 = + CreateVariantFromRagged({}, TensorShape({2}), values_3); + Tensor component_variant_4 = + CreateVariantFromRagged({}, TensorShape({1}), values_4); + + Tensor expected_splits_1(DT_INT64, TensorShape({3})); + Tensor expected_splits_2(DT_INT64, TensorShape({5})); + Tensor expected_values(DT_INT32, TensorShape({6})); + + test::FillValues(&expected_splits_1, batched_splits_1); + test::FillValues(&expected_splits_2, batched_splits_2); + test::FillValues(&expected_values, batched_values); + + int input_ragged_rank = 0; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2, 2}), + {component_variant_1, component_variant_2, component_variant_3, + component_variant_4}); + TF_ASSERT_OK(RunOpKernel()); + + test::ExpectTensorEqual(*GetOutput(0), expected_splits_1); + test::ExpectTensorEqual(*GetOutput(1), expected_splits_2); + test::ExpectTensorEqual(*GetOutput(2), expected_values); +} + +TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOut) { + // ragged_component_1 = [[x]] + // ragged_component_2 = [[x], [x]] + // ragged_component_3 = [[x, x]] + // ragged_component_4 = [[x, x], [x]] + // ragged_component_5 = [[x], [x, x]] + // batched_ragged = [[rc1, rc2, rc3, rc4, rc5], [rc4, rc5, rc1, rc3, rc2]] + const std::vector component_split_1_1 = {0, 1}; + const std::vector component_split_2_1 = {0, 1, 2}; + const std::vector component_split_3_1 = {0, 2}; + const std::vector component_split_4_1 = {0, 2, 3}; + const std::vector component_split_5_1 = {0, 1, 3}; + const std::vector component_values_1 = {0}; + const std::vector component_values_2 = {0, 1}; + const std::vector component_values_3 = {0, 1}; + const std::vector component_values_4 = {0, 1, 2}; + const std::vector component_values_5 = {0, 1, 2}; + + const std::vector batched_splits_1 = {0, 5, 10}; + const std::vector batched_splits_2 = {0, 1, 3, 4, 6, 8, + 10, 12, 13, 14, 16}; + const std::vector batched_splits_3 = { + 0, 1, 2, 3, 5, 7, 8, 9, 11, 13, 14, 15, 17, 18, 20, 21, 22}; + const std::vector batched_values = {0, 0, 1, 0, 1, 0, 1, 2, 0, 1, 2, + 0, 1, 2, 0, 1, 2, 0, 0, 1, 0, 1}; + + Tensor expected_splits_1(DT_INT64, TensorShape({3})); + Tensor expected_splits_2(DT_INT64, TensorShape({11})); + Tensor expected_splits_3(DT_INT64, TensorShape({17})); + Tensor expected_values(DT_INT32, TensorShape({22})); + + test::FillValues(&expected_splits_1, batched_splits_1); + test::FillValues(&expected_splits_2, batched_splits_2); + test::FillValues(&expected_splits_3, batched_splits_3); + test::FillValues(&expected_values, batched_values); + + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({1}), component_values_1); + Tensor variant_component_2 = CreateVariantFromRagged( + {component_split_2_1}, TensorShape({2}), component_values_2); + Tensor variant_component_3 = CreateVariantFromRagged( + {component_split_3_1}, TensorShape({2}), component_values_3); + Tensor variant_component_4 = CreateVariantFromRagged( + {component_split_4_1}, TensorShape({3}), component_values_4); + Tensor variant_component_5 = CreateVariantFromRagged( + {component_split_5_1}, TensorShape({3}), component_values_5); + int input_ragged_rank = 1; + int output_ragged_rank = 3; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2, 5}), + {variant_component_1, variant_component_2, variant_component_3, + variant_component_4, variant_component_5, variant_component_4, + variant_component_5, variant_component_1, variant_component_3, + variant_component_2}); + TF_ASSERT_OK(RunOpKernel()); + + test::ExpectTensorEqual(*GetOutput(0), expected_splits_1); + test::ExpectTensorEqual(*GetOutput(1), expected_splits_2); + test::ExpectTensorEqual(*GetOutput(2), expected_splits_3); + test::ExpectTensorEqual(*GetOutput(3), expected_values); +} + +TEST_F(RaggedTensorFromVariantKernelTest, + NonEmpty2DIn4DOutInferredInputRaggedRank) { + // ragged_component_1 = + // [ + // [ [x] ], + // [ [x], [x] ], + // [ [x, x] ], + // [ [x, x], [x] ], + // [ [x], [x, x] ] + // ] + // ragged_component_2 = + // [ + // [ [x, x], [x] ], + // [ [x], [x, x] ], + // [ [x] ], + // [ [x, x] ], + // [ [x], [x] ] + // ] + // batched_ragged = [[rc1, rc2], [rc2, rc1]] + const std::vector component_split_1_1 = {0, 1, 3, 4, 6, 8}; + const std::vector component_split_1_2 = {0, 1, 2, 3, 5, 7, 8, 9, 11}; + const std::vector component_split_2_1 = {0, 2, 4, 5, 6, 8}; + const std::vector component_split_2_2 = {0, 2, 3, 4, 6, 7, 9, 10, 11}; + const std::vector component_values_1 = {0, 0, 1, 0, 1, 0, 1, 2, 0, 1, 2}; + const std::vector component_values_2 = {0, 1, 2, 0, 1, 2, 0, 0, 1, 0, 1}; + const std::vector batched_splits_1 = {0, 2, 4}; + const std::vector batched_splits_2 = {0, 5, 10, 15, 20}; + const std::vector batched_splits_3 = {0, 1, 3, 4, 6, 8, 10, + 12, 13, 14, 16, 18, 20, 21, + 22, 24, 25, 27, 28, 30, 32}; + const std::vector batched_splits_4 = { + 0, 1, 2, 3, 5, 7, 8, 9, 11, 13, 14, 15, 17, 18, 20, 21, 22, + 24, 25, 26, 28, 29, 31, 32, 33, 34, 35, 36, 38, 40, 41, 42, 44}; + const std::vector batched_values = { + 0, 0, 1, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 0, 1, 0, 1, + 0, 1, 2, 0, 1, 2, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 2, 0, 1, 2}; + + Tensor expected_splits_1(DT_INT64, TensorShape({3})); + Tensor expected_splits_2(DT_INT64, TensorShape({5})); + Tensor expected_splits_3(DT_INT64, TensorShape({21})); + Tensor expected_splits_4(DT_INT64, TensorShape({33})); + Tensor expected_values(DT_INT32, TensorShape({44})); + test::FillValues(&expected_splits_1, batched_splits_1); + test::FillValues(&expected_splits_2, batched_splits_2); + test::FillValues(&expected_splits_3, batched_splits_3); + test::FillValues(&expected_splits_4, batched_splits_4); + test::FillValues(&expected_values, batched_values); + + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1, component_split_1_2}, TensorShape({11}), + component_values_1); + Tensor variant_component_2 = CreateVariantFromRagged( + {component_split_2_1, component_split_2_2}, TensorShape({11}), + component_values_2); + int input_ragged_rank = -1; + int output_ragged_rank = 4; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2, 2}), + {variant_component_1, variant_component_2, variant_component_2, + variant_component_1}); + TF_ASSERT_OK(RunOpKernel()); + + test::ExpectTensorEqual(*GetOutput(0), expected_splits_1); + test::ExpectTensorEqual(*GetOutput(1), expected_splits_2); + test::ExpectTensorEqual(*GetOutput(2), expected_splits_3); + test::ExpectTensorEqual(*GetOutput(3), expected_splits_4); + test::ExpectTensorEqual(*GetOutput(4), expected_values); +} + +TEST_F(RaggedTensorFromVariantKernelTest, EmptyRow1DIn2DOut) { + // ragged_component_1 = [[x, x, x], []] + // ragged_component_2 = [] + // batched_ragged = [rc1, rc2] = [[[x, x, x], []], []] + const std::vector component_split_1_1 = {0, 3, 3}; + const std::vector component_values_1 = {1, 2, 3}; + const std::vector component_split_2_1 = {0}; + const std::vector batched_splits_1 = {0, 2, 2}; + const std::vector batched_splits_2 = {0, 3, 3}; + const std::vector batched_values = {1, 2, 3}; + + Tensor expected_splits_1(DT_INT64, TensorShape({3})); + Tensor expected_splits_2(DT_INT64, TensorShape({3})); + Tensor expected_values(DT_INT32, TensorShape({3})); + test::FillValues(&expected_splits_1, batched_splits_1); + test::FillValues(&expected_splits_2, batched_splits_2); + test::FillValues(&expected_values, batched_values); + + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({3}), component_values_1); + Tensor variant_component_2 = CreateVariantFromRagged( + {component_split_2_1}, TensorShape({0}), {}); // Empty row. + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2}), + {variant_component_1, variant_component_2}); + TF_ASSERT_OK(RunOpKernel()); + + test::ExpectTensorEqual(*GetOutput(0), expected_splits_1); + test::ExpectTensorEqual(*GetOutput(1), expected_splits_2); + test::ExpectTensorEqual(*GetOutput(2), expected_values); +} + +TEST_F(RaggedTensorFromVariantKernelTest, NDValues1DIn2DOut) { + // ragged_component_1 = [[x]] + // ragged_component_1 = [[x], [x]] + // batched_ragged = [rc1, rc2] = [[[x]], [[x], [x]]] + const std::vector component_split_1_1 = {0, 1}; + const std::vector component_values_1 = {1, 2}; + const std::vector component_split_2_1 = {0, 1, 2}; + const std::vector component_values_2 = {1, 2, 3, 4}; + const std::vector batched_splits_1 = {0, 1, 3}; + const std::vector batched_splits_2 = {0, 1, 2, 3}; + const std::vector batched_values = {1, 2, 1, 2, 3, 4}; + + Tensor expected_splits_1(DT_INT64, TensorShape({3})); + Tensor expected_splits_2(DT_INT64, TensorShape({4})); + Tensor expected_values(DT_INT32, TensorShape({3, 2})); + test::FillValues(&expected_splits_1, batched_splits_1); + test::FillValues(&expected_splits_2, batched_splits_2); + test::FillValues(&expected_values, batched_values); + + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({1, 2}), component_values_1); + Tensor variant_component_2 = CreateVariantFromRagged( + {component_split_2_1}, TensorShape({2, 2}), component_values_2); + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2}), + {variant_component_1, variant_component_2}); + + TF_ASSERT_OK(RunOpKernel()); + test::ExpectTensorEqual(*GetOutput(0), expected_splits_1); + test::ExpectTensorEqual(*GetOutput(1), expected_splits_2); + test::ExpectTensorEqual(*GetOutput(2), expected_values); +} + +TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) { + // ragged_component_1 = [[x]] + // ragged_component_2 = [[x], [x]] + // ragged_component_3 = [[x, x]] + // ragged_component_4 = [[x, x], [x]] + // ragged_component_5 = [[x], [x, x]] + // batched_ragged = [[rc1, rc2, rc3, rc4, rc5], [rc4, rc5, rc1, rc3, rc2]] + const std::vector component_split_1_1 = {0, 1}; + const std::vector component_split_2_1 = {0, 1, 2}; + const std::vector component_split_3_1 = {0, 2}; + const std::vector component_split_4_1 = {0, 2, 3}; + const std::vector component_split_5_1 = {0, 1, 3}; + const std::vector component_values_1 = {0}; + const std::vector component_values_2 = {0, 1}; + const std::vector component_values_3 = {0, 1}; + const std::vector component_values_4 = {0, 1, 2}; + const std::vector component_values_5 = {0, 1, 2}; + + const std::vector batched_splits_1 = {0, 5, 10}; + const std::vector batched_splits_2 = {0, 1, 3, 4, 6, 8, + 10, 12, 13, 14, 16}; + const std::vector batched_splits_3 = {0, 1, 2, 3, 5, 7, 8, 9, 11, + 13, 14, 15, 17, 18, 20, 21, 22}; + const std::vector batched_values = {0, 0, 1, 0, 1, 0, 1, 2, 0, 1, 2, + 0, 1, 2, 0, 1, 2, 0, 0, 1, 0, 1}; + + Tensor expected_splits_1(DT_INT32, TensorShape({3})); + Tensor expected_splits_2(DT_INT32, TensorShape({11})); + Tensor expected_splits_3(DT_INT32, TensorShape({17})); + Tensor expected_values(DT_INT32, TensorShape({22})); + + test::FillValues(&expected_splits_1, batched_splits_1); + test::FillValues(&expected_splits_2, batched_splits_2); + test::FillValues(&expected_splits_3, batched_splits_3); + test::FillValues(&expected_values, batched_values); + + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({1}), component_values_1); + Tensor variant_component_2 = CreateVariantFromRagged( + {component_split_2_1}, TensorShape({2}), component_values_2); + Tensor variant_component_3 = CreateVariantFromRagged( + {component_split_3_1}, TensorShape({2}), component_values_3); + Tensor variant_component_4 = CreateVariantFromRagged( + {component_split_4_1}, TensorShape({3}), component_values_4); + Tensor variant_component_5 = CreateVariantFromRagged( + {component_split_5_1}, TensorShape({3}), component_values_5); + int input_ragged_rank = 1; + int output_ragged_rank = 3; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2, 5}), + {variant_component_1, variant_component_2, variant_component_3, + variant_component_4, variant_component_5, variant_component_4, + variant_component_5, variant_component_1, variant_component_3, + variant_component_2}); + TF_ASSERT_OK(RunOpKernel()); + + test::ExpectTensorEqual(*GetOutput(0), expected_splits_1); + test::ExpectTensorEqual(*GetOutput(1), expected_splits_2); + test::ExpectTensorEqual(*GetOutput(2), expected_splits_3); + test::ExpectTensorEqual(*GetOutput(3), expected_values); +} + +// Tests for invalid inputs. +TEST_F(RaggedTensorFromVariantKernelTest, InvalidInferredInputRaggedRank) { + Tensor component_variant_1 = + CreateVariantFromRagged({}, TensorShape({3}), {1, 2, 3}); + Tensor component_variant_2 = + CreateVariantFromRagged({}, TensorShape({0}), {}); + Tensor component_variant_3 = + CreateVariantFromRagged({}, TensorShape({2}), {1, 2}); + Tensor component_variant_4 = + CreateVariantFromRagged({}, TensorShape({1}), {1}); + + int input_ragged_rank = -1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({1, 1, 1, 4}), + {component_variant_1, component_variant_2, component_variant_3, + component_variant_4}); + EXPECT_TRUE( + absl::StartsWith(RunOpKernel().error_message(), + "Inferred input_ragged_rank (output_ragged_rank - " + "encoded_variant.dims()) must be >= 0")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) { + const std::vector component_split_1_1 = {0, 1}; + const std::vector component_split_2_1 = {0, 1, 2}; + const std::vector component_values_1 = {0}; + const std::vector component_values_2 = {0, 1}; + + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({1}), component_values_1); + Tensor variant_component_2 = CreateVariantFromRagged( + {component_split_2_1}, TensorShape({2}), component_values_2); + + int input_ragged_rank = 1; + int output_ragged_rank = 4; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2}), + {variant_component_1, variant_component_2}); + EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), + "output_ragged_rank must be equal to " + "input_ragged_rank + encoded_ragged.dims()")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldTensors) { + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2}), {1, 2}); + EXPECT_TRUE(absl::StartsWith( + RunOpKernel().error_message(), + "Input Variant element at index 0 doesn't hold a Tensor")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, InputVariantTensorRankNotOne) { + Tensor variant_list(DT_VARIANT, TensorShape({2, 1})); + test::FillValues(&variant_list, {1, 2}); + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list}); + EXPECT_TRUE(absl::StartsWith( + RunOpKernel().error_message(), + "Encoded input Variant must have rank 1, but found rank: 2")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, + InputScalarElementDoesNotMatchInputRaggedRank) { + const std::vector component_split_1_1 = {0, 1}; + const std::vector component_values_1 = {1, 2}; + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({1, 2}), component_values_1); + + int input_ragged_rank = 2; + int output_ragged_rank = 3; + BuildDecodeRaggedTensorGraph(input_ragged_rank, + output_ragged_rank, TensorShape({1}), + {variant_component_1}); + EXPECT_TRUE(absl::StartsWith( + RunOpKernel().error_message(), + "Encoded input Variant must hold either input_ragged_rank + 1 " + "Tensors or an empty Tensor")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitNotATensor) { + Tensor variant_list(DT_VARIANT, TensorShape({2})); + test::FillValues(&variant_list, {1, 2}); + + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph(input_ragged_rank, output_ragged_rank, + TensorShape({1}), {variant_list}); + EXPECT_TRUE( + absl::StartsWith(RunOpKernel().error_message(), + "Encoded scalar element at index 0 doesn't have a " + "splits Tensor at split_index 0")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) { + const std::vector component_split_1_1 = {0, 1}; + const std::vector component_values_1 = {0}; + + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({1}), component_values_1); + + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph(input_ragged_rank, output_ragged_rank, + TensorShape({1}), + {variant_component_1}); + EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), + "Expected splits Tensor dtype: 3, found: 9")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitRankNotOne) { + Tensor splits(DT_INT64, TensorShape({2, 1})); + test::FillValues(&splits, {1, 2}); + Tensor values(DT_INT32, {2}); + test::FillValues(&values, {1, 2}); + Tensor encoded_list(DT_VARIANT, TensorShape({2})); + test::FillValues(&encoded_list, {splits, values}); + + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded_list}); + EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), + "Ragged splits must have rank 1")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesNotATensor) { + Tensor splits(DT_INT64, TensorShape({3})); + test::FillValues(&splits, {0, 2, 3}); + Tensor variant_list(DT_VARIANT, TensorShape({2})); + test::FillValues(&variant_list, {splits, 2}); + + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list}); + EXPECT_TRUE( + absl::StartsWith(RunOpKernel().error_message(), + "Encoded scalar element at index 0 doesn't have a " + "values Tensor")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) { + const std::vector component_split_1_1 = {0, 1}; + const std::vector component_values_1 = {0}; + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({1}), component_values_1); + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({1}), + {variant_component_1}); + EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), + "Expected values Tensor dtype: 7, found: 3")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankNotGreaterThanOne) { + Tensor variant_component_1 = + CreateVariantFromRagged({{0, 1}}, TensorShape({}), {1}); + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph(input_ragged_rank, + output_ragged_rank, TensorShape({1}), + {variant_component_1}); + EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), + "Ragged values must have rank >= 1")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankMismatch) { + const std::vector component_split_1_1 = {0, 1}; + const std::vector component_split_2_1 = {0, 1, 2}; + const std::vector component_values_1 = {0}; + const std::vector component_values_2 = {0, 1, 2, 3}; + + Tensor variant_component_1 = CreateVariantFromRagged( + {component_split_1_1}, TensorShape({1}), component_values_1); + Tensor variant_component_2 = CreateVariantFromRagged( + {component_split_2_1}, TensorShape({2, 2}), component_values_2); + int input_ragged_rank = 1; + int output_ragged_rank = 2; + BuildDecodeRaggedTensorGraph( + input_ragged_rank, output_ragged_rank, TensorShape({2}), + {variant_component_1, variant_component_2}); + EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(), + "Rank of values must match for all components")); +} + +TEST_F(RaggedTensorFromVariantKernelTest, ShapeFnTest) { + ShapeInferenceTestOp op("RaggedTensorFromVariant"); + + // Tests with input_ragged_rank == 0. + (*op.node_def.mutable_attr())["input_ragged_rank"].set_i(0); + (*op.node_def.mutable_attr())["output_ragged_rank"].set_i(1); + INFER_OK(op, "?", "[?];?"); + INFER_OK(op, "[?]", "[?];?"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[?,?]"); + + // Tests with input_ragged_rank == 1. + (*op.node_def.mutable_attr())["input_ragged_rank"].set_i(1); + + (*op.node_def.mutable_attr())["output_ragged_rank"].set_i(1); + INFER_OK(op, "?", "[?];?"); + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?]"); + INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[?,?]"); + + (*op.node_def.mutable_attr())["output_ragged_rank"].set_i(2); + INFER_OK(op, "?", "[?];[?];?"); + INFER_OK(op, "[?]", "[?];[?];?"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[?,?]"); + + (*op.node_def.mutable_attr())["output_ragged_rank"].set_i(3); + INFER_OK(op, "?", "[?];[?];[?];?"); + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[?]"); + INFER_OK(op, "[?,?]", "[?];[?];[?];?"); + INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[?,?,?]"); + + // Tests with input_ragged_rank == 3. + (*op.node_def.mutable_attr())["input_ragged_rank"].set_i(3); + + (*op.node_def.mutable_attr())["output_ragged_rank"].set_i(3); + INFER_OK(op, "?", "[?];[?];[?];?"); + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?]"); + + (*op.node_def.mutable_attr())["output_ragged_rank"].set_i(4); + INFER_OK(op, "?", "[?];[?];[?];[?];?"); + INFER_OK(op, "[?]", "[?];[?];[?];[?];?"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[?,?]"); + + (*op.node_def.mutable_attr())["output_ragged_rank"].set_i(5); + INFER_OK(op, "?", "[?];[?];[?];[?];[?];?"); + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[?]"); + INFER_OK(op, "[?,?]", "[?];[?];[?];[?];[?];?"); + + (*op.node_def.mutable_attr())["output_ragged_rank"].set_i(6); + INFER_OK(op, "?", "[?];[?];[?];[?];[?];[?];?"); + INFER_ERROR("Shape must be rank 3 but is rank 1", op, "[?]"); + INFER_ERROR("Shape must be rank 3 but is rank 2", op, "[?,?]"); + INFER_OK(op, "[?,?,?]", "[?];[?];[?];[?];[?];[?];?"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc index 8cd4b8da858..39b530f4a15 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc @@ -26,21 +26,23 @@ namespace tensorflow { using errors::InvalidArgument; +template class RaggedTensorToSparseOp : public OpKernel { public: using OpKernel::OpKernel; + using ConstFlatSplits = typename TTypes::ConstFlat; void Compute(OpKernelContext* context) override { // Read the `rt_nested_splits` input & convert to Eigen tensors. OpInputList rt_nested_splits_in; OP_REQUIRES_OK( context, context->input_list("rt_nested_splits", &rt_nested_splits_in)); - const int64 rt_nested_splits_len = rt_nested_splits_in.size(); + const int rt_nested_splits_len = rt_nested_splits_in.size(); DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP. - std::vector::ConstFlat> rt_nested_splits; + std::vector rt_nested_splits; rt_nested_splits.reserve(rt_nested_splits_len); for (int i = 0; i < rt_nested_splits_len; ++i) { - rt_nested_splits.push_back(rt_nested_splits_in[i].flat()); + rt_nested_splits.push_back(rt_nested_splits_in[i].flat()); } // Read the `rt_dense_values` input. @@ -135,7 +137,7 @@ class RaggedTensorToSparseOp : public OpKernel { sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1; for (int dim = 0; dim < rt_nested_splits_len; ++dim) { const auto& splits = rt_nested_splits[dim]; - int64 max_width = 0; + SPLITS_TYPE max_width = 0; for (int i = 1; i < splits.size(); ++i) { max_width = std::max(max_width, splits(i) - splits(i - 1)); } @@ -150,7 +152,7 @@ class RaggedTensorToSparseOp : public OpKernel { private: // Validate `rt_nested_splits` to ensure we don't get any segfaults. static ::tensorflow::Status ValidateInputs( - std::vector::ConstFlat> rt_nested_splits, + std::vector rt_nested_splits, const Tensor& rt_dense_values_in) { for (int i = 0; i < rt_nested_splits.size(); ++i) { if (rt_nested_splits[i].size() == 0) { @@ -160,7 +162,7 @@ class RaggedTensorToSparseOp : public OpKernel { return InvalidArgument("First value of ragged splits must be 0."); } if (i > 0) { - int64 last_split = + SPLITS_TYPE last_split = rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1); if (rt_nested_splits[i].size() != last_split + 1) { return InvalidArgument( @@ -206,14 +208,21 @@ class RaggedTensorToSparseOp : public OpKernel { // values. static bool IsCompleted( const std::vector& pos, int dim, - const std::vector::ConstFlat>& rt_nested_splits) { + const std::vector& rt_nested_splits) { int64 current_child = pos[dim + 1]; int64 limit_child = rt_nested_splits[dim](pos[dim] + 1); return current_child >= limit_child; } }; -REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse").Device(DEVICE_CPU), - RaggedTensorToSparseOp); +REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse") + .Device(DEVICE_CPU) + .TypeConstraint("Tsplits"), + RaggedTensorToSparseOp); + +REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse") + .Device(DEVICE_CPU) + .TypeConstraint("Tsplits"), + RaggedTensorToSparseOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc new file mode 100644 index 00000000000..6923fd45f11 --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -0,0 +1,221 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +struct RaggedTensor { + Tensor values; + std::vector nested_splits; +}; + +Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) { + // Encode as a rank-1 Variant Tensor. + int ragged_rank = ragged.nested_splits.size(); + *encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1})); + auto encoded_vec = encoded_list->vec(); + for (int i = 0; i < ragged_rank; i++) { + encoded_vec(i) = ragged.nested_splits[i]; + } + encoded_vec(ragged_rank) = ragged.values; + return Status::OK(); +} + +template +Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged, + std::vector* ragged_components) { + // Set up the component Ragged Tensors. + int ragged_rank = batched_ragged.nested_splits.size(); + auto batched_splits_top_vec = + batched_ragged.nested_splits[0].vec(); + int num_components = batched_splits_top_vec.size() - 1; + int num_splits = ragged_rank - 1; + ragged_components->resize(num_components); + for (RaggedTensor ragged_component : *ragged_components) { + ragged_component.nested_splits.reserve(num_splits); + } + const auto& batched_flat = batched_ragged.values.flat(); + int num_inner_elems = batched_ragged.values.NumElements(); + if (batched_ragged.values.dim_size(0) > 1) { + num_inner_elems /= batched_ragged.values.dim_size(0); + } + TensorShape values_shape = batched_ragged.values.shape(); + + // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]] + if (num_splits == 0) { + for (int i = 0; i < num_components; i++) { + int start = batched_splits_top_vec(i); + int limit = batched_splits_top_vec(i + 1); + int num_values = limit - start; + values_shape.set_dim(0, num_values); + (*ragged_components)[i].values = + Tensor(DataTypeToEnum::value, values_shape); + auto ragged_component_values_flat = + (*ragged_components)[i].values.flat(); + for (int j = 0; j < num_values * num_inner_elems; j++) { + ragged_component_values_flat(j) = + batched_flat(j + start * num_inner_elems); + } + } + return Status::OK(); + } + + // Unbatch nested splits. + std::vector::ConstVec> batched_splits_vec; + batched_splits_vec.reserve(ragged_rank); + for (int i = 0; i < ragged_rank; i++) { + batched_splits_vec.push_back( + batched_ragged.nested_splits[i].vec()); + } + std::vector index(num_splits, 1); + std::vector ragged_component_values_size(num_components, 0); + for (int i = 0; i < num_components; i++) { + std::vector::Vec> ragged_component_splits_vec; + ragged_component_splits_vec.reserve(num_splits); + int split_size = -1; + for (int j = 0; j < num_splits; j++) { + if (j == 0) { + split_size = + batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1; + } else { + // Update split size based on previous split. + int last_index = ragged_component_splits_vec[j - 1].size() - 1; + split_size = ragged_component_splits_vec[j - 1](last_index) + 1; + } + (*ragged_components)[i].nested_splits.push_back( + Tensor(DataTypeToEnum::value, TensorShape({split_size}))); + ragged_component_splits_vec.push_back( + (*ragged_components)[i].nested_splits[j].vec()); + SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1); + ragged_component_splits_vec[j](0) = 0; + for (int k = 1; k < split_size; k++, index[j]++) { + ragged_component_splits_vec[j](k) = + batched_splits_vec[j + 1](index[j]) - last_split_value; + } + } + int last_split_size = ragged_component_splits_vec[num_splits - 1].size(); + ragged_component_values_size[i] = + ragged_component_splits_vec[num_splits - 1](last_split_size - 1); + } + + // Unbatch values. + int value_index = 0; + for (int i = 0; i < num_components; i++) { + int num_values = ragged_component_values_size[i]; + values_shape.set_dim(0, num_values); + (*ragged_components)[i].values = + Tensor(DataTypeToEnum::value, values_shape); + auto ragged_component_values_flat = + (*ragged_components)[i].values.flat(); + for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) { + ragged_component_values_flat(j) = batched_flat(value_index); + } + } + + return Status::OK(); +} +} // namespace + +template +class RaggedTensorToVariantOp : public OpKernel { + public: + explicit RaggedTensorToVariantOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("batched_input", &batched_input_)); + } + + void Compute(OpKernelContext* context) override { + // Read ragged_splits inputs. + OpInputList ragged_nested_splits_in; + OP_REQUIRES_OK(context, context->input_list("rt_nested_splits", + &ragged_nested_splits_in)); + const int ragged_nested_splits_len = ragged_nested_splits_in.size(); + DCHECK_GT(ragged_nested_splits_len, 0); // Enforced by REGISTER_OP. + RaggedTensor batched_ragged_input; + // Read ragged_values input. + batched_ragged_input.values = context->input(ragged_nested_splits_len); + batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len); + for (int i = 0; i < ragged_nested_splits_len; i++) { + batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]); + } + + if (!batched_input_) { + // Encode the input as is. + Tensor encoded_list; + OP_REQUIRES_OK(context, + RaggedToVariant(batched_ragged_input, &encoded_list)); + // Encode as a Scalar Variant Tensor. + Tensor* encoded_scalar; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), + &encoded_scalar)); + encoded_scalar->scalar()() = std::move(encoded_list); + return; + } + + // Unbatch the Ragged Tensor and encode the components. + std::vector ragged_components; + OP_REQUIRES_OK(context, UnbatchRaggedZerothDim( + batched_ragged_input, &ragged_components)); + std::vector encoded_components(ragged_components.size()); + for (int i = 0; i < ragged_components.size(); i++) { + OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i], + &encoded_components[i])); + } + + // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor. + Tensor* encoded_ragged; + int output_size = ragged_components.size(); + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({output_size}), + &encoded_ragged)); + auto encoded_ragged_vec = encoded_ragged->vec(); + for (int i = 0; i < output_size; i++) { + encoded_ragged_vec(i) = encoded_components[i]; + } + } + + private: + bool batched_input_; +}; + +#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \ + REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tvalues") \ + .TypeConstraint("Tsplits"), \ + RaggedTensorToVariantOp); +#define REGISTER_KERNELS(value_type) \ + REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \ + REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64) +TF_CALL_POD_TYPES(REGISTER_KERNELS); +TF_CALL_string(REGISTER_KERNELS); +TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); +TF_CALL_quint16(REGISTER_KERNELS); +TF_CALL_qint16(REGISTER_KERNELS); +TF_CALL_uint32(REGISTER_KERNELS); +TF_CALL_uint64(REGISTER_KERNELS); +#undef REGISTER_KERNELS +#undef REGISTER_KERNELS_WITH_SPLIT_TYPE +} // namespace tensorflow diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc new file mode 100644 index 00000000000..2854044d19a --- /dev/null +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc @@ -0,0 +1,610 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/strings/match.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class RaggedTensorToVariantKernelTest : public ::tensorflow::OpsTestBase { + protected: + // Builds the tensorflow test graph for the RaggedTensorToVariant op, and + // populates the `splits` input with the given values. + template + void BuildEncodeRaggedTensorGraph( + const std::vector>& ragged_splits, + const TensorShape& ragged_values_shape, + const std::vector& ragged_values, const bool batched) { + const auto values_dtype = DataTypeToEnum::v(); + const auto splits_dtype = DataTypeToEnum::v(); + int64 num_splits = ragged_splits.size(); + TF_ASSERT_OK( + NodeDefBuilder("tested_op", "RaggedTensorToVariant") + .Input(FakeInput(num_splits, splits_dtype)) // ragged_splits + .Input(FakeInput(values_dtype)) // ragged_values + .Attr("RAGGED_RANK", num_splits) + .Attr("Tvalues", values_dtype) + .Attr("Tsplits", splits_dtype) + .Attr("batched_input", batched) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + for (const auto& splits : ragged_splits) { + int64 splits_size = splits.size(); + AddInputFromArray(TensorShape({splits_size}), splits); + } + AddInputFromArray(ragged_values_shape, ragged_values); + } +}; + +TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) { + // ragged_tensor=[[[], []], [[]], []] + const std::vector batched_splits_1 = {0, 2, 3, 3}; + const std::vector batched_splits_2 = {0, 0, 0, 0}; + + const std::vector component_splits_1_1 = {0, 0, 0}; + const std::vector component_splits_2_1 = {0, 0}; + const std::vector component_splits_3_1 = {0}; + + Tensor expected_splits_1_1(DT_INT64, TensorShape({3})); + Tensor expected_splits_2_1(DT_INT64, TensorShape({2})); + Tensor expected_splits_3_1(DT_INT64, TensorShape({1})); + + test::FillValues(&expected_splits_1_1, component_splits_1_1); + test::FillValues(&expected_splits_2_1, component_splits_2_1); + test::FillValues(&expected_splits_3_1, component_splits_3_1); + + BuildEncodeRaggedTensorGraph({batched_splits_1, batched_splits_2}, + TensorShape({0}), {}, true); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_list = GetOutput(0)->vec(); + EXPECT_EQ(encoded_list.size(), 3); + + const Variant& encoded_splits_1_1 = + encoded_list(0).get()->vec()(0); + const Variant& encoded_values_1 = + encoded_list(0).get()->vec()(1); + const Variant& encoded_splits_2_1 = + encoded_list(1).get()->vec()(0); + const Variant& encoded_values_2 = + encoded_list(1).get()->vec()(1); + const Variant& encoded_splits_3_1 = + encoded_list(2).get()->vec()(0); + const Variant& encoded_values_3 = + encoded_list(2).get()->vec()(1); + + test::ExpectTensorEqual(*encoded_splits_1_1.get(), + expected_splits_1_1); + test::ExpectTensorEqual(*encoded_splits_2_1.get(), + expected_splits_2_1); + test::ExpectTensorEqual(*encoded_splits_3_1.get(), + expected_splits_3_1); + test::ExpectTensorEqual(*encoded_values_1.get(), + Tensor(DT_INT32, TensorShape({0}))); + test::ExpectTensorEqual(*encoded_values_2.get(), + Tensor(DT_INT32, TensorShape({0}))); + test::ExpectTensorEqual(*encoded_values_3.get(), + Tensor(DT_INT32, TensorShape({0}))); +} + +TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) { + // ragged_tensor= + // [ [x, x, x], + // [ ], + // [x, x ], + // [x ]] + const std::vector batched_splits = {0, 3, 3, 5, 6}; + const std::vector batched_values = {1, 2, 3, 4, 5, 6}; + + const std::vector component_values_1 = {1, 2, 3}; + const std::vector component_values_3 = {4, 5}; + const std::vector component_values_4 = {6}; + + Tensor expected_values_1(DT_INT32, TensorShape({3})); + Tensor expected_values_2(DT_INT32, TensorShape({0})); + Tensor expected_values_3(DT_INT32, TensorShape({2})); + Tensor expected_values_4(DT_INT32, TensorShape({1})); + + test::FillValues(&expected_values_1, component_values_1); + test::FillValues(&expected_values_3, component_values_3); + test::FillValues(&expected_values_4, component_values_4); + + BuildEncodeRaggedTensorGraph({batched_splits}, TensorShape({6}), + batched_values, true); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_list = GetOutput(0)->vec(); + EXPECT_EQ(encoded_list.size(), 4); + + const Variant& encoded_values_1 = + encoded_list(0).get()->vec()(0); + const Variant& encoded_values_2 = + encoded_list(1).get()->vec()(0); + const Variant& encoded_values_3 = + encoded_list(2).get()->vec()(0); + const Variant& encoded_values_4 = + encoded_list(3).get()->vec()(0); + + test::ExpectTensorEqual(*encoded_values_1.get(), + expected_values_1); + test::ExpectTensorEqual(*encoded_values_2.get(), + expected_values_2); + test::ExpectTensorEqual(*encoded_values_3.get(), + expected_values_3); + test::ExpectTensorEqual(*encoded_values_4.get(), + expected_values_4); +} + +TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) { + // ragged_tensor= + // [[x, x], + // [x, x], + // [x, x]] + const std::vector batched_splits = {0, 1, 2, 3}; + const std::vector batched_values = {1, 2, 4, 5, 6, 7}; + + const std::vector component_values_1 = {1, 2}; + const std::vector component_values_2 = {4, 5}; + const std::vector component_values_3 = {6, 7}; + + Tensor expected_values_1(DT_INT32, TensorShape({1, 2})); + Tensor expected_values_2(DT_INT32, TensorShape({1, 2})); + Tensor expected_values_3(DT_INT32, TensorShape({1, 2})); + + test::FillValues(&expected_values_1, component_values_1); + test::FillValues(&expected_values_2, component_values_2); + test::FillValues(&expected_values_3, component_values_3); + + BuildEncodeRaggedTensorGraph( + {batched_splits}, TensorShape({3, 2}), batched_values, true); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_list = GetOutput(0)->vec(); + EXPECT_EQ(encoded_list.size(), 3); + + const Variant& encoded_values_1 = + encoded_list(0).get()->vec()(0); + const Variant& encoded_values_2 = + encoded_list(1).get()->vec()(0); + const Variant& encoded_values_3 = + encoded_list(2).get()->vec()(0); + + test::ExpectTensorEqual(*encoded_values_1.get(), + expected_values_1); + test::ExpectTensorEqual(*encoded_values_2.get(), + expected_values_2); + test::ExpectTensorEqual(*encoded_values_3.get(), + expected_values_3); +} + +TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) { + // ragged_tensor=[ + // [ [[x, x], [x, x]], + // [[x, x] ] ] + const std::vector batched_splits_1 = {0, 1, 2}; + const std::vector batched_splits_2 = {0, 2, 3}; + const std::vector batched_values = {1, 2, 4, 5, 6, 7}; + + const std::vector component_splits_1_1 = {0, 2}; + const std::vector component_splits_2_1 = {0, 1}; + const std::vector component_values_1 = {1, 2, 4, 5}; + const std::vector component_values_2 = {6, 7}; + + Tensor expected_splits_1_1(DT_INT64, TensorShape({2})); + Tensor expected_splits_2_1(DT_INT64, TensorShape({2})); + Tensor expected_values_1(DT_INT32, TensorShape({2, 2})); + Tensor expected_values_2(DT_INT32, TensorShape({1, 2})); + + test::FillValues(&expected_splits_1_1, component_splits_1_1); + test::FillValues(&expected_splits_2_1, component_splits_2_1); + test::FillValues(&expected_values_1, component_values_1); + test::FillValues(&expected_values_2, component_values_2); + + BuildEncodeRaggedTensorGraph({batched_splits_1, batched_splits_2}, + TensorShape({3, 2}), batched_values, + true); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_list = GetOutput(0)->vec(); + EXPECT_EQ(encoded_list.size(), 2); + + const Variant& encoded_splits_1_1 = + encoded_list(0).get()->vec()(0); + const Variant& encoded_values_1 = + encoded_list(0).get()->vec()(1); + const Variant& encoded_splits_2_1 = + encoded_list(1).get()->vec()(0); + const Variant& encoded_values_2 = + encoded_list(1).get()->vec()(1); + + test::ExpectTensorEqual(*encoded_splits_1_1.get(), + expected_splits_1_1); + test::ExpectTensorEqual(*encoded_values_1.get(), + expected_values_1); + test::ExpectTensorEqual(*encoded_splits_2_1.get(), + expected_splits_2_1); + test::ExpectTensorEqual(*encoded_values_2.get(), + expected_values_2); +} + +TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) { + // ragged_tensor = + // [[ [x], [x x], [] ], + // [ ], + // [ [x x x x x], [x x x] ], + // [ [], [x x x x] ]] + const std::vector batched_splits_1 = {0, 3, 3, 5, 7}; + const std::vector batched_splits_2 = {0, 1, 3, 3, 8, 11, 11, 15}; + const std::vector batched_values = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15}; + const std::vector component_splits_1_1 = {0, 1, 3, 3}; + const std::vector component_splits_2_1 = {0}; + const std::vector component_splits_3_1 = {0, 5, 8}; + const std::vector component_splits_4_1 = {0, 0, 4}; + const std::vector component_values_1 = {1, 2, 3}; + const std::vector component_values_3 = {4, 5, 6, 7, 8, 9, 10, 11}; + const std::vector component_values_4 = {12, 13, 14, 15}; + + Tensor expected_splits_1_1(DT_INT64, TensorShape({4})); + Tensor expected_splits_2_1(DT_INT64, TensorShape({1})); + Tensor expected_splits_3_1(DT_INT64, TensorShape({3})); + Tensor expected_splits_4_1(DT_INT64, TensorShape({3})); + Tensor expected_values_1(DT_INT32, TensorShape({3})); + Tensor expected_values_2(DT_INT32, TensorShape({0})); + Tensor expected_values_3(DT_INT32, TensorShape({8})); + Tensor expected_values_4(DT_INT32, TensorShape({4})); + + test::FillValues(&expected_splits_1_1, component_splits_1_1); + test::FillValues(&expected_splits_2_1, component_splits_2_1); + test::FillValues(&expected_splits_3_1, component_splits_3_1); + test::FillValues(&expected_splits_4_1, component_splits_4_1); + test::FillValues(&expected_values_1, component_values_1); + test::FillValues(&expected_values_3, component_values_3); + test::FillValues(&expected_values_4, component_values_4); + + BuildEncodeRaggedTensorGraph({batched_splits_1, batched_splits_2}, + TensorShape({15}), batched_values, + true); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_list = GetOutput(0)->vec(); + EXPECT_EQ(encoded_list.size(), 4); + + const Variant& encoded_splits_1_1 = + encoded_list(0).get()->vec()(0); + const Variant& encoded_values_1 = + encoded_list(0).get()->vec()(1); + const Variant& encoded_splits_2_1 = + encoded_list(1).get()->vec()(0); + const Variant& encoded_values_2 = + encoded_list(1).get()->vec()(1); + const Variant& encoded_splits_3_1 = + encoded_list(2).get()->vec()(0); + const Variant& encoded_values_3 = + encoded_list(2).get()->vec()(1); + const Variant& encoded_splits_4_1 = + encoded_list(3).get()->vec()(0); + const Variant& encoded_values_4 = + encoded_list(3).get()->vec()(1); + + test::ExpectTensorEqual(*encoded_splits_1_1.get(), + expected_splits_1_1); + test::ExpectTensorEqual(*encoded_values_1.get(), + expected_values_1); + test::ExpectTensorEqual(*encoded_splits_2_1.get(), + expected_splits_2_1); + test::ExpectTensorEqual(*encoded_values_2.get(), + expected_values_2); + test::ExpectTensorEqual(*encoded_splits_3_1.get(), + expected_splits_3_1); + test::ExpectTensorEqual(*encoded_values_3.get(), + expected_values_3); + test::ExpectTensorEqual(*encoded_splits_4_1.get(), + expected_splits_4_1); + test::ExpectTensorEqual(*encoded_values_4.get(), + expected_values_4); +} + +TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) { + // ragged_tensor = + // [[ [ [x, x] ], + // [ [x], [x] ], + // [ [x] ], + // [ [x] ], + // [ [x] ]], + // [ [ [x] ], + // [ [x] ], + // [ [x, x, x] ], + // [ [x] ], + // [ [x] ] ]] + const std::vector batched_splits_1 = {0, 5, 10}; + const std::vector batched_splits_2 = {0, 1, 3, 4, 5, 6, + 7, 8, 9, 10, 11}; + const std::vector batched_splits_3 = {0, 2, 3, 4, 5, 6, + 7, 8, 9, 12, 13, 14}; + const std::vector batched_values = {0, 1, 1, 2, 2, 3, 4, + 5, 6, 7, 8, 9, 8, 9}; + const std::vector component_split_1_1 = {0, 1, 3, 4, 5, 6}; + const std::vector component_split_1_2 = {0, 2, 3, 4, 5, 6, 7}; + const std::vector component_split_2_1 = {0, 1, 2, 3, 4, 5}; + const std::vector component_split_2_2 = {0, 1, 2, 5, 6, 7}; + const std::vector component_values_1 = {0, 1, 1, 2, 2, 3, 4}; + const std::vector component_values_2 = {5, 6, 7, 8, 9, 8, 9}; + + Tensor expected_splits_1_1(DT_INT64, TensorShape({6})); + Tensor expected_splits_1_2(DT_INT64, TensorShape({7})); + Tensor expected_splits_2_1(DT_INT64, TensorShape({6})); + Tensor expected_splits_2_2(DT_INT64, TensorShape({6})); + Tensor expected_values_1(DT_INT32, TensorShape({7})); + Tensor expected_values_2(DT_INT32, TensorShape({7})); + + test::FillValues(&expected_splits_1_1, component_split_1_1); + test::FillValues(&expected_splits_1_2, component_split_1_2); + test::FillValues(&expected_splits_2_1, component_split_2_1); + test::FillValues(&expected_splits_2_2, component_split_2_2); + test::FillValues(&expected_values_1, component_values_1); + test::FillValues(&expected_values_2, component_values_2); + + BuildEncodeRaggedTensorGraph( + {batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}), + batched_values, true); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_list = GetOutput(0)->vec(); + EXPECT_EQ(encoded_list.size(), 2); + + const Variant& encoded_splits_1_1 = + encoded_list(0).get()->vec()(0); + const Variant& encoded_splits_1_2 = + encoded_list(0).get()->vec()(1); + const Variant& encoded_values_1 = + encoded_list(0).get()->vec()(2); + const Variant& encoded_splits_2_1 = + encoded_list(1).get()->vec()(0); + const Variant& encoded_splits_2_2 = + encoded_list(1).get()->vec()(1); + const Variant& encoded_values_2 = + encoded_list(1).get()->vec()(2); + + test::ExpectTensorEqual(*encoded_splits_1_1.get(), + expected_splits_1_1); + test::ExpectTensorEqual(*encoded_splits_1_2.get(), + expected_splits_1_2); + test::ExpectTensorEqual(*encoded_splits_2_1.get(), + expected_splits_2_1); + test::ExpectTensorEqual(*encoded_splits_2_2.get(), + expected_splits_2_2); + test::ExpectTensorEqual(*encoded_values_1.get(), + expected_values_1); + test::ExpectTensorEqual(*encoded_values_2.get(), + expected_values_2); +} + +TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) { + // ragged_tensor = + // [[ [ [x, x] ], + // [ [x], [x] ], + // [ [x] ], + // [ [x] ], + // [ [x] ]], + // [ [ [x] ], + // [ [x] ], + // [ [x, x, x] ], + // [ [x] ], + // [ [x] ] ]] + const std::vector batched_splits_1 = {0, 5, 10}; + const std::vector batched_splits_2 = {0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + const std::vector batched_splits_3 = {0, 2, 3, 4, 5, 6, + 7, 8, 9, 12, 13, 14}; + const std::vector batched_values = {0, 1, 1, 2, 2, 3, 4, + 5, 6, 7, 8, 9, 8, 9}; + const std::vector component_split_1_1 = {0, 1, 3, 4, 5, 6}; + const std::vector component_split_1_2 = {0, 2, 3, 4, 5, 6, 7}; + const std::vector component_split_2_1 = {0, 1, 2, 3, 4, 5}; + const std::vector component_split_2_2 = {0, 1, 2, 5, 6, 7}; + const std::vector component_values_1 = {0, 1, 1, 2, 2, 3, 4}; + const std::vector component_values_2 = {5, 6, 7, 8, 9, 8, 9}; + + Tensor expected_splits_1_1(DT_INT32, TensorShape({6})); + Tensor expected_splits_1_2(DT_INT32, TensorShape({7})); + Tensor expected_splits_2_1(DT_INT32, TensorShape({6})); + Tensor expected_splits_2_2(DT_INT32, TensorShape({6})); + Tensor expected_values_1(DT_INT32, TensorShape({7})); + Tensor expected_values_2(DT_INT32, TensorShape({7})); + + test::FillValues(&expected_splits_1_1, component_split_1_1); + test::FillValues(&expected_splits_1_2, component_split_1_2); + test::FillValues(&expected_splits_2_1, component_split_2_1); + test::FillValues(&expected_splits_2_2, component_split_2_2); + test::FillValues(&expected_values_1, component_values_1); + test::FillValues(&expected_values_2, component_values_2); + + BuildEncodeRaggedTensorGraph( + {batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}), + batched_values, true); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_list = GetOutput(0)->vec(); + EXPECT_EQ(encoded_list.size(), 2); + + const Variant& encoded_splits_1_1 = + encoded_list(0).get()->vec()(0); + const Variant& encoded_splits_1_2 = + encoded_list(0).get()->vec()(1); + const Variant& encoded_values_1 = + encoded_list(0).get()->vec()(2); + const Variant& encoded_splits_2_1 = + encoded_list(1).get()->vec()(0); + const Variant& encoded_splits_2_2 = + encoded_list(1).get()->vec()(1); + const Variant& encoded_values_2 = + encoded_list(1).get()->vec()(2); + + test::ExpectTensorEqual(*encoded_splits_1_1.get(), + expected_splits_1_1); + test::ExpectTensorEqual(*encoded_splits_1_2.get(), + expected_splits_1_2); + test::ExpectTensorEqual(*encoded_splits_2_1.get(), + expected_splits_2_1); + test::ExpectTensorEqual(*encoded_splits_2_2.get(), + expected_splits_2_2); + test::ExpectTensorEqual(*encoded_values_1.get(), + expected_values_1); + test::ExpectTensorEqual(*encoded_values_2.get(), + expected_values_2); +} + +TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) { + // ragged_tensor = + // [[ [x], [x x], [] ], + // [ ], + // [ [x x x x x], [x x x] ], + // [ [], [x x x x] ]] + const std::vector batched_splits_1 = {0, 3, 3, 5, 7}; + const std::vector batched_splits_2 = {0, 1, 3, 3, 8, 11, 11, 15}; + const std::vector batched_values = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15}; + + Tensor batched_ragged_splits_1(DT_INT64, TensorShape({5})); + Tensor batched_ragged_splits_2(DT_INT64, TensorShape({8})); + Tensor batched_ragged_values(DT_INT32, TensorShape({15})); + + test::FillValues(&batched_ragged_splits_1, batched_splits_1); + test::FillValues(&batched_ragged_splits_2, batched_splits_2); + test::FillValues(&batched_ragged_values, batched_values); + + BuildEncodeRaggedTensorGraph({batched_splits_1, batched_splits_2}, + TensorShape({15}), batched_values, + false); + TF_ASSERT_OK(RunOpKernel()); + + const auto& encoded_scalar = GetOutput(0)->scalar()(); + const Variant& encoded_splits_1 = + encoded_scalar.get()->vec()(0); + const Variant& encoded_splits_2 = + encoded_scalar.get()->vec()(1); + const Variant& encoded_values = + encoded_scalar.get()->vec()(2); + + test::ExpectTensorEqual(*encoded_splits_1.get(), + batched_ragged_splits_1); + test::ExpectTensorEqual(*encoded_splits_2.get(), + batched_ragged_splits_2); + test::ExpectTensorEqual(*encoded_values.get(), + batched_ragged_values); +} + +TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) { + ShapeInferenceTestOp op("RaggedTensorToVariant"); + (*op.node_def.mutable_attr())["batched_input"].set_b(true); + + // Tests with len(ragged_splits)==0. + (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0); + INFER_ERROR("Shape inference should have returned error", op, "?"); + + // Tests with len(ragged_splits)==1. + (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1); + INFER_OK(op, "?;?", "[?]"); + INFER_OK(op, "?;[?]", "[?]"); + INFER_OK(op, "?;[?,?]", "[?]"); + INFER_OK(op, "[?];[5]", "[?]"); + INFER_OK(op, "[?];[5,2]", "[?]"); + INFER_OK(op, "[5];[5,2]", "[4]"); + INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[5,5];?"); + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]"); + + // Tests with len(ragged_splits)==2 + (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(2); + INFER_OK(op, "?;?;?", "[?]"); + INFER_OK(op, "?;?;[?]", "[?]"); + INFER_OK(op, "?;?;[?,?]", "[?]"); + INFER_OK(op, "[?];[?];[5]", "[?]"); + INFER_OK(op, "[?];[?];[5,2]", "[?]"); + INFER_OK(op, "[6];[?];[5,2]", "[5]"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[5,5];?"); + + // Tests with len(ragged_splits)==3 + (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(3); + INFER_OK(op, "?;?;?;?", "[?]"); + INFER_OK(op, "?;?;?;[?]", "[?]"); + INFER_OK(op, "?;?;?;[5]", "[?]"); + INFER_OK(op, "[4];?;?;[5]", "[3]"); +} + +TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) { + ShapeInferenceTestOp op("RaggedTensorToVariant"); + (*op.node_def.mutable_attr())["batched_input"].set_b(false); + + // Tests with len(ragged_splits)==0. + (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0); + INFER_ERROR("Shape inference should have returned error", op, "?"); + + // Tests with len(ragged_splits)==1. + (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1); + INFER_OK(op, "?;?", "[]"); + INFER_OK(op, "?;[?]", "[]"); + INFER_OK(op, "?;[?,?]", "[]"); + INFER_OK(op, "[?];[5]", "[]"); + INFER_OK(op, "[?];[5,2]", "[]"); + INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[5,5];?"); + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]"); + + // Tests with len(ragged_splits)==2 + (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(2); + INFER_OK(op, "?;?;?", "[]"); + INFER_OK(op, "?;?;[?]", "[]"); + INFER_OK(op, "?;?;[?,?]", "[]"); + INFER_OK(op, "[?];[?];[5]", "[]"); + INFER_OK(op, "[?];[?];[5,2]", "[]"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[5,5];?"); + + // Tests with len(ragged_splits)==3 + (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(3); + INFER_OK(op, "?;?;?;?", "[]"); + INFER_OK(op, "?;?;?;[?]", "[]"); + INFER_OK(op, "?;?;?;[5]", "[]"); +} + +TEST_F(RaggedTensorToVariantKernelTest, NoSplits) { + const auto dtype = DataTypeToEnum::v(); + TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToVariant") + .Input(FakeInput(0)) + .Input(FakeInput(dtype)) + .Attr("RAGGED_RANK", 0) + .Attr("Tvalues", dtype) + .Attr("Tsplits", DT_INT64) + .Attr("batched_input", true) + .Finalize(node_def())); + EXPECT_TRUE(absl::StartsWith( + InitOp().error_message(), + "Value for attr 'RAGGED_RANK' of 0 must be at least minimum 1")); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc new file mode 100644 index 00000000000..6ed36605530 --- /dev/null +++ b/tensorflow/core/kernels/random_binomial_op.cc @@ -0,0 +1,447 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/random_ops.cc. +// NOTE: If the algorithm is changed, please run the test +// .../python/kernel_tests/random:random_binomial_test +// commenting out the "tf.set_random_seed(seed)" lines, and using the +// "--runs-per-test=1000" flag. This tests the statistical correctness of the +// op results. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/random_binomial_op.h" + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" +#include "tensorflow/core/kernels/training_op_helpers.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/util/work_sharder.h" + +#define UNIFORM(X) \ + if (uniform_remaining == 0) { \ + uniform_remaining = Uniform::kResultElementCount; \ + uniform_result = uniform(gen); \ + } \ + uniform_remaining--; \ + double X = uniform_result[uniform_remaining] + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace { + +typedef random::UniformDistribution Uniform; + +// Binomial inversion. Given prob, sum geometric random variables until they +// exceed count. The number of random variables used is binomially distributed. +// This is also known as binomial inversion, as this is equivalent to inverting +// the Binomial CDF. +double binomial_inversion(double count, double prob, + random::PhiloxRandom* gen) { + using Eigen::numext::ceil; + using Eigen::numext::log; + using Eigen::numext::log1p; + + double geom_sum = 0; + int num_geom = 0; + + Uniform uniform; + typename Uniform::ResultType uniform_result; + int16 uniform_remaining = 0; + + while (true) { + UNIFORM(u); + double geom = ceil(log(u) / log1p(-prob)); + geom_sum += geom; + if (geom_sum > count) { + break; + } + ++num_geom; + } + return num_geom; +} + +double stirling_approx_tail(double k) { + static double kTailValues[] = {0.0810614667953272, 0.0413406959554092, + 0.0276779256849983, 0.02079067210376509, + 0.0166446911898211, 0.0138761288230707, + 0.0118967099458917, 0.0104112652619720, + 0.00925546218271273, 0.00833056343336287}; + if (k <= 9) { + return kTailValues[static_cast(k)]; + } + double kp1sq = (k + 1) * (k + 1); + return (1 / 12 - (1 / 360 + 1 / 1260 / kp1sq) / kp1sq) / (k + 1); +} + +// We use a transformation-rejection algorithm from +// pairs of uniform random variables due to Hormann. +// https://www.tandfonline.com/doi/abs/10.1080/00949659308811496 +double btrs(double count, double prob, random::PhiloxRandom* gen) { + using Eigen::numext::abs; + using Eigen::numext::floor; + using Eigen::numext::log; + using Eigen::numext::log1p; + using Eigen::numext::sqrt; + + // This is spq in the paper. + const double stddev = sqrt(count * prob * (1 - prob)); + + // Other coefficients for Transformed Rejection sampling. + const double b = 1.15 + 2.53 * stddev; + const double a = -0.0873 + 0.0248 * b + 0.01 * prob; + const double c = count * prob + 0.5; + const double v_r = 0.92 - 4.2 / b; + const double r = prob / (1 - prob); + + Uniform uniform; + typename Uniform::ResultType uniform_result; + int16 uniform_remaining = 0; + + while (true) { + UNIFORM(u); + UNIFORM(v); + u = u - 0.5; + double us = 0.5 - abs(u); + double k = floor((2 * a / us + b) * u + c); + + // Region for which the box is tight, and we + // can return our calculated value This should happen + // 0.86 * v_r times. In the limit as n * p is large, + // the acceptance rate converges to ~79% (and in the lower + // regime it is ~24%). + if (us >= 0.07 && v <= v_r) { + return k; + } + // Reject non-sensical answers. + if (k < 0 || k > count) { + continue; + } + + double alpha = (2.83 + 5.1 / b) * stddev; + double m = floor((count + 1) * prob); + // This deviates from Hormann's BRTS algorithm, as there is a log missing. + // For all (u, v) pairs outside of the bounding box, this calculates the + // transformed-reject ratio. + v = log(v * alpha / (a / (us * us) + b)); + double upperbound = + ((m + 0.5) * log((m + 1) / (r * (count - m + 1))) + + (count + 1) * log((count - m + 1) / (count - k + 1)) + + (k + 0.5) * log(r * (count - k + 1) / (k + 1)) + + stirling_approx_tail(m) + stirling_approx_tail(count - m) - + stirling_approx_tail(k) - stirling_approx_tail(count - k)); + if (v <= upperbound) { + return k; + } + } +} + +} // namespace + +namespace functor { + +template +struct RandomBinomialFunctor { + void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches, + int64 samples_per_batch, int64 num_elements, + typename TTypes::ConstFlat counts, + typename TTypes::ConstFlat probs, + const random::PhiloxRandom& gen, + typename TTypes::Flat output) { + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + + auto DoWork = [samples_per_batch, num_elements, &counts, &probs, &gen, + &output](int start_batch, int limit_batch) { + // Capturing "gen" by-value would only make a copy for the _shared_ + // lambda. Since we want to let each worker have its own copy, we pass + // "gen" by reference and explicitly do a copy assignment here. + random::PhiloxRandom gen_copy = gen; + // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to + // us using the same state in different batches. + // The sample from each iteration uses 2 random numbers. + gen_copy.Skip(start_batch * 2 * 3 * (samples_per_batch + 3) / 4); + + // Vectorized intermediate calculations for uniform rejection sampling. + // We always generate at most 4 samples. + Eigen::array z; + Eigen::array g; + + for (int64 b = start_batch; b < limit_batch; ++b) { + // We are passed a flat array for each of the parameter tensors. + // The input is either a scalar broadcasted to all batches or a vector + // with length num_batches, but the scalar becomes an array of length 1. + T count = counts((counts.dimension(0) == 1) ? 0 : b); + T prob = probs((probs.dimension(0) == 1) ? 0 : b); + + // The last batch can be short, if we adjusted num_batches and + // samples_per_batch. + const int64 limit_sample = + std::min((b + 1) * samples_per_batch, num_elements); + int64 sample = b * samples_per_batch; + + // Calculate normalized samples, then convert them. + // Determine the method to use. + double dcount = static_cast(count); + if (prob <= T(0.5)) { + double dp = static_cast(prob); + if (count * prob >= T(10)) { + while (sample < limit_sample) { + output(sample) = static_cast(btrs(dcount, dp, &gen_copy)); + sample++; + } + } else { + while (sample < limit_sample) { + output(sample) = + static_cast(binomial_inversion(dcount, dp, &gen_copy)); + sample++; + } + } + } else { + T q = T(1) - prob; + double dcount = static_cast(count); + double dq = static_cast(q); + if (count * q >= T(10)) { + while (sample < limit_sample) { + output(sample) = + static_cast(dcount - btrs(dcount, dq, &gen_copy)); + sample++; + } + } else { + while (sample < limit_sample) { + output(sample) = static_cast( + dcount - binomial_inversion(dcount, dq, &gen_copy)); + sample++; + } + } + } + } + }; + + const int64 batch_init_cost = + // normMin, normMax + (Eigen::TensorOpCost::AddCost() + + Eigen::TensorOpCost::MulCost()) * + 2 + // sqrtFactor + + Eigen::TensorOpCost::AddCost() + + Eigen::TensorOpCost::MulCost() + + Eigen::internal::functor_traits< + Eigen::internal::scalar_sqrt_op>::Cost + // cutoff + + Eigen::TensorOpCost::MulCost() * 4 + + Eigen::internal::functor_traits>::Cost + // diff + + Eigen::TensorOpCost::AddCost(); + // This will depend on count * p (or count * q). + // For n * p < 10, on average, O(n * p) calls to uniform are + // needed, with that + // many multiplies. ~10 uniform calls on average with ~200 cost op calls. + // + // Very roughly, for rate >= 10, the four calls to log + // occur for ~72 percent of samples. + // 4 x 100 (64-bit cycles per log) * 0.72 = ~288 + // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each: + // 40 * .72 = ~25. + // + // Finally, there are several other ops that are done every loop along with + // 2 uniform generations along with 5 other ops at 3-6 cycles each. + // ~15 / .89 = ~16 + // + // In total this should be ~529 + 2 * Uniform::kElementCost. + // We assume that half the tensor has rate < 10, so on average 6 + // uniform's + // will be needed. We will upper bound the other op cost by the one for + // rate > 10. + static const int kElementCost = 529 + 6 * Uniform::kElementCost + + 6 * random::PhiloxRandom::kElementCost; + // Assume we use uniform sampling, and accept the 2nd sample on average. + const int64 batch_cost = batch_init_cost + kElementCost * samples_per_batch; + Shard(worker_threads.num_threads, worker_threads.workers, num_batches, + batch_cost, DoWork); + } +}; + +} // namespace functor + +namespace { + +// Samples from a binomial distribution, using the given parameters. +template +class RandomBinomialOp : public OpKernel { + // Reshape batches so each batch is this size if possible. + static const int32 kDesiredBatchSize = 100; + + public: + explicit RandomBinomialOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& alg_tensor = ctx->input(1); + const Tensor& shape_tensor = ctx->input(2); + const Tensor& counts_tensor = ctx->input(3); + const Tensor& probs_tensor = ctx->input(4); + + OP_REQUIRES(ctx, alg_tensor.dims() == 0, + errors::InvalidArgument("algorithm must be of shape [], not ", + alg_tensor.shape().DebugString())); + Algorithm alg = alg_tensor.flat()(0); + + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), + errors::InvalidArgument("Input shape should be a vector, got shape: ", + shape_tensor.shape().DebugString())); + int32 num_batches = shape_tensor.flat()(0); + + int32 samples_per_batch = 1; + const int32 num_dims = shape_tensor.dim_size(0); + for (int32 i = 1; i < num_dims; i++) { + samples_per_batch *= shape_tensor.flat()(i); + } + const int32 num_elements = num_batches * samples_per_batch; + + // Allocate the output before fudging num_batches and samples_per_batch. + auto shape_vec = shape_tensor.flat(); + TensorShape tensor_shape; + OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( + shape_vec.data(), shape_vec.size(), &tensor_shape)); + Tensor* samples_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor)); + + // Parameters must be 0-d or 1-d. + OP_REQUIRES(ctx, counts_tensor.dims() <= 1, + errors::InvalidArgument( + "Input counts should be a scalar or vector, got shape: ", + counts_tensor.shape().DebugString())); + OP_REQUIRES(ctx, probs_tensor.dims() <= 1, + errors::InvalidArgument( + "Input probs should be a scalar or vector, got shape: ", + probs_tensor.shape().DebugString())); + + if ((counts_tensor.dims() == 0 || counts_tensor.dim_size(0) == 1) && + (probs_tensor.dims() == 0 || probs_tensor.dim_size(0) == 1)) { + // All batches have the same parameters, so we can update the batch size + // to a reasonable value to improve parallelism (ensure enough batches, + // and no very small batches which have high overhead). + int32 size = num_batches * samples_per_batch; + int32 adjusted_samples = kDesiredBatchSize; + // Ensure adjusted_batches * adjusted_samples >= size. + int32 adjusted_batches = Eigen::divup(size, adjusted_samples); + num_batches = adjusted_batches; + samples_per_batch = adjusted_samples; + } else { + // Parameters must be broadcastable to the shape [num_batches]. + OP_REQUIRES( + ctx, + TensorShapeUtils::IsScalar(counts_tensor.shape()) || + counts_tensor.dim_size(0) == 1 || + counts_tensor.dim_size(0) == num_batches, + errors::InvalidArgument( + "Input counts should have length 1 or shape[0], got shape: ", + counts_tensor.shape().DebugString())); + OP_REQUIRES( + ctx, + TensorShapeUtils::IsScalar(probs_tensor.shape()) || + probs_tensor.dim_size(0) == 1 || + probs_tensor.dim_size(0) == num_batches, + errors::InvalidArgument( + "Input probs should have length 1 or shape[0], got shape: ", + probs_tensor.shape().DebugString())); + } + Var* var = nullptr; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var)); + + ScopedUnlockUnrefVar var_guard(var); + Tensor* var_tensor = var->tensor(); + OP_REQUIRES( + ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE, + errors::InvalidArgument("dtype of RNG state variable must be ", + DataTypeString(STATE_ELEMENT_DTYPE), ", not ", + DataTypeString(var_tensor->dtype()))); + OP_REQUIRES(ctx, var_tensor->dims() == 1, + errors::InvalidArgument( + "RNG state must have one and only one dimension, not ", + var_tensor->dims())); + auto var_tensor_flat = var_tensor->flat(); + OP_REQUIRES(ctx, alg == RNG_ALG_PHILOX, + errors::InvalidArgument("Unsupported algorithm id: ", alg)); + static_assert(std::is_same::value, + "StateElementType must be int64"); + static_assert(std::is_same::value, + "PhiloxRandom::ResultElementType must be uint32"); + OP_REQUIRES(ctx, var_tensor_flat.size() >= PHILOX_MIN_STATE_SIZE, + errors::InvalidArgument( + "For Philox algorithm, the size of state must be at least ", + PHILOX_MIN_STATE_SIZE, "; got ", var_tensor_flat.size())); + + // Each worker has the fudge factor for samples_per_batch, so use it here. + OP_REQUIRES_OK(ctx, PrepareToUpdateVariable( + ctx, var_tensor, var->copy_on_read_mode.load())); + auto var_data = var_tensor_flat.data(); + auto philox = GetPhiloxRandomFromMem(var_data); + UpdateMemWithPhiloxRandom( + philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data); + var_guard.Release(); + + auto binomial_functor = functor::RandomBinomialFunctor(); + binomial_functor(ctx, ctx->eigen_device(), num_batches, + samples_per_batch, num_elements, counts_tensor.flat(), + probs_tensor.flat(), philox, samples_tensor->flat()); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp); +}; + +} // namespace + +#define REGISTER(RTYPE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial") \ + .Device(DEVICE_CPU) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .HostMemory("counts") \ + .HostMemory("probs") \ + .TypeConstraint("dtype") \ + .TypeConstraint("T"), \ + RandomBinomialOp) + +#define REGISTER_ALL(RTYPE) \ + REGISTER(RTYPE, Eigen::half); \ + REGISTER(RTYPE, float); \ + REGISTER(RTYPE, double); + +REGISTER_ALL(Eigen::half); +REGISTER_ALL(float); +REGISTER_ALL(double); +REGISTER_ALL(int32); +REGISTER_ALL(int64); + +#undef REGISTER +#undef REGISTER_ALL + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/random_binomial_op.h b/tensorflow/core/kernels/random_binomial_op.h new file mode 100644 index 00000000000..05c489da83a --- /dev/null +++ b/tensorflow/core/kernels/random_binomial_op.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace functor { + +// Sample a binomial random variable, with probs and counts for each batch. +// Uses binomial inversion and a transformed rejection sampling method as +// described in +// https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf. +// Two different algorithms are employed, depending on the size of +// counts * probs (or counts * (1 - probs) if probs > 0.5. +// If counts * probs < 10, we simply sum up Geometric random variables until +// they exceed count, and the number we used is binomially distributed. +// In expectation, this will take O(counts * probs) time, and requiring in +// expectation the same number of random variates. +// This can be much cheaper than summing bernoulli random variates, as we +// will always need O(counts) bernoulli random variates (so this requires fewer +// uniform r.v.s as well as can be faster). +// +// If counts * probs > 10, we use a transformed-rejection algorithm based on +// pairs of uniform random variates due to Hormann. +// https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf +// This algorithm has higher acceptance rates for counts * probs large, as the +// proposal distribution becomes quite tight, requiring approximately two +// uniform random variates as counts * probs becomes large. +template +struct RandomBinomialFunctor { + void operator()(OpKernelContext* ctx, const Device& d, int64 num_batches, + int64 samples_per_batch, int64 num_elements, + typename TTypes::ConstFlat counts, + typename TTypes::ConstFlat probs, + const random::PhiloxRandom& gen, + typename TTypes::Flat output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ diff --git a/tensorflow/core/kernels/random_binomial_op_test.cc b/tensorflow/core/kernels/random_binomial_op_test.cc new file mode 100644 index 00000000000..9f8f47ef853 --- /dev/null +++ b/tensorflow/core/kernels/random_binomial_op_test.cc @@ -0,0 +1,107 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +static Graph* RandomBinomialGraph(double count, double prob, int num_batches, + int samples_per_batch) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor shape_t(DT_INT32, TensorShape({2})); + shape_t.flat().setValues({num_batches, samples_per_batch}); + + Tensor counts_t(DT_FLOAT, TensorShape({num_batches})); + counts_t.flat().setConstant(count); + Tensor probs_t(DT_FLOAT, TensorShape({num_batches})); + probs_t.flat().setConstant(prob); + + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("randombinomial"), "RandomBinomial") + .Input(test::graph::Constant(g, shape_t)) + .Input(test::graph::Constant(g, counts_t)) + .Input(test::graph::Constant(g, probs_t)) + .Attr("dtype", DT_FLOAT) + .Finalize(g, &ret)); + return g; +} + +static Graph* RandomBinomialInv(int num_batches, int samples_per_batch) { + // Because counts * probs < 10, we are guaranteed to use inversion. + return RandomBinomialGraph(10., 0.3, num_batches, samples_per_batch); +} + +static Graph* RandomBinomialRej(int num_batches, int samples_per_batch) { + // Because counts * probs > 10, we are guaranteed to use rejection. + return RandomBinomialGraph(100., 0.3, num_batches, samples_per_batch); +} + +static Graph* RandomBinomialInvComplement(int num_batches, + int samples_per_batch) { + // Because counts * (1 - probs) < 10, we are guaranteed to use inversion. + return RandomBinomialGraph(10., 0.8, num_batches, samples_per_batch); +} + +static Graph* RandomBinomialRejComplement(int num_batches, + int samples_per_batch) { + // Because counts * (1 - probs) > 10, we are guaranteed to use inversion. + return RandomBinomialGraph(100., 0.2, num_batches, samples_per_batch); +} + +#define BM_RandomBinomialInv(DEVICE, B, S) \ + static void BM_RandomBinomialInv_##DEVICE##_##B##_##S(int iters) { \ + test::Benchmark(#DEVICE, RandomBinomialInv(B, S)).Run(iters); \ + testing::ItemsProcessed(static_cast(B) * S * iters); \ + } \ + BENCHMARK(BM_RandomBinomialInv_##DEVICE##_##B##_##S); + +#define BM_RandomBinomialRej(DEVICE, B, S) \ + static void BM_RandomBinomialRej_##DEVICE##_##B##_##S(int iters) { \ + test::Benchmark(#DEVICE, RandomBinomialRej(B, S)).Run(iters); \ + testing::ItemsProcessed(static_cast(B) * S * iters); \ + } \ + BENCHMARK(BM_RandomBinomialRej_##DEVICE##_##B##_##S); + +#define BM_RandomBinomialInvComplement(DEVICE, B, S) \ + static void BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S(int iters) { \ + test::Benchmark(#DEVICE, RandomBinomialInvComplement(B, S)).Run(iters); \ + testing::ItemsProcessed(static_cast(B) * S * iters); \ + } \ + BENCHMARK(BM_RandomBinomialInvComplement_##DEVICE##_##B##_##S); + +#define BM_RandomBinomialRejComplement(DEVICE, B, S) \ + static void BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S(int iters) { \ + test::Benchmark(#DEVICE, RandomBinomialRejComplement(B, S)).Run(iters); \ + testing::ItemsProcessed(static_cast(B) * S * iters); \ + } \ + BENCHMARK(BM_RandomBinomialRejComplement_##DEVICE##_##B##_##S); + +BM_RandomBinomialInv(cpu, 1000, 1000); +BM_RandomBinomialRej(cpu, 1000, 1000); +BM_RandomBinomialInvComplement(cpu, 1000, 1000); +BM_RandomBinomialRejComplement(cpu, 1000, 1000); +BM_RandomBinomialInv(gpu, 1000, 1000); +BM_RandomBinomialRej(gpu, 1000, 1000); +BM_RandomBinomialInvComplement(gpu, 1000, 1000); +BM_RandomBinomialRejComplement(gpu, 1000, 1000); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 3810d817ca9..e39e5f2eb3b 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -17,8 +17,6 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/random_op.h" - #include #include #include @@ -27,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/random_op_cpu.h" #include "tensorflow/core/lib/hash/crc32c.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/simple_philox.h" @@ -52,131 +51,6 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL -namespace functor { -using random::PhiloxRandom; -using random::SingleSampleAdapter; - -// The default implementation of the functor, which should never be invoked -// But we still need to provide implementation for now for the linker to work, -// since we do not support all the distributions yet. -template -struct FillPhiloxRandom { - typedef typename Distribution::ResultElementType T; - void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen, - T* data, int64 size, Distribution dist) { - LOG(FATAL) << "Default FillPhiloxRandom should not be executed."; - } -}; - -// A class to fill a specified range of random groups -template -struct FillPhiloxRandomTask; - -// Specialization for distribution that takes a fixed number of samples for -// each output. -template -struct FillPhiloxRandomTask { - typedef typename Distribution::ResultElementType T; - static void Run(random::PhiloxRandom gen, T* data, int64 size, - int64 start_group, int64 limit_group, Distribution dist) { - const int kGroupSize = Distribution::kResultElementCount; - - gen.Skip(start_group); - int64 offset = start_group * kGroupSize; - - // First fill all the full-size groups - int64 limit_group_full = std::min(limit_group, size / kGroupSize); - for (int64 index = start_group; index < limit_group_full; ++index) { - auto samples = dist(&gen); - std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); - offset += kGroupSize; - } - - // If there are any remaining elements that need to be filled, process them - if (limit_group_full < limit_group) { - int64 remaining_size = size - limit_group_full * kGroupSize; - auto samples = dist(&gen); - std::copy(&samples[0], &samples[0] + remaining_size, data + offset); - } - } -}; - -// Specialization for distribution that takes a variable number of samples for -// each output. This will be slower due to the generality. -template -struct FillPhiloxRandomTask { - typedef typename Distribution::ResultElementType T; - static const int64 kReservedSamplesPerOutput = 256; - - static void Run(random::PhiloxRandom base_gen, T* data, int64 size, - int64 start_group, int64 limit_group, Distribution dist) { - const int kGroupSize = Distribution::kResultElementCount; - - static const int kGeneratorSkipPerOutputGroup = - kGroupSize * kReservedSamplesPerOutput / - PhiloxRandom::kResultElementCount; - - int64 offset = start_group * kGroupSize; - - // First fill all the full-size groups - int64 limit_group_full = std::min(limit_group, size / kGroupSize); - int64 group_index; - for (group_index = start_group; group_index < limit_group_full; - ++group_index) { - // Reset the generator to the beginning of the output group region - // This is necessary if we want the results to be independent of order - // of work - PhiloxRandom gen = base_gen; - gen.Skip(group_index * kGeneratorSkipPerOutputGroup); - SingleSampleAdapter single_samples(&gen); - - auto samples = dist(&single_samples); - std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); - offset += kGroupSize; - } - - // If there are any remaining elements that need to be filled, process them - if (limit_group_full < limit_group) { - PhiloxRandom gen = base_gen; - gen.Skip(group_index * kGeneratorSkipPerOutputGroup); - SingleSampleAdapter single_samples(&gen); - - int64 remaining_size = size - limit_group_full * kGroupSize; - auto samples = dist(&single_samples); - std::copy(&samples[0], &samples[0] + remaining_size, data + offset); - } - } -}; - -// Partial specialization for CPU to fill the entire region with randoms -// It splits the work into several tasks and run them in parallel -template -void FillPhiloxRandom::operator()( - OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen, - typename Distribution::ResultElementType* data, int64 size, - Distribution dist) { - const int kGroupSize = Distribution::kResultElementCount; - - auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - - int64 total_group_count = (size + kGroupSize - 1) / kGroupSize; - - const int kGroupCost = - random::PhiloxRandom::kResultElementCount * - (random::PhiloxRandom::kElementCost + Distribution::kElementCost); - Shard(worker_threads.num_threads, worker_threads.workers, total_group_count, - kGroupCost, - [&gen, data, size, dist](int64 start_group, int64 limit_group) { - FillPhiloxRandomTask< - Distribution, - Distribution::kVariableSamplesPerOutput>::Run(gen, data, size, - start_group, - limit_group, dist); - }); -} - -} // namespace functor - namespace { static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape, @@ -299,7 +173,7 @@ class RandomGammaOp : public OpKernel { Tensor* samples_t = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t)); - if (num_samples == 0) return; + if (samples_shape.num_elements() == 0) return; using random::PhiloxRandom; @@ -354,7 +228,7 @@ class RandomGammaOp : public OpKernel { const double alpha = static_cast(alpha_flat[alpha_idx]); DISABLE_FLOAT_EQUALITY_WARNING - if (alpha == double(1.0)) { + if (alpha == static_cast(1.0)) { ENABLE_FLOAT_EQUALITY_WARNING // Sample from an exponential distribution. for (int64 sample_idx = output_idx % num_samples; @@ -364,7 +238,7 @@ class RandomGammaOp : public OpKernel { // (including eventually on GPU), we skip on a per-sample basis. PhiloxRandom gen = rng; gen.Skip(kReservedSamplesPerOutput * output_idx); - short uniform_remaining = 0; + int16 uniform_remaining = 0; UNIFORM(u); const double res = -log(1.0 - u); samples_alpha_offset[sample_idx * num_alphas] = static_cast(res); @@ -392,8 +266,8 @@ class RandomGammaOp : public OpKernel { // (including eventually on GPU), we skip on a per-sample basis. PhiloxRandom gen = rng; gen.Skip(kReservedSamplesPerOutput * output_idx); - short norm_remaining = 0; - short uniform_remaining = 0; + int16 norm_remaining = 0; + int16 uniform_remaining = 0; // Keep trying until we don't reject a sample. In practice, we will // only reject ~5% at worst, for low alpha near 1. @@ -565,145 +439,6 @@ TF_CALL_int64(REGISTER_INT); #ifdef TENSORFLOW_USE_SYCL -namespace functor { - -using namespace cl; - -template -struct FillPhiloxRandomKernel; - -template -struct FillPhiloxRandomKernel { - typedef typename Distribution::ResultElementType T; - using write_accessor = sycl::accessor; - - FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, - Distribution& dist) - : data_(data), gen_(gen), dist_(dist) {} - - void operator()(sycl::nd_item<1> item) { - const size_t kGroupSize = Distribution::kResultElementCount; - - const size_t item_id = item.get_global(0); - const size_t total_item_count = item.get_global_range(); - size_t offset = item_id * kGroupSize; - gen_.Skip(item_id); - - const size_t size = data_.get_size() / sizeof(T); - T* data = ConvertToActualTypeSycl(T, data_); - - while (offset + kGroupSize <= size) { - const typename Distribution::ResultType samples = dist_(&gen_); - for (size_t i = 0; i < kGroupSize; ++i) { - data[offset + i] = samples[i]; - } - - offset += (total_item_count - 1) * kGroupSize; - gen_.Skip(total_item_count - 1); - } - - const typename Distribution::ResultType samples = dist_(&gen_); - for (size_t i = 0; i < kGroupSize; ++i) { - if (offset >= size) { - return; - } - data[offset] = samples[i]; - ++offset; - } - } - - private: - write_accessor data_; - random::PhiloxRandom gen_; - Distribution dist_; -}; - -template -struct FillPhiloxRandomKernel { - typedef typename Distribution::ResultElementType T; - using write_accessor = sycl::accessor; - - FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, - Distribution& dist) - : data_(data), gen_(gen), dist_(dist) {} - - void operator()(sycl::nd_item<1> item) { - using random::PhiloxRandom; - using random::SingleSampleAdapter; - - const size_t kReservedSamplesPerOutput = 256; - const size_t kGroupSize = Distribution::kResultElementCount; - const size_t kGeneratorSkipPerOutputGroup = - kGroupSize * kReservedSamplesPerOutput / - PhiloxRandom::kResultElementCount; - - const size_t item_id = item.get_global(0); - const size_t total_item_count = item.get_global_range(); - size_t group_index = item_id; - size_t offset = group_index * kGroupSize; - - T* data = ConvertToActualTypeSycl(T, data_); - const size_t size = data_.get_size() / sizeof(T); - - while (offset < size) { - // Since each output takes a variable number of samples, we need to - // realign the generator to the beginning for the current output group - PhiloxRandom gen = gen_; - gen.Skip(group_index * kGeneratorSkipPerOutputGroup); - SingleSampleAdapter single_samples(&gen); - - const typename Distribution::ResultType samples = dist_(&single_samples); - - for (size_t i = 0; i < kGroupSize; ++i) { - if (offset >= size) { - return; - } - data[offset] = samples[i]; - ++offset; - } - - offset += (total_item_count - 1) * kGroupSize; - group_index += total_item_count; - } - } - - private: - write_accessor data_; - random::PhiloxRandom gen_; - Distribution dist_; -}; - -template -class FillRandomKernel; -// Partial specialization for SYCL to fill the entire region with randoms -// It splits the work into several tasks and run them in parallel -template -void FillPhiloxRandom::operator()( - OpKernelContext* context, const SYCLDevice& device, - random::PhiloxRandom gen, typename Distribution::ResultElementType* data, - int64 size, Distribution dist) { - const size_t group_size = device.maxSyclThreadsPerBlock(); - const size_t group_count = (size + group_size - 1) / group_size; - - auto buffer = device.get_sycl_buffer(data); - - device.sycl_queue().submit([&](sycl::handler& cgh) { - auto access = buffer.template get_access(cgh); - - FillPhiloxRandomKernel - task(access, gen, dist); - cgh.parallel_for>( - sycl::nd_range<1>(sycl::range<1>(group_count * group_size), - sycl::range<1>(group_size)), - task); - }); -} - -} // namespace functor - #define REGISTER(TYPE) \ template struct functor::FillPhiloxRandom< \ SYCLDevice, random::UniformDistribution>; \ diff --git a/tensorflow/core/kernels/random_op_cpu.h b/tensorflow/core/kernels/random_op_cpu.h new file mode 100644 index 00000000000..cf8594e4752 --- /dev/null +++ b/tensorflow/core/kernels/random_op_cpu.h @@ -0,0 +1,328 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/random_op.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/util/work_sharder.h" + +#if EIGEN_COMP_GNUC && __cplusplus > 199711L +#define DISABLE_FLOAT_EQUALITY_WARNING \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") +#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") +#else +#define DISABLE_FLOAT_EQUALITY_WARNING +#define ENABLE_FLOAT_EQUALITY_WARNING +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + +namespace functor { +using random::PhiloxRandom; +using random::SingleSampleAdapter; + +// The default implementation of the functor, which should never be invoked +// But we still need to provide implementation for now for the linker to work, +// since we do not support all the distributions yet. +template +struct FillPhiloxRandom { + typedef typename Distribution::ResultElementType T; + void operator()(OpKernelContext* ctx, const Device&, random::PhiloxRandom gen, + T* data, int64 size, Distribution dist) { + OP_REQUIRES( + ctx, false, + errors::Internal( + "Default `FillPhiloxRandom` implementation should not be executed. " + "The cause of this error is probabily that `FillPhiloxRandom` does " + "not support this device or random distribution yet.")); + } +}; + +// A class to fill a specified range of random groups +template +struct FillPhiloxRandomTask; + +// Specialization for distribution that takes a fixed number of samples for +// each output. +template +struct FillPhiloxRandomTask { + typedef typename Distribution::ResultElementType T; + static void Run(random::PhiloxRandom gen, T* data, int64 size, + int64 start_group, int64 limit_group, Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + gen.Skip(start_group); + int64 offset = start_group * kGroupSize; + + // First fill all the full-size groups + int64 limit_group_full = std::min(limit_group, size / kGroupSize); + for (int64 index = start_group; index < limit_group_full; ++index) { + auto samples = dist(&gen); + std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); + offset += kGroupSize; + } + + // If there are any remaining elements that need to be filled, process them + if (limit_group_full < limit_group) { + int64 remaining_size = size - limit_group_full * kGroupSize; + auto samples = dist(&gen); + std::copy(&samples[0], &samples[0] + remaining_size, data + offset); + } + } +}; + +// Specialization for distribution that takes a variable number of samples for +// each output. This will be slower due to the generality. +template +struct FillPhiloxRandomTask { + typedef typename Distribution::ResultElementType T; + static const int64 kReservedSamplesPerOutput = 256; + + static void Run(random::PhiloxRandom base_gen, T* data, int64 size, + int64 start_group, int64 limit_group, Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + static const int kGeneratorSkipPerOutputGroup = + kGroupSize * kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + int64 offset = start_group * kGroupSize; + + // First fill all the full-size groups + int64 limit_group_full = std::min(limit_group, size / kGroupSize); + int64 group_index; + for (group_index = start_group; group_index < limit_group_full; + ++group_index) { + // Reset the generator to the beginning of the output group region + // This is necessary if we want the results to be independent of order + // of work + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + auto samples = dist(&single_samples); + std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); + offset += kGroupSize; + } + + // If there are any remaining elements that need to be filled, process them + if (limit_group_full < limit_group) { + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + int64 remaining_size = size - limit_group_full * kGroupSize; + auto samples = dist(&single_samples); + std::copy(&samples[0], &samples[0] + remaining_size, data + offset); + } + } +}; + +// Partial specialization for CPU to fill the entire region with randoms +// It splits the work into several tasks and run them in parallel +template +void FillPhiloxRandom::operator()( + OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64 size, + Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + int64 total_group_count = (size + kGroupSize - 1) / kGroupSize; + + const int kGroupCost = + random::PhiloxRandom::kResultElementCount * + (random::PhiloxRandom::kElementCost + Distribution::kElementCost); + Shard(worker_threads.num_threads, worker_threads.workers, total_group_count, + kGroupCost, + [&gen, data, size, dist](int64 start_group, int64 limit_group) { + FillPhiloxRandomTask< + Distribution, + Distribution::kVariableSamplesPerOutput>::Run(gen, data, size, + start_group, + limit_group, dist); + }); +} + +} // namespace functor + +#ifdef TENSORFLOW_USE_SYCL + +namespace functor { + +template +struct FillPhiloxRandomKernel; + +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + using write_accessor = sycl::accessor; + + FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, + Distribution& dist) + : data_(data), gen_(gen), dist_(dist) {} + + void operator()(sycl::nd_item<1> item) { + const size_t kGroupSize = Distribution::kResultElementCount; + + const size_t item_id = item.get_global(0); + const size_t total_item_count = item.get_global_range(); + size_t offset = item_id * kGroupSize; + gen_.Skip(item_id); + + const size_t size = data_.get_size() / sizeof(T); + T* data = ConvertToActualTypeSycl(T, data_); + + while (offset + kGroupSize <= size) { + const typename Distribution::ResultType samples = dist_(&gen_); + for (size_t i = 0; i < kGroupSize; ++i) { + data[offset + i] = samples[i]; + } + + offset += (total_item_count - 1) * kGroupSize; + gen_.Skip(total_item_count - 1); + } + + const typename Distribution::ResultType samples = dist_(&gen_); + for (size_t i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + } + + private: + write_accessor data_; + random::PhiloxRandom gen_; + Distribution dist_; +}; + +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + using write_accessor = sycl::accessor; + + FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen, + Distribution& dist) + : data_(data), gen_(gen), dist_(dist) {} + + void operator()(sycl::nd_item<1> item) { + using random::PhiloxRandom; + using random::SingleSampleAdapter; + + const size_t kReservedSamplesPerOutput = 256; + const size_t kGroupSize = Distribution::kResultElementCount; + const size_t kGeneratorSkipPerOutputGroup = + kGroupSize * kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + const size_t item_id = item.get_global(0); + const size_t total_item_count = item.get_global_range(); + size_t group_index = item_id; + size_t offset = group_index * kGroupSize; + + T* data = ConvertToActualTypeSycl(T, data_); + const size_t size = data_.get_size() / sizeof(T); + + while (offset < size) { + // Since each output takes a variable number of samples, we need to + // realign the generator to the beginning for the current output group + PhiloxRandom gen = gen_; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + const typename Distribution::ResultType samples = dist_(&single_samples); + + for (size_t i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + + offset += (total_item_count - 1) * kGroupSize; + group_index += total_item_count; + } + } + + private: + write_accessor data_; + random::PhiloxRandom gen_; + Distribution dist_; +}; + +template +class FillRandomKernel; +// Partial specialization for SYCL to fill the entire region with randoms +// It splits the work into several tasks and run them in parallel +template +void FillPhiloxRandom::operator()( + OpKernelContext* context, const SYCLDevice& device, + random::PhiloxRandom gen, typename Distribution::ResultElementType* data, + int64 size, Distribution dist) { + const size_t group_size = device.maxSyclThreadsPerBlock(); + const size_t group_count = (size + group_size - 1) / group_size; + + auto buffer = device.get_sycl_buffer(data); + + device.sycl_queue().submit([&](sycl::handler& cgh) { + auto access = buffer.template get_access(cgh); + + FillPhiloxRandomKernel + task(access, gen, dist); + cgh.parallel_for>( + sycl::nd_range<1>(sycl::range<1>(group_count * group_size), + sycl::range<1>(group_size)), + task); + }); +} + +} // namespace functor + +#endif // TENSORFLOW_USE_SYCL + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ diff --git a/tensorflow/core/kernels/random_op_gpu.cu.cc b/tensorflow/core/kernels/random_op_gpu.cu.cc index 55278d0480e..9c3db8742ba 100644 --- a/tensorflow/core/kernels/random_op_gpu.cu.cc +++ b/tensorflow/core/kernels/random_op_gpu.cu.cc @@ -17,17 +17,15 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/random_op.h" -#include "tensorflow/core/kernels/random_op_gpu.h" - #include #include +#include "tensorflow/core/kernels/random_op_gpu.h" + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { @@ -37,33 +35,6 @@ namespace functor { typedef Eigen::GpuDevice GPUDevice; -// A simple launch pad to call the correct function templates to fill the data -template -__global__ void __launch_bounds__(1024) - FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen, - typename Distribution::ResultElementType* data, - int64 size, Distribution dist) { - FillPhiloxRandomKernel() - .Run(base_gen, data, size, dist); -} - -// Partial specialization for GPU -template -void FillPhiloxRandom::operator()( - OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen, - typename Distribution::ResultElementType* data, int64 size, - Distribution dist) { - const int32 block_size = d.maxGpuThreadsPerBlock(); - const int32 num_blocks = - (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) / - block_size; - - TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch, - num_blocks, block_size, 0, d.stream(), gen, data, - size, dist)); -} - // Explicit instantiation of the GPU distributions functors // clang-format off // NVCC cannot handle ">>" properly diff --git a/tensorflow/core/kernels/random_op_gpu.h b/tensorflow/core/kernels/random_op_gpu.h index e32c755d782..bb7a0723800 100644 --- a/tensorflow/core/kernels/random_op_gpu.h +++ b/tensorflow/core/kernels/random_op_gpu.h @@ -18,8 +18,10 @@ limitations under the License. #if defined(__CUDACC__) +#include "tensorflow/core/kernels/random_op.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -31,15 +33,15 @@ struct FillPhiloxRandomKernel; template struct FillPhiloxRandomKernel { typedef typename Distribution::ResultElementType T; - PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size, - Distribution dist); + PHILOX_DEVICE_INLINE void Run(random::PhiloxRandom gen, T* data, int64 size, + Distribution dist); }; template struct FillPhiloxRandomKernel { typedef typename Distribution::ResultElementType T; - PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data, - int64 size, Distribution dist); + PHILOX_DEVICE_INLINE void Run(const random::PhiloxRandom& base_gen, T* data, + int64 size, Distribution dist); }; template @@ -128,7 +130,7 @@ class SampleCopier { // A cuda kernel to fill the data with random numbers from the specified // distribution. Each output takes a fixed number of samples. template -PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel::Run( +PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel::Run( random::PhiloxRandom gen, T* data, int64 size, Distribution dist) { const int kGroupSize = Distribution::kResultElementCount; @@ -159,7 +161,7 @@ PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel::Run( // A cuda kernel to fill the data with random numbers from the specified // distribution. Each output takes a variable number of samples. template -PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel::Run( +PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel::Run( const random::PhiloxRandom& base_gen, T* data, int64 size, Distribution dist) { using random::PhiloxRandom; @@ -198,6 +200,33 @@ PHILOX_DEVICE_FUNC void FillPhiloxRandomKernel::Run( } } +// A simple launch pad to call the correct function templates to fill the data +template +__global__ void __launch_bounds__(1024) + FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen, + typename Distribution::ResultElementType* data, + int64 size, Distribution dist) { + FillPhiloxRandomKernel() + .Run(base_gen, data, size, dist); +} + +// Partial specialization for GPU +template +void FillPhiloxRandom::operator()( + OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64 size, + Distribution dist) { + const int32 block_size = d.maxGpuThreadsPerBlock(); + const int32 num_blocks = + (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) / + block_size; + + TF_CHECK_OK(CudaLaunchKernel(FillPhiloxRandomKernelLaunch, + num_blocks, block_size, 0, d.stream(), gen, data, + size, dist)); +} + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc index 9522b1ac44b..b38bf1c0f6b 100644 --- a/tensorflow/core/kernels/range_sampler.cc +++ b/tensorflow/core/kernels/range_sampler.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/range_sampler.h" +#include #include #include @@ -69,7 +70,7 @@ static float ExpectedCountHelper(float p, int batch_size, int num_tries) { return p * batch_size; } // numerically stable version of (1 - (1-p)^num_tries) - return -expm1(num_tries * log1p(-p)); + return -std::expm1(num_tries * std::log1p(-p)); } } // namespace @@ -298,7 +299,7 @@ Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file, return errors::InvalidArgument("Wrong vocabulary format at line: ", line); } - w = pow(w, distortion); + w = std::pow(w, distortion); total_weight_ += w; weights_.push_back(w); } @@ -313,7 +314,7 @@ void FixedUnigramSampler::LoadFromUnigrams(const std::vector& unigrams, for (float w : unigrams) { // Skip entries that do not belong to this shard. if (word_id % num_shards_ == shard_) { - w = pow(w, distortion); + w = std::pow(w, distortion); total_weight_ += w; weights_.push_back(w); } diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index 0f08588ebac..84748742b36 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -20,30 +20,31 @@ limitations under the License. #define EIGEN_USE_GPU +#include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/cub/device/device_reduce.cuh" #include "third_party/cub/device/device_segmented_reduce.cuh" #include "third_party/cub/iterator/counting_input_iterator.cuh" #include "third_party/cub/iterator/transform_input_iterator.cuh" #include "third_party/cub/warp/warp_reduce.cuh" -#include "cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cuComplex.h" #include "tensorflow/core/kernels/reduction_ops.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_device_functions.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/permutation_input_iterator.h" #include "tensorflow/core/util/transform_output_iterator.h" -#include - namespace tensorflow { namespace functor { typedef Eigen::GpuDevice GPUDevice; template -struct Sqrt { +struct SqrtOfReal { __host__ __device__ T operator()(const T& a) const { - return Eigen::numext::sqrt(a); + return T(Eigen::numext::sqrt(Eigen::numext::real(a))); } }; @@ -54,28 +55,6 @@ struct Sum { } }; -// needed to work around a compiler bug in nvcc - it doesn't seem to like -// the overloaded addition op for std::complex -template <> -struct Sum> { - __host__ __device__ std::complex operator()( - const std::complex& a, const std::complex& b) const { - auto result = cuCaddf(make_cuComplex(a.real(), a.imag()), - make_cuComplex(b.real(), b.imag())); - return std::complex(result.x, result.y); - } -}; - -template <> -struct Sum> { - __host__ __device__ std::complex operator()( - const std::complex& a, const std::complex& b) const { - auto result = cuCadd(make_cuDoubleComplex(a.real(), a.imag()), - make_cuDoubleComplex(b.real(), b.imag())); - return std::complex(result.x, result.y); - } -}; - template struct Prod { __host__ __device__ T operator()(const T& a, const T& b) const { @@ -83,28 +62,6 @@ struct Prod { } }; -// needed to work around a compiler bug in nvcc - it doesn't seem to like -// the overloaded multiply op for std::complex -template <> -struct Prod> { - __host__ __device__ std::complex operator()( - const std::complex& a, const std::complex& b) const { - auto result = cuCmulf(make_cuComplex(a.real(), a.imag()), - make_cuComplex(b.real(), b.imag())); - return std::complex(result.x, result.y); - } -}; - -template <> -struct Prod> { - __host__ __device__ std::complex operator()( - const std::complex& a, const std::complex& b) const { - auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()), - make_cuDoubleComplex(b.real(), b.imag())); - return std::complex(result.x, result.y); - } -}; - template struct Square { __host__ __device__ T operator()(const T& a) const { @@ -687,7 +644,6 @@ void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in, (T*)temp_storage.flat().data(), extent_x, extent_y, op, init)); dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1); - dim3 num_threads(128, 1, 1); TF_CHECK_OK(CudaLaunchKernel(CleanupSegments, new_grid_dim, block_dim, 0, cu_stream, (T*)temp_storage.flat().data(), out, @@ -918,8 +874,8 @@ struct ReduceFunctor> { const functor::EuclideanNormReducer& reducer) { typedef cub::TransformInputIterator, T*> inputIterType; inputIterType input_itr((T*)in.data(), Square()); - typedef TransformOutputIterator> outputIterType; - outputIterType output_itr((T*)out.data(), Sqrt()); + typedef TransformOutputIterator> outputIterType; + outputIterType output_itr((T*)out.data(), SqrtOfReal()); ReduceImpl, outputIterType, inputIterType, ReductionAxes>( ctx, output_itr, input_itr, in.rank(), in.dimension(0), in.rank() >= 2 ? in.dimension(1) : 1, diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h index 0a1568bdc25..164359f601a 100644 --- a/tensorflow/core/kernels/reduction_ops.h +++ b/tensorflow/core/kernels/reduction_ops.h @@ -117,8 +117,6 @@ struct Identity { FIX_MEAN_IDENTITY(Eigen::half) FIX_MEAN_IDENTITY(float) FIX_MEAN_IDENTITY(double) -FIX_MEAN_IDENTITY(complex64) -FIX_MEAN_IDENTITY(complex128) #undef FIX_MEAN_IDENTITY template diff --git a/tensorflow/core/kernels/redux_functor.h b/tensorflow/core/kernels/redux_functor.h new file mode 100644 index 00000000000..05a867ab007 --- /dev/null +++ b/tensorflow/core/kernels/redux_functor.h @@ -0,0 +1,330 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace functor { + +// Compute reduction over outer dimensions. +// Example: +// input: [D1, D2, ... , DN] +// -> +// output: [Di, ... , DN] where i belongs to set [1,N] +template +struct ReduceOuterDimensions { + template + void operator()(const CPUDevice& device, + const Eigen::DSizes& input_dims, + const Tensor& input, Tensor* output) const { + // Compute inner and outer dim after reshaping into 2d tensor. + const int num_output_dims = output->dims(); + auto output_dims = output->template flat().dimensions(); + + Eigen::Index inner_dim = 1, outer_dim = 1; + for (int i = 0; i < num_dims - num_output_dims; ++i) + outer_dim *= input_dims[i]; + for (int i = num_dims - num_output_dims; i < num_dims; ++i) + inner_dim *= input_dims[i]; + + if (1 == outer_dim) { + // Nothing to do but passing input to output. + output->template flat() = + input.template flat().reshape(output_dims); + return; + } + + // Get device thread num. + const Eigen::Index num_threads = device.numThreads(); + + // If the inner dim parallelism is large enough + if (inner_dim > num_threads * 16) { + // Do not create more blocks than there are threads in a pool. + const Eigen::Index num_blocks = num_threads; + + // Block size along the outer dimension. + const Eigen::Index inner_block_size = Eigen::divup(inner_dim, num_blocks); + const T* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Eigen::Tensor buffer( + {inner_dim}); + buffer.setZero(); + AccumT* buffer_data = buffer.data(); + + using Buffer = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + using Input = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + const auto compute = [inner_dim, outer_dim, num_blocks, inner_block_size, + input_data, buffer_data]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + Eigen::Index inner_dim_start = start * inner_block_size; + Eigen::Index inner_dim_limit = limit * inner_block_size; + inner_dim_limit = std::min(inner_dim, inner_dim_limit); + Eigen::Index my_job_len = inner_dim_limit - inner_dim_start; + + const T* my_job_start = input_data + inner_dim_start; + Buffer buf(buffer_data + inner_dim_start, my_job_len); + + for (Eigen::Index i = 0; i < outer_dim; ++i) { + auto in = Input(my_job_start + i * inner_dim, my_job_len); + auto cast = in.template cast(); + buf = Eigen::TensorCwiseBinaryOp(buf, cast); + } + }; + + // Compute cost of reducing a single block. + const Eigen::Index compute_size = outer_dim * inner_block_size; + const Eigen::Index compute_input_bytes = compute_size * sizeof(T); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + // Write final result to the output. + output->template flat() = + buffer.template cast().reshape(output_dims); + } else { + // Compute block size along the outer dimension for efficiency. + const Eigen::Index parallel_cell_size = inner_dim; + const Eigen::Index total_workload = outer_dim * inner_dim; + const Eigen::Index max_parallelism = total_workload / parallel_cell_size; + + const Eigen::Index min_block_workload = 2000; + const Eigen::Index min_block_size = + Eigen::divup(min_block_workload, parallel_cell_size); + const Eigen::Index max_num_blocks = std::min( + max_parallelism, Eigen::divup(total_workload, min_block_size)); + + // Do not create more blocks than there are threads in a pool. + const Eigen::Index num_blocks = std::min(max_num_blocks, num_threads); + + // Block size along the outer dimension. + const Eigen::Index outer_block_size = Eigen::divup(outer_dim, num_blocks); + + const T* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Tensor buffer(DataTypeToEnum::v(), {num_blocks, inner_dim}); + buffer.template flat().setZero(); + AccumT* buffer_data = buffer.template flat().data(); + + using Buffer = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + using Input = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + const auto compute = [inner_dim, num_blocks, outer_block_size, + buffer_data, input_data, outer_dim]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + Eigen::Index outer_dim_start = start * outer_block_size; + Eigen::Index outer_dim_limit = limit * outer_block_size; + outer_dim_limit = std::min(outer_dim, outer_dim_limit); + + Buffer buf(buffer_data + start * inner_dim, inner_dim); + for (Eigen::Index i = outer_dim_start; i < outer_dim_limit; ++i) { + auto in = Input(input_data + i * inner_dim, inner_dim); + auto cast = in.template cast(); + buf = Eigen::TensorCwiseBinaryOp(buf, cast); + } + }; + + // Compute cost of reducing a single block. + const Eigen::Index compute_size = outer_block_size * inner_dim; + const Eigen::Index compute_input_bytes = compute_size * sizeof(T); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + // Aggregate partial results from temporary buffer into first block. + auto buf0 = Buffer(buffer_data, inner_dim); + // Just sum the buffer up, as inner dimensions is not large in this case. + for (int i = 1; i < num_blocks; ++i) { + auto buf = Buffer(buffer_data + i * inner_dim, inner_dim); + buf0 = Eigen::TensorCwiseBinaryOp(buf0, buf); + } + // Write final result to the output. + output->template flat() = buf0.template cast().reshape(output_dims); + } + } +}; + +// Compute reduction to some serial middle dimensions (like a axis). +// Example: +// input: [D1, D2, ... , DN] +// -> +// output: [Di, ... , Dj] where i & j belongs to set [1,N]. +template +struct ReduceMiddleDimensions { + template + void operator()(const CPUDevice& device, + const Eigen::DSizes& input_dims, + const Tensor& input, Tensor* output, + const int axis_begin_dim) const { + // Compute dims after reshaping into 3d tensor. + const int num_output_dims = output->dims(); + auto output_dims = output->template flat().dimensions(); + + Eigen::Index inner_dim = 1, middle_dim = 1, outer_dim = 1; + for (int i = 0; i < axis_begin_dim; ++i) outer_dim *= input_dims[i]; + for (int i = axis_begin_dim; i < axis_begin_dim + num_output_dims; ++i) + middle_dim *= input_dims[i]; + for (int i = axis_begin_dim + num_output_dims; i < num_dims; ++i) + inner_dim *= input_dims[i]; + + if ((1 == inner_dim * outer_dim)) { + // Nothing to do. + output->template flat() = + input.template flat().reshape(output_dims); + return; + } else if (1 == inner_dim) { + // Equivalent to ReduceOuterDimensions. + const ReduceOuterDimensions redux; + redux(device, input_dims, input, output); + return; + } + + // Compute block size along the outer dimension for efficiency. + const Eigen::Index parallel_cell_size = inner_dim; + const Eigen::Index max_parallelism = outer_dim * middle_dim; + const Eigen::Index total_workload = max_parallelism * inner_dim; + + const Eigen::Index min_block_workload = 2000; + const Eigen::Index min_block_size = + Eigen::divup(min_block_workload, parallel_cell_size); + const Eigen::Index max_num_blocks = + std::min(max_parallelism, Eigen::divup(total_workload, min_block_size)); + + // Do not create more blocks than there are threads in a pool. + const Eigen::Index num_threads = device.numThreads(); + const Eigen::Index num_blocks = std::min(max_num_blocks, num_threads); + + // Block size along the outer dimension. + const Eigen::Index outer_block_size = + Eigen::divup(total_workload, num_blocks); + + const T* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Eigen::Tensor buffer(num_blocks, middle_dim); + buffer.setZero(); + AccumT* buffer_data = buffer.data(); + + using Buffer = Eigen::TensorMap>; + using Input = Eigen::TensorMap>; + + Eigen::array reduction_axis = {0}; + const Reducer reducer; + const BinaryFunctor binary_op; + + const auto compute = [inner_dim, middle_dim, input_data, buffer_data, + total_workload, num_blocks, outer_block_size, + reduction_axis, reducer, binary_op]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + Eigen::Index block_start = start * outer_block_size; + Eigen::Index block_limit = limit * outer_block_size; + block_limit = std::min(total_workload, block_limit); + Buffer buf(buffer_data + start * middle_dim, middle_dim); + + const int align_start = + ((block_start + inner_dim - 1) / inner_dim) * inner_dim; + const int align_end = (block_limit / inner_dim) * inner_dim; + + Eigen::Index coordinate = block_start / inner_dim % middle_dim; + Eigen::Tensor reduced = + Input(&input_data[block_start], align_start - block_start) + .reduce(reduction_axis, reducer) + .template cast(); + + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + + coordinate = align_start / inner_dim % middle_dim; + for (int i = align_start; i < align_end; i += inner_dim) { + reduced = Input(&input_data[i], inner_dim) + .reduce(reduction_axis, reducer) + .template cast(); + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + ++coordinate; + if (middle_dim == coordinate) coordinate = 0; + } + + reduced = Input(&input_data[align_end], block_limit - align_end) + .reduce(reduction_axis, reducer) + .template cast(); + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + }; + + // Compute cost of reducing a single block. + const Eigen::Index compute_size = outer_block_size * inner_dim; + const Eigen::Index compute_input_bytes = compute_size * sizeof(T); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + using Output = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + // Aggregate partial results from temporary buffer into first block. + auto buf0 = Output(buffer_data, middle_dim); + // TODO(ezhulenev): Parallelize this loop for large inner dimensions? + for (int i = 1; i < num_blocks; ++i) { + auto buf = Output(buffer_data + i * middle_dim, middle_dim); + buf0 = Eigen::TensorCwiseBinaryOp(buf0, buf); + } + + // Write final result to the output. + output->template flat() = buf0.template cast().reshape(output_dims); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc index dd5f9495e2c..2ade89b7ff5 100644 --- a/tensorflow/core/kernels/relu_op_gpu.cu.cc +++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/relu_op_functor.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" -#include "tensorflow/core/util/cuda_launch_config.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" namespace tensorflow { @@ -104,11 +104,11 @@ struct ReluGrad { if (count == 0) return; int32 half2_count = Eigen::divup(count, 2); constexpr int32 kThreadInBlock = 512; - CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock); - ReluGradHalfKernel<<>>(gradient.data(), feature.data(), - backprop.data(), count); + TF_CHECK_OK(CudaLaunchKernel( + ReluGradHalfKernel, config.block_count, config.thread_per_block, 0, + d.stream(), gradient.data(), feature.data(), backprop.data(), count)); } }; @@ -133,12 +133,12 @@ struct Relu { int32 vect_count = Eigen::divup(count, 4); constexpr int32 kThreadInBlock = 512; - CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + GpuLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( vect_count, d, Relu_int8x4_kernel, 0, kThreadInBlock); - Relu_int8x4_kernel<<>>( - vect_count, reinterpret_cast(input.data()), - reinterpret_cast(output.data())); + TF_CHECK_OK(CudaLaunchKernel( + Relu_int8x4_kernel, config.block_count, config.thread_per_block, 0, + d.stream(), vect_count, reinterpret_cast(input.data()), + reinterpret_cast(output.data()))); } }; diff --git a/tensorflow/core/kernels/requantize.cc b/tensorflow/core/kernels/requantize.cc index dce6f1a1852..3259e5ddd09 100644 --- a/tensorflow/core/kernels/requantize.cc +++ b/tensorflow/core/kernels/requantize.cc @@ -18,7 +18,6 @@ limitations under the License. #define EIGEN_USE_THREADS #include - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -100,4 +99,10 @@ REGISTER_KERNEL_BUILDER(Name("Requantize") .TypeConstraint("out_type"), RequantizeOp); +REGISTER_KERNEL_BUILDER(Name("Requantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("out_type"), + RequantizeOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/reshape_op.cc b/tensorflow/core/kernels/reshape_op.cc index 33c63e70500..9860448947a 100644 --- a/tensorflow/core/kernels/reshape_op.cc +++ b/tensorflow/core/kernels/reshape_op.cc @@ -86,7 +86,8 @@ REGISTER_KERNEL_BUILDER(Name("Reshape") #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. @@ -106,6 +107,6 @@ REGISTER_KERNEL_BUILDER(Name("Reshape") .TypeConstraint("T") .TypeConstraint("Tshape"), ReshapeOp); -#endif +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h index 7458ac75ca0..47cd219d8cf 100644 --- a/tensorflow/core/kernels/reshape_op.h +++ b/tensorflow/core/kernels/reshape_op.h @@ -45,14 +45,17 @@ class ReshapeOp : public OpKernel { TensorShape shape; int64 product = 1; int unknown_index = -1; + bool sizes_has_zero_dim; switch (sizes.dtype()) { case DT_INT32: - OP_REQUIRES_OK(context, ValidateSizes(sizes, &product, - &unknown_index, &shape)); + OP_REQUIRES_OK(context, + ValidateSizes(sizes, &product, &unknown_index, + &shape, &sizes_has_zero_dim)); break; case DT_INT64: - OP_REQUIRES_OK(context, ValidateSizes(sizes, &product, - &unknown_index, &shape)); + OP_REQUIRES_OK(context, + ValidateSizes(sizes, &product, &unknown_index, + &shape, &sizes_has_zero_dim)); break; default: context->CtxFailure(errors::InvalidArgument( @@ -61,18 +64,28 @@ class ReshapeOp : public OpKernel { return; } if (unknown_index != -1) { - OP_REQUIRES( - context, product > 0, - errors::InvalidArgument("Reshape cannot infer the missing input size " - "for an empty tensor unless all specified " - "input sizes are non-zero")); - const int64 missing = input.NumElements() / product; - OP_REQUIRES( - context, product * missing == input.NumElements(), - errors::InvalidArgument( - "Input to reshape is a tensor with ", input.NumElements(), - " values, but the requested shape requires a multiple of ", - product)); + int64 input_num_elements = 1; + bool input_has_zero_dim = false; + for (int dim = 0; dim < input.dims(); dim++) { + // For zero dimension, we don't count it into `input_num_elements` + // unless `sizes` has no zero dimension, so we are still able to + // infer shapes for other dimensions. + if (input.dim_size(dim) > 0 || !sizes_has_zero_dim) { + input_num_elements *= input.dim_size(dim); + } else { + input_has_zero_dim = true; + } + } + + const int64 missing = input_num_elements / product; + if (!input_has_zero_dim) { + OP_REQUIRES( + context, product * missing == input_num_elements, + errors::InvalidArgument( + "Input to reshape is a tensor with ", input_num_elements, + " values, but the requested shape requires a multiple of ", + product)); + } shape.set_dim(unknown_index, missing); } OP_REQUIRES(context, shape.num_elements() == input.NumElements(), @@ -92,9 +105,10 @@ class ReshapeOp : public OpKernel { private: template Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index, - TensorShape* shape) { + TensorShape* shape, bool* has_zero_dim) { *product = 1; *unknown_index = -1; + *has_zero_dim = false; const int64 num_dims = sizes.NumElements(); auto Svec = sizes.flat(); for (int d = 0; d < num_dims; ++d) { @@ -110,6 +124,12 @@ class ReshapeOp : public OpKernel { } else if (size < 0) { return errors::InvalidArgument("Size ", d, " must be non-negative, not ", size); + } else if (size == 0) { + // We don't include zero-sized dimension in product, so that we can + // still calculate number of elements for non-zero-sized dimensions and + // therefore infer their shapes. + shape->AddDim(size); + *has_zero_dim = true; } else { shape->AddDim(size); (*product) *= size; diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc index 38bb2a9a969..85afa37d5e4 100644 --- a/tensorflow/core/kernels/resize_area_op.cc +++ b/tensorflow/core/kernels/resize_area_op.cc @@ -165,14 +165,14 @@ class ResizeAreaOp : public OpKernel { const float in_x1 = (x + 1) * st.width_scale; // The start and end width indices of all the cells that could // contribute to the target cell. - int64 v = floor(in_x); + int64 v = std::floor(in_x); x_interp.start = v; // TODO(cwhipkey): simplify this logic. x_interp.start_scale = v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x) : (v + 1 > in_x1 ? in_x1 - v : 1.0); - v = ceil(in_x1); + v = std::ceil(in_x1); x_interp.end = v; v = x_interp.end - 1; x_interp.end_minus_one_scale = @@ -226,8 +226,8 @@ class ResizeAreaOp : public OpKernel { const float in_y1 = (y + 1) * st.height_scale; // The start and end height indices of all the cells that could // contribute to the target cell. - const int64 y_start = floor(in_y); - const int64 y_end = ceil(in_y1); + const int64 y_start = std::floor(in_y); + const int64 y_end = std::ceil(in_y1); y_scales.clear(); y_ptrs.clear(); for (int64 i = y_start; i < y_end; ++i) { diff --git a/tensorflow/core/kernels/resize_area_op_test.cc b/tensorflow/core/kernels/resize_area_op_test.cc index 84ff090b546..e57c06a546f 100644 --- a/tensorflow/core/kernels/resize_area_op_test.cc +++ b/tensorflow/core/kernels/resize_area_op_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -103,16 +104,16 @@ class ResizeAreaOpTest : public OpsTestBase { const float in_y1 = (y + 1) * height_scale; // The start and end height indices of all the cells that could // contribute to the target cell. - int64 y_start = floor(in_y); - int64 y_end = ceil(in_y1); + int64 y_start = std::floor(in_y); + int64 y_end = std::ceil(in_y1); for (int64 x = 0; x < out_width; ++x) { const float in_x = x * width_scale; const float in_x1 = (x + 1) * width_scale; // The start and end width indices of all the cells that could // contribute to the target cell. - int64 x_start = floor(in_x); - int64 x_end = ceil(in_x1); + int64 x_start = std::floor(in_x); + int64 x_end = std::ceil(in_x1); sum_data.setConstant(0.0); for (int64 i = y_start; i < y_end; ++i) { diff --git a/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc b/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc index 4da2b877df2..7c8ac7db359 100644 --- a/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc +++ b/tensorflow/core/kernels/resize_bilinear_op_gpu.cu.cc @@ -19,12 +19,11 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/resize_bilinear_op.h" - #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/resize_bilinear_op.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -279,7 +278,7 @@ struct ResizeBilinear { const int total_count = batch * out_height * out_width * channels; if (total_count == 0) return; - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); if (half_pixel_centers) { TF_CHECK_OK(CudaLaunchKernel( ResizeBilinearKernel, config.block_count, config.thread_per_block, @@ -313,19 +312,19 @@ struct ResizeBilinearGrad { const int resized_width = input_grad.dimension(2); int total_count; - CudaLaunchConfig config; + GpuLaunchConfig config; // Initialize output_grad with all zeros. total_count = batch * original_height * original_width * channels; if (total_count == 0) return; - config = GetCudaLaunchConfig(total_count, d); + config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(CudaLaunchKernel( SetZero, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, output_grad.data())); // Accumulate. total_count = batch * resized_height * resized_width * channels; - config = GetCudaLaunchConfig(total_count, d); + config = GetGpuLaunchConfig(total_count, d); if (half_pixel_centers) { TF_CHECK_OK(CudaLaunchKernel( ResizeBilinearGradKernel, config.block_count, diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.cu.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.cu.cc index d2494ea36b0..5ae1bfc92e1 100644 --- a/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.cu.cc +++ b/tensorflow/core/kernels/resize_nearest_neighbor_op_gpu.cu.cc @@ -19,12 +19,11 @@ limitations under the License. #include -#include "tensorflow/core/kernels/resize_nearest_neighbor_op.h" - #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/resize_nearest_neighbor_op.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -173,7 +172,7 @@ struct ResizeNearestNeighbor { const int output_size = batch_size * out_height * out_width * channels; if (output_size == 0) return true; - CudaLaunchConfig config = GetCudaLaunchConfig(output_size, d); + GpuLaunchConfig config = GetCudaLaunchConfig(output_size, d); if (half_pixel_centers) { TF_CHECK_OK(CudaLaunchKernel( ResizeNearestNeighborNHWC, config.block_count, @@ -219,15 +218,16 @@ struct ResizeNearestNeighborGrad>>(output_size, output.data()); + GpuLaunchConfig output_config = GetCudaLaunchConfig(output_size, d); + TF_CHECK_OK(CudaLaunchKernel(SetZero, output_config.block_count, + output_config.thread_per_block, 0, d.stream(), + output_size, output.data())); if (!d.ok()) return false; const int input_size = batch_size * channels * in_height * in_width; if (input_size == 0) return true; - CudaLaunchConfig input_config = GetCudaLaunchConfig(input_size, d); + GpuLaunchConfig input_config = GetCudaLaunchConfig(input_size, d); if (half_pixel_centers) { TF_CHECK_OK(CudaLaunchKernel( ResizeNearestNeighborBackwardNHWC, input_config.block_count, diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 8e3c52ba5b5..e5381a058b8 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -64,6 +64,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/gather_functor.h" +#include "tensorflow/core/kernels/gather_nd_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/scatter_functor.h" #include "tensorflow/core/kernels/training_op_helpers.h" @@ -86,6 +87,7 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { } namespace { + Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) { Tensor* output; Notification n; @@ -583,8 +585,34 @@ REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp") template class ResourceGatherOp : public OpKernel { + private: + int32 batch_dims_ = 0; + + // Add the batch offset derrived from params to each batch of indices. + // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]] + // If indexing into a params dimension of size 4, then the indices will become + // [0, 1, 2, 4, 5, 6] + void AddBatchOffsets(Tensor* indices, const Tensor& params) { + int64 batch_size = 1; // The size of all batch dimensions. + for (int idx = 0; idx < batch_dims_; ++idx) { + batch_size *= params.dim_size(idx); + } + + auto indices_flat = indices->flat(); + int64 const index_inner_size = indices->NumElements() / batch_size; + int64 const batch_offset = params.dim_size(batch_dims_); + for (int64 batch_idx = 0, dest_idx = 0; batch_idx < batch_size; + ++batch_idx) { + for (int64 idx = 0; idx < index_inner_size; ++idx) { + indices_flat(dest_idx++) += batch_offset * batch_idx; + } + } + } + public: - explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {} + explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_)); + } void Compute(OpKernelContext* c) override { Var* v = nullptr; @@ -612,9 +640,16 @@ class ResourceGatherOp : public OpKernel { " indexing: ", params.dim_size(0), " > ", std::numeric_limits::max())); - // The result shape is indices.shape + params.shape[1:]. - TensorShape result_shape = indices.shape(); - for (int i = 1; i < params.dims(); i++) { + // The result shape is params.shape[:batch_dims] + + // indices.shape[batch_dims:] + params.shape[batch_dims+1:]. + TensorShape result_shape; + for (int i = 0; i < batch_dims_; ++i) { + result_shape.AddDim(params.dim_size(i)); + } + for (int i = batch_dims_; i < indices.dims(); ++i) { + result_shape.AddDim(indices.dim_size(i)); + } + for (int i = batch_dims_ + 1; i < params.dims(); ++i) { result_shape.AddDim(params.dim_size(i)); } @@ -627,14 +662,33 @@ class ResourceGatherOp : public OpKernel { } else { OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out)); } + if (N > 0) { - const int64 gather_dim_size = params.dim_size(0); + Tensor tmp_indices; + + // Points to the original or updated (if batch_dims is set) indices. + const Tensor* op_indices = &indices; + if (batch_dims_ > 0) { + OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(), + &tmp_indices)); + functor::DenseUpdate copy_functor; + copy_functor(c->eigen_device(), tmp_indices.flat(), + indices.flat()); + + AddBatchOffsets(&tmp_indices, params); + op_indices = &tmp_indices; + } + + int64 gather_dim_size = 1; + for (int idx = 0; idx <= batch_dims_; ++idx) { + gather_dim_size *= params.dim_size(idx); + } int64 inner_size = 1; - for (int i = 1; i < params.dims(); i++) { + for (int i = batch_dims_ + 1; i < params.dims(); ++i) { inner_size *= params.dim_size(i); } auto params_flat = params.shaped({1, gather_dim_size, inner_size}); - auto indices_flat = indices.flat(); + const auto indices_flat = op_indices->flat(); auto out_flat = out->shaped({1, N, out->NumElements() / N}); functor::GatherFunctor functor; @@ -697,6 +751,62 @@ REGISTER_KERNEL_BUILDER(Name("ResourceGather") #undef REGISTER_GATHER_ALL_INDICES #undef REGISTER_GATHER_FULL +template +class ResourceGatherNdOp : public OpKernel { + public: + explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + Var* v = nullptr; + OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); + core::ScopedUnref su(v); + OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v)); + // NOTE: We hold the lock for the whole gather operation instead + // of increasing the reference count of v->tensor() to avoid a + // situation where a write to the same variable will see a + // reference count greater than one and make a copy of the + // (potentially very large) tensor buffer. + tf_shared_lock ml(*v->mu()); + const Tensor& params = *v->tensor(); + const Tensor& indices = c->input(1); + + Tensor out; + OP_REQUIRES_OK( + c, functor::DoGatherNd(c, params, indices, &out)); + c->set_output(0, out); + } +}; + +#define REGISTER_GATHER_ND_FULL(dev, type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd") \ + .Device(DEVICE_##dev) \ + .HostMemory("resource") \ + .TypeConstraint("dtype") \ + .TypeConstraint("Tindices"), \ + ResourceGatherNdOp) + +#define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \ + REGISTER_GATHER_ND_FULL(dev, type, int32); \ + REGISTER_GATHER_ND_FULL(dev, type, int64) + +#define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type) + +// Registration of the CPU implementations. +TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); + +// Registers GPU kernels. +#if GOOGLE_CUDA +#define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU); + +#endif // GOOGLE_CUDA + +#undef REGISTER_GATHER_ND_CPU +#undef REGISTER_GATHER_ND_GPU +#undef REGISTER_GATHER_ND_ALL_INDICES +#undef REGISTER_GATHER_ND_FULL + template class ResourceScatterUpdateOp : public OpKernel { public: diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index aa2434da03f..c60ab60849f 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -317,7 +317,7 @@ TF_CALL_POD_TYPES(REGISTER_KERNELS); TF_CALL_string(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the function specializations for GPU (to prevent // building the GPU versions here, they will be built compiling _gpu.cu.cc). @@ -407,7 +407,7 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2") .HostMemory("axis") .HostMemory("output"), ReverseV2Op); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNELS(T) \ diff --git a/tensorflow/core/kernels/reverse_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_op_gpu.cu.cc index 3ee49db669f..2917a0d5f11 100644 --- a/tensorflow/core/kernels/reverse_op_gpu.cu.cc +++ b/tensorflow/core/kernels/reverse_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -51,4 +51,4 @@ TF_CALL_complex128(DEFINE_REVERSE_ALL_DIMS); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc index cded417986b..0e112133915 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.cc +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -17,9 +17,9 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/reverse_sequence_op.h" @@ -177,7 +177,7 @@ class ReverseSequenceOp : public OpKernel { TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN); TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_LEN); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { @@ -222,6 +222,6 @@ TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_GPU_LEN); #undef REGISTER_REVERSE_SEQUENCE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc index 4a2136a2cd3..948a99a7d37 100644 --- a/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc +++ b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -43,4 +43,4 @@ TF_CALL_bool(DEFINE_GPU_SPECS); } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc index 0e68af867bd..9de850acd05 100644 --- a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc +++ b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/image_ops.cc. #include +#include #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -122,8 +123,8 @@ bool GenerateRandomCrop(int original_width, int original_height, const float max_area = max_relative_crop_area * original_width * original_height; - int height = static_cast(lrintf(sqrt(min_area / aspect_ratio))); - int max_height = static_cast(lrintf(sqrt(max_area / aspect_ratio))); + int height = static_cast(lrintf(std::sqrt(min_area / aspect_ratio))); + int max_height = static_cast(lrintf(std::sqrt(max_area / aspect_ratio))); if (lrintf(max_height * aspect_ratio) > original_width) { // We must find the smallest max_height satisfying diff --git a/tensorflow/core/kernels/scale_and_translate_op.cc b/tensorflow/core/kernels/scale_and_translate_op.cc index 92b458f2e75..fff457e55c7 100644 --- a/tensorflow/core/kernels/scale_and_translate_op.cc +++ b/tensorflow/core/kernels/scale_and_translate_op.cc @@ -82,10 +82,8 @@ Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel, const float col_f = x + 0.5f; const float sample_f = col_f * inv_scale + inv_translate; - // Don't sample when the sampling *kernel* is completely outside the - // source image. - if (sample_f < 0 - kernel.Radius() * kernel_scale || - sample_f > input_size + kernel.Radius() * kernel_scale) { + // Don't sample when the sampling location is outside the source image. + if (sample_f < 0 || sample_f > input_size) { // Add an empty span. starts_vec(x) = 0; continue; @@ -169,11 +167,15 @@ Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans, auto grad_weights_vec = grad_spans->weights.vec(); grad_weights_vec.setZero(); for (int input_index = 0; input_index < forward_input_size; ++input_index) { - const int start_span = grad_components[input_index].front().index; - grad_starts_vec(input_index) = start_span; - for (const GradComponent& gc : grad_components[input_index]) { - grad_weights_vec(input_index * grad_spans->span_size + gc.index - - start_span) += gc.weight; + if (!grad_components[input_index].empty()) { + const int start_span = grad_components[input_index].front().index; + grad_starts_vec(input_index) = start_span; + for (const GradComponent& gc : grad_components[input_index]) { + grad_weights_vec(input_index * grad_spans->span_size + gc.index - + start_span) += gc.weight; + } + } else { + grad_starts_vec(input_index) = 0; } } return Status::OK(); diff --git a/tensorflow/core/kernels/scale_and_translate_op_test.cc b/tensorflow/core/kernels/scale_and_translate_op_test.cc index 127f1641554..a17e3d83963 100644 --- a/tensorflow/core/kernels/scale_and_translate_op_test.cc +++ b/tensorflow/core/kernels/scale_and_translate_op_test.cc @@ -120,7 +120,8 @@ void Sample(const DynamicKernel& kernel, const bool antialias, 1; std::fill(dest, dest + channels, 0.0f); - if (y_span_end <= y_span_start || x_span_end <= x_span_start) { + if (sample_f.x() < 0.0f || sample_f.y() < 0.0f || sample_f.x() > in_width || + sample_f.y() > in_height) { return; } const Vector2f one_over_kernel_scale(1.0f / kernel_scale.x(), @@ -170,6 +171,8 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel, const int64 out_height = output.dimension(1); const int64 out_width = output.dimension(2); + const int64 in_height = images.dimension(1); + const int64 in_width = images.dimension(2); for (int b = 0; b < batch; ++b) { for (int64 y = 0; y < out_height; ++y) { @@ -178,8 +181,13 @@ void ScaleAndTranslateBaseline(const DynamicKernel& kernel, for (int64 x = 0; x < out_width; ++x) { const float out_x_f = static_cast(x) + 0.5; const float in_x_f = out_x_f * scale.x() + translate.x(); - Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f), - &output(b, y, x, 0)); + if (in_x_f < 0.0f || in_y_f < 0.0f || in_x_f > in_width || + in_y_f > in_height) { + std::fill(&output(b, y, x, 0), &output(b, y, x + 1, 0), 0.0f); + } else { + Sample(kernel, antialias, images, b, scale, Vector2f(in_x_f, in_y_f), + &output(b, y, x, 0)); + } } } } diff --git a/tensorflow/core/kernels/scan_ops_gpu.h b/tensorflow/core/kernels/scan_ops_gpu.h index 557b72000a7..685fe3bf950 100644 --- a/tensorflow/core/kernels/scan_ops_gpu.h +++ b/tensorflow/core/kernels/scan_ops_gpu.h @@ -29,15 +29,15 @@ limitations under the License. #include "third_party/cub/block/block_store.cuh" #include "third_party/cub/iterator/counting_input_iterator.cuh" #include "third_party/cub/iterator/transform_input_iterator.cuh" -#include "cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cuComplex.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/cuda_launch_config.h" +#include "tensorflow/core/kernels/scan_ops.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" #include "tensorflow/core/util/permutation_input_iterator.h" #include "tensorflow/core/util/permutation_output_iterator.h" -#include "tensorflow/core/kernels/scan_ops.h" - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h index 57344c1dd24..6c195e59e20 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/scatter_functor.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -126,7 +126,7 @@ struct ScatterFunctor { const Index first_dim_size = params.dimension(0); const Index indices_size = indices.size(); const Index updates_size = updates.size(); - CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d); + GpuLaunchConfig config = GetCudaLaunchConfig(updates_size, d); TF_CHECK_OK(CudaLaunchKernel( scatter_op_gpu::ScatterOpCustomKernel, config.block_count, config.thread_per_block, 0, d.stream(), params.data(), updates.data(), @@ -147,7 +147,7 @@ struct ScatterScalarFunctor { const Index first_dim_size = params.dimension(0); const Index indices_size = indices.size(); const Index synthesized_updates_size = indices_size * params.dimension(1); - CudaLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d); + GpuLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d); TF_CHECK_OK(CudaLaunchKernel( scatter_op_gpu::ScatterScalarOpCustomKernel, config.block_count, config.thread_per_block, 0, d.stream(), diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc index 9936b3f9b78..9152e71acb2 100644 --- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/scatter_nd_op.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -135,7 +135,7 @@ struct ScatterNdFunctor { } } - CudaLaunchConfig config = GetCudaLaunchConfig(Toutput.size(), d); + GpuLaunchConfig config = GetCudaLaunchConfig(Toutput.size(), d); TF_CHECK_OK(CudaLaunchKernel(ScatterNdOpKernel, config.block_count, config.thread_per_block, 0, diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc index 71580ff9a87..bd20793b078 100644 --- a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc +++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc @@ -17,15 +17,14 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/searchsorted_op.h" - #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/searchsorted_op.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -65,8 +64,8 @@ struct UpperBoundFunctor { int batch_size, int num_inputs, int num_values, typename TTypes::Tensor* output) { const cudaStream_t& stream = GetCudaStream(context); - CudaLaunchConfig config = - GetCudaLaunchConfig(values.size(), context->eigen_gpu_device()); + GpuLaunchConfig config = + GetGpuLaunchConfig(values.size(), context->eigen_gpu_device()); TF_CHECK_OK(CudaLaunchKernel( UpperBoundKernel, config.block_count, @@ -85,8 +84,8 @@ struct LowerBoundFunctor { int batch_size, int num_inputs, int num_values, typename TTypes::Tensor* output) { const cudaStream_t& stream = GetCudaStream(context); - CudaLaunchConfig config = - GetCudaLaunchConfig(values.size(), context->eigen_gpu_device()); + GpuLaunchConfig config = + GetGpuLaunchConfig(values.size(), context->eigen_gpu_device()); TF_CHECK_OK(CudaLaunchKernel( LowerBoundKernel, config.block_count, diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 6e1a0d57a16..60c6a7dbcaf 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -20,12 +20,12 @@ limitations under the License. #define EIGEN_USE_GPU #endif // GOOGLE_CUDA -#include "third_party/eigen3/Eigen/Core" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - #include "tensorflow/core/kernels/segment_reduction_ops.h" + #include +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index d28e35157b2..ce584812968 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -17,22 +17,8 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ // This file requires the following include because it uses CudaAtomicMax: -// #include "tensorflow/core/util/cuda_kernel_helper.h" - -// Unfortunately we can't add the #include, since it breaks compilation for -// non-GPU targets. This only breaks in clang, because it's more strict for -// template code and CudaAtomicMax is used in template context. - -// This file requires the following include because it uses CudaAtomicMax: -// #include "tensorflow/core/util/cuda_kernel_helper.h" - -// Unfortunately we can't add the #include, since it breaks compilation for -// non-GPU targets. This only breaks in clang, because it's more strict for -// template code and CudaAtomicMax is used in template context. - -// This file requires the following include because it uses CudaAtomicMax: -// #include "tensorflow/core/util/cuda_kernel_helper.h" - +// #include "tensorflow/core/util/gpu_kernel_helper.h" +// // Unfortunately we can't add the #include, since it breaks compilation for // non-GPU targets. This only breaks in clang, because it's more strict for // template code and CudaAtomicMax is used in template context. diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index 39406dd9a22..305673b56fc 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -19,12 +19,13 @@ limitations under the License. // We need to include cuda_kernel_helper.h before segment_reduction_ops.h // See comment in segment_reduction_ops.h for more details. -#include "tensorflow/core/util/cuda_kernel_helper.h" +// clang-format off +#include "tensorflow/core/util/gpu_kernel_helper.h" +// clang-format on #include "tensorflow/core/kernels/segment_reduction_ops.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/cuda_device_functions.h" - +#include "tensorflow/core/util/gpu_device_functions.h" namespace tensorflow { @@ -137,7 +138,7 @@ void SegmentSumFunctor::operator()( return; } // Set 'output' to zeros. - CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); + GpuLaunchConfig config = GetCudaLaunchConfig(output.size(), d); TF_CHECK_OK(CudaLaunchKernel(SetZero, config.block_count, config.thread_per_block, 0, d.stream(), output.size(), output.data())); @@ -162,7 +163,7 @@ void SegmentSumFunctor::operator()( const Index total_stripe_count = input_inner_dim_size * input_outer_dim_num_stripe; - config = GetCudaLaunchConfig(total_stripe_count, d); + config = GetGpuLaunchConfig(total_stripe_count, d); TF_CHECK_OK(CudaLaunchKernel( SortedSegmentSumCustomKernel, config.block_count, config.thread_per_block, 0, d.stream(), @@ -183,7 +184,7 @@ struct UnsortedSegmentFunctor { } // Set 'output' to initial value. GPUDevice d = ctx->template eigen_device(); - CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); + GpuLaunchConfig config = GetCudaLaunchConfig(output.size(), d); TF_CHECK_OK(CudaLaunchKernel( SetToValue, config.block_count, config.thread_per_block, 0, d.stream(), output.size(), output.data(), InitialValueF()())); @@ -197,7 +198,7 @@ struct UnsortedSegmentFunctor { // *) 'input_outer_dim_size' is the total number of segments to process. const Index input_outer_dim_size = segment_ids.dimension(0); const Index input_inner_dim_size = data_size / input_outer_dim_size; - config = GetCudaLaunchConfig(data_size, d); + config = GetGpuLaunchConfig(data_size, d); TF_CHECK_OK(CudaLaunchKernel( UnsortedSegmentCustomKernel, config.block_count, diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 6521dcf932a..91d6e9b2d39 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -116,7 +116,6 @@ REGISTER_KERNEL_BUILDER( Name("_HostSend").Device(DEVICE_SYCL).HostMemory("tensor"), SendOp); #endif // TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp); REGISTER_KERNEL_BUILDER( Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp); @@ -200,7 +199,6 @@ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp); REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_SYCL), RecvOp); #endif // TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp); REGISTER_KERNEL_BUILDER( Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp); @@ -209,4 +207,16 @@ REGISTER_KERNEL_BUILDER( Name("_HostRecv").Device(DEVICE_SYCL).HostMemory("tensor"), RecvOp); #endif // TENSORFLOW_USE_SYCL +// Environment variable `DISABLE_HOST_SEND_RECV_REGISTRATION` is used to disable +// hostSend and hostRecv registration on CPU device in the mock environment. +static bool InitModule() { + if (!std::getenv("DISABLE_HOST_SEND_RECV_REGISTRATION")) { + REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp); + REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp); + } + return true; +} + +static bool module_initialized = InitModule(); + } // end namespace tensorflow diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index 21c3b89f548..02dcc1e4dec 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -103,14 +103,14 @@ TF_CALL_double(REGISTER_CPU_KERNEL); TF_CALL_int32(REGISTER_CPU_KERNEL); TF_CALL_int64(REGISTER_CPU_KERNEL); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_CALL_float(REGISTER_GPU_KERNEL); TF_CALL_double(REGISTER_GPU_KERNEL); TF_CALL_int32(REGISTER_GPU_KERNEL); TF_CALL_int64(REGISTER_GPU_KERNEL); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #undef REGISTER_KERNEL #undef REGISTER_CPU_KERNEL diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index db7357ca70e..86ccde9fb8c 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -68,7 +68,7 @@ REGISTER_KERNEL_BUILDER(Name("Shape") ShapeOp); #endif // TENSORFLOW_USE_SYCL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Shape") \ .Device(DEVICE_GPU) \ @@ -106,7 +106,7 @@ REGISTER_KERNEL_BUILDER(Name("Shape") .TypeConstraint("out_type"), ShapeOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // ShapeN --------------------------------------- REGISTER_KERNEL_BUILDER(Name("ShapeN") @@ -120,7 +120,7 @@ REGISTER_KERNEL_BUILDER(Name("ShapeN") .TypeConstraint("out_type"), ShapeNOp); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("ShapeN") \ .Device(DEVICE_GPU) \ @@ -156,7 +156,7 @@ REGISTER_KERNEL_BUILDER(Name("ShapeN") .TypeConstraint("T") .TypeConstraint("out_type"), ShapeNOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(type) \ @@ -222,7 +222,7 @@ REGISTER_KERNEL_BUILDER(Name("Rank") RankOp); #endif // TENSORFLOW_USE_SYCL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Rank") \ .Device(DEVICE_GPU) \ @@ -250,7 +250,7 @@ REGISTER_KERNEL_BUILDER(Name("Rank") .HostMemory("output"), RankOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Size ------------------------------------------ REGISTER_KERNEL_BUILDER(Name("Size") @@ -264,7 +264,7 @@ REGISTER_KERNEL_BUILDER(Name("Size") .TypeConstraint("out_type"), SizeOp); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("Size") \ .Device(DEVICE_GPU) \ @@ -301,7 +301,7 @@ REGISTER_KERNEL_BUILDER(Name("Size") .HostMemory("output"), SizeOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(type) \ @@ -349,7 +349,7 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .TypeConstraint("Tdim"), ExpandDimsOp); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ .Device(DEVICE_GPU) \ @@ -383,7 +383,7 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .HostMemory("dim") .HostMemory("output"), ExpandDimsOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(type) \ @@ -424,7 +424,7 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") // Squeeze --------------------------------------- REGISTER_KERNEL_BUILDER(Name("Squeeze").Device(DEVICE_CPU), SqueezeOp); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ Name("Squeeze").Device(DEVICE_GPU).TypeConstraint("T"), \ @@ -442,7 +442,7 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze") .HostMemory("input") .HostMemory("output"), SqueezeOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(type) \ @@ -532,7 +532,7 @@ REGISTER_GPU_KERNEL(Variant); #undef REGISTER_GPU_KERNEL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // A special GPU kernel for int32 and bool. // TODO(b/25387198): Also enable int32 in device memory. This kernel // registration requires all int32 inputs and outputs to be in host memory. diff --git a/tensorflow/core/kernels/snapshot_op.cc b/tensorflow/core/kernels/snapshot_op.cc index fe04dcf72e2..95bcfd6b39d 100644 --- a/tensorflow/core/kernels/snapshot_op.cc +++ b/tensorflow/core/kernels/snapshot_op.cc @@ -51,7 +51,7 @@ class SnapshotOp : public OpKernel { TF_CALL_POD_TYPES(REGISTER_KERNEL); #undef REGISTER_KERNEL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_KERNEL(TYPE) \ REGISTER_KERNEL_BUILDER( \ Name("Snapshot").Device(DEVICE_GPU).TypeConstraint("T"), \ diff --git a/tensorflow/core/kernels/snapshot_op.h b/tensorflow/core/kernels/snapshot_op.h index 02d492988eb..a35233bb43c 100644 --- a/tensorflow/core/kernels/snapshot_op.h +++ b/tensorflow/core/kernels/snapshot_op.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_ #define TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #endif diff --git a/tensorflow/core/kernels/snapshot_op_gpu.cu.cc b/tensorflow/core/kernels/snapshot_op_gpu.cu.cc index e4e3bd52203..d4fee5b40e6 100644 --- a/tensorflow/core/kernels/snapshot_op_gpu.cu.cc +++ b/tensorflow/core/kernels/snapshot_op_gpu.cu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // See docs in ../ops/array_ops.cc. #include "tensorflow/core/kernels/snapshot_op.h" @@ -31,4 +31,4 @@ TF_CALL_POD_TYPES(DEFINE_GPU_KERNELS); } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc index 9b2f3a963bd..c11b59fe46a 100644 --- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc @@ -23,12 +23,10 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/platform/types.h" - -#include "tensorflow/core/util/cuda_kernel_helper.h" - #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" #include "tensorflow/core/kernels/reduction_ops_common.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc index d3fc0e1461b..0c0f33093e3 100644 --- a/tensorflow/core/kernels/softplus_op.cc +++ b/tensorflow/core/kernels/softplus_op.cc @@ -87,7 +87,8 @@ void SoftplusGradOp::OperateNoTemplate(OpKernelContext* context, TF_CALL_FLOAT_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -119,6 +120,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/softplus_op_gpu.cu.cc b/tensorflow/core/kernels/softplus_op_gpu.cu.cc index 8df734588b8..0cf169da85e 100644 --- a/tensorflow/core/kernels/softplus_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softplus_op_gpu.cu.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU @@ -37,4 +38,4 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc index d691f156518..df1c61f4f22 100644 --- a/tensorflow/core/kernels/softsign_op.cc +++ b/tensorflow/core/kernels/softsign_op.cc @@ -88,7 +88,7 @@ void SoftsignGradOp::OperateNoTemplate(OpKernelContext* context, TF_CALL_FLOAT_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -120,6 +120,6 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/softsign_op_gpu.cu.cc b/tensorflow/core/kernels/softsign_op_gpu.cu.cc index b80cdf0d963..679f743ac18 100644 --- a/tensorflow/core/kernels/softsign_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softsign_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -37,4 +37,4 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/spacetobatch_functor_gpu.cu.cc b/tensorflow/core/kernels/spacetobatch_functor_gpu.cu.cc index ea6e076909b..4db5c6f30ad 100644 --- a/tensorflow/core/kernels/spacetobatch_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/spacetobatch_functor_gpu.cu.cc @@ -19,10 +19,9 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/spacetobatch_functor.h" - #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/kernels/spacetobatch_functor.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -139,8 +138,8 @@ struct SpaceToBatchFunctor { return errors::InvalidArgument( "number of batch_tensor elements exceeds 2^32-1"); } - CudaLaunchConfig config = - GetCudaLaunchConfig(static_cast(total_count), d); + GpuLaunchConfig config = + GetGpuLaunchConfig(static_cast(total_count), d); return CudaLaunchKernel(S2B, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, diff --git a/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc b/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc index 606ff89e742..55573208540 100644 --- a/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc +++ b/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc @@ -17,11 +17,10 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/spacetodepth_op.h" - #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/spacetodepth_op.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -157,7 +156,7 @@ struct SpaceToDepthOpFunctor { if (total_count == 0) { return; } - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); TF_CHECK_OK(CudaLaunchKernel( S2D_NHWC, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, input.data(), block_size, batch_size, @@ -191,7 +190,7 @@ struct SpaceToDepthOpFunctor { if (total_count == 0) { return; } - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); switch (block_size) { case 2: TF_CHECK_OK(CudaLaunchKernel( @@ -222,7 +221,7 @@ struct SpaceToDepthOpFunctor { if (total_count == 0) { return; } - CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); TF_CHECK_OK(CudaLaunchKernel( S2D_NCHW, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, input.data(), block_size, output_width, diff --git a/tensorflow/core/kernels/sparse_softmax_op.cc b/tensorflow/core/kernels/sparse_softmax_op.cc index 37664fe8df8..548080b8b13 100644 --- a/tensorflow/core/kernels/sparse_softmax_op.cc +++ b/tensorflow/core/kernels/sparse_softmax_op.cc @@ -62,12 +62,8 @@ class SparseSoftmaxOp : public OpKernel { errors::InvalidArgument( "Input should have rank >= 2, but received shape: ", shape_t->SummarizeValue(3))); - OP_REQUIRES(context, - indices_t->dim_size(0) < std::numeric_limits::max(), - errors::InvalidArgument( - "Number of non-zero elements exceeds int32 range")); - const int nnz = static_cast(indices_t->dim_size(0)); + const int64 nnz = indices_t->dim_size(0); const int rank = static_cast(indices_t->dim_size(1)); SparseTensor st; OP_REQUIRES_OK( diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc index f85f2a48a10..2b00549a9ea 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc @@ -17,11 +17,10 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h" - #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -79,7 +78,7 @@ struct SparseTensorDenseMatMulFunctor { // TODO(ebrevdo): Should this be alpha * nnz instead of // out.size()? Perhaps p * nnz ? - CudaLaunchConfig config = GetCudaLaunchConfig(p * nnz, d); + GpuLaunchConfig config = GetCudaLaunchConfig(p * nnz, d); TF_CHECK_OK(CudaLaunchKernel( SparseTensorDenseMatMulKernel, diff --git a/tensorflow/core/kernels/sparse_xent_op_test.cc b/tensorflow/core/kernels/sparse_xent_op_test.cc index afb0bf76267..f20af4f9217 100644 --- a/tensorflow/core/kernels/sparse_xent_op_test.cc +++ b/tensorflow/core/kernels/sparse_xent_op_test.cc @@ -49,6 +49,7 @@ static Graph* SparseXent(int batch_size, int num_classes) { BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE); /// The representative tests for ptb_word on GPU +#ifdef GOOGLE_CUDA BM_SparseXentDev(8, 1000000, gpu); BM_SparseXentDev(16, 10000, gpu); @@ -62,6 +63,7 @@ BM_SparseXentDev(32, 100000, gpu); BM_SparseXentDev(64, 10000, gpu); BM_SparseXentDev(64, 30000, gpu); BM_SparseXentDev(64, 100000, gpu); +#endif // GOOGLE_CUDA // CPU BM_SparseXentDev(8, 1000000, cpu); diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc index 3d42f2dc70b..368239477b1 100644 --- a/tensorflow/core/kernels/split_lib_gpu.cu.cc +++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc @@ -19,12 +19,12 @@ limitations under the License. #include -#include "tensorflow/core/kernels/split_lib.h" - #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/cuda_device_array_gpu.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/kernels/gpu_device_array_gpu.h" +#include "tensorflow/core/kernels/split_lib.h" +#include "tensorflow/core/kernels/split_lib_gpu.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { @@ -56,6 +56,8 @@ TF_CALL_complex64(DEFINE_GPU_KERNELS); TF_CALL_complex128(DEFINE_GPU_KERNELS); TF_CALL_int64(DEFINE_GPU_KERNELS); TF_CALL_bfloat16(DEFINE_GPU_KERNELS); +TF_CALL_uint8(DEFINE_GPU_KERNELS); +TF_CALL_bool(DEFINE_GPU_KERNELS); #undef DEFINE_GPU_KERNELS #define DEFINE_GPU_KERNELS(T) template struct SplitCustom; @@ -74,9 +76,9 @@ namespace { template __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size, int32 split_dim_size, int32 suffix_dim_size, - CudaDeviceArrayStruct output_ptr_data) { + GpuDeviceArrayStruct output_ptr_data) { const int32 num_split = output_ptr_data.size; - T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); + T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data); eigen_assert(blockDim.y == 1); eigen_assert(blockDim.z == 1); @@ -111,11 +113,11 @@ __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size, // is reversed template __global__ void split_v_kernel(const T* input_ptr, - CudaDeviceArrayStruct output_scan, + GpuDeviceArrayStruct output_scan, IntType total_rows, IntType total_cols, - CudaDeviceArrayStruct output_ptr_data) { - T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); - IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan); + GpuDeviceArrayStruct output_ptr_data) { + T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data); + IntType* col_scan = GetGpuDeviceArrayOnDevice(&output_scan); // do upper_bound on col to find which pointer we should be using IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; @@ -167,11 +169,11 @@ __global__ void split_v_kernel(const T* input_ptr, // different from the original split implementation due to 2D vs 3D // dimensions. This version is likely faster due to less integer math. template -__global__ void SplitVOpKernel_fixed( - const T* input, int32 prefix_dim_size, int32 suffix_dim_size, - CudaDeviceArrayStruct output_ptr_data) { +__global__ void SplitVOpKernel_fixed(const T* input, int32 prefix_dim_size, + int32 suffix_dim_size, + GpuDeviceArrayStruct output_ptr_data) { const int32 num_split = output_ptr_data.size; - T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); + T** output_ptrs = GetGpuDeviceArrayOnDevice(&output_ptr_data); eigen_assert(blockDim.y == 1); eigen_assert(blockDim.z == 1); @@ -192,54 +194,53 @@ __global__ void SplitVOpKernel_fixed( } template -struct SplitOpGPULaunch { - void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, - int32 split_dim_size, int32 suffix_dim_size, - const CudaDeviceArrayStruct& output_ptr_data) { - CudaLaunchConfig config = GetCudaLaunchConfig( - prefix_dim_size * split_dim_size * suffix_dim_size, d); +void SplitOpGPULaunch::Run(const Eigen::GpuDevice& d, const T* input, + int32 prefix_dim_size, int32 split_dim_size, + int32 suffix_dim_size, + const GpuDeviceArrayStruct& output_ptr_data) { + GpuLaunchConfig config = GetCudaLaunchConfig( + prefix_dim_size * split_dim_size * suffix_dim_size, d); - TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel, config.block_count, - config.thread_per_block, 0, d.stream(), input, - prefix_dim_size, split_dim_size, - suffix_dim_size, output_ptr_data)); - } -}; + TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel, config.block_count, + config.thread_per_block, 0, d.stream(), input, + prefix_dim_size, split_dim_size, suffix_dim_size, + output_ptr_data)); +} template -struct SplitVOpGPULaunch { - void Run(const Eigen::GpuDevice& gpu_device, bool fixed_size, - const T* input_ptr, int total_rows, int total_cols, - const CudaDeviceArrayStruct& output_scan, - const CudaDeviceArrayStruct& output_ptr_data) { - if (fixed_size) { - CudaLaunchConfig config = - GetCudaLaunchConfig(total_rows * total_cols, gpu_device); +void SplitVOpGPULaunch::Run( + const Eigen::GpuDevice& gpu_device, bool fixed_size, const T* input_ptr, + int total_rows, int total_cols, + const GpuDeviceArrayStruct& output_scan, + const GpuDeviceArrayStruct& output_ptr_data) { + if (fixed_size) { + GpuLaunchConfig config = + GetGpuLaunchConfig(total_rows * total_cols, gpu_device); - SplitVOpKernel_fixed<<>>( - input_ptr, total_rows, total_cols, output_ptr_data); - } else { - auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device); - IntType smem_max = gpu_device.sharedMemPerBlock(); - IntType smem_usage = output_scan.size * sizeof(IntType); - // performance crossover is less than using maximum available shared - // memory on most processors possibly due to decreasing occupancy - // 4096 inputs is a lot, most code will take the smem path - const int32 kMaxSmemBytesPerformance = 16384; - if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) - split_v_kernel - <<>>(input_ptr, output_scan, total_rows, - total_cols, output_ptr_data); - else - split_v_kernel - <<>>(input_ptr, output_scan, total_rows, - total_cols, output_ptr_data); - } + TF_CHECK_OK(CudaLaunchKernel(SplitVOpKernel_fixed, config.block_count, + config.thread_per_block, 0, + gpu_device.stream(), input_ptr, total_rows, + total_cols, output_ptr_data)); + } else { + auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device); + IntType smem_max = gpu_device.sharedMemPerBlock(); + IntType smem_usage = output_scan.size * sizeof(IntType); + // performance crossover is less than using maximum available shared + // memory on most processors possibly due to decreasing occupancy + // 4096 inputs is a lot, most code will take the smem path + const int32 kMaxSmemBytesPerformance = 16384; + if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) + TF_CHECK_OK(CudaLaunchKernel( + split_v_kernel, config.block_count, + config.thread_per_block, smem_usage, gpu_device.stream(), input_ptr, + output_scan, total_rows, total_cols, output_ptr_data)); + else + TF_CHECK_OK(CudaLaunchKernel( + split_v_kernel, config.block_count, + config.thread_per_block, 0, gpu_device.stream(), input_ptr, + output_scan, total_rows, total_cols, output_ptr_data)); } -}; +} #define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch; diff --git a/tensorflow/core/kernels/split_lib_gpu.h b/tensorflow/core/kernels/split_lib_gpu.h new file mode 100644 index 00000000000..20feb7df143 --- /dev/null +++ b/tensorflow/core/kernels/split_lib_gpu.h @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ + +#define EIGEN_USE_THREADS +#define EIGEN_USE_GPU + +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/gpu_device_array_gpu.h" +#include "tensorflow/core/kernels/split_lib.h" + +namespace tensorflow { + +template +struct SplitOpGPULaunch { + void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, + int32 split_dim_size, int32 suffix_dim_size, + const GpuDeviceArrayStruct& output_ptr_data); +}; + +template +struct SplitVOpGPULaunch { + void Run(const Eigen::GpuDevice& d, bool fixed, const T* input, + int total_cols, int total_rows, + const GpuDeviceArrayStruct& output_scan, + const GpuDeviceArrayStruct& output_ptr_data); +}; + +// Explicit instantiations in split_lib_gpu.cu.cc. +#define REGISTER_GPU_KERNEL(T) \ + extern template struct SplitOpGPULaunch; \ + extern template struct SplitVOpGPULaunch; \ + extern template struct SplitVOpGPULaunch; + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); +TF_CALL_complex64(REGISTER_GPU_KERNEL); +TF_CALL_complex128(REGISTER_GPU_KERNEL); +TF_CALL_bfloat16(REGISTER_GPU_KERNEL); +TF_CALL_uint8(REGISTER_GPU_KERNEL); +TF_CALL_bool(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index ed3429ff5cb..a419eedb398 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -29,7 +29,8 @@ limitations under the License. #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" -#include "tensorflow/core/kernels/cuda_device_array.h" +#include "tensorflow/core/kernels/gpu_device_array.h" +#include "tensorflow/core/kernels/split_lib_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -267,13 +268,6 @@ class SplitOpCPU : public SplitOpBase { #if GOOGLE_CUDA -template -struct SplitOpGPULaunch { - void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, - int32 split_dim_size, int32 suffix_dim_size, - const CudaDeviceArrayStruct& output_ptr_data); -}; - // Partial specialization for GPU template class SplitOpGPU : public SplitOpBase { @@ -308,7 +302,7 @@ class SplitOpGPU : public SplitOpBase { TensorShape output_shape(input_shape); output_shape.set_dim(split_dim, split_dim_output_size); - CudaDeviceArrayOnHost ptrs(context, num_split); + GpuDeviceArrayOnHost ptrs(context, num_split); OP_REQUIRES_OK(context, ptrs.Init()); for (int i = 0; i < num_split; ++i) { diff --git a/tensorflow/core/kernels/split_v_op.cc b/tensorflow/core/kernels/split_v_op.cc index 0324ce9babc..8e53089af0d 100644 --- a/tensorflow/core/kernels/split_v_op.cc +++ b/tensorflow/core/kernels/split_v_op.cc @@ -35,7 +35,8 @@ limitations under the License. #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" -#include "tensorflow/core/kernels/cuda_device_array.h" +#include "tensorflow/core/kernels/gpu_device_array.h" +#include "tensorflow/core/kernels/split_lib_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -329,14 +330,6 @@ class SplitVOpCPU : public SplitVOpBase { #if GOOGLE_CUDA -template -struct SplitVOpGPULaunch { - void Run(const Eigen::GpuDevice& d, bool fixed, const T* input, - int total_cols, int total_rows, - const CudaDeviceArrayStruct& output_scan, - const CudaDeviceArrayStruct& output_ptr_data); -}; - // Partial specialization for GPU template class SplitVOpGPU : public SplitVOpBase { @@ -373,10 +366,10 @@ class SplitVOpGPU : public SplitVOpBase { // reshape to 2D if (num_split > 16) { - CudaDeviceArrayOnHost ptrs(context, num_split); + GpuDeviceArrayOnHost ptrs(context, num_split); OP_REQUIRES_OK(context, ptrs.Init()); - CudaDeviceArrayOnHost offsets(context, num_split + 1); + GpuDeviceArrayOnHost offsets(context, num_split + 1); OP_REQUIRES_OK(context, offsets.Init()); Tlen offset = 0; diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index 65174e163c1..9c0f370de3b 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -216,7 +216,8 @@ class StageOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp); #endif #ifdef TENSORFLOW_USE_SYCL @@ -249,7 +250,8 @@ class UnstageOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp); #endif #ifdef TENSORFLOW_USE_SYCL @@ -284,7 +286,8 @@ class StagePeekOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), StagePeekOp); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) REGISTER_KERNEL_BUILDER( Name("StagePeek").HostMemory("index").Device(DEVICE_GPU), StagePeekOp); #endif @@ -314,7 +317,8 @@ class StageSizeOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("StageSize").Device(DEVICE_CPU), StageSizeOp); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) REGISTER_KERNEL_BUILDER(Name("StageSize").HostMemory("size").Device(DEVICE_GPU), StageSizeOp); #endif @@ -339,7 +343,8 @@ class StageClearOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_CPU), StageClearOp); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_GPU), StageClearOp); #endif #ifdef TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/stateful_random_ops.cc b/tensorflow/core/kernels/stateful_random_ops.cc index b664bf1c294..cbbce249a66 100644 --- a/tensorflow/core/kernels/stateful_random_ops.cc +++ b/tensorflow/core/kernels/stateful_random_ops.cc @@ -15,18 +15,20 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/random_op.h" +#include "tensorflow/core/kernels/random_op_cpu.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" #include "tensorflow/core/kernels/training_op_helpers.h" +#include "tensorflow/core/lib/random/random.h" namespace tensorflow { template struct UpdateVariableAndFill_Philox { void operator()(OpKernelContext* ctx, const CPUDevice& device, - int64 output_size, int64 alg_tag_skip, + Distribution dist, int64 output_size, int64 alg_tag_skip, ScopedUnlockUnrefVar* state_var_guard, Tensor* state_tensor, - typename Distribution::ResultElementType* output_data) { + typename Distribution::ResultElementType* output_data) + UNLOCK_FUNCTION() { auto state_tensor_flat = state_tensor->flat(); auto state_data = state_tensor_flat.data(); // Delegates to PhiloxRandom to do the actual increasing. @@ -35,14 +37,41 @@ struct UpdateVariableAndFill_Philox { // No longer needs the lock. state_var_guard->Release(); functor::FillPhiloxRandom()( - ctx, device, philox, output_data, output_size, Distribution()); + ctx, device, philox, output_data, output_size, dist); } }; +Status CheckState(const Tensor& state) { + if (state.dtype() != STATE_ELEMENT_DTYPE) { + return errors::InvalidArgument("dtype of RNG state variable must be ", + DataTypeString(STATE_ELEMENT_DTYPE), + ", not ", DataTypeString(state.dtype())); + } + if (state.dims() != 1) { + return errors::InvalidArgument( + "RNG state must have one and only one dimension, not ", state.dims()); + } + return Status::OK(); +} + +Status CheckPhiloxState(const Tensor& state, int64 alg_tag_skip = 0) { + static_assert(std::is_same::value, + "StateElementType must be int64"); + static_assert(std::is_same::value, + "PhiloxRandom::ResultElementType must be uint32"); + if (state.NumElements() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) { + return errors::InvalidArgument( + "For the Philox algorithm, the size of state" + " must be at least ", + alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ", state.NumElements()); + } + return Status::OK(); +} + template Status UpdateVariableAndFill( - OpKernelContext* ctx, int state_input_idx, bool read_alg_from_state, - Algorithm alg, int64 output_size, + OpKernelContext* ctx, Distribution dist, int state_input_idx, + bool read_alg_from_state, Algorithm alg, int64 output_size, typename Distribution::ResultElementType* output_data) { Var* var = nullptr; TF_RETURN_IF_ERROR( @@ -53,17 +82,7 @@ Status UpdateVariableAndFill( // filling. ScopedUnlockUnrefVar state_var_guard(var); Tensor* var_tensor = var->tensor(); - if (var_tensor->dtype() != STATE_ELEMENT_DTYPE) { - return errors::InvalidArgument("dtype of RNG state variable must be ", - DataTypeString(STATE_ELEMENT_DTYPE), - ", not ", - DataTypeString(var_tensor->dtype())); - } - if (var_tensor->dims() != 1) { - return errors::InvalidArgument( - "RNG state must have one and only one dimension, not ", - var_tensor->dims()); - } + TF_RETURN_IF_ERROR(CheckState(*var_tensor)); auto var_tensor_flat = var_tensor->flat(); int64 alg_tag_skip = 0; if (read_alg_from_state) { @@ -74,21 +93,11 @@ Status UpdateVariableAndFill( alg = var_tensor_flat(0); } if (alg == RNG_ALG_PHILOX) { - static_assert(std::is_same::value, - "StateElementType must be int64"); - static_assert(std::is_same::value, - "PhiloxRandom::ResultElementType must be uint32"); - if (var_tensor_flat.size() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) { - return errors::InvalidArgument( - "For the Philox algorithm, the size of state" - " must be at least ", - alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ", - var_tensor_flat.size()); - } + TF_RETURN_IF_ERROR(CheckPhiloxState(*var_tensor, alg_tag_skip)); TF_RETURN_IF_ERROR(PrepareToUpdateVariable( ctx, var_tensor, var->copy_on_read_mode.load())); UpdateVariableAndFill_Philox()( - ctx, ctx->eigen_device(), output_size, alg_tag_skip, + ctx, ctx->eigen_device(), dist, output_size, alg_tag_skip, &state_var_guard, var_tensor, output_data); return Status::OK(); } else { @@ -98,8 +107,9 @@ Status UpdateVariableAndFill( // Preconditon: input(0) is an existing resource. template -void ComputeImpl(OpKernelContext* ctx, int state_input_idx, int shape_input_idx, - bool read_alg_from_state, Algorithm alg) { +void StatefulRandomCompute(OpKernelContext* ctx, Distribution dist, + int state_input_idx, int shape_input_idx, + bool read_alg_from_state, Algorithm alg) { using T = typename Distribution::ResultElementType; const Tensor& shape_t = ctx->input(shape_input_idx); TensorShape shape; @@ -107,8 +117,8 @@ void ComputeImpl(OpKernelContext* ctx, int state_input_idx, int shape_input_idx, Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); auto output_flat = output->flat(); - OP_REQUIRES_OK(ctx, UpdateVariableAndFill( - ctx, state_input_idx, read_alg_from_state, alg, + OP_REQUIRES_OK(ctx, UpdateVariableAndFill( + ctx, dist, state_input_idx, read_alg_from_state, alg, output_flat.size(), output_flat.data())); } @@ -118,75 +128,310 @@ class StatefulRandomOp : public OpKernel { explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - ComputeImpl(ctx, 0, 1, true, 0); + StatefulRandomCompute(ctx, Distribution(), 0, 1, true, 0); } }; +template +Status GetScalar(const Tensor& tensor, int input_idx, T* result) { + auto dtype = DataTypeToEnum::v(); + if (tensor.dims() != 0) { + return errors::InvalidArgument("input ", std::to_string(input_idx), + " (0-based) must have shape [], not ", + tensor.shape().DebugString()); + } + if (tensor.dtype() != dtype) { + return errors::InvalidArgument("dtype of input ", std::to_string(input_idx), + " (0-based) must be ", DataTypeString(dtype), + ", not ", DataTypeString(tensor.dtype())); + } + *result = tensor.flat()(0); + return Status::OK(); +} + template class StatefulRandomOpV2 : public OpKernel { public: explicit StatefulRandomOpV2(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { - const Tensor& alg_tensor = ctx->input(1); - OP_REQUIRES(ctx, alg_tensor.dims() == 0, - errors::InvalidArgument("algorithm must be of shape [], not ", - alg_tensor.shape().DebugString())); - OP_REQUIRES( - ctx, alg_tensor.dtype() == ALGORITHM_DTYPE, - errors::InvalidArgument("algorithm's dtype must be ", - DataTypeString(ALGORITHM_DTYPE), ", not ", - DataTypeString(alg_tensor.dtype()))); - auto alg = alg_tensor.flat()(0); - ComputeImpl(ctx, 0, 2, false, alg); + Algorithm alg; + OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg)); + StatefulRandomCompute(ctx, Distribution(), /*state_input_idx=*/0, + /*shape_input_idx=*/2, + /*read_alg_from_state=*/false, alg); } }; +template +class StatefulUniformIntOp : public OpKernel { + public: + explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Algorithm alg; + OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg)); + const Tensor& minval = ctx->input(3); + const Tensor& maxval = ctx->input(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()), + errors::InvalidArgument("minval must be 0-D, got shape ", + minval.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()), + errors::InvalidArgument("maxval must be 0-D, got shape ", + maxval.shape().DebugString())); + + // Verify that minval < maxval. This check intentionally happens after the + // early exit for empty output. Zero impossible things are fine. + IntType lo = minval.scalar()(); + IntType hi = maxval.scalar()(); + OP_REQUIRES( + ctx, lo < hi, + errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi)); + + // Build distribution + typedef random::UniformDistribution + Distribution; + Distribution dist(lo, hi); + + StatefulRandomCompute(ctx, dist, /*state_input_idx=*/0, + /*shape_input_idx=*/2, + /*read_alg_from_state=*/false, alg); + } +}; + +template +class StatefulUniformFullIntOp : public OpKernel { + public: + explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Algorithm alg; + OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg)); + StatefulRandomCompute( + ctx, + random::UniformFullIntDistribution(), + /*state_input_idx=*/0, /*shape_input_idx=*/2, + /*read_alg_from_state=*/false, alg); + } +}; + +template <> +struct RngSkip_Philox { + void operator()(const CPUDevice& device, int64 delta, Tensor* state_tensor) { + auto state_data = state_tensor->flat().data(); + // Delegates to PhiloxRandom to do the actual increasing. + auto philox = GetPhiloxRandomFromMem(state_data); + UpdateMemWithPhiloxRandom(philox, delta, state_data); + } +}; + +template +class RngSkipOp : public OpKernel { + public: + explicit RngSkipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + auto state_input_idx = 0; + Algorithm alg; + OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg)); + int64 delta; + OP_REQUIRES_OK(ctx, GetScalar(ctx->input(2), 2, &delta)); + Var* var = nullptr; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var)); + ScopedUnlockUnrefVar state_var_guard(var); + Tensor* var_tensor = var->tensor(); + OP_REQUIRES_OK(ctx, CheckState(*var_tensor)); + if (alg == RNG_ALG_PHILOX) { + OP_REQUIRES_OK(ctx, CheckPhiloxState(*var_tensor)); + OP_REQUIRES_OK(ctx, PrepareToUpdateVariable( + ctx, var_tensor, var->copy_on_read_mode.load())); + RngSkip_Philox()(ctx->eigen_device(), delta, var_tensor); + } else { + OP_REQUIRES(ctx, false, + errors::InvalidArgument("Unsupported algorithm id: ", alg)); + } + } +}; + +template +class NonDeterministicIntsOp : public OpKernel { + public: + explicit NonDeterministicIntsOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& shape_t = ctx->input(0); + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->op_kernel().MakeShape(shape_t, &shape)); + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); + if (shape.num_elements() == 0) return; + + switch (dtype_) { + case DT_INT32: + case DT_UINT32: + case DT_INT64: + case DT_UINT64: { + auto output_flat = output->flat(); + auto data = output_flat.data(); + for (int64 i = 0; i < output_flat.size(); ++i) { + data[i] = static_cast(random::New64()); + } + break; + } + default: + OP_REQUIRES(ctx, false, + errors::InvalidArgument("Unsupported dtype: ", + DataTypeString(dtype_))); + } + } + + private: + DataType dtype_; +}; + // So far the 'Distribution' type parameter is only used when the algorithm is // philox, so 'NormalDistribution' is fine for now. -#define REGISTER(DEVICE, TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatefulStandardNormalV2") \ - .Device(DEVICE_##DEVICE) \ - .HostMemory("resource") \ - .HostMemory("algorithm") \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - StatefulRandomOpV2 >); +#define REGISTER_FloatOps(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("StatefulStandardNormalV2") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + StatefulRandomOpV2 >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatefulUniform") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + StatefulRandomOpV2 >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatefulTruncatedNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + StatefulRandomOpV2< \ + DEVICE##Device, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE> >); -// CPU also has the old 'StatefulStandardNormal' op for backward compatibility. -#define REGISTER_CPU(TYPE) \ - REGISTER(CPU, TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatefulStandardNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("resource") \ - .HostMemory("shape") \ - .TypeConstraint("dtype"), \ - StatefulRandomOp("dtype"), \ + StatefulRandomOp >); -#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) +#define REGISTER_FloatOps_GPU(TYPE) REGISTER_FloatOps(GPU, TYPE) -TF_CALL_half(REGISTER_CPU); -TF_CALL_bfloat16(REGISTER_CPU); -TF_CALL_float(REGISTER_CPU); -TF_CALL_double(REGISTER_CPU); +TF_CALL_half(REGISTER_FloatOps_CPU); +TF_CALL_bfloat16(REGISTER_FloatOps_CPU); +TF_CALL_float(REGISTER_FloatOps_CPU); +TF_CALL_double(REGISTER_FloatOps_CPU); + +#define REGISTER_StatefulUniformInt(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatefulUniformInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint("dtype"), \ + StatefulUniformIntOp); + +#define REGISTER_StatefulUniformInt_CPU(TYPE) \ + REGISTER_StatefulUniformInt(CPU, TYPE) +#define REGISTER_StatefulUniformInt_GPU(TYPE) \ + REGISTER_StatefulUniformInt(GPU, TYPE) + +TF_CALL_int32(REGISTER_StatefulUniformInt_CPU); +TF_CALL_int64(REGISTER_StatefulUniformInt_CPU); + +#define REGISTER_StatefulUniformFullInt(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatefulUniformFullInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + StatefulUniformFullIntOp); + +#define REGISTER_StatefulUniformFullInt_CPU(TYPE) \ + REGISTER_StatefulUniformFullInt(CPU, TYPE) +#define REGISTER_StatefulUniformFullInt_GPU(TYPE) \ + REGISTER_StatefulUniformFullInt(GPU, TYPE) + +TF_CALL_int32(REGISTER_StatefulUniformFullInt_CPU); +TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU); +TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU); +TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU); + +#define REGISTER_RngSkip(DEVICE) \ + REGISTER_KERNEL_BUILDER(Name("RngSkip") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("resource") \ + .HostMemory("algorithm") \ + .HostMemory("delta"), \ + RngSkipOp); + +REGISTER_RngSkip(CPU); #if GOOGLE_CUDA -TF_CALL_half(REGISTER_GPU); -TF_CALL_float(REGISTER_GPU); -TF_CALL_double(REGISTER_GPU); +TF_CALL_half(REGISTER_FloatOps_GPU); +TF_CALL_float(REGISTER_FloatOps_GPU); +TF_CALL_double(REGISTER_FloatOps_GPU); +TF_CALL_int32(REGISTER_StatefulUniformInt_GPU); +TF_CALL_int64(REGISTER_StatefulUniformInt_GPU); +TF_CALL_int32(REGISTER_StatefulUniformFullInt_GPU); +TF_CALL_int64(REGISTER_StatefulUniformFullInt_GPU); +TF_CALL_uint32(REGISTER_StatefulUniformFullInt_GPU); +TF_CALL_uint64(REGISTER_StatefulUniformFullInt_GPU); +REGISTER_RngSkip(GPU); #endif // GOOGLE_CUDA -#undef REGISTER_GPU -#undef REGISTER_CPU -#undef REGISTER +#undef REGISTER_StatefulUniformFullInt_GPU +#undef REGISTER_StatefulUniformFullInt_CPU +#undef REGISTER_StatefulUniformFullInt +#undef REGISTER_StatefulUniformInt_GPU +#undef REGISTER_StatefulUniformInt_CPU +#undef REGISTER_StatefulUniformInt +#undef REGISTER_FloatOps_GPU +#undef REGISTER_FloatOps_CPU +#undef REGISTER_FloatOps + +#define REGISTER_NonDeterministicInts(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("NonDeterministicInts") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("dtype"), \ + NonDeterministicIntsOp); + +TF_CALL_int32(REGISTER_NonDeterministicInts); +TF_CALL_uint32(REGISTER_NonDeterministicInts); +TF_CALL_int64(REGISTER_NonDeterministicInts); +TF_CALL_uint64(REGISTER_NonDeterministicInts); + +#undef REGISTER_NonDeterministicInts // TODO(wangpeng): Add RNG ops for other distributions. -// TODO(wangpeng): Add support for XLA. } // end namespace tensorflow diff --git a/tensorflow/core/kernels/stateful_random_ops.h b/tensorflow/core/kernels/stateful_random_ops.h index 25d0ce7dfe5..58ab41426f1 100644 --- a/tensorflow/core/kernels/stateful_random_ops.h +++ b/tensorflow/core/kernels/stateful_random_ops.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_ #define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_ -// #include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/random/philox_random.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h b/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h index a07f48a71cb..f3d966b6d64 100644 --- a/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h +++ b/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h @@ -27,24 +27,21 @@ namespace tensorflow { // The following 2 functions use the contract "lower 32 bits for the first // uint32, higher 32 bits for the second". Note that this is endian-neutral, // unlike a direct memory copy `memcpy(output, &input, 8)`. -template -PHILOX_DEVICE_FUNC void Int64ToUint32s(INT64 input, uint32* output1, - uint32* output2) { +PHILOX_DEVICE_INLINE void Int64ToUint32s(int64 input, uint32* output1, + uint32* output2) { auto u64 = static_cast(input); *output1 = static_cast(u64); *output2 = static_cast(u64 >> 32); } -template -PHILOX_DEVICE_FUNC int64 Uint32sToInt64(UINT32 input1, UINT32 input2) { +PHILOX_DEVICE_INLINE int64 Uint32sToInt64(uint32 input1, uint32 input2) { auto u64_1 = static_cast(input1); auto u64_2 = static_cast(input2); return static_cast(u64_1 | (u64_2 << 32)); } -template -PHILOX_DEVICE_FUNC PhiloxRandom -GetPhiloxRandomFromMem(STATE_ELEMENT_TYPE const* ptr) { +PHILOX_DEVICE_INLINE PhiloxRandom +GetPhiloxRandomFromMem(StateElementType const* ptr) { PhiloxRandom::ResultType counter; PhiloxRandom::Key key; Int64ToUint32s(ptr[0], &counter[0], &counter[1]); @@ -53,9 +50,8 @@ GetPhiloxRandomFromMem(STATE_ELEMENT_TYPE const* ptr) { return PhiloxRandom(counter, key); } -template -PHILOX_DEVICE_FUNC void WritePhiloxRandomToMem(PHILOX_RANDOM const& philox, - StateElementType* ptr) { +PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox, + StateElementType* ptr) { PhiloxRandom::ResultType const& counter = philox.counter(); PhiloxRandom::Key const& key = philox.key(); ptr[0] = Uint32sToInt64(counter[0], counter[1]); @@ -63,10 +59,9 @@ PHILOX_DEVICE_FUNC void WritePhiloxRandomToMem(PHILOX_RANDOM const& philox, ptr[2] = Uint32sToInt64(key[0], key[1]); } -template -PHILOX_DEVICE_FUNC void UpdateMemWithPhiloxRandom(PHILOX_RANDOM const& philox, - int64 output_size, - StateElementType* ptr) { +PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox, + int64 output_size, + StateElementType* ptr) { auto new_philox = philox; // Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change // it just here. @@ -82,21 +77,29 @@ PHILOX_DEVICE_FUNC void UpdateMemWithPhiloxRandom(PHILOX_RANDOM const& philox, template struct UpdateVariableAndFill_Philox; +template +struct RngSkip_Philox; + using CPUDevice = Eigen::ThreadPoolDevice; #if GOOGLE_CUDA using GPUDevice = Eigen::GpuDevice; -// Declares the partially GPU-specialized functor struct. +// Declares the partially GPU-specialized functor structs. template struct UpdateVariableAndFill_Philox { void operator()(OpKernelContext* ctx, const GPUDevice& device, - int64 output_size, int64 alg_tag_skip, + Distribution dist, int64 output_size, int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor, typename Distribution::ResultElementType* output_data); }; +template <> +struct RngSkip_Philox { + void operator()(const GPUDevice& device, int64 delta, Tensor* state_tensor); +}; + #endif // GOOGLE_CUDA } // end namespace tensorflow diff --git a/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc b/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc index 99ce3e677d8..8d6e826d625 100644 --- a/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/stateful_random_ops_gpu.cu.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/random_op_gpu.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" -#include "tensorflow/core/util/cuda_launch_config.h" +#include "tensorflow/core/util/gpu_launch_config.h" namespace tensorflow { @@ -53,8 +53,9 @@ __global__ void FillKernel( template void UpdateVariableAndFill_Philox::operator()( - OpKernelContext* ctx, const GPUDevice& d, int64 output_size, - int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, Tensor* state_tensor, + OpKernelContext* ctx, const GPUDevice& d, Distribution dist, + int64 output_size, int64 alg_tag_skip, ScopedUnlockUnrefVar* not_used, + Tensor* state_tensor, typename Distribution::ResultElementType* output_data) { OP_REQUIRES( ctx, alg_tag_skip == 0, @@ -69,15 +70,26 @@ void UpdateVariableAndFill_Philox::operator()( // maximize occupancy const int kGroupSize = Distribution::kResultElementCount; int work_element_count = (output_size + kGroupSize - 1) / kGroupSize; - CudaLaunchConfig cfg = GetCudaLaunchConfig(work_element_count, d, - FillKernel, 0, 0); + GpuLaunchConfig cfg = GetCudaLaunchConfig(work_element_count, d, + FillKernel, 0, 0); int zero = 0; cudaMemcpyToSymbol(thread_counter, &zero, sizeof(int)); - TF_CHECK_OK(CudaLaunchKernel(FillKernel, cfg.block_count, - cfg.thread_per_block, 0, d.stream(), - Distribution(), state_size, output_size, - state_data, output_data)); + TF_CHECK_OK(CudaLaunchKernel( + FillKernel, cfg.block_count, cfg.thread_per_block, 0, + d.stream(), dist, state_size, output_size, state_data, output_data)); +} + +// Precondition: there is only 1 block and 1 thread. +__global__ void SkipKernel(int64 delta, StateElementType* state_data) { + auto philox = GetPhiloxRandomFromMem(state_data); + UpdateMemWithPhiloxRandom(philox, delta, state_data); +} + +void RngSkip_Philox::operator()(const GPUDevice& d, int64 delta, + Tensor* state_tensor) { + SkipKernel<<<1, 1, 0, d.stream()>>>( + delta, state_tensor->flat().data()); } // Explicit instantiation of the GPU distributions functors. @@ -90,6 +102,40 @@ template struct UpdateVariableAndFill_Philox< GPUDevice, random::NormalDistribution >; template struct UpdateVariableAndFill_Philox< GPUDevice, random::NormalDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::TruncatedNormalDistribution< + random::SingleSampleAdapter, + Eigen::half> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::TruncatedNormalDistribution< + random::SingleSampleAdapter, + float> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::TruncatedNormalDistribution< + random::SingleSampleAdapter, + double> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformDistribution >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformFullIntDistribution< + random::PhiloxRandom, int32> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformFullIntDistribution< + random::PhiloxRandom, int64> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformFullIntDistribution< + random::PhiloxRandom, uint32> >; +template struct UpdateVariableAndFill_Philox< + GPUDevice, random::UniformFullIntDistribution< + random::PhiloxRandom, uint64> >; // clang-format on } // end namespace tensorflow diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 20bf42ccaa2..03169965b14 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -22,16 +22,17 @@ limitations under the License. #endif // GOOGLE_CUDA #include "tensorflow/core/kernels/strided_slice_op.h" -#include "tensorflow/core/kernels/dense_update_functor.h" -#include "tensorflow/core/kernels/slice_op.h" -#include "tensorflow/core/kernels/strided_slice_op_impl.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/dense_update_functor.h" +#include "tensorflow/core/kernels/inplace_ops_functor.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/slice_op.h" +#include "tensorflow/core/kernels/strided_slice_op_impl.h" #include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/status.h" @@ -123,11 +124,8 @@ class StridedSliceOp : public OpKernel { "Input must have rank at least 1, got: ", input.dims())); // Otherwise, is_identity should be true. VLOG(1) << "Strided slice dim 0: " << input.shape().DebugString(); - OP_REQUIRES( - context, begin[0] <= end[0], - errors::InvalidArgument("begin[0] (", begin[0], - ") must less or equal to end[0] (", end[0])); - Tensor slice = input.Slice(begin[0], end[0]); + // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end). + Tensor slice = input.Slice(std::min(begin[0], end[0]), end[0]); Tensor tmp; OP_REQUIRES(context, tmp.CopyFrom(slice, final_shape), errors::Internal("Copy failed")); @@ -278,7 +276,7 @@ class StridedSliceGradOp : public OpKernel { int32 ellipsis_mask, new_axis_mask, shrink_axis_mask; }; -template +template class StridedSliceAssignOp : public OpKernel { public: explicit StridedSliceAssignOp(OpKernelConstruction* context) @@ -302,24 +300,47 @@ class StridedSliceAssignOp : public OpKernel { Tensor* old_lhs = nullptr; Tensor tmp; - if (context->input_dtype(0) == DT_RESOURCE) { - Var* v; - OP_REQUIRES_OK(context, - LookupResource(context, HandleFromInput(context, 0), &v)); - core::ScopedUnref scoped_unref(v); - OP_REQUIRES_OK(context, - EnsureSparseVariableAccess(context, v)); - mutex_lock ml(*v->mu()); - old_lhs = v->tensor(); - OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum::value, - errors::InvalidArgument( - "l-value dtype ", DataTypeString(old_lhs->dtype()), - " does not match r-value dtype ", - DataTypeString(DataTypeToEnum::value))); + if (isTensor) { + const Tensor& input = context->input(0); + TensorShape shape = input.shape(); + + std::unique_ptr forwarded_input = context->forward_input( + 0, 0, input.dtype(), shape, DEVICE_MEMORY, AllocatorAttributes()); + + if (forwarded_input == nullptr) { + Tensor* out; + // We were not able to forward the input, so we deep copy the tensor and + // set the output. + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &out)); + + OP_REQUIRES_OK(context, + tensorflow::functor::DoCopy( + context->eigen_device(), input, out)); + old_lhs = out; + } else { + old_lhs = forwarded_input.get(); + } } else { - context->forward_ref_input_to_ref_output(0, 0); - tmp = context->mutable_input(0, true); - old_lhs = &tmp; + if (context->input_dtype(0) == DT_RESOURCE) { + Var* v; + OP_REQUIRES_OK( + context, LookupResource(context, HandleFromInput(context, 0), &v)); + core::ScopedUnref scoped_unref(v); + OP_REQUIRES_OK(context, + EnsureSparseVariableAccess(context, v)); + mutex_lock ml(*v->mu()); + old_lhs = v->tensor(); + OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum::value, + errors::InvalidArgument( + "l-value dtype ", DataTypeString(old_lhs->dtype()), + " does not match r-value dtype ", + DataTypeString(DataTypeToEnum::value))); + } else { + context->forward_ref_input_to_ref_output(0, 0); + tmp = context->mutable_input(0, true); + old_lhs = &tmp; + } } OP_REQUIRES_OK( @@ -376,37 +397,44 @@ class StridedSliceAssignOp : public OpKernel { int32 ellipsis_mask, new_axis_mask, shrink_axis_mask; }; -#define REGISTER_STRIDED_SLICE(type) \ - REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceOp) \ - REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("shape") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceGradOp) \ - REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceAssignOp) \ - REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("ref") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceAssignOp) +#define REGISTER_STRIDED_SLICE(type) \ + REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceOp) \ + REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("shape") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceGradOp) \ + REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) \ + REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("ref") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) \ + REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE); @@ -414,37 +442,44 @@ TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE); #if GOOGLE_CUDA -#define REGISTER_GPU(type) \ - REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceOp) \ - REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("shape") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceGradOp) \ - REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceAssignOp) \ - REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("ref") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceAssignOp) +#define REGISTER_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceOp) \ + REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("shape") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceGradOp) \ + REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) \ + REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("ref") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) \ + REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_bool(REGISTER_GPU); @@ -482,7 +517,7 @@ REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") .HostMemory("begin") .HostMemory("end") .HostMemory("strides"), - StridedSliceAssignOp) + StridedSliceAssignOp); REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") .Device(DEVICE_GPU) .TypeConstraint("T") @@ -490,43 +525,58 @@ REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") .HostMemory("begin") .HostMemory("end") .HostMemory("strides"), - StridedSliceAssignOp) + StridedSliceAssignOp); +REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("input") + .HostMemory("begin") + .HostMemory("end") + .HostMemory("strides"), + StridedSliceAssignOp); #undef REGISTER_GPU #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL(type) \ - REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceOp) \ - REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .HostMemory("shape") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceGradOp) \ - REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceAssignOp) \ - REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T") \ - .HostMemory("ref") \ - .HostMemory("begin") \ - .HostMemory("end") \ - .HostMemory("strides"), \ - StridedSliceAssignOp) +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceOp) \ + REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("shape") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceGradOp) \ + REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) \ + REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("ref") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) \ + REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides"), \ + StridedSliceAssignOp) TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL); @@ -556,7 +606,7 @@ REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") .HostMemory("begin") .HostMemory("end") .HostMemory("strides"), - StridedSliceAssignOp) + StridedSliceAssignOp); REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") .Device(DEVICE_SYCL) .TypeConstraint("T") @@ -564,7 +614,14 @@ REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") .HostMemory("begin") .HostMemory("end") .HostMemory("strides"), - StridedSliceAssignOp) + StridedSliceAssignOp); +REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .HostMemory("begin") + .HostMemory("end") + .HostMemory("strides"), + StridedSliceAssignOp) #undef REGISTER_SYCL #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/string_lower_op.cc b/tensorflow/core/kernels/string_lower_op.cc new file mode 100644 index 00000000000..e24eedcc3ae --- /dev/null +++ b/tensorflow/core/kernels/string_lower_op.cc @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/string_ops.cc. + +#include + +#include "absl/strings/ascii.h" +#include "unicode/unistr.h" // TF:icu +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +class StringLowerOp : public OpKernel { + public: + explicit StringLowerOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("encoding", &encoding_)); + OP_REQUIRES(context, encoding_.empty() || encoding_ == "utf-8", + errors::InvalidArgument( + "only utf-8 or '' (no encoding) is supported, received ", + encoding_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + Tensor* output_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor)); + + const auto input = input_tensor->flat(); + auto output = output_tensor->flat(); + + if (encoding_.empty()) { + for (int64 i = 0; i < input.size(); ++i) { + StringPiece entry(input(i)); + output(i) = absl::AsciiStrToLower(entry); + } + } else { + // The validation of utf-8 has already been done in GetAttr above. + for (int64 i = 0; i < input.size(); ++i) { + icu::UnicodeString us(input(i).c_str(), "UTF-8"); + us.toLower(); + us.toUTF8String(output(i)); + } + } + } + + private: + string encoding_; +}; + +REGISTER_KERNEL_BUILDER(Name("StringLower").Device(DEVICE_CPU), StringLowerOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/string_upper_op.cc b/tensorflow/core/kernels/string_upper_op.cc new file mode 100644 index 00000000000..f2a1d33e7a6 --- /dev/null +++ b/tensorflow/core/kernels/string_upper_op.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/string_ops.cc. + +#include + +#include "absl/strings/ascii.h" +#include "unicode/unistr.h" // TF:icu +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +class StringUpperOp : public OpKernel { + public: + explicit StringUpperOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("encoding", &encoding_)); + OP_REQUIRES(context, encoding_.empty() || encoding_ == "utf-8", + errors::InvalidArgument( + "only utf-8 or '' (no encoding) is supported, received ", + encoding_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + Tensor* output_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor)); + + const auto input = input_tensor->flat(); + auto output = output_tensor->flat(); + if (encoding_.empty()) { + for (int64 i = 0; i < input.size(); ++i) { + StringPiece entry(input(i)); + output(i) = absl::AsciiStrToUpper(entry); + } + } else { + // The validation of utf-8 has already been done in GetAttr above. + for (int64 i = 0; i < input.size(); ++i) { + icu::UnicodeString us(input(i).c_str(), "UTF-8"); + us.toUpper(); + us.toUTF8String(output(i)); + } + } + } + + private: + string encoding_; +}; + +REGISTER_KERNEL_BUILDER(Name("StringUpper").Device(DEVICE_CPU), StringUpperOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/string_view_variant_wrapper.h b/tensorflow/core/kernels/string_view_variant_wrapper.h deleted file mode 100644 index dc4a8e95348..00000000000 --- a/tensorflow/core/kernels/string_view_variant_wrapper.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_KERNELS_STRING_VIEW_VARIANT_WRAPPER_H_ -#define TENSORFLOW_CORE_KERNELS_STRING_VIEW_VARIANT_WRAPPER_H_ - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/framework/variant_tensor_data.h" - -namespace tensorflow { - -// A wrapper class for storing an `absl::string_view` instance in a DT_VARIANT -// tensor. -class StringViewVariantWrapper { - public: - static constexpr const char kTypeName[] = - "tensorflow::StringViewVariantWrapper"; - - using value_type = absl::string_view; - - StringViewVariantWrapper() = default; - - explicit StringViewVariantWrapper(absl::string_view str_view) - : str_view_(str_view) {} - - StringViewVariantWrapper(const StringViewVariantWrapper& other) - : str_view_(other.str_view_) {} - - const absl::string_view* get() const { return &str_view_; } - - static string TypeName() { return kTypeName; } - - string DebugString() const { return string(str_view_); } - - void Encode(VariantTensorData* data) const { - data->add_tensor(string(str_view_)); - } - - // Decode assumes that the source VariantTensorData will have a longer - // lifetime than this StringViewVariantWrapper. - bool Decode(const VariantTensorData& data) { - if (data.tensors_size() != 1 || data.tensors(0).dtype() != DT_STRING) { - return false; - } - str_view_ = data.tensors(0).scalar()(); - return true; - } - - private: - absl::string_view str_view_; -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_STRING_VIEW_VARIANT_WRAPPER_H_ diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index d33c0cdb7f0..6ea9e19018e 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/summary/schema.h" @@ -147,6 +148,43 @@ class WriteSummaryOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU), WriteSummaryOp); +class WriteRawProtoSummaryOp : public OpKernel { + public: + explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), + errors::InvalidArgument("step must be scalar, got shape ", + tmp->shape().DebugString())); + const int64 step = tmp->scalar()(); + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); + std::unique_ptr event{new Event}; + event->set_step(step); + event->set_wall_time(static_cast(ctx->env()->NowMicros()) / 1.0e6); + // Each Summary proto contains just one repeated field "value" of Value + // messages with the actual data, so repeated Merge() is equivalent to + // concatenating all the Value entries together into a single Event. + const auto summary_pbs = t->flat(); + for (int i = 0; i < summary_pbs.size(); ++i) { + if (!event->mutable_summary()->MergeFromString(summary_pbs(i))) { + ctx->CtxFailureWithWarning(errors::DataLoss( + "Bad tf.compat.v1.Summary binary proto tensor string at index ", + i)); + return; + } + } + OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event))); + } +}; +REGISTER_KERNEL_BUILDER(Name("WriteRawProtoSummary").Device(DEVICE_CPU), + WriteRawProtoSummaryOp); + class ImportEventOp : public OpKernel { public: explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc index 9e308cfc023..3f51820cd55 100644 --- a/tensorflow/core/kernels/svd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc @@ -43,7 +43,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -74,7 +74,7 @@ __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m, // Extracts the sign of V // V[i] = V[i]>=0 ? 1 : 0 template -__global__ void ExtractSignOfVKernel(CudaLaunchConfig config, Scalar* V) { +__global__ void ExtractSignOfVKernel(GpuLaunchConfig config, Scalar* V) { CUDA_1D_KERNEL_LOOP(i, config.virtual_thread_count) { V[i] = V[i] >= 0 ? Scalar(1) : Scalar(-1); } @@ -196,14 +196,16 @@ class SvdOpGpu : public AsyncOpKernel { const GPUDevice& d = context->eigen_device(); d.memset(outputV_ptr, 0, batch_size * sizeof(Scalar)); Cuda2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d); - ComputeValueOfVKernel<<>>( - cfg2D, m, full_matrices_ ? m : p, input_copy.flat().data(), - outputU_ptr, outputS_ptr, outputV_ptr); + TF_CHECK_OK(CudaLaunchKernel(ComputeValueOfVKernel, + cfg2D.block_count, cfg2D.thread_per_block, 0, + d.stream(), cfg2D, m, full_matrices_ ? m : p, + input_copy.flat().data(), + outputU_ptr, outputS_ptr, outputV_ptr)); // 2. clamp V to -1 or +1 - CudaLaunchConfig cfg1D = GetCudaLaunchConfig(batch_size, d); - ExtractSignOfVKernel<<>>(cfg1D, outputV_ptr); + GpuLaunchConfig cfg1D = GetCudaLaunchConfig(batch_size, d); + TF_CHECK_OK(CudaLaunchKernel(ExtractSignOfVKernel, + cfg1D.block_count, cfg1D.thread_per_block, 0, + d.stream(), cfg1D, outputV_ptr)); } if (compute_uv_) { diff --git a/tensorflow/core/kernels/tensor_flag_utils.h b/tensorflow/core/kernels/tensor_flag_utils.h index f406c73a297..ab59eecc256 100644 --- a/tensorflow/core/kernels/tensor_flag_utils.h +++ b/tensorflow/core/kernels/tensor_flag_utils.h @@ -36,7 +36,7 @@ std::vector ParseRowStartIndices( // Returns Status::OK() if and only if config is a float scalar or a matrix with // dimensions M x 3. If config is a scalar then config must be in the range -// [0, 1.0). If confix is a matrix then config must have shape M x 3, all of +// [0, 1.0). If config is a matrix then config must have shape M x 3, all of // its entries must be positive, and entries in the last column may not // exceed 1.0. If config is a matrix then it may not be empty. Status ValidateSparseMatrixShardingConfig(const Tensor& config); diff --git a/tensorflow/core/kernels/tile_functor.h b/tensorflow/core/kernels/tile_functor.h index 9a460d191fc..d41b17459e6 100644 --- a/tensorflow/core/kernels/tile_functor.h +++ b/tensorflow/core/kernels/tile_functor.h @@ -26,9 +26,21 @@ namespace tensorflow { namespace internal { -// Device-specific naive implementation for tile. -template -void TileSimple(const Device& d, Tensor* out, const Tensor& in); +// Device-specific naive implementation for Tile. + +template +void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out, + const Tensor& in); + +#if GOOGLE_CUDA +template +void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in); +#endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +template +void TileSimple(const Eigen::SyclDevice& d, Tensor* out, const Tensor& in); +#endif template void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in, @@ -99,7 +111,7 @@ struct Tile { broadcast_array); break; default: - internal::TileSimple(d, out, in); + internal::TileSimple(d, out, in); break; } } diff --git a/tensorflow/core/kernels/tile_functor_cpu.cc b/tensorflow/core/kernels/tile_functor_cpu.cc index 43fd0d20adb..5a8af3468fa 100644 --- a/tensorflow/core/kernels/tile_functor_cpu.cc +++ b/tensorflow/core/kernels/tile_functor_cpu.cc @@ -21,11 +21,11 @@ limitations under the License. #include "tensorflow/core/kernels/tile_functor.h" namespace tensorflow { - namespace internal { +namespace { template -void TileSimple(const Device& d, Tensor* out, const Tensor& in) { +void TileSimpleImpl(const Device& d, Tensor* out, const Tensor& in) { const int ndims = in.dims(); const int64 nelem = out->NumElements(); gtl::InlinedVector in_strides = ComputeStride(in.shape()); @@ -44,7 +44,21 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) { } } -} // end namespace internal +} // namespace + +template +void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out, + const Tensor& in) { + return TileSimpleImpl(d, out, in); +} +#ifdef TENSORFLOW_USE_SYCL +template +void TileSimple(const Eigen::SyclDevice& d, Tensor* out, const Tensor& in) { + return TileSimpleImpl(d, out, in); +} +#endif + +} // namespace internal namespace functor { @@ -60,6 +74,7 @@ TF_CALL_float(DEFINE_TYPE); TF_CALL_bfloat16(DEFINE_TYPE); TF_CALL_double(DEFINE_TYPE); TF_CALL_uint8(DEFINE_TYPE); +TF_CALL_int8(DEFINE_TYPE); TF_CALL_int32(DEFINE_TYPE); TF_CALL_int16(DEFINE_TYPE); TF_CALL_int64(DEFINE_TYPE); diff --git a/tensorflow/core/kernels/tile_functor_gpu.h b/tensorflow/core/kernels/tile_functor_gpu.h index 59bc2d3a008..7d45a9843fd 100644 --- a/tensorflow/core/kernels/tile_functor_gpu.h +++ b/tensorflow/core/kernels/tile_functor_gpu.h @@ -21,11 +21,10 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/tile_functor.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace internal { @@ -47,8 +46,8 @@ __global__ void TileKernel(int nthreads, const T* src, const int32* buf, } } -template -void TileSimple(const Device& d, Tensor* out, const Tensor& in) { +template +void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in) { // Ensures we can use 32-bit index. const int64 in_nelem = in.NumElements(); CHECK_LT(in_nelem, kint32max) << "Tensor too large to transpose on GPU"; @@ -74,7 +73,7 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) { // Launch kernel to q[...] = p[...]. const T* p = in.flat().data(); T* q = out->flat().data(); - CudaLaunchConfig cfg = GetCudaLaunchConfig(out_nelem, d); + GpuLaunchConfig cfg = GetCudaLaunchConfig(out_nelem, d); TF_CHECK_OK( CudaLaunchKernel(TileKernel, cfg.block_count, cfg.thread_per_block, 0, d.stream(), cfg.virtual_thread_count, p, @@ -85,6 +84,7 @@ void TileSimple(const Device& d, Tensor* out, const Tensor& in) { } // end namespace internal } // namespace tensorflow + #endif // GOOGLE_CUDA #endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_GPU_H_ diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc index 2e01fa17630..34c6085ee1c 100644 --- a/tensorflow/core/kernels/tile_ops.cc +++ b/tensorflow/core/kernels/tile_ops.cc @@ -78,6 +78,102 @@ struct ReduceAndReshape { const Eigen::DSizes& reduce_dim, const Eigen::DSizes& reshape_dim) const; }; + +// Explicit instantiations are defined in tile_ops_{cpu,gpu}_impl.*, +// below are their declarations. + +#ifdef GOOGLE_CUDA +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +extern template struct Tile; +#define DECLARE_CUDA_DIM(T, NDIM) \ + extern template struct TileGrad; \ + extern template struct ReduceAndReshape +#else // GOOGLE_CUDA +#define DECLARE_CUDA_DIM(T, NDIM) +#endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +#define DECLARE_TYPE(T) \ + extern template struct Tile; \ + extern template struct Tile; +TF_CALL_bool(DECLARE_TYPE); +TF_CALL_float(DECLARE_TYPE); +TF_CALL_bfloat16(DECLARE_TYPE); +TF_CALL_double(DECLARE_TYPE); +TF_CALL_uint8(DECLARE_TYPE); +TF_CALL_int32(DECLARE_TYPE); +TF_CALL_int16(DECLARE_TYPE); +TF_CALL_int64(DECLARE_TYPE); +#undef DECLARE_TYPE +#define DECLARE_SYCL_DIM(T, NDIM) \ + extern template struct TileGrad; \ + extern template struct ReduceAndReshape +#else // TENSORFLOW_USE_SYCL +#define DECLARE_SYCL_DIM(T, NDIM) +#endif // TENSORFLOW_USE_SYCL + +#define DECLARE_TYPE(T) \ + extern template struct Tile; \ + extern template struct Tile; +TF_CALL_bool(DECLARE_TYPE); +TF_CALL_float(DECLARE_TYPE); +TF_CALL_bfloat16(DECLARE_TYPE); +TF_CALL_double(DECLARE_TYPE); +TF_CALL_uint8(DECLARE_TYPE); +TF_CALL_int32(DECLARE_TYPE); +TF_CALL_int16(DECLARE_TYPE); +TF_CALL_int64(DECLARE_TYPE); +TF_CALL_half(DECLARE_TYPE); +TF_CALL_complex64(DECLARE_TYPE); +TF_CALL_complex128(DECLARE_TYPE); +TF_CALL_string(DECLARE_TYPE); +#undef DECLARE_TYPE + +#define DECLARE_DIM(T, NDIM) \ + DECLARE_CUDA_DIM(T, NDIM); \ + DECLARE_SYCL_DIM(T, NDIM); \ + extern template struct TileGrad; \ + extern template struct ReduceAndReshape; + +#define DECLARE_TYPE(T) \ + DECLARE_DIM(T, 1) \ + DECLARE_DIM(T, 2) \ + DECLARE_DIM(T, 3) \ + DECLARE_DIM(T, 4) \ + DECLARE_DIM(T, 5) \ + DECLARE_DIM(T, 6) \ + DECLARE_DIM(T, 7) +TF_CALL_float(DECLARE_TYPE); +TF_CALL_bfloat16(DECLARE_TYPE); +TF_CALL_double(DECLARE_TYPE); +TF_CALL_int16(DECLARE_TYPE); +TF_CALL_int32(DECLARE_TYPE); +TF_CALL_int64(DECLARE_TYPE); +TF_CALL_half(DECLARE_TYPE); +TF_CALL_complex64(DECLARE_TYPE); +TF_CALL_complex128(DECLARE_TYPE); +#undef DECLARE_TYPE + +#undef DECLARE_DIM +#undef DECLARE_SYCL_DIM +#undef DECLARE_CUDA_DIM + } // namespace functor // -------------------------------------------------------------------------- @@ -140,6 +236,7 @@ class TileOp : public OpKernel { TF_CALL_float(HANDLE_TYPE_NAME); TF_CALL_double(HANDLE_TYPE_NAME); TF_CALL_uint8(HANDLE_TYPE_NAME); + TF_CALL_int8(HANDLE_TYPE_NAME); TF_CALL_int32(HANDLE_TYPE_NAME); TF_CALL_int16(HANDLE_TYPE_NAME); TF_CALL_int64(HANDLE_TYPE_NAME); @@ -218,6 +315,7 @@ TF_CALL_float(HANDLE_TYPE_NAME_CPU); TF_CALL_bfloat16(HANDLE_TYPE_NAME_CPU); TF_CALL_double(HANDLE_TYPE_NAME_CPU); TF_CALL_uint8(HANDLE_TYPE_NAME_CPU); +TF_CALL_int8(HANDLE_TYPE_NAME_CPU); TF_CALL_int32(HANDLE_TYPE_NAME_CPU); TF_CALL_int16(HANDLE_TYPE_NAME_CPU); TF_CALL_int64(HANDLE_TYPE_NAME_CPU); diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc index 2f6fffed2fd..f51deb20196 100644 --- a/tensorflow/core/kernels/topk_op.cc +++ b/tensorflow/core/kernels/topk_op.cc @@ -134,7 +134,7 @@ struct TopKFunctor { return Status::OK(); } - auto SortIndices = [&, context](int start_batch, int limit_batch) { + auto SortIndices = [&](int start_batch, int limit_batch) { for (int32 b = start_batch; b < limit_batch; ++b) { const T* input_data = &input(b, 0); const auto stable_comp = [input_data](const int32 a, const int32 b) { diff --git a/tensorflow/core/kernels/topk_op_gpu.h b/tensorflow/core/kernels/topk_op_gpu.h index 1bcc0221b87..e0a813b3be2 100644 --- a/tensorflow/core/kernels/topk_op_gpu.h +++ b/tensorflow/core/kernels/topk_op_gpu.h @@ -21,6 +21,7 @@ limitations under the License. #include #include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/cub/device/device_segmented_radix_sort.cuh" #include "third_party/cub/iterator/counting_input_iterator.cuh" @@ -33,7 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" // Required for sorting Eigen::half namespace cub { diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 5594c998dd1..b6c37707acc 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -34,6 +34,7 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; using SYCLDevice = Eigen::SyclDevice; +using Index = Eigen::Index; namespace { template @@ -310,6 +311,19 @@ struct ApplyAdamNonCuda { typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad, bool use_nesterov) { + // Get params length and check if they can be vectorized by packet size. + Index length = var.size(); + Index packet_size = Eigen::internal::packet_traits::size; + if (length % packet_size == 0) { + length = length / packet_size; + } else { + packet_size = 1; + } + + T* var_ptr = var.data(); + T* m_ptr = m.data(); + T* v_ptr = v.data(); + const T* g_ptr = grad.data(); const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / (T(1) - beta1_power()); // beta1 == μ @@ -317,14 +331,45 @@ struct ApplyAdamNonCuda { // v == n // var == θ - m.device(d) += (grad - m) * (T(1) - beta1()); - v.device(d) += (grad.square() - v) * (T(1) - beta2()); - if (use_nesterov) { - var.device(d) -= ((grad * (T(1) - beta1()) + beta1() * m) * alpha) / - (v.sqrt() + epsilon()); - } else { - var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); - } + auto shard = [this, var_ptr, m_ptr, v_ptr, g_ptr, alpha, beta1, beta2, + epsilon, use_nesterov, packet_size](int begin, int end) { + int t_size = (end - begin) * packet_size; + begin = begin * packet_size; + auto var = typename TTypes::UnalignedTensor(var_ptr + begin, t_size); + auto m = typename TTypes::UnalignedTensor(m_ptr + begin, t_size); + auto v = typename TTypes::UnalignedTensor(v_ptr + begin, t_size); + auto g = typename TTypes::UnalignedConstTensor(g_ptr + begin, t_size); + + if (use_nesterov) { + m += (g - m) * (T(1) - beta1()); + v += (g.square() - v) * (T(1) - beta2()); + var -= ((g * (T(1) - beta1()) + beta1() * m) * alpha) / + (v.sqrt() + epsilon()); + } else { + m += (g - m) * (T(1) - beta1()); + v += (g.square() - v) * (T(1) - beta2()); + var -= (m * alpha) / (v.sqrt() + epsilon()); + } + }; + + // Input data: var, v, m, grad. + // Output data: var, v, m. + const int input_bytes = length * packet_size * sizeof(T) * 4; + const int output_bytes = length * packet_size * sizeof(T) * 3; + const int compute_cycles = + // Consider Sub as Add + (Eigen::TensorOpCost::AddCost() * 5 + + Eigen::TensorOpCost::MulCost() * 2 + + Eigen::TensorOpCost::AddCost() * 10 + + Eigen::TensorOpCost::MulCost() * 6 + + Eigen::TensorOpCost::DivCost()) * + length; + const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); + + // Eigen device must update 3 variables with 3 different expressions, + // which is bad for cache locality on CPU. Here use ParallelFor instead of + // "regular" tensor expressions to get better performance. + d.parallelFor(length, cost, shard); } }; @@ -562,7 +607,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -723,7 +768,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -1196,7 +1241,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -2441,7 +2486,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -2659,7 +2704,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -3032,7 +3077,7 @@ TF_CALL_float(REGISTER_SYCL_KERNELS); TF_CALL_double(REGISTER_SYCL_KERNELS); #endif -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -3173,7 +3218,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -3304,7 +3349,7 @@ TF_CALL_half(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -3538,7 +3583,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -3957,7 +4002,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ @@ -4064,7 +4109,7 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNELS); TF_CALL_float(REGISTER_CPU_KERNELS); TF_CALL_double(REGISTER_CPU_KERNELS); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index f45b9ffca7c..e67ac07517f 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -390,4 +390,4 @@ template struct functor::ApplyPowerSign; } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/training_ops_test.cc b/tensorflow/core/kernels/training_ops_test.cc index 2dcc4a500e6..596c980b24a 100644 --- a/tensorflow/core/kernels/training_ops_test.cc +++ b/tensorflow/core/kernels/training_ops_test.cc @@ -183,16 +183,22 @@ static void Adam(int32 n, Graph** init_g, Graph** train_g) { } } -static void BM_Adam(int iters, int params) { +static void BM_Adam(int iters, int params, int is_multi_threaded) { const int64 tot = static_cast(iters) * params; testing::ItemsProcessed(tot); testing::BytesProcessed(tot * sizeof(float)); Graph* init; Graph* train; Adam(params, &init, &train); - test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + if (is_multi_threaded) { + // Use max thread number if test performance. + test::Benchmark("cpu", train, nullptr, init).Run(iters); + } else { + test::Benchmark("cpu", train, GetOptions(), init).Run(iters); + } } -BENCHMARK(BM_Adam)->Arg(128 << 10)->Arg(256 << 10); +BENCHMARK(BM_Adam)->ArgPair(128 << 10, 0)->ArgPair(256 << 10, 0); +BENCHMARK(BM_Adam)->ArgPair(256 << 5, 1)->ArgPair(256 << 16, 1); static void RMSProp(int32 n, Graph** init_g, Graph** train_g) { TensorShape shape({n}); diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc index b4e5d0ae58a..aa9e7196223 100644 --- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc @@ -20,7 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/transpose_functor.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" // TODO(yangzihao): Remove the dependency of conv_2d.h once we move all // GPU util functions and transpose kernels into separate files. @@ -79,7 +79,7 @@ void TransposeSimple(const GPUDevice& d, const Tensor& in, // Launch kernel to q[...] = p[...]. const T* p = reinterpret_cast(in.tensor_data().data()); T* q = reinterpret_cast(const_cast((out->tensor_data().data()))); - CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); + GpuLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); TF_CHECK_OK(CudaLaunchKernel( TransposeKernel, cfg.block_count, cfg.thread_per_block, 0, d.stream(), cfg.virtual_thread_count, p, @@ -168,60 +168,29 @@ struct TransposeUsingTile { } // namespace internal // Transpose kernel specialized for GPU Device. +#define HANDLE_DIM(DIM) \ + case DIM: \ + internal::TransposeUsingEigen(d, in, perm, conjugate, \ + out); \ + break + template struct Transpose { static void run(const GPUDevice& d, const Tensor& in, const gtl::ArraySlice perm, Tensor* out) { + if (in.dims() < 2) return; + if (internal::TransposeUsingTile::run(d, in, perm, out)) { + return; + } + switch (in.dims()) { - case 2: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 3: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 4: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 5: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 6: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 7: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; - case 8: - if (!internal::TransposeUsingTile::run(d, in, perm, - out)) { - internal::TransposeUsingEigen(d, in, perm, conjugate, - out); - } - break; + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + HANDLE_DIM(6); + HANDLE_DIM(7); + HANDLE_DIM(8); default: internal::TransposeSimple(d, in, perm, out); break; @@ -229,6 +198,8 @@ struct Transpose { } }; +#undef HANDLE_DIM + template struct Transpose { static void run(const GPUDevice& d, const Tensor& in, diff --git a/tensorflow/core/kernels/tridiagonal_matmul_op.cc b/tensorflow/core/kernels/tridiagonal_matmul_op.cc new file mode 100644 index 00000000000..3ddf22012de --- /dev/null +++ b/tensorflow/core/kernels/tridiagonal_matmul_op.cc @@ -0,0 +1,134 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/linalg_ops.cc. + +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// TODO(b/131583008): add broadcast support (for batch dimensions). +template +class TridiagonalMatMulOp : public LinearAlgebraOp { + public: + INHERIT_LINALG_TYPEDEFS(Scalar); + + explicit TridiagonalMatMulOp(OpKernelConstruction* context) : Base(context) {} + + void ValidateInputMatrixShapes( + OpKernelContext* context, + const TensorShapes& input_matrix_shapes) const final { + auto num_inputs = input_matrix_shapes.size(); + OP_REQUIRES( + context, num_inputs == 4, + errors::InvalidArgument("Expected 4 inputs, got ", num_inputs, ".")); + + auto n = input_matrix_shapes[3].dim_size(0); + + OP_REQUIRES(context, + input_matrix_shapes[0].dim_size(0) == 1 && + input_matrix_shapes[0].dim_size(1) == n, + errors::InvalidArgument("Invalid superdiagonal shape.")); + + OP_REQUIRES(context, + input_matrix_shapes[1].dim_size(0) == 1 && + input_matrix_shapes[1].dim_size(1) == n, + errors::InvalidArgument("Invalid main diagonal shape.")); + + OP_REQUIRES(context, + input_matrix_shapes[2].dim_size(0) == 1 && + input_matrix_shapes[2].dim_size(1) == n, + errors::InvalidArgument("Invalid subdiagonal shape.")); + } + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + return TensorShapes({input_matrix_shapes[3]}); + } + + int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { + const int num_eqs = static_cast(input_matrix_shapes[0].dim_size(1)); + const int num_rhss = static_cast(input_matrix_shapes[3].dim_size(0)); + + const double add_cost = Eigen::TensorOpCost::AddCost(); + const double mult_cost = Eigen::TensorOpCost::MulCost(); + + const double cost = num_rhss * ((3 * num_eqs - 2) * mult_cost + + (2 * num_eqs - 2) * add_cost); + return cost >= static_cast(kint64max) ? kint64max + : static_cast(cost); + } + + // Needed to prevent writing result to the same location where input is. + bool EnableInputForwarding() const final { return false; } + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + // Superdiagonal elements. Must have length m. + // Last element is ignored. + const auto& superdiag = inputs[0].row(0); + + // Diagonal elements. Must have length m. + const auto& maindiag = inputs[1].row(0); + + // Subdiagonal elements. Must have length m. + // First element is ignored. + const auto& subdiag = inputs[2].row(0); + + // Right-hand matrix. Size m x n. + const auto& rhs = inputs[3]; + + MatrixMap& result = outputs->at(0); + + const int m = rhs.rows(); + const int n = rhs.cols(); + + ConstVectorMap subdiag_map(subdiag.data() + 1, m - 1); + ConstVectorMap superdiag_map(superdiag.data(), m - 1); + ConstMatrixMap rhs_except_first_row(rhs.data() + n, m - 1, n); + ConstMatrixMap rhs_except_last_row(rhs.data(), m - 1, n); + + MatrixMap result_except_first_row(result.data() + n, m - 1, n); + MatrixMap result_except_last_row(result.data(), m - 1, n); + result.array() = rhs.array().colwise() * maindiag.transpose().array(); + result_except_first_row.noalias() += + (rhs_except_last_row.array().colwise() * + subdiag_map.transpose().array()) + .matrix(); + result_except_last_row.noalias() += + (rhs_except_first_row.array().colwise() * + superdiag_map.transpose().array()) + .matrix(); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalMatMulOp); +}; + +REGISTER_LINALG_OP_CPU("TridiagonalMatMul", (TridiagonalMatMulOp), + float); +REGISTER_LINALG_OP_CPU("TridiagonalMatMul", (TridiagonalMatMulOp), + double); +REGISTER_LINALG_OP_CPU("TridiagonalMatMul", (TridiagonalMatMulOp), + complex64); +REGISTER_LINALG_OP_CPU("TridiagonalMatMul", (TridiagonalMatMulOp), + complex128); +} // namespace tensorflow diff --git a/tensorflow/core/kernels/tridiagonal_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/tridiagonal_matmul_op_gpu.cu.cc new file mode 100644 index 00000000000..7b0d4ed8227 --- /dev/null +++ b/tensorflow/core/kernels/tridiagonal_matmul_op_gpu.cu.cc @@ -0,0 +1,99 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/linalg_ops.cc. + +#ifdef GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/kernels/cuda_sparse.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/gpu_device_functions.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" + +namespace tensorflow { + +template +__global__ void TridiagonalMatMulKernel(int batch_size, int m, int n, + const Scalar* superdiag, + const Scalar* maindiag, + const Scalar* subdiag, + const Scalar* rhs, Scalar* product) { + for (int i : CudaGridRangeX(batch_size * m * n)) { + int row_id = i / n; + Scalar result = maindiag[row_id] * rhs[i]; + if (row_id % m != 0) { + result = result + subdiag[row_id] * rhs[i - n]; + } + if ((row_id + 1) % m != 0) { + result = result + superdiag[row_id] * rhs[i + n]; + } + product[i] = result; + } +} + +template +class TridiagonalMatMulOpGpu : public OpKernel { + public: + explicit TridiagonalMatMulOpGpu(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) final { + const Tensor& superdiag = context->input(0); + const Tensor& maindiag = context->input(1); + const Tensor& subdiag = context->input(2); + const Tensor& rhs = context->input(3); + + const int ndims = rhs.dims(); + int64 batch_size = 1; + for (int i = 0; i < ndims - 2; i++) { + batch_size *= rhs.dim_size(i); + } + const int m = rhs.dim_size(ndims - 2); + const int n = rhs.dim_size(ndims - 1); + + // Allocate output. + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output)); + + const Eigen::GpuDevice& device = context->eigen_device(); + CudaLaunchConfig cfg = GetCudaLaunchConfig(1, device); + TF_CHECK_OK(GpuLaunchKernel( + TridiagonalMatMulKernel, cfg.block_count, cfg.thread_per_block, + 0, device.stream(), batch_size, m, n, superdiag.flat().data(), + maindiag.flat().data(), subdiag.flat().data(), + rhs.flat().data(), output->flat().data())); + } +}; + +REGISTER_LINALG_OP_GPU("TridiagonalMatMul", (TridiagonalMatMulOpGpu), + float); +REGISTER_LINALG_OP_GPU("TridiagonalMatMul", (TridiagonalMatMulOpGpu), + double); +REGISTER_LINALG_OP_GPU("TridiagonalMatMul", (TridiagonalMatMulOpGpu), + complex64); +REGISTER_LINALG_OP_GPU("TridiagonalMatMul", + (TridiagonalMatMulOpGpu), complex128); +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/tridiagonal_solve_op.cc b/tensorflow/core/kernels/tridiagonal_solve_op.cc index 5884ffedfbc..88931ff3e66 100644 --- a/tensorflow/core/kernels/tridiagonal_solve_op.cc +++ b/tensorflow/core/kernels/tridiagonal_solve_op.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,14 +25,25 @@ limitations under the License. namespace tensorflow { -static const char kErrMsg[] = "The matrix is not invertible."; +static const char kNotInvertibleMsg[] = "The matrix is not invertible."; + +static const char kNotInvertibleScalarMsg[] = + "The matrix is not invertible: it is a scalar with value zero."; + +static const char kThomasFailedMsg[] = + "The matrix is either not invertible, or requires pivoting. " + "Try setting partial_pivoting = True."; template class TridiagonalSolveOp : public LinearAlgebraOp { public: INHERIT_LINALG_TYPEDEFS(Scalar); + using MatrixMapRow = + decltype(std::declval()[0].row(0)); - explicit TridiagonalSolveOp(OpKernelConstruction* context) : Base(context) {} + explicit TridiagonalSolveOp(OpKernelConstruction* context) : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_)); + } void ValidateInputMatrixShapes( OpKernelContext* context, @@ -71,25 +82,32 @@ class TridiagonalSolveOp : public LinearAlgebraOp { const double mult_cost = Eigen::TensorOpCost::MulCost(); const double div_cost = Eigen::TensorOpCost::DivCost(); - // Assuming cases with and without row interchange are equiprobable. - const double cost = - num_eqs * (div_cost * (num_rhss + 1) + - (add_cost + mult_cost) * (2.5 * num_rhss + 1.5)); + double cost; + if (pivoting_) { + // Assuming cases with and without row interchange are equiprobable. + cost = num_eqs * (div_cost * (num_rhss + 1) + + (add_cost + mult_cost) * (2.5 * num_rhss + 1.5)); + } else { + cost = num_eqs * (div_cost * (num_rhss + 1) + + (add_cost + mult_cost) * (2 * num_rhss + 1)); + } return cost >= static_cast(kint64max) ? kint64max : static_cast(cost); } + bool EnableInputForwarding() const final { return false; } + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, MatrixMaps* outputs) final { const auto diagonals = inputs[0]; - // Subdiagonal elements, first is ignored. + // Superdiagonal elements, first is ignored. const auto& superdiag = diagonals.row(0); // Diagonal elements. const auto& diag = diagonals.row(1); - // Superdiagonal elements, n-th is ignored. + // Subdiagonal elements, n-th is ignored. const auto& subdiag = diagonals.row(2); - // Right-hand sides (transposed - necessary for GPU impl). + // Right-hand sides. const auto& rhs = inputs[1]; const int n = diag.size(); @@ -100,11 +118,32 @@ class TridiagonalSolveOp : public LinearAlgebraOp { return; } if (n == 1) { - OP_REQUIRES(context, diag(0) != zero, errors::InvalidArgument(kErrMsg)); + OP_REQUIRES(context, diag(0) != zero, + errors::InvalidArgument(kNotInvertibleScalarMsg)); x.row(0) = rhs.row(0) / diag(0); return; } + if (pivoting_) { + SolveWithGaussianEliminationWithPivoting(context, superdiag, diag, + subdiag, rhs, x); + } else { + SolveWithThomasAlgorithm(context, superdiag, diag, subdiag, rhs, x); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOp); + + void SolveWithGaussianEliminationWithPivoting(OpKernelContext* context, + const MatrixMapRow& superdiag, + const MatrixMapRow& diag, + const MatrixMapRow& subdiag, + const ConstMatrixMap& rhs, + MatrixMap& x) { + const int n = diag.size(); + const Scalar zero(0); + // The three columns in u are the diagonal, superdiagonal, and second // superdiagonal, respectively, of the U matrix in the LU decomposition of // the input matrix (subject to row exchanges due to pivoting). For pivoted @@ -119,7 +158,8 @@ class TridiagonalSolveOp : public LinearAlgebraOp { for (int i = 0; i < n - 1; ++i) { if (std::abs(u(i)) >= std::abs(subdiag(i + 1))) { // No row interchange. - OP_REQUIRES(context, u(i) != zero, errors::InvalidArgument(kErrMsg)); + OP_REQUIRES(context, u(i) != zero, + errors::InvalidArgument(kNotInvertibleMsg)); const Scalar factor = subdiag(i + 1) / u(i, 0); u(i + 1, 0) = diag(i + 1) - factor * u(i, 1); x.row(i + 1) = rhs.row(i + 1) - factor * x.row(i); @@ -141,6 +181,8 @@ class TridiagonalSolveOp : public LinearAlgebraOp { } } } + OP_REQUIRES(context, u(n - 1, 0) != zero, + errors::InvalidArgument(kNotInvertibleMsg)); x.row(n - 1) /= u(n - 1, 0); x.row(n - 2) = (x.row(n - 2) - u(n - 2, 1) * x.row(n - 1)) / u(n - 2, 0); for (int i = n - 3; i >= 0; --i) { @@ -149,8 +191,36 @@ class TridiagonalSolveOp : public LinearAlgebraOp { } } - private: - TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOp); + void SolveWithThomasAlgorithm(OpKernelContext* context, + const MatrixMapRow& superdiag, + const MatrixMapRow& diag, + const MatrixMapRow& subdiag, + const ConstMatrixMap& rhs, MatrixMap& x) { + const int n = diag.size(); + const Scalar zero(0); + + // The superdiagonal of the U matrix in the LU decomposition of the input + // matrix (in Thomas algorithm, the U matrix has ones on the diagonal and + // one superdiagonal). + Eigen::Matrix u(n); + + OP_REQUIRES(context, diag(0) != zero, + errors::InvalidArgument(kThomasFailedMsg)); + u(0) = superdiag(0) / diag(0); + x.row(0) = rhs.row(0) / diag(0); + for (int i = 1; i < n; ++i) { + auto denom = diag(i) - subdiag(i) * u(i - 1); + OP_REQUIRES(context, denom != zero, + errors::InvalidArgument(kThomasFailedMsg)); + u(i) = superdiag(i) / denom; + x.row(i) = (rhs.row(i) - subdiag(i) * x.row(i - 1)) / denom; + } + for (int i = n - 2; i >= 0; --i) { + x.row(i) -= u(i) * x.row(i + 1); + } + } + + bool pivoting_; }; REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp), float); diff --git a/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc b/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc new file mode 100644 index 00000000000..d70cc92f217 --- /dev/null +++ b/tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc @@ -0,0 +1,347 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/linalg_ops.cc. + +#ifdef GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/kernels/cuda_sparse.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/gpu_device_functions.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" + +namespace tensorflow { + +static const char kNotInvertibleMsg[] = "The matrix is not invertible."; + +static const char kNotInvertibleScalarMsg[] = + "The matrix is not invertible: it is a scalar with value zero."; + +template +__global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* diags, + const Scalar* rhs, const int num_rhs, + Scalar* x, bool* not_invertible) { + if (m == 1) { + if (diags[1] == Scalar(0)) { + *not_invertible = true; + return; + } + for (int i : CudaGridRangeX(num_rhs)) { + x[i] = rhs[i] / diags[1]; + } + } else { + Scalar det = diags[2] * diags[3] - diags[0] * diags[5]; + if (det == Scalar(0)) { + *not_invertible = true; + return; + } + for (int i : CudaGridRangeX(num_rhs)) { + x[i] = (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det; + x[i + num_rhs] = (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det; + } + } +} + +template +se::DeviceMemory AsDeviceMemory(const Scalar* cuda_memory) { + se::DeviceMemoryBase wrapped(const_cast(cuda_memory)); + se::DeviceMemory typed(wrapped); + return typed; +} + +template +void CopyDeviceToDevice(OpKernelContext* context, const Scalar* src, + Scalar* dst, const int num_elements) { + auto src_device_mem = AsDeviceMemory(src); + auto dst_device_mem = AsDeviceMemory(dst); + auto* stream = context->op_device_context()->stream(); + bool copy_status = stream + ->ThenMemcpyD2D(&dst_device_mem, src_device_mem, + sizeof(Scalar) * num_elements) + .ok(); + + if (!copy_status) { + context->SetStatus(errors::Internal("Copying device-to-device failed.")); + } +} + +// This implementation is used in cases when the batching mechanism of +// LinearAlgebraOp is suitable. See TridiagonalSolveOpGpu below. +template +class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp { + public: + INHERIT_LINALG_TYPEDEFS(Scalar); + + explicit TridiagonalSolveOpGpuLinalg(OpKernelConstruction* context) + : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_)); + } + + void ValidateInputMatrixShapes( + OpKernelContext* context, + const TensorShapes& input_matrix_shapes) const final { + auto num_inputs = input_matrix_shapes.size(); + OP_REQUIRES(context, num_inputs == 2, + errors::InvalidArgument("Expected two input matrices, got ", + num_inputs, ".")); + + auto num_diags = input_matrix_shapes[0].dim_size(0); + OP_REQUIRES( + context, num_diags == 3, + errors::InvalidArgument("Expected diagonals to be provided as a " + "matrix with 3 columns, got ", + num_diags, " columns.")); + + auto num_rows1 = input_matrix_shapes[0].dim_size(1); + auto num_rows2 = input_matrix_shapes[1].dim_size(0); + OP_REQUIRES(context, num_rows1 == num_rows2, + errors::InvalidArgument("Expected same number of rows in both " + "arguments, got ", + num_rows1, " and ", num_rows2, ".")); + } + + bool EnableInputForwarding() const final { return false; } + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + return TensorShapes({input_matrix_shapes[1]}); + } + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + const auto diagonals = inputs[0]; + // Superdiagonal elements, first is ignored. + const auto& superdiag = diagonals.row(0); + // Diagonal elements. + const auto& diag = diagonals.row(1); + // Subdiagonal elements, last is ignored. + const auto& subdiag = diagonals.row(2); + // Right-hand sides. + const auto& rhs = inputs[1]; + MatrixMap& x = outputs->at(0); + const int m = diag.size(); + const int k = rhs.cols(); + + if (m == 0) { + return; + } + if (m < 3) { + // Cusparse gtsv routine requires m >= 3. Solving manually for m < 3. + SolveForSizeOneOrTwo(context, diagonals.data(), rhs.data(), x.data(), m, + k); + return; + } + std::unique_ptr cusparse_solver(new CudaSparse(context)); + OP_REQUIRES_OK(context, cusparse_solver->Initialize()); + if (k == 1) { + // rhs is copied into x, then gtsv replaces x with solution. + CopyDeviceToDevice(context, rhs.data(), x.data(), m); + SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(), + subdiag.data(), x.data(), m, 1); + } else { + // Gtsv expects rhs in column-major form, so we have to transpose. + // rhs is transposed into temp, gtsv replaces temp with solution, then + // temp is transposed into x. + std::unique_ptr cublas_solver(new CudaSolver(context)); + Tensor temp; + TensorShape temp_shape({k, m}); + OP_REQUIRES_OK(context, + cublas_solver->allocate_scoped_tensor( + DataTypeToEnum::value, temp_shape, &temp)); + TransposeWithGeam(context, cublas_solver, rhs.data(), + temp.flat().data(), m, k); + SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(), + subdiag.data(), temp.flat().data(), m, k); + TransposeWithGeam(context, cublas_solver, temp.flat().data(), + x.data(), k, m); + } + } + + private: + void TransposeWithGeam(OpKernelContext* context, + const std::unique_ptr& cublas_solver, + const Scalar* src, Scalar* dst, const int src_rows, + const int src_cols) const { + const Scalar zero(0), one(1); + OP_REQUIRES_OK(context, + cublas_solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, src_rows, + src_cols, &one, src, src_cols, &zero, + static_cast(nullptr), + src_rows, dst, src_rows)); + } + + void SolveWithGtsv(OpKernelContext* context, + std::unique_ptr& cusparse_solver, + const Scalar* superdiag, const Scalar* diag, + const Scalar* subdiag, Scalar* rhs, const int num_eqs, + const int num_rhs) const { + auto function = pivoting_ ? &CudaSparse::Gtsv + : &CudaSparse::GtsvNoPivot; + OP_REQUIRES_OK( + context, (cusparse_solver.get()->*function)( + num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs)); + } + + void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals, + const Scalar* rhs, Scalar* output, int m, int k) { + const Eigen::GpuDevice& device = context->eigen_device(); + GpuLaunchConfig cfg = GetCudaLaunchConfig(1, device); + bool* not_invertible_dev; + cudaMalloc(¬_invertible_dev, sizeof(bool)); + TF_CHECK_OK(CudaLaunchKernel(SolveForSizeOneOrTwoKernel, + cfg.block_count, cfg.thread_per_block, 0, + device.stream(), m, diagonals, rhs, k, output, + not_invertible_dev)); + bool not_invertible_host; + cudaMemcpy(¬_invertible_host, not_invertible_dev, sizeof(bool), + cudaMemcpyDeviceToHost); + cudaFree(not_invertible_dev); + OP_REQUIRES(context, !not_invertible_host, + errors::InvalidArgument(m == 1 ? kNotInvertibleScalarMsg + : kNotInvertibleMsg)); + } + + bool pivoting_; +}; + +template +class TridiagonalSolveOpGpu : public OpKernel { + public: + explicit TridiagonalSolveOpGpu(OpKernelConstruction* context) + : OpKernel(context), linalgOp_(context) { + OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_)); + } + + void Compute(OpKernelContext* context) final { + const Tensor& lhs = context->input(0); + const Tensor& rhs = context->input(1); + const int ndims = lhs.dims(); + const int64 num_rhs = rhs.dim_size(rhs.dims() - 1); + const int64 matrix_size = lhs.dim_size(ndims - 1); + int64 batch_size = 1; + for (int i = 0; i < ndims - 2; i++) { + batch_size *= lhs.dim_size(i); + } + + // The batching mechanism of LinearAlgebraOp is used when it's not + // possible or desirable to use GtsvBatched. + const bool use_linalg_op = + pivoting_ // GtsvBatched doesn't do pivoting + || num_rhs > 1 // GtsvBatched doesn't support multiple rhs + || matrix_size < 3 // Not supported in cuSparse, use the custom kernel + || batch_size == 1; // No point to use GtsvBatched + + if (use_linalg_op) { + linalgOp_.Compute(context); + } else { + ComputeWithGtsvBatched(context, lhs, rhs, batch_size); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOpGpu); + + void ComputeWithGtsvBatched(OpKernelContext* context, const Tensor& lhs, + const Tensor& rhs, const int batch_size) { + const Scalar* rhs_data = rhs.flat().data(); + const int ndims = lhs.dims(); + + // To use GtsvBatched we need to transpose the left-hand side from shape + // [..., 3, M] into shape [3, ..., M]. With shape [..., 3, M] the stride + // between corresponding diagonal elements of consecutive batch components + // is 3 * M, while for the right-hand side the stride is M. Unfortunately, + // GtsvBatched requires the strides to be the same. For this reason we + // transpose into [3, ..., M], so that diagonals, superdiagonals, and + // and subdiagonals are separated from each other, and have stride M. + Tensor lhs_transposed; + TransposeLhsForGtsvBatched(context, lhs, lhs_transposed); + int matrix_size = lhs.dim_size(ndims - 1); + const Scalar* lhs_data = lhs_transposed.flat().data(); + const Scalar* superdiag = lhs_data; + const Scalar* diag = lhs_data + matrix_size * batch_size; + const Scalar* subdiag = lhs_data + 2 * matrix_size * batch_size; + + // Copy right-hand side into the output. GtsvBatched will replace it with + // the solution. + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output)); + CopyDeviceToDevice(context, rhs_data, output->flat().data(), + rhs.flat().size()); + Scalar* x = output->flat().data(); + + std::unique_ptr cusparse_solver(new CudaSparse(context)); + + OP_REQUIRES_OK(context, cusparse_solver->Initialize()); + OP_REQUIRES_OK(context, cusparse_solver->GtsvStridedBatch( + matrix_size, subdiag, diag, superdiag, x, + batch_size, matrix_size)); + } + + void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs, + Tensor& lhs_transposed) { + const int ndims = lhs.dims(); + + // Permutation of indices, transforming [..., 3, M] into [3, ..., M]. + // E.g. for ndims = 6, it is [4, 0, 1, 2, 3, 5]. + std::vector perm(ndims); + perm[0] = ndims - 2; + for (int i = 0; i < ndims - 2; ++i) { + perm[i + 1] = i; + } + perm[ndims - 1] = ndims - 1; + + std::vector dims; + for (int index : perm) { + dims.push_back(lhs.dim_size(index)); + } + TensorShape lhs_transposed_shape( + gtl::ArraySlice(dims.data(), ndims)); + + std::unique_ptr cublas_solver(new CudaSolver(context)); + OP_REQUIRES_OK(context, cublas_solver->allocate_scoped_tensor( + DataTypeToEnum::value, + lhs_transposed_shape, &lhs_transposed)); + auto device = context->eigen_device(); + OP_REQUIRES_OK( + context, + DoTranspose(device, lhs, gtl::ArraySlice(perm.data(), ndims), + &lhs_transposed)); + } + + TridiagonalSolveOpGpuLinalg linalgOp_; + bool pivoting_; +}; + +REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu), + float); +REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu), + double); +REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu), + complex64); +REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu), + complex128); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/unicode_ops.cc b/tensorflow/core/kernels/unicode_ops.cc index c071db60648..59ebbedcd7f 100644 --- a/tensorflow/core/kernels/unicode_ops.cc +++ b/tensorflow/core/kernels/unicode_ops.cc @@ -350,6 +350,7 @@ class UnicodeTranscodeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("UnicodeTranscode").Device(DEVICE_CPU), UnicodeTranscodeOp); +template class UnicodeDecodeBaseOp : public OpKernel { public: explicit UnicodeDecodeBaseOp(OpKernelConstruction* ctx, bool generate_offsets) @@ -369,8 +370,8 @@ class UnicodeDecodeBaseOp : public OpKernel { } void Decode(OpKernelContext* ctx, std::vector* char_values, - std::vector* offset_values, int* current_offset, - int64* next_row_split, UChar32 char_value, int char_length, + std::vector* offset_values, int* current_offset, + SPLITS_TYPE* next_row_split, UChar32 char_value, int char_length, bool found_any_format_error) { if (error_options_.error_on_malformatting && found_any_format_error) { ctx->CtxFailure( @@ -414,16 +415,16 @@ class UnicodeDecodeBaseOp : public OpKernel { input_encoding_)); std::vector char_values; - std::vector offset_values; + std::vector offset_values; Tensor* output_row_splits; OP_REQUIRES_OK(ctx, ctx->allocate_output("row_splits", {input_tensor->NumElements() + 1}, &output_row_splits)); - auto out_row_splits = output_row_splits->vec(); + auto out_row_splits = output_row_splits->vec(); int row_split_index = 0; - int64 next_row_split = 0; + SPLITS_TYPE next_row_split = 0; for (int i = 0; i < input_vec.size(); ++i) { const string& input = input_vec(i); // Convert input strings into unicode values. Output to a list of @@ -443,18 +444,18 @@ class UnicodeDecodeBaseOp : public OpKernel { Tensor* output_char_values; OP_REQUIRES_OK( - ctx, ctx->allocate_output("char_values", - {static_cast(char_values.size())}, - &output_char_values)); + ctx, ctx->allocate_output( + "char_values", {static_cast(char_values.size())}, + &output_char_values)); auto out_char_values = output_char_values->vec(); if (generate_offsets_) { DCHECK(offset_values.size() == char_values.size()); Tensor* output_offset_values; - OP_REQUIRES_OK( - ctx, ctx->allocate_output("char_to_byte_starts", - {static_cast(offset_values.size())}, - &output_offset_values)); - auto out_offset_values = output_offset_values->vec(); + OP_REQUIRES_OK(ctx, ctx->allocate_output( + "char_to_byte_starts", + {static_cast(offset_values.size())}, + &output_offset_values)); + auto out_offset_values = output_offset_values->vec(); // Load output tensors from intermediate value arrays. for (int i = 0; i < char_values.size(); ++i) { @@ -474,23 +475,36 @@ class UnicodeDecodeBaseOp : public OpKernel { bool generate_offsets_ = false; }; -class UnicodeDecodeOp : public UnicodeDecodeBaseOp { +template +class UnicodeDecodeOp : public UnicodeDecodeBaseOp { public: explicit UnicodeDecodeOp(OpKernelConstruction* ctx) - : UnicodeDecodeBaseOp(ctx, false) {} + : UnicodeDecodeBaseOp(ctx, false) {} }; -class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp { +template +class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp { public: explicit UnicodeDecodeWithOffsetsOp(OpKernelConstruction* ctx) - : UnicodeDecodeBaseOp(ctx, true) {} + : UnicodeDecodeBaseOp(ctx, true) {} }; -REGISTER_KERNEL_BUILDER(Name("UnicodeDecode").Device(DEVICE_CPU), - UnicodeDecodeOp); -REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets").Device(DEVICE_CPU), - UnicodeDecodeWithOffsetsOp); +REGISTER_KERNEL_BUILDER( + Name("UnicodeDecode").Device(DEVICE_CPU).TypeConstraint("Tsplits"), + UnicodeDecodeOp); +REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets") + .Device(DEVICE_CPU) + .TypeConstraint("Tsplits"), + UnicodeDecodeWithOffsetsOp); +REGISTER_KERNEL_BUILDER( + Name("UnicodeDecode").Device(DEVICE_CPU).TypeConstraint("Tsplits"), + UnicodeDecodeOp); +REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets") + .Device(DEVICE_CPU) + .TypeConstraint("Tsplits"), + UnicodeDecodeWithOffsetsOp); +template class UnicodeEncodeOp : public OpKernel { public: explicit UnicodeEncodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -515,7 +529,7 @@ class UnicodeEncodeOp : public OpKernel { const Tensor& input_tensor = context->input(0); const auto input_tensor_flat = input_tensor.flat(); const Tensor& input_splits = context->input(1); - const auto input_splits_flat = input_splits.flat(); + const auto input_splits_flat = input_splits.flat(); // Since we limit to a 2-D input (flat_values of rank 1 and a single splits // tensor), our output dimension will be 1 with it's size equal to the @@ -558,7 +572,11 @@ class UnicodeEncodeOp : public OpKernel { ErrorOptions error_options_; }; -REGISTER_KERNEL_BUILDER(Name("UnicodeEncode").Device(DEVICE_CPU), - UnicodeEncodeOp); +REGISTER_KERNEL_BUILDER( + Name("UnicodeEncode").Device(DEVICE_CPU).TypeConstraint("Tsplits"), + UnicodeEncodeOp); +REGISTER_KERNEL_BUILDER( + Name("UnicodeEncode").Device(DEVICE_CPU).TypeConstraint("Tsplits"), + UnicodeEncodeOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc index 8577ce7bf79..46906092795 100644 --- a/tensorflow/core/kernels/unpack_op.cc +++ b/tensorflow/core/kernels/unpack_op.cc @@ -144,6 +144,8 @@ TF_CALL_ALL_TYPES(REGISTER_UNPACK); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_bfloat16(REGISTER_GPU); +TF_CALL_uint8(REGISTER_GPU); +TF_CALL_bool(REGISTER_GPU); #undef REGISTER_GPU // A special GPU kernel for int32. diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 00994bbe8e7..3865bdbb848 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -209,7 +209,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Only register 'Variable' on GPU for the subset of types also supported by // 'Assign' (see dense_update_ops.cc.) #define REGISTER_GPU_KERNELS(type) \ @@ -236,6 +236,6 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h index da9434efcb0..c4895cb95b5 100644 --- a/tensorflow/core/kernels/where_op_gpu.cu.h +++ b/tensorflow/core/kernels/where_op_gpu.cu.h @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/kernels/where_op.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -323,7 +323,7 @@ struct Where { const Eigen::array strides = CalculateStrides(input); const TIndex output_rows = output.dimension(0); - CudaLaunchConfig config = GetCudaLaunchConfig(output_rows, d); + GpuLaunchConfig config = GetCudaLaunchConfig(output_rows, d); TF_CHECK_OK(CudaLaunchKernel(PropagateWhereIndicesKernel, config.block_count, config.thread_per_block, 0, d.stream(), output_rows, strides, diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index 9a3612bd72c..8a7c16349a7 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -57,8 +57,8 @@ class SoftmaxXentWithLogitsOp : public OpKernel { shape_in = BCast::ToShape(bcast.output_shape()); } OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in), - errors::InvalidArgument("logits and labels must be beither " - "2-dimensional, or roadcasted to " + errors::InvalidArgument("logits and labels must be either " + "2-dimensional, or broadcasted to be " "2-dimensional")); // loss is 1-D (one per example), and size is batch_size. @@ -134,7 +134,8 @@ TF_CALL_half(REGISTER_CPU); TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") .Device(DEVICE_GPU) .TypeConstraint("T"), @@ -147,7 +148,7 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") .Device(DEVICE_GPU) .TypeConstraint("T"), SoftmaxXentWithLogitsOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") diff --git a/tensorflow/core/kernels/xent_op_gpu.cu.cc b/tensorflow/core/kernels/xent_op_gpu.cu.cc index 2c0c0b3a027..2b1ac45ab4c 100644 --- a/tensorflow/core/kernels/xent_op_gpu.cu.cc +++ b/tensorflow/core/kernels/xent_op_gpu.cu.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #define EIGEN_USE_GPU @@ -54,4 +55,4 @@ template struct functor::XentFunctor; } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index f8c06988cba..941e2bdf545 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -34,7 +34,7 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(); #include "include/libxsmm_cpuid.h" #include "include/libxsmm_malloc.h" -#include "third_party/libxsmm/src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/ +#include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/ namespace tensorflow { diff --git a/tensorflow/core/lib/core/status.cc b/tensorflow/core/lib/core/status.cc index 0b63f66f6da..c8b24df6270 100644 --- a/tensorflow/core/lib/core/status.cc +++ b/tensorflow/core/lib/core/status.cc @@ -140,6 +140,23 @@ string* TfCheckOpHelperOutOfLine(const ::tensorflow::Status& v, return new string(r); } +// kDerivedMarker is appended to the Status message string to indicate whether a +// Status object is the root cause of an error or if it has been triggered by +// cancelling/aborting a step. +static const char* kDerivedMarker = "[_Derived_]"; + +Status StatusGroup::MakeDerived(const Status& s) { + if (IsDerived(s)) { + return s; + } else { + return Status(s.code(), strings::StrCat(kDerivedMarker, s.error_message())); + } +} + +bool StatusGroup::IsDerived(const Status& s) { + return s.error_message().find(kDerivedMarker) != std::string::npos; +} + void StatusGroup::Update(const Status& s) { if (s.ok()) { ++num_ok_; @@ -149,91 +166,88 @@ void StatusGroup::Update(const Status& s) { } } -const int kMaxChildMessageSize = 2048; +static std::vector GetNonDerivedStatuses( + const std::vector& status) { + std::vector nonderived_statuses; + for (auto& s : status) { + if (!StatusGroup::IsDerived(s)) { + nonderived_statuses.push_back(s); + } + } + return nonderived_statuses; +} -Status StatusGroup::as_status() const { +static constexpr int kMaxAggregatedStatusMessageSize = 8 * 1024; + +// Summarize all the status objects in the StatusGroup. This is used when +// individual Status objects in the StatusGroup are not already summarized. +Status StatusGroup::as_summary_status() const { if (ok_) { return Status::OK(); } - // Reduce verbosity when handling duplicate messages. If there is only a - // single message, or all messages have similar content, then return the - // longest status message. - std::vector sorted_children(children_); - std::sort(sorted_children.begin(), sorted_children.end(), - [](const Status& a, const Status& b) { - return a.error_message().length() > b.error_message().length(); - }); - bool single_status = true; - for (const auto& s : sorted_children) { - if (s.code() != sorted_children[0].code() || - sorted_children[0].error_message().find(s.error_message()) == - string::npos) { - single_status = false; - break; - } + std::vector nonderived_statuses = GetNonDerivedStatuses(children_); + + // If only one root status is found, return it directly. + if (nonderived_statuses.size() == 1) { + return nonderived_statuses[0]; } - if (single_status) { - return sorted_children[0]; - } + if (!nonderived_statuses.empty()) { + std::vector fmt; - std::vector fmt; + fmt.push_back(strings::Printf("%zu root error(s) found.", + nonderived_statuses.size())); - // Compute a final output string with status codes sorted by frequency in - // increasing order. This prefers more "interesting" messages over child - // messages that may come from cancellation. - std::map> code_to_status; - for (const Status& s : children_) { - code_to_status[s.code()].push_back(s); - } - - std::vector> count_vec; - count_vec.reserve(code_to_status.size()); - for (auto& p : code_to_status) { - count_vec.push_back(std::make_pair(p.first, p.second.size())); - } - - std::sort( - count_vec.begin(), count_vec.end(), - [](const std::pair& a, - const std::pair& b) { return a.second < b.second; }); - - fmt.push_back( - strings::Printf("Combined status information from %zu operations:\n", - num_ok_ + children_.size())); - - for (const auto& p : count_vec) { - // Deduplicate error messages - std::map child_errors; - for (const Status& s : code_to_status[p.first]) { - ++child_errors[s.error_message()]; + int index = 0; + for (auto& s : nonderived_statuses) { + fmt.emplace_back(strings::StrCat(" (", index, ") ", s.ToString())); + ++index; } - string child_fmt; - for (auto& m : child_errors) { - child_fmt.append(strings::Printf( - " %s [%dx]", - str_util::StringReplace(m.first, "\n", "\n ", true).c_str(), - m.second)); - child_fmt.append("\n"); - } - // Strip last newline. - child_fmt = child_fmt.substr(0, child_fmt.size() - 1); + fmt.push_back(strings::Printf("%zu successful operations.", num_ok_)); + fmt.push_back( + strings::Printf("%zu derived errors ignored.", + children_.size() - nonderived_statuses.size())); - if (child_fmt.size() > kMaxChildMessageSize) { - child_fmt = - strings::StrCat(child_fmt.substr(0, kMaxChildMessageSize), "..."); - } - fmt.push_back(strings::Printf("Status code: %s [%dx]\n%s", - error_name(p.first).c_str(), p.second, - child_fmt.c_str())); + return Status( + nonderived_statuses[0].code(), + absl::StrJoin(fmt, "\n").substr(0, kMaxAggregatedStatusMessageSize)); + } else { + // All statuses are derived. Pick the first available status to return. + return children_[0]; + } +} + +// Concatenate all the status objects in the StatusGroup. This is used when +// individual Status objects in the StatusGroup are already summarized Status. +Status StatusGroup::as_concatenated_status() const { + if (ok_) { + return Status::OK(); } - fmt.push_back(strings::Printf("(%zd successful operations.)", num_ok_)); + std::vector nonderived_statuses = GetNonDerivedStatuses(children_); - // TODO(power): use the least-frequently occurring status for the return code - return Status(children_[0].code(), str_util::Join(fmt, "\n")); + // If only one root status is found, return it directly. + if (nonderived_statuses.size() == 1) { + return nonderived_statuses[0]; + } + + if (!nonderived_statuses.empty()) { + std::vector fmt; + fmt.emplace_back("\n====================="); + for (auto& s : nonderived_statuses) { + fmt.emplace_back(s.ToString()); + } + fmt.emplace_back("=====================\n"); + return Status( + nonderived_statuses[0].code(), + absl::StrJoin(fmt, "\n").substr(0, kMaxAggregatedStatusMessageSize)); + } else { + // All statuses are derived. Pick the first available status to return. + // This should not happen in normal execution. + return children_[0]; + } } } // namespace tensorflow diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h index fe3eec1be00..48174cb65c8 100644 --- a/tensorflow/core/lib/core/status.h +++ b/tensorflow/core/lib/core/status.h @@ -100,11 +100,16 @@ class Status { // Helper class to manage multiple child status values. class StatusGroup { public: - // Return a merged status with combined child status messages. - // - // The status code returned is OK if all children were successful, otherwise - // the first non-OK child status code is reported. - Status as_status() const; + // Utility function to mark a Status as derived. By marking derived status, + // Derived status messages are ignored when reporting errors to end users. + static Status MakeDerived(const Status& s); + static bool IsDerived(const Status& s); + + // Return a merged status with combined child status messages with a summary. + Status as_summary_status() const; + // Return a merged status with combined child status messages with + // concatenation. + Status as_concatenated_status() const; bool ok() const { return ok_; } diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc index 7c281840804..c932458fc76 100644 --- a/tensorflow/core/lib/core/status_test.cc +++ b/tensorflow/core/lib/core/status_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/lib/core/status.h" + #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -98,72 +98,67 @@ TEST(Status, EqualsDifferentMessage) { ASSERT_NE(a, b); } -TEST(StatusGroup, AcceptsFirstCode) { +TEST(StatusGroup, OKStatusGroup) { StatusGroup c; - const Status internal(errors::Internal("Original error.")); - c.Update(internal); c.Update(Status::OK()); c.Update(Status::OK()); - c.Update(Status::OK()); - ASSERT_EQ(c.as_status().code(), internal.code()); - ASSERT_EQ(c.ok(), false); + ASSERT_EQ(c.as_summary_status(), Status::OK()); + ASSERT_EQ(c.as_concatenated_status(), Status::OK()); } -TEST(StatusGroup, ContainsChildMessages) { +TEST(StatusGroup, AggregateWithSingleErrorStatus) { + StatusGroup c; + const Status internal(errors::Internal("Original error.")); + + c.Update(internal); + ASSERT_EQ(c.as_summary_status(), internal); + + Status concat_status = c.as_concatenated_status(); + ASSERT_EQ(concat_status.code(), internal.code()); + ASSERT_TRUE(absl::StrContains(concat_status.error_message(), + internal.error_message())); + + // Add derived error status + const Status derived = + StatusGroup::MakeDerived(errors::Internal("Derived error.")); + c.Update(derived); + + ASSERT_EQ(c.as_summary_status(), internal); + + concat_status = c.as_concatenated_status(); + ASSERT_EQ(concat_status.code(), internal.code()); + ASSERT_TRUE(absl::StrContains(concat_status.error_message(), + internal.error_message())); +} + +TEST(StatusGroup, AggregateWithMultipleErrorStatus) { StatusGroup c; const Status internal(errors::Internal("Original error.")); const Status cancelled(errors::Cancelled("Cancelled after 10 steps.")); const Status aborted(errors::Aborted("Aborted after 10 steps.")); + c.Update(internal); - for (size_t i = 0; i < 5; ++i) { - c.Update(cancelled); - } - for (size_t i = 0; i < 10; ++i) { - c.Update(aborted); - } - for (size_t i = 0; i < 100; ++i) { - c.Update(Status::OK()); - } + c.Update(cancelled); + c.Update(aborted); - ASSERT_EQ(c.as_status().code(), internal.code()); - EXPECT_TRUE(str_util::StrContains(c.as_status().error_message(), - internal.error_message())); - EXPECT_TRUE(str_util::StrContains(c.as_status().error_message(), - cancelled.error_message())); - EXPECT_TRUE(str_util::StrContains(c.as_status().error_message(), - aborted.error_message())); - StatusGroup d; - d.Update(c.as_status()); - c.Update(errors::FailedPrecondition("Failed!")); - d.Update(c.as_status()); - c.Update(errors::DataLoss("Data loss!")); - d.Update(c.as_status()); - LOG(INFO) << d.as_status(); -} + Status summary = c.as_summary_status(); -TEST(StatusGroup, ContainsIdenticalMessage) { - StatusGroup sg; - const Status internal(errors::Internal("Original error")); - for (size_t i = 0; i < 10; i++) { - sg.Update(internal); - } - EXPECT_EQ(sg.as_status(), internal); -} + ASSERT_EQ(summary.code(), internal.code()); + ASSERT_TRUE( + absl::StrContains(summary.error_message(), internal.error_message())); + ASSERT_TRUE( + absl::StrContains(summary.error_message(), cancelled.error_message())); + ASSERT_TRUE( + absl::StrContains(summary.error_message(), aborted.error_message())); -TEST(StatusGroup, ContainsCommonPrefix) { - StatusGroup sg; - const Status a(errors::Internal("Original error")); - const Status b(errors::Internal("Original error is")); - const Status c(errors::Internal("Original error is invalid")); - sg.Update(a); - sg.Update(c); - sg.Update(c); - sg.Update(b); - sg.Update(c); - sg.Update(b); - sg.Update(a); - sg.Update(b); - EXPECT_EQ(sg.as_status(), c); + Status concat_status = c.as_concatenated_status(); + ASSERT_EQ(concat_status.code(), internal.code()); + ASSERT_TRUE(absl::StrContains(concat_status.error_message(), + internal.error_message())); + ASSERT_TRUE(absl::StrContains(concat_status.error_message(), + cancelled.error_message())); + ASSERT_TRUE(absl::StrContains(concat_status.error_message(), + aborted.error_message())); } static void BM_TF_CHECK_OK(int iters) { diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index e929ff45a1f..83d1f48c6c0 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -208,5 +208,10 @@ void ThreadPool::SetStealPartitions( const std::vector>& partitions) { impl_->SetStealPartitions(partitions); } + +Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() { + DCHECK(impl_ != nullptr); + return impl_.get(); +} } // namespace thread } // namespace tensorflow diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h index 90c9f294472..8a2c76d8361 100644 --- a/tensorflow/core/lib/core/threadpool.h +++ b/tensorflow/core/lib/core/threadpool.h @@ -18,12 +18,14 @@ limitations under the License. #include #include + #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace Eigen { class Allocator; +class ThreadPoolInterface; } // namespace Eigen namespace tensorflow { namespace thread { @@ -120,6 +122,11 @@ class ThreadPool { // thread in the pool. Returns -1 otherwise. int CurrentThreadId() const; + // If ThreadPool implementation is compatible with Eigen::ThreadPoolInterface, + // returns a non-null pointer. The caller does not own the object the returned + // pointer points to, and should not attempt to delete. + Eigen::ThreadPoolInterface* AsEigenThreadPool(); + struct Impl; private: diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc index db996b783fd..0e16d5b0b9f 100644 --- a/tensorflow/core/lib/core/threadpool_test.cc +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -40,7 +40,7 @@ TEST(ThreadPool, DoWork) { for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { fprintf(stderr, "Testing with %d threads\n", num_threads); const int kWorkItems = 15; - bool work[kWorkItems]; + std::atomic work[kWorkItems]; for (int i = 0; i < kWorkItems; i++) { work[i] = false; } @@ -50,8 +50,7 @@ TEST(ThreadPool, DoWork) { pool.Schedule([&outer_context, &work, i]() { Context inner_context(ContextKind::kThread); ASSERT_EQ(outer_context, inner_context); - ASSERT_FALSE(work[i]); - work[i] = true; + ASSERT_FALSE(work[i].exchange(true)); }); } } @@ -65,7 +64,10 @@ void RunSharding(int64 block_size, int64 total, ThreadPool* threads) { mutex mu; int64 num_shards = 0; int64 num_done_work = 0; - std::vector work(total, false); + std::vector> work(total); + for (int i = 0; i < total; i++) { + work[i] = false; + } threads->TransformRangeConcurrently( block_size, total, [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) { @@ -75,14 +77,16 @@ void RunSharding(int64 block_size, int64 total, ThreadPool* threads) { mutex_lock l(mu); ++num_shards; for (; start < end; ++start) { - EXPECT_FALSE(work[start]); // No duplicate + EXPECT_FALSE(work[start].exchange(true)); // No duplicate ++num_done_work; - work[start] = true; } }); LOG(INFO) << block_size << " " << total; - const int64 num_workers = (total + block_size - 1) / block_size; EXPECT_EQ(num_done_work, total); + for (int i = 0; i < total; i++) { + ASSERT_TRUE(work[i]); + } + const int64 num_workers = (total + block_size - 1) / block_size; if (num_workers < threads->NumThreads()) { // If the intention is to limit the parallelism explicitly, we'd // better honor it. Ideally, even if per_thread_max_parallelism > @@ -129,7 +133,7 @@ TEST(ThreadPool, ParallelFor) { for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { fprintf(stderr, "Testing with %d threads\n", num_threads); const int kWorkItems = 15; - bool work[kWorkItems]; + std::atomic work[kWorkItems]; ThreadPool pool(Env::Default(), "test", num_threads); for (int i = 0; i < kWorkItems; i++) { work[i] = false; @@ -139,8 +143,7 @@ TEST(ThreadPool, ParallelFor) { Context inner_context(ContextKind::kThread); ASSERT_EQ(outer_context, inner_context); for (int64 i = begin; i < end; ++i) { - ASSERT_FALSE(work[i]); - work[i] = true; + ASSERT_FALSE(work[i].exchange(true)); } }); for (int i = 0; i < kWorkItems; i++) { @@ -155,19 +158,18 @@ TEST(ThreadPool, ParallelForWithWorkerId) { for (int num_threads = 1; num_threads < kNumThreads; num_threads++) { fprintf(stderr, "Testing with %d threads\n", num_threads); const int kWorkItems = 15; - volatile std::atomic work[kWorkItems]; + std::atomic work[kWorkItems]; ThreadPool pool(Env::Default(), "test", num_threads); for (int i = 0; i < kWorkItems; i++) { work[i] = false; } - volatile std::atomic threads_running[kNumThreads + 1]; + std::atomic threads_running[kNumThreads + 1]; for (int i = 0; i < num_threads + 1; i++) { threads_running[i] = false; } pool.ParallelForWithWorkerId( kWorkItems, kHugeCost, - [&threads_running, &work, num_threads](int64 begin, int64 end, - int64 id) { + [&threads_running, &work](int64 begin, int64 end, int64 id) { // Store true for the current thread, and assert that another thread // is not running with the same id. ASSERT_LE(0, id); diff --git a/tensorflow/core/lib/gif/gif_io.h b/tensorflow/core/lib/gif/gif_io.h index 0a7967a5a15..e46a7917398 100644 --- a/tensorflow/core/lib/gif/gif_io.h +++ b/tensorflow/core/lib/gif/gif_io.h @@ -15,7 +15,7 @@ limitations under the License. // Functions to read and write images in GIF format. // -// The advantage over image/codec/png{enc,dec}ocder.h is that this library +// The advantage over image/codec/png{enc,dec}oder.h is that this library // supports both 8 and 16 bit images. // // The decoding routine accepts binary image data as a StringPiece. These are diff --git a/tensorflow/core/lib/gtl/map_util.h b/tensorflow/core/lib/gtl/map_util.h index 356e5d64188..6a48d5566e0 100644 --- a/tensorflow/core/lib/gtl/map_util.h +++ b/tensorflow/core/lib/gtl/map_util.h @@ -21,11 +21,14 @@ limitations under the License. #define TENSORFLOW_LIB_GTL_MAP_UTIL_H_ #include + #include #include #include #include +#include "tensorflow/core/lib/gtl/subtle/map_traits.h" + namespace tensorflow { namespace gtl { @@ -155,6 +158,34 @@ typename Collection::value_type::second_type& LookupOrInsert( typename Collection::value_type(key, value)); } +// Erases the m item identified by the given key, and returns the value +// associated with that key. It is assumed that the value (i.e., the +// mapped_type) is a pointer. Returns null if the key was not found in the +// m. +// +// Examples: +// std::map my_map; +// +// One line cleanup: +// delete EraseKeyReturnValuePtr(&my_map, "abc"); +// +// Use returned value: +// std::unique_ptr value_ptr( +// EraseKeyReturnValuePtr(&my_map, "abc")); +// if (value_ptr.get()) +// value_ptr->DoSomething(); +// +template +typename Collection::value_type::second_type EraseKeyReturnValuePtr( + Collection* collection, + const typename Collection::value_type::first_type& key) { + auto it = collection->find(key); + if (it == collection->end()) return nullptr; + auto v = gtl::subtle::GetMapped(*it); + collection->erase(it); + return v; +} + } // namespace gtl } // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/subtle/map_traits.h b/tensorflow/core/lib/gtl/subtle/map_traits.h new file mode 100644 index 00000000000..96578ab19a4 --- /dev/null +++ b/tensorflow/core/lib/gtl/subtle/map_traits.h @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Traits classes for performing uniform lookup on different map value types. +// +// The access is computed as follows: +// +// 1. If T has a `first` or `second` field, use them. +// 2. Otherwise if it has `key()` or `value()` methods, use them. +// 3. Otherwise the program is ill-formed. +#ifndef TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ +#define TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ + +#include + +namespace tensorflow { +namespace gtl { +namespace subtle { +namespace internal_map_traits { +struct Rank1 {}; +struct Rank0 : Rank1 {}; + +template +auto GetKey(V&& v, Rank0) -> decltype((std::forward(v).first)) { + return std::forward(v).first; +} +template +auto GetKey(V&& v, Rank1) -> decltype(std::forward(v).key()) { + return std::forward(v).key(); +} + +template +auto GetMapped(V&& v, Rank0) -> decltype((std::forward(v).second)) { + return std::forward(v).second; +} +template +auto GetMapped(V&& v, Rank1) -> decltype(std::forward(v).value()) { + return std::forward(v).value(); +} + +} // namespace internal_map_traits + +// Accesses the `key_type` from a `value_type`. +template +auto GetKey(V&& v) + -> decltype(internal_map_traits::GetKey(std::forward(v), + internal_map_traits::Rank0())) { + return internal_map_traits::GetKey(std::forward(v), + internal_map_traits::Rank0()); +} + +// Accesses the `mapped_type` from a `value_type`. +template +auto GetMapped(V&& v) + -> decltype(internal_map_traits::GetMapped(std::forward(v), + internal_map_traits::Rank0())) { + return internal_map_traits::GetMapped(std::forward(v), + internal_map_traits::Rank0()); +} + +} // namespace subtle +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ diff --git a/tensorflow/core/lib/hash/crc32c.cc b/tensorflow/core/lib/hash/crc32c.cc index bd3b41e748c..9a3eba704ac 100644 --- a/tensorflow/core/lib/hash/crc32c.cc +++ b/tensorflow/core/lib/hash/crc32c.cc @@ -263,5 +263,16 @@ uint32 Extend(uint32 crc, const char *buf, size_t size) { return l ^ 0xffffffffu; } +#if defined(PLATFORM_GOOGLE) +uint32 Extend(uint32 crc, const absl::Cord &cord) { + absl::CordReader reader(cord); + absl::string_view fragment; + while (reader.ReadFragment(&fragment)) { + crc = Extend(crc, fragment.data(), fragment.size()); + } + return crc; +} +#endif + } // namespace crc32c } // namespace tensorflow diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h index 2718cd31b37..edf9eb05320 100644 --- a/tensorflow/core/lib/hash/crc32c.h +++ b/tensorflow/core/lib/hash/crc32c.h @@ -17,6 +17,11 @@ limitations under the License. #define TENSORFLOW_CORE_LIB_HASH_CRC32C_H_ #include + +// NOLINTNEXTLINE +#include "tensorflow/core/platform/platform.h" +// NOLINTNEXTLINE +#include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -30,6 +35,11 @@ extern uint32 Extend(uint32 init_crc, const char* data, size_t n); // Return the crc32c of data[0,n-1] inline uint32 Value(const char* data, size_t n) { return Extend(0, data, n); } +#if defined(PLATFORM_GOOGLE) +extern uint32 Extend(uint32 init_crc, const absl::Cord& cord); +inline uint32 Value(const absl::Cord& cord) { return Extend(0, cord); } +#endif + static const uint32 kMaskDelta = 0xa282ead8ul; // Return a masked representation of crc. diff --git a/tensorflow/core/lib/hash/crc32c_test.cc b/tensorflow/core/lib/hash/crc32c_test.cc index 5213e4c532f..1080b7b1613 100644 --- a/tensorflow/core/lib/hash/crc32c_test.cc +++ b/tensorflow/core/lib/hash/crc32c_test.cc @@ -70,6 +70,17 @@ TEST(CRC, Mask) { ASSERT_EQ(crc, Unmask(Unmask(Mask(Mask(crc))))); } +#if defined(PLATFORM_GOOGLE) +TEST(CRC, ValuesWithCord) { + ASSERT_NE(Value(absl::Cord("a")), Value(absl::Cord("foo"))); +} + +TEST(CRC, ExtendWithCord) { + ASSERT_EQ(Value(absl::Cord("hello world")), + Extend(Value(absl::Cord("hello ")), absl::Cord("world"))); +} +#endif + static void BM_CRC(int iters, int len) { std::string input(len, 'x'); uint32 h = 0; diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 2c6db2487ea..b0718a35101 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -107,6 +107,27 @@ Status RecordWriter::WriteRecord(StringPiece data) { return dest_->Append(StringPiece(footer, sizeof(footer))); } +#if defined(PLATFORM_GOOGLE) +Status RecordWriter::WriteRecord(const absl::Cord& data) { + if (dest_ == nullptr) { + return Status(::tensorflow::error::FAILED_PRECONDITION, + "Writer not initialized or previously closed"); + } + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + char header[kHeaderSize]; + char footer[kFooterSize]; + PopulateHeader(header, data); + PopulateFooter(footer, data); + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + TF_RETURN_IF_ERROR(dest_->Append(data)); + return dest_->Append(StringPiece(footer, sizeof(footer))); +} +#endif + Status RecordWriter::Close() { if (dest_ == nullptr) return Status::OK(); #if !defined(IS_SLIM_BUILD) diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 1212e1fafb4..dba4d75799e 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" #endif // IS_SLIM_BUILD +#include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -71,6 +72,10 @@ class RecordWriter { Status WriteRecord(StringPiece slice); +#if defined(PLATFORM_GOOGLE) + Status WriteRecord(const absl::Cord& data); +#endif + // Flushes any buffered data held by underlying containers of the // RecordWriter to the WritableFile. Does *not* flush the // WritableFile. @@ -90,6 +95,11 @@ class RecordWriter { // "footer[0,kFooterSize-1]". The record-footer is based on data[0, n-1]. inline static void PopulateFooter(char* footer, const char* data, size_t n); +#if defined(PLATFORM_GOOGLE) + inline static void PopulateHeader(char* header, const absl::Cord& data); + inline static void PopulateFooter(char* footer, const absl::Cord& data); +#endif + private: WritableFile* dest_; RecordWriterOptions options_; @@ -98,6 +108,12 @@ class RecordWriter { return crc32c::Mask(crc32c::Value(data, n)); } +#if defined(PLATFORM_GOOGLE) + inline static uint32 MaskedCrc(const absl::Cord& data) { + return crc32c::Mask(crc32c::Value(data)); + } +#endif + TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); }; @@ -111,6 +127,18 @@ void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) { core::EncodeFixed32(footer, MaskedCrc(data, n)); } +#if defined(PLATFORM_GOOGLE) +void RecordWriter::PopulateHeader(char* header, const absl::Cord& data) { + core::EncodeFixed64(header + 0, data.size()); + core::EncodeFixed32(header + sizeof(uint64), + MaskedCrc(header, sizeof(uint64))); +} + +void RecordWriter::PopulateFooter(char* footer, const absl::Cord& data) { + core::EncodeFixed32(footer, MaskedCrc(data)); +} +#endif + } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc index e6a2e4a0662..e3e69909a3a 100644 --- a/tensorflow/core/lib/io/recordio_test.cc +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -62,6 +62,12 @@ class StringDest : public WritableFile { contents_->append(slice.data(), slice.size()); return Status::OK(); } +#if defined(PLATFORM_GOOGLE) + Status Append(const absl::Cord& data) override { + contents_->append(data.ToString()); + return Status::OK(); + } +#endif Status Tell(int64* pos) override { *pos = contents_->size(); return Status::OK(); @@ -130,6 +136,13 @@ class RecordioTest : public ::testing::Test { TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg))); } +#if defined(PLATFORM_GOOGLE) + void Write(const absl::Cord& msg) { + ASSERT_TRUE(!reading_) << "Write() after starting to read"; + TF_ASSERT_OK(writer_->WriteRecord(msg)); + } +#endif + size_t WrittenBytes() const { return contents_.size(); } string Read() { @@ -191,6 +204,21 @@ TEST_F(RecordioTest, ReadWrite) { ASSERT_EQ("EOF", Read()); // Make sure reads at eof work } +#if defined(PLATFORM_GOOGLE) +TEST_F(RecordioTest, ReadWriteCords) { + Write(absl::Cord("foo")); + Write(absl::Cord("bar")); + Write(absl::Cord("")); + Write(absl::Cord("xxxx")); + ASSERT_EQ("foo", Read()); + ASSERT_EQ("bar", Read()); + ASSERT_EQ("", Read()); + ASSERT_EQ("xxxx", Read()); + ASSERT_EQ("EOF", Read()); + ASSERT_EQ("EOF", Read()); // Make sure reads at eof work +} +#endif + TEST_F(RecordioTest, ManyRecords) { for (int i = 0; i < 100000; i++) { Write(NumberString(i)); diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h index b9c6b8d9d23..788d192816f 100644 --- a/tensorflow/core/lib/io/table.h +++ b/tensorflow/core/lib/io/table.h @@ -17,16 +17,15 @@ limitations under the License. #define TENSORFLOW_CORE_LIB_IO_TABLE_H_ #include + #include "tensorflow/core/lib/io/iterator.h" namespace tensorflow { + class RandomAccessFile; namespace table { -class Block; -class BlockHandle; -class Footer; struct Options; // A Table is a sorted map from strings to strings. Tables are diff --git a/tensorflow/core/lib/io/zlib_inputstream.cc b/tensorflow/core/lib/io/zlib_inputstream.cc index d069db6d20b..a489d2e9d50 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.cc +++ b/tensorflow/core/lib/io/zlib_inputstream.cc @@ -197,24 +197,21 @@ Status ZlibInputStream::ReadNBytes(int64 bytes_to_read, string* result) { // Now that the cache is empty we need to inflate more data. - // Step 1. Fill up input buffer. - // We read from stream only after the previously read contents have been - // completely consumed. This is an optimization and can be removed if - // it causes problems. `ReadFromStream` is capable of handling partially - // filled up buffers. - if (z_stream_def_->stream->avail_in == 0) { - TF_RETURN_IF_ERROR(ReadFromStream()); - } - - // Step 2. Setup output stream. + // Step 1. Setup output stream. z_stream_def_->stream->next_out = z_stream_def_->output.get(); next_unread_byte_ = reinterpret_cast(z_stream_def_->output.get()); z_stream_def_->stream->avail_out = output_buffer_capacity_; - // Step 3. Inflate Inflate Inflate! + // Step 2. Try to inflate some input data. TF_RETURN_IF_ERROR(Inflate()); - bytes_to_read -= ReadBytesFromCache(bytes_to_read, result); + // Step 3. Read any data produced by inflate. If no progress was made by + // inflate, read more compressed data from the input stream. + if (NumUnreadBytes() == 0) { + TF_RETURN_IF_ERROR(ReadFromStream()); + } else { + bytes_to_read -= ReadBytesFromCache(bytes_to_read, result); + } } return Status::OK(); @@ -224,7 +221,11 @@ int64 ZlibInputStream::Tell() const { return bytes_read_; } Status ZlibInputStream::Inflate() { int error = inflate(z_stream_def_->stream.get(), zlib_options_.flush_mode); - if (error != Z_OK && error != Z_STREAM_END) { + // Source: http://zlib.net/manual.html + // Z_BUF_ERROR: `inflate` returns Z_BUF_ERROR if no progress was made. This is + // not fatal and `inflate` can be called again with more input and output + // space to continue inflating. + if (error != Z_OK && error != Z_STREAM_END && error != Z_BUF_ERROR) { string error_string = strings::StrCat("inflate() failed with error ", error); if (z_stream_def_->stream->msg != nullptr) { diff --git a/tensorflow/core/lib/math/math_util_test.cc b/tensorflow/core/lib/math/math_util_test.cc index cad5d0d8993..386a4a24b6d 100644 --- a/tensorflow/core/lib/math/math_util_test.cc +++ b/tensorflow/core/lib/math/math_util_test.cc @@ -252,7 +252,7 @@ void TestFloatIPow(const int max_exponent, const T start, const T end, const T step) { for (T f = start; f < end; f += step) { for (int i = 0; i < max_exponent; ++i) { - EXPECT_FLOAT_EQ(MathUtil::IPow(f, i), pow(f, i)); + EXPECT_FLOAT_EQ(MathUtil::IPow(f, i), std::pow(f, i)); } } } diff --git a/tensorflow/core/lib/monitoring/collection_registry.cc b/tensorflow/core/lib/monitoring/collection_registry.cc index fface033cb9..461b92a36a9 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.cc +++ b/tensorflow/core/lib/monitoring/collection_registry.cc @@ -74,8 +74,9 @@ CollectionRegistry::Register(const AbstractMetricDef* const metric_def, const auto found_it = registry_.find(metric_def->name()); if (found_it != registry_.end()) { - LOG(FATAL) << "Cannot register 2 metrics with the same name: " + LOG(ERROR) << "Cannot register 2 metrics with the same name: " << metric_def->name(); + return nullptr; } registry_.insert( {metric_def->name(), diff --git a/tensorflow/core/lib/monitoring/collection_registry_test.cc b/tensorflow/core/lib/monitoring/collection_registry_test.cc index ce87e4dcae6..52cdb840068 100644 --- a/tensorflow/core/lib/monitoring/collection_registry_test.cc +++ b/tensorflow/core/lib/monitoring/collection_registry_test.cc @@ -73,12 +73,9 @@ TEST(CollectionRegistryDeathTest, DuplicateRegistration) { auto handle = collection_registry->Register(&metric_def, EmptyCollectionFunction); - EXPECT_DEATH( - { - auto duplicate_handle = - collection_registry->Register(&metric_def, EmptyCollectionFunction); - }, - "/tensorflow/metric"); + auto duplicate_handle = + collection_registry->Register(&metric_def, EmptyCollectionFunction); + EXPECT_EQ(duplicate_handle, nullptr); } TEST(CollectMetricsTest, Counter) { @@ -374,7 +371,7 @@ class FakeClockEnv : public EnvWrapper { void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; } // Method that this environment specifically overrides. - uint64 NowMicros() override { return current_millis_ * 1000; } + uint64 NowMicros() const override { return current_millis_ * 1000; } private: uint64 current_millis_; diff --git a/tensorflow/core/lib/monitoring/counter.h b/tensorflow/core/lib/monitoring/counter.h index 8ff810db41d..20522192778 100644 --- a/tensorflow/core/lib/monitoring/counter.h +++ b/tensorflow/core/lib/monitoring/counter.h @@ -16,9 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ #define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +// clang-format on + // We replace this implementation with a null implementation for mobile // platforms. -#include "tensorflow/core/platform/platform.h" #ifdef IS_MOBILE_PLATFORM #include "tensorflow/core/lib/monitoring/mobile_counter.h" #else @@ -27,6 +31,7 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/lib/monitoring/metric_def.h" #include "tensorflow/core/platform/logging.h" @@ -48,7 +53,7 @@ namespace monitoring { // This class is thread-safe. class CounterCell { public: - CounterCell(int64 value) : value_(value) {} + explicit CounterCell(int64 value) : value_(value) {} ~CounterCell() {} // Atomically increments the value by step. @@ -97,6 +102,8 @@ class Counter { template CounterCell* GetCell(const Labels&... labels) LOCKS_EXCLUDED(mu_); + Status GetStatus() { return status_; } + private: explicit Counter( const MetricDef& metric_def) @@ -109,10 +116,19 @@ class Counter { for (const auto& cell : cells_) { metric_collector.CollectValue(cell.first, cell.second.value()); } - })) {} + })) { + if (registration_handle_) { + status_ = Status::OK(); + } else { + status_ = Status(tensorflow::error::Code::ALREADY_EXISTS, + "Another metric with the same name already exists."); + } + } mutable mutex mu_; + Status status_; + // The metric definition. This will be used to identify the metric when we // register it for collection. const MetricDef metric_def_; diff --git a/tensorflow/core/lib/monitoring/counter_test.cc b/tensorflow/core/lib/monitoring/counter_test.cc index 75b9a19a6fd..1dec04df980 100644 --- a/tensorflow/core/lib/monitoring/counter_test.cc +++ b/tensorflow/core/lib/monitoring/counter_test.cc @@ -86,6 +86,14 @@ TEST(UnlabeledCounterDeathTest, DiesOnDecrement) { "decrement"); } +TEST(LabeledCounterTest, SameName) { + auto* same_counter = Counter<1>::New("/tensorflow/test/counter_with_labels", + "Counter with one label.", "MyLabel"); + EXPECT_TRUE(counter_with_labels->GetStatus().ok()); + EXPECT_FALSE(same_counter->GetStatus().ok()); + delete same_counter; +} + } // namespace } // namespace monitoring } // namespace tensorflow diff --git a/tensorflow/core/lib/monitoring/gauge.h b/tensorflow/core/lib/monitoring/gauge.h index ee9a862f40a..83edf68d1bf 100644 --- a/tensorflow/core/lib/monitoring/gauge.h +++ b/tensorflow/core/lib/monitoring/gauge.h @@ -16,9 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ #define TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +// clang-format on + // We replace this implementation with a null implementation for mobile // platforms. -#include "tensorflow/core/platform/platform.h" #ifdef IS_MOBILE_PLATFORM #include "tensorflow/core/lib/monitoring/mobile_gauge.h" #else @@ -27,6 +31,7 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/lib/monitoring/metric_def.h" #include "tensorflow/core/platform/macros.h" @@ -149,6 +154,8 @@ class Gauge { template GaugeCell* GetCell(const Labels&... labels) LOCKS_EXCLUDED(mu_); + Status GetStatus() { return status_; } + private: explicit Gauge( const MetricDef& metric_def) @@ -161,10 +168,19 @@ class Gauge { for (const auto& cell : cells_) { metric_collector.CollectValue(cell.first, cell.second.value()); } - })) {} + })) { + if (registration_handle_) { + status_ = Status::OK(); + } else { + status_ = Status(tensorflow::error::Code::ALREADY_EXISTS, + "Another metric with the same name already exists."); + } + } mutable mutex mu_; + Status status_; + // The metric definition. This will be used to identify the metric when we // register it for collection. const MetricDef metric_def_; diff --git a/tensorflow/core/lib/monitoring/gauge_test.cc b/tensorflow/core/lib/monitoring/gauge_test.cc index c8f673db389..7bbe7596fe4 100644 --- a/tensorflow/core/lib/monitoring/gauge_test.cc +++ b/tensorflow/core/lib/monitoring/gauge_test.cc @@ -109,6 +109,14 @@ TEST(GaugeOfBoolValue, GetCell) { EXPECT_EQ(false, same_cell->value()); } +TEST(LabeledGaugeTest, SameName) { + auto* same_gauge = Gauge::New("/tensorflow/test/gauge_with_labels", + "Gauge with one label.", "MyLabel"); + EXPECT_TRUE(gauge_with_labels->GetStatus().ok()); + EXPECT_FALSE(same_gauge->GetStatus().ok()); + delete same_gauge; +} + } // namespace } // namespace monitoring } // namespace tensorflow diff --git a/tensorflow/core/lib/monitoring/mobile_counter.h b/tensorflow/core/lib/monitoring/mobile_counter.h index c297d843d2f..db46072a3ee 100644 --- a/tensorflow/core/lib/monitoring/mobile_counter.h +++ b/tensorflow/core/lib/monitoring/mobile_counter.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_ #define TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_ +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -53,6 +54,8 @@ class Counter { return &default_counter_cell_; } + Status GetStatus() { return Status::OK(); } + private: Counter() {} diff --git a/tensorflow/core/lib/monitoring/mobile_gauge.h b/tensorflow/core/lib/monitoring/mobile_gauge.h index a03b41aef33..bb86d253b11 100644 --- a/tensorflow/core/lib/monitoring/mobile_gauge.h +++ b/tensorflow/core/lib/monitoring/mobile_gauge.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ #define TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -58,6 +59,8 @@ class Gauge { return &default_gauge_cell_; } + Status GetStatus() { return Status::OK(); } + private: Gauge() {} diff --git a/tensorflow/core/lib/monitoring/mobile_sampler.h b/tensorflow/core/lib/monitoring/mobile_sampler.h index 77310dd619f..5233f0ff472 100644 --- a/tensorflow/core/lib/monitoring/mobile_sampler.h +++ b/tensorflow/core/lib/monitoring/mobile_sampler.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/monitoring/metric_def.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -86,6 +87,8 @@ class Sampler { return &default_sampler_cell_; } + Status GetStatus() { return Status::OK(); } + private: Sampler(std::unique_ptr buckets) : buckets_(std::move(buckets)) {} diff --git a/tensorflow/core/lib/monitoring/sampler.cc b/tensorflow/core/lib/monitoring/sampler.cc index 23d3668fbd1..20c5f1a73fe 100644 --- a/tensorflow/core/lib/monitoring/sampler.cc +++ b/tensorflow/core/lib/monitoring/sampler.cc @@ -15,9 +15,13 @@ limitations under the License. #include "tensorflow/core/lib/monitoring/sampler.h" +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +// clang-format on + // We replace this implementation with a null implementation for mobile // platforms. -#include "tensorflow/core/platform/platform.h" #ifdef IS_MOBILE_PLATFORM // Do nothing. #else @@ -92,6 +96,12 @@ class ExponentialBuckets : public Buckets { } // namespace +// static +std::unique_ptr Buckets::Explicit(std::vector bucket_limits) { + return std::unique_ptr( + new ExplicitBuckets(std::move(bucket_limits))); +} + // static std::unique_ptr Buckets::Explicit( std::initializer_list bucket_limits) { diff --git a/tensorflow/core/lib/monitoring/sampler.h b/tensorflow/core/lib/monitoring/sampler.h index a4f397f5566..c6f32d46fa2 100644 --- a/tensorflow/core/lib/monitoring/sampler.h +++ b/tensorflow/core/lib/monitoring/sampler.h @@ -16,17 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ #define TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +// clang-format on + // We replace this implementation with a null implementation for mobile // platforms. -#include "tensorflow/core/platform/platform.h" #ifdef IS_MOBILE_PLATFORM #include "tensorflow/core/lib/monitoring/mobile_sampler.h" #else #include + #include #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/histogram/histogram.h" #include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/lib/monitoring/metric_def.h" @@ -90,6 +96,11 @@ class Buckets { static std::unique_ptr Explicit( std::initializer_list bucket_limits); + // This alternative Explicit Buckets factory method is primarily meant to be + // used by the CLIF layer code paths that are incompatible with + // initialize_lists. + static std::unique_ptr Explicit(std::vector bucket_limits); + virtual const std::vector& explicit_bounds() const = 0; }; @@ -128,6 +139,8 @@ class Sampler { template SamplerCell* GetCell(const Labels&... labels) LOCKS_EXCLUDED(mu_); + Status GetStatus() { return status_; } + private: friend class SamplerCell; @@ -144,10 +157,19 @@ class Sampler { for (const auto& cell : cells_) { metric_collector.CollectValue(cell.first, cell.second.value()); } - })) {} + })) { + if (registration_handle_) { + status_ = Status::OK(); + } else { + status_ = Status(tensorflow::error::Code::ALREADY_EXISTS, + "Another metric with the same name already exists."); + } + } mutable mutex mu_; + Status status_; + // The metric definition. This will be used to identify the metric when we // register it for collection. const MetricDef diff --git a/tensorflow/core/lib/monitoring/sampler_test.cc b/tensorflow/core/lib/monitoring/sampler_test.cc index d61d858b6b4..8be15f92185 100644 --- a/tensorflow/core/lib/monitoring/sampler_test.cc +++ b/tensorflow/core/lib/monitoring/sampler_test.cc @@ -61,7 +61,7 @@ TEST(LabeledSamplerTest, ExplicitBucketBoundaries) { auto* init_sampler_without_labels = Sampler<0>::New({"/tensorflow/test/init_sampler_without_labels", "Sampler without labels initialized as empty."}, - Buckets::Explicit({1.5, 2.8})); + Buckets::Explicit(std::vector{1.5, 2.8})); TEST(UnlabeledSamplerTest, InitializedEmpty) { Histogram empty; @@ -112,6 +112,15 @@ TEST(ExponentialSamplerTest, ExponentialBucketBoundaries) { EqHistograms(expected, cell->value()); } +TEST(ExplicitSamplerTest, SameName) { + auto* same_sampler = Sampler<1>::New({"/tensorflow/test/sampler_with_labels", + "Sampler with one label.", "MyLabel"}, + Buckets::Explicit({10.0, 20.0})); + EXPECT_TRUE(sampler_with_labels->GetStatus().ok()); + EXPECT_FALSE(same_sampler->GetStatus().ok()); + delete same_sampler; +} + } // namespace } // namespace monitoring } // namespace tensorflow diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index c3801a04128..102f9ba7ea8 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -16,12 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#include + #define _USE_MATH_DEFINES #include #include #undef _USE_MATH_DEFINES -#include #include #include @@ -236,6 +237,73 @@ class UniformDistribution { uint64 range_; }; +// Similar to `UniformDistribution`, except that instead of generating numbers +// in the range [low, high), it generates numbers covering the whole range of +// the integer type. +template +class UniformFullIntDistribution; + +template +class UniformFullIntDistribution32 { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount; + // Cost of generation of a single element (in cycles). + static const int kElementCost = 3; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array ResultType; + typedef IntType ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; ++i) { + result[i] = sample[i]; + } + return result; + } +}; + +template +class UniformFullIntDistribution64 { + public: + // The number of elements that will be returned. + static const int kResultElementCount = Generator::kResultElementCount / 2; + // Cost of generation of a single element (in cycles). + static const int kElementCost = 3; + // Indicate that this distribution may take variable number of samples + // during the runtime. + static const bool kVariableSamplesPerOutput = false; + typedef Array ResultType; + typedef IntType ResultElementType; + + PHILOX_DEVICE_INLINE + ResultType operator()(Generator* gen) { + typename Generator::ResultType sample = (*gen)(); + ResultType result; + for (int i = 0; i < kResultElementCount; ++i) { + result[i] = sample[2 * i] | static_cast(sample[2 * i + 1]) << 32; + } + return result; + } +}; + +template +class UniformFullIntDistribution + : public UniformFullIntDistribution32 {}; +template +class UniformFullIntDistribution + : public UniformFullIntDistribution32 {}; +template +class UniformFullIntDistribution + : public UniformFullIntDistribution64 {}; +template +class UniformFullIntDistribution + : public UniformFullIntDistribution64 {}; + // A class that adapts the underlying native multiple samples to return a single // sample at a time. template diff --git a/tensorflow/core/lib/strings/base64.h b/tensorflow/core/lib/strings/base64.h index 48a7f42b81d..cb8f50df11f 100644 --- a/tensorflow/core/lib/strings/base64.h +++ b/tensorflow/core/lib/strings/base64.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LIB_STRINGS_B64_H_ -#define TENSORFLOW_LIB_STRINGS_B64_H_ +#ifndef TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ #include #include "tensorflow/core/lib/core/status.h" @@ -34,4 +34,4 @@ Status Base64Decode(StringPiece data, string* decoded); } // namespace tensorflow -#endif // TENSORFLOW_LIB_STRINGS_B64_H_ +#endif // TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ diff --git a/tensorflow/core/lib/strings/proto_serialization.cc b/tensorflow/core/lib/strings/proto_serialization.cc index 2341d3e341d..5ffc845d098 100644 --- a/tensorflow/core/lib/strings/proto_serialization.cc +++ b/tensorflow/core/lib/strings/proto_serialization.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/proto_serialization.h" #include + #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -24,7 +25,35 @@ limitations under the License. namespace tensorflow { namespace { -static const int kInlinedBufferSize = 256; + +// Helper for deterministic serialization. +class DeterministicSerializer { + public: + explicit DeterministicSerializer(const protobuf::MessageLite& msg) + : DeterministicSerializer(msg, msg.ByteSizeLong()) {} + + DeterministicSerializer(const protobuf::MessageLite& msg, size_t size) + : size_(size) { + char* ptr = space_; + if (size_ > sizeof(space_)) { + ptr = new char[size_]; + alloc_.reset(ptr); + } + bool ok = SerializeToBufferDeterministic(msg, ptr, size_); + DCHECK(ok); + } + + size_t size() const { return size_; } + const char* data() const { return alloc_ == nullptr ? space_ : alloc_.get(); } + + private: + // Avoid InlinedVector since it causes 2x slowdown in the compilation + // of graphs containing large tensors in debug mode. + static constexpr int kInlinedBufferSize = 256; + const size_t size_; + std::unique_ptr alloc_; + char space_[kInlinedBufferSize]; +}; } // namespace bool SerializeToStringDeterministic(const protobuf::MessageLite& msg, @@ -51,28 +80,20 @@ bool AreSerializedProtosEqual(const protobuf::MessageLite& x, const size_t size = x.ByteSizeLong(); if (size != y.ByteSizeLong()) return false; if (size == 0) return true; - gtl::InlinedVector x_serialized(size); - bool success_x = SerializeToBufferDeterministic(x, x_serialized.data(), size); - DCHECK(success_x); - gtl::InlinedVector y_serialized(size); - bool success_y = SerializeToBufferDeterministic(y, y_serialized.data(), size); - DCHECK(success_y); + DeterministicSerializer x_serialized(x, size); + DeterministicSerializer y_serialized(y, size); return memcmp(x_serialized.data(), y_serialized.data(), size) == 0; } uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto, uint64 seed) { - const size_t size = proto.ByteSizeLong(); - gtl::InlinedVector serialized(size); - SerializeToBufferDeterministic(proto, serialized.data(), size); - return Hash64(serialized.data(), size, seed); + DeterministicSerializer serialized(proto); + return Hash64(serialized.data(), serialized.size(), seed); } uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto) { - const size_t size = proto.ByteSizeLong(); - gtl::InlinedVector serialized(size); - SerializeToBufferDeterministic(proto, serialized.data(), size); - return Hash64(serialized.data(), size); + DeterministicSerializer serialized(proto); + return Hash64(serialized.data(), serialized.size()); } } // namespace tensorflow diff --git a/tensorflow/core/lib/strings/proto_text_util.h b/tensorflow/core/lib/strings/proto_text_util.h index 05dbda6e152..7fb99c400b3 100644 --- a/tensorflow/core/lib/strings/proto_text_util.h +++ b/tensorflow/core/lib/strings/proto_text_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ #define TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ +#include "absl/strings/str_cat.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -101,8 +102,8 @@ class ProtoTextOutput { private: void AppendFieldAndValue(const char field_name[], StringPiece value_text) { - StrAppend(output_, level_empty_ ? "" : field_separator_, indent_, - field_name, kColonSeparator, value_text); + absl::StrAppend(output_, level_empty_ ? "" : field_separator_, indent_, + field_name, kColonSeparator, value_text); level_empty_ = false; } diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc index 7584f6a2391..b2feadea9bc 100644 --- a/tensorflow/core/lib/strings/str_util.cc +++ b/tensorflow/core/lib/strings/str_util.cc @@ -19,6 +19,10 @@ limitations under the License. #include #include #include +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/strip.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -26,196 +30,10 @@ limitations under the License. namespace tensorflow { namespace str_util { -static char hex_char[] = "0123456789abcdef"; - -string CEscape(StringPiece src) { - string dest; - - for (unsigned char c : src) { - switch (c) { - case '\n': - dest.append("\\n"); - break; - case '\r': - dest.append("\\r"); - break; - case '\t': - dest.append("\\t"); - break; - case '\"': - dest.append("\\\""); - break; - case '\'': - dest.append("\\'"); - break; - case '\\': - dest.append("\\\\"); - break; - default: - // Note that if we emit \xNN and the src character after that is a hex - // digit then that digit must be escaped too to prevent it being - // interpreted as part of the character code by C. - if ((c >= 0x80) || !isprint(c)) { - dest.append("\\"); - dest.push_back(hex_char[c / 64]); - dest.push_back(hex_char[(c % 64) / 8]); - dest.push_back(hex_char[c % 8]); - } else { - dest.push_back(c); - break; - } - } - } - - return dest; -} +string CEscape(StringPiece src) { return absl::CEscape(src); } namespace { // Private helpers for CUnescape(). -inline bool is_octal_digit(unsigned char c) { return c >= '0' && c <= '7'; } - -inline bool ascii_isxdigit(unsigned char c) { - return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || - (c >= 'A' && c <= 'F'); -} - -inline int hex_digit_to_int(char c) { - int x = static_cast(c); - if (x > '9') { - x += 9; - } - return x & 0xf; -} - -bool CUnescapeInternal(StringPiece source, string* dest, - string::size_type* dest_len, string* error) { - const char* p = source.data(); - const char* end = source.end(); - const char* last_byte = end - 1; - - // We are going to write the result to dest with its iterator. If our string - // implementation uses copy-on-write, this will trigger a copy-on-write of - // dest's buffer; that is, dest will be assigned a new buffer. - // - // Note that the following way is NOT a legal way to modify a string's - // content: - // - // char* d = const_cast(dest->data()); - // - // This won't trigger copy-on-write of the string, and so is dangerous when - // the buffer is shared. - auto d = dest->begin(); - - // Small optimization for case where source = dest and there's no escaping - if (source.data() == dest->data()) { - while (p < end && *p != '\\') { - p++; - d++; - } - } - - while (p < end) { - if (*p != '\\') { - *d++ = *p++; - } else { - if (++p > last_byte) { // skip past the '\\' - if (error) *error = "String cannot end with \\"; - return false; - } - switch (*p) { - case 'a': - *d++ = '\a'; - break; - case 'b': - *d++ = '\b'; - break; - case 'f': - *d++ = '\f'; - break; - case 'n': - *d++ = '\n'; - break; - case 'r': - *d++ = '\r'; - break; - case 't': - *d++ = '\t'; - break; - case 'v': - *d++ = '\v'; - break; - case '\\': - *d++ = '\\'; - break; - case '?': - *d++ = '\?'; - break; // \? Who knew? - case '\'': - *d++ = '\''; - break; - case '"': - *d++ = '\"'; - break; - case '0': - case '1': - case '2': - case '3': // octal digit: 1 to 3 digits - case '4': - case '5': - case '6': - case '7': { - const char* octal_start = p; - unsigned int ch = *p - '0'; - if (p < last_byte && is_octal_digit(p[1])) ch = ch * 8 + *++p - '0'; - if (p < last_byte && is_octal_digit(p[1])) - ch = ch * 8 + *++p - '0'; // now points at last digit - if (ch > 0xff) { - if (error) { - *error = "Value of \\" + - string(octal_start, p + 1 - octal_start) + - " exceeds 0xff"; - } - return false; - } - *d++ = ch; - break; - } - case 'x': - case 'X': { - if (p >= last_byte) { - if (error) *error = "String cannot end with \\x"; - return false; - } else if (!ascii_isxdigit(p[1])) { - if (error) *error = "\\x cannot be followed by a non-hex digit"; - return false; - } - unsigned int ch = 0; - const char* hex_start = p; - while (p < last_byte && ascii_isxdigit(p[1])) - // Arbitrarily many hex digits - ch = (ch << 4) + hex_digit_to_int(*++p); - if (ch > 0xFF) { - if (error) { - *error = "Value of \\" + string(hex_start, p + 1 - hex_start) + - " exceeds 0xff"; - } - return false; - } - *d++ = ch; - break; - } - default: { - if (error) *error = string("Unknown escape sequence: \\") + *p; - return false; - } - } - p++; // read past letter we escaped - } - } - *dest_len = d - dest->begin(); - return true; -} - template bool SplitAndParseAsInts(StringPiece text, char delim, std::function converter, @@ -233,39 +51,18 @@ bool SplitAndParseAsInts(StringPiece text, char delim, } // namespace bool CUnescape(StringPiece source, string* dest, string* error) { - dest->resize(source.size()); - string::size_type dest_size; - if (!CUnescapeInternal(source, dest, &dest_size, error)) { - return false; - } - dest->erase(dest_size); - return true; + return absl::CUnescape(source, dest, error); } void StripTrailingWhitespace(string* s) { - string::size_type i; - for (i = s->size(); i > 0 && isspace((*s)[i - 1]); --i) { - } - s->resize(i); + absl::StripTrailingAsciiWhitespace(s); } // Return lower-cased version of s. -string Lowercase(StringPiece s) { - string result(s.data(), s.size()); - for (char& c : result) { - c = tolower(c); - } - return result; -} +string Lowercase(StringPiece s) { return absl::AsciiStrToLower(s); } // Return upper-cased version of s. -string Uppercase(StringPiece s) { - string result(s.data(), s.size()); - for (char& c : result) { - c = toupper(c); - } - return result; -} +string Uppercase(StringPiece s) { return absl::AsciiStrToUpper(s); } string ArgDefCase(StringPiece s) { const size_t n = s.size(); @@ -349,46 +146,32 @@ string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub, } size_t RemoveLeadingWhitespace(StringPiece* text) { - size_t count = 0; - const char* ptr = text->data(); - while (count < text->size() && isspace(*ptr)) { - count++; - ptr++; - } - text->remove_prefix(count); + absl::string_view new_text = absl::StripLeadingAsciiWhitespace(*text); + size_t count = text->size() - new_text.size(); + *text = new_text; return count; } size_t RemoveTrailingWhitespace(StringPiece* text) { - size_t count = 0; - const char* ptr = text->data() + text->size() - 1; - while (count < text->size() && isspace(*ptr)) { - ++count; - --ptr; - } - text->remove_suffix(count); + absl::string_view new_text = absl::StripTrailingAsciiWhitespace(*text); + size_t count = text->size() - new_text.size(); + *text = new_text; return count; } size_t RemoveWhitespaceContext(StringPiece* text) { - // use RemoveLeadingWhitespace() and RemoveTrailingWhitespace() to do the job - return (RemoveLeadingWhitespace(text) + RemoveTrailingWhitespace(text)); + absl::string_view new_text = absl::StripAsciiWhitespace(*text); + size_t count = text->size() - new_text.size(); + *text = new_text; + return count; } bool ConsumePrefix(StringPiece* s, StringPiece expected) { - if (StartsWith(*s, expected)) { - s->remove_prefix(expected.size()); - return true; - } - return false; + return absl::ConsumePrefix(s, expected); } bool ConsumeSuffix(StringPiece* s, StringPiece expected) { - if (EndsWith(*s, expected)) { - s->remove_suffix(expected.size()); - return true; - } - return false; + return absl::ConsumeSuffix(s, expected); } bool ConsumeLeadingDigits(StringPiece* s, uint64* val) { @@ -447,11 +230,12 @@ bool SplitAndParseAsInts(StringPiece text, char delim, bool SplitAndParseAsFloats(StringPiece text, char delim, std::vector* result) { - return SplitAndParseAsInts(text, delim, - [](StringPiece str, float* value) { - return strings::safe_strtof(str, value); - }, - result); + return SplitAndParseAsInts( + text, delim, + [](StringPiece str, float* value) { + return strings::safe_strtof(str, value); + }, + result); } size_t Strnlen(const char* str, const size_t string_max_len) { @@ -463,20 +247,15 @@ size_t Strnlen(const char* str, const size_t string_max_len) { } bool StrContains(StringPiece haystack, StringPiece needle) { - return std::search(haystack.begin(), haystack.end(), needle.begin(), - needle.end()) != haystack.end(); + return absl::StrContains(haystack, needle); } bool StartsWith(StringPiece text, StringPiece prefix) { - return prefix.empty() || - (text.size() >= prefix.size() && - memcmp(text.data(), prefix.data(), prefix.size()) == 0); + return absl::StartsWith(text, prefix); } bool EndsWith(StringPiece text, StringPiece suffix) { - return suffix.empty() || (text.size() >= suffix.size() && - memcmp(text.data() + (text.size() - suffix.size()), - suffix.data(), suffix.size()) == 0); + return absl::EndsWith(text, suffix); } } // namespace str_util diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h index 9f52cf29fc3..5441dfc25a5 100644 --- a/tensorflow/core/lib/strings/str_util.h +++ b/tensorflow/core/lib/strings/str_util.h @@ -19,6 +19,10 @@ limitations under the License. #include #include #include +#include "absl/base/macros.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" @@ -29,6 +33,7 @@ namespace str_util { // Returns a version of 'src' where unprintable characters have been // escaped using C-style escape sequences. +ABSL_DEPRECATED("Use absl::CEscape instead.") string CEscape(StringPiece src); // Copies "source" to "dest", rewriting C-style escape sequences -- @@ -38,21 +43,26 @@ string CEscape(StringPiece src); // 'error'. To disable error reporting, set 'error' to NULL. // // NOTE: Does not support \u or \U! +ABSL_DEPRECATED("Use absl::CUnescape instead.") bool CUnescape(StringPiece source, string* dest, string* error); // Removes any trailing whitespace from "*s". +ABSL_DEPRECATED("Use absl::StripTrailingAsciiWhitespace instead.") void StripTrailingWhitespace(string* s); // Removes leading ascii_isspace() characters. // Returns number of characters removed. +ABSL_DEPRECATED("Use absl::StripLeadingAsciiWhitespace instead.") size_t RemoveLeadingWhitespace(StringPiece* text); // Removes trailing ascii_isspace() characters. // Returns number of characters removed. +ABSL_DEPRECATED("Use absl::StripTrailingAsciiWhitespace instead.") size_t RemoveTrailingWhitespace(StringPiece* text); // Removes leading and trailing ascii_isspace() chars. // Returns number of chars removed. +ABSL_DEPRECATED("Use absl::StripAsciiWhitespace instead.") size_t RemoveWhitespaceContext(StringPiece* text); // Consume a leading positive integer value. If any digits were @@ -68,16 +78,20 @@ bool ConsumeNonWhitespace(StringPiece* s, StringPiece* val); // If "*s" starts with "expected", consume it and return true. // Otherwise, return false. +ABSL_DEPRECATED("Use absl::ConsumeSuffix instead.") bool ConsumePrefix(StringPiece* s, StringPiece expected); // If "*s" ends with "expected", remove it and return true. // Otherwise, return false. +ABSL_DEPRECATED("Use absl::ConsumePrefix instead.") bool ConsumeSuffix(StringPiece* s, StringPiece expected); // Return lower-cased version of s. +ABSL_DEPRECATED("Use absl::AsciiStrToLower instead.") string Lowercase(StringPiece s); // Return upper-cased version of s. +ABSL_DEPRECATED("Use absl::AsciiStrToUpper instead.") string Uppercase(StringPiece s); // Converts "^2ILoveYou!" to "i_love_you_". More specifically: @@ -102,12 +116,14 @@ string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub, // Join functionality template +ABSL_DEPRECATED("Use absl::StrJoin instead.") string Join(const T& s, const char* sep); // A variant of Join where for each element of "s", f(&dest_string, elem) // is invoked (f is often constructed with a lambda of the form: // [](string* result, ElemType elem) template +ABSL_DEPRECATED("Use absl::StrJoin instead.") string Join(const T& s, const char* sep, Formatter f); struct AllowEmpty { @@ -118,16 +134,17 @@ struct SkipEmpty { }; struct SkipWhitespace { bool operator()(StringPiece sp) const { - RemoveTrailingWhitespace(&sp); - return !sp.empty(); + return !absl::StripTrailingAsciiWhitespace(sp).empty(); } }; // Split strings using any of the supplied delimiters. For example: // Split("a,b.c,d", ".,") would return {"a", "b", "c", "d"}. +ABSL_DEPRECATED("Use absl::StrSplit instead.") std::vector Split(StringPiece text, StringPiece delims); template +ABSL_DEPRECATED("Use absl::StrSplit instead.") std::vector Split(StringPiece text, StringPiece delims, Predicate p); // Split "text" at "delim" characters, and parse each component as @@ -143,29 +160,26 @@ bool SplitAndParseAsFloats(StringPiece text, char delim, // StartsWith() // // Returns whether a given string `text` begins with `prefix`. +ABSL_DEPRECATED("Use absl::StartsWith instead.") bool StartsWith(StringPiece text, StringPiece prefix); // EndsWith() // // Returns whether a given string `text` ends with `suffix`. +ABSL_DEPRECATED("Use absl::EndsWith instead.") bool EndsWith(StringPiece text, StringPiece suffix); // StrContains() // // Returns whether a given string `haystack` contains the substring `needle`. +ABSL_DEPRECATED("Use absl::StrContains instead.") bool StrContains(StringPiece haystack, StringPiece needle); // ------------------------------------------------------------------ // Implementation details below template string Join(const T& s, const char* sep) { - string result; - bool first = true; - for (const auto& x : s) { - tensorflow::strings::StrAppend(&result, (first ? "" : sep), x); - first = false; - } - return result; + return absl::StrJoin(s, sep); } template @@ -180,47 +194,29 @@ class Formatter { template string Join(const T& s, const char* sep, Formatter f) { - string result; - bool first = true; - for (const auto& x : s) { - if (!first) { - result.append(sep); - } - f(&result, x); - first = false; - } - return result; + return absl::StrJoin(s, sep, f); } inline std::vector Split(StringPiece text, StringPiece delims) { - return Split(text, delims, AllowEmpty()); + return text.empty() ? std::vector() + : absl::StrSplit(text, absl::ByAnyChar(delims)); } template std::vector Split(StringPiece text, StringPiece delims, Predicate p) { - std::vector result; - size_t token_start = 0; - if (!text.empty()) { - for (size_t i = 0; i < text.size() + 1; i++) { - if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) { - StringPiece token(text.data() + token_start, i - token_start); - if (p(token)) { - result.emplace_back(token); - } - token_start = i + 1; - } - } - } - return result; + return text.empty() ? std::vector() + : absl::StrSplit(text, absl::ByAnyChar(delims), p); } +ABSL_DEPRECATED("Use absl::StrSplit instead.") inline std::vector Split(StringPiece text, char delim) { - return Split(text, StringPiece(&delim, 1)); + return text.empty() ? std::vector() : absl::StrSplit(text, delim); } template -std::vector Split(StringPiece text, char delims, Predicate p) { - return Split(text, StringPiece(&delims, 1), p); +ABSL_DEPRECATED("Use absl::StrSplit instead.") +std::vector Split(StringPiece text, char delim, Predicate p) { + return text.empty() ? std::vector() : absl::StrSplit(text, delim, p); } // Returns the length of the given null-terminated byte string 'str'. diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h index a620f594476..ef308052767 100644 --- a/tensorflow/core/lib/strings/strcat.h +++ b/tensorflow/core/lib/strings/strcat.h @@ -52,7 +52,7 @@ limitations under the License. // You can convert to Hexadecimal output rather than Decimal output using Hex. // To do this, pass strings::Hex(my_int) as a parameter to StrCat. You may // specify a minimum field width using a separate parameter, so the equivalent -// of Printf("%04x", my_int) is StrCat(Hex(my_int, strings::ZERO_PAD_4)) +// of Printf("%04x", my_int) is StrCat(Hex(my_int, strings::kZeroPad4)) // // This class has implicit constructors. namespace tensorflow { diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index a0b602f301c..ca6e64c34d1 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -18,6 +18,7 @@ limitations under the License. #ifdef GOOGLE_CUDA +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/cuda.h" #include "tensorflow/core/platform/env.h" @@ -44,14 +45,10 @@ using se::cuda::ScopedActivateExecutorContext; // Contains data for a single stream used for nccl communication; this includes // a background thread that calls NcclManager::LoopKernelLaunches. -struct NcclManager::NcclStream { +struct NcclManager::NcclStream : public core::RefCounted { public: - NcclStream() {} - ~NcclStream() { - mutex_lock l(mu); - shutdown_requested = true; - cv.notify_all(); - } + NcclStream() = default; + ~NcclStream() = default; se::StreamExecutor* executor = nullptr; @@ -59,11 +56,13 @@ struct NcclManager::NcclStream { // This is a different stream than the tensorflow compute stream. std::unique_ptr stream; - // See NcclManager::LoopKernelLaunches for information on these. - std::unique_ptr thread; + // `mu` protects access to `pending_launches_`, which is the list of + // collectives ready but whose kernels are yet to be launched. When the + // NcclManager object that owns this NcclStream object is destroyed, it + // signals `cv` to unblock the thread waiting on more collectives. mutex mu; condition_variable cv; - // Has collective,participant_idx pairs. + // Has (collective, participant_idx) pairs. std::deque> pending_launches_ GUARDED_BY(mu); bool shutdown_requested GUARDED_BY(mu) = false; }; @@ -74,9 +73,9 @@ struct NcclManager::CommunicatorMember { ~CommunicatorMember() { if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm); } - ncclComm_t nccl_comm; - // Owned by NcclManager::device_to_comm_streams_. + ncclComm_t nccl_comm = nullptr; + // Owned by NcclManager::device_to_comm_streams_ and LoopKernelLaunches. NcclStream* nccl_stream = nullptr; }; @@ -127,7 +126,7 @@ void StringToNcclUniqueId(const string& str_id, ncclUniqueId* nccl_id) { // have a single `Collective` per step. However, a collective that executes on // 3 nodes with 4 GPUs each would have a `Collective` per node, each of which is // tracking the 4 GPUs local to that node. -struct NcclManager::Collective { +struct NcclManager::Collective : public core::RefCounted { Collective(DataType data_type_in, CollectiveType type_in, ncclRedOp_t reduction_op_in, int num_local_devices_in, int num_global_devices_in, const string& communicator_key_in) @@ -137,8 +136,7 @@ struct NcclManager::Collective { num_local_devices(num_local_devices_in), num_global_devices(num_global_devices_in), single_node(num_local_devices_in == num_global_devices_in), - communicator_key(communicator_key_in), - remaining_participants(num_local_devices_in) { + communicator_key(communicator_key_in) { participants.reserve(num_local_devices_in); } @@ -174,13 +172,23 @@ struct NcclManager::Collective { int available_participants = 0; bool multi_node_ready = false; - mutable std::atomic_int_fast32_t remaining_participants; - Status status; }; -NcclManager::NcclManager() {} -NcclManager::~NcclManager() {} +NcclManager::NcclManager() { VLOG(2) << "New NcclManager " << this; } +NcclManager::~NcclManager() { + VLOG(2) << "~NcclManager " << this; + for (auto& it : device_to_comm_streams_) { + for (NcclStream* nccl_stream : it.second) { + { + mutex_lock l(nccl_stream->mu); + nccl_stream->shutdown_requested = true; + nccl_stream->cv.notify_all(); + } + nccl_stream->Unref(); + } + } +} NcclManager* NcclManager::instance() { static NcclManager* instance = new NcclManager(); return instance; @@ -203,11 +211,12 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, mutex_lock l(mu_); - if (collective->single_node) { - // For single-node collectives, we identify a communicator uniquely by the - // set of devices participating in the collective. For example, if a - // collective is for GPUs 0, 1, and 2 then this will scan to find the - // communicator for GPUs 0, 1, and 2. + if (collective->communicator_key.empty()) { + // For single-node collectives, when the caller does not specify a + // `communicator_key`, we identify a communicator uniquely by the set of + // devices participating in the collective. For example, if a collective is + // for GPUs 0, 1, and 2 then this will scan to find the communicator for + // GPUs 0, 1, and 2. // // Note that each executor identifies a context on one device, so this is // the same as getting the communicator connecting the devices in the @@ -275,8 +284,8 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, auto& streams = device_to_comm_streams_[executor]; NcclStream* nccl_stream = nullptr; for (const auto& s : streams) { - if (used_streams.insert(s.get()).second) { - nccl_stream = s.get(); + if (used_streams.insert(s).second) { + nccl_stream = s; break; } } @@ -289,9 +298,11 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, streams.emplace_back(nccl_stream); used_streams.insert(nccl_stream); - nccl_stream->thread.reset(env->StartThread( - ThreadOptions(), "nccl_kernel_launch", - [this, nccl_stream] { LoopKernelLaunches(nccl_stream); })); + nccl_stream->Ref(); + env->SchedClosure([this, nccl_stream]() { + LoopKernelLaunches(nccl_stream); + nccl_stream->Unref(); + }); } members[i].nccl_stream = nccl_stream; @@ -383,9 +394,11 @@ void NcclManager::SignalMultiNodeReady(const string& collective_key) { mutex_lock l(mu_); auto collective_it = collectives_.find(collective_key); if (collective_it != collectives_.end()) { - Collective* collective = collective_it->second.get(); + Collective* collective = collective_it->second; collective->multi_node_ready = true; - to_run = CheckReady(collective_key, collective); + if (CheckReady(collective_key, collective)) { + to_run = collective; + } } } @@ -403,23 +416,22 @@ void NcclManager::AddParticipant(std::unique_ptr participant, auto collective_it = collectives_.find(context.collective_key); Collective* collective = nullptr; if (collective_it == collectives_.end()) { - auto collective_unique_ptr = absl::make_unique( + collective = new Collective( data_type, collective_type, reduction_op, context.num_local_devices, context.num_global_devices, context.communicator_key); - collective = collective_unique_ptr.get(); - collectives_.emplace(context.collective_key, - std::move(collective_unique_ptr)); + collectives_.emplace(context.collective_key, collective); } else { - collective = collective_it->second.get(); + collective = collective_it->second; } // Check `collective` is correct and consistent. - if (collective->status.ok() && collective->single_node && - !collective->communicator_key.empty()) { - collective->status = - errors::Internal("Collective ", reduction_op, - " is single node but has communicator_key of size ", - collective->communicator_key.size()); + if (collective->status.ok() && !collective->single_node && + collective->communicator_key.empty()) { + collective->status = errors::Internal( + "Collective ", reduction_op, " is multi node with num_local_devices=", + collective->num_local_devices, + " and num_global_devices=", collective->num_global_devices, + " but has an empty communicator_key"); } if (collective->status.ok() && collective->communicator_key.size() != context.communicator_key.size()) { @@ -463,26 +475,25 @@ void NcclManager::AddParticipant(std::unique_ptr participant, collective->participants.emplace_back(std::move(participant)); ++collective->available_participants; - to_run = CheckReady(context.collective_key, collective); + if (CheckReady(context.collective_key, collective)) { + to_run = collective; + } } if (to_run != nullptr) RunCollective(to_run); } -NcclManager::Collective* NcclManager::CheckReady(const string& collective_key, - Collective* collective) { - Collective* to_run = nullptr; +bool NcclManager::CheckReady(const string& collective_key, + Collective* collective) { if (collective->available_participants == collective->num_local_devices) { if (collective->num_global_devices == collective->num_local_devices || collective->multi_node_ready) { // Ownership transferred to callee. - to_run = collective; - auto collectives_it = collectives_.find(collective_key); - collectives_it->second.release(); - collectives_.erase(collectives_it); + collectives_.erase(collective_key); + return true; } } - return to_run; + return false; } void NcclManager::RunCollective(Collective* collective) { @@ -496,7 +507,7 @@ void NcclManager::RunCollective(Collective* collective) { for (int i = 0; i < collective->num_local_devices; ++i) { collective->participants[i]->done_callback(s); } - delete collective; + collective->Unref(); return; } @@ -533,9 +544,13 @@ void NcclManager::RunCollective(Collective* collective) { collective->communicator->members[i].nccl_stream; mutex_lock l(nccl_stream->mu); nccl_stream->pending_launches_.push_front(std::make_pair(collective, i)); + // Ownership is shared between LoopKernelLaunches for each stream in this + // collective. + collective->Ref(); nccl_stream->cv.notify_all(); } } + collective->Unref(); } void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { @@ -548,6 +563,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { // Find collective to run. std::pair next_launch; { + VLOG(2) << "Locking mutex nccl_stream " << nccl_stream; mutex_lock l(nccl_stream->mu); while (nccl_stream->pending_launches_.empty()) { if (nccl_stream->shutdown_requested) { @@ -624,15 +640,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { collective->participants[p_idx]->done_callback(errors::Unknown( "Error invoking NCCL: ", ncclGetErrorString(nccl_result))); } - - // TODO(cwhipkey): use RefCounted after figuring out how to use in a - // custom op library. - // See tensorflow/core/lib/core/refcount.h for details on this locking. - if (collective->remaining_participants.load(std::memory_order_acquire) == - 1 || - collective->remaining_participants.fetch_sub(1) == 1) { - delete collective; - } + collective->Unref(); }; p->event_mgr->ThenExecute(comm_stream, done_callback); } diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index 7cf2c85f3e8..d968fac833b 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -189,7 +189,7 @@ class NcclManager { // the corresponding NCCL/CUDA error string. Status GetCommunicator(Collective* collective, Communicator** communicator); - // Adds a participant device to the local `Collective` instance correponding + // Adds a participant device to the local `Collective` instance corresponding // to `collective_key`. Launches the `Collective` if it is ready, which it // checks by calling `CheckReady()`. Also performs consistency and sanity // checks before launching. @@ -198,13 +198,13 @@ class NcclManager { ncclRedOp_t reduction_op); // If `collective` is ready to run, removes it from the `collectives_` map and - // returns the pointer. Otherwise returns `nullptr`. + // returns true. Otherwise returns false. // Assumes `collective_key` corresponds to `collective`. // // A collective is ready to run when all local participants have called Add* // function, and the collective is signalled globally ready via // `SetMultiNodeReady`. - Collective* CheckReady(const string& collective_key, Collective* collective) + bool CheckReady(const string& collective_key, Collective* collective) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Run . This calls takes ownership of . @@ -214,13 +214,12 @@ class NcclManager { mutex mu_; // Maps key to collectives currently being assembled or run. - std::unordered_map> collectives_ - GUARDED_BY(mu_); + std::unordered_map collectives_ GUARDED_BY(mu_); // Maps a device to the communication streams that make up its collective. // This is used to share the stream across different communicators that // include the same device. - std::map>> + std::map> device_to_comm_streams_ GUARDED_BY(mu_); std::vector> communicators_; diff --git a/tensorflow/core/nccl/nccl_manager_test.cc b/tensorflow/core/nccl/nccl_manager_test.cc index 420e143c837..06564ee8020 100644 --- a/tensorflow/core/nccl/nccl_manager_test.cc +++ b/tensorflow/core/nccl/nccl_manager_test.cc @@ -15,6 +15,8 @@ limitations under the License. #ifdef GOOGLE_CUDA +#include "tensorflow/core/nccl/nccl_manager.h" + #include #include #include @@ -23,7 +25,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_device.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/nccl/nccl_manager.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -69,6 +70,8 @@ class NcclManagerTest : public ::testing::Test { LOG(INFO) << "Running test with " << devices_->size() << " gpus"; } + void SetUp() override { ASSERT_GT(devices_->size(), 0) << "No GPUs found"; } + static int32 NumGPUs() { return static_cast(devices_->size()); } static void TearDownTestCase() { delete devices_; } @@ -220,6 +223,58 @@ class NcclManagerTest : public ::testing::Test { }; } + void RunMultiNodeTest(const int num_nodes, const int num_ranks_per_node) { + const int num_global_ranks = num_nodes * num_ranks_per_node; + std::vector nccl_managers(num_nodes); + const string collective_key = "allreduce"; + // The NcclManagers in this test synchronize in real-time, so we need to run + // each node's code in a separate thread. + // Specifically, the call to ncclGroupEnd() after calling ncclCommInitRank + // waits for all communicators before returning. + thread::ThreadPool pool(Env::Default(), "test_multi_node_nccl", num_nodes); + + // First, initialize the communicator_key used for this collective. + const string communicator_key = nccl_managers[0].GenerateCommunicatorKey(); + + for (int op = 0; op < 4; ++op) { + ncclRedOp_t reduction_op = static_cast(op); + std::unique_ptr test_case( + this->MakeReductionTestCase(num_nodes, num_ranks_per_node, + reduction_op, TensorShape({2, 3}), 0.0f)); + for (int node = 0; node < num_nodes; ++node) { + auto node_fn = [this, node, num_ranks_per_node, num_global_ranks, + &nccl_managers, &communicator_key, &collective_key, + reduction_op, &test_case] { + for (int local_rank = 0; local_rank < num_ranks_per_node; + ++local_rank) { + auto* device = this->GetDevice(local_rank); + auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; + auto* stream = device->tensorflow_gpu_device_info()->stream; + const int global_rank = node * num_ranks_per_node + local_rank; + auto participant = absl::make_unique( + device->executor(), stream, event_mgr, device->gpu_id(), + &test_case->ins[global_rank], &test_case->outs[global_rank], + global_rank, this->CreateDoneCallback(test_case.get())); + nccl_managers[node].AddToAllReduce( + std::move(participant), + {collective_key, num_ranks_per_node, num_global_ranks, + communicator_key}, + reduction_op); + VLOG(1) << "AddToAllReduce node " << node << " global_rank " + << global_rank; + } + + // Signal collective ready to launch at this node. + nccl_managers[node].SignalMultiNodeReady(collective_key); + }; + pool.Schedule(node_fn); + } + + VLOG(2) << "Verifying results"; + this->VerifyResults(test_case.get()); + } + } + static BaseGPUDevice* GetDevice(size_t rank) { return devices_->at(rank % devices_->size()).get(); } @@ -405,59 +460,16 @@ TEST(NcclManagerTest, CommunicatorKey) { } // This test creates `num_nodes` NcclManagers to simulate a multi-node -// environment. It works on a single node and reuse GPUs. It enqueues NCCL ops -// on separate stream per rank. +// environment. It works on a single node and reuses GPUs. It enqueues NCCL +// kernels on separate stream per rank. TYPED_TEST(NcclManagerTest, MultiNode) { - const int num_nodes = 2; - const int num_ranks_per_node = 4; - const int num_global_ranks = num_nodes * num_ranks_per_node; - std::vector nccl_managers(num_nodes); - const string collective_key = "allreduce"; - // The NcclManagers in this test synchronize in real-time, so we need to run - // each node's code in a separate thread. - // Specifically, the call to ncclGroupEnd() after calling ncclCommInitRank - // waits for all communicators before returning. - thread::ThreadPool pool(Env::Default(), "test_multi_node_nccl", num_nodes); + this->RunMultiNodeTest(/*num_nodes=*/2, /*num_ranks_per_node=*/4); +} - // First, initialize the communicator_key used for this collective. - const string communicator_key = nccl_managers[0].GenerateCommunicatorKey(); - - for (int op = 0; op < 4; ++op) { - ncclRedOp_t reduction_op = static_cast(op); - std::unique_ptr test_case( - this->MakeReductionTestCase(num_nodes, num_ranks_per_node, reduction_op, - TensorShape({2, 3}), 0.0f)); - for (int node = 0; node < num_nodes; ++node) { - auto node_fn = [this, node, &nccl_managers, &communicator_key, - &collective_key, reduction_op, &test_case] { - for (int local_rank = 0; local_rank < num_ranks_per_node; - ++local_rank) { - auto* device = this->GetDevice(local_rank); - auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr; - auto* stream = device->tensorflow_gpu_device_info()->stream; - const int global_rank = node * num_ranks_per_node + local_rank; - auto participant = absl::make_unique( - device->executor(), stream, event_mgr, device->gpu_id(), - &test_case->ins[global_rank], &test_case->outs[global_rank], - global_rank, this->CreateDoneCallback(test_case.get())); - nccl_managers[node].AddToAllReduce( - std::move(participant), - {collective_key, num_ranks_per_node, num_global_ranks, - communicator_key}, - reduction_op); - VLOG(1) << "AddToAllReduce node " << node << " global_rank " - << global_rank; - } - - // Signal collective ready to launch at this node. - nccl_managers[node].SignalMultiNodeReady(collective_key); - }; - pool.Schedule(node_fn); - } - - VLOG(2) << "Verifying results"; - this->VerifyResults(test_case.get()); - } +// Tests that specifying `communicator_key` with a single node NCCL collective +// works well. +TYPED_TEST(NcclManagerTest, MultiNodeSingle) { + this->RunMultiNodeTest(/*num_nodes=*/1, /*num_ranks_per_node=*/4); } // Checks that we return error status if a collective_key is used for different diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 3d03bc1d5fd..f64cf801f22 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -550,4 +550,30 @@ Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad); +Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) { + DataType itype; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype)); + if (itype != DT_INT32) { + return errors::Unimplemented( + "BroadcastToGrad for int64 index are not supported."); + } + std::vector nodes = { + {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}}, + {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}}, + {{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}}, + {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}}, + {{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}}; + *g = FDH::Define( + // Arg defs + {"x: T", "shape: int32", "dy: T"}, + // Ret val defs + {"dx: T", "dshape: Tidx"}, + // Attr defs + {{"T: type"}, {"Tidx: {int32, int64}"}}, + // Nodes + nodes); + return Status::OK(); +} +REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad); + } // end namespace tensorflow diff --git a/tensorflow/core/ops/array_grad_test.cc b/tensorflow/core/ops/array_grad_test.cc index 79d28a83cc4..bcef90c15e3 100644 --- a/tensorflow/core/ops/array_grad_test.cc +++ b/tensorflow/core/ops/array_grad_test.cc @@ -765,5 +765,40 @@ TEST(ArrayGradTest, StridedSliceGrad) { } } +std::vector BroadcastToGrad(const Tensor& x, const Tensor& shape, + const Tensor& dy) { + auto T = DT_FLOAT; + auto Tidx = DT_INT32; + auto gdef = test::function::GDef( + {f::NDef("x", "Placeholder", {}, {{"dtype", T}}), + f::NDef("shape", "Placeholder", {}, {{"dtype", Tidx}}), + f::NDef("dy", "Placeholder", {}, {{"dtype", T}}), + f::NDef( + "dx", "SymbolicGradient", {"x", "shape", "dy"}, + {{"f", FDH::FunctionRef("BroadcastTo", {{"T", T}, {"Tidx", Tidx}})}, + {"Tin", DataTypeSlice{T, Tidx, T}}, + {"Tout", DataTypeSlice{T, Tidx}}})}); + VLOG(1) << DebugStringWhole(gdef); + auto sess = NewSession(); + TF_CHECK_OK(sess->Create(gdef)); + std::vector out; + TF_CHECK_OK(sess->Run({{"x:0", x}, {"shape:0", shape}, {"dy:0", dy}}, + {"dx:0", "dx:1"}, {}, &out)); + CHECK_EQ(out.size(), 2); + TF_CHECK_OK(sess->Close()); + return out; +} + +TEST(ArrayGradTest, BroadcastToGrad) { + Tensor x(DT_FLOAT, {2, 2}); + x.flat().setZero(); + Tensor shape(DT_INT32, {3}); + test::FillValues(&shape, {2, 2, 2}); + Tensor dy(DT_FLOAT, {2, 2, 2}); + test::FillIota(&dy, 0); + auto dx = BroadcastToGrad(x, shape, dy); + test::ExpectClose(dx[0], test::AsTensor({4., 6., 8., 10.}, {2, 2})); + test::ExpectTensorEqual(dx[1], test::AsTensor({0, 0, 0}, {3})); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 8b6ee870799..ccbf4177b98 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -17,6 +17,9 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mirror_pad_mode.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/strided_slice_op.h" @@ -1178,34 +1181,7 @@ REGISTER_OP("GatherNd") .Output("output: Tparams") .Attr("Tparams: type") .Attr("Tindices: {int32,int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle params = c->input(0); - ShapeHandle indices; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices)); - DimensionHandle r_dim = c->Dim(indices, -1); - - if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) { - c->set_output(0, c->UnknownShape()); - return Status::OK(); - } - - if (c->Value(r_dim) > c->Rank(params)) { - return errors::InvalidArgument( - "indices.shape[-1] must be <= params.rank, but saw indices shape: ", - c->DebugString(indices), - " and params shape: ", c->DebugString(params)); - } - - // Remove r_dim from indices to get output. - ShapeHandle indices_slice; - ShapeHandle params_slice; - TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice)); - TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), ¶ms_slice)); - ShapeHandle out; - TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out)); - c->set_output(0, out); - return Status::OK(); - }); + .SetShapeFn(shape_inference::GatherNdShape); // -------------------------------------------------------------------------- REGISTER_OP("Identity") @@ -1674,6 +1650,22 @@ REGISTER_OP("ResourceStridedSliceAssign") .Attr("shrink_axis_mask: int = 0") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("TensorStridedSliceUpdate") + .Input("input: T") + .Input("begin: Index") + .Input("end: Index") + .Input("strides: Index") + .Input("value: T") + .Output("output: T") + .Attr("T: type") + .Attr("Index: {int32, int64}") + .Attr("begin_mask: int = 0") + .Attr("end_mask: int = 0") + .Attr("ellipsis_mask: int = 0") + .Attr("new_axis_mask: int = 0") + .Attr("shrink_axis_mask: int = 0") + .SetShapeFn(shape_inference::UnchangedShape); + REGISTER_OP("Tile") .Input("input: T") .Input("multiples: Tmultiples") @@ -3134,6 +3126,37 @@ REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient") return Status::OK(); }); +REGISTER_OP("Fingerprint") + .Input("data: T") + .Input("method: string") + .Output("fingerprint: uint8") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + + DimensionHandle fingerprint_size; + const Tensor* method = c->input_tensor(1); + if (method == nullptr) { + fingerprint_size = c->UnknownDim(); + } else { + if (method->dims() != 0) { + return errors::InvalidArgument("`method` must be rank 0: ", + method->shape()); + } + const string& method_string = method->scalar()(); + if (method_string != "farmhash64") { + return errors::InvalidArgument("Unsupported method: ", method_string); + } + fingerprint_size = c->MakeDim(sizeof(uint64)); + } + + DimensionHandle batch = c->Dim(c->input(0), 0); + c->set_output(0, c->MakeShape({batch, fingerprint_size})); + return Status::OK(); + }); + #ifdef INTEL_MKL REGISTER_OP("_MklConcat") .Input("concat_dim: int32") diff --git a/tensorflow/core/ops/bitwise_ops.cc b/tensorflow/core/ops/bitwise_ops.cc index 39acf5f358b..8d04d97fd1e 100644 --- a/tensorflow/core/ops/bitwise_ops.cc +++ b/tensorflow/core/ops/bitwise_ops.cc @@ -27,6 +27,13 @@ REGISTER_OP("Invert") .SetShapeFn(shape_inference::UnchangedShape); #define BINARY_BITWISE() \ + Input("x: T") \ + .Input("y: T") \ + .Output("z: T") \ + .Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") \ + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + +#define BINARY_BITWISE_COMMUTATIVE() \ Input("x: T") \ .Input("y: T") \ .Output("z: T") \ @@ -40,11 +47,11 @@ REGISTER_OP("PopulationCount") .Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") .SetShapeFn(shape_inference::UnchangedShape); -REGISTER_OP("BitwiseAnd").BINARY_BITWISE(); +REGISTER_OP("BitwiseAnd").BINARY_BITWISE_COMMUTATIVE(); -REGISTER_OP("BitwiseOr").BINARY_BITWISE(); +REGISTER_OP("BitwiseOr").BINARY_BITWISE_COMMUTATIVE(); -REGISTER_OP("BitwiseXor").BINARY_BITWISE(); +REGISTER_OP("BitwiseXor").BINARY_BITWISE_COMMUTATIVE(); REGISTER_OP("LeftShift").BINARY_BITWISE(); diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index 852b8d326c1..c23c9e828a0 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -97,6 +97,41 @@ REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature") return Status::OK(); }); +REGISTER_OP("BoostedTreesCalculateBestFeatureSplit") + .Input("node_id_range: int32") + .Input("stats_summary: float32") + .Input("l1: float") + .Input("l2: float") + .Input("tree_complexity: float") + .Input("min_node_weight: float") + .Attr("logits_dimension: int >= 1") + .Attr("split_type: {'inequality'} = 'inequality'") + .Output("node_ids: int32") + .Output("gains: float32") + .Output("feature_dimensions: int32") + .Output("thresholds: int32") + .Output("left_node_contribs: float32") + .Output("right_node_contribs: float32") + .Output("split_with_default_directions: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle node_id_range_shape; + shape_inference::ShapeHandle unused_shape; + // node id range is rank 1 with 2 values. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape)); + TF_RETURN_IF_ERROR( + c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); + ShapeHandle output_shape = c->MakeShape({c->UnknownDim()}); + for (int i = 0; i < 7; ++i) { + c->set_output(i, output_shape); + } + return Status::OK(); + }); + REGISTER_OP("BoostedTreesCreateEnsemble") .Input("tree_ensemble_handle: resource") .Input("stamp_token: int64") @@ -181,6 +216,53 @@ REGISTER_OP("BoostedTreesMakeStatsSummary") return Status::OK(); }); +// V2 of BoostedTreesMakeStatsSummary. Supports multi-dim dense Tensor and +// multi class. +REGISTER_OP("BoostedTreesAggregateStats") + .Input("node_ids: int32") + .Input("gradients: float") + .Input("hessians: float") + .Input("feature: int32") + .Attr("max_splits: int >= 1") + .Attr("num_buckets: int >= 1") + .Output("stats_summary: float") + .SetShapeFn([](shape_inference::InferenceContext* c) { + // Sets the shape of the output as a Rank 4 Tensor. + int max_splits; + int num_buckets; + TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits)); + TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets)); + + shape_inference::ShapeHandle node_ids_shape; + shape_inference::ShapeHandle gradients_shape; + shape_inference::ShapeHandle hessians_shape; + shape_inference::ShapeHandle feature_shape; + + shape_inference::DimensionHandle batch_size = c->Dim(c->input(0), 0); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &feature_shape)); + + // Verify all three inputs have same first dimension, i.e., batch_size. + TF_RETURN_IF_ERROR(c->Merge(c->Dim(gradients_shape, 0), + c->Dim(node_ids_shape, 0), &batch_size)); + TF_RETURN_IF_ERROR(c->Merge(c->Dim(hessians_shape, 0), + c->Dim(node_ids_shape, 0), &batch_size)); + TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), + c->Dim(node_ids_shape, 0), &batch_size)); + + DimensionHandle logits_dim = c->Dim(c->input(1), 1); + DimensionHandle hessian_dim = c->Dim(c->input(2), 1); + DimensionHandle feature_dim = c->Dim(c->input(3), 1); + DimensionHandle stats_dim; + TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim)); + c->set_output( + 0, c->MakeShape({max_splits, num_buckets, feature_dim, stats_dim})); + return Status::OK(); + }); + // TODO(nponomareva): when/if creating the new op for unbucketized data, rename // bucketized_features to features. REGISTER_OP("BoostedTreesPredict") diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc index 06e5f14de76..4b50d62ee7b 100644 --- a/tensorflow/core/ops/collective_ops.cc +++ b/tensorflow/core/ops/collective_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { @@ -41,7 +42,38 @@ REGISTER_OP("CollectiveGather") .Attr("instance_key: int") .Attr("shape: shape") .SetIsStateful() - .SetShapeFn(shape_inference::ExplicitShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + // Scalar input is not supported. + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused)); + + shape_inference::ShapeHandle in_subshape; + TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 1, &in_subshape)); + + auto input_first_dim_value = c->Value(c->Dim(c->input(0), 0)); + + // This output should have the same shape as its input except the first + // dimension should be multiplied by group size. + shape_inference::ShapeHandle output_first_dim_as_shape; + if (input_first_dim_value == + shape_inference::InferenceContext::kUnknownDim) { + output_first_dim_as_shape = + c->Vector(shape_inference::InferenceContext::kUnknownDim); + } else { + int group_size; + TF_CHECK_OK(c->GetAttr("group_size", &group_size)); + std::vector output_first_dim; + output_first_dim.push_back( + c->MakeDim(group_size * input_first_dim_value)); + output_first_dim_as_shape = c->MakeShape(output_first_dim); + } + + shape_inference::ShapeHandle out; + TF_RETURN_IF_ERROR( + c->Concatenate(output_first_dim_as_shape, in_subshape, &out)); + c->set_output(0, out); + return Status::OK(); + }); REGISTER_OP("CollectiveBcastSend") .Input("input: T") diff --git a/tensorflow/core/ops/compat/BUILD b/tensorflow/core/ops/compat/BUILD index c613ab144f8..5ffb8cf9a10 100644 --- a/tensorflow/core/ops/compat/BUILD +++ b/tensorflow/core/ops/compat/BUILD @@ -37,6 +37,7 @@ tf_cc_test( data = [ ":ops_history.v0.pbtxt", ":ops_history.v1.pbtxt", + ":ops_history.v2.pbtxt", "//tensorflow/core:ops/ops.pbtxt", ], deps = [ diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 97d1520b7ad..f3a9d101016 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -1693,6 +1693,30 @@ op { } is_stateful: true } +op { + name: "AnonymousIteratorV2" + output_arg { + name: "handle" + type: DT_RESOURCE + } + output_arg { + name: "deleter" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Any" input_arg { @@ -9050,6 +9074,44 @@ op { minimum: 1 } } +op { + name: "BatchDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "drop_remainder" + type: DT_BOOL + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "parallel_copy" + type: "bool" + default_value { + b: false + } + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "BatchFFT" input_arg { @@ -9394,6 +9456,51 @@ op { } } } +op { + name: "BatchMatMulV2" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "adj_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "adj_y" + type: "bool" + default_value { + b: false + } + } +} op { name: "BatchMatrixBandPart" input_arg { @@ -11578,6 +11685,41 @@ op { } is_commutative: true } +op { + name: "BoostedTreesAggregateStats" + input_arg { + name: "node_ids" + type: DT_INT32 + } + input_arg { + name: "gradients" + type: DT_FLOAT + } + input_arg { + name: "hessians" + type: DT_FLOAT + } + input_arg { + name: "feature" + type: DT_INT32 + } + output_arg { + name: "stats_summary" + type: DT_FLOAT + } + attr { + name: "max_splits" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "num_buckets" + type: "int" + has_minimum: true + minimum: 1 + } +} op { name: "BoostedTreesBucketize" input_arg { @@ -11601,6 +11743,79 @@ op { has_minimum: true } } +op { + name: "BoostedTreesCalculateBestFeatureSplit" + input_arg { + name: "node_id_range" + type: DT_INT32 + } + input_arg { + name: "stats_summary" + type: DT_FLOAT + } + input_arg { + name: "l1" + type: DT_FLOAT + } + input_arg { + name: "l2" + type: DT_FLOAT + } + input_arg { + name: "tree_complexity" + type: DT_FLOAT + } + input_arg { + name: "min_node_weight" + type: DT_FLOAT + } + output_arg { + name: "node_ids" + type: DT_INT32 + } + output_arg { + name: "gains" + type: DT_FLOAT + } + output_arg { + name: "feature_dimensions" + type: DT_INT32 + } + output_arg { + name: "thresholds" + type: DT_INT32 + } + output_arg { + name: "left_node_contribs" + type: DT_FLOAT + } + output_arg { + name: "right_node_contribs" + type: DT_FLOAT + } + output_arg { + name: "split_with_default_directions" + type: DT_STRING + } + attr { + name: "logits_dimension" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "split_type" + type: "string" + default_value { + s: "inequality" + } + allowed_values { + list { + s: "inequality" + } + } + } +} op { name: "BoostedTreesCalculateBestGainsPerFeature" input_arg { @@ -12829,6 +13044,64 @@ op { } } } +op { + name: "ChooseFastestBranchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "ratio_numerator" + type: DT_INT64 + } + input_arg { + name: "ratio_denominator" + type: DT_INT64 + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "num_elements_per_branch" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "branches" + type: "list(func)" + has_minimum: true + minimum: 1 + } + attr { + name: "other_arguments_lengths" + type: "list(int)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "ClipByValue" input_arg { @@ -13224,6 +13497,63 @@ op { } } } +op { + name: "CombinedNonMaxSuppression" + input_arg { + name: "boxes" + type: DT_FLOAT + } + input_arg { + name: "scores" + type: DT_FLOAT + } + input_arg { + name: "max_output_size_per_class" + type: DT_INT32 + } + input_arg { + name: "max_total_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + input_arg { + name: "score_threshold" + type: DT_FLOAT + } + output_arg { + name: "nmsed_boxes" + type: DT_FLOAT + } + output_arg { + name: "nmsed_scores" + type: DT_FLOAT + } + output_arg { + name: "nmsed_classes" + type: DT_FLOAT + } + output_arg { + name: "valid_detections" + type: DT_INT32 + } + attr { + name: "pad_per_class" + type: "bool" + default_value { + b: false + } + } + attr { + name: "clip_boxes" + type: "bool" + default_value { + b: true + } + } +} op { name: "CompareAndBitpack" input_arg { @@ -17166,6 +17496,159 @@ op { } is_stateful: true } +op { + name: "CudnnRNNBackpropV3" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "input_h" + type_attr: "T" + } + input_arg { + name: "input_c" + type_attr: "T" + } + input_arg { + name: "params" + type_attr: "T" + } + input_arg { + name: "sequence_lengths" + type: DT_INT32 + } + input_arg { + name: "output" + type_attr: "T" + } + input_arg { + name: "output_h" + type_attr: "T" + } + input_arg { + name: "output_c" + type_attr: "T" + } + input_arg { + name: "output_backprop" + type_attr: "T" + } + input_arg { + name: "output_h_backprop" + type_attr: "T" + } + input_arg { + name: "output_c_backprop" + type_attr: "T" + } + input_arg { + name: "reserve_space" + type_attr: "T" + } + input_arg { + name: "host_reserved" + type: DT_INT8 + } + output_arg { + name: "input_backprop" + type_attr: "T" + } + output_arg { + name: "input_h_backprop" + type_attr: "T" + } + output_arg { + name: "input_c_backprop" + type_attr: "T" + } + output_arg { + name: "params_backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "rnn_mode" + type: "string" + default_value { + s: "lstm" + } + allowed_values { + list { + s: "rnn_relu" + s: "rnn_tanh" + s: "lstm" + s: "gru" + } + } + } + attr { + name: "input_mode" + type: "string" + default_value { + s: "linear_input" + } + allowed_values { + list { + s: "linear_input" + s: "skip_input" + s: "auto_select" + } + } + } + attr { + name: "direction" + type: "string" + default_value { + s: "unidirectional" + } + allowed_values { + list { + s: "unidirectional" + s: "bidirectional" + } + } + } + attr { + name: "dropout" + type: "float" + default_value { + f: 0 + } + } + attr { + name: "seed" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "seed2" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "time_major" + type: "bool" + default_value { + b: true + } + } + is_stateful: true +} op { name: "CudnnRNNCanonicalToParams" input_arg { @@ -17733,6 +18216,138 @@ op { } is_stateful: true } +op { + name: "CudnnRNNV3" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "input_h" + type_attr: "T" + } + input_arg { + name: "input_c" + type_attr: "T" + } + input_arg { + name: "params" + type_attr: "T" + } + input_arg { + name: "sequence_lengths" + type: DT_INT32 + } + output_arg { + name: "output" + type_attr: "T" + } + output_arg { + name: "output_h" + type_attr: "T" + } + output_arg { + name: "output_c" + type_attr: "T" + } + output_arg { + name: "reserve_space" + type_attr: "T" + } + output_arg { + name: "host_reserved" + type: DT_INT8 + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "rnn_mode" + type: "string" + default_value { + s: "lstm" + } + allowed_values { + list { + s: "rnn_relu" + s: "rnn_tanh" + s: "lstm" + s: "gru" + } + } + } + attr { + name: "input_mode" + type: "string" + default_value { + s: "linear_input" + } + allowed_values { + list { + s: "linear_input" + s: "skip_input" + s: "auto_select" + } + } + } + attr { + name: "direction" + type: "string" + default_value { + s: "unidirectional" + } + allowed_values { + list { + s: "unidirectional" + s: "bidirectional" + } + } + } + attr { + name: "dropout" + type: "float" + default_value { + f: 0 + } + } + attr { + name: "seed" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "seed2" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } + attr { + name: "time_major" + type: "bool" + default_value { + b: true + } + } + is_stateful: true +} op { name: "Cumprod" input_arg { @@ -19283,6 +19898,45 @@ op { } } } +op { + name: "DecodePaddedRaw" + input_arg { + name: "input_bytes" + type: DT_STRING + } + input_arg { + name: "fixed_length" + type: DT_INT32 + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "out_type" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT16 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_INT64 + } + } + } + attr { + name: "little_endian" + type: "bool" + default_value { + b: true + } + } +} op { name: "DecodePng" input_arg { @@ -19469,6 +20123,44 @@ op { } } } +op { + name: "DecodeRaw" + input_arg { + name: "bytes" + type: DT_STRING + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "out_type" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT16 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_BOOL + } + } + } + attr { + name: "little_endian" + type: "bool" + default_value { + b: true + } + } +} op { name: "DecodeWav" input_arg { @@ -19514,6 +20206,18 @@ op { } is_stateful: true } +op { + name: "DeleteIterator" + input_arg { + name: "handle" + type: DT_RESOURCE + } + input_arg { + name: "deleter" + type: DT_VARIANT + } + is_stateful: true +} op { name: "DeleteSessionTensor" input_arg { @@ -21647,6 +22351,34 @@ op { } } } +op { + name: "DivNoNan" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "DrawBoundingBoxes" input_arg { @@ -21675,6 +22407,38 @@ op { } } } +op { + name: "DrawBoundingBoxesV2" + input_arg { + name: "images" + type_attr: "T" + } + input_arg { + name: "boxes" + type: DT_FLOAT + } + input_arg { + name: "colors" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } +} op { name: "DynamicPartition" input_arg { @@ -22099,6 +22863,21 @@ op { } } } +op { + name: "EncodeJpegVariableQuality" + input_arg { + name: "images" + type: DT_UINT8 + } + input_arg { + name: "quality" + type: DT_INT32 + } + output_arg { + name: "contents" + type: DT_STRING + } +} op { name: "EncodePng" input_arg { @@ -22469,6 +23248,101 @@ op { } is_stateful: true } +op { + name: "EnqueueTPUEmbeddingSparseTensorBatch" + input_arg { + name: "sample_indices" + type_attr: "T1" + number_attr: "N" + } + input_arg { + name: "embedding_indices" + type_attr: "T2" + number_attr: "N" + } + input_arg { + name: "aggregation_weights" + type_attr: "T3" + number_attr: "N" + } + input_arg { + name: "mode_override" + type: DT_STRING + } + attr { + name: "T1" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T2" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T3" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "device_ordinal" + type: "int" + default_value { + i: -1 + } + } + attr { + name: "combiners" + type: "list(string)" + default_value { + list { + } + } + } + attr { + name: "table_ids" + type: "list(int)" + } + attr { + name: "max_sequence_lengths" + type: "list(int)" + default_value { + list { + } + } + } + is_stateful: true +} op { name: "EnsureShape" input_arg { @@ -30182,6 +31056,20 @@ op { } is_stateful: true } +op { + name: "InfeedEnqueuePrelinearizedBuffer" + input_arg { + name: "input" + type: DT_VARIANT + } + attr { + name: "device_ordinal" + type: "int" + default_value { + i: -1 + } + } +} op { name: "InfeedEnqueueTuple" input_arg { @@ -32052,6 +32940,37 @@ op { } is_commutative: true } +op { + name: "LeftShift" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + type: DT_UINT32 + type: DT_UINT64 + } + } + } +} op { name: "Less" input_arg { @@ -34304,6 +35223,59 @@ op { type: "func" } } +op { + name: "MapDefun" + input_arg { + name: "arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "captured_inputs" + type_list_attr: "Tcaptured" + } + output_arg { + name: "output" + type_list_attr: "output_types" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tcaptured" + type: "list(type)" + default_value { + list { + } + } + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "f" + type: "func" + } + attr { + name: "max_intra_op_parallelism" + type: "int" + default_value { + i: 1 + } + } +} op { name: "MapIncompleteSize" output_arg { @@ -40278,6 +41250,36 @@ op { minimum: 1 } } +op { + name: "ModelDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "cpu_budget" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "Mul" input_arg { @@ -40411,6 +41413,63 @@ op { } is_commutative: true } +op { + name: "MulNoNan" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true +} +op { + name: "MulNoNan" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "MultiDeviceIterator" output_arg { @@ -41396,6 +42455,32 @@ op { op { name: "NoOp" } +op { + name: "NonDeterministicInts" + input_arg { + name: "shape" + type_attr: "shape_dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + attr { + name: "shape_dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + is_stateful: true +} op { name: "NonMaxSuppression" input_arg { @@ -42109,6 +43194,41 @@ op { minimum: 1 } } +op { + name: "OptimizeDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "optimizations" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "optimization_configs" + type: "list(string)" + default_value { + list { + } + } + } +} op { name: "OptionalFromValue" input_arg { @@ -42803,6 +43923,59 @@ op { minimum: 1 } } +op { + name: "PaddedBatchDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "padded_shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "padding_values" + type_list_attr: "Toutput_types" + } + input_arg { + name: "drop_remainder" + type: DT_BOOL + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "parallel_copy" + type: "bool" + default_value { + b: false + } + } + attr { + name: "Toutput_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } +} op { name: "PaddingFIFOQueue" output_arg { @@ -44429,6 +45602,66 @@ op { minimum: 1 } } +op { + name: "Prelinearize" + input_arg { + name: "input" + type_attr: "dtype" + } + output_arg { + name: "output" + type: DT_VARIANT + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + } + } + } + attr { + name: "layout" + type: "list(int)" + default_value { + list { + } + } + } +} +op { + name: "PrelinearizeTuple" + input_arg { + name: "inputs" + type_list_attr: "dtypes" + } + output_arg { + name: "output" + type: DT_VARIANT + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "shapes" + type: "list(shape)" + } + attr { + name: "layouts" + type: "list(int)" + default_value { + list { + } + } + } +} op { name: "PreventGradient" input_arg { @@ -46194,6 +47427,87 @@ op { } is_commutative: true } +op { + name: "QuantizedAdd" + input_arg { + name: "x" + type_attr: "T1" + } + input_arg { + name: "y" + type_attr: "T2" + } + input_arg { + name: "min_x" + type: DT_FLOAT + } + input_arg { + name: "max_x" + type: DT_FLOAT + } + input_arg { + name: "min_y" + type: DT_FLOAT + } + input_arg { + name: "max_y" + type: DT_FLOAT + } + output_arg { + name: "z" + type_attr: "Toutput" + } + output_arg { + name: "min_z" + type: DT_FLOAT + } + output_arg { + name: "max_z" + type: DT_FLOAT + } + attr { + name: "T1" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "T2" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Toutput" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } +} op { name: "QuantizedAvgPool" input_arg { @@ -47735,6 +49049,113 @@ op { } } } +op { + name: "QuantizedConv2DPerChannel" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} op { name: "QuantizedConv2DWithBias" input_arg { @@ -49597,6 +51018,464 @@ op { } } } +op { + name: "QuantizedDepthwiseConv2D" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +op { + name: "QuantizedDepthwiseConv2DWithBias" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "bias" + type: DT_FLOAT + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +op { + name: "QuantizedDepthwiseConv2DWithBiasAndRelu" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "bias" + type: DT_FLOAT + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +op { + name: "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "bias" + type_attr: "Tbias" + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + input_arg { + name: "min_freezed_output" + type: DT_FLOAT + } + input_arg { + name: "max_freezed_output" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tbias" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_QINT32 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QUINT8 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} op { name: "QuantizedInstanceNorm" input_arg { @@ -50249,6 +52128,87 @@ op { } is_commutative: true } +op { + name: "QuantizedMul" + input_arg { + name: "x" + type_attr: "T1" + } + input_arg { + name: "y" + type_attr: "T2" + } + input_arg { + name: "min_x" + type: DT_FLOAT + } + input_arg { + name: "max_x" + type: DT_FLOAT + } + input_arg { + name: "min_y" + type: DT_FLOAT + } + input_arg { + name: "max_y" + type: DT_FLOAT + } + output_arg { + name: "z" + type_attr: "Toutput" + } + output_arg { + name: "min_z" + type: DT_FLOAT + } + output_arg { + name: "max_z" + type: DT_FLOAT + } + attr { + name: "T1" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "T2" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Toutput" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } +} op { name: "QuantizedRelu" input_arg { @@ -58769,6 +60729,80 @@ op { } is_stateful: true } +op { + name: "ResourceGather" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "batch_dims" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "validate_indices" + type: "bool" + default_value { + b: true + } + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { + name: "ResourceGatherNd" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} op { name: "ResourceScatterAdd" input_arg { @@ -64101,6 +66135,37 @@ op { } is_commutative: true } +op { + name: "RightShift" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + type: DT_UINT32 + type: DT_UINT64 + } + } + } +} op { name: "Rint" input_arg { @@ -64700,6 +66765,41 @@ op { } is_stateful: true } +op { + name: "SamplingDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "rate" + type: DT_FLOAT + } + input_arg { + name: "seed" + type: DT_INT64 + } + input_arg { + name: "seed2" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "Save" input_arg { @@ -79910,6 +82010,76 @@ op { } is_stateful: true } +op { + name: "StatefulRandomBinomial" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "algorithm" + type: DT_INT64 + } + input_arg { + name: "shape" + type_attr: "S" + } + input_arg { + name: "counts" + type_attr: "T" + } + input_arg { + name: "probs" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "S" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_DOUBLE + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} op { name: "StatefulStandardNormal" input_arg { @@ -79984,6 +82154,39 @@ op { } is_stateful: true } +op { + name: "StatefulStandardNormal" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "shape" + type_attr: "shape_dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + } + attr { + name: "shape_dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + deprecation { + version: 29 + } + is_stateful: true +} op { name: "StatefulStandardNormalV2" input_arg { @@ -80018,6 +82221,74 @@ op { } is_stateful: true } +op { + name: "StatefulTruncatedNormal" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "algorithm" + type: DT_INT64 + } + input_arg { + name: "shape" + type_attr: "shape_dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + } + attr { + name: "shape_dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + is_stateful: true +} +op { + name: "StatefulUniform" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "algorithm" + type: DT_INT64 + } + input_arg { + name: "shape" + type_attr: "shape_dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + } + attr { + name: "shape_dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + is_stateful: true +} op { name: "StatefulUniformFullInt" input_arg { @@ -80776,6 +83047,40 @@ op { } } } +op { + name: "StatsAggregatorHandleV2" + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { + name: "StatsAggregatorSetSummaryWriter" + input_arg { + name: "stats_aggregator" + type: DT_RESOURCE + } + input_arg { + name: "summary" + type: DT_RESOURCE + } + is_stateful: true +} op { name: "StopGradient" input_arg { @@ -84832,6 +87137,82 @@ op { } is_stateful: true } +op { + name: "TensorStridedSliceUpdate" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "end" + type_attr: "Index" + } + input_arg { + name: "strides" + type_attr: "Index" + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "begin_mask" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "end_mask" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "ellipsis_mask" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "new_axis_mask" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "shrink_axis_mask" + type: "int" + default_value { + i: 0 + } + } +} op { name: "TensorSummary" input_arg { @@ -87985,6 +90366,22 @@ op { } is_stateful: true } +op { + name: "WriteRawProtoSummary" + input_arg { + name: "writer" + type: DT_RESOURCE + } + input_arg { + name: "step" + type: DT_INT64 + } + input_arg { + name: "tensor" + type: DT_STRING + } + is_stateful: true +} op { name: "WriteScalarSummary" input_arg { diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc index cd2e5c9d340..9b22ccdeeec 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops.cc @@ -167,6 +167,7 @@ REGISTER_OP("CudnnRNNV3") .Attr("seed: int = 0") .Attr("seed2: int = 0") .Attr("is_training: bool = true") + .Attr("time_major: bool = true") .SetShapeFn([](InferenceContext* c) { auto input_shape = c->input(0); auto input_h_shape = c->input(1); @@ -292,6 +293,7 @@ REGISTER_OP("CudnnRNNBackpropV3") .Attr("dropout: float = 0.0") .Attr("seed: int = 0") .Attr("seed2: int = 0") + .Attr("time_major: bool = true") .SetShapeFn([](InferenceContext* c) { auto input_shape = c->input(0); auto input_h_shape = c->input(1); diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index cc7ce542579..e98827a2528 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -167,6 +167,7 @@ REGISTER_OP("PrefetchDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("slack_period: int = 0") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; // buffer_size should be a scalar. @@ -264,6 +265,7 @@ REGISTER_OP("BatchDatasetV2") .Input("batch_size: int64") .Input("drop_remainder: bool") .Output("handle: variant") + .Attr("parallel_copy: bool = false") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn([](shape_inference::InferenceContext* c) { @@ -280,6 +282,7 @@ REGISTER_OP("ShardDataset") .Input("num_shards: int64") .Input("index: int64") .Output("handle: variant") + .Attr("require_non_empty: bool = false") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn([](shape_inference::InferenceContext* c) { @@ -318,6 +321,7 @@ REGISTER_OP("PaddedBatchDatasetV2") .Input("padding_values: Toutput_types") .Input("drop_remainder: bool") .Output("handle: variant") + .Attr("parallel_copy: bool = false") .Attr("Toutput_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .Attr("N: int >= 1") @@ -502,6 +506,22 @@ REGISTER_OP("AnonymousIterator") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("AnonymousIteratorV2") + .Output("handle: resource") + .Output("deleter: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("DeleteIterator") + .Input("handle: resource") + .Input("deleter: variant") + .SetShapeFn(shape_inference::NoOutputs); + REGISTER_OP("MakeIterator") .Input("dataset: variant") .Input("iterator: resource") @@ -624,6 +644,7 @@ REGISTER_OP("OptimizeDataset") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("optimization_configs: list(string) = []") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("OptionalFromValue") @@ -658,6 +679,7 @@ REGISTER_OP("IteratorGetNextAsOptional") REGISTER_OP("ModelDataset") .Input("input_dataset: variant") .Output("handle: variant") + .Attr("cpu_budget: int = 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); @@ -673,6 +695,7 @@ REGISTER_OP("MapDefun") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") + .Attr("max_intra_op_parallelism: int = 1") .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 7d8c2a46fd4..239732ab2c2 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -17,6 +17,11 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("StatsAggregatorSetSummaryWriter") + .Input("stats_aggregator: resource") + .Input("summary: resource") + .SetShapeFn(shape_inference::NoOutputs); + REGISTER_OP("ExperimentalAutoShardDataset") .Input("input_dataset: variant") .Input("num_workers: int64") @@ -38,6 +43,20 @@ REGISTER_OP("ExperimentalBytesProducedStatsDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("ChooseFastestBranchDataset") + .Input("input_dataset: variant") + .Input("ratio_numerator: int64") + .Input("ratio_denominator: int64") + .Input("other_arguments: Targuments") + .Output("handle: variant") + .Attr("Targuments: list(type) >= 0") + .Attr("num_elements_per_branch: int >= 1") + .Attr("branches: list(func) >= 1") + .Attr("other_arguments_lengths: list(int) >= 1") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("ExperimentalCSVDataset") .Input("filenames: string") .Input("compression_type: string") @@ -345,6 +364,19 @@ REGISTER_OP("ExperimentalSlidingWindowDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("SnapshotDataset") + .Input("input_dataset: variant") + .Input("path: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // snapshot_path should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("ExperimentalSqlDataset") .Input("driver_name: string") .Input("data_source_name: string") @@ -369,6 +401,12 @@ REGISTER_OP("ExperimentalStatsAggregatorHandle") .Attr("container: string = ''") .Attr("shared_name: string = ''"); +REGISTER_OP("StatsAggregatorHandleV2") + .Output("handle: resource") + .SetShapeFn(shape_inference::ScalarShape) + .Attr("container: string = ''") + .Attr("shared_name: string = ''"); + REGISTER_OP("ExperimentalStatsAggregatorSummary") .Input("iterator: resource") .Output("summary: string") @@ -501,6 +539,23 @@ REGISTER_OP("ExperimentalIdentityIndexedDataset") .SetShapeFn( shape_inference::ScalarShape); // TODO(saeta): check input shapes. +REGISTER_OP("SamplingDataset") + .Input("input_dataset: variant") + .Input("rate: float32") + .Input("seed: int64") + .Input("seed2: int64") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // rate, seed, and seed2 should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return shape_inference::ScalarShape(c); + }); + /////////////////////////////////////////////////////////////////////////////// // IndexedDataset Internals /////////////////////////////////////////////////////////////////////////////// diff --git a/tensorflow/core/ops/fingerprint64_map_ops.cc b/tensorflow/core/ops/fingerprint64_map_ops.cc deleted file mode 100644 index 91b24b40178..00000000000 --- a/tensorflow/core/ops/fingerprint64_map_ops.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -using shape_inference::InferenceContext; - -REGISTER_OP("Fingerprint64Map") - .Output("table_handle: resource") - .Attr("heterogeneous_key_dtype: type") - .Attr("table_value_dtype: type = DT_INT64") - .Attr("num_oov_buckets: int >= 1") - .Attr("offset: int >= 0 = 0") - .Attr("use_node_name_sharing: bool = false") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }); - -} // namespace tensorflow diff --git a/tensorflow/core/ops/function_ops.cc b/tensorflow/core/ops/function_ops.cc index 8e86dd9f780..9f78e583237 100644 --- a/tensorflow/core/ops/function_ops.cc +++ b/tensorflow/core/ops/function_ops.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -25,7 +28,55 @@ REGISTER_SYSTEM_OP("_Arg") .Attr("index: int >= 0") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* context) { - context->set_output(0, context->UnknownShape()); + const AttrValue* dtype_attr = context->attrs().Find("T"); + if (!dtype_attr) { + return errors::InvalidArgument( + "_Arg node does not have attribute \"T\""); + } + + if (dtype_attr->type() == DT_RESOURCE) { + const AttrValue* dtype_attr = context->attrs().Find("_handle_dtypes"); + const AttrValue* shape_attr = context->attrs().Find("_handle_shapes"); + if (dtype_attr && shape_attr) { + if (dtype_attr->list().type().empty()) { + return errors::InvalidArgument( + "Invalid \"_handle_dtypes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + if (shape_attr->list().shape().empty()) { + return errors::InvalidArgument( + "Invalid \"_handle_shapes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + DataType dtype = dtype_attr->list().type(0); + const TensorShapeProto& shape_proto = shape_attr->list().shape(0); + shape_inference::ShapeHandle shape_handle; + TF_RETURN_IF_ERROR( + context->MakeShapeFromShapeProto(shape_proto, &shape_handle)); + context->set_output(0, shape_handle); + context->set_output_handle_shapes_and_types( + 0, std::vector{ + {shape_handle, dtype}}); + } else { + context->set_output(0, context->UnknownShape()); + } + } else { + const AttrValue* shape_attr = context->attrs().Find("_output_shapes"); + if (shape_attr && shape_attr->has_list()) { + if (shape_attr->list().shape().empty()) { + return errors::InvalidArgument( + "Invalid \"_output_shapes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + const TensorShapeProto& shape_proto = shape_attr->list().shape(0); + shape_inference::ShapeHandle shape_handle; + TF_RETURN_IF_ERROR( + context->MakeShapeFromShapeProto(shape_proto, &shape_handle)); + context->set_output(0, shape_handle); + } else { + context->set_output(0, context->UnknownShape()); + } + } return Status::OK(); }) .Doc(R"doc( @@ -33,6 +84,19 @@ A graph node which represents an argument to a function. output: The argument. index: This argument is the index-th argument of the function. + +Attributes for shape inference: +1. _output_shapes: this attribute can be set on an _Arg node producing + non-resource output(s). If set, its value should contain a list of + TensorShapeProto describing the shape(s) of the tensor(s) this _Arg node will + produce. If set, _Arg node's shape inference function will use it as the + node's output shapes. +2. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg + node producing resource output(s). If set, value of _handle_dtypes should + contain the dtype(s) of the resource(s) and value of _handle_shapes should + contain the shape(s) of the resource(s). If both attributes are set, _Arg + node's shape inference function will use their values as the node's output + type(s) and shape(s). )doc"); REGISTER_SYSTEM_OP("_DeviceArg") diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 4982ec6bd82..087e1190d7e 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -249,6 +249,8 @@ REGISTER_OP("For") .Attr("body: func") .SetShapeFn(shape_inference::UnknownShape); +// While no useful shape function is registered for function call ops directly, +// ShapeRefiner is run by default to perform shape inference. REGISTER_OP("PartitionedCall") .Input("args: Tin") .Output("output: Tout") diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 62d473847c0..3dd37bd97ce 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -449,6 +449,13 @@ REGISTER_OP("EncodeJpeg") .Output("contents: string") .SetShapeFn(EncodeImageShapeFn); +// -------------------------------------------------------------------------- +REGISTER_OP("EncodeJpegVariableQuality") + .Input("images: uint8") + .Input("quality: int32") + .Output("contents: string") + .SetShapeFn(EncodeImageShapeFn); + // -------------------------------------------------------------------------- REGISTER_OP("ExtractJpegShape") .Input("contents: string") @@ -594,6 +601,17 @@ REGISTER_OP("DrawBoundingBoxes") return shape_inference::UnchangedShape(c); }); +// -------------------------------------------------------------------------- +REGISTER_OP("DrawBoundingBoxesV2") + .Input("images: T") + .Input("boxes: float") + .Input("colors: float") + .Output("output: T") + .Attr("T: {float, half} = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); + }); + // -------------------------------------------------------------------------- REGISTER_OP("SampleDistortedBoundingBox") .Input("image_size: T") @@ -904,6 +922,7 @@ REGISTER_OP("CombinedNonMaxSuppression") .Output("nmsed_classes: float") .Output("valid_detections: int32") .Attr("pad_per_class: bool = false") + .Attr("clip_boxes: bool = true") .SetShapeFn(CombinedNMSShapeFn); } // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 66594b3576e..51ab3e268c6 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -208,6 +208,54 @@ Status SvdShapeFn(InferenceContext* c) { return Status::OK(); } +// Inputs: [...,1,M], [...,1,M], [...,1,M],[...,M,N]. +// Output is [...,M,N]. +Status TridiagonalMatMulShapeFn(InferenceContext* c) { + ShapeHandle superdiag; + ShapeHandle maindiag; + ShapeHandle subdiag; + ShapeHandle rhs; + + // Check that rank is at least 2. + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &superdiag)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &maindiag)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 2, &subdiag)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 2, &rhs)); + + // Extract batch dimensions and check they are the same. + ShapeHandle superdiag_batch_shape; + ShapeHandle maindiag_batch_shape; + ShapeHandle subdiag_batch_shape; + ShapeHandle rhs_batch_shape; + TF_RETURN_IF_ERROR(c->Subshape(superdiag, 0, -2, &superdiag_batch_shape)); + TF_RETURN_IF_ERROR(c->Subshape(maindiag, 0, -2, &maindiag_batch_shape)); + TF_RETURN_IF_ERROR(c->Subshape(subdiag, 0, -2, &subdiag_batch_shape)); + TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); + TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &superdiag)); + TF_RETURN_IF_ERROR( + c->Merge(maindiag_batch_shape, rhs_batch_shape, &rhs_batch_shape)); + TF_RETURN_IF_ERROR( + c->Merge(subdiag_batch_shape, rhs_batch_shape, &rhs_batch_shape)); + + // Check that diagonals have the same shape. + TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &maindiag)); + TF_RETURN_IF_ERROR(c->Merge(subdiag, maindiag, &maindiag)); + + // Check that size of tri-diagonal matrix is the same as height of matrix on + // the right. + DimensionHandle m_lhs = c->Dim(maindiag, -1); + DimensionHandle m_rhs = c->Dim(rhs, -2); + TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs)); + + // Check that next-to-last dimension of diagonals is 1. + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(maindiag, -2), 1, &unused)); + + // The output shape is the same as rhs shape. + c->set_output(0, rhs); + return Status::OK(); +} + // The first input is [...,3,M] and second input is [...,M,K]. // Output is [...,M,K]. Status TridiagonalSolveShapeFn(InferenceContext* c) { @@ -409,10 +457,20 @@ REGISTER_OP("Svd") .Attr("T: {double, float, half, complex64, complex128}") .SetShapeFn(SvdShapeFn); +REGISTER_OP("TridiagonalMatMul") + .Input("superdiag: T") + .Input("maindiag: T") + .Input("subdiag: T") + .Input("rhs: T") + .Output("output: T") + .Attr("T: {double, float, complex64, complex128}") + .SetShapeFn(TridiagonalMatMulShapeFn); + REGISTER_OP("TridiagonalSolve") .Input("diagonals: T") .Input("rhs: T") .Output("output: T") + .Attr("partial_pivoting: bool = True") .Attr("T: {double, float, complex64, complex128}") .SetShapeFn(TridiagonalSolveShapeFn); diff --git a/tensorflow/core/ops/linalg_ops_test.cc b/tensorflow/core/ops/linalg_ops_test.cc index 93732f938a9..682a994e890 100644 --- a/tensorflow/core/ops/linalg_ops_test.cc +++ b/tensorflow/core/ops/linalg_ops_test.cc @@ -314,6 +314,43 @@ TEST(LinalgOpsTest, Lu_ShapeFn) { "[d0_0,d0_1,d0_2,d0_3,d0_5,d0_5];[d0_0,d0_1,d0_2,d0_3,d0_5]"); } +TEST(LinalgOpsTest, TridiagonalMatMul_ShapeFn) { + ShapeInferenceTestOp op("TridiagonalMatMul"); + INFER_OK(op, "?;?;?;?", "in3"); + INFER_OK(op, "[1,5];[1,5];[1,5];[?,1]", "in3"); + INFER_OK(op, "[1,5];[1,5];[1,5];[5,1]", "in3"); + + INFER_OK(op, "[?,1,?];[?,1,?];[?,1,?];[?,?,?]", "in3"); + INFER_OK(op, "[?,1,5];[?,1,5];[?,1,5];[7,5,2]", "in3"); + INFER_OK(op, "[7,1,5];[7,1,5];[7,1,5];[?,5,2]", "in3"); + INFER_OK(op, "[7,1,5];[7,1,5];[7,1,5];[7,5,2]", "in3"); + + INFER_OK(op, "[7,?,1,5];[7,?,1,5];[7,?,1,5];[7,8,5,2]", "in3"); + INFER_OK(op, "[7,8,1,5];[7,8,1,5];[7,8,1,5];[7,8,5,2]", "in3"); + + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, + "[3];[3];[3];[5,1]"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, + "[3,5];[3,5];[3,5];[5]"); + INFER_ERROR( + "Dimension 1 in both shapes must be equal, but are 4 and 8. " + "Shapes are [6,4] and [6,8].", + op, "[6,4,3,5];[6,4,3,5];[6,4,3,5];[6,8,5,2]"); + INFER_ERROR( + "Dimension 1 in both shapes must be equal, but are 4 and 8. " + "Shapes are [?,4] and [6,8].", + op, "[?,4,3,5];[?,4,3,5];[?,4,3,5];[6,8,5,2]"); + + // Diagonals must have the same length. + INFER_ERROR( + "Dimension 1 in both shapes must be equal, but are 5 and 6. " + "Shapes are [1,5] and [1,6]", + op, "[1,5];[1,6];[1,5];[6,2]"); + + // Diagonals must be 1-row matrices. + INFER_ERROR("Dimension must be 1 but is 3", op, "[3,5];[3,5];[3,5];[5,2]"); +} + TEST(LinalgOpsTest, TridiagonalSolve_ShapeFn) { ShapeInferenceTestOp op("TridiagonalSolve"); INFER_OK(op, "?;?", "in1"); diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index 123ffc493a9..7a0ccb11f1d 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -215,7 +215,7 @@ REGISTER_OP("TensorListStack") return errors::InvalidArgument( "Trying to read from list with wrong element dtype. List has " "type ", - DataTypeString(list_shape_type.dtype), " but expectec type ", + DataTypeString(list_shape_type.dtype), " but expected type ", DataTypeString(element_dtype)); } shape_inference::ShapeHandle ignored; @@ -223,6 +223,11 @@ REGISTER_OP("TensorListStack") c->Merge(element_shape, list_shape_type.shape, &ignored)); element_shape = list_shape_type.shape; } + shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + 1, &element_shape_input)); + TF_RETURN_IF_ERROR( + c->Merge(element_shape, element_shape_input, &element_shape)); int expected_num_elements = -1; TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements)); shape_inference::ShapeHandle num_elements; @@ -418,6 +423,11 @@ REGISTER_OP("TensorListGetItem") DataTypeString(list_shape_type.dtype)); } } + shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + 2, &element_shape_input)); + TF_RETURN_IF_ERROR( + c->Merge(element_shape, element_shape_input, &element_shape)); c->set_output(0, element_shape); return Status::OK(); }); @@ -486,6 +496,11 @@ REGISTER_OP("TensorListGather") DataTypeString(list_shape_type.dtype)); } } + shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + 2, &element_shape_input)); + TF_RETURN_IF_ERROR( + c->Merge(element_shape, element_shape_input, &element_shape)); shape_inference::ShapeHandle out; TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out)); c->set_output(0, out); diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc index 42a1b1d7e3f..da8b7d883f1 100644 --- a/tensorflow/core/ops/logging_ops.cc +++ b/tensorflow/core/ops/logging_ops.cc @@ -50,6 +50,7 @@ REGISTER_OP("PrintV2") .Input("input: string") .SetIsStateful() .Attr("output_stream: string = 'stderr'") + .Attr("end: string = '\n'") .SetShapeFn([](InferenceContext* c) { // Make sure that the input is a scalar. if (c->Rank(c->input(0)) != 0) { diff --git a/tensorflow/core/ops/lookup_table_ops.cc b/tensorflow/core/ops/lookup_table_ops.cc deleted file mode 100644 index 3ce08f6f2f9..00000000000 --- a/tensorflow/core/ops/lookup_table_ops.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -using shape_inference::InferenceContext; - -REGISTER_OP("LookupTableInsertOrAssignOp") - .Input("table_int64_args: num_int64_table_args * int64") - .Input("table_handle: resource") - .Input("keys: insert_key_tensor_dtype") - .Input("values: table_value_dtype") - .Attr("insert_key_tensor_dtype: type") - .Attr("table_value_dtype: type") - .Attr("num_int64_table_args: int >= 0") - .SetShapeFn([](InferenceContext* c) { - // Note that, by design, shape checks are implementation dependent so they - // must be deferred until runtime. - return Status::OK(); - }); - -REGISTER_OP("LookupTableFindOp") - .Input("table_int64_args: num_int64_table_args * int64") - .Input("table_handle: resource") - .Input("keys: lookup_key_tensor_dtype") - .Input("num_threads: int64") - .Output("values: table_value_dtype") - .Attr("table_value_dtype: type") - .Attr("lookup_key_tensor_dtype: type") - .Attr("num_int64_table_args: int >= 0") - .SetShapeFn([](InferenceContext* c) { - // The output shape cannot be inferred here because the key size - // cannot be inferred from the key tensor in general. - c->set_output(0, c->UnknownShape()); - return Status::OK(); - }); - -REGISTER_OP("ContainerSizeOp") - .Input("container_handle: resource") - .Output("size: int64") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }); - -} // namespace tensorflow diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc index a8d454038c9..c63e3be26b1 100644 --- a/tensorflow/core/ops/math_grad.cc +++ b/tensorflow/core/ops/math_grad.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include + #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -593,6 +596,20 @@ Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Xdivy", XdivyGrad); +Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + FDH::Const("c", 2LL), + {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"x_sub_y"}, "Sub", {"x", "y"}}, + {{"two_x_sub_y"}, "Mul", {"two", "x_sub_y"}}, // 2 * (x - y) + {{"gx"}, "Mul", {"two_x_sub_y", "dz"}}, + {{"gy"}, "Neg", {"gx"}} + }); + // clang-format on +} +REGISTER_OP_GRADIENT("SquaredDifference", SquaredDifferenceGrad); + Status MaximumMinimumGradHelper(const string& comparator, const AttrSlice& attrs, FunctionDef* g) { // clang-format off @@ -775,7 +792,47 @@ static Status MatMulGradHelper(FunctionDef* g, const string& opname, const string& attr_adj_y, const string& x0, bool ax0, const string& x1, bool ax1, const string& y0, bool ay0, const string& y1, - bool ay1) { + bool ay1, bool enable_broadcasting) { + // The final outputs are "dx" and "dy". If we're broadcasting compute + // intermediate nodes for now. + std::vector nodes = { + {{(enable_broadcasting ? "gx" : "dx")}, + opname, + {x0, x1}, + {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}}, + {{(enable_broadcasting ? "gy" : "dy")}, + opname, + {y0, y1}, + {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}}, + }; + // TODO(anudhyan): Figure out a way to inspect the static shapes of "x" and + // "y". If they have the same batch dimensions, then we can omit adding the + // broadcasting-specific ops. + if (enable_broadcasting) { + std::vector unbroadcast_gradients = { + FDH::Const("zero", gtl::ArraySlice{0}), + FDH::Const("one", gtl::ArraySlice{1}), + FDH::Const("minustwo", gtl::ArraySlice{-2}), + // Compute the batch shapes of the inputs (all but last two dims). + {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}}, + {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}}, + {{"batch_sx"}, + "StridedSlice", + {"sx", "zero", "minustwo", "one"}, + {{"T", DT_INT32}, {"Index", DT_INT32}}}, + {{"batch_sy"}, + "StridedSlice", + {"sy", "zero", "minustwo", "one"}, + {{"T", DT_INT32}, {"Index", DT_INT32}}}, + // Sum along dimensions that the inputs were broadcasted across. + {{"rx", "ry"}, "BroadcastGradientArgs", {"batch_sx", "batch_sy"}}, + {{"sum_gx"}, "Sum", {"gx", "rx"}, {{"T", "$T"}}}, + {{"sum_gy"}, "Sum", {"gy", "ry"}, {{"T", "$T"}}}, + {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}}, + {{"dy"}, "Reshape", {"sum_gy", "sy"}, {{"T", "$T"}}}}; + nodes.insert(nodes.end(), unbroadcast_gradients.begin(), + unbroadcast_gradients.end()); + } *g = FDH::Define( // Arg defs {"x: T", "y: T", "dz: T"}, @@ -784,22 +841,13 @@ static Status MatMulGradHelper(FunctionDef* g, const string& opname, // Attr defs {{"T: {half, float, double}"}}, // Nodes - { - {{"dx"}, - opname, - {x0, x1}, - {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}}, - {{"dy"}, - opname, - {y0, y1}, - {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}}, - }); + nodes); return Status::OK(); } Status MatMulGradCommon(const string& opname, const string& attr_adj_x, const string& attr_adj_y, const AttrSlice& attrs, - FunctionDef* g) { + FunctionDef* g, bool enable_broadcasting) { DataType T; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { @@ -812,31 +860,39 @@ Status MatMulGradCommon(const string& opname, const string& attr_adj_x, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb)); if (!ta && !tb) { return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y", - true, "x", true, "dz", false); + true, "x", true, "dz", false, enable_broadcasting); } if (!ta && tb) { return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y", - false, "dz", true, "x", false); + false, "dz", true, "x", false, enable_broadcasting); } if (ta && !tb) { return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz", - true, "x", false, "dz", false); + true, "x", false, "dz", false, enable_broadcasting); } CHECK(ta && tb); return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz", - true, "dz", true, "x", true); + true, "dz", true, "x", true, enable_broadcasting); } Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { - return MatMulGradCommon("MatMul", "transpose_a", "transpose_b", attrs, g); + return MatMulGradCommon("MatMul", "transpose_a", "transpose_b", attrs, g, + false /* enable_broadcasting */); } REGISTER_OP_GRADIENT("MatMul", MatMulGrad); Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) { - return MatMulGradCommon("BatchMatMul", "adj_x", "adj_y", attrs, g); + return MatMulGradCommon("BatchMatMul", "adj_x", "adj_y", attrs, g, + false /* enable_broadcasting */); } REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad); +Status BatchMatMulV2Grad(const AttrSlice& attrs, FunctionDef* g) { + return MatMulGradCommon("BatchMatMulV2", "adj_x", "adj_y", attrs, g, + true /* enable_broadcasting */); +} +REGISTER_OP_GRADIENT("BatchMatMulV2", BatchMatMulV2Grad); + // REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad); // Comparison ops. diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index 9fc6b341479..115dbd27df6 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" @@ -224,6 +225,29 @@ class MathGradTest : public ::testing::Test { *di = outputs[1]; } + Tensor ReduceSum(const Tensor& x, gtl::ArraySlice axes) { + int num_axes = axes.length(); + Tensor y(DT_INT32, TensorShape({num_axes})); + for (size_t i = 0; i < axes.size(); ++i) { + y.flat()(i) = axes[i]; + } + auto T = x.dtype(); + auto gdef = test::function::GDef( + { + f::NDef("x", "Placeholder", {}, {{"dtype", T}}), + f::NDef("y", "Const", {}, {{"dtype", DT_INT32}, {"value", y}}), + f::NDef("z", "Sum", {"x", "y"}, {{"T", T}}), + }, + {}); + auto sess = NewSession(); + TF_CHECK_OK(sess->Create(gdef)); + std::vector outputs; + TF_CHECK_OK(sess->Run({{"x:0", x}}, {"z:0"}, {}, &outputs)); + CHECK_EQ(outputs.size(), 1); + TF_CHECK_OK(sess->Close()); + return outputs[0]; + } + Tensor MatMulCommon(const string& opname, const string& attr_adj_x, const string& attr_adj_y, const Tensor& x, bool ax, const Tensor& y, bool ay) { @@ -253,6 +277,10 @@ class MathGradTest : public ::testing::Test { return MatMulCommon("BatchMatMul", "adj_x", "adj_y", x, ax, y, ay); } + Tensor BatchMatMulV2(const Tensor& x, bool ax, const Tensor& y, bool ay) { + return MatMulCommon("BatchMatMulV2", "adj_x", "adj_y", x, ax, y, ay); + } + void MatMulGradCommon(const string& opname, const string& attr_adj_x, const string& attr_adj_y, const Tensor& x, bool ax, const Tensor& y, bool ay, Tensor* dx, Tensor* dy) { @@ -325,6 +353,12 @@ class MathGradTest : public ::testing::Test { dy); } + void BatchMatMulV2Grad(const Tensor& x, bool ax, const Tensor& y, bool ay, + Tensor* dx, Tensor* dy) { + return MatMulGradCommon("BatchMatMulV2", "adj_x", "adj_y", x, ax, y, ay, dx, + dy); + } + void SelectGrad(const Tensor& c, const Tensor& x, const Tensor& y, Tensor* dc, Tensor* dx, Tensor* dy) { auto T = DT_FLOAT; @@ -949,6 +983,25 @@ TEST_F(MathGradTest, Xdivy) { TensorShape({2, 1}))); } +TEST_F(MathGradTest, SquaredDifference) { + auto x = test::AsTensor({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, + TensorShape({2, 3})); + auto y = test::AsTensor({.5f, 2.f}, TensorShape({2, 1})); + Tensor dx; + Tensor dy; + auto g = [](float x, float y) -> float { return 2. * (x - y); }; + auto h = [](float x, float y) -> float { return 2. * (y - x); }; + SymGrad("SquaredDifference", x, y, &dx, &dy); + test::ExpectClose( + dx, test::AsTensor({g(-3.f, .5f), g(-2.f, .5f), g(-1.f, .5f), + g(1.f, 2.f), g(2.f, 2.f), g(3.f, 2.f)}, + TensorShape({2, 3}))); + test::ExpectClose( + dy, test::AsTensor({h(-3.f, .5f) + h(-2.f, .5f) + h(-1.f, .5f), + h(1.f, 2.f) + h(2.f, 2.f) + h(3.f, 2.f)}, + TensorShape({2, 1}))); +} + TEST_F(MathGradTest, Maximum) { auto x = test::AsTensor({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, TensorShape({2, 3})); @@ -1160,6 +1213,139 @@ TEST_F(MathGradTest, BatchMatMul_11) { } #endif // TENSORFLOW_USE_SYCL +TEST_F(MathGradTest, BatchMatMulV2_00) { + auto x = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, + TensorShape({1, 2, 3})); + auto y = test::AsTensor({-1.f, .5f, 2.f}, TensorShape({1, 3, 1})); + Tensor dx; + Tensor dy; + BatchMatMulV2Grad(x, false, y, false, &dx, &dy); + auto dz = test::AsTensor({1.f, 1.f}, TensorShape({1, 2, 1})); + test::ExpectClose(dx, BatchMatMulV2(dz, false, y, true)); + test::ExpectClose(dy, BatchMatMulV2(x, true, dz, false)); +} + +TEST_F(MathGradTest, BatchMatMulV2_01) { + auto x = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, + TensorShape({1, 2, 3})); + auto y = test::AsTensor({-1.f, .5f, 2.f}, TensorShape({1, 1, 3})); + Tensor dx; + Tensor dy; + BatchMatMulV2Grad(x, false, y, true, &dx, &dy); + auto dz = test::AsTensor({1.f, 1.f}, TensorShape({1, 2, 1})); + test::ExpectClose(dx, BatchMatMulV2(dz, false, y, false)); + test::ExpectClose(dy, BatchMatMulV2(dz, true, x, false)); +} + +TEST_F(MathGradTest, BatchMatMulV2_10) { + auto x = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, + TensorShape({1, 3, 2})); + auto y = test::AsTensor({-1.f, .5f, 2.f}, TensorShape({1, 3, 1})); + Tensor dx; + Tensor dy; + BatchMatMulV2Grad(x, true, y, false, &dx, &dy); + auto dz = test::AsTensor({1.f, 1.f}, TensorShape({1, 2, 1})); + test::ExpectClose(dx, BatchMatMulV2(y, false, dz, true)); + test::ExpectClose(dy, BatchMatMulV2(x, false, dz, false)); +} + +TEST_F(MathGradTest, BatchMatMulV2_11) { + auto x = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, + TensorShape({1, 3, 2})); + auto y = test::AsTensor({-1.f, .5f, 2.f}, TensorShape({1, 1, 3})); + Tensor dx; + Tensor dy; + BatchMatMulV2Grad(x, true, y, true, &dx, &dy); + auto dz = test::AsTensor({1.f, 1.f}, TensorShape({1, 2, 1})); + test::ExpectClose(dx, BatchMatMulV2(y, true, dz, true)); + test::ExpectClose(dy, BatchMatMulV2(dz, true, x, true)); +} + +TEST_F(MathGradTest, BatchMatMulV2_LhsBroadcasts) { + auto x = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, + TensorShape({2, 3})); + auto y = test::AsTensor( + {1.f, 2.4, 3.f, -1.f, .5f, 2.f, 3.f, 1.f, -1.f, 2.f, -.1f, 0}, + TensorShape({2, 3, 2})); + Tensor dx; + Tensor dy; + BatchMatMulV2Grad(x, false, y, false, &dx, &dy); + EXPECT_TRUE(dx.shape().IsSameSize(x.shape())); + EXPECT_TRUE(dy.shape().IsSameSize(y.shape())); + auto dz = test::AsTensor({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}, + TensorShape({2, 2, 2})); + Tensor ans_dx; + CHECK(ans_dx.CopyFrom(ReduceSum(BatchMatMulV2(dz, false, y, true), {0}), + dx.shape())); + Tensor ans_dy = BatchMatMulV2(x, true, dz, false); + test::ExpectClose(dx, ans_dx); + test::ExpectClose(dy, ans_dy); +} + +TEST_F(MathGradTest, BatchMatMulV2_RhsBroadcasts) { + auto x = test::AsTensor( + {1.f, 2.4, 3.f, -1.f, .5f, 2.f, 3.f, 1.f, -1.f, 2.f, -.1f, 0}, + TensorShape({2, 2, 3})); + auto y = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, + TensorShape({3, 2})); + Tensor dx; + Tensor dy; + BatchMatMulV2Grad(x, false, y, false, &dx, &dy); + auto dz = test::AsTensor({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}, + TensorShape({2, 2, 2})); + Tensor ans_dx = BatchMatMulV2(dz, false, y, true); + Tensor ans_dy; + CHECK(ans_dy.CopyFrom(ReduceSum(BatchMatMulV2(x, true, dz, false), {0}), + dy.shape())); + test::ExpectClose(dx, ans_dx); + test::ExpectClose(dy, ans_dy); +} + +TEST_F(MathGradTest, BatchMatMulV2_BothLhsAndRhsBroadcast) { + auto x = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, + TensorShape({2, 1, 1, 3})); + auto y = test::AsTensor({3.f, 1.f, -1.f, 2.f, -.1f, 0}, + TensorShape({1, 2, 3, 1})); + Tensor dx; + Tensor dy; + BatchMatMulV2Grad(x, false, y, false, &dx, &dy); + EXPECT_TRUE(dx.shape().IsSameSize(x.shape())); + EXPECT_TRUE(dy.shape().IsSameSize(y.shape())); + auto dz = + test::AsTensor({1.f, 1.f, 1.f, 1.f}, TensorShape({2, 2, 1, 1})); + Tensor ans_dx; + Tensor ans_dy; + CHECK(ans_dx.CopyFrom(ReduceSum(BatchMatMulV2(dz, false, y, true), {1}), + dx.shape())); + CHECK(ans_dy.CopyFrom(ReduceSum(BatchMatMulV2(x, true, dz, false), {0}), + dy.shape())); + test::ExpectClose(dx, ans_dx); + test::ExpectClose(dy, ans_dy); +} + +TEST_F(MathGradTest, BatchMatMulV2_BroadcastWhileAdjointed) { + auto x = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, + TensorShape({2, 1, 3, 1})); + auto y = test::AsTensor({3.f, 1.f, -1.f, 2.f, -.1f, 0}, + TensorShape({1, 2, 1, 3})); + Tensor dx; + Tensor dy; + BatchMatMulV2Grad(x, true, y, true, &dx, &dy); + EXPECT_TRUE(dx.shape().IsSameSize(x.shape())); + EXPECT_TRUE(dy.shape().IsSameSize(y.shape())); + + auto dz = + test::AsTensor({1.f, 1.f, 1.f, 1.f}, TensorShape({2, 2, 1, 1})); + Tensor ans_dx; + Tensor ans_dy; + CHECK(ans_dx.CopyFrom(ReduceSum(BatchMatMulV2(y, true, dz, true), {1}), + dx.shape())); + CHECK(ans_dy.CopyFrom(ReduceSum(BatchMatMulV2(dz, true, x, true), {0}), + dy.shape())); + test::ExpectClose(dx, ans_dx); + test::ExpectClose(dy, ans_dy); +} + TEST_F(MathGradTest, Sum_dim0) { auto x = test::AsTensor({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, TensorShape({2, 3})); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 1c2e9d40ae6..3ff9bc09853 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -158,6 +158,17 @@ REGISTER_OP("BatchMatMul") return Status::OK(); }); +REGISTER_OP("BatchMatMulV2") + .Input("x: T") + .Input("y: T") + .Output("output: T") + .Attr( + "T: {bfloat16, half, float, double, int32, int64, complex64, " + "complex128}") + .Attr("adj_x: bool = false") + .Attr("adj_y: bool = false") + .SetShapeFn(shape_inference::BatchMatMulV2Shape); + // -------------------------------------------------------------------------- // Casting Ops // @@ -403,7 +414,7 @@ REGISTER_OP("_MklAdd") .Output("mkl_z: uint8") .Attr( "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " - "complex128, string}") + "complex128, string, bfloat16}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns `x` + `y` element-wise. @@ -435,16 +446,15 @@ REGISTER_OP("MulNoNan") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {float, double}") - .SetIsCommutative() + .Attr("T: {half, float, double, complex64, complex128}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("_MklMul") .BINARY_MORE() .Input("mkl_x: uint8") .Input("mkl_y: uint8") .Output("mkl_z: uint8") - .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns x * y element-wise. @@ -460,7 +470,7 @@ REGISTER_OP("DivNoNan") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {float, double}") + .Attr("T: {half, float, double, complex64, complex128}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); REGISTER_OP("FloorDiv") @@ -479,12 +489,12 @@ REGISTER_OP("SquaredDifference") .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("_MklSquaredDifference") .BINARY_FEWER() .Input("mkl_x: uint8") .Input("mkl_y: uint8") .Output("mkl_z: uint8") - .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns (x - y)(x - y) element-wise. @@ -515,9 +525,9 @@ REGISTER_OP("Maximum") .Input("y: T") .Output("z: T") .Attr("T: {bfloat16, half, float, double, int32, int64}") - .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("_MklMaximum") .Input("x: T") .Input("y: T") @@ -525,8 +535,7 @@ REGISTER_OP("_MklMaximum") .Input("mkl_y: uint8") .Output("z: T") .Output("mkl_z: uint8") - .Attr("T: {half, float, double, int32, int64}") - .SetIsCommutative() + .Attr("T: {half, float, double, int32, int64, bfloat16}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns the max of x and y (i.e. x > y ? x : y) element-wise. @@ -540,7 +549,6 @@ REGISTER_OP("Minimum") .Input("y: T") .Output("z: T") .Attr("T: {bfloat16, half, float, double, int32, int64}") - .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); REGISTER_OP("Mod") @@ -820,6 +828,57 @@ REGISTER_OP("Select") return Status::OK(); }); +REGISTER_OP("SelectV2") + .Input("condition: bool") + .Input("t: T") + .Input("e: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + auto* handle_data_1 = c->input_handle_shapes_and_types(1); + auto* handle_data_2 = c->input_handle_shapes_and_types(2); + // Merge handle shape and dtype if applicable. + if (handle_data_1 != nullptr && handle_data_2 != nullptr) { + const auto size = handle_data_1->size(); + std::vector merged_handle_data(size); + if (size != handle_data_2->size()) { + return errors::InvalidArgument( + "Trying to merge handles pointing to different numbers of " + "tensors."); + } + + for (int i = 0; i < size; ++i) { + const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i]; + const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i]; + if (s1.dtype != s2.dtype) { + // TODO(apassos) resolve this in the manner of b/32476923 + return errors::InvalidArgument( + "Trying to merge handles pointing to different dtypes."); + } + merged_handle_data[i].dtype = s1.dtype; + TF_RETURN_IF_ERROR( + c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape)); + } + + c->set_output_handle_shapes_and_types(0, merged_handle_data); + } + + // The inputs 'cond', 'then', and 'else' must be broadcastable. + // TODO (yongtang): Consolidate 3-ary broadcast instead of + // multiple 2-ary broadcast. + ShapeHandle cond = c->input(0); + ShapeHandle then = c->input(1); + ShapeHandle else_ = c->input(2); + ShapeHandle other; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, &other)); + ShapeHandle output; + TF_RETURN_IF_ERROR( + BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, &output)); + c->set_output(0, output); + return Status::OK(); + }); + // -------------------------------------------------------------------------- REGISTER_OP("MatMul") @@ -845,6 +904,25 @@ REGISTER_OP("SparseMatMul") .Attr("Tb: {float, bfloat16} = DT_FLOAT") .SetShapeFn(shape_inference::MatMulShape); +REGISTER_OP("_FusedMatMul") + .Input("a: T") + .Input("b: T") + .Input("args: num_args * T") + .Output("product: T") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("T: {float}") + .Attr("num_args: int >= 0") + .Attr("fused_ops: list(string) = []") + // Attributes for the FusedBatchNorm ----------- // + .Attr("epsilon: float = 0.0001") + // --------------------------------------------- // + .SetShapeFn(shape_inference::MatMulShape) + .Doc(R"doc( +*NOTE*: Do not invoke this operator directly in Python. Grappler is +expected to create these operators. +)doc"); + // -------------------------------------------------------------------------- // For operations where the output is a reduction function along some @@ -1589,6 +1667,7 @@ REGISTER_OP("QuantizedMatMul") return Status::OK(); }); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("QuantizedMul") .Input("x: T1") .Input("y: T2") @@ -1602,7 +1681,6 @@ REGISTER_OP("QuantizedMul") .Attr("T1: quantizedtype") .Attr("T2: quantizedtype") .Attr("Toutput: quantizedtype = DT_QINT32") - .SetIsCommutative() .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); c->set_output(1, c->Scalar()); @@ -1610,6 +1688,7 @@ REGISTER_OP("QuantizedMul") return Status::OK(); }); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("QuantizedAdd") .Input("x: T1") .Input("y: T2") @@ -1623,7 +1702,6 @@ REGISTER_OP("QuantizedAdd") .Attr("T1: quantizedtype") .Attr("T2: quantizedtype") .Attr("Toutput: quantizedtype = DT_QINT32") - .SetIsCommutative() .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); // min_x, max_x, min_y, max_y should be scalar. @@ -1740,6 +1818,7 @@ REGISTER_OP("ClipByValue") .SetShapeFn(shape_inference::UnchangedShape); #ifdef INTEL_MKL +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("_MklAddN") .Input("inputs: N * T") .Input("mkl_input: N * uint8") @@ -1747,8 +1826,6 @@ REGISTER_OP("_MklAddN") .Output("mkl_sum: uint8") .Attr("N: int >= 1") .Attr("T: numbertype") - .SetIsCommutative() - .SetIsAggregate() .SetShapeFn([](InferenceContext* c) { ShapeHandle cur = c->input(c->num_inputs() - 1); for (int i = c->num_inputs() - 2; i >= 0; --i) { diff --git a/tensorflow/core/ops/mkl_array_ops.cc b/tensorflow/core/ops/mkl_array_ops.cc index e7ad3be6112..599ac038d64 100644 --- a/tensorflow/core/ops/mkl_array_ops.cc +++ b/tensorflow/core/ops/mkl_array_ops.cc @@ -87,6 +87,54 @@ REGISTER_OP("_MklQuantizedConcatV2") c->set_output(2, c->Scalar()); return Status::OK(); }); + +REGISTER_OP("_MklQuantizeV2") + .Input("input: float") + .Input("min_range: float") + .Input("max_range: float") + .Input("mkl_input: uint8") + .Input("mkl_min_range: uint8") + .Input("mkl_max_range: uint8") + .Output("output: T") + .Output("output_min: float") + .Output("output_max: float") + .Output("mkl_output: uint8") + .Output("mkl_output_min: uint8") + .Output("mkl_output_max: uint8") + .Attr("T: quantizedtype") + .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'") + .Attr( + "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = " + "'HALF_TO_EVEN'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("_MklDequantize") + .Input("input: T") + .Input("min_range: float") + .Input("max_range: float") + .Input("mkl_input: uint8") + .Input("mkl_min_range: uint8") + .Input("mkl_max_range: uint8") + .Output("output: float") + .Output("mkl_output: uint8") + .Attr("T: quantizedtype") + .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return Status::OK(); + }); + } // namespace tensorflow -#endif +#endif // INTEL_MKL diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index 0e6ad9162a5..858849c1937 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -43,7 +43,7 @@ REGISTER_OP("_MklFusedConv2D") .Output("filter_output: T") .Output("mkl_output: uint8") .Output("mkl_filter_output: uint8") - .Attr("T: {float}") + .Attr("T: {bfloat16, float}") .Attr("num_args: int >= 0") .Attr("strides: list(int)") .Attr("is_filter_const: bool = false") @@ -69,7 +69,7 @@ REGISTER_OP("__MklDummyPadWithFusedConv2D") .Output("filter_output: T") .Output("mkl_output: uint8") .Output("mkl_filter_output: uint8") - .Attr("T: {float}") + .Attr("T: {bfloat16, float}") .Attr("num_args: int >= 0") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) @@ -99,7 +99,7 @@ REGISTER_OP("_MklPadWithFusedConv2D") .Output("filter_output: T") .Output("mkl_output: uint8") .Output("mkl_filter_output: uint8") - .Attr("T: {float}") + .Attr("T: {bfloat16, float}") .Attr("num_args: int >= 0") .Attr("strides: list(int)") .Attr("is_filter_const: bool = false") @@ -138,8 +138,7 @@ REGISTER_OP("_MklQuantizedMaxPool") .Doc(R"doc( MKL version of QuantizedMaxPool operator. Uses MKL DNN APIs to perform max pooling on the quantized input. - -NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); @@ -172,8 +171,7 @@ REGISTER_OP("_MklQuantizedAvgPool") .Doc(R"doc( MKL version of QuantizedAvgPool operator. Uses MKL DNN APIs to perform average pooling on the quantized input. - -NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); @@ -219,6 +217,9 @@ REGISTER_OP("_MklQuantizedConv2D") return Status::OK(); }); +// TODO(nammbash): Most of the TF_RETURN_IF_ERROR(c->WithRank) checks +// seems to be similar and hence can be moved into a single function +// with appropriate arguments for a cleaner design. REGISTER_OP("_MklQuantizedConv2DAndRequantize") .Input("input: Tinput") .Input("filter: Tfilter") @@ -258,8 +259,8 @@ REGISTER_OP("_MklQuantizedConv2DAndRequantize") ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); c->set_output(1, c->Scalar()); @@ -301,14 +302,14 @@ REGISTER_OP("_MklQuantizedConv2DWithBias") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); return Status::OK(); }); @@ -355,8 +356,8 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRequantize") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); c->set_output(1, c->Scalar()); c->set_output(2, c->Scalar()); return Status::OK(); @@ -394,13 +395,13 @@ REGISTER_OP("_MklQuantizedConv2DAndRelu") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); return Status::OK(); }); @@ -443,8 +444,8 @@ REGISTER_OP("_MklQuantizedConv2DAndReluAndRequantize") ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); c->set_output(1, c->Scalar()); @@ -486,14 +487,14 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRelu") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); return Status::OK(); }); @@ -540,8 +541,8 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndReluAndRequantize") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); c->set_output(1, c->Scalar()); @@ -585,14 +586,14 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndRelu") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); return Status::OK(); }); @@ -646,8 +647,8 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); c->set_output(1, c->Scalar()); @@ -705,8 +706,8 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); c->set_output(1, c->Scalar()); @@ -714,6 +715,50 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") return Status::OK(); }); +REGISTER_OP("_MklQuantizedConv2DPerChannel") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("mkl_input: uint8") + .Input("mkl_filter: uint8") + .Input("mkl_min_input: uint8") + .Input("mkl_max_input: uint8") + .Input("mkl_min_filter: uint8") + .Input("mkl_max_filter: uint8") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Output("mkl_output: uint8") + .Output("mkl_min_output: uint8") + .Output("mkl_max_output: uint8") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("T: quantizedtype") // Additional attribute "T" for enabling MklToTf + // conversion + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("data_format: string = 'NHWC'") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = false") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused, channel; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); + return Status::OK(); + }) + .Doc(R"doc( +MKL-DNN implementation of QuantizedConv2D op. +)doc"); + REGISTER_OP("_MklDepthwiseConv2dNativeBackpropInput") .Input("input_sizes: int32") .Input("filter: T") @@ -758,6 +803,208 @@ REGISTER_OP("_MklDepthwiseConv2dNativeBackpropFilter") return Status::OK(); }); +REGISTER_OP("_MklQuantizedDepthwiseConv2D") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("mkl_input: uint8") + .Input("mkl_filter: uint8") + .Input("mkl_min_input: uint8") + .Input("mkl_max_input: uint8") + .Input("mkl_min_filter: uint8") + .Input("mkl_max_filter: uint8") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Output("mkl_output: uint8") + .Output("mkl_min_output: uint8") + .Output("mkl_max_output: uint8") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + // In order to enable MKL to TF conversion, _MklToTf op requires the + // attribute "T" to be specified. + .Attr("T: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("data_format: string = 'NHWC'") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = true") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + // TODO(bhavanis): Print an error message during the return. + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused, channel; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); + return Status::OK(); + }) + .Doc(R"doc( +MKL-DNN implementation of quantized depthwise Conv2D. +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke this operator. +)doc"); + +REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBias") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: float") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("mkl_input: uint8") + .Input("mkl_filter: uint8") + .Input("mkl_bias: uint8") + .Input("mkl_min_input: uint8") + .Input("mkl_max_input: uint8") + .Input("mkl_min_filter: uint8") + .Input("mkl_max_filter: uint8") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Output("mkl_output: uint8") + .Output("mkl_min_output: uint8") + .Output("mkl_max_output: uint8") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + // Additional attribute "T" for enabling MKL to TF conversion + .Attr("T: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("data_format: string = 'NHWC'") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = true") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused, channel; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); + return Status::OK(); + }) + .Doc(R"doc( +MKL-DNN implementation of quantized depthwise Conv2D with Bias. +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke this operator. +)doc"); + +REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBiasAndRelu") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: float") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("mkl_input: uint8") + .Input("mkl_filter: uint8") + .Input("mkl_bias: uint8") + .Input("mkl_min_input: uint8") + .Input("mkl_max_input: uint8") + .Input("mkl_min_filter: uint8") + .Input("mkl_max_filter: uint8") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Output("mkl_output: uint8") + .Output("mkl_min_output: uint8") + .Output("mkl_max_output: uint8") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + // Additional attribute "T" for enabling MKL to TF conversion + .Attr("T: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("data_format: string = 'NHWC'") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = true") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused, channel; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); + return Status::OK(); + }) + .Doc(R"doc( +MKL-DNN implementation of quantized depthwise Conv2D with Bias and Relu. +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke this operator. +)doc"); + +REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: Tbias") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Input("mkl_input: uint8") + .Input("mkl_filter: uint8") + .Input("mkl_bias: uint8") + .Input("mkl_min_input: uint8") + .Input("mkl_max_input: uint8") + .Input("mkl_min_filter: uint8") + .Input("mkl_max_filter: uint8") + .Input("mkl_min_freezed_output: uint8") + .Input("mkl_max_freezed_output: uint8") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Output("mkl_output: uint8") + .Output("mkl_min_output: uint8") + .Output("mkl_max_output: uint8") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("Tbias: {float, qint32}") + // Additional attribute "T" for enabling MKL to TF conversion + .Attr("T: quantizedtype") + .Attr("out_type: quantizedtype = DT_QUINT8") + .Attr("data_format: string = 'NHWC'") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = true") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +MKL-DNN implementation of quantized depthwise Conv2D with Bias, Relu and Requantize. +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke this operator. +)doc"); + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/ops/nn_grad.cc b/tensorflow/core/ops/nn_grad.cc index 560b71a337e..7beaf57c10b 100644 --- a/tensorflow/core/ops/nn_grad.cc +++ b/tensorflow/core/ops/nn_grad.cc @@ -37,18 +37,41 @@ Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { { {{"softmax"}, "Softmax", {"x"}, {{"T", "$T"}}}, {{"n0"}, "Mul", {"grad_softmax", "softmax"}, {{"T", "$T"}}}, - FDH::Const("indices", {1}), - {{"n1"}, "Sum", {"n0", "indices"}, {{"T", "$T"}}}, - FDH::Const("newshape", {-1, 1}), - {{"n2"}, "Reshape", {"n1", "newshape"}, {{"T", "$T"}}}, - {{"n3"}, "Sub", {"grad_softmax", "n2"}, {{"T", "$T"}}}, - {{"grad_x"}, "Mul", {"n3", "softmax"}, {{"T", "$T"}}} + FDH::Const("indices", {-1}), + {{"n1"}, "Sum", {"n0", "indices"}, {{"keep_dims", true}, {"T", "$T"}}}, + {{"n2"}, "Sub", {"grad_softmax", "n1"}, {{"T", "$T"}}}, + {{"grad_x"}, "Mul", {"n2", "softmax"}, {{"T", "$T"}}} }); // clang-format on return Status::OK(); } REGISTER_OP_GRADIENT("Softmax", SoftmaxGrad); +Status LogSoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + "LogSoftmaxGrad", + // Arg defs + {"x: T", "grad_logsoftmax: T"}, + // Ret val defs + {"grad_x: T"}, + // Attr defs + {{"T: {float, double}"}}, + // Nodes + // Based on _LogSoftmaxGrad in nn_grad.py. + { + {{"softmax"}, "Softmax", {"x"}, {{"T", "$T"}}}, + FDH::Const("indices", {-1}), + {{"n0"}, "Sum", {"grad_logsoftmax", "indices"}, + {{"keep_dims", true}, {"T", "$T"}}}, + {{"n1"}, "Mul", {"n0", "softmax"}, {{"T", "$T"}}}, + {{"grad_x"}, "Sub", {"grad_logsoftmax", "n1"}, {{"T", "$T"}}} + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("LogSoftmax", LogSoftmaxGrad); + Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 2b1d031be86..4d248b9f0ea 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op.h" @@ -45,7 +46,7 @@ Status FractionalPoolShapeFn(InferenceContext* c) { if (c->ValueKnown(d)) { // This must match the same logic in the kernel function in // core/kernels/fractional_max_pool_op.cc. - auto val = static_cast(floor(c->Value(d) / pooling_ratio[i])); + auto val = static_cast(std::floor(c->Value(d) / pooling_ratio[i])); if (val < 0) { return errors::InvalidArgument("Size computed for dim ", i, " is negative: ", val); @@ -326,7 +327,8 @@ REGISTER_OP("_FusedConv2D") .Attr("T: {float, double}") .Attr("num_args: int >= 0") .Attr("strides: list(int)") - .Attr(GetPaddingAttrString()) + .Attr(GetPaddingAttrStringWithExplicit()) + .Attr(GetExplicitPaddingsAttrString()) .Attr(GetConvnetDataFormatAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") .Attr("use_cudnn_on_gpu: bool = true") @@ -334,7 +336,7 @@ REGISTER_OP("_FusedConv2D") // Attributes for the FusedBatchNorm ------------------------------------ // .Attr("epsilon: float = 0.0001") // ---------------------------------------------------------------------- // - .SetShapeFn(shape_inference::Conv2DShape) + .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) .Doc(R"doc( *NOTE*: Do not invoke this operator directly in Python. Grappler is expected to create these operators. @@ -1572,7 +1574,7 @@ REGISTER_OP("_MklConv2D") .Output("filter_output: T") .Output("mkl_output: uint8") .Output("mkl_filter_output: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr("is_filter_const: bool = false") @@ -1592,7 +1594,7 @@ REGISTER_OP("__MklDummyConv2DWithBias") .Input("filter: T") .Input("bias: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr("is_filter_const: bool = false") @@ -1620,7 +1622,7 @@ REGISTER_OP("_MklConv2DWithBias") .Output("filter_output: T") .Output("mkl_output: uint8") .Output("mkl_filter_output: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr("is_filter_const: bool = false") @@ -1641,7 +1643,7 @@ REGISTER_OP("__MklDummyPadWithConv2D") .Input("filter: T") .Input("paddings: Tpaddings") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr("is_filter_const: bool = false") @@ -1670,7 +1672,7 @@ REGISTER_OP("_MklPadWithConv2D") .Output("filter_output: T") .Output("mkl_output: uint8") .Output("mkl_filter_output: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) @@ -1696,7 +1698,7 @@ REGISTER_OP("_MklConv2DBackpropFilter") .Input("mkl_out_backprop: uint8") .Output("output: T") .Output("mkl_output: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) @@ -1723,7 +1725,7 @@ REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias") .Input("out_backprop: T") .Output("output: T") .Output("bias_grad: T") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) @@ -1768,7 +1770,7 @@ REGISTER_OP("_MklConv2DBackpropFilterWithBias") .Output("bias_grad: T") .Output("mkl_output: uint8") .Output("mkl_bias_grad: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) @@ -1829,7 +1831,7 @@ REGISTER_OP("_MklConv2DBackpropInput") .Input("mkl_out_backprop: uint8") .Output("output: T") .Output("mkl_output: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) @@ -1859,7 +1861,7 @@ REGISTER_OP("_MklConv3D") .Output("filter_output: T") .Output("mkl_output: uint8") .Output("mkl_filter_output: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int) >= 5") .Attr("is_filter_const: bool = false") .Attr(GetPaddingAttrString()) @@ -1882,7 +1884,7 @@ REGISTER_OP("_MklConv3DBackpropInputV2") .Input("mkl_out_backprop: uint8") .Output("output: T") .Output("mkl_output: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int) >= 5") .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") .Attr("Tshape: {int32, int64} = DT_INT32") @@ -1912,7 +1914,7 @@ REGISTER_OP("_MklConv3DBackpropFilterV2") .Input("mkl_out_backprop: uint8") .Output("output: T") .Output("mkl_output: uint8") - .Attr("T: {half, float, double}") + .Attr("T: {bfloat16, float}") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) @@ -1937,7 +1939,7 @@ REGISTER_OP("_MklRelu") .Input("mkl_features: uint8") .Output("activations: T") .Output("mkl_activations: uint8") - .Attr("T: realnumbertype") + .Attr("T: {float, bfloat16} = DT_FLOAT") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator. @@ -1953,7 +1955,7 @@ REGISTER_OP("_MklReluGrad") .Input("mkl_features: uint8") .Output("backprops: T") .Output("mkl_backprops: uint8") - .Attr("T: realnumbertype") + .Attr("T: {float, bfloat16} = DT_FLOAT") .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified @@ -1968,7 +1970,7 @@ REGISTER_OP("_MklRelu6") .Input("mkl_features: uint8") .Output("activations: T") .Output("mkl_activations: uint8") - .Attr("T: realnumbertype") + .Attr("T: {float, bfloat16} = DT_FLOAT") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( MKL version of Relu6 operator. Uses MKL DNN APIs to implement Relu6 operator. @@ -1984,7 +1986,7 @@ REGISTER_OP("_MklRelu6Grad") .Input("mkl_features: uint8") .Output("backprops: T") .Output("mkl_backprops: uint8") - .Attr("T: realnumbertype") + .Attr("T: {float, bfloat16} = DT_FLOAT") .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( MKL version of Relu6Grad operator. Uses MKL DNN APIs to compute rectified @@ -1999,7 +2001,7 @@ REGISTER_OP("_MklLeakyRelu") .Input("mkl_features: uint8") .Output("activations: T") .Output("mkl_activations: uint8") - .Attr("T: {half, float, double} = DT_FLOAT") + .Attr("T: {float, bfloat16} = DT_FLOAT") .Attr("alpha: float = 0.2") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( @@ -2017,7 +2019,7 @@ REGISTER_OP("_MklLeakyReluGrad") .Input("mkl_features: uint8") .Output("backprops: T") .Output("mkl_backprops: uint8") - .Attr("T: {half, float, double} = DT_FLOAT") + .Attr("T: {float, bfloat16} = DT_FLOAT") .Attr("alpha: float = 0.2") .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( @@ -2033,7 +2035,7 @@ REGISTER_OP("_MklElu") .Input("mkl_features: uint8") .Output("activations: T") .Output("mkl_activations: uint8") - .Attr("T: realnumbertype") + .Attr("T: {float, bfloat16} = DT_FLOAT") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator. @@ -2048,7 +2050,7 @@ REGISTER_OP("_MklEluGrad") .Input("mkl_features: uint8") .Output("backprops: T") .Output("mkl_backprops: uint8") - .Attr("T: realnumbertype") + .Attr("T: {float, bfloat16} = DT_FLOAT") .SetShapeFn(shape_inference::MergeBothInputsShapeFn) .Doc(R"doc( MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu @@ -2101,7 +2103,7 @@ expected to invoke these operators. )doc"); REGISTER_OP("_MklMaxPool") - .Attr("T: {float, half} = DT_FLOAT") + .Attr("T: {float, half, bfloat16} = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) @@ -2127,7 +2129,7 @@ expected to invoke these operators. )doc"); REGISTER_OP("_MklMaxPoolGrad") - .Attr("T: {float, half} = DT_FLOAT") + .Attr("T: {float, half, bfloat16} = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr("workspace_enabled: bool = false") @@ -2167,7 +2169,7 @@ REGISTER_OP("_MklAvgPool") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Attr("T: {float, half, double}") + .Attr("T: {float, half, double, bfloat16}") .SetShapeFn(shape_inference::AvgPoolShape) .Doc(R"doc( MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling @@ -2188,7 +2190,7 @@ REGISTER_OP("_MklAvgPoolGrad") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Attr("T: {float, half, double}") + .Attr("T: {float, half, double, bfloat16}") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); @@ -2213,7 +2215,7 @@ REGISTER_OP("_MklAvgPool3D") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) - .Attr("T: {float, half, double}") + .Attr("T: {float, half, double, bfloat16}") .SetShapeFn(shape_inference::Pool3DShape) .Doc(R"doc( MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling @@ -2234,7 +2236,7 @@ REGISTER_OP("_MklAvgPool3DGrad") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) - .Attr("T: {float, half, double}") + .Attr("T: {float, half, double, bfloat16}") .SetShapeFn([](InferenceContext* c) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); @@ -2506,11 +2508,67 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +REGISTER_OP("_MklFusedBatchNormV2") + .Input("x: T") + .Input("scale: U") + .Input("offset: U") + .Input("mean: U") + .Input("variance: U") + .Input("mkl_x: uint8") + .Input("mkl_scale: uint8") + .Input("mkl_offset: uint8") + .Input("mkl_mean: uint8") + .Input("mkl_variance: uint8") + .Output("y: T") + .Output("batch_mean: U") + .Output("batch_variance: U") + .Output("reserve_space_1: U") + .Output("reserve_space_2: U") + .Output("mkl_y: uint8") + .Output("mkl_batch_mean: uint8") + .Output("mkl_batch_variance: uint8") + .Output("mkl_reserve_space_1: uint8") + .Output("mkl_reserve_space_2: uint8") + .Attr("T: {bfloat16, float}") + .Attr("U: {float}") + .Attr("epsilon: float = 0.0001") + .Attr(GetConvnetDataFormatAttrString()) + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormShape); + +REGISTER_OP("_MklFusedBatchNormGradV2") + .Input("y_backprop: T") + .Input("x: T") + .Input("scale: float") + .Input("reserve_space_1: U") + .Input("reserve_space_2: U") + .Input("mkl_y_backprop: uint8") + .Input("mkl_x: uint8") + .Input("mkl_scale: uint8") + .Input("mkl_reserve_space_1: uint8") + .Input("mkl_reserve_space_2: uint8") + .Output("x_backprop: T") + .Output("scale_backprop: U") + .Output("offset_backprop: U") + .Output("reserve_space_3: U") + .Output("reserve_space_4: U") + .Output("mkl_x_backprop: uint8") + .Output("mkl_scale_backprop: uint8") + .Output("mkl_offset_backprop: uint8") + .Output("mkl_reserve_space_3: uint8") + .Output("mkl_reserve_space_4: uint8") + .Attr("T: {bfloat16, float}") + .Attr("U: {float}") + .Attr("epsilon: float = 0.0001") + .Attr(GetConvnetDataFormatAttrString()) + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormGradShape); + REGISTER_OP("_MklToTf") .Input("input: T") .Input("mkl_input: uint8") .Output("output: T") - .Attr("T: {half, float, double, qint8, quint8, qint32}") + .Attr("T: {half, float, double, bfloat16, qint8, quint8, qint32}") .Attr(GetConvnetDataFormat2D3DAttrString()) .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( @@ -2531,8 +2589,8 @@ REGISTER_OP("_MklInputConversion") .Output("mkl_output_1: uint8") // All datatypes supported by element-wise ops .Attr( - "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " - "complex64, complex128}") + "T: {half, float, bfloat16, double, uint8, int8, uint16, int16, int32, " + "int64, complex64, complex128}") .Attr(GetConvnetDataFormat2D3DAttrString()) .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( @@ -2599,14 +2657,14 @@ REGISTER_OP("QuantizedConv2DWithBias") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); return Status::OK(); }); @@ -2633,12 +2691,12 @@ REGISTER_OP("QuantizedConv2DWithBiasAndRequantize") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); c->set_output(1, c->Scalar()); @@ -2666,13 +2724,13 @@ REGISTER_OP("QuantizedConv2DAndRelu") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); return Status::OK(); }); @@ -2697,11 +2755,11 @@ REGISTER_OP("QuantizedConv2DAndReluAndRequantize") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); c->set_output(1, c->Scalar()); @@ -2730,14 +2788,14 @@ REGISTER_OP("QuantizedConv2DWithBiasAndRelu") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); return Status::OK(); }); @@ -2765,12 +2823,12 @@ REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); c->set_output(1, c->Scalar()); @@ -2800,14 +2858,14 @@ REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); return Status::OK(); }); @@ -2838,12 +2896,12 @@ REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); c->set_output(1, c->Scalar()); @@ -2878,17 +2936,125 @@ REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); - ShapeHandle unused; + ShapeHandle unused, channel; TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); + // Since activations are not requantized per channel, `min_output` + // and `max_output` are scalars. c->set_output(1, c->Scalar()); c->set_output(2, c->Scalar()); return Status::OK(); }); +REGISTER_OP("QuantizedConv2DPerChannel") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + ShapeHandle unused, channel; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); + c->set_output(1, channel); + c->set_output(2, channel); + return Status::OK(); + }); + +REGISTER_OP("QuantizedDepthwiseConv2D") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); + +REGISTER_OP("QuantizedDepthwiseConv2DWithBias") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: float") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); + +REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndRelu") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: float") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("out_type: quantizedtype = DT_QINT32") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); + +REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize") + .Input("input: Tinput") + .Input("filter: Tfilter") + .Input("bias: Tbias") + .Input("min_input: float") + .Input("max_input: float") + .Input("min_filter: float") + .Input("max_filter: float") + .Input("min_freezed_output: float") + .Input("max_freezed_output: float") + .Output("output: out_type") + .Output("min_output: float") + .Output("max_output: float") + .Attr("Tinput: quantizedtype") + .Attr("Tfilter: quantizedtype") + .Attr("Tbias: {float, qint32}") + .Attr("out_type: quantizedtype = DT_QUINT8") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); + } // namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 285d39bd931..c1cc30da06b 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -761,6 +761,30 @@ op { } is_stateful: true } +op { + name: "AnonymousIteratorV2" + output_arg { + name: "handle" + type: DT_RESOURCE + } + output_arg { + name: "deleter" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Any" input_arg { @@ -3116,6 +3140,13 @@ op { name: "handle" type: DT_VARIANT } + attr { + name: "parallel_copy" + type: "bool" + default_value { + b: false + } + } attr { name: "output_types" type: "list(type)" @@ -3348,6 +3379,51 @@ op { } } } +op { + name: "BatchMatMulV2" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "adj_x" + type: "bool" + default_value { + b: false + } + } + attr { + name: "adj_y" + type: "bool" + default_value { + b: false + } + } +} op { name: "BatchMatrixBandPart" input_arg { @@ -4364,6 +4440,41 @@ op { } is_commutative: true } +op { + name: "BoostedTreesAggregateStats" + input_arg { + name: "node_ids" + type: DT_INT32 + } + input_arg { + name: "gradients" + type: DT_FLOAT + } + input_arg { + name: "hessians" + type: DT_FLOAT + } + input_arg { + name: "feature" + type: DT_INT32 + } + output_arg { + name: "stats_summary" + type: DT_FLOAT + } + attr { + name: "max_splits" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "num_buckets" + type: "int" + has_minimum: true + minimum: 1 + } +} op { name: "BoostedTreesBucketize" input_arg { @@ -4387,6 +4498,79 @@ op { has_minimum: true } } +op { + name: "BoostedTreesCalculateBestFeatureSplit" + input_arg { + name: "node_id_range" + type: DT_INT32 + } + input_arg { + name: "stats_summary" + type: DT_FLOAT + } + input_arg { + name: "l1" + type: DT_FLOAT + } + input_arg { + name: "l2" + type: DT_FLOAT + } + input_arg { + name: "tree_complexity" + type: DT_FLOAT + } + input_arg { + name: "min_node_weight" + type: DT_FLOAT + } + output_arg { + name: "node_ids" + type: DT_INT32 + } + output_arg { + name: "gains" + type: DT_FLOAT + } + output_arg { + name: "feature_dimensions" + type: DT_INT32 + } + output_arg { + name: "thresholds" + type: DT_INT32 + } + output_arg { + name: "left_node_contribs" + type: DT_FLOAT + } + output_arg { + name: "right_node_contribs" + type: DT_FLOAT + } + output_arg { + name: "split_with_default_directions" + type: DT_STRING + } + attr { + name: "logits_dimension" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "split_type" + type: "string" + default_value { + s: "inequality" + } + allowed_values { + list { + s: "inequality" + } + } + } +} op { name: "BoostedTreesCalculateBestGainsPerFeature" input_arg { @@ -5360,6 +5544,64 @@ op { } } } +op { + name: "ChooseFastestBranchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "ratio_numerator" + type: DT_INT64 + } + input_arg { + name: "ratio_denominator" + type: DT_INT64 + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "num_elements_per_branch" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "branches" + type: "list(func)" + has_minimum: true + minimum: 1 + } + attr { + name: "other_arguments_lengths" + type: "list(int)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "ClipByValue" input_arg { @@ -5691,6 +5933,13 @@ op { b: false } } + attr { + name: "clip_boxes" + type: "bool" + default_value { + b: true + } + } } op { name: "CompareAndBitpack" @@ -7702,6 +7951,13 @@ op { i: 0 } } + attr { + name: "time_major" + type: "bool" + default_value { + b: true + } + } is_stateful: true } op { @@ -8269,6 +8525,13 @@ op { b: true } } + attr { + name: "time_major" + type: "bool" + default_value { + b: true + } + } is_stateful: true } op { @@ -8968,6 +9231,45 @@ op { } } } +op { + name: "DecodePaddedRaw" + input_arg { + name: "input_bytes" + type: DT_STRING + } + input_arg { + name: "fixed_length" + type: DT_INT32 + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "out_type" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT16 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_INT64 + } + } + } + attr { + name: "little_endian" + type: "bool" + default_value { + b: true + } + } +} op { name: "DecodePng" input_arg { @@ -9074,6 +9376,7 @@ op { type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 + type: DT_BOOL } } } @@ -9130,6 +9433,18 @@ op { } is_stateful: true } +op { + name: "DeleteIterator" + input_arg { + name: "handle" + type: DT_RESOURCE + } + input_arg { + name: "deleter" + type: DT_VARIANT + } + is_stateful: true +} op { name: "DeleteSessionTensor" input_arg { @@ -9947,8 +10262,11 @@ op { type: "type" allowed_values { list { + type: DT_HALF type: DT_FLOAT type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 } } } @@ -9981,6 +10299,38 @@ op { } } } +op { + name: "DrawBoundingBoxesV2" + input_arg { + name: "images" + type_attr: "T" + } + input_arg { + name: "boxes" + type: DT_FLOAT + } + input_arg { + name: "colors" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } +} op { name: "DynamicPartition" input_arg { @@ -10309,6 +10659,21 @@ op { } } } +op { + name: "EncodeJpegVariableQuality" + input_arg { + name: "images" + type: DT_UINT8 + } + input_arg { + name: "quality" + type: DT_INT32 + } + output_arg { + name: "contents" + type: DT_STRING + } +} op { name: "EncodePng" input_arg { @@ -10585,6 +10950,14 @@ op { name: "table_ids" type: "list(int)" } + attr { + name: "max_sequence_lengths" + type: "list(int)" + default_value { + list { + } + } + } is_stateful: true } op { @@ -15205,6 +15578,20 @@ op { } is_stateful: true } +op { + name: "InfeedEnqueuePrelinearizedBuffer" + input_arg { + name: "input" + type: DT_VARIANT + } + attr { + name: "device_ordinal" + type: "int" + default_value { + i: -1 + } + } +} op { name: "InfeedEnqueueTuple" input_arg { @@ -16250,7 +16637,6 @@ op { } } } - is_commutative: true } op { name: "Less" @@ -17823,6 +18209,13 @@ op { name: "f" type: "func" } + attr { + name: "max_intra_op_parallelism" + type: "int" + default_value { + i: 1 + } + } } op { name: "MapIncompleteSize" @@ -19849,6 +20242,13 @@ op { name: "handle" type: DT_VARIANT } + attr { + name: "cpu_budget" + type: "int" + default_value { + i: 0 + } + } attr { name: "output_types" type: "list(type)" @@ -19917,12 +20317,14 @@ op { type: "type" allowed_values { list { + type: DT_HALF type: DT_FLOAT type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 } } } - is_commutative: true } op { name: "MultiDeviceIterator" @@ -20698,6 +21100,32 @@ op { op { name: "NoOp" } +op { + name: "NonDeterministicInts" + input_arg { + name: "shape" + type_attr: "shape_dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + attr { + name: "shape_dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + is_stateful: true +} op { name: "NonMaxSuppression" input_arg { @@ -21104,6 +21532,14 @@ op { has_minimum: true minimum: 1 } + attr { + name: "optimization_configs" + type: "list(string)" + default_value { + list { + } + } + } } op { name: "OptionalFromValue" @@ -21737,6 +22173,13 @@ op { name: "handle" type: DT_VARIANT } + attr { + name: "parallel_copy" + type: "bool" + default_value { + b: false + } + } attr { name: "Toutput_types" type: "list(type)" @@ -22846,6 +23289,66 @@ op { minimum: 1 } } +op { + name: "Prelinearize" + input_arg { + name: "input" + type_attr: "dtype" + } + output_arg { + name: "output" + type: DT_VARIANT + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + } + } + } + attr { + name: "layout" + type: "list(int)" + default_value { + list { + } + } + } +} +op { + name: "PrelinearizeTuple" + input_arg { + name: "inputs" + type_list_attr: "dtypes" + } + output_arg { + name: "output" + type: DT_VARIANT + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "shapes" + type: "list(shape)" + } + attr { + name: "layouts" + type: "list(int)" + default_value { + list { + } + } + } +} op { name: "PreventGradient" input_arg { @@ -23538,7 +24041,6 @@ op { } } } - is_commutative: true } op { name: "QuantizedAvgPool" @@ -24297,6 +24799,113 @@ op { } } } +op { + name: "QuantizedConv2DPerChannel" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} op { name: "QuantizedConv2DWithBias" input_arg { @@ -25256,6 +25865,464 @@ op { } } } +op { + name: "QuantizedDepthwiseConv2D" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +op { + name: "QuantizedDepthwiseConv2DWithBias" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "bias" + type: DT_FLOAT + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +op { + name: "QuantizedDepthwiseConv2DWithBiasAndRelu" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "bias" + type: DT_FLOAT + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QINT32 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} +op { + name: "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize" + input_arg { + name: "input" + type_attr: "Tinput" + } + input_arg { + name: "filter" + type_attr: "Tfilter" + } + input_arg { + name: "bias" + type_attr: "Tbias" + } + input_arg { + name: "min_input" + type: DT_FLOAT + } + input_arg { + name: "max_input" + type: DT_FLOAT + } + input_arg { + name: "min_filter" + type: DT_FLOAT + } + input_arg { + name: "max_filter" + type: DT_FLOAT + } + input_arg { + name: "min_freezed_output" + type: DT_FLOAT + } + input_arg { + name: "max_freezed_output" + type: DT_FLOAT + } + output_arg { + name: "output" + type_attr: "out_type" + } + output_arg { + name: "min_output" + type: DT_FLOAT + } + output_arg { + name: "max_output" + type: DT_FLOAT + } + attr { + name: "Tinput" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tfilter" + type: "type" + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "Tbias" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_QINT32 + } + } + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_QUINT8 + } + allowed_values { + list { + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_QINT16 + type: DT_QUINT16 + } + } + } + attr { + name: "strides" + type: "list(int)" + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "dilations" + type: "list(int)" + default_value { + list { + i: 1 + i: 1 + i: 1 + i: 1 + } + } + } +} op { name: "QuantizedInstanceNorm" input_arg { @@ -25580,7 +26647,6 @@ op { } } } - is_commutative: true } op { name: "QuantizedRelu" @@ -29686,6 +30752,13 @@ op { name: "output" type_attr: "dtype" } + attr { + name: "batch_dims" + type: "int" + default_value { + i: 0 + } + } attr { name: "validate_indices" type: "bool" @@ -29709,6 +30782,36 @@ op { } is_stateful: true } +op { + name: "ResourceGatherNd" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} op { name: "ResourceScatterAdd" input_arg { @@ -32023,7 +33126,6 @@ op { } } } - is_commutative: true } op { name: "Rint" @@ -32394,6 +33496,41 @@ op { } is_stateful: true } +op { + name: "SamplingDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "rate" + type: DT_FLOAT + } + input_arg { + name: "seed" + type: DT_INT64 + } + input_arg { + name: "seed2" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "Save" input_arg { @@ -38264,6 +39401,76 @@ op { } is_stateful: true } +op { + name: "StatefulRandomBinomial" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "algorithm" + type: DT_INT64 + } + input_arg { + name: "shape" + type_attr: "S" + } + input_arg { + name: "counts" + type_attr: "T" + } + input_arg { + name: "probs" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "S" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_DOUBLE + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} op { name: "StatefulStandardNormal" input_arg { @@ -38292,6 +39499,10 @@ op { type: DT_INT64 } } + deprecation { + version: 29 + explanation: "Use StatefulStandardNormalV2 instead" + } is_stateful: true } op { @@ -38328,6 +39539,74 @@ op { } is_stateful: true } +op { + name: "StatefulTruncatedNormal" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "algorithm" + type: DT_INT64 + } + input_arg { + name: "shape" + type_attr: "shape_dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + } + attr { + name: "shape_dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + is_stateful: true +} +op { + name: "StatefulUniform" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "algorithm" + type: DT_INT64 + } + input_arg { + name: "shape" + type_attr: "shape_dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + } + attr { + name: "shape_dtype" + type: "type" + default_value { + type: DT_INT64 + } + } + is_stateful: true +} op { name: "StatefulUniformFullInt" input_arg { @@ -38795,6 +40074,40 @@ op { } } } +op { + name: "StatsAggregatorHandleV2" + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { + name: "StatsAggregatorSetSummaryWriter" + input_arg { + name: "stats_aggregator" + type: DT_RESOURCE + } + input_arg { + name: "summary" + type: DT_RESOURCE + } + is_stateful: true +} op { name: "StopGradient" input_arg { @@ -41762,6 +43075,82 @@ op { } is_stateful: true } +op { + name: "TensorStridedSliceUpdate" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "end" + type_attr: "Index" + } + input_arg { + name: "strides" + type_attr: "Index" + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "begin_mask" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "end_mask" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "ellipsis_mask" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "new_axis_mask" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "shrink_axis_mask" + type: "int" + default_value { + i: 0 + } + } +} op { name: "TensorSummary" input_arg { @@ -43726,6 +45115,22 @@ op { } is_stateful: true } +op { + name: "WriteRawProtoSummary" + input_arg { + name: "writer" + type: DT_RESOURCE + } + input_arg { + name: "step" + type: DT_INT64 + } + input_arg { + name: "tensor" + type: DT_STRING + } + is_stateful: true +} op { name: "WriteScalarSummary" input_arg { diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 169076a6f67..ff87544db81 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -20,6 +20,7 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; @@ -29,7 +30,7 @@ REGISTER_OP("DecodeRaw") .Attr( "out_type: " "{half,float,double,int32,uint16,uint8,int16,int8,int64,complex64," - "complex128}") + "complex128,bool}") .Attr("little_endian: bool = true") .SetShapeFn([](InferenceContext* c) { // Note: last dimension is data dependent. @@ -40,6 +41,31 @@ REGISTER_OP("DecodeRaw") return Status::OK(); }); +REGISTER_OP("DecodePaddedRaw") + .Input("input_bytes: string") + .Input("fixed_length: int32") + .Output("output: out_type") + .Attr("out_type: {half,float,double,int32,uint16,uint8,int16,int8,int64}") + .Attr("little_endian: bool = true") + .SetShapeFn([](InferenceContext* c) { + DimensionHandle fixed_length; + TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &fixed_length)); + + DataType out_type; + TF_RETURN_IF_ERROR(c->GetAttr("out_type", &out_type)); + + int32 data_type_size = DataTypeSize(out_type); + + DimensionHandle width; + TF_RETURN_IF_ERROR(c->Divide(fixed_length, data_type_size, true, &width)); + + ShapeHandle out; + TF_RETURN_IF_ERROR(c->Concatenate(c->input(0), c->Vector(width), &out)); + + c->set_output(0, out); + return Status::OK(); + }); + REGISTER_OP("DecodeCompressed") .Input("bytes: string") .Output("output: string") diff --git a/tensorflow/core/ops/ragged_array_ops.cc b/tensorflow/core/ops/ragged_array_ops.cc index 46425799399..1e888907ae9 100644 --- a/tensorflow/core/ops/ragged_array_ops.cc +++ b/tensorflow/core/ops/ragged_array_ops.cc @@ -29,13 +29,14 @@ Status RaggedGatherShapeFn(InferenceContext* c); //============================================================================== REGISTER_OP("RaggedGather") - .Input("params_nested_splits: PARAMS_RAGGED_RANK * int64") + .Input("params_nested_splits: PARAMS_RAGGED_RANK * Tsplits") .Input("params_dense_values: Tvalues") .Input("indices: Tindices") - .Output("output_nested_splits: OUTPUT_RAGGED_RANK * int64") + .Output("output_nested_splits: OUTPUT_RAGGED_RANK * Tsplits") .Output("output_dense_values: Tvalues") .Attr("Tvalues: type") .Attr("Tindices: {int32, int64}") + .Attr("Tsplits: {int32, int64} = DT_INT64") .Attr("PARAMS_RAGGED_RANK: int >= 1") .Attr("OUTPUT_RAGGED_RANK: int >= 0") .SetShapeFn(RaggedGatherShapeFn); diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc index 90fd51717fa..5794b89a64e 100644 --- a/tensorflow/core/ops/ragged_conversion_ops.cc +++ b/tensorflow/core/ops/ragged_conversion_ops.cc @@ -23,21 +23,44 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; Status RaggedTensorToSparseShapeFn(InferenceContext* c); +Status RaggedTensorToVariantShapeFn(InferenceContext* c); +Status RaggedTensorFromVariantShapeFn(InferenceContext* c); //============================================================================== // Registered Ops //============================================================================== REGISTER_OP("RaggedTensorToSparse") - .Input("rt_nested_splits: RAGGED_RANK * int64") + .Input("rt_nested_splits: RAGGED_RANK * Tsplits") .Input("rt_dense_values: T") .Output("sparse_indices: int64") .Output("sparse_values: T") .Output("sparse_dense_shape: int64") .Attr("RAGGED_RANK: int >= 1") .Attr("T: type") + .Attr("Tsplits: {int32, int64} = DT_INT64") .SetShapeFn(RaggedTensorToSparseShapeFn); +REGISTER_OP("RaggedTensorToVariant") + .Input("rt_nested_splits: RAGGED_RANK * Tsplits") + .Input("rt_dense_values: Tvalues") + .Output("encoded_ragged: variant") + .Attr("RAGGED_RANK: int >= 1") + .Attr("Tvalues: type") + .Attr("Tsplits: {int32, int64}") + .Attr("batched_input: bool") + .SetShapeFn(RaggedTensorToVariantShapeFn); + +REGISTER_OP("RaggedTensorFromVariant") + .Input("encoded_ragged: variant") + .Output("output_nested_splits: output_ragged_rank * Tsplits") + .Output("output_dense_values: Tvalues") + .Attr("input_ragged_rank: int >= -1") + .Attr("output_ragged_rank: int >= 1") + .Attr("Tvalues: type") + .Attr("Tsplits: {int32, int64}") + .SetShapeFn(RaggedTensorFromVariantShapeFn); + //============================================================================== // Shape Functions //============================================================================== @@ -71,4 +94,46 @@ Status RaggedTensorToSparseShapeFn(InferenceContext* c) { return Status::OK(); } +Status RaggedTensorToVariantShapeFn(InferenceContext* c) { + int64 num_splits; + TF_RETURN_IF_ERROR(c->GetAttr("RAGGED_RANK", &num_splits)); + bool batched; + TF_RETURN_IF_ERROR(c->GetAttr("batched_input", &batched)); + shape_inference::ShapeHandle rt_dense_values = c->input(num_splits); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values)); + for (int64 i = 0; i < num_splits; ++i) { + shape_inference::ShapeHandle splits = c->input(i); + TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits)); + } + if (batched) { + auto num_first_splits = c->Dim(c->input(0), 0); + shape_inference::DimensionHandle num_rows; + TF_RETURN_IF_ERROR(c->Subtract(num_first_splits, 1, &num_rows)); + c->set_output(0, c->Vector(num_rows)); + } else { + c->set_output(0, c->Scalar()); + } + return Status::OK(); +} + +Status RaggedTensorFromVariantShapeFn(InferenceContext* c) { + int64 input_ragged_rank; + TF_RETURN_IF_ERROR( + c->GetAttr("input_ragged_rank", &input_ragged_rank)); + int64 output_ragged_rank; + TF_RETURN_IF_ERROR( + c->GetAttr("output_ragged_rank", &output_ragged_rank)); + shape_inference::ShapeHandle encoded_ragged = c->input(0); + if (c->RankKnown(encoded_ragged) && input_ragged_rank >= 0) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank( + encoded_ragged, output_ragged_rank - input_ragged_rank, &unused)); + } + for (int64 i = 0; i < output_ragged_rank; i++) { + c->set_output(i, c->UnknownShapeOfRank(1)); + } + c->set_output(output_ragged_rank, c->UnknownShape()); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/ops/ragged_math_ops.cc b/tensorflow/core/ops/ragged_math_ops.cc index d739c697981..5ceb31be3f0 100644 --- a/tensorflow/core/ops/ragged_math_ops.cc +++ b/tensorflow/core/ops/ragged_math_ops.cc @@ -32,9 +32,10 @@ REGISTER_OP("RaggedRange") .Input("starts: T") .Input("limits: T") .Input("deltas: T") - .Output("rt_nested_splits: int64") + .Output("rt_nested_splits: Tsplits") .Output("rt_dense_values: T") .Attr("T: {bfloat16, float, double, int32, int64} = DT_INT32") + .Attr("Tsplits: {int32, int64} = DT_INT64") .SetShapeFn(RaggedRangeShapeFn); //============================================================================== diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index f54ed52ea29..696a69eff80 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -29,29 +29,10 @@ namespace tensorflow { namespace { -Status ValidateVariableResourceHandle( - InferenceContext* c, std::vector* shape_and_type) { - auto* handle_data = c->input_handle_shapes_and_types(0); - if (handle_data == nullptr || handle_data->empty()) { - shape_and_type->emplace_back(c->UnknownShape(), DT_INVALID); - } else { - *shape_and_type = *handle_data; - DataType value_dtype; - TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype)); - if (shape_and_type->at(0).dtype != value_dtype) { - return errors::InvalidArgument( - "Trying to read variable with wrong dtype. " - "Expected ", - DataTypeString(shape_and_type->at(0).dtype), " got ", - DataTypeString(value_dtype)); - } - } - return Status::OK(); -} - Status ReadVariableShapeFn(InferenceContext* c) { std::vector shape_and_type; - TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &shape_and_type)); + TF_RETURN_IF_ERROR( + shape_inference::ValidateVariableResourceHandle(c, &shape_and_type)); c->set_output(0, shape_and_type[0].shape); if (shape_and_type[0].dtype == DT_VARIANT && shape_and_type.size() > 1) { std::vector variant_shape_and_type; @@ -186,7 +167,8 @@ REGISTER_OP("DestroyResourceOp") Status CreateAssignShapeFn(InferenceContext* c) { std::vector handle_shape_and_type; - TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type)); + TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle( + c, &handle_shape_and_type)); ShapeHandle value_shape = c->input(1); ShapeHandle unused; @@ -254,24 +236,50 @@ REGISTER_OP("VariableShape") REGISTER_OP("ResourceGather") .Input("resource: resource") .Input("indices: Tindices") + .Attr("batch_dims: int = 0") .Attr("validate_indices: bool = true") .Output("output: dtype") .Attr("dtype: type") .Attr("Tindices: {int32,int64}") .SetShapeFn([](InferenceContext* c) { std::vector handle_shape_and_type; - TF_RETURN_IF_ERROR( - ValidateVariableResourceHandle(c, &handle_shape_and_type)); + TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle( + c, &handle_shape_and_type)); + + ShapeHandle indices_shape = c->input(1); ShapeHandle unused; + int32 batch_dims; + TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims)); + if (batch_dims < 0) + return errors::InvalidArgument("batch_dims is negative (", batch_dims, + ")"); + + TF_RETURN_IF_ERROR(c->WithRankAtLeast(handle_shape_and_type[0].shape, + batch_dims + 1, &unused)); + TF_RETURN_IF_ERROR( - c->WithRankAtLeast(handle_shape_and_type[0].shape, 1, &unused)); - ShapeHandle params_subshape; + c->WithRankAtLeast(indices_shape, batch_dims, &unused)); + + ShapeHandle params_subshape1; + TF_RETURN_IF_ERROR(c->Subshape(handle_shape_and_type[0].shape, 0, + batch_dims, ¶ms_subshape1)); + + ShapeHandle params_subshape2; + TF_RETURN_IF_ERROR(c->Subshape(handle_shape_and_type[0].shape, + batch_dims + 1, ¶ms_subshape2)); + + ShapeHandle indices_subshape; TF_RETURN_IF_ERROR( - c->Subshape(handle_shape_and_type[0].shape, 1, ¶ms_subshape)); - ShapeHandle indices_shape = c->input(1); + c->Subshape(indices_shape, batch_dims, &indices_subshape)); + + // The out shape is params_shape[:batch_dims] + + // indices_shape[batch_dims:] + params_shape[batch_dims+1:]. ShapeHandle out; - TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out)); + TF_RETURN_IF_ERROR( + c->Concatenate(params_subshape1, indices_subshape, &out)); + TF_RETURN_IF_ERROR(c->Concatenate(out, params_subshape2, &out)); + c->set_output(0, out); if (handle_shape_and_type[0].dtype == DT_VARIANT && !handle_shape_and_type.empty()) { @@ -284,11 +292,20 @@ REGISTER_OP("ResourceGather") return Status::OK(); }); +REGISTER_OP("ResourceGatherNd") + .Input("resource: resource") + .Input("indices: Tindices") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("Tindices: {int32,int64}") + .SetShapeFn(shape_inference::GatherNdShape); + namespace { Status ResourceScatterUpdateShape(InferenceContext* c) { std::vector handle_shape_and_type; - TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type)); + TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle( + c, &handle_shape_and_type)); ShapeHandle var_shape = handle_shape_and_type[0].shape; ShapeHandle indices_shape = c->input(1); diff --git a/tensorflow/core/ops/stateful_random_ops.cc b/tensorflow/core/ops/stateful_random_ops.cc index cf35eb78544..9537e614069 100644 --- a/tensorflow/core/ops/stateful_random_ops.cc +++ b/tensorflow/core/ops/stateful_random_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -40,8 +41,10 @@ Status StatefulRandomShape(shape_inference::InferenceContext* c) { .Attr("shape_dtype : type = DT_INT64") \ .SetShapeFn(StatefulRandomShape); +REGISTER_STATEFUL_OP("StatefulUniform", DT_FLOAT); REGISTER_STATEFUL_OP("StatefulUniformFullInt", DT_UINT64); REGISTER_STATEFUL_OP("StatefulStandardNormalV2", DT_FLOAT); +REGISTER_STATEFUL_OP("StatefulTruncatedNormal", DT_FLOAT); REGISTER_OP("StatefulUniformInt") .Input("resource: resource") @@ -66,10 +69,59 @@ REGISTER_OP("StatefulUniformInt") return Status::OK(); }); -// Register the old 'StatefulStandardNormal' op. This op is a short-lived +REGISTER_OP("RngSkip") + .Input("resource: resource") + .Input("algorithm: int64") + .Input("delta: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return Status::OK(); + }); + +REGISTER_OP("NonDeterministicInts") + .Input("shape: shape_dtype") + .SetIsStateful() + .Output("output: dtype") + .Attr("dtype : type = DT_INT64") + .Attr("shape_dtype : type = DT_INT64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + using shape_inference::ShapeHandle; + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); + return Status::OK(); + }); + +REGISTER_OP("StatefulRandomBinomial") + .Input("resource: resource") + .Input("algorithm: int64") + .Input("shape: S") + .Input("counts: T") + .Input("probs: T") + .Output("output: dtype") + .Attr("S: {int32, int64}") + .Attr("T: {half, float, double, int32, int64} = DT_DOUBLE") + .Attr("dtype: {half, float, double, int32, int64} = DT_INT64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + using shape_inference::ShapeHandle; + + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); + + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out)); + c->set_output(0, out); + return Status::OK(); + }); + +// Register the depracated 'StatefulStandardNormal' op. This op is a short-lived // version where the 'resource' variable also contains the algorithm tag. // It is deprecated in favor of 'StatefulStandardNormalV2'. REGISTER_OP("StatefulStandardNormal") + .Deprecated(29, "Use StatefulStandardNormalV2 instead") .Input("resource: resource") .Input("shape: shape_dtype") .Output("output: dtype") diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index d012ce67fd0..4aefaad90d0 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -206,6 +206,18 @@ REGISTER_OP("StringSplitV2") return Status::OK(); }); +REGISTER_OP("StringLower") + .Input("input: string") + .Output("output: string") + .Attr("encoding: string =''") + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("StringUpper") + .Input("input: string") + .Output("output: string") + .Attr("encoding: string =''") + .SetShapeFn(shape_inference::UnchangedShape); + REGISTER_OP("StringStrip") .Input("input: string") .Output("output: string") @@ -263,10 +275,11 @@ REGISTER_OP("UnicodeScript") REGISTER_OP("UnicodeEncode") .Input("input_values: int32") - .Input("input_splits: int64") + .Input("input_splits: Tsplits") .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'") .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}") .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char + .Attr("Tsplits: {int32, int64} = DT_INT64") .Output("output: string") .SetShapeFn([](InferenceContext* c) { // Check rank of inner values @@ -298,12 +311,13 @@ REGISTER_OP("UnicodeTranscode") REGISTER_OP("UnicodeDecode") .Input("input: string") - .Output("row_splits: int64") + .Output("row_splits: Tsplits") .Output("char_values: int32") .Attr("input_encoding: string") .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char .Attr("replace_control_characters: bool = false") + .Attr("Tsplits: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { // row_splits.shape == [input.size() + 1] DimensionHandle num_row_splits; @@ -319,13 +333,14 @@ REGISTER_OP("UnicodeDecode") REGISTER_OP("UnicodeDecodeWithOffsets") .Input("input: string") - .Output("row_splits: int64") + .Output("row_splits: Tsplits") .Output("char_values: int32") .Output("char_to_byte_starts: int64") .Attr("input_encoding: string") .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char .Attr("replace_control_characters: bool = false") + .Attr("Tsplits: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { // row_splits.shape == [input.size() + 1] DimensionHandle num_row_splits; diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc index 742a221adcb..7923d2436f9 100644 --- a/tensorflow/core/ops/summary_ops.cc +++ b/tensorflow/core/ops/summary_ops.cc @@ -57,6 +57,12 @@ REGISTER_OP("WriteSummary") .Attr("T: type") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("WriteRawProtoSummary") + .Input("writer: resource") + .Input("step: int64") + .Input("tensor: string") + .SetShapeFn(shape_inference::NoOutputs); + REGISTER_OP("ImportEvent") .Input("writer: resource") .Input("event: string") diff --git a/tensorflow/core/ops/tpu_embedding_ops.cc b/tensorflow/core/ops/tpu_embedding_ops.cc index 4eaab1e6c72..36345ba831c 100644 --- a/tensorflow/core/ops/tpu_embedding_ops.cc +++ b/tensorflow/core/ops/tpu_embedding_ops.cc @@ -73,251 +73,48 @@ class RegisterPerTableLoadAndRetrieveOpsOnConstruction { RegisterPerTableLoadAndRetrieveOpsOnConstruction register_per_table_load_and_retrieve_ops_var; -Status RegisterPerTableLoadOpsForAlgorithmBody( - tpu::OptimizationAlgorithm alg, bool is_debug_op, - OpRegistrationData* op_reg_data) { - tpu::GradientAccumulationSupport grad_accum_support; - TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support)); - - std::vector state_variable_specs; - TF_CHECK_OK(GetOptimizationAlgorithmStateVariables( - alg, - grad_accum_support == tpu::GradientAccumulationSupport::kSupported && - is_debug_op, - &state_variable_specs)); - auto* op_def = &op_reg_data->op_def; - op_def->set_name( - strings::StrCat("LoadTPUEmbedding", GetOptimizationAlgorithmName(alg), - "Parameters", (is_debug_op ? "GradAccumDebug" : ""))); - // It is important for the order of the inputs to the op defined here - // to match the order in input_names because the indexes are used in - // the combining transformation. - for (const auto& parameter : state_variable_specs) { - if (parameter.has_user_defined() || is_debug_op) { - auto* arg = op_def->add_input_arg(); - arg->set_name(parameter.name()); - arg->set_type(DT_FLOAT); - } - } - { - auto* table_id_attr = op_def->add_attr(); - table_id_attr->set_name("table_id"); - table_id_attr->set_type("int"); - table_id_attr->set_has_minimum(true); - table_id_attr->set_minimum(-1); - table_id_attr->mutable_default_value()->set_i(-1); - } - { - auto* table_name_attr = op_def->add_attr(); - table_name_attr->set_name("table_name"); - table_name_attr->set_type("string"); - table_name_attr->mutable_default_value()->set_s(""); - } - { - auto* num_shards_attr = op_def->add_attr(); - num_shards_attr->set_name("num_shards"); - num_shards_attr->set_type("int"); - } - { - auto* shard_id_attr = op_def->add_attr(); - shard_id_attr->set_name("shard_id"); - shard_id_attr->set_type("int"); - } - string parameter_descriptions; - for (const auto& parameter : state_variable_specs) { - if (parameter.has_user_defined() || is_debug_op) { - strings::Appendf(¶meter_descriptions, - R"( -%s: A tensor containing the initial embedding table %s to use in embedding -lookups using the %s optimization algorithm.)", - parameter.name().c_str(), parameter.name().c_str(), - GetOptimizationAlgorithmFriendlyName(alg).c_str()); - } - } - op_def->set_is_commutative(false); - op_def->set_is_aggregate(false); - op_def->set_is_stateful(true); - auto shape_inference_function = - [state_variable_specs, - is_debug_op](shape_inference::InferenceContext* c) -> Status { - int table_id; - TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); - string table_name; - TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); - // Exactly one must be non-default. - if ((table_id >= 0) == (!table_name.empty())) { - return errors::InvalidArgument( - "exactly one of table_id or table_name must be non-default"); - } - int num_shards; - TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); - int shard_id; - TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id)); - const int user_param_count = - std::count_if(state_variable_specs.begin(), state_variable_specs.end(), - [&](const tpu::StateVariableSpecification& sv) { - return sv.has_user_defined() || is_debug_op; - }); - std::vector inputs(user_param_count); - int input_index = 0; - for (int i = 0; i < state_variable_specs.size(); ++i) { - if (state_variable_specs[i].has_user_defined() || is_debug_op) { - std::vector input_temp; - TF_RETURN_IF_ERROR( - c->input(state_variable_specs[i].name(), &input_temp)); - if (input_temp.size() != 1) { - return errors::InvalidArgument("each input to be rank 1"); - } - inputs[input_index] = input_temp[0]; - ++input_index; - } - } - // Verify shapes have rank 2 and are compatible when they are - // required to be valid. - shape_inference::ShapeHandle parameter_shape; - TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, ¶meter_shape)); - for (int j = 1; j < user_param_count; ++j) { - shape_inference::ShapeHandle accumulator_j_shape; - TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape)); - shape_inference::ShapeHandle merged; - TF_RETURN_IF_ERROR( - c->Merge(parameter_shape, accumulator_j_shape, &merged)); - } - return Status::OK(); - }; - op_reg_data->shape_inference_fn = shape_inference_function; - return Status::OK(); -} - -Status RegisterPerTableRetrieveOpsForAlgorithmBody( - tpu::OptimizationAlgorithm alg, bool is_debug_op, - OpRegistrationData* op_reg_data) { - tpu::GradientAccumulationSupport grad_accum_support; - TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support)); - - std::vector state_variable_specs; - TF_CHECK_OK(GetOptimizationAlgorithmStateVariables( - alg, - grad_accum_support == tpu::GradientAccumulationSupport::kSupported && - is_debug_op, - &state_variable_specs)); - - auto* op_def = &op_reg_data->op_def; - op_def->set_name(strings::StrCat( - "RetrieveTPUEmbedding", tpu::GetOptimizationAlgorithmName(alg), - "Parameters", (is_debug_op ? "GradAccumDebug" : ""))); - // It is important for the order of the outputs of the op defined here - // to match the order in output_names because the indexes are used in - // the combining transformation. - for (const auto& parameter : state_variable_specs) { - if (parameter.has_user_defined() || is_debug_op) { - auto* arg = op_def->add_output_arg(); - arg->set_name(parameter.name()); - arg->set_type(DT_FLOAT); - } - } - { - auto* table_id_attr = op_def->add_attr(); - table_id_attr->set_name("table_id"); - table_id_attr->set_type("int"); - table_id_attr->set_has_minimum(true); - table_id_attr->set_minimum(-1); - table_id_attr->mutable_default_value()->set_i(-1); - } - { - auto* table_name_attr = op_def->add_attr(); - table_name_attr->set_name("table_name"); - table_name_attr->set_type("string"); - table_name_attr->mutable_default_value()->set_s(""); - } - { - auto* num_shards_attr = op_def->add_attr(); - num_shards_attr->set_name("num_shards"); - num_shards_attr->set_type("int"); - } - { - auto* shard_id_attr = op_def->add_attr(); - shard_id_attr->set_name("shard_id"); - shard_id_attr->set_type("int"); - } - string parameter_descriptions; - for (const auto& param : state_variable_specs) { - if (param.has_user_defined() || is_debug_op) { - strings::Appendf(¶meter_descriptions, - R"( -%s: A tensor containing the embedding table %s to store with the -parameters from embedding updates using the %s optimization algorithm.)", - param.name().c_str(), param.name().c_str(), - tpu::GetOptimizationAlgorithmFriendlyName(alg).c_str()); - } - } - op_def->set_is_commutative(false); - op_def->set_is_aggregate(false); - op_def->set_is_stateful(true); - auto shape_inference_function = - [state_variable_specs, - is_debug_op](shape_inference::InferenceContext* c) -> Status { - int table_id; - TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); - string table_name; - TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); - // Exactly one must be non-default. - if ((table_id >= 0) == (!table_name.empty())) { - return errors::InvalidArgument( - "exactly one of table_id or table_name must be non-default"); - } - int num_shards; - TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); - int shard_id; - TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id)); - for (int j = 0; j < state_variable_specs.size(); ++j) { - if (state_variable_specs[j].has_user_defined() || is_debug_op) { - auto shape = c->MakeShape( - std::vector(2, c->UnknownDim())); - TF_RETURN_IF_ERROR( - c->set_output(state_variable_specs[j].name(), - std::vector(1, shape))); - } - } - return Status::OK(); - }; - op_reg_data->shape_inference_fn = shape_inference_function; - return Status::OK(); -} - void RegisterPerTableLoadAndRetrieveOps() { // Load ops for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) { - OpRegistry::Global()->Register( - [alg](OpRegistrationData* op_reg_data) -> Status { - return RegisterPerTableLoadOpsForAlgorithmBody(alg, false, - op_reg_data); - }); - tpu::GradientAccumulationSupport grad_accum_support; - TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support)); - if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) { + bool internal; + TF_CHECK_OK(tpu::IsOptimizationAlgorithmInternal(alg, &internal)); + if (!internal) { OpRegistry::Global()->Register( [alg](OpRegistrationData* op_reg_data) -> Status { - return RegisterPerTableLoadOpsForAlgorithmBody(alg, true, - op_reg_data); + return tpu::RegisterPerTableLoadOpsForAlgorithmBody(alg, false, + op_reg_data); }); + tpu::GradientAccumulationSupport grad_accum_support; + TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support)); + if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) { + OpRegistry::Global()->Register( + [alg](OpRegistrationData* op_reg_data) -> Status { + return tpu::RegisterPerTableLoadOpsForAlgorithmBody(alg, true, + op_reg_data); + }); + } } } + // Retrieve ops for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) { - OpRegistry::Global()->Register( - [alg](OpRegistrationData* op_reg_data) -> Status { - return RegisterPerTableRetrieveOpsForAlgorithmBody(alg, false, - op_reg_data); - }); - tpu::GradientAccumulationSupport grad_accum_support; - TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support)); - if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) { + bool internal; + TF_CHECK_OK(tpu::IsOptimizationAlgorithmInternal(alg, &internal)); + if (!internal) { OpRegistry::Global()->Register( [alg](OpRegistrationData* op_reg_data) -> Status { - return RegisterPerTableRetrieveOpsForAlgorithmBody(alg, true, - op_reg_data); + return tpu::RegisterPerTableRetrieveOpsForAlgorithmBody( + alg, false, op_reg_data); }); + tpu::GradientAccumulationSupport grad_accum_support; + TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support)); + if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) { + OpRegistry::Global()->Register( + [alg](OpRegistrationData* op_reg_data) -> Status { + return tpu::RegisterPerTableRetrieveOpsForAlgorithmBody( + alg, true, op_reg_data); + }); + } } } } @@ -430,6 +227,7 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch") .Attr("device_ordinal: int = -1") .Attr("combiners: list(string) = []") .Attr("table_ids: list(int)") + .Attr("max_sequence_lengths: list(int) = []") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape); diff --git a/tensorflow/core/ops/tpu_infeed_ops.cc b/tensorflow/core/ops/tpu_infeed_ops.cc index 0090b761c48..2cab6f7f976 100644 --- a/tensorflow/core/ops/tpu_infeed_ops.cc +++ b/tensorflow/core/ops/tpu_infeed_ops.cc @@ -63,4 +63,25 @@ REGISTER_OP("InfeedDequeueTuple") return Status::OK(); }); +REGISTER_OP("Prelinearize") + .Input("input: dtype") + .Attr("dtype: type") + .Attr("shape: shape = {}") + .Attr("layout: list(int) = []") + .Output("output: variant") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("PrelinearizeTuple") + .Input("inputs: dtypes") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Attr("layouts: list(int) = []") + .Output("output: variant") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("InfeedEnqueuePrelinearizedBuffer") + .Input("input: variant") + .Attr("device_ordinal: int = -1") + .SetShapeFn(shape_inference::NoOutputs); + } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc index eb9023d7089..e31901a7a0f 100644 --- a/tensorflow/core/platform/cloud/curl_http_request_test.cc +++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc @@ -29,7 +29,7 @@ class FakeEnv : public EnvWrapper { public: FakeEnv() : EnvWrapper(Env::Default()) {} - uint64 NowSeconds() override { return now_; } + uint64 NowSeconds() const override { return now_; } uint64 now_ = 10000; }; diff --git a/tensorflow/core/platform/cloud/gcs_throttle_test.cc b/tensorflow/core/platform/cloud/gcs_throttle_test.cc index 8f962b92b88..774855a3b04 100644 --- a/tensorflow/core/platform/cloud/gcs_throttle_test.cc +++ b/tensorflow/core/platform/cloud/gcs_throttle_test.cc @@ -24,7 +24,7 @@ namespace { class TestTime : public EnvTime { public: - uint64 NowNanos() override { return now_micros_ * kMicrosToNanos; } + uint64 NowNanos() const override { return now_micros_ * kMicrosToNanos; } void SetTime(uint64 now_micros) { now_micros_ = now_micros; } diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc index d2db59200ab..8c7e107037a 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc @@ -30,7 +30,7 @@ class FakeEnv : public EnvWrapper { public: FakeEnv() : EnvWrapper(Env::Default()) {} - uint64 NowSeconds() override { return now; } + uint64 NowSeconds() const override { return now; } uint64 now = 10000; }; diff --git a/tensorflow/core/platform/cloud/now_seconds_env.h b/tensorflow/core/platform/cloud/now_seconds_env.h index b3450bef990..8928587b4e0 100644 --- a/tensorflow/core/platform/cloud/now_seconds_env.h +++ b/tensorflow/core/platform/cloud/now_seconds_env.h @@ -28,7 +28,7 @@ class NowSecondsEnv : public EnvWrapper { NowSecondsEnv() : EnvWrapper(Env::Default()) {} /// The current (fake) timestamp. - uint64 NowSeconds() override { + uint64 NowSeconds() const override { mutex_lock lock(mu_); return now_; } @@ -40,7 +40,7 @@ class NowSecondsEnv : public EnvWrapper { } /// Guards access to now_. - mutex mu_; + mutable mutex mu_; /// The NowSeconds() value that this Env will return. uint64 now_ = 1; diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index a8657359a35..89b1056be7d 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -284,7 +284,7 @@ Status OAuthClient::ParseOAuthResponse(StringPiece response, return errors::FailedPrecondition("Unexpected Oauth token type: " + token_type); } - int64 expires_in; + int64 expires_in = 0; TF_RETURN_IF_ERROR(ReadJsonInt(root, "expires_in", &expires_in)); *expiration_timestamp_sec = request_timestamp_sec + expires_in; TF_RETURN_IF_ERROR(ReadJsonString(root, "access_token", token)); diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index ce3b9d79c8b..7b76e4c6c16 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -42,7 +42,7 @@ class FakeEnv : public EnvWrapper { public: FakeEnv() : EnvWrapper(Env::Default()) {} - uint64 NowSeconds() override { return now; } + uint64 NowSeconds() const override { return now; } uint64 now = 10000; }; diff --git a/tensorflow/core/platform/cloud/ram_file_block_cache.cc b/tensorflow/core/platform/cloud/ram_file_block_cache.cc index 82b692a9e39..5d924685c7d 100644 --- a/tensorflow/core/platform/cloud/ram_file_block_cache.cc +++ b/tensorflow/core/platform/cloud/ram_file_block_cache.cc @@ -104,7 +104,9 @@ Status RamFileBlockCache::MaybeFetch(const Key& key, mutex_lock l(mu_); // Do not update state if the block is already to be evicted. if (block->timestamp != 0) { - cache_size_ += block->data.size(); + // Use capacity() instead of size() to account for all memory + // used by the cache. + cache_size_ += block->data.capacity(); // Put to beginning of LRA list. lra_list_.erase(block->lra_iterator); lra_list_.push_front(key); @@ -132,7 +134,9 @@ Status RamFileBlockCache::MaybeFetch(const Key& key, block->mu.lock(); // Reacquire the lock immediately afterwards if (status.ok()) { block->data.resize(bytes_transferred, 0); - block->data.shrink_to_fit(); + // Shrink the data capacity to the actual size used. + // NOLINTNEXTLINE: shrink_to_fit() may not shrink the capacity. + std::vector(block->data).swap(block->data); downloaded_block = true; block->state = FetchState::FINISHED; } else { @@ -285,7 +289,7 @@ void RamFileBlockCache::RemoveBlock(BlockMap::iterator entry) { entry->second->timestamp = 0; lru_list_.erase(entry->second->lru_iterator); lra_list_.erase(entry->second->lra_iterator); - cache_size_ -= entry->second->data.size(); + cache_size_ -= entry->second->data.capacity(); block_map_.erase(entry); } diff --git a/tensorflow/core/platform/cloud/time_util.cc b/tensorflow/core/platform/cloud/time_util.cc index 0587a65c299..afd06efa854 100644 --- a/tensorflow/core/platform/cloud/time_util.cc +++ b/tensorflow/core/platform/cloud/time_util.cc @@ -41,14 +41,14 @@ Status ParseRfc3339Time(const string& time, int64* mtime_nsec) { return errors::Internal( strings::StrCat("Unrecognized RFC 3339 time format: ", time)); } - const int int_seconds = floor(seconds); + const int int_seconds = std::floor(seconds); parsed.tm_year -= 1900; // tm_year expects years since 1900. parsed.tm_mon -= 1; // month is zero-based. parsed.tm_sec = int_seconds; *mtime_nsec = timegm(&parsed) * kNanosecondsPerSecond + - static_cast( - floor((seconds - int_seconds) * kNanosecondsPerSecond)); + static_cast(std::floor((seconds - int_seconds) * + kNanosecondsPerSecond)); return Status::OK(); } diff --git a/tensorflow/core/platform/cpu_info.h b/tensorflow/core/platform/cpu_info.h index c9208cc7553..fd76047edc5 100644 --- a/tensorflow/core/platform/cpu_info.h +++ b/tensorflow/core/platform/cpu_info.h @@ -36,6 +36,16 @@ namespace port { // value (e.g. `4`) may be returned. int NumSchedulableCPUs(); +// Returns an estimate for the maximum parallelism for this process. +// Applications should avoid running more than this number of threads with +// intensive workloads concurrently to avoid performance degradation and +// contention. +// This value is either the number of schedulable CPUs, or a value specific to +// the underlying cluster management. Applications should assume this value can +// change throughout the lifetime of the process. This function must not be +// called during initialization, i.e., before before main() has started. +int MaxParallelism(); + // Returns the total number of CPUs on the system. This number should // not change even if the underlying cluster management software may // change the number of schedulable CPUs. Unlike `NumSchedulableCPUs`, if the diff --git a/tensorflow/core/platform/cuda_libdevice_path.h b/tensorflow/core/platform/cuda_libdevice_path.h index f2dbff9043a..1f54730dd7d 100644 --- a/tensorflow/core/platform/cuda_libdevice_path.h +++ b/tensorflow/core/platform/cuda_libdevice_path.h @@ -25,6 +25,22 @@ namespace tensorflow { // the CUDA SDK, which contains sub-folders such as bin, lib64, and nvvm. std::vector CandidateCudaRoots(); +// A convenient wrapper for CandidateCudaRoots, which allows supplying a +// preferred location (inserted first in the output vector), and a flag whether +// the current working directory should be searched (inserted last). +inline std::vector CandidateCudaRoots( + string preferred_location, bool use_working_directory = true) { + std::vector candidates = CandidateCudaRoots(); + if (!preferred_location.empty()) { + candidates.insert(candidates.begin(), preferred_location); + } + + // "." is our last resort, even though it probably won't work. + candidates.push_back("."); + + return candidates; +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_ diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 40eb0f63d72..ba7fccdd6fd 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -317,7 +317,7 @@ def tf_proto_library_cc( srcs = [], has_services = None, protodeps = [], - visibility = [], + visibility = None, testonly = 0, cc_libs = [], cc_stubby_versions = None, @@ -387,7 +387,7 @@ def tf_proto_library_py( srcs = [], protodeps = [], deps = [], - visibility = [], + visibility = None, testonly = 0, srcs_version = "PY2AND3", use_grpc_plugin = False): @@ -433,7 +433,7 @@ def tf_proto_library( srcs = [], has_services = None, protodeps = [], - visibility = [], + visibility = None, testonly = 0, cc_libs = [], cc_api_version = 2, @@ -520,6 +520,14 @@ def tf_additional_lib_srcs(exclude = []): ], exclude = exclude), }) +def tf_additional_monitoring_hdrs(): + return [] + +def tf_additional_monitoring_srcs(): + return [ + "platform/default/monitoring.cc", + ] + def tf_additional_minimal_lib_srcs(): return [ "platform/default/integral_types.h", @@ -546,7 +554,11 @@ def tf_additional_all_protos(): return ["//tensorflow/core:protos_all"] def tf_protos_all_impl(): - return ["//tensorflow/core:protos_all_cc_impl"] + return [ + "//tensorflow/core:autotuning_proto_cc_impl", + "//tensorflow/core:conv_autotuning_proto_cc_impl", + "//tensorflow/core:protos_all_cc_impl", + ] def tf_protos_all(): return if_static( @@ -573,7 +585,14 @@ def tf_protos_grappler(): ) def tf_additional_cupti_wrapper_deps(): - return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"] + return [ + "//tensorflow/stream_executor/cuda:cupti_stub", + "@com_google_absl//absl/base", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:flat_hash_map", + ] def tf_additional_device_tracer_srcs(): return ["platform/default/device_tracer.cc"] @@ -590,6 +609,13 @@ def tf_additional_device_tracer_deps(): def tf_additional_device_tracer_test_flags(): return [] +def tf_additional_profiler_lib_deps(): + return [ + "//tensorflow/core/profiler/internal/cpu:host_tracer", + ] + if_cuda([ + "//tensorflow/core/profiler/internal/gpu:device_tracer", + ]) + def tf_additional_libdevice_data(): return [] diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index 845fe0ec047..d917d442f5c 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -8,8 +8,8 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "check_deps") -load("//tensorflow:tensorflow.bzl", "if_cuda") -load("//tensorflow:tensorflow.bzl", "if_rocm") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_cuda_library") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") @@ -34,7 +34,10 @@ cc_library( tf_cuda_library( name = "stream_executor", - cuda_deps = ["//tensorflow/stream_executor/cuda:cuda_activation"], + cuda_deps = [ + "//tensorflow/stream_executor/cuda:cuda_activation", + "//tensorflow/stream_executor/rocm:rocm_activation", + ], deps = [ "//tensorflow/stream_executor", "//tensorflow/stream_executor:dnn", @@ -62,9 +65,12 @@ cc_library( name = "stream_executor_cuda", deps = [ ":stream_executor_no_cuda", - ":cuda", ] + if_static( - ["//tensorflow/stream_executor/cuda:all_runtime"], + [ + "//tensorflow/stream_executor/cuda:all_runtime", + ":cuda", + ], + ["//tensorflow/stream_executor/cuda:cudart_stub"], ) + select({ "@local_config_cuda//cuda:darwin": ["IOKit"], "//conditions:default": [], @@ -284,20 +290,6 @@ cc_library( ], ) -# Check that libtensorflow_framework.so does not depend on cuda shared libraries. -check_deps( - name = "libtensorflow_cuda_check_deps", - disallowed_deps = [ - ":cuda", - "@local_config_cuda//cuda:cublas", - "@local_config_cuda//cuda:cuda_driver", - "@local_config_cuda//cuda:cudnn", - "@local_config_cuda//cuda:curand", - "@local_config_cuda//cuda:cusolver", - ], - deps = ["//tensorflow:libtensorflow_framework.so"], -) - cc_library( name = "rocm", data = [], diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl index ab05b25d682..72755341220 100644 --- a/tensorflow/core/platform/default/build_config_root.bzl +++ b/tensorflow/core/platform/default/build_config_root.bzl @@ -4,9 +4,13 @@ load("@local_config_remote_execution//:remote_execution.bzl", "gpu_test_tags") -def tf_cuda_tests_tags(): +def tf_gpu_tests_tags(): return ["requires-gpu", "gpu"] + gpu_test_tags() +# terminology changes: saving tf_cuda_* for compatibility +def tf_cuda_tests_tags(): + return tf_gpu_tests_tags() + def tf_sycl_tests_tags(): return ["requires-gpu", "gpu"] + gpu_test_tags() @@ -61,9 +65,25 @@ def tf_additional_gdr_deps(): "//conditions:default": [], }) -def if_static(extra_deps, otherwise = []): +# Include specific extra dependencies when building statically, or +# another set of dependencies otherwise. If "macos" is provided, that +# dependency list is used when using the framework_shared_object config +# on MacOS platforms. If "macos" is not provided, the "otherwise" list is +# used for all framework_shared_object platforms including MacOS. +def if_static(extra_deps, otherwise = [], macos = []): + ret = { + str(Label("//tensorflow:framework_shared_object")): otherwise, + "//conditions:default": extra_deps, + } + if macos: + ret[str(Label("//tensorflow:macos_with_framework_shared_object"))] = macos + return select(ret) + +def if_static_and_not_mobile(extra_deps, otherwise = []): return select({ str(Label("//tensorflow:framework_shared_object")): otherwise, + str(Label("//tensorflow:android")): otherwise, + str(Label("//tensorflow:ios")): otherwise, "//conditions:default": extra_deps, }) diff --git a/tensorflow/core/platform/default/cuda_build_defs.bzl b/tensorflow/core/platform/default/cuda_build_defs.bzl new file mode 100644 index 00000000000..8b0b3f55960 --- /dev/null +++ b/tensorflow/core/platform/default/cuda_build_defs.bzl @@ -0,0 +1,8 @@ +"""Open source build configurations for CUDA.""" + +load("@local_config_cuda//cuda:build_defs.bzl", _if_cuda_is_configured = "if_cuda_is_configured") + +# We perform this indirection so that the copybara tool can distinguish this +# macro from others provided by the same file. +def if_cuda_is_configured(x): + return _if_cuda_is_configured(x) diff --git a/tensorflow/core/platform/default/cuda_libdevice_path.cc b/tensorflow/core/platform/default/cuda_libdevice_path.cc index a8b2e7202ac..25eb6ab463b 100644 --- a/tensorflow/core/platform/default/cuda_libdevice_path.cc +++ b/tensorflow/core/platform/default/cuda_libdevice_path.cc @@ -19,7 +19,7 @@ limitations under the License. #include #if !defined(PLATFORM_GOOGLE) -#include "cuda/cuda_config.h" +#include "third_party/gpus/cuda/cuda_config.h" #endif #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc index ffcb38fdcd2..38cdb65c566 100644 --- a/tensorflow/core/platform/default/device_tracer.cc +++ b/tensorflow/core/platform/default/device_tracer.cc @@ -13,344 +13,344 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/device_tracer.h" - #if GOOGLE_CUDA #include + #include +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/cupti_wrapper.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/tracing.h" -#include "tensorflow/core/profiler/internal/cpu/host_tracer.h" -#include "tensorflow/core/profiler/lib/traceme.h" - -namespace { - -// Maps a MemcpyKind enum to a const string. -const char *getMemcpyKindString(CUpti_ActivityMemcpyKind kind) { - switch (kind) { - case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD: - return "HtoD"; - case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH: - return "DtoH"; - case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA: - return "HtoA"; - case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH: - return "AtoH"; - case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA: - return "AtoA"; - case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD: - return "AtoD"; - case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA: - return "DtoA"; - case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD: - return "DtoD"; - case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH: - return "HtoH"; - case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP: - return "PtoP"; - default: - break; - } - return ""; -} - -// Maps a MemoryKind enum to a const string. -const char *getMemoryKindString(CUpti_ActivityMemoryKind kind) { - switch (kind) { - case CUPTI_ACTIVITY_MEMORY_KIND_UNKNOWN: - return "Unknown"; - case CUPTI_ACTIVITY_MEMORY_KIND_PAGEABLE: - return "Pageable"; - case CUPTI_ACTIVITY_MEMORY_KIND_PINNED: - return "Pinned"; - case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE: - return "Device"; - case CUPTI_ACTIVITY_MEMORY_KIND_ARRAY: - return "Array"; - default: - break; - } - return ""; -} - -// Maps an OverheadKind enum to a const string. -const char *getActivityOverheadKindString(CUpti_ActivityOverheadKind kind) { - switch (kind) { - case CUPTI_ACTIVITY_OVERHEAD_DRIVER_COMPILER: - return "COMPILER"; - case CUPTI_ACTIVITY_OVERHEAD_CUPTI_BUFFER_FLUSH: - return "BUFFER_FLUSH"; - case CUPTI_ACTIVITY_OVERHEAD_CUPTI_INSTRUMENTATION: - return "INSTRUMENTATION"; - case CUPTI_ACTIVITY_OVERHEAD_CUPTI_RESOURCE: - return "RESOURCE"; - default: - break; - } - return ""; -} - -} // namespace +#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { -namespace devicetracer { - -// Forward declaration. -class CUPTIManager; - -// Returns a pointer to the CUPTIManager singleton. -CUPTIManager *GetCUPTIManager(); - -// Callback interface for consumers of CUPTI tracing. -class CUPTIClient { - public: - virtual ~CUPTIClient() {} - - // Invoked for each CUPTI activity reported. - virtual void ActivityCallback(const CUpti_Activity &activity) = 0; -}; - -#define CUPTI_CALL(call) \ - do { \ - CUptiResult _status = cupti_wrapper_->call; \ - if (_status != CUPTI_SUCCESS) { \ - LOG(ERROR) << "cuda call " << #call << " failed " << _status; \ - } \ - } while (0) - -// Singleton class to manage registration of CUPTI callbacks. -class CUPTIManager { - public: - CUPTIManager() { - cupti_wrapper_.reset(new perftools::gputools::profiler::CuptiWrapper()); +namespace { +Status ToStatus(CUptiResult result) { + if (result == CUPTI_SUCCESS) { + return Status::OK(); } - - static CUPTIManager *Create() { - auto manager = absl::make_unique(); - CUptiResult status = manager->cupti_wrapper_->ActivityRegisterCallbacks( - BufferRequested, BufferCompleted); - if (status != CUPTI_SUCCESS) { - LOG(ERROR) << "Failed to initialize CUPTI: " << status; - return nullptr; - } - return manager.release(); - } - - // Enables tracing and delivers event callbacks to 'client'. - // Does not take ownership of client. Client's lifetime must persist - // until tracing is disabled. - Status EnableTrace(CUPTIClient *client); - - // Disable tracing. No further events will be delivered to 'client'. - Status DisableTrace(); - - private: - // Static functions which we can use as CUPTI callbacks. - static void BufferRequested(uint8_t **buffer, size_t *size, - size_t *maxNumRecords) { - GetCUPTIManager()->InternalBufferRequested(buffer, size, maxNumRecords); - } - static void BufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer, - size_t size, size_t validSize) { - GetCUPTIManager()->InternalBufferCompleted(ctx, streamId, buffer, size, - validSize); - } - // These methods are called by the static stubs above. - void InternalBufferRequested(uint8_t **buffer, size_t *size, - size_t *maxNumRecords); - void InternalBufferCompleted(CUcontext ctx, uint32_t streamId, - uint8_t *buffer, size_t size, size_t validSize); - - // Size of buffers used for CUPTI tracing. - static constexpr size_t kBufferSize = 32 * 1024; - // Required alignment of CUPTI buffers. - static constexpr size_t kBufferAlignment = 8; - - mutex mu_; - CUPTIClient *client_ GUARDED_BY(mu_); - std::unique_ptr cupti_wrapper_; - - TF_DISALLOW_COPY_AND_ASSIGN(CUPTIManager); -}; - -Status CUPTIManager::EnableTrace(CUPTIClient *client) { - mutex_lock l(mu_); - // TODO(pbar) Work out the minimal set to trace. - // We can currently manage without driver/runtime tracing. - // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_CONTEXT)); - // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER)); - // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME)); - // These might be useful for annotations but require NVTX API. - // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_NAME)); - // CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_MARKER)); - - CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_DEVICE)); - CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_KERNEL)); - CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY)); - CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY2)); - CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET)); - CUPTI_CALL(ActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD)); - client_ = client; - return Status::OK(); + const char* str = nullptr; + cuptiGetResultString(result, &str); + return errors::Unavailable("CUPTI error: ", str ? str : ""); } -Status CUPTIManager::DisableTrace() { - // We turn off all tracing regardless. - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_NAME)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_MARKER)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_CONTEXT)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_DRIVER)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_DEVICE)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_KERNEL)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY2)); - CUPTI_CALL(ActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET)); - CUPTI_CALL(ActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED)); - { - // Don't acquire this lock until Flush returns, since Flush - // will potentially cause callbacks into BufferCompleted. - mutex_lock l(mu_); - client_ = nullptr; +Status ToStatus(CUresult result) { + if (result == CUDA_SUCCESS) { + return Status::OK(); } - return Status::OK(); + const char* str = nullptr; + cuGetErrorName(result, &str); + return errors::Unavailable("CUDA error: ", str ? str : ""); } -void CUPTIManager::InternalBufferRequested(uint8_t **buffer, size_t *size, - size_t *maxNumRecords) { - VLOG(2) << "BufferRequested"; - void *p = port::AlignedMalloc(kBufferSize, kBufferAlignment); - *size = kBufferSize; - *buffer = reinterpret_cast(p); - *maxNumRecords = 0; +void LogIfError(const Status& status) { + if (status.ok()) { + return; + } + LOG(ERROR) << status.error_message(); } -void CUPTIManager::InternalBufferCompleted(CUcontext ctx, uint32_t streamId, - uint8_t *buffer, size_t size, - size_t validSize) { - VLOG(2) << "BufferCompleted"; - CUptiResult status; - CUpti_Activity *record = nullptr; - mutex_lock l(mu_); // Hold mu_ while using client_. - if (client_ && validSize > 0) { - do { - status = - cupti_wrapper_->ActivityGetNextRecord(buffer, validSize, &record); - if (status == CUPTI_SUCCESS) { - client_->ActivityCallback(*record); - } else { - break; - } - } while (1); - - // report any records dropped from the queue - size_t dropped; - CUPTI_CALL(ActivityGetNumDroppedRecords(ctx, streamId, &dropped)); - if (dropped != 0) { - LOG(WARNING) << "Dropped " << dropped << " activity records"; +bool IsAscii(string& str) { + for (auto& ch : str) { + if (!absl::ascii_isascii(ch)) { + return false; } } - port::AlignedFree(buffer); + return true; } -CUPTIManager *GetCUPTIManager() { - static CUPTIManager *manager = CUPTIManager::Create(); - return manager; +struct KernelRecord { + const char* kernel_name; + // TODO(csigg): cuStreamGetCtx introduced in CUDA 9.2 would allow us to only + // record the stream and infer the context during collection. + CUcontext context; + CUstream stream; + CUevent start_event; + CUevent stop_event; + const std::string* annotation; +}; + +struct MemcpyRecord { + CUmemorytype src_type; + CUmemorytype dst_type; + size_t size_bytes; + CUcontext context; + CUstream stream; + CUevent start_event; + CUevent stop_event; + const std::string* annotation; +}; + +Status CreateAndRecordEvent(CUevent* event, CUstream stream) { + TF_RETURN_IF_ERROR(ToStatus(cuEventCreate(event, CU_EVENT_DEFAULT))); + return ToStatus(cuEventRecord(*event, stream)); } -#ifdef _MSC_VER -#define __thread __declspec(thread) -#endif - -// TODO(pbar) Move this to platform specific header file? -// Static thread local variable for POD types. -#define TF_STATIC_THREAD_LOCAL_POD(_Type_, _var_) \ - static __thread _Type_ s_obj_##_var_; \ - namespace { \ - class ThreadLocal_##_var_ { \ - public: \ - ThreadLocal_##_var_() {} \ - void Init() {} \ - inline _Type_ *pointer() const { return &s_obj_##_var_; } \ - inline _Type_ *safe_pointer() const { return &s_obj_##_var_; } \ - _Type_ &get() const { return s_obj_##_var_; } \ - bool is_native_tls() const { return true; } \ - \ - private: \ - TF_DISALLOW_COPY_AND_ASSIGN(ThreadLocal_##_var_); \ - } _var_; \ - } // namespace - // Thread-local state recording the most recent annotation (if any). // When non-null, this points to a string in the active annotation // of the current thread. The annotation is guaranteed to remain live // for the duration of the CUPTI API callback. -TF_STATIC_THREAD_LOCAL_POD(const char *, tls_current_annotation); +static thread_local const char* tls_current_annotation; + +// Stores a series of kernel and memcpy records. +class CudaEventRecorder { + public: + // Registers the start of a kernel launch. The returned index should be passed + // to StopKernel() after the kernel launch has completed. + size_t StartKernel(const char* kernel_name, CUcontext context, + CUstream stream) { + KernelRecord record = {kernel_name, context, stream}; + LogIfError(CreateAndRecordEvent(&record.start_event, stream)); + mutex_lock lock(mutex_); + if (tls_current_annotation) { + record.annotation = &*annotations_.emplace(tls_current_annotation).first; + } + kernel_records_.push_back(record); + return kernel_records_.size() - 1; + } + void StopKernel(size_t index) { + mutex_lock lock(mutex_); + auto& record = kernel_records_[index]; + LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream)); + } + + // Registers the start of a copy operation. The returned index should be + // passed to StopMemcpy() after the kernel launch has completed. + size_t StartMemcpy(CUmemorytype src_type, CUmemorytype dst_type, + size_t size_bytes, CUcontext context, CUstream stream) { + MemcpyRecord record = {src_type, dst_type, size_bytes, context, stream}; + LogIfError(CreateAndRecordEvent(&record.start_event, stream)); + mutex_lock lock(mutex_); + if (tls_current_annotation) { + record.annotation = &*annotations_.emplace(tls_current_annotation).first; + } + memcpy_records_.push_back(record); + return memcpy_records_.size() - 1; + } + void StopMemcpy(size_t index) { + mutex_lock lock(mutex_); + auto& record = memcpy_records_[index]; + LogIfError(CreateAndRecordEvent(&record.stop_event, record.stream)); + } + + std::vector ConsumeKernelRecords() { + mutex_lock lock(mutex_); + return std::move(kernel_records_); + } + std::vector ConsumeMemcpyRecords() { + mutex_lock lock(mutex_); + return std::move(memcpy_records_); + } + + private: + mutex mutex_; + std::unordered_set annotations_ GUARDED_BY(mutex_); + std::vector kernel_records_ GUARDED_BY(mutex_); + std::vector memcpy_records_ GUARDED_BY(mutex_); +}; + +// Instances register callbacks with CUPTI to notify the event recorder before +// and after kernel launches and memory copies. +class CuptiCallbackHook { + public: + CuptiCallbackHook() : subscriber_(nullptr) {} + + Status Enable(CudaEventRecorder* recorder) { + TF_RETURN_IF_ERROR( + ToStatus(cuptiSubscribe(&subscriber_, &CuptiCallback, recorder))); + for (auto cbid : {CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel, + CUPTI_DRIVER_TRACE_CBID_cuMemcpy, + CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync, + CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2, + CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2, + CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2, + CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2, + CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2, + CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2}) { + TF_RETURN_IF_ERROR(ToStatus(cuptiEnableCallback( + /*enable=*/1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid))); + } + return Status::OK(); + } + + ~CuptiCallbackHook() { LogIfError(ToStatus(cuptiUnsubscribe(subscriber_))); } + + private: + static void CUPTIAPI CuptiCallback(void* userdata, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid, + const void* cbdata) { + auto recorder = static_cast(userdata); + auto data = static_cast(cbdata); + DCHECK_EQ(domain, CUPTI_CB_DOMAIN_DRIVER_API); + + if (data->callbackSite == CUPTI_API_ENTER) { + DriverApiEnterCallback(cbid, *data, recorder); + } else { + DriverApiExitCallback(cbid, *data, recorder); + } + } + + static CUmemorytype GetMemoryType(CUdeviceptr ptr) { + CUmemorytype mem_type; + auto status = + cuPointerGetAttribute(&mem_type, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, ptr); + if (status == CUDA_ERROR_INVALID_VALUE) { + // Pointer not registered with CUDA, must be host memory. + return CU_MEMORYTYPE_HOST; + } + LogIfError(ToStatus(status)); + return mem_type; + } + + template + static void StartMemcpy(CUmemorytype src_type, CUmemorytype dst_type, + const CUpti_CallbackData& cbdata, + CudaEventRecorder* recorder) { + auto params = static_cast(cbdata.functionParams); + *cbdata.correlationData = recorder->StartMemcpy( + src_type, dst_type, params->ByteCount, cbdata.context, nullptr); + } + template + static void StartMemcpyAsync(CUmemorytype src_type, CUmemorytype dst_type, + const CUpti_CallbackData& cbdata, + CudaEventRecorder* recorder) { + auto params = static_cast(cbdata.functionParams); + *cbdata.correlationData = recorder->StartMemcpy( + src_type, dst_type, params->ByteCount, cbdata.context, params->hStream); + } + + static void DriverApiEnterCallback(CUpti_CallbackId cbid, + const CUpti_CallbackData& cbdata, + CudaEventRecorder* recorder) { + switch (cbid) { + case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: { + DCHECK_NE(cbdata.symbolName, nullptr); + auto params = + static_cast(cbdata.functionParams); + *cbdata.correlationData = recorder->StartKernel( + cbdata.symbolName, cbdata.context, params->hStream); + return; + } + + case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: { + auto params = + static_cast(cbdata.functionParams); + return StartMemcpy(GetMemoryType(params->src), + GetMemoryType(params->dst), cbdata, + recorder); + } + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: { + auto params = + static_cast(cbdata.functionParams); + return StartMemcpyAsync( + GetMemoryType(params->src), GetMemoryType(params->dst), cbdata, + recorder); + } + + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2: + return StartMemcpy( + CU_MEMORYTYPE_HOST, CU_MEMORYTYPE_DEVICE, cbdata, recorder); + + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2: + return StartMemcpyAsync( + CU_MEMORYTYPE_HOST, CU_MEMORYTYPE_DEVICE, cbdata, recorder); + + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2: + return StartMemcpy( + CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_HOST, cbdata, recorder); + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2: + return StartMemcpyAsync( + CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_HOST, cbdata, recorder); + + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2: + return StartMemcpy( + CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_DEVICE, cbdata, recorder); + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2: + return StartMemcpyAsync( + CU_MEMORYTYPE_DEVICE, CU_MEMORYTYPE_DEVICE, cbdata, recorder); + + default: + LOG(ERROR) << "Unexpected callback id: " << cbid; + } + } + + static void DriverApiExitCallback(CUpti_CallbackId cbid, + const CUpti_CallbackData& cbdata, + CudaEventRecorder* recorder) { + switch (cbid) { + case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: + recorder->StopKernel(*cbdata.correlationData); + break; + case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyAsync: + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2: + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2: + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2: + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2: + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2: + case CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2: + recorder->StopMemcpy(*cbdata.correlationData); + break; + default: + LOG(ERROR) << "Unexpected callback id: " << cbid; + } + } + + CUpti_SubscriberHandle subscriber_; +}; class TraceCollectorImpl : public tracing::TraceCollector { public: - class ActivityHandle : public Handle { - public: - ActivityHandle(string &&name, int level) - : trace_me_(std::move(name), level) {} - - private: - profiler::TraceMe trace_me_; - }; - TraceCollectorImpl() { tracing::SetTraceCollector(this); } + TraceCollectorImpl() : active_trace_session_(false) { + tracing::SetTraceCollector(this); + } ~TraceCollectorImpl() override { DCHECK(!active_trace_session_) - << "Unexpected active trace session detected. "; + << "Unexpected active trace session detected."; } // Note the method can be called after a call to Stop(). virtual std::unique_ptr CreateAnnotationHandle( StringPiece name_part1, StringPiece name_part2) const { struct Impl : public tracing::TraceCollector::Handle { - string annotation; - explicit Impl(string &&name_scope) : annotation(name_scope) { + std::string annotation; + explicit Impl(std::string&& name_scope) : annotation(name_scope) { VLOG(2) << "CreateAnnotationHandle " << annotation; // Remember the most recent ScopedAnnotation for each thread. - tls_current_annotation.get() = annotation.c_str(); + tls_current_annotation = annotation.c_str(); } - ~Impl() override { tls_current_annotation.get() = nullptr; } + ~Impl() override { tls_current_annotation = nullptr; } }; return absl::make_unique(ConcatenateNames(name_part1, name_part2)); } - virtual std::unique_ptr CreateActivityHandle( - StringPiece name_part1, StringPiece name_part2, bool is_expensive) const { - if (!IsEnabledForActivities(is_expensive)) { - return nullptr; - } - return absl::make_unique( - ConcatenateNames(name_part1, name_part2), GetLevel(is_expensive)); - } - bool IsEnabledForAnnotations() const override { return active_trace_session_.load(std::memory_order_relaxed); } - bool IsEnabledForActivities(bool is_expensive) const override { - return profiler::TraceMeRecorder::Active(GetLevel(is_expensive)); - } - void Start() { DCHECK(!active_trace_session_) - << "Unexpected active trace session detected. "; + << "Unexpected active trace session detected."; active_trace_session_ = true; } @@ -360,366 +360,370 @@ class TraceCollectorImpl : public tracing::TraceCollector { } private: - static int GetLevel(bool is_expensive) { - return profiler::GetTFTraceMeLevel(is_expensive); - } - std::atomic active_trace_session_; }; -TraceCollectorImpl *GlobalDefaultTraceCollector() { - static auto *instance = new TraceCollectorImpl(); +TraceCollectorImpl* GlobalDefaultTraceCollector() { + static auto* instance = new TraceCollectorImpl(); return instance; } -class DeviceTracerImpl : public DeviceTracer, public CUPTIClient { +// 'DeviceTracer' is an interface for collecting low-level execution timings +// of hardware accelerator (e.g. GPU) computation and DMA transfers. +class DeviceTracer : public profiler::ProfilerInterface { public: - DeviceTracerImpl(CUPTIManager *cupti_manager); - ~DeviceTracerImpl() override; + DeviceTracer(); + ~DeviceTracer() override; - // DeviceTracer interface: + // ProfilerInterface interface: Status Start() override; Status Stop() override; - Status Collect(StepStatsCollector *collector) override; - - protected: - // This callback is used exclusively by CUPTIManager. - friend class CUPTIManager; - void ActivityCallback(const CUpti_Activity &activity) override; + // Collect trace results. Results are added to the specified + // StepStatsCollector. Does not clear any existing stats. + // It is an error to call 'Collect' while a trace is running. + Status CollectData(RunMetadata* run_metadata) override; private: - // Internal struct to record kernel launches. - struct KernelRecord { - uint64_t start_timestamp; - uint64_t end_timestamp; - uint32 device_id; - uint32 stream_id; - uint32 correlation_id; - }; - // Internal struct to record memcpy operations. - struct MemcpyRecord { - uint64_t start_timestamp; - uint64_t end_timestamp; - uint32 device_id; - uint32 stream_id; - uint32 correlation_id; - uint8 copyKind; - uint8 srcKind; - uint8 dstKind; - uint64 bytes; - }; - - // This is the subscriber callback which is invoked directly by CUPTI. - // The 'userdata' argument will be a pointer to the active 'DeviceTracerImpl'. - static void CUPTIAPI ApiCallback(void *userdata, CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, const void *cbdata); - - // Records the mapping between correlation ID and kernel name. - void AddCorrelationId(uint32 correlation_id, const string &name); - - // Returns the current system time in microseconds. - inline int64 NowInUsec() { return Env::Default()->NowMicros(); } - - CUPTIManager *cupti_manager_; - std::unique_ptr cupti_wrapper_; - CUpti_SubscriberHandle subscriber_; - - mutex trace_mu_; - static constexpr size_t kMaxRecords = 1024 * 1024; - std::map correlations_ GUARDED_BY(trace_mu_); - std::vector kernel_records_ GUARDED_BY(trace_mu_); - std::vector memcpy_records_ GUARDED_BY(trace_mu_); + std::unique_ptr recorder_; + std::unique_ptr cupti_hook_; mutex mu_; bool enabled_ GUARDED_BY(mu_); - int64 start_walltime_us_ GUARDED_BY(mu_); - int64 end_walltime_us_ GUARDED_BY(mu_); - uint64_t start_timestamp_ GUARDED_BY(mu_); - uint64_t end_timestamp_ GUARDED_BY(mu_); - std::unique_ptr host_tracer_ GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(DeviceTracerImpl); }; -DeviceTracerImpl::DeviceTracerImpl(CUPTIManager *cupti_manager) - : cupti_manager_(cupti_manager) { +DeviceTracer::DeviceTracer() + : recorder_(new CudaEventRecorder()), enabled_(false) { VLOG(1) << "DeviceTracer created."; - cupti_wrapper_.reset(new perftools::gputools::profiler::CuptiWrapper()); - host_tracer_ = profiler::cpu::HostTracer::Create(2); - enabled_ = false; } -DeviceTracerImpl::~DeviceTracerImpl() { +DeviceTracer::~DeviceTracer() { // Unregister the CUPTI callbacks if needed to prevent them from accessing // freed memory. Stop().IgnoreError(); } -Status DeviceTracerImpl::Start() { +Status DeviceTracer::Start() { VLOG(1) << "DeviceTracer::Start"; mutex_lock l(mu_); if (enabled_) { return errors::FailedPrecondition("DeviceTracer is already enabled."); } - // There can only be one CUPTI subscriber. If we can't create one then - // there is another trace in progress (possibly by external code). - CUptiResult ret; - ret = cupti_wrapper_->Subscribe( - &subscriber_, static_cast(ApiCallback), this); - if (ret == CUPTI_ERROR_MAX_LIMIT_REACHED) { - return errors::Unavailable("CUPTI subcriber limit reached."); - } else if (ret != CUPTI_SUCCESS) { - return errors::Internal("Failed to create CUPTI subcriber."); - } + cupti_hook_.reset(new CuptiCallbackHook()); + TF_RETURN_IF_ERROR(cupti_hook_->Enable(recorder_.get())); // Register as a TraceEngine to receive ScopedAnnotations. GlobalDefaultTraceCollector()->Start(); - // Intercept launch and memcpy calls to capture the Op name annotation. - // TODO(pbar) Add callbacks for memcpy variants. - CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_, - CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel)); - CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_, - CUPTI_CB_DOMAIN_RUNTIME_API, - CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020)); - CUPTI_CALL(EnableCallback( - /*enable=*/1, subscriber_, CUPTI_CB_DOMAIN_RUNTIME_API, - CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020)); - - CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_, - CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2)); - CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_, - CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2)); - CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_, - CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2)); - CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_, - CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2)); - CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_, - CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2)); - CUPTI_CALL(EnableCallback(/*enable=*/1, subscriber_, - CUPTI_CB_DOMAIN_DRIVER_API, - CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2)); - - TF_RETURN_IF_ERROR(cupti_manager_->EnableTrace(this)); - - CUPTI_CALL(GetTimestamp(&start_timestamp_)); - start_walltime_us_ = NowInUsec(); - host_tracer_->Start().IgnoreError(); enabled_ = true; return Status::OK(); } -Status DeviceTracerImpl::Stop() { +Status DeviceTracer::Stop() { VLOG(1) << "DeviceTracer::Stop"; mutex_lock l(mu_); if (!enabled_) { return Status::OK(); } - CUPTI_CALL(Unsubscribe(subscriber_)); + cupti_hook_.reset(); GlobalDefaultTraceCollector()->Stop(); - TF_RETURN_IF_ERROR(cupti_manager_->DisableTrace()); - end_walltime_us_ = NowInUsec(); - CUPTI_CALL(GetTimestamp(&end_timestamp_)); enabled_ = false; - host_tracer_->Stop().IgnoreError(); return Status::OK(); } -void DeviceTracerImpl::AddCorrelationId(uint32 correlation_id, - const string &name) { - VLOG(2) << correlation_id << " : " << name; - mutex_lock l(trace_mu_); - if (correlations_.size() >= kMaxRecords) return; - correlations_.emplace(correlation_id, name); -} +class CudaEventCollector { + struct DeviceInfo { + int ordinal; + std::string name; + int num_contexts; + }; -/*static*/ void DeviceTracerImpl::ApiCallback(void *userdata, - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const void *cbdata) { - auto *cbInfo = reinterpret_cast(cbdata); - DeviceTracerImpl *tracer = reinterpret_cast(userdata); - VLOG(2) << "ApiCallback " << domain << ":" << cbid - << " func: " << cbInfo->functionName; + struct ContextInfo { + int index; + const DeviceInfo* dev_info; + int num_streams; + CUevent end_event; + }; - // API callbacks are invoked synchronously on the thread making the - // CUDA API call. If this pointer is non-null then the ScopedAnnotation - // must be valid. - const char *tls_annotation = tls_current_annotation.get(); + struct StreamInfo { + std::string name; + int index; // 0 is reserved for null stream. + const ContextInfo* ctx_info; + }; - if ((domain == CUPTI_CB_DOMAIN_DRIVER_API) && - (cbid == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel)) { - if (cbInfo->callbackSite == CUPTI_API_ENTER) { - auto *params = reinterpret_cast( - cbInfo->functionParams); - if (VLOG_IS_ON(2)) { - VLOG(2) << "LAUNCH stream " << params->hStream << " correllation " - << cbInfo->correlationId << " kernel " << cbInfo->symbolName; - } - const string annotation = - tls_annotation ? tls_annotation : cbInfo->symbolName; - tracer->AddCorrelationId(cbInfo->correlationId, annotation); - } - } else if ((domain == CUPTI_CB_DOMAIN_RUNTIME_API) && - (cbid == CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020 || - cbid == CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020)) { - if (cbInfo->callbackSite == CUPTI_API_ENTER) { - if (VLOG_IS_ON(2)) { - auto *funcParams = reinterpret_cast( - cbInfo->functionParams); - size_t count = funcParams->count; - enum cudaMemcpyKind kind = funcParams->kind; - VLOG(2) << "MEMCPY count " << count << " kind " << kind; - } - if (tls_annotation) { - const string annotation = tls_annotation; - tracer->AddCorrelationId(cbInfo->correlationId, annotation); - } - } - } else if ((domain == CUPTI_CB_DOMAIN_DRIVER_API) && - (cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoD_v2 || - cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoH_v2 || - cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoD_v2 || - cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyHtoDAsync_v2 || - cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoHAsync_v2 || - cbid == CUPTI_DRIVER_TRACE_CBID_cuMemcpyDtoDAsync_v2)) { - if (cbInfo->callbackSite == CUPTI_API_EXIT && tls_annotation) { - const string annotation = tls_annotation; - tracer->AddCorrelationId(cbInfo->correlationId, annotation); - } - } else { - VLOG(1) << "Unhandled API Callback for " << domain << " " << cbid; + // Include context in key to distinguish null streams. + using StreamKey = std::pair; + + CudaEventCollector(CudaEventRecorder* recorder, StepStatsCollector* collector) + : recorder_(recorder), collector_(collector) { + DCHECK(recorder != nullptr); + DCHECK(collector != nullptr); } -} -void DeviceTracerImpl::ActivityCallback(const CUpti_Activity &record) { - VLOG(2) << "ActivityCallback " << record.kind; - mutex_lock l(trace_mu_); - switch (record.kind) { - case CUPTI_ACTIVITY_KIND_MEMCPY: { - if (memcpy_records_.size() >= kMaxRecords) return; - auto *memcpy = reinterpret_cast(&record); - memcpy_records_.push_back(MemcpyRecord{ - memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId, - memcpy->correlationId, memcpy->copyKind, memcpy->srcKind, - memcpy->dstKind, memcpy->bytes}); - break; + // Populates device_infos_ from all devices. + Status InitializeDeviceInfos() { + int count; + TF_RETURN_IF_ERROR(ToStatus(cuDeviceGetCount(&count))); + for (int ordinal = 0; ordinal < count; ++ordinal) { + CUdevice device; + TF_RETURN_IF_ERROR(ToStatus(cuDeviceGet(&device, ordinal))); + char name[100]; + TF_RETURN_IF_ERROR(ToStatus(cuDeviceGetName(name, sizeof(name), device))); + device_infos_[device] = {ordinal, name}; } - case CUPTI_ACTIVITY_KIND_MEMCPY2: { - if (memcpy_records_.size() >= kMaxRecords) return; - auto *memcpy = reinterpret_cast(&record); - memcpy_records_.push_back(MemcpyRecord{ - memcpy->start, memcpy->end, memcpy->deviceId, memcpy->streamId, - memcpy->correlationId, memcpy->copyKind, memcpy->srcKind, - memcpy->dstKind, memcpy->bytes}); - break; - } - case CUPTI_ACTIVITY_KIND_KERNEL: - case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { - if (kernel_records_.size() >= kMaxRecords) return; - auto *kernel = reinterpret_cast(&record); - kernel_records_.push_back(KernelRecord{kernel->start, kernel->end, - kernel->deviceId, kernel->streamId, - kernel->correlationId}); - break; - } - default: - VLOG(1) << "ActivityCallback unhandled kind"; - break; + return Status::OK(); } -} -Status DeviceTracerImpl::Collect(StepStatsCollector *collector) { + // Returns element from context_infos_, adding it if not yet present. + Status GetContextInfo(CUcontext context, ContextInfo** ctx_info_ptr) { + auto it = context_infos_.find(context); + + if (it == context_infos_.end()) { + TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(context))); + CUdevice device; + TF_RETURN_IF_ERROR(ToStatus(cuCtxGetDevice(&device))); + + auto& dev_info = device_infos_[device]; + ContextInfo ctx_info = {dev_info.num_contexts++, &dev_info}; + it = context_infos_.emplace(context, ctx_info).first; + } + + *ctx_info_ptr = &it->second; + return Status::OK(); + } + + // Adds element to stream_infos_ if not yet present. If present, clear name + // if it doesn't match parameter. + Status AddStreamInfo(CUcontext context, CUstream stream, + absl::string_view name) { + StreamKey key(context, stream); + auto it = stream_infos_.find(key); + if (it != stream_infos_.end()) { + if (it->second.name != name) { + it->second.name.clear(); // Stream with inconsistent names, clear it. + } + return Status::OK(); + } + + ContextInfo* ctx_info; + TF_RETURN_IF_ERROR(GetContextInfo(context, &ctx_info)); + int index = stream ? ++ctx_info->num_streams : 0; + StreamInfo stream_info = {static_cast(name), index, ctx_info}; + stream_infos_.emplace(key, stream_info); + return Status::OK(); + } + + // Returns string describing source and destination memory types. + static std::string GetMemcpyName(const MemcpyRecord& record) { + auto get_memory_type = [](CUmemorytype mem_type) { + switch (mem_type) { + case CU_MEMORYTYPE_HOST: + return 'H'; + case CU_MEMORYTYPE_DEVICE: + return 'D'; + case CU_MEMORYTYPE_ARRAY: + return 'A'; + case CU_MEMORYTYPE_UNIFIED: + return 'U'; + default: + LOG(ERROR) << "Unknown memory type: " << mem_type; + return '?'; + } + }; + return absl::StrFormat("Memcpy%cto%c", get_memory_type(record.src_type), + get_memory_type(record.dst_type)); + } + + // Returns time in microseconds between events recorded on the GPU. + static uint64_t GetElapsedTimeUs(CUevent start, CUevent stop) { + float elapsed_ms = 0.0f; + LogIfError(ToStatus(cuEventElapsedTime(&elapsed_ms, start, stop))); + return static_cast( + std::llroundf(1000 * std::max(elapsed_ms, 0.0f))); + } + + // Synchronizes all contexts. + Status Synchronize() const { + for (const auto& pair : context_infos_) { + TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first))); + TF_RETURN_IF_ERROR(ToStatus(cuCtxSynchronize())); + } + return Status::OK(); + } + + // Save stats to collector; + Status SaveStats(std::unique_ptr stats, + const StreamInfo& stream_info) const { + auto ctx_info = stream_info.ctx_info; + auto dev_info = ctx_info->dev_info; + // TODO(csigg): tfprof_node.cc, run_metadata_test.py, and timeline_test.py + // currently require this particular formatting. + collector_->Save( + absl::StrFormat("/device:GPU:%d/stream:all", dev_info->ordinal), + new NodeExecStats(*stats)); + auto name = absl::StrFormat("/gpu:%d (%s)/context#%d/", dev_info->ordinal, + dev_info->name, ctx_info->index); + if (stream_info.index) { + absl::StrAppend(&name, "stream#", std::to_string(stream_info.index)); + } else { + absl::StrAppend(&name, "null stream"); + } + if (!stream_info.name.empty()) { + absl::StrAppend(&name, ":", stream_info.name); + } + collector_->Save(name, stats.release()); + return Status::OK(); + } + + Status SaveRecord(const KernelRecord& record) const { + if (!record.start_event || !record.stop_event) { + return Status::OK(); + } + const auto& stream_info = + stream_infos_.at(StreamKey(record.context, record.stream)); + auto start_us = + GetElapsedTimeUs(record.start_event, stream_info.ctx_info->end_event); + auto elapsed_us = GetElapsedTimeUs(record.start_event, record.stop_event); + + auto stats = absl::make_unique(); + std::string node_name = record.kernel_name; + // Sometimes CUPTI returns invalid characters. See b/129892466. + if (!IsAscii(node_name)) { + node_name = ""; + } + if (record.annotation) { + node_name = absl::StrCat(*record.annotation, "::", node_name); + } + stats->set_node_name(node_name); + // TODO(csigg): Report grid size? + std::string node_label; + stats->set_timeline_label(node_label); + stats->set_all_start_micros(end_walltime_us_ - start_us); + stats->set_op_end_rel_micros(elapsed_us); + stats->set_all_end_rel_micros(elapsed_us); + return SaveStats(std::move(stats), stream_info); + } + + Status SaveRecord(const MemcpyRecord& record) const { + if (!record.start_event || !record.stop_event) { + return Status::OK(); + } + const auto& stream_info = + stream_infos_.at(StreamKey(record.context, record.stream)); + auto start_us = + GetElapsedTimeUs(record.start_event, stream_info.ctx_info->end_event); + auto elapsed_us = GetElapsedTimeUs(record.start_event, record.stop_event); + + auto stats = absl::make_unique(); + std::string node_name = GetMemcpyName(record); + // Sometimes CUPTI returns invalid characters. See b/129892466. + if (!IsAscii(node_name)) { + node_name = ""; + } + if (record.annotation) { + node_name = absl::StrCat(*record.annotation, "::", node_name); + } + stats->set_node_name(node_name); + // TODO(csigg): Show label in Chrome trace viewer. + std::string node_label = absl::StrFormat("%d bytes", record.size_bytes); + stats->set_timeline_label(node_label); + stats->set_all_start_micros(end_walltime_us_ - start_us); + stats->set_op_end_rel_micros(elapsed_us); + stats->set_all_end_rel_micros(elapsed_us); + return SaveStats(std::move(stats), stream_info); + } + + Status Collect() { + TF_RETURN_IF_ERROR(InitializeDeviceInfos()); + + auto kernel_records = recorder_->ConsumeKernelRecords(); + auto memcpy_records = recorder_->ConsumeMemcpyRecords(); + LOG(INFO) << "Collecting " << kernel_records.size() << " kernel records, " + << memcpy_records.size() << " memcpy records."; + + // Gather all profiled streams and contexts. + for (const auto& record : kernel_records) { + TF_RETURN_IF_ERROR( + AddStreamInfo(record.context, record.stream, "Kernel")); + } + for (const auto& record : memcpy_records) { + TF_RETURN_IF_ERROR( + AddStreamInfo(record.context, record.stream, GetMemcpyName(record))); + } + + // Synchronize all contexts, record end events, synchronize again. + TF_RETURN_IF_ERROR(Synchronize()); + for (auto& pair : context_infos_) { + TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(pair.first))); + TF_RETURN_IF_ERROR(CreateAndRecordEvent(&pair.second.end_event, nullptr)); + } + TF_RETURN_IF_ERROR(Synchronize()); + end_walltime_us_ = Env::Default()->NowMicros(); + + for (const auto& record : kernel_records) { + TF_RETURN_IF_ERROR(SaveRecord(record)); + } + for (const auto& record : memcpy_records) { + TF_RETURN_IF_ERROR(SaveRecord(record)); + } + + return Status::OK(); + } + + public: + // Consumes the records in recorder and saves them to the collector. + static Status Collect(CudaEventRecorder* recorder, + StepStatsCollector* collector) { + CUcontext context; + TF_RETURN_IF_ERROR(ToStatus(cuCtxGetCurrent(&context))); + auto status = CudaEventCollector(recorder, collector).Collect(); + TF_RETURN_IF_ERROR(ToStatus(cuCtxSetCurrent(context))); + return status; + } + + private: + CudaEventRecorder* recorder_; + StepStatsCollector* collector_; + + absl::node_hash_map device_infos_; + absl::node_hash_map context_infos_; + absl::flat_hash_map> stream_infos_; + int64 end_walltime_us_; +}; + +Status DeviceTracer::CollectData(RunMetadata* run_metadata) { mutex_lock l(mu_); if (enabled_) { return errors::FailedPrecondition("DeviceTracer is still enabled."); } - // TODO(pbar) Handle device IDs and prefix properly. - const string prefix = ""; - const int id = 0; - const string stream_device = - strings::StrCat(prefix, "/device:GPU:", id, "/stream:"); - const string memcpy_device = - strings::StrCat(prefix, "/device:GPU:", id, "/memcpy"); - - mutex_lock l2(trace_mu_); - for (const auto &rec : kernel_records_) { - auto it = correlations_.find(rec.correlation_id); - const string name = (it != correlations_.cend()) ? it->second : "unknown"; - NodeExecStats *ns = new NodeExecStats; - ns->set_all_start_micros(start_walltime_us_ + - ((rec.start_timestamp - start_timestamp_) / 1000)); - ns->set_op_start_rel_micros(0); - auto elapsed_us = - std::max((rec.end_timestamp - rec.start_timestamp) / 1000, 1); - ns->set_op_end_rel_micros(elapsed_us); - ns->set_all_end_rel_micros(elapsed_us); - ns->set_node_name(name); - // TODO(pbar) Generate details based on the kernel activity record. - // ns->set_timeline_label(details); - auto nscopy = new NodeExecStats; - *nscopy = *ns; - collector->Save(strings::StrCat(stream_device, "all"), ns); - collector->Save(strings::StrCat(stream_device, rec.stream_id), nscopy); - } - for (const auto &rec : memcpy_records_) { - auto it = correlations_.find(rec.correlation_id); - const string name = (it != correlations_.cend()) ? it->second : "unknown"; - NodeExecStats *ns = new NodeExecStats; - ns->set_all_start_micros(start_walltime_us_ + - ((rec.start_timestamp - start_timestamp_) / 1000)); - ns->set_op_start_rel_micros(0); - auto elapsed_us = - std::max((rec.end_timestamp - rec.start_timestamp) / 1000, 1); - ns->set_op_end_rel_micros(elapsed_us); - ns->set_all_end_rel_micros(elapsed_us); - auto copyKind = static_cast(rec.copyKind); - auto srcKind = static_cast(rec.srcKind); - auto dstKind = static_cast(rec.dstKind); - const string details = strings::Printf( - "MEMCPY%s %llu bytes (%s to %s)", getMemcpyKindString(copyKind), - rec.bytes, getMemoryKindString(srcKind), getMemoryKindString(dstKind)); - ns->set_node_name( - strings::StrCat(name, ":MEMCPY", getMemcpyKindString(copyKind))); - ns->set_timeline_label(details); - auto nscopy = new NodeExecStats; - *nscopy = *ns; - collector->Save(memcpy_device, ns); - collector->Save(strings::StrCat(stream_device, rec.stream_id), nscopy); - } - - host_tracer_->CollectDataToCollector(collector).IgnoreError(); + StepStatsCollector step_stats_collector(run_metadata->mutable_step_stats()); + TF_RETURN_IF_ERROR( + CudaEventCollector::Collect(recorder_.get(), &step_stats_collector)); + step_stats_collector.Finalize(); return Status::OK(); } +} // namespace -} // namespace devicetracer - -std::unique_ptr CreateDeviceTracer() { - devicetracer::CUPTIManager *cupti_manager = devicetracer::GetCUPTIManager(); - if (cupti_manager == nullptr) { +// Not in anonymous namespace for testing purposes. +std::unique_ptr CreateDeviceTracer( + const ProfilerContext*) { + auto status = cuInit(0); + if (status != CUDA_SUCCESS) { + LogIfError(ToStatus(status)); return nullptr; } - std::unique_ptr tracer( - new devicetracer::DeviceTracerImpl(cupti_manager)); - return tracer; + return absl::make_unique(); } -} // namespace tensorflow - -#else // GOOGLE_CUDA - -namespace tensorflow { - -std::unique_ptr CreateDeviceTracer() { return nullptr; } +auto register_device_tracer_factory = [] { + bool enable; + TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_OSS_GPU_PROFILER", true, &enable)); + if (enable) { + RegisterProfilerFactory(&CreateDeviceTracer); + } + return 0; +}(); } // namespace tensorflow - #endif // GOOGLE_CUDA diff --git a/tensorflow/core/platform/default/gpu/BUILD b/tensorflow/core/platform/default/gpu/BUILD deleted file mode 100644 index 3965c7d2ec6..00000000000 --- a/tensorflow/core/platform/default/gpu/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load( - "//tensorflow:tensorflow.bzl", - "tf_copts", - "tf_cuda_library", -) - -tf_cuda_library( - name = "cupti_wrapper", - srcs = [ - "cupti_wrapper.cc", - ], - hdrs = [ - "cupti_wrapper.h", - ], - copts = tf_copts(), - cuda_deps = [ - "//tensorflow/core:stream_executor", - "@local_config_cuda//cuda:cupti_headers", - ], - data = ["@local_config_cuda//cuda:cupti_dsos"], - visibility = ["//visibility:public"], -) diff --git a/tensorflow/core/platform/default/gpu/cupti_wrapper.cc b/tensorflow/core/platform/default/gpu/cupti_wrapper.cc deleted file mode 100644 index 481bbf9bae1..00000000000 --- a/tensorflow/core/platform/default/gpu/cupti_wrapper.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/platform/default/gpu/cupti_wrapper.h" - -#if GOOGLE_CUDA - -#include - -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/stream_executor.h" - -namespace perftools { -namespace gputools { -namespace profiler { - -namespace dynload { - -#define LIBCUPTI_WRAP(__name) \ - struct DynLoadShim__##__name { \ - static const char* kName; \ - using FuncPointerT = std::add_pointer::type; \ - template \ - CUptiResult operator()(Args... args) { \ - static auto fn = []() -> FuncPointerT { \ - auto handle_or = \ - stream_executor::internal::CachedDsoLoader::GetCuptiDsoHandle(); \ - if (!handle_or.ok()) return nullptr; \ - void* symbol; \ - stream_executor::port::Env::Default() \ - ->GetSymbolFromLibrary(handle_or.ValueOrDie(), kName, &symbol) \ - .IgnoreError(); \ - return reinterpret_cast(symbol); \ - }(); \ - if (fn == nullptr) return CUPTI_ERROR_UNKNOWN; \ - return fn(args...); \ - } \ - } __name; \ - const char* DynLoadShim__##__name::kName = #__name; - -LIBCUPTI_WRAP(cuptiActivityDisable); -LIBCUPTI_WRAP(cuptiActivityEnable); -LIBCUPTI_WRAP(cuptiActivityFlushAll); -LIBCUPTI_WRAP(cuptiActivityGetNextRecord); -LIBCUPTI_WRAP(cuptiActivityGetNumDroppedRecords); -LIBCUPTI_WRAP(cuptiActivityRegisterCallbacks); -LIBCUPTI_WRAP(cuptiGetTimestamp); -LIBCUPTI_WRAP(cuptiEnableCallback); -LIBCUPTI_WRAP(cuptiEnableDomain); -LIBCUPTI_WRAP(cuptiSubscribe); -LIBCUPTI_WRAP(cuptiUnsubscribe); - -} // namespace dynload - -CUptiResult CuptiWrapper::ActivityDisable(CUpti_ActivityKind kind) { - return dynload::cuptiActivityDisable(kind); -} - -CUptiResult CuptiWrapper::ActivityEnable(CUpti_ActivityKind kind) { - return dynload::cuptiActivityEnable(kind); -} - -CUptiResult CuptiWrapper::ActivityFlushAll(uint32_t flag) { - return dynload::cuptiActivityFlushAll(flag); -} - -CUptiResult CuptiWrapper::ActivityGetNextRecord(uint8_t* buffer, - size_t valid_buffer_size_bytes, - CUpti_Activity** record) { - return dynload::cuptiActivityGetNextRecord(buffer, valid_buffer_size_bytes, - record); -} - -CUptiResult CuptiWrapper::ActivityGetNumDroppedRecords(CUcontext context, - uint32_t stream_id, - size_t* dropped) { - return dynload::cuptiActivityGetNumDroppedRecords(context, stream_id, - dropped); -} - -CUptiResult CuptiWrapper::ActivityRegisterCallbacks( - CUpti_BuffersCallbackRequestFunc func_buffer_requested, - CUpti_BuffersCallbackCompleteFunc func_buffer_completed) { - return dynload::cuptiActivityRegisterCallbacks(func_buffer_requested, - func_buffer_completed); -} - -CUptiResult CuptiWrapper::GetTimestamp(uint64_t* timestamp) { - return dynload::cuptiGetTimestamp(timestamp); -} - -CUptiResult CuptiWrapper::EnableCallback(uint32_t enable, - CUpti_SubscriberHandle subscriber, - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid) { - return dynload::cuptiEnableCallback(enable, subscriber, domain, cbid); -} - -CUptiResult CuptiWrapper::EnableDomain(uint32_t enable, - CUpti_SubscriberHandle subscriber, - CUpti_CallbackDomain domain) { - return dynload::cuptiEnableDomain(enable, subscriber, domain); -} - -CUptiResult CuptiWrapper::Subscribe(CUpti_SubscriberHandle* subscriber, - CUpti_CallbackFunc callback, - void* userdata) { - return dynload::cuptiSubscribe(subscriber, callback, userdata); -} - -CUptiResult CuptiWrapper::Unsubscribe(CUpti_SubscriberHandle subscriber) { - return dynload::cuptiUnsubscribe(subscriber); -} - -} // namespace profiler -} // namespace gputools -} // namespace perftools - -#endif // GOOGLE_CUDA diff --git a/tensorflow/core/platform/default/gpu/cupti_wrapper.h b/tensorflow/core/platform/default/gpu/cupti_wrapper.h deleted file mode 100644 index e3ebe6ca1d0..00000000000 --- a/tensorflow/core/platform/default/gpu/cupti_wrapper.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ -#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ - -#if GOOGLE_CUDA - -#include -#include -#if defined(WIN32) -#include "extras/CUPTI/include/cupti.h" -#else -#include "cupti.h" -#endif -namespace perftools { -namespace gputools { -namespace profiler { - -// Wraps the CUPTI API so that we can dynamically load the library. -class CuptiWrapper { - public: - CuptiWrapper() {} - - // CUPTI activity API - CUptiResult ActivityDisable(CUpti_ActivityKind kind); - - CUptiResult ActivityEnable(CUpti_ActivityKind kind); - - CUptiResult ActivityFlushAll(uint32_t flag); - - CUptiResult ActivityGetNextRecord(uint8_t* buffer, - size_t valid_buffer_size_bytes, - CUpti_Activity** record); - - CUptiResult ActivityGetNumDroppedRecords(CUcontext context, - uint32_t stream_id, size_t* dropped); - - CUptiResult ActivityRegisterCallbacks( - CUpti_BuffersCallbackRequestFunc func_buffer_requested, - CUpti_BuffersCallbackCompleteFunc func_buffer_completed); - - CUptiResult GetDeviceId(CUcontext context, uint32_t* deviceId); - - CUptiResult GetTimestamp(uint64_t* timestamp); - - // CUPTI callback API - CUptiResult EnableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber, - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid); - - CUptiResult EnableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber, - CUpti_CallbackDomain domain); - - CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber, - CUpti_CallbackFunc callback, void* userdata); - - CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber); -}; - -} // namespace profiler -} // namespace gputools -} // namespace perftools - -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_ diff --git a/tensorflow/core/platform/default/human_readable_json.cc b/tensorflow/core/platform/default/human_readable_json.cc index 977ff1272ea..6f0b97d215f 100644 --- a/tensorflow/core/platform/default/human_readable_json.cc +++ b/tensorflow/core/platform/default/human_readable_json.cc @@ -20,8 +20,8 @@ limitations under the License. namespace tensorflow { -Status ProtoToHumanReadableJson(const protobuf::Message& proto, - string* result) { +Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result, + bool ignore_accuracy_loss) { #ifdef TENSORFLOW_LITE_PROTOS *result = "[human readable output not available on Android]"; return Status::OK(); diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java b/tensorflow/core/platform/default/monitoring.cc similarity index 77% rename from tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java rename to tensorflow/core/platform/default/monitoring.cc index 211d7e66fb2..71ece3e3c14 100644 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java +++ b/tensorflow/core/platform/default/monitoring.cc @@ -13,11 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -package org.tensorflow.demo; +#include "tensorflow/core/platform/monitoring.h" -import java.util.List; -import org.tensorflow.demo.Classifier.Recognition; +namespace tensorflow { +namespace monitoring { -public interface ResultsView { - public void setResults(final List results); -} +void StartExporter() {} + +void ExportMetrics() {} + +} // namespace monitoring +} // namespace tensorflow diff --git a/tensorflow/core/platform/default/stacktrace.h b/tensorflow/core/platform/default/stacktrace.h index b64bc159710..808ef25c430 100644 --- a/tensorflow/core/platform/default/stacktrace.h +++ b/tensorflow/core/platform/default/stacktrace.h @@ -16,7 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STACKTRACE_H_ #define TENSORFLOW_CORE_PLATFORM_DEFAULT_STACKTRACE_H_ +// clang-format off #include "tensorflow/core/platform/platform.h" +// clang-format on + #if !defined(IS_MOBILE_PLATFORM) && !defined(PLATFORM_WINDOWS) && \ defined(PLATFORM_POSIX) && (defined(__clang__) || defined(__GNUC__)) #define TF_GENERATE_BACKTRACE diff --git a/tensorflow/core/platform/device_tracer.h b/tensorflow/core/platform/device_tracer.h deleted file mode 100644 index d0f86a51030..00000000000 --- a/tensorflow/core/platform/device_tracer.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PLATFORM_DEVICE_TRACER_H_ -#define TENSORFLOW_CORE_PLATFORM_DEVICE_TRACER_H_ - -#include - -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -class StepStatsCollector; - -// 'DeviceTracer' is an interface for collecting low-level execution timings -// of hardware accelerator (e.g. GPU) computation and DMA transfers. -// -// Typical usage pattern is as follows: -// -// DeviceTracer* tracer = CreateDeviceTracer(); -// if (tracer) { -// tracer->Start(); -// -// ... perform some computations on a hardware accelerator. -// -// tracer->Stop(); -// -// StepStats stats; -// StepStatsCollector collector(&stats); -// tracer->Collect(&collector); -// } -// -// Notes: -// Tracing is not supported on all plaforms. On platforms -// with no tracing support, 'CreateDeviceTracer' will return 'nullptr'. -// On most plaforms, hardware tracing will be a system-wide activity and -// a single 'DeviceTracer' will collect activity from all devices. -// It is also common that only a single tracer may be active at any -// given time. The 'Start' method will return an error if tracing is -// already in progress elsewhere. -// -class DeviceTracer { - public: - virtual ~DeviceTracer() {} - - // Start device tracing. - // Note that only a single trace can be active, in which case this - // methods will return an 'Unavailable' error. - virtual Status Start() = 0; - - // Stop device tracing. - // It is safe to call 'Stop' on a tracer which is not enabled. - virtual Status Stop() = 0; - - // Collect trace results. Results are added to the specified - // StepStatsCollector. Does not clear any existing stats. - // It is an error to call 'Collect' while a trace is running. - virtual Status Collect(StepStatsCollector* collector) = 0; -}; - -// Creates a platform-specific DeviceTracer. -// Returns 'nullptr' on platforms where tracing is not supported. -std::unique_ptr CreateDeviceTracer(); - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PLATFORM_DEVICE_TRACER_H_ diff --git a/tensorflow/core/platform/device_tracer_test.cc b/tensorflow/core/platform/device_tracer_test.cc index 89f14e905af..d90e1265817 100644 --- a/tensorflow/core/platform/device_tracer_test.cc +++ b/tensorflow/core/platform/device_tracer_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/device_tracer.h" - #include #include #include @@ -36,10 +34,24 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/profiler/internal/profiler_interface.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +struct ProfilerContext; + +#if GOOGLE_CUDA +std::unique_ptr CreateDeviceTracer( + const ProfilerContext*); +#else +// We don't have device tracer for non-cuda case. +std::unique_ptr CreateDeviceTracer( + const ProfilerContext*) { + return nullptr; +} +#endif + namespace { std::unique_ptr CreateSession() { @@ -99,42 +111,40 @@ class DeviceTracerTest : public ::testing::Test { }; TEST_F(DeviceTracerTest, StartStop) { - std::unique_ptr tracer(CreateDeviceTracer()); + auto tracer = CreateDeviceTracer(nullptr); if (!tracer) return; TF_EXPECT_OK(tracer->Start()); TF_EXPECT_OK(tracer->Stop()); } TEST_F(DeviceTracerTest, StopBeforeStart) { - std::unique_ptr tracer(CreateDeviceTracer()); + auto tracer = CreateDeviceTracer(nullptr); if (!tracer) return; TF_EXPECT_OK(tracer->Stop()); TF_EXPECT_OK(tracer->Stop()); } TEST_F(DeviceTracerTest, CollectBeforeStart) { - std::unique_ptr tracer(CreateDeviceTracer()); + auto tracer = CreateDeviceTracer(nullptr); if (!tracer) return; - StepStats stats; - StepStatsCollector collector(&stats); - TF_EXPECT_OK(tracer->Collect(&collector)); - EXPECT_EQ(stats.dev_stats_size(), 0); + RunMetadata run_metadata; + TF_EXPECT_OK(tracer->CollectData(&run_metadata)); + EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0); } TEST_F(DeviceTracerTest, CollectBeforeStop) { - std::unique_ptr tracer(CreateDeviceTracer()); + auto tracer = CreateDeviceTracer(nullptr); if (!tracer) return; TF_EXPECT_OK(tracer->Start()); - StepStats stats; - StepStatsCollector collector(&stats); - Status status = tracer->Collect(&collector); + RunMetadata run_metadata; + Status status = tracer->CollectData(&run_metadata); ExpectFailure(status, tensorflow::error::FAILED_PRECONDITION); TF_EXPECT_OK(tracer->Stop()); } TEST_F(DeviceTracerTest, StartTwoTracers) { - std::unique_ptr tracer1(CreateDeviceTracer()); - std::unique_ptr tracer2(CreateDeviceTracer()); + auto tracer1 = CreateDeviceTracer(nullptr); + auto tracer2 = CreateDeviceTracer(nullptr); if (!tracer1 || !tracer2) return; TF_EXPECT_OK(tracer1->Start()); @@ -147,7 +157,7 @@ TEST_F(DeviceTracerTest, StartTwoTracers) { TEST_F(DeviceTracerTest, RunWithTracer) { // On non-GPU platforms, we may not support DeviceTracer. - std::unique_ptr tracer(CreateDeviceTracer()); + auto tracer = CreateDeviceTracer(nullptr); if (!tracer) return; Initialize({3, 2, -1, 0}); @@ -174,7 +184,7 @@ TEST_F(DeviceTracerTest, RunWithTracer) { } TEST_F(DeviceTracerTest, TraceToStepStatsCollector) { - std::unique_ptr tracer(CreateDeviceTracer()); + auto tracer = CreateDeviceTracer(nullptr); if (!tracer) return; Initialize({3, 2, -1, 0}); @@ -193,13 +203,12 @@ TEST_F(DeviceTracerTest, TraceToStepStatsCollector) { TF_ASSERT_OK(s); TF_ASSERT_OK(tracer->Stop()); - StepStats stats; - StepStatsCollector collector(&stats); - TF_ASSERT_OK(tracer->Collect(&collector)); - collector.Finalize(); + RunMetadata run_metadata; + TF_ASSERT_OK(tracer->CollectData(&run_metadata)); // Depending on whether this runs on CPU or GPU, we will have a // different number of devices. - EXPECT_GE(stats.dev_stats_size(), 1) << "Saw stats: " << stats.DebugString(); + EXPECT_GE(run_metadata.step_stats().dev_stats_size(), 1) + << "Saw stats: " << run_metadata.DebugString(); } TEST_F(DeviceTracerTest, RunWithTraceOption) { diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index 59768bf92ae..1037b2b918b 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -517,7 +517,8 @@ Status ReadBinaryProto(Env* env, const string& fname, // respectively. coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - if (!proto->ParseFromCodedStream(&coded_stream)) { + if (!proto->ParseFromCodedStream(&coded_stream) || + !coded_stream.ConsumedEntireMessage()) { TF_RETURN_IF_ERROR(stream->status()); return errors::DataLoss("Can't parse ", fname, " as binary proto"); } diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index 280076e098d..a9876690966 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -251,13 +251,13 @@ class Env { // provide a routine to get the absolute time. /// \brief Returns the number of nano-seconds since the Unix epoch. - virtual uint64 NowNanos() { return envTime->NowNanos(); } + virtual uint64 NowNanos() const { return env_time_->NowNanos(); } /// \brief Returns the number of micro-seconds since the Unix epoch. - virtual uint64 NowMicros() { return envTime->NowMicros(); } + virtual uint64 NowMicros() const { return env_time_->NowMicros(); } /// \brief Returns the number of seconds since the Unix epoch. - virtual uint64 NowSeconds() { return envTime->NowSeconds(); } + virtual uint64 NowSeconds() const { return env_time_->NowSeconds(); } /// Sleeps/delays the thread for the prescribed number of micro-seconds. virtual void SleepForMicroseconds(int64 micros) = 0; @@ -327,7 +327,7 @@ class Env { private: std::unique_ptr file_system_registry_; TF_DISALLOW_COPY_AND_ASSIGN(Env); - EnvTime* envTime = EnvTime::Default(); + EnvTime* env_time_ = EnvTime::Default(); }; /// \brief An implementation of Env that forwards all calls to another Env. @@ -338,7 +338,7 @@ class EnvWrapper : public Env { public: /// Initializes an EnvWrapper that delegates all calls to *t explicit EnvWrapper(Env* t) : target_(t) {} - virtual ~EnvWrapper(); + ~EnvWrapper() override; /// Returns the target to which this Env forwards all calls Env* target() const { return target_; } @@ -361,7 +361,7 @@ class EnvWrapper : public Env { return target_->MatchPath(path, pattern); } - uint64 NowMicros() override { return target_->NowMicros(); } + uint64 NowMicros() const override { return target_->NowMicros(); } void SleepForMicroseconds(int64 micros) override { target_->SleepForMicroseconds(micros); } diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index ea1f1234247..593727e850d 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -51,6 +51,11 @@ GraphDef CreateTestProto() { return g; } +static void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(str_util::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + } // namespace string BaseDir() { return io::JoinPath(testing::TmpDir(), "base_dir"); } @@ -408,4 +413,19 @@ TEST_F(DefaultEnvTest, GetThreadInformation) { #endif } +TEST_F(DefaultEnvTest, GetChildThreadInformation) { + Env* env = Env::Default(); + Thread* child_thread = env->StartThread({}, "tf_child_thread", [env]() { + // TODO(fishx): Turn on this test for Apple. +#if !defined(__APPLE__) + EXPECT_NE(env->GetCurrentThreadId(), 0); +#endif + string thread_name; + bool res = env->GetCurrentThreadName(&thread_name); + EXPECT_TRUE(res); + ExpectHasSubstr(thread_name, "tf_child_thread"); + }); + delete child_thread; +} + } // namespace tensorflow diff --git a/tensorflow/core/platform/env_time.h b/tensorflow/core/platform/env_time.h index c12b6ba6fb8..1b791cef374 100644 --- a/tensorflow/core/platform/env_time.h +++ b/tensorflow/core/platform/env_time.h @@ -43,13 +43,13 @@ class EnvTime { static EnvTime* Default(); /// \brief Returns the number of nano-seconds since the Unix epoch. - virtual uint64 NowNanos() = 0; + virtual uint64 NowNanos() const = 0; /// \brief Returns the number of micro-seconds since the Unix epoch. - virtual uint64 NowMicros() { return NowNanos() / kMicrosToNanos; } + virtual uint64 NowMicros() const { return NowNanos() / kMicrosToNanos; } /// \brief Returns the number of seconds since the Unix epoch. - virtual uint64 NowSeconds() { return NowNanos() / kSecondsToNanos; } + virtual uint64 NowSeconds() const { return NowNanos() / kSecondsToNanos; } }; } // namespace tensorflow diff --git a/tensorflow/core/platform/fingerprint.h b/tensorflow/core/platform/fingerprint.h index 720dc4c3d6b..ae41a8e541a 100644 --- a/tensorflow/core/platform/fingerprint.h +++ b/tensorflow/core/platform/fingerprint.h @@ -19,6 +19,16 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" +// The following line is used by copybara to set or unset the USE_OSS_FARMHASH +// preprocessor symbol as needed. Please do not remove. +#define USE_OSS_FARMHASH + +#ifdef USE_OSS_FARMHASH +#include +#else +#include "util/hash/farmhash_fingerprint.h" +#endif + namespace tensorflow { struct Fprint128 { @@ -37,13 +47,6 @@ struct Fprint128Hasher { } }; -// This is a portable fingerprint interface for strings that will never change. -// However, it is not suitable for cryptography. -uint64 Fingerprint64(StringPiece s); - -// 128-bit variant of Fingerprint64 above (same properties and caveats apply). -Fprint128 Fingerprint128(StringPiece s); - namespace internal { // Mixes some of the bits that got propagated to the high bits back into the // low bits. @@ -72,12 +75,33 @@ inline uint64 FingerprintCat64(const uint64 fp1, const uint64 fp2) { return result; } +// This is a portable fingerprint interface for strings that will never change. +// However, it is not suitable for cryptography. +inline uint64 Fingerprint64(StringPiece s) { +#ifdef USE_OSS_FARMHASH + return ::util::Fingerprint64(s.data(), s.size()); +#else + // Fingerprint op depends on the fact that Fingerprint64() is implemented by + // Farmhash. If the implementation ever changes, Fingerprint op should be + // modified to keep using Farmhash. + // LINT.IfChange + return farmhash::Fingerprint64(s.data(), s.size()); + // LINT.ThenChange(//third_party/tensorflow/core/kernels/fingerprint_op.cc) +#endif +} + +// 128-bit variant of Fingerprint64 above (same properties and caveats apply). +inline Fprint128 Fingerprint128(StringPiece s) { +#ifdef USE_OSS_FARMHASH + const auto fingerprint = ::util::Fingerprint128(s.data(), s.size()); + return {::util::Uint128Low64(fingerprint), + ::util::Uint128High64(fingerprint)}; +#else + const auto fingerprint = farmhash::Fingerprint128(s.data(), s.size()); + return {absl::Uint128Low64(fingerprint), absl::Uint128High64(fingerprint)}; +#endif +} + } // namespace tensorflow -#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) -#include "tensorflow/core/platform/google/fingerprint.h" -#else -#include "tensorflow/core/platform/default/fingerprint.h" -#endif - #endif // TENSORFLOW_CORE_PLATFORM_FINGERPRINT_H_ diff --git a/tensorflow/core/platform/human_readable_json.h b/tensorflow/core/platform/human_readable_json.h index c759e801e97..49908eac7c8 100644 --- a/tensorflow/core/platform/human_readable_json.h +++ b/tensorflow/core/platform/human_readable_json.h @@ -26,7 +26,11 @@ namespace tensorflow { // // This string may not be strictly JSON-compliant, but it must be parseable by // HumanReadableJSONToProto. -Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result); +// +// When ignore_accuracy_loss = true, this function may ignore JavaScript +// accuracy loss with large integers. +Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result, + bool ignore_accuracy_loss); // Converts a string produced by ProtoToHumanReadableJSON to a protobuf. Not // guaranteed to work for general JSON. diff --git a/tensorflow/core/platform/monitoring.h b/tensorflow/core/platform/monitoring.h new file mode 100644 index 00000000000..f01233933c3 --- /dev/null +++ b/tensorflow/core/platform/monitoring.h @@ -0,0 +1,38 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_MONITORING_H_ +#define TENSORFLOW_CORE_PLATFORM_MONITORING_H_ + +namespace tensorflow { +namespace monitoring { + +// Starts exporting metrics through a platform-specific monitoring API (if +// provided). For builds using "tensorflow/core/platform/default", this is +// currently a no-op. This function is idempotent. +// +// The TensorFlow runtime will call this the first time a new session is created +// using the NewSession() method or an Eager Context is created. +void StartExporter(); + +// Manually invokes a one time metrics export through a platform-specific +// monitoring API (if provided). For builds using +// "tensorflow/core/platform/default", this is currently a no-op. +void ExportMetrics(); + +} // namespace monitoring +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_MONITORING_H_ diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc index f2dff5a9b64..2700a269c4d 100644 --- a/tensorflow/core/platform/posix/env.cc +++ b/tensorflow/core/platform/posix/env.cc @@ -32,19 +32,37 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/load_library.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/posix/posix_file_system.h" namespace tensorflow { namespace { +mutex name_mutex(tensorflow::LINKER_INITIALIZED); + +std::map& GetThreadNameRegistry() + EXCLUSIVE_LOCKS_REQUIRED(name_mutex) { + static auto* thread_name_registry = new std::map(); + return *thread_name_registry; +} + class StdThread : public Thread { public: - // name and thread_options are both ignored. + // thread_options is ignored. StdThread(const ThreadOptions& thread_options, const string& name, std::function fn) - : thread_(fn) {} - ~StdThread() override { thread_.join(); } + : thread_(fn) { + mutex_lock l(name_mutex); + GetThreadNameRegistry().emplace(thread_.get_id(), name); + } + + ~StdThread() override { + std::thread::id thread_id = thread_.get_id(); + thread_.join(); + mutex_lock l(name_mutex); + GetThreadNameRegistry().erase(thread_id); + } private: std::thread thread_; @@ -102,6 +120,15 @@ class PosixEnv : public Env { } bool GetCurrentThreadName(string* name) override { + { + mutex_lock l(name_mutex); + auto thread_name = + GetThreadNameRegistry().find(std::this_thread::get_id()); + if (thread_name != GetThreadNameRegistry().end()) { + *name = thread_name->second; + return true; + } + } #if defined(__ANDROID__) || defined(__EMSCRIPTEN__) return false; #else diff --git a/tensorflow/core/platform/posix/env_time.cc b/tensorflow/core/platform/posix/env_time.cc index 59a67b17aab..e7658108654 100644 --- a/tensorflow/core/platform/posix/env_time.cc +++ b/tensorflow/core/platform/posix/env_time.cc @@ -26,7 +26,7 @@ class PosixEnvTime : public EnvTime { public: PosixEnvTime() {} - uint64 NowNanos() override { + uint64 NowNanos() const override { struct timespec ts; clock_gettime(CLOCK_REALTIME, &ts); return (static_cast(ts.tv_sec) * kSecondsToNanos + diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc index 1561632a49a..13a904295c1 100644 --- a/tensorflow/core/platform/posix/port.cc +++ b/tensorflow/core/platform/posix/port.cc @@ -80,6 +80,8 @@ int NumSchedulableCPUs() { return kDefaultCores; } +int MaxParallelism() { return NumSchedulableCPUs(); } + int NumTotalCPUs() { int count = absl::base_internal::NumCPUs(); return (count <= 0) ? kUnknownCPU : count; diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc index 083284c5ff9..10f0950c0df 100644 --- a/tensorflow/core/platform/posix/posix_file_system.cc +++ b/tensorflow/core/platform/posix/posix_file_system.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include #if defined(__linux__) @@ -62,7 +63,16 @@ class PosixRandomAccessFile : public RandomAccessFile { Status s; char* dst = scratch; while (n > 0 && s.ok()) { - ssize_t r = pread(fd_, dst, n, static_cast(offset)); + // Some platforms, notably macs, throw EINVAL if pread is asked to read + // more than fits in a 32-bit integer. + size_t requested_read_length; + if (n > INT32_MAX) { + requested_read_length = INT32_MAX; + } else { + requested_read_length = n; + } + ssize_t r = + pread(fd_, dst, requested_read_length, static_cast(offset)); if (r > 0) { dst += r; n -= r; @@ -105,6 +115,9 @@ class PosixWritableFile : public WritableFile { } Status Close() override { + if (file_ == nullptr) { + return IOError(filename_, EBADF); + } Status result; if (fclose(file_) != 0) { result = IOError(filename_, errno); @@ -323,12 +336,13 @@ Status PosixFileSystem::CopyFile(const string& src, const string& target) { return IOError(src, errno); } string translated_target = TranslateName(target); - // O_WRONLY | O_CREAT: + // O_WRONLY | O_CREAT | O_TRUNC: // Open file for write and if file does not exist, create the file. - // S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH: - // Create the file with permission of 0644 - int target_fd = open(translated_target.c_str(), O_WRONLY | O_CREAT, - S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); + // If file exists, truncate its size to 0. + // When creating file, use the same permissions as original + mode_t mode = sbuf.st_mode & (S_IRWXU | S_IRWXG | S_IRWXO); + int target_fd = + open(translated_target.c_str(), O_WRONLY | O_CREAT | O_TRUNC, mode); if (target_fd < 0) { close(src_fd); return IOError(target, errno); diff --git a/tensorflow/core/platform/protobuf_util.cc b/tensorflow/core/platform/protobuf_util.cc index 5eccddfb15b..e46a77fa25f 100644 --- a/tensorflow/core/platform/protobuf_util.cc +++ b/tensorflow/core/platform/protobuf_util.cc @@ -19,15 +19,12 @@ namespace tensorflow { bool ParseProtoUnlimited(protobuf::MessageLite* proto, const string& serialized) { - return ParseProtoUnlimited(proto, serialized.data(), serialized.size()); + return proto->ParseFromString(serialized); } bool ParseProtoUnlimited(protobuf::MessageLite* proto, const void* serialized, size_t size) { - protobuf::io::CodedInputStream coded_stream( - reinterpret_cast(serialized), size); - coded_stream.SetTotalBytesLimit(INT_MAX, INT_MAX); - return proto->ParseFromCodedStream(&coded_stream); + return proto->ParseFromArray(serialized, size); } } // namespace tensorflow diff --git a/tensorflow/core/platform/cupti_wrapper.h b/tensorflow/core/platform/rocm.h similarity index 69% rename from tensorflow/core/platform/cupti_wrapper.h rename to tensorflow/core/platform/rocm.h index 9a17ab60c0d..1896cc3d84c 100644 --- a/tensorflow/core/platform/cupti_wrapper.h +++ b/tensorflow/core/platform/rocm.h @@ -13,15 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ -#define TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_ROCM_H_ +#define TENSORFLOW_CORE_PLATFORM_ROCM_H_ #include "tensorflow/core/platform/platform.h" +#include "tensorflow/stream_executor/rocm/rocm_activation.h" -#if defined(PLATFORM_GOOGLE) -#include "tensorflow/core/platform/google/cupti_wrapper.h" -#else -#include "tensorflow/core/platform/default/gpu/cupti_wrapper.h" -#endif - -#endif // TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_ +#endif // TENSORFLOW_CORE_PLATFORM_ROCM_H_ diff --git a/tensorflow/core/platform/strong_hash.h b/tensorflow/core/platform/strong_hash.h index 999fd2e4b30..cbd267f90ed 100644 --- a/tensorflow/core/platform/strong_hash.h +++ b/tensorflow/core/platform/strong_hash.h @@ -24,7 +24,7 @@ namespace tensorflow { // This is a strong keyed hash function interface for strings. // The hash function is deterministic on the content of the string within the // process. The key of the hash is an array of 2 uint64 elements. -// A strong hash make it dificult, if not infeasible, to compute inputs that +// A strong hash makes it difficult, if not infeasible, to compute inputs that // hash to the same bucket. // // Usage: diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h index aefbe64425a..9b2886f1c42 100644 --- a/tensorflow/core/platform/tracing.h +++ b/tensorflow/core/platform/tracing.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" @@ -151,18 +152,10 @@ class TraceCollector { virtual ~TraceCollector() {} virtual std::unique_ptr CreateAnnotationHandle( StringPiece name_part1, StringPiece name_part2) const = 0; - virtual std::unique_ptr CreateActivityHandle( - StringPiece name_part1, StringPiece name_part2, - bool is_expensive) const = 0; // Returns true if this annotation tracing is enabled for any op. virtual bool IsEnabledForAnnotations() const = 0; - // Returns true if this activity handle tracking is enabled for an op of the - // given expensiveness. - virtual bool IsEnabledForActivities(bool is_expensive) const = 0; - - protected: static string ConcatenateNames(StringPiece first, StringPiece second); private: @@ -200,34 +193,10 @@ class ScopedAnnotation { : nullptr; }()) {} - bool IsEnabled() const { return static_cast(handle_); } - - private: - std::unique_ptr handle_; -}; - -// Adds an activity through the currently registered TraceCollector. -// The activity starts when an object of this class is created and stops when -// the object is destroyed. -class ScopedActivity { - public: - explicit ScopedActivity(StringPiece name, bool is_expensive = true) - : ScopedActivity(name, StringPiece(), is_expensive) {} - - // If tracing is enabled, set up an activity with a label of - // ":". This can be cheaper than the - // single-argument constructor because the concatenation of the - // label string is only done if tracing is enabled. - ScopedActivity(StringPiece name_part1, StringPiece name_part2, - bool is_expensive = true) - : handle_([&] { - auto trace_collector = GetTraceCollector(); - return trace_collector ? trace_collector->CreateActivityHandle( - name_part1, name_part2, is_expensive) - : nullptr; - }()) {} - - bool IsEnabled() const { return static_cast(handle_); } + static bool IsEnabled() { + auto* trace_collector = GetTraceCollector(); + return trace_collector && trace_collector->IsEnabledForAnnotations(); + } private: std::unique_ptr handle_; diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h index a4fa790317f..b82d9cc3247 100644 --- a/tensorflow/core/platform/types.h +++ b/tensorflow/core/platform/types.h @@ -33,14 +33,8 @@ limitations under the License. namespace tensorflow { -// Define tensorflow::string to refer to appropriate platform specific type. -// TODO(josh11b): Move this into the platform/*/integral_types.h files -// above, and rename them platform/*/types.h. -#if defined(PLATFORM_GOOGLE) -using ::string; -#else +// Alias tensorflow::string to std::string. using std::string; -#endif static const uint8 kuint8max = ((uint8)0xFF); static const uint16 kuint16max = ((uint16)0xFFFF); diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc index e0e3dda7055..fedbd674d5f 100644 --- a/tensorflow/core/platform/windows/env.cc +++ b/tensorflow/core/platform/windows/env.cc @@ -40,13 +40,30 @@ namespace tensorflow { namespace { +mutex name_mutex(tensorflow::LINKER_INITIALIZED); + +std::map& GetThreadNameRegistry() + EXCLUSIVE_LOCKS_REQUIRED(name_mutex) { + static auto* thread_name_registry = new std::map(); + return *thread_name_registry; +} + class StdThread : public Thread { public: - // name and thread_options are both ignored. + // thread_options is ignored. StdThread(const ThreadOptions& thread_options, const string& name, std::function fn) - : thread_(fn) {} - ~StdThread() { thread_.join(); } + : thread_(fn) { + mutex_lock l(name_mutex); + GetThreadNameRegistry().emplace(thread_.get_id(), name); + } + + ~StdThread() override { + std::thread::id thread_id = thread_.get_id(); + thread_.join(); + mutex_lock l(name_mutex); + GetThreadNameRegistry().erase(thread_id); + } private: std::thread thread_; @@ -88,7 +105,16 @@ class WindowsEnv : public Env { return static_cast(::GetCurrentThreadId()); } - bool GetCurrentThreadName(string* name) override { return false; } + bool GetCurrentThreadName(string* name) override { + mutex_lock l(name_mutex); + auto thread_name = GetThreadNameRegistry().find(std::this_thread::get_id()); + if (thread_name != GetThreadNameRegistry().end()) { + *name = thread_name->second; + return true; + } else { + return false; + } + } static VOID CALLBACK SchedClosureCallback(PTP_CALLBACK_INSTANCE Instance, PVOID Context, PTP_WORK Work) { diff --git a/tensorflow/core/platform/windows/env_time.cc b/tensorflow/core/platform/windows/env_time.cc index b1713f695c5..f6d77dc5b6e 100644 --- a/tensorflow/core/platform/windows/env_time.cc +++ b/tensorflow/core/platform/windows/env_time.cc @@ -42,7 +42,7 @@ class WindowsEnvTime : public EnvTime { } } - uint64 NowNanos() { + uint64 NowNanos() const override { if (GetSystemTimePreciseAsFileTime_ != NULL) { // GetSystemTimePreciseAsFileTime function is only available in latest // versions of Windows, so we need to check for its existence here. diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index b902c85cdcf..08d0fadcfce 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -55,6 +55,8 @@ int NumSchedulableCPUs() { return system_info.dwNumberOfProcessors; } +int MaxParallelism() { return NumSchedulableCPUs(); } + int NumTotalCPUs() { // TODO(ebrevdo): Make this more accurate. // diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index 4efc15b7e5f..0ffd170e310 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -44,7 +44,6 @@ tf_proto_library( cc_api_version = 2, cc_grpc_version = 1, protodeps = tf_profiler_all_protos() + tf_additional_all_protos(), - visibility = ["//visibility:public"], ) tf_proto_library( @@ -54,7 +53,6 @@ tf_proto_library( cc_api_version = 2, cc_grpc_version = 1, protodeps = [":profiler_service_proto"] + tf_additional_all_protos(), - visibility = ["//visibility:public"], ) tf_proto_library( @@ -68,5 +66,13 @@ tf_proto_library( ), cc_api_version = 2, protodeps = tf_additional_all_protos(), +) + +filegroup( + name = "mobile_srcs", + srcs = [ + "//tensorflow/core/profiler/internal:mobile_srcs", + "//tensorflow/core/profiler/lib:mobile_srcs", + ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/profiler/g3doc/options.md b/tensorflow/core/profiler/g3doc/options.md index 38a8e028511..8c4b45db689 100644 --- a/tensorflow/core/profiler/g3doc/options.md +++ b/tensorflow/core/profiler/g3doc/options.md @@ -104,11 +104,14 @@ accelerator_micros and cpu_micros. Note: cpu and accelerator can run in parallel `-start_name_regexes`: Show node starting from the node that matches the regexes, recursively. regexes are comma-separated. -`-trim_name_regexes`: Hide node starting from the node that matches the regexes, recursively, regexes are comma-seprated. +`-trim_name_regexes`: Hide node starting from the node that matches the regexes, +recursively, regexes are comma-separated. -`-show_name_regexes`: Show node that match the regexes. regexes are comma-seprated. +`-show_name_regexes`: Show node that match the regexes. regexes are +comma-separated. -`-hide_name_regexes`: Hide node that match the regexes. regexes are comma-seprated. +`-hide_name_regexes`: Hide node that match the regexes. regexes are +comma-separated. `-account_displayed_op_only`: If True, only account the statistics of ops eventually displayed. If False, account all op statistics matching -account_type_regexes recursively. diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD index da3039ae3ce..eb98c8dd31a 100644 --- a/tensorflow/core/profiler/internal/BUILD +++ b/tensorflow/core/profiler/internal/BUILD @@ -375,12 +375,12 @@ tf_cuda_library( visibility = [ "//learning/brain/runtime:__pkg__", # xprof_bridge "//perftools/accelerators/xprof/xprofilez:__pkg__", # alias xprof::TraceMeRecorder + "//tensorflow/core:__pkg__", # executor.cc "//tensorflow/core/profiler/internal/cpu:__pkg__", # host_tracer "//tensorflow/core/profiler/lib:__pkg__", # traceme ], deps = [ "//tensorflow/core:lib", - "//tensorflow/stream_executor/lib", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", ], @@ -399,11 +399,26 @@ tf_cuda_cc_test( tf_cuda_library( name = "profiler_interface", + srcs = [ + "profiler_interface.cc", + ], hdrs = [ "profiler_interface.h", ], deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/synchronization", ], ) + +filegroup( + name = "mobile_srcs", + srcs = [ + "profiler_interface.cc", + "profiler_interface.h", + "traceme_recorder.cc", + "traceme_recorder.h", + ], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD index b94453c0a4b..a07e51fe003 100644 --- a/tensorflow/core/profiler/internal/cpu/BUILD +++ b/tensorflow/core/profiler/internal/cpu/BUILD @@ -15,18 +15,17 @@ tf_cuda_library( srcs = [ "host_tracer.cc", ], - hdrs = [ - "host_tracer.h", - ], deps = [ "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/internal:profiler_interface", "//tensorflow/core/profiler/internal:traceme_recorder", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], + alwayslink = True, ) tf_cuda_cc_test( @@ -34,9 +33,11 @@ tf_cuda_cc_test( srcs = ["host_tracer_test.cc"], deps = [ ":host_tracer", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", + "//tensorflow/core/profiler/internal:profiler_interface", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc index 3fb29664688..6fddd5829ce 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc @@ -12,23 +12,52 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/cpu/host_tracer.h" - #include #include "absl/container/flat_hash_map.h" #include "absl/strings/str_split.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env_time.h" +#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/internal/traceme_recorder.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { namespace profiler { namespace cpu { +namespace { +// Controls TraceMeRecorder and converts TraceMeRecorder::Events into +// RunMetadata messages. +// +// Thread-safety: This class is go/thread-compatible. +class HostTracer : public ProfilerInterface { + public: + explicit HostTracer(int host_trace_level); + ~HostTracer() override; + + // Starts recording TraceMes. + Status Start() override; + + // Stops recording TraceMes. + Status Stop() override; + + // Populates user traces and thread names in response. + // The user traces and thread names are in no particular order. + Status CollectData(RunMetadata* run_metadata) override; + + private: + // Level of host tracing. + const int host_trace_level_; + + // True if currently recording. + bool recording_ = false; + + // Container of all traced events. + TraceMeRecorder::Events events_; +}; -/* static */ std::unique_ptr HostTracer::Create( - int host_trace_level) { - return absl::WrapUnique(new HostTracer(host_trace_level)); -} HostTracer::HostTracer(int host_trace_level) : host_trace_level_(host_trace_level) {} @@ -57,15 +86,8 @@ Status HostTracer::Stop() { constexpr char kUserMetadataMarker = '#'; Status HostTracer::CollectData(RunMetadata* run_metadata) { - auto step_stats_collector = - absl::make_unique(run_metadata->mutable_step_stats()); - return CollectDataToCollector(step_stats_collector.get()); -} - -Status HostTracer::CollectDataToCollector( - StepStatsCollector* step_stats_collector) { - if (events_.empty() && recording_) { - events_ = TraceMeRecorder::Collect(); + if (recording_) { + return Status(error::INTERNAL, "TraceMeRecorder not stopped"); } // Pair up start and end events, and add complete events to trace_entries. absl::flat_hash_map end_times; @@ -77,10 +99,12 @@ Status HostTracer::CollectDataToCollector( } } + StepStatsCollector step_stats_collector(run_metadata->mutable_step_stats()); + const string cpu_name = "/host:CPU"; for (auto& thread : events_) { - step_stats_collector->SaveThreadName(cpu_name, thread.thread.tid, - thread.thread.name); + step_stats_collector.SaveThreadName(cpu_name, thread.thread.tid, + thread.thread.name); for (auto& event : thread.events) { if (!event.end_time) { auto it = end_times.find(event.activity_id); @@ -106,14 +130,31 @@ Status HostTracer::CollectDataToCollector( EnvTime::kMicrosToNanos); ns->set_thread_id(thread.thread.tid); // TODO(fishx): Add thread name to RunMetadata - step_stats_collector->Save(cpu_name, ns); + step_stats_collector.Save(cpu_name, ns); } } } events_.clear(); - step_stats_collector->Finalize(); + step_stats_collector.Finalize(); return Status::OK(); } +} // namespace + +// Not in anonymous namespace for testing purposes. +std::unique_ptr CreateHostTracer(const ProfilerContext*) { + int host_trace_level = 2; + return absl::make_unique(host_trace_level); +} + +auto register_host_tracer_factory = [] { + bool enable; + + TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_OSS_CPU_PROFILER", true, &enable)); + if (enable) { + RegisterProfilerFactory(&CreateHostTracer); + } + return 0; +}(); } // namespace cpu } // namespace profiler diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.h b/tensorflow/core/profiler/internal/cpu/host_tracer.h deleted file mode 100644 index c6340c2eddc..00000000000 --- a/tensorflow/core/profiler/internal/cpu/host_tracer.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_HOST_TRACER_H_ -#define TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_HOST_TRACER_H_ - -#include "tensorflow/core/common_runtime/step_stats_collector.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" -#include "tensorflow/core/profiler/internal/traceme_recorder.h" -#include "tensorflow/core/protobuf/config.pb.h" - -namespace tensorflow { -namespace profiler { -namespace cpu { - -// Controls TraceMeRecorder and converts TraceMeRecorder::Events into -// RunMetadata messages. -// -// Thread-safety: This class is go/thread-compatible. -class HostTracer : public ProfilerInterface { - public: - static std::unique_ptr Create(int host_trace_level); - - ~HostTracer(); - - // Starts recording TraceMes. - Status Start() override; - - // Stops recording TraceMes. - Status Stop() override; - - // Populates user traces and thread names in response. - // The user traces and thread names are in no particular order. - Status CollectData(RunMetadata* run_metadata) override; - - Status CollectDataToCollector(StepStatsCollector* step_stats_collector); - - private: - explicit HostTracer(int host_trace_level); - - // Level of host tracing. - const int host_trace_level_; - - // True if currently recording. - bool recording_ = false; - - // Container of all traced events. - TraceMeRecorder::Events events_; -}; - -} // namespace cpu -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_HOST_TRACER_H_ diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc index 51f9c6a8ca6..8b0e027bad5 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc @@ -12,24 +12,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/cpu/host_tracer.h" - #include #include #include #include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/profiler/internal/profiler_interface.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { namespace profiler { namespace cpu { +std::unique_ptr CreateHostTracer(const ProfilerContext*); + namespace { +Status CollectData(ProfilerInterface* profiler, RunMetadata* run_metadata) { + return profiler->CollectData(run_metadata); +} + using ::testing::ElementsAre; using ::testing::Pair; using ::testing::UnorderedElementsAre; @@ -74,7 +80,7 @@ inline ::testing::PolymorphicMatcher EqualsNodeStats( TEST(HostTracerTest, CollectsTraceMeEvents) { uint32 thread_id = Env::Default()->GetCurrentThreadId(); - auto tracer = HostTracer::Create(/*host_trace_level=*/1); + auto tracer = CreateHostTracer(nullptr); TF_ASSERT_OK(tracer->Start()); { TraceMe traceme("hello"); } @@ -86,7 +92,7 @@ TEST(HostTracerTest, CollectsTraceMeEvents) { TF_ASSERT_OK(tracer->Stop()); RunMetadata run_metadata; - TF_ASSERT_OK(tracer->CollectData(&run_metadata)); + TF_ASSERT_OK(CollectData(tracer.get(), &run_metadata)); EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 1); EXPECT_EQ(run_metadata.step_stats().dev_stats(0).node_stats_size(), 6); @@ -103,30 +109,6 @@ TEST(HostTracerTest, CollectsTraceMeEvents) { MakeNodeStats("incomplete", thread_id, "key1=value1,key2")))); } -void ValidateResult(const RunMetadata& run_metadata, const string& trace_name) { - uint32 thread_id = Env::Default()->GetCurrentThreadId(); - - EXPECT_THAT( - run_metadata.step_stats().dev_stats(0).node_stats(), - ElementsAre(EqualsNodeStats(MakeNodeStats(trace_name, thread_id)))); -} - -TEST(HostTracerTest, CollectsTraceMeEventsBetweenTracing) { - auto tracer = HostTracer::Create(/*host_trace_level=*/1); - RunMetadata run_metadata; - RunMetadata run_metadata2; - - TF_ASSERT_OK(tracer->Start()); - { TraceMe traceme("hello"); } - TF_ASSERT_OK(tracer->CollectData(&run_metadata)); - { TraceMe traceme("world"); } - TF_ASSERT_OK(tracer->CollectData(&run_metadata2)); - TF_ASSERT_OK(tracer->Stop()); - - ValidateResult(run_metadata, "hello"); - ValidateResult(run_metadata2, "world"); -} - } // namespace } // namespace cpu } // namespace profiler diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index 35f90e9bfc0..653ec045d00 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -4,22 +4,7 @@ package( licenses(["notice"]) # Apache 2.0 -load( - "//tensorflow:tensorflow.bzl", - "tf_cuda_library", -) - -tf_cuda_library( - name = "tracer", - srcs = [ - "tracer.cc", - ], - hdrs = [ - "tracer.h", - ], - deps = [ - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:device_tracer", - "//tensorflow/core/profiler/internal:profiler_interface", - ], +alias( + name = "device_tracer", + actual = "//tensorflow/core:device_tracer", ) diff --git a/tensorflow/core/profiler/internal/gpu/tracer.cc b/tensorflow/core/profiler/internal/gpu/tracer.cc deleted file mode 100644 index f1cb54161c7..00000000000 --- a/tensorflow/core/profiler/internal/gpu/tracer.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/internal/gpu/tracer.h" -#include "tensorflow/core/common_runtime/step_stats_collector.h" - -namespace tensorflow { -namespace profiler { -namespace gpu { - -/* static */ std::unique_ptr Tracer::Create() { - return absl::WrapUnique(new Tracer()); -} - -Status Tracer::Start() { - device_tracer_ = CreateDeviceTracer(); - if (!device_tracer_) { - return Status(tensorflow::error::Code::FAILED_PRECONDITION, - "Failed to create device tracer."); - } - return device_tracer_->Start(); -} - -Status Tracer::Stop() { - if (!device_tracer_) { - return Status(tensorflow::error::Code::FAILED_PRECONDITION, - "No running device tracer."); - } - return device_tracer_->Stop(); -} - -Status Tracer::CollectData(RunMetadata* run_metadata) { - if (!device_tracer_) { - return Status(tensorflow::error::Code::FAILED_PRECONDITION, - "No running device tracer."); - } - auto step_stats_collector = - absl::make_unique(run_metadata->mutable_step_stats()); - Status s = device_tracer_->Collect(step_stats_collector.get()); - step_stats_collector->Finalize(); - return s; -} - -Tracer::Tracer() {} - -} // namespace gpu -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/internal/gpu/tracer.h b/tensorflow/core/profiler/internal/profiler_interface.cc similarity index 52% rename from tensorflow/core/profiler/internal/gpu/tracer.h rename to tensorflow/core/profiler/internal/profiler_interface.cc index d7765432de9..2f48102318c 100644 --- a/tensorflow/core/profiler/internal/gpu/tracer.h +++ b/tensorflow/core/profiler/internal/profiler_interface.cc @@ -12,37 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_TRACER_H_ -#define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_TRACER_H_ - -#include "tensorflow/core/platform/device_tracer.h" #include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "absl/synchronization/mutex.h" + namespace tensorflow { -namespace profiler { -namespace gpu { +namespace { +std::vector* GetFactories() { + static auto factories = new std::vector(); + return factories; +} +absl::Mutex* GetMutex() { + static auto mutex = new absl::Mutex; + return mutex; +} +} // namespace -class Tracer : public ProfilerInterface { - public: - static std::unique_ptr Create(); +void RegisterProfilerFactory(ProfilerFactory factory) { + absl::MutexLock lock(GetMutex()); + GetFactories()->push_back(factory); +} - Status Start() override; - - Status Stop() override; - - Status CollectData(RunMetadata* run_metadata) override; - - private: - Tracer(); - - // Trace is neither copyable nor movable. - Tracer(const Tracer&) = delete; - Tracer& operator=(const Tracer&) = delete; - - std::unique_ptr device_tracer_; -}; - -} // namespace gpu -} // namespace profiler +void CreateProfilers( + const ProfilerContext* context, + std::vector>* result) { + absl::MutexLock lock(GetMutex()); + for (auto factory : *GetFactories()) { + if (auto profiler = factory(context)) { + result->push_back(std::move(profiler)); + } + } +} } // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_TRACER_H_ diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h index 144c4bb44d7..4754f4f03a6 100644 --- a/tensorflow/core/profiler/internal/profiler_interface.h +++ b/tensorflow/core/profiler/internal/profiler_interface.h @@ -19,6 +19,11 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { +class EagerContext; +struct ProfilerContext { + EagerContext* eager_context = nullptr; +}; + namespace profiler { // Interface for tensorflow profiler plugins. @@ -39,11 +44,21 @@ class ProfilerInterface { // Stops profiling. virtual Status Stop() = 0; - // Moves collected profile data into run_metadata. + // Moves collected profile data into step_stats_collector. virtual Status CollectData(RunMetadata* run_metadata) = 0; }; } // namespace profiler + +using ProfilerFactory = + std::unique_ptr (*)(const ProfilerContext*); + +void RegisterProfilerFactory(ProfilerFactory factory); + +void CreateProfilers( + const ProfilerContext* context, + std::vector>* result); + } // namespace tensorflow #endif // TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_ diff --git a/tensorflow/core/profiler/internal/runtime/BUILD b/tensorflow/core/profiler/internal/runtime/BUILD index 2e383f1716f..085fed81578 100644 --- a/tensorflow/core/profiler/internal/runtime/BUILD +++ b/tensorflow/core/profiler/internal/runtime/BUILD @@ -14,11 +14,11 @@ tf_cuda_library( srcs = [ "eager_profiler.cc", ], - hdrs = [ - "eager_profiler.h", - ], deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/profiler/internal:profiler_interface", ], + alwayslink = True, ) diff --git a/tensorflow/core/profiler/internal/runtime/eager_profiler.cc b/tensorflow/core/profiler/internal/runtime/eager_profiler.cc index aad692b01f6..30182da5db1 100644 --- a/tensorflow/core/profiler/internal/runtime/eager_profiler.cc +++ b/tensorflow/core/profiler/internal/runtime/eager_profiler.cc @@ -12,11 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/runtime/eager_profiler.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { namespace profiler { namespace runtime { +namespace { +class TraceCollector : public RunMetadataListener { + public: + explicit TraceCollector(EagerContext* const eager_context); + + void BeforeClearRunMetadata() override; + + Status CollectData(RunMetadata* run_metadata); + + private: + RunMetadata run_metadata_; + EagerContext* const context_; +}; + +class EagerProfiler : public ProfilerInterface { + public: + explicit EagerProfiler(EagerContext* const eager_context); + + Status Start() override; + + Status Stop() override; + + Status CollectData(RunMetadata* run_metadata) override; + + EagerContext* const context_; + TraceCollector collector_; +}; TraceCollector::TraceCollector(EagerContext* const eager_context) : context_(eager_context) {} @@ -30,11 +60,6 @@ Status TraceCollector::CollectData(RunMetadata* run_metadata) { return Status::OK(); } -/* static */ std::unique_ptr EagerProfiler::Create( - EagerContext* const eager_context) { - return absl::WrapUnique(new EagerProfiler(eager_context)); -} - Status EagerProfiler::Start() { if (context_ == nullptr) { return Status(tensorflow::error::Code::FAILED_PRECONDITION, @@ -55,6 +80,25 @@ Status EagerProfiler::CollectData(RunMetadata* run_metadata) { EagerProfiler::EagerProfiler(EagerContext* const eager_context) : context_(eager_context), collector_(eager_context) {} +} // namespace + +std::unique_ptr CreateEagerProfiler( + const ProfilerContext* context) { + if (!context || !context->eager_context) { + return nullptr; + } + return absl::make_unique(context->eager_context); +} + +auto register_eager_profiler_factory = [] { + bool enable; + TF_CHECK_OK( + ReadBoolFromEnvVar("TF_ENABLE_EAGER_RUNTIME_PROFILER", true, &enable)); + if (enable) { + RegisterProfilerFactory(&CreateEagerProfiler); + } + return 0; +}(); } // namespace runtime } // namespace profiler diff --git a/tensorflow/core/profiler/internal/runtime/eager_profiler.h b/tensorflow/core/profiler/internal/runtime/eager_profiler.h deleted file mode 100644 index 7135355e6ff..00000000000 --- a/tensorflow/core/profiler/internal/runtime/eager_profiler.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_RUNTIME_EAGER_PROFILER_H_ -#define TENSORFLOW_CORE_PROFILER_INTERNAL_RUNTIME_EAGER_PROFILER_H_ - -#include "tensorflow/core/common_runtime/eager/context.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" - -namespace tensorflow { -namespace profiler { -namespace runtime { - -class TraceCollector : public RunMetadataListener { - public: - TraceCollector(EagerContext* const eager_context); - - void BeforeClearRunMetadata() override; - - Status CollectData(RunMetadata* run_metadata); - - private: - RunMetadata run_metadata_; - EagerContext* const context_; -}; - -class EagerProfiler : public ProfilerInterface { - public: - static std::unique_ptr Create( - EagerContext* const eager_context); - - Status Start() override; - - Status Stop() override; - - Status CollectData(RunMetadata* run_metadata) override; - - private: - EagerProfiler(EagerContext* const eager_context); - - // Trace is neither copyable nor movable. - EagerProfiler(const EagerProfiler&) = delete; - EagerProfiler& operator=(const EagerProfiler&) = delete; - - EagerContext* const context_; - TraceCollector collector_; -}; - -} // namespace runtime -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_RUNTIME_EAGER_PROFILER_H_ diff --git a/tensorflow/core/profiler/internal/tfprof_op.cc b/tensorflow/core/profiler/internal/tfprof_op.cc index 3dce1d85db3..6e9178c7164 100644 --- a/tensorflow/core/profiler/internal/tfprof_op.cc +++ b/tensorflow/core/profiler/internal/tfprof_op.cc @@ -182,7 +182,7 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, // TODO(xpan): Is it the right choice? root_->formatted_str = display_str; } - // Populate the chidren field. + // Populate the children field. auto* pre_pb = root_->mutable_proto(); for (auto& show_node : show_nodes) { pre_pb->clear_children(); diff --git a/tensorflow/core/profiler/internal/traceme_recorder.cc b/tensorflow/core/profiler/internal/traceme_recorder.cc index 0369e0b96de..b2a20c7955b 100644 --- a/tensorflow/core/profiler/internal/traceme_recorder.cc +++ b/tensorflow/core/profiler/internal/traceme_recorder.cc @@ -14,79 +14,70 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/internal/traceme_recorder.h" -// To avoid unneccesary synchronization between threads, each thread has a -// ThreadLocalRecorder that independently records its events. -// -// Events are stored in an EventQueue implemented as a linked-list of blocks, -// with start and end pointers: -// [ events........ | next-]--> [ events......... | next ] -// ^start_block ^start ^end_block ^end -// -// Record() writes at end, and then advances it, allocating a block if needed. -// Clear() takes ownership of events in the range [start, end). -// The end pointer is atomic so these can be concurrent. -// -// If a thread dies, the ThreadLocalRecorder's destructor hands its data off to -// the orphaned_events list. +#include -#include -#include "absl/container/flat_hash_map.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/stream_executor/lib/initialize.h" namespace tensorflow { namespace profiler { -// Default value for g_trace_level when tracing is disabled -constexpr static int kTracingDisabled = -1; +std::atomic TraceMeRecorder::trace_level_ = + ATOMIC_VAR_INIT(TraceMeRecorder::kTracingDisabled); -namespace internal { -std::atomic g_trace_level = ATOMIC_VAR_INIT(kTracingDisabled); -} // namespace internal +// Implementation of TraceMeRecorder::trace_level_ must be lock-free for faster +// execution of the TraceMe() public API. This can be commented (if compilation +// is failing) but execution might be slow (even when host tracing is disabled). +static_assert(ATOMIC_INT_LOCK_FREE == 2, "Assumed atomic was lock free"); namespace { -class ThreadLocalRecorder; - -struct Data { - // Lock for only rare events - start/stop, thread death. - mutex global_lock; - // Map of the static container instances (thread_local storage) for each - // thread, that store the trace events. - absl::flat_hash_map threads - GUARDED_BY(global_lock); - // Events traced from threads that died during tracing. - TraceMeRecorder::Events orphaned_events GUARDED_BY(global_lock); -}* g_data = nullptr; - // A single-producer single-consumer queue of Events. -// Only the owner thread can write events, writing is lock-free. -// Consume is also lock-free in this class. // -// Internally, we have a linked list of blocks containing numbered slots. -// start is the first occupied slot, end is the first unoccupied slot. +// Implemented as a linked-list of blocks containing numbered slots, with start +// and end pointers: +// +// [ events........ | next-]--> [ events......... | next ] +// ^start_block_ ^start_ ^end_block_ ^end_ +// +// start_ is the first occupied slot, end_ is the first unoccupied slot. +// +// Push writes at end_, and then advances it, allocating a block if needed. +// PopAll takes ownership of events in the range [start_, end_). +// The end_ pointer is atomic so Push and PopAll can be concurrent. +// +// Push and PopAll are lock free and each might be called from at most one +// thread. Push is only called by the owner thread. PopAll is called by the +// owner thread when it shuts down, or by the tracing control thread. +// +// Thus, PopAll might race with Push, so PopAll only removes events that were +// in the queue when it was invoked. If Push is called while PopAll is active, +// the new event remains in the queue. Thus, the tracing control thread should +// call PopAll when tracing stops to remove events created during tracing, but +// also when tracing starts again to clear any remaining events. class EventQueue { public: EventQueue() - : start_block_(new Block{0, nullptr}), end_block_(start_block_) {} + : start_block_(new Block{/*start=*/0, /*next=*/nullptr}), + start_(start_block_->start), + end_block_(start_block_), + end_(start_) {} - // REQUIRES: Consume() was called since the last Push(). + // REQUIRES: PopAll() was called since the last Push(). // Memory should be deallocated and trace events destroyed on destruction. // This doesn't require global lock as this discards all the stored trace - // events and we assume of destruction of this class only after the last + // events and we assume of destruction of this instance only after the last // Push() has been called. ~EventQueue() { - DCHECK_EQ(start_, end_.load()) << "EventQueue destroyed without Consume()"; + DCHECK(Empty()) << "EventQueue destroyed without PopAll()"; delete end_block_; } // Add a new event to the back of the queue. Fast and lock-free. void Push(TraceMeRecorder::Event&& event) { - uint64 end = end_.load(std::memory_order_relaxed); + size_t end = end_.load(std::memory_order_relaxed); new (&end_block_->events[end++ - end_block_->start].event) TraceMeRecorder::Event(std::move(event)); - if (ABSL_PREDICT_FALSE(end - end_block_->start == Block::kLength)) { + if (ABSL_PREDICT_FALSE(end - end_block_->start == Block::kNumSlots)) { auto* new_block = new Block{end, nullptr}; end_block_->next = new_block; end_block_ = new_block; @@ -94,41 +85,53 @@ class EventQueue { end_.store(end, std::memory_order_release); // Write index after contents. } - // Retrieve and remove all events in the queue. - std::vector Consume() { + // Retrieve and remove all events in the queue at the time of invocation. + // If Push is called while PopAll is active, the new event will not be + // removed from the queue. + std::vector PopAll() { // Read index before contents. - uint64 end = end_.load(std::memory_order_acquire); + size_t end = end_.load(std::memory_order_acquire); std::vector result; result.reserve(end - start_); while (start_ != end) { - Shift(&result); + result.emplace_back(Pop()); } return result; } private: - // Shift one event off the front of the queue into *out. - void Shift(std::vector* out) { + // Returns true if the queue is empty at the time of invocation. + bool Empty() const { + return (start_ == end_.load(std::memory_order_acquire)); + } + + // Remove one event off the front of the queue and return it. + // REQUIRES: The queue must not be empty. + TraceMeRecorder::Event Pop() { + DCHECK(!Empty()); // Move the next event into the output. auto& event = start_block_->events[start_++ - start_block_->start].event; - out->push_back(std::move(event)); + TraceMeRecorder::Event out = std::move(event); event.~Event(); // Events must be individually destroyed. // If we reach the end of a block, we own it and should delete it. // The next block is present: end always points to something. - if (start_ - start_block_->start == Block::kLength) { + if (ABSL_PREDICT_FALSE(start_ - start_block_->start == Block::kNumSlots)) { auto* next_block = start_block_->next; delete start_block_; start_block_ = next_block; + DCHECK_EQ(start_, start_block_->start); } + return out; } - // The number of slots in a block. Chosen so that the block fits in 64k. struct Block { - static constexpr size_t kLength = - ((1 << 16) - (sizeof(uint64) + sizeof(std::atomic))) / + // The number of slots in a block is chosen so the block fits in 64 KiB. + static constexpr size_t kSize = 1 << 16; + static constexpr size_t kNumSlots = + (kSize - (sizeof(size_t) + sizeof(Block*))) / sizeof(TraceMeRecorder::Event); - const uint64 start; // The number of the first slot. + size_t start; // The number of the first slot. Block* next; // Defer construction of Event until the data is available. // Must also destroy manually, as the block may not fill entirely. @@ -136,113 +139,108 @@ class EventQueue { MaybeEvent() {} ~MaybeEvent() {} TraceMeRecorder::Event event; - } events[kLength]; + } events[kNumSlots]; }; + static_assert(sizeof(Block) <= Block::kSize, ""); + // Head of list for reading. Only accessed by consumer thread. Block* start_block_; - uint64 start_ = 0; + size_t start_; // Tail of list for writing. Accessed by producer thread. Block* end_block_; - std::atomic end_ = {0}; // Atomic: also read by consumer thread. + std::atomic end_; // Atomic: also read by consumer thread. }; -class ThreadLocalRecorder { +} // namespace + +// To avoid unnecessary synchronization between threads, each thread has a +// ThreadLocalRecorder that independently records its events. +class TraceMeRecorder::ThreadLocalRecorder { public: - // The recorder is created the first time Record() is called on a thread. + // The recorder is created the first time TraceMeRecorder::Record() is called + // on a thread. ThreadLocalRecorder() { auto* env = Env::Default(); info_.tid = env->GetCurrentThreadId(); env->GetCurrentThreadName(&info_.name); - mutex_lock lock(g_data->global_lock); - g_data->threads.emplace(info_.tid, this); + TraceMeRecorder::Get()->RegisterThread(info_.tid, this); } // The destructor is called when the thread shuts down early. - // We unregister this thread, and move its events to orphaned_events. - ~ThreadLocalRecorder() { - mutex_lock lock(g_data->global_lock); - g_data->threads.erase(info_.tid); - g_data->orphaned_events.push_back(Clear()); - } + ~ThreadLocalRecorder() { TraceMeRecorder::Get()->UnregisterThread(Clear()); } - // This is the performance-critical part! + // Record is only called from the owner thread. void Record(TraceMeRecorder::Event&& event) { queue_.Push(std::move(event)); } - TraceMeRecorder::ThreadEvents Clear() - EXCLUSIVE_LOCKS_REQUIRED(g_data->global_lock) { - return {info_, queue_.Consume()}; - } + // Clear is called from the control thread when tracing starts/stops, or from + // the owner thread when it shuts down (see destructor). + TraceMeRecorder::ThreadEvents Clear() { return {info_, queue_.PopAll()}; } private: TraceMeRecorder::ThreadInfo info_; EventQueue queue_; }; -// Gather events from all active threads, and clear their buffers. The global -// lock is held, so no threads can be added/removed for the duration while we -// consume the collected trace entries. This will block any new thread and also -// the starting and stopping of TraceMeRecorder, hence, this is performance -// critical and should be kept fast. -TraceMeRecorder::Events Clear() EXCLUSIVE_LOCKS_REQUIRED(g_data->global_lock) { +/*static*/ TraceMeRecorder* TraceMeRecorder::Get() { + static TraceMeRecorder* singleton = new TraceMeRecorder; + return singleton; +} + +void TraceMeRecorder::RegisterThread(int32 tid, ThreadLocalRecorder* thread) { + mutex_lock lock(mutex_); + threads_.emplace(tid, thread); +} + +void TraceMeRecorder::UnregisterThread(TraceMeRecorder::ThreadEvents&& events) { + mutex_lock lock(mutex_); + threads_.erase(events.thread.tid); + orphaned_events_.push_back(std::move(events)); +} + +// This method is performance critical and should be kept fast. It is called +// when tracing starts/stops. The mutex is held, so no threads can be +// registered/unregistered. This prevents calling ThreadLocalRecorder::Clear +// from two different threads. +TraceMeRecorder::Events TraceMeRecorder::Clear() { TraceMeRecorder::Events result; - std::swap(g_data->orphaned_events, result); - for (const auto& entry : g_data->threads) { + std::swap(orphaned_events_, result); + for (const auto& entry : threads_) { auto* recorder = entry.second; result.push_back(recorder->Clear()); } return result; } -} // namespace - -bool TraceMeRecorder::Start(int level) { +bool TraceMeRecorder::StartRecording(int level) { level = std::max(0, level); - mutex_lock lock(g_data->global_lock); + mutex_lock lock(mutex_); + // Change trace_level_ while holding mutex_. int expected = kTracingDisabled; - if (!internal::g_trace_level.compare_exchange_strong( - expected, level, std::memory_order_acq_rel)) { - return false; + bool started = trace_level_.compare_exchange_strong( + expected, level, std::memory_order_acq_rel); + if (started) { + // We may have old events in buffers because Record() raced with Stop(). + Clear(); } - // We may have old events in buffers because Record() raced with Stop(). - Clear(); - return true; + return started; } - void TraceMeRecorder::Record(Event event) { static thread_local ThreadLocalRecorder thread_local_recorder; thread_local_recorder.Record(std::move(event)); } -// Only one thread is expected to call Stop() as first instance of XprofSession -// prevents another XprofSession from doing any profiling. -TraceMeRecorder::Events TraceMeRecorder::Stop() { - mutex_lock lock(g_data->global_lock); - if (internal::g_trace_level.exchange( - kTracingDisabled, std::memory_order_acq_rel) == kTracingDisabled) { - return {}; - } - return Clear(); -} - -TraceMeRecorder::Events TraceMeRecorder::Collect() { - mutex_lock lock(g_data->global_lock); - if (internal::g_trace_level.load(std::memory_order_acquire) == +TraceMeRecorder::Events TraceMeRecorder::StopRecording() { + TraceMeRecorder::Events events; + mutex_lock lock(mutex_); + // Change trace_level_ while holding mutex_. + if (trace_level_.exchange(kTracingDisabled, std::memory_order_acq_rel) != kTracingDisabled) { - return {}; + events = Clear(); } - return Clear(); + return events; } } // namespace profiler } // namespace tensorflow - -REGISTER_MODULE_INITIALIZER(traceme_recorder, { - tensorflow::profiler::g_data = new tensorflow::profiler::Data(); - - // Workaround for b/35097229, the first block-scoped thread_local can - // trigger false positives in the heap checker. Currently triggered by - // //perftools/accelerators/xprof/xprofilez/integration_tests:xla_hlo_trace_test - static thread_local tensorflow::string fix_deadlock ABSL_ATTRIBUTE_UNUSED; -}); diff --git a/tensorflow/core/profiler/internal/traceme_recorder.h b/tensorflow/core/profiler/internal/traceme_recorder.h index 1e66b1e5bb3..374029714a3 100644 --- a/tensorflow/core/profiler/internal/traceme_recorder.h +++ b/tensorflow/core/profiler/internal/traceme_recorder.h @@ -16,17 +16,18 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_ #include +#include +#include #include + #include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace profiler { -namespace internal { -extern std::atomic g_trace_level; -} // namespace internal - // TraceMeRecorder is a singleton repository of TraceMe events. // It can be safely and cheaply appended to by multiple threads. // @@ -49,45 +50,68 @@ class TraceMeRecorder { uint64 end_time; // 0 = missing }; struct ThreadInfo { - int64 tid; + int32 tid; string name; }; struct ThreadEvents { - const ThreadInfo thread; + ThreadInfo thread; std::vector events; }; using Events = std::vector; // Starts recording of TraceMe(). // Only traces <= level will be recorded. - // Level must be >= 0. - // If level is 0, no traces will be recorded. - static bool Start(int level); + // Level must be >= 0. If level is 0, no traces will be recorded. + static bool Start(int level) { return Get()->StartRecording(level); } // Stops recording and returns events recorded since Start(). - static Events Stop(); - - // Returns events recorded till now without stopping the recording. Empty - // container is returned if the recorder was already stopped. - static Events Collect(); + // Events passed to Record after Stop has started will be dropped. + static Events Stop() { return Get()->StopRecording(); } // Returns whether we're currently recording. Racy, but cheap! static inline bool Active(int level = 1) { - return ABSL_PREDICT_FALSE( - internal::g_trace_level.load(std::memory_order_acquire) >= level); + return ABSL_PREDICT_FALSE(trace_level_.load(std::memory_order_acquire) >= + level); } - static void Record(Event); + // Records an event. Non-blocking. + static void Record(Event event); private: + // Default value for trace_level_ when tracing is disabled + static constexpr int kTracingDisabled = -1; + + class ThreadLocalRecorder; + + // Returns singleton. + static TraceMeRecorder* Get(); + + TraceMeRecorder() = default; + // No copy and assignment TraceMeRecorder(const TraceMeRecorder&) = delete; TraceMeRecorder& operator=(const TraceMeRecorder&) = delete; - // Implementation of g_trace_level must be lock-free for faster execution - // of the TraceMe() public API. This can be commented (if compilation is - // failing) but execution might be slow (even when host tracing is disabled). - static_assert(ATOMIC_INT_LOCK_FREE == 2, "Assumed atomic was lock free"); + void RegisterThread(int32 tid, ThreadLocalRecorder* thread); + void UnregisterThread(ThreadEvents&& events); + + bool StartRecording(int level); + Events StopRecording(); + + // Gathers events from all active threads, and clears their buffers. + Events Clear() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Current trace level. + // Static atomic so TraceMeRecorder::Active can be fast and non-blocking. + // Modified by TraceMeRecorder singleton when tracing starts/stops. + static std::atomic trace_level_; + + mutex mutex_; + // Map of the static container instances (thread_local storage) for each + // thread. While active, a ThreadLocalRecorder stores trace events. + absl::flat_hash_map threads_ GUARDED_BY(mutex_); + // Events from threads that died during recording. + TraceMeRecorder::Events orphaned_events_ GUARDED_BY(mutex_); }; } // namespace profiler diff --git a/tensorflow/core/profiler/internal/traceme_recorder_test.cc b/tensorflow/core/profiler/internal/traceme_recorder_test.cc index ec588af1d60..6899658c0a2 100644 --- a/tensorflow/core/profiler/internal/traceme_recorder_test.cc +++ b/tensorflow/core/profiler/internal/traceme_recorder_test.cc @@ -46,31 +46,6 @@ TEST(RecorderTest, SingleThreaded) { ::testing::ElementsAre(Named("during1"), Named("during2"))); } -TEST(RecorderTest, CollectionBeforeStop) { - uint64 start_time = Env::Default()->NowNanos(); - uint64 end_time = start_time + kNanosInSec; - - TraceMeRecorder::Record({1, "ignored", start_time, end_time}); - TraceMeRecorder::Start(/*level=*/1); - TraceMeRecorder::Record({2, "during1", start_time, end_time}); - TraceMeRecorder::Record({3, "during2", start_time, end_time}); - auto collected_results = TraceMeRecorder::Collect(); - TraceMeRecorder::Record({4, "after_collect", start_time, end_time}); - auto stopped_results = TraceMeRecorder::Stop(); - TraceMeRecorder::Record({5, "after_stop", start_time, end_time}); - auto results_after_stop = TraceMeRecorder::Collect(); - - ASSERT_EQ(collected_results.size(), 1); - EXPECT_THAT(collected_results[0].events, - ::testing::ElementsAre(Named("during1"), Named("during2"))); - - ASSERT_EQ(stopped_results.size(), 1); - EXPECT_THAT(stopped_results[0].events, - ::testing::ElementsAre(Named("after_collect"))); - - ASSERT_EQ(results_after_stop.size(), 0); -} - void SpinNanos(int nanos) { uint64 deadline = Env::Default()->NowNanos() + nanos; while (Env::Default()->NowNanos() < deadline) { diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index f078099321e..14040345c4e 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -11,6 +11,10 @@ load( "//tensorflow:tensorflow.bzl", "tf_cuda_library", ) +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_profiler_lib_deps", +) tf_cuda_library( name = "profiler_session", @@ -22,15 +26,10 @@ tf_cuda_library( ], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/core/common_runtime/eager:context", - "//tensorflow/core/profiler/internal/gpu:tracer", - "//tensorflow/core/profiler/internal/runtime:eager_profiler", "//tensorflow/core/profiler/internal:profiler_interface", "//tensorflow/core/profiler:protos_all_cc", ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", - ], + "//tensorflow:android": [], "//conditions:default": [ "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", @@ -39,11 +38,26 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", - "//tensorflow/core:device_tracer", ], }), ) +tf_cuda_library( + name = "profiler_graph_lib", + visibility = ["//tensorflow:internal"], + deps = tf_additional_profiler_lib_deps(), + alwayslink = 1, +) + +tf_cuda_library( + name = "profiler_eager_lib", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core/profiler/internal/runtime:eager_profiler", + ] + tf_additional_profiler_lib_deps(), + alwayslink = 1, +) + tf_cuda_library( name = "traceme", srcs = ["traceme.cc"], @@ -53,6 +67,11 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core/profiler/internal:traceme_recorder", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", ], ) + +filegroup( + name = "mobile_srcs", + srcs = glob(["*"]), + visibility = ["//visibility:public"], +) diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc index 86dd4c1e152..3913260360b 100644 --- a/tensorflow/core/profiler/lib/profiler_session.cc +++ b/tensorflow/core/profiler/lib/profiler_session.cc @@ -14,17 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/lib/profiler_session.h" + #include #include -#include "tensorflow/core/common_runtime/eager/context.h" + +#include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/internal/gpu/tracer.h" -#include "tensorflow/core/profiler/internal/runtime/eager_profiler.h" -#include "tensorflow/core/profiler/trace_events.pb.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/trace_events.pb.h" namespace tensorflow { @@ -68,7 +68,8 @@ void AssignLanes(RunMetadata* run_metadata) { void ConvertRunMetadataToTraceEvent(RunMetadata* run_metadata, profiler::Trace* trace, - const uint64 profile_start_time_micros) { + const uint64 profile_start_time_micros, + const uint64 profile_end_time_micros) { AssignLanes(run_metadata); auto trace_devices = trace->mutable_devices(); @@ -95,7 +96,9 @@ void ConvertRunMetadataToTraceEvent(RunMetadata* run_metadata, // Emit events. for (auto node : run_metadata->step_stats().dev_stats(device_id).node_stats()) { - if (node.all_start_micros() < profile_start_time_micros) { + if (node.all_start_micros() < profile_start_time_micros || + node.all_start_micros() + node.all_end_rel_micros() > + profile_end_time_micros) { continue; } auto* event = trace->add_trace_events(); @@ -114,7 +117,6 @@ void ConvertRunMetadataToTraceEvent(RunMetadata* run_metadata, // TODO(fishx): Convert allocation data as well. } - } // namespace /*static*/ std::unique_ptr ProfilerSession::Create( @@ -127,15 +129,15 @@ Status ProfilerSession::Status() { return status_; } -Status ProfilerSession::SerializeToString(string* content) { +Status ProfilerSession::CollectData(RunMetadata* run_metadata) { mutex_lock l(mutex_); if (!status_.ok()) return status_; for (auto& profiler : profilers_) { profiler->Stop().IgnoreError(); } - RunMetadata run_metadata; + for (auto& profiler : profilers_) { - profiler->CollectData(&run_metadata).IgnoreError(); + profiler->CollectData(run_metadata).IgnoreError(); } if (active_) { @@ -144,9 +146,17 @@ Status ProfilerSession::SerializeToString(string* content) { active_ = false; } - profiler::Trace trace; + return Status::OK(); +} - ConvertRunMetadataToTraceEvent(&run_metadata, &trace, start_time_micros_); +Status ProfilerSession::SerializeToString(string* content) { + RunMetadata run_metadata; + TF_RETURN_IF_ERROR(CollectData(&run_metadata)); + + profiler::Trace trace; + ConvertRunMetadataToTraceEvent( + &run_metadata, &trace, start_time_micros_, + Env::Default()->NowNanos() / EnvTime::kMicrosToNanos); trace.SerializeToString(content); return Status::OK(); @@ -156,23 +166,22 @@ ProfilerSession::ProfilerSession(ProfilerContext* const context) : active_(!session_active.exchange(true)), start_time_micros_(Env::Default()->NowNanos() / EnvTime::kMicrosToNanos) { if (!active_) { - status_ = tensorflow::Status(tensorflow::error::Code::UNAVAILABLE, - "Another profiling session is active."); + status_ = tensorflow::Status(error::UNAVAILABLE, + "Another profiler session is active."); return; } - LOG(INFO) << "Profile Session started."; - - if (context->eager_context != nullptr) { - profilers_.push_back(tensorflow::profiler::runtime::EagerProfiler::Create( - context->eager_context)); - } - profilers_.push_back(tensorflow::profiler::gpu::Tracer::Create()); + LOG(INFO) << "Profiler session started."; + CreateProfilers(context, &profilers_); status_ = Status::OK(); for (auto& profiler : profilers_) { - profiler->Start().IgnoreError(); + auto start_status = profiler->Start(); + if (!start_status.ok()) { + LOG(WARNING) << "Encountered error while starting profiler: " + << start_status.ToString(); + } } } @@ -186,5 +195,4 @@ ProfilerSession::~ProfilerSession() { session_active.store(false); } } - } // namespace tensorflow diff --git a/tensorflow/core/profiler/lib/profiler_session.h b/tensorflow/core/profiler/lib/profiler_session.h index 07276571244..b1a12336a57 100644 --- a/tensorflow/core/profiler/lib/profiler_session.h +++ b/tensorflow/core/profiler/lib/profiler_session.h @@ -15,17 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_SESSION_H_ #define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_SESSION_H_ -#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/internal/profiler_interface.h" namespace tensorflow { -struct ProfilerContext { - EagerContext* eager_context = nullptr; -}; - // A profiler which will start profiling when creating the object and will stop // when either the object is destroyed or SerializedToString is called. It will // profile all operations run under the given EagerContext. @@ -44,6 +40,7 @@ class ProfilerSession { tensorflow::Status Status() LOCKS_EXCLUDED(mutex_); + tensorflow::Status CollectData(RunMetadata* run_metadata); tensorflow::Status SerializeToString(string* content) LOCKS_EXCLUDED(mutex_); private: @@ -54,8 +51,8 @@ class ProfilerSession { ProfilerSession(const ProfilerSession&) = delete; ProfilerSession& operator=(const ProfilerSession&) = delete; - std::vector> - profilers_ GUARDED_BY(mutex_); + std::vector> profilers_ + GUARDED_BY(mutex_); // True if the session is active. bool active_ GUARDED_BY(mutex_); diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h index b9fae3d37f0..5a5ba524856 100644 --- a/tensorflow/core/profiler/lib/traceme.h +++ b/tensorflow/core/profiler/lib/traceme.h @@ -18,7 +18,6 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -27,13 +26,25 @@ limitations under the License. namespace tensorflow { namespace profiler { -// This is specifically used in xprof_bridge for instrumenting Tensorflow ops. +// This is specifically used for instrumenting Tensorflow ops. // Takes input as whether a TF op is expensive or not and returns the TraceMe // level to be assigned to trace that particular op. Assigns level 2 for -// expensive ops (these are high-level details and shown by default in xprof +// expensive ops (these are high-level details and shown by default in profiler // UI). Assigns level 3 for cheap ops (low-level details not shown by default). inline int GetTFTraceMeLevel(bool is_expensive) { return is_expensive ? 2 : 3; } +// Predefined levels: +// - Level 1 (kCritical) is the default and used only for user instrumentation. +// - Level 2 (kInfo) is used by profiler for instrumenting high level program +// execution details (expensive TF ops, XLA ops, etc). +// - Level 3 (kVerbose) is also used by profiler to instrument more verbose +// (low-level) program execution details (cheap TF ops, etc). +enum TraceMeLevel { + kCritical = 1, + kInfo = 2, + kVerbose = 3, +}; + // This class permits user-specified (CPU) tracing activities. A trace activity // is started when an object of this class is created and stopped when the // object is destroyed. @@ -63,12 +74,8 @@ class TraceMe { // in the UI. Level defines the trace priority, used for filtering TraceMe // events. By default, traces with TraceMe level <= 2 are recorded. Levels: // - Must be a positive integer. - // - Level 1 is the default and used only for user instrumentation. - // - Level 2 is used by xprof for instrumenting high level program execution - // details (expensive TF ops, XLA ops, etc). - // - Level 3 is also used by xprof to instrument more verbose (low-level) - // program execution details (cheap TF ops, etc). - // Users are welcome to use level >= 2 in their code, if they wish to filter + // - Can be a value in enum TraceMeLevel. + // Users are welcome to use level > 3 in their code, if they wish to filter // out their host traces based on verbosity. explicit TraceMe(absl::string_view activity_name, int level = 1) { DCHECK_GE(level, 1); @@ -113,7 +120,7 @@ class TraceMe { // type that the string() constructor can take. // name_generator is templated, rather than a std::function to avoid // allocations std::function might make even if never called. - // Usage: xprof::TraceMe([&]{ return StrCat(prefix, ":", postfix); }); + // Usage: profiler::TraceMe([&]{ return StrCat(prefix, ":", postfix); }); template explicit TraceMe(NameGeneratorT name_generator, int level = 1) { DCHECK_GE(level, 1); @@ -125,7 +132,10 @@ class TraceMe { } } - ~TraceMe() { + // Stop tracing the activity. Called by the destructor, but exposed to allow + // stopping tracing before the object goes out of scope. Only has an effect + // the first time it is called. + void Stop() { // We do not need to check the trace level again here. // - If tracing wasn't active to start with, we have kUntracedActivity. // - If tracing was active and was stopped, we have @@ -133,16 +143,19 @@ class TraceMe { // - If tracing was active and was restarted at a lower level, we may // spuriously record the event. This is extremely rare, and acceptable as // event will be discarded when its start timestamp fall outside of the - // start/stop session timestamp (recorded in XprofResponse). + // start/stop session timestamp. if (start_time_ != kUntracedActivity) { if (TraceMeRecorder::Active()) { TraceMeRecorder::Record({kCompleteActivity, std::move(no_init_.name), start_time_, Env::Default()->NowNanos()}); } no_init_.name.~string(); + start_time_ = kUntracedActivity; } } + ~TraceMe() { Stop(); } + // TraceMe is not movable or copyable. TraceMe(const TraceMe &) = delete; TraceMe &operator=(const TraceMe &) = delete; diff --git a/tensorflow/core/profiler/profiler_service.proto b/tensorflow/core/profiler/profiler_service.proto index 77702c3c900..be7ff7b5b19 100644 --- a/tensorflow/core/profiler/profiler_service.proto +++ b/tensorflow/core/profiler/profiler_service.proto @@ -1,19 +1,18 @@ syntax = "proto3"; + package tensorflow; import "tensorflow/core/framework/graph.proto"; -import "tensorflow/core/protobuf/config.proto"; import "tensorflow/core/profiler/op_profile.proto"; +import "tensorflow/core/protobuf/config.proto"; // The ProfilerService service retrieves performance information about // the programs running on connected devices over a period of time. service ProfilerService { // Starts a profiling session, blocks until it completes, and returns data. - rpc Profile(ProfileRequest) returns (ProfileResponse) { - } + rpc Profile(ProfileRequest) returns (ProfileResponse) {} // Collects profiling data and returns user-friendly metrics. - rpc Monitor(MonitorRequest) returns (MonitorResponse) { - } + rpc Monitor(MonitorRequest) returns (MonitorResponse) {} } message ProfileOptions { @@ -120,8 +119,10 @@ message MonitorRequest { // information, step time information, etc. Do not use this option if the TPU // host is being very heavily used. int32 monitoring_level = 2; + // True to display timestamp in monitoring result. + bool timestamp = 3; - // next-field: 3 + // next-field: 4 } message MonitorResponse { diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD index 3e5cdaa4984..cb6c1456ec2 100644 --- a/tensorflow/core/profiler/rpc/BUILD +++ b/tensorflow/core/profiler/rpc/BUILD @@ -15,6 +15,7 @@ tf_cuda_library( "//tensorflow/core:grpc_services", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/profiler:protos_all_cc", + "//tensorflow/core/profiler/lib:profiler_eager_lib", "//tensorflow/core/profiler/lib:profiler_session", ], alwayslink = 1, @@ -32,6 +33,7 @@ tf_cuda_library( "//tensorflow/core:grpc_services", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/profiler:protos_all_cc", + "//tensorflow/core/profiler/lib:profiler_eager_lib", "//tensorflow/core/profiler/lib:profiler_session", ], alwayslink = 1, diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD index ed0137f9b21..2ec88b5d61a 100644 --- a/tensorflow/core/profiler/rpc/client/BUILD +++ b/tensorflow/core/profiler/rpc/client/BUILD @@ -33,6 +33,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:grpc_services", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler:protos_all_cc", ], ) @@ -43,7 +44,7 @@ cc_library( hdrs = ["trace_events_to_json.h"], deps = [ "//tensorflow/core:lib", - "//tensorflow/core/profiler:protos_all_cc", + "//tensorflow/core:protos_all_cc", "@jsoncpp_git//:jsoncpp", ], ) @@ -54,6 +55,7 @@ tf_cc_test( deps = [ ":trace_events_to_json", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler:protos_all_cc", diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.cc b/tensorflow/core/profiler/rpc/client/capture_profile.cc index daada69e761..d1cc109eb63 100644 --- a/tensorflow/core/profiler/rpc/client/capture_profile.cc +++ b/tensorflow/core/profiler/rpc/client/capture_profile.cc @@ -74,6 +74,7 @@ ProfileRequest PopulateProfileRequest(int duration_ms, request.add_tools("input_pipeline"); request.add_tools("memory_viewer"); request.add_tools("overview_page"); + request.add_tools("pod_viewer"); *request.mutable_opts() = opts; return request; } @@ -164,9 +165,10 @@ Status NewSession(const string& service_addr, } // Creates an empty event file if not already exists, which indicates that we -// have a profile/plugin/ directory in the current logdir. +// have a plugins/profile/ directory in the current logdir. Status MaybeCreateEmptyEventFile(const tensorflow::string& logdir) { - // Suffix for an empty event file. + // Suffix for an empty event file. it should be kept in sync with + // _EVENT_FILE_SUFFIX in tensorflow/python/eager/profiler.py. constexpr char kProfileEmptySuffix[] = ".profile-empty"; std::vector children; TF_RETURN_IF_ERROR(Env::Default()->GetChildren(logdir, &children)); @@ -231,20 +233,20 @@ Status StartTracing(const tensorflow::string& service_addr, return status; } -MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) { +MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level, + bool timestamp) { MonitorRequest request; request.set_duration_ms(duration_ms); request.set_monitoring_level(monitoring_level); + request.set_timestamp(timestamp); return request; } -// Repeatedly collects profiles and shows user-friendly metrics for -// 'num_queries' time(s). void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, - int monitoring_level, int num_queries) { + int monitoring_level, bool timestamp, int num_queries) { for (int query = 0; query < num_queries; ++query) { MonitorRequest request = - PopulateMonitorRequest(duration_ms, monitoring_level); + PopulateMonitorRequest(duration_ms, monitoring_level, timestamp); ::grpc::ClientContext context; ::grpc::ChannelArguments channel_args; diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.h b/tensorflow/core/profiler/rpc/client/capture_profile.h index 98803672479..f8be2220193 100644 --- a/tensorflow/core/profiler/rpc/client/capture_profile.h +++ b/tensorflow/core/profiler/rpc/client/capture_profile.h @@ -26,9 +26,10 @@ namespace client { Status ValidateHostPortPair(const string& host_port); // Repeatedly collects profiles and shows user-friendly metrics for -// 'num_queries' time(s). +// 'num_queries' time(s). If timestamp flag is true, timestamp will be +// displayed in "%H:%M:%S" format. void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, - int monitoring_level, int num_queries); + int monitoring_level, bool timestamp, int num_queries); // Starts tracing on a single or multiple hosts and saves the result in the // given logdir. If no trace was collected, retries tracing for diff --git a/tensorflow/core/profiler/rpc/client/dump_tpu_profile.cc b/tensorflow/core/profiler/rpc/client/dump_tpu_profile.cc index ed65c110c9d..81b1ca83f8a 100644 --- a/tensorflow/core/profiler/rpc/client/dump_tpu_profile.cc +++ b/tensorflow/core/profiler/rpc/client/dump_tpu_profile.cc @@ -31,7 +31,7 @@ limitations under the License. #undef ERROR #include "tensorflow/core/profiler/op_profile.pb.h" #include "tensorflow/core/profiler/rpc/client/trace_events_to_json.h" -#include "tensorflow/core/profiler/trace_events.pb.h" +#include "tensorflow/core/protobuf/trace_events.pb.h" #include "tensorflow/core/util/events_writer.h" namespace tensorflow { @@ -46,7 +46,6 @@ using ::tensorflow::protobuf::util::MessageToJsonString; using ::tensorflow::str_util::EndsWith; using ::tensorflow::strings::StrCat; -constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; constexpr char kJsonOpProfileFileName[] = "op_profile.json"; constexpr char kJsonTraceFileName[] = "trace.json.gz"; constexpr char kProfilePluginDirectory[] = "plugins/profile/"; diff --git a/tensorflow/core/profiler/rpc/client/trace_events_to_json.cc b/tensorflow/core/profiler/rpc/client/trace_events_to_json.cc index 6adaec55460..e593f696e94 100644 --- a/tensorflow/core/profiler/rpc/client/trace_events_to_json.cc +++ b/tensorflow/core/profiler/rpc/client/trace_events_to_json.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/rpc/client/trace_events_to_json.h" + #include "include/json/json.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/profiler/trace_events.pb.h" +#include "tensorflow/core/protobuf/trace_events.pb.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/rpc/client/trace_events_to_json.h b/tensorflow/core/profiler/rpc/client/trace_events_to_json.h index d54cc3c619e..6625a12dd9d 100644 --- a/tensorflow/core/profiler/rpc/client/trace_events_to_json.h +++ b/tensorflow/core/profiler/rpc/client/trace_events_to_json.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_TRACE_EVENTS_TO_JSON_H_ #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/trace_events.pb.h" +#include "tensorflow/core/protobuf/trace_events.pb.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/rpc/client/trace_events_to_json_test.cc b/tensorflow/core/profiler/rpc/client/trace_events_to_json_test.cc index 0f883b04dc8..1350d1e3a4b 100644 --- a/tensorflow/core/profiler/rpc/client/trace_events_to_json_test.cc +++ b/tensorflow/core/profiler/rpc/client/trace_events_to_json_test.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/rpc/client/trace_events_to_json.h" + #include "include/json/json.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/trace_events.pb.h" +#include "tensorflow/core/protobuf/trace_events.pb.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/rpc/profiler_server.h b/tensorflow/core/profiler/rpc/profiler_server.h index 4e8c715ac75..21898d491f0 100644 --- a/tensorflow/core/profiler/rpc/profiler_server.h +++ b/tensorflow/core/profiler/rpc/profiler_server.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/profiler/lib/profiler_session.h" namespace tensorflow { - +class Thread; std::unique_ptr StartProfilerServer( ProfilerContext* const profiler_context, int32 port); } // namespace tensorflow diff --git a/tensorflow/core/protobuf/autotuning.proto b/tensorflow/core/protobuf/autotuning.proto new file mode 100644 index 00000000000..2edc70b34c5 --- /dev/null +++ b/tensorflow/core/protobuf/autotuning.proto @@ -0,0 +1,76 @@ +// This file defines protos that store the results of autotuning various +// operations. +// +// They are in proto format because we want to log them structured. They offer +// tremendous statistical, testing, and debugging value. +syntax = "proto3"; + +package tensorflow; + +import "google/protobuf/any.proto"; +import "google/protobuf/duration.proto"; + +message CudnnVersion { + int32 major = 1; + int32 minor = 2; + int32 patch = 3; +} + +message ComputeCapability { + int32 major = 1; + int32 minor = 2; +} + +message AutotuneResult { + enum FailureKind { + UNKNOWN = 0; + REDZONE_MODIFIED = 1; + WRONG_RESULT = 2; + } + + message FailureResult { + FailureKind kind = 1; + string msg = 2; + + // For failure_kind == WRONG_RESULT, this field indicates the reference + // configuration that we compared against. + // + // Note that the reference algorithm isn't always correct. However, + // empirically it's more correct, as it's "algo 0", less fancy than the + // compared one. + oneof key { + ConvKey reference_conv = 11; + } + } + + message ConvKey { + int64 algorithm = 1; + bool tensor_ops_enabled = 2; + } + + int64 scratch_bytes = 8; + google.protobuf.Duration run_time = 9; + + FailureResult failure = 7; + + oneof key { + ConvKey conv = 5; + } + + // Next ID: 12 +} + +message AutotuningLog { + google.protobuf.Any instr = 1; + + // Records all auto-tuning results per algorithm. + repeated AutotuneResult results = 2; + + CudnnVersion cudnn_version = 3; + ComputeCapability compute_capability = 4; + + // stream_executor::DeviceDescription::pci_bus_id. + string device_pci_bus_id = 5; + + // Next ID: 6 +} diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 3e24235369a..4d6212422fd 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -1,16 +1,18 @@ syntax = "proto3"; package tensorflow; + option cc_enable_arenas = true; option java_outer_classname = "ConfigProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; + // add go_package externally with copybara import "tensorflow/core/framework/cost_graph.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/step_stats.proto"; -import "tensorflow/core/protobuf/debug.proto"; import "tensorflow/core/protobuf/cluster.proto"; +import "tensorflow/core/protobuf/debug.proto"; import "tensorflow/core/protobuf/rewriter_config.proto"; message GPUOptions { @@ -163,16 +165,33 @@ message GPUOptions { // is really not subject to pending use. bool timestamped_allocator = 5; - // If > 0 limit the number of pending kernels on any compute - // stream to this number. - int32 pending_cap = 6; + // reserved id: 6 + + // Parameters for GPUKernelTracker. By default no kernel tracking is done. + // Note that timestamped_allocator is only effective if some tracking is + // specified. + // + // If kernel_tracker_max_interval = n > 0, then a tracking event + // is inserted after every n kernels without an event. + int32 kernel_tracker_max_interval = 7; + // If kernel_tracker_max_bytes = n > 0, then a tracking event is + // inserted after every series of kernels allocating a sum of + // memory >= n. If one kernel allocates b * n bytes, then one + // event will be inserted after it, but it will count as b against + // the pending limit. + int32 kernel_tracker_max_bytes = 8; + // If kernel_tracker_max_pending > 0 then no more than this many + // tracking events can be outstanding at a time. An attempt to + // launch an additional kernel will stall until an event + // completes. + int32 kernel_tracker_max_pending = 9; } // Everything inside experimental is subject to change and is not subject // to API stability guarantees in // https://www.tensorflow.org/guide/version_compat. Experimental experimental = 9; -}; +} // Options passed to the graph optimizer message OptimizerOptions { @@ -267,7 +286,7 @@ message GraphOptions { // Not currently configurable via the public Python API (i.e. there is no API // stability guarantee if you import RewriterConfig explicitly). RewriterConfig rewrite_options = 10; -}; +} message ThreadPoolOptionProto { // The number of threads in the pool. @@ -292,7 +311,7 @@ message ThreadPoolOptionProto { // value as is specified on this call. // - threadpools created this way are never garbage collected. string global_name = 2; -}; +} message RPCOptions { // If true, always use RPC to contact the session target. @@ -308,7 +327,7 @@ message RPCOptions { // If compression_algorithm is set, the compression level to be used. // From 0 (no compression), up to 3. int32 compression_level = 3; -}; +} // Session configuration parameters. // The system picks appropriate values for fields that are not set. @@ -328,6 +347,7 @@ message ConfigProto { // inter_op_parallelism_threads available in each process. // // 0 means the system picks an appropriate number. + // Negative means all operations are performed in caller's thread. // // Note that the first Session created in the process sets the // number of threads for all future sessions unless use_per_session_threads is @@ -406,7 +426,8 @@ message ConfigProto { ClusterDef cluster_def = 14; // If true, any resources such as Variables used in the session will not be - // shared with other sessions. + // shared with other sessions. However, when clusterspec propagation is + // enabled, this field is ignored and sessions are always isolated. bool isolate_session_state = 15; // Everything inside Experimental is subject to change and is not subject @@ -443,12 +464,46 @@ message ConfigProto { // If true, use NCCL for CollectiveOps. This feature is highly // experimental. bool collective_nccl = 7; + + // In the following, session state means the value of a variable, elements + // in a hash table, or any other resource, accessible by worker sessions + // held by a TF server. + // + // When ClusterSpec propagation is enabled, the value of + // isolate_session_state is ignored when deciding whether to share session + // states in a TF server (for backwards compatibility reasons). + // - If share_session_state_in_clusterspec_propagation is true, the session + // states are shared. + // - If share_session_state_in_clusterspec_propagation is false, session + // states are isolated. + // + // When clusterspec propagation is not used, the value of + // share_session_state_in_clusterspec_propagation is ignored when deciding + // whether to share session states in a TF server. + // - If isolate_session_state is true, session states are isolated. + // - If isolate_session_state is false, session states are shared. + // + // TODO(b/129330037): Add a single API that consistently treats + // isolate_session_state and ClusterSpec propagation. + bool share_session_state_in_clusterspec_propagation = 8; + + // If using a direct session, disable spinning while waiting for work in + // the thread pool. This may result in higher latency for completing ops, + // but in the case where there is a lot of spinning may result in lower + // CPU usage. + bool disable_thread_spinning = 9; + + // When true, WorkerSessions are created with device attributes from the + // full cluster. + // This is helpful when a worker wants to partition a graph + // (for example during a PartitionedCallOp). + bool share_cluster_devices_in_session = 10; }; Experimental experimental = 16; // Next: 17 -}; +} // Options for a single Run() call. message RunOptions { diff --git a/tensorflow/core/protobuf/conv_autotuning.proto b/tensorflow/core/protobuf/conv_autotuning.proto new file mode 100644 index 00000000000..c75f530695b --- /dev/null +++ b/tensorflow/core/protobuf/conv_autotuning.proto @@ -0,0 +1,25 @@ +// This is used for convolution logging. Also see +// tensorflow/core/protobuf/autotuing.h +syntax = "proto3"; + +package tensorflow; + +import "tensorflow/stream_executor/dnn.proto"; + +// A convolution. Currently it's only used for logging. In the future, we may +// want to use it in the API as well. +message ConvolutionProto { + stream_executor.dnn.ConvolutionKind kind = 1; + stream_executor.dnn.TensorDescriptorProto input = 2; + stream_executor.dnn.TensorDescriptorProto filter = 3; + stream_executor.dnn.TensorDescriptorProto output = 4; + stream_executor.dnn.ConvolutionDescriptorProto conv_desc = 5; + + // result = conv_scale * conv(...) + side_value_scale * side_value. + // side_value is an arbitrary buffer if activation is not none. Otherwise, it + // has to be the result buffer (using its old values). + double conv_scale = 6; + double side_value_scale = 7; + + stream_executor.dnn.ActivationMode activation = 8; +} diff --git a/tensorflow/core/protobuf/data/experimental/snapshot.proto b/tensorflow/core/protobuf/data/experimental/snapshot.proto new file mode 100644 index 00000000000..dde0ade9f75 --- /dev/null +++ b/tensorflow/core/protobuf/data/experimental/snapshot.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package tensorflow.data.experimental; + +import "tensorflow/core/framework/tensor.proto"; + +// Each SnapshotRecord represents one batch of pre-processed input data. A batch +// consists of a list of tensors that we encode as TensorProtos. This message +// doesn't store the structure of the batch. +message SnapshotRecord { + repeated .tensorflow.TensorProto tensor = 1; +} + +// This stores the metadata information present in each snapshot record. +message SnapshotMetadataRecord { + string graph_fingerprint = 1; + string run_id = 2; + int64 creation_timestamp = 3; + + bool finalized = 1000; +} diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 63ba4eb173c..f93ca99c63c 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -5,10 +5,10 @@ package tensorflow.eager; import "tensorflow/core/framework/attr_value.proto"; import "tensorflow/core/framework/device_attributes.proto"; import "tensorflow/core/framework/function.proto"; +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/versions.proto"; import "tensorflow/core/protobuf/tensorflow_server.proto"; -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/tensor.proto"; message RemoteTensorHandle { // The ID of the operation that produced this tensor. @@ -71,6 +71,9 @@ message CreateContextRequest { // both ends use this ID for selecting a rendezvous to get everything to // match. int64 rendezvous_id = 5; + + // Device attributes in the cluster + repeated DeviceAttributes cluster_device_attributes = 6; } message CreateContextResponse { @@ -110,15 +113,13 @@ message KeepAliveRequest { fixed64 context_id = 1; } -message KeepAliveResponse { -} +message KeepAliveResponse {} message CloseContextRequest { fixed64 context_id = 1; } -message CloseContextResponse { -} +message CloseContextResponse {} message RegisterFunctionRequest { fixed64 context_id = 1; @@ -126,8 +127,7 @@ message RegisterFunctionRequest { FunctionDef function_def = 2; } -message RegisterFunctionResponse { -} +message RegisterFunctionResponse {} message SendTensorRequest { fixed64 context_id = 1; @@ -144,8 +144,7 @@ message SendTensorRequest { string device_name = 4; } -message SendTensorResponse { -} +message SendTensorResponse {} //////////////////////////////////////////////////////////////////////////////// // diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 0978a8257bd..e1701b075ef 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -1,10 +1,12 @@ syntax = "proto3"; package tensorflow; + option cc_enable_arenas = true; option java_outer_classname = "RewriterConfigProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; + // add go_package externally with copybara import "tensorflow/core/framework/attr_value.proto"; @@ -81,6 +83,11 @@ message RewriterConfig { // Enable the swap of kernel implementations based on the device placement // (default is ON). Toggle implementation_selector = 22; + // Optimize data types (default is OFF). + // e.g., This will try to use float16 on GPU which is faster. + // Note that this can change the numerical stability of the graph and may + // require the use of loss scaling to maintain model convergence. + Toggle auto_mixed_precision = 23; // Disable the entire meta optimizer (off by default). bool disable_meta_optimizer = 19; diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto index 48060b33dc4..720f54cdd77 100644 --- a/tensorflow/core/protobuf/saved_object_graph.proto +++ b/tensorflow/core/protobuf/saved_object_graph.proto @@ -5,6 +5,7 @@ import "tensorflow/core/protobuf/struct.proto"; import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; import "tensorflow/core/framework/versions.proto"; +import "tensorflow/core/framework/variable.proto"; option cc_enable_arenas = true; @@ -132,6 +133,9 @@ message SavedVariable { DataType dtype = 1; TensorShapeProto shape = 2; bool trainable = 3; + VariableSynchronization synchronization = 4; + VariableAggregation aggregation = 5; + string name = 6; } // Represents `FunctionSpec` used in `Function`. This represents a @@ -141,18 +145,18 @@ message FunctionSpec { StructuredValue fullargspec = 1; // Whether this represents a class method. bool is_method = 2; - // Which arguments to always prepend, in case the original function is based - // on a functools.partial. - StructuredValue args_to_prepend = 3; - // Which kwargs to always include, in case the original function is based on a - // functools.partial. - StructuredValue kwargs_to_include = 4; // The input signature, if specified. StructuredValue input_signature = 5; + + reserved 3, 4; } // A SavedResource represents a TF object that holds state during its lifetime. +// An object of this type can have a reference to a: +// create_resource() and an initialize() function. message SavedResource { - // An object of this type can have a reference to a: - // create_resource() and an initialize() function. + // A device specification indicating a required placement for the resource + // creation function, e.g. "CPU". An empty string allows the user to select a + // device. + string device = 1; } diff --git a/tensorflow/core/protobuf/tpu/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto index 7d3c105eec3..6ea9ce15656 100644 --- a/tensorflow/core/protobuf/tpu/optimization_parameters.proto +++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto @@ -60,10 +60,23 @@ message AdagradParameters { float initial_accumulator = 1; } +// Algorithm in http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf. +message BoundedAdagradParameters { + // Whether to use the updated or the old value of the accumulator when + // computing the effective learning rate. When update_accumulator_first is set + // to True, the updated value of the accumulator is used. + bool update_accumulator_first = 1; + // The max_var_update value to use. Set value to 0 (default) to disable using + // max_var_update to clip the gradient. + float max_var_update = 2; + // The maximum value of the accumulator. Set max_accumulator to 0 (default) + // to disable using max_accumulator to clip the accumulator. + float max_accumulator = 3; +} + // https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer // https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L423 -message StochasticGradientDescentParameters { -} +message StochasticGradientDescentParameters {} // https://www.tensorflow.org/api_docs/python/tf/train/FtrlOptimizer // https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L192 @@ -181,7 +194,7 @@ message GradientAccumulationStatus { ENABLED = 1; DISABLED = 2; } -}; +} // Configuration proto for hot ID optimization. This is an experimental feature // that is currently disabled (by default). @@ -252,6 +265,7 @@ message OptimizationParameters { // algorithm to use. oneof parameters { AdagradParameters adagrad = 3; + BoundedAdagradParameters bounded_adagrad = 19; StochasticGradientDescentParameters stochastic_gradient_descent = 4; FtrlParameters ftrl = 5; AdamParameters adam = 6; diff --git a/tensorflow/core/profiler/trace_events.proto b/tensorflow/core/protobuf/trace_events.proto similarity index 91% rename from tensorflow/core/profiler/trace_events.proto rename to tensorflow/core/protobuf/trace_events.proto index 69ec88ca9a7..76b7300aeab 100644 --- a/tensorflow/core/profiler/trace_events.proto +++ b/tensorflow/core/protobuf/trace_events.proto @@ -2,6 +2,11 @@ syntax = "proto3"; package tensorflow.profiler; +option cc_enable_arenas = true; +option java_outer_classname = "TraceEventsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + // A 'Trace' contains metadata for the individual traces of a system. message Trace { // The devices that this trace has information about. Maps from device_id to diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 4284dd119ed..0bea9aa4ee5 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -16,16 +16,18 @@ limitations under the License. syntax = "proto3"; package tensorflow; + option cc_enable_arenas = true; option java_outer_classname = "WorkerProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.distruntime"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"; + +// add go_package externally with copybara import "google/protobuf/any.proto"; import "tensorflow/core/framework/cost_graph.proto"; -import "tensorflow/core/framework/step_stats.proto"; import "tensorflow/core/framework/device_attributes.proto"; import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/step_stats.proto"; import "tensorflow/core/framework/tensor.proto"; import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; @@ -41,8 +43,7 @@ import "tensorflow/core/protobuf/tensorflow_server.proto"; // //////////////////////////////////////////////////////////////////////////////// -message GetStatusRequest { -} +message GetStatusRequest {} message GetStatusResponse { repeated DeviceAttributes device_attributes = 1; @@ -66,10 +67,12 @@ message CreateWorkerSessionRequest { // If true, any resources such as Variables used in the session will not be // shared with other sessions. bool isolate_session_state = 3; + + // The device attributes of all the devices in the cluster. + repeated DeviceAttributes cluster_device_attributes = 4; } -message CreateWorkerSessionResponse { -} +message CreateWorkerSessionResponse {} //////////////////////////////////////////////////////////////////////////////// // @@ -84,8 +87,7 @@ message DeleteWorkerSessionRequest { string session_handle = 1; } -message DeleteWorkerSessionResponse { -} +message DeleteWorkerSessionResponse {} //////////////////////////////////////////////////////////////////////////////// // @@ -186,8 +188,7 @@ message CleanupAllRequest { repeated string container = 1; } -message CleanupAllResponse { -} +message CleanupAllResponse {} //////////////////////////////////////////////////////////////////////////////// // @@ -207,7 +208,7 @@ message ExecutorOpts { bool record_timeline = 3; bool record_partition_graphs = 4; bool report_tensor_allocations_upon_oom = 5; -}; +} message RunGraphRequest { // session_handle is the master-generated unique id for this session. @@ -253,7 +254,17 @@ message RunGraphRequest { // truncate long metadata messages. bool store_errors_in_response_body = 9; - // Next: 11 + // Unique identifier for this request. Every RunGraphRequest must have a + // unique request_id, and retried RunGraphRequests must have the same + // request_id. If request_id is zero, retry detection is disabled. + // + // Retried RunGraphRequests are problematic because they may issue a + // RecvTensor that will have no corresponding sender and will wait forever. + // Workers use request_ids to reject retried RunGraph requests instead of + // waiting forever. + int64 request_id = 11; + + // Next: 12 } message RunGraphResponse { @@ -295,8 +306,7 @@ message CleanupGraphRequest { int64 step_id = 1; } -message CleanupGraphResponse { -} +message CleanupGraphResponse {} //////////////////////////////////////////////////////////////////////////////// // @@ -332,7 +342,8 @@ message RecvTensorRequest { // Unique identifier for this request. Every RecvTensorRequest must have a // unique request_id, and retried RecvTensorRequests must have the same - // request_id. If request_id is zero, retry detection is disabled. + // request_id. If request_id is zero, retry detection and response cache + // are disabled. // // Retried RecvTensorRequests are problematic because a RecvTensor with no // corresponding sender will wait forever, and the tensor may have been @@ -355,8 +366,20 @@ message RecvTensorResponse { // Optional additional information about how to receive the tensor, // e.g. in the event that `RecvTensorRequest.dma_ok` was true. google.protobuf.Any transport_options = 4; + + // Whether the receiver should send a MarkRecvFinishedRequest to the sender + // to ack the message. + bool require_ack = 5; } +// Message for managing the response cache maintained on the sender side. +// Currently only used by the gRPC worker service. +message MarkRecvFinishedRequest { + int64 request_id = 1; +} + +message MarkRecvFinishedResponse {} + //////////////////////////////////////////////////////////////////////////////// // // Logging method request/response messages @@ -424,8 +447,7 @@ message TracingRequest { TraceOpts options = 1; } -message TracingResponse { -} +message TracingResponse {} //////////////////////////////////////////////////////////////////////////////// // @@ -484,6 +506,10 @@ message RecvBufResponse { google.protobuf.Any transport_options = 4; // Optional, for timeline. int64 send_start_micros = 5; + + // Whether the receiver should send a MarkRecvFinishedRequest to the sender + // to ack the message. + bool require_ack = 6; } //////////////////////////////////////////////////////////////////////////////// @@ -500,6 +526,7 @@ message CompleteGroupRequest { int32 group_size = 2; string device_type = 3; repeated string device_name = 4; + int32 collective_type = 5; } // Gives the complete membership of the group identified by group_key. @@ -510,6 +537,7 @@ message CompleteGroupResponse { int32 num_tasks = 4; // number of distinct tasks hosting the devices repeated string device_name = 5; repeated string task_name = 6; // task name prefixes of device_names + bytes communicator_key = 7; } // Supplies data about one collective op belonging to the instance identified @@ -535,7 +563,7 @@ message CompleteInstanceRequest { message CompleteInstanceResponse { int32 instance_key = 1; int32 source_rank = 2; - bytes communicator_key = 3; + reserved 3; } // Request for next agreed-upon step_id for the specified graph_keys. diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 40b101fb917..b8538b89c3d 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -18,6 +18,8 @@ limitations under the License. // TensorFlow uses semantic versioning, see http://semver.org/. +// Also update tensorflow/tensorflow.bzl and +// tensorflow/tools/pip_package/setup.py #define TF_MAJOR_VERSION 1 #define TF_MINOR_VERSION 13 #define TF_PATCH_VERSION 1 @@ -34,8 +36,6 @@ limitations under the License. (TF_STR(TF_MAJOR_VERSION) "." TF_STR(TF_MINOR_VERSION) "." TF_STR( \ TF_PATCH_VERSION) TF_VERSION_SUFFIX) -// TODO(josh11b): Public API functions for exporting the above. - // GraphDef compatibility versions (the versions field in graph.proto). // // Each graph has producer and min_consumer versions, and each @@ -98,10 +98,17 @@ limitations under the License. // deprecated in favor of V2 ops. (2018/01/23) // 28. Deprecate MatrixExponential op in favor of Python implementation. // (2018/08/21). +// (2019/02/15). Added `control_ret` field to FunctionDef proto, and +// `control_output` field to OpDef proto. +// 29. Deprecate StatefulStandardNormal op in favor of StatefulStandardNormalV2. +// (2019/03/25). +// (2019/04/17). Added `arg_attr` field to FunctionDefProto. +// 30. (2019/05/09) First date based GraphDef version. GraphDef +// versions advance by 1 each day after this point. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 27 +#define TF_GRAPH_DEF_VERSION 37 // Updated: 2019/5/16 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/summary/summary_db_writer_test.cc b/tensorflow/core/summary/summary_db_writer_test.cc index c4e9ddea2c5..a61d4bb701f 100644 --- a/tensorflow/core/summary/summary_db_writer_test.cc +++ b/tensorflow/core/summary/summary_db_writer_test.cc @@ -39,8 +39,8 @@ class FakeClockEnv : public EnvWrapper { public: FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {} void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; } - uint64 NowMicros() override { return current_millis_ * 1000; } - uint64 NowSeconds() override { return current_millis_ * 1000; } + uint64 NowMicros() const override { return current_millis_ * 1000; } + uint64 NowSeconds() const override { return current_millis_ * 1000; } private: uint64 current_millis_; diff --git a/tensorflow/core/summary/summary_file_writer_test.cc b/tensorflow/core/summary/summary_file_writer_test.cc index d3b19c3abdb..f650cb72021 100644 --- a/tensorflow/core/summary/summary_file_writer_test.cc +++ b/tensorflow/core/summary/summary_file_writer_test.cc @@ -32,8 +32,8 @@ class FakeClockEnv : public EnvWrapper { public: FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {} void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; } - uint64 NowMicros() override { return current_millis_ * 1000; } - uint64 NowSeconds() override { return current_millis_ * 1000; } + uint64 NowMicros() const override { return current_millis_ * 1000; } + uint64 NowSeconds() const override { return current_millis_ * 1000; } private: uint64 current_millis_; diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 5cbed402f75..f81cc6b4ee0 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -8,8 +8,10 @@ cc_library( hdrs = ["tpu_embedding_optimization_parameters_utils.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", "@com_google_absl//absl/base", ], diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc index 71766f6f037..e39b1d6e1bd 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc @@ -14,7 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { namespace tpu { @@ -23,6 +27,8 @@ string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) { switch (alg) { case OptimizationAlgorithm::kAdagrad: return "Adagrad"; + case OptimizationAlgorithm::kBoundedAdagrad: + return "BoundedAdagrad"; case OptimizationAlgorithm::kStochasticGradientDescent: return "StochasticGradientDescent"; case OptimizationAlgorithm::kFtrl: @@ -51,6 +57,8 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) { switch (alg) { case OptimizationAlgorithm::kAdagrad: return "Adagrad"; + case OptimizationAlgorithm::kBoundedAdagrad: + return "Bounded Adagrad"; case OptimizationAlgorithm::kStochasticGradientDescent: return "stochastic gradient descent"; case OptimizationAlgorithm::kFtrl: @@ -83,6 +91,9 @@ Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) { case OptimizationAlgorithm::kAdagrad: *count = 1; return Status::OK(); + case OptimizationAlgorithm::kBoundedAdagrad: + *count = 1; + return Status::OK(); case OptimizationAlgorithm::kStochasticGradientDescent: *count = 0; return Status::OK(); @@ -166,6 +177,11 @@ Status GetOptimizationAlgorithmStateVariables( MakeStandardStateVariableSpecification("accumulators", 0.1)); break; } + case OptimizationAlgorithm::kBoundedAdagrad: { + state_variables->push_back( + MakeStandardStateVariableSpecification("accumulators", 0.1)); + break; + } case OptimizationAlgorithm::kStochasticGradientDescent: { // None. break; @@ -251,6 +267,7 @@ Status GetOptimizationAlgorithmStateVariables( std::vector GetOptimizationAlgorithms() { return { OptimizationAlgorithm::kAdagrad, + OptimizationAlgorithm::kBoundedAdagrad, OptimizationAlgorithm::kStochasticGradientDescent, OptimizationAlgorithm::kFtrl, OptimizationAlgorithm::kAdam, @@ -263,5 +280,242 @@ std::vector GetOptimizationAlgorithms() { }; } +Status RegisterPerTableLoadOpsForAlgorithmBody( + OptimizationAlgorithm alg, bool is_debug_op, + OpRegistrationData* op_reg_data) { + GradientAccumulationSupport grad_accum_support; + TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support)); + + std::vector state_variable_specs; + TF_CHECK_OK(GetOptimizationAlgorithmStateVariables( + alg, + grad_accum_support == GradientAccumulationSupport::kSupported && + is_debug_op, + &state_variable_specs)); + auto* op_def = &op_reg_data->op_def; + op_def->set_name( + strings::StrCat("LoadTPUEmbedding", GetOptimizationAlgorithmName(alg), + "Parameters", (is_debug_op ? "GradAccumDebug" : ""))); + // It is important for the order of the inputs to the op defined here + // to match the order in input_names because the indexes are used in + // the combining transformation. + for (const auto& parameter : state_variable_specs) { + if (parameter.has_user_defined() || is_debug_op) { + auto* arg = op_def->add_input_arg(); + arg->set_name(parameter.name()); + arg->set_type(DT_FLOAT); + } + } + { + auto* table_id_attr = op_def->add_attr(); + table_id_attr->set_name("table_id"); + table_id_attr->set_type("int"); + table_id_attr->set_has_minimum(true); + table_id_attr->set_minimum(-1); + table_id_attr->mutable_default_value()->set_i(-1); + } + { + auto* table_name_attr = op_def->add_attr(); + table_name_attr->set_name("table_name"); + table_name_attr->set_type("string"); + table_name_attr->mutable_default_value()->set_s(""); + } + { + auto* num_shards_attr = op_def->add_attr(); + num_shards_attr->set_name("num_shards"); + num_shards_attr->set_type("int"); + } + { + auto* shard_id_attr = op_def->add_attr(); + shard_id_attr->set_name("shard_id"); + shard_id_attr->set_type("int"); + } + string parameter_descriptions; + for (const auto& parameter : state_variable_specs) { + if (parameter.has_user_defined() || is_debug_op) { + strings::Appendf(¶meter_descriptions, + R"( +%s: A tensor containing the initial embedding table %s to use in embedding +lookups using the %s optimization algorithm.)", + parameter.name().c_str(), parameter.name().c_str(), + GetOptimizationAlgorithmFriendlyName(alg).c_str()); + } + } + op_def->set_is_commutative(false); + op_def->set_is_aggregate(false); + op_def->set_is_stateful(true); + auto shape_inference_function = + [state_variable_specs, + is_debug_op](shape_inference::InferenceContext* c) -> Status { + int table_id; + TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); + string table_name; + TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); + // Exactly one must be non-default. + if ((table_id >= 0) == (!table_name.empty())) { + return errors::InvalidArgument( + "exactly one of table_id or table_name must be non-default"); + } + int num_shards; + TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); + int shard_id; + TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id)); + const int user_param_count = + std::count_if(state_variable_specs.begin(), state_variable_specs.end(), + [&](const StateVariableSpecification& sv) { + return sv.has_user_defined() || is_debug_op; + }); + std::vector inputs(user_param_count); + int input_index = 0; + for (int i = 0; i < state_variable_specs.size(); ++i) { + if (state_variable_specs[i].has_user_defined() || is_debug_op) { + std::vector input_temp; + TF_RETURN_IF_ERROR( + c->input(state_variable_specs[i].name(), &input_temp)); + if (input_temp.size() != 1) { + return errors::InvalidArgument("each input to be rank 1"); + } + inputs[input_index] = input_temp[0]; + ++input_index; + } + } + // Verify shapes have rank 2 and are compatible when they are + // required to be valid. + shape_inference::ShapeHandle parameter_shape; + TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, ¶meter_shape)); + for (int j = 1; j < user_param_count; ++j) { + shape_inference::ShapeHandle accumulator_j_shape; + TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape)); + shape_inference::ShapeHandle merged; + TF_RETURN_IF_ERROR( + c->Merge(parameter_shape, accumulator_j_shape, &merged)); + } + return Status::OK(); + }; + op_reg_data->shape_inference_fn = shape_inference_function; + return Status::OK(); +} + +Status RegisterPerTableRetrieveOpsForAlgorithmBody( + OptimizationAlgorithm alg, bool is_debug_op, + OpRegistrationData* op_reg_data) { + GradientAccumulationSupport grad_accum_support; + TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support)); + + std::vector state_variable_specs; + TF_CHECK_OK(GetOptimizationAlgorithmStateVariables( + alg, + grad_accum_support == GradientAccumulationSupport::kSupported && + is_debug_op, + &state_variable_specs)); + + auto* op_def = &op_reg_data->op_def; + op_def->set_name( + strings::StrCat("RetrieveTPUEmbedding", GetOptimizationAlgorithmName(alg), + "Parameters", (is_debug_op ? "GradAccumDebug" : ""))); + // It is important for the order of the outputs of the op defined here + // to match the order in output_names because the indexes are used in + // the combining transformation. + for (const auto& parameter : state_variable_specs) { + if (parameter.has_user_defined() || is_debug_op) { + auto* arg = op_def->add_output_arg(); + arg->set_name(parameter.name()); + arg->set_type(DT_FLOAT); + } + } + { + auto* table_id_attr = op_def->add_attr(); + table_id_attr->set_name("table_id"); + table_id_attr->set_type("int"); + table_id_attr->set_has_minimum(true); + table_id_attr->set_minimum(-1); + table_id_attr->mutable_default_value()->set_i(-1); + } + { + auto* table_name_attr = op_def->add_attr(); + table_name_attr->set_name("table_name"); + table_name_attr->set_type("string"); + table_name_attr->mutable_default_value()->set_s(""); + } + { + auto* num_shards_attr = op_def->add_attr(); + num_shards_attr->set_name("num_shards"); + num_shards_attr->set_type("int"); + } + { + auto* shard_id_attr = op_def->add_attr(); + shard_id_attr->set_name("shard_id"); + shard_id_attr->set_type("int"); + } + string parameter_descriptions; + for (const auto& param : state_variable_specs) { + if (param.has_user_defined() || is_debug_op) { + strings::Appendf(¶meter_descriptions, + R"( +%s: A tensor containing the embedding table %s to store with the +parameters from embedding updates using the %s optimization algorithm.)", + param.name().c_str(), param.name().c_str(), + GetOptimizationAlgorithmFriendlyName(alg).c_str()); + } + } + op_def->set_is_commutative(false); + op_def->set_is_aggregate(false); + op_def->set_is_stateful(true); + auto shape_inference_function = + [state_variable_specs, + is_debug_op](shape_inference::InferenceContext* c) -> Status { + int table_id; + TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); + string table_name; + TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); + // Exactly one must be non-default. + if ((table_id >= 0) == (!table_name.empty())) { + return errors::InvalidArgument( + "exactly one of table_id or table_name must be non-default"); + } + int num_shards; + TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); + int shard_id; + TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id)); + for (int j = 0; j < state_variable_specs.size(); ++j) { + if (state_variable_specs[j].has_user_defined() || is_debug_op) { + auto shape = c->MakeShape( + std::vector(2, c->UnknownDim())); + TF_RETURN_IF_ERROR( + c->set_output(state_variable_specs[j].name(), + std::vector(1, shape))); + } + } + return Status::OK(); + }; + op_reg_data->shape_inference_fn = shape_inference_function; + return Status::OK(); +} + +Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg, + bool* internal) { + switch (alg) { + case OptimizationAlgorithm::kAdagrad: + case OptimizationAlgorithm::kStochasticGradientDescent: + case OptimizationAlgorithm::kFtrl: + case OptimizationAlgorithm::kAdam: + case OptimizationAlgorithm::kMomentum: + case OptimizationAlgorithm::kRmsProp: + case OptimizationAlgorithm::kCenteredRmsProp: + case OptimizationAlgorithm::kMdlAdagradLight: + case OptimizationAlgorithm::kAdadelta: + case OptimizationAlgorithm::kProximalAdagrad: { + *internal = false; + return Status::OK(); + } + case OptimizationAlgorithm::kBoundedAdagrad: { + *internal = true; + return Status::OK(); + } + case OptimizationAlgorithm::PARAMETERS_NOT_SET: + return errors::InvalidArgument("No optimization algorithm specified"); + } +} + } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h index ceb07ff3551..e7516da8f39 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_ #include + #include "absl/base/casts.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h" @@ -84,6 +86,23 @@ static constexpr int kMaxAuxiliaryParameterCount = 3; // already been applied to the parameters and accumulators. const float kGradientAccumulatorInitialValue = absl::bit_cast(1); +// Computes registration data for per table load Op. Each load Op transfers +// the embedding parameters from the host memory to the TPU memory. +Status RegisterPerTableLoadOpsForAlgorithmBody(OptimizationAlgorithm alg, + bool is_debug_op, + OpRegistrationData *op_reg_data); + +// Computes registration data for per table retrieve Op. Each retrieve Op +// transfers the embedding parameters from the TPU memory to the host memory. +Status RegisterPerTableRetrieveOpsForAlgorithmBody( + OptimizationAlgorithm alg, bool is_debug_op, + OpRegistrationData *op_reg_data); + +// Returns whether an optimization algorithm is only supported internally. +// Returns an error if the algorithm is not recongized at all. +Status IsOptimizationAlgorithmInternal(OptimizationAlgorithm alg, + bool *internal); + } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.cc b/tensorflow/core/util/ctc/ctc_loss_calculator.cc index 8641e2c3e2d..a0ac5eec4bc 100644 --- a/tensorflow/core/util/ctc/ctc_loss_calculator.cc +++ b/tensorflow/core/util/ctc/ctc_loss_calculator.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/util/ctc/ctc_loss_calculator.h" +#include namespace tensorflow { namespace ctc { @@ -34,10 +35,10 @@ void CTCLossCalculator::CalculateForwardVariables( CHECK_EQ(U, log_alpha->rows()); // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6. - log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_)); + log_alpha->coeffRef(0, 0) = std::log(y(blank_index_, output_delay_)); // Below, l_prime[1] == labels[0] auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_; - log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_)); + log_alpha->coeffRef(1, 0) = std::log(y(label_0, output_delay_)); for (int t = 1; t < T; ++t) { // If there is not enough time to output the remaining labels or @@ -69,7 +70,7 @@ void CTCLossCalculator::CalculateForwardVariables( } // Multiply the summed alphas with the activation log probability. log_alpha->coeffRef(u, t) = - log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha; + std::log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha; } // End (GravesTh) Eq 7.9. } } @@ -102,7 +103,7 @@ void CTCLossCalculator::CalculateBackwardVariables( log_beta->coeffRef(u, t) = LogSumExp(log_beta->coeff(u, t), log_beta->coeff(u, t + 1) + - log(y(l_prime[u], output_delay_ + t + 1))); + std::log(y(l_prime[u], output_delay_ + t + 1))); } // Add in the u + 1, t + 1 term. @@ -110,7 +111,7 @@ void CTCLossCalculator::CalculateBackwardVariables( log_beta->coeffRef(u, t) = LogSumExp(log_beta->coeff(u, t), log_beta->coeff(u + 1, t + 1) + - log(y(l_prime[u + 1], output_delay_ + t + 1))); + std::log(y(l_prime[u + 1], output_delay_ + t + 1))); } // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2). @@ -122,7 +123,7 @@ void CTCLossCalculator::CalculateBackwardVariables( log_beta->coeffRef(u, t) = LogSumExp(log_beta->coeff(u, t), log_beta->coeff(u + 2, t + 1) + - log(y(l_prime[u + 2], output_delay_ + t + 1))); + std::log(y(l_prime[u + 2], output_delay_ + t + 1))); } } // End (GravesTh) Eq. 7.15 } diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.h b/tensorflow/core/util/ctc/ctc_loss_calculator.h index dd1163310bf..5f4c4cd8a08 100644 --- a/tensorflow/core/util/ctc/ctc_loss_calculator.h +++ b/tensorflow/core/util/ctc/ctc_loss_calculator.h @@ -315,7 +315,8 @@ Status CTCLossCalculator::PopulateLPrimes( "Saw a non-null label (index >= num_classes - 1) " "following a ", "null label, batch: ", b, " num_classes: ", num_classes, - " labels: ", str_util::Join(l, ",")); + " labels: ", str_util::Join(label, ","), + " labels seen so far: ", str_util::Join(l, ",")); } l.push_back(label[i]); } diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 56e618872a7..35d34221a6c 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -327,10 +327,11 @@ bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern, return true; } -/* static */ -Status DeviceNameUtils::MergeDevNames(ParsedName* target, - const ParsedName& other, - bool allow_soft_placement) { +namespace { +Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target, + const DeviceNameUtils::ParsedName& other, + bool allow_soft_placement, bool override_conflicts) { + const auto& ParsedNameToString = DeviceNameUtils::ParsedNameToString; if (other.has_job) { if (target->has_job && target->job != other.job) { return errors::InvalidArgument( @@ -374,6 +375,8 @@ Status DeviceNameUtils::MergeDevNames(ParsedName* target, "Cannot merge devices with incompatible types: '", ParsedNameToString(*target), "' and '", ParsedNameToString(other), "'"); + } else if (override_conflicts) { + target->type = other.type; } else { target->has_id = false; target->has_type = false; @@ -392,6 +395,8 @@ Status DeviceNameUtils::MergeDevNames(ParsedName* target, "Cannot merge devices with incompatible ids: '", ParsedNameToString(*target), "' and '", ParsedNameToString(other), "'"); + } else if (override_conflicts) { + target->id = other.id; } else { target->has_id = false; return Status::OK(); @@ -405,6 +410,23 @@ Status DeviceNameUtils::MergeDevNames(ParsedName* target, return Status::OK(); } +} // namespace + +/* static */ +Status DeviceNameUtils::MergeDevNames(ParsedName* target, + const ParsedName& other, + bool allow_soft_placement) { + return MergeDevNamesImpl(target, other, allow_soft_placement, + /*override_conflicts=*/false); +} + +/* static */ +Status DeviceNameUtils::MergeOverrideDevNames(ParsedName* target, + const ParsedName& other) { + return MergeDevNamesImpl(target, other, /*allow_soft_placement=*/true, + /*override_conflicts=*/true); +} + /* static */ bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a, const ParsedName& b) { @@ -516,4 +538,10 @@ std::vector DeviceNameUtils::GetLocalNamesForDeviceMappings( return Status::OK(); } +std::ostream& operator<<(std::ostream& os, + const DeviceNameUtils::ParsedName& x) { + os << DeviceNameUtils::ParsedNameToString(x); + return os; +} + } // namespace tensorflow diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h index b047e814bd6..651231d4db8 100644 --- a/tensorflow/core/util/device_name_utils.h +++ b/tensorflow/core/util/device_name_utils.h @@ -86,6 +86,14 @@ class DeviceNameUtils { int id = 0; }; // Parses "fullname" into "*parsed". Returns true iff succeeds. + // Legacy names like "/cpu:0" that don't contain "device", + // are parsed to mean their current counterparts "/device:CPU:0". More + // specifically, the lower case "cpu" and "gpu" is capitalized and "device" + // is added. "/tpu:0" is not treated the same way - it has use the current + // full syntax. + // Also, note that lower case "cpu" and "gpu" device types in current syntax + // are not capitalized. For example, "/device:CPU:0" is different from + // "/device:cpu:0" static bool ParseFullName(StringPiece fullname, ParsedName* parsed); // Canonicalizes "fullname" into "*canonical_name". Uses a fully specified @@ -135,6 +143,10 @@ class DeviceNameUtils { } static Status MergeDevNames(ParsedName* target, const ParsedName& other, bool allow_soft_placement); + // Same as MergeDevNames with allow_soft_placement=true, but instead of + // clearing conflicting fields, overrides them with `other`'s values. + static Status MergeOverrideDevNames(ParsedName* target, + const ParsedName& other); // Returns true iff devices identified by 'src' and 'dst' are in the // same address space. @@ -181,6 +193,9 @@ class DeviceNameUtils { string* host_device_name); }; +std::ostream& operator<<(std::ostream& os, + const DeviceNameUtils::ParsedName& x); + } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_ diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc index dafb3b20b9e..49f7fe4ac20 100644 --- a/tensorflow/core/util/device_name_utils_test.cc +++ b/tensorflow/core/util/device_name_utils_test.cc @@ -412,6 +412,19 @@ static void MergeDevNamesError(const string& name_a, const string& name_b, << s; } +static void MergeOverrideHelper(const string& target, const string& name, + const string& expected_merge_name) { + DeviceNameUtils::ParsedName parsed_target = Name(target); + TF_EXPECT_OK( + DeviceNameUtils::MergeOverrideDevNames(&parsed_target, Name(name))); + DeviceNameUtils::ParsedName parsed_expected = Name(expected_merge_name); + + EXPECT_EQ(parsed_target, parsed_expected) + << "parsed_target: " << DeviceNameUtils::ParsedNameToString(parsed_target) + << " expected_name: " + << DeviceNameUtils::ParsedNameToString(parsed_expected); +} + TEST(DeviceNameUtilsTest, MergeDevNames) { DeviceNameUtils::ParsedName target; @@ -425,7 +438,7 @@ TEST(DeviceNameUtilsTest, MergeDevNames) { MergeDevNamesHelper("", "/job:foo", "/job:foo"); MergeDevNamesHelper("", "/replica:2", "/replica:2"); MergeDevNamesHelper("", "/task:7", "/task:7"); - // MergeDevNamesHelper("", "/device:GPU:1", "/device:GPU:1"); + MergeDevNamesHelper("", "/device:GPU:1", "/device:GPU:1"); // Combining disjoint names. MergeDevNamesHelper("/job:foo", "/task:7", "/job:foo/task:7"); @@ -455,6 +468,46 @@ TEST(DeviceNameUtilsTest, MergeDevNamesAllowSoftPlacement) { MergeDevNamesHelperAllowSoftPlacement("/device:GPU:1", "/device:GPU:2", "/device:GPU:*"); } + +TEST(DeviceNameUtilsTest, MergeOverrideDevNames) { + // Idempotence tests. + MergeOverrideHelper("", "", ""); + MergeOverrideHelper("/job:foo/replica:1/task:2/cpu:1", + "/job:foo/replica:1/task:2/cpu:1", + "/job:foo/replica:1/task:2/cpu:1"); + + // Merging with empty device has no effect. + MergeOverrideHelper("", "/job:foo", "/job:foo"); + MergeOverrideHelper("", "/replica:2", "/replica:2"); + MergeOverrideHelper("", "/task:7", "/task:7"); + MergeOverrideHelper("", "/device:GPU:1", "/device:GPU:1"); + + // Combining disjoint names. + MergeOverrideHelper("/job:foo", "/task:7", "/job:foo/task:7"); + MergeOverrideHelper("/job:foo", "/device:GPU:1", "/job:foo/device:GPU:1"); + + // Combining overlapping names. + MergeOverrideHelper("/job:foo/replica:0", "/replica:0/task:1", + "/job:foo/replica:0/task:1"); + + // Wildcard tests. + MergeOverrideHelper("", "/gpu:*", "/gpu:*"); + MergeOverrideHelper("/gpu:*", "/gpu:*", "/gpu:*"); + MergeOverrideHelper("/device:GPU:1", "/gpu:*", "/device:GPU:1"); + + // Testing actual override functionality + MergeOverrideHelper("/gpu:0", "/cpu:1", "/cpu:1"); + MergeOverrideHelper("/gpu:*", "/cpu:1", "/cpu:1"); + MergeOverrideHelper("/cpu:*", "/device:GPU:1", "/gpu:1"); + MergeOverrideHelper("/device:GPU:1", "/device:GPU:2", "/device:GPU:2"); + + // Override with regular merging + MergeOverrideHelper("/job:foo/CPU:*", "/device:GPU:1", "/job:foo/GPU:1"); + MergeOverrideHelper("/cpu:*", "/job:foo/device:GPU:1", "/job:foo/GPU:1"); + MergeOverrideHelper("/task:0/cpu:*", "/device:GPU:1", "/task:0/GPU:1"); + MergeOverrideHelper("/cpu:*", "/task:0/device:GPU:1", "/task:0/GPU:1"); +} + TEST(DeviceNameUtilsTest, GetNamesForDeviceMappings) { DeviceNameUtils::ParsedName p = Name("/job:foo/replica:10/task:0/device:GPU:1"); diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto index 2d3ae627773..ee1040d7574 100644 --- a/tensorflow/core/util/event.proto +++ b/tensorflow/core/util/event.proto @@ -1,13 +1,14 @@ syntax = "proto3"; package tensorflow; + +import "tensorflow/core/framework/summary.proto"; + option cc_enable_arenas = true; option java_outer_classname = "EventProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.util"; -import "tensorflow/core/framework/summary.proto"; - // Protocol buffer representing an event that happened during // the execution of a Brain model. message Event { @@ -89,6 +90,7 @@ enum WorkerHealth { OK = 0; // By default a worker is healthy. RECEIVED_SHUTDOWN_SIGNAL = 1; INTERNAL_ERROR = 2; + SHUTTING_DOWN = 3; // Worker has been instructed to shutdown after a timeout. } // Indicates the behavior of the worker when an internal error or shutdown @@ -97,6 +99,7 @@ enum WorkerShutdownMode { DEFAULT = 0; NOT_CONFIGURED = 1; WAIT_FOR_COORDINATOR = 2; + SHUTDOWN_AFTER_TIMEOUT = 3; } message WatchdogConfig { diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index 4c29bd582e6..37c7166c5cf 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -20,9 +20,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb_text.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -157,22 +159,23 @@ class Feature { return false; } + constexpr int32 kNumFloatBytes = 4; if (peek_tag == kDelimitedTag(1)) { // packed if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag uint32 packed_length; if (!stream.ReadVarint32(&packed_length)) return false; auto packed_limit = stream.PushLimit(packed_length); + // Store the initial size to know the offset we have to start writing + // data from before resizing the output "vector". + const size_t initial_size = float_list->size(); + float_list->resize(initial_size + packed_length / kNumFloatBytes); + // If the result data type is float and we are on a little endian // machine then we can simply memcpy the data from the proto into the // result vector. - constexpr int32 kNumFloatBytes = 4; if (port::kLittleEndian && sizeof(typename Result::value_type) == kNumFloatBytes) { - // Store the initial size to know the offset we have to start writing - // data from before resizing the output "vector". - const size_t initial_size = float_list->size(); - float_list->resize(initial_size + packed_length / kNumFloatBytes); // Calculate the length of the buffer available what can be less than // what we requested in resize in case of a LimitedArraySlice. const uint32 bytes_to_copy = @@ -182,20 +185,32 @@ class Feature { if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy)) return false; } else { + int64 index = initial_size; while (!stream.ExpectAtEnd()) { uint32 buffer32; if (!stream.ReadLittleEndian32(&buffer32)) return false; - float_list->push_back(absl::bit_cast(buffer32)); + if (index < float_list->size()) { + float_list->data()[index] = absl::bit_cast(buffer32); + ++index; + } } } stream.PopLimit(packed_limit); } else { // non-packed + const size_t initial_size = float_list->size(); + // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for + // the value. + const int64 num_elements = + stream.BytesUntilLimit() / (1 + kNumFloatBytes); + float_list->resize(initial_size + num_elements); + int64 index = initial_size; while (!stream.ExpectAtEnd()) { if (!stream.ExpectTag(kFixed32Tag(1))) return false; uint32 buffer32; if (!stream.ReadLittleEndian32(&buffer32)) return false; - float_list->push_back(absl::bit_cast(buffer32)); + float_list->data()[index] = absl::bit_cast(buffer32); + ++index; } } } @@ -949,6 +964,41 @@ void FillAndCopyVarLen( } } +// Thin vector like interface wrapper around a Tensor. This enable us to +// directly populate a tensor during parsing instead of having to first create a +// vactor and then copy the data over. +template +class TensorVector { + public: + using value_type = T; + + const Tensor& tensor() { + if (!tensor_.has_value()) { + resize(0); + } + return *tensor_; + } + + int64 size() const { + return tensor_.has_value() ? tensor_->NumElements() : 0; + } + void resize(int64 new_size) { + DCHECK(!tensor_.has_value()); + tensor_ = Tensor(DataTypeToEnum::v(), TensorShape({new_size})); + data_ = tensor_->flat().data(); + } + T* data() { return data_; } + const T* data() const { return data_; } + + private: + // Use absl::optional to avoid calling the default constructor of Tensor + // unnecessarily. + absl::optional tensor_; + + // Cached pointer to the raw data inside the tensor. + T* data_ = nullptr; +}; + } // namespace Status FastParseExample(const Config& config, @@ -1401,7 +1451,10 @@ Status FastParseSingleExample(const Config& config, const string& serialized, } } else { // if variable length - SparseBuffer out_temp; + SmallVector bytes_list; + TensorVector float_list; + SmallVector int64_list; + const size_t num_elements_divisor = is_dense ? config.dense[d].elements_per_stride : 1; size_t num_elements; @@ -1443,17 +1496,13 @@ Status FastParseSingleExample(const Config& config, const string& serialized, case DT_INT64: { // TODO(mrry): Use the fact that the `int64_list` is packed to read // out the length and pre-allocate the output tensor. - if (!feature.ParseInt64List(&out_temp.int64_list)) - return parse_error(); - num_elements = out_temp.int64_list.size(); + if (!feature.ParseInt64List(&int64_list)) return parse_error(); + num_elements = int64_list.size(); break; } case DT_FLOAT: { - // TODO(mrry): Use the fact that the `float_list` is packed to read - // out the length and pre-allocate the output tensor. - if (!feature.ParseFloatList(&out_temp.float_list)) - return parse_error(); - num_elements = out_temp.float_list.size(); + if (!feature.ParseFloatList(&float_list)) return parse_error(); + num_elements = float_list.size(); break; } case DT_STRING: { @@ -1461,10 +1510,9 @@ Status FastParseSingleExample(const Config& config, const string& serialized, if (!feature.GetNumElementsInBytesList(&actual_num_elements)) { return parse_error(); } - out_temp.bytes_list.reserve(actual_num_elements); - if (!feature.ParseBytesList(&out_temp.bytes_list)) - return parse_error(); - num_elements = out_temp.bytes_list.size(); + bytes_list.reserve(actual_num_elements); + if (!feature.ParseBytesList(&bytes_list)) return parse_error(); + num_elements = bytes_list.size(); break; } default: @@ -1480,20 +1528,19 @@ Status FastParseSingleExample(const Config& config, const string& serialized, } Tensor* out; + DataType out_dtype; + TensorShape out_shape; if (is_dense) { - TensorShape values_shape; - values_shape.AddDim(num_elements / num_elements_divisor); + out_shape.AddDim(num_elements / num_elements_divisor); for (int i = 1; i < config.dense[d].shape.dims(); ++i) { - values_shape.AddDim(config.dense[d].shape.dim_size(i)); + out_shape.AddDim(config.dense[d].shape.dim_size(i)); } out = &result->dense_values[d]; - *out = Tensor(config.dense[d].dtype, values_shape); - + out_dtype = config.dense[d].dtype; } else { Tensor* out_indices = &result->sparse_indices[d]; Tensor* out_dense_shape = &result->sparse_shapes[d]; - out = &result->sparse_values[d]; // TODO(mrry): Investigate the possibility of not materializing // the indices (and perhaps dense_shape) until they are needed. @@ -1508,24 +1555,27 @@ Status FastParseSingleExample(const Config& config, const string& serialized, auto shapes_shape_t = out_dense_shape->vec(); shapes_shape_t(0) = num_elements; - *out = Tensor(config.sparse[d].dtype, - TensorShape({static_cast(num_elements)})); + out = &result->sparse_values[d]; + out_dtype = config.sparse[d].dtype; + out_shape.AddDim(num_elements); } switch (example_dtype) { case DT_INT64: { - CopyOrMoveBlock(out_temp.int64_list.begin(), - out_temp.int64_list.end(), out->flat().data()); + *out = Tensor(out_dtype, out_shape); + CopyOrMoveBlock(int64_list.begin(), int64_list.end(), + out->flat().data()); break; } case DT_FLOAT: { - CopyOrMoveBlock(out_temp.float_list.begin(), - out_temp.float_list.end(), out->flat().data()); + if (!out->CopyFrom(float_list.tensor(), out_shape)) { + return parse_error(); + } break; } case DT_STRING: { - CopyOrMoveBlock(out_temp.bytes_list.begin(), - out_temp.bytes_list.end(), + *out = Tensor(out_dtype, out_shape); + CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(), out->flat().data()); break; } @@ -1998,6 +2048,10 @@ Status FastParseSequenceExample( feature_list_result->dense_values.resize(feature_list_config.dense.size()); dense_feature_lengths->resize(feature_list_config.dense.size()); + // NOTE(mrry): Cache the CPU allocator here and use it in Tensor construction, + // to avoid lock contention in `tensorflow::cpu_allocator()`. + Allocator* allocator = tensorflow::cpu_allocator(); + int t = 0; for (const auto& c : context_config.dense) { TensorShape dense_shape, example_shape; @@ -2014,7 +2068,7 @@ Status FastParseSequenceExample( for (const int dim : c.shape.dim_sizes()) { dense_shape.AddDim(dim); } - context_result->dense_values[t] = Tensor(dtype, dense_shape); + context_result->dense_values[t] = Tensor(allocator, dtype, dense_shape); // TODO(sundberg): Refactor to reduce code duplication, and add bounds // checking for the outputs. @@ -2122,9 +2176,11 @@ Status FastParseSequenceExample( indices_shape.AddDim(expected_num_elements); indices_shape.AddDim(2); values_shape.AddDim(expected_num_elements); - context_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape); - context_result->sparse_values[t] = Tensor(dtype, values_shape); - context_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({2})); + context_result->sparse_indices[t] = + Tensor(allocator, DT_INT64, indices_shape); + context_result->sparse_values[t] = Tensor(allocator, dtype, values_shape); + context_result->sparse_shapes[t] = + Tensor(allocator, DT_INT64, TensorShape({2})); // TODO(sundberg): Refactor to reduce code duplication, and add bounds // checking for the outputs. string* out_bytes = nullptr; @@ -2212,8 +2268,10 @@ Status FastParseSequenceExample( for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) { dense_shape.AddDim(dim); } - feature_list_result->dense_values[t] = Tensor(dtype, dense_shape); - (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape); + feature_list_result->dense_values[t] = + Tensor(allocator, dtype, dense_shape); + (*dense_feature_lengths)[t] = + Tensor(allocator, DT_INT64, dense_length_shape); int64* out_lengths = (*dense_feature_lengths)[t].flat().data(); string* out_bytes = nullptr; @@ -2320,9 +2378,12 @@ Status FastParseSequenceExample( indices_shape.AddDim(expected_num_elements); indices_shape.AddDim(3); values_shape.AddDim(expected_num_elements); - feature_list_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape); - feature_list_result->sparse_values[t] = Tensor(dtype, values_shape); - feature_list_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({3})); + feature_list_result->sparse_indices[t] = + Tensor(allocator, DT_INT64, indices_shape); + feature_list_result->sparse_values[t] = + Tensor(allocator, dtype, values_shape); + feature_list_result->sparse_shapes[t] = + Tensor(allocator, DT_INT64, TensorShape({3})); string* out_bytes = nullptr; float* out_float = nullptr; diff --git a/tensorflow/core/util/gpu_cuda_alias.h b/tensorflow/core/util/gpu_cuda_alias.h new file mode 100644 index 00000000000..0a15d15e04a --- /dev/null +++ b/tensorflow/core/util/gpu_cuda_alias.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_GPU_CUDA_ALIAS_H_ +#define TENSORFLOW_CORE_UTIL_GPU_CUDA_ALIAS_H_ + +// Several forwarding macros are defined in this file to serve for backward +// compatibility usage as we migrating from CUDA prefixed function to GPU +// prefixed functions. Both Cuda and ROCm can unify under the new GPU prefix +// naming scheme. In the migration period, we provide equivalent CUDA* and GPU* +// function. Over time, all CUDA* functions will be deprecated. + +namespace tensorflow { + +// CREATE_CUDA_HOST_FUNCTION_ALIAS forward the host function to its CUDA Alias. +#ifndef TENSORFLOW_USE_ROCM +#define CREATE_CUDA_HOST_FUNCTION_ALIAS(func, cuda_alias) \ + template \ + auto cuda_alias(Args&&... args) \ + ->decltype(func(std::forward(args)...)) { \ + return func(std::forward(args)...); \ + } +#else +#define CREATE_CUDA_HOST_FUNCTION_ALIAS(func, cuda_alias) +#endif + +// CREATE_CUDA_DEVICE_FUNCTION_ALIAS forward the device function to its CUDA +// Alias. +#ifndef TENSORFLOW_USE_ROCM +#define CREATE_CUDA_DEVICE_FUNCTION_ALIAS(func, cuda_alias) \ + template \ + __device__ auto cuda_alias(Args&&... args) \ + ->decltype(func(std::forward(args)...)) { \ + return func(std::forward(args)...); \ + } +#else +#define CREATE_CUDA_DEVICE_FUNCTION_ALIAS(func, cuda_alias) +#endif + +// CREATE_CUDA_TYPE_ALIAS forward the type to its CUDA Alias. +#ifndef TENSORFLOW_USE_ROCM +#define CREATE_CUDA_TYPE_ALIAS(type, cuda_alias) using cuda_alias = type; +#else +#define CREATE_CUDA_TYPE_ALIAS(type, cuda_alias) +#endif +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_GPU_CUDA_ALIAS_H_ diff --git a/tensorflow/core/util/cuda_device_functions.h b/tensorflow/core/util/gpu_device_functions.h similarity index 71% rename from tensorflow/core/util/cuda_device_functions.h rename to tensorflow/core/util/gpu_device_functions.h index b91f8bb8ef0..7230150c899 100644 --- a/tensorflow/core/util/cuda_device_functions.h +++ b/tensorflow/core/util/gpu_device_functions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_ -#define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_ +#ifndef TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_ +#define TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_ /** * Wrappers and helpers for CUDA device code. @@ -24,12 +24,16 @@ limitations under the License. * Provides atomic operations on types that aren't natively supported. */ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include #include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "cuda/include/cuda.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cuda.h" +#endif #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -113,7 +117,15 @@ const unsigned kCudaWarpAll = 0xffffffff; // Returns the warp lane ID of the calling thread __device__ inline unsigned CudaLaneId() { unsigned int lane_id; +#if GOOGLE_CUDA +#if __clang__ + return __nvvm_read_ptx_sreg_laneid(); +#else // __clang__ asm("mov.u32 %0, %%laneid;" : "=r"(lane_id)); +#endif // __clang__ +#elif TENSORFLOW_USE_ROCM + land_id = __lane_id(); +#endif return lane_id; } @@ -135,7 +147,12 @@ __device__ inline bool CudaValidateShuffleSyncMask(unsigned mask, #if CUDA_VERSION >= 9000 unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane); #else +#if GOOGLE_CUDA unsigned src_lane_mask = __shfl(mask, src_lane); +#elif TENSORFLOW_USE_ROCM + unsigned src_lane_mask = + __shfl(static_cast(mask), static_cast(src_lane)); +#endif #endif return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask; } @@ -248,12 +265,22 @@ __device__ T CudaShuffleSync(unsigned mask, T value, int src_lane, // See b/69446944. __device__ inline double CudaShuffleSync(unsigned mask, double value, int src_lane, int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); +#if GOOGLE_CUDA + auto tmp = __double_as_longlong(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); hi = CudaShuffleSync(mask, hi, src_lane, width); lo = CudaShuffleSync(mask, lo, src_lane, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; + return __longlong_as_double(static_cast(hi) << 32 | lo); +#elif TENSORFLOW_USE_ROCM + auto tmp = static_cast(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = __shfl(static_cast(hi), src_lane, width); + lo = __shfl(static_cast(lo), src_lane, width); + return static_cast(static_cast(hi) << 32 | + static_cast(lo)); +#endif } // Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in @@ -277,12 +304,22 @@ __device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta, __device__ inline double CudaShuffleUpSync(unsigned mask, double value, unsigned delta, int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); +#if GOOGLE_CUDA + auto tmp = __double_as_longlong(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); hi = CudaShuffleUpSync(mask, hi, delta, width); lo = CudaShuffleUpSync(mask, lo, delta, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; + return __longlong_as_double(static_cast(hi) << 32 | lo); +#elif TENSORFLOW_USE_ROCM + auto tmp = static_cast(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = __shfl_up(static_cast(hi), delta, width); + lo = __shfl_up(static_cast(lo), delta, width); + return static_cast(static_cast(hi) << 32 | + static_cast(lo)); +#endif } // Wrapper for __shfl_down_sync. All threads in 'mask' must call this function @@ -306,12 +343,22 @@ __device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta, __device__ inline double CudaShuffleDownSync(unsigned mask, double value, unsigned delta, int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); +#if GOOGLE_CUDA + auto tmp = __double_as_longlong(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); hi = CudaShuffleDownSync(mask, hi, delta, width); lo = CudaShuffleDownSync(mask, lo, delta, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; + return __longlong_as_double(static_cast(hi) << 32 | lo); +#elif TENSORFLOW_USE_ROCM + auto tmp = static_cast(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = __shfl_down(static_cast(hi), delta, width); + lo = __shfl_down(static_cast(lo), delta, width); + return static_cast(static_cast(hi) << 32 | + static_cast(lo)); +#endif } // Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in @@ -322,25 +369,54 @@ __device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask, assert(!(width & width - 1)); assert(detail::CudaValidateShuffleSyncMask( mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width))); +#if GOOGLE_CUDA #if CUDA_VERSION >= 9000 return __shfl_xor_sync(mask, value, lane_mask, width); #else return __shfl_xor(value, lane_mask, width); #endif +#elif TENSORFLOW_USE_ROCM + return __shfl_xor(static_cast(value), lane_mask, width); +#endif } +#if TENSORFLOW_USE_ROCM +__device__ inline Eigen::half GpuShuffleXorSync(unsigned mask, + Eigen::half value, + int lane_mask, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width))); + // TODO(rocm): This doesn't preserve NaN payload and flushes denorms to zero, + // maybe this should be implemented differently? + return static_cast( + __shfl_xor(static_cast(value), lane_mask, width)); +} +#endif + // Variant of the (undocumented) version from the CUDA SDK, but using unsigned // instead of float for lo and hi (which is incorrect with ftz, for example). // See b/69446944. __device__ inline double CudaShuffleXorSync(unsigned mask, double value, int lane_mask, int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); +#if GOOGLE_CUDA + auto tmp = __double_as_longlong(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); hi = CudaShuffleXorSync(mask, hi, lane_mask, width); lo = CudaShuffleXorSync(mask, lo, lane_mask, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; + return __longlong_as_double(static_cast(hi) << 32 | lo); +#elif TENSORFLOW_USE_ROCM + auto tmp = static_cast(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = __shfl_xor(static_cast(hi), lane_mask, width); + lo = __shfl_xor(static_cast(lo), lane_mask, width); + return static_cast(static_cast(hi) << 32 | + static_cast(lo)); +#endif } // Wrapper for __ldg. @@ -383,7 +459,8 @@ __host__ __device__ inline std::complex CudaLdg( template __global__ void SetZero(const int count, T* ptr) { // Check that the grid is one dimensional and index doesn't overflow. - assert(blockDim.y == 1 && blockDim.z == 1); + assert(blockDim.y == 1); + assert(blockDim.z == 1); assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); for (int i : CudaGridRangeX(count)) { ptr[i] = T(0); @@ -394,7 +471,8 @@ __global__ void SetZero(const int count, T* ptr) { template __global__ void SetToValue(const int count, T* ptr, T value) { // Check that the grid is one dimensional and index doesn't overflow. - assert(blockDim.y == 1 && blockDim.z == 1); + assert(blockDim.y == 1); + assert(blockDim.z == 1); assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); for (int i : CudaGridRangeX(count)) { ptr[i] = value; @@ -425,11 +503,24 @@ __device__ float CudaAtomicCasHelper(float* ptr, F accumulate) { } template __device__ double CudaAtomicCasHelper(double* ptr, F accumulate) { +#if TENSORFLOW_USE_ROCM + // FIXME: remove the workaround below once bug is fixed. + // HIP has a bug in the implementation of __longlong_as_double + // So workaround it by using reinterpret_cast. + uint64_t result = + CudaAtomicCasHelper(reinterpret_cast(ptr), + [accumulate](tensorflow::uint64 a) { + return __double_as_longlong( + accumulate(*(reinterpret_cast(&a)))); + }); + return *(reinterpret_cast(&result)); +#else return __longlong_as_double(CudaAtomicCasHelper( reinterpret_cast(ptr), [accumulate](tensorflow::uint64 a) { return __double_as_longlong(accumulate(__longlong_as_double(a))); })); +#endif } // Overload of above function for half. Note that we don't have @@ -493,30 +584,20 @@ __device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr, ptr, [value](Eigen::half a) { return a + value; }); } - #if __CUDA_ARCH__ < 600 __device__ inline double CudaAtomicAdd(double* ptr, double value) { return detail::CudaAtomicCasHelper(ptr, [value](double a) { return a + value; }); } -#elif __clang__ -// Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX. -// see https://reviews.llvm.org/D39638 -__device__ inline double CudaAtomicAdd(double* ptr, double value) { - double result; - asm volatile("atom.add.f64 %0, [%1], %2;" - : "=d"(result) - : "l"(ptr), "d"(value) - : "memory"); - return result; -} #endif + // CudaAtomicAdd // Specializations of CudaAtomicAdd for complex types, which CudaAtomicAdd does // not support. We treat a std::complex* as a T* (the C++ standard section // 26.4.4 allows this explicitly) and atomic add the real and imaginary // components individually. The operation as a whole is not atomic, but we can // safely treat the components independently for the purpose of accumulating. +#if GOOGLE_CUDA __device__ inline std::complex CudaAtomicAdd(std::complex* ptr, std::complex value) { auto ptr_scalar = reinterpret_cast(ptr); @@ -530,6 +611,7 @@ __device__ inline std::complex CudaAtomicAdd( return std::complex(CudaAtomicAdd(ptr_scalar, value.real()), CudaAtomicAdd(ptr_scalar + 1, value.imag())); } +#endif // CudaAtomicSub template @@ -563,6 +645,33 @@ __device__ detail::ToTypeIfConvertible CudaAtomicMax(T* ptr, U value) { return atomicMax(ptr, value); } +#if TENSORFLOW_USE_ROCM + +/* + * CUDA runtime headers have the following defined + * __device__ int max(int, int) + * __device__ float max(float, float) + * __device__ double max(double, double) + * + * and many others, where as HIP runtime headers only have the "int" version + * + * Therefore need to special case ROCm version to call the correct underlying + * routines for float and double types. + * + */ + +__device__ inline float CudaAtomicMax(float* ptr, float value) { + return detail::CudaAtomicCasHelper( + ptr, [value](float a) { return fmaxf(a, value); }); +} + +__device__ inline double CudaAtomicMax(double* ptr, double value) { + return detail::CudaAtomicCasHelper( + ptr, [value](double a) { return fmax(a, value); }); +} + +#else + __device__ inline float CudaAtomicMax(float* ptr, float value) { return detail::CudaAtomicCasHelper( ptr, [value](float a) { return max(a, value); }); @@ -573,6 +682,8 @@ __device__ inline double CudaAtomicMax(double* ptr, double value) { ptr, [value](double a) { return max(a, value); }); } +#endif + __device__ inline Eigen::half CudaAtomicMax(Eigen::half* ptr, Eigen::half value) { return detail::CudaAtomicCasHelper( @@ -593,6 +704,33 @@ __device__ detail::ToTypeIfConvertible CudaAtomicMin(T* ptr, U value) { return atomicMin(ptr, value); } +#if TENSORFLOW_USE_ROCM + +/* + * CUDA runtime headers have the following defined + * __device__ int min(int, int) + * __device__ float min(float, float) + * __device__ double min(double, double) + * + * and many others, where as HIP runtime headers only have the "int" version + * + * Therefore need to special case ROCm version to call the correct underlying + * routines for float and double types. + * + */ + +__device__ inline float CudaAtomicMin(float* ptr, float value) { + return detail::CudaAtomicCasHelper( + ptr, [value](float a) { return fminf(a, value); }); +} + +__device__ inline double CudaAtomicMin(double* ptr, double value) { + return detail::CudaAtomicCasHelper( + ptr, [value](double a) { return fmin(a, value); }); +} + +#else + __device__ inline float CudaAtomicMin(float* ptr, float value) { return detail::CudaAtomicCasHelper( ptr, [value](float a) { return min(a, value); }); @@ -603,6 +741,8 @@ __device__ inline double CudaAtomicMin(double* ptr, double value) { ptr, [value](double a) { return min(a, value); }); } +#endif + __device__ inline Eigen::half CudaAtomicMin(Eigen::half* ptr, Eigen::half value) { return detail::CudaAtomicCasHelper( @@ -629,7 +769,66 @@ __device__ detail::ToTypeIfConvertible CudaAtomicDiv(T* ptr, U value) { return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; }); } +// Operator overloads for complex numbers. +#if GOOGLE_CUDA +__device__ inline std::complex operator+(const std::complex& a, + const std::complex& b) { + auto result = cuCaddf(make_cuComplex(a.real(), a.imag()), + make_cuComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); +} + +__device__ inline std::complex operator-(const std::complex& a, + const std::complex& b) { + auto result = cuCsubf(make_cuComplex(a.real(), a.imag()), + make_cuComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); +} + +__device__ inline std::complex operator*(const std::complex& a, + const std::complex& b) { + auto result = cuCmulf(make_cuComplex(a.real(), a.imag()), + make_cuComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); +} + +__device__ inline std::complex operator/(const std::complex& a, + const std::complex& b) { + auto result = cuCdivf(make_cuComplex(a.real(), a.imag()), + make_cuComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); +} + +__device__ inline std::complex operator+( + const std::complex& a, const std::complex& b) { + auto result = cuCadd(make_cuDoubleComplex(a.real(), a.imag()), + make_cuDoubleComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); +} + +__device__ inline std::complex operator-( + const std::complex& a, const std::complex& b) { + auto result = cuCsub(make_cuDoubleComplex(a.real(), a.imag()), + make_cuDoubleComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); +} + +__device__ inline std::complex operator*( + const std::complex& a, const std::complex& b) { + auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()), + make_cuDoubleComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); +} + +__device__ inline std::complex operator/( + const std::complex& a, const std::complex& b) { + auto result = cuCdiv(make_cuDoubleComplex(a.real(), a.imag()), + make_cuDoubleComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); +} +#endif // GOOGLE_CUDA + } // namespace tensorflow -#endif // GOOGLE_CUDA -#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_ diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h similarity index 62% rename from tensorflow/core/util/cuda_kernel_helper.h rename to tensorflow/core/util/gpu_kernel_helper.h index f6f0408ccc1..71faa129e63 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/gpu_kernel_helper.h @@ -13,15 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ -#define TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ +#ifndef TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_ +#define TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda_fp16.h" +#endif +#include "tensorflow/core/util/gpu_device_functions.h" +#include "tensorflow/core/util/gpu_launch_config.h" -#include "tensorflow/core/util/cuda_device_functions.h" -#include "tensorflow/core/util/cuda_launch_config.h" - -#include "cuda/include/cuda_fp16.h" +#if GOOGLE_CUDA +#define TF_RED_WARPSIZE 32 +#elif TENSORFLOW_USE_ROCM +#define TF_RED_WARPSIZE 64 +#endif // Deprecated, use 'for(int i : CudaGridRangeX(n))' instead. #define CUDA_1D_KERNEL_LOOP(i, n) \ @@ -30,7 +37,54 @@ limitations under the License. #define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \ for (int i : ::tensorflow::CudaGridRange##axis(n)) +#if GOOGLE_CUDA +#define gpuSuccess cudaSuccess +using gpuStream_t = cudaStream_t; +using gpuError_t = cudaError_t; +#elif TENSORFLOW_USE_ROCM +#define gpuSuccess hipSuccess +using gpuStream_t = hipStream_t; +using gpuError_t = hipError_t; +#endif + +#define GetGPUStream(context) context->eigen_gpu_device().stream() + namespace tensorflow { +// Launches a GPU kernel through cudaLaunchKernel in CUDA environment, or +// hipLaunchKernel in ROCm environment with the given arguments. +// +// The kernel parameters 'Ts' must be constructible from the arguments 'Args'. +template +Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim, + size_t shared_memory_size_bytes, gpuStream_t stream, + Args... arguments) { + static_assert(detail::NoneIsReference(), + "Kernels with reference arguments have undefined behaviour."); +#if GOOGLE_CUDA + auto func_ptr = absl::bit_cast(function); + // Cast arguments and forward them as an array of pointers. + auto args_tuple = std::tuple(arguments...); + auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple); + auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(), + shared_memory_size_bytes, stream); + if (result != cudaSuccess) { + return errors::Internal(cudaGetErrorString(result)); + } +#elif TENSORFLOW_USE_ROCM + hipLaunchKernelGGL(function, grid_dim, block_dim, shared_memory_size_bytes, + stream, std::forward(arguments)...); +#endif + return Status::OK(); +} + +// Perfect forwarding to make CudaLaunchKernel available to both ROCm and CUDA +// builds +template +auto CudaLaunchKernel(Args&&... args) + -> decltype(GpuLaunchKernel(std::forward(args)...)) { + return GpuLaunchKernel(std::forward(args)...); +} + __host__ __device__ inline tensorflow::bfloat16 CudaLdg( const tensorflow::bfloat16* address) { tensorflow::bfloat16 return_value; @@ -136,5 +190,5 @@ __device__ OutType lower_bound(const T* first, OutType count, T val) { } // namespace cuda_helper } // namespace tensorflow -#endif // GOOGLE_CUDA -#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_ diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/gpu_kernel_helper_test.cu.cc similarity index 92% rename from tensorflow/core/util/cuda_kernel_helper_test.cu.cc rename to tensorflow/core/util/gpu_kernel_helper_test.cu.cc index 35f8d13f754..c3becb1509a 100644 --- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc +++ b/tensorflow/core/util/gpu_kernel_helper_test.cu.cc @@ -17,10 +17,11 @@ limitations under the License. #define EIGEN_USE_GPU #include + #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" -#include "tensorflow/core/util/cuda_launch_config.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" #define CUDA_EXPECT_SUCCESS \ { \ @@ -40,12 +41,12 @@ namespace tensorflow { namespace { -__global__ void SetOutbufZero(CudaLaunchConfig config, int* outbuf) { +__global__ void SetOutbufZero(GpuLaunchConfig config, int* outbuf) { CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { outbuf[x] = 0; } } // counting number of jobs by using atomic +1 -__global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) { +__global__ void Count1D(GpuLaunchConfig config, int bufsize, int* outbuf) { CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { if (x < 0) { // x might overflow when testing extreme case break; @@ -129,7 +130,7 @@ __global__ void CudaShuffleGetSrcLaneTest(unsigned* failure_count) { } // namespace -class CudaLaunchConfigTest : public ::testing::Test { +class GpuLaunchConfigTest : public ::testing::Test { protected: const int bufsize = 1024; int* outbuf = nullptr; @@ -148,28 +149,28 @@ class CudaLaunchConfigTest : public ::testing::Test { } }; -TEST_F(CudaLaunchConfigTest, GetCudaLaunchConfig) { - CudaLaunchConfig cfg; +TEST_F(GpuLaunchConfigTest, GetGpuLaunchConfig) { + GpuLaunchConfig cfg; // test valid inputs #define TEST_LAUNCH_PARAMETER(work_element_count) \ - cfg = GetCudaLaunchConfig(bufsize, d); \ + cfg = GetGpuLaunchConfig(bufsize, d); \ TF_CHECK_OK(CudaLaunchKernel(SetOutbufZero, cfg.block_count, \ cfg.thread_per_block, 0, d.stream(), cfg, \ outbuf)); \ CUDA_ASSERT_SUCCESS \ - cfg = GetCudaLaunchConfig(work_element_count, d); \ + cfg = GetGpuLaunchConfig(work_element_count, d); \ TF_CHECK_OK(CudaLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \ 0, d.stream(), cfg, bufsize, outbuf)); \ CUDA_EXPECT_SUCCESS \ EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)); \ \ - cfg = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + cfg = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ TF_CHECK_OK(CudaLaunchKernel(SetOutbufZero, cfg.block_count, \ cfg.thread_per_block, 0, d.stream(), cfg, \ outbuf)); \ CUDA_ASSERT_SUCCESS \ - cfg = GetCudaLaunchConfig(work_element_count, d, Count1D, 0, 0); \ + cfg = GetGpuLaunchConfig(work_element_count, d, Count1D, 0, 0); \ TF_CHECK_OK(CudaLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \ 0, d.stream(), cfg, bufsize, outbuf)); \ CUDA_EXPECT_SUCCESS \ @@ -200,13 +201,13 @@ bool operator==(const Cuda2DLaunchConfig& a, const Cuda2DLaunchConfig& b) { a.thread_per_block.z == b.thread_per_block.z; } -TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) { +TEST_F(GpuLaunchConfigTest, GetCuda2DLaunchConfig) { Cuda2DLaunchConfig cfg; - CudaLaunchConfig cfg1d; + GpuLaunchConfig cfg1d; // test valid inputs #define TEST_LAUNCH_PARAMETER(dimx, dimy) \ - cfg1d = GetCudaLaunchConfig(bufsize, d); \ + cfg1d = GetGpuLaunchConfig(bufsize, d); \ TF_EXPECT_OK(CudaLaunchKernel(SetOutbufZero, cfg1d.block_count, \ cfg1d.thread_per_block, 0, d.stream(), cfg1d, \ outbuf)); \ @@ -218,7 +219,7 @@ TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) { CUDA_EXPECT_SUCCESS \ EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \ \ - cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + cfg1d = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ TF_EXPECT_OK(CudaLaunchKernel(SetOutbufZero, cfg1d.block_count, \ cfg1d.thread_per_block, 0, d.stream(), cfg1d, \ outbuf)); \ @@ -244,13 +245,13 @@ TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) { #undef TEST_LAUNCH_PARAMETER } -TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) { +TEST_F(GpuLaunchConfigTest, GetCuda3DLaunchConfig) { Cuda3DLaunchConfig cfg; - CudaLaunchConfig cfg1d; + GpuLaunchConfig cfg1d; // test valid inputs #define TEST_LAUNCH_PARAMETER(dimx, dimy, dimz) \ - cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ + cfg1d = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \ TF_EXPECT_OK(CudaLaunchKernel(SetOutbufZero, cfg1d.block_count, \ cfg1d.thread_per_block, 0, d.stream(), cfg1d, \ outbuf)); \ diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/gpu_launch_config.h similarity index 56% rename from tensorflow/core/util/cuda_launch_config.h rename to tensorflow/core/util/gpu_launch_config.h index a46bd72c930..565fff8ed47 100644 --- a/tensorflow/core/util/cuda_launch_config.h +++ b/tensorflow/core/util/gpu_launch_config.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_ -#define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_ +#ifndef TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_ +#define TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include @@ -26,45 +26,46 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_cuda_alias.h" -// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and -// GetCuda3DLaunchConfig: +// Usage of GetGpuLaunchConfig, GetGpu2DLaunchConfig, and +// GetGpu3DLaunchConfig: // -// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one +// There are two versions of GetGpuLaunchConfig and GetGpu2DLaunchConfig, one // version uses heuristics without any knowledge of the device kernel, the other // version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical // launch parameters that maximize occupancy. Currently, only the maximum -// occupancy version of GetCuda3DLaunchConfig is available. +// occupancy version of GetGpu3DLaunchConfig is available. // // For large number of work elements, the convention is that each kernel would -// iterate through its assigned range. The return value of GetCudaLaunchConfig -// is struct CudaLaunchConfig, which contains all the information needed for the +// iterate through its assigned range. The return value of GetGpuLaunchConfig +// is struct GpuLaunchConfig, which contains all the information needed for the // kernel launch, including: virtual number of threads, the number of threads // per block and number of threads per block used inside <<< >>> of a kernel -// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing -// as CudaLaunchConfig. The only difference is the dimension. The macros -// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop. +// launch. GetGpu2DLaunchConfig and GetGpu3DLaunchConfig does the same thing +// as GpuLaunchConfig. The only difference is the dimension. The macros +// GPU_1D_KERNEL_LOOP and GPU_AXIS_KERNEL_LOOP might be used to do inner loop. // /* Sample code: -__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) { - CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { +__global__ void MyKernel1D(GpuLaunchConfig config, other_args...) { + GPU_1D_KERNEL_LOOP(x, config.virtual_thread_count) { do_your_job_here; } } -__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { +__global__ void MyKernel2D(Gpu2DLaunchConfig config, other_args...) { + GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { do_your_job_here; } } } -__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { - CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { +__global__ void MyKernel3D(Gpu3DLaunchConfig config, other_args...) { + GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + GPU_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { do_your_job_here; } } @@ -73,25 +74,25 @@ __global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) { void MyDriverFunc(const Eigen::GpuDevice &d) { // use heuristics - CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d); + GpuLaunchConfig cfg1 = GetGpuLaunchConfig(10240, d); MyKernel1D <<>> (cfg1, other_args...); - Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d); + Gpu2DLaunchConfig cfg2 = GetGpu2DLaunchConfig(10240, 10240, d); MyKernel2D <<>> (cfg2, other_args...); - Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d); + Gpu3DLaunchConfig cfg3 = GetGpu3DLaunchConfig(4096, 4096, 100, d); MyKernel3D <<>> (cfg3, other_args...); // maximize occupancy - CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 ); + GpuLaunchConfig cfg4 = GetGpuLaunchConfig(10240, d, MyKernel1D, 0, 0 ); MyKernel1D <<>> (cfg4, other_args...); - Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d, + Gpu2DLaunchConfig cfg5 = GetGpu2DLaunchConfig(10240, 10240, d, MyKernel1D, 0, 0); MyKernel2D <<>> (cfg5, other_args...); - Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d, + Gpu3DLaunchConfig cfg6 = GetGpu3DLaunchConfig(4096, 4096, 100, d, MyKernel1D, 0, 0); MyKernel3D <<>> (cfg6, other_args...); @@ -99,7 +100,7 @@ void MyDriverFunc(const Eigen::GpuDevice &d) { // See the test for this for more example: // -https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc +https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/gpu_kernel_helper_test.cu.cc */ @@ -107,25 +108,26 @@ namespace tensorflow { inline int DivUp(int a, int b) { return (a + b - 1) / b; } -struct CudaLaunchConfig { +struct GpuLaunchConfig { // Logical number of thread that works on the elements. If each logical // thread works on exactly a single element, this is the same as the working // element count. int virtual_thread_count = -1; // Number of threads per block. int thread_per_block = -1; - // Number of blocks for Cuda kernel launch. + // Number of blocks for GPU kernel launch. int block_count = -1; }; +CREATE_CUDA_TYPE_ALIAS(GpuLaunchConfig, CudaLaunchConfig); -// Calculate the Cuda launch config we should use for a kernel launch. +// Calculate the GPU launch config we should use for a kernel launch. // This is assuming the kernel is quite simple and will largely be // memory-limited. // REQUIRES: work_element_count > 0. -inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, - const Eigen::GpuDevice& d) { +inline GpuLaunchConfig GetGpuLaunchConfig(int work_element_count, + const Eigen::GpuDevice& d) { CHECK_GT(work_element_count, 0); - CudaLaunchConfig config; + GpuLaunchConfig config; const int virtual_thread_count = work_element_count; const int physical_thread_count = std::min( d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(), @@ -140,25 +142,48 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, config.block_count = block_count; return config; } +inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, + const Eigen::GpuDevice& d) { + return GetGpuLaunchConfig(work_element_count, d); +} -// Calculate the Cuda launch config we should use for a kernel launch. This +// Calculate the GPU launch config we should use for a kernel launch. This // variant takes the resource limits of func into account to maximize occupancy. // REQUIRES: work_element_count > 0. template -inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, - const Eigen::GpuDevice& d, - DeviceFunc func, - size_t dynamic_shared_memory_size, - int block_size_limit) { +GpuLaunchConfig GetGpuLaunchConfig(int work_element_count, + const Eigen::GpuDevice& d, DeviceFunc func, + size_t dynamic_shared_memory_size, + int block_size_limit) { CHECK_GT(work_element_count, 0); - CudaLaunchConfig config; + GpuLaunchConfig config; int block_count = 0; int thread_per_block = 0; +#if GOOGLE_CUDA cudaError_t err = cudaOccupancyMaxPotentialBlockSize( &block_count, &thread_per_block, func, dynamic_shared_memory_size, block_size_limit); CHECK_EQ(err, cudaSuccess); +#elif TENSORFLOW_USE_ROCM + // ROCM TODO re-enable this after hipOccupancyMaxPotentialBlockSize is + // implemented + // hipError_t err = hipOccupancyMaxPotentialBlockSize( + // &block_count, &thread_per_block, func, dynamic_shared_memory_size, + // block_size_limit); + // CHECK_EQ(err, hipSuccess); + + // Apply the heuristic in GetGpuLaunchConfig(int, const Eigen::GpuDevice&) + // that the kernel is quite simple and will largely be memory-limited. + const int physical_thread_count = std::min( + d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(), + work_element_count); + // Assume the kernel be simple enough that it is okay to use 1024 threads + // per workgroup. + thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock()); + block_count = std::min(DivUp(physical_thread_count, thread_per_block), + d.getNumGpuMultiProcessors()); +#endif block_count = std::min(block_count, DivUp(work_element_count, thread_per_block)); @@ -168,40 +193,64 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, config.block_count = block_count; return config; } +CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpuLaunchConfig, GetCudaLaunchConfig); -// Calculate the Cuda launch config we should use for a kernel launch. This +// Calculate the GPU launch config we should use for a kernel launch. This // variant takes the resource limits of func into account to maximize occupancy. // The returned launch config has thread_per_block set to fixed_block_size. // REQUIRES: work_element_count > 0. template -inline CudaLaunchConfig GetCudaLaunchConfigFixedBlockSize( +GpuLaunchConfig GetGpuLaunchConfigFixedBlockSize( int work_element_count, const Eigen::GpuDevice& d, DeviceFunc func, size_t dynamic_shared_memory_size, int fixed_block_size) { CHECK_GT(work_element_count, 0); - CudaLaunchConfig config; + GpuLaunchConfig config; int block_count = 0; +#if GOOGLE_CUDA cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &block_count, func, fixed_block_size, dynamic_shared_memory_size); CHECK_EQ(err, cudaSuccess); block_count = std::min(block_count * d.getNumGpuMultiProcessors(), DivUp(work_element_count, fixed_block_size)); +#elif TENSORFLOW_USE_ROCM + // ROCM TODO re-enable this after hipOccupancyMaxActiveBlocksPerMultiprocessor + // is implemented + // hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor( + // &block_count, &thread_per_block, func, dynamic_shared_memory_size, + // block_size_limit); + // CHECK_EQ(err, hipSuccess); + + // Apply the heuristic in GetGpuLaunchConfig(int, const Eigen::GpuDevice&) + // that the kernel is quite simple and will largely be memory-limited. + const int physical_thread_count = std::min( + d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(), + work_element_count); + // Assume the kernel be simple enough that it is okay to use 1024 threads + // per workgroup. + int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock()); + block_count = std::min(DivUp(physical_thread_count, thread_per_block), + d.getNumGpuMultiProcessors()); +#endif config.virtual_thread_count = work_element_count; config.thread_per_block = fixed_block_size; config.block_count = block_count; return config; } +CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpuLaunchConfigFixedBlockSize, + GetCudaLaunchConfigFixedBlockSize); -struct Cuda2DLaunchConfig { +struct Gpu2DLaunchConfig { dim3 virtual_thread_count = dim3(0, 0, 0); dim3 thread_per_block = dim3(0, 0, 0); dim3 block_count = dim3(0, 0, 0); }; +CREATE_CUDA_TYPE_ALIAS(Gpu2DLaunchConfig, Cuda2DLaunchConfig); -inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, - const Eigen::GpuDevice& d) { - Cuda2DLaunchConfig config; +inline Gpu2DLaunchConfig GetGpu2DLaunchConfig(int xdim, int ydim, + const Eigen::GpuDevice& d) { + Gpu2DLaunchConfig config; if (xdim <= 0 || ydim <= 0) { return config; @@ -226,26 +275,39 @@ inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1); return config; } +inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, + const Eigen::GpuDevice& d) { + return GetGpu2DLaunchConfig(xdim, ydim, d); +} -// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch. +// Calculate the GPU 2D and 3D launch config we should use for a kernel launch. // This variant takes the resource limits of func into account to maximize // occupancy. -using Cuda3DLaunchConfig = Cuda2DLaunchConfig; +using Gpu3DLaunchConfig = Gpu2DLaunchConfig; +CREATE_CUDA_TYPE_ALIAS(Gpu3DLaunchConfig, Cuda3DLaunchConfig); template -inline Cuda3DLaunchConfig GetCuda3DLaunchConfig( - int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func, - size_t dynamic_shared_memory_size, int block_size_limit) { - Cuda3DLaunchConfig config; +Gpu3DLaunchConfig GetGpu3DLaunchConfig(int xdim, int ydim, int zdim, + const Eigen::GpuDevice& d, + DeviceFunc func, + size_t dynamic_shared_memory_size, + int block_size_limit) { + Gpu3DLaunchConfig config; if (xdim <= 0 || ydim <= 0 || zdim <= 0) { return config; } int dev; +#if GOOGLE_CUDA cudaGetDevice(&dev); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, dev); +#elif TENSORFLOW_USE_ROCM + hipGetDevice(&dev); + hipDeviceProp_t deviceProp; + hipGetDeviceProperties(&deviceProp, dev); +#endif int xthreadlimit = deviceProp.maxThreadsDim[0]; int ythreadlimit = deviceProp.maxThreadsDim[1]; int zthreadlimit = deviceProp.maxThreadsDim[2]; @@ -255,10 +317,26 @@ inline Cuda3DLaunchConfig GetCuda3DLaunchConfig( int block_count = 0; int thread_per_block = 0; + +#if GOOGLE_CUDA cudaError_t err = cudaOccupancyMaxPotentialBlockSize( &block_count, &thread_per_block, func, dynamic_shared_memory_size, block_size_limit); CHECK_EQ(err, cudaSuccess); +#elif TENSORFLOW_USE_ROCM + // ROCM TODO re-enable this after hipOccupancyMaxPotentialBlockSize is + // implemented + // hipError_t err = hipOccupancyMaxPotentialBlockSize( + // &block_count, &thread_per_block, func, dynamic_shared_memory_size, + // block_size_limit); + // CHECK_EQ(err, hipSuccess); + + const int physical_thread_count = + d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(); + thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock()); + block_count = std::min(DivUp(physical_thread_count, thread_per_block), + d.getNumGpuMultiProcessors()); +#endif int threadsx = std::min({xdim, thread_per_block, xthreadlimit}); int threadsy = @@ -278,15 +356,20 @@ inline Cuda3DLaunchConfig GetCuda3DLaunchConfig( config.block_count = dim3(blocksx, blocksy, blocksz); return config; } +CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpu3DLaunchConfig, GetCuda3DLaunchConfig); template -inline Cuda2DLaunchConfig GetCuda2DLaunchConfig( - int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func, - size_t dynamic_shared_memory_size, int block_size_limit) { - return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func, - dynamic_shared_memory_size, block_size_limit); +Gpu2DLaunchConfig GetGpu2DLaunchConfig(int xdim, int ydim, + const Eigen::GpuDevice& d, + DeviceFunc func, + size_t dynamic_shared_memory_size, + int block_size_limit) { + return GetGpu3DLaunchConfig(xdim, ydim, 1, d, func, + dynamic_shared_memory_size, block_size_limit); } +CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpu2DLaunchConfig, GetCuda2DLaunchConfig); +#if GOOGLE_CUDA // Returns a raw reference to the current cuda stream. Required by a // number of kernel calls (for which StreamInterface* does not work), i.e. // CUB and certain cublas primitives. @@ -298,6 +381,16 @@ inline const cudaStream_t& GetCudaStream(OpKernelContext* context) { ->GpuStreamMemberHack())); return *ptr; } +template +Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, + const Eigen::GpuDevice& d, + DeviceFunc func, + size_t dynamic_shared_memory_size, + int block_size_limit) { + return GetGpu2DLaunchConfig(xdim, ydim, d, func, dynamic_shared_memory_size, + block_size_limit); +} +#endif // GOOGLE_CUDA namespace detail { template @@ -323,30 +416,6 @@ constexpr bool NoneIsReference() { return NoneTrue<(std::is_reference::value)...>::value; } } // namespace detail - -// Launches a CUDA kernel through cudaLaunchKernel with the given arguments. -// -// The kernel parameters 'Ts' must be constructible from the arguments 'Args'. -template -Status CudaLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim, - size_t shared_memory_size_bytes, cudaStream_t stream, - Args... arguments) { - static_assert(detail::NoneIsReference(), - "Kernels with reference arguments have undefined behaviour."); - // Cast arguments and forward them as an array of pointers. - auto args_tuple = std::tuple(arguments...); - auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple); - auto func_ptr = absl::bit_cast(function); - auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(), - shared_memory_size_bytes, stream); - if (result != cudaSuccess) { - return errors::Internal(cudaGetErrorString(result)); - } - return Status::OK(); -} - } // namespace tensorflow - -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_ diff --git a/tensorflow/core/util/matmul_bcast.cc b/tensorflow/core/util/matmul_bcast.cc new file mode 100644 index 00000000000..3e5c5cf1750 --- /dev/null +++ b/tensorflow/core/util/matmul_bcast.cc @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/matmul_bcast.h" + +namespace tensorflow { +namespace { + +// Returns the mapping from the output batch indices to the corresponding +// input's batch indices, given the input's "reshape" and "bcast" shapes as +// returned by the BCast helper class. The i'th element denotes the (flattened) +// batch index of the input that must be used to compute the i'th batch output. +void ComputeBatchIndices(const int64 output_batch_size, + const MatMulBCast::Vec& reshape, + const MatMulBCast::Vec& bcast, + std::vector* out_indices) { + // Populates the mapping in out_indices. This algorithm is identical to + // the following steps: + // - Reshape {0, 1, ..., input_batch_size - 1} to the input shape. + // - Broadcast to the output shape. + // - Reshape back to a flat 1D vector. + out_indices->resize(output_batch_size); + int64 num_output_elements = 1; + int64 num_input_elements = 1; + for (int64 i = reshape.size() - 1; i >= 0; --i) { + // Replicate the already populated mapping an additional (dim - 1) times. + // If we are broadcasting, just copy the existing mapping. + // Otherwise, add another dimension from the input shape. + const int64 dim = std::max(reshape[i], bcast[i]); + const int64 incr = bcast[i] > 1 ? 0 : num_input_elements; + for (int64 k = 0; k < (dim - 1) * num_output_elements; ++k) { + (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr; + } + num_output_elements *= dim; + num_input_elements *= reshape[i]; + } +} + +} // namespace + +MatMulBCast::MatMulBCast(Vec x, Vec y) { + if (x.size() < 2 || y.size() < 2) return; + x.resize(x.size() - 2); + y.resize(y.size() - 2); + + batch_bcast_ = absl::make_unique(std::move(x), std::move(y)); + if (!batch_bcast_->IsValid()) return; + + x_batch_size_ = TensorShape(batch_bcast_->x_reshape()).num_elements(); + y_batch_size_ = TensorShape(batch_bcast_->y_reshape()).num_elements(); + output_shape_ = TensorShape(batch_bcast_->output_shape()); + output_batch_size_ = output_shape_.num_elements(); + broadcasting_required_ = + std::min(x_batch_size_, y_batch_size_) != output_batch_size_; + + if (broadcasting_required_) { + ComputeBatchIndices(output_batch_size_, batch_bcast_->x_reshape(), + batch_bcast_->x_bcast(), &x_batch_indices_); + ComputeBatchIndices(output_batch_size_, batch_bcast_->y_reshape(), + batch_bcast_->y_bcast(), &y_batch_indices_); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/matmul_bcast.h b/tensorflow/core/util/matmul_bcast.h new file mode 100644 index 00000000000..611ef237de6 --- /dev/null +++ b/tensorflow/core/util/matmul_bcast.h @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_ +#define TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_ + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +// Simple wrapper over BCast specialized for MatMul. +// Provides utilities for broadcasting across batch dimensions for binary +// MatMul-like operations. +class MatMulBCast { + public: + using Vec = BCast::Vec; + + MatMulBCast(Vec x, Vec y); + + bool IsValid() const { return batch_bcast_ && batch_bcast_->IsValid(); } + bool IsBroadcastingRequired() const { return broadcasting_required_; } + + const int64 output_batch_size() const { return output_batch_size_; } + const int64 x_batch_size() const { return x_batch_size_; } + const int64 y_batch_size() const { return y_batch_size_; } + const TensorShape& output_batch_shape() const { return output_shape_; } + + // Returns the mapping from the flattened output batch indices to x's + // flattened batch indices. The result is a vector of length + // output_batch_size(). To compute the i'th batch output, a binary matmul-like + // operation should use the `x_batch_indices()[i]`th batch index of `x`. + // Note: Returns an empty vector if broadcasting is not required. Callers + // should only use this when IsBroadcastingRequired() returns true. + const std::vector& x_batch_indices() const { return x_batch_indices_; } + // Returns the mapping from the flattened output batch indices to y's + // flattened batch indices. Similar to x_batch_indices(). + // Note: Returns an empty vector if broadcasting is not required. Callers + // should only use this when IsBroadcastingRequired() returns true. + const std::vector& y_batch_indices() const { return y_batch_indices_; } + + private: + std::unique_ptr batch_bcast_; + bool broadcasting_required_ = false; + int64 x_batch_size_; + int64 y_batch_size_; + TensorShape output_shape_; + int64 output_batch_size_; + std::vector x_batch_indices_; + std::vector y_batch_indices_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_ diff --git a/tensorflow/core/util/matmul_bcast_test.cc b/tensorflow/core/util/matmul_bcast_test.cc new file mode 100644 index 00000000000..1de62297f70 --- /dev/null +++ b/tensorflow/core/util/matmul_bcast_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/matmul_bcast.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +string MatMulBCastToStr(const MatMulBCast& b) { + if (!b.IsValid()) { + return "invalid"; + } + string ret; + strings::StrAppend( + &ret, "[", str_util::Join(b.output_batch_shape().dim_sizes(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.x_batch_indices(), ","), "]"); + strings::StrAppend(&ret, "[", str_util::Join(b.y_batch_indices(), ","), "]"); + return ret; +} + +TEST(MatMulBCastTest, SimpleBroadcast) { + MatMulBCast bcast({1, 5, 3}, {4, 3, 7}); + + EXPECT_TRUE(bcast.IsValid()); + EXPECT_TRUE(bcast.IsBroadcastingRequired()); + + EXPECT_EQ(1, bcast.x_batch_size()); + EXPECT_EQ(4, bcast.y_batch_size()); + EXPECT_EQ(4, bcast.output_batch_size()); + + EXPECT_EQ("[4][0,0,0,0][0,1,2,3]", MatMulBCastToStr(bcast)); +} + +TEST(MatMulBCastTest, EmptyBatchBroadcast) { + MatMulBCast bcast({5, 3}, {3, 7}); + + EXPECT_TRUE(bcast.IsValid()); + EXPECT_FALSE(bcast.IsBroadcastingRequired()); + + EXPECT_EQ(1, bcast.x_batch_size()); + EXPECT_EQ(1, bcast.y_batch_size()); + EXPECT_EQ(1, bcast.output_batch_size()); + + EXPECT_EQ("[][][]", MatMulBCastToStr(bcast)); +} + +TEST(MatMulBCastTest, BroadcastingNotRequired) { + MatMulBCast bcast({2, 4, 6, 5, 3}, {2, 4, 6, 3, 7}); + + EXPECT_TRUE(bcast.IsValid()); + EXPECT_FALSE(bcast.IsBroadcastingRequired()); + + EXPECT_EQ(48, bcast.x_batch_size()); + EXPECT_EQ(48, bcast.y_batch_size()); + EXPECT_EQ(48, bcast.output_batch_size()); + + EXPECT_EQ("[2,4,6][][]", MatMulBCastToStr(bcast)); +} + +TEST(MatMulBCastTest, EmptyWithNonEmptyBatchBroadcast) { + MatMulBCast bcast1({5, 3}, {6, 3, 7}); + + EXPECT_TRUE(bcast1.IsValid()); + EXPECT_TRUE(bcast1.IsBroadcastingRequired()); + + EXPECT_EQ(1, bcast1.x_batch_size()); + EXPECT_EQ(6, bcast1.y_batch_size()); + EXPECT_EQ(6, bcast1.output_batch_size()); + EXPECT_EQ("[6][0,0,0,0,0,0][0,1,2,3,4,5]", MatMulBCastToStr(bcast1)); + + MatMulBCast bcast2({2, 5, 3}, {3, 7}); + EXPECT_TRUE(bcast2.IsValid()); + EXPECT_TRUE(bcast2.IsBroadcastingRequired()); + + EXPECT_EQ(2, bcast2.x_batch_size()); + EXPECT_EQ(1, bcast2.y_batch_size()); + EXPECT_EQ(2, bcast2.output_batch_size()); + EXPECT_EQ("[2][0,1][0,0]", MatMulBCastToStr(bcast2)); +} + +TEST(MatMulBCastTest, InvalidDimensions) { + // Too few dimensions. + MatMulBCast bcast1({3, 3}, {3}); + EXPECT_FALSE(bcast1.IsValid()); + + MatMulBCast bcast2({3}, {3, 3}); + EXPECT_FALSE(bcast2.IsValid()); + + // Batch dimensions not broadcastable. + MatMulBCast bcast3({4, 5, 3}, {2, 3, 7}); + EXPECT_FALSE(bcast3.IsValid()); + + MatMulBCast bcast4({2, 1, 5, 3}, {1, 3, 1, 3, 7}); + EXPECT_FALSE(bcast4.IsValid()); +} + +TEST(MatMulBCastTest, BroadcastBothOperands) { + MatMulBCast bcast({3, 1, 5, 3}, {1, 4, 3, 7}); + EXPECT_TRUE(bcast.IsValid()); + + EXPECT_EQ(3, bcast.x_batch_size()); + EXPECT_EQ(4, bcast.y_batch_size()); + EXPECT_EQ(12, bcast.output_batch_size()); + + EXPECT_EQ("[3,4][0,0,0,0,1,1,1,1,2,2,2,2][0,1,2,3,0,1,2,3,0,1,2,3]", + MatMulBCastToStr(bcast)); +} + +TEST(MatMulBCastTest, DifferentRanks) { + MatMulBCast bcast({3, 1, 5, 3}, {2, 1, 2, 3, 7}); + EXPECT_TRUE(bcast.IsValid()); + + EXPECT_EQ(3, bcast.x_batch_size()); + EXPECT_EQ(4, bcast.y_batch_size()); + EXPECT_EQ(12, bcast.output_batch_size()); + + EXPECT_EQ("[2,3,2][0,0,1,1,2,2,0,0,1,1,2,2][0,1,0,1,0,1,2,3,2,3,2,3]", + MatMulBCastToStr(bcast)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/memmapped_file_system.cc b/tensorflow/core/util/memmapped_file_system.cc index b1773a25171..00c527f1af6 100644 --- a/tensorflow/core/util/memmapped_file_system.cc +++ b/tensorflow/core/util/memmapped_file_system.cc @@ -235,8 +235,7 @@ Status MemmappedFileSystem::InitializeFromFile(Env* env, if (!directory_ .insert(std::make_pair( element_iter->name(), - FileRegion(element_iter->offset(), - prev_element_offset - element_iter->offset()))) + FileRegion(element_iter->offset(), element_iter->length()))) .second) { return errors::DataLoss("Corrupted memmapped model file: ", filename, " Duplicate name of internal component ", diff --git a/tensorflow/core/util/memmapped_file_system.proto b/tensorflow/core/util/memmapped_file_system.proto index bf6cb4af296..a988b45b6f0 100644 --- a/tensorflow/core/util/memmapped_file_system.proto +++ b/tensorflow/core/util/memmapped_file_system.proto @@ -15,12 +15,14 @@ limitations under the License. syntax = "proto3"; package tensorflow; + option cc_enable_arenas = true; // A message that describes one region of memmapped file. message MemmappedFileSystemDirectoryElement { uint64 offset = 1; string name = 2; + uint64 length = 3; } // A directory of regions in a memmapped file. diff --git a/tensorflow/core/util/memmapped_file_system_writer.cc b/tensorflow/core/util/memmapped_file_system_writer.cc index 9556ee385f6..483c8be7933 100644 --- a/tensorflow/core/util/memmapped_file_system_writer.cc +++ b/tensorflow/core/util/memmapped_file_system_writer.cc @@ -47,7 +47,7 @@ Status MemmappedFileSystemWriter::SaveTensor(const Tensor& tensor, } // Adds pad for correct alignment after memmapping. TF_RETURN_IF_ERROR(AdjustAlignment(Allocator::kAllocatorAlignment)); - AddToDirectoryElement(element_name); + AddToDirectoryElement(element_name, tensor_data.size()); const auto result = output_file_->Append(tensor_data); if (result.ok()) { output_file_offset_ += tensor_data.size(); @@ -69,8 +69,8 @@ Status MemmappedFileSystemWriter::SaveProtobuf( MemmappedFileSystem::kMemmappedPackagePrefix, " and include [A-Za-z0-9_.]"); } - AddToDirectoryElement(element_name); const string encoded = message.SerializeAsString(); + AddToDirectoryElement(element_name, encoded.size()); const auto res = output_file_->Append(encoded); if (res.ok()) { output_file_offset_ += encoded.size(); @@ -124,11 +124,13 @@ Status MemmappedFileSystemWriter::AdjustAlignment(uint64 alignment) { return Status::OK(); } -void MemmappedFileSystemWriter::AddToDirectoryElement(const string& name) { +void MemmappedFileSystemWriter::AddToDirectoryElement(const string& name, + uint64 length) { MemmappedFileSystemDirectoryElement* new_directory_element = directory_.add_element(); new_directory_element->set_offset(output_file_offset_); new_directory_element->set_name(name); + new_directory_element->set_length(length); } } // namespace tensorflow diff --git a/tensorflow/core/util/memmapped_file_system_writer.h b/tensorflow/core/util/memmapped_file_system_writer.h index 2cebaa256da..884b4b8bc63 100644 --- a/tensorflow/core/util/memmapped_file_system_writer.h +++ b/tensorflow/core/util/memmapped_file_system_writer.h @@ -40,7 +40,7 @@ class MemmappedFileSystemWriter { private: Status AdjustAlignment(uint64 alignment); - void AddToDirectoryElement(const string& element_name); + void AddToDirectoryElement(const string& element_name, uint64 length); MemmappedFileSystemDirectory directory_; // The current offset in the file, to support alignment. uint64 output_file_offset_ = 0; diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 91f9bc03625..11c1ec1cc60 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -1442,6 +1442,12 @@ template <> memory::data_type MklDnnType() { return memory::data_type::s32; } +template <> +memory::data_type MklDnnType() { + // TODO(nhasabni): Enable MKL-DNN bfloat16 type later. + // Currently, falling back to f32 to get compilation working. + return memory::data_type::f32; +} /// Map TensorFlow's data format into MKL-DNN 3D data format /// @input: TensorFlow data format @@ -1581,7 +1587,7 @@ inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) { /// Function to calculate strides given tensor shape in Tensorflow order /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention, -/// dimesion with size 1 is outermost dimension; while dimension with size 4 is +/// dimension with size 1 is outermost dimension; while dimension with size 4 is /// innermost dimension. So strides for this tensor would be {4 * 3 * 2, /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}. /// diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc index 7dc8ddda06a..0ec78153016 100644 --- a/tensorflow/core/util/port.cc +++ b/tensorflow/core/util/port.cc @@ -26,8 +26,17 @@ bool IsGoogleCudaEnabled() { #endif } -bool CudaSupportsHalfMatMulAndConv() { -#if GOOGLE_CUDA +bool IsBuiltWithROCm() { +#if TENSORFLOW_USE_ROCM + return true; +#else + return false; +#endif +} + +bool GpuSupportsHalfMatMulAndConv() { +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) return true; #else return false; diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h index e9b9cb1cd21..bfdede74212 100644 --- a/tensorflow/core/util/port.h +++ b/tensorflow/core/util/port.h @@ -21,9 +21,19 @@ namespace tensorflow { // Returns true if GOOGLE_CUDA is defined. bool IsGoogleCudaEnabled(); -// Returns true if GOOGLE_CUDA is defined, and the given CUDA version supports -// half-precision matrix multiplications and convolution operations. -bool CudaSupportsHalfMatMulAndConv(); +// Returns true if TENSORFLOW_USE_ROCM is defined. (i.e. TF is built with ROCm) +bool IsBuiltWithROCm(); + +// Returns true if either +// +// GOOGLE_CUDA is defined, and the given CUDA version supports +// half-precision matrix multiplications and convolution operations. +// +// OR +// +// TENSORFLOW_USE_ROCM is defined +// +bool GpuSupportsHalfMatMulAndConv(); // Returns true if INTEL_MKL is defined bool IsMklEnabled(); diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD index b990f0a7491..890bd837025 100644 --- a/tensorflow/core/util/proto/BUILD +++ b/tensorflow/core/util/proto/BUILD @@ -70,6 +70,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:platform_base", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@protobuf_archive//:protobuf_headers", ], ) diff --git a/tensorflow/core/util/proto/descriptors.cc b/tensorflow/core/util/proto/descriptors.cc index 271c85efd88..c3797f1a8a8 100644 --- a/tensorflow/core/util/proto/descriptors.cc +++ b/tensorflow/core/util/proto/descriptors.cc @@ -25,7 +25,7 @@ namespace { // Build a `DescriptorPool` from the named file or URI. The file or URI // must be available to the current TensorFlow environment. // -// The file must contiain a serialized `FileDescriptorSet`. See +// The file must contain a serialized `FileDescriptorSet`. See // `GetDescriptorPool()` for more information. Status GetDescriptorPoolFromFile( tensorflow::Env* env, const string& filename, diff --git a/tensorflow/core/util/proto/proto_utils.h b/tensorflow/core/util/proto/proto_utils.h index 9451e317a13..ba45f8a5b0e 100644 --- a/tensorflow/core/util/proto/proto_utils.h +++ b/tensorflow/core/util/proto/proto_utils.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_ #define TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_ +#include "google/protobuf/duration.pb.h" #include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/protobuf.h" @@ -58,6 +60,20 @@ class StringErrorCollector : public protobuf::io::ErrorCollector { const int index_offset_; }; +// Converts an absl::Duration to a google::protobuf::Duration. +inline google::protobuf::Duration ToDurationProto(absl::Duration duration) { + google::protobuf::Duration proto; + proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); + proto.set_nanos( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + return proto; +} + +// Converts a google::protobuf::Duration to an absl::Duration. +inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + } // namespace proto_utils } // namespace tensorflow diff --git a/tensorflow/core/util/reffed_status_callback.h b/tensorflow/core/util/reffed_status_callback.h index 4d9a851037c..1c552d45c42 100644 --- a/tensorflow/core/util/reffed_status_callback.h +++ b/tensorflow/core/util/reffed_status_callback.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_REFFED_STATUS_CALLBACK_H_ #define TENSORFLOW_CORE_UTIL_REFFED_STATUS_CALLBACK_H_ +#include "absl/strings/str_cat.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mutex.h" @@ -28,33 +29,30 @@ namespace tensorflow { // UpdateStatus(), or Status::OK() if no non-OK status was set. class ReffedStatusCallback : public core::RefCounted { public: - explicit ReffedStatusCallback(StatusCallback done) - : done_(std::move(done)), status_(Status::OK()) {} + explicit ReffedStatusCallback(StatusCallback done) : done_(std::move(done)) {} void UpdateStatus(const Status& s) { - if (!s.ok()) { - mutex_lock lock(mu_); - if (status_.ok()) status_.Update(s); - } + mutex_lock lock(mu_); + status_group_.Update(s); } bool ok() { - mutex_lock lock(mu_); - return status_.ok(); + tf_shared_lock lock(mu_); + return status_group_.ok(); } // Returns a copy of the current status. Status status() { - mutex_lock lock(mu_); - return status_; + tf_shared_lock lock(mu_); + return status_group_.as_summary_status(); } - ~ReffedStatusCallback() { done_(status_); } + ~ReffedStatusCallback() { done_(status_group_.as_summary_status()); } private: StatusCallback done_; mutex mu_; - Status status_ GUARDED_BY(mu_); + StatusGroup status_group_ GUARDED_BY(mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/util/reffed_status_callback_test.cc b/tensorflow/core/util/reffed_status_callback_test.cc index 7e776beb237..6799183dc1f 100644 --- a/tensorflow/core/util/reffed_status_callback_test.cc +++ b/tensorflow/core/util/reffed_status_callback_test.cc @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/core/util/reffed_status_callback.h" +#include + +#include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -49,11 +50,16 @@ TEST(TestReffedStatusCallback, CallsBackFail) { }; auto* cb = new ReffedStatusCallback(std::move(done)); cb->UpdateStatus(errors::Internal("1")); - cb->UpdateStatus(errors::Internal("2")); // Will be ignored. + cb->UpdateStatus(errors::InvalidArgument("2")); EXPECT_FALSE(called); cb->Unref(); EXPECT_TRUE(called); - EXPECT_EQ(status.error_message(), "1"); + // Equal to the first error. + EXPECT_EQ(status.code(), error::INTERNAL); + // Both errors are reported. + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Internal: 1")); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "Invalid argument: 2")); } TEST(TestReffedStatusCallback, RefMulti) { @@ -67,13 +73,15 @@ TEST(TestReffedStatusCallback, RefMulti) { cb->Ref(); cb->UpdateStatus(errors::Internal("1")); cb->Ref(); - cb->UpdateStatus(errors::Internal("2")); // Will be ignored. + cb->UpdateStatus(errors::Internal("2")); cb->Unref(); cb->Unref(); EXPECT_FALSE(called); cb->Unref(); // Created by constructor. EXPECT_TRUE(called); - EXPECT_EQ(status.error_message(), "1"); + // Both errors are reported. + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Internal: 1")); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Internal: 2")); } TEST(TestReffedStatusCallback, MultiThreaded) { @@ -104,7 +112,9 @@ TEST(TestReffedStatusCallback, MultiThreaded) { n.WaitForNotification(); EXPECT_EQ(num_called.load(), 1); - EXPECT_EQ(status.error_message(), "err"); + EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "Invalid argument: err")); } } // namespace diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 3709ee5ae30..111ccdc48f4 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/framework/versions.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -1033,6 +1034,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key, HANDLE_COPY(qint32) HANDLE_COPY(quint8) HANDLE_COPY(qint8) + HANDLE_COPY(bfloat16) default: return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype), " not supported."); diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 9567e4750b7..d6fab75662b 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -166,10 +166,10 @@ template void TestBasic() { { BundleWriter writer(Env::Default(), Prefix("foo")); - TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3(3))); - TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3(0))); - TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3(2))); - TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3(1))); + TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3(T(3)))); + TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3(T(0)))); + TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3(T(2)))); + TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3(T(1)))); TF_ASSERT_OK(writer.Finish()); } { @@ -178,28 +178,28 @@ void TestBasic() { EXPECT_EQ( AllTensorKeys(&reader), std::vector({"foo_000", "foo_001", "foo_002", "foo_003"})); - Expect(&reader, "foo_000", Constant_2x3(0)); - Expect(&reader, "foo_001", Constant_2x3(1)); - Expect(&reader, "foo_002", Constant_2x3(2)); - Expect(&reader, "foo_003", Constant_2x3(3)); + Expect(&reader, "foo_000", Constant_2x3(T(0))); + Expect(&reader, "foo_001", Constant_2x3(T(1))); + Expect(&reader, "foo_002", Constant_2x3(T(2))); + Expect(&reader, "foo_003", Constant_2x3(T(3))); } { BundleReader reader(Env::Default(), Prefix("foo")); TF_ASSERT_OK(reader.status()); - ExpectNext(&reader, Constant_2x3(0)); - ExpectNext(&reader, Constant_2x3(1)); - ExpectNext(&reader, Constant_2x3(2)); - ExpectNext(&reader, Constant_2x3(3)); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); EXPECT_TRUE(reader.Valid()); reader.Next(); EXPECT_FALSE(reader.Valid()); } { BundleWriter writer(Env::Default(), Prefix("bar")); - TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3(3))); - TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3(0))); - TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3(2))); - TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3(1))); + TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3(T(3)))); + TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3(T(0)))); + TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3(T(2)))); + TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3(T(1)))); TF_ASSERT_OK(writer.Finish()); } { @@ -208,18 +208,18 @@ void TestBasic() { EXPECT_EQ( AllTensorKeys(&reader), std::vector({"bar_000", "bar_001", "bar_002", "bar_003"})); - Expect(&reader, "bar_003", Constant_2x3(3)); - Expect(&reader, "bar_002", Constant_2x3(2)); - Expect(&reader, "bar_001", Constant_2x3(1)); - Expect(&reader, "bar_000", Constant_2x3(0)); + Expect(&reader, "bar_003", Constant_2x3(T(3))); + Expect(&reader, "bar_002", Constant_2x3(T(2))); + Expect(&reader, "bar_001", Constant_2x3(T(1))); + Expect(&reader, "bar_000", Constant_2x3(T(0))); } { BundleReader reader(Env::Default(), Prefix("bar")); TF_ASSERT_OK(reader.status()); - ExpectNext(&reader, Constant_2x3(0)); - ExpectNext(&reader, Constant_2x3(1)); - ExpectNext(&reader, Constant_2x3(2)); - ExpectNext(&reader, Constant_2x3(3)); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); EXPECT_TRUE(reader.Valid()); reader.Next(); EXPECT_FALSE(reader.Valid()); @@ -233,26 +233,26 @@ void TestBasic() { AllTensorKeys(&reader), std::vector({"bar_000", "bar_001", "bar_002", "bar_003", "foo_000", "foo_001", "foo_002", "foo_003"})); - Expect(&reader, "bar_000", Constant_2x3(0)); - Expect(&reader, "bar_001", Constant_2x3(1)); - Expect(&reader, "bar_002", Constant_2x3(2)); - Expect(&reader, "bar_003", Constant_2x3(3)); - Expect(&reader, "foo_000", Constant_2x3(0)); - Expect(&reader, "foo_001", Constant_2x3(1)); - Expect(&reader, "foo_002", Constant_2x3(2)); - Expect(&reader, "foo_003", Constant_2x3(3)); + Expect(&reader, "bar_000", Constant_2x3(T(0))); + Expect(&reader, "bar_001", Constant_2x3(T(1))); + Expect(&reader, "bar_002", Constant_2x3(T(2))); + Expect(&reader, "bar_003", Constant_2x3(T(3))); + Expect(&reader, "foo_000", Constant_2x3(T(0))); + Expect(&reader, "foo_001", Constant_2x3(T(1))); + Expect(&reader, "foo_002", Constant_2x3(T(2))); + Expect(&reader, "foo_003", Constant_2x3(T(3))); } { BundleReader reader(Env::Default(), Prefix("merged")); TF_ASSERT_OK(reader.status()); - ExpectNext(&reader, Constant_2x3(0)); - ExpectNext(&reader, Constant_2x3(1)); - ExpectNext(&reader, Constant_2x3(2)); - ExpectNext(&reader, Constant_2x3(3)); - ExpectNext(&reader, Constant_2x3(0)); - ExpectNext(&reader, Constant_2x3(1)); - ExpectNext(&reader, Constant_2x3(2)); - ExpectNext(&reader, Constant_2x3(3)); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); + ExpectNext(&reader, Constant_2x3(T(0))); + ExpectNext(&reader, Constant_2x3(T(1))); + ExpectNext(&reader, Constant_2x3(T(2))); + ExpectNext(&reader, Constant_2x3(T(3))); EXPECT_TRUE(reader.Valid()); reader.Next(); EXPECT_FALSE(reader.Valid()); @@ -263,20 +263,20 @@ template void TestNonStandardShapes() { { BundleWriter writer(Env::Default(), Prefix("nonstandard")); - TF_EXPECT_OK(writer.Add("scalar", Constant(0, TensorShape()))); + TF_EXPECT_OK(writer.Add("scalar", Constant(T(0), TensorShape()))); TF_EXPECT_OK( - writer.Add("non_standard0", Constant(0, TensorShape({0, 1618})))); + writer.Add("non_standard0", Constant(T(0), TensorShape({0, 1618})))); TF_EXPECT_OK( - writer.Add("non_standard1", Constant(0, TensorShape({16, 0, 18})))); + writer.Add("non_standard1", Constant(T(0), TensorShape({16, 0, 18})))); TF_ASSERT_OK(writer.Finish()); } { BundleReader reader(Env::Default(), Prefix("nonstandard")); TF_ASSERT_OK(reader.status()); - Expect(&reader, "scalar", Constant(0, TensorShape())); - Expect(&reader, "non_standard0", Constant(0, TensorShape({0, 1618}))); + Expect(&reader, "scalar", Constant(T(0), TensorShape())); + Expect(&reader, "non_standard0", Constant(T(0), TensorShape({0, 1618}))); Expect(&reader, "non_standard1", - Constant(0, TensorShape({16, 0, 18}))); + Constant(T(0), TensorShape({16, 0, 18}))); } } @@ -318,6 +318,7 @@ TEST(TensorBundleTest, Basic) { TestBasic(); TestBasic(); TestBasic(); + TestBasic(); } TEST(TensorBundleTest, PartitionedVariables) { @@ -461,6 +462,7 @@ TEST(TensorBundleTest, NonStandardShapes) { TestNonStandardShapes(); TestNonStandardShapes(); TestNonStandardShapes(); + TestNonStandardShapes(); } TEST(TensorBundleTest, StringTensorsOldFormat) { diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc index f331973f5ce..5dbd8ef318f 100644 --- a/tensorflow/core/util/tensor_format.cc +++ b/tensorflow/core/util/tensor_format.cc @@ -63,6 +63,8 @@ string ToString(FilterTensorFormat format) { return "HWIO"; case FORMAT_OIHW: return "OIHW"; + case FORMAT_OHWI: + return "OHWI"; case FORMAT_OIHW_VECT_I: return "OIHW_VECT_I"; default: diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index 643e14e0b56..82af5c545f7 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -80,6 +80,9 @@ enum FilterTensorFormat { // FORMAT_OIHW often improves performance on GPUs. FORMAT_OIHW = 1, + // FORMAT_OHWI used by cuDNN for NHWC convolutions. + FORMAT_OHWI = 2, + // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C // data format. It is laid out in the same order as OIHW, except that the size @@ -88,7 +91,7 @@ enum FilterTensorFormat { // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format. // A pre-condition of this format is that I must be a multiple of 4. - FORMAT_OIHW_VECT_I = 2, + FORMAT_OIHW_VECT_I = 3, }; // Parse tensor format from the given string. diff --git a/tensorflow/core/util/tensor_slice_set_test.cc b/tensorflow/core/util/tensor_slice_set_test.cc index 38ad6adf51a..8e12f7c7874 100644 --- a/tensorflow/core/util/tensor_slice_set_test.cc +++ b/tensorflow/core/util/tensor_slice_set_test.cc @@ -218,10 +218,18 @@ TEST(TensorSliceSetTest, QueryMetaTwoD) { std::vector> results; EXPECT_TRUE(tss.QueryMeta(s, &results)); EXPECT_EQ(2, results.size()); - EXPECT_EQ("2,2:0,3", results[0].first.DebugString()); - EXPECT_EQ("slice_2", results[0].second); - EXPECT_EQ("0,2:-", results[1].first.DebugString()); - EXPECT_EQ("slice_1", results[1].second); + // Allow results to be returned in either order + if (results[0].second == "slice_2") { + EXPECT_EQ("2,2:0,3", results[0].first.DebugString()); + EXPECT_EQ("slice_2", results[0].second); + EXPECT_EQ("0,2:-", results[1].first.DebugString()); + EXPECT_EQ("slice_1", results[1].second); + } else { + EXPECT_EQ("0,2:-", results[0].first.DebugString()); + EXPECT_EQ("slice_1", results[0].second); + EXPECT_EQ("2,2:0,3", results[1].first.DebugString()); + EXPECT_EQ("slice_2", results[1].second); + } } // Slice #4 includes the hole and so there is no match diff --git a/tensorflow/core/util/tensor_slice_util.h b/tensorflow/core/util/tensor_slice_util.h index 8f5a6f1d935..6d478349a78 100644 --- a/tensorflow/core/util/tensor_slice_util.h +++ b/tensorflow/core/util/tensor_slice_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ -#define TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ +#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_UTIL_H_ +#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_UTIL_H_ #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" @@ -188,4 +188,4 @@ static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape, } // namespace tensorflow -#endif // TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_ +#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_UTIL_H_ diff --git a/tensorflow/core/util/test_log.proto b/tensorflow/core/util/test_log.proto index 8ea59e10680..ddb0599388f 100644 --- a/tensorflow/core/util/test_log.proto +++ b/tensorflow/core/util/test_log.proto @@ -2,6 +2,7 @@ syntax = "proto3"; import "google/protobuf/any.proto"; +import "google/protobuf/wrappers.proto"; option cc_enable_arenas = true; option java_outer_classname = "TestLogProtos"; @@ -17,6 +18,20 @@ message EntryValue { } }; +message MetricEntry { + // Metric name + string name = 1; + + // Metric value + double value = 2; + + // The minimum acceptable value for the metric if specified + google.protobuf.DoubleValue min_value = 3; + + // The maximum acceptable value for the metric if specified + google.protobuf.DoubleValue max_value = 4; +} + // Each unit test or benchmark in a test or benchmark run provides // some set of information. Here we provide some reasonable keys // one would expect to see, with optional key/value pairs for things @@ -43,6 +58,10 @@ message BenchmarkEntry { // Generic map from result key to value. map extras = 6; + + // Metric name, value and expected range. This can include accuracy metrics + // typically used to determine whether the accuracy test has passed + repeated MetricEntry metrics = 7; }; message BenchmarkEntries { @@ -193,4 +212,8 @@ message TestResults { // * presubmit: results from oneshot requests. // * culprit: results from culprit finder rerun. string run_mode = 11; + + // TensorFlow version this benchmark runs against. + // This can be either set to full version or just the major version. + string tf_version = 12; }; diff --git a/tensorflow/examples/adding_an_op/BUILD b/tensorflow/examples/adding_an_op/BUILD index a4d6f204cd9..f65f3a5b933 100644 --- a/tensorflow/examples/adding_an_op/BUILD +++ b/tensorflow/examples/adding_an_op/BUILD @@ -68,8 +68,12 @@ py_test( name = "zero_out_1_test", size = "small", srcs = ["zero_out_1_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", - tags = ["notap"], + tags = [ + "no_pip", + "notap", + ], deps = [ ":zero_out_op_1", "//tensorflow:tensorflow_py", @@ -80,8 +84,12 @@ py_test( name = "zero_out_2_test", size = "small", srcs = ["zero_out_2_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", - tags = ["notap"], + tags = [ + "no_pip", + "notap", + ], deps = [ ":zero_out_grad_2", ":zero_out_op_2", @@ -93,8 +101,12 @@ py_test( name = "zero_out_3_test", size = "small", srcs = ["zero_out_3_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", - tags = ["notap"], + tags = [ + "no_pip", + "notap", + ], deps = [ ":zero_out_op_3", "//tensorflow:tensorflow_py", @@ -120,8 +132,12 @@ py_test( size = "small", srcs = ["cuda_op_test.py"], exec_compatible_with = tf_exec_compatible_with({"tags": tf_cuda_tests_tags()}), + python_version = "PY2", srcs_version = "PY2AND3", - tags = tf_cuda_tests_tags() + ["notap"], + tags = tf_cuda_tests_tags() + [ + "notap", + "no_pip", + ], deps = [ ":cuda_op", "//tensorflow:tensorflow_py", @@ -132,6 +148,7 @@ py_test( name = "fact_test", size = "small", srcs = ["fact_test.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = ["//tensorflow:tensorflow_py"], ) diff --git a/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc b/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc index 721da2a0bdb..1dcf23e4d03 100644 --- a/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc +++ b/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc @@ -16,7 +16,8 @@ limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/util/cuda_launch_config.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" __global__ void AddOneKernel(const int* in, const int N, int* out) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; diff --git a/tensorflow/examples/android/build.gradle b/tensorflow/examples/android/build.gradle index f771530eb9d..ba37ca841ac 100644 --- a/tensorflow/examples/android/build.gradle +++ b/tensorflow/examples/android/build.gradle @@ -94,7 +94,7 @@ android { } externalNativeBuild { cmake { - arguments '-DANDROID_TOOLCHAIN=gcc', '-DANDROID_STL=gnustl_static' + arguments '-DANDROID_STL=c++_static' } } } diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/examples/android/jni/object_tracking/jni_utils.h index 06048ecfd36..5f622a2e65f 100644 --- a/tensorflow/examples/android/jni/object_tracking/jni_utils.h +++ b/tensorflow/examples/android/jni/object_tracking/jni_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_ +#include #include #include "tensorflow/examples/android/jni/object_tracking/utils.h" diff --git a/tensorflow/examples/android/jni/object_tracking/sprite.h b/tensorflow/examples/android/jni/object_tracking/sprite.h index b54a68458f1..964f1c30bfa 100755 --- a/tensorflow/examples/android/jni/object_tracking/sprite.h +++ b/tensorflow/examples/android/jni/object_tracking/sprite.h @@ -16,16 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ +#ifdef __RENDER_OPENGL__ + #include #include #include "tensorflow/examples/android/jni/object_tracking/image-inl.h" #include "tensorflow/examples/android/jni/object_tracking/image.h" -#ifndef __RENDER_OPENGL__ -#error sprite.h should not included if OpenGL is not enabled by platform.h -#endif - namespace tf_tracking { // This class encapsulates the logic necessary to load an render image data @@ -199,4 +197,6 @@ class Sprite { } // namespace tf_tracking +#endif // __RENDER_OPENGL__ + #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_ diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java index 7882d87c1cf..f778f3de425 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java @@ -157,7 +157,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { - LOGGER.e("Exception initializing classifier!", e); + LOGGER.e(e, "Exception initializing classifier!"); Toast toast = Toast.makeText( getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); diff --git a/tensorflow/examples/autograph/integration_tests/BUILD b/tensorflow/examples/autograph/integration_tests/BUILD deleted file mode 100644 index 2a4a0f75e7a..00000000000 --- a/tensorflow/examples/autograph/integration_tests/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "py_test") - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -py_test( - name = "keras_test", - srcs = [ - "keras_test.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_test( - name = "list_literals_test", - srcs = [ - "list_literals_test.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - ], -) diff --git a/tensorflow/examples/autograph/integration_tests/keras_test.py b/tensorflow/examples/autograph/integration_tests/keras_test.py deleted file mode 100644 index 72b62f1ad4d..00000000000 --- a/tensorflow/examples/autograph/integration_tests/keras_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras integration tests.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from tensorflow.python import autograph -from tensorflow.python.framework import test_util - - -class MinimalKeras(tf.keras.Model): - - def call(self, x): - return x * 3 - - -class ModelWithStaticConditional(object): - - def __init__(self, initial): - self.initial = initial - if self.initial: - self.h = 15 - - @autograph.convert() - def call(self): - x = 10 - if self.initial: - x += self.h - return x - - -class BasicBlock(tf.keras.Model): - - def __init__(self): - super(BasicBlock, self).__init__() - self.conv1 = tf.keras.layers.Conv2D(8, 3) - self.pool = tf.keras.layers.GlobalAveragePooling2D() - self.dense = tf.keras.layers.Dense(3) - - def call(self, x): - x = self.conv1(x) - x = self.pool(x) - x = self.dense(x) - return x - - -class CompoundModel(tf.keras.Model): - - def __init__(self): - super(CompoundModel, self).__init__() - self.block = BasicBlock() - - @autograph.convert(recursive=True) - def call(self, x): - x = self.block(x) # pylint: disable=not-callable - return x - - -class KerasTest(tf.test.TestCase): - - def test_basic(self): - MinimalKeras() - - def test_conditional_attributes_False(self): - model = ModelWithStaticConditional(False) - self.assertEqual(model.call(), 10) - - def test_conditional_attributes_True(self): - model = ModelWithStaticConditional(True) - self.assertEqual(model.call(), 25) - - @test_util.run_deprecated_v1 - def test_recursive_true(self): - with tf.Graph().as_default(): - model = CompoundModel() - model.build(tf.TensorShape((None, 10, 10, 1))) - init = tf.global_variables_initializer() - - with tf.Session() as sess: - self.evaluate(init) - sample_input = tf.random_uniform((1, 10, 10, 1)) - output = model(sample_input) # pylint: disable=not-callable - self.assertEqual(self.evaluate(output).shape, (1, 3)) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/examples/get_started/regression/BUILD b/tensorflow/examples/get_started/regression/BUILD index bee94d7d90f..cdef25ce495 100644 --- a/tensorflow/examples/get_started/regression/BUILD +++ b/tensorflow/examples/get_started/regression/BUILD @@ -9,10 +9,10 @@ py_test( "custom_regression.py", "dnn_regression.py", "imports85.py", - "linear_regression.py", "linear_regression_categorical.py", "test.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "manual", diff --git a/tensorflow/examples/get_started/regression/test.py b/tensorflow/examples/get_started/regression/test.py index bb4db6700b8..1c37e4a671b 100644 --- a/tensorflow/examples/get_started/regression/test.py +++ b/tensorflow/examples/get_started/regression/test.py @@ -32,7 +32,6 @@ sys.modules["imports85"] = imports85 import tensorflow.data as data import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression -import tensorflow.examples.get_started.regression.linear_regression as linear_regression import tensorflow.examples.get_started.regression.linear_regression_categorical as linear_regression_categorical import tensorflow.examples.get_started.regression.custom_regression as custom_regression @@ -46,7 +45,8 @@ FOUR_LINES = "\n".join([ "1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500", "2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950", "2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450", - "2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250",]) + "2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250", +]) # pylint: enable=line-too-long @@ -54,8 +54,8 @@ FOUR_LINES = "\n".join([ def four_lines_dataframe(): text = StringIO(FOUR_LINES) - return pd.read_csv(text, names=imports85.types.keys(), - dtype=imports85.types, na_values="?") + return pd.read_csv( + text, names=imports85.types.keys(), dtype=imports85.types, na_values="?") def four_lines_dataset(*args, **kwargs): @@ -66,22 +66,13 @@ def four_lines_dataset(*args, **kwargs): class RegressionTest(googletest.TestCase): """Test the regression examples in this directory.""" - @test.mock.patch.dict(data.__dict__, - {"TextLineDataset": four_lines_dataset}) - @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)}) - @test.mock.patch.dict(linear_regression.__dict__, {"STEPS": 1}) - def test_linear_regression(self): - linear_regression.main([""]) - - @test.mock.patch.dict(data.__dict__, - {"TextLineDataset": four_lines_dataset}) + @test.mock.patch.dict(data.__dict__, {"TextLineDataset": four_lines_dataset}) @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)}) @test.mock.patch.dict(linear_regression_categorical.__dict__, {"STEPS": 1}) def test_linear_regression_categorical(self): linear_regression_categorical.main([""]) - @test.mock.patch.dict(data.__dict__, - {"TextLineDataset": four_lines_dataset}) + @test.mock.patch.dict(data.__dict__, {"TextLineDataset": four_lines_dataset}) @test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)}) @test.mock.patch.dict(dnn_regression.__dict__, {"STEPS": 1}) def test_dnn_regression(self): diff --git a/tensorflow/examples/how_tos/reading_data/BUILD b/tensorflow/examples/how_tos/reading_data/BUILD index e846b291467..b8a6ee1026a 100644 --- a/tensorflow/examples/how_tos/reading_data/BUILD +++ b/tensorflow/examples/how_tos/reading_data/BUILD @@ -10,6 +10,7 @@ exports_files(["LICENSE"]) py_binary( name = "convert_to_records", srcs = ["convert_to_records.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", @@ -22,6 +23,7 @@ py_binary( srcs = [ "fully_connected_reader.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/examples/ios/camera/ios_image_load.h b/tensorflow/examples/ios/camera/ios_image_load.h index 991568751e9..8f2da481f46 100644 --- a/tensorflow/examples/ios/camera/ios_image_load.h +++ b/tensorflow/examples/ios/camera/ios_image_load.h @@ -20,8 +20,8 @@ #include "third_party/tensorflow/core/framework/types.h" std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); + int* out_width, + int* out_height, + int* out_channels); #endif // TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/examples/label_image/BUILD b/tensorflow/examples/label_image/BUILD index c50fd93d039..cc73163f3b5 100644 --- a/tensorflow/examples/label_image/BUILD +++ b/tensorflow/examples/label_image/BUILD @@ -57,6 +57,7 @@ py_binary( name = "label_image_py", srcs = ["label_image.py"], main = "label_image.py", + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/examples/learn/BUILD b/tensorflow/examples/learn/BUILD index a22d55e5af7..d98fe96f47a 100644 --- a/tensorflow/examples/learn/BUILD +++ b/tensorflow/examples/learn/BUILD @@ -12,6 +12,7 @@ exports_files(["LICENSE"]) py_binary( name = "iris_custom_decay_dnn", srcs = ["iris_custom_decay_dnn.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = ["//tensorflow:tensorflow_py"], ) @@ -19,6 +20,7 @@ py_binary( py_binary( name = "iris_custom_model", srcs = ["iris_custom_model.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = ["//tensorflow:tensorflow_py"], ) diff --git a/tensorflow/examples/multibox_detector/main.cc b/tensorflow/examples/multibox_detector/main.cc index 96ea525a4e7..82552a71740 100644 --- a/tensorflow/examples/multibox_detector/main.cc +++ b/tensorflow/examples/multibox_detector/main.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -228,7 +229,9 @@ void DecodeLocation(const float* encoded_location, const float* box_priors, } } -float DecodeScore(float encoded_score) { return 1 / (1 + exp(-encoded_score)); } +float DecodeScore(float encoded_score) { + return 1 / (1 + std::exp(-encoded_score)); +} void DrawBox(const int image_width, const int image_height, int left, int top, int right, int bottom, tensorflow::TTypes::Flat* image) { diff --git a/tensorflow/examples/saved_model/integration_tests/BUILD b/tensorflow/examples/saved_model/integration_tests/BUILD index 08415936879..5ade3c2dbea 100644 --- a/tensorflow/examples/saved_model/integration_tests/BUILD +++ b/tensorflow/examples/saved_model/integration_tests/BUILD @@ -2,64 +2,25 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") - -py_binary( - name = "export_text_rnn_model", - srcs = ["export_text_rnn_model.py"], - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "use_text_rnn_model", - srcs = ["use_text_rnn_model.py"], - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "export_rnn_cell", - srcs = ["export_rnn_cell.py"], - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "use_rnn_cell", - srcs = ["use_rnn_cell.py"], - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "export_simple_text_embedding", - srcs = ["export_simple_text_embedding.py"], - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "use_model_in_sequential_keras", - srcs = ["use_model_in_sequential_keras.py"], - visibility = ["//tensorflow:internal"], - deps = [ - ":util", - "//tensorflow:tensorflow_py", - ], -) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( - name = "util", - srcs = ["util.py"], + name = "integration_scripts", + srcs = [ + "export_mnist_cnn.py", + "export_rnn_cell.py", + "export_simple_text_embedding.py", + "export_text_rnn_model.py", + "integration_scripts.py", + "use_mnist_cnn.py", + "use_model_in_sequential_keras.py", + "use_rnn_cell.py", + "use_text_embedding_in_dataset.py", + "use_text_rnn_model.py", + ], + visibility = ["//tensorflow:internal"], deps = [ + ":mnist_util", "//tensorflow:tensorflow_py", ], ) @@ -67,54 +28,38 @@ py_library( py_library( name = "mnist_util", srcs = ["mnist_util.py"], - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "export_mnist_cnn", - srcs = ["export_mnist_cnn.py"], - deps = [ - ":mnist_util", - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "use_mnist_cnn", - srcs = ["use_mnist_cnn.py"], visibility = ["//tensorflow:internal"], deps = [ - ":mnist_util", - ":util", "//tensorflow:tensorflow_py", ], ) -py_test( +cuda_py_test( name = "saved_model_test", srcs = [ "saved_model_test.py", ], - data = [ - ":export_mnist_cnn", - ":export_rnn_cell", - ":export_simple_text_embedding", - ":export_text_rnn_model", - ":use_mnist_cnn", - ":use_model_in_sequential_keras", - ":use_rnn_cell", - ":use_text_rnn_model", + additional_deps = [ + ":integration_scripts", + "//tensorflow:tensorflow_py", ], shard_count = 4, - srcs_version = "PY2AND3", tags = [ + "no_pip", # b/131697937 and b/132196869 "noasan", # forge input size exceeded "nomsan", # forge input size exceeded "notsan", # forge input size exceeded ], - deps = [ - "//tensorflow:tensorflow_py", - ], +) + +# b/132234211: Target added to support internal test target that runs the test +# in an environment that has the extra dependencies required to test integration +# with non core tensorflow packages. +py_library( + name = "saved_model_test_lib", + srcs = [ + "saved_model_test.py", + ], + visibility = ["//tensorflow:internal"], + deps = [":integration_scripts"], ) diff --git a/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py index 7c1a356e661..1d36bc234ae 100644 --- a/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py +++ b/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py @@ -185,5 +185,4 @@ def main(argv): if __name__ == '__main__': - tf.enable_v2_behavior() app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/export_rnn_cell.py b/tensorflow/examples/saved_model/integration_tests/export_rnn_cell.py index bac1d4c35a1..876e3004bca 100644 --- a/tensorflow/examples/saved_model/integration_tests/export_rnn_cell.py +++ b/tensorflow/examples/saved_model/integration_tests/export_rnn_cell.py @@ -60,5 +60,4 @@ def main(argv): if __name__ == "__main__": - tf.enable_v2_behavior() app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py b/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py index 00829aab9a7..b8e76e895fc 100644 --- a/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py +++ b/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py @@ -72,7 +72,7 @@ class TextEmbeddingModel(tf.train.Checkpoint): normalized_sentences = tf.strings.regex_replace( input=sentences, pattern=r"\pP", rewrite="") normalized_sentences = tf.reshape(normalized_sentences, [-1]) - sparse_tokens = tf.string_split(normalized_sentences, " ") + sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse() # Deal with a corner case: there is one empty sentence. sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant("")) @@ -102,5 +102,4 @@ def main(argv): if __name__ == "__main__": - tf.enable_v2_behavior() app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/export_text_rnn_model.py b/tensorflow/examples/saved_model/integration_tests/export_text_rnn_model.py index af89e4e9f19..9b9f5925588 100644 --- a/tensorflow/examples/saved_model/integration_tests/export_text_rnn_model.py +++ b/tensorflow/examples/saved_model/integration_tests/export_text_rnn_model.py @@ -50,7 +50,7 @@ class TextRnnModel(tf.train.Checkpoint): # splitting on spaces. normalized_sentences = tf.strings.regex_replace( input=sentences, pattern=r"\pP", rewrite="") - sparse_tokens = tf.string_split(normalized_sentences, " ") + sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse() # Deal with a corner case: there is one empty sentence. sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant("")) @@ -190,5 +190,4 @@ def main(argv): if __name__ == "__main__": - tf.enable_v2_behavior() app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/integration_scripts.py b/tensorflow/examples/saved_model/integration_tests/integration_scripts.py new file mode 100644 index 00000000000..0db91facd65 --- /dev/null +++ b/tensorflow/examples/saved_model/integration_tests/integration_scripts.py @@ -0,0 +1,65 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to write SavedModel integration tests. + +SavedModel testing requires isolation between the process that creates and +consumes it. This file helps doing that by relaunching the same binary that +calls `assertCommandSucceeded` with an environment flag indicating what source +file to execute. That binary must start by calling `MaybeRunScriptInstead`. + +This allows to wire this into existing building systems without having to depend +on data dependencies. And as so allow to keep a fixed binary size and allows +interop with GPU tests. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import os +import subprocess +import sys + +from absl import app +import tensorflow.compat.v2 as tf + +from tensorflow.python.platform import tf_logging as logging + + +class TestCase(tf.test.TestCase): + """Base class to write SavedModel integration tests.""" + + def assertCommandSucceeded(self, script_name, **flags): + """Runs an integration test script with given flags.""" + run_script = sys.argv[0] + if run_script.endswith(".py"): + command_parts = [sys.executable, run_script] + else: + command_parts = [run_script] + for flag_key, flag_value in flags.items(): + command_parts.append("--%s=%s" % (flag_key, flag_value)) + env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name) + logging.info("Running: %s with environment flags %s" % (command_parts, env)) + subprocess.check_call(command_parts, env=dict(os.environ, **env)) + + +def MaybeRunScriptInstead(): + if "SCRIPT_NAME" in os.environ: + # Append current path to import path and execute `SCRIPT_NAME` main. + sys.path.extend([os.path.dirname(__file__)]) + module_name = os.environ["SCRIPT_NAME"] + retval = app.run(importlib.import_module(module_name).main) + sys.exit(retval) diff --git a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py index 6ec387e2550..7cc8fde6167 100644 --- a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py +++ b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py @@ -18,69 +18,81 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import subprocess - import tensorflow.compat.v2 as tf -from tensorflow.python.framework import test_util -from tensorflow.python.platform import resource_loader -from tensorflow.python.platform import tf_logging as logging +from tensorflow.examples.saved_model.integration_tests import integration_scripts -class SavedModelTest(tf.test.TestCase): +class SavedModelTest(integration_scripts.TestCase): - def assertCommandSucceeded(self, binary, **flags): - command_parts = [binary] - for flag_key, flag_value in flags.items(): - command_parts.append("--%s=%s" % (flag_key, flag_value)) + def __init__(self, method_name="runTest", has_extra_deps=False): + super(SavedModelTest, self).__init__(method_name) + self.has_extra_deps = has_extra_deps - logging.info("Running: %s" % command_parts) - subprocess.check_call( - command_parts, env=dict(os.environ, TF2_BEHAVIOR="enabled")) + def skipIfMissingExtraDeps(self): + """Skip test if it requires extra dependencies. + + b/132234211: The extra dependencies are not available in all environments + that run the tests, e.g. "tensorflow_hub" is not available from tests + within "tensorflow" alone. Those tests are instead run by another + internal test target. + """ + if not self.has_extra_deps: + self.skipTest("Missing extra dependencies") - @test_util.run_v2_only def test_text_rnn(self): export_dir = self.get_temp_dir() - export_binary = resource_loader.get_path_to_datafile( - "export_text_rnn_model") - self.assertCommandSucceeded(export_binary, export_dir=export_dir) + self.assertCommandSucceeded("export_text_rnn_model", export_dir=export_dir) + self.assertCommandSucceeded("use_text_rnn_model", model_dir=export_dir) - use_binary = resource_loader.get_path_to_datafile("use_text_rnn_model") - self.assertCommandSucceeded(use_binary, model_dir=export_dir) - - @test_util.run_v2_only def test_rnn_cell(self): export_dir = self.get_temp_dir() - export_binary = resource_loader.get_path_to_datafile( - "export_rnn_cell") - self.assertCommandSucceeded(export_binary, export_dir=export_dir) + self.assertCommandSucceeded("export_rnn_cell", export_dir=export_dir) + self.assertCommandSucceeded("use_rnn_cell", model_dir=export_dir) - use_binary = resource_loader.get_path_to_datafile("use_rnn_cell") - self.assertCommandSucceeded(use_binary, model_dir=export_dir) - - @test_util.run_v2_only def test_text_embedding_in_sequential_keras(self): + self.skipIfMissingExtraDeps() export_dir = self.get_temp_dir() - export_binary = resource_loader.get_path_to_datafile( - "export_simple_text_embedding") - self.assertCommandSucceeded(export_binary, export_dir=export_dir) + self.assertCommandSucceeded( + "export_simple_text_embedding", export_dir=export_dir) + self.assertCommandSucceeded( + "use_model_in_sequential_keras", model_dir=export_dir) - use_binary = resource_loader.get_path_to_datafile( - "use_model_in_sequential_keras") - self.assertCommandSucceeded(use_binary, model_dir=export_dir) + def test_text_embedding_in_dataset(self): + if tf.test.is_gpu_available(): + self.skipTest("b/132156097 - fails if there is a gpu available") + + export_dir = self.get_temp_dir() + self.assertCommandSucceeded( + "export_simple_text_embedding", export_dir=export_dir) + self.assertCommandSucceeded( + "use_text_embedding_in_dataset", model_dir=export_dir) - @test_util.run_v2_only def test_mnist_cnn(self): + self.skipIfMissingExtraDeps() export_dir = self.get_temp_dir() - export_binary = resource_loader.get_path_to_datafile("export_mnist_cnn") - self.assertCommandSucceeded(export_binary, export_dir=export_dir, - fast_test_mode="true") + self.assertCommandSucceeded( + "export_mnist_cnn", export_dir=export_dir, fast_test_mode="true") + self.assertCommandSucceeded( + "use_mnist_cnn", export_dir=export_dir, fast_test_mode="true") + + def test_mnist_cnn_with_mirrored_strategy(self): + self.skipIfMissingExtraDeps() + self.skipTest( + "b/129134185 - saved model and distribution strategy integration") + export_dir = self.get_temp_dir() + self.assertCommandSucceeded( + "export_mnist_cnn", + export_dir=export_dir, + fast_test_mode="true") + self.assertCommandSucceeded( + "use_mnist_cnn", + export_dir=export_dir, + fast_test_mode="true", + use_mirrored_strategy=True, + ) - use_binary = resource_loader.get_path_to_datafile("use_mnist_cnn") - self.assertCommandSucceeded(use_binary, export_dir=export_dir, - fast_test_mode="true") if __name__ == "__main__": - tf.enable_v2_behavior() + integration_scripts.MaybeRunScriptInstead() tf.test.main() diff --git a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py index 3f5455e47ba..957091f0e86 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py +++ b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py @@ -29,9 +29,9 @@ from __future__ import print_function from absl import app from absl import flags import tensorflow.compat.v2 as tf +import tensorflow_hub as hub from tensorflow.examples.saved_model.integration_tests import mnist_util -from tensorflow.examples.saved_model.integration_tests import util FLAGS = flags.FLAGS @@ -57,6 +57,30 @@ flags.DEFINE_bool( flags.DEFINE_bool( 'fast_test_mode', False, 'Shortcut training for running in unit tests.') +flags.DEFINE_bool( + 'use_mirrored_strategy', False, + 'Whether to use mirrored distribution strategy.') + + +def make_feature_extractor(saved_model_path, trainable, + regularization_loss_multiplier): + """Load a pre-trained feature extractor and wrap it for use in Keras.""" + obj = tf.saved_model.load(saved_model_path) + + # Optional: scale regularization losses to target problem. + if regularization_loss_multiplier: + def _scale_one_loss(l): # Separate def avoids lambda capture of loop var. + f = tf.function(lambda: tf.multiply(regularization_loss_multiplier, l())) + _ = f.get_concrete_function() + return f + obj.regularization_losses = [_scale_one_loss(l) + for l in obj.regularization_losses] + + arguments = {} + if FLAGS.dropout_rate is not None: + arguments['dropout_rate'] = FLAGS.dropout_rate + + return hub.KerasLayer(obj, trainable=trainable, arguments=arguments) def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5): @@ -70,53 +94,38 @@ def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5): return tf.keras.Model(inputs=inp, outputs=net) -def scale_regularization_losses(obj, multiplier): - """Scales obj.regularization_losses by multiplier if not None.""" - if multiplier is None: return - def _scale_one_loss(l): # Separate def avoids lambda capture of loop var. - f = tf.function(lambda: tf.multiply(multiplier, l())) - _ = f.get_concrete_function() - return f - obj.regularization_losses = [_scale_one_loss(l) - for l in obj.regularization_losses] - - def main(argv): del argv - # Load a pre-trained feature extractor and wrap it for use in Keras. - obj = tf.saved_model.load(FLAGS.export_dir) - scale_regularization_losses(obj, FLAGS.regularization_loss_multiplier) - arguments = {} - if FLAGS.dropout_rate is not None: - arguments['dropout_rate'] = FLAGS.dropout_rate - feature_extractor = util.CustomLayer(obj, output_shape=[10], - trainable=FLAGS.retrain, - arguments=arguments) + if FLAGS.use_mirrored_strategy: + strategy = tf.distribute.MirroredStrategy() + else: + strategy = tf.distribute.get_strategy() - # Build a classifier with it. - model = make_classifier(feature_extractor) + with strategy.scope(): + feature_extractor = make_feature_extractor( + FLAGS.export_dir, + FLAGS.retrain, + FLAGS.regularization_loss_multiplier) + model = make_classifier(feature_extractor) + + model.compile(loss=tf.keras.losses.categorical_crossentropy, + optimizer=tf.keras.optimizers.SGD(), + metrics=['accuracy']) # Train the classifier (possibly on a different dataset). (x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data( use_fashion_mnist=FLAGS.use_fashion_mnist, fake_tiny_data=FLAGS.fast_test_mode) - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.keras.optimizers.SGD(), - metrics=['accuracy'], - # TODO(arnoegw): Remove after investigating huge allocs. - run_eagerly=True) print('Training on %s with %d trainable and %d untrainable variables.' % ('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST', len(model.trainable_variables), len(model.non_trainable_variables))) model.fit(x_train, y_train, batch_size=128, epochs=FLAGS.epochs, - steps_per_epoch=3, verbose=1, validation_data=(x_test, y_test)) if __name__ == '__main__': - tf.enable_v2_behavior() app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py b/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py index 2b6efb76f74..2446ff91fb0 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py +++ b/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py @@ -23,7 +23,7 @@ from absl import flags import numpy as np import tensorflow.compat.v2 as tf -from tensorflow.examples.saved_model.integration_tests import util +import tensorflow_hub as hub FLAGS = flags.FLAGS @@ -42,7 +42,8 @@ def train(fine_tuning): l = tf.keras.layers model = tf.keras.Sequential() model.add(l.Reshape((), batch_input_shape=[None, 1], dtype=tf.string)) - model.add(util.CustomLayer(module, output_shape=[10], trainable=fine_tuning)) + # TODO(b/124219898): output_shape should be optional. + model.add(hub.KerasLayer(module, output_shape=[10], trainable=fine_tuning)) model.add(l.Dense(100, activation="relu")) model.add(l.Dense(50, activation="relu")) model.add(l.Dense(1, activation="sigmoid")) @@ -65,5 +66,4 @@ def main(argv): if __name__ == "__main__": - tf.enable_v2_behavior() app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py b/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py index 14393795832..e9f251376ef 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py +++ b/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py @@ -41,5 +41,4 @@ def main(argv): if __name__ == "__main__": - tf.enable_v2_behavior() app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py b/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py new file mode 100644 index 00000000000..a21922219ce --- /dev/null +++ b/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py @@ -0,0 +1,67 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Load and use text embedding module in a Dataset map function.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags + +import numpy as np +import tensorflow.compat.v2 as tf + +FLAGS = flags.FLAGS + +flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.") + + +def train(): + """Build a Keras model and train with mock data.""" + module = tf.saved_model.load(FLAGS.model_dir) + def _map_fn(features, labels): + features = tf.expand_dims(features, 0) + features = module(features) + features = tf.squeeze(features, 0) + return features, labels + + features = np.array(["my first sentence", "my second sentence"]) + labels = np.array([1, 0]) + dataset = tf.data.Dataset.from_tensor_slices((features, labels)).map(_map_fn) + + # Create the sequential keras model. + l = tf.keras.layers + model = tf.keras.Sequential() + model.add(l.Dense(10, activation="relu")) + model.add(l.Dense(1, activation="sigmoid")) + + model.compile( + optimizer="adam", + loss="binary_crossentropy", + metrics=["accuracy"]) + + model.fit_generator(generator=dataset.batch(10), epochs=5) + + +def main(argv): + del argv + + train() + + +if __name__ == "__main__": + tf.enable_v2_behavior() + app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py b/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py index 3811e3606c4..9178ff5581f 100644 --- a/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py +++ b/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py @@ -42,5 +42,4 @@ def main(argv): if __name__ == "__main__": - tf.enable_v2_behavior() app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/util.py b/tensorflow/examples/saved_model/integration_tests/util.py deleted file mode 100644 index 1b709fdf98c..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/util.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for integration tests.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -import tensorflow.compat.v2 as tf - -from tensorflow.python.framework import smart_cond -from tensorflow.python.util import tf_inspect - - -# TODO(vbardiovsky): We should just reuse Keras's Lambda layer, when that -# enables to get trainable variables. -class CustomLayer(tf.keras.layers.Layer): - """Wraps callable object as a `Layer` object. - - Args: - func: The callable object to wrap. Layer inputs are passed as the first - positional argument. If `func` accepts a `training` argument, a Python - boolean is passed for it. - If present, the following attributes of `func` have a special meaning: - * variables: a list of all tf.Variable objects that `func` depends on. - * trainable_variables: those elements of `variables` that are reported - as trainable variables of this Keras Layer. - * regularization_losses: a list of callables to be added as losses - of this Keras layer. Each one must accept zero arguments and return - a scalare tensor. - trainable: Boolean controlling whether the trainable variables of `func` - are reported as trainable variables of this layer. - arguments: optionally, a dict with additional keyword arguments passed - to `func`. - **kwargs: 'output_shape': A tuple with the (possibly partial) output - shape of the callable *without* leading batch size. Other arguments - are pass into the Layer constructor. - """ - - def __init__(self, func, trainable=False, arguments=None, **kwargs): - # Set self._{non,}_trainable_weights before calling Layer.__init__. - if hasattr(func, 'trainable_variables'): - self._trainable_weights = [v for v in func.trainable_variables] - trainable_variables_set = set(func.trainable_variables) - else: - self._trainable_weights = [] - trainable_variables_set = set() - if hasattr(func, 'variables'): - self._non_trainable_weights = [v for v in func.variables - if v not in trainable_variables_set] - else: - self._non_trainable_weights = [] # TODO(arnoegw): Infer from `func`. - - # TODO(b/124219898): We should be able to get the embedding dimension from - # the restored model. - if 'output_shape' in kwargs: - self._output_shape = tuple(kwargs.pop('output_shape')) - - super(CustomLayer, self).__init__(trainable=trainable, **kwargs) - # Prepare to call `func`. - self._func = func - self._func_fullargspec = tf_inspect.getfullargspec(func.__call__) - self._func_wants_training = ( - 'training' in self._func_fullargspec.args or - 'training' in self._func_fullargspec.kwonlyargs) - self._arguments = arguments or {} - # Forward the callable's regularization losses (if any). - if hasattr(func, 'regularization_losses'): - for l in func.regularization_losses: - if not callable(l): - raise ValueError( - 'CustomLayer(func) expects func.regularization_losses to be an ' - 'iterable of callables, each returning a scalar loss term.') - self.add_loss(l) # Supports callables. - - def call(self, x, training=None): - # We basically want to call this... - f = functools.partial(self._func, x, **self._arguments) - # ...but we may also have to pass a Python boolean for `training`. - if not self._func_wants_training: - result = f() - else: - if training is None: - training = tf.keras.backend.learning_phase() # Could be a tensor. - result = smart_cond.smart_cond(training, - lambda: f(training=True), - lambda: f(training=False)) - # TODO(b/124219898): Polymorphic function should return shaped tensor. - if hasattr(self, '_output_shape'): - result.set_shape((x.shape[0],) + self._output_shape) - return result diff --git a/tensorflow/examples/speech_commands/BUILD b/tensorflow/examples/speech_commands/BUILD index 88f7fe7faa6..f498f2a390d 100644 --- a/tensorflow/examples/speech_commands/BUILD +++ b/tensorflow/examples/speech_commands/BUILD @@ -35,6 +35,9 @@ tf_py_test( ":models", "//tensorflow/python:client_testlib", ], + tags = [ + "no_pip", # b/131330719 + ], ) py_library( @@ -59,11 +62,15 @@ tf_py_test( ":models", "//tensorflow/python:client_testlib", ], + tags = [ + "no_pip", # b/131330719 + ], ) py_binary( name = "train", srcs = ["train.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":train_main_lib"], ) @@ -88,14 +95,18 @@ tf_py_test( size = "small", srcs = ["train_test.py"], additional_deps = [ - ":train", + ":train_main_lib", "//tensorflow/python:client_testlib", ], + tags = [ + "no_pip", # b/131330719 + ], ) py_binary( name = "freeze", srcs = ["freeze.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":freeze_main_lib"], ) @@ -113,6 +124,9 @@ py_library( "freeze.py", ], srcs_version = "PY2AND3", + tags = [ + "no_pip", # b/131330719 + ], deps = [ ":input_data", ":models", @@ -127,14 +141,18 @@ tf_py_test( size = "small", srcs = ["freeze_test.py"], additional_deps = [ - ":freeze", + ":freeze_main_lib", "//tensorflow/python:client_testlib", ], + tags = [ + "no_pip", # b/131330719 + ], ) py_binary( name = "wav_to_features", srcs = ["wav_to_features.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":wav_to_features_main_lib"], ) @@ -166,14 +184,18 @@ tf_py_test( size = "small", srcs = ["wav_to_features_test.py"], additional_deps = [ - ":wav_to_features", + ":wav_to_features_main_lib", "//tensorflow/python:client_testlib", ], + tags = [ + "no_pip", # b/131330719 + ], ) py_binary( name = "generate_streaming_test_wav", srcs = ["generate_streaming_test_wav.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":generate_streaming_test_wav_main_lib"], ) @@ -205,9 +227,12 @@ tf_py_test( size = "small", srcs = ["generate_streaming_test_wav_test.py"], additional_deps = [ - ":generate_streaming_test_wav", + ":generate_streaming_test_wav_main_lib", "//tensorflow/python:client_testlib", ], + tags = [ + "no_pip", # b/131330719 + ], ) tf_cc_binary( @@ -228,6 +253,7 @@ tf_cc_binary( py_binary( name = "label_wav", srcs = ["label_wav.py"], + python_version = "PY2", srcs_version = "PY2AND3", deps = [":label_wav_main_lib"], ) @@ -255,9 +281,12 @@ tf_py_test( size = "medium", srcs = ["label_wav_test.py"], additional_deps = [ - ":label_wav", + ":label_wav_main_lib", "//tensorflow/python:client_testlib", ], + tags = [ + "no_pip", # b/131330719 + ], ) cc_library( @@ -344,3 +373,17 @@ tf_cc_binary( "//tensorflow/core:protos_all_cc", ], ) + +py_library( + name = "test_lib", + srcs_version = "PY2AND3", + deps = [ + ":freeze", + ":generate_streaming_test_wav", + ":input_data", + ":label_wav", + ":models", + ":train", + ":wav_to_features", + ], +) diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py index c63d4c3c7d1..1fd6a8eea17 100644 --- a/tensorflow/examples/speech_commands/models.py +++ b/tensorflow/examples/speech_commands/models.py @@ -527,6 +527,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, shape=[num_filters, batch, input_time_size], trainable=False, name='runtime-memory') + first_time_flag = tf.get_variable( + name="first_time_flag", + dtype=tf.int32, + initializer=1) # Determine the number of new frames in the input, such that we only operate # on those. For training we do not use the memory, and thus use all frames # provided in the input. @@ -537,9 +541,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, window_stride_ms = int(model_settings['window_stride_samples'] * 1000 / model_settings['sample_rate']) num_new_frames = tf.cond( - tf.equal(tf.count_nonzero(memory), 0), + tf.equal(first_time_flag, 1), lambda: input_time_size, lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms)) + first_time_flag = 0 new_fingerprint_input = fingerprint_input[ :, -num_new_frames*input_frequency_size:] # Expand to add input channels dimension. diff --git a/tensorflow/examples/tf2_showcase/BUILD b/tensorflow/examples/tf2_showcase/BUILD index 922bc96b25b..4fd62a15868 100644 --- a/tensorflow/examples/tf2_showcase/BUILD +++ b/tensorflow/examples/tf2_showcase/BUILD @@ -19,6 +19,7 @@ test_suite( py_test( name = "mnist", srcs = ["mnist.py"], + python_version = "PY2", tags = [ "manual", "no_oss", diff --git a/tensorflow/examples/tutorials/layers/BUILD b/tensorflow/examples/tutorials/layers/BUILD index aad78b18409..e4383d155b0 100644 --- a/tensorflow/examples/tutorials/layers/BUILD +++ b/tensorflow/examples/tutorials/layers/BUILD @@ -13,6 +13,7 @@ py_binary( srcs = [ "cnn_mnist.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD index 5f12374bdbd..6839c486144 100644 --- a/tensorflow/examples/tutorials/mnist/BUILD +++ b/tensorflow/examples/tutorials/mnist/BUILD @@ -50,6 +50,7 @@ py_binary( srcs = [ "fully_connected_feed.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = ["optonly"], deps = [ @@ -64,6 +65,7 @@ py_binary( srcs = [ "mnist_with_summaries.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":input_data", @@ -82,6 +84,7 @@ py_binary( srcs = [ "mnist_softmax_xla.py", ], + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":input_data", @@ -100,6 +103,7 @@ py_test( "--max_steps=10", ], main = "fully_connected_feed.py", + python_version = "PY2", srcs_version = "PY2AND3", deps = [ ":input_data", @@ -120,6 +124,7 @@ py_test( "--learning_rate=0.00", ], main = "mnist_with_summaries.py", + python_version = "PY2", srcs_version = "PY2AND3", tags = ["notsan"], # http://b/29184009 deps = [ diff --git a/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py b/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py index a9cb20fdfd3..2945660dad5 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py +++ b/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py @@ -46,12 +46,12 @@ def main(_): # The raw formulation of cross-entropy, # - # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), + # tf.reduce_mean(-tf.reduce_sum(y_ * tf.math.log(tf.nn.softmax(y)), # reduction_indices=[1])) # # can be numerically unstable. # - # So here we use tf.losses.sparse_softmax_cross_entropy on the raw + # So here we use tf.compat.v1.losses.sparse_softmax_cross_entropy on the raw # logit outputs of 'y', and then average across the batch. cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py index 1854e84d490..3485e7afbf1 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py +++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py @@ -111,12 +111,12 @@ def train(): with tf.name_scope('cross_entropy'): # The raw formulation of cross-entropy, # - # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.softmax(y)), + # tf.reduce_mean(-tf.reduce_sum(y_ * tf.math.log(tf.softmax(y)), # reduction_indices=[1])) # # can be numerically unstable. # - # So here we use tf.losses.sparse_softmax_cross_entropy on the + # So here we use tf.compat.v1.losses.sparse_softmax_cross_entropy on the # raw logit outputs of the nn_layer above, and then average across # the batch. with tf.name_scope('total'): diff --git a/tensorflow/examples/tutorials/word2vec/BUILD b/tensorflow/examples/tutorials/word2vec/BUILD index 2e19c038bdf..5293f437dce 100644 --- a/tensorflow/examples/tutorials/word2vec/BUILD +++ b/tensorflow/examples/tutorials/word2vec/BUILD @@ -12,6 +12,7 @@ py_binary( srcs = [ "word2vec_basic.py", ], + python_version = "PY2", srcs_version = "PY2AND3", tags = [ "no-internal-py3", diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index 805ec203b48..ebfaacb8a2c 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -81,10 +81,10 @@ def word2vec_basic(log_dir): """Process raw inputs into a dataset.""" count = [['UNK', -1]] count.extend(collections.Counter(words).most_common(n_words - 1)) - dictionary = dict() + dictionary = {} for word, _ in count: dictionary[word] = len(dictionary) - data = list() + data = [] unk_count = 0 for word in words: index = dictionary.get(word, 0) @@ -337,8 +337,9 @@ def word2vec_basic(log_dir): print(ex) -# All functionality is run after tf.app.run() (b/122547914). This could be split -# up but the methods are laid sequentially with their usage for clarity. +# All functionality is run after tf.compat.v1.app.run() (b/122547914). This +# could be split up but the methods are laid sequentially with their usage for +# clarity. def main(unused_argv): # Give a folder path as an argument with '--log_dir' to save # TensorBoard summaries. Default is a log folder in current directory. @@ -353,5 +354,6 @@ def main(unused_argv): flags, unused_flags = parser.parse_known_args() word2vec_basic(flags.log_dir) + if __name__ == '__main__': tf.app.run() diff --git a/tensorflow/examples/udacity/1_notmnist.ipynb b/tensorflow/examples/udacity/1_notmnist.ipynb deleted file mode 100644 index dffe5d37c64..00000000000 --- a/tensorflow/examples/udacity/1_notmnist.ipynb +++ /dev/null @@ -1,800 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "version": "0.3.2", - "views": {}, - "default_view": {}, - "name": "1_notmnist.ipynb", - "provenance": [] - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "5hIbr52I7Z7U", - "colab_type": "text" - }, - "source": [ - "Deep Learning\n", - "=============\n", - "\n", - "Assignment 1\n", - "------------\n", - "\n", - "The objective of this assignment is to learn about simple data curation practices, and familiarize you with some of the data we'll be reusing later.\n", - "\n", - "This notebook uses the [notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html) dataset to be used with python experiments. This dataset is designed to look like the classic [MNIST](http://yann.lecun.com/exdb/mnist/) dataset, while looking a little more like real data: it's a harder task, and the data is a lot less 'clean' than MNIST." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "apJbCsBHl-2A", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "# These are all the modules we'll be using later. Make sure you can import them\n", - "# before proceeding further.\n", - "from __future__ import print_function\n", - "import imageio\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import os\n", - "import sys\n", - "import tarfile\n", - "from IPython.display import display, Image\n", - "from sklearn.linear_model import LogisticRegression\n", - "from six.moves.urllib.request import urlretrieve\n", - "from six.moves import cPickle as pickle\n", - "\n", - "# Config the matplotlib backend as plotting inline in IPython\n", - "%matplotlib inline" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jNWGtZaXn-5j", - "colab_type": "text" - }, - "source": [ - "First, we'll download the dataset to our local machine. The data consists of characters rendered in a variety of fonts on a 28x28 image. The labels are limited to 'A' through 'J' (10 classes). The training set has about 500k and the testset 19000 labeled examples. Given these sizes, it should be possible to train models quickly on any machine." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "EYRJ4ICW6-da", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 186058, - "status": "ok", - "timestamp": 1444485672507, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2a0a5e044bb03b66", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "0d0f85df-155f-4a89-8e7e-ee32df36ec8d" - }, - "source": [ - "url = 'https://commondatastorage.googleapis.com/books1000/'\n", - "last_percent_reported = None\n", - "data_root = '.' # Change me to store data elsewhere\n", - "\n", - "def download_progress_hook(count, blockSize, totalSize):\n", - " \"\"\"A hook to report the progress of a download. This is mostly intended for users with\n", - " slow internet connections. Reports every 5% change in download progress.\n", - " \"\"\"\n", - " global last_percent_reported\n", - " percent = int(count * blockSize * 100 / totalSize)\n", - "\n", - " if last_percent_reported != percent:\n", - " if percent % 5 == 0:\n", - " sys.stdout.write(\"%s%%\" % percent)\n", - " sys.stdout.flush()\n", - " else:\n", - " sys.stdout.write(\".\")\n", - " sys.stdout.flush()\n", - " \n", - " last_percent_reported = percent\n", - " \n", - "def maybe_download(filename, expected_bytes, force=False):\n", - " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", - " dest_filename = os.path.join(data_root, filename)\n", - " if force or not os.path.exists(dest_filename):\n", - " print('Attempting to download:', filename) \n", - " filename, _ = urlretrieve(url + filename, dest_filename, reporthook=download_progress_hook)\n", - " print('\\nDownload Complete!')\n", - " statinfo = os.stat(dest_filename)\n", - " if statinfo.st_size == expected_bytes:\n", - " print('Found and verified', dest_filename)\n", - " else:\n", - " raise Exception(\n", - " 'Failed to verify ' + dest_filename + '. Can you get to it with a browser?')\n", - " return dest_filename\n", - "\n", - "train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)\n", - "test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Found and verified notMNIST_large.tar.gz\n", - "Found and verified notMNIST_small.tar.gz\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cC3p0oEyF8QT", - "colab_type": "text" - }, - "source": [ - "Extract the dataset from the compressed .tar.gz file.\n", - "This should give you a set of directories, labeled A through J." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "H8CBE-WZ8nmj", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 186055, - "status": "ok", - "timestamp": 1444485672525, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2a0a5e044bb03b66", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "ef6c790c-2513-4b09-962e-27c79390c762" - }, - "source": [ - "num_classes = 10\n", - "np.random.seed(133)\n", - "\n", - "def maybe_extract(filename, force=False):\n", - " root = os.path.splitext(os.path.splitext(filename)[0])[0] # remove .tar.gz\n", - " if os.path.isdir(root) and not force:\n", - " # You may override by setting force=True.\n", - " print('%s already present - Skipping extraction of %s.' % (root, filename))\n", - " else:\n", - " print('Extracting data for %s. This may take a while. Please wait.' % root)\n", - " tar = tarfile.open(filename)\n", - " sys.stdout.flush()\n", - " tar.extractall(data_root)\n", - " tar.close()\n", - " data_folders = [\n", - " os.path.join(root, d) for d in sorted(os.listdir(root))\n", - " if os.path.isdir(os.path.join(root, d))]\n", - " if len(data_folders) != num_classes:\n", - " raise Exception(\n", - " 'Expected %d folders, one per class. Found %d instead.' % (\n", - " num_classes, len(data_folders)))\n", - " print(data_folders)\n", - " return data_folders\n", - " \n", - "train_folders = maybe_extract(train_filename)\n", - "test_folders = maybe_extract(test_filename)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "['notMNIST_large/A', 'notMNIST_large/B', 'notMNIST_large/C', 'notMNIST_large/D', 'notMNIST_large/E', 'notMNIST_large/F', 'notMNIST_large/G', 'notMNIST_large/H', 'notMNIST_large/I', 'notMNIST_large/J']\n", - "['notMNIST_small/A', 'notMNIST_small/B', 'notMNIST_small/C', 'notMNIST_small/D', 'notMNIST_small/E', 'notMNIST_small/F', 'notMNIST_small/G', 'notMNIST_small/H', 'notMNIST_small/I', 'notMNIST_small/J']\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4riXK3IoHgx6", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 1\n", - "---------\n", - "\n", - "Let's take a peek at some of the data to make sure it looks sensible. Each exemplar should be an image of a character A through J rendered in a different font. Display a sample of the images that we just downloaded. Hint: you can use the package IPython.display.\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PBdkjESPK8tw", - "colab_type": "text" - }, - "source": [ - "Now let's load the data in a more manageable format. Since, depending on your computer setup you might not be able to fit it all in memory, we'll load each class into a separate dataset, store them on disk and curate them independently. Later we'll merge them into a single dataset of manageable size.\n", - "\n", - "We'll convert the entire dataset into a 3D array (image index, x, y) of floating point values, normalized to have approximately zero mean and standard deviation ~0.5 to make training easier down the road. \n", - "\n", - "A few images might not be readable, we'll just skip them." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "h7q0XhG3MJdf", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 30 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 399874, - "status": "ok", - "timestamp": 1444485886378, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2a0a5e044bb03b66", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "92c391bb-86ff-431d-9ada-315568a19e59" - }, - "source": [ - "image_size = 28 # Pixel width and height.\n", - "pixel_depth = 255.0 # Number of levels per pixel.\n", - "\n", - "def load_letter(folder, min_num_images):\n", - " \"\"\"Load the data for a single letter label.\"\"\"\n", - " image_files = os.listdir(folder)\n", - " dataset = np.ndarray(shape=(len(image_files), image_size, image_size),\n", - " dtype=np.float32)\n", - " print(folder)\n", - " num_images = 0\n", - " for image in image_files:\n", - " image_file = os.path.join(folder, image)\n", - " try:\n", - " image_data = (imageio.imread(image_file).astype(float) - \n", - " pixel_depth / 2) / pixel_depth\n", - " if image_data.shape != (image_size, image_size):\n", - " raise Exception('Unexpected image shape: %s' % str(image_data.shape))\n", - " dataset[num_images, :, :] = image_data\n", - " num_images = num_images + 1\n", - " except (IOError, ValueError) as e:\n", - " print('Could not read:', image_file, ':', e, '- it\\'s ok, skipping.')\n", - " \n", - " dataset = dataset[0:num_images, :, :]\n", - " if num_images < min_num_images:\n", - " raise Exception('Many fewer images than expected: %d < %d' %\n", - " (num_images, min_num_images))\n", - " \n", - " print('Full dataset tensor:', dataset.shape)\n", - " print('Mean:', np.mean(dataset))\n", - " print('Standard deviation:', np.std(dataset))\n", - " return dataset\n", - " \n", - "def maybe_pickle(data_folders, min_num_images_per_class, force=False):\n", - " dataset_names = []\n", - " for folder in data_folders:\n", - " set_filename = folder + '.pickle'\n", - " dataset_names.append(set_filename)\n", - " if os.path.exists(set_filename) and not force:\n", - " # You may override by setting force=True.\n", - " print('%s already present - Skipping pickling.' % set_filename)\n", - " else:\n", - " print('Pickling %s.' % set_filename)\n", - " dataset = load_letter(folder, min_num_images_per_class)\n", - " try:\n", - " with open(set_filename, 'wb') as f:\n", - " pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)\n", - " except Exception as e:\n", - " print('Unable to save data to', set_filename, ':', e)\n", - " \n", - " return dataset_names\n", - "\n", - "train_datasets = maybe_pickle(train_folders, 45000)\n", - "test_datasets = maybe_pickle(test_folders, 1800)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "notMNIST_large/A\n", - "Could not read: notMNIST_large/A/Um9tYW5hIEJvbGQucGZi.png : cannot identify image file - it's ok, skipping.\n", - "Could not read: notMNIST_large/A/RnJlaWdodERpc3BCb29rSXRhbGljLnR0Zg==.png : cannot identify image file - it's ok, skipping.\n", - "Could not read: notMNIST_large/A/SG90IE11c3RhcmQgQlROIFBvc3Rlci50dGY=.png : cannot identify image file - it's ok, skipping.\n", - "Full dataset tensor: (52909, 28, 28)\n", - "Mean: -0.12848\n", - "Standard deviation: 0.425576\n", - "notMNIST_large/B\n", - "Could not read: notMNIST_large/B/TmlraXNFRi1TZW1pQm9sZEl0YWxpYy5vdGY=.png : cannot identify image file - it's ok, skipping.\n", - "Full dataset tensor: (52911, 28, 28)\n", - "Mean: -0.00755947\n", - "Standard deviation: 0.417272\n", - "notMNIST_large/C\n", - "Full dataset tensor: (52912, 28, 28)\n", - "Mean: -0.142321\n", - "Standard deviation: 0.421305\n", - "notMNIST_large/D\n", - "Could not read: notMNIST_large/D/VHJhbnNpdCBCb2xkLnR0Zg==.png : cannot identify image file - it's ok, skipping.\n", - "Full dataset tensor: (52911, 28, 28)\n", - "Mean: -0.0574553\n", - "Standard deviation: 0.434072\n", - "notMNIST_large/E\n", - "Full dataset tensor: (52912, 28, 28)\n", - "Mean: -0.0701406\n", - "Standard deviation: 0.42882\n", - "notMNIST_large/F\n", - "Full dataset tensor: (52912, 28, 28)\n", - "Mean: -0.125914\n", - "Standard deviation: 0.429645\n", - "notMNIST_large/G\n", - "Full dataset tensor: (52912, 28, 28)\n", - "Mean: -0.0947771\n", - "Standard deviation: 0.421674\n", - "notMNIST_large/H\n", - "Full dataset tensor: (52912, 28, 28)\n", - "Mean: -0.0687667\n", - "Standard deviation: 0.430344\n", - "notMNIST_large/I\n", - "Full dataset tensor: (52912, 28, 28)\n", - "Mean: 0.0307405\n", - "Standard deviation: 0.449686\n", - "notMNIST_large/J\n", - "Full dataset tensor: (52911, 28, 28)\n", - "Mean: -0.153479\n", - "Standard deviation: 0.397169\n", - "notMNIST_small/A\n", - "Could not read: notMNIST_small/A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png : cannot identify image file - it's ok, skipping.\n", - "Full dataset tensor: (1872, 28, 28)\n", - "Mean: -0.132588\n", - "Standard deviation: 0.445923\n", - "notMNIST_small/B\n", - "Full dataset tensor: (1873, 28, 28)\n", - "Mean: 0.00535619\n", - "Standard deviation: 0.457054\n", - "notMNIST_small/C\n", - "Full dataset tensor: (1873, 28, 28)\n", - "Mean: -0.141489\n", - "Standard deviation: 0.441056\n", - "notMNIST_small/D\n", - "Full dataset tensor: (1873, 28, 28)\n", - "Mean: -0.0492094\n", - "Standard deviation: 0.460477\n", - "notMNIST_small/E\n", - "Full dataset tensor: (1873, 28, 28)\n", - "Mean: -0.0598952\n", - "Standard deviation: 0.456146\n", - "notMNIST_small/F\n", - "Could not read: notMNIST_small/F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png : cannot identify image file - it's ok, skipping.\n", - "Full dataset tensor: (1872, 28, 28)\n", - "Mean: -0.118148\n", - "Standard deviation: 0.451134\n", - "notMNIST_small/G\n", - "Full dataset tensor: (1872, 28, 28)\n", - "Mean: -0.092519\n", - "Standard deviation: 0.448468\n", - "notMNIST_small/H\n", - "Full dataset tensor: (1872, 28, 28)\n", - "Mean: -0.0586729\n", - "Standard deviation: 0.457387\n", - "notMNIST_small/I\n", - "Full dataset tensor: (1872, 28, 28)\n", - "Mean: 0.0526481\n", - "Standard deviation: 0.472657\n", - "notMNIST_small/J\n", - "Full dataset tensor: (1872, 28, 28)\n", - "Mean: -0.15167\n", - "Standard deviation: 0.449521\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vUdbskYE2d87", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 2\n", - "---------\n", - "\n", - "Let's verify that the data still looks good. Displaying a sample of the labels and images from the ndarray. Hint: you can use matplotlib.pyplot.\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cYznx5jUwzoO", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 3\n", - "---------\n", - "Another check: we expect the data to be balanced across classes. Verify that.\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LA7M7K22ynCt", - "colab_type": "text" - }, - "source": [ - "Merge and prune the training data as needed. Depending on your computer setup, you might not be able to fit it all in memory, and you can tune `train_size` as needed. The labels will be stored into a separate array of integers 0 through 9.\n", - "\n", - "Also create a validation dataset for hyperparameter tuning." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "s3mWgZLpyuzq", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 411281, - "status": "ok", - "timestamp": 1444485897869, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2a0a5e044bb03b66", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "8af66da6-902d-4719-bedc-7c9fb7ae7948" - }, - "source": [ - "def make_arrays(nb_rows, img_size):\n", - " if nb_rows:\n", - " dataset = np.ndarray((nb_rows, img_size, img_size), dtype=np.float32)\n", - " labels = np.ndarray(nb_rows, dtype=np.int32)\n", - " else:\n", - " dataset, labels = None, None\n", - " return dataset, labels\n", - "\n", - "def merge_datasets(pickle_files, train_size, valid_size=0):\n", - " num_classes = len(pickle_files)\n", - " valid_dataset, valid_labels = make_arrays(valid_size, image_size)\n", - " train_dataset, train_labels = make_arrays(train_size, image_size)\n", - " vsize_per_class = valid_size // num_classes\n", - " tsize_per_class = train_size // num_classes\n", - " \n", - " start_v, start_t = 0, 0\n", - " end_v, end_t = vsize_per_class, tsize_per_class\n", - " end_l = vsize_per_class+tsize_per_class\n", - " for label, pickle_file in enumerate(pickle_files): \n", - " try:\n", - " with open(pickle_file, 'rb') as f:\n", - " letter_set = pickle.load(f)\n", - " # let's shuffle the letters to have random validation and training set\n", - " np.random.shuffle(letter_set)\n", - " if valid_dataset is not None:\n", - " valid_letter = letter_set[:vsize_per_class, :, :]\n", - " valid_dataset[start_v:end_v, :, :] = valid_letter\n", - " valid_labels[start_v:end_v] = label\n", - " start_v += vsize_per_class\n", - " end_v += vsize_per_class\n", - " \n", - " train_letter = letter_set[vsize_per_class:end_l, :, :]\n", - " train_dataset[start_t:end_t, :, :] = train_letter\n", - " train_labels[start_t:end_t] = label\n", - " start_t += tsize_per_class\n", - " end_t += tsize_per_class\n", - " except Exception as e:\n", - " print('Unable to process data from', pickle_file, ':', e)\n", - " raise\n", - " \n", - " return valid_dataset, valid_labels, train_dataset, train_labels\n", - " \n", - " \n", - "train_size = 200000\n", - "valid_size = 10000\n", - "test_size = 10000\n", - "\n", - "valid_dataset, valid_labels, train_dataset, train_labels = merge_datasets(\n", - " train_datasets, train_size, valid_size)\n", - "_, _, test_dataset, test_labels = merge_datasets(test_datasets, test_size)\n", - "\n", - "print('Training:', train_dataset.shape, train_labels.shape)\n", - "print('Validation:', valid_dataset.shape, valid_labels.shape)\n", - "print('Testing:', test_dataset.shape, test_labels.shape)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Training (200000, 28, 28) (200000,)\n", - "Validation (10000, 28, 28) (10000,)\n", - "Testing (10000, 28, 28) (10000,)\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GPTCnjIcyuKN", - "colab_type": "text" - }, - "source": [ - "Next, we'll randomize the data. It's important to have the labels well shuffled for the training and test distributions to match." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "6WZ2l2tN2zOL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "def randomize(dataset, labels):\n", - " permutation = np.random.permutation(labels.shape[0])\n", - " shuffled_dataset = dataset[permutation,:,:]\n", - " shuffled_labels = labels[permutation]\n", - " return shuffled_dataset, shuffled_labels\n", - "train_dataset, train_labels = randomize(train_dataset, train_labels)\n", - "test_dataset, test_labels = randomize(test_dataset, test_labels)\n", - "valid_dataset, valid_labels = randomize(valid_dataset, valid_labels)" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "puDUTe6t6USl", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 4\n", - "---------\n", - "Convince yourself that the data is still good after shuffling!\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tIQJaJuwg5Hw", - "colab_type": "text" - }, - "source": [ - "Finally, let's save the data for later reuse:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "QiR_rETzem6C", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "pickle_file = os.path.join(data_root, 'notMNIST.pickle')\n", - "\n", - "try:\n", - " f = open(pickle_file, 'wb')\n", - " save = {\n", - " 'train_dataset': train_dataset,\n", - " 'train_labels': train_labels,\n", - " 'valid_dataset': valid_dataset,\n", - " 'valid_labels': valid_labels,\n", - " 'test_dataset': test_dataset,\n", - " 'test_labels': test_labels,\n", - " }\n", - " pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)\n", - " f.close()\n", - "except Exception as e:\n", - " print('Unable to save data to', pickle_file, ':', e)\n", - " raise" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "hQbLjrW_iT39", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 413065, - "status": "ok", - "timestamp": 1444485899688, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2a0a5e044bb03b66", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "b440efc6-5ee1-4cbc-d02d-93db44ebd956" - }, - "source": [ - "statinfo = os.stat(pickle_file)\n", - "print('Compressed pickle size:', statinfo.st_size)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Compressed pickle size: 718193801\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gE_cRAQB33lk", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 5\n", - "---------\n", - "\n", - "By construction, this dataset might contain a lot of overlapping samples, including training data that's also contained in the validation and test set! Overlap between training and test can skew the results if you expect to use your model in an environment where there is never an overlap, but are actually ok if you expect to see training samples recur when you use it.\n", - "Measure how much overlap there is between training, validation and test samples.\n", - "\n", - "Optional questions:\n", - "- What about near duplicates between datasets? (images that are almost identical)\n", - "- Create a sanitized validation and test set, and compare your accuracy on those in subsequent assignments.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L8oww1s4JMQx", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 6\n", - "---------\n", - "\n", - "Let's get an idea of what an off-the-shelf classifier can give you on this data. It's always good to check that there is something to learn, and that it's a problem that is not so trivial that a canned solution solves it.\n", - "\n", - "Train a simple model on this data using 50, 100, 1000 and 5000 training samples. Hint: you can use the LogisticRegression model from sklearn.linear_model.\n", - "\n", - "Optional question: train an off-the-shelf model on all the data!\n", - "\n", - "---" - ] - } - ] -} diff --git a/tensorflow/examples/udacity/2_fullyconnected.ipynb b/tensorflow/examples/udacity/2_fullyconnected.ipynb deleted file mode 100644 index a6a206307aa..00000000000 --- a/tensorflow/examples/udacity/2_fullyconnected.ipynb +++ /dev/null @@ -1,586 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "version": "0.3.2", - "views": {}, - "default_view": {}, - "name": "2_fullyconnected.ipynb", - "provenance": [] - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "kR-4eNdK6lYS", - "colab_type": "text" - }, - "source": [ - "Deep Learning\n", - "=============\n", - "\n", - "Assignment 2\n", - "------------\n", - "\n", - "Previously in `1_notmnist.ipynb`, we created a pickle with formatted datasets for training, development and testing on the [notMNIST dataset](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html).\n", - "\n", - "The goal of this assignment is to progressively train deeper and more accurate models using TensorFlow." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "JLpLa8Jt7Vu4", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "# These are all the modules we'll be using later. Make sure you can import them\n", - "# before proceeding further.\n", - "from __future__ import print_function\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "from six.moves import cPickle as pickle\n", - "from six.moves import range" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1HrCK6e17WzV", - "colab_type": "text" - }, - "source": [ - "First reload the data we generated in `1_notmnist.ipynb`." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y3-cj1bpmuxc", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 19456, - "status": "ok", - "timestamp": 1449847956073, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "0ddb1607-1fc4-4ddb-de28-6c7ab7fb0c33" - }, - "source": [ - "pickle_file = 'notMNIST.pickle'\n", - "\n", - "with open(pickle_file, 'rb') as f:\n", - " save = pickle.load(f)\n", - " train_dataset = save['train_dataset']\n", - " train_labels = save['train_labels']\n", - " valid_dataset = save['valid_dataset']\n", - " valid_labels = save['valid_labels']\n", - " test_dataset = save['test_dataset']\n", - " test_labels = save['test_labels']\n", - " del save # hint to help gc free up memory\n", - " print('Training set', train_dataset.shape, train_labels.shape)\n", - " print('Validation set', valid_dataset.shape, valid_labels.shape)\n", - " print('Test set', test_dataset.shape, test_labels.shape)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Training set (200000, 28, 28) (200000,)\n", - "Validation set (10000, 28, 28) (10000,)\n", - "Test set (18724, 28, 28) (18724,)\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L7aHrm6nGDMB", - "colab_type": "text" - }, - "source": [ - "Reformat into a shape that's more adapted to the models we're going to train:\n", - "- data as a flat matrix,\n", - "- labels as float 1-hot encodings." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IRSyYiIIGIzS", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 19723, - "status": "ok", - "timestamp": 1449847956364, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "2ba0fc75-1487-4ace-a562-cf81cae82793" - }, - "source": [ - "image_size = 28\n", - "num_labels = 10\n", - "\n", - "def reformat(dataset, labels):\n", - " dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)\n", - " # Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]\n", - " labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)\n", - " return dataset, labels\n", - "train_dataset, train_labels = reformat(train_dataset, train_labels)\n", - "valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)\n", - "test_dataset, test_labels = reformat(test_dataset, test_labels)\n", - "print('Training set', train_dataset.shape, train_labels.shape)\n", - "print('Validation set', valid_dataset.shape, valid_labels.shape)\n", - "print('Test set', test_dataset.shape, test_labels.shape)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Training set (200000, 784) (200000, 10)\n", - "Validation set (10000, 784) (10000, 10)\n", - "Test set (18724, 784) (18724, 10)\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nCLVqyQ5vPPH", - "colab_type": "text" - }, - "source": [ - "We're first going to train a multinomial logistic regression using simple gradient descent.\n", - "\n", - "TensorFlow works like this:\n", - "* First you describe the computation that you want to see performed: what the inputs, the variables, and the operations look like. These get created as nodes over a computation graph. This description is all contained within the block below:\n", - "\n", - " with graph.as_default():\n", - " ...\n", - "\n", - "* Then you can run the operations on this graph as many times as you want by calling `session.run()`, providing it outputs to fetch from the graph that get returned. This runtime operation is all contained in the block below:\n", - "\n", - " with tf.Session(graph=graph) as session:\n", - " ...\n", - "\n", - "Let's load all the data into TensorFlow and build the computation graph corresponding to our training:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Nfv39qvtvOl_", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "# With gradient descent training, even this much data is prohibitive.\n", - "# Subset the training data for faster turnaround.\n", - "train_subset = 10000\n", - "\n", - "graph = tf.Graph()\n", - "with graph.as_default():\n", - "\n", - " # Input data.\n", - " # Load the training, validation and test data into constants that are\n", - " # attached to the graph.\n", - " tf_train_dataset = tf.constant(train_dataset[:train_subset, :])\n", - " tf_train_labels = tf.constant(train_labels[:train_subset])\n", - " tf_valid_dataset = tf.constant(valid_dataset)\n", - " tf_test_dataset = tf.constant(test_dataset)\n", - " \n", - " # Variables.\n", - " # These are the parameters that we are going to be training. The weight\n", - " # matrix will be initialized using random values following a (truncated)\n", - " # normal distribution. The biases get initialized to zero.\n", - " weights = tf.Variable(\n", - " tf.truncated_normal([image_size * image_size, num_labels]))\n", - " biases = tf.Variable(tf.zeros([num_labels]))\n", - " \n", - " # Training computation.\n", - " # We multiply the inputs with the weight matrix, and add biases. We compute\n", - " # the softmax and cross-entropy (it's one operation in TensorFlow, because\n", - " # it's very common, and it can be optimized). We take the average of this\n", - " # cross-entropy across all training examples: that's our loss.\n", - " logits = tf.matmul(tf_train_dataset, weights) + biases\n", - " loss = tf.reduce_mean(\n", - " tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n", - " \n", - " # Optimizer.\n", - " # We are going to find the minimum of this loss using gradient descent.\n", - " optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)\n", - " \n", - " # Predictions for the training, validation, and test data.\n", - " # These are not part of training, but merely here so that we can report\n", - " # accuracy figures as we train.\n", - " train_prediction = tf.nn.softmax(logits)\n", - " valid_prediction = tf.nn.softmax(\n", - " tf.matmul(tf_valid_dataset, weights) + biases)\n", - " test_prediction = tf.nn.softmax(tf.matmul(tf_test_dataset, weights) + biases)" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KQcL4uqISHjP", - "colab_type": "text" - }, - "source": [ - "Let's run this computation and iterate:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "z2cjdenH869W", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 9 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 57454, - "status": "ok", - "timestamp": 1449847994134, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "4c037ba1-b526-4d8e-e632-91e2a0333267" - }, - "source": [ - "num_steps = 801\n", - "\n", - "def accuracy(predictions, labels):\n", - " return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))\n", - " / predictions.shape[0])\n", - "\n", - "with tf.Session(graph=graph) as session:\n", - " # This is a one-time operation which ensures the parameters get initialized as\n", - " # we described in the graph: random weights for the matrix, zeros for the\n", - " # biases. \n", - " tf.global_variables_initializer().run()\n", - " print('Initialized')\n", - " for step in range(num_steps):\n", - " # Run the computations. We tell .run() that we want to run the optimizer,\n", - " # and get the loss value and the training predictions returned as numpy\n", - " # arrays.\n", - " _, l, predictions = session.run([optimizer, loss, train_prediction])\n", - " if (step % 100 == 0):\n", - " print('Loss at step %d: %f' % (step, l))\n", - " print('Training accuracy: %.1f%%' % accuracy(\n", - " predictions, train_labels[:train_subset, :]))\n", - " # Calling .eval() on valid_prediction is basically like calling run(), but\n", - " # just to get that one numpy array. Note that it recomputes all its graph\n", - " # dependencies.\n", - " print('Validation accuracy: %.1f%%' % accuracy(\n", - " valid_prediction.eval(), valid_labels))\n", - " print('Test accuracy: %.1f%%' % accuracy(test_prediction.eval(), test_labels))" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Initialized\n", - "Loss at step 0 : 17.2939\n", - "Training accuracy: 10.8%\n", - "Validation accuracy: 13.8%\n", - "Loss at step 100 : 2.26903\n", - "Training accuracy: 72.3%\n", - "Validation accuracy: 71.6%\n", - "Loss at step 200 : 1.84895\n", - "Training accuracy: 74.9%\n", - "Validation accuracy: 73.9%\n", - "Loss at step 300 : 1.60701\n", - "Training accuracy: 76.0%\n", - "Validation accuracy: 74.5%\n", - "Loss at step 400 : 1.43912\n", - "Training accuracy: 76.8%\n", - "Validation accuracy: 74.8%\n", - "Loss at step 500 : 1.31349\n", - "Training accuracy: 77.5%\n", - "Validation accuracy: 75.0%\n", - "Loss at step 600 : 1.21501\n", - "Training accuracy: 78.1%\n", - "Validation accuracy: 75.4%\n", - "Loss at step 700 : 1.13515\n", - "Training accuracy: 78.6%\n", - "Validation accuracy: 75.4%\n", - "Loss at step 800 : 1.0687\n", - "Training accuracy: 79.2%\n", - "Validation accuracy: 75.6%\n", - "Test accuracy: 82.9%\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x68f-hxRGm3H", - "colab_type": "text" - }, - "source": [ - "Let's now switch to stochastic gradient descent training instead, which is much faster.\n", - "\n", - "The graph will be similar, except that instead of holding all the training data into a constant node, we create a `Placeholder` node which will be fed actual data at every call of `session.run()`." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "qhPMzWYRGrzM", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "batch_size = 128\n", - "\n", - "graph = tf.Graph()\n", - "with graph.as_default():\n", - "\n", - " # Input data. For the training data, we use a placeholder that will be fed\n", - " # at run time with a training minibatch.\n", - " tf_train_dataset = tf.placeholder(tf.float32,\n", - " shape=(batch_size, image_size * image_size))\n", - " tf_train_labels = tf.placeholder(tf.float32, shape=(batch_size, num_labels))\n", - " tf_valid_dataset = tf.constant(valid_dataset)\n", - " tf_test_dataset = tf.constant(test_dataset)\n", - " \n", - " # Variables.\n", - " weights = tf.Variable(\n", - " tf.truncated_normal([image_size * image_size, num_labels]))\n", - " biases = tf.Variable(tf.zeros([num_labels]))\n", - " \n", - " # Training computation.\n", - " logits = tf.matmul(tf_train_dataset, weights) + biases\n", - " loss = tf.reduce_mean(\n", - " tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n", - " \n", - " # Optimizer.\n", - " optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(loss)\n", - " \n", - " # Predictions for the training, validation, and test data.\n", - " train_prediction = tf.nn.softmax(logits)\n", - " valid_prediction = tf.nn.softmax(\n", - " tf.matmul(tf_valid_dataset, weights) + biases)\n", - " test_prediction = tf.nn.softmax(tf.matmul(tf_test_dataset, weights) + biases)" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XmVZESmtG4JH", - "colab_type": "text" - }, - "source": [ - "Let's run it:" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "FoF91pknG_YW", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 6 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 66292, - "status": "ok", - "timestamp": 1449848003013, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "d255c80e-954d-4183-ca1c-c7333ce91d0a" - }, - "source": [ - "num_steps = 3001\n", - "\n", - "with tf.Session(graph=graph) as session:\n", - " tf.global_variables_initializer().run()\n", - " print(\"Initialized\")\n", - " for step in range(num_steps):\n", - " # Pick an offset within the training data, which has been randomized.\n", - " # Note: we could use better randomization across epochs.\n", - " offset = (step * batch_size) % (train_labels.shape[0] - batch_size)\n", - " # Generate a minibatch.\n", - " batch_data = train_dataset[offset:(offset + batch_size), :]\n", - " batch_labels = train_labels[offset:(offset + batch_size), :]\n", - " # Prepare a dictionary telling the session where to feed the minibatch.\n", - " # The key of the dictionary is the placeholder node of the graph to be fed,\n", - " # and the value is the numpy array to feed to it.\n", - " feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels}\n", - " _, l, predictions = session.run(\n", - " [optimizer, loss, train_prediction], feed_dict=feed_dict)\n", - " if (step % 500 == 0):\n", - " print(\"Minibatch loss at step %d: %f\" % (step, l))\n", - " print(\"Minibatch accuracy: %.1f%%\" % accuracy(predictions, batch_labels))\n", - " print(\"Validation accuracy: %.1f%%\" % accuracy(\n", - " valid_prediction.eval(), valid_labels))\n", - " print(\"Test accuracy: %.1f%%\" % accuracy(test_prediction.eval(), test_labels))" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Initialized\n", - "Minibatch loss at step 0 : 16.8091\n", - "Minibatch accuracy: 12.5%\n", - "Validation accuracy: 14.0%\n", - "Minibatch loss at step 500 : 1.75256\n", - "Minibatch accuracy: 77.3%\n", - "Validation accuracy: 75.0%\n", - "Minibatch loss at step 1000 : 1.32283\n", - "Minibatch accuracy: 77.3%\n", - "Validation accuracy: 76.6%\n", - "Minibatch loss at step 1500 : 0.944533\n", - "Minibatch accuracy: 83.6%\n", - "Validation accuracy: 76.5%\n", - "Minibatch loss at step 2000 : 1.03795\n", - "Minibatch accuracy: 78.9%\n", - "Validation accuracy: 77.8%\n", - "Minibatch loss at step 2500 : 1.10219\n", - "Minibatch accuracy: 80.5%\n", - "Validation accuracy: 78.0%\n", - "Minibatch loss at step 3000 : 0.758874\n", - "Minibatch accuracy: 82.8%\n", - "Validation accuracy: 78.8%\n", - "Test accuracy: 86.1%\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7omWxtvLLxik", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem\n", - "-------\n", - "\n", - "Turn the logistic regression example with SGD into a 1-hidden layer neural network with rectified linear units [nn.relu()](https://www.tensorflow.org/versions/r0.7/api_docs/python/nn.html#relu) and 1024 hidden nodes. This model should improve your validation / test accuracy.\n", - "\n", - "---" - ] - } - ] -} diff --git a/tensorflow/examples/udacity/3_regularization.ipynb b/tensorflow/examples/udacity/3_regularization.ipynb deleted file mode 100644 index 5dc6f148611..00000000000 --- a/tensorflow/examples/udacity/3_regularization.ipynb +++ /dev/null @@ -1,300 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "version": "0.3.2", - "views": {}, - "default_view": {}, - "name": "3_regularization.ipynb", - "provenance": [] - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "kR-4eNdK6lYS", - "colab_type": "text" - }, - "source": [ - "Deep Learning\n", - "=============\n", - "\n", - "Assignment 3\n", - "------------\n", - "\n", - "Previously in `2_fullyconnected.ipynb`, you trained a logistic regression and a neural network model.\n", - "\n", - "The goal of this assignment is to explore regularization techniques." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "JLpLa8Jt7Vu4", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "# These are all the modules we'll be using later. Make sure you can import them\n", - "# before proceeding further.\n", - "from __future__ import print_function\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "from six.moves import cPickle as pickle" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1HrCK6e17WzV", - "colab_type": "text" - }, - "source": [ - "First reload the data we generated in `1_notmnist.ipynb`." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y3-cj1bpmuxc", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 11777, - "status": "ok", - "timestamp": 1449849322348, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "e03576f1-ebbe-4838-c388-f1777bcc9873" - }, - "source": [ - "pickle_file = 'notMNIST.pickle'\n", - "\n", - "with open(pickle_file, 'rb') as f:\n", - " save = pickle.load(f)\n", - " train_dataset = save['train_dataset']\n", - " train_labels = save['train_labels']\n", - " valid_dataset = save['valid_dataset']\n", - " valid_labels = save['valid_labels']\n", - " test_dataset = save['test_dataset']\n", - " test_labels = save['test_labels']\n", - " del save # hint to help gc free up memory\n", - " print('Training set', train_dataset.shape, train_labels.shape)\n", - " print('Validation set', valid_dataset.shape, valid_labels.shape)\n", - " print('Test set', test_dataset.shape, test_labels.shape)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Training set (200000, 28, 28) (200000,)\n", - "Validation set (10000, 28, 28) (10000,)\n", - "Test set (18724, 28, 28) (18724,)\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L7aHrm6nGDMB", - "colab_type": "text" - }, - "source": [ - "Reformat into a shape that's more adapted to the models we're going to train:\n", - "- data as a flat matrix,\n", - "- labels as float 1-hot encodings." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IRSyYiIIGIzS", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 11728, - "status": "ok", - "timestamp": 1449849322356, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "3f8996ee-3574-4f44-c953-5c8a04636582" - }, - "source": [ - "image_size = 28\n", - "num_labels = 10\n", - "\n", - "def reformat(dataset, labels):\n", - " dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)\n", - " # Map 1 to [0.0, 1.0, 0.0 ...], 2 to [0.0, 0.0, 1.0 ...]\n", - " labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)\n", - " return dataset, labels\n", - "train_dataset, train_labels = reformat(train_dataset, train_labels)\n", - "valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)\n", - "test_dataset, test_labels = reformat(test_dataset, test_labels)\n", - "print('Training set', train_dataset.shape, train_labels.shape)\n", - "print('Validation set', valid_dataset.shape, valid_labels.shape)\n", - "print('Test set', test_dataset.shape, test_labels.shape)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Training set (200000, 784) (200000, 10)\n", - "Validation set (10000, 784) (10000, 10)\n", - "Test set (18724, 784) (18724, 10)\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "RajPLaL_ZW6w", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "def accuracy(predictions, labels):\n", - " return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))\n", - " / predictions.shape[0])" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sgLbUAQ1CW-1", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 1\n", - "---------\n", - "\n", - "Introduce and tune L2 regularization for both logistic and neural network models. Remember that L2 amounts to adding a penalty on the norm of the weights to the loss. In TensorFlow, you can compute the L2 loss for a tensor `t` using `nn.l2_loss(t)`. The right amount of regularization should improve your validation / test accuracy.\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "na8xX2yHZzNF", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 2\n", - "---------\n", - "Let's demonstrate an extreme case of overfitting. Restrict your training data to just a few batches. What happens?\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ww3SCBUdlkRc", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 3\n", - "---------\n", - "Introduce Dropout on the hidden layer of the neural network. Remember: Dropout should only be introduced during training, not evaluation, otherwise your evaluation results would be stochastic as well. TensorFlow provides `nn.dropout()` for that, but you have to make sure it's only inserted during training.\n", - "\n", - "What happens to our extreme overfitting case?\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-b1hTz3VWZjw", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 4\n", - "---------\n", - "\n", - "Try to get the best performance you can using a multi-layer model! The best reported test accuracy using a deep network is [97.1%](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html?showComment=1391023266211#c8758720086795711595).\n", - "\n", - "One avenue you can explore is to add multiple layers.\n", - "\n", - "Another one is to use learning rate decay:\n", - "\n", - " global_step = tf.Variable(0) # count the number of steps taken.\n", - " learning_rate = tf.train.exponential_decay(0.5, global_step, ...)\n", - " optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)\n", - " \n", - " ---\n" - ] - } - ] -} diff --git a/tensorflow/examples/udacity/4_convolutions.ipynb b/tensorflow/examples/udacity/4_convolutions.ipynb deleted file mode 100644 index d607dddbb2d..00000000000 --- a/tensorflow/examples/udacity/4_convolutions.ipynb +++ /dev/null @@ -1,465 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "version": "0.3.2", - "views": {}, - "default_view": {}, - "name": "4_convolutions.ipynb", - "provenance": [] - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "4embtkV0pNxM", - "colab_type": "text" - }, - "source": [ - "Deep Learning\n", - "=============\n", - "\n", - "Assignment 4\n", - "------------\n", - "\n", - "Previously in `2_fullyconnected.ipynb` and `3_regularization.ipynb`, we trained fully connected networks to classify [notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html) characters.\n", - "\n", - "The goal of this assignment is make the neural network convolutional." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "tm2CQN_Cpwj0", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "# These are all the modules we'll be using later. Make sure you can import them\n", - "# before proceeding further.\n", - "from __future__ import print_function\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "from six.moves import cPickle as pickle\n", - "from six.moves import range" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "y3-cj1bpmuxc", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 11948, - "status": "ok", - "timestamp": 1446658914837, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "016b1a51-0290-4b08-efdb-8c95ffc3cd01" - }, - "source": [ - "pickle_file = 'notMNIST.pickle'\n", - "\n", - "with open(pickle_file, 'rb') as f:\n", - " save = pickle.load(f)\n", - " train_dataset = save['train_dataset']\n", - " train_labels = save['train_labels']\n", - " valid_dataset = save['valid_dataset']\n", - " valid_labels = save['valid_labels']\n", - " test_dataset = save['test_dataset']\n", - " test_labels = save['test_labels']\n", - " del save # hint to help gc free up memory\n", - " print('Training set', train_dataset.shape, train_labels.shape)\n", - " print('Validation set', valid_dataset.shape, valid_labels.shape)\n", - " print('Test set', test_dataset.shape, test_labels.shape)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Training set (200000, 28, 28) (200000,)\n", - "Validation set (10000, 28, 28) (10000,)\n", - "Test set (18724, 28, 28) (18724,)\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L7aHrm6nGDMB", - "colab_type": "text" - }, - "source": [ - "Reformat into a TensorFlow-friendly shape:\n", - "- convolutions need the image data formatted as a cube (width by height by #channels)\n", - "- labels as float 1-hot encodings." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IRSyYiIIGIzS", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 11952, - "status": "ok", - "timestamp": 1446658914857, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "650a208c-8359-4852-f4f5-8bf10e80ef6c" - }, - "source": [ - "image_size = 28\n", - "num_labels = 10\n", - "num_channels = 1 # grayscale\n", - "\n", - "import numpy as np\n", - "\n", - "def reformat(dataset, labels):\n", - " dataset = dataset.reshape(\n", - " (-1, image_size, image_size, num_channels)).astype(np.float32)\n", - " labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)\n", - " return dataset, labels\n", - "train_dataset, train_labels = reformat(train_dataset, train_labels)\n", - "valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)\n", - "test_dataset, test_labels = reformat(test_dataset, test_labels)\n", - "print('Training set', train_dataset.shape, train_labels.shape)\n", - "print('Validation set', valid_dataset.shape, valid_labels.shape)\n", - "print('Test set', test_dataset.shape, test_labels.shape)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Training set (200000, 28, 28, 1) (200000, 10)\n", - "Validation set (10000, 28, 28, 1) (10000, 10)\n", - "Test set (18724, 28, 28, 1) (18724, 10)\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "AgQDIREv02p1", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "def accuracy(predictions, labels):\n", - " return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))\n", - " / predictions.shape[0])" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5rhgjmROXu2O", - "colab_type": "text" - }, - "source": [ - "Let's build a small network with two convolutional layers, followed by one fully connected layer. Convolutional networks are more expensive computationally, so we'll limit its depth and number of fully connected nodes." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IZYv70SvvOan", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "batch_size = 16\n", - "patch_size = 5\n", - "depth = 16\n", - "num_hidden = 64\n", - "\n", - "graph = tf.Graph()\n", - "\n", - "with graph.as_default():\n", - "\n", - " # Input data.\n", - " tf_train_dataset = tf.placeholder(\n", - " tf.float32, shape=(batch_size, image_size, image_size, num_channels))\n", - " tf_train_labels = tf.placeholder(tf.float32, shape=(batch_size, num_labels))\n", - " tf_valid_dataset = tf.constant(valid_dataset)\n", - " tf_test_dataset = tf.constant(test_dataset)\n", - " \n", - " # Variables.\n", - " layer1_weights = tf.Variable(tf.truncated_normal(\n", - " [patch_size, patch_size, num_channels, depth], stddev=0.1))\n", - " layer1_biases = tf.Variable(tf.zeros([depth]))\n", - " layer2_weights = tf.Variable(tf.truncated_normal(\n", - " [patch_size, patch_size, depth, depth], stddev=0.1))\n", - " layer2_biases = tf.Variable(tf.constant(1.0, shape=[depth]))\n", - " layer3_weights = tf.Variable(tf.truncated_normal(\n", - " [image_size // 4 * image_size // 4 * depth, num_hidden], stddev=0.1))\n", - " layer3_biases = tf.Variable(tf.constant(1.0, shape=[num_hidden]))\n", - " layer4_weights = tf.Variable(tf.truncated_normal(\n", - " [num_hidden, num_labels], stddev=0.1))\n", - " layer4_biases = tf.Variable(tf.constant(1.0, shape=[num_labels]))\n", - " \n", - " # Model.\n", - " def model(data):\n", - " conv = tf.nn.conv2d(data, layer1_weights, [1, 2, 2, 1], padding='SAME')\n", - " hidden = tf.nn.relu(conv + layer1_biases)\n", - " conv = tf.nn.conv2d(hidden, layer2_weights, [1, 2, 2, 1], padding='SAME')\n", - " hidden = tf.nn.relu(conv + layer2_biases)\n", - " shape = hidden.get_shape().as_list()\n", - " reshape = tf.reshape(hidden, [shape[0], shape[1] * shape[2] * shape[3]])\n", - " hidden = tf.nn.relu(tf.matmul(reshape, layer3_weights) + layer3_biases)\n", - " return tf.matmul(hidden, layer4_weights) + layer4_biases\n", - " \n", - " # Training computation.\n", - " logits = model(tf_train_dataset)\n", - " loss = tf.reduce_mean(\n", - " tf.nn.softmax_cross_entropy_with_logits(labels=tf_train_labels, logits=logits))\n", - " \n", - " # Optimizer.\n", - " optimizer = tf.train.GradientDescentOptimizer(0.05).minimize(loss)\n", - " \n", - " # Predictions for the training, validation, and test data.\n", - " train_prediction = tf.nn.softmax(logits)\n", - " valid_prediction = tf.nn.softmax(model(tf_valid_dataset))\n", - " test_prediction = tf.nn.softmax(model(tf_test_dataset))" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "noKFb2UovVFR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 37 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 63292, - "status": "ok", - "timestamp": 1446658966251, - "user": { - "color": "", - "displayName": "", - "isAnonymous": false, - "isMe": true, - "permissionId": "", - "photoUrl": "", - "sessionId": "0", - "userId": "" - }, - "user_tz": 480 - }, - "outputId": "28941338-2ef9-4088-8bd1-44295661e628" - }, - "source": [ - "num_steps = 1001\n", - "\n", - "with tf.Session(graph=graph) as session:\n", - " tf.global_variables_initializer().run()\n", - " print('Initialized')\n", - " for step in range(num_steps):\n", - " offset = (step * batch_size) % (train_labels.shape[0] - batch_size)\n", - " batch_data = train_dataset[offset:(offset + batch_size), :, :, :]\n", - " batch_labels = train_labels[offset:(offset + batch_size), :]\n", - " feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels}\n", - " _, l, predictions = session.run(\n", - " [optimizer, loss, train_prediction], feed_dict=feed_dict)\n", - " if (step % 50 == 0):\n", - " print('Minibatch loss at step %d: %f' % (step, l))\n", - " print('Minibatch accuracy: %.1f%%' % accuracy(predictions, batch_labels))\n", - " print('Validation accuracy: %.1f%%' % accuracy(\n", - " valid_prediction.eval(), valid_labels))\n", - " print('Test accuracy: %.1f%%' % accuracy(test_prediction.eval(), test_labels))" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Initialized\n", - "Minibatch loss at step 0 : 3.51275\n", - "Minibatch accuracy: 6.2%\n", - "Validation accuracy: 12.8%\n", - "Minibatch loss at step 50 : 1.48703\n", - "Minibatch accuracy: 43.8%\n", - "Validation accuracy: 50.4%\n", - "Minibatch loss at step 100 : 1.04377\n", - "Minibatch accuracy: 68.8%\n", - "Validation accuracy: 67.4%\n", - "Minibatch loss at step 150 : 0.601682\n", - "Minibatch accuracy: 68.8%\n", - "Validation accuracy: 73.0%\n", - "Minibatch loss at step 200 : 0.898649\n", - "Minibatch accuracy: 75.0%\n", - "Validation accuracy: 77.8%\n", - "Minibatch loss at step 250 : 1.3637\n", - "Minibatch accuracy: 56.2%\n", - "Validation accuracy: 75.4%\n", - "Minibatch loss at step 300 : 1.41968\n", - "Minibatch accuracy: 62.5%\n", - "Validation accuracy: 76.0%\n", - "Minibatch loss at step 350 : 0.300648\n", - "Minibatch accuracy: 81.2%\n", - "Validation accuracy: 80.2%\n", - "Minibatch loss at step 400 : 1.32092\n", - "Minibatch accuracy: 56.2%\n", - "Validation accuracy: 80.4%\n", - "Minibatch loss at step 450 : 0.556701\n", - "Minibatch accuracy: 81.2%\n", - "Validation accuracy: 79.4%\n", - "Minibatch loss at step 500 : 1.65595\n", - "Minibatch accuracy: 43.8%\n", - "Validation accuracy: 79.6%\n", - "Minibatch loss at step 550 : 1.06995\n", - "Minibatch accuracy: 75.0%\n", - "Validation accuracy: 81.2%\n", - "Minibatch loss at step 600 : 0.223684\n", - "Minibatch accuracy: 100.0%\n", - "Validation accuracy: 82.3%\n", - "Minibatch loss at step 650 : 0.619602\n", - "Minibatch accuracy: 87.5%\n", - "Validation accuracy: 81.8%\n", - "Minibatch loss at step 700 : 0.812091\n", - "Minibatch accuracy: 75.0%\n", - "Validation accuracy: 82.4%\n", - "Minibatch loss at step 750 : 0.276302\n", - "Minibatch accuracy: 87.5%\n", - "Validation accuracy: 82.3%\n", - "Minibatch loss at step 800 : 0.450241\n", - "Minibatch accuracy: 81.2%\n", - "Validation accuracy: 82.3%\n", - "Minibatch loss at step 850 : 0.137139\n", - "Minibatch accuracy: 93.8%\n", - "Validation accuracy: 82.3%\n", - "Minibatch loss at step 900 : 0.52664\n", - "Minibatch accuracy: 75.0%\n", - "Validation accuracy: 82.2%\n", - "Minibatch loss at step 950 : 0.623835\n", - "Minibatch accuracy: 87.5%\n", - "Validation accuracy: 82.1%\n", - "Minibatch loss at step 1000 : 0.243114\n", - "Minibatch accuracy: 93.8%\n", - "Validation accuracy: 82.9%\n", - "Test accuracy: 90.0%\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KedKkn4EutIK", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 1\n", - "---------\n", - "\n", - "The convolutional model above uses convolutions with stride 2 to reduce the dimensionality. Replace the strides by a max pooling operation (`nn.max_pool()`) of stride 2 and kernel size 2.\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "klf21gpbAgb-", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 2\n", - "---------\n", - "\n", - "Try to get the best performance you can using a convolutional net. Look for example at the classic [LeNet5](http://yann.lecun.com/exdb/lenet/) architecture, adding Dropout, and/or adding learning rate decay.\n", - "\n", - "---" - ] - } - ] -} diff --git a/tensorflow/examples/udacity/5_word2vec.ipynb b/tensorflow/examples/udacity/5_word2vec.ipynb deleted file mode 100644 index 3b43d1fb55e..00000000000 --- a/tensorflow/examples/udacity/5_word2vec.ipynb +++ /dev/null @@ -1,896 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "version": "0.3.2", - "views": {}, - "default_view": {}, - "name": "5_word2vec.ipynb", - "provenance": [] - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "D7tqLMoKF6uq", - "colab_type": "text" - }, - "source": [ - "Deep Learning\n", - "=============\n", - "\n", - "Assignment 5\n", - "------------\n", - "\n", - "The goal of this assignment is to train a Word2Vec skip-gram model over [Text8](http://mattmahoney.net/dc/textdata) data." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "0K1ZyLn04QZf", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "# These are all the modules we'll be using later. Make sure you can import them\n", - "# before proceeding further.\n", - "%matplotlib inline\n", - "from __future__ import print_function\n", - "import collections\n", - "import math\n", - "import numpy as np\n", - "import os\n", - "import random\n", - "import tensorflow as tf\n", - "import zipfile\n", - "from matplotlib import pylab\n", - "from six.moves import range\n", - "from six.moves.urllib.request import urlretrieve\n", - "from sklearn.manifold import TSNE" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aCjPJE944bkV", - "colab_type": "text" - }, - "source": [ - "Download the data from the source website if necessary." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "RJ-o3UBUFtCw", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 14640, - "status": "ok", - "timestamp": 1445964482948, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2f1ffade4c9f20de", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "c4ec222c-80b5-4298-e635-93ca9f79c3b7" - }, - "source": [ - "url = 'http://mattmahoney.net/dc/'\n", - "\n", - "def maybe_download(filename, expected_bytes):\n", - " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", - " if not os.path.exists(filename):\n", - " filename, _ = urlretrieve(url + filename, filename)\n", - " statinfo = os.stat(filename)\n", - " if statinfo.st_size == expected_bytes:\n", - " print('Found and verified %s' % filename)\n", - " else:\n", - " print(statinfo.st_size)\n", - " raise Exception(\n", - " 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n", - " return filename\n", - "\n", - "filename = maybe_download('text8.zip', 31344016)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Found and verified text8.zip\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Zqz3XiqI4mZT", - "colab_type": "text" - }, - "source": [ - "Read the data into a string." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Mvf09fjugFU_", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 28844, - "status": "ok", - "timestamp": 1445964497165, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2f1ffade4c9f20de", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "e3a928b4-1645-4fe8-be17-fcf47de5716d" - }, - "source": [ - "def read_data(filename):\n", - " \"\"\"Extract the first file enclosed in a zip file as a list of words\"\"\"\n", - " with zipfile.ZipFile(filename) as f:\n", - " data = tf.compat.as_str(f.read(f.namelist()[0])).split()\n", - " return data\n", - " \n", - "words = read_data(filename)\n", - "print('Data size %d' % len(words))" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Data size 17005207\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Zdw6i4F8glpp", - "colab_type": "text" - }, - "source": [ - "Build the dictionary and replace rare words with UNK token." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "gAL1EECXeZsD", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 28849, - "status": "ok", - "timestamp": 1445964497178, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2f1ffade4c9f20de", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "3fb4ecd1-df67-44b6-a2dc-2291730970b2" - }, - "source": [ - "vocabulary_size = 50000\n", - "\n", - "def build_dataset(words):\n", - " count = [['UNK', -1]]\n", - " count.extend(collections.Counter(words).most_common(vocabulary_size - 1))\n", - " dictionary = dict()\n", - " for word, _ in count:\n", - " dictionary[word] = len(dictionary)\n", - " data = list()\n", - " unk_count = 0\n", - " for word in words:\n", - " if word in dictionary:\n", - " index = dictionary[word]\n", - " else:\n", - " index = 0 # dictionary['UNK']\n", - " unk_count = unk_count + 1\n", - " data.append(index)\n", - " count[0][1] = unk_count\n", - " reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) \n", - " return data, count, dictionary, reverse_dictionary\n", - "\n", - "data, count, dictionary, reverse_dictionary = build_dataset(words)\n", - "print('Most common words (+UNK)', count[:5])\n", - "print('Sample data', data[:10])\n", - "del words # Hint to reduce memory." - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Most common words (+UNK) [['UNK', 418391], ('the', 1061396), ('of', 593677), ('and', 416629), ('one', 411764)]\n", - "Sample data [5243, 3083, 12, 6, 195, 2, 3136, 46, 59, 156]\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lFwoyygOmWsL", - "colab_type": "text" - }, - "source": [ - "Function to generate a training batch for the skip-gram model." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "w9APjA-zmfjV", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 113, - "status": "ok", - "timestamp": 1445964901989, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2f1ffade4c9f20de", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "67cccb02-cdaf-4e47-d489-43bcc8d57bb8" - }, - "source": [ - "data_index = 0\n", - "\n", - "def generate_batch(batch_size, num_skips, skip_window):\n", - " global data_index\n", - " assert batch_size % num_skips == 0\n", - " assert num_skips <= 2 * skip_window\n", - " batch = np.ndarray(shape=(batch_size), dtype=np.int32)\n", - " labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)\n", - " span = 2 * skip_window + 1 # [ skip_window target skip_window ]\n", - " buffer = collections.deque(maxlen=span)\n", - " for _ in range(span):\n", - " buffer.append(data[data_index])\n", - " data_index = (data_index + 1) % len(data)\n", - " for i in range(batch_size // num_skips):\n", - " target = skip_window # target label at the center of the buffer\n", - " targets_to_avoid = [ skip_window ]\n", - " for j in range(num_skips):\n", - " while target in targets_to_avoid:\n", - " target = random.randint(0, span - 1)\n", - " targets_to_avoid.append(target)\n", - " batch[i * num_skips + j] = buffer[skip_window]\n", - " labels[i * num_skips + j, 0] = buffer[target]\n", - " buffer.append(data[data_index])\n", - " data_index = (data_index + 1) % len(data)\n", - " return batch, labels\n", - "\n", - "print('data:', [reverse_dictionary[di] for di in data[:8]])\n", - "\n", - "for num_skips, skip_window in [(2, 1), (4, 2)]:\n", - " data_index = 0\n", - " batch, labels = generate_batch(batch_size=8, num_skips=num_skips, skip_window=skip_window)\n", - " print('\\nwith num_skips = %d and skip_window = %d:' % (num_skips, skip_window))\n", - " print(' batch:', [reverse_dictionary[bi] for bi in batch])\n", - " print(' labels:', [reverse_dictionary[li] for li in labels.reshape(8)])" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "data: ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first']\n", - "\n", - "with num_skips = 2 and skip_window = 1:\n", - " batch: ['originated', 'originated', 'as', 'as', 'a', 'a', 'term', 'term']\n", - " labels: ['as', 'anarchism', 'a', 'originated', 'term', 'as', 'a', 'of']\n", - "\n", - "with num_skips = 4 and skip_window = 2:\n", - " batch: ['as', 'as', 'as', 'as', 'a', 'a', 'a', 'a']\n", - " labels: ['anarchism', 'originated', 'term', 'a', 'as', 'of', 'originated', 'term']\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ofd1MbBuwiva", - "colab_type": "text" - }, - "source": [ - "Train a skip-gram model." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8pQKsV4Vwlzy", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "batch_size = 128\n", - "embedding_size = 128 # Dimension of the embedding vector.\n", - "skip_window = 1 # How many words to consider left and right.\n", - "num_skips = 2 # How many times to reuse an input to generate a label.\n", - "# We pick a random validation set to sample nearest neighbors. here we limit the\n", - "# validation samples to the words that have a low numeric ID, which by\n", - "# construction are also the most frequent. \n", - "valid_size = 16 # Random set of words to evaluate similarity on.\n", - "valid_window = 100 # Only pick dev samples in the head of the distribution.\n", - "valid_examples = np.array(random.sample(range(valid_window), valid_size))\n", - "num_sampled = 64 # Number of negative examples to sample.\n", - "\n", - "graph = tf.Graph()\n", - "\n", - "with graph.as_default(), tf.device('/cpu:0'):\n", - "\n", - " # Input data.\n", - " train_dataset = tf.placeholder(tf.int32, shape=[batch_size])\n", - " train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])\n", - " valid_dataset = tf.constant(valid_examples, dtype=tf.int32)\n", - " \n", - " # Variables.\n", - " embeddings = tf.Variable(\n", - " tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))\n", - " softmax_weights = tf.Variable(\n", - " tf.truncated_normal([vocabulary_size, embedding_size],\n", - " stddev=1.0 / math.sqrt(embedding_size)))\n", - " softmax_biases = tf.Variable(tf.zeros([vocabulary_size]))\n", - " \n", - " # Model.\n", - " # Look up embeddings for inputs.\n", - " embed = tf.nn.embedding_lookup(embeddings, train_dataset)\n", - " # Compute the softmax loss, using a sample of the negative labels each time.\n", - " loss = tf.reduce_mean(\n", - " tf.nn.sampled_softmax_loss(weights=softmax_weights, biases=softmax_biases, inputs=embed,\n", - " labels=train_labels, num_sampled=num_sampled, num_classes=vocabulary_size))\n", - "\n", - " # Optimizer.\n", - " # Note: The optimizer will optimize the softmax_weights AND the embeddings.\n", - " # This is because the embeddings are defined as a variable quantity and the\n", - " # optimizer's `minimize` method will by default modify all variable quantities \n", - " # that contribute to the tensor it is passed.\n", - " # See docs on `tf.train.Optimizer.minimize()` for more details.\n", - " optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss)\n", - " \n", - " # Compute the similarity between minibatch examples and all embeddings.\n", - " # We use the cosine distance:\n", - " norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True))\n", - " normalized_embeddings = embeddings / norm\n", - " valid_embeddings = tf.nn.embedding_lookup(\n", - " normalized_embeddings, valid_dataset)\n", - " similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings))" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "1bQFGceBxrWW", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 23 - }, - { - "item_id": 48 - }, - { - "item_id": 61 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 436189, - "status": "ok", - "timestamp": 1445965429787, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2f1ffade4c9f20de", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "5ebd6d9a-33c6-4bcd-bf6d-252b0b6055e4" - }, - "source": [ - "num_steps = 100001\n", - "\n", - "with tf.Session(graph=graph) as session:\n", - " tf.global_variables_initializer().run()\n", - " print('Initialized')\n", - " average_loss = 0\n", - " for step in range(num_steps):\n", - " batch_data, batch_labels = generate_batch(\n", - " batch_size, num_skips, skip_window)\n", - " feed_dict = {train_dataset : batch_data, train_labels : batch_labels}\n", - " _, l = session.run([optimizer, loss], feed_dict=feed_dict)\n", - " average_loss += l\n", - " if step % 2000 == 0:\n", - " if step > 0:\n", - " average_loss = average_loss / 2000\n", - " # The average loss is an estimate of the loss over the last 2000 batches.\n", - " print('Average loss at step %d: %f' % (step, average_loss))\n", - " average_loss = 0\n", - " # note that this is expensive (~20% slowdown if computed every 500 steps)\n", - " if step % 10000 == 0:\n", - " sim = similarity.eval()\n", - " for i in range(valid_size):\n", - " valid_word = reverse_dictionary[valid_examples[i]]\n", - " top_k = 8 # number of nearest neighbors\n", - " nearest = (-sim[i, :]).argsort()[1:top_k+1]\n", - " log = 'Nearest to %s:' % valid_word\n", - " for k in range(top_k):\n", - " close_word = reverse_dictionary[nearest[k]]\n", - " log = '%s %s,' % (log, close_word)\n", - " print(log)\n", - " final_embeddings = normalized_embeddings.eval()" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Initialized\n", - "Average loss at step 0 : 8.58149623871\n", - "Nearest to been: unfavourably, marmara, ancestral, legal, bogart, glossaries, worst, rooms,\n", - "Nearest to time: conformist, strawberries, sindhi, waterfall, xia, nominates, psp, sensitivity,\n", - "Nearest to over: overlord, panda, golden, semigroup, rawlings, involved, shreveport, handling,\n", - "Nearest to not: hymenoptera, reintroducing, lamiaceae, because, davao, omnipotent, combustion, debilitating,\n", - "Nearest to three: catalog, koza, gn, braque, holstein, postgresql, luddite, justine,\n", - "Nearest to if: chilled, vince, fiddler, represented, sandinistas, happiness, lya, glands,\n", - "Nearest to there: coast, photosynthetic, kimmei, legally, inner, illyricum, formats, fullmetal,\n", - "Nearest to between: chuvash, prinz, suitability, wolfe, guideline, computability, diminutive, paulo,\n", - "Nearest to from: tanganyika, workshop, elphinstone, spearhead, resurrected, kevlar, shangri, loves,\n", - "Nearest to state: sextus, wuppertal, glaring, inches, unrounded, courageous, adler, connie,\n", - "Nearest to on: gino, phocas, rhine, jg, macrocosm, jackass, jays, theorie,\n", - "Nearest to and: standings, towed, reyes, willard, equality, juggling, wladislaus, faked,\n", - "Nearest to eight: gresham, dogg, moko, tennis, superseded, telegraphy, scramble, vinod,\n", - "Nearest to they: prisons, divisor, coder, ribeira, willingness, factional, nne, lotta,\n", - "Nearest to more: blues, fur, sterling, tangier, khwarizmi, discouraged, cal, deicide,\n", - "Nearest to other: enemies, bogged, brassicaceae, lascaux, dispense, alexandrians, crimea, dou,\n", - "Average loss at step 2000 : 4.39983723116\n", - "Average loss at step 4000 : 3.86921076906\n", - "Average loss at step 6000 : 3.72542127335\n", - "Average loss at step 8000 : 3.57835536212\n", - "Average loss at step 10000 : 3.61056993055\n", - "Nearest to been: glossaries, legal, unfavourably, be, hadad, wore, scarcity, were,\n", - "Nearest to time: strawberries, conformist, gleichschaltung, waterfall, molality, nominates, baal, dole,\n", - "Nearest to over: golden, semigroup, catus, motorways, brick, shehri, mussolini, overlord,\n", - "Nearest to not: hinayana, it, often, they, boots, also, noaa, lindsey,\n", - "Nearest to three: four, seven, six, five, nine, eight, two, zero,\n", - "Nearest to if: glands, euros, wallpaper, redefine, toho, confuse, unsound, shepherd,\n", - "Nearest to there: it, they, fullmetal, pace, legally, harpsichord, mma, bug,\n", - "Nearest to between: chuvash, wandering, from, kirsch, pursuant, eurocents, suitability, jackie,\n", - "Nearest to from: into, in, workshop, to, at, misogynist, elphinstone, spearhead,\n", - "Nearest to state: sextus, glaring, connie, adler, esoteric, didactic, handedness, presidents,\n", - "Nearest to on: in, at, for, ruminants, wakefulness, torrey, foley, gino,\n", - "Nearest to and: or, who, but, zelda, of, for, thirst, chisel,\n", - "Nearest to eight: nine, six, seven, five, four, three, zero, two,\n", - "Nearest to they: he, prisons, there, we, hydrate, it, not, cumbersome,\n", - "Nearest to more: skye, blues, trypomastigotes, deicide, most, readable, used, sterling,\n", - "Nearest to other: trochaic, hush, surveyors, joachim, differentiation, attackers, reverence, attestation,\n", - "Average loss at step 12000 : 3.66169466591\n", - "Average loss at step 14000 : 3.60342905837\n", - "Average loss at step 16000 : 3.57761328053\n", - "Average loss at step 18000 : 3.57667332476\n", - "Average loss at step 20000 : 3.53310145146\n", - "Nearest to been: be, become, was, hadad, unfavourably, were, wore, partido,\n", - "Nearest to time: gleichschaltung, strawberries, year, nominates, conformist, etch, admittedly, treasuries,\n", - "Nearest to over: golden, semigroup, motorways, rawlings, triangle, trey, ustawa, mattingly,\n", - "Nearest to not: they, boots, often, dieppe, still, hinayana, nearly, be,\n", - "Nearest to three: two, four, five, seven, eight, six, nine, one,\n", - "Nearest to if: wallpaper, euros, before, toho, unsound, so, bg, pfc,\n", - "Nearest to there: they, it, he, usually, which, we, not, transactions,\n", - "Nearest to between: from, with, about, near, reactance, eurocents, wandering, voltaire,\n", - "Nearest to from: into, workshop, by, between, in, on, elphinstone, under,\n", - "Nearest to state: glaring, esoteric, succeeding, sextus, vorarlberg, presidents, depends, connie,\n", - "Nearest to on: in, at, upon, during, from, janis, foley, nubian,\n", - "Nearest to and: or, thirst, but, where, s, who, pfaff, including,\n", - "Nearest to eight: nine, seven, six, five, four, three, zero, one,\n", - "Nearest to they: there, he, we, not, it, you, prisons, who,\n", - "Nearest to more: less, most, deicide, skye, trypomastigotes, interventionism, toed, drummond,\n", - "Nearest to other: such, joachim, hush, attackers, surveyors, trochaic, differentiation, reverence,\n", - "Average loss at step 22000 : 3.59519316927\n", - "Average loss at step 24000 : 3.55378576797\n", - "Average loss at step 26000 : 3.56455037558\n", - "Average loss at step 28000 : 3.5040882225\n", - "Average loss at step 30000 : 3.39208897972\n", - "Nearest to been: become, be, were, was, spotless, hadad, by, hausdorff,\n", - "Nearest to time: gleichschaltung, year, day, nominates, jesus, strawberries, way, admittedly,\n", - "Nearest to over: golden, semigroup, motorways, rawlings, interventionism, counternarcotics, adaption, brick,\n", - "Nearest to not: often, they, it, never, still, nor, boots, pki,\n", - "Nearest to three: four, six, two, eight, five, seven, nine, zero,\n", - "Nearest to if: when, before, so, should, toho, where, bg, wallpaper,\n", - "Nearest to there: they, it, which, usually, he, that, also, now,\n", - "Nearest to between: with, from, in, panasonic, presupposes, churchmen, hijacking, where,\n", - "Nearest to from: into, elphinstone, workshop, between, through, speculates, sosa, in,\n", - "Nearest to state: esoteric, glaring, presidents, vorarlberg, atmosphere, succeeding, lute, connie,\n", - "Nearest to on: upon, in, janis, during, torrey, against, infield, catalans,\n", - "Nearest to and: or, thirst, in, but, of, sobib, cleaves, including,\n", - "Nearest to eight: nine, six, four, seven, three, zero, five, one,\n", - "Nearest to they: we, there, he, you, it, these, who, i,\n", - "Nearest to more: less, most, deicide, faster, toed, very, skye, tonic,\n", - "Nearest to other: different, attackers, joachim, various, such, many, differentiation, these,\n", - "Average loss at step 32000 : 3.49501452419\n", - "Average loss at step 34000 : 3.48593705952\n", - "Average loss at step 36000 : 3.50112806576\n", - "Average loss at step" - ], - "name": "stdout" - }, - { - "output_type": "stream", - "text": [ - " 38000 : 3.49244426501\n", - "Average loss at step 40000 : 3.3890105716\n", - "Nearest to been: become, be, were, was, jolie, hausdorff, spotless, had,\n", - "Nearest to time: year, way, gleichschaltung, period, day, stanislav, stage, outcome,\n", - "Nearest to over: through, semigroup, rawlings, golden, about, brick, on, motorways,\n", - "Nearest to not: they, radiated, never, pki, still, omnipotent, hinayana, really,\n", - "Nearest to three: four, six, five, two, seven, eight, one, nine,\n", - "Nearest to if: when, before, where, then, bg, because, can, should,\n", - "Nearest to there: they, it, he, usually, this, typically, still, often,\n", - "Nearest to between: with, in, from, about, against, churchmen, johansen, presupposes,\n", - "Nearest to from: into, through, elphinstone, in, workshop, between, suing, under,\n", - "Nearest to state: esoteric, presidents, atmosphere, vorarlberg, lute, succeeding, glaring, didactic,\n", - "Nearest to on: upon, at, in, during, unitarians, under, catalans, batavians,\n", - "Nearest to and: or, but, s, incapacitation, including, while, of, which,\n", - "Nearest to eight: nine, six, seven, four, five, three, one, two,\n", - "Nearest to they: we, he, there, you, she, i, not, it,\n", - "Nearest to more: less, most, deicide, toed, greater, faster, quite, longer,\n", - "Nearest to other: various, different, attackers, joachim, clutter, nz, trochaic, apulia,\n", - "Average loss at step 42000 : 3.45294014364\n", - "Average loss at step 44000 : 3.47660055941\n", - "Average loss at step 46000 : 3.47458503014\n", - "Average loss at step 48000 : 3.47261548793\n", - "Average loss at step 50000 : 3.45390708435\n", - "Nearest to been: become, be, had, was, were, hausdorff, prem, remained,\n", - "Nearest to time: way, year, period, stv, day, gleichschaltung, stage, outcome,\n", - "Nearest to over: through, golden, semigroup, about, brick, counternarcotics, theremin, mattingly,\n", - "Nearest to not: they, still, never, really, sometimes, it, kiwifruit, nearly,\n", - "Nearest to three: five, four, six, seven, two, eight, one, nine,\n", - "Nearest to if: when, before, where, because, connexion, though, so, whether,\n", - "Nearest to there: they, it, he, this, now, often, usually, still,\n", - "Nearest to between: with, from, fashioned, churchmen, panasonic, explores, within, racial,\n", - "Nearest to from: into, through, under, elphinstone, between, workshop, circumpolar, idiom,\n", - "Nearest to state: atmosphere, vorarlberg, esoteric, presidents, madhya, majority, moulin, bowmen,\n", - "Nearest to on: upon, in, catalans, tezuka, minotaurs, wakefulness, batavians, guglielmo,\n", - "Nearest to and: or, but, thirst, signifier, which, however, including, unattractive,\n", - "Nearest to eight: six, nine, seven, five, four, three, zero, two,\n", - "Nearest to they: we, there, he, you, it, she, these, not,\n", - "Nearest to more: less, most, quite, very, further, faster, toed, deicide,\n", - "Nearest to other: various, different, many, attackers, are, joachim, nihilo, reject,\n", - "Average loss at step 52000 : 3.43597227755\n", - "Average loss at step 54000 : 3.25126817495\n", - "Average loss at step 56000 : 3.35102432287\n", - "Average loss at step 58000 : 3.44654818082\n", - "Average loss at step 60000 : 3.4287913968\n", - "Nearest to been: become, be, was, prem, had, remained, hadad, stanislavsky,\n", - "Nearest to time: year, way, period, stv, barely, name, stage, restoring,\n", - "Nearest to over: about, through, golden, adaption, counternarcotics, up, mattingly, brick,\n", - "Nearest to not: still, never, nor, kiwifruit, they, nearly, therefore, rarely,\n", - "Nearest to three: two, five, four, six, seven, eight, one, nine,\n", - "Nearest to if: when, though, before, where, although, because, can, could,\n", - "Nearest to there: they, it, he, still, she, we, this, often,\n", - "Nearest to between: with, from, churchmen, among, ethical, within, vma, panasonic,\n", - "Nearest to from: through, into, under, during, between, in, suing, across,\n", - "Nearest to state: atmosphere, infringe, madhya, vorarlberg, government, bowmen, vargas, republic,\n", - "Nearest to on: upon, through, within, ridiculous, janis, in, under, over,\n", - "Nearest to and: or, while, including, but, of, like, whose, bannister,\n", - "Nearest to eight: nine, six, five, four, seven, zero, three, two,\n", - "Nearest to they: we, there, you, he, it, these, she, prisons,\n", - "Nearest to more: less, most, quite, further, toed, very, faster, rather,\n", - "Nearest to other: different, various, many, nihilo, these, amour, including, screenplays,\n", - "Average loss at step 62000 : 3.38358767056\n", - "Average loss at step 64000 : 3.41693099326\n", - "Average loss at step 66000 : 3.39588000977\n", - "Average loss at step 68000 : 3.35567189544\n", - "Average loss at step 70000 : 3.38878934443\n", - "Nearest to been: become, be, was, prem, remained, were, being, discounts,\n", - "Nearest to time: year, way, day, period, barely, ethos, stage, reason,\n", - "Nearest to over: about, through, fortunately, semigroup, theremin, off, loudest, up,\n", - "Nearest to not: still, nor, never, they, actually, nearly, unelected, therefore,\n", - "Nearest to three: five, two, four, six, seven, eight, nine, zero,\n", - "Nearest to if: when, though, before, where, because, then, after, since,\n", - "Nearest to there: they, it, he, often, she, we, usually, still,\n", - "Nearest to between: among, with, within, from, ethical, churchmen, racial, prentice,\n", - "Nearest to from: through, into, within, during, under, until, between, across,\n", - "Nearest to state: city, atmosphere, desks, surrounding, preservation, bohr, principal, republic,\n", - "Nearest to on: upon, tezuka, through, within, wakefulness, catalans, at, ingeborg,\n", - "Nearest to and: or, but, while, including, thirst, jerzy, massing, abadan,\n", - "Nearest to eight: seven, six, nine, five, four, three, two, zero,\n", - "Nearest to they: we, you, he, there, she, it, prisons, who,\n", - "Nearest to more: less, most, quite, very, faster, smaller, further, larger,\n", - "Nearest to other: various, different, some, screenplays, lab, many, including, debugging,\n", - "Average loss at step 72000 : 3.41103189731\n", - "Average loss at step 74000 : 3.44926435578\n", - "Average loss at step 76000 : 3.4423020488\n", - "Average loss at step 78000 : 3.41976813722\n", - "Average loss at step 80000 : 3.39511853886\n", - "Nearest to been: become, be, remained, was, grown, were, prem, already," - ], - "name": "stdout" - }, - { - "output_type": "stream", - "text": [ - "\n", - "Nearest to time: year, way, period, reason, barely, distance, stage, day,\n", - "Nearest to over: about, fortunately, through, semigroup, further, mattingly, rawlings, golden,\n", - "Nearest to not: still, they, nor, never, we, kiwifruit, noaa, really,\n", - "Nearest to three: five, two, seven, four, eight, six, nine, zero,\n", - "Nearest to if: when, where, though, before, since, because, although, follows,\n", - "Nearest to there: they, it, he, we, she, still, typically, actually,\n", - "Nearest to between: with, among, within, in, racial, around, from, serapeum,\n", - "Nearest to from: into, through, in, within, under, using, during, towards,\n", - "Nearest to state: city, atmosphere, ferro, vorarlberg, surrounding, republic, madhya, national,\n", - "Nearest to on: upon, poll, in, from, tezuka, janis, through, within,\n", - "Nearest to and: or, but, including, while, s, which, thirst, although,\n", - "Nearest to eight: nine, seven, six, five, four, three, zero, two,\n", - "Nearest to they: we, you, there, he, she, it, these, not,\n", - "Nearest to more: less, most, smaller, very, faster, quite, rather, larger,\n", - "Nearest to other: various, different, joachim, including, theos, smaller, individual, screenplays,\n", - "Average loss at step 82000 : 3.40933967865\n", - "Average loss at step 84000 : 3.41618054378\n", - "Average loss at step 86000 : 3.31485116804\n", - "Average loss at step 88000 : 3.37068593091\n", - "Average loss at step 90000 : 3.2785516749\n", - "Nearest to been: become, be, was, prem, remained, grown, recently, already,\n", - "Nearest to time: year, way, period, day, barely, battle, buds, name,\n", - "Nearest to over: through, about, fortunately, off, theremin, semigroup, extraterrestrial, mattingly,\n", - "Nearest to not: nor, still, never, otherwise, generally, separately, gown, hydrate,\n", - "Nearest to three: four, five, six, two, eight, seven, nine, zero,\n", - "Nearest to if: when, where, before, though, because, since, then, while,\n", - "Nearest to there: they, it, he, we, she, still, typically, fiorello,\n", - "Nearest to between: with, among, within, from, churchmen, prentice, racial, panasonic,\n", - "Nearest to from: through, into, across, during, towards, until, at, within,\n", - "Nearest to state: bohr, city, atmosphere, ferro, bowmen, republic, retaliation, vorarlberg,\n", - "Nearest to on: upon, in, tezuka, at, during, within, via, catalans,\n", - "Nearest to and: or, including, but, while, like, thirst, with, schuman,\n", - "Nearest to eight: seven, nine, six, five, four, three, zero, two,\n", - "Nearest to they: we, there, he, you, she, it, prisons, these,\n", - "Nearest to more: less, most, very, faster, larger, quite, smaller, better,\n", - "Nearest to other: different, various, tamara, prosthetic, including, individual, failing, restaurants,\n", - "Average loss at step 92000 : 3.40355363208\n", - "Average loss at step 94000 : 3.35647508007\n", - "Average loss at step 96000 : 3.34374570692\n", - "Average loss at step 98000 : 3.4230104093\n", - "Average loss at step 100000 : 3.36909827\n", - "Nearest to been: become, be, grown, was, being, already, remained, prem,\n", - "Nearest to time: way, year, day, period, years, days, mothersbaugh, separators,\n", - "Nearest to over: through, about, semigroup, further, fortunately, off, into, theremin,\n", - "Nearest to not: never, nor, still, dieppe, really, unelected, actually, now,\n", - "Nearest to three: four, two, five, seven, six, eight, nine, zero,\n", - "Nearest to if: when, though, where, before, is, abe, then, follows,\n", - "Nearest to there: they, it, he, we, still, she, typically, often,\n", - "Nearest to between: within, with, among, churchmen, around, explores, from, reactance,\n", - "Nearest to from: into, through, within, across, in, between, using, workshop,\n", - "Nearest to state: atmosphere, bohr, national, ferro, germ, desks, city, unpaid,\n", - "Nearest to on: upon, in, within, tezuka, janis, batavians, about, macrocosm,\n", - "Nearest to and: or, but, purview, thirst, sukkot, epr, including, honesty,\n", - "Nearest to eight: seven, nine, six, four, five, three, zero, one,\n", - "Nearest to they: we, there, you, he, she, prisons, it, these,\n", - "Nearest to more: less, most, very, quite, faster, larger, rather, smaller,\n", - "Nearest to other: various, different, tamara, theos, some, cope, many, others,\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "jjJXYA_XzV79", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "num_points = 400\n", - "\n", - "tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact')\n", - "two_d_embeddings = tsne.fit_transform(final_embeddings[1:num_points+1, :])" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "o_e0D_UezcDe", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 4763, - "status": "ok", - "timestamp": 1445965465525, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "2f1ffade4c9f20de", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "df22e4a5-e8ec-4e5e-d384-c6cf37c68c34" - }, - "source": [ - "def plot(embeddings, labels):\n", - " assert embeddings.shape[0] >= len(labels), 'More labels than embeddings'\n", - " pylab.figure(figsize=(15,15)) # in inches\n", - " for i, label in enumerate(labels):\n", - " x, y = embeddings[i,:]\n", - " pylab.scatter(x, y)\n", - " pylab.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points',\n", - " ha='right', va='bottom')\n", - " pylab.show()\n", - "\n", - "words = [reverse_dictionary[i] for i in range(1, num_points+1)]\n", - "plot(two_d_embeddings, words)" - ], - "outputs": [ - { - "output_type": "display_data", - "metadata": {}, - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3MAAANpCAYAAAChBGCHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XdAldUfx/H3BdlbQEVzoyDukZaae5aZ5tbcIzUz9x5Z\njhwNNXMVztTExFHqT9Ny50hFc+ZKEVBwAbLh/v4gSXILChc/r3+69/Lc53yfewL8cM5zjsFoNBoR\nERERERERk2KW0QWIiIiIiIjI01OYExERERERMUEKcyIiIiIiIiZIYU5ERERERMQEKcyJiIiIiIiY\nIIU5ERERERERE5TmMBceHk7fvn1p2LAhb775JgEBAdy6dYvOnTtTv359unTpQnh4eHrUKiIiIiIi\nIv8wpHWfuaFDh/Lqq6/SvHlzEhISiI6OZvbs2bi4uNC9e3fmzZtHeHg4gwYNSq+aRUREREREXnpp\nGpmLiIjg4MGDNG/eHIBs2bLh4ODAtm3baNq0KQBNmzbll19+SXulIiIiIiIikiJbWt4cGBhI9uzZ\nGT58OKdOnaJ48eKMGDGC69ev4+bmBoCbmxvXr19Pl2JFREREREQkWZpG5hISEjhx4gRt2rTB398f\nGxsb5s2bl+oYg8GAwWBIU5EiIiIiIiKSWprCXK5cuciZMyelSpUCoH79+pw4cQI3NzdCQ0MBuHbt\nGtmzZ3/kedJ4256IiIiIiMhLJ03TLN3d3fHw8ODChQsULFiQvXv34unpiaenJ/7+/vTo0YM1a9ZQ\np06dR57HYDAQGhqRllIkE3N3d1D/ZmHq36xLfZu1qX+zLvVt1qb+zbrc3R2e+j1pCnMAo0ePZtCg\nQcTHx5MvXz4mTZpEYmIi/fr148cffyRPnjx89dVXaW1GRERERERE7pHmMOft7c2PP/543+sLFy5M\n66lFRERERETkIdK8abiIiIiIiIi8eApzIiIiIiIiJkhhTkRERERExAQpzImIiIiIiJgghTkRERER\nERETpDAnIiIiIiJighTmRERERERETJDCnIiIiIiIiAlSmBMRERERETFBCnMiIiIiIiImSGFORERE\nRETEBCnMiYiIiIiImCCFOREREREREROkMCciIiIiImKCFOZERERERERMkMKciIiIiIiICVKYExER\nERERMUEKcyIiIiIiIiZIYU5ERERERMQEKcyJiIiIiIiYIIU5ERERERERE6QwJyIiIiIiYoIU5kRE\nREREREyQwpyIiIiIiIgJUpgTERERERExQQpzIiIiIiIiJkhhTkRERERExAQpzImIiIiIiJgghTkR\nERERERETpDAnIiIiIiJighTmRERERERETJDCnIiIiIiIiAlSmBMRERERETFBCnMiIiIiIiImSGFO\nRERERETEBCnMiYiIiIiImCCFOREREREREROkMCciIiIiImKCFOZERERERERMkMKciIiIiIiICVKY\nExERERERMUEKcyIiIiIiIiZIYU5ERERERMQEKcyJiIiIiIiYIIU5ERERERERE6QwJyIiIiIiYoIU\n5kREREREREyQwpyIiIiIiIgJUpgTERERERExQQpzIiIiIiIiJkhhTkRERERExAQpzImIiIiIiJgg\nhTkRERERERETpDAnIiIiIiJighTmRERERERETJDCnIiIiIiIiAlSmBMRERERETFBCnMiIiIiIiIm\nSGFORERERETEBCnMiYiIiIiImCCFOREREREREROkMCciIiIiImKCFOZERERERERMkMKciIiIiIiI\nCVKYExERERERMUEKcyIiIiIiIiZIYU5ERERERMQEKcyJiIiIiIiYIIU5ERERERERE6QwJyIiIiIi\nYoIU5kREREREREyQwpyIiIiIiIgJUpgTERERERExQQpzIiIiIiIiJkhhTkRERERExAQpzImIiIiI\niJgghTkRERERERETpDAnIiIiIiJighTmRERERERETJDCnIiIiIiIiAlSmBMRERERETFBCnMiIiIi\nIiImSGFORERERETEBCnMiYiIiIiImCCFOREREREREROULaMLEBERkcxnxYqlbNiwHoBGjZpQrVoN\nBg78kFKlyvLnnwG4u+dg0qTPsbKy4sqVQL74Ygq3bt3E2tqaoUNHki9fgYy9ABGRl4BG5kRERCSV\nU6dOsnHjT8yfv4i5cxeyfr0/ERHhBAZeplmzlixZshJ7ewe2b98GwJQpE+jffzDffbeE3r0/4vPP\nJ2fwFYiIvBw0MiciIiKpHD16hGrVamJlZQ1A9eq1CAg4jIdHHjw9iwDg5eVNcHAQ0dHRHDt2lNGj\nh6a8Pz4+IUPqFhF52SjMiYiISCoGg+GBr1taWqQ8NjMzJykpDqMxCQcHBxYsWPaiyhMRkX9omqWI\niIikUrp0GXbs+I3Y2Biio6PZseNXSpcue99xRqMRW1s7cufOza+//pLy2tmzf73okkVEXkoamRMR\nEZFUihb15s03G9G9e0cA3n67KQ4OjveN2N19PmbMeKZN+4xFi3xJSEigTp16KdMxRUTk+TEYjUZj\nRhcBEBoakdElyHPi7u6g/s3C1L9Zl/o2a1P/Zl3q26xN/Zt1ubs7PPV7NM1SREREntnx4+eYP38j\nf/xxMqNLERF56SjMiYiIyDNZt24/LVpEM3JkC1q1smLhwu0ZXZKIyEtFYU5ERESeyeLFNwgLex0w\nEB5emqVLYzK6JBGRl4rCnIiIiDwTo9HwyOciIvJ8KcyJiIjIM2nd2gEXl0MA2NmdpkULLZItIvIi\n6aeuiIiIPJMWLSpToMAJ9u3zo2TJXFSvXiujSxIReakozImIiMgze/VVH1591SejyxAReSlpmqWI\niIiIiIgJUpgTERERERExQQpzIiIiIiIiJkhhTkRERERExAQpzImIiGSAyMhI/P1XZXQZIiJiwhTm\nREREMkBERDj+/n4ZXYaIiJgwbU0gIiKSAebMmcmVK4F07tyWIkW8qFatJlWrVmP48EE4OjoyfPgY\nfvppLUFBV+jRozcrVixlw4b1ADRq1ISWLdtk8BWIiEhG08iciIhIBujVqy958rzCggXLqFTpdY4e\nPQxAWNg1/v77IgBHjx6hbNlynDp1ko0bf2L+/EXMnbuQ9ev9+euv0xlYvYiIZAYKcyIiIhnAaDSm\nPC5VqgwBAUe4ePECBQsWxsUlO9evh3H8+DFKlCjN0aNHqFatJlZW1tjY2FC9ei0CAg5nYPUiIpIZ\naJqliIhIBnN3z0FkZAT79u2hdOmyhIeHs3XrFmxtbbGxscFgMKQ63mg03veaiIi8fDQyJyIikgFs\nbW2JiopKeV68eElWrlxOmTLlKF26DCtWLKVUqbIAlC5dhh07fiM2Nobo6Gh27vwt5WsiIvLy0sic\niIhIBnBycqZkydJ06NCK116rTKlSZThwYB958rxCzpy5iIgIp3Tp5MBWtKg3b77ZiO7dOwLw9ttN\nKVKkaEaWLyIimYDBeO+k/QwUGhqR0SXIc+Lu7qD+zcLUv1mX+jbzuHr1KsHBoXh7e2JtbZ0u51T/\nZl3q26xN/Zt1ubs7PPV7NM1SREQkE/P13U61apepV8+DRo02cflySEaXJCIimYTCnIiISCYVHx/P\nrFnx3LxZEyjA0aPt+eKLgxldloiIZBIKcyIiIplUXFwcUVH2qV6LibHMoGpERCSzUZgTERHJpOzs\n7Kha9W8gGgBHx8M0bOiUsUWJiEimodUsRUREMrHZs5tRvPhPhIVBzZq5qF27UkaXJCIimYTCnIiI\nSCaWLVs2+vVrkNFliIhIJqRpliIiIiIiIiZIYU5ERERERMQEKcyJiIjIA02ePJ6LFy9kdBkiIvIQ\numdOREREHmjo0FEZXYKIiDyCwpyIiMhLIDo6mjFjhhEaGkpSUiIfftgHBwc3vv76S6Kjo3FwcMBo\nNBIaGsrVq8GMGPEx/v5+tG79Ht98M4OkpCQGDRrGwoXfcvr0SfLnL8DkyV/i6upGnz49KF68JIcO\nHSQyMoJhw8ZQunSZjL5kEZEsT9MsRUREXgL79u3BzS0HCxcuY/HiH6hWrRrTp09lwoQpfPfdEgoX\n9uTatWssW7aKfPnyU6BAAQwGAwcO7KNBgzdp2/Y9xowZTrly5alf/03atGnPvHnfAGAwGEhKSmL+\n/EX07TuQBQvmZfDVioi8HBTmRETkpRcZGYm//yoADh06yJAh/TO4ovRXuHARDh7cx+zZMwkIOEJQ\nUBDnz5+jX7/edO7clj17dnPtWgizZ8+kWLES7N27G4CDB/dRu3Y9SpQoxZ07kSxZspBjx46yeLEv\noaGhKeevXr0mAF5e3oSEBGfINYqIvGw0zVJERF56ERHh+Pv70bRp84wuBUgOl1u2bKJp0+YcOnSQ\nFSu+Z8qUL9N0zrx58+Hr+z179+5i/vxveOONKhQsWJg5c3xTjomIiGDv3l34+a1g9+4d5MiREzCQ\nJ88rnDhxnGzZsuHunoMvvpiJq6tbqvNbWFgCYGZmTmJiYppqFRGRJ6OROREReenNmTOTK1cC6dy5\nLbNnzyA6OopRo4bSrl1zPvlk9Auv5264fBpJSUmP/HpYWBiWlpbUq9eQNm3ac/ToUW7dusWffx4D\nICQkhKCgK9Sr15BOnboRGxtLSEgwFSu+BsDGjeuwsrKmXr03mTx5AgkJCVy4cP7ZLlBERNKFRuZE\nROSl16tXXy5cOM+CBcs4fPgPhg8fyNKlfri6utGrV1eOHj1CqVIvbkGPe8NltmzZsLa2YdSooVy4\ncA4vr2KMGfMpAM2bv03t2vU4cGAf7dp1wMHBEV/fecTFxZEnzyuMGDEWGxsbTp06ycSJ47hy5TLm\n5ubkyZOXyZMncft2DNOnTyMyMpLIyAgSExNxdnYmWzYLGjZshJ/fcl59tRKHD//BX3/9xfTp3zB9\n+uf89dcZWrR4m65d36dgwUIPuALDC/usREReZgpzIiLy0jMajakeFytWHDc3dwA8PYsSEhL8QsPc\n48LlsWMBlCxZGoPBgJOTM76+S7l16xajRg1h+vRvsLKyZunShYwePYzg4CvcunWLZctW4eTkzNat\nm9m//3eKFy9OaGgEX3/98MVKPvpoIJA86jdt2nTs7R0eePzMmXNTHjs7O+Pntzb9PxQREbmPwpyI\niMh/3L3/C8Dc3OyF3wP2uHAZHBxMyZKlAahduy4Ax48f4+LF8/Ts2QWA+PgEwsKuMWHCVEaMGES/\nfr2B5GDm6ur+xLVs3nyETz4JJCzMHR+fv5k3rw5ubtn/aSOedet2YzQaady4CpaWlo85m4iIpCeF\nOREReaiFC79l8+aNODu7kCNHTry8itGmzXsZXVa6s7W1JSoqKqPLeKj7w2VCynMbG5uUxxUqVOLj\njycAMHXqRDZsWM/UqZNwcHAkVy4PgoKCuHPnDtHR0VSsWJEiRbyoVKlySp+2b9+SqVNnYDQmMWBA\nH4oXL8mmTfs5f96PhAQPdu0yMn7893z11TvEx8fToYMfW7e2BQysWLGM779/FysrqxfzoYiISPos\ngJKYmEiTJk3o2bMnALdu3aJz587Ur1+fLl26EB4enh7NiIgIEBwcRLt2zZk8eQLt27dkwIA+xMbG\npns7R48eZfv2bSxatIJp02Zw6tRJDFn0VignJ2dKlixNhw6tmD17RoZf57OESx+fEhw7FsCVK4EA\n9OnTH2dnF775Zj4RERE4O2dn0aLlJCUlYWZmxv79+++bOmq458KvXAnknXfe5fbtYSQkeNw9gtu3\nbQHw89vO1q0dAAfAnh07OrFs2fZnvWQREXkG6TIyt3jxYgoXLsydO3cAmDdvHpUrV6Z79+7MmzeP\nefPmMWjQoPRoSkREgMDAy4wbN4mhQ0cyZsxwtm/fRr16DdO1jUOHDvHGGzWwsLDAwsKCKlXe4J7Z\nf1nO2LHjH/h6//5DXnAlqcOllZUV2bO7PvY9Li4ujBz5MR9/PIK4uHgAEhISMDc3x9XVjbNnz/Dm\nm7UJD7+NmZkZc+fOZe/e3VStWo3mzd9m1ar1AMTGxtC3b09y5sxF9uyu5Mo1HisrX5KSbLh+/UMq\nVkwOfPHxRlL/M8Kc+PhHr6gpIiLpK81hLiQkhO3bt9OzZ08WLlwIwLZt21i6dCkATZs2pX379gpz\nIiLpyMMjD56eRYDkTZqDg4PSvQ2DwZDq3i3IwknuH7Nnb+P77+NJTDTn3XcTGDy4QYbV8iTh0s9v\nXaqvlStXgfnzF6c8b9GiMVFRUZiZGRgz5lPy5s1HixaNSUhIwMnJCTMzAxYWFhQpUpRDhw4SFxfH\nwYP7KVOmHBcunGPKlAnMmjWeRYv+4tKlq0RHf0rPnsv/OXcV/PwWs39/J8BAuXILadv2zXT/HERE\n5OHSPM1y4sSJDBkyBDOzf091/fp13NySNxN1c3Pj+vXraW1GRETuYWlpkfL4eW3SXK5cOXbv3klc\nXBxRUVHs2bMrw6cfPk8HDhxn6tTCnDnTnHPnmjJjxqts3Lgvo8t6qMGDP+LOnchHHnPt2jUaNVrP\n0aNedOz4FZcvhxAbG4u1tTUXLlzA1taO06dPUatWXVavXklwcBC7du2gcuU3SEpK4tixo0yaNI6g\noGVky7YVa+uklKmYtra2/PDDW3z88Y+MHbsKP7+G2Nvbv4hLFxGRf6RpZO7XX3/F1dUVHx8f9u17\n8C88g8GQag7+w7i7O6SlFMnk1L9Zm/r3xYqNtSNbNvOUz93e3gozs8R07wd395LUr1+XLl3a4ubm\nho9PMXLlcsuy/X3hQiiRkZVSnsfGFiAo6Gimvd6FC30f+XWj0UhiIoSFVSE21gczsxF07dqNhIQI\nmjdvxrFjx/D0LMTly5dZunQBwcHB5MuXj7//Pk/16q+zfPkinJwc+emn9Q9tw93dgbFjW6b3pUk6\nyKz/30r6UP/KXWkKc4cPH2bbtm1s376duLg4IiMjGTx4MK6uroSGhuLu7s61a9fInj37Y88VGhqR\nllIkE3N3d1D/ZmHq3xfvxo07JCYmpXzukZGxxMTEpXs/uLs70LhxS1q16khMTAx9+vSgdeuCWba/\ny5cvjIfHbwQH1wTA1XU/ZcrkzhTX+7//bWDVqh9ISIjHx6cEAwYMpVWrJvj6LsXR0emBq462aNGa\nhIT82Nntxs3tC8zNw8mevTEJCetZvdqf6OgoTpw4ycCBw6hVqw6jRw/D0tICOzt7rK2dWbhwBb16\ndWHlSn9q1qyD0Wjk3LmzKdN7JfPSz+WsTf2bdT1LSE9TmBswYAADBgwAYP/+/fj6+jJ16lSmTJmC\nv78/PXr0YM2aNdSpUyctzYiICBAXF8fhwydwdXVk0aIVKa+n91YBcXFx/PzzXlxd7dm2bTV//32B\nuLg4GjZsRJEiXunaVmZSuHBepk+/yXff+WE0GmjdOjvly7+a0WVx8eIFtm3bwpw5vpibm/P555PZ\nvHljyqyXkyePp6w6Gh8fT5cu7+HtXYxs2bLh6BhHeHgcly/74ez8A0lJSzAYzOjQoTO//76L3Lnz\nUqtW8u/omjVrM2bMCF55pRtjxqxn5Mh6jBkznmnTPmPRIl8SEhKoU6eewpyISCbyXPaZ69GjB/36\n9ePHH38kT548fPXVV8+jGRGRl0ZERAStWn1HaOgerl79mM6d/Tl/fjmffPIZBQsWSrd2YmJiaNvW\nn1272gLxNGhwmgULPsHc3Dzd2sjMatQoRY0apTK6jFT++GM/p0+folu39kDyfXDR0cnbFhiNRo4d\nC7hv1dG7ihZ1p2xZC+LjV1GqlBmbNsWwYsVaNmxYT6FChejZs1/Ksb/9FsOZM39w5owd27bFcePG\nMr7+uhmffz7jxV6wiIg8sXQLcxUrVqRixYoAODs7p6xsKSIiaTdjxg4OHhyJq+sM7O1/Y+3aS3Tv\nXjVdgxzAokW/sWtXZyB5gZVNm1qxfv0umjSpnq7tyNNp2LAR77//QarXWrRo/M+jh686ajAY6NKl\nBl5e3ty6dYuff/7moW0cOWIN2P3zzJJDh+wYMuQn4uKy8e67r1CtWon0uBQREUlH6bJpuIiIPF8x\nMeaAGdevf4Cd3W6yZbtMzZr10r2d+/cOsyImJiHd25HHu7s5/OHDh1i+fCkffvg+sbGxjB07An//\nVQB07tyOCxfOsWzZYjp0aMWZMyfZs2cX8fHxTJw4jtOnTzJu3Eh27Uq9mbednV3K3rB3ubqm3qQ8\nOPgKCxe2YdmyFvTuncDBg6ef7wWLiMhTU5gTETEB775bmNy5t2BufhODIQp7+zDy5Xsl3dtp1+51\nSpVaTPLoTgKvv76UJk2qpHs7mcXKlcuIjY3J6DIeKjDwMp06dWP06E/4668ztG37LgcP7vtnS4Lk\n1aI9PYvSrl1Hbt68yaBB/Shc2JOAgCNUqFARL69iDB48glmzpv9zncn32ZUtW4GzZ8/SuXNbtm37\nBYCxY1/l9dcX4e6+gcKF53DnzuspdVy79gb/+9+FDPgERETkUZ7LPXMiIpK+ypYtwqJF8PHH3cmf\n/zVKlnRm/vxvUm0gnR5cXJzx86vFkiWrcHKyokWLxlhbW6drG5mJn98K6td/EyurzHmNdzeH9/Qs\nQkhIEAkJCVy5EkjevPnw81tLixaNqV69FnZ2dlSqVJk5c2YSEhJCdHQ0S5cuxNzcnBkzPic+Pp7Y\n2Bj8/NYC4OjoyKpVq1KtiFegQG7Wrn2XuLg4LlzITf36iUSlDNZF4eqqv/+KiGQ2CnMiIiYiKOg0\nxYvnYfz4oSQlJdGzZxcOHTpIuXIV0rUdFxdn+vZtkOWWv46OjmbMmGGEhoaSlJRIzZp1CAsLpW/f\nnjg7uzB9+uyMLvE+928OH/vAY6ZMmcCpUye4du0anTt349dft/LxxxPImzcfAElJSQwe7M/Oneew\ns4ulf393unat9ZA2LfHyKkyfPv9jwYJQYmIcqFnzBN26NX8+FykiIs9MYU5ExEQ0bNiIhg0bAWBm\nZsa8eQsztiATs2/fHtzccjB16nQA7tyJZMOG9cycORdHR6cMri5txo4dz6lTJ5g1azrvvdeJO3fu\nsGrVipSR2wkTFrNkSUfAEYAxY36mWbNbwMNXKR00qD49e0YQGxtH9uwlU7ZCEBGRzENzJkREMrmp\nUzdRpcoWqlX7H99++1tGl2OyChcuwsGD+5g9eyYBAUews7PP6JIe6+kClCHl+E6dupGQkEDHjq1p\n374lBw78wt0gBxAUVIRLl4Iee0Z7ewdcXV0V5EREMimDMfV6xhkmK03lkdSy2lQtSU39+3z9/PM+\nevb0JDa2AAAODkdYtSqesmW9n3vbaenbjRt/YsWK7zEYDBQu7Mno0Z88cx11677Bli07n/n994qI\niGDv3l2sW+dP+fKvsmHDer77bonJj8w9zurVe+nXz5OYmOStLHx8fuDgwcZERmql0qxIP5ezNvVv\n1uXu7vDU79E0SxGRTOzMmZspQQ4gIqIkR4+ufiFh7lmdP3+OxYt9mTt3AY6OToSHh6fxjOkzKhQW\nFoaDgwP16jXEzs6en35ai61t8hL9WSXMHT78FzNmnCEmxoIGDazo2DF5f8B3332d69e3sm3bYWxt\n4xg4sDg2NjZERuofhCIipkxhTkQkE6tcOR8uLge5eTN5kZPcuX/L9Js3Hzp0gFq16qYEJEdHx8e8\n48U4f/4ss2ZNx8zMQLZsFgwaNJw//wxg4MAPcXfPkSkXQHmQNWt+ZO3aHwGIjIzEwyM37dt3Yt68\nbzhx4haRkaUICZnE3r2XWbmyHo0avc2BA/to164DBQoksXTpUj77zMiBA7Xo2PH9DL4aERFJC4U5\nEZFMrFIlHz777Hf8/FZhZmaka1cPChZM//3l0pPBYCCTzOBPpWLF16hY8TWMRmPKKJ2XlzfNmrXK\n6NKeSpMmzWjSpBkJCQl89FEv3nqrMYsW+dK2bS9atSqBi8tPuLgs4MaND4iNTcLJyRlf36WEhYXy\n/vud8fVdir29A0OHfsTOnb/xxhs1MvqSRETkGWkBFBGRTK5p09dYtqw+S5c2oGbN0mk6V3BwEB06\nPN/wUq7cq/z66y+Eh98GSPlvZhAREUHLlj9QqVIIlSvv5vvvd2d0Sc/sq6+mUb78qzg4OHLx4nl8\nfb+kUKH2ODquxcIiGIPhOtmyGahduy4AJ08ep1y5Cjg5OWNubs7bb7/NkSOHM/gqREQkLRTmREQk\nXRUsWIgOHbrQp08POnVqy9dff5XRJaWYMmU727d3JTKyMoGBTZg2LYbo6OiMLuupbdiwnmvXrtKl\nSw+MRiMVKlRiyZKVDBw4DHv7jtjbV6Br1w04ONhgY2MD3D9imhlHT0VE5OlomqWIyEvqypVARo8e\nypAho/D2Lpau5753TzyAFSuWsmHDegAaNWpCy5Zt0rW9J3X7tgX3/h3z5s1cREREpAQeU3Dq1ElW\nrFjKrFnfAuDjU4IvvpjMlSuBtGjxGo0aRRMWFkrevPlo0cI35X3e3sX56qtp3L59C3t7BzZs2EDj\nxtoIXETElCnMiYi8hC5dusjHH49k5MhxFC7smebzrVq1h3XrIrGwSKBPnyKULVsk5WunTp1k48af\nmD9/EUlJRnr06EjZsuUoUsTric6dnnuc1azpxLp1J4iK8gGSKF/+GO7upnXP3OrVK4mIiKBv3+TF\nS7y9fRg58mM+/ngEcXHxAPTo0Zu8efOlep+bmxs9e/ahb9+eGI1G6tSpTdWq1V54/SIikn60z5w8\nd9oPJWtT/5oGP78VrF37I/ny5ePYsWM4OjoyceI08ucv8ND3PGnf/vbbUbp1syM8vAwAhQqtYePG\ncri4uACwcuVyIiLC6do1OXx8++0cnJ2dad68ddov7Bn4+//Otm23cXSMY+jQ6plmtc0XTd+7WZf6\nNmtT/2Zd2mdOROQldvdvcw8ayVqzZhXTp88mPj6eAQP6kDOnBwEBhx8Z5p7U7t1BhIe3SHl+/nx1\nfv99Hw0bVnlgPUaj8YE1nj79N7NnHyc+PhvNmuWkVq20LfbyME2bvkbTps/l1JnW5s2HmTs3hPh4\ncxo3tqBbt5oZXZKIiKQDLYAiImLCgoODaNPmXcaPH0uHDq24du3qfcdMnTqRoKArDBz4IT//vBYL\nCwsmTpzXjaYmAAAgAElEQVTKpk0/s2XLpjTXkC+fNWZmoSnPnZ1PUKxY3pTnpUuXYceO34iNTV5s\nZOfO3yhVqmyqc9y6dYuuXU+ybFkr/Pya8eGHBg4ePJ3m2gQCA4MZPDiOnTtb8vvvzZgwwZv//e+P\njC5LRETSgUbmRERMXPJCJp/g4/PgzcQHDx7B/v2/M3PmXO7cucPOnduxtrZmypSv6N+/N7a2dlSp\n8sYzt//ee9U5eXINmzc7YmWVQI8eNhQoUCrl60WLevPmm43o3r0jAG+/3ZQiRYqmOseOHcc4c6ZB\nyvPQ0Cps3epHhQpPdl+dJNu1awcXL57nvfc68d13c7G1tcPK6hWSkrZjb59IZGR97O0Xs3OnB++9\nVyOjyxURkTRSmBMRMXE5c3o8NMjd6/btW5iZmbNgwTIA7O3tmT9/cZrbNxgMTJzYlAkTHjx9EqBV\nq3a0atXuoecoVCgXtrbniIoq8885b5Ejh0Waa3vZVK1aLWVRE4PBgMEA5cp5Ym29mjt3kvvm9u1u\nvPba2YwsU0RE0onCnIiIibOxsX7sMeHh0dSvf5bIyLxUqfIDCxY0xdr68e97Gg+6N27t2l1cu3aH\nd96pQM6cbg99b4kSRejXbwsLFvxNfLwN9epdoWPHd9O1vsfp1asLs2f7PvTrdeu+wZYtO19gRakF\nBwcxcOCHlChRimPHAvD29qFhw0YsWDCPmzdvMXbsp1y4cJ7Tp0/Sv/8QAIxGKFDgFcqVS+L8+d0Y\njZE4OMzD0/MTALZs2cTSpQsxGo28/npVevX6MOVaW7Row549u7CysuKzzz7HxSV7hl27iIg8mO6Z\nExHJ4kJCgrl924wbN2oQE1OerVs7M336tufaptFopF+/H+nZsyKjRjWnWbODnD8f+Mj39OtXlwMH\nqnLgQCm++qo5ZmYv9lfUo4JcsvTbIuFZXbkSSOvW77Fs2Y9cuvQ3W7duZvZsX/r0+YjFixc8dGQ0\nf353Ro0qxy+/1OWVV7JjMBi4evUqc+Z8zYwZc1iwYBmnTp1g587fAIiJiaFEiVIsXLiM0qXLsm6d\n/wu8ShEReVIKcyIiJu5x+7CFhd0iKeneiRgWREQ83x//gYGB+PuXIinJDTBw5kwLvvsu4LHvs7S0\nxM7O7rnW9jB16ybfNxgWFsYHH3Snc+e2dOjQiqNHj6QcM3PmF7Rv35KPPurNrVu3AOjTpwezZ8+k\ne/eOtGnzLgEBRx54/vTg4ZGHQoUKYzAYKFiwEBUqVASgYMHChIQEPfF5jEYjx44do2zZ8jg5OWNu\nbk7dug04cuQwABYWFlSuXBUAL69ihIQEp//FiIhIminMiYiYmIiIcE6dOk1UVBQeHrlZtGjFA4+7\ndu0aoaGhFC1amBw5OpKU5ASAq+teGjTI+8D3pJcHbWFqNGb8yNajJde3ZcsmKlV6nQULlrFw4XI8\nPZMXa4mJicbb24clS1ZStmw5FiyYl/wug4GkpCTmz19E374DU15/Hiwt/72P0MzMDAsLi5THiYmJ\n91/RIz7y+/8I8O89j+bm/4Z/MzPDA88tIiIZT2FORMSEbN58hJo1/6BaNUfq19/B4cN/3XeM0Wik\nf/9VVKp0lddeC2LUqJ9ZtKgO77+/gg4dVjF7dhJVqxZ/rnXmzZuXd94JwGC4ARjx9PyRrl1LPtc2\n04uPT3E2bFiPr+88zp07i62tLZAcmGrXrgdAvXoNU43YVa+evG+bl5d3phnFMhqNPCBTA8lBrlSp\nUhw5cojbt2+RmJjIL79spkyZci+2SBERSRMtgCIiYkK+/PIKly61BuD06aJMm7aC778vkuqY1at3\nsHx5Y5KSXAFYssSTGjUC+PTTRg88Z3BwEEOH9mfx4h/SrU6DwcCMGc2pXn0H169H8/bb5cidO0e6\nnf95Kl26LLNmzWfPnl1MnPgxrVq1o0GDt1Id89+Nzy0sLAEwMzN/rqNY/x1Ne9AU27uv3V3N8mHc\n3d3p2bMPffv2xGg0UrnyG6lWwnxUGyIikjkozImImJA7d6xSPY+KsrzvmKtXo1KCHEBiYg6CgyOe\ne23/ZTAYaN68+gtvN61CQkJwd3fn7bebEBcXy19/naZBg7dISkri119/oXbtemzZsum+jc+ft/9O\nqR0xYmyqr90N4w0bJof2Ll16PPDYmTPnpjyuU6c+derUT9XO7du3+PbbxSQmJmJubk6NGrWpUaN2\n+l6MiIikC02zFBExIVWrRmIw3ATA0jKQGjXun0fXqFEZChZcl/Lc03Mtb7316OlzSUlJTJ48gfbt\nWzJgQB9iY2PTt3ATcHcE6vDhg3Tu3JYuXdrx669badGiDQDW1jacOHGcDh1acfjwITp37vawMz1V\nu8HBQXTo0CotpaebefN+o3LlE1SunEiLFquIiHjxfwQQEZEnZzA+6C71DBAaql8YWZW7u4P6NwtT\n/75YRqORefO2cvFiEqVK2dKmTdUHHnfixAUWLjwFGOnWrQRFi+Z76DmDg4No3bop3323FE/PIowZ\nM5yqVavRrl1L9e0L8DymuT6J/37vhoff5vXX/yQ0tME/ryTRq9cPjBv34Om5knnp53LWpv7Nutzd\nHZ76PZpmKSJiQgwGA++/X+exx/n4FGTKlIJPfF4Pjzx4eibfe+fl5U1w8JMvc/8y+vvvYEaO3E9Q\nkB1Fitzm88/rY29vn+bzXrkSyOjRQ6lTpwHHjh0hJiaGwMDLtG7djtjYOH75ZRMWFpZMnTodR0fH\ndLiSf0VERBAefu99jWZERlo89HgREcl4mmYpIiL/WfL++S7ikRUMGbKPzZvf488/m+Lv34FRo35J\n8zkvXbrI6NFDGTlyHM7Ozly4cJ6JE6cxf/5i5s37Bjs7O3x9v6dEiZJs2vRzOlxFah4eualU6Q8g\nue+dnf+gfn33dG9HXi6DB3/EnTuRjzxm8WLfF1SNSNajMCciIvKULl26dyqMGZcvp21U7ubNmwwf\nPoixYydQuLAnAGXLVsDGxgZnZ2fs7R2oUiV5pclChTyfaoPwJ2VmZsbChY3o08ePjh1/ZNasKOrV\n01YFkjZTp07Hzu7R3x9Llix8McWIZEGaZikiIk+05L38q0CBcM6dM5K82EkCBQs+euThcezt7cmZ\n04OAgMPkz18Ag8Fw3wbhd58/bIPw9GBvb8+YMW89/kB5KURHRzNmzDBCQ0NJSkqkY8duODk58c03\n00lMTMTb24dBg4bzxx8H+PnndXz66WcAHDp0kBUrvmfKlC9p3vxtfH2X4ujoxP/+t4FVq34gISEe\nH58SDBw4jLlzZxEXF0vnzm0pVKgwo0d/msFXLWJaFOZERF5y/13yvk2b915Y2w/6x52ZWeafNPLF\nF1UZOXIpwcH2FCkSzqefNkzT+SwsLJg4cSoDBvTBxsbmkcdmknXL5CWwb98e3NxyMHXqdBYu/Jb5\n87/h6tUQXn21EmXLVuDYsQA6dWqDlZU1Fy6c4+zZ03h6ejF16iRy5MhBjx6diIgI54svpmA0JrF/\n/+/Y2zswfPgYpk6dyLvvvknFiq9jaWnFggXLmDbtM7p160BsbAw1atSma9f3AWje/G0aNmzE7t07\nSUxM4NNPPyNfvgIZ++GIZBKZ/zemiIg8F3fu3GH9+p0cOHAsQ9q/ePEC27ZtYc4cXxYsWIbBYMbm\nzRszpJan5eHhjq9vEzZurMOMGe8+NoA9jsFgwNramilTvmLlymXcuRP5n9HR1Jt4a+RUXoTChYtw\n8OA+xo8fy6ZNPzN27AS8vX24dOkSAMHBweTMmQtf36W89loVxo0bTUJCAqGh17CwsGDu3AU4OjoB\nEBh4GSsrawA++qg3CQkJNG78LufOncVoTAKgR4/efPvtYhYuXM6RI4c4f/4skPz/vLOzC76+S2nS\npDnLly/NgE9DJHPSyJyIyEvo2rXrtGu3nYCAd7G0DKFTp3WMH9/4hdbwxx/7OX36FN26tQcgNjYW\nV1fXx7wr67l3ZNTe3p758xffd4yf39qUxw0bNkrZGFzkecqbNx++vt/z5ZdTSEhI4Pffd2Nubk6V\nKm8QFxfLxYvnCAqyonPntkRFRXHz5g0OHz6Ik5MTderUT/VHh0KFPKlY8XUaNXqHgQP7smLFagCC\ngq5w4cJ5ALZt28y6dWtITEzk+vUwLly4QKFCyfeQVq9eC4CiRb3Zvn3bC/4kRDIvhTkRkZfQrFl7\nCQjoABiIi3NgyZLr9O59hdy587zQOho2bMT773/wQts0Fdu2HWXSpEvcvGlD+fI3mTHjbaysrDK6\nLHmJhIWF4eDggLe3D0lJifz55zFCQoLJk+cVHBwcMDMzo3v3njRv3prExERat27KunVryJ07D9bW\n1qnOVaRIUVavXkX16rWwtLQgPPw2UVHRmJmZYW5uxuXLl1ix4nu+/XYJ9vb2TJw4jri42JT3371n\n1Nz8+d0zKmKKNM1SRMTEzJnzNatX+6U8/+67uU897Sg+3px7p+7Fx9sRHR2TXiU+kfLlK/Lrr1u5\nefMmkLxpdUhIyAutIbNKSEhg1KhAAgLacOlSE/z92zFlypaMLuup1K37RkaXIGl0/vxZevTohL+/\nH7t27aBz5+707z+EzZs34u+/CltbW5ydXYDkhXl8fEqwb99ecuTIec9Zkn/O5MiRk+7dezF+/BgC\nAy/Tv38fbtwIA6BChUoMHPght2/fws7Ojhs3rvP773te9OWKmCSFORERE1O7dl22bfv3H/a//rqV\nOnXqPdU52rQpQt68d+9Pi6Zu3T0ULPjkm4ynhwIFCtK9ey8GDPiAjh3bpPrH3cvu9u3bXL36yj2v\nWBISYplh9Twb3ddn6ipWfI1Fi5azfPlq2rbtwIQJY/nss08pWtSL7t17Mm/eIjZu/JlOndrSvn0r\nChYsxObN2zE3N0+ZYunntxZLS0sMBgO1a9dl6tTp5M2bj+++W4KPTwkAGjR4k5Ur1/LGGzVo27YZ\n48aNplSp0g+pSveMitzLYMwky2KFhkZkdAnynLi7O6h/szD1b8Z4770WfPXVbG7evMEXX0xm9uzv\nnvocp0//zdq1J3F2NqNLl1pky5Z65n16922vXl2YPfvBmwPfu5S5JK9Y2ajRjxw40BkAc/MQPvlk\nH92710q3Np73927dutXYsmUHUVFRDB8+iIiIcBITE+jevRdVq1YnODiIQYP6UqpUWf78MwB39xxM\nmvQ5VlZWnDx5nM8++xQzMzMqVKjEvn17WLz4BzZsWM/p0yfp338IAEOG9KNNm/aULVueadM+49Sp\nE/ethLh37y6+/vorrK1tKFmyFEFBQUyZ8iXR0dF8+eUULlw4T2JiAl269KBq1eqcP3+OSZM+ISEh\nnqQkIxMmTOGVV/I+t8/peXjSvk1MTMTc3PyJzhkdHY2NjQ2ffjqGY8eOMmHCZIoU8Xqm+vbsOcGm\nTZdwdjbywQe1NH34Ken3btbl7u7w+IP+Q/fMiYiYoJo16/Dbb79w/fr1px6Vu8vLKz9DhuRP58oe\n7m6QMxqNLFmyg2PHYihUyIyePeu8sBpMhcFgYPbsKkyc+D3h4da89hp061Y3o8t6JlZWVkyaNBVb\nWztu3bpFz56dqVq1OpC8wuG4cZMYOnQkY8YMZ/v2bdSr15CJE8cxbNgYihcvwZw5Xz9iJObfUZoe\nPXrj6OhIYmIi/fr15ty5s7zySl6mTp3EN998S65cHnz88UjunmrxYl8qVKjIiBFjiYiIoEePjlSo\nUIl161bTokUb6tVrQEJCQqa8P+vu3+EfN0K1cOG3bN68EWdnF3LkyImXVzH27NlJkSJFOXo0gLp1\n61O6dDm+/jo53Do5OTNy5FhcXd24ciWQL76Ywq1bN7G2tsbOzo7Q0GsEBQVRpcobFCnixfz5swkN\nvcawYaOfeEuR3347Su/eBsLCWgBxHDq0gCVL2jzwWuLi4pgyZTOBgZYUL26gT586GpUT+Q+FORER\nE1SrVl0mTx7P7du3mDVrfkaX80Tq1n2DLVt20rlzf06cuILRaMHq1R0IDFxLs2avEB0dxahRQ7lw\n4RxeXsUYMyZ58+CXdY+pfPk8mDPnxa4w+jwYjUbmzPmagIAjmJkZCAsL5ebNGwB4eOTB07MIAF5e\n3gQHBxEZGUl0dDTFiydPwatbtwF79ux8bDv/XQnx4sXzJCUlkjt3HnLl8gCgTp36rFvnD8D+/b+z\ne/cOli9fAkB8fDxXr4ZQvHhJFi/2JTT0KtWr18o0o3LBwUEMGNCH4sVLcvr0SYoVK86pUycwGAx0\n6NCV2rXrcujQQXx95+Hq6kJAwFESExPp2fMDVq/2Y/v2bSmfw+XLlzAzM2PTpp/x9Z3PvHkLyZ+/\nACNHDuGDD7rj7p6D48f/pG3b9+jWrRfHj//JvHmzWLBgGRMnjqNy5arMmjWd6OhoRowY+1TXsX59\nMGFhzf95ZsnOnWW5ejUkpbZ7DRiwjpUr2wDW+Ptf586djQwb9mYaP0mRrEVhTkTEBBUsWIjo6Chy\n5MhJ9uymspy/ge3bt3HxYggXL27E3PwG+fI1Z9euzjRrBn/9dZqlS/1wdXWjV6+uHDsWQMmSpVPt\nMeXvv4rly5cydOiojL4YeUKbN2/k9u1b+PouxdzcnBYtGhMbGwf8u0IhgJmZOYmJsfe9/967QczN\nzUlK+vf53dUOg4KuPGAlxDjuv28v9Z0lEyZMJW/efKley5+/AMWLl2TPnp0MGvQRQ4aMoFy5Cs9y\n6enuypVARo/+hNDQa6xZ8yOLFq3g1q2bdOvWgTJlygJw9uxfzJq1iWXLVrJgwXxCQkL47rulfPjh\n+xw7dgQzM3Pefbclr79ehfPnz9KtW8d/Apw7YWFhxMfHM3/+Yt55pz6LFy9g166dGAwQH58AJPfH\nwoXf4eNTnCFDRj71NVhaxpPcD8l9Y2d3E1vb3A889vBhZ8D6n3ZdOXDA1O4bFXn+tACKiIiJWrRo\nBdOnz87oMh4rODiIDh1aAXD06BGcnEoABhITXYmOfhVr6wsYDAaKFSuOm5s7BoMBT8+iBAcHp5zj\n3j2mgoODMuIy5BnduXMHF5fsmJubc+jQQUJCgh95vL29Pba2tpw48ScAW7duTvlarly5OXv2NEaj\nkatXQzh58jgAUVFRWFvb3LcSYr58+QkKupLS5tatW1KmWVas+BqrVq1IOfeZM6eA5GCYO3cemjdv\nzRtvVOfcubPp80Gkg5w5PfDxKUFAwGHq1m2AwWDAxSU7ZcqU4+TJE/98H/ng5uaGuXk2HBwcqVTp\ndQCcnJwJD0++zyoqKor+/T9gxIjBAHh7F2PBgmU0bdqcdu06YGZmwMHBkXz58jNt2nQWLFjG0qUr\nAVLaOH36FOHh4U99DQMGvE758guBSzg47OH99++kbCz+Xy4uqVfYdXKKfur2RLI6jcyJiJgIo9HI\n6tU7CQmJon794nh6Zo7pX0/HQP36Obl504/Tp0tgb3+ZJk2qAmBh8e9f3ZP3kkpIea49pkzP3Xub\n6tVrwNChA+jYsTVeXsXIn7/gfcf89/mwYaOZPHkCZmYGypQpj52dPQClS5fBwyMP773Xgvz5C+Ll\nVQwAT88iFC3qRdu2zciRI1fKSohWVlYMHDiMgQM/xNrahmLFfFLa6NSpGzNmfE7Hjq1JSkoid+48\nTJ78Jdu2beF//9tAtmzZcHV1o0OHLs/3g3oKNjbJo1QGg4H/rl9397rufh+VKlWa+fNnAwaioqI4\nceJPbGxsAfj++0V07fo+FSu+RsuW73DjxnUAkpKSCA8Px87Onty5c3P16lUSEhIxGo2cO3c2ZUps\npUqvU7HiawwZ0o8vvvgaW1vbJ74Gd3dX1qx5m+PHz5AzZ3by5Cn50GNHjCjM8OHLCArKQ5EiFxk5\nsuITtyPyslCYExExEYMHr2bp0rdISnLH13cj8+dHU65c0Ywu64kkJiYSFxfLtm1bSEhIYM2aFVy8\neIEJE4Jo3boRFy6cz+gSJZ1t3rwdSB4RGjt2PAMHfghAYmICc+d+TcOGjbC1taV163cZO/ZTmjRp\nxpdfTqF7947Ex8fRvXtPqlatzqhRQ7lx4zoDB/blypVAqlWrkXI/5b0edu9WuXIV+P77VQB8/vlk\nvL19gOSgN3jwiJTjbt68we7df/DWW415771O6flRpLtSpcqydu1qGjZsxO3btwkIOEyfPv1SfR95\ne/vg5OTE2LHDyZXLAw+P3ERGRmIwGIiJicbNzR0LCwu8vLw5cuQQnTq1JSwslLJlywMwZsx4OnRo\nxaBBHwIG6tSplxLmDAYDNWrUJioqimHDBjBt2gwsLZ98CqSVlRXlyj08xN1VuXIxfv3Vi/Dw2zg5\nldXiJyIPoDAnImICwsNvs25dXpKS3AG4fLkhixevxMPDnoEDP8Tb24czZ05RoEAhRo8eh5WVdQZX\nnNqlS39jZWWFv/8GOnRoRadObXBxceGDD/rh4pKdixcv8GT/TtMeU8/Kz28Fa9f+iJeXN6NH3x+G\nnrcrVwIZP34Kw4ePoVu3DmzdupnZs33ZtWs7ixcvoECBgimrS/700zpGjx5GnjyvkC2bBYmJiXz6\n6SSyZbPgnXca0Lhx04cuTDJ58nhatWpHgQLJI4Dr1/uzceNPxMcn4OXlxTvvvHvfe7ZtC2DQoNsE\nBpYlb97DTJ3qRK1aD9vnLOPc/X+/evWaHD9+lE6dkleB7N37owd+H+XIkZOPPhpE/vwF6NKlHS4u\n2ZkxYw67dm1n9OihODg4Ur58BaKiopgxYw6+vvNSRtk8PHLj4ZGbKVOmkytXrpRz3hua33qrMW+9\n9eSL9AQHBzF0aH8WL/6BU6dOsGnTBvr1G/TAY+/druTuxuQicj+FORERE2AwGDAzS0r12t2VwC9f\nvsSIEWMpUaIUkyZ9wurVq2jT5r0MqPLh3NzcU/az6tdvMH5+K5g0aVrK18uWLZ8yIgCk7CMG4Oe3\nLuWxt3cxZsyY8wIqznrWrFnF9OmzcXNzT3ktISHhvv0FnxcPjzwUKlQYSF7Ap0KFiv88LkxISBCh\noddSrS7p6urGhAlTOXHiT44eDcDW1g5IXvTk8uW/HxjmkpKS7lscp2XLtrRs2faRtc2YcYXAwOT7\nOi9fzs3MmT9kujDn4ZGbRYv+vcevd++P6N37o1TH/Pf7KGfOXHz22SfExcXx5ptvp4w4Vq1aPWV7\niHt16dIj1fPFi39IeZyYmMjnn2/mr7+ykS9fLMOG1cfCwuK/p3hi3t4+KaOkIvLsFOZEREyAg4Mj\nLVtew9f3EnFxr1Co0Dq6dfMGkv/6XqJEKQDq138TP78VmSrM3bhxnRs3rtO370Ag+d6/Jxld+/HH\nPaxZE4GlZSK9exemfPln26BYYOrUiQQFXWHgwA+5ejWEKlWqERoagqtrDt5//wM+/XQM0dHJi0sM\nGDCEEiVKpSxz7+zsct92ESdPHmfGjM+Jjo7BwsKCGTPmYGlpyZw5X3PkyB9cv34do9GIk5Mznp5F\naNy4KWFhoXTs2AZnZxccHR2xsLBgwoSP8fEp/s/m1dm4ciWQrVt3p7Q9f/5sjh37N8j5+a0gLi6O\nL7+cyooV3zN9+mzq1n2Dd95pxsGD+xkwYAjz5n1Dnz798fYuxv79v+PrO4+4uDjy5HmFESPGYmNj\nw+zZM9m9eyfm5uZUrPgaMTGpQ0VMzLOHlMxk7NjxT3zsnj0n+PXXS3h4WNCpU8379o3r1WsSf/zx\nFwZDEnv3Fuf69RiOH59OixZt2LNnF1ZWVnz22ee4uGTnypVAxo0bRWxsDFWqVMPPbwVbtuxIdb57\nR94OH/6DGTM+B5L/cPX118nbrTxsuxIR+ZfCnIiIifjkk8ZUq7afS5f28+ab5ciVy53g4KBUwehJ\ng9Kz6tWrS8rm308qe3ZXkpKSUhar2LJlE6VLl3nke3bsOMawYTm4fbs+AH/+uZYNG9xwdTWVbRgy\nl8GDR7B//+/MnDmXVat+YM+eXfj5/cDt27HExsbw5ZezsLS05PLlS4wbN4pvv10MwNmzZ+7bLsLb\n24exY0fwySef4e1djKioKCwtLfnpp7XY29szfPhYRowYhI2NLRMnTsXOzp5Ro4bi4ODAokXL+fnn\ndfj6zqN27bqp/l+tWPE1/vrrTMrz06dPsnz5avbt28vcubM4diyAFi1aM2fOTD76aCBVqlQDICYm\nhuLFS9CnTz8gOQwYDAZu3brF4sX/Z+8+A5o6uwCO/zNI2MsFooKigoLgrnsWt7Yqjrq1asW6xf1q\nnbgHWndFcSuu2rp3XXWhOHHjYIlMIRAgyfshEkGw1bo6nt8ncnPHc29Sm3Ofc88JwN9/CUqlMevX\nr2HLlg20adOOkyePs3HjdgBSUpJJTT3F9evhpKc7oFA85csv/1tFdvbvv8SQIabExbUDErh6dQcL\nFngb3g8Le8idO7d58mQ7IKNgwUlcvnwfrTYNd3cP+vbtz5IlC9m9eyfdu3+Lv/8cOnToRMOGjdi1\na/ufHn/z5vUMHz4ad3cP0tLSDDN+r7cruXr1Ch4ef/xvhyD814hgThAE4R/kyy9zV3OLjo7i+vVr\nuLuXe6tA6X28ayAH+h/XxYo5snPnVmbMmIyTUwm+/tr7D7c5dSqcxMR2htcPH9bj7NkztGhR652P\nL7ySVQGxVq06LwtWqMnIyGT+/Jncu3cXqVTK06dPDOtntYsAXraLiMDU1Ix8+fLj6qoPzrOesbpw\n4Xfu37/Hrl3b0Wgy0Wq1PH36hCpVvuDu3VAKFCgE6GeP586dkSOQk0gk9OjRmw0b1tK9e0dUKhVK\npZL8+QsglUqxtrYhMjKScuWyUh9fbSuVSqlXr2Gu87xx4xphYQ/o109fjTIjI5Ny5TwwMzNHoVAy\nffpkatSoTc2atfH1bULRoqe4ceMM7u4WtG/f5MNd9H+AnTtjiYur9/KVNYcP26FWq1EqlQBcunQe\nne4pxYq1BUAiUWNsXJSMDCNq1ND/N+niUoaLF88BcOPGNWbMmAeAl1djFi/2/8PjlyvnycKF82jU\nqKibZbUAACAASURBVAl16zagQIGCQO7vX1RUpAjmBOE1IpgTBEH4h3vXQOl9eHnV5tChk++0jZ2d\nvaGa4NtydDRBKo0xFHyxsrpJmTJF3mkfwptlL5CzZcsG8uXLz/jxU9BoNDRoUMPwXu52EZo/LFQz\nbNhIHj9+RGxsLH379s+2rYyAgPWG16amptSt24Dffz+DlZUNgYGb0Wq1SKUSAgM3G1LwALy8mhAa\netPQqsLWNh/lynkY9qVQKN84G1258hdMnDgt1/KVKwO5ePE8x48fYceOrfj7L6VDh9w3CqZNm0jN\nmrVzBYvva9Wq5Ziamv1t0qGNjDJyvFYq03I9S9mkiRcnT5bhwYPCFC0aycyZbowYccHwvlQq+ctt\nQ7p06UGNGrU5e/YUPj7fMm/eopfjyv39EwQhJ9E0XBAE4R9OJpMxfvwU1q8PYurUmYa76R/H26dw\nqtVqZs3ay6hRB9m79+I7HaVTpzr07r2fYsV2UapUEOPGxeHs7PSOY/37yN44HWDjxnUEBKz4jCN6\n5dixw8TFxQGwf/8etFp9oZ3ExARu374F6J9vOn1aH8QXK+ZEbOxzQkNvAtC2bQvi4+OoWrU6O3Zs\nw9OzIseOHebmzeukpaWRlJSIu7uHofn3wYP78PSsAOgD/axjnDr1G5mZr3oLJieraN58J+XLn+bX\nXx8QHa0fo6mpKSkpKX94ThKJBDe3cly7FkJ4+FMAUlNTefLkMampqSQnv6B69ZoMHDiMe/fu/OF+\n3tW0aRM5fvwIgOFavmm/Pj5/3MNu7dqcM+F/tv5fNWCAG6VLBwHxWFhcoHdvqaFgEUClSlUJCbnI\n+vV1uXjRjW3b6mFnZ/7G/bm5lePYMf01OHz44BvXyxIe/pQSJZzp3Lk7rq5lefz4kahaKwhvSczM\nCYIg/AMdP36V48cjMDdP5F0CrE/pu+92sHdvD0BBUNB15sw5S5s21d9qW4lEwtSpXzNlysd9BvBz\n+TznlD2t8dVSR8fiXL58iR49OvHFF9UNjaWtrKwNqZTZyeVyJk+ezvz5s1Gr1cTFxZKRkUnLll8T\nGRnBlCnjSU1NZeDAfjg4OODqWpYhQ0YyffokNm5ch42NjaG8fatWrRk9eniuYwPcufOCq1e7AaDT\nXWbnzkf06KHfZvjwgRQoUBB//6VvvJbW1taMGzeRiRPHkp6un3nq27c/pqamjB49nPT0dEDHwIHD\nDNvs2/crmzdvQCKR4OxcEplMxpUrl9myZQOxsbH07z+IevUa5ijeATBv3kzKlHGjadMWHD9+hOTk\nZNauXU3nzt0wMzNnxYolaLVarK2tWbBA/3dY2APkciPat/+K9u2/wdu7Y65zWLduTY6m5X8lzflt\nuLo6sXevDWfOnMfZ2Z5SpRrkeN/JqTh9+vjw7bedSUhIxMTEhNmz/XOlymYZNGg4kyePZ9261VSt\nWg1zc/M818v6MyhoE8HBF4mJeUbx4s5Uq1aTa9dC3rJdiSD8t0l0WQn0n1lMzIvPPQThIylQwEJ8\nvv9i4vP99H7++Ry+vvlITKwAJNGhwzYWLWr3p9u9q7w+Wy+vOrmq0uUlOTmZypWvExfnZVjWtu12\nli5t9MHH+U+Qvb8WwKZN60lNVeUqBf8pZAUsRkYyHB1LIJPJMDU14/btmzkCluxjzh68JCYmMHHi\nOJ4/j8Hd3YMLF84RELCe6OgoRo8eTrlynty/f5fZsxdy9OhBjh07THp6BnXq1OPbb78jMjICX99B\neHhU4Pr1EAoUKMj06XNzzSh7eR0mJKS14XXFijvYv9/r9dP5YB48uM+4cSNYvnw1lpZWJCUl8eOP\n80lLS2Py5OmEhT1k9OhhbN68M8f12LfvV378cT5KpTEVKlTit9+OUbKkC6AlJiYGlUpFQMB6IiLC\nWb58MTY2Nly7FkKxYo7cvXuX7dt/oWPH1hQv7kxqqgqNRsPw4WM4c+Ykmzevp0QJZ0qUcGb8+CmG\nNGeVSsWYMb68eJGERpNJnz4+1KpV13Btv/iiKhcuXHzjtf2rOnf2fqv2Fmp1miGV9/DhAxw5cihH\nK5I38fObRI0atT54Wuu/jfj/7r9XgQIW77yNmJkTBEH4h/nllwQSE798+cqS48cLkZ6e/rKgxd+D\nsbEx5uYveJm9B+gwNVV/ziF9VjKZDK321b1TtTrts4zjwYP7rF0bwPLlq3F2LsL9++H8+ON84uJi\nWbo0wBCw/NGP6dWrV+LpWYEePXpz9uwpfv31Z44evcbUqc9QKp8RFVUcf/9uPH4cxtOnT2jduh2h\noTe5fTuUkJDLFCxYiKdPnzBp0nRGjRrHhAljOHHiKI0aNc1xHHf3JEJC1IASSMPD449TK9/GihVH\n2b1bg1yuoU+f/DRvXtnwXnDwBRo08MLS0goAS0tLAGrX1vdjc3IqbkhHff161qlTDw+PCtSsWYff\nfjuGhYUFs2bNZ8eOIJYtW4SdnT0REeE8eHCPdeu2snfvLxgZGXHv3j2srKyRy40oV84TH5+BaLVa\n0tLS8PQsz44dQaxevTHbEfVTVUqlkunTZ2NqakZCQgL9+vU09I17+vQJCxf6M2jQyDde278ir/YW\nERHh2NnZM3iwL3Pm+BEdHQVAs2at+PnnHcTGPkcikWBnZ59rBjL7LGjJkqX43/8mAXDlSjBz5/5I\nYmIKtrZfMnlyGzw8Sr73+AXh30oEc4IgCP8wSmXOYgXGxrmLFXwsb5seKJfLGTzYhBkzDhAbW5SK\nFS8wYkTtjzy6vy9b23wkJMSRlJSIsbEJZ86colq1Gn++4Qf2VwKW14WEXMbPTz/LUr16LSwsLFm6\nNJKIiEYUKbKWsLChzJ+/merV73LhwjnOnTtLWloaFhaWPH36hIIFC2Fv70DJkqUAcHFxJTIyItdx\nZs5sgaXlDsLCFDg7pzNmTPP3OvcDBy7i51cWlUp/3Hv3juHuHo6jowOg/27nlayUvTF21vsymRyd\nTmu4nrGxz4FX17NWLX3bhEKF9FUhs5Qp44adnT0Acvmr/RobG3PkyEGUSiW1a9ejVKnSf3guOp2O\nZct+JCTkClKphOfPY4iP139u9vYOuLq6EhPz4o3X9q/Iq73FkiU/oVAomDhxHO3bd8LDozxRUVH4\n+g5k/fogVq1azsWL51m0aDkpKcl06tSW1q3b8ehRmOGmgqWlFS9evDCc1/nztzh/fgtGRgmkpfkw\nZEgZDh50+mT/xgnCP434L0MQBOEfZsgQD65f38KtW7WxsrpHv37KXA1+35dOp0OlUuVYlpiYYPix\n+ja6dq1FixZxPH8ei6Nji7/VzOGnJpfL6dGjN336dKdAgYI4ORX/LM/N5RWwqFQqFi9ewNmzp7lz\nJxS1Og21Oo379+8RERHOt992RSKRYGGhT/9Rq9WMGeOLTqfDwaHIy++KEnv7YchkyRQr9jVPn8YR\nF1eBLl16oFAoCA29ydChI4mPj2fq1AnExETTp083Bg0ajlQqQ6PJPWurUCiYNKnFBzv3K1eeo1LV\nN7x+9qwa588fNARzFStWYexYXzp27PwyzTLxjfuys7MjLOwhVap8gVqt5tKli4aiLoAh8Chb1h2t\nVmsIqN4UkJiYmDBmzETu3buNn99EOnToTJMmbw5eDx7cR2JiAgEB65HJZLRr1wq1Oh0AheJVkPim\na/s+cre3gIsXz/Po0UPDOiqVitTUVCQSCTVq1EIul2NlZY2NjS1xcbG5bipkfbckEglSqSs6XT7S\n0/Mhkz3n3r1SPH8eYwiCBUHISQRzgiAI/zClShVjzx5bQkJu4+hoR5EiFf58o3cQHHyXUaNuEhFh\nR/Hi4fj7V8XKSsnAgd/xzTdd32lfNja22NjYftDx/VN5e3fMs8jFp5Q9YClQwMIQsDx//pw2bdrh\n7u5B3bpfsH37Vo4cOUihQnasWrWOn35axt69vwD6oL5OnfqMGvU/Jk36Hy9eJFGrlopDhzIBHVFR\ni+jVaytXr/7C06dPadasJQAxMc+YP382LVp8RUzMM6ZMmYWv70CaN//qk5y7h0c+jI3vk5bmDEDB\nguepUuXVDFjx4iXo1q0XAwb0RSqVUbq0C/B6wQ7934UK2VG//pds2bKJxMR4KlWqApArALSxsUGh\nUDBu3AiSk1NQqZKz7evVepmZGVhaWtKy5dekp6u5e/c2TZo0Ry6X5/lMWkpKCjY2tshkMoKDLxIV\nFfkBrtC7yd7eAnSsWBGYYxYzS/YZSKk0q71F3rOgAHZ2MuAFYIFEosPR8QH58n28ZyUF4Z9OBHOC\nIAj/QObm5tSsWemj7HvKlFBCQvT9r2JiYMqUDaxZ04pNm3Z8lOP9W92794S5c0NQqRR4eRnTpUud\nzz2kHAGLQmFEiRKlkEj0lSvd3fW92+RyI86d+53Hjx+h0Wjo2bMTKSnJpKenk5KSjEKhJDo6iq5d\n21OihDNyuRFjxjQhNHQLyclS5sy5TseOfWnbdjd16tQnMHAVarWaO3dCCQsL4/HjMCIiwhkzZhgq\nlYrMzIxPMkvZtGkVRo06xC+/hGBkpKF3b1ucnIq8tk4LmjZ982zgwYMnDH/37z+I/v0HsW/fr2za\ntI4tWzZy+fIl6tf/EjMzM8N6MpmcgIANBAdfZMsWfe+86tXrk5qaZgjounbtxZgxw5DL5Ziamhme\nH2vVqjU9enyDi4sr48dPMVynRo2aMGrUMLp374iLSxkcHYsbjvf6tfwU17ZKlWoEBW2mUyf9zZ67\nd+/8QaqoJI9Z0CTDrH+zZp7IZDs5f96CjIxMZs4snmeQKAiCngjmBEEQhBxiY01yvI6LM3nDmsKb\npKWl8d13l7l2rTMAx4/fwcLiHF999cVnHtmrgCWrIl5kZAQDB35neH/WrPls374VZ+dSLFuWsxR+\ncnIyUqmUefN+BPT9wZ48eYKVlTVFixakZ89xVKz4qqhImzbe2Nracvv2LYYOHUnz5l8yfPgEbGws\ncXJy/DQnnM3333vx/fcfdp9vGwBWrFiZChUq4eu7nY0bK5OZaU3Dhv1IT09/4z58fAbi4zMw176s\nrKxzfTZZAgM3G/7+8E3J825vMWSIL/PmzaR792/QaDSUL18RX9/RudbLktcsaFa7CplMysyZ+iqm\njRpNoUaN3O0xBEF4RQRzgiAIQg6enomEhmZVEUyiQoXPU3nxn+zhw0dcu/YqqElNLc3Zs1f56tNk\nFL6z6Ogorl+/hrt7OQ4d2o+bmzu//LLLsCwzM5MnTx5TvHgJLCwsCQm5gqdnefbv30OFCvoZYp1O\nx9Gjh6hYsTIhIVcwN7fA1PTVDFVaWhrp6fZ07HiS1NSG9Oz5K126lKJUKZfPddqf3Jkzl9m4sS6Z\nmfqZtCNHerBq1W58fJq8975TU1NZvvw46eng41MTC4u3f771bQUF/QyQq6WGlZU1kyZNz7X+6+tl\nteaAvIPg778fTGamxvA6+0yoIAh5E8GcIAiCkMPcuS3Jn38Hz56Z4+SUyrBhzT73kP5x7O0LUqjQ\nTaKjswIVFfZ/4/oNxYo5snPnVmbMmIyTUwm8vTtStWp1/P3nkJycjEaTSYcOnShevATjxk1kzpzp\npKWl4eBQxDCjIpFIUCgU9OrVGY1Gw5gxEwzLJRIJS5Yc49q1lRQs6Iel5WT27UtBo3Fl6tQZn/PU\nP6m4uBQyM/NnW6JEpXr/NMiMjAw6d/6ZU6d6AnJ27/6ZdetcKF7c4b33/alMnvwrGzYUIDNTQbNm\nR/H3b/vBCzsJwr+RaBoufHSiueW/m/h8/51WrVpOwYK2tGz54ZuR/1cEBf2Ov38cycnG1KwZy8KF\nbZDJZJ97WAbZ0yyzNzT/qwYO/I4BA4bi4uKa5/tTpux/rbl9JDt33v5oz37+HalUKtq23cOlSz0A\nKSVK7GLTJtf3DrqOHz9P+/buQCHDsqFDtzJmzPv1l/Px6cXSpXmnc35IZ85coX17O9LTS71cksDc\nuSfo2rXBRz/2P5H4/+6/l2gaLgiCIHwQn6Ns/r9Nu3bV8PbWodVq/1ZBXF4+9uedmJiIVhuDtfXP\nJCR8BWipWnUvlSt/uLzTXbu2Y2xs/Icl/T83U1NTNm1qxJIlW8nMlPLNN27vHMidPHmcokUdcXJ6\nVfTEzMwYmewFGk1WMKdBofjr9+qzKmh+ikAO4NGj56SnV822xJqYmPRPcmxB+KcTwZwgCIIAQGDg\nKvbv34ONjS0FCxaiQAGbzz2kfzyJRPK3D+Ts7QvnKJrxVy1atDzP5bGxcXTocIyrV/sBV3F0nEKr\nViUYPLgJSqXyvY8LoNFoOHBgzycLPt6HtbUVY8f+9YDzt9+OU7NmbZYvX8yzZ9Gkp6vx9u5Ihw53\nOH++LQkJbSlY8ABhYY5cv16UZcsWER0dxeDBvtSqVQeNRsOyZT9y5col0tMzaNOmHV991Ybg4Iv8\n9NMyLC0tefz4ERs3bsfLqzaHDp0EYP36NRw6tB+JREr16jX57rvv2b17J7/8spOMjEyKFCnC+PGT\nUSqNmTZtImZm5ty+fZPY2Fj69x9EvXoN33hOjRtXwsXlZ27fbg9AkSKHaNYs7xleQRByEsGcIAiC\nQGjoLY4ePcSaNZvQaDLp1asLlSt/2P51wn9TQMDvXL3aHX0lxPI8elSI2rVv5tmAPjU1lQkTRhMT\nE4NWq6F79944OBThxx/nk5qaipWVNePG/UC+fPkZMKAvpUu7EBJyBS+vxlSpUo1Nm9bzzTddCA9/\nyrx5s0hIiMfY2JhRo8ZRrJgTR48eZs2alUilMszNzfnxxxWf5Bps3LgWhUKBt3dHFi6cy/379/D3\nX8qlSxfYs2c3TZs2Z9WqFaSnpxueQzQxMWHp0kWcPn0SmUxG1arVqFu3PqdPn+TKlcuYmJgwY8Zc\n1Oo0+vbtgYNDUaTSNLp31zJt2lG8vBoxfvwoChQogIdHeaZNm0jjxk05c+Y0L14kMWrUOGrUqE3/\n/r2pWrUaAHfv3mbduq3ZGnTrZ2zPnj3N6dO/sWJFIEqlkqSkJADq1WtAq1b6ypMrVy7l119/pm3b\nDgDExcWydGkAYWEPGT162B8Gc7a2NgQGlmPp0s1oNDI6dXLC1dXp43wYgvAvI4I5QRAEgatXL1On\nTv2XMyVKatas88amvoLwrgoXHoBcHolEkk5CQkugJF5etWnd2puzZ0+TL19+evf2YcaMyTx79owJ\nE6ZQq1YdkpIS6datA7a2+dBotBQqVIgVK5bQuHEz7t69Q3R0FEZGRnTs2IV69arRr98AAIYNG4BM\nJkOhUGJv78DcuTNp2LAR/v5zsLd3oFixIgwbNuqTnb+nZ0U2b16Pt3dHQkNvkZmZSWZmJiEhl3F2\nLklgYAALFizB2NiY9evXsGXLBtq0acfJk8fZuHE7ACkpyZiZmVOrVh1q1qzNvXt3GTt2BOHhT5BI\npIwYMZbvv+9NRMQ9lEolFhYWZGRksHz5GnQ6HQ0a1CAuLo7SpUsTGnqLSZP+h5NTcVJSUnj69Aky\nmYwyZdyyBXKvXLx4nubNWxlmUrMC8fv377Fy5VJSUpJRqVL54ovqgH5GunbtugA4ORUnLi7uT69R\niRJFmD27yJ+uJwhCTqJMkCAIgkD2/lF6IpATPoxvv61OvnzuPH68jcePAyladC3ly5ciLS2NSpWq\nsm7dVkxNzVi1ahnTps3C1NSUmTOnEhJyhaCgTSQmJr68saDj4MF9PH36BIDUVBU+PoMMwU7Wd/jE\niWNERIRjZGSERAJ37twiNjaWevUavOyvVwC1Op19+379ZNfAxcWV27dvoVKloFAocHcvR2joLa5e\nvYJSqSQs7AE+Pr3o2bMT+/fvJTo6CjMzcxQKJdOnT+bEiWMolcaG/V25cplNm9axYMFitFotOp2W\nKVPGo9FoiI2N1V8NiQSpVEpw8EWkUikajYbq1WsAMGrUOLRaLT/+uIKtW3+mShV9/0Nj41c9Jb29\nWxpu6EgkEvK6t+PnN4nhw0cTGLiZXr36kJ6uNryXvdG3uDEkCB+PmJkTBEEQKF++AtOmTaJLlx5o\nNJmcPn2K4sU7fe5hCf8CtrY2tG0bh6lpfaRSCWq1mvBwfbCVNZPj7FwShUKBo2NxAgM307Ztc1au\nXEJCQryhOItUKiVfvvz06NEbAFNTMxwccs/kBAdfwNTULNdzgJcvX+LBg/vExj4nMTGRkJBgWrb8\nGktLq498BUAul2Nv78Devb9Qrpwnzs4lCQ6+QHj4U+ztHahc+QsmTpyWa7uVKwO5ePE8x48fYceO\nrfj7LwUgIyMdqVSKkZERJiamqFQqRo36HyNHDmX9+q0A6HTg4lKGSpWqvHytQ6fTUbVqdXbs2IZC\nocTMzJzHjx9RsGChXMfOXhSnSpUvWLNmJY0aNUGpNCYpKQlLS0tSU1XY2uYjMzOTAwf25rkfQRA+\nLhHMCYIgCJQu7UrDhl706PENNja2lC3r9rmH9LcTGRnB8OEDcXf34Nq1EFxdy9K0aQtWr15BfHwC\nP/wwhTJlxHV7XXDwRa5fv8q2bdtRKpUMHPgd6elqZLJXP0EkEglyuRHPnz/HwsICiUTKN990ZebM\nqVhaWjF8+Jgczcvj4+Py7EGm04GRkQJrayuOHTtM/fpfotPpuH//Hn5+kxg2bCTVq9di375f+fHH\n+Tx79uyTBHMAnp7l2bRpPWPH/kCJEs4sXDiPMmXK4uZWjnnzZhIe/hQHhyKkpqby/HkM+fMXIC0t\nlerVa1KunCcdOugrf5qampI/fwG0Wh1t2jQnNTUVY2MT0tPTSU1V4ec3iUePHpCc/IKrV0M4fvwI\nz5/HALBq1QocHIrg5laOkyeP07mzNzY2tigUCsLDnxIfH8fRo4dp0OBLQF/VslevLmg0mVSp8gXf\nftsNIyM51avXom/f/vTu3Y++fXtgbW2Nm5s7KpXKcL7Zg0FRHVcQPh4RzAmCIAgAdOvWi27dehle\ni15GuYWHP2Xq1FmMGTOB3r27ceTIQZYuDeDUqROsXbua6dPnfO4h/u2oVClYWFi8TCd8yI0b19+4\n7oMH91i82J+0tFTWrPmJFi2+5vr1EJYuXUhKSgppaSo6dOiMo2PxPLeXSPSzSJcvX2T37p0EBgag\nVqtp3LgpqakqduwIYunSRURGRlCokB0lS5bKcz8fg6dnBdatW427ezmUSmOUSiWenhWwtrZm3LiJ\nTJw4lvT0DAD69u2Pqakpo0cPJz09HdAxcOAwABo2bMS0aRNJS0tl6tSZuLiUoX//3vj5TUIul/P8\n+XN27NjB0KG+REVFIpFI8PbuyPLli+ndux9Nm7YA9NUply5dRXDwRc6d+515834E9M/mZRk4cCht\n27Zn585t3LkTapj1y/L11958/bV3rnPNaiSf5eDBEx/sOgqCkJMI5gRBEP7Drl9/wPjx14mONqNs\n2Xj8/ZtiZmb2uYf1t2Vv70CJEs4AFC9egsqVq77825moqIjPObS/rS++qMGuXdvp0qUdRYs64u5e\nDsg9WyORQNWq1ahatRqNGtVl5cpAdDodK1Ys4cyZk+h0OgoVsqdRo6bcuXObcuU8cjQoVygUdOzY\nBdBXZdy/fy9GRnLq129Ijx69sba2ZsOGdVhbW9O8eascs0ifQqVKVTh27Kzh9aZNOwx/V6xYmZUr\n1+baZuXKwFzLypXzZP78xQwY0NdQIXL8+MkEBW3i3r27jBo1DtAHVH5+kwzbWVvbULNmbcPrrEIn\nzs6lWLzYn6VLF1GjRm08Pcsb1qlbV9+0u3RpV06cOPqn5xgWFs66dVeRy7X061cTGxvrP91GEIT3\nI4I5QRCE/7DRo69x/rz+B/C9exqsrTcxZ86Ha+T8b6NQvCrqkPXMUtbfGo3mcw3rb83IyIg5cxbm\nWp59tqZXr755vieRSPjuu+/57rvvc7xfoUIlKlSo9Mb9denSgy5degDw/HkcPj47iYqywtW1F5Mm\nNUWhULzXOf0dZA+GdTodEok+7dTExCTXujduPCAhQUVg4DF8fFogl7/6+Ve0aDECAjZw9uwpVq5c\nQuXKVQ3PJWZ932WyP/9+P3kSRefON7h7tz2g48SJNWzfLm4OCcLHJoI5QRCE/yidTkd4uHm2JTIi\nInL/EBT+fsaM8TU0jG7X7htatWrNr7/uYsOGtZibW1CyZCkUCgVDh44kPj6euXOnEx0dBcCgQcMp\nV87zM5/Bp9O6dSC3b3sC6Zw+XQ2p9ADTprX83MN6b9HRUVy/fg1393IcOrQfDw9P7t69nWu9+/cj\nGDLEFLm8INOnV+fy5SBWrepoeD/rOcVGjZpiZmbOnj27/9J4tm27wt277V6+khAc3JF9+w7h7V3/\nL+1PEIS3I4I5QRCE/yiJREKpUgmEh+vQl3VX4eKS/rmH9beWOzXw8xR5GDNmApaWlqjVafTp050a\nNWoRGBhAQMAGTExMGDzYh1KlSgPg7z+H9u074eFRnqioKHx9B7J+fdAnG+vrIiMjGDVqKGvXbnnv\nfQUHX2Tz5g3MmjU/z/c3bz7N7dudAeeXSzYRGip77+N+bhKJhGLFHNm5cyszZkzGyakErVt7s337\n1lzrnj4dTXi4D9bW0RQpMoCrV02JiHiVbpn1nKJUKkEul+PrOzavI/7p99vMTAKkAfoWClJpHDY2\npu9xloIgvA0RzAmCIPyHLVpUlwkTNhATY4qbWyrjxjX73EP627K3L5yj3H32Ig+vv/exBQVt4uRJ\nfVrhs2fR7N+/hwoVKmFhYQFA/foNefLkMaBv+Pzo0UPDtiqVirS0NMDik433c7lwIZlXgRxAWayt\nj3yu4Xwwdnb2rF27BZksZ2AaFJRzVm3s2B8YN+5XQEdCQhcSErpgaXkChUJpWDfrOcXXZd+Xq2sZ\nFi5c9odj6tmzASdOBHLoUCPk8jS8vc/SoEHu4iiCIHxYIpgTBEH4DytUKD/Ll4tn5N6WTqdjzZrj\nPH6cTrVq+WncuNKfb/SBBQdf5NKlCyxfvtpQ6t/R0YlHj8KyjTP7TKGOFSsCczRx/qsOHNjL1Vgz\nsQAAIABJREFUtm1byMzMoGxZd4YNG0WTJvVo1+4bzpw5hVKpZMaMudjY2BIe/pRJk/6HWp1GzZp1\nCArazKFDv+XYX2RkBFOn/kBqaioAw4aNxN3dg+DgiwQErMDa2oaHD+/j4lKGCROmAPD772dYtGge\nSqUxHh7lc40xO3t7LZAK6NOHTUxuMHVqi/e+Dh/Sm67poUMnATh27DBnz55m7NgfmDZtIgqFgrt3\n7+DhUZ7GjZsye/Z01Go1Dg5FGDNmAhYWFnTt2hVHR2euXLmEWp2Ou/sTrl//DmPj+1SsuISxYzPQ\naDLp1asvtWrVzfE5xMa+4MWLlsjlxahX7wkREefy/BxeZ2RkxNq1Hbh8+QbGxgrc3LxFSwJB+ARy\nN2kRBEEQBCFPY8bsYvToWixe7E2/foVYv/63P98om8jICLp16/BeY8he6v/RozBu3LhOamoaV64E\n8+LFCzIzM3NUHqxSpRpBQa9mDfN6rupthIU95OjRQyxbFsDq1RuRSmUcPLiPtLQ03N09WLNmI56e\nFdi9eyegT+/s0KETgYGb39hM2tbWlvnzFxMQsJ5Jk/xYsOBVa4d79+4wZIgv69cHERERzrVrIajV\nambNmsasWQsICFhPXFwsfxQvDB78Jd7emyhadBdly27mxx/tsbe3/0vn/2d8fHr96Tpbt25ErU4z\nvH7TNdWnPeu9HhA9fx7D8uWrGTBgCFOn/sD33w8mMHATzs4lWb16BaCffX38+BGrV29k1KhxWFv/\nSt++k/n224307t2WlSsD8fdf9rINRJrhc+jVawQ3b44lKekUV660Zf36QoSGhub4HK5evWIYi5dX\n7Rxjk8lkVK7sgbu7qwjkBOETETNzgiAIgvCWjh61QKezBSAlpQz79t2kS5e32zYzM/ODjCGvUv8F\nCxaka9ee9OnTHUtLSxwdnTA11VcRHDLEl3nzZtK9+zdoNBrKl6+Ir+/odz7upUvnuX07lN69uwKQ\nnp6OjY0NRkZG1KhRCwAXlzJcvHgOgBs3rjFjxjwAvLwas3ixf659ZmRkMn/+TO7du4tUKuXp0yeG\n98qUcSN//gIAlCxZmsjICIyNjSlc2AEHhyIANGrU1BA85sXIyIglS9q9rPb4cYOLpUsD/nSdoKDN\nNG7cDKVS/1zZm65pdlqtzvC3RCKhfv0vkUgkJCcnk5ycjKdnBQCaNGnO+PH6z1WlUmFrmx/Q97eT\nSCSMHu3L4ME+3L9/k02b1gGQkZHBs2dR2NrmZ/78mfz+ezBWVrYoFI8ASEkpgY2NfY7PISoqMtuM\nqAjYBOFzE8GcIAiCILwlE5MM8uVbiEZjRUJCd5TKDJYvX4ytbT6ePYvm3LkzSCQSunX7loYNvQgO\nvshPPy3D0tKSx48fGRozg74B+fjxoxg58n+4upZ56zG8qdS/i0sZWrVqTWZmJuPGjaBOnXoAWFlZ\nM2nS9Pc+d4CmTVvkahOwadN6w99SqeSdWjRs2bKBfPnyM378FDQaDQ0a1DC8Z2T0qn3Aq9L4rwcP\nOt7Gp5gl8vKqzaFDJ9+YIhoUtJnnz2MYNKgf1tY2+PsvJSws7OX4pDg4FGHs2B8wMTFh9eqVLF26\niAsXzuHuXo5jxw5TqJAdp079xqVLFyhb1g1b2/xkZGTQr18v0tPVAKSnZ5CRkcHTp0+Jjn5Gz56d\n6NKlJ6mpqSxZog+mhwwZwZo1P5GYmIiDQ1GUSmO2bNnAnTu38fT05PDhh0gkKszND2BmFoWNjTmD\nB/fnxYskoqIiMTKS06hR049+PQVBeDsizVIQBEEQ3pKPjxUyWWEsLbfh7LyTAQNKcvToIQoWLMi9\ne3cIDNzMggVLWLLEn9jY54A+rXHIkBFs3LgdnU4ffDx+HMb48aMYN27SOwVyfyQgYAU9e3aie/eO\nFC5cBDc3T+bN28f8+ftITEx87/1XqlSVY8eOEB8fD0BSUiJRUZFvXN/NrRzHjumLjRw+fDDPdVSq\nFGxt8wGwf/8etFrtH47B0dGJyMgIwsOfAnDo0IF3Po+P51XAmFeKaLt2HcmfvwCLFi3H338pCQkJ\n3Lp1HaVSydy5i3BxcSUwcBVRUZFIpVK0Wi0//bTW8D2ytrahVq06VK1ajU2b1mNubo61tQ3fffc9\nAQEbKF7cmczMDIyMjHBwcKBQoUKsXr2R/PkLYGxsjJGREVWrVsPPbxLNmrUkMHATHh6eLFgwB5Uq\nBaXSGLlcR7NmVZBIoGjRSfTunYyVlQXTp88mIGA9derUe+NnKQjC5yFm5gRBEAThLXXsWIN69aIY\nPVrG4ME2qNWxlCrlwtWrV/DyaoJEIsHGxpby5Sty69ZNzMzMKFPGDTu7V89pxcfHM2aML35+c3B0\ndPpgY/v++8GGvxMTE2nb9jBXr3YHdBw4sJpt25phbm7+5h38CSen4vTp48OwYd+j1eowMjJi6NCR\nb2zPMGjQcCZPHs+6daupWrVajmNnrde6dTvGjRvJ/v17+eKL6piYmGZbJ/cYFAoFI0eOY+TIISiV\nxnh6ViAi4mme49VoNLmqPX4quVNEI3P19rtx4xpRUZEYG5vg7d0CrVaHqakptWvXw8LCkmPHDnP1\n6hVDsF+3bgNu375F4cJFuHTpPADffz+IsWNHoFanYWRkRL58+tRKiUSCVCqlV6/OaDQamjdvRVJS\nIj169GbLlg1s3LiODRsCsbcvzM2b1/n++8Hs3fsr8fFxfPll45cpumnUrl2WzZuDWbbsR0JCrvD8\neQwqVQrx8XHY2Nh+ugsqCMIbiWBOEARBEN6BnZ0dXbt248SJY8THx9K8eSsuXjxnmHXLkhWwGBvn\nbMRubm5OoUL2hIRc/qDBXHabN5/h6tVu6GeLJAQHdyco6Gd69mz0Xvtt2NCLhg29ciw7ePCE4e96\n9RpSr15DAAoUKMCKFWsAOHz4gKFVQvY2DkWKFCUwcJNhex+fgQBUrFiZihUrA/pqjzdv3uDq1Stc\nv36VYcNG8exZdI5qj35+k96p2uOAAX0pVcqFK1cuodFoGDNmAmXKuJGamsr8+bN4+PBBjmqP7yor\nRTQoaDNHjhzk3r07NGrUJNd6lSt/wcSJ03ItNzExYdWqdVhaWgFw5swpFAojxo79gdDQm5w/f/bl\ndT1I797f0bZtB6KiIhk48DvDPooVc2LyZH167b59v5KUlIhSqcTU1JSAgPXI5XIyMzP5+usmFClS\nlNq161KjRi3q1WuIj89AvLzqULFiZaKiIjl37gwBAeuRyWS0a9cKtVr0oxSEvwuRZikIgiAI76hu\n3fqcO3eG0NBbVKtWAw+PChw5cgitVkt8fDwhIZcpW9YtV4AH+mfe/Pxms3//Hg4d2v9RxmdsLAey\n/+BOw8Tk085ShYaG0qNHJ7p3/4Zdu7YzYMCQd95H9mqPP/20jjt3IvD1nZXjuv6Vao8SiQS1Oo3V\nqzcyfPhopk+fDMDatQFUrlw1V7XHv2rXrm3Url2Xr75qA4CpqSkpKSkAlC3rzrVrIYaU0eTkF4aA\n922lpKQYZgD37HnVF04mk+UYd/br5e7uwZEj+lTJgwf3GQqovC4jI5PmzQ+ycOEFXrzQz3IGB1/8\nw9RaQRA+PTEzJwiCIPyrZRWmeF/BwRfZvHkDs2bNRy6XU6lSFSwsLJFIJNStW58bN67So8c3SCQS\n+vcfjI2NLWFhD3OlC0okEoyNjZk1awFDh/bH1NSMmjVr53nMrVs38tVXbQzVD99Wp071OHAgkMOH\nvQEtTZrsoF2792uJ8K48PcuzZs3G99pH9mqPT54kkJhoyosXVciXT8eNGw9wcyuRY/23rfYI8OWX\njV+OswIpKSkkJydz/vzvnD79W65qj8WKOf3pWHOmm8Ls2X5ERIQTHx9Perqa3347RmJiIp07e1Oy\nZClWrAikQoVK9OnTnfR0NXK5EUOHjmDJkoXExDxjwIC+jBz5P9zdy6FSpTJkyPfodFoKFy5iOE6n\nTt2YNu0HAgNXUb16LbKe29uwYQPdu/cwFECRSCSG8Q0ZMpLp0yexceM6bGxsGDv2h1znsGvXWdLT\n5dy82RaptAGJiR3p3NkbN7dyODoWz/OcBUH4PEQwJwiCIOQQGnqL/fv3MG3aJFatWo6pqRnffPOW\n9fffko9PL5YuDSAqKpJr10Lw8sqdgvbhfPgfnFqtlhs3rjF16izDsv79B9O//+Ac61WoUIkKFV41\nFs+eYmhubs7KlWv/8Divl7J/fQxSad4JNvoGzu05fPg8UqmEhg07fLbnx95X06YtaNq0JTVrpqNW\n6wMzG5sANmy4jZ9fCdRqdY71jY3fLfDNkhWXTJs2m6JFixmWBwVtZuzYEbi4uDJ+fN4Ns+FVumn2\nFNHz539n1ap1rFq1HBsbW6ZPn0tw8EUWLdK3bLC3L0zhwg4sWfITCoWCCRPGULFiZaZPn4NOp0Ol\nSiEs7CHu7u74+c1BJpMxZ84MatfWp366u5dj06YdhjH06eMDgJWVVa7vVtOm+mbpdnZ2+PsvzTX+\n7EHdtWuJ3Lt3GQCt1ob79zcyadIlGjWqnuc5C4Lw+YhgThAEQcjB1bWMoejC+9x5z8zMRC7P+38z\nWT25IiLCOXTowEcO5vR0Oh1LlizMs31AXqXkAX7//QyLFs1DqTQ29NZ6+PABI0YMRiaTMXbsCNLS\nUpFIpBgbG6PT6ShWzJELF86RlpaGvb09CxYsoVAhO/r2/Z6wsOJIpR589ZWMHTsmvnMpey+v2nz1\nVVsuXjxPvXoNuH07lOnT9Y22L1z4nZ07t+PnNxsAuVxOkyY18r4Y/xCVKlVl9Ojh1KvXEJlMg1Sa\ngFSagkaTn9TUaLRaLb/9dgwzs9yFXczNzbGwsCQk5AqenuXZv3+PIbDW6XQcPXqIihUrExJyBXNz\nC8zMzKlatRrbtm1m6NCRANy5E8quXdvw919qSGf8I3l953U6HdeuhTBtmv5zqVixMomJiahUKUgk\nEmrVqoNCoX/GLjj4ouG7J5FIMDMzZ//+PTl60anVavLly/cXr+jbcXOzRKl8jFqtD2rz5TvPyZPh\nPHmSQs+eDd54E0EQhE9PBHOCIAj/cpGREYwaNZS1a7cAsHHjOtLSUrl8+RJly7oTHHyR5OQXjB49\nAU/P8kyfPpkjRw7h5laWGzdu0rVrD8LDnzJv3iyCgy9QurQrAwYMZcuW9URHRwH6yoXlynmyatVy\nIiKeEhERgZ2dPV279mT69ElkZmai1erw85uNg0MRQ+rjsmU/8vhxGD17dqJp0xb89ttxBg/2pVSp\n0gD4+HyLr+8YnJ1Lvvd1OHHiqKF9QEJCPL17d6N8ef1Mz717d1i/Poh8+fLj4/Mt166FULq0K7Nm\nTWPRouU4OBRhwoQxSCRQvHgJatasjY2NLXXq1Gf48IFYWFiyZs1GlixZyC+/7GLgwKHUqlUHb++W\nzJ8/m169BnL5spzY2BokJzcmNPQuJUq8KsOf1/HbtevI1q0bWbRouaEQRlpaGm5u7obnzzp39iYx\nMQErK2v27PmFFi2+eu/r9HeSVUFzxozJlC6dQHy8FdHRP2BsXJXHj7fj43MCV9cypKamGrbJfgNi\n3LiJzJkznbS0NEMft6x1FAqFodrjmDETAOjRozcLF86le/eOaLVaVCoVcXGxDB8+kKZNWxAScpmI\nCH3z8pEjx+HsXDLXd37QoGHMmuVHZGQEMTHPCA29CcCJE8c4cuQgmZkZJCUlodVq0Wq1nDlzkqNH\nDyGRSEhLS8vzOcu8+vt9TG3a1CAs7AAHDlwiNTWeiAhYvrw3kERwcBCLF7f/ZGMRBOGPiWBOEATh\nPyb7j12tVsvKlYGcPXua1atX0K/fQC5fvkT58hVZvHghdevW5cGDe8yadYmvv26LVquhd28fRo8e\nxrRps/DwKE9UVBS+vgNZvz4IgEePHhnSxhYsmE27dp1o1KgJmZmZ2RpK68fg4zOQTZvWM2vWfAAs\nLCzZt+8XSpUazuPHj8jIyPgggRzwp+0DcpaS1/9gL1zYAQcH/TNKjRo1ZffunQCGmZbTp3+jWbOW\n7Nv3KypVCsbGxmRkpNOkSXNkMhkFCxbi6tXLnD9/h7S0VzM7KlUptNpXP9rfppQ9gFQqNVSLBGjc\nuBkHDuyladOW3Lhx3TCr80/wtq0DslfQPH78ApGRT2natB/W1qNyrZs9VRCgVKnSLF++Os/9Nm7c\nnEGDhhteJye/ICzsCT4+g3K0UWjXrhWLFi1n1arluLiUMaRKTp06gdWr9c8EZv/OZ6VKtmvXEW/v\nlhQr5oiTUwl2797BunVbCQm5zPjxozl16jdiY2NJSUkxpEqOGzeCnTu30b79N2g0GtLSUg2zk+3b\nd8LGxoakpERUqlTs7Oz+9Nq9j2HDGjNsGAwatJ/Q0HYvl1py8GAJww2E1w0Y0JeBA4fh4uLKiBGD\nmThxGjodHDq0n9atvQF9gZoFC+YwderMdx7TtGkTqVmzdo7/BgThv04Ec4IgCP9hdevWB8DFxZWo\nqEiuXr2Mh0d5kpKSMDc3p1q1GoSGhhITE839+3dQKJTMmeNHYmIC8+e/el5MpVKRmpqaK23Mza0c\na9cGEBMTTd26DShSpGiO478+C1G//pcEBq6if//B7Nmzm2bNWn6wc5VIJG9sH5BVSh5AJpO+DDpf\nTzHNua1Op8tzn1nvAUil+mClenVXjI33kZKin40zM7uFTvdqZi738TPzPAeFQpkjGG/WrBWjRg1F\noVDQoMGXnyX9bc2anzh4cB/W1jYULFgIF5cy1KlTj3nzZpGQEI+xsTGjRo2jWDGnHK0DypXz5MWL\nJMPr+Pg4Ro8ez969vxAaepOyZd0NwdmcOTMIDb2JWp1GvXoNsbb+EgBv75Y0bdqC06dPotFkMmXK\nDIoUKUanTt4sWxaAtbU1Wq2WTp3asnz56jwDEIDjx68xcmQMYWHuODufY84ce2rWLGt4P69UybCw\nh2+RKinFzMwcZ+eSHD9+BC+vOkil+psJkZERWFpakpSUyIIFs6levRbDh49m9mw/9uz5GalUiq/v\nWNzc3HP095PL5QwfPuqjB3NZjIw0OV4rFKnI5UZ5rpv9uzl7tj+gzwzYuTPIEMzlz1/gLwVyWfsX\nRVcEIScRzAmCIPzLyWSyHLNA6emvCkZk/SiTSmV5BjD58uUnKSkBMzNzlErjl72vLGnR4ktWrAjE\nyCj3j7rsxTq8vJrg5laOM2dO4us7mJEjxxqKQ+TF2NiYypW/4OTJ4xw7dpiAgA1/9bRz8fCowM8/\n76Bp0xYkJiYSEnKZAQOG8PDhgzzXd3R0IjIygvDwpzg4FOHQoQM59nXw4D7q1m3AsGEDsLKywtTU\njLS0NOzsCnPkyEEaN25GcvILypRxw9m5GF9+ac2FC9tRKlVUrHib48c1eR43u6xS9llplq/Lnz8/\n+fPnJzAwAH//JX/twryHW7ducOLEUQIDN5ORkUGvXl1wcSnDrFl+jBgxhiJFinLjxnXmzp1pKLqR\n1TpAIpHg5zeJ5ORkli9fzalTJxg9ejjLlgVQvHgJevfuxt27dyhVqjR9+/bH0tISjUbDkCH9efDg\nHiVKlEQikWBtbUNAwHp27tzGpk3rGTXqfzRu3JSDB/fRvv03XLx4npIlS+cI5BYtWp7jPObPf0xY\nWEcA7t93ZsGCzTmCuSzZA/fsTbOzvvOZmZk51gsK+hnQf687dOicZ6pkz559OHfuDLt2bcfS0pLp\n0+fmeF+tVlOyZCmWLFmFiYlJru3/iqwiR0OG+L5xneDgiwQGriIiIoLSpbejVr9AoylEhw5NuHXr\nBkuW+KPRaHB1LYuv75hc/xZ4e7dk1ap1LFu2iPDwp/Ts2YkqVarRpk07RowYzLp1W9FoNCxduojz\n588ikUhp1ao1bdu2Z/XqlZw5cxK1Wo27uwcjR44z7DevmyeC8F8mgjlBEIR/OVvbfCQkxJGUlIix\nsQlnzpziiy+q57lu+fIV2LZtM8WKOZGcnMzp06dwdi7Fw4f3sbCwwNLSEp1Oh6urG0FBm+nUSV+U\nIetH9+siIsIpXNgBb++OREdHc//+vRzBnKmpGSpVSo5tWrb8mpEjh1C+fMUc6W5/Vdad/HdpHwCg\nUCgYOXIcI0cOQak0xtOzAhER+p5gvXr1Zfr0yRw7dgSFQkFaWio9enRCItEfZ+/eX9i4cR3JyS/o\n1asvACNHDmT06OGo1WsoXLg6Jiam2caY99hbtWrN8OEDKVCgIP7+S/OclfDyakJiYuJblc//0K5d\nC6F27XoYGRlhZGREzZq1SU9Xc/16COPHv0qDzMjQBznZWwdkyWrLULy4M7a2+ShRwvnl6xJERUVQ\nqlRpjh49yO7du9BoNMTGPufhw4eUKKFPv61btwEApUu7cuLEUQCaNWvJmDG+tG//DXv2/IydnT19\n+nQnMzODsmXdGTZsFE2a1KN1a2/Onj1NfLwUY2NX8uefg1weRWKifsZ6795fiIuLZfTo4UREhDNj\nxlQWLlxKcPBFYmKeYWpqRkREOCEhlwkJCebx40dUqlSF4cMHkpqqQq1Op0WLr6hatVqeqZImJsbI\n5XLq1m1A0aLFmDJlQo7re/nyXYYMucPdu2VwcjrJjBkO1Knj9t6fW/YiR38mKiqCuXNHExOj4bff\ndlGoUBR+fktZuHAZRYoUZerUHwypodllzaL5+Azi4cMHhpTUyMgIw+e/e/dOoqOjWLNmE1KplKSk\nJADatu1Az559AJgyZQKnT598Y/sOQfivE8GcIAjCv5xcLqdHj9706dOdAgUK4ujoBOSVsiShdGlX\nKlaszOHDB+nbty9ly7phbW3D5cuXyJcvPz16dCIzM5Patetw+/ZNunfXP9tTvnxFfH1Hv9zvqz0e\nPXqIAwf2IpfLyZcvP9269TIcG6BkyVLIZDJ69OhEs2Ytad/+G1xcXDE3N6d581Yf5Pyzl09/m/YB\nWZUMAb74ojobNmzLtU/9DMqcHMtUKhXPnkVjb18YpVKZaxsbG9scz2/5+AwEcpayf/34bdt2oG3b\nV/3hsp+LRqMhKiqS4OCLtGz5dR5n/inkTjPV6XSYm1sYfry/7vXWAVkzOlKpFIXi1eyOVCpFq9US\nERHO5s0b+OmndZibm+PnNynH7HLWNq/SY6FQITtsbW25dOkCV6+GULq0C8uWBSCTyZg7dyYHD+4j\nLS2NSpWq0r//YNq27UJy8hyePg3E1PQs+fOPAwYB+l5z//vfJJRKJZ06edOxYxusra1RKPSfsUQi\nISbmGYsWLcfOzp6NG9exZ8/PyOVypFIJP/+8ndq16+aZKqlQKPHzm2RIue3Xb2COazN79m1u3dIH\nSffueTJnzmbq1HEjNTWVCRNGExMTg1aroXv33lhZWRlmy8qX92TAAF+MjIy4desGCxfOJTU1DSMj\nI/z9lxIaetPQM/HmzessXDiP9HQ1SqWSMWN+oFgxR8MYChYsZLj5U7iwMWvW/EThwg6GlOmmTVuw\nY8fWXMFc9u/Dm1y6dJ6vv/Y2pAdbWloCEBx8gY0b16FWp5GUlESJEs4imBOENxDBnCAIwn+At3dH\nvL07vvF9a2trQ0rY6NHjGT16PAUKWBAT84L09HR8fAbmmVL5uqxZqCxduvSgS5ceudbLCkrkcnmO\nnlcZGRnEx8eh1WqpWrXa25za38KxY1cZOzaaR49KUrr0YRYudMHD48MUbslLRMQzevf+jefPNyOT\nSXF2bvDRjvVHPDw8mTXLj65de5KZmcmZMydp1aoNhQsX5tixw9Sv/yU6nY779+9RsmSpd96/vtea\nCmNjE8zMzIiLi+X338/kCL7fpGXLr5k8eTwlSjjnKO2fnp6OjY0NRkZGhiCladNaXL8ejonJTsqU\nUbBx46vZ4saNmxmK4HTo0AkLC0vat/8GL686L7dtQXR0FHZ29gDcvHktRw/AjIwMnj59gpubO506\ndcvVhiMgYP0bz+HFC2Wer8+dO0P+/AUNz6UlJyfTrVsHw2zZnDlT2blzG61be/PDD2OZPHkGrq5l\nUKlUuW40ODkVZ/HilchkMi5cOMeKFYtz9E/MfsMnK1BPSkrMsex9vL69Wq1m3rxZrFq1jgIFChIQ\nsIL09PT3OoYg/JuJYE74P3vnHRbF1cXhd5elSK8qikRABBVFECtW7L0XiNIixoLGFmvsBaMkdgVR\nEMUSNcZesNdYsUQFKwQQLKj0ust+f6xsqJaoMeab93l8dGbuvXPv7IL3zDnndwQEBARKRS6XM336\nHn77TQ+xWIaraw6TJnX8JPfKz89n9OidnD37J+XKHaBVq8/lafp7+PvH8vChwjNx504dFi7cQljY\npzPmFiy4wJUrnoAXAIsXb6VXL/k7iUMUFIIfNWrYB8/D1rYmTZs2x8NjAIaGRlhZVUNHR5vp0+fi\n77+A0NBgpFIpbdq0UxpzxedY+Li0a9WqWVO9ug1ubr0pX74ideqUVPl83bpY+GZz5s+fRc2adtja\n1iySryaTydiy5S8jSiwW07ChDa6u7QHYuLGoKmYBcrkcsbjkM9bQKJrLNnbsBOrXL/oyIiLiirKm\nYlLSS/z8zpGWpk6LFuX4+uvSvU5Nm+Zx5cozZLLyQDJNmiiMTCsra1auXMrq1ctp0qQZmpqaRbxl\nPXr0ICQkFCen+hgZGStDKjU1NUvcIy0tjTlzZvD4cRwikUiZ91fA06dPuHXrD+zsanPkyCFsbWuw\ne/dOZS7p4cMH3mhca2pqkpmZWeo1J6eG7N69E0dHJ1RUVEhNTVV+hrq6emRmZnLixFFcXNqWOb6A\nwP87gjEnICAgIFAq27efZu1aF6RShcdh1aqHNG58lRYt3u4VeV/Wrj3G1q29AV2MjIzYu/cpQ4c+\nJjQ06IuQIk9LKxo6mJFRMszy49/vL6MiJUUbqVT6Tt7Tj60G6Oo6CG/vIWRnZ+PrOwQbmxqYmlbi\np5+WlWhbvHRA4WNT00q0atUGN7feSmXMhIQE7t+/S0xMNGpq6mhpaTFx4jRevnyBj4/FVRY3AAAg\nAElEQVQH27fvARR5WH5+swgN3cqdO7fw85tHRkYaYrGY+vUbsmDBXK5du0LNmnZcuxZB48bO5ORk\ns3r1ciIirvD4cTytW7cjKiqSgIDl5ObmMGbMCO7du4tUKsXUNIjTp48TFxfHokVLAJDL8xk3bhQJ\nCfGkpCQTGxuDuXlVkpKes2iRH4aGhrx8+ZJ+/Vzp0qVHkZqKjx+bcf36KkDEoUMPUFU9T79+JQu8\nT5jQESOjE9y5k4elpZjhwxXqrlWqmBMcvInffz9LUNAq6tWr/7c/v7VrA3Byqo+fnz9PniQycuS3\nRa6bm3/Fb79tY8GC2VStakn//l9Tq1Ztpk2biEwmo0aNWvTo0afM8fX09Kld2x539/40auRMr159\nld/Brl17EBcXi4eHKxKJhG7detKrV1+6du2Bu3t/DA2NqFnTrsh4gpqlgEBRBGNOQEBAQKBUoqNT\nlIYcQHa2BQ8eRNCixce/V2KiDFDky8jlIrKyTIiOTvxipMgbN04jKioV0EUieULTpm9XqvwQWrZU\n59ixR2RnWwK5NGjw5I2GXGjoOg4d2o+BgaGyfMD27dvZtGkzeXlSzMzMmDZtNjKZDA8PN7Zs+RWJ\nREJGRjqenl+zZcuv/PbbDnbv3omKigpVq1owa9Z8ABYunEdMzCNyc3Pp2LEL1tY2f2tNZSljzp07\nk7FjJ2Bv78C6dYGEhKxh1KhxSKV5JCYmYGqqUA9t3bod2dlZjBw5mceP26CnF06lSnU4cGCvsvB4\nXFwspqaVcXZuTljYemWdxVmzfuDy5QuMGzeRdu06cvXqZebP92f//j0sX/4zR48eIj9fjrW1NQ8e\n3MfR0Ync3FzGjPmeZ8+esnZtgFKx08zMnLS0dLKyshCLxaxevZyOHbsoayqOGzeJhg1fUGCMZ2dX\n4+zZ6/QrpQ63SCRi8OCSIbRJSUno6OjQrl1HtLS02blzO0+eJCq9Zbt378bBoR7m5lV58SKJqKg7\n2NrWJDMzo4jaLEBGRoayxuH+/XtK3EtFRYVp04rWL6xXr36pSrOFVUILDG2AGTPmFmkXGrpVOfbI\nkWMYOXJMkes+PsPw8SnpOS7+MkBAQEAw5gQEBAQEyqBzZzuWLj1BYqJC2c/c/CDt2tX92+MdPLiP\nrVs3KUPnBg8eyvz5s0hJSUEul2BoqM/Ll4pwqgoVHuHg0JczZ/Yrc2qioiJZsWIxWVlZ6OnpM3Xq\nDIyMjImMvM2CBXMQi8U4OTXk4sXzbNjwCzKZjICAFVy/fpXc3Dx69epL9+69PvzBlIKfX3fMzI4Q\nEyOndm0NPDzaf5L7FODh0QJ19XNcuBBB+fJSxo0rOyw1KiqS48ePsH79FmQyKd7eA7G1rUHbtm1p\n2VKRvxUUtJp9+3bTu3d/HBwc+f33szRr1pKjR8Np2dIFiUTCpk2h7NixV2nkFVB8o/53KU0ZMzs7\ni/T0NOztHQDo0KEz06YphHZcXNpy7Fg4Awd6cvz4UebMWcCKFdvJzExDW/sKMpkh8fFxJCdfYcqU\nGezatYOqVS0wMDAkIGA5xsYm/P77Wc6ePcWIEd8RGXmbiIgrbNu2hS5depCXl8uBA3vIz8+nXLly\nTJkyk6ioO+zevZO7dyORy+UMGtQPfX199PUNSE1VhCdKpVIGDPhaKUrTrl0LtLS0ld9jXV1djIxu\n8VfkoRQDg7z3elaPHj1g5cqliMUiJBJVxo+fTHp6mtJb5uBQlx49+iCRSJg924/FixeRk5ODhoYG\nixevfP2SRDGWm5s78+bNIDR0HY0bN6VoeZLP8zIlJSWFn38+Q1aWKl26VKZ5c7u3dxIQ+D9FMOYE\nBAQEBErFzs6SlSufEha2HbFYzuDBVlSp8vcKFT969JANG4IJDAxBV1eP1NRU5s6dQadOXenQoTP7\n9+9BLA5CVTWVly//oEWLr9DW1gFQ5vEsWbKIH3/8GT09fY4dC2fNmlVMnjyd+fNnMWnSdGrVsiMg\nYIVy87lv3260tbUJCtpAbm4uw4cPpkGDRpiaVvpoz6gAsVjMyJHtPvq4b2LAAGcGlK1po+TmzWs0\nb97qtfCFOs7OzZHL4d69eyxa9BMZGelkZmYpxUC6du3B5s0baNasJQcP7mPixB8ARZ7WzJlTad68\nJc2atfwEKyq9AHtZuLi0Zdq0SbRo4YJIJKJyZTOysmTk5lYjLu4XACSSaJycPJR9ChdnB5g2bQ4v\nX75g06ZQpRImKBQy160LpGLFSsTHx+PpOZi5c6fTr58bcnk+8fFxGBgYsnHjL7i59Wbt2o2IxWLG\njfuVkydT+O23RCIi9jFjRpcSaypXrhxTp+qwcOF2UlL0qFfvTyZOfD/l1gYNGpUqEFTgLSsQLwJF\nXmNhFVUoquBqZ1ebLVt2Kq8VeMSKq6z+U+Tl5TFo0GEuXPACxOzbd46goDul1v0TEBAA8eeegICA\ngIDAv5emTe0ICOjAqlUdcXQsWUfuXYmIuIyLS1tl8WtdXV3u3PlDqezXvn0nkpL+JCioHZ07V8fE\nxEDZVy6XExsbQ3T0Q0aPHo6XlxsbNgTz/Plz0tMV4Wy1aine3Ldt20G5eb58+QKHDu3Hy8uNb7/1\nJDU1hfj4uL+9hi+X0j0rkydPZty4SYSGbsXb20cp91+7tj2JiYqSBzKZDAsLSwAWLVpCr159uXs3\nCh8f9yLGz8egTh17zp07Q25uLpmZmZw/fwYNjXLo6Ohy48Z1AA4d2q80QipXNkNFRcz69Wtp3Vph\nSHt5daRcuVg0NK4B+dSqtRMNjb/CTwsbVvr6iiLiNja2PH36tMR8/vjjBiNHjkFPTw97ewdSUlLI\nzs4GRDRt2pzKlSsTEXEFAwNDXr58werVW9m0qSu5uZVIS7MnKKgRJ09eUY5XuKZir14NOH++HVeu\n1CEsbECpwiT/JHK5nLCwU8yZc4jDh69+1rncvfuACxdaULBFTUpy5uDB2M86JwGBfzOCZ05AQEBA\n4JMjEpXudSnLE1NaZJeFhRUBAcFFzqWlpb1xvNJUBf/fqFvXgXnzZjFwoCcymZRz587QvXsvMjIy\nMDQ0QiqVcvjwAcqXr6Ds06FDJ2bPnoan52BA8VyfPn2Co6MTderU5dixcLKzs9DS+vCi7gWUpYw5\ndepM/P39yM7OpnJlsyJ5Uy4u7Vi9ehk+PsMBMDOriL//HObPn4hUmoOOjoScnGxl+8JKjQUeXLFY\nhfx8GWKxSolriu9TUbVNkQgkElWlYmdCwmNGjhyKnp41+fkFpTlE5Oaa8+efV99YU1FLS+ujPb8P\nYdasfQQGtkEmK4+WVhSzZp3G3b35Z5mLsbE+OjqJpKUVqMFK0db+tDmoAgJfMoIxJyAgICDwyXF0\nrM+UKeMZMODr12GWKdjZ1eHYsXDat+9EePhBZV6UXC6nsE0mEokwN69KcvIrpUS6VColLi4WCwtL\nNDU1uXPnFjVr2nHsWLiyX4MGjdm5cwcODk5IJBJiY/+kfPkKJYpWv4n09HSOHDlEz55lq/X9E6xb\nF6isbwYQGLgSQ0Mj8vJyOXHiKLm5eTRv3pJvvlEoEU6ePJ5nz56Sm5tD376utG7dFk9PVx4/jqdK\nFXO2bt1Ez549cXXtSW5uHuXKafDixQvl/dq27UBQ0GratlXk/slkMubMmU5GRjpyuZy+fQd8VEOu\ngNKUMa2tq5cIE/yr/UBcXQcWOdeoUUP27PkVUBhvPXp0IDU1hZ9+Ws7Ikd8qw0m//34qNja2JCcn\nIxaL2b59NxERVyhfvgKjR3/PkiX+hIcfVJ7X1zegR4/evHiRBKBU7HR378/ChUtJTc3l8uVw4uP9\nALCy2kWHDvXw8Ci9puK/ifBwzdflDyAjw5YDB+7g7v555lKxoikjR95i9epjZGQY4+x8ke+++7JK\nlQgI/JMIxpyAgICAwCfHwsISd3dvfH2HIBarUL26DaNHT8DPbxabN2/EwMBA6XEpLM5QgEQiYc6c\nH1m61J/09HRkMin9+7thYWHJpEnT+PHHeYjFIurWrac0Mrp27UFiYgLffDMQuVyOgYEh8+cveuc5\nS6VS0tJS+e237Z/dmOvcuRtTpnxPv36u5Ofnc/z4EYYMGcHVq5cICtpAfn4+kyaN48aNa9jbOzB5\n8nR0dXXJycnGx8eDFSuCcHf3plmz+gwePJRWrdogkUg5evQYmzcrDJ/CoiY3b16nVas2ymcpkUhY\ntWrtJ1/nx1LGLEAikeDpORgfHw9MTMrz1VdVAUpRSS3sfVP87e09BD+/2Xh4uFKuXDl++GFmob4l\n71W9ujlr1mSxceM2VFTk+PjUQE1Nwvff7yUlRYOmTdU/m7frbWhoFBVgUVOTltHyn2H06La4u78g\nPT0dM7P+ygLsAgICJRHJ3yfb+BNSkKgr8N+jcCK2wH8P4fP97/KlfLZZWVmUK6co2jxt2kQiIq5i\nYlK+hGKmvr4BU6ZMp0KFisybN7NI/bq2bZtx5MgZIiKusHZtALq6uvz5ZwzVq9ty9uwpzM2/on79\nRgwfPuqzrXPMmBEMHz6KFy9esG/fbkxNK3Hy5DG0tbVfP4dsBg3ypHPnbqxbF8iZMwqP0JMnCfz8\n8wpq1rSjRYuGnDx5AZFIhIFBObp374mNjS1NmjTD2bkZEomExYsXcvHiBfz9l2JmVoWwsDOcOpWF\nnl42kyc3xcjI8LM9gy8JuVxO797bOHv2G0CEmlo08+ff+UcMuvf92d227XdmzlQlKckOC4vzLF9u\nSoMGNT7hDAU+hC/ld7PA+2NiovPefQTPnICAgIDAF83582cJCwshKyuLpKTnTJmykGXLnrN/v5gz\nZ0YzbFg/evTozf79e1iyxB8/P/9S5Nb/Or5//y4bN26jYkVTnjxJJDr6ISEhm//ZRZVCly492L9/\nL69evaBz525cvXqZgQM9S5RbiIi4wtWrlwkMDEFdXZ2RI78lNzcXADU1deXaJRIJQUGhXLlyiZMn\nj7Fz5zaWLl3NmDETlGNt3XqOKVNsyM62AuQ8eBDMb7/1/SJq/30KQkNPcfZsDrq6CsPW2LhswzY5\n+RV//GFNwXcrN9eCCxciPlv44pvo168xTZsmEhl5A0fHuhgYCAa7gMCXguC3FhAQEBD4omndui0h\nIZvp06c//ft/zeLFzzh/fiAPH7qRkvKSkycVm+n27Tvxxx/X3zpejRq1qFhRUSz9nw5e8fUdQlRU\nZInzBw7s5fr1q1y8eJ6oqEgaNWpCw4aN2L9/D1lZWQA8f/6MV69ekZmZgY6ODurq6vz5Zwy3b98q\n9V6ZmZmkp6fRuLEzI0eO5cGDeyXanD2b8dqQAxBx40ZNkpKSPtp6vyQ2bTrDDz/UYvfu3mzc6Mbg\nwcfe+P3Q1tbB0PD566N8QIaBQc4/Mte/Q6VKprRu3VAw5AQEvjAEz5yAgICAwH+CAsXMhISiYSqJ\nieVKtFVRUSE/X7ERz8/PRyr9K2dIQ6Nk+38CmUxWSi7XX4jFYurVq4+Oji4ikYj69RsRExPD0KFe\nAGhqajJt2hwaNmzCrl2/MnBgX6pU+Qo7u9rKMQqPnZGRwYQJY1577eSMHDm2xD2NjHIBKQXbBWPj\nRHR1rT/amr8kfv89i5wci9dHIv74w5YXL15gbGxcQnCmW7eedOrkQpMmzkgky0lNdUVffwUmJr0Z\nNGgjRkbGDB48jICA5Tx79pRRo8bRtGlzfH2H8N1347G2VpQBGTbsG8aPn4yVVbWyJyYgIPB/jWDM\nCQgICAj8JyhQzLSy0iUuTo5YnEJ2tj36+hFAxyKKmRUrmnL3biQuLm04e/Z0Ecn6wmhqapKZmfnW\ne2/evAE1NTX69BnAsmU/8fDhA5YuXc3Vq5fZv38PjRs7Exa2HrlcTuPGTRk2bCSgyNXr3r03V65c\nYuzYCUXG3L9/D2Fh69HW1qFateqoqkq4ffsP5s5dqGzTt+8A+vYtWTnc339ZqfMMDz+l/LeJiQlB\nQaFvXNfEiS48eBBCRMRX6OklM3Gi0evi4/9/GBrmUNiwNTJKRFdXIdBSXHCmZUsXsrOz6datDT/9\nNJ+srCw6dFhC48bOjB49nilTvmfdugCWLl1NdPQj5s2bQdOmzencuRsHD+7F2nocsbF/kpeXJxhy\nAgICb0Qw5gQEBAQE/hMUKGZu3LieOnW2AGY4OjYjO/ssHh6uRRQzu3XryaRJ4/D0dKNhw8aUK/dX\n0ebCjjE9PX1q17bH3b0/jRo5lymAYm/vyNatYfTpM4CoqEikUilSqZQbN65RpYo5AQErCA4OQ1tb\nh7FjfTlz5iTNmrUkOzubWrXs8PUdXWS8pKQkgoPXEBwchpaWNj4+HiQkxNOtWy8qVzb7oOckl8vZ\nvv0UmZlyWra0oWrVSmW21dTUZNOmAWRmZqKhofHZVAWLC9Z8DiZNas2jRyFcu2aOnl4ykyaZoKam\nBsD27VuUgjPPnj0jLi4OsVhMy5atEYlEaGlpoaqqqiyLYGVVDTU1NVRUVLC0tCIxMRGAVq3aEBq6\njuHDv2P//j106tT18yxWQEDgi0Ew5gQEBAQE/jN07NiFjh27FDs7sEQ7AwPDIrXLCjxljo5OODo6\nFWk7Y8bct97XxsaWu3cjyczMQE1NDVvbGkRFRXLz5nWcnZvj6OiEnp4+oKjhdv36NZo1a6nc8BdG\nLpdz584tHBzqKft07tyVuLhYRoz47q1zeRNyuZzRo39l69buyOWGWFjsJTg4m1q1LN/YT1NT843X\nPyVvCz/9p9DU1CQsbABZWVloaGgo51O64ExOEbEZABWVv7ZcIpEIiUQVUITPymSKotgaGho4OTXk\nzJmTnDhxlODgTf/gCgUEBL5EBAEUAQEBAQGBQty8+YAffjjIrFn7ePny1Tv1kUgkmJpW5sCBvdSu\nbU+dOnWJiLjM48fxmJqaFhPKkCs3+cU3/AUUP/WxdFiePXvGnj22yOUKkYvo6K5s2FBS+ORTcPjw\nAXx8PPDycmPRovnk5+fj7+/H4MHuDBrUj3XrApVt+/TpyurVy/H2HsjJk8cAhSEaEXGFyZPHK9td\nvnyBKVO+/0fmX0C5cuWKfGaFBWdiYqLLFJx5V7p27cGSJf7UqFFLWXZCQEBAoCwEY05AQEBAQOA1\nkZExeHk9Zc2afqxcOQA3t2PvlDMHYG9fly1bwqhb1xF7ewd27fqV6tVtqFGjFtevR5CSkoxMJuPo\n0XDq1nUscxyRSETNmnZcvx5BamoKUqmUEyeOfpT1icVixOL8Yvf79IqdMTHRHD9+hICAYEJCNiMS\niQkPP8iQISNYu3YD69dv4fr1CB49evB6TiL09PQJDg6jdet2ynOOjk7ExsaQkpIMwP79e+nSpfsn\nn/+baNiwCTKZjIED+xIYuFIpOFPcSC95XPo1GxtbtLW16dy526ebtICAwH8GIcxSQEBA4F/I/fv3\nSEp6TuPGzp97Kv9X7N4dRVxc39dHIiIiunP27BXatWv81r729g5s3BiCnV1t1NU1UFdXx97eASMj\nY4YO9WXUqKHI5XKaNGlG06aKwtFlhQ4aGRnj7T2Eb7/1Qltbh+rVbT4ozDA/Px+xWIyJiQl9+pxh\n40YLpNKKWFvvZPBgu7897rty9eol7t6NYvDgQQDk5uZiZGTE8ePh7NmzC5lMxosXSURHR2NpqRD8\naN26baljtW/ficOHD9CxY1du377F9OlzPvn834SqqmqpgjOFxWaKH3t7Dylx7eHDP4mMjMPaujz5\n+fk0aNDo00xYQEDgP4VgzAkICAj8C7l//y5370a+lzEnlUqRSIRf6x+Cjg5ADqBQbFRTe4qJid47\n9a1Xrz4nTvyuPN6yZafy323atKdNm/Yl+hTf8C9fHsi6dYHcuHGNfv1c6dSpK4GBKzE0NCIvLxcf\nH3dyc/No3rwl33zzLUCpsvhQVClz3LiJ1K5tD8CCBT1p2fIC6ekRNG/uRIUKRu/6eD6Ijh278O23\nI5THCQmPGTvWl7VrN6Ktrc38+bPIzf2rDlu5ckVLRBSEqnbq1I2JE8egpqaGi0ubzybK8jHZsOEM\nc+fqI5MlU7HiTLy8/oWVxQUEBP6VCP/rCwgICJTCwYP72Lp1EyKRiGrVrBk8eCjz588iJSUFfX0D\npkyZToUKFZk3bybq6hrcv3+XV69eMmnSNA4c2EtU1B1q1rRTqie2bduMbt16cunSBQwNjZk1az76\n+vr4+g7B13cMtrY1SE5OxsfHnS1bdrJ2bQC5ubncvHmdQYO8adzYmcWLFxId/QiZTIq39xCaNm3B\ngQN7OXXqONnZ2eTn57N8eeBbVibwJnx8XDh/PoTjx5ujppaGh8cDHBz+2XC3jh27MGzYcB4+NKB1\na3OOHz/CkCEjuHr1EkFBG8jPz2fSpHHcuHENe3uHUmTxW6Orq1umUqZIJKJjx8aYmOjw/HnaP7Km\nevUaMGnSOPr1c8PAwIDU1BSePn2ChkY5tLS0ePnyBRcunMfBod5bxzI2NsbY2JjQ0GCWLl31D8z+\n03L27GnWrDlEcvJyjIyukpT0LWfPaiISBVK3riP16tVn27bNdO/eC3V1jc89XQEBgX8ZgjEnICAg\nUIxHjx6yYUMwgYEh6OrqkZqayty5M+jUqSsdOnRm//49LFnij5+fPwDp6WkEBoZw9uwpJk0aR0BA\nMBYWlgwe7M6DB/epVs2a7OxsbG1rMnLkWNavX0tIyBrGjJlQqkqfRCLBx2cYd+9GMnq0QtwhMHAl\nTk4NmDJlBmlpaQwZ4oGTU0NAEZIZGroVHZ2ixbIF3h81NTU2bnTlwYOHlCunQ5Uq/3ze0oIFvxMb\na0ZEhB1bt56mUaMKREXd4fLli3h5uQGQlZVNfHwc9vYObNoUSnj4QfT09Hn27Cnx8bHUrGlXqlLm\n56JqVQt8fIYxduwI8vPlqKqqMmbMBKpXt8HNrTfly1ekTh37N45R+OekbdsOpKSkYG5e9RPP/NPT\ntGlz8vJyAZDLFWvMy5MoPa8A27dvpX37Tu9lzBWE1goICPy3EYw5AQEBgWJERFzGxaUturqK8Dpd\nXV3u3PlDaby1b9+J1asVOTIikQhn52YAWFhYYWhohKWl1etjS548SaBaNWvEYrFSyKFdu45Mnfpm\nBT65XF5EAfHSpQucO3eaLVs2ApCXl8fTp08QiUQ4OTUQDLmPiFgspnp1689y7/T0dMLDzcjNHYCu\n7q/I5S/IzKyFXJ7PwIGedO/eq0j7All8LS0t1q/f/FoWX2EYlKWUWUB+fn6Z1z4FrVu3LZEHV6tW\n6fl627fvKXJc4OEu4ObN63Tt2uPjTvADWLcuEE1NLVxdi5bBuHHjGqNHD6dt2w5cvXoJNTUNxoz5\nnpCQNbx6lcyMGXOIjn6EhcUBYmObAKCunkD37pWUtfWSkp6TlPScUaOGoq9vwNKlq/H39yMqKpKc\nnGxatmytNPz69OlK69btuH79Cs7OLTh58jjBwWEAxMXFMmPGFOWxgIDAfwPBmBMQEBAohkgkKiYl\nr6C0c6AQQACFEaCmpqo8X7h+VPFxCjbZKioqyOWKTXXhfKHSmDdvEVWqmBc5d+fOrRK5RQJfFoVD\nei0sLFFXb4S6+jm0tBT5dKqqX9OwYQPmz5/NjRvXeP78GQkJj+nevReWllYkJT0jLS2Nr7/uQ3x8\nHPfv32Xz5o3K8X/++Udq1KhFx45dlJv9y5cv0rlzR/bvP/jFbPZlMhlz5x7k5Mk1qKlJGDjQ+4PG\nK/h5/hj16940Rl5eHgMGDGTy5OkMHuzOsWPhrF4dzNmzp9iwIYTmzVvSoIEFPXpc4fDh21hZGdOv\nnzPz5x9FJBLRp88AfvllM8uXBypfMA0ZMgJdXV1kMhmjRw/n0aMHWFpWU6qA7ty5k+fP07hy5RL3\n79/D2ro6Bw7sFRQyBQT+gwj+dwEBAYFiODrW58SJo6SmpgCQmpqCnV0djh0LByA8/CD29g7vNWZ+\nfr5SXv7IkUPUqaPob2paiaioOwDKeloAWlpaRSTxGzRoxI4dW5XH9+5FAWUbmAJfBgUhvcuXB7B+\n/WbGjJlAzZq/kZ3tTGpqF8TiuqSnH6F+/UZYWlpx+vQJUlKSMTQ0YsuWjTg5NaRKla/Iz8/H3Lwq\n9vYOypp0BQZGYUOjsOT/0KFD0dbW5v59RZ2599nsr1sXyJYtH8/oe5fx5s8/yMqV3bh9+wTXroUz\nevTp975PYmICrq69mDt3Bv36dWfhwnn4+Ljj4eGKp6crJ08eIzExATe33owf/x29enXmhx8mkpOT\nDSg8XwW/F6Ki7jBy5F+hkA8e3GPoUG8GDOjF3r27lOclElUsLa24du0qL1++wMmpAZmZmRw8uJ+L\nF88TFLSahIR4BgxoTrNm1bCyqvTWdRw/Ho6390C8vQcSHf2I6Oho5bXC3s8uXXpw4MBe8vPzOX78\nCG3bdnjvZyYgIPDvRjDmBAQEBIphYWGJu7s3vr5D8PR0Y8WKJYwePYEDB/bi4eFKePhBvvvur8LF\nxTfLpaGhUY47d27j7t6fa9ci8PIaDICr60B+++1XvL2/JiUlBVD0d3BwIibmEV5ebhw/fhRPz8FI\npVI8PAYUKbBcWs7d52LdukCuXLn0uafxRVFaSG9aWiwODsEYG++lSpWnSKV5ZGVlUatWbTw8vmHD\nhl8ICgrFyMiY9PQ0pkyZQZUq5vj5+bNsWQDVqilCRIsrZRbwMTb7H/s79y7j3bmjBmgV9ODePf2/\nda/Hj+Pp1asvtWvb8/DhA4KCNhASsonU1FRiYhRGUVxcLNWr29CsWQu0tLTYuXPHG+cpl8t5+PAB\ny5YFEBgYTEhIEC9eJJXoIxKJUFVVZf36tWhr62BmVgUfn2GYmJQv1ObN809IeMzWrZtYtiyA0NAt\nNGnStEwV0JYtXbhw4Rznz5/B1rYGurq67/ewBAQE/vUIYZYCAgICpdCxYxc6duxS5NzSpatLtCuc\ny2NqWonQ0K2lXgMYOXJMif7m5lUJDd2iPPbxGQYoNvVBQRuKtP3++ynvNM9PidoUXuUAACAASURB\nVEwmQ0VFpdRrhQUbBN6N0kJ6ZTIp+fky+vYdwIgR3xW5JpEUDeOVSouG8W7deo6tW2+TkfGU06dv\n0by5HTk5RcN3i2/2Q0LWUK+e01s3+6Gh6zh0aD8GBoaUL18BG5saPH4cz88/LyQ5+RUaGhpMnDgV\nQ0NjPD1d2bFjLwBZWVl8/XUftm/fw5MniSXaFxcxuX//LosW+ZGTk0PlymZMnjwdHR0dXr0KwMQk\nknLlriAS5aGpmY2n5zqSkp7z1VdVSUtLIzY2BiMjYzQ1tRCJwMSkAjExjwgJ2URCQgLz5s1ALpez\natUyIiNvk5eXR7Nm9dHW1iYjI4ONG9dz+PABxGIxu3f/ikgkRktLi0ePHpTIhyv+OTZr1gI1NTXU\n1NRwdHTizp1bSiO9OFevXsbXdzSRkbcAhfAOFOTKlmyvqalJRkYGurp6ZGRkvLMKqJqaGg0bNsbf\nfwGTJ08vc/4CAgJfLoJnTkBAQOAf4GN6MtauPUn79ofp0OEQmzad/VtjZGVl8f333+Hp6Ya7e3+O\nHTtCVFQkvr5D+OabQYwdO5Lnz58D4Os7hGXLfmLwYHc2bAimT5+uSgMkKyuLXr06I5VKmTdvpjJU\nNDLyNsOGeePp6YaPjwdZWVnIZDJWrlyqDGvbvXtnmfP7J0hMTMDdvf87tz94cB9JSUnK423bNivD\n76BoCN674uBQr0RIb6NGzvTo0UdpyBWEQRZQPCRRU1OTzMxMTp/+gx9+MOXSJTdSUzMYPfoF9+8/\n5OrVK2Xev/Bmv1OnskMso6IiOX78COvXb8Hff6kyNHjhwvmMGfM969ZtZPjw7/jppx/R1tbG2ro6\nERGK+54/f4aGDZugoqLCwoXzSrQvoOBHZO7cGYwY8R2hoVuwsqpGSMgaACwtDTE3v4lY7ImeXmXU\n1V+xfv1mevbsg0wmIy0tlXnzFpGc/ApDQyO6dOmBuro6ubk5SKVSlixZxMCBnqirq9OzZx/KldOk\nefNW1K3rSJcuPWjfvhO2tjXQ1zfAwMAQsViFNm3aMWHCVExMKgCKHNf8fMV3Pycn942frUgkLrKu\nv84rThQ24guHxJb2q6Jbt56MGzeS774bhrV1daUK6KxZ096qAtqmTQfEYrFQhFxA4D+K4JkTEBAQ\n+AcoK+TtTWRlZTF9+iSeP39Ofr4MD4/BLF7sT1xcc9TUbpOfr8HcuV7UqnWX9PSnbNgQjFSah66u\nHjNmzMXAwJDMzEyWLFnE3buRgAhvbx9atHAhLGw9UVGRmJiUp0oVc+ztHZg2bQILFvyMnp4+x46F\ns3jxYsaMmYxIJEIqlbJ2rcJTeO9eFNeuXcXR0Um5UZdIJMqQz7y8PGbMmMLs2Qt49OgBt2/fRE1N\njX37dqOtrU1Q0AZyc3MZPnwwDRo0wtT07TlCoDCWgoPDyvR2fGoOHNiLhYUVxsbGAISErCUnJ4dB\ng7xYtuwnXr58ASi8Lvv370FTU4uoqDtlKg5evnyRr7/2UIb0isUqVK9uw+jR4/n55x/x8HBFJpNR\nt64j48dPAhSGQfEXA3p6+tSubY+f3yTU1LqSmvo9aWkd0dZeyowZGtjY2LxxXW3adOD06ZNv3Ozf\nvHmN5s1boa6uDqjj7Nyc3Nwcbt26wbRpE5Xt8vKkALi4tOX48SM4Ojpx9Gg4vXv3IzMzkz/+uFlq\n+wIyMtJJT09X5qR26NCZadMUa1dRUWHePB8cHZ2Ii7Nm4MC+LF36ExkZGdjY1EAikdC4sTNyuRx3\ndy927tyGlVU1rl+/yuPH8URHP2Tt2gBycnLYsCEYFRUVIiNvU6FCRVq0aMW2bZtp0qQZv/22nRcv\nkl6LE8k5cuQQ9vZ1AahY0ZSoqDs0atSEU6f+ynGVy+WcPXuKQYO8yMrK5Nq1qwwbNpLc3FzMzKoo\n21WrVp0WLVyIjLzzWgDlFwCaNm0BgLf3EGXbEyeOKr37vXv3p3fvv148FPf6F1BcBRQUdexyc3P/\nNeHYpSGUURAQ+PsIxpyAgIDAv5SLF89jbFyeRYuWAoqNbl6ejOxsCxITF6Cjswsdnd1cvtyaAQOa\nsmbNegD27t3Fpk0b8PUdzfr1a9HR0VGGf6alpZGcnMylS7+jrq5O/foNSU9PY8OGYB49esjo0cMB\nxebK1LSici4FZRWg9I16AXK5nNjYPzEyMsbWtgbR0Q+RSFRRUVHh8uULPHz4QOm9y8jIID4+7p2N\nuU+xGZXJZMyePY1796KoWtWSadNmsXnzRs6fP0NOTg52dnWYMGEqJ04cJSoqktmzf0BdXZ1OnbqR\nlZVJWFgoV65cIi8vD7lcjlQq5caNa6iqqnLnzi3k8nxq17bn2rWrJRQHC6tGFg+VnTXLr8hxaOg6\njhw5VCTEsW9fV6ZNm0BenhQzMzO8vL5n3DgbqlZtTUzMYaTSZkycCHPnTlGOV9pm/+bN63Tu3O0t\nz7fkNblcjra2DiEhm0tcc3Zuzpo1q0hNTeXevSjq1atPZmYGOjqlt39fqlQxx8jIGEtLS0JDgzE3\n/wpQhJ6qqEiUirEFf2SyfCwsrOjXz43582cRGrqV4OA13Lt3lytXLjF37gwyMtKxsamBTCbD3Pwr\nYmNjCQ8/RL169enRow8AXl5DWLBgNmvXauPgUK+IR83KyppRo4aSnJyMl9dgjIyMSUxMKJYzp/jb\nw+Mbfv75R9zd+yMWq+DtPYTmzVu+9Zm/D1FR0cyaNZ3s7GT09P5efuHHYvLk8Tx79pTc3Bz69nWl\nW7eetG3bjO7de3PlyiXGjp1AYmICO3b8glSaR82adowbNwmxWIy//4JSX4oICAgoEIw5AQEBgX8p\nVlbWrFy5lNWrl9OkSTPs7euipiYBLAFIS+tMhQpzaNjwW549e8r06ZN4+fIFeXl5VKpUGVB4iWbP\n/ssw0NHR4dy5MyQmJmBoaMTRo4dJTk7mq6+qYmFhRUBAsLKtiYkOz5+nAQoBlwKcnZsTGLiSMWNG\ncO3aVeLjY/H09CE5+RWBgasASEp6rlTjTEp6zrhxo7h58xqNGjkzZ84CQKHquWLFYuRyOY0bN2XY\nsJHK82Fh60uc/xTExv7J5MnTsbOrg5/fbHbu3EHv3v3x8vIBYM6c6Zw7d4ZWrdqwc+d2fH3HYGNj\nC8Avv2wCwM/PnylTvkdNTY379+9x8eLvqKmp0aVLN/bt28PJk8eRSvOIjo7G0rIaQIl6a2+icIij\nTCbF23sgtrY1aNGilbLWWlDQalRUnjJ8eAb791egRo35fPONEzExz2nZ0qVEnqNMJmP58gPs2LEG\niSSTNWuCS7u1krp1HZg3bxYDB3oik0k5d+4M3bv3olKlSpw4cZRWrdogl8t58OA+1tbV0dTUxNa2\nJkuXLsLZuRkikQgtLe0S7R8+fKAUbJHLQUtLGx0dXW7cuI69fV0OHdqvzAeTy+XKlwinT59EW1uH\nrl17cvXqFe7diyI3N5fHj+MBOHz4AHXrOpKamoqOji6XLv3O8+fPOHXqGDVq1GTECB8qVjTF2ro6\nGRnp+PqO4ddff0Ff3wBQeAFVVSU0bdqcqVNnKp+DvX1dtmwpGR5c2KNWmMJ5tI6OTjg6OgGKvMXC\n476JzMxMJk8eT1paKjKZFB+fYTRt2oLExATGjx9FnToO3Lp1AxOT8vj5/YS6ujoBAdvx9w8gL0+b\nvLxGmJmdf6d7fSomT56Orq4uOTnZ+Ph40LKlC9nZ2dSqZYev72hiYqLZtCmUgACFx9TffwHh4Qfp\n0KEzQ4YML1KG4eHDB1hZVfus6xEQ+DchGHMCAgIC/1KqVDEnOHgTv/9+lqCgVdSrVx8NDTVGj37K\nwYM7AClZWVCnjjW+vkNwdR2Es3Mzrl27SnDwGuU4pZUvqF27LjNnzkNdXZ1z586wa9cO4uLiuHXr\nD+zsaiOVSnnw4AF6ehVK9NXU1MTY2JiEhAS6devJ2LETychIZ8GCOXh6DqZ/fzdcXXsRHf0IuVzO\n3btRhIRs4siRw6xcuYQnTxKRSCSsWLGENWvWY2xswtixvpw5c5IaNWoRELCC4OAwtLV1lOebNWv5\nSZ5x+fIVsLOrAyiKwW/fvhVTU1M2bdpAbm4OqampWFpaKQvDF89zKl++IgcO7KV2bXsePLjPjRvX\niIuLRUVFheXLF2Nu/hX6+vqoq2uUqTj4NkoLcZTLea3EuJqMjHQyM7No2LAxkydPoksXMzZv3oC7\n+xiGDvVm4sQfiownl8sZOnQ7u3e7AZ2oVOkI9+49wcmp7PDV6tVtad26LZ6erhgYGFKzZi1EIpg+\nfS7+/gsIDQ1GKpXSpk07rK2rAwqDdfr0ySxfHqgc5969e+zbt6dI+wJjTiRSFE1v3NiZVauWkp2d\nTeXKZsqQQpFIhJqaGt7eX5OWloZYrIKXlxuvXr3CxaUtzs7NmDZtItnZWaioqNCjRx82bAjGyakh\nO3b8gpaWFjdv3iA9PQ25XI6lZTVOnz5JQkI89+7dVd4DFN48FRWJUlF20CBvXFzavPNnVhY7d/7O\nlSspmJmJGTq0zTuFFqqrq+PntwhNTS2Sk5MZOtRLGZYZHx/HrFl+TJw4lenTJ3Pq1HHatevIypUr\nePz4J7KznTA2Xkhy8gdP/YPYvn0LZ84oQs2fPXtGXFwcYrGYli1bA3D16iXu3o1i8OBBAOTk5GBk\nZAQoyjDs2bMLmUzGixdJxMQ8Eow5AYFCCMacgICAwL+UpKQkdHR0aNeuI1pa2uzbtxsAHZ1X7Nnj\nyeHDBzhxwhGAzMwMjI1NAIVQRwH16zdk585tjBo1DlCEWdaqVRs/v9l4eX2NuroaKioquLt7Y2pa\nmaVL/UlPT0cmk/LNN960bFm6VH2bNh1YsmQRaWm1uHHjOtra2mhoqGNmZoZEImHOnAUsXryI58+f\nkZeXh6qqGj179mHr1jBGjRqKTCZ7HaanjYqKCm3bduD69WuIRCIcHOopw8IKzn8qY65wCFxBaN7P\nPy9k3bqNmJiUJzh4Dbm5uaW2B7Czs2PLljCmTJnB/v17OHBgL+XLl8fWtiZRUZGEhGzi1auXeHq6\nfcgsSz07f/5sFiz4CSurahw8uI9r164CULu2PYmJiUREXEEmk2FhYVmk39OnTwgPt6dA5j8hoS1b\nt27HyenNuXXu7t64u5cs1P3TT8tKbd+yZWtOny5aqkIsFpfavsCzlZiYwLlzp5W5ZAUkJiZw585t\ntLV1ycvLe12K4SdiY2NYtMiPq1cv8+RJIkuXBqCjo4Ov7xBWrVrKzZs3aN68JRKJBIlEFR0dbVas\nWMOCBXPQ1NRETU0NY2MTzM2/omvX7oAi5PVThPSuXXuC2bNrkZ1tBaQQHb2LRYt6vbWfXC4nIGAF\nN25cRywWkZT0nFevXgJgalpZaQzb2NiSmJhAeno6+fnZZGcrvICpqd0xMNhX5vifmoiIK1y9epnA\nwBDU1dUZOfJbcnNzUFNTL/KcO3bswrffjijSt6AMw9q1G9HW1mb+/FlFfh4FBAQENUsBAQGBfy2P\nHj1gyBBPvLzcWL9+LR4e3wAKg8zDw5UdO35h5MixgGIzPG3aRL75ZhD6+vrKTZKHxzekpaXh7t4f\nT083rl27ir6+PrNn+6GlpUl+vpy8PCkqKhKsrauzYsUa1q/fzMaN2+jbty8Ay5cHKkMLC+jTpz8H\nD56gcWNngoJWcerUcczMzGnRwgUAW9uaBAaG4OMzDBeXNmhoaCASiaha1ZLJk6czZsz3NGjQCC0t\n7dcjllX8XP5JhRuePn3CrVt/AAXF3BXKgLq6emRmZioLvUOBPHx6kWMrq+q8fPkCO7vaqKiooKam\nRqNGTbh+/RpffVUVN7feTJs2merV32wovYm6dR04ffokOTk5ZGZmcO7cGQCysjIwNDRCKpVy+PCB\nIn06dOjE7NnTSi0CrpDPzyh0Ro6qqrREu09JZmYm3303HG/vgXh4DODsWYXXJiBgOY8fx+Pl5caq\nVQqjb/PmDUyaNI6cnGzEYhEbN25DW1uHU6eOM3fuzFKVLwuL9ri7e9O0aXN8fb8jOHgTlSubAYq8\n0KCgUEaMGM20afMZNOggTZv+jKPjcZo1O8D27Rc+6pqPHs17bcgB6HHmzLsJ+YSHHyQlJZng4DBC\nQjZjYGCoVNJUUytcqkIFmUxRqkJbWwUTk98BUFePQ+/zaAYBKHMl1dXViYmJ5vbtWyXa1KvXgBMn\njvHq1StAoer65MkTMjMzS5RhEBAQKIrgmRMQEPi/4ODBfdSv30ipRPgl0KBBo1IVBr/+2r1EHlnT\npi2UoVeFKSs3x9HRqUQdu/ehuNdw164dvHz5gqioO9ja1iQzMwN1dQ1lWGJQ0AmOH5fy8uVTmjR5\nQrNmDVmyxJ+UlGS0tXU4ejScPn0GUKNGzVLPfwpEIhHm5l/x22/bWLBgNlWrWtKzZx+l8WtoaETN\nmnbK9p06dcXf3w8NDQ1Wrw6mW7eerF8fRJ06dVFX1wBg7doN6OrqYW1tS1hYCGpq6mRnZzFixCjl\nWKWJkLyJskIcBw8eypAhnujr61Orlp0yRxEUHs2goNW0bdu+xHiGhka4u//OmjVR5ORUplatPfj6\nNvw7j/BvU1bo4LBho4iOfqQUSbl06QLx8XEsWPATo0ePIC8vjxs3rmFjY8vjx/Gkp6eVqnwJRUV7\noGS4cYsWrQDYuTOOJ0/yuHBBExgHKF4wzJ59iDZtXmJgYPhR1qypmVfkWEvr3TxMGRkZGBgYoqKi\nQkTEFZ48SXxje21tbSpUMOabb2JISIgjNvYMiYnab+zzKWnYsAm7dv3KwIF9qVLlK+zsagNFvdxV\nq1rg4zOMsWNHkJ8vRyKRMG7cRGrWtFOWYShfvuJbyzAICPw/IhhzAgIC/xcUl5X/cvlwL9Xu3eeI\niUnDxaUatWv/vdyTR48esHLlUsRiERKJKuPHT0Yuz2fx4kXk5OSgoaHB4sUrEYlEPHz4hN27a5OT\nY0GlSuEsXhxN+/YuDB3qy6hRQ5HL5TRp0oymTZsDlHn+Y6y9MBUrmrJp044S5318himLtxemRQsX\npecRSsrFFzbSWrduW0Lk5NSpCB48eEbr1nWoWvXdFDwLKCvEsUBlsTg3b16nVas2hTyfRZk2rQtf\nfx3D7dtnaNmyFTo6Ou81nw+lrNDB4gbXpUsXuHz5IjdvXuf586eIRCLi4+MQi1VIT0974z0Ki/ZA\nyRBZVVVFoe6oKG1EonwgnwJDDuDpUysSEp5+NGNu7FhbHj7cRmRkPUxN7/Ldd2/+XVQw33btOjBx\n4lg8PAZgY1ODr76yKHNNBcd+fn5MmDAJkQjq12/EkyePPsoa/g6qqqr4+5cMrS1erqW0nxkouwyD\ngICAgg8y5hITE5kwYQIvX75EJBLRr18/3N3dSU5OZsyYMSQkJFC5cmWWLFmCrq7ux5qzgIDA/yGl\n1Vw7evQwfn7+AFy+fIHffvuVuXN/xM9vNnfvRiISiejcuRvly1dQysoXeFWiox+xYsVisrKy0NPT\nZ+rUGRgZGePrOwQbG1tu3LhOVlYmP/wwiw0bQoiOfkTr1m3x8RlW6lzeR53wQ9i+ffcH9Z8xYy9B\nQa2QSk1Zu/YMy5bdoFWr93/bXZbXMDAwpMhxx45dOHlSlZwcxQY0ISGAhIS7xMbG0aZNe9q0Kek5\natOmPS1auJCYmKDMA4QPX/s/wfbtv7NxYyoAbm7aDBjgDMDChQdZscKB7OxmVKkSzurVKTRoUOOj\n3lsul7NmzTHCw3eRlfWAwMA1b2zfsGFtLC2rftQ5vCuFQwdVVFTo27dbmUW4Bw70pEGDRkycOEaZ\nS7dlSxhaWtro6paufAlFPXGKENm/QktzcnKYPn0Sv/yyCyOjLF68ADAG7gOKHLRata5gadn6o63Z\nzs6KgwdNefgwhipV7JTKmWVRYOzo6ekXUZktTIFSJoCr60Dlv2vVqsX69X+VgBg+fNSHTP2zcPz4\nTVaufExOjoQOHcDX95/5HSsg8KXxQcacRCJhypQp1KhRg4yMDHr16oWzszO//vorTZo0wcfHhzVr\n1rBmzRrGjx//seYsICDwf0hpNdeCgwNJSVHUUNq/fy9dunTn/v17JCU9V276MjLS0dLS5tdftyll\n5aVSKUuWLOLHH/8qkL1mzSomT56OSCRCVVWNtWs3sH37ViZNGkdIyCZ0dHTp378H/fu7ERFxpcRc\nvgSkUim7d2shlZoC8PRpMzZt2v63jLn3wcxMBKQBCu9P5cr3MDV1KLP93bt/MmLETSIja2Fmdp45\nc4xp167uJ53jx+Datbv88IMRr14pwvuioq5hZXUHR0cbtm5VIztbYSTExXVg7dptH92YmzlzLwEB\nHZDLe6Ki8oywsONMnmz+3uMkJiYUMZxAUR7h0KH9jB79cf4vLyt0UFNTs0i4aMOGjQgKCqBOnbqI\nRCKeP3+GRKLIExOJREyZMhN/f78SypcF1wto3bodP/44jx07flGWxijg++/t8fVdS4UKSWhoHKdK\nlaqULy9hzJja76U6+i5oampSu3bNjzpmYY4cucrNm89p1coSR8fqn+w+n5qkpBeMH59KfLyihuXN\nm9FUqXKB7t3LLmwvIPD/ygcJoJiYmFCjhuI/Iy0tLaysrHj69CnHjx+nZ8+eAPTs2ZOjR4++aRgB\nAQGBt2JlZc2VKxdZvXo5N25cR0tLm/btO3H48AHS0tK4ffsWjRo1wdS0EgkJj1myZBEXL/6OpqaW\ncoyCN/WxsTFERysKZHt5ubFhQzDPnz9XtisI67O0tMLS0gpDQyNUVVWpVKkyz549K3UuXwIikQix\nOL/IueLHn4IRI9owcOCvWFjsws7uF2bN0kJXt2xFhh9//IObN93Iy7MnOron/v7xn3yOH4OLFx/x\n6tVfnqHkZAcuX/4TuVxOfrHHnJ//8UVdzp3TRC5XyLnLZOU5c0bto41ta1vjoxhyhUMHo6Ii8fAY\nwKFD+5Whg3p6+tSubY+7e39WrVpG/fqNaNu2AzNnTgFg+vRJZGVl4uo6EC8vH6ytqxMYGEJo6Bbm\nz1+EtrbiZ7G4aE/t2vaEhW0jODiMypXNmDVrPqqqqvz44zwWLZqBk5MtFy82ZPPmDhgb/0ZWVhih\noStIS1OEcvr6DiEqKhKA5ORk+vZVCMs8evQQHx8PvLzc8PBwLVLnruD8okXzyS/+BfgErF59DB+f\nSvz4Y1/69NEmOPjkJ7/np6Jv367Ex9dHReUppqajyM624ObNlM89LQGBfyUfLWcuPj6eyMhI6tSp\nw4sXL5R5KcbGxrxQxC8ICAgI/G2K11xzcmpAly49mDhxDGpqari4KGo26erqEhq6lYsXz7Nr168c\nP36EyZOnA39tJOVyShTILkxBLk2Bl64AkUiETCYrdS6enoM/8RP4cFRUVPj6axnLlt0lK8sac/ND\n+PhYvr3jByIWi/n5597v3D49vagRkpb2cb0jnwpHR3P09G6QklKgiPkHDg6KUg09emQSFPSYvLzK\nVKx4mkGDKn/0++vo5BQ51tb+cAn3x4/jmTZtIm3adOD69QgWLlzMunWBPH36hMTEBJ4+fUK/fq5K\nkZr169cSHn4QfX0DypevgI1NjSLhf+8SOjhjxlwA0tJSefXqJX37DqBv348jgpOdnc2aNSd58eIV\ncXGxzJw5X1mj7ezZU2zatIGxYydgb+/AunWBhISsYdSocYhEolJVVXfv/pW+fV1p164DUqkUmUxG\nTEw0x48fKbUA9qdk1y4ZmZkKb1x6ug07d97Cu2Sa5Wfj7NnTxMQ8YuBAz7e2VVFRoWLFazx50pbE\nxGWoqcVRo8aX8dJMQOCf5qMYcxkZGYwaNYqpU6cq34oVUNYvQAEBAYH3obh64v79ezA2NsbY2JjQ\n0GCWLl0FQEpKMhKJhBYtXKhSxZy5cxVhV4Vl5c3NvyI5+VWRAtlxcbEl6nGVhlwuL7P+25fAuHHt\nadz4Bvfv36J16zqYmVX83FMqQbNmIs6dUxg+kEHDhq8+95TeiQYNajFt2mk2b76HXC7C1VWDxo0V\nCqMzZ3bF0fEcf/55DhcXG2rV+vhG9LhxFjx5spPoaBuqVYtk/Hirt3d6A7GxMcycOZWpU2eRmprC\n9esRymtxcbEsXx5IRkY6bm696dmzL/fuRXHq1HFCQ7eSl5eHt/dAbG3/XijpvHn7CQszJC9Pnfbt\nj/6PvfMOqKn/4/jrdttLQ0h2KDREZkY/hOxRZBXx8JgP2VtWZh57RzYRHnvzGI+RyCoy00BG2rfu\n7f7+uE+XFIoQz3n9wznne77ne86593Y+5/P5vt8sXuySK4PtTyGVSunRYzdnzvRCVfU5Zcv6I5Mp\nXhzkRh0zJ6ysbP7N7D+nYcNGlChR8pMG2N8SVdWs2T+xWPbNj5kX6tVr8J6Y0acRiUTMmqXG4sUr\nSEzcQps2g9DWljJu3EgkEglRUZE0aOConAt4+fJFpSdkZrltfpfICggUVL46mEtPT2fIkCG0adOG\nJk2aAGBsbExsbCwmJia8ePECI6PPK0GZmHxfJS2B74twf39tvsf9vXs3hNGj56CiooKqqire3t6Y\nmOjRsWN7Nm7cSPXqCrnrV6+iGDVqnLKsadSokZiY6OHm1glf39loaWmxbds2li5dwvTp00lISEAm\nk+Hh4UHNmraoqYkxNNTGxEQPQ0MdNDRUleenpibGyEiHV6+ilGNRU1NjypQpP9VnvG3berlu+yPO\na+rUDpQseZIrV4IpVUrO2LHuiMXi7z6OL2H48JYMH57ztt69czZgz4lt27ahqalJu3btctU+MjIS\nP7/p3LixnaioaEqUaJmrh9mc7q9EosPbt3FMmDCKJUuWYG5uzqVLl5TfBV1dTZycGmNqaggYUrhw\nYUQiCQ8fhtG8eTOKF1f8zXdyaoyOjkaeP0MXLoSwcqUdqakKb76AADsaN75Av365v345cf58MGfO\nNAcUc+7S0ozYt+8B9epZoa+vzfPnCYjFKsrxpqTooKYmxsREDy0tDQoVEIHtwgAAIABJREFU0sTE\nRA+ZLAkVFREmJnp07epK/fq1OX36NGPGDMPb2xtdXU06duyAl5fXV403r/zxhynDhl0kNTUGE5NV\nqKtLWbz4FlOmTOHcuXP8+eefyGQyDA0NWb9+PXFxcYwbN47IyEi0tLSYOnUqFhYWLF68mOjoaCIj\nI4mJicHDw4MePRSB6bp16wgMDATAxcUFDw8PIiMj6dOnD3Z2dgQHB2NlZUX79u1ZsmQJb968Ye7c\nudjY2BAYGMjt27eZOHEiL1++ZPLkyURGKspSp0yZgp3du3m0IhF4eDjSuHF5+vc/wJw5nQgMDOTR\no/vs2bMHdXV1mjdvzu+/90FNTY2tW/3ZvHkjmpqarFq1in37Ahg4cGD2i/QL8TP9zRH4tnxVMCeX\nyxk/fjzm5ub07NlTub5Ro0bs3r2bvn37smfPHmWQ9yliYz8tMSzw82Jioifc31+Y73V/LSxsWbt2\nc5Z1sbEJnDv3D82bt1aOwdjYjJUr/bO1s7Orw8aNAQDEx6dhbGzGggXLs7Xz9V2m/H/ZspWYNm2u\nsu/MbUWKlMpxLL8aP/K7265dDTLjmNevkz/d+BdDJpPRuLGiJC+31//16ySkUhlJSTIMDIqSmCj9\nrHT/x+7v69dJaGvrULhwUU6fPo++fhHi4pKRSKTExiaQlCRBS0us3Fcuhxcv3pKUlEZiYqpyfXKy\nhMRESZ4/Q9evPyY19X2POH0ePkz46s+iVCpHVTUeqdIfXY5UmkJsbAKJiRJUVNTR0dHl+PGz2NpW\nZcuWHVhZVSU2NgFj4yJcvBhEsWJl2LVrLxkZcmJjE4iKisTMrATNm7fjwYMnBAffpEaNWqxbN5xW\nrVwwNDQkPv4tyckpFCv2bbPgTZvasmTJWdasWc+ffy7BxKQI8+fPZtOm7axevZxly9ZQrJgpCQmK\na7lgwXzKlq2At/dsgoODGD58BOvWbSEpSUJ4+IMsmVcnp9aEh98jIGAnq1atJyNDTt++HlSoUAVd\nXT0iIiLw9p7FsGFj6dPHncDAvSxevJpz586waNFSfHzmkZCQSkpKGrGxCUyaNAVr66pMmTKLjIwM\nUlKSs9xfuVzx2c/8XMfGJpCQkErVqvakpMhJSZFQsmRpbt26R0JCAuHh4bi4uAKQni7F2trml/xN\nzkR4rvp1+ZIg/auCuatXr/LXX39hYWGhfHvo5eVF3759GTp0KLt27VJaEwgICAjkN56e3dHW1mbI\nkI+kQvIZuVzO/v0XiI6Ox9m5KqVKmX6X434JP6NJ+q9ETEw0w4cPxtKyMvfuhVGmTDkmTvTm0aNH\nH7XEqFjRghs3QmjSpCnJycloaWnTpUt3wsPvMneuDxKJBDOzEowdOwk9PT3CwkLx8ZmKSCSiZs38\nNf1WU1Nj5sy5eHkNQktLC2Pjd5+jD73gFIiwsbFlzpyZ9OjRC6lUyoUL52jbtkOej92kSTUsLfcR\nFqZQMixe/CTNm1f40lNRUrlyRbp1C2TLFg3kcglaWnEMHPiu7O9T6phdunRn4sSx/PXXburUqUem\n5+HJk8c5evQgqqqqGBsXxt3dEz09vRwNsL91MAfw+nU0SUmvGDduOFKpDIlEwp07t7Czq0axYorf\nq0xPwZs3Q5gxYy4A1arZ8/btW5KTkxCJRNStWw9VVVUKFTLA0NCI169fcePGdRo0+B8aGpqAwncx\nJOQa9eo1xNTUjHLlFGW9ZcuWw96+5r//N+fZs+hs4wwODmLSpGmAYk5tbkWk1NXVlP9XUREjkylK\nSe3tazFlyow8Xy8BgV+Brwrm7O3tCQsLy3Hb+vXrv6ZrAQEBgc/i57fpux5v3Li9+Ps7IZUWZe3a\nA6xZk4SNzZeZbn9rfh2T9J+Xp08jGDduMlZWNvj4TGXXrh2cPXsaHx9fDAyyW2JIpVLWrNkAgJ/f\nKjKnm0+fPhkvr9HZRDl8fLzx8hqDrW1Vli1b+Nnx7NmzC01NzWxCHDlZEYhEIjQ1NZkz50+GDRuA\nh0cf5XgUc+Hf7f/2bRwSiQRLy8rUq9cADw83jIyMMTcvn20efW4wNDRg/Xprli3bjkymQufOpbCy\n+rr5f5nMnduBzp1vEReXSL16u9HUVAQm74u0fOiVCFCqVBn8/bcqlzNN5Xv06EmPHj2ztf+YAfb3\nwNm5FRMmjFFmbs6fP8uJE0dzbJtzYI7S/gEUwZZMJsumfyCXy5XrsgZZivLz9/fNy7HzgkgkokoV\na3x9ZyuzpCkpKbx8GUvJknm35RAQ+BnJNzVLAQEBgV+Z+Pi3BAaaIpUq3q4/ftyK9et34Ov7bYK5\ngmCS3rp1S7p16/1Nzu+/QJEiRbGysgGgWbMW+Pv78fDhA4YNGwBARkYGxsbvTNEbN26arY+kpEQS\nExOziXIkJmaur/pv/y25ePHCJ8fTrl3uFEVNTYsrzah1dXVZvVoRYGaKV3h69s3SXl1dAwMDhdVE\nly498PTsS2pq6r+frS8TQClXrgTz5pX4on0/h7291TfpFxSCcNu3n0NVVYSbmyPq6vlnD5Ebqlev\nyZgxwxkwoC+gRnz8W8zNyzN//ixiYqIxNS1OfPxb9PULYWNjx9Gjh+jZsw/BwUEYGBiira2DXC4n\nIeEt7u6d3wvwRdjaVmXGDG+6d/cgI0POqVPHcXZupQzKMr0IczfOGuzevZNOnbogk8lITU3Jkp17\nP3DM/P/HBPUMDAwYP34KU6aMIy0tHYC+fQcIwZzAfwYhmBMQEBAogBQEk/QuXdrTurUr+vr6P/JS\n/LS8/+Apl8vR0dH5pCWGpuY7wZLExAQCAwO4c+cWL1/GMmHCaCZO9GbgwN/IyJAxcOBvpKamKlX8\nEhLiefkylpSUFLS0tFi+fDHnz59FLBZTq1ZtBgz4g7VrV6KtrUOXLt0JCwuld+/pyGTyLCWaMpmM\nFSuWcP36VdLS0unQwZW2bTsQHByEn98qDAwMefToARYWlZg0aRoBAduIjY2lc+euiMW6lC9fjPj4\nl6SlpeHs3IoKFSy+3QUuYCQmJtKp00GCgjwAKfv3+7N5s6syS/U9KFOmLL/91h9PT0/S0qSoqqri\n5TWaUaPGM378SDIy5BgZGeHruwRPz774+EzFw6MLWlpaTJgwBcj83GYPmipWtKRFi1b89psHADVq\n1OLmzRs0adIMkUiEpWUlLC0rMXOm92eDsaFDRzBnzgwOHNiLiooKI0aMo0qVd0F2poXF+y8WnJ1b\n4ezcStlmzpwFyv9Xq2avfOkgIPBfQwjmBAQEBHKBvn4hOnaMYf36GKTSYpQuvZ+ePb9+Hs/HMDev\nwNKlC1m+fDF169bH1raq0iTd2bk1t2/fYtKkaSQmJipN0uvUqUfNmrWVfeRkkg7ZM0I5maQDlCxZ\nkufPnxWYYM7JqT7Hjp390cPINc+fP1PaXxw7dpgqVazYt29Pri0x4uLe0KlTVyIiIpBIUtm1K4DU\n1JR/Pxur6d69E8uWLWLlSj/Wrl3F4cP72b59Mx06uHL27Gm2bNkFoLTkeL880sfHm6lTvSld2iJL\nieb+/XuV2bi0tDQGDOij/Ezdv3+PTZsCMDYuTP/+vbl5M4TWrduxaNFqrl//i4wMQyIjT7JhgwG2\ntt/uu1FQ2bDhLEFBPQExoMrp093Yu/ckLi7/y/djfWxO5s2bN9i8eT0gx9KyEiNGjEVNTQ0Xl9Y0\nauTEpUsXSEh4J9yiq6tLr159cHRsDLz7jsXERHPunCKgmj17AdOmTSQlJQWAUaPGY2VlQ9++PYmI\neMy4cSNo2bINwcFBbNu2mTlzFhAf/5axY4cTHR2NpqYWDx7cx9m5FdHRUcyc6Z2jR+GX8PBhJMuX\nh5CRIaZLl7LY22d9efCz/WYICHwJX2faIiAgIPAfYsaMtqxaFcK0aTvZtcvim86XyzQmNzcvz+rV\ny1i/fg0tWrThyJFDnDhxJJtJup1ddfbs2cWsWdOUfXxokr5u3RbWrduCv/82fH0XK9t9yiQ90+Kh\nYPBzeZaWKlWa3bt30L27K4mJibi4uDFt2mxWrFhMz55d6dWrK7dv3/jo/rq6elhZ2TB+/BRiYqLx\n919DerqU4cMV3mdt2rTj4cP7tGjRmEOH9pGQkMjz58/Q0dFFXV0DH5+pnDlzSilYkUlmiaa9vT2g\nKNHM5MqVixw+fIBevbrSr19P4uPfEhn5FJFIRKVKVShc2ASRSET58hWJiYnh8uWbpKVpkXlvYmIa\nsW/fgxzPJzExkd27d37RtXRxaU18/Nsv2vdX5enTCDp0cGXTpgB0dHTYunUTM2d6M3XqLPbt24dM\nJlNeb5FIhJ6eHv7+2+jYsRMLF85Xrs9K9u+YkZERCxYsxc9vE97eM/nzT0Wpd//+g7GxsWPdui10\n6tQ1yz5r167EwqIS/v5b6ddvINOnTwIUnnuXLoVQpkxr5s5dxLp1qz86p+5zvH79hp49b+Lv78bG\nja707fuasLDHnz0fAYFfDSEzJyAgIJBLRCIRrVo5fJdjFRST9IJIcnIyY8eOICEhHplMym+/9ade\nvYZs2bIBdXV1XFzcWLRoPg8e3GfhwuVcvXqFAwf+UqrnfS/EYjETJ2Y9ZoUKFVmyZFW2tosXr8yy\n7Orahb//Pq3cZ9iwUezatYPw8LtKURszs5I0auSUo4rf6tX+BAVd5vTpEwQG7mDhwuXZ2mTyoRCF\nl9coatSonWVdcHBQlkBfLFZBJpNSpIghItH7D+MS9D6irJ2QEM/u3QG0b++SbZtUqigJ/BCZTIZY\nLM5xrlRBw929Pvv3r/+3zFKGo+MW2rbNfq75xYdzMtevX0Px4maUKFESUJQlBgbuoFOnLgA0adJM\n+e/ixb65Pk56upQFC2Zz/344KioqREY+BT4tYPKhUmZMTAxubu2JjIwhMbE2s2e7cfjwegoXLsSb\nN68pXPhdpcDBg/u4ezeUYcNGAXDkyEF27tyOVJrOw4cPOHnyAs2bO1K5cm2Sk6MpWXIX0dHLiIxs\nSkDASiIjvUlNTcHBIXcG5QICPztCMCcgICBQAHn48D5Lly5ERUWEqqoqI0aMA8DJqTlv376lVKky\nAMTGxjJzpjdyuSKD9vvvgwFo0aI18+b5KAVQpk2bzcKF80hMTEQmk9K5c9dswdyHKoUFFQ0NDXx8\n5qKtrUNcXBy//96LevUaYmtbjW3bNuHi4kZYWChSqRSpVEpIyDWqVq32RccaOfIPpkyZ8Unp9EGD\n+jJo0DAsLbOKfaSlpfHPP+epU+fLXgB8WKZpY2NLePhdbt9+xMaNMaSkpPDixZVsKn6FC5uQmppC\nnToOWFvb0rlzW0Dx8C2XK0RNdHX1uHr1KqVKVeTo0UPKY9asWYfAwJ3Y2dmjqqpKRMQTihQpCsCL\nF8/x8OiCSCQiLU1CqVKl2bRpFerqbyhduiOxsb9Tt+4btLTeZCmla9asBSdOHCUtTcLz589p3tyR\nFi1aU7p0WZYtW6hU8ty+fQ/z5vkQFHQFNTVVtLV1cHHpTJEixYiNfcGgQX3R1y/EkiWrkEgkzJ8/\ni7t3QxGLxQwaNIxq1ew5eHAf5879jUQiISoqkgYNHBkwYMgXXf+8oqurS0BAS7Zt24Oamgg3t46o\nqalx9uxpSpYsTZkyZfP1eB/OydTV1cuSvXxfbfJj+4rFYjIyFEFZRkYGUml6trbbt2/G2LgwEydO\nQyaT0ahR3VyN7/1gLyUlmebNf2PFCglyuS4gIiTEg5o11yGVZs3MvT/mx48fcfLkMVas8EMsFuPo\nWJujRw+RmpqKjY01u3aNQF9/N4UK7eD16y6EhR3C3b0LzZq1IDAwIFfjFBD42RGCOQEBAYECSM2a\ntbPMf8vkxo3rtG7dTrlcvnyFHC0aGjZsRMOGjZTLuckI2dlVx86uunJ548aNBdKYVi6Xs2LFEkJC\nrqOiIuLly1jevHmNhYUld++GkpychLq6OpaWlQgLC+XGjevKt/x5Pc6cOX9+NiuU03ZT0+K4u3ty\n8eKXB3OZZZqzZk2lTJlytG/vwo4dW/HyesqDB4qytqJFZYwc6YWamhhQqPhpa2szZsxw0tLSADmD\nB3spx5k51HHjJjN16lRksgxq1KitPIfWrdsRExNN797dkcvlGBoaMXPmXJ49iyEy8imBgfvR1y/E\n7NnTOXr0EAMHDsXOrjqbN29ER2caGzacZd261Tx9GqE0nXZza09iYiLTp89h7doVynLNnTu3IZPJ\n2LQpgNu3b/4ryjOZ8eNHUqpUaW7eDKFFizYMHtwXIyNjlixZhVisOM/AwABUVFTw999GRMRjhg0b\nxNatgYBibt/69VtQVVWja9eOuLq6YWJS5IvuQV7R0dGhd+9mWdb9/fdpHBzq53sw92Gwb2lZib17\nA4mKisTEpBJHjhzM8hLjxImjdO/ekxMnjiozesWKmXL3biiNGjXh3Lm/kb5zVFeSnJykvH6HDx9Q\nll5ra+uQnJyU49jeV8ocPXoYGRkZnDq1BS2tsqioSHjzxhOx+BlJSXGMGTMMNTU1hgwZjrW1bZZ+\nTp48yqVL/+DkVB8dHV1kMhkxMdGoqanRu7c7UVH72bkzA1XVqzRurM7jx9HKDGSzZs4sX744p+EJ\nCPxSCMGcgICAwE/CtzRJT09PZ8uW06SmyujUqTaGhgb5foz84ujRQ7x9G4ef3ybEYjGurm2QSNIw\nNFTF1NSMgwf3YW1ti7l5eYKDFZmr0qXL5KrvmJhovLwGUaWKNXfvhvL48SMOHDiOvn4h1q9fw9Gj\nhzAwMKRIkaJYWFRS+pOdOnWc+fNnkZiYwJgxk6hSxYo1a1aQlpbGjRvX6dHDk0aNmuTpPHMq0+ze\nfSS//95Yufz8uSfu7vqMHJk1gFi92j9bf+9bClhYWLJ3715lsJ6ZvRKJRPTrN5B+/QZm2Tc5OYnO\nnbuir6+wIBg9egKtWjmxYMEcAAoV0gfkSCSSbKbT+vqF0NTUomJFhThFpk1DTEw0IpGIsWOHK0V5\nAgK2EhZ2hxcvnvP2bRyRkRFYW9ty8OB+Dh8+oPTIu3kzBBeXzoDCA65YMVOePo1AJBJRvXpNtLV1\nAIW6Y0xMdJ6Duc+V7Do7t2Tt2lWkpaUpzcU/VBGtWbM2DRv+j/Pnz3L9+jX8/dcyffoczMzyx3Lh\nw2C/c+duVKlizcSJowE5FStWol27d2WeCQkJeHh0QV1dXVma26ZNe8aMGU7Pnl2pVasOWlrayvaZ\nAX779q6MHz+Kw4cPZmlTvnwFxGIxPXt2pUULhXJp5suCD5UyjY0L4+fnj5vbAF6+1AFeY239O4UK\nGTNrlkKVcsSIwWzaFJAlo3f69Elq1arL7Nm+BAYGsHz5Yjw9+7J1q+IF1qRJrahe/S8uXXrK1Kmd\naNkya7mygMB/ASGYExAQEPhJ+FYm6VKpFA+PAI4f9wDUCQjYyI4d/+Ps2WMEBV37oqzWtyQpKQlD\nQyPEYjHBwUE8exaj3GZrW5WtWzcxbtxkypUzZ9EiXypVqpyn/qOiIpk4cSqVK1vh6toGgNDQ25w5\ncxJ//22kp6fj6dk9S1llRkYGq1f7888/51m3bhV//rmM337rz927oQwdOvKLzjOnjF+FCsXR1b1H\nYqLCX05F5SVmZprZ2uU3IpFI+ZB94EAQhw69JjFRwqpVvhQvXixb+w9Np98n06ahWLHiFC9uppSY\nDw4OYs2aFVSsaMmQIcNZsmQBaWlpjBgxlrNnz/DyZSy9e/dg7dqNnxxrVgNr8ReJ+HyqZNfcvDz+\n/n78+ecyNDU12bRp/UdVRHV0dKlXrwEODvWzZMrzg5yC/erVa+DntxkTE71sWfVu3dzp339wlnWG\nhkZZTNIzt79vCVCiRMkshumZbVRVVbPNxczM7Ovr6ys9MQFcXdsgFosZMqQrx4+foVmzyyxd+gJd\nXRPGjlVkjpOTk5WKmZm8ePECiSSNN2/e/JtpW5Tl+w6K+cGZ5u/W1racOHGUpk2dOXr08EevnYDA\nr4SgZikgICDwH+fUqSCOH+8AqAMq3Ljhjp/fPz96WNnIDG6aNm1OWFgoHh5uHD58gNKl35Wv2dhU\n5fXrV1hZWWNoaISGhobScDu3FC1qSuXK7zyv5HI5N2+GUL++I2pqamhra+PgUD/LPg0bKuTnLSws\nlQ+bijlqHxeJ+BTvP0y/j7V1Rby8HlGy5G6KFduPu/sBunT59kIP1arV4NSp4+zde46hQw3YubMx\nr183wt19PunpinlW4eH3Prr/y5exREQ8Jjk5WWnTkJSUQEJCPKB4ofDo0QP09PRQUVEhJiaa27dv\nAYrgWl1dne7de2JgYMDz58+xta2qnOsXEfGE58+fUbp0mRyv95fcgw9Ldq2srJUluxoaGjx+/JD+\n/T3p1asrhw8f/KyK6Jd+Dj5F3kRhfvxk2MOHgwkMvM3bt2k0bVobkLNqlb9SZTcw8ABaWlpZzkss\nVqFPn9/x8hrI7797IpFIePXqVY4+dgB//DGCwMAAPDzcePky9qcQzhEQ+FqEzJyAgIDAL8DnysLq\n1HFg06b1yOVy6tSpp3y77uRUn6pV61Kq1J+8eDEVdfXHGBmt4syZdNTVv49yZ27JNBIuVMjgo8bb\n9vY1OXXqXSCaOY8qL2hp5ZTpEn3wQJ714TxT6VFFRfzFUuu5ITExETOzOC5fbk9s7AuWLPkTkajD\nNzteJmXLlsPd3ZMFC3woVMgQTc3KvHgxAZFoBO7unRGLValatRojRihsEz58hi5e3IyjRw+RkBDP\nqVPHadWqHZ6e/Vi+fBE9e3ZFJpPSsWNnZDIZoaG32bVrO1ZW1gAsW7aQ2NgX9O/fh1q1alOhQkVK\nly7DvHk+eHi4IRaLGT9+CqqqqlmMqTP5kgd6VdWPl+yampphb18rTyqi+R1UfCzY/xgBAXvz9fh5\nJTExlbFji6OiUgMNjet4ee2mRo3aBARso2vXHgCEh9+lQgWLLN8za2tb5PIM1q3bwu7dO1m2bBFV\nqlgpfwsAHB0bK33yTE2LZ/lt+O23/t/pDAUEfhxCMCcgICDwC/CpsrCSJUuxYsUS/Pw2oaurh5fX\nIM6ePU39+o6kpqbi7NyEFy9iiYoywtR0OAYGrqxf3xJv77GUK/dzmT9HRT1nzZqrZGRAr162lClj\n9tV9ikQibGxsmTNnJj169EIqlXLhwjnatv10EKWjo0NycvJXH/993pf3L1bMlOnTZ+dr/5/C2bkV\nISEifH07AopSxrQ0N5YsqYCxsbGy3ftz8wB8fZcwevSwbCWBQBYxH4B27Tpma5OTEqm6ujrjxk3O\ncYzOzq2Uy5klnF/Cx0p2q1Sxxtd3dq5VRBU2ITkLhfxXSEmRk5xcAV3dSECds2cNOXHCiz//nIuH\nRxdkMpnyZcD7Afkff4zA23sCmzf7U69ew08GxSdPhrBnzzM0NaX88Yc9ZmZFv9PZCQj8WIRgTkBA\nQOAX4FNKjg4ODahWzZ5ChRSiJk5Ozbl+/Rr16zuioqLC//7XhIYNM5g7dxl375qyeHEHdHV1adGi\nBaGhHy+dK2i8fv2Gbt0uc+dOF0DE8eMBBASoUbx43sQvsj4wKv5vaVmZevUa4OHhhpGRMebm5dHV\n/ZhdgWIfOzt7Nm1aT69eXb9IACUnVqxYTFRUJL16daVEiVI8efKIDRu2c/DgPs6ePU1qaiqRkU9x\nc+uGRJLG8eOHUVNTZ+7chejr6xMVFYmv7xzi4t6gp6eDl9cYpc1Fbhg6tDG3bq3j4kVLdHVfM2iQ\napZA7mPkJTPl53eaPXvSUFWV0aePCS1a2Odqv4sXb3Ht2lNq1CiDvX2lz+/wGWxt7di4cR1WVtZo\naGgqS3YNDAwYP34KU6aMIy1NUWL6KRXRxo2bMnv2DHbu3M60abPyTQDlZ8LC4g/u3DEgPr490J6K\nFbdjaGiEt7dPtrbvB+S5zbT9808oAweKefXKBZATHLyRv/5qhra2do7tBQR+JYRgTkBAQOAX4NNl\nYQr58Xe8859SV9dAJBIhFotxcLBGKn2pDFK+xTyfb8nevZe5c6czmcFUeLgLu3cHMHCgc677+LB8\n7f3ytC5deuDp2ZfU1FQGDeqLhYUiYHjf3qFQoUKsWOGHRCJBX1+f1as3fOVZZaV//yE8evSQdeu2\n8OxZDKNGDVVuy1wvkUjo3LktAwb8gZ/fZhYv9uXw4QN06tSFOXNmMHLkOEqUKEl09ENmz579SUPx\nD9HU1GTjRjfi4t6gpVVJKTzxKfJSEnjy5DWmTatAUpIlAOHhp6lcOZIyZT4dAPn7n2HatNLEx3fC\nwCCYqVPP4+b2dWXC1avX+GjJbrVq9jne25xURK2tbdm0acdXjeVnx8vLmnv3tnH7dk2KFLnPkCGG\nn93n9u1wbt9+SsOG1hQtavLJtseOPeHVK9d/l0TcuOHEtWuhODhU/+R+AgK/AoIAioCAgMAvQmZZ\nWNWq1bC1tWPPnl1UrGhBpUpVuH49mLdv45DJZBw/fjTH0rVKlay4fj2Y+Pi3SKVSDh/+udTgChfW\nQSR6/d6aBAwN1fOt/zlzZtCrV1d69+6Oo2MjKlSwyLI9Pj4eV9cd1KwZhYPD3wQEXMy3Y2fyfoD9\nYbBtZ2ePlpYWBgYG6Orq4eCgEEYpV648z55Fk5KSws2bN5g4cTS9enVl8uTJvHr1Ks9jEIlEGBoa\n5SqQyyvBwS+UgRzA8+d1uHQp7LP7bd2aSny8Yo5dXFw1tmxJzPex5YV//gmlY8cDODkdZ/Lkv366\nFyP5jYVFaQ4ebMSxYy84c6YinTrV+WT71atP07ZtBoMGNaNVqztcvhz6yfaGhgDvlDB1dSMwMyuc\nDyMXECj4CJk5AQEBgV+Ej5WFGRsX5vffBzFkyO/I5XLq1q1PvXqKB/33y98KFy6Mp2df+vXrha6u\nHjY2VnxDLY98p1UrBzp33smuXfbI5WJatbqAm1unfOt/8uTpn9zu43OGv//2BFRISIDZswNp2zYN\ndfX8Cyg/RVZJfhXlsoqKCjKZDLk8Az09Pdat2wKQo3z9j8bKygiT9TckAAAgAElEQVRNzQekppoD\nULhwENWrf8m8zR8XPKWmpjJy5BPu3XMD4MaN1xQrdoL+/b+8zDYxMZFjxw7Tvr3LR9s8exbDzZsh\nODk1/2RfMTHRjB49jA0btn/xeL4ELS0tbG2rfLadXC7Hz09CfLyivPbJk1asWLGdmjU/Xjrbv38T\nQkI2ceaMBRoaifTtm0qZMk75NnYBgYKMEMwJCAgI/CJ8qiysSZNmNGnSLNs+76vCAbRo0ZoWLVoD\nBfNh/1OIRCIWLnRhyJCHZGRIqVCh83eVJo+P1+D9gpe4uMIkJSWirm6Ub8fQ1tbOs6hKZlZIW1uH\n4sWLc+rUcf73vybI5XLu3w+nfPmCI3LTvHkNRo48xt69IaiqSunTx5Dy5W0+u1/Xrlrcv3+D+Hgb\nDAyu0q2b/ncYbc48exbDw4fvsrZyuRH37+fd6+593he++RjR0VEcO3bks8Hcz4BUKs6ynJ4u/khL\nBaqqqqxe3Zk3b16jqaklzJUT+E8hBHMCAgIC/2Ey3/gXK1aF06efYGqqjofH/34qf6ZDh/azbdtm\nRCIR5ubladTICX//tUil6ejrF2Ly5OkYGhpx7dpVFi2aDygCv6VL16ClpcWWLRs4deo4aWnpNGjg\nSO/e/b5oHI6Oeuzff4+UlIpABtWq3cPAoGo+nqnClsHa2hZ3986ULl1WeZ+yS/Jn9eHK3DZp0nTm\nzZuFv78fkIGjY5MCFcwBDB7sxODBn2/3Pu7uDbC0vM3VqzuoWbMs1avX/TaDywXFiplSvvxpwsIU\nQaiKykssLD4djHyO94VvatSohVwOly5dQCQS4e7em8aNnVixYgkREY/p1asrrq4uVKtWh2nTJimN\nuL28RmFl9fnA+EcjEolo2TKFVaueIZUWw8AgCBeXQrnaz8jo82I8AgK/GiJ5ASnk/pne/grkjZ/t\n7b5A3hDu789NTEw0Awf+TmjoLF69qoVI9IZu3fbg6+uS473N/JNRUIK9hw8fMH78SFauXIe+fiHi\n4+MRiUTo6ekBsG/fHp48ecygQUMZPXoYPXr0wsrKhtTUVNTU1Lh69QqnT59g1KjxZGRkMGbMcLp1\nc8+z0XgmAQH/cOZMAoUKSRgzxlE5joKI8N39dly5cpc5c+6TmKhO3bqpTJjQ6qu+M5liNxs2bOf0\n6RPs3RuIr+8S4uLe0KePO6tWrSci4glbt25izpwFmJjoERkZi0ikgrq6Ok+fRuDtPYE1azb8sDLL\nvCCXy9m58zyPHiVSv34p6tSp/KOHVKAQvru/LiYmef+bIWTmBAQEBP7DrFixmNjYWHR0fBCJ6iKT\nGXPhwjbc3XfSokVz3Nx6EhMTjZfXIKpUseb27ZukpaWRmJiISARSqRQHh/oYGRVm//49SKVSrK1t\nmTv3TzQ0NJkxYwoaGpqEh9/lzZvXjBkzkYMH9xEWdofKla2UXmGXL1/Ez28VaWlpmJmVYNy4yWhp\naX12/MHBV2jUyAl9fcWbe319fR48uM+kSWN4/foV6enpFC+u8JqztrZl0SJfmjZtTsOGjTAxKcLl\nyxe5cuUSvXp1BSAlRSHt/6XBnKtrHVxdP9/ueyOVSvH1PcajR2IqVJAzdKgwn+hbUqOGBQEBFp9v\nmEvef+9+48Z1nJyaK4VoqlatRmjoHXR0dLLsk54uZcGC2dy/H46KigpPn0bk23i+NSKRCFfXej96\nGAICPwWCmqWAgIDAf5j+/YegoWFIRMQekpProqb2hPT0AaxZs4Hbt28TEnINgKioSDp0cMXXdwmx\nsS9ITU1h2bK11KlTjzt3bvP27RuOHTvLtGmziI2NZf9+haS/SCQiMTGBlSvXMWSIF2PGDKdrV3c2\nbtzBgwf3CQ+/R1xcHBs2+LFw4TL8/DZhYWHJ9u2bczV+kUiUTSlwwYI5uLi44e+/jZEjxyGRSADo\n3r0nY8ZMRCKR0L9/byIiHivXr1u3hXXrtrBtWyAtW7bJp6tbcBg7dh/z5rVi166OzJrlxJQp+3/0\nkAS+kJw+8zll/bZv34yxcWH8/bexZs1G0tPTv9cQBQQEviNCMCcgICDwH0Yul2NsrEGFCjvR1j6J\nru5JTE0X0a9fTx49ekRk5FMAihY1pXJlKwCKFCmGqakZ5cqZY2lZCV1dXUxNzRgwoA/Lli0iJiaa\nR48eKY/h4FAfgLJlzTEyMqZcOXNEIhFly5bj2bNobt++yePHD/n9d0969erK4cMHef78Wa7GX61a\nDU6dOk58/FsA4uPfkpycROHCCl+qQ4feBS1RUZGUK2dOt24eWFpWJiLiCbVq1ebAgb+U84piY1/w\n5s2br7yqBY/gYD0gUxSiEEFBn896ChQc3he+sbGpyokTx8jIyODNmzeEhFyjcuUqaGlpk5ycpNwn\nOTlJOYfs8OEDZGR8nQiLgIBAwUQosxQQEBD4j6Ohoc6BA7WYOvU4Fhbt6NdPIQCSOS8jJiYaLa13\nnmJqaqqoqWXK3iuEHfbu3c3ChcvQ1tZmwIA+pKVJ3mv/TiL/Q/l8mUyGiooYe/taTJky47NjXbt2\nJdraOnTp0h2AsmXL4e7uyaBBfVFREVOxogWenn2ZOHE0enr6VK9uz7NnMQAEBGwlODgIkUiFcuXM\nqV3bAVVVVR4/fszvv/cCFA/NEydOw9Dw86bGPxMGBikfLKd+t2P/DHO0voSzZ09TsmRpypQp+82P\n9b7wTe3adSlfvjw9e3ZBJBIxYMAfGBoaoaenj1gspmfPrnTq5EL79q6MHz+Kw4cPUqtWHbS03ik8\nFpQ5rwICAl+PEMwJCAgI/IfJfONvYGCAm1tH1qxZQUqKO1paWjx//py3byWf7wRIS5NgZGRMYmIC\niYm5n5gvEomoUsUaX9/ZREVFYmZWgpSUFF6+jKVEiZLKNu+3/xBn51Y4O7fKsq5evYYAyGQyxGJF\nwDl06Mhs+6anp9OmTXtcXd0+O9bZs6fj5tad0qXLsGGDH+7unkDuPMB+NOPHWzJy5GYiIkpStuwT\nxo8v+KqGBZ2//z6Ng0P97xLMQXafwwED/siyrKqqysKFy4F3L2L8/bcqt/fvr5AINTUtjr//tm88\nWgEBge+FEMwJCAgI/If58I2/k1NzZZZKX1+PsWOnZJO9zy6DD02aNKVv355oa2tnM8n+VDB27tzf\nrF27EhUVFYYM6YeGhhYxMVFYW9vy5s1r5s5dxJEjBzh8+ACGhkYUKVIUCwuFeXBUVCS+vnOIi3uD\npqYmo0ePp1SpMsyYMQV1dXXCw+9hY1OVQYOG5njuvr5H2bBBjFSqRsuWL5k1q/1HMxYZGRmMHj1B\nubxx43plMJcbD7AfTbVqFTh2zJz4+LcUKlT1izIz69ev4ejRQxgYGCrvg719DebO9UEikWBmVoKx\nYyehp6dHWFgoPj5TEYlE1KxZ6xuc0ddz5MhBdu7cjlSaTuXKVgwfPgZf39mEhYUikaTi6NhYaVOx\nfPlizp8/i1gspmbN2jRs+D/Onz/L9evX8Pdfy/TpczAzK/GDzyhnNm8+x7p1ichkYtq1gz/+EMRv\nBAR+JQRrAoFvjiCh+2sj3N9fl299bxUP/N7Mn78YVVU1Bg/uy6RJ0+jduwcrVvhRubKVss2qVf7I\nZFI8PbvTrl1H3Ny688cf/Rk5chwlSpTk9u1brFq1lIULlzNjxhTi498ya5bvR4OWoKDb9Ox5grS0\nUsTF9aBIkYlYWQWzbds2rl69wv79ezl37m/atu1AUNBlvLxGsWrVMgYNGsapU8fZtm0T5cqZU7as\nOTKZjHPnzlCqVGlq1KjNgAFDcvSui4mJZsSIIdjY2HHrVggmJkXw8ZmPhobGN7vGnyIv9zc09DZz\n5sxg1Sp/0tPT8fTsTtu2HTh8+ABeXqOwtbVj7dqVJCUlMmTIcDw83PDyGoOtbVWWLVvIxYsXClSZ\n5ePHj1i+fBEzZ85DLBYzb94srKysqVu3Pvr6+shkMoYOHcDQoSMpXLgw/fv3ZsuWXQAkJSWio6PL\nzJneODjUp2HDRj/4bLKTeW9DQx/Qtm0qcXG1AdDUfMTKlQ9xdq75g0co8DUIf3d/XQRrAgEBAQGB\n78alS6EsXPiIlBRVmjRRYeDAJnna//r1YOLji+Dg8AhNzQTq1ClLSMi1LGIrN25co0GD//0b8Gjg\n4NAAgJSUFG7evMHEiaOV/aWnSwFF9u9//2vyyezTvXvRxMc3w9BwI3FxPVBXf0BSUipSqZQbN65T\ntWo1jh8/QpUqVsrMXmZGsn//wQQGBrBu3RZA4QH26NED5fLlyxeJjHzK6tUblN51ISHXKFKkKJGR\nT/H29mH06PFMmjSWM2dO0rSpc56u24/g5s0Q6td3RE1NDTU1NRwc6pOamkJiYoLSxqF585ZMnDiG\nxMREEhMTsbVVGKY3a9aSixcv/MjhZ+Pq1cvcvRtGnz49AEhLS8PY2JiTJ4/y1197kMlkvHr1kseP\nH1GmTFnU1TXw8ZlK3br1lYI+QDZVyYLG1asPiIt7p86amlqWO3eCcC74HzkBAYFcIgRzAgICAgJ5\nJiEhnj/+iObhw84ABAU9plixC3TsWDfXfVy4cI+wMCvevGkMwD//nKN69bgsYivwYUCmeHiWyzPQ\n09NTBlAfoqmpmeP6TBo3tsPM7Boy2W1EokTU1FKoWtWasLBQQkKuMXToSFRUVHB0bPzZ8/jwgf5j\n3nVFihTF1NSM8uUrAGBhYUlMTPRn+y8YZJfDzy0FNeBxdm5Fv34DlcvR0VF4eQ1izZqN6OoqMm9p\naRLEYjGrV/sTFHSZ06dPEBi4Qzk3raALidSrV5lixc7x7Nn/ANDXv0WNGmY/eFQCAgL5iWBNICAg\nICCQZ0JDH/LwYXXlskRShpCQvJX9qKmVQVf3DCJRKiJRMurqtzA0NM3SpmpVO/7++zQSiYTk5CTO\nnz8HgLa2DsWLF+fUqeOAImC4fz8818cuWrQwK1eWxchIjQYNJuPkZE7jxo4EB18hKipKmY350of1\nj3nXva/mee/eXRISfo5SKRsbW86fP0taWhrJyclcuHAWTU0t9PT0CQm5Dijk7+3sqqOrq4uurh43\nbijWHz166EcOPUeqV6/JqVMnlDYU8fFvef78GZqaWujo6PD69StlNjElRZGBrFPHgcGDvbh//x6g\nEA9KSkr66DEKAmXKlGDePBUcHQOoV28X3t4RNGhg/aOHJSAgkI8ImTkBAQEBgTxTvnxJihe/RXS0\nQnFSRSWWsmXVP7NXVtq1q8mJE28oVcoVAG3tilSvbs2+fe8CqIoVLWnc2ImePbtgaGhE5cpVlNsm\nTZrOvHmz8Pf3QyqV0qRJU2XWKzdBmI1NeTp1asKBA3/Rvv1kypUzZ9EiXypVqvzZfVVVVZFKpaiq\nqmbxAAOoVas2q1evoGlTZ7S0tIiNfYGqqlq2PkJD76Cjo/PZYxUELC0rU69eAzw83DAyMsbcvDx6\nerqMHz+FefN8SE1NxcysBOPGTQZg3LjJ/wqgQI0atQtcBqtMmbL89lt/vLwGkpEhR01NjWHDRlGx\nogVdu3akSJFi2NjYAgq/tjFjhpOWlgbIGTzYC4DGjZsye/YMdu7czrRps76bAEr//p4sX+730e0u\nLq3Zu3cPoFBxbdq0Gk2bftmxnJzqc+zY2S/b+V/27NmFpqYmzZu35ODBfdSsWYfChQt/VZ8CAgLv\nEARQBL45wkTdXxvh/v66fO7eHjx4lUWLnpOSoo6jYzJTprTO80P74cNX2b//FWpqaQwbVp1SpUw/\nv1M+cvXqFUaMGMLhw6fQ0NCkS5cOtG/vQqdOXWnatCFHj55Rth08uB+DBg3DwsLyX3XDv7GwsGTi\nxGl4e0/gwYNw7O1r8fTpE8LD7xEf/xZDQyN0dfVQV1dHIkklJiaGbdsCuXHjOt7eE9HR0aZo0WIs\nX+733YVQ8vrdTUlJQUtLi9TUVAYN6svo0eOpUMEiWzu5XK4UCSloQdx/AVfXNuzZs5v0dPFX9+Xk\n1IBjx/7Oh1EpGDy4HwMHDsXSslK+9flfRPi7++vyJQIoQjAn8M0RfnR+bYT7++vys93bK1duEhn5\nmiZNqqGnp/iDGBYWyuHDBxg6dMRH9wsPv8fLl7HUqePw1WM4ffoEly5dZPTo8YBC+XDEiCHMmuVL\noUIGnDhxlMuXLzJ27KQsweGPIK/319t7Ao8fPyQtLQ1n51Z0794zW5t79yIYOjSIBw/MKFXqGXPm\nVMbOrkI+jvrH8ezZS0aO/JsnT/QpUyaeuXMbUrSo8XcfR2a27OXLl0yePJbk5CRkMhkjRozFxqZq\nlmBu7NgRvHjxnLQ0Ca6uXWjTpr2yD1fXLly4cA4NDQ1mzZqPoaER0dFReHtPIDU1BQeHBgQEbMtz\nMHfo0H62bduMSCTC3Lw8ZmYl0NLSxtTUlBkzvDExMUFDQ4O+fQfw11978PGZB8CVKxfZvXsXM2fO\nzfdr9qvxs/02C+QeQc1SQEBAQOA/yZQp+1mzphppabZYWe1h48ZamJkVxdKy0mezAOHhd7l7NzRf\ngjlz8wosXbqQ5csXU7duffT0dHn48AEDB/YlKiqJtDQRGhradO4cBRRccZCc+NC0OiemT79OUJAH\nAG/ewPTpW9i169cI5saMOcuRI+6AiLAwOaqqG/Hza//Njjdy5B9MmTIDHR3dD7Yosp3Hjh2mVq06\nuLt7kpGRQWpqarY+xo6dhL6+PhJJKr/95oGjY2P09fVJTU3FysqGvn0HsGzZIv76azceHr1ZuHAe\nHTq40qxZCwIDA/I85ocPH7Bhgx8rV65DX78Q8fHx7Ny5DZEIHB0bs2vXjiwvMJYs+ZO3b+MoVMiA\nAwf20apV2zwfU0Dgv44ggCIgICAgUGBJSUlh5Mg/6NmzK+7unTlx4hhBQZfx9OyGh4cbPj5TiYmJ\nYdMmM0QiCSVL9iY+fgd9+vQlOTmZ4OAgRo0apuxr5kxvfvvNA0/Pbpw7dwapVMqaNSs4ceIYnp7d\nOHHiGG5uHYiLiwMUZuFubu15+zYuV+MtWbIUfn6bMTcvz+rVyzh9+iRly5qjodGV27fPEB5+hlu3\nDjFhQhBQ8NUQ88rr11pZll+90vpIy5+PqCg93qmriv5d/nbMnbswWyAnl8uVLwAqV67CwYP78PNb\nxYMH99HW1s7WR0DAVnr27Eq/fp68ePGcyMgIANTU1Khbtx4AFhaVePYsBoBbt27QpEkzAJo1y7t/\nQXDwFRo1ckJfvxAA+vr6PHnymNevXyvb7Nmzk6Cgy/8eowVHjhwkISGB27dvUbt27tVwsx733fdc\nQOC/hpCZExAQEMhHZDIZYvHXz1URUHDp0gUKFy7C3LkLAUhMTMTdvTOLFq2gRImSTJ8+mQMH9pKW\nVh9TUy9iYv5EIrHC0XFztjloGzb4YW9fk3HjJpOQkEDfvh7Y29fit9/6c/duKEOHjgQgIuIxR48e\nolOnLgQFXaZ8+YoUKmSQq/G+fPkSPT09mjZ1RkdHlz17dhIXF8fr169QBALpqKs/ISZGl5IltUlK\nSszPy/XDqV49lcuX4wF9IBU7u/gfPaR8o2zZeEJCMlC8B8+gbNm3+dZ3TuWQLi6t8fPbRFJSEl5e\ng6hSxZq7d0OVwZytrR1Ll67mwoVzzJw5hc6du9G8eUtln8HBQVy9eoWVK9ehoaHB4MH9/hVxAbH4\n3eOfiooImUyWZTw7dmxRBnWfYseOLbRt2wENDYUViEiU3cLiyZPHqKm9EwBq185FmZlr0aINo0cP\nQ11dnUaNmqCiIuQYBATyihDMCQgICOSB9evXcPToIQwMDClSpCgWFpW4cOEsFSpU5MaNkH8VFSuy\nbNlCZDIZlpaVGTFiLGpqasqHM339QoSF3WHp0oUsXryStWtXEh0dSVRUFHFxcXTr5k7r1u1+9KkW\nCD4sW9TW1qZ4cTNKlFCoaDo7tyIwcAcNG6Zz544xEokVJUocoWtXy2xB9eXLFzl//m+2bt0IQHp6\nOs+fP8uS7QBo2bINY8YMp1OnLhw4sJeWLVvnerwPH95n6dKFqKiIUFVVY8SIsaioqDBkyBhKlTqJ\nSJTBmzfuVKiQQosWrZk3zwdNTc0fIoDyLZg0qQV6eke4e1eF0qXTGT0699euoDN/vhNqapuIiNCl\ndOkEfHy+UCIyB7KXQzbKkrWNiopk4sSpVK5shZNTAwCePXuGiYkJrVu3Iy1NQnj43SzBXHJyEnp6\nemhoaPDkyWNu37712XFYW9ty4sRRAgK2IZVKc2wTExPNiBFDsLGx49ChfZw7dwYrK1siIp4QEfGY\np08jePAgnKlTfThx4hgPHoTz4sUzunbtSGpqCitXLqFNm/Y4Ojbm8eOHREZGsGDBXOrXb0h6erry\nt9LZuRXnz59FJpMybdosSpUqw507t1i0yJe0NAkaGhqMHTuZUqVKf+XVFxD4uRGCOQEBAYFcEhp6\nmzNnTuLvv4309HQ8PbtjYaGYj6Uo19uARCKhS5cOWTJHu3fvpFOnLp8sqXv48AErV64nJSWZXr26\nUadOvWzy3R+TCf+Vpb8zyxb/+eccq1cvo3r1Glm2ZwZhkyY1Y/To07i5BdCypQWVK5fNsb8ZM+ZS\nsmSpLOvu3Mn6kFukSFGMjIy4evUKoaF3mDJlZq7HW7NmbWrWrJ1t/e7dW5g06QiPHmlRr14y06Y1\nRVdXl4YNG+XYz6BBfRk0aNhPp/onFosZMaL5jx7GN0FPT4+lS7/NHLmAgK2cPatQTn3x4gVPnz7N\nsr1oUVMqV7YC3pXmXrsWxNatG/+1x9BhwgTvLPvUqlWXPXt20b27KyVLlsbK6p2/3Pu/Renp6Vy5\ncomePbsikaSyYsUSXrx4ztatG5FIFPPw5s3zISwsFIkklerVaxIZ+RQHh4aIRCLu3r3Lw4cPcXHp\nRJs27Zk/fzb//HMBZ+dG1KpVBzU1NRwcGjB+/BTOnDnJzJnePH0agb19TWbO9MbTsx+nTh1HU1Mr\ny2+lgYEhfn6b2L17J1u3bmL06AmUKVOWpUtXIxaLuXLlEqtWLWX69Dn5f0MEBH4ihGBOQEBAIJfc\nvBlC/fqOqKmp/fuAUl+5rXFjxVv6iIgnOWaOOnXq8tF+RSIR9eo1RF1dHXV1dapVsyc09Bb16zt+\n2DLH/du166j8/6FD+ylXrvwPCeYSExM5duww7du7EBwcxLZtm5kzZ8FX9flh2WJgYADPnsUQFRWJ\nmVkJjhw5iJ1ddcqWLYeqqpTWrUthaVmW5OQkZelXJjVr1mbnzm0MGzYKgHv3wqhY0TKbTxxA69bt\nmDp1Is7OrfJlXpumpiY+Pq1YtuwEkZFaHDlyg44dPz4/SCQS/XLz6QRyJudySEmWNlpa7z7LmXYZ\nzs6tcHZula2/gIC/MDDQIz09gXnzFuV4zPctN9TV1ald2yGLAmv37p3Q0NCgRo3adO/uSokSJVmy\nZBXdu7ty8+Z1Chc2oVGjxuzcuZVOnboQEnKNq1eD8PNbjYqKGD09PczMSlC8uBnq6ho4ONQnODiI\nAwf20bBhI6pXr8H06ZOJj3/Lhg1rady4KQ0bNmLYsEHK38rMFx0VK1py5sxJABISEpg2bTJRUU8R\niUQfzR5msnbtSrS1dejSpfsn2wkI/MwIwZyAgIBArsk+HyQTTc2chR7kcrnyoVwsFpORodhfIknL\n1nbLlg2oqyuMt/fv/4udO7ezcOFyrl69wv79ewFYtWpZNjnxzAcWU1NTwsJCmTp1grJ079GjhyxZ\nsoCUlBQKFTJg/PjJGBt/m0AvISGe3bsDaN/eJd/6/LBs0ctrFMnJyUycOBqZTEalSlVo184FVVVV\npk71YcGCuUgkEjQ1NVmwYOm/QZGir549+7Bo0Xw8PNzIyMigeHEzZs9egJ2dPZs2radXr650796L\nxo2dcHBowMyZ3rRokX9lgsOH72bLFldAl61bH/D27WmcnSsyfPhgLC0rc+9eGGXKlGPixKwZlnnz\nZhEWdgeJJBVHx8b07t0PUGSKFy2aT0pKKmpqaixatAJ1dXVWrFjC9etXSUtLp0MHV9q27ZBv5yCQ\n/7xfDvn48aNclUPmJ+bmFfD1nUvnzl4UKWLOpElugKK0c8IEb6ysbPj9d0/c3NoTF/eGxMTELGIr\nYrEKGRkZ3LwZgra2NkWKFOXx40c8fvwQU1OFb+SHLyZOnTpOZGQ86ekZFC9uSq9efXnwIJz3m6mr\nqyn7z5zTt2bNCuzta+DjM49nz2IYPLjfJ89NeCEi8F9ACOYEBAQEcomNjS1z5sykR49eSKVSLlw4\nS5s2igflzCCvVKnSxMREZ8kcVa1aDYBixUwJC7tD7dp1OXPmhLJfuVzOuXNnGDp0JFu3biQ8/B7G\nxsaA4s3zjRvXqVq1GsePH8lRTjwzYPlQ+lsqlfLnn3OZPfudx9mqVcsYO3bSN7k+K1YsJioqkl69\nuqKqqoqmphYTJozm0aMHWFhUYtKkaYDC+y2nADM8/C5z5/ogkUgwMyvB2LGTqFmzNhs2+FGxogU3\nboRw8eIFDh7cz9atu1BVVSUpKZEuXTqybVsglpaVWblyXZYx2dlVx86uOgAaGhqMHDku27j19fVZ\nvXoDoBCwCQq6QUzMU8qXr5iv83HOnzcEFOqEqanmHD9+HWdnePo0gnHjJmNlZcP/2TvzgBjzP46/\npmu6iwodEklFypH7vuW2ZLGI3NbNSm65rXWvYyMikZy5WbfcVO4rQqdC0TXVzPz+mF+jUVjk2n1e\n/3iO7/U8M43n83w+n/dnzhwfduzYptJvwIAhGBoaIpVKGTlyCJGRD7C2LsXUqRPw8ZmLg4Mj6enp\naGlpsXfvbvT19fH13UBWVhZDhvSjevWamJtbFNp1CBQu7w6HfGOIfEmjJClJQlTURNLScjAy2kqX\nLjMwMpJjamqGk5MzsbExxMfHYWdXjqioR5QpU5Z79+6ojCESiVBTU6NKlWr4+MyhU6fW2NiUYdCg\nYVy6dJG0tDQMDAyV7S9cuMb9+2uxtBzG1auehIU95ty5/TYq9kAAACAASURBVEoBFLlcztq1qwkL\nu4JEkoWmpuJxNS0tjVu3bhASspOXL18oX4qdPXuaa9fC6N27O1ZWVkye7JPPMy8g8G9FMOYEBAQE\n/iEODuWpW7c+Hh5dKVrUBFvbsujr66uExInFYiZMmJrPcwTQp88A5s71Yc0afSpXrqrsoyiua8eK\nFUu5c+c2w4eP4vTpk5QpY8udO7eJiAhj5Mjf8smJX758ocB15hqWT55E8ehRJCNHDgEUMvsmJmZf\n7P4MHjycR48esm5dIGFhV/D2HkNAQDAmJqYMHtyXa9fCKV/e6Z0G5syZUxk92gsXl8qsXbuadev+\nYvjwMcpwqjVrFAZXXFws586doV69hvz992EaNmxcKAqi2dnZ9OkTzKVLLzE23ouTU3MVz+rnoqcn\neWtf8SBarFhxnJycAYVUe3DwFpV2x44dJiRkF1KplOfPk4iKegiAiYmpMqcu11Ny6dJ5IiMfcOKE\n4mVBWloa0dFPBWPuO0ZTU7PAcMh16wLIzs7G3NwCf/8tBfQsHLZvDych4RfkcjEymQEyWQBGRmqA\n4nckLS0NLS2xUpHy2rVwdHS0kUiyUFNTVypkurhUJjT0FP369URf3wCZTEZcXCz6+voEBm4kJyeb\nYsWKk5WlS0aGBnK5AfHxszEzm8/s2c9p2rQmGhqKOTIzJTx8GIm//xYuXbrA+PGjef48CWfnSvz1\n159YW5eibduOHDy4//9zV+Hp0yfMm7cIX9+V7N27m06dfv5i90xA4HtCMOYEBAQEPoJu3Xri6TmA\nzMxMhg4dgIODYz7lyapVq+HntylfXxeXSmzevKPAcW1t7Zg0aTojRgxBLpdTsaILtrZluXr1EjEx\nMdjYlP6gnHguucaHXA6lS9uyapXfp17uR5E3BFUul+PoWAFTU4XxWLZsOeLj49DX1y/QwExLSyU1\nNRUXl8oAtGzZmsmTxyvHy81JBEU+W2DgBurVa8iBA3vx8ppUKOtft+4Yhw97ANq8fDmex49jOHbs\nIk2a1CgUQZJRo0yZOnUfsbFlqVAhjDFjVAUtgHzGY2xsDFu2bGLNmo3o6+sze/Z0srKyeJ99OXr0\nOKpVyy/CIvDjMGfOfjZsMCYnR0yLFn+zdGnnLybbL5XGY23dGblcHblck+TkrtSpc5Xdu7fRt29P\n1q7diJqaGteuhSOVyrCxKU27dh05efIopqamnDhxjOzsbPT19Zk/f7HSQ6+oxReDtrY2GzYEKfNo\nvbwmsXfvVIyNN5GYOIEnT3ZSp44f48e7c/ToEQCaN29B2bLlEIlEVK9ek0aNmnL79i2SkhIZOfI3\nqlRxxdDQkP79BwOgr6/P69ev8fDoSnp6BjVq1Poi90pA4HtEKOghICAg8BHMnz+LPn2607dvDxo2\nbIydnf1njbd790X8/e+zbNl9Jk3a/X+DL4BKlarg4lKZXbu2U65cufeOoZDWV2zr6r6pXWZtXYrk\n5JfcuHEdUChuPnr08LPW+zFoamopt/PmvZQubcu6dYGsWxeIv/8WFi5cxjtSEZXkzUmsWNGFuLg4\nrl69jFQqpXTpMoWy3rQ0OfAmNEsmK8rLlwphlM/xzuWKNLRvX52TJ505efI1+/c3xsHBBoCEhHjl\nZ3TkyEGcnV0AxeealpaGtrYOenp6vHjxnPPnzwJgbW3D8+dJ3LlzC1DkXUmlUqpXr8WOHduUcz55\n8pjMzMxPXvv3wODBnu8937lzW169Kpyab82a1ftwoy/MpUs3WLnShefPW5CS0pCtWzsTEHDii803\nffpAHBzakpAwlVevfqV/fyk9evSiVCkbbGxs6NHDnbJl7di+fR9z5y4kOfkl27dvRV1dg2LFihMY\nuB03tzbY2tqxbp0vGRmZjB07nk2bgnF1rabytyMSgY6ODt27N0BLK5JSpdywt69P/fqqnvWC6tUB\nZGVlsWrVJWrVklCr1k3WrFEIucyePZ0xY8bj778FT8/++QRkBAT+zQieOQEBAYGPYOrUmYU21vPn\nz5k0KYeEhFUAREY+Z9iw9bx48Rwnp4qIxdqIxWKlt0r1oUh1O3f37dplM2bMY8mSBaSmpiKV5vDz\nz90Lzfh5m4JUId/G2tpGaWA6OVUkJyeHp0+fULp0GQwMDImICMfFpRIHD+5T5roVRMuWrfDxmUzv\n3v1UjuetgXXjRgRmZsWYM+cPxowZpvSsJScn079/L4KDQ9i/fw+nT58gMzOTqKgo7OwukJTkiIHB\nXnR1k6lb9y/l2IcO7WfevBlIpVK8vafg6FiBjIwMFi2az6NHD5FKc/D0HEDdug3Yv38PJ08eIzMz\nE5lMxrJlqwEwMjLOV4Dc2roUO3duZe5cH2xsytCxY2dCQ08jEomwsytHuXL2dO/eiWLFSigNvYIE\nXxYvXkHbth2Ii4ulb98eyOVyihQpyuzZv3/U5/i9sXLl+z3LhZtP9u0FM6KinpGZWSnPESOePcsv\nmFRYaGlpsWFDV2JiotHRscHEpCpxcbGoq6szefIMlbbvii6oWbMZCxfeJDOzDi1bimnTpgGASoho\nlSquVKniCsDo0a3p0KECT548o1q1Cujp6amM5+xcmd27d+Dm1oaUlBQiIsIYOnQkQUFXSUhIJCvL\nkcTEyixevIMuXVLIyEinaFETcnJyOHRoP8WKFQd4p2CVgMC/CcGYExAQEPhGREY+JSGhvHJfLjch\nM9Oa48fPKY/lfXDKKyfesGETGjZsAoCn5wDl8QYNGqvULrOzK8fy5W8Mki+JkZExFSu60KvXz4jF\nYooWNcnXRkND450G5sSJ01iwYA6ZmZlYWloxYcLUd87VrFlLfH1X0qxZi3znoqOfMn36HLy8JjJl\nijcnTx57r9R/bp6fRCLB3b0dDRqIKF36ZzQ0rnHq1HG6dOmGXC5HIslk3bpAIiLCmDPHhw0bgtiw\nwQ9X1+pMmDCV169fM2CAB66uNQC4f/8e/v5bMDAweO99K+ihOdf4A955H94WfJHL5aSkJNO370AG\nDvz1vXP+SOTWV0xKSmLqVG+lF3LsWG+cnSuptPX2HsuzZwlkZUlwd+9Gu3YdlWO4u3dTUYJNSEhg\n+/Ygnjx5zMuXL5QvTYB8c40Z442Li+pcoPAK+vkFYGhoVGjX27RpFRwc9nLnjjsA5ubHaNnSrtDG\nLwiRSKQsp5L32D8hPT2dwYNvcfu2oqTAqVP3MTa+QIcONfK1vXs3ipUrbyKTqdOtmzUNG1YvcM4G\nDRpx8+Y1evdW1JwbMmQERYoURUvLkdTU4lhbd0Iu10QisefVKwf69RvEgAG9MTY2pkIFJ+VLpbwv\nugQE/q0IxpyAgIDAN6J8eVvKlj3PgweKhyht7UhcXY0/0OufExQUyvnzqRQvLmX06KbKsgdfknd5\nLnNru8G7DUw7u3L51ChB1bDJ5dq1cBo1aoqenn6+c+bmlpQtq3j4tbd3IC4u9r1rrlzZFR0dHXR0\ndDA0NGT27CGYmpqyb5+UyMj7gOKhsGlTheHo4lKZtLQ0UlNTuXjxPKGhp9i8eSOgEFFJSIhHJBLh\n6lr9g4Zc7tify/PnL+jf/2+uX7fFxOQZkycXo3Xrd3s2fywU9+fIkYPUqFGLXr08kclkBYaPentP\nwdDQEIkkk/79PWjYsAmGhoZkZmYWqAT76lUKP/3kTnT0UyIjHyjHyTuXXC4nIyOj4JV9AUuhSBFj\n1q934s8/t5CTozB6nJxsC32e9/ExoiuRkVHcvv3GKMvMtOPixXA6qKYSk5T0Ak/Pu9y/ryh9cPLk\ncTZtilS5trwvrIYMGcGQISNUxmjRwow9e0rz+PEAQEq9en5YWFjSoUNnOnTonC/nNO+LLgGBfyuC\nMScgICDwkUil0kJRT9TXN2D58lIsXryZzEwtmjfXpEOHRoWwQli37iRTppRHIikDSHjwIIA1a7oU\nytjfkjNnIli79i9evoxi+fJVBbbJrU8FoKamjlQq+X+NP0XO3tv5NKrt1ZT7ampq7xSZAZRv/GfN\n+p2SJa1Vzt26dQMdnYJrD+alsJQKZ806w5kznoCIlBSYM2czrVoVnhLn90D58hWYM8eHnJwc6tVr\niJ2dai5pXFws/ft7KEV3oqOfsnLlUqKjFQWm163zZdmyhTRr5kZ8vCLn8sKFcwwfPoYVK5YAIjIz\nM7h2LZySJa2ZPn0iO3YEY2BgwLhxE6lY0YWUlGSmTZtIUlIiTk7OXyyMr0wZK/74w+qLjF3YWFmV\noHjx2yQk5IZvv8LSMr8kw5EjV7l/v51yPy6uEYcPB3+Uodq8eRWWL7/KoUPbMDDIZuzYNqipqXH7\ndhReXuFERxtgZ5fC4sX1MDf/csq9AgLfE4IxJyAgIPAW69ev4fDhAxgbF6FYseLY2zty9uxp7OzK\nce1aBE2bNqds2XKsWLEEqVSKg0N5xo71RlNTUyXs6s6dW/z55xKWLVvN2rWriY2NJiYmhuTkZH75\npRdt23bA2rooBga7UVdP49QpKdWqFSkwnOtjOXlS8n9DDkDMpUtm5OTkoKHx4/7sr1t3kpkzy/D6\n9QaMjcM5deox3buX/HBHFEbT3bu3cXSsoJTt/xBvq3MeO3aEKlVciYgIR1/fAD09fapXr8m2bVuU\nnsd79+5QrpwDERFh3L59E4BTp05gbV0KG5vSH3nF/5yUFG3y5nu9fGlEVlYWYrH4i835tXFxqcyf\nf/py9uwZZs+exs8//0LLlq2V52/evE5mZgarV69DLBbz888dlMa4mpoavr7+nDsXysqVy1RUSUuU\nMKd9+05oaKizcaM/zs6VmDZtIlOmzCQ5+SVBQYFMnuzFrl0HWbfOFxeXyvTu3Y9z586wd+/ur34f\nvjeKFCmKj48GixcHkZ4upn79ZAYP7pivXenSxdDWfkRmpkI0SiR6QYkSH//9bN68Cs2bqx6bNCmC\n8+d7AhAdLWfKlE34+rb/+IsREPgB+XH/VxcQEBD4Aty+fZOTJ4/h77+F7OxsPD17YG+vePDLrXUm\nkUjo1u0nli5dhZVVSWbOnMrOndvo0qXbez0hDx9Gsnr1ejIy0unT5xdq1ar7j8O5PhZ9fdUQNEPD\n9ELxJn5LtmzJ5PVrRY5hcnIltmx5QPfu+du9/RmIRCK6devB5MnehITspFatuuQaPm/n0snlb4y4\nvOdEIhFaWlp4ev6iFEAB6N27H0uX/oGHR1dkMhkWFpbMm7dIZczTp09Qp069L2rM1a6txaFDT8jK\nsgakVKoU+68y5ADi4+MxMzOjbdsOZGVJuH//rooxl56ejpqaGmKxmMePo4iPj1PmweV+9+3tHUhO\nfqnsY2BgyNGjhwG4e/dNIeyLF88TFfUQkUhEWloqaWnpZGRkEBERxuzZCwCoVauuSiHs/zIdO9ag\nY8f8pTXyUrOmM4MHHyAgIJKcHDGtWj2hW7dOhTJ/QkJeARURz57pFsq4AgI/AoIxJyAgIJCH69cj\nqFevIZqammhqalKnzhup8txaZ0+ePMbCwlIpGODm1oYdO7bSpUu3d44rEomoW7cBWlpaaGlpUaWK\nK7dv3/hg6Nin4uVVncjI9dy44UyxYk8YO9bshw+5e3v5IlH+ELe3wxa7deuh3Pb336zczq1P5ebW\nhkqVqtCt209UqFARHR1tduwI5uzZ02RlZVO/fkMA5s9fzJQp45HJ5MjlcmJjY3FwKM8vv3TO54kF\nRfkELS0tbty4RmjoacLDw/D3X8vMmfOxtCz88Lm+fRuioXGSCxcuUaSIhAkT2hT6HAXxJQRA3ib3\nexsWdpnNmzeioaGBrq4ekyZNV2lXtWo15HI5PXq4U7JkKaWiYd4x1NTUkclkyuOlS9uyY0cwMTHR\n2NqWVbbLyclGKpWiqamJpaUVkyZNV4bNCgqJ7+ZDvzHe3m4MH56GTCbFwKD6e9t+DI6OKdy7JwXU\ngUwqVCicl2ICAj8CgjEnICAgoELB9Y1AtdZZXvK+jVbkZin6SyTvlxMXidQ+GDr2qZQsWYI9e9oT\nHx9H0aK10NX98d9U9+ihy8OHYSQnV8LE5DK9ehWeARETE83kyT6kpaVy/PhRfH03IJPJGD9+DBER\nYSQnv8TUtBi//64w1tLT0wDVh9f7959y40YSder8jYXFZapXF+Pk5EzduvWpU6eeisrol8DDowEe\nHl90iny8qx5YYZIriuHm1gY3t/xGanBwCAC6unqIxWJWrFiDtrYOw4YNpEQJc+LiYlm+3FfZXkdH\nhwkTpnL16mW0tbVZunQlW7YEkJaWxuLFKwCoU6c+dnb2dO+uCN27f/8eJUqY4+JShSNHDuLh0Zdz\n50J5/frVF732fyNvlyEoDBYvbomR0WZiY3Wws8ti0iS3Qp9DQOB7RSgaLiAgIJAHZ2cXQkNPk5WV\nRXp6OmfPnlaey31otbYuRVxcLDEx0YCi/lilSlUARf5NbiHnkyePqvQ9c+YkWVlZpKQkExZ2BUfH\n8sTHx2NsXIS2bTvQpk0H7t+/W2jXoqGhgZVVyX+FIQfQo0c9tmzJYcaMYIKC1OnUqVahjV28uDnl\nyztx4cJ5Ll26oCwM/+TJY6Kjn1KmTFkuX77AypXLiIgIR1c3/wPp8uWRpKQU5/79joSF1eLixafK\nc/8Gb05GRga//TaC3r2706vXzxw9egSAbduC8PTsgYdHV548iQLg1asUvL3H4OHRjYED+yiVIj08\nupKWlopcLqdVqyYcPLgPgBkzpnDp0oXPWp+Ghga9e/ejf38PRo8eSqlSNkD+UNq8uYW5h+vUqc/J\nk8dp3botw4atxtW1FXfv3sLDoxs9enRh925FiRBPz/5ERITRs2cXTp06QYkS5p+1ZoHCQU9PjwUL\n2hMY2Jzp09ugqan54U4CAv8SBM+cgICAQB4cHMpTt259PDy6UrSoCba2ZdHX11d5IBSLxUyYMJXJ\nk72QSqU4OlagQ4fOAPTpM4C5c31Ys0afypWrquRc2draMXz4IJKTk+nTpx8mJqYcOLD3vaFjPyJx\ncbF4eY1iw4agTx4jLOwKmpqaODk5qxyvUsWBKlUcPneJ+dDR0VZu9+jRm/btf8rXxs9vE+fOncHX\ndwWurtXp3buf0hMrk8lITMybo6ZOWtqbHMUfPcQV4MKFsyreybS0VFatWoaxcRH8/ALYuXMbmzcH\n4OU1ibVrV2Nv78icOX9w9eplZs6cwrp1gVSs6MK1a+EUL14CS0tLrl0Lp2XL1ty8eYNx4yZ89ho7\nd+5K585d33ne2NiY4GCFaEneItYlS1qjqdmBS5d6cumSNocOXWHx4k5Mn+6q0t/Q0IiFC5d/9joF\nBAQECgvBmBMQEBB4i27deuLpOYDMzEyGDh2Ag4MjbduqFk2qWrUafn6b8vV1camkUug7L7a2dvmM\ntXeFjv3XuXr1Mrq6evmMuS9NjRo18fVdRfPmbujo6JCY+AwNDU2kUikGBgY0b+6Gnp4++/YpQvty\nPbE1a9bG1PQySUm5RpuEIkUUuVm6urqkpaV91ev4Etja2vHnn0tYuXIZtWvXU6qu5oaPlivnwMmT\nxwBF7umsWb8DCqMpJSWF9PQ0nJ0rEx4eRokS5nTo0JmQkJ0kJSViYGCAWKxd8MRfgdevX3HmjC2g\nWENyclX27dtOq1Zv2oSEnOPu3RRq1LCgfv2v+70UEBAQeBeCMScgICDwFvPnzyIq6iFZWVm4ubXB\nzs6+UMbN65yZO3c/e/ZooaEhpW9fXXr1qvfujj8gUqkUH5/J3Lt3BxubMkyePJ1Hjx6xfPkiMjIy\nMDIyZuLEqZiYmBIcvIXdu3egrq5O6dJlGDRoKCEhO1BTU+fw4f2MHDmuUMo1vI9cz1m1ajWJiopi\n0KA+gMIQmzTJh5iYaP78cwlqaiI0NDQYO1bhRcrriW3SxJGjR69QqtQOzMzCcHCwARTCOfPmzWLb\ntiBmzJj7RQRQvgYlS1qreCerVq0GvKnTp66uWpcvf2ipiEqVKrNjx1YSEuIZMGAIp04d5/jxo8ow\n5W+FWKyNru4rXiqFLuVoa7/JeV2w4CBLllRHIimFoeF1fHxO0737v+tvVkBA4MdEMOYEBAQE3mLq\n1JmFPqan5wDl9s6doSxfXoesLMVD/YwZF6hWLRJHx39ePPd758mTx3h7T8HJyZk5c3zYvn0rp0+f\nYM6chRgbG3P06GH++msF3t5T2LTJn23b9qChoUFaWip6evq0b98JXV1dunbt8eHJPpO3FTDd3bvi\n7q4aqmdpaUX16jXz9X3bE+vllbvVjPT0dORyORUruhAQsPWz1zlv3kx+/vmX95Y4OH36BCVLfpma\ndklJSUrvpL6+AXv27HpnW2fnyhw+fIDevftx9epljI2LoKuri66uLsnJyUilOVhYWOLsXInNmzcy\nerTXO8f6GmhpaTF4sDoLFpwgObkUlSqdZsyYOsrzISHqSCSlAHj1qiK7dt0tsCyGwIc5cGAv1arV\nxNTUFPg6iqgCAv9mBAEUAQEBga/M/fuvlYYcQEqKC9euRX27BX0BihUrrgyRbNGiFRcunOfhw0hG\njRpCnz7d2bDBj8TEREARvjdt2kQOHz6AmtqbPLMfVTPk8eM42rbdQdWqETRtuodLlwpH1MbLa9IH\njbRTp04QFfWwUOZ7m4cPHzBgQG/69OnOunW+eHj0Ja+YCLzJK/X0HMDdu3fw8OjGX3+tYNKkacpW\nFSo4UbKkwjBydq7E8+dJODt/Wc/rP2HAgIacPm3F4cPxhIS4YWFRTHlOU1Om0lZdXfZ29/8cv/02\ngrS01I/qI5VK2b9/D0lJicpjX0MRVUDg34xI/p38BSUmvv7WSxD4QpiZGQif778Y4fP9eE6fvoGn\npw4pKYoHWCurQ+zZUwZLyxLfeGWqfOpnGxcXy7BhA9m2bQ8AV65cYvv2rbx48ZxVq/zytZfJZISH\nXyU09DQXLpzF338L/v5r0dHRVakT96l8qrdKEf65HXt7ByZPnvGP+/XtG8KePb8o92vWDCQkpG2+\ndnFxsYwZMwwHh/Iq4ajXr19jxYolSKVSHBzKM3asN5qamgwdOoBhw0Zjb+9As2b1cHfvxtmzZxCL\nxcyd+wfR0U/x8hqNnp4++vp6H6xpJ/zt/nMCAs7g41OE5OTKWFicYuFCHRo3dvnWy3onX/qzPXhw\nH9u2BSGV5lC+vBNjxoxn4cJ53LlzG4kkk4YNm9C370BA4Xlr0qQ5ly5doGvXX/j99zmYmZmhra3N\nihVr6dHDHTe3NoSGnkYqzWHGjLlYW9t8sbX/GxD+dv+9mJkZfHQfwTMnICAg8JWpV8+J2bMTaNx4\nGy1aBLN4seF3Z8h9LgkJ8dy4cR2AI0cOUr58BZKTXyqP5eTk8OjRQ+RyOQkJ8VSp4srgwcNITU0l\nIyMDXV1dZS23zyEnJ+eTvVW7dm1j8eIV/8iQy8nJUW6/eKFaj/D584LrEwI8ffqEn35yJyAgGD09\nPTZvDmD27On4+MzF338LUqmUnTu3AaqKmJmZmTg5ObN+fSAuLpUJCdlJxYou1K1bn6FDR7BuXeB3\nmZuXlPSCfv124uZ2hF9/3U5q6o8hDNOjR11CQrRZunQ/e/aU+q4NuS9FXFws3br9xPjxY1i4cB53\n795m8eIV3Lp1k9mzpzNgwK+sWbOBBg0ac+jQAR4+fEBg4AaeP0/i4MF91KlTj+bN3ShdugwSiYQy\nZcrSv38vZDKZUhG1Q4fObN4c8K0vVUDgh0LImRMQEBD4Bri718bd/VuvonB428NUooQFVlYlGTly\nCFpaWshkMkaOHEvr1u0ZOXIwMpkMHR0dBg8eRsmS1owbN5KYmBhAjplZMfT19alatTrDhw8iMHAj\nRYsWJTs7mypVXLlx4xqvX7+mePHivHjxnCJFiiKVSklPT8fQ0AiZTIpUKqNGjZpcuxZB/foNCQ09\nTXh4GP7+az/orcrl999nExsbw5gxw3Bza0NERBixsbFoa2szbtxEbG3LsnbtajZs8KN8eSeMjY15\n+vQJDg7lycm5TunSS0lMnIC29mV0dPYzYsQO6tdvQKdOP6vM83Y46vr1a7CwsMTKqiSgUDvdsWMr\nXbp0U+mnqalJ7dp1AbC3d+Ty5Tc12r6TgJsCGTPmBAcO9AJEXLkiRV19E0uXdvzWy/pHODjY4uDw\n78lr/RRiYqKpW7c+d+7cIisri2HDBpKRkU54+FWOHTtMSMguHj16iI6ODkeOHCIlJRlTUzOWLl3F\nokW/ExERBsCzZwn89JM75cs74e7erkBFVAEBgX+G4JkTEBAQEPhs8nqYTExMaNu2I0WKFKFHj94c\nPHgcV9fq7Nq1je3b93LkyGnKli2HkZExr1+/Ji0tjcDAbRw7dpY1azYCcOzYEUaN+o2jR88wZ84C\nkpISadOmHbVr18XWtiwtWrTG3z+Iv/7yp1+/QWhpadG//2A2bAjCyMiInJwc1qzZQK9enp/krfrt\ntwmYmpqxbNlq4uJisbd3xN9/MwMH/srMmVOU7eRyOUuWrGTOnD9o1KgpcXGx7NgRRMuWnbG0HE3j\nxlL27duFuroagYEb882T19sml8vR11cNsXmXYaau/uZdrJqaSEVF8nuuaRcVZcibPDt1Hj36+JAi\ngW9H8eLmlChhjptbG8zMirF8+V8EB+9BXV2dTZs2MHLkWBwcHKlbtz737t3h0qULJCY+w8trNE+e\nPCY6+ikAJiamlC/vpBz3XYqoAgICH0bwzAkICAgIfDZve5iCgzcD0KRJMwBu375JlSquGBkZA9Cs\nWUvCw8NQU1OnUqUqGBoa0bdvMNeuFcXUNA0joxOEhp5i8+aNZGdno6amhomJGRUqVCQ09DT79oWg\npaXJwYP7SU19TWxsDFu2bERf3wB1dXWaNGmusr5P9VbJ5fJ31kwTiUSoq6ujpaVFXFwsu3Zto0uX\n7jx9+oTY2GNADqmpF0lM7ERi4jOeP0+iT5/uVKtWkyFDhgNvwlGdnCpy5MhBHBwc2b17BzEx0Vha\nWnHo0H4qV676j9erra3zXde0s7Z+ze3buXtySpYUnY+NGwAAIABJREFU8n5+JHR0tKlatTrjx49B\nKlWIwLx6lULFii5cvnyRc+dCqVWrDtu2BWFnV44ePXoTELCe5ctXK9UqDx8+gKam5re8DAGBfxWC\nMScgICAg8Nm87WESiRSBHzo6OsrzqgaVqnE1a9ZR9uzpBWjw+DHY2/uyadNyrK1LKQVVSpWyoVQp\nG+RyOZs2bWDJkj+YNm0WjRs3Y+XKZURHP8XXdwUJCfFoa6vmqX2ut+qfGoMaGprs3r0dd/duzJs3\nE4lEgq/vSp4/T0IkUmPVKj+uX7+Gp+cvZGZmoqurx/btQUyfPpHs7GyCgnYhk0np2rUjpUuXwd7e\nkaCgQDp37opEImHRovlkZ2cjkWTy5EkU1tY27NgRzPPnSQwY0BsLC0sCAzcWWNOuIAn4jIwMpkwZ\nT2JiIjKZFA+PfhgZGRUowNK5c9vPFqqYN68O4E9MjCFlyqQwe3bTj+ov8O2xsSlN//6DmTFjMkOG\n9EdbW5uff/6F8+fPsnnzRsqXr4izswvm5pbs2xeCTKYw+hITn6GhoUmjRk1YsuQPPD1/YeXKtwWR\nRN+1Z1lA4HtECLMUEBAQEPhs3hY8cXZWFYhwcKhAePhVUlKSkUql/P33YSpXrkqFChUJD7/K06cZ\ngAZqaskApKc7ERQUqOwfHx/HjRvXiY2N4erVy7i5tUZLS4v4+HiePn3CmTMnKVvWjm7deirru+Wi\nq6v7Wd6q3JppQJ6aaXrvNPCcnJzZuNGPnJwcoqOf4O7elRYtWiMSiVQETv74YxkikQhHxwps2bIT\nLS0txGIxqalpODiU57ffJtCqVVulx1NHR4dJk6azdu1GVq70448/5gFgbFwEI6MiLFmykmnTZhEQ\nsBU/vwAVQ04qlRb4kHzhwllMTYuxfn0gGzYEUaNGrfcKsHyuUIWFRTE2bvyJY8easmZNJ4yNv//a\nYnFxsfTq9fOHGxbA1auXGTduVCGv6NuR+x1q0qQZZmbFWbHClzVrNtCsWQtMTExwcanMihW+zJw5\nn19/HUGzZi3R19dn2LCBTJkynoyMdGrUqI21dSn8/DYhFosJDg5RvmBwcHBk6dJV3/ISBQR+OATP\nnICAgIDAZ2NtXYqdO7cyd64PNjZl6NixM9u3vymUbWpqyqBBQxk+fBByuZzatetRt259AMaNm8jM\nmXMpVWonOTnFiYlZg7V1JSASD4+uZGVloaOjw86dW7l48TwSSRYWFhaYmJiyc2cw+/fv4cWLF+zc\nuY3ixUtQooS5iuHSpElz5s2bVaC36v0ovASengOYM8cHD49u/zeopinOFmAciUSKENIKFSrSvXtn\n1NTUyMrKAkAsFnP16mWlwElcXCwGBoZERFylS5duWFpa8fhxFHfu3KJr118IDw9DJpMikUjYvHkD\n169fY9CgPkgkWVhZWfH69Wt69x7A7duvgGSaN/+JNm3q4+XlDUCzZvVo374Tly9fZPTocco1SiSZ\nTJgwjrZtW2FrW54//1zCypXLqF27Hrq6uu8VYBGEKv67mJtb4O+/RbkfHLxb5Xzec7m4u3fF3b1r\nvuP+/lt4/vwFq1adQyZTo2dPF2xsLAp/0QIC/wEEY05AQEBA4LNRV1fPJ+EfHByist+0aQuaNm2R\nr2/NmrXZs2c38+cf4OpVDWrWDGDq1MaUKKHwhsTFxeLlNeqjar3lpWJFFwICtn644VvkfVidM2dB\nvvOengNUvFNFihSla9cexMREk5j4il9/nYSv7zwiIx/Qp08//v77EPr6Brx6lQIoHo69vCayc2cw\nAC4ulTl37gzq6hpUrVqdgwenIpPJadmyNceOHcHAwABLSytycnJYsWIN/v5r8fV9iLr6bRISJvP6\ndQf09VtTu/YJ6tVrSGZmJhUqODF06EjlGtPT05kyxRs3tza4u7uTmPgaP79NnDt3Bl/fFVStWk3l\nGhUhs2+M1v+qUIVUKsXHZ/I/qgd4/vxZli1biFisjbNzJUQixX3s1q0Tq1b5YWxsjEwmo3v3Tqxe\nvU6ZR/pvRiaTsXDhYW7dUsfCIpORI2vSrdspIiI8ABEHDwYTFKSOlVXxb71UAYEfDsGYExAQEBD4\nbD43z0UkEuHl1eqjxt+37yJhYS9wdDSgU6c6yuPJySnMnHmS58+1cXUVMWRI0y+Wh5N33Nzt8eOX\ncvfufeRydcTidExMzDA0NEJHR5fr1yOQSCQFCpy4uFRmxowptGrVFmNjY1JSUkhOfkmjRk3w9V1B\n8eIlSE1NxdW1Grdv3+LcuVCyshzR0DBDLtcH1DEwqEJ4eBj16jVETU2Nhg2bKNcnl8sZP34Mv/zS\ni2bNWgKQlJSEgYEBzZu7oaenz44dwcTHx6msr1KlKl/k3v1IPHnyGG/vKTg5OTNnjg+bNwcQErKT\npUtXYWVVkpkzp7Jz5zbat/+J+fNnsWzZaiwtrZgyReElFYlEtGjhxuHDB+jSpRuXL19UKrr+F5g1\naz/LlrUCjIBsLl/+g4iIUeQqm96/78727cGMGNHyWy5TQOCHRDDmBAQEBAQ+i7fDr77G+H/9dYyZ\nMyuQmdkELa1oHj06yNixigfBgQMPcfx4H0CNgwfjEYmOMmTIlxHaOHz4pMoaHz9+wsmTfcnMdEVD\nIxpLy/6sXLkJP79VlCtnz+TJPty4cY3Jk72QSqU4OlagQ4fOAMrC6i4ulQEoW9aOly9foKGhgbm5\nJc7OLhw+fIBTp05w5MhBsrNzsLe3JzJSBogwNLyGk5NYaVRqaYnzGZvOzi6cP39Wacw9fPiAqVO9\n0dc3wNi4CGPHepOa+rrA9b0pKaDY/lQD+f79eyQlJVKrVp0PN/5O+Kf1ACtXroqFhaUylLd5czdC\nQnYC0Lp1O8aPH0OXLt3Yt283rVu3/TYX8w2IiNBGYcgBaBIXZ42a2ktkshL/P5aBvr4g4yAg8CkI\nxpyAgICAwA/H3r1SMjPLApCVZcWBAxqMHQsSiYQbNyzI1feSSktw5crXK6Kdnp5BVpYBkADsB9IR\niRoREDBY2aZq1Wr4+W3K11cs1ubYsbPK/XHjJiq3XVwqsW9fCBMmTKVMGVv69u2Js3Mlhg/vR8+e\nx2nYMI7GjTU5ePABjRvnz1HKpV+/Qfj5+fLHH/OYO3cm1avXpF69htSuXVfFi1fQ+vKGneYKVeSK\nwHyMYXf//l3u3r39UcZcTk4OGhrf7pGloHqAueGyuccK5s3xYsWKU7RoUa5cucTt27eYNm32l1ru\nd0eRIqoCRNbWUL/+frZvr0VOjpiWLY/g4dHlG61OQODHRjDmBAQEBAR+OMTinLf2swHQ0tLC1DSZ\nxMTcMzJMTDK+2rrKlStLw4aBHDumD7gBezh/vg27dl2gQ4canzyui0tlNm5ch5NTRcRibcRiMWlp\nOmzceJXevfty6NBWAgNVhWXyG1iKfXt7B5YuXUitWkepVq0m6urqhIeHERS0iefPnzNkyHAaNmxC\neno63t5jSUh4RlJSBmZmzRg8uAl2diaMHj2UChUqcvfubX7/fSkBAeu5c+cWEkkmDRs2oW/fgYCi\nvuDSpX+QkZGJlpYWixYtZ82aVWRlZXHtWjg9e3pSq1YdFi2az6NHD5FKc/D0HEDdug3Yv38PJ08e\nIzMzE5lMxrJlqz/5/n0ub9cDPHv2NCYmply/fo2goE3o6OhQuXJVSpWyIS4uVhmmeuTIIUAhRnPk\nyGnatu2Aj89k3NzaIBKJOH36BCVLlsLGpvQ3u7avwdSpNUlMXM/9+8WxsHjO1KmOVK1qx8CBd8nK\nekmlSl1RUxM8cwICn4JgzAkICAh8BbZuDaR9+58Qi7ULZbyCaoZ9DPv37+Hu3duMGjXuw42/QwYP\ntuDBgwPExNSkWLEwBg0yBRQGzNSp1vj4BJKUZISTUzyTJn29PBx1dXV8fKpz/LgmOTklefx4DwBn\nztyiQ4dPH/fu3dv8+usIxGJtlixZQGKinKNH56CjE4al5Q4aN7YlPv4poaGn0NTUpG/fgcyZswBv\n77FK8ZZx47yZPNmbZ8/iCQraia2tFZGRMSxfvogXL56zcqUfUVGPGD9+NA0bNkEsFjNkyBi6d48l\nJqYm1tZdGTq0PkuXPiYmJprJk30oX94JgAEDhmBoaIhUKmXkyCFERj7A2roUU6dOwMdnLg4OjqSn\npyMWi+nffzB3795m5MjfAFi9+k9cXaszYcJUXr9+zYABHri6Kgzf+/fv4e+/BQMDg8/4VD4PkUiU\nT61VU1OTCROmsmjRPJVwVA0NDcaNm8i4cSMRi7VxcalMbGw0uYZ0nTr1mT17Oq1aKUIsT506QZ06\n9f71xpyVVXF27eqERCJBLBYrjzs5OXzDVQkI/DsQjDkBAQGBr0Bw8BZatGhVaMbcPwlrmzdvJj//\n/As2NqWRyWQf9eY7Li4WT8+x+PkFfrjxN6BxYxcOHkwkLOwizs5lsLAokeecM40aVSQ7OxstLa2v\nvjZzc3NKlIggLi631p4EM7PPU350canCli0BdO7clbCwq7x8qQtooKNzhcTEn5DLM1izZpbSmHr4\n8AFVq1Zj4cJ5vHjxggULThMaGoJYrEGDBlWVLwEMDQ0BqFevAaAoCP3ixQtAETq4YMFC1NSSsbL6\nCw2NZyQmOnHu3HaKFzdXGnIAx44dJiRkF1KplOfPk4iKegiAiYkpDg6OgKLeH8CyZYto0KCRsu/F\ni+cJDT3F5s0bAcjOziYhIR6RSISra/VvasgBlChhzqZN21SONWtWn6pVqzFr1u94eY1i/PjJZGZm\nMn36JB49ekipUqVJSkqkZcvW2Ns70KxZff76awXHj/+NXA4GBgZcvx5BaOhpwsPD8Pdfy8yZ8z+i\nbMaPSV5DTkBAoHAQjDkBAQGBQiYjI4MpU8aTmJiITCalUaOmJCUlMnz4IIyNFcWdFyyYw507t/OF\npXXu3BY3tzaEhp5GKs1hxoy5WFvbkJKSzLRpE0lKSsTJyVklR8fbeyzPniWQlSXB3b0b7dp1BODv\nvw+hp6evrDP29OkTAgLWo69vQNmy5ZQy8z8qxYub0bKlWYHnRCLRNzHkAAwMDPH21mDx4mBev9an\nZs0YRo3q+Flj2ts7cPfubdLT0xCLxWRllURb+8b/jbmJJCauxNNzh9KYevToEWXKlKVFi1ZMmLCI\nXbu8sLZeR1xcbySSE3h7y1TG19R8813I/W4dPnwATU0ZCQnzycx0oHTpxmhqPsHSUo/bt9+8lIiN\njWHLlk2sWbMRfX19Zs+eTlZWFgW9b3hX8fJZs36nZElrlWO3bt1AR0fnc27bV2XHjmCMjIwICNjK\nw4eR9OnTHYArV+6RkZHBrl3hiESp1K1bj5CQnXh49KVu3frUqVNPWb/veye3TMiGDUHfeikCAgL/\nRzDmBAQEBAqZCxfOYmpajN9/XwJAWloq+/fvYdmy1UqPyIABv6qEpT18+IAyZcoiEokwNi6Cn18A\nO3duY/PmALy8JrFunS8uLpXp3bsf586dYe/eN2IU3t5T0NTUZOLE31i8+HeCgjbh6TmQjIwMihQp\nyvr1gTRtWhd1dXWKFSuBSCTiwYN7VKjgRExMNNOnT0IiyaROnfoEB2/hyJFTKtcjlUpZtWo54eFX\nyMrK5qef3Gnf/qevd0N/QLp2rUWXLjKysrLQ1q772ePlKlru37+H6tVrkpWVxIULB9DSisLRcS8v\nX15l3brAPMaUBIBWrdqxcWNfDAxOkJrqRnp6LSQSX54+fUzx4s4qIh5vk5aWRrlyZRg69Dpbt+5D\nQyOGBg1W0qzZGIKCljBixGCWLFnJ5csXSUlJ5uzZ0/j7ryU6+ikvX76kWbOWPH+eRJMmdejY0Z2L\nF88zevQ4RCIRGRnpyuLlRYsWZcuWTTx7Fk9iYiIZGekMGPDre0RFvk+uX49QFlcvU8YWW1s7Xrx4\nyYgRmWhoaHL58maKFr1I+/YXiI+PVvb70a5TQEDg+0LINhUQEPhPM3ToAO7cuV2oY9ra2nH58gVW\nrlxGREQ4enr6+docO3YYT88eeHr24NGjhzx69Eh5LvctfblyDsTFxQIQERFGixaKOmy1atXFwMBQ\n2T44eDM9e3bhzp3baGlpMXHiNGrWrAWgLAKdmZmJg0N5Nm4MolKlKhgbF0Eul7NkyQJ+/rk7/v5b\nKFas4IK9e/fuRl9fH1/fDfj6+rNnzy7lugTejZqaGtrahRNWCwpFy82bA6hUqQp//DGI0qV34+ho\nwsKF1dDT00dPT48XL55z/vwbRUxTU1OMjAwoWnQFKSk/kZVVFh2dmkyaNJ727duzfPlioOB6ec2b\nt+TOndvcvLmejh1jsbCwoEQJRbioRCIhIyODnJwckpISKVHCnFmzpmFsXIS6desTFxfDuXNn8PGZ\ng0Qi4cSJo+jq6uLgUB6xWExU1CPatGmOlZUVM2fOJzY2mhs3riOV5mBjU5qaNWshEn16+YNvxduG\n2dWr94mKak7uu/MXL6pz61aSSsH1s2fPIJFkfs1lfhYymYx582bRs2cXRo8eikQiUfkdTU5Oxt29\nHaDIzfX2HsOoUb/i7t6O7duDCAzciKfnLwwc2IdXr14BEBKyk/79e9G7d3cmTRqnvB+zZk1j8eIF\nDB7sSZcu7Tlx4ui3uWgBge8YwTMnICDwn+ZLPDCWLGmNn98mzp07g6/vCqVBlUvBYWkS5fnc8Ed1\ndTWVh76C3uBfvXqZK1cuMW/eIsaPH41MJuPWrVuUL++Empqa8to0NDQwMysGgL29IxERYVhYWHDz\n5nXmzl0IQLNmLfjzzyX55rh06TyRkQ+UD1JpaWlERz/F3Nzic26TwEfytqKloaEBrVo1o2JFZ8qV\ns6d7904UK1YCZ2cXlX4DB/Zg0aJVFCkSjqHheSZN6oazc1nMzAxITHydb57c2nlGRsasWuWnPJ6T\nk0P37p0wMjLCycmZMmVsuXPnNteuhdOqVTvu3bvDxInTAMULgPDwMIYNG4W6ujrBwSHK76LiXxHj\nxk1Q1rsbPdqL0aOHUrt2PSwsSnHp0i0aNWqKm1ubL3AnvwwVK7pw7NjfVKniyqNHD3n48AGdOvVC\nX//NyyJ19QTMzMSAQqpfV1eXI0cOKcVt/ikfmwNbmDx9+oRp02bj5TWRKVO8OXny2Ht/Rx89esi6\ndYFIJBJ+/rk9Q4aMwM9vE8uWLeTgwX106dKNhg0bK8PDfX1Xsnfvbjp1+hmgQHEeAQGBNwjGnICA\nwH+CuLhYxowZhoNDee7du4ONTRkmT56u0mbBgrn55NWvXLnEtm1BSkXAS5fOs3PndmbP/p2LF8/j\n5/cXWVlZWFpaMWHCVHR0dOjYsRVNmjTn6tXLuLpW5969u+jq6pGWloahoRFpaWloa+uoeFIqV676\n3vW7uFThyJGDeHj05dy5UF6/VrzRTk9Pw8DAAFvbskyfPochQ/qyd+9OXr9WDZ/T0NAkPPwqr16l\nIJfLiI2NUQpT/BNGjx5HtWo1/3F7gcKnatVqHD9+Trm/efMO5faECVPf2e/69Qh+/dWT1q2bfdb8\neUM9K1Z0wda2LFevXiImJhpzc3Pu3s3r4ZZ/VPHy3BcgY8cuY+3abaSm1qJ06ccEBjajSBHjz1r3\nl6AgT2Z2dhY3b16jR48uZGdnoampRaVK5enZcz9//51J6dK9KVo0iitXpBgZGf+/rxopKcl07Nia\ncuXKsWrVunf+rnTu3JYmTZpz6dIFfvnFgyZNPu/z/FTMzS0pW9YOUORyfshLX7myKzo6Oujo6KCv\nb0CdOorSGWXKlCUy8j4AkZEP8PVdSVpaKunpGdSooYgsEIlEBYrzCAgIvEEIsxQQEPjP8PTpE376\nyZ2AgGD09PTYsUNVoW7AgCGsWbOB9es3Ex5+VakI+ORJFCkpyQDs27eHNm3ak5yczIYNfixZsgI/\nvwDs7R0IClIUWpZKczhy5CByuYywsCv07t2Pdu06MGbMMEaMGIydXTmlJ2X69Mn5PClvePO229Oz\nPxERYfTs2YVTp05QooQ5ADVq1EYqldK160/4+6/F2bkSTZq04N69u6ojiUR4eg5g4MA+rFmzCn19\nA0QiERUqVOT4cYXH7e+/Dxe4iurVa7FjxzZychS13Z48eUxm5o8TFvZfZevWUFq2bEdo6BWaN3cr\nlDHzhnq6uFRm167tlCtnj6NjBcLDr5KSkoxUKuXvvw9TqVKVd47Tr98gDAwM+eOPeQAkJSWRmJjI\noUMdeP58ONraj7lypQ9Ll54plHUXNrneS3NzC/z9twBQpUo1LCysCAjYir6+ATk52ZiYmFKqFIwb\nN57Dh2dw4MAegoP3oK9vwMOHDxgxYgzm5hbs2rWfVavWvfd3RSQSYWRkjJ9fwDcz5AAV4SQ1NXWk\nUinq6urIZIoogrxRBvnbqyn3RSKRMvJg9uzpjBkzHn//LXh69lcZoyBxHgEBgTcInjkBAYH/DMWK\nFcfJyRmAFi1aERy8ReX82/LqeRUBDx3aj5tbW27evMGUKTM4dy6UqKiHDBrkCUB2dg4VKyrGFou1\nWb78L4oXfyOXb2/voAwbgnd7UoKDQ5TbDg6OLF26CgBDQyMWLlxeYJ8FC5Zy8eJ5/vxzCWpqIk6d\nOs6YMeNJTX0TQicSiWjVqi2tWrXlxImjnD17hpEjfyM6+ik+PpPZuHEd1avXRF8/f35f27YdiIuL\npW/fHsjlcooUKcrs2b+/+0YLfHP+/PMIc+dWRSI5jkj0gmnT9jNrVvvPHreg4uUuLpUxMTFl0KCh\nDB8+CLn8nxUvHzlyLLNnT2fFiqW4ulZn4cL5mJmpI5MZ8OzZNEBEdvaP85hSqlQpzp49Ta9eXYmL\ni6FWrTo8eHCfa9fCGTnyN44f/5tdu3YQG5uCRJLKypW7mT9/tMoYN29ef+fvCvBNjbj3YW5uwd27\nt3F0rPBJeW0ZGekULWpCTk4Ohw7tf2f+roCAQH5+nF9JAQEBgc8k70OlXC5X2X9fHlurVu3w8hqF\nlpYWjRs3VeaquLrWYNq0WQXO9bUl1atXr0n16qphkMuWrQYgMDCUnBxv6tU7QqtWWXh7t1bmnZiZ\nmfHXX+sBRSmDp0+fAIqHsz179pCY+BqRSMTAgb8ycOCvX++CBD6LY8dESCSlAJDLi3LqlF6hjPu+\nUM+mTVvQtGmLfH1yvVi5BAe/UWLN+1Jj06ZgevYM4u+/+wBaWFoeonNn20JZ99fA0NAIZ+fK1KtX\nn5SUFJUwVLFYzJYtm9DQ+Inw8IEULz6JXbvKYGGR3xv+Pf2uFMTbxrlIJKJbtx5MnuxNSMhOatWq\nS67Bnj+XTjU8Nfdcv36DGDCgN8bGxlSo4ER6enqB8/1ogjgCAl8DwZgTEBD4z5CQEM+NG9dxcqrI\nkSMHcXZ2ITT0NHK5/L15bKamppiamuLvrwh/Aihf3omFC+cRExONpaUVGRkZJCUl5quV9a2JjHyM\nj48+L14ocpMePYrBzi6Uzp3rAHDnzh0WLZqPXC7HwMAAb+8pKv3XrDlBQIAEmUxEp04iRoz4Pj0D\nAqro6KiGuunqSt7R8tsSFRXL1KmXePZMFyenVFavbouf325SU6F9ezucnH4cYw7ehKFOmDCVMmVs\nWbp0IY6O5f//+6JNRIQ56uov0dM7RXp6Da5ckaKrq6vMp/0Svytv14YLDNxIZmYGBgaG7N69A3V1\ndWxsSjN9+uwPjpU3rBSgW7ceym1//83K7f79BwPg5tZGRcQmryGf91yHDp3p0KFzvvnejmB4+8WA\ngICAYMwJCAj8h7C2LsXOnVuZO9cHG5sydOzYmdDQ04hEIpU8toIUAZs1a0lKSgrW1jYAFClShIkT\npzFt2gSysrIBRc7d92bMXb8exYsXDZX7WVmW3L//RrrexaUS69cHFtj3woWbzJ1bilevFGFef/wR\nSfnyl2nWzPWLrvlH4F3Fk9euXY2LS2VcXasX2O/06ROULFkKG5vSX3R9o0bZERUVzL17VbCyus3I\nkeZfdL5PZdSoC4SG9gLgyhUJuro7mD79x1GwfJt3haGWLWtHuXIOREb+TokSu8nIqArIKVo0gwYN\nOjJmzDDMzIqxZMnKd/6uyGQypk6dwMuXL5HJpHh49MPS0orlyxeRkZGBkZExEydOxcTElK1btxIY\nuJns7BxMTEyRyd4Uic/1bm3a5M+2bXvQ0NAgLS31W9yud3L0aDhr1sQjlYpwdzfA3b32t16SgMB3\ni0j+nWSTFiSPLPDv4F3y1wL/Dn6Uz/ddD9//lIUL52Fv70jr1u0KPJ+c/BKZTEbRoiafs8xCJyEh\nkVat7vL0qUIAw8DgJmvWvKRRo3eJrrxh69YTDB3ahryhUZMmbWP48PyhdP81PvX7NGvWNOrUqfdR\n8uq5AhMfS1paGlFRTyhZ0kJZrD4v3/pvVyaTUbnyceLiOiiPNW++nYCA5t9sTV+aU6duMmXKQxIT\ni+LgEMeqVY0wM3v3b4ZcLufRo0doamoSGXmPCxfO4+U1EYC0tFTGjh3O3LkLMTIy5ujRw1y8eB5v\n7yloakrJzlZ8ZxYtms+JE8fYvfsgAJs3B5CRkc7NmzfQ0dGhfv2G1KvX8LsI4QR49Cia9u3jiY9v\nBICxcRj+/hJq1arwjVf2/fCt/3YFvhxmZgYf3UfwzAkICPxn+NR8C0/PHujq6jJ8+JgCz0+cuJut\nWy0ANdq1O86CBZ2+m9yO4sXNWLQogdWrt5KdrUa7djo0alT/H/Vt0qTi/9g784CcsjeOf963fZUt\nS0mpVKRIYxiyJcY2zNjXMBh+1rGHoiyhLNmXlLJkZDD2nRFZxhZmZEmJNor29X17f3+8ekkhS9b7\n+Wfee+65555z79Wc55zn+T5UqhRCQoK8ftmyF2nU6PPaefyUFCRPvnEjjIoV9fH0XIi3t6fCWFu1\nahlnzoSgpKREgwYNadasBWfOhHD16hUCAtYze/YCMjMz8PKSJ9Y2MDDExcUNHR0dRo4cSs2aFly7\nFkbjxg7s37+XoKA/FbsoAwb0YevWHa818rStySyMAAAgAElEQVS0tKhdu+TpJ96XtzVwxWIxRkap\nxMUVlORhZJRN164d8fPbVKwB+qXTtGltTpyoRU5OzhsTykulUoYO3caBAw1RUsqgc+co4uPPs2rV\nMn74wQEdHW3u3Ytg7Nj/AfLvsXz5igDcvn0bL6+FZGSkk56eVkgdsiAht7e3D1euXMLffx2enh60\naOGIm9vsdx7b+vVrqFvXjvr1v2PkyKGMHPn7W6U/KeDUqRvEx/+iOE5OrsfZs8GCMScg8AoEY05A\nQOCb4OVYj7fBz2/TK88dPHgOf/8WSCQGAGzebM0PP4TQpUvJDKaPQdOm1jRtav3W11laGrNw4X38\n/ILJz4eePXX57rsv390pPPwmBw/uY+zYCe/VzuuSJ6ekJBMScpItW/4E5LsoWlraNGnSlMaNHWjW\nrCUAzs49GTduMra29Vi/fg3+/msZPXo8IpEIiUSCr28gIDeUzp49jYNDc44ePUzz5i3fabfuc8Pb\n244ZMzbz+LEmtWql4ObWnn79fD91t0oVkUj0RkMOIDDwBHv29AG0kEjgzz8r4etrhrJyJuvWrcTO\nzh4TE9NCid0LmDJlCnPnLsTU1Iy9e/9i0aL5pKamoK6uQWjoab7/vhEJCfHY2dnj7e2Jjo4u48dP\nea9x/frrb4XG+K4LWnXr1qBMmaukpMhjltXVI7GyKvdefRMQ+JoRjDkBAQGB9+Dhw2QkkufxSPn5\nFYiPz/iEPfqwtG5tR+uvzOvN0tLqnXYMXuZ1yZO1tXVQVVXD09ODH35woHFjB8W5guiG9PR00tPT\nsbWtB8CPP7bH1fX5hNrR8fmD79ixM1u2BOLg0JwDB/YyefL09+5/aSCVSvHwcOX27XCMjWvg6urO\n9evXWLnSB6lUiqVlLSZMcEFFRYWLFy+wcqUPampSWrV6Xl5ATk42U6dOokWLlrRq9SOurpN5/Pix\nIl7sc5Xp/1CkpEiA5yqkMpmYpKRM+vVri5aWNrt2bSc5OVkh6iSRSHjwIBoTkxpkZj6X+j969BBm\nZuYMGeJMxYr6GBubKN5TVNQ9UlNTKVeuPLt2/UlIyN/k5uagpqaGi8sMjIyqs3//HkJCTpKdnc3D\nhw/o2bMPOTm5HD16EBUVVby8fNDV1S3iQiyTydi3bzcREXcUXg27d+/k/v1IRo0aV9yQAbC1rcm0\naX8TELAdqVRM5875tG0ruHYLCLwKwZgTEBAQeA86dKjP+vW7iYiQx/0YGe2jXTubN1wlUBpkZWXh\n5jal0IS/atWq+PgsJDs7GxUVFXx8VhEe/h9bt25mwQK5cMTixQuIjLyHVCph0KChNGnSjP3793D6\n9ClycnKIiXlI06bN+d//RhMXF8v//jeYrKws0tPT+OmnNkyf7s6OHcE8eZKEqqoaly79w44dwSgr\nK3P16mUSEuLZsWMbU6a4cvbsGa5evUxgoD/Dho0C4PLli/j5rUVNTZ2oqHt4eLgCoK6uUchVMykp\nkcuXLyKVSjExqfEpH/UriY6+j4uLG9bWNnh6ehAUtIndu3eydOlqDA2rMXv2DHbu3E6nTr8wd657\nkfLu3XsBkJmZiZubC23bdqBNm3acPHmMChX08fLyAfjsBDtKg06d6hAUtJvIyJ8AGRYWfhw+fJbj\nxzeirKzChAkuiMVifHy8SU9PRyqV0KNHb0xMajBmzJgiUv8FypDx8XEkJaUwcOAQVFRU6NbtJ9av\n34iysjI9e/ZFSUmJf/45z9q1K5g9ewEAkZH3WL58LYcOHWDt2pV06NCZChX0qVatGgcP7qN7915F\nduNEIhEtWzoRGOjHiBFj2bDBl5Mnj+Hu7vnGsQ8Y0IwBA0rjqQoIfH0IxpyAgIDAe1C5ckX8/U3x\n9d1Gfj4MGGCGiYnBp+7WN8n586FFJvwDB/bBw2MelpZWZGZmoqamVuiawEA/7O0bMHXqDNLS0hg6\n1Bl7++8BuHv3Nhs2bEFZWYXevbvQrVtPUlJSePz4EdOnu7N5cwDKysocO3YYA4NqVKhQkcTEx5ia\nmhMe/i9r1mxAW1uHwYP7k5aWQrly5WjevCVWVrWpU8cWd/fp6OjoEhFxl7t3b9O+/U9Ur27Mf//d\nICcnh/T0tEKumoGB6/HwcGXAgMEf98G+Bfr6lbC2li9mtGnTjg0bfKla1QBDw2qAXI5+x45t2NnZ\nF1vevXsvZDIZU6aMp0+f/jg5yVNqmJqas2KFjyJezNa27qcZ4EfExMSADRty2bRpG0pK+fzvf32p\nXPn3IvWWL19bpKxXr160alVUFdTb+xCrVlUhPb0ijRr9ycaN7RXn0tLSmDVrBjExDxCJREilUsW5\nevXsycvL49ChfWhr62BtbcODB9HUqGFGRMSdV45BQ0MDO7vvOHMmBFvbemzZEkiNGoXTTbxrfN3+\n/Xu4desmv/8+6a2uExD42hCMOQEBAYH3xNLSGG9v40/djW+elyf82tralC9fQTFJ1NTULHLNhQvn\nOHPmFEFBGwHIy8sjISEekUhE/foN0NSUu7kZG5sQFxdLVFQUmpqa2NrWIyhoI6am5tjbN2DlyqXk\n5GSTnp7OkydJ5Ofn06dPV0BEfr6UIUOGk5cnISbmIXv37kYsFgEiVq/2w919GlJpPnFxcUydOoPV\nq5dz5colNDW1CrlqtmnTnoAAP5ycPl+Xsxd3ZmQyGdraOqSmphQqK44Xy0UiETY2tpw7F6ow5qpV\nM8LPbzNnz55m3bqV2Ns3+KyN2g+FlZUJc+aUPI3F7t0X2Lz5CaqqKnTrpsNPPz1PkZGQkMDq1fqk\npclzTIaGmuHjsw2QP39f39XY23+Hp6c38fFxjBr1PAZOVVWF1auXERPzEKlUysaNfmhpabNz53Zi\nYh6QnJyMsrJ8ShkefpM7d27h4TGdSpWq0K1bT/7660/u349CVVW+mNK1a0ccHVvzzz/nyc3Nfaf4\nus9FZEpA4FMj/tQdEBAQEBD4Nrh8+SKTJhXdWfhQFEz4TU3NWLduJX//fbxE182Z44W//xb8/bew\nffseqlc3BuQT2ALEYiWkUikikQixWKwQ1BGLxaioqFCuXHlmz16AsbEJpqbmbNoUzPHjoRw/foaT\nJ8/Rp48zf/yxGTMzc06cCOXw4VNIJHmYm9dk3LjJ1KtXn7lzvdDW1kZJSUzfvs5YWdVi3boAmjd3\nJDQ0hIkTx9CiRSu0tLRL4/F9EBIS4rlx4zoAR44cxNLSiri4WGJiHgJw6NB+6tWrj5FR9WLLCxg8\neBg6OrosXDgfgMTERFRVVWndui29evXj1q3wjzyyz59r1+4wZYoWJ05049ChzkyZos21a893zVJT\nU8nIqPDCFWKysp6v6WdkZFChglwNc9++3UXaHz58NAYGhlSsqM+gQb9x584tnJx+xMmpLbGxMTx5\nkoRUKmXJEi9MTExxdZ1F+/YdOXnyGI8ePSIhIR5VVVU8PFxJTHzMuXNnWLXKl7Jlyyru4e09j8GD\n+9OvX3fWr1+jKL9581+GDx/EgAG9GTp0AJmZmYUWAEJDTzNs2KBCCwcCAt8KgjEnICAgIPBV8PKE\n/+bNf3nyJInw8P8AyMzMKOQ6BtCgQUO2b3+ucnr7ttxIKG4HSSQSUbOmBVlZ2Qqxk9zcXEU7+/fv\nUfz29V1VpM3MzAxFHsKDB/cVSuRcHAVxeeXKGfLvv7FERt7D2fnXkj+Qj4xIJMLIqDo7d26jb99u\npKen06NHH6ZOnYGr62ScnXuipKRE585dUVVVLbZ8/vzZSCQSAMaOnUBOTjYrVy7l3r27DB06gIED\ne7Nhg+8H35ULDt5K377daNu2JZs3BwByqf2goFcr2X5unDlzl8TEhorjxMTvOXPmruK4Ro0aNG58\nBpD/G9DXD6FDByNAHuvWu3d/Vq9ezqBBfZ59m/Kdr4JYuOf/JuS/raxqo6uri1gswsysJllZWTx+\n/JjIyAju3r2Nu/s0AgP9ePz4MS1btqJMmbI8eZLEL790o2JFfapXN2HHju2FxjB06P/w9Q1kw4Yg\nrl69TETEXfLy8pgxYypjxkxkw4YtLFmyEjU1NcXO3N9/n2Dz5gC8vZd+leksBATehOBmKSAgIPCN\nUJxAiIGBIcuXy4VAypTRY9q0GZQvX4GHDx8wceICHj9OQiwWM3v2fKpWNWDFCh/Onw9FJBLRv/+v\nODo6KQQ89PTKEhkZgYWFFW5uswA4dy6UZcsWoaamjo1N6cY53bt3lxUrfBCLRQqBCJksn8WLvRR5\nvRYvXvFsciq/ZsCAwSxduhBn557k5+dTtaoB8+cvfqW0uq6uLvr6+kybNpH8fBnJyU9p0cKRAQMG\nM2+eBw8fPuDUqRNkZWUVafPnn7sxbdokDh7cz/ffN0JD47nbZ3EeY5mZGYwaNYKoqCzy8rR5+tSV\nRYsusHjx55nrr3LlKmzevL1Ief363+Hnt7lE5S+rdBaIdoDcSC4tdu3ajo/PKsXOFHx5bnx16lRF\nSyucjAxLALS0wrG2fq60q6SkREBAR5YuDSYjQ4n27avTqJEVwcF/AWBtXYegoB2K+kOGDAfk8Yxt\n23ZQLGAEB//F5csXUVFRVZxbvHgBHTp0wsLCkhMnjhZJlzBp0u8YGFQlLy9HEVPp5NSG/fv3Fqp3\n/Phhdu/ehVQqJSkpkaioewDFukvLZDIuXbpIePhNFi9eUawbtYDAt4BgzAkICAh8IxQnEDJhwmjm\nzVtEmTJ6HDt2mLVrV+Li4oa7+3RGjvwftrbfk5eXR36+lJMnj3H37m0CAraSnPyUwYP7U7euXFb/\n7t3bbNoUTPnyFRg+/FeuXw+jZk1LFiyYw7JlazAwMMTNzaVYo+VD0aBBw2In/GvW+Bc6rlevvsKl\nT01NjYkTpxa5pmCSWsCCBYsVv4ODi7qgAcyYMee1/TM0rEZAQJDiePhwuZqlnZ09dnb2ivLRo8cT\nEXGP3NxcKlUayNGj3RTn9u49iZvbE8qW/fR5t+LiYhk/fhTW1jZcvx6GpWUt2rbtgJ/fWpKTk5kx\nYxahoafR1NSiV6++APTr1x0vr6WUKVOm0MLCgAFDaNmyVSExjHPnQlmzZgUJCSmAFr/9No5Onb7/\noGNwcnKgdeu2xMQ8pE+fbvz661BiYh6SnJzM5csXqV27Dr169WXkyKFYWFgSFnaVrKxMpk93JzDQ\nn8jIezg6OikMn09Jkya2TJ58nKCgf1FRUaJLFxEODi0L1dHS0sLFpf0rWng9mpqaZGZmvraOkZEx\nyclPuXHjOmZm5vj47OXo0dXUqVMbPb1yQJSirkxW2GCOjY1h69bN+PpuRFtbm7lz3Z/F0xV/L5FI\nhIGBAXFxsURH3/8g6UYEBL5EBGNOQEBA4BvhZYEQHR1t7t2LYOzY/yGVSklNTaVGDTPOnj1DRMQd\nWrVqxePHac9yf6lw/XoYeXm53L8fhbGxCXXr2nHz5n9oaWlhZVVbsathZlaTAwf2sn37H1StaoCB\ngSEArVu3ZffunZ/wCZQOoaFh3LwZR8uW1piYGL5XW9nZ2Tg77+TUqSaoqz/B2PhWofNKSnmfVbLw\nmJiHzJ69ABcXNwYP7s+xY4dZvdqP06f/JjDQH3PzmoXqyyfvsmIXFgrOi0Qinj59yoIFc9DQ+ImL\nF0cjFmczZkwMOTmhdO/+IRPXi5g4cSoXLpxj/fqNnDkTQnZ2Nrdu3eTnn7sqdntEIhEqKqr4+gYS\nHLyVKVPG4++/GR0dXXr06EyPHn3Q1dX9gP16N4YNa8mwYVCxog6PH6d90LbLlNGjTh1b+vfvgZqa\nmsJl+EWUlZWZNWs+ixcv4L//YsjIKMPTpxNITVWlUaMdpKQkK2IqT548ho2NLWfOhCCTycjIyEBd\nXQMtLS2ePEni3LnQZ/GVxiQlJRIe/h+WlrXIzMxATU0dmUxG5cpVGDFiDFOnTmLWrHmfbcoOAYHS\nRDDmBAQEBL4RXlYEtLOzx8TElNWr/YiLi2Xy5N9ZtGgZZ8+efmUbbdt2xNj4ubpewcq6svJzsRAl\nJTESSV4xVxevZPgl4+NzhMWLrcjMbIyBwVGWLk3GwcH6ndtbufIEJ04MBFTIyIC7d1UwMlpOdPRv\nqKo+oFevxM8qLqhKFQOF1LyJSQ3s7Rs8+21KfHxsEWNOjui1qQZkMhn//nsda2sbgoPrAKrk56uS\nmanL0aP/0b37hx+HRCJh+PBf6dPHmZCQv8nJyWbXrj9p3tyRmJiH3L17h6SkRK5fD6NDh5+oUcNU\nYcxUrWpAQkL8Z2HMlTYzZswutvzF9ADm5jXp128EnTrVAOSLG6mp0KpVAkZG8ezcuQ01NTVyc3P5\n+eeunDkTgkgkwty8JjVrWtC7dxf09StjY2MLyA1EDw/PV7hLizAyMmbGjFm4uk5hwYLFVK0qpIYR\n+LYQjDkBAQGBb4TExER0dHRo3botWlra7Nq1neRk+Up5cPCWZ65m8t0IZWUVevbsyZMnTzE3r4mL\nixs2NvWYN28WNWqYUblyFQ4fPoCGhib//HOOsmXLsW/fbjZt2kB6ejrVqxtjbFyDGzeuERPzEAMD\nQ44cOfSpH8EHRSaTERQkIzNT7t4VE+OEn9+29zLmMjLEwHPDOCenCh4e1XnyZDfVqpWnWbOiucM+\nJYUVP8XPdnHlv6VSKUpKSshkz4VeCgRj3pRqQCQSoaSkhI5OGklJBaUytLSyS31MjRs7cOvWTZo3\nd0RTU5MFC+ZQrZoR48dPQSKRsHChJxUq6Bfq65vEbL41dHQ0UVFJJU+xpiOlTJkydO8+kfPnn9Kw\nYROGD2+FkpISy5Y9V618MUbyRSwtaxVxl37RFdrc3IJNm7aVxlAEBD57BGNOQEBA4BuhOIEQsViM\nj483T58+JT8/n169+mJgUI3Jk39/NqFW5vTpEM6ePUPz5o74+Hgzc+ZU1NXVAahf3x5HRycCA/3w\n81uLn98mfH1X888/5zEzM2fSpGlMmjQWNTV1bG3rERv78BM/hQ+LVFpYFDo///2CAn/6qQZ//nmM\n2FhHIJ8GDQ7i6NipSLLzL4UqVapy5kwIALduhStENF5eWHhRCl8kElG7dh0WLpzHr79+z8qV+0hK\nKouNzS0mT25W6n0uUG2UyWTk5uZy48Y1xGKx4rtPTU0rZMwJFMXa2oLevXewebMyEkk5vv9+F3p6\neoweXZ2srFZAOnfvBrNkSdd3av/gwUv88UciSkoyfv1VLuQiIPCtIhhzAgIC3zzp6ekcOXKQn3/u\nyuXLF9m6dXMhwYuvhVcJhCxfvlbhZtmhQ2eF8MOmTYE8fpyGt/c8cnPlS+wGBoaMHPk7165dxcfH\nm/nz59Cv3wB+/rkbISEnKVNGj/Hjp7B9+1aio6Np0KBhIYXD+Pg4jhw5qEgG/SUjEon45ZdcVq6M\nJifHiIoVz9Knz/tN8m1tzfH1vcXOncGoqkoYO7bNZ23Ivaz4WHD86FECDx8+oFmzlhw8uI9+/bpT\nq5Y11apVB15eWFBmwoTCIjR6enpMmjSNKVPGYW1tiK5uGZYtW6NITF1a43hRxbTgv9raOhgZVWfk\nyN+xsLDkypVLbN365aQs+FR4ef1C9+7XSUm5j4NDZ3799W+yssyfndXm9OmyyGSyt1YNvXz5NuPG\nqZGYKDcEr1w5wF9/xWNoWPkDj+DjUvA3ODDwj0/dFYEvDMGYExAQ+OZJS0tl585gfv753VaJv0ZU\nVFQVv5WUxEilkkLnd+3ajrq6BgcPngAgJORkodxsf/8dzqVLiWzZcoxOnbJwde0IyBXrjhw59FUY\ncwAuLu2wtT3H3bvnad7cFBubd3exLMDe3gJ7e4sP0LvSpSBxegEvusjp61fC0LAaampqLFq0vMi1\nlStXLnZh4UWXu4YNf6BiRX1WrVpf6nGCK1asY/Lk32nbtgN169oxefLvDBo0FICzZ0/zyy/dsLCw\nRCaToaOjy/z5zxd7XuyzQGG++66O4remZm6hc5qaOe+U/uHUqUiFIQfw4EFrTpzYSb9+X7Yx97ZI\nJJJSW9wQ+LIQvgIBAYFvntWrlxET85CBA3ujrKyMuroG06dPLpIzLTz8ZrE52b4GSiI7XkBg4Hpi\nY2OQSCRs2yaPtevXbxAeHq6IRCL+/fc/7t1LJiurPioqAezencfVq+vZuHELq1cvJzo6ioEDe9O2\nbUe6d+9VyiMrfdq1K738Z18yUqkUDw9Xbt8Ox9i4Bq6u7ly/fo2VK32QSqVYWtZiwgQXVFRUuHjx\nAitX+pCXl0dmpi6amr9gZSVTLBDk5GQzdeokWrRoSYcOnT9YH180Jl71281tNt7e8wgI8OPJk1SS\nk2uTl9caJ6cM3N07fnH56D4V48ZZc/t2EP/99x2VKt1l9OiiapglwcREGxWVGPLy5EInWlo3sbJ6\nPxXZklCQisPSslaJvumuXTvSsqUT58+HoqqqxsyZczAwMGTOnJk0buxA8+aOgDw9xpEjIUXuNXv2\nDLKysgAYN24S1tY2XL58EV/f1VSoUI47d+4Wygso8O0iGHMCAgLfPMOHjyYy8h7+/lu4cuUSLi7j\nC+VMu3btKrVqWbNkiRfz5xfNyfY1UBLZ8QL69/+V27dvkZycjI6OXMGvQoUKmJnV5MSJY+jolCc9\nvRWammeIj19IdnY9unffjJqaGsOHjyIoaNNX6cYqUJjo6Pu4uLhhbW2Dp6cHQUGb2L17J0uXrsbQ\nsBqzZ89g587tdOr0C3PnurN06WpmzTrP1av/kZOTzaFDXbGxWUxmZiZubi60bduBNm3afdA+Hj78\nN1B4l/HF3zk5OYjFYry8lhATE0fr1gkkJcnj9qKiHmFhcYo+fUo/ju9rwNLSmAMHKnHvXhRVq1q9\nc67ETp0aExa2l7/+0kBZWUq/fmLs7Vt94N4Wz4MH0UydOuON33T37r0QiUTo6OgQELCVgwf34eOz\nkAULFhdj/BddDChXrhyLF69AVVWVBw+icXefjq9vIAB37txiyZJ9qKp+/eqpAiVDMOYEBAS+eV50\nD5TJZEVypsXHx6GtrU1kpDwnG0B+fj7ly1f8JP0tLUoiO/6iS9n27bsV4haJiYloa+syZsx46tVr\nSKdOEaSkVKNiRU/y8+tQv37DZ8qGX196AoHi0devhLW1DQBt2rRjwwZfqlY1wNCwGiBXI9yxYxt2\ndvaK8n//vUVq6i/o6W0hOdmZ3FyYMmU8ffr0/+iuufv3X8bdPYn4eEOsrc/Qu7cGSUnPjQapVJ97\n97I+ap++dDQ0NKhd+/3FStzcOjB9en6hGMePQUm/6QKPg1at2ij+u2zZohLfJy9PwuLF87l79w5i\nsZiHDx8ozllZ1cbAwOCD5xEU+HIRv7mKgICAwLdF0XgxKSDPneXvvwV//y0EBGxl0aJlJWpv+PBB\nwHPxj6+N8PBYmjaN4OBBNRYtukl+vozly8uhr7+MmjWb0bFjBitXziE6OqpE7W3btoWcnNKXoBco\nXV6cZMtkMrS1dQqdL86wr1ixsHGkpAQ2NracOxdaOp18DQsWxBEZ+QtZWQ34558BnDiRjInJ8xyM\n2tr/8f33gqrlp0IsFn90F9eSfNOv6lNBuZKSEvn58m8/Pz+/2Jycf/yxmfLlKxAQsBVf342KlB4A\n6uoa7z0Oga8LwZgTEBD45ilJvJiRkTHJyU+5ceM6IA8+j4y8V6L2V63yA56Lf3xN5Ofnc+2ahMTE\nVujohBIdbc/ChRdp1MgKNTUxQUH98fCYhqVlLaKj76OlpU1mZsZr2wwO3kp29rdrzK1fv4agoM9T\nLTEk5CRRUZElqpuQEK/493LkyEEsLa2Ii4slJkaenuLQof3Uq1cfI6PqivKZM60wMVmEsrImjRtv\nQE9PjcGDh6Gjo8vChfNLbVwvk5+fT0qKeqGynJwKLF9elbZtt+Lo+CceHpG0bm1XKvd3cnIotnzO\nnJmcPHmsVO4p8GZK8k3Xrfv8mzh27LDivwU7epUrV+HWrZsAnD59ComksLgUQGZmhsLV/eDBfUIe\nQ4HXIrhZCggIfPOUJF5MWVmZWbPm4+PjTXp6OlKphB49emNiUuON7RcEuL8s/mFv3wBPT3ckEgn5\n+TLmzFmgcNf5/BEpkiXn5akWKs/KkkvpSyQS+vfvgUwGKSnJxMY+RCKRIBKJGDCgN61atSEs7DKP\nHz8mP1+Ks/Ngnj5NIjHxMaNHD0NPryw+Pqs+zfA+IR9rt6Fgx/ltOHXqJI0bO2BsbPLaeiKRCCOj\n6uzcuY158zwwNq5Bjx59qF27Dq6uk5FKpVhZ1aZz564oKyszdeoMRfmPP9Zm7NgJqKur062bPE5o\n7NgJzJ3rzsqVS/nf/0a/03jfBrFYjL39E2JicgFVVFWjcXBQ5rvvLAkIsCz1+xcXRwV8dLdCgcKU\n9JsuIC0tDWfnXqiqqjJz5hwAfvrpZ6ZMGc+AAb35/vtGaGhoKuoXvNuff+7GtGmTOHhwfzF1PtJg\nBb4YRLLPJIBB8P39eqlYUUd4v18xwvt9M05OTTly5BRXrlwqJP6xZIkXtWrVoXXrH5FIJEil0o+e\nU6xAoc3a2obr18OwtKxF27Yd8PdfS1paKtOmuWNgYIinpwexsbGoq6szadI0TE3NSElJZubMaVy7\nFsGjRw5oal4gKWk2Xl4ZaGo+Yc6cmZiammFlVZv//W802to6tGrVBGVlFSpXrkKLFo74+a2lTx9n\nQkNPo6ysjJfXEoYOHcD69RtLXY7+cyIgYD0HD+6jbNly6OtXwsLCCnv77/Dy8iQnJwcDA0NcXNyQ\nSPKYMGEM69dv5M6d2wwa1Ic//9yLvn4levToTGDgVry956Glpc2pUyd48iQJAwNDzM0tsLCwIjQ0\nBHPzmly7FsbPP3fC1LRWsQqtu3fvZM+eneTlSTA0NMTV1YPbt28xefI4tLS00dbWYvbsBRgYlL6K\n4Idg/fo1aGpq0atX3xJfk5OTg4uLL8nJqjRvbkH//k2LKBF+CLZu3cT+/XsA6NChM92791L8zZDJ\nZCxevICLFy+gr18JFRUV2rf/6Y33F/HSUSIAACAASURBVP4uf3jeNg9ct24/fZC/YzKZjKioKJSU\nlDAyMgKE9/s1U7GizpsrvYTgZikgICBQAq5eDSco6ChxcY/euY2X185q167Dxo1+bN4cQHx83CdL\nDh0T85CePfuyZcufREff59ixw6xa5cekSZMIDPTHz28tFhZWBAQE8dtvI5g9W67g6e+/Dlvbehw8\nuId27bRRUYnF2zuZevUqcfz4EdTU1PH33wJAjx4DadKkI9nZ2WRlZbF48XJatnRCKpUSHX2f33+f\nhJ2dPbt37/wkz+BTEh5+k+PHj7BhQxDe3j6Eh/8HwOzZMxkxYgwBAUGYmprh77+WsmXLkZubQ2Zm\nBteuXcHSshZXr14hPj6OsmXLoaYmdw2MirpHmTJl8PXdSF5eHuHhNxX3k0gk+PoG0rdvX5Ys8WLO\nnAWsX7+R9u07snbtSgCaN2/JunWBbNiwherVTdi79y/q1LGlSZOmjBw5Bn//LaVqyGVkZLB9+wmO\nH7/wQURz3mU3S01Njdq1lXFy0qB//6bv3M7rCA+/yYEDe1m3LoA1azawZ89O7ty5pTh/6tQJHjyI\nZvPm7Uyf7sH169eEnblPyNs9+/d/T/n5+Qwb9gdNmsho0iSbceO2CyJSAkUQ3CwFBAQE3sCKFcdY\nuNCE9PQOGBkdZtWqJ3z33fu7Wjk5/Ujt2nUIDQ1hwoQxTJo0FTs7+w/Q47ejShUDatQwBcDEpAb2\n9g0AqFmzJnFxsSQkxDFnjhcAdnb2pKSkkJmZQVjYFebO9UZFRYX588fQrt1uWrSow9Gjh7h1K5zs\n7CwGDuzNgwexpKToc//+MczNbZBItElLS8PIqDoqKio0bdqcdetWUrZsObS1tT/6+D81165doWnT\nFs+MeTUaN25KdnYW6elp2NrWA+DHH9vj6joFAGtrW65dCyMs7Cr9+g3k/PlQQKaoKxKJKF++AjY2\ndTEzMyc5OZmOHZ/nZnN0bA3AvXv3XqnQGhFxl3XrVpGRkU5mZhbff99IcX1pTyafPHlKr17HuHKl\nB0pKT+jZ808WLery1kZMcbudMTEPWbRoAcnJT1FXV2fy5GkYGRlz+vQpAgP9kEjy0NUtw4wZs8nO\nzmb37h2IxUocOXKAMWMmAnD16hX++GMzSUlJ/O9/o99rl+7atavP3r3cCG/WrCVXr15RnL969QpO\nTj8iEomoUKEC9et//L8PAnJeTFlREoKD/3rve27ZcpydO3sAuuTlQVCQAa1ancXZuc17ty3w9SAY\ncwICAgKvQSaTERAgJT3dFoDo6PasWfPHOxlzmppahcQ/YmNjqFrVgK5de5KQkEBExN1PYsypqqoo\nfovFYlRU5MfymDgpYrHKKyfwrypv27YD27f/gb//Fnr1mklcXFlACZlMCZEohadPn6CpqYWSkjKt\nW7dFS0ubDRt8MTGpgaamJhkZGV+dm+WBA3v57ruGVKjwcqL5tzNS6tatR1jYFRIS4nFwaMamTRsQ\niUT88MNz0YwX00C8/I4K1PBkMhkmJqasXu1X5B5z57ozb94iTE3NOHBgL1euXHre21LeGVq9OpQr\nVwYAIqRSTbZt+54hQ25Tq5ZFidt4cbdTKpUwaFBfLCysWLBgLhMnujxLg3CDhQvn4+OzClvbeqxd\nuwGAPXt2sXlzICNHjqVTpy5oamrSs6fcPXPv3l08eZLEqlV+REVFMmXKuPcy5op7li8WiUSlbzwL\nfL48fZoHPM8nJ5VW4PHj9E/XIYHPEsHNUkBAQOA1yGQypFKlQmUvH7+JggmbmZk5SkpKDBjQm23b\ntnD8+BH69evOwIG9iYyM4Mcf23+wfn9IbG3rcfjwAQAuX76Inl5ZNDW1sLW1U6RaOHv2DGlpqYhE\nIurXb8CJE88V9+rXt0JD4zzVq3dEJJIgFlegfPkK3Lt3l5ycbAYO7M2GDeto3rwlIBcIGD9+FGPG\nDP/4gy1F9u/fQ2Li4yLldevW49Spk+TkyN0nz5wJQV1dAx0dXcLCrgJyRbt69eoD8vdx6NB+DA2r\nIRKJ0NXV5ezZM9jY1FW0Wb26MWfOhDyTNJcRGhqiOFdgHJiYmLxSoTUrK5Ny5cojkUg4dGi/4toC\nQ7s0kUrFvGjg5uWpk52d++oLiuHF3U5NTS0aN25Kbm4ON26E4eo6mYEDe+PtPZekpCQAHj1K4Pff\nR+Ds3JOgoI1ERT1Xqn3RlhKJRDg4yJOEGxub8OTJk3cfKGBrW/fZu5e7H586dUKxwyo/b8exY0fI\nz88nMTGRy5cvvaY1ga+Nzp3rYWq6S3FsZRVMx47ffcIeCXyOCDtzAgICAq9BLBbTvn0Gvr7xSCSV\nKV/+PN27F1W7fB2HD/8NyBUxX1Zn7Nt3wIfq6jvz8u7Ai8cikYiBA4fg6emBs3MvNDQ0mD59JgCD\nBg1h5sxp9OvXHWtrWypXrgLIJ7lDhgxn0yZ/nJ17oaysjKNjM27erMHTpzNYsWIZhobVMDSshoaG\nJv7+W7h1K4L9+w+Sl5dHly496NKlx0cb//vwsniFg0OzQiIJW7ZsJDs7ixo1TAkPv4mb2xQSEx9z\n4MAJRYxkzZqWODo6MWBAL8qWLUetWrURiWDatJl4e3uSnZ2NgYEhU6fOAOQ7usnJyQoJdFvbes+S\ntj93UTU0rEaTJk1xdu5JTk4OpqZmaGtrF1JDVFVVfaVC6+DBwxg6dAB6enrUrm2tSN3h6Nia+fPn\nsH37H8yaNa9U4ub69KnD/v07iIj4BcjGyekwtrY937KVojteBXnBCuI4X2Tx4gX06tWPxo0duHLl\nEn5+a1/ZcsHOdUGb70PNmpa0a9eBIUOcAejY8WfMzS0U76hZsxZcvvwPfft2o1KlytSpY/Ne9xP4\nsqhWrTIBATkEBv6BWCxj6FB7ypUr+6m7JfCZIahZCpQ6gurS18238H5lMhk7dpzh/v0MmjY1xt6+\n5O5exXHx4i2OH4/EwECd3r2bfbaCBh/r3S5ffpRFi4xITzfFwuIwvr61sLCoXur3fV/Cw2/i6enO\n2rUbyM+XMXSoM25us5g1y01hzAUFbXoWOziEUaN+o2fPvqxZs7zEinjFcfnyRbZu3axQRX0VWVlZ\naGhokJ2dzciRQ5k8eRrm5s+/3c/t3+6LaoH378cxb94KxOJcRKJkzM0tuHr1ElKpFBcXN6ysar+2\nrdu3w5kzR/5u5G6W/ejU6RdOnTpO9+69adGiFTKZjIiIu5iZmTNoUB8mT3bFwsKSuXPdiYuLZdmy\nNWzduomMjAx+/fU3QO5++sMPTRSulQWqk58bn9u7FfiwCO/36+Vd1CyFnTkBAQGBNyASiejSpckH\naevIkSuMGaNMYmI3xOIkwsJ2smDBLx+k7S+JHTvOsmdPGioqWYSGqpOeLnchvHWrOytWbGXp0s/f\nmHuTeEUBL66Zyt12pXh4uHL7djjGxjVwdXVny5aNhIaGkJOTg7W1DZMmTQPg4cMHeHl5kpKSjFgs\nZtaseYXavnnzX7y85jJ79gL09MoyZ84xEhPVqF9fmbi4Y0RFRZKbm0vbth0KGXIlJS0tjcmTjxAZ\nqUO1aul4ejanfPnS3xmoXr0KTZtakpWVyZUrl8jJycbffwthYVfw9PR4ozH8qt1ON7fZeHvPIyDA\nD4lEQqtWrZ8Zc0NxdZ2Mjo4u9evbEx8fB0Djxk2ZPn0yZ86cUgigvLxzXVpER8exefNVVFRk/Pab\nAzo6bz/JExAQ+PoRjDkBAQGBj8j27Y9JTOwCQH5+eQ4cqMDs2bmoqqq+4cqvh2PHrjJpUhVSU1sD\nWYjFRwqdz839Mv7XJBKJOHBgL40bN8XS0gqZTEZGRjr5+c+Nt5yc7CKT/+jo+7i4uGFtbYOnpwc7\ndmynS5ceDBw4BIBZs9w4cyaExo0dcHefTv/+A3FwaE5eXh75+VISEuIBuH49jCVLvJk3bxH6+pVw\ndt7GgQPOgDJ79sQxdaqUGTPmvNcYJ08+wvbt/QAxly7JkEgC8fP7+IsPrVrJ1ftsbeuRkZFBRkY6\nWlqvVz7t338Q/fsPKlK+cOHSImVNmjSjSZNmRcqrVTMiICAIkBvitWtbo6z8/PsscKH+0Dx8mEDv\n3mHcvt0dkHLy5AaCgzuioaFR4jYOHNjL1q2bEYlEmJmZM3jwMObOdSclJQU9vbJMnepGpUqVmTNn\nJmpq6ty5c4unT58wZYor+/fvITz8P2rVsla49zo5OfDTTz9z4cI5ypWrgLv7XPT09IrNSaimps6c\nOTPR0tLm1q3/Cil/zp49g2bNWuDg0BwAd/fpODo6Ffv8BQQE3owggCIgICDwEVFWlhY6VlHJQ0np\n7QRVvnRCQuJJTa3z7EiD/Pw4IAWA8uXP8csvFT9Z394GW9u6JCcnk5eXS1ZWFiEhJ2nY8AeSk5+Q\nmppCbm4uoaGnFfU1NTXJzMxAX78S1tby2Kc2bdpx7dpVLl/+hyFDnHF27snlyxeJirpHZmYGSUmJ\nikmvioqKYhfw/v1IvLzmsmDBYvT1K5Gfn8/Vq+UpWKOVSKpw4cL7R1FERenwfKogIjJS93XV3wsl\nJaVChnBubs4r635s1+Rduy7g4HAAO7vTDBmy7ZmwzPsRHLyVvn27MWuWa5Fz27Zd4fbtbs+OlLhw\noTuHD/9T4rbv3LlDYKAfy5atZsOGLYwePZ5FixbQrl1HAgKCaN36R5Ys8VbUT09PY80af0aPHseU\nKePp3bs/GzduIyLiLnfv3gEgOzsbS8tabNy4jXr17PD3l8cVFpeTsIAC5c8FC5awevVyADp06MT+\n/Xuf3TedGzeuF1JiFRAQeDu+jOVPAQEBga+E4cNrcvnyTiIi2qCldZsBA/Lfy5gbPnwQq1b5ER8f\nx/XrYTg5/fhO7XTt2hE/v03o6pYhOHgrf/31JzY2dZg0ye2d+/YqDAyUkRtv8tQDeno1GTZsFzk5\nWrRsacT339t98Hu+SFxcLOPHj8La2obr18OwtKxF27Yd8Pdfy9OnycyYMQsAH5+F5ObmoKamhovL\nDIyMqpOTk83cue5ERNzFyMiYMmXKMHv2DNTU1LCxqceSJd6oqqrRpUsHzMxqYmxsorhvu3YdWbHC\nh6SkRHJy5O3KZDJEIhGLFskTd1esqI+f39pnxkLxBktBHrm8vFxu3w6nUaMmiMViypXLIi6uoJYM\nPb2s935WhoZpXLwoe9YXGUZGpRenU65ceYUhrK6uQWjoaUV+u+PHj2BnZ09Y2FW0tXXQ1NQqtX68\nTEZGBh4eGTx8KBfl+euvHMzMdjF5crv3anfXru34+KyiQoWiixfq6gC5gHzHXix+Qpky6iVu+9y5\nc7Rs6aRI76Grq8t//13H01NuwLVq1YZVq+Q7lCKRiMaN5caUiYkp5cqVL5R3Mj4+FjMzc8RisSJH\nYevWbZk2Te52+qqchK9S/qxb146FC+eRnJzMyZNHadGiJWKxsLcgIPCuCMacgICAwEfE2tqUPXv0\n+PvvY5ibV8XGxum92lu1Sp4jLDY2hiNHDr2zMffiTkfBJNPKqkapBNn/+mtLbt7cwYkT5VBTy2X4\ncE2cnT+u615MzENmz16Ai4sbgwf359ixw6xa5cfp038TGOiPq6sHK1asQ0lJiX/+Oc/atSuYPXsB\nO3duR0NDk02bgomIuMugQX1YuzaASpUqM336JHx8VqKmps6mTRuQSCQMGDBYcc9mzVpSs6Yl3bt3\n4s6d21hb1+HIkYPY2Nhy48Y1dHXLkJmZyYkTR2nZ0glNTU0qVtQnJOQkDg7Nyc3NRSbLV6gyuri4\nMnbsCNTVNahXrz7TplXD3T2IR4/0sbK6z/TpLd/7Oc2f3xKpNJCoKF2qVUtj/vzS20FRVlZmwIDB\nDBniTMWK+lSvbqw4p6qqyqBBfRQCKMXxooDKh+TJkyQePTJ6oUSNR4/ebzfdy2susbExjB8/irZt\nOxAWdoXY2FjU1dWZNGkagwY1Z/fuocTGaqOiEkOVKjLu3XPg9Ol9xMXFkpAQz6hRv3P9+jX++ecc\nFSroM3/+IpSVlQkPv8nGjRtJT0/n5s3/mDZtBuXLVyAtLY1lyxZz48Y1HB0L/91RUVEhPT2dI0cO\nFsk7KZVKeZmCRQh4fU7CVyl//vhjew4d2sexY0eYNm3mez1LAYFvHcGYExAQEPjIVKhQni5dmn+Q\ntpycHDhyJITVq5cTHR3FwIG9adu2I/b2DfD0dEcikZCfL2PuXC8MDAw5dGg/27f/gUSSR61a1owf\nP0WxKi6TyQpNMrt370b79l0+SD9fRCwWs2hRV6RSKWKx+JOoeVapYlBo98HevsGz36bEx8eSnp7G\nrFluxMQ8QCQSKSa0YWFX6dZNLpNvamqGqak5AP/+e52oqHsMGyaP0crLk1Cnjg1btpzmjz8yUFKS\n0b9/WRo2NMLIqDo7d25j3jwPjI1r8PPPXUlLS6N//x6UK1eeWrWsFf10dfXAy2suvr5rUFFRwcPD\n81l6AShbthwLFixmwoTRTJ06g1atbGnZsg4ZGeno6DT4IM+pbFk91q//eIZ216496dr1eRoCiUTC\n+fNncXRsw+jR4z9aP16kSpWq1Kmzm0uX5Hn81NUjaNjw/dxNJ06cyoUL51i2bA3r16/BwsIKT8+F\nXL58kdmz3fD330KXLrU4evQIY8ZMpVGj+vj5rSUuLpalS1cTGXmP334bwNy53owcOZapUydy9uxp\nGjVqwpIlXsybN49JkybTooUja9euZMSIMWhraxMVdQ9f30D2799TKJ8dQFpaKocOHUBFpfipYX5+\nPidOHMXRsfWzRQj59S/nJNTXr/TG8bdr15HBg/tToULFQkb727Jt2xY6dfpF4X78vvUEBL5EBGNO\nQEBA4ItGbggNHz6KoKBNCrn6JUu86NatN61b/4hEIkEqlRIVFcnx40dYvdoPJSUlvL3ncfjwAUWy\ncpFIVGiSaWpqWKry129yL42Li2XChNHY2NTjxo0wKlbUx9NzIYmJj1m0aAHJyU9RV1dn8uRpGBhU\no2fPXwgO/ou0tDTat3dk2bK12NrWZcSIIUydOqNQTrSXdx8KdhAKdiJ8fVdjb/8dnp7exMXFMnr0\nsFf2s2DHwd7+e2bOfC44cvbsDfr31yQlxRaAW7fOsn17Fps3by/SxpAhwxkypGiSdEPDaoVyE2Zk\nZHD9ejT9+snrVqpUmY0btxUai45O6cW1fUzCwu4wdmw4aWkShg07h5eXMg0bWr72mvz8fObPn1Po\ne4mOjsLLy5OcnBwMDAxxcXFDR0eHkSOHYmFhSVjYVbKyMpk+3Z3AQH8iI+/h6OikeB/Hjh1GX387\ntrZ+qKgY0a1bN7p1a/FBxiiTybh+PYw5c7wAsLOzJyUlhczMDJSUlGjbth0//GAPyP99Nmz4A0pK\nStSoYYpMJlO4NJqamhEXF0d09H0iIyPw8PAgNzeXxYsXoKysjEwmw9DQiNTUVJyde1G2bFmFsElB\n26tXLyMhQe6nu3KlD3p65ThzJoTLly9y9+4d1NU1+O+/f5k/fw4go2JFfXbv3qnISRgfH4e5eU3C\nw2/y6FECKiqq+PquZuXKpYwePb7Qok3ZsuUwNq5B06bN3+v5BQdvpU2bdm800kpaT0DgS0Qw5gQE\nBAQ+Ia9yDVu/fg22tvUUO0Zv4uWUobVr1yEw0I/HjxNo1qwlhobVuHTpArduhTN4cD8AcnJyKF/+\n7RKgf2wePnyAu7snkydPw83Nhb//Ps6+fXuYONEFQ8Nq/PvvDRYunI+PzyqMjKoTGXmP2NiYZ5P0\ny1hZ1eLRo0dvldy6QJWyIJapICk4QN269Thy5CB2dvbcu3eXiIg7iEQiateuw6JF84mJeYiBgSFZ\nWVkcO3aVlJTnBtrjxw0JDd2OlVWNd3oW8fGP6ds3hGvXOqKu/pDfftvHtGnt36mtL4F5827x77+9\ngd4ALFgQxI4drzfmHjyIZubMuYW+l82bAxk3bhK2tvVYv34N/v5rFcaF3OAIJDh4K1OmjMfffzM6\nOrr06NGZHj368ORJEsePH8HPb6NiAaRKlcwPPtZXpfx92fhQVn6+6KCk9HwK93z3WIaJiSl//hlc\nZCFm1KjfGDduMhYWhZ9hgVFnYWFFZOQ9AgP/4MKFc5w8eYx9+46Sn5/PlCnjkUqljBr1O87Ov6Kr\nq0tOTjZDhjizfPk6OnfuioPDd/z22wi+/74RU6dOJCsrk4CArURG3mPOnBkcPvz3s53WMJSVRTx8\nGI2TU5sSP6OsrCzc3Kbw+PFj8vOltGjRisTEx4wePQw9vbL4+KzC29uT8PCb5ORk07y5I7/++hvB\nwVuL1Ltw4ZwiNtXAwJCpU2egoaHBqlXLOHMmBCUlJRo0aMiIEWNK3D8BgU+FYMwJCAgIfIYUJCl+\nV5ycfqR27TqEhoYwYcIYJk2aCkDbth347bcRH6KLH4UqVQwwM5O7MlpYWBIXF8uNG2G4uk5W1MnL\nkwBydcmwsMvExsbSt+9A9uzZSd26dlhZ1SrS7suunS8ei8VievXqz5w5MwgIWE+jRk0o2AHt3Lkr\nc+e607dvN6pXN8bSUt62np4e06bNZObMqeTm5gHQqNGPaGndIiPD4lmdy9jbv5shB7B06XmuXesP\niMjOLseGDUkMH55EuXKft0H+rqSmqhU6Tkl5867Ky99LTMxD0tPTFC6FP/7YHlfXKYr6TZo0BaBG\nDVNq1DBVPMuqVQ1ISIjn2rUrpb4AYmNTj8OHDzBgwGAuX76Inl5ZNDW1XmngvQ4jI2OSk59y9epV\nDAxMkUgkPHgQjYnJm7+7gvuFhJwkJORvrly5xMCBckM6KysbkJ8PDg4iJESekuHRowQePoymVi1r\nVFRUCu0UqqqqKnYR4+LiyM7Opn//nVy4UJlKlbwwNbVBQ0OzxGM7fz6UChX08fLyASAjI539+/ew\nbNkahdDL0KEj0NXVRSqVMnbs/7h37y7duvVk27YtinrJyckEBvoVim/944/N/PJLN0JCTrJly5+K\n9gUEvgQEY05AQEDgE1Oca5i3tyeNGzvQvLljiVaLNTW1yMzMUBzHxsZQtaoBXbv2JCEhgYiIu3z3\n3fdMmTKe7t17U7ZsWVJTU8jMzKJy5cofc7hvRWF3SCVSU5+gra2Dv/+WInVtbe3YuTOYpKREBg8e\nRlDQRq5cuVQkNqhKlaoEBGxVHL/obvbiuaCgHYryApc7NTU13N3nFttXOzt71q0LfKn0GH/+eQOx\nOJ/+/bWxtX335PN5eSq8qHCZk6NDTs6r5fs/Fi/HI02cOIaZM+e8Mg/c+vVr0NTUolevvq9tt2HD\nbC5efIJMVg5IpUGDN0+uX/5e0tNf7yasoiJXiyzYpSvgxTjJ0lsAESESiRg0aCienh44O/dCQ0OD\n6dNnKvrwcjjpi8dFFyTkIjKzZs3H29ubp09TkEol9OjRu0TGHIBMBqdOnSQ1NYW+fQfQqVPheMnL\nly9y6dI/rFnjj5qaGqNG/aZI0/DyTuGLu4hy1+UTnDw5EFAhMvInoqOjCA29QuPGJVOvNTU1Z8UK\nH1atWsYPPzhga1u3SJ3jxw+ze/cupFIpSUmJREZGUqOGWaE6r4pv1dLSRlVVDU9PD374wUGh8Ckg\n8LkjGHMCAgLfNC4uE3j0KIHc3By6devFTz/9/NH7UJxrmHwiJyIlJfm1q8UFEzozM3OUlJQYMKA3\n7dp1IDc3l0OH9qOsrEz58hXo338QOjo6DBkynHHjRpCfL0NZWZnx4ycXY8x9fEGSkqKlpUXVqgac\nOHGUFi1aIZPJuHv3DubmNalVqzazZrliYFANVVVVzMzM+euvHYqV/NIiPT0dL6+TpKaq0LJlOTp2\n/E5xbtgwR4a9OtzurejevTqHDx8nLq4lkEWrVleoXLnHh2n8HZFKpUXikd70vEsqeDN9egfKlTvK\nrVv5mJqKGDXqp7fun5aWNrq6uoSFXcXWti4HD+6jXr36JbpWJBJRv36DUlsACQ5+no+tIGXAiwwa\nNPS1xy8mLH/xnLa2NomJiZibW3L7djjnzp3FyelH7Ozs8faeS05ODtbWNkyaNA2AkSOHUrOmBVeu\nXOLx40ckJSWirKzMpUv/YG1tw5w5M58pZapw/34UERF3UVNT4/79KP7990aJx5uVJQKeG9tSaVme\nPr1Z4uurVTPCz28zZ8+eZt26ldSv/12h87GxMWzduhlf341oa2szd677K3MVvhzfWsC6dQFcvHiB\nkyePsWPHtkLxqgICnyuCMScgIPBN4+LiVij+o3nzlgqXnY9Fca6EBWhr67x2tbhgQqesrFxk4tG3\n74Ai93J0dCoiSw4QHLz7hd9/FTn/qSjOHdLNbRbe3vMICPBDIpHQqlVrzM1roqKiQqVKlaldW64G\naWtbj2PHjmBqalZc0x8EmUzGoEH7nu04KLF793Vksgv89NO7q0kW5PmzsLBkyhQ3JkwYQ2pqMv36\nDWLjRmP27QtGT0/EkCFdS10JtLjFDicnBzp16sLFixdo3rxlkXikF3MWHjiwl61bNyMSiTAzM2f6\ndPdC7cfEPCwiZmNkZAzI3/WIEW+XuqO472Xq1Jl4e3uSnZ2tiI8q7rriHqWxsUkJF0A+L6Kiopg8\n2ZXateswbNgIpkyZw++/D2HgwCEAzJrlxpkzITRu7IBIJEIikeDvvwV39+mcPXuG2rWtsbP7Dg+P\n6cTExDBx4ljmzvUiLi6WihX16du3G9WqVcfauo7insXtFL54rls3G3bs2ElExM9APg0abKdVq5Ib\n6ImJiejo6NC6dVu0tLTZu/cvNDW1yMjIQFe3DBkZGaira6ClpcWTJ0mcOxeqMNw1NTUV9WrVsi4S\n35qY+JgKFSqSnZ1Fo0aNqVPHlh49Or37CxAQ+IiIZO/ilF0KlKZimsCnpWJFHeH9fsV86e93/fo1\niviP+Pg4Fi5cpjAGPgYvC6AEBW0iKyuT+Pg4fvihCc2bO5KXl6dYLY6PjyuV1eLQ0KtERCTQunV9\nKlWqAHz57/Zj8OjRIxo0eExmHMSaCAAAIABJREFUZkNFWc+e21m6tOTCDi/Tp09XRTLpGzeu4+u7\niiVLVn6I7haiJO83NTX1JbGLtbT/P3vnGRDV0YXhZ5dd6tJEsGChWFBRhGCvUTEaNdEoAWzYWxJ7\nrLFHsKAGjYoSaSrYjd1YY8GosYFG8TN2moKKtKXsst+PlZUqFlCT3OfX3r0zc2fuXS5zZs55T5cO\nzJ3rzaefdgDA1fUL1q1br1kEyT1OTExk+vTvWbMmECMjY1JSUjA0NCQgYC36+vq4u/dlzJiRfP/9\nNI2Yzdq1K4XdkHckLi6WMWNG0LOnG8ePqzh40BYTk18xNi6HtXUkKpWS5ORkevVyo08fT777bjhD\nhozQuCN7ec3RvHsADh06yI0bf/Hdd+Pw8PgKf/8QjIzeTjH17t0YQkMj0dZWMWJEKwwNDV+77vnz\nZ1m50hexWO3COXHiVK5di2D79i2Ym1vg67saL685XL0agYVFRQwNZbRo0ZrOnbuyffvmfOUuXbrA\n6tXLNfGtw4aNws6uDlOmTHjhNqrCw6OfRun3Y0N4N/97MTd//b+JXISdOQEBgf8sRcV/ZGdnlfl1\n3yS5sVwuL/PVYm/v/axe7UhGRjNsbffyyy/W1Kv39kIdH4qnT5+xadNZ9PUl9OnTNl/C4rJCJpNh\nbHyNdI3AYQ6Ghq//G9q0aYNGLbNr1+48eHBPk+evY8fO7NnzK0lJzxg4sDc//rjojVQ5S4P8YheP\nefjwIWKxWDPRLw6VSsWlS3/Srp2LxsgrOHGXy+VcvRpZpJjNx8DduzEsWuTHw4eXMDWV0a1bdxQK\nBdraUnr1cmf58iXcvv03vr6ruXjxT/bt283MmfNwcWmFq6sHZ86cRkdHhwULlmBqWu61r6tQKJBI\nJMUevw4ikYjQ0A1cvLgELa0cVCoxmZnHsbQcwPz57holx1x0dfUK1c+lbdt2BAau5ZNPnLGzq1Oi\nIRcefoVbtx7RsaMjlStb5DtnbW3J9OmWbzSWXBo3bkrjxk3zfVe7th09e750NS5q1xWgZ0+3fOWK\njm9Vu1kKCPzTEIw5AQGB/yzp6WkYGhq+VfxHaVKcq5xIJCI9PS3favE334wt1WtnZGSwcaOMjIxa\nANy+3Z01azazfPk/y5h7/PgJbm6n+euvPkAGR46EEBzsVmIuu3dFX1+fCROk+PjsJSmpIp98coXJ\nkzu/Vt2oqBscOLAXf/9gcnJUDBvmycyZ8zh37g+N8l7duvb58geWBampqRw+fJAePXpx6dIFNm3a\nyKJFy4oRu8hEW1vntdw7RSLRKxUZVaocDA2LFrP50ERF3cPT8xQ5Obd58GAXn322kT17djJ58gw2\nb95Ir17uREXdQKFQoFAoiIi4TMOGaiGPjIwM7O0bMGzYKCZOHIOnpwflyplha1sDLS2tfLteLi6t\nOHxYncvtl1/8MDIy4v79e0yaNB1//9UYGRnx4MF9NmzYyurVK7hy5SJZWdl89ZUrX375FZcuXSAg\nYC0mJqbcvXub2rXrMHToSGJiYtDS0qJSpRloaaXz9Olg9PQuo6WlR3p6OsePH6Fdu5curHmfU65L\nYi7a2to0adIMH58FTJ0685X3bdGiA/z8syMZGc1ZtWo/a9c+p2HDmm/9HIpa+IqKusHBg/sYO3bi\nW7dbsP3PPx/N0aNp6OtnMmlSYywtS058LiDwsSD+0B0QEBAQ+FA0adIcpVJJ376u+Pn9nC/+o6zJ\nVbCcNGksZmblyczMJCYmmgsXzhMefoqYmGisrW0wMyuPlZU19vb10dHR5c6d26XaD6VSiVIpKfBd\n2RpAZUFQ0LkXhpwI0OPQoS84derie7l2//6tOHPmE86dM2T7dtfXdkGLjLxC69afoqOji56eHm3a\ntOPKlcv5yryPSIiUlGR27txa6Pu8ix337t0tdrGj4OQf1Iack1Mjjh8/QnLyc0DtspmLSqVWYK1c\nuTLHjx958Z1azOZjIDT0Bs+eGZCa2hGVypTDh9tRr15Drl+/xs2bN0hPT0NbWxt7+/pERd0gMvKK\nxk1RKpXSvHlL7ty5za1bN2nUqAlBQaGMGVOU8fHSKL516yZjx35PWNgOVCqV5jg0dDt79vyKTCbD\n3z8Ef/9g9uz5VRNb+/ff/2Ps2Ils2LCV2NgYbt68gY2NDVKpNjLZc+TyBjx/7o5EUo+//vqFCRO+\no27d/K7keY3z9u07Ehq6nkGD+hIbGwNAhw6dEIvFhXbG8pKVlcXGjTpkZNQEtLh3rxv+/qX7vgKw\ns6tTKoZcLsnJcqZMqc6uXT0JC/Ng8ODTZGdnl1r7AgJljbAzJyAg8J9DpVLx+PFjxGIxPj7LP0gf\nilKwzJsMOzh4B0OGTMLWdgAGBsloaalYsyaw1AUvDAwM6No1nvXrE8nJKU+lSsfp0+f9uvKVBoVv\niwKp9P0ZpTKZDJmsaCn+4ijqWZaxnkmR+PmtICYmmoEDeyORSNDV1eOHHyZz587fpKena8QubGxs\n8PX1ITMzg/Hjv2P69FmYmZWnatVqeHj0RCqV0Ly5WqBHLs8gLGw9CoWS7t07Y2ZWHkfHTzRucLnj\nnDnzx0JiNrliQB8SsVj54pPamJZIMpBIxIjFIipVsmT//j3Ur++ArW0NLl36k5iYaKpXtwJeSvRf\nuqRWg8zdHS7JyK9Tpx4VK1Yq8vjPP89y+/bf/P77UQDS0tKIjn6IRCKhTp16mgT3NWrU4vHjx0gk\nEkxNTfH3D2b//ks8f76Hr76aRoUKhXPkrVixJt9x/foObNiwJd93kZFX6NLli1e+f1QqVaGFoZyc\n0vtBx8REM2PGZDp06MSVK5dYtGgZ69at4dGjeOLiYnn0KJ6vv/agVy93AIKCfuHQoQOYmJhiYVGB\n2rXr4OHRl6ioG3h7z0UkEtG4cRPS0pSkp9dBJMrEwmI2T55cZMCAMCZMmIKTkzP79+/h1KnfycjI\nIDr6Ie7ufcjMzOLIkYNIpdosXuz71jGEAgKlgbAzJyAg8J9CpVIxZsw2mjRJoGnTGKZO/fW97H4U\n5FXJsN3cerJ69QaePROzfbsbZ86k4+TUBJFIRFxcLP37F5ajX7duDRcunH+rvixa1ANf37NMm7aV\n0NByNG9er1CZb78dRlSUWkb82LEj9O3rypgxI9/qemXBkCHNadgwCFAASXTrdoBmzRxLqPVhcXBo\nyMmTv5OZmYFcLufkyeOFcuLlZf/+PSxbtqjU+zFy5GgsLasQGBjKqFFjXuwITWTjxm1UrFiJSZOm\nM2/eAiQSKb6+qzl58jxdunRj7Vq1KMu1a1c5cuQUhw+fYtKkaWzduptff92Gs3Njtm7dxa5dvyGR\nSBg/Xh0bN2jQMNzd+5KU9Iw7d+KYNm0WQUGhbNiwhQEDhpTauLZsCSUzM+Ot6o4a1YwqVeKRyQ4j\nld7k668vEhl5GQcHJxwcGhIWtoGGDZ1wcHDk11+3U6tW7UJtqA2f/O8WLS0tcnLU3+Xk5KBQvNwB\nKhi3VvB4/PhJBAaGEhgYypYtu2jUqAkqlSpffjwtLTE5OUrNsVgspnfv9owc2blIQ+5VbNoUztCh\nh+natS/79+/B1dX9leV1dHT44ounaGk9AqBixd/p3fvt4uMK8uDBPWbMmMz06XOoU6duvnMPHz5g\n2bKV+PsHExjoj1Kp5MaNvzhx4hjBwZvw8VlOVNQNzQKCt/ccxo+fTFCQ2r1XIlEBmZiYbATEZGSM\nY9q0WcyfP1sTV3j37h28vHzw9w9h7dpVGBgYEBCwEXv7+hw8uK9Uxigg8LYIO3MCAgL/KTZv/p3N\nm7u/SEQMwcE2tG17js8+a1qsDPu7iBkUx6uSYfv4HOT4cVfN+ZQUS27dintle4MHD3/rvohEItzc\n2pRYJndVfu/eXUye/AP16zu89TVLG1NTE3bs6MT27buQybTp0cMNsfjjXq+sVcuOzz/vytChngB0\n69ajUILjvJL5ZZWGIO9ihkqlKrTTEx8fh0wm4+7d24wdOwpQGyJmZuoytrY1mT17Oq1bt6VVq7aA\nWnkwPPwkYWHrAcjOzubx43hN2oHw8OuMHfuY+/c/oXLlSLy99ejc+fWSR78OReW/exMsLMzYvbs/\nXl4J3Lw5goQEPbp160HNmrV4/jyJ9esDNa7POjo6GiP8++/HaNpwcmrE+vWBNGyolsdPTn5OxYqV\n2Lo1DCMjI9LT01EoXk/wpXHjZuzYsQ0QsXXrJr79diwWFkXHdRkbm7Bnzx4+/7yLRo7/Tdm58yxT\npliTnl4b6IGzcxD6+gYl1ps//0ucnU/z8GEaHTrUKhUhpWfPnjF16kS8vHyoXt2KS5cuaM6JRCKa\nN2+JRCLB2NgEU9NyPH36hKtXI2jVSi2CJJVKNSldUlJSSE1N1SQc/+yzLvzxRzjduq3n6tXf0NJq\nxJgxWtSrV5+KFSvx8OEDRCIRjo7O6Onpoaenh0xmSIsWrQGwsanB7dsfh2uwwH+XMjPmTp48iZeX\nFzk5OfTq1Ythw4aVXElAQECgjElIyNAYcgAKRUXi4s4CReecyytmsGrVcnbv3omn5+BS71feZNhV\nq+qhpfUYLa2nZGXZIZEkU7lydU3Z3Hi7a9ciMDe3wNt7CT4+3rRo0Yq2bdvTq1c3XFw6cfZsOGKx\nFpMmTcfPbwWxsTF4ePSje/eeJCYmMmvWVNLT01AqlUyYMBUHh4acP39Wo3RnY2PFhAnT0dNT7xCo\nVCoCA/25ejUCb++5tGzZmlGjxhQ3pPeOTCbD07NjmbQtl8uZOXMKCQkJ5OQo8fQcgqVlFX7+eRly\nuRxjYxOmT59FamoqP/44S6OKFxcXy5Qp4wkO3kRU1I1C5d3c+nDq1Alq1arNb7/tR6lUMHbs94wb\n9w2ZmZmYmZUvMrlxURTlViaTydi9ewfZ2QqqVKnCjBlz0dHRZf782ZiYGBIZeY3ExATEYjE//jiL\ny5cvanaOABITH7Nu3RqkUilaWhJWrVqHnp4eq1evIDz8FJ6eHjRq1ITmzVsSHn6KkJAA/PwCiY5+\noDEIPT2H4Oe3AhMTUwCioq4zY8ZM7t//DTOzFeTkPGDBgssEB4vo06c/3bp159KlC6xbtwYDAwOi\nox/i5OTMhAlTEIlEHD58kA0bglCpVDRr1pKRI78DKDH/3ZuiTjxdWB3R2bkxx4//oTnOjXFTqVT5\nEqZbW9swfPi3hIWtZ8CA3tSqVZuRI7/jzJnT/PzzTzRp0gw9PX1N+YJ52fIed+vWnbi4WBYv9iIp\n6RlLlizAy2txsfnxAL74ogcTJnynkeN/E06ffv7CkAMQERHhRGxsDNWqVX9lPZFIxFdftXplmTdF\nJpNRoUIlIiIua1xZ8yKR5F0YE6NUKoGC4jsqTf/yolKpEIlErFvnxpQp5+jV6xOcnQvniMy/+CbW\nHL+8noDAh6NMli2VSiXz5s3jl19+Yd++fezbt4/bt0s/CFZAQEDgTenatQFWVns1x7Vq/crnn6tX\nzrduDWPAgN4MHz5II8OeK2YAULt2HeLjX71D9roUlwx7797dHDjgR/36PahYcTVVquykbt1kbG1f\nxrE9fPiAnj2/Zv36Lchkhpw4cSzfzplIJKJChYoEBobSsKEjXl6z8fLyYc2aIAIC1gJw+PBBmjRp\nRmBgKEFBYdSsWYukpCRCQgLw9V1FQMAG6tWrx+bNG/P1ceDAodjZ1WHWrPkflSFX1pw7d4by5S0I\nCgolJGQzTZs2w9d3MfPnL2LduvUat8Pq1a1QKLI14hRHjx6iffuOKBQKfvqpcHlAk7T5l19CcHfv\ny7Vr6Zw7N5JjxxYSFVWOdevUz+xV7sDFuZW1afMp/v4hBAWFUr26NXv37tJcMyUlhTVrAhk2bBTx\n8XH07t2fadNmkZ6exq1b/yMpKYmoqBv06eNJUFAYAMuXLyU5+TknT/7OvHkLCAoKpXPnrjg5OTNy\n5HekpqZy+vQJKlWyxNm5seZeFRSUyMl5Of3Q1r6FltYI1qwJIDDQn8TExBdjus64cZPYsGErMTHR\nnDhxjMTEBPz8fmb5cj8CA0OJirrOqVO/A2oVyXr17AkKCmXAgCGUL2/OihVrSi1v3aZNG+jf343+\n/d3YsiWM+Pg4PDy+4scfZ9G/vxvz5m2mbdsOjBmzieTkZIKCfmH9+kBkMkOsrKyxtrbF1LQc1apV\nZ8CAwYwc+R1GRkasW7eGn3/+ifj4OB48uAeoXRafP3/OoEF9GDlyEA8fPmD48G+YMmUGDRt+gq/v\nagwMZDg6fsLChS+VTseNm0Tnzl0BtRx/aOj2txq/mZkCeJm6wNz8AeXKvbtHwtsglUrx8lrMwYP7\nOHz4YL5zRf9NiGjQwIHw8FNkZWWRnp7OmTOngdzYVkMiI68AcOjQAU0tR8dPOHLkNwAePLjPo0fx\nVK9uVYIi60eRqlngP06Z7MxFRkZSrVo1qlRRTz66dOnC0aNHsbW1LYvLCQgICLw21taWBARkEBy8\nBbFYxdCh9bGwMCtWhj1XzABALBaVyipspUqVCQ7epDn28Oir+bxkiVqQJS0tjYiIG9jYWFKxYocC\n9QvH2xWkZUu126SNTQ3kcrnGRUgqlZKWlkrduvXw9p6LQqGgVau21KxZi8uXL3Lv3h1GjBgEqKXj\n69QpOoH6f20SY2tbk5UrfVm9egXNm7fC0FDGnTtFux22a+fC0aOH6Nt3AMeOHWHevAU8eHCvWDdF\nUCsIgnoHcM2aRKTS/VhaJvDsWTYnTmgxsQTxvqLcylQquH37b/z9V5OWlkp6upwmTZpp6nz66acA\n2Ns3QE9Pj9mzp6Gjo4O+vj7x8bE8fvyI5ORkgoPXsXPnVmQyGeHhJ7l+/RqPHsWxePF8XF092LIl\njPT0NFQqFa6u7tSrV5/U1BQiIq7w9ddfoqOjQ0pKcr7+lisHOjr3UalEyOVN6NEjC2NjE5ycnLlx\n4xoymSF169ajUqXKAHTo8BmRkVeQSCQ4On6CsbEJAC4unbhy5TKtWrV9rfx3b0tRaSQcHZ1eiHLM\nZe/eh/j4uGBtHcbmzT2Ji1uNgcFZgoM3kZ2dzaBBfbGzqwNQaOHFxMSUgIAN7Ny5jbCwDUye/ANW\nVtasXOmPlpYWf/55jrVrV/LjjyXHS169eotz5/7ms88cqFq18luPd8KEDty6FcL581UwNExh3Dhj\nZLI3T2ZcGohEInR1dVm06CfGjRuFp+eQfK7HRe1M2tnVpWXL1nh6umvSQuQKFE2bNuuFAAo0atRU\n8yx69HDFx8cbT093tLS0mD59NhKJJN/zenHVfH0rK/dnAYHXpUyMuUePHlGp0ktFpgoVKhAZGVkW\nlxIQEBB4Y+ztbVm8OP/i0uvKsL8Pbt68z7Bh17hxowVmZreYPv02ffu21JwvGG+nVGYWaiOvG1De\n5Nm5bkEODo6sXOnPmTOn8fKajZtbHwwNjXB2bqJx6zM3NyQhIaXIPv7XJjBVq1YjIGAjf/xxGn//\nVTg5OWNtbYufX0Chsu3auTBjxhTatGmHSCTC0rIKt2//XWx5eCl2kZ6ejpbWfhITx5GW9il6eucp\nV+7Vub3UFJ3TzctrLgsWLMHWtgYHDuzl8uWX6RpyfxdisZiKFStpcnl5ec1BqVQiFmvRunXbIt08\ns7OzuXDhPL//fhRtbW1Wr16X73xgYCh//HGa3bt38sknjfjtt/0a983MzCwsLU3p2zeKnTv/wsxM\nwrRp37wciaiw01CuO1xhXn7/uvnv3oa8aSQA2rRpR0TEZSpUqETduvbMnx8LmOaOgHv3njJiROGY\nraJo06YdoI6hPHHiGKCO7Zo3bxYxMQ81O7clsW3bWWbMMOLJE1d8fC4xe3Y47u4t3mq8Ojo6BAa6\nIZfL0dHR+WDxp3kXvnJTMwC0bKmOWRs0KH8IT958dB4e/Rg0aBgZGRl8++0watdWG9O1a9tpxE8A\nRo0aDajz6RWVdLxz566a3U6ArVt3FXtOQOBDUCbG3Nu8TM3NP8yKj8D7QXi+/27+Dc+3S5eO7N+/\nC09PN6ytrXF0bIiJiT5isUgzPmNjfXR1pWU+3jFjorhxQ60c9+RJVfz8tjF2rAyRSERmpgESiZam\nDzKZDmKxEl1dKUZGepibGyIWizAzk2FiYohMpoOenramfO659PQUatashp1dP3R0xDx4cJfhw4fj\n67sYufwZ1apVIz09nbS0J1hZWSGVamFqqo+5uWG+z/8VHj9+jKWlGX36fE3lyuaEhYWRmppMTMxt\nGjZsSHZ2Nvfv36dGjRqYm9dBR0fKpk3BfPllN8zNDTE2rkdKyvMiy0ulWpiYqJ9d+fIyjIySiI9X\nKw+amQVgbq6Dubkhhoa6+Z5lXlq3bsasWbMYP3402dnZnDsXjpubGxkZ6dSqVR1DQ12OHz9ExYoV\nMTc3RFdXbciZmxsW+k3p6koxNtbH2dm50O/h8ePHWFhYIJdn8cUXnfj00xZ06NBBU1cdf7me69eN\nqFQJPDw8OHz4N6pXr0Zc3F1q1mzN+fOnkEq1GD78M7Ky/sfRo0cxMdElLS2NyMjL/PDDVO7cuUNU\n1HUyM59TuXJlTp8+jru7Ow0bNmTFiqVIJAqMjIw4efIY/fr1w9zcEJEo/7vI0FCGjk7pvJ8MDXVR\nKjM0benrayOT6WJoaIC5uSGWlgryKlcaGyvR13/5rPT0pMhkupp7n/dvtVIlU0xMDDEzkyEWq/u7\nZMl82rZtRd++fYmJidGM0cREHx0dSZFjCgtL48kTdSLwp08/YdOmnXz33buO/Z/7Nz5hwmxu375N\nZmYmPXr0oHlz53dqb9mygxw7loWRUSbz5jXHxqZ0lDrflv/S+1fg1ZSJMVehQgXi4l7GlcTHx1Oh\nQtGqS7kUt/or8M/nVav7Av98/inP99Sp36latTpWVtaAWmr/22/HaVyfALy8luark5qayi+/rCcu\n7hkSiQQnp+Y4OTUv8/EmJ+dfEEtNlfLo0XO0tLR4+jQNpTJH04fU1Ezk8kwyMrJJTpaTkJBCTo6K\nJ09Syc7WIjVVfS63fE4OPHmSSnj4KcLC1iORSNDXN+CHH+agVEqZMmUmo0ePISsrG4lEzKBBIzAw\nMCM7W8mzZ+kkJKTk+/xf4c8/I1i50hexWIREImXixKmIxWK8vReSmpqKUqnAza03xsbq/3WtW7dn\n9erl9Os3VHOfZs/2LrJ8draSpCS5pty0aWNZsGAYIEWlSuX5cz06dHDByMgYU9Ny9OrlyrNnScya\nNQ8AX98lZGVlkpSURMeOn2FhUQEdHT127PgVY2NTWrZsiYmJKZ9+2p70dPVzy8jIRiQSkZCQUug3\nlftbKvh7ABg2bBR2diqmTJnwQrZdxbffjtPUnT9/L/7+lpibr+DmTTEXLybh57eYjIwM5s6dq4nz\nUijU10tPz6J6dRs8PPqQlJRE//6DAF2SktKxs6vLjBmzXgigNKJhQ3XC6qFDR9GnT19UKhXNm7ei\nfv1GL64vyveb7NLlSwYOHPRWAiAFsbWtw/z5c/jqKw9yclQcPPgbM2bMRaFQkpCQwpQpzbh/P4iH\nD9Oxt9/G0KHNOHgwlK++6o1CoeDo0WN8+eVXmntf1N9qUlI62dnq9p48SUJXV/1uXb8+jJwcFQkJ\nKSQlpZOZqSjyby8rK78LeGZmzn/qb7QgU6bMznf8Lvdi/fqTTJ1an6ysagBERQWzb9+XmhyC75t/\nyv9dgTfnbYx0kaoMAh8UCgWdOnUiKCgICwsLXF1dWbp06Stj5oQf5b8X4aXz7+af8nznz5+tUXsE\n+O674Xzzzdh8xlxetm49h5eXnISEqjRocJmAgDZUrKiOcVIoFEgkZZfZZcuWM0yZUoXU1HpAMh4e\n2/H17VVm1yuOf8qz/bcSFxeLu3sPAgNDsba2YciQ/tSoUZOpU2dy+vQJ9u3b80KdUgctLS1Onz7J\nwYN7+eGHufTr9zUKRTYbN25DIpHSu3dPVq9eh7m5hab9sni+ffoc4vDhnppjS8tfuXSpXbEeOwEB\na9HT088XNwpw6dIFNm3ayKJFy4qs9yHYvHkj+/btBtRpJFq1asPkyePyxb+6un7BunXrMTIyJiBg\nLYcPH6RcOTNMTU1p2rQ5Xbt2x8trDi1atKJNm3b5ykdF3WDVKl+WL/fj2rWrzJ8/Cz09PZo1a8mh\nQwfZunUXly5dYPPmjflET3IJCTnJ3LnVSU62x9DwBtOm/c3gwW3f1+35VzNmzG+Ehb18B+vqXuDs\nWRmVK3+Y3Tnh3fzv5W2MuTKZjUgkEmbMmMHgwYM1qQkE8RMBAYG35U3yv8XFxeLtPZfnz59jYmLK\ntGkzefz4EeHhp7hy5TIhIQHMm7cQgOPHj7BkyQJSU1OYMmUmDg4NUSqVrF69grCw40gkMvT0+nDh\nwgCmTfNCVzcSIyMj7t+/R1jYjjIZq1wu588/N+PoeJ/U1GwcHdvh5TWuTK71JmzfHs7//pdCkyYV\nadeu4Yfuzr+e48cj8fW9ikplwq5dfzNunC3W1jYa2XRra1vi42NJTU1h3ryZxMQ8fPE3ksXgwWqx\nDZnMSJMbzMrKmri42HzG3Juye/cFli17TGqqDi1aPGPJkh6FdiYqVEgHcsgVy65UKbXE0IuiTr9K\ncv9V7N9/gXPnEqleXZuBAz8t1Rg6N7c+uLn1yfddXkMOYOvW3ZrPxcVs5Y3Lylvezq4Oy5f7AWBv\nXz/fO2bo0JEAODk54+RUtLtg//6tqVHjKhcvbqNtW1vq12/7FqN8e1JTUzl8+CA9evQiMTGRn35a\nzI8/LnyvfSgrKlVSABmAOmayYsUHlCtXuikYBATeljJbWm7Tpg1t2rw6Ca2AgIDA6/Am+d+WLVvM\n5593o1OnLuzbt5uffvLB29uHli1ba1bDc8nJycHfP5g//ggnMHAtP/20ir17d6Gnp0dS0iQSE9tR\ntaoHaWktyMjQ5uHDm6xfv4WKFSu9orfvRq4Efm6+qrS01NcSH1i3bg0ODo5F5kh6V7y99/Hzz63I\nzrZEJrvOnDkn6devdakXG+jbAAAgAElEQVRfJy9xcbFMnjwun6ABlO04PxaePXvKxInPiIvriqXl\nPpYsaUKVKuH5xGxyhWx++cUPZ+dGeHv7EB8fx3ffDWfjxm3s37+HmzdvaNoUi7XIycl56z4lJz9n\n1qw0YmLcALh/PxUbm4OMHv1ZvnJz5rTnyZNgbtwwxcIijXnz6r6y3YICFrk4On6Co+Mnb9THkJCT\nzJpVg7S0TxGJnvD337vw8ur+Rm2UJosWzefevTtkZWXRuXNXatasXXKlPMTHJzJ16iliYw2xtU1m\n0aKOGkXG4mjevD7Nm9f/IDs3KSnJ7Ny5lR49elG+fPl/jSEHMH68C/fuhXL+fDmMjTOYNKkSurpv\nnoxeQKAsKDs/IQEBAYFSYuvWME6dOgFQbP63CxfOAXD9+lW8vX0A+Oyzz1m9ermmnYJe5W3afPqi\nvp0mf9yff57l9u2/qVhxF3p6qxGL05DJztCwoQ7R0fXK1JCDwhL4Dg4vd8Fy+1/UbsPgwcPLrE8H\nDmiTna12J0pNrcvu3dfp16/MLvdKynKcHwvXr9/l4UMnJBK18ZWVVZWrV89ScO6oUqlIS0vVJOfO\ndQEsjneJqoiPf0xsbI0838iIiSlcztDQkODg9+8SDHDwYAZpaWqDSaUy49ixDysQMWvWj+9Uf+LE\nkxw61B8QcflyDlLpRnx9P5xxWhJ+fiuIiYlm4MDeVKlSjfv37xISspn9+/dw6tTvZGRkEB39EHf3\nPmRmZnHkyEGkUm0WL/bFyMiImJholi5dRFLSM3R1dZk8eTrVqllx7NgRgoL8EYu1kMlk/Pzz2vc+\nNm1tbfz8XF+hqiog8OEQjDkBAYGPmrfJ/1bcpLXgP2GpVPtFfa189cePn0SDBo4sXXqUxEQJLVoY\nYmVlzqZNpZNixc/vZywsKvDVV66AerdJX98AlSqH48ePoKury8OHD/D3X0Xt2nacOXOaevXqc/Pm\nDRYvXs66dX7cvHkDkUhEly5f8vXXHvliAi9cOM+qVb4olUrs7OoyceJUpFIpvXp1o3PnroSHn0Kp\nVDBv3gKqVbMqsb9SqbLA8dvv8LwJOTk5LFw4n2vXIjA3t8Dbewk+Pt6acfbq1Q0Xl06cPRuOWKzF\npEnT8fNbQWxsDB4e/ejevWfJF/kIqVPHiipVrhAf3wAAqTSGunUNuHMn/29YLBbj4dGf+fNnERy8\njmbNWpKbA6uo/FfvMgmtVq0q9eod4do1OwB0dO7i7Pxxqenp6ORPTK6rm1VMyX8G9+8b8zKnmZi7\ndw0+ZHdKZOTI0dy9e4fAwFDi4+OYNGms5lzu95mZmbi5fcmoUWMICNjIihVLOXhwH19/7cGiRfP5\n/vtpVKlSlb/+usaSJQvx9V1NcPAvLF26kvLly5OWlvoBR/jfS8ki8M/gwyQOERAQEHhN3jT/m719\nA44ePQTAoUMHcHBwBEBfX5+0tLQSr9e4cTN27NiGlpYWU6d+zpgxtfn8c8d3H0ge2rd34dixw5rj\n48ePYmJiQnT0Q7y9l7J2bTAKhYKmTZtz585tYmKi+eorV9av30JS0jMSExMICdlMcPAmunTpBryc\nvGdmZuLlNYe5cxcQHLwJpVLJzp3bNGVyExR3796LsLANr9XfoUONKFfuNJBK1aoHGDGiaqnej+J4\n+PABPXt+zfr1W5DJDDlx4lihhMsVKlQkMDCUhg0d8fKajZeXD2vWBBEQ8P5X70uLcuXMWLjQkEaN\nwjEzG8SYMadxd2/FtGmzNG7Cufm3cmOrAgI2MnToSE0OrM6duzJ27PeaNhctWkbDhk5v3SddXV1W\nrbLniy/C6NBhOzNnRuLq2vzdBlrKjB5dC1vbnUAs5csfZdQokw/dpXeiWrXneY5yqF79wxoyJZF3\nEa3ggpqjozN6enqYmJggkxnSooXaTdvGpgbx8bHI5XKuXo1kxozJDBzYGx8fL548eQJA/foOzJ8/\niz17fs236CYgIKBG2JkTEBD4qGnSpDm//rqdvn1dqVq1Ovb29YH8K6R5P48dOwlv7zmEhq7H1NRU\nIzbQvn1HFi6cz7Ztm5k3b0ERV1K30a1bd+LiYhk8WC19bmpaDi+vxW8tyFAUNWvWfmGUJfLs2VMM\nDQ25c+c2f/55josX/yQhIQGVKoe7d28zduz3REdHU7euPQCWllWIjY3hp58W06xZSxo3bqppV6VS\n8eDBfSpXtqRKFbXB1blzV3bs2MLXX3sARScoLgl39+Y0afKAq1dP0qSJHRUqmJfOjSiBSpUsqVGj\nJqB2hY2Liy1UpmVLdWy2jU0N5HI5enp66OnpIZVKSUtLxcDg1TFGHysuLg1xcXk7oZkdO8K5eTMF\nZ2cLXFze3oAriJ2dFb/8YlVq7ZU2jo41OXSoApGR/6NGjeolpkT62PHxacXUqeuJjTXExuY53t4d\nNede5XL9MaKtLdV8FovFmuPc2E+VKgdDQ0MCA0ML1Z04cSrXr1/jjz/CGTy4n0b9U0BAQI1gzAkI\nCHzUSKVSfHyWF/r+0KETms9t27bXpByoWLFikTml6td3YMOGLZrjFSvWaD6bmJhodjREIhHDh3/D\n8OHf5Kv/NoIMr+LTTzvw++9HePLkCe3buxAfH0/fvgP48suv8pWLi4tFT+9lsJQ6JmkT586d4ddf\nt3Ps2GGmTp2pOV9wclcwxiN3EqWlJX6jVW5r62pYW1d7ozG+K/kngFpkZ8uLLZNXHCT3+L+4ir9o\n0QGWL29OVlYVDAyimDXrBAMG/HfEyAwNjWjRovT+TovKR1malORyrVJl4+bWlsGDhxMXF8vQof01\nLtft2rmQkpLM6NETANi9eyf379/lu+/Gl0lfS0JfX5/09PQ3qpNrlOrrG1C5cmWOHz/Cp592QKVS\ncfv239SoUZOYGPViVt269pw9G87jx48FY05AIA+CMScgICBQDCqVipMnL/D4cQqdOzcuUUnuTWjX\nzoWFC3/k+fMkVq70JzR0FwEBIdSq1YA6dWqQkPAYiURaqN7z50lIJBLatGlH1arV+PHHlzLnIpGI\natWqExcXS0xMNJaWVfjtt/1FutclJiZw48ZfpTaeVxEXF8uECd9hb9+Aq1cjsLOrS+fOXQkMXKtJ\nfm1pWQVv77nExsaiq6vLwIFDAfXkNjY2moiIK+jrG2BjY0toaAgbNgSRmJjA9evXaNq0xTuJe3xM\nBAX9wqFDBzAxMcXCogK1a9dBJpOxe/cOsrMVVKlS5UVuOV3mz5+Njo4ut27d5Nmzp0yZMoM9e3ZQ\nqVIQGRkOPHrkze7df1G37lkCAtaSlZWFpWUVpk1T5y8TKJl32fl6nXyU7du74Ou7RGPMHT9+lD59\n+nP1agT+/iHk5OQwZcoEIiIuY2FRgZiYaGbMmEvduvbI5XIGDPDgm2/GoqWlxYEDe/j+++lv3d93\nxdjYhPr1Hejf343q1a3zuUPnv4/5vSpyz82c+SM+PgsIDg5AoVDQoUNHatSoyapVvkRHP0SlUuHs\n3FizWy8gIKBGMOYEBAQEimHSpJ1s3NgWhcICB4ethIa2xtzcrFTatra2QS5Px8KiAkuXniYgwBWZ\nzITBg8dhaSmhfHlTZsyYV2gilJCQgJfXHFQqtQjJiBHf5WtXW1ubadNmMWPGZJRKJXXq1KN791x1\nwaInVO+DmJhofvxxEVOnzmTIkP4cPXqI1asDOH36BCEhgVSooDZcvL2XcOnSBZYuXajJYXb//n26\nd+9JVlYWBw7spWvXL/H0HMxXX3XB13cJTZu2eOWE8Z/CjRt/ceLEMYKDN5Gdnc2gQep8cW3afEq3\nbmoVQ3//1ezdu4uePd0QiUSkpqawZk0gp0+fYMqUCejoDOLmzW+oVq0n2tpRiMXJhIRsw9d3FTo6\numzYEMTmzRsZMGDIBx7tu/M6iwRWVjYsW7aIu3fvoFQqGDRoGC1btnlthUWA337bz8KF81AqlUyd\nOpM6deohl8uLbffEiWNkZGSQk5PD7NnzmTlzKunpaSiVSiZMmJpPofZVLtcDB/YGQC5X99HCogIV\nKlTSuFzr6enh5NSI8PBTVK9uhUKhwMbmw+b0LUrBs3PnrnTu3FVznOsFUfBcpUqVWbKksBfG/PmL\ny6CnAgL/HgRjTkBAQKAI7t69y6ZNDVAoqgMQEdGPVas2M2tWl1K7RnDwJtLT03F2voRCUZmkJE+S\nkjxxctrCzz93zlculxo1ahIQUFi4JG8i4k8+aURAwMZCZfJOomxta1CxYqVCapEPHtxj8WJvMjMz\nsbGxYvz4aSgU2UycOIZ169Zz69b/GDSoD9u378XCogJff/0l69dvQUdH55VjrVTJUjPRzJv82sam\nBnFxsTx6FKeZtDk5OZOens6GDVvYtGkjLVu2pm/fAQDs2LGV338/yu+/H8XY2Jjnz5+TkZHxygnj\nP4WrVyNo1aotUqkUqVRKixatUKng9u2/8fdfTVpaKunpcpo0aaap06KFOnGxtbUt5cqZ0bNnbWbO\nPEVmphUVK+6hQwd99u69w4gRgwDIzlZQv36DDzK+sqCkRQIrK2ucnRszbdosUlJSGDbME2fnJsDr\nKSyqVCoyMzMIDAwlIuIy3t5zCQnZTEhIQLHt3rr1P4KDN2FoaEhY2AaaNGlG//6DUKlUyOWFXYXf\n1uUaoFu3LwkJCaB6dWu6dPmijO7yh+P8+WtERkbTunUdatWq/qG7IyDwUSIYcwICAgJFkJmpQKHI\n64omQqksfQFg9Y6SqsB3pe8yGBh4gl27MpFKlQwdWpH69Svw8OEDZs/2YvLk6cycOZUTJ46xcWMI\n48dPwsHBkbCwQAID1zJ69ASysjJJT08jMvIydnZ1uXLlMg0aOFCunFmJhhwUFkDIjW8TiUTk5CgR\ni6XFukrq6OSdwKpYuzZYUz8jIwMvr8PExOhQt24O48Z1fK0k6x8noiLvgZfXXBYsWIKtbQ0OHNjL\n5csXNefyJhHX1pbi6tqURo2iWbAggU6dHDA3Nyc+vgmzZ89/b6N4nxS3SGBtbUt8fCwJCY8JDz9J\nWNh6ALKzs3n0KB6RSKRRWNTT0yuksHj79i1A/fvs0EGdGN3BwZG0tDRSU1M5f/5sse06OzfG0FCd\ntqFu3Xp4e89FoVDQqlVbatasVWgMBV2ub9++hb+/Hx07dkZPT69Yl2t1+/Y8fvyY//3vJiEhm9/6\nPrq4tOLw4VNvXb8s8PM7xqJFtqSm9qJChZMsXfoUF5fSVRYWEPg38E/9jycgICBQptSqZUvHjicB\n9Uq6tfVu+vQpfREEPT09evVKQls7GlBRvfp+Bg2qUWK9N+Ho0cvMnVuDM2d6cuLE13z/fTZxcY8L\nqUXGxESTmpqiSefQo0cPrly5DIC9vQORkRFERFyhX7+BRERcIjLyCg0avJ3iYkEcHBw5dOgAoM4t\naGJi+kIIIr9x06hRU7ZufblTOXz4Ovz8XNmzpycLF7rg7b2/2GukpqZq0jRcunSBSZPGlUrfS4sG\nDRwIDz9FVlYW6enpnDmjnlzL5WmUK2eGQqHgt9+KH18uVlZVqFatAqamJtSrV5+rVyOIiYl+0Zac\nhw8flOk43ifFLRLkFcCZP38xgYGhBAaGsm3bHqpXtyqybkGFxeLI9eYtrt288YgODo6sXOmPubkF\nXl6zOXhwX6H28rpclytnRqNGTXFx6cSIEQPx9HRn5swpyOXpL65d2H24XbsONGjQ8B1jel/fLVml\nUr2XGNXQUAWpqfaAiEeP2hAY+KjMrykg8E9E2JkTEBAQKAKxWMy6dV8TFLSf5OQcevSoj7W1ZZlc\na86cbjRrdpb79//gs88aYGVVuVTbv3DhEWlprTXHcXHNuXx5UyG1yNTUlHz18k7YGjZ0JCLiMo8e\nxdOqVRs2bAhCJBLRvHmr1+rDqxJYi0QiBg4cirf3XFq1akS9evX54YfZmnN5q44dO5GlSxfi6emB\nUqnk7l0LIHccJly6VPwuYUpKMjt3bqVHj17FlvmQ2NnVpWXL1nh6ulOunBm2tjWQyWQMGTKCYcMG\nYGJiQr169vkUA4tL0ZGLiYkJ06fPZvbsaWRlqZNqDxs2iqpV368y6YeiceOmbNu2iXHjJgHwv/9F\nUauW3SuNkYL50o4dO4yTkzMREVeQyQwxMJC9drvx8fGYm5vTrVt3srKyuHXrJp06FXbVzutKDeDq\n6o6rq3uJ5QAiIyNwd+/zirvw+qSnpzN16kRSUpJRKhUMHTqSli3bEBcXy/jx32qUNBcvXs7Bg3sL\nifV4ePQlJiaapUsXkZT0DF1dXSZPnk61alZv3JecnILKvP+8OFgBgfeBYMwJCAgIFINEImHIkI4l\nFywFOnVqWnKht6RePVN0dO6RmWkFQPnyF7G3t+LEifzlDAxkGBkZERFxBQeHhuzatUuTjsHBwZE1\na1bi6PgJIpEIIyMj/vgjvJAAS1HkJrjOJW9836lTv7N2bRA6Orp4e/vQsqUzfn4BmvODBg3TfM7M\nzCQpKYkpU2Zqdj+6dduT71qmpoVjknLx81tBTEw0Awf2RiKRoKurxw8/TObu3dvUrl2HmTPnARAV\ndYOff16GXC7H2NiE6dNnYWZWnm+/HUbt2nZERFxBLk/nhx/mEBISyN27d2jf3oWhQ0eWeC9KwsOj\nH4MGDSMjI4Nvvx2GnV0datasnUfE5iV572Nx91ilUlGrlh1r1wb/Y3KSvQklLRIMGDAEX18fPD3d\nycnJoXJlSxYuXPbaCosikQhtbW0GDeqjEUABGDBgCMuXLymx3cuXLxAWth6JRIK+vgE//DCn1Mb+\n559XmDt3KnXq1MHJyblU2tTR0cHbezH6+gYkJSUxYsRATS7HvEqaxYn1ACxaNJ/vv59GlSpV+euv\nayxZsrDIdDEl0asXLF16h4wMG8zMzuPhYVoqYxQQ+LchGHMCAgL/KPbv38PNmzc0K+ICJdO1axNu\n3fqNPXsuIZUqGDLEhGrVqhc5EZ42bTY+Pt5kZGRgY2PFhAlqqfOKFSsBaNIcODg4kpiY+FquXa9K\ncLx16yY+++zzAnFxakJDQzh+/AhZWdnUrt2AQ4cacOuWNVZWQ6hcOR09PW2++KILcvkGnj37AwOD\ny6SnG7Jy5V2++WZMofZGjhytEb24fPkiU6dOYMOGrZiZlWfkyMFERl6hbl17fvppMQsXLsXY2ISj\nRw+xdu0qpk6diUgkQirV5pdfQti6dRNTpkwgMHAjhoZGuLl1x82tj0YB8W1ZtGg+9+7dISsri86d\nu1KzZu23buvy5Vt8//1fREdXwsYmmp9+akStWh92Ry4uLpbJk8dp4rtCQ9eTkSHH0NCIXbt2oKWl\nhZWVNXPmeBWrGJnLqxYJ8p77/vtphfrxugqLefNR5kVHR+e12i14XFqsXn2MxYurkpp6iMePj3Dp\n0v9wciocj/emqFQq/Px+JiLiCmKxiMTEBJ49ewqQT0mzKLEeULvxXr0ayYwZkzVtZmcr3qovY8e6\nUK/eBaKiLtCihQ1OTo3fcXQCAv9OBGNOQEDgH8W/cXfhfTBu3GeMKxAilnci7OHRV/N5zZpAAMzN\nDUlIeOl6uWPHy3iffv0G0q/fwGKvV9Atq06dety5c5vMzAzatm3P4MHD2bp1E4mJCYwePQITE1PN\n6v3atas4cuQ35PJ0QkI2Y2xsQseO7iQnP8DS8hFKZSqpqd3ZunUUfn4raNbsL8LDT9OmzafMmvUj\naWmpRfYp16h0cWnFwoXLqFOnHuXLmwNQo0Yt4uPjkMlk3L17m7FjRwGQk5ODmZm5po2WLXNFMmyx\nsVErSAJUrmzJo0fx72zMFSXt/rbMmxdFZGQ/AJ4+hR9/3EhIyMflXpn797xxYzDbtu1BIpFonl9x\nipG6uoUN/4+NR4+eEBR0DrEYhgxpjqmpSam1rVKpCAxUkJqqXli5e/cL/Pw2s3btuxtzhw4d4Pnz\nJAICNqClpYWr6xdkZmYBFFDSLCjWo3rRtxwMDQ0JDAx9574AuLg44+JSKk0JCPxrEYw5AQGB98pv\nv+1n27bNKBTZ1K1rz4QJU1i6dCFRUTfyTfRBnXdr+fIlyOUZaGtr89NPqwB1wusJE0YTExNN69Zt\nGTVq9Icc0n8CuVzO6NH7uHrVlHLl5MyYUYNmzexeWSevW1ZycjJGRkYolUrGjh3FnTt/4+rqzpYt\noaxYsQYjI2NNPXv7BmRlZbFnz68MGOBBuXJmZGREI5e7kJz8FVWq9CclZRMREc2RSrWJjY2latWq\naGtrc+LEcc0uQfGoDQipVFvzjZbWS9ELa2vbfK6eecmtk7tLp2lRJCInJ6eE675fnj7Nnxj82bOP\nN1G4rW1NZs+eTuvWbWnVqi1AkYqRjx/Hv1X81fvkyZOnuLuH89dfvQEVR44EsX1753cUKHlJTk4O\n2dla+b4rePy2pKWlYWpaDi0tLS5dukB8fFyR5Ro0cGDRIi/69RuIQqHgzJnTfPnlV+jrG1C5cmWO\nHz/Cp592QKVScfv230KibwGBMkQw5gQEBN4b9+7d5dixw/j5BaClpYWPzwIOHTrAsGHf5Jvo3779\nN9WqVWfWrGnMnbsAO7s6pKeno6Ojg0ql4tat/xEUFIpEIqV37564urpjbm5R5DXj4mKZOHE0DRo4\nFptPzdKyClOnztTIiQsUZv78I+za1ReQcucOTJ++kaNHa79ypzSvW9axY4fYvftXlEolT54kcvfu\nXWxsilbtbN68JZcuXaBduw4ATJ78A+3atUEmO4aBwUlycgyRSJ7j57cCqVRKq1Zt6NPHkwsXzvP7\n70fZsWNLkTE6+vr6+cRDVCoVK1f6cu7cGRITE8nJycHFpRPx8XEMHNgHS8sq3LnzN1WqVGXRop8A\niIy8wty5P7yYUCuYNGkcixYte9vbWqY0bJjM9esZgC6QjKNj8fGE7wstLS1ycl7u6GRmZgDg4+PL\n5csXCQ8/RUhIgGbXeP78xf84sZbNm8+9MOREgIjLl/uxY8du+vcvnfhbLS0tunRJZd26xyiVFpQr\nd55evcq9U5u5f8cdO3Zi8uTxeHq6U7t2HapXty5UBooX6wGYOfNHfHwWEBwcgEKhoEOHjoIxJyBQ\nhgjGnICAwHvj4sXz3LwZxZAhatevrKwszMzMCk307927A4CZWXlNUL2+vj6gnlB88klj9PUNALCy\nsiYuLrZYYw4gOvohc+Z4F5tPbd26NZp8agJFEx+vw0vVSIiLM0cul2ueS1HkumXFxsawadNGfvll\nPTKZDC+vOWRlZb7yek2aNGXJkoUaY1BXV4svv3Tj1q1satQwYPToTvzxRzgrV/pqlDibNWtB/foO\nuLl9WWSbxsYm1K/vwJEjv7F69XJyclQoFNkEB29iwYJ5HD16mAEDhjB48HAWL/YmOzsbsVjM/fv3\nuHo1gpycHNavD2Tt2iDi4mKZM+cHPmav38WLu2Fm9isPHkipXTuH8eNLP3brTSlXzoykpKckJz9H\nV1ePM2dO06RJMx49isfJyZkGDRpy9Ogh5HJ5sYqRHzv6+hIgE7URDZCKTFZyLsY3Yd68L7C3P8WD\nB+m0bl2Npk0bvVN7hw6p1ZCMjU2K3ZUuqKRZUKyndm31u7pSpcosWbL8nfojICDw+gjGnICAwHul\nc+euDB/+jeY4NjaG8eO/LTDRz3rlJLmgpH5J7m0l5VPr1KkLM2ZMeYdR/fuxt1exZ89TVCr1DkDt\n2rHo6zd7rbppaWno6uphYGDA06dPOHv2jEYlU19fn7S0tHxulqDOJ9egQUNOnfodT093RCIxOjrR\n9O/vzMqVvvTtuw4DAwMaNnQiOzuLSZPGkZWVhTp2R0Ry8vNCbYI6Ju306ZP4+4ewfPkSatSohUgk\nYurUmSgUM7lx4zoAtrY1CAzcCICPzwLi4mIZO3Yivr5LqFixEjk5Klxd+3LlynmgeKGMD4lUKmXG\njMIy+B+ShITHaGlJGDrUE3NzC6ysrFEqlcydO4O0tFRUKhWuru7IZDIGDBjCmDEj6dfva0Adl3jp\n0oWPLrl1Qfr0acuRI8EcOvQFoKBbtwN07+5WqtcQiUS4u7cuuWAZUlCsJzo6k8mT9yOXS2nfXsHE\niZ0/aP8EBP4rCMacgIDAe+OTTxozZcoEvv66N6ampiQnP+fRo/giJ/rVqlnx5EkiUVHXsbOrS3p6\nGjo6ukXmhyopgW1J+dQESmbMmI5kZh7g4kUp5cpl8MMPLUusk+uWVbNmLWrVqk3v3j2xsKhIgwYO\nmjJffNGDCRO+w9zcAl/f1ZodV1C7W4JapfD58ySWLl3IypXhKJVKnJwaMXHiFAIC1qKvr4+/f7Cm\nnqvrF6+V1FgkKijioI5Hio5+SEpKsua7l/F06vH4+PzG6tWVUalkVK8er4kHFHg9jI2NNWqWxZGa\nmkpiYgKJiQmsW7ceY2O1gIiLy4c1YF4HqVRKcLAbp05dRCrVolkzN8Ri8YfuVqmTV6wnKekZ7dtH\n8PCh2miNjIyhSpXTuLuX/J4QEBB4NwRjTkBA4L1hZWXN0KEjGT/+G3JyVEilUsaNm1TkRF8ikTB3\nrjfLli0mMzMTXV1dli1bWUR+qDdXuCyYT+3gwX2anSKBohGJREye/Plrl3+VbHxeevZ0o2fPl7sW\nue5eAG3btqdt2/aA2v1rzhzvQvWtrBrg77+cvXt3IxaL8PQcAsC2bZsJDz+FUqlg3rwFVKtmRXLy\nc7y955KRIWf48IG0a+fC0aOHiYuL5d69u4SHn3whrR5BSkoKAwf2pl+/QZprVatWnejohxw+LCYl\npTkVK27n2bNqLFt2klmzPrwL45vi6OjIoUMnC6ULKGtyd+L+978orKxsmDFjDlevRrJqlS9KpRID\ngwqcO/c5aWl/U778Y4YPH0yFChb51E7PnDmNjo4OCxYswdT03eLFygItLS3atv3vSOlHRd3j4UNH\nzXF2tiV//XXmA/ZIQOC/g2DMCQgIvFfat3ehffv8WtP16tkXWdbOrq5GJj+XgnmbXkd8oqR8apaW\nVYo1NgQ+PlQqFWtzuAAAACAASURBVPv3n2H9+rNcvFgBHR0HDA0b4e9fHWvrCvj5rcDExJSAgA3s\n3LmNsLANTJ78A+vWraF27TpcvHiB4cO/YcWKpTRq1IRff92OXJ7O1Kmz6NixE35+P3P06CGNvHpE\nxCVAnVusX7+BLFz4M4aGG8jIqA+IyMiQvqK3AgV58OA+U6fOxN6+Ad7ecwkL28Du3TtZvtyPKlWq\n0qrVEFJS0khKmoKx8SEMDNzw9VW7WmZkyLG3b8CwYaNYtWo5u3fvxNNz8AcekYCdnRVVq17h4cMq\nAEilsdSta1BCLQEBgdJAMOYEBAT+MRw5cpktWx4jkSgZMaImDRqUrJBWcIfI3b0PV69GkZiYxsqV\n/kgkwmvwn4RKpWLcuO1s2tSZnJwOSKVrKFfuFM+fm7B06U0CAkYC0KZNOwBq1bLjxIljgDrR8fz5\nixkwQL179/z5cwYMGIKurh5isZiOHTsB6h24XBdPIF+C+s8++5xNm5ScODEIC4v5SCQGdOtWtczH\nLZfLmTlzCgkJCeTkKPH0HMLq1ctxcenE2bPhiMVaTJo0HT+/FcTGxuDh0e//7J1nQBRXF4af3aU3\nqQpWigZUBLEX7L3Ghl0Ru35qbFHRWFGJXWygKFgRxa7BCPYSY0Ox90pTUEDqwrL7/diwiqAxStFk\nnl87s3PvnDtDmTPnnPfQqVNXUlNTcXefSFLSW7KyZAwZMiJH4+2ioHjxEtjbOwDK67lx43pKlixF\n6dLK6yiV1kRb+zIJCa4AJCe/awGhrq6uuje2thW5fPlCIVsvkBeGhkYsWmTEihU7SE9X1sz16iXU\nzAkIFAbCU4yAgMB3weXL9/jpJzViY7sBEBa2jwMHjCle3OSz51AoFPz88x62b69NZmZJGjUKYsuW\nzt9FE+KCpEWLBv9IVOLq1Suoq6urHsgLkydPnrBrlxNyuTkAmZkjefbMCF1dPSIjV+Pvr/y3ll0n\n+X7/OPh4faWm5rufgQ8juQqFgoMHz/HyZTJi8Qt0dI5TrdpmdHVL8tNPY6hXr1K+rjEvLlz4A1PT\n4ixa5AVASkoyPj4rKVHCHH//AFauXMr8+bPw8fFHKpXSv38POnXqiqamJp6ei9DR0SUhIYHhw92K\n3Jl7//oqFAr09PR5+zZRta9SpSSuXs3661gZ9eu/66Emkbx7bBGLRTnurUDR0rSpA02bFv7fBAGB\n/zr/vopcAQGBfyUnTjwhNraeavvx41acOhX+j+a4dOkGAQHOZGZWBEpz6pQb69adzF9Dv0v+Wc1h\nWNhlbty4XkC2fBqZLAu5/F1ao0QSi0Khjr6+Hj16dOH+/XsfHevg4ERIyGFAuQZDQyN0dHRzOXgf\n9qP7+ec9DB1ajWnTuuLnV45fflnA778fZvfuDTRs6EhhYGNTgcuXL+DtvZLw8Gvo6ip7emU7ZtbW\n5alcuQra2toYGhqirq6uUof08VmFq2svxo0bSVxcLPHxbwrF5o/x8mUMN2/eACA09Hfs7CoSHR1F\nZGQEADY2b3F21qJfv12UKCHGxSX/61nv3r3D8uWL831eAQEBgcJGiMwJCAh8F5QurYVYHIdcbgqA\nru4DfvihFHK5/LOV4uLjU5DJ3hdLUCet6PsoFzgBAZvR0NCgW7eerFixhEePHuLl5c2VK5c4dGg/\nkLeoxPHjx1m5cjUyWSYGBsWYOXMu6enpHDiwB7FYQkhIMGPHTsLRsWqhraVChfK0bbuDAwfKAXqY\nm3tTokQIRkZ6nD+vy4QJUz5oM/FOMGfgwKF4es7B1bUX2tra/PLLLOURIlGOVhhOTjXYunUjbm69\n6dixC7t22SGXlwDgwYOubNgQyK+/Fm4j6zJlyuLnt43z58/i67uG6tWVfcWyI5BisRh19fdVW8XI\nZDJOnTpMYmICfn5bkUgkuLh0RCrNKFTb30ckElG2bDn27t3Jr7/OwdLSmh49+lC5chWmT59MVlYW\nFStWxsvLHTU1NXbvTsihdvp+VO+fCh+9j51dRVUPy89BJpN9Vkp2TEw0N26E06JF6y+27XOZN28W\n9es3UIkECQgI/DcRnDkBAYFC5ciRYHbt2oFMlkmlSvbY2FQgJiaKkSN/AiA4+CD37t1h3LhJuY51\ndX3D4cMm6Ou7Y2tbm2XLXtK4cVPu3buLp6fyLfulS3+yd+9u5s9flOvcjRo5UavWXi5edAPEWFvv\nw8Ulb/GVfxOOjtUIDNxKt249uXv3DjKZDJlMxvXr16hatRpHjx7JU1SiRo0arFu3EYCDB/exbdtm\nRo0ay48/dkVHR4eePfsW+lpEIhFr17rQqNFREhIy6Ny5D6VLj89xTFDQftVnO7uKrFjhA4CBgYHq\n5+R9Bg4cmmPbwMAAX9/NgDKKBB9Gsgo/qSUuLg59fX1atmyDnp4+Bw/uy/H9x9JHU1JSMDIyRiKR\nEBZ2mZiY6MIw96OYm1uwbduuXPurV6+Jn9+2XPs/pXZqa1uRdevWMH/+bG7cCMfOrhJt2rTH338d\n8fEJzJzpAYCX1xIyMqRoamri7j6TsmXLERZ2mcDAbSxcuEylchoVFYWWlhaTJk3DxqY8GzasJSoq\ngqioKMzNLXJI8X+MqKhIQkOP/CNn7nMdxQ/JS9lXQEDgv4fgzAkICBQaT58+4fjxUHx8/JBIJCxZ\nsgBtbW1Onz6pcuaOHw/F1XVQrmMXL/6VJk008fCoS7NmmfTr144mTZoD0KdPNxITEyhWzJDffjtI\n+/Y/5nl+LS0ttm9vw5o1O8nMFNOzpz3W1qULbf1Fha2tHffu3SE1NQUNDQ3s7Cpy9+4dwsOvMnbs\nzx8VlYiOjsbDYx5v3rwmMzOTkiVLqeb8jDZuBYZEIqFfv/yNRhw+fIXff49DRyeDn392xtjYCFCK\ndXTqdJbt28ujUBhjbb0fN7fPj+jkF48fP2T1ai/EYhFqauq5IpC5H+yV2y1btmby5PG4uvbE1rYi\n5cpZ5RiT1+dvkVevXnPmzHUqVCiFg8MPqv2RkRHMnbsQd/cZDB7cn2PHQvD29uPs2VNs3uzP9Olz\nWL3aF4lEwqVLF1i3bjVz5y7MMXe2yqmn5xLCwi4zd+4MlZLps2fPWLNmPSdOHGXIEFfVi6W2bTuy\ncOE8fH03kZWVxdChrsye7YmPzyqeP3+Km1tv2rTpQLduPfD2Xsm1a1fIyMikSxcXfvyxC2Fhl1m/\n3gcDAwOePXvKpEnT2LBhLYaGRjx58ghb24rMmKF0RjduXM+5c6eRSqXY2zswadI0le2f009RQEDg\n343gzAkICBQaV65c5N69uwwe3A+AjIwMjIyMKFmyFLdu3aR06dI8e/aMKlUc2b17R45jpVIpJiYm\naGhoIBaLc6QWtWrVliNHgmnTpgO3bt1UPQTlhb6+PpMntyvYhX5jqKmpYWFRiuDgg1Sp4oiNTXnC\nwi4RGRmJpaXVR0Ul5s6dS7duvahfvwFXr17Bz29dUS2hQAkNvcZPPxmQkNAYUHDrlj979nRGTU0N\nkUjEsmVdcXY+TWxsKu3bO1GmjHmh21irVh1q1aqTY9/7EcgPW3a8/52Pj1+OcfHxb0hKektYWBix\nsUm5FF+/NcLDHzBs2AseP26Bnt49Jkw4xv/+p/z9t7AohbW1DQBWVtbUqFHrr882xMREkZychIfH\nDCIjXyASiZDJZLnmz1Y5BahWrQaJiYmkpqYgEolwdm5IVFRkrpdQL148w9m5Ib6+3kil6bRq1RZr\naxtGjBjN9u1bVS1T9u/fg56eHr6+m8nIyGDkyMGq+/jgwT22bNmJubkFYWGXefjwPlu3BmFiYsqI\nEYO4fv0aDg5V6dKlu0qB1cNjBufOnaF+/QYFe9EFBAS+GwRnTkBAoFBp06Y9w4b9L8e+3347wPHj\noZQrZ0mjRk0+eSyAhoZmjkhC27YdmTx5HBoaGjRt2vyza+j+Szg6VmX79q1MnToTa2sbVqxYSsWK\nn1ZhTE5OxtTUDIDDhw8ByjTYS5cuUKNGLTZsWIuOji69en1+uuU/Vc4sDI4efUlCQre/tkRculSb\n58+fqZwEkUhEt25FqwCZHygUCsaO3U1wsBUSSQZDhlxgwoQWfz+wiPH2fsDjx8pUy+RkJ/z9nzBi\nhBx4VzMIOesGxWKliun69T7UqFETT8/FxMREM3r0sDzP8SmV0w9fQkmlUoyNjXFzG8KgQf3Q1NRU\nta/4cJ5Ll/7k0aOHnDx5DFCmvUZEvEAikVCxYmXMzS1Ux1asWFn1+1a+/A/ExETj4FCVsLBLBARs\nQSpN5+3bt1hb2wjOnICAgArhiUdAQKDQqF69FidOHCM+Ph6At28TiYmJoWHDJpw5c5KjR4/QvHnL\nTx6bF6amppiamrJpkx/t2nUonMV8Zzg6OvHmzWvs7atgZGSMpqYmjo5OwMfT7UaNGsX06ZMZNKgf\nhoaGqlS+lJQUfv89mP379xAdHaU6PizsMpMmjfsbS769dL5ixWTAu4iNoWEUhoaGRWdQAbF9+0kC\nAzuRmNiYN29asny5I+fOXStqs/4WmUySYzszU4JcLv/bcQqFgpSUdy8kfvvtQJ7HfY7KqbIWLwB/\n/wACAnbj5jaEhIQE0tPTSEtLRSqVftSO8eMnqcbu3LmfmjVrA6ClpZ3jOHX1d/30sltqSKVSli5d\nyLx5C9m0KZAOHTqRkVF4AjbR0VH0798j1/4NG9Zy+fLFT47dsGEt27dvLSjTBAQE/kKIzAkICBQa\nlpZWDBkygvHj/4dcrkBNTY0JEyZjbm6OpaU1z549wc6u0t8em1d9T4sWrUlMTKRsWctCXtWXER0d\nxcSJY3BwcOLmzXDMzIrj6bmE58+fsmiRJ1KplFKlSuPuPgN9fX1GjRpK5cpVCAu7THJyElOmzPhH\nKpLVq9fkxInzqu3t2/eoPoeEnOLw4UMEBm5DJBJhY1Oec+fOEBCwET09PfT19enVqx+6unrs37+H\nqlWrMX78JPz81qGtrQMoa5e8vVfy/Pkz/ve/IUyePI2yZS2Jiopk9uxfSE9Po379hvl3AfORceOa\ncvOmH+fPV0Ff/zWjRyswNv78/oXfCy9fSlEojFTbUmkZnj4Np379IjTqM+jWzZRz5y7w+nVtxOJY\n2rRJVAmGfPi34P1tsVhMr179mTdvJps2baBuXWfef5mQfejfqZxWr16LKVMm0L17b4yMjHj7NpHU\n1FSWLVvIkCEjiIqKxNt7BePGTUJHR5fU1BTVOWrVqsuePbtwcqqBmpoaz58/o3jxEp+99mzHzcCg\nGKmpqZw4cZSmTYs+mjpoUN4Rzvf51uswBQT+LQjOnICAQKHSrFkLmjXL/TAikUjQ09OnX7/uuLj0\nomPHzvz66xxcXHrxxx9n/6r7KklqagrFihmqFOBiY2MZMWIgdevWp0OHTkWwoi8nIuIFs2d7Mnny\nNGbMcOfUqeNs27aZ8eMn4ejoxIYNa/H3X8eYMRMQiUTI5XJ8fTdx/vw5/P3XsXz5mnyx4/HjR2zc\nuJ6SJUsRHx/P3bt3qFatBo0aNSI09ChPnz7BzW0AERGjSU9/g4nJUZWs+/PnT+nTpxuxsa9o2LAJ\nRkbGuLoOYsmSBXh5eePltZguXVxo1aote/YE5Yu9+Y22tjbbtvUkLi4OHR0rdHV1i9qkAqFNGzu2\nbAkhIkIZ/a5Y8Tdatcr/Hm75TevW1TE2vsPJk0GUKaNFz57K3/MPa/2mTp2p+vz+d++/uBgyZASg\njPQXK6aMvn6OyumHL5YaNGiEuroGzZu3Qi6XM3z4QMLCLuPgUBWJRMKAAb1p27YDLi49iY6OYtCg\nvigUCoyMjJk/f1Gudhgfbmejr69Phw6d6N+/B8bGJlSqlFN9tzAcJrlczoIF83K8dFq82FPVFuH8\n+bOsWrUcLS1tqlRxICoqSlUz+PTpY0aPHsbLlzF0796Lbt16Fri9AgL/NUSKb0QKKTY2qahNECgg\nzMz0hfv7Lya/7u/bt28xMDBAKk1nyBBXVq1aR7t2zVmwYBn16jmzZs0KdHV1cXUdxPz5s3F2bsSR\nI4kEB99HW/s3TExM2Lt3xxdJfBcF0dFRjBs3isBA5YPmtm2byMjI4NCh/ezeraxPi4yMYPr0Kfj5\nbWX06GEMG/Y/7O0dePPmNSNHDiYwcG++2LJrVyDXroWhr1+MyZOVSnk3b17H338tcXGvycjI4Pnz\neKKiFiGRxGNk5IujY3Vq1dJnz54gVq/2ZcgQVzQ01FEoFJQsWYrMTBlbt+6kXbtmHDgQgkQiISUl\nmU6d2hIaejpf7Bb451y6dJeAgKeIxXKmTauDsbHx3w/6l3H27Cm8vVfi7j4Te/sqOb5bsSKU/ftB\nIpExaFAxevSoV0RWfh359Xc5OjqKnj07s2HDVsqXr8CMGe44Ozfk8uWL1K/fgDp16tOrVxfWrFmP\nubkFs2ZNIy0tlQULlv2VinmBlSvXkZKSTO/eXVV/CwS+DuG56t+LmZn+Px7zfTz1CAgI/OsJCtrO\nmTPKHlKvXr3ixYsXH5XM79ChE0uWLCMkZDUWFod4/nwHL16oc/LkVZo3r1lka/in5BRvkJCc/Ol/\nztk1NWKxRKU4mR+IRCIMDY24cOE83t4rqVevAb6+a6hVqwZnz/5BSkoKkIaGxkOyspTph2lp6iQm\nJmBgYICFhQX6+vr8/PNUDhzYq3orL/BpNm/2o3//gYV6zpo17ahZ0w747z4QOjs3wtk5t6BNcPBF\nFi92Ij1d2b5h5sw/cHJ6Snj4n+zfvxtbWzumT/+4Um5BolAoOHHiEi9fvqVt25oUK1as0M5tYVGK\n8uUrAMo2J9l1sgqFgufPn1KyZCmVkEvz5q04cED5kkkkElGvXgPU1NQoVswQIyNj4uPfqGoYBQQE\n8gdBAEVAQKDICQu7zJUrl1i71p+NGwOoUOEHMjKkH5XMr1LFkbi4ONTVnwFZZGSU/6v+53URrSB/\n0NXVw8DAgPBwpSjF77//hpNTwafBVatWkytXLuHl5Y2NTXl8fFby4sUzdu7cybx5C3Fyqo6amg4i\nUToAYnEGFStqqsbr6OhSsmRJrl9X2q1QKHj48AGgvFfHjoUAEBLye4Gv5Xtiy5aNRW3CN8vHhDcK\nklu33qgcOYA3b2pw5cpD9u3bxfLla4rMkQNwd99H374V+OmndnTqdJqoqFeFdu4PXzrlfJH0YZpn\nzmQvNbWcaqMyWf69hBIQEFAiROYEBASKnNTUFPT19dHU1OTp0yfcunXzb8e0bt2aN29GExs7HgBL\ny2Batfp8QZBvgbzEG6ZOncXixZ6kp6dTqlTpHHVAH4zONzusrKzp3NmFSZPGoqamjqGhISVKWHD3\n7m0mTRpH1apOqKtnUKPGTeTyGJKSMmnQoCL3798jKektkZERzJgxl+HDB5KWlka/fj1o3rwl5ctX\n4KefJjJ79i9s27YJZ+dG/1lRBHf3ibx69ZKMDCkuLr2IiookI0OKm1tvrK1titRR+KdER0cxYcJo\n7O0duHEjHDu7Sn+pPa4jPj6BmTOVa/HyWkJGhhRNTU3c3WdStmw5goMPcvRoCG/exJGeLqVhw8aM\nHDmGQ4f28/jxQ8aMmQDA0aNHeP06rlDXVb26OXp6t0lOVoowmZuf486dE0RFRTJhwmiaNWtJZGQE\njx8/IitLxsCBQ3F2bsTPP//E8OGjsbEpj5tbbxo1asqAAYNZv96HEiXMv7qWNyoqku3bbZHJygJw\n61YvvL134OGRf/0yv7RlSNmy5YiKiiQmJhpzcwv27dvN27dvAaGhuYBAYSE4cwICAkVO7dr12Ldv\nN337ulCmTDlVHcvHJPMBevXqye7dATRvnoGa2g4GD65QJM2cv5QPxRve79W2dq1/ruNXrPBRXQND\nQ8McTaHzg+zm4SKRshfWxInuXL58jgMHDnL37h2aNWuJubkFbm5DmD9fKcM+ZMgIHByqMmnSWDQ1\ntWjatAVRUREsWPAuzdLComSOptXZAhT/NdzdZ+SqCd29eyf+/gFFbdoXERkZwdy5C3F3n8Hgwf05\ndiwEb28/zp49xebN/kyfPofVq32RSCRcunSBdetWM3fuQgAePXpA+fIV8PRcQu/eXXFx6UmzZi3Z\nssWf//1vLBKJhJMnj6Gjo8ucOdO5f/8ulpbWTJ8+mydPnrBq1TLS0tIoVsyQadNmYmJiSkTECxYt\n8iQxMQGxWMzcuQswMjJmypQJJCW9JStLxpAhI3B2bkR0dBSTJ49j8+YdAAQEbCE9PY2BA4fStetc\nLl78A5FIxA8/lGXOHG+6deuAlZUNu3fvRF1dnbFjJ+LoWI2hQ12pUaM2jo5OhIdfxdzcHDU1NW7c\nuA7A9evX+PnnqV99rTMyMsnK0nxvj4isrPxOrPr4S5aPvYARiURoamoyYcIUJkwYjZaWNjo6Oqp0\n8Y+JuggICOQvgjMnICBQ5Kirq7N48Ypc+0NCTqk+N27cjMaNm6m2r1+/RrNmLfjll86FYmNRsXHj\nadavTyUjQ43WrVOZPbtDgUS3atWqQ61adXLsc3auSZ8+g3Idmx0tjIl5SVKSCC+vtZia5pTyj4h4\nyZQpfxARoY+19VuWLGmKkdG/r3fb55JXTej3woeRuHLlLDE2NmbRonnExydQpkwZKleugrv7BJ49\ne8arVzHcunWDPXuCePjwvirCNnBgXzp27ExWVha3bt1g2LABaGpqERMTTZUqxalWrSbnzp2hXDlL\nZDIZMTHRzJw5F3t7Bzw957B7907OnDmJp+dSDA0NOXYshHXr1uDuPoPZs3+hf383GjRoTGZmJnJ5\nFmpq6nh6LkJHR5eEhASGD3fLs1Yuu38iwIMH5zh27CBqamqkpCQDkJSURNWq1Xjx4jlSaTrTp0+h\nbFlLMjMzefUqBkdHJ3btCsTCoiR16zpz+fJFpNJ0oqOjKFOm7Fdf/3LlytGmzQ727y8P6GFpeZC+\nfW2/eL4Po8QdOyr/hq5cuZSLF//E2NiU2bPnY2hoSHJyEpqaWri69srVKqVECWWdnI1NeTIyMti0\nKZA2bZoA4ObWmz59Bqj6hgIq51lAQCB/EZw5AQGB74qQkDC8vPxITX3A2LGTi9qcAuX+/SfMm2dC\nYqIyncrX9yWVK5+hR4+i79e2f/9Fpk1T8OqVE2XKXGDpUiMaNXqnDDh58nlCQ/sDcPu2Ak3NLXh7\n/7sd74/xfk2opqYmo0cPIyPj402mv0Xej8S5uvYkPT1dFYlbtmwRCoUCe3sHxoyZwJgxw5kxw53B\ng4chl2cxatRYVq1azpo16zl69AgODo7IZFksXLiMSZPGqWqwOnT4kc2b/ShXzoqmTZuTnJyMvb0D\nAK1atWXTJj8eP37EuHEjAaVkvomJGampqbx+HUeDBo0B5cshUEcmk+Hjs4rw8GuIxSLi4mKJj3+T\n5/qyUwJtbCowa9Y0GjZsrJovI0NKUNB2VSqhiYkpHh6eqp6WMpmMu3fvULJkaWrWrE1iYgL79+/F\n1rZivlx7kUiEj48L9euHkpAg48cfq2BlVeqL5/swSty4cVPS09Ows6vE6NHj2bhxPf7+6xg3bhJz\n585k/PjJebZKyXaAjxwJJjb2FQMG9MbGpgImJuaEh9di/Hh1rKz2sHx57a+yV0BA4NMIAigCAgLf\nDdevP2DcOHUuXdrErVt/MHu2Nk+fRhW1WQXG7dsvSEx811cqK6sET56kFqFF7/D2juPVq+aAKS9e\ntGP16pyRpogIvfe2RB9sf3u0aNGgwOb+WE2ompoaMpmswM6bn1hYlMLa2gaRSETp0mVVzeKtrcuT\nmprKs2dPadWqLQBaWlpkZEjR1zegShVHli1bREpKMklJbxGLcz92ZDtSlSrZ8+rVK0JDf6d+/Zz1\nlQqFAl1dXaysbPD3D8DfP4BNmwJZunQlH4puZBMScpjExAT8/Lbi7x+AkZExUmkGEokEufzdGKk0\nXfV50aLldOniwr17dxkypL/K0Zw2bTbdu/emRo1a7Np1kLJlLbl//y6gvI9mZsU5ceIo9vYOODg4\nERi4lapVnb7iiudE2buuOWPHtv5qxygoaDsDBvRm2LCBqiixWCymWTNlFK1lyzZcv36NlJRkkpOT\ncXRUrqN163Zcu3Y113w//tgVM7PibN26k44dO3P5cjwXL/YjJqYj58+7MmvW5a+yV0BA4NMIzpyA\ngMB3w4kTj4iNfdf3KSKiOcePXy9CiwqW+vUrY2l5QrVtaHiVevW+jTfcUqlGju2MDPUc21ZWibx7\nyM7C2jq5cAz7YgquuKd27XpkZWXRt68La9euVtWEduzYmQEDeuHhMb3Azp1f5FQ0FKscLZFI9Jcz\nJlI5ZWKxGB0dHfz81nH8eCj16jVAoYARIwbx5s1rPrzW7zttTZs2x8GhKrq6urx8GcPNmzcACA39\nncqV7UlIiFftk8lkPHnyGB0dXczMinPmzEkAMjIykErTSUlJwcjIGIlEQljYZWJiogEwNjYhIeEN\nb98mkpGRwR9/nFWt4+XLGKpVq8GIEaNJTk4mLS0NTU1NDh7cw4ABg5HJZPTs2Zl+/bqzYcNald1V\nq1bDyMgYDQ0NHB2rEhcXq3KCYmKiCQ39tJJrXi8TgoMPsmyZss4wOTmZvXt3fXKOz+FjysHwzqlW\nKBR/m8otkUhQKOQAuaLMaWk5k75iY3W+2m4BAYGPI6RZCggIfDfY2BRDQyOCjIzSAOjo3KNixZJF\nbFXBYWZmwurV5nh77yAzU0Lnzvo0bPhtNDFu1SqD+/ejyMwsibb2I9q2zdkIeMmSpmhobCEqSg8r\nqyQ8PVsXmC1paWnMmDGF2NhY5PIsXF0H4+Ozkg0btmBgUIy7d2+zerUXK1euJTU1leXLF3Hv3h1A\nxMCBQ2nUSFnns27dGv744yyampr8+usSjIzyp6H2x2pCnZyqM2LE6Hw5R2Gio6PD6NHjVNvFixen\nevVahIQcZsCAwfz000RWrVqOn99WIiMjKFWqNJMmTeWXXyZjaWlFzZp1WLlyKUCunoTXr4fTs2cf\nRCIRZcuW1bCnbgAAIABJREFUY+/enfz66xwsLa3p1q0ntWrVxctrMcnJyWRlyejRozdWVtZMnz6H\nRYvms379WtTU1Jg7dwEtW7Zm8uTxuLr2xNa2IuXKKdsOqKmpMWDAYIYMccXMrDiWlsr9WVlZeHjM\nICUlGYVCgYtLT/T09Ni//wgrVixh6FBX5HI55cpZ5hD5ARg8eDiDBw8HwNTUjNOnL6q+i4qKJDT0\nCC1a5P4dkMlkqKmpkdfLhPcdqqSkt+zdG0Tnzt0++z5lO2fvz/OxKLFcLufkyWM0a9aS0NDfcXBw\nQldXD319ZasUR8eqOVqlWFiU5O7d29jZVeLkyWOq+XV1dTE0TAYyAXUgHXv7b/1FjoDA943gzAkI\nCHw3tG9fl+HDD7FvnxZisZw+fUTUrduiqM0qUN5v8vwtMWVKW6ysznD//jmcnIxp375pju9NTIxY\nt65wauQuXPgDU9PiLFrkBUBKSjI+PivzPHbjxvXo6+urlESTkpTKe+npadjbOzB06EjWrFnBgQN7\ncXXNLf7yNWzbdpZt21IA6NlTm/79i7728XPIysrKs43G+5/d3Ibg6TkHV9deaGtr88svswBlSl9Y\n2GVEIjHW1jbUqVMfyE4b7E2bNu24eVOfK1dEpKWtompVW6pVqwHAtm25I1EVKvzAqlXrcu0vXboM\nXl7eufa/r6T6Pt269aRbt5659q9Zsz7XvqNHQ7h58wYikRhbW1sGDx7OmDHDSUxMxNDQiKlTZxAb\nm8qYMXNJTdVHR+cpRkZyRo8eS+PGzfDxWcXz509xc+tNmzbt0dc34OTJY6SnpyOXy5k3bxFSqRRX\n115oaGggEomQyWTEx79RNev28VlJZGQEbm69qVmzDiNHjiEgYDMnThwlIyOThg0bM2jQMKKjoxg/\nfhTVq1cjPPw6ixevoESJdyq/H1MO1tLS5vbtW2zatAEjIxPmzJkPwLRpebdK6dWrL9Onu3PgwF7q\n1nUm2xl1cqqBmZk/1ao1R0enCY6Otkyd2j7PeyAgIJA/iBTfSCOQ2NikojZBoIAwM9MX7u+/mKK4\nv3m9cRbIf76X390XL54zfvwomjZtQb16DXB0rIqLS8c8I3ODBvVjzhxPSpUqnWOOpk3rcfz4HwAc\nOxbK5csXmDz5l3yz8fLlO/Tpk4lc/ozExN4YGFxn1qzLXL/+R67o1NcQHR3FxIljcHBw4ubNcMzM\niuPpuYS4uFiWLl1IQkI8WlpaTJ48jerVq7B3729s3uyHTJaJgUExZs6ci5GRMRs2rCUqKoKoqCjM\nzS2YOXNuvtn4PgsWBLNkSTtAWVNZqVJbDh1aj56efoGc75/w9u1bBg7cSlTUAbS1BzJzZmWqVi3D\n3Lkzadq0Oa1bt+O33w5w9uxp7txpwKNHNxGL04mOXkaHDstJSfmNwMC9XL16he3bt6ruc3DwQdav\n92HTpkD09fVZtmwh+/fv5eTJ81y6dIFVq5axaVMgu3btwNfXmyNHThITE82kSWNVipAXL/7JyZPH\nmDRpGnK5nClTJtCnT3+KFy9Bjx6d2LFjBxYWVp9aXr6hUCh4/Pgx6urqlC379eqdAn/P9/K3WeCf\nY2b2z//2CTVzAgIC3x3vK6kJCJQpUxY/v23Y2JTH13cN/v6+OUQupNKMHMfn9Q5TInmXqCIWi1TC\nF/nFlStPSUoqh6HhdgDevnXg/v3Yr5rzYzZGRLyga9fubNmyEz09fU6dOs7ChfMZN+5nNmzYwsiR\nP7FkyQIAHB2dWLduI35+22jWrCXbtm1WzfPs2TO8vLy/2pGLjX3N//63j+7dQ/DwOJTD7ocP1VA6\ncgpAQUzMWESib+PRZM6cE1y/bkF8vAs3bgxi1qzHGBgYcPv2DVXKZKtWbblx4xoxMXqAiOTk5oCI\nxMSyvHmjVM788OdNJBJRo0Yt9PWVD203boQjkSjTlJ2cqhMVFUm/fj3Yvn0z6elpxMe/yTXHxYt/\ncunSBdzcejNoUF+eP39GRIRShKhECQscHBwK8Mq8IysriyFDdtCggQRn5zR+/nmP0CxcQKCQEdIs\nBQQEBAS+a+Li4tDX16dlyzbo6upx6NB+VU1PnTr1OHXqXU1PzZq12bNnJ2PGTACUaZbZD9X5TWDg\nVoKDDwJQtWpdSpYMRl39OWXLdiIrqwJ2dpU5fz6cX36ZzJMnj7C1rciMGR4A3L17J8/m2KNGDeWH\nH2y5fj2c5s1bUry4ORs3+iIWS9DT02PatFlYWJRSpefZ2toRHR3FzZvhTJ/+rpVHZqZSRfPVq5fM\nmDGFN29ek5mZScmSSoEdkUiEs3NDNDRyCt38HT4+qyhevARdurgAsGHDWnbvvkd0dCYSyVuePJES\nF3cPL68JREdH8ezZEkqUOI+W1l0iI9dhZvYLMtnuXNevfftOdO/e65MNv4OCAtm/fw8SiQRLSytm\nz57/Rfctm7g4bSCNbCGf2Fh9lfrohw6Lre0bwsIUKBTqQBqVKkk5f/7jTo22tnae+0NCDiOXK1iz\nZj1nzpxk+fLFuV5GZNO37wB+/LELAIcPHyIwcBsBAVtISkokMjKSiRMn5UgFLVHCnHnzZqGpqcWD\nB/eIj3/DlCnTCQ4+yN27t6lUyV6VRtmiRQM6d+7G+fPnMDExZfDgEfj4rOTVq5eMGTMBZ+eGSKVS\nhg8fw82biZQsGUxs7BS2bWuKufkKXr+OQCqVEhkZQcOGjRk5csw/uPICAgL/BMGZExAQEBD4rnn8\n+CGrV3shFotQU1Nn4kR30tPT+fXXOaxfr4eTU3VVJNfVdRBLly6gf/8eiMUSBg4cSsOGjXPVgH0t\nd+/e4fDhQ/j6bkIuVzB0qCuurl0JCnqMsXFfevXSwc5Om82b77F1axAmJqaMGDGI69evUamSPcuX\nL2LBgqUUK5azOXZ2PdX69Zv/Wk9Pli5djampKSkpybx9+/YD5UkJb9++QU9PH3//gFx2Llu2kF69\n+lG/fgOuXr2Cn9+7ejRNTa1/vO5mzVrg5bVE5cydOHGUmJheREV1RaHQQyx+Q3h4e0DpTKenx1On\njgkPHw6iXr2TZGYqa8byun5OTtVypV++H6Xftm0Tu3blbPj9NTg5wbFjtlhYbCE+fgCVKsWSmpqC\nvb0Dx46F0KpVW0JCDuPo6MTkya3o1+8wCoU2Tk5vmD69DW3aLAJAR0eX1NQU1bwfOoIODk48fvwY\ngHv37qCtrY2+vj7Pnj1RjdPR0SE19V1bktq16+Dr60PLlm2Ijo7C39+XxYtXoK6uzsSJY/Dw8KBt\n2w6qVNDlyxfj6bkYgOTkJNau9efs2VNMmTIBHx8/rKysGTy4Pw8fPqB8+Qqkp6dTvXotRo78ialT\nf2bDBh+8vLx58uQx8+bNxNm5IXv2BJGZCc+e/Ya6+mNKlx7E06cHSExM4+HD+2zcGICamjq9e3fF\nxaUnZmbFv/qeCAgI5EZw5gQEBAQEvmtq1apDrVp1cu3fvn1Prn3a2tpMmzYr1/6QkFOqz40bN6Nx\n42ZfZdP169do2LCJyiFq1KgpxYopsLTUY/PmVoBSJr5ixcqYmpoBUL78D8TERKOnp8eTJ48YOzZn\nc+xssvuBAVSp4si8eTNp2rSFSpXzQ3R1dSlZshQnThylSZPmKBQKHj16iJlZNVJTU1TnP3z4kGrM\nl6bKVahgS0JCPHFxccTHv0Ff34DixUVkZCxFW/syCoUYufytqnl3iRIW+PqOUo13cfFFoVDkef3C\nw6/i7Nwo1zk/1fD7axgzpgUKRSinTtWgWLEOaGgYsGrVbcaOnYSn52wCArZgZGTE1KkzMTAwoHbt\nctSvX5VGjZRiQNlOZvnyFVSCL23bKgVQ3n9hMHDgUPbuDcLVtRfq6uqYmprh6tpTpSYJUKyYIVWq\nONK/fw/q1KnPyJFjePr0KcOHu5GYmIBIJP5LFVOp1nnt2jVmzfoVUKaCenuvUNlUv76yDYKVlQ3G\nxiZYW9v8tW1NTEwU5ctXQF1dndq16/51XcujoaGBRCLB2tqG6Ghli4cbN8Lp3bszDx4c4MmTjmRm\nlqRixXU4Olqjq5uFjo4uAJaWVkRHRwnOnIBAASE4cwICAgIC/1liY18zd+5Z3r7Vpm5dNYYObfr3\ngz6DvKJ7eQX81NXfpTFKJGJVPZmVlc1HlRi1tN6l6E2c6M7t2zc5f/4cgwb1Y/78RXkqT86Y4cHi\nxb+yaZMfMpmM5s1bUrduNQYOHMr06ZPR1zegevUaql5syojXP142AE2aNOfkyaO8fv2a5s1b8urV\na3buvEZKygAqVUokOXmjKnVQWzvv6N+Ha8juffZ3Db+vXQvj3LkzbN7sx6ZNgapatC9BJBIxdmxL\nxo5tCUzJ8V1eypnZKYrZZL8gUFNTy3V8mzbvFB4NDAw4derC39rzYe2ii0tPXFx6snv3Dl6/fq1K\nkd20KZAOHVp81CFXV1dGbsVica7+gdk/f+/XkIpEyoj3h8cAmJubsmmTKVu37uTatVimT6/Kmzev\nckWH5XL5365PQEDgy/g2qowFBAQEBAQKGYVCwZAhx9i+vQ+//daV2bOr4e9/6u8HfgaOjlU5ffok\nUmk6aWlpnD59gipVquZIlfsYZcta5tkcOy8iIyOoVMmeQYOGYWhoiEgkVrVdAKWEvJvbECwsSrJk\nyQo2bgxg69adDBgwGABn50bs3LlfJYyyYoUPoIwW9ezZ94vW3rRpC44eDeHkyWM0adIcU1M9fvzR\nnosXW/DTT5a8evXyk+NFIlGu63fmzEkcHJwwMjL+7Ibf6elpX2R/UbBo0WFatTpKx46HCQm5+o/G\nlitXnoCAfdSseZDOnfcRFnYLJycnjh0LAVClguY3jo5VCQk5jJ2dJUOHVkJLK4Pq1avl6UQKoigC\nAgWHEJkTEBAQEPhPkpAQz61b5cnukZWZWYZLly7h5vb1c//wgx1t27ZnyBBXADp06IytrV2OVLm6\ndevnGf1SU1PDw2NBns2xP2TNGi8iIl6gUCioUaOWSvjkSzh37jaHDr1ASyuT8eMbfbEwjJWVNWlp\nqRQvXgJjY5OPNu+GvCKYyu28rl+FCj8AfHbDb11dvS+yv7AJCjrH8uX1yMxUtst4/vww1au/xsTE\n5LPGr137ghcvJmFs7MeLF2ImT9YkOHg5EydOypEKms3n1Ifmju7m/q5zZxcWL/bE1bUnEomEadNm\noaamlqfasKA+LCBQcAh95gQKHKEfyr8b4f7+e/m331uZTEaDBkd59Mglew/DhgXh4fHfaHL8/v09\nf/4OgwdnEBvrDMipV8+PoKAuqpS8r6Vbtw74+W3FwKAYLVo0IDT0TL7M+29g1qzfWbPG5b09L9m1\n6xYNG9b8rPEdO4by559dVNt2dnu5c6fzv/p397/Ov/1v83+ZL+kzJ0TmBAQEBAT+k6ipqTFzZgnm\nzw8kPl4PJ6eXuLt/H45cbOxrAgMvoqkpxtW1MZqaml8136FDz4iNzXYoxJw/35CHDx9TsaLt1xuL\nMjKTlpaOh8cM0tPT6d+/B66ugylWrBhr1niRlZWFnV0lJk50R11dnW7dOtCiRWv+/PMcYrGESZOm\n4eOzkqioSHr16kenTl0BCAjYTGjoEaKiEjAwqMygQT1o3bp6vthcWDg4GKCp+RypVNlwu3TpK9jb\nV/rs8XZ2Kfz5ZwagAcixs3tbMIZ+BgqFgt9/P8+rV0l06FALY2OjIrNFQOC/guDMCQgICAj8Z2nd\n2olWraoik8nyLQpV0Lx69Zru3c9y+3YfIJPQUH8CAly+yn59/SxARvZjgZ7eS4yMSnzRXO7uE3n1\n6iUZGVJcXHrRsWNnAK5cuYipaXG0tLTZvHkHycnJ9O/fgxUrfChdugxz585k795ddO/eC5FIRIkS\n5vj7B7By5VLmz5+Fj48/UqmU/v170KlTVy5e/JMXL54TF9eJq1ddKVnyf0yYEIWamoTmzat+8bUo\nbLp0qcfz5yGEhl5BSyuT//2vJMbGn5diCTB3blvU1IJ4+FCLUqVS8fBo+feDCoiJE/cQENCKrCwz\n/P2D2LatJqVKfdnPkYCAwOchOHMCAgICAv9pRCLRd+PIAWzefPEvR04EaHDqVDeOHbtE69b1vnjO\nMWOacOWKP+fO1UVXN44RI15jbv5lDpG7+wwMDAyQStMZMsSVxo2VCqFWVtb4+/uSmZlBePg1dHR0\nKFmyFKVLlwGUCo979uyke/deAKo2BNbW5UlLS0NbWxttbW3U1dVJTk7m4sU/OX/+HDEx4ZQtewCx\nOI3k5BaEhr6kefMvvhRFglI188vGamhoMH9+x/w16At4/vw5QUH2ZGVZAHD7di98fALx8GhXxJYJ\nCPy7EZw5AQEBgSIkLS2NGTOmEBsbi1yehavrYEqVKs2qVctIS0ujWDFDpk2biYmJKZGRESxdupCE\nhHi0tLSYPHkaZctaFvUSBAoZZTsxOZAtuy9FQ+PLJfhB2ZQ6MNCFp0+fYmBQFjOzL09VDArazpkz\nSlXQV69e8eLFCwBKlSqNn982fvyxFb6+a6hePWdNWHb7gWyy5e3FYnEOZ1spjy8DoHv33nh4/MDr\n19neWybFiu36YtsFvpysrCyysnK+FJHLBeETAYGCRmhNICAgIFCEXLjwB6amxdm4MYDNm3dQp05d\nvLwWMW/eQjZs2EK7dh1Yt24NAAsXzmPcuJ9VMvJLliwoYusFioLBgxtQs+ZGIB14TYcOh2jc+PPE\nMj6FRCLBxsYGMzOzvz/4I4SFXebKlUusXevPxo0BVKjwAxkZUgBev37zV/NpNXr16sfNmzeIiYkm\nMjICgCNHgqlatVquOfPSaROJRNSuXYeTJ48xalQC5uaH0NcPpnHjVYwb93UN3wW+DEtLS9q3vwQo\na/ZsbPbRv3/lojVKQOA/gBCZExAQEChCbGwqsHq1F97eK6lXrwH6+no8fvyIsWNHAiCXyzExMSMt\nLY0bN64zffpk1djMTFlRmS1QhOjp6REU1J4DB35HT0+Ttm17IBZ//N1sXjVsLVo0oHfv3hw/fgIT\nE1MGDx6Bj89KXr16yZgxE3B2bohUKmXJkl+5d+8OEomEUaPGUa1aDYKDD3L27GmkUimRkRE0bNiY\nkSPHAHDq1HHu37/HqFFDKVHCnPDwdz3Tnj59zKxZU0lPT2PjxvVMnOhOcnIS06dPJisri4oVK9Op\nU7e/js4pn59T2l75uWbNOjx9+pRDh/xxdJSjqanF7Nnz0dbWJjk5mdDQ3+ncWTlfWNhlAgO3sXDh\nsvy5CQK5EIlEeHu70KDBceLjM+jUyYkyZcyL2iwBgX89gjMnICAgUISUKVMWP79tnD9/Fl/fNVSr\nVgMrKxt8fPxyHJeSkoy+vj7+/gFFZKnAt4SOjg49e346AhUUFMj+/buxtrZhw4YtOWrY0tPTqVu3\nLm5uI5g69Wc2bPDBy8ubJ08eM2/eTJydG7JnTxBisbIJ+fPnTxk3bhTbt+8B4OHD+2zcGICamjq9\ne3fFxaUnIpGIc+fOYG9fhdjYV4SFXcbExPQva0RUr16DJk2a0bJlI3x9N6ns9PPbloft+1Wf27Rp\nT5s27fP8zsWlJy4uPXONT0p6y969QSpn7mvJyspCIvm6VNb85syZk5QpU07Vay8/+Nq2EWKxmL59\nhciogEBhIjhzAgICAkVIXFwc+vr6tGzZBl1dPfbt20VCQgI3b97A3r4KMpmMFy+eY2VlTcmSJTlx\n4ihNmjRHoVDw6NHDr2oSLfDvZt++XXh5ebN//x4GDOgNvKthU1dXp0GDBsTGJmFjU/6v9EcJ1tY2\nREdHA3DjRjjduvUAoGxZS8zNLXjx4jkikYjq1Wuho6MLgKWlFdHRUSQkJODkVJ1p02YBsGtXIC9e\nPMfJqXoOBywk5NRXr00mk7F//1mkUhldujizb98ugoMPAtC+fSdu3bpBZGQEbm69qVmzNnXrOpOW\nlsovv0zmyZNH2NpWZMYMDwDu3r2TZ43qqFFD+eEHW65fD6dFi1b06NHnq+3OT06fPkn9+g3+kTMn\nk8lQU/vUo59Q4yYg8L0hOHMCAgICRcjjxw9ZvdoLsViEmpo6Eye6IxaL8fJaTHJyMllZMnr06I2V\nlTUzZsxl8eJf2bTJD5lMRvPmLQVnTgCAwMCtOZyZ58+fEhUVyciRgwARW7bsRFNTk9Gjh5GRIUUi\neffvXyRS/uxBtrhI1t+eL1ucRDlGQlZWFqIP/ACZLIuwsGdMmHCEOnUMcHGp+/ULRRklGzBgJyEh\nvQF1tm1bhLHxH6xfvxm5XMHQoa7MmOHBkyePVJHssLDLPHhwj61bgzAxMWXEiEFcv36NSpXsWb58\nEQsWLKVYMUOOHQth3bo1uLvPQCQSIZPJWL9+c77Y/XdER0cxceIYHBycuHkzHDOz4nh6LuHIkWAO\nHtxLZqaM0qVLM336HO7fv8e5c2e4du0qmzf74eGxAE/POYwaNQ47u4q8efMGF5euBAUdIDj4IKdO\nHSc9PR25XM7ChcuZMmUCSUlvycqSMWTICJVyqICAwPeH4MwJCAgIFCG1atWhVq06ufavWrUux3Zc\nXBxZWVksXuz1Qf2QQF6MGjWU0aPHY2trR7duHfDz24qBQbGiNqtAuHv3DocPH8LXd1MOZ+bChfMM\nHjyCY8dC0NTU5OnTJ9y6dfOz53V0rEpIyGGqVavB8+fPePkyhnLlLLl3706uY0UiERUrVmbFiqUk\nJSWhra2Nn18QkZENiI3tRlDQY5KTT+Hm9vVOw8GD5wgJ6QnoA/DgQWmaNi2HpqYWAI0aNeXatau5\nxlWsWBlTU6W4S/nyPxATE42enh5PnuSuUc2mWbPC7dkWEfGC2bM9mTx5GjNmuHPq1HEaN26q6tXn\n6+vNoUP76dq1B87ODalfvwGNGilbP+SuLXzHgwf32bQpEH19fbKysvD0XISOji4JCQkMH+4mOHMC\nAt8xgjMnICAg8I0zb95vbNxojlRqQNOmO/D17fpd9UUrCt5/qM0v51cul39SaKSouH79Gg0bNsnT\nmalWrQYhIYfp29eFMmXKYW9fBch9Td7fzP6uc2cXFi/2xNW1JxKJhGnTZqGmpvZRp8HU1Ix+/dwY\nMsQVAwMDEhNLIJcrHej0dGtOnLiKm9vXr1cmyyLn44sYuTyn4mVet1xdXUP1WSJ5F4HMq0Y1Gy0t\n7a819x9hYVFKFW23tbUjOjqKR48e4uvrTUpKMqmpadSu/S7CmZfSZ17UrFkbfX191Rgfn1WEh19D\nLBYRFxdLfPwbjIyM839BAgICBY7gzAkICAh8w9y8eZ+1ayuRnu4AQHBwJdatO8j//te6iC0rHAIC\nNqOhoUG3bj1ZsWIJjx49xMvLmytXLvHbbwdo06YdGzasIyMjg1KlSjN16ky0tXM/gO/atQMDA4N/\nPE+3bh1o1qwlly5doE+f/ujrG+Dn9/fnK0zycqyyd2loaLB48Ypc379ftzZw4NA8v9PQ0GDq1Jm5\nxn4oSPK+QmSLFq3p2LEzmZmZNGkygPT0Kqrv9PTSP3NFn6Zjx/oEBGzj7Fk3QIK5eRCJiWlIpenI\n5QpOnz7BtGmzCQzMLawCykjmtWtXsbOrxK1bN3ny5HGeNapFQe70VSnz58/h11+XYGNTnsOHD3H1\n6hXVMe/fe4lEgkIhByAjIyPHvFpaWqrPISGHSUxMwM9vKxKJhLZtmzJixCAqV7YH4KefRpKUlEif\nPgO4fPkCPXr0+Whd3tmzp3n69DF9+w746JqCgw9y794dxo2b9PkXQkBA4LP59l4xCggICAioiI5+\nQ3p6qff2aJGYWDDnGjFiYMFM/BFiYqIJDf39k8c4OlYjPPwaoHwIT0tLQyaTER5+FRub8ixa9CsL\nFizDz28rtrZ27NiR9wO8vb3DJ+fZtMmP5cvX5JpHJBJRrJghfn5bqV69Fps3++Hllfu4osTRsSqn\nT59EKk0nLS2N06dP4OjoVGjnT0tLY8KEfXTpEkrfvj/j6tqTAQN6Ua1aSYyMUlBTu07VqluYNKlG\nvpxPQ0ODgIDOzJ27n5kzd3HwoA9durgwZIgrw4YNoEOHztja2lGliiP9+/dgzZoVf0UTlePt7CpS\ntary+qipqVGnTj18fFYyYEBv3Nx6c+vW9Y+eWyYr/HYgaWmpGBubIJPJOHIkWLVfR0eHlJQU1baF\nRUnu3r0NwO+/f/z3KiUlBSMjYyQSCWFhl3n79i0zZ85l+nQP5HI5IpFSYbRZsxZMnvzLJwVWnJ0b\nftKRg/yLjAsICOSNEJkTEBAQ+IapV68Kjo6HCA/vD4iwsDhOu3Y2+Tb/+/23vL3zTjUrCGQyGVFR\nkYSGHqFFi49HGW1t7bh37w6pqSloaGhgZ1eRu3fvcP36NZydGxITE8Xo0UORSCRkZsqoUsUhz3nK\nl6/wyXmePn2scmY/nKdZsxYA3Lp1g6dPHzN8eN7HFRU//GBH27btGTLEFYAOHTpToYJtoZ1/8uTD\nBAb2QflI8SOdO29h7dquACQnJxMXF0upUu3yJTU4LS2NGTOmEBsbi1yehavrYKZN+5nRo8fTo0cf\nWrRoQGzsS/r1646JiSmTJv2Cj89KTp48xpgxEwDlz3x0dDTjxk0iOPgghoaGzJ49n7NnT7N5sx97\n9gRx9GgIc+Z4YmRkzIYNa4mKiiAqKgpzcwtmzpz71ev4GHk5PoMHD2Po0AEYGhpSubI9qampgLKe\nb8GCeezatYO5cxfQq1dfpk9358CBvTRr1pRsZcr302IDA7dy8OA+YmKiOXHiGLq6SkVSD48ZtGvX\nkYwMKXfv3mbgwD65RFX+/PMP1q1bg1wux9DQkOXL1+SIumVfP5ksEwODYsycOVdI3RQQKAQEZ05A\nQEDgG0ZXV5etWxuyYkUgmZnquLiUw9GxYBQss3tMhYVdxs9vHSYmRty5c5cmTZpjZWXN7t07yMjI\nYP78xZQqVZp582ahoaHBvXt3SUlJZvTo8dSr5/zJZtPZqnpZWVlkZmby7NkT3Nx606ZNBxo2bIyH\nxwwN72x0AAAgAElEQVTS0tIAGD9+Evb2Dujp6ePm1gd1dXWePXtCWNhlUlJSePz4EQqFApFIhIFB\nMby8vD+6NjU1NSwsShEcfJAqVRyxsSlPWNglIiMjsLAoRY0atZk1a16eY99Po/zUcUVJjx596NGj\nDzKZjPj4eORyOUFBBwrl3A8e6PPucULCgwfvhGb09PTQ09NTbW/YsBYdHV1SU1NwdHSiRo1ahIdf\nZdEiTzQ01PH29mP9eh/+/PMcdes6q5qRZ3Phwh+YmhZn0SIvQNl/cd++Xarv09PTqV69FiNH/vTR\n/nkAWVky9u7dhaampmqso6MT7u7zWL78KlFR4cyZs4BlyxYA8OzZM9asWY+GhgYFhYVFSTZtClRt\n9+rVV/X5XTP1d1Sp4sjWrTtz7Nu0aTsAZmb69OkzCHiXFpstlOPntzWHUM7UqT/j4+OHgUExKlWy\nZ/v2rarU2WxHMD4+noUL57FmzXrMzS1ISkpSfZ+No6MT69ZtBODgwX1s27aZUaPGfnZdX0GQ/fP2\n/rV8n4Lo1ScgUNgIzpyAgIBAEfO5kuQeHnPQ1NRi3rxZaGpq8eDBPeLj3zBlynSCgw9y9+5tKlWy\nV9U5Xbz4Z571XX/++QcrVy5FU1MLB4eq71ny7sHs4cMHrF79OxkZYlxcOtKhQ6f/s3eWAVFlbxx+\nhu6yQLBAKSXEVuze1V1XxVpFReVv69qFhaBgIq4oJuiCK3Z3d6CirphgEAoiHQIz/w8jIwgoKpj3\n+TQz99xzzo2B+877nt+PVav8CQraxNat/8qyHM+fR7N6tT/Pnj1l5MjBbNq0/b1m07lV9a5du5rn\nwTEjI53Fi/9GSUmJp0+fMGvWNFav9sfEpCoHDuxl1ix3bG3t6Ny5AzVqWOHsPIw9e3YxZcoMzMws\nSEtLIzY2hgoVKhZ4nm1sbAkM3MiUKTMwNjZh6dJFWFhYUr26FYsWeRAR8QxDQ6NC+7G0rFGkdl+L\n8+dDmTjxAc+eVcLE5Aze3jUxN69c4uMaGCQDEnLuH+n7gsl5+B8w4H+yzw4d2o+jY3/atGkPwO7d\n29m//3iBWSoTk2r8/bcXPj7eNGzYGBsb2zzBgqKiokwgpDD/PJBmhrdvD6JHj7cP+U+ehDN4sCvp\n6QqIRJmEhWlx4kQIIpEIe/sm+QK5qKhIJk78C3//f4t4pr4uuYVyrl69R1xcBUaODEQsTpO1KSjw\nkkgk3L59E1tbO/T1DQBkYiq5efHiOdOnTyIu7iWZmZmUL2+Yr82X5kMlnp/i1Scg8K0hBHMCAgIC\n3wAfI0kuEolITk5i5cp1nDlzkkmTxrJixVqqVDFm4EBH7t+/R5kyZWXru5SVVdi4cT3//vsPPXv2\nwdPTDW/vlRgaGjF9+uQClf8sLCwpXbo0MTFJGBlVkD0gGxubEBx8BZA+KLVoIS1BNDKqQPnyhjx+\nHP5es+natevmUdXLTWZmFosXe/DgwX3k5OR49uwpIH2AB2jUqDHKyiooKytTrlw5dHR00NHRwcNj\nDtnZUuEHZ+eh7wnmarJhwzpq1LCS9WNjUxMdHR2mTp3JzJlTeP06s9B+dHV1i9TuazF37gNCQ3sC\ncONGQ9zd/8Hfv3KJj+vu3pj09A2EhWlRsWIi7u55rTb8/NZw4MBedHX1KFu2HGZmFri7z6JhQ3uS\nk5M4fvwoly5d5MKFc6SmppCWloaT05/07t0fO7vaLFw4l+fPowEYOXIsa9f+w9y5s5k2bQKKiopk\nZmaSlJTEtGkTyMrKYtAgR0aOHItIJCI4+ApPnz4hKiqSlJRktmzZhLFxVR4/DiMpKYk1a1agra0D\ngIfHXJ49G0BKSldUVS9RqpQ3x49HUqkSMqXQ75mcwCY5OZmRI8OJi6tOdrYeenqX2bnzEn36tP7g\nvu9j8WJPevbsQ6NGjbl27Spr1/p+cJ+SoKD7bffuHezate2DXn1Xr17O5+n3I1x7gR8bIZgTEBAQ\n+Ab4WEnyRo0aA1JZdT29Uhgbm7x5b0x0dCQvXjwvcH3XkyePKV/eEENDIwDatGnPrl3b880nt4y7\nSCSSvReJRO81lf7QQ9/7lB///fcfSpUqjYuLK9nZ2bRo0RAAU1MzGjSwR0lJGYlEQrt2v2BubgmA\nsrIyS5Ysz+ch5+29UvY6p9ywVq06HD9+XvZ5TrYQpBL+q1blN4d+t1SxsHbfAvHxKu+8/zIqm/r6\npQkI+KPAbaGhdzh27DDr1weSnZ2Fk1NvzMwsAOm90qFDJ0JCbuTxS2vduonM7HvmzKl069YLa2tb\noqOjGT16CH5+mzAxqcqjRw+oVKkKqakpBAT406/fAC5evICrqyfjxo2Q/dDw9OkTvL1X0qqVPUuX\nLqJy5SpkZ4tRV9egQ4dObN8eRN++PYmOjkBV9RUpKaCjsw5l5dtcvfqE+/dVaNOmnex45s6djUgk\nom7deiV9aosVGxtb3NxmYW1dlwcPLKlYcT3R0Z5IJOsICUkqdD+RSET16lYsXDiPqKhIDAzKk5iY\ngJaWdp4fZFJTU2Q+fvv37ynx4ymIgu43c3MLmjZtTseOnYD3e/VpamoW+AOagMC3jBDMCQgI/NTk\nrBOLjY1hyZIFzJnjUaT2xc3HSpLniEnIycm9s6/UP0tOTr7A9V337997Z+RPX88ikUg4fvwI7dt3\nIDIygsjICCpVqpzHbNrHx5vbt28yc+YUTEyq8fTpY7ZtC6JzZwfU1TV49OgBgYEb6dmzN8HBl4mK\niuLixQsYGBggFouJiorE1XU6AI6O3fPJ7Oco+mlpaRMcfI9Fi+6TlqZEixYwbFjhmYZPOdbNm0/w\n8mU6v/1mh5FRuWLru7ioUyeR0NBUQA2RKI569TK+9pQICbn2prRPGVCmUaMmBbYrbF3VlSuXePw4\nTPY+KSmJQYMciY+PR15env79B/H330v4779bLF7sSUZGOpMnjyE1NZWsrCxEIhENG9pz//5dsrPF\nVKxYkV69+jJ/vhvq6hrs3r0DKytb3Nw8mTFjMqdOLUNHJxBIQE1NnZ07dzJq1BAuX75Iv34DmTt3\nFmPGTMLGxpbly71K4IyVHDlCOUuWuGFikk5s7AAyMiwAMZUrS/+G5Fb9zI2Ojg4TJkxl6tTxiMUS\n9PT0WLRoWR5xFScnZ1xcJqKpqUWtWrWJjo7K1eeXUbQs6H6TSCiyV1/+dvULGEVA4NtCCOYEBAR+\ncqQPGaVLl/lgIJe7/ZfgXUnysmWLFkDk/JJe0PquSpUqExUVKfv88OGDefZ7+7rwvnO2iUQiypXT\nZ9CgvqSkJDN+/GQUFRVlZtPdu/9BTMwLPD0XY2VlS8+enbG1rcmxY4fp3NkBE5OqxMe/Yu/eXURE\nPKVUqTIkJ6cgEkFY2COZOEVsbAxWVjYsX74633x+++0Pxo4dgZ5eKa5d68L9+z0AuHTpMWXLnsPB\noWGRztn7kEgkjBgRRFBQZyQSHTZs2Mn69emYmVX67L6LEw+P3yhXbg+PH8tjYSFi2LBfv/aU+Pzv\niwRfX798Sphr1/qiqqqGubkF3t4r6dChVaHtFBQUuXnzBv37D+TEiaPY2trRunU7zp49hUgkws3N\nEwBn52E8ffqUJUv+pl+/XmzbtheAyZOn4+IyieTkZJKTk7Gxka4zbdv2Vy5cOPeZx/dlyRHK2b8/\nmCVLokhL206TJgMZOlTqG1izZi1q1qwla587w12/fkPq18/7fcrtOWhv3xR7+6b5xnzXl7BkKfh+\nK6pXn7v7LObNW1RgOwGBbxXBZ05AQEAAqZiBo6O0nGbfvt1MmTKesWNH0qNHZ5Yvz2+6HB8fz+DB\nTpw/f5bY2FiGDRtE//69cHTsLvMz+xjeJ0k+ZMiAfAv08wZeeff181uLoqICU6fOZMqU8bRo0ZDB\ng51Ys2YFW7duZsKEqUyYMBonp97o6ZWSBWc5ZtF2drXx8HhrBO3tvRIzM3NA+rCXe1udOvVYvdqf\n2NgYGjSwB96aTXfp0o0//3Skbt0GqKqq0qHD71Svbk18/CtiY2MJC3tEtWpmbNy4GRUVVe7cuY2c\nnOhNwCgnMxnW1y+fJ5D7668JsofDLl26ExCwFWfnUdy/bydrk5FRiWvXEot49t9PZGQEu3bZIJHo\nAiIePuyEn9/tYum7OFFQUGDChPZ4eDTE0DBBtmZswoS/vtqcbG1rvvHAyyA1NYWzZ99mtYuiclin\nTn2Cgt4qPObPLBe1nSifUEpmZmaBfcnLy5OVlc0ff+yiUaMjTJ16ALFYnK/d11Rp/Fzat7fj4MFf\nOXWqFXPm/F5o5qwo3pM3blyjd+9uODn9SVzcS/76awfduh1i2rRd+czLi4tr165y69ZbP8AdO7Zy\n4MDeQu+3tLSUInn1FebpJyDwLSNk5gQEBAQK4MGDe6xfH4CCgiK9enXBwaEHZcqUBeDVqzgmThyD\ns/NQateuS2DgRurVa4CjoxMSiUQmrV9UPlaSPEetsqB9c2+zs6vNvHkLmTjxL/z8AmWCBPXqNeCf\nf97KuRfG+9bG5Sf/w+C7D4hSGwFo3rwVJ04c4eXLl7Rq1Ua2vXfvfvz+e+c8+0RFRaKqqiLbf/bs\nPZw9q4aGRhrjxpnQsKF0/VWlSoYYGt4gIsL4zdhxVKnyab5mmzZtZN++3QB06NAJMzNzDAxmkJJi\nj6rqNbKyyiEWN/ukvr8ESUmJbN8exB9/5L93vjSmpua0bNmafv16oqurh6Vlddm2wn6QyP169Ohx\nLFrkQd++PcnOzsbW1o5x4ya9aUeR21lb2+Dp6S77fl65cgkDA0MePw5j5sypzJzpxoEDe6lZsxbq\n6hrEx8sTHm5Genpt4uK8sLLSe2OzoElIyHWsraWlxJ9LUNAmdu7cipmZOS4urp/dX3FTFO/J3Gqk\nAwduZdcuR0COEycyyMoKYt683/O0z8rKQkHh8x4/g4OvoKamTo0aUp/HTp26yLa9e7+JRDBw4OAP\nevW5us4rtJ2AwLeMEMwJCAgIFECtWnVRU5Ma6lauXIXo6CjKlClLVlYmo0YNYezYSdjY1ATA0rI6\nc+fOJisri8aNm1GtmmmRxti/fw+bNv2DSCSiatVqDBw4GHf3WSQkJKCjo8uUKdMpV04fN7eZqKtr\ncPfuf7x8+ZKhQ0fSrFlLYmNjmTFjMqmpKWRnZzNu3GSsrW3p2rUja9duREtLm61bN/P06ROGDh2I\nRCLHvXv3uXjxLhDH/ft3SU1NRUdHBy+v5VSsWJlhwwYRFvaQjIwMTE1NmTbNlalTJ/DyZSwVK1bC\n2tqWo0cPsWTJcu7fv8exY4e5ezeU9PQ01qxZydmzp8nOzsLVdZ5McKF3776IxRJOnz6Bi4srCgoK\neHjMISEhnr//XgVAvXr1WbVqBW3atEdVVZWYmBcoKOQNxnx9j7F8eVskklIAPH++hSNHKqOqqoq2\ntg7u7posWbKJ1FRlmjRJZuDA39895R8kx4tr1So/mRdXzZquKClFExXVnhcvZmFq6oCVVcqHO/tK\nrFjhTUTEM/r374WCggIqKqpMmzaRsLCHmJlZMH26NGg4f/487u5zyc7OxtzcknHjpGWyPj7enD17\nGnl5eerWrc+wYaN49epVPlVJKyubIs3H0dEJR8fCMzy5f4CAtxliAG1tHWbNmptvHycn5zzvi9LO\n3r4Jhw8fYMECd0xMqlK/fkPMzCxZsGAuffv2lNl3iMVikpJ6UqbMfOTk0nj9uiKKim1kc5UKoEiz\ngZ+7FmzHji14efnIhEOgeIKd4uJd70kdHd0899Hu3TtkaqQXL54jNLQppUvPR139DADXr0tLNoOD\nr7B69Qq0tLR4/DicCROmsmbNSjQ1NXn48EGhXpYFGZGnp6eza9c25OTkOXRoH6NHT+DKlYsyP7kG\nDRpx5swpUlJSSE9P59dff0dTU5MjRw5RvboVwcFXSE5O4saN69jY2Mq8+pKSErGwsMbff9N7hZoE\nBL41vo2/FgICAgLfGPkFSaRZKgUFBczNLblw4ZwsmLOxqcnff6/i3LkzuLvPpHv3P2nX7v3rlR49\neoi//1pWrlyHlpY2iYmJzJkzg19+6Ui7dr+yd+8ulixZwNy5CwCIi3uJj89awsPDmDRpDM2ateTw\n4QOyjKBYLCY9PR14m9kIDb3DuXOnMTQ0olmzznh5zebVK0ceP66AgcEc1qzxo3JlY7p27cjcua64\nukqNvq2ta+LpuZhJk0YzZco4xo+fgpfXAoYOHcXEiX9hYFCe+fPdsbCwpH79hrIHUR0dXdau3cj2\n7VsIDNzIxInT+OWXDgwa1BeAjh3/kAW6aWmplC1bDj09aWBWp059wsPDGTy4PyAtf3Jxcc0jnnD/\nfrYskAMICzMnOjqKKlWk2bj27e1o3/6TLreM3F5cAE2btuDGjWsYGhoxcWIi0dFbyc6uR1ZW+ucN\nVIIMGTKSsLBHrFsXwLVrV5k8eSwbNwZRqlRphgwZwM2bNzA1NWfy5MksXrwcI6MKzJkzg+3bt9Cu\n3S+cPn2CgICtgNSUG8DLa0EeVclx40awcWPQ1zzMj+b33zsTGlqOiAh5wsN96d27P9WqmbJy5bp8\nbY2NVTl9ehPSjHMKDg5SVVMzM3PWrw+QtXvX1PxjmD/fncjICMaOHcHz59E0atSEyMgI9PUNGDVq\nHAsWuOcLntPS0li82JOwsEdv1BqdC1ynVnzk9p68l+c+Cgm5TseOnbh5860a6cWL80hJieLx413I\ny79EQ6M9L19Kv//3799lw4bN6OsbEBx8hQcP7hMQsAVNTa1CvSwLMyL//fcuqKmpyXwCr169JMvU\nzpkzgzFjJmJjU5M1a1aybp2vzKpCLBazapUf58+fZd06X5YsWQ7A/v3BTJ2awLNnplhYHOfvv82o\nUcOkBM+rgEDxIQRzAgICAh+FiMmTpzNt2gT++cePP//sS3R0NGXKlKFjx068fv2a+/fvfjCYCw6+\nTIsWrWWS+lpaWvz3301Z8Na27S/4+EjX6olEIho3lj6wVa5chbi4OOD9GUGJREJIyDXq1WvAkSOH\n8PX1JjGxI1lZZRGJ5BCLJbi6zkBBQZ709HRiY2O4c+c2pUqVpk2bdigoKNCmTRvc3NxYuHAejx+H\n4+npRlpaKq1bt2PVKh/KldPHxqYmVlbWLF7sKZP3NjU15+TJY8BbwYV3yV0amoODQw8cHHoU2tbM\nTAE5uVjE4tIAGBvfwcCg2XvP88dSWKZFSUmRDh0aARAYuJG0tG+3/Cr3Wi6JRIKFRXVZwF21qilR\nUZGoqKhiZGSEkVEFQCpSsW3bZrp06YaSkjJz586mYcPGMguMd1UlU1NTSU9PR0Xl+/Hg6t17DDEx\nYkSi1yQmdiYo6CEuLhYFtvX2bszMmf8QG6uOlVUakya1JzLyBfv2BVOhgi5t236+LcH48VO4dOkC\n3t4r2bLlX86dO8Py5atRUlLKZ8mQEzz7+6+ldu26TJkyg6SkJJyd+1K7dr0vch3evY+io6Oxts7b\npk6ddLKz9VBU3ImxcTwmJnW4c+c/1NXVsbCoLjMdl/ZnKfsxpzAvy/cZkRe0ZDElJUekRvpDW7t2\nv+LiMkm2vWnT5oA0KM9R2wRYvDiKZ8+kf3vu3DFnwYJA1q8XgjmB7wMhmBMQEPipKWitzvuktHO2\nzZzpzsSJY1BTU0dFRYXAwA0oKCigpqbOtGmzijRuQQIKhYkq5Fbpy2nz4YygCIkENDQ0EIuVUFB4\nTkaG2ZttCvj4rEZDQ5MJE/6iZ8/esixMzoOhWCxGUVGJdesCmDfPlapVqxEaegdra1uys7O4dSuE\nkSPHyOaTk82Ul5f7yPV2+fH03M/u3QrIy2fj5KSOo2NjBg5szvPnezlzRhlNzdeMG1e12B9i3y0N\nPXXqOC4uswv04vteyO0ZmHNtClrPKN0uz6pVfly5cokTJ46ybdtmvLx8KExV8nvi1au+PHnSSfY+\nJGRroW3Lly+Lr+/bMt07d8JwcnrEw4ddUFSMwslpF66uvxXLvHLOvb19E5SUpNeqoOA5LS2NS5cu\ncPbsKQIDNwCQmZnJixfRVKxYuVjm8j7y30dZ+dro6WkxfHhVfv21JQCurjdk95qKSt7SxaJ4WRa3\nEXnOGLmrLQCSk/P+HUlJUUJA4HtBCOYEBAR+anLW5+QWEnlXStvTc3G+9oqKiixa5C37/GOlt+3s\n6jBlyjh69PjzTZllAjVqWHP06CHatv2FQ4f2y35dLoy8GcGMPBlBkUiErW1Ndu7ciry8PG5u0xk+\n3InsbG0yM01QUVHi8uWLNG/eColEQmRkBPXrNyQu7qVsDd6xY8coV64cx48feVPutJyOHTthamqG\nvLw8GRkZqKmps337h8VUPoZdu87j7V2fjIyKALi6XqJ27QdYWlZl2rSSlTjP8eLKXRqqqamVL/j5\nUr5Zn4KamtoHhRsqVqxERESEzKLi4MF91KxZi7S0NNLT02jQoBFWVjZ07y4NaHLUInv16gPA0KED\nmT9/CRIJHD58QCa2Ehx8hU2b/snzncnBw2MO3bv/mU+Z9UtRunQKjx7lvJNQunTRs6urV4fy8GE3\nADIzDdm8WZ/x4xPymdV/DjmlvTnzKyx4dnObT4UKFYtt3OIgJyC1tq7Jzp3baN++AwkJCdy4cY3h\nw0cTFvboAz0UTGFG5O+qUErnAOrqGmhqasnWw+WI2nyIRo0SePAgAdBGWfkJzZsLYu8C3w9CMCcg\nICDwidy9G8a1aw9p2NCSihXLf9S+VaoY4+joxPDhzsjJyWNqasbo0ROYO3cWAQEb0NXVzSMMUVAG\n8dq1K+/NCJqamtOwoT1btvyLr+8iGjduxPXr52nXrgzJyc3Ys2cXfn5riYh4RqlSpfj1198wM7Ng\nzRpfNm36h1atWvLXX5NYsGAe0dGRxMS8IDk5GTk5OUxNzXj27Bl9+/YoYM3O55kE37uXIAvkABIS\nbLl+fQ+WllU/uc+PIXdpqEQi4fLlEAYNGicTpsitNvotoq2tg5WVDY6O3VFWVpaVsuVGSUkJd3d3\nXFwmkp2djYVFdTp16kp8fDyTJ499IykvYcSIMUDBapHq6hpERUUWWTlz4sRpxX2oH8Xs2RZMmbKR\nqChNTE1fMXNm0deavZswF4vlCrQrKC7eDZ7v379HtWqm1K1bny1bNslsO+7dC8XU1LzE5lEU78nc\n7Zo2bc7t2yH069cTkUjE0KGj0NXVIzw8LM/+hZmTv7utMCPyRo2aMG3aRM6ePcWoUePzzG/q1Jks\nWDCX9PR0mahNISPJXnl4dKJy5SM8fizBzk6DHj1aFX6wAgLfGCLJN2KUEhOT9LWnIFBClCmjKVzf\nH5if9foGBp5l1ixt4uLsKF/+DAsXKtGype3XnlaxUti1zcrKIj09DQ0NTeDtr/LFla06c+YW/fur\nkpAgPZ+GhofZtasSFSoYfGDP4kUikTB8eBBbtzZHLFahRYvd+Pt3lZXCfe+877sbEOCPkpISXbv2\nYOnShTx8+AAvLx+uXr3Mnj07uXUrhNWr/Vm0yIMzZ05RsWIl6tSpR4MG9qxd64u2tk4+9czhw50Z\nMWIMZmbmtG7dGAeHnpw7dwZlZWXmzVuIrq7eFznu7Oxs5OXlP2qf69fvM3BgFE+e/IqcXAy9e+9l\nwYIuH97xAzg4/M7q1X5s3bo5j6BHQkI8ixZ5EB4ensdqISMjg6VLF3LrVghisZjy5Q3z+D7m8LP+\nXf5ZEK7vj0uZMpofvY+QmRMQEBD4BNasSSIurh0AkZEt8fXd/MMFcwURGHiOhQuTSUrSoU6dhxgZ\nyXP0aCkUFTMZNEiV/v0/X1nP3r4Gc+acZevWLSgoiHF2NvzigRzAoUMX2LLlVyQSfQCOHevH+vW7\ncXZu+8Xn8qWxsbFj06aN3Lp1k5Mnr5CRocavv+6gQYNQbG3tuHUrBJFIlEc5E6Rllvfv382nnmll\nZZMn2E9PT6dGDWucnYeyfPlSdu3aTt++A77IsX1sIAdga1uNzZtV2b9/M/r66nTu3PnDOxWBoKCd\nQNGsFtLT03n27ClDhoyQ/ZAi8GHOn7/DiROPMTJSoXfvpt90ibSAwKcgBHMCAgICn0BWVt4HwszM\nj39A/N5ITk7CwyOTyEhpRuLQoSbAv4BUVMLd/QL29mFUq/b5a6K6d29E9+6f3c1nkZCQhkSSe02U\nEqmp30QxS4ljZmbO3bt3yMjQJTm5FKmpDXj2zIqkpH/o3bs7GzeuBwoW7MmvnhmVz5NOUVGRhg3t\n34xlwZUrF0v2gIoBY2Mjhg0z+ipj3779iOHD/+POHRsMDa/g6qrFL798eC3Yz87evZcZM0aDV68c\nEIlecvPmdjw9iycQFxD4VhBWeAoICAh8Ar/9JkZZ+TEAmpq36dz5+5Fo/1Ti4+OJjc39MKsIaMne\nJSRYEhr67IvPq6To0KE+tWoFAtK1Uaamm+nWze7rTuo9REVF0qtXF9zdZ9GzZ2dmzZrGpUsXGDzY\niR49OnPnzm3WrFlJYOBG2T59+nQjOlrqZbZ//x769u1Jv369mDfPFQMDQ5KS4hGLVVBTO0X58s5k\nZcV8UHyjKKqH8vJvf0uWkxN9tvrpj87Chbe5fbsnYrElT5/+xqJF0V97St8FO3a84tWrugBIJKU4\neFCXrKz896OAwPeMkJkTEBAQ+ATGjGlLtWoXCA29RJ06BjRr1uRrT6nEMTAoT82a27h40QYQoaJy\nCzm5RHKEE6tUOUGDBtbv7eN7Qk1NjX//bcvKlUFkZYno3duO8uXLfu1pvZeIiGfMmePJ5MnTGTjQ\nkaNHD7FixVrOnDmJv/+6PF6E8Had47sm9klJSQQFBXL9+jUkkrI8e/YPlSr9jpxcfJ4yxaIoZwp8\nPikpynnevyulL1AwCgp5AzclpUzk5IQ8hsCPhRDMCQgICHwiHTvWp2PHrz2LL4e8vDxr17bAwyOQ\nlBRlWrTQRF6+LDt2bEVRMYthw6pSunR+5cTvGS0tLcaP/+VrT6PIGBgYYmwsNTuuUsWY2rXrvvOR\nnUEAACAASURBVHltQnR0ZL5gTookn4m9pqYmNjY1EYtXU69eJV6+PEFsbHY+Vcfcypn16zeiQYNG\n71U9zKEgdVaBwmnWTI5z5568UXlNolGj+K89pe+CoUNNuX59Ow8ftkRT8x4DBigKwZzAD4cQzAkI\nCAgIFJkyZUqxYEHeCLaYtCAEioEc43YAOTk5mU+ZnJycTMVRInkrqS+1ICjYxL5WrTq0b/8rDRvW\np1mzlkAbWreWZqCDgnbJ2s2YMSfPfrl9vXIk9AG8vVfKXuf4NQI0a9byTf8ChTFkSCt0dc9w9eol\nKlYUMXRopw/vJICVVVX27SvFmTNnqFatPObmzb/2lAQEih3h5wkBAQEBAYH3MH78KFJSkklOTs5j\nkB4cfIUJE/76ijP7eAwMynP3bigAt2/fJioqEhBhZ1eH48ePkJiYAEBiYmKJjJ+QkMCkSbsYOvQQ\nmzadLZExflR69LBn/vy2jBjR5pMUOX9WdHV16dixMebmJl97KgICJYKQmRMQEBAQEHgP8+d7AVKB\nka1bNwHwxx9duX//Lnfu3P6icwkOvsKmTf/g6ZnfWwzylyy+W87YtGkLDhzYS58+3bCzq0mFCpWA\ngk3sc8yWi6skUiKR4OR0gNOnnQA59u69h5zcObp1a/jJfQoICAj87Aim4QIljmBu+WMjXN8fl5/l\n2n6MQfbp09LywK5du1O6dBnWr1+DnV3tfAbZuRGLxcW2TudDwdzH8KWvb2xsLHXrRpGc/DZ469Zt\nK8uWtflic/hZ+Fm+uz8rwvX9cfkU03ChzFJAQEBA4KfGxsaOGzeuAxAaeoe0tDSysrIICbmOra3U\niiDHIFtRUQmRSMTlyxfZsWMrKSnJvH6dAcDFi+cJCZH207VrR3x8vHFy6s3x40c4fPgAffv2wNGx\nOz4+3rKxW7duLHt9/PgR3N1nAVJVSmfnfvTt2wNf3+WytWoAaWmpTJs2kT//7Mrs2S4lck7S09MZ\nP34nf/xxmBEjthdL2aWmpia6us9zfZKNru7rz+5XQEBA4GdGCOYEBAQEBH5qcgyyU1NTUFJSokYN\nK0JD73DjxjVsbGrK2kkkEvT09DA0NGLdugCsrW0Ri8VMmDCVjRuDkJeX4/Jlqfm1SCRCW1uHtWs3\nYmNTkxUrlrF06QrWrQsgNPQ/Tp8+8abXgksYvbwW0L17L/z8NlG2bLk8871//y6jR49j48YgIiMj\nZAFkcTJ16n78/Lpz9mxn/v23N3/9dfiz+1RWVmbSJE0qVdqKtvYJmjdfy8SJzT5/sgICAgI/McKa\nOQEBAQGBnxoFBQUMDAzZt283VlY2mJhUJTj4MhEREVSuXKXQ/apUMUZLS5vSpcsAoKOjS1zcS9n2\nli1bA3Dnzm3s7Gqjra0DQOvW7bh+/RqNGzcrtO/bt28yb96iN+3b8vffXrJtFhbVZWNWrWpKdHQU\n1ta2n3bwhXDvngZSU3gAOe7f1y6Wfh0c6tGpUyYpKcloa9sJtgSfyKZNG9m3bzcAHTp0okmTZowZ\nMxxzc0vu3QvFzMyUCRNcUFZWITT0DsuWLSYtLQ1tbR2mTp1BqVKlGT7cmerVrQgOvkJychKTJk3H\nxqZ47yMBAYGSRwjmBAQEBAR+emxsbAkM3Mj//jecVat8SEtLw8LCkoCADSQlJbFz5zb2799DZGQE\nKiqqgDRIS09PA8DNbSYvX8Zy6tQJLl++SFpaGqqqqojFYnbu3MbNmzeIjY1BQUGBcuX0ZX3kDmYy\nMjKKNFdFRSXZa3l5qeVAcVO+fAogISdzaGiYXGx9KyoqoqOjW2z9/WyEht5h//49rFrlh1gswdm5\nLzVr2vH06ROmTJlBjRrWLF48l23btuDg0IMlS+bj4bEIbW0djh49hK/vciZPno5IJEIsFrNqlR/n\nz59l3TpflixZ/rUPT0BA4CMRyiwFBAQEvnOioiJxdOz+tafxXRMf/4rnz6M5evQg8vLyKCsrY2NT\nUxZsbdmyCX//f7G3b0p6ehrLly8lPDxMtn9qaiqJiUnY2zfB03MJSUnSNWYnTx4jMzMTTU0tRo0a\nx61bN7l584ZsLZ6enh6PH4cjFos5deq4rL/q1a04fvwoAEeOHPrg/IOCNtG7twOursWzhs7NrQlt\n2/pTteoOmjXbwNy5dYulX4HPJyTkOk2aNEdZWQVVVVWaNm3B9evXKFu2HDVqWAPw22+/ERJynSdP\nHhMW9pDRo4fSv38v/P3XEhMTI+uraVOp75qZmTnR0VFf5XgEBAQ+DyEzJyAgICDwQ9G6dWMOHz5N\nbGwMS5YsYM4cD/bt283du3fymFjn5urVy+zYsZ/MzEwmTvyLwMBtAAQGbqRbt57cvn2LmTOn0rRp\nc+Tl5Th//gzh4WHo6+sD0gybiooy1ta2VK5cBbFYaswdEnKDdu1+RVFRkRkzJpOdnY2hYQXs7aWC\nJoMHD2fChNFoa+tgYWFJWpo00zdy5Fhmz3Zhw4Z11K1bHw0NDdlcC6pM3LFjC15ePrLyS4CsrCwU\nFD7t33zp0nps2CC4wX+LFFSaKhLl/Vwikbx5L6FKFRNWrFhbYF85WV45OfkSyfAKCAiUPEIwJyAg\nIPADIBaL8fBw49atG5QpU5a5cxdy8OA+du/eTmZmFkZGRri4zEZZWYVjx46wfv0q5OTk0dDQYNky\n3689/WJG+lBbunQZ5szxkH7ynrVZ8+e7ExkZwdixI4iKikJZWVm2LSDAj9at2zN27ESGD3cmPPwR\nERHPsLdvRlhYGLGxCYwfPxYVFQU0NDQJCblBQIA/IpEIZWUVRCKIi3tJcPBVFBQUUVRUwM6uFiAt\nzVRSUkJbWwdra1uGDx8tG7dMmTL4+q4H4MiRgzx9+gQAO7va2NnVlrX7668Jeeb//Hk0jRo1ITIy\nAn19A/73v2G4u88iISEBHR1dpkyZTrly+ri5zURHR5OQkFu8ehXHpEku7Nu3m9DQ/7C0rCHzmBMo\neUJD73DgwF5Gjx5XpPY2Nra4uc2id+++iMUSTp06jovLbLy8FnLr1k1q1LBiz5492NjYUrFiZeLj\nX8k+z8rK4unTJ1SpYlzCRyUgIPClEMosBQQEBH4Anj59Qpcu3diwYTMaGpqcPHmMZs1asGqVP+vX\nB1CpUhX27NkJgJ/fahYt+pv16wPw8Fj0lWdecuQuP81tqXru3BkGD3YiISGeS5cu8PDhAyQSMDAw\npHNnB9LSUklMTOD169ekpqYikUiIjY0hNjaGiROnoaWlzblzVojFirx40ZnTp38hOTmVly9jZddA\nJBJx8uQxrKxsCAjwZ/Toccyfv4TMzCx27dohm0tsbAwrV67LE8gBhIaG0q9fL/r27cmOHVtl2x8/\njmTDhoPcuHFP1nb8+CmULl0Gb++VdOvWi/DwMLy8fJgxYw6LFnnyyy8d8fMLpE2bdixZskC2X1JS\nEitXrmPkyDFMmjSWXr0c2bBhMw8fPuD+/XsIfBnMzS2KHMgBmJqa88svHRg0qC//+18/Onb8A01N\nLSpWrMT27Zvp3duBpKQkOnXqioKCAq6uHqxY4U2/fr3o378Xt2+HFNLzlxGjGT7cmdDQO19kLAGB\nnwEhMycgICDwA2BgYEjVqtUA6fqXqKhIHj58wKpVPqSkJJOamka9eg0AsLKywc1tBi1atJatmflZ\nOHnyOJs3B7BgwVKysrLw91+Ll9dyevfuRtWqVQkJuUGdOvUYNKgvZcqURVFRiezsbJYuXYicnBzz\n57tjbl6Xdev6UrXqMkDEo0edqFTJD11dPdk1kJOTIyoqEgeHniQlJeHo2ANFRQUkEkhNTQGk2cLm\nzVsVmDW0sbFl/fqAd+Z+k9GjU4iI6ISW1k0mTTrJwIFNZdtzAtbGjZuipCQtn/vvv5vMnSsN4Nq2\n/QUfn6W5xm4GQJUqJujplcLY2OTNe2OioyOpVs20mM76j09UVCRjx46gRg1rbt68gbm5Je3bd2Dd\nOl9evYpnxgypmbyX10Jev85AWVmZyZNnULFipTxG8GvWrOT582iioiJ5/jyabt160rVrj3zjde/+\nJ927/5lnfHl5eVxcpOPkNpWuVs20wOy7t/dK2WsdHR2CgnYWy7nIuQ8Ly4aLRCJBxVRAoBgRgjkB\nAQGBHwAlJUXZa+n6lwzc3Wczb95CTEyqsn//Hq5duwrAuHGT+e+/W5w/f5YBA/qwZs0GtLSKR3r+\nW+bq1SuEht5h8eK/UVNT4+zZ04SHP2LwYCdiYl5w9OgR1NTUqF+/IXPnLgSgR48/6NXLEYlEzMSJ\nf+Hv/y8nTlzG3z8GUCAmZhqQRv36v3Lt2jbZWM7Ow94oXUrQ0dFl166DJCTE4+zcL8/6JRUVlSLP\nf/XqCCIiHABITKyJn99DBg7M305ZOW+fubOSuVFUlN4zcnJy79w/JaOQCTBkiBM+PgWv3/reiYh4\nxpw5nkyePJ2BAx05evQQPj5rOXPmJP7+63Bxmc3ff69CXl6ey5cv4uv7N3PmeObr5+nTJ3h7ryQl\nJZlevbrwxx8OyMvLf3D8DwVI69ev5tCh/WhpaRMZmU1mZkVMTKqgqHiVxMREVFRUmDhxKhUrVsbN\nbSbq6hrcvfsfL1++ZOjQkTRr1hKAgAB/jh8/wuvXmTRp0owBA/5HVFQkY8YMp3p1K+7evcP8+UvZ\nuHE9oaH/kZGRTrNmLRkw4H+fdmIFBATei1BmKSAgIPCDkpaWip5eKbKysjh4cJ/s84iIZ1ha1mDA\ngP+ho6PDixcvvuIsvwwikQhDQ0PS0lJ58uSx7PPateuxbl0AZcqUZcWKNXTv3ou7d0MBuHs3lKio\nyHx9NW1amz//PIpEIkJO7jbt22+ga9cGedqIxdlcunSfxYtPk5qaSteuHRk2zJm+fQfw8uXLfH0W\nhezsvA/rWVkffsCvUcOao0elapiHDu3PY4L+NfhRAzmQZseNjU0QiURUqWJM7dpSBdAqVUyIjo4k\nOTmJadMm4ujYnWXLFhMW9ihfHyKRiIYN7VFQUEBbWwddXT1evYorwtjl8fPbVOj2O3duc/LkMfz8\nNiGRtCY6OplHj6w5ezaYlJSarFmzgaFDR7FwoYdsn7i4l/j4rMXTcwkrViwD4NKlCzx79pRVq/xZ\nt+4f7t4N5caNa4D070rnzg5s2LAZfX19nJ2Hsnq1P+vXB3L9ejAPHz74qPMpICBQNITMnICAgMAP\nQEG/yg8c+D+cnfuho6ND9eo1SE1NBWD5ci+ePXuKRCKhdu26stLAHxmJRIK+vgHDho1iypQJuLrO\nw9KyBosWeRAR8QwQkZ6egYlJNQ4c2EufPt2wtKxBhQqVZH3knGORSMSCBV2YO/c/goOHY2Jig6Ji\nM9l2iURCUFAIoaGWxMU5YGAAlSptJjs7nYCADbRq1UZ2zj+m3MzBQZsrV64SH18LZeUndOr0OtfW\nt/3k7nL06AnMnTuLgIAN6Orq5hE2yT32u/MoqTK4HKXR4OArrF3ri46OLmFhDzEzs2D6dNcSGfNL\n8W52M3fmMzs7m9WrV1C7dh3mzl1AdHQUI0YUnKlSUMjbT1bW52dJb968QePGzVBUVOThwzIkJ7dA\nJMpAVfU6oaGP6d//GACZmVmA9Po3biwt4a1cuQpxcdKA8tKlC1y+fJH+/XsBkJaWzrNnTylbthzl\nyhlgaVlDNuaxY4fYtWsH2dnZvHwZS3h4GCYmVT/7WAQEBPIiBHMCAgICH8GnlonlXhdTVNasWYma\nmjo9e/Z+b7t3f5XP3b5Tp6752ru5zS/yHL5HCgpSctbpVKxYmRkzXHFxmYSn52KmTp3JzJlTUFNT\nY8KE0Tg7D2XRomUF9vtu5mPyZJcCt0dHRxEcPIj09DoAREU5oKsrYsGCtoDU0y409C5jx07Ko5z5\nIf74oz76+rc5fz6IatW06djxF9m2nPVOTk7OefbR19fHy8snX19TpsyQratKTX1N375/kZycjIaG\nRgkrWb69Ng8e3GPjxiBKlSrNkCEDCAm5jrW1bQmO/fWQSCSkpCTLrCP27t1VaLuSQSTr28AgmchI\nADFisSbGxgNYt65Tvj1ygtF359W7dz9+/z2vbUVUVCSqqm/LeyMjI9i06R9Wr96AhoYG7u6zeP06\no3gPSUBAABCCOQEBAYGPojjLxD7kfVbc2ZGnT6M5ePA6lSrp0br1j2sCfejQSSBvkNu+fQfat+8A\nQLVqZmzcuBmA8uUNWbXKv1jHV1FRQUXlGenpOZ9IUFTMBCAo6AKzZ2fz4kUVrKwO4utrh7GxUZH7\nbtCgOg0aVC+2uS5bdoRFiwxJTrbB0vII69fXpHJlw2Lr/32Ymprj4TGHmJgYXryI5ujRQwQHX+Hs\n2VNkZGRQo4Y1EyZMBaQKiGZm5ty4cZ20tFSmTZuFv/86wsIe0bJlawYNGgLAwYP72LLlX7KyMrG0\nrMHYsZOQk/syK0rel92Uk5OjZ09H3Nxm4Oe3hgYN7Ckomyr90aH452ZtbYOnpzt9+vTHxcUGZ+fl\nZGQ0QFVVmQ4dpJk/iUTCw4cP3pupr1evPqtWraBNm/aoqqoSE/MiTyYxh5SUFFRUVFFXVycu7iUX\nLpyjZs1axX9gAgICQjAnICAg8DEUpUzszp3bLF26kLS0dBQVFfNlRnIybjo6OgD06dON+fOXoq+v\nj5/fGg4c2Iuurh5ly5bDzMwCkK5HWbTIk/j4V3mECopKSMh9BgyI4PHjrigpPWXgwN3MnNmxeE7K\nd0ZmZiZz5hzg4UNlKlRIZ/r01qiqqhZb/7q6ejg5vWT58uukpxthbb2PUaPskUgkLF4cx/PnUruE\nkBAzFiwIYPnyogdzxUlGRgarVsmRnCwN7P/7rydeXptYvPjLBHMpKSkYGlZg/nwvFi/2pEoVY1q0\naEO/flJVF1fX6Zw9e5pGjRojEolQVFRi9Wp/goI2MWnSWNat+wdNTS26d+9E9+5/Ehf3kmPHDrNi\nxVrk5eVZsGAehw7tp127X0v8WN7NjufObubelmNGD8gC0Nzege9mVv39/y2W+ZmbW2Jv34S+fXug\np1cKe3tb6tevSa1a/2PBgnns2bOVrKysQkuAc17XqVOf8PBwBg/uD4CamhouLq75FCqrVTPF1NSM\nXr26ULasPtbWNsVyHAICAvkRgjkBAQGBjyJvmdjChctwc5vBpUsX6Nz5V6ysbLh27Sq6unpkZGQw\natRYHj16yKJFnsTEPGfIECdMTc1RU1N/26NIxNWrl9m8+R+ys7MZPnw0fn5rOHnyGOHhj+jUqQue\nnm6MHz8FI6MK3L59i4ULPQosnyuMNWse8PhxNwBev65IUJAuEyakoqamVnyn5jth6tS9rF/vAKgA\nmSQl/cOyZV2KdYxJk9rz++/3iYi4RoMGbVBXVyc7O5uUlLxKk6mpRS+zLG4yMzNJT1fP89nr1/mz\nLCWFmpo6V65cxMfHm9jYGMzNLQkOvkxAwAYyMtJJTEzE2NiERo0aA2Bv3wQAY2MTjI2ldgogza4+\nfx5NSMg17t4NZeDAPoA0WC1VqtQXO55PRSKR4OGxj1OnlNDQyGDcuKrUrWte7OP07NkHJydn0tPT\n32Q6LTAwKM/ChUvztX231DYn2w3g4NADB4f8dgnvliEXVq6b2xJBQEDg8xGCOQEBAYFPxMKiOnp6\nekREPKNp0xY0atSYDRvWI5FI8PML5MyZk2zeHIiLy2xGjx7H5s2BODj0wMtrAb/++rtsHUpKSjI7\ndmyhVau2JCYmEhi4EW/vlfj6+vDkSTgbN67n5s0QXFwmysbOESooKhKJ6J33ciW4PqdgPmW94enT\nJ6hQoRKVK1cptnncvq2BNJADUOT27ZKxZbCwqIaFxduSNXl5eeztXxIUlAqooaZ2l9atiy8j+LFo\naGjQokU4W7cmAxro6V2gU6cyJTpm7uyNmpoqa9f+w/nzZzh4cB8SiYRbt26yZs0GypQpy9q1vrx+\n/VbkRVFRSdZHzuuc9zlWCu3bd+B//xtWosdQ3KxdewIvrxZkZ5cFIDIyiMOHKxVrthjA09ON8PBH\nvH79mvbtO1Ctmlmx9l8QZ8/e5syZpxgbq9O1q73gLycgUAIIwZyAgMA3QXJyMocPH+CPP/ILduQQ\nFRUp8/r6Fsh5oDQwMERXVxexWIyhoRHZ2dJAK7ck+dq1K3nw4D7Pn0cRHx+PRCIGpN5nr169YsEC\nby5fvkhMzP1c3mcxyMmJ0NHRQVNTk3XrAgqdy4dwdKzC2bMHefq0LQoK0XTq9AJ1dfUP70jxnfdP\nWW946tQJGjVqXKzBXNmyKe+8Ty22vj+El9cfVKu2j+fPRdSvr83vvzf+YmMXxLJlXbGxOUhMjJhW\nrSrRoEHJWhfkZHjs7GpTsWJllJSUaNOmPRoamuzevQORCLS0tElNTeX48SO0aNG6SP2KRCJq1arL\npElj6datF7q6uiQmJpCamoa+vn5JHtJnc+fOa1kgB/DggSVRUZEyE/fiYsaMOcXa34fYseMCEybo\nER/vgIJCNDdv7mb27N++6BwEBH4GhGBOQEDgmyApKZHt24PeG8x9q+SWJNfU1CQ5OZnQ0P/Q1tYh\nMzMTX9/lmJqao6GhxZgxExg4sA9374ZSv35DtLW1CQ9/RGRkJLa2Ndm8OQA7uzpMmjQNJ6c+dOrU\nmR49ejNkiBPHjx+hefNWRRIqeJfatc3YtOkJhw5txshIk99++70kTsV7ad26MZ6eSwgM3ChT9Vy0\nyAMLi+q0b98BHx9vzp49jby8PHXr1qdp0+acPXua69ev4ee3hjlzPDE0/Pz1ZTNn1iUhwY9Hj7Sp\nWDGJ2bO/nPeagoICo0e3/WLjfQh5eXkGD27zxcZLSkoiIOAsCgpymJur4Ovrg5ycCAUFRcaNm8yp\nU8dxdOyOnl6pPDL3uSlIJGTHjq0yIZQxY4YhFktQUFBg7NiJ30QwFxS0iZ07t2JmZo6LS14LhmrV\n5BGJ4pBI9ACoUuUe+voNAdi8OYDff++czwj+e2DbtkTi46XBeFaWPvv3qzNrlkTIzgkIFDNCMCcg\nIPBNsGKFNxERz+jfvxd2drV58OABSUmJZGdnMWjQEOztm+ZpHxHxDBeXiUyYMA1NTU2ZOIiioiJ1\n6tRjwICCPZzeR1HsA/KKAuTfLicnR48ef7J48XySk5OJjo7EyKgihoZGREVFsnfvLlRUVElKSmTl\nyr9RV9fA0NAIb+9FzJ27gNat2xEQ4P/G2Ls6mZmZPH36hOnT57BgwTz8/NbmEyooKtWqVaRatYof\ntU8OYrEYDw83bt26QZkyZZk7dyEHD+5j9+7tZGZmYWRkhIvLbDIzs+jXrydbtuwGIC0tjT//7EpQ\n0C4kEgnLly/lyZPHDBs2iIkTp8rOZ2JiAqdPnyAgYCsgLT1VV9fA3r4JjRo1pmnTFp8074KoVMmA\nbds6I5EID5ZfksTERBwcDnLtWj8gi8aN1xEY6IeS0tuSSTMzc5kwSG5yr7OqWbNWHmXEd9dgtWxZ\ntGzel2THji14efnIrAly4+zcksjI3Zw9q4aGRjpjx1aRrWUNCtpE27a/fJfBnKJi9nvfCwgIFA9C\nMCcgIPBNMGTISMLCHrFuXQDZ2dlkZKSjpqZOfHw8gwf3zxPMPXkSzsyZU5k6dRYmJlUZNWqITBzk\n5MnjuLpO/6RgrijkLhOzs6tNVFQkIpFIZi9w7dpVjIwqsHLlOqKiIpk0aQy9e/fDzW0GqqqqVK1q\nikgkx6JFy9i/fw93795h9Ojx3L9/V+Z9VqdOPXx8lvLw4QMePnyAsXFVGjVqXKBQQVH53FLJ8PAw\nZs50Z+LEqUyfPpmTJ4/RrFkLfvvtDwBWrfJhz56ddOnSnWrVTAkOvoKdXW3OnTtNvXoNkZeX5/Xr\n1zg49ODIkUP07TuAhQs9ZOWT6uoaKCkpM3fubBo2bCwTvYCS894SArkvi7//2TeBnBygxOnTf7J7\n93G6dGn2Uf2sX7+aQ4f2o6OjK1N8ffDgHklJWmRlKZOdfQ8vL28g7w80ly5dkK3DMzQ0YsoU6Xey\na9eOtGrVlu3bg8jMzKRcOX0GDhyCoaERy5YtJi0tDW1tHaZOnUGpUqULVZZ1c5uJuroGd+/+x8uX\nLxk6dCTNmrVk/nx3IiMjGDt2BG3atOf06ZO8fp2BsrIykyfPoGLFSkyf/is+Pt5cunSe1avlePGi\nExKJhNjYGEaOHIyOju5HCR59CwweXIWQkD08ftwCLa3bODkpCd85AYESQAjmBAQEvglyP7BLJBJW\nrFjGjRvXkZMTERsbw6tXcQC8evWKyZPH4e6+gEqVKpOamsqtW2/FQaKiosjISKd//17UqVMPiQQu\nXjyHSCTC0XEALVu2lmWI3v08N3fu3Gb+fHfmzPGkfPnCpdo/R5K8IO+z8PDHZGTI4+3ti4qKCsHB\n99iy5REHD+5hxIgGlCnzddT5RCKRLBNoZmZOVFQkDx8+YNUqH1JSkklNTaNevQYAtGjRmmPHDmNn\nV5sjRw7RpUs3UlNTEYvFrF3rS1xcHDExz8nMzJKdW3l5eVat8uPKlUucOHGUbds2yx5ei/IA+LHB\n6v79e6hTpz6lS5f+lNMh8AnIyYmA3IF5NvLy7/eACw29w4EDexk9ehxr1qwkOTmJ69eD8fPbRGZm\nJk5OvdHXN+DQoaNERMwjObktVas24sqV29SuXZ1jxw7TqlVb4uPj8fdfi5fXcpSVVdi4cT3//vsP\n/foNRCQSER//ihYt2mBqasa9e6HUr9+AceNGMm/eIrS1dTh69BC+vsuZPHn6e5Vl4+Je4uOzlvDw\nMCZNGkOzZi0ZP34Kly5dwNt7JQoKCvTo0Rt5eXkuX76Ir+/fzJnjya5d23n+PJr16wORk5MjMTER\nLS0t/v03AG/vlWhplYxIT0lSp445e/aU5syZI1hYVMDSstnXnpKAwA+JEMwJCAh8cxw6tJ+EhHjW\nrt2IvLw8Dg6/kZEhVbXT0NCgXDkDbty4RqVKlZFIxGhovBUHiY6OYsKE0axbF8CJE0fZZKTsEgAA\nIABJREFUuXMbfn6biI9/xcCBjtja1uTmzRs8eHAv3+c53Lx5gyVLFjBv3iLKli33xY570aJDeHsb\nkZJiQs2ae5k6tRKjRmUSEeEASLhwYT07dvzyyXYCRS2VVFZWITIyguHDZ5CUlIyNjV2efuTk5MnO\nzsDdfTbz5i3ExKQq+/fv4dq1qwA0atQEX9/lJCYmcu9eKLVq1SE1NQUQsXTpCoYNG4Svrx/p6ek4\nOfXG2tqWtLQ00tPTaNCgEVZWNnTvLl3Tp6amRkpKyruH8tns27ebKlVMhGDuC9K3b2P27VvHpUuO\nQCatWgXSoUP39+5jbm6BubnUa1EkEhEVFUnjxs1QVFREUVGRRo0aExf3iqwsBaSPNPIkJrbB13cf\ntrZmnD9/lmHDRhMcfEUmLARSNVgrK2vZOL/88huuri68fp1BeHgYz59H8+jRQ0aPHgpIvzulSpUh\nLS2tUGVZkUhE48bSCoLKlasQFxeX73iSkpJwdZ1BRMTTPCqcV69eolOnrjKDcy0trY8/wd8g5cqV\npkuX5l97GgICPzRCMCcgIPBNoKamRmqqVFUwOTkZXV095OXlCQ6+QnR0lKydoqIi7u7zGTNmOKqq\nqrRu3Y7y5cvLxEHEYrFMzjwk5DqtW7dDJBKhq6uHra0dd+78x82bNwr8XF1dnfDwR8yf787ixX9T\nqtSXe9BPSIhn1SpNUlLqAXDtWl9cXecRETH5TQsR16//xrlz12jVqt4njfH06ZMil0p6eS2gV69e\nNGzYgvXrVxfYX1paKnp6pcjKyuLgwX2ywFdNTQ1zc0u8vObLDJ/V1TWQkxNx+/ZNmjdvRZ8+3dHR\n0cHMTCqPnpqawqRJY99cOwkjRowBoGXLNnh4uLFly7+4us57rwBKdnY2s2e7cO9eKJUrG+PiMouw\nsLB8pXIhIdcJDb3D7NnTUFZWZvToCWze/A9ubvM5ffoEM2dO5eDBk2RnZ9OnTzc2b95ZaGndq1ev\nWLhwLs+fRwMwcuRYrKxsWLNmJc+fRxMVFcnz59F069aTrl3ze3P9KERFRTJ27AjM/8/eWYdFlbZx\n+B6GlBIUMQETkEZsbF111bWwXQEDYy1s7Mbe1V0DXUEUY0WxVtfuTsDCVhqR7piZ749ZRhAwcY3v\n3Nfl5cw5b533zDDvc57n/T1mtQvM/+3bwaxZsxKJRIKZWW22bRvL/v1/c+3aEeLiXjBo0N/Ur9+A\nESPGcPLkcTZt2oCSkhgtLS3++GN9oX2sr1694sGDvRw9eph+/QYAoKKijNzjJwMkKCkl8fDhafr2\nPY+urq5C4t/BoT6zZy8ocvwmJiZ4e29l9+6/uHDhLKdPn6Rq1eqsW1dQgTUtLfWtyrIqKq/FkN4M\nD5bJZPz55zocHOri6bmMqKhIRo8eVmx5AQEBgfdBMOYEBAS+CnR1S2NlZcOAAb0wM6tNaOgLnJ17\nY2pqjrHxa1l6kUiEuro6S5b8hrv7CEqV0iwgDpKRkfGvF0hetrgF0pvH80L5ypY1ICcnm4cPQ2jY\n0PEzXW1h5OMunX9EiMXqQCZ5OdHU1aMwNCxdVPX3okKFSu8dKnnnTjAbNngRH59OkybN2bixcKLf\nwYOH4ubmQunSpbGwsFQY4yAXoZg500MhTpGUlEiZMmX5++/9xMW9QllZmQYNGuHiMlhRZ8MG30J9\nWFnZ4Oe3872uLzT0BR4eM7G0tMbTcy67d+/k3LnTeHquoHTpgqFyAQH+jBzpjqmpGbm5uTx69BCA\noKBAqlWrwf37d8nNzcXCwgqg2NC6lSuX0bNnX6ytbYmOjmbChFH4+fkDcuP599+9SEtLpW/f7nTt\n2gOxWPxe1/ItEhYWytSpsxTzv327H/v372HVqnVUrlyF+fNncfjwQX766UcOHVpXQOwGwNf3T1as\nWE3ZsmUVx/Ijk8lIS0tFV7c0S5euxM3NGRUVFVq1+gE1NSnKyuHo6m6ibNlUSpcuTc2apjx+/JCo\nqEhq17ZkxYrFRESEU6lSZTIyMnj1KpYqVeSCQPHxcVSoUIkGDRpx8uQx7t+/S2JiInfu3MbS0orc\n3FzCwkKpWrVagYdHH6osm5aWphBBOXTogOK4g0N99u0LwN7eAbFYrAizzPNMf4thlgICAv8NgjEn\nICDw1fA+eZDy9qBpaWmxYcNmxfE8cZCkpEQGDfoZAGtrW/bt20P79h1JSkoiKOgWI0eORSKRsG9f\nQKHjz549RUtLGw+PGYwd+wvq6hoFVPM+J4aG5Wne/DT//GMDqGFgcJ7Jkx3w9vbl5MmGqKkl4+IS\nhpVVx4/uI38KhXeFSuanfPnyqKu/TmDcp09/xesuXQqmkli8eD69evWjefNWnD17FYBXr2IZNWoo\n/fu70L17z7eOUSKRsHz5VqKjE3F2bo+NTU3Onz/L8+dP6d/fhY0bvShVSrPAGPIoV84QS0t56Fzb\ntj/i6+vN06dPcHcvGCqXR55Br6ysTKVKlXnx4jkhIffo3bsfgYG3kEol2NjYvjW07vr1q7x48Uxx\nPD09nYyMDEQiEY0aOaKsrIyubmn09PRJSIgvUs3we+HN+d+06U8qVqxE5cpVAPke0YCAnXTv3rNI\nsRsrKxsWLJhFy5ZtaNascGieSCSideu2iEQiRo8eSm5uLoaGhmhqaqKursK0aZmcP3+WlJQ4UlIy\nOXv2NOXLlyc8PIy6deszbdpsZs+eSnZ2DgBubiMUxtzz58+YNWsa2dnZxMW9Ytq0OSgpKbFy5TJS\nU1ORSHLp1asvVatWe6uybEG12/x7PUWIRCL69h3AggWz8PXd+O/DInmZTp26EBYWirNzH5SVlfnp\np65069aDn37qyvjxozAwKPfNCaAICAj8NwjGnICAwDfNzp0X2bUrBbFYypAhFWnZ0kbh4WvQoBE1\natTAxaUPIpGIESPGoKenT7NmLbh7N7jQ8efPnyESgZ6ePkuW/MqECaOZOnUW5uYWn/06RCIRGzZ0\nZ+3a/SQkQPv2xtSvb06zZtaEhoaioVEeQ0ObEu+3uFBJKysbDh48SMOGLTh69PB7tzd58vRCx3R0\ndLG1Hcy5c2Kys8/Tp0/RHk+ZTMYvv+zi+HE91NTiOX48mXXr7uHo2BRHx6bA28VQ8p+TyWRoamoW\nGSpXVHkbGzsuXTqPWKxMnTr1OHx4FlKpjF9+GYNUKnlLaJ2M9et9C4TX5aGsnN94ViI39/uWZn9z\n/rW0tElOTipwDIoXu5kwwYN79+5w6dIFBg36mY0btxTZT58+PzNwoBtz5kzn3r27VK9eg3LlyuHs\nPIBHj+7g5jaYunUbFKpnb+9Q4AEQyI3vKVPmU7myIb6+2wvV+eOP9YWOVahQsUhlWQ+PmQXmIE/5\nFsDffx8AlpZWRYohicViRo1yZ9Qod0D+4EEmk9G9ey+6d3/7vkIBAYH/bwRjTkBA4Jvl/Pk7TJtW\nnqQkeRLm+/dPsG9fZCEP34gRYwrVHTFiTKHj+fNXGRqWZ8uW9wvvKylUVVUZM6ZdgWNKSkqYmJiU\nSPtFGULFhUqOGTOBhQtnsW6dF46OzYqsm5GRwcyZU4iNjUUqleDsPJg9e/wZNWocpqZmtGnThB49\n+rBjxz4yM9N49uwku3encOGCB/Hxj4iNjUFNTZ3SpfWQSHJp2NCRmzcvYmgYjkwmJjv7Bl5e1iQm\nPuHBg/uK9A/FERMTrQiLO3bsMBYWlhw4sLfIUDl5+NrrUD4bGzvmzZvJjz92onTp0iQlJZGYmEC1\natUBig2tq1u3Af7+O+jbV+4NfvToITVr1vroe/Qt8+b8m5mZs29fgCK08ciRQ9jZ1SlW7EaeW9GS\n2rUtuXz5Ai9fvizQvkwm4/z5M4SGvuDZsye8ePGcfv0GYGJSTVGmXr2GBATsws7OAWVlZUJDX1Cu\nnCHq6oXztD19Gs6gQbe4e9cRff3HTJ36kAEDmhQq9y4OHbrB0qXRJCerU69ePKtWdSnSuC+O/AnF\np0+fy+TJezh+XBdV1RyGDtXA1bXZuxt5C5+alkRAQODrRjDmBAQEvlkuXw4nKamH4n1kZFPOnz+A\nsXHFj2rv2LFb7N37ElXVHMaOrYOxcYWSGuoX580UCm8Llcwrv2PHDmJjUwCKTOR85cpFypYtx9Kl\nKwH53qe9e3cpzmdmZmJhYUliogGqquvR1d1JfPxw7t27jZ1dNcqUKYOpqTmuroM5fvwoXl6rSUmZ\nQ05OJnp63iQkuFCuXCYi0buFIUQiEUZGxuzZs5NFi+ZiYlINJ6fe1KvXsMhQuR9/7MSyZZ6oq6uz\nbp0PtWtbkJiYgI2NXNW0Ro2ainQYQLGhdWPHTmDFisU4O/dBIpFga2vPhAlT/h3TO4f9XfHm/Pfq\n1Q8LCytmzJiMRCLB3NyCLl2cSExMxMOjsNjNmjUrCQ8PQyaT4eBQjxo1anLr1g3FPIpEIqpXr0l4\neBjZ2TlMnOhBx45dFLkeQR6uGBUVyaBB/ZHJZOjp6bNw4dIix/vrr4HcvdsXgPh4I1av3kX//lKF\nouT7kJGRwaxZibx4IRe3CQvLxMRkH5Mn//jebeRPKO7jcxJf307IZPoAeHpeoGXLMIyNq7x3ewIC\nAv9fiGRfiXxS3oJB4PvDwEBbuL/fMV/y/v7992WGD69FVpYxALq6V9m7VwULixof3NbFi/cYNAji\n4uRKkZaWWzlwoBWampolOuavHX//S5w7l4yeXg5Ll/5Eerq02LJhYaGMGzeSli3b0KhRE2xsbBk1\naqhCWKRly0acPHmRJk38yM7+i/R0R2Ji5mNqaomzszNBQbcYOvQXLC2t2bzZmz//9EJLqxzx8dmI\nRNmoqNTBz288d+/eICTkHu7uk/D2Xo+GRqki98wJfBgl+d39Fr0/Q4YcZd++7or3hoaHuHGjPqqq\nqu/dRkREOA0aZJKV9Tq1Sb9+u/j117bvVX/p0oUcOnQAIyNj2rfvyM6dRwgLkyKVahATM5fs7LIM\nGTIPS0tTxWf+5597snTpKmQyKRMmjMba2q5AuhE1NTViYl4wadIURCIR9erV5/Lli9/UvRF4O8K6\n6vvFwED7g+u8/+MnAQEBga+Mjh0bMGbMVczMArCw8GfGjNiPMuQAjh8PVRhyAHfutOLWrfslNdRv\ngh07LjBhggk7djixdm1Pevb0f2v5KlWM8PbeSvXqNdiwYQ0+PhsKnBeL5cEfw4YZIhanIxa/xN7e\nF1XV1z89efvKlJREaGhocOjQAVxdu1O3ri2HDs3CyOjb8o7GxSUwbNgefvrpGO7ue8nIyPiodiZO\nHENaWiqpqans2fPa23nz5nUmTXIvkbFevXqVO3eCS6QteL/k7p+TR49CWb36Hw4cuPBe5du310ZH\nJ+/602jaNOqDDDmQh2PXrv16DlVVX+Dg8P55ICdOnErZsgb8/rsXUVGRWFmZkpg4l1ev3ClffjI1\napylcuWCojn55zk8PIzu3XuyZctOtLS0OXPmJAAeHh6MGzeZTZuKTqEgICDw/SCEWQoICHzTTJjQ\njgkTPr0dAwMlIB2QL8S0tZ9TuXK5T2/4G+L06TQyMvL2e4m5fNmYlJRktLWLTmD86tUrtLW1+eGH\n9mhqavH33/sKnM/KygSgV68W+PrOw9FRyuzZ7XByWs39+3cBkEolpKXJE5Nv2LCOhIQEVFVFhIbe\nJzU1DS0t7QJpJIoLJsmvcrlxoxc2NnY4ONT7xBn5cMaNO8k//zgDIi5fzkUk2s6KFV0+uJ280NWo\nqEj27PGna1d5KGx8fByBgYUVRz+GK1euIJMpKxQoP4U3w3j/ay5fvsfw4clERPRERSWSa9f2M3fu\nT2+t061bQ7S1b3H2rD+GhiKGD+/+1vJFoayszLp19Vm8eBupqeo0bSqmX7+WH9yOTCbj9u0gFixY\nio3NC/bvT+H580iWLDHkzp2nxdYrKt1Iaqr8IYCNjS0Abdt24PLlix88JgEBgW8DwTMnICAgAAwZ\n0pKuXbehp3eCChX2MW5cNCYmRl96WP8pOjpZyBMvy9HTS6BUqeLDTJ8+fYybmwuurn3ZtOlPnJ0H\nFTifl85AWVmZFi1aExh4g0mTxlKvXgOioiK5f/8e8+bN5Pnz51SpYoyOjg7jxv3CwYP7iY+Px919\nBCdOHEMkEim8EfLXhceS31sxaNDQL2LIATx9qkue3Dwo8/hx0fO3bdtmdu2SGz+rVi1nzBj5nsQb\nN64xZ850evT4iaSkRNat+52IiHBcXfuyZs1KQIRUKmX69Mn06+fE3LkzFG1ev36VgQP74ezcG0/P\nueTkyCX4nZw6KVQlQ0LuMWrUUKKjo/jrr7/YuXMbrq59CQoKLHKc/v476N+/B/PmzSjy/KFDB/j1\n1yWA3KDevt3vQ6arEHv37ubw4YMfXG/z5lAiItoAkJNTkYAAXbKyst5Zr00bO+bNa8fIkW0/Ogdg\n1aqVWLeuE35+bXBz+3BDLj8ymYxu3RqyadMP6OurY2lZDbFYjEz2OtxZvtdQTuF0I4UVU7+S3TQC\nAgKfCcEzJyAgIIDc4PDy6kVKSjIqKqpFqt9970yZ4sjDh94EBlpQpkw0c+YYvnWBW69eA+rVKygB\nn5ckHFAsQG/evE5ERDj16jXk2bMnlCtnqEgYff/+XVauXE5qaiqGhhVYtWodISH32LFjK0uW/Mqq\nVctRU1NnzBi5+/XUqeMsXSqXhff13cjhwwfR09OnXDlDzMzMAViwYDaNGzehefNWODl1on37jly4\ncA6JJJd58xZhZGRCQkICc+ZMIy7uFZaW1ly7dgVvb79PTs5cqVIKDx4oZoBKlQonvwawsbFnxw4/\nnJx6ExJyn9zcXHJzcwkODsTW1p47d4IRiUQMHz6aZ8+eKtIiHD9+hKysLKRS+dxevnyR7dv9CAy8\nwaNHD1m1ah1RURF4es5nz55d9OzZp8jwx/LlK9C7d29kMjG9exe//zC/OEdRFJ9X7ePo0uXDvWPy\nvgu+V1KSffGwzw/F2tqOo0f/wcVlMDdvXqd0aT1KldKkQoWKXLhwDoAHD0KIiop8aztaWlpoa2sT\nHByItbUtR4/+818MX0BA4AsheOYEBAQE8qGtrfNVG3Il4f0oDn19PQICnLhypSwXLjSmX7/Gn9ji\n68X048cPGTt2An5+/kRGRnD7dhA5OTlMnTqJJ0/qcOHCTEJCunL16pN/a8iYPHkPGzaks3JlFr/8\nshOpVKpYoIeE3OfkyWNs2rSdZctWEhJy73Wvb3jySpfWw9vbjy5dnBRz5+OzHgeHemzZspPmzVsR\nExP9idcqZ/HiBrRqtQVz87106LAZT8+iPTWmpmY8eHCf9PQ0VFVVsbS0IiTkPkFBtxSKmlC0V0Um\nk9G3789s3bqL0qX1ePDgPk+fyo3kypWrcPDgAbp06UZQ0M13jvdtTpulSxcSGRnB+PGj2LHDDw+P\n8Tg792HoUFeePHn81nYfPXqAm5sLzs59mDp1IikpKSQkxDNo0OsUDk2a1OXlyxgAevXqQlZWZoHP\n98iRbqxd+ztDhjjTp083hfcwMzOTGTOm0L9/T6ZOnYibmwutW0OVKv8AMtTUXtCzZ9oH73/7csg/\nrwMHuvHgQQjOzn1Yv34N06fPBqBZs5akpCTz8889CQjYSZUqxq9rvmGw5r339PRkxYoluLr2LbKc\ngIDA94PgmRMQEBD4hvjcizIlJSUMDQ1LvF1zcwuFd6dGjVpERUVSqpQmKSmqBAfLpfwfPmzIkiXb\nmT+/PK9eJXL4cEt0dJKRyUqxa1d3Gjc+9W9rMoKDb9G0aQvU1NQANRo3blps382ayQ2qWrXMFAIR\nt28H4em5HID69RsWuy/wQzE2rsD27e/eI6esrEyFCpU4dOgAVlY2VK9eg5s3rxEREYGJSdW31lVV\nVVPsczMxMSE09AWNGzfh4sXzpKSkcPfuHTp06MTDhyGAPCG1VCq32rKysott900mTpzK1auX+f13\nLzZu9MLU1BxPz+XcvHmd+fNn4uOzrZCxmffxnD9/FuPGTcbGxo6NG73w8VnP6NHjyc7OIj09jeDg\nW5iZ1SYw8BbW1jbo6emjpqZeIIxWJJKHlG7Y4MulSxfw8VnPb7+tISDAH11dXfz8dvL06RNcXfsy\nfrwJu3frcOzYLkxM9GnTpsN7X+eXJi+hOICn57JC59XU1Fix4o8i6xaXbsTCwqKA+MmIEaNLYqgC\nAgJfIYJnTkBAQOArx9d3I336dGPEiMGEhr4AYNSooYSEyNU2ExMT6dFDLvYgkUhYvXolQ4YMwNm5\nD/v2BXyxcedHReW1l0QsVkIikSASgURS8GcoIUG+zy4zM4fc3PKAGJACusTF5eTbL/SmUVu8iylv\nX1Fev4oaX3gvkY2NLdu3+2Fra4+NjR179+6mVq2CCcdLlSqlSOSeR357XiaTGz29evXj5csYdu/e\nQcuWrTl69DC2tvaAPKQyz3N55swJRV1NTU3S09PeOc48cY62beW50+ztHUhKSiq2bp4KZ56HsV27\nDgQG3gLA0tKG4OAggoIC+flnV4KCbhIcHFjAG5mfZs1aAHJPZnR0FCA3xFu1+gGAatWqU726XADE\nxKQiQ4a0o02bL7Nf8mtAJpNx9ux1AgJOv9eeQQEBgW8fwZgTEBAQ+Ip5Vzjhm/z99z60tLTYsGEz\nGzb4cuDA3nfusflSGBmZoKKSjIbGWQBEojgcHOSJug0N9TA3DyAnpxJqavcwMTmAlVWpf69FhK2t\nHWfPyhes6elpXLhw/oP6trKy4eTJYwBcvXqZlJTkEr2298HGxo74+DgsLa3+9UypFTJqdHVLY2Vl\nw4ABvVizZhUgIisrizt3bgMQFvaCKlWMqFChIqam5mzatJGzZ08hFosVyeBdXd1YuXIZgwcPQCxW\nVnxuWrRowdmzp3F17UtwcNECKPkpbPx+uJfY1taOoKBbxMRE06RJMx49evhWYy7vIcCb4h5f2hD/\nGpHJZIwe7U/PntXp3t2WHj32kppa9J5NAQGB7wchzFJAQEDgK+ZDwgkBrl27zJMnjzl9Wu6BSUtL\nIzw8jAoVKv4Hoy1IQXGMwueVlZVZufI3PDxmkJqag7q6MgsW+PLkySNUVVXw9bVh3bqr3LnzgFKl\ngrl82U6xX6hWLTNatWqDi0sf9PT0qV3b4n1GpBiTq6sbs2dP48iRQ1hYWKOvX+atyp2fgzp16nLq\n1CXF++3bX3tR/f33K17PmjVf8To6OgpjYxP27NnJokVzMTGphofHTAB69OjNrl1/sW6dd4F+5B7A\nwh5aExMTfH23v9dYixbnKJhPTSaTIZOBpqYW2to6BAUFYmNjy+HDB7Gzq/PvWOzw8lqNnV0dRCIR\nOjo6XLp0gWHDRuVr5+1jkRvix7G3d+DZs6c8ffr2/XtfA6mpqRw7dliRYuJzcPVqMP7+LZBK5Sq8\nly+7sn79LsaN+/Gz9SkgIPDlEYw5AQEBga+aor0f8n1Qck9FdnbBcKpx4yZRt26Doqr9pxw9egaQ\nh+XZ2zsojru7T1K8Nje3YO/egoaGnV0dxeJ/0aIuQNF70AYMGMiAAQMLHZ86dZbidX6jyMzMnFWr\n1gFyxb8VK35HLBZz504wDx7cQ1n56/9JLF++Alu37ipwLCbmJbdvP+XKlct06vT2/Xo5OTns2XMO\niUTK0KHvs8h/Lc7h6TkXZ+c+aGhoKMQ5iksbMW3abJYt8yQyMgI9PT3WrNmoGD+gCAG1sbHj1atX\naGlpve6xWIefiI0bvVBVVSUxMYH+/XtibGxM1arVCtT/GKRSKUpK8mClkSPdGDnSXaGOWhKkpCQX\nyBf4OUhPz0Qqzf9AQkx2tiB8IiDwvSOSfWSswuLFizl9+jQqKioYGRnh6emJtrY2AF5eXuzevRsl\nJSWmT5+Oo6PjO9uLjU35mGEIfAMYGGgL9/c7Rri/n5eHD0NYsGAO69dvQiLJZeDAn+ncuRuhoc8x\nNTWjSxcndu7chr//Dvz997N//x4uXbrAvHmLUFZWJjT0BeXKGX6UQuf3eG/T0tJYuvQUL1+mEB29\nEx0dDVRUlBk/3qNEF+//FYcO3WDKlBxUVX9HRSWXuXMn0aZNnSLL5uTk0L+/P6dO9QfEtGixHV/f\njp9VvdXbez0aGqUKiHN8anvq6ho4OfVCVVWViIhwxo79he3bdxdrjEdFRTJ+/CjMzGrz8GEIJibV\nmDFjDv369aBVqx+4du0K/foNQFtbB2/v9Tx58hgLC0s8PZejoaHB2rW/c+HCOcRiMfXqNeCXX8aQ\nkJDA8uWeChXU0aPHY2Vlw8aNXsTERBMVFUlMTDQ9e/bByak3s2Z5cP78WYyMjKlbt8FnESTJycmh\nT59dnD3rCqhQs+Yutm+3xsioQon3JfBl+R7/NgvIMTDQ/uA6H/0Y0tHRkYkTJ6KkpMSyZcvw8vJi\nwoQJPH78mEOHDnHw4EFiYmJwdXXlyJEjiideAgICAv/PtGnThGPHzr13+XPnzmBgYFAgnFAkkivX\nzZjhwf79e2jY0JE8D16nTl2Iiopk0KD+yGQy9PT0Wbhw6We6mm8LmUzGwIEHOHVqICBGS6sRy5dH\n0bXrl/difixr1sQQHd0LkCfMXrfuL9q0Kbrs7t1nOXXqZ0AeHnnqVH/8/PYyeHC7Eh3Tm/n/TE3N\niYgIZ8WKJSQmJqCurs7kydPQ1y+Li0sfdu06AEBGRgb9+jnh77+f6OioQuWNjEzIzs4mKOgZ/v49\nUFMTExv7En19fWbO9MDDYyba2tqMHOlGzZqmBAbeQCKRMGTIcMLCQhk/fgrKyspcuHCOXr26IpFI\n0NUtzdq1fzJ79jSuXr2Cg0M9qlathrFxVf76ayvduvXg3LnTiryIaWnyPWgrVy6jZ8++WFvbEh0d\nzYQJo/Dz8wcgLCyUqVNnMWXKOHx8NtC1a49C+QJLkvyeRD+/rmzYsAdlZXU6drShSpXyJdZPVFQk\nkye7s3nzXyXWpoCAwKfz0cZc48av8w/Z2Nhw5MgRAE6cOEGHDh1QUVGhcuXKGBntZbufAAAgAElE\nQVQZERwcjK2t7aePVkBAQOCb58PCnkQiEXXq1GPZslWFzuXf7zRkyHBAvtjs3Lkbbm4jhNxSbxAX\nF8e1axbIFTIhNdWSkycf0LXrlx3Xp5CZqVLgfVaWSjElITdXSt61y1FCIilZIZH8gj1yT3J/TE3N\nWbJkIRMnelC5chXu3r3D8uWLWblyLTVr1uLmzevY2ztw8eI56tdvhFgsZsmSBUycOLVA+XnzFrNz\n5yPCwxuRmLgUC4vWLF/+G/b2DgXSH4hEIrKyMvHx2UZQ0C0WLZpHuXKGXL9+FQeHerRr14Ht27dw\n9eplHB2bsmfPLjIzM9HQUCcsLJTQ0OckJiZQp05dNDW1UFVVw9NzLo0aNaFx4yYAXL9+lRcvnimu\nOz09nYyMDEQiEY0aOaKsrIxYLEZPT5+EhPjPKtiSP9RVXV2dUaPaC54bAYH/I0pkg8Du3bvp0EGe\n0+Xly5fY2NgozpUvX56YmJiS6EZAQEDgu2Lbts2cOnWc7OwcmjZtzqBBQ4GiPRvvw9q1J/njD1VS\nU8vQqNEpvL27oqGh8Tkv4ZtCS0sLXd2XvBb4k6Kt/W3Jtw8fPpC1a18LnLRrJyEkJJzs7Mqoqz+j\nfXvYuXMbnTt3Q02tYPikk1MT/P03c+mSKyCiQYOt9OtXjBvvIylKsCc7O4s7d4KYMWOyolxOTi4A\nLVu24eTJY9jbO3D8+FG6d+9Jeno6t28HFyq/adNFIiPtADFKShLS05WJjMzE3l6e/mDGjCmK8q1b\ntwXke/IyMtJRUhJz9eplLlw4S1ZWFomJiYB8L1tQUCB2dnXQ1S3N7NkLGDiwP5MnT8fU1AyADRt8\nuX79KqdPnyAgYCcrV64FZKxf74uKSmHjWVlZfkwikRAX94qRI92oVKkKMpkMH58NXLx4jqysLCwt\nrZk0aRoA/v472LcvALFYjIlJVebMWUhGRga//rqEZ8+eIpHkYmtbh9u3g0hLSyUyMgIDg3IkJydh\nZGRCZmYmP//ck7lzF1G+fAVcXEYRF5eARJLLkCHDcXRspgg3tbS05vbtIMzMatO+fUd8fNaTkJDI\nrFnzMDe3YONGLyIjw4mIiCAxMZF+/QYU2ospkUhYt+4PAgNvkJ2dQ7duPejcudunfXgEBAQ+irca\nc66urrx69arQcXd3d1q2lCdhXbt2LSoqKnTq1KnYdoSnwwICAgIFuXr1MuHhYWzYsBmpVMqUKeMJ\nCrqFmpp6Ic/G++zlio2N5bfftElIkP9tPnHCjpUrdzNlSuHkyYcOHeDBg/u4u0+id++utG37I66u\nQ0r8Gr821NXVGT9enWXLDhAfX4E6dQKZPLlkQww/N/kNOYAJE9phYnKB+/cvYW1dms6d29Cjx0+0\nbftjIWNOXV2dHTs6s3nzHqRSGePGdSMjo6Q9RoV/72UyGVpa2kWGGDZu3JT169eQnJzMw4ch1KlT\nl/T0NLS1C5dfterwR48qNvYl6uoaLF68gq1bfTExqcru3TupUkWu/GhsXJV9+wKIiAgHICsrk7Cw\nUMqWNSAzMwMTk6rcvRtMQkICAHXrNsDffwd9+/4MwKNHD6lZ83WOwJ07t/HixXMqVqzEb7+txcvr\nD+7du42jYzNOnTrO5s1/MW/eTC5cOEfjxk3YutWXXbsOoKysrAjl3LzZGweHekydOou7d+8wZsxw\n9u07zNatm9i+3Y9Bg4YSHBzIgQN72bVrB23b/kjVqtWQSCT88ccfZGTISExMZNgwVxwdmwEQHh7G\n/PlL8PCYyeDBAzhx4ihr13pz/vwZNm/2USQtf/r0CV5em8jISMfVtR+NGhXUPsifAiU7O5sRIwZT\nr16DL6KaKyDw/85bjTkfH5+3Vg4ICODMmTP4+voqjhkaGhIdHa14Hx0djaGh4TsH8jEb/gS+HYT7\n+30j3N/3RySSz9edOze5ceMqQ4bIF4MZGRkkJr4kLS2N9u3bUblyWQDatGmNpqbaO+c4NjaSpKRK\n+Y6okJ1dCgMD7QJKfQDa2upoaKhiYKBNxYoV6NChbbHtl9S9bdmyJQEBcs/DgQMH6Nu3LwBXrlzB\nx8eHdevWlUg/RREeHs7w4cM5cOAA7u7tcHNLIzExkQoV7L+5/dx2dnbcunWLK1eu8Mcff6Cnp8ej\nR4+wsLBg8OBlbN68mVevYnF3H4G+vj6+vr78/fffeHl5AdCsWTOmT5+gaO8TRSAL0aKFI1OmTMHd\nfRQ5OTlcuXKBXr16YWRUhRs3LtCuXTtkMhkPHjzAzMwM0MbGxpp1636jdetWlCunA+gUWX7ChLbs\n3z+OsLAGSKXKaGpKqV1bHwMDbXbsOE7jxg0xMNBGRUXMxYunadu2BdevX0dHRwcdHR3EYjHDhrnS\noEEDhgxxZefObZQpo4WjY0MCA6+yZMlipk+fyKNHj1i8eB4eHh5UqVKO8eMnkJqaSmRkJPPmzcPA\nQJt582Yzd+5cBg3qh0QioW7dujRqNBtNTTU0NdVJTVVDR0cHLS1NypTRpH//Pty+HcjkyWOJi4tj\n4MC+JCUlYWVVGwMDbczNzfD0nEXr1q1p3bo1pUqV4ubNq1y5cgF//23Ex8eTnZ3F8OGuREdHI5VK\n8fffRk5ODiKRiNDQZwwY0A83twHk5OQglUoRi8WIxWIiIsJZt+43rly5QpkyZfDwGEfXrl2Jjo4k\nOTmRlJRYHBxs8PX9EwMDbbS01Gnb9gcqVSoDlKFRo4aEhz/BzMwMZWUxBgbaBAff4MGDB5w/fxrI\nSxQfh4GBacl+oASKRfjdFcjjo8Msz549y8aNG9myZcu/4RRyWrZsyfjx43FxcSEmJoYXL15gbW39\nzvaE2O7vFyF2//tGuL8fhkwm/3uXnp5N377OXL58kZcvYxCJlEhKSiczM4sHDx7TqVNnpFIpKSlJ\nODn14cWLGBYunMPFi+eoUsWYgQPdiIgI5+7d29y5E0xychI1ayqTlNSC6OhfqVHDmsDAslhbz6Bc\nOQPq12/ElSuX0NTUJDk5mZSUFEJDI3j69Bnbt/szZowRI0e6YWFhxc2b10lNTWHRIk+MjU3JzMxk\nwYLZPHv2FCMjY169imXcuMkfpP4olcqIi0slLS2NLVv8aNNGHs2RmJhOVlbuZ/0MxcenkZsrKdCH\nqqoOcXFpn63P9+FjJPDzPj+Jiencu3cPPz9/ypQpy/Dhgzh58jzt23fF29uH335bi46OLvfvP2XJ\nkqV4e/uhpaXNuHEjCQg4QJMmzT/Ld9fAoArNmrWiQ4eO6OnpU6uWOWlpWUydOodlyxbx+++ryc3N\npXXrHyhTRv7wwdGxBTNnevD7716K8RRV3sVlME5O1Xnw4A729gdxcFjG4sVLC6Q/iI1NISdHglQq\nol279kREhDN3rie///4r9eo1RCqVcOdOMF26dKVmTVNycsS0bt2RI0fGMGzYMOrVa4Cqqjrjxk1W\nhFmuWeNNVFQkEyeO4fz5y6xb54WBQTk8PZfz6lUsK1Ys4cqVa9jY2ODt7YeRkQmrVi1HJpPh7b2N\nkJD7zJw5lZSUVCSSXCpVqoy39za8vdcTH59MbGwKCxYsJzDwJhcunGP16jX4+u4gN1fKnDmLqFLF\niN27/+LVq1cMHfoL7u6/cP36VdzdJ1O+fEW6dm3Py5cvmTRpMr//7kVwcCC+vn/SrVtPevXqR7Nm\n9RGJVFm4cDmTJ7uTkZGBikopGjduilgsZs0aLwYMGEhWVjaxsSmkpWUhk8kU9yIzM4eUlKwC36Os\nrBzGjJlQKAWK8Fvw3yD87n6//KdqlvPnzycnJ4eBA+U5fmxtbZk9ezY1atSgffv2dOjQAbFYzKxZ\ns4QwSwEBAYE3qF+/ARs2rGP+/CWUK1eO8PAwJk92x919ImvWrGTz5r8oW7YsLi79EIlg06Y/0dTU\nonLlKvj6biclJYVHjx5w/fpV1NTUOHz4NO7uI8nOVkdHZxdXr2bRtKkjY8dO5MWL5wwY0IudO/dz\n/PhhduzYStu27enUqSsuLn0K5AmTSqVs2ODLpUsXWL16NUuWrCIgwB9dXV38/Hby9OkTXF37vvXv\nuofHBF6+jCE7O4sePfrw009yhRGZTMa6db8TERGOq2tf6tatT8OGjmRkpDN9+mSePXuCqak5M2fO\nA+QiE2vWrEQikWBmVpsJEzxQUVHByakT3t5+6OjoEhJyj9WrV/L7714kJCQwZ8404uJeYWlpzbVr\nV/D29gPkecQWL17AnTtBioV4/geRxfGmV7MkyS9c8TGYm1tQtqwBADVq1CIqKgorK5sCZe7fv4u9\nvQO6uqUBaNOmHYGBt2jSpPlH9/suisv/t3x5YREfgObNW3H27NUCxypUqFhk+WHDRhZ47+Xlo0h/\nkD/XXNu2HejRow+TJ7tTvXrNf0NsJ7/ZHABqamq4uAxmx46tLFhQvPJrWFgos2cvZPLkacyc6cGZ\nMyc5ePAAEyd6IBaLGTNmuELYRSaDlJQUxo3bwPPnf1OjRlXs7R348891iMVi0tPTOXXqOC1btkEm\nkxETE429vQPW1racOHGUjIwM6tVrwK5dO3B3n0SdOvUYN24UPXv2pXZtC27fDqJsWQNmz54KyENo\nZTIZlStX4dKlC5ibmxMcHEjNmqZIJBKFcEsezZq15MGD+1SsWIkbN64VOCeTyTh//gw//+xKRkY6\nt27dYPjwUWRnZyvK1KvXkICAXdjZOXxyChQBAYFP46ONuaNHjxZ7btiwYQwbNuxjmxYQEBD4bslb\nvNet24Dnz58zaFA/0tLSUFJSQklJzJ07tzExqcbkye7o6eljaWkFwI0b1xg5ciz3798BQFtbm5iY\nGCpUqEiZMmWZN28mxsbGqKqq4O7eFkfHady4cQ1X176kpqaioqJKZmYGd+7cViwgq1evgb5+mQLj\na9asBQCmpmZEREQAcPt2ED179gGgWrXqVK9e863X6OExEx0dHbKyMhkyxJnmzVsqrv1NifabN6/z\n6NGDAh6m27eDqFXLjIUL57Bq1ToqV67C/Pmz2LNnFz179inWAPLxWY+DQz3693fhypVL/P33PsW5\nohbiP/zQvkjDs02bJnTu3J3r168ybtwk7t27w6FDcvn8jh270LNnn0Iy7du2bSEzM4OBA90KeTin\nTJmJjY0tWVmZLFw4hydPHmNkZEJWVtYnqRyqqKgqXovFSkgkuYXKiESiN/r4fKqKJUVReeGmT59D\n//49ijTiAR4/fsiwYQNJTExESang56NChYqMGTOBSZPcWbLkV27duvGv5wzi4lKwsBiMiUlasQ8V\nQkLus3z5IpSUxKxZs4pp02ZhampGcHAQN29eY8CA3mhqliItLQ1VVfkDgkuXnpCbq8fx46Foa0ej\npVWWrl2dCA19waFDBxg/fhS1a1sCcjGRefNmkpaWikwmo0eP3mhpaeHiMphVq5bj7NwbqVSKnp4e\n48b9QkZGJllZWQwePABlZWWMjEzQ19fnwYP73Lx5nR9+aMfBg3uJiopCU1MLZWVlxf7JvO+Oqqpc\npEVJSQmJRFLgnEgkonr1mowePYzExERcXQdTpkxZoqIiFWWEFCgCAl8PJaJmKSAgICDwfhw9ekbx\nunr1GlSpYsyvv65GTU2NUaOGUrOmKaGhLxQLyTyOHZPn65RKXy/Gc3NzABnLlq3k1q0bbNniw4MH\n9xk9ehwikRILFy6jShUjzp07zZkzpzA2Nvm3ZvEL+jwDQUlJTG7ua+PgQ4wOf//tnDsnv86XL18S\nFhb21nYKe5giUVfXoGLFSlSuXAWA9u07EhCwU2FUFsXt20F4ei4HoH79hmhr6yjOVahQiRo15Eao\nqakZUVGRQNGGZ2ZmJhYWlowcOZaQkPv888/fbNjgi1Qqw83NGTs7e7S0CobC5Peyvenh9PFZz2+/\nrWHPnl1oaJTCz8+fJ08eM3Bgv88SuVKqlNyw0NHRxczMgt9+W0ZSUiJaWtocP34UJ6feJd5nSZOX\nq83S0hpPz7kEBPgXO1cymYwnTx6zfv1rwY6yZcsW8CTlZ8cOP8aPn4KX10P++acTV67ooa+/i4oV\n77FzZ4DioUJwcCC1a1vy229LmTjRg/nzZ9GhQyfWr1+DiUk1Tp06jo6ODgcPnmDNmpVcvnxRYdy/\neKFDfPxQkpOd0NAIIienF2pq6nTr1pM7d4JZu3ZjgTGtWfNnoXGqqakxceLUQsejoiLp2bMzixat\nwNLSikWL5lGxYiUiIyMwNCyPrm5prKys6NixC05Ovbl16waGhuXQ0dHF13cHPXr8BMDUqbMICbnH\n5csXqVChIr6+OxR9VK9ek+nT5xToN38ZqVSKm9sIhg79pbhbKCAg8B/xbe38FhAQEPiOyFPtU1NT\n4/nzZ9y+HcyxY1e5deuGwthITk4CoG7d+pw6dYLExHiSk5OIj4/jxYtnREdH8fjxQ2xs7JBKpchk\ncjEVsViJXbvkCy9zc0uuXbtCcnISVlbWnDp1ApFIxNOnj4mPj3vnOK2sbDh58jgAz5495enTx8WW\nvXnzOjduXMPLy4dNm7ZRs2YtsrPfLv9f2MMkKbRwl8lkimNisVhh1GZlZRcqVxR5ngiQG6p53gh/\n/+24uPRl6NCBCsNTSUmJ5s1bARAcHPiv1L46GhoaNGvWkqCgW0UaFvn7zu/hjI6OAiAoKJAffmgP\nyA35d3k4iyJ/v8XZgT/91JXx40cxZsxwypYty7BhIxk9ehiurn0xM6uNo2PTD+73v6ZcOUMsLeX7\n7du2/ZHbtwOLLSsSiWjSpBmqqqro6pbG3t6Be/fuFFveysqGlSuXc+nSfcRiCSAmI8MEZeVKlC1r\ngEgkokaNWkRHRxEa+pxnz54wd+4MwsPD2LzZm9jYWLKzs8jNzcXIyIRTp47Ttm0HZDIZjx8/AkBZ\nWYJIJEMq1f733xMAjh7955Pm5dGjF+zZc5by5SuyZ89O+vfvQWpqKr169WPq1FnMmDEZZ+feiMVi\nunRxypuhN2eswOuiPsvFfbYyMjJwdd2Jnd05mjc/xOHDtz7pegQEBD4dwTMnICAg8IWoX78Re/fu\npl+/HiQkKJOcbIWPTzuqVoVJk9wRi8Xo6+uzYsUfODsPYsWKxYjFynTu3I5KlapQu7YFhoYVGD58\nEFKplFKlStG/vzNaWlqoqqqSm5ubL0RLn6FDXdHU1EJDQ4MjR/4hNvalwiNWFHmLvG7dejB//iz6\n9++JsbExVatWK7A/KT9vGqh37xZcVJcqVYr09PR3zo2RkTFRUZFERIRTqVJljhw5hK2tPQDly1cg\nJOQeM2ZMLpD/Sm50HqNfP2euXr1MSkryW/vIb3jmeUazs7NQVVUr4GXLT55Rmd+gBLmUff6y+T2c\neYZjSZDn2bW3d8De3kFx3N19kuJ19+696N69l+J969ZtFXnXvhXyz6V8zpXeasQXrl/8s+r+/V1o\n1MiRfv18qFKlD+Hhcq+YsvLrOnkPFQCqVq3OrFnzmTJlnMIztWmT3LM2c+Y8li1bRGRkBNHRkZw/\nf4YaNWpSt646p08/ISHhJWJxW2Syvbi6nqNu3QYf7Y3dv/8qHh6qxMa6UqaMBb17pzJjxmsPfp06\ndfH23goUFMjw938dbpybm4uXl7fCa21mZs6qVQWVZAcOdCt2DIsWneDgQWdAmehomDPHn9atc1FW\nFpaTAgJfCuHbJyAgIPCFUFFRYdmyVdy//5BWrTTIza0NwL179WnV6i9mzHidI05DQ4Np02a/d9tH\nj5794PHk7T8CKF26NCdOnCA2NgWRSMT06XNQV1cnIiKcsWN/wdCwfJFt5Bmo/fv3oEoVY8Wevzxv\ngDwEzIYBA3rRoEFjGjZsXKQXQFVVVeFpkEgkmJtbKDwNrq5uLFo0l8zMLMRiZcXi2NXVjdmzp3Hk\nyCEsLKzR1y9DqVKapKWlFWmUvcvwBLCxsWXBgjn07++MVCrj3LnTzJgxDz09fYWXVF1dg4sXz9Ow\nYeO3zq+trR3Hjh3G3t6Bp08f8+TJo7eW/1SysrLw8TlFVhb07l0XQ8My7670lRATE82dO7extLTi\n2LHDWFvbkJ6eRkjIPRo0aMSZMycUZd9HsCM/ERHhVKtWg19+ac+qVS9RVb2CkdFzqlbVKVTWyMiE\nxMQE4uLi/lWYzCUsLBQXl0GcOnWc2NiXLF++ijVrVnH58gVcXAYDsGLFPCIjo3jwIBh7+5/Q1R2g\naHPEiNEfNScbN8YRG9sTgLi4Rnh776R79/evf/jwLebMiSE21pDatU+wYUNLDA3LftAYXr1SJf/S\n8eXL8iQlJVGmzLfz2RIQ+N4QjDkBAQGBL0xurgSpNP+fYxFSacnupfL3v8SmTUlIJCK6d1dlyJAW\n76wjkUgYPXoXp07poK29mrJlZejpaTJhwpRin8TnGaiF+3/tHZg1a36Bc3Z2dRSv83uYHjy4z48/\ndsLJqTerVi1n/PhRrFy5ltzcHMzMavPq1StUVFRITk5m6FBX5s1bxIoVv5OcnMzMmVNITU1l+PCB\njB49Hl/fHWzc6EVkZDiRkZGUL1+Bn37qxpIlC2nZshGqqmqYmFQFCnqFatUy48cfOzJkiDMAnTp1\nVSSHdnEZTIcOrbGxsVPULRp5e126OLFwoVzIw9jYBDOz2m+p82nk5OTQv38AZ864ACoEBGxn586G\nH7x4/1IYGRmzZ89OFi2ai4lJNbp27YG5uSWLFs3lzz+1sLOr80GCHfJy8v/9/bdz8+Z1RCIlWreu\nRLdulcjN1WfPnieFxqGsrMy8eYtZuXIZqany1AK9evWlatVqTJ06C0/PuYhEFOlxe/z4FTduJJCV\n9Yh27ep+8pxIJOIC73NzxcWULIxMJsPTM4InT+R7Ti9fbsKCBVtZtarzB42hbl019u4NJyenMiDD\nwuIR+vo276wnICDw+RDJPkVKqwQR8mV8vwj5UL5vhPv76UilUgYP/ou//+4HlKJWLX/8/GwwMalY\nIu3fu/eEbt0yiI9vCICm5gO8vV/SooXtW+v5+Z1j3LgmgFzso3Tps5w+bUjFiiUzrndx9+4dduzw\nY968RYwYMZjc3FzWrPmTLVt80Ncvw7Jlnixe/CuNGjmyZs0qcnNzCQy8QWRkJGXKlGHSpOn8+WcQ\nt2/vwMFhKDVrRnDt2hXWrPkTVVVVZs+eRrduPbC2tiU6OpoJE0bh5+f/QWNs06Ypx459uBf0c3P0\n6CX697cH8ow3GRMn+jNxYntFma/1u/umUui3yNat55g1y4jkZEvU1Z8yfvwtxoz54ZPa9PE5w9y5\nNUhLM6NUqYdMmXKPYcNaFVn2zXsrkUiwsztHdHQnxbH27Xfj6/vhY/LyOsHFixJ0dTOZNq3RN/OA\n4Hvia/3uCnw6/2meOQEBAQGBkkFJSYkNG3qybdsxkpNz6NbNgQoVit/L9qFcu/aI+PjX8VhpaaYE\nBd2mxTuccxERueQZcgCJiTV4/vzRf2bMmZqa8eDBfdLT01BVVcXMzJyQkPsEBd1i7NiJqKio0KiR\n479lzbl+/Qre3lvp2LENqqqqTJgwjYQEPcRiMVu2dKZJk3F07NgUVVX5frbr16/y4sUzRX/p6elk\nZmYWyJW1bdtmVFVVFd7BJ08es3LlWm7cuKZIfbB+/RouXjyPmpoaixYtR09Pn4SEBJYv9yQmJhoA\nS8sfiI8vS2rqecqUUSEqKpKYmGh69uzzWdQl1dSUEYkyef24Voqy8lfx7Pa9+Jrz0z56FMbs2bd4\n9UqT2rWTWLy4o+Izlcfu3RkkJ8tTD2RmVmPfvkDGjPm0fl1dm2FsHMitW7extS1Hq1ZFG3JFIRaL\nsbd/yaFDuYAyKirhNG78cUvAoUNbMXToR1UVEBD4DAjGnICAgMBXgFgs5uef339x9iE0aGBK2bKX\nePVKvqdLS+s+dnZF73nLT5s2lfnzz0CSkuQePHPzc1hb/3dKiMrKylSoUIlDhw5gZWVD9eo1uHnz\nGhEREZiYVEUsfv0TpqQkyic0ImP9el86dDhDaGhXRZm4ODVFvq385VRUVCgOGxt7duzww8mpNyEh\n98nNzSU3N5fg4EBsbe05fvwIlpbWuLmNYM2aVezfvwdn50GsXLmMnj37Ym1ty/z5f7F16188f36M\ncuUeUr36GXbv/ou0tFT69u1O1649EIvfP2TufWja1IHOnf9i797OgBb16m1nyJAfS7SPz8WbMvlf\nG+PHX+fyZfkeuFu3stDW3s3cuZ0KlJGrZL5GWblkRHBatrSlZcuPq7t2bScWLfInNlaVevXUcHH5\nyIYEBAS+KgRjTkBAQOA7x9S0KgsWXMbHxx+JRISTkzrNmjV7Z722be1ZuvQ4Bw/uRlU1m7FjrYtV\nsfxc2NjYsn27H1OnzqJateqsWrUCc/O37zWrW7cB/v47KF9e7lVUVQ0hO9sMLa3sIsv17fszAI8e\nPaBmTdMCZT7GOwgFvX5PnqQgEskQidLJzdUlJcUMZWVldHVL/+vFi3+rqujHIBKJWLeuJ926XSY1\nNYsOHTqhoaFRon38PyKVSnnxIr9QihrPnqkVKjd4sCEhIaeJiWmMnt4NXFw+PHSqpNHQ0GDOnI5f\nehgCAgIljGDMCQgICPwf0LVrA7p2fXe5N+nSpT5dury73OfCxsaOLVt8sLS0Qk1NHTU1NWxs7IA3\nc669fj127ARWrFhMbu4jatdeR1ZWDapUaUXDhpULKGfmlXN27oNEIsHW1p4JE6YU6P9TvYMqKip0\n6XKIixd75WuTfHWUyM0tudQF+VFSUqJdu0afpe3/V5SUlDA2TiYqKu9IFlWrFs6j2LatPaam4Vy4\ncAB7+2qYm79d6VRAQEDgYxGMOQEBAQGBr5Y6depy6tQlxfvt2wMUr/NyrgE0b95KkehbV7c0c+Z4\nvrPt9y33Kd7Bvn1/ZtSoSoSFrScsrBM6Og9xdCwsgS/w7bB8uQOzZ28lNrYUlpYpTJ/eochyJiaV\nMTGp/B+PTkBA4P8NwZgTEBAQEPi/4+nTcDZtCkYkkuHm5kClSobFlv0U7/xS3VsAACAASURBVGCe\n169bt9o0b16LwMByQk6ub5yaNauwdWuVLz0MAQEBAUBITSDwHyBI6H7fCPf3++VrubcfKlX/4sVz\nZs2aipKSEvPmLaJSpYLekcjIl/ToEcijR90BGbVrbycgwBF9fb3PMPqvl6/l/gqUPMK9/b4R7u/3\ny8ekJlD6DOMQEBAQEBD4Ypw9e5oWLVrh7e1XyJAD2LPnxr+GHICIe/d6sn//1c8+rri4eHbtOsWt\nW/ffu87IkW6EhLx/+fdh4sQxpKWlkpKSwp49uxTHb968zqRJ7iXal4CAgIDA50UIsxQQEBAQ+OqR\nSCTMnTuDhw9DMDGpxowZc3j27Bl//PErGRkZ6OqWZtq0WTx8GMKuXdtRUhJz8+Z1Vq5cy44dfhw6\ndACAjh27ULp0OZSVQ6hceQwZGbaoqwehojKAbds2c+rUcbKzc2jatDmDBpVcMq17954xePATHj9u\nj4bGY8aMOcq4ce9O2CwSiT4o55pUKkVJ6e3PaZcuXQlAUlISe/b407Wr03u3/zYkEkmJp1h4kzZt\nmnDs2LnP2oeAgIDAt4RgzAkICAgIfPWEhr7Aw2MmlpbWeHrOZffunZw7dxpPzxWULl2aEyeOsn79\nGjw8ZtK5c3dKlSpF7979CQm5zz///M2GDb5IpTLc3JyZPn0ubdse5OHDUGJj3WjZ0pQaNcpy5kww\nGzZsRiqVMmXKeIKCbin2xn0seSGiOjoDePy4N3p6GxGJMti58yCqqvcICrpFamoKU6bMxMbGlqys\nTBYunMOTJ48xMjIhK+u1UuLVq5fx9l5PdnY2lSpVZurUWWhoaODk1IlWrX7g2rUr9OvnTKtWbRR1\njhw5xK5df5Gbm0Pt2paMGzeZXr26sHHjFlavXkFERDiurn2pW7c+DRs6kpGRzvTpk3n27AmmpubM\nnDkPgJCQ+4UM5zJlyjJypBu1apkSHBxEmzZt6dWr3yfN17v5epOJCwgICHwJBGNOQEBAQOCz4+TU\nCW9vP3R0dD+qfrlyhlhaWgPQtu2P+Pp68/TpE9zdRwByj1SZMvJcbTKZjLzd4MHBgTRt2kKRLLxZ\ns5bcvh3IwoUdGT58L1u21KBq1aqsXr2Sa9eu4OraF4CMjEzCw8Pe25g7f/4sz58/pX9/lyLPSyR5\nP7eif8eohESSy4YNvsyfPwsfn/X89tsa9uzZhYZGKf7H3n0GNHW1ARz/BwhhJYg4UARFRBxMte5t\naaWOalUcxYWr1FG34sCtddVVd0VxK65XrVqte9Sq4N4DZYsDgQgEEvJ+SEmhYB1FcZzfp+Tm3nvO\nvWHkyTnnedauDeHu3Tv4+emCo2fPnrF6dRDz5i1CJjNh7dpVbNq0jm7deiKRSLC0LERQ0Nocbd6/\nH86hQwdYsiQIQ0NDZs+ezv79e/WjfUOHDuXGjZusXLke0E2zvH37JmvXhmBtXQR//x5cunSBSpVc\nmDt3JtOn/4SlZc7AWSKRoFar+eWX1a90nwACAoYSH/+Q9HQV7dp1pGXL1nh51aNdu46cOnUCmUzG\njz/OxsqqMDEx0UyYMIa0tFTq1Hl3BesFQRA+FCKYEwRB+Mj988Nz8+ZfM23aRG7evI5EIqFZs5b4\n+HRCqVRy4MA+WrduS1jYObZv38SkSTNfuZ29e3fz2Wc1KVKkSK7XXmeqYF6yH6/VajE3N8fBwZEl\nS4L+dd9/tqvVavXBjEIhp2zZsvrXfH278fXX37xR/+rWrU/dui8ONtq3L8Hx4yfIyACJJJXChdNp\n0kQ3zfLo0UNYW+vu2cWLF2jXrgMAjo7lcHR0AuDq1cvcv3+P777zAyAjQ42rq5v+/NlH47KEhp7h\n5s0b9OypK4qenp6OldXfSV7yyn9WsWJlfQHzcuXKExcXi4WFBeHhdxk4MHfgrGv75dNFswsICESh\nUKBSpdGrV1caNmxMWloaLi5u9O79PYsWzWfnzu107dqDefNm8c037fjyy6/Yti3ktdoRBEH4FIhg\nThAE4SP3zw/Pzs4Vefz4kT47pFKpBCA5Oek/raHas2cXDg6OzJ79Y66RlyxeXvXo0aNPjjVsPj4d\nCQ5ewbZtIdSuXY/Dh3/HxsaGpUtXIZPJuHPnNnFxsXTs+A116zZg9+4d+Pp2Y9euHVy5chkXF1fU\najWRkRE4OJTN0Sd3dw+mTJmAr29XMjO1HD9+hLFjJ+UIZAIChhIefpdHj+JRqzNo06Y9GzasYceO\nrSgUlpQr54SxsTGDBg3nxIljrF4dhFqdgUJhybhxk7GyKsyePbu4efM6gwYNZ8qU8ZibW3Dz5jXi\n4+PJzMykYUM3Fi48w+TJSwENGRkZhIff4/jxo6hUKuLiYpk0aWye9zWrr9Wq1WD8+Cl57mNqaprn\ndm/v5vTp0zfHtr17d7/wPZRKjfWPDQ0N9EXQXxQ4A5iY5N32i4SEbOD4cV2NwPj4eCIjI5FKpdSu\nXRcAZ+eKnDv3JwBXrlxi6tRZAHz5pTeLFy94rbYEQRA+diKYEwRB+Mj988NzRkYGMTHRzJ07k1q1\n6lK9ek0AlixZoF9DZWRkhFxukef6qVWrfuHkyWOoVCpcXNwYPnw0hw//zo0b15k4cQxSqZRly4IB\nrX7kJUtmpjbXGjZPzyp88YU3QUHLaNPGB41GTWTkA44ePcQXX3izaNE8bGxKUKlSZfbs2Ulmppa2\nbTtQvXot5s2bhVKpRKNR0759J30wlzUgV758Bb76qjm9enUFoEWL1jg5lSc2NkY/apcV7G7YsIaF\nC+exbVsIMTHRLFu2CgcHR374wR8np/KArubcsmWrANi1awfr1q2mX7+BuUYAnz59wuLFQdy9e5vu\n3b8lKSmRyMhryOUymjf/mrCwc9jZlaZFi1Zs27aZQoWsGDt2Eps2rePAgX1UqVKNe/fucPfubSQS\nCZUru/LTT9OJjo7C1rYUqampPH78CDs7+xe+71WrVmfkyCH4+HTCysqKpKREUlJS9K+bm5vneP4i\n9vZlePYs4aWB86sICztHaOhZli5diUwmo3//PqSnqzA0/PvjiIGBRB9ECoIgCP9OBHOCIAgfsbw+\nPKvVGQQHb+TPP0+xY8dWDh06QEBAIP7+AwgPv8fKles5fz6UUaOGsmbN5hzrp9zcPPjmGx+6desJ\nwKRJgZw8eZxGjT5n27YQypVzIiLiAX36dOfx43iSk5OJjIxEpVIxffoUMjM1mJtb0KePHzKZjOrV\na3Hx4nmio6OwsJBTrpxuWmHJkrbExsYQFhZKXFwcZcs6kpDwjIkTf2TevFnIZDKcnMrz88/Lcl2z\nn1/vHM/bt/82V2KOEiVKEhy8EcgKdo8AEoyNjfH2bk5ExAOcnJwBaNSoCZGREQDExz8kMHAkT58+\nISMjg5IlbYGcUxYlEgn16jUAwNHRCSMjI3r16oqpqRlKpZLz50NJS0v7x2iaLhhs1aotU6dOwNe3\nHaVLl6FChUoAFCpUiNGjxzN+/CjS0zMA6N37+38N5sqUcaBXL38GD+5LZqYWqVTKoEHD9W1ZWVnh\n6upOly7tqVmzDrVq1SGv2bBGRkZMmjT9hYHz60hJeY5cLkcmk3H/fjhXr1751/1dXd05eHA/X3zh\nzf79+167PUEQhI+dCOYEQRA+Ytk/PD94cJ+rV6/w7FkCGo2aBg0aY2dnz6RJgUDOgESr1eLm5pZr\n/ZSbmwdhYWdZv34NKlUaSUlJlC3rSJ069QAwNpZx48Z1tm37lUGD+nLnzm1SU1NIT0/H1dWNo0cP\nUaxYccaPn8KiRfO5fv0qJUuWRCKR5Ehrb2BgQEZGBosXz8fKyooVK9Zw8OB+tmzZmK/358cfV7Jr\n137S0nrx+eepyOX7KF26DA8e3M92L/7ef86cGXTs2Jk6depx/nwoQUG5g0kAqVSqf2xoaMSmTTsA\nePLkMadOnWDbts1/jXhWACSEhPwPAJlMxoQJU/M8Z5Uq1Vi+PHeikZCQnS+8viZNvHKtp8tqC2Dc\nuMk5XvP0rKp/rAv8dF4UOC9YsPSFbeelRo3a7NixFV/fdtjZlcbFxRV48TrHH34YyoQJY1i3Lpi6\ndRv857WXgiAIHxsRzAmCIHzE8vrw/OjRI/r3/w6tNhOA777rn+exxsa510+pVCp++mkGK1asoWjR\nYvpU+VksLQuRlpaGRqP+a/80IiIekJ6um5JpZGREePg9VKo0HBwc2bVrBwMGDCEqKipH21otJCY+\nIyLiPhkZGXTs2BpjYxkpKamYmprky72JjIxm7VoJRkZliYlpx7p1YTg6TqdFi9ZcuBBGcnIypqam\nHD16SD9imJLyXB/g/tvas7zExcVRtGhRWrRoRXq6itu3b9K0aTOMjIxQq9UYGb36v+RLl25z8eI9\n6tVzoUwZ29fqx3+lUqmYM+cgjx4ZUq9eIVq1qvHKx0qlUmbNmp9r+/79R/WPGzZsQsOGTQDdCGr2\ntXq9evn/h54LgiB8fEQwJwiC8BF70YfnrIyJ2ZmZmb10DVVW4KZQWJKSksLhw7/TuLGX/vhy5ZyQ\nyWR06NAaCws59vZluHXrBhqNBnv70kilxvo1bEqlEjs7O/16tH+OukgkEhwcHBkwYDDTp0/BwEBC\n3br1uXHj2hvdi3+6dSuK+Pi22NpeonTpr8jIcMDS0p5ixYrRuXN3evXqikKhoHTpMpibWwC6KZxj\nx45ALldQtWo14uJi9X190ehS1uPz58+xYcMajIyMMDMzZ8yYCQC0bNmabt064uxcgbFjJ7203ytX\nHmXq1BIkJraiZMkjzJnzmEaN3PPlnryK77/fwa5dXQBjtm69SVraSTp0qJOvbSQlJbFlyx/I5ca0\nadPgpYXQBUEQPlUimBMEQfhEbN9+mjVrEtFqoVMnOe3a1c7xuqVlIf0aKplMho1N8VznkMvltGjR\nii5d2lO4sDWVKrnoX/vqqxbMnTsTqdQIQ0MjhgwZSdmyjvTo0Zm6devra8xlrWE7fPh3/vjjJAAW\nFhZ06OCrP1e9eg2oU6c+vr7tSE1NIzh4A2q1mkWL5lOxYqV8uR+ffVYBR8fT3L27HACF4jIDBybi\n4eGOs3NFWrZsjVqtZvToYdSv3xCAunUbULdug1zn8vZujrd3cwBGjRqX47WsUafs+2Tn798ff/+8\nR0fzsmpVKomJuumQMTGf88svm99ZMKdSqTh92hbQjdqmpDhz8OAVOuT+buCNPXnylPbtj3Dpki+g\nZN++TSxf3v6FAV3W9GAxBVMQhE+RCOYEQRDekL+/H4sXB/H48SPmzp3F5MnTC7pLL3Tp0m1GjbLk\nyRPdKNrVq6E4OFynWrWKOfbLvoaqaFE5jx4lAznXT/Xq5Z/ndLcGDRrToEFjQkPPMnToAFxcXJHJ\nTJDJZPri2/82evXPz+KGhoa0adONiRPHkZycgFarxd6+DPPnL37Du5CTQmHJ0qVlmD9/I+npUlq0\nsKBxY12AGxS0jHPn/iQ9PZ3q1WtRr17DfGkzy5495zh58jG2tgZ8993nrzXypFYb5niu0by7USup\nVIpcruTRo6wtWszN0/K1jRUr/uTSpS7oErVYsmvXl8yYMZNr18IAXTmL+vUbMmhQXypXduXmzevM\nmjWf4sVt8rUfgiAIHwKJNq+qoQUg6wOD8PHJ/oFQ+PiI9/fDsHTpXsaO9cmxbdy4zfTt6/3CY17l\nvX306AlLlpxGozHg228r4+T04uyKr0Or1TJgwBZCQpqQmSmnVq2trF/fAnNz83w5f0Fat+44Y8aU\n5fnzCkASnTptZe7cV6/tN2PGXhYsqIFKVZpChcKYNu0RbdrUfvmB//Cmv7ubN//B1Kkq4uNL4+ER\nxi+/1KNkyWKvfZ4XmT59L7NntyMr66ZMdpAaNaawbt0mfTmLwMBJ9OjRmSVLgnKMDgs64u/yx028\nvx+vokXlr32MmIQuCILwhry8dBkcY2Nj6NKl/Ttt+3Xb9PCwQ6G4pH9uYXENd/eS/6kPSqWSTp2O\nsGBBexYt8qFz59uEh0f/p3Nm2bBhL5s2NSQzszRQmD/+6M7y5cfy5dwF7bffUv8K5AAUHDtmRWZm\n5isfP3y4N0uW3CIgYAvBwelvFMj9Fz4+tThxoip//CFh586v8zWQA+jWrTqVK68HtEAKVapspmlT\nb2QyE0xNTWnQoDEXL56nePESIpATBOGTJ6ZZCoIgvLEPZ41OjRoujBlzlHXrbqHVSujQQUrduo3+\n0zn37j3DxYvtyboP9+61Yvv2EAYP/m/ZFR8/fsKPP0YATbJtNSI9/cO53//G1DQjx3Mzs/TXTvDR\nrFlNmjXLz169HgsLORYWr/8N8qsoXtyarVvrsWlTCBYWRhgbe6JUKnPtl19ZTQVBED5kYmROEATh\nA6XRaJg4cSy+vu0YM2YEKlUaN25cp1+/3vTo0ZnBg/vz5MljAK5fv8rhw4spVSqY1q2vc+zYEkA3\nwte3by/8/Hzx8/PlyhXd6F1Y2Dk6d+7MmDEj+PbbtkycODZX+0WLyjE0fJRtSypy+X//t3L06CXi\n4noDuwA1AIULr6RDh3eXsfFtGjy4MpUqbQLuU6zYAfr3L1TQXXrvFC5shb+/N507e+HpWZVjx46g\nUqWRmprKsWOH9WswBUEQPnViZE4QBOEDFRHxgICAQFxc3Jg2bSJbt27m+PEjTJv2E4UKFeLgwf0s\nW7aIgIBApk6dwMiRgVSu7MKSJT/rk48ULlyYOXMWYmxsTGRkBBMmjOGXX3SFqa9fv86aNZuxti6C\nv38PLl26gJubh779Bg2q4eu7nQ0bPFCrTfjiiyN07+6TV1dfi5OTLebm4Tx/3gH4FXhO376G2NuX\n+M/nfh84O5dhz55i3LhxF3v7chQpUqSgu/ReK1++gr6cBUCLFq2RyxUie6UgCAIimBMEQfhgFStW\nHBcXNwC+/PIrgoODuHfvLoMGfQ9AZmYm1tZFUSqVpKamUrmybn2Rl1dTTp06DkBGhpo5c6Zz585t\nDAwMiIqK1J/fzc1NXyC7XLnyxMXF5gjmJBIJM2d+Q58+d1GpkqhYsUO+1ANzcyvPoEEHWLUqnIwM\nKd7eGfTr1/o/n/d9YmZmRpUqrgXdjQ9GVjmL7IKDNxZQbwRBEN4fIpgTBEF4Qy9Ks18Q7Wu1WszN\nzXFwcGTJkqAc+yUn58x6lj2J8aZN67C2LsLYsZPQaDT61PwAxsbG+seGhgZoNJo8+1GunON/uo68\nDBjgRd++GjQaTY5+CJ+uGzfuM2PGFZRKYxo0kNC3r1dBd0kQBKHAiTVzgiAIb8jRURfElChRskBG\nCR4+jOPKlcsAHDiwj8qVXXj2LEG/Ta1WEx5+D7lcjpmZGdeuXQHg4MH9+kAwJeU5hQtbA7Bv36+v\nlVXxbTM0NBSBnABAeno6ffteZvfujhw50oZp0z5j3brjBd0tQRCEAieCOUEQhNdw4sRVOnTYQ6tW\n+/Hw+PblB7wlEokEe/vSbN++GV/fdiiVStq27cCkSdNZsmQB3bp1onv3Tly9qktoMnLkWKZPn0L3\n7p1IS0vDzExXr61163bs3fsr3bp1IiLiAaamZgV2Te+T9etXs2WLLkCfP382P/ygK5IeGnqWiRPH\ncvbsab77zg8/P1/Gjh1JampqQXb3vfcqXxIolUq2b98C6BLwDB8+SP9adHQU16//PS01Pd2OsLCU\n/O+oIAjCB0ZMsxQEQXhFCQlPGTToEQ8e6Oq7xce7U6KEBS1b1njnfbGxKcG6dVtybXdyKs/PPy/L\ntd3BwZHg4A0ArFmziooVKwFQqpSdfjuAv39/AKpUqcaXXzbSF6YdNGh4vl/D+8zdvQobN66lbdsO\n3LhxnQcP7nPw4H4ePLiPo2M5goODmDt3ESYmJqxdu4pNm9bRrVvPgu52gQkIGEp8/EPS01W0a9eR\nli1b4+VVj6+/bsO5c2cYPHg4sbExbNmyCbU6g0qVXBgyZGSONZbJyUls3x5C69a5C6gXK1YcW9vT\nPHiQFdA9p1Qpba79BEEQPjUimBMEQXhFFy/e4cGD6tm2GBAWlkDLlgXWpVd26tQJ1q5diUajwcam\nJKNHj8tzv4iIOMaN+5OHD82pUkVFYKDXJznV0dm5AjdvXicl5TnGxsZYWVkRExPNpUsXqFu3Pvfv\n38Pf3w/QJZFxdXUr4B4XrICAQBQKBSpVGr16daVhw8akpaVRubIL/foN5P79cNatC2bJkiAMDQ2Z\nNetH9u/fS9OmfxfLW7JkAdHRUXTv3gkjIyNMTEwZM2YE4eF3cXauyMSJrfnpp40olSeQy68QFmbG\njBmhDB8+GoB+/XpTubIrYWHnUCqTGTkyEHd3jxd1WRAE4aMggjlBeAdiY2MYMWIQq1dveqX9z58P\nRSqV6jMVCu+HihXLUKzYZeLji/+1RUv58h/GtMQmTbxo0uTlCSMGDTrF8eO6FPDnzqWj1W5hypQW\nb7t7BWLjxrXs2bMLgObNW1G/fkOGDOmPm5snV65cRKlMZufO7bi6unPhQhh3797h3r27pKSkUK1a\nDcaPn8LZs6fZvn0rI0aMKeCrKVghIRs4fvwoAPHx8URGRmJgYEDDhrrC76GhZ7h58wY9e3YGQKVS\nYW1tneMc/v4DCA+/x8qV6zl/PpSAgCGsXRuiL41ha2vAgQPNSEqqh0KhAGDSpEBOnjxOnTr1kEgk\nZGZmsnx5MH/8cZKVK5cxd+6id3gXBEEQ3j0RzAnCeygs7BxmZuYimHvPFC9ejKlTH7Bw4SZUKmO0\nWg0dO9Yr6G7lG61WS3i4ZbYtxty7Jyuw/rxNN25cZ+/e3SxfHkxmppbevbvi6VmFqKhIJkyYxogR\no+nc2Yc1a1YyceKPREQ84OzZP/Hw8OTu3Ts8ehRPdHQUv/66iy++aEpkZAR2dvYFfVkFIizsHKGh\nZ1m6dCUymYz+/fuQnq7C2FiWI+Oqt3dz+vTp+8LzZM+yqtVqqVixcp6lMcLCzrJ+/RpUqjSSkpIo\nW9aROnV0v4cNGjQCdCOrcXGxb+NyBUEQ3isimBOEd0Sj0TBx4lhu3bpBmTJlGTNmAr6+7QgKWotC\nYcmNG9dYuHAeo0ePZ+fObRgYGLJ//x4GDhwupgq9R1q2/Ew/rdLLa+JHVbhYl1QliaiorC1q7Ow+\nzsQely5doH79RshkJgA0aNCYixfPU6KELeXKOQHg4uLGr7/uxMXFld9+24NUaoS7uyfOzhV59Cie\nsWNHcu/eHcLD79K7d98PMpjbu3c3GzeuQyKR4OhYjsaNvQgOXoFanYFCYcm4cZOxsirMihVLefgw\njtjYGB4+jMPHpyNt23YAdBlR5XI5MpmM+/fDuXr1Sq52qlatzsiRQ/Dx6YSVlRVJSYmkpKRiY2Pz\nwr5JpblLY6hUKn76aQYrVqyhaNFiBAUtIz09PdcxBgaGLyylIQiC8DERwZwgvCMREQ8ICAjExcWN\nadMmsm1bSJ6BgI1NCb7+ug1mZmZ06OBbAD0VXtXHFMhlmTmzGoGB63j40AxPz1QmTPiioLv0Vrzo\nvTM2luof29uXoVu3nvqAb+DAYTRs2ITY2BiGDx9IixatefLksT5pzIfm9u3brF4dxNKlK1EoLElK\nSkIikbBs2SoAdu3awbp1q+nXbyAAkZERLFiwlOfPlXTq1IbWrdthaGhIjRq12bFjK76+7bCzK42L\niy5JSfZ7XKaMA716+TN4cF8yM7UYGRkxZMiIHMGcmZkZKSn/nqEyK3BTKCxJSUnh8OHfadxY1JsT\nBOHTJYI5QXhHihUrrp82+eWXXxESsuFf99eKRG3vvf37jxZ0F/Kdk5MdGzbYAVC0qFyfzfJj4+7u\nwZQpE/D17UpmppZjxw4zduxEdu7c/q/HXb16j/79r5GQYMW9e8sZMGDEO+px/jt9+jSNG3uhUOim\n1ioUCu7evUNg4EiePn1CRkYGJUvaArrArHbtuhgZGWFpWQgrq8IkJDylSJGiSKVSZs2an+v8//z9\neNm6TUvLQri6utOlS3tkMpm+/mF2crmcFi1a0aVLewoXtqZSJZd/ucK392XL666DfpEVK5bi7u5J\ntWrVX76zIAhCHkQwJwjvSPZvqbVaLRKJAYaGhmRm6qI2lSr9RYcKBUypVDJ//lFSUw35+msHqlVz\nfittHDiwL8+07EL+K1++Al991ZxevXTJXlq0aI1crsg1Ypf9uUQiYebMq1y50gm5XIGBwRpWrsyg\nfft32vV8I5FIcqxTA5gzZwYdO3amTp16nD8fSlDQ32UujIz+HrU0MDBArX61aYz79oUxb14cqanG\nNGyYyrhxzV84Mjpu3OQ8t2cvjdGrlz+9evnn2mfMmAlkBXCFChUiJOR/r9S/gtSjR5+C7oIgCB84\nUTRcEN6Rhw/juHLlMgAHDuzDzc0dG5sS3LhxDYCjRw/q99VNN3peIP0UcsrIyMDXdzdz57Zj6dJ2\n+Pklce7czXxvJ6vG1vvg34o3vw2xsTF06fLuI6L27b9l9epNrF69iXbtOmBjU4Lg4I361zt29KV7\n914AjBo1jgYNGpOcrJtyaWoaSmJiO/3zD1HNmjU5fPh3kpISAf5ax/Zcn3Rk797d+n3/GfS9qmfP\nEhg1SkloaHuuXWvNkiVerFp15D/3PTutVsvAgVuoUeMxNWvGM2zYtjfu7+vIzMxk+vQpdO7sw+DB\n/VCpVOzcuZ1evbrQrVsnxowZjkqVhlKppG3bvzPCpqam8s03zVCr1UyZMp4jR3R/+9u2bcGKFUvx\n8/Ola9cORETcByAhIYGBA7+nc2cfpk+fTNu2LfTvmSAIggjmBOEd0CWWKM327Zvx9W2HUqmkdet2\ndO/em3nzZtGzZxcMDY3031bXqVOfY8eO0L17Jy5dulDAvf+0Xb16i1OnGgOGAMTFNWbnzvB8byd7\nja1Fi3JPWXuX3qfA8m04dOh3fH3b8cMP/pw/H8qVK5de6bgDBy4QHn4de/vmmJkdx8pqGUWLrn3L\nvX17ypUrR5cufvTr15tu3Trx889z8fPrzdixI+jRozOFChXS/02ShAqJvAAAIABJREFUSCS8yRLR\nu3cjiYr6eypkZmZR7txR5dclALBt23E2bmxJSkotnj+vzdq1Tdm9+1S+tpGXyMgI2rTxYc2azVhY\nyDl69BANGzZm+fLVrFq1ntKlHdi9+39YWFjg5FSesLBzAJw6dZwaNWpjZGT01339+x4XKmRFUNBa\nWrVqy4YNup+tlSuXUa1addas2UzDhk14+DDurV+bIAgfDjHNUhDeARubEqxbtyXXdnd3DzZs2AZA\nQsJT7tyJJDk5CTs7e4KD/31NnfBuWFnJMTN7QkqK419bNJiZ5X+WvOw1tgpaVmDZqlUrQJKreHNg\n4CQAzp07w6JF89BoNFSoUImhQwOQSqW0bdsiV5bWBQuWkpCQwIQJo3ny5DEuLm6cPfsnQUG6D6xZ\noxxXrlykaNFiTJs2G5ns7ZRF2L37f4wYMQZXV3dWrFj6SmVAEhOfMXJkIlFRo4H9lC49gXLlPmft\n2pePWqrVaoyM3s9/t97ezfH2bp5jW926DXI812g0tGnjo19bB7zyWrHy5cvg6HiWu3dLAyCTReDu\nLv+Pvc4pPv45mZmFs/W3GHFxb3etZ3z8QwwMDPSZT52dKxAbG8Pdu3dYvnwxz58rSUlJpUaNWoBu\nWurGjeuoUqUav/++nzZtfPI8b4MGjQHdNOCjRw8BcPnyRaZNmw1AjRq1kMsVb/XaBEH4sLyf/10E\n4ROzZ08oo0alEBPjgqPjWebNs6V69QoF3S0BKF3anj599rJ0aSapqdbUrXuE/v3zv4j2u5gW9qqy\nAssdO3awf/+RXMWbL1++SPnyFZg6dQLz5y+hVCk7Jk8ex/btW/Dx6fjC9VBZIwy+vt34888/2L37\n7zVNkZERjB8/lREjRhMYGMDRo4f44gvv/3wtAQFDiY9/SHq6inbtOvL06ZO/PhxPxNHRiUuXzuvL\ngAwaNBw7u9LMnj1NP/oxYMAQXF3d+fnnBWRkKLGzW4eh4WOMjBJ5+vQwISE2XLx4npiYGExMTBg+\nfDSOjuVYsWIpMTFRxMTEYGNT4oVrwd53e/eGMWlSHI8eFcPF5R7Ll39OkSKFX37gX+RyBXPnlmTe\nvE2kpEhp0kSLj0/+Zkht0aIKwcH/4969VgCUK7edFi2q5msbefv751xXCkHF1KkT+fHH2Tg6lmPv\n3t2cPx8KwOjRE+jatQNJSUncunWDqlU/y/OMWdlUs0oxZHmf/j4IgvB+EcGcILwH5s+PIyZGV7Pp\n7l175s7dyPr1Iph7XwQEeNOlSzTPnj3D2bntezvKkl9eVrw5NjYGExNTSpa0pVQpXeZLb+/mbNu2\nGR+fji8877+NMGSv75Y1ypEfAgICUSgUqFRp9OrVlZ9/XkZo6Fn69RvEiRNHSUl5zmef1aBDB1+W\nLl3I3LmzsLcvjUqlQq3WEBg4ku3b95KRkYZcfpI7d86i1Rrj6OhJ1ap1iY2Nwdm5ItOmzSYs7ByT\nJwfqR1cfPHjAokW/YGxs/JJevp+0Wi3TpsVy547ub9PJkw2ZMmU9c+a0fK3z1KhRkfXrK76NLgJQ\nqlRxgoNTCQrajESSSY8ertjYFH1r7f1Nq68damhoSN269VEqk5g5cwrp6RnExcXqs1TOmTMDa2tr\n5s2bSXJyEkFByzh58jixsdGUL69LqKTRZDJq1DASE59RqpQdV69eJikpEVdXdw4dOsC333blzJnT\nJCcnvYNrEwThQyHWzAnCeyA11fhfnwsFz9bWlsqVK7y1QO5VamwVlLyKN/9z9E2XoVXy1z4vztL6\nohGG7PXd8rPgc0jIBrp160SfPn7Ex8cTGRmpf61Zs5bcuXMbrVY3zfPQoQNERUVy6tRxDAwMMDQ0\n4MmTJ0RHRyGVSilatBANG26nZs1tyGQaqlZ15PLli3z55VcAVKlSjcREXRIRiURC3br1P9hADnTJ\nf54+zT4lUkJiommB9effODuXYfp0b378sRlOTu+meHtGRgbffNOOtWtDMDY25tq1q8jlCh49eoSh\noSEVK1bi1q0bgG49nKurBwcO/Iapqal+bZy9fRlOnDgGgFKZhIdHFdas2UzVqtX1NfW6d+/NmTN/\n0qVLew4fPkjhwtaYmZm/k2sUBOH9J4I5QXgPNGyYioHBYwBksvt4eX18xaiFf5e9xlZBJ0B5lcDS\n3r40sbExREdHAfDbb3vw8KgC8MIsrVkjDMA7GWEICztHaOhZli5dyapV63FyKk96+t/JN2xsSiCT\nyXj0KJ4zZ07j5OSMRqOmf//BrFq1gTVrNuPl1ZTw8HsAWFiYsmnTV+zc6YWx8d9B/YsC1Kxi4x8q\nY2NjqlSJA3SBtVQaRa1aH/eo9KsqVqw4xYvb6Nda+vsPIDNTS2LiMxQKBWp1Bo8fP6JkyVL6Y1xc\nXDl27AzGxjL92rgBA4boX7e1LUXz5l8D0LZte/0aRQsLC376aQGrV2+iWbMWWFtbf/SzAwRBeHXi\nr4EgvAfGj29BmTJHuHtXhaengjZtPi/oLgkF4H1ZV5UVWLZo0QJDQ6M8izcbGxszatQ4xo4dgUaj\noWLFyrRqpauR1717b378cSK//GKBp2dV/Yhd9+69GT9+NL/9tofKld30IwzPnz//1/pubyol5Tly\nuRyZTMaDB/e5evVKrn1cXd25fPkiT548olmzlty7d4czZ07TooVu/VVychISiQQDA4Nc008B3Nw8\n2b9/L9269SQs7ByFCllhZmb+0axxWry4GVOnbuLJExnVqxvj59eooLtUoPz9/Vi8OAjIXTvU3Nwc\nBwdHliwJeul5XmVtXHq6mhYtDpCWpqFQodWUKKFAKpUyfPiY/LocQRA+AiKYE4T3gEQioXv3T/tD\n0qdoxYojnDiRjkKRyujRdSlWLHfQVFDGjZtM0aJyHj3KmRUwe/HmqlU/IyhoXa5js2dpzS5rhMHQ\n0JArVy5x8+Y1jIyMKFGiZK76bvmhRo3a7NixFV/fdtjZlcbFxTXXPr6+3fDz+5Y7d27xzTc+dO7c\nncWLF9ClSwcyMnSjKwEBgYSGnuXp06dkZGSQlpZGeroKAwMJfn69mTZtIl27dsTU1JQxY8YDb57K\n/31jbm7OlCn5n/DnQ5UVyMHftUNdXFw5cGAflSu7sGvXDv02tVpNZGQEDg5lX+nc2dfGHTy4n9TU\nFO7ebU1mZiGgHa1aHcTf3+stXZkgCB8qEcwJgiAUgNWrjzJ+vBsqVWlAy/37QezY0S5fRqTeVw8f\nxhEYOJLMTC1SqRHDh49h4cIDHDmixcwsnSFDKuDmVi7f2pNKpcyalXvK6oIFS/WPHRzK0qxZS+Ry\nBe7unri7exIefpfTp09hbCwlICAQK6vCDBgwGCMjIzp3bk/JkiWpV68BJiamKBQKpk2blasNP7/e\n+XYdwvvDy6seBw4cJyEhAWNjY4YNG0BaWhouLm4MGjSc6tVrMW/eLJRKJRqNmvbtO70kmJPkOXJd\nokQpNBorMjOz1sZZEhWV/yVRBEH48Em078lckH9++yt8PPL6dl/4eIj3983067efzZvb6J/L5Sc5\nc8YWa+v3Z3Qu672NjY1hxIhBr1xb7FVt3HiCoUMrk56uS1hRocImfvutEaam7y7JRmZmJj16+DJ5\n8gxsbUvleO3WrQjmz79MWpoRDRtqqVu3Avb29hgY5L3cPCkpifHjD/PokSmurmqGDm36wn3fhdTU\nVAIDR/Lo0SMyMzV07doTS0tLfW1ADw93+vUbilQqffnJBAC8vOpz4MAxNmxYS0ZGOl26+KHVaklN\nTcXMzOyNzxsdHUto6C2qVStPyZIlOHv2TwYNmsytW0cAMDG5w/z592nVqsYrnU/8Xf64iff341W0\n6OvX4RQjc4IgCAXA2jodUJP1Z7hIkVgUireXvv1VZBXQTkl5jru7J97eTXK8HhZ2jo0b1zFjxpx8\nae/8+ef6QA7g5k0PIiOjKF/e6bXP9SaFucPD7zFixCAaNGicK5BTKpX07HmFGzc6AP9j5045RkYa\nGjXaSFBQmzwLmn///W/s398NMOC33xLQavcxYsRXr30t+eXPP09RpEgxZs6cB+iuqUuX9vragLNm\nTdbXBhReT6VKlZk2bSJqtZp69Rri5FT+jc/1v/+dYfRoCQkJpbG398fWVoKVlZyAgAFs3bqBtDRj\nvLyMadWqYf5dgCAIHw2RzVIQBKEAjBzZBG/vYIoV+xVn5w0EBhYp8BGSrOlePXr00dfHypKZmcmG\nDWs5fz6UwYP7oVKpuH37Jr17d6Nr146MGjWM5ORkEhKe0qNHZwBu375FvXqfER//EAAfn69RqVQk\nJCQwZsxwrl9fjL19a0xMwoBMHB27olD8nXK9Q4fWJCQk6Pfv1asLvXp14fLli4Au+Jw0aSz+/j2Y\nMmX8a1+vg0NZNm/+H337/pDrtbCw69y40QAIB2yBxqjVHhw40J2FCw/l2l+r1XLtWmH+/rdqxaVL\nBft+Ojo6ce7cnyxevICLFy8QGxuTozZgq1atuHgxrED7+KFyd/dk4cLlFC1ajKlTx7Nv369vfK6l\nS58QH9+YjAxX7t49iETSm+XLV9OsmRdBQc1Zv/4LundvmH+dFwThoyJG5gRBEAqAqakpwcE+pKen\nI5VKC2ytXHDwCvbt+xUrq8IUK1YcZ+eKTJ06gdq169KuXStOnz7FnDkziI6OwsnJGU/PqpiYmHD0\n6CHWrVvN4MHDcXf3ZMWKpaxcuYwBA4aQnq4iJeU5ly6dp0KFSly4cB43N3cKF7ZGJpMxbdpEfHw6\nMXGiG4MGrebcuR8wN+9L5cpVCAsL5auvSnL16hVKlCiJlZUV48ePxsenE25uHsTFxTF0aH/Wrg0B\n3l5hbgeHElha3vmrrlqJbK8Yk5yc+72SSCQUK/acqKisLVqKFHmer316XXZ29gQFreOPP06wfPki\nqlb9rED78zGJi4ujaNGitGjRivT0dG7fvknTps3e6Fzp6TmDfpVKfDQTBOHVib8YgiAIBaggi0rf\nuHGdQ4cOsGrVBjQaNX5+viQmPiM5OZk6deqhUqmYMWMKY8dOZPr0KX8VCwdn5wpER0ehVCbj7u4J\nQNOmzRg7diQALi7uXLp0kYsXL9C5c3f+/PMUoNXve+7cGR48CNf3o3hxCevXN+HOHTtWrvyFr75q\nwcGDv9GkiVee+6ekpJCamvpWC3Pb2ZVixIi7LFkSRWzsCTIyBgASihc/SrNmDnkeM2GCM2PHruPh\nQ3MqVkxg3Lgmee73rjx+/Bi5XM4XX3hjbm7Btm0hxMXFEh0dha1tKf73v//h6Vm1QPv4ocn60uX8\n+XNs2LAGIyMjzMzMGTNmwhuf09tbzc2bEahU9pia3qF5c8P86q4gCJ8AEcwJgiB8oi5dOk/9+o3+\nWv8lo06d+ty7dxfQTRu8d+8eJUvaYmNTAmNjKV984c3OndsxMDBEqXzx4nsPD08uXjzPw4dx1KvX\ngLVrVyGRSKhdu95fe2hZtiw417TS3bv/x/3793j27BnHjx+jW7deufbv1683I0cGYmpqyubN62nf\nPn/KGOSlZ88G+PllEh//mMWLN5KWJqVVK3uqVXPOc/8aNZzZv9+ZzMzMAk18kuXevTssXDgPAwMJ\nRkZShg4NQKlM1tcG9PT00NcGFF7N/v1HAfD2bo63d/N8OeeQIU1xcDjF9et/4ulpzVdfNc6X8wqC\n8GkQwZwgCMInSzfKkDXVMjU1FSurwhgaGhIbG8Pq1SuIiIhk5sypfxU2/jv5sbm5BQqFgosXL+Du\n7sG+fb/qR3nc3T1ZunShvmC4QqHgjz9O8t13/QH47LOahIRspFOnrLV1N3FycmbkyLEsWjSPBQtm\n4+DggEKhyLW/RCIhIuIBzs4V9P3PT1lJYLJq3RkYGGBjU4wJE/L+4P748SPmzp3F5MnT9dveh0AO\noHr1mlSvXjPX9qzagCIj3uvJzMwkMHAX586ZUahQGqNGVcLNzTFfzv3NN7Xz5TyCIHx63o//OIIg\nCMI75+HhyYED+/j9999YtOgXZDIZ8fFxAGzZspmGDRuiVqu5fv0asbEx7N+/j9u3b7Jhwxq2b99C\ntWo1WLRoHp9/Xo/9+/dy7tyfdOnSnoSEpwBUquTC1KkTuHfvDs+eJXDhQigA/fsPYvfuHTRuXJtG\njWozZ85MAPr1642jY3n279+HSqWiZ88udO7sQ9GiRbl58xpdu3bk+vWrHD2aPQGJlhUrlrJ58wb9\nlqVLFxISspE38TprF9VqNUWKFM0RyL2vjh27Qo8eO2jXbgJ//HEdgIcPHzJmzIgC7tn7Yf361WzZ\novuZmT9/Nj/84A9AaOhZJk4cy6xZP9KyZWt++20dDx5Ec+hQJ4YMucKiRfPx9fWha9eOLFw4ryAv\nQRCET5QYmRMEQfhElS9fATs7e65evcLo0cOpXNmVhw8fkpKi5PlzJbdv32batFnMmTOT58+VxMXF\nkpGRzu7dvwPw/LkSc3ML+vfvg52dPcOHj+bixfNMmzaRbdt+ZenShVSrVp1Ro8aRnJxM795dqVat\nBkePHsbR0Ym1a0MwMDAgKSkJ0AVSZco4cPz4WZKSklAoFGg0GgYO/J6BA4fh6FiO/v374OZWk6NH\nz2JhYUGbNj6kpKQwatQwfHw6kpmZyaFDB1i+fPUr34eskUmFwhKNRkOTJl/QrVsnzMzMWLToF549\ne0avXl0ICdnJnj27OHr0EGlpaWRmZjJ69HiGDfuBNWs2s2fPLk6cOIZKpSI6Oor69Rvy/fcDANi9\newfr1q3GwkJOuXJOGBsbM2jQ8Px/U9GVIDhwYB+tW7clLOwcv/yynJMne/HoUSNsbTfTt+8ztmyJ\nokaNih9EIPouuLtXYePGtbRt24EbN66jVqtRq9VcvHgeD48qNGzYhHv3qnDmTCtKleqGsfFN7t+3\n5ujRlWzatAPQ/T4IgiC8ayKYEwRB+IR99llNKlSoRI8efQBYsGAOFhYWbN68nqtXrxIVFY2xsRQj\nIyMqV3YlMfEZc+fOpFatujmm8H3++ZeAborl8+fPUSqVnDlzmpMnj7FhwxoAMjIyePgwjtDQM7Rq\n1VY/HTFrOmV2hw7tZ+fOHWg0Gp48ecz9++GUKePAjRvx7NhRCpWqIpUqPUetVmNjUwJLS0tu377J\nkydPKF++Qp7nzEv2JDDR0ZH4+XWmSZMv/no171G627dvERy8EblcTmxsTI7RvDt3brFq1XokEgM6\nd/ahXbsOSCQSgoODCApah6mpKT/84J+jLtnLintXqFCJoUMDkEqltG3bAi+vppw+fRIDA0OGDx/N\nkiULiImJpmPHzrRq1Ybk5CRWrlzOnj07SUxMJDk5jdjYRtjYDEIqjUCrXcSMGUWYP38CPXv2YvXq\nTezZs4vjx4+QlpZGVFQkHTp8i0qVzu+/70MqNWbmzHkoFAqio6P46acZPHuWgImJCSNGjMbevgyH\nDv3OqlXLMTAwxMLCgp9/XvZK9/994excgZs3r5OS8hxjY2MqVKjIjRvXuXTpAgMHDuPQof08eLCa\n0qVXYmj4BGPju9jZPcHU1JRp0yZSu3Y96tSp9/KGBEEQ8pkI5gRBED5hHh6eTJkyAV/fbmg0ak6e\nPM7XX3+DTGaCoWEJYmM7YGCwlwYNKvPDD0Po06cvf/55ih07tnLo0AECAgLzPG9WfDNlykzs7Oxz\nva7VanNtyxITE83Gjev45Zc1WFhYMHXqBNLTVWzbdownT4oBxYESpKZaEBJygj59vqZ581b8+usu\nEhKe0KxZy1e+/uxJYIKDV6DVZrJp01pSUlKwsyvNmDEjuHPnFgkJCfpjypd3JiBgCKmpqZiYmPy1\nnhDWrFmFiYkJAwb44+X1JUWKFGXYsIGkpKSQkZFBeroKuVxOo0ZNiIyM0J/vZcW9J08epy/uLZFI\nKF7chpUr17NgwU9MnTqeJUtWolKp6NKlPa1atWHKlPEkJj7D2toahcKShIRn2Nn5YGiYgFYrJT7+\nJ9q3j6Bjx46YmJgCEBsbw5kzp9m9+3fOnTvD6NHDKFKkKIUKFcLZuQL79v2Kj09HZsyYwrBhoyhV\nyo6rV68we/Z05s1bTHDwL/z000KKFCnyQY5QGRkZUaKELXv27MLV1R1Hx3KEhZ0lOjoKmUzGxo3r\nWL9+DdOmHeaPP/ZSseIRpkzpQOXKqzl37gxHjhxk27bNzJu3uKAvRRCET4xYMycIgvAJK1++Ak2a\neNGtW0eGDv2BSpUqI5FAmTJNiI6+zpMnvxAdLePkSTvi4mLRaNQ0aNCYXr2+4/btm4AuMDt06AAA\nFy9ewMJCjrm5BdWr19SvQwK4desGANWq1eB//9umD4Kypllmef78OSYmppibm/P06RNOnz711/YM\ncv7bkpCaqgagQYNG/PnnKW7cuE6NGrVe4w78Parm7z8ACwsL2rf3xc6uNBER9xk4cCjz5i1Go1Fz\n6dIFNBoN165dZcqUGaxYsYZGjT7n6dMnujNJdOf75ZfVtGnTngcPwunZsw/9+w+kVCk7li1b9Nf9\nytmDlxX39vZunqO4d926DQAoW7YclSu7YmpqSqFChZBKpSiVSuzs7PWjhcnJSaSlpeDhURy1uiMG\nBml07LiNRo2q5lofaGEhx9TUlN27d2BpWYhly4JZtGgFTk7OxMXFkJqayuXLlxg7dgTdu3di1qyp\nPHmiu3ZXV3emTBnHrl079O/rh8bd3YMNG9bi4VEFd3dPduzYSvnyzvqfR4VCwfDh9bC0vE+/fp44\nO5dCqUymVq069O8/mDt3bhX0JQiC8AkSI3OCIAgfqX9mZnyRLl386NLFL8e2gwf3ExdXg8KFl2Fs\nHEFi4k3u36/IsmWL0GozAfTZKSUSCcbGxvj5fYtGo9GP1nXr1pP582fTtWsHMjMzKVnSlunT59Ci\nRSsiIyPo2rUjRkZGtGzZmm++aadv28mpPOXLO9OpUxuKFbPBzc0dgDZtarFixWwMDIYDUqTSR1hY\nPGbRonl8//0PVK36GfHxD5k7dyaDBg3nt9/2sGXLJtTqDCpVcmHIkJEYGBjg5VWPdu06curUCbTa\nTNRqDb6+3UhJeU5KSgoAhQsXJi0tlSJFirJ583qkUilxcbE8ffqYpKREBg78HgCVSoVardb33c5O\nF4BFRNwnJSWFn3+eg1RqTGRkBAYGhqjVao4ePUS5ck7Zjvn34t5arTZH4GVsrCvpYGBgkKO8g4GB\nARqNGq0WChWyYuXK9YSFnWPNmpXMmTOVqKhIevUypUmT0nn+HBgY6NpwdXXnzJnT7Nu3my+//Aoj\nIyM0Gg1abSZyuZyVK9fnOnbo0ACuXbvCH3+cpEePzqxYsQaFwjLPdt5X7u6erFmzEhcXV2QyE2Qy\nGe7unpQr55Tnz2NKynNGjhxCeno6oKV//8EFewGCIHySRDAnCILwkXqdzIxZTp++xoYNEdy5cxOl\ncghK5VcAeHqupkaNWtSsmXcK9S+/bMaAAUNybJPJZAwbNirXvoaGhvTvP4j+/Qfl2L5gwVL941Gj\nxuU67v79cD77zAkbm6aAIenpZ7C1tSU4eAXffdefq1cvY25uweeff8n9++EcOnSAJUuCMDQ0ZNas\nH9m/fy9NmzYjLS0NFxc3evf+nkWL5nPr1g26deuIubkFMpkMiQQaN/ZizpyZ+Pl9S61adQHJXyNO\nEhQKS31AExsbw8iRWR/idfXcQDf6ZmZmxsiRgXh4VGHnzu2sX7+G77/vSenSZTAzM9df18uKe//2\n2x48PKrkuh95TVWVSCR4enpy4MBeUlNT9fslJCQgl8vRaNRoNGr9+5Ale0Dq69uNLVs2oVKp8Pfv\nwddffwOAmZk5JUuW5PDh32nU6HO0Wi13796hXDknoqOjqFTJhUqVXDh9+iTx8fEfXDBXtepnHD78\nh/75hg3b9I//+fOYnJxEYmKi/udLEAShoIhgTsg3/v5+LF4cBMDChfM4ffoktWrVZdy40QXcM0H4\ndGRlZrSyKkyxYsVxdq7I7ds3mTlzGiqVClvbUgQEBCKXy3Mls2jbtiuDBklJSrLA2vo0Dg6NMTZW\nULJkV8aOdX+j4DA/hYae4d69uyQk6AKp9PR0Hj1K59EjCY0bN6F+/Tpcv34ZV1d3tm7dxM2bN+jZ\nU1fLTqVSYW1tDYBUKqV27boAODtXJDk5iblzF5GY+IwePTrToYMvYWHn8PCowowZcwD0RdK//bYr\ne/fu5sqVy7i4uFK0aDHGj58KgLW1NR076tqzty+NpWUhjIykaLVaGjRojKurO3Z29owePYz69Rvq\nr+tlxb0rVqycrbj33++BRCL5x3uie1yzZh2MjWV89113UlNTSE5OJjU1hZIlbbG2LsKyZYsID7+H\njY0NMTExANy8eV1/fHR0FFKpMe3adSA8/B5Pnz7RtxMYOJlZs34kODgItVrN559/QblyTixaNI+o\nqEi0Wi3VqlXPMfL4sQkKOsqcOYY8e2bDZ59tYdUq71dOuCMIgpDfRDAn5JusQA5g167t7N17uMA/\n/AnCpyR7ZkaNRo2fny/OzhWZPHk8gwcPx93dkxUrlrJy5TIGDBiSK5nFmDETiYnZS+nSLYiKCkKj\nseLHH0Pw82vxr+1mH1F727y9m9OnT18Adu48w4ABpTAyuoKx8W2OHn1M69b18tw3O0PDv//1GRhI\n9Gu8LC0L4erqTpcu7ZHJZBQubK3fLyMjg9DQs3h7N+fbb7syfPgPFCtmg0ajpn37Tjg4lAX+Hg2V\nSqVMmjSdsWMDiYpSkpmZhExmSIkShalRozb16jXUn/tlxb2zCwn5X47r8/ZunudrNWvW5u7d21ha\nFsLBwZGSJW31bVWoUAlv7+Y8eHCTkSMD6NmzC56eVfWjcyEhGzA1NaF//+8oW9aRvn0HYmSku2cl\nSpRk9uz5ufo1ZcrMXNs+RkqlkjlzJDx86A3AiRPuzJy5iUmT8i4qLwiC8LaJYE7IN15e9Thw4Dgj\nRgwiNTUVP79v8fXtTocO3xR01wThk5A9MyPIqFOnPmlpqSjHQyKVAAAgAElEQVSVybi7ewLQtGkz\nxo4dmSOZRRa1+jkSyVNSU6tgYzOStLSq2NuXLaCrya1q1eqMHDkEH59OWFlZcfx4FOnpFUlL88Le\nfjHJyVaUK/dNnvsmJSWSkpKKjY3Nv7YxbtzkPLd37tydESN000JtbUvh4uKuH7XL8s+g1sbGhhs3\nviMmJisYTqV589388EPTN7j61/Oi68he265atWo5phJmGThw2Cu3o9VqOX48lMTEVLy8PsPExOT1\nO/sBUSqTSUwslm2LAUql9IX7C4IgvG0imBPyke4b6enT5+DlVT/PRfKCILxNrz4SnlcyC61Wy5Ah\nW9m9uz4SyR3q1LnI0qVbqF7d9b1Y/1SmjAO9evkzeHBfMjO1PH2agpGRM2lpNUlPL4ep6WVq166V\n575GRkYMGTICGxubHDMGXnX2wJIlC4iOjqJ7904YGRlhYmLKmDEjCA+/i7NzRQIDJwG60dGff57D\n06fPSExU8/BhIFJpBCVKDCQiYhtxcRIiIyMYN24UQUFr8/8m/Qdbt/7B8eNJWFllMHx4E0xNTf91\nf61Wy8CBW9m0qQmZmYWoXj2EDRu8kcvl76jH716xYsWpWfMYR45UAQxRKC7g5VWkoLslCMInTARz\ngiAIH4kX1YyTyxVcvHgBd3cP9u37FU/Pqi9MZvHTT23p0+c6pUp9TZkyfWjVqvV7lcyiSRMvmjTx\nAkCj0TB48HaOH3+IuXljBgxok2NqZPZ9s9u//6j+ccOGTWjYsMlL2/X3H0B4+D1WrlzP+fOhBAQM\nYe3aEKyti+Dv34NLly5QqZILc+fOpFu3AfTtqyQhIYMiRX7m4cNNZGZaIJcfoGZNBXv27HqtWnjv\nwqZNpxg+vAypqc6Amtu3V7J2bYd/PSYs7CqbN9clM1NXR/DMme4sWRLCsGFfvYMeFwwDAwOCgpox\ne/ZmkpKkeHkVpWnTqgXdLUEQPmEimBMEQfhIZK8ZZ2VVWF8zbvTo8cyaNY20tDRsbUvpM/P9n737\nDIji6ho4/l+WooA0ARHsiqCiYDeW2KKxPyZ2LIAtajRqSGLvsWM3KhZQsPfXxB5iN5ZY0BixYqHY\nkN7Z3fcDYZWIHUTw/L64M3tn7p0dBA537jkvS2bh5+dDSMh9lEodXFyqf7TJLJRKJQsWdEStVqOj\n8/qyqXfuhPPjj6e5d68QpUvHMHduPWxtrV97HGTOHKnRaKhQoRKWllYAlCtXngcPwjE2NiY4+BYT\nJ45DoTCjcGEVaWmmgC/W1kWoVWsTX301h27dZrJihd87XXNOOXw4/t9ADkCXs2fLEBcXh7Gx8UuP\nSUxMQqVKz8ppZuaHqelGzp414tgxQ4oXL0mpUqXfe1zh4WGMGDEcP79N732u7GJsbMyECbJGTgjx\ncZBgTggh8pGsasYBeHv7vrDvdcksrKwK8fhxbPYPMpu9SSAHMGbMGY4cSc82GRwMY8b44+vb/p36\n1NPT175WKnW0SVRKly5LyZIdmDevMxkFzgsUOM/8+ZWZNGkUJ08ew9GxwkeX/dDYOAnQkPGorqlp\n5Gsfs6xTx4WGDbdw5Ig7pqYb0NPrzKxZn7Fu3Qrq1WuQLcGcEEKIV5NgTmSbd1mHIoT4OERGRjFs\nWAA3b5pRtGgc06dXw8qqYm4PK1s9eGD0yu1XMTQ01BYUf5kSJUoRFRVJ796W/PmnD6dOfYmR0WV6\n9ozC0bEttWt/hpfXDG1R9Y/JiBH1uH7dhwsXqmBpGYanp8lL66dt3LiWPXt+BeCrr1pjZNSH27fv\nYme3kaNHYzhx4hgXL15gzZpVTJ06G41Gk6kExogRYyhRohRTp07EyMiYa9f+ISIigkGDvsvykVeV\nSsXkyeO4fj2IUqXKMG7cJIKDg1m8eB6JiYmYmpoxZswEChe2JCTkPrNnTyc6OgodHR1+/nkm5uYW\njBzpSWxsDCpVGv36DaR+/YaEh4fh6TkEJ6cqXL4cqM3y6eu7nMjIKCZMmEKFCpVITExk3rxZBAff\n/jdLbH/q12+Yo/dDCCHelARzIlukpaXh7e3L06cRWFgUzrQmRQjx8Rs79hB797oBCm7cgJEj/Tl6\nNH8Fc+XLR3P5sgpQAmmUL//ms46vKluQQVdXlylTZrJggRfm5jE0aLCOL79sTf/+3wDwxRctOHr0\ncJZlCHKbpaUFO3Z05OHDB5ialsDQ0DDLdkFBV9m79zdWrFiDWq2hf383xo+fwujRP+LtvQoTE1NC\nQu5Tr14DGjZsAsDQoQMzlcCYM2cmCxYsBeDp0wiWLvXhzp1gRo78Pstg7t69u4waNR4npypMnz6Z\nbds2c+zYYaZPn4uZmRkBAQdYvnwJo0aNZ9KksfTq5UGDBo1ITU1FrVahq6vH9OmzMTQ0IioqigED\nPLTBWGhoCD//PItRo8bTt28vAgIOsHSpD8ePH8HPz5fp073w8/OhRo1ajB49gdjYWPr3d6NGjdr5\nPnOnECJvkGBOvLeYmBh69drLmTN1MDW9wdChlxgwoHFuD0sI8RbCw415PhtmWFj+y0jo5dWCggXX\nc/++IaVLJzBpUsu3Ov5N0v3b25dn8eLlmd6PjY3hzJmrBAWdpnXrdh/tkws6OjoULWr7yjaXLl38\nt/xFeiDTsGETLl688EK7jDWGCQkJ/P135hIYqanp9ewUCgUNGqQHVaVKlebp06dZ9mltXQQnpyoA\nfPllK9as8eH27VsMHz4IALVaTeHCViQkJBAR8URbw09PTw/QIy0tjWXLFhMYeBEdHQVPnjwmMjK9\nr6JF7ShTpiwApUuXoUaNWv++LsuDB+kF1c+cOcWJE0fZsMH/3/Gn8ujRA0qUKPXKz0oIIT4ECebE\ne/PyOsrJk70BHSIinFi0aB+urtEfTfY7IcTr2dsncPx4CqAPaLC3j3rnc/3441AmTpyKkZGxtv7k\nx5DIwsjIiLlz322N3Lu6du0uffv+Q2zsAQwM7tCzZ78P2n92yyoQzSo2zWin0agxNi700lI16QEX\n/7bVZNnm+T41Gg1GRkaULl2WZct8MrVLSIjP8vgDB/YSHR2Fj89alEolnTq1Izk5BQB9/Wf96+jo\naMejo/NsHSSkryMtXrxElucXQojc9GarxoV4hYQEfZ7/UoqNtSIuLi73BiSEeGtTprSkd+8t1Ku3\nnY4d/Zk/v8k7n2v27AUYGWVkQfw4Z6E+lIULL3PtWmfCwlYSHPw7q1frolarc3tY78zZ2YWjRw+T\nnJxEYmIiR48e0hakz2BoaEh8fHpgZWRkrC2BAenB2M2bN96qz4cPH/D335cBOHhwH5UqOREVFand\nl5aWRnDwbQwNjbCysubYscMApKSkkJycRHx8PObmFiiVSs6f/4sHD8Lfqv9ateqwdetG7fb160Fv\ndbwQQuQkCebEe/vySyvMzM7+u5VGnTrnsbEpmqtjEkK8HX19fWbMaMeOHc1YsuQrLCzMX9p2/Xo/\n7S+3CxfOYejQgQCcO3eWSZPG0qlTO2Jioj/IuD92KSn6mbaTkwuQlpaWS6N5f+XLO9KqVRv69XPj\nm2/cadv2K+ztHTK1adq0OevX+9O7dw/CwkIZP/5nfvttF+7urvTs2YXjx5+tqX5d4iyFQkGJEiXZ\nsWMzPXp0Ii4ujo4duzJlykyWLVuEu7srHh6uXLlyCYBx4yazdesm3Ny6MXBgH54+fUrz5i0ICrqK\nm1tX9u3bTcmSpV/aZ1bjcXfvS1paGm5uXenZszOrVnm/xycohBDZS6F52XMNH1heSH8tXi4g4CIH\nDjzE2DgVT88mmRbP55X05uLdyP3Nv152b69c+ZuNG9cyZcoMBg1K/0V3yZKV+Pv7YmFRmLVrV7Nq\nlT8mJqY0a/Y5Bw8e/Sges8wNv/56Bk9PM6KiqgPRuLpuZ/78jrk9LCD//9/9VL/mIP/f20+d3N/8\ny8rq7dery5o5kS2aNnWh6YtJyIQQ+ZCDgyPXrl0lISEefX19HB0rEBR0lcDACwwb9iNr167O7SF+\nNNq2rYWp6WWOHNlC0aK69O79dW4P6aPwf/93mv37oylYMJkff6yDjY1Vbg/pBcnJyXh6/sqlS+YU\nLpzAuHGOVKtmT3h4GN9/PwQ9Pd1PMlAUQnxcJJgTQgjxVnR1dSla1I49e36lcmVnypYtx/nzZwkN\nDZVC0Vn4/PPKfP555dwexkdj//7zeHpaExPzBaDhypXV7NrVDn19/dce+7bep0adpWVzNm/uh6Xl\nPBITjzN0aBQTJw6jYkUnkpKSePw4mlGjPLl16yaNG39B6dJl2LZtEykpKUyb5oWdXTEiIyOZM2c6\nDx8+AOC77zypXNk5269TCPHpkjVzQggh3pqzswsbNqzFxaUazs5V2blzG+XLl8/tYYk84NChx8TE\nVPl3S8GFC3W5fftOjvR1795dvv66E2vXbsHIyIht2zazYMFsfv55FqtW+dO6dVuWL18CwKRJY+nY\nsTOrV6/H29uXx48LY2x8CAODa9y9u4unT7/jl18WEBkZiUajIikpiXPnzhIZ+ZQtWzZw4cJfqFRq\n7t27S/fuHRk8+Btmz55K48bNMDExIzk5hWHDvuXevfRrnTp1IvPnezFwYG86d/4fhw8H5MhnIITI\n3ySYE0II8dacnavy9GkETk6VMTe3wMDA4IWshvD6BBfi02NpqQaStdsWFvewtn6xCHt2+G+NutOn\nT2lr1Hl4uOLn58Pjx4+zrFFXoYIOBQueIja2DaCgTJkEqlatzq1bN3jy5AkAc+cupkmTZiiVuvz5\n50mUSh0mTpxGuXLlUSjg1KmTzJgxhYcPw1EqdTA0NGT27Gna8WUUTZ81az7Lli3Okc9ACJG/yWOW\nQggh3lr16jU5dOhP7faGDdu1r7ds2UViYiIpKSkcOJCeubBoUVvWrNn4wnnEp2Hw4P4MGfI9Dg6O\n/PnnL7Ro8ZDTpytibBzDkCH6WFjkTDD3PjXqPD2bc/Hibp4+TcbWNpHx4+vg738RhUKBqakpiYlJ\nODlVISUlhT/++J3IyEiioiKZOHE0aWlpFChQELVajUaTpq1fZ2lpSWRklHZsb1I0XQghXkVm5oQQ\nQmQbjUaDp+dWatQ4T82ax5k790BuD0l8ABqN5qVFvyFzUKWjo8PixV9x7lwFTp1qiLv75zk2rvep\nUZeWlso333SgYsUwVq5sg6lpQQIDL1CunD2g0BZL12g06OjoYGhYkIoVnZgzZxGffVafgweP0qBB\nI/T0dPH1XY+v73pGj57A2rWbteN7k6LpQgjxKjIzJ4QQItusXXuYtWv/h0ZjAcCCBZdo0iQIFxfH\nXB7Zp6lRozrY2tphZmaOtXURHBwq8PnnjZg7dxZxcdHo6urj7FyVo0cPkZKSQqlSpXn8+DHx8XEM\nGvQdjRqlpylev96PQ4d+JyUllc8/b0SfPt/8m9VxMJUqVebatavMnr2QtWtXExT0D8nJSTRq1JQ+\nfb7JclwajYZNm9ZTqJAJnTt3A8Db+xcsLArTqVPXbLn252vUzZgxmVKlytCxY1dq1fqMBQu8iIuL\nQ6VKo0sXV0qXLsO4cZOZPXsaK1d6o6ury88/z6Rhw8ZcuXIJd/duKBQKBg0aiqmpGdHR6bNrf/99\nmYMH96HRqNHR0eHhwwfcuRMMpBdH79SpG+fPn6VDhzYULGiIi0tV2rfv+G9AKIQQ70+COSGEENkm\nPDxZG8gBJCaW49at3yWYywVXr15BpVKxZs1GUlNT6d27Bw4OFZg1axo//jiKqlUrcvjwnwwdOoCN\nG3fg7f0LFy6cY/DgYZQqVYaRI7+nUaOmnDlzipCQ+6xY4YdarWbkSE8CAy9gbV2E0NAQxo2bTMWK\nTgD07z8IExMTVCoVw4YN4tatm5QtW+6FsSkUClq3bsfo0T/SuXM31Go1f/xxkBUr/LLt+m1sirJu\n3dYX9tvbl2fx4uUv7C9WrDgLFix9Yf+gQUMZNGiodvvBg3Ds7IoRFRWJp+dgNBoN1avXomdPd+bM\nmcGSJQtRq1Vcvx5E374DWLnSHy+vGUREPOHixQtYWlppgzlZUyqEeF8SzAkhhMg2zZqVxc/vOI8e\n1QfA3n4vjRrVyOVRfZouXw5EqdRFT08PPT09ChcuzObN63jy5AmDBvWlSBFrwsLCSEpKwtNzCEql\nLtHRUfzyy0KMjY2IiEhP8nHmzCnOnj2Nh4crAImJSYSE3MfaughFihTVBnIAf/xxgF27dqJSqYiI\neMKdO8FZBnOQHmyZmppy48Y1IiIiKF/eERMTk5z/YN6TjU1RNm3a+cL+tLQ0Vq70R6FQsHfveRYu\nfICXlz6NGp3Dy2vBC8Ha6NETMm1nrC8VQoi3IcGcEEKIbFO1qj2LF19i8+Zt6OqmMXBgBQoXtnj9\nge8oLi6Ogwf38dVXHTl//i82blzHrFnzcqy/vEUBpK/DOnPmFLGxsfzvfx3YsWMrDg6OfPvtQEqW\ndKBTp3YsWuTN4sXzMTIyomPHLjRs2IRmzZ6tZevRw53//S9zwfPw8DAKFiyg3Q4LC2XjxnWsXOmP\nsbEx06ZNIiUlmVdp06Y9u3f/SmRkBK1bt8u+S/+A0tLS+O67HZw4YYmRUSJ9+ypZuNCQsLAuAAQF\nPaZkySN4eDQCID4+nsmTAwgPL0ClSmn8+GMLdHQkhYEQ4t3Idw8hhMiHwsPD6NWrS5bvDR7cn6Cg\nqznWd6NGVViypDkLF7aiQoWcLSIeGxvDjh1bcrSPvKpKFWdUKhUpKSmcOHGMu3eD2blzK/HxcQQF\nBXH37l00Gg2pqamZjvtvIo7ateuwe3d6hlKAx48fERkZ+UJ/8fHxFChQECMjI54+jeDUqZOvHWPD\nho05ffokQUFXqV37s/e42tyzeHEAW7d2Izy8HTdvdmHmzEeEhT2brVSrrbh1K0m7PWTIXnx9u7Jv\nXwfmzGnB9Ol7c2PYQoh8QmbmhBDiE6NQKPLN+pxlyxYRGhqCh4crurq6FChQkLFjRxAcfAsHhwqM\nHz8FgKCgqyxePI/ExERMTc0YM2YChQtbMnhwfypVqsz5838RFxfLyJHjcXZ2yeWryh6OjhVRKnVx\nc+tKUlISZcuW4+uvO1G9ei28vGbg5+fH8uUrSUp6Fmg8/7WR8W/NmnW4c+cOAwZ4AGBoaMi4cVNe\n+Dqyty9P+fIOuLp2wNrahipVnF87Rl1dXapXr0mhQiZ59mvywQOAgtrtyMjaFCv2JyEhJQAwMLiL\ns3Mh7ftXrpgDyn+3TAkMfJbRUggh3pYEc0II8RHau/c3Nm5ch0KhoFw5e/r2HcC0aZOIjo7GzMyc\n0aPHU6SIDVOnTqRevQbarIPNmjXg4MFjmc6VnJzEtGmTuHXrJiVKlCI5OTnfpEEfOPA7goNv4+u7\nngsXzjFqlCdr126hcGFLBg7sw6VLF6lY0Yn582czc+ZcTE3NCAg4wPLlSxg1ajwKhQK1Ws2KFWv4\n888T+PouZ/78Jbl9WdlGV1eXDRu2c+LEUSZMGEPJkmUoWtSWkSPHYmNjjkqlR6dO6Y83jh49gfnz\nZxMfn15z7fk1XJ06dc0yy+R/awf+dx1YhkWLvLWvt2zZpX2tVqu5cuUyP/88690vMpfVrWvOxo3X\nSEhwAKBKlSAmTizFkiWbSEzUo0kT6NSpmba9lVU8wcEZWxoKF0748IMWQuQbEswJIcRH5vbtW/j5\n+eDt7YuJiSkxMTH8/PMEWrVqS4sWrdm9exfz53sxfbpXFrMZL85u7NixlYIFDVm7dgu3bt2kd+/u\nH80syKpV3hgaGtGtW49M+8PDwxgxYjh+fpteefzzQalGo6FChUpYWloBUK5ceR48CMfY2Jjg4FsM\nGzYISA8gChe20h7XsGFjABwcHHnwIDxbrutjkZqagoeHKykpKVSvXgMvr2kAFCxoyPz5cylQwIzn\nv2aaNm3OzJlT2bp1E1OmzMDOrliOjGv16iNs336P6Gg/Pvusdo718yG0a1eLuLjjHDz4N4aGyXh6\nVqNMmWLUr18ly/YTJzoyZsxawsNNcHB4ysSJDT/wiIUQ+YkEc0II8ZE5f/4sTZo0w8TEFAATExP+\n+ecy06d7AfDll61YunThG58vMPCidlalbNlylC378dS4yu6gUk9PX/taqdRBpVIBULp0WZYt83nl\nMTo6Sm37DyWrmdTsdOTI6Ze+Z2VViEePYli6dJX2PlSu7JypqHVOOHDgPJMmlSM+vg0wiLi4vQwf\n/pgiRaxee+zHytW1Pq6ub9a2Ro3y7N9fHpVKhVKpfP0BQgjxCpIARQghPjIKhSLLxyCz2qdUKlGr\n0/er1WrS0lJfaJOd1q/3Y+vW9EfrFi6cw9ChAwE4d+4skyeP4+DBfbi5daVXry4sXbpIe1yzZg20\nrw8d+p1p0ya9cO6goKu4uXXD3d31jZOaGBoakpDw6sfUSpQoRVRUJH//fRlIzz4YHHz7jc6f83Jv\nhlSlUjFgwCZq1w6lVq2/mTVr3wfp9+zZR8THP6s7eP9+Xc6evf5B+v6YSCAnhMgOMjMnhBAfmWrV\najJ69A907dr938cso3FyqkJAwAG+/LIVBw7sxdm5KpBe8+ratas0afIFx48fJS0t7YXzubhU5eDB\nfVSrVoPbt29y69aNF9qEh4fh6TkEJ6cqXL4ciKNjRbp27cT8+QuIjIxiwoT0RCL79u3h0aOHBAQc\nICUlFaVSybff9qNUqTIUL16CZcsWY2FRmJ9+GsPixfM4duwwDRo04vmg5b+zcRmb06dP4vvvR+Ls\n7MKSJQve6LMyNTWjcmVnevXqgoGBARYWhV9oo6ury5QpM1mwwIu4uDhUqjS6dHGldOkyWZwxe4Or\n9ev90NfXp2PHrixcOIdbt26yYMFSzp07y2+//R8Ay5cv4eTJ4xgYGDBjxhzMzS0IDw9j+vTJL6yR\nzE6LF+9jxw5XwAiAX365SOvWN6hUKWdnbh0cCqGvH0pKih0AlpYXqVKlVI72KYQQ+ZXMzAkhxEem\ndOky9OrVm8GD++Pu7srixfMZNuwn9uz5FTe3bhw4sJehQ38AoF27r7h48Tzu7q5cuXKZggUNtefJ\nCJrat+9IQkICPXp0YtUqbxwdKwLQsWNbYmKite1DQ0Po2rUH69dv4969u+zZs4elS30YPHgofn6+\nlCxZmhUr1mBiYoKrqxuPHj3AyakyVatWY9eu7RgbF8LBwRGNRoO9fXmaNWuBt/cSDh8OeO01x8XF\nERcXp80k+eWXrd/485ow4Wf8/DaxYoUfM2c+qzE3fPhPtGzZBkjPtLh48XJWr16Pv/9m2rRpD6Qn\n5nBwSJ8lMjMzY8uW/3vjft+Es3M1AgMvAukzj4mJiaSlpXHp0kVcXKqRlJSIk1MVVq9ej7NzVXbt\n2gHAvHmzadWqLWvWbKB58xbMn+/10j727v2NJ0+evPXYHj9WkRHIASQmluT+/cdvfZ7nhYeH0b17\nR2bOnErPnp35/vvBJCcnc+PGNfr3d8fNrRvnz+/km28O4Ojoh4NDYyZNSiE5OZ4GDWry6NFDADp3\n/h/Jya+uUSeEEEJm5oQQ4qPUsmUbbSCSYcGCpS+0Mze3wNvbV7s9cOAQAIoWtdVmGjQwMGDSpGkv\nHPvfxzmLFrWjTJmyQHpAWbdu3X9fl+XBgzDi4mKZN282kZFPmTNnOmq1mipVXLh9+xYajQZra2vu\n37/PV191/PeMGhSKF0shvMkv6R8i2+bOnae4eDGKChWM6dKlfo704eDgyLVrV0lIiEdfXx9HxwoE\nBV0lMPACw4b9iJ6eHnXr1v+3bQX++it9jdvbrJHcs+dXSpcui6Wl5VuN7X//c2DFiqM8epReHNzJ\naS/16zd+l8vMJCTkPpMmTWfEiDGMHz+KI0f+YN06P77//iecnauyapU38fF3OXrUk549N9C6dWX2\n7v0NR8eKXLx4gSpVnLGwKIyBgcFL++jYsS0+PmsxMTHN8XWHQgjxMZNgTggh8qm4uHh++mk/N2+a\nYmMTganpIWJjo1GrVbi59QVg69ZNnDhxjKSkRCA9gIqJiebChb+4dOkC5uar8fDoh0ql4qefhmNr\na0u3bj3ZtWsHcXFxFC1qx8KFc1EoFERHR3P3bjBXrlxm69aNREZGYmtrh0ajwcLCgrt371C8eAmO\nHj2EkZExkB60aTRgbGyMsXEhLl26SJUqLhw4kLOFlH/55XdmzHAhObk0enphBAfvZuTIN58NfFO6\nuroULWrHnj2/UrmyM2XLluP8+bOEhoZSqlRplMpnP4Z1dBSoVCrCw8OIiYlh1qxp/PPPZQoXtkSj\ngRs3rjF79nSSk5OxsyvGqFHj+euv0wQFXWXy5LEUKFCApUt9XhkEPa9mTQe8vSPZsmUL+voqhgyp\nibGx8Xtfc9GidpQrl/6opoODI6GhIcTFxWofDW7RojXjxo0EwMnJmUuXAgkMvEjPnh6cPn0S0FCl\nyqtr/WV+VPfjyMwqhBC5QR6zFEKIfGrMmANs3dqDixfbc/x4Ma5fT2X16vX4+W2iTp3PADAzM8fH\nZy3Nm7ckOjoKSC8XYGJixsiRI/nmm29ZvHgeGo2G1NQUjIyMcHauytOnESgUCkxNzTAwMECpVLJq\nlTd2dsU4evQQSqWSZs1aEBoagkKhYMCAwfz00zAGDuyjLR0AGbN26a9Hj57A3Lmz8PBw1b6XU/bt\nU5OcXBqA1FRbDh58swDoXTg7u7Bhw1pcXKrh7FyVnTu3Ub58+Vceo1arsbOzw99/M/Hx8RQtasvP\nP0/k22+HsmbNBsqWLYev73IaN/4CR8cKTJgwFR+fdW8cyGWoV68S8+e3YNas1hQvnj1r8vT19QgP\nD6NXry7o6CiJi4slPj4eH5/lbNmyEU/PIdy5E8zEiWNwcamKn58PV65cpkGDhty4cZ1582ZTsmT6\nvRk16gf69OlJz56dtY+gvsyUKeM5duywdnvSpLEcP37k5QcIIUQ+IDNzQgiRT927ZwykZ8xLTnYg\nLu4uS5cuom7dBly79g8ajYaGDZsAUKZMWW3ylMuXA/EeK58AACAASURBVLGzK4ZCoaBq1RrExsZi\nYmJC5crOHDt2hFu3btKzpwfr1q0BYMOG7TRr9jnGxsaUKFGKnj09aNWqLQAREelrsBo1aqotbP68\n3r37a187ODiyevV67fagQd9l/4fyLwODzIliChRIybG+nJ2r4u/vi5NTZQwMCmBgYKCdpXo+YH3+\ntY2NLefOnSUg4CDJyUnUr/85hw4FZDm7BR/msdR3ZWRkTIECBjx48IBdu3bQunU7kpKS6N27H7Gx\nscydO5NixYqjUCgwMTEhMTGBChXS13WOGjUeExMTkpOT6NfPjUaNmmJiYpJlP23btmfTpvU0aNCI\nuLg4/v77MuPGTf6QlyqEEB+cBHNCCJFPlSwZz4kTKkBJamopihXrTdmyZqxYsYSbN29gZGSEvr4e\nANbWRbSJUSA9kHJ2duTx41iUSiXe3r5s2bKR7t174eraC4CAgAPa9hqNBrVaTdGiRd86sEhMTGTj\nxmPo6kKXLo3Q19d//UHvaciQ4gQH/8b9+zWxsQlk0CDrHOurevWaHDr0p3Z7w4bt2tcHDjybOcoI\neENDQyhQwEC7RnLDhrU8efLolX18LEXgIatspQqaNv2SP/44SHx8PLt376J37/7o6CixsSkKgK1t\netFwZ+eqnD//F4aG6YlZtmzZwLFj6Z/Ro0cPCQm5R8WKTln26+JSjTlzZhAVFcXhw7/TuHETdHTk\nASQhRP4m3+WEEOIjFx4ehqtrB6ZNm0S3bl8zadJYzpw5xYABvena9WuuXr3CqlXebNiwVntMz56d\nGTq0Ch06+FKhQjsqVaqHQrEVpVKXkiVLER8fR0TEE0aN+uGF/qpUqapds3b+/F+YmZljaGhE0aK2\nXLsWBMC1a0GEhYUyePBeatWaRVJSIq1adcLZuRoBAQdRq9U8efKE8+fPvfLaEhIS6Nx5FyNGtMPT\nsw3du28jJSXnZskyNGxYmd9/r8yWLVf4/fdytG5dI8f7fBP791+gc+fD3LyZTMeOm3n6NBJIn90y\nMTHRZsbct283VatWB9Jr7cXHx+XamJ+XkXgno/5ht2498PDoR6FChWjbtj0HDhxh/Pgp3L17h379\neqFSqXB17aWdievZ0wNr6yJA+tfeuXNn8fb2ZfXq9djbO7z2a6NFi9bs37+bPXt+o3Xr/+X49Qoh\nRG6TYE4IIfKA/5YNCAg4wLJlz8oGZDUbUrBgQbp0saRVq4rMnDkRAwN9/PxWcf36NQoXLkzhwlba\njInwLONk7979uXYtiHbt2rF8+RLGjp0IQMOGTYiNjaFnz85s374ZPT0LTp7swJ07U1CpjNi504iG\nDRtTvHhxevToxNSpE6hcucorr8vf/yinT7sDeoABR470YNu2o9n62b2MubkFDRvWxNra6vWNPwCN\nRsPPP4cSHNwWlcqYo0d7M3VqepZGhULB6NETWbJkAW5u3bh16yYeHv0AaNWqLV5e0+ndu/tHk87f\nwqIwUVFPiYmJJiUlhZMnj6PRaHj48AHVqtVg4MAhxMXFERMTQ0hINIcOHUOlUnHtWhDh4WEAJCTE\nU6hQIQwMDLh79w5Xrvz92n5btWrL5s0bUCgUlCxZKoevUgghcp88ZimEEHnAf8sG1KhR69/X6WUD\n7O2zSqihoGxZe375ZQEmJiZ8//0IbR23Tp3asWqVPyYmpgA4OlZg4cJlAJiYmDB9uhdWVoV4/DhW\nezYDAwPmzl2s3T5z5iBpaemFn2/dOo+Ozg4iIiLo0cOD4cN/eourk8yEACkpKTx9akpaWjHu3v0V\ngKioAnTr1k7b5vkyFBkaNmyiXfv4sdDV1cXdvS/9+rlhZWVNqVKlUalUTJ48jvj4ODQaDV991ZE+\nffZz4sRgbG2H88UXrWjWrC7Fi5cEoHbtuuzcuY0ePTpRvHhJnJwqZ9nX83/IMDe3oFSpMnz+eaMP\ncZlCCJHrJJgTQog8IGNtG4COjg56enra1yqVCqVSiUaj1rbJeBytePES+Pis488/j7NixRJq1KiF\nu3vfbBlTxYoJnDiRCBQEVKSmXqR2bXPUaj3atTvCvHkdXruWq0ePBuzatZqzZ90ANZ9/7k+HDh2y\nZXx5jYGBAS4uYRw8mL7OUU8vlNq1X/wxvX79cQICEjAySmLEiNrY2RV56742b15Pnz5u2u0ffxzK\nxIlTMTIy1tZtCw8PY8SI4fj5bXqn6+nYsSsdO3Z96fve3vs4frw9oEdoqB+hoVEMHXqc0aMnaNt4\neWVdX2/Lll3a18+vO0xKSiIk5B7Nmn35TmMWQoi8RoI5IYTIB4oWteXEifRH8p5/VO3JkycUKlSI\n5s1bYmRkzO7d6b8Ep6+zitfOzL2LiRNboau7k2vX9FGp/uHEiX6kpaUnstiwoQJ16x6hc+dGrzyH\nkZERW7a0Yf36Hejq6tCt29cfJAHKx8rbuzVTp24iIsKA2rX16d07cxHvHTtOMWpUGRITHQANN26s\nYdeudtrg/k1t2bIRV9fOZPwaMHv2AgDi4uK0WU3f1NSpE6lXr0GW2UpfJTVVQ+ZfQwxISnq7vjNc\nuHCD5cv3c+PGDnr27K5NoCKEEPmdBHNCCJEHZLUm7vnXDRs2Yd++3fTs2ZmKFZ20j6rdvn2TX35Z\ngI6OAl1dXX74YTQA7dp9hafnEKysrLVZE9+Wnp4ekya1AcDHR8ORI7ba9zQaCx4+THyj8xgaGtK3\nb4t3GkN+Y2xszPTpbV/6/okTMf8GcgAKAgOrERoaQqlSpV96TGJiIuPHj+Tx48eo1SoaN/6CJ08e\n06tXLwoVMmXBgqV07NgWH5+1xMfHv3Uwl14r8O0fj+3e/TN27vTj0qVegIrPPltL+/bt3/o8ly/f\nonfvJ4SGjgJGsnGjL126JFGgQIG3PpcQQuQ1EswJIcRHLiNDYIbnH0N7/r3n17NlsLGxoVatOi/s\n79ChCx06dMm2MbZuXZ1Vq3Zw40b6I5IlS/5GmzavTn4i3p6VVRqQDKQXB7e2vkfhwlVfeczp0yex\ntLTWzr7Fx8exZ8+v+Pv7k5qaXocwIxhbtmwRGo0GDw9XKlSoRETEE3r16oJCoaBXrz40bdoMjUbD\nvHmz+OuvM1hbF8k0K+jru4KTJ4+RnJyMk1MVfvppDKGhIYwbNxIfn/Rsq/fv32PChNH4+Kxl8+bG\n+PtvQVcXPDzavVMAtmvXDUJDO/27peDcuTacOnWJRo1qvfW5hBAir5FslkII8QlITk5m7NhduLoe\nYNSoXSQlJWXr+YsUsWT1akfc3TfRq9dmfHzsKF3aLlv7EDB8+Be0beuPtfUeypTZzJgxBhQqlHUR\n7Qxly9rz11+nWbp0EYGBFzEyMn5p24EDv0OhUODr+6wUwJo1G5k/fwlLliwgIuIJR48e4v79e6xb\nt5WxYydz+fIl7fEdOnRhxQo//Pw2kZyczIkTx7CzK4axsTE3blwHYM+eX2ndOj2pi4WFOUOHtuTb\nb1tiaGj4Tp+JsTHAs5IFBQo8wMrK7J3OJYQQeY3MzAkhxCdg9Og9+Pt3BfSBVOLi1rNo0dfZ2oe9\nfQlmzSqRrecUmenr67NqVReSk5PR19d/o8cb/5sEp3r1mi9t+3zB96CgfzA2NkahUGBuboGLSzWu\nXv2HwMALNGvWAoVCgaWlJdWrP6vRd/78Wdav9yc5OYmYmBjKlClLvXoNaNOmPXv2/MqQIcP544+D\nrFjh934fxHMGDGjMmTO+/PFHAwwMound+w6VKrXJtvMLIcTHTII5IYT4BFy5Ykx6IAegxz//vHo2\nR3zcDAwM3rjtf5Pg/Pbb/2FoaERcXBwGBq9KgPPyQPH5oC9DcnIyc+fOYtUqf6ysrPHxWa6te9ew\nYWN8fZdTvXoNHB0rYGKSfV9/BgYG+Pt35dat2xgZmWNr65Rt5xZCiI+dPGYphBCfgCJFEjJtW1sn\nvKSlyG9u375J//7ueHi4snr1Stzd+9KuXXv69u3L0KEDM7V9/lHHChUqEBcXj1qtJjIyksDAC1Sq\n5ISzczUCAg6iVqt58uQJ58+fA56VwzAxMSUhIYFDh37XzhwaGBhQu/ZneHnNoFWrdmQ3HR0d7O3L\nYWsrj/YKIT4tMjMnhBCfgEmTahIb60dwsAklS8YyeXL13B6S+EBq1arzQhIcBwdHBgzoqy0Kv2XL\nLu1s2xdffEmvXl2oU6cuX33VAXf3bigUCgYNGoq5uQUNGzbm/Pmz9OjRiSJFbKhcOT3RTaFChWjb\ntj29enXBwqIwFStmniH74osWHD16OMuEPEIIId6NQpPVsxK5IOMHish/rKwKyf3Nx+T+5i0ajeaN\n08jLvc3fnr+/06fvYft2fZRKNd276zBkyBfZ2ldsbAxr165GV1ePfv0Gvv4A8V7k/27+Jvc3/7Ky\nKvTWx8jMnBBCfELepR6YyN927z7FkiWfkZycnrxmzpyr1Kx5mTp1KmfL+VevPsrSpb5oNDHY2LSj\nU6dozMzevVi9EEKIZ2TNnBBCCPEJu3kzShvIASQkOPDPP6HZcu6EhATmz1cRHLyVO3cOcOrUIGbN\nOpYt5xZCCCHBnBBCCPFJa9SoLFZWJ7XbdnZ/0KRJ9szKxcfHEx1t9dweHeLi9F7aXgghxNuRYE4I\nIYT4hDk72zNvXiotW26hdevNLFxoRKlS2ZMV0tLSklq1/gZUAJiYXOKLL8yz5dxCCCFkzZwQQog8\nKjw8jBEjhuPntym3h5LnNW9ejebNs/+8CoUCH5+2eHltJiZGj6ZNC9OqVa3s70gIIT5REswJIYQQ\nIscYGRkxYUKb3B6GEELkS/KYpRBCiDxLrVYzc+ZUevbszPffDyY5OZnQ0BA8Pb+jT5+efPttP+7d\nu5PpmIEDewPw4EE4Bw/uy4VRCyGEENlDgjkhhBDZIi4ujh07tn7QPu/fv0eHDp3x99+MsXEhjhz5\ng1mzpjF8+I+sWuXPoEFDmTNnZqZjli71ASAsLJSDB/fn6PiaNWuQ5f6dO7exb9/ulx53/vxf/PTT\n8JwalhBCiHxCgjkhhBDZIjY2hh07tnzQPosWtaNcOXsAHBwcCQ8P4++/Axk3bgQeHq54eU0jIiIi\n0zEZAdayZYu5dOkCHh6ubN68IYdGmHVdv/btO9CiResc6vPVOnZsS0xMdK70LYQQInvJmjkhhBDZ\nYtmyRYSGhuDh4UrNmrXRaOD06ZMoFAp69epD06bNsr1Pff1nae51dJTExDzF2LgQvr7rX3FUeoA1\ncOAQNmxYy6xZ8965//Xr/dDX16djx64sXDiHW7dusmDBUs6dO8tvv/0fAMuXL+HkyeMYGBgwY8Yc\nzM0tWLXKG0NDI7p160FIyH1mz55OdHQUOjo6TJkyA4VCQWJiAmPHjiA4+BYODhUYP37KO48z09VL\n4XghhMg3ZGZOCCFEthg48Dvs7Irh67ueihWduHnzOmvWbGT+/CUsWbKAiIgnOT4GIyMjbG3tOHTo\ndwA0Gg03b97Isq1Go3nv/pydqxEYeBGAoKCrJCYmkpaWxqVLF3FxqUZSUiJOTlVYvXo9zs5V2bVr\nB5AeUGXEVJMmjaVjx86sXr0eb29fLC0t0Wg03LhxjWHDfmDt2i2EhYVy6dLFtx7fqFE/0KdPT3r2\n7KztO0NCQgI//jgUd3dXevXqQkDAQQD++usMvXt3x82tK9OnTyY1NfU9PiEhhBA5SYI5IYQQ2eL5\n4OjSpYs0a9YChUKBubkFLi7VuHr1n2zv87+zTAqFgvHjp/Dbb7twd3elZ88uHD9+JNv7zeDg4Mi1\na1dJSIhHX18fJ6fKBAVdJTDwAs7OVdHT06Nu3fr/tq3AgwfhmY5PSEggIuIJDRo0AkBPTw8DgwKo\n1WoqVKiEpaUVCoWCcuXKv3Dsmxg1ajyrVvmzcqUfW7du1D5eqdFoOHbsGJaW1qxevR4/v03UqfMZ\nycnJTJs2icmTZ7BmzUZUKtUHXwcphBDizcljlkIIIbKdQqF4YeYrux/vK1rUljVrNmq3u3XroX09\nZ87C1x5vaGhEQkL8G/cXHh7GDz98R5UqVfn770CsrKyZPn0OFhaF6dfPjbi4OB4/fgzA/fv3+eGH\n71Aq03/MJiYmMnfuDOrWbUBoaAj79+8lJSWZw4f/IC0tDYCpUyeir6/PjRvXsbGxQU9PX9u3UqmD\nSqV647Fm2LJlA8eOpQezjx494v79+0D6vXBwcGD69BksXbqIunUb4Ozswo0b17G1taNYseIAtGzZ\nhu3bN9O5c7e37lsIIUTOk5k5IYQQ2cLQ0JCEhAQAqlRxISDgIGq1msjISAIDL1CxYqUc7T8tLY2d\nO4+ybdthUlJSXtouI6gsV84epVKJu/ubJ0AJCbn/QvbMiIgI4uLiGD9+CkOGDGPHjm04Ojpib19e\nG4CdPHkMe3sHFAoFs2ZNpU6dz+jc2ZUhQ74nOTmJY8cOA+kB18KFS2nfvuP7fRikZ8Q8d+4s3t6+\nrF69Hnv78qSkJGvfL1WqFD4+6yhbthwrVixh9eqVLwTc2fEoqhBCiJzz3jNzPj4+zJo1i1OnTmFm\nZgaAt7c327ZtQ0dHh7Fjx1K/fv33HqgQQoiPm6mpGZUrO9OrVxfq1KlLuXLlcHfvhkKhYNCgoZib\nW+RY32lpabi5bebgwa6Akg0bNrBu3VcYGBi80PbAgfSZKl1dXRYsWPpW/WSVPfPJk0ekpqayaNFc\n7Yyks3NVzM0t+PPPEwD8/vsBKld2JjQ0hMuXLxEcfBsdHR0OHNiDubkFW7du4vr1IIyNC/H06dNM\na+reVUJCPIUKFcLAwIA7d4K5cuXvTO8/evQIfX19mjdviZGRMbt378LVtRfh4WGEhoZgZ1eM/fv3\nULVq9fcbiBBCiBzzXsFceHg4J06cwNbWVrvv5s2b7Nmzh927d/Pw4UM8PDzYv38/OjoyCSiEEPnd\nhAk/A+nFvNPS0hg0aOgH6Xf79qMcPNgdMAbg6FE3/P130bfvl9o2cXFxzJp1mOhoPRo3NqN9+9pv\n3U9W2TNNTEz5v/97sfh4QkICVlbWxMTEcP16ENOmzSYhIZ5z585m2X7atEnUrVsfW1s7LCwKM27c\ns+yVw4f/9NZjrV27Ljt3bqNHj04UL14SJ6fK/76THiVev36dadNmoKOjQFdXlx9+GI2+vj6jR09g\n3LgRqFQqKlSolC2zhEIIIXLGewVz06dP58cff2TQoEHafQEBAbRu3Ro9PT2KFStGiRIluHTpEi4u\nLu89WCGEEB+/detOsGhRLHFxRtSrF8Yvv3RAVzdnl2inpKgAvef26JKaqtZuaTQaevfezeHDHoCS\nXbv+QaM5xVdf1Xmvfp/Pntm48Rfa7Jn29uUxNDTE0bEiCxbMpl69BigUCoyMjLG1tc3U/tatm9rZ\nPgAvr/2sWVOA1FQDvvwyhHnzvn6nP4jq6enh5fXi2sEtW9JLJpQtW581a549XhoTE83hw2ewty+G\nj8+6d/g0hBBCfGjvPF32+++/Y2Njg6OjY6b9jx49wsbGRrttY2PDw4cP332EQggh8oynTyOYPl2H\n27c78ehRK3bs6M6iRb/neL8dOtSnVi1/QAWocXFZQ/fu9bTvN2vWgL/+qggoAYiPr8gff7x94ew3\nyZ554sRR7ftNmzbj4MH9NG3aXLtv/PifX5pt886dMBYtcuDhwzY8fdqMDRu+Yu3anMvGmeHSpZu0\nbHmazp0r07TpQ9auPZ7jfQohhHh/r/xTqYeHB0+evFgXaNiwYSxfvhwfHx/tvlctkn6TDGZWVoVe\n20bkXXJ/8ze5v/nX297b8PD7PHpU+rk9BYmNNfgAXyOF+OOPbixbtgeVSsM333TA1NRE+66Ojg4W\nFhHExWXsUVOkyMuvLzY2ll9//RVXV1dOnz6Nr68vy5YtY8+e3do23303kLFjx2JkpIufn2+W5+nU\nqT2dOrXPtM/KyiHL9vPmebFxYwCJiWWe22tGbGzO/R/LOK+3921u3OgAwNOn1qxYsY3hw+X/dV4m\n35fzN7m/IsMrgzlf36x/OF2/fp2QkBDatWsHwMOHD+nQoQObN2+mSJEiPHjwQNv2wYMHFClS5LUD\nefw49m3GLfIQK6tCcn/zMbm/+de73FsLC2ucnfcRGJj+2KCR0T/UqPHhvkZ69WoEQEpK5p8rGg0M\nH66Hl9dmlMoNGBtHEBRkxI4dBahfv+ELZQdMTEyJjo6iWbO2BAb+w+nTp2nTpi01atTm9OmT+Plt\nYs+eX1GrFZiYWPP4cSw//TSMbt16UrVqdby8ZhAU9A/JyUk0atSUPn2+AeDPP4+zePF8ChQoiJNT\nFS5c+JsaNbrx+ecl2bNnMzdv3qB8+SmEho4jPr4pRYoco169Yjny+T1/f2NiMv9BNi5OyaNHMdle\nTkJ8GPJ9OX+T+5t/vUuQ/k6LGMqXL8/Jkye1202aNGH79u2YmZnRpEkTPD09cXd35+HDh9y9e5cq\nVaq8SzdCCCHyGAMDA1aurMWcORtISNDnyy+NadWqbm4PC4Du3evRrl0sERGOlChRkpiYGAYM8KB+\n/YZAetmBSZOmM2LEGL7+ujVPn0bg4eHKvXt3KVmyNLa2duzevQu1+tlavCNHDtGqVTvs7ctz48Z1\npk2bhKGhIU2bNueHH0aiUqkYNmwQt27dpFix4syePZ0lS1ZiY1OUr792IzjYlH37OlGq1FB69SqF\nj88ELlz4hx9+GE758g/p1q0ULi72L7ukbNOmjTEnT/5NbKwTCsVTmjaNlkBOCCHygGxZkf78N/xy\n5crRsmVLWrdujVKpZMKECfIDQQghPiElSxZl4cI2uT2MLBUsWJDt2zcTGHgRHR0FT548JjLyKZC5\n7EDz5i3ZvXsXixYtx9W1A2FhIcyaNY+oqCi++caDy5cDM533+vVrpKSkMGHCz7i4VGPTpnX07t0D\nlUpFRMQT7ty5jVqtwtbWDhuboiQnJ3P7dh3gNgBq9X127LjAuXN/AGBurs+oUZUpUaLUB/lcunSp\nR+HCFzh5cgvFiunj4fHVB+lXCCHE+8mWYC4gICDT9oABAxgwYEB2nFoIIYTINgcO7CU6Ogofn7Uo\nlUo6dWpHcnJ6gfHMZQcUQMajhxoqVKiEpaUVUVFRGBjoEx4ejlKp1La3sytGUlIiW7ZsICwslB07\ntrJypT/GxsZMmzbp3yLmz/6wqVAo0NFR89wkH6VLd2Plyh45ePWv9sUXVfnii1zrXgghxDuQ4m9C\nCCE+GfHx8ZibW6BUKjl//i8ePAh/ZXtjY2MMDAqQlJQMQEDAAUCBSpWGjY0tiYkJaDQaEhLiUSp1\nsbd3YN++3cTERGNkZMTTpxGcOpW+LKFEiZKEhYXy4EE4+vr6lCt3Dh2dBECNnp4N1tb/aPu9fj0o\npz4CIYQQ+UjOFv4RQgghPgIZj/s3b96CESO+x82tKw4OFShZsvQLbQD09PRJTU0FwNW1J0uXLsbD\nwxUXl+ro6aXP4Dk7u6Cvb8DYsSMoXboM9vblcXGpxuefN+Lbb/vj6toBa2sbqlRxBtLXE3p6jsTT\ncwgFChSkVq2KWFndp0aN7TRqNJitW/1xc+uKWq3G1taOmTPnfaiPRwghRB6l0LyqpsAHJFl58i/J\nupS/yf3Nvz71eztp0lhu3bqBnp4elpZWzJw5D3//1QQE7KdLl+60bNmGIUO+YfDg4SiVSqZNm4RG\nk/7c5IABQ6hd+7MXzpmYmEjBggUBmDNnJsWLl+DzzxsTFHSPqlXLY2pq9sGu71O/v/mZ3Nv8Te5v\n/vUu2SwlmBM5Tr7p5G9yf/MvubfpAgIOsnatLyqVChsbW8aMmZAp6Lp3L4Rr10KoUcMBc3PzV55r\n8+b17N37G6mpaTg4OFCuXHOmTCnAkyeVKFPmTxYvLkaNGg45fUmA3N/8TO5t/ib3N/+SYE58lOSb\nTv4m9zf/ys/3tmPHtvj4rMXExPS9zrNu3XEmTzYiMtKJMmWOvnUw1rjxPq5c6aTdbtlyE2vWtHqv\nMb2p/Hx/P3Vyb/M3ub/517sEc5IARQghxCdHoVCQHX/L9PaOJTLyc8CC27fb88svN9/q+KQkvUzb\nycmylF0IIcSbk58aQggh8rXExETGjx/J48ePUatVuLn1BWDr1k2cOHEMlSqNKVNmUKJEKRITE5k3\nbxbBwbdRqdLo3bu/tqh4VpKTMwdjKSl6L2mZtaZN4wgOfoJabYmh4TVatDB4+wsUQgjxyZJgTggh\nRL52+vRJLC2tmT17AQDx8XEsW7YIMzNzfHzWsmPHVjZsWMuIEWPx8/OhRo1ajB49gdjYWPr3d6NG\njdoUKFAgy3M3axbPypWPUautKFToCm3bGr7V2KZMaYe9/RHu3EmiZk1LWrV6eeAohBBC/JcEc0II\nIfK1smXt+eWXBSxduoi6dRvg7OwCQMOGTQAoX96RI0f+AODMmVOcOHGUDRv8AUhNTeXRoweUKFEq\ny3NPmdIOR8ejBAcnUq+eDU2a1H+rsSkUCtzcGr3bhQkhhPjkSTAnhBAiXytevAQ+Puv488/jrFix\nhOrVawKgr5/+SKRSqYNKpdK2nzp1NsWLl3ijcysUCnr0kNk0IYQQuUMSoAghhMjXnjx5gr6+Ps2b\nt8TVtRfXr197adtateqwdetG7fb160EfYohCCCHEO5FgTgghRL52+/ZN+vd3x8PDFV/fFbi59QEU\nz7VQoFCkb7u79yUtLQ03t6707NmZVau8c2XMQgghxJuQOnMix0k9lPxN7m/+Jfc2f5P7m3/Jvc3f\n5P7mX1JnTgghhHgHarWaESO2U69eAM2a7WH37nO5PSQhhBDitSQBihBCiE+et3cAvr7/A8wAGDv2\nN+rXj8LU1Cx3ByaEEEK8gszMCSGE+OTdu6cmI5ADCA2tQEhIeO4NSAghhHgDEswJIYT45FWrVggD\ngzvabUfH85QuXTLXxiOEEEK8CXnMUgghxCevU6e6REQEEBBwDkPDFDw9HTA0NMztYQkhhBCvJMGc\nEEIIAQwY0JQBA3J7FEIIIcSbk8cshRBCCCGE/MFDnAAADGJJREFUECIPkmBOCCGEEEIIIfIgCeaE\nEEIIIYQQIg+SYE4IIYQQQggh8iAJ5oQQQgghhBAiD5JgTgghhBBCCCHyIAnmhBBCCCGEECIPkmBO\nCCGEEEIIIfIgCeaEEEIIIYQQIg+SYE4IIYQQQggh8iAJ5oQQQgghhBAiD5JgTgghhBBCCCHyIAnm\nhBBCCCGEECIPkmBOCCGEEEIIIfIgCeaEEEIIIYQQIg+SYE4IIYQQQggh8iAJ5oQQQgghhBAiD5Jg\nTgghhBBCCCHyIAnmhBBCCCGEECIPkmBOCCGEEEIIIfIgCeaEEEIIIYQQIg+SYE4IIYQQQggh8iAJ\n5oQQQgghhBAiD5JgTgghhBBCCCHyIAnmhBBCCCGEECIPkmBOCCGEEEIIIfIgCeaEEEIIIYQQIg+S\nYE4IIYQQQggh8iAJ5oQQQgghhBAiD5JgTgghhBBCCCHyIAnmhBBCCCGEECIPkmBOCCGEEEIIIfIg\nCeaEEEIIIYQQIg+SYE4IIYQQQggh8iAJ5oQQQgghhBAiD5JgTgghhBBCCCHyIAnmhBBCCCGEECIP\nkmBOCCGEEEIIIfIgCeaEEEIIIYQQIg+SYE4IIYQQQggh8iAJ5oQQQgghhBAiD5JgTgghhBBCCCHy\nIAnmhBBCCCGEECIPkmBOCCGEEEIIIfIgCeaEEEIIIYQQIg+SYE6I/2/v7kKzrB8/jn98uPkF1cl0\nbJJYoJRFrA6DDkpbc2s6FM0jBTWwDkKWppAPGPYgaxAdFQpp5YFgaCFoBLpSpFYY0QSDEmQo6UzN\npzrYXNf/IBr/KP3lw4953bxeZ7vu2/GVD0Pe933NGwAASkjMAQAAlJCYAwAAKCExBwAAUEJiDgAA\noITEHAAAQAmJOQAAgBIScwAAACUk5gAAAEpIzAEAAJSQmAMAACghMQcAAFBCYg4AAKCExBwAAEAJ\niTkAAIASEnMAAAAlJOYAAABKSMwBAACUkJgDAAAoITEHAABQQmIOAACghMQcAABACYk5AACAEhJz\nAAAAJSTmAAAASkjMAQAAlJCYAwAAKCExBwAAUEJiDgAAoITEHAAAQAmJOQAAgBIScwAAACUk5gAA\nAEpIzAEAAJSQmAMAACghMQcAAFBCYg4AAKCExBwAAEAJiTkAAIASEnMAAAAlJOYAAABKSMwBAACU\nkJgDAAAoITEHAABQQmIOAACghMQcAABACYk5AACAEhJzAAAAJSTmAAAASkjMAQAAlJCYAwAAKCEx\nBwAAUEJiDgAAoITEHAAAQAmJOQAAgBIScwAAACUk5gAAAEpIzAEAAJSQmAMAACghMQcAAFBCYg4A\nAKCExBwAAEAJiTkAAIASEnMAAAAlJOYAAABKSMwBAACU0A3F3JYtW9LS0pLp06ens7Nz6PqGDRvS\n1NSU5ubmHDhw4IYPCQAAwF+Nvt4/2N3dna6uruzcuTOVSiVnz55Nkhw5ciS7d+/Orl270tfXl4UL\nF+bTTz/NyJHeBAQAALhZrruwtm7dmsWLF6dSqSRJampqkiR79+5Na2trKpVKxo8fnwkTJqSnp+fm\nnBYAAIAkNxBzvb29OXjwYObOnZv58+fn0KFDSZJTp06lvr5+6Hn19fXp6+u78ZMCAAAw5Kq3WS5c\nuDCnT5/+2/X29vYMDg7m/Pnz2bZtW3p6etLe3p69e/f+4/cZMWLEzTktAAAASf5LzG3evPmKj23d\nujVNTU1JkoaGhowcOTJnz55NXV1dTp48OfS8kydPpq6u7r8epLb2zn97ZkrIvtXNvtXLttXNvtXL\nttXNvvzpum+zbGxsTHd3d5Lk6NGjGRgYSE1NTaZOnZpdu3alv78/x44dS29vbxoaGm7agQEAALiB\n/81y9uzZWblyZWbMmJFKpZKOjo4kyaRJk9LS0pLW1taMGjUqa9eudZslAADATTaiKIpiuA8BAADA\ntfHhbwAAACUk5gAAAEpIzAEAAJTQsMZcT09P5syZk5kzZ2b27Nnp6ekZemzDhg1pampKc3NzDhw4\nMIyn5Hpt2bIlLS0tmT59ejo7O4eu27Z6bNq0KZMnT865c+eGrtm3/Do6OtLS0pK2trY8//zzuXjx\n4tBj9i2//fv3p7m5OU1NTdm4ceNwH4cbdOLEicyfPz+tra2ZPn16PvjggyTJuXPnsnDhwkybNi2L\nFi3KhQsXhvmkXK/BwcHMnDkzzz33XBLbVpMLFy5kyZIlaWlpyVNPPZXvvvvu2vcthtG8efOK/fv3\nF0VRFJ9//nkxb968oiiK4scffyza2tqK/v7+4tixY0VjY2MxODg4nEflGn355ZfFggULiv7+/qIo\niuLMmTNFUdi2mvz000/FokWLiilTphS//PJLURT2rRYHDhwY2q2zs7Po7OwsisK+1eDy5ctFY2Nj\ncezYsaK/v79oa2srjhw5MtzH4gacOnWqOHz4cFEURXHp0qWiqampOHLkSNHR0VFs3LixKIqi2LBh\nw9DPMeWzadOmYunSpcWzzz5bFEVh2yqyYsWK4sMPPyyKoigGBgaKCxcuXPO+w/rOXG1t7dArvhcv\nXhz6cPG9e/emtbU1lUol48ePz4QJE/7yrh23vq1bt2bx4sWpVCpJkpqamiS2rSbr16/P8uXL/3LN\nvtXh0UcfzciRf/zz8NBDD+XkyZNJ7FsNenp6MmHChIwfPz6VSiWtra3Zu3fvcB+LG1BbW5v7778/\nSXL77bdn4sSJ6evrS1dXV2bNmpUkmTVrVvbs2TOcx+Q6nTx5Mvv27cvTTz89dM221eHixYs5ePBg\n5syZkyQZPXp07rzzzmved1hjbtmyZeno6Mjjjz+eN954I8uWLUuSnDp1KvX19UPPq6+vT19f33Ad\nk+vQ29ubgwcPZu7cuZk/f34OHTqUxLbVYs+ePamvr8/kyZP/ct2+1Wf79u157LHHkti3GvT19WXc\nuHFDX9fV1dmwihw/fjzff/99GhoacubMmYwdOzZJMnbs2Jw5c2aYT8f1eP3117NixYqhF9iS2LZK\nHD9+PDU1NXnppZcya9asrF69Or/99ts173vdHxr+by1cuDCnT5/+2/X29vZs2bIlq1evzpNPPplP\nPvkkK1euzObNm//x+/jg8VvP1bYdHBzM+fPns23btvT09KS9vf2Kr/7a9tZ0tX03btyYTZs2DV0r\nrvJxlfa9NV1p3xdeeCFTp05NkrzzzjupVCqZMWPGFb+PfcvFXtXr119/zZIlS7Jq1arccccdf3ls\nxIgRti+hzz77LGPGjMkDDzyQr7766h+fY9vyunz5cg4fPpw1a9akoaEhr7322t9+j/nf7Ps/j7kr\nxVmSLF++PO+9916SpLm5OatXr07yxyuFf97Wk/zxFvOft2By67jatlu3bk1TU1OSpKGhISNHjszZ\ns2dtWyJX2veHH37I8ePH09bWluSPV/pnz56dbdu22bdErvbzmyQ7duzIvn378v777w9ds2/51dXV\n5cSJE0Nf27A6DAwMZMmSJWlra0tjY2OSZMyYMfn5559TW1ubU6dODf26A+Xx7bffpqurK/v27Ut/\nf38uXbqU5cuX27ZK1NfXp66uLg0NDUmSadOmZePGjRk7duw17Tust1nefffd+frrr5Mk3d3dueee\ne5IkU6dOza5du9Lf359jx46lt7d36C9KOTQ2Nqa7uztJcvTo0QwMDKSmpsa2VeDee+/NF198ka6u\nrnR1daWuri47duzI2LFj7Vsl9u/fn3fffTdvv/12/vOf/wxdt2/5Pfjgg+nt7c3x48fT39+f3bt3\n54knnhjuY3EDiqLIqlWrMnHixCxYsGDo+tSpU/PRRx8lST7++OOhyKM8li5dmn379qWrqytvvvlm\nHnnkkXR2dtq2StTW1mbcuHE5evRokuTLL7/MpEmTMmXKlGva93/+ztzVrFu3LuvWrUt/f39uu+22\nvPLKK0mSSZMmpaWlJa2trRk1alTWrl3rLeSSmT17dlauXJkZM2akUqmko6MjiW2r0f/fz77V4dVX\nX83AwEAWLVqUJHn44Yfz8ssv27cKjB49OmvWrMkzzzyT33//PXPmzMnEiROH+1jcgG+++SY7d+7M\nfffdl5kzZyb5IwIWL16c9vb2bN++PXfddVfeeuutYT4pN4ttq8eaNWvy4osvZmBgIBMmTMj69esz\nODh4TfuOKK72yy4AAADckob1NksAAACuj5gDAAAoITEHAABQQmIOAACghMQcAABACYk5AACAEhJz\nAAAAJSTmAAAASuj/AKSSWUR2kw4CAAAAAElFTkSuQmCC\n", - "text/plain": [ - "" - ] - } - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QB5EFrBnpNnc", - "colab_type": "text" - }, - "source": [ - "---\n", - "\n", - "Problem\n", - "-------\n", - "\n", - "An alternative to skip-gram is another Word2Vec model called [CBOW](http://arxiv.org/abs/1301.3781) (Continuous Bag of Words). In the CBOW model, instead of predicting a context word from a word vector, you predict a word from the sum of all the word vectors in its context. Implement and evaluate a CBOW model trained on the text8 dataset.\n", - "\n", - "---" - ] - } - ] -} diff --git a/tensorflow/examples/udacity/6_lstm.ipynb b/tensorflow/examples/udacity/6_lstm.ipynb deleted file mode 100644 index b17e70be95d..00000000000 --- a/tensorflow/examples/udacity/6_lstm.ipynb +++ /dev/null @@ -1,1069 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "version": "0.3.2", - "views": {}, - "default_view": {}, - "name": "6_lstm.ipynb", - "provenance": [] - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "8tQJd2YSCfWR", - "colab_type": "text" - }, - "source": [ - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D7tqLMoKF6uq", - "colab_type": "text" - }, - "source": [ - "Deep Learning\n", - "=============\n", - "\n", - "Assignment 6\n", - "------------\n", - "\n", - "After training a skip-gram model in `5_word2vec.ipynb`, the goal of this notebook is to train a LSTM character model over [Text8](http://mattmahoney.net/dc/textdata) data." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "MvEblsgEXxrd", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "# These are all the modules we'll be using later. Make sure you can import them\n", - "# before proceeding further.\n", - "from __future__ import print_function\n", - "import os\n", - "import numpy as np\n", - "import random\n", - "import string\n", - "import tensorflow as tf\n", - "import zipfile\n", - "from six.moves import range\n", - "from six.moves.urllib.request import urlretrieve" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "RJ-o3UBUFtCw", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 5993, - "status": "ok", - "timestamp": 1445965582896, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "6f6f07b359200c46", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "d530534e-0791-4a94-ca6d-1c8f1b908a9e" - }, - "source": [ - "url = 'http://mattmahoney.net/dc/'\n", - "\n", - "def maybe_download(filename, expected_bytes):\n", - " \"\"\"Download a file if not present, and make sure it's the right size.\"\"\"\n", - " if not os.path.exists(filename):\n", - " filename, _ = urlretrieve(url + filename, filename)\n", - " statinfo = os.stat(filename)\n", - " if statinfo.st_size == expected_bytes:\n", - " print('Found and verified %s' % filename)\n", - " else:\n", - " print(statinfo.st_size)\n", - " raise Exception(\n", - " 'Failed to verify ' + filename + '. Can you get to it with a browser?')\n", - " return filename\n", - "\n", - "filename = maybe_download('text8.zip', 31344016)" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Found and verified text8.zip\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "Mvf09fjugFU_", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 5982, - "status": "ok", - "timestamp": 1445965582916, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "6f6f07b359200c46", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "8f75db58-3862-404b-a0c3-799380597390" - }, - "source": [ - "def read_data(filename):\n", - " with zipfile.ZipFile(filename) as f:\n", - " name = f.namelist()[0]\n", - " data = tf.compat.as_str(f.read(name))\n", - " return data\n", - " \n", - "text = read_data(filename)\n", - "print('Data size %d' % len(text))" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Data size 100000000\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ga2CYACE-ghb", - "colab_type": "text" - }, - "source": [ - "Create a small validation set." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "w-oBpfFG-j43", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 6184, - "status": "ok", - "timestamp": 1445965583138, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "6f6f07b359200c46", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "bdb96002-d021-4379-f6de-a977924f0d02" - }, - "source": [ - "valid_size = 1000\n", - "valid_text = text[:valid_size]\n", - "train_text = text[valid_size:]\n", - "train_size = len(train_text)\n", - "print(train_size, train_text[:64])\n", - "print(valid_size, valid_text[:64])" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "99999000 ons anarchists advocate social relations based upon voluntary as\n", - "1000 anarchism originated as a term of abuse first used against earl\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Zdw6i4F8glpp", - "colab_type": "text" - }, - "source": [ - "Utility functions to map characters to vocabulary IDs and back." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "gAL1EECXeZsD", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 6276, - "status": "ok", - "timestamp": 1445965583249, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "6f6f07b359200c46", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "88fc9032-feb9-45ff-a9a0-a26759cc1f2e" - }, - "source": [ - "vocabulary_size = len(string.ascii_lowercase) + 1 # [a-z] + ' '\n", - "first_letter = ord(string.ascii_lowercase[0])\n", - "\n", - "def char2id(char):\n", - " if char in string.ascii_lowercase:\n", - " return ord(char) - first_letter + 1\n", - " elif char == ' ':\n", - " return 0\n", - " else:\n", - " print('Unexpected character: %s' % char)\n", - " return 0\n", - " \n", - "def id2char(dictid):\n", - " if dictid > 0:\n", - " return chr(dictid + first_letter - 1)\n", - " else:\n", - " return ' '\n", - "\n", - "print(char2id('a'), char2id('z'), char2id(' '), char2id('\u00ef'))\n", - "print(id2char(1), id2char(26), id2char(0))" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "1 26 0 Unexpected character: \u00ef\n", - "0\n", - "a z \n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lFwoyygOmWsL", - "colab_type": "text" - }, - "source": [ - "Function to generate a training batch for the LSTM model." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "d9wMtjy5hCj9", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 6473, - "status": "ok", - "timestamp": 1445965583467, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "6f6f07b359200c46", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "3dd79c80-454a-4be0-8b71-4a4a357b3367" - }, - "source": [ - "batch_size=64\n", - "num_unrollings=10\n", - "\n", - "class BatchGenerator(object):\n", - " def __init__(self, text, batch_size, num_unrollings):\n", - " self._text = text\n", - " self._text_size = len(text)\n", - " self._batch_size = batch_size\n", - " self._num_unrollings = num_unrollings\n", - " segment = self._text_size // batch_size\n", - " self._cursor = [ offset * segment for offset in range(batch_size)]\n", - " self._last_batch = self._next_batch()\n", - " \n", - " def _next_batch(self):\n", - " \"\"\"Generate a single batch from the current cursor position in the data.\"\"\"\n", - " batch = np.zeros(shape=(self._batch_size, vocabulary_size), dtype=np.float)\n", - " for b in range(self._batch_size):\n", - " batch[b, char2id(self._text[self._cursor[b]])] = 1.0\n", - " self._cursor[b] = (self._cursor[b] + 1) % self._text_size\n", - " return batch\n", - " \n", - " def next(self):\n", - " \"\"\"Generate the next array of batches from the data. The array consists of\n", - " the last batch of the previous array, followed by num_unrollings new ones.\n", - " \"\"\"\n", - " batches = [self._last_batch]\n", - " for step in range(self._num_unrollings):\n", - " batches.append(self._next_batch())\n", - " self._last_batch = batches[-1]\n", - " return batches\n", - "\n", - "def characters(probabilities):\n", - " \"\"\"Turn a 1-hot encoding or a probability distribution over the possible\n", - " characters back into its (most likely) character representation.\"\"\"\n", - " return [id2char(c) for c in np.argmax(probabilities, 1)]\n", - "\n", - "def batches2string(batches):\n", - " \"\"\"Convert a sequence of batches back into their (most likely) string\n", - " representation.\"\"\"\n", - " s = [''] * batches[0].shape[0]\n", - " for b in batches:\n", - " s = [''.join(x) for x in zip(s, characters(b))]\n", - " return s\n", - "\n", - "train_batches = BatchGenerator(train_text, batch_size, num_unrollings)\n", - "valid_batches = BatchGenerator(valid_text, 1, 1)\n", - "\n", - "print(batches2string(train_batches.next()))\n", - "print(batches2string(train_batches.next()))\n", - "print(batches2string(valid_batches.next()))\n", - "print(batches2string(valid_batches.next()))" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "['ons anarchi', 'when milita', 'lleria arch', ' abbeys and', 'married urr', 'hel and ric', 'y and litur', 'ay opened f', 'tion from t', 'migration t', 'new york ot', 'he boeing s', 'e listed wi', 'eber has pr', 'o be made t', 'yer who rec', 'ore signifi', 'a fierce cr', ' two six ei', 'aristotle s', 'ity can be ', ' and intrac', 'tion of the', 'dy to pass ', 'f certain d', 'at it will ', 'e convince ', 'ent told hi', 'ampaign and', 'rver side s', 'ious texts ', 'o capitaliz', 'a duplicate', 'gh ann es d', 'ine january', 'ross zero t', 'cal theorie', 'ast instanc', ' dimensiona', 'most holy m', 't s support', 'u is still ', 'e oscillati', 'o eight sub', 'of italy la', 's the tower', 'klahoma pre', 'erprise lin', 'ws becomes ', 'et in a naz', 'the fabian ', 'etchy to re', ' sharman ne', 'ised empero', 'ting in pol', 'd neo latin', 'th risky ri', 'encyclopedi', 'fense the a', 'duating fro', 'treet grid ', 'ations more', 'appeal of d', 'si have mad']\n", - "['ists advoca', 'ary governm', 'hes nationa', 'd monasteri', 'raca prince', 'chard baer ', 'rgical lang', 'for passeng', 'the nationa', 'took place ', 'ther well k', 'seven six s', 'ith a gloss', 'robably bee', 'to recogniz', 'ceived the ', 'icant than ', 'ritic of th', 'ight in sig', 's uncaused ', ' lost as in', 'cellular ic', 'e size of t', ' him a stic', 'drugs confu', ' take to co', ' the priest', 'im to name ', 'd barred at', 'standard fo', ' such as es', 'ze on the g', 'e of the or', 'd hiver one', 'y eight mar', 'the lead ch', 'es classica', 'ce the non ', 'al analysis', 'mormons bel', 't or at lea', ' disagreed ', 'ing system ', 'btypes base', 'anguages th', 'r commissio', 'ess one nin', 'nux suse li', ' the first ', 'zi concentr', ' society ne', 'elatively s', 'etworks sha', 'or hirohito', 'litical ini', 'n most of t', 'iskerdoo ri', 'ic overview', 'air compone', 'om acnm acc', ' centerline', 'e than any ', 'devotional ', 'de such dev']\n", - "[' a']\n", - "['an']\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "KyVd8FxT5QBc", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "def logprob(predictions, labels):\n", - " \"\"\"Log-probability of the true labels in a predicted batch.\"\"\"\n", - " predictions[predictions < 1e-10] = 1e-10\n", - " return np.sum(np.multiply(labels, -np.log(predictions))) / labels.shape[0]\n", - "\n", - "def sample_distribution(distribution):\n", - " \"\"\"Sample one element from a distribution assumed to be an array of normalized\n", - " probabilities.\n", - " \"\"\"\n", - " r = random.uniform(0, 1)\n", - " s = 0\n", - " for i in range(len(distribution)):\n", - " s += distribution[i]\n", - " if s >= r:\n", - " return i\n", - " return len(distribution) - 1\n", - "\n", - "def sample(prediction):\n", - " \"\"\"Turn a (column) prediction into 1-hot encoded samples.\"\"\"\n", - " p = np.zeros(shape=[1, vocabulary_size], dtype=np.float)\n", - " p[0, sample_distribution(prediction[0])] = 1.0\n", - " return p\n", - "\n", - "def random_distribution():\n", - " \"\"\"Generate a random column of probabilities.\"\"\"\n", - " b = np.random.uniform(0.0, 1.0, size=[1, vocabulary_size])\n", - " return b/np.sum(b, 1)[:,None]" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K8f67YXaDr4C", - "colab_type": "text" - }, - "source": [ - "Simple LSTM Model." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Q5rxZK6RDuGe", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "both" - }, - "source": [ - "num_nodes = 64\n", - "\n", - "graph = tf.Graph()\n", - "with graph.as_default():\n", - " \n", - " # Parameters:\n", - " # Input gate: input, previous output, and bias.\n", - " ix = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n", - " im = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n", - " ib = tf.Variable(tf.zeros([1, num_nodes]))\n", - " # Forget gate: input, previous output, and bias.\n", - " fx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n", - " fm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n", - " fb = tf.Variable(tf.zeros([1, num_nodes]))\n", - " # Memory cell: input, state and bias. \n", - " cx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n", - " cm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n", - " cb = tf.Variable(tf.zeros([1, num_nodes]))\n", - " # Output gate: input, previous output, and bias.\n", - " ox = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))\n", - " om = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))\n", - " ob = tf.Variable(tf.zeros([1, num_nodes]))\n", - " # Variables saving state across unrollings.\n", - " saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n", - " saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)\n", - " # Classifier weights and biases.\n", - " w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], -0.1, 0.1))\n", - " b = tf.Variable(tf.zeros([vocabulary_size]))\n", - " \n", - " # Definition of the cell computation.\n", - " def lstm_cell(i, o, state):\n", - " \"\"\"Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf\n", - " Note that in this formulation, we omit the various connections between the\n", - " previous state and the gates.\"\"\"\n", - " input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)\n", - " forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)\n", - " update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb\n", - " state = forget_gate * state + input_gate * tf.tanh(update)\n", - " output_gate = tf.sigmoid(tf.matmul(i, ox) + tf.matmul(o, om) + ob)\n", - " return output_gate * tf.tanh(state), state\n", - "\n", - " # Input data.\n", - " train_data = list()\n", - " for _ in range(num_unrollings + 1):\n", - " train_data.append(\n", - " tf.placeholder(tf.float32, shape=[batch_size,vocabulary_size]))\n", - " train_inputs = train_data[:num_unrollings]\n", - " train_labels = train_data[1:] # labels are inputs shifted by one time step.\n", - "\n", - " # Unrolled LSTM loop.\n", - " outputs = list()\n", - " output = saved_output\n", - " state = saved_state\n", - " for i in train_inputs:\n", - " output, state = lstm_cell(i, output, state)\n", - " outputs.append(output)\n", - "\n", - " # State saving across unrollings.\n", - " with tf.control_dependencies([saved_output.assign(output),\n", - " saved_state.assign(state)]):\n", - " # Classifier.\n", - " logits = tf.nn.xw_plus_b(tf.concat(outputs, 0), w, b)\n", - " loss = tf.reduce_mean(\n", - " tf.nn.softmax_cross_entropy_with_logits(\n", - " labels=tf.concat(train_labels, 0), logits=logits))\n", - "\n", - " # Optimizer.\n", - " global_step = tf.Variable(0)\n", - " learning_rate = tf.train.exponential_decay(\n", - " 10.0, global_step, 5000, 0.1, staircase=True)\n", - " optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n", - " gradients, v = zip(*optimizer.compute_gradients(loss))\n", - " gradients, _ = tf.clip_by_global_norm(gradients, 1.25)\n", - " optimizer = optimizer.apply_gradients(\n", - " zip(gradients, v), global_step=global_step)\n", - "\n", - " # Predictions.\n", - " train_prediction = tf.nn.softmax(logits)\n", - " \n", - " # Sampling and validation eval: batch 1, no unrolling.\n", - " sample_input = tf.placeholder(tf.float32, shape=[1, vocabulary_size])\n", - " saved_sample_output = tf.Variable(tf.zeros([1, num_nodes]))\n", - " saved_sample_state = tf.Variable(tf.zeros([1, num_nodes]))\n", - " reset_sample_state = tf.group(\n", - " saved_sample_output.assign(tf.zeros([1, num_nodes])),\n", - " saved_sample_state.assign(tf.zeros([1, num_nodes])))\n", - " sample_output, sample_state = lstm_cell(\n", - " sample_input, saved_sample_output, saved_sample_state)\n", - " with tf.control_dependencies([saved_sample_output.assign(sample_output),\n", - " saved_sample_state.assign(sample_state)]):\n", - " sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))" - ], - "outputs": [], - "execution_count": 0 - }, - { - "cell_type": "code", - "metadata": { - "id": "RD9zQCZTEaEm", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 41 - }, - { - "item_id": 80 - }, - { - "item_id": 126 - }, - { - "item_id": 144 - } - ] - }, - "cellView": "both", - "executionInfo": { - "elapsed": 199909, - "status": "ok", - "timestamp": 1445965877333, - "user": { - "color": "#1FA15D", - "displayName": "Vincent Vanhoucke", - "isAnonymous": false, - "isMe": true, - "permissionId": "05076109866853157986", - "photoUrl": "//lh6.googleusercontent.com/-cCJa7dTDcgQ/AAAAAAAAAAI/AAAAAAAACgw/r2EZ_8oYer4/s50-c-k-no/photo.jpg", - "sessionId": "6f6f07b359200c46", - "userId": "102167687554210253930" - }, - "user_tz": 420 - }, - "outputId": "5e868466-2532-4545-ce35-b403cf5d9de6" - }, - "source": [ - "num_steps = 7001\n", - "summary_frequency = 100\n", - "\n", - "with tf.Session(graph=graph) as session:\n", - " tf.global_variables_initializer().run()\n", - " print('Initialized')\n", - " mean_loss = 0\n", - " for step in range(num_steps):\n", - " batches = train_batches.next()\n", - " feed_dict = dict()\n", - " for i in range(num_unrollings + 1):\n", - " feed_dict[train_data[i]] = batches[i]\n", - " _, l, predictions, lr = session.run(\n", - " [optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)\n", - " mean_loss += l\n", - " if step % summary_frequency == 0:\n", - " if step > 0:\n", - " mean_loss = mean_loss / summary_frequency\n", - " # The mean loss is an estimate of the loss over the last few batches.\n", - " print(\n", - " 'Average loss at step %d: %f learning rate: %f' % (step, mean_loss, lr))\n", - " mean_loss = 0\n", - " labels = np.concatenate(list(batches)[1:])\n", - " print('Minibatch perplexity: %.2f' % float(\n", - " np.exp(logprob(predictions, labels))))\n", - " if step % (summary_frequency * 10) == 0:\n", - " # Generate some samples.\n", - " print('=' * 80)\n", - " for _ in range(5):\n", - " feed = sample(random_distribution())\n", - " sentence = characters(feed)[0]\n", - " reset_sample_state.run()\n", - " for _ in range(79):\n", - " prediction = sample_prediction.eval({sample_input: feed})\n", - " feed = sample(prediction)\n", - " sentence += characters(feed)[0]\n", - " print(sentence)\n", - " print('=' * 80)\n", - " # Measure validation set perplexity.\n", - " reset_sample_state.run()\n", - " valid_logprob = 0\n", - " for _ in range(valid_size):\n", - " b = valid_batches.next()\n", - " predictions = sample_prediction.eval({sample_input: b[0]})\n", - " valid_logprob = valid_logprob + logprob(predictions, b[1])\n", - " print('Validation set perplexity: %.2f' % float(np.exp(\n", - " valid_logprob / valid_size)))" - ], - "outputs": [ - { - "output_type": "stream", - "text": [ - "Initialized\n", - "Average loss at step 0 : 3.29904174805 learning rate: 10.0\n", - "Minibatch perplexity: 27.09\n", - "================================================================================\n", - "srk dwmrnuldtbbgg tapootidtu xsciu sgokeguw hi ieicjq lq piaxhazvc s fht wjcvdlh\n", - "lhrvallvbeqqquc dxd y siqvnle bzlyw nr rwhkalezo siie o deb e lpdg storq u nx o\n", - "meieu nantiouie gdys qiuotblci loc hbiznauiccb cqzed acw l tsm adqxplku gn oaxet\n", - "unvaouc oxchywdsjntdh zpklaejvxitsokeerloemee htphisb th eaeqseibumh aeeyj j orw\n", - "ogmnictpycb whtup otnilnesxaedtekiosqet liwqarysmt arj flioiibtqekycbrrgoysj\n", - "================================================================================\n", - "Validation set perplexity: 19.99\n", - "Average loss at step 100 : 2.59553678274 learning rate: 10.0\n", - "Minibatch perplexity: 9.57\n", - "Validation set perplexity: 10.60\n", - "Average loss at step 200 : 2.24747137785 learning rate: 10.0\n", - "Minibatch perplexity: 7.68\n", - "Validation set perplexity: 8.84\n", - "Average loss at step 300 : 2.09438110709 learning rate: 10.0\n", - "Minibatch perplexity: 7.41\n", - "Validation set perplexity: 8.13\n", - "Average loss at step 400 : 1.99440989017 learning rate: 10.0\n", - "Minibatch perplexity: 6.46\n", - "Validation set perplexity: 7.58\n", - "Average loss at step 500 : 1.9320810616 learning rate: 10.0\n", - "Minibatch perplexity: 6.30\n", - "Validation set perplexity: 6.88\n", - "Average loss at step 600 : 1.90935629249 learning rate: 10.0\n", - "Minibatch perplexity: 7.21\n", - "Validation set perplexity: 6.91\n", - "Average loss at step 700 : 1.85583009005 learning rate: 10.0\n", - "Minibatch perplexity: 6.13\n", - "Validation set perplexity: 6.60\n", - "Average loss at step 800 : 1.82152368546 learning rate: 10.0\n", - "Minibatch perplexity: 6.01\n", - "Validation set perplexity: 6.37\n", - "Average loss at step 900 : 1.83169809818 learning rate: 10.0\n", - "Minibatch perplexity: 7.20\n", - "Validation set perplexity: 6.23\n", - "Average loss at step 1000 : 1.82217029214 learning rate: 10.0\n", - "Minibatch perplexity: 6.73\n", - "================================================================================\n", - "le action b of the tert sy ofter selvorang previgned stischdy yocal chary the co\n", - "le relganis networks partucy cetinning wilnchan sics rumeding a fulch laks oftes\n", - "hian andoris ret the ecause bistory l pidect one eight five lack du that the ses\n", - "aiv dromery buskocy becomer worils resism disele retery exterrationn of hide in \n", - "mer miter y sught esfectur of the upission vain is werms is vul ugher compted by\n", - "================================================================================\n", - "Validation set perplexity: 6.07\n", - "Average loss at step 1100 : 1.77301145077 learning rate: 10.0\n", - "Minibatch perplexity: 6.03\n", - "Validation set perplexity: 5.89\n", - "Average loss at step 1200 : 1.75306463003 learning rate: 10.0\n", - "Minibatch perplexity: 6.50\n", - "Validation set perplexity: 5.61\n", - "Average loss at step 1300 : 1.72937195778 learning rate: 10.0\n", - "Minibatch perplexity: 5.00\n", - "Validation set perplexity: 5.60\n", - "Average loss at step 1400 : 1.74773373723 learning rate: 10.0\n", - "Minibatch perplexity: 6.48\n", - "Validation set perplexity: 5.66\n", - "Average loss at step 1500 : 1.7368799901 learning rate: 10.0\n", - "Minibatch perplexity: 5.22\n", - "Validation set perplexity: 5.44\n", - "Average loss at step 1600 : 1.74528762937 learning rate: 10.0\n", - "Minibatch perplexity: 5.85\n", - "Validation set perplexity: 5.33\n", - "Average loss at step 1700 : 1.70881183743 learning rate: 10.0\n", - "Minibatch perplexity: 5.33\n", - "Validation set perplexity: 5.56\n", - "Average loss at step 1800 : 1.67776108027 learning rate: 10.0\n", - "Minibatch perplexity: 5.33\n", - "Validation set perplexity: 5.29\n", - "Average loss at step 1900 : 1.64935536742 learning rate: 10.0\n", - "Minibatch perplexity: 5.29\n", - "Validation set perplexity: 5.15\n", - "Average loss at step" - ], - "name": "stdout" - }, - { - "output_type": "stream", - "text": [ - " 2000 : 1.69528644681 learning rate: 10.0\n", - "Minibatch perplexity: 5.13\n", - "================================================================================\n", - "vers soqually have one five landwing to docial page kagan lower with ther batern\n", - "ctor son alfortmandd tethre k skin the known purated to prooust caraying the fit\n", - "je in beverb is the sournction bainedy wesce tu sture artualle lines digra forme\n", - "m rousively haldio ourso ond anvary was for the seven solies hild buil s to te\n", - "zall for is it is one nine eight eight one neval to the kime typer oene where he\n", - "================================================================================\n", - "Validation set perplexity: 5.25\n", - "Average loss at step 2100 : 1.68808053017 learning rate: 10.0\n", - "Minibatch perplexity: 5.17\n", - "Validation set perplexity: 5.01\n", - "Average loss at step 2200 : 1.68322490931 learning rate: 10.0\n", - "Minibatch perplexity: 5.09\n", - "Validation set perplexity: 5.15\n", - "Average loss at step 2300 : 1.64465074301 learning rate: 10.0\n", - "Minibatch perplexity: 5.51\n", - "Validation set perplexity: 5.00\n", - "Average loss at step 2400 : 1.66408578038 learning rate: 10.0\n", - "Minibatch perplexity: 5.86\n", - "Validation set perplexity: 4.80\n", - "Average loss at step 2500 : 1.68515402555 learning rate: 10.0\n", - "Minibatch perplexity: 5.75\n", - "Validation set perplexity: 4.82\n", - "Average loss at step 2600 : 1.65405208349 learning rate: 10.0\n", - "Minibatch perplexity: 5.38\n", - "Validation set perplexity: 4.85\n", - "Average loss at step 2700 : 1.65706222177 learning rate: 10.0\n", - "Minibatch perplexity: 5.46\n", - "Validation set perplexity: 4.78\n", - "Average loss at step 2800 : 1.65204829812 learning rate: 10.0\n", - "Minibatch perplexity: 5.06\n", - "Validation set perplexity: 4.64\n", - "Average loss at step 2900 : 1.65107253551 learning rate: 10.0\n", - "Minibatch perplexity: 5.00\n", - "Validation set perplexity: 4.61\n", - "Average loss at step 3000 : 1.6495274055 learning rate: 10.0\n", - "Minibatch perplexity: 4.53\n", - "================================================================================\n", - "ject covered in belo one six six to finsh that all di rozial sime it a the lapse\n", - "ble which the pullic bocades record r to sile dric two one four nine seven six f\n", - " originally ame the playa ishaps the stotchational in a p dstambly name which as\n", - "ore volum to bay riwer foreal in nuily operety can and auscham frooripm however \n", - "kan traogey was lacous revision the mott coupofiteditey the trando insended frop\n", - "================================================================================\n", - "Validation set perplexity: 4.76\n", - "Average loss at step 3100 : 1.63705502152 learning rate: 10.0\n", - "Minibatch perplexity: 5.50\n", - "Validation set perplexity: 4.76\n", - "Average loss at step 3200 : 1.64740695596 learning rate: 10.0\n", - "Minibatch perplexity: 4.84\n", - "Validation set perplexity: 4.67\n", - "Average loss at step 3300 : 1.64711504817 learning rate: 10.0\n", - "Minibatch perplexity: 5.39\n", - "Validation set perplexity: 4.57\n", - "Average loss at step 3400 : 1.67113256454 learning rate: 10.0\n", - "Minibatch perplexity: 5.56\n", - "Validation set perplexity: 4.71\n", - "Average loss at step 3500 : 1.65637169957 learning rate: 10.0\n", - "Minibatch perplexity: 5.03\n", - "Validation set perplexity: 4.80\n", - "Average loss at step 3600 : 1.66601825476 learning rate: 10.0\n", - "Minibatch perplexity: 4.63\n", - "Validation set perplexity: 4.52\n", - "Average loss at step 3700 : 1.65021387935 learning rate: 10.0\n", - "Minibatch perplexity: 5.50\n", - "Validation set perplexity: 4.56\n", - "Average loss at step 3800 : 1.64481814981 learning rate: 10.0\n", - "Minibatch perplexity: 4.60\n", - "Validation set perplexity: 4.54\n", - "Average loss at step 3900 : 1.642069453 learning rate: 10.0\n", - "Minibatch perplexity: 4.91\n", - "Validation set perplexity: 4.54\n", - "Average loss at step 4000 : 1.65179730773 learning rate: 10.0\n", - "Minibatch perplexity: 4.77\n", - "================================================================================\n", - "k s rasbonish roctes the nignese at heacle was sito of beho anarchys and with ro\n", - "jusar two sue wletaus of chistical in causations d ow trancic bruthing ha laters\n", - "de and speacy pulted yoftret worksy zeatlating to eight d had to ie bue seven si" - ], - "name": "stdout" - }, - { - "output_type": "stream", - "text": [ - "\n", - "s fiction of the feelly constive suq flanch earlied curauking bjoventation agent\n", - "quen s playing it calana our seopity also atbellisionaly comexing the revideve i\n", - "================================================================================\n", - "Validation set perplexity: 4.58\n", - "Average loss at step 4100 : 1.63794238806 learning rate: 10.0\n", - "Minibatch perplexity: 5.47\n", - "Validation set perplexity: 4.79\n", - "Average loss at step 4200 : 1.63822438836 learning rate: 10.0\n", - "Minibatch perplexity: 5.30\n", - "Validation set perplexity: 4.54\n", - "Average loss at step 4300 : 1.61844664574 learning rate: 10.0\n", - "Minibatch perplexity: 4.69\n", - "Validation set perplexity: 4.54\n", - "Average loss at step 4400 : 1.61255454302 learning rate: 10.0\n", - "Minibatch perplexity: 4.67\n", - "Validation set perplexity: 4.54\n", - "Average loss at step 4500 : 1.61543365479 learning rate: 10.0\n", - "Minibatch perplexity: 4.83\n", - "Validation set perplexity: 4.69\n", - "Average loss at step 4600 : 1.61607327104 learning rate: 10.0\n", - "Minibatch perplexity: 5.18\n", - "Validation set perplexity: 4.64\n", - "Average loss at step 4700 : 1.62757282495 learning rate: 10.0\n", - "Minibatch perplexity: 4.24\n", - "Validation set perplexity: 4.66\n", - "Average loss at step 4800 : 1.63222063541 learning rate: 10.0\n", - "Minibatch perplexity: 5.30\n", - "Validation set perplexity: 4.53\n", - "Average loss at step 4900 : 1.63678096652 learning rate: 10.0\n", - "Minibatch perplexity: 5.43\n", - "Validation set perplexity: 4.64\n", - "Average loss at step 5000 : 1.610340662 learning rate: 1.0\n", - "Minibatch perplexity: 5.10\n", - "================================================================================\n", - "in b one onarbs revieds the kimiluge that fondhtic fnoto cre one nine zero zero \n", - " of is it of marking panzia t had wap ironicaghni relly deah the omber b h menba\n", - "ong messified it his the likdings ara subpore the a fames distaled self this int\n", - "y advante authors the end languarle meit common tacing bevolitione and eight one\n", - "zes that materly difild inllaring the fusts not panition assertian causecist bas\n", - "================================================================================\n", - "Validation set perplexity: 4.69\n", - "Average loss at step 5100 : 1.60593637228 learning rate: 1.0\n", - "Minibatch perplexity: 4.69\n", - "Validation set perplexity: 4.47\n", - "Average loss at step 5200 : 1.58993269444 learning rate: 1.0\n", - "Minibatch perplexity: 4.65\n", - "Validation set perplexity: 4.39\n", - "Average loss at step 5300 : 1.57930587292 learning rate: 1.0\n", - "Minibatch perplexity: 5.11\n", - "Validation set perplexity: 4.39\n", - "Average loss at step 5400 : 1.58022856832 learning rate: 1.0\n", - "Minibatch perplexity: 5.19\n", - "Validation set perplexity: 4.37\n", - "Average loss at step 5500 : 1.56654450059 learning rate: 1.0\n", - "Minibatch perplexity: 4.69\n", - "Validation set perplexity: 4.33\n", - "Average loss at step 5600 : 1.58013380885 learning rate: 1.0\n", - "Minibatch perplexity: 5.13\n", - "Validation set perplexity: 4.35\n", - "Average loss at step 5700 : 1.56974959254 learning rate: 1.0\n", - "Minibatch perplexity: 5.00\n", - "Validation set perplexity: 4.34\n", - "Average loss at step 5800 : 1.5839582932 learning rate: 1.0\n", - "Minibatch perplexity: 4.88\n", - "Validation set perplexity: 4.31\n", - "Average loss at step 5900 : 1.57129439116 learning rate: 1.0\n", - "Minibatch perplexity: 4.66\n", - "Validation set perplexity: 4.32\n", - "Average loss at step 6000 : 1.55144061089 learning rate: 1.0\n", - "Minibatch perplexity: 4.55\n", - "================================================================================\n", - "utic clositical poopy stribe addi nixe one nine one zero zero eight zero b ha ex\n", - "zerns b one internequiption of the secordy way anti proble akoping have fictiona\n", - "phare united from has poporarly cities book ins sweden emperor a sass in origina\n", - "quulk destrebinist and zeilazar and on low and by in science over country weilti\n", - "x are holivia work missincis ons in the gages to starsle histon one icelanctrotu\n", - "================================================================================\n", - "Validation set perplexity: 4.30\n", - "Average loss at step 6100 : 1.56450940847 learning rate: 1.0\n", - "Minibatch perplexity: 4.77\n", - "Validation set perplexity: 4.27" - ], - "name": "stdout" - }, - { - "output_type": "stream", - "text": [ - "\n", - "Average loss at step 6200 : 1.53433164835 learning rate: 1.0\n", - "Minibatch perplexity: 4.77\n", - "Validation set perplexity: 4.27\n", - "Average loss at step 6300 : 1.54773445129 learning rate: 1.0\n", - "Minibatch perplexity: 4.76\n", - "Validation set perplexity: 4.25\n", - "Average loss at step 6400 : 1.54021131516 learning rate: 1.0\n", - "Minibatch perplexity: 4.56\n", - "Validation set perplexity: 4.24\n", - "Average loss at step 6500 : 1.56153374553 learning rate: 1.0\n", - "Minibatch perplexity: 5.43\n", - "Validation set perplexity: 4.27\n", - "Average loss at step 6600 : 1.59556478739 learning rate: 1.0\n", - "Minibatch perplexity: 4.92\n", - "Validation set perplexity: 4.28\n", - "Average loss at step 6700 : 1.58076951623 learning rate: 1.0\n", - "Minibatch perplexity: 4.77\n", - "Validation set perplexity: 4.30\n", - "Average loss at step 6800 : 1.6070714438 learning rate: 1.0\n", - "Minibatch perplexity: 4.98\n", - "Validation set perplexity: 4.28\n", - "Average loss at step 6900 : 1.58413293839 learning rate: 1.0\n", - "Minibatch perplexity: 4.61\n", - "Validation set perplexity: 4.29\n", - "Average loss at step 7000 : 1.57905534983 learning rate: 1.0\n", - "Minibatch perplexity: 5.08\n", - "================================================================================\n", - "jague are officiencinels ored by film voon higherise haik one nine on the iffirc\n", - "oshe provision that manned treatists on smalle bodariturmeristing the girto in s\n", - "kis would softwenn mustapultmine truativersakys bersyim by s of confound esc bub\n", - "ry of the using one four six blain ira mannom marencies g with fextificallise re\n", - " one son vit even an conderouss to person romer i a lebapter at obiding are iuse\n", - "================================================================================\n", - "Validation set perplexity: 4.25\n" - ], - "name": "stdout" - } - ], - "execution_count": 0 - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pl4vtmFfa5nn", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 1\n", - "---------\n", - "\n", - "You might have noticed that the definition of the LSTM cell involves 4 matrix multiplications with the input, and 4 matrix multiplications with the output. Simplify the expression by using a single matrix multiply for each, and variables that are 4 times larger.\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4eErTCTybtph", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 2\n", - "---------\n", - "\n", - "We want to train a LSTM over bigrams, that is pairs of consecutive characters like 'ab' instead of single characters like 'a'. Since the number of possible bigrams is large, feeding them directly to the LSTM using 1-hot encodings will lead to a very sparse representation that is very wasteful computationally.\n", - "\n", - "a- Introduce an embedding lookup on the inputs, and feed the embeddings to the LSTM cell instead of the inputs themselves.\n", - "\n", - "b- Write a bigram-based LSTM, modeled on the character LSTM above.\n", - "\n", - "c- Introduce Dropout. For best practices on how to use Dropout in LSTMs, refer to this [article](http://arxiv.org/abs/1409.2329).\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Y5tapX3kpcqZ", - "colab_type": "text" - }, - "source": [ - "---\n", - "Problem 3\n", - "---------\n", - "\n", - "(difficult!)\n", - "\n", - "Write a sequence-to-sequence LSTM which mirrors all the words in a sentence. For example, if your input is:\n", - "\n", - " the quick brown fox\n", - " \n", - "the model should attempt to output:\n", - "\n", - " eht kciuq nworb xof\n", - " \n", - "Refer to the lecture on how to put together a sequence-to-sequence model, as well as [this article](http://arxiv.org/abs/1409.3215) for best practices.\n", - "\n", - "---" - ] - } - ] -} diff --git a/tensorflow/examples/udacity/Dockerfile b/tensorflow/examples/udacity/Dockerfile deleted file mode 100644 index 00eb853e527..00000000000 --- a/tensorflow/examples/udacity/Dockerfile +++ /dev/null @@ -1,15 +0,0 @@ -FROM gcr.io/tensorflow/tensorflow:latest -LABEL maintainer="Vincent Vanhoucke " - -# Pillow needs libjpeg by default as of 3.0. -RUN apt-get update && apt-get install -y --no-install-recommends \ - libjpeg8-dev \ - && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -RUN pip install scikit-learn pyreadline Pillow imageio -RUN rm -rf /notebooks/* -ADD *.ipynb /notebooks/ -WORKDIR /notebooks -CMD ["/run_jupyter.sh", "--allow-root"] diff --git a/tensorflow/examples/udacity/README.md b/tensorflow/examples/udacity/README.md index b3bd73a08b2..c5883c48204 100644 --- a/tensorflow/examples/udacity/README.md +++ b/tensorflow/examples/udacity/README.md @@ -1,127 +1,5 @@ -Assignments for Udacity Deep Learning class with TensorFlow -=========================================================== -Course information can be found at https://www.udacity.com/course/deep-learning--ud730 +The contents of this folder have been moved to: -## Getting Started with Docker +[https://github.com/tensorflow/examples/tree/master/courses/udacity_deep_learning](https://github.com/tensorflow/examples/tree/master/courses/udacity_deep_learning) -If you are new to Docker, follow -[Docker document](https://docs.docker.com/machine/get-started/) to start a -docker instance. Kindly read the requirements of Windows and Mac carefully. - -Running the Docker container from the Google Cloud repository -------------------------------------------------------------- - - docker run -p 8888:8888 --name tensorflow-udacity -it gcr.io/tensorflow/udacity-assignments:1.0.0 - -Note that if you ever exit the container, you can return to it using: - - docker start -ai tensorflow-udacity - -Accessing the Notebooks ------------------------ - -On linux, go to: http://127.0.0.1:8888 - -On mac, go to terminal and find the virtual machine's IP using: - - docker-machine ip default - -Then go to: http://(ip address received from the above command):8888 (likely -http://192.168.99.100:8888) - -On Windows, use powershell to find the virtual machine's IP using: - - docker-machine ip default - - -Then go to: http://(ip address received from the above command):8888 (likely -http://192.168.99.100:8888) - -FAQ ---- - -* **I'm getting a MemoryError when loading data in the first notebook.** - -If you're using a Mac, Docker works by running a VM locally (which -is controlled by `docker-machine`). It's quite likely that you'll -need to bump up the amount of RAM allocated to the VM beyond the -default (which is 1G). -[This Stack Overflow question](http://stackoverflow.com/questions/32834082/how-to-increase-docker-machine-memory-mac) -has two good suggestions; we recommend using 8G. - -In addition, you may need to pass `--memory=8g` as an extra argument to -`docker run`. - -* **I want to create a new virtual machine instead of the default one.** - -`docker-machine` is a tool to provision and manage docker hosts, it supports multiple platform (ex. aws, gce, azure, virtualbox, ...). To create a new virtual machine locally with built-in docker engine, you can use - - docker-machine create -d virtualbox --virtualbox-memory 8196 tensorflow - -`-d` means the driver for the cloud platform, supported drivers listed [here](https://docs.docker.com/machine/drivers/). Here we use virtualbox to create a new virtual machine locally. `tensorflow` means the name of the virtual machine, feel free to use whatever you like. You can use - - docker-machine ip tensorflow - -to get the ip of the new virtual machine. To switch from default virtual machine to a new one (here we use tensorflow), type - - eval $(docker-machine env tensorflow) - -Note that `docker-machine env tensorflow` outputs some environment variables such like `DOCKER_HOST`. Then your docker client is now connected to the docker host in virtual machine `tensorflow` - -* **I'm getting a TLS connection error.** - -If you get an error about the TLS connection of your docker, run the command below to confirm the problem. - - docker-machine ip tensorflow - -Then if it is the case use the instructions on [this page](https://docs.docker.com/toolbox/faqs/troubleshoot/) to solve the issue. - - -* **I'm getting the error - docker: Cannot connect to the Docker daemon. Is the docker daemon running on this host? - when I run 'docker run'.** - -This is a permissions issue, and a popular answer is provided for Linux and Max OSX [here](http://stackoverflow.com/questions/21871479/docker-cant-connect-to-docker-daemon) on StackOverflow. - -Notes for anyone needing to build their own containers (mostly instructors) -=========================================================================== - -Building a local Docker container ---------------------------------- - - cd tensorflow/examples/udacity - docker build --pull -t $USER/assignments . - -Running the local container ---------------------------- - -To run a disposable container: - - docker run -p 8888:8888 -it --rm $USER/assignments - -Note the above command will create an ephemeral container and all data stored in the container will be lost when the container stops. - -To avoid losing work between sessions in the container, it is recommended that you mount the `tensorflow/examples/udacity` directory into the container: - - docker run -p 8888:8888 -v :/notebooks -it --rm $USER/assignments - -This will allow you to save work and have access to generated files on the host filesystem. - -Pushing a Google Cloud release ------------------------------- - - V=1.0.0 - docker tag $USER/assignments gcr.io/tensorflow/udacity-assignments:$V - gcloud docker push gcr.io/tensorflow/udacity-assignments - docker tag $USER/assignments gcr.io/tensorflow/udacity-assignments:latest - gcloud docker push gcr.io/tensorflow/udacity-assignments - -History -------- - -* 0.1.0: Initial release. -* 0.2.0: Many fixes, including lower memory footprint and support for Python 3. -* 0.3.0: Use 0.7.1 release. -* 0.4.0: Move notMNIST data for Google Cloud. -* 0.5.0: Actually use 0.7.1 release. -* 0.6.0: Update to TF 0.10.0, add libjpeg (for Pillow). -* 1.0.0: Update to TF 1.0.0 release. diff --git a/tensorflow/go/BUILD b/tensorflow/go/BUILD index 62d6b4f57c2..e1909880697 100644 --- a/tensorflow/go/BUILD +++ b/tensorflow/go/BUILD @@ -9,17 +9,21 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load( + "//tensorflow:tensorflow.bzl", + "tf_shared_library_deps", +) + sh_test( name = "test", size = "small", srcs = ["test.sh"], data = [ ":all_files", # Go sources - "//tensorflow:libtensorflow.so", # C library "//tensorflow/c:headers", # C library header "//tensorflow/c/eager:headers", # Eager C library header "//tensorflow/cc/saved_model:saved_model_half_plus_two", # Testdata for LoadSavedModel - ], + ] + tf_shared_library_deps(), ) filegroup( diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 6ff41ca9169..31f087591b7 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -53,7 +53,7 @@ type Graph struct { c *C.TF_Graph } -// Graph execution options +// The GraphImportOptions struct holds parameters for the ImportWithOptions function. type GraphImportOptions struct { // Node prefix Prefix string @@ -170,7 +170,7 @@ func (g *Graph) Operation(name string) *Operation { // Operations returns a list of all operations in the graph func (g *Graph) Operations() []Operation { - var pos C.size_t = 0 + var pos C.size_t ops := []Operation{} for { cop := C.TF_GraphNextOperation(g.c, &pos) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 3ea479eeed0..a87ae742a31 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -96,174 +96,6 @@ func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output return op.Output(0), op.Output(1), op.Output(2) } -// FakeQuantWithMinMaxVarsPerChannelAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannel. -type FakeQuantWithMinMaxVarsPerChannelAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsPerChannelNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsPerChannelNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsPerChannelNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsPerChannelNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`, -// -// `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` -// to 'outputs' tensor of same shape as `inputs`. -// -// `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. -// -// This operation has a gradient and thus allows for training `min` and `max` -// values. -func FakeQuantWithMinMaxVarsPerChannel(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelAttr) (outputs tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVarsPerChannel", - Input: []tf.Input{ - inputs, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. -type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value. -// -// value: The bitwidth of the quantization; between 2 and 8, inclusive. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsGradientNarrowRange sets the optional narrow_range attribute to value. -// -// value: Whether to quantize into 2^num_bits - 1 distinct values. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxVars operation. -// -// Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. -// min, max: Quantization interval, scalar floats. -// -// -// -// Returns Backpropagated gradients w.r.t. inputs: -// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter: -// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter: -// `sum(gradients * (inputs > max))`. -func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVarsGradient", - Input: []tf.Input{ - gradients, inputs, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. -type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. -// If not specified, defaults to -6 -func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["min"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. -// If not specified, defaults to 6 -func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["max"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxArgs operation. -// -// Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. -// -// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: -// `gradients * (inputs >= min && inputs <= max)`. -func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxArgsGradient", - Input: []tf.Input{ - gradients, inputs, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs. type FakeQuantWithMinMaxArgsAttr func(optionalAttr) @@ -307,6 +139,15 @@ func FakeQuantWithMinMaxArgsNarrowRange(value bool) FakeQuantWithMinMaxArgsAttr // then de-quantized and output as floats in `[min; max]` interval. // `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. // +// Before quantization, `min` and `max` values are adjusted with the following +// logic. +// It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, +// the behavior can be unexpected: +// If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. +// If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. +// If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, +// `min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. +// // Quantization is called fake since the output is still in floating point. func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsAttr) (outputs tf.Output) { if scope.Err() != nil { @@ -327,12 +168,71 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua return op.Output(0) } -// Subtracts sparse `updates` from an existing tensor according to `indices`. +// Applies sparse addition to `input` using individual values or slices // -// This operation creates a new tensor by subtracting sparse `updates` from the -// passed in `tensor`. -// This operation is very similar to `tf.scatter_nd_sub`, except that the updates -// are subtracted from an existing tensor (as opposed to a variable). If the memory +// from `updates` according to indices `indices`. The updates are non-aliasing: +// `input` is only modified in-place if no other operations will use it. +// Otherwise, a copy of `input` is made. This operation has a gradient with +// respect to both `input` and `updates`. +// +// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `input`. +// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or `(P-K)`-dimensional slices +// (if `K < P`) along the `K`th dimension of `input`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$ +// +// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 +// elements. In Python, that addition would look like this: +// +// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) +// with tf.Session() as sess: +// print(sess.run(output)) +// +// The resulting value `output` would look like this: +// +// [1, 13, 3, 14, 14, 6, 7, 20] +// +// See `tf.scatter_nd` for more details about how to make updates to slices. +// +// Arguments: +// input: A Tensor. +// indices: A Tensor. Must be one of the following types: `int32`, `int64`. +// A tensor of indices into `input`. +// updates: A Tensor. Must have the same type as ref. A tensor of updated values +// to add to `input`. +// +// Returns A `Tensor` with the same shape as `input`, containing values of `input` +// updated with `updates`. +func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ScatterNdNonAliasingAdd", + Input: []tf.Input{ + input, indices, updates, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds sparse `updates` to an existing tensor according to `indices`. +// +// This operation creates a new tensor by adding sparse `updates` to the passed +// in `tensor`. +// This operation is very similar to `tf.scatter_nd_add`, except that the updates +// are added onto an existing tensor (as opposed to a variable). If the memory // for the existing tensor cannot be re-used, a copy is made and updated. // // `indices` is an integer tensor containing indices into a new tensor of shape @@ -347,24 +247,24 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua // // indices.shape[:-1] + shape[indices.shape[-1]:] // -// The simplest form of tensor_scatter_sub is to subtract individual elements -// from a tensor by index. For example, say we want to insert 4 scattered elements -// in a rank-1 tensor with 8 elements. +// The simplest form of tensor_scatter_add is to add individual elements to a +// tensor by index. For example, say we want to add 4 elements in a rank-1 +// tensor with 8 elements. // -// In Python, this scatter subtract operation would look like this: +// In Python, this scatter add operation would look like this: // // ```python // indices = tf.constant([[4], [3], [1], [7]]) // updates = tf.constant([9, 10, 11, 12]) // tensor = tf.ones([8], dtype=tf.int32) -// updated = tf.tensor_scatter_sub(tensor, indices, updates) +// updated = tf.tensor_scatter_add(tensor, indices, updates) // with tf.Session() as sess: // print(sess.run(scatter)) // ``` // // The resulting tensor would look like this: // -// [1, -10, 1, -9, -8, 1, 1, -11] +// [1, 12, 1, 11, 10, 1, 1, 13] // // We can also, insert entire slices of a higher rank tensor all at once. For // example, if we wanted to insert two slices in the first dimension of a @@ -379,16 +279,16 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua // [[5, 5, 5, 5], [6, 6, 6, 6], // [7, 7, 7, 7], [8, 8, 8, 8]]]) // tensor = tf.ones([4, 4, 4]) -// updated = tf.tensor_scatter_sub(tensor, indices, updates) +// updated = tf.tensor_scatter_add(tensor, indices, updates) // with tf.Session() as sess: // print(sess.run(scatter)) // ``` // // The resulting tensor would look like this: // -// [[[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]], +// [[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], // [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], -// [[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]], +// [[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], // [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] // // Note that on CPU, if an out of bound index is found, an error is returned. @@ -399,13 +299,13 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua // indices: Index tensor. // updates: Updates to scatter into output. // -// Returns A new tensor copied from tensor and updates subtracted according to the indices. -func TensorScatterSub(scope *Scope, tensor tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { +// Returns A new tensor copied from tensor and updates added according to the indices. +func TensorScatterAdd(scope *Scope, tensor tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorScatterSub", + Type: "TensorScatterAdd", Input: []tf.Input{ tensor, indices, updates, }, @@ -414,232 +314,515 @@ func TensorScatterSub(scope *Scope, tensor tf.Output, indices tf.Output, updates return op.Output(0) } -// Scatter `updates` into an existing tensor according to `indices`. +// LowerBoundAttr is an optional argument to LowerBound. +type LowerBoundAttr func(optionalAttr) + +// LowerBoundOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func LowerBoundOutType(value tf.DataType) LowerBoundAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Applies lower_bound(sorted_search_values, values) along each row. // -// This operation creates a new tensor by applying sparse `updates` to the passed -// in `tensor`. -// This operation is very similar to `tf.scatter_nd`, except that the updates are -// scattered onto an existing tensor (as opposed to a zero-tensor). If the memory -// for the existing tensor cannot be re-used, a copy is made and updated. +// Each set of rows with the same index in (sorted_inputs, values) is treated +// independently. The resulting row is the equivalent of calling +// `np.searchsorted(sorted_inputs, values, side='left')`. // -// If `indices` contains duplicates, then their updates are accumulated (summed). +// The result is not a global index to the entire +// `Tensor`, but rather just the index in the last dimension. // -// **WARNING**: The order in which updates are applied is nondeterministic, so the -// output will be nondeterministic if `indices` contains duplicates -- because -// of some numerical approximation issues, numbers summed in different order -// may yield different results. +// A 2-D example: +// sorted_sequence = [[0, 3, 9, 9, 10], +// [1, 2, 3, 4, 5]] +// values = [[2, 4, 9], +// [0, 2, 6]] // -// `indices` is an integer tensor containing indices into a new tensor of shape -// `shape`. The last dimension of `indices` can be at most the rank of `shape`: +// result = LowerBound(sorted_sequence, values) // -// indices.shape[-1] <= shape.rank -// -// The last dimension of `indices` corresponds to indices into elements -// (if `indices.shape[-1] = shape.rank`) or slices -// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of -// `shape`. `updates` is a tensor with shape -// -// indices.shape[:-1] + shape[indices.shape[-1]:] -// -// The simplest form of scatter is to insert individual elements in a tensor by -// index. For example, say we want to insert 4 scattered elements in a rank-1 -// tensor with 8 elements. -// -//
-// -//
-// -// In Python, this scatter operation would look like this: -// -// ```python -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// tensor = tf.ones([8], dtype=tf.int32) -// updated = tf.tensor_scatter_update(tensor, indices, updates) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [1, 11, 1, 10, 9, 1, 1, 12] -// -// We can also, insert entire slices of a higher rank tensor all at once. For -// example, if we wanted to insert two slices in the first dimension of a -// rank-3 tensor with two matrices of new values. -// -// In Python, this scatter operation would look like this: -// -// ```python -// indices = tf.constant([[0], [2]]) -// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]], -// [[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]]]) -// tensor = tf.ones([4, 4, 4]) -// updated = tf.tensor_scatter_update(tensor, indices, updates) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], -// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], -// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], -// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] -// -// Note that on CPU, if an out of bound index is found, an error is returned. -// On GPU, if an out of bound index is found, the index is ignored. +// result == [[1, 2, 2], +// [0, 1, 5]] // // Arguments: -// tensor: Tensor to copy/update. -// indices: Index tensor. -// updates: Updates to scatter into output. +// sorted_inputs: 2-D Tensor where each row is ordered. +// values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains +// the values that will be searched for in `sorted_search_values`. // -// Returns A new tensor with the given shape and updates applied according -// to the indices. -func TensorScatterUpdate(scope *Scope, tensor tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { +// Returns A `Tensor` with the same shape as `values`. It contains the first scalar index +// into the last dimension where values can be inserted without changing the +// ordered property. +func LowerBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...LowerBoundAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorScatterUpdate", + Type: "LowerBound", Input: []tf.Input{ - tensor, indices, updates, + sorted_inputs, values, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Scatter `updates` into a new tensor according to `indices`. +// UpperBoundAttr is an optional argument to UpperBound. +type UpperBoundAttr func(optionalAttr) + +// UpperBoundOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func UpperBoundOutType(value tf.DataType) UpperBoundAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Applies upper_bound(sorted_search_values, values) along each row. // -// Creates a new tensor by applying sparse `updates` to individual values or -// slices within a tensor (initially zero for numeric, empty for string) of -// the given `shape` according to indices. This operator is the inverse of the -// `tf.gather_nd` operator which extracts values or slices from a given tensor. +// Each set of rows with the same index in (sorted_inputs, values) is treated +// independently. The resulting row is the equivalent of calling +// `np.searchsorted(sorted_inputs, values, side='right')`. // -// This operation is similar to tensor_scatter_add, except that the tensor is -// zero-initialized. Calling `tf.scatter_nd(indices, values, shape)` is identical -// to `tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)` +// The result is not a global index to the entire +// `Tensor`, but rather just the index in the last dimension. // -// If `indices` contains duplicates, then their updates are accumulated (summed). +// A 2-D example: +// sorted_sequence = [[0, 3, 9, 9, 10], +// [1, 2, 3, 4, 5]] +// values = [[2, 4, 9], +// [0, 2, 6]] // -// **WARNING**: The order in which updates are applied is nondeterministic, so the -// output will be nondeterministic if `indices` contains duplicates -- because -// of some numerical approximation issues, numbers summed in different order -// may yield different results. +// result = UpperBound(sorted_sequence, values) // -// `indices` is an integer tensor containing indices into a new tensor of shape -// `shape`. The last dimension of `indices` can be at most the rank of `shape`: -// -// indices.shape[-1] <= shape.rank -// -// The last dimension of `indices` corresponds to indices into elements -// (if `indices.shape[-1] = shape.rank`) or slices -// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of -// `shape`. `updates` is a tensor with shape -// -// indices.shape[:-1] + shape[indices.shape[-1]:] -// -// The simplest form of scatter is to insert individual elements in a tensor by -// index. For example, say we want to insert 4 scattered elements in a rank-1 -// tensor with 8 elements. -// -//
-// -//
-// -// In Python, this scatter operation would look like this: -// -// ```python -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// shape = tf.constant([8]) -// scatter = tf.scatter_nd(indices, updates, shape) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [0, 11, 0, 10, 9, 0, 0, 12] -// -// We can also, insert entire slices of a higher rank tensor all at once. For -// example, if we wanted to insert two slices in the first dimension of a -// rank-3 tensor with two matrices of new values. -// -//
-// -//
-// -// In Python, this scatter operation would look like this: -// -// ```python -// indices = tf.constant([[0], [2]]) -// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]], -// [[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]]]) -// shape = tf.constant([4, 4, 4]) -// scatter = tf.scatter_nd(indices, updates, shape) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], -// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], -// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], -// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] -// -// Note that on CPU, if an out of bound index is found, an error is returned. -// On GPU, if an out of bound index is found, the index is ignored. +// result == [[1, 2, 4], +// [0, 2, 5]] // // Arguments: -// indices: Index tensor. -// updates: Updates to scatter into output. -// shape: 1-D. The shape of the resulting tensor. +// sorted_inputs: 2-D Tensor where each row is ordered. +// values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains +// the values that will be searched for in `sorted_search_values`. // -// Returns A new tensor with the given shape and updates applied according -// to the indices. -func ScatterNd(scope *Scope, indices tf.Output, updates tf.Output, shape tf.Output) (output tf.Output) { +// Returns A `Tensor` with the same shape as `values`. It contains the last scalar index +// into the last dimension where values can be inserted without changing the +// ordered property. +func UpperBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...UpperBoundAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ScatterNd", + Type: "UpperBound", Input: []tf.Input{ - indices, updates, shape, + sorted_inputs, values, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Bitcasts a tensor from one type to another without copying data. +// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. +type QuantizedInstanceNormAttr func(optionalAttr) + +// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value. // -// Given a tensor `input`, this operation returns a tensor that has the same buffer -// data as `input` with datatype `type`. +// value: If True, `given_y_min` and `given_y_min` +// and `given_y_max` are used as the output range. Otherwise, +// the implementation computes the output range. +// If not specified, defaults to false +func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["output_range_given"] = value + } +} + +// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value. // -// If the input datatype `T` is larger than the output datatype `type` then the -// shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)]. +// value: Output in `y_min` if `output_range_given` is True. +// If not specified, defaults to 0 +func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["given_y_min"] = value + } +} + +// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value. // -// If `T` is smaller than `type`, the operator requires that the rightmost -// dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from -// [..., sizeof(`type`)/sizeof(`T`)] to [...]. +// value: Output in `y_max` if `output_range_given` is True. +// If not specified, defaults to 0 +func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["given_y_max"] = value + } +} + +// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value. // -// *NOTE*: Bitcast is implemented as a low-level cast, so machines with different -// endian orderings will give different results. -func Bitcast(scope *Scope, input tf.Output, type_ tf.DataType) (output tf.Output) { +// value: A small float number to avoid dividing by 0. +// If not specified, defaults to 1e-05 +func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["variance_epsilon"] = value + } +} + +// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value. +// +// value: Minimum value of `y_max - y_min` +// If not specified, defaults to 0.001 +func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr { + return func(m optionalAttr) { + m["min_separation"] = value + } +} + +// Quantized Instance normalization. +// +// Arguments: +// x: A 4D input Tensor. +// x_min: The value represented by the lowest quantized input. +// x_max: The value represented by the highest quantized input. +// +// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output. +func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"type": type_} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Bitcast", + Type: "QuantizedInstanceNorm", + Input: []tf.Input{ + x, x_min, x_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Concatenates quantized tensors along one dimension. +// +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// input_mins: The minimum scalar values for each of the input tensors. +// input_maxes: The maximum scalar values for each of the input tensors. +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "QuantizedConcat", + Input: []tf.Input{ + concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// DequantizeAttr is an optional argument to Dequantize. +type DequantizeAttr func(optionalAttr) + +// DequantizeMode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func DequantizeMode(value string) DequantizeAttr { + return func(m optionalAttr) { + m["mode"] = value + } +} + +// Dequantize the 'input' tensor into a float Tensor. +// +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. +// +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: +// +// ``` +// if T == qint8: in[i] += (range(T) + 1)/ 2.0 +// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) +// ``` +// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// +// *MIN_COMBINED Mode Example* +// +// If the input comes from a QuantizedRelu6, the output type is +// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is +// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. +// Dequantize on quint8 will take each value, cast to float, and multiply +// by 6 / 255. +// Note that if quantizedtype is qint8, the operation will additionally add +// each value by 128 prior to casting. +// +// If the mode is 'MIN_FIRST', then this approach is used: +// +// ```c++ +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = range / num_discrete_values +// const double offset_input = static_cast(input) - lowest_quantized; +// result = range_min + ((input - numeric_limits::min()) * range_scale) +// ``` +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. +// +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` +// +// Our input tensor range is then `[-m, m]`. +// +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` +// +// Otherwise, if T is unsigned, the fixed-point range is +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` +// +// From this we compute our scaling factor, s: +// ```c++ +// s = (2 * m) / (max_fixed - min_fixed) +// ``` +// +// Now we can dequantize the elements of our tensor: +// ```c++ +// result = input * s +// ``` +// +// Arguments: +// +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, optional ...DequantizeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Dequantize", + Input: []tf.Input{ + input, min_range, max_range, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizeV2Attr is an optional argument to QuantizeV2. +type QuantizeV2Attr func(optionalAttr) + +// QuantizeV2Mode sets the optional mode attribute to value. +// If not specified, defaults to "MIN_COMBINED" +func QuantizeV2Mode(value string) QuantizeV2Attr { + return func(m optionalAttr) { + m["mode"] = value + } +} + +// QuantizeV2RoundMode sets the optional round_mode attribute to value. +// If not specified, defaults to "HALF_AWAY_FROM_ZERO" +func QuantizeV2RoundMode(value string) QuantizeV2Attr { + return func(m optionalAttr) { + m["round_mode"] = value + } +} + +// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. +// +// [min_range, max_range] are scalar floats that specify the range for +// the 'input' data. The 'mode' attribute controls exactly which calculations are +// used to convert the float values to their quantized equivalents. The +// 'round_mode' attribute controls which rounding tie-breaking algorithm is used +// when rounding float values to their quantized equivalents. +// +// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: +// +// ``` +// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) +// if T == qint8: out[i] -= (range(T) + 1) / 2.0 +// ``` +// +// here `range(T) = numeric_limits::max() - numeric_limits::min()` +// +// *MIN_COMBINED Mode Example* +// +// Assume the input is type float and has a possible range of [0.0, 6.0] and the +// output type is quint8 ([0, 255]). The min_range and max_range values should be +// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each +// value of the input by 255/6 and cast to quint8. +// +// If the output type was qint8 ([-128, 127]), the operation will additionally +// subtract each value by 128 prior to casting, so that the range of values aligns +// with the range of qint8. +// +// If the mode is 'MIN_FIRST', then this approach is used: +// +// ``` +// num_discrete_values = 1 << (# of bits in T) +// range_adjust = num_discrete_values / (num_discrete_values - 1) +// range = (range_max - range_min) * range_adjust +// range_scale = num_discrete_values / range +// quantized = round(input * range_scale) - round(range_min * range_scale) + +// numeric_limits::min() +// quantized = max(quantized, numeric_limits::min()) +// quantized = min(quantized, numeric_limits::max()) +// ``` +// +// The biggest difference between this and MIN_COMBINED is that the minimum range +// is rounded first, before it's subtracted from the rounded value. With +// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing +// and dequantizing will introduce a larger and larger error. +// +// *SCALED mode Example* +// +// `SCALED` mode matches the quantization approach used in +// `QuantizeAndDequantize{V2|V3}`. +// +// If the mode is `SCALED`, we do not use the full range of the output type, +// choosing to elide the lowest possible value for symmetry (e.g., output range is +// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to +// 0. +// +// We first find the range of values in our tensor. The +// range we use is always centered on 0, so we find m such that +// +// ```c++ +// m = max(abs(input_min), abs(input_max)) +// ``` +// +// Our input tensor range is then `[-m, m]`. +// +// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. +// If T is signed, this is +// +// ``` +// num_bits = sizeof(T) * 8 +// [min_fixed, max_fixed] = +// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] +// ``` +// +// Otherwise, if T is unsigned, the fixed-point range is +// +// ``` +// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] +// ``` +// +// From this we compute our scaling factor, s: +// +// ```c++ +// s = (max_fixed - min_fixed) / (2 * m) +// ``` +// +// Now we can quantize the elements of our tensor: +// +// ```c++ +// result = round(input * s) +// ``` +// +// One thing to watch out for is that the operator may choose to adjust the +// requested minimum and maximum values slightly during the quantization process, +// so you should always use the output ports as the range for further calculations. +// For example, if the requested minimum and maximum values are close to equal, +// they will be separated by a small epsilon value to prevent ill-formed quantized +// buffers from being created. Otherwise, you can end up with buffers where all the +// quantized values map to the same float value, which causes problems for +// operations that have to perform further calculations on them. +// +// Arguments: +// +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +// +// +// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output. +func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizeV2", + Input: []tf.Input{ + input, min_range, max_range, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Extract `patches` from `input` and put them in the "depth" output dimension. 3D extension of `extract_image_patches`. +// +// Arguments: +// input: 5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`. +// ksizes: The size of the sliding window for each dimension of `input`. +// strides: 1-D of length 5. How far the centers of two consecutive patches are in +// `input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`. +// padding: The type of padding algorithm to use. +// +// We specify the size-related attributes as: +// +// ```python +// ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1] +// strides = [1, stride_planes, strides_rows, strides_cols, 1] +// ``` +// +// Returns 5-D Tensor with shape `[batch, out_planes, out_rows, out_cols, +// ksize_planes * ksize_rows * ksize_cols * depth]` containing patches +// with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized +// in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols` +// are the dimensions of the output patches. +func ExtractVolumePatches(scope *Scope, input tf.Output, ksizes []int64, strides []int64, padding string) (patches tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "padding": padding} + opspec := tf.OpSpec{ + Type: "ExtractVolumePatches", Input: []tf.Input{ input, }, @@ -649,30 +832,76 @@ func Bitcast(scope *Scope, input tf.Output, type_ tf.DataType) (output tf.Output return op.Output(0) } -// SpaceToDepthAttr is an optional argument to SpaceToDepth. -type SpaceToDepthAttr func(optionalAttr) +// Extract `patches` from `images` and put them in the "depth" output dimension. +// +// Arguments: +// images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`. +// ksizes: The size of the sliding window for each dimension of `images`. +// strides: 1-D of length 4. How far the centers of two consecutive patches are in +// the images. Must be: `[1, stride_rows, stride_cols, 1]`. +// rates: 1-D of length 4. Must be: `[1, rate_rows, rate_cols, 1]`. This is the +// input stride, specifying how far two consecutive patch samples are in the +// input. Equivalent to extracting patches with +// `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by +// subsampling them spatially by a factor of `rates`. This is equivalent to +// `rate` in dilated (a.k.a. Atrous) convolutions. +// padding: The type of padding algorithm to use. +// +// We specify the size-related attributes as: +// +// ```python +// ksizes = [1, ksize_rows, ksize_cols, 1] +// strides = [1, strides_rows, strides_cols, 1] +// rates = [1, rates_rows, rates_cols, 1] +// ``` +// +// Returns 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows * +// ksize_cols * depth]` containing image patches with size +// `ksize_rows x ksize_cols x depth` vectorized in the "depth" dimension. Note +// `out_rows` and `out_cols` are the dimensions of the output patches. +func ExtractImagePatches(scope *Scope, images tf.Output, ksizes []int64, strides []int64, rates []int64, padding string) (patches tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "rates": rates, "padding": padding} + opspec := tf.OpSpec{ + Type: "ExtractImagePatches", + Input: []tf.Input{ + images, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// SpaceToDepthDataFormat sets the optional data_format attribute to value. +// DepthToSpaceAttr is an optional argument to DepthToSpace. +type DepthToSpaceAttr func(optionalAttr) + +// DepthToSpaceDataFormat sets the optional data_format attribute to value. // If not specified, defaults to "NHWC" -func SpaceToDepthDataFormat(value string) SpaceToDepthAttr { +func DepthToSpaceDataFormat(value string) DepthToSpaceAttr { return func(m optionalAttr) { m["data_format"] = value } } -// SpaceToDepth for tensors of type T. +// DepthToSpace for tensors of type T. // -// Rearranges blocks of spatial data, into depth. More specifically, -// this op outputs a copy of the input tensor where values from the `height` -// and `width` dimensions are moved to the `depth` dimension. -// The attr `block_size` indicates the input block size. +// Rearranges data from depth into blocks of spatial data. +// This is the reverse transformation of SpaceToDepth. More specifically, +// this op outputs a copy of the input tensor where values from the `depth` +// dimension are moved in spatial blocks to the `height` and `width` dimensions. +// The attr `block_size` indicates the input block size and how the data is moved. // -// * Non-overlapping blocks of size `block_size x block size` are rearranged -// into depth at each location. -// * The depth of the output tensor is `block_size * block_size * input_depth`. -// * The Y, X coordinates within each block of the input become the high order -// component of the output channel index. -// * The input tensor's height and width must be divisible by block_size. +// * Chunks of data of size `block_size * block_size` from depth are rearranged +// into non-overlapping blocks of size `block_size x block_size` +// * The width the output tensor is `input_depth * block_size`, whereas the +// height is `input_height * block_size`. +// * The Y, X coordinates within each block of the output image are determined +// by the high order component of the input channel index. +// * The depth of the input tensor must be divisible by +// `block_size * block_size`. // // The `data_format` attr specifies the layout of the input and output tensors // with the following options: @@ -685,71 +914,74 @@ func SpaceToDepthDataFormat(value string) SpaceToDepthAttr { // e.g. for data_format = NHWC, // Each element in the input tensor can be specified via 6 coordinates, // ordered by decreasing memory layout significance as: -// n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates -// within the output image, bX, bY means coordinates -// within the input block, iC means input channels). -// The output would be a transpose to the following layout: -// n,oY,oX,bY,bX,iC +// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates +// within the input image, bX, bY means coordinates +// within the output block, oC means output channels). +// The output would be the input transposed to the following layout: +// n,iY,bY,iX,bX,oC // // This operation is useful for resizing the activations between convolutions // (but keeping all data), e.g. instead of pooling. It is also useful for training // purely convolutional models. // -// For example, given an input of shape `[1, 2, 2, 1]`, data_format = "NHWC" and +// For example, given an input of shape `[1, 1, 1, 4]`, data_format = "NHWC" and // block_size = 2: // // ``` -// x = [[[[1], [2]], -// [[3], [4]]]] -// ``` -// -// This operation will output a tensor of shape `[1, 1, 1, 4]`: +// x = [[[[1, 2, 3, 4]]]] // // ``` -// [[[[1, 2, 3, 4]]]] -// ``` // -// Here, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`, -// the corresponding output will have a single element (i.e. width and height are -// both 1) and will have a depth of 4 channels (1 * block_size * block_size). -// The output element shape is `[1, 1, 4]`. -// -// For an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g. +// This operation will output a tensor of shape `[1, 2, 2, 1]`: // // ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] +// [[[[1], [2]], +// [[3], [4]]]] // ``` // -// This operation, for block_size of 2, will return the following tensor of shape -// `[1, 1, 1, 12]` +// Here, the input has a batch of 1 and each batch element has shape `[1, 1, 4]`, +// the corresponding output will have 2x2 elements and will have a depth of +// 1 channel (1 = `4 / (block_size * block_size)`). +// The output element shape is `[2, 2, 1]`. +// +// For an input tensor with larger depth, here of shape `[1, 1, 1, 12]`, e.g. // // ``` -// [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] +// x = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] // ``` // -// Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2: +// This operation, for block size of 2, will return the following tensor of shape +// `[1, 2, 2, 3]` // // ``` -// x = [[[[1], [2], [5], [6]], -// [[3], [4], [7], [8]], -// [[9], [10], [13], [14]], -// [[11], [12], [15], [16]]]] -// ``` -// -// the operator will return the following tensor of shape `[1 2 2 4]`: +// [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] // // ``` -// x = [[[[1, 2, 3, 4], +// +// Similarly, for the following input of shape `[1 2 2 4]`, and a block size of 2: +// +// ``` +// x = [[[[1, 2, 3, 4], // [5, 6, 7, 8]], // [[9, 10, 11, 12], // [13, 14, 15, 16]]]] // ``` // +// the operator will return the following tensor of shape `[1 4 4 1]`: +// +// ``` +// x = [[[ [1], [2], [5], [6]], +// [ [3], [4], [7], [8]], +// [ [9], [10], [13], [14]], +// [ [11], [12], [15], [16]]]] +// +// ``` +// // Arguments: // -// block_size: The size of the spatial block. -func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...SpaceToDepthAttr) (output tf.Output) { +// block_size: The size of the spatial block, same as in Space2Depth. +func DepthToSpace(scope *Scope, input tf.Output, block_size int64, optional ...DepthToSpaceAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -758,7 +990,7 @@ func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...S a(attrs) } opspec := tf.OpSpec{ - Type: "SpaceToDepth", + Type: "DepthToSpace", Input: []tf.Input{ input, }, @@ -768,77 +1000,93 @@ func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...S return op.Output(0) } -// SpaceToBatch for 4-D tensors of type T. +// BatchToSpace for N-D tensors of type T. // -// This is a legacy version of the more general SpaceToBatchND. -// -// Zero-pads and then rearranges (permutes) blocks of spatial data into batch. -// More specifically, this op outputs a copy of the input tensor where values from -// the `height` and `width` dimensions are moved to the `batch` dimension. After -// the zero-padding, both `height` and `width` of the input must be divisible by the -// block size. +// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape +// `block_shape + [batch]`, interleaves these blocks back into the grid defined by +// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as +// the input. The spatial dimensions of this intermediate result are then +// optionally cropped according to `crops` to produce the output. This is the +// reverse of SpaceToBatch. See below for a precise description. // // Arguments: -// input: 4-D with shape `[batch, height, width, depth]`. -// paddings: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies -// the padding of the input with zeros across the spatial dimensions as follows: +// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, +// where spatial_shape has M dimensions. +// block_shape: 1-D with shape `[M]`, all values must be >= 1. +// crops: 2-D with shape `[M, 2]`, all values must be >= 0. +// `crops[i] = [crop_start, crop_end]` specifies the amount to crop from input +// dimension `i + 1`, which corresponds to spatial dimension `i`. It is +// required that +// `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`. // -// paddings = [[pad_top, pad_bottom], [pad_left, pad_right]] +// This operation is equivalent to the following steps: // -// The effective spatial dimensions of the zero-padded input tensor will be: +// 1. Reshape `input` to `reshaped` of shape: +// [block_shape[0], ..., block_shape[M-1], +// batch / prod(block_shape), +// input_shape[1], ..., input_shape[N-1]] // -// height_pad = pad_top + height + pad_bottom -// width_pad = pad_left + width + pad_right +// 2. Permute dimensions of `reshaped` to produce `permuted` of shape +// [batch / prod(block_shape), // -// The attr `block_size` must be greater than one. It indicates the block size. +// input_shape[1], block_shape[0], +// ..., +// input_shape[M], block_shape[M-1], // -// * Non-overlapping blocks of size `block_size x block size` in the height and -// width dimensions are rearranged into the batch dimension at each location. -// * The batch of the output tensor is `batch * block_size * block_size`. -// * Both height_pad and width_pad must be divisible by block_size. +// input_shape[M+1], ..., input_shape[N-1]] // -// The shape of the output will be: +// 3. Reshape `permuted` to produce `reshaped_permuted` of shape +// [batch / prod(block_shape), // -// [batch*block_size*block_size, height_pad/block_size, width_pad/block_size, -// depth] +// input_shape[1] * block_shape[0], +// ..., +// input_shape[M] * block_shape[M-1], +// +// input_shape[M+1], +// ..., +// input_shape[N-1]] +// +// 4. Crop the start and end of dimensions `[1, ..., M]` of +// `reshaped_permuted` according to `crops` to produce the output of shape: +// [batch / prod(block_shape), +// +// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], +// ..., +// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], +// +// input_shape[M+1], ..., input_shape[N-1]] // // Some examples: // -// (1) For the following input of shape `[1, 2, 2, 1]` and block_size of 2: -// -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` -// -// The output tensor has shape `[4, 1, 1, 1]` and value: +// (1) For the following input of shape `[4, 1, 1, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: // // ``` // [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] // ``` // -// (2) For the following input of shape `[1, 2, 2, 3]` and block_size of 2: +// The output tensor has shape `[1, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// (2) For the following input of shape `[4, 1, 1, 3]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: +// +// ``` +// [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 3]` and value: // // ``` // x = [[[[1, 2, 3], [4, 5, 6]], // [[7, 8, 9], [10, 11, 12]]]] // ``` // -// The output tensor has shape `[4, 1, 1, 3]` and value: -// -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` -// -// (3) For the following input of shape `[1, 4, 4, 1]` and block_size of 2: -// -// ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] -// ``` -// -// The output tensor has shape `[4, 2, 2, 1]` and value: +// (3) For the following input of shape `[4, 2, 2, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [0, 0]]`: // // ``` // x = [[[[1], [3]], [[9], [11]]], @@ -847,7 +1095,26 @@ func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...S // [[[6], [8]], [[14], [16]]]] // ``` // -// (4) For the following input of shape `[2, 2, 4, 1]` and block_size of 2: +// The output tensor has shape `[1, 4, 4, 1]` and value: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +// +// (4) For the following input of shape `[8, 1, 3, 1]`, `block_shape = [2, 2]`, and +// `crops = [[0, 0], [2, 0]]`: +// +// ``` +// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], +// [[[0], [2], [4]]], [[[0], [10], [12]]], +// [[[0], [5], [7]]], [[[0], [13], [15]]], +// [[[0], [6], [8]]], [[[0], [14], [16]]]] +// ``` +// +// The output tensor has shape `[2, 2, 4, 1]` and value: // // ``` // x = [[[[1], [2], [3], [4]], @@ -855,28 +1122,15 @@ func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...S // [[[9], [10], [11], [12]], // [[13], [14], [15], [16]]]] // ``` -// -// The output tensor has shape `[8, 1, 2, 1]` and value: -// -// ``` -// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], -// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] -// ``` -// -// Among others, this operation is useful for reducing atrous convolution into -// regular convolution. -// -func SpaceToBatch(scope *Scope, input tf.Output, paddings tf.Output, block_size int64) (output tf.Output) { +func BatchToSpaceND(scope *Scope, input tf.Output, block_shape tf.Output, crops tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"block_size": block_size} opspec := tf.OpSpec{ - Type: "SpaceToBatch", + Type: "BatchToSpaceND", Input: []tf.Input{ - input, paddings, + input, block_shape, crops, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -962,7 +1216,7 @@ func SpaceToBatch(scope *Scope, input tf.Output, paddings tf.Output, block_size // The output tensor has shape `[4, 1, 1, 3]` and value: // // ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] +// [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] // ``` // // (3) For the following input of shape `[1, 4, 4, 1]`, `block_shape = [2, 2]`, and @@ -1019,84 +1273,6 @@ func SpaceToBatchND(scope *Scope, input tf.Output, block_shape tf.Output, paddin return op.Output(0) } -// Inserts a dimension of 1 into a tensor's shape. -// -// Given a tensor `input`, this operation inserts a dimension of 1 at the -// dimension index `axis` of `input`'s shape. The dimension index `axis` starts at -// zero; if you specify a negative number for `axis` it is counted backward from -// the end. -// -// This operation is useful if you want to add a batch dimension to a single -// element. For example, if you have a single image of shape `[height, width, -// channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, -// which will make the shape `[1, height, width, channels]`. -// -// Other examples: -// -// ``` -// # 't' is a tensor of shape [2] -// shape(expand_dims(t, 0)) ==> [1, 2] -// shape(expand_dims(t, 1)) ==> [2, 1] -// shape(expand_dims(t, -1)) ==> [2, 1] -// -// # 't2' is a tensor of shape [2, 3, 5] -// shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] -// shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] -// shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] -// ``` -// -// This operation requires that: -// -// `-1-input.dims() <= dim <= input.dims()` -// -// This operation is related to `squeeze()`, which removes dimensions of -// size 1. -// -// Arguments: -// -// axis: 0-D (scalar). Specifies the dimension index at which to -// expand the shape of `input`. Must be in the range -// `[-rank(input) - 1, rank(input)]`. -// -// Returns Contains the same data as `input`, but its shape has an additional -// dimension of size 1 added. -func ExpandDims(scope *Scope, input tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ExpandDims", - Input: []tf.Input{ - input, axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// A placeholder op that passes through `input` when its output is not fed. -// -// Arguments: -// input: The default value to produce when `output` is not fed. -// shape: The (possibly partial) shape of the tensor. -// -// Returns A placeholder tensor that defaults to `input` if it is not fed. -func PlaceholderWithDefault(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shape": shape} - opspec := tf.OpSpec{ - Type: "PlaceholderWithDefault", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // A placeholder op for a value that will be fed into the computation. // // DEPRECATED at GraphDef version 23: Placeholder now behaves the same as PlaceholderV2. @@ -1125,47 +1301,6 @@ func PlaceholderV2(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.O return op.Output(0) } -// PlaceholderAttr is an optional argument to Placeholder. -type PlaceholderAttr func(optionalAttr) - -// PlaceholderShape sets the optional shape attribute to value. -// -// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the -// shape is unconstrained. -// If not specified, defaults to -func PlaceholderShape(value tf.Shape) PlaceholderAttr { - return func(m optionalAttr) { - m["shape"] = value - } -} - -// A placeholder op for a value that will be fed into the computation. -// -// N.B. This operation will fail with an error if it is executed. It is -// intended as a way to represent a value that will always be fed, and to -// provide attrs that enable the fed value to be checked at runtime. -// -// Arguments: -// dtype: The type of elements in the tensor. -// -// Returns A placeholder tensor that must be replaced using the feed mechanism. -func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Placeholder", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor. // // This operation folds the padded areas of `input` by `MirrorPad` according to the @@ -1210,61 +1345,6 @@ func MirrorPadGrad(scope *Scope, input tf.Output, paddings tf.Output, mode strin return op.Output(0) } -// Pads a tensor with mirrored values. -// -// This operation pads a `input` with mirrored values according to the `paddings` -// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is -// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many values to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many values to add after the contents of `input` -// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater -// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true -// (if false, respectively). -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` -// -// For example: -// -// ``` -// # 't' is [[1, 2, 3], [4, 5, 6]]. -// # 'paddings' is [[1, 1]], [2, 2]]. -// # 'mode' is SYMMETRIC. -// # rank of 't' is 2. -// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] -// [2, 1, 1, 2, 3, 3, 2] -// [5, 4, 4, 5, 6, 6, 5] -// [5, 4, 4, 5, 6, 6, 5]] -// ``` -// -// Arguments: -// input: The input tensor to be padded. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions -// do not include the borders, while in symmetric mode the padded regions -// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` -// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and -// it is `[1, 2, 3, 3, 2]` in symmetric mode. -// -// Returns The padded tensor. -func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode} - opspec := tf.OpSpec{ - Type: "MirrorPad", - Input: []tf.Input{ - input, paddings, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Pads a tensor. // // This operation pads `input` according to the `paddings` and `constant_values` @@ -1305,21 +1385,61 @@ func PadV2(scope *Scope, input tf.Output, paddings tf.Output, constant_values tf return op.Output(0) } -// Return the reduction indices for computing gradients of s0 op s1 with broadcast. +// Pads a tensor with zeros. // -// This is typically used by gradient computations for a broadcasting operation. -func BroadcastGradientArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output, r1 tf.Output) { +// This operation pads a `input` with zeros according to the `paddings` you +// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many zeros to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` +// in that dimension. +// +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 1], [2, 2]] +// # 'paddings' is [[1, 1], [2, 2]] +// # rank of 't' is 2 +// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] +// [0, 0, 1, 1, 0, 0] +// [0, 0, 2, 2, 0, 0] +// [0, 0, 0, 0, 0, 0]] +// ``` +// +func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BroadcastGradientArgs", + Type: "Pad", + Input: []tf.Input{ + input, paddings, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Return the shape of s0 op s1 with broadcast. +// +// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the +// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. +func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BroadcastArgs", Input: []tf.Input{ s0, s1, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } // Returns the gradient of `Tile`. @@ -1368,181 +1488,58 @@ func Tile(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) return op.Output(0) } -// StridedSliceAttr is an optional argument to StridedSlice. -type StridedSliceAttr func(optionalAttr) +// TensorStridedSliceUpdateAttr is an optional argument to TensorStridedSliceUpdate. +type TensorStridedSliceUpdateAttr func(optionalAttr) -// StridedSliceBeginMask sets the optional begin_mask attribute to value. -// -// value: a bitmask where a bit i being 1 means to ignore the begin -// value and instead use the largest interval possible. At runtime -// begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or -// `[-1, n-1]` if `stride[i] < 0` +// TensorStridedSliceUpdateBeginMask sets the optional begin_mask attribute to value. // If not specified, defaults to 0 -func StridedSliceBeginMask(value int64) StridedSliceAttr { +func TensorStridedSliceUpdateBeginMask(value int64) TensorStridedSliceUpdateAttr { return func(m optionalAttr) { m["begin_mask"] = value } } -// StridedSliceEndMask sets the optional end_mask attribute to value. -// -// value: analogous to `begin_mask` +// TensorStridedSliceUpdateEndMask sets the optional end_mask attribute to value. // If not specified, defaults to 0 -func StridedSliceEndMask(value int64) StridedSliceAttr { +func TensorStridedSliceUpdateEndMask(value int64) TensorStridedSliceUpdateAttr { return func(m optionalAttr) { m["end_mask"] = value } } -// StridedSliceEllipsisMask sets the optional ellipsis_mask attribute to value. -// -// value: a bitmask where bit `i` being 1 means the `i`th -// position is actually an ellipsis. One bit at most can be 1. -// If `ellipsis_mask == 0`, then an implicit ellipsis mask of `1 << (m+1)` -// is provided. This means that `foo[3:5] == foo[3:5, ...]`. An ellipsis -// implicitly creates as many range specifications as necessary to fully -// specify the sliced range for every dimension. For example for a 4-dimensional -// tensor `foo` the slice `foo[2, ..., 5:8]` implies `foo[2, :, :, 5:8]`. +// TensorStridedSliceUpdateEllipsisMask sets the optional ellipsis_mask attribute to value. // If not specified, defaults to 0 -func StridedSliceEllipsisMask(value int64) StridedSliceAttr { +func TensorStridedSliceUpdateEllipsisMask(value int64) TensorStridedSliceUpdateAttr { return func(m optionalAttr) { m["ellipsis_mask"] = value } } -// StridedSliceNewAxisMask sets the optional new_axis_mask attribute to value. -// -// value: a bitmask where bit `i` being 1 means the `i`th -// specification creates a new shape 1 dimension. For example -// `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor. +// TensorStridedSliceUpdateNewAxisMask sets the optional new_axis_mask attribute to value. // If not specified, defaults to 0 -func StridedSliceNewAxisMask(value int64) StridedSliceAttr { +func TensorStridedSliceUpdateNewAxisMask(value int64) TensorStridedSliceUpdateAttr { return func(m optionalAttr) { m["new_axis_mask"] = value } } -// StridedSliceShrinkAxisMask sets the optional shrink_axis_mask attribute to value. -// -// value: a bitmask where bit `i` implies that the `i`th -// specification should shrink the dimensionality. begin and end -// must imply a slice of size 1 in the dimension. For example in -// python one might do `foo[:, 3, :]` which would result in -// `shrink_axis_mask` being 2. +// TensorStridedSliceUpdateShrinkAxisMask sets the optional shrink_axis_mask attribute to value. // If not specified, defaults to 0 -func StridedSliceShrinkAxisMask(value int64) StridedSliceAttr { +func TensorStridedSliceUpdateShrinkAxisMask(value int64) TensorStridedSliceUpdateAttr { return func(m optionalAttr) { m["shrink_axis_mask"] = value } } -// Return a strided slice from `input`. +// Assign `value` to the sliced l-value reference of `input`. // -// Note, most python users will want to use the Python `Tensor.__getitem__` -// or `Variable.__getitem__` rather than this op directly. +// The values of `value` are assigned to the positions in the tensor `input` that +// are selected by the slice parameters. The slice parameters `begin` `end` +// `strides` etc. work exactly as in `StridedSlice`. // -// The goal of this op is to produce a new tensor with a subset of -// the elements from the `n` dimensional `input` tensor. The subset is chosen using -// a sequence of `m` sparse range specifications encoded into the arguments -// of this function. Note, in some cases -// `m` could be equal to `n`, but this need not be the case. Each -// range specification entry can be one of the following: -// -// - An ellipsis (...). Ellipses are used to imply zero or more -// dimensions of full-dimension selection and are produced using -// `ellipsis_mask`. For example, `foo[...]` is the identity slice. -// -// - A new axis. This is used to insert a new shape=1 dimension and is -// produced using `new_axis_mask`. For example, `foo[:, ...]` where -// `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. -// -// -// - A range `begin:end:stride`. This is used to specify how much to choose from -// a given dimension. `stride` can be any integer but 0. `begin` is an integer -// which represents the index of the first value to select while `end` represents -// the index of the last value to select. The number of values selected in each -// dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. -// `begin` and `end` can be negative where `-1` is the last element, `-2` is -// the second to last. `begin_mask` controls whether to replace the explicitly -// given `begin` with an implicit effective value of `0` if `stride > 0` and -// `-1` if `stride < 0`. `end_mask` is analogous but produces the number -// required to create the largest open interval. For example, given a shape -// `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do -// not assume this is equivalent to `foo[0:-1]` which has an effective `begin` -// and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the -// first dimension of a tensor while dropping the last two (in the original -// order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. -// -// - A single index. This is used to keep only elements that have a given -// index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a -// shape `(6,)` tensor. This is encoded in `begin` and `end` and -// `shrink_axis_mask`. -// -// Each conceptual range specification is encoded in the op's argument. This -// encoding is best understand by considering a non-trivial example. In -// particular, -// `foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as -// -// ``` -// begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0) -// end = [2, 4, x, x, -3, x] -// strides = [1, 1, x, x, -1, 1] -// begin_mask = 1<<4 | 1 << 5 = 48 -// end_mask = 1<<5 = 32 -// ellipsis_mask = 1<<3 = 8 -// new_axis_mask = 1<<2 4 -// shrink_axis_mask = 1<<0 -// ``` -// -// In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of -// the slice becomes (2, 1, 5, 5, 2, 5). -// Let us walk step by step through each argument specification. -// -// 1. The first argument in the example slice is turned into `begin = 1` and -// `end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we -// also set the appropriate bit in `shrink_axis_mask`. -// -// 2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have -// zero bits contributed. -// -// 3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1 -// dimension in the final shape. Dummy values are contributed to begin, -// end and stride, while the new_axis_mask bit is set. -// -// 4. `...` grab the full ranges from as many dimensions as needed to -// fully specify a slice for every dimension of the input shape. -// -// 5. `:-3:-1` shows the use of negative indices. A negative index `i` associated -// with a dimension that has shape `s` is converted to a positive index -// `s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion -// is done internally so begin, end and strides receive x, -3, and -1. -// The appropriate begin_mask bit is set to indicate the start range is the -// full range (ignoring the x). -// -// 6. `:` indicates that the entire contents of the corresponding dimension -// is selected. This is equivalent to `::` or `0::1`. begin, end, and strides -// receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and -// `end_mask` are also set. -// -// *Requirements*: -// `0 != strides[i] for i in [0, m)` -// `ellipsis_mask must be a power of two (only one ellipsis)` -// -// Arguments: -// -// begin: `begin[k]` specifies the offset into the `k`th range specification. -// The exact dimension this corresponds to will be determined by context. -// Out-of-bounds values will be silently clamped. If the `k`th bit of -// `begin_mask` then `begin[k]` is ignored and the full range of the -// appropriate dimension is used instead. Negative values causes indexing -// to start from the highest element e.g. If `foo==[1,2,3]` then `foo[-1]==3`. -// end: `end[i]` is like `begin` with the exception that `end_mask` is -// used to determine full ranges. -// strides: `strides[i]` specifies the increment in the `i`th specification -// after extracting a given element. Negative indices will reverse -// the original order. Out or range values are -// clamped to `[0,dim[i]) if slice[i]>0` or `[-1,dim[i]-1] if slice[i] < 0` -func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, strides tf.Output, optional ...StridedSliceAttr) (output tf.Output) { +// NOTE this op currently does not support broadcasting and so `value`'s shape +// must be exactly the shape produced by the slice of `input`. +func TensorStridedSliceUpdate(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, strides tf.Output, value tf.Output, optional ...TensorStridedSliceUpdateAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -1551,9 +1548,9 @@ func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "StridedSlice", + Type: "TensorStridedSliceUpdate", Input: []tf.Input{ - input, begin, end, strides, + input, begin, end, strides, value, }, Attrs: attrs, } @@ -1561,6 +1558,164 @@ func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, return op.Output(0) } +// StridedSliceGradAttr is an optional argument to StridedSliceGrad. +type StridedSliceGradAttr func(optionalAttr) + +// StridedSliceGradBeginMask sets the optional begin_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradBeginMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["begin_mask"] = value + } +} + +// StridedSliceGradEndMask sets the optional end_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradEndMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["end_mask"] = value + } +} + +// StridedSliceGradEllipsisMask sets the optional ellipsis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradEllipsisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["ellipsis_mask"] = value + } +} + +// StridedSliceGradNewAxisMask sets the optional new_axis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradNewAxisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["new_axis_mask"] = value + } +} + +// StridedSliceGradShrinkAxisMask sets the optional shrink_axis_mask attribute to value. +// If not specified, defaults to 0 +func StridedSliceGradShrinkAxisMask(value int64) StridedSliceGradAttr { + return func(m optionalAttr) { + m["shrink_axis_mask"] = value + } +} + +// Returns the gradient of `StridedSlice`. +// +// Since `StridedSlice` cuts out pieces of its `input` which is size +// `shape`, its gradient will have the same shape (which is passed here +// as `shape`). The gradient will be zero in any element that the slice +// does not select. +// +// Arguments are the same as StridedSliceGrad with the exception that +// `dy` is the input gradient to be propagated and `shape` is the +// shape of `StridedSlice`'s `input`. +func StridedSliceGrad(scope *Scope, shape tf.Output, begin tf.Output, end tf.Output, strides tf.Output, dy tf.Output, optional ...StridedSliceGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StridedSliceGrad", + Input: []tf.Input{ + shape, begin, end, strides, dy, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Return a slice from 'input'. +// +// The output tensor is a tensor with dimensions described by 'size' +// whose values are extracted from 'input' starting at the offsets in +// 'begin'. +// +// *Requirements*: +// 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) +// +// Arguments: +// +// begin: begin[i] specifies the offset into the 'i'th dimension of +// 'input' to slice from. +// size: size[i] specifies the number of elements of the 'i'th dimension +// of 'input' to slice. If size[i] is -1, all remaining elements in dimension +// i are included in the slice (i.e. this is equivalent to setting +// size[i] = input.dim_size(i) - begin[i]). +func Slice(scope *Scope, input tf.Output, begin tf.Output, size tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Slice", + Input: []tf.Input{ + input, begin, size, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. +type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value. +// +// value: The bitwidth of the quantization; between 2 and 8, inclusive. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxVarsGradientNarrowRange sets the optional narrow_range attribute to value. +// +// value: Whether to quantize into 2^num_bits - 1 distinct values. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxVars operation. +// +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. +// min, max: Quantization interval, scalar floats. +// +// +// +// Returns Backpropagated gradients w.r.t. inputs: +// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter: +// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter: +// `sum(gradients * (inputs > max))`. +func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxVarsGradient", + Input: []tf.Input{ + gradients, inputs, min, max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // SizeAttr is an optional argument to Size. type SizeAttr func(optionalAttr) @@ -1631,119 +1786,37 @@ func Rank(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// ReverseSequenceAttr is an optional argument to ReverseSequence. -type ReverseSequenceAttr func(optionalAttr) +// ShapeAttr is an optional argument to Shape. +type ShapeAttr func(optionalAttr) -// ReverseSequenceBatchDim sets the optional batch_dim attribute to value. -// -// value: The dimension along which reversal is performed. -// If not specified, defaults to 0 -func ReverseSequenceBatchDim(value int64) ReverseSequenceAttr { +// ShapeOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func ShapeOutType(value tf.DataType) ShapeAttr { return func(m optionalAttr) { - m["batch_dim"] = value + m["out_type"] = value } } -// Reverses variable length slices. +// Returns the shape of a tensor. // -// This op first slices `input` along the dimension `batch_dim`, and for each -// slice `i`, reverses the first `seq_lengths[i]` elements along -// the dimension `seq_dim`. -// -// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`, -// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`. -// -// The output slice `i` along dimension `batch_dim` is then given by input -// slice `i`, with the first `seq_lengths[i]` slices along dimension -// `seq_dim` reversed. +// This operation returns a 1-D integer tensor representing the shape of `input`. // // For example: // // ``` -// # Given this: -// batch_dim = 0 -// seq_dim = 1 -// input.dims = (4, 8, ...) -// seq_lengths = [7, 2, 3, 5] -// -// # then slices of input are reversed on seq_dim, but only up to seq_lengths: -// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] -// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] -// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] -// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] -// -// # while entries past seq_lens are copied through: -// output[0, 7:, :, ...] = input[0, 7:, :, ...] -// output[1, 2:, :, ...] = input[1, 2:, :, ...] -// output[2, 3:, :, ...] = input[2, 3:, :, ...] -// output[3, 2:, :, ...] = input[3, 2:, :, ...] +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// shape(t) ==> [2, 2, 3] // ``` -// -// In contrast, if: -// -// ``` -// # Given this: -// batch_dim = 2 -// seq_dim = 0 -// input.dims = (8, ?, 4, ...) -// seq_lengths = [7, 2, 3, 5] -// -// # then slices of input are reversed on seq_dim, but only up to seq_lengths: -// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] -// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] -// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] -// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] -// -// # while entries past seq_lens are copied through: -// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] -// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] -// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] -// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] -// ``` -// -// Arguments: -// input: The input to reverse. -// seq_lengths: 1-D with length `input.dims(batch_dim)` and -// `max(seq_lengths) <= input.dims(seq_dim)` -// seq_dim: The dimension which is partially reversed. -// -// Returns The partially reversed input. It has the same shape as `input`. -func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_dim int64, optional ...ReverseSequenceAttr) (output tf.Output) { +func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"seq_dim": seq_dim} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ReverseSequence", - Input: []tf.Input{ - input, seq_lengths, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Ensures that the tensor's shape matches the expected shape. -// -// Raises an error if the input tensor's shape does not match the specified shape. -// Returns the input tensor otherwise. -// -// Arguments: -// input: A tensor, whose shape is to be validated. -// shape: The expected (possibly partially specified) shape of the input tensor. -// -// Returns A tensor with the same shape and contents as the input tensor or value. -func EnsureShape(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shape": shape} - opspec := tf.OpSpec{ - Type: "EnsureShape", + Type: "Shape", Input: []tf.Input{ input, }, @@ -1840,60 +1913,6 @@ func UniqueWithCountsV2(scope *Scope, x tf.Output, axis tf.Output, optional ...U return op.Output(0), op.Output(1), op.Output(2) } -// UniqueWithCountsAttr is an optional argument to UniqueWithCounts. -type UniqueWithCountsAttr func(optionalAttr) - -// UniqueWithCountsOutIdx sets the optional out_idx attribute to value. -// If not specified, defaults to DT_INT32 -func UniqueWithCountsOutIdx(value tf.DataType) UniqueWithCountsAttr { - return func(m optionalAttr) { - m["out_idx"] = value - } -} - -// Finds unique elements in a 1-D tensor. -// -// This operation returns a tensor `y` containing all of the unique elements of `x` -// sorted in the same order that they occur in `x`. This operation also returns a -// tensor `idx` the same size as `x` that contains the index of each value of `x` -// in the unique output `y`. Finally, it returns a third tensor `count` that -// contains the count of each element of `y` in `x`. In other words: -// -// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` -// -// For example: -// -// ``` -// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] -// y, idx, count = unique_with_counts(x) -// y ==> [1, 2, 4, 7, 8] -// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] -// count ==> [2, 1, 3, 1, 2] -// ``` -// -// Arguments: -// x: 1-D. -// -// Returns 1-D.1-D.1-D. -func UniqueWithCounts(scope *Scope, x tf.Output, optional ...UniqueWithCountsAttr) (y tf.Output, idx tf.Output, count tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UniqueWithCounts", - Input: []tf.Input{ - x, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // UniqueV2Attr is an optional argument to UniqueV2. type UniqueV2Attr func(optionalAttr) @@ -2048,6 +2067,60 @@ func ConjugateTranspose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) return op.Output(0) } +// Shuffle dimensions of x according to a permutation. +// +// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: +// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` +func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Transpose", + Input: []tf.Input{ + x, perm, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the inverse permutation of a tensor. +// +// This operation computes the inverse of an index permutation. It takes a 1-D +// integer tensor `x`, which represents the indices of a zero-based array, and +// swaps each value with its index position. In other words, for an output tensor +// `y` and an input tensor `x`, this operation computes the following: +// +// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]` +// +// The values must include 0. There can be no duplicate values or negative values. +// +// For example: +// +// ``` +// # tensor `x` is [3, 4, 0, 2, 1] +// invert_permutation(x) ==> [2, 4, 3, 0, 1] +// ``` +// +// Arguments: +// x: 1-D. +// +// Returns 1-D. +func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InvertPermutation", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Reshapes a tensor. // // Given `tensor`, this operation returns a tensor that has the same values @@ -2148,172 +2221,102 @@ func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Ou return op.Output(0) } -// Gather slices from `params` into a Tensor with shape specified by `indices`. +// Identity op for gradient debugging. // -// `indices` is an K-dimensional integer tensor, best thought of as a -// (K-1)-dimensional tensor of indices into `params`, where each element defines a -// slice of `params`: +// This op is hidden from public in Python. It is used by TensorFlow Debugger to +// register gradient tensors for gradient debugging. +// This op operates on non-reference-type tensors. +func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DebugGradientIdentity", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Gather slices from `params` axis `axis` according to `indices`. // -// output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]] +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `params.shape[:axis] + indices.shape + +// params.shape[axis + 1:]` where: // -// Whereas in `tf.gather` `indices` defines slices into the first -// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the -// first `N` dimensions of `params`, where `N = indices.shape[-1]`. +// ```python +// # Scalar indices (output is rank(params) - 1). +// output[a_0, ..., a_n, b_0, ..., b_n] = +// params[a_0, ..., a_n, indices, b_0, ..., b_n] // -// The last dimension of `indices` can be at most the rank of -// `params`: +// # Vector indices (output is rank(params)). +// output[a_0, ..., a_n, i, b_0, ..., b_n] = +// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] // -// indices.shape[-1] <= params.rank +// # Higher rank indices (output is rank(params) + rank(indices) - 1). +// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = +// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] +// ``` // -// The last dimension of `indices` corresponds to elements -// (if `indices.shape[-1] == params.rank`) or slices -// (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]` -// of `params`. The output tensor has shape -// -// indices.shape[:-1] + params.shape[indices.shape[-1]:] +//
+// +//
// // Note that on CPU, if an out of bound index is found, an error is returned. // On GPU, if an out of bound index is found, a 0 is stored in the // corresponding output value. // -// Some examples below. -// -// Simple indexing into a matrix: -// -// ```python -// indices = [[0, 0], [1, 1]] -// params = [['a', 'b'], ['c', 'd']] -// output = ['a', 'd'] -// ``` -// -// Slice indexing into a matrix: -// -// ```python -// indices = [[1], [0]] -// params = [['a', 'b'], ['c', 'd']] -// output = [['c', 'd'], ['a', 'b']] -// ``` -// -// Indexing into a 3-tensor: -// -// ```python -// indices = [[1]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [[['a1', 'b1'], ['c1', 'd1']]] -// -// -// indices = [[0, 1], [1, 0]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [['c0', 'd0'], ['a1', 'b1']] -// -// -// indices = [[0, 0, 1], [1, 0, 1]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = ['b0', 'b1'] -// ``` -// -// Batched indexing into a matrix: -// -// ```python -// indices = [[[0, 0]], [[0, 1]]] -// params = [['a', 'b'], ['c', 'd']] -// output = [['a'], ['b']] -// ``` -// -// Batched slice indexing into a matrix: -// -// ```python -// indices = [[[1]], [[0]]] -// params = [['a', 'b'], ['c', 'd']] -// output = [[['c', 'd']], [['a', 'b']]] -// ``` -// -// Batched indexing into a 3-tensor: -// -// ```python -// indices = [[[1]], [[0]]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [[[['a1', 'b1'], ['c1', 'd1']]], -// [[['a0', 'b0'], ['c0', 'd0']]]] -// -// indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [[['c0', 'd0'], ['a1', 'b1']], -// [['a0', 'b0'], ['c1', 'd1']]] -// -// -// indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]] -// params = [[['a0', 'b0'], ['c0', 'd0']], -// [['a1', 'b1'], ['c1', 'd1']]] -// output = [['b0', 'b1'], ['d0', 'c1']] -// ``` -// -// See also `tf.gather` and `tf.batch_gather`. +// See also `tf.batch_gather` and `tf.gather_nd`. // // Arguments: -// params: The tensor from which to gather values. -// indices: Index tensor. +// params: The tensor from which to gather values. Must be at least rank +// `axis + 1`. +// indices: Index tensor. Must be in range `[0, params.shape[axis])`. +// axis: The axis in `params` to gather `indices` from. Defaults to the first +// dimension. Supports negative indexes. // // Returns Values from `params` gathered from indices given by `indices`, with -// shape `indices.shape[:-1] + params.shape[indices.shape[-1]:]`. -func GatherNd(scope *Scope, params tf.Output, indices tf.Output) (output tf.Output) { +// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`. +func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "GatherNd", + Type: "GatherV2", Input: []tf.Input{ - params, indices, + params, indices, axis, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// GatherAttr is an optional argument to Gather. -type GatherAttr func(optionalAttr) +// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. +type QuantizeAndDequantizeV3Attr func(optionalAttr) -// GatherValidateIndices sets the optional validate_indices attribute to value. +// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. // If not specified, defaults to true -func GatherValidateIndices(value bool) GatherAttr { +func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { return func(m optionalAttr) { - m["validate_indices"] = value + m["signed_input"] = value } } -// Gather slices from `params` according to `indices`. +// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// Quantizes then dequantizes a tensor. // -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: -// -// ```python -// # Scalar indices -// output[:, ..., :] = params[indices, :, ... :] -// -// # Vector indices -// output[i, :, ..., :] = params[indices[i], :, ... :] -// -// # Higher rank indices -// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] -// ``` -// -// If `indices` is a permutation and `len(indices) == params.shape[0]` then -// this operation will permute `params` accordingly. -// -// `validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in -// `indices` are always validated to be within range. If assigned to GPU, -// out-of-bound indices result in safe but unspecified behavior, which may include -// raising an error. -// -//
-// -//
-func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...GatherAttr) (output tf.Output) { +// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a +// tensor, so its value can change during training. +func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { if scope.Err() != nil { return } @@ -2322,67 +2325,9 @@ func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...Gathe a(attrs) } opspec := tf.OpSpec{ - Type: "Gather", + Type: "QuantizeAndDequantizeV3", Input: []tf.Input{ - params, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LowerBoundAttr is an optional argument to LowerBound. -type LowerBoundAttr func(optionalAttr) - -// LowerBoundOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func LowerBoundOutType(value tf.DataType) LowerBoundAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Applies lower_bound(sorted_search_values, values) along each row. -// -// Each set of rows with the same index in (sorted_inputs, values) is treated -// independently. The resulting row is the equivalent of calling -// `np.searchsorted(sorted_inputs, values, side='left')`. -// -// The result is not a global index to the entire -// `Tensor`, but rather just the index in the last dimension. -// -// A 2-D example: -// sorted_sequence = [[0, 3, 9, 9, 10], -// [1, 2, 3, 4, 5]] -// values = [[2, 4, 9], -// [0, 2, 6]] -// -// result = LowerBound(sorted_sequence, values) -// -// result == [[1, 2, 2], -// [0, 1, 5]] -// -// Arguments: -// sorted_inputs: 2-D Tensor where each row is ordered. -// values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains -// the values that will be searched for in `sorted_search_values`. -// -// Returns A `Tensor` with the same shape as `values`. It contains the first scalar index -// into the last dimension where values can be inserted without changing the -// ordered property. -func LowerBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...LowerBoundAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LowerBound", - Input: []tf.Input{ - sorted_inputs, values, + input, input_min, input_max, num_bits, }, Attrs: attrs, } @@ -2499,95 +2444,98 @@ func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output) return op.Output(0) } -// Returns the batched diagonal part of a batched tensor. +// Copy a tensor setting everything outside a central band in each innermost matrix // -// This operation returns a tensor with the `diagonal` part -// of the batched `input`. The `diagonal` part is computed as follows: +// to zero. // +// The `band` part is computed as follows: // Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a -// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where: +// tensor with the same shape where // -// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`. +// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. // -// The input must be at least a matrix. +// The indicator function +// +// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && +// (num_upper < 0 || (n-m) <= num_upper)`. // // For example: // // ``` -// # 'input' is [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] +// # if 'input' is [[ 0, 1, 2, 3] +// [-1, 0, 1, 2] +// [-2, -1, 0, 1] +// [-3, -2, -1, 0]], // -// and input.shape = (2, 4, 4) +// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] +// [-1, 0, 1, 2] +// [ 0, -1, 0, 1] +// [ 0, 0, -1, 0]], // -// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]] +// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] +// [-1, 0, 1, 0] +// [-2, -1, 0, 1] +// [ 0, -2, -1, 0]] +// ``` // -// which has shape (2, 4) +// Useful special cases: +// +// ``` +// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. +// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. +// tf.matrix_band_part(input, 0, 0) ==> Diagonal. // ``` // // Arguments: -// input: Rank `k` tensor where `k >= 2`. +// input: Rank `k` tensor. +// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire +// lower triangle. +// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep +// entire upper triangle. // -// Returns The extracted diagonal(s) having shape -// `diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`. -func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { +// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor. +func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixDiagPart", + Type: "MatrixBandPart", Input: []tf.Input{ - input, + input, num_lower, num_upper, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a batched diagonal tensor with a given batched diagonal values. +// Returns a diagonal tensor with a given diagonal values. // // Given a `diagonal`, this operation returns a tensor with the `diagonal` and // everything else padded with zeros. The diagonal is computed as follows: // -// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a -// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: +// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of +// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: // -// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. +// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. // // For example: // // ``` -// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] -// -// and diagonal.shape = (2, 4) -// -// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] -// -// which has shape (2, 4, 4) +// # 'diagonal' is [1, 2, 3, 4] +// tf.diag(diagonal) ==> [[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]] // ``` // // Arguments: -// diagonal: Rank `k`, where `k >= 1`. -// -// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. -func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { +// diagonal: Rank k tensor where k is at most 1. +func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixDiag", + Type: "Diag", Input: []tf.Input{ diagonal, }, @@ -2596,121 +2544,20 @@ func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { return op.Output(0) } -// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. -type QuantizedInstanceNormAttr func(optionalAttr) - -// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value. -// -// value: If True, `given_y_min` and `given_y_min` -// and `given_y_max` are used as the output range. Otherwise, -// the implementation computes the output range. -// If not specified, defaults to false -func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["output_range_given"] = value - } -} - -// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value. -// -// value: Output in `y_min` if `output_range_given` is True. -// If not specified, defaults to 0 -func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["given_y_min"] = value - } -} - -// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value. -// -// value: Output in `y_max` if `output_range_given` is True. -// If not specified, defaults to 0 -func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["given_y_max"] = value - } -} - -// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value. -// -// value: A small float number to avoid dividing by 0. -// If not specified, defaults to 1e-05 -func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["variance_epsilon"] = value - } -} - -// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value. -// -// value: Minimum value of `y_max - y_min` -// If not specified, defaults to 0.001 -func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr { - return func(m optionalAttr) { - m["min_separation"] = value - } -} - -// Quantized Instance normalization. +// Returns a tensor of zeros with the same shape and type as x. // // Arguments: -// x: A 4D input Tensor. -// x_min: The value represented by the lowest quantized input. -// x_max: The value represented by the highest quantized input. +// x: a tensor of type T. // -// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output. -func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedInstanceNorm", - Input: []tf.Input{ - x, x_min, x_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Returns the diagonal part of the tensor. -// -// This operation returns a tensor with the `diagonal` part -// of the `input`. The `diagonal` part is computed as follows: -// -// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a -// tensor of rank `k` with dimensions `[D1,..., Dk]` where: -// -// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. -// -// For example: -// -// ``` -// # 'input' is [[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]] -// -// tf.diag_part(input) ==> [1, 2, 3, 4] -// ``` -// -// Arguments: -// input: Rank k tensor where k is even and not zero. -// -// Returns The extracted diagonal. -func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { +// Returns a tensor of the same shape and type as x but filled with zeros. +func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DiagPart", + Type: "ZerosLike", Input: []tf.Input{ - input, + x, }, } op := scope.AddOperation(opspec) @@ -2739,6 +2586,29 @@ func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// Returns immutable tensor from memory region. +// +// The current implementation memmaps the tensor from a file. +// +// Arguments: +// dtype: Type of the returned tensor. +// shape: Shape of the returned tensor. +// memory_region_name: Name of readonly memory region used by the tensor, see +// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. +func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} + opspec := tf.OpSpec{ + Type: "ImmutableConst", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns a constant tensor on the host. Only for writing C++ tests. // // Arguments: @@ -2758,45 +2628,6 @@ func HostConst(scope *Scope, value tf.Tensor, dtype tf.DataType) (output tf.Outp return op.Output(0) } -// Splits a tensor into `num_split` tensors along one dimension. -// -// Arguments: -// value: The tensor to split. -// size_splits: list containing the sizes of each output tensor along the split -// dimension. Must sum to the dimension of value along split_dim. -// Can contain one -1 indicating that dimension is to be inferred. -// axis: 0-D. The dimension along which to split. Must be in the range -// `[-rank(value), rank(value))`. -// -// -// Returns Tensors whose shape matches that of `value` -// except along `axis`, where their sizes are -// `size_splits[i]`. -func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_split": num_split} - opspec := tf.OpSpec{ - Type: "SplitV", - Input: []tf.Input{ - value, size_splits, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("SplitV", err) - return - } - return output -} - // Splits a tensor into `num_split` tensors along one dimension. // // Arguments: @@ -2834,31 +2665,6 @@ func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (outp return output } -// Concatenates tensors along one dimension. -// -// Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Concat", - Input: []tf.Input{ - concat_dim, tf.OutputList(values), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Broadcast an array for a compatible shape. // // Broadcasting is the process of making arrays to have compatible shapes @@ -2868,7 +2674,8 @@ func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.O // and works its way forward. // // For example, -// ``` +// +// ```python // >>> x = tf.constant([1, 2, 3]) // >>> y = tf.broadcast_to(x, [3, 3]) // >>> sess.run(y) @@ -2876,6 +2683,7 @@ func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.O // [1, 2, 3], // [1, 2, 3]], dtype=int32) // ``` +// // In the above example, the input Tensor with the shape of `[1, 3]` // is broadcasted to output Tensor with shape of `[3, 3]`. // @@ -2928,6 +2736,29 @@ func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Ou return op.Output(0) } +// A placeholder op that passes through `input` when its output is not fed. +// +// Arguments: +// input: The default value to produce when `output` is not fed. +// shape: The (possibly partial) shape of the tensor. +// +// Returns A placeholder tensor that defaults to `input` if it is not fed. +func PlaceholderWithDefault(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "PlaceholderWithDefault", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Subtracts `v` into specified rows of `x`. // // Computes y = x; y[i, :] -= v; return y. @@ -2952,6 +2783,30 @@ func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Outpu return op.Output(0) } +// Updates specified rows with values in `v`. +// +// Computes `x[i, :] = v; return x`. +// +// Arguments: +// x: A tensor of type `T`. +// i: A vector. Indices into the left-most dimension of `x`. +// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. +// +// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. +func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InplaceUpdate", + Input: []tf.Input{ + x, i, v, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Makes a copy of `x`. // // Arguments: @@ -3032,41 +2887,67 @@ func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Out return op.Output(0) } -// Concatenates a list of `N` tensors along the first dimension. +// AudioSpectrogramAttr is an optional argument to AudioSpectrogram. +type AudioSpectrogramAttr func(optionalAttr) + +// AudioSpectrogramMagnitudeSquared sets the optional magnitude_squared attribute to value. // -// The input tensors are all required to have size 1 in the first dimension. +// value: Whether to return the squared magnitude or just the +// magnitude. Using squared magnitude can avoid extra calculations. +// If not specified, defaults to false +func AudioSpectrogramMagnitudeSquared(value bool) AudioSpectrogramAttr { + return func(m optionalAttr) { + m["magnitude_squared"] = value + } +} + +// Produces a visualization of audio data over time. // -// For example: +// Spectrograms are a standard way of representing audio information as a series of +// slices of frequency information, one slice for each window of time. By joining +// these together into a sequence, they form a distinctive fingerprint of the sound +// over time. // -// ``` -// # 'x' is [[1, 4]] -// # 'y' is [[2, 5]] -// # 'z' is [[3, 6]] -// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -// ``` +// This op expects to receive audio data as an input, stored as floats in the range +// -1 to 1, together with a window width in samples, and a stride specifying how +// far to move the window between slices. From this it generates a three +// dimensional output. The first dimension is for the channels in the input, so a +// stereo audio input would have two here for example. The second dimension is time, +// with successive frequency slices. The third dimension has an amplitude value for +// each frequency during that time slice. // -// The difference between concat and parallel_concat is that concat requires all -// of the inputs be computed before the operation will begin but doesn't require -// that the input shapes be known during graph construction. Parallel concat -// will copy pieces of the input into the output as they become available, in -// some situations this can provide a performance benefit. +// This means the layout when converted and saved as an image is rotated 90 degrees +// clockwise from a typical spectrogram. Time is descending down the Y axis, and +// the frequency decreases from left to right. +// +// Each value in the result represents the square root of the sum of the real and +// imaginary parts of an FFT on the current window of samples. In this way, the +// lowest dimension represents the power of each frequency in the current window, +// and adjacent windows are concatenated in the next dimension. +// +// To get a more intuitive and visual look at what this operation does, you can run +// tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the +// resulting spectrogram as a PNG image. // // Arguments: -// values: Tensors to be concatenated. All must have size 1 in the first dimension -// and same shape. -// shape: the final shape of the result; should be equal to the shapes of any input -// but with the number of input values in the first dimension. +// input: Float representation of audio data. +// window_size: How wide the input window is in samples. For the highest efficiency +// this should be a power of two, but other values are accepted. +// stride: How widely apart the center of adjacent sample windows should be. // -// Returns The concatenated tensor. -func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { +// Returns 3D representation of the audio frequencies as an image. +func AudioSpectrogram(scope *Scope, input tf.Output, window_size int64, stride int64, optional ...AudioSpectrogramAttr) (spectrogram tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} + attrs := map[string]interface{}{"window_size": window_size, "stride": stride} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ParallelConcat", + Type: "AudioSpectrogram", Input: []tf.Input{ - tf.OutputList(values), + input, }, Attrs: attrs, } @@ -3074,122 +2955,30 @@ func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf return op.Output(0) } -// DecodeWavAttr is an optional argument to DecodeWav. -type DecodeWavAttr func(optionalAttr) - -// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. +// Encode audio data using the WAV file format. // -// value: Number of sample channels wanted. -// If not specified, defaults to -1 -func DecodeWavDesiredChannels(value int64) DecodeWavAttr { - return func(m optionalAttr) { - m["desired_channels"] = value - } -} - -// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. +// This operation will generate a string suitable to be saved out to create a .wav +// audio file. It will be encoded in the 16-bit PCM format. It takes in float +// values in the range -1.0f to 1.0f, and any outside that value will be clamped to +// that range. // -// value: Length of audio requested. -// If not specified, defaults to -1 -func DecodeWavDesiredSamples(value int64) DecodeWavAttr { - return func(m optionalAttr) { - m["desired_samples"] = value - } -} - -// Decode a 16-bit PCM WAV file to a float tensor. -// -// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. -// -// When desired_channels is set, if the input contains fewer channels than this -// then the last channel will be duplicated to give the requested number, else if -// the input has more channels than requested then the additional channels will be -// ignored. -// -// If desired_samples is set, then the audio will be cropped or padded with zeroes -// to the requested length. -// -// The first output contains a Tensor with the content of the audio samples. The -// lowest dimension will be the number of channels, and the second will be the -// number of samples. For example, a ten-sample-long stereo WAV file should give an -// output shape of [10, 2]. +// `audio` is a 2-D float Tensor of shape `[length, channels]`. +// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). // // Arguments: -// contents: The WAV-encoded audio, usually from a file. +// audio: 2-D with shape `[length, channels]`. +// sample_rate: Scalar containing the sample frequency. // -// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. -func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { +// Returns 0-D. WAV-encoded file contents. +func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "DecodeWav", + Type: "EncodeWav", Input: []tf.Input{ - contents, + audio, sample_rate, }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// UnbatchAttr is an optional argument to Unbatch. -type UnbatchAttr func(optionalAttr) - -// UnbatchContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnbatchContainer(value string) UnbatchAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// UnbatchSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnbatchSharedName(value string) UnbatchAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Reverses the operation of Batch for a single output Tensor. -// -// An instance of Unbatch either receives an empty batched_tensor, in which case it -// asynchronously waits until the values become available from a concurrently -// running instance of Unbatch with the same container and shared_name, or receives -// a non-empty batched_tensor in which case it finalizes all other concurrently -// running instances and outputs its own element from the batch. -// -// batched_tensor: The possibly transformed output of Batch. The size of the first -// dimension should remain unchanged by the transformations for the operation to -// work. -// batch_index: The matching batch_index obtained from Batch. -// id: The id scalar emitted by Batch. -// unbatched_tensor: The Tensor corresponding to this execution. -// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the -// batched input tensor associated with a given invocation of the op. -// container: Container to control resource sharing. -// shared_name: Instances of Unbatch with the same container and shared_name are -// assumed to possibly belong to the same batch. If left empty, the op name will -// be used as the shared name. -func Unbatch(scope *Scope, batched_tensor tf.Output, batch_index tf.Output, id tf.Output, timeout_micros int64, optional ...UnbatchAttr) (unbatched_tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"timeout_micros": timeout_micros} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Unbatch", - Input: []tf.Input{ - batched_tensor, batch_index, id, - }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -3213,16 +3002,16 @@ func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Elementwise computes the bitwise XOR of `x` and `y`. +// Elementwise computes the bitwise AND of `x` and `y`. // -// The result will have those bits set, that are different in `x` and `y`. The +// The result will have those bits set, that are set in both `x` and `y`. The // computation is performed on the underlying representations of `x` and `y`. -func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BitwiseXor", + Type: "BitwiseAnd", Input: []tf.Input{ x, y, }, @@ -3253,38 +3042,22 @@ func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Bucketize each feature based on bucket boundaries. +// Flips all bits elementwise. // -// An op that returns a list of float tensors, where each tensor represents the -// bucketized values for a single feature. -// -// Arguments: -// float_values: float; List of Rank 1 Tensor each containing float values for a single feature. -// bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a single -// feature. -// -// Returns int; List of Rank 1 Tensors each containing the bucketized values for a single feature. -func BoostedTreesBucketize(scope *Scope, float_values []tf.Output, bucket_boundaries []tf.Output) (buckets []tf.Output) { +// The result will have exactly those bits set, that are not set in `x`. The +// computation is performed on the underlying representation of x. +func Invert(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesBucketize", + Type: "Invert", Input: []tf.Input{ - tf.OutputList(float_values), tf.OutputList(bucket_boundaries), + x, }, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if buckets, idx, err = makeOutputList(op, idx, "buckets"); err != nil { - scope.UpdateErr("BoostedTreesBucketize", err) - return - } - return buckets + return op.Output(0) } // BoostedTreesQuantileStreamResourceFlushAttr is an optional argument to BoostedTreesQuantileStreamResourceFlush. @@ -3332,6 +3105,28 @@ func BoostedTreesQuantileStreamResourceFlush(scope *Scope, quantile_stream_resou return scope.AddOperation(opspec) } +// Deserialize bucket boundaries and ready flag into current QuantileAccumulator. +// +// An op that deserializes bucket boundaries and are boundaries ready flag into current QuantileAccumulator. +// +// Arguments: +// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. +// bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a feature. +// +// Returns the created operation. +func BoostedTreesQuantileStreamResourceDeserialize(scope *Scope, quantile_stream_resource_handle tf.Output, bucket_boundaries []tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesQuantileStreamResourceDeserialize", + Input: []tf.Input{ + quantile_stream_resource_handle, tf.OutputList(bucket_boundaries), + }, + } + return scope.AddOperation(opspec) +} + // Makes the summary of quantiles for the batch. // // An op that takes a list of tensors (one tensor per feature) and outputs the @@ -3367,45 +3162,6 @@ func BoostedTreesMakeQuantileSummaries(scope *Scope, float_values []tf.Output, e return summaries } -// BoostedTreesCreateQuantileStreamResourceAttr is an optional argument to BoostedTreesCreateQuantileStreamResource. -type BoostedTreesCreateQuantileStreamResourceAttr func(optionalAttr) - -// BoostedTreesCreateQuantileStreamResourceMaxElements sets the optional max_elements attribute to value. -// -// value: int; The maximum number of data points that can be fed to the stream. -// If not specified, defaults to 1099511627776 -func BoostedTreesCreateQuantileStreamResourceMaxElements(value int64) BoostedTreesCreateQuantileStreamResourceAttr { - return func(m optionalAttr) { - m["max_elements"] = value - } -} - -// Create the Resource for Quantile Streams. -// -// Arguments: -// quantile_stream_resource_handle: resource; Handle to quantile stream resource. -// epsilon: float; The required approximation error of the stream resource. -// num_streams: int; The number of streams managed by the resource that shares the same epsilon. -// -// Returns the created operation. -func BoostedTreesCreateQuantileStreamResource(scope *Scope, quantile_stream_resource_handle tf.Output, epsilon tf.Output, num_streams tf.Output, optional ...BoostedTreesCreateQuantileStreamResourceAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "BoostedTreesCreateQuantileStreamResource", - Input: []tf.Input{ - quantile_stream_resource_handle, epsilon, num_streams, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // Checks whether a quantile stream has been initialized. // // An Op that checks if quantile stream resource is initialized. @@ -3452,59 +3208,103 @@ func BoostedTreesCenterBias(scope *Scope, tree_ensemble_handle tf.Output, mean_g return op.Output(0) } -// Runs multiple additive regression ensemble predictors on input instances and +// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. +type FakeQuantWithMinMaxVarsAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` // -// computes the update to cached logits. It is designed to be used during training. -// It traverses the trees starting from cached tree id and cached node id and -// calculates the updates to be pushed to the cache. +// and `max` to 'outputs' tensor of same shape as `inputs`. // -// Arguments: +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. // -// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting -// tree of prediction. -// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting -// node of prediction. -// bucketized_features: A list of rank 1 Tensors containing bucket id for each -// feature. -// logits_dimension: scalar, dimension of the logits, to be used for partial logits -// shape. +// Before quantization, `min` and `max` values are adjusted with the following +// logic. +// It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, +// the behavior can be unexpected: +// If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. +// If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. +// If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, +// `min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. // -// Returns Rank 2 Tensor containing logits update (with respect to cached -// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids. -func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) { +// This operation has a gradient and thus allows for training `min` and `max` +// values. +func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"logits_dimension": logits_dimension} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BoostedTreesTrainingPredict", + Type: "FakeQuantWithMinMaxVars", Input: []tf.Input{ - tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features), + inputs, min, max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Serializes the tree ensemble to a proto. +// Updates the tree ensemble by either adding a layer to the last tree being grown +// +// or by starting a new tree. // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. +// tree_ensemble_handle: Handle to the ensemble variable. +// feature_ids: Rank 1 tensor with ids for each feature. This is the real id of +// the feature that will be used in the split. +// node_ids: List of rank 1 tensors representing the nodes for which this feature +// has a split. +// gains: List of rank 1 tensors representing the gains for each of the feature's +// split. +// thresholds: List of rank 1 tensors representing the thesholds for each of the +// feature's split. +// left_node_contribs: List of rank 2 tensors with left leaf contribs for each of +// the feature's splits. Will be added to the previous node values to constitute +// the values of the left nodes. +// right_node_contribs: List of rank 2 tensors with right leaf contribs for each +// of the feature's splits. Will be added to the previous node values to constitute +// the values of the right nodes. +// max_depth: Max depth of the tree to build. +// learning_rate: shrinkage const for each new tree. +// pruning_mode: 0-No pruning, 1-Pre-pruning, 2-Post-pruning. // -// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble. -func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) { +// Returns the created operation. +func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, feature_ids tf.Output, node_ids []tf.Output, gains []tf.Output, thresholds []tf.Output, left_node_contribs []tf.Output, right_node_contribs []tf.Output, max_depth tf.Output, learning_rate tf.Output, pruning_mode int64) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"pruning_mode": pruning_mode} opspec := tf.OpSpec{ - Type: "BoostedTreesSerializeEnsemble", + Type: "BoostedTreesUpdateEnsemble", Input: []tf.Input{ - tree_ensemble_handle, + tree_ensemble_handle, feature_ids, tf.OutputList(node_ids), tf.OutputList(gains), tf.OutputList(thresholds), tf.OutputList(left_node_contribs), tf.OutputList(right_node_contribs), max_depth, learning_rate, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } // Debugging/model interpretability outputs for each example. @@ -3537,28 +3337,46 @@ func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Outpu return op.Output(0) } -// Makes the summary of accumulated stats for the batch. +// Return the reduction indices for computing gradients of s0 op s1 with broadcast. // -// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example. +// This is typically used by gradient computations for a broadcasting operation. +func BroadcastGradientArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output, r1 tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BroadcastGradientArgs", + Input: []tf.Input{ + s0, s1, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Aggregates the summary of accumulated stats for the batch. +// +// The summary stats contains gradients and hessians accumulated for each node, feature dimension id and bucket. // // Arguments: -// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer. -// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients. -// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians. -// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column). +// node_ids: int32; Rank 1 Tensor containing node ids for each example, shape [batch_size]. +// gradients: float32; Rank 2 Tensor (shape=[batch_size, logits_dimension]) with gradients for each example. +// hessians: float32; Rank 2 Tensor (shape=[batch_size, hessian_dimension]) with hessians for each example. +// feature: int32; Rank 2 feature Tensors (shape=[batch_size, feature_dimension]). // max_splits: int; the maximum number of splits possible in the whole tree. // num_buckets: int; equals to the maximum possible value of bucketized feature. // -// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians. -func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) { +// Returns output Rank 4 Tensor (shape=[splits, feature_dimension, buckets, logits_dimension + hessian_dimension]) +// containing accumulated stats for each node, feature dimension and bucket. +func BoostedTreesAggregateStats(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, feature tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "BoostedTreesMakeStatsSummary", + Type: "BoostedTreesAggregateStats", Input: []tf.Input{ - node_ids, gradients, hessians, tf.OutputList(bucketized_features_list), + node_ids, gradients, hessians, feature, }, Attrs: attrs, } @@ -3566,6 +3384,50 @@ func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf return op.Output(0) } +// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. +// +// Arguments: +// tree_ensemble_handle: Handle to the tree ensemble. +// +// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest +// layer. +func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesGetEnsembleStates", + Input: []tf.Input{ + tree_ensemble_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + +// Deserializes a serialized tree ensemble config and replaces current tree +// +// ensemble. +// +// Arguments: +// tree_ensemble_handle: Handle to the tree ensemble. +// stamp_token: Token to use as the new value of the resource stamp. +// tree_ensemble_serialized: Serialized proto of the ensemble. +// +// Returns the created operation. +func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesDeserializeEnsemble", + Input: []tf.Input{ + tree_ensemble_handle, stamp_token, tree_ensemble_serialized, + }, + } + return scope.AddOperation(opspec) +} + // Creates a tree ensemble model and returns a handle to it. // // Arguments: @@ -3587,6 +3449,59 @@ func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, st return scope.AddOperation(opspec) } +// BoostedTreesCalculateBestFeatureSplitAttr is an optional argument to BoostedTreesCalculateBestFeatureSplit. +type BoostedTreesCalculateBestFeatureSplitAttr func(optionalAttr) + +// BoostedTreesCalculateBestFeatureSplitSplitType sets the optional split_type attribute to value. +// +// value: A string indicating if this Op should perform inequality split or equality split. +// If not specified, defaults to "inequality" +func BoostedTreesCalculateBestFeatureSplitSplitType(value string) BoostedTreesCalculateBestFeatureSplitAttr { + return func(m optionalAttr) { + m["split_type"] = value + } +} + +// Calculates gains for each feature and returns the best possible split information for the feature. +// +// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature. +// +// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split. +// +// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features). +// +// The output shapes are compatible in a way that the first dimension of all tensors are the same and equal to the number of possible split nodes for each feature. +// +// Arguments: +// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive). +// stats_summary: A Rank 4 tensor (#shape=[max_splits, feature_dims, bucket, stats_dims]) for accumulated stats summary (gradient/hessian) per node, per dimension, per buckets for each feature. +// The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used. +// l1: l1 regularization factor on leaf weights, per instance based. +// l2: l2 regularization factor on leaf weights, per instance based. +// tree_complexity: adjustment to the gain, per leaf based. +// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting. +// logits_dimension: The dimension of logit, i.e., number of classes. +// +// Returns A Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.A Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.A Rank 1 tensors indicating the best feature dimension for each feature to split for certain nodes if the feature is multi-dimension. See above for details like shapes and sizes.A Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.A Rank 1 tensors indicating the which direction to go if data is missing. See above for details like shapes and sizes. +func BoostedTreesCalculateBestFeatureSplit(scope *Scope, node_id_range tf.Output, stats_summary tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, logits_dimension int64, optional ...BoostedTreesCalculateBestFeatureSplitAttr) (node_ids tf.Output, gains tf.Output, feature_dimensions tf.Output, thresholds tf.Output, left_node_contribs tf.Output, right_node_contribs tf.Output, split_with_default_directions tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"logits_dimension": logits_dimension} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "BoostedTreesCalculateBestFeatureSplit", + Input: []tf.Input{ + node_id_range, stats_summary, l1, l2, tree_complexity, min_node_weight, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) +} + // Checks whether a tree ensemble has been initialized. // // Arguments: @@ -3644,30 +3559,6 @@ func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTrees return op.Output(0) } -// Output the logits for the given input data -// -// Arguments: -// tree_handle: Handle to the tree resource. -// dense_features: Rank 2 dense features tensor. -// logits_dimension: Scalar, dimension of the logits. -// -// Returns The logits predictions from the tree for each instance in the batch. -func TensorForestTreePredict(scope *Scope, tree_handle tf.Output, dense_features tf.Output, logits_dimension int64) (logits tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"logits_dimension": logits_dimension} - opspec := tf.OpSpec{ - Type: "TensorForestTreePredict", - Input: []tf.Input{ - tree_handle, dense_features, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Get the number of nodes in a tree // // Arguments: @@ -3688,6 +3579,26 @@ func TensorForestTreeSize(scope *Scope, tree_handle tf.Output) (tree_size tf.Out return op.Output(0) } +// Deserializes a proto into the tree handle +// +// Arguments: +// tree_handle: Handle to the tree resource to be restored. +// tree_config: Serialied proto string of the boosted_trees.Tree proto. +// +// Returns the created operation. +func TensorForestTreeDeserialize(scope *Scope, tree_handle tf.Output, tree_config tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorForestTreeDeserialize", + Input: []tf.Input{ + tree_handle, tree_config, + }, + } + return scope.AddOperation(opspec) +} + // Creates a tree resource and returns a handle to it. // // Arguments: @@ -3708,6 +3619,67 @@ func TensorForestCreateTreeVariable(scope *Scope, tree_handle tf.Output, tree_co return scope.AddOperation(opspec) } +// TensorForestTreeResourceHandleOpAttr is an optional argument to TensorForestTreeResourceHandleOp. +type TensorForestTreeResourceHandleOpAttr func(optionalAttr) + +// TensorForestTreeResourceHandleOpContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func TensorForestTreeResourceHandleOpContainer(value string) TensorForestTreeResourceHandleOpAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// TensorForestTreeResourceHandleOpSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func TensorForestTreeResourceHandleOpSharedName(value string) TensorForestTreeResourceHandleOpAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a handle to a TensorForestTreeResource +func TensorForestTreeResourceHandleOp(scope *Scope, optional ...TensorForestTreeResourceHandleOpAttr) (resource tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorForestTreeResourceHandleOp", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds v into specified rows of x. +// +// Computes y = x; y[i, :] += v; return y. +// +// Arguments: +// x: A `Tensor` of type T. +// i: A vector. Indices into the left-most dimension of `x`. +// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. +// +// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. +func InplaceAdd(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InplaceAdd", + Input: []tf.Input{ + x, i, v, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. type ComputeAccidentalHitsAttr func(optionalAttr) @@ -3767,6 +3739,78 @@ func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candida return op.Output(0), op.Output(1), op.Output(2) } +// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. +type AllCandidateSamplerAttr func(optionalAttr) + +// AllCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Generates labels for candidate sampling with a learned unigram distribution. +// +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. +// +// Arguments: +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to produce. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AllCandidateSampler", + Input: []tf.Input{ + true_classes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler. type FixedUnigramCandidateSamplerAttr func(optionalAttr) @@ -3926,32 +3970,32 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true return op.Output(0), op.Output(1), op.Output(2) } -// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler. -type LogUniformCandidateSamplerAttr func(optionalAttr) +// ThreadUnsafeUnigramCandidateSamplerAttr is an optional argument to ThreadUnsafeUnigramCandidateSampler. +type ThreadUnsafeUnigramCandidateSamplerAttr func(optionalAttr) -// LogUniformCandidateSamplerSeed sets the optional seed attribute to value. +// ThreadUnsafeUnigramCandidateSamplerSeed sets the optional seed attribute to value. // // value: If either seed or seed2 are set to be non-zero, the random number // generator is seeded by the given seed. Otherwise, it is seeded by a // random seed. // If not specified, defaults to 0 -func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr { +func ThreadUnsafeUnigramCandidateSamplerSeed(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { return func(m optionalAttr) { m["seed"] = value } } -// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// ThreadUnsafeUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. // // value: An second seed to avoid seed collision. // If not specified, defaults to 0 -func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr { +func ThreadUnsafeUnigramCandidateSamplerSeed2(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { return func(m optionalAttr) { m["seed2"] = value } } -// Generates labels for candidate sampling with a log-uniform distribution. +// Generates labels for candidate sampling with a learned unigram distribution. // // See explanations of candidate sampling and the data formats at // go/candidate-sampling. @@ -3980,7 +4024,7 @@ func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr // candidate representing the number of times the candidate is expected // to occur in a batch of sampled candidates. If unique=true, then this is a // probability. -func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +func ThreadUnsafeUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...ThreadUnsafeUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } @@ -3989,7 +4033,80 @@ func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true i a(attrs) } opspec := tf.OpSpec{ - Type: "LogUniformCandidateSampler", + Type: "ThreadUnsafeUnigramCandidateSampler", + Input: []tf.Input{ + true_classes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. +type LearnedUnigramCandidateSamplerAttr func(optionalAttr) + +// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Generates labels for candidate sampling with a learned unigram distribution. +// +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. +// +// Arguments: +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LearnedUnigramCandidateSampler", Input: []tf.Input{ true_classes, }, @@ -4072,62 +4189,152 @@ func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int6 return op.Output(0), op.Output(1), op.Output(2) } -// Broadcasts a tensor value to one or more other devices. -func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} - opspec := tf.OpSpec{ - Type: "CollectiveBcastSend", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// LoadAndRemapMatrixAttr is an optional argument to LoadAndRemapMatrix. +type LoadAndRemapMatrixAttr func(optionalAttr) -// Mutually accumulates multiple tensors of identical type and shape. -func CollectiveGather(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} - opspec := tf.OpSpec{ - Type: "CollectiveGather", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CollectiveReduceAttr is an optional argument to CollectiveReduce. -type CollectiveReduceAttr func(optionalAttr) - -// CollectiveReduceWaitFor sets the optional wait_for attribute to value. -// If not specified, defaults to <> -func CollectiveReduceWaitFor(value []int64) CollectiveReduceAttr { +// LoadAndRemapMatrixMaxRowsInMemory sets the optional max_rows_in_memory attribute to value. +// +// value: The maximum number of rows to load from the checkpoint at +// once. If less than or equal to 0, the entire matrix will be loaded into +// memory. Setting this arg trades increased disk reads for lower memory usage. +// If not specified, defaults to -1 +func LoadAndRemapMatrixMaxRowsInMemory(value int64) LoadAndRemapMatrixAttr { return func(m optionalAttr) { - m["wait_for"] = value + m["max_rows_in_memory"] = value } } -// Mutually reduces multiple tensors of identical type and shape. -func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64, optional ...CollectiveReduceAttr) (data tf.Output) { +// Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint +// +// at `ckpt_path` and potentially reorders its rows and columns using the +// specified remappings. +// +// Most users should use one of the wrapper initializers (such as +// `tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this +// function directly. +// +// The remappings are 1-D tensors with the following properties: +// +// * `row_remapping` must have exactly `num_rows` entries. Row `i` of the output +// matrix will be initialized from the row corresponding to index +// `row_remapping[i]` in the old `Tensor` from the checkpoint. +// * `col_remapping` must have either 0 entries (indicating that no column +// reordering is needed) or `num_cols` entries. If specified, column `j` of the +// output matrix will be initialized from the column corresponding to index +// `col_remapping[j]` in the old `Tensor` from the checkpoint. +// * A value of -1 in either of the remappings signifies a "missing" entry. In that +// case, values from the `initializing_values` tensor will be used to fill that +// missing row or column. If `row_remapping` has `r` missing entries and +// `col_remapping` has `c` missing entries, then the following condition must be +// true: +// +// `(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)` +// +// The remapping tensors can be generated using the GenerateVocabRemapping op. +// +// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], +// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing +// the value from row i, column j of the old tensor in the checkpoint, the output +// matrix will look like the following: +// +// [[w(1, 0), w(1, 2), 0.5], +// [w(0, 0), w(0, 2), -0.5], +// [0.25, -0.25, 42]] +// +// Arguments: +// ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from +// which the old matrix `Tensor` will be loaded. +// old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. +// row_remapping: An int `Tensor` of row remappings (generally created by +// `generate_vocab_remapping`). Even if no row remapping is needed, this must +// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted +// index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`). +// col_remapping: An int `Tensor` of column remappings (generally created by +// `generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping +// is to be done (e.g. column ordering is the same). +// initializing_values: A float `Tensor` containing values to fill in for cells +// in the output matrix that are not loaded from the checkpoint. Length must be +// exactly the same as the number of missing / new cells. +// num_rows: Number of rows (length of the 1st dimension) in the output matrix. +// num_cols: Number of columns (length of the 2nd dimension) in the output matrix. +// +// Returns Output matrix containing existing values loaded from the +// checkpoint, and with any missing values filled in from initializing_values. +func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Output, row_remapping tf.Output, col_remapping tf.Output, initializing_values tf.Output, num_rows int64, num_cols int64, optional ...LoadAndRemapMatrixAttr) (output_matrix tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets} + attrs := map[string]interface{}{"num_rows": num_rows, "num_cols": num_cols} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CollectiveReduce", + Type: "LoadAndRemapMatrix", + Input: []tf.Input{ + ckpt_path, old_tensor_name, row_remapping, col_remapping, initializing_values, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize. +type QuantizeAndDequantizeAttr func(optionalAttr) + +// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["signed_input"] = value + } +} + +// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to false +func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value. +// If not specified, defaults to 0 +func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["input_min"] = value + } +} + +// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. +// If not specified, defaults to 0 +func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["input_max"] = value + } +} + +// Use QuantizeAndDequantizeV2 instead. +// +// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 +func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizeAndDequantize", Input: []tf.Input{ input, }, @@ -4137,6 +4344,33 @@ func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key return op.Output(0) } +// Returns the index of a data point that should be added to the seed set. +// +// Entries in distances are assumed to be squared distances of candidate points to +// the already sampled centers in the seed set. The op constructs one Markov chain +// of the k-MC^2 algorithm and returns the index of one candidate point to be added +// as an additional cluster center. +// +// Arguments: +// distances: Vector with squared distances to the closest previously sampled cluster center +// for each candidate point. +// seed: Scalar. Seed for initializing the random number generator. +// +// Returns Scalar with the index of the sampled point. +func KMC2ChainInitialization(scope *Scope, distances tf.Output, seed tf.Output) (index tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "KMC2ChainInitialization", + Input: []tf.Input{ + distances, seed, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AbortAttr is an optional argument to Abort. type AbortAttr func(optionalAttr) @@ -4182,6 +4416,21 @@ func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) { return scope.AddOperation(opspec) } +// Does nothing. Serves as a control trigger for scheduling. +// +// Only useful as a placeholder for control edges. +// +// Returns the created operation. +func ControlTrigger(scope *Scope) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ControlTrigger", + } + return scope.AddOperation(opspec) +} + // Forwards the input to the output. // // This operator represents the loop termination condition used by the @@ -4205,39 +4454,85 @@ func LoopCond(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// Returns a tensor of zeros with the same shape and type as x. +// EnterAttr is an optional argument to Enter. +type EnterAttr func(optionalAttr) + +// EnterIsConstant sets the optional is_constant attribute to value. +// +// value: If true, the output is constant within the child frame. +// If not specified, defaults to false +func EnterIsConstant(value bool) EnterAttr { + return func(m optionalAttr) { + m["is_constant"] = value + } +} + +// EnterParallelIterations sets the optional parallel_iterations attribute to value. +// +// value: The number of iterations allowed to run in parallel. +// If not specified, defaults to 10 +func EnterParallelIterations(value int64) EnterAttr { + return func(m optionalAttr) { + m["parallel_iterations"] = value + } +} + +// Creates or finds a child frame, and makes `data` available to the child frame. +// +// This op is used together with `Exit` to create loops in the graph. +// The unique `frame_name` is used by the `Executor` to identify frames. If +// `is_constant` is true, `output` is a constant in the child frame; otherwise +// it may be changed in the child frame. At most `parallel_iterations` iterations +// are run in parallel in the child frame. // // Arguments: -// x: a tensor of type T. +// data: The tensor to be made available to the child frame. +// frame_name: The name of the child frame. // -// Returns a tensor of the same shape and type as x but filled with zeros. -func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { +// Returns The same tensor as `data`. +func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"frame_name": frame_name} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ZerosLike", + Type: "Enter", Input: []tf.Input{ - x, + data, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a copy of the input tensor. -func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { +// Forwards the value of an available tensor from `inputs` to `output`. +// +// `Merge` waits for at least one of the tensors in `inputs` to become available. +// It is usually combined with `Switch` to implement branching. +// +// `Merge` forwards the first tensor to become available to `output`, and sets +// `value_index` to its index in `inputs`. +// +// Arguments: +// inputs: The input tensors, exactly one of which will become available. +// +// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. +func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Snapshot", + Type: "Merge", Input: []tf.Input{ - input, + tf.OutputList(inputs), }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } // Forwards `data` to the output port determined by `pred`. @@ -4266,74 +4561,6 @@ func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Outpu return op.Output(0), op.Output(1) } -// AudioSpectrogramAttr is an optional argument to AudioSpectrogram. -type AudioSpectrogramAttr func(optionalAttr) - -// AudioSpectrogramMagnitudeSquared sets the optional magnitude_squared attribute to value. -// -// value: Whether to return the squared magnitude or just the -// magnitude. Using squared magnitude can avoid extra calculations. -// If not specified, defaults to false -func AudioSpectrogramMagnitudeSquared(value bool) AudioSpectrogramAttr { - return func(m optionalAttr) { - m["magnitude_squared"] = value - } -} - -// Produces a visualization of audio data over time. -// -// Spectrograms are a standard way of representing audio information as a series of -// slices of frequency information, one slice for each window of time. By joining -// these together into a sequence, they form a distinctive fingerprint of the sound -// over time. -// -// This op expects to receive audio data as an input, stored as floats in the range -// -1 to 1, together with a window width in samples, and a stride specifying how -// far to move the window between slices. From this it generates a three -// dimensional output. The lowest dimension has an amplitude value for each -// frequency during that time slice. The next dimension is time, with successive -// frequency slices. The final dimension is for the channels in the input, so a -// stereo audio input would have two here for example. -// -// This means the layout when converted and saved as an image is rotated 90 degrees -// clockwise from a typical spectrogram. Time is descending down the Y axis, and -// the frequency decreases from left to right. -// -// Each value in the result represents the square root of the sum of the real and -// imaginary parts of an FFT on the current window of samples. In this way, the -// lowest dimension represents the power of each frequency in the current window, -// and adjacent windows are concatenated in the next dimension. -// -// To get a more intuitive and visual look at what this operation does, you can run -// tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the -// resulting spectrogram as a PNG image. -// -// Arguments: -// input: Float representation of audio data. -// window_size: How wide the input window is in samples. For the highest efficiency -// this should be a power of two, but other values are accepted. -// stride: How widely apart the center of adjacent sample windows should be. -// -// Returns 3D representation of the audio frequencies as an image. -func AudioSpectrogram(scope *Scope, input tf.Output, window_size int64, stride int64, optional ...AudioSpectrogramAttr) (spectrogram tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"window_size": window_size, "stride": stride} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AudioSpectrogram", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // CTCBeamSearchDecoderAttr is an optional argument to CTCBeamSearchDecoder. type CTCBeamSearchDecoderAttr func(optionalAttr) @@ -4406,228 +4633,61 @@ func CTCBeamSearchDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Out return decoded_indices, decoded_values, decoded_shape, log_probability } -// CTCGreedyDecoderAttr is an optional argument to CTCGreedyDecoder. -type CTCGreedyDecoderAttr func(optionalAttr) +// CudnnRNNCanonicalToParamsAttr is an optional argument to CudnnRNNCanonicalToParams. +type CudnnRNNCanonicalToParamsAttr func(optionalAttr) -// CTCGreedyDecoderMergeRepeated sets the optional merge_repeated attribute to value. -// -// value: If True, merge repeated classes in output. -// If not specified, defaults to false -func CTCGreedyDecoderMergeRepeated(value bool) CTCGreedyDecoderAttr { - return func(m optionalAttr) { - m["merge_repeated"] = value - } -} - -// Performs greedy decoding on the logits given in inputs. -// -// A note about the attribute merge_repeated: if enabled, when -// consecutive logits' maximum indices are the same, only the first of -// these is emitted. Labeling the blank '*', the sequence "A B B * B B" -// becomes "A B B" if merge_repeated = True and "A B B B B" if -// merge_repeated = False. -// -// Regardless of the value of merge_repeated, if the maximum index of a given -// time and batch corresponds to the blank, index `(num_classes - 1)`, no new -// element is emitted. -// -// Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// sequence_length: A vector containing sequence lengths, size `(batch_size)`. -// -// Returns Indices matrix, size `(total_decoded_outputs x 2)`, -// of a `SparseTensor`. The rows store: [batch, time].Values vector, size: `(total_decoded_outputs)`, -// of a `SparseTensor`. The vector stores the decoded classes.Shape vector, size `(2)`, of the decoded SparseTensor. -// Values are: `[batch_size, max_decoded_length]`.Matrix, size `(batch_size x 1)`, containing sequence -// log-probabilities. -func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, optional ...CTCGreedyDecoderAttr) (decoded_indices tf.Output, decoded_values tf.Output, decoded_shape tf.Output, log_probability tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CTCGreedyDecoder", - Input: []tf.Input{ - inputs, sequence_length, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// CTCLossAttr is an optional argument to CTCLoss. -type CTCLossAttr func(optionalAttr) - -// CTCLossPreprocessCollapseRepeated sets the optional preprocess_collapse_repeated attribute to value. -// -// value: Scalar, if true then repeated labels are -// collapsed prior to the CTC calculation. -// If not specified, defaults to false -func CTCLossPreprocessCollapseRepeated(value bool) CTCLossAttr { - return func(m optionalAttr) { - m["preprocess_collapse_repeated"] = value - } -} - -// CTCLossCtcMergeRepeated sets the optional ctc_merge_repeated attribute to value. -// -// value: Scalar. If set to false, *during* CTC calculation -// repeated non-blank labels will not be merged and are interpreted as -// individual labels. This is a simplified version of CTC. -// If not specified, defaults to true -func CTCLossCtcMergeRepeated(value bool) CTCLossAttr { - return func(m optionalAttr) { - m["ctc_merge_repeated"] = value - } -} - -// CTCLossIgnoreLongerOutputsThanInputs sets the optional ignore_longer_outputs_than_inputs attribute to value. -// -// value: Scalar. If set to true, during CTC -// calculation, items that have longer output sequences than input sequences -// are skipped: they don't contribute to the loss term and have zero-gradient. -// If not specified, defaults to false -func CTCLossIgnoreLongerOutputsThanInputs(value bool) CTCLossAttr { - return func(m optionalAttr) { - m["ignore_longer_outputs_than_inputs"] = value - } -} - -// Calculates the CTC Loss (log probability) for each batch entry. Also calculates -// -// the gradient. This class performs the softmax operation for you, so inputs -// should be e.g. linear projections of outputs by an LSTM. -// -// Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// labels_indices: The indices of a `SparseTensor`. -// `labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for -// `(batch b, time t)`. -// labels_values: The values (labels) associated with the given batch and time. -// sequence_length: A vector containing sequence lengths (batch). -// -// Returns A vector (batch) containing log-probabilities.The gradient of `loss`. 3-D, shape: -// `(max_time x batch_size x num_classes)`. -func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_values tf.Output, sequence_length tf.Output, optional ...CTCLossAttr) (loss tf.Output, gradient tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CTCLoss", - Input: []tf.Input{ - inputs, labels_indices, labels_values, sequence_length, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// ShapeNAttr is an optional argument to ShapeN. -type ShapeNAttr func(optionalAttr) - -// ShapeNOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func ShapeNOutType(value tf.DataType) ShapeNAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Returns shape of tensors. -// -// This operation returns N 1-D integer tensors representing shape of `input[i]s`. -func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ShapeN", - Input: []tf.Input{ - tf.OutputList(input), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("ShapeN", err) - return - } - return output -} - -// CudnnRNNParamsToCanonicalAttr is an optional argument to CudnnRNNParamsToCanonical. -type CudnnRNNParamsToCanonicalAttr func(optionalAttr) - -// CudnnRNNParamsToCanonicalRnnMode sets the optional rnn_mode attribute to value. +// CudnnRNNCanonicalToParamsRnnMode sets the optional rnn_mode attribute to value. // If not specified, defaults to "lstm" -func CudnnRNNParamsToCanonicalRnnMode(value string) CudnnRNNParamsToCanonicalAttr { +func CudnnRNNCanonicalToParamsRnnMode(value string) CudnnRNNCanonicalToParamsAttr { return func(m optionalAttr) { m["rnn_mode"] = value } } -// CudnnRNNParamsToCanonicalInputMode sets the optional input_mode attribute to value. +// CudnnRNNCanonicalToParamsInputMode sets the optional input_mode attribute to value. // If not specified, defaults to "linear_input" -func CudnnRNNParamsToCanonicalInputMode(value string) CudnnRNNParamsToCanonicalAttr { +func CudnnRNNCanonicalToParamsInputMode(value string) CudnnRNNCanonicalToParamsAttr { return func(m optionalAttr) { m["input_mode"] = value } } -// CudnnRNNParamsToCanonicalDirection sets the optional direction attribute to value. +// CudnnRNNCanonicalToParamsDirection sets the optional direction attribute to value. // If not specified, defaults to "unidirectional" -func CudnnRNNParamsToCanonicalDirection(value string) CudnnRNNParamsToCanonicalAttr { +func CudnnRNNCanonicalToParamsDirection(value string) CudnnRNNCanonicalToParamsAttr { return func(m optionalAttr) { m["direction"] = value } } -// CudnnRNNParamsToCanonicalDropout sets the optional dropout attribute to value. +// CudnnRNNCanonicalToParamsDropout sets the optional dropout attribute to value. // If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalDropout(value float32) CudnnRNNParamsToCanonicalAttr { +func CudnnRNNCanonicalToParamsDropout(value float32) CudnnRNNCanonicalToParamsAttr { return func(m optionalAttr) { m["dropout"] = value } } -// CudnnRNNParamsToCanonicalSeed sets the optional seed attribute to value. +// CudnnRNNCanonicalToParamsSeed sets the optional seed attribute to value. // If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalSeed(value int64) CudnnRNNParamsToCanonicalAttr { +func CudnnRNNCanonicalToParamsSeed(value int64) CudnnRNNCanonicalToParamsAttr { return func(m optionalAttr) { m["seed"] = value } } -// CudnnRNNParamsToCanonicalSeed2 sets the optional seed2 attribute to value. +// CudnnRNNCanonicalToParamsSeed2 sets the optional seed2 attribute to value. // If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalSeed2(value int64) CudnnRNNParamsToCanonicalAttr { +func CudnnRNNCanonicalToParamsSeed2(value int64) CudnnRNNCanonicalToParamsAttr { return func(m optionalAttr) { m["seed2"] = value } } -// Retrieves CudnnRNN params in canonical form. +// Converts CudnnRNN params from canonical form to usable form. // -// Retrieves a set of weights from the opaque params buffer that can be saved and -// restored in a way compatible with future runs. +// Writes a set of weights into the opaque params buffer so they can be used in +// upcoming training or inferences. // // Note that the params buffer may not be compatible across different GPUs. So any // save and restoration should be converted to and from the canonical weights and @@ -4636,15 +4696,15 @@ func CudnnRNNParamsToCanonicalSeed2(value int64) CudnnRNNParamsToCanonicalAttr { // num_layers: Specifies the number of layers in the RNN model. // num_units: Specifies the size of the hidden state. // input_size: Specifies the size of the input state. -// num_params: number of parameter sets for all layers. -// Each layer may contain multiple parameter sets, with each set consisting of -// a weight matrix and a bias vector. // weights: the canonical form of weights that can be used for saving // and restoration. They are more likely to be compatible across different // generations. // biases: the canonical form of biases that can be used for saving // and restoration. They are more likely to be compatible across different // generations. +// num_params: number of parameter sets for all layers. +// Each layer may contain multiple parameter sets, with each set consisting of +// a weight matrix and a bias vector. // rnn_mode: Indicates the type of the RNN model. // input_mode: Indicate whether there is a linear projection between the input and // The actual computation before the first layer. 'skip_input' is only allowed @@ -4655,36 +4715,23 @@ func CudnnRNNParamsToCanonicalSeed2(value int64) CudnnRNNParamsToCanonicalAttr { // dropout: dropout probability. When set to 0., dropout is disabled. // seed: the 1st part of a seed to initialize dropout. // seed2: the 2nd part of a seed to initialize dropout. -func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, params tf.Output, num_params int64, optional ...CudnnRNNParamsToCanonicalAttr) (weights []tf.Output, biases []tf.Output) { +func CudnnRNNCanonicalToParams(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, weights []tf.Output, biases []tf.Output, optional ...CudnnRNNCanonicalToParamsAttr) (params tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_params": num_params} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNNParamsToCanonical", + Type: "CudnnRNNCanonicalToParams", Input: []tf.Input{ - num_layers, num_units, input_size, params, + num_layers, num_units, input_size, tf.OutputList(weights), tf.OutputList(biases), }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if weights, idx, err = makeOutputList(op, idx, "weights"); err != nil { - scope.UpdateErr("CudnnRNNParamsToCanonical", err) - return - } - if biases, idx, err = makeOutputList(op, idx, "biases"); err != nil { - scope.UpdateErr("CudnnRNNParamsToCanonical", err) - return - } - return weights, biases + return op.Output(0) } // CudnnRNNBackpropV3Attr is an optional argument to CudnnRNNBackpropV3. @@ -4738,6 +4785,14 @@ func CudnnRNNBackpropV3Seed2(value int64) CudnnRNNBackpropV3Attr { } } +// CudnnRNNBackpropV3TimeMajor sets the optional time_major attribute to value. +// If not specified, defaults to true +func CudnnRNNBackpropV3TimeMajor(value bool) CudnnRNNBackpropV3Attr { + return func(m optionalAttr) { + m["time_major"] = value + } +} + // Backprop step of CudnnRNNV3. // // Compute the backprop of both data and weights in a RNN. Takes an extra @@ -4753,9 +4808,12 @@ func CudnnRNNBackpropV3Seed2(value int64) CudnnRNNBackpropV3Attr { // dropout: Dropout probability. When set to 0., dropout is disabled. // seed: The 1st part of a seed to initialize dropout. // seed2: The 2nd part of a seed to initialize dropout. -// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. +// input: If time_major is true, this is a 3-D tensor with the shape of +// [seq_length, batch_size, input_size]. If time_major is false, the shape is +// [batch_size, seq_length, input_size]. +// input_h: If time_major is true, this is a 3-D tensor with the shape of +// [num_layer * dir, batch_size, num_units]. If time_major is false, the shape +// is [batch_size, num_layer * dir, num_units]. // input_c: For LSTM, a 3-D tensor with the shape of // [num_layer * dir, batch, num_units]. For other models, it is ignored. // params: A 1-D tensor that contains the weights and biases in an opaque layout. @@ -4763,8 +4821,9 @@ func CudnnRNNBackpropV3Seed2(value int64) CudnnRNNBackpropV3Attr { // separately. Note that they might not be compatible across different // generations. So it is a good idea to save and restore // sequence_lengths: a vector of lengths of each input sequence. -// output: A 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. +// output: If time_major is true, this is a 3-D tensor with the shape of +// [seq_length, batch_size, dir * num_units]. If time_major is false, the +// shape is [batch_size, seq_length, dir * num_units]. // output_h: The same shape has input_h. // output_c: The same shape as input_c for LSTM. An empty tensor for other models. // output_backprop: A 3-D tensor with the same shape as output in the forward pass. @@ -4772,6 +4831,8 @@ func CudnnRNNBackpropV3Seed2(value int64) CudnnRNNBackpropV3Attr { // pass. // output_c_backprop: A 3-D tensor with the same shape as output_c in the forward // pass. +// time_major: Indicates whether the input/output format is time major or batch +// major. // reserve_space: The same reserve_space produced in the forward operation. // input_backprop: The backprop to input in the forward pass. Has the same shape // as input. @@ -4800,102 +4861,113 @@ func CudnnRNNBackpropV3(scope *Scope, input tf.Output, input_h tf.Output, input_ return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// CudnnRNNBackpropV2Attr is an optional argument to CudnnRNNBackpropV2. -type CudnnRNNBackpropV2Attr func(optionalAttr) +// CudnnRNNV3Attr is an optional argument to CudnnRNNV3. +type CudnnRNNV3Attr func(optionalAttr) -// CudnnRNNBackpropV2RnnMode sets the optional rnn_mode attribute to value. +// CudnnRNNV3RnnMode sets the optional rnn_mode attribute to value. // If not specified, defaults to "lstm" -func CudnnRNNBackpropV2RnnMode(value string) CudnnRNNBackpropV2Attr { +func CudnnRNNV3RnnMode(value string) CudnnRNNV3Attr { return func(m optionalAttr) { m["rnn_mode"] = value } } -// CudnnRNNBackpropV2InputMode sets the optional input_mode attribute to value. +// CudnnRNNV3InputMode sets the optional input_mode attribute to value. // If not specified, defaults to "linear_input" -func CudnnRNNBackpropV2InputMode(value string) CudnnRNNBackpropV2Attr { +func CudnnRNNV3InputMode(value string) CudnnRNNV3Attr { return func(m optionalAttr) { m["input_mode"] = value } } -// CudnnRNNBackpropV2Direction sets the optional direction attribute to value. +// CudnnRNNV3Direction sets the optional direction attribute to value. // If not specified, defaults to "unidirectional" -func CudnnRNNBackpropV2Direction(value string) CudnnRNNBackpropV2Attr { +func CudnnRNNV3Direction(value string) CudnnRNNV3Attr { return func(m optionalAttr) { m["direction"] = value } } -// CudnnRNNBackpropV2Dropout sets the optional dropout attribute to value. +// CudnnRNNV3Dropout sets the optional dropout attribute to value. // If not specified, defaults to 0 -func CudnnRNNBackpropV2Dropout(value float32) CudnnRNNBackpropV2Attr { +func CudnnRNNV3Dropout(value float32) CudnnRNNV3Attr { return func(m optionalAttr) { m["dropout"] = value } } -// CudnnRNNBackpropV2Seed sets the optional seed attribute to value. +// CudnnRNNV3Seed sets the optional seed attribute to value. // If not specified, defaults to 0 -func CudnnRNNBackpropV2Seed(value int64) CudnnRNNBackpropV2Attr { +func CudnnRNNV3Seed(value int64) CudnnRNNV3Attr { return func(m optionalAttr) { m["seed"] = value } } -// CudnnRNNBackpropV2Seed2 sets the optional seed2 attribute to value. +// CudnnRNNV3Seed2 sets the optional seed2 attribute to value. // If not specified, defaults to 0 -func CudnnRNNBackpropV2Seed2(value int64) CudnnRNNBackpropV2Attr { +func CudnnRNNV3Seed2(value int64) CudnnRNNV3Attr { return func(m optionalAttr) { m["seed2"] = value } } -// Backprop step of CudnnRNN. +// CudnnRNNV3IsTraining sets the optional is_training attribute to value. +// If not specified, defaults to true +func CudnnRNNV3IsTraining(value bool) CudnnRNNV3Attr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// CudnnRNNV3TimeMajor sets the optional time_major attribute to value. +// If not specified, defaults to true +func CudnnRNNV3TimeMajor(value bool) CudnnRNNV3Attr { + return func(m optionalAttr) { + m["time_major"] = value + } +} + +// A RNN backed by cuDNN. // -// Compute the backprop of both data and weights in a RNN. Takes an extra -// "host_reserved" inupt than CudnnRNNBackprop, which is used to determine RNN -// cudnnRNNAlgo_t and cudnnMathType_t. +// Computes the RNN from the input and initial states, with respect to the params +// buffer. Accepts one extra input "sequence_lengths" than CudnnRNN. // // rnn_mode: Indicates the type of the RNN model. // input_mode: Indicates whether there is a linear projection between the input and -// the actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. // direction: Indicates whether a bidirectional model will be used. Should be // "unidirectional" or "bidirectional". // dropout: Dropout probability. When set to 0., dropout is disabled. // seed: The 1st part of a seed to initialize dropout. // seed2: The 2nd part of a seed to initialize dropout. -// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. +// input: If time_major is true, this is a 3-D tensor with the shape of +// [seq_length, batch_size, input_size]. If time_major is false, the shape is +// [batch_size, seq_length, input_size]. +// input_h: If time_major is true, this is a 3-D tensor with the shape of +// [num_layer * dir, batch_size, num_units]. If time_major is false, the shape +// is [batch_size, num_layer * dir, num_units]. // input_c: For LSTM, a 3-D tensor with the shape of // [num_layer * dir, batch, num_units]. For other models, it is ignored. // params: A 1-D tensor that contains the weights and biases in an opaque layout. // The size must be created through CudnnRNNParamsSize, and initialized // separately. Note that they might not be compatible across different // generations. So it is a good idea to save and restore -// output: A 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. +// sequence_lengths: a vector of lengths of each input sequence. +// output: If time_major is true, this is a 3-D tensor with the shape of +// [seq_length, batch_size, dir * num_units]. If time_major is false, the +// shape is [batch_size, seq_length, dir * num_units]. // output_h: The same shape has input_h. // output_c: The same shape as input_c for LSTM. An empty tensor for other models. -// output_backprop: A 3-D tensor with the same shape as output in the forward pass. -// output_h_backprop: A 3-D tensor with the same shape as output_h in the forward -// pass. -// output_c_backprop: A 3-D tensor with the same shape as output_c in the forward -// pass. -// reserve_space: The same reserve_space produced in the forward operation. -// host_reserved: The same host_reserved produced in the forward operation. -// input_backprop: The backprop to input in the forward pass. Has the same shape -// as input. -// input_h_backprop: The backprop to input_h in the forward pass. Has the same -// shape as input_h. -// input_c_backprop: The backprop to input_c in the forward pass. Has the same -// shape as input_c. -// params_backprop: The backprop to the params buffer in the forward pass. Has the -// same shape as params. -func CudnnRNNBackpropV2(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, output tf.Output, output_h tf.Output, output_c tf.Output, output_backprop tf.Output, output_h_backprop tf.Output, output_c_backprop tf.Output, reserve_space tf.Output, host_reserved tf.Output, optional ...CudnnRNNBackpropV2Attr) (input_backprop tf.Output, input_h_backprop tf.Output, input_c_backprop tf.Output, params_backprop tf.Output) { +// is_training: Indicates whether this operation is used for inferenece or +// training. +// time_major: Indicates whether the input/output format is time major or batch +// major. +// reserve_space: An opaque tensor that can be used in backprop calculation. It +// is only produced if is_training is true. +func CudnnRNNV3(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, sequence_lengths tf.Output, optional ...CudnnRNNV3Attr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output, host_reserved tf.Output) { if scope.Err() != nil { return } @@ -4904,14 +4976,43 @@ func CudnnRNNBackpropV2(scope *Scope, input tf.Output, input_h tf.Output, input_ a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNNBackpropV2", + Type: "CudnnRNNV3", Input: []tf.Input{ - input, input_h, input_c, params, output, output_h, output_c, output_backprop, output_h_backprop, output_c_backprop, reserve_space, host_reserved, + input, input_h, input_c, params, sequence_lengths, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + +// Runs multiple additive regression ensemble predictors on input instances and +// +// computes the logits. It is designed to be used during prediction. +// It traverses all the trees and calculates the final score for each instance. +// +// Arguments: +// +// bucketized_features: A list of rank 1 Tensors containing bucket id for each +// feature. +// logits_dimension: scalar, dimension of the logits, to be used for partial logits +// shape. +// +// Returns Output rank 2 Tensor containing logits for each example. +func BoostedTreesPredict(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (logits tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"logits_dimension": logits_dimension} + opspec := tf.OpSpec{ + Type: "BoostedTreesPredict", + Input: []tf.Input{ + tree_ensemble_handle, tf.OutputList(bucketized_features), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } // CudnnRNNV2Attr is an optional argument to CudnnRNNV2. @@ -5027,92 +5128,114 @@ func CudnnRNNV2(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Out return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// RecordInputAttr is an optional argument to RecordInput. -type RecordInputAttr func(optionalAttr) +// CudnnRNNAttr is an optional argument to CudnnRNN. +type CudnnRNNAttr func(optionalAttr) -// RecordInputFileRandomSeed sets the optional file_random_seed attribute to value. -// -// value: Random seeds used to produce randomized records. -// If not specified, defaults to 301 -func RecordInputFileRandomSeed(value int64) RecordInputAttr { +// CudnnRNNRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNRnnMode(value string) CudnnRNNAttr { return func(m optionalAttr) { - m["file_random_seed"] = value + m["rnn_mode"] = value } } -// RecordInputFileShuffleShiftRatio sets the optional file_shuffle_shift_ratio attribute to value. -// -// value: Shifts the list of files after the list is randomly -// shuffled. +// CudnnRNNInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNInputMode(value string) CudnnRNNAttr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNDirection(value string) CudnnRNNAttr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNDropout sets the optional dropout attribute to value. // If not specified, defaults to 0 -func RecordInputFileShuffleShiftRatio(value float32) RecordInputAttr { +func CudnnRNNDropout(value float32) CudnnRNNAttr { return func(m optionalAttr) { - m["file_shuffle_shift_ratio"] = value + m["dropout"] = value } } -// RecordInputFileBufferSize sets the optional file_buffer_size attribute to value. -// -// value: The randomization shuffling buffer. -// If not specified, defaults to 10000 -func RecordInputFileBufferSize(value int64) RecordInputAttr { +// CudnnRNNSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNSeed(value int64) CudnnRNNAttr { return func(m optionalAttr) { - m["file_buffer_size"] = value + m["seed"] = value } } -// RecordInputFileParallelism sets the optional file_parallelism attribute to value. -// -// value: How many sstables are opened and concurrently iterated over. -// If not specified, defaults to 16 -func RecordInputFileParallelism(value int64) RecordInputAttr { +// CudnnRNNSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNSeed2(value int64) CudnnRNNAttr { return func(m optionalAttr) { - m["file_parallelism"] = value + m["seed2"] = value } } -// RecordInputBatchSize sets the optional batch_size attribute to value. -// -// value: The batch size. -// If not specified, defaults to 32 -func RecordInputBatchSize(value int64) RecordInputAttr { +// CudnnRNNIsTraining sets the optional is_training attribute to value. +// If not specified, defaults to true +func CudnnRNNIsTraining(value bool) CudnnRNNAttr { return func(m optionalAttr) { - m["batch_size"] = value + m["is_training"] = value } } -// RecordInputCompressionType sets the optional compression_type attribute to value. +// A RNN backed by cuDNN. // -// value: The type of compression for the file. Currently ZLIB and -// GZIP are supported. Defaults to none. -// If not specified, defaults to "" -func RecordInputCompressionType(value string) RecordInputAttr { - return func(m optionalAttr) { - m["compression_type"] = value - } -} - -// Emits randomized records. +// Computes the RNN from the input and initial states, with respect to the params +// buffer. // -// Arguments: -// file_pattern: Glob pattern for the data files. -// -// Returns A tensor of shape [batch_size]. -func RecordInput(scope *Scope, file_pattern string, optional ...RecordInputAttr) (records tf.Output) { +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, +// num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// output: A 3-D tensor with the shape of [seq_length, batch_size, +// dir * num_units]. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. +// is_training: Indicates whether this operation is used for inferenece or +// training. +// reserve_space: An opaque tensor that can be used in backprop calculation. It +// is only produced if is_training is false. +func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNAttr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"file_pattern": file_pattern} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RecordInput", - + Type: "CudnnRNN", + Input: []tf.Input{ + input, input_h, input_c, params, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } // OrderedMapClearAttr is an optional argument to OrderedMapClear. @@ -5230,43 +5353,6 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or return op.Output(0) } -// BoostedTreesQuantileStreamResourceHandleOpAttr is an optional argument to BoostedTreesQuantileStreamResourceHandleOp. -type BoostedTreesQuantileStreamResourceHandleOpAttr func(optionalAttr) - -// BoostedTreesQuantileStreamResourceHandleOpContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func BoostedTreesQuantileStreamResourceHandleOpContainer(value string) BoostedTreesQuantileStreamResourceHandleOpAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// BoostedTreesQuantileStreamResourceHandleOpSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func BoostedTreesQuantileStreamResourceHandleOpSharedName(value string) BoostedTreesQuantileStreamResourceHandleOpAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a handle to a BoostedTreesQuantileStreamResource. -func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...BoostedTreesQuantileStreamResourceHandleOpAttr) (resource tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "BoostedTreesQuantileStreamResourceHandleOp", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // OrderedMapSizeAttr is an optional argument to OrderedMapSize. type OrderedMapSizeAttr func(optionalAttr) @@ -5324,25 +5410,61 @@ func OrderedMapSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapSi return op.Output(0) } -// Generate the bucket boundaries for each feature based on accumulated summaries. +// OrderedMapUnstageNoKeyAttr is an optional argument to OrderedMapUnstageNoKey. +type OrderedMapUnstageNoKeyAttr func(optionalAttr) + +// OrderedMapUnstageNoKeyCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// An op that returns a list of float tensors for a quantile stream resource. Each -// tensor is Rank 1 containing bucket boundaries for a single feature. +// REQUIRES: value >= 0 +func OrderedMapUnstageNoKeyCapacity(value int64) OrderedMapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// OrderedMapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. -// num_features: inferred int; number of features to get bucket boundaries for. +// REQUIRES: value >= 0 +func OrderedMapUnstageNoKeyMemoryLimit(value int64) OrderedMapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// OrderedMapUnstageNoKeyContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageNoKeyContainer(value string) OrderedMapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapUnstageNoKeySharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageNoKeySharedName(value string) OrderedMapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns the (key, value) element with the smallest // -// Returns float; List of Rank 1 Tensors each containing the bucket boundaries for a feature. -func BoostedTreesQuantileStreamResourceGetBucketBoundaries(scope *Scope, quantile_stream_resource_handle tf.Output, num_features int64) (bucket_boundaries []tf.Output) { +// key from the underlying container. If the underlying container +// does not contain elements, the op will block until it does. +func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_features": num_features} + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BoostedTreesQuantileStreamResourceGetBucketBoundaries", + Type: "OrderedMapUnstageNoKey", Input: []tf.Input{ - quantile_stream_resource_handle, + indices, }, Attrs: attrs, } @@ -5352,11 +5474,12 @@ func BoostedTreesQuantileStreamResourceGetBucketBoundaries(scope *Scope, quantil } var idx int var err error - if bucket_boundaries, idx, err = makeOutputList(op, idx, "bucket_boundaries"); err != nil { - scope.UpdateErr("BoostedTreesQuantileStreamResourceGetBucketBoundaries", err) + key = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapUnstageNoKey", err) return } - return bucket_boundaries + return key, values } // OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage. @@ -5502,47 +5625,66 @@ func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf. return values } -// MapIncompleteSizeAttr is an optional argument to MapIncompleteSize. -type MapIncompleteSizeAttr func(optionalAttr) +// OrderedMapStageAttr is an optional argument to OrderedMapStage. +type OrderedMapStageAttr func(optionalAttr) -// MapIncompleteSizeCapacity sets the optional capacity attribute to value. +// OrderedMapStageCapacity sets the optional capacity attribute to value. +// +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func MapIncompleteSizeCapacity(value int64) MapIncompleteSizeAttr { +func OrderedMapStageCapacity(value int64) OrderedMapStageAttr { return func(m optionalAttr) { m["capacity"] = value } } -// MapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// OrderedMapStageMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func MapIncompleteSizeMemoryLimit(value int64) MapIncompleteSizeAttr { +func OrderedMapStageMemoryLimit(value int64) OrderedMapStageAttr { return func(m optionalAttr) { m["memory_limit"] = value } } -// MapIncompleteSizeContainer sets the optional container attribute to value. +// OrderedMapStageContainer sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. // If not specified, defaults to "" -func MapIncompleteSizeContainer(value string) MapIncompleteSizeAttr { +func OrderedMapStageContainer(value string) OrderedMapStageAttr { return func(m optionalAttr) { m["container"] = value } } -// MapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// OrderedMapStageSharedName sets the optional shared_name attribute to value. +// +// value: It is necessary to match this name to the matching Unstage Op. // If not specified, defaults to "" -func MapIncompleteSizeSharedName(value string) MapIncompleteSizeAttr { +func OrderedMapStageSharedName(value string) OrderedMapStageAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// Op returns the number of incomplete elements in the underlying container. -func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncompleteSizeAttr) (size tf.Output) { +// Stage (key, values) in the underlying container which behaves like a ordered +// +// associative container. Elements are ordered by key. +// +// Arguments: +// key: int64 +// +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. +// +// +// Returns the created operation. +func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...OrderedMapStageAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -5551,12 +5693,13 @@ func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncomp a(attrs) } opspec := tf.OpSpec{ - Type: "MapIncompleteSize", - + Type: "OrderedMapStage", + Input: []tf.Input{ + key, indices, tf.OutputList(values), + }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } // MapSizeAttr is an optional argument to MapSize. @@ -5616,50 +5759,50 @@ func MapSize(scope *Scope, dtypes []tf.DataType, optional ...MapSizeAttr) (size return op.Output(0) } -// MapPeekAttr is an optional argument to MapPeek. -type MapPeekAttr func(optionalAttr) +// MapUnstageAttr is an optional argument to MapUnstage. +type MapUnstageAttr func(optionalAttr) -// MapPeekCapacity sets the optional capacity attribute to value. +// MapUnstageCapacity sets the optional capacity attribute to value. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func MapPeekCapacity(value int64) MapPeekAttr { +func MapUnstageCapacity(value int64) MapUnstageAttr { return func(m optionalAttr) { m["capacity"] = value } } -// MapPeekMemoryLimit sets the optional memory_limit attribute to value. +// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func MapPeekMemoryLimit(value int64) MapPeekAttr { +func MapUnstageMemoryLimit(value int64) MapUnstageAttr { return func(m optionalAttr) { m["memory_limit"] = value } } -// MapPeekContainer sets the optional container attribute to value. +// MapUnstageContainer sets the optional container attribute to value. // If not specified, defaults to "" -func MapPeekContainer(value string) MapPeekAttr { +func MapUnstageContainer(value string) MapUnstageAttr { return func(m optionalAttr) { m["container"] = value } } -// MapPeekSharedName sets the optional shared_name attribute to value. +// MapUnstageSharedName sets the optional shared_name attribute to value. // If not specified, defaults to "" -func MapPeekSharedName(value string) MapPeekAttr { +func MapUnstageSharedName(value string) MapUnstageAttr { return func(m optionalAttr) { m["shared_name"] = value } } -// Op peeks at the values at the specified key. If the +// Op removes and returns the values associated with the key // -// underlying container does not contain this key -// this op will block until it does. -func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { +// from the underlying container. If the underlying container +// does not contain this key, the op will block until it does. +func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } @@ -5668,7 +5811,7 @@ func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataTyp a(attrs) } opspec := tf.OpSpec{ - Type: "MapPeek", + Type: "MapUnstage", Input: []tf.Input{ key, indices, }, @@ -5681,202 +5824,12 @@ func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataTyp var idx int var err error if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapPeek", err) + scope.UpdateErr("MapUnstage", err) return } return values } -// MapStageAttr is an optional argument to MapStage. -type MapStageAttr func(optionalAttr) - -// MapStageCapacity sets the optional capacity attribute to value. -// -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapStageCapacity(value int64) MapStageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapStageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapStageMemoryLimit(value int64) MapStageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapStageContainer sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. -// If not specified, defaults to "" -func MapStageContainer(value string) MapStageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapStageSharedName sets the optional shared_name attribute to value. -// -// value: It is necessary to match this name to the matching Unstage Op. -// If not specified, defaults to "" -func MapStageSharedName(value string) MapStageAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Stage (key, values) in the underlying container which behaves like a hashtable. -// -// Arguments: -// key: int64 -// -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. -// -// -// Returns the created operation. -func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...MapStageAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapStage", - Input: []tf.Input{ - key, indices, tf.OutputList(values), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// StageClearAttr is an optional argument to StageClear. -type StageClearAttr func(optionalAttr) - -// StageClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageClearCapacity(value int64) StageClearAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// StageClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageClearMemoryLimit(value int64) StageClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// StageClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StageClearContainer(value string) StageClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// StageClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StageClearSharedName(value string) StageClearAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes all elements in the underlying container. -// -// Returns the created operation. -func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StageClear", - - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// StageSizeAttr is an optional argument to StageSize. -type StageSizeAttr func(optionalAttr) - -// StageSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageSizeCapacity(value int64) StageSizeAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// StageSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageSizeMemoryLimit(value int64) StageSizeAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// StageSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StageSizeContainer(value string) StageSizeAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// StageSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StageSizeSharedName(value string) StageSizeAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op returns the number of elements in the underlying container. -func StageSize(scope *Scope, dtypes []tf.DataType, optional ...StageSizeAttr) (size tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StageSize", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // StagePeekAttr is an optional argument to StagePeek. type StagePeekAttr func(optionalAttr) @@ -5949,75 +5902,6 @@ func StagePeek(scope *Scope, index tf.Output, dtypes []tf.DataType, optional ... return values } -// UnstageAttr is an optional argument to Unstage. -type UnstageAttr func(optionalAttr) - -// UnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func UnstageCapacity(value int64) UnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// UnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func UnstageMemoryLimit(value int64) UnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// UnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnstageContainer(value string) UnstageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// UnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnstageSharedName(value string) UnstageAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op is similar to a lightweight Dequeue. -// -// The basic functionality is similar to dequeue with many fewer -// capabilities and options. This Op is optimized for performance. -func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Unstage", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("Unstage", err) - return - } - return values -} - // StageAttr is an optional argument to Stage. type StageAttr func(optionalAttr) @@ -6096,18 +5980,72 @@ func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Opera return scope.AddOperation(opspec) } -// Delete the tensor specified by its handle in the session. +// Pads a tensor with mirrored values. +// +// This operation pads a `input` with mirrored values according to the `paddings` +// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is +// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many values to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many values to add after the contents of `input` +// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater +// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true +// (if false, respectively). +// +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 2, 3], [4, 5, 6]]. +// # 'paddings' is [[1, 1]], [2, 2]]. +// # 'mode' is SYMMETRIC. +// # rank of 't' is 2. +// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] +// [2, 1, 1, 2, 3, 3, 2] +// [5, 4, 4, 5, 6, 6, 5] +// [5, 4, 4, 5, 6, 6, 5]] +// ``` // // Arguments: -// handle: The handle for a tensor stored in the session state. +// input: The input tensor to be padded. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions +// do not include the borders, while in symmetric mode the padded regions +// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` +// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and +// it is `[1, 2, 3, 3, 2]` in symmetric mode. +// +// Returns The padded tensor. +func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"mode": mode} + opspec := tf.OpSpec{ + Type: "MirrorPad", + Input: []tf.Input{ + input, paddings, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Use TensorArrayCloseV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 // // Returns the created operation. -func DeleteSessionTensor(scope *Scope, handle tf.Output) (o *tf.Operation) { +func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DeleteSessionTensor", + Type: "TensorArrayCloseV2", Input: []tf.Input{ handle, }, @@ -6115,22 +6053,42 @@ func DeleteSessionTensor(scope *Scope, handle tf.Output) (o *tf.Operation) { return scope.AddOperation(opspec) } -// Store the input tensor in the state of the current session. +// PlaceholderAttr is an optional argument to Placeholder. +type PlaceholderAttr func(optionalAttr) + +// PlaceholderShape sets the optional shape attribute to value. +// +// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the +// shape is unconstrained. +// If not specified, defaults to +func PlaceholderShape(value tf.Shape) PlaceholderAttr { + return func(m optionalAttr) { + m["shape"] = value + } +} + +// A placeholder op for a value that will be fed into the computation. +// +// N.B. This operation will fail with an error if it is executed. It is +// intended as a way to represent a value that will always be fed, and to +// provide attrs that enable the fed value to be checked at runtime. // // Arguments: -// value: The tensor to be stored. +// dtype: The type of elements in the tensor. // -// Returns The handle for the tensor stored in the session state, represented -// as a string. -func GetSessionHandle(scope *Scope, value tf.Output) (handle tf.Output) { +// Returns A placeholder tensor that must be replaced using the feed mechanism. +func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "GetSessionHandle", - Input: []tf.Input{ - value, - }, + Type: "Placeholder", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -6184,23 +6142,56 @@ func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtyp return op.Output(0), op.Output(1) } -// Deprecated. Use TensorArrayGradV3 +// Deprecated. Use TensorArrayScatterV3 // -// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 -func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 +func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayWriteV2", + Type: "TensorArrayScatterV2", Input: []tf.Input{ - handle, index, value, flow_in, + handle, indices, value, flow_in, }, } op := scope.AddOperation(opspec) return op.Output(0) } +// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. +type TensorArrayGatherV2Attr func(optionalAttr) + +// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Deprecated. Use TensorArrayGatherV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 +func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayGatherV2", + Input: []tf.Input{ + handle, indices, flow_in, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Deprecated. Use TensorArrayGradV3 // // DEPRECATED at GraphDef version 26: Use TensorArrayGradV3 @@ -6220,146 +6211,49 @@ func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source return op.Output(0) } -// TensorArrayV2Attr is an optional argument to TensorArrayV2. -type TensorArrayV2Attr func(optionalAttr) - -// TensorArrayV2ElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorArrayV2ElementShape(value tf.Shape) TensorArrayV2Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// TensorArrayV2DynamicSize sets the optional dynamic_size attribute to value. -// If not specified, defaults to false -func TensorArrayV2DynamicSize(value bool) TensorArrayV2Attr { - return func(m optionalAttr) { - m["dynamic_size"] = value - } -} - -// TensorArrayV2ClearAfterRead sets the optional clear_after_read attribute to value. -// If not specified, defaults to true -func TensorArrayV2ClearAfterRead(value bool) TensorArrayV2Attr { - return func(m optionalAttr) { - m["clear_after_read"] = value - } -} - -// TensorArrayV2TensorArrayName sets the optional tensor_array_name attribute to value. -// If not specified, defaults to "" -func TensorArrayV2TensorArrayName(value string) TensorArrayV2Attr { - return func(m optionalAttr) { - m["tensor_array_name"] = value - } -} - -// Deprecated. Use TensorArrayV3 +// Delete the TensorArray from its resource container. // -// DEPRECATED at GraphDef version 26: Use TensorArrayV3 -func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV2Attr) (handle tf.Output) { +// This enables the user to close and release the resource in the middle +// of a step/run. +// +// Arguments: +// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). +// +// Returns the created operation. +func TensorArrayCloseV3(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TensorArrayV2", + Type: "TensorArrayCloseV3", Input: []tf.Input{ - size, + handle, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Split the data from the input value into TensorArray elements. -// -// Assuming that `lengths` takes on values -// -// ```(n0, n1, ..., n(T-1))``` -// -// and that `value` has shape -// -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, -// -// this splits values into a TensorArray with T tensors. -// -// TensorArray index t will be the subtensor of values with starting position -// -// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` -// -// and having size -// -// ```nt x d0 x d1 x ...``` +// Get the current size of the TensorArray. // // Arguments: -// handle: The handle to a TensorArray. -// value: The concatenated tensor to write to the TensorArray. -// lengths: The vector of lengths, how to split the rows of value into the -// TensorArray. +// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). // flow_in: A float scalar that enforces proper chaining of operations. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Returns The current size of the TensorArray. +func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArraySplitV3", + Type: "TensorArraySizeV3", Input: []tf.Input{ - handle, value, lengths, flow_in, + handle, flow_in, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// EmptyAttr is an optional argument to Empty. -type EmptyAttr func(optionalAttr) - -// EmptyInit sets the optional init attribute to value. -// -// value: If True, initialize the returned tensor with the default value of dtype. Otherwise, the implementation is free not to initializethe tensor's content. -// If not specified, defaults to false -func EmptyInit(value bool) EmptyAttr { - return func(m optionalAttr) { - m["init"] = value - } -} - -// Creates a tensor with the given shape. -// -// This operation creates a tensor of `shape` and `dtype`. -// -// Arguments: -// shape: 1-D. Represents the shape of the output tensor. -// -// -// Returns A `Tensor` of type `T`. -func Empty(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...EmptyAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Empty", - Input: []tf.Input{ - shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // TensorArrayConcatV3Attr is an optional argument to TensorArrayConcatV3. type TensorArrayConcatV3Attr func(optionalAttr) @@ -6443,47 +6337,24 @@ func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, val return op.Output(0) } -// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. -type TensorArrayGatherV3Attr func(optionalAttr) - -// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. -// -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// Gather specific elements from the TensorArray into output `value`. -// -// All elements selected by `indices` must have the same shape. +// Push an element onto the tensor_array. // // Arguments: // handle: The handle to a TensorArray. -// indices: The locations in the TensorArray from which to read tensor elements. +// index: The position to write to inside the TensorArray. +// value: The tensor to write to the TensorArray. // flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. // -// Returns All of the elements in the TensorArray, concatenated along a new -// axis (the new dimension 0). -func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { +// Returns A float scalar that enforces proper chaining of operations. +func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "TensorArrayGatherV3", + Type: "TensorArrayWriteV3", Input: []tf.Input{ - handle, indices, flow_in, + handle, index, value, flow_in, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -6580,151 +6451,98 @@ func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source return op.Output(0), op.Output(1) } -// Pop the element at the top of the stack. +// Delete the stack from its resource container. // // Arguments: // handle: The handle to a stack. -// elem_type: The type of the elem that is popped. // -// Returns The tensor that is popped from the top of the stack. -func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { +// Returns the created operation. +func StackCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StackCloseV2", + Input: []tf.Input{ + handle, + }, + } + return scope.AddOperation(opspec) +} + +// StackPushV2Attr is an optional argument to StackPushV2. +type StackPushV2Attr func(optionalAttr) + +// StackPushV2SwapMemory sets the optional swap_memory attribute to value. +// +// value: Swap `elem` to CPU. Default to false. +// If not specified, defaults to false +func StackPushV2SwapMemory(value bool) StackPushV2Attr { + return func(m optionalAttr) { + m["swap_memory"] = value + } +} + +// Push an element onto the stack. +// +// Arguments: +// handle: The handle to a stack. +// elem: The tensor to be pushed onto the stack. +// +// Returns The same tensor as the input 'elem'. +func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...StackPushV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StackPushV2", + Input: []tf.Input{ + handle, elem, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StackV2Attr is an optional argument to StackV2. +type StackV2Attr func(optionalAttr) + +// StackV2StackName sets the optional stack_name attribute to value. +// +// value: Overrides the name used for the temporary stack resource. Default +// value is the name of the 'Stack' op (which is guaranteed unique). +// If not specified, defaults to "" +func StackV2StackName(value string) StackV2Attr { + return func(m optionalAttr) { + m["stack_name"] = value + } +} + +// A stack that produces elements in first-in last-out order. +// +// Arguments: +// max_size: The maximum size of the stack if non-negative. If negative, the stack +// size is unlimited. +// elem_type: The type of the elements on the stack. +// +// Returns The handle to the stack. +func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"elem_type": elem_type} - opspec := tf.OpSpec{ - Type: "StackPopV2", - Input: []tf.Input{ - handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// OneHotAttr is an optional argument to OneHot. -type OneHotAttr func(optionalAttr) - -// OneHotAxis sets the optional axis attribute to value. -// -// value: The axis to fill (default: -1, a new inner-most axis). -// If not specified, defaults to -1 -func OneHotAxis(value int64) OneHotAttr { - return func(m optionalAttr) { - m["axis"] = value - } -} - -// Returns a one-hot tensor. -// -// The locations represented by indices in `indices` take value `on_value`, -// while all other locations take value `off_value`. -// -// If the input `indices` is rank `N`, the output will have rank `N+1`, -// The new axis is created at dimension `axis` (default: the new axis is -// appended at the end). -// -// If `indices` is a scalar the output shape will be a vector of length `depth`. -// -// If `indices` is a vector of length `features`, the output shape will be: -// ``` -// features x depth if axis == -1 -// depth x features if axis == 0 -// ``` -// -// If `indices` is a matrix (batch) with shape `[batch, features]`, -// the output shape will be: -// ``` -// batch x features x depth if axis == -1 -// batch x depth x features if axis == 1 -// depth x batch x features if axis == 0 -// ``` -// -// -// Examples -// ========= -// -// Suppose that -// ``` -// indices = [0, 2, -1, 1] -// depth = 3 -// on_value = 5.0 -// off_value = 0.0 -// axis = -1 -// ``` -// -// Then output is `[4 x 3]`: -// ``` -// output = -// [5.0 0.0 0.0] // one_hot(0) -// [0.0 0.0 5.0] // one_hot(2) -// [0.0 0.0 0.0] // one_hot(-1) -// [0.0 5.0 0.0] // one_hot(1) -// ``` -// -// Suppose that -// ``` -// indices = [0, 2, -1, 1] -// depth = 3 -// on_value = 0.0 -// off_value = 3.0 -// axis = 0 -// ``` -// -// Then output is `[3 x 4]`: -// ``` -// output = -// [0.0 3.0 3.0 3.0] -// [3.0 3.0 3.0 0.0] -// [3.0 3.0 3.0 3.0] -// [3.0 0.0 3.0 3.0] -// // ^ one_hot(0) -// // ^ one_hot(2) -// // ^ one_hot(-1) -// // ^ one_hot(1) -// ``` -// -// Suppose that -// ``` -// indices = [[0, 2], [1, -1]] -// depth = 3 -// on_value = 1.0 -// off_value = 0.0 -// axis = -1 -// ``` -// -// Then output is `[2 x 2 x 3]`: -// ``` -// output = -// [ -// [1.0, 0.0, 0.0] // one_hot(0) -// [0.0, 0.0, 1.0] // one_hot(2) -// ][ -// [0.0, 1.0, 0.0] // one_hot(1) -// [0.0, 0.0, 0.0] // one_hot(-1) -// ] -// ``` -// -// Arguments: -// indices: A tensor of indices. -// depth: A scalar defining the depth of the one hot dimension. -// on_value: A scalar defining the value to fill in output when `indices[j] = i`. -// off_value: A scalar defining the value to fill in output when `indices[j] != i`. -// -// Returns The one-hot tensor. -func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OneHot", + Type: "StackV2", Input: []tf.Input{ - indices, depth, on_value, off_value, + max_size, }, Attrs: attrs, } @@ -6732,147 +6550,79 @@ func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output return op.Output(0) } -// Computes the number of elements in the given queue. +// Add the quantile summaries to each quantile stream resource. +// +// An op that adds a list of quantile summaries to a quantile stream resource. Each +// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank) +// for a single feature. // // Arguments: -// handle: The handle to a queue. +// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. +// summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature. // -// Returns The number of elements in the given queue. -func QueueSizeV2(scope *Scope, handle tf.Output) (size tf.Output) { +// Returns the created operation. +func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QueueSizeV2", + Type: "BoostedTreesQuantileStreamResourceAddSummaries", Input: []tf.Input{ - handle, + quantile_stream_resource_handle, tf.OutputList(summaries), }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. -type QueueDequeueManyV2Attr func(optionalAttr) +// EmptyAttr is an optional argument to Empty. +type EmptyAttr func(optionalAttr) -// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. +// EmptyInit sets the optional init attribute to value. // -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } -} - -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// If the queue is closed and there are fewer than `n` elements, then an -// OutOfRange error is returned. -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size `n` in the 0th dimension. -// -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. -// -// N.B. If the queue is empty, this operation will block until `n` elements -// have been dequeued (or 'timeout_ms' elapses, if specified). -// -// Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. -// -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueDequeueManyV2", - Input: []tf.Input{ - handle, n, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueManyV2", err) - return - } - return components -} - -// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize. -type QuantizeAndDequantizeAttr func(optionalAttr) - -// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["signed_input"] = value - } -} - -// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value. +// value: If True, initialize the returned tensor with the default value of dtype. Otherwise, the implementation is free not to initializethe tensor's content. // If not specified, defaults to false -func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr { +func EmptyInit(value bool) EmptyAttr { return func(m optionalAttr) { - m["range_given"] = value + m["init"] = value } } -// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["input_min"] = value - } -} - -// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["input_max"] = value - } -} - -// Use QuantizeAndDequantizeV2 instead. +// Creates a tensor with the given shape. // -// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 -func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { +// This operation creates a tensor of `shape` and `dtype`. +// +// Arguments: +// shape: 1-D. Represents the shape of the output tensor. +// +// +// Returns A `Tensor` of type `T`. +func Empty(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...EmptyAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeAndDequantize", + Type: "Empty", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Broadcasts a tensor value to one or more other devices. +func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} + opspec := tf.OpSpec{ + Type: "CollectiveBcastSend", Input: []tf.Input{ input, }, @@ -6882,79 +6632,48 @@ func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAn return op.Output(0) } -// Returns locations of nonzero / true values in a tensor. +// QueueCloseV2Attr is an optional argument to QueueCloseV2. +type QueueCloseV2Attr func(optionalAttr) + +// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. // -// This operation returns the coordinates of true elements in `condition`. The -// coordinates are returned in a 2-D tensor where the first dimension (rows) -// represents the number of true elements, and the second dimension (columns) -// represents the coordinates of the true elements. Keep in mind, the shape of -// the output tensor can vary depending on how many true values there are in -// `condition`. Indices are output in row-major order. +// value: If true, all pending enqueue requests that are +// blocked on the given queue will be canceled. +// If not specified, defaults to false +func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { + return func(m optionalAttr) { + m["cancel_pending_enqueues"] = value + } +} + +// Closes the given queue. // -// For example: +// This operation signals that no more elements will be enqueued in the +// given queue. Subsequent Enqueue(Many) operations will fail. +// Subsequent Dequeue(Many) operations will continue to succeed if +// sufficient elements remain in the queue. Subsequent Dequeue(Many) +// operations that would block will fail immediately. // -// ``` -// # 'input' tensor is [[True, False] -// # [True, False]] -// # 'input' has two true values, so output has two coordinates. -// # 'input' has rank of 2, so coordinates have two indices. -// where(input) ==> [[0, 0], -// [1, 0]] +// Arguments: +// handle: The handle to a queue. // -// # `condition` tensor is [[[True, False] -// # [True, False]] -// # [[False, True] -// # [False, True]] -// # [[False, False] -// # [False, True]]] -// # 'input' has 5 true values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] -// -// # `condition` tensor is [[[1.5, 0.0] -// # [-0.5, 0.0]] -// # [[0.0, 0.25] -// # [0.0, 0.75]] -// # [[0.0, 0.0] -// # [0.0, 0.01]]] -// # 'input' has 5 nonzero values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] -// -// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j] -// # [0.0 + 0.5j, 0.0 + 0.0j]] -// # [[0.0 + 0.0j, 0.25 + 1.5j] -// # [0.0 + 0.0j, 0.75 + 0.0j]] -// # [[0.0 + 0.0j, 0.0 + 0.0j] -// # [0.0 + 0.0j, 0.01 + 0.0j]]] -// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] -// ``` -func Where(scope *Scope, condition tf.Output) (index tf.Output) { +// Returns the created operation. +func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Where", - Input: []tf.Input{ - condition, - }, + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } - op := scope.AddOperation(opspec) - return op.Output(0) + opspec := tf.OpSpec{ + Type: "QueueCloseV2", + Input: []tf.Input{ + handle, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) } // QueueDequeueV2Attr is an optional argument to QueueDequeueV2. @@ -7060,133 +6779,150 @@ func QueueEnqueueV2(scope *Scope, handle tf.Output, components []tf.Output, opti return scope.AddOperation(opspec) } -// MfccAttr is an optional argument to Mfcc. -type MfccAttr func(optionalAttr) +// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. +type PriorityQueueV2Attr func(optionalAttr) -// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. +// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. // -// value: The highest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 4000 -func MfccUpperFrequencyLimit(value float32) MfccAttr { +// value: The type of each component in a value. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { return func(m optionalAttr) { - m["upper_frequency_limit"] = value + m["component_types"] = value } } -// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. +// PriorityQueueV2Capacity sets the optional capacity attribute to value. // -// value: The lowest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 20 -func MfccLowerFrequencyLimit(value float32) MfccAttr { +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { return func(m optionalAttr) { - m["lower_frequency_limit"] = value + m["capacity"] = value } } -// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. +// PriorityQueueV2Container sets the optional container attribute to value. // -// value: Resolution of the Mel bank used internally. -// If not specified, defaults to 40 -func MfccFilterbankChannelCount(value int64) MfccAttr { +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func PriorityQueueV2Container(value string) PriorityQueueV2Attr { return func(m optionalAttr) { - m["filterbank_channel_count"] = value + m["container"] = value } } -// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. +// PriorityQueueV2SharedName sets the optional shared_name attribute to value. // -// value: How many output channels to produce per time slice. -// If not specified, defaults to 13 -func MfccDctCoefficientCount(value int64) MfccAttr { +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { return func(m optionalAttr) { - m["dct_coefficient_count"] = value + m["shared_name"] = value } } -// Transforms a spectrogram into a form that's useful for speech recognition. +// A queue that produces elements sorted by the first component value. // -// Mel Frequency Cepstral Coefficients are a way of representing audio data that's -// been effective as an input feature for machine learning. They are created by -// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the -// higher frequencies that are less significant to the human ear. They have a long -// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum -// is a good resource to learn more. +// Note that the PriorityQueue requires the first component of any element +// to be a scalar int64, in addition to the other elements declared by +// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue +// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra +// entry in their input (resp. output) lists. // // Arguments: -// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared -// set to true. -// sample_rate: How many samples per second the source audio used. -func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { +// shapes: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// +// Returns The handle to the queue. +func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"shapes": shapes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Mfcc", - Input: []tf.Input{ - spectrogram, sample_rate, - }, + Type: "PriorityQueueV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// PaddingFIFOQueueV2Attr is an optional argument to PaddingFIFOQueueV2. -type PaddingFIFOQueueV2Attr func(optionalAttr) +// Elementwise computes the bitwise XOR of `x` and `y`. +// +// The result will have those bits set, that are different in `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BitwiseXor", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// PaddingFIFOQueueV2Shapes sets the optional shapes attribute to value. +// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. +type FIFOQueueV2Attr func(optionalAttr) + +// FIFOQueueV2Shapes sets the optional shapes attribute to value. // // value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. -// Shapes of fixed rank but variable size are allowed by setting -// any shape dimension to -1. In this case, the inputs' shape may vary along -// the given dimension, and DequeueMany will pad the given dimension with -// zeros up to the maximum shape of all elements in the given batch. -// If the length of this attr is 0, different queue elements may have -// different ranks and shapes, but only one element may be dequeued at a time. +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. // If not specified, defaults to <> // // REQUIRES: len(value) >= 0 -func PaddingFIFOQueueV2Shapes(value []tf.Shape) PaddingFIFOQueueV2Attr { +func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { return func(m optionalAttr) { m["shapes"] = value } } -// PaddingFIFOQueueV2Capacity sets the optional capacity attribute to value. +// FIFOQueueV2Capacity sets the optional capacity attribute to value. // // value: The upper bound on the number of elements in this queue. // Negative numbers mean no limit. // If not specified, defaults to -1 -func PaddingFIFOQueueV2Capacity(value int64) PaddingFIFOQueueV2Attr { +func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { return func(m optionalAttr) { m["capacity"] = value } } -// PaddingFIFOQueueV2Container sets the optional container attribute to value. +// FIFOQueueV2Container sets the optional container attribute to value. // // value: If non-empty, this queue is placed in the given container. // Otherwise, a default container is used. // If not specified, defaults to "" -func PaddingFIFOQueueV2Container(value string) PaddingFIFOQueueV2Attr { +func FIFOQueueV2Container(value string) FIFOQueueV2Attr { return func(m optionalAttr) { m["container"] = value } } -// PaddingFIFOQueueV2SharedName sets the optional shared_name attribute to value. +// FIFOQueueV2SharedName sets the optional shared_name attribute to value. // // value: If non-empty, this queue will be shared under the given name // across multiple sessions. // If not specified, defaults to "" -func PaddingFIFOQueueV2SharedName(value string) PaddingFIFOQueueV2Attr { +func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { return func(m optionalAttr) { m["shared_name"] = value } @@ -7194,15 +6930,11 @@ func PaddingFIFOQueueV2SharedName(value string) PaddingFIFOQueueV2Attr { // A queue that produces elements in first-in first-out order. // -// Variable-size shapes are allowed by setting the corresponding shape dimensions -// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum -// size of any given element in the minibatch. See below for details. -// // Arguments: // component_types: The type of each component in a value. // // Returns The handle to the queue. -func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...PaddingFIFOQueueV2Attr) (handle tf.Output) { +func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } @@ -7211,7 +6943,114 @@ func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional .. a(attrs) } opspec := tf.OpSpec{ - Type: "PaddingFIFOQueueV2", + Type: "FIFOQueueV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. +type RandomShuffleQueueV2Attr func(optionalAttr) + +// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. +// +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["shapes"] = value + } +} + +// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// +// value: Dequeue will block unless there would be this +// many elements after the dequeue or the queue is closed. This +// ensures a minimum level of mixing of elements. +// If not specified, defaults to 0 +func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["min_after_dequeue"] = value + } +} + +// RandomShuffleQueueV2Seed sets the optional seed attribute to value. +// +// value: If either seed or seed2 is set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// RandomShuffleQueueV2Container sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A queue that randomizes the order of elements. +// +// Arguments: +// component_types: The type of each component in a value. +// +// Returns The handle to the queue. +func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomShuffleQueueV2", Attrs: attrs, } @@ -7244,9 +7083,10 @@ func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional .. // // merged.shape = [max(indices)] + constant // -// Values may be merged in parallel, so if an index appears in both `indices[m][i]` -// and `indices[n][j]`, the result may be invalid. This differs from the normal -// DynamicStitch operator that defines the behavior in that case. +// Values are merged in order, so if an index appears in both `indices[m][i]` and +// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the +// merged result. If you do not need this guarantee, ParallelDynamicStitch might +// perform better on some devices. // // For example: // @@ -7282,12 +7122,12 @@ func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional .. //
// //
-func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ParallelDynamicStitch", + Type: "DynamicStitch", Input: []tf.Input{ tf.OutputList(indices), tf.OutputList(data), }, @@ -7365,6 +7205,58 @@ func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_pa return outputs } +// MultiDeviceIteratorFromStringHandleAttr is an optional argument to MultiDeviceIteratorFromStringHandle. +type MultiDeviceIteratorFromStringHandleAttr func(optionalAttr) + +// MultiDeviceIteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. +// +// value: The type list for the return values. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func MultiDeviceIteratorFromStringHandleOutputTypes(value []tf.DataType) MultiDeviceIteratorFromStringHandleAttr { + return func(m optionalAttr) { + m["output_types"] = value + } +} + +// MultiDeviceIteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. +// +// value: The list of shapes being produced. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func MultiDeviceIteratorFromStringHandleOutputShapes(value []tf.Shape) MultiDeviceIteratorFromStringHandleAttr { + return func(m optionalAttr) { + m["output_shapes"] = value + } +} + +// Generates a MultiDeviceIterator resource from its provided string handle. +// +// Arguments: +// string_handle: String representing the resource. +// +// Returns A MultiDeviceIterator resource. +func MultiDeviceIteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...MultiDeviceIteratorFromStringHandleAttr) (multi_device_iterator tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MultiDeviceIteratorFromStringHandle", + Input: []tf.Input{ + string_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Produces a string handle for the given MultiDeviceIterator. // // Arguments: @@ -7385,61 +7277,26 @@ func MultiDeviceIteratorToStringHandle(scope *Scope, multi_device_iterator tf.Ou return op.Output(0) } -// Checks whether a tree has been initialized. +// Returns a tensor of ones with the same shape and type as x. // // Arguments: -// tree_handle: Handle to the tree. +// x: a tensor of type T. // -// Returns Whether the tree is initialized. -func TensorForestTreeIsInitializedOp(scope *Scope, tree_handle tf.Output) (is_initialized tf.Output) { +// Returns a tensor of the same shape and type as x but filled with ones. +func OnesLike(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorForestTreeIsInitializedOp", + Type: "OnesLike", Input: []tf.Input{ - tree_handle, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Gets next element for the provided shard number. -// -// Arguments: -// multi_device_iterator: A MultiDeviceIterator resource. -// shard_num: Integer representing which shard to fetch data for. -// incarnation_id: Which incarnation of the MultiDeviceIterator is running. -// output_types: The type list for the return values. -// output_shapes: The list of shapes being produced. -// -// Returns Result of the get_next on the dataset. -func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.Output, shard_num tf.Output, incarnation_id tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "MultiDeviceIteratorGetNextFromShard", - Input: []tf.Input{ - multi_device_iterator, shard_num, incarnation_id, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("MultiDeviceIteratorGetNextFromShard", err) - return - } - return components -} - // Initializes the multi device iterator with the given dataset. // // Arguments: @@ -7463,64 +7320,96 @@ func MultiDeviceIteratorInit(scope *Scope, dataset tf.Output, multi_device_itera return op.Output(0) } -// Copy a tensor setting everything outside a central band in each innermost matrix +// MapStageAttr is an optional argument to MapStage. +type MapStageAttr func(optionalAttr) + +// MapStageCapacity sets the optional capacity attribute to value. // -// to zero. +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. +// If not specified, defaults to 0 // -// The `band` part is computed as follows: -// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a -// tensor with the same shape where +// REQUIRES: value >= 0 +func MapStageCapacity(value int64) MapStageAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapStageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. +// REQUIRES: value >= 0 +func MapStageMemoryLimit(value int64) MapStageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapStageContainer sets the optional container attribute to value. // -// The indicator function +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func MapStageContainer(value string) MapStageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapStageSharedName sets the optional shared_name attribute to value. // -// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && -// (num_upper < 0 || (n-m) <= num_upper)`. -// -// For example: -// -// ``` -// # if 'input' is [[ 0, 1, 2, 3] -// [-1, 0, 1, 2] -// [-2, -1, 0, 1] -// [-3, -2, -1, 0]], -// -// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] -// [-1, 0, 1, 2] -// [ 0, -1, 0, 1] -// [ 0, 0, -1, 0]], -// -// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] -// [-1, 0, 1, 0] -// [-2, -1, 0, 1] -// [ 0, -2, -1, 0]] -// ``` -// -// Useful special cases: -// -// ``` -// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. -// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. -// tf.matrix_band_part(input, 0, 0) ==> Diagonal. -// ``` +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func MapStageSharedName(value string) MapStageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Stage (key, values) in the underlying container which behaves like a hashtable. // // Arguments: -// input: Rank `k` tensor. -// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire -// lower triangle. -// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep -// entire upper triangle. +// key: int64 // -// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor. -func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) { +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. +// +// +// Returns the created operation. +func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...MapStageAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MapStage", + Input: []tf.Input{ + key, indices, tf.OutputList(values), + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Store the input tensor in the state of the current session. +// +// Arguments: +// value: The tensor to be stored. +// +// Returns The handle for the tensor stored in the session state, represented +// as a string. +func GetSessionHandle(scope *Scope, value tf.Output) (handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixBandPart", + Type: "GetSessionHandle", Input: []tf.Input{ - input, num_lower, num_upper, + value, }, } op := scope.AddOperation(opspec) @@ -7585,64 +7474,38 @@ func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) { return op.Output(0) } -// Deserializes a proto into the tree handle +// OptimizeDatasetAttr is an optional argument to OptimizeDataset. +type OptimizeDatasetAttr func(optionalAttr) + +// OptimizeDatasetOptimizationConfigs sets the optional optimization_configs attribute to value. +// If not specified, defaults to <> +func OptimizeDatasetOptimizationConfigs(value []string) OptimizeDatasetAttr { + return func(m optionalAttr) { + m["optimization_configs"] = value + } +} + +// Creates a dataset by applying optimizations to `input_dataset`. +// +// Creates a dataset by applying optimizations to `input_dataset`. // // Arguments: -// tree_handle: Handle to the tree resource to be restored. -// tree_config: Serialied proto string of the boosted_trees.Tree proto. +// input_dataset: A variant tensor representing the input dataset. +// optimizations: A `tf.string` vector `tf.Tensor` identifying optimizations to use. // -// Returns the created operation. -func TensorForestTreeDeserialize(scope *Scope, tree_handle tf.Output, tree_config tf.Output) (o *tf.Operation) { +// +func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...OptimizeDatasetAttr) (handle tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "TensorForestTreeDeserialize", - Input: []tf.Input{ - tree_handle, tree_config, - }, - } - return scope.AddOperation(opspec) -} - -// Constructs an Optional variant from a tuple of tensors. -func OptionalFromValue(scope *Scope, components []tf.Output) (optional tf.Output) { - if scope.Err() != nil { - return + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) } opspec := tf.OpSpec{ - Type: "OptionalFromValue", + Type: "OptimizeDataset", Input: []tf.Input{ - tf.OutputList(components), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs a tensor containing the reduction across all input tensors. -// -// Outputs a tensor containing the reduction across all input tensors passed to ops -// within the same `shared_name. -// -// The graph should be constructed so if one op runs with shared_name value `c`, -// then `num_devices` ops will run with shared_name value `c`. Failure to do so -// will cause the graph execution to fail to complete. -// -// input: the input to the reduction -// data: the value of the reduction across all `num_devices` devices. -// reduction: the reduction operation to perform. -// num_devices: The number of devices participating in this reduction. -// shared_name: Identifier that shared between ops of the same reduction. -func NcclAllReduce(scope *Scope, input tf.Output, reduction string, num_devices int64, shared_name string) (data tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"reduction": reduction, "num_devices": num_devices, "shared_name": shared_name} - opspec := tf.OpSpec{ - Type: "NcclAllReduce", - Input: []tf.Input{ - input, + input_dataset, optimizations, }, Attrs: attrs, } @@ -7650,31 +7513,85 @@ func NcclAllReduce(scope *Scope, input tf.Output, reduction string, num_devices return op.Output(0) } -// RegexReplaceAttr is an optional argument to RegexReplace. -type RegexReplaceAttr func(optionalAttr) - -// RegexReplaceReplaceGlobal sets the optional replace_global attribute to value. +// Returns a serialized GraphDef representing `input_dataset`. // -// value: If True, the replacement is global, otherwise the replacement -// is done only on the first match. -// If not specified, defaults to true -func RegexReplaceReplaceGlobal(value bool) RegexReplaceAttr { +// Returns a graph representation for `input_dataset`. +// +// Arguments: +// input_dataset: A variant tensor representing the dataset to return the graph representation for. +// +// Returns The graph representation of the dataset (as serialized GraphDef). +func DatasetToGraph(scope *Scope, input_dataset tf.Output) (graph tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DatasetToGraph", + Input: []tf.Input{ + input_dataset, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Converts the given variant tensor to an iterator and stores it in the given resource. +// +// Arguments: +// resource_handle: A handle to an iterator resource. +// serialized: A variant tensor storing the state of the iterator contained in the +// resource. +// +// Returns the created operation. +func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DeserializeIterator", + Input: []tf.Input{ + resource_handle, serialized, + }, + } + return scope.AddOperation(opspec) +} + +// IteratorFromStringHandleAttr is an optional argument to IteratorFromStringHandle. +type IteratorFromStringHandleAttr func(optionalAttr) + +// IteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. +// +// value: If specified, defines the type of each tuple component in an +// element produced by the resulting iterator. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func IteratorFromStringHandleOutputTypes(value []tf.DataType) IteratorFromStringHandleAttr { return func(m optionalAttr) { - m["replace_global"] = value + m["output_types"] = value } } -// Replaces the match of pattern in input with rewrite. +// IteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. // -// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// value: If specified, defines the shape of each tuple component in an +// element produced by the resulting iterator. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func IteratorFromStringHandleOutputShapes(value []tf.Shape) IteratorFromStringHandleAttr { + return func(m optionalAttr) { + m["output_shapes"] = value + } +} + +// Converts the given string representing a handle to an iterator to a resource. // // Arguments: -// input: The text to be processed. -// pattern: The regular expression to match the input. -// rewrite: The rewrite to be applied to the matched expression. +// string_handle: A string representation of the given handle. // -// Returns The text after applying pattern and rewrite. -func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.Output, optional ...RegexReplaceAttr) (output tf.Output) { +// Returns A handle to an iterator resource. +func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...IteratorFromStringHandleAttr) (resource_handle tf.Output) { if scope.Err() != nil { return } @@ -7683,9 +7600,9 @@ func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.O a(attrs) } opspec := tf.OpSpec{ - Type: "RegexReplace", + Type: "IteratorFromStringHandle", Input: []tf.Input{ - input, pattern, rewrite, + string_handle, }, Attrs: attrs, } @@ -7693,47 +7610,286 @@ func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.O return op.Output(0) } -// Quantized Batch normalization. -// -// This op is deprecated and will be removed in the future. Prefer -// `tf.nn.batch_normalization`. +// Converts the given `resource_handle` representing an iterator to a string. // // Arguments: -// t: A 4D input Tensor. -// t_min: The value represented by the lowest quantized input. -// t_max: The value represented by the highest quantized input. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// m_min: The value represented by the lowest quantized mean. -// m_max: The value represented by the highest quantized mean. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// v_min: The value represented by the lowest quantized variance. -// v_max: The value represented by the highest quantized variance. -// beta: A 1D beta Tensor with size matching the last dimension of t. -// An offset to be added to the normalized tensor. -// beta_min: The value represented by the lowest quantized offset. -// beta_max: The value represented by the highest quantized offset. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this tensor will be multiplied -// with the normalized tensor. -// gamma_min: The value represented by the lowest quantized gamma. -// gamma_max: The value represented by the highest quantized gamma. +// resource_handle: A handle to an iterator resource. // -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -func QuantizedBatchNormWithGlobalNormalization(scope *Scope, t tf.Output, t_min tf.Output, t_max tf.Output, m tf.Output, m_min tf.Output, m_max tf.Output, v tf.Output, v_min tf.Output, v_max tf.Output, beta tf.Output, beta_min tf.Output, beta_max tf.Output, gamma tf.Output, gamma_min tf.Output, gamma_max tf.Output, out_type tf.DataType, variance_epsilon float32, scale_after_normalization bool) (result tf.Output, result_min tf.Output, result_max tf.Output) { +// Returns A string representation of the given handle. +func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type, "variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "QuantizedBatchNormWithGlobalNormalization", + Type: "IteratorToStringHandle", Input: []tf.Input{ - t, t_min, t_max, m, m_min, m_max, v, v_min, v_max, beta, beta_min, beta_max, gamma, gamma_min, gamma_max, + resource_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Outputs the single element from the given dataset. +// +// Arguments: +// dataset: A handle to a dataset that contains a single element. +// +// +// +// Returns The components of the single element of `input`. +func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "DatasetToSingleElement", + Input: []tf.Input{ + dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("DatasetToSingleElement", err) + return + } + return components +} + +// Computes offsets of concat inputs within its output. +// +// For example: +// +// ``` +// # 'x' is [2, 2, 7] +// # 'y' is [2, 3, 7] +// # 'z' is [2, 5, 7] +// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] +// ``` +// +// This is typically used by gradient computations for a concat operation. +// +// Arguments: +// concat_dim: The dimension along which to concatenate. +// shape: The `N` int32 vectors representing shape of tensors being concatenated. +// +// Returns The `N` int32 vectors representing the starting offset +// of input tensors within the concatenated output. +func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatOffset", + Input: []tf.Input{ + concat_dim, tf.OutputList(shape), + }, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { + scope.UpdateErr("ConcatOffset", err) + return + } + return offset +} + +// Gets the next output from the given iterator. +// +// This operation is a synchronous version IteratorGetNext. It should only be used +// in situations where the iterator does not block the calling thread, or where +// the calling thread is not a member of the thread pool used to execute parallel +// operations (e.g. in eager mode). +func IteratorGetNextSync(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "IteratorGetNextSync", + Input: []tf.Input{ + iterator, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("IteratorGetNextSync", err) + return + } + return components +} + +// Gets the next output from the given iterator . +func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "IteratorGetNext", + Input: []tf.Input{ + iterator, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("IteratorGetNext", err) + return + } + return components +} + +// A container for an iterator resource. +// +// Arguments: +// handle: A handle to the iterator to delete. +// deleter: A variant deleter. +// +// Returns the created operation. +func DeleteIterator(scope *Scope, handle tf.Output, deleter tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DeleteIterator", + Input: []tf.Input{ + handle, deleter, + }, + } + return scope.AddOperation(opspec) +} + +// A container for an iterator resource. +// +// Returns A handle to the iterator that can be passed to a "MakeIterator" or +// "IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents +// resource sharing by name, and does not keep a reference to the resource +// container. +func AnonymousIterator(scope *Scope, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "AnonymousIterator", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// A container for an iterator resource. +// +// Returns A handle to the iterator that can be passed to a "MakeIterator" +// or "IteratorGetNext" op. +func Iterator(scope *Scope, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "Iterator", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that emits the records from one or more TFRecord files. +// +// Arguments: +// filenames: A scalar or vector containing the name(s) of the file(s) to be +// read. +// compression_type: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// buffer_size: A scalar representing the number of bytes to buffer. A value of +// 0 means no buffering will be performed. +func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TFRecordDataset", + Input: []tf.Input{ + filenames, compression_type, buffer_size, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// UniqueWithCountsAttr is an optional argument to UniqueWithCounts. +type UniqueWithCountsAttr func(optionalAttr) + +// UniqueWithCountsOutIdx sets the optional out_idx attribute to value. +// If not specified, defaults to DT_INT32 +func UniqueWithCountsOutIdx(value tf.DataType) UniqueWithCountsAttr { + return func(m optionalAttr) { + m["out_idx"] = value + } +} + +// Finds unique elements in a 1-D tensor. +// +// This operation returns a tensor `y` containing all of the unique elements of `x` +// sorted in the same order that they occur in `x`. This operation also returns a +// tensor `idx` the same size as `x` that contains the index of each value of `x` +// in the unique output `y`. Finally, it returns a third tensor `count` that +// contains the count of each element of `y` in `x`. In other words: +// +// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` +// +// For example: +// +// ``` +// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] +// y, idx, count = unique_with_counts(x) +// y ==> [1, 2, 4, 7, 8] +// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +// count ==> [2, 1, 3, 1, 2] +// ``` +// +// Arguments: +// x: 1-D. +// +// Returns 1-D.1-D.1-D. +func UniqueWithCounts(scope *Scope, x tf.Output, optional ...UniqueWithCountsAttr) (y tf.Output, idx tf.Output, count tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UniqueWithCounts", + Input: []tf.Input{ + x, }, Attrs: attrs, } @@ -7741,6 +7897,603 @@ func QuantizedBatchNormWithGlobalNormalization(scope *Scope, t tf.Output, t_min return op.Output(0), op.Output(1), op.Output(2) } +// Creates a dataset that emits the records from one or more binary files. +// +// Arguments: +// filenames: A scalar or a vector containing the name(s) of the file(s) to be +// read. +// header_bytes: A scalar representing the number of bytes to skip at the +// beginning of a file. +// record_bytes: A scalar representing the number of bytes in each record. +// footer_bytes: A scalar representing the number of bytes to skip at the end +// of a file. +// buffer_size: A scalar representing the number of bytes to buffer. Must be > 0. +func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf.Output, record_bytes tf.Output, footer_bytes tf.Output, buffer_size tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FixedLengthRecordDataset", + Input: []tf.Input{ + filenames, header_bytes, record_bytes, footer_bytes, buffer_size, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FakeQuantWithMinMaxVarsPerChannelAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannel. +type FakeQuantWithMinMaxVarsPerChannelAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsPerChannelNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsPerChannelNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxVarsPerChannelNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsPerChannelNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`, +// +// `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` +// to 'outputs' tensor of same shape as `inputs`. +// +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. +// +// Before quantization, `min` and `max` values are adjusted with the following +// logic. +// It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, +// the behavior can be unexpected: +// If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. +// If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. +// If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, +// `min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. +// +// This operation has a gradient and thus allows for training `min` and `max` +// values. +func FakeQuantWithMinMaxVarsPerChannel(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelAttr) (outputs tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxVarsPerChannel", + Input: []tf.Input{ + inputs, min, max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ShuffleDatasetAttr is an optional argument to ShuffleDataset. +type ShuffleDatasetAttr func(optionalAttr) + +// ShuffleDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value. +// +// value: If true, each iterator over this dataset will be given +// a different pseudorandomly generated seed, based on a sequence seeded by the +// `seed` and `seed2` inputs. If false, each iterator will be given the same +// seed, and repeated iteration over this dataset will yield the exact same +// sequence of results. +// If not specified, defaults to true +func ShuffleDatasetReshuffleEachIteration(value bool) ShuffleDatasetAttr { + return func(m optionalAttr) { + m["reshuffle_each_iteration"] = value + } +} + +// Creates a dataset that shuffles elements from `input_dataset` pseudorandomly. +// +// Arguments: +// +// buffer_size: The number of output elements to buffer in an iterator over +// this dataset. Compare with the `min_after_dequeue` attr when creating a +// `RandomShuffleQueue`. +// seed: A scalar seed for the random number generator. If either `seed` or +// `seed2` is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. +// +// +func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleDatasetAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ShuffleDataset", + Input: []tf.Input{ + input_dataset, buffer_size, seed, seed2, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// PaddedBatchDatasetV2Attr is an optional argument to PaddedBatchDatasetV2. +type PaddedBatchDatasetV2Attr func(optionalAttr) + +// PaddedBatchDatasetV2ParallelCopy sets the optional parallel_copy attribute to value. +// If not specified, defaults to false +func PaddedBatchDatasetV2ParallelCopy(value bool) PaddedBatchDatasetV2Attr { + return func(m optionalAttr) { + m["parallel_copy"] = value + } +} + +// Creates a dataset that batches and pads `batch_size` elements from the input. +// +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// padded_shapes: A list of int64 tensors representing the desired padded shapes +// of the corresponding output components. These shapes may be partially +// specified, using `-1` to indicate that a particular dimension should be +// padded to the maximum size of all batch elements. +// padding_values: A list of scalars containing the padding value to use for +// each of the outputs. +// drop_remainder: A scalar representing whether the last batch should be dropped in case its size +// is smaller than desired. +// +func PaddedBatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, drop_remainder tf.Output, output_shapes []tf.Shape, optional ...PaddedBatchDatasetV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "PaddedBatchDatasetV2", + Input: []tf.Input{ + input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), drop_remainder, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ShardDatasetAttr is an optional argument to ShardDataset. +type ShardDatasetAttr func(optionalAttr) + +// ShardDatasetRequireNonEmpty sets the optional require_non_empty attribute to value. +// If not specified, defaults to false +func ShardDatasetRequireNonEmpty(value bool) ShardDatasetAttr { + return func(m optionalAttr) { + m["require_non_empty"] = value + } +} + +// Creates a `Dataset` that includes only 1/`num_shards` of this dataset. +// +// Arguments: +// +// num_shards: An integer representing the number of shards operating in parallel. +// index: An integer representing the current worker index. +// +// +func ShardDataset(scope *Scope, input_dataset tf.Output, num_shards tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShardDatasetAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ShardDataset", + Input: []tf.Input{ + input_dataset, num_shards, index, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// BatchDatasetV2Attr is an optional argument to BatchDatasetV2. +type BatchDatasetV2Attr func(optionalAttr) + +// BatchDatasetV2ParallelCopy sets the optional parallel_copy attribute to value. +// If not specified, defaults to false +func BatchDatasetV2ParallelCopy(value bool) BatchDatasetV2Attr { + return func(m optionalAttr) { + m["parallel_copy"] = value + } +} + +// Creates a dataset that batches `batch_size` elements from `input_dataset`. +// +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a batch. +// drop_remainder: A scalar representing whether the last batch should be dropped in case its size +// is smaller than desired. +// +// +func BatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...BatchDatasetV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "BatchDatasetV2", + Input: []tf.Input{ + input_dataset, batch_size, drop_remainder, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that batches `batch_size` elements from `input_dataset`. +// +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// +// +func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "BatchDataset", + Input: []tf.Input{ + input_dataset, batch_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// A dataset that creates window datasets from the input dataset. +// +// Arguments: +// +// size: A scalar representing the number of elements to accumulate in a window. +// shift: A scalar representing the steps moving the sliding window forward in one +// iteration. It must be positive. +// stride: A scalar representing the stride of the input elements of the sliding window. +// It must be positive. +// drop_remainder: A scalar representing whether a window should be dropped in case its size is +// smaller than desired. +// +// +func WindowDataset(scope *Scope, input_dataset tf.Output, size tf.Output, shift tf.Output, stride tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "WindowDataset", + Input: []tf.Input{ + input_dataset, size, shift, stride, drop_remainder, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset containing elements of first component of `input_dataset` having true in the last component. +func FilterByLastComponentDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "FilterByLastComponentDataset", + Input: []tf.Input{ + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. +type QuantizedRelu6Attr func(optionalAttr) + +// QuantizedRelu6OutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` +// +// Arguments: +// +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. +// +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedRelu6", + Input: []tf.Input{ + features, min_features, max_features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Bucketize each feature based on bucket boundaries. +// +// An op that returns a list of float tensors, where each tensor represents the +// bucketized values for a single feature. +// +// Arguments: +// float_values: float; List of Rank 1 Tensor each containing float values for a single feature. +// bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a single +// feature. +// +// Returns int; List of Rank 1 Tensors each containing the bucketized values for a single feature. +func BoostedTreesBucketize(scope *Scope, float_values []tf.Output, bucket_boundaries []tf.Output) (buckets []tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesBucketize", + Input: []tf.Input{ + tf.OutputList(float_values), tf.OutputList(bucket_boundaries), + }, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if buckets, idx, err = makeOutputList(op, idx, "buckets"); err != nil { + scope.UpdateErr("BoostedTreesBucketize", err) + return + } + return buckets +} + +// Set a summary_writer_interface to record statistics using given stats_aggregator. +// +// Returns the created operation. +func StatsAggregatorSetSummaryWriter(scope *Scope, stats_aggregator tf.Output, summary tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StatsAggregatorSetSummaryWriter", + Input: []tf.Input{ + stats_aggregator, summary, + }, + } + return scope.AddOperation(opspec) +} + +// QuantizedReluAttr is an optional argument to QuantizedRelu. +type QuantizedReluAttr func(optionalAttr) + +// QuantizedReluOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedReluOutType(value tf.DataType) QuantizedReluAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Computes Quantized Rectified Linear: `max(features, 0)` +// +// Arguments: +// +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. +// +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedRelu", + Input: []tf.Input{ + features, min_features, max_features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// ListDiffAttr is an optional argument to ListDiff. +type ListDiffAttr func(optionalAttr) + +// ListDiffOutIdx sets the optional out_idx attribute to value. +// If not specified, defaults to DT_INT32 +func ListDiffOutIdx(value tf.DataType) ListDiffAttr { + return func(m optionalAttr) { + m["out_idx"] = value + } +} + +// Computes the difference between two lists of numbers or strings. +// +// Given a list `x` and a list `y`, this operation returns a list `out` that +// represents all values that are in `x` but not in `y`. The returned list `out` +// is sorted in the same order that the numbers appear in `x` (duplicates are +// preserved). This operation also returns a list `idx` that represents the +// position of each `out` element in `x`. In other words: +// +// `out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` +// +// For example, given this input: +// +// ``` +// x = [1, 2, 3, 4, 5, 6] +// y = [1, 3, 5] +// ``` +// +// This operation would return: +// +// ``` +// out ==> [2, 4, 6] +// idx ==> [1, 3, 5] +// ``` +// +// Arguments: +// x: 1-D. Values to keep. +// y: 1-D. Values to remove. +// +// Returns 1-D. Values present in `x` but not in `y`.1-D. Positions of `x` values preserved in `out`. +func ListDiff(scope *Scope, x tf.Output, y tf.Output, optional ...ListDiffAttr) (out tf.Output, idx tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ListDiff", + Input: []tf.Input{ + x, y, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// EditDistanceAttr is an optional argument to EditDistance. +type EditDistanceAttr func(optionalAttr) + +// EditDistanceNormalize sets the optional normalize attribute to value. +// +// value: boolean (if true, edit distances are normalized by length of truth). +// +// The output is: +// If not specified, defaults to true +func EditDistanceNormalize(value bool) EditDistanceAttr { + return func(m optionalAttr) { + m["normalize"] = value + } +} + +// Computes the (possibly normalized) Levenshtein Edit Distance. +// +// The inputs are variable-length sequences provided by SparseTensors +// (hypothesis_indices, hypothesis_values, hypothesis_shape) +// and +// (truth_indices, truth_values, truth_shape). +// +// The inputs are: +// +// Arguments: +// hypothesis_indices: The indices of the hypothesis list SparseTensor. +// This is an N x R int64 matrix. +// hypothesis_values: The values of the hypothesis list SparseTensor. +// This is an N-length vector. +// hypothesis_shape: The shape of the hypothesis list SparseTensor. +// This is an R-length vector. +// truth_indices: The indices of the truth list SparseTensor. +// This is an M x R int64 matrix. +// truth_values: The values of the truth list SparseTensor. +// This is an M-length vector. +// truth_shape: truth indices, vector. +// +// Returns A dense float tensor with rank R - 1. +// +// For the example input: +// +// // hypothesis represents a 2x1 matrix with variable-length values: +// // (0,0) = ["a"] +// // (1,0) = ["b"] +// hypothesis_indices = [[0, 0, 0], +// [1, 0, 0]] +// hypothesis_values = ["a", "b"] +// hypothesis_shape = [2, 1, 1] +// +// // truth represents a 2x2 matrix with variable-length values: +// // (0,0) = [] +// // (0,1) = ["a"] +// // (1,0) = ["b", "c"] +// // (1,1) = ["a"] +// truth_indices = [[0, 1, 0], +// [1, 0, 0], +// [1, 0, 1], +// [1, 1, 0]] +// truth_values = ["a", "b", "c", "a"] +// truth_shape = [2, 2, 2] +// normalize = true +// +// The output will be: +// +// // output is a 2x2 matrix with edit distances normalized by truth lengths. +// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis +// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis +func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EditDistance", + Input: []tf.Input{ + hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Adds Tensor 'bias' to Tensor 'input' for Quantized types. // // Broadcasts the values of bias on dimensions 0..N-2 of 'input'. @@ -7771,89 +8524,27 @@ func QuantizedBiasAdd(scope *Scope, input tf.Output, bias tf.Output, min_input t return op.Output(0), op.Output(1), op.Output(2) } -// Produces the average pool of the input tensor for quantized types. +// Computes the reciprocal of x element-wise. // -// Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// ksize: The size of the window for each dimension of the input tensor. -// The length must be 4 to match the number of dimensions of the input. -// strides: The stride of the sliding window for each dimension of the input -// tensor. The length must be 4 to match the number of dimensions of the input. -// padding: The type of padding algorithm to use. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { +// I.e., \\(y = 1 / x\\). +func Inv(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "QuantizedAvgPool", + Type: "Inv", Input: []tf.Input{ - input, min_input, max_input, + x, }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Extract `patches` from `input` and put them in the "depth" output dimension. 3D extension of `extract_image_patches`. -// -// Arguments: -// input: 5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`. -// ksizes: The size of the sliding window for each dimension of `input`. -// strides: 1-D of length 5. How far the centers of two consecutive patches are in -// `input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`. -// padding: The type of padding algorithm to use. -// -// We specify the size-related attributes as: -// -// ```python -// ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1] -// strides = [1, stride_planes, strides_rows, strides_cols, 1] -// ``` -// -// Returns 5-D Tensor with shape `[batch, out_planes, out_rows, out_cols, -// ksize_planes * ksize_rows * ksize_cols * depth]` containing patches -// with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized -// in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols` -// are the dimensions of the output patches. -func ExtractVolumePatches(scope *Scope, input tf.Output, ksizes []int64, strides []int64, padding string) (patches tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "ExtractVolumePatches", - Input: []tf.Input{ - input, - }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// FractionalAvgPoolAttr is an optional argument to FractionalAvgPool. -type FractionalAvgPoolAttr func(optionalAttr) +// FractionalMaxPoolGradAttr is an optional argument to FractionalMaxPoolGrad. +type FractionalMaxPoolGradAttr func(optionalAttr) -// FractionalAvgPoolPseudoRandom sets the optional pseudo_random attribute to value. -// -// value: When set to True, generates the pooling sequence in a -// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin -// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for -// difference between pseudorandom and random. -// If not specified, defaults to false -func FractionalAvgPoolPseudoRandom(value bool) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["pseudo_random"] = value - } -} - -// FractionalAvgPoolOverlapping sets the optional overlapping attribute to value. +// FractionalMaxPoolGradOverlapping sets the optional overlapping attribute to value. // // value: When set to True, it means when pooling, the values at the boundary // of adjacent pooling cells are used by both cells. For example: @@ -7863,126 +8554,28 @@ func FractionalAvgPoolPseudoRandom(value bool) FractionalAvgPoolAttr { // `value 20 5 16 3 7` // // If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [41/3, 26/3] for fractional avg pooling. +// The result would be [20, 16] for fractional max pooling. // If not specified, defaults to false -func FractionalAvgPoolOverlapping(value bool) FractionalAvgPoolAttr { +func FractionalMaxPoolGradOverlapping(value bool) FractionalMaxPoolGradAttr { return func(m optionalAttr) { m["overlapping"] = value } } -// FractionalAvgPoolDeterministic sets the optional deterministic attribute to value. -// -// value: When set to True, a fixed pooling region will be used when -// iterating over a FractionalAvgPool node in the computation graph. Mainly used -// in unit test to make FractionalAvgPool deterministic. -// If not specified, defaults to false -func FractionalAvgPoolDeterministic(value bool) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["deterministic"] = value - } -} - -// FractionalAvgPoolSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func FractionalAvgPoolSeed(value int64) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// FractionalAvgPoolSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FractionalAvgPoolSeed2(value int64) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Performs fractional average pooling on the input. -// -// Fractional average pooling is similar to Fractional max pooling in the pooling -// region generation step. The only difference is that after pooling regions are -// generated, a mean operation is performed instead of a max operation in each -// pooling region. +// Computes gradient of the FractionalMaxPool function. // // Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// pooling_ratio: Pooling ratio for each dimension of `value`, currently only -// supports row and col dimension and should be >= 1.0. For example, a valid -// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements -// must be 1.0 because we don't allow pooling on batch and channels -// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions -// respectively. +// orig_input: Original input for `fractional_max_pool` +// orig_output: Original output for `fractional_max_pool` +// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients +// w.r.t. the output of `fractional_max_pool`. +// row_pooling_sequence: row pooling sequence, form pooling region with +// col_pooling_sequence. +// col_pooling_sequence: column pooling sequence, form pooling region with +// row_pooling sequence. // -// Returns output tensor after fractional avg pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. -func FractionalAvgPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalAvgPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FractionalAvgPool", - Input: []tf.Input{ - value, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// RandomCropAttr is an optional argument to RandomCrop. -type RandomCropAttr func(optionalAttr) - -// RandomCropSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomCropSeed(value int64) RandomCropAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomCropSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomCropSeed2(value int64) RandomCropAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Randomly crop `image`. -// -// DEPRECATED at GraphDef version 8: Random crop is now pure Python -// -// `size` is a 1-D int64 tensor with 2 elements representing the crop height and -// width. The values must be non negative. -// -// This Op picks a random location in `image` and crops a `height` by `width` -// rectangle from that location. The random location is picked so the cropped -// area will fit inside the original image. -// -// Arguments: -// image: 3-D of shape `[height, width, channels]`. -// size: 1-D of length 2 containing: `crop_height`, `crop_width`.. -// -// Returns 3-D of shape `[crop_height, crop_width, channels].` -func RandomCrop(scope *Scope, image tf.Output, size tf.Output, optional ...RandomCropAttr) (output tf.Output) { +// Returns 4-D. Gradients w.r.t. the input of `fractional_max_pool`. +func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalMaxPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -7991,9 +8584,9 @@ func RandomCrop(scope *Scope, image tf.Output, size tf.Output, optional ...Rando a(attrs) } opspec := tf.OpSpec{ - Type: "RandomCrop", + Type: "FractionalMaxPoolGrad", Input: []tf.Input{ - image, size, + orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence, }, Attrs: attrs, } @@ -8001,368 +8594,190 @@ func RandomCrop(scope *Scope, image tf.Output, size tf.Output, optional ...Rando return op.Output(0) } -// TopKV2Attr is an optional argument to TopKV2. -type TopKV2Attr func(optionalAttr) +// MaxPoolGradWithArgmaxAttr is an optional argument to MaxPoolGradWithArgmax. +type MaxPoolGradWithArgmaxAttr func(optionalAttr) -// TopKV2Sorted sets the optional sorted attribute to value. +// MaxPoolGradWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. // -// value: If true the resulting `k` elements will be sorted by the values in -// descending order. -// If not specified, defaults to true -func TopKV2Sorted(value bool) TopKV2Attr { +// value: Whether to include batch dimension in flattened index of `argmax`. +// If not specified, defaults to false +func MaxPoolGradWithArgmaxIncludeBatchInIndex(value bool) MaxPoolGradWithArgmaxAttr { return func(m optionalAttr) { - m["sorted"] = value + m["include_batch_in_index"] = value } } -// Finds values and indices of the `k` largest elements for the last dimension. -// -// If the input is a vector (rank-1), finds the `k` largest entries in the vector -// and outputs their values and indices as vectors. Thus `values[j]` is the -// `j`-th largest entry in `input`, and its index is `indices[j]`. -// -// For matrices (resp. higher rank input), computes the top `k` entries in each -// row (resp. vector along the last dimension). Thus, -// -// values.shape = indices.shape = input.shape[:-1] + [k] -// -// If two elements are equal, the lower-index element appears first. +// Computes gradients of the maxpooling function. // // Arguments: -// input: 1-D or higher with last dimension at least `k`. -// k: 0-D. Number of top elements to look for along the last dimension (along each -// row for matrices). +// input: The original input. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the +// output of `max_pool`. +// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. -func TopKV2(scope *Scope, input tf.Output, k tf.Output, optional ...TopKV2Attr) (values tf.Output, indices tf.Output) { +// Returns Gradients w.r.t. the input of `max_pool`. +func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradWithArgmaxAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TopKV2", + Type: "MaxPoolGradWithArgmax", Input: []tf.Input{ - input, k, + input, grad, argmax, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Provides the time since epoch in seconds. +// +// Returns the timestamp as a `float64` for seconds since the Unix epoch. +// +// Note: the timestamp is computed when the op is executed, not when it is added +// to the graph. +func Timestamp(scope *Scope) (ts tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Timestamp", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Says whether the targets are in the top `K` predictions. +// +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. +// +// More formally, let +// +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// +// Arguments: +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. +// +// Returns Computed Precision at `k` as a `bool Tensor`. +func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (precision tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"k": k} + opspec := tf.OpSpec{ + Type: "InTopK", + Input: []tf.Input{ + predictions, targets, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes softmax cross entropy cost and gradients to backpropagate. +// +// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept +// a matrix of label probabilities, but rather a single label per row +// of features. This label is considered to have probability 1.0 for the +// given row. +// +// Inputs are the logits, not probabilities. +// +// Arguments: +// features: batch_size x num_classes matrix +// labels: batch_size vector with values in [0, num_classes). +// This is the label for the given minibatch entry. +// +// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). +func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSoftmaxCrossEntropyWithLogits", + Input: []tf.Input{ + features, labels, + }, + } + op := scope.AddOperation(opspec) return op.Output(0), op.Output(1) } -// Returns x // y element-wise. -// -// *NOTE*: `FloorDiv` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FloorDiv", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// UnstageAttr is an optional argument to Unstage. +type UnstageAttr func(optionalAttr) -// Computes the inverse permutation of a tensor. -// -// This operation computes the inverse of an index permutation. It takes a 1-D -// integer tensor `x`, which represents the indices of a zero-based array, and -// swaps each value with its index position. In other words, for an output tensor -// `y` and an input tensor `x`, this operation computes the following: -// -// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]` -// -// The values must include 0. There can be no duplicate values or negative values. -// -// For example: -// -// ``` -// # tensor `x` is [3, 4, 0, 2, 1] -// invert_permutation(x) ==> [2, 4, 3, 0, 1] -// ``` -// -// Arguments: -// x: 1-D. -// -// Returns 1-D. -func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InvertPermutation", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes log softmax activations. -// -// For each batch `i` and class `j` we have -// -// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) -// -// Arguments: -// logits: 2-D with shape `[batch_size, num_classes]`. -// -// Returns Same shape as `logits`. -func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogSoftmax", - Input: []tf.Input{ - logits, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes softmax activations. -// -// For each batch `i` and class `j` we have -// -// $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$ -// -// Arguments: -// logits: 2-D with shape `[batch_size, num_classes]`. -// -// Returns Same shape as `logits`. -func Softmax(scope *Scope, logits tf.Output) (softmax tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Softmax", - Input: []tf.Input{ - logits, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DecodeBmpAttr is an optional argument to DecodeBmp. -type DecodeBmpAttr func(optionalAttr) - -// DecodeBmpChannels sets the optional channels attribute to value. +// UnstageCapacity sets the optional capacity attribute to value. // If not specified, defaults to 0 -func DecodeBmpChannels(value int64) DecodeBmpAttr { +// +// REQUIRES: value >= 0 +func UnstageCapacity(value int64) UnstageAttr { return func(m optionalAttr) { - m["channels"] = value + m["capacity"] = value } } -// Decode the first frame of a BMP-encoded image to a uint8 tensor. +// UnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// The attr `channels` indicates the desired number of color channels for the -// decoded image. +// REQUIRES: value >= 0 +func UnstageMemoryLimit(value int64) UnstageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// UnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnstageContainer(value string) UnstageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// UnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnstageSharedName(value string) UnstageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op is similar to a lightweight Dequeue. // -// Accepted values are: -// -// * 0: Use the number of channels in the BMP-encoded image. -// * 3: output an RGB image. -// * 4: output an RGBA image. -// -// Arguments: -// contents: 0-D. The BMP-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`. RGB order -func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (image tf.Output) { +// The basic functionality is similar to dequeue with many fewer +// capabilities and options. This Op is optimized for performance. +func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeBmp", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} + Type: "Unstage", -// Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. -// -// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) -// ](http://arxiv.org/abs/1511.07289) -func Elu(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Elu", - Input: []tf.Input{ - features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes square of x element-wise. -// -// I.e., \\(y = x * x = x^2\\). -func Square(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Square", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LeakyReluGradAttr is an optional argument to LeakyReluGrad. -type LeakyReluGradAttr func(optionalAttr) - -// LeakyReluGradAlpha sets the optional alpha attribute to value. -// If not specified, defaults to 0.2 -func LeakyReluGradAlpha(value float32) LeakyReluGradAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// Computes rectified linear gradients for a LeakyRelu operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding LeakyRelu operation. -// features: The features passed as input to the corresponding LeakyRelu operation, -// OR the outputs of that operation (both work equivalently). -// -// Returns `gradients * (features > 0) + alpha * gradients * (featurs <= 0)`. -func LeakyReluGrad(scope *Scope, gradients tf.Output, features tf.Output, optional ...LeakyReluGradAttr) (backprops tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LeakyReluGrad", - Input: []tf.Input{ - gradients, features, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes rectified linear 6: `min(max(features, 0), 6)`. -func Relu6(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Relu6", - Input: []tf.Input{ - features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SdcaOptimizerV2Attr is an optional argument to SdcaOptimizerV2. -type SdcaOptimizerV2Attr func(optionalAttr) - -// SdcaOptimizerV2Adaptive sets the optional adaptive attribute to value. -// -// value: Whether to use Adaptive SDCA for the inner loop. -// If not specified, defaults to true -func SdcaOptimizerV2Adaptive(value bool) SdcaOptimizerV2Attr { - return func(m optionalAttr) { - m["adaptive"] = value - } -} - -// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for -// -// linear models with L1 + L2 regularization. As global optimization objective is -// strongly-convex, the optimizer optimizes the dual objective at each step. The -// optimizer applies each update one example at a time. Examples are sampled -// uniformly, and the optimizer is learning rate free and enjoys linear convergence -// rate. -// -// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
-// Shai Shalev-Shwartz, Tong Zhang. 2012 -// -// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ -// -// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
-// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, -// Peter Richtarik, Martin Takac. 2015 -// -// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
-// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 -// -// Arguments: -// sparse_example_indices: a list of vectors which contain example indices. -// sparse_feature_indices: a list of vectors which contain feature indices. -// sparse_feature_values: a list of vectors which contains feature value -// associated with each feature group. -// dense_features: a list of matrices which contains the dense feature values. -// example_weights: a vector which contains the weight associated with each -// example. -// example_labels: a vector which contains the label/target associated with each -// example. -// sparse_indices: a list of vectors where each value is the indices which has -// corresponding weights in sparse_weights. This field maybe omitted for the -// dense approach. -// sparse_weights: a list of vectors where each value is the weight associated with -// a sparse feature group. -// dense_weights: a list of vectors where the values are the weights associated -// with a dense feature group. -// example_state_data: a list of vectors containing the example state data. -// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, -// squared and hinge losses. -// l1: Symmetric l1 regularization strength. -// l2: Symmetric l2 regularization strength. -// num_loss_partitions: Number of partitions of the global loss function. -// num_inner_iterations: Number of iterations per mini-batch. -// -// Returns a list of vectors containing the updated example state -// data.a list of vectors where each value is the delta -// weights associated with a sparse feature group.a list of vectors where the values are the delta -// weights associated with a dense feature group. -func SdcaOptimizerV2(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerV2Attr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SdcaOptimizerV2", - Input: []tf.Input{ - tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, - }, Attrs: attrs, } op := scope.AddOperation(opspec) @@ -8371,117 +8786,232 @@ func SdcaOptimizerV2(scope *Scope, sparse_example_indices []tf.Output, sparse_fe } var idx int var err error - out_example_state_data = op.Output(idx) - if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { - scope.UpdateErr("SdcaOptimizerV2", err) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("Unstage", err) return } - if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { - scope.UpdateErr("SdcaOptimizerV2", err) - return - } - return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights + return values } -// Computes the minimum along segments of a tensor. +// Converts a `RaggedTensor` into a `SparseTensor` with the same values. // -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// This operator is similar to the unsorted segment sum operator found -// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). -// Instead of computing the sum over segments, it computes the minimum such that: -// -// \\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such -// that `segment_ids[j...] == i`. -// -// If the minimum is empty for a given segment ID `i`, it outputs the largest -// possible value for the specific numeric type, -// `output[i] = numeric_limits::max()`. -// -// For example: -// -// ``` python -// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) -// tf.unsorted_segment_min(c, tf.constant([0, 1, 0]), num_segments=2) -// # ==> [[ 1, 2, 2, 1], -// # [5, 6, 7, 8]] -// ``` -// -// If the given segment ID `i` is negative, then the corresponding value is -// dropped, and will not be included in the result. +// input=ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) +// output=SparseTensor(indices=sparse_indices, values=sparse_values, +// dense_shape=sparse_dense_shape) // // Arguments: +// rt_nested_splits: The `row_splits` for the `RaggedTensor`. +// rt_dense_values: The `flat_values` for the `RaggedTensor`. // -// segment_ids: A tensor whose shape is a prefix of `data.shape`. -// -// -// Returns Has same shape as data, except for the first `segment_ids.rank` -// dimensions, which are replaced with a single dimension which has size -// `num_segments`. -func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// Returns The indices for the `SparseTensor`.The values of the `SparseTensor`.`sparse_dense_shape` is a tight bounding box of the input `RaggedTensor`. +func RaggedTensorToSparse(scope *Scope, rt_nested_splits []tf.Output, rt_dense_values tf.Output) (sparse_indices tf.Output, sparse_values tf.Output, sparse_dense_shape tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "UnsortedSegmentMin", + Type: "RaggedTensorToSparse", Input: []tf.Input{ - data, segment_ids, num_segments, + tf.OutputList(rt_nested_splits), rt_dense_values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes gradients for the scaled exponential linear (Selu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Selu operation. +// outputs: The outputs of the corresponding Selu operation. +// +// Returns The gradients: `gradients * (outputs + scale * alpha)` +// if outputs < 0, `scale * gradients` otherwise. +func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SeluGrad", + Input: []tf.Input{ + gradients, outputs, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes rectified linear gradients for a Relu operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Relu operation. -// features: The features passed as input to the corresponding Relu operation, OR -// the outputs of that operation (both work equivalently). -// -// Returns `gradients * (features > 0)`. -func ReluGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReluGrad", - Input: []tf.Input{ - gradients, features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// RetrieveTPUEmbeddingMomentumParametersAttr is an optional argument to RetrieveTPUEmbeddingMomentumParameters. +type RetrieveTPUEmbeddingMomentumParametersAttr func(optionalAttr) -// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. -type TensorArrayGatherV2Attr func(optionalAttr) - -// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { +// RetrieveTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingMomentumParametersTableId(value int64) RetrieveTPUEmbeddingMomentumParametersAttr { return func(m optionalAttr) { - m["element_shape"] = value + m["table_id"] = value } } -// Deprecated. Use TensorArrayGatherV3 +// RetrieveTPUEmbeddingMomentumParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMomentumParametersTableName(value string) RetrieveTPUEmbeddingMomentumParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve Momentum embedding parameters. // -// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 -func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the Momentum optimization algorithm.Parameter momenta updated by the Momentum optimization algorithm. +func RetrieveTPUEmbeddingMomentumParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersAttr) (parameters tf.Output, momenta tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayGatherV2", + Type: "RetrieveTPUEmbeddingMomentumParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` +// +// if < 0, `scale * features` otherwise. +// +// To be used together with +// `initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`. +// For correct dropout, use `tf.contrib.nn.alpha_dropout`. +// +// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) +func Selu(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Selu", Input: []tf.Input{ - handle, indices, flow_in, + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. +type ResourceApplyAdamAttr func(optionalAttr) + +// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, uses the nesterov update. +// If not specified, defaults to false +func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the Adam algorithm. +// +// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ +// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ +// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ +// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// beta2_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdam", + Input: []tf.Input{ + var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// PreventGradientAttr is an optional argument to PreventGradient. +type PreventGradientAttr func(optionalAttr) + +// PreventGradientMessage sets the optional message attribute to value. +// +// value: Will be printed in the error when anyone tries to differentiate +// this operation. +// If not specified, defaults to "" +func PreventGradientMessage(value string) PreventGradientAttr { + return func(m optionalAttr) { + m["message"] = value + } +} + +// An identity op that triggers an error if a gradient is requested. +// +// When executed in a graph, this op outputs its input tensor as-is. +// +// When building ops to compute gradients, the TensorFlow gradient system +// will return an error when trying to lookup the gradient of this op, +// because no gradient must ever be registered for this function. This +// op exists to prevent subtle bugs from silently returning unimplemented +// gradients in some corner cases. +// +// Arguments: +// input: any tensor. +// +// Returns the same input tensor. +func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "PreventGradient", + Input: []tf.Input{ + input, }, Attrs: attrs, } @@ -8489,41 +9019,247 @@ func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow return op.Output(0) } -// Returns the truth value of (x == y) element-wise. +// DenseToDenseSetOperationAttr is an optional argument to DenseToDenseSetOperation. +type DenseToDenseSetOperationAttr func(optionalAttr) + +// DenseToDenseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func DenseToDenseSetOperationValidateIndices(value bool) DenseToDenseSetOperationAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Applies set operation along last dimension of 2 `Tensor` inputs. // -// *NOTE*: `Equal` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. +// +// Arguments: +// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// set2: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set1`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func DenseToDenseSetOperation(scope *Scope, set1 tf.Output, set2 tf.Output, set_operation string, optional ...DenseToDenseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"set_operation": set_operation} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DenseToDenseSetOperation", + Input: []tf.Input{ + set1, set2, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes gradients for the exponential linear (Elu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Elu operation. +// outputs: The outputs of the corresponding Elu operation. +// +// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, +// `gradients` otherwise. +func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Equal", + Type: "EluGrad", Input: []tf.Input{ - x, y, + gradients, outputs, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Compute the polygamma function \\(\psi^{(n)}(x)\\). +// Computes the gradient of morphological 2-D dilation with respect to the filter. // -// The polygamma function is defined as: +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +// strides: 1-D of length 4. The stride of the sliding window for each dimension of +// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: 1-D of length 4. The input stride for atrous morphological dilation. +// Must be: `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. // -// -// \\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\) -// -// where \\(\psi(x)\\) is the digamma function. -func Polygamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { +// Returns 3-D with shape `[filter_height, filter_width, depth]`. +func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (filter_backprop tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "Polygamma", + Type: "Dilation2DBackpropFilter", Input: []tf.Input{ - a, x, + input, filter, out_backprop, }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. +type CropAndResizeGradBoxesAttr func(optionalAttr) + +// CropAndResizeGradBoxesMethod sets the optional method attribute to value. +// +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { + return func(m optionalAttr) { + m["method"] = value + } +} + +// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. +// +// Arguments: +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +// Both `image_height` and `image_width` need to be positive. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// +// Returns A 2-D tensor of shape `[num_boxes, 4]`. +func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CropAndResizeGradBoxes", + Input: []tf.Input{ + grads, image, boxes, box_ind, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolGradGradWithArgmaxAttr is an optional argument to MaxPoolGradGradWithArgmax. +type MaxPoolGradGradWithArgmaxAttr func(optionalAttr) + +// MaxPoolGradGradWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. +// +// value: Whether to include batch dimension in flattened index of `argmax`. +// If not specified, defaults to false +func MaxPoolGradGradWithArgmaxIncludeBatchInIndex(value bool) MaxPoolGradGradWithArgmaxAttr { + return func(m optionalAttr) { + m["include_batch_in_index"] = value + } +} + +// Computes second-order gradients of the maxpooling function. +// +// Arguments: +// input: The original input. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the +// input of `max_pool`. +// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input of `max_pool`. +func MaxPoolGradGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradWithArgmaxAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolGradGradWithArgmax", + Input: []tf.Input{ + input, grad, argmax, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AvgPoolGradAttr is an optional argument to AvgPoolGrad. +type AvgPoolGradAttr func(optionalAttr) + +// AvgPoolGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of the average pooling function. +// +// Arguments: +// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. +// the output of `avg_pool`. +// ksize: The size of the sliding window for each dimension of the input. +// strides: The stride of the sliding window for each dimension of the input. +// padding: The type of padding algorithm to use. +// +// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. +func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AvgPoolGrad", + Input: []tf.Input{ + orig_input_shape, grad, + }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -8577,44 +9313,47 @@ func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output return op.Output(0) } -// MaxPoolGradWithArgmaxAttr is an optional argument to MaxPoolGradWithArgmax. -type MaxPoolGradWithArgmaxAttr func(optionalAttr) +// MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. +type MaxPoolGradV2Attr func(optionalAttr) -// MaxPoolGradWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. +// MaxPoolGradV2DataFormat sets the optional data_format attribute to value. // -// value: Whether to include batch dimension in flattened index of `argmax`. -// If not specified, defaults to false -func MaxPoolGradWithArgmaxIncludeBatchInIndex(value bool) MaxPoolGradWithArgmaxAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradV2DataFormat(value string) MaxPoolGradV2Attr { return func(m optionalAttr) { - m["include_batch_in_index"] = value + m["data_format"] = value } } // Computes gradients of the maxpooling function. // // Arguments: -// input: The original input. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the -// output of `max_pool`. -// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients w.r.t. the output of `max_pool`. // ksize: The size of the window for each dimension of the input tensor. // strides: The stride of the sliding window for each dimension of the // input tensor. // padding: The type of padding algorithm to use. // -// Returns Gradients w.r.t. the input of `max_pool`. -func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradWithArgmaxAttr) (output tf.Output) { +// Returns Gradients w.r.t. the input to `max_pool`. +func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPoolGradWithArgmax", + Type: "MaxPoolGradV2", Input: []tf.Input{ - input, grad, argmax, + orig_input, orig_output, grad, ksize, strides, }, Attrs: attrs, } @@ -8622,35 +9361,60 @@ func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax return op.Output(0) } -// MutexV2Attr is an optional argument to MutexV2. -type MutexV2Attr func(optionalAttr) +// CTCLossAttr is an optional argument to CTCLoss. +type CTCLossAttr func(optionalAttr) -// MutexV2Container sets the optional container attribute to value. +// CTCLossPreprocessCollapseRepeated sets the optional preprocess_collapse_repeated attribute to value. // -// value: If non-empty, this variable is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutexV2Container(value string) MutexV2Attr { +// value: Scalar, if true then repeated labels are +// collapsed prior to the CTC calculation. +// If not specified, defaults to false +func CTCLossPreprocessCollapseRepeated(value bool) CTCLossAttr { return func(m optionalAttr) { - m["container"] = value + m["preprocess_collapse_repeated"] = value } } -// MutexV2SharedName sets the optional shared_name attribute to value. +// CTCLossCtcMergeRepeated sets the optional ctc_merge_repeated attribute to value. // -// value: If non-empty, this variable is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func MutexV2SharedName(value string) MutexV2Attr { +// value: Scalar. If set to false, *during* CTC calculation +// repeated non-blank labels will not be merged and are interpreted as +// individual labels. This is a simplified version of CTC. +// If not specified, defaults to true +func CTCLossCtcMergeRepeated(value bool) CTCLossAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["ctc_merge_repeated"] = value } } -// Creates a Mutex resource that can be locked by `MutexLock`. +// CTCLossIgnoreLongerOutputsThanInputs sets the optional ignore_longer_outputs_than_inputs attribute to value. // -// Returns The mutex resource. -func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) { +// value: Scalar. If set to true, during CTC +// calculation, items that have longer output sequences than input sequences +// are skipped: they don't contribute to the loss term and have zero-gradient. +// If not specified, defaults to false +func CTCLossIgnoreLongerOutputsThanInputs(value bool) CTCLossAttr { + return func(m optionalAttr) { + m["ignore_longer_outputs_than_inputs"] = value + } +} + +// Calculates the CTC Loss (log probability) for each batch entry. Also calculates +// +// the gradient. This class performs the softmax operation for you, so inputs +// should be e.g. linear projections of outputs by an LSTM. +// +// Arguments: +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// labels_indices: The indices of a `SparseTensor`. +// `labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for +// `(batch b, time t)`. +// labels_values: The values (labels) associated with the given batch and time. +// sequence_length: A vector containing sequence lengths (batch). +// +// Returns A vector (batch) containing log-probabilities.The gradient of `loss`. 3-D, shape: +// `(max_time x batch_size x num_classes)`. +func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_values tf.Output, sequence_length tf.Output, optional ...CTCLossAttr) (loss tf.Output, gradient tf.Output) { if scope.Err() != nil { return } @@ -8659,58 +9423,101 @@ func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) { a(attrs) } opspec := tf.OpSpec{ - Type: "MutexV2", + Type: "CTCLoss", + Input: []tf.Input{ + inputs, labels_indices, labels_values, sequence_length, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} +// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. +type CropAndResizeGradImageAttr func(optionalAttr) + +// CropAndResizeGradImageMethod sets the optional method attribute to value. +// +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { + return func(m optionalAttr) { + m["method"] = value + } +} + +// Computes the gradient of the crop_and_resize op wrt the input image tensor. +// +// Arguments: +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` +// containing the original image size. Both `image_height` and `image_width` need +// to be positive. +// +// +// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CropAndResizeGradImage", + Input: []tf.Input{ + grads, boxes, box_ind, image_size, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Connects N inputs to an N-way replicated TPU computation. -func TPUReplicatedInput(scope *Scope, inputs []tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TPUReplicatedInput", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// MaxPoolGradAttr is an optional argument to MaxPoolGrad. +type MaxPoolGradAttr func(optionalAttr) -// AvgPool3DAttr is an optional argument to AvgPool3D. -type AvgPool3DAttr func(optionalAttr) - -// AvgPool3DDataFormat sets the optional data_format attribute to value. +// MaxPoolGradDataFormat sets the optional data_format attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DDataFormat(value string) AvgPool3DAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradDataFormat(value string) MaxPoolGradAttr { return func(m optionalAttr) { m["data_format"] = value } } -// Performs 3D average pooling on the input. +// Computes gradients of the maxpooling function. // // Arguments: -// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients w.r.t. the output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. // padding: The type of padding algorithm to use. // -// Returns The average pooled output tensor. -func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DAttr) (output tf.Output) { +// Returns Gradients w.r.t. the input to `max_pool`. +func MaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -8719,9 +9526,9 @@ func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, pa a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool3D", + Type: "MaxPoolGrad", Input: []tf.Input{ - input, + orig_input, orig_output, grad, }, Attrs: attrs, } @@ -8729,122 +9536,127 @@ func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, pa return op.Output(0) } -// DepthToSpaceAttr is an optional argument to DepthToSpace. -type DepthToSpaceAttr func(optionalAttr) +// InfeedEnqueuePrelinearizedBufferAttr is an optional argument to InfeedEnqueuePrelinearizedBuffer. +type InfeedEnqueuePrelinearizedBufferAttr func(optionalAttr) -// DepthToSpaceDataFormat sets the optional data_format attribute to value. +// InfeedEnqueuePrelinearizedBufferDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. This should be -1 when the Op is running on a TPU device +// and = 0 when the Op is running on the CPU device. +// If not specified, defaults to -1 +func InfeedEnqueuePrelinearizedBufferDeviceOrdinal(value int64) InfeedEnqueuePrelinearizedBufferAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// An op which enqueues prelinearized buffer into TPU infeed. +// +// Arguments: +// input: A variant tensor representing linearized output. +// +// Returns the created operation. +func InfeedEnqueuePrelinearizedBuffer(scope *Scope, input tf.Output, optional ...InfeedEnqueuePrelinearizedBufferAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "InfeedEnqueuePrelinearizedBuffer", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// MaxPoolV2Attr is an optional argument to MaxPoolV2. +type MaxPoolV2Attr func(optionalAttr) + +// MaxPoolV2DataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. // If not specified, defaults to "NHWC" -func DepthToSpaceDataFormat(value string) DepthToSpaceAttr { +func MaxPoolV2DataFormat(value string) MaxPoolV2Attr { return func(m optionalAttr) { m["data_format"] = value } } -// DepthToSpace for tensors of type T. -// -// Rearranges data from depth into blocks of spatial data. -// This is the reverse transformation of SpaceToDepth. More specifically, -// this op outputs a copy of the input tensor where values from the `depth` -// dimension are moved in spatial blocks to the `height` and `width` dimensions. -// The attr `block_size` indicates the input block size and how the data is moved. -// -// * Chunks of data of size `block_size * block_size` from depth are rearranged -// into non-overlapping blocks of size `block_size x block_size` -// * The width the output tensor is `input_depth * block_size`, whereas the -// height is `input_height * block_size`. -// * The Y, X coordinates within each block of the output image are determined -// by the high order component of the input channel index. -// * The depth of the input tensor must be divisible by -// `block_size * block_size`. -// -// The `data_format` attr specifies the layout of the input and output tensors -// with the following options: -// "NHWC": `[ batch, height, width, channels ]` -// "NCHW": `[ batch, channels, height, width ]` -// "NCHW_VECT_C": -// `qint8 [ batch, channels / 4, height, width, 4 ]` -// -// It is useful to consider the operation as transforming a 6-D Tensor. -// e.g. for data_format = NHWC, -// Each element in the input tensor can be specified via 6 coordinates, -// ordered by decreasing memory layout significance as: -// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates -// within the input image, bX, bY means coordinates -// within the output block, oC means output channels). -// The output would be the input transposed to the following layout: -// n,iY,bY,iX,bX,oC -// -// This operation is useful for resizing the activations between convolutions -// (but keeping all data), e.g. instead of pooling. It is also useful for training -// purely convolutional models. -// -// For example, given an input of shape `[1, 1, 1, 4]`, data_format = "NHWC" and -// block_size = 2: -// -// ``` -// x = [[[[1, 2, 3, 4]]]] -// -// ``` -// -// This operation will output a tensor of shape `[1, 2, 2, 1]`: -// -// ``` -// [[[[1], [2]], -// [[3], [4]]]] -// ``` -// -// Here, the input has a batch of 1 and each batch element has shape `[1, 1, 4]`, -// the corresponding output will have 2x2 elements and will have a depth of -// 1 channel (1 = `4 / (block_size * block_size)`). -// The output element shape is `[2, 2, 1]`. -// -// For an input tensor with larger depth, here of shape `[1, 1, 1, 12]`, e.g. -// -// ``` -// x = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] -// ``` -// -// This operation, for block size of 2, will return the following tensor of shape -// `[1, 2, 2, 3]` -// -// ``` -// [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// -// ``` -// -// Similarly, for the following input of shape `[1 2 2 4]`, and a block size of 2: -// -// ``` -// x = [[[[1, 2, 3, 4], -// [5, 6, 7, 8]], -// [[9, 10, 11, 12], -// [13, 14, 15, 16]]]] -// ``` -// -// the operator will return the following tensor of shape `[1 4 4 1]`: -// -// ``` -// x = [[[ [1], [2], [5], [6]], -// [ [3], [4], [7], [8]], -// [ [9], [10], [13], [14]], -// [ [11], [12], [15], [16]]]] -// -// ``` +// Performs max pooling on the input. // // Arguments: +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// block_size: The size of the spatial block, same as in Space2Depth. -func DepthToSpace(scope *Scope, input tf.Output, block_size int64, optional ...DepthToSpaceAttr) (output tf.Output) { +// Returns The max pooled output tensor. +func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"block_size": block_size} + attrs := map[string]interface{}{"padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DepthToSpace", + Type: "MaxPoolV2", + Input: []tf.Input{ + input, ksize, strides, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolAttr is an optional argument to MaxPool. +type MaxPoolAttr func(optionalAttr) + +// MaxPoolDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolDataFormat(value string) MaxPoolAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs max pooling on the input. +// +// Arguments: +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns The max pooled output tensor. +func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool", Input: []tf.Input{ input, }, @@ -8854,10 +9666,350 @@ func DepthToSpace(scope *Scope, input tf.Output, block_size int64, optional ...D return op.Output(0) } -// Conv3DBackpropInputV2Attr is an optional argument to Conv3DBackpropInputV2. -type Conv3DBackpropInputV2Attr func(optionalAttr) +// Restore a Reader to its initial clean state. +// +// Arguments: +// reader_handle: Handle to a Reader. +// +// Returns the created operation. +func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderResetV2", + Input: []tf.Input{ + reader_handle, + }, + } + return scope.AddOperation(opspec) +} -// Conv3DBackpropInputV2DataFormat sets the optional data_format attribute to value. +// Computes softplus gradients for a softplus operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding softplus operation. +// features: The features passed as input to the corresponding softplus operation. +// +// Returns The gradients: `gradients / (1 + exp(-features))`. +func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SoftplusGrad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LRNGradAttr is an optional argument to LRNGrad. +type LRNGradAttr func(optionalAttr) + +// LRNGradDepthRadius sets the optional depth_radius attribute to value. +// +// value: A depth radius. +// If not specified, defaults to 5 +func LRNGradDepthRadius(value int64) LRNGradAttr { + return func(m optionalAttr) { + m["depth_radius"] = value + } +} + +// LRNGradBias sets the optional bias attribute to value. +// +// value: An offset (usually > 0 to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNGradBias(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["bias"] = value + } +} + +// LRNGradAlpha sets the optional alpha attribute to value. +// +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNGradAlpha(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNGradBeta sets the optional beta attribute to value. +// +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNGradBeta(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Gradients for Local Response Normalization. +// +// Arguments: +// input_grads: 4-D with shape `[batch, height, width, channels]`. +// input_image: 4-D with shape `[batch, height, width, channels]`. +// output_image: 4-D with shape `[batch, height, width, channels]`. +// +// Returns The gradients for LRN. +func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LRNGrad", + Input: []tf.Input{ + input_grads, input_image, output_image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient of morphological 2-D dilation with respect to the input. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +// strides: 1-D of length 4. The stride of the sliding window for each dimension of +// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: 1-D of length 4. The input stride for atrous morphological dilation. +// Must be: `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape `[batch, in_height, in_width, depth]`. +func Dilation2DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (in_backprop tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + opspec := tf.OpSpec{ + Type: "Dilation2DBackpropInput", + Input: []tf.Input{ + input, filter, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LRNAttr is an optional argument to LRN. +type LRNAttr func(optionalAttr) + +// LRNDepthRadius sets the optional depth_radius attribute to value. +// +// value: 0-D. Half-width of the 1-D normalization window. +// If not specified, defaults to 5 +func LRNDepthRadius(value int64) LRNAttr { + return func(m optionalAttr) { + m["depth_radius"] = value + } +} + +// LRNBias sets the optional bias attribute to value. +// +// value: An offset (usually positive to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNBias(value float32) LRNAttr { + return func(m optionalAttr) { + m["bias"] = value + } +} + +// LRNAlpha sets the optional alpha attribute to value. +// +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNAlpha(value float32) LRNAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNBeta sets the optional beta attribute to value. +// +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNBeta(value float32) LRNAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Local Response Normalization. +// +// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last +// dimension), and each vector is normalized independently. Within a given vector, +// each component is divided by the weighted, squared sum of inputs within +// `depth_radius`. In detail, +// +// sqr_sum[a, b, c, d] = +// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) +// output = input / (bias + alpha * sqr_sum) ** beta +// +// For details, see [Krizhevsky et al., ImageNet classification with deep +// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). +// +// Arguments: +// input: 4-D. +func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LRN", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// EuclideanNormAttr is an optional argument to EuclideanNorm. +type EuclideanNormAttr func(optionalAttr) + +// EuclideanNormKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func EuclideanNormKeepDims(value bool) EuclideanNormAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the euclidean norm of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func EuclideanNorm(scope *Scope, input tf.Output, axis tf.Output, optional ...EuclideanNormAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EuclideanNorm", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x + y element-wise. +// +// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Add", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Receives a tensor value broadcast from another device. +func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T, "group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} + opspec := tf.OpSpec{ + Type: "CollectiveBcastRecv", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. +type ResourceSparseApplyAdadeltaAttr func(optionalAttr) + +// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// var: Should be from a Variable(). +// +// Arguments: +// +// accum: Should be from a Variable(). +// accum_update: : Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyAdadelta", + Input: []tf.Input{ + var_, accum, accum_update, lr, rho, epsilon, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. +type MaxPool3DGradAttr func(optionalAttr) + +// MaxPool3DGradDataFormat sets the optional data_format attribute to value. // // value: The data format of the input and output data. With the // default format "NDHWC", the data is stored in the order of: @@ -8865,51 +10017,35 @@ type Conv3DBackpropInputV2Attr func(optionalAttr) // Alternatively, the format could be "NCDHW", the data storage order is: // [batch, in_channels, in_depth, in_height, in_width]. // If not specified, defaults to "NDHWC" -func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { +func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { return func(m optionalAttr) { m["data_format"] = value } } -// Conv3DBackpropInputV2Dilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of 3-D convolution with respect to the input. +// Computes gradients of max pooling function. // // Arguments: -// input_sizes: An integer vector representing the tensor shape of `input`, -// where `input` is a 5-D -// `[batch, depth, rows, cols, in_channels]` tensor. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. // strides: 1-D tensor of length 5. The stride of the sliding window for each // dimension of `input`. Must have `strides[0] = strides[4] = 1`. // padding: The type of padding algorithm to use. -func Conv3DBackpropInputV2(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputV2Attr) (output tf.Output) { +func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropInputV2", + Type: "MaxPool3DGrad", Input: []tf.Input{ - input_sizes, filter, out_backprop, + orig_input, orig_output, grad, }, Attrs: attrs, } @@ -8917,31 +10053,85 @@ func Conv3DBackpropInputV2(scope *Scope, input_sizes tf.Output, filter tf.Output return op.Output(0) } -// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. -type Conv3DBackpropInputAttr func(optionalAttr) +// Assigns sparse updates to the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] = updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] = updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] +// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterUpdate", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} -// Conv3DBackpropInputDilations sets the optional dilations attribute to value. +// Conv3DAttr is an optional argument to Conv3D. +type Conv3DAttr func(optionalAttr) + +// Conv3DDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DDataFormat(value string) Conv3DAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv3DDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. // If not specified, defaults to -func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { +func Conv3DDilations(value []int64) Conv3DAttr { return func(m optionalAttr) { m["dilations"] = value } } -// Computes the gradients of 3-D convolution with respect to the input. +// Computes a 3-D convolution given 5-D `input` and `filter` tensors. // -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 +// In signal processing, cross-correlation is a measure of similarity of +// two waveforms as a function of a time-lag applied to one of them. This +// is also known as a sliding dot product or sliding inner-product. +// +// Our Conv3D implements a form of cross-correlation. // // Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. +// input: Shape `[batch, in_depth, in_height, in_width, in_channels]`. +// filter: Shape `[filter_depth, filter_height, filter_width, in_channels, +// out_channels]`. `in_channels` must match between `input` and `filter`. // strides: 1-D tensor of length 5. The stride of the sliding window for each // dimension of `input`. Must have `strides[0] = strides[4] = 1`. // padding: The type of padding algorithm to use. -func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { +func Conv3D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv3DAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -8950,9 +10140,66 @@ func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_ba a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropInput", + Type: "Conv3D", Input: []tf.Input{ - input, filter, out_backprop, + input, filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that will write to / read from a snapshot. +// +// This dataset attempts to determine whether a valid snapshot exists at the +// `snapshot_path`, and reads from the snapshot in lieu of using `input_dataset`. +// If not, it will run the preprocessing pipeline as usual, and write out a +// snapshot of the data processed for future use. +// +// Arguments: +// input_dataset: A variant tensor representing the input dataset. +// path: The path we should write snapshots to / read snapshots from. +// +// +func SnapshotDataset(scope *Scope, input_dataset tf.Output, path tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "SnapshotDataset", + Input: []tf.Input{ + input_dataset, path, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Makes the summary of accumulated stats for the batch. +// +// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example. +// +// Arguments: +// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer. +// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients. +// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians. +// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column). +// max_splits: int; the maximum number of splits possible in the whole tree. +// num_buckets: int; equals to the maximum possible value of bucketized feature. +// +// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians. +func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets} + opspec := tf.OpSpec{ + Type: "BoostedTreesMakeStatsSummary", + Input: []tf.Input{ + node_ids, gradients, hessians, tf.OutputList(bucketized_features_list), }, Attrs: attrs, } @@ -9037,176 +10284,44 @@ func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, stri return op.Output(0) } -// MaxPoolGradAttr is an optional argument to MaxPoolGrad. -type MaxPoolGradAttr func(optionalAttr) - -// MaxPoolGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradDataFormat(value string) MaxPoolGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of the maxpooling function. +// Computes the number of elements in the given queue. // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients w.r.t. the output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// handle: The handle to a queue. // -// Returns Gradients w.r.t. the input to `max_pool`. -func MaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradAttr) (output tf.Output) { +// Returns The number of elements in the given queue. +func QueueSizeV2(scope *Scope, handle tf.Output) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MaxPoolGrad", + Type: "QueueSizeV2", Input: []tf.Input{ - orig_input, orig_output, grad, + handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CropAndResizeAttr is an optional argument to CropAndResize. -type CropAndResizeAttr func(optionalAttr) +// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput. +type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr) -// CropAndResizeMethod sets the optional method attribute to value. -// -// value: A string specifying the sampling method for resizing. It can be either -// `"bilinear"` or `"nearest"` and default to `"bilinear"`. Currently two sampling -// methods are supported: Bilinear and Nearest Neighbor. -// If not specified, defaults to "bilinear" -func CropAndResizeMethod(value string) CropAndResizeAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// CropAndResizeExtrapolationValue sets the optional extrapolation_value attribute to value. -// -// value: Value used for extrapolation, when applicable. -// If not specified, defaults to 0 -func CropAndResizeExtrapolationValue(value float32) CropAndResizeAttr { - return func(m optionalAttr) { - m["extrapolation_value"] = value - } -} - -// Extracts crops from the input image tensor and resizes them. -// -// Extracts crops from the input image tensor and resizes them using bilinear -// sampling or nearest neighbor sampling (possibly with aspect ratio change) to a -// common output size specified by `crop_size`. This is more general than the -// `crop_to_bounding_box` op which extracts a fixed size slice from the input image -// and does not allow resizing or aspect ratio change. -// -// Returns a tensor with `crops` from the input `image` at positions defined at the -// bounding box locations in `boxes`. The cropped boxes are all resized (with -// bilinear or nearest neighbor interpolation) to a fixed -// `size = [crop_height, crop_width]`. The result is a 4-D tensor -// `[num_boxes, crop_height, crop_width, depth]`. The resizing is corner aligned. -// In particular, if `boxes = [[0, 0, 1, 1]]`, the method will give identical -// results to using `tf.image.resize_bilinear()` or -// `tf.image.resize_nearest_neighbor()`(depends on the `method` argument) with -// `align_corners=True`. -// -// Arguments: -// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -// Both `image_height` and `image_width` need to be positive. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1]` in image height coordinates. We do allow `y1` > `y2`, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// crop_size: A 1-D tensor of 2 elements, `size = [crop_height, crop_width]`. All -// cropped image patches are resized to this size. The aspect ratio of the image -// content is not preserved. Both `crop_height` and `crop_width` need to be -// positive. -// -// Returns A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -func CropAndResize(scope *Scope, image tf.Output, boxes tf.Output, box_ind tf.Output, crop_size tf.Output, optional ...CropAndResizeAttr) (crops tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CropAndResize", - Input: []tf.Input{ - image, boxes, box_ind, crop_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. -type Conv2DBackpropFilterAttr func(optionalAttr) - -// Conv2DBackpropFilterUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DBackpropFilterUseCudnnOnGpu(value bool) Conv2DBackpropFilterAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} - -// Conv2DBackpropFilterExplicitPaddings sets the optional explicit_paddings attribute to value. -// -// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith -// dimension, the amount of padding inserted before and after the dimension is -// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If -// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. -// If not specified, defaults to <> -func Conv2DBackpropFilterExplicitPaddings(value []int64) Conv2DBackpropFilterAttr { - return func(m optionalAttr) { - m["explicit_paddings"] = value - } -} - -// Conv2DBackpropFilterDataFormat sets the optional data_format attribute to value. +// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value. // // value: Specify the data format of the input and output data. With the // default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. +// [batch, height, width, channels]. // Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. +// [batch, channels, height, width]. // If not specified, defaults to "NHWC" -func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { +func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["data_format"] = value } } -// Conv2DBackpropFilterDilations sets the optional dilations attribute to value. +// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value. // // value: 1-D tensor of length 4. The dilation factor for each dimension of // `input`. If set to k > 1, there will be k-1 skipped cells between each filter @@ -9214,30 +10329,33 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { // `data_format`, see above for details. Dilations in the batch and depth // dimensions must be 1. // If not specified, defaults to -func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { +func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value } } -// Computes the gradients of convolution with respect to the filter. +// Computes the gradients of depthwise convolution with respect to the input. // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 4-D -// `[filter_height, filter_width, in_channels, out_channels]` tensor. -// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// input_sizes: An integer vector representing the shape of `input`, based +// on `data_format`. For example, if `data_format` is 'NHWC' then +// `input` is a 4-D `[batch, height, width, channels]` tensor. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, depthwise_multiplier]`. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. // Gradients w.r.t. the output of the convolution. // strides: The stride of the sliding window for each dimension of the input -// of the convolution. Must be in the same order as the dimension specified with -// format. +// of the convolution. // padding: The type of padding algorithm to use. // -// Returns 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. -// the `filter` input of the convolution. -func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropFilterAttr) (output tf.Output) { +// Returns 4-D with shape according to `data_format`. For example, if +// `data_format` is 'NHWC', output shape is `[batch, in_height, +// in_width, in_channels]`. Gradient w.r.t. the input of the +// convolution. +func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -9246,9 +10364,9 @@ func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "Conv2DBackpropFilter", + Type: "DepthwiseConv2dNativeBackpropInput", Input: []tf.Input{ - input, filter_sizes, out_backprop, + input_sizes, filter, out_backprop, }, Attrs: attrs, } @@ -9256,36 +10374,124 @@ func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, return op.Output(0) } -// Computes Psi, the derivative of Lgamma (the log of the absolute value of +// Performs a padding as a preprocess during a convolution. // -// `Gamma(x)`), element-wise. -func Digamma(scope *Scope, x tf.Output) (y tf.Output) { +// Similar to FusedResizeAndPadConv2d, this op allows for an optimized +// implementation where the spatial padding transformation stage is fused with the +// im2col lookup, but in this case without the bilinear filtering required for +// resizing. Fusing the padding prevents the need to write out the intermediate +// results as whole tensors, reducing memory pressure, and we can get some latency +// gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' +// order is used instead. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "Digamma", + Type: "FusedPadConv2D", Input: []tf.Input{ - x, + input, paddings, filter, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the number of work units this Reader has finished processing. +// Read an element from the TensorArray into output `value`. // // Arguments: -// reader_handle: Handle to a Reader. -func ReaderNumWorkUnitsCompletedV2(scope *Scope, reader_handle tf.Output) (units_completed tf.Output) { +// handle: The handle to a TensorArray. +// +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns The tensor that is read from the TensorArray. +func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "ReaderNumWorkUnitsCompletedV2", + Type: "TensorArrayReadV3", Input: []tf.Input{ - reader_handle, + handle, index, flow_in, }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D. +type FusedResizeAndPadConv2DAttr func(optionalAttr) + +// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr { + return func(m optionalAttr) { + m["resize_align_corners"] = value + } +} + +// Performs a resize and padding as a preprocess during a convolution. +// +// It's often possible to do spatial transformations more efficiently as part of +// the packing stage of a convolution, so this op allows for an optimized +// implementation where these stages are fused together. This prevents the need to +// write out the intermediate results as whole tensors, reducing memory pressure, +// and we can get some latency gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and defaults to +// 'NHWC' order. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedResizeAndPadConv2D", + Input: []tf.Input{ + input, size, paddings, filter, + }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -9398,2248 +10604,46 @@ func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, pa return op.Output(0) } -// Fills empty rows in the input 2-D `SparseTensor` with a default value. -// -// The input `SparseTensor` is represented via the tuple of inputs -// (`indices`, `values`, `dense_shape`). The output `SparseTensor` has the -// same `dense_shape` but with indices `output_indices` and values -// `output_values`. -// -// This op inserts a single entry for every row that doesn't have any values. -// The index is created as `[row, 0, ..., 0]` and the inserted value -// is `default_value`. -// -// For example, suppose `sp_input` has shape `[5, 6]` and non-empty values: -// -// [0, 1]: a -// [0, 3]: b -// [2, 0]: c -// [3, 1]: d -// -// Rows 1 and 4 are empty, so the output will be of shape `[5, 6]` with values: -// -// [0, 1]: a -// [0, 3]: b -// [1, 0]: default_value -// [2, 0]: c -// [3, 1]: d -// [4, 0]: default_value -// -// The output `SparseTensor` will be in row-major order and will have the -// same shape as the input. -// -// This op also returns an indicator vector shaped `[dense_shape[0]]` such that -// -// empty_row_indicator[i] = True iff row i was an empty row. -// -// And a reverse index map vector shaped `[indices.shape[0]]` that is used during -// backpropagation, -// -// reverse_index_map[j] = out_j s.t. indices[j, :] == output_indices[out_j, :] -// -// Arguments: -// indices: 2-D. the indices of the sparse tensor. -// values: 1-D. the values of the sparse tensor. -// dense_shape: 1-D. the shape of the sparse tensor. -// default_value: 0-D. default value to insert into location `[row, 0, ..., 0]` -// for rows missing from the input sparse tensor. -// output indices: 2-D. the indices of the filled sparse tensor. -// -// Returns 1-D. the values of the filled sparse tensor.1-D. whether the dense row was missing in the -// input sparse tensor.1-D. a map from the input indices to the output indices. -func SparseFillEmptyRows(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output, default_value tf.Output) (output_indices tf.Output, output_values tf.Output, empty_row_indicator tf.Output, reverse_index_map tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseFillEmptyRows", - Input: []tf.Input{ - indices, values, dense_shape, default_value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// LoadTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingADAMParametersGradAccumDebug. -type LoadTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) - -// LoadTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load ADAM embedding parameters with debug support. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the ADAM optimization algorithm. -// momenta: Value of momenta used in the ADAM optimization algorithm. -// velocities: Value of velocities used in the ADAM optimization algorithm. -// gradient_accumulators: Value of gradient_accumulators used in the ADAM optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingADAMParametersGradAccumDebugAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingADAMParametersGradAccumDebug", - Input: []tf.Input{ - parameters, momenta, velocities, gradient_accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// BiasAddAttr is an optional argument to BiasAdd. -type BiasAddAttr func(optionalAttr) - -// BiasAddDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the bias tensor will be added to the last dimension -// of the value tensor. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// The tensor will be added to "in_channels", the third-to-the-last -// dimension. -// If not specified, defaults to "NHWC" -func BiasAddDataFormat(value string) BiasAddAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Adds `bias` to `value`. -// -// This is a special case of `tf.add` where `bias` is restricted to be 1-D. -// Broadcasting is supported, so `value` may have any number of dimensions. -// -// Arguments: -// value: Any number of dimensions. -// bias: 1-D with size the last dimension of `value`. -// -// Returns Broadcasted sum of `value` and `bias`. -func BiasAdd(scope *Scope, value tf.Output, bias tf.Output, optional ...BiasAddAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "BiasAdd", - Input: []tf.Input{ - value, bias, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SparseReduceSumSparseAttr is an optional argument to SparseReduceSumSparse. -type SparseReduceSumSparseAttr func(optionalAttr) - -// SparseReduceSumSparseKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceSumSparseKeepDims(value bool) SparseReduceSumSparseAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the sum of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a -// SparseTensor. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. -// -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. -// -// Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. -func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseReduceSumSparse", - Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// LoadTPUEmbeddingStochasticGradientDescentParametersAttr is an optional argument to LoadTPUEmbeddingStochasticGradientDescentParameters. -type LoadTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) LoadTPUEmbeddingStochasticGradientDescentParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingStochasticGradientDescentParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingStochasticGradientDescentParametersTableName(value string) LoadTPUEmbeddingStochasticGradientDescentParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load SGD embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the stochastic gradient descent optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, parameters tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingStochasticGradientDescentParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingStochasticGradientDescentParameters", - Input: []tf.Input{ - parameters, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Selects the k nearest centers for each point. -// -// Rows of points are assumed to be input points. Rows of centers are assumed to be -// the list of candidate centers. For each point, the k centers that have least L2 -// distance to it are computed. -// -// Arguments: -// points: Matrix of shape (n, d). Rows are assumed to be input points. -// centers: Matrix of shape (m, d). Rows are assumed to be centers. -// k: Number of nearest centers to return for each point. If k is larger than m, then -// only m centers are returned. -// -// Returns Matrix of shape (n, min(m, k)). Each row contains the indices of the centers -// closest to the corresponding point, ordered by increasing distance.Matrix of shape (n, min(m, k)). Each row contains the squared L2 distance to the -// corresponding center in nearest_center_indices. -func NearestNeighbors(scope *Scope, points tf.Output, centers tf.Output, k tf.Output) (nearest_center_indices tf.Output, nearest_center_distances tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NearestNeighbors", - Input: []tf.Input{ - points, centers, k, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Returns x * y element-wise. -// -// *NOTE*: `Multiply` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Mul(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Mul", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FusedBatchNormV2Attr is an optional argument to FusedBatchNormV2. -type FusedBatchNormV2Attr func(optionalAttr) - -// FusedBatchNormV2Epsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormV2Epsilon(value float32) FusedBatchNormV2Attr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormV2DataFormat sets the optional data_format attribute to value. -// -// value: The data format for x and y. Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormV2DataFormat(value string) FusedBatchNormV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormV2IsTraining sets the optional is_training attribute to value. -// -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormV2IsTraining(value bool) FusedBatchNormV2Attr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. -// -// Arguments: -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// offset: A 1D Tensor for offset, to shift to the normalized x. -// mean: A 1D Tensor for population mean. Used for inference only; -// must be empty for training. -// variance: A 1D Tensor for population variance. Used for inference only; -// must be empty for training. -// -// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow -// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by -// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused -// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance -// in the cuDNN case), to be reused in the gradient computation. -func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormV2Attr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedBatchNormV2", - Input: []tf.Input{ - x, scale, offset, mean, variance, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// Reverses specific dimensions of a tensor. -// -// NOTE `tf.reverse` has now changed behavior in preparation for 1.0. -// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0. -// -// Given a `tensor`, and a `int32` tensor `axis` representing the set of -// dimensions of `tensor` to reverse. This operation reverses each dimension -// `i` for which there exists `j` s.t. `axis[j] == i`. -// -// `tensor` can have up to 8 dimensions. The number of dimensions specified -// in `axis` may be 0 or more entries. If an index is specified more than -// once, a InvalidArgument error is raised. -// -// For example: -// -// ``` -// # tensor 't' is [[[[ 0, 1, 2, 3], -// # [ 4, 5, 6, 7], -// # [ 8, 9, 10, 11]], -// # [[12, 13, 14, 15], -// # [16, 17, 18, 19], -// # [20, 21, 22, 23]]]] -// # tensor 't' shape is [1, 2, 3, 4] -// -// # 'dims' is [3] or 'dims' is [-1] -// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], -// [ 7, 6, 5, 4], -// [ 11, 10, 9, 8]], -// [[15, 14, 13, 12], -// [19, 18, 17, 16], -// [23, 22, 21, 20]]]] -// -// # 'dims' is '[1]' (or 'dims' is '[-3]') -// reverse(t, dims) ==> [[[[12, 13, 14, 15], -// [16, 17, 18, 19], -// [20, 21, 22, 23] -// [[ 0, 1, 2, 3], -// [ 4, 5, 6, 7], -// [ 8, 9, 10, 11]]]] -// -// # 'dims' is '[2]' (or 'dims' is '[-2]') -// reverse(t, dims) ==> [[[[8, 9, 10, 11], -// [4, 5, 6, 7], -// [0, 1, 2, 3]] -// [[20, 21, 22, 23], -// [16, 17, 18, 19], -// [12, 13, 14, 15]]]] -// ``` -// -// Arguments: -// tensor: Up to 8-D. -// axis: 1-D. The indices of the dimensions to reverse. Must be in the range -// `[-rank(tensor), rank(tensor))`. -// -// Returns The same shape as `tensor`. -func ReverseV2(scope *Scope, tensor tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReverseV2", - Input: []tf.Input{ - tensor, axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adds `bias` to `value`. -// -// This is a deprecated version of BiasAdd and will be soon removed. -// -// This is a special case of `tf.add` where `bias` is restricted to be 1-D. -// Broadcasting is supported, so `value` may have any number of dimensions. -// -// Arguments: -// value: Any number of dimensions. -// bias: 1-D with size the last dimension of `value`. -// -// Returns Broadcasted sum of `value` and `bias`. -func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BiasAddV1", - Input: []tf.Input{ - value, bias, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Selects num_to_sample rows of input using the KMeans++ criterion. -// -// Rows of points are assumed to be input points. One row is selected at random. -// Subsequent rows are sampled with probability proportional to the squared L2 -// distance from the nearest row selected thus far till num_to_sample rows have -// been sampled. -// -// Arguments: -// points: Matrix of shape (n, d). Rows are assumed to be input points. -// num_to_sample: Scalar. The number of rows to sample. This value must not be larger than n. -// seed: Scalar. Seed for initializing the random number generator. -// num_retries_per_sample: Scalar. For each row that is sampled, this parameter -// specifies the number of additional points to draw from the current -// distribution before selecting the best. If a negative value is specified, a -// heuristic is used to sample O(log(num_to_sample)) additional points. -// -// Returns Matrix of shape (num_to_sample, d). The sampled rows. -func KmeansPlusPlusInitialization(scope *Scope, points tf.Output, num_to_sample tf.Output, seed tf.Output, num_retries_per_sample tf.Output) (samples tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "KmeansPlusPlusInitialization", - Input: []tf.Input{ - points, num_to_sample, seed, num_retries_per_sample, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Transforms a Tensor into a serialized TensorProto proto. -// -// Arguments: -// tensor: A Tensor of type `T`. -// -// Returns A serialized TensorProto proto of the input tensor. -func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SerializeTensor", - Input: []tf.Input{ - tensor, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// UnbatchGradAttr is an optional argument to UnbatchGrad. -type UnbatchGradAttr func(optionalAttr) - -// UnbatchGradContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnbatchGradContainer(value string) UnbatchGradAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// UnbatchGradSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnbatchGradSharedName(value string) UnbatchGradAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Gradient of Unbatch. -// -// Acts like Batch but using the given batch_index index of batching things as they -// become available. This ensures that the gradients are propagated back in the -// same session which did the forward pass. -// -// original_input: The input to the Unbatch operation this is the gradient of. -// batch_index: The batch_index given to the Unbatch operation this is the gradient -// of. -// grad: The downstream gradient. -// id: The id scalar emitted by Batch. -// batched_grad: The return value, either an empty tensor or the batched gradient. -// container: Container to control resource sharing. -// shared_name: Instances of UnbatchGrad with the same container and shared_name -// are assumed to possibly belong to the same batch. If left empty, the op name -// will be used as the shared name. -func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, grad tf.Output, id tf.Output, optional ...UnbatchGradAttr) (batched_grad tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UnbatchGrad", - Input: []tf.Input{ - original_input, batch_index, grad, id, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. -type AvgPool3DGradAttr func(optionalAttr) - -// AvgPool3DGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of average pooling function. -// -// Arguments: -// orig_input_shape: The original input dimensions. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -// -// Returns The backprop for input. -func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AvgPool3DGrad", - Input: []tf.Input{ - orig_input_shape, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. -type ParseSingleSequenceExampleAttr func(optionalAttr) - -// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. -// -// value: A list of Ncontext_sparse types; the data types of data in -// each context Feature given in context_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["context_sparse_types"] = value - } -} - -// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["feature_list_dense_types"] = value - } -} - -// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. -// -// value: A list of Ncontext_dense shapes; the shapes of data in -// each context Feature given in context_dense_keys. -// The number of elements in the Feature corresponding to context_dense_key[j] -// must always equal context_dense_shapes[j].NumEntries(). -// The shape of context_dense_values[j] will match context_dense_shapes[j]. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["context_dense_shapes"] = value - } -} - -// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. -// -// value: A list of Nfeature_list_sparse types; the data types -// of data in each FeatureList given in feature_list_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["feature_list_sparse_types"] = value - } -} - -// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. -// -// value: A list of Nfeature_list_dense shapes; the shapes of -// data in each FeatureList given in feature_list_dense_keys. -// The shape of each Feature in the FeatureList corresponding to -// feature_list_dense_key[j] must always equal -// feature_list_dense_shapes[j].NumEntries(). -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { - return func(m optionalAttr) { - m["feature_list_dense_shapes"] = value - } -} - -// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. -// -// Arguments: -// serialized: A scalar containing a binary serialized SequenceExample proto. -// feature_list_dense_missing_assumed_empty: A vector listing the -// FeatureList keys which may be missing from the SequenceExample. If the -// associated FeatureList is missing, it is treated as empty. By default, -// any FeatureList not listed in this vector must exist in the SequenceExample. -// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). -// The keys expected in the Examples' features associated with context_sparse -// values. -// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' context features associated with -// dense values. -// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors -// (scalars). The keys expected in the FeatureLists associated with sparse -// values. -// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' feature_lists associated -// with lists of dense values. -// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). -// context_dense_defaults[j] provides default values -// when the SequenceExample's context map lacks context_dense_key[j]. -// If an empty Tensor is provided for context_dense_defaults[j], -// then the Feature context_dense_keys[j] is required. -// The input type is inferred from context_dense_defaults[j], even when it's -// empty. If context_dense_defaults[j] is not empty, its shape must match -// context_dense_shapes[j]. -// debug_name: A scalar containing the name of the serialized proto. -// May contain, for example, table key (descriptive) name for the -// corresponding serialized proto. This is purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty scalar if no name is available. -func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ParseSingleSequenceExample", - Input: []tf.Input{ - serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values -} - -// SparseToDenseAttr is an optional argument to SparseToDense. -type SparseToDenseAttr func(optionalAttr) - -// SparseToDenseValidateIndices sets the optional validate_indices attribute to value. -// -// value: If true, indices are checked to make sure they are sorted in -// lexicographic order and that there are no repeats. -// If not specified, defaults to true -func SparseToDenseValidateIndices(value bool) SparseToDenseAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Converts a sparse representation into a dense tensor. -// -// Builds an array `dense` with shape `output_shape` such that -// -// ``` -// # If sparse_indices is scalar -// dense[i] = (i == sparse_indices ? sparse_values : default_value) -// -// # If sparse_indices is a vector, then for each i -// dense[sparse_indices[i]] = sparse_values[i] -// -// # If sparse_indices is an n by d matrix, then for each i in [0, n) -// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] -// ``` -// -// All other values in `dense` are set to `default_value`. If `sparse_values` is a -// scalar, all sparse indices are set to this single value. -// -// Indices should be sorted in lexicographic order, and indices must not -// contain any repeats. If `validate_indices` is true, these properties -// are checked during execution. -// -// Arguments: -// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete -// index where `sparse_values[i]` will be placed. -// output_shape: 1-D. Shape of the dense output tensor. -// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, -// or a scalar value to be used for all sparse indices. -// default_value: Scalar value to set for indices not specified in -// `sparse_indices`. -// -// Returns Dense output tensor of shape `output_shape`. -func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseToDense", - Input: []tf.Input{ - sparse_indices, output_shape, sparse_values, default_value, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// PreventGradientAttr is an optional argument to PreventGradient. -type PreventGradientAttr func(optionalAttr) - -// PreventGradientMessage sets the optional message attribute to value. -// -// value: Will be printed in the error when anyone tries to differentiate -// this operation. -// If not specified, defaults to "" -func PreventGradientMessage(value string) PreventGradientAttr { - return func(m optionalAttr) { - m["message"] = value - } -} - -// An identity op that triggers an error if a gradient is requested. -// -// When executed in a graph, this op outputs its input tensor as-is. -// -// When building ops to compute gradients, the TensorFlow gradient system -// will return an error when trying to lookup the gradient of this op, -// because no gradient must ever be registered for this function. This -// op exists to prevent subtle bugs from silently returning unimplemented -// gradients in some corner cases. -// -// Arguments: -// input: any tensor. -// -// Returns the same input tensor. -func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "PreventGradient", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes asin of x element-wise. -func Asin(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Asin", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sum along sparse segments of a tensor. -// -// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/sparse#Segmentation) -// for an explanation of segments. -// -// For example: -// -// ```python -// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) -// -// tf.sparse_segment_sum_with_num_segments( -// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) -// # => [[0 0 0 0] -// # [0 0 0 0] -// # [0 0 0 0]] -// -// tf.sparse_segment_sum_with_num_segments(c, -// tf.constant([0, 1]), -// tf.constant([0, 2], -// num_segments=4)) -// # => [[ 1 2 3 4] -// # [ 0 0 0 0] -// # [-1 -2 -3 -4] -// # [ 0 0 0 0]] -// ``` -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `num_segments`. -func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentSumWithNumSegments", - Input: []tf.Input{ - data, indices, segment_ids, num_segments, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SparseReduceMaxAttr is an optional argument to SparseReduceMax. -type SparseReduceMaxAttr func(optionalAttr) - -// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the max of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. -// -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. -// -// Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. -// -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseReduceMax", - Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DecodeRawAttr is an optional argument to DecodeRaw. -type DecodeRawAttr func(optionalAttr) - -// DecodeRawLittleEndian sets the optional little_endian attribute to value. -// -// value: Whether the input `bytes` are in little-endian order. -// Ignored for `out_type` values that are stored in a single byte like -// `uint8`. -// If not specified, defaults to true -func DecodeRawLittleEndian(value bool) DecodeRawAttr { - return func(m optionalAttr) { - m["little_endian"] = value - } -} - -// Reinterpret the bytes of a string as a vector of numbers. -// -// Arguments: -// bytes: All the elements must have the same length. -// -// -// Returns A Tensor with one more dimension than the input `bytes`. The -// added dimension will have size equal to the length of the elements -// of `bytes` divided by the number of bytes to represent `out_type`. -func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"out_type": out_type} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeRaw", - Input: []tf.Input{ - bytes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingADAMParametersAttr is an optional argument to RetrieveTPUEmbeddingADAMParameters. -type RetrieveTPUEmbeddingADAMParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingADAMParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingADAMParametersTableId(value int64) RetrieveTPUEmbeddingADAMParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingADAMParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingADAMParametersTableName(value string) RetrieveTPUEmbeddingADAMParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve ADAM embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the ADAM optimization algorithm.Parameter momenta updated by the ADAM optimization algorithm.Parameter velocities updated by the ADAM optimization algorithm. -func RetrieveTPUEmbeddingADAMParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingADAMParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// FusedBatchNormAttr is an optional argument to FusedBatchNorm. -type FusedBatchNormAttr func(optionalAttr) - -// FusedBatchNormEpsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormDataFormat sets the optional data_format attribute to value. -// -// value: The data format for x and y. Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormIsTraining sets the optional is_training attribute to value. -// -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. -// -// Arguments: -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// offset: A 1D Tensor for offset, to shift to the normalized x. -// mean: A 1D Tensor for population mean. Used for inference only; -// must be empty for training. -// variance: A 1D Tensor for population variance. Used for inference only; -// must be empty for training. -// -// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow -// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by -// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused -// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance -// in the cuDNN case), to be reused in the gradient computation. -func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedBatchNorm", - Input: []tf.Input{ - x, scale, offset, mean, variance, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// Creates a dataset that shards the input dataset. -// -// Creates a dataset that shards the input dataset by num_workers, returning a -// sharded dataset for the index-th worker. This attempts to automatically shard -// a dataset by examining the Dataset graph and inserting a shard op before the -// inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset). -// -// This dataset will throw a NotFound error if we cannot shard the dataset -// automatically. -// -// Arguments: -// input_dataset: A variant tensor representing the input dataset. -// num_workers: A scalar representing the number of workers to distribute this dataset across. -// index: A scalar representing the index of the current worker out of num_workers. -// -// -func ExperimentalAutoShardDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalAutoShardDataset", - Input: []tf.Input{ - input_dataset, num_workers, index, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RandomStandardNormalAttr is an optional argument to RandomStandardNormal. -type RandomStandardNormalAttr func(optionalAttr) - -// RandomStandardNormalSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomStandardNormalSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. -// -// Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. -// -// Returns A tensor of the specified shape filled with random normal values. -func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomStandardNormal", - Input: []tf.Input{ - shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D. -type FusedResizeAndPadConv2DAttr func(optionalAttr) - -// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr { - return func(m optionalAttr) { - m["resize_align_corners"] = value - } -} - -// Performs a resize and padding as a preprocess during a convolution. -// -// It's often possible to do spatial transformations more efficiently as part of -// the packing stage of a convolution, so this op allows for an optimized -// implementation where these stages are fused together. This prevents the need to -// write out the intermediate results as whole tensors, reducing memory pressure, -// and we can get some latency gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and defaults to -// 'NHWC' order. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedResizeAndPadConv2D", - Input: []tf.Input{ - input, size, paddings, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RandomUniformAttr is an optional argument to RandomUniform. -type RandomUniformAttr func(optionalAttr) - -// RandomUniformSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomUniformSeed(value int64) RandomUniformAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomUniformSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomUniformSeed2(value int64) RandomUniformAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. -// -// Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. -// -// Returns A tensor of the specified shape filled with uniform random values. -func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomUniform", - Input: []tf.Input{ - shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. -type ResourceApplyFtrlAttr func(optionalAttr) +// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. +type ResourceApplyRMSPropAttr func(optionalAttr) -// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected +// value: If `True`, updating of the var, ms, and mom tensors is protected // by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { +func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { return func(m optionalAttr) { m["use_locking"] = value } } -// Update '*var' according to the Ftrl-proximal scheme. +// Update '*var' according to the RMSProp algorithm. // -// accum_new = accum + grad * grad -// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom // // Arguments: // var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). // lr: Scaling factor. Must be a scalar. -// l1: L1 regulariation. Must be a scalar. -// l2: L2 regulariation. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. // -// Returns the created operation. -func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyFtrl", - Input: []tf.Input{ - var_, accum, linear, grad, lr, l1, l2, lr_power, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Transforms a vector of brain.Example protos (as strings) into typed tensors. -// -// Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// names: A vector containing the names of the serialized protos. -// May contain, for example, table key (descriptive) names for the -// corresponding serialized protos. These are purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty vector if no names are available. -// If non-empty, this vector must be the same length as "serialized". -// sparse_keys: A list of Nsparse string Tensors (scalars). -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: A list of Ndense string Tensors (scalars). -// The keys expected in the Examples' features associated with dense values. -// dense_defaults: A list of Ndense Tensors (some may be empty). -// dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// sparse_types: A list of Nsparse types; the data types of data in each Feature -// given in sparse_keys. -// Currently the ParseExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: A list of Ndense shapes; the shapes of data in each Feature -// given in dense_keys. -// The number of elements in the Feature corresponding to dense_key[j] -// must always equal dense_shapes[j].NumEntries(). -// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output -// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): -// The dense outputs are just the inputs row-stacked by batch. -// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case -// the shape of the output Tensor dense_values[j] will be -// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks -// of elements of length D1 * .... * DN, across all minibatch entries -// in the input. Any minibatch entry with less than M blocks of elements of -// length D1 * ... * DN will be padded with the corresponding default_value -// scalar element along the second dimension. -func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"sparse_types": sparse_types, "dense_shapes": dense_shapes} - opspec := tf.OpSpec{ - Type: "ParseExample", - Input: []tf.Input{ - serialized, names, tf.OutputList(sparse_keys), tf.OutputList(dense_keys), tf.OutputList(dense_defaults), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseExample", err) - return - } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseExample", err) - return - } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseExample", err) - return - } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseExample", err) - return - } - return sparse_indices, sparse_values, sparse_shapes, dense_values -} - -// Compute the pairwise cross product. -// -// `a` and `b` must be the same shape; they can either be simple 3-element vectors, -// or any shape where the innermost dimension is 3. In the latter case, each pair -// of corresponding 3-element vectors is cross-multiplied independently. -// -// Arguments: -// a: A tensor containing 3-element vectors. -// b: Another tensor, of same type and shape as `a`. -// -// Returns Pairwise cross product of the vectors in `a` and `b`. -func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cross", - Input: []tf.Input{ - a, b, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LeakyReluAttr is an optional argument to LeakyRelu. -type LeakyReluAttr func(optionalAttr) - -// LeakyReluAlpha sets the optional alpha attribute to value. -// If not specified, defaults to 0.2 -func LeakyReluAlpha(value float32) LeakyReluAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// Computes rectified linear: `max(features, features * alpha)`. -func LeakyRelu(scope *Scope, features tf.Output, optional ...LeakyReluAttr) (activations tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LeakyRelu", - Input: []tf.Input{ - features, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs random integers from a uniform distribution. -// -// The generated values are uniform integers in the range `[minval, maxval)`. -// The lower bound `minval` is included in the range, while the upper bound -// `maxval` is excluded. -// -// The random integers are slightly biased unless `maxval - minval` is an exact -// power of two. The bias is small for values of `maxval - minval` significantly -// smaller than the range of the output (either `2^32` or `2^64`). -// -// Arguments: -// resource: The handle of the resource variable that stores the state of the RNG. -// algorithm: The RNG algorithm. -// shape: The shape of the output tensor. -// minval: Minimum value (inclusive, scalar). -// maxval: Maximum value (exclusive, scalar). -// -// Returns Random values with specified shape. -func StatefulUniformInt(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "StatefulUniformInt", - Input: []tf.Input{ - resource, algorithm, shape, minval, maxval, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. -type DecodeAndCropJpegAttr func(optionalAttr) - -// DecodeAndCropJpegChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodeAndCropJpegRatio sets the optional ratio attribute to value. -// -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value - } -} - -// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. -// -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value - } -} - -// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. -// -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value - } -} - -// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. -// -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["acceptable_fraction"] = value - } -} - -// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. -// -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) -// If not specified, defaults to "" -func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { - return func(m optionalAttr) { - m["dct_method"] = value - } -} - -// Decode and Crop a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. -// -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. -// -// -// It is equivalent to a combination of decode and crop, but much faster by only -// decoding partial jpeg image. -// -// Arguments: -// contents: 0-D. The JPEG-encoded image. -// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. -// -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeAndCropJpeg", - Input: []tf.Input{ - contents, crop_window, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StatefulStandardNormalV2Attr is an optional argument to StatefulStandardNormalV2. -type StatefulStandardNormalV2Attr func(optionalAttr) - -// StatefulStandardNormalV2Dtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatefulStandardNormalV2Dtype(value tf.DataType) StatefulStandardNormalV2Attr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs random values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. -// -// Arguments: -// resource: The handle of the resource variable that stores the state of the RNG. -// algorithm: The RNG algorithm. -// shape: The shape of the output tensor. -// -// Returns A tensor of the specified shape filled with random normal values. -func StatefulStandardNormalV2(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, optional ...StatefulStandardNormalV2Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatefulStandardNormalV2", - Input: []tf.Input{ - resource, algorithm, shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StatefulUniformFullIntAttr is an optional argument to StatefulUniformFullInt. -type StatefulUniformFullIntAttr func(optionalAttr) - -// StatefulUniformFullIntDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_UINT64 -func StatefulUniformFullIntDtype(value tf.DataType) StatefulUniformFullIntAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs random integers from a uniform distribution. -// -// The generated values are uniform integers covering the whole range of `dtype`. -// -// Arguments: -// resource: The handle of the resource variable that stores the state of the RNG. -// algorithm: The RNG algorithm. -// shape: The shape of the output tensor. -// -// Returns Random values with specified shape. -func StatefulUniformFullInt(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, optional ...StatefulUniformFullIntAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatefulUniformFullInt", - Input: []tf.Input{ - resource, algorithm, shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Locks a mutex resource. The output is the lock. So long as the lock tensor -// -// is alive, any other request to use `MutexLock` with this mutex will wait. -// -// This is particularly useful for creating a critical section when used in -// conjunction with `MutexLockIdentity`: -// -// ```python -// -// mutex = mutex_v2( -// shared_name=handle_name, container=container, name=name) -// -// def execute_in_critical_section(fn, *args, **kwargs): -// lock = gen_resource_variable_ops.mutex_lock(mutex) -// -// with ops.control_dependencies([lock]): -// r = fn(*args, **kwargs) -// -// with ops.control_dependencies(nest.flatten(r)): -// with ops.colocate_with(mutex): -// ensure_lock_exists = mutex_lock_identity(lock) -// -// # Make sure that if any element of r is accessed, all of -// # them are executed together. -// r = nest.map_structure(tf.identity, r) -// -// with ops.control_dependencies([ensure_lock_exists]): -// return nest.map_structure(tf.identity, r) -// ``` -// -// While `fn` is running in the critical section, no other functions which wish to -// use this critical section may run. -// -// Often the use case is that two executions of the same graph, in parallel, -// wish to run `fn`; and we wish to ensure that only one of them executes -// at a time. This is especially important if `fn` modifies one or more -// variables at a time. -// -// It is also useful if two separate functions must share a resource, but we -// wish to ensure the usage is exclusive. -// -// Arguments: -// mutex: The mutex resource to lock. -// -// Returns A tensor that keeps a shared pointer to a lock on the mutex; -// when the Tensor is destroyed, the use count on the shared pointer is decreased -// by 1. When it reaches 0, the lock is released. -func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MutexLock", - Input: []tf.Input{ - mutex, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Transforms a serialized tensorflow.TensorProto proto into a Tensor. -// -// Arguments: -// serialized: A scalar string containing a serialized TensorProto proto. -// out_type: The type of the serialized tensor. The provided type must match the -// type of the serialized tensor and no implicit conversion will take place. -// -// Returns A Tensor of type `out_type`. -func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"out_type": out_type} - opspec := tf.OpSpec{ - Type: "ParseTensor", - Input: []tf.Input{ - serialized, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. -type MaxPoolWithArgmaxAttr func(optionalAttr) - -// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. -// If not specified, defaults to DT_INT64 -func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { - return func(m optionalAttr) { - m["Targmax"] = value - } -} - -// MaxPoolWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. -// -// value: Whether to include batch dimension in flattened index of `argmax`. -// If not specified, defaults to false -func MaxPoolWithArgmaxIncludeBatchInIndex(value bool) MaxPoolWithArgmaxAttr { - return func(m optionalAttr) { - m["include_batch_in_index"] = value - } -} - -// Performs max pooling on the input and outputs both max values and indices. -// -// The indices in `argmax` are flattened, so that a maximum value at position -// `[b, y, x, c]` becomes flattened index: -// `(y * width + x) * channels + c` if `include_batch_in_index` is False; -// `((b * height + y) * width + x) * channels + c` if `include_batch_in_index` is True. -// -// The indices returned are always in `[0, height) x [0, width)` before flattening, -// even if padding is involved and the mathematically correct answer is outside -// (either negative or too large). This is a bug, but fixing it is difficult to do -// in a safe backwards compatible way, especially due to flattening. -// -// Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. -func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolWithArgmax", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Identity transformation that models performance. -// -// Identity transformation that models performance. -// -// Arguments: -// input_dataset: A variant tensor representing the input dataset. -// -// -func ModelDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ModelDataset", - Input: []tf.Input{ - input_dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Fast Fourier transform. -// -// Computes the 1-dimensional discrete Fourier transform over the inner-most -// dimension of `input`. -// -// Arguments: -// input: A complex tensor. -// -// Returns A complex tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fft -// @end_compatibility -func FFT(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolAttr is an optional argument to MaxPool. -type MaxPoolAttr func(optionalAttr) - -// MaxPoolDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolDataFormat(value string) MaxPoolAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs max pooling on the input. -// -// Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor. -func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Multiplies sparse updates into the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] *= updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] *= updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions multiply. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterMul(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterMul", - Input: []tf.Input{ - resource, indices, updates, - }, - } - return scope.AddOperation(opspec) -} - -// Subtracts sparse updates from the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] -= updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] -= updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions add. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterSub", - Input: []tf.Input{ - resource, indices, updates, - }, - } - return scope.AddOperation(opspec) -} - -// Adds sparse updates to the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] += updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] += updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions add. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterAdd(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterAdd", - Input: []tf.Input{ - resource, indices, updates, - }, - } - return scope.AddOperation(opspec) -} - -// Reads the value of a variable. -// -// The tensor returned by this operation is immutable. -// -// The value returned by this operation is guaranteed to be influenced by all the -// writes on which this operation depends directly or indirectly, and to not be -// influenced by any of the writes which depend directly or indirectly on this -// operation. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// dtype: the dtype of the value. -func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - opspec := tf.OpSpec{ - Type: "ReadVariableOp", - Input: []tf.Input{ - resource, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. -type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) - -// ResourceSparseApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyProximalAdagradUseLocking(value bool) ResourceSparseApplyProximalAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. -// -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// prox_v = var -// prox_v -= lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. // grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. // // Returns the created operation. -func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalAdagradAttr) (o *tf.Operation) { +func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -11648,1563 +10652,15 @@ func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.O a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalAdagrad", + Type: "ResourceApplyRMSProp", Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, indices, + var_, ms, mom, lr, rho, momentum, epsilon, grad, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// DecodeJpegAttr is an optional argument to DecodeJpeg. -type DecodeJpegAttr func(optionalAttr) - -// DecodeJpegChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeJpegChannels(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodeJpegRatio sets the optional ratio attribute to value. -// -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeJpegRatio(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value - } -} - -// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. -// -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value - } -} - -// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. -// -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value - } -} - -// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. -// -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { - return func(m optionalAttr) { - m["acceptable_fraction"] = value - } -} - -// DecodeJpegDctMethod sets the optional dct_method attribute to value. -// -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) -// If not specified, defaults to "" -func DecodeJpegDctMethod(value string) DecodeJpegAttr { - return func(m optionalAttr) { - m["dct_method"] = value - } -} - -// Decode a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. -// -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. -// -// -// This op also supports decoding PNGs and non-animated GIFs since the interface is -// the same, though it is cleaner to use `tf.image.decode_image`. -// -// Arguments: -// contents: 0-D. The JPEG-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeJpeg", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput. -type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr) - -// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of depthwise convolution with respect to the input. -// -// Arguments: -// input_sizes: An integer vector representing the shape of `input`, based -// on `data_format`. For example, if `data_format` is 'NHWC' then -// `input` is a 4-D `[batch, height, width, channels]` tensor. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, depthwise_multiplier]`. -// out_backprop: 4-D with shape based on `data_format`. -// For example, if `data_format` is 'NHWC' then -// out_backprop shape is `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. -// padding: The type of padding algorithm to use. -// -// Returns 4-D with shape according to `data_format`. For example, if -// `data_format` is 'NHWC', output shape is `[batch, in_height, -// in_width, in_channels]`. Gradient w.r.t. the input of the -// convolution. -func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNativeBackpropInput", - Input: []tf.Input{ - input_sizes, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// EditDistanceAttr is an optional argument to EditDistance. -type EditDistanceAttr func(optionalAttr) - -// EditDistanceNormalize sets the optional normalize attribute to value. -// -// value: boolean (if true, edit distances are normalized by length of truth). -// -// The output is: -// If not specified, defaults to true -func EditDistanceNormalize(value bool) EditDistanceAttr { - return func(m optionalAttr) { - m["normalize"] = value - } -} - -// Computes the (possibly normalized) Levenshtein Edit Distance. -// -// The inputs are variable-length sequences provided by SparseTensors -// (hypothesis_indices, hypothesis_values, hypothesis_shape) -// and -// (truth_indices, truth_values, truth_shape). -// -// The inputs are: -// -// Arguments: -// hypothesis_indices: The indices of the hypothesis list SparseTensor. -// This is an N x R int64 matrix. -// hypothesis_values: The values of the hypothesis list SparseTensor. -// This is an N-length vector. -// hypothesis_shape: The shape of the hypothesis list SparseTensor. -// This is an R-length vector. -// truth_indices: The indices of the truth list SparseTensor. -// This is an M x R int64 matrix. -// truth_values: The values of the truth list SparseTensor. -// This is an M-length vector. -// truth_shape: truth indices, vector. -// -// Returns A dense float tensor with rank R - 1. -// -// For the example input: -// -// // hypothesis represents a 2x1 matrix with variable-length values: -// // (0,0) = ["a"] -// // (1,0) = ["b"] -// hypothesis_indices = [[0, 0, 0], -// [1, 0, 0]] -// hypothesis_values = ["a", "b"] -// hypothesis_shape = [2, 1, 1] -// -// // truth represents a 2x2 matrix with variable-length values: -// // (0,0) = [] -// // (0,1) = ["a"] -// // (1,0) = ["b", "c"] -// // (1,1) = ["a"] -// truth_indices = [[0, 1, 0], -// [1, 0, 0], -// [1, 0, 1], -// [1, 1, 0]] -// truth_values = ["a", "b", "c", "a"] -// truth_shape = [2, 2, 2] -// normalize = true -// -// The output will be: -// -// // output is a 2x2 matrix with edit distances normalized by truth lengths. -// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis -// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis -func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EditDistance", - Input: []tf.Input{ - hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns 0 if x == 0, and x * log(y) otherwise, elementwise. -func Xlogy(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Xlogy", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Stops gradient computation. -// -// When executed in a graph, this op outputs its input tensor as-is. -// -// When building ops to compute gradients, this op prevents the contribution of -// its inputs to be taken into account. Normally, the gradient generator adds ops -// to a graph to compute the derivatives of a specified 'loss' by recursively -// finding out inputs that contributed to its computation. If you insert this op -// in the graph it inputs are masked from the gradient generator. They are not -// taken into account for computing gradients. -// -// This is useful any time you want to compute a value with TensorFlow but need -// to pretend that the value was a constant. Some examples include: -// -// * The *EM* algorithm where the *M-step* should not involve backpropagation -// through the output of the *E-step*. -// * Contrastive divergence training of Boltzmann machines where, when -// differentiating the energy function, the training must not backpropagate -// through the graph that generated the samples from the model. -// * Adversarial training, where no backprop should happen through the adversarial -// example generation process. -func StopGradient(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "StopGradient", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Eagerly executes a python function to compute func(input)->output. The -// -// semantics of the input, output, and attributes are the same as those for -// PyFunc. -func EagerPyFunc(scope *Scope, input []tf.Output, token string, Tout []tf.DataType) (output []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"token": token, "Tout": Tout} - opspec := tf.OpSpec{ - Type: "EagerPyFunc", - Input: []tf.Input{ - tf.OutputList(input), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("EagerPyFunc", err) - return - } - return output -} - -// Concats all tensors in the list along the 0th dimension. -// -// Requires that all tensors have the same shape except the first dimension. -// -// input_handle: The input list. -// element_shape: The shape of the uninitialized elements in the list. If the first -// dimension is not -1, it is assumed that all list elements have the same -// leading dim. -// leading_dims: The list of leading dims of uninitialized list elements. Used if -// the leading dim of input_handle.element_shape or the element_shape input arg -// is not already set. -// tensor: The concated result. -// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. -// -func TensorListConcatV2(scope *Scope, input_handle tf.Output, element_shape tf.Output, leading_dims tf.Output, element_dtype tf.DataType) (tensor tf.Output, lengths tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "TensorListConcatV2", - Input: []tf.Input{ - input_handle, element_shape, leading_dims, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve. -type MatrixTriangularSolveAttr func(optionalAttr) - -// MatrixTriangularSolveLower sets the optional lower attribute to value. -// -// value: Boolean indicating whether the innermost matrices in `matrix` are -// lower or upper triangular. -// If not specified, defaults to true -func MatrixTriangularSolveLower(value bool) MatrixTriangularSolveAttr { - return func(m optionalAttr) { - m["lower"] = value - } -} - -// MatrixTriangularSolveAdjoint sets the optional adjoint attribute to value. -// -// value: Boolean indicating whether to solve with `matrix` or its (block-wise) -// adjoint. -// -// @compatibility(numpy) -// Equivalent to scipy.linalg.solve_triangular -// @end_compatibility -// If not specified, defaults to false -func MatrixTriangularSolveAdjoint(value bool) MatrixTriangularSolveAttr { - return func(m optionalAttr) { - m["adjoint"] = value - } -} - -// Solves systems of linear equations with upper or lower triangular matrices by -// -// backsubstitution. -// -// `matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form -// square matrices. If `lower` is `True` then the strictly upper triangular part -// of each inner-most matrix is assumed to be zero and not accessed. -// If `lower` is False then the strictly lower triangular part of each inner-most -// matrix is assumed to be zero and not accessed. -// `rhs` is a tensor of shape `[..., M, K]`. -// -// The output is a tensor of shape `[..., M, K]`. If `adjoint` is -// `True` then the innermost matrices in `output` satisfy matrix equations -// `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. -// If `adjoint` is `False` then the strictly then the innermost matrices in -// `output` satisfy matrix equations -// `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. -// -// Arguments: -// matrix: Shape is `[..., M, M]`. -// rhs: Shape is `[..., M, K]`. -// -// Returns Shape is `[..., M, K]`. -func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixTriangularSolveAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixTriangularSolve", - Input: []tf.Input{ - matrix, rhs, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Saves tensors in V2 checkpoint format. -// -// By default, saves the named tensors in full. If the caller wishes to save -// specific slices of full tensors, "shape_and_slices" should be non-empty strings -// and correspondingly well-formed. -// -// Arguments: -// prefix: Must have a single element. The prefix of the V2 checkpoint to which we -// write the tensors. -// tensor_names: shape {N}. The names of the tensors to be saved. -// shape_and_slices: shape {N}. The slice specs of the tensors to be saved. -// Empty strings indicate that they are non-partitioned tensors. -// tensors: `N` tensors to save. -// -// Returns the created operation. -func SaveV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, tensors []tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SaveV2", - Input: []tf.Input{ - prefix, tensor_names, shape_and_slices, tf.OutputList(tensors), - }, - } - return scope.AddOperation(opspec) -} - -// Concatenates quantized tensors along one dimension. -// -// Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// input_mins: The minimum scalar values for each of the input tensors. -// input_maxes: The maximum scalar values for each of the input tensors. -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "QuantizedConcat", - Input: []tf.Input{ - concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Slice a `SparseTensor` based on the `start` and `size`. -// -// For example, if the input is -// -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] -// -// Graphically the output tensors are: -// -// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] -// [ a ] -// [b c ] -// -// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] -// [ d e ] -// [ ] -// -// Arguments: -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// start: 1-D. tensor represents the start of the slice. -// size: 1-D. tensor represents the size of the slice. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. -// -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSlice", - Input: []tf.Input{ - indices, values, shape, start, size, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Runs multiple additive regression ensemble predictors on input instances and -// -// computes the logits. It is designed to be used during prediction. -// It traverses all the trees and calculates the final score for each instance. -// -// Arguments: -// -// bucketized_features: A list of rank 1 Tensors containing bucket id for each -// feature. -// logits_dimension: scalar, dimension of the logits, to be used for partial logits -// shape. -// -// Returns Output rank 2 Tensor containing logits for each example. -func BoostedTreesPredict(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (logits tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"logits_dimension": logits_dimension} - opspec := tf.OpSpec{ - Type: "BoostedTreesPredict", - Input: []tf.Input{ - tree_ensemble_handle, tf.OutputList(bucketized_features), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Pads a tensor with zeros. -// -// This operation pads a `input` with zeros according to the `paddings` you -// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the -// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many zeros to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` -// in that dimension. -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` -// -// For example: -// -// ``` -// # 't' is [[1, 1], [2, 2]] -// # 'paddings' is [[1, 1], [2, 2]] -// # rank of 't' is 2 -// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] -// [0, 0, 1, 1, 0, 0] -// [0, 0, 2, 2, 0, 0] -// [0, 0, 0, 0, 0, 0]] -// ``` -// -func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Pad", - Input: []tf.Input{ - input, paddings, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Checks whether a resource handle-based variable has been initialized. -// -// Arguments: -// resource: the input resource handle. -// -// Returns a scalar boolean which is true if the variable has been -// initialized. -func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "VarIsInitializedOp", - Input: []tf.Input{ - resource, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the min of x and y (i.e. x < y ? x : y) element-wise. -// -// *NOTE*: `Minimum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Minimum", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` -// -// if < 0, `scale * features` otherwise. -// -// To be used together with -// `initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`. -// For correct dropout, use `tf.contrib.nn.alpha_dropout`. -// -// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) -func Selu(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Selu", - Input: []tf.Input{ - features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SetSizeAttr is an optional argument to SetSize. -type SetSizeAttr func(optionalAttr) - -// SetSizeValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func SetSizeValidateIndices(value bool) SetSizeAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Number of unique elements along last dimension of input `set`. -// -// Input `set` is a `SparseTensor` represented by `set_indices`, `set_values`, -// and `set_shape`. The last dimension contains values in a set, duplicates are -// allowed but ignored. -// -// If `validate_indices` is `True`, this op validates the order and range of `set` -// indices. -// -// Arguments: -// set_indices: 2D `Tensor`, indices of a `SparseTensor`. -// set_values: 1D `Tensor`, values of a `SparseTensor`. -// set_shape: 1D `Tensor`, shape of a `SparseTensor`. -// -// Returns For `set` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st -// `n-1` dimensions as `set`. Each value is the number of unique elements in -// the corresponding `[0...n-1]` dimension of `set`. -func SetSize(scope *Scope, set_indices tf.Output, set_values tf.Output, set_shape tf.Output, optional ...SetSizeAttr) (size tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SetSize", - Input: []tf.Input{ - set_indices, set_values, set_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adds sparse `updates` to an existing tensor according to `indices`. -// -// This operation creates a new tensor by adding sparse `updates` to the passed -// in `tensor`. -// This operation is very similar to `tf.scatter_nd_add`, except that the updates -// are added onto an existing tensor (as opposed to a variable). If the memory -// for the existing tensor cannot be re-used, a copy is made and updated. -// -// `indices` is an integer tensor containing indices into a new tensor of shape -// `shape`. The last dimension of `indices` can be at most the rank of `shape`: -// -// indices.shape[-1] <= shape.rank -// -// The last dimension of `indices` corresponds to indices into elements -// (if `indices.shape[-1] = shape.rank`) or slices -// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of -// `shape`. `updates` is a tensor with shape -// -// indices.shape[:-1] + shape[indices.shape[-1]:] -// -// The simplest form of tensor_scatter_add is to add individual elements to a -// tensor by index. For example, say we want to add 4 elements in a rank-1 -// tensor with 8 elements. -// -// In Python, this scatter add operation would look like this: -// -// ```python -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// tensor = tf.ones([8], dtype=tf.int32) -// updated = tf.tensor_scatter_add(tensor, indices, updates) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [1, 12, 1, 11, 10, 1, 1, 13] -// -// We can also, insert entire slices of a higher rank tensor all at once. For -// example, if we wanted to insert two slices in the first dimension of a -// rank-3 tensor with two matrices of new values. -// -// In Python, this scatter add operation would look like this: -// -// ```python -// indices = tf.constant([[0], [2]]) -// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]], -// [[5, 5, 5, 5], [6, 6, 6, 6], -// [7, 7, 7, 7], [8, 8, 8, 8]]]) -// tensor = tf.ones([4, 4, 4]) -// updated = tf.tensor_scatter_add(tensor, indices, updates) -// with tf.Session() as sess: -// print(sess.run(scatter)) -// ``` -// -// The resulting tensor would look like this: -// -// [[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], -// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], -// [[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]], -// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] -// -// Note that on CPU, if an out of bound index is found, an error is returned. -// On GPU, if an out of bound index is found, the index is ignored. -// -// Arguments: -// tensor: Tensor to copy/update. -// indices: Index tensor. -// updates: Updates to scatter into output. -// -// Returns A new tensor copied from tensor and updates added according to the indices. -func TensorScatterAdd(scope *Scope, tensor tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorScatterAdd", - Input: []tf.Input{ - tensor, indices, updates, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sign and the log of the absolute value of the determinant of -// -// one or more square matrices. -// -// The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions -// form square matrices. The outputs are two tensors containing the signs and -// absolute values of the log determinants for all N input submatrices -// `[..., :, :]` such that the determinant = sign*exp(log_abs_determinant). -// The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU -// is the LU decomposition of the input and P is the corresponding -// permutation matrix. -// -// Arguments: -// input: Shape is `[N, M, M]`. -// -// Returns The signs of the log determinants of the inputs. Shape is `[N]`.The logs of the absolute values of the determinants -// of the N input matrices. Shape is `[N]`. -func LogMatrixDeterminant(scope *Scope, input tf.Output) (sign tf.Output, log_abs_determinant tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogMatrixDeterminant", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. -// -// More formally, let -// -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, -// -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ -// -// Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. -// -// Returns Computed precision at `k` as a `bool Tensor`. -func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InTopKV2", - Input: []tf.Input{ - predictions, targets, k, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Check if the input matches the regex pattern. -// -// The input is a string tensor of any shape. The pattern is a scalar -// string tensor which is applied to every element of the input tensor. -// The boolean values (True or False) of the output tensor indicate -// if the input matches the regex pattern provided. -// -// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) -// -// Arguments: -// input: A string tensor of the text to be processed. -// pattern: A scalar string tensor containing the regular expression to match the input. -// -// Returns A bool tensor with the same shape as `input`. -func RegexFullMatch(scope *Scope, input tf.Output, pattern tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RegexFullMatch", - Input: []tf.Input{ - input, pattern, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Converts a `RaggedTensor` into a `SparseTensor` with the same values. -// -// input=ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) -// output=SparseTensor(indices=sparse_indices, values=sparse_values, -// dense_shape=sparse_dense_shape) -// -// Arguments: -// rt_nested_splits: The `row_splits` for the `RaggedTensor`. -// rt_dense_values: The `flat_values` for the `RaggedTensor`. -// -// Returns The indices for the `SparseTensor`.The values of the `SparseTensor`.`sparse_dense_shape` is a tight bounding box of the input `RaggedTensor`. -func RaggedTensorToSparse(scope *Scope, rt_nested_splits []tf.Output, rt_dense_values tf.Output) (sparse_indices tf.Output, sparse_values tf.Output, sparse_dense_shape tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RaggedTensorToSparse", - Input: []tf.Input{ - tf.OutputList(rt_nested_splits), rt_dense_values, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. -type FusedBatchNormGradV2Attr func(optionalAttr) - -// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. -// -// value: The data format for y_backprop, x, x_backprop. -// Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. -// -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Gradient for batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. -// -// Arguments: -// y_backprop: A 4D Tensor for the gradient with respect to y. -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch -// mean to be reused in gradient computation. When is_training is -// False, a 1D Tensor for the population mean to be reused in both -// 1st and 2nd order gradient computation. -// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch -// variance (inverted variance in the cuDNN case) to be reused in -// gradient computation. When is_training is False, a 1D Tensor -// for the population variance to be reused in both 1st and 2nd -// order gradient computation. -// -// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input -// in FusedBatchNorm. -func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedBatchNormGradV2", - Input: []tf.Input{ - y_backprop, x, scale, reserve_space_1, reserve_space_2, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// Component-wise multiplies a SparseTensor by a dense Tensor. -// -// The output locations corresponding to the implicitly zero elements in the sparse -// tensor will be zero (i.e., will not take up storage space), regardless of the -// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). -// -// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not -// the other direction. -// -// Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. -// -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseDenseCwiseMul", - Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. -type MaxPool3DGradAttr func(optionalAttr) - -// MaxPool3DGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of max pooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool3DGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the name of the device on which `resource` has been placed. -func ExperimentalIteratorGetDevice(scope *Scope, resource tf.Output) (device tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ExperimentalIteratorGetDevice", - Input: []tf.Input{ - resource, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SparseReduceSumAttr is an optional argument to SparseReduceSum. -type SparseReduceSumAttr func(optionalAttr) - -// SparseReduceSumKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the sum of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. -// -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. -// -// Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. -// -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseReduceSum", - Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Records the latency of producing `input_dataset` elements in a StatsAggregator. -func ExperimentalLatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalLatencyStatsDataset", - Input: []tf.Input{ - input_dataset, tag, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. -// -// This Op does not require `a_indices` be sorted in standard lexicographic order. -// -// Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. -// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. -// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. -// b: `ndims`-D Tensor. With shape `a_shape`. -func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseTensorDenseAdd", - Input: []tf.Input{ - a_indices, a_values, a_shape, b, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QuantizedReluAttr is an optional argument to QuantizedRelu. -type QuantizedReluAttr func(optionalAttr) - -// QuantizedReluOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedReluOutType(value tf.DataType) QuantizedReluAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Computes Quantized Rectified Linear: `max(features, 0)` -// -// Arguments: -// -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. -// -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedRelu", - Input: []tf.Input{ - features, min_features, max_features, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Reorders a SparseTensor into the canonical, row-major ordering. -// -// Note that by convention, all sparse ops preserve the canonical ordering along -// increasing dimension number. The only time ordering can be violated is during -// manual manipulation of the indices and values vectors to add entries. -// -// Reordering does not affect the shape of the SparseTensor. -// -// If the tensor has rank `R` and `N` non-empty values, `input_indices` has -// shape `[N, R]`, input_values has length `N`, and input_shape has length `R`. -// -// Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// -// Returns 2-D. `N x R` matrix with the same indices as input_indices, but -// in canonical row-major ordering.1-D. `N` non-empty values corresponding to `output_indices`. -func SparseReorder(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseReorder", - Input: []tf.Input{ - input_indices, input_values, input_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Split a `SparseTensor` into `num_split` tensors along one dimension. -// -// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices -// `[0 : shape[split_dim] % num_split]` gets one extra dimension. -// For example, if `split_dim = 1` and `num_split = 2` and the input is -// -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] -// -// Graphically the output tensors are: -// -// output_tensor[0] = shape = [2, 4] -// [ a ] -// [b c ] -// -// output_tensor[1] = shape = [2, 3] -// [ d e ] -// [ ] -// -// Arguments: -// split_dim: 0-D. The dimension along which to split. Must be in the range -// `[0, rank(shape))`. -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. -// num_split: The number of ways to split. -// -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_split": num_split} - opspec := tf.OpSpec{ - Type: "SparseSplit", - Input: []tf.Input{ - split_dim, indices, values, shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - return output_indices, output_values, output_shape -} - -// Applies sparse addition to `input` using individual values or slices -// -// from `updates` according to indices `indices`. The updates are non-aliasing: -// `input` is only modified in-place if no other operations will use it. -// Otherwise, a copy of `input` is made. This operation has a gradient with -// respect to both `input` and `updates`. -// -// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `input`. -// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or `(P-K)`-dimensional slices -// (if `K < P`) along the `K`th dimension of `input`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: -// -// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$ -// -// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 -// elements. In Python, that addition would look like this: -// -// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) -// with tf.Session() as sess: -// print(sess.run(output)) -// -// The resulting value `output` would look like this: -// -// [1, 13, 3, 14, 14, 6, 7, 20] -// -// See `tf.scatter_nd` for more details about how to make updates to slices. -// -// Arguments: -// input: A Tensor. -// indices: A Tensor. Must be one of the following types: `int32`, `int64`. -// A tensor of indices into `input`. -// updates: A Tensor. Must have the same type as ref. A tensor of updated values -// to add to `input`. -// -// Returns A `Tensor` with the same shape as `input`, containing values of `input` -// updated with `updates`. -func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ScatterNdNonAliasingAdd", - Input: []tf.Input{ - input, indices, updates, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a MultiDeviceIterator resource. -// -// Arguments: -// devices: A list of devices the iterator works across. -// shared_name: If non-empty, this resource will be shared under the given name -// across multiple sessions. -// container: If non-empty, this resource is placed in the given container. -// Otherwise, a default container is used. -// output_types: The type list for the return values. -// output_shapes: The list of shapes being produced. -// -// Returns Handle to the resource created. -func MultiDeviceIterator(scope *Scope, devices []string, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"devices": devices, "shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "MultiDeviceIterator", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. -type FractionalMaxPoolAttr func(optionalAttr) - -// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. -// -// value: When set to True, generates the pooling sequence in a -// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin -// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for -// difference between pseudorandom and random. -// If not specified, defaults to false -func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["pseudo_random"] = value - } -} - -// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` -// -// `value 20 5 16 3 7` -// -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [20, 16] for fractional max pooling. -// If not specified, defaults to false -func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["overlapping"] = value - } -} - -// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. -// -// value: When set to True, a fixed pooling region will be used when -// iterating over a FractionalMaxPool node in the computation graph. Mainly used -// in unit test to make FractionalMaxPool deterministic. -// If not specified, defaults to false -func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["deterministic"] = value - } -} - -// FractionalMaxPoolSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Performs fractional max pooling on the input. -// -// Fractional max pooling is slightly different than regular max pooling. In -// regular max pooling, you downsize an input set by taking the maximum value of -// smaller N x N subsections of the set (often 2x2), and try to reduce the set by -// a factor of N, where N is an integer. Fractional max pooling, as you might -// expect from the word "fractional", means that the overall reduction ratio N -// does not have to be an integer. -// -// The sizes of the pooling regions are generated randomly but are fairly uniform. -// For example, let's look at the height dimension, and the constraints on the -// list of rows that will be pool boundaries. -// -// First we define the following: -// -// 1. input_row_length : the number of rows from the input set -// 2. output_row_length : which will be smaller than the input -// 3. alpha = input_row_length / output_row_length : our reduction ratio -// 4. K = floor(alpha) -// 5. row_pooling_sequence : this is the result list of pool boundary rows -// -// Then, row_pooling_sequence should satisfy: -// -// 1. a[0] = 0 : the first value of the sequence is 0 -// 2. a[end] = input_row_length : the last value of the sequence is the size -// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size -// 4. length(row_pooling_sequence) = output_row_length+1 -// -// For more details on fractional max pooling, see this paper: -// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) -// -// Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// pooling_ratio: Pooling ratio for each dimension of `value`, currently only -// supports row and col dimension and should be >= 1.0. For example, a valid -// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements -// must be 1.0 because we don't allow pooling on batch and channels -// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions -// respectively. -// -// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. -func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FractionalMaxPool", - Input: []tf.Input{ - value, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // Generates sparse cross from a list of sparse and dense tensors. // // The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each @@ -13276,273 +10732,18 @@ func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes [ return op.Output(0), op.Output(1), op.Output(2) } -// Inverse real-valued fast Fourier transform. -// -// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most dimension of `input`. -// -// The inner-most dimension of `input` is assumed to be the result of `RFFT`: the -// `fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If -// `fft_length` is not provided, it is computed from the size of the inner-most -// dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to -// compute `input` is odd, it should be provided since it cannot be inferred -// properly. -// -// Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller -// than the corresponding dimension of `input`, the dimension is cropped. If it is -// larger, the dimension is padded with zeros. -// -// Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [1]. The FFT length. -// -// Returns A float32 tensor of the same rank as `input`. The inner-most -// dimension of `input` is replaced with the `fft_length` samples of its inverse -// 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.irfft -// @end_compatibility -func IRFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IRFFT", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. +type MaxPoolGradGradAttr func(optionalAttr) -// Concatenates a list of `SparseTensor` along the specified dimension. +// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. // -// Concatenation is with respect to the dense versions of these sparse tensors. -// It is assumed that each input is a `SparseTensor` whose elements are ordered -// along increasing dimension number. -// -// All inputs' shapes must match, except for the concat dimension. The -// `indices`, `values`, and `shapes` lists must have the same length. -// -// The output shape is identical to the inputs', except along the concat -// dimension, where it is the sum of the inputs' sizes along that dimension. -// -// The output elements will be resorted to preserve the sort order along -// increasing dimension number. -// -// This op runs in `O(M log M)` time, where `M` is the total number of non-empty -// values across all inputs. This is due to the need for an internal sort in -// order to concatenate efficiently across an arbitrary dimension. -// -// For example, if `concat_dim = 1` and the inputs are -// -// sp_inputs[0]: shape = [2, 3] -// [0, 2]: "a" -// [1, 0]: "b" -// [1, 1]: "c" -// -// sp_inputs[1]: shape = [2, 4] -// [0, 1]: "d" -// [0, 2]: "e" -// -// then the output will be -// -// shape = [2, 7] -// [0, 2]: "a" -// [0, 4]: "d" -// [0, 5]: "e" -// [1, 0]: "b" -// [1, 1]: "c" -// -// Graphically this is equivalent to doing -// -// [ a] concat [ d e ] = [ a d e ] -// [b c ] [ ] [b c ] -// -// Arguments: -// indices: 2-D. Indices of each input `SparseTensor`. -// values: 1-D. Non-empty values of each `SparseTensor`. -// shapes: 1-D. Shapes of each `SparseTensor`. -// concat_dim: Dimension to concatenate along. Must be in range [-rank, rank), -// where rank is the number of dimensions in each input `SparseTensor`. -// -// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. -func SparseConcat(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, concat_dim int64) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"concat_dim": concat_dim} - opspec := tf.OpSpec{ - Type: "SparseConcat", - Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Elementwise computes the bitwise AND of `x` and `y`. -// -// The result will have those bits set, that are set in both `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BitwiseAnd", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deserialize and concatenate `SparseTensors` from a serialized minibatch. -// -// The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where -// `N` is the minibatch size and the rows correspond to packed outputs of -// `SerializeSparse`. The ranks of the original `SparseTensor` objects -// must all match. When the final `SparseTensor` is created, it has rank one -// higher than the ranks of the incoming `SparseTensor` objects -// (they have been concatenated along a new row dimension). -// -// The output `SparseTensor` object's shape values for all dimensions but the -// first are the max across the input `SparseTensor` objects' shape values -// for the corresponding dimensions. Its first shape value is `N`, the minibatch -// size. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the serialized input is a `[2 x 3]` matrix representing two -// original `SparseTensor` objects: -// -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// -// and -// -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] -// -// then the final deserialized `SparseTensor` will be: -// -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] -// -// Arguments: -// serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects. -// Must have 3 columns. -// dtype: The `dtype` of the serialized `SparseTensor` objects. -func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - opspec := tf.OpSpec{ - Type: "DeserializeManySparse", - Input: []tf.Input{ - serialized_sparse, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Deserialize `SparseTensor` objects. -// -// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where -// the last dimension stores serialized `SparseTensor` objects and the other N -// dimensions (N >= 0) correspond to a batch. The ranks of the original -// `SparseTensor` objects must all match. When the final `SparseTensor` is -// created, its rank is the rank of the incoming `SparseTensor` objects plus N; -// the sparse tensors have been concatenated along new dimensions, one for each -// batch. -// -// The output `SparseTensor` object's shape values for the original dimensions -// are the max across the input `SparseTensor` objects' shape values for the -// corresponding dimensions. The new dimensions match the size of the batch. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the serialized input is a `[2 x 3]` matrix representing two -// original `SparseTensor` objects: -// -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// -// and -// -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] -// -// then the final deserialized `SparseTensor` will be: -// -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] -// -// Arguments: -// serialized_sparse: The serialized `SparseTensor` objects. The last dimension -// must have 3 columns. -// dtype: The `dtype` of the serialized `SparseTensor` objects. -func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - opspec := tf.OpSpec{ - Type: "DeserializeSparse", - Input: []tf.Input{ - serialized_sparse, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. -type MaxPool3DGradGradAttr func(optionalAttr) - -// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { return func(m optionalAttr) { m["data_format"] = value } @@ -13553,15 +10754,14 @@ func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { // Arguments: // orig_input: The original input tensor. // orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. // padding: The type of padding algorithm to use. // // Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { +func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -13570,7 +10770,7 @@ func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPool3DGradGrad", + Type: "MaxPoolGradGrad", Input: []tf.Input{ orig_input, orig_output, grad, }, @@ -13580,51 +10780,385 @@ func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output return op.Output(0) } -// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. -type Conv3DBackpropFilterV2Attr func(optionalAttr) +// StatefulStandardNormalV2Attr is an optional argument to StatefulStandardNormalV2. +type StatefulStandardNormalV2Attr func(optionalAttr) -// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. +// StatefulStandardNormalV2Dtype sets the optional dtype attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatefulStandardNormalV2Dtype(value tf.DataType) StatefulStandardNormalV2Attr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random values from a normal distribution. +// +// The generated values will have mean 0 and standard deviation 1. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// shape: The shape of the output tensor. +// +// Returns A tensor of the specified shape filled with random normal values. +func StatefulStandardNormalV2(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, optional ...StatefulStandardNormalV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatefulStandardNormalV2", + Input: []tf.Input{ + resource, algorithm, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Generate the bucket boundaries for each feature based on accumulated summaries. +// +// An op that returns a list of float tensors for a quantile stream resource. Each +// tensor is Rank 1 containing bucket boundaries for a single feature. +// +// Arguments: +// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. +// num_features: inferred int; number of features to get bucket boundaries for. +// +// Returns float; List of Rank 1 Tensors each containing the bucket boundaries for a feature. +func BoostedTreesQuantileStreamResourceGetBucketBoundaries(scope *Scope, quantile_stream_resource_handle tf.Output, num_features int64) (bucket_boundaries []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_features": num_features} + opspec := tf.OpSpec{ + Type: "BoostedTreesQuantileStreamResourceGetBucketBoundaries", + Input: []tf.Input{ + quantile_stream_resource_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if bucket_boundaries, idx, err = makeOutputList(op, idx, "bucket_boundaries"); err != nil { + scope.UpdateErr("BoostedTreesQuantileStreamResourceGetBucketBoundaries", err) + return + } + return bucket_boundaries +} + +// Encodes a `RaggedTensor` into a `variant` Tensor. +// +// +// Encodes the given `RaggedTensor` and returns a `variant` Tensor. If +// `batched_input` is True, then input `RaggedTensor` is unbatched along the +// zero-th dimension, each component `RaggedTensor` is encoded into a scalar +// `variant` Tensor, and these are stacked to return a 1-D `variant` Tensor. +// If `batched_input` is False, then the input `RaggedTensor` is encoded as is and +// a scalar `variant` Tensor is returned. A `RaggedTensor` is encoded by first +// creating a 1-D `variant` Tensor with `ragged_rank + 1` elements, containing the +// splits and values Tensors of the `RaggedTensor`. Then the 1-D `variant` Tensor +// is wrapped in a scalar `variant` Tensor. See `RaggedTensorFromVariant` for the +// corresponding decoding logic. +// +// +// Arguments: +// rt_nested_splits: A list of one or more Tensors representing the splits of the input +// `RaggedTensor`. +// rt_dense_values: A Tensor representing the values of the input `RaggedTensor`. +// batched_input: A `bool` denoting whether the input is a batched `RaggedTensor`. +// +// Returns A `variant` Tensor that containing encoded `RaggedTensor`. +func RaggedTensorToVariant(scope *Scope, rt_nested_splits []tf.Output, rt_dense_values tf.Output, batched_input bool) (encoded_ragged tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"batched_input": batched_input} + opspec := tf.OpSpec{ + Type: "RaggedTensorToVariant", + Input: []tf.Input{ + tf.OutputList(rt_nested_splits), rt_dense_values, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the next representable value of `x1` in the direction of `x2`, element-wise. +// +// This operation returns the same result as the C++ std::nextafter function. +// +// It can also return a subnormal number. +// +// @compatibility(cpp) +// Equivalent to C++ std::nextafter function. +// @end_compatibility +func NextAfter(scope *Scope, x1 tf.Output, x2 tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NextAfter", + Input: []tf.Input{ + x1, x2, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Convert JSON-encoded Example records to binary protocol buffer strings. +// +// This op translates a tensor containing Example records, encoded using +// the [standard JSON +// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), +// into a tensor containing the same records encoded as binary protocol +// buffers. The resulting tensor can then be fed to any of the other +// Example-parsing ops. +// +// Arguments: +// json_examples: Each string is a JSON object serialized according to the JSON +// mapping of the Example proto. +// +// Returns Each string is a binary Example protocol buffer corresponding +// to the respective element of `json_examples`. +func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DecodeJSONExample", + Input: []tf.Input{ + json_examples, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AngleAttr is an optional argument to Angle. +type AngleAttr func(optionalAttr) + +// AngleTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func AngleTout(value tf.DataType) AngleAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Returns the argument of a complex number. +// +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the argument of each element in `input`. All elements in +// `input` must be complex numbers of the form \\(a + bj\\), where *a* +// is the real part and *b* is the imaginary part. +// +// The argument returned by this operation is of the form \\(atan2(b, a)\\). +// +// For example: +// +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.angle(input) ==> [2.0132, 1.056] +// ``` +// +// @compatibility(numpy) +// Equivalent to np.angle. +// @end_compatibility +func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Angle", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AvgPoolAttr is an optional argument to AvgPool. +type AvgPoolAttr func(optionalAttr) + +// AvgPoolDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolDataFormat(value string) AvgPoolAttr { return func(m optionalAttr) { m["data_format"] = value } } -// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. +// Performs average pooling on the input. // -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. +// Each entry in `output` is the mean of the corresponding size `ksize` +// window in `value`. +// +// Arguments: +// value: 4-D with shape `[batch, height, width, channels]`. +// ksize: The size of the sliding window for each dimension of `value`. +// strides: The stride of the sliding window for each dimension of `value`. +// padding: The type of padding algorithm to use. +// +// Returns The average pooled output tensor. +func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AvgPool", + Input: []tf.Input{ + value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that skips `count` elements from the `input_dataset`. +// +// Arguments: +// +// count: A scalar representing the number of elements from the `input_dataset` +// that should be skipped. If count is -1, skips everything. +// +// +func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "SkipDataset", + Input: []tf.Input{ + input_dataset, count, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Gather ragged slices from `params` axis `0` according to `indices`. +// +// Outputs a `RaggedTensor` output composed from `output_dense_values` and +// `output_nested_splits`, such that: +// +// ```python +// output.shape = indices.shape + params.shape[1:] +// output.ragged_rank = indices.shape.ndims + params.ragged_rank +// output[i...j, d0...dn] = params[indices[i...j], d0...dn] +// ``` +// +// where +// +// * `params = +// ragged.from_nested_row_splits(params_dense_values, params_nested_splits)` +// provides the values that should be gathered. +// * `indices` ia a dense tensor with dtype `int32` or `int64`, indicating which +// values should be gathered. +// * `output = +// ragged.from_nested_row_splits(output_dense_values, output_nested_splits)` +// is the output tensor. +// +// (Note: This c++ op is used to implement the higher-level python +// `tf.ragged.gather` op, which also supports ragged indices.) +// +// +// Arguments: +// params_nested_splits: The `nested_row_splits` tensors that define the row-partitioning for the +// `params` RaggedTensor input. +// params_dense_values: The `flat_values` for the `params` RaggedTensor. There was a terminology change +// at the python level from dense_values to flat_values, so dense_values is the +// deprecated name. +// indices: Indices in the outermost dimension of `params` of the values that should be +// gathered. +// OUTPUT_RAGGED_RANK: The ragged rank of the output RaggedTensor. `output_nested_splits` will contain +// this number of `row_splits` tensors. This value should equal +// `indices.shape.ndims + params.ragged_rank - 1`. +// +// Returns The `nested_row_splits` tensors that define the row-partitioning for the +// returned RaggedTensor.The `flat_values` for the returned RaggedTensor. +func RaggedGather(scope *Scope, params_nested_splits []tf.Output, params_dense_values tf.Output, indices tf.Output, OUTPUT_RAGGED_RANK int64) (output_nested_splits []tf.Output, output_dense_values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"OUTPUT_RAGGED_RANK": OUTPUT_RAGGED_RANK} + opspec := tf.OpSpec{ + Type: "RaggedGather", + Input: []tf.Input{ + tf.OutputList(params_nested_splits), params_dense_values, indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output_nested_splits, idx, err = makeOutputList(op, idx, "output_nested_splits"); err != nil { + scope.UpdateErr("RaggedGather", err) + return + } + output_dense_values = op.Output(idx) + return output_nested_splits, output_dense_values +} + +// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. +type Conv3DBackpropInputAttr func(optionalAttr) + +// Conv3DBackpropInputDilations sets the optional dilations attribute to value. // If not specified, defaults to -func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { +func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { m["dilations"] = value } } -// Computes the gradients of 3-D convolution with respect to the filter. +// Computes the gradients of 3-D convolution with respect to the input. +// +// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 // // Arguments: // input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 5-D -// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` -// tensor. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. // out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, // out_channels]`. // strides: 1-D tensor of length 5. The stride of the sliding window for each // dimension of `input`. Must have `strides[0] = strides[4] = 1`. // padding: The type of padding algorithm to use. -func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { +func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -13633,1953 +11167,9 @@ func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Outpu a(attrs) } opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilterV2", + Type: "Conv3DBackpropInput", Input: []tf.Input{ - input, filter_sizes, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Execute a sub graph on a remote processor. -// -// The graph specifications(such as graph itself, input tensors and output names) -// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo -// as serialized_remote_fused_graph_execute_info. -// The specifications will be passed to a dedicated registered -// remote fused graph executor. The executor will send the graph specifications -// to a remote processor and execute that graph. The execution results -// will be passed to consumer nodes as outputs of this node. -// -// Arguments: -// inputs: Arbitrary number of tensors with arbitrary data types -// -// serialized_remote_fused_graph_execute_info: Serialized protocol buffer -// of RemoteFusedGraphExecuteInfo which contains graph specifications. -// -// Returns Arbitrary number of tensors with arbitrary data types -func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} - opspec := tf.OpSpec{ - Type: "RemoteFusedGraphExecute", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("RemoteFusedGraphExecute", err) - return - } - return outputs -} - -// SerializeManySparseAttr is an optional argument to SerializeManySparse. -type SerializeManySparseAttr func(optionalAttr) - -// SerializeManySparseOutType sets the optional out_type attribute to value. -// -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeManySparseOutType(value tf.DataType) SerializeManySparseAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object. -// -// The `SparseTensor` must have rank `R` greater than 1, and the first dimension -// is treated as the minibatch dimension. Elements of the `SparseTensor` -// must be sorted in increasing order of this first dimension. The serialized -// `SparseTensor` objects going into each row of `serialized_sparse` will have -// rank `R-1`. -// -// The minibatch size `N` is extracted from `sparse_shape[0]`. -// -// Arguments: -// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. -// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. -func SerializeManySparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeManySparseAttr) (serialized_sparse tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SerializeManySparse", - Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes inverse hyperbolic cosine of x element-wise. -func Acosh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Acosh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes rectified linear 6 gradients for a Relu6 operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Relu6 operation. -// features: The features passed as input to the corresponding Relu6 operation, or -// its output; using either one produces the same result. -// -// Returns The gradients: -// `gradients * (features > 0) * (features < 6)`. -func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Relu6Grad", - Input: []tf.Input{ - gradients, features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes natural logarithm of (1 + x) element-wise. -// -// I.e., \\(y = \log_e (1 + x)\\). -func Log1p(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Log1p", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBicubicAttr is an optional argument to ResizeBicubic. -type ResizeBicubicAttr func(optionalAttr) - -// ResizeBicubicAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeBicubicAlignCorners(value bool) ResizeBicubicAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// ResizeBicubicHalfPixelCenters sets the optional half_pixel_centers attribute to value. -// If not specified, defaults to false -func ResizeBicubicHalfPixelCenters(value bool) ResizeBicubicAttr { - return func(m optionalAttr) { - m["half_pixel_centers"] = value - } -} - -// Resize `images` to `size` using bicubic interpolation. -// -// Input images can be of different types but output images are always float. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBicubicAttr) (resized_images tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBicubic", - Input: []tf.Input{ - images, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SparseTensorDenseMatMulAttr is an optional argument to SparseTensorDenseMatMul. -type SparseTensorDenseMatMulAttr func(optionalAttr) - -// SparseTensorDenseMatMulAdjointA sets the optional adjoint_a attribute to value. -// -// value: Use the adjoint of A in the matrix multiply. If A is complex, this -// is transpose(conj(A)). Otherwise it's transpose(A). -// If not specified, defaults to false -func SparseTensorDenseMatMulAdjointA(value bool) SparseTensorDenseMatMulAttr { - return func(m optionalAttr) { - m["adjoint_a"] = value - } -} - -// SparseTensorDenseMatMulAdjointB sets the optional adjoint_b attribute to value. -// -// value: Use the adjoint of B in the matrix multiply. If B is complex, this -// is transpose(conj(B)). Otherwise it's transpose(B). -// If not specified, defaults to false -func SparseTensorDenseMatMulAdjointB(value bool) SparseTensorDenseMatMulAttr { - return func(m optionalAttr) { - m["adjoint_b"] = value - } -} - -// Multiply SparseTensor (of rank 2) "A" by dense matrix "B". -// -// No validity checking is performed on the indices of A. However, the following -// input format is recommended for optimal behavior: -// -// if adjoint_a == false: -// A should be sorted in lexicographically increasing order. Use SparseReorder -// if you're not sure. -// if adjoint_a == true: -// A should be sorted in order of increasing dimension 1 (i.e., "column major" -// order instead of "row major" order). -// -// Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, size `[nnz, 2]` Matrix. -// a_values: 1-D. The `values` of the `SparseTensor`, size `[nnz]` Vector. -// a_shape: 1-D. The `shape` of the `SparseTensor`, size `[2]` Vector. -// b: 2-D. A dense Matrix. -func SparseTensorDenseMatMul(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output, optional ...SparseTensorDenseMatMulAttr) (product tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseTensorDenseMatMul", - Input: []tf.Input{ - a_indices, a_values, a_shape, b, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adds two `SparseTensor` objects to produce another `SparseTensor`. -// -// The input `SparseTensor` objects' indices are assumed ordered in standard -// lexicographic order. If this is not the case, before this step run -// `SparseReorder` to restore index ordering. -// -// By default, if two values sum to zero at some index, the output `SparseTensor` -// would still include that particular location in its index, storing a zero in the -// corresponding value slot. To override this, callers can specify `thresh`, -// indicating that if the sum has a magnitude strictly smaller than `thresh`, its -// corresponding value and index would then not be included. In particular, -// `thresh == 0` (default) means everything is kept and actual thresholding happens -// only for a positive value. -// -// In the following shapes, `nnz` is the count after taking `thresh` into account. -// -// Arguments: -// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. -// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. -// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. -// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. -// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. -// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. -// thresh: 0-D. The magnitude threshold that determines if an output value/index -// pair takes space. -func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseAdd", - Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// EnqueueTPUEmbeddingSparseTensorBatchAttr is an optional argument to EnqueueTPUEmbeddingSparseTensorBatch. -type EnqueueTPUEmbeddingSparseTensorBatchAttr func(optionalAttr) - -// EnqueueTPUEmbeddingSparseTensorBatchDeviceOrdinal sets the optional device_ordinal attribute to value. -// -// value: The TPU device to use. Should be >= 0 and less than the number -// of TPU cores in the task on which the node is placed. -// If not specified, defaults to -1 -func EnqueueTPUEmbeddingSparseTensorBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingSparseTensorBatchAttr { - return func(m optionalAttr) { - m["device_ordinal"] = value - } -} - -// EnqueueTPUEmbeddingSparseTensorBatchCombiners sets the optional combiners attribute to value. -// -// value: A list of string scalars, one for each embedding table that specify -// how to normalize the embedding activations after weighted summation. -// Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have -// the sum of the weights be 0 for 'mean' or the sum of the squared weights be -// 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for -// all tables. -// If not specified, defaults to <> -func EnqueueTPUEmbeddingSparseTensorBatchCombiners(value []string) EnqueueTPUEmbeddingSparseTensorBatchAttr { - return func(m optionalAttr) { - m["combiners"] = value - } -} - -// Eases the porting of code that uses tf.nn.embedding_lookup_sparse(). -// -// sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond -// to the ith feature. table_ids[i] indicates which embedding table to look up ith -// feature. -// -// The tensors at corresponding positions in the three input lists (sample_indices, -// embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1 -// with dim_size() equal to the total number of lookups into the table described by -// the corresponding feature. -// -// Arguments: -// sample_indices: A list of rank 1 Tensors specifying the training example to -// which the corresponding embedding_indices and aggregation_weights values -// belong. It corresponds to sp_ids.indices[:,0] in embedding_lookup_sparse(). -// embedding_indices: A list of rank 1 Tensors, indices into the embedding tables. -// It corresponds to sp_ids.values in embedding_lookup_sparse(). -// aggregation_weights: A list of rank 1 Tensors containing per training example -// aggregation weights. It corresponds to sp_weights.values in -// embedding_lookup_sparse(). -// mode_override: A string input that overrides the mode specified in the -// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', -// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set -// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. -// table_ids: A list of integers specifying the identifier of the embedding table -// (offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the -// corresponding input. The ith input is looked up using table_ids[i]. The size -// of the table_ids list must be equal to that of sample_indices, -// embedding_indices and aggregation_weights. -// -// Returns the created operation. -func EnqueueTPUEmbeddingSparseTensorBatch(scope *Scope, sample_indices []tf.Output, embedding_indices []tf.Output, aggregation_weights []tf.Output, mode_override tf.Output, table_ids []int64, optional ...EnqueueTPUEmbeddingSparseTensorBatchAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"table_ids": table_ids} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EnqueueTPUEmbeddingSparseTensorBatch", - Input: []tf.Input{ - tf.OutputList(sample_indices), tf.OutputList(embedding_indices), tf.OutputList(aggregation_weights), mode_override, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// The gradient operator for the SparseAdd op. -// -// The SparseAdd op calculates A + B, where A, B, and the sum are all represented -// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. -// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty -// values of A and B. -// -// Arguments: -// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to -// the non-empty values of the sum. -// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. -// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. -// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size -// `[nnz(sum), ndims]`. -// -// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the -// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the -// non-empty values of B. -func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseAddGrad", - Input: []tf.Input{ - backprop_val_grad, a_indices, b_indices, sum_indices, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// This op consumes a lock created by `MutexLock`. -// -// This op exists to consume a tensor created by `MutexLock` (other than -// direct control dependencies). It should be the only that consumes the tensor, -// and will raise an error if it is not. Its only purpose is to keep the -// mutex lock tensor alive until it is consumed by this op. -// -// **NOTE**: This operation must run on the same device as its input. This may -// be enforced via the `colocate_with` mechanism. -// -// Arguments: -// mutex_lock: A tensor returned by `MutexLock`. -// -// Returns the created operation. -func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConsumeMutexLock", - Input: []tf.Input{ - mutex_lock, - }, - } - return scope.AddOperation(opspec) -} - -// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd. -type ResourceScatterNdAddAttr func(optionalAttr) - -// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value. -// -// value: An optional bool. Defaults to True. If True, the assignment will -// be protected by a lock; otherwise the behavior is undefined, -// but may exhibit less contention. -// If not specified, defaults to true -func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Applies sparse addition to individual values or slices in a Variable. -// -// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `ref`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th -// dimension of `ref`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: -// -// ``` -// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] -// ``` -// -// For example, say we want to add 4 scattered elements to a rank-1 tensor to -// 8 elements. In Python, that addition would look like this: -// -// ```python -// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// add = tf.scatter_nd_add(ref, indices, updates) -// with tf.Session() as sess: -// print sess.run(add) -// ``` -// -// The resulting update to ref would look like this: -// -// [1, 13, 3, 14, 14, 6, 7, 20] -// -// See `tf.scatter_nd` for more details about how to make updates to -// slices. -// -// Arguments: -// ref: A resource handle. Must be from a VarHandleOp. -// indices: A Tensor. Must be one of the following types: int32, int64. -// A tensor of indices into ref. -// updates: A Tensor. Must have the same type as ref. A tensor of -// values to add to ref. -// -// Returns the created operation. -func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceScatterNdAdd", - Input: []tf.Input{ - ref, indices, updates, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Replaces the contents of the table with the specified keys and values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. -// -// Returns the created operation. -func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableImportV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// Extract `patches` from `images` and put them in the "depth" output dimension. -// -// Arguments: -// images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`. -// ksizes: The size of the sliding window for each dimension of `images`. -// strides: 1-D of length 4. How far the centers of two consecutive patches are in -// the images. Must be: `[1, stride_rows, stride_cols, 1]`. -// rates: 1-D of length 4. Must be: `[1, rate_rows, rate_cols, 1]`. This is the -// input stride, specifying how far two consecutive patch samples are in the -// input. Equivalent to extracting patches with -// `patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by -// subsampling them spatially by a factor of `rates`. This is equivalent to -// `rate` in dilated (a.k.a. Atrous) convolutions. -// padding: The type of padding algorithm to use. -// -// We specify the size-related attributes as: -// -// ```python -// ksizes = [1, ksize_rows, ksize_cols, 1] -// strides = [1, strides_rows, strides_cols, 1] -// rates = [1, rates_rows, rates_cols, 1] -// ``` -// -// Returns 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows * -// ksize_cols * depth]` containing image patches with size -// `ksize_rows x ksize_cols x depth` vectorized in the "depth" dimension. Note -// `out_rows` and `out_cols` are the dimensions of the output patches. -func ExtractImagePatches(scope *Scope, images tf.Output, ksizes []int64, strides []int64, rates []int64, padding string) (patches tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "rates": rates, "padding": padding} - opspec := tf.OpSpec{ - Type: "ExtractImagePatches", - Input: []tf.Input{ - images, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the mean along sparse segments of a tensor. -// -// See `tf.sparse.segment_sum` for usage examples. -// -// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first -// dimension, selecting a subset of dimension 0, specified by `indices`. -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentMean", - Input: []tf.Input{ - data, indices, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deserializes a serialized tree ensemble config and replaces current tree -// -// ensemble. -// -// Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. -// stamp_token: Token to use as the new value of the resource stamp. -// tree_ensemble_serialized: Serialized proto of the ensemble. -// -// Returns the created operation. -func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BoostedTreesDeserializeEnsemble", - Input: []tf.Input{ - tree_ensemble_handle, stamp_token, tree_ensemble_serialized, - }, - } - return scope.AddOperation(opspec) -} - -// Transforms a tf.Example proto (as a string) into typed tensors. -// -// Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// dense_defaults: A list of Tensors (some may be empty), whose length matches -// the length of `dense_keys`. dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// num_sparse: The number of sparse features to be parsed from the example. This -// must match the lengths of `sparse_keys` and `sparse_types`. -// sparse_keys: A list of `num_sparse` strings. -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: The keys expected in the Examples' features associated with dense -// values. -// sparse_types: A list of `num_sparse` types; the data types of data in each -// Feature given in sparse_keys. -// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: The shapes of data in each Feature given in dense_keys. -// The length of this list must match the length of `dense_keys`. The -// number of elements in the Feature corresponding to dense_key[j] must -// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == -// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] -// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, -// ..., DN), the shape of the output Tensor dense_values[j] will be (M, -// D1, .., DN), where M is the number of blocks of elements of length -// D1 * .... * DN, in the input. -func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} - opspec := tf.OpSpec{ - Type: "ParseSingleExample", - Input: []tf.Input{ - serialized, tf.OutputList(dense_defaults), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - return sparse_indices, sparse_values, sparse_shapes, dense_values -} - -// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. -type WholeFileReaderV2Attr func(optionalAttr) - -// WholeFileReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A Reader that outputs the entire contents of a file as a value. -// -// To use, enqueue filenames in a Queue. The output of ReaderRead will -// be a filename (key) and the contents of that file (value). -// -// Returns The handle to reference the Reader. -func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "WholeFileReaderV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Says whether the targets are in the top `K` predictions. -// -// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the -// prediction for the target class is among the top `k` predictions among -// all predictions for example `i`. Note that the behavior of `InTopK` differs -// from the `TopK` op in its handling of ties; if multiple classes have the -// same prediction value and straddle the top-`k` boundary, all of those -// classes are considered to be in the top `k`. -// -// More formally, let -// -// \\(predictions_i\\) be the predictions for all classes for example `i`, -// \\(targets_i\\) be the target class for example `i`, -// \\(out_i\\) be the output for example `i`, -// -// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ -// -// Arguments: -// predictions: A `batch_size` x `classes` tensor. -// targets: A `batch_size` vector of class ids. -// k: Number of top elements to look at for computing precision. -// -// Returns Computed Precision at `k` as a `bool Tensor`. -func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (precision tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"k": k} - opspec := tf.OpSpec{ - Type: "InTopK", - Input: []tf.Input{ - predictions, targets, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingAdagradParametersGradAccumDebug. -type RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve Adagrad embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the Adagrad optimization algorithm.Parameter accumulators updated by the Adagrad optimization algorithm.Parameter gradient_accumulators updated by the Adagrad optimization algorithm. -func RetrieveTPUEmbeddingAdagradParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingAdagradParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Serializes the tree handle to a proto -// -// Arguments: -// tree_handle: Handle to the tree resource to be serialized. -// -// Returns Serialied proto string of the tree resource. -func TensorForestTreeSerialize(scope *Scope, tree_handle tf.Output) (tree_config tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorForestTreeSerialize", - Input: []tf.Input{ - tree_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SparseMatMulAttr is an optional argument to SparseMatMul. -type SparseMatMulAttr func(optionalAttr) - -// SparseMatMulTransposeA sets the optional transpose_a attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeA(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// SparseMatMulTransposeB sets the optional transpose_b attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeB(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["a_is_sparse"] = value - } -} - -// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["b_is_sparse"] = value - } -} - -// Multiply matrix "a" by matrix "b". -// -// The inputs must be two-dimensional matrices and the inner dimension of "a" must -// match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not -// `SparseTensor`s. This op is optimized for the case where at least one of "a" or -// "b" is sparse, in the sense that they have a large proportion of zero values. -// The breakeven for using this versus a dense matrix multiply on one platform was -// 30% zero values in the sparse matrix. -// -// The gradient computation of this operation will only take advantage of sparsity -// in the input gradient when that gradient comes from a Relu. -func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseMatMul", - Input: []tf.Input{ - a, b, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ExperimentalThreadPoolHandleAttr is an optional argument to ExperimentalThreadPoolHandle. -type ExperimentalThreadPoolHandleAttr func(optionalAttr) - -// ExperimentalThreadPoolHandleMaxIntraOpParallelism sets the optional max_intra_op_parallelism attribute to value. -// -// value: The maximum degree of parallelism to use within operations that execute on this -// threadpool. -// If not specified, defaults to 1 -func ExperimentalThreadPoolHandleMaxIntraOpParallelism(value int64) ExperimentalThreadPoolHandleAttr { - return func(m optionalAttr) { - m["max_intra_op_parallelism"] = value - } -} - -// ExperimentalThreadPoolHandleContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func ExperimentalThreadPoolHandleContainer(value string) ExperimentalThreadPoolHandleAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// ExperimentalThreadPoolHandleSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func ExperimentalThreadPoolHandleSharedName(value string) ExperimentalThreadPoolHandleAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a dataset that uses a custom thread pool to compute `input_dataset`. -// -// Arguments: -// num_threads: The number of threads in the thread pool. -// display_name: A human-readable name for the threads that may be visible in some -// visualizations. -// threadpool. -// -// Returns A resource that can be consumed by one or more ExperimentalThreadPoolDataset -// ops. -func ExperimentalThreadPoolHandle(scope *Scope, num_threads int64, display_name string, optional ...ExperimentalThreadPoolHandleAttr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_threads": num_threads, "display_name": display_name} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ExperimentalThreadPoolHandle", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug. -type LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAttr) - -// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load proximal Adagrad embedding parameters with debug support. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the proximal Adagrad optimization algorithm. -// accumulators: Value of accumulators used in the proximal Adagrad optimization algorithm. -// gradient_accumulators: Value of gradient_accumulators used in the proximal Adagrad optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug", - Input: []tf.Input{ - parameters, accumulators, gradient_accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// LoadTPUEmbeddingProximalAdagradParametersAttr is an optional argument to LoadTPUEmbeddingProximalAdagradParameters. -type LoadTPUEmbeddingProximalAdagradParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingProximalAdagradParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingProximalAdagradParametersTableId(value int64) LoadTPUEmbeddingProximalAdagradParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingProximalAdagradParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingProximalAdagradParametersTableName(value string) LoadTPUEmbeddingProximalAdagradParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load proximal Adagrad embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the proximal Adagrad optimization algorithm. -// accumulators: Value of accumulators used in the proximal Adagrad optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingProximalAdagradParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingProximalAdagradParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingProximalAdagradParameters", - Input: []tf.Input{ - parameters, accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Get the current size of the TensorArray. -// -// Arguments: -// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). -// flow_in: A float scalar that enforces proper chaining of operations. -// -// Returns The current size of the TensorArray. -func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArraySizeV3", - Input: []tf.Input{ - handle, flow_in, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes gradients for the scaled exponential linear (Selu) operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Selu operation. -// outputs: The outputs of the corresponding Selu operation. -// -// Returns The gradients: `gradients * (outputs + scale * alpha)` -// if outputs < 0, `scale * gradients` otherwise. -func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SeluGrad", - Input: []tf.Input{ - gradients, outputs, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. -type ResourceSparseApplyFtrlV2Attr func(optionalAttr) - -// ResourceSparseApplyFtrlV2UseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2Attr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. -// -// That is for rows we have grad for, we update var, accum and linear as follows: -// grad_with_shrinkage = grad + 2 * l2_shrinkage * var -// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage -// linear += grad_with_shrinkage + -// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 shrinkage regulariation. Must be a scalar. -// -// lr_power: Scaling factor. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrlV2", - Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, l2_shrinkage, lr_power, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// SumAttr is an optional argument to Sum. -type SumAttr func(optionalAttr) - -// SumKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SumKeepDims(value bool) SumAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the sum of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Sum", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. -type SparseToSparseSetOperationAttr func(optionalAttr) - -// SparseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func SparseToSparseSetOperationValidateIndices(value bool) SparseToSparseSetOperationAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Applies set operation along last dimension of 2 `SparseTensor` inputs. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. -// -// If `validate_indices` is `True`, `SparseToSparseSetOperation` validates the -// order and range of `set1` and `set2` indices. -// -// Input `set1` is a `SparseTensor` represented by `set1_indices`, `set1_values`, -// and `set1_shape`. For `set1` ranked `n`, 1st `n-1` dimensions must be the same -// as `set2`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. -// -// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, -// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same -// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. -// -// If `validate_indices` is `True`, this op validates the order and range of `set1` -// and `set2` indices. -// -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. -// -// Arguments: -// set1_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set1_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set1_shape: 1D `Tensor`, shape of a `SparseTensor`. `set1_shape[0...n-1]` must -// be the same as `set2_shape[0...n-1]`, `set1_shape[n]` is the -// max set size across `0...n-1` dimensions. -// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must -// be the same as `set1_shape[0...n-1]`, `set2_shape[n]` is the -// max set size across `0...n-1` dimensions. -// -// -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_values tf.Output, set1_shape tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...SparseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"set_operation": set_operation} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseToSparseSetOperation", - Input: []tf.Input{ - set1_indices, set1_values, set1_shape, set2_indices, set2_values, set2_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Computes softmax cross entropy cost and gradients to backpropagate. -// -// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept -// a matrix of label probabilities, but rather a single label per row -// of features. This label is considered to have probability 1.0 for the -// given row. -// -// Inputs are the logits, not probabilities. -// -// Arguments: -// features: batch_size x num_classes matrix -// labels: batch_size vector with values in [0, num_classes). -// This is the label for the given minibatch entry. -// -// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). -func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSoftmaxCrossEntropyWithLogits", - Input: []tf.Input{ - features, labels, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// StridedSliceGradAttr is an optional argument to StridedSliceGrad. -type StridedSliceGradAttr func(optionalAttr) - -// StridedSliceGradBeginMask sets the optional begin_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradBeginMask(value int64) StridedSliceGradAttr { - return func(m optionalAttr) { - m["begin_mask"] = value - } -} - -// StridedSliceGradEndMask sets the optional end_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradEndMask(value int64) StridedSliceGradAttr { - return func(m optionalAttr) { - m["end_mask"] = value - } -} - -// StridedSliceGradEllipsisMask sets the optional ellipsis_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradEllipsisMask(value int64) StridedSliceGradAttr { - return func(m optionalAttr) { - m["ellipsis_mask"] = value - } -} - -// StridedSliceGradNewAxisMask sets the optional new_axis_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradNewAxisMask(value int64) StridedSliceGradAttr { - return func(m optionalAttr) { - m["new_axis_mask"] = value - } -} - -// StridedSliceGradShrinkAxisMask sets the optional shrink_axis_mask attribute to value. -// If not specified, defaults to 0 -func StridedSliceGradShrinkAxisMask(value int64) StridedSliceGradAttr { - return func(m optionalAttr) { - m["shrink_axis_mask"] = value - } -} - -// Returns the gradient of `StridedSlice`. -// -// Since `StridedSlice` cuts out pieces of its `input` which is size -// `shape`, its gradient will have the same shape (which is passed here -// as `shape`). The gradient will be zero in any element that the slice -// does not select. -// -// Arguments are the same as StridedSliceGrad with the exception that -// `dy` is the input gradient to be propagated and `shape` is the -// shape of `StridedSlice`'s `input`. -func StridedSliceGrad(scope *Scope, shape tf.Output, begin tf.Output, end tf.Output, strides tf.Output, dy tf.Output, optional ...StridedSliceGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StridedSliceGrad", - Input: []tf.Input{ - shape, begin, end, strides, dy, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LoadTPUEmbeddingRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingRMSPropParameters. -type LoadTPUEmbeddingRMSPropParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingRMSPropParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingRMSPropParametersTableId(value int64) LoadTPUEmbeddingRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingRMSPropParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingRMSPropParametersTableName(value string) LoadTPUEmbeddingRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load RMSProp embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the RMSProp optimization algorithm. -// ms: Value of ms used in the RMSProp optimization algorithm. -// mom: Value of mom used in the RMSProp optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingRMSPropParameters(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingRMSPropParameters", - Input: []tf.Input{ - parameters, ms, mom, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes the gradient for the inverse of `x` wrt its input. -// -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReciprocalGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// EuclideanNormAttr is an optional argument to EuclideanNorm. -type EuclideanNormAttr func(optionalAttr) - -// EuclideanNormKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func EuclideanNormKeepDims(value bool) EuclideanNormAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the euclidean norm of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func EuclideanNorm(scope *Scope, input tf.Output, axis tf.Output, optional ...EuclideanNormAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EuclideanNorm", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the element-wise min of two SparseTensors. -// -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. -// -// Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. -// -// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. -func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSparseMinimum", - Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. -type ResourceSparseApplyAdagradDAAttr func(optionalAttr) - -// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. -// -// Arguments: -// var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagradDA", - Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// EncodeJpegAttr is an optional argument to EncodeJpeg. -type EncodeJpegAttr func(optionalAttr) - -// EncodeJpegFormat sets the optional format attribute to value. -// -// value: Per pixel image format. -// If not specified, defaults to "" -func EncodeJpegFormat(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["format"] = value - } -} - -// EncodeJpegQuality sets the optional quality attribute to value. -// -// value: Quality of the compression from 0 to 100 (higher is better and slower). -// If not specified, defaults to 95 -func EncodeJpegQuality(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["quality"] = value - } -} - -// EncodeJpegProgressive sets the optional progressive attribute to value. -// -// value: If True, create a JPEG that loads progressively (coarse to fine). -// If not specified, defaults to false -func EncodeJpegProgressive(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["progressive"] = value - } -} - -// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. -// -// value: If True, spend CPU/RAM to reduce size with no quality change. -// If not specified, defaults to false -func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["optimize_size"] = value - } -} - -// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. -// -// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. -// If not specified, defaults to true -func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["chroma_downsampling"] = value - } -} - -// EncodeJpegDensityUnit sets the optional density_unit attribute to value. -// -// value: Unit used to specify `x_density` and `y_density`: -// pixels per inch (`'in'`) or centimeter (`'cm'`). -// If not specified, defaults to "in" -func EncodeJpegDensityUnit(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["density_unit"] = value - } -} - -// EncodeJpegXDensity sets the optional x_density attribute to value. -// -// value: Horizontal pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegXDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["x_density"] = value - } -} - -// EncodeJpegYDensity sets the optional y_density attribute to value. -// -// value: Vertical pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegYDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["y_density"] = value - } -} - -// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. -// -// value: If not empty, embed this XMP metadata in the image header. -// If not specified, defaults to "" -func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["xmp_metadata"] = value - } -} - -// JPEG-encode an image. -// -// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. -// -// The attr `format` can be used to override the color format of the encoded -// output. Values can be: -// -// * `''`: Use a default format based on the number of channels in the image. -// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension -// of `image` must be 1. -// * `rgb`: Output an RGB JPEG image. The `channels` dimension -// of `image` must be 3. -// -// If `format` is not specified or is the empty string, a default format is picked -// in function of the number of channels in `image`: -// -// * 1: Output a grayscale image. -// * 3: Output an RGB image. -// -// Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. JPEG-encoded image. -func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodeJpeg", - Input: []tf.Input{ - image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MultinomialAttr is an optional argument to Multinomial. -type MultinomialAttr func(optionalAttr) - -// MultinomialSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 is set to be non-zero, the internal random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func MultinomialSeed(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// MultinomialSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func MultinomialSeed2(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// MultinomialOutputDtype sets the optional output_dtype attribute to value. -// If not specified, defaults to DT_INT64 -func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { - return func(m optionalAttr) { - m["output_dtype"] = value - } -} - -// Draws samples from a multinomial distribution. -// -// Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. -// -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Multinomial", - Input: []tf.Input{ - logits, num_samples, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingRMSPropParametersAttr is an optional argument to RetrieveTPUEmbeddingRMSPropParameters. -type RetrieveTPUEmbeddingRMSPropParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingRMSPropParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingRMSPropParametersTableId(value int64) RetrieveTPUEmbeddingRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingRMSPropParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingRMSPropParametersTableName(value string) RetrieveTPUEmbeddingRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve RMSProp embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the RMSProp optimization algorithm.Parameter ms updated by the RMSProp optimization algorithm.Parameter mom updated by the RMSProp optimization algorithm. -func RetrieveTPUEmbeddingRMSPropParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingRMSPropParametersAttr) (parameters tf.Output, ms tf.Output, mom tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingRMSPropParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// QuantizedRelu6Attr is an optional argument to QuantizedRelu6. -type QuantizedRelu6Attr func(optionalAttr) - -// QuantizedRelu6OutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedRelu6OutType(value tf.DataType) QuantizedRelu6Attr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` -// -// Arguments: -// -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. -// -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedRelu6(scope *Scope, features tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedRelu6Attr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedRelu6", - Input: []tf.Input{ - features, min_features, max_features, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// BatchMatMulAttr is an optional argument to BatchMatMul. -type BatchMatMulAttr func(optionalAttr) - -// BatchMatMulAdjX sets the optional adj_x attribute to value. -// -// value: If `True`, adjoint the slices of `x`. Defaults to `False`. -// If not specified, defaults to false -func BatchMatMulAdjX(value bool) BatchMatMulAttr { - return func(m optionalAttr) { - m["adj_x"] = value - } -} - -// BatchMatMulAdjY sets the optional adj_y attribute to value. -// -// value: If `True`, adjoint the slices of `y`. Defaults to `False`. -// If not specified, defaults to false -func BatchMatMulAdjY(value bool) BatchMatMulAttr { - return func(m optionalAttr) { - m["adj_y"] = value - } -} - -// Multiplies slices of two tensors in batches. -// -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be adjointed (to adjoint a matrix -// means to transpose and conjugate it) before multiplication by setting -// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if adj_x else r_x -// c_o = r_y if adj_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -// -// Arguments: -// x: 2-D or higher with shape `[..., r_x, c_x]`. -// y: 2-D or higher with shape `[..., r_y, c_y]`. -// -// Returns 3-D or higher with shape `[..., r_o, c_o]` -func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "BatchMatMul", - Input: []tf.Input{ - x, y, + input, filter, out_backprop, }, Attrs: attrs, } @@ -15795,175 +11385,2039 @@ func ParseSequenceExample(scope *Scope, serialized tf.Output, debug_name tf.Outp return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values, feature_list_dense_lengths } -// LoadTPUEmbeddingADAMParametersAttr is an optional argument to LoadTPUEmbeddingADAMParameters. -type LoadTPUEmbeddingADAMParametersAttr func(optionalAttr) +// RandomPoissonAttr is an optional argument to RandomPoisson. +type RandomPoissonAttr func(optionalAttr) -// LoadTPUEmbeddingADAMParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingADAMParametersTableId(value int64) LoadTPUEmbeddingADAMParametersAttr { +// RandomPoissonSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed(value int64) RandomPoissonAttr { return func(m optionalAttr) { - m["table_id"] = value + m["seed"] = value } } -// LoadTPUEmbeddingADAMParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingADAMParametersTableName(value string) LoadTPUEmbeddingADAMParametersAttr { +// RandomPoissonSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed2(value int64) RandomPoissonAttr { return func(m optionalAttr) { - m["table_name"] = value + m["seed2"] = value } } -// Load ADAM embedding parameters. +// Use RandomPoissonV2 instead. // -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the ADAM optimization algorithm. -// momenta: Value of momenta used in the ADAM optimization algorithm. -// velocities: Value of velocities used in the ADAM optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingADAMParameters(scope *Scope, parameters tf.Output, momenta tf.Output, velocities tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingADAMParametersAttr) (o *tf.Operation) { +// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 +func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingADAMParameters", + Type: "RandomPoisson", Input: []tf.Input{ - parameters, momenta, velocities, + shape, rate, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Inverse 2D real-valued fast Fourier transform. -// -// Computes the inverse 2-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most 2 dimensions of `input`. -// -// The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`: -// The inner-most dimension contains the `fft_length / 2 + 1` unique components of -// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed -// from the size of the inner-most 2 dimensions of `input`. If the FFT length used -// to compute `input` is odd, it should be provided since it cannot be inferred -// properly. -// -// Along each axis `IRFFT2D` is computed on, if `fft_length` (or -// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. -// -// Returns A float32 tensor of the same rank as `input`. The inner-most 2 -// dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.irfft2 -// @end_compatibility -func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { +// Computes the derivative of a Gamma random sample w.r.t. `alpha`. +func RandomGammaGrad(scope *Scope, alpha tf.Output, sample tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IRFFT2D", + Type: "RandomGammaGrad", Input: []tf.Input{ - input, fft_length, + alpha, sample, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// InfeedEnqueueTupleAttr is an optional argument to InfeedEnqueueTuple. -type InfeedEnqueueTupleAttr func(optionalAttr) - -// InfeedEnqueueTupleLayouts sets the optional layouts attribute to value. +// Returns a batched diagonal tensor with a given batched diagonal values. // -// value: A vector holding the requested layout in minor-to-major sequence for -// all the tuple shapes, in the order the shapes appear in the "shapes" input. -// The layout elements for a sub-shape can be set to -1, in which case the -// corresponding layout will be computed by the infeed operation. -// If not specified, defaults to <> -func InfeedEnqueueTupleLayouts(value []int64) InfeedEnqueueTupleAttr { - return func(m optionalAttr) { - m["layouts"] = value - } -} - -// InfeedEnqueueTupleDeviceOrdinal sets the optional device_ordinal attribute to value. +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: // -// value: The TPU device to use. This should be -1 when the Op -// is running on a TPU device, and >= 0 when the Op is running on the CPU -// device. -// If not specified, defaults to -1 -func InfeedEnqueueTupleDeviceOrdinal(value int64) InfeedEnqueueTupleAttr { - return func(m optionalAttr) { - m["device_ordinal"] = value - } -} - -// Feeds multiple Tensor values into the computation as an XLA tuple. +// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a +// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: +// +// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. +// +// For example: +// +// ``` +// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] +// +// and diagonal.shape = (2, 4) +// +// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] +// +// which has shape (2, 4, 4) +// ``` // // Arguments: -// inputs: A list of tensors that will be provided using the infeed mechanism. -// shapes: The shapes of each tensor in `inputs`. +// diagonal: Rank `k`, where `k >= 1`. // -// Returns the created operation. -func InfeedEnqueueTuple(scope *Scope, inputs []tf.Output, shapes []tf.Shape, optional ...InfeedEnqueueTupleAttr) (o *tf.Operation) { +// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. +func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shapes": shapes} + opspec := tf.OpSpec{ + Type: "MatrixDiag", + Input: []tf.Input{ + diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomGammaAttr is an optional argument to RandomGamma. +type RandomGammaAttr func(optionalAttr) + +// RandomGammaSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomGammaSeed(value int64) RandomGammaAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomGammaSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomGammaSeed2(value int64) RandomGammaAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from the Gamma distribution(s) described by alpha. +// +// This op uses the algorithm by Marsaglia et al. to acquire samples via +// transformation-rejection from pairs of uniform and normal random variables. +// See http://dl.acm.org/citation.cfm?id=358414 +// +// Arguments: +// shape: 1-D integer tensor. Shape of independent samples to draw from each +// distribution described by the shape parameters given in alpha. +// alpha: A tensor in which each scalar is a "shape" parameter describing the +// associated gamma distribution. +// +// Returns A tensor with shape `shape + shape(alpha)`. Each slice +// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for +// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. +func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "InfeedEnqueueTuple", + Type: "RandomGamma", + Input: []tf.Input{ + shape, alpha, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomShuffleAttr is an optional argument to RandomShuffle. +type RandomShuffleAttr func(optionalAttr) + +// RandomShuffleSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomShuffleSeed(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleSeed2(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Randomly shuffles a tensor along its first dimension. +// +// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped +// to one and only one `output[i]`. For example, a mapping that might occur for a +// 3x2 tensor is: +// +// ``` +// [[1, 2], [[5, 6], +// [3, 4], ==> [1, 2], +// [5, 6]] [3, 4]] +// ``` +// +// Arguments: +// value: The tensor to be shuffled. +// +// Returns A tensor of same shape and type as `value`, shuffled along its first +// dimension. +func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomShuffle", + Input: []tf.Input{ + value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. +type ResourceSparseApplyFtrlV2Attr func(optionalAttr) + +// ResourceSparseApplyFtrlV2UseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2Attr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// +// That is for rows we have grad for, we update var, accum and linear as follows: +// grad_with_shrinkage = grad + 2 * l2_shrinkage * var +// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +// linear += grad_with_shrinkage + +// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 shrinkage regulariation. Must be a scalar. +// +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyFtrlV2", + Input: []tf.Input{ + var_, accum, linear, grad, indices, lr, l1, l2, l2_shrinkage, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// MaxPool3DAttr is an optional argument to MaxPool3D. +type MaxPool3DAttr func(optionalAttr) + +// MaxPool3DDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DDataFormat(value string) MaxPool3DAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs 3D max pooling on the input. +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +// +// Returns The max pooled output tensor. +func MaxPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool3D", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Concatenates tensors along one dimension. +// +// Arguments: +// values: List of `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// axis: 0-D. The dimension along which to concatenate. Must be in the +// range [-rank(values), rank(values)). +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatV2", + Input: []tf.Input{ + tf.OutputList(values), axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// A dataset that splits the elements of its input into multiple elements. +func ExperimentalUnbatchDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalUnbatchDataset", + Input: []tf.Input{ + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomUniformIntAttr is an optional argument to RandomUniformInt. +type RandomUniformIntAttr func(optionalAttr) + +// RandomUniformIntSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomUniformIntSeed(value int64) RandomUniformIntAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomUniformIntSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomUniformIntSeed2(value int64) RandomUniformIntAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random integers from a uniform distribution. +// +// The generated values are uniform integers in the range `[minval, maxval)`. +// The lower bound `minval` is included in the range, while the upper bound +// `maxval` is excluded. +// +// The random integers are slightly biased unless `maxval - minval` is an exact +// power of two. The bias is small for values of `maxval - minval` significantly +// smaller than the range of the output (either `2^32` or `2^64`). +// +// Arguments: +// shape: The shape of the output tensor. +// minval: 0-D. Inclusive lower bound on the generated integers. +// maxval: 0-D. Exclusive upper bound on the generated integers. +// +// Returns A tensor of the specified shape filled with uniform random integers. +func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomUniformInt", + Input: []tf.Input{ + shape, minval, maxval, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ArgMinAttr is an optional argument to ArgMin. +type ArgMinAttr func(optionalAttr) + +// ArgMinOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMinOutputType(value tf.DataType) ArgMinAttr { + return func(m optionalAttr) { + m["output_type"] = value + } +} + +// Returns the index with the smallest value across dimensions of a tensor. +// +// Note that in case of ties the identity of the return value is not guaranteed. +// +// Usage: +// ```python +// import tensorflow as tf +// a = [1, 10, 26.9, 2.8, 166.32, 62.3] +// b = tf.math.argmin(input = a) +// c = tf.keras.backend.eval(b) +// # c = 0 +// # here a[0] = 1 which is the smallest element of a across axis 0 +// ``` +// +// Arguments: +// +// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. +// Describes which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ArgMin", + Input: []tf.Input{ + input, dimension, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ExperimentalParseExampleDatasetAttr is an optional argument to ExperimentalParseExampleDataset. +type ExperimentalParseExampleDatasetAttr func(optionalAttr) + +// ExperimentalParseExampleDatasetSloppy sets the optional sloppy attribute to value. +// If not specified, defaults to false +func ExperimentalParseExampleDatasetSloppy(value bool) ExperimentalParseExampleDatasetAttr { + return func(m optionalAttr) { + m["sloppy"] = value + } +} + +// Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features. +// +// Arguments: +// +// +// dense_defaults: A dict mapping string keys to `Tensor`s. +// The keys of the dict must match the dense_keys of the feature. +// sparse_keys: A list of string keys in the examples features. +// The results for these keys will be returned as `SparseTensor` objects. +// dense_keys: A list of Ndense string Tensors (scalars). +// The keys expected in the Examples features associated with dense values. +// sparse_types: A list of `DTypes` of the same length as `sparse_keys`. +// Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), +// and `tf.string` (`BytesList`) are supported. +// dense_shapes: List of tuples with the same length as `dense_keys`. +// The shape of the data for each dense feature referenced by `dense_keys`. +// Required for any input tensors identified by `dense_keys`. Must be +// either fully defined, or may contain an unknown first dimension. +// An unknown first dimension means the feature is treated as having +// a variable number of blocks, and the output shape along this dimension +// is considered unknown at graph build time. Padding is applied for +// minibatch elements smaller than the maximum number of blocks for the +// given feature along this dimension. +// output_types: The type list for the return values. +// output_shapes: The list of shapes being produced. +func ExperimentalParseExampleDataset(scope *Scope, input_dataset tf.Output, num_parallel_calls tf.Output, dense_defaults []tf.Output, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ExperimentalParseExampleDatasetAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes, "output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ExperimentalParseExampleDataset", + Input: []tf.Input{ + input_dataset, num_parallel_calls, tf.OutputList(dense_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Transforms a tf.Example proto (as a string) into typed tensors. +// +// Arguments: +// serialized: A vector containing a batch of binary serialized Example protos. +// dense_defaults: A list of Tensors (some may be empty), whose length matches +// the length of `dense_keys`. dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// num_sparse: The number of sparse features to be parsed from the example. This +// must match the lengths of `sparse_keys` and `sparse_types`. +// sparse_keys: A list of `num_sparse` strings. +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: The keys expected in the Examples' features associated with dense +// values. +// sparse_types: A list of `num_sparse` types; the data types of data in each +// Feature given in sparse_keys. +// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: The shapes of data in each Feature given in dense_keys. +// The length of this list must match the length of `dense_keys`. The +// number of elements in the Feature corresponding to dense_key[j] must +// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == +// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] +// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, +// ..., DN), the shape of the output Tensor dense_values[j] will be (M, +// D1, .., DN), where M is the number of blocks of elements of length +// D1 * .... * DN, in the input. +func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} + opspec := tf.OpSpec{ + Type: "ParseSingleExample", + Input: []tf.Input{ + serialized, tf.OutputList(dense_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + return sparse_indices, sparse_values, sparse_shapes, dense_values +} + +// Computes the maximum along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// This operator is similar to the unsorted segment sum operator found +// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). +// Instead of computing the sum over segments, it computes the maximum such that: +// +// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such +// that `segment_ids[j...] == i`. +// +// If the maximum is empty for a given segment ID `i`, it outputs the smallest +// possible value for the specific numeric type, +// `output[i] = numeric_limits::lowest()`. +// +// If the given segment ID `i` is negative, then the corresponding value is +// dropped, and will not be included in the result. +// +//
+// +//
+// +// For example: +// +// ``` python +// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) +// tf.unsorted_segment_max(c, tf.constant([0, 1, 0]), num_segments=2) +// # ==> [[ 4, 3, 3, 4], +// # [5, 6, 7, 8]] +// ``` +// +// +// Arguments: +// +// segment_ids: A tensor whose shape is a prefix of `data.shape`. +// +// +// Returns Has same shape as data, except for the first `segment_ids.rank` +// dimensions, which are replaced with a single dimension which has size +// `num_segments`. +func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "UnsortedSegmentMax", + Input: []tf.Input{ + data, segment_ids, num_segments, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Advance the counter of a counter-based RNG. +// +// The state of the RNG after +// `rng_skip(n)` will be the same as that after `stateful_uniform([n])` +// (or any other distribution). The actual increment added to the +// counter is an unspecified implementation detail. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// delta: The amount of advancement. +// +// Returns the created operation. +func RngSkip(scope *Scope, resource tf.Output, algorithm tf.Output, delta tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RngSkip", + Input: []tf.Input{ + resource, algorithm, delta, + }, + } + return scope.AddOperation(opspec) +} + +// Converts each string in the input Tensor to its hash mod by a number of buckets. +// +// The hash function is deterministic on the content of the string within the +// process. The hash function is a keyed hash function, where attribute `key` +// defines the key of the hash function. `key` is an array of 2 elements. +// +// A strong hash is important when inputs may be malicious, e.g. URLs with +// additional components. Adversaries could try to make their inputs hash to the +// same bucket for a denial-of-service attack or to skew the results. A strong +// hash can be used to make it difficult to find inputs with a skewed hash value +// distribution over buckets. This requires that the hash function is +// seeded by a high-entropy (random) "key" unknown to the adversary. +// +// The additional robustness comes at a cost of roughly 4x higher compute +// time than `tf.string_to_hash_bucket_fast`. +// +// Arguments: +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// key: The key used to seed the hash function, passed as a list of two uint64 +// elements. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} + opspec := tf.OpSpec{ + Type: "StringToHashBucketStrong", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. +type AvgPool3DGradAttr func(optionalAttr) + +// AvgPool3DGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of average pooling function. +// +// Arguments: +// orig_input_shape: The original input dimensions. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +// +// Returns The backprop for input. +func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AvgPool3DGrad", + Input: []tf.Input{ + orig_input_shape, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Execute a sub graph on a remote processor. +// +// The graph specifications(such as graph itself, input tensors and output names) +// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo +// as serialized_remote_fused_graph_execute_info. +// The specifications will be passed to a dedicated registered +// remote fused graph executor. The executor will send the graph specifications +// to a remote processor and execute that graph. The execution results +// will be passed to consumer nodes as outputs of this node. +// +// Arguments: +// inputs: Arbitrary number of tensors with arbitrary data types +// +// serialized_remote_fused_graph_execute_info: Serialized protocol buffer +// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// +// Returns Arbitrary number of tensors with arbitrary data types +func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} + opspec := tf.OpSpec{ + Type: "RemoteFusedGraphExecute", Input: []tf.Input{ tf.OutputList(inputs), }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("RemoteFusedGraphExecute", err) + return + } + return outputs } -// Returns which elements of x are finite. +// Locks a mutex resource. The output is the lock. So long as the lock tensor // -// @compatibility(numpy) -// Equivalent to np.isfinite -// @end_compatibility -func IsFinite(scope *Scope, x tf.Output) (y tf.Output) { +// is alive, any other request to use `MutexLock` with this mutex will wait. +// +// This is particularly useful for creating a critical section when used in +// conjunction with `MutexLockIdentity`: +// +// ```python +// +// mutex = mutex_v2( +// shared_name=handle_name, container=container, name=name) +// +// def execute_in_critical_section(fn, *args, **kwargs): +// lock = gen_resource_variable_ops.mutex_lock(mutex) +// +// with ops.control_dependencies([lock]): +// r = fn(*args, **kwargs) +// +// with ops.control_dependencies(nest.flatten(r)): +// with ops.colocate_with(mutex): +// ensure_lock_exists = mutex_lock_identity(lock) +// +// # Make sure that if any element of r is accessed, all of +// # them are executed together. +// r = nest.map_structure(tf.identity, r) +// +// with ops.control_dependencies([ensure_lock_exists]): +// return nest.map_structure(tf.identity, r) +// ``` +// +// While `fn` is running in the critical section, no other functions which wish to +// use this critical section may run. +// +// Often the use case is that two executions of the same graph, in parallel, +// wish to run `fn`; and we wish to ensure that only one of them executes +// at a time. This is especially important if `fn` modifies one or more +// variables at a time. +// +// It is also useful if two separate functions must share a resource, but we +// wish to ensure the usage is exclusive. +// +// Arguments: +// mutex: The mutex resource to lock. +// +// Returns A tensor that keeps a shared pointer to a lock on the mutex; +// when the Tensor is destroyed, the use count on the shared pointer is decreased +// by 1. When it reaches 0, the lock is released. +func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IsFinite", + Type: "MutexLock", Input: []tf.Input{ - x, + mutex, }, } op := scope.AddOperation(opspec) return op.Output(0) } +// QuantizedReluXAttr is an optional argument to QuantizedReluX. +type QuantizedReluXAttr func(optionalAttr) + +// QuantizedReluXOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QUINT8 +func QuantizedReluXOutType(value tf.DataType) QuantizedReluXAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)` +// +// Arguments: +// +// +// min_features: The float value that the lowest quantized value represents. +// max_features: The float value that the highest quantized value represents. +// +// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. +func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluXAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedReluX", + Input: []tf.Input{ + features, max_value, min_features, max_features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// FractionalAvgPoolAttr is an optional argument to FractionalAvgPool. +type FractionalAvgPoolAttr func(optionalAttr) + +// FractionalAvgPoolPseudoRandom sets the optional pseudo_random attribute to value. +// +// value: When set to True, generates the pooling sequence in a +// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin +// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for +// difference between pseudorandom and random. +// If not specified, defaults to false +func FractionalAvgPoolPseudoRandom(value bool) FractionalAvgPoolAttr { + return func(m optionalAttr) { + m["pseudo_random"] = value + } +} + +// FractionalAvgPoolOverlapping sets the optional overlapping attribute to value. +// +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [41/3, 26/3] for fractional avg pooling. +// If not specified, defaults to false +func FractionalAvgPoolOverlapping(value bool) FractionalAvgPoolAttr { + return func(m optionalAttr) { + m["overlapping"] = value + } +} + +// FractionalAvgPoolDeterministic sets the optional deterministic attribute to value. +// +// value: When set to True, a fixed pooling region will be used when +// iterating over a FractionalAvgPool node in the computation graph. Mainly used +// in unit test to make FractionalAvgPool deterministic. +// If not specified, defaults to false +func FractionalAvgPoolDeterministic(value bool) FractionalAvgPoolAttr { + return func(m optionalAttr) { + m["deterministic"] = value + } +} + +// FractionalAvgPoolSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FractionalAvgPoolSeed(value int64) FractionalAvgPoolAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// FractionalAvgPoolSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FractionalAvgPoolSeed2(value int64) FractionalAvgPoolAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Performs fractional average pooling on the input. +// +// Fractional average pooling is similar to Fractional max pooling in the pooling +// region generation step. The only difference is that after pooling regions are +// generated, a mean operation is performed instead of a max operation in each +// pooling region. +// +// Arguments: +// value: 4-D with shape `[batch, height, width, channels]`. +// pooling_ratio: Pooling ratio for each dimension of `value`, currently only +// supports row and col dimension and should be >= 1.0. For example, a valid +// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements +// must be 1.0 because we don't allow pooling on batch and channels +// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions +// respectively. +// +// Returns output tensor after fractional avg pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. +func FractionalAvgPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalAvgPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FractionalAvgPool", + Input: []tf.Input{ + value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// HashTableV2Attr is an optional argument to HashTableV2. +type HashTableV2Attr func(optionalAttr) + +// HashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func HashTableV2Container(value string) HashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// HashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func HashTableV2SharedName(value string) HashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates a non-initialized hash table. +// +// This op creates a hash table, specifying the type of its keys and values. +// Before using the table you will have to initialize it. After initialization the +// table will be immutable. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "HashTableV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reduces sparse updates into the variable referenced by `resource` using the `max` operation. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] = max(ref[indices, ...], updates[...]) +// +// # Vector indices (for each i) +// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions are combined. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterMax(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterMax", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// CollectiveReduceAttr is an optional argument to CollectiveReduce. +type CollectiveReduceAttr func(optionalAttr) + +// CollectiveReduceWaitFor sets the optional wait_for attribute to value. +// If not specified, defaults to <> +func CollectiveReduceWaitFor(value []int64) CollectiveReduceAttr { + return func(m optionalAttr) { + m["wait_for"] = value + } +} + +// Mutually reduces multiple tensors of identical type and shape. +func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64, optional ...CollectiveReduceAttr) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CollectiveReduce", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Split the data from the input value into TensorArray elements. +// +// Assuming that `lengths` takes on values +// +// ```(n0, n1, ..., n(T-1))``` +// +// and that `value` has shape +// +// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, +// +// this splits values into a TensorArray with T tensors. +// +// TensorArray index t will be the subtensor of values with starting position +// +// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` +// +// and having size +// +// ```nt x d0 x d1 x ...``` +// +// Arguments: +// handle: The handle to a TensorArray. +// value: The concatenated tensor to write to the TensorArray. +// lengths: The vector of lengths, how to split the rows of value into the +// TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// +// Returns A float scalar that enforces proper chaining of operations. +func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArraySplitV3", + Input: []tf.Input{ + handle, value, lengths, flow_in, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Divides sparse updates into the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] /= updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] /= updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions multiply. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterDiv", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// ResourceGatherAttr is an optional argument to ResourceGather. +type ResourceGatherAttr func(optionalAttr) + +// ResourceGatherBatchDims sets the optional batch_dims attribute to value. +// If not specified, defaults to 0 +func ResourceGatherBatchDims(value int64) ResourceGatherAttr { + return func(m optionalAttr) { + m["batch_dims"] = value + } +} + +// ResourceGatherValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func ResourceGatherValidateIndices(value bool) ResourceGatherAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Gather slices from the variable pointed to by `resource` according to `indices`. +// +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: +// +// ```python +// # Scalar indices +// output[:, ..., :] = params[indices, :, ... :] +// +// # Vector indices +// output[i, :, ..., :] = params[indices[i], :, ... :] +// +// # Higher rank indices +// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] +// ``` +func ResourceGather(scope *Scope, resource tf.Output, indices tf.Output, dtype tf.DataType, optional ...ResourceGatherAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceGather", + Input: []tf.Input{ + resource, indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// EncodeJpegAttr is an optional argument to EncodeJpeg. +type EncodeJpegAttr func(optionalAttr) + +// EncodeJpegFormat sets the optional format attribute to value. +// +// value: Per pixel image format. +// If not specified, defaults to "" +func EncodeJpegFormat(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["format"] = value + } +} + +// EncodeJpegQuality sets the optional quality attribute to value. +// +// value: Quality of the compression from 0 to 100 (higher is better and slower). +// If not specified, defaults to 95 +func EncodeJpegQuality(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["quality"] = value + } +} + +// EncodeJpegProgressive sets the optional progressive attribute to value. +// +// value: If True, create a JPEG that loads progressively (coarse to fine). +// If not specified, defaults to false +func EncodeJpegProgressive(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["progressive"] = value + } +} + +// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. +// +// value: If True, spend CPU/RAM to reduce size with no quality change. +// If not specified, defaults to false +func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["optimize_size"] = value + } +} + +// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. +// +// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. +// If not specified, defaults to true +func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["chroma_downsampling"] = value + } +} + +// EncodeJpegDensityUnit sets the optional density_unit attribute to value. +// +// value: Unit used to specify `x_density` and `y_density`: +// pixels per inch (`'in'`) or centimeter (`'cm'`). +// If not specified, defaults to "in" +func EncodeJpegDensityUnit(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["density_unit"] = value + } +} + +// EncodeJpegXDensity sets the optional x_density attribute to value. +// +// value: Horizontal pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegXDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["x_density"] = value + } +} + +// EncodeJpegYDensity sets the optional y_density attribute to value. +// +// value: Vertical pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegYDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["y_density"] = value + } +} + +// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. +// +// value: If not empty, embed this XMP metadata in the image header. +// If not specified, defaults to "" +func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["xmp_metadata"] = value + } +} + +// JPEG-encode an image. +// +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// +// The attr `format` can be used to override the color format of the encoded +// output. Values can be: +// +// * `''`: Use a default format based on the number of channels in the image. +// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension +// of `image` must be 1. +// * `rgb`: Output an RGB JPEG image. The `channels` dimension +// of `image` must be 3. +// +// If `format` is not specified or is the empty string, a default format is picked +// in function of the number of channels in `image`: +// +// * 1: Output a grayscale image. +// * 3: Output an RGB image. +// +// Arguments: +// image: 3-D with shape `[height, width, channels]`. +// +// Returns 0-D. JPEG-encoded image. +func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EncodeJpeg", + Input: []tf.Input{ + image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Checks whether a resource handle-based variable has been initialized. +// +// Arguments: +// resource: the input resource handle. +// +// Returns a scalar boolean which is true if the variable has been +// initialized. +func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "VarIsInitializedOp", + Input: []tf.Input{ + resource, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizedDepthwiseConv2DWithBiasAndReluAttr is an optional argument to QuantizedDepthwiseConv2DWithBiasAndRelu. +type QuantizedDepthwiseConv2DWithBiasAndReluAttr func(optionalAttr) + +// QuantizedDepthwiseConv2DWithBiasAndReluOutType sets the optional out_type attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_QINT32 +func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) QuantizedDepthwiseConv2DWithBiasAndReluAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value. +// +// value: List of dilation values. +// If not specified, defaults to +func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes quantized depthwise Conv2D with Bias and Relu. +// +// Arguments: +// input: The original input tensor. +// filter: The original filter tensor. +// bias: The original bias tensor. +// min_input: The float value that the minimum quantized input value represents. +// max_input: The float value that the maximum quantized input value represents. +// min_filter: The float value that the minimum quantized filter value represents. +// max_filter: The float value that the maximum quantized filter value represents. +// strides: List of stride values. +// +// +// Returns The output tensor.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedDepthwiseConv2DWithBiasAndRelu(scope *Scope, input tf.Output, filter tf.Output, bias tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedDepthwiseConv2DWithBiasAndReluAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedDepthwiseConv2DWithBiasAndRelu", + Input: []tf.Input{ + input, filter, bias, min_input, max_input, min_filter, max_filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Draw bounding boxes on a batch of images. +// +// Outputs a copy of `images` but draws on top of the pixels zero or more bounding +// boxes specified by the locations in `boxes`. The coordinates of the each +// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. +// +// For example, if an image is 100 x 200 pixels (height x width) and the bounding +// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of +// the bounding box will be `(40, 10)` to `(100, 50)` (in (x,y) coordinates). +// +// Parts of the bounding box may fall outside the image. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. +// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding +// boxes. +// colors: 2-D. A list of RGBA colors to cycle through for the boxes. +// +// Returns 4-D with the same shape as `images`. The batch of input images with +// bounding boxes drawn on the images. +func DrawBoundingBoxesV2(scope *Scope, images tf.Output, boxes tf.Output, colors tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DrawBoundingBoxesV2", + Input: []tf.Input{ + images, boxes, colors, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inverse fast Fourier transform. +// +// Computes the inverse 1-dimensional discrete Fourier transform over the +// inner-most dimension of `input`. +// +// Arguments: +// input: A complex tensor. +// +// Returns A complex tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its inverse 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft +// @end_compatibility +func IFFT(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IFFT", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. +type ResourceApplyMomentumAttr func(optionalAttr) + +// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, the tensor passed to compute grad will be +// var - lr * momentum * accum, so in the end, the var you get is actually +// var - lr * momentum * accum. +// If not specified, defaults to false +func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the momentum scheme. Set use_nesterov = True if you +// +// want to use Nesterov momentum. +// +// accum = accum * momentum + grad +// var -= lr * accum +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// grad: The gradient. +// momentum: Momentum. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyMomentum", + Input: []tf.Input{ + var_, accum, lr, grad, momentum, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Creates an Optional variant with no value. +func OptionalNone(scope *Scope) (optional tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "OptionalNone", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reads the value of a variable. +// +// The tensor returned by this operation is immutable. +// +// The value returned by this operation is guaranteed to be influenced by all the +// writes on which this operation depends directly or indirectly, and to not be +// influenced by any of the writes which depend directly or indirectly on this +// operation. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// dtype: the dtype of the value. +func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "ReadVariableOp", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. +type DataFormatDimMapAttr func(optionalAttr) + +// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. +// +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { + return func(m optionalAttr) { + m["src_format"] = value + } +} + +// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. +// +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { + return func(m optionalAttr) { + m["dst_format"] = value + } +} + +// Returns the dimension index in the destination data format given the one in +// +// the source data format. +// +// Arguments: +// x: A Tensor with each element as a dimension index in source data format. +// Must be in the range [-4, 4). +// +// Returns A Tensor with each element as a dimension index in destination data format. +func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DataFormatDimMap", + Input: []tf.Input{ + x, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Writes the given dataset to the given file using the TFRecord format. +// +// Arguments: +// input_dataset: A variant tensor representing the dataset to write. +// filename: A scalar string tensor representing the filename to use. +// compression_type: A scalar string tensor containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// +// Returns the created operation. +func ExperimentalDatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ExperimentalDatasetToTFRecord", + Input: []tf.Input{ + input_dataset, filename, compression_type, + }, + } + return scope.AddOperation(opspec) +} + +// QuantizedConv2DAttr is an optional argument to QuantizedConv2D. +type QuantizedConv2DAttr func(optionalAttr) + +// QuantizedConv2DOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// QuantizedConv2DDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes a 2D convolution given quantized 4D input and filter tensors. +// +// The inputs are quantized tensors where the lowest value represents the real +// number of the associated minimum, and the highest represents the maximum. +// This means that you can only interpret the quantized output in the same way, by +// taking the returned minimum and maximum values into account. +// +// Arguments: +// +// filter: filter's input_depth dimension must match input's depth dimensions. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// min_filter: The float value that the lowest quantized filter value represents. +// max_filter: The float value that the highest quantized filter value represents. +// strides: The stride of the sliding window for each dimension of the input +// tensor. +// padding: The type of padding algorithm to use. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedConv2D", + Input: []tf.Input{ + input, filter, min_input, max_input, min_filter, max_filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// DecodePaddedRawAttr is an optional argument to DecodePaddedRaw. +type DecodePaddedRawAttr func(optionalAttr) + +// DecodePaddedRawLittleEndian sets the optional little_endian attribute to value. +// +// value: Whether the input `input_bytes` is in little-endian order. Ignored for +// `out_type` values that are stored in a single byte, like `uint8` +// If not specified, defaults to true +func DecodePaddedRawLittleEndian(value bool) DecodePaddedRawAttr { + return func(m optionalAttr) { + m["little_endian"] = value + } +} + +// Reinterpret the bytes of a string as a vector of numbers. +// +// Arguments: +// input_bytes: Tensor of string to be decoded. +// fixed_length: Length in bytes for each element of the decoded output. Must be a multiple +// of the size of the output type. +// +// +// Returns A Tensor with one more dimension than the input `bytes`. The added dimension +// will have size equal to the length of the elements of `bytes` divided by the +// number of bytes to represent `out_type`. +func DecodePaddedRaw(scope *Scope, input_bytes tf.Output, fixed_length tf.Output, out_type tf.DataType, optional ...DecodePaddedRawAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodePaddedRaw", + Input: []tf.Input{ + input_bytes, fixed_length, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the truth value of x OR y element-wise. +// +// *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LogicalOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LogicalOr", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes fingerprints of the input strings. +// +// Arguments: +// input: vector of strings to compute fingerprints on. +// +// Returns a (N,2) shaped matrix where N is the number of elements in the input +// vector. Each row contains the low and high parts of the fingerprint. +func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SdcaFprint", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TopKV2Attr is an optional argument to TopKV2. +type TopKV2Attr func(optionalAttr) + +// TopKV2Sorted sets the optional sorted attribute to value. +// +// value: If true the resulting `k` elements will be sorted by the values in +// descending order. +// If not specified, defaults to true +func TopKV2Sorted(value bool) TopKV2Attr { + return func(m optionalAttr) { + m["sorted"] = value + } +} + +// Finds values and indices of the `k` largest elements for the last dimension. +// +// If the input is a vector (rank-1), finds the `k` largest entries in the vector +// and outputs their values and indices as vectors. Thus `values[j]` is the +// `j`-th largest entry in `input`, and its index is `indices[j]`. +// +// For matrices (resp. higher rank input), computes the top `k` entries in each +// row (resp. vector along the last dimension). Thus, +// +// values.shape = indices.shape = input.shape[:-1] + [k] +// +// If two elements are equal, the lower-index element appears first. +// +// Arguments: +// input: 1-D or higher with last dimension at least `k`. +// k: 0-D. Number of top elements to look for along the last dimension (along each +// row for matrices). +// +// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. +func TopKV2(scope *Scope, input tf.Output, k tf.Output, optional ...TopKV2Attr) (values tf.Output, indices tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TopKV2", + Input: []tf.Input{ + input, k, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// SdcaOptimizerAttr is an optional argument to SdcaOptimizer. +type SdcaOptimizerAttr func(optionalAttr) + +// SdcaOptimizerAdaptative sets the optional adaptative attribute to value. +// +// value: Whether to use Adaptive SDCA for the inner loop. +// If not specified, defaults to true +func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { + return func(m optionalAttr) { + m["adaptative"] = value + } +} + +// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for +// +// linear models with L1 + L2 regularization. As global optimization objective is +// strongly-convex, the optimizer optimizes the dual objective at each step. The +// optimizer applies each update one example at a time. Examples are sampled +// uniformly, and the optimizer is learning rate free and enjoys linear convergence +// rate. +// +// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
+// Shai Shalev-Shwartz, Tong Zhang. 2012 +// +// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ +// +// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
+// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, +// Peter Richtarik, Martin Takac. 2015 +// +// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
+// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 +// +// Arguments: +// sparse_example_indices: a list of vectors which contain example indices. +// sparse_feature_indices: a list of vectors which contain feature indices. +// sparse_feature_values: a list of vectors which contains feature value +// associated with each feature group. +// dense_features: a list of matrices which contains the dense feature values. +// example_weights: a vector which contains the weight associated with each +// example. +// example_labels: a vector which contains the label/target associated with each +// example. +// sparse_indices: a list of vectors where each value is the indices which has +// corresponding weights in sparse_weights. This field maybe omitted for the +// dense approach. +// sparse_weights: a list of vectors where each value is the weight associated with +// a sparse feature group. +// dense_weights: a list of vectors where the values are the weights associated +// with a dense feature group. +// example_state_data: a list of vectors containing the example state data. +// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, +// squared and hinge losses. +// l1: Symmetric l1 regularization strength. +// l2: Symmetric l2 regularization strength. +// num_loss_partitions: Number of partitions of the global loss function. +// num_inner_iterations: Number of iterations per mini-batch. +// +// Returns a list of vectors containing the updated example state +// data.a list of vectors where each value is the delta +// weights associated with a sparse feature group.a list of vectors where the values are the delta +// weights associated with a dense feature group. +func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SdcaOptimizer", + Input: []tf.Input{ + tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + out_example_state_data = op.Output(idx) + if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return + } + if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return + } + return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights +} + // ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign. type ResourceStridedSliceAssignAttr func(optionalAttr) @@ -16035,27 +13489,111 @@ func ResourceStridedSliceAssign(scope *Scope, ref tf.Output, begin tf.Output, en return scope.AddOperation(opspec) } -// ArgMaxAttr is an optional argument to ArgMax. -type ArgMaxAttr func(optionalAttr) +// Checks whether a tree has been initialized. +// +// Arguments: +// tree_handle: Handle to the tree. +// +// Returns Whether the tree is initialized. +func TensorForestTreeIsInitializedOp(scope *Scope, tree_handle tf.Output) (is_initialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorForestTreeIsInitializedOp", + Input: []tf.Input{ + tree_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// ArgMaxOutputType sets the optional output_type attribute to value. -// If not specified, defaults to DT_INT64 -func ArgMaxOutputType(value tf.DataType) ArgMaxAttr { +// Returns the batched diagonal part of a batched tensor. +// +// This operation returns a tensor with the `diagonal` part +// of the batched `input`. The `diagonal` part is computed as follows: +// +// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a +// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where: +// +// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`. +// +// The input must be at least a matrix. +// +// For example: +// +// ``` +// # 'input' is [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] +// +// and input.shape = (2, 4, 4) +// +// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]] +// +// which has shape (2, 4) +// ``` +// +// Arguments: +// input: Rank `k` tensor where `k >= 2`. +// +// Returns The extracted diagonal(s) having shape +// `diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`. +func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixDiagPart", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. +type ResourceApplyFtrlAttr func(optionalAttr) + +// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { return func(m optionalAttr) { - m["output_type"] = value + m["use_locking"] = value } } -// Returns the index with the largest value across dimensions of a tensor. +// Update '*var' according to the Ftrl-proximal scheme. // -// Note that in case of ties the identity of the return value is not guaranteed. +// accum_new = accum + grad * grad +// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // // Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regulariation. Must be a scalar. +// l2: L2 regulariation. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. // -// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. -// Describes which dimension of the input Tensor to reduce across. For vectors, -// use dimension = 0. -func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) { +// Returns the created operation. +func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -16064,9 +13602,99 @@ func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgM a(attrs) } opspec := tf.OpSpec{ - Type: "ArgMax", + Type: "ResourceApplyFtrl", Input: []tf.Input{ - input, dimension, + var_, accum, linear, grad, lr, l1, l2, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Computes the product along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Computes a tensor such that +// \\(output_i = \prod_j data_j\\) where the product is over `j` such +// that `segment_ids[j] == i`. +// +// If the product is empty for a given segment ID `i`, `output[i] = 1`. +// +//
+// +//
+// +// For example: +// +// ``` +// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_prod(c, tf.constant([0, 0, 1])) +// # ==> [[4, 6, 6, 4], +// # [5, 6, 7, 8]] +// ``` +// +// +// Arguments: +// +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SegmentProd", + Input: []tf.Input{ + data, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Transforms a Tensor into a serialized TensorProto proto. +// +// Arguments: +// tensor: A Tensor of type `T`. +// +// Returns A serialized TensorProto proto of the input tensor. +func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SerializeTensor", + Input: []tf.Input{ + tensor, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Transforms a serialized tensorflow.TensorProto proto into a Tensor. +// +// Arguments: +// serialized: A scalar string containing a serialized TensorProto proto. +// out_type: The type of the serialized tensor. The provided type must match the +// type of the serialized tensor and no implicit conversion will take place. +// +// Returns A Tensor of type `out_type`. +func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + opspec := tf.OpSpec{ + Type: "ParseTensor", + Input: []tf.Input{ + serialized, }, Attrs: attrs, } @@ -16074,21 +13702,695 @@ func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgM return op.Output(0) } -// Fetches multiple values from infeed as an XLA tuple. +// ResourceScatterNdSubAttr is an optional argument to ResourceScatterNdSub. +type ResourceScatterNdSubAttr func(optionalAttr) + +// ResourceScatterNdSubUseLocking sets the optional use_locking attribute to value. +// +// value: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, +// but may exhibit less contention. +// If not specified, defaults to true +func ResourceScatterNdSubUseLocking(value bool) ResourceScatterNdSubAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Applies sparse subtraction to individual values or slices in a Variable. +// +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// ``` +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] +// ``` +// +// For example, say we want to subtract 4 scattered elements from a rank-1 tensor +// with 8 elements. In Python, that subtraction would look like this: +// +// ```python +// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// sub = tf.scatter_nd_sub(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(sub) +// ``` +// +// The resulting update to ref would look like this: +// +// [1, -9, 3, -6, -4, 6, 7, -4] +// +// See `tf.scatter_nd` for more details about how to make updates to +// slices. // // Arguments: -// dtypes: The element types of each element in `outputs`. -// shapes: The shapes of each tensor in `outputs`. +// ref: A resource handle. Must be from a VarHandleOp. +// indices: A Tensor. Must be one of the following types: int32, int64. +// A tensor of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of +// values to add to ref. // -// Returns A list of tensors that will be provided using the infeed mechanism. -func InfeedDequeueTuple(scope *Scope, dtypes []tf.DataType, shapes []tf.Shape) (outputs []tf.Output) { +// Returns the created operation. +func ResourceScatterNdSub(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdSubAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes, "shapes": shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "InfeedDequeueTuple", + Type: "ResourceScatterNdSub", + Input: []tf.Input{ + ref, indices, updates, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} +// SparseReduceSumAttr is an optional argument to SparseReduceSum. +type SparseReduceSumAttr func(optionalAttr) + +// SparseReduceSumKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the sum of elements across dimensions of a SparseTensor. +// +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseReduceSum", + Input: []tf.Input{ + input_indices, input_values, input_shape, reduction_axes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RealAttr is an optional argument to Real. +type RealAttr func(optionalAttr) + +// RealTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func RealTout(value tf.DataType) RealAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Returns the real part of a complex number. +// +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the real part of each element in `input`. All elements in +// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real +// part returned by this operation and *b* is the imaginary part. +// +// For example: +// +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.real(input) ==> [-2.25, 3.25] +// ``` +func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Real", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MutexV2Attr is an optional argument to MutexV2. +type MutexV2Attr func(optionalAttr) + +// MutexV2Container sets the optional container attribute to value. +// +// value: If non-empty, this variable is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutexV2Container(value string) MutexV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutexV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this variable is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func MutexV2SharedName(value string) MutexV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a Mutex resource that can be locked by `MutexLock`. +// +// Returns The mutex resource. +func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutexV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FusedBatchNormV2Attr is an optional argument to FusedBatchNormV2. +type FusedBatchNormV2Attr func(optionalAttr) + +// FusedBatchNormV2Epsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormV2Epsilon(value float32) FusedBatchNormV2Attr { + return func(m optionalAttr) { + m["epsilon"] = value + } +} + +// FusedBatchNormV2DataFormat sets the optional data_format attribute to value. +// +// value: The data format for x and y. Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormV2DataFormat(value string) FusedBatchNormV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormV2IsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormV2IsTraining(value bool) FusedBatchNormV2Attr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// Batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// offset: A 1D Tensor for offset, to shift to the normalized x. +// mean: A 1D Tensor for population mean. Used for inference only; +// must be empty for training. +// variance: A 1D Tensor for population variance. Used for inference only; +// must be empty for training. +// +// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow +// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by +// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused +// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance +// in the cuDNN case), to be reused in the gradient computation. +func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormV2Attr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedBatchNormV2", + Input: []tf.Input{ + x, scale, offset, mean, variance, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + +// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] = min(ref[indices, ...], updates[...]) +// +// # Vector indices (for each i) +// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions are combined. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterMin", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// Returns the element-wise min of two SparseTensors. +// +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// +// Arguments: +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. +// +// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. +func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSparseMinimum", + Input: []tf.Input{ + a_indices, a_values, a_shape, b_indices, b_values, b_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Component-wise multiplies a SparseTensor by a dense Tensor. +// +// The output locations corresponding to the implicitly zero elements in the sparse +// tensor will be zero (i.e., will not take up storage space), regardless of the +// contents of the dense tensor (even if it's +/-INF and that INF*0 == NaN). +// +// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not +// the other direction. +// +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. +// +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseDenseCwiseMul", + Input: []tf.Input{ + sp_indices, sp_values, sp_shape, dense, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. +type Conv2DBackpropInputAttr func(optionalAttr) + +// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["use_cudnn_on_gpu"] = value + } +} + +// Conv2DBackpropInputExplicitPaddings sets the optional explicit_paddings attribute to value. +// +// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith +// dimension, the amount of padding inserted before and after the dimension is +// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If +// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. +// If not specified, defaults to <> +func Conv2DBackpropInputExplicitPaddings(value []int64) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["explicit_paddings"] = value + } +} + +// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv2DBackpropInputDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of convolution with respect to the input. +// +// Arguments: +// input_sizes: An integer vector representing the shape of `input`, +// where `input` is a 4-D `[batch, height, width, channels]` tensor. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient +// w.r.t. the input of the convolution. +func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv2DBackpropInput", + Input: []tf.Input{ + input_sizes, filter, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds `bias` to `value`. +// +// This is a deprecated version of BiasAdd and will be soon removed. +// +// This is a special case of `tf.add` where `bias` is restricted to be 1-D. +// Broadcasting is supported, so `value` may have any number of dimensions. +// +// Arguments: +// value: Any number of dimensions. +// bias: 1-D with size the last dimension of `value`. +// +// Returns Broadcasted sum of `value` and `bias`. +func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BiasAddV1", + Input: []tf.Input{ + value, bias, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Selects elements from `x` or `y`, depending on `condition`. +// +// The `x`, and `y` tensors must all have the same shape, and the +// output will also have that shape. +// +// The `condition` tensor must be a scalar if `x` and `y` are scalars. +// If `x` and `y` are vectors or higher rank, then `condition` must be either a +// scalar, a vector with size matching the first dimension of `x`, or must have +// the same shape as `x`. +// +// The `condition` tensor acts as a mask that chooses, based on the value at each +// element, whether the corresponding element / row in the output should be +// taken from `x` (if true) or `y` (if false). +// +// If `condition` is a vector and `x` and `y` are higher rank matrices, then +// it chooses which row (outer dimension) to copy from `x` and `y`. +// If `condition` has the same shape as `x` and `y`, then it chooses which +// element to copy from `x` and `y`. +// +// For example: +// +// ```python +// # 'condition' tensor is [[True, False] +// # [False, True]] +// # 't' is [[1, 2], +// # [3, 4]] +// # 'e' is [[5, 6], +// # [7, 8]] +// select(condition, t, e) # => [[1, 6], [7, 4]] +// +// +// # 'condition' tensor is [True, False] +// # 't' is [[1, 2], +// # [3, 4]] +// # 'e' is [[5, 6], +// # [7, 8]] +// select(condition, t, e) ==> [[1, 2], +// [7, 8]] +// +// ``` +// +// Arguments: +// +// x: = A `Tensor` which may have the same shape as `condition`. +// If `condition` is rank 1, `x` may have higher rank, +// but its first dimension must match the size of `condition`. +// y: = A `Tensor` with the same type and shape as `x`. +// +// Returns = A `Tensor` with the same type and shape as `x` and `y`. +func Select(scope *Scope, condition tf.Output, x tf.Output, y tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Select", + Input: []tf.Input{ + condition, x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. +type ParseSingleSequenceExampleAttr func(optionalAttr) + +// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. +// +// value: A list of Ncontext_sparse types; the data types of data in +// each context Feature given in context_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_sparse_types"] = value + } +} + +// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_types"] = value + } +} + +// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. +// +// value: A list of Ncontext_dense shapes; the shapes of data in +// each context Feature given in context_dense_keys. +// The number of elements in the Feature corresponding to context_dense_key[j] +// must always equal context_dense_shapes[j].NumEntries(). +// The shape of context_dense_values[j] will match context_dense_shapes[j]. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_dense_shapes"] = value + } +} + +// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. +// +// value: A list of Nfeature_list_sparse types; the data types +// of data in each FeatureList given in feature_list_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_sparse_types"] = value + } +} + +// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. +// +// value: A list of Nfeature_list_dense shapes; the shapes of +// data in each FeatureList given in feature_list_dense_keys. +// The shape of each Feature in the FeatureList corresponding to +// feature_list_dense_key[j] must always equal +// feature_list_dense_shapes[j].NumEntries(). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_shapes"] = value + } +} + +// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. +// +// Arguments: +// serialized: A scalar containing a binary serialized SequenceExample proto. +// feature_list_dense_missing_assumed_empty: A vector listing the +// FeatureList keys which may be missing from the SequenceExample. If the +// associated FeatureList is missing, it is treated as empty. By default, +// any FeatureList not listed in this vector must exist in the SequenceExample. +// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). +// The keys expected in the Examples' features associated with context_sparse +// values. +// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' context features associated with +// dense values. +// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors +// (scalars). The keys expected in the FeatureLists associated with sparse +// values. +// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' feature_lists associated +// with lists of dense values. +// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). +// context_dense_defaults[j] provides default values +// when the SequenceExample's context map lacks context_dense_key[j]. +// If an empty Tensor is provided for context_dense_defaults[j], +// then the Feature context_dense_keys[j] is required. +// The input type is inferred from context_dense_defaults[j], even when it's +// empty. If context_dense_defaults[j] is not empty, its shape must match +// context_dense_shapes[j]. +// debug_name: A scalar containing the name of the serialized proto. +// May contain, for example, table key (descriptive) name for the +// corresponding serialized proto. This is purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty scalar if no name is available. +func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ParseSingleSequenceExample", + Input: []tf.Input{ + serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, + }, Attrs: attrs, } op := scope.AddOperation(opspec) @@ -16097,69 +14399,375 @@ func InfeedDequeueTuple(scope *Scope, dtypes []tf.DataType, shapes []tf.Shape) ( } var idx int var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("InfeedDequeueTuple", err) + if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) return } - return outputs + if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values } -// Enqueue multiple Tensor values on the computation outfeed. +// Deprecated. Use TensorArrayReadV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 +func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "TensorArrayReadV2", + Input: []tf.Input{ + handle, index, flow_in, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Calculate product with tridiagonal matrix. +// +// Calculates product of two matrices, where left matrix is a tridiagonal matrix. // // Arguments: -// inputs: A list of tensors that will be inserted into the outfeed queue as an -// XLA tuple. +// superdiag: Tensor of shape `[..., 1, M]`, representing superdiagonals of +// tri-diagonal matrices to the left of multiplication. Last element is ingored. +// maindiag: Tensor of shape `[..., 1, M]`, representing main diagonals of tri-diagonal +// matrices to the left of multiplication. +// subdiag: Tensor of shape `[..., 1, M]`, representing subdiagonals of tri-diagonal +// matrices to the left of multiplication. First element is ingored. +// rhs: Tensor of shape `[..., M, N]`, representing MxN matrices to the right of +// multiplication. // -// Returns the created operation. -func OutfeedEnqueueTuple(scope *Scope, inputs []tf.Output) (o *tf.Operation) { +// Returns Tensor of shape `[..., M, N]` containing the product. +func TridiagonalMatMul(scope *Scope, superdiag tf.Output, maindiag tf.Output, subdiag tf.Output, rhs tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "OutfeedEnqueueTuple", + Type: "TridiagonalMatMul", Input: []tf.Input{ - tf.OutputList(inputs), + superdiag, maindiag, subdiag, rhs, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Component-wise divides a SparseTensor by a dense Tensor. +// +// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not +// the other direction. +// +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. +// +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseDenseCwiseDiv", + Input: []tf.Input{ + sp_indices, sp_values, sp_shape, dense, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. +// +// This Op does not require `a_indices` be sorted in standard lexicographic order. +// +// Arguments: +// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. +// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. +// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. +// b: `ndims`-D Tensor. With shape `a_shape`. +func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseTensorDenseAdd", + Input: []tf.Input{ + a_indices, a_values, a_shape, b, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes softmax activations. +// +// For each batch `i` and class `j` we have +// +// $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$ +// +// Arguments: +// logits: 2-D with shape `[batch_size, num_classes]`. +// +// Returns Same shape as `logits`. +func Softmax(scope *Scope, logits tf.Output) (softmax tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Softmax", + Input: []tf.Input{ + logits, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inverse 3D real-valued fast Fourier transform. +// +// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most 3 dimensions of `input`. +// +// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: +// The inner-most dimension contains the `fft_length / 2 + 1` unique components of +// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed +// from the size of the inner-most 3 dimensions of `input`. If the FFT length used +// to compute `input` is odd, it should be provided since it cannot be inferred +// properly. +// +// Along each axis `IRFFT3D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// +// Arguments: +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. +// +// Returns A float32 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the `fft_length` samples of their +// inverse 3D real Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.irfftn with 3 dimensions. +// @end_compatibility +func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IRFFT3D", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. +type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) + +// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of depthwise convolution with respect to the filter. +// +// Arguments: +// input: 4-D with shape based on `data_format`. For example, if +// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, +// in_width, in_channels]` tensor. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DepthwiseConv2dNativeBackpropFilter", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// L2 Loss. +// +// Computes half the L2 norm of a tensor without the `sqrt`: +// +// output = sum(t ** 2) / 2 +// +// Arguments: +// t: Typically 2-D, but may have any dimensions. +// +// Returns 0-D. +func L2Loss(scope *Scope, t tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "L2Loss", + Input: []tf.Input{ + t, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Replaces the contents of the table with the specified keys and values. +// +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableImportV2", + Input: []tf.Input{ + table_handle, keys, values, }, } return scope.AddOperation(opspec) } -// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. -type ResourceApplyAdagradAttr func(optionalAttr) - -// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value. -// If not specified, defaults to true -func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr { - return func(m optionalAttr) { - m["update_slots"] = value - } -} - -// Update '*var' according to the adagrad scheme. -// -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// Computes rectified linear 6 gradients for a Relu6 operation. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. +// gradients: The backpropagated gradients to the corresponding Relu6 operation. +// features: The features passed as input to the corresponding Relu6 operation, or +// its output; using either one produces the same result. // -// Returns the created operation. -func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { +// Returns The gradients: +// `gradients * (features > 0) * (features < 6)`. +func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Relu6Grad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeCompressedAttr is an optional argument to DecodeCompressed. +type DecodeCompressedAttr func(optionalAttr) + +// DecodeCompressedCompressionType sets the optional compression_type attribute to value. +// +// value: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// If not specified, defaults to "" +func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { + return func(m optionalAttr) { + m["compression_type"] = value + } +} + +// Decompress strings. +// +// This op decompresses each element of the `bytes` input `Tensor`, which +// is assumed to be compressed using the given `compression_type`. +// +// The `output` is a string `Tensor` of the same shape as `bytes`, +// each element containing the decompressed data from the corresponding +// element in `bytes`. +// +// Arguments: +// bytes: A Tensor of string which is compressed. +// +// Returns A Tensor with the same shape as input `bytes`, uncompressed +// from bytes. +func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -16168,108 +14776,132 @@ func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.O a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyAdagrad", + Type: "DecodeCompressed", Input: []tf.Input{ - var_, accum, lr, grad, + bytes, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// CudnnRNNV3Attr is an optional argument to CudnnRNNV3. -type CudnnRNNV3Attr func(optionalAttr) +// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. +type GenerateVocabRemappingAttr func(optionalAttr) -// CudnnRNNV3RnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNV3RnnMode(value string) CudnnRNNV3Attr { +// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. +// +// value: Number of entries in the old vocab file to consider. If -1, +// use the entire old vocabulary. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { return func(m optionalAttr) { - m["rnn_mode"] = value + m["old_vocab_size"] = value } } -// CudnnRNNV3InputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNV3InputMode(value string) CudnnRNNV3Attr { - return func(m optionalAttr) { - m["input_mode"] = value +// Given a path to new and old vocabulary files, returns a remapping Tensor of +// +// length `num_new_vocab`, where `remapping[i]` contains the row number in the old +// vocabulary that corresponds to row `i` in the new vocabulary (starting at line +// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` +// in the new vocabulary is not in the old vocabulary. The old vocabulary is +// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the +// default value of -1. +// +// `num_vocab_offset` enables +// use in the partitioned variable case, and should generally be set through +// examining partitioning info. The format of the files should be a text file, +// with each line containing a single entity within the vocabulary. +// +// For example, with `new_vocab_file` a text file containing each of the following +// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], +// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be +// `[0, -1, 2]`. +// +// The op also returns a count of how many entries in the new vocabulary +// were present in the old vocabulary, which is used to calculate the number of +// values to initialize in a weight matrix remapping +// +// This functionality can be used to remap both row vocabularies (typically, +// features) and column vocabularies (typically, classes) from TensorFlow +// checkpoints. Note that the partitioning logic relies on contiguous vocabularies +// corresponding to div-partitioned variables. Moreover, the underlying remapping +// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should +// use the corresponding index_table_from_file() as the FeatureColumn framework +// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). +// +// Arguments: +// new_vocab_file: Path to the new vocab file. +// old_vocab_file: Path to the old vocab file. +// new_vocab_offset: How many entries into the new vocab file to start reading. +// num_new_vocab: Number of entries in the new vocab file to remap. +// +// Returns A Tensor of length num_new_vocab where the element at index i +// is equal to the old ID that maps to the new ID i. This element is -1 for any +// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. +func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "GenerateVocabRemapping", + Input: []tf.Input{ + new_vocab_file, old_vocab_file, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// CudnnRNNV3Direction sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNV3Direction(value string) CudnnRNNV3Attr { - return func(m optionalAttr) { - m["direction"] = value - } -} +// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal. +type ParameterizedTruncatedNormalAttr func(optionalAttr) -// CudnnRNNV3Dropout sets the optional dropout attribute to value. +// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. // If not specified, defaults to 0 -func CudnnRNNV3Dropout(value float32) CudnnRNNV3Attr { - return func(m optionalAttr) { - m["dropout"] = value - } -} - -// CudnnRNNV3Seed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNV3Seed(value int64) CudnnRNNV3Attr { +func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr { return func(m optionalAttr) { m["seed"] = value } } -// CudnnRNNV3Seed2 sets the optional seed2 attribute to value. +// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. // If not specified, defaults to 0 -func CudnnRNNV3Seed2(value int64) CudnnRNNV3Attr { +func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr { return func(m optionalAttr) { m["seed2"] = value } } -// CudnnRNNV3IsTraining sets the optional is_training attribute to value. -// If not specified, defaults to true -func CudnnRNNV3IsTraining(value bool) CudnnRNNV3Attr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// A RNN backed by cuDNN. +// Outputs random values from a normal distribution. The parameters may each be a // -// Computes the RNN from the input and initial states, with respect to the params -// buffer. Accepts one extra input "sequence_lengths" than CudnnRNN. +// scalar which applies to the entire output, or a vector of length shape[0] which +// stores the parameters for each batch. // -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicates whether there is a linear projection between the input and -// the actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. Should be -// "unidirectional" or "bidirectional". -// dropout: Dropout probability. When set to 0., dropout is disabled. -// seed: The 1st part of a seed to initialize dropout. -// seed2: The 2nd part of a seed to initialize dropout. -// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. -// input_c: For LSTM, a 3-D tensor with the shape of -// [num_layer * dir, batch, num_units]. For other models, it is ignored. -// params: A 1-D tensor that contains the weights and biases in an opaque layout. -// The size must be created through CudnnRNNParamsSize, and initialized -// separately. Note that they might not be compatible across different -// generations. So it is a good idea to save and restore -// sequence_lengths: a vector of lengths of each input sequence. -// output: A 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. -// output_h: The same shape has input_h. -// output_c: The same shape as input_c for LSTM. An empty tensor for other models. -// is_training: Indicates whether this operation is used for inferenece or -// training. -// reserve_space: An opaque tensor that can be used in backprop calculation. It -// is only produced if is_training is true. -func CudnnRNNV3(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, sequence_lengths tf.Output, optional ...CudnnRNNV3Attr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output, host_reserved tf.Output) { +// Arguments: +// shape: The shape of the output tensor. Batches are indexed by the 0th dimension. +// means: The mean parameter of each batch. +// stdevs: The standard deviation parameter of each batch. Must be greater than 0. +// minvals: The minimum cutoff. May be -infinity. +// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval +// for each batch. +// +// Returns A matrix of shape num_batches x samples_per_batch, filled with random +// truncated normal values using the parameters for each row. +func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -16278,9 +14910,828 @@ func CudnnRNNV3(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Out a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNNV3", + Type: "ParameterizedTruncatedNormal", Input: []tf.Input{ - input, input_h, input_c, params, sequence_lengths, + shape, means, stdevs, minvals, maxvals, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Outputs random integers from a uniform distribution. +// +// The generated values are uniform integers in the range `[minval, maxval)`. +// The lower bound `minval` is included in the range, while the upper bound +// `maxval` is excluded. +// +// The random integers are slightly biased unless `maxval - minval` is an exact +// power of two. The bias is small for values of `maxval - minval` significantly +// smaller than the range of the output (either `2^32` or `2^64`). +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// shape: The shape of the output tensor. +// minval: Minimum value (inclusive, scalar). +// maxval: Maximum value (exclusive, scalar). +// +// Returns Random values with specified shape. +func StatefulUniformInt(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StatefulUniformInt", + Input: []tf.Input{ + resource, algorithm, shape, minval, maxval, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatefulTruncatedNormalAttr is an optional argument to StatefulTruncatedNormal. +type StatefulTruncatedNormalAttr func(optionalAttr) + +// StatefulTruncatedNormalDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatefulTruncatedNormalDtype(value tf.DataType) StatefulTruncatedNormalAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random values from a truncated normal distribution. +// +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// shape: The shape of the output tensor. +// +// Returns Random values with specified shape. +func StatefulTruncatedNormal(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, optional ...StatefulTruncatedNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatefulTruncatedNormal", + Input: []tf.Input{ + resource, algorithm, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DenseToSparseSetOperationAttr is an optional argument to DenseToSparseSetOperation. +type DenseToSparseSetOperationAttr func(optionalAttr) + +// DenseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func DenseToSparseSetOperationValidateIndices(value bool) DenseToSparseSetOperationAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Applies set operation along last dimension of `Tensor` and `SparseTensor`. +// +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// +// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, +// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same +// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set2` +// indices. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. +// +// Arguments: +// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. +// Dimension `n` contains values in a set, duplicates are allowed but ignored. +// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must +// be the same as the 1st `n-1` dimensions of `set1`, `result_shape[n]` is the +// max set size across `n-1` dimensions. +// +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func DenseToSparseSetOperation(scope *Scope, set1 tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...DenseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"set_operation": set_operation} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DenseToSparseSetOperation", + Input: []tf.Input{ + set1, set2_indices, set2_values, set2_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// SkipgramAttr is an optional argument to Skipgram. +type SkipgramAttr func(optionalAttr) + +// SkipgramWindowSize sets the optional window_size attribute to value. +// +// value: The number of words to predict to the left and right of the target. +// If not specified, defaults to 5 +func SkipgramWindowSize(value int64) SkipgramAttr { + return func(m optionalAttr) { + m["window_size"] = value + } +} + +// SkipgramMinCount sets the optional min_count attribute to value. +// +// value: The minimum number of word occurrences for it to be included in the +// vocabulary. +// If not specified, defaults to 5 +func SkipgramMinCount(value int64) SkipgramAttr { + return func(m optionalAttr) { + m["min_count"] = value + } +} + +// SkipgramSubsample sets the optional subsample attribute to value. +// +// value: Threshold for word occurrence. Words that appear with higher +// frequency will be randomly down-sampled. Set to 0 to disable. +// If not specified, defaults to 0.001 +func SkipgramSubsample(value float32) SkipgramAttr { + return func(m optionalAttr) { + m["subsample"] = value + } +} + +// Parses a text file and creates a batch of examples. +// +// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result +// +// Arguments: +// filename: The corpus's text file name. +// batch_size: The size of produced batch. +// +// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids. +func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Skipgram", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) +} + +// Deserialize `SparseTensor` objects. +// +// The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where +// the last dimension stores serialized `SparseTensor` objects and the other N +// dimensions (N >= 0) correspond to a batch. The ranks of the original +// `SparseTensor` objects must all match. When the final `SparseTensor` is +// created, its rank is the rank of the incoming `SparseTensor` objects plus N; +// the sparse tensors have been concatenated along new dimensions, one for each +// batch. +// +// The output `SparseTensor` object's shape values for the original dimensions +// are the max across the input `SparseTensor` objects' shape values for the +// corresponding dimensions. The new dimensions match the size of the batch. +// +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. +// +// For example, if the serialized input is a `[2 x 3]` matrix representing two +// original `SparseTensor` objects: +// +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// +// and +// +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// +// then the final deserialized `SparseTensor` will be: +// +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// +// Arguments: +// serialized_sparse: The serialized `SparseTensor` objects. The last dimension +// must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` objects. +func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "DeserializeSparse", + Input: []tf.Input{ + serialized_sparse, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// RegexReplaceAttr is an optional argument to RegexReplace. +type RegexReplaceAttr func(optionalAttr) + +// RegexReplaceReplaceGlobal sets the optional replace_global attribute to value. +// +// value: If True, the replacement is global (that is, all matches of the `pattern` regular +// expression in each input string are rewritten), otherwise the `rewrite` +// substitution is only made for the first `pattern` match. +// If not specified, defaults to true +func RegexReplaceReplaceGlobal(value bool) RegexReplaceAttr { + return func(m optionalAttr) { + m["replace_global"] = value + } +} + +// Replaces matches of the `pattern` regular expression in `input` with the +// replacement string provided in `rewrite`. +// +// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// +// Arguments: +// input: The text to be processed. +// pattern: The regular expression to be matched in the `input` strings. +// rewrite: The rewrite string to be substituted for the `pattern` expression where it is +// matched in the `input` strings. +// +// Returns The text after applying pattern match and rewrite substitution. +func RegexReplace(scope *Scope, input tf.Output, pattern tf.Output, rewrite tf.Output, optional ...RegexReplaceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RegexReplace", + Input: []tf.Input{ + input, pattern, rewrite, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Says whether the targets are in the top `K` predictions. +// +// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +// prediction for the target class is among the top `k` predictions among +// all predictions for example `i`. Note that the behavior of `InTopK` differs +// from the `TopK` op in its handling of ties; if multiple classes have the +// same prediction value and straddle the top-`k` boundary, all of those +// classes are considered to be in the top `k`. +// +// More formally, let +// +// \\(predictions_i\\) be the predictions for all classes for example `i`, +// \\(targets_i\\) be the target class for example `i`, +// \\(out_i\\) be the output for example `i`, +// +// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ +// +// Arguments: +// predictions: A `batch_size` x `classes` tensor. +// targets: A `batch_size` vector of class ids. +// k: Number of top elements to look at for computing precision. +// +// Returns Computed precision at `k` as a `bool Tensor`. +func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InTopKV2", + Input: []tf.Input{ + predictions, targets, k, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Exits the current frame to its parent frame. +// +// Exit makes its input `data` available to the parent frame. +// +// Arguments: +// data: The tensor to be made available to the parent frame. +// +// Returns The same tensor as `data`. +func Exit(scope *Scope, data tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Exit", + Input: []tf.Input{ + data, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// The gradient operator for the SparseAdd op. +// +// The SparseAdd op calculates A + B, where A, B, and the sum are all represented +// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t. +// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty +// values of A and B. +// +// Arguments: +// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to +// the non-empty values of the sum. +// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. +// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. +// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size +// `[nnz(sum), ndims]`. +// +// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the +// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the +// non-empty values of B. +func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseAddGrad", + Input: []tf.Input{ + backprop_val_grad, a_indices, b_indices, sum_indices, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// RetrieveTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to RetrieveTPUEmbeddingMDLAdagradLightParameters. +type RetrieveTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingMDLAdagradLightParametersTableId(value int64) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMDLAdagradLightParametersTableName(value string) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve MDL Adagrad Light embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the MDL Adagrad Light optimization algorithm.Parameter accumulators updated by the MDL Adagrad Light optimization algorithm.Parameter weights updated by the MDL Adagrad Light optimization algorithm.Parameter benefits updated by the MDL Adagrad Light optimization algorithm. +func RetrieveTPUEmbeddingMDLAdagradLightParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMDLAdagradLightParametersAttr) (parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingMDLAdagradLightParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// Fast Fourier transform. +// +// Computes the 1-dimensional discrete Fourier transform over the inner-most +// dimension of `input`. +// +// Arguments: +// input: A complex tensor. +// +// Returns A complex tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft +// @end_compatibility +func FFT(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// BoostedTreesCreateQuantileStreamResourceAttr is an optional argument to BoostedTreesCreateQuantileStreamResource. +type BoostedTreesCreateQuantileStreamResourceAttr func(optionalAttr) + +// BoostedTreesCreateQuantileStreamResourceMaxElements sets the optional max_elements attribute to value. +// +// value: int; The maximum number of data points that can be fed to the stream. +// If not specified, defaults to 1099511627776 +func BoostedTreesCreateQuantileStreamResourceMaxElements(value int64) BoostedTreesCreateQuantileStreamResourceAttr { + return func(m optionalAttr) { + m["max_elements"] = value + } +} + +// Create the Resource for Quantile Streams. +// +// Arguments: +// quantile_stream_resource_handle: resource; Handle to quantile stream resource. +// epsilon: float; The required approximation error of the stream resource. +// num_streams: int; The number of streams managed by the resource that shares the same epsilon. +// +// Returns the created operation. +func BoostedTreesCreateQuantileStreamResource(scope *Scope, quantile_stream_resource_handle tf.Output, epsilon tf.Output, num_streams tf.Output, optional ...BoostedTreesCreateQuantileStreamResourceAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "BoostedTreesCreateQuantileStreamResource", + Input: []tf.Input{ + quantile_stream_resource_handle, epsilon, num_streams, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. +type Conv3DBackpropFilterV2Attr func(optionalAttr) + +// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv3DBackpropFilterV2Dilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of 3-D convolution with respect to the filter. +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 5-D +// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` +// tensor. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv3DBackpropFilterV2", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatefulStandardNormalAttr is an optional argument to StatefulStandardNormal. +type StatefulStandardNormalAttr func(optionalAttr) + +// StatefulStandardNormalDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatefulStandardNormalDtype(value tf.DataType) StatefulStandardNormalAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random values from a normal distribution. This op is deprecated in favor of op 'StatefulStandardNormalV2' +// +// DEPRECATED at GraphDef version 29: Use StatefulStandardNormalV2 instead +// +// The generated values will have mean 0 and standard deviation 1. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// shape: The shape of the output tensor. +// +// Returns A tensor of the specified shape filled with random normal values. +func StatefulStandardNormal(scope *Scope, resource tf.Output, shape tf.Output, optional ...StatefulStandardNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatefulStandardNormal", + Input: []tf.Input{ + resource, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient for the tanh of `x` wrt its input. +// +// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` +// is the corresponding input gradient. +func TanhGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TanhGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Sets the index-th position of the list to contain the given tensor. +// +// input_handle: the list +// index: the position in the list to which the tensor will be assigned +// item: the element to be assigned to that position +// output_handle: the new list, with the element in the proper position +// +func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, item tf.Output) (output_handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorListSetItem", + Input: []tf.Input{ + input_handle, index, item, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeCSVAttr is an optional argument to DecodeCSV. +type DecodeCSVAttr func(optionalAttr) + +// DecodeCSVFieldDelim sets the optional field_delim attribute to value. +// +// value: char delimiter to separate fields in a record. +// If not specified, defaults to "," +func DecodeCSVFieldDelim(value string) DecodeCSVAttr { + return func(m optionalAttr) { + m["field_delim"] = value + } +} + +// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value. +// +// value: If false, treats double quotation marks as regular +// characters inside of the string fields (ignoring RFC 4180, Section 2, +// Bullet 5). +// If not specified, defaults to true +func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr { + return func(m optionalAttr) { + m["use_quote_delim"] = value + } +} + +// DecodeCSVNaValue sets the optional na_value attribute to value. +// +// value: Additional string to recognize as NA/NaN. +// If not specified, defaults to "" +func DecodeCSVNaValue(value string) DecodeCSVAttr { + return func(m optionalAttr) { + m["na_value"] = value + } +} + +// DecodeCSVSelectCols sets the optional select_cols attribute to value. +// If not specified, defaults to <> +func DecodeCSVSelectCols(value []int64) DecodeCSVAttr { + return func(m optionalAttr) { + m["select_cols"] = value + } +} + +// Convert CSV records to tensors. Each column maps to one tensor. +// +// RFC 4180 format is expected for the CSV records. +// (https://tools.ietf.org/html/rfc4180) +// Note that we allow leading and trailing spaces with int or float field. +// +// Arguments: +// records: Each string is a record/row in the csv and all records should have +// the same format. +// record_defaults: One tensor per column of the input record, with either a +// scalar default value for that column or an empty vector if the column is +// required. +// +// Returns Each tensor will have the same shape as records. +func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeCSV", + Input: []tf.Input{ + records, tf.OutputList(record_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("DecodeCSV", err) + return + } + return output +} + +// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. +type FusedBatchNormGradV2Attr func(optionalAttr) + +// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["epsilon"] = value + } +} + +// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. +// +// value: The data format for y_backprop, x, x_backprop. +// Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// Gradient for batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// y_backprop: A 4D Tensor for the gradient with respect to y. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch +// mean to be reused in gradient computation. When is_training is +// False, a 1D Tensor for the population mean to be reused in both +// 1st and 2nd order gradient computation. +// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch +// variance (inverted variance in the cuDNN case) to be reused in +// gradient computation. When is_training is False, a 1D Tensor +// for the population variance to be reused in both 1st and 2nd +// order gradient computation. +// +// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input +// in FusedBatchNorm. +func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedBatchNormGradV2", + Input: []tf.Input{ + y_backprop, x, scale, reserve_space_1, reserve_space_2, }, Attrs: attrs, } @@ -16288,152 +15739,198 @@ func CudnnRNNV3(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Out return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Applies softmax to a batched N-D `SparseTensor`. +// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. +type DestroyResourceOpAttr func(optionalAttr) + +// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. // -// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` -// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. -// -// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost -// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly -// zero elements do not participate*. Specifically, the algorithm is equivalent -// to the following: -// -// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix -// with shape `[B, C]`, along the size-C dimension; -// (2) Masks out the original implicitly-zero locations; -// (3) Renormalizes the remaining elements. -// -// Hence, the `SparseTensor` result has exactly the same non-zero indices and -// shape. -// -// Arguments: -// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a -// SparseTensor, in canonical ordering. -// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// -// Returns 1-D. The `NNZ` values for the result `SparseTensor`. -func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { - if scope.Err() != nil { - return +// value: whether to ignore the error when the resource +// doesn't exist. +// If not specified, defaults to true +func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { + return func(m optionalAttr) { + m["ignore_lookup_error"] = value } - opspec := tf.OpSpec{ - Type: "SparseSoftmax", - Input: []tf.Input{ - sp_indices, sp_values, sp_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Creates a Tensor by indexing into the TensorList. +// Deletes the resource specified by the handle. // -// Each row in the produced Tensor corresponds to the element in the TensorList -// specified by the given index (see `tf.gather`). +// All subsequent operations using the resource will result in a NotFound +// error status. // -// input_handle: The input tensor list. -// indices: The indices used to index into the list. -// values: The tensor. -func TensorListGather(scope *Scope, input_handle tf.Output, indices tf.Output, element_shape tf.Output, element_dtype tf.DataType) (values tf.Output) { +// Arguments: +// resource: handle to the resource to delete. +// +// Returns the created operation. +func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorListGather", + Type: "DestroyResourceOp", Input: []tf.Input{ - input_handle, indices, element_shape, + resource, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. +type SparseToSparseSetOperationAttr func(optionalAttr) + +// SparseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func SparseToSparseSetOperationValidateIndices(value bool) SparseToSparseSetOperationAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Applies set operation along last dimension of 2 `SparseTensor` inputs. +// +// See SetOperationOp::SetOperationFromContext for values of `set_operation`. +// +// If `validate_indices` is `True`, `SparseToSparseSetOperation` validates the +// order and range of `set1` and `set2` indices. +// +// Input `set1` is a `SparseTensor` represented by `set1_indices`, `set1_values`, +// and `set1_shape`. For `set1` ranked `n`, 1st `n-1` dimensions must be the same +// as `set2`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, +// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same +// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but +// ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set1` +// and `set2` indices. +// +// Output `result` is a `SparseTensor` represented by `result_indices`, +// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this +// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` +// dimension contains the result of `set_operation` applied to the corresponding +// `[0...n-1]` dimension of `set`. +// +// Arguments: +// set1_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set1_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set1_shape: 1D `Tensor`, shape of a `SparseTensor`. `set1_shape[0...n-1]` must +// be the same as `set2_shape[0...n-1]`, `set1_shape[n]` is the +// max set size across `0...n-1` dimensions. +// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major +// order. +// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major +// order. +// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must +// be the same as `set1_shape[0...n-1]`, `set2_shape[n]` is the +// max set size across `0...n-1` dimensions. +// +// +// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is +// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` +// is the max result set size across all `0...n-1` dimensions. +func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_values tf.Output, set1_shape tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...SparseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"set_operation": set_operation} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseToSparseSetOperation", + Input: []tf.Input{ + set1_indices, set1_values, set1_shape, set2_indices, set2_values, set2_shape, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. -type FixedLengthRecordReaderV2Attr func(optionalAttr) +// PaddingFIFOQueueV2Attr is an optional argument to PaddingFIFOQueueV2. +type PaddingFIFOQueueV2Attr func(optionalAttr) -// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. +// PaddingFIFOQueueV2Shapes sets the optional shapes attribute to value. // -// value: Number of bytes in the header, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. +// Shapes of fixed rank but variable size are allowed by setting +// any shape dimension to -1. In this case, the inputs' shape may vary along +// the given dimension, and DequeueMany will pad the given dimension with +// zeros up to the maximum shape of all elements in the given batch. +// If the length of this attr is 0, different queue elements may have +// different ranks and shapes, but only one element may be dequeued at a time. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func PaddingFIFOQueueV2Shapes(value []tf.Shape) PaddingFIFOQueueV2Attr { return func(m optionalAttr) { - m["header_bytes"] = value + m["shapes"] = value } } -// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. +// PaddingFIFOQueueV2Capacity sets the optional capacity attribute to value. // -// value: Number of bytes in the footer, defaults to 0. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func PaddingFIFOQueueV2Capacity(value int64) PaddingFIFOQueueV2Attr { return func(m optionalAttr) { - m["footer_bytes"] = value + m["capacity"] = value } } -// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. +// PaddingFIFOQueueV2Container sets the optional container attribute to value. // -// value: Number of bytes to hop before each read. Default of 0 means using -// record_bytes. -// If not specified, defaults to 0 -func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["hop_bytes"] = value - } -} - -// FixedLengthRecordReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. +// value: If non-empty, this queue is placed in the given container. // Otherwise, a default container is used. // If not specified, defaults to "" -func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { +func PaddingFIFOQueueV2Container(value string) PaddingFIFOQueueV2Attr { return func(m optionalAttr) { m["container"] = value } } -// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. +// PaddingFIFOQueueV2SharedName sets the optional shared_name attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. // If not specified, defaults to "" -func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { +func PaddingFIFOQueueV2SharedName(value string) PaddingFIFOQueueV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. +// A queue that produces elements in first-in first-out order. // -// value: The type of encoding for the file. Currently ZLIB and GZIP -// are supported. Defaults to none. -// If not specified, defaults to "" -func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { - return func(m optionalAttr) { - m["encoding"] = value - } -} - -// A Reader that outputs fixed-length records from a file. +// Variable-size shapes are allowed by setting the corresponding shape dimensions +// to 0 in the shape attr. In this case DequeueMany will pad up to the maximum +// size of any given element in the minibatch. See below for details. // // Arguments: -// record_bytes: Number of bytes in the record. +// component_types: The type of each component in a value. // -// Returns The handle to reference the Reader. -func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { +// Returns The handle to the queue. +func PaddingFIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...PaddingFIFOQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"record_bytes": record_bytes} + attrs := map[string]interface{}{"component_types": component_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FixedLengthRecordReaderV2", + Type: "PaddingFIFOQueueV2", Attrs: attrs, } @@ -16441,37 +15938,102 @@ func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...Fix return op.Output(0) } -// CompilationResultProto indicating the status of the TPU compilation. -func TPUCompilationResult(scope *Scope) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TPUCompilationResult", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. +// Decodes a `variant` Tensor into a `RaggedTensor`. +// +// Decodes the given `variant` Tensor and returns a `RaggedTensor`. The input +// could be a scalar, meaning it encodes a single `RaggedTensor` with ragged_rank +// `output_ragged_rank`. It could also have an arbitrary rank, in which case each +// element is decoded into a `RaggedTensor` with ragged_rank `input_ragged_rank` +// and these are then stacked according to the input shape to output a single +// `RaggedTensor` with ragged_rank `output_ragged_rank`. Each `variant` element in +// the input Tensor is decoded by retrieving from the element a 1-D `variant` +// Tensor with `input_ragged_rank + 1` Tensors, corresponding to the splits and +// values of the decoded `RaggedTensor`. If `input_ragged_rank` is -1, then it is +// inferred as `output_ragged_rank` - `rank(encoded_ragged)`. See +// `RaggedTensorToVariant` for the corresponding encoding logic. +// // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. +// encoded_ragged: A `variant` Tensor containing encoded `RaggedTensor`s. +// input_ragged_rank: The ragged rank of each encoded `RaggedTensor` component in the input. If set to +// -1, this is inferred as `output_ragged_rank` - `rank(encoded_ragged)` +// output_ragged_rank: The expected ragged rank of the output `RaggedTensor`. The following must hold: +// `output_ragged_rank = rank(encoded_ragged) + input_ragged_rank`. // -// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest -// layer. -func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) { +// +// +// Returns A list of one or more Tensors representing the splits of the output +// `RaggedTensor`.A Tensor representing the values of the output `RaggedTensor`. +func RaggedTensorFromVariant(scope *Scope, encoded_ragged tf.Output, input_ragged_rank int64, output_ragged_rank int64, Tvalues tf.DataType, Tsplits tf.DataType) (output_nested_splits []tf.Output, output_dense_values tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"input_ragged_rank": input_ragged_rank, "output_ragged_rank": output_ragged_rank, "Tvalues": Tvalues, "Tsplits": Tsplits} opspec := tf.OpSpec{ - Type: "BoostedTreesGetEnsembleStates", + Type: "RaggedTensorFromVariant", Input: []tf.Input{ - tree_ensemble_handle, + encoded_ragged, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + if scope.Err() != nil { + return + } + var idx int + var err error + if output_nested_splits, idx, err = makeOutputList(op, idx, "output_nested_splits"); err != nil { + scope.UpdateErr("RaggedTensorFromVariant", err) + return + } + output_dense_values = op.Output(idx) + return output_nested_splits, output_dense_values +} + +// RetrieveTPUEmbeddingADAMParametersAttr is an optional argument to RetrieveTPUEmbeddingADAMParameters. +type RetrieveTPUEmbeddingADAMParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingADAMParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingADAMParametersTableId(value int64) RetrieveTPUEmbeddingADAMParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingADAMParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingADAMParametersTableName(value string) RetrieveTPUEmbeddingADAMParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve ADAM embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the ADAM optimization algorithm.Parameter momenta updated by the ADAM optimization algorithm.Parameter velocities updated by the ADAM optimization algorithm. +func RetrieveTPUEmbeddingADAMParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingADAMParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } // ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign. @@ -16523,87 +16085,28 @@ func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Out return scope.AddOperation(opspec) } -// Deprecated. Use TensorArraySplitV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 -func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArraySplitV2", - Input: []tf.Input{ - handle, value, lengths, flow_in, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reshapes a SparseTensor to represent values in a new dense shape. -// -// This operation has the same semantics as reshape on the represented dense -// tensor. The `input_indices` are recomputed based on the requested `new_shape`. -// -// If one component of `new_shape` is the special value -1, the size of that -// dimension is computed so that the total dense size remains constant. At -// most one component of `new_shape` can be -1. The number of dense elements -// implied by `new_shape` must be the same as the number of dense elements -// originally implied by `input_shape`. -// -// Reshaping does not affect the order of values in the SparseTensor. -// -// If the input tensor has rank `R_in` and `N` non-empty values, and `new_shape` -// has length `R_out`, then `input_indices` has shape `[N, R_in]`, -// `input_shape` has length `R_in`, `output_indices` has shape `[N, R_out]`, and -// `output_shape` has length `R_out`. -// -// Arguments: -// input_indices: 2-D. `N x R_in` matrix with the indices of non-empty values in a -// SparseTensor. -// input_shape: 1-D. `R_in` vector with the input SparseTensor's dense shape. -// new_shape: 1-D. `R_out` vector with the requested new dense shape. -// -// Returns 2-D. `N x R_out` matrix with the updated indices of non-empty -// values in the output SparseTensor.1-D. `R_out` vector with the full dense shape of the output -// SparseTensor. This is the same as `new_shape` but with any -1 dimensions -// filled in. -func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, new_shape tf.Output) (output_indices tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseReshape", - Input: []tf.Input{ - input_indices, input_shape, new_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Computes the product along segments of a tensor. +// Computes the sum along segments of a tensor. // // Read // [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) // for an explanation of segments. // // Computes a tensor such that -// \\(output_i = \prod_j data_j\\) where the product is over `j` such +// \\(output_i = \sum_j data_j\\) where sum is over `j` such // that `segment_ids[j] == i`. // -// If the product is empty for a given segment ID `i`, `output[i] = 1`. +// If the sum is empty for a given segment ID `i`, `output[i] = 0`. // //
-// +// //
// // For example: // // ``` // c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) -// tf.segment_prod(c, tf.constant([0, 0, 1])) -// # ==> [[4, 6, 6, 4], +// tf.segment_sum(c, tf.constant([0, 0, 1])) +// # ==> [[5, 5, 5, 5], // # [5, 6, 7, 8]] // ``` // @@ -16615,12 +16118,12 @@ func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, // // Returns Has same shape as data, except for dimension 0 which // has size `k`, the number of segments. -func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { +func SegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SegmentProd", + Type: "SegmentSum", Input: []tf.Input{ data, segment_ids, }, @@ -16629,36 +16132,121 @@ func SegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf return op.Output(0) } -// RetrieveTPUEmbeddingFTRLParametersAttr is an optional argument to RetrieveTPUEmbeddingFTRLParameters. -type RetrieveTPUEmbeddingFTRLParametersAttr func(optionalAttr) +// Interleave the values from the `data` tensors into a single tensor. +// +// Builds a merged tensor such that +// +// ```python +// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] +// ``` +// +// For example, if each `indices[m]` is scalar or vector, we have +// +// ```python +// # Scalar indices: +// merged[indices[m], ...] = data[m][...] +// +// # Vector indices: +// merged[indices[m][i], ...] = data[m][i, ...] +// ``` +// +// Each `data[i].shape` must start with the corresponding `indices[i].shape`, +// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +// must have `data[i].shape = indices[i].shape + constant`. In terms of this +// `constant`, the output shape is +// +// merged.shape = [max(indices)] + constant +// +// Values may be merged in parallel, so if an index appears in both `indices[m][i]` +// and `indices[n][j]`, the result may be invalid. This differs from the normal +// DynamicStitch operator that defines the behavior in that case. +// +// For example: +// +// ```python +// indices[0] = 6 +// indices[1] = [4, 1] +// indices[2] = [[5, 2], [0, 3]] +// data[0] = [61, 62] +// data[1] = [[41, 42], [11, 12]] +// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] +// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], +// [51, 52], [61, 62]] +// ``` +// +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: +// +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. +// ``` +// +//
+// +//
+func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ParallelDynamicStitch", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(data), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// RetrieveTPUEmbeddingFTRLParametersTableId sets the optional table_id attribute to value. +// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug. +type LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 // // REQUIRES: value >= -1 -func RetrieveTPUEmbeddingFTRLParametersTableId(value int64) RetrieveTPUEmbeddingFTRLParametersAttr { +func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_id"] = value } } -// RetrieveTPUEmbeddingFTRLParametersTableName sets the optional table_name attribute to value. +// LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. // If not specified, defaults to "" -func RetrieveTPUEmbeddingFTRLParametersTableName(value string) RetrieveTPUEmbeddingFTRLParametersAttr { +func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { return func(m optionalAttr) { m["table_name"] = value } } -// Retrieve FTRL embedding parameters. +// Load proximal Adagrad embedding parameters with debug support. // -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // -// Returns Parameter parameters updated by the FTRL optimization algorithm.Parameter accumulators updated by the FTRL optimization algorithm.Parameter linears updated by the FTRL optimization algorithm. -func RetrieveTPUEmbeddingFTRLParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output) { +// Arguments: +// parameters: Value of parameters used in the proximal Adagrad optimization algorithm. +// accumulators: Value of accumulators used in the proximal Adagrad optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the proximal Adagrad optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -16667,40 +16255,3501 @@ func RetrieveTPUEmbeddingFTRLParameters(scope *Scope, num_shards int64, shard_id a(attrs) } opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingFTRLParameters", + Type: "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug", + Input: []tf.Input{ + parameters, accumulators, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} +// LoadTPUEmbeddingAdagradParametersAttr is an optional argument to LoadTPUEmbeddingAdagradParameters. +type LoadTPUEmbeddingAdagradParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingAdagradParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingAdagradParametersTableId(value int64) LoadTPUEmbeddingAdagradParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingAdagradParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingAdagradParametersTableName(value string) LoadTPUEmbeddingAdagradParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load Adagrad embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the Adagrad optimization algorithm. +// accumulators: Value of accumulators used in the Adagrad optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingAdagradParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdagradParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingAdagradParameters", + Input: []tf.Input{ + parameters, accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// LoadTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to LoadTPUEmbeddingMDLAdagradLightParameters. +type LoadTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingMDLAdagradLightParametersTableId(value int64) LoadTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingMDLAdagradLightParametersTableName(value string) LoadTPUEmbeddingMDLAdagradLightParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load MDL Adagrad Light embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the MDL Adagrad Light optimization algorithm. +// accumulators: Value of accumulators used in the MDL Adagrad Light optimization algorithm. +// weights: Value of weights used in the MDL Adagrad Light optimization algorithm. +// benefits: Value of benefits used in the MDL Adagrad Light optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingMDLAdagradLightParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMDLAdagradLightParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingMDLAdagradLightParameters", + Input: []tf.Input{ + parameters, accumulators, weights, benefits, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingFTRLParametersGradAccumDebug. +type LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingFTRLParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingFTRLParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load FTRL embedding parameters with debug support. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the FTRL optimization algorithm. +// accumulators: Value of accumulators used in the FTRL optimization algorithm. +// linears: Value of linears used in the FTRL optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the FTRL optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingFTRLParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, linears tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingFTRLParametersGradAccumDebug", + Input: []tf.Input{ + parameters, accumulators, linears, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Computes the maximum along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Computes a tensor such that +// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such +// that `segment_ids[j] == i`. +// +// If the max is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
+// +// For example: +// +// ``` +// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_max(c, tf.constant([0, 0, 1])) +// # ==> [[4, 3, 3, 4], +// # [5, 6, 7, 8]] +// ``` +// +// +// Arguments: +// +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SegmentMax", + Input: []tf.Input{ + data, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingStochasticGradientDescentParametersAttr is an optional argument to LoadTPUEmbeddingStochasticGradientDescentParameters. +type LoadTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) LoadTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingStochasticGradientDescentParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingStochasticGradientDescentParametersTableName(value string) LoadTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load SGD embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the stochastic gradient descent optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, parameters tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingStochasticGradientDescentParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingStochasticGradientDescentParameters", + Input: []tf.Input{ + parameters, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Conv3DBackpropInputV2Attr is an optional argument to Conv3DBackpropInputV2. +type Conv3DBackpropInputV2Attr func(optionalAttr) + +// Conv3DBackpropInputV2DataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv3DBackpropInputV2Dilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 5. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of 3-D convolution with respect to the input. +// +// Arguments: +// input_sizes: An integer vector representing the tensor shape of `input`, +// where `input` is a 5-D +// `[batch, depth, rows, cols, in_channels]` tensor. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropInputV2(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv3DBackpropInputV2", + Input: []tf.Input{ + input_sizes, filter, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns a copy of the input tensor. +func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Snapshot", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a TensorList by indexing into a Tensor. +// +// Each member of the TensorList corresponds to one row of the input tensor, +// specified by the given index (see `tf.gather`). +// +// tensor: The input tensor. +// indices: The indices used to index into the list. +// element_shape: The shape of the elements in the list (can be less specified than +// the shape of the tensor). +// num_elements: The size of the output list. Must be large enough to accommodate +// the largest index in indices. If -1, the list is just large enough to include +// the largest index in indices. +// output_handle: The TensorList. +func TensorListScatterV2(scope *Scope, tensor tf.Output, indices tf.Output, element_shape tf.Output, num_elements tf.Output) (output_handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorListScatterV2", + Input: []tf.Input{ + tensor, indices, element_shape, num_elements, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Multiplies sparse updates into the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] *= updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] *= updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions multiply. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterMul(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterMul", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// PrefetchDatasetAttr is an optional argument to PrefetchDataset. +type PrefetchDatasetAttr func(optionalAttr) + +// PrefetchDatasetSlackPeriod sets the optional slack_period attribute to value. +// If not specified, defaults to 0 +func PrefetchDatasetSlackPeriod(value int64) PrefetchDatasetAttr { + return func(m optionalAttr) { + m["slack_period"] = value + } +} + +// Creates a dataset that asynchronously prefetches elements from `input_dataset`. +// +// Arguments: +// +// buffer_size: The maximum number of elements to buffer in an iterator over +// this dataset. +// +// +func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...PrefetchDatasetAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "PrefetchDataset", + Input: []tf.Input{ + input_dataset, buffer_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. +// +// DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices, with the same constraints as the single matrix +// SelfAdjointEig. +// +// The result is a [..., M+1, M] matrix with [..., 0,:] containing the +// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. The eigenvalues +// are sorted in non-decreasing order. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M+1, M]`. +func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SelfAdjointEig", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. +type WholeFileReaderV2Attr func(optionalAttr) + +// WholeFileReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the entire contents of a file as a value. +// +// To use, enqueue filenames in a Queue. The output of ReaderRead will +// be a filename (key) and the contents of that file (value). +// +// Returns The handle to reference the Reader. +func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "WholeFileReaderV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug. +type RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve RMSProp embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the RMSProp optimization algorithm.Parameter ms updated by the RMSProp optimization algorithm.Parameter mom updated by the RMSProp optimization algorithm.Parameter gradient_accumulators updated by the RMSProp optimization algorithm. +func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr) (parameters tf.Output, ms tf.Output, mom tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. +// +// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the +// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each +// input channel is processed independently of the others with its own structuring +// function. The `output` tensor has shape +// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output +// tensor depend on the `padding` algorithm. We currently only support the default +// "NHWC" `data_format`. +// +// In detail, the grayscale morphological 2-D dilation is the max-sum correlation +// (for consistency with `conv2d`, we use unmirrored filters): +// +// output[b, y, x, c] = +// max_{dy, dx} input[b, +// strides[1] * y + rates[1] * dy, +// strides[2] * x + rates[2] * dx, +// c] + +// filter[dy, dx, c] +// +// Max-pooling is a special case when the filter has size equal to the pooling +// kernel size and contains all zeros. +// +// Note on duality: The dilation of `input` by the `filter` is equal to the +// negation of the erosion of `-input` by the reflected `filter`. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// strides: The stride of the sliding window for each dimension of the input +// tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: The input stride for atrous morphological dilation. Must be: +// `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape `[batch, out_height, out_width, depth]`. +func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, rates []int64, padding string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + opspec := tf.OpSpec{ + Type: "Dilation2D", + Input: []tf.Input{ + input, filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Constructs an Optional variant from a tuple of tensors. +func OptionalFromValue(scope *Scope, components []tf.Output) (optional tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "OptionalFromValue", + Input: []tf.Input{ + tf.OutputList(components), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MatrixSolveAttr is an optional argument to MatrixSolve. +type MatrixSolveAttr func(optionalAttr) + +// MatrixSolveAdjoint sets the optional adjoint attribute to value. +// +// value: Boolean indicating whether to solve with `matrix` or its (block-wise) +// adjoint. +// If not specified, defaults to false +func MatrixSolveAdjoint(value bool) MatrixSolveAttr { + return func(m optionalAttr) { + m["adjoint"] = value + } +} + +// Solves systems of linear equations. +// +// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is +// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix +// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +// If `adjoint` is `True` then each output matrix satisfies +// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. +// +// Arguments: +// matrix: Shape is `[..., M, M]`. +// rhs: Shape is `[..., M, K]`. +// +// Returns Shape is `[..., M, K]`. +func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixSolve", + Input: []tf.Input{ + matrix, rhs, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Slice a `SparseTensor` based on the `start` and `size`. +// +// For example, if the input is +// +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] +// +// Graphically the output tensors are: +// +// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] +// [ a ] +// [b c ] +// +// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] +// [ d e ] +// [ ] +// +// Arguments: +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// start: 1-D. tensor represents the start of the slice. +// size: 1-D. tensor represents the size of the slice. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. +// +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSlice", + Input: []tf.Input{ + indices, values, shape, start, size, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. +// +// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) +// ](http://arxiv.org/abs/1511.07289) +func Elu(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Elu", + Input: []tf.Input{ + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// InfeedEnqueueTupleAttr is an optional argument to InfeedEnqueueTuple. +type InfeedEnqueueTupleAttr func(optionalAttr) + +// InfeedEnqueueTupleLayouts sets the optional layouts attribute to value. +// +// value: A vector holding the requested layout in minor-to-major sequence for +// all the tuple shapes, in the order the shapes appear in the "shapes" input. +// The layout elements for a sub-shape can be set to -1, in which case the +// corresponding layout will be computed by the infeed operation. +// If not specified, defaults to <> +func InfeedEnqueueTupleLayouts(value []int64) InfeedEnqueueTupleAttr { + return func(m optionalAttr) { + m["layouts"] = value + } +} + +// InfeedEnqueueTupleDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func InfeedEnqueueTupleDeviceOrdinal(value int64) InfeedEnqueueTupleAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// Feeds multiple Tensor values into the computation as an XLA tuple. +// +// Arguments: +// inputs: A list of tensors that will be provided using the infeed mechanism. +// shapes: The shapes of each tensor in `inputs`. +// +// Returns the created operation. +func InfeedEnqueueTuple(scope *Scope, inputs []tf.Output, shapes []tf.Shape, optional ...InfeedEnqueueTupleAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shapes": shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "InfeedEnqueueTuple", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Computes log softmax activations. +// +// For each batch `i` and class `j` we have +// +// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) +// +// Arguments: +// logits: 2-D with shape `[batch_size, num_classes]`. +// +// Returns Same shape as `logits`. +func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LogSoftmax", + Input: []tf.Input{ + logits, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAdadeltaAttr is an optional argument to ResourceApplyAdadelta. +type ResourceApplyAdadeltaAttr func(optionalAttr) + +// ResourceApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var, accum and update_accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyAdadeltaUseLocking(value bool) ResourceApplyAdadeltaAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the adadelta scheme. +// +// accum = rho() * accum + (1 - rho()) * grad.square(); +// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; +// update_accum = rho() * update_accum + (1 - rho()) * update.square(); +// var -= update; +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// accum_update: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdadeltaAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdadelta", + Input: []tf.Input{ + var_, accum, accum_update, lr, rho, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// InfeedEnqueueAttr is an optional argument to InfeedEnqueue. +type InfeedEnqueueAttr func(optionalAttr) + +// InfeedEnqueueShape sets the optional shape attribute to value. +// +// value: The shape of the tensor. +// If not specified, defaults to <> +func InfeedEnqueueShape(value tf.Shape) InfeedEnqueueAttr { + return func(m optionalAttr) { + m["shape"] = value + } +} + +// InfeedEnqueueLayout sets the optional layout attribute to value. +// +// value: A vector holding the requested layout in minor-to-major sequence. +// If a layout attribute is passed, but its values are all -1, the layout will +// be computed by the infeed operation. +// If not specified, defaults to <> +func InfeedEnqueueLayout(value []int64) InfeedEnqueueAttr { + return func(m optionalAttr) { + m["layout"] = value + } +} + +// InfeedEnqueueDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func InfeedEnqueueDeviceOrdinal(value int64) InfeedEnqueueAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// An op which feeds a single Tensor value into the computation. +// +// Arguments: +// input: A tensor that will be provided using the infeed mechanism. +// +// Returns the created operation. +func InfeedEnqueue(scope *Scope, input tf.Output, optional ...InfeedEnqueueAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "InfeedEnqueue", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Inverse 3D fast Fourier transform. +// +// Computes the inverse 3-dimensional discrete Fourier transform over the +// inner-most 3 dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their inverse 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifftn with 3 dimensions. +// @end_compatibility +func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IFFT3D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SparseReduceSumSparseAttr is an optional argument to SparseReduceSumSparse. +type SparseReduceSumSparseAttr func(optionalAttr) + +// SparseReduceSumSparseKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceSumSparseKeepDims(value bool) SparseReduceSumSparseAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the sum of elements across dimensions of a SparseTensor. +// +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a +// SparseTensor. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseReduceSumSparse", + Input: []tf.Input{ + input_indices, input_values, input_shape, reduction_axes, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0), op.Output(1), op.Output(2) } -// Connects outputs of an N-way replicated computation to N outputs. -func TPUReplicatedOutput(scope *Scope, input tf.Output, num_replicas int64) (outputs []tf.Output) { +// ExtractJpegShapeAttr is an optional argument to ExtractJpegShape. +type ExtractJpegShapeAttr func(optionalAttr) + +// ExtractJpegShapeOutputType sets the optional output_type attribute to value. +// +// value: (Optional) The output type of the operation (int32 or int64). +// Defaults to int32. +// If not specified, defaults to DT_INT32 +func ExtractJpegShapeOutputType(value tf.DataType) ExtractJpegShapeAttr { + return func(m optionalAttr) { + m["output_type"] = value + } +} + +// Extract the shape information of a JPEG-encoded image. +// +// This op only parses the image header, so it is much faster than DecodeJpeg. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// +// Returns 1-D. The image shape with format [height, width, channels]. +func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegShapeAttr) (image_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ExtractJpegShape", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Enqueue multiple Tensor values on the computation outfeed. +// +// Arguments: +// inputs: A list of tensors that will be inserted into the outfeed queue as an +// XLA tuple. +// +// Returns the created operation. +func OutfeedEnqueueTuple(scope *Scope, inputs []tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "OutfeedEnqueueTuple", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + } + return scope.AddOperation(opspec) +} + +// StageClearAttr is an optional argument to StageClear. +type StageClearAttr func(optionalAttr) + +// StageClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageClearCapacity(value int64) StageClearAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// StageClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageClearMemoryLimit(value int64) StageClearAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// StageClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StageClearContainer(value string) StageClearAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// StageClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StageClearSharedName(value string) StageClearAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes all elements in the underlying container. +// +// Returns the created operation. +func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StageClear", + + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Updates the table to associates keys with values. +// +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableInsertV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// SparseTensorDenseMatMulAttr is an optional argument to SparseTensorDenseMatMul. +type SparseTensorDenseMatMulAttr func(optionalAttr) + +// SparseTensorDenseMatMulAdjointA sets the optional adjoint_a attribute to value. +// +// value: Use the adjoint of A in the matrix multiply. If A is complex, this +// is transpose(conj(A)). Otherwise it's transpose(A). +// If not specified, defaults to false +func SparseTensorDenseMatMulAdjointA(value bool) SparseTensorDenseMatMulAttr { + return func(m optionalAttr) { + m["adjoint_a"] = value + } +} + +// SparseTensorDenseMatMulAdjointB sets the optional adjoint_b attribute to value. +// +// value: Use the adjoint of B in the matrix multiply. If B is complex, this +// is transpose(conj(B)). Otherwise it's transpose(B). +// If not specified, defaults to false +func SparseTensorDenseMatMulAdjointB(value bool) SparseTensorDenseMatMulAttr { + return func(m optionalAttr) { + m["adjoint_b"] = value + } +} + +// Multiply SparseTensor (of rank 2) "A" by dense matrix "B". +// +// No validity checking is performed on the indices of A. However, the following +// input format is recommended for optimal behavior: +// +// if adjoint_a == false: +// A should be sorted in lexicographically increasing order. Use SparseReorder +// if you're not sure. +// if adjoint_a == true: +// A should be sorted in order of increasing dimension 1 (i.e., "column major" +// order instead of "row major" order). +// +// Arguments: +// a_indices: 2-D. The `indices` of the `SparseTensor`, size `[nnz, 2]` Matrix. +// a_values: 1-D. The `values` of the `SparseTensor`, size `[nnz]` Vector. +// a_shape: 1-D. The `shape` of the `SparseTensor`, size `[2]` Vector. +// b: 2-D. A dense Matrix. +func SparseTensorDenseMatMul(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output, optional ...SparseTensorDenseMatMulAttr) (product tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseTensorDenseMatMul", + Input: []tf.Input{ + a_indices, a_values, a_shape, b, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve. +type MatrixTriangularSolveAttr func(optionalAttr) + +// MatrixTriangularSolveLower sets the optional lower attribute to value. +// +// value: Boolean indicating whether the innermost matrices in `matrix` are +// lower or upper triangular. +// If not specified, defaults to true +func MatrixTriangularSolveLower(value bool) MatrixTriangularSolveAttr { + return func(m optionalAttr) { + m["lower"] = value + } +} + +// MatrixTriangularSolveAdjoint sets the optional adjoint attribute to value. +// +// value: Boolean indicating whether to solve with `matrix` or its (block-wise) +// adjoint. +// +// @compatibility(numpy) +// Equivalent to scipy.linalg.solve_triangular +// @end_compatibility +// If not specified, defaults to false +func MatrixTriangularSolveAdjoint(value bool) MatrixTriangularSolveAttr { + return func(m optionalAttr) { + m["adjoint"] = value + } +} + +// Solves systems of linear equations with upper or lower triangular matrices by backsubstitution. +// +// +// `matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form +// square matrices. If `lower` is `True` then the strictly upper triangular part +// of each inner-most matrix is assumed to be zero and not accessed. +// If `lower` is False then the strictly lower triangular part of each inner-most +// matrix is assumed to be zero and not accessed. +// `rhs` is a tensor of shape `[..., M, K]`. +// +// The output is a tensor of shape `[..., M, K]`. If `adjoint` is +// `True` then the innermost matrices in `output` satisfy matrix equations +// `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +// If `adjoint` is `False` then the strictly then the innermost matrices in +// `output` satisfy matrix equations +// `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. +// +// Example: +// ```python +// +// a = tf.constant([[3, 0, 0, 0], +// [2, 1, 0, 0], +// [1, 0, 1, 0], +// [1, 1, 1, 1]], dtype=tf.float32) +// +// b = tf.constant([[4], +// [2], +// [4], +// [2]], dtype=tf.float32) +// +// x = tf.linalg.triangular_solve(a, b, lower=True) +// x +// # +// +// # in python3 one can use `a@x` +// tf.matmul(a, x) +// # +// ``` +// +// Arguments: +// matrix: Shape is `[..., M, M]`. +// rhs: Shape is `[..., M, K]`. +// +// Returns Shape is `[..., M, K]`. +func MatrixTriangularSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixTriangularSolveAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixTriangularSolve", + Input: []tf.Input{ + matrix, rhs, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SetSizeAttr is an optional argument to SetSize. +type SetSizeAttr func(optionalAttr) + +// SetSizeValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func SetSizeValidateIndices(value bool) SetSizeAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Number of unique elements along last dimension of input `set`. +// +// Input `set` is a `SparseTensor` represented by `set_indices`, `set_values`, +// and `set_shape`. The last dimension contains values in a set, duplicates are +// allowed but ignored. +// +// If `validate_indices` is `True`, this op validates the order and range of `set` +// indices. +// +// Arguments: +// set_indices: 2D `Tensor`, indices of a `SparseTensor`. +// set_values: 1D `Tensor`, values of a `SparseTensor`. +// set_shape: 1D `Tensor`, shape of a `SparseTensor`. +// +// Returns For `set` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st +// `n-1` dimensions as `set`. Each value is the number of unique elements in +// the corresponding `[0...n-1]` dimension of `set`. +func SetSize(scope *Scope, set_indices tf.Output, set_values tf.Output, set_shape tf.Output, optional ...SetSizeAttr) (size tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SetSize", + Input: []tf.Input{ + set_indices, set_values, set_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds two `SparseTensor` objects to produce another `SparseTensor`. +// +// The input `SparseTensor` objects' indices are assumed ordered in standard +// lexicographic order. If this is not the case, before this step run +// `SparseReorder` to restore index ordering. +// +// By default, if two values sum to zero at some index, the output `SparseTensor` +// would still include that particular location in its index, storing a zero in the +// corresponding value slot. To override this, callers can specify `thresh`, +// indicating that if the sum has a magnitude strictly smaller than `thresh`, its +// corresponding value and index would then not be included. In particular, +// `thresh == 0` (default) means everything is kept and actual thresholding happens +// only for a positive value. +// +// In the following shapes, `nnz` is the count after taking `thresh` into account. +// +// Arguments: +// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix. +// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector. +// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector. +// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix. +// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector. +// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector. +// thresh: 0-D. The magnitude threshold that determines if an output value/index +// pair takes space. +func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseAdd", + Input: []tf.Input{ + a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Strip leading and trailing whitespaces from the Tensor. +// +// Arguments: +// input: A string `Tensor` of any shape. +// +// Returns A string `Tensor` of the same shape as the input. +func StringStrip(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StringStrip", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Serializes the tree ensemble to a proto. +// +// Arguments: +// tree_ensemble_handle: Handle to the tree ensemble. +// +// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble. +func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesSerializeEnsemble", + Input: []tf.Input{ + tree_ensemble_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Outputs a `Summary` protocol buffer with scalar values. +// +// The input `tags` and `values` must have the same shape. The generated summary +// has a summary value for each tag-value pair in `tags` and `values`. +// +// Arguments: +// tags: Tags for the summary. +// values: Same shape as `tags. Values for the summary. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ScalarSummary", + Input: []tf.Input{ + tags, values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// CompilationResultProto indicating the status of the TPU compilation. +func TPUCompilationResult(scope *Scope) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TPUCompilationResult", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeBicubicAttr is an optional argument to ResizeBicubic. +type ResizeBicubicAttr func(optionalAttr) + +// ResizeBicubicAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeBicubicAlignCorners(value bool) ResizeBicubicAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// ResizeBicubicHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeBicubicHalfPixelCenters(value bool) ResizeBicubicAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value + } +} + +// Resize `images` to `size` using bicubic interpolation. +// +// Input images can be of different types but output images are always float. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBicubicAttr) (resized_images tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBicubic", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingADAMParametersAttr is an optional argument to LoadTPUEmbeddingADAMParameters. +type LoadTPUEmbeddingADAMParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingADAMParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingADAMParametersTableId(value int64) LoadTPUEmbeddingADAMParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingADAMParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingADAMParametersTableName(value string) LoadTPUEmbeddingADAMParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load ADAM embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the ADAM optimization algorithm. +// momenta: Value of momenta used in the ADAM optimization algorithm. +// velocities: Value of velocities used in the ADAM optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingADAMParameters(scope *Scope, parameters tf.Output, momenta tf.Output, velocities tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingADAMParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingADAMParameters", + Input: []tf.Input{ + parameters, momenta, velocities, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// JPEG encode input image with provided compression quality. +// +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// `quality` is an int32 jpeg compression quality value between 0 and 100. +// +// +// Arguments: +// images: Images to adjust. At least 3-D. +// quality: An int quality to encode to. +// +// Returns 0-D. JPEG-encoded image. +func EncodeJpegVariableQuality(scope *Scope, images tf.Output, quality tf.Output) (contents tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "EncodeJpegVariableQuality", + Input: []tf.Input{ + images, quality, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TPUReplicateMetadataAttr is an optional argument to TPUReplicateMetadata. +type TPUReplicateMetadataAttr func(optionalAttr) + +// TPUReplicateMetadataNumCoresPerReplica sets the optional num_cores_per_replica attribute to value. +// +// value: Number of cores per replica. Used for model parallelism. +// If not specified, defaults to 1 +func TPUReplicateMetadataNumCoresPerReplica(value int64) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["num_cores_per_replica"] = value + } +} + +// TPUReplicateMetadataTopology sets the optional topology attribute to value. +// +// value: TopologyProto indicating the topology of the TPU pod slice. +// If not specified, defaults to "" +func TPUReplicateMetadataTopology(value string) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["topology"] = value + } +} + +// TPUReplicateMetadataUseTpu sets the optional use_tpu attribute to value. +// +// value: Whether to place the computation on the TPU. +// If not specified, defaults to true +func TPUReplicateMetadataUseTpu(value bool) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["use_tpu"] = value + } +} + +// TPUReplicateMetadataDeviceAssignment sets the optional device_assignment attribute to value. +// +// value: The assignment of devices for the computation. +// If not specified, defaults to <> +func TPUReplicateMetadataDeviceAssignment(value []int64) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["device_assignment"] = value + } +} + +// TPUReplicateMetadataComputationShape sets the optional computation_shape attribute to value. +// +// value: DEPRECATED. Use num_cores_per_replica instead. +// If not specified, defaults to <> +func TPUReplicateMetadataComputationShape(value []int64) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["computation_shape"] = value + } +} + +// TPUReplicateMetadataHostComputeCore sets the optional host_compute_core attribute to value. +// If not specified, defaults to <> +func TPUReplicateMetadataHostComputeCore(value []string) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["host_compute_core"] = value + } +} + +// TPUReplicateMetadataPaddingMap sets the optional padding_map attribute to value. +// If not specified, defaults to <> +func TPUReplicateMetadataPaddingMap(value []string) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["padding_map"] = value + } +} + +// TPUReplicateMetadataStepMarkerLocation sets the optional step_marker_location attribute to value. +// If not specified, defaults to "STEP_MARK_AT_ENTRY" +func TPUReplicateMetadataStepMarkerLocation(value string) TPUReplicateMetadataAttr { + return func(m optionalAttr) { + m["step_marker_location"] = value + } +} + +// Metadata indicaitng how the TPU computation should be replicated. +// +// Arguments: +// num_replicas: Number of replicas of the computation +// +// Returns the created operation. +func TPUReplicateMetadata(scope *Scope, num_replicas int64, optional ...TPUReplicateMetadataAttr) (o *tf.Operation) { if scope.Err() != nil { return } attrs := map[string]interface{}{"num_replicas": num_replicas} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TPUReplicatedOutput", + Type: "TPUReplicateMetadata", + + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// LuAttr is an optional argument to Lu. +type LuAttr func(optionalAttr) + +// LuOutputIdxType sets the optional output_idx_type attribute to value. +// If not specified, defaults to DT_INT32 +func LuOutputIdxType(value tf.DataType) LuAttr { + return func(m optionalAttr) { + m["output_idx_type"] = value + } +} + +// Computes the LU decomposition of one or more square matrices. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. +// +// The input has to be invertible. +// +// The output consists of two tensors LU and P containing the LU decomposition +// of all input submatrices `[..., :, :]`. LU encodes the lower triangular and +// upper triangular factors. +// +// For each input submatrix of shape `[M, M]`, L is a lower triangular matrix of +// shape `[M, M]` with unit diagonal whose entries correspond to the strictly lower +// triangular part of LU. U is a upper triangular matrix of shape `[M, M]` whose +// entries correspond to the upper triangular part, including the diagonal, of LU. +// +// P represents a permutation matrix encoded as a list of indices each between `0` +// and `M-1`, inclusive. If P_mat denotes the permutation matrix corresponding to +// P, then the L, U and P satisfies P_mat * input = L * U. +// +// Arguments: +// input: A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form matrices of +// size `[M, M]`. +// +// Returns A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the +// lower triangular factor `L` with unit diagonal, and whose upper triangular part +// denotes the upper triangular factor `U`.Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is +// `[..., M]`. +// @compatibility(scipy) +// Similar to `scipy.linalg.lu`, except the triangular factors `L` and `U` are +// packed into a single tensor, the permutation is applied to `input` instead of +// the right hand side and the permutation `P` is returned as a list of indices +// instead of a permutation matrix. +// @end_compatibility +func Lu(scope *Scope, input tf.Output, optional ...LuAttr) (lu tf.Output, p tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Lu", Input: []tf.Input{ input, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. +type Conv2DBackpropFilterAttr func(optionalAttr) + +// Conv2DBackpropFilterUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropFilterUseCudnnOnGpu(value bool) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["use_cudnn_on_gpu"] = value + } +} + +// Conv2DBackpropFilterExplicitPaddings sets the optional explicit_paddings attribute to value. +// +// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith +// dimension, the amount of padding inserted before and after the dimension is +// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If +// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. +// If not specified, defaults to <> +func Conv2DBackpropFilterExplicitPaddings(value []int64) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["explicit_paddings"] = value + } +} + +// Conv2DBackpropFilterDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv2DBackpropFilterDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of convolution with respect to the filter. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, out_channels]` tensor. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropFilterAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv2DBackpropFilter", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingRMSPropParameters. +type LoadTPUEmbeddingRMSPropParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingRMSPropParametersTableId(value int64) LoadTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingRMSPropParametersTableName(value string) LoadTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load RMSProp embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the RMSProp optimization algorithm. +// ms: Value of ms used in the RMSProp optimization algorithm. +// mom: Value of mom used in the RMSProp optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingRMSPropParameters(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingRMSPropParameters", + Input: []tf.Input{ + parameters, ms, mom, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Returns which elements of x are Inf. +// +// @compatibility(numpy) +// Equivalent to np.isinf +// @end_compatibility +func IsInf(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsInf", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomPoissonV2Attr is an optional argument to RandomPoissonV2. +type RandomPoissonV2Attr func(optionalAttr) + +// RandomPoissonV2Seed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomPoissonV2Seed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// RandomPoissonV2Dtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_INT64 +func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random values from the Poisson distribution(s) described by rate. +// +// This op uses two algorithms, depending on rate. If rate >= 10, then +// the algorithm by Hormann is used to acquire samples via +// transformation-rejection. +// See http://www.sciencedirect.com/science/article/pii/0167668793909974. +// +// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform +// random variables. +// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer +// Programming, Volume 2. Addison Wesley +// +// Arguments: +// shape: 1-D integer tensor. Shape of independent samples to draw from each +// distribution described by the shape parameters given in rate. +// rate: A tensor in which each scalar is a "rate" parameter describing the +// associated poisson distribution. +// +// Returns A tensor with shape `shape + shape(rate)`. Each slice +// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for +// `rate[i0, i1, ...iN]`. +func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomPoissonV2", + Input: []tf.Input{ + shape, rate, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the determinant of one or more square matrices. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor containing the determinants +// for all input submatrices `[..., :, :]`. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[...]`. +func MatrixDeterminant(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixDeterminant", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. +type ResizeNearestNeighborGradAttr func(optionalAttr) + +// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. +// If not specified, defaults to false +func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// ResizeNearestNeighborGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeNearestNeighborGradHalfPixelCenters(value bool) ResizeNearestNeighborGradAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value + } +} + +// Computes the gradient of nearest neighbor interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The +// original input size. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients +// with respect to the input image. +func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeNearestNeighborGrad", + Input: []tf.Input{ + grads, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. +type MaxPoolWithArgmaxAttr func(optionalAttr) + +// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. +// If not specified, defaults to DT_INT64 +func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { + return func(m optionalAttr) { + m["Targmax"] = value + } +} + +// MaxPoolWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. +// +// value: Whether to include batch dimension in flattened index of `argmax`. +// If not specified, defaults to false +func MaxPoolWithArgmaxIncludeBatchInIndex(value bool) MaxPoolWithArgmaxAttr { + return func(m optionalAttr) { + m["include_batch_in_index"] = value + } +} + +// Performs max pooling on the input and outputs both max values and indices. +// +// The indices in `argmax` are flattened, so that a maximum value at position +// `[b, y, x, c]` becomes flattened index: +// `(y * width + x) * channels + c` if `include_batch_in_index` is False; +// `((b * height + y) * width + x) * channels + c` if `include_batch_in_index` is True. +// +// The indices returned are always in `[0, height) x [0, width)` before flattening, +// even if padding is involved and the mathematically correct answer is outside +// (either negative or too large). This is a bug, but fixing it is difficult to do +// in a safe backwards compatible way, especially due to flattening. +// +// Arguments: +// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. +func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolWithArgmax", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// ShapeNAttr is an optional argument to ShapeN. +type ShapeNAttr func(optionalAttr) + +// ShapeNOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func ShapeNOutType(value tf.DataType) ShapeNAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Returns shape of tensors. +// +// This operation returns N 1-D integer tensors representing shape of `input[i]s`. +func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ShapeN", + Input: []tf.Input{ + tf.OutputList(input), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("ShapeN", err) + return + } + return output +} + +// Creates a dataset that changes the batch size. +// +// Creates a dataset that changes the batch size of the dataset to current batch +// size // num_workers. +// +// Arguments: +// input_dataset: A variant tensor representing the input dataset. +// num_workers: A scalar representing the number of workers to distribute this batch across. As +// a result of this transformation the current batch size would end up being +// divided by this parameter. +// +// +func ExperimentalRebatchDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalRebatchDataset", + Input: []tf.Input{ + input_dataset, num_workers, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// NonDeterministicIntsAttr is an optional argument to NonDeterministicInts. +type NonDeterministicIntsAttr func(optionalAttr) + +// NonDeterministicIntsDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_INT64 +func NonDeterministicIntsDtype(value tf.DataType) NonDeterministicIntsAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Non-deterministically generates some integers. +// +// This op may use some OS-provided source of non-determinism (e.g. an RNG), so each execution will give different results. +// +// Arguments: +// shape: The shape of the output tensor. +// +// Returns Non-deterministic integer values with specified shape. +func NonDeterministicInts(scope *Scope, shape tf.Output, optional ...NonDeterministicIntsAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "NonDeterministicInts", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Convert one or more images from HSV to RGB. +// +// Outputs a tensor of the same shape as the `images` tensor, containing the RGB +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// See `rgb_to_hsv` for a description of the HSV encoding. +// +// Arguments: +// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// +// Returns `images` converted to RGB. +func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "HSVToRGB", + Input: []tf.Input{ + images, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TruncatedNormalAttr is an optional argument to TruncatedNormal. +type TruncatedNormalAttr func(optionalAttr) + +// TruncatedNormalSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func TruncatedNormalSeed(value int64) TruncatedNormalAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// TruncatedNormalSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func TruncatedNormalSeed2(value int64) TruncatedNormalAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a truncated normal distribution. +// +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. +// +// Arguments: +// shape: The shape of the output tensor. +// dtype: The type of the output. +// +// Returns A tensor of the specified shape filled with random truncated normal +// values. +func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TruncatedNormal", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes softsign gradients for a softsign operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding softsign operation. +// features: The features passed as input to the corresponding softsign operation. +// +// Returns The gradients: `gradients / (1 + abs(features)) ** 2`. +func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SoftsignGrad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Calculates gains for each feature and returns the best possible split information for the feature. +// +// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature. +// +// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split. +// +// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features). +// +// The length of output lists are all of the same length, `num_features`. +// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature. +// +// Arguments: +// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive). +// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used. +// l1: l1 regularization factor on leaf weights, per instance based. +// l2: l2 regularization factor on leaf weights, per instance based. +// tree_complexity: adjustment to the gain, per leaf based. +// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting. +// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors. +// +// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node. +func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"max_splits": max_splits} + opspec := tf.OpSpec{ + Type: "BoostedTreesCalculateBestGainsPerFeature", + Input: []tf.Input{ + node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil { + scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) + return + } + return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list +} + +// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2. +type ResourceApplyFtrlV2Attr func(optionalAttr) + +// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the Ftrl-proximal scheme. +// +// grad_with_shrinkage = grad + 2 * l2_shrinkage * var +// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +// linear += grad_with_shrinkage + +// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regulariation. Must be a scalar. +// l2: L2 shrinkage regulariation. Must be a scalar. +// +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyFtrlV2", + Input: []tf.Input{ + var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// ResourceApplyAdamWithAmsgradAttr is an optional argument to ResourceApplyAdamWithAmsgrad. +type ResourceApplyAdamWithAmsgradAttr func(optionalAttr) + +// ResourceApplyAdamWithAmsgradUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdamWithAmsgradUseLocking(value bool) ResourceApplyAdamWithAmsgradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the Adam algorithm. +// +// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ +// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ +// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ +// $$vhat_t := max{vhat_{t-1}, v_t}$$ +// $$variable := variable - lr_t * m_t / (\sqrt{vhat_t} + \epsilon)$$ +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// vhat: Should be from a Variable(). +// beta1_power: Must be a scalar. +// beta2_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdamWithAmsgrad(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, vhat tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamWithAmsgradAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdamWithAmsgrad", + Input: []tf.Input{ + var_, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Quantized Batch normalization. +// +// This op is deprecated and will be removed in the future. Prefer +// `tf.nn.batch_normalization`. +// +// Arguments: +// t: A 4D input Tensor. +// t_min: The value represented by the lowest quantized input. +// t_max: The value represented by the highest quantized input. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// m_min: The value represented by the lowest quantized mean. +// m_max: The value represented by the highest quantized mean. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// v_min: The value represented by the lowest quantized variance. +// v_max: The value represented by the highest quantized variance. +// beta: A 1D beta Tensor with size matching the last dimension of t. +// An offset to be added to the normalized tensor. +// beta_min: The value represented by the lowest quantized offset. +// beta_max: The value represented by the highest quantized offset. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this tensor will be multiplied +// with the normalized tensor. +// gamma_min: The value represented by the lowest quantized gamma. +// gamma_max: The value represented by the highest quantized gamma. +// +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +func QuantizedBatchNormWithGlobalNormalization(scope *Scope, t tf.Output, t_min tf.Output, t_max tf.Output, m tf.Output, m_min tf.Output, m_max tf.Output, v tf.Output, v_min tf.Output, v_max tf.Output, beta tf.Output, beta_min tf.Output, beta_max tf.Output, gamma tf.Output, gamma_min tf.Output, gamma_max tf.Output, out_type tf.DataType, variance_epsilon float32, scale_after_normalization bool) (result tf.Output, result_min tf.Output, result_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type, "variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + opspec := tf.OpSpec{ + Type: "QuantizedBatchNormWithGlobalNormalization", + Input: []tf.Input{ + t, t_min, t_max, m, m_min, m_max, v, v_min, v_max, beta, beta_min, beta_max, gamma, gamma_min, gamma_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Output the logits for the given input data +// +// Arguments: +// tree_handle: Handle to the tree resource. +// dense_features: Rank 2 dense features tensor. +// logits_dimension: Scalar, dimension of the logits. +// +// Returns The logits predictions from the tree for each instance in the batch. +func TensorForestTreePredict(scope *Scope, tree_handle tf.Output, dense_features tf.Output, logits_dimension int64) (logits tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"logits_dimension": logits_dimension} + opspec := tf.OpSpec{ + Type: "TensorForestTreePredict", + Input: []tf.Input{ + tree_handle, dense_features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingRMSPropParametersAttr is an optional argument to RetrieveTPUEmbeddingRMSPropParameters. +type RetrieveTPUEmbeddingRMSPropParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingRMSPropParametersTableId(value int64) RetrieveTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingRMSPropParametersTableName(value string) RetrieveTPUEmbeddingRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve RMSProp embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the RMSProp optimization algorithm.Parameter ms updated by the RMSProp optimization algorithm.Parameter mom updated by the RMSProp optimization algorithm. +func RetrieveTPUEmbeddingRMSPropParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingRMSPropParametersAttr) (parameters tf.Output, ms tf.Output, mom tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingRMSPropParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. +type ResourceSparseApplyAdagradAttr func(optionalAttr) + +// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// If not specified, defaults to true +func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr { + return func(m optionalAttr) { + m["update_slots"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. +// +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyAdagrad", + Input: []tf.Input{ + var_, accum, lr, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// TensorArrayV3Attr is an optional argument to TensorArrayV3. +type TensorArrayV3Attr func(optionalAttr) + +// TensorArrayV3ElementShape sets the optional element_shape attribute to value. +// +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. +// +// value: A boolean that determines whether writes to the TensorArray +// are allowed to grow the size. By default, this is not allowed. +// If not specified, defaults to false +func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["dynamic_size"] = value + } +} + +// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. +// +// value: If true (default), Tensors in the TensorArray are cleared +// after being read. This disables multiple read semantics but allows early +// release of memory. +// If not specified, defaults to true +func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["clear_after_read"] = value + } +} + +// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. +// +// value: If true (default is false), then all +// elements in the TensorArray will be expected to have have identical shapes. +// This allows certain behaviors, like dynamically checking for +// consistent shapes on write, and being able to fill in properly +// shaped zero tensors on stack -- even if the element_shape attribute +// is not fully defined. +// If not specified, defaults to false +func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["identical_element_shapes"] = value + } +} + +// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. +// +// value: Overrides the name used for the temporary tensor_array +// resource. Default value is the name of the 'TensorArray' op (which +// is guaranteed unique). +// If not specified, defaults to "" +func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr { + return func(m optionalAttr) { + m["tensor_array_name"] = value + } +} + +// An array of Tensors of given size. +// +// Write data via Write and read via Read or Pack. +// +// Arguments: +// size: The size of the array. +// dtype: The type of the elements on the tensor_array. +// +// Returns The handle to the TensorArray.A scalar used to control gradient flow. +func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayV3", + Input: []tf.Input{ + size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Creates a dataset that caches elements from `input_dataset`. +// +// A CacheDataset will iterate over the input_dataset, and store tensors. If the +// cache already exists, the cache will be used. If the cache is inappropriate +// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error +// will the returned when used. +// +// Arguments: +// +// filename: A path on the filesystem where we should cache the dataset. Note: this +// will be a directory. +// +// +func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "CacheDataset", + Input: []tf.Input{ + input_dataset, filename, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Generate a sharded filename. The filename is printf formatted as +// +// %s-%05d-of-%05d, basename, shard, num_shards. +func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shards tf.Output) (filename tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ShardedFilename", + Input: []tf.Input{ + basename, shard, num_shards, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Get the value of the tensor specified by its handle. +// +// Arguments: +// handle: The handle for a tensor stored in the session state. +// dtype: The type of the output value. +// +// Returns The tensor for the given handle. +func GetSessionTensor(scope *Scope, handle tf.Output, dtype tf.DataType) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "GetSessionTensor", + Input: []tf.Input{ + handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds sparse updates to the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] += updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] += updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions add. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterAdd(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterAdd", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// LoadTPUEmbeddingCenteredRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingCenteredRMSPropParameters. +type LoadTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingCenteredRMSPropParametersTableId(value int64) LoadTPUEmbeddingCenteredRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingCenteredRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingCenteredRMSPropParametersTableName(value string) LoadTPUEmbeddingCenteredRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load centered RMSProp embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the centered RMSProp optimization algorithm. +// ms: Value of ms used in the centered RMSProp optimization algorithm. +// mom: Value of mom used in the centered RMSProp optimization algorithm. +// mg: Value of mg used in the centered RMSProp optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingCenteredRMSPropParameters(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, mg tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingCenteredRMSPropParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingCenteredRMSPropParameters", + Input: []tf.Input{ + parameters, ms, mom, mg, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// CTCGreedyDecoderAttr is an optional argument to CTCGreedyDecoder. +type CTCGreedyDecoderAttr func(optionalAttr) + +// CTCGreedyDecoderMergeRepeated sets the optional merge_repeated attribute to value. +// +// value: If True, merge repeated classes in output. +// If not specified, defaults to false +func CTCGreedyDecoderMergeRepeated(value bool) CTCGreedyDecoderAttr { + return func(m optionalAttr) { + m["merge_repeated"] = value + } +} + +// Performs greedy decoding on the logits given in inputs. +// +// A note about the attribute merge_repeated: if enabled, when +// consecutive logits' maximum indices are the same, only the first of +// these is emitted. Labeling the blank '*', the sequence "A B B * B B" +// becomes "A B B" if merge_repeated = True and "A B B B B" if +// merge_repeated = False. +// +// Regardless of the value of merge_repeated, if the maximum index of a given +// time and batch corresponds to the blank, index `(num_classes - 1)`, no new +// element is emitted. +// +// Arguments: +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// sequence_length: A vector containing sequence lengths, size `(batch_size)`. +// +// Returns Indices matrix, size `(total_decoded_outputs x 2)`, +// of a `SparseTensor`. The rows store: [batch, time].Values vector, size: `(total_decoded_outputs)`, +// of a `SparseTensor`. The vector stores the decoded classes.Shape vector, size `(2)`, of the decoded SparseTensor. +// Values are: `[batch_size, max_decoded_length]`.Matrix, size `(batch_size x 1)`, containing sequence +// log-probabilities. +func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, optional ...CTCGreedyDecoderAttr) (decoded_indices tf.Output, decoded_values tf.Output, decoded_shape tf.Output, log_probability tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CTCGreedyDecoder", + Input: []tf.Input{ + inputs, sequence_length, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// An op that receives embedding activations on the TPU. +// +// The TPU system performs the embedding lookups and aggregations specified by +// the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The +// results of these aggregations are visible to the Tensorflow Graph as the +// outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing +// one Tensor of activations per table specified in the model. There can be at +// most one RecvTPUEmbeddingActivations op in the TPU graph. +// +// Arguments: +// num_outputs: The number of output activation tensors, equal to the number of +// embedding tables in the model. +// config: Serialized TPUEmbeddingConfiguration proto. +// +// Returns A TensorList of embedding activations containing one Tensor per +// embedding table in the model. +func RecvTPUEmbeddingActivations(scope *Scope, num_outputs int64, config string) (outputs []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_outputs": num_outputs, "config": config} + opspec := tf.OpSpec{ + Type: "RecvTPUEmbeddingActivations", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) if scope.Err() != nil { return } var idx int var err error if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("TPUReplicatedOutput", err) + scope.UpdateErr("RecvTPUEmbeddingActivations", err) return } return outputs } +// RecordInputAttr is an optional argument to RecordInput. +type RecordInputAttr func(optionalAttr) + +// RecordInputFileRandomSeed sets the optional file_random_seed attribute to value. +// +// value: Random seeds used to produce randomized records. +// If not specified, defaults to 301 +func RecordInputFileRandomSeed(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_random_seed"] = value + } +} + +// RecordInputFileShuffleShiftRatio sets the optional file_shuffle_shift_ratio attribute to value. +// +// value: Shifts the list of files after the list is randomly +// shuffled. +// If not specified, defaults to 0 +func RecordInputFileShuffleShiftRatio(value float32) RecordInputAttr { + return func(m optionalAttr) { + m["file_shuffle_shift_ratio"] = value + } +} + +// RecordInputFileBufferSize sets the optional file_buffer_size attribute to value. +// +// value: The randomization shuffling buffer. +// If not specified, defaults to 10000 +func RecordInputFileBufferSize(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_buffer_size"] = value + } +} + +// RecordInputFileParallelism sets the optional file_parallelism attribute to value. +// +// value: How many sstables are opened and concurrently iterated over. +// If not specified, defaults to 16 +func RecordInputFileParallelism(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_parallelism"] = value + } +} + +// RecordInputBatchSize sets the optional batch_size attribute to value. +// +// value: The batch size. +// If not specified, defaults to 32 +func RecordInputBatchSize(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["batch_size"] = value + } +} + +// RecordInputCompressionType sets the optional compression_type attribute to value. +// +// value: The type of compression for the file. Currently ZLIB and +// GZIP are supported. Defaults to none. +// If not specified, defaults to "" +func RecordInputCompressionType(value string) RecordInputAttr { + return func(m optionalAttr) { + m["compression_type"] = value + } +} + +// Emits randomized records. +// +// Arguments: +// file_pattern: Glob pattern for the data files. +// +// Returns A tensor of shape [batch_size]. +func RecordInput(scope *Scope, file_pattern string, optional ...RecordInputAttr) (records tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"file_pattern": file_pattern} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RecordInput", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LeakyReluAttr is an optional argument to LeakyRelu. +type LeakyReluAttr func(optionalAttr) + +// LeakyReluAlpha sets the optional alpha attribute to value. +// If not specified, defaults to 0.2 +func LeakyReluAlpha(value float32) LeakyReluAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// Computes rectified linear: `max(features, features * alpha)`. +func LeakyRelu(scope *Scope, features tf.Output, optional ...LeakyReluAttr) (activations tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LeakyRelu", + Input: []tf.Input{ + features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyCenteredRMSPropAttr is an optional argument to ResourceApplyCenteredRMSProp. +type ResourceApplyCenteredRMSPropAttr func(optionalAttr) + +// ResourceApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, mg, ms, and mom tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyCenteredRMSPropUseLocking(value bool) ResourceApplyCenteredRMSPropAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the centered RMSProp algorithm. +// +// The centered RMSProp algorithm uses an estimate of the centered second moment +// (i.e., the variance) for normalization, as opposed to regular RMSProp, which +// uses the (uncentered) second moment. This often helps with training, but is +// slightly more expensive in terms of computation and memory. +// +// Note that in dense implementation of this algorithm, mg, ms, and mom will +// update even if the grad is zero, but in this sparse implementation, mg, ms, +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// mean_grad = decay * mean_grad + (1-decay) * gradient +// +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) +// +// mg <- rho * mg_{t-1} + (1-rho) * grad +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) +// var <- var - mom +// +// Arguments: +// var_: Should be from a Variable(). +// mg: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyCenteredRMSPropAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyCenteredRMSProp", + Input: []tf.Input{ + var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// MultinomialAttr is an optional argument to Multinomial. +type MultinomialAttr func(optionalAttr) + +// MultinomialSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 is set to be non-zero, the internal random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func MultinomialSeed(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// MultinomialSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func MultinomialSeed2(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { + return func(m optionalAttr) { + m["output_dtype"] = value + } +} + +// Draws samples from a multinomial distribution. +// +// Arguments: +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. +// +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Multinomial", + Input: []tf.Input{ + logits, num_samples, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RequantizePerChannelAttr is an optional argument to RequantizePerChannel. +type RequantizePerChannelAttr func(optionalAttr) + +// RequantizePerChannelOutType sets the optional out_type attribute to value. +// +// value: The quantized type of output tensor that needs to be converted. +// If not specified, defaults to DT_QUINT8 +func RequantizePerChannelOutType(value tf.DataType) RequantizePerChannelAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Requantizes input with min and max values known per channel. +// +// Arguments: +// input: The original input tensor. +// input_min: The minimum value of the input tensor +// input_max: The maximum value of the input tensor. +// requested_output_min: The minimum value of the output tensor requested. +// requested_output_max: The maximum value of the output tensor requested. +// +// Returns Output tensor.The minimum value of the final output tensorThe maximum value of the final output tensor. +func RequantizePerChannel(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, optional ...RequantizePerChannelAttr) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RequantizePerChannel", + Input: []tf.Input{ + input, input_min, input_max, requested_output_min, requested_output_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// CudnnRNNBackpropV2Attr is an optional argument to CudnnRNNBackpropV2. +type CudnnRNNBackpropV2Attr func(optionalAttr) + +// CudnnRNNBackpropV2RnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNBackpropV2RnnMode(value string) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["rnn_mode"] = value + } +} + +// CudnnRNNBackpropV2InputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNBackpropV2InputMode(value string) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNBackpropV2Direction sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNBackpropV2Direction(value string) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNBackpropV2Dropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNBackpropV2Dropout(value float32) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["dropout"] = value + } +} + +// CudnnRNNBackpropV2Seed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNBackpropV2Seed(value int64) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// CudnnRNNBackpropV2Seed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNBackpropV2Seed2(value int64) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Backprop step of CudnnRNN. +// +// Compute the backprop of both data and weights in a RNN. Takes an extra +// "host_reserved" inupt than CudnnRNNBackprop, which is used to determine RNN +// cudnnRNNAlgo_t and cudnnMathType_t. +// +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicates whether there is a linear projection between the input and +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, +// num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// output: A 3-D tensor with the shape of [seq_length, batch_size, +// dir * num_units]. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. +// output_backprop: A 3-D tensor with the same shape as output in the forward pass. +// output_h_backprop: A 3-D tensor with the same shape as output_h in the forward +// pass. +// output_c_backprop: A 3-D tensor with the same shape as output_c in the forward +// pass. +// reserve_space: The same reserve_space produced in the forward operation. +// host_reserved: The same host_reserved produced in the forward operation. +// input_backprop: The backprop to input in the forward pass. Has the same shape +// as input. +// input_h_backprop: The backprop to input_h in the forward pass. Has the same +// shape as input_h. +// input_c_backprop: The backprop to input_c in the forward pass. Has the same +// shape as input_c. +// params_backprop: The backprop to the params buffer in the forward pass. Has the +// same shape as params. +func CudnnRNNBackpropV2(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, output tf.Output, output_h tf.Output, output_c tf.Output, output_backprop tf.Output, output_h_backprop tf.Output, output_c_backprop tf.Output, reserve_space tf.Output, host_reserved tf.Output, optional ...CudnnRNNBackpropV2Attr) (input_backprop tf.Output, input_h_backprop tf.Output, input_c_backprop tf.Output, params_backprop tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CudnnRNNBackpropV2", + Input: []tf.Input{ + input, input_h, input_c, params, output, output_h, output_c, output_backprop, output_h_backprop, output_c_backprop, reserve_space, host_reserved, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// Inverse 2D fast Fourier transform. +// +// Computes the inverse 2-dimensional discrete Fourier transform over the +// inner-most 2 dimensions of `input`. +// +// Arguments: +// input: A complex tensor. +// +// Returns A complex tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their inverse 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft2 +// @end_compatibility +func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IFFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // LoadTPUEmbeddingFTRLParametersAttr is an optional argument to LoadTPUEmbeddingFTRLParameters. type LoadTPUEmbeddingFTRLParametersAttr func(optionalAttr) @@ -16756,16 +19805,898 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula return scope.AddOperation(opspec) } -// Returns (x - y)(x - y) element-wise. +// Restores tensors from a V2 checkpoint. // -// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// For backward compatibility with the V1 format, this Op currently allows +// restoring from a V1 checkpoint as well: +// - This Op first attempts to find the V2 index file pointed to by "prefix", and +// if found proceed to read it as a V2 checkpoint; +// - Otherwise the V1 read path is invoked. +// Relying on this behavior is not recommended, as the ability to fall back to read +// V1 might be deprecated and eventually removed. +// +// By default, restores the named tensors in full. If the caller wishes to restore +// specific slices of stored tensors, "shape_and_slices" should be non-empty +// strings and correspondingly well-formed. +// +// Callers must ensure all the named tensors are indeed stored in the checkpoint. +// +// Arguments: +// prefix: Must have a single element. The prefix of a V2 checkpoint. +// tensor_names: shape {N}. The names of the tensors to be restored. +// shape_and_slices: shape {N}. The slice specs of the tensors to be restored. +// Empty strings indicate that they are non-partitioned tensors. +// dtypes: shape {N}. The list of expected dtype for the tensors. Must match +// those stored in the checkpoint. +// +// Returns shape {N}. The restored tensors, whose shapes are read from the +// checkpoint directly. +func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + opspec := tf.OpSpec{ + Type: "RestoreV2", + Input: []tf.Input{ + prefix, tensor_names, shape_and_slices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil { + scope.UpdateErr("RestoreV2", err) + return + } + return tensors +} + +// OneHotAttr is an optional argument to OneHot. +type OneHotAttr func(optionalAttr) + +// OneHotAxis sets the optional axis attribute to value. +// +// value: The axis to fill (default: -1, a new inner-most axis). +// If not specified, defaults to -1 +func OneHotAxis(value int64) OneHotAttr { + return func(m optionalAttr) { + m["axis"] = value + } +} + +// Returns a one-hot tensor. +// +// The locations represented by indices in `indices` take value `on_value`, +// while all other locations take value `off_value`. +// +// If the input `indices` is rank `N`, the output will have rank `N+1`, +// The new axis is created at dimension `axis` (default: the new axis is +// appended at the end). +// +// If `indices` is a scalar the output shape will be a vector of length `depth`. +// +// If `indices` is a vector of length `features`, the output shape will be: +// ``` +// features x depth if axis == -1 +// depth x features if axis == 0 +// ``` +// +// If `indices` is a matrix (batch) with shape `[batch, features]`, +// the output shape will be: +// ``` +// batch x features x depth if axis == -1 +// batch x depth x features if axis == 1 +// depth x batch x features if axis == 0 +// ``` +// +// +// Examples +// ========= +// +// Suppose that +// ``` +// indices = [0, 2, -1, 1] +// depth = 3 +// on_value = 5.0 +// off_value = 0.0 +// axis = -1 +// ``` +// +// Then output is `[4 x 3]`: +// ``` +// output = +// [5.0 0.0 0.0] // one_hot(0) +// [0.0 0.0 5.0] // one_hot(2) +// [0.0 0.0 0.0] // one_hot(-1) +// [0.0 5.0 0.0] // one_hot(1) +// ``` +// +// Suppose that +// ``` +// indices = [0, 2, -1, 1] +// depth = 3 +// on_value = 0.0 +// off_value = 3.0 +// axis = 0 +// ``` +// +// Then output is `[3 x 4]`: +// ``` +// output = +// [0.0 3.0 3.0 3.0] +// [3.0 3.0 3.0 0.0] +// [3.0 3.0 3.0 3.0] +// [3.0 0.0 3.0 3.0] +// // ^ one_hot(0) +// // ^ one_hot(2) +// // ^ one_hot(-1) +// // ^ one_hot(1) +// ``` +// +// Suppose that +// ``` +// indices = [[0, 2], [1, -1]] +// depth = 3 +// on_value = 1.0 +// off_value = 0.0 +// axis = -1 +// ``` +// +// Then output is `[2 x 2 x 3]`: +// ``` +// output = +// [ +// [1.0, 0.0, 0.0] // one_hot(0) +// [0.0, 0.0, 1.0] // one_hot(2) +// ][ +// [0.0, 1.0, 0.0] // one_hot(1) +// [0.0, 0.0, 0.0] // one_hot(-1) +// ] +// ``` +// +// Arguments: +// indices: A tensor of indices. +// depth: A scalar defining the depth of the one hot dimension. +// on_value: A scalar defining the value to fill in output when `indices[j] = i`. +// off_value: A scalar defining the value to fill in output when `indices[j] != i`. +// +// Returns The one-hot tensor. +func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OneHot", + Input: []tf.Input{ + indices, depth, on_value, off_value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Split a `SparseTensor` into `num_split` tensors along one dimension. +// +// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices +// `[0 : shape[split_dim] % num_split]` gets one extra dimension. +// For example, if `split_dim = 1` and `num_split = 2` and the input is +// +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] +// +// Graphically the output tensors are: +// +// output_tensor[0] = shape = [2, 4] +// [ a ] +// [b c ] +// +// output_tensor[1] = shape = [2, 3] +// [ d e ] +// [ ] +// +// Arguments: +// split_dim: 0-D. The dimension along which to split. Must be in the range +// `[0, rank(shape))`. +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. +// num_split: The number of ways to split. +// +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_split": num_split} + opspec := tf.OpSpec{ + Type: "SparseSplit", + Input: []tf.Input{ + split_dim, indices, values, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil { + scope.UpdateErr("SparseSplit", err) + return + } + if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil { + scope.UpdateErr("SparseSplit", err) + return + } + if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil { + scope.UpdateErr("SparseSplit", err) + return + } + return output_indices, output_values, output_shape +} + +// UnicodeDecodeAttr is an optional argument to UnicodeDecode. +type UnicodeDecodeAttr func(optionalAttr) + +// UnicodeDecodeErrors sets the optional errors attribute to value. +// +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeDecodeErrors(value string) UnicodeDecodeAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeDecodeReplacementChar sets the optional replacement_char attribute to value. +// +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD or U+65533.) +// If not specified, defaults to 65533 +func UnicodeDecodeReplacementChar(value int64) UnicodeDecodeAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// UnicodeDecodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. +// +// value: Whether to replace the C0 control characters (00-1F) with the +// `replacement_char`. Default is false. +// If not specified, defaults to false +func UnicodeDecodeReplaceControlCharacters(value bool) UnicodeDecodeAttr { + return func(m optionalAttr) { + m["replace_control_characters"] = value + } +} + +// UnicodeDecodeTsplits sets the optional Tsplits attribute to value. +// If not specified, defaults to DT_INT64 +func UnicodeDecodeTsplits(value tf.DataType) UnicodeDecodeAttr { + return func(m optionalAttr) { + m["Tsplits"] = value + } +} + +// Decodes each string in `input` into a sequence of Unicode code points. +// +// The character codepoints for all strings are returned using a single vector +// `char_values`, with strings expanded to characters in row-major order. +// +// The `row_splits` tensor indicates where the codepoints for +// each input string begin and end within the `char_values` tensor. +// In particular, the values for the `i`th +// string (in row-major order) are stored in the slice +// `[row_splits[i]:row_splits[i+1]]`. Thus: +// +// * `char_values[row_splits[i]+j]` is the Unicode codepoint for the `j`th +// character in the `i`th string (in row-major order). +// * `row_splits[i+1] - row_splits[i]` is the number of characters in the `i`th +// string (in row-major order). +// +// Arguments: +// input: The text to be decoded. Can have any shape. Note that the output is flattened +// to a vector of char values. +// input_encoding: Text encoding of the input strings. This is any of the encodings supported +// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. +// +// Returns A 1D int32 tensor containing the row splits.A 1D int32 Tensor containing the decoded codepoints. +func UnicodeDecode(scope *Scope, input tf.Output, input_encoding string, optional ...UnicodeDecodeAttr) (row_splits tf.Output, char_values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"input_encoding": input_encoding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UnicodeDecode", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// OutfeedDequeueAttr is an optional argument to OutfeedDequeue. +type OutfeedDequeueAttr func(optionalAttr) + +// OutfeedDequeueDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func OutfeedDequeueDeviceOrdinal(value int64) OutfeedDequeueAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// Retrieves a single tensor from the computation outfeed. +// +// This operation will block indefinitely until data is available. +// +// Arguments: +// dtype: The type of elements in the tensor. +// shape: The shape of the tensor. +// +// Returns A tensor that will be read from the device outfeed. +func OutfeedDequeue(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...OutfeedDequeueAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OutfeedDequeue", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingProximalAdagradParametersAttr is an optional argument to LoadTPUEmbeddingProximalAdagradParameters. +type LoadTPUEmbeddingProximalAdagradParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingProximalAdagradParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingProximalAdagradParametersTableId(value int64) LoadTPUEmbeddingProximalAdagradParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingProximalAdagradParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingProximalAdagradParametersTableName(value string) LoadTPUEmbeddingProximalAdagradParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load proximal Adagrad embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the proximal Adagrad optimization algorithm. +// accumulators: Value of accumulators used in the proximal Adagrad optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingProximalAdagradParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingProximalAdagradParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingProximalAdagradParameters", + Input: []tf.Input{ + parameters, accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// ConfigureDistributedTPUAttr is an optional argument to ConfigureDistributedTPU. +type ConfigureDistributedTPUAttr func(optionalAttr) + +// ConfigureDistributedTPUEmbeddingConfig sets the optional embedding_config attribute to value. +// +// value: Reserved. Do not use. +// If not specified, defaults to "" +func ConfigureDistributedTPUEmbeddingConfig(value string) ConfigureDistributedTPUAttr { + return func(m optionalAttr) { + m["embedding_config"] = value + } +} + +// ConfigureDistributedTPUTpuEmbeddingConfig sets the optional tpu_embedding_config attribute to value. +// +// value: Serialized tensorflow.tpu.TPUEmbeddingConfiguration that +// describes the embedding lookups of the program. +// If not specified, defaults to "" +func ConfigureDistributedTPUTpuEmbeddingConfig(value string) ConfigureDistributedTPUAttr { + return func(m optionalAttr) { + m["tpu_embedding_config"] = value + } +} + +// ConfigureDistributedTPUIsGlobalInit sets the optional is_global_init attribute to value. +// +// value: Reserved. Do not use. +// If not specified, defaults to false +func ConfigureDistributedTPUIsGlobalInit(value bool) ConfigureDistributedTPUAttr { + return func(m optionalAttr) { + m["is_global_init"] = value + } +} + +// Sets up the centralized structures for a distributed TPU system. +// +// Returns A serialized tensorflow.tpu.TopologyProto that describes the TPU +// topology. +func ConfigureDistributedTPU(scope *Scope, optional ...ConfigureDistributedTPUAttr) (topology tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ConfigureDistributedTPU", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomStandardNormalAttr is an optional argument to RandomStandardNormal. +type RandomStandardNormalAttr func(optionalAttr) + +// RandomStandardNormalSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomStandardNormalSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a normal distribution. +// +// The generated values will have mean 0 and standard deviation 1. +// +// Arguments: +// shape: The shape of the output tensor. +// dtype: The type of the output. +// +// Returns A tensor of the specified shape filled with random normal values. +func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomStandardNormal", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. +type ResourceApplyAdagradAttr func(optionalAttr) + +// ResourceApplyAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// If not specified, defaults to true +func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr { + return func(m optionalAttr) { + m["update_slots"] = value + } +} + +// Update '*var' according to the adagrad scheme. +// +// accum += grad * grad +// var -= lr * grad * (1 / sqrt(accum)) +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, optional ...ResourceApplyAdagradAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdagrad", + Input: []tf.Input{ + var_, accum, lr, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. +type ResourceApplyProximalAdagradAttr func(optionalAttr) + +// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. +// +// accum += grad * grad +// prox_v = var - lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyProximalAdagrad", + Input: []tf.Input{ + var_, accum, lr, l1, l2, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// ApproximateEqualAttr is an optional argument to ApproximateEqual. +type ApproximateEqualAttr func(optionalAttr) + +// ApproximateEqualTolerance sets the optional tolerance attribute to value. +// If not specified, defaults to 1e-05 +func ApproximateEqualTolerance(value float32) ApproximateEqualAttr { + return func(m optionalAttr) { + m["tolerance"] = value + } +} + +// Returns the truth value of abs(x-y) < tolerance element-wise. +func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...ApproximateEqualAttr) (z tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ApproximateEqual", + Input: []tf.Input{ + x, y, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Pop the element at the top of the stack. +// +// Arguments: +// handle: The handle to a stack. +// elem_type: The type of the elem that is popped. +// +// Returns The tensor that is popped from the top of the stack. +func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"elem_type": elem_type} + opspec := tf.OpSpec{ + Type: "StackPopV2", + Input: []tf.Input{ + handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Fills empty rows in the input 2-D `SparseTensor` with a default value. +// +// The input `SparseTensor` is represented via the tuple of inputs +// (`indices`, `values`, `dense_shape`). The output `SparseTensor` has the +// same `dense_shape` but with indices `output_indices` and values +// `output_values`. +// +// This op inserts a single entry for every row that doesn't have any values. +// The index is created as `[row, 0, ..., 0]` and the inserted value +// is `default_value`. +// +// For example, suppose `sp_input` has shape `[5, 6]` and non-empty values: +// +// [0, 1]: a +// [0, 3]: b +// [2, 0]: c +// [3, 1]: d +// +// Rows 1 and 4 are empty, so the output will be of shape `[5, 6]` with values: +// +// [0, 1]: a +// [0, 3]: b +// [1, 0]: default_value +// [2, 0]: c +// [3, 1]: d +// [4, 0]: default_value +// +// The output `SparseTensor` will be in row-major order and will have the +// same shape as the input. +// +// This op also returns an indicator vector shaped `[dense_shape[0]]` such that +// +// empty_row_indicator[i] = True iff row i was an empty row. +// +// And a reverse index map vector shaped `[indices.shape[0]]` that is used during +// backpropagation, +// +// reverse_index_map[j] = out_j s.t. indices[j, :] == output_indices[out_j, :] +// +// Arguments: +// indices: 2-D. the indices of the sparse tensor. +// values: 1-D. the values of the sparse tensor. +// dense_shape: 1-D. the shape of the sparse tensor. +// default_value: 0-D. default value to insert into location `[row, 0, ..., 0]` +// for rows missing from the input sparse tensor. +// output indices: 2-D. the indices of the filled sparse tensor. +// +// Returns 1-D. the values of the filled sparse tensor.1-D. whether the dense row was missing in the +// input sparse tensor.1-D. a map from the input indices to the output indices. +func SparseFillEmptyRows(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output, default_value tf.Output) (output_indices tf.Output, output_values tf.Output, empty_row_indicator tf.Output, reverse_index_map tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SquaredDifference", + Type: "SparseFillEmptyRows", + Input: []tf.Input{ + indices, values, dense_shape, default_value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// EnqueueTPUEmbeddingSparseBatchAttr is an optional argument to EnqueueTPUEmbeddingSparseBatch. +type EnqueueTPUEmbeddingSparseBatchAttr func(optionalAttr) + +// EnqueueTPUEmbeddingSparseBatchDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. Should be >= 0 and less than the number +// of TPU cores in the task on which the node is placed. +// If not specified, defaults to -1 +func EnqueueTPUEmbeddingSparseBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingSparseBatchAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// EnqueueTPUEmbeddingSparseBatchCombiners sets the optional combiners attribute to value. +// +// value: A list of string scalars, one for each embedding table that specify +// how to normalize the embedding activations after weighted summation. +// Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have +// the sum of the weights be 0 for 'mean' or the sum of the squared weights be +// 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for +// all tables. +// If not specified, defaults to <> +func EnqueueTPUEmbeddingSparseBatchCombiners(value []string) EnqueueTPUEmbeddingSparseBatchAttr { + return func(m optionalAttr) { + m["combiners"] = value + } +} + +// An op that enqueues TPUEmbedding input indices from a SparseTensor. +// +// This Op eases the porting of code that uses embedding_lookup_sparse(), +// although some Python preprocessing of the SparseTensor arguments to +// embedding_lookup_sparse() is required to produce the arguments to this Op, +// since only a single EnqueueTPUEmbeddingSparseBatch Op is allowed per training +// step. +// +// The tensors at corresponding positions in the three input lists +// must have the same shape, i.e. rank 1 with dim_size() equal to the total +// number of lookups into the table described by the corresponding table_id. +// +// Arguments: +// sample_indices: A list of rank 1 Tensors specifying the training example and +// feature to which the corresponding embedding_indices and aggregation_weights +// values belong. sample_indices[i] must equal b * nf + f, where nf is the +// number of features from the corresponding table, f is in [0, nf), and +// b is in [0, batch size). +// embedding_indices: A list of rank 1 Tensors, indices into the embedding tables. +// aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e. per +// (training example, feature) -- aggregation weights. +// mode_override: A string input that overrides the mode specified in the +// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', +// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set +// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. +// +// Returns the created operation. +func EnqueueTPUEmbeddingSparseBatch(scope *Scope, sample_indices []tf.Output, embedding_indices []tf.Output, aggregation_weights []tf.Output, mode_override tf.Output, optional ...EnqueueTPUEmbeddingSparseBatchAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EnqueueTPUEmbeddingSparseBatch", + Input: []tf.Input{ + tf.OutputList(sample_indices), tf.OutputList(embedding_indices), tf.OutputList(aggregation_weights), mode_override, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr is an optional argument to RetrieveTPUEmbeddingStochasticGradientDescentParameters. +type RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve SGD embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the stochastic gradient descent optimization algorithm. +func RetrieveTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr) (parameters tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingStochasticGradientDescentParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns element-wise remainder of division. When `x < 0` xor `y < 0` is +// +// true, this follows Python semantics in that the result here is consistent +// with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. +// +// *NOTE*: `FloorMod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func FloorMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FloorMod", Input: []tf.Input{ x, y, }, @@ -16774,23 +20705,1521 @@ func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Push an element onto the tensor_array. -// -// Arguments: -// handle: The handle to a TensorArray. -// index: The position to write to inside the TensorArray. -// value: The tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Returns 0 if x == 0, and x / y otherwise, elementwise. +func Xdivy(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayWriteV3", + Type: "Xdivy", Input: []tf.Input{ - handle, index, value, flow_in, + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// The gradient operator for the SparseSlice op. +// +// This op takes in the upstream gradient w.r.t. non-empty values of +// the sliced `SparseTensor`, and outputs the gradients w.r.t. +// the non-empty values of input `SparseTensor`. +// +// Arguments: +// backprop_val_grad: 1-D. The gradient with respect to +// the non-empty values of the sliced `SparseTensor`. +// input_indices: 2-D. The `indices` of the input `SparseTensor`. +// input_start: 1-D. tensor represents the start of the slice. +// output_indices: 2-D. The `indices` of the sliced `SparseTensor`. +// +// Returns 1-D. The gradient with respect to the non-empty values of input `SparseTensor`. +func SparseSliceGrad(scope *Scope, backprop_val_grad tf.Output, input_indices tf.Output, input_start tf.Output, output_indices tf.Output) (val_grad tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSliceGrad", + Input: []tf.Input{ + backprop_val_grad, input_indices, input_start, output_indices, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug. +type RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve Adadelta embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the Adadelta optimization algorithm.Parameter accumulators updated by the Adadelta optimization algorithm.Parameter updates updated by the Adadelta optimization algorithm.Parameter gradient_accumulators updated by the Adadelta optimization algorithm. +func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, updates tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// ReduceJoinAttr is an optional argument to ReduceJoin. +type ReduceJoinAttr func(optionalAttr) + +// ReduceJoinKeepDims sets the optional keep_dims attribute to value. +// +// value: If `True`, retain reduced dimensions with length `1`. +// If not specified, defaults to false +func ReduceJoinKeepDims(value bool) ReduceJoinAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// ReduceJoinSeparator sets the optional separator attribute to value. +// +// value: The separator to use when joining. +// If not specified, defaults to "" +func ReduceJoinSeparator(value string) ReduceJoinAttr { + return func(m optionalAttr) { + m["separator"] = value + } +} + +// Joins a string Tensor across the given dimensions. +// +// Computes the string join across dimensions in the given string Tensor of shape +// `[\\(d_0, d_1, ..., d_{n-1}\\)]`. Returns a new Tensor created by joining the input +// strings with the given separator (default: empty string). Negative indices are +// counted backwards from the end, with `-1` being equivalent to `n - 1`. If +// indices are not specified, joins across all dimensions beginning from `n - 1` +// through `0`. +// +// For example: +// +// ```python +// # tensor `a` is [["a", "b"], ["c", "d"]] +// tf.reduce_join(a, 0) ==> ["ac", "bd"] +// tf.reduce_join(a, 1) ==> ["ab", "cd"] +// tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==> ["ac", "bd"] +// tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"] +// tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]] +// tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]] +// tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"] +// tf.reduce_join(a, [0, 1]) ==> "acbd" +// tf.reduce_join(a, [1, 0]) ==> "abcd" +// tf.reduce_join(a, []) ==> [["a", "b"], ["c", "d"]] +// tf.reduce_join(a) = tf.reduce_join(a, [1, 0]) ==> "abcd" +// ``` +// +// Arguments: +// inputs: The input to be joined. All reduced indices must have non-zero size. +// reduction_indices: The dimensions to reduce over. Dimensions are reduced in the +// order specified. Omitting `reduction_indices` is equivalent to passing +// `[n-1, n-2, ..., 0]`. Negative indices from `-n` to `-1` are supported. +// +// Returns Has shape equal to that of the input with reduced dimensions removed or +// set to `1` depending on `keep_dims`. +func ReduceJoin(scope *Scope, inputs tf.Output, reduction_indices tf.Output, optional ...ReduceJoinAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ReduceJoin", + Input: []tf.Input{ + inputs, reduction_indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the max of x and y (i.e. x > y ? x : y) element-wise. +// +// *NOTE*: `Maximum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Maximum", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that shuffles and repeats elements from `input_dataset` +// +// pseudorandomly. +// +// Arguments: +// +// buffer_size: The number of output elements to buffer in an iterator over +// this dataset. Compare with the `min_after_dequeue` attr when creating a +// `RandomShuffleQueue`. +// seed: A scalar seed for the random number generator. If either `seed` or +// `seed2` is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. +// count: A scalar representing the number of times the underlying dataset +// should be repeated. The default is `-1`, which results in infinite repetition. +// +// +func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ShuffleAndRepeatDataset", + Input: []tf.Input{ + input_dataset, buffer_size, seed, seed2, count, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. +type ResourceSparseApplyAdagradDAAttr func(optionalAttr) + +// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. +// +// Arguments: +// var_: Should be from a Variable(). +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyAdagradDA", + Input: []tf.Input{ + var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Applies softmax to a batched N-D `SparseTensor`. +// +// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` +// (where `N >= 2`), and with indices sorted in the canonical lexicographic order. +// +// This op is equivalent to applying the normal `tf.nn.softmax()` to each innermost +// logical submatrix with shape `[B, C]`, but with the catch that *the implicitly +// zero elements do not participate*. Specifically, the algorithm is equivalent +// to the following: +// +// (1) Applies `tf.nn.softmax()` to a densified view of each innermost submatrix +// with shape `[B, C]`, along the size-C dimension; +// (2) Masks out the original implicitly-zero locations; +// (3) Renormalizes the remaining elements. +// +// Hence, the `SparseTensor` result has exactly the same non-zero indices and +// shape. +// +// Arguments: +// sp_indices: 2-D. `NNZ x R` matrix with the indices of non-empty values in a +// SparseTensor, in canonical ordering. +// sp_values: 1-D. `NNZ` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// +// Returns 1-D. The `NNZ` values for the result `SparseTensor`. +func SparseSoftmax(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSoftmax", + Input: []tf.Input{ + sp_indices, sp_values, sp_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAdagradDAAttr is an optional argument to ResourceApplyAdagradDA. +type ResourceApplyAdagradDAAttr func(optionalAttr) + +// ResourceApplyAdagradDAUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyAdagradDAUseLocking(value bool) ResourceApplyAdagradDAAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the proximal adagrad scheme. +// +// Arguments: +// var_: Should be from a Variable(). +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceApplyAdagradDAAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdagradDA", + Input: []tf.Input{ + var_, gradient_accumulator, gradient_squared_accumulator, grad, lr, l1, l2, global_step, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// TensorListConcatAttr is an optional argument to TensorListConcat. +type TensorListConcatAttr func(optionalAttr) + +// TensorListConcatElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorListConcatElementShape(value tf.Shape) TensorListConcatAttr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Concats all tensors in the list along the 0th dimension. +// +// Requires that all tensors have the same shape except the first dimension. +// +// input_handle: The input list. +// tensor: The concated result. +// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. +// +func TensorListConcat(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListConcatAttr) (tensor tf.Output, lengths tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorListConcat", + Input: []tf.Input{ + input_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Returns the diagonal part of the tensor. +// +// This operation returns a tensor with the `diagonal` part +// of the `input`. The `diagonal` part is computed as follows: +// +// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a +// tensor of rank `k` with dimensions `[D1,..., Dk]` where: +// +// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`. +// +// For example: +// +// ``` +// # 'input' is [[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]] +// +// tf.diag_part(input) ==> [1, 2, 3, 4] +// ``` +// +// Arguments: +// input: Rank k tensor where k is even and not zero. +// +// Returns The extracted diagonal. +func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DiagPart", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that contains `count` elements from the `input_dataset`. +// +// Arguments: +// +// count: A scalar representing the number of elements from the `input_dataset` +// that should be taken. A value of `-1` indicates that all of `input_dataset` +// is taken. +// +// +func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "TakeDataset", + Input: []tf.Input{ + input_dataset, count, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Concats all tensors in the list along the 0th dimension. +// +// Requires that all tensors have the same shape except the first dimension. +// +// input_handle: The input list. +// element_shape: The shape of the uninitialized elements in the list. If the first +// dimension is not -1, it is assumed that all list elements have the same +// leading dim. +// leading_dims: The list of leading dims of uninitialized list elements. Used if +// the leading dim of input_handle.element_shape or the element_shape input arg +// is not already set. +// tensor: The concated result. +// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. +// +func TensorListConcatV2(scope *Scope, input_handle tf.Output, element_shape tf.Output, leading_dims tf.Output, element_dtype tf.DataType) (tensor tf.Output, lengths tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + opspec := tf.OpSpec{ + Type: "TensorListConcatV2", + Input: []tf.Input{ + input_handle, element_shape, leading_dims, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// CropAndResizeAttr is an optional argument to CropAndResize. +type CropAndResizeAttr func(optionalAttr) + +// CropAndResizeMethod sets the optional method attribute to value. +// +// value: A string specifying the sampling method for resizing. It can be either +// `"bilinear"` or `"nearest"` and default to `"bilinear"`. Currently two sampling +// methods are supported: Bilinear and Nearest Neighbor. +// If not specified, defaults to "bilinear" +func CropAndResizeMethod(value string) CropAndResizeAttr { + return func(m optionalAttr) { + m["method"] = value + } +} + +// CropAndResizeExtrapolationValue sets the optional extrapolation_value attribute to value. +// +// value: Value used for extrapolation, when applicable. +// If not specified, defaults to 0 +func CropAndResizeExtrapolationValue(value float32) CropAndResizeAttr { + return func(m optionalAttr) { + m["extrapolation_value"] = value + } +} + +// Extracts crops from the input image tensor and resizes them. +// +// Extracts crops from the input image tensor and resizes them using bilinear +// sampling or nearest neighbor sampling (possibly with aspect ratio change) to a +// common output size specified by `crop_size`. This is more general than the +// `crop_to_bounding_box` op which extracts a fixed size slice from the input image +// and does not allow resizing or aspect ratio change. +// +// Returns a tensor with `crops` from the input `image` at positions defined at the +// bounding box locations in `boxes`. The cropped boxes are all resized (with +// bilinear or nearest neighbor interpolation) to a fixed +// `size = [crop_height, crop_width]`. The result is a 4-D tensor +// `[num_boxes, crop_height, crop_width, depth]`. The resizing is corner aligned. +// In particular, if `boxes = [[0, 0, 1, 1]]`, the method will give identical +// results to using `tf.image.resize_bilinear()` or +// `tf.image.resize_nearest_neighbor()`(depends on the `method` argument) with +// `align_corners=True`. +// +// Arguments: +// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +// Both `image_height` and `image_width` need to be positive. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1]` in image height coordinates. We do allow `y1` > `y2`, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// crop_size: A 1-D tensor of 2 elements, `size = [crop_height, crop_width]`. All +// cropped image patches are resized to this size. The aspect ratio of the image +// content is not preserved. Both `crop_height` and `crop_width` need to be +// positive. +// +// Returns A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +func CropAndResize(scope *Scope, image tf.Output, boxes tf.Output, box_ind tf.Output, crop_size tf.Output, optional ...CropAndResizeAttr) (crops tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CropAndResize", + Input: []tf.Input{ + image, boxes, box_ind, crop_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a MultiDeviceIterator resource. +// +// Arguments: +// devices: A list of devices the iterator works across. +// shared_name: If non-empty, this resource will be shared under the given name +// across multiple sessions. +// container: If non-empty, this resource is placed in the given container. +// Otherwise, a default container is used. +// output_types: The type list for the return values. +// output_shapes: The list of shapes being produced. +// +// Returns Handle to the resource created. +func MultiDeviceIterator(scope *Scope, devices []string, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"devices": devices, "shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "MultiDeviceIterator", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes rectified linear 6: `min(max(features, 0), 6)`. +func Relu6(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Relu6", + Input: []tf.Input{ + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Merges summaries. +// +// This op creates a +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// protocol buffer that contains the union of all the values in the input +// summaries. +// +// When the Op is run, it reports an `InvalidArgument` error if multiple values +// in the summaries to merge use the same tag. +// +// Arguments: +// inputs: Can be of any shape. Each must contain serialized `Summary` protocol +// buffers. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MergeSummary", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 3D real-valued fast Fourier transform. +// +// Computes the 3-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 3 dimensions of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. +// +// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// +// Arguments: +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the their 3D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfftn with 3 dimensions. +// @end_compatibility +func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RFFT3D", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. +type ResourceSparseApplyRMSPropAttr func(optionalAttr) + +// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the RMSProp algorithm. +// +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +// +// Arguments: +// var_: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var, ms and mom. +// +// Returns the created operation. +func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyRMSProp", + Input: []tf.Input{ + var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Eagerly executes a python function to compute func(input)->output. The +// +// semantics of the input, output, and attributes are the same as those for +// PyFunc. +func EagerPyFunc(scope *Scope, input []tf.Output, token string, Tout []tf.DataType) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"token": token, "Tout": Tout} + opspec := tf.OpSpec{ + Type: "EagerPyFunc", + Input: []tf.Input{ + tf.OutputList(input), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("EagerPyFunc", err) + return + } + return output +} + +// Computes softmax cross entropy cost and gradients to backpropagate. +// +// Inputs are the logits, not probabilities. +// +// Arguments: +// features: batch_size x num_classes matrix +// labels: batch_size x num_classes matrix +// The caller must ensure that each batch of labels represents a valid +// probability distribution. +// +// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). +func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SoftmaxCrossEntropyWithLogits", + Input: []tf.Input{ + features, labels, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. +type MutableHashTableOfTensorsV2Attr func(optionalAttr) + +// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// If not specified, defaults to <> +func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a vector. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableOfTensorsV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. +type ExtractGlimpseAttr func(optionalAttr) + +// ExtractGlimpseCentered sets the optional centered attribute to value. +// +// value: indicates if the offset coordinates are centered relative to +// the image, in which case the (0, 0) offset is relative to the center +// of the input images. If false, the (0,0) offset corresponds to the +// upper left corner of the input images. +// If not specified, defaults to true +func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["centered"] = value + } +} + +// ExtractGlimpseNormalized sets the optional normalized attribute to value. +// +// value: indicates if the offset coordinates are normalized. +// If not specified, defaults to true +func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["normalized"] = value + } +} + +// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. +// +// value: indicates if the noise should be generated using a +// uniform distribution or a Gaussian distribution. +// If not specified, defaults to true +func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["uniform_noise"] = value + } +} + +// ExtractGlimpseNoise sets the optional noise attribute to value. +// +// value: indicates if the noise should `uniform`, `gaussian`, or +// `zero`. The default is `uniform` which means the the noise type +// will be decided by `uniform_noise`. +// If not specified, defaults to "uniform" +func ExtractGlimpseNoise(value string) ExtractGlimpseAttr { + return func(m optionalAttr) { + m["noise"] = value + } +} + +// Extracts a glimpse from the input tensor. +// +// Returns a set of windows called glimpses extracted at location +// `offsets` from the input tensor. If the windows only partially +// overlaps the inputs, the non overlapping areas will be filled with +// random noise. +// +// The result is a 4-D tensor of shape `[batch_size, glimpse_height, +// glimpse_width, channels]`. The channels and batch dimensions are the +// same as that of the input tensor. The height and width of the output +// windows are specified in the `size` parameter. +// +// The argument `normalized` and `centered` controls how the windows are built: +// +// * If the coordinates are normalized but not centered, 0.0 and 1.0 +// correspond to the minimum and maximum of each height and width +// dimension. +// * If the coordinates are both normalized and centered, they range from +// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper +// left corner, the lower right corner is located at (1.0, 1.0) and the +// center is at (0, 0). +// * If the coordinates are not normalized they are interpreted as +// numbers of pixels. +// +// Arguments: +// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. +// size: A 1-D tensor of 2 elements containing the size of the glimpses +// to extract. The glimpse height must be specified first, following +// by the glimpse width. +// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing +// the y, x locations of the center of each window. +// +// Returns A tensor representing the glimpses `[batch_size, +// glimpse_height, glimpse_width, channels]`. +func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ExtractGlimpse", + Input: []tf.Input{ + input, size, offsets, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax. +type ResourceApplyAdaMaxAttr func(optionalAttr) + +// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the AdaMax algorithm. +// +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// v_t <- max(beta2 * v_{t-1}, abs(g)) +// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdaMax", + Input: []tf.Input{ + var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// EncodeBase64Attr is an optional argument to EncodeBase64. +type EncodeBase64Attr func(optionalAttr) + +// EncodeBase64Pad sets the optional pad attribute to value. +// +// value: Bool whether padding is applied at the ends. +// If not specified, defaults to false +func EncodeBase64Pad(value bool) EncodeBase64Attr { + return func(m optionalAttr) { + m["pad"] = value + } +} + +// Encode strings into web-safe base64 format. +// +// Refer to the following article for more information on base64 format: +// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the +// end so that the encoded has length multiple of 4. See Padding section of the +// link above. +// +// Web-safe means that the encoder uses - and _ instead of + and /. +// +// Arguments: +// input: Strings to be encoded. +// +// Returns Input strings encoded in base64. +func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EncodeBase64", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inverse 2D real-valued fast Fourier transform. +// +// Computes the inverse 2-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most 2 dimensions of `input`. +// +// The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`: +// The inner-most dimension contains the `fft_length / 2 + 1` unique components of +// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed +// from the size of the inner-most 2 dimensions of `input`. If the FFT length used +// to compute `input` is odd, it should be provided since it cannot be inferred +// properly. +// +// Along each axis `IRFFT2D` is computed on, if `fft_length` (or +// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// +// Arguments: +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. +// +// Returns A float32 tensor of the same rank as `input`. The inner-most 2 +// dimensions of `input` are replaced with the `fft_length` samples of their +// inverse 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.irfft2 +// @end_compatibility +func IRFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IRFFT2D", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// An Op to sum inputs across replicated TPU instances. +// +// Each instance supplies its own input. +// +// For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`. +// Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, +// and `B, D, F, H` as group 1. Thus we get the outputs: +// `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. +// +// Arguments: +// input: The local input to the sum. +// group_assignment: An int32 tensor with shape +// [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the +// replica ids in the ith subgroup. +// +// Returns The sum of all the distributed inputs. +func CrossReplicaSum(scope *Scope, input tf.Output, group_assignment tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CrossReplicaSum", + Input: []tf.Input{ + input, group_assignment, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the mean along sparse segments of a tensor. +// +// See `tf.sparse.segment_sum` for usage examples. +// +// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first +// dimension, selecting a subset of dimension 0, specified by `indices`. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentMean", + Input: []tf.Input{ + data, indices, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. +type ResourceApplyProximalGradientDescentAttr func(optionalAttr) + +// ResourceApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalGradientDescentUseLocking(value bool) ResourceApplyProximalGradientDescentAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' as FOBOS algorithm with fixed learning rate. +// +// prox_v = var - alpha * delta +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// delta: The change. +// +// Returns the created operation. +func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, delta tf.Output, optional ...ResourceApplyProximalGradientDescentAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyProximalGradientDescent", + Input: []tf.Input{ + var_, alpha, l1, l2, delta, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Decode web-safe base64-encoded strings. +// +// Input may or may not have padding at the end. See EncodeBase64 for padding. +// Web-safe means that input must use - and _ instead of + and /. +// +// Arguments: +// input: Base64 strings to decode. +// +// Returns Decoded strings. +func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DecodeBase64", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// CastAttr is an optional argument to Cast. +type CastAttr func(optionalAttr) + +// CastTruncate sets the optional Truncate attribute to value. +// If not specified, defaults to false +func CastTruncate(value bool) CastAttr { + return func(m optionalAttr) { + m["Truncate"] = value + } +} + +// Cast x of type SrcT to y of DstT. +func Cast(scope *Scope, x tf.Output, DstT tf.DataType, optional ...CastAttr) (y tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"DstT": DstT} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Cast", + Input: []tf.Input{ + x, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// EnqueueTPUEmbeddingIntegerBatchAttr is an optional argument to EnqueueTPUEmbeddingIntegerBatch. +type EnqueueTPUEmbeddingIntegerBatchAttr func(optionalAttr) + +// EnqueueTPUEmbeddingIntegerBatchDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. Should be >= 0 and less than the number +// of TPU cores in the task on which the node is placed. +// If not specified, defaults to -1 +func EnqueueTPUEmbeddingIntegerBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingIntegerBatchAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// An op that enqueues a list of input batch tensors to TPUEmbedding. +// +// Arguments: +// batch: A list of 1D tensors, one for each embedding table, containing the +// indices into the tables. +// mode_override: A string input that overrides the mode specified in the +// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', +// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set +// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. +// +// Returns the created operation. +func EnqueueTPUEmbeddingIntegerBatch(scope *Scope, batch []tf.Output, mode_override tf.Output, optional ...EnqueueTPUEmbeddingIntegerBatchAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EnqueueTPUEmbeddingIntegerBatch", + Input: []tf.Input{ + tf.OutputList(batch), mode_override, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// StringJoinAttr is an optional argument to StringJoin. +type StringJoinAttr func(optionalAttr) + +// StringJoinSeparator sets the optional separator attribute to value. +// +// value: string, an optional join separator. +// If not specified, defaults to "" +func StringJoinSeparator(value string) StringJoinAttr { + return func(m optionalAttr) { + m["separator"] = value + } +} + +// Joins the strings in the given list of string tensors into one tensor; +// +// with the given separator (default is an empty separator). +// +// Arguments: +// inputs: A list of string tensors. The tensors must all have the same shape, +// or be scalars. Scalars may be mixed in; these will be broadcast to the shape +// of non-scalar inputs. +func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StringJoin", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyKerasMomentumAttr is an optional argument to ResourceApplyKerasMomentum. +type ResourceApplyKerasMomentumAttr func(optionalAttr) + +// ResourceApplyKerasMomentumUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyKerasMomentumUseLocking(value bool) ResourceApplyKerasMomentumAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceApplyKerasMomentumUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, the tensor passed to compute grad will be +// var + momentum * accum, so in the end, the var you get is actually +// var + momentum * accum. +// If not specified, defaults to false +func ResourceApplyKerasMomentumUseNesterov(value bool) ResourceApplyKerasMomentumAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the momentum scheme. Set use_nesterov = True if you +// +// want to use Nesterov momentum. +// +// accum = accum * momentum - lr * grad +// var += accum +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// grad: The gradient. +// momentum: Momentum. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyKerasMomentumAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyKerasMomentum", + Input: []tf.Input{ + var_, accum, lr, grad, momentum, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. +type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) + +// ResourceSparseApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyProximalAdagradUseLocking(value bool) ResourceSparseApplyProximalAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Sparse update entries in '*var' and '*accum' according to FOBOS algorithm. +// +// That is for rows we have grad for, we update var and accum as follows: +// accum += grad * grad +// prox_v = var +// prox_v -= lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalAdagradAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyProximalAdagrad", + Input: []tf.Input{ + var_, accum, lr, l1, l2, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// AnyAttr is an optional argument to Any. +type AnyAttr func(optionalAttr) + +// AnyKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AnyKeepDims(value bool) AnyAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the "logical or" of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Any", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes natural logarithm of (1 + x) element-wise. +// +// I.e., \\(y = \log_e (1 + x)\\). +func Log1p(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Log1p", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated, use python implementation tf.linalg.matrix_exponential. +// +// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead. +func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixExponential", + Input: []tf.Input{ + input, }, } op := scope.AddOperation(opspec) @@ -16843,45 +22272,8659 @@ func RetrieveTPUEmbeddingAdagradParameters(scope *Scope, num_shards int64, shard return op.Output(0), op.Output(1) } -// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. +// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. +type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) + +// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. // -// Each comparison returns a boolean `true` (if `input_value > threshold`) -// or and `false` otherwise. +// value: If `True`, updating of the var, mg, ms, and mom tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the centered RMSProp algorithm. // -// This operation is useful for Locality-Sensitive-Hashing (LSH) and other -// algorithms that use hashing approximations of cosine and `L2` distances; -// codes can be generated from an input via: +// The centered RMSProp algorithm uses an estimate of the centered second moment +// (i.e., the variance) for normalization, as opposed to regular RMSProp, which +// uses the (uncentered) second moment. This often helps with training, but is +// slightly more expensive in terms of computation and memory. // -// ```python -// codebook_size = 50 -// codebook_bits = codebook_size * 32 -// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], -// dtype=x.dtype, -// initializer=tf.orthogonal_initializer()) -// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) -// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 -// # now codes has shape x.shape[:-1] + [codebook_size] -// ``` +// Note that in dense implementation of this algorithm, mg, ms, and mom will +// update even if the grad is zero, but in this sparse implementation, mg, ms, +// and mom will not update in iterations during which the grad is zero. // -// **NOTE**: Currently, the innermost dimension of the tensor must be divisible -// by 8. +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// mean_grad = decay * mean_grad + (1-decay) * gradient +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) // -// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is -// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom // // Arguments: -// input: Values to compare against `threshold` and bitpack. -// threshold: Threshold to compare against. +// var_: Should be from a Variable(). +// mg: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. // -// Returns The bitpacked comparisons. -func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var, ms and mom. +// +// Returns the created operation. +func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyCenteredRMSProp", + Input: []tf.Input{ + var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. +type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) + +// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Sparse update '*var' as FOBOS algorithm with fixed learning rate. +// +// That is for rows we have grad for, we update var as follows: +// prox_v = var - alpha * grad +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyProximalGradientDescent", + Input: []tf.Input{ + var_, alpha, l1, l2, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Assigns a new value to a variable. +// +// Any ReadVariableOp with a control dependency on this op is guaranteed to return +// this value or a subsequent newer value of the variable. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// value: the value to set the new tensor to use. +// +// Returns the created operation. +func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "CompareAndBitpack", + Type: "AssignVariableOp", Input: []tf.Input{ - input, threshold, + resource, value, + }, + } + return scope.AddOperation(opspec) +} + +// Check if the input matches the regex pattern. +// +// The input is a string tensor of any shape. The pattern is the +// regular expression to be matched with every element of the input tensor. +// The boolean values (True or False) of the output tensor indicate +// if the input matches the regex pattern provided. +// +// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// +// Arguments: +// input: A string tensor of the text to be processed. +// pattern: The regular expression to match the input. +// +// Returns A bool tensor with the same shape as `input`. +func StaticRegexFullMatch(scope *Scope, input tf.Output, pattern string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"pattern": pattern} + opspec := tf.OpSpec{ + Type: "StaticRegexFullMatch", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Produces a summary of any statistics recorded by the given statistics manager. +func ExperimentalStatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ExperimentalStatsAggregatorSummary", + Input: []tf.Input{ + iterator, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Makes its input available to the next iteration. +// +// Arguments: +// data: The tensor to be made available to the next iteration. +// +// Returns The same tensor as `data`. +func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NextIteration", + Input: []tf.Input{ + data, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Does nothing. Only useful as a placeholder for control edges. +// +// Returns the created operation. +func NoOp(scope *Scope) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NoOp", + } + return scope.AddOperation(opspec) +} + +// AsStringAttr is an optional argument to AsString. +type AsStringAttr func(optionalAttr) + +// AsStringPrecision sets the optional precision attribute to value. +// +// value: The post-decimal precision to use for floating point numbers. +// Only used if precision > -1. +// If not specified, defaults to -1 +func AsStringPrecision(value int64) AsStringAttr { + return func(m optionalAttr) { + m["precision"] = value + } +} + +// AsStringScientific sets the optional scientific attribute to value. +// +// value: Use scientific notation for floating point numbers. +// If not specified, defaults to false +func AsStringScientific(value bool) AsStringAttr { + return func(m optionalAttr) { + m["scientific"] = value + } +} + +// AsStringShortest sets the optional shortest attribute to value. +// +// value: Use shortest representation (either scientific or standard) for +// floating point numbers. +// If not specified, defaults to false +func AsStringShortest(value bool) AsStringAttr { + return func(m optionalAttr) { + m["shortest"] = value + } +} + +// AsStringWidth sets the optional width attribute to value. +// +// value: Pad pre-decimal numbers to this width. +// Applies to both floating point and integer numbers. +// Only used if width > -1. +// If not specified, defaults to -1 +func AsStringWidth(value int64) AsStringAttr { + return func(m optionalAttr) { + m["width"] = value + } +} + +// AsStringFill sets the optional fill attribute to value. +// +// value: The value to pad if width > -1. If empty, pads with spaces. +// Another typical value is '0'. String cannot be longer than 1 character. +// If not specified, defaults to "" +func AsStringFill(value string) AsStringAttr { + return func(m optionalAttr) { + m["fill"] = value + } +} + +// Converts each entry in the given tensor to strings. Supports many numeric +// +// types and boolean. +func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AsString", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. +type ResourceSparseApplyFtrlAttr func(optionalAttr) + +// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// +// That is for rows we have grad for, we update var, accum and linear as follows: +// accum_new = accum + grad * grad +// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyFtrl", + Input: []tf.Input{ + var_, accum, linear, grad, indices, lr, l1, l2, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// SdcaOptimizerV2Attr is an optional argument to SdcaOptimizerV2. +type SdcaOptimizerV2Attr func(optionalAttr) + +// SdcaOptimizerV2Adaptive sets the optional adaptive attribute to value. +// +// value: Whether to use Adaptive SDCA for the inner loop. +// If not specified, defaults to true +func SdcaOptimizerV2Adaptive(value bool) SdcaOptimizerV2Attr { + return func(m optionalAttr) { + m["adaptive"] = value + } +} + +// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for +// +// linear models with L1 + L2 regularization. As global optimization objective is +// strongly-convex, the optimizer optimizes the dual objective at each step. The +// optimizer applies each update one example at a time. Examples are sampled +// uniformly, and the optimizer is learning rate free and enjoys linear convergence +// rate. +// +// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
+// Shai Shalev-Shwartz, Tong Zhang. 2012 +// +// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ +// +// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
+// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, +// Peter Richtarik, Martin Takac. 2015 +// +// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
+// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 +// +// Arguments: +// sparse_example_indices: a list of vectors which contain example indices. +// sparse_feature_indices: a list of vectors which contain feature indices. +// sparse_feature_values: a list of vectors which contains feature value +// associated with each feature group. +// dense_features: a list of matrices which contains the dense feature values. +// example_weights: a vector which contains the weight associated with each +// example. +// example_labels: a vector which contains the label/target associated with each +// example. +// sparse_indices: a list of vectors where each value is the indices which has +// corresponding weights in sparse_weights. This field maybe omitted for the +// dense approach. +// sparse_weights: a list of vectors where each value is the weight associated with +// a sparse feature group. +// dense_weights: a list of vectors where the values are the weights associated +// with a dense feature group. +// example_state_data: a list of vectors containing the example state data. +// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, +// squared and hinge losses. +// l1: Symmetric l1 regularization strength. +// l2: Symmetric l2 regularization strength. +// num_loss_partitions: Number of partitions of the global loss function. +// num_inner_iterations: Number of iterations per mini-batch. +// +// Returns a list of vectors containing the updated example state +// data.a list of vectors where each value is the delta +// weights associated with a sparse feature group.a list of vectors where the values are the delta +// weights associated with a dense feature group. +func SdcaOptimizerV2(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerV2Attr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SdcaOptimizerV2", + Input: []tf.Input{ + tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + out_example_state_data = op.Output(idx) + if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { + scope.UpdateErr("SdcaOptimizerV2", err) + return + } + if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { + scope.UpdateErr("SdcaOptimizerV2", err) + return + } + return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights +} + +// LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingAdagradParametersGradAccumDebug. +type LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingAdagradParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingAdagradParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load Adagrad embedding parameters with debug support. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the Adagrad optimization algorithm. +// accumulators: Value of accumulators used in the Adagrad optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the Adagrad optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingAdagradParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingAdagradParametersGradAccumDebug", + Input: []tf.Input{ + parameters, accumulators, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. +type MaxPool3DGradGradAttr func(optionalAttr) + +// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes second-order gradients of the maxpooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool3DGradGrad", + Input: []tf.Input{ + orig_input, orig_output, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// VarHandleOpAttr is an optional argument to VarHandleOp. +type VarHandleOpAttr func(optionalAttr) + +// VarHandleOpContainer sets the optional container attribute to value. +// +// value: the container this variable is placed in. +// If not specified, defaults to "" +func VarHandleOpContainer(value string) VarHandleOpAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// VarHandleOpSharedName sets the optional shared_name attribute to value. +// +// value: the name by which this variable is referred to. +// If not specified, defaults to "" +func VarHandleOpSharedName(value string) VarHandleOpAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a handle to a Variable resource. +// +// Arguments: +// dtype: the type of this variable. Must agree with the dtypes +// of all ops using this variable. +// shape: The (possibly partially specified) shape of this variable. +func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "VarHandleOp", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StridedSliceAttr is an optional argument to StridedSlice. +type StridedSliceAttr func(optionalAttr) + +// StridedSliceBeginMask sets the optional begin_mask attribute to value. +// +// value: a bitmask where a bit i being 1 means to ignore the begin +// value and instead use the largest interval possible. At runtime +// begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or +// `[-1, n-1]` if `stride[i] < 0` +// If not specified, defaults to 0 +func StridedSliceBeginMask(value int64) StridedSliceAttr { + return func(m optionalAttr) { + m["begin_mask"] = value + } +} + +// StridedSliceEndMask sets the optional end_mask attribute to value. +// +// value: analogous to `begin_mask` +// If not specified, defaults to 0 +func StridedSliceEndMask(value int64) StridedSliceAttr { + return func(m optionalAttr) { + m["end_mask"] = value + } +} + +// StridedSliceEllipsisMask sets the optional ellipsis_mask attribute to value. +// +// value: a bitmask where bit `i` being 1 means the `i`th +// position is actually an ellipsis. One bit at most can be 1. +// If `ellipsis_mask == 0`, then an implicit ellipsis mask of `1 << (m+1)` +// is provided. This means that `foo[3:5] == foo[3:5, ...]`. An ellipsis +// implicitly creates as many range specifications as necessary to fully +// specify the sliced range for every dimension. For example for a 4-dimensional +// tensor `foo` the slice `foo[2, ..., 5:8]` implies `foo[2, :, :, 5:8]`. +// If not specified, defaults to 0 +func StridedSliceEllipsisMask(value int64) StridedSliceAttr { + return func(m optionalAttr) { + m["ellipsis_mask"] = value + } +} + +// StridedSliceNewAxisMask sets the optional new_axis_mask attribute to value. +// +// value: a bitmask where bit `i` being 1 means the `i`th +// specification creates a new shape 1 dimension. For example +// `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor. +// If not specified, defaults to 0 +func StridedSliceNewAxisMask(value int64) StridedSliceAttr { + return func(m optionalAttr) { + m["new_axis_mask"] = value + } +} + +// StridedSliceShrinkAxisMask sets the optional shrink_axis_mask attribute to value. +// +// value: a bitmask where bit `i` implies that the `i`th +// specification should shrink the dimensionality. begin and end +// must imply a slice of size 1 in the dimension. For example in +// python one might do `foo[:, 3, :]` which would result in +// `shrink_axis_mask` being 2. +// If not specified, defaults to 0 +func StridedSliceShrinkAxisMask(value int64) StridedSliceAttr { + return func(m optionalAttr) { + m["shrink_axis_mask"] = value + } +} + +// Return a strided slice from `input`. +// +// Note, most python users will want to use the Python `Tensor.__getitem__` +// or `Variable.__getitem__` rather than this op directly. +// +// The goal of this op is to produce a new tensor with a subset of +// the elements from the `n` dimensional `input` tensor. The subset is chosen using +// a sequence of `m` sparse range specifications encoded into the arguments +// of this function. Note, in some cases +// `m` could be equal to `n`, but this need not be the case. Each +// range specification entry can be one of the following: +// +// - An ellipsis (...). Ellipses are used to imply zero or more +// dimensions of full-dimension selection and are produced using +// `ellipsis_mask`. For example, `foo[...]` is the identity slice. +// +// - A new axis. This is used to insert a new shape=1 dimension and is +// produced using `new_axis_mask`. For example, `foo[:, ...]` where +// `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. +// +// +// - A range `begin:end:stride`. This is used to specify how much to choose from +// a given dimension. `stride` can be any integer but 0. `begin` is an integer +// which represents the index of the first value to select while `end` represents +// the index of the last value to select. The number of values selected in each +// dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. +// `begin` and `end` can be negative where `-1` is the last element, `-2` is +// the second to last. `begin_mask` controls whether to replace the explicitly +// given `begin` with an implicit effective value of `0` if `stride > 0` and +// `-1` if `stride < 0`. `end_mask` is analogous but produces the number +// required to create the largest open interval. For example, given a shape +// `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do +// not assume this is equivalent to `foo[0:-1]` which has an effective `begin` +// and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the +// first dimension of a tensor while dropping the last two (in the original +// order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. +// +// - A single index. This is used to keep only elements that have a given +// index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a +// shape `(6,)` tensor. This is encoded in `begin` and `end` and +// `shrink_axis_mask`. +// +// Each conceptual range specification is encoded in the op's argument. This +// encoding is best understand by considering a non-trivial example. In +// particular, +// `foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as +// +// ``` +// begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0) +// end = [2, 4, x, x, -3, x] +// strides = [1, 1, x, x, -1, 1] +// begin_mask = 1<<4 | 1 << 5 = 48 +// end_mask = 1<<5 = 32 +// ellipsis_mask = 1<<3 = 8 +// new_axis_mask = 1<<2 4 +// shrink_axis_mask = 1<<0 +// ``` +// +// In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of +// the slice becomes (2, 1, 5, 5, 2, 5). +// Let us walk step by step through each argument specification. +// +// 1. The first argument in the example slice is turned into `begin = 1` and +// `end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we +// also set the appropriate bit in `shrink_axis_mask`. +// +// 2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have +// zero bits contributed. +// +// 3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1 +// dimension in the final shape. Dummy values are contributed to begin, +// end and stride, while the new_axis_mask bit is set. +// +// 4. `...` grab the full ranges from as many dimensions as needed to +// fully specify a slice for every dimension of the input shape. +// +// 5. `:-3:-1` shows the use of negative indices. A negative index `i` associated +// with a dimension that has shape `s` is converted to a positive index +// `s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion +// is done internally so begin, end and strides receive x, -3, and -1. +// The appropriate begin_mask bit is set to indicate the start range is the +// full range (ignoring the x). +// +// 6. `:` indicates that the entire contents of the corresponding dimension +// is selected. This is equivalent to `::` or `0::1`. begin, end, and strides +// receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and +// `end_mask` are also set. +// +// *Requirements*: +// `0 != strides[i] for i in [0, m)` +// `ellipsis_mask must be a power of two (only one ellipsis)` +// +// Arguments: +// +// begin: `begin[k]` specifies the offset into the `k`th range specification. +// The exact dimension this corresponds to will be determined by context. +// Out-of-bounds values will be silently clamped. If the `k`th bit of +// `begin_mask` then `begin[k]` is ignored and the full range of the +// appropriate dimension is used instead. Negative values causes indexing +// to start from the highest element e.g. If `foo==[1,2,3]` then `foo[-1]==3`. +// end: `end[i]` is like `begin` with the exception that `end_mask` is +// used to determine full ranges. +// strides: `strides[i]` specifies the increment in the `i`th specification +// after extracting a given element. Negative indices will reverse +// the original order. Out or range values are +// clamped to `[0,dim[i]) if slice[i]>0` or `[-1,dim[i]-1] if slice[i] < 0` +func StridedSlice(scope *Scope, input tf.Output, begin tf.Output, end tf.Output, strides tf.Output, optional ...StridedSliceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StridedSlice", + Input: []tf.Input{ + input, begin, end, strides, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Increments variable pointed to by 'resource' until it reaches 'limit'. +// +// Arguments: +// resource: Should be from a scalar `Variable` node. +// limit: If incrementing ref would bring it above limit, instead generates an +// 'OutOfRange' error. +// +// +// Returns A copy of the input before increment. If nothing else modifies the +// input, the values produced will all be distinct. +func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataType) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"limit": limit, "T": T} + opspec := tf.OpSpec{ + Type: "ResourceCountUpTo", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that batches input elements into a SparseTensor. +// +// Arguments: +// input_dataset: A handle to an input dataset. Must have a single component. +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// row_shape: A vector representing the dense shape of each row in the produced +// SparseTensor. The shape may be partially specified, using `-1` to indicate +// that a particular dimension should use the maximum size of all batch elements. +// +// +func ExperimentalDenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalDenseToSparseBatchDataset", + Input: []tf.Input{ + input_dataset, batch_size, row_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes square root of x element-wise. +// +// I.e., \\(y = \sqrt{x} = x^{1/2}\\). +func Sqrt(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sqrt", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingAdadeltaParametersGradAccumDebug. +type LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load Adadelta parameters with debug support. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the Adadelta optimization algorithm. +// accumulators: Value of accumulators used in the Adadelta optimization algorithm. +// updates: Value of updates used in the Adadelta optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the Adadelta optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingAdadeltaParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, updates tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingAdadeltaParametersGradAccumDebug", + Input: []tf.Input{ + parameters, accumulators, updates, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// RpcAttr is an optional argument to Rpc. +type RpcAttr func(optionalAttr) + +// RpcProtocol sets the optional protocol attribute to value. +// +// value: RPC protocol to use. Empty string means use the default protocol. +// Options include 'grpc'. +// If not specified, defaults to "" +func RpcProtocol(value string) RpcAttr { + return func(m optionalAttr) { + m["protocol"] = value + } +} + +// RpcFailFast sets the optional fail_fast attribute to value. +// +// value: `boolean`. If `true` (default), then failures to connect +// (i.e., the server does not immediately respond) cause an RPC failure. +// If not specified, defaults to true +func RpcFailFast(value bool) RpcAttr { + return func(m optionalAttr) { + m["fail_fast"] = value + } +} + +// RpcTimeoutInMs sets the optional timeout_in_ms attribute to value. +// +// value: `int`. If `0` (default), then the kernel will run the RPC +// request and only time out if the RPC deadline passes or the session times out. +// If this value is greater than `0`, then the op will raise an exception if +// the RPC takes longer than `timeout_in_ms`. +// If not specified, defaults to 0 +func RpcTimeoutInMs(value int64) RpcAttr { + return func(m optionalAttr) { + m["timeout_in_ms"] = value + } +} + +// Perform batches of RPC requests. +// +// This op asynchronously performs either a single RPC request, or a batch +// of requests. RPC requests are defined by three main parameters: +// +// - `address` (the host+port or BNS address of the request) +// - `method` (the RPC method name for the request) +// - `request` (the serialized proto string, or vector of strings, +// of the RPC request argument). +// +// For example, if you have an RPC service running on port localhost:2345, +// and its interface is configured with the following proto declaration: +// +// ``` +// service MyService { +// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { +// } +// }; +// ``` +// +// then call this op with arguments: +// +// ``` +// address = "localhost:2345" +// method = "MyService/MyMethod" +// ``` +// +// The `request` tensor is a string tensor representing serialized `MyRequestProto` +// strings; and the output string tensor `response` will have the same shape +// and contain (upon successful completion) corresponding serialized +// `MyResponseProto` strings. +// +// For example, to send a single, empty, `MyRequestProto`, call +// this op with `request = ""`. To send 5 **parallel** empty requests, +// call this op with `request = ["", "", "", "", ""]`. +// +// More generally, one can create a batch of `MyRequestProto` serialized protos +// from regular batched tensors using the `encode_proto` op, and convert +// the response `MyResponseProto` serialized protos to batched tensors +// using the `decode_proto` op. +// +// **NOTE** Working with serialized proto strings is faster than instantiating +// actual proto objects in memory, so no performance degradation is expected +// compared to writing custom kernels for this workflow. +// +// If the connection fails or the remote worker returns an error +// status, the op reraises this exception locally. +// +// See the `TryRpc` op if you prefer to handle RPC failures manually in the graph. +// +// Arguments: +// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `method` and `request`. +// method: `0-D` or `1-D`. The method address on the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `request`. +// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `method`. +// +// Returns Same shape as `request`. Serialized proto strings: the rpc responses. +func Rpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...RpcAttr) (response tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Rpc", + Input: []tf.Input{ + address, method, request, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Rounds the values of a tensor to the nearest integer, element-wise. +// +// Rounds half to even. Also known as bankers rounding. If you want to round +// according to the current system rounding mode use std::cint. +func Round(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Round", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Records the latency of producing `input_dataset` elements in a StatsAggregator. +func ExperimentalLatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalLatencyStatsDataset", + Input: []tf.Input{ + input_dataset, tag, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyKerasMomentumAttr is an optional argument to ResourceSparseApplyKerasMomentum. +type ResourceSparseApplyKerasMomentumAttr func(optionalAttr) + +// ResourceSparseApplyKerasMomentumUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyKerasMomentumUseLocking(value bool) ResourceSparseApplyKerasMomentumAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceSparseApplyKerasMomentumUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, the tensor passed to compute grad will be +// var + momentum * accum, so in the end, the var you get is actually +// var + momentum * accum. +// If not specified, defaults to false +func ResourceSparseApplyKerasMomentumUseNesterov(value bool) ResourceSparseApplyKerasMomentumAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the momentum scheme. +// +// Set use_nesterov = True if you want to use Nesterov momentum. +// +// That is for rows we have grad for, we update var and accum as follows: +// +// accum = accum * momentum - lr * grad +// var += accum +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// momentum: Momentum. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyKerasMomentumAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyKerasMomentum", + Input: []tf.Input{ + var_, accum, lr, grad, indices, momentum, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// MapUnstageNoKeyAttr is an optional argument to MapUnstageNoKey. +type MapUnstageNoKeyAttr func(optionalAttr) + +// MapUnstageNoKeyCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapUnstageNoKeyCapacity(value int64) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapUnstageNoKeyMemoryLimit(value int64) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapUnstageNoKeyContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapUnstageNoKeyContainer(value string) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapUnstageNoKeySharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapUnstageNoKeySharedName(value string) MapUnstageNoKeyAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns a random (key, value) +// +// from the underlying container. If the underlying container +// does not contain elements, the op will block until it does. +func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MapUnstageNoKey", + Input: []tf.Input{ + indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + key = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapUnstageNoKey", err) + return + } + return key, values +} + +// Outputs deterministic pseudorandom random integers from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[minval, maxval)`. +// +// The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// minval: Minimum value (inclusive, scalar). +// maxval: Maximum value (exclusive, scalar). +// +// Returns Random values with specified shape. +func StatelessRandomUniformInt(scope *Scope, shape tf.Output, seed tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StatelessRandomUniformInt", + Input: []tf.Input{ + shape, seed, minval, maxval, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingAdadeltaParametersAttr is an optional argument to RetrieveTPUEmbeddingAdadeltaParameters. +type RetrieveTPUEmbeddingAdadeltaParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingAdadeltaParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingAdadeltaParametersTableId(value int64) RetrieveTPUEmbeddingAdadeltaParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingAdadeltaParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingAdadeltaParametersTableName(value string) RetrieveTPUEmbeddingAdadeltaParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve Adadelta embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the Adadelta optimization algorithm.Parameter accumulators updated by the Adadelta optimization algorithm.Parameter updates updated by the Adadelta optimization algorithm. +func RetrieveTPUEmbeddingAdadeltaParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdadeltaParametersAttr) (parameters tf.Output, accumulators tf.Output, updates tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingAdadeltaParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// ImagAttr is an optional argument to Imag. +type ImagAttr func(optionalAttr) + +// ImagTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func ImagTout(value tf.DataType) ImagAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Returns the imaginary part of a complex number. +// +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// type `float` that is the imaginary part of each element in `input`. All +// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* +// is the real part and *b* is the imaginary part returned by this operation. +// +// For example: +// +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.imag(input) ==> [4.75, 5.75] +// ``` +func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Imag", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Concatenates tensors along one dimension. +// +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Concat", + Input: []tf.Input{ + concat_dim, tf.OutputList(values), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// UnicodeEncodeAttr is an optional argument to UnicodeEncode. +type UnicodeEncodeAttr func(optionalAttr) + +// UnicodeEncodeErrors sets the optional errors attribute to value. +// +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeEncodeErrors(value string) UnicodeEncodeAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeEncodeReplacementChar sets the optional replacement_char attribute to value. +// +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD (U+65533). +// If not specified, defaults to 65533 +func UnicodeEncodeReplacementChar(value int64) UnicodeEncodeAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// Encode a tensor of ints into unicode strings. +// +// Returns a vector of strings, where `output[i]` is constructed by encoding the +// Unicode codepoints in `input_values[input_splits[i]:input_splits[i+1]]` +// using `output_encoding`. +// +// --- +// +// Example: +// +// ``` +// input_values = [72, 101, 108, 108, 111, 87, 111, 114, 108, 100] +// input_splits = [0, 5, 10] +// output_encoding = 'UTF-8' +// +// output = ['Hello', 'World'] +// ``` +// +// Arguments: +// input_values: A 1D tensor containing the unicode codepoints that should be encoded. +// input_splits: A 1D tensor specifying how the unicode codepoints should be split into strings. +// In particular, `output[i]` is constructed by encoding the codepoints in the +// slice `input_values[input_splits[i]:input_splits[i+1]]`. +// output_encoding: Unicode encoding of the output strings. Valid encodings are: `"UTF-8", +// "UTF-16-BE", and "UTF-32-BE"`. +// +// Returns The 1-D Tensor of strings encoded from the provided unicode codepoints. +func UnicodeEncode(scope *Scope, input_values tf.Output, input_splits tf.Output, output_encoding string, optional ...UnicodeEncodeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_encoding": output_encoding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UnicodeEncode", + Input: []tf.Input{ + input_values, input_splits, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingADAMParametersGradAccumDebug. +type RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve ADAM embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the ADAM optimization algorithm.Parameter momenta updated by the ADAM optimization algorithm.Parameter velocities updated by the ADAM optimization algorithm.Parameter gradient_accumulators updated by the ADAM optimization algorithm. +func RetrieveTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// Returns 0 if the denominator is zero. +// +// +// *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func DivNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DivNoNan", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// UnicodeDecodeWithOffsetsAttr is an optional argument to UnicodeDecodeWithOffsets. +type UnicodeDecodeWithOffsetsAttr func(optionalAttr) + +// UnicodeDecodeWithOffsetsErrors sets the optional errors attribute to value. +// +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeDecodeWithOffsetsErrors(value string) UnicodeDecodeWithOffsetsAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeDecodeWithOffsetsReplacementChar sets the optional replacement_char attribute to value. +// +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD or U+65533.) +// If not specified, defaults to 65533 +func UnicodeDecodeWithOffsetsReplacementChar(value int64) UnicodeDecodeWithOffsetsAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// UnicodeDecodeWithOffsetsReplaceControlCharacters sets the optional replace_control_characters attribute to value. +// +// value: Whether to replace the C0 control characters (00-1F) with the +// `replacement_char`. Default is false. +// If not specified, defaults to false +func UnicodeDecodeWithOffsetsReplaceControlCharacters(value bool) UnicodeDecodeWithOffsetsAttr { + return func(m optionalAttr) { + m["replace_control_characters"] = value + } +} + +// UnicodeDecodeWithOffsetsTsplits sets the optional Tsplits attribute to value. +// If not specified, defaults to DT_INT64 +func UnicodeDecodeWithOffsetsTsplits(value tf.DataType) UnicodeDecodeWithOffsetsAttr { + return func(m optionalAttr) { + m["Tsplits"] = value + } +} + +// Decodes each string in `input` into a sequence of Unicode code points. +// +// The character codepoints for all strings are returned using a single vector +// `char_values`, with strings expanded to characters in row-major order. +// Similarly, the character start byte offsets are returned using a single vector +// `char_to_byte_starts`, with strings expanded in row-major order. +// +// The `row_splits` tensor indicates where the codepoints and start offsets for +// each input string begin and end within the `char_values` and +// `char_to_byte_starts` tensors. In particular, the values for the `i`th +// string (in row-major order) are stored in the slice +// `[row_splits[i]:row_splits[i+1]]`. Thus: +// +// * `char_values[row_splits[i]+j]` is the Unicode codepoint for the `j`th +// character in the `i`th string (in row-major order). +// * `char_to_bytes_starts[row_splits[i]+j]` is the start byte offset for the `j`th +// character in the `i`th string (in row-major order). +// * `row_splits[i+1] - row_splits[i]` is the number of characters in the `i`th +// string (in row-major order). +// +// Arguments: +// input: The text to be decoded. Can have any shape. Note that the output is flattened +// to a vector of char values. +// input_encoding: Text encoding of the input strings. This is any of the encodings supported +// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. +// +// Returns A 1D int32 tensor containing the row splits.A 1D int32 Tensor containing the decoded codepoints.A 1D int32 Tensor containing the byte index in the input string where each +// character in `char_values` starts. +func UnicodeDecodeWithOffsets(scope *Scope, input tf.Output, input_encoding string, optional ...UnicodeDecodeWithOffsetsAttr) (row_splits tf.Output, char_values tf.Output, char_to_byte_starts tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"input_encoding": input_encoding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UnicodeDecodeWithOffsets", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes the reciprocal of x element-wise. +// +// I.e., \\(y = 1 / x\\). +func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Reciprocal", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. +type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. +// If not specified, defaults to -6 +func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["min"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. +// If not specified, defaults to 6 +func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["max"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxArgs operation. +// +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. +// +// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: +// `gradients * (inputs >= min && inputs <= max)`. +func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxArgsGradient", + Input: []tf.Input{ + gradients, inputs, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the truth value of NOT x element-wise. +func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LogicalNot", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Produces the max pool of the input tensor for quantized types. +// +// Arguments: +// input: The 4D (batch x rows x cols x depth) Tensor to MaxReduce over. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// ksize: The size of the window for each dimension of the input tensor. +// The length must be 4 to match the number of dimensions of the input. +// strides: The stride of the sliding window for each dimension of the input +// tensor. The length must be 4 to match the number of dimensions of the input. +// padding: The type of padding algorithm to use. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedMaxPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + opspec := tf.OpSpec{ + Type: "QuantizedMaxPool", + Input: []tf.Input{ + input, min_input, max_input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingMomentumParametersGradAccumDebug. +type RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve Momentum embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the Momentum optimization algorithm.Parameter momenta updated by the Momentum optimization algorithm.Parameter gradient_accumulators updated by the Momentum optimization algorithm. +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// GatherAttr is an optional argument to Gather. +type GatherAttr func(optionalAttr) + +// GatherValidateIndices sets the optional validate_indices attribute to value. +// If not specified, defaults to true +func GatherValidateIndices(value bool) GatherAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Gather slices from `params` according to `indices`. +// +// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). +// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: +// +// ```python +// # Scalar indices +// output[:, ..., :] = params[indices, :, ... :] +// +// # Vector indices +// output[i, :, ..., :] = params[indices[i], :, ... :] +// +// # Higher rank indices +// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] +// ``` +// +// If `indices` is a permutation and `len(indices) == params.shape[0]` then +// this operation will permute `params` accordingly. +// +// `validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in +// `indices` are always validated to be within range. If assigned to GPU, +// out-of-bound indices result in safe but unspecified behavior, which may include +// raising an error. +// +//
+// +//
+func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...GatherAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Gather", + Input: []tf.Input{ + params, indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SparseReduceMaxAttr is an optional argument to SparseReduceMax. +type SparseReduceMaxAttr func(optionalAttr) + +// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the max of elements across dimensions of a SparseTensor. +// +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseReduceMax", + Input: []tf.Input{ + input_indices, input_values, input_shape, reduction_axes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. +type StatelessTruncatedNormalAttr func(optionalAttr) + +// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom values from a truncated normal distribution. +// +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessTruncatedNormal", + Input: []tf.Input{ + shape, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Concatenates a list of `SparseTensor` along the specified dimension. +// +// Concatenation is with respect to the dense versions of these sparse tensors. +// It is assumed that each input is a `SparseTensor` whose elements are ordered +// along increasing dimension number. +// +// All inputs' shapes must match, except for the concat dimension. The +// `indices`, `values`, and `shapes` lists must have the same length. +// +// The output shape is identical to the inputs', except along the concat +// dimension, where it is the sum of the inputs' sizes along that dimension. +// +// The output elements will be resorted to preserve the sort order along +// increasing dimension number. +// +// This op runs in `O(M log M)` time, where `M` is the total number of non-empty +// values across all inputs. This is due to the need for an internal sort in +// order to concatenate efficiently across an arbitrary dimension. +// +// For example, if `concat_dim = 1` and the inputs are +// +// sp_inputs[0]: shape = [2, 3] +// [0, 2]: "a" +// [1, 0]: "b" +// [1, 1]: "c" +// +// sp_inputs[1]: shape = [2, 4] +// [0, 1]: "d" +// [0, 2]: "e" +// +// then the output will be +// +// shape = [2, 7] +// [0, 2]: "a" +// [0, 4]: "d" +// [0, 5]: "e" +// [1, 0]: "b" +// [1, 1]: "c" +// +// Graphically this is equivalent to doing +// +// [ a] concat [ d e ] = [ a d e ] +// [b c ] [ ] [b c ] +// +// Arguments: +// indices: 2-D. Indices of each input `SparseTensor`. +// values: 1-D. Non-empty values of each `SparseTensor`. +// shapes: 1-D. Shapes of each `SparseTensor`. +// concat_dim: Dimension to concatenate along. Must be in range [-rank, rank), +// where rank is the number of dimensions in each input `SparseTensor`. +// +// Returns 2-D. Indices of the concatenated `SparseTensor`.1-D. Non-empty values of the concatenated `SparseTensor`.1-D. Shape of the concatenated `SparseTensor`. +func SparseConcat(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, concat_dim int64) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"concat_dim": concat_dim} + opspec := tf.OpSpec{ + Type: "SparseConcat", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Performs gradient updates of embedding tables. +// +// Arguments: +// inputs: A TensorList of gradients with which to update embedding tables. +// This argument has the same length and shapes as the return value of +// RecvTPUEmbeddingActivations, but contains gradients of the model's loss +// with respect to the embedding activations. The embedding tables are updated +// from these gradients via the optimizer specified in the TPU embedding +// configuration given to tpu.initialize_system. +// learning_rates: A TensorList of float32 scalars, one for each dynamic learning +// rate tag: see the comments in +// //third_party/tensorflow/core/protobuf/tpu/optimization_parameters.proto. +// Multiple tables can share the same dynamic learning rate tag as specified +// in the configuration. If the learning rates for all tables are constant, +// this list should be empty. +// config: Serialized TPUEmbeddingConfiguration proto. +// +// Returns the created operation. +func SendTPUEmbeddingGradients(scope *Scope, inputs []tf.Output, learning_rates []tf.Output, config string) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"config": config} + opspec := tf.OpSpec{ + Type: "SendTPUEmbeddingGradients", + Input: []tf.Input{ + tf.OutputList(inputs), tf.OutputList(learning_rates), + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Returns a list of tensors with the same shapes and contents as the input +// +// tensors. +// +// This op can be used to override the gradient for complicated functions. For +// example, suppose y = f(x) and we wish to apply a custom function g for backprop +// such that dx = g(dy). In Python, +// +// ```python +// with tf.get_default_graph().gradient_override_map( +// {'IdentityN': 'OverrideGradientWithG'}): +// y, _ = identity_n([f(x), x]) +// +// @tf.RegisterGradient('OverrideGradientWithG') +// def ApplyG(op, dy, _): +// return [None, g(dy)] # Do not backprop to f(x). +// ``` +func IdentityN(scope *Scope, input []tf.Output) (output []tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IdentityN", + Input: []tf.Input{ + tf.OutputList(input), + }, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("IdentityN", err) + return + } + return output +} + +// RetrieveTPUEmbeddingProximalAdagradParametersAttr is an optional argument to RetrieveTPUEmbeddingProximalAdagradParameters. +type RetrieveTPUEmbeddingProximalAdagradParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingProximalAdagradParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingProximalAdagradParametersTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingProximalAdagradParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingProximalAdagradParametersTableName(value string) RetrieveTPUEmbeddingProximalAdagradParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve proximal Adagrad embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the proximal Adagrad optimization algorithm.Parameter accumulators updated by the proximal Adagrad optimization algorithm. +func RetrieveTPUEmbeddingProximalAdagradParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingProximalAdagradParametersAttr) (parameters tf.Output, accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingProximalAdagradParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// List of the given size with empty elements. +// +// element_shape: the shape of the future elements of the list +// num_elements: the number of elements to reserve +// handle: the output list +// element_dtype: the desired type of elements in the list. +func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + opspec := tf.OpSpec{ + Type: "TensorListReserve", + Input: []tf.Input{ + element_shape, num_elements, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ExperimentalStatsAggregatorHandleAttr is an optional argument to ExperimentalStatsAggregatorHandle. +type ExperimentalStatsAggregatorHandleAttr func(optionalAttr) + +// ExperimentalStatsAggregatorHandleContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func ExperimentalStatsAggregatorHandleContainer(value string) ExperimentalStatsAggregatorHandleAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// ExperimentalStatsAggregatorHandleSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func ExperimentalStatsAggregatorHandleSharedName(value string) ExperimentalStatsAggregatorHandleAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a statistics manager resource. +func ExperimentalStatsAggregatorHandle(scope *Scope, optional ...ExperimentalStatsAggregatorHandleAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ExperimentalStatsAggregatorHandle", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FusedBatchNormAttr is an optional argument to FusedBatchNorm. +type FusedBatchNormAttr func(optionalAttr) + +// FusedBatchNormEpsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { + return func(m optionalAttr) { + m["epsilon"] = value + } +} + +// FusedBatchNormDataFormat sets the optional data_format attribute to value. +// +// value: The data format for x and y. Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormIsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// Batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// offset: A 1D Tensor for offset, to shift to the normalized x. +// mean: A 1D Tensor for population mean. Used for inference only; +// must be empty for training. +// variance: A 1D Tensor for population variance. Used for inference only; +// must be empty for training. +// +// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow +// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by +// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused +// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance +// in the cuDNN case), to be reused in the gradient computation. +func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedBatchNorm", + Input: []tf.Input{ + x, scale, offset, mean, variance, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + +// TopKAttr is an optional argument to TopK. +type TopKAttr func(optionalAttr) + +// TopKSorted sets the optional sorted attribute to value. +// +// value: If true the resulting `k` elements will be sorted by the values in +// descending order. +// If not specified, defaults to true +func TopKSorted(value bool) TopKAttr { + return func(m optionalAttr) { + m["sorted"] = value + } +} + +// Finds values and indices of the `k` largest elements for the last dimension. +// +// DEPRECATED at GraphDef version 7: Use TopKV2 instead +// +// If the input is a vector (rank-1), finds the `k` largest entries in the vector +// and outputs their values and indices as vectors. Thus `values[j]` is the +// `j`-th largest entry in `input`, and its index is `indices[j]`. +// +// For matrices (resp. higher rank input), computes the top `k` entries in each +// row (resp. vector along the last dimension). Thus, +// +// values.shape = indices.shape = input.shape[:-1] + [k] +// +// If two elements are equal, the lower-index element appears first. +// +// If `k` varies dynamically, use `TopKV2` below. +// +// Arguments: +// input: 1-D or higher with last dimension at least `k`. +// k: Number of top elements to look for along the last dimension (along each +// row for matrices). +// +// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. +func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values tf.Output, indices tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"k": k} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TopK", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// RaggedRangeAttr is an optional argument to RaggedRange. +type RaggedRangeAttr func(optionalAttr) + +// RaggedRangeTsplits sets the optional Tsplits attribute to value. +// If not specified, defaults to DT_INT64 +func RaggedRangeTsplits(value tf.DataType) RaggedRangeAttr { + return func(m optionalAttr) { + m["Tsplits"] = value + } +} + +// Returns a `RaggedTensor` containing the specified sequences of numbers. +// +// +// Returns a `RaggedTensor` `result` composed from `rt_dense_values` and +// `rt_nested_splits`, such that +// `result[i] = range(starts[i], limits[i], deltas[i])`. +// +// ```python +// >>> (rt_nested_splits, rt_dense_values) = gen_ragged_ops.ragged_range( +// ... starts=[2, 5, 8], limits=[3, 5, 12], deltas=1) +// >>> result = ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) +// >>> print result.eval().tolist() +// [[2], # result[0] = range(2, 3) +// [], # result[1] = range(5, 5) +// [8, 9, 10, 11]] # result[2] = range(8, 12) +// ``` +// +// The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. +// The vector inputs must all have the same size. Scalar inputs are broadcast +// to match the size of the vector inputs. +// +// Arguments: +// starts: The starts of each range. +// limits: The limits of each range. +// deltas: The deltas of each range. +// +// Returns The `row_splits` for the returned `RaggedTensor`.The `flat_values` for the returned `RaggedTensor`. +func RaggedRange(scope *Scope, starts tf.Output, limits tf.Output, deltas tf.Output, optional ...RaggedRangeAttr) (rt_nested_splits tf.Output, rt_dense_values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RaggedRange", + Input: []tf.Input{ + starts, limits, deltas, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// RetrieveTPUEmbeddingCenteredRMSPropParametersAttr is an optional argument to RetrieveTPUEmbeddingCenteredRMSPropParameters. +type RetrieveTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingCenteredRMSPropParametersTableId(value int64) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingCenteredRMSPropParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingCenteredRMSPropParametersTableName(value string) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve centered RMSProp embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the centered RMSProp optimization algorithm.Parameter ms updated by the centered RMSProp optimization algorithm.Parameter mom updated by the centered RMSProp optimization algorithm.Parameter mg updated by the centered RMSProp optimization algorithm. +func RetrieveTPUEmbeddingCenteredRMSPropParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingCenteredRMSPropParametersAttr) (parameters tf.Output, ms tf.Output, mom tf.Output, mg tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingCenteredRMSPropParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug. +type RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve proximal Adagrad embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the proximal Adagrad optimization algorithm.Parameter accumulators updated by the proximal Adagrad optimization algorithm.Parameter gradient_accumulators updated by the proximal Adagrad optimization algorithm. +func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Adds a value to the current value of a variable. +// +// Any ReadVariableOp with a control dependency on this op is guaranteed to +// see the incremented value or a subsequent newer one. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// value: the value by which the variable will be incremented. +// +// Returns the created operation. +func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AssignAddVariableOp", + Input: []tf.Input{ + resource, value, + }, + } + return scope.AddOperation(opspec) +} + +// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingAdagradParametersGradAccumDebug. +type RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingAdagradParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve Adagrad embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the Adagrad optimization algorithm.Parameter accumulators updated by the Adagrad optimization algorithm.Parameter gradient_accumulators updated by the Adagrad optimization algorithm. +func RetrieveTPUEmbeddingAdagradParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdagradParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingAdagradParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes inverse hyperbolic tangent of x element-wise. +func Atanh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Atanh", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient for the inverse of `x` wrt its input. +// +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InvGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compute the pairwise cross product. +// +// `a` and `b` must be the same shape; they can either be simple 3-element vectors, +// or any shape where the innermost dimension is 3. In the latter case, each pair +// of corresponding 3-element vectors is cross-multiplied independently. +// +// Arguments: +// a: A tensor containing 3-element vectors. +// b: Another tensor, of same type and shape as `a`. +// +// Returns Pairwise cross product of the vectors in `a` and `b`. +func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cross", + Input: []tf.Input{ + a, b, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes rectified linear gradients for a Relu operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Relu operation. +// features: The features passed as input to the corresponding Relu operation, OR +// the outputs of that operation (both work equivalently). +// +// Returns `gradients * (features > 0)`. +func ReluGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReluGrad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Check if the input matches the regex pattern. +// +// The input is a string tensor of any shape. The pattern is a scalar +// string tensor which is applied to every element of the input tensor. +// The boolean values (True or False) of the output tensor indicate +// if the input matches the regex pattern provided. +// +// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// +// Arguments: +// input: A string tensor of the text to be processed. +// pattern: A scalar string tensor containing the regular expression to match the input. +// +// Returns A bool tensor with the same shape as `input`. +func RegexFullMatch(scope *Scope, input tf.Output, pattern tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RegexFullMatch", + Input: []tf.Input{ + input, pattern, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that contains the unique elements of `input_dataset`. +func ExperimentalUniqueDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalUniqueDataset", + Input: []tf.Input{ + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodePngAttr is an optional argument to DecodePng. +type DecodePngAttr func(optionalAttr) + +// DecodePngChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodePngChannels(value int64) DecodePngAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodePngDtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_UINT8 +func DecodePngDtype(value tf.DataType) DecodePngAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Decode a PNG-encoded image to a uint8 or uint16 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the PNG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// * 4: output an RGBA image. +// +// If needed, the PNG-encoded image is transformed to match the requested number +// of color channels. +// +// This op also supports decoding JPEGs and non-animated GIFs since the interface +// is the same, though it is cleaner to use `tf.image.decode_image`. +// +// Arguments: +// contents: 0-D. The PNG-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`. +func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodePng", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that contains `rate` elements from the `input_dataset`. +// +// Arguments: +// +// rate: A scalar representing the sample rate of elements from the `input_dataset` +// that should be taken. +// seed: A scalar representing seed of random number generator. +// seed2: A scalar representing seed2 of random number generator. +// +// +func SamplingDataset(scope *Scope, input_dataset tf.Output, rate tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "SamplingDataset", + Input: []tf.Input{ + input_dataset, rate, seed, seed2, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Scatter `updates` into an existing tensor according to `indices`. +// +// This operation creates a new tensor by applying sparse `updates` to the passed +// in `tensor`. +// This operation is very similar to `tf.scatter_nd`, except that the updates are +// scattered onto an existing tensor (as opposed to a zero-tensor). If the memory +// for the existing tensor cannot be re-used, a copy is made and updated. +// +// If `indices` contains duplicates, then their updates are accumulated (summed). +// +// **WARNING**: The order in which updates are applied is nondeterministic, so the +// output will be nondeterministic if `indices` contains duplicates -- because +// of some numerical approximation issues, numbers summed in different order +// may yield different results. +// +// `indices` is an integer tensor containing indices into a new tensor of shape +// `shape`. The last dimension of `indices` can be at most the rank of `shape`: +// +// indices.shape[-1] <= shape.rank +// +// The last dimension of `indices` corresponds to indices into elements +// (if `indices.shape[-1] = shape.rank`) or slices +// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of +// `shape`. `updates` is a tensor with shape +// +// indices.shape[:-1] + shape[indices.shape[-1]:] +// +// The simplest form of scatter is to insert individual elements in a tensor by +// index. For example, say we want to insert 4 scattered elements in a rank-1 +// tensor with 8 elements. +// +//
+// +//
+// +// In Python, this scatter operation would look like this: +// +// ```python +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// tensor = tf.ones([8], dtype=tf.int32) +// updated = tf.tensor_scatter_update(tensor, indices, updates) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [1, 11, 1, 10, 9, 1, 1, 12] +// +// We can also, insert entire slices of a higher rank tensor all at once. For +// example, if we wanted to insert two slices in the first dimension of a +// rank-3 tensor with two matrices of new values. +// +// In Python, this scatter operation would look like this: +// +// ```python +// indices = tf.constant([[0], [2]]) +// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]], +// [[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]]]) +// tensor = tf.ones([4, 4, 4]) +// updated = tf.tensor_scatter_update(tensor, indices, updates) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], +// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], +// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], +// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] +// +// Note that on CPU, if an out of bound index is found, an error is returned. +// On GPU, if an out of bound index is found, the index is ignored. +// +// Arguments: +// tensor: Tensor to copy/update. +// indices: Index tensor. +// updates: Updates to scatter into output. +// +// Returns A new tensor with the given shape and updates applied according +// to the indices. +func TensorScatterUpdate(scope *Scope, tensor tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorScatterUpdate", + Input: []tf.Input{ + tensor, indices, updates, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x // y element-wise. +// +// *NOTE*: `FloorDiv` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FloorDiv", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Shuts down a running distributed TPU system. +// +// The op returns an error if no system is running. +// +// Returns the created operation. +func ShutdownDistributedTPU(scope *Scope) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ShutdownDistributedTPU", + } + return scope.AddOperation(opspec) +} + +// QuantizedDepthwiseConv2DWithBiasAttr is an optional argument to QuantizedDepthwiseConv2DWithBias. +type QuantizedDepthwiseConv2DWithBiasAttr func(optionalAttr) + +// QuantizedDepthwiseConv2DWithBiasOutType sets the optional out_type attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_QINT32 +func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwiseConv2DWithBiasAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value. +// +// value: List of dilation values. +// If not specified, defaults to +func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes quantized depthwise Conv2D with Bias. +// +// Arguments: +// input: The original input tensor. +// filter: The original filter tensor. +// bias: The original bias tensor. +// min_input: The float value that the minimum quantized input value represents. +// max_input: The float value that the maximum quantized input value represents. +// min_filter: The float value that the minimum quantized filter value represents. +// max_filter: The float value that the maximum quantized filter value represents. +// strides: List of stride values. +// +// +// Returns The output tensor.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedDepthwiseConv2DWithBias(scope *Scope, input tf.Output, filter tf.Output, bias tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedDepthwiseConv2DWithBiasAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedDepthwiseConv2DWithBias", + Input: []tf.Input{ + input, filter, bias, min_input, max_input, min_filter, max_filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// StaticRegexReplaceAttr is an optional argument to StaticRegexReplace. +type StaticRegexReplaceAttr func(optionalAttr) + +// StaticRegexReplaceReplaceGlobal sets the optional replace_global attribute to value. +// +// value: If True, the replacement is global, otherwise the replacement +// is done only on the first match. +// If not specified, defaults to true +func StaticRegexReplaceReplaceGlobal(value bool) StaticRegexReplaceAttr { + return func(m optionalAttr) { + m["replace_global"] = value + } +} + +// Replaces the match of pattern in input with rewrite. +// +// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// +// Arguments: +// input: The text to be processed. +// pattern: The regular expression to match the input. +// rewrite: The rewrite to be applied to the matched expression. +// +// Returns The text after applying pattern and rewrite. +func StaticRegexReplace(scope *Scope, input tf.Output, pattern string, rewrite string, optional ...StaticRegexReplaceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"pattern": pattern, "rewrite": rewrite} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StaticRegexReplace", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Subtracts sparse updates from the variable referenced by `resource`. +// +// This operation computes +// +// # Scalar indices +// ref[indices, ...] -= updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] -= updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions add. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
+// +// Arguments: +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. +// +// Returns the created operation. +func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ResourceScatterSub", + Input: []tf.Input{ + resource, indices, updates, + }, + } + return scope.AddOperation(opspec) +} + +// LoadTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingADAMParametersGradAccumDebug. +type LoadTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load ADAM embedding parameters with debug support. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the ADAM optimization algorithm. +// momenta: Value of momenta used in the ADAM optimization algorithm. +// velocities: Value of velocities used in the ADAM optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the ADAM optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingADAMParametersGradAccumDebugAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingADAMParametersGradAccumDebug", + Input: []tf.Input{ + parameters, momenta, velocities, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// SerializeManySparseAttr is an optional argument to SerializeManySparse. +type SerializeManySparseAttr func(optionalAttr) + +// SerializeManySparseOutType sets the optional out_type attribute to value. +// +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeManySparseOutType(value tf.DataType) SerializeManySparseAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor` object. +// +// The `SparseTensor` must have rank `R` greater than 1, and the first dimension +// is treated as the minibatch dimension. Elements of the `SparseTensor` +// must be sorted in increasing order of this first dimension. The serialized +// `SparseTensor` objects going into each row of `serialized_sparse` will have +// rank `R-1`. +// +// The minibatch size `N` is extracted from `sparse_shape[0]`. +// +// Arguments: +// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. +// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. +func SerializeManySparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeManySparseAttr) (serialized_sparse tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SerializeManySparse", + Input: []tf.Input{ + sparse_indices, sparse_values, sparse_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the Bessel i0e function of `x` element-wise. +// +// Exponentially scaled modified Bessel function of order 0 defined as +// `bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. +// +// This function is faster and numerically stabler than `bessel_i0(x)`. +func BesselI0e(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BesselI0e", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the element-wise max of two SparseTensors. +// +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// +// Arguments: +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. +// +// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. +func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSparseMaximum", + Input: []tf.Input{ + a_indices, a_values, a_shape, b_indices, b_values, b_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. +type StatelessRandomNormalAttr func(optionalAttr) + +// StatelessRandomNormalDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom values from a normal distribution. +// +// The generated values will have mean 0 and standard deviation 1. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessRandomNormal", + Input: []tf.Input{ + shape, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TensorArrayV2Attr is an optional argument to TensorArrayV2. +type TensorArrayV2Attr func(optionalAttr) + +// TensorArrayV2ElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorArrayV2ElementShape(value tf.Shape) TensorArrayV2Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// TensorArrayV2DynamicSize sets the optional dynamic_size attribute to value. +// If not specified, defaults to false +func TensorArrayV2DynamicSize(value bool) TensorArrayV2Attr { + return func(m optionalAttr) { + m["dynamic_size"] = value + } +} + +// TensorArrayV2ClearAfterRead sets the optional clear_after_read attribute to value. +// If not specified, defaults to true +func TensorArrayV2ClearAfterRead(value bool) TensorArrayV2Attr { + return func(m optionalAttr) { + m["clear_after_read"] = value + } +} + +// TensorArrayV2TensorArrayName sets the optional tensor_array_name attribute to value. +// If not specified, defaults to "" +func TensorArrayV2TensorArrayName(value string) TensorArrayV2Attr { + return func(m optionalAttr) { + m["tensor_array_name"] = value + } +} + +// Deprecated. Use TensorArrayV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayV3 +func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayV2", + Input: []tf.Input{ + size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Converts each string in the input Tensor to its hash mod by a number of buckets. +// +// The hash function is deterministic on the content of the string within the +// process and will never change. However, it is not suitable for cryptography. +// This function may be used when CPU time is scarce and inputs are trusted or +// unimportant. There is a risk of adversaries constructing inputs that all hash +// to the same bucket. To prevent this problem, use a strong hash function with +// `tf.string_to_hash_bucket_strong`. +// +// Arguments: +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_buckets": num_buckets} + opspec := tf.OpSpec{ + Type: "StringToHashBucketFast", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// The gradient of SparseFillEmptyRows. +// +// Takes vectors reverse_index_map, shaped `[N]`, and grad_values, +// shaped `[N_full]`, where `N_full >= N` and copies data into either +// `d_values` or `d_default_value`. Here `d_values` is shaped `[N]` and +// `d_default_value` is a scalar. +// +// d_values[j] = grad_values[reverse_index_map[j]] +// d_default_value = sum_{k : 0 .. N_full - 1} ( +// grad_values[k] * 1{k not in reverse_index_map}) +// +// Arguments: +// reverse_index_map: 1-D. The reverse index map from SparseFillEmptyRows. +// grad_values: 1-D. The gradients from backprop. +// +// Returns 1-D. The backprop into values.0-D. The backprop into default_value. +func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_values tf.Output) (d_values tf.Output, d_default_value tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseFillEmptyRowsGrad", + Input: []tf.Input{ + reverse_index_map, grad_values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Converts each string in the input Tensor to its hash mod by a number of buckets. +// +// The hash function is deterministic on the content of the string within the +// process. +// +// Note that the hash function may change from time to time. +// This functionality will be deprecated and it's recommended to use +// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. +// +// Arguments: +// +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_buckets": num_buckets} + opspec := tf.OpSpec{ + Type: "StringToHashBucket", + Input: []tf.Input{ + string_tensor, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Scatters tensor at indices in an input list. +// +// Each member of the TensorList corresponds to one row of the input tensor, +// specified by the given index (see `tf.gather`). +// +// input_handle: The list to scatter into. +// tensor: The input tensor. +// indices: The indices used to index into the list. +// output_handle: The TensorList. +func TensorListScatterIntoExistingList(scope *Scope, input_handle tf.Output, tensor tf.Output, indices tf.Output) (output_handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorListScatterIntoExistingList", + Input: []tf.Input{ + input_handle, tensor, indices, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingFTRLParametersGradAccumDebug. +type RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve FTRL embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the FTRL optimization algorithm.Parameter accumulators updated by the FTRL optimization algorithm.Parameter linears updated by the FTRL optimization algorithm.Parameter gradient_accumulators updated by the FTRL optimization algorithm. +func RetrieveTPUEmbeddingFTRLParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// CudnnRNNParamsToCanonicalAttr is an optional argument to CudnnRNNParamsToCanonical. +type CudnnRNNParamsToCanonicalAttr func(optionalAttr) + +// CudnnRNNParamsToCanonicalRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNParamsToCanonicalRnnMode(value string) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["rnn_mode"] = value + } +} + +// CudnnRNNParamsToCanonicalInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNParamsToCanonicalInputMode(value string) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNParamsToCanonicalDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNParamsToCanonicalDirection(value string) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNParamsToCanonicalDropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsToCanonicalDropout(value float32) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["dropout"] = value + } +} + +// CudnnRNNParamsToCanonicalSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsToCanonicalSeed(value int64) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// CudnnRNNParamsToCanonicalSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsToCanonicalSeed2(value int64) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Retrieves CudnnRNN params in canonical form. +// +// Retrieves a set of weights from the opaque params buffer that can be saved and +// restored in a way compatible with future runs. +// +// Note that the params buffer may not be compatible across different GPUs. So any +// save and restoration should be converted to and from the canonical weights and +// biases. +// +// num_layers: Specifies the number of layers in the RNN model. +// num_units: Specifies the size of the hidden state. +// input_size: Specifies the size of the input state. +// num_params: number of parameter sets for all layers. +// Each layer may contain multiple parameter sets, with each set consisting of +// a weight matrix and a bias vector. +// weights: the canonical form of weights that can be used for saving +// and restoration. They are more likely to be compatible across different +// generations. +// biases: the canonical form of biases that can be used for saving +// and restoration. They are more likely to be compatible across different +// generations. +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// The actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. +// dir = (direction == bidirectional) ? 2 : 1 +// dropout: dropout probability. When set to 0., dropout is disabled. +// seed: the 1st part of a seed to initialize dropout. +// seed2: the 2nd part of a seed to initialize dropout. +func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, params tf.Output, num_params int64, optional ...CudnnRNNParamsToCanonicalAttr) (weights []tf.Output, biases []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_params": num_params} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CudnnRNNParamsToCanonical", + Input: []tf.Input{ + num_layers, num_units, input_size, params, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if weights, idx, err = makeOutputList(op, idx, "weights"); err != nil { + scope.UpdateErr("CudnnRNNParamsToCanonical", err) + return + } + if biases, idx, err = makeOutputList(op, idx, "biases"); err != nil { + scope.UpdateErr("CudnnRNNParamsToCanonical", err) + return + } + return weights, biases +} + +// Returns the number of records this Reader has produced. +// +// This is the same as the number of ReaderRead executions that have +// succeeded. +// +// Arguments: +// reader_handle: Handle to a Reader. +func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_produced tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderNumRecordsProducedV2", + Input: []tf.Input{ + reader_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd. +type ResourceScatterNdAddAttr func(optionalAttr) + +// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value. +// +// value: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, +// but may exhibit less contention. +// If not specified, defaults to true +func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Applies sparse addition to individual values or slices in a Variable. +// +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// ``` +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] +// ``` +// +// For example, say we want to add 4 scattered elements to a rank-1 tensor to +// 8 elements. In Python, that addition would look like this: +// +// ```python +// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// add = tf.scatter_nd_add(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(add) +// ``` +// +// The resulting update to ref would look like this: +// +// [1, 13, 3, 14, 14, 6, 7, 20] +// +// See `tf.scatter_nd` for more details about how to make updates to +// slices. +// +// Arguments: +// ref: A resource handle. Must be from a VarHandleOp. +// indices: A Tensor. Must be one of the following types: int32, int64. +// A tensor of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of +// values to add to ref. +// +// Returns the created operation. +func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceScatterNdAdd", + Input: []tf.Input{ + ref, indices, updates, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Decode the frame(s) of a GIF-encoded image to a uint8 tensor. +// +// GIF images with frame or transparency compression are not supported. +// On Linux and MacOS systems, convert animated GIFs from compressed to +// uncompressed by running: +// +// convert $src.gif -coalesce $dst.gif +// +// This op also supports decoding JPEGs and PNGs, though it is cleaner to use +// `tf.image.decode_image`. +// +// Arguments: +// contents: 0-D. The GIF-encoded image. +// +// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB channel order. +func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DecodeGif", + Input: []tf.Input{ + contents, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset with a range of values. Corresponds to python's xrange. +// +// Arguments: +// start: corresponds to start in python's xrange(). +// stop: corresponds to stop in python's xrange(). +// step: corresponds to step in python's xrange(). +// +// +func RangeDataset(scope *Scope, start tf.Output, stop tf.Output, step tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "RangeDataset", + Input: []tf.Input{ + start, stop, step, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StringSplitV2Attr is an optional argument to StringSplitV2. +type StringSplitV2Attr func(optionalAttr) + +// StringSplitV2Maxsplit sets the optional maxsplit attribute to value. +// +// value: An `int`. If `maxsplit > 0`, limit of the split of the result. +// If not specified, defaults to -1 +func StringSplitV2Maxsplit(value int64) StringSplitV2Attr { + return func(m optionalAttr) { + m["maxsplit"] = value + } +} + +// Split elements of `source` based on `sep` into a `SparseTensor`. +// +// Let N be the size of source (typically N will be the batch size). Split each +// element of `source` based on `sep` and return a `SparseTensor` +// containing the split tokens. Empty tokens are ignored. +// +// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', +// then the output will be +// ``` +// st.indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// st.shape = [2, 3] +// st.values = ['hello', 'world', 'a', 'b', 'c'] +// ``` +// +// If `sep` is given, consecutive delimiters are not grouped together and are +// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and +// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty +// string, consecutive whitespace are regarded as a single separator, and the +// result will contain no empty strings at the startor end if the string has +// leading or trailing whitespace. +// +// Note that the above mentioned behavior matches python's str.split. +// +// Arguments: +// input: `1-D` string `Tensor`, the strings to split. +// sep: `0-D` string `Tensor`, the delimiter character. +func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StringSplitV2", + Input: []tf.Input{ + input, sep, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Generates fingerprint values. +// +// Generates fingerprint values of `data`. +// +// Fingerprint op considers the first dimension of `data` as the batch dimension, +// and `output[i]` contains the fingerprint value generated from contents in +// `data[i, ...]` for all `i`. +// +// Fingerprint op writes fingerprint values as byte arrays. For example, the +// default method `farmhash64` generates a 64-bit fingerprint value at a time. +// This 8-byte value is written out as an `uint8` array of size 8, in little-endian +// order. +// +// For example, suppose that `data` has data type `DT_INT32` and shape (2, 3, 4), +// and that the fingerprint method is `farmhash64`. In this case, the output shape +// is (2, 8), where 2 is the batch dimension size of `data`, and 8 is the size of +// each fingerprint value in bytes. `output[0, :]` is generated from 12 integers in +// `data[0, :, :]` and similarly `output[1, :]` is generated from other 12 integers +// in `data[1, :, :]`. +// +// Note that this op fingerprints the raw underlying buffer, and it does not +// fingerprint Tensor's metadata such as data type and/or shape. For example, the +// fingerprint values are invariant under reshapes and bitcasts as long as the +// batch dimension remain the same: +// +// ``` +// Fingerprint(data) == Fingerprint(Reshape(data, ...)) +// Fingerprint(data) == Fingerprint(Bitcast(data, ...)) +// ``` +// +// For string data, one should expect `Fingerprint(data) != +// Fingerprint(ReduceJoin(data))` in general. +// +// Arguments: +// data: Must have rank 1 or higher. +// method: Fingerprint method used by this op. Currently available method is +// `farmhash::fingerprint64`. +// +// Returns A two-dimensional `Tensor` of type `tf.uint8`. The first dimension equals to +// `data`'s first dimension, and the second dimension size depends on the +// fingerprint algorithm. +func Fingerprint(scope *Scope, data tf.Output, method tf.Output) (fingerprint tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Fingerprint", + Input: []tf.Input{ + data, method, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AvgPool3DAttr is an optional argument to AvgPool3D. +type AvgPool3DAttr func(optionalAttr) + +// AvgPool3DDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DDataFormat(value string) AvgPool3DAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs 3D average pooling on the input. +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +// +// Returns The average pooled output tensor. +func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AvgPool3D", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StringLengthAttr is an optional argument to StringLength. +type StringLengthAttr func(optionalAttr) + +// StringLengthUnit sets the optional unit attribute to value. +// +// value: The unit that is counted to compute string length. One of: `"BYTE"` (for +// the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8 +// encoded Unicode code points in each string). Results are undefined +// if `unit=UTF8_CHAR` and the `input` strings do not contain structurally +// valid UTF-8. +// If not specified, defaults to "BYTE" +func StringLengthUnit(value string) StringLengthAttr { + return func(m optionalAttr) { + m["unit"] = value + } +} + +// String lengths of `input`. +// +// Computes the length of each string given in the input tensor. +// +// Arguments: +// input: The string for which to compute the length. +// +// Returns Integer tensor that has the same shape as `input`. The output contains the +// element-wise string lengths of `input`. +func StringLength(scope *Scope, input tf.Output, optional ...StringLengthAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StringLength", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SubstrAttr is an optional argument to Substr. +type SubstrAttr func(optionalAttr) + +// SubstrUnit sets the optional unit attribute to value. +// +// value: The unit that is used to create the substring. One of: `"BYTE"` (for +// defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8 +// encoded Unicode code points). The default is `"BYTE"`. Results are undefined if +// `unit=UTF8_CHAR` and the `input` strings do not contain structurally valid +// UTF-8. +// If not specified, defaults to "BYTE" +func SubstrUnit(value string) SubstrAttr { + return func(m optionalAttr) { + m["unit"] = value + } +} + +// Return substrings from `Tensor` of strings. +// +// For each string in the input `Tensor`, creates a substring starting at index +// `pos` with a total length of `len`. +// +// If `len` defines a substring that would extend beyond the length of the input +// string, then as many characters as possible are used. +// +// A negative `pos` indicates distance within the string backwards from the end. +// +// If `pos` specifies an index which is out of range for any of the input strings, +// then an `InvalidArgumentError` is thrown. +// +// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on +// Op creation. +// +// *NOTE*: `Substr` supports broadcasting up to two dimensions. More about +// broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +// +// --- +// +// Examples +// +// Using scalar `pos` and `len`: +// +// ```python +// input = [b'Hello', b'World'] +// position = 1 +// length = 3 +// +// output = [b'ell', b'orl'] +// ``` +// +// Using `pos` and `len` with same shape as `input`: +// +// ```python +// input = [[b'ten', b'eleven', b'twelve'], +// [b'thirteen', b'fourteen', b'fifteen'], +// [b'sixteen', b'seventeen', b'eighteen']] +// position = [[1, 2, 3], +// [1, 2, 3], +// [1, 2, 3]] +// length = [[2, 3, 4], +// [4, 3, 2], +// [5, 5, 5]] +// +// output = [[b'en', b'eve', b'lve'], +// [b'hirt', b'urt', b'te'], +// [b'ixtee', b'vente', b'hteen']] +// ``` +// +// Broadcasting `pos` and `len` onto `input`: +// +// ``` +// input = [[b'ten', b'eleven', b'twelve'], +// [b'thirteen', b'fourteen', b'fifteen'], +// [b'sixteen', b'seventeen', b'eighteen'], +// [b'nineteen', b'twenty', b'twentyone']] +// position = [1, 2, 3] +// length = [1, 2, 3] +// +// output = [[b'e', b'ev', b'lve'], +// [b'h', b'ur', b'tee'], +// [b'i', b've', b'hte'], +// [b'i', b'en', b'nty']] +// ``` +// +// Broadcasting `input` onto `pos` and `len`: +// +// ``` +// input = b'thirteen' +// position = [1, 5, 7] +// length = [3, 2, 1] +// +// output = [b'hir', b'ee', b'n'] +// ``` +// +// Arguments: +// input: Tensor of strings +// pos: Scalar defining the position of first character in each substring +// len: Scalar defining the number of characters to include in each substring +// +// Returns Tensor of substrings +func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output, optional ...SubstrAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Substr", + Input: []tf.Input{ + input, pos, len, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MapPeekAttr is an optional argument to MapPeek. +type MapPeekAttr func(optionalAttr) + +// MapPeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapPeekCapacity(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapPeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapPeekMemoryLimit(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapPeekContainer(value string) MapPeekAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapPeekSharedName(value string) MapPeekAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op peeks at the values at the specified key. If the +// +// underlying container does not contain this key +// this op will block until it does. +func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MapPeek", + Input: []tf.Input{ + key, indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapPeek", err) + return + } + return values +} + +// Determine the script codes of a given tensor of Unicode integer code points. +// +// This operation converts Unicode code points to script codes corresponding to +// each code point. Script codes correspond to International Components for +// Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html. +// Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will +// match input shape. +// +// Arguments: +// input: A Tensor of int32 Unicode code points. +// +// Returns A Tensor of int32 script codes corresponding to each input code point. +func UnicodeScript(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "UnicodeScript", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. +type StatelessRandomUniformAttr func(optionalAttr) + +// StatelessRandomUniformDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessRandomUniform", + Input: []tf.Input{ + shape, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// A placeholder op for a value that will be fed into the computation. +// +// Arguments: +// dtype: The type of elements in the tensor. +// shape: The shape of the tensor. +// +// Returns A tensor that will be provided using the infeed mechanism. +func InfeedDequeue(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + opspec := tf.OpSpec{ + Type: "InfeedDequeue", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatelessMultinomialAttr is an optional argument to StatelessMultinomial. +type StatelessMultinomialAttr func(optionalAttr) + +// StatelessMultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func StatelessMultinomialOutputDtype(value tf.DataType) StatelessMultinomialAttr { + return func(m optionalAttr) { + m["output_dtype"] = value + } +} + +// Draws samples from a multinomial distribution. +// +// Arguments: +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. +// seed: 2 seeds (shape [2]). +// +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func StatelessMultinomial(scope *Scope, logits tf.Output, num_samples tf.Output, seed tf.Output, optional ...StatelessMultinomialAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessMultinomial", + Input: []tf.Input{ + logits, num_samples, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Gather slices from `params` into a Tensor with shape specified by `indices`. +// +// `indices` is an K-dimensional integer tensor, best thought of as a +// (K-1)-dimensional tensor of indices into `params`, where each element defines a +// slice of `params`: +// +// output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]] +// +// Whereas in `tf.gather` `indices` defines slices into the first +// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the +// first `N` dimensions of `params`, where `N = indices.shape[-1]`. +// +// The last dimension of `indices` can be at most the rank of +// `params`: +// +// indices.shape[-1] <= params.rank +// +// The last dimension of `indices` corresponds to elements +// (if `indices.shape[-1] == params.rank`) or slices +// (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]` +// of `params`. The output tensor has shape +// +// indices.shape[:-1] + params.shape[indices.shape[-1]:] +// +// Note that on CPU, if an out of bound index is found, an error is returned. +// On GPU, if an out of bound index is found, a 0 is stored in the +// corresponding output value. +// +// Some examples below. +// +// Simple indexing into a matrix: +// +// ```python +// indices = [[0, 0], [1, 1]] +// params = [['a', 'b'], ['c', 'd']] +// output = ['a', 'd'] +// ``` +// +// Slice indexing into a matrix: +// +// ```python +// indices = [[1], [0]] +// params = [['a', 'b'], ['c', 'd']] +// output = [['c', 'd'], ['a', 'b']] +// ``` +// +// Indexing into a 3-tensor: +// +// ```python +// indices = [[1]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [[['a1', 'b1'], ['c1', 'd1']]] +// +// +// indices = [[0, 1], [1, 0]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [['c0', 'd0'], ['a1', 'b1']] +// +// +// indices = [[0, 0, 1], [1, 0, 1]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = ['b0', 'b1'] +// ``` +// +// Batched indexing into a matrix: +// +// ```python +// indices = [[[0, 0]], [[0, 1]]] +// params = [['a', 'b'], ['c', 'd']] +// output = [['a'], ['b']] +// ``` +// +// Batched slice indexing into a matrix: +// +// ```python +// indices = [[[1]], [[0]]] +// params = [['a', 'b'], ['c', 'd']] +// output = [[['c', 'd']], [['a', 'b']]] +// ``` +// +// Batched indexing into a 3-tensor: +// +// ```python +// indices = [[[1]], [[0]]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [[[['a1', 'b1'], ['c1', 'd1']]], +// [[['a0', 'b0'], ['c0', 'd0']]]] +// +// indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [[['c0', 'd0'], ['a1', 'b1']], +// [['a0', 'b0'], ['c1', 'd1']]] +// +// +// indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]] +// params = [[['a0', 'b0'], ['c0', 'd0']], +// [['a1', 'b1'], ['c1', 'd1']]] +// output = [['b0', 'b1'], ['d0', 'c1']] +// ``` +// +// See also `tf.gather` and `tf.batch_gather`. +// +// Arguments: +// params: The tensor from which to gather values. +// indices: Index tensor. +// +// Returns Values from `params` gathered from indices given by `indices`, with +// shape `indices.shape[:-1] + params.shape[indices.shape[-1]:]`. +func GatherNd(scope *Scope, params tf.Output, indices tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "GatherNd", + Input: []tf.Input{ + params, indices, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns a batched matrix tensor with new batched diagonal values. +// +// Given `input` and `diagonal`, this operation returns a tensor with the +// same shape and values as `input`, except for the main diagonal of the +// innermost matrices. These will be overwritten by the values in `diagonal`. +// +// The output is computed as follows: +// +// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has +// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a +// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: +// +// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. +// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. +// +// Arguments: +// input: Rank `k+1`, where `k >= 1`. +// diagonal: Rank `k`, where `k >= 1`. +// +// Returns Rank `k+1`, with `output.shape = input.shape`. +func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixSetDiag", + Input: []tf.Input{ + input, diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Store the input tensor in the state of the current session. +// +// Arguments: +// value: The tensor to be stored. +// +// Returns The handle for the tensor stored in the session state, represented +// as a ResourceHandle object. +func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "GetSessionHandleV2", + Input: []tf.Input{ + value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Connects outputs of an N-way replicated computation to N outputs. +func TPUReplicatedOutput(scope *Scope, input tf.Output, num_replicas int64) (outputs []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_replicas": num_replicas} + opspec := tf.OpSpec{ + Type: "TPUReplicatedOutput", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("TPUReplicatedOutput", err) + return + } + return outputs +} + +// Returns true if queue is closed. +// +// This operation returns true if the queue is closed and false if the queue +// is open. +// +// Arguments: +// handle: The handle to a queue. +func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "QueueIsClosedV2", + Input: []tf.Input{ + handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Connects N inputs to an N-way replicated TPU computation. +func TPUReplicatedInput(scope *Scope, inputs []tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TPUReplicatedInput", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inserts a dimension of 1 into a tensor's shape. +// +// Given a tensor `input`, this operation inserts a dimension of 1 at the +// dimension index `axis` of `input`'s shape. The dimension index `axis` starts at +// zero; if you specify a negative number for `axis` it is counted backward from +// the end. +// +// This operation is useful if you want to add a batch dimension to a single +// element. For example, if you have a single image of shape `[height, width, +// channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, +// which will make the shape `[1, height, width, channels]`. +// +// Other examples: +// +// ``` +// # 't' is a tensor of shape [2] +// shape(expand_dims(t, 0)) ==> [1, 2] +// shape(expand_dims(t, 1)) ==> [2, 1] +// shape(expand_dims(t, -1)) ==> [2, 1] +// +// # 't2' is a tensor of shape [2, 3, 5] +// shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] +// shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] +// shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] +// ``` +// +// This operation requires that: +// +// `-1-input.dims() <= dim <= input.dims()` +// +// This operation is related to `squeeze()`, which removes dimensions of +// size 1. +// +// Arguments: +// +// axis: 0-D (scalar). Specifies the dimension index at which to +// expand the shape of `input`. Must be in the range +// `[-rank(input) - 1, rank(input)]`. +// +// Returns Contains the same data as `input`, but its shape has an additional +// dimension of size 1 added. +func ExpandDims(scope *Scope, input tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ExpandDims", + Input: []tf.Input{ + input, axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inverse real-valued fast Fourier transform. +// +// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued +// signal over the inner-most dimension of `input`. +// +// The inner-most dimension of `input` is assumed to be the result of `RFFT`: the +// `fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If +// `fft_length` is not provided, it is computed from the size of the inner-most +// dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to +// compute `input` is odd, it should be provided since it cannot be inferred +// properly. +// +// Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller +// than the corresponding dimension of `input`, the dimension is cropped. If it is +// larger, the dimension is padded with zeros. +// +// Arguments: +// input: A complex64 tensor. +// fft_length: An int32 tensor of shape [1]. The FFT length. +// +// Returns A float32 tensor of the same rank as `input`. The inner-most +// dimension of `input` is replaced with the `fft_length` samples of its inverse +// 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.irfft +// @end_compatibility +func IRFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IRFFT", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ModelDatasetAttr is an optional argument to ModelDataset. +type ModelDatasetAttr func(optionalAttr) + +// ModelDatasetCpuBudget sets the optional cpu_budget attribute to value. +// If not specified, defaults to 0 +func ModelDatasetCpuBudget(value int64) ModelDatasetAttr { + return func(m optionalAttr) { + m["cpu_budget"] = value + } +} + +// Identity transformation that models performance. +// +// Identity transformation that models performance. +// +// Arguments: +// input_dataset: A variant tensor representing the input dataset. +// +// +func ModelDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ModelDatasetAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ModelDataset", + Input: []tf.Input{ + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Produce a string tensor that encodes the state of a Reader. +// +// Not all Readers support being serialized, so this can produce an +// Unimplemented error. +// +// Arguments: +// reader_handle: Handle to a Reader. +func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderSerializeStateV2", + Input: []tf.Input{ + reader_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum. +type ResourceSparseApplyMomentumAttr func(optionalAttr) + +// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, the tensor passed to compute grad will be +// var - lr * momentum * accum, so in the end, the var you get is actually +// var - lr * momentum * accum. +// If not specified, defaults to false +func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update relevant entries in '*var' and '*accum' according to the momentum scheme. +// +// Set use_nesterov = True if you want to use Nesterov momentum. +// +// That is for rows we have grad for, we update var and accum as follows: +// +// accum = accum * momentum + grad +// var -= lr * accum +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// momentum: Momentum. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyMomentum", + Input: []tf.Input{ + var_, accum, lr, grad, indices, momentum, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Returns element-wise remainder of division. This emulates C semantics in that +// +// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * +// y + truncate_mod(x, y) = x`. +// +// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TruncateMod", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatefulUniformFullIntAttr is an optional argument to StatefulUniformFullInt. +type StatefulUniformFullIntAttr func(optionalAttr) + +// StatefulUniformFullIntDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_UINT64 +func StatefulUniformFullIntDtype(value tf.DataType) StatefulUniformFullIntAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random integers from a uniform distribution. +// +// The generated values are uniform integers covering the whole range of `dtype`. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// shape: The shape of the output tensor. +// +// Returns Random values with specified shape. +func StatefulUniformFullInt(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, optional ...StatefulUniformFullIntAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatefulUniformFullInt", + Input: []tf.Input{ + resource, algorithm, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the trignometric inverse sine of x element-wise. +// +// The `tf.math.asin` operation returns the inverse of `tf.math.sin`, such that +// if `y = tf.math.sin(x)` then, `x = tf.math.asin(y)`. +// +// **Note**: The output of `tf.math.asin` will lie within the invertible range +// of sine, i.e [-pi/2, pi/2]. +// +// For example: +// +// ```python +// # Note: [1.047, 0.785] ~= [(pi/3), (pi/4)] +// x = tf.constant([1.047, 0.785]) +// y = tf.math.sin(x) # [0.8659266, 0.7068252] +// +// tf.math.asin(y) # [1.047, 0.785] = x +// ``` +// +func Asin(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Asin", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// OutfeedDequeueTupleAttr is an optional argument to OutfeedDequeueTuple. +type OutfeedDequeueTupleAttr func(optionalAttr) + +// OutfeedDequeueTupleDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. This should be -1 when the Op +// is running on a TPU device, and >= 0 when the Op is running on the CPU +// device. +// If not specified, defaults to -1 +func OutfeedDequeueTupleDeviceOrdinal(value int64) OutfeedDequeueTupleAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// Retrieve multiple values from the computation outfeed. +// +// This operation will block indefinitely until data is available. Output `i` +// corresponds to XLA tuple element `i`. +// +// Arguments: +// dtypes: The element types of each element in `outputs`. +// shapes: The shapes of each tensor in `outputs`. +// +// Returns A list of tensors that will be read from the outfeed. +func OutfeedDequeueTuple(scope *Scope, dtypes []tf.DataType, shapes []tf.Shape, optional ...OutfeedDequeueTupleAttr) (outputs []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes, "shapes": shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OutfeedDequeueTuple", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("OutfeedDequeueTuple", err) + return + } + return outputs +} + +// Gets next element for the provided shard number. +// +// Arguments: +// multi_device_iterator: A MultiDeviceIterator resource. +// shard_num: Integer representing which shard to fetch data for. +// incarnation_id: Which incarnation of the MultiDeviceIterator is running. +// output_types: The type list for the return values. +// output_shapes: The list of shapes being produced. +// +// Returns Result of the get_next on the dataset. +func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.Output, shard_num tf.Output, incarnation_id tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "MultiDeviceIteratorGetNextFromShard", + Input: []tf.Input{ + multi_device_iterator, shard_num, incarnation_id, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("MultiDeviceIteratorGetNextFromShard", err) + return + } + return components +} + +// This op consumes a lock created by `MutexLock`. +// +// This op exists to consume a tensor created by `MutexLock` (other than +// direct control dependencies). It should be the only that consumes the tensor, +// and will raise an error if it is not. Its only purpose is to keep the +// mutex lock tensor alive until it is consumed by this op. +// +// **NOTE**: This operation must run on the same device as its input. This may +// be enforced via the `colocate_with` mechanism. +// +// Arguments: +// mutex_lock: A tensor returned by `MutexLock`. +// +// Returns the created operation. +func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConsumeMutexLock", + Input: []tf.Input{ + mutex_lock, + }, + } + return scope.AddOperation(opspec) +} + +// StringFormatAttr is an optional argument to StringFormat. +type StringFormatAttr func(optionalAttr) + +// StringFormatTemplate sets the optional template attribute to value. +// +// value: A string, the template to format tensor summaries into. +// If not specified, defaults to "%s" +func StringFormatTemplate(value string) StringFormatAttr { + return func(m optionalAttr) { + m["template"] = value + } +} + +// StringFormatPlaceholder sets the optional placeholder attribute to value. +// +// value: A string, at each placeholder in the template a subsequent tensor summary will be inserted. +// If not specified, defaults to "%s" +func StringFormatPlaceholder(value string) StringFormatAttr { + return func(m optionalAttr) { + m["placeholder"] = value + } +} + +// StringFormatSummarize sets the optional summarize attribute to value. +// +// value: When formatting the tensor summaries print the first and last summarize entries of each tensor dimension. +// If not specified, defaults to 3 +func StringFormatSummarize(value int64) StringFormatAttr { + return func(m optionalAttr) { + m["summarize"] = value + } +} + +// Formats a string template using a list of tensors. +// +// Formats a string template using a list of tensors, pretty-printing tensor summaries. +// +// Arguments: +// inputs: The list of tensors to format into the placeholder string. +// +// Returns = The resulting string scalar. +func StringFormat(scope *Scope, inputs []tf.Output, optional ...StringFormatAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StringFormat", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Real-valued fast Fourier transform. +// +// Computes the 1-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most dimension of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the +// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, +// followed by the `fft_length / 2` positive-frequency terms. +// +// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// +// Arguments: +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [1]. The FFT length. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most +// dimension of `input` is replaced with the `fft_length / 2 + 1` unique +// frequency components of its 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfft +// @end_compatibility +func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RFFT", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingMomentumParametersGradAccumDebug. +type LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) + +// LoadTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingMomentumParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingMomentumParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load Momentum embedding parameters with debug support. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the Momentum optimization algorithm. +// momenta: Value of momenta used in the Momentum optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the Momentum optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, parameters tf.Output, momenta tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingMomentumParametersGradAccumDebug", + Input: []tf.Input{ + parameters, momenta, gradient_accumulators, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// RestoreAttr is an optional argument to Restore. +type RestoreAttr func(optionalAttr) + +// RestorePreferredShard sets the optional preferred_shard attribute to value. +// +// value: Index of file to open first if multiple files match +// `file_pattern`. +// If not specified, defaults to -1 +func RestorePreferredShard(value int64) RestoreAttr { + return func(m optionalAttr) { + m["preferred_shard"] = value + } +} + +// Restores a tensor from checkpoint files. +// +// Reads a tensor stored in one or several files. If there are several files (for +// instance because a tensor was saved as slices), `file_pattern` may contain +// wildcard symbols (`*` and `?`) in the filename portion only, not in the +// directory portion. +// +// If a `file_pattern` matches several files, `preferred_shard` can be used to hint +// in which file the requested tensor is likely to be found. This op will first +// open the file at index `preferred_shard` in the list of matching files and try +// to restore tensors from that file. Only if some tensors or tensor slices are +// not found in that first file, then the Op opens all the files. Setting +// `preferred_shard` to match the value passed as the `shard` input +// of a matching `Save` Op may speed up Restore. This attribute only affects +// performance, not correctness. The default value -1 means files are processed in +// order. +// +// See also `RestoreSlice`. +// +// Arguments: +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// dt: The type of the tensor to be restored. +// +// Returns The restored tensor. +func Restore(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, dt tf.DataType, optional ...RestoreAttr) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dt": dt} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Restore", + Input: []tf.Input{ + file_pattern, tensor_name, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FusedBatchNormGradAttr is an optional argument to FusedBatchNormGrad. +type FusedBatchNormGradAttr func(optionalAttr) + +// FusedBatchNormGradEpsilon sets the optional epsilon attribute to value. +// +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormGradEpsilon(value float32) FusedBatchNormGradAttr { + return func(m optionalAttr) { + m["epsilon"] = value + } +} + +// FusedBatchNormGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format for y_backprop, x, x_backprop. +// Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormGradDataFormat(value string) FusedBatchNormGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormGradIsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormGradIsTraining(value bool) FusedBatchNormGradAttr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// Gradient for batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// y_backprop: A 4D Tensor for the gradient with respect to y. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch +// mean to be reused in gradient computation. When is_training is +// False, a 1D Tensor for the population mean to be reused in both +// 1st and 2nd order gradient computation. +// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch +// variance (inverted variance in the cuDNN case) to be reused in +// gradient computation. When is_training is False, a 1D Tensor +// for the population variance to be reused in both 1st and 2nd +// order gradient computation. +// +// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input +// in FusedBatchNorm. +func FusedBatchNormGrad(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradAttr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FusedBatchNormGrad", + Input: []tf.Input{ + y_backprop, x, scale, reserve_space_1, reserve_space_2, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + +// RetrieveTPUEmbeddingFTRLParametersAttr is an optional argument to RetrieveTPUEmbeddingFTRLParameters. +type RetrieveTPUEmbeddingFTRLParametersAttr func(optionalAttr) + +// RetrieveTPUEmbeddingFTRLParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func RetrieveTPUEmbeddingFTRLParametersTableId(value int64) RetrieveTPUEmbeddingFTRLParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingFTRLParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingFTRLParametersTableName(value string) RetrieveTPUEmbeddingFTRLParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Retrieve FTRL embedding parameters. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns Parameter parameters updated by the FTRL optimization algorithm.Parameter accumulators updated by the FTRL optimization algorithm.Parameter linears updated by the FTRL optimization algorithm. +func RetrieveTPUEmbeddingFTRLParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingFTRLParameters", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// SparseReduceMaxSparseAttr is an optional argument to SparseReduceMaxSparse. +type SparseReduceMaxSparseAttr func(optionalAttr) + +// SparseReduceMaxSparseKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceMaxSparseKeepDims(value bool) SparseReduceMaxSparseAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the max of elements across dimensions of a SparseTensor. +// +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_max()`. In contrast to SparseReduceMax, this Op returns a +// SparseTensor. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +func SparseReduceMaxSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseReduceMaxSparse", + Input: []tf.Input{ + input_indices, input_values, input_shape, reduction_axes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Output a fact about factorials. +func Fact(scope *Scope) (fact tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Fact", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the minimum along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// This operator is similar to the unsorted segment sum operator found +// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). +// Instead of computing the sum over segments, it computes the minimum such that: +// +// \\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such +// that `segment_ids[j...] == i`. +// +// If the minimum is empty for a given segment ID `i`, it outputs the largest +// possible value for the specific numeric type, +// `output[i] = numeric_limits::max()`. +// +// For example: +// +// ``` python +// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) +// tf.unsorted_segment_min(c, tf.constant([0, 1, 0]), num_segments=2) +// # ==> [[ 1, 2, 2, 1], +// # [5, 6, 7, 8]] +// ``` +// +// If the given segment ID `i` is negative, then the corresponding value is +// dropped, and will not be included in the result. +// +// Arguments: +// +// segment_ids: A tensor whose shape is a prefix of `data.shape`. +// +// +// Returns Has same shape as data, except for the first `segment_ids.rank` +// dimensions, which are replaced with a single dimension which has size +// `num_segments`. +func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "UnsortedSegmentMin", + Input: []tf.Input{ + data, segment_ids, num_segments, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent. +type ResourceApplyGradientDescentAttr func(optionalAttr) + +// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' by subtracting 'alpha' * 'delta' from it. +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// delta: The change. +// +// Returns the created operation. +func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyGradientDescent", + Input: []tf.Input{ + var_, alpha, delta, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// QuantizedDepthwiseConv2DAttr is an optional argument to QuantizedDepthwiseConv2D. +type QuantizedDepthwiseConv2DAttr func(optionalAttr) + +// QuantizedDepthwiseConv2DOutType sets the optional out_type attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_QINT32 +func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2DAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value. +// +// value: List of dilation values. +// If not specified, defaults to +func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes quantized depthwise Conv2D. +// +// Arguments: +// input: The original input tensor. +// filter: The original filter tensor. +// min_input: The float value that the minimum quantized input value represents. +// max_input: The float value that the maximum quantized input value represents. +// min_filter: The float value that the minimum quantized filter value represents. +// max_filter: The float value that the maximum quantized filter value represents. +// strides: List of stride values. +// +// +// Returns The output tensor.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedDepthwiseConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedDepthwiseConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedDepthwiseConv2D", + Input: []tf.Input{ + input, filter, min_input, max_input, min_filter, max_filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Compute the polygamma function \\(\psi^{(n)}(x)\\). +// +// The polygamma function is defined as: +// +// +// \\(\psi^{(a)}(x) = \frac{d^a}{dx^a} \psi(x)\\) +// +// where \\(\psi(x)\\) is the digamma function. +// The polygamma function is defined only for non-negative integer orders \\a\\. +func Polygamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Polygamma", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes softplus: `log(exp(features) + 1)`. +func Softplus(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Softplus", + Input: []tf.Input{ + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomUniformAttr is an optional argument to RandomUniform. +type RandomUniformAttr func(optionalAttr) + +// RandomUniformSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomUniformSeed(value int64) RandomUniformAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomUniformSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomUniformSeed2(value int64) RandomUniformAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// Arguments: +// shape: The shape of the output tensor. +// dtype: The type of the output. +// +// Returns A tensor of the specified shape filled with uniform random values. +func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomUniform", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the complementary error function of `x` element-wise. +func Erfc(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Erfc", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Subtracts a value from the current value of a variable. +// +// Any ReadVariableOp with a control dependency on this op is guaranteed to +// see the decremented value or a subsequent newer one. +// +// Arguments: +// resource: handle to the resource in which to store the variable. +// value: the value by which the variable will be incremented. +// +// Returns the created operation. +func AssignSubVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AssignSubVariableOp", + Input: []tf.Input{ + resource, value, + }, + } + return scope.AddOperation(opspec) +} + +// An op enabling differentiation of TPU Embeddings. +// +// This op simply returns its first input, which is assumed to have been sliced +// from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of +// this op, and its first argument being a trainable Variable, enables automatic +// differentiation of graphs containing embeddings via the TPU Embedding Python +// libraries. +// +// Arguments: +// embedding_variable: A trainable variable, enabling optimizers to find this op. +// sliced_activations: The embedding activations Tensor to return. +// table_id: The id of the table in the embedding layer configuration from which +// these activations were computed. +// lookup_id: Identifier of the set of embedding indices which produced these +// activations. +func TPUEmbeddingActivations(scope *Scope, embedding_variable tf.Output, sliced_activations tf.Output, table_id int64, lookup_id int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"table_id": table_id, "lookup_id": lookup_id} + opspec := tf.OpSpec{ + Type: "TPUEmbeddingActivations", + Input: []tf.Input{ + embedding_variable, sliced_activations, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// PrelinearizeTupleAttr is an optional argument to PrelinearizeTuple. +type PrelinearizeTupleAttr func(optionalAttr) + +// PrelinearizeTupleLayouts sets the optional layouts attribute to value. +// +// value: A vector holding the requested layout in minor-to-major sequence for all the +// tuple shapes in the order the shapes appear in the "shapes" input. The layout +// elements for a sub-shape can be set to -1 in which case the corresponding layout +// will be computed by the infeed operation. +// If not specified, defaults to <> +func PrelinearizeTupleLayouts(value []int64) PrelinearizeTupleAttr { + return func(m optionalAttr) { + m["layouts"] = value + } +} + +// An op which linearizes multiple Tensor values to an opaque variant tensor. +// +// Arguments: +// inputs: A list of tensors that will be provided using the infeed mechanism. +// shapes: The shapes of each tensor in `inputs`. +func PrelinearizeTuple(scope *Scope, inputs []tf.Output, shapes []tf.Shape, optional ...PrelinearizeTupleAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shapes": shapes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "PrelinearizeTuple", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 3D fast Fourier transform. +// +// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 +// dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fftn with 3 dimensions. +// @end_compatibility +func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT3D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. +type ResizeNearestNeighborAttr func(optionalAttr) + +// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// ResizeNearestNeighborHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeNearestNeighborHalfPixelCenters(value bool) ResizeNearestNeighborAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value + } +} + +// Resize `images` to `size` using nearest neighbor interpolation. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeNearestNeighbor", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxAttr is an optional argument to Max. +type MaxAttr func(optionalAttr) + +// MaxKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func MaxKeepDims(value bool) MaxAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the maximum of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Max(scope *Scope, input tf.Output, axis tf.Output, optional ...MaxAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Max", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. +type QueueDequeueUpToV2Attr func(optionalAttr) + +// QueueDequeueUpToV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue has fewer than n elements, this operation +// will block for up to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueUpToV2TimeoutMs(value int64) QueueDequeueUpToV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Dequeues `n` tuples of one or more tensors from the given queue. +// +// This operation is not supported by all queues. If a queue does not support +// DequeueUpTo, then an Unimplemented error is returned. +// +// If the queue is closed and there are more than 0 but less than `n` +// elements remaining, then instead of returning an OutOfRange error like +// QueueDequeueMany, less than `n` elements are returned immediately. If +// the queue is closed and there are 0 elements left in the queue, then +// an OutOfRange error is returned just like in QueueDequeueMany. +// Otherwise the behavior is identical to QueueDequeueMany: +// +// This operation concatenates queue-element component tensors along the +// 0th dimension to make a single component tensor. All of the components +// in the dequeued tuple will have size n in the 0th dimension. +// +// This operation has `k` outputs, where `k` is the number of components in +// the tuples stored in the given queue, and output `i` is the ith +// component of the dequeued tuple. +// +// Arguments: +// handle: The handle to a queue. +// n: The number of tuples to dequeue. +// component_types: The type of each component in a tuple. +// +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueUpToV2Attr) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QueueDequeueUpToV2", + Input: []tf.Input{ + handle, n, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueUpToV2", err) + return + } + return components +} + +// Computes rectified linear: `max(features, 0)`. +func Relu(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Relu", + Input: []tf.Input{ + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Conv3DBackpropFilterAttr is an optional argument to Conv3DBackpropFilter. +type Conv3DBackpropFilterAttr func(optionalAttr) + +// Conv3DBackpropFilterDilations sets the optional dilations attribute to value. +// If not specified, defaults to +func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of 3-D convolution with respect to the filter. +// +// DEPRECATED at GraphDef version 10: Use Conv3DBackpropFilterV2 +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv3DBackpropFilter", + Input: []tf.Input{ + input, filter, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes hyperbolic sine of x element-wise. +func Sinh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sinh", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 2D real-valued fast Fourier transform. +// +// Computes the 2-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 2 dimensions of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. +// +// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the +// corresponding dimension of `input`, the dimension is cropped. If it is larger, +// the dimension is padded with zeros. +// +// Arguments: +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfft2 +// @end_compatibility +func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RFFT2D", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute. +type DataFormatVecPermuteAttr func(optionalAttr) + +// DataFormatVecPermuteSrcFormat sets the optional src_format attribute to value. +// +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatVecPermuteSrcFormat(value string) DataFormatVecPermuteAttr { + return func(m optionalAttr) { + m["src_format"] = value + } +} + +// DataFormatVecPermuteDstFormat sets the optional dst_format attribute to value. +// +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatVecPermuteDstFormat(value string) DataFormatVecPermuteAttr { + return func(m optionalAttr) { + m["dst_format"] = value + } +} + +// Returns the permuted vector/tensor in the destination data format given the +// +// one in the source data format. +// +// Arguments: +// x: Vector of size 4 or Tensor of shape (4, 2) in source data format. +// +// Returns Vector of size 4 or Tensor of shape (4, 2) in destination data format. +func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPermuteAttr) (y tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DataFormatVecPermute", + Input: []tf.Input{ + x, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the element-wise sum of a list of tensors. +// +// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not +// wait for all of its inputs to be ready before beginning to sum. This can +// save memory if inputs are ready at different times, since minimum temporary +// storage is proportional to the output size rather than the inputs size. +// +// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. +// +// Returns a `Tensor` of same shape and type as the elements of `inputs`. +// +// Arguments: +// inputs: A list of `Tensor` objects, each with same shape and type. +// shape: Shape of elements of `inputs`. +func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "AccumulateNV2", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TridiagonalSolveAttr is an optional argument to TridiagonalSolve. +type TridiagonalSolveAttr func(optionalAttr) + +// TridiagonalSolvePartialPivoting sets the optional partial_pivoting attribute to value. +// +// value: Whether to apply partial pivoting. Partial pivoting makes the procedure more +// stable, but slower. +// If not specified, defaults to true +func TridiagonalSolvePartialPivoting(value bool) TridiagonalSolveAttr { + return func(m optionalAttr) { + m["partial_pivoting"] = value + } +} + +// Solves tridiagonal systems of equations. +// +// Solves tridiagonal systems of equations. +// Supports batch dimensions and multiple right-hand sides per each left-hand +// side. +// On CPU, solution is computed via Gaussian elimination with or without partial +// pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE +// library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv +// +// Arguments: +// diagonals: Tensor of shape `[..., 3, M]` whose innermost 2 dimensions represent the +// tridiagonal matrices with three rows being the superdiagonal, diagonals, and +// subdiagonals, in order. The last element of the superdiagonal and the first +// element of the subdiagonal is ignored. +// rhs: Tensor of shape `[..., M, K]`, representing K right-hand sides per each +// left-hand side. +// +// Returns Tensor of shape `[..., M, K]` containing the solutions +func TridiagonalSolve(scope *Scope, diagonals tf.Output, rhs tf.Output, optional ...TridiagonalSolveAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TridiagonalSolve", + Input: []tf.Input{ + diagonals, rhs, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg. +type DecodeAndCropJpegAttr func(optionalAttr) + +// DecodeAndCropJpegChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodeAndCropJpegRatio sets the optional ratio attribute to value. +// +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["ratio"] = value + } +} + +// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["fancy_upscaling"] = value + } +} + +// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value. +// +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } +} + +// Decode and Crop a JPEG-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// +// It is equivalent to a combination of decode and crop, but much faster by only +// decoding partial jpeg image. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width]. +// +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeAndCropJpeg", + Input: []tf.Input{ + contents, crop_window, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Elementwise computes the bitwise right-shift of `x` and `y`. +// +// Performs a logical shift for unsigned integer types, and an arithmetic shift +// for signed integer types. +// +// If `y` is negative, or greater than or equal to than the width of `x` in bits +// the result is implementation defined. +func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RightShift", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Mutually accumulates multiple tensors of identical type and shape. +func CollectiveGather(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} + opspec := tf.OpSpec{ + Type: "CollectiveGather", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr is an optional argument to QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize. +type QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr func(optionalAttr) + +// QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType sets the optional out_type attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_QUINT8 +func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataType) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value. +// +// value: List of dilation values. +// If not specified, defaults to +func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes quantized depthwise Conv2D with Bias, Relu and Requantize. +// +// Arguments: +// input: The original input tensor. +// filter: The original filter tensor. +// bias: The original bias tensor. +// min_input: The float value that the minimum quantized input value represents. +// max_input: The float value that the maximum quantized input value represents. +// min_filter: The float value that the minimum quantized filter value represents. +// max_filter: The float value that the maximum quantized filter value represents. +// min_freezed_output: The minimum float value of the output tensor. +// max_freezed_output: The maximum float value of the output tensor. +// strides: List of stride values. +// +// +// Returns The output tensor.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize(scope *Scope, input tf.Output, filter tf.Output, bias tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, min_freezed_output tf.Output, max_freezed_output tf.Output, strides []int64, padding string, optional ...QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", + Input: []tf.Input{ + input, filter, bias, min_input, max_input, min_filter, max_filter, min_freezed_output, max_freezed_output, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Outputs a tensor containing the reduction across all input tensors. +// +// Outputs a tensor containing the reduction across all input tensors passed to ops +// within the same `shared_name. +// +// The graph should be constructed so if one op runs with shared_name value `c`, +// then `num_devices` ops will run with shared_name value `c`. Failure to do so +// will cause the graph execution to fail to complete. +// +// input: the input to the reduction +// data: the value of the reduction across all `num_devices` devices. +// reduction: the reduction operation to perform. +// num_devices: The number of devices participating in this reduction. +// shared_name: Identifier that shared between ops of the same reduction. +func NcclAllReduce(scope *Scope, input tf.Output, reduction string, num_devices int64, shared_name string) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"reduction": reduction, "num_devices": num_devices, "shared_name": shared_name} + opspec := tf.OpSpec{ + Type: "NcclAllReduce", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reduces `input` from `num_devices` using `reduction` to a single device. +// +// Reduces `input` from `num_devices` using `reduction` to a single device. +// +// The graph should be constructed so that all inputs have a valid device +// assignment, and the op itself is assigned one of these devices. +// +// input: The input to the reduction. +// data: the value of the reduction across all `num_devices` devices. +// reduction: the reduction operation to perform. +func NcclReduce(scope *Scope, input []tf.Output, reduction string) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"reduction": reduction} + opspec := tf.OpSpec{ + Type: "NcclReduce", + Input: []tf.Input{ + tf.OutputList(input), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Ensures that the tensor's shape matches the expected shape. +// +// Raises an error if the input tensor's shape does not match the specified shape. +// Returns the input tensor otherwise. +// +// Arguments: +// input: A tensor, whose shape is to be validated. +// shape: The expected (possibly partially specified) shape of the input tensor. +// +// Returns A tensor with the same shape and contents as the input tensor or value. +func EnsureShape(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "EnsureShape", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SparseToDenseAttr is an optional argument to SparseToDense. +type SparseToDenseAttr func(optionalAttr) + +// SparseToDenseValidateIndices sets the optional validate_indices attribute to value. +// +// value: If true, indices are checked to make sure they are sorted in +// lexicographic order and that there are no repeats. +// If not specified, defaults to true +func SparseToDenseValidateIndices(value bool) SparseToDenseAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Converts a sparse representation into a dense tensor. +// +// Builds an array `dense` with shape `output_shape` such that +// +// ``` +// # If sparse_indices is scalar +// dense[i] = (i == sparse_indices ? sparse_values : default_value) +// +// # If sparse_indices is a vector, then for each i +// dense[sparse_indices[i]] = sparse_values[i] +// +// # If sparse_indices is an n by d matrix, then for each i in [0, n) +// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] +// ``` +// +// All other values in `dense` are set to `default_value`. If `sparse_values` is a +// scalar, all sparse indices are set to this single value. +// +// Indices should be sorted in lexicographic order, and indices must not +// contain any repeats. If `validate_indices` is true, these properties +// are checked during execution. +// +// Arguments: +// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete +// index where `sparse_values[i]` will be placed. +// output_shape: 1-D. Shape of the dense output tensor. +// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, +// or a scalar value to be used for all sparse indices. +// default_value: Scalar value to set for indices not specified in +// `sparse_indices`. +// +// Returns Dense output tensor of shape `output_shape`. +func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseToDense", + Input: []tf.Input{ + sparse_indices, output_shape, sparse_values, default_value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adds up a SparseTensor and a dense Tensor, using these special rules: +// +// (1) Broadcasts the dense side to have the same shape as the sparse side, if +// eligible; +// (2) Then, only the dense values pointed to by the indices of the SparseTensor +// participate in the cwise addition. +// +// By these rules, the result is a logical SparseTensor with exactly the same +// indices and shape, but possibly with different non-zero values. The output of +// this Op is the resultant non-zero values. +// +// Arguments: +// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. +// sp_shape: 1-D. Shape of the input SparseTensor. +// dense: `R`-D. The dense Tensor operand. +// +// Returns 1-D. The `N` values that are operated on. +func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseDenseCwiseAdd", + Input: []tf.Input{ + sp_indices, sp_values, sp_shape, dense, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatefulUniformAttr is an optional argument to StatefulUniform. +type StatefulUniformAttr func(optionalAttr) + +// StatefulUniformDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatefulUniformDtype(value tf.DataType) StatefulUniformAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// +// Arguments: +// resource: The handle of the resource variable that stores the state of the RNG. +// algorithm: The RNG algorithm. +// shape: The shape of the output tensor. +// +// Returns Random values with specified shape. +func StatefulUniform(scope *Scope, resource tf.Output, algorithm tf.Output, shape tf.Output, optional ...StatefulUniformAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatefulUniform", + Input: []tf.Input{ + resource, algorithm, shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Saves input tensors slices to disk. +// +// This is like `Save` except that tensors can be listed in the saved file as being +// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the +// larger tensor and the slice that this tensor covers. `shapes_and_slices` must +// have as many elements as `tensor_names`. +// +// Elements of the `shapes_and_slices` input must either be: +// +// * The empty string, in which case the corresponding tensor is +// saved normally. +// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the +// `dimI` are the dimensions of the larger tensor and `slice-spec` +// specifies what part is covered by the tensor to save. +// +// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` +// where each `sliceI` is either: +// +// * The string `-` meaning that the slice covers all indices of this dimension +// * `start,length` where `start` and `length` are integers. In that +// case the slice covers `length` indices starting at `start`. +// +// See also `Save`. +// +// Arguments: +// filename: Must have a single element. The name of the file to which we write the +// tensor. +// tensor_names: Shape `[N]`. The names of the tensors to be saved. +// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when +// saving the tensors. +// data: `N` tensors to save. +// +// Returns the created operation. +func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SaveSlices", + Input: []tf.Input{ + filename, tensor_names, shapes_and_slices, tf.OutputList(data), + }, + } + return scope.AddOperation(opspec) +} + +// Sends `input` to all devices that are connected to the output. +// +// Sends `input` to all devices that are connected to the output. +// +// The graph should be constructed so that all ops connected to the output have a +// valid device assignment, and the op itself is assigned one of these devices. +// +// input: The input to the broadcast. +// output: The same as input. +// shape: The shape of the input tensor. +// +func NcclBroadcast(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "NcclBroadcast", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizedConv2DPerChannelAttr is an optional argument to QuantizedConv2DPerChannel. +type QuantizedConv2DPerChannelAttr func(optionalAttr) + +// QuantizedConv2DPerChannelOutType sets the optional out_type attribute to value. +// +// value: The quantized type of output tensor that needs to be converted. +// If not specified, defaults to DT_QINT32 +func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChannelAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value. +// +// value: list of dilation values. +// If not specified, defaults to +func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes QuantizedConv2D per channel. +// +// Arguments: +// input: The original input tensor. +// filter: The original filter tensor. +// min_input: The minimum value of the input tensor +// max_input: The maximum value of the input tensor. +// min_filter: The minimum value of the filter tensor. +// max_filter: The maximum value of the filter tensor. +// strides: list of stride values. +// +// +// Returns The output tensor.The minimum value of the final output tensor.The maximum value of the final output tensor. +func QuantizedConv2DPerChannel(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DPerChannelAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedConv2DPerChannel", + Input: []tf.Input{ + input, filter, min_input, max_input, min_filter, max_filter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Add all input tensors element wise. +// +// Arguments: +// inputs: Must all be the same size and shape. +func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AddN", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ComplexAttr is an optional argument to Complex. +type ComplexAttr func(optionalAttr) + +// ComplexTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_COMPLEX64 +func ComplexTout(value tf.DataType) ComplexAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Converts two real numbers to a complex number. +// +// Given a tensor `real` representing the real part of a complex number, and a +// tensor `imag` representing the imaginary part of a complex number, this +// operation returns complex numbers elementwise of the form \\(a + bj\\), where +// *a* represents the `real` part and *b* represents the `imag` part. +// +// The input tensors `real` and `imag` must have the same shape. +// +// For example: +// +// ``` +// # tensor 'real' is [2.25, 3.25] +// # tensor `imag` is [4.75, 5.75] +// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] +// ``` +func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Complex", + Input: []tf.Input{ + real, imag, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// BatchMatMulAttr is an optional argument to BatchMatMul. +type BatchMatMulAttr func(optionalAttr) + +// BatchMatMulAdjX sets the optional adj_x attribute to value. +// +// value: If `True`, adjoint the slices of `x`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulAdjX(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["adj_x"] = value + } +} + +// BatchMatMulAdjY sets the optional adj_y attribute to value. +// +// value: If `True`, adjoint the slices of `y`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulAdjY(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["adj_y"] = value + } +} + +// Multiplies slices of two tensors in batches. +// +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. Each of the +// individual slices can optionally be adjointed (to adjoint a matrix +// means to transpose and conjugate it) before multiplication by setting +// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if adj_x else r_x +// c_o = r_y if adj_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// +// Arguments: +// x: 2-D or higher with shape `[..., r_x, c_x]`. +// y: 2-D or higher with shape `[..., r_y, c_y]`. +// +// Returns 3-D or higher with shape `[..., r_o, c_o]` +func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "BatchMatMul", + Input: []tf.Input{ + x, y, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// BatchMatMulV2Attr is an optional argument to BatchMatMulV2. +type BatchMatMulV2Attr func(optionalAttr) + +// BatchMatMulV2AdjX sets the optional adj_x attribute to value. +// +// value: If `True`, adjoint the slices of `x`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulV2AdjX(value bool) BatchMatMulV2Attr { + return func(m optionalAttr) { + m["adj_x"] = value + } +} + +// BatchMatMulV2AdjY sets the optional adj_y attribute to value. +// +// value: If `True`, adjoint the slices of `y`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulV2AdjY(value bool) BatchMatMulV2Attr { + return func(m optionalAttr) { + m["adj_y"] = value + } +} + +// Multiplies slices of two tensors in batches. +// +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. Each of the +// individual slices can optionally be adjointed (to adjoint a matrix +// means to transpose and conjugate it) before multiplication by setting +// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if adj_x else r_x +// c_o = r_y if adj_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// +// *NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More +// about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). +// +// +// Arguments: +// x: 2-D or higher with shape `[..., r_x, c_x]`. +// y: 2-D or higher with shape `[..., r_y, c_y]`. +// +// Returns 3-D or higher with shape `[..., r_o, c_o]` +func BatchMatMulV2(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "BatchMatMulV2", + Input: []tf.Input{ + x, y, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the absolute value of a tensor. +// +// Given a tensor `x`, this operation returns a tensor containing the absolute +// value of each element in `x`. For example, if x is an input element and y is +// an output element, this operation computes \\(y = |x|\\). +func Abs(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Abs", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ComplexAbsAttr is an optional argument to ComplexAbs. +type ComplexAbsAttr func(optionalAttr) + +// ComplexAbsTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func ComplexAbsTout(value tf.DataType) ComplexAbsAttr { + return func(m optionalAttr) { + m["Tout"] = value + } +} + +// Computes the complex absolute value of a tensor. +// +// Given a tensor `x` of complex numbers, this operation returns a tensor of type +// `float` or `double` that is the absolute value of each element in `x`. All +// elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute +// value is computed as \\( \sqrt{a^2 + b^2}\\). +func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ComplexAbs", + Input: []tf.Input{ + x, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// BatchAttr is an optional argument to Batch. +type BatchAttr func(optionalAttr) + +// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value. +// If not specified, defaults to 10 +func BatchMaxEnqueuedBatches(value int64) BatchAttr { + return func(m optionalAttr) { + m["max_enqueued_batches"] = value + } +} + +// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value. +// If not specified, defaults to <> +func BatchAllowedBatchSizes(value []int64) BatchAttr { + return func(m optionalAttr) { + m["allowed_batch_sizes"] = value + } +} + +// BatchContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func BatchContainer(value string) BatchAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// BatchSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func BatchSharedName(value string) BatchAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// BatchBatchingQueue sets the optional batching_queue attribute to value. +// If not specified, defaults to "" +func BatchBatchingQueue(value string) BatchAttr { + return func(m optionalAttr) { + m["batching_queue"] = value + } +} + +// Batches all input tensors nondeterministically. +// +// When many instances of this Op are being run concurrently with the same +// container/shared_name in the same device, some will output zero-shaped Tensors +// and others will output Tensors of size up to max_batch_size. +// +// All Tensors in in_tensors are batched together (so, for example, labels and +// features should be batched with a single instance of this operation. +// +// Each invocation of batch emits an `id` scalar which will be used to identify +// this particular invocation when doing unbatch or its gradient. +// +// Each op which emits a non-empty batch will also emit a non-empty batch_index +// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, +// start, and length of elements of each set of Tensors present in batched_tensors. +// +// Batched tensors are concatenated along the first dimension, and all tensors in +// in_tensors must have the first dimension of the same size. +// +// in_tensors: The tensors to be batched. +// num_batch_threads: Number of scheduling threads for processing batches of work. +// Determines the number of batches processed in parallel. +// max_batch_size: Batch sizes will never be bigger than this. +// batch_timeout_micros: Maximum number of microseconds to wait before outputting +// an incomplete batch. +// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does +// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad +// batches up to one of those sizes. The entries must increase monotonically, and +// the final entry must equal max_batch_size. +// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. +// batched_tensors: Either empty tensors or a batch of concatenated Tensors. +// batch_index: If out_tensors is non-empty, has information to invert it. +// container: Controls the scope of sharing of this batch. +// id: always contains a scalar with a unique ID for this invocation of Batch. +// shared_name: Concurrently running instances of batch in the same device with the +// same container and shared_name will batch their elements together. If left +// empty, the op name will be used as the shared name. +// T: the types of tensors to be batched. +func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Batch", + Input: []tf.Input{ + tf.OutputList(in_tensors), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil { + scope.UpdateErr("Batch", err) + return + } + batch_index = op.Output(idx) + id = op.Output(idx) + return batched_tensors, batch_index, id +} + +// Computes numerical negative value element-wise. +// +// I.e., \\(y = -x\\). +func Neg(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Neg", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient for the inverse of `x` wrt its input. +// +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReciprocalGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the item in the list with the given index. +// +// input_handle: the list +// index: the position in the list from which an element will be retrieved +// item: the element at that position +// +// +func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, element_shape tf.Output, element_dtype tf.DataType) (item tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + opspec := tf.OpSpec{ + Type: "TensorListGetItem", + Input: []tf.Input{ + input_handle, index, element_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient for the sqrt of `x` wrt its input. +// +// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` +// is the corresponding input gradient. +func SqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SqrtGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes reciprocal of square root of x element-wise. +// +// I.e., \\(y = 1 / \sqrt{x}\\). +func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Rsqrt", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the gradient for the rsqrt of `x` wrt its input. +// +// Specifically, `grad = dy * -0.5 * y^3`, where `y = rsqrt(x)`, and `dy` +// is the corresponding input gradient. +func RsqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RsqrtGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes exponential of x element-wise. \\(y = e^x\\). +func Exp(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Exp", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes exponential of x - 1 element-wise. +// +// I.e., \\(y = (\exp x) - 1\\). +func Expm1(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Expm1", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the Cholesky decomposition of one or more square matrices. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. +// +// The input has to be symmetric and positive definite. Only the lower-triangular +// part of the input will be used for this operation. The upper-triangular part +// will not be read. +// +// The output is a tensor of the same shape as the input +// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. +// +// **Note**: The gradient computation on GPU is faster for large matrices but +// not for large batch dimensions when the submatrices are small. In this +// case it might be faster to use the CPU. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M, M]`. +func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cholesky", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LeakyReluGradAttr is an optional argument to LeakyReluGrad. +type LeakyReluGradAttr func(optionalAttr) + +// LeakyReluGradAlpha sets the optional alpha attribute to value. +// If not specified, defaults to 0.2 +func LeakyReluGradAlpha(value float32) LeakyReluGradAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// Computes rectified linear gradients for a LeakyRelu operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding LeakyRelu operation. +// features: The features passed as input to the corresponding LeakyRelu operation, +// OR the outputs of that operation (both work equivalently). +// +// Returns `gradients * (features > 0) + alpha * gradients * (featurs <= 0)`. +func LeakyReluGrad(scope *Scope, gradients tf.Output, features tf.Output, optional ...LeakyReluGradAttr) (backprops tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LeakyReluGrad", + Input: []tf.Input{ + gradients, features, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes hyperbolic cosine of x element-wise. +func Cosh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cosh", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes hyperbolic tangent of `x` element-wise. +func Tanh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Tanh", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns element-wise integer closest to x. +// +// If the result is midway between two representable values, +// the even representable is chosen. +// For example: +// +// ``` +// rint(-1.5) ==> -2.0 +// rint(0.5000001) ==> 1.0 +// rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] +// ``` +func Rint(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Rint", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes inverse hyperbolic cosine of x element-wise. +func Acosh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Acosh", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize. +type CudnnRNNParamsSizeAttr func(optionalAttr) + +// CudnnRNNParamsSizeRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNParamsSizeRnnMode(value string) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["rnn_mode"] = value + } +} + +// CudnnRNNParamsSizeInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNParamsSizeInputMode(value string) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNParamsSizeDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNParamsSizeDirection(value string) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNParamsSizeDropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsSizeDropout(value float32) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["dropout"] = value + } +} + +// CudnnRNNParamsSizeSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsSizeSeed(value int64) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// CudnnRNNParamsSizeSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsSizeSeed2(value int64) CudnnRNNParamsSizeAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Computes size of weights that can be used by a Cudnn RNN model. +// +// Return the params size that can be used by the Cudnn RNN model. Subsequent +// weight allocation and initialization should use this size. +// +// num_layers: Specifies the number of layers in the RNN model. +// num_units: Specifies the size of the hidden state. +// input_size: Specifies the size of the input state. +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// The actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. +// dir = (direction == bidirectional) ? 2 : 1 +// dropout: dropout probability. When set to 0., dropout is disabled. +// seed: the 1st part of a seed to initialize dropout. +// seed2: the 2nd part of a seed to initialize dropout. +// params_size: The size of the params buffer that should be allocated and +// initialized for this RNN model. Note that this params buffer may not be +// compatible across GPUs. Please use CudnnRNNParamsWeights and +// CudnnRNNParamsBiases to save and restore them in a way that is compatible +// across different runs. +func CudnnRNNParamsSize(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, T tf.DataType, S tf.DataType, optional ...CudnnRNNParamsSizeAttr) (params_size tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T, "S": S} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CudnnRNNParamsSize", + Input: []tf.Input{ + num_layers, num_units, input_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the log of the absolute value of `Gamma(x)` element-wise. +func Lgamma(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Lgamma", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. +type QueueDequeueManyV2Attr func(optionalAttr) + +// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue has fewer than n elements, this operation +// will block for up to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Dequeues `n` tuples of one or more tensors from the given queue. +// +// If the queue is closed and there are fewer than `n` elements, then an +// OutOfRange error is returned. +// +// This operation concatenates queue-element component tensors along the +// 0th dimension to make a single component tensor. All of the components +// in the dequeued tuple will have size `n` in the 0th dimension. +// +// This operation has `k` outputs, where `k` is the number of components in +// the tuples stored in the given queue, and output `i` is the ith +// component of the dequeued tuple. +// +// N.B. If the queue is empty, this operation will block until `n` elements +// have been dequeued (or 'timeout_ms' elapses, if specified). +// +// Arguments: +// handle: The handle to a queue. +// n: The number of tuples to dequeue. +// component_types: The type of each component in a tuple. +// +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QueueDequeueManyV2", + Input: []tf.Input{ + handle, n, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueManyV2", err) + return + } + return components +} + +// Computes Psi, the derivative of Lgamma (the log of the absolute value of +// +// `Gamma(x)`), element-wise. +func Digamma(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Digamma", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adjust the saturation of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A scale is then applied all the saturation +// values, and then remapped back to RGB colorspace. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// scale: A float scale to add to the saturation. +// +// Returns The hue-adjusted image or images. +func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustSaturation", + Input: []tf.Input{ + images, scale, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the Gauss error function of `x` element-wise. +func Erf(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Erf", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output) (selected_indices tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NonMaxSuppressionV2", + Input: []tf.Input{ + boxes, scores, max_output_size, iou_threshold, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system and more +// generally is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionV3(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NonMaxSuppressionV3", + Input: []tf.Input{ + boxes, scores, max_output_size, iou_threshold, score_threshold, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Enqueue a Tensor on the computation outfeed. +// +// Arguments: +// input: A tensor that will be inserted into the outfeed queue. +// +// Returns the created operation. +func OutfeedEnqueue(scope *Scope, input tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "OutfeedEnqueue", + Input: []tf.Input{ + input, + }, + } + return scope.AddOperation(opspec) +} + +// Computes the gradient of the sigmoid of `x` wrt its input. +// +// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and +// `dy` is the corresponding input gradient. +func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SigmoidGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// An Op to permute tensors across replicated TPU instances. +// +// Each instance supplies its own input. +// +// For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +// source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: +// `[D, A, B, C]`. +// +// Arguments: +// input: The local input to be permuted. Currently only supports float and +// bfloat16. +// source_target_pairs: A tensor with shape [num_pairs, 2]. +// +// Returns The permuted input. +func CollectivePermute(scope *Scope, input tf.Output, source_target_pairs tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CollectivePermute", + Input: []tf.Input{ + input, source_target_pairs, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Splits a tensor into a list. +// +// list[i] corresponds to lengths[i] tensors from the input tensor. +// The tensor must have rank at least 1 and contain exactly sum(lengths) elements. +// +// tensor: The input tensor. +// element_shape: A shape compatible with that of elements in the tensor. +// lengths: Vector of sizes of the 0th dimension of tensors in the list. +// output_handle: The list. +func TensorListSplit(scope *Scope, tensor tf.Output, element_shape tf.Output, lengths tf.Output) (output_handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorListSplit", + Input: []tf.Input{ + tensor, element_shape, lengths, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes cos of x element-wise. +func Cos(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cos", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes tan of x element-wise. +func Tan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Tan", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SumAttr is an optional argument to Sum. +type SumAttr func(optionalAttr) + +// SumKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SumKeepDims(value bool) SumAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the sum of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Sum", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// EnqueueTPUEmbeddingSparseTensorBatchAttr is an optional argument to EnqueueTPUEmbeddingSparseTensorBatch. +type EnqueueTPUEmbeddingSparseTensorBatchAttr func(optionalAttr) + +// EnqueueTPUEmbeddingSparseTensorBatchDeviceOrdinal sets the optional device_ordinal attribute to value. +// +// value: The TPU device to use. Should be >= 0 and less than the number +// of TPU cores in the task on which the node is placed. +// If not specified, defaults to -1 +func EnqueueTPUEmbeddingSparseTensorBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingSparseTensorBatchAttr { + return func(m optionalAttr) { + m["device_ordinal"] = value + } +} + +// EnqueueTPUEmbeddingSparseTensorBatchCombiners sets the optional combiners attribute to value. +// +// value: A list of string scalars, one for each embedding table that specify +// how to normalize the embedding activations after weighted summation. +// Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have +// the sum of the weights be 0 for 'mean' or the sum of the squared weights be +// 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for +// all tables. +// If not specified, defaults to <> +func EnqueueTPUEmbeddingSparseTensorBatchCombiners(value []string) EnqueueTPUEmbeddingSparseTensorBatchAttr { + return func(m optionalAttr) { + m["combiners"] = value + } +} + +// EnqueueTPUEmbeddingSparseTensorBatchMaxSequenceLengths sets the optional max_sequence_lengths attribute to value. +// If not specified, defaults to <> +func EnqueueTPUEmbeddingSparseTensorBatchMaxSequenceLengths(value []int64) EnqueueTPUEmbeddingSparseTensorBatchAttr { + return func(m optionalAttr) { + m["max_sequence_lengths"] = value + } +} + +// Eases the porting of code that uses tf.nn.embedding_lookup_sparse(). +// +// sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond +// to the ith feature. table_ids[i] indicates which embedding table to look up ith +// feature. +// +// The tensors at corresponding positions in the three input lists (sample_indices, +// embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1 +// with dim_size() equal to the total number of lookups into the table described by +// the corresponding feature. +// +// Arguments: +// sample_indices: A list of rank 1 Tensors specifying the training example to +// which the corresponding embedding_indices and aggregation_weights values +// belong. It corresponds to sp_ids.indices[:,0] in embedding_lookup_sparse(). +// embedding_indices: A list of rank 1 Tensors, indices into the embedding tables. +// It corresponds to sp_ids.values in embedding_lookup_sparse(). +// aggregation_weights: A list of rank 1 Tensors containing per training example +// aggregation weights. It corresponds to sp_weights.values in +// embedding_lookup_sparse(). +// mode_override: A string input that overrides the mode specified in the +// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', +// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set +// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. +// table_ids: A list of integers specifying the identifier of the embedding table +// (offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the +// corresponding input. The ith input is looked up using table_ids[i]. The size +// of the table_ids list must be equal to that of sample_indices, +// embedding_indices and aggregation_weights. +// +// Returns the created operation. +func EnqueueTPUEmbeddingSparseTensorBatch(scope *Scope, sample_indices []tf.Output, embedding_indices []tf.Output, aggregation_weights []tf.Output, mode_override tf.Output, table_ids []int64, optional ...EnqueueTPUEmbeddingSparseTensorBatchAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"table_ids": table_ids} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EnqueueTPUEmbeddingSparseTensorBatch", + Input: []tf.Input{ + tf.OutputList(sample_indices), tf.OutputList(embedding_indices), tf.OutputList(aggregation_weights), mode_override, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Computes the trignometric inverse tangent of x element-wise. +// +// The `tf.math.atan` operation returns the inverse of `tf.math.tan`, such that +// if `y = tf.math.tan(x)` then, `x = tf.math.atan(y)`. +// +// **Note**: The output of `tf.math.atan` will lie within the invertible range +// of tan, i.e (-pi/2, pi/2). +// +// For example: +// +// ```python +// # Note: [1.047, 0.785] ~= [(pi/3), (pi/4)] +// x = tf.constant([1.047, 0.785]) +// y = tf.math.tan(x) # [1.731261, 0.99920404] +// +// tf.math.atan(y) # [1.047, 0.785] = x +// ``` +// +func Atan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Atan", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAddSignAttr is an optional argument to ResourceApplyAddSign. +type ResourceApplyAddSignAttr func(optionalAttr) + +// ResourceApplyAddSignUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and m tensors is +// protected by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAddSignUseLocking(value bool) ResourceApplyAddSignAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the AddSign update. +// +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// update <- (alpha + sign_decay * sign(g) *sign(m)) * g +// variable <- variable - lr_t * update +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// alpha: Must be a scalar. +// sign_decay: Must be a scalar. +// beta: Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, alpha tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyAddSignAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAddSign", + Input: []tf.Input{ + var_, m, lr, alpha, sign_decay, beta, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// ResizeAreaAttr is an optional argument to ResizeArea. +type ResizeAreaAttr func(optionalAttr) + +// ResizeAreaAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeAreaAlignCorners(value bool) ResizeAreaAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Resize `images` to `size` using area interpolation. +// +// Input images can be of different types but output images are always float. +// +// The range of pixel values for the output image might be slightly different +// from the range for the input image because of limited numerical precision. +// To guarantee an output range, for example `[0.0, 1.0]`, apply +// `tf.clip_by_value` to the output. +// +// Each output pixel is computed by first transforming the pixel's footprint into +// the input tensor and then averaging the pixels that intersect the footprint. An +// input pixel's contribution to the average is weighted by the fraction of its +// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeArea", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the Bessel i1e function of `x` element-wise. +// +// Exponentially scaled modified Bessel function of order 0 defined as +// `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. +// +// This function is faster and numerically stabler than `bessel_i1(x)`. +func BesselI1e(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BesselI1e", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// VariableShapeAttr is an optional argument to VariableShape. +type VariableShapeAttr func(optionalAttr) + +// VariableShapeOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func VariableShapeOutType(value tf.DataType) VariableShapeAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Returns the shape of the variable pointed to by `resource`. +// +// This operation returns a 1-D integer tensor representing the shape of `input`. +// +// For example: +// +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// shape(t) ==> [2, 2, 3] +// ``` +func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "VariableShape", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns which elements of x are NaN. +// +// @compatibility(numpy) +// Equivalent to np.isnan +// @end_compatibility +func IsNan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsNan", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TakeManySparseFromTensorsMapAttr is an optional argument to TakeManySparseFromTensorsMap. +type TakeManySparseFromTensorsMapAttr func(optionalAttr) + +// TakeManySparseFromTensorsMapContainer sets the optional container attribute to value. +// +// value: The container name for the `SparseTensorsMap` read by this op. +// If not specified, defaults to "" +func TakeManySparseFromTensorsMapContainer(value string) TakeManySparseFromTensorsMapAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// TakeManySparseFromTensorsMapSharedName sets the optional shared_name attribute to value. +// +// value: The shared name for the `SparseTensorsMap` read by this op. +// It should not be blank; rather the `shared_name` or unique Operation name +// of the Op that created the original `SparseTensorsMap` should be used. +// If not specified, defaults to "" +func TakeManySparseFromTensorsMapSharedName(value string) TakeManySparseFromTensorsMapAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. +// +// The input `sparse_handles` must be an `int64` matrix of shape `[N, 1]` where +// `N` is the minibatch size and the rows correspond to the output handles of +// `AddSparseToTensorsMap` or `AddManySparseToTensorsMap`. The ranks of the +// original `SparseTensor` objects that went into the given input ops must all +// match. When the final `SparseTensor` is created, it has rank one +// higher than the ranks of the incoming `SparseTensor` objects +// (they have been concatenated along a new row dimension on the left). +// +// The output `SparseTensor` object's shape values for all dimensions but the +// first are the max across the input `SparseTensor` objects' shape values +// for the corresponding dimensions. Its first shape value is `N`, the minibatch +// size. +// +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. +// +// For example, if the handles represent an input, which is a `[2, 3]` matrix +// representing two original `SparseTensor` objects: +// +// ``` +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// ``` +// +// and +// +// ``` +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// ``` +// +// then the final `SparseTensor` will be: +// +// ``` +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// ``` +// +// Arguments: +// sparse_handles: 1-D, The `N` serialized `SparseTensor` objects. +// Shape: `[N]`. +// dtype: The `dtype` of the `SparseTensor` objects stored in the +// `SparseTensorsMap`. +// +// Returns 2-D. The `indices` of the minibatch `SparseTensor`.1-D. The `values` of the minibatch `SparseTensor`.1-D. The `shape` of the minibatch `SparseTensor`. +func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype tf.DataType, optional ...TakeManySparseFromTensorsMapAttr) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TakeManySparseFromTensorsMap", + Input: []tf.Input{ + sparse_handles, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Returns which elements of x are finite. +// +// @compatibility(numpy) +// Equivalent to np.isfinite +// @end_compatibility +func IsFinite(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsFinite", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns an element-wise indication of the sign of a number. +// +// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`. +// +// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. +func Sign(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sign", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns element-wise largest integer not greater than x. +func Floor(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Floor", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SpaceToBatch for 4-D tensors of type T. +// +// This is a legacy version of the more general SpaceToBatchND. +// +// Zero-pads and then rearranges (permutes) blocks of spatial data into batch. +// More specifically, this op outputs a copy of the input tensor where values from +// the `height` and `width` dimensions are moved to the `batch` dimension. After +// the zero-padding, both `height` and `width` of the input must be divisible by the +// block size. +// +// Arguments: +// input: 4-D with shape `[batch, height, width, depth]`. +// paddings: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies +// the padding of the input with zeros across the spatial dimensions as follows: +// +// paddings = [[pad_top, pad_bottom], [pad_left, pad_right]] +// +// The effective spatial dimensions of the zero-padded input tensor will be: +// +// height_pad = pad_top + height + pad_bottom +// width_pad = pad_left + width + pad_right +// +// The attr `block_size` must be greater than one. It indicates the block size. +// +// * Non-overlapping blocks of size `block_size x block size` in the height and +// width dimensions are rearranged into the batch dimension at each location. +// * The batch of the output tensor is `batch * block_size * block_size`. +// * Both height_pad and width_pad must be divisible by block_size. +// +// The shape of the output will be: +// +// [batch*block_size*block_size, height_pad/block_size, width_pad/block_size, +// depth] +// +// Some examples: +// +// (1) For the following input of shape `[1, 2, 2, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// The output tensor has shape `[4, 1, 1, 1]` and value: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// (2) For the following input of shape `[1, 2, 2, 3]` and block_size of 2: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// The output tensor has shape `[4, 1, 1, 3]` and value: +// +// ``` +// [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] +// ``` +// +// (3) For the following input of shape `[1, 4, 4, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +// +// The output tensor has shape `[4, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// (4) For the following input of shape `[2, 2, 4, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]]], +// [[[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +// +// The output tensor has shape `[8, 1, 2, 1]` and value: +// +// ``` +// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], +// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] +// ``` +// +// Among others, this operation is useful for reducing atrous convolution into +// regular convolution. +// +func SpaceToBatch(scope *Scope, input tf.Output, paddings tf.Output, block_size int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"block_size": block_size} + opspec := tf.OpSpec{ + Type: "SpaceToBatch", + Input: []tf.Input{ + input, paddings, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// The shape of the elements of the given list, as a tensor. +// +// input_handle: the list +// element_shape: the shape of elements of the list +func TensorListElementShape(scope *Scope, input_handle tf.Output, shape_type tf.DataType) (element_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape_type": shape_type} + opspec := tf.OpSpec{ + Type: "TensorListElementShape", + Input: []tf.Input{ + input_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns element-wise smallest integer not less than x. +func Ceil(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Ceil", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MinAttr is an optional argument to Min. +type MinAttr func(optionalAttr) + +// MinKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func MinKeepDims(value bool) MinAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the minimum of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Min", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x + y element-wise. +// +// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AddV2", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x * y element-wise. +// +// *NOTE*: `Multiply` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Mul(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Mul", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Concatenates a list of `N` tensors along the first dimension. +// +// The input tensors are all required to have size 1 in the first dimension. +// +// For example: +// +// ``` +// # 'x' is [[1, 4]] +// # 'y' is [[2, 5]] +// # 'z' is [[3, 6]] +// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +// ``` +// +// The difference between concat and parallel_concat is that concat requires all +// of the inputs be computed before the operation will begin but doesn't require +// that the input shapes be known during graph construction. Parallel concat +// will copy pieces of the input into the output as they become available, in +// some situations this can provide a performance benefit. +// +// Arguments: +// values: Tensors to be concatenated. All must have size 1 in the first dimension +// and same shape. +// shape: the final shape of the result; should be equal to the shapes of any input +// but with the number of input values in the first dimension. +// +// Returns The concatenated tensor. +func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "ParallelConcat", + Input: []tf.Input{ + tf.OutputList(values), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adjust the contrast of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are +// interpreted as `[height, width, channels]`. The other dimensions only +// represent a collection of images, such as `[batch, height, width, channels].` +// +// Contrast is adjusted independently for each channel of each image. +// +// For each channel, the Op first computes the mean of the image pixels in the +// channel and then adjusts each component of each pixel to +// `(x - mean) * contrast_factor + mean`. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// contrast_factor: A float multiplier for adjusting contrast. +// +// Returns The contrast-adjusted image or images. +func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustContrastv2", + Input: []tf.Input{ + images, contrast_factor, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN. +// +// *NOTE*: `MulNoNan` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func MulNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MulNoNan", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AddSparseToTensorsMapAttr is an optional argument to AddSparseToTensorsMap. +type AddSparseToTensorsMapAttr func(optionalAttr) + +// AddSparseToTensorsMapContainer sets the optional container attribute to value. +// +// value: The container name for the `SparseTensorsMap` created by this op. +// If not specified, defaults to "" +func AddSparseToTensorsMapContainer(value string) AddSparseToTensorsMapAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// AddSparseToTensorsMapSharedName sets the optional shared_name attribute to value. +// +// value: The shared name for the `SparseTensorsMap` created by this op. +// If blank, the new Operation's unique name is used. +// If not specified, defaults to "" +func AddSparseToTensorsMapSharedName(value string) AddSparseToTensorsMapAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Add a `SparseTensor` to a `SparseTensorsMap` return its handle. +// +// A `SparseTensor` is represented by three tensors: `sparse_indices`, +// `sparse_values`, and `sparse_shape`. +// +// This operator takes the given `SparseTensor` and adds it to a container +// object (a `SparseTensorsMap`). A unique key within this container is generated +// in the form of an `int64`, and this is the value that is returned. +// +// The `SparseTensor` can then be read out as part of a minibatch by passing +// the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure +// the correct `SparseTensorsMap` is accessed, ensure that the same +// `container` and `shared_name` are passed to that Op. If no `shared_name` +// is provided here, instead use the *name* of the Operation created by calling +// `AddSparseToTensorsMap` as the `shared_name` passed to +// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. +// +// Arguments: +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +// +// Returns 0-D. The handle of the `SparseTensor` now stored in the +// `SparseTensorsMap`. +func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddSparseToTensorsMapAttr) (sparse_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AddSparseToTensorsMap", + Input: []tf.Input{ + sparse_indices, sparse_values, sparse_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes a range that covers the actual values present in a quantized tensor. +// +// Given a quantized tensor described by `(input, input_min, input_max)`, outputs a +// range that covers the actual values present in that tensor. This op is typically +// used to produce the `requested_output_min` and `requested_output_max` for +// `Requantize`. +// +// Arguments: +// +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// +// Returns The computed min output.the computed max output. +func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RequantizationRange", + Input: []tf.Input{ + input, input_min, input_max, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Returns x / y element-wise. +// +// *NOTE*: `Div` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Div(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Div", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x / y element-wise for real types. +// +// If `x` and `y` are reals, this will return the floating-point division. +// +// *NOTE*: `Div` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RealDiv", + Input: []tf.Input{ + x, y, }, } op := scope.AddOperation(opspec) @@ -17021,94 +31064,26 @@ func QuantizeAndDequantizeV2(scope *Scope, input tf.Output, input_min tf.Output, return op.Output(0) } -// A TPU core selector Op. -// -// This Op produces a set of TPU cores (for warm-up) or a single TPU core -// (for regular inference) to execute the TPU program on. The output is -// consumed by TPUPartitionedCall. -// -// Returns A vector 1 or more TPU cores. -func TPUOrdinalSelector(scope *Scope) (device_ordinals tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TPUOrdinalSelector", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// StringToNumberAttr is an optional argument to StringToNumber. +type StringToNumberAttr func(optionalAttr) -// Looks up keys in a table, outputs the corresponding values. +// StringToNumberOutType sets the optional out_type attribute to value. // -// The tensor `keys` must of the same type as the keys of the table. -// The output `values` is of the type of the table values. -// -// The scalar `default_value` is the value output for keys not present in the -// table. It must also be of the same type as the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// -// -// Returns Same shape as `keys`. Values found in the table, or `default_values` -// for missing keys. -func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableFindV2", - Input: []tf.Input{ - table_handle, keys, default_value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. -type ResourceSparseApplyRMSPropAttr func(optionalAttr) - -// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr { +// value: The numeric type to interpret each string in `string_tensor` as. +// If not specified, defaults to DT_FLOAT +func StringToNumberOutType(value tf.DataType) StringToNumberAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["out_type"] = value } } -// Update '*var' according to the RMSProp algorithm. +// Converts each string in the input Tensor to the specified numeric type. // -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms -// and mom will not update in iterations during which the grad is zero. +// (Note that int32 overflow results in an error while float overflow +// results in a rounded value.) // -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) -// -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom -// -// Arguments: -// var_: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. -// -// Returns the created operation. -func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) { +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -17117,15 +31092,579 @@ func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyRMSProp", + Type: "StringToNumber", Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, indices, + string_tensor, }, Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns 0 if x == 0, and x * log(y) otherwise, elementwise. +func Xlogy(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Xlogy", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the min of x and y (i.e. x < y ? x : y) element-wise. +// +// *NOTE*: `Minimum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Minimum", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the power of one value to another. +// +// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for +// corresponding elements in `x` and `y`. For example: +// +// ``` +// # tensor 'x' is [[2, 2]], [3, 3]] +// # tensor 'y' is [[8, 16], [2, 3]] +// tf.pow(x, y) ==> [[256, 65536], [9, 27]] +// ``` +func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Pow", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compute the upper regularized incomplete Gamma function `Q(a, x)`. +// +// The upper regularized incomplete Gamma function is defined as: +// +// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) +// +// where +// +// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) +// +// is the upper incomplete Gama function. +// +// Note, above `P(a, x)` (`Igamma`) is the lower regularized complete +// Gamma function. +func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Igammac", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reshapes a SparseTensor to represent values in a new dense shape. +// +// This operation has the same semantics as reshape on the represented dense +// tensor. The `input_indices` are recomputed based on the requested `new_shape`. +// +// If one component of `new_shape` is the special value -1, the size of that +// dimension is computed so that the total dense size remains constant. At +// most one component of `new_shape` can be -1. The number of dense elements +// implied by `new_shape` must be the same as the number of dense elements +// originally implied by `input_shape`. +// +// Reshaping does not affect the order of values in the SparseTensor. +// +// If the input tensor has rank `R_in` and `N` non-empty values, and `new_shape` +// has length `R_out`, then `input_indices` has shape `[N, R_in]`, +// `input_shape` has length `R_in`, `output_indices` has shape `[N, R_out]`, and +// `output_shape` has length `R_out`. +// +// Arguments: +// input_indices: 2-D. `N x R_in` matrix with the indices of non-empty values in a +// SparseTensor. +// input_shape: 1-D. `R_in` vector with the input SparseTensor's dense shape. +// new_shape: 1-D. `R_out` vector with the requested new dense shape. +// +// Returns 2-D. `N x R_out` matrix with the updated indices of non-empty +// values in the output SparseTensor.1-D. `R_out` vector with the full dense shape of the output +// SparseTensor. This is the same as `new_shape` but with any -1 dimensions +// filled in. +func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, new_shape tf.Output) (output_indices tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseReshape", + Input: []tf.Input{ + input_indices, input_shape, new_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Computes the gradient of `igamma(a, x)` wrt `a`. +func IgammaGradA(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IgammaGradA", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// BatchToSpace for 4-D tensors of type T. +// +// This is a legacy version of the more general BatchToSpaceND. +// +// Rearranges (permutes) data from batch into blocks of spatial data, followed by +// cropping. This is the reverse transformation of SpaceToBatch. More specifically, +// this op outputs a copy of the input tensor where values from the `batch` +// dimension are moved in spatial blocks to the `height` and `width` dimensions, +// followed by cropping along the `height` and `width` dimensions. +// +// Arguments: +// input: 4-D tensor with shape +// `[batch*block_size*block_size, height_pad/block_size, width_pad/block_size, +// depth]`. Note that the batch size of the input tensor must be divisible by +// `block_size * block_size`. +// crops: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies +// how many elements to crop from the intermediate result across the spatial +// dimensions as follows: +// +// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] +// +// +// Returns 4-D with shape `[batch, height, width, depth]`, where: +// +// height = height_pad - crop_top - crop_bottom +// width = width_pad - crop_left - crop_right +// +// The attr `block_size` must be greater than one. It indicates the block size. +// +// Some examples: +// +// (1) For the following input of shape `[4, 1, 1, 1]` and block_size of 2: +// +// ``` +// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 1]` and value: +// +// ``` +// x = [[[[1], [2]], [[3], [4]]]] +// ``` +// +// (2) For the following input of shape `[4, 1, 1, 3]` and block_size of 2: +// +// ``` +// [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] +// ``` +// +// The output tensor has shape `[1, 2, 2, 3]` and value: +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// (3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [3]], [[9], [11]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +// +// The output tensor has shape `[1, 4, 4, 1]` and value: +// +// ``` +// x = [[[[1], [2], [3], [4]], +// [[5], [6], [7], [8]], +// [[9], [10], [11], [12]], +// [[13], [14], [15], [16]]]] +// ``` +// +// (4) For the following input of shape `[8, 1, 2, 1]` and block_size of 2: +// +// ``` +// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], +// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] +// ``` +// +// The output tensor has shape `[2, 2, 4, 1]` and value: +// +// ``` +// x = [[[[1], [3]], [[5], [7]]], +// [[[2], [4]], [[10], [12]]], +// [[[5], [7]], [[13], [15]]], +// [[[6], [8]], [[14], [16]]]] +// ``` +func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"block_size": block_size} + opspec := tf.OpSpec{ + Type: "BatchToSpace", + Input: []tf.Input{ + input, crops, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compute the Hurwitz zeta function \\(\zeta(x, q)\\). +// +// The Hurwitz zeta function is defined as: +// +// +// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) +func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Zeta", + Input: []tf.Input{ + x, q, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeRawAttr is an optional argument to DecodeRaw. +type DecodeRawAttr func(optionalAttr) + +// DecodeRawLittleEndian sets the optional little_endian attribute to value. +// +// value: Whether the input `bytes` are in little-endian order. +// Ignored for `out_type` values that are stored in a single byte like +// `uint8`. +// If not specified, defaults to true +func DecodeRawLittleEndian(value bool) DecodeRawAttr { + return func(m optionalAttr) { + m["little_endian"] = value + } +} + +// Reinterpret the bytes of a string as a vector of numbers. +// +// Arguments: +// bytes: All the elements must have the same length. +// +// +// Returns A Tensor with one more dimension than the input `bytes`. The +// added dimension will have size equal to the length of the elements +// of `bytes` divided by the number of bytes to represent `out_type`. +func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeRaw", + Input: []tf.Input{ + bytes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. +// +// This is the angle \( \theta \in [-\pi, \pi] \) such that +// \[ x = r \cos(\theta) \] +// and +// \[ y = r \sin(\theta) \] +// where \(r = \sqrt(x^2 + y^2) \). +func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Atan2", + Input: []tf.Input{ + y, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// PrelinearizeAttr is an optional argument to Prelinearize. +type PrelinearizeAttr func(optionalAttr) + +// PrelinearizeShape sets the optional shape attribute to value. +// +// value: The shape of the tensor. +// If not specified, defaults to <> +func PrelinearizeShape(value tf.Shape) PrelinearizeAttr { + return func(m optionalAttr) { + m["shape"] = value + } +} + +// PrelinearizeLayout sets the optional layout attribute to value. +// +// value: A vector holding the requested layout in minor-to-major sequence. If a layout +// attribute is passed but its values are all -1 the layout will be computed by +// the infeed operation. +// If not specified, defaults to <> +func PrelinearizeLayout(value []int64) PrelinearizeAttr { + return func(m optionalAttr) { + m["layout"] = value + } +} + +// An op which linearizes one Tensor value to an opaque variant tensor. +// +// Arguments: +// input: A tensor that will be linearized. +func Prelinearize(scope *Scope, input tf.Output, optional ...PrelinearizeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Prelinearize", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compute the regularized incomplete beta integral \\(I_x(a, b)\\). +// +// The regularized incomplete beta integral is defined as: +// +// +// \\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) +// +// where +// +// +// \\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) +// +// +// is the incomplete beta function and \\(B(a, b)\\) is the *complete* +// beta function. +func Betainc(scope *Scope, a tf.Output, b tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Betainc", + Input: []tf.Input{ + a, b, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// An Op to exchange data across TPU replicas. +// +// On each replica, the input is split into `split_count` blocks along +// `split_dimension` and send to the other replicas given group_assignment. After +// receiving `split_count` - 1 blocks from other replicas, we concatenate the +// blocks along `concat_dimension` as the output. +// +// For example, suppose there are 2 TPU replicas: +// replica 0 receives input: `[[A, B]]` +// replica 1 receives input: `[[C, D]]` +// +// group_assignment=`[[0, 1]]` +// concat_dimension=0 +// split_dimension=1 +// split_count=2 +// +// replica 0's output: `[[A], [C]]` +// replica 1's output: `[[B], [D]]` +// +// Arguments: +// input: The local input to the sum. +// group_assignment: An int32 tensor with shape +// [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the +// replica ids in the ith subgroup. +// concat_dimension: The dimension number to concatenate. +// split_dimension: The dimension number to split. +// split_count: The number of splits, this number must equal to the sub-group +// size(group_assignment.get_shape()[1]) +// +// Returns The exchanged result. +func AllToAll(scope *Scope, input tf.Output, group_assignment tf.Output, concat_dimension int64, split_dimension int64, split_count int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"concat_dimension": concat_dimension, "split_dimension": split_dimension, "split_count": split_count} + opspec := tf.OpSpec{ + Type: "AllToAll", + Input: []tf.Input{ + input, group_assignment, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the truth value of (x < y) element-wise. +// +// *NOTE*: `Less` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Less", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MeanAttr is an optional argument to Mean. +type MeanAttr func(optionalAttr) + +// MeanKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func MeanKeepDims(value bool) MeanAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the mean of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Mean(scope *Scope, input tf.Output, axis tf.Output, optional ...MeanAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Mean", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Saves tensors in V2 checkpoint format. +// +// By default, saves the named tensors in full. If the caller wishes to save +// specific slices of full tensors, "shape_and_slices" should be non-empty strings +// and correspondingly well-formed. +// +// Arguments: +// prefix: Must have a single element. The prefix of the V2 checkpoint to which we +// write the tensors. +// tensor_names: shape {N}. The names of the tensors to be saved. +// shape_and_slices: shape {N}. The slice specs of the tensors to be saved. +// Empty strings indicate that they are non-partitioned tensors. +// tensors: `N` tensors to save. +// +// Returns the created operation. +func SaveV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, tensors []tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SaveV2", + Input: []tf.Input{ + prefix, tensor_names, shape_and_slices, tf.OutputList(tensors), + }, + } return scope.AddOperation(opspec) } +// Returns the truth value of (x <= y) element-wise. +// +// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LessEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LessEqual", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns the truth value of (x > y) element-wise. // // *NOTE*: `Greater` supports broadcasting. More about broadcasting @@ -17144,27 +31683,967 @@ func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Creates a TensorList by indexing into a Tensor. +// Returns the truth value of (x >= y) element-wise. // -// Each member of the TensorList corresponds to one row of the input tensor, -// specified by the given index (see `tf.gather`). -// -// tensor: The input tensor. -// indices: The indices used to index into the list. -// element_shape: The shape of the elements in the list (can be less specified than -// the shape of the tensor). -// num_elements: The size of the output list. Must be large enough to accommodate -// the largest index in indices. If -1, the list is just large enough to include -// the largest index in indices. -// output_handle: The TensorList. -func TensorListScatterV2(scope *Scope, tensor tf.Output, indices tf.Output, element_shape tf.Output, num_elements tf.Output) (output_handle tf.Output) { +// *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func GreaterEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListScatterV2", + Type: "GreaterEqual", Input: []tf.Input{ - tensor, indices, element_shape, num_elements, + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the truth value of (x == y) element-wise. +// +// *NOTE*: `Equal` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Equal", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// EncodeProtoAttr is an optional argument to EncodeProto. +type EncodeProtoAttr func(optionalAttr) + +// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value. +// If not specified, defaults to "local://" +func EncodeProtoDescriptorSource(value string) EncodeProtoAttr { + return func(m optionalAttr) { + m["descriptor_source"] = value + } +} + +// The op serializes protobuf messages provided in the input tensors. +// +// The types of the tensors in `values` must match the schema for the +// fields specified in `field_names`. All the tensors in `values` must +// have a common shape prefix, *batch_shape*. +// +// The `sizes` tensor specifies repeat counts for each field. The repeat +// count (last dimension) of a each tensor in `values` must be greater +// than or equal to corresponding repeat count in `sizes`. +// +// A `message_type` name must be provided to give context for the field +// names. The actual message descriptor can be looked up either in the +// linked-in descriptor pool or a filename provided by the caller using +// the `descriptor_source` attribute. +// +// The `descriptor_source` attribute selects a source of protocol +// descriptors to consult when looking up `message_type`. This may be a +// filename containing a serialized `FileDescriptorSet` message, +// or the special value `local://`, in which case only descriptors linked +// into the code will be searched; the filename can be on any filesystem +// accessible to TensorFlow. +// +// You can build a `descriptor_source` file using the `--descriptor_set_out` +// and `--include_imports` options to the protocol compiler `protoc`. +// +// The `local://` database only covers descriptors linked into the +// code via C++ libraries, not Python imports. You can link in a proto descriptor +// by creating a cc_library target with alwayslink=1. +// +// There are a few special cases in the value mapping: +// +// Submessage and group fields must be pre-serialized as TensorFlow strings. +// +// TensorFlow lacks support for unsigned int64s, so they must be +// represented as `tf.int64` with the same twos-complement bit pattern +// (the obvious way). +// +// Unsigned int32 values can be represented exactly with `tf.int64`, or +// with sign wrapping if the input is of type `tf.int32`. +// +// Arguments: +// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`. +// values: List of tensors containing values for the corresponding field. +// field_names: List of strings containing proto field names. +// message_type: Name of the proto message type to decode. +// +// Returns Tensor of serialized protos with shape `batch_shape`. +func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EncodeProto", + Input: []tf.Input{ + sizes, tf.OutputList(values), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Splits a tensor into `num_split` tensors along one dimension. +// +// Arguments: +// value: The tensor to split. +// size_splits: list containing the sizes of each output tensor along the split +// dimension. Must sum to the dimension of value along split_dim. +// Can contain one -1 indicating that dimension is to be inferred. +// axis: 0-D. The dimension along which to split. Must be in the range +// `[-rank(value), rank(value))`. +// +// +// Returns Tensors whose shape matches that of `value` +// except along `axis`, where their sizes are +// `size_splits[i]`. +func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_split": num_split} + opspec := tf.OpSpec{ + Type: "SplitV", + Input: []tf.Input{ + value, size_splits, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("SplitV", err) + return + } + return output +} + +// UnbatchGradAttr is an optional argument to UnbatchGrad. +type UnbatchGradAttr func(optionalAttr) + +// UnbatchGradContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnbatchGradContainer(value string) UnbatchGradAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// UnbatchGradSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnbatchGradSharedName(value string) UnbatchGradAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Gradient of Unbatch. +// +// Acts like Batch but using the given batch_index index of batching things as they +// become available. This ensures that the gradients are propagated back in the +// same session which did the forward pass. +// +// original_input: The input to the Unbatch operation this is the gradient of. +// batch_index: The batch_index given to the Unbatch operation this is the gradient +// of. +// grad: The downstream gradient. +// id: The id scalar emitted by Batch. +// batched_grad: The return value, either an empty tensor or the batched gradient. +// container: Container to control resource sharing. +// shared_name: Instances of UnbatchGrad with the same container and shared_name +// are assumed to possibly belong to the same batch. If left empty, the op name +// will be used as the shared name. +func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, grad tf.Output, id tf.Output, optional ...UnbatchGradAttr) (batched_grad tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UnbatchGrad", + Input: []tf.Input{ + original_input, batch_index, grad, id, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// BiasAddAttr is an optional argument to BiasAdd. +type BiasAddAttr func(optionalAttr) + +// BiasAddDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the bias tensor will be added to the last dimension +// of the value tensor. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// The tensor will be added to "in_channels", the third-to-the-last +// dimension. +// If not specified, defaults to "NHWC" +func BiasAddDataFormat(value string) BiasAddAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Adds `bias` to `value`. +// +// This is a special case of `tf.add` where `bias` is restricted to be 1-D. +// Broadcasting is supported, so `value` may have any number of dimensions. +// +// Arguments: +// value: Any number of dimensions. +// bias: 1-D with size the last dimension of `value`. +// +// Returns Broadcasted sum of `value` and `bias`. +func BiasAdd(scope *Scope, value tf.Output, bias tf.Output, optional ...BiasAddAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "BiasAdd", + Input: []tf.Input{ + value, bias, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the truth value of (x != y) element-wise. +// +// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NotEqual", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeWavAttr is an optional argument to DecodeWav. +type DecodeWavAttr func(optionalAttr) + +// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. +// +// value: Number of sample channels wanted. +// If not specified, defaults to -1 +func DecodeWavDesiredChannels(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_channels"] = value + } +} + +// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. +// +// value: Length of audio requested. +// If not specified, defaults to -1 +func DecodeWavDesiredSamples(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_samples"] = value + } +} + +// Decode a 16-bit PCM WAV file to a float tensor. +// +// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. +// +// When desired_channels is set, if the input contains fewer channels than this +// then the last channel will be duplicated to give the requested number, else if +// the input has more channels than requested then the additional channels will be +// ignored. +// +// If desired_samples is set, then the audio will be cropped or padded with zeroes +// to the requested length. +// +// The first output contains a Tensor with the content of the audio samples. The +// lowest dimension will be the number of channels, and the second will be the +// number of samples. For example, a ten-sample-long stereo WAV file should give an +// output shape of [10, 2]. +// +// Arguments: +// contents: The WAV-encoded audio, usually from a file. +// +// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. +func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeWav", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Returns the truth value of x AND y element-wise. +// +// *NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LogicalAnd", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MatMulAttr is an optional argument to MatMul. +type MatMulAttr func(optionalAttr) + +// MatMulTransposeA sets the optional transpose_a attribute to value. +// +// value: If true, "a" is transposed before multiplication. +// If not specified, defaults to false +func MatMulTransposeA(value bool) MatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// MatMulTransposeB sets the optional transpose_b attribute to value. +// +// value: If true, "b" is transposed before multiplication. +// If not specified, defaults to false +func MatMulTransposeB(value bool) MatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// Multiply the matrix "a" by the matrix "b". +// +// The inputs must be two-dimensional matrices and the inner dimension of +// "a" (after being transposed if transpose_a is true) must match the +// outer dimension of "b" (after being transposed if transposed_b is +// true). +// +// *Note*: The default kernel implementation for MatMul on GPUs uses +// cublas. +func MatMul(scope *Scope, a tf.Output, b tf.Output, optional ...MatMulAttr) (product tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatMul", + Input: []tf.Input{ + a, b, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Selects the k nearest centers for each point. +// +// Rows of points are assumed to be input points. Rows of centers are assumed to be +// the list of candidate centers. For each point, the k centers that have least L2 +// distance to it are computed. +// +// Arguments: +// points: Matrix of shape (n, d). Rows are assumed to be input points. +// centers: Matrix of shape (m, d). Rows are assumed to be centers. +// k: Number of nearest centers to return for each point. If k is larger than m, then +// only m centers are returned. +// +// Returns Matrix of shape (n, min(m, k)). Each row contains the indices of the centers +// closest to the corresponding point, ordered by increasing distance.Matrix of shape (n, min(m, k)). Each row contains the squared L2 distance to the +// corresponding center in nearest_center_indices. +func NearestNeighbors(scope *Scope, points tf.Output, centers tf.Output, k tf.Output) (nearest_center_indices tf.Output, nearest_center_distances tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NearestNeighbors", + Input: []tf.Input{ + points, centers, k, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. +type ResizeBicubicGradAttr func(optionalAttr) + +// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. +// If not specified, defaults to false +func ResizeBicubicGradAlignCorners(value bool) ResizeBicubicGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// ResizeBicubicGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeBicubicGradHalfPixelCenters(value bool) ResizeBicubicGradAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value + } +} + +// Computes the gradient of bicubic interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBicubicGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBicubicGrad", + Input: []tf.Input{ + grads, original_image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. +type InitializeTableFromTextFileV2Attr func(optionalAttr) + +// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. +// +// value: Number of elements of the file, use -1 if unknown. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["vocab_size"] = value + } +} + +// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. +// +// value: Delimiter to separate fields in a line. +// If not specified, defaults to "\t" +func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["delimiter"] = value + } +} + +// Initializes a table from a text file. +// +// It inserts one key-value pair into the table for each line of the file. +// The key and value is extracted from the whole line content, elements from the +// split line based on `delimiter` or the line number (starting from zero). +// Where to extract the key and value from a line is specified by `key_index` and +// `value_index`. +// +// - A value of -1 means use the line number(starting from zero), expects `int64`. +// - A value of -2 means use the whole line content, expects `string`. +// - A value >= 0 means use the index (starting at zero) of the split line based +// on `delimiter`. +// +// Arguments: +// table_handle: Handle to a table which will be initialized. +// filename: Filename of a vocabulary text file. +// key_index: Column index in a line to get the table `key` values from. +// value_index: Column index that represents information of a line to get the table +// `value` values from. +// +// Returns the created operation. +func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "InitializeTableFromTextFileV2", + Input: []tf.Input{ + table_handle, filename, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// SparseMatMulAttr is an optional argument to SparseMatMul. +type SparseMatMulAttr func(optionalAttr) + +// SparseMatMulTransposeA sets the optional transpose_a attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeA(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// SparseMatMulTransposeB sets the optional transpose_b attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeB(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["a_is_sparse"] = value + } +} + +// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["b_is_sparse"] = value + } +} + +// Multiply matrix "a" by matrix "b". +// +// The inputs must be two-dimensional matrices and the inner dimension of "a" must +// match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not +// `SparseTensor`s. This op is optimized for the case where at least one of "a" or +// "b" is sparse, in the sense that they have a large proportion of zero values. +// The breakeven for using this versus a dense matrix multiply on one platform was +// 30% zero values in the sparse matrix. +// +// The gradient computation of this operation will only take advantage of sparsity +// in the input gradient when that gradient comes from a Relu. +func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseMatMul", + Input: []tf.Input{ + a, b, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the minimum along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Computes a tensor such that +// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such +// that `segment_ids[j] == i`. +// +// If the min is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
+// +// For example: +// +// ``` +// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_min(c, tf.constant([0, 0, 1])) +// # ==> [[1, 2, 2, 1], +// # [5, 6, 7, 8]] +// ``` +// +// Arguments: +// +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SegmentMin", + Input: []tf.Input{ + data, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Batch normalization. +// +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// +// This op is deprecated. Prefer `tf.nn.batch_normalization`. +// +// Arguments: +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// beta: A 1D beta Tensor with size matching the last dimension of t. +// An offset to be added to the normalized tensor. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this tensor will be multiplied +// with the normalized tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + opspec := tf.OpSpec{ + Type: "BatchNormWithGlobalNormalization", + Input: []tf.Input{ + t, m, v, beta, gamma, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ProdAttr is an optional argument to Prod. +type ProdAttr func(optionalAttr) + +// ProdKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func ProdKeepDims(value bool) ProdAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the product of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Prod", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// A TPU core selector Op. +// +// This Op produces a set of TPU cores (for warm-up) or a single TPU core +// (for regular inference) to execute the TPU program on. The output is +// consumed by TPUPartitionedCall. +// +// Returns A vector 1 or more TPU cores. +func TPUOrdinalSelector(scope *Scope) (device_ordinals tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TPUOrdinalSelector", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes gradients for SparseSegmentMean. +// +// Returns tensor "output" with same shape as grad, except for dimension 0 whose +// value is output_dim0. +// +// Arguments: +// grad: gradient propagated to the SparseSegmentMean op. +// indices: indices passed to the corresponding SparseSegmentMean op. +// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. +// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. +func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentMeanGrad", + Input: []tf.Input{ + grad, indices, segment_ids, output_dim0, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 2D fast Fourier transform. +// +// Computes the 2-dimensional discrete Fourier transform over the inner-most +// 2 dimensions of `input`. +// +// Arguments: +// input: A complex tensor. +// +// Returns A complex tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft2 +// @end_compatibility +func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ArgMaxAttr is an optional argument to ArgMax. +type ArgMaxAttr func(optionalAttr) + +// ArgMaxOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMaxOutputType(value tf.DataType) ArgMaxAttr { + return func(m optionalAttr) { + m["output_type"] = value + } +} + +// Returns the index with the largest value across dimensions of a tensor. +// +// Note that in case of ties the identity of the return value is not guaranteed. +// +// Usage: +// ```python +// import tensorflow as tf +// a = [1, 10, 26.9, 2.8, 166.32, 62.3] +// b = tf.math.argmax(input = a) +// c = tf.keras.backend.eval(b) +// # c = 4 +// # here a[4] = 166.32 which is the largest element of a across axis 0 +// ``` +// +// Arguments: +// +// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. +// Describes which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ArgMax", + Input: []tf.Input{ + input, dimension, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FractionalAvgPoolGradAttr is an optional argument to FractionalAvgPoolGrad. +type FractionalAvgPoolGradAttr func(optionalAttr) + +// FractionalAvgPoolGradOverlapping sets the optional overlapping attribute to value. +// +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [41/3, 26/3] for fractional avg pooling. +// If not specified, defaults to false +func FractionalAvgPoolGradOverlapping(value bool) FractionalAvgPoolGradAttr { + return func(m optionalAttr) { + m["overlapping"] = value + } +} + +// Computes gradient of the FractionalAvgPool function. +// +// Unlike FractionalMaxPoolGrad, we don't need to find arg_max for +// FractionalAvgPoolGrad, we just need to evenly back-propagate each element of +// out_backprop to those indices that form the same pooling cell. Therefore, we +// just need to know the shape of original input tensor, instead of the whole +// tensor. +// +// Arguments: +// orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool` +// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients +// w.r.t. the output of `fractional_avg_pool`. +// row_pooling_sequence: row pooling sequence, form pooling region with +// col_pooling_sequence. +// col_pooling_sequence: column pooling sequence, form pooling region with +// row_pooling sequence. +// +// Returns 4-D. Gradients w.r.t. the input of `fractional_avg_pool`. +func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalAvgPoolGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FractionalAvgPoolGrad", + Input: []tf.Input{ + orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the mean along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Computes a tensor such that +// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is +// over `j` such that `segment_ids[j] == i` and `N` is the total number of +// values summed. +// +// If the mean is empty for a given segment ID `i`, `output[i] = 0`. +// +//
+// +//
+// +// For example: +// +// ``` +// c = tf.constant([[1.0,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +// tf.segment_mean(c, tf.constant([0, 0, 1])) +// # ==> [[2.5, 2.5, 2.5, 2.5], +// # [5, 6, 7, 8]] +// ``` +// +// +// Arguments: +// +// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s +// first dimension. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SegmentMean", + Input: []tf.Input{ + data, segment_ids, }, } op := scope.AddOperation(opspec) @@ -17324,6 +32803,220 @@ func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_box return op.Output(0), op.Output(1), op.Output(2) } +// MapIncompleteSizeAttr is an optional argument to MapIncompleteSize. +type MapIncompleteSizeAttr func(optionalAttr) + +// MapIncompleteSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapIncompleteSizeCapacity(value int64) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapIncompleteSizeMemoryLimit(value int64) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapIncompleteSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapIncompleteSizeContainer(value string) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapIncompleteSizeSharedName(value string) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of incomplete elements in the underlying container. +func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncompleteSizeAttr) (size tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MapIncompleteSize", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the sum along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Computes a tensor such that +// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such +// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` +// need not be sorted and need not cover all values in the full +// range of valid values. +// +// If the sum is empty for a given segment ID `i`, `output[i] = 0`. +// If the given segment ID `i` is negative, the value is dropped and will not be +// added to the sum of the segment. +// +// `num_segments` should equal the number of distinct segment IDs. +// +//
+// +//
+// +// ``` python +// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) +// tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2) +// # ==> [[ 5, 5, 5, 5], +// # [5, 6, 7, 8]] +// ``` +// +// +// Arguments: +// +// segment_ids: A tensor whose shape is a prefix of `data.shape`. +// +// +// Returns Has same shape as data, except for the first `segment_ids.rank` +// dimensions, which are replaced with a single dimension which has size +// `num_segments`. +func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "UnsortedSegmentSum", + Input: []tf.Input{ + data, segment_ids, num_segments, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the sum along sparse segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first +// dimension, selecting a subset of dimension 0, specified by `indices`. +// +// For example: +// +// ```python +// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) +// +// # Select two rows, one segment. +// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) +// # => [[0 0 0 0]] +// +// # Select two rows, two segment. +// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) +// # => [[ 1 2 3 4] +// # [-1 -2 -3 -4]] +// +// # Select all rows, two segments. +// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) +// # => [[0 0 0 0] +// # [5 6 7 8]] +// +// # Which is equivalent to: +// tf.segment_sum(c, tf.constant([0, 0, 1])) +// ``` +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentSum", + Input: []tf.Input{ + data, indices, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the sum along sparse segments of a tensor. +// +// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/sparse#Segmentation) +// for an explanation of segments. +// +// For example: +// +// ```python +// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) +// +// tf.sparse_segment_sum_with_num_segments( +// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) +// # => [[0 0 0 0] +// # [0 0 0 0] +// # [0 0 0 0]] +// +// tf.sparse_segment_sum_with_num_segments(c, +// tf.constant([0, 1]), +// tf.constant([0, 2], +// num_segments=4)) +// # => [[ 1 2 3 4] +// # [ 0 0 0 0] +// # [-1 -2 -3 -4] +// # [ 0 0 0 0]] +// ``` +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `num_segments`. +func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentSumWithNumSegments", + Input: []tf.Input{ + data, indices, segment_ids, num_segments, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate. type ResourceScatterNdUpdateAttr func(optionalAttr) @@ -17403,276 +33096,83 @@ func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } -// UnicodeDecodeWithOffsetsAttr is an optional argument to UnicodeDecodeWithOffsets. -type UnicodeDecodeWithOffsetsAttr func(optionalAttr) - -// UnicodeDecodeWithOffsetsErrors sets the optional errors attribute to value. +// Computes the mean along sparse segments of a tensor. // -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeDecodeWithOffsetsErrors(value string) UnicodeDecodeWithOffsetsAttr { - return func(m optionalAttr) { - m["errors"] = value - } -} - -// UnicodeDecodeWithOffsetsReplacementChar sets the optional replacement_char attribute to value. +// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. // -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD or U+65533.) -// If not specified, defaults to 65533 -func UnicodeDecodeWithOffsetsReplacementChar(value int64) UnicodeDecodeWithOffsetsAttr { - return func(m optionalAttr) { - m["replacement_char"] = value - } -} - -// UnicodeDecodeWithOffsetsReplaceControlCharacters sets the optional replace_control_characters attribute to value. -// -// value: Whether to replace the C0 control characters (00-1F) with the -// `replacement_char`. Default is false. -// If not specified, defaults to false -func UnicodeDecodeWithOffsetsReplaceControlCharacters(value bool) UnicodeDecodeWithOffsetsAttr { - return func(m optionalAttr) { - m["replace_control_characters"] = value - } -} - -// Decodes each string in `input` into a sequence of Unicode code points. -// -// The character codepoints for all strings are returned using a single vector -// `char_values`, with strings expanded to characters in row-major order. -// Similarly, the character start byte offsets are returned using a single vector -// `char_to_byte_starts`, with strings expanded in row-major order. -// -// The `row_splits` tensor indicates where the codepoints and start offsets for -// each input string begin and end within the `char_values` and -// `char_to_byte_starts` tensors. In particular, the values for the `i`th -// string (in row-major order) are stored in the slice -// `[row_splits[i]:row_splits[i+1]]`. Thus: -// -// * `char_values[row_splits[i]+j]` is the Unicode codepoint for the `j`th -// character in the `i`th string (in row-major order). -// * `char_to_bytes_starts[row_splits[i]+j]` is the start byte offset for the `j`th -// character in the `i`th string (in row-major order). -// * `row_splits[i+1] - row_splits[i]` is the number of characters in the `i`th -// string (in row-major order). +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. // // Arguments: -// input: The text to be decoded. Can have any shape. Note that the output is flattened -// to a vector of char values. -// input_encoding: Text encoding of the input strings. This is any of the encodings supported -// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. // -// Returns A 1D int32 tensor containing the row splits.A 1D int32 Tensor containing the decoded codepoints.A 1D int32 Tensor containing the byte index in the input string where each -// character in `char_values` starts. -func UnicodeDecodeWithOffsets(scope *Scope, input tf.Output, input_encoding string, optional ...UnicodeDecodeWithOffsetsAttr) (row_splits tf.Output, char_values tf.Output, char_to_byte_starts tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"input_encoding": input_encoding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UnicodeDecodeWithOffsets", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Returns x - y element-wise. +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. // -// *NOTE*: `Subtract` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Returns Has same shape as data, except for dimension 0 which has size +// `num_segments`. +func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Sub", + Type: "SparseSegmentMeanWithNumSegments", Input: []tf.Input{ - x, y, + data, indices, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// LRNAttr is an optional argument to LRN. -type LRNAttr func(optionalAttr) - -// LRNDepthRadius sets the optional depth_radius attribute to value. -// -// value: 0-D. Half-width of the 1-D normalization window. -// If not specified, defaults to 5 -func LRNDepthRadius(value int64) LRNAttr { - return func(m optionalAttr) { - m["depth_radius"] = value - } -} - -// LRNBias sets the optional bias attribute to value. -// -// value: An offset (usually positive to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNBias(value float32) LRNAttr { - return func(m optionalAttr) { - m["bias"] = value - } -} - -// LRNAlpha sets the optional alpha attribute to value. -// -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNAlpha(value float32) LRNAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// LRNBeta sets the optional beta attribute to value. -// -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNBeta(value float32) LRNAttr { - return func(m optionalAttr) { - m["beta"] = value - } -} - -// Local Response Normalization. -// -// The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last -// dimension), and each vector is normalized independently. Within a given vector, -// each component is divided by the weighted, squared sum of inputs within -// `depth_radius`. In detail, -// -// sqr_sum[a, b, c, d] = -// sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) -// output = input / (bias + alpha * sqr_sum) ** beta -// -// For details, see [Krizhevsky et al., ImageNet classification with deep -// convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks). -// -// Arguments: -// input: 4-D. -func LRN(scope *Scope, input tf.Output, optional ...LRNAttr) (output tf.Output) { +// Computes inverse hyperbolic sine of x element-wise. +func Asinh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "LRN", + Type: "Asinh", Input: []tf.Input{ - input, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug. -type RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr func(optionalAttr) +// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints. +type MergeV2CheckpointsAttr func(optionalAttr) -// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 +// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value. // -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve proximal Adagrad embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the proximal Adagrad optimization algorithm.Parameter accumulators updated by the proximal Adagrad optimization algorithm.Parameter gradient_accumulators updated by the proximal Adagrad optimization algorithm. -func RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ResourceSparseApplyAdagradAttr is an optional argument to ResourceSparseApplyAdagrad. -type ResourceSparseApplyAdagradAttr func(optionalAttr) - -// ResourceSparseApplyAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value. +// value: see above. // If not specified, defaults to true -func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr { +func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr { return func(m optionalAttr) { - m["update_slots"] = value + m["delete_old_dirs"] = value } } -// Update relevant entries in '*var' and '*accum' according to the adagrad scheme. +// V2 format specific: merges the metadata files of sharded checkpoints. The // -// That is for rows we have grad for, we update var and accum as follows: -// accum += grad * grad -// var -= lr * grad * (1 / sqrt(accum)) +// result is one logical checkpoint, with one physical metadata file and renamed +// data files. +// +// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. +// +// If delete_old_dirs is true, attempts to delete recursively the dirname of each +// path in the input checkpoint_prefixes. This is useful when those paths are non +// user-facing temporary locations. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// checkpoint_prefixes: prefixes of V2 checkpoints to merge. +// destination_prefix: scalar. The desired final prefix. Allowed to be the same +// as one of the checkpoint_prefixes. // // Returns the created operation. -func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdagradAttr) (o *tf.Operation) { +func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -17681,99 +33181,387 @@ func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, l a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagrad", + Type: "MergeV2Checkpoints", Input: []tf.Input{ - var_, accum, lr, grad, indices, + checkpoint_prefixes, destination_prefix, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// LoadTPUEmbeddingMomentumParametersAttr is an optional argument to LoadTPUEmbeddingMomentumParameters. -type LoadTPUEmbeddingMomentumParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. // -// REQUIRES: value >= -1 -func LoadTPUEmbeddingMomentumParametersTableId(value int64) LoadTPUEmbeddingMomentumParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingMomentumParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingMomentumParametersTableName(value string) LoadTPUEmbeddingMomentumParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load Momentum embedding parameters. +// N is the size of the segment being reduced. +// +// See `tf.sparse.segment_sum` for usage examples. // -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. // // Arguments: -// parameters: Value of parameters used in the Momentum optimization algorithm. -// momenta: Value of momenta used in the Momentum optimization algorithm. // +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. // -// -// Returns the created operation. -func LoadTPUEmbeddingMomentumParameters(scope *Scope, parameters tf.Output, momenta tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMomentumParametersAttr) (o *tf.Operation) { +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + opspec := tf.OpSpec{ + Type: "SparseSegmentSqrtN", + Input: []tf.Input{ + data, indices, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. +type FractionalMaxPoolAttr func(optionalAttr) + +// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. +// +// value: When set to True, generates the pooling sequence in a +// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin +// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for +// difference between pseudorandom and random. +// If not specified, defaults to false +func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["pseudo_random"] = value + } +} + +// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. +// +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [20, 16] for fractional max pooling. +// If not specified, defaults to false +func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["overlapping"] = value + } +} + +// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. +// +// value: When set to True, a fixed pooling region will be used when +// iterating over a FractionalMaxPool node in the computation graph. Mainly used +// in unit test to make FractionalMaxPool deterministic. +// If not specified, defaults to false +func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["deterministic"] = value + } +} + +// FractionalMaxPoolSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Performs fractional max pooling on the input. +// +// Fractional max pooling is slightly different than regular max pooling. In +// regular max pooling, you downsize an input set by taking the maximum value of +// smaller N x N subsections of the set (often 2x2), and try to reduce the set by +// a factor of N, where N is an integer. Fractional max pooling, as you might +// expect from the word "fractional", means that the overall reduction ratio N +// does not have to be an integer. +// +// The sizes of the pooling regions are generated randomly but are fairly uniform. +// For example, let's look at the height dimension, and the constraints on the +// list of rows that will be pool boundaries. +// +// First we define the following: +// +// 1. input_row_length : the number of rows from the input set +// 2. output_row_length : which will be smaller than the input +// 3. alpha = input_row_length / output_row_length : our reduction ratio +// 4. K = floor(alpha) +// 5. row_pooling_sequence : this is the result list of pool boundary rows +// +// Then, row_pooling_sequence should satisfy: +// +// 1. a[0] = 0 : the first value of the sequence is 0 +// 2. a[end] = input_row_length : the last value of the sequence is the size +// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size +// 4. length(row_pooling_sequence) = output_row_length+1 +// +// For more details on fractional max pooling, see this paper: +// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) +// +// Arguments: +// value: 4-D with shape `[batch, height, width, channels]`. +// pooling_ratio: Pooling ratio for each dimension of `value`, currently only +// supports row and col dimension and should be >= 1.0. For example, a valid +// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements +// must be 1.0 because we don't allow pooling on batch and channels +// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions +// respectively. +// +// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. +func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingMomentumParameters", + Type: "FractionalMaxPool", Input: []tf.Input{ - parameters, momenta, + value, }, Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// AllAttr is an optional argument to All. +type AllAttr func(optionalAttr) + +// AllKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AllKeepDims(value bool) AllAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the "logical and" of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. +// +// Returns The reduced tensor. +func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "All", + Input: []tf.Input{ + input, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MapClearAttr is an optional argument to MapClear. +type MapClearAttr func(optionalAttr) + +// MapClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapClearCapacity(value int64) MapClearAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapClearMemoryLimit(value int64) MapClearAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapClearContainer(value string) MapClearAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapClearSharedName(value string) MapClearAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes all elements in the underlying container. +// +// Returns the created operation. +func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MapClear", + + Attrs: attrs, + } return scope.AddOperation(opspec) } -// Assigns sparse updates to the variable referenced by `resource`. +// Creates a sequence of numbers. // -// This operation computes +// This operation creates a sequence of numbers that begins at `start` and +// extends by increments of `delta` up to but not including `limit`. // -// # Scalar indices -// ref[indices, ...] = updates[...] +// For example: // -// # Vector indices (for each i) -// ref[indices[i], ...] = updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] +// ``` +// # 'start' is 3 +// # 'limit' is 18 +// # 'delta' is 3 +// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] +// ``` // // Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. +// start: 0-D (scalar). First entry in the sequence. +// limit: 0-D (scalar). Upper limit of sequence, exclusive. +// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. // -// Returns the created operation. -func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +// Returns 1-D. +func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ResourceScatterUpdate", + Type: "Range", Input: []tf.Input{ - resource, indices, updates, + start, limit, delta, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Generates values in an interval. +// +// A sequence of `num` evenly-spaced values are generated beginning at `start`. +// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, +// so that the last one is exactly `stop`. +// +// For example: +// +// ``` +// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] +// ``` +// +// Arguments: +// start: 0-D tensor. First entry in the range. +// stop: 0-D tensor. Last entry in the range. +// num: 0-D tensor. Number of values to generate. +// +// Returns 1-D. The generated values. +func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LinSpace", + Input: []tf.Input{ + start, stop, num, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the complex conjugate of a complex number. +// +// Given a tensor `input` of complex numbers, this operation returns a tensor of +// complex numbers that are the complex conjugate of each element in `input`. The +// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the +// real part and *b* is the imaginary part. +// +// The complex conjugate returned by this operation is of the form \\(a - bj\\). +// +// For example: +// +// ``` +// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] +// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] +// ``` +func Conj(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Conj", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// A container for an iterator resource. +// +// Returns A handle to the iterator that can be passed to a "MakeIterator" or +// "IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents +// resource sharing by name, and does not keep a reference to the resource +// container.A variant deleter that should be passed into the op that deletes the iterator. +func AnonymousIteratorV2(scope *Scope, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output, deleter tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "AnonymousIteratorV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } // HistogramFixedWidthAttr is an optional argument to HistogramFixedWidth. @@ -17832,112 +33620,75 @@ func HistogramFixedWidth(scope *Scope, values tf.Output, value_range tf.Output, return op.Output(0) } -// Elementwise computes the bitwise right-shift of `x` and `y`. +// Counts the number of occurrences of each value in an integer array. // -// Performs a logical shift for unsigned integer types, and an arithmetic shift -// for signed integer types. +// Outputs a vector with length `size` and the same dtype as `weights`. If +// `weights` are empty, then index `i` stores the number of times the value `i` is +// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of +// the value in `weights` at each index where the corresponding value in `arr` is +// `i`. // -// If `y` is negative, or greater than or equal to than the width of `x` in bits -// the result is implementation defined. -func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Values in `arr` outside of the range [0, size) are ignored. +// +// Arguments: +// arr: int32 `Tensor`. +// size: non-negative int32 scalar `Tensor`. +// weights: is an int32, int64, float32, or float64 `Tensor` with the same +// shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights +// equal to 1. +// +// Returns 1D `Tensor` with length equal to `size`. The counts or summed weights for +// each value in the range [0, size). +func Bincount(scope *Scope, arr tf.Output, size tf.Output, weights tf.Output) (bins tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RightShift", + Type: "Bincount", Input: []tf.Input{ - x, y, + arr, size, weights, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorListStackAttr is an optional argument to TensorListStack. -type TensorListStackAttr func(optionalAttr) +// QrAttr is an optional argument to Qr. +type QrAttr func(optionalAttr) -// TensorListStackNumElements sets the optional num_elements attribute to value. -// If not specified, defaults to -1 -func TensorListStackNumElements(value int64) TensorListStackAttr { +// QrFullMatrices sets the optional full_matrices attribute to value. +// +// value: If true, compute full-sized `q` and `r`. If false +// (the default), compute only the leading `P` columns of `q`. +// If not specified, defaults to false +func QrFullMatrices(value bool) QrAttr { return func(m optionalAttr) { - m["num_elements"] = value + m["full_matrices"] = value } } -// Stacks all tensors in the list. +// Computes the QR decompositions of one or more matrices. // -// Requires that all tensors have the same shape. +// Computes the QR decomposition of each inner matrix in `tensor` such that +// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` // -// input_handle: the input list -// tensor: the gathered result -// num_elements: optional. If not -1, the number of elements in the list. -// -func TensorListStack(scope *Scope, input_handle tf.Output, element_shape tf.Output, element_dtype tf.DataType, optional ...TensorListStackAttr) (tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorListStack", - Input: []tf.Input{ - input_handle, element_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// A placeholder op for a value that will be fed into the computation. +// ```python +// # a is a tensor. +// # q is a tensor of orthonormal matrices. +// # r is a tensor of upper triangular matrices. +// q, r = qr(a) +// q_full, r_full = qr(a, full_matrices=True) +// ``` // // Arguments: -// dtype: The type of elements in the tensor. -// shape: The shape of the tensor. +// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. // -// Returns A tensor that will be provided using the infeed mechanism. -func InfeedDequeue(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} - opspec := tf.OpSpec{ - Type: "InfeedDequeue", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. -type StatelessRandomUniformAttr func(optionalAttr) - -// StatelessRandomUniformDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomUniformDtype(value tf.DataType) StatelessRandomUniformAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// -// Returns Random values with specified shape. -func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformAttr) (output tf.Output) { +// Returns Orthonormal basis for range of `a`. If `full_matrices` is `False` then +// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is +// `[..., M, M]`.Triangular factor. If `full_matrices` is `False` then shape is +// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`. +func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) { if scope.Err() != nil { return } @@ -17946,116 +33697,9 @@ func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optio a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessRandomUniform", + Type: "Qr", Input: []tf.Input{ - shape, seed, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Makes its input available to the next iteration. -// -// Arguments: -// data: The tensor to be made available to the next iteration. -// -// Returns The same tensor as `data`. -func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NextIteration", - Input: []tf.Input{ - data, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Output a fact about factorials. -func Fact(scope *Scope) (fact tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Fact", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping. -type GenerateVocabRemappingAttr func(optionalAttr) - -// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value. -// -// value: Number of entries in the old vocab file to consider. If -1, -// use the entire old vocabulary. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr { - return func(m optionalAttr) { - m["old_vocab_size"] = value - } -} - -// Given a path to new and old vocabulary files, returns a remapping Tensor of -// -// length `num_new_vocab`, where `remapping[i]` contains the row number in the old -// vocabulary that corresponds to row `i` in the new vocabulary (starting at line -// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i` -// in the new vocabulary is not in the old vocabulary. The old vocabulary is -// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the -// default value of -1. -// -// `num_vocab_offset` enables -// use in the partitioned variable case, and should generally be set through -// examining partitioning info. The format of the files should be a text file, -// with each line containing a single entity within the vocabulary. -// -// For example, with `new_vocab_file` a text file containing each of the following -// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3], -// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be -// `[0, -1, 2]`. -// -// The op also returns a count of how many entries in the new vocabulary -// were present in the old vocabulary, which is used to calculate the number of -// values to initialize in a weight matrix remapping -// -// This functionality can be used to remap both row vocabularies (typically, -// features) and column vocabularies (typically, classes) from TensorFlow -// checkpoints. Note that the partitioning logic relies on contiguous vocabularies -// corresponding to div-partitioned variables. Moreover, the underlying remapping -// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should -// use the corresponding index_table_from_file() as the FeatureColumn framework -// does (as opposed to tf.feature_to_id(), which uses a CuckooTable). -// -// Arguments: -// new_vocab_file: Path to the new vocab file. -// old_vocab_file: Path to the old vocab file. -// new_vocab_offset: How many entries into the new vocab file to start reading. -// num_new_vocab: Number of entries in the new vocab file to remap. -// -// Returns A Tensor of length num_new_vocab where the element at index i -// is equal to the old ID that maps to the new ID i. This element is -1 for any -// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab. -func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "GenerateVocabRemapping", - Input: []tf.Input{ - new_vocab_file, old_vocab_file, + input, }, Attrs: attrs, } @@ -18063,73 +33707,67 @@ func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_fi return op.Output(0), op.Output(1) } -// Worker heartbeat op. -// -// Heartbeats may be sent periodically to indicate the coordinator is still active, -// to retrieve the current worker status and to expedite shutdown when necessary. -// -// Arguments: -// request: A string tensor containing a serialized WorkerHeartbeatRequest -// -// Returns A string tensor containing a serialized WorkerHeartbeatResponse -func WorkerHeartbeat(scope *Scope, request tf.Output) (response tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "WorkerHeartbeat", - Input: []tf.Input{ - request, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// CumsumAttr is an optional argument to Cumsum. +type CumsumAttr func(optionalAttr) -// Returns the truth value of (x <= y) element-wise. +// CumsumExclusive sets the optional exclusive attribute to value. // -// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LessEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LessEqual", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// EnqueueTPUEmbeddingIntegerBatchAttr is an optional argument to EnqueueTPUEmbeddingIntegerBatch. -type EnqueueTPUEmbeddingIntegerBatchAttr func(optionalAttr) - -// EnqueueTPUEmbeddingIntegerBatchDeviceOrdinal sets the optional device_ordinal attribute to value. -// -// value: The TPU device to use. Should be >= 0 and less than the number -// of TPU cores in the task on which the node is placed. -// If not specified, defaults to -1 -func EnqueueTPUEmbeddingIntegerBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingIntegerBatchAttr { +// value: If `True`, perform exclusive cumsum. +// If not specified, defaults to false +func CumsumExclusive(value bool) CumsumAttr { return func(m optionalAttr) { - m["device_ordinal"] = value + m["exclusive"] = value } } -// An op that enqueues a list of input batch tensors to TPUEmbedding. +// CumsumReverse sets the optional reverse attribute to value. +// +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumsumReverse(value bool) CumsumAttr { + return func(m optionalAttr) { + m["reverse"] = value + } +} + +// Compute the cumulative sum of the tensor `x` along `axis`. +// +// By default, this op performs an inclusive cumsum, which means that the first +// element of the input is identical to the first element of the output: +// +// ```python +// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] +// ``` +// +// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is +// performed instead: +// +// ```python +// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] +// ``` +// +// By setting the `reverse` kwarg to `True`, the cumsum is performed in the +// opposite direction: +// +// ```python +// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] +// ``` +// +// This is more efficient than using separate `tf.reverse` ops. +// +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] +// ``` // // Arguments: -// batch: A list of 1D tensors, one for each embedding table, containing the -// indices into the tables. -// mode_override: A string input that overrides the mode specified in the -// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', -// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set -// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. -// -// Returns the created operation. -func EnqueueTPUEmbeddingIntegerBatch(scope *Scope, batch []tf.Output, mode_override tf.Output, optional ...EnqueueTPUEmbeddingIntegerBatchAttr) (o *tf.Operation) { +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) { if scope.Err() != nil { return } @@ -18138,39 +33776,266 @@ func EnqueueTPUEmbeddingIntegerBatch(scope *Scope, batch []tf.Output, mode_overr a(attrs) } opspec := tf.OpSpec{ - Type: "EnqueueTPUEmbeddingIntegerBatch", + Type: "Cumsum", Input: []tf.Input{ - tf.OutputList(batch), mode_override, + x, axis, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// An op that receives embedding activations on the TPU. -// -// The TPU system performs the embedding lookups and aggregations specified by -// the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The -// results of these aggregations are visible to the Tensorflow Graph as the -// outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing -// one Tensor of activations per table specified in the model. There can be at -// most one RecvTPUEmbeddingActivations op in the TPU graph. +// Creates a dataset that emits the lines of one or more text files. // // Arguments: -// num_outputs: The number of output activation tensors, equal to the number of -// embedding tables in the model. -// config: Serialized TPUEmbeddingConfiguration proto. -// -// Returns A TensorList of embedding activations containing one Tensor per -// embedding table in the model. -func RecvTPUEmbeddingActivations(scope *Scope, num_outputs int64, config string) (outputs []tf.Output) { +// filenames: A scalar or a vector containing the name(s) of the file(s) to be +// read. +// compression_type: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// buffer_size: A scalar containing the number of bytes to buffer. +func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_outputs": num_outputs, "config": config} opspec := tf.OpSpec{ - Type: "RecvTPUEmbeddingActivations", + Type: "TextLineDataset", + Input: []tf.Input{ + filenames, compression_type, buffer_size, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} +// UnicodeTranscodeAttr is an optional argument to UnicodeTranscode. +type UnicodeTranscodeAttr func(optionalAttr) + +// UnicodeTranscodeErrors sets the optional errors attribute to value. +// +// value: Error handling policy when there is invalid formatting found in the input. +// The value of 'strict' will cause the operation to produce a InvalidArgument +// error on any invalid input formatting. A value of 'replace' (the default) will +// cause the operation to replace any invalid formatting in the input with the +// `replacement_char` codepoint. A value of 'ignore' will cause the operation to +// skip any invalid formatting in the input and produce no corresponding output +// character. +// If not specified, defaults to "replace" +func UnicodeTranscodeErrors(value string) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["errors"] = value + } +} + +// UnicodeTranscodeReplacementChar sets the optional replacement_char attribute to value. +// +// value: The replacement character codepoint to be used in place of any invalid +// formatting in the input when `errors='replace'`. Any valid unicode codepoint may +// be used. The default value is the default unicode replacement character is +// 0xFFFD or U+65533.) +// +// Note that for UTF-8, passing a replacement character expressible in 1 byte, such +// as ' ', will preserve string alignment to the source since invalid bytes will be +// replaced with a 1-byte replacement. For UTF-16-BE and UTF-16-LE, any 1 or 2 byte +// replacement character will preserve byte alignment to the source. +// If not specified, defaults to 65533 +func UnicodeTranscodeReplacementChar(value int64) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["replacement_char"] = value + } +} + +// UnicodeTranscodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. +// +// value: Whether to replace the C0 control characters (00-1F) with the +// `replacement_char`. Default is false. +// If not specified, defaults to false +func UnicodeTranscodeReplaceControlCharacters(value bool) UnicodeTranscodeAttr { + return func(m optionalAttr) { + m["replace_control_characters"] = value + } +} + +// Transcode the input text from a source encoding to a destination encoding. +// +// The input is a string tensor of any shape. The output is a string tensor of +// the same shape containing the transcoded strings. Output strings are always +// valid unicode. If the input contains invalid encoding positions, the +// `errors` attribute sets the policy for how to deal with them. If the default +// error-handling policy is used, invalid formatting will be substituted in the +// output by the `replacement_char`. If the errors policy is to `ignore`, any +// invalid encoding positions in the input are skipped and not included in the +// output. If it set to `strict` then any invalid formatting will result in an +// InvalidArgument error. +// +// This operation can be used with `output_encoding = input_encoding` to enforce +// correct formatting for inputs even if they are already in the desired encoding. +// +// If the input is prefixed by a Byte Order Mark needed to determine encoding +// (e.g. if the encoding is UTF-16 and the BOM indicates big-endian), then that +// BOM will be consumed and not emitted into the output. If the input encoding +// is marked with an explicit endianness (e.g. UTF-16-BE), then the BOM is +// interpreted as a non-breaking-space and is preserved in the output (including +// always for UTF-8). +// +// The end result is that if the input is marked as an explicit endianness the +// transcoding is faithful to all codepoints in the source. If it is not marked +// with an explicit endianness, the BOM is not considered part of the string itself +// but as metadata, and so is not preserved in the output. +// +// Arguments: +// input: The text to be processed. Can have any shape. +// input_encoding: Text encoding of the input strings. This is any of the encodings supported +// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. +// output_encoding: The unicode encoding to use in the output. Must be one of +// `"UTF-8", "UTF-16-BE", "UTF-32-BE"`. Multi-byte encodings will be big-endian. +// +// Returns A string tensor containing unicode text encoded using `output_encoding`. +func UnicodeTranscode(scope *Scope, input tf.Output, input_encoding string, output_encoding string, optional ...UnicodeTranscodeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"input_encoding": input_encoding, "output_encoding": output_encoding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "UnicodeTranscode", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// CumprodAttr is an optional argument to Cumprod. +type CumprodAttr func(optionalAttr) + +// CumprodExclusive sets the optional exclusive attribute to value. +// +// value: If `True`, perform exclusive cumprod. +// If not specified, defaults to false +func CumprodExclusive(value bool) CumprodAttr { + return func(m optionalAttr) { + m["exclusive"] = value + } +} + +// CumprodReverse sets the optional reverse attribute to value. +// +// value: A `bool` (default: False). +// If not specified, defaults to false +func CumprodReverse(value bool) CumprodAttr { + return func(m optionalAttr) { + m["reverse"] = value + } +} + +// Compute the cumulative product of the tensor `x` along `axis`. +// +// By default, this op performs an inclusive cumprod, which means that the first +// element of the input is identical to the first element of the output: +// +// ```python +// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +// ``` +// +// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +// performed instead: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +// ``` +// +// By setting the `reverse` kwarg to `True`, the cumprod is performed in the +// opposite direction: +// +// ```python +// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +// ``` +// +// This is more efficient than using separate `tf.reverse` ops. +// +// The `reverse` and `exclusive` kwargs can also be combined: +// +// ```python +// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +// ``` +// +// Arguments: +// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, +// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, +// `complex128`, `qint8`, `quint8`, `qint32`, `half`. +// axis: A `Tensor` of type `int32` (default: 0). Must be in the range +// `[-rank(x), rank(x))`. +func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Cumprod", + Input: []tf.Input{ + x, axis, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// UnpackAttr is an optional argument to Unpack. +type UnpackAttr func(optionalAttr) + +// UnpackAxis sets the optional axis attribute to value. +// +// value: Dimension along which to unpack. Negative values wrap around, so the +// valid range is `[-R, R)`. +// If not specified, defaults to 0 +func UnpackAxis(value int64) UnpackAttr { + return func(m optionalAttr) { + m["axis"] = value + } +} + +// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. +// +// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. +// For example, given a tensor of shape `(A, B, C, D)`; +// +// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` +// and each tensor in `output` will have shape `(B, C, D)`. (Note that the +// dimension unpacked along is gone, unlike `split`). +// +// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` +// and each tensor in `output` will have shape `(A, C, D)`. +// Etc. +// +// This is the opposite of `pack`. +// +// Arguments: +// value: 1-D or higher, with `axis` dimension size equal to `num`. +// +// +// Returns The list of tensors unpacked from `value`. +func Unpack(scope *Scope, value tf.Output, num int64, optional ...UnpackAttr) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num": num} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Unpack", + Input: []tf.Input{ + value, + }, Attrs: attrs, } op := scope.AddOperation(opspec) @@ -18179,94 +34044,1162 @@ func RecvTPUEmbeddingActivations(scope *Scope, num_outputs int64, config string) } var idx int var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("RecvTPUEmbeddingActivations", err) + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("Unpack", err) return } - return outputs + return output } -// Selects elements from `x` or `y`, depending on `condition`. +// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. +type QuantizedMatMulAttr func(optionalAttr) + +// QuantizedMatMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Toutput"] = value + } +} + +// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. // -// The `x`, and `y` tensors must all have the same shape, and the -// output will also have that shape. +// value: If true, `a` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. // -// The `condition` tensor must be a scalar if `x` and `y` are scalars. -// If `x` and `y` are vectors or higher rank, then `condition` must be either a -// scalar, a vector with size matching the first dimension of `x`, or must have -// the same shape as `x`. +// value: If true, `b` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. // -// The `condition` tensor acts as a mask that chooses, based on the value at each -// element, whether the corresponding element / row in the output should be -// taken from `x` (if true) or `y` (if false). +// value: The type of output produced by activation function +// following this operation. +// If not specified, defaults to DT_QUINT8 +func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Tactivation"] = value + } +} + +// Perform a quantized matrix multiplication of `a` by the matrix `b`. // -// If `condition` is a vector and `x` and `y` are higher rank matrices, then -// it chooses which row (outer dimension) to copy from `x` and `y`. -// If `condition` has the same shape as `x` and `y`, then it chooses which -// element to copy from `x` and `y`. -// -// For example: -// -// ```python -// # 'condition' tensor is [[True, False] -// # [False, True]] -// # 't' is [[1, 2], -// # [3, 4]] -// # 'e' is [[5, 6], -// # [7, 8]] -// select(condition, t, e) # => [[1, 6], [7, 4]] -// -// -// # 'condition' tensor is [True, False] -// # 't' is [[1, 2], -// # [3, 4]] -// # 'e' is [[5, 6], -// # [7, 8]] -// select(condition, t, e) ==> [[1, 2], -// [7, 8]] -// -// ``` +// The inputs must be two-dimensional matrices and the inner dimension of +// `a` (after being transposed if `transpose_a` is non-zero) must match the +// outer dimension of `b` (after being transposed if `transposed_b` is +// non-zero). // // Arguments: +// a: Must be a two-dimensional tensor. +// b: Must be a two-dimensional tensor. +// min_a: The float value that the lowest quantized `a` value represents. +// max_a: The float value that the highest quantized `a` value represents. +// min_b: The float value that the lowest quantized `b` value represents. +// max_b: The float value that the highest quantized `b` value represents. // -// x: = A `Tensor` which may have the same shape as `condition`. -// If `condition` is rank 1, `x` may have higher rank, -// but its first dimension must match the size of `condition`. -// y: = A `Tensor` with the same type and shape as `x`. +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedMatMul", + Input: []tf.Input{ + a, b, min_a, max_a, min_b, max_b, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Subtracts sparse `updates` from an existing tensor according to `indices`. // -// Returns = A `Tensor` with the same type and shape as `x` and `y`. -func Select(scope *Scope, condition tf.Output, x tf.Output, y tf.Output) (output tf.Output) { +// This operation creates a new tensor by subtracting sparse `updates` from the +// passed in `tensor`. +// This operation is very similar to `tf.scatter_nd_sub`, except that the updates +// are subtracted from an existing tensor (as opposed to a variable). If the memory +// for the existing tensor cannot be re-used, a copy is made and updated. +// +// `indices` is an integer tensor containing indices into a new tensor of shape +// `shape`. The last dimension of `indices` can be at most the rank of `shape`: +// +// indices.shape[-1] <= shape.rank +// +// The last dimension of `indices` corresponds to indices into elements +// (if `indices.shape[-1] = shape.rank`) or slices +// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of +// `shape`. `updates` is a tensor with shape +// +// indices.shape[:-1] + shape[indices.shape[-1]:] +// +// The simplest form of tensor_scatter_sub is to subtract individual elements +// from a tensor by index. For example, say we want to insert 4 scattered elements +// in a rank-1 tensor with 8 elements. +// +// In Python, this scatter subtract operation would look like this: +// +// ```python +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// tensor = tf.ones([8], dtype=tf.int32) +// updated = tf.tensor_scatter_sub(tensor, indices, updates) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [1, -10, 1, -9, -8, 1, 1, -11] +// +// We can also, insert entire slices of a higher rank tensor all at once. For +// example, if we wanted to insert two slices in the first dimension of a +// rank-3 tensor with two matrices of new values. +// +// In Python, this scatter add operation would look like this: +// +// ```python +// indices = tf.constant([[0], [2]]) +// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]], +// [[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]]]) +// tensor = tf.ones([4, 4, 4]) +// updated = tf.tensor_scatter_sub(tensor, indices, updates) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [[[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]], +// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], +// [[-4, -4, -4, -4], [-5, -5, -5, -5], [-6, -6, -6, -6], [-7, -7, -7, -7]], +// [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]] +// +// Note that on CPU, if an out of bound index is found, an error is returned. +// On GPU, if an out of bound index is found, the index is ignored. +// +// Arguments: +// tensor: Tensor to copy/update. +// indices: Index tensor. +// updates: Updates to scatter into output. +// +// Returns A new tensor copied from tensor and updates subtracted according to the indices. +func TensorScatterSub(scope *Scope, tensor tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Select", + Type: "TensorScatterSub", Input: []tf.Input{ - condition, x, y, + tensor, indices, updates, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the set of files matching one or more glob patterns. +// IdentityReaderV2Attr is an optional argument to IdentityReaderV2. +type IdentityReaderV2Attr func(optionalAttr) + +// IdentityReaderV2Container sets the optional container attribute to value. // -// Note that this routine only supports wildcard characters in the -// basename portion of the pattern, not in the directory portion. -// Note also that the order of filenames returned can be non-deterministic. +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func IdentityReaderV2Container(value string) IdentityReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// IdentityReaderV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func IdentityReaderV2SharedName(value string) IdentityReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the queued work as both the key and value. +// +// To use, enqueue strings in a Queue. ReaderRead will take the front +// work string and output (work, work). +// +// Returns The handle to reference the Reader. +func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "IdentityReaderV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizedMulAttr is an optional argument to QuantizedMul. +type QuantizedMulAttr func(optionalAttr) + +// QuantizedMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { + return func(m optionalAttr) { + m["Toutput"] = value + } +} + +// Returns x * y element-wise, working on quantized buffers. // // Arguments: -// pattern: Shell wildcard pattern(s). Scalar or vector of type string. // -// Returns A vector of matching filenames. -func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { +// +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// +// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedMul", + Input: []tf.Input{ + x, y, min_x, max_x, min_y, max_y, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// LoadTPUEmbeddingMomentumParametersAttr is an optional argument to LoadTPUEmbeddingMomentumParameters. +type LoadTPUEmbeddingMomentumParametersAttr func(optionalAttr) + +// LoadTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingMomentumParametersTableId(value int64) LoadTPUEmbeddingMomentumParametersAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingMomentumParametersTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func LoadTPUEmbeddingMomentumParametersTableName(value string) LoadTPUEmbeddingMomentumParametersAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// Load Momentum embedding parameters. +// +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. +// +// Arguments: +// parameters: Value of parameters used in the Momentum optimization algorithm. +// momenta: Value of momenta used in the Momentum optimization algorithm. +// +// +// +// Returns the created operation. +func LoadTPUEmbeddingMomentumParameters(scope *Scope, parameters tf.Output, momenta tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMomentumParametersAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LoadTPUEmbeddingMomentumParameters", + Input: []tf.Input{ + parameters, momenta, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// QuantizedAddAttr is an optional argument to QuantizedAdd. +type QuantizedAddAttr func(optionalAttr) + +// QuantizedAddToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { + return func(m optionalAttr) { + m["Toutput"] = value + } +} + +// Returns x + y element-wise, working on quantized buffers. +// +// Arguments: +// +// +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// +// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedAdd", + Input: []tf.Input{ + x, y, min_x, max_x, min_y, max_y, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Convert the quantized 'input' tensor into a lower-precision 'output', using the +// +// actual distribution of the values to maximize the usage of the lower bit depth +// and adjusting the output min and max ranges accordingly. +// +// [input_min, input_max] are scalar floats that specify the range for the float +// interpretation of the 'input' data. For example, if input_min is -1.0f and +// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 +// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. +// +// This operator tries to squeeze as much precision as possible into an output with +// a lower bit depth by calculating the actual min and max values found in the +// data. For example, maybe that quint16 input has no values lower than 16,384 and +// none higher than 49,152. That means only half the range is actually needed, all +// the float interpretations are between -0.5f and 0.5f, so if we want to compress +// the data into a quint8 output, we can use that range rather than the theoretical +// -1.0f to 1.0f that is suggested by the input min and max. +// +// In practice, this is most useful for taking output from operations like +// QuantizedMatMul that can produce higher bit-depth outputs than their inputs and +// may have large potential output ranges, but in practice have a distribution of +// input values that only uses a small fraction of the possible range. By feeding +// that output into this operator, we can reduce it from 32 bits down to 8 with +// minimal loss of accuracy. +// +// Arguments: +// +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// out_type: The type of the output. Should be a lower bit depth than Tinput. +// +// Returns The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + opspec := tf.OpSpec{ + Type: "QuantizeDownAndShrinkRange", + Input: []tf.Input{ + input, input_min, input_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Returns the name of the device on which `resource` has been placed. +func ExperimentalIteratorGetDevice(scope *Scope, resource tf.Output) (device tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatchingFiles", + Type: "ExperimentalIteratorGetDevice", Input: []tf.Input{ - pattern, + resource, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x - y element-wise. +// +// *NOTE*: `Subtract` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sub", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. +// +// Each comparison returns a boolean `true` (if `input_value > threshold`) +// or and `false` otherwise. +// +// This operation is useful for Locality-Sensitive-Hashing (LSH) and other +// algorithms that use hashing approximations of cosine and `L2` distances; +// codes can be generated from an input via: +// +// ```python +// codebook_size = 50 +// codebook_bits = codebook_size * 32 +// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], +// dtype=x.dtype, +// initializer=tf.orthogonal_initializer()) +// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) +// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 +// # now codes has shape x.shape[:-1] + [codebook_size] +// ``` +// +// **NOTE**: Currently, the innermost dimension of the tensor must be divisible +// by 8. +// +// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is +// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. +// +// Arguments: +// input: Values to compare against `threshold` and bitpack. +// threshold: Threshold to compare against. +// +// Returns The bitpacked comparisons. +func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CompareAndBitpack", + Input: []tf.Input{ + input, threshold, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Bucketizes 'input' based on 'boundaries'. +// +// For example, if the inputs are +// boundaries = [0, 10, 100] +// input = [[-5, 10000] +// [150, 10] +// [5, 100]] +// +// then the output will be +// output = [[0, 3] +// [3, 2] +// [1, 3]] +// +// Arguments: +// input: Any shape of Tensor contains with int or float type. +// boundaries: A sorted list of floats gives the boundary of the buckets. +// +// Returns Same shape with 'input', each value of input replaced with bucket index. +// +// @compatibility(numpy) +// Equivalent to np.digitize. +// @end_compatibility +func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"boundaries": boundaries} + opspec := tf.OpSpec{ + Type: "Bucketize", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StageSizeAttr is an optional argument to StageSize. +type StageSizeAttr func(optionalAttr) + +// StageSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageSizeCapacity(value int64) StageSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// StageSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageSizeMemoryLimit(value int64) StageSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// StageSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StageSizeContainer(value string) StageSizeAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// StageSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StageSizeSharedName(value string) StageSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of elements in the underlying container. +func StageSize(scope *Scope, dtypes []tf.DataType, optional ...StageSizeAttr) (size tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StageSize", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Clips tensor values to a specified min and max. +// +// Given a tensor `t`, this operation returns a tensor of the same type and +// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. +// Any values less than `clip_value_min` are set to `clip_value_min`. Any values +// greater than `clip_value_max` are set to `clip_value_max`. +// +// Arguments: +// t: A `Tensor`. +// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape +// as `t`. The minimum value to clip by. +// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape +// as `t`. The maximum value to clip by. +// +// Returns A clipped `Tensor` with the same shape as input 't'. +func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ClipByValue", + Input: []tf.Input{ + t, clip_value_min, clip_value_max, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes requantization range per channel. +// +// Arguments: +// input: The original input tensor. +// input_min: The minimum value of the input tensor +// input_max: The maximum value of the input tensor. +// clip_value_max: The maximum value of the output that needs to be clipped. +// Example: set this to 6 for Relu6. +// +// Returns The minimum value of the final output tensorThe maximum value of the final output tensor. +func RequantizationRangePerChannel(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, clip_value_max float32) (output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"clip_value_max": clip_value_max} + opspec := tf.OpSpec{ + Type: "RequantizationRangePerChannel", + Input: []tf.Input{ + input, input_min, input_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Rolls the elements of a tensor along an axis. +// +// The elements are shifted positively (towards larger indices) by the offset of +// `shift` along the dimension of `axis`. Negative `shift` values will shift +// elements in the opposite direction. Elements that roll passed the last position +// will wrap around to the first and vice versa. Multiple shifts along multiple +// axes may be specified. +// +// For example: +// +// ``` +// # 't' is [0, 1, 2, 3, 4] +// roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] +// +// # shifting along multiple dimensions +// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +// roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] +// +// # shifting along the same axis multiple times +// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +// roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] +// ``` +// +// Arguments: +// +// shift: Dimension must be 0-D or 1-D. `shift[i]` specifies the number of places by which +// elements are shifted positively (towards larger indices) along the dimension +// specified by `axis[i]`. Negative shifts will roll the elements in the opposite +// direction. +// axis: Dimension must be 0-D or 1-D. `axis[i]` specifies the dimension that the shift +// `shift[i]` should occur. If the same axis is referenced more than once, the +// total shift for that axis will be the sum of all the shifts that belong to that +// axis. +// +// Returns Has the same shape and size as the input. The elements are shifted +// positively (towards larger indices) by the offsets of `shift` along the +// dimensions of `axis`. +func Roll(scope *Scope, input tf.Output, shift tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Roll", + Input: []tf.Input{ + input, shift, axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Looks up keys in a table, outputs the corresponding values. +// +// The tensor `keys` must of the same type as the keys of the table. +// The output `values` is of the type of the table values. +// +// The scalar `default_value` is the value output for keys not present in the +// table. It must also be of the same type as the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// +// +// Returns Same shape as `keys`. Values found in the table, or `default_values` +// for missing keys. +func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableFindV2", + Input: []tf.Input{ + table_handle, keys, default_value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeBmpAttr is an optional argument to DecodeBmp. +type DecodeBmpAttr func(optionalAttr) + +// DecodeBmpChannels sets the optional channels attribute to value. +// If not specified, defaults to 0 +func DecodeBmpChannels(value int64) DecodeBmpAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// Decode the first frame of a BMP-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the BMP-encoded image. +// * 3: output an RGB image. +// * 4: output an RGBA image. +// +// Arguments: +// contents: 0-D. The BMP-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`. RGB order +func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeBmp", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Removes keys and its associated values from a table. +// +// The tensor `keys` must of the same type as the keys of the table. Keys not +// already in the table are silently ignored. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys of the elements to remove. +// +// Returns the created operation. +func LookupTableRemoveV2(scope *Scope, table_handle tf.Output, keys tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableRemoveV2", + Input: []tf.Input{ + table_handle, keys, + }, + } + return scope.AddOperation(opspec) +} + +// SerializeSparseAttr is an optional argument to SerializeSparse. +type SerializeSparseAttr func(optionalAttr) + +// SerializeSparseOutType sets the optional out_type attribute to value. +// +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { + return func(m optionalAttr) { + m["out_type"] = value + } +} + +// Serialize a `SparseTensor` into a `[3]` `Tensor` object. +// +// Arguments: +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SerializeSparse", + Input: []tf.Input{ + sparse_indices, sparse_values, sparse_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the number of elements in the given table. +// +// Arguments: +// table_handle: Handle to the table. +// +// Returns Scalar that contains number of elements in the table. +func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableSizeV2", + Input: []tf.Input{ + table_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. +type TensorArrayGatherV3Attr func(optionalAttr) + +// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Gather specific elements from the TensorArray into output `value`. +// +// All elements selected by `indices` must have the same shape. +// +// Arguments: +// handle: The handle to a TensorArray. +// indices: The locations in the TensorArray from which to read tensor elements. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns All of the elements in the TensorArray, concatenated along a new +// axis (the new dimension 0). +func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayGatherV3", + Input: []tf.Input{ + handle, indices, flow_in, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Converts the quantized `input` tensor into a lower-precision `output`. +// +// Converts the quantized `input` tensor into a lower-precision `output`, using the +// output range specified with `requested_output_min` and `requested_output_max`. +// +// `[input_min, input_max]` are scalar floats that specify the range for the float +// interpretation of the `input` data. For example, if `input_min` is -1.0f and +// `input_max` is 1.0f, and we are dealing with `quint16` quantized data, then a 0 +// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. +// +// Arguments: +// +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// requested_output_min: The float value that the minimum quantized output value represents. +// requested_output_max: The float value that the maximum quantized output value represents. +// out_type: The type of the output. Should be a lower bit depth than Tinput. +// +// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output. +func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + opspec := tf.OpSpec{ + Type: "Requantize", + Input: []tf.Input{ + input, input_min, input_max, requested_output_min, requested_output_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Reads and outputs the entire contents of the input filename. +func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReadFile", + Input: []tf.Input{ + filename, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// NonMaxSuppressionV4Attr is an optional argument to NonMaxSuppressionV4. +type NonMaxSuppressionV4Attr func(optionalAttr) + +// NonMaxSuppressionV4PadToMaxOutputSize sets the optional pad_to_max_output_size attribute to value. +// +// value: If true, the output `selected_indices` is padded to be of length +// `max_output_size`. Defaults to false. +// If not specified, defaults to false +func NonMaxSuppressionV4PadToMaxOutputSize(value bool) NonMaxSuppressionV4Attr { + return func(m optionalAttr) { + m["pad_to_max_output_size"] = value + } +} + +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system and more +// generally is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`.A 0-D integer tensor representing the number of valid elements in +// `selected_indices`, with the valid elements appearing first. +func NonMaxSuppressionV4(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output, optional ...NonMaxSuppressionV4Attr) (selected_indices tf.Output, valid_outputs tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "NonMaxSuppressionV4", + Input: []tf.Input{ + boxes, scores, max_output_size, iou_threshold, score_threshold, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Computes the product along segments of a tensor. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// This operator is similar to the unsorted segment sum operator found +// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). +// Instead of computing the sum over segments, it computes the product of all +// entries belonging to a segment such that: +// +// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples +// `j...` such that `segment_ids[j...] == i`. +// +// For example: +// +// ``` python +// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) +// tf.unsorted_segment_prod(c, tf.constant([0, 1, 0]), num_segments=2) +// # ==> [[ 4, 6, 6, 4], +// # [5, 6, 7, 8]] +// ``` +// +// If there is no entry for a given segment ID `i`, it outputs 1. +// +// If the given segment ID `i` is negative, then the corresponding value is +// dropped, and will not be included in the result. +// +// Arguments: +// +// segment_ids: A tensor whose shape is a prefix of `data.shape`. +// +// +// Returns Has same shape as data, except for the first `segment_ids.rank` +// dimensions, which are replaced with a single dimension which has size +// `num_segments`. +func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "UnsortedSegmentProd", + Input: []tf.Input{ + data, segment_ids, num_segments, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Outputs all keys and values in the table. +// +// Arguments: +// table_handle: Handle to the table. +// +// +// +// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. +func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} + opspec := tf.OpSpec{ + Type: "LookupTableExportV2", + Input: []tf.Input{ + table_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. +type MutableHashTableV2Attr func(optionalAttr) + +// MutableHashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableV2Container(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes sigmoid of `x` element-wise. +// +// Specifically, `y = 1 / (1 + exp(-x))`. +func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sigmoid", + Input: []tf.Input{ + x, }, } op := scope.AddOperation(opspec) @@ -18335,2844 +35268,99 @@ func Squeeze(scope *Scope, input tf.Output, optional ...SqueezeAttr) (output tf. return op.Output(0) } -// ResourceApplyAdadeltaAttr is an optional argument to ResourceApplyAdadelta. -type ResourceApplyAdadeltaAttr func(optionalAttr) +// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2. +type MutableDenseHashTableV2Attr func(optionalAttr) -// ResourceApplyAdadeltaUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var, accum and update_accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyAdadeltaUseLocking(value bool) ResourceApplyAdadeltaAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the adadelta scheme. -// -// accum = rho() * accum + (1 - rho()) * grad.square(); -// update = (update_accum + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad; -// update_accum = rho() * update_accum + (1 - rho()) * update.square(); -// var -= update; -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// accum_update: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdadeltaAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAdadelta", - Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. -type NonMaxSuppressionAttr func(optionalAttr) - -// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. -// -// value: A float representing the threshold for deciding whether boxes -// overlap too much with respect to IOU. -// If not specified, defaults to 0.5 -func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { - return func(m optionalAttr) { - m["iou_threshold"] = value - } -} - -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) -// -// Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "NonMaxSuppression", - Input: []tf.Input{ - boxes, scores, max_output_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that emits `components` as a tuple of tensors once. -func TensorDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "TensorDataset", - Input: []tf.Input{ - tf.OutputList(components), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// VariableShapeAttr is an optional argument to VariableShape. -type VariableShapeAttr func(optionalAttr) - -// VariableShapeOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func VariableShapeOutType(value tf.DataType) VariableShapeAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Returns the shape of the variable pointed to by `resource`. -// -// This operation returns a 1-D integer tensor representing the shape of `input`. -// -// For example: -// -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// shape(t) ==> [2, 2, 3] -// ``` -func VariableShape(scope *Scope, input tf.Output, optional ...VariableShapeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "VariableShape", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Store the input tensor in the state of the current session. -// -// Arguments: -// value: The tensor to be stored. -// -// Returns The handle for the tensor stored in the session state, represented -// as a ResourceHandle object. -func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "GetSessionHandleV2", - Input: []tf.Input{ - value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. -type ResourceApplyAdamAttr func(optionalAttr) - -// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, uses the nesterov update. -// If not specified, defaults to false -func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update '*var' according to the Adam algorithm. -// -// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ -// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ -// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ -// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ -// -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// beta1_power: Must be a scalar. -// beta2_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAdam", - Input: []tf.Input{ - var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// SdcaOptimizerAttr is an optional argument to SdcaOptimizer. -type SdcaOptimizerAttr func(optionalAttr) - -// SdcaOptimizerAdaptative sets the optional adaptative attribute to value. -// -// value: Whether to use Adaptive SDCA for the inner loop. -// If not specified, defaults to true -func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { - return func(m optionalAttr) { - m["adaptative"] = value - } -} - -// Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for -// -// linear models with L1 + L2 regularization. As global optimization objective is -// strongly-convex, the optimizer optimizes the dual objective at each step. The -// optimizer applies each update one example at a time. Examples are sampled -// uniformly, and the optimizer is learning rate free and enjoys linear convergence -// rate. -// -// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).
-// Shai Shalev-Shwartz, Tong Zhang. 2012 -// -// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$ -// -// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
-// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, -// Peter Richtarik, Martin Takac. 2015 -// -// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
-// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 -// -// Arguments: -// sparse_example_indices: a list of vectors which contain example indices. -// sparse_feature_indices: a list of vectors which contain feature indices. -// sparse_feature_values: a list of vectors which contains feature value -// associated with each feature group. -// dense_features: a list of matrices which contains the dense feature values. -// example_weights: a vector which contains the weight associated with each -// example. -// example_labels: a vector which contains the label/target associated with each -// example. -// sparse_indices: a list of vectors where each value is the indices which has -// corresponding weights in sparse_weights. This field maybe omitted for the -// dense approach. -// sparse_weights: a list of vectors where each value is the weight associated with -// a sparse feature group. -// dense_weights: a list of vectors where the values are the weights associated -// with a dense feature group. -// example_state_data: a list of vectors containing the example state data. -// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, -// squared and hinge losses. -// l1: Symmetric l1 regularization strength. -// l2: Symmetric l2 regularization strength. -// num_loss_partitions: Number of partitions of the global loss function. -// num_inner_iterations: Number of iterations per mini-batch. -// -// Returns a list of vectors containing the updated example state -// data.a list of vectors where each value is the delta -// weights associated with a sparse feature group.a list of vectors where the values are the delta -// weights associated with a dense feature group. -func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SdcaOptimizer", - Input: []tf.Input{ - tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - out_example_state_data = op.Output(idx) - if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return - } - if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return - } - return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights -} - -// ExperimentalParseExampleDatasetAttr is an optional argument to ExperimentalParseExampleDataset. -type ExperimentalParseExampleDatasetAttr func(optionalAttr) - -// ExperimentalParseExampleDatasetSloppy sets the optional sloppy attribute to value. -// If not specified, defaults to false -func ExperimentalParseExampleDatasetSloppy(value bool) ExperimentalParseExampleDatasetAttr { - return func(m optionalAttr) { - m["sloppy"] = value - } -} - -// Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features. -// -// Arguments: -// -// -// dense_defaults: A dict mapping string keys to `Tensor`s. -// The keys of the dict must match the dense_keys of the feature. -// sparse_keys: A list of string keys in the examples features. -// The results for these keys will be returned as `SparseTensor` objects. -// dense_keys: A list of Ndense string Tensors (scalars). -// The keys expected in the Examples features associated with dense values. -// sparse_types: A list of `DTypes` of the same length as `sparse_keys`. -// Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), -// and `tf.string` (`BytesList`) are supported. -// dense_shapes: List of tuples with the same length as `dense_keys`. -// The shape of the data for each dense feature referenced by `dense_keys`. -// Required for any input tensors identified by `dense_keys`. Must be -// either fully defined, or may contain an unknown first dimension. -// An unknown first dimension means the feature is treated as having -// a variable number of blocks, and the output shape along this dimension -// is considered unknown at graph build time. Padding is applied for -// minibatch elements smaller than the maximum number of blocks for the -// given feature along this dimension. -// output_types: The type list for the return values. -// output_shapes: The list of shapes being produced. -func ExperimentalParseExampleDataset(scope *Scope, input_dataset tf.Output, num_parallel_calls tf.Output, dense_defaults []tf.Output, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ExperimentalParseExampleDatasetAttr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes, "output_types": output_types, "output_shapes": output_shapes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ExperimentalParseExampleDataset", - Input: []tf.Input{ - input_dataset, num_parallel_calls, tf.OutputList(dense_defaults), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 2D real-valued fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 2 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. -// -// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. -// -// @compatibility(numpy) -// Equivalent to np.fft.rfft2 -// @end_compatibility -func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RFFT2D", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. -type ResourceSparseApplyFtrlAttr func(optionalAttr) - -// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. -// -// That is for rows we have grad for, we update var, accum and linear as follows: -// accum_new = accum + grad * grad -// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrl", - Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, lr_power, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Returns which elements of x are Inf. -// -// @compatibility(numpy) -// Equivalent to np.isinf -// @end_compatibility -func IsInf(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IsInf", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Gather ragged slices from `params` axis `0` according to `indices`. -// -// Outputs a `RaggedTensor` output composed from `output_dense_values` and -// `output_nested_splits`, such that: -// -// ```python -// output.shape = indices.shape + params.shape[1:] -// output.ragged_rank = indices.shape.ndims + params.ragged_rank -// output[i...j, d0...dn] = params[indices[i...j], d0...dn] -// ``` -// -// where -// -// * `params = -// ragged.from_nested_row_splits(params_dense_values, params_nested_splits)` -// provides the values that should be gathered. -// * `indices` ia a dense tensor with dtype `int32` or `int64`, indicating which -// values should be gathered. -// * `output = -// ragged.from_nested_row_splits(output_dense_values, output_nested_splits)` -// is the output tensor. -// -// (Note: This c++ op is used to implement the higher-level python -// `tf.ragged.gather` op, which also supports ragged indices.) -// -// -// Arguments: -// params_nested_splits: The `nested_row_splits` tensors that define the row-partitioning for the -// `params` RaggedTensor input. -// params_dense_values: The `flat_values` for the `params` RaggedTensor. There was a terminology change -// at the python level from dense_values to flat_values, so dense_values is the -// deprecated name. -// indices: Indices in the outermost dimension of `params` of the values that should be -// gathered. -// OUTPUT_RAGGED_RANK: The ragged rank of the output RaggedTensor. `output_nested_splits` will contain -// this number of `row_splits` tensors. This value should equal -// `indices.shape.ndims + params.ragged_rank - 1`. -// -// Returns The `nested_row_splits` tensors that define the row-partitioning for the -// returned RaggedTensor.The `flat_values` for the returned RaggedTensor. -func RaggedGather(scope *Scope, params_nested_splits []tf.Output, params_dense_values tf.Output, indices tf.Output, OUTPUT_RAGGED_RANK int64) (output_nested_splits []tf.Output, output_dense_values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"OUTPUT_RAGGED_RANK": OUTPUT_RAGGED_RANK} - opspec := tf.OpSpec{ - Type: "RaggedGather", - Input: []tf.Input{ - tf.OutputList(params_nested_splits), params_dense_values, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output_nested_splits, idx, err = makeOutputList(op, idx, "output_nested_splits"); err != nil { - scope.UpdateErr("RaggedGather", err) - return - } - output_dense_values = op.Output(idx) - return output_nested_splits, output_dense_values -} - -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system. Note that this -// algorithm is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// -// selected_indices = tf.image.non_max_suppression_v2( -// boxes, scores, max_output_size, iou_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) -// -// Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// iou_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too much with respect to IOU. -// -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output) (selected_indices tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NonMaxSuppressionV2", - Input: []tf.Input{ - boxes, scores, max_output_size, iou_threshold, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TruncatedNormalAttr is an optional argument to TruncatedNormal. -type TruncatedNormalAttr func(optionalAttr) - -// TruncatedNormalSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func TruncatedNormalSeed(value int64) TruncatedNormalAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// TruncatedNormalSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func TruncatedNormalSeed2(value int64) TruncatedNormalAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from a truncated normal distribution. -// -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. -// -// Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. -// -// Returns A tensor of the specified shape filled with random truncated normal -// values. -func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TruncatedNormal", - Input: []tf.Input{ - shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StringToNumberAttr is an optional argument to StringToNumber. -type StringToNumberAttr func(optionalAttr) - -// StringToNumberOutType sets the optional out_type attribute to value. -// -// value: The numeric type to interpret each string in `string_tensor` as. -// If not specified, defaults to DT_FLOAT -func StringToNumberOutType(value tf.DataType) StringToNumberAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Converts each string in the input Tensor to the specified numeric type. -// -// (Note that int32 overflow results in an error while float overflow -// results in a rounded value.) -// -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StringToNumber", - Input: []tf.Input{ - string_tensor, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2. -type ResourceApplyFtrlV2Attr func(optionalAttr) - -// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the Ftrl-proximal scheme. -// -// grad_with_shrinkage = grad + 2 * l2_shrinkage * var -// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage -// linear += grad_with_shrinkage + -// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regulariation. Must be a scalar. -// l2: L2 shrinkage regulariation. Must be a scalar. -// -// lr_power: Scaling factor. Must be a scalar. -// -// Returns the created operation. -func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyFtrlV2", - Input: []tf.Input{ - var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// SkipgramAttr is an optional argument to Skipgram. -type SkipgramAttr func(optionalAttr) - -// SkipgramWindowSize sets the optional window_size attribute to value. -// -// value: The number of words to predict to the left and right of the target. -// If not specified, defaults to 5 -func SkipgramWindowSize(value int64) SkipgramAttr { - return func(m optionalAttr) { - m["window_size"] = value - } -} - -// SkipgramMinCount sets the optional min_count attribute to value. -// -// value: The minimum number of word occurrences for it to be included in the -// vocabulary. -// If not specified, defaults to 5 -func SkipgramMinCount(value int64) SkipgramAttr { - return func(m optionalAttr) { - m["min_count"] = value - } -} - -// SkipgramSubsample sets the optional subsample attribute to value. -// -// value: Threshold for word occurrence. Words that appear with higher -// frequency will be randomly down-sampled. Set to 0 to disable. -// If not specified, defaults to 0.001 -func SkipgramSubsample(value float32) SkipgramAttr { - return func(m optionalAttr) { - m["subsample"] = value - } -} - -// Parses a text file and creates a batch of examples. -// -// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result -// -// Arguments: -// filename: The corpus's text file name. -// batch_size: The size of produced batch. -// -// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids. -func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Skipgram", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) -} - -// ThreadUnsafeUnigramCandidateSamplerAttr is an optional argument to ThreadUnsafeUnigramCandidateSampler. -type ThreadUnsafeUnigramCandidateSamplerAttr func(optionalAttr) - -// ThreadUnsafeUnigramCandidateSamplerSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ThreadUnsafeUnigramCandidateSamplerSeed(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// ThreadUnsafeUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func ThreadUnsafeUnigramCandidateSamplerSeed2(value int64) ThreadUnsafeUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. -// -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. -// -// Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). -// -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func ThreadUnsafeUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...ThreadUnsafeUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ThreadUnsafeUnigramCandidateSampler", - Input: []tf.Input{ - true_classes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// MaxPoolV2Attr is an optional argument to MaxPoolV2. -type MaxPoolV2Attr func(optionalAttr) - -// MaxPoolV2DataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolV2DataFormat(value string) MaxPoolV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs max pooling on the input. -// -// Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor. -func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolV2Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolV2", - Input: []tf.Input{ - input, ksize, strides, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Does nothing. Serves as a control trigger for scheduling. -// -// Only useful as a placeholder for control edges. -// -// Returns the created operation. -func ControlTrigger(scope *Scope) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ControlTrigger", - } - return scope.AddOperation(opspec) -} - -// Deprecated. Use TensorArrayReadV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 -func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - opspec := tf.OpSpec{ - Type: "TensorArrayReadV2", - Input: []tf.Input{ - handle, index, flow_in, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Batch normalization. -// -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() -// -// This op is deprecated. Prefer `tf.nn.batch_normalization`. -// -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// beta: A 1D beta Tensor with size matching the last dimension of t. -// An offset to be added to the normalized tensor. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this tensor will be multiplied -// with the normalized tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} - opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalization", - Input: []tf.Input{ - t, m, v, beta, gamma, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap. -type AddManySparseToTensorsMapAttr func(optionalAttr) - -// AddManySparseToTensorsMapContainer sets the optional container attribute to value. -// -// value: The container name for the `SparseTensorsMap` created by this op. -// If not specified, defaults to "" -func AddManySparseToTensorsMapContainer(value string) AddManySparseToTensorsMapAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// AddManySparseToTensorsMapSharedName sets the optional shared_name attribute to value. -// -// value: The shared name for the `SparseTensorsMap` created by this op. -// If blank, the new Operation's unique name is used. -// If not specified, defaults to "" -func AddManySparseToTensorsMapSharedName(value string) AddManySparseToTensorsMapAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. -// -// A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`, -// `sparse_values`, and `sparse_shape`, where -// -// ```sparse_indices.shape[1] == sparse_shape.shape[0] == R``` -// -// An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor` -// having a first `sparse_indices` column taking values between `[0, N)`, where -// the minibatch size `N == sparse_shape[0]`. -// -// The input `SparseTensor` must have rank `R` greater than 1, and the first -// dimension is treated as the minibatch dimension. Elements of the `SparseTensor` -// must be sorted in increasing order of this first dimension. The stored -// `SparseTensor` objects pointed to by each row of the output `sparse_handles` -// will have rank `R-1`. -// -// The `SparseTensor` values can then be read out as part of a minibatch by passing -// the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure -// the correct `SparseTensorsMap` is accessed, ensure that the same -// `container` and `shared_name` are passed to that Op. If no `shared_name` -// is provided here, instead use the *name* of the Operation created by calling -// `AddManySparseToTensorsMap` as the `shared_name` passed to -// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. -// -// Arguments: -// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. -// `sparse_indices[:, 0]` must be ordered values in `[0, N)`. -// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. -// The minibatch size `N == sparse_shape[0]`. -// -// Returns 1-D. The handles of the `SparseTensor` now stored in the -// `SparseTensorsMap`. Shape: `[N]`. -func AddManySparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddManySparseToTensorsMapAttr) (sparse_handles tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AddManySparseToTensorsMap", - Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TPUReplicateMetadataAttr is an optional argument to TPUReplicateMetadata. -type TPUReplicateMetadataAttr func(optionalAttr) - -// TPUReplicateMetadataNumCoresPerReplica sets the optional num_cores_per_replica attribute to value. -// -// value: Number of cores per replica. Used for model parallelism. -// If not specified, defaults to 1 -func TPUReplicateMetadataNumCoresPerReplica(value int64) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["num_cores_per_replica"] = value - } -} - -// TPUReplicateMetadataTopology sets the optional topology attribute to value. -// -// value: TopologyProto indicating the topology of the TPU pod slice. -// If not specified, defaults to "" -func TPUReplicateMetadataTopology(value string) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["topology"] = value - } -} - -// TPUReplicateMetadataUseTpu sets the optional use_tpu attribute to value. -// -// value: Whether to place the computation on the TPU. -// If not specified, defaults to true -func TPUReplicateMetadataUseTpu(value bool) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["use_tpu"] = value - } -} - -// TPUReplicateMetadataDeviceAssignment sets the optional device_assignment attribute to value. -// -// value: The assignment of devices for the computation. -// If not specified, defaults to <> -func TPUReplicateMetadataDeviceAssignment(value []int64) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["device_assignment"] = value - } -} - -// TPUReplicateMetadataComputationShape sets the optional computation_shape attribute to value. -// -// value: DEPRECATED. Use num_cores_per_replica instead. -// If not specified, defaults to <> -func TPUReplicateMetadataComputationShape(value []int64) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["computation_shape"] = value - } -} - -// TPUReplicateMetadataHostComputeCore sets the optional host_compute_core attribute to value. -// If not specified, defaults to <> -func TPUReplicateMetadataHostComputeCore(value []string) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["host_compute_core"] = value - } -} - -// TPUReplicateMetadataPaddingMap sets the optional padding_map attribute to value. -// If not specified, defaults to <> -func TPUReplicateMetadataPaddingMap(value []string) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["padding_map"] = value - } -} - -// TPUReplicateMetadataStepMarkerLocation sets the optional step_marker_location attribute to value. -// If not specified, defaults to "STEP_MARK_AT_ENTRY" -func TPUReplicateMetadataStepMarkerLocation(value string) TPUReplicateMetadataAttr { - return func(m optionalAttr) { - m["step_marker_location"] = value - } -} - -// Metadata indicaitng how the TPU computation should be replicated. -// -// Arguments: -// num_replicas: Number of replicas of the computation -// -// Returns the created operation. -func TPUReplicateMetadata(scope *Scope, num_replicas int64, optional ...TPUReplicateMetadataAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_replicas": num_replicas} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TPUReplicateMetadata", - - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingFTRLParametersGradAccumDebug. -type LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) - -// LoadTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingFTRLParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingFTRLParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load FTRL embedding parameters with debug support. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the FTRL optimization algorithm. -// accumulators: Value of accumulators used in the FTRL optimization algorithm. -// linears: Value of linears used in the FTRL optimization algorithm. -// gradient_accumulators: Value of gradient_accumulators used in the FTRL optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingFTRLParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, linears tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingFTRLParametersGradAccumDebugAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingFTRLParametersGradAccumDebug", - Input: []tf.Input{ - parameters, accumulators, linears, gradient_accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Concatenates tensors along one dimension. -// -// Arguments: -// values: List of `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// axis: 0-D. The dimension along which to concatenate. Must be in the -// range [-rank(values), rank(values)). -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConcatV2", - Input: []tf.Input{ - tf.OutputList(values), axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reads and outputs the entire contents of the input filename. -func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReadFile", - Input: []tf.Input{ - filename, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AvgPoolGradAttr is an optional argument to AvgPoolGrad. -type AvgPoolGradAttr func(optionalAttr) - -// AvgPoolGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of the average pooling function. -// -// Arguments: -// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. -// the output of `avg_pool`. -// ksize: The size of the sliding window for each dimension of the input. -// strides: The stride of the sliding window for each dimension of the input. -// padding: The type of padding algorithm to use. -// -// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. -func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AvgPoolGrad", - Input: []tf.Input{ - orig_input_shape, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high overlaps -// with previously selected boxes. Bounding boxes with score less than -// `score_threshold` are removed. N-by-n overlap values are supplied as square matrix, -// which allows for defining a custom overlap criterium (eg. intersection over union, -// intersection over area, etc.). -// -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// -// selected_indices = tf.image.non_max_suppression_with_overlaps( -// overlaps, scores, max_output_size, overlap_threshold, score_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) -// -// Arguments: -// overlaps: A 2-D float tensor of shape `[num_boxes, num_boxes]` representing -// the n-by-n box overlap values. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// overlap_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too. -// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove -// boxes based on score. -// -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppressionWithOverlaps(scope *Scope, overlaps tf.Output, scores tf.Output, max_output_size tf.Output, overlap_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NonMaxSuppressionWithOverlaps", - Input: []tf.Input{ - overlaps, scores, max_output_size, overlap_threshold, score_threshold, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FractionalAvgPoolGradAttr is an optional argument to FractionalAvgPoolGrad. -type FractionalAvgPoolGradAttr func(optionalAttr) - -// FractionalAvgPoolGradOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` -// -// `value 20 5 16 3 7` -// -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [41/3, 26/3] for fractional avg pooling. -// If not specified, defaults to false -func FractionalAvgPoolGradOverlapping(value bool) FractionalAvgPoolGradAttr { - return func(m optionalAttr) { - m["overlapping"] = value - } -} - -// Computes gradient of the FractionalAvgPool function. -// -// Unlike FractionalMaxPoolGrad, we don't need to find arg_max for -// FractionalAvgPoolGrad, we just need to evenly back-propagate each element of -// out_backprop to those indices that form the same pooling cell. Therefore, we -// just need to know the shape of original input tensor, instead of the whole -// tensor. -// -// Arguments: -// orig_input_tensor_shape: Original input tensor shape for `fractional_avg_pool` -// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients -// w.r.t. the output of `fractional_avg_pool`. -// row_pooling_sequence: row pooling sequence, form pooling region with -// col_pooling_sequence. -// col_pooling_sequence: column pooling sequence, form pooling region with -// row_pooling sequence. -// -// Returns 4-D. Gradients w.r.t. the input of `fractional_avg_pool`. -func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalAvgPoolGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FractionalAvgPoolGrad", - Input: []tf.Input{ - orig_input_tensor_shape, out_backprop, row_pooling_sequence, col_pooling_sequence, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StaticRegexReplaceAttr is an optional argument to StaticRegexReplace. -type StaticRegexReplaceAttr func(optionalAttr) - -// StaticRegexReplaceReplaceGlobal sets the optional replace_global attribute to value. -// -// value: If True, the replacement is global, otherwise the replacement -// is done only on the first match. -// If not specified, defaults to true -func StaticRegexReplaceReplaceGlobal(value bool) StaticRegexReplaceAttr { - return func(m optionalAttr) { - m["replace_global"] = value - } -} - -// Replaces the match of pattern in input with rewrite. -// -// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) -// -// Arguments: -// input: The text to be processed. -// pattern: The regular expression to match the input. -// rewrite: The rewrite to be applied to the matched expression. -// -// Returns The text after applying pattern and rewrite. -func StaticRegexReplace(scope *Scope, input tf.Output, pattern string, rewrite string, optional ...StaticRegexReplaceAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"pattern": pattern, "rewrite": rewrite} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StaticRegexReplace", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes gradients for the exponential linear (Elu) operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Elu operation. -// outputs: The outputs of the corresponding Elu operation. -// -// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, -// `gradients` otherwise. -func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "EluGrad", - Input: []tf.Input{ - gradients, outputs, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. -// -// Note that the hash function may change from time to time. -// This functionality will be deprecated and it's recommended to use -// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. -// -// Arguments: -// -// num_buckets: The number of buckets. -// -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_buckets": num_buckets} - opspec := tf.OpSpec{ - Type: "StringToHashBucket", - Input: []tf.Input{ - string_tensor, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that batches `batch_size` elements from `input_dataset`. -// -// Arguments: -// -// batch_size: A scalar representing the number of elements to accumulate in a batch. -// drop_remainder: A scalar representing whether the last batch should be dropped in case its size -// is smaller than desired. -// -// -func BatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "BatchDatasetV2", - Input: []tf.Input{ - input_dataset, batch_size, drop_remainder, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient of `igamma(a, x)` wrt `a`. -func IgammaGradA(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IgammaGradA", - Input: []tf.Input{ - a, x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that contains `count` elements from the `input_dataset`. -// -// Arguments: -// -// count: A scalar representing the number of elements from the `input_dataset` -// that should be taken. A value of `-1` indicates that all of `input_dataset` -// is taken. -// -// -func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "TakeDataset", - Input: []tf.Input{ - input_dataset, count, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FakeQuantWithMinMaxVarsAttr is an optional argument to FakeQuantWithMinMaxVars. -type FakeQuantWithMinMaxVarsAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsNumBits(value int64) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsNarrowRange(value bool) FakeQuantWithMinMaxVarsAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` -// -// and `max` to 'outputs' tensor of same shape as `inputs`. -// -// `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. -// -// This operation has a gradient and thus allows for training `min` and `max` -// values. -func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsAttr) (outputs tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVars", - Input: []tf.Input{ - inputs, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingMomentumParametersAttr is an optional argument to RetrieveTPUEmbeddingMomentumParameters. -type RetrieveTPUEmbeddingMomentumParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingMomentumParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingMomentumParametersTableId(value int64) RetrieveTPUEmbeddingMomentumParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingMomentumParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingMomentumParametersTableName(value string) RetrieveTPUEmbeddingMomentumParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve Momentum embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the Momentum optimization algorithm.Parameter momenta updated by the Momentum optimization algorithm. -func RetrieveTPUEmbeddingMomentumParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersAttr) (parameters tf.Output, momenta tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingMomentumParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Forwards the value of an available tensor from `inputs` to `output`. -// -// `Merge` waits for at least one of the tensors in `inputs` to become available. -// It is usually combined with `Switch` to implement branching. -// -// `Merge` forwards the first tensor to become available to `output`, and sets -// `value_index` to its index in `inputs`. -// -// Arguments: -// inputs: The input tensors, exactly one of which will become available. -// -// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. -func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Merge", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// QueueCloseV2Attr is an optional argument to QueueCloseV2. -type QueueCloseV2Attr func(optionalAttr) - -// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. -// -// value: If true, all pending enqueue requests that are -// blocked on the given queue will be canceled. -// If not specified, defaults to false -func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { - return func(m optionalAttr) { - m["cancel_pending_enqueues"] = value - } -} - -// Closes the given queue. -// -// This operation signals that no more elements will be enqueued in the -// given queue. Subsequent Enqueue(Many) operations will fail. -// Subsequent Dequeue(Many) operations will continue to succeed if -// sufficient elements remain in the queue. Subsequent Dequeue(Many) -// operations that would block will fail immediately. -// -// Arguments: -// handle: The handle to a queue. -// -// Returns the created operation. -func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueCloseV2", - Input: []tf.Input{ - handle, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Writes the given dataset to the given file using the TFRecord format. -// -// Arguments: -// input_dataset: A variant tensor representing the dataset to write. -// filename: A scalar string tensor representing the filename to use. -// compression_type: A scalar string tensor containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// -// Returns the created operation. -func ExperimentalDatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ExperimentalDatasetToTFRecord", - Input: []tf.Input{ - input_dataset, filename, compression_type, - }, - } - return scope.AddOperation(opspec) -} - -// BiasAddGradAttr is an optional argument to BiasAddGrad. -type BiasAddGradAttr func(optionalAttr) - -// BiasAddGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the bias tensor will be added to the last dimension -// of the value tensor. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// The tensor will be added to "in_channels", the third-to-the-last -// dimension. -// If not specified, defaults to "NHWC" -func BiasAddGradDataFormat(value string) BiasAddGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// The backward operation for "BiasAdd" on the "bias" tensor. -// -// It accumulates all the values from out_backprop into the feature dimension. -// For NHWC data format, the feature dimension is the last. For NCHW data format, -// the feature dimension is the third-to-last. -// -// Arguments: -// out_backprop: Any number of dimensions. -// -// Returns 1-D with size the feature dimension of `out_backprop`. -func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "BiasAddGrad", - Input: []tf.Input{ - out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reduces `input` from `num_devices` using `reduction` to a single device. -// -// Reduces `input` from `num_devices` using `reduction` to a single device. -// -// The graph should be constructed so that all inputs have a valid device -// assignment, and the op itself is assigned one of these devices. -// -// input: The input to the reduction. -// data: the value of the reduction across all `num_devices` devices. -// reduction: the reduction operation to perform. -func NcclReduce(scope *Scope, input []tf.Output, reduction string) (data tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"reduction": reduction} - opspec := tf.OpSpec{ - Type: "NcclReduce", - Input: []tf.Input{ - tf.OutputList(input), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient of morphological 2-D dilation with respect to the input. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. -// strides: 1-D of length 4. The stride of the sliding window for each dimension of -// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: 1-D of length 4. The input stride for atrous morphological dilation. -// Must be: `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. -// -// Returns 4-D with shape `[batch, in_height, in_width, depth]`. -func Dilation2DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (in_backprop tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} - opspec := tf.OpSpec{ - Type: "Dilation2DBackpropInput", - Input: []tf.Input{ - input, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// An Op to sum inputs across replicated TPU instances. -// -// Each instance supplies its own input. -// -// For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`. -// Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, -// and `B, D, F, H` as group 1. Thus we get the outputs: -// `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. -// -// Arguments: -// input: The local input to the sum. -// group_assignment: An int32 tensor with shape -// [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the -// replica ids in the ith subgroup. -// -// Returns The sum of all the distributed inputs. -func CrossReplicaSum(scope *Scope, input tf.Output, group_assignment tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CrossReplicaSum", - Input: []tf.Input{ - input, group_assignment, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum. -type ResourceSparseApplyMomentumAttr func(optionalAttr) - -// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var - lr * momentum * accum, so in the end, the var you get is actually -// var - lr * momentum * accum. -// If not specified, defaults to false -func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update relevant entries in '*var' and '*accum' according to the momentum scheme. -// -// Set use_nesterov = True if you want to use Nesterov momentum. -// -// That is for rows we have grad for, we update var and accum as follows: -// -// accum = accum * momentum + grad -// var -= lr * accum -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// momentum: Momentum. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyMomentum", - Input: []tf.Input{ - var_, accum, lr, grad, indices, momentum, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// An Op to permute tensors across replicated TPU instances. -// -// Each instance supplies its own input. -// -// For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing -// source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: -// `[D, A, B, C]`. -// -// Arguments: -// input: The local input to be permuted. Currently only supports float and -// bfloat16. -// source_target_pairs: A tensor with shape [num_pairs, 2]. -// -// Returns The permuted input. -func CollectivePermute(scope *Scope, input tf.Output, source_target_pairs tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CollectivePermute", - Input: []tf.Input{ - input, source_target_pairs, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the complex conjugate of a complex number. -// -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// complex numbers that are the complex conjugate of each element in `input`. The -// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the -// real part and *b* is the imaginary part. -// -// The complex conjugate returned by this operation is of the form \\(a - bj\\). -// -// For example: -// -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] -// ``` -func Conj(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Conj", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingCenteredRMSPropParametersAttr is an optional argument to RetrieveTPUEmbeddingCenteredRMSPropParameters. -type RetrieveTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingCenteredRMSPropParametersTableId(value int64) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingCenteredRMSPropParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingCenteredRMSPropParametersTableName(value string) RetrieveTPUEmbeddingCenteredRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve centered RMSProp embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the centered RMSProp optimization algorithm.Parameter ms updated by the centered RMSProp optimization algorithm.Parameter mom updated by the centered RMSProp optimization algorithm.Parameter mg updated by the centered RMSProp optimization algorithm. -func RetrieveTPUEmbeddingCenteredRMSPropParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingCenteredRMSPropParametersAttr) (parameters tf.Output, ms tf.Output, mom tf.Output, mg tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingCenteredRMSPropParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// StringSplitAttr is an optional argument to StringSplit. -type StringSplitAttr func(optionalAttr) - -// StringSplitSkipEmpty sets the optional skip_empty attribute to value. -// -// value: A `bool`. If `True`, skip the empty strings from the result. -// If not specified, defaults to true -func StringSplitSkipEmpty(value bool) StringSplitAttr { - return func(m optionalAttr) { - m["skip_empty"] = value - } -} - -// Split elements of `input` based on `delimiter` into a `SparseTensor`. -// -// Let N be the size of source (typically N will be the batch size). Split each -// element of `input` based on `delimiter` and return a `SparseTensor` -// containing the splitted tokens. Empty tokens are ignored. -// -// `delimiter` can be empty, or a string of split characters. If `delimiter` is an -// empty string, each element of `input` is split into individual single-byte -// character strings, including splitting of UTF-8 multibyte sequences. Otherwise -// every character of `delimiter` is a potential split point. -// -// For example: -// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output -// will be -// -// indices = [0, 0; -// 0, 1; -// 1, 0; -// 1, 1; -// 1, 2] -// shape = [2, 3] -// values = ['hello', 'world', 'a', 'b', 'c'] -// -// Arguments: -// input: 1-D. Strings to split. -// delimiter: 0-D. Delimiter characters (bytes), or empty string. -// -// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse -// tensor, where the first value is N and the second value is the maximum number -// of tokens in a single input entry. -func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StringSplit", - Input: []tf.Input{ - input, delimiter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// MaxPool3DAttr is an optional argument to MaxPool3D. -type MaxPool3DAttr func(optionalAttr) - -// MaxPool3DDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DDataFormat(value string) MaxPool3DAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs 3D max pooling on the input. -// -// Arguments: -// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor. -func MaxPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool3D", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Convert JSON-encoded Example records to binary protocol buffer strings. -// -// This op translates a tensor containing Example records, encoded using -// the [standard JSON -// mapping](https://developers.google.com/protocol-buffers/docs/proto3#json), -// into a tensor containing the same records encoded as binary protocol -// buffers. The resulting tensor can then be fed to any of the other -// Example-parsing ops. -// -// Arguments: -// json_examples: Each string is a JSON object serialized according to the JSON -// mapping of the Example proto. -// -// Returns Each string is a binary Example protocol buffer corresponding -// to the respective element of `json_examples`. -func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DecodeJSONExample", - Input: []tf.Input{ - json_examples, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2. -type QueueEnqueueManyV2Attr func(optionalAttr) - -// QueueEnqueueManyV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue is too full, this operation will block for up -// to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueEnqueueManyV2TimeoutMs(value int64) QueueEnqueueManyV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } -} - -// Enqueues zero or more tuples of one or more tensors in the given queue. -// -// This operation slices each component tensor along the 0th dimension to -// make multiple queue elements. All of the tuple components must have the -// same size in the 0th dimension. -// -// The components input has k elements, which correspond to the components of -// tuples stored in the given queue. -// -// N.B. If the queue is full, this operation will block until the given -// elements have been enqueued (or 'timeout_ms' elapses, if specified). -// -// Arguments: -// handle: The handle to a queue. -// components: One or more tensors from which the enqueued tensors should -// be taken. -// -// Returns the created operation. -func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueManyV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueEnqueueManyV2", - Input: []tf.Input{ - handle, tf.OutputList(components), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// PrintV2Attr is an optional argument to PrintV2. -type PrintV2Attr func(optionalAttr) - -// PrintV2OutputStream sets the optional output_stream attribute to value. -// -// value: A string specifying the output stream or logging level to print to. -// If not specified, defaults to "stderr" -func PrintV2OutputStream(value string) PrintV2Attr { - return func(m optionalAttr) { - m["output_stream"] = value - } -} - -// Prints a string scalar. -// -// Prints a string scalar to the desired output_stream. -// -// Arguments: -// input: The string scalar to print. -// -// Returns the created operation. -func PrintV2(scope *Scope, input tf.Output, optional ...PrintV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "PrintV2", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// The gradient operator for the SparseSlice op. -// -// This op takes in the upstream gradient w.r.t. non-empty values of -// the sliced `SparseTensor`, and outputs the gradients w.r.t. -// the non-empty values of input `SparseTensor`. -// -// Arguments: -// backprop_val_grad: 1-D. The gradient with respect to -// the non-empty values of the sliced `SparseTensor`. -// input_indices: 2-D. The `indices` of the input `SparseTensor`. -// input_start: 1-D. tensor represents the start of the slice. -// output_indices: 2-D. The `indices` of the sliced `SparseTensor`. -// -// Returns 1-D. The gradient with respect to the non-empty values of input `SparseTensor`. -func SparseSliceGrad(scope *Scope, backprop_val_grad tf.Output, input_indices tf.Output, input_start tf.Output, output_indices tf.Output) (val_grad tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSliceGrad", - Input: []tf.Input{ - backprop_val_grad, input_indices, input_start, output_indices, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset by applying optimizations to `input_dataset`. -// -// Creates a dataset by applying optimizations to `input_dataset`. -// -// Arguments: -// input_dataset: A variant tensor representing the input dataset. -// optimizations: A `tf.string` vector `tf.Tensor` identifying optimizations to use. -// -// -func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "OptimizeDataset", - Input: []tf.Input{ - input_dataset, optimizations, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. -type ResourceApplyProximalAdagradAttr func(optionalAttr) - -// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. -// -// accum += grad * grad -// prox_v = var - lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyProximalAdagrad", - Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. -type MutableHashTableOfTensorsV2Attr func(optionalAttr) - -// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. +// MutableDenseHashTableV2Container sets the optional container attribute to value. // // value: If non-empty, this table is placed in the given container. // Otherwise, a default container is used. // If not specified, defaults to "" -func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { +func MutableDenseHashTableV2Container(value string) MutableDenseHashTableV2Attr { return func(m optionalAttr) { m["container"] = value } } -// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. +// MutableDenseHashTableV2SharedName sets the optional shared_name attribute to value. // // value: If non-empty, this table is shared under the given name across // multiple sessions. // If not specified, defaults to "" -func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { +func MutableDenseHashTableV2SharedName(value string) MutableDenseHashTableV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// MutableDenseHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. // If not specified, defaults to false -func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { +func MutableDenseHashTableV2UseNodeNameSharing(value bool) MutableDenseHashTableV2Attr { return func(m optionalAttr) { m["use_node_name_sharing"] = value } } -// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// MutableDenseHashTableV2ValueShape sets the optional value_shape attribute to value. +// +// value: The shape of each value. // If not specified, defaults to <> -func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { +func MutableDenseHashTableV2ValueShape(value tf.Shape) MutableDenseHashTableV2Attr { return func(m optionalAttr) { m["value_shape"] = value } } -// Creates an empty hash table. +// MutableDenseHashTableV2InitialNumBuckets sets the optional initial_num_buckets attribute to value. +// +// value: The initial number of hash table buckets. Must be a power +// to 2. +// If not specified, defaults to 131072 +func MutableDenseHashTableV2InitialNumBuckets(value int64) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["initial_num_buckets"] = value + } +} + +// MutableDenseHashTableV2MaxLoadFactor sets the optional max_load_factor attribute to value. +// +// value: The maximum ratio between number of entries and number of +// buckets before growing the table. Must be between 0 and 1. +// If not specified, defaults to 0.8 +func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2Attr { + return func(m optionalAttr) { + m["max_load_factor"] = value + } +} + +// Creates an empty hash table that uses tensors as the backing store. +// +// It uses "open addressing" with quadratic reprobing to resolve +// collisions. // // This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a vector. Data can be inserted into the table using +// values. Each value must be a scalar. Data can be inserted into the table using // the insert operations. It does not support the initialization operation. // // Arguments: -// key_dtype: Type of the table keys. +// empty_key: The key used to represent empty key buckets internally. Must not +// be used in insert or lookup operations. +// // value_dtype: Type of the table values. // // Returns Handle to a table. -func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { +func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, deleted_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + attrs := map[string]interface{}{"value_dtype": value_dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MutableHashTableOfTensorsV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. -type ResourceApplyProximalGradientDescentAttr func(optionalAttr) - -// ResourceApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalGradientDescentUseLocking(value bool) ResourceApplyProximalGradientDescentAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' as FOBOS algorithm with fixed learning rate. -// -// prox_v = var - alpha * delta -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// delta: The change. -// -// Returns the created operation. -func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, delta tf.Output, optional ...ResourceApplyProximalGradientDescentAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyProximalGradientDescent", + Type: "MutableDenseHashTableV2", Input: []tf.Input{ - var_, alpha, l1, l2, delta, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Returns 0 if the denominator is zero. -// -// -// *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func DivNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DivNoNan", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Subtracts a value from the current value of a variable. -// -// Any ReadVariableOp with a control dependency on this op is guaranteed to -// see the decremented value or a subsequent newer one. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value by which the variable will be incremented. -// -// Returns the created operation. -func AssignSubVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AssignSubVariableOp", - Input: []tf.Input{ - resource, value, - }, - } - return scope.AddOperation(opspec) -} - -// RestoreAttr is an optional argument to Restore. -type RestoreAttr func(optionalAttr) - -// RestorePreferredShard sets the optional preferred_shard attribute to value. -// -// value: Index of file to open first if multiple files match -// `file_pattern`. -// If not specified, defaults to -1 -func RestorePreferredShard(value int64) RestoreAttr { - return func(m optionalAttr) { - m["preferred_shard"] = value - } -} - -// Restores a tensor from checkpoint files. -// -// Reads a tensor stored in one or several files. If there are several files (for -// instance because a tensor was saved as slices), `file_pattern` may contain -// wildcard symbols (`*` and `?`) in the filename portion only, not in the -// directory portion. -// -// If a `file_pattern` matches several files, `preferred_shard` can be used to hint -// in which file the requested tensor is likely to be found. This op will first -// open the file at index `preferred_shard` in the list of matching files and try -// to restore tensors from that file. Only if some tensors or tensor slices are -// not found in that first file, then the Op opens all the files. Setting -// `preferred_shard` to match the value passed as the `shard` input -// of a matching `Save` Op may speed up Restore. This attribute only affects -// performance, not correctness. The default value -1 means files are processed in -// order. -// -// See also `RestoreSlice`. -// -// Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// dt: The type of the tensor to be restored. -// -// Returns The restored tensor. -func Restore(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, dt tf.DataType, optional ...RestoreAttr) (tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dt": dt} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Restore", - Input: []tf.Input{ - file_pattern, tensor_name, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QuantizedResizeBilinearAttr is an optional argument to QuantizedResizeBilinear. -type QuantizedResizeBilinearAttr func(optionalAttr) - -// QuantizedResizeBilinearAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func QuantizedResizeBilinearAlignCorners(value bool) QuantizedResizeBilinearAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// QuantizedResizeBilinearHalfPixelCenters sets the optional half_pixel_centers attribute to value. -// If not specified, defaults to false -func QuantizedResizeBilinearHalfPixelCenters(value bool) QuantizedResizeBilinearAttr { - return func(m optionalAttr) { - m["half_pixel_centers"] = value - } -} - -// Resize quantized `images` to `size` using quantized bilinear interpolation. -// -// Input images and output images must be quantized types. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min tf.Output, max tf.Output, optional ...QuantizedResizeBilinearAttr) (resized_images tf.Output, out_min tf.Output, out_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedResizeBilinear", - Input: []tf.Input{ - images, size, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Creates a dataset that uses a custom thread pool to compute `input_dataset`. -// -// Arguments: -// -// num_threads: Identifies the number of threads to use for the private threadpool. -// -// -func ExperimentalPrivateThreadPoolDataset(scope *Scope, input_dataset tf.Output, num_threads tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalPrivateThreadPoolDataset", - Input: []tf.Input{ - input_dataset, num_threads, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DenseToSparseSetOperationAttr is an optional argument to DenseToSparseSetOperation. -type DenseToSparseSetOperationAttr func(optionalAttr) - -// DenseToSparseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func DenseToSparseSetOperationValidateIndices(value bool) DenseToSparseSetOperationAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Applies set operation along last dimension of `Tensor` and `SparseTensor`. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. -// -// Input `set2` is a `SparseTensor` represented by `set2_indices`, `set2_values`, -// and `set2_shape`. For `set2` ranked `n`, 1st `n-1` dimensions must be the same -// as `set1`. Dimension `n` contains values in a set, duplicates are allowed but -// ignored. -// -// If `validate_indices` is `True`, this op validates the order and range of `set2` -// indices. -// -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. -// -// Arguments: -// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// set2_indices: 2D `Tensor`, indices of a `SparseTensor`. Must be in row-major -// order. -// set2_values: 1D `Tensor`, values of a `SparseTensor`. Must be in row-major -// order. -// set2_shape: 1D `Tensor`, shape of a `SparseTensor`. `set2_shape[0...n-1]` must -// be the same as the 1st `n-1` dimensions of `set1`, `result_shape[n]` is the -// max set size across `n-1` dimensions. -// -// -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func DenseToSparseSetOperation(scope *Scope, set1 tf.Output, set2_indices tf.Output, set2_values tf.Output, set2_shape tf.Output, set_operation string, optional ...DenseToSparseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"set_operation": set_operation} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DenseToSparseSetOperation", - Input: []tf.Input{ - set1, set2_indices, set2_values, set2_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// L2 Loss. -// -// Computes half the L2 norm of a tensor without the `sqrt`: -// -// output = sum(t ** 2) / 2 -// -// Arguments: -// t: Typically 2-D, but may have any dimensions. -// -// Returns 0-D. -func L2Loss(scope *Scope, t tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "L2Loss", - Input: []tf.Input{ - t, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StackV2Attr is an optional argument to StackV2. -type StackV2Attr func(optionalAttr) - -// StackV2StackName sets the optional stack_name attribute to value. -// -// value: Overrides the name used for the temporary stack resource. Default -// value is the name of the 'Stack' op (which is guaranteed unique). -// If not specified, defaults to "" -func StackV2StackName(value string) StackV2Attr { - return func(m optionalAttr) { - m["stack_name"] = value - } -} - -// A stack that produces elements in first-in last-out order. -// -// Arguments: -// max_size: The maximum size of the stack if non-negative. If negative, the stack -// size is unlimited. -// elem_type: The type of the elements on the stack. -// -// Returns The handle to the stack. -func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"elem_type": elem_type} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StackV2", - Input: []tf.Input{ - max_size, + empty_key, deleted_key, }, Attrs: attrs, } @@ -21291,12723 +35479,6 @@ func CudnnRNNBackprop(scope *Scope, input tf.Output, input_h tf.Output, input_c return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// InfeedEnqueueAttr is an optional argument to InfeedEnqueue. -type InfeedEnqueueAttr func(optionalAttr) - -// InfeedEnqueueShape sets the optional shape attribute to value. -// -// value: The shape of the tensor. -// If not specified, defaults to <> -func InfeedEnqueueShape(value tf.Shape) InfeedEnqueueAttr { - return func(m optionalAttr) { - m["shape"] = value - } -} - -// InfeedEnqueueLayout sets the optional layout attribute to value. -// -// value: A vector holding the requested layout in minor-to-major sequence. -// If a layout attribute is passed, but its values are all -1, the layout will -// be computed by the infeed operation. -// If not specified, defaults to <> -func InfeedEnqueueLayout(value []int64) InfeedEnqueueAttr { - return func(m optionalAttr) { - m["layout"] = value - } -} - -// InfeedEnqueueDeviceOrdinal sets the optional device_ordinal attribute to value. -// -// value: The TPU device to use. This should be -1 when the Op -// is running on a TPU device, and >= 0 when the Op is running on the CPU -// device. -// If not specified, defaults to -1 -func InfeedEnqueueDeviceOrdinal(value int64) InfeedEnqueueAttr { - return func(m optionalAttr) { - m["device_ordinal"] = value - } -} - -// An op which feeds a single Tensor value into the computation. -// -// Arguments: -// input: A tensor that will be provided using the infeed mechanism. -// -// Returns the created operation. -func InfeedEnqueue(scope *Scope, input tf.Output, optional ...InfeedEnqueueAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "InfeedEnqueue", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes softmax cross entropy cost and gradients to backpropagate. -// -// Inputs are the logits, not probabilities. -// -// Arguments: -// features: batch_size x num_classes matrix -// labels: batch_size x num_classes matrix -// The caller must ensure that each batch of labels represents a valid -// probability distribution. -// -// Returns Per example loss (batch_size vector).backpropagated gradients (batch_size x num_classes matrix). -func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.Output) (loss tf.Output, backprop tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SoftmaxCrossEntropyWithLogits", - Input: []tf.Input{ - features, labels, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// ReduceJoinAttr is an optional argument to ReduceJoin. -type ReduceJoinAttr func(optionalAttr) - -// ReduceJoinKeepDims sets the optional keep_dims attribute to value. -// -// value: If `True`, retain reduced dimensions with length `1`. -// If not specified, defaults to false -func ReduceJoinKeepDims(value bool) ReduceJoinAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// ReduceJoinSeparator sets the optional separator attribute to value. -// -// value: The separator to use when joining. -// If not specified, defaults to "" -func ReduceJoinSeparator(value string) ReduceJoinAttr { - return func(m optionalAttr) { - m["separator"] = value - } -} - -// Joins a string Tensor across the given dimensions. -// -// Computes the string join across dimensions in the given string Tensor of shape -// `[\\(d_0, d_1, ..., d_{n-1}\\)]`. Returns a new Tensor created by joining the input -// strings with the given separator (default: empty string). Negative indices are -// counted backwards from the end, with `-1` being equivalent to `n - 1`. If -// indices are not specified, joins across all dimensions beginning from `n - 1` -// through `0`. -// -// For example: -// -// ```python -// # tensor `a` is [["a", "b"], ["c", "d"]] -// tf.reduce_join(a, 0) ==> ["ac", "bd"] -// tf.reduce_join(a, 1) ==> ["ab", "cd"] -// tf.reduce_join(a, -2) = tf.reduce_join(a, 0) ==> ["ac", "bd"] -// tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"] -// tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]] -// tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]] -// tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"] -// tf.reduce_join(a, [0, 1]) ==> "acbd" -// tf.reduce_join(a, [1, 0]) ==> "abcd" -// tf.reduce_join(a, []) ==> [["a", "b"], ["c", "d"]] -// tf.reduce_join(a) = tf.reduce_join(a, [1, 0]) ==> "abcd" -// ``` -// -// Arguments: -// inputs: The input to be joined. All reduced indices must have non-zero size. -// reduction_indices: The dimensions to reduce over. Dimensions are reduced in the -// order specified. Omitting `reduction_indices` is equivalent to passing -// `[n-1, n-2, ..., 0]`. Negative indices from `-n` to `-1` are supported. -// -// Returns Has shape equal to that of the input with reduced dimensions removed or -// set to `1` depending on `keep_dims`. -func ReduceJoin(scope *Scope, inputs tf.Output, reduction_indices tf.Output, optional ...ReduceJoinAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ReduceJoin", - Input: []tf.Input{ - inputs, reduction_indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TopKAttr is an optional argument to TopK. -type TopKAttr func(optionalAttr) - -// TopKSorted sets the optional sorted attribute to value. -// -// value: If true the resulting `k` elements will be sorted by the values in -// descending order. -// If not specified, defaults to true -func TopKSorted(value bool) TopKAttr { - return func(m optionalAttr) { - m["sorted"] = value - } -} - -// Finds values and indices of the `k` largest elements for the last dimension. -// -// DEPRECATED at GraphDef version 7: Use TopKV2 instead -// -// If the input is a vector (rank-1), finds the `k` largest entries in the vector -// and outputs their values and indices as vectors. Thus `values[j]` is the -// `j`-th largest entry in `input`, and its index is `indices[j]`. -// -// For matrices (resp. higher rank input), computes the top `k` entries in each -// row (resp. vector along the last dimension). Thus, -// -// values.shape = indices.shape = input.shape[:-1] + [k] -// -// If two elements are equal, the lower-index element appears first. -// -// If `k` varies dynamically, use `TopKV2` below. -// -// Arguments: -// input: 1-D or higher with last dimension at least `k`. -// k: Number of top elements to look for along the last dimension (along each -// row for matrices). -// -// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. -func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values tf.Output, indices tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"k": k} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TopK", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// BatchToSpace for N-D tensors of type T. -// -// This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape -// `block_shape + [batch]`, interleaves these blocks back into the grid defined by -// the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as -// the input. The spatial dimensions of this intermediate result are then -// optionally cropped according to `crops` to produce the output. This is the -// reverse of SpaceToBatch. See below for a precise description. -// -// Arguments: -// input: N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`, -// where spatial_shape has M dimensions. -// block_shape: 1-D with shape `[M]`, all values must be >= 1. -// crops: 2-D with shape `[M, 2]`, all values must be >= 0. -// `crops[i] = [crop_start, crop_end]` specifies the amount to crop from input -// dimension `i + 1`, which corresponds to spatial dimension `i`. It is -// required that -// `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`. -// -// This operation is equivalent to the following steps: -// -// 1. Reshape `input` to `reshaped` of shape: -// [block_shape[0], ..., block_shape[M-1], -// batch / prod(block_shape), -// input_shape[1], ..., input_shape[N-1]] -// -// 2. Permute dimensions of `reshaped` to produce `permuted` of shape -// [batch / prod(block_shape), -// -// input_shape[1], block_shape[0], -// ..., -// input_shape[M], block_shape[M-1], -// -// input_shape[M+1], ..., input_shape[N-1]] -// -// 3. Reshape `permuted` to produce `reshaped_permuted` of shape -// [batch / prod(block_shape), -// -// input_shape[1] * block_shape[0], -// ..., -// input_shape[M] * block_shape[M-1], -// -// input_shape[M+1], -// ..., -// input_shape[N-1]] -// -// 4. Crop the start and end of dimensions `[1, ..., M]` of -// `reshaped_permuted` according to `crops` to produce the output of shape: -// [batch / prod(block_shape), -// -// input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], -// ..., -// input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], -// -// input_shape[M+1], ..., input_shape[N-1]] -// -// Some examples: -// -// (1) For the following input of shape `[4, 1, 1, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 1]` and value: -// -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` -// -// (2) For the following input of shape `[4, 1, 1, 3]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 3]` and value: -// -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` -// -// (3) For the following input of shape `[4, 2, 2, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [0, 0]]`: -// -// ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -// -// The output tensor has shape `[1, 4, 4, 1]` and value: -// -// ``` -// x = [[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]] -// ``` -// -// (4) For the following input of shape `[8, 1, 3, 1]`, `block_shape = [2, 2]`, and -// `crops = [[0, 0], [2, 0]]`: -// -// ``` -// x = [[[[0], [1], [3]]], [[[0], [9], [11]]], -// [[[0], [2], [4]]], [[[0], [10], [12]]], -// [[[0], [5], [7]]], [[[0], [13], [15]]], -// [[[0], [6], [8]]], [[[0], [14], [16]]]] -// ``` -// -// The output tensor has shape `[2, 2, 4, 1]` and value: -// -// ``` -// x = [[[[1], [2], [3], [4]], -// [[5], [6], [7], [8]]], -// [[[9], [10], [11], [12]], -// [[13], [14], [15], [16]]]] -// ``` -func BatchToSpaceND(scope *Scope, input tf.Output, block_shape tf.Output, crops tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BatchToSpaceND", - Input: []tf.Input{ - input, block_shape, crops, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// UnpackAttr is an optional argument to Unpack. -type UnpackAttr func(optionalAttr) - -// UnpackAxis sets the optional axis attribute to value. -// -// value: Dimension along which to unpack. Negative values wrap around, so the -// valid range is `[-R, R)`. -// If not specified, defaults to 0 -func UnpackAxis(value int64) UnpackAttr { - return func(m optionalAttr) { - m["axis"] = value - } -} - -// Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. -// -// Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. -// For example, given a tensor of shape `(A, B, C, D)`; -// -// If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` -// and each tensor in `output` will have shape `(B, C, D)`. (Note that the -// dimension unpacked along is gone, unlike `split`). -// -// If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` -// and each tensor in `output` will have shape `(A, C, D)`. -// Etc. -// -// This is the opposite of `pack`. -// -// Arguments: -// value: 1-D or higher, with `axis` dimension size equal to `num`. -// -// -// Returns The list of tensors unpacked from `value`. -func Unpack(scope *Scope, value tf.Output, num int64, optional ...UnpackAttr) (output []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num": num} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Unpack", - Input: []tf.Input{ - value, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("Unpack", err) - return - } - return output -} - -// Delete the stack from its resource container. -// -// Arguments: -// handle: The handle to a stack. -// -// Returns the created operation. -func StackCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "StackCloseV2", - Input: []tf.Input{ - handle, - }, - } - return scope.AddOperation(opspec) -} - -// Increments variable pointed to by 'resource' until it reaches 'limit'. -// -// Arguments: -// resource: Should be from a scalar `Variable` node. -// limit: If incrementing ref would bring it above limit, instead generates an -// 'OutOfRange' error. -// -// -// Returns A copy of the input before increment. If nothing else modifies the -// input, the values produced will all be distinct. -func ResourceCountUpTo(scope *Scope, resource tf.Output, limit int64, T tf.DataType) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"limit": limit, "T": T} - opspec := tf.OpSpec{ - Type: "ResourceCountUpTo", - Input: []tf.Input{ - resource, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes softsign gradients for a softsign operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding softsign operation. -// features: The features passed as input to the corresponding softsign operation. -// -// Returns The gradients: `gradients / (1 + abs(features)) ** 2`. -func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SoftsignGrad", - Input: []tf.Input{ - gradients, features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Provides the time since epoch in seconds. -// -// Returns the timestamp as a `float64` for seconds since the Unix epoch. -// -// Note: the timestamp is computed when the op is executed, not when it is added -// to the graph. -func Timestamp(scope *Scope) (ts tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Timestamp", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns immutable tensor from memory region. -// -// The current implementation memmaps the tensor from a file. -// -// Arguments: -// dtype: Type of the returned tensor. -// shape: Shape of the returned tensor. -// memory_region_name: Name of readonly memory region used by the tensor, see -// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. -func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} - opspec := tf.OpSpec{ - Type: "ImmutableConst", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StringJoinAttr is an optional argument to StringJoin. -type StringJoinAttr func(optionalAttr) - -// StringJoinSeparator sets the optional separator attribute to value. -// -// value: string, an optional join separator. -// If not specified, defaults to "" -func StringJoinSeparator(value string) StringJoinAttr { - return func(m optionalAttr) { - m["separator"] = value - } -} - -// Joins the strings in the given list of string tensors into one tensor; -// -// with the given separator (default is an empty separator). -// -// Arguments: -// inputs: A list of string tensors. The tensors must all have the same shape, -// or be scalars. Scalars may be mixed in; these will be broadcast to the shape -// of non-scalar inputs. -func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StringJoin", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates and returns an empty tensor list. -// -// All list elements must be tensors of dtype element_dtype and shape compatible -// with element_shape. -// -// handle: an empty tensor list. -// element_dtype: the type of elements in the list. -// element_shape: a shape compatible with that of elements in the list. -func EmptyTensorList(scope *Scope, element_shape tf.Output, max_num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "EmptyTensorList", - Input: []tf.Input{ - element_shape, max_num_elements, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns a list of tensors with the same shapes and contents as the input -// -// tensors. -// -// This op can be used to override the gradient for complicated functions. For -// example, suppose y = f(x) and we wish to apply a custom function g for backprop -// such that dx = g(dy). In Python, -// -// ```python -// with tf.get_default_graph().gradient_override_map( -// {'IdentityN': 'OverrideGradientWithG'}): -// y, _ = identity_n([f(x), x]) -// -// @tf.RegisterGradient('OverrideGradientWithG') -// def ApplyG(op, dy, _): -// return [None, g(dy)] # Do not backprop to f(x). -// ``` -func IdentityN(scope *Scope, input []tf.Output) (output []tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IdentityN", - Input: []tf.Input{ - tf.OutputList(input), - }, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("IdentityN", err) - return - } - return output -} - -// ResourceApplyCenteredRMSPropAttr is an optional argument to ResourceApplyCenteredRMSProp. -type ResourceApplyCenteredRMSPropAttr func(optionalAttr) - -// ResourceApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, mg, ms, and mom tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyCenteredRMSPropUseLocking(value bool) ResourceApplyCenteredRMSPropAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the centered RMSProp algorithm. -// -// The centered RMSProp algorithm uses an estimate of the centered second moment -// (i.e., the variance) for normalization, as opposed to regular RMSProp, which -// uses the (uncentered) second moment. This often helps with training, but is -// slightly more expensive in terms of computation and memory. -// -// Note that in dense implementation of this algorithm, mg, ms, and mom will -// update even if the grad is zero, but in this sparse implementation, mg, ms, -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// mean_grad = decay * mean_grad + (1-decay) * gradient -// -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) -// -// mg <- rho * mg_{t-1} + (1-rho) * grad -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon) -// var <- var - mom -// -// Arguments: -// var_: Should be from a Variable(). -// mg: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyCenteredRMSPropAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyCenteredRMSProp", - Input: []tf.Input{ - var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// ResourceSparseApplyCenteredRMSPropAttr is an optional argument to ResourceSparseApplyCenteredRMSProp. -type ResourceSparseApplyCenteredRMSPropAttr func(optionalAttr) - -// ResourceSparseApplyCenteredRMSPropUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, mg, ms, and mom tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyCenteredRMSPropUseLocking(value bool) ResourceSparseApplyCenteredRMSPropAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the centered RMSProp algorithm. -// -// The centered RMSProp algorithm uses an estimate of the centered second moment -// (i.e., the variance) for normalization, as opposed to regular RMSProp, which -// uses the (uncentered) second moment. This often helps with training, but is -// slightly more expensive in terms of computation and memory. -// -// Note that in dense implementation of this algorithm, mg, ms, and mom will -// update even if the grad is zero, but in this sparse implementation, mg, ms, -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// mean_grad = decay * mean_grad + (1-decay) * gradient -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2) -// -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom -// -// Arguments: -// var_: Should be from a Variable(). -// mg: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var, ms and mom. -// -// Returns the created operation. -func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyCenteredRMSPropAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyCenteredRMSProp", - Input: []tf.Input{ - var_, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Creates a dataset that batches `batch_size` elements from `input_dataset`. -// -// Arguments: -// -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// -// -func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "BatchDataset", - Input: []tf.Input{ - input_dataset, batch_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LoadTPUEmbeddingAdadeltaParametersAttr is an optional argument to LoadTPUEmbeddingAdadeltaParameters. -type LoadTPUEmbeddingAdadeltaParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingAdadeltaParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingAdadeltaParametersTableId(value int64) LoadTPUEmbeddingAdadeltaParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingAdadeltaParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingAdadeltaParametersTableName(value string) LoadTPUEmbeddingAdadeltaParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load Adadelta embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the Adadelta optimization algorithm. -// accumulators: Value of accumulators used in the Adadelta optimization algorithm. -// updates: Value of updates used in the Adadelta optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingAdadeltaParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, updates tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdadeltaParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingAdadeltaParameters", - Input: []tf.Input{ - parameters, accumulators, updates, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process and will never change. However, it is not suitable for cryptography. -// This function may be used when CPU time is scarce and inputs are trusted or -// unimportant. There is a risk of adversaries constructing inputs that all hash -// to the same bucket. To prevent this problem, use a strong hash function with -// `tf.string_to_hash_bucket_strong`. -// -// Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. -// -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_buckets": num_buckets} - opspec := tf.OpSpec{ - Type: "StringToHashBucketFast", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RealAttr is an optional argument to Real. -type RealAttr func(optionalAttr) - -// RealTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func RealTout(value tf.DataType) RealAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Returns the real part of a complex number. -// -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the real part of each element in `input`. All elements in -// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real -// part returned by this operation and *b* is the imaginary part. -// -// For example: -// -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.real(input) ==> [-2.25, 3.25] -// ``` -func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Real", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AudioSummaryAttr is an optional argument to AudioSummary. -type AudioSummaryAttr func(optionalAttr) - -// AudioSummaryMaxOutputs sets the optional max_outputs attribute to value. -// -// value: Max number of batch elements to generate audio for. -// If not specified, defaults to 3 -// -// REQUIRES: value >= 1 -func AudioSummaryMaxOutputs(value int64) AudioSummaryAttr { - return func(m optionalAttr) { - m["max_outputs"] = value - } -} - -// Outputs a `Summary` protocol buffer with audio. -// -// DEPRECATED at GraphDef version 15: Use AudioSummaryV2. -// -// The summary has up to `max_outputs` summary values containing audio. The -// audio is built from `tensor` which must be 3-D with shape `[batch_size, -// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are -// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. -// -// The `tag` argument is a scalar `Tensor` of type `string`. It is used to -// build the `tag` of the summary values: -// -// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. -// * If `max_outputs` is greater than 1, the summary value tags are -// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. -// -// Arguments: -// tag: Scalar. Used to build the `tag` attribute of the summary values. -// tensor: 2-D of shape `[batch_size, frames]`. -// sample_rate: The sample rate of the signal in hertz. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate float32, optional ...AudioSummaryAttr) (summary tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"sample_rate": sample_rate} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AudioSummary", - Input: []tf.Input{ - tag, tensor, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QrAttr is an optional argument to Qr. -type QrAttr func(optionalAttr) - -// QrFullMatrices sets the optional full_matrices attribute to value. -// -// value: If true, compute full-sized `q` and `r`. If false -// (the default), compute only the leading `P` columns of `q`. -// If not specified, defaults to false -func QrFullMatrices(value bool) QrAttr { - return func(m optionalAttr) { - m["full_matrices"] = value - } -} - -// Computes the QR decompositions of one or more matrices. -// -// Computes the QR decomposition of each inner matrix in `tensor` such that -// `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` -// -// ```python -// # a is a tensor. -// # q is a tensor of orthonormal matrices. -// # r is a tensor of upper triangular matrices. -// q, r = qr(a) -// q_full, r_full = qr(a, full_matrices=True) -// ``` -// -// Arguments: -// input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. -// -// Returns Orthonormal basis for range of `a`. If `full_matrices` is `False` then -// shape is `[..., M, P]`; if `full_matrices` is `True` then shape is -// `[..., M, M]`.Triangular factor. If `full_matrices` is `False` then shape is -// `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`. -func Qr(scope *Scope, input tf.Output, optional ...QrAttr) (q tf.Output, r tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Qr", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// TensorArrayV3Attr is an optional argument to TensorArrayV3. -type TensorArrayV3Attr func(optionalAttr) - -// TensorArrayV3ElementShape sets the optional element_shape attribute to value. -// -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. -// -// value: A boolean that determines whether writes to the TensorArray -// are allowed to grow the size. By default, this is not allowed. -// If not specified, defaults to false -func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["dynamic_size"] = value - } -} - -// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. -// -// value: If true (default), Tensors in the TensorArray are cleared -// after being read. This disables multiple read semantics but allows early -// release of memory. -// If not specified, defaults to true -func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["clear_after_read"] = value - } -} - -// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. -// -// value: If true (default is false), then all -// elements in the TensorArray will be expected to have have identical shapes. -// This allows certain behaviors, like dynamically checking for -// consistent shapes on write, and being able to fill in properly -// shaped zero tensors on stack -- even if the element_shape attribute -// is not fully defined. -// If not specified, defaults to false -func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["identical_element_shapes"] = value - } -} - -// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. -// -// value: Overrides the name used for the temporary tensor_array -// resource. Default value is the name of the 'TensorArray' op (which -// is guaranteed unique). -// If not specified, defaults to "" -func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr { - return func(m optionalAttr) { - m["tensor_array_name"] = value - } -} - -// An array of Tensors of given size. -// -// Write data via Write and read via Read or Pack. -// -// Arguments: -// size: The size of the array. -// dtype: The type of the elements on the tensor_array. -// -// Returns The handle to the TensorArray.A scalar used to control gradient flow. -func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorArrayV3", - Input: []tf.Input{ - size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Returns the truth value of NOT x element-wise. -func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogicalNot", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 3D real-valued fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 3 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. -// -// Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the their 3D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. -// -// @compatibility(numpy) -// Equivalent to np.fft.rfftn with 3 dimensions. -// @end_compatibility -func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RFFT3D", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes rectified linear: `max(features, 0)`. -func Relu(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Relu", - Input: []tf.Input{ - features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyAddSignAttr is an optional argument to ResourceApplyAddSign. -type ResourceApplyAddSignAttr func(optionalAttr) - -// ResourceApplyAddSignUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and m tensors is -// protected by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAddSignUseLocking(value bool) ResourceApplyAddSignAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the AddSign update. -// -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// update <- (alpha + sign_decay * sign(g) *sign(m)) * g -// variable <- variable - lr_t * update -// -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// alpha: Must be a scalar. -// sign_decay: Must be a scalar. -// beta: Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAddSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Output, alpha tf.Output, sign_decay tf.Output, beta tf.Output, grad tf.Output, optional ...ResourceApplyAddSignAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAddSign", - Input: []tf.Input{ - var_, m, lr, alpha, sign_decay, beta, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Divides sparse updates into the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] /= updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] /= updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions multiply. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterDiv", - Input: []tf.Input{ - resource, indices, updates, - }, - } - return scope.AddOperation(opspec) -} - -// ListDiffAttr is an optional argument to ListDiff. -type ListDiffAttr func(optionalAttr) - -// ListDiffOutIdx sets the optional out_idx attribute to value. -// If not specified, defaults to DT_INT32 -func ListDiffOutIdx(value tf.DataType) ListDiffAttr { - return func(m optionalAttr) { - m["out_idx"] = value - } -} - -// Computes the difference between two lists of numbers or strings. -// -// Given a list `x` and a list `y`, this operation returns a list `out` that -// represents all values that are in `x` but not in `y`. The returned list `out` -// is sorted in the same order that the numbers appear in `x` (duplicates are -// preserved). This operation also returns a list `idx` that represents the -// position of each `out` element in `x`. In other words: -// -// `out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` -// -// For example, given this input: -// -// ``` -// x = [1, 2, 3, 4, 5, 6] -// y = [1, 3, 5] -// ``` -// -// This operation would return: -// -// ``` -// out ==> [2, 4, 6] -// idx ==> [1, 3, 5] -// ``` -// -// Arguments: -// x: 1-D. Values to keep. -// y: 1-D. Values to remove. -// -// Returns 1-D. Values present in `x` but not in `y`.1-D. Positions of `x` values preserved in `out`. -func ListDiff(scope *Scope, x tf.Output, y tf.Output, optional ...ListDiffAttr) (out tf.Output, idx tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ListDiff", - Input: []tf.Input{ - x, y, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingAdadeltaParametersGradAccumDebug. -type LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr func(optionalAttr) - -// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingAdadeltaParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load Adadelta parameters with debug support. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the Adadelta optimization algorithm. -// accumulators: Value of accumulators used in the Adadelta optimization algorithm. -// updates: Value of updates used in the Adadelta optimization algorithm. -// gradient_accumulators: Value of gradient_accumulators used in the Adadelta optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingAdadeltaParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, updates tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdadeltaParametersGradAccumDebugAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingAdadeltaParametersGradAccumDebug", - Input: []tf.Input{ - parameters, accumulators, updates, gradient_accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Return a tensor with the same shape and contents as the input tensor or value. -func Identity(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Identity", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. -// -// This is the angle \( \theta \in [-\pi, \pi] \) such that -// \[ x = r \cos(\theta) \] -// and -// \[ y = r \sin(\theta) \] -// where \(r = \sqrt(x^2 + y^2) \). -func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Atan2", - Input: []tf.Input{ - y, x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Updates specified rows with values in `v`. -// -// Computes `x[i, :] = v; return x`. -// -// Arguments: -// x: A tensor of type `T`. -// i: A vector. Indices into the left-most dimension of `x`. -// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. -// -// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. -func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InplaceUpdate", - Input: []tf.Input{ - x, i, v, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// OutfeedDequeueTupleAttr is an optional argument to OutfeedDequeueTuple. -type OutfeedDequeueTupleAttr func(optionalAttr) - -// OutfeedDequeueTupleDeviceOrdinal sets the optional device_ordinal attribute to value. -// -// value: The TPU device to use. This should be -1 when the Op -// is running on a TPU device, and >= 0 when the Op is running on the CPU -// device. -// If not specified, defaults to -1 -func OutfeedDequeueTupleDeviceOrdinal(value int64) OutfeedDequeueTupleAttr { - return func(m optionalAttr) { - m["device_ordinal"] = value - } -} - -// Retrieve multiple values from the computation outfeed. -// -// This operation will block indefinitely until data is available. Output `i` -// corresponds to XLA tuple element `i`. -// -// Arguments: -// dtypes: The element types of each element in `outputs`. -// shapes: The shapes of each tensor in `outputs`. -// -// Returns A list of tensors that will be read from the outfeed. -func OutfeedDequeueTuple(scope *Scope, dtypes []tf.DataType, shapes []tf.Shape, optional ...OutfeedDequeueTupleAttr) (outputs []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes, "shapes": shapes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OutfeedDequeueTuple", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("OutfeedDequeueTuple", err) - return - } - return outputs -} - -// Identity op for gradient debugging. -// -// This op is hidden from public in Python. It is used by TensorFlow Debugger to -// register gradient tensors for gradient debugging. -// This op operates on non-reference-type tensors. -func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DebugGradientIdentity", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. -type ResourceSparseApplyAdadeltaAttr func(optionalAttr) - -// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// var: Should be from a Variable(). -// -// Arguments: -// -// accum: Should be from a Variable(). -// accum_update: : Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// -// Returns the created operation. -func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdadelta", - Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Returns which elements of x are NaN. -// -// @compatibility(numpy) -// Equivalent to np.isnan -// @end_compatibility -func IsNan(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IsNan", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. -type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) - -// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of depthwise convolution with respect to the filter. -// -// Arguments: -// input: 4-D with shape based on `data_format`. For example, if -// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, -// in_width, in_channels]` tensor. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 4-D -// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. -// out_backprop: 4-D with shape based on `data_format`. -// For example, if `data_format` is 'NHWC' then -// out_backprop shape is `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. -// padding: The type of padding algorithm to use. -// -// Returns 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. -// the `filter` input of the convolution. -func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNativeBackpropFilter", - Input: []tf.Input{ - input, filter_sizes, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MapUnstageAttr is an optional argument to MapUnstage. -type MapUnstageAttr func(optionalAttr) - -// MapUnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapUnstageCapacity(value int64) MapUnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapUnstageMemoryLimit(value int64) MapUnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapUnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapUnstageContainer(value string) MapUnstageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapUnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapUnstageSharedName(value string) MapUnstageAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes and returns the values associated with the key -// -// from the underlying container. If the underlying container -// does not contain this key, the op will block until it does. -func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapUnstage", - Input: []tf.Input{ - key, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapUnstage", err) - return - } - return values -} - -// An op enabling differentiation of TPU Embeddings. -// -// This op simply returns its first input, which is assumed to have been sliced -// from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of -// this op, and its first argument being a trainable Variable, enables automatic -// differentiation of graphs containing embeddings via the TPU Embedding Python -// libraries. -// -// Arguments: -// embedding_variable: A trainable variable, enabling optimizers to find this op. -// sliced_activations: The embedding activations Tensor to return. -// table_id: The id of the table in the embedding layer configuration from which -// these activations were computed. -// lookup_id: Identifier of the set of embedding indices which produced these -// activations. -func TPUEmbeddingActivations(scope *Scope, embedding_variable tf.Output, sliced_activations tf.Output, table_id int64, lookup_id int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"table_id": table_id, "lookup_id": lookup_id} - opspec := tf.OpSpec{ - Type: "TPUEmbeddingActivations", - Input: []tf.Input{ - embedding_variable, sliced_activations, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// BatchToSpace for 4-D tensors of type T. -// -// This is a legacy version of the more general BatchToSpaceND. -// -// Rearranges (permutes) data from batch into blocks of spatial data, followed by -// cropping. This is the reverse transformation of SpaceToBatch. More specifically, -// this op outputs a copy of the input tensor where values from the `batch` -// dimension are moved in spatial blocks to the `height` and `width` dimensions, -// followed by cropping along the `height` and `width` dimensions. -// -// Arguments: -// input: 4-D tensor with shape -// `[batch*block_size*block_size, height_pad/block_size, width_pad/block_size, -// depth]`. Note that the batch size of the input tensor must be divisible by -// `block_size * block_size`. -// crops: 2-D tensor of non-negative integers with shape `[2, 2]`. It specifies -// how many elements to crop from the intermediate result across the spatial -// dimensions as follows: -// -// crops = [[crop_top, crop_bottom], [crop_left, crop_right]] -// -// -// Returns 4-D with shape `[batch, height, width, depth]`, where: -// -// height = height_pad - crop_top - crop_bottom -// width = width_pad - crop_left - crop_right -// -// The attr `block_size` must be greater than one. It indicates the block size. -// -// Some examples: -// -// (1) For the following input of shape `[4, 1, 1, 1]` and block_size of 2: -// -// ``` -// [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 1]` and value: -// -// ``` -// x = [[[[1], [2]], [[3], [4]]]] -// ``` -// -// (2) For the following input of shape `[4, 1, 1, 3]` and block_size of 2: -// -// ``` -// [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]] -// ``` -// -// The output tensor has shape `[1, 2, 2, 3]` and value: -// -// ``` -// x = [[[[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]] -// ``` -// -// (3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2: -// -// ``` -// x = [[[[1], [3]], [[9], [11]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -// -// The output tensor has shape `[1, 4, 4, 1]` and value: -// -// ``` -// x = [[[1], [2], [3], [4]], -// [[5], [6], [7], [8]], -// [[9], [10], [11], [12]], -// [[13], [14], [15], [16]]] -// ``` -// -// (4) For the following input of shape `[8, 1, 2, 1]` and block_size of 2: -// -// ``` -// x = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], -// [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] -// ``` -// -// The output tensor has shape `[2, 2, 4, 1]` and value: -// -// ``` -// x = [[[[1], [3]], [[5], [7]]], -// [[[2], [4]], [[10], [12]]], -// [[[5], [7]], [[13], [15]]], -// [[[6], [8]], [[14], [16]]]] -// ``` -func BatchToSpace(scope *Scope, input tf.Output, crops tf.Output, block_size int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"block_size": block_size} - opspec := tf.OpSpec{ - Type: "BatchToSpace", - Input: []tf.Input{ - input, crops, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Produces a summary of any statistics recorded by the given statistics manager. -func ExperimentalStatsAggregatorSummary(scope *Scope, iterator tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ExperimentalStatsAggregatorSummary", - Input: []tf.Input{ - iterator, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Makes a new iterator from the given `dataset` and stores it in `iterator`. -// -// This operation may be executed multiple times. Each execution will reset the -// iterator in `iterator` to the first element of `dataset`. -// -// Returns the created operation. -func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MakeIterator", - Input: []tf.Input{ - dataset, iterator, - }, - } - return scope.AddOperation(opspec) -} - -// Component-wise divides a SparseTensor by a dense Tensor. -// -// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not -// the other direction. -// -// Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. -// -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseDiv(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseDenseCwiseDiv", - Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that batches and pads `batch_size` elements from the input. -// -// Arguments: -// -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// padded_shapes: A list of int64 tensors representing the desired padded shapes -// of the corresponding output components. These shapes may be partially -// specified, using `-1` to indicate that a particular dimension should be -// padded to the maximum size of all batch elements. -// padding_values: A list of scalars containing the padding value to use for -// each of the outputs. -// -func PaddedBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "PaddedBatchDataset", - Input: []tf.Input{ - input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. -type ResourceApplyMomentumAttr func(optionalAttr) - -// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var - lr * momentum * accum, so in the end, the var you get is actually -// var - lr * momentum * accum. -// If not specified, defaults to false -func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update '*var' according to the momentum scheme. Set use_nesterov = True if you -// -// want to use Nesterov momentum. -// -// accum = accum * momentum + grad -// var -= lr * accum -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. -// momentum: Momentum. Must be a scalar. -// -// Returns the created operation. -func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyMomentum", - Input: []tf.Input{ - var_, accum, lr, grad, momentum, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. -type MaxPoolGradGradAttr func(optionalAttr) - -// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the last element of the input list as well as a list with all but that element. -// -// Fails if the list is empty. -// -// input_handle: the input list -// tensor: the withdrawn last element of the list -// element_dtype: the type of elements in the list -// element_shape: the shape of the output tensor -func TensorListPopBack(scope *Scope, input_handle tf.Output, element_shape tf.Output, element_dtype tf.DataType) (output_handle tf.Output, tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "TensorListPopBack", - Input: []tf.Input{ - input_handle, element_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Determine the script codes of a given tensor of Unicode integer code points. -// -// This operation converts Unicode code points to script codes corresponding to -// each code point. Script codes correspond to International Components for -// Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html. -// Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will -// match input shape. -// -// Arguments: -// input: A Tensor of int32 Unicode code points. -// -// Returns A Tensor of int32 script codes corresponding to each input code point. -func UnicodeScript(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "UnicodeScript", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a sequence of numbers. -// -// This operation creates a sequence of numbers that begins at `start` and -// extends by increments of `delta` up to but not including `limit`. -// -// For example: -// -// ``` -// # 'start' is 3 -// # 'limit' is 18 -// # 'delta' is 3 -// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] -// ``` -// -// Arguments: -// start: 0-D (scalar). First entry in the sequence. -// limit: 0-D (scalar). Upper limit of sequence, exclusive. -// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. -// -// Returns 1-D. -func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Range", - Input: []tf.Input{ - start, limit, delta, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolGradGradWithArgmaxAttr is an optional argument to MaxPoolGradGradWithArgmax. -type MaxPoolGradGradWithArgmaxAttr func(optionalAttr) - -// MaxPoolGradGradWithArgmaxIncludeBatchInIndex sets the optional include_batch_in_index attribute to value. -// -// value: Whether to include batch dimension in flattened index of `argmax`. -// If not specified, defaults to false -func MaxPoolGradGradWithArgmaxIncludeBatchInIndex(value bool) MaxPoolGradGradWithArgmaxAttr { - return func(m optionalAttr) { - m["include_batch_in_index"] = value - } -} - -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// input: The original input. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the -// input of `max_pool`. -// argmax: The indices of the maximum values chosen for each output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients of gradients w.r.t. the input of `max_pool`. -func MaxPoolGradGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradWithArgmaxAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradGradWithArgmax", - Input: []tf.Input{ - input, grad, argmax, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Return a slice from 'input'. -// -// The output tensor is a tensor with dimensions described by 'size' -// whose values are extracted from 'input' starting at the offsets in -// 'begin'. -// -// *Requirements*: -// 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) -// -// Arguments: -// -// begin: begin[i] specifies the offset into the 'i'th dimension of -// 'input' to slice from. -// size: size[i] specifies the number of elements of the 'i'th dimension -// of 'input' to slice. If size[i] is -1, all remaining elements in dimension -// i are included in the slice (i.e. this is equivalent to setting -// size[i] = input.dim_size(i) - begin[i]). -func Slice(scope *Scope, input tf.Output, begin tf.Output, size tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Slice", - Input: []tf.Input{ - input, begin, size, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Compute the Hurwitz zeta function \\(\zeta(x, q)\\). -// -// The Hurwitz zeta function is defined as: -// -// -// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\) -func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Zeta", - Input: []tf.Input{ - x, q, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the cardinality of `input_dataset`. -// -// Returns the cardinality of `input_dataset`. -// -// Arguments: -// input_dataset: A variant tensor representing the dataset to return cardinality for. -// -// Returns The cardinality of `input_dataset`. Named constants are used to represent -// infinite and unknown cardinality. -func ExperimentalDatasetCardinality(scope *Scope, input_dataset tf.Output) (cardinality tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ExperimentalDatasetCardinality", - Input: []tf.Input{ - input_dataset, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TakeManySparseFromTensorsMapAttr is an optional argument to TakeManySparseFromTensorsMap. -type TakeManySparseFromTensorsMapAttr func(optionalAttr) - -// TakeManySparseFromTensorsMapContainer sets the optional container attribute to value. -// -// value: The container name for the `SparseTensorsMap` read by this op. -// If not specified, defaults to "" -func TakeManySparseFromTensorsMapContainer(value string) TakeManySparseFromTensorsMapAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// TakeManySparseFromTensorsMapSharedName sets the optional shared_name attribute to value. -// -// value: The shared name for the `SparseTensorsMap` read by this op. -// It should not be blank; rather the `shared_name` or unique Operation name -// of the Op that created the original `SparseTensorsMap` should be used. -// If not specified, defaults to "" -func TakeManySparseFromTensorsMapSharedName(value string) TakeManySparseFromTensorsMapAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Read `SparseTensors` from a `SparseTensorsMap` and concatenate them. -// -// The input `sparse_handles` must be an `int64` matrix of shape `[N, 1]` where -// `N` is the minibatch size and the rows correspond to the output handles of -// `AddSparseToTensorsMap` or `AddManySparseToTensorsMap`. The ranks of the -// original `SparseTensor` objects that went into the given input ops must all -// match. When the final `SparseTensor` is created, it has rank one -// higher than the ranks of the incoming `SparseTensor` objects -// (they have been concatenated along a new row dimension on the left). -// -// The output `SparseTensor` object's shape values for all dimensions but the -// first are the max across the input `SparseTensor` objects' shape values -// for the corresponding dimensions. Its first shape value is `N`, the minibatch -// size. -// -// The input `SparseTensor` objects' indices are assumed ordered in -// standard lexicographic order. If this is not the case, after this -// step run `SparseReorder` to restore index ordering. -// -// For example, if the handles represent an input, which is a `[2, 3]` matrix -// representing two original `SparseTensor` objects: -// -// ``` -// index = [ 0] -// [10] -// [20] -// values = [1, 2, 3] -// shape = [50] -// ``` -// -// and -// -// ``` -// index = [ 2] -// [10] -// values = [4, 5] -// shape = [30] -// ``` -// -// then the final `SparseTensor` will be: -// -// ``` -// index = [0 0] -// [0 10] -// [0 20] -// [1 2] -// [1 10] -// values = [1, 2, 3, 4, 5] -// shape = [2 50] -// ``` -// -// Arguments: -// sparse_handles: 1-D, The `N` serialized `SparseTensor` objects. -// Shape: `[N]`. -// dtype: The `dtype` of the `SparseTensor` objects stored in the -// `SparseTensorsMap`. -// -// Returns 2-D. The `indices` of the minibatch `SparseTensor`.1-D. The `values` of the minibatch `SparseTensor`.1-D. The `shape` of the minibatch `SparseTensor`. -func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype tf.DataType, optional ...TakeManySparseFromTensorsMapAttr) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TakeManySparseFromTensorsMap", - Input: []tf.Input{ - sparse_handles, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ResourceSparseApplyKerasMomentumAttr is an optional argument to ResourceSparseApplyKerasMomentum. -type ResourceSparseApplyKerasMomentumAttr func(optionalAttr) - -// ResourceSparseApplyKerasMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyKerasMomentumUseLocking(value bool) ResourceSparseApplyKerasMomentumAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceSparseApplyKerasMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var + momentum * accum, so in the end, the var you get is actually -// var + momentum * accum. -// If not specified, defaults to false -func ResourceSparseApplyKerasMomentumUseNesterov(value bool) ResourceSparseApplyKerasMomentumAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update relevant entries in '*var' and '*accum' according to the momentum scheme. -// -// Set use_nesterov = True if you want to use Nesterov momentum. -// -// That is for rows we have grad for, we update var and accum as follows: -// -// accum = accum * momentum - lr * grad -// var += accum -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// momentum: Momentum. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyKerasMomentumAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyKerasMomentum", - Input: []tf.Input{ - var_, accum, lr, grad, indices, momentum, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// ResourceApplyAdamWithAmsgradAttr is an optional argument to ResourceApplyAdamWithAmsgrad. -type ResourceApplyAdamWithAmsgradAttr func(optionalAttr) - -// ResourceApplyAdamWithAmsgradUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdamWithAmsgradUseLocking(value bool) ResourceApplyAdamWithAmsgradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the Adam algorithm. -// -// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ -// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ -// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ -// $$vhat_t := max{vhat_{t-1}, v_t}$$ -// $$variable := variable - lr_t * m_t / (\sqrt{vhat_t} + \epsilon)$$ -// -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// vhat: Should be from a Variable(). -// beta1_power: Must be a scalar. -// beta2_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAdamWithAmsgrad(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, vhat tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamWithAmsgradAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAdamWithAmsgrad", - Input: []tf.Input{ - var_, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// MapUnstageNoKeyAttr is an optional argument to MapUnstageNoKey. -type MapUnstageNoKeyAttr func(optionalAttr) - -// MapUnstageNoKeyCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapUnstageNoKeyCapacity(value int64) MapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapUnstageNoKeyMemoryLimit(value int64) MapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapUnstageNoKeyContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapUnstageNoKeyContainer(value string) MapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapUnstageNoKeySharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapUnstageNoKeySharedName(value string) MapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes and returns a random (key, value) -// -// from the underlying container. If the underlying container -// does not contain elements, the op will block until it does. -func MapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapUnstageNoKey", - Input: []tf.Input{ - indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - key = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapUnstageNoKey", err) - return - } - return key, values -} - -// HashTableV2Attr is an optional argument to HashTableV2. -type HashTableV2Attr func(optionalAttr) - -// HashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func HashTableV2Container(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// HashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func HashTableV2SharedName(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates a non-initialized hash table. -// -// This op creates a hash table, specifying the type of its keys and values. -// Before using the table you will have to initialize it. After initialization the -// table will be immutable. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "HashTableV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingMomentumParametersGradAccumDebug. -type RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve Momentum embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the Momentum optimization algorithm.Parameter momenta updated by the Momentum optimization algorithm.Parameter gradient_accumulators updated by the Momentum optimization algorithm. -func RetrieveTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Enqueue a Tensor on the computation outfeed. -// -// Arguments: -// input: A tensor that will be inserted into the outfeed queue. -// -// Returns the created operation. -func OutfeedEnqueue(scope *Scope, input tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "OutfeedEnqueue", - Input: []tf.Input{ - input, - }, - } - return scope.AddOperation(opspec) -} - -// Outputs a `Summary` protocol buffer with a histogram. -// -// The generated -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// has one summary value containing a histogram for `values`. -// -// This op reports an `InvalidArgument` error if any value is not finite. -// -// Arguments: -// tag: Scalar. Tag to use for the `Summary.Value`. -// values: Any shape. Values to use to build the histogram. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "HistogramSummary", - Input: []tf.Input{ - tag, values, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2. -type MutableDenseHashTableV2Attr func(optionalAttr) - -// MutableDenseHashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableDenseHashTableV2Container(value string) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableDenseHashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableDenseHashTableV2SharedName(value string) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableDenseHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// If not specified, defaults to false -func MutableDenseHashTableV2UseNodeNameSharing(value bool) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// MutableDenseHashTableV2ValueShape sets the optional value_shape attribute to value. -// -// value: The shape of each value. -// If not specified, defaults to <> -func MutableDenseHashTableV2ValueShape(value tf.Shape) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["value_shape"] = value - } -} - -// MutableDenseHashTableV2InitialNumBuckets sets the optional initial_num_buckets attribute to value. -// -// value: The initial number of hash table buckets. Must be a power -// to 2. -// If not specified, defaults to 131072 -func MutableDenseHashTableV2InitialNumBuckets(value int64) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["initial_num_buckets"] = value - } -} - -// MutableDenseHashTableV2MaxLoadFactor sets the optional max_load_factor attribute to value. -// -// value: The maximum ratio between number of entries and number of -// buckets before growing the table. Must be between 0 and 1. -// If not specified, defaults to 0.8 -func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2Attr { - return func(m optionalAttr) { - m["max_load_factor"] = value - } -} - -// Creates an empty hash table that uses tensors as the backing store. -// -// It uses "open addressing" with quadratic reprobing to resolve -// collisions. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. -// -// Arguments: -// empty_key: The key used to represent empty key buckets internally. Must not -// be used in insert or lookup operations. -// -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, deleted_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MutableDenseHashTableV2", - Input: []tf.Input{ - empty_key, deleted_key, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingADAMParametersGradAccumDebug. -type RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve ADAM embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the ADAM optimization algorithm.Parameter momenta updated by the ADAM optimization algorithm.Parameter velocities updated by the ADAM optimization algorithm.Parameter gradient_accumulators updated by the ADAM optimization algorithm. -func RetrieveTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// CudnnRNNAttr is an optional argument to CudnnRNN. -type CudnnRNNAttr func(optionalAttr) - -// CudnnRNNRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNRnnMode(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} - -// CudnnRNNInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNInputMode(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} - -// CudnnRNNDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNDirection(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["direction"] = value - } -} - -// CudnnRNNDropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNDropout(value float32) CudnnRNNAttr { - return func(m optionalAttr) { - m["dropout"] = value - } -} - -// CudnnRNNSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNSeed(value int64) CudnnRNNAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// CudnnRNNSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNSeed2(value int64) CudnnRNNAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// CudnnRNNIsTraining sets the optional is_training attribute to value. -// If not specified, defaults to true -func CudnnRNNIsTraining(value bool) CudnnRNNAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// A RNN backed by cuDNN. -// -// Computes the RNN from the input and initial states, with respect to the params -// buffer. -// -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// the actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. Should be -// "unidirectional" or "bidirectional". -// dropout: Dropout probability. When set to 0., dropout is disabled. -// seed: The 1st part of a seed to initialize dropout. -// seed2: The 2nd part of a seed to initialize dropout. -// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. -// input_c: For LSTM, a 3-D tensor with the shape of -// [num_layer * dir, batch, num_units]. For other models, it is ignored. -// params: A 1-D tensor that contains the weights and biases in an opaque layout. -// The size must be created through CudnnRNNParamsSize, and initialized -// separately. Note that they might not be compatible across different -// generations. So it is a good idea to save and restore -// output: A 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. -// output_h: The same shape has input_h. -// output_c: The same shape as input_c for LSTM. An empty tensor for other models. -// is_training: Indicates whether this operation is used for inferenece or -// training. -// reserve_space: An opaque tensor that can be used in backprop calculation. It -// is only produced if is_training is false. -func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNAttr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CudnnRNN", - Input: []tf.Input{ - input, input_h, input_c, params, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// DecodeCompressedAttr is an optional argument to DecodeCompressed. -type DecodeCompressedAttr func(optionalAttr) - -// DecodeCompressedCompressionType sets the optional compression_type attribute to value. -// -// value: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// If not specified, defaults to "" -func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { - return func(m optionalAttr) { - m["compression_type"] = value - } -} - -// Decompress strings. -// -// This op decompresses each element of the `bytes` input `Tensor`, which -// is assumed to be compressed using the given `compression_type`. -// -// The `output` is a string `Tensor` of the same shape as `bytes`, -// each element containing the decompressed data from the corresponding -// element in `bytes`. -// -// Arguments: -// bytes: A Tensor of string which is compressed. -// -// Returns A Tensor with the same shape as input `bytes`, uncompressed -// from bytes. -func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeCompressed", - Input: []tf.Input{ - bytes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to RetrieveTPUEmbeddingMDLAdagradLightParameters. -type RetrieveTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingMDLAdagradLightParametersTableId(value int64) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingMDLAdagradLightParametersTableName(value string) RetrieveTPUEmbeddingMDLAdagradLightParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve MDL Adagrad Light embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the MDL Adagrad Light optimization algorithm.Parameter accumulators updated by the MDL Adagrad Light optimization algorithm.Parameter weights updated by the MDL Adagrad Light optimization algorithm.Parameter benefits updated by the MDL Adagrad Light optimization algorithm. -func RetrieveTPUEmbeddingMDLAdagradLightParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMDLAdagradLightParametersAttr) (parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingMDLAdagradLightParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug. -type RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve Adadelta embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the Adadelta optimization algorithm.Parameter accumulators updated by the Adadelta optimization algorithm.Parameter updates updated by the Adadelta optimization algorithm.Parameter gradient_accumulators updated by the Adadelta optimization algorithm. -func RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, updates tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// MapClearAttr is an optional argument to MapClear. -type MapClearAttr func(optionalAttr) - -// MapClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapClearCapacity(value int64) MapClearAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapClearMemoryLimit(value int64) MapClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapClearContainer(value string) MapClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapClearSharedName(value string) MapClearAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes all elements in the underlying container. -// -// Returns the created operation. -func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapClear", - - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// DecodeCSVAttr is an optional argument to DecodeCSV. -type DecodeCSVAttr func(optionalAttr) - -// DecodeCSVFieldDelim sets the optional field_delim attribute to value. -// -// value: char delimiter to separate fields in a record. -// If not specified, defaults to "," -func DecodeCSVFieldDelim(value string) DecodeCSVAttr { - return func(m optionalAttr) { - m["field_delim"] = value - } -} - -// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value. -// -// value: If false, treats double quotation marks as regular -// characters inside of the string fields (ignoring RFC 4180, Section 2, -// Bullet 5). -// If not specified, defaults to true -func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr { - return func(m optionalAttr) { - m["use_quote_delim"] = value - } -} - -// DecodeCSVNaValue sets the optional na_value attribute to value. -// -// value: Additional string to recognize as NA/NaN. -// If not specified, defaults to "" -func DecodeCSVNaValue(value string) DecodeCSVAttr { - return func(m optionalAttr) { - m["na_value"] = value - } -} - -// DecodeCSVSelectCols sets the optional select_cols attribute to value. -// If not specified, defaults to <> -func DecodeCSVSelectCols(value []int64) DecodeCSVAttr { - return func(m optionalAttr) { - m["select_cols"] = value - } -} - -// Convert CSV records to tensors. Each column maps to one tensor. -// -// RFC 4180 format is expected for the CSV records. -// (https://tools.ietf.org/html/rfc4180) -// Note that we allow leading and trailing spaces with int or float field. -// -// Arguments: -// records: Each string is a record/row in the csv and all records should have -// the same format. -// record_defaults: One tensor per column of the input record, with either a -// scalar default value for that column or an empty vector if the column is -// required. -// -// Returns Each tensor will have the same shape as records. -func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeCSV", - Input: []tf.Input{ - records, tf.OutputList(record_defaults), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("DecodeCSV", err) - return - } - return output -} - -// Produces the max pool of the input tensor for quantized types. -// -// Arguments: -// input: The 4D (batch x rows x cols x depth) Tensor to MaxReduce over. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// ksize: The size of the window for each dimension of the input tensor. -// The length must be 4 to match the number of dimensions of the input. -// strides: The stride of the sliding window for each dimension of the input -// tensor. The length must be 4 to match the number of dimensions of the input. -// padding: The type of padding algorithm to use. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedMaxPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "QuantizedMaxPool", - Input: []tf.Input{ - input, min_input, max_input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// RandomShuffleAttr is an optional argument to RandomShuffle. -type RandomShuffleAttr func(optionalAttr) - -// RandomShuffleSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomShuffleSeed(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomShuffleSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleSeed2(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Randomly shuffles a tensor along its first dimension. -// -// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped -// to one and only one `output[i]`. For example, a mapping that might occur for a -// 3x2 tensor is: -// -// ``` -// [[1, 2], [[5, 6], -// [3, 4], ==> [1, 2], -// [5, 6]] [3, 4]] -// ``` -// -// Arguments: -// value: The tensor to be shuffled. -// -// Returns A tensor of same shape and type as `value`, shuffled along its first -// dimension. -func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomShuffle", - Input: []tf.Input{ - value, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// EnqueueTPUEmbeddingSparseBatchAttr is an optional argument to EnqueueTPUEmbeddingSparseBatch. -type EnqueueTPUEmbeddingSparseBatchAttr func(optionalAttr) - -// EnqueueTPUEmbeddingSparseBatchDeviceOrdinal sets the optional device_ordinal attribute to value. -// -// value: The TPU device to use. Should be >= 0 and less than the number -// of TPU cores in the task on which the node is placed. -// If not specified, defaults to -1 -func EnqueueTPUEmbeddingSparseBatchDeviceOrdinal(value int64) EnqueueTPUEmbeddingSparseBatchAttr { - return func(m optionalAttr) { - m["device_ordinal"] = value - } -} - -// EnqueueTPUEmbeddingSparseBatchCombiners sets the optional combiners attribute to value. -// -// value: A list of string scalars, one for each embedding table that specify -// how to normalize the embedding activations after weighted summation. -// Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have -// the sum of the weights be 0 for 'mean' or the sum of the squared weights be -// 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for -// all tables. -// If not specified, defaults to <> -func EnqueueTPUEmbeddingSparseBatchCombiners(value []string) EnqueueTPUEmbeddingSparseBatchAttr { - return func(m optionalAttr) { - m["combiners"] = value - } -} - -// An op that enqueues TPUEmbedding input indices from a SparseTensor. -// -// This Op eases the porting of code that uses embedding_lookup_sparse(), -// although some Python preprocessing of the SparseTensor arguments to -// embedding_lookup_sparse() is required to produce the arguments to this Op, -// since only a single EnqueueTPUEmbeddingSparseBatch Op is allowed per training -// step. -// -// The tensors at corresponding positions in the three input lists -// must have the same shape, i.e. rank 1 with dim_size() equal to the total -// number of lookups into the table described by the corresponding table_id. -// -// Arguments: -// sample_indices: A list of rank 1 Tensors specifying the training example and -// feature to which the corresponding embedding_indices and aggregation_weights -// values belong. sample_indices[i] must equal b * nf + f, where nf is the -// number of features from the corresponding table, f is in [0, nf), and -// b is in [0, batch size). -// embedding_indices: A list of rank 1 Tensors, indices into the embedding tables. -// aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e. per -// (training example, feature) -- aggregation weights. -// mode_override: A string input that overrides the mode specified in the -// TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference', -// 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set -// in TPUEmbeddingConfiguration is used, otherwise mode_override is used. -// -// Returns the created operation. -func EnqueueTPUEmbeddingSparseBatch(scope *Scope, sample_indices []tf.Output, embedding_indices []tf.Output, aggregation_weights []tf.Output, mode_override tf.Output, optional ...EnqueueTPUEmbeddingSparseBatchAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EnqueueTPUEmbeddingSparseBatch", - Input: []tf.Input{ - tf.OutputList(sample_indices), tf.OutputList(embedding_indices), tf.OutputList(aggregation_weights), mode_override, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. -type StatelessRandomNormalAttr func(optionalAttr) - -// StatelessRandomNormalDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// -// Returns Random values with specified shape. -func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatelessRandomNormal", - Input: []tf.Input{ - shape, seed, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// An Op to exchange data across TPU replicas. -// -// On each replica, the input is split into `split_count` blocks along -// `split_dimension` and send to the other replicas given group_assignment. After -// receiving `split_count` - 1 blocks from other replicas, we concatenate the -// blocks along `concat_dimension` as the output. -// -// For example, suppose there are 2 TPU replicas: -// replica 0 receives input: `[[A, B]]` -// replica 1 receives input: `[[C, D]]` -// -// group_assignment=`[[0, 1]]` -// concat_dimension=0 -// split_dimension=1 -// split_count=2 -// -// replica 0's output: `[[A], [C]]` -// replica 1's output: `[[B], [D]]` -// -// Arguments: -// input: The local input to the sum. -// group_assignment: An int32 tensor with shape -// [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the -// replica ids in the ith subgroup. -// concat_dimension: The dimension number to concatenate. -// split_dimension: The dimension number to split. -// split_count: The number of splits, this number must equal to the sub-group -// size(group_assignment.get_shape()[1]) -// -// Returns The exchanged result. -func AllToAll(scope *Scope, input tf.Output, group_assignment tf.Output, concat_dimension int64, split_dimension int64, split_count int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"concat_dimension": concat_dimension, "split_dimension": split_dimension, "split_count": split_count} - opspec := tf.OpSpec{ - Type: "AllToAll", - Input: []tf.Input{ - input, group_assignment, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adds a value to the current value of a variable. -// -// Any ReadVariableOp with a control dependency on this op is guaranteed to -// see the incremented value or a subsequent newer one. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value by which the variable will be incremented. -// -// Returns the created operation. -func AssignAddVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AssignAddVariableOp", - Input: []tf.Input{ - resource, value, - }, - } - return scope.AddOperation(opspec) -} - -// Real-valued fast Fourier transform. -// -// Computes the 1-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most dimension of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the -// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term, -// followed by the `fft_length / 2` positive-frequency terms. -// -// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [1]. The FFT length. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most -// dimension of `input` is replaced with the `fft_length / 2 + 1` unique -// frequency components of its 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.rfft -// @end_compatibility -func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RFFT", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingAdadeltaParametersAttr is an optional argument to RetrieveTPUEmbeddingAdadeltaParameters. -type RetrieveTPUEmbeddingAdadeltaParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingAdadeltaParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingAdadeltaParametersTableId(value int64) RetrieveTPUEmbeddingAdadeltaParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingAdadeltaParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingAdadeltaParametersTableName(value string) RetrieveTPUEmbeddingAdadeltaParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve Adadelta embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the Adadelta optimization algorithm.Parameter accumulators updated by the Adadelta optimization algorithm.Parameter updates updated by the Adadelta optimization algorithm. -func RetrieveTPUEmbeddingAdadeltaParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingAdadeltaParametersAttr) (parameters tf.Output, accumulators tf.Output, updates tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingAdadeltaParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// UpperBoundAttr is an optional argument to UpperBound. -type UpperBoundAttr func(optionalAttr) - -// UpperBoundOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func UpperBoundOutType(value tf.DataType) UpperBoundAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Applies upper_bound(sorted_search_values, values) along each row. -// -// Each set of rows with the same index in (sorted_inputs, values) is treated -// independently. The resulting row is the equivalent of calling -// `np.searchsorted(sorted_inputs, values, side='right')`. -// -// The result is not a global index to the entire -// `Tensor`, but rather just the index in the last dimension. -// -// A 2-D example: -// sorted_sequence = [[0, 3, 9, 9, 10], -// [1, 2, 3, 4, 5]] -// values = [[2, 4, 9], -// [0, 2, 6]] -// -// result = UpperBound(sorted_sequence, values) -// -// result == [[1, 2, 4], -// [0, 2, 5]] -// -// Arguments: -// sorted_inputs: 2-D Tensor where each row is ordered. -// values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains -// the values that will be searched for in `sorted_search_values`. -// -// Returns A `Tensor` with the same shape as `values`. It contains the last scalar index -// into the last dimension where values can be inserted without changing the -// ordered property. -func UpperBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...UpperBoundAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UpperBound", - Input: []tf.Input{ - sorted_inputs, values, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FractionalMaxPoolGradAttr is an optional argument to FractionalMaxPoolGrad. -type FractionalMaxPoolGradAttr func(optionalAttr) - -// FractionalMaxPoolGradOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` -// -// `value 20 5 16 3 7` -// -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [20, 16] for fractional max pooling. -// If not specified, defaults to false -func FractionalMaxPoolGradOverlapping(value bool) FractionalMaxPoolGradAttr { - return func(m optionalAttr) { - m["overlapping"] = value - } -} - -// Computes gradient of the FractionalMaxPool function. -// -// Arguments: -// orig_input: Original input for `fractional_max_pool` -// orig_output: Original output for `fractional_max_pool` -// out_backprop: 4-D with shape `[batch, height, width, channels]`. Gradients -// w.r.t. the output of `fractional_max_pool`. -// row_pooling_sequence: row pooling sequence, form pooling region with -// col_pooling_sequence. -// col_pooling_sequence: column pooling sequence, form pooling region with -// row_pooling sequence. -// -// Returns 4-D. Gradients w.r.t. the input of `fractional_max_pool`. -func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, out_backprop tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output, optional ...FractionalMaxPoolGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FractionalMaxPoolGrad", - Input: []tf.Input{ - orig_input, orig_output, out_backprop, row_pooling_sequence, col_pooling_sequence, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SparseReduceMaxSparseAttr is an optional argument to SparseReduceMaxSparse. -type SparseReduceMaxSparseAttr func(optionalAttr) - -// SparseReduceMaxSparseKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceMaxSparseKeepDims(value bool) SparseReduceMaxSparseAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the max of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_max()`. In contrast to SparseReduceMax, this Op returns a -// SparseTensor. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. -// -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. -// -// Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. -func SparseReduceMaxSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseReduceMaxSparse", - Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Convert one or more images from HSV to RGB. -// -// Outputs a tensor of the same shape as the `images` tensor, containing the RGB -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. -// -// See `rgb_to_hsv` for a description of the HSV encoding. -// -// Arguments: -// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. -// -// Returns `images` converted to RGB. -func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "HSVToRGB", - Input: []tf.Input{ - images, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient of the sigmoid of `x` wrt its input. -// -// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and -// `dy` is the corresponding input gradient. -func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SigmoidGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that changes the batch size. -// -// Creates a dataset that changes the batch size of the dataset to current batch -// size // num_workers. -// -// Arguments: -// input_dataset: A variant tensor representing the input dataset. -// num_workers: A scalar representing the number of workers to distribute this batch across. As -// a result of this transformation the current batch size would end up being -// divided by this parameter. -// -// -func ExperimentalRebatchDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalRebatchDataset", - Input: []tf.Input{ - input_dataset, num_workers, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that emits the outputs of `input_dataset` `count` times. -// -// Arguments: -// -// count: A scalar representing the number of times that `input_dataset` should -// be repeated. A value of `-1` indicates that it should be repeated infinitely. -// -// -func RepeatDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "RepeatDataset", - Input: []tf.Input{ - input_dataset, count, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyAdagradDAAttr is an optional argument to ResourceApplyAdagradDA. -type ResourceApplyAdagradDAAttr func(optionalAttr) - -// ResourceApplyAdagradDAUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyAdagradDAUseLocking(value bool) ResourceApplyAdagradDAAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the proximal adagrad scheme. -// -// Arguments: -// var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. -// -// Returns the created operation. -func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceApplyAdagradDAAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAdagradDA", - Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, lr, l1, l2, global_step, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Creates a TensorList which, when stacked, has the value of `tensor`. -// -// Each tensor in the result list corresponds to one row of the input tensor. -// -// tensor: The input tensor. -// output_handle: The list. -func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Output) (output_handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorListFromTensor", - Input: []tf.Input{ - tensor, element_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ConfigureDistributedTPUAttr is an optional argument to ConfigureDistributedTPU. -type ConfigureDistributedTPUAttr func(optionalAttr) - -// ConfigureDistributedTPUEmbeddingConfig sets the optional embedding_config attribute to value. -// -// value: Reserved. Do not use. -// If not specified, defaults to "" -func ConfigureDistributedTPUEmbeddingConfig(value string) ConfigureDistributedTPUAttr { - return func(m optionalAttr) { - m["embedding_config"] = value - } -} - -// ConfigureDistributedTPUTpuEmbeddingConfig sets the optional tpu_embedding_config attribute to value. -// -// value: Serialized tensorflow.tpu.TPUEmbeddingConfiguration that -// describes the embedding lookups of the program. -// If not specified, defaults to "" -func ConfigureDistributedTPUTpuEmbeddingConfig(value string) ConfigureDistributedTPUAttr { - return func(m optionalAttr) { - m["tpu_embedding_config"] = value - } -} - -// ConfigureDistributedTPUIsGlobalInit sets the optional is_global_init attribute to value. -// -// value: Reserved. Do not use. -// If not specified, defaults to false -func ConfigureDistributedTPUIsGlobalInit(value bool) ConfigureDistributedTPUAttr { - return func(m optionalAttr) { - m["is_global_init"] = value - } -} - -// Sets up the centralized structures for a distributed TPU system. -// -// Returns A serialized tensorflow.tpu.TopologyProto that describes the TPU -// topology. -func ConfigureDistributedTPU(scope *Scope, optional ...ConfigureDistributedTPUAttr) (topology tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ConfigureDistributedTPU", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reshapes a quantized tensor as per the Reshape op. -// -// ``` -// -// Arguments: -// -// shape: Defines the shape of the output tensor. -// input_min: The minimum value of the input. -// input_max: The maximum value of the input. -// -// Returns This value is copied from input_min.This value is copied from input_max. -func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "QuantizedReshape", - Input: []tf.Input{ - tensor, shape, input_min, input_max, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. -type PriorityQueueV2Attr func(optionalAttr) - -// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. -// -// value: The type of each component in a value. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["component_types"] = value - } -} - -// PriorityQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// PriorityQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func PriorityQueueV2Container(value string) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// PriorityQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that produces elements sorted by the first component value. -// -// Note that the PriorityQueue requires the first component of any element -// to be a scalar int64, in addition to the other elements declared by -// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue -// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra -// entry in their input (resp. output) lists. -// -// Arguments: -// shapes: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// -// Returns The handle to the queue. -func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shapes": shapes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "PriorityQueueV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. -type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) - -// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Sparse update '*var' as FOBOS algorithm with fixed learning rate. -// -// That is for rows we have grad for, we update var as follows: -// prox_v = var - alpha * grad -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// -// Returns the created operation. -func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalGradientDescent", - Input: []tf.Input{ - var_, alpha, l1, l2, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Check if the input matches the regex pattern. -// -// The input is a string tensor of any shape. The pattern is the -// regular expression to be matched with every element of the input tensor. -// The boolean values (True or False) of the output tensor indicate -// if the input matches the regex pattern provided. -// -// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) -// -// Arguments: -// input: A string tensor of the text to be processed. -// pattern: The regular expression to match the input. -// -// Returns A bool tensor with the same shape as `input`. -func StaticRegexFullMatch(scope *Scope, input tf.Output, pattern string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"pattern": pattern} - opspec := tf.OpSpec{ - Type: "StaticRegexFullMatch", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// OutfeedDequeueAttr is an optional argument to OutfeedDequeue. -type OutfeedDequeueAttr func(optionalAttr) - -// OutfeedDequeueDeviceOrdinal sets the optional device_ordinal attribute to value. -// -// value: The TPU device to use. This should be -1 when the Op -// is running on a TPU device, and >= 0 when the Op is running on the CPU -// device. -// If not specified, defaults to -1 -func OutfeedDequeueDeviceOrdinal(value int64) OutfeedDequeueAttr { - return func(m optionalAttr) { - m["device_ordinal"] = value - } -} - -// Retrieves a single tensor from the computation outfeed. -// -// This operation will block indefinitely until data is available. -// -// Arguments: -// dtype: The type of elements in the tensor. -// shape: The shape of the tensor. -// -// Returns A tensor that will be read from the device outfeed. -func OutfeedDequeue(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...OutfeedDequeueAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OutfeedDequeue", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RandomPoissonV2Attr is an optional argument to RandomPoissonV2. -type RandomPoissonV2Attr func(optionalAttr) - -// RandomPoissonV2Seed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomPoissonV2Seed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// RandomPoissonV2Dtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_INT64 -func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs random values from the Poisson distribution(s) described by rate. -// -// This op uses two algorithms, depending on rate. If rate >= 10, then -// the algorithm by Hormann is used to acquire samples via -// transformation-rejection. -// See http://www.sciencedirect.com/science/article/pii/0167668793909974. -// -// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform -// random variables. -// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer -// Programming, Volume 2. Addison Wesley -// -// Arguments: -// shape: 1-D integer tensor. Shape of independent samples to draw from each -// distribution described by the shape parameters given in rate. -// rate: A tensor in which each scalar is a "rate" parameter describing the -// associated poisson distribution. -// -// Returns A tensor with shape `shape + shape(rate)`. Each slice -// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for -// `rate[i0, i1, ...iN]`. -func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomPoissonV2", - Input: []tf.Input{ - shape, rate, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug. -type RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve RMSProp embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the RMSProp optimization algorithm.Parameter ms updated by the RMSProp optimization algorithm.Parameter mom updated by the RMSProp optimization algorithm.Parameter gradient_accumulators updated by the RMSProp optimization algorithm. -func RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugAttr) (parameters tf.Output, ms tf.Output, mom tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// Computes the gradient for the rsqrt of `x` wrt its input. -// -// Specifically, `grad = dy * -0.5 * y^3`, where `y = rsqrt(x)`, and `dy` -// is the corresponding input gradient. -func RsqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RsqrtGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Encode audio data using the WAV file format. -// -// This operation will generate a string suitable to be saved out to create a .wav -// audio file. It will be encoded in the 16-bit PCM format. It takes in float -// values in the range -1.0f to 1.0f, and any outside that value will be clamped to -// that range. -// -// `audio` is a 2-D float Tensor of shape `[length, channels]`. -// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). -// -// Arguments: -// audio: 2-D with shape `[length, channels]`. -// sample_rate: Scalar containing the sample frequency. -// -// Returns 0-D. WAV-encoded file contents. -func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "EncodeWav", - Input: []tf.Input{ - audio, sample_rate, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax. -type ResourceApplyAdaMaxAttr func(optionalAttr) - -// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the AdaMax algorithm. -// -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g -// v_t <- max(beta2 * v_{t-1}, abs(g)) -// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) -// -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// beta1_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAdaMax", - Input: []tf.Input{ - var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes atan of x element-wise. -func Atan(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Atan", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AssertAttr is an optional argument to Assert. -type AssertAttr func(optionalAttr) - -// AssertSummarize sets the optional summarize attribute to value. -// -// value: Print this many entries of each tensor. -// If not specified, defaults to 3 -func AssertSummarize(value int64) AssertAttr { - return func(m optionalAttr) { - m["summarize"] = value - } -} - -// Asserts that the given condition is true. -// -// If `condition` evaluates to false, print the list of tensors in `data`. -// `summarize` determines how many entries of the tensors to print. -// -// Arguments: -// condition: The condition to evaluate. -// data: The tensors to print out when condition is false. -// -// Returns the created operation. -func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Assert", - Input: []tf.Input{ - condition, tf.OutputList(data), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingAdagradParametersGradAccumDebug. -type LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr func(optionalAttr) - -// LoadTPUEmbeddingAdagradParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingAdagradParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingAdagradParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingAdagradParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load Adagrad embedding parameters with debug support. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the Adagrad optimization algorithm. -// accumulators: Value of accumulators used in the Adagrad optimization algorithm. -// gradient_accumulators: Value of gradient_accumulators used in the Adagrad optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingAdagradParametersGradAccumDebug(scope *Scope, parameters tf.Output, accumulators tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdagradParametersGradAccumDebugAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingAdagradParametersGradAccumDebug", - Input: []tf.Input{ - parameters, accumulators, gradient_accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingFTRLParametersGradAccumDebug. -type RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingFTRLParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve FTRL embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the FTRL optimization algorithm.Parameter accumulators updated by the FTRL optimization algorithm.Parameter linears updated by the FTRL optimization algorithm.Parameter gradient_accumulators updated by the FTRL optimization algorithm. -func RetrieveTPUEmbeddingFTRLParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingFTRLParametersGradAccumDebugAttr) (parameters tf.Output, accumulators tf.Output, linears tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// A dataset that splits the elements of its input into multiple elements. -func ExperimentalUnbatchDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalUnbatchDataset", - Input: []tf.Input{ - input_dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StringFormatAttr is an optional argument to StringFormat. -type StringFormatAttr func(optionalAttr) - -// StringFormatTemplate sets the optional template attribute to value. -// -// value: A string, the template to format tensor summaries into. -// If not specified, defaults to "%s" -func StringFormatTemplate(value string) StringFormatAttr { - return func(m optionalAttr) { - m["template"] = value - } -} - -// StringFormatPlaceholder sets the optional placeholder attribute to value. -// -// value: A string, at each placeholder in the template a subsequent tensor summary will be inserted. -// If not specified, defaults to "%s" -func StringFormatPlaceholder(value string) StringFormatAttr { - return func(m optionalAttr) { - m["placeholder"] = value - } -} - -// StringFormatSummarize sets the optional summarize attribute to value. -// -// value: When formatting the tensor summaries print the first and last summarize entries of each tensor dimension. -// If not specified, defaults to 3 -func StringFormatSummarize(value int64) StringFormatAttr { - return func(m optionalAttr) { - m["summarize"] = value - } -} - -// Formats a string template using a list of tensors. -// -// Formats a string template using a list of tensors, pretty-printing tensor summaries. -// -// Arguments: -// inputs: The list of tensors to format into the placeholder string. -// -// Returns = The resulting string scalar. -func StringFormat(scope *Scope, inputs []tf.Output, optional ...StringFormatAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StringFormat", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns true if queue is closed. -// -// This operation returns true if the queue is closed and false if the queue -// is open. -// -// Arguments: -// handle: The handle to a queue. -func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "QueueIsClosedV2", - Input: []tf.Input{ - handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes inverse hyperbolic tangent of x element-wise. -func Atanh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Atanh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the reverse mode backpropagated gradient of the Cholesky algorithm. -// -// For an explanation see "Differentiation of the Cholesky algorithm" by -// Iain Murray http://arxiv.org/abs/1602.07527. -// -// Arguments: -// l: Output of batch Cholesky algorithm l = cholesky(A). Shape is `[..., M, M]`. -// Algorithm depends only on lower triangular part of the innermost matrices of -// this tensor. -// grad: df/dl where f is some scalar function. Shape is `[..., M, M]`. -// Algorithm depends only on lower triangular part of the innermost matrices of -// this tensor. -// -// Returns Symmetrized version of df/dA . Shape is `[..., M, M]` -func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "CholeskyGrad", - Input: []tf.Input{ - l, grad, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Assigns a new value to a variable. -// -// Any ReadVariableOp with a control dependency on this op is guaranteed to return -// this value or a subsequent newer value of the variable. -// -// Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value to set the new tensor to use. -// -// Returns the created operation. -func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AssignVariableOp", - Input: []tf.Input{ - resource, value, - }, - } - return scope.AddOperation(opspec) -} - -// Returns a tensor of ones with the same shape and type as x. -// -// Arguments: -// x: a tensor of type T. -// -// Returns a tensor of the same shape and type as x but filled with ones. -func OnesLike(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "OnesLike", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// The gradient of SparseFillEmptyRows. -// -// Takes vectors reverse_index_map, shaped `[N]`, and grad_values, -// shaped `[N_full]`, where `N_full >= N` and copies data into either -// `d_values` or `d_default_value`. Here `d_values` is shaped `[N]` and -// `d_default_value` is a scalar. -// -// d_values[j] = grad_values[reverse_index_map[j]] -// d_default_value = sum_{k : 0 .. N_full - 1} ( -// grad_values[k] * 1{k not in reverse_index_map}) -// -// Arguments: -// reverse_index_map: 1-D. The reverse index map from SparseFillEmptyRows. -// grad_values: 1-D. The gradients from backprop. -// -// Returns 1-D. The backprop into values.0-D. The backprop into default_value. -func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_values tf.Output) (d_values tf.Output, d_default_value tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseFillEmptyRowsGrad", - Input: []tf.Input{ - reverse_index_map, grad_values, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Creates a dataset that zips together `input_datasets`. -func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ZipDataset", - Input: []tf.Input{ - tf.OutputList(input_datasets), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LoadTPUEmbeddingAdagradParametersAttr is an optional argument to LoadTPUEmbeddingAdagradParameters. -type LoadTPUEmbeddingAdagradParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingAdagradParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingAdagradParametersTableId(value int64) LoadTPUEmbeddingAdagradParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingAdagradParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingAdagradParametersTableName(value string) LoadTPUEmbeddingAdagradParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load Adagrad embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the Adagrad optimization algorithm. -// accumulators: Value of accumulators used in the Adagrad optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingAdagradParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdagradParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingAdagradParameters", - Input: []tf.Input{ - parameters, accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Strip leading and trailing whitespaces from the Tensor. -// -// Arguments: -// input: A string `Tensor` of any shape. -// -// Returns A string `Tensor` of the same shape as the input. -func StringStrip(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "StringStrip", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. The hash function is a keyed hash function, where attribute `key` -// defines the key of the hash function. `key` is an array of 2 elements. -// -// A strong hash is important when inputs may be malicious, e.g. URLs with -// additional components. Adversaries could try to make their inputs hash to the -// same bucket for a denial-of-service attack or to skew the results. A strong -// hash prevents this by making it difficult, if not infeasible, to compute inputs -// that hash to the same bucket. This comes at a cost of roughly 4x higher compute -// time than `tf.string_to_hash_bucket_fast`. -// -// Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. -// key: The key for the keyed hash function passed as a list of two uint64 -// elements. -// -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, key []int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_buckets": num_buckets, "key": key} - opspec := tf.OpSpec{ - Type: "StringToHashBucketStrong", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StringLengthAttr is an optional argument to StringLength. -type StringLengthAttr func(optionalAttr) - -// StringLengthUnit sets the optional unit attribute to value. -// -// value: The unit that is counted to compute string length. One of: `"BYTE"` (for -// the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8 -// encoded Unicode code points in each string). Results are undefined -// if `unit=UTF8_CHAR` and the `input` strings do not contain structurally -// valid UTF-8. -// If not specified, defaults to "BYTE" -func StringLengthUnit(value string) StringLengthAttr { - return func(m optionalAttr) { - m["unit"] = value - } -} - -// String lengths of `input`. -// -// Computes the length of each string given in the input tensor. -// -// Arguments: -// input: The string for which to compute the length. -// -// Returns Integer tensor that has the same shape as `input`. The output contains the -// element-wise string lengths of `input`. -func StringLength(scope *Scope, input tf.Output, optional ...StringLengthAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StringLength", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Performs gradient updates of embedding tables. -// -// Arguments: -// inputs: A TensorList of gradients with which to update embedding tables. -// This argument has the same length and shapes as the return value of -// RecvTPUEmbeddingActivations, but contains gradients of the model's loss -// with respect to the embedding activations. The embedding tables are updated -// from these gradients via the optimizer specified in the TPU embedding -// configuration given to tpu.initialize_system. -// learning_rates: A TensorList of float32 scalars, one for each dynamic learning -// rate tag: see the comments in -// //third_party/tensorflow/core/protobuf/tpu/optimization_parameters.proto. -// Multiple tables can share the same dynamic learning rate tag as specified -// in the configuration. If the learning rates for all tables are constant, -// this list should be empty. -// config: Serialized TPUEmbeddingConfiguration proto. -// -// Returns the created operation. -func SendTPUEmbeddingGradients(scope *Scope, inputs []tf.Output, learning_rates []tf.Output, config string) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"config": config} - opspec := tf.OpSpec{ - Type: "SendTPUEmbeddingGradients", - Input: []tf.Input{ - tf.OutputList(inputs), tf.OutputList(learning_rates), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes numerical negative value element-wise. -// -// I.e., \\(y = -x\\). -func Neg(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Neg", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Receives a tensor value broadcast from another device. -func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"T": T, "group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} - opspec := tf.OpSpec{ - Type: "CollectiveBcastRecv", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Decode web-safe base64-encoded strings. -// -// Input may or may not have padding at the end. See EncodeBase64 for padding. -// Web-safe means that input must use - and _ instead of + and /. -// -// Arguments: -// input: Base64 strings to decode. -// -// Returns Decoded strings. -func DecodeBase64(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DecodeBase64", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// SubstrAttr is an optional argument to Substr. -type SubstrAttr func(optionalAttr) - -// SubstrUnit sets the optional unit attribute to value. -// -// value: The unit that is used to create the substring. One of: `"BYTE"` (for -// defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8 -// encoded Unicode code points). The default is `"BYTE"`. Results are undefined if -// `unit=UTF8_CHAR` and the `input` strings do not contain structurally valid -// UTF-8. -// If not specified, defaults to "BYTE" -func SubstrUnit(value string) SubstrAttr { - return func(m optionalAttr) { - m["unit"] = value - } -} - -// Return substrings from `Tensor` of strings. -// -// For each string in the input `Tensor`, creates a substring starting at index -// `pos` with a total length of `len`. -// -// If `len` defines a substring that would extend beyond the length of the input -// string, then as many characters as possible are used. -// -// A negative `pos` indicates distance within the string backwards from the end. -// -// If `pos` specifies an index which is out of range for any of the input strings, -// then an `InvalidArgumentError` is thrown. -// -// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on -// Op creation. -// -// *NOTE*: `Substr` supports broadcasting up to two dimensions. More about -// broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -// -// --- -// -// Examples -// -// Using scalar `pos` and `len`: -// -// ```python -// input = [b'Hello', b'World'] -// position = 1 -// length = 3 -// -// output = [b'ell', b'orl'] -// ``` -// -// Using `pos` and `len` with same shape as `input`: -// -// ```python -// input = [[b'ten', b'eleven', b'twelve'], -// [b'thirteen', b'fourteen', b'fifteen'], -// [b'sixteen', b'seventeen', b'eighteen']] -// position = [[1, 2, 3], -// [1, 2, 3], -// [1, 2, 3]] -// length = [[2, 3, 4], -// [4, 3, 2], -// [5, 5, 5]] -// -// output = [[b'en', b'eve', b'lve'], -// [b'hirt', b'urt', b'te'], -// [b'ixtee', b'vente', b'hteen']] -// ``` -// -// Broadcasting `pos` and `len` onto `input`: -// -// ``` -// input = [[b'ten', b'eleven', b'twelve'], -// [b'thirteen', b'fourteen', b'fifteen'], -// [b'sixteen', b'seventeen', b'eighteen'], -// [b'nineteen', b'twenty', b'twentyone']] -// position = [1, 2, 3] -// length = [1, 2, 3] -// -// output = [[b'e', b'ev', b'lve'], -// [b'h', b'ur', b'tee'], -// [b'i', b've', b'hte'], -// [b'i', b'en', b'nty']] -// ``` -// -// Broadcasting `input` onto `pos` and `len`: -// -// ``` -// input = b'thirteen' -// position = [1, 5, 7] -// length = [3, 2, 1] -// -// output = [b'hir', b'ee', b'n'] -// ``` -// -// Arguments: -// input: Tensor of strings -// pos: Scalar defining the position of first character in each substring -// len: Scalar defining the number of characters to include in each substring -// -// Returns Tensor of substrings -func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output, optional ...SubstrAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Substr", - Input: []tf.Input{ - input, pos, len, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Exits the current frame to its parent frame. -// -// Exit makes its input `data` available to the parent frame. -// -// Arguments: -// data: The tensor to be made available to the parent frame. -// -// Returns The same tensor as `data`. -func Exit(scope *Scope, data tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Exit", - Input: []tf.Input{ - data, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingProximalAdagradParametersAttr is an optional argument to RetrieveTPUEmbeddingProximalAdagradParameters. -type RetrieveTPUEmbeddingProximalAdagradParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingProximalAdagradParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingProximalAdagradParametersTableId(value int64) RetrieveTPUEmbeddingProximalAdagradParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingProximalAdagradParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingProximalAdagradParametersTableName(value string) RetrieveTPUEmbeddingProximalAdagradParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve proximal Adagrad embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the proximal Adagrad optimization algorithm.Parameter accumulators updated by the proximal Adagrad optimization algorithm. -func RetrieveTPUEmbeddingProximalAdagradParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingProximalAdagradParametersAttr) (parameters tf.Output, accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingProximalAdagradParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Produce a string tensor that encodes the state of a Reader. -// -// Not all Readers support being serialized, so this can produce an -// Unimplemented error. -// -// Arguments: -// reader_handle: Handle to a Reader. -func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderSerializeStateV2", - Input: []tf.Input{ - reader_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the number of tensors in the input tensor list. -// -// input_handle: the input list -// length: the number of tensors in the list -func TensorListLength(scope *Scope, input_handle tf.Output) (length tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorListLength", - Input: []tf.Input{ - input_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset with a range of values. Corresponds to python's xrange. -// -// Arguments: -// start: corresponds to start in python's xrange(). -// stop: corresponds to stop in python's xrange(). -// step: corresponds to step in python's xrange(). -// -// -func RangeDataset(scope *Scope, start tf.Output, stop tf.Output, step tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "RangeDataset", - Input: []tf.Input{ - start, stop, step, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes inverse hyperbolic sine of x element-wise. -func Asinh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Asinh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// UnicodeTranscodeAttr is an optional argument to UnicodeTranscode. -type UnicodeTranscodeAttr func(optionalAttr) - -// UnicodeTranscodeErrors sets the optional errors attribute to value. -// -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeTranscodeErrors(value string) UnicodeTranscodeAttr { - return func(m optionalAttr) { - m["errors"] = value - } -} - -// UnicodeTranscodeReplacementChar sets the optional replacement_char attribute to value. -// -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD or U+65533.) -// -// Note that for UTF-8, passing a replacement character expressible in 1 byte, such -// as ' ', will preserve string alignment to the source since invalid bytes will be -// replaced with a 1-byte replacement. For UTF-16-BE and UTF-16-LE, any 1 or 2 byte -// replacement character will preserve byte alignment to the source. -// If not specified, defaults to 65533 -func UnicodeTranscodeReplacementChar(value int64) UnicodeTranscodeAttr { - return func(m optionalAttr) { - m["replacement_char"] = value - } -} - -// UnicodeTranscodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. -// -// value: Whether to replace the C0 control characters (00-1F) with the -// `replacement_char`. Default is false. -// If not specified, defaults to false -func UnicodeTranscodeReplaceControlCharacters(value bool) UnicodeTranscodeAttr { - return func(m optionalAttr) { - m["replace_control_characters"] = value - } -} - -// Transcode the input text from a source encoding to a destination encoding. -// -// The input is a string tensor of any shape. The output is a string tensor of -// the same shape containing the transcoded strings. Output strings are always -// valid unicode. If the input contains invalid encoding positions, the -// `errors` attribute sets the policy for how to deal with them. If the default -// error-handling policy is used, invalid formatting will be substituted in the -// output by the `replacement_char`. If the errors policy is to `ignore`, any -// invalid encoding positions in the input are skipped and not included in the -// output. If it set to `strict` then any invalid formatting will result in an -// InvalidArgument error. -// -// This operation can be used with `output_encoding = input_encoding` to enforce -// correct formatting for inputs even if they are already in the desired encoding. -// -// If the input is prefixed by a Byte Order Mark needed to determine encoding -// (e.g. if the encoding is UTF-16 and the BOM indicates big-endian), then that -// BOM will be consumed and not emitted into the output. If the input encoding -// is marked with an explicit endianness (e.g. UTF-16-BE), then the BOM is -// interpreted as a non-breaking-space and is preserved in the output (including -// always for UTF-8). -// -// The end result is that if the input is marked as an explicit endianness the -// transcoding is faithful to all codepoints in the source. If it is not marked -// with an explicit endianness, the BOM is not considered part of the string itself -// but as metadata, and so is not preserved in the output. -// -// Arguments: -// input: The text to be processed. Can have any shape. -// input_encoding: Text encoding of the input strings. This is any of the encodings supported -// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. -// output_encoding: The unicode encoding to use in the output. Must be one of -// `"UTF-8", "UTF-16-BE", "UTF-32-BE"`. Multi-byte encodings will be big-endian. -// -// Returns A string tensor containing unicode text encoded using `output_encoding`. -func UnicodeTranscode(scope *Scope, input tf.Output, input_encoding string, output_encoding string, optional ...UnicodeTranscodeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"input_encoding": input_encoding, "output_encoding": output_encoding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UnicodeTranscode", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// UnicodeDecodeAttr is an optional argument to UnicodeDecode. -type UnicodeDecodeAttr func(optionalAttr) - -// UnicodeDecodeErrors sets the optional errors attribute to value. -// -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeDecodeErrors(value string) UnicodeDecodeAttr { - return func(m optionalAttr) { - m["errors"] = value - } -} - -// UnicodeDecodeReplacementChar sets the optional replacement_char attribute to value. -// -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD or U+65533.) -// If not specified, defaults to 65533 -func UnicodeDecodeReplacementChar(value int64) UnicodeDecodeAttr { - return func(m optionalAttr) { - m["replacement_char"] = value - } -} - -// UnicodeDecodeReplaceControlCharacters sets the optional replace_control_characters attribute to value. -// -// value: Whether to replace the C0 control characters (00-1F) with the -// `replacement_char`. Default is false. -// If not specified, defaults to false -func UnicodeDecodeReplaceControlCharacters(value bool) UnicodeDecodeAttr { - return func(m optionalAttr) { - m["replace_control_characters"] = value - } -} - -// Decodes each string in `input` into a sequence of Unicode code points. -// -// The character codepoints for all strings are returned using a single vector -// `char_values`, with strings expanded to characters in row-major order. -// -// The `row_splits` tensor indicates where the codepoints for -// each input string begin and end within the `char_values` tensor. -// In particular, the values for the `i`th -// string (in row-major order) are stored in the slice -// `[row_splits[i]:row_splits[i+1]]`. Thus: -// -// * `char_values[row_splits[i]+j]` is the Unicode codepoint for the `j`th -// character in the `i`th string (in row-major order). -// * `row_splits[i+1] - row_splits[i]` is the number of characters in the `i`th -// string (in row-major order). -// -// Arguments: -// input: The text to be decoded. Can have any shape. Note that the output is flattened -// to a vector of char values. -// input_encoding: Text encoding of the input strings. This is any of the encodings supported -// by ICU ucnv algorithmic converters. Examples: `"UTF-16", "US ASCII", "UTF-8"`. -// -// Returns A 1D int32 tensor containing the row splits.A 1D int32 Tensor containing the decoded codepoints. -func UnicodeDecode(scope *Scope, input tf.Output, input_encoding string, optional ...UnicodeDecodeAttr) (row_splits tf.Output, char_values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"input_encoding": input_encoding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UnicodeDecode", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Adds up a SparseTensor and a dense Tensor, using these special rules: -// -// (1) Broadcasts the dense side to have the same shape as the sparse side, if -// eligible; -// (2) Then, only the dense values pointed to by the indices of the SparseTensor -// participate in the cwise addition. -// -// By these rules, the result is a logical SparseTensor with exactly the same -// indices and shape, but possibly with different non-zero values. The output of -// this Op is the resultant non-zero values. -// -// Arguments: -// sp_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// sp_values: 1-D. `N` non-empty values corresponding to `sp_indices`. -// sp_shape: 1-D. Shape of the input SparseTensor. -// dense: `R`-D. The dense Tensor operand. -// -// Returns 1-D. The `N` values that are operated on. -func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output, sp_shape tf.Output, dense tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseDenseCwiseAdd", - Input: []tf.Input{ - sp_indices, sp_values, sp_shape, dense, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. -type ResourceApplyRMSPropAttr func(optionalAttr) - -// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the RMSProp algorithm. -// -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) -// -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom -// -// Arguments: -// var_: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyRMSProp", - Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. -type StatelessTruncatedNormalAttr func(optionalAttr) - -// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom values from a truncated normal distribution. -// -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// -// Returns Random values with specified shape. -func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatelessTruncatedNormal", - Input: []tf.Input{ - shape, seed, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RestoreSliceAttr is an optional argument to RestoreSlice. -type RestoreSliceAttr func(optionalAttr) - -// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. -// -// value: Index of file to open first if multiple files match -// `file_pattern`. See the documentation for `Restore`. -// If not specified, defaults to -1 -func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { - return func(m optionalAttr) { - m["preferred_shard"] = value - } -} - -// Restores a tensor from checkpoint files. -// -// This is like `Restore` except that restored tensor can be listed as filling -// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the -// larger tensor and the slice that the restored tensor covers. -// -// The `shape_and_slice` input has the same format as the -// elements of the `shapes_and_slices` input of the `SaveSlices` op. -// -// Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// shape_and_slice: Scalar. The shapes and slice specifications to use when -// restoring a tensors. -// dt: The type of the tensor to be restored. -// -// Returns The restored tensor. -func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dt": dt} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RestoreSlice", - Input: []tf.Input{ - file_pattern, tensor_name, shape_and_slice, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Convert the quantized 'input' tensor into a lower-precision 'output', using the -// -// actual distribution of the values to maximize the usage of the lower bit depth -// and adjusting the output min and max ranges accordingly. -// -// [input_min, input_max] are scalar floats that specify the range for the float -// interpretation of the 'input' data. For example, if input_min is -1.0f and -// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 -// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. -// -// This operator tries to squeeze as much precision as possible into an output with -// a lower bit depth by calculating the actual min and max values found in the -// data. For example, maybe that quint16 input has no values lower than 16,384 and -// none higher than 49,152. That means only half the range is actually needed, all -// the float interpretations are between -0.5f and 0.5f, so if we want to compress -// the data into a quint8 output, we can use that range rather than the theoretical -// -1.0f to 1.0f that is suggested by the input min and max. -// -// In practice, this is most useful for taking output from operations like -// QuantizedMatMul that can produce higher bit-depth outputs than their inputs and -// may have large potential output ranges, but in practice have a distribution of -// input values that only uses a small fraction of the possible range. By feeding -// that output into this operator, we can reduce it from 32 bits down to 8 with -// minimal loss of accuracy. -// -// Arguments: -// -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// out_type: The type of the output. Should be a lower bit depth than Tinput. -// -// Returns The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"out_type": out_type} - opspec := tf.OpSpec{ - Type: "QuantizeDownAndShrinkRange", - Input: []tf.Input{ - input, input_min, input_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// RandomGammaAttr is an optional argument to RandomGamma. -type RandomGammaAttr func(optionalAttr) - -// RandomGammaSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomGammaSeed(value int64) RandomGammaAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomGammaSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomGammaSeed2(value int64) RandomGammaAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from the Gamma distribution(s) described by alpha. -// -// This op uses the algorithm by Marsaglia et al. to acquire samples via -// transformation-rejection from pairs of uniform and normal random variables. -// See http://dl.acm.org/citation.cfm?id=358414 -// -// Arguments: -// shape: 1-D integer tensor. Shape of independent samples to draw from each -// distribution described by the shape parameters given in alpha. -// alpha: A tensor in which each scalar is a "shape" parameter describing the -// associated gamma distribution. -// -// Returns A tensor with shape `shape + shape(alpha)`. Each slice -// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for -// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha. -func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomGamma", - Input: []tf.Input{ - shape, alpha, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceScatterNdSubAttr is an optional argument to ResourceScatterNdSub. -type ResourceScatterNdSubAttr func(optionalAttr) - -// ResourceScatterNdSubUseLocking sets the optional use_locking attribute to value. -// -// value: An optional bool. Defaults to True. If True, the assignment will -// be protected by a lock; otherwise the behavior is undefined, -// but may exhibit less contention. -// If not specified, defaults to true -func ResourceScatterNdSubUseLocking(value bool) ResourceScatterNdSubAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Applies sparse subtraction to individual values or slices in a Variable. -// -// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `ref`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th -// dimension of `ref`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: -// -// ``` -// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] -// ``` -// -// For example, say we want to subtract 4 scattered elements from a rank-1 tensor -// with 8 elements. In Python, that subtraction would look like this: -// -// ```python -// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// sub = tf.scatter_nd_sub(ref, indices, updates) -// with tf.Session() as sess: -// print sess.run(sub) -// ``` -// -// The resulting update to ref would look like this: -// -// [1, -9, 3, -6, -4, 6, 7, -4] -// -// See `tf.scatter_nd` for more details about how to make updates to -// slices. -// -// Arguments: -// ref: A resource handle. Must be from a VarHandleOp. -// indices: A Tensor. Must be one of the following types: int32, int64. -// A tensor of indices into ref. -// updates: A Tensor. Must have the same type as ref. A tensor of -// values to add to ref. -// -// Returns the created operation. -func ResourceScatterNdSub(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdSubAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceScatterNdSub", - Input: []tf.Input{ - ref, indices, updates, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Outputs deterministic pseudorandom random integers from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[minval, maxval)`. -// -// The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// minval: Minimum value (inclusive, scalar). -// maxval: Maximum value (exclusive, scalar). -// -// Returns Random values with specified shape. -func StatelessRandomUniformInt(scope *Scope, shape tf.Output, seed tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "StatelessRandomUniformInt", - Input: []tf.Input{ - shape, seed, minval, maxval, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QuantizedConv2DAttr is an optional argument to QuantizedConv2D. -type QuantizedConv2DAttr func(optionalAttr) - -// QuantizedConv2DOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// QuantizedConv2DDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes a 2D convolution given quantized 4D input and filter tensors. -// -// The inputs are quantized tensors where the lowest value represents the real -// number of the associated minimum, and the highest represents the maximum. -// This means that you can only interpret the quantized output in the same way, by -// taking the returned minimum and maximum values into account. -// -// Arguments: -// -// filter: filter's input_depth dimension must match input's depth dimensions. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// min_filter: The float value that the lowest quantized filter value represents. -// max_filter: The float value that the highest quantized filter value represents. -// strides: The stride of the sliding window for each dimension of the input -// tensor. -// padding: The type of padding algorithm to use. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedConv2D(scope *Scope, input tf.Output, filter tf.Output, min_input tf.Output, max_input tf.Output, min_filter tf.Output, max_filter tf.Output, strides []int64, padding string, optional ...QuantizedConv2DAttr) (output tf.Output, min_output tf.Output, max_output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedConv2D", - Input: []tf.Input{ - input, filter, min_input, max_input, min_filter, max_filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// ResourceGatherAttr is an optional argument to ResourceGather. -type ResourceGatherAttr func(optionalAttr) - -// ResourceGatherValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func ResourceGatherValidateIndices(value bool) ResourceGatherAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Gather slices from the variable pointed to by `resource` according to `indices`. -// -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `indices.shape + params.shape[1:]` where: -// -// ```python -// # Scalar indices -// output[:, ..., :] = params[indices, :, ... :] -// -// # Vector indices -// output[i, :, ..., :] = params[indices[i], :, ... :] -// -// # Higher rank indices -// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] -// ``` -func ResourceGather(scope *Scope, resource tf.Output, indices tf.Output, dtype tf.DataType, optional ...ResourceGatherAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceGather", - Input: []tf.Input{ - resource, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StatelessMultinomialAttr is an optional argument to StatelessMultinomial. -type StatelessMultinomialAttr func(optionalAttr) - -// StatelessMultinomialOutputDtype sets the optional output_dtype attribute to value. -// If not specified, defaults to DT_INT64 -func StatelessMultinomialOutputDtype(value tf.DataType) StatelessMultinomialAttr { - return func(m optionalAttr) { - m["output_dtype"] = value - } -} - -// Draws samples from a multinomial distribution. -// -// Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. -// seed: 2 seeds (shape [2]). -// -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func StatelessMultinomial(scope *Scope, logits tf.Output, num_samples tf.Output, seed tf.Output, optional ...StatelessMultinomialAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatelessMultinomial", - Input: []tf.Input{ - logits, num_samples, seed, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns a batched matrix tensor with new batched diagonal values. -// -// Given `input` and `diagonal`, this operation returns a tensor with the -// same shape and values as `input`, except for the main diagonal of the -// innermost matrices. These will be overwritten by the values in `diagonal`. -// -// The output is computed as follows: -// -// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has -// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a -// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: -// -// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. -// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. -// -// Arguments: -// input: Rank `k+1`, where `k >= 1`. -// diagonal: Rank `k`, where `k >= 1`. -// -// Returns Rank `k+1`, with `output.shape = input.shape`. -func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixSetDiag", - Input: []tf.Input{ - input, diagonal, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the element-wise max of two SparseTensors. -// -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. -// -// Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. -// -// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. -func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSparseMaximum", - Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// List of the given size with empty elements. -// -// element_shape: the shape of the future elements of the list -// num_elements: the number of elements to reserve -// handle: the output list -// element_dtype: the desired type of elements in the list. -func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "TensorListReserve", - Input: []tf.Input{ - element_shape, num_elements, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LoadTPUEmbeddingMDLAdagradLightParametersAttr is an optional argument to LoadTPUEmbeddingMDLAdagradLightParameters. -type LoadTPUEmbeddingMDLAdagradLightParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingMDLAdagradLightParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingMDLAdagradLightParametersTableId(value int64) LoadTPUEmbeddingMDLAdagradLightParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingMDLAdagradLightParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingMDLAdagradLightParametersTableName(value string) LoadTPUEmbeddingMDLAdagradLightParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load MDL Adagrad Light embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the MDL Adagrad Light optimization algorithm. -// accumulators: Value of accumulators used in the MDL Adagrad Light optimization algorithm. -// weights: Value of weights used in the MDL Adagrad Light optimization algorithm. -// benefits: Value of benefits used in the MDL Adagrad Light optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingMDLAdagradLightParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, weights tf.Output, benefits tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMDLAdagradLightParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingMDLAdagradLightParameters", - Input: []tf.Input{ - parameters, accumulators, weights, benefits, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes the gradient for the inverse of `x` wrt its input. -// -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InvGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reduces sparse updates into the variable referenced by `resource` using the `min` operation. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] = min(ref[indices, ...], updates[...]) -// -// # Vector indices (for each i) -// ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions are combined. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterMin(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterMin", - Input: []tf.Input{ - resource, indices, updates, - }, - } - return scope.AddOperation(opspec) -} - -// Elementwise computes the bitwise OR of `x` and `y`. -// -// The result will have those bits set, that are set in `x`, `y` or both. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BitwiseOr", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. -type MatrixSolveLsAttr func(optionalAttr) - -// MatrixSolveLsFast sets the optional fast attribute to value. -// If not specified, defaults to true -func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { - return func(m optionalAttr) { - m["fast"] = value - } -} - -// Solves one or more linear least-squares problems. -// -// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same -// type as `matrix` and shape `[..., M, K]`. -// The output is a tensor shape `[..., N, K]` where each output matrix solves -// each of the equations -// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` -// in the least squares sense. -// -// We use the following notation for (complex) matrix and right-hand sides -// in the batch: -// -// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), -// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), -// `output`=\\(X \in \mathbb{C}^{n \times k}\\), -// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). -// -// If `fast` is `True`, then the solution is computed by solving the normal -// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then -// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares -// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). -// If \\(m \lt n\\) then `output` is computed as -// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the -// minimum-norm solution to the under-determined linear system, i.e. -// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), -// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable -// when \\(A\\) is numerically full rank and has a condition number -// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is -// sufficiently large. -// -// If `fast` is `False` an algorithm based on the numerically robust complete -// orthogonal decomposition is used. This computes the minimum-norm -// least-squares solution, even when \\(A\\) is rank deficient. This path is -// typically 6-7 times slower than the fast path. If `fast` is `False` then -// `l2_regularizer` is ignored. -// -// Arguments: -// matrix: Shape is `[..., M, N]`. -// rhs: Shape is `[..., M, K]`. -// l2_regularizer: Scalar tensor. -// -// @compatibility(numpy) -// Equivalent to np.linalg.lstsq -// @end_compatibility -// -// Returns Shape is `[..., N, K]`. -func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixSolveLs", - Input: []tf.Input{ - matrix, rhs, l2_regularizer, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Interleave the values from the `data` tensors into a single tensor. -// -// Builds a merged tensor such that -// -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] -// ``` -// -// For example, if each `indices[m]` is scalar or vector, we have -// -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] -// -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] -// ``` -// -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is -// -// merged.shape = [max(indices)] + constant -// -// Values are merged in order, so if an index appears in both `indices[m][i]` and -// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the -// merged result. If you do not need this guarantee, ParallelDynamicStitch might -// perform better on some devices. -// -// For example: -// -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] -// ``` -// -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: -// -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. -// ``` -// -//
-// -//
-func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DynamicStitch", - Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Performs a padding as a preprocess during a convolution. -// -// Similar to FusedResizeAndPadConv2d, this op allows for an optimized -// implementation where the spatial padding transformation stage is fused with the -// im2col lookup, but in this case without the bilinear filtering required for -// resizing. Fusing the padding prevents the need to write out the intermediate -// results as whole tensors, reducing memory pressure, and we can get some latency -// gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' -// order is used instead. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "FusedPadConv2D", - Input: []tf.Input{ - input, paddings, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. -type Conv2DBackpropInputAttr func(optionalAttr) - -// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} - -// Conv2DBackpropInputExplicitPaddings sets the optional explicit_paddings attribute to value. -// -// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith -// dimension, the amount of padding inserted before and after the dimension is -// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If -// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. -// If not specified, defaults to <> -func Conv2DBackpropInputExplicitPaddings(value []int64) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["explicit_paddings"] = value - } -} - -// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Conv2DBackpropInputDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of convolution with respect to the input. -// -// Arguments: -// input_sizes: An integer vector representing the shape of `input`, -// where `input` is a 4-D `[batch, height, width, channels]` tensor. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. Must be in the same order as the dimension specified with -// format. -// padding: The type of padding algorithm to use. -// -// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient -// w.r.t. the input of the convolution. -func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Conv2DBackpropInput", - Input: []tf.Input{ - input_sizes, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that executes a SQL query and emits rows of the result set. -// -// Arguments: -// driver_name: The database type. Currently, the only supported type is 'sqlite'. -// data_source_name: A connection string to connect to the database. -// query: A SQL query to execute. -// -// -func ExperimentalSqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalSqlDataset", - Input: []tf.Input{ - driver_name, data_source_name, query, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LoadTPUEmbeddingCenteredRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingCenteredRMSPropParameters. -type LoadTPUEmbeddingCenteredRMSPropParametersAttr func(optionalAttr) - -// LoadTPUEmbeddingCenteredRMSPropParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingCenteredRMSPropParametersTableId(value int64) LoadTPUEmbeddingCenteredRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingCenteredRMSPropParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingCenteredRMSPropParametersTableName(value string) LoadTPUEmbeddingCenteredRMSPropParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load centered RMSProp embedding parameters. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the centered RMSProp optimization algorithm. -// ms: Value of ms used in the centered RMSProp optimization algorithm. -// mom: Value of mom used in the centered RMSProp optimization algorithm. -// mg: Value of mg used in the centered RMSProp optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingCenteredRMSPropParameters(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, mg tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingCenteredRMSPropParametersAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingCenteredRMSPropParameters", - Input: []tf.Input{ - parameters, ms, mom, mg, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute. -type DataFormatVecPermuteAttr func(optionalAttr) - -// DataFormatVecPermuteSrcFormat sets the optional src_format attribute to value. -// -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatVecPermuteSrcFormat(value string) DataFormatVecPermuteAttr { - return func(m optionalAttr) { - m["src_format"] = value - } -} - -// DataFormatVecPermuteDstFormat sets the optional dst_format attribute to value. -// -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatVecPermuteDstFormat(value string) DataFormatVecPermuteAttr { - return func(m optionalAttr) { - m["dst_format"] = value - } -} - -// Returns the permuted vector/tensor in the destination data format given the -// -// one in the source data format. -// -// Arguments: -// x: Vector of size 4 or Tensor of shape (4, 2) in source data format. -// -// Returns Vector of size 4 or Tensor of shape (4, 2) in destination data format. -func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPermuteAttr) (y tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DataFormatVecPermute", - Input: []tf.Input{ - x, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns x / y element-wise. -// -// *NOTE*: `Div` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Div(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Div", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeAreaAttr is an optional argument to ResizeArea. -type ResizeAreaAttr func(optionalAttr) - -// ResizeAreaAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeAreaAlignCorners(value bool) ResizeAreaAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Resize `images` to `size` using area interpolation. -// -// Input images can be of different types but output images are always float. -// -// The range of pixel values for the output image might be slightly different -// from the range for the input image because of limited numerical precision. -// To guarantee an output range, for example `[0.0, 1.0]`, apply -// `tf.clip_by_value` to the output. -// -// Each output pixel is computed by first transforming the pixel's footprint into -// the input tensor and then averaging the pixels that intersect the footprint. An -// input pixel's contribution to the average is weighted by the fraction of its -// area that intersects the footprint. This is the same as OpenCV's INTER_AREA. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeArea", - Input: []tf.Input{ - images, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Sends `input` to all devices that are connected to the output. -// -// Sends `input` to all devices that are connected to the output. -// -// The graph should be constructed so that all ops connected to the output have a -// valid device assignment, and the op itself is assigned one of these devices. -// -// input: The input to the broadcast. -// output: The same as input. -// shape: The shape of the input tensor. -// -func NcclBroadcast(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shape": shape} - opspec := tf.OpSpec{ - Type: "NcclBroadcast", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient of morphological 2-D dilation with respect to the filter. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. -// strides: 1-D of length 4. The stride of the sliding window for each dimension of -// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: 1-D of length 4. The input stride for atrous morphological dilation. -// Must be: `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. -// -// Returns 3-D with shape `[filter_height, filter_width, depth]`. -func Dilation2DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (filter_backprop tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} - opspec := tf.OpSpec{ - Type: "Dilation2DBackpropFilter", - Input: []tf.Input{ - input, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AddSparseToTensorsMapAttr is an optional argument to AddSparseToTensorsMap. -type AddSparseToTensorsMapAttr func(optionalAttr) - -// AddSparseToTensorsMapContainer sets the optional container attribute to value. -// -// value: The container name for the `SparseTensorsMap` created by this op. -// If not specified, defaults to "" -func AddSparseToTensorsMapContainer(value string) AddSparseToTensorsMapAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// AddSparseToTensorsMapSharedName sets the optional shared_name attribute to value. -// -// value: The shared name for the `SparseTensorsMap` created by this op. -// If blank, the new Operation's unique name is used. -// If not specified, defaults to "" -func AddSparseToTensorsMapSharedName(value string) AddSparseToTensorsMapAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Add a `SparseTensor` to a `SparseTensorsMap` return its handle. -// -// A `SparseTensor` is represented by three tensors: `sparse_indices`, -// `sparse_values`, and `sparse_shape`. -// -// This operator takes the given `SparseTensor` and adds it to a container -// object (a `SparseTensorsMap`). A unique key within this container is generated -// in the form of an `int64`, and this is the value that is returned. -// -// The `SparseTensor` can then be read out as part of a minibatch by passing -// the key as a vector element to `TakeManySparseFromTensorsMap`. To ensure -// the correct `SparseTensorsMap` is accessed, ensure that the same -// `container` and `shared_name` are passed to that Op. If no `shared_name` -// is provided here, instead use the *name* of the Operation created by calling -// `AddSparseToTensorsMap` as the `shared_name` passed to -// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. -// -// Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. -// -// Returns 0-D. The handle of the `SparseTensor` now stored in the -// `SparseTensorsMap`. -func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddSparseToTensorsMapAttr) (sparse_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AddSparseToTensorsMap", - Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns a list list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. -// -// tensor: The tensor to put on the list. -// input_handle: The old list. -// output_handle: A list with the elements of the old list followed by tensor. -// element_dtype: the type of elements in the list. -// element_shape: a shape compatible with that of elements in the list. -func TensorListPushBack(scope *Scope, input_handle tf.Output, tensor tf.Output) (output_handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorListPushBack", - Input: []tf.Input{ - input_handle, tensor, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CudnnRNNCanonicalToParamsAttr is an optional argument to CudnnRNNCanonicalToParams. -type CudnnRNNCanonicalToParamsAttr func(optionalAttr) - -// CudnnRNNCanonicalToParamsRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNCanonicalToParamsRnnMode(value string) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} - -// CudnnRNNCanonicalToParamsInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNCanonicalToParamsInputMode(value string) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} - -// CudnnRNNCanonicalToParamsDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNCanonicalToParamsDirection(value string) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["direction"] = value - } -} - -// CudnnRNNCanonicalToParamsDropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNCanonicalToParamsDropout(value float32) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["dropout"] = value - } -} - -// CudnnRNNCanonicalToParamsSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNCanonicalToParamsSeed(value int64) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// CudnnRNNCanonicalToParamsSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNCanonicalToParamsSeed2(value int64) CudnnRNNCanonicalToParamsAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Converts CudnnRNN params from canonical form to usable form. -// -// Writes a set of weights into the opaque params buffer so they can be used in -// upcoming training or inferences. -// -// Note that the params buffer may not be compatible across different GPUs. So any -// save and restoration should be converted to and from the canonical weights and -// biases. -// -// num_layers: Specifies the number of layers in the RNN model. -// num_units: Specifies the size of the hidden state. -// input_size: Specifies the size of the input state. -// weights: the canonical form of weights that can be used for saving -// and restoration. They are more likely to be compatible across different -// generations. -// biases: the canonical form of biases that can be used for saving -// and restoration. They are more likely to be compatible across different -// generations. -// num_params: number of parameter sets for all layers. -// Each layer may contain multiple parameter sets, with each set consisting of -// a weight matrix and a bias vector. -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -func CudnnRNNCanonicalToParams(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, weights []tf.Output, biases []tf.Output, optional ...CudnnRNNCanonicalToParamsAttr) (params tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CudnnRNNCanonicalToParams", - Input: []tf.Input{ - num_layers, num_units, input_size, tf.OutputList(weights), tf.OutputList(biases), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset containing elements of first component of `input_dataset` having true in the last component. -func FilterByLastComponentDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "FilterByLastComponentDataset", - Input: []tf.Input{ - input_dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the absolute value of a tensor. -// -// Given a tensor `x`, this operation returns a tensor containing the absolute -// value of each element in `x`. For example, if x is an input element and y is -// an output element, this operation computes \\(y = |x|\\). -func Abs(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Abs", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. -type MaxPoolGradV2Attr func(optionalAttr) - -// MaxPoolGradV2DataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradV2DataFormat(value string) MaxPoolGradV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients w.r.t. the output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients w.r.t. the input to `max_pool`. -func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradV2Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradV2", - Input: []tf.Input{ - orig_input, orig_output, grad, ksize, strides, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Restore a reader to a previously saved state. -// -// Not all Readers support being restored, so this can produce an -// Unimplemented error. -// -// Arguments: -// reader_handle: Handle to a Reader. -// state: Result of a ReaderSerializeState of a Reader with type -// matching reader_handle. -// -// Returns the created operation. -func ReaderRestoreStateV2(scope *Scope, reader_handle tf.Output, state tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderRestoreStateV2", - Input: []tf.Input{ - reader_handle, state, - }, - } - return scope.AddOperation(opspec) -} - -// Inverse fast Fourier transform. -// -// Computes the inverse 1-dimensional discrete Fourier transform over the -// inner-most dimension of `input`. -// -// Arguments: -// input: A complex tensor. -// -// Returns A complex tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its inverse 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.ifft -// @end_compatibility -func IFFT(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IFFT", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 2D fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform over the inner-most -// 2 dimensions of `input`. -// -// Arguments: -// input: A complex tensor. -// -// Returns A complex tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fft2 -// @end_compatibility -func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Inverse 2D fast Fourier transform. -// -// Computes the inverse 2-dimensional discrete Fourier transform over the -// inner-most 2 dimensions of `input`. -// -// Arguments: -// input: A complex tensor. -// -// Returns A complex tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their inverse 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.ifft2 -// @end_compatibility -func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IFFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Inverse 3D real-valued fast Fourier transform. -// -// Computes the inverse 3-dimensional discrete Fourier transform of a real-valued -// signal over the inner-most 3 dimensions of `input`. -// -// The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`: -// The inner-most dimension contains the `fft_length / 2 + 1` unique components of -// the DFT of a real-valued signal. If `fft_length` is not provided, it is computed -// from the size of the inner-most 3 dimensions of `input`. If the FFT length used -// to compute `input` is odd, it should be provided since it cannot be inferred -// properly. -// -// Along each axis `IRFFT3D` is computed on, if `fft_length` (or -// `fft_length / 2 + 1` for the inner-most dimension) is smaller than the -// corresponding dimension of `input`, the dimension is cropped. If it is larger, -// the dimension is padded with zeros. -// -// Arguments: -// input: A complex64 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A float32 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the `fft_length` samples of their -// inverse 3D real Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.irfftn with 3 dimensions. -// @end_compatibility -func IRFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IRFFT3D", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the truth value of (x != y) element-wise. -// -// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NotEqual", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingMomentumParametersGradAccumDebug. -type LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) - -// LoadTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingMomentumParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingMomentumParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load Momentum embedding parameters with debug support. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the Momentum optimization algorithm. -// momenta: Value of momenta used in the Momentum optimization algorithm. -// gradient_accumulators: Value of gradient_accumulators used in the Momentum optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, parameters tf.Output, momenta tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingMomentumParametersGradAccumDebugAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingMomentumParametersGradAccumDebug", - Input: []tf.Input{ - parameters, momenta, gradient_accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// StatefulStandardNormalAttr is an optional argument to StatefulStandardNormal. -type StatefulStandardNormalAttr func(optionalAttr) - -// StatefulStandardNormalDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatefulStandardNormalDtype(value tf.DataType) StatefulStandardNormalAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs random values from a normal distribution. -// -// The generated values will have mean 0 and standard deviation 1. -// -// Arguments: -// resource: The handle of the resource variable that stores the state of the RNG. -// shape: The shape of the output tensor. -// -// Returns A tensor of the specified shape filled with random normal values. -func StatefulStandardNormal(scope *Scope, resource tf.Output, shape tf.Output, optional ...StatefulStandardNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatefulStandardNormal", - Input: []tf.Input{ - resource, shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the Gauss error function of `x` element-wise. -func Erf(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Erf", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns element-wise largest integer not greater than x. -func Floor(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Floor", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the number of records this Reader has produced. -// -// This is the same as the number of ReaderRead executions that have -// succeeded. -// -// Arguments: -// reader_handle: Handle to a Reader. -func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_produced tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderNumRecordsProducedV2", - Input: []tf.Input{ - reader_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorListConcatAttr is an optional argument to TensorListConcat. -type TensorListConcatAttr func(optionalAttr) - -// TensorListConcatElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorListConcatElementShape(value tf.Shape) TensorListConcatAttr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// Concats all tensors in the list along the 0th dimension. -// -// Requires that all tensors have the same shape except the first dimension. -// -// input_handle: The input list. -// tensor: The concated result. -// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. -// -func TensorListConcat(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListConcatAttr) (tensor tf.Output, lengths tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorListConcat", - Input: []tf.Input{ - input_handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Conv3DAttr is an optional argument to Conv3D. -type Conv3DAttr func(optionalAttr) - -// Conv3DDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DDataFormat(value string) Conv3DAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Conv3DDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 5. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv3DDilations(value []int64) Conv3DAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes a 3-D convolution given 5-D `input` and `filter` tensors. -// -// In signal processing, cross-correlation is a measure of similarity of -// two waveforms as a function of a time-lag applied to one of them. This -// is also known as a sliding dot product or sliding inner-product. -// -// Our Conv3D implements a form of cross-correlation. -// -// Arguments: -// input: Shape `[batch, in_depth, in_height, in_width, in_channels]`. -// filter: Shape `[filter_depth, filter_height, filter_width, in_channels, -// out_channels]`. `in_channels` must match between `input` and `filter`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv3DAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Conv3D", - Input: []tf.Input{ - input, filter, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QuantizeV2Attr is an optional argument to QuantizeV2. -type QuantizeV2Attr func(optionalAttr) - -// QuantizeV2Mode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func QuantizeV2Mode(value string) QuantizeV2Attr { - return func(m optionalAttr) { - m["mode"] = value - } -} - -// QuantizeV2RoundMode sets the optional round_mode attribute to value. -// If not specified, defaults to "HALF_AWAY_FROM_ZERO" -func QuantizeV2RoundMode(value string) QuantizeV2Attr { - return func(m optionalAttr) { - m["round_mode"] = value - } -} - -// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. -// -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. The -// 'round_mode' attribute controls which rounding tie-breaking algorithm is used -// when rounding float values to their quantized equivalents. -// -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: -// -// ``` -// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) -// if T == qint8: out[i] -= (range(T) + 1) / 2.0 -// ``` -// -// here `range(T) = numeric_limits::max() - numeric_limits::min()` -// -// *MIN_COMBINED Mode Example* -// -// Assume the input is type float and has a possible range of [0.0, 6.0] and the -// output type is quint8 ([0, 255]). The min_range and max_range values should be -// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each -// value of the input by 255/6 and cast to quint8. -// -// If the output type was qint8 ([-128, 127]), the operation will additionally -// subtract each value by 128 prior to casting, so that the range of values aligns -// with the range of qint8. -// -// If the mode is 'MIN_FIRST', then this approach is used: -// -// ``` -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = num_discrete_values / range -// quantized = round(input * range_scale) - round(range_min * range_scale) + -// numeric_limits::min() -// quantized = max(quantized, numeric_limits::min()) -// quantized = min(quantized, numeric_limits::max()) -// ``` -// -// The biggest difference between this and MIN_COMBINED is that the minimum range -// is rounded first, before it's subtracted from the rounded value. With -// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing -// and dequantizing will introduce a larger and larger error. -// -// *SCALED mode Example* -// -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. -// -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. -// -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` -// -// Our input tensor range is then `[-m, m]`. -// -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is -// -// ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] -// ``` -// -// Otherwise, if T is unsigned, the fixed-point range is -// -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` -// -// From this we compute our scaling factor, s: -// -// ```c++ -// s = (max_fixed - min_fixed) / (2 * m) -// ``` -// -// Now we can quantize the elements of our tensor: -// -// ```c++ -// result = round(input * s) -// ``` -// -// One thing to watch out for is that the operator may choose to adjust the -// requested minimum and maximum values slightly during the quantization process, -// so you should always use the output ports as the range for further calculations. -// For example, if the requested minimum and maximum values are close to equal, -// they will be separated by a small epsilon value to prevent ill-formed quantized -// buffers from being created. Otherwise, you can end up with buffers where all the -// quantized values map to the same float value, which causes problems for -// operations that have to perform further calculations on them. -// -// Arguments: -// -// min_range: The minimum scalar value possibly produced for the input. -// max_range: The maximum scalar value possibly produced for the input. -// -// -// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output. -func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"T": T} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizeV2", - Input: []tf.Input{ - input, min_range, max_range, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// EnterAttr is an optional argument to Enter. -type EnterAttr func(optionalAttr) - -// EnterIsConstant sets the optional is_constant attribute to value. -// -// value: If true, the output is constant within the child frame. -// If not specified, defaults to false -func EnterIsConstant(value bool) EnterAttr { - return func(m optionalAttr) { - m["is_constant"] = value - } -} - -// EnterParallelIterations sets the optional parallel_iterations attribute to value. -// -// value: The number of iterations allowed to run in parallel. -// If not specified, defaults to 10 -func EnterParallelIterations(value int64) EnterAttr { - return func(m optionalAttr) { - m["parallel_iterations"] = value - } -} - -// Creates or finds a child frame, and makes `data` available to the child frame. -// -// This op is used together with `Exit` to create loops in the graph. -// The unique `frame_name` is used by the `Executor` to identify frames. If -// `is_constant` is true, `output` is a constant in the child frame; otherwise -// it may be changed in the child frame. At most `parallel_iterations` iterations -// are run in parallel in the child frame. -// -// Arguments: -// data: The tensor to be made available to the child frame. -// frame_name: The name of the child frame. -// -// Returns The same tensor as `data`. -func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"frame_name": frame_name} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Enter", - Input: []tf.Input{ - data, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TryRpcAttr is an optional argument to TryRpc. -type TryRpcAttr func(optionalAttr) - -// TryRpcProtocol sets the optional protocol attribute to value. -// -// value: RPC protocol to use. Empty string means use the default protocol. -// Options include 'grpc'. -// If not specified, defaults to "" -func TryRpcProtocol(value string) TryRpcAttr { - return func(m optionalAttr) { - m["protocol"] = value - } -} - -// TryRpcFailFast sets the optional fail_fast attribute to value. -// -// value: `boolean`. If `true` (default), then failures to connect -// (i.e., the server does not immediately respond) cause an RPC failure. -// If not specified, defaults to true -func TryRpcFailFast(value bool) TryRpcAttr { - return func(m optionalAttr) { - m["fail_fast"] = value - } -} - -// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value. -// -// value: `int`. If `0` (default), then the kernel will run the RPC -// request and only time out if the RPC deadline passes or the session times out. -// If this value is greater than `0`, then the op will raise an exception if -// the RPC takes longer than `timeout_in_ms`. -// If not specified, defaults to 0 -func TryRpcTimeoutInMs(value int64) TryRpcAttr { - return func(m optionalAttr) { - m["timeout_in_ms"] = value - } -} - -// Perform batches of RPC requests. -// -// This op asynchronously performs either a single RPC request, or a batch -// of requests. RPC requests are defined by three main parameters: -// -// - `address` (the host+port or BNS address of the request) -// - `method` (the method name for the request) -// - `request` (the serialized proto string, or vector of strings, -// of the RPC request argument). -// -// For example, if you have an RPC service running on port localhost:2345, -// and its interface is configured with the following proto declaration: -// -// ``` -// service MyService { -// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { -// } -// }; -// ``` -// -// then call this op with arguments: -// -// ``` -// address = "localhost:2345" -// method = "MyService/MyMethod" -// ``` -// -// The `request` tensor is a string tensor representing serialized `MyRequestProto` -// strings; and the output string tensor `response` will have the same shape -// and contain (upon successful completion) corresponding serialized -// `MyResponseProto` strings. -// -// For example, to send a single, empty, `MyRequestProto`, call -// this op with `request = ""`. To send 5 **parallel** empty requests, -// call this op with `request = ["", "", "", "", ""]`. -// -// More generally, one can create a batch of `MyRequestProto` serialized protos -// from regular batched tensors using the `encode_proto` op, and convert -// the response `MyResponseProto` serialized protos to batched tensors -// using the `decode_proto` op. -// -// **NOTE** Working with serialized proto strings is faster than instantiating -// actual proto objects in memory, so no performance degradation is expected -// compared to writing custom kernels for this workflow. -// -// Unlike the standard `Rpc` op, if the connection fails or the remote worker -// returns an error status, this op does **not** reraise the exception. -// Instead, the `status_code` and `status_message` entry for the corresponding RPC -// call is set with the error returned from the RPC call. The `response` tensor -// will contain valid response values for those minibatch entries whose RPCs did -// not fail; the rest of the entries will have empty strings. -// -// Arguments: -// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `method` and `request`. -// method: `0-D` or `1-D`. The method address on the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `request`. -// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `method`. -// -// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages -// returned from the RPC calls. -func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TryRpc", - Input: []tf.Input{ - address, method, request, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Add all input tensors element wise. -// -// Arguments: -// inputs: Must all be the same size and shape. -func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AddN", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the element-wise sum of a list of tensors. -// -// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not -// wait for all of its inputs to be ready before beginning to sum. This can -// save memory if inputs are ready at different times, since minimum temporary -// storage is proportional to the output size rather than the inputs size. -// -// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. -// -// Returns a `Tensor` of same shape and type as the elements of `inputs`. -// -// Arguments: -// inputs: A list of `Tensor` objects, each with same shape and type. -// shape: Shape of elements of `inputs`. -func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shape": shape} - opspec := tf.OpSpec{ - Type: "AccumulateNV2", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ComplexAbsAttr is an optional argument to ComplexAbs. -type ComplexAbsAttr func(optionalAttr) - -// ComplexAbsTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func ComplexAbsTout(value tf.DataType) ComplexAbsAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Computes the complex absolute value of a tensor. -// -// Given a tensor `x` of complex numbers, this operation returns a tensor of type -// `float` or `double` that is the absolute value of each element in `x`. All -// elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute -// value is computed as \\( \sqrt{a^2 + b^2}\\). -func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ComplexAbs", - Input: []tf.Input{ - x, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the truth value of x AND y element-wise. -// -// *NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogicalAnd", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the reciprocal of x element-wise. -// -// I.e., \\(y = 1 / x\\). -func Inv(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Inv", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that batches input elements into a SparseTensor. -// -// Arguments: -// input_dataset: A handle to an input dataset. Must have a single component. -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// row_shape: A vector representing the dense shape of each row in the produced -// SparseTensor. The shape may be partially specified, using `-1` to indicate -// that a particular dimension should use the maximum size of all batch elements. -// -// -func ExperimentalDenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalDenseToSparseBatchDataset", - Input: []tf.Input{ - input_dataset, batch_size, row_shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the reciprocal of x element-wise. -// -// I.e., \\(y = 1 / x\\). -func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Reciprocal", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Conv3DBackpropFilterAttr is an optional argument to Conv3DBackpropFilter. -type Conv3DBackpropFilterAttr func(optionalAttr) - -// Conv3DBackpropFilterDilations sets the optional dilations attribute to value. -// If not specified, defaults to -func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of 3-D convolution with respect to the filter. -// -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropFilterV2 -// -// Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropFilter(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilter", - Input: []tf.Input{ - input, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes square root of x element-wise. -// -// I.e., \\(y = \sqrt{x} = x^{1/2}\\). -func Sqrt(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sqrt", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Get the value of the tensor specified by its handle. -// -// Arguments: -// handle: The handle for a tensor stored in the session state. -// dtype: The type of the output value. -// -// Returns The tensor for the given handle. -func GetSessionTensor(scope *Scope, handle tf.Output, dtype tf.DataType) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - opspec := tf.OpSpec{ - Type: "GetSessionTensor", - Input: []tf.Input{ - handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient for the sqrt of `x` wrt its input. -// -// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` -// is the corresponding input gradient. -func SqrtGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SqrtGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MatrixInverseAttr is an optional argument to MatrixInverse. -type MatrixInverseAttr func(optionalAttr) - -// MatrixInverseAdjoint sets the optional adjoint attribute to value. -// If not specified, defaults to false -func MatrixInverseAdjoint(value bool) MatrixInverseAttr { - return func(m optionalAttr) { - m["adjoint"] = value - } -} - -// Computes the inverse of one or more square invertible matrices or their -// -// adjoints (conjugate transposes). -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor of the same shape as the input -// containing the inverse for all input submatrices `[..., :, :]`. -// -// The op uses LU decomposition with partial pivoting to compute the inverses. -// -// If a matrix is not invertible there is no guarantee what the op does. It -// may detect the condition and raise an exception or it may simply return a -// garbage result. -// -// Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[..., M, M]`. -// -// @compatibility(numpy) -// Equivalent to np.linalg.inv -// @end_compatibility -func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixInverse", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes reciprocal of square root of x element-wise. -// -// I.e., \\(y = 1 / \sqrt{x}\\). -func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Rsqrt", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Rounds the values of a tensor to the nearest integer, element-wise. -// -// Rounds half to even. Also known as bankers rounding. If you want to round -// according to the current system rounding mode use std::cint. -func Round(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Round", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Delete the TensorArray from its resource container. -// -// This enables the user to close and release the resource in the middle -// of a step/run. -// -// Arguments: -// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). -// -// Returns the created operation. -func TensorArrayCloseV3(scope *Scope, handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArrayCloseV3", - Input: []tf.Input{ - handle, - }, - } - return scope.AddOperation(opspec) -} - -// Computes exponential of x element-wise. \\(y = e^x\\). -func Exp(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Exp", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// NthElementAttr is an optional argument to NthElement. -type NthElementAttr func(optionalAttr) - -// NthElementReverse sets the optional reverse attribute to value. -// -// value: When set to True, find the nth-largest value in the vector and vice -// versa. -// If not specified, defaults to false -func NthElementReverse(value bool) NthElementAttr { - return func(m optionalAttr) { - m["reverse"] = value - } -} - -// Finds values of the `n`-th order statistic for the last dimension. -// -// If the input is a vector (rank-1), finds the entries which is the nth-smallest -// value in the vector and outputs their values as scalar tensor. -// -// For matrices (resp. higher rank input), computes the entries which is the -// nth-smallest value in each row (resp. vector along the last dimension). Thus, -// -// values.shape = input.shape[:-1] -// -// Arguments: -// input: 1-D or higher with last dimension at least `n+1`. -// n: 0-D. Position of sorted vector to select along the last dimension (along -// each row for matrices). Valid range of n is `[0, input.shape[:-1])` -// -// Returns The `n`-th order statistic along each last dimensional slice. -func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "NthElement", - Input: []tf.Input{ - input, n, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the maximum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// This operator is similar to the unsorted segment sum operator found -// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). -// Instead of computing the sum over segments, it computes the maximum such that: -// -// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such -// that `segment_ids[j...] == i`. -// -// If the maximum is empty for a given segment ID `i`, it outputs the smallest -// possible value for the specific numeric type, -// `output[i] = numeric_limits::lowest()`. -// -// If the given segment ID `i` is negative, then the corresponding value is -// dropped, and will not be included in the result. -// -//
-// -//
-// -// For example: -// -// ``` python -// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) -// tf.unsorted_segment_max(c, tf.constant([0, 1, 0]), num_segments=2) -// # ==> [[ 4, 3, 3, 4], -// # [5, 6, 7, 8]] -// ``` -// -// -// Arguments: -// -// segment_ids: A tensor whose shape is a prefix of `data.shape`. -// -// -// Returns Has same shape as data, except for the first `segment_ids.rank` -// dimensions, which are replaced with a single dimension which has size -// `num_segments`. -func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "UnsortedSegmentMax", - Input: []tf.Input{ - data, segment_ids, num_segments, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes softplus: `log(exp(features) + 1)`. -func Softplus(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Softplus", - Input: []tf.Input{ - features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes exponential of x - 1 element-wise. -// -// I.e., \\(y = (\exp x) - 1\\). -func Expm1(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Expm1", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes natural logarithm of x element-wise. -// -// I.e., \\(y = \log_e x\\). -func Log(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Log", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the index of a data point that should be added to the seed set. -// -// Entries in distances are assumed to be squared distances of candidate points to -// the already sampled centers in the seed set. The op constructs one Markov chain -// of the k-MC^2 algorithm and returns the index of one candidate point to be added -// as an additional cluster center. -// -// Arguments: -// distances: Vector with squared distances to the closest previously sampled cluster center -// for each candidate point. -// seed: Scalar. Seed for initializing the random number generator. -// -// Returns Scalar with the index of the sampled point. -func KMC2ChainInitialization(scope *Scope, distances tf.Output, seed tf.Output) (index tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "KMC2ChainInitialization", - Input: []tf.Input{ - distances, seed, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes hyperbolic sine of x element-wise. -func Sinh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sinh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sum along sparse segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first -// dimension, selecting a subset of dimension 0, specified by `indices`. -// -// For example: -// -// ```python -// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) -// -// # Select two rows, one segment. -// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) -// # => [[0 0 0 0]] -// -// # Select two rows, two segment. -// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) -// # => [[ 1 2 3 4] -// # [-1 -2 -3 -4]] -// -// # Select all rows, two segments. -// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) -// # => [[0 0 0 0] -// # [5 6 7 8]] -// -// # Which is equivalent to: -// tf.segment_sum(c, tf.constant([0, 0, 1])) -// ``` -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentSum", - Input: []tf.Input{ - data, indices, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CastAttr is an optional argument to Cast. -type CastAttr func(optionalAttr) - -// CastTruncate sets the optional Truncate attribute to value. -// If not specified, defaults to false -func CastTruncate(value bool) CastAttr { - return func(m optionalAttr) { - m["Truncate"] = value - } -} - -// Cast x of type SrcT to y of DstT. -func Cast(scope *Scope, x tf.Output, DstT tf.DataType, optional ...CastAttr) (y tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"DstT": DstT} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Cast", - Input: []tf.Input{ - x, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the log of the absolute value of `Gamma(x)` element-wise. -func Lgamma(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Lgamma", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// UnicodeEncodeAttr is an optional argument to UnicodeEncode. -type UnicodeEncodeAttr func(optionalAttr) - -// UnicodeEncodeErrors sets the optional errors attribute to value. -// -// value: Error handling policy when there is invalid formatting found in the input. -// The value of 'strict' will cause the operation to produce a InvalidArgument -// error on any invalid input formatting. A value of 'replace' (the default) will -// cause the operation to replace any invalid formatting in the input with the -// `replacement_char` codepoint. A value of 'ignore' will cause the operation to -// skip any invalid formatting in the input and produce no corresponding output -// character. -// If not specified, defaults to "replace" -func UnicodeEncodeErrors(value string) UnicodeEncodeAttr { - return func(m optionalAttr) { - m["errors"] = value - } -} - -// UnicodeEncodeReplacementChar sets the optional replacement_char attribute to value. -// -// value: The replacement character codepoint to be used in place of any invalid -// formatting in the input when `errors='replace'`. Any valid unicode codepoint may -// be used. The default value is the default unicode replacement character is -// 0xFFFD (U+65533). -// If not specified, defaults to 65533 -func UnicodeEncodeReplacementChar(value int64) UnicodeEncodeAttr { - return func(m optionalAttr) { - m["replacement_char"] = value - } -} - -// Encode a tensor of ints into unicode strings. -// -// Returns a vector of strings, where `output[i]` is constructed by encoding the -// Unicode codepoints in `input_values[input_splits[i]:input_splits[i+1]]` -// using `output_encoding`. -// -// --- -// -// Example: -// -// ``` -// input_values = [72, 101, 108, 108, 111, 87, 111, 114, 108, 100] -// input_splits = [0, 5, 10] -// output_encoding = 'UTF-8' -// -// output = ['Hello', 'World'] -// ``` -// -// Arguments: -// input_values: A 1D tensor containing the unicode codepoints that should be encoded. -// input_splits: A 1D tensor specifying how the unicode codepoints should be split into strings. -// In particular, `output[i]` is constructed by encoding the codepoints in the -// slice `input_values[input_splits[i]:input_splits[i+1]]`. -// output_encoding: Unicode encoding of the output strings. Valid encodings are: `"UTF-8", -// "UTF-16-BE", and "UTF-32-BE"`. -// -// Returns The 1-D Tensor of strings encoded from the provided unicode codepoints. -func UnicodeEncode(scope *Scope, input_values tf.Output, input_splits tf.Output, output_encoding string, optional ...UnicodeEncodeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_encoding": output_encoding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UnicodeEncode", - Input: []tf.Input{ - input_values, input_splits, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the complementary error function of `x` element-wise. -func Erfc(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Erfc", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes sigmoid of `x` element-wise. -// -// Specifically, `y = 1 / (1 + exp(-x))`. -func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sigmoid", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes sin of x element-wise. -func Sin(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sin", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FusedBatchNormGradAttr is an optional argument to FusedBatchNormGrad. -type FusedBatchNormGradAttr func(optionalAttr) - -// FusedBatchNormGradEpsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormGradEpsilon(value float32) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format for y_backprop, x, x_backprop. -// Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormGradDataFormat(value string) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormGradIsTraining sets the optional is_training attribute to value. -// -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormGradIsTraining(value bool) FusedBatchNormGradAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Gradient for batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. -// -// Arguments: -// y_backprop: A 4D Tensor for the gradient with respect to y. -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch -// mean to be reused in gradient computation. When is_training is -// False, a 1D Tensor for the population mean to be reused in both -// 1st and 2nd order gradient computation. -// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch -// variance (inverted variance in the cuDNN case) to be reused in -// gradient computation. When is_training is False, a 1D Tensor -// for the population variance to be reused in both 1st and 2nd -// order gradient computation. -// -// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input -// in FusedBatchNorm. -func FusedBatchNormGrad(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradAttr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedBatchNormGrad", - Input: []tf.Input{ - y_backprop, x, scale, reserve_space_1, reserve_space_2, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// Computes cos of x element-wise. -func Cos(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cos", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the determinant of one or more square matrices. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor containing the determinants -// for all input submatrices `[..., :, :]`. -// -// Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[...]`. -func MatrixDeterminant(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixDeterminant", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Updates the tree ensemble by either adding a layer to the last tree being grown -// -// or by starting a new tree. -// -// Arguments: -// tree_ensemble_handle: Handle to the ensemble variable. -// feature_ids: Rank 1 tensor with ids for each feature. This is the real id of -// the feature that will be used in the split. -// node_ids: List of rank 1 tensors representing the nodes for which this feature -// has a split. -// gains: List of rank 1 tensors representing the gains for each of the feature's -// split. -// thresholds: List of rank 1 tensors representing the thesholds for each of the -// feature's split. -// left_node_contribs: List of rank 2 tensors with left leaf contribs for each of -// the feature's splits. Will be added to the previous node values to constitute -// the values of the left nodes. -// right_node_contribs: List of rank 2 tensors with right leaf contribs for each -// of the feature's splits. Will be added to the previous node values to constitute -// the values of the right nodes. -// max_depth: Max depth of the tree to build. -// learning_rate: shrinkage const for each new tree. -// pruning_mode: 0-No pruning, 1-Pre-pruning, 2-Post-pruning. -// -// Returns the created operation. -func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, feature_ids tf.Output, node_ids []tf.Output, gains []tf.Output, thresholds []tf.Output, left_node_contribs []tf.Output, right_node_contribs []tf.Output, max_depth tf.Output, learning_rate tf.Output, pruning_mode int64) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"pruning_mode": pruning_mode} - opspec := tf.OpSpec{ - Type: "BoostedTreesUpdateEnsemble", - Input: []tf.Input{ - tree_ensemble_handle, feature_ids, tf.OutputList(node_ids), tf.OutputList(gains), tf.OutputList(thresholds), tf.OutputList(left_node_contribs), tf.OutputList(right_node_contribs), max_depth, learning_rate, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes tan of x element-wise. -func Tan(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Tan", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that emits each dim-0 slice of `components` once. -func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "TensorSliceDataset", - Input: []tf.Input{ - tf.OutputList(components), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes acos of x element-wise. -func Acos(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Acos", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the Bessel i0e function of `x` element-wise. -// -// Exponentially scaled modified Bessel function of order 0 defined as -// `bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. -// -// This function is faster and numerically stabler than `bessel_i0(x)`. -func BesselI0e(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BesselI0e", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Shuffle dimensions of x according to a permutation. -// -// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: -// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` -func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Transpose", - Input: []tf.Input{ - x, perm, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MinAttr is an optional argument to Min. -type MinAttr func(optionalAttr) - -// MinKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func MinKeepDims(value bool) MinAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the minimum of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Min", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the Bessel i1e function of `x` element-wise. -// -// Exponentially scaled modified Bessel function of order 0 defined as -// `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. -// -// This function is faster and numerically stabler than `bessel_i1(x)`. -func BesselI1e(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BesselI1e", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns an element-wise indication of the sign of a number. -// -// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`. -// -// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. -func Sign(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sign", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that passes a sliding window over `input_dataset`. -// -// Arguments: -// -// window_size: A scalar representing the number of elements in the -// sliding window. -// window_shift: A scalar representing the steps moving the sliding window -// forward in one iteration. It must be positive. -// window_stride: A scalar representing the stride of the input elements of the sliding window. -// It must be positive. -// -// -func ExperimentalSlidingWindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, window_shift tf.Output, window_stride tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalSlidingWindowDataset", - Input: []tf.Input{ - input_dataset, window_size, window_shift, window_stride, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// OrderedMapUnstageNoKeyAttr is an optional argument to OrderedMapUnstageNoKey. -type OrderedMapUnstageNoKeyAttr func(optionalAttr) - -// OrderedMapUnstageNoKeyCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapUnstageNoKeyCapacity(value int64) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapUnstageNoKeyMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapUnstageNoKeyMemoryLimit(value int64) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapUnstageNoKeyContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageNoKeyContainer(value string) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapUnstageNoKeySharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageNoKeySharedName(value string) OrderedMapUnstageNoKeyAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes and returns the (key, value) element with the smallest -// -// key from the underlying container. If the underlying container -// does not contain elements, the op will block until it does. -func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageNoKeyAttr) (key tf.Output, values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OrderedMapUnstageNoKey", - Input: []tf.Input{ - indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - key = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapUnstageNoKey", err) - return - } - return key, values -} - -// Returns element-wise integer closest to x. -// -// If the result is midway between two representable values, -// the even representable is chosen. -// For example: -// -// ``` -// rint(-1.5) ==> -2.0 -// rint(0.5000001) ==> 1.0 -// rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] -// ``` -func Rint(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Rint", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the derivative of a Gamma random sample w.r.t. `alpha`. -func RandomGammaGrad(scope *Scope, alpha tf.Output, sample tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RandomGammaGrad", - Input: []tf.Input{ - alpha, sample, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns x + y element-wise. -// -// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Add", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns x + y element-wise. -// -// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AddV2", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. -type AllCandidateSamplerAttr func(optionalAttr) - -// AllCandidateSamplerSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. -// -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. -// -// Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to produce. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AllCandidateSampler", - Input: []tf.Input{ - true_classes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Returns element-wise remainder of division. When `x < 0` xor `y < 0` is -// -// true, this follows Python semantics in that the result here is consistent -// with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. -// -// *NOTE*: `FloorMod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func FloorMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FloorMod", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Saves the input tensors to disk. -// -// The size of `tensor_names` must match the number of tensors in `data`. `data[i]` -// is written to `filename` with name `tensor_names[i]`. -// -// See also `SaveSlices`. -// -// Arguments: -// filename: Must have a single element. The name of the file to which we write -// the tensor. -// tensor_names: Shape `[N]`. The names of the tensors to be saved. -// data: `N` tensors to save. -// -// Returns the created operation. -func Save(scope *Scope, filename tf.Output, tensor_names tf.Output, data []tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Save", - Input: []tf.Input{ - filename, tensor_names, tf.OutputList(data), - }, - } - return scope.AddOperation(opspec) -} - -// Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN. -// -// *NOTE*: `Mul` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func MulNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MulNoNan", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns x / y element-wise for integer types. -// -// Truncation designates that negative numbers will round fractional quantities -// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different -// than Python semantics. See `FloorDiv` for a division function that matches -// Python Semantics. -// -// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TruncateDiv", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RequantizePerChannelAttr is an optional argument to RequantizePerChannel. -type RequantizePerChannelAttr func(optionalAttr) - -// RequantizePerChannelOutType sets the optional out_type attribute to value. -// -// value: The quantized type of output tensor that needs to be converted. -// If not specified, defaults to DT_QUINT8 -func RequantizePerChannelOutType(value tf.DataType) RequantizePerChannelAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Requantizes input with min and max values known per channel. -// -// Arguments: -// input: The original input tensor. -// input_min: The minimum value of the input tensor -// input_max: The maximum value of the input tensor. -// requested_output_min: The minimum value of the output tensor requested. -// requested_output_max: The maximum value of the output tensor requested. -// -// Returns Output tensor.The minimum value of the final output tensorThe maximum value of the final output tensor. -func RequantizePerChannel(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, optional ...RequantizePerChannelAttr) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RequantizePerChannel", - Input: []tf.Input{ - input, input_min, input_max, requested_output_min, requested_output_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Restores tensors from a V2 checkpoint. -// -// For backward compatibility with the V1 format, this Op currently allows -// restoring from a V1 checkpoint as well: -// - This Op first attempts to find the V2 index file pointed to by "prefix", and -// if found proceed to read it as a V2 checkpoint; -// - Otherwise the V1 read path is invoked. -// Relying on this behavior is not recommended, as the ability to fall back to read -// V1 might be deprecated and eventually removed. -// -// By default, restores the named tensors in full. If the caller wishes to restore -// specific slices of stored tensors, "shape_and_slices" should be non-empty -// strings and correspondingly well-formed. -// -// Callers must ensure all the named tensors are indeed stored in the checkpoint. -// -// Arguments: -// prefix: Must have a single element. The prefix of a V2 checkpoint. -// tensor_names: shape {N}. The names of the tensors to be restored. -// shape_and_slices: shape {N}. The slice specs of the tensors to be restored. -// Empty strings indicate that they are non-partitioned tensors. -// dtypes: shape {N}. The list of expected dtype for the tensors. Must match -// those stored in the checkpoint. -// -// Returns shape {N}. The restored tensors, whose shapes are read from the -// checkpoint directly. -func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and_slices tf.Output, dtypes []tf.DataType) (tensors []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - opspec := tf.OpSpec{ - Type: "RestoreV2", - Input: []tf.Input{ - prefix, tensor_names, shape_and_slices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if tensors, idx, err = makeOutputList(op, idx, "tensors"); err != nil { - scope.UpdateErr("RestoreV2", err) - return - } - return tensors -} - -// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. -type FIFOQueueV2Attr func(optionalAttr) - -// FIFOQueueV2Shapes sets the optional shapes attribute to value. -// -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value - } -} - -// FIFOQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// FIFOQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func FIFOQueueV2Container(value string) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// FIFOQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that produces elements in first-in first-out order. -// -// Arguments: -// component_types: The type of each component in a value. -// -// Returns The handle to the queue. -func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FIFOQueueV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that contains the elements of `input_dataset` ignoring errors. -func ExperimentalIgnoreErrorsDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalIgnoreErrorsDataset", - Input: []tf.Input{ - input_dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns 0 if x == 0, and x / y otherwise, elementwise. -func Xdivy(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Xdivy", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Bucketizes 'input' based on 'boundaries'. -// -// For example, if the inputs are -// boundaries = [0, 10, 100] -// input = [[-5, 10000] -// [150, 10] -// [5, 100]] -// -// then the output will be -// output = [[0, 3] -// [3, 2] -// [1, 3]] -// -// Arguments: -// input: Any shape of Tensor contains with int or float type. -// boundaries: A sorted list of floats gives the boundary of the buckets. -// -// Returns Same shape with 'input', each value of input replaced with bucket index. -// -// @compatibility(numpy) -// Equivalent to np.digitize. -// @end_compatibility -func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"boundaries": boundaries} - opspec := tf.OpSpec{ - Type: "Bucketize", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Calculates gains for each feature and returns the best possible split information for the feature. -// -// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature. -// -// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split. -// -// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features). -// -// The length of output lists are all of the same length, `num_features`. -// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature. -// -// Arguments: -// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive). -// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used. -// l1: l1 regularization factor on leaf weights, per instance based. -// l2: l2 regularization factor on leaf weights, per instance based. -// tree_complexity: adjustment to the gain, per leaf based. -// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting. -// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors. -// -// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node. -func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"max_splits": max_splits} - opspec := tf.OpSpec{ - Type: "BoostedTreesCalculateBestGainsPerFeature", - Input: []tf.Input{ - node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil { - scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err) - return - } - return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list -} - -// EncodePngAttr is an optional argument to EncodePng. -type EncodePngAttr func(optionalAttr) - -// EncodePngCompression sets the optional compression attribute to value. -// -// value: Compression level. -// If not specified, defaults to -1 -func EncodePngCompression(value int64) EncodePngAttr { - return func(m optionalAttr) { - m["compression"] = value - } -} - -// PNG-encode an image. -// -// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` -// where `channels` is: -// -// * 1: for grayscale. -// * 2: for grayscale + alpha. -// * 3: for RGB. -// * 4: for RGBA. -// -// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder -// default or a value from 0 to 9. 9 is the highest compression level, generating -// the smallest output, but is slower. -// -// Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. PNG-encoded image. -func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodePng", - Input: []tf.Input{ - image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. -type QueueDequeueUpToV2Attr func(optionalAttr) - -// QueueDequeueUpToV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueUpToV2TimeoutMs(value int64) QueueDequeueUpToV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } -} - -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// This operation is not supported by all queues. If a queue does not support -// DequeueUpTo, then an Unimplemented error is returned. -// -// If the queue is closed and there are more than 0 but less than `n` -// elements remaining, then instead of returning an OutOfRange error like -// QueueDequeueMany, less than `n` elements are returned immediately. If -// the queue is closed and there are 0 elements left in the queue, then -// an OutOfRange error is returned just like in QueueDequeueMany. -// Otherwise the behavior is identical to QueueDequeueMany: -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size n in the 0th dimension. -// -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. -// -// Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. -// -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueUpToV2Attr) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueDequeueUpToV2", - Input: []tf.Input{ - handle, n, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueUpToV2", err) - return - } - return components -} - -// Returns the max of x and y (i.e. x > y ? x : y) element-wise. -// -// *NOTE*: `Maximum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Maximum", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns element-wise remainder of division. This emulates C semantics in that -// -// the result here is consistent with a truncating divide. E.g. -// `tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`. -// -// *NOTE*: `Mod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Mod", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns element-wise remainder of division. This emulates C semantics in that -// -// the result here is consistent with a truncating divide. E.g. `truncate(x / y) * -// y + truncate_mod(x, y) = x`. -// -// *NOTE*: `TruncateMod` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func TruncateMod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TruncateMod", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes offsets of concat inputs within its output. -// -// For example: -// -// ``` -// # 'x' is [2, 2, 7] -// # 'y' is [2, 3, 7] -// # 'z' is [2, 5, 7] -// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0] -// ``` -// -// This is typically used by gradient computations for a concat operation. -// -// Arguments: -// concat_dim: The dimension along which to concatenate. -// shape: The `N` int32 vectors representing shape of tensors being concatenated. -// -// Returns The `N` int32 vectors representing the starting offset -// of input tensors within the concatenated output. -func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConcatOffset", - Input: []tf.Input{ - concat_dim, tf.OutputList(shape), - }, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil { - scope.UpdateErr("ConcatOffset", err) - return - } - return offset -} - -// LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingRMSPropParametersGradAccumDebug. -type LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) - -// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Load RMSProp embedding parameters with debug support. -// -// An op that loads optimization parameters into HBM for embedding. Must be -// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct -// embedding table configuration. For example, this op is used to install -// parameters that are loaded from a checkpoint before a training loop is -// executed. -// -// Arguments: -// parameters: Value of parameters used in the RMSProp optimization algorithm. -// ms: Value of ms used in the RMSProp optimization algorithm. -// mom: Value of mom used in the RMSProp optimization algorithm. -// gradient_accumulators: Value of gradient_accumulators used in the RMSProp optimization algorithm. -// -// -// -// Returns the created operation. -func LoadTPUEmbeddingRMSPropParametersGradAccumDebug(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LoadTPUEmbeddingRMSPropParametersGradAccumDebug", - Input: []tf.Input{ - parameters, ms, mom, gradient_accumulators, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Compute the lower regularized incomplete Gamma function `P(a, x)`. -// -// The lower regularized incomplete Gamma function is defined as: -// -// -// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) -// -// where -// -// \\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) -// -// is the lower incomplete Gamma function. -// -// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete -// Gamma function. -func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Igamma", - Input: []tf.Input{ - a, x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Compute the regularized incomplete beta integral \\(I_x(a, b)\\). -// -// The regularized incomplete beta integral is defined as: -// -// -// \\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) -// -// where -// -// -// \\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) -// -// -// is the incomplete beta function and \\(B(a, b)\\) is the *complete* -// beta function. -func Betainc(scope *Scope, a tf.Output, b tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Betainc", - Input: []tf.Input{ - a, b, x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ShapeAttr is an optional argument to Shape. -type ShapeAttr func(optionalAttr) - -// ShapeOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func ShapeOutType(value tf.DataType) ShapeAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Returns the shape of a tensor. -// -// This operation returns a 1-D integer tensor representing the shape of `input`. -// -// For example: -// -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// shape(t) ==> [2, 2, 3] -// ``` -func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Shape", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes fingerprints of the input strings. -// -// Arguments: -// input: vector of strings to compute fingerprints on. -// -// Returns a (N,2) shaped matrix where N is the number of elements in the input -// vector. Each row contains the low and high parts of the fingerprint. -func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SdcaFprint", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the power of one value to another. -// -// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for -// corresponding elements in `x` and `y`. For example: -// -// ``` -// # tensor 'x' is [[2, 2]], [3, 3]] -// # tensor 'y' is [[8, 16], [2, 3]] -// tf.pow(x, y) ==> [[256, 65536], [9, 27]] -// ``` -func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Pow", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QuantizedReluXAttr is an optional argument to QuantizedReluX. -type QuantizedReluXAttr func(optionalAttr) - -// QuantizedReluXOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_QUINT8 -func QuantizedReluXOutType(value tf.DataType) QuantizedReluXAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)` -// -// Arguments: -// -// -// min_features: The float value that the lowest quantized value represents. -// max_features: The float value that the highest quantized value represents. -// -// Returns Has the same output shape as "features".The float value that the lowest quantized value represents.The float value that the highest quantized value represents. -func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_features tf.Output, max_features tf.Output, optional ...QuantizedReluXAttr) (activations tf.Output, min_activations tf.Output, max_activations tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedReluX", - Input: []tf.Input{ - features, max_value, min_features, max_features, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Returns the truth value of (x < y) element-wise. -// -// *NOTE*: `Less` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Less", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RandomPoissonAttr is an optional argument to RandomPoisson. -type RandomPoissonAttr func(optionalAttr) - -// RandomPoissonSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomPoissonSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed2(value int64) RandomPoissonAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Use RandomPoissonV2 instead. -// -// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 -func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomPoisson", - Input: []tf.Input{ - shape, rate, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Gets the next output from the given iterator. -// -// This operation is a synchronous version IteratorGetNext. It should only be used -// in situations where the iterator does not block the calling thread, or where -// the calling thread is not a member of the thread pool used to execute parallel -// operations (e.g. in eager mode). -func IteratorGetNextSync(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "IteratorGetNextSync", - Input: []tf.Input{ - iterator, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("IteratorGetNextSync", err) - return - } - return components -} - -// Returns the truth value of (x >= y) element-wise. -// -// *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func GreaterEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "GreaterEqual", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ApproximateEqualAttr is an optional argument to ApproximateEqual. -type ApproximateEqualAttr func(optionalAttr) - -// ApproximateEqualTolerance sets the optional tolerance attribute to value. -// If not specified, defaults to 1e-05 -func ApproximateEqualTolerance(value float32) ApproximateEqualAttr { - return func(m optionalAttr) { - m["tolerance"] = value - } -} - -// Returns the truth value of abs(x-y) < tolerance element-wise. -func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...ApproximateEqualAttr) (z tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ApproximateEqual", - Input: []tf.Input{ - x, y, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the truth value of x OR y element-wise. -// -// *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LogicalOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogicalOr", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MatMulAttr is an optional argument to MatMul. -type MatMulAttr func(optionalAttr) - -// MatMulTransposeA sets the optional transpose_a attribute to value. -// -// value: If true, "a" is transposed before multiplication. -// If not specified, defaults to false -func MatMulTransposeA(value bool) MatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// MatMulTransposeB sets the optional transpose_b attribute to value. -// -// value: If true, "b" is transposed before multiplication. -// If not specified, defaults to false -func MatMulTransposeB(value bool) MatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// Multiply the matrix "a" by the matrix "b". -// -// The inputs must be two-dimensional matrices and the inner dimension of -// "a" (after being transposed if transpose_a is true) must match the -// outer dimension of "b" (after being transposed if transposed_b is -// true). -// -// *Note*: The default kernel implementation for MatMul on GPUs uses -// cublas. -func MatMul(scope *Scope, a tf.Output, b tf.Output, optional ...MatMulAttr) (product tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatMul", - Input: []tf.Input{ - a, b, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. -type InitializeTableFromTextFileV2Attr func(optionalAttr) - -// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. -// -// value: Number of elements of the file, use -1 if unknown. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { - return func(m optionalAttr) { - m["vocab_size"] = value - } -} - -// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. -// -// value: Delimiter to separate fields in a line. -// If not specified, defaults to "\t" -func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { - return func(m optionalAttr) { - m["delimiter"] = value - } -} - -// Initializes a table from a text file. -// -// It inserts one key-value pair into the table for each line of the file. -// The key and value is extracted from the whole line content, elements from the -// split line based on `delimiter` or the line number (starting from zero). -// Where to extract the key and value from a line is specified by `key_index` and -// `value_index`. -// -// - A value of -1 means use the line number(starting from zero), expects `int64`. -// - A value of -2 means use the whole line content, expects `string`. -// - A value >= 0 means use the index (starting at zero) of the split line based -// on `delimiter`. -// -// Arguments: -// table_handle: Handle to a table which will be initialized. -// filename: Filename of a vocabulary text file. -// key_index: Column index in a line to get the table `key` values from. -// value_index: Column index that represents information of a line to get the table -// `value` values from. -// -// Returns the created operation. -func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "InitializeTableFromTextFileV2", - Input: []tf.Input{ - table_handle, filename, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// MeanAttr is an optional argument to Mean. -type MeanAttr func(optionalAttr) - -// MeanKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func MeanKeepDims(value bool) MeanAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the mean of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Mean(scope *Scope, input tf.Output, axis tf.Output, optional ...MeanAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Mean", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ProdAttr is an optional argument to Prod. -type ProdAttr func(optionalAttr) - -// ProdKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func ProdKeepDims(value bool) ProdAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the product of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Prod", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBilinearAttr is an optional argument to ResizeBilinear. -type ResizeBilinearAttr func(optionalAttr) - -// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// ResizeBilinearHalfPixelCenters sets the optional half_pixel_centers attribute to value. -// If not specified, defaults to false -func ResizeBilinearHalfPixelCenters(value bool) ResizeBilinearAttr { - return func(m optionalAttr) { - m["half_pixel_centers"] = value - } -} - -// Resize `images` to `size` using bilinear interpolation. -// -// Input images can be of different types but output images are always float. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBilinear", - Input: []tf.Input{ - images, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxAttr is an optional argument to Max. -type MaxAttr func(optionalAttr) - -// MaxKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func MaxKeepDims(value bool) MaxAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the maximum of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Max(scope *Scope, input tf.Output, axis tf.Output, optional ...MaxAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Max", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that contains the unique elements of `input_dataset`. -func ExperimentalUniqueDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ExperimentalUniqueDataset", - Input: []tf.Input{ - input_dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ArgMinAttr is an optional argument to ArgMin. -type ArgMinAttr func(optionalAttr) - -// ArgMinOutputType sets the optional output_type attribute to value. -// If not specified, defaults to DT_INT64 -func ArgMinOutputType(value tf.DataType) ArgMinAttr { - return func(m optionalAttr) { - m["output_type"] = value - } -} - -// Returns the index with the smallest value across dimensions of a tensor. -// -// Note that in case of ties the identity of the return value is not guaranteed. -// -// Arguments: -// -// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. -// Describes which dimension of the input Tensor to reduce across. For vectors, -// use dimension = 0. -func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ArgMin", - Input: []tf.Input{ - input, dimension, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Converts the quantized `input` tensor into a lower-precision `output`. -// -// Converts the quantized `input` tensor into a lower-precision `output`, using the -// output range specified with `requested_output_min` and `requested_output_max`. -// -// `[input_min, input_max]` are scalar floats that specify the range for the float -// interpretation of the `input` data. For example, if `input_min` is -1.0f and -// `input_max` is 1.0f, and we are dealing with `quint16` quantized data, then a 0 -// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. -// -// Arguments: -// -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// requested_output_min: The float value that the minimum quantized output value represents. -// requested_output_max: The float value that the maximum quantized output value represents. -// out_type: The type of the output. Should be a lower bit depth than Tinput. -// -// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output. -func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"out_type": out_type} - opspec := tf.OpSpec{ - Type: "Requantize", - Input: []tf.Input{ - input, input_min, input_max, requested_output_min, requested_output_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Creates a dataset that emits the lines of one or more text files. -// -// Arguments: -// filenames: A scalar or a vector containing the name(s) of the file(s) to be -// read. -// compression_type: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// buffer_size: A scalar containing the number of bytes to buffer. -func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TextLineDataset", - Input: []tf.Input{ - filenames, compression_type, buffer_size, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \sum_j data_j\\) where sum is over `j` such -// that `segment_ids[j] == i`. -// -// If the sum is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
-// -// For example: -// -// ``` -// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) -// tf.segment_sum(c, tf.constant([0, 0, 1])) -// # ==> [[5, 5, 5, 5], -// # [5, 6, 7, 8]] -// ``` -// -// -// Arguments: -// -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SegmentSum", - Input: []tf.Input{ - data, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the mean along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is -// over `j` such that `segment_ids[j] == i` and `N` is the total number of -// values summed. -// -// If the mean is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
-// -// For example: -// -// ``` -// c = tf.constant([[1.0,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) -// tf.segment_mean(c, tf.constant([0, 0, 1])) -// # ==> [[2.5, 2.5, 2.5, 2.5], -// # [5, 6, 7, 8]] -// ``` -// -// -// Arguments: -// -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMean(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SegmentMean", - Input: []tf.Input{ - data, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the minimum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such -// that `segment_ids[j] == i`. -// -// If the min is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
-// -// For example: -// -// ``` -// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) -// tf.segment_min(c, tf.constant([0, 0, 1])) -// # ==> [[1, 2, 2, 1], -// # [5, 6, 7, 8]] -// ``` -// -// Arguments: -// -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SegmentMin", - Input: []tf.Input{ - data, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such -// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` -// need not be sorted and need not cover all values in the full -// range of valid values. -// -// If the sum is empty for a given segment ID `i`, `output[i] = 0`. -// If the given segment ID `i` is negative, the value is dropped and will not be -// added to the sum of the segment. -// -// `num_segments` should equal the number of distinct segment IDs. -// -//
-// -//
-// -// ``` python -// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) -// tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2) -// # ==> [[ 5, 5, 5, 5], -// # [5, 6, 7, 8]] -// ``` -// -// -// Arguments: -// -// segment_ids: A tensor whose shape is a prefix of `data.shape`. -// -// -// Returns Has same shape as data, except for the first `segment_ids.rank` -// dimensions, which are replaced with a single dimension which has size -// `num_segments`. -func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "UnsortedSegmentSum", - Input: []tf.Input{ - data, segment_ids, num_segments, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the product along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// This operator is similar to the unsorted segment sum operator found -// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). -// Instead of computing the sum over segments, it computes the product of all -// entries belonging to a segment such that: -// -// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples -// `j...` such that `segment_ids[j...] == i`. -// -// For example: -// -// ``` python -// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]]) -// tf.unsorted_segment_prod(c, tf.constant([0, 1, 0]), num_segments=2) -// # ==> [[ 4, 6, 6, 4], -// # [5, 6, 7, 8]] -// ``` -// -// If there is no entry for a given segment ID `i`, it outputs 1. -// -// If the given segment ID `i` is negative, then the corresponding value is -// dropped, and will not be included in the result. -// -// Arguments: -// -// segment_ids: A tensor whose shape is a prefix of `data.shape`. -// -// -// Returns Has same shape as data, except for the first `segment_ids.rank` -// dimensions, which are replaced with a single dimension which has size -// `num_segments`. -func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "UnsortedSegmentProd", - Input: []tf.Input{ - data, segment_ids, num_segments, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes hyperbolic cosine of x element-wise. -func Cosh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cosh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the mean along sparse segments of a tensor. -// -// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. -// -// Returns Has same shape as data, except for dimension 0 which has size -// `num_segments`. -func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentMeanWithNumSegments", - Input: []tf.Input{ - data, indices, segment_ids, num_segments, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize. -type CudnnRNNParamsSizeAttr func(optionalAttr) - -// CudnnRNNParamsSizeRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNParamsSizeRnnMode(value string) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} - -// CudnnRNNParamsSizeInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNParamsSizeInputMode(value string) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} - -// CudnnRNNParamsSizeDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNParamsSizeDirection(value string) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["direction"] = value - } -} - -// CudnnRNNParamsSizeDropout sets the optional dropout attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsSizeDropout(value float32) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["dropout"] = value - } -} - -// CudnnRNNParamsSizeSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsSizeSeed(value int64) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// CudnnRNNParamsSizeSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsSizeSeed2(value int64) CudnnRNNParamsSizeAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Computes size of weights that can be used by a Cudnn RNN model. -// -// Return the params size that can be used by the Cudnn RNN model. Subsequent -// weight allocation and initialization should use this size. -// -// num_layers: Specifies the number of layers in the RNN model. -// num_units: Specifies the size of the hidden state. -// input_size: Specifies the size of the input state. -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -// params_size: The size of the params buffer that should be allocated and -// initialized for this RNN model. Note that this params buffer may not be -// compatible across GPUs. Please use CudnnRNNParamsWeights and -// CudnnRNNParamsBiases to save and restore them in a way that is compatible -// across different runs. -func CudnnRNNParamsSize(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, T tf.DataType, S tf.DataType, optional ...CudnnRNNParamsSizeAttr) (params_size tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"T": T, "S": S} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CudnnRNNParamsSize", - Input: []tf.Input{ - num_layers, num_units, input_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes gradients for SparseSegmentMean. -// -// Returns tensor "output" with same shape as grad, except for dimension 0 whose -// value is output_dim0. -// -// Arguments: -// grad: gradient propagated to the SparseSegmentMean op. -// indices: indices passed to the corresponding SparseSegmentMean op. -// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. -// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. -func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentMeanGrad", - Input: []tf.Input{ - grad, indices, segment_ids, output_dim0, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. -// -// N is the size of the segment being reduced. -// -// See `tf.sparse.segment_sum` for usage examples. -// -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtN", - Input: []tf.Input{ - data, indices, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Compute the upper regularized incomplete Gamma function `Q(a, x)`. -// -// The upper regularized incomplete Gamma function is defined as: -// -// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) -// -// where -// -// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) -// -// is the upper incomplete Gama function. -// -// Note, above `P(a, x)` (`Igamma`) is the lower regularized complete -// Gamma function. -func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Igammac", - Input: []tf.Input{ - a, x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. -// -// N is the size of the segment being reduced. -// -// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSqrtNWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtNWithNumSegments", - Input: []tf.Input{ - data, indices, segment_ids, num_segments, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes gradients for SparseSegmentSqrtN. -// -// Returns tensor "output" with same shape as grad, except for dimension 0 whose -// value is output_dim0. -// -// Arguments: -// grad: gradient propagated to the SparseSegmentSqrtN op. -// indices: indices passed to the corresponding SparseSegmentSqrtN op. -// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op. -// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. -func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtNGrad", - Input: []tf.Input{ - grad, indices, segment_ids, output_dim0, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LRNGradAttr is an optional argument to LRNGrad. -type LRNGradAttr func(optionalAttr) - -// LRNGradDepthRadius sets the optional depth_radius attribute to value. -// -// value: A depth radius. -// If not specified, defaults to 5 -func LRNGradDepthRadius(value int64) LRNGradAttr { - return func(m optionalAttr) { - m["depth_radius"] = value - } -} - -// LRNGradBias sets the optional bias attribute to value. -// -// value: An offset (usually > 0 to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNGradBias(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["bias"] = value - } -} - -// LRNGradAlpha sets the optional alpha attribute to value. -// -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNGradAlpha(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["alpha"] = value - } -} - -// LRNGradBeta sets the optional beta attribute to value. -// -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNGradBeta(value float32) LRNGradAttr { - return func(m optionalAttr) { - m["beta"] = value - } -} - -// Gradients for Local Response Normalization. -// -// Arguments: -// input_grads: 4-D with shape `[batch, height, width, channels]`. -// input_image: 4-D with shape `[batch, height, width, channels]`. -// output_image: 4-D with shape `[batch, height, width, channels]`. -// -// Returns The gradients for LRN. -func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LRNGrad", - Input: []tf.Input{ - input_grads, input_image, output_image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AnyAttr is an optional argument to Any. -type AnyAttr func(optionalAttr) - -// AnyKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AnyKeepDims(value bool) AnyAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the "logical or" of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Any", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. -type DestroyResourceOpAttr func(optionalAttr) - -// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. -// -// value: whether to ignore the error when the resource -// doesn't exist. -// If not specified, defaults to true -func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { - return func(m optionalAttr) { - m["ignore_lookup_error"] = value - } -} - -// Deletes the resource specified by the handle. -// -// All subsequent operations using the resource will result in a NotFound -// error status. -// -// Arguments: -// resource: handle to the resource to delete. -// -// Returns the created operation. -func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DestroyResourceOp", - Input: []tf.Input{ - resource, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Generates values in an interval. -// -// A sequence of `num` evenly-spaced values are generated beginning at `start`. -// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, -// so that the last one is exactly `stop`. -// -// For example: -// -// ``` -// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] -// ``` -// -// Arguments: -// start: 0-D tensor. First entry in the range. -// stop: 0-D tensor. Last entry in the range. -// num: 0-D tensor. Number of values to generate. -// -// Returns 1-D. The generated values. -func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LinSpace", - Input: []tf.Input{ - start, stop, num, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ComplexAttr is an optional argument to Complex. -type ComplexAttr func(optionalAttr) - -// ComplexTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_COMPLEX64 -func ComplexTout(value tf.DataType) ComplexAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Converts two real numbers to a complex number. -// -// Given a tensor `real` representing the real part of a complex number, and a -// tensor `imag` representing the imaginary part of a complex number, this -// operation returns complex numbers elementwise of the form \\(a + bj\\), where -// *a* represents the `real` part and *b* represents the `imag` part. -// -// The input tensors `real` and `imag` must have the same shape. -// -// For example: -// -// ``` -// # tensor 'real' is [2.25, 3.25] -// # tensor `imag` is [4.75, 5.75] -// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] -// ``` -func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Complex", - Input: []tf.Input{ - real, imag, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ImagAttr is an optional argument to Imag. -type ImagAttr func(optionalAttr) - -// ImagTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func ImagTout(value tf.DataType) ImagAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Returns the imaginary part of a complex number. -// -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the imaginary part of each element in `input`. All -// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* -// is the real part and *b* is the imaginary part returned by this operation. -// -// For example: -// -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.imag(input) ==> [4.75, 5.75] -// ``` -func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Imag", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes hyperbolic tangent of `x` element-wise. -func Tanh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Tanh", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the maximum along segments of a tensor. -// -// Read -// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) -// for an explanation of segments. -// -// Computes a tensor such that -// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such -// that `segment_ids[j] == i`. -// -// If the max is empty for a given segment ID `i`, `output[i] = 0`. -// -//
-// -//
-// -// For example: -// -// ``` -// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) -// tf.segment_max(c, tf.constant([0, 0, 1])) -// # ==> [[4, 3, 3, 4], -// # [5, 6, 7, 8]] -// ``` -// -// -// Arguments: -// -// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s -// first dimension. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SegmentMax", - Input: []tf.Input{ - data, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that skips `count` elements from the `input_dataset`. -// -// Arguments: -// -// count: A scalar representing the number of elements from the `input_dataset` -// that should be skipped. If count is -1, skips everything. -// -// -func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "SkipDataset", - Input: []tf.Input{ - input_dataset, count, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// VarHandleOpAttr is an optional argument to VarHandleOp. -type VarHandleOpAttr func(optionalAttr) - -// VarHandleOpContainer sets the optional container attribute to value. -// -// value: the container this variable is placed in. -// If not specified, defaults to "" -func VarHandleOpContainer(value string) VarHandleOpAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// VarHandleOpSharedName sets the optional shared_name attribute to value. -// -// value: the name by which this variable is referred to. -// If not specified, defaults to "" -func VarHandleOpSharedName(value string) VarHandleOpAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a handle to a Variable resource. -// -// Arguments: -// dtype: the type of this variable. Must agree with the dtypes -// of all ops using this variable. -// shape: The (possibly partially specified) shape of this variable. -func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "VarHandleOp", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AngleAttr is an optional argument to Angle. -type AngleAttr func(optionalAttr) - -// AngleTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func AngleTout(value tf.DataType) AngleAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Returns the argument of a complex number. -// -// Given a tensor `input` of complex numbers, this operation returns a tensor of -// type `float` that is the argument of each element in `input`. All elements in -// `input` must be complex numbers of the form \\(a + bj\\), where *a* -// is the real part and *b* is the imaginary part. -// -// The argument returned by this operation is of the form \\(atan2(b, a)\\). -// -// For example: -// -// ``` -// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -// tf.angle(input) ==> [2.0132, 1.056] -// ``` -// -// @compatibility(numpy) -// Equivalent to np.angle. -// @end_compatibility -func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Angle", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Clips tensor values to a specified min and max. -// -// Given a tensor `t`, this operation returns a tensor of the same type and -// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. -// Any values less than `clip_value_min` are set to `clip_value_min`. Any values -// greater than `clip_value_max` are set to `clip_value_max`. -// -// Arguments: -// t: A `Tensor`. -// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape -// as `t`. The minimum value to clip by. -// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape -// as `t`. The maximum value to clip by. -// -// Returns A clipped `Tensor` with the same shape as input 't'. -func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ClipByValue", - Input: []tf.Input{ - t, clip_value_min, clip_value_max, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Counts the number of occurrences of each value in an integer array. -// -// Outputs a vector with length `size` and the same dtype as `weights`. If -// `weights` are empty, then index `i` stores the number of times the value `i` is -// counted in `arr`. If `weights` are non-empty, then index `i` stores the sum of -// the value in `weights` at each index where the corresponding value in `arr` is -// `i`. -// -// Values in `arr` outside of the range [0, size) are ignored. -// -// Arguments: -// arr: int32 `Tensor`. -// size: non-negative int32 scalar `Tensor`. -// weights: is an int32, int64, float32, or float64 `Tensor` with the same -// shape as `arr`, or a length-0 `Tensor`, in which case it acts as all weights -// equal to 1. -// -// Returns 1D `Tensor` with length equal to `size`. The counts or summed weights for -// each value in the range [0, size). -func Bincount(scope *Scope, arr tf.Output, size tf.Output, weights tf.Output) (bins tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Bincount", - Input: []tf.Input{ - arr, size, weights, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CumsumAttr is an optional argument to Cumsum. -type CumsumAttr func(optionalAttr) - -// CumsumExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumsum. -// If not specified, defaults to false -func CumsumExclusive(value bool) CumsumAttr { - return func(m optionalAttr) { - m["exclusive"] = value - } -} - -// CumsumReverse sets the optional reverse attribute to value. -// -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumsumReverse(value bool) CumsumAttr { - return func(m optionalAttr) { - m["reverse"] = value - } -} - -// Compute the cumulative sum of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumsum, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is -// performed instead: -// -// ```python -// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] -// ``` -// -// By setting the `reverse` kwarg to `True`, the cumsum is performed in the -// opposite direction: -// -// ```python -// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] -// ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: -// -// ```python -// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] -// ``` -// -// Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Cumsum", - Input: []tf.Input{ - x, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Return the shape of s0 op s1 with broadcast. -// -// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the -// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. -func BroadcastArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BroadcastArgs", - Input: []tf.Input{ - s0, s1, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DataFormatDimMapAttr is an optional argument to DataFormatDimMap. -type DataFormatDimMapAttr func(optionalAttr) - -// DataFormatDimMapSrcFormat sets the optional src_format attribute to value. -// -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatDimMapSrcFormat(value string) DataFormatDimMapAttr { - return func(m optionalAttr) { - m["src_format"] = value - } -} - -// DataFormatDimMapDstFormat sets the optional dst_format attribute to value. -// -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatDimMapDstFormat(value string) DataFormatDimMapAttr { - return func(m optionalAttr) { - m["dst_format"] = value - } -} - -// Returns the dimension index in the destination data format given the one in -// -// the source data format. -// -// Arguments: -// x: A Tensor with each element as a dimension index in source data format. -// Must be in the range [-4, 4). -// -// Returns A Tensor with each element as a dimension index in destination data format. -func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAttr) (y tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DataFormatDimMap", - Input: []tf.Input{ - x, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CumprodAttr is an optional argument to Cumprod. -type CumprodAttr func(optionalAttr) - -// CumprodExclusive sets the optional exclusive attribute to value. -// -// value: If `True`, perform exclusive cumprod. -// If not specified, defaults to false -func CumprodExclusive(value bool) CumprodAttr { - return func(m optionalAttr) { - m["exclusive"] = value - } -} - -// CumprodReverse sets the optional reverse attribute to value. -// -// value: A `bool` (default: False). -// If not specified, defaults to false -func CumprodReverse(value bool) CumprodAttr { - return func(m optionalAttr) { - m["reverse"] = value - } -} - -// Compute the cumulative product of the tensor `x` along `axis`. -// -// By default, this op performs an inclusive cumprod, which means that the first -// element of the input is identical to the first element of the output: -// -// ```python -// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] -// ``` -// -// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is -// performed instead: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] -// ``` -// -// By setting the `reverse` kwarg to `True`, the cumprod is performed in the -// opposite direction: -// -// ```python -// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] -// ``` -// -// This is more efficient than using separate `tf.reverse` ops. -// -// The `reverse` and `exclusive` kwargs can also be combined: -// -// ```python -// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] -// ``` -// -// Arguments: -// x: A `Tensor`. Must be one of the following types: `float32`, `float64`, -// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, -// `complex128`, `qint8`, `quint8`, `qint32`, `half`. -// axis: A `Tensor` of type `int32` (default: 0). Must be in the range -// `[-rank(x), rank(x))`. -func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Cumprod", - Input: []tf.Input{ - x, axis, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr is an optional argument to RetrieveTPUEmbeddingStochasticGradientDescentParameters. -type RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr func(optionalAttr) - -// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableId(value int64) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingStochasticGradientDescentParametersTableName(value string) RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// Retrieve SGD embedding parameters. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns Parameter parameters updated by the stochastic gradient descent optimization algorithm. -func RetrieveTPUEmbeddingStochasticGradientDescentParameters(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingStochasticGradientDescentParametersAttr) (parameters tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingStochasticGradientDescentParameters", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. -type QuantizedMatMulAttr func(optionalAttr) - -// QuantizedMatMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } -} - -// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. -// -// value: If true, `a` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. -// -// value: If true, `b` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. -// -// value: The type of output produced by activation function -// following this operation. -// If not specified, defaults to DT_QUINT8 -func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["Tactivation"] = value - } -} - -// Perform a quantized matrix multiplication of `a` by the matrix `b`. -// -// The inputs must be two-dimensional matrices and the inner dimension of -// `a` (after being transposed if `transpose_a` is non-zero) must match the -// outer dimension of `b` (after being transposed if `transposed_b` is -// non-zero). -// -// Arguments: -// a: Must be a two-dimensional tensor. -// b: Must be a two-dimensional tensor. -// min_a: The float value that the lowest quantized `a` value represents. -// max_a: The float value that the highest quantized `a` value represents. -// min_b: The float value that the lowest quantized `b` value represents. -// max_b: The float value that the highest quantized `b` value represents. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedMatMul", - Input: []tf.Input{ - a, b, min_a, max_a, min_b, max_b, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// QuantizedMulAttr is an optional argument to QuantizedMul. -type QuantizedMulAttr func(optionalAttr) - -// QuantizedMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } -} - -// Returns x * y element-wise, working on quantized buffers. -// -// Arguments: -// -// -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -// -// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedMul", - Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// QuantizedAddAttr is an optional argument to QuantizedAdd. -type QuantizedAddAttr func(optionalAttr) - -// QuantizedAddToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } -} - -// Returns x + y element-wise, working on quantized buffers. -// -// Arguments: -// -// -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -// -// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedAdd", - Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Scatters tensor at indices in an input list. -// -// Each member of the TensorList corresponds to one row of the input tensor, -// specified by the given index (see `tf.gather`). -// -// input_handle: The list to scatter into. -// tensor: The input tensor. -// indices: The indices used to index into the list. -// output_handle: The TensorList. -func TensorListScatterIntoExistingList(scope *Scope, input_handle tf.Output, tensor tf.Output, indices tf.Output) (output_handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorListScatterIntoExistingList", - Input: []tf.Input{ - input_handle, tensor, indices, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes a range that covers the actual values present in a quantized tensor. -// -// Given a quantized tensor described by `(input, input_min, input_max)`, outputs a -// range that covers the actual values present in that tensor. This op is typically -// used to produce the `requested_output_min` and `requested_output_max` for -// `Requantize`. -// -// Arguments: -// -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// -// Returns The computed min output.the computed max output. -func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RequantizationRange", - Input: []tf.Input{ - input, input_min, input_max, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Rolls the elements of a tensor along an axis. -// -// The elements are shifted positively (towards larger indices) by the offset of -// `shift` along the dimension of `axis`. Negative `shift` values will shift -// elements in the opposite direction. Elements that roll passed the last position -// will wrap around to the first and vice versa. Multiple shifts along multiple -// axes may be specified. -// -// For example: -// -// ``` -// # 't' is [0, 1, 2, 3, 4] -// roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] -// -// # shifting along multiple dimensions -// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] -// roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] -// -// # shifting along the same axis multiple times -// # 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] -// roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] -// ``` -// -// Arguments: -// -// shift: Dimension must be 0-D or 1-D. `shift[i]` specifies the number of places by which -// elements are shifted positively (towards larger indices) along the dimension -// specified by `axis[i]`. Negative shifts will roll the elements in the opposite -// direction. -// axis: Dimension must be 0-D or 1-D. `axis[i]` specifies the dimension that the shift -// `shift[i]` should occur. If the same axis is referenced more than once, the -// total shift for that axis will be the sum of all the shifts that belong to that -// axis. -// -// Returns Has the same shape and size as the input. The elements are shifted -// positively (towards larger indices) by the offsets of `shift` along the -// dimensions of `axis`. -func Roll(scope *Scope, input tf.Output, shift tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Roll", - Input: []tf.Input{ - input, shift, axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Updates the table to associates keys with values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. -// -// Returns the created operation. -func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableInsertV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// Creates a `Dataset` that includes only 1/`num_shards` of this dataset. -// -// Arguments: -// -// num_shards: An integer representing the number of shards operating in parallel. -// index: An integer representing the current worker index. -// -// -func ShardDataset(scope *Scope, input_dataset tf.Output, num_shards tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ShardDataset", - Input: []tf.Input{ - input_dataset, num_shards, index, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that batches and pads `batch_size` elements from the input. -// -// Arguments: -// -// batch_size: A scalar representing the number of elements to accumulate in a -// batch. -// padded_shapes: A list of int64 tensors representing the desired padded shapes -// of the corresponding output components. These shapes may be partially -// specified, using `-1` to indicate that a particular dimension should be -// padded to the maximum size of all batch elements. -// padding_values: A list of scalars containing the padding value to use for -// each of the outputs. -// drop_remainder: A scalar representing whether the last batch should be dropped in case its size -// is smaller than desired. -// -func PaddedBatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, drop_remainder tf.Output, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "PaddedBatchDatasetV2", - Input: []tf.Input{ - input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), drop_remainder, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns element-wise smallest integer not less than x. -func Ceil(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Ceil", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the number of elements in the given table. -// -// Arguments: -// table_handle: Handle to the table. -// -// Returns Scalar that contains number of elements in the table. -func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableSizeV2", - Input: []tf.Input{ - table_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. -type ResizeBilinearGradAttr func(optionalAttr) - -// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// ResizeBilinearGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. -// If not specified, defaults to false -func ResizeBilinearGradHalfPixelCenters(value bool) ResizeBilinearGradAttr { - return func(m optionalAttr) { - m["half_pixel_centers"] = value - } -} - -// Computes the gradient of bilinear interpolation. -// -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. -// -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBilinearGrad", - Input: []tf.Input{ - grads, original_image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs all keys and values in the table. -// -// Arguments: -// table_handle: Handle to the table. -// -// -// -// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. -func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} - opspec := tf.OpSpec{ - Type: "LookupTableExportV2", - Input: []tf.Input{ - table_handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// MultiDeviceIteratorFromStringHandleAttr is an optional argument to MultiDeviceIteratorFromStringHandle. -type MultiDeviceIteratorFromStringHandleAttr func(optionalAttr) - -// MultiDeviceIteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. -// -// value: The type list for the return values. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func MultiDeviceIteratorFromStringHandleOutputTypes(value []tf.DataType) MultiDeviceIteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_types"] = value - } -} - -// MultiDeviceIteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. -// -// value: The list of shapes being produced. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func MultiDeviceIteratorFromStringHandleOutputShapes(value []tf.Shape) MultiDeviceIteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_shapes"] = value - } -} - -// Generates a MultiDeviceIterator resource from its provided string handle. -// -// Arguments: -// string_handle: String representing the resource. -// -// Returns A MultiDeviceIterator resource. -func MultiDeviceIteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...MultiDeviceIteratorFromStringHandleAttr) (multi_device_iterator tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MultiDeviceIteratorFromStringHandle", - Input: []tf.Input{ - string_handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. -type MutableHashTableV2Attr func(optionalAttr) - -// MutableHashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableV2Container(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates an empty hash table. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MutableHashTableV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DequantizeAttr is an optional argument to Dequantize. -type DequantizeAttr func(optionalAttr) - -// DequantizeMode sets the optional mode attribute to value. -// If not specified, defaults to "MIN_COMBINED" -func DequantizeMode(value string) DequantizeAttr { - return func(m optionalAttr) { - m["mode"] = value - } -} - -// Dequantize the 'input' tensor into a float Tensor. -// -// [min_range, max_range] are scalar floats that specify the range for -// the 'input' data. The 'mode' attribute controls exactly which calculations are -// used to convert the float values to their quantized equivalents. -// -// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following: -// -// ``` -// if T == qint8: in[i] += (range(T) + 1)/ 2.0 -// out[i] = min_range + (in[i]* (max_range - min_range) / range(T)) -// ``` -// here `range(T) = numeric_limits::max() - numeric_limits::min()` -// -// *MIN_COMBINED Mode Example* -// -// If the input comes from a QuantizedRelu6, the output type is -// quint8 (range of 0-255) but the possible range of QuantizedRelu6 is -// 0-6. The min_range and max_range values are therefore 0.0 and 6.0. -// Dequantize on quint8 will take each value, cast to float, and multiply -// by 6 / 255. -// Note that if quantizedtype is qint8, the operation will additionally add -// each value by 128 prior to casting. -// -// If the mode is 'MIN_FIRST', then this approach is used: -// -// ```c++ -// num_discrete_values = 1 << (# of bits in T) -// range_adjust = num_discrete_values / (num_discrete_values - 1) -// range = (range_max - range_min) * range_adjust -// range_scale = range / num_discrete_values -// const double offset_input = static_cast(input) - lowest_quantized; -// result = range_min + ((input - numeric_limits::min()) * range_scale) -// ``` -// -// *SCALED mode Example* -// -// `SCALED` mode matches the quantization approach used in -// `QuantizeAndDequantize{V2|V3}`. -// -// If the mode is `SCALED`, we do not use the full range of the output type, -// choosing to elide the lowest possible value for symmetry (e.g., output range is -// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to -// 0. -// -// We first find the range of values in our tensor. The -// range we use is always centered on 0, so we find m such that -// ```c++ -// m = max(abs(input_min), abs(input_max)) -// ``` -// -// Our input tensor range is then `[-m, m]`. -// -// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. -// If T is signed, this is -// ``` -// num_bits = sizeof(T) * 8 -// [min_fixed, max_fixed] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1] -// ``` -// -// Otherwise, if T is unsigned, the fixed-point range is -// ``` -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] -// ``` -// -// From this we compute our scaling factor, s: -// ```c++ -// s = (2 * m) / (max_fixed - min_fixed) -// ``` -// -// Now we can dequantize the elements of our tensor: -// ```c++ -// result = input * s -// ``` -// -// Arguments: -// -// min_range: The minimum scalar value possibly produced for the input. -// max_range: The maximum scalar value possibly produced for the input. -func Dequantize(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, optional ...DequantizeAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Dequantize", - Input: []tf.Input{ - input, min_range, max_range, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Flips all bits elementwise. -// -// The result will have exactly those bits set, that are not set in `x`. The -// computation is performed on the underlying representation of x. -func Invert(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Invert", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deserialize bucket boundaries and ready flag into current QuantileAccumulator. -// -// An op that deserializes bucket boundaries and are boundaries ready flag into current QuantileAccumulator. -// -// Arguments: -// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. -// bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a feature. -// -// Returns the created operation. -func BoostedTreesQuantileStreamResourceDeserialize(scope *Scope, quantile_stream_resource_handle tf.Output, bucket_boundaries []tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BoostedTreesQuantileStreamResourceDeserialize", - Input: []tf.Input{ - quantile_stream_resource_handle, tf.OutputList(bucket_boundaries), - }, - } - return scope.AddOperation(opspec) -} - -// Inverse 3D fast Fourier transform. -// -// Computes the inverse 3-dimensional discrete Fourier transform over the -// inner-most 3 dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their inverse 3D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.ifftn with 3 dimensions. -// @end_compatibility -func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IFFT3D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Shuts down a running distributed TPU system. -// -// The op returns an error if no system is running. -// -// Returns the created operation. -func ShutdownDistributedTPU(scope *Scope) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ShutdownDistributedTPU", - } - return scope.AddOperation(opspec) -} - -// Deprecated. Disallowed in GraphDef version >= 2. -// -// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead -func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustContrast", - Input: []tf.Input{ - images, contrast_factor, min_value, max_value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Table initializer that takes two tensors for keys and values respectively. // // Arguments: @@ -34090,6 +35561,122 @@ func Print(scope *Scope, input tf.Output, data []tf.Output, optional ...PrintAtt return op.Output(0) } +// PrintV2Attr is an optional argument to PrintV2. +type PrintV2Attr func(optionalAttr) + +// PrintV2OutputStream sets the optional output_stream attribute to value. +// +// value: A string specifying the output stream or logging level to print to. +// If not specified, defaults to "stderr" +func PrintV2OutputStream(value string) PrintV2Attr { + return func(m optionalAttr) { + m["output_stream"] = value + } +} + +// PrintV2End sets the optional end attribute to value. +// If not specified, defaults to "\n" +func PrintV2End(value string) PrintV2Attr { + return func(m optionalAttr) { + m["end"] = value + } +} + +// Prints a string scalar. +// +// Prints a string scalar to the desired output_stream. +// +// Arguments: +// input: The string scalar to print. +// +// Returns the created operation. +func PrintV2(scope *Scope, input tf.Output, optional ...PrintV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "PrintV2", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Reverses specific dimensions of a tensor. +// +// NOTE `tf.reverse` has now changed behavior in preparation for 1.0. +// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0. +// +// Given a `tensor`, and a `int32` tensor `axis` representing the set of +// dimensions of `tensor` to reverse. This operation reverses each dimension +// `i` for which there exists `j` s.t. `axis[j] == i`. +// +// `tensor` can have up to 8 dimensions. The number of dimensions specified +// in `axis` may be 0 or more entries. If an index is specified more than +// once, a InvalidArgument error is raised. +// +// For example: +// +// ``` +// # tensor 't' is [[[[ 0, 1, 2, 3], +// # [ 4, 5, 6, 7], +// # [ 8, 9, 10, 11]], +// # [[12, 13, 14, 15], +// # [16, 17, 18, 19], +// # [20, 21, 22, 23]]]] +// # tensor 't' shape is [1, 2, 3, 4] +// +// # 'dims' is [3] or 'dims' is [-1] +// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], +// [ 7, 6, 5, 4], +// [ 11, 10, 9, 8]], +// [[15, 14, 13, 12], +// [19, 18, 17, 16], +// [23, 22, 21, 20]]]] +// +// # 'dims' is '[1]' (or 'dims' is '[-3]') +// reverse(t, dims) ==> [[[[12, 13, 14, 15], +// [16, 17, 18, 19], +// [20, 21, 22, 23] +// [[ 0, 1, 2, 3], +// [ 4, 5, 6, 7], +// [ 8, 9, 10, 11]]]] +// +// # 'dims' is '[2]' (or 'dims' is '[-2]') +// reverse(t, dims) ==> [[[[8, 9, 10, 11], +// [4, 5, 6, 7], +// [0, 1, 2, 3]] +// [[20, 21, 22, 23], +// [16, 17, 18, 19], +// [12, 13, 14, 15]]]] +// ``` +// +// Arguments: +// tensor: Up to 8-D. +// axis: 1-D. The indices of the dimensions to reverse. Must be in the range +// `[-rank(tensor), rank(tensor))`. +// +// Returns The same shape as `tensor`. +func ReverseV2(scope *Scope, tensor tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReverseV2", + Input: []tf.Input{ + tensor, axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Outputs a `Summary` protocol buffer with a tensor and per-plugin data. // // Arguments: @@ -34111,25 +35698,21 @@ func TensorSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, serialized_s return op.Output(0) } -// Creates a dataset that asynchronously prefetches elements from `input_dataset`. +// Serializes the tree handle to a proto // // Arguments: +// tree_handle: Handle to the tree resource to be serialized. // -// buffer_size: The maximum number of elements to buffer in an iterator over -// this dataset. -// -// -func PrefetchDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns Serialied proto string of the tree resource. +func TensorForestTreeSerialize(scope *Scope, tree_handle tf.Output) (tree_config tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "PrefetchDataset", + Type: "TensorForestTreeSerialize", Input: []tf.Input{ - input_dataset, buffer_size, + tree_handle, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -34195,24 +35778,47 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr return op.Output(0) } -// Read an element from the TensorArray into output `value`. +// BiasAddGradAttr is an optional argument to BiasAddGrad. +type BiasAddGradAttr func(optionalAttr) + +// BiasAddGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the bias tensor will be added to the last dimension +// of the value tensor. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// The tensor will be added to "in_channels", the third-to-the-last +// dimension. +// If not specified, defaults to "NHWC" +func BiasAddGradDataFormat(value string) BiasAddGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// The backward operation for "BiasAdd" on the "bias" tensor. +// +// It accumulates all the values from out_backprop into the feature dimension. +// For NHWC data format, the feature dimension is the last. For NCHW data format, +// the feature dimension is the third-to-last. // // Arguments: -// handle: The handle to a TensorArray. +// out_backprop: Any number of dimensions. // -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. -// -// Returns The tensor that is read from the TensorArray. -func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { +// Returns 1-D with size the feature dimension of `out_backprop`. +func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayReadV3", + Type: "BiasAddGrad", Input: []tf.Input{ - handle, index, flow_in, + out_backprop, }, Attrs: attrs, } @@ -34220,89 +35826,6 @@ func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in return op.Output(0) } -// Reduces sparse updates into the variable referenced by `resource` using the `max` operation. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] = max(ref[indices, ...], updates[...]) -// -// # Vector indices (for each i) -// ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) -// -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions are combined. -// -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. -// -//
-// -//
-// -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. -// -// Returns the created operation. -func ResourceScatterMax(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ResourceScatterMax", - Input: []tf.Input{ - resource, indices, updates, - }, - } - return scope.AddOperation(opspec) -} - -// Computes the gradient for the tanh of `x` wrt its input. -// -// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy` -// is the corresponding input gradient. -func TanhGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TanhGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs a `Summary` protocol buffer with scalar values. -// -// The input `tags` and `values` must have the same shape. The generated summary -// has a summary value for each tag-value pair in `tags` and `values`. -// -// Arguments: -// tags: Tags for the summary. -// values: Same shape as `tags. Values for the summary. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ScalarSummary", - Input: []tf.Input{ - tags, values, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ImageSummaryAttr is an optional argument to ImageSummary. type ImageSummaryAttr func(optionalAttr) @@ -34389,6 +35912,43 @@ func ImageSummary(scope *Scope, tag tf.Output, tensor tf.Output, optional ...Ima return op.Output(0) } +// BoostedTreesQuantileStreamResourceHandleOpAttr is an optional argument to BoostedTreesQuantileStreamResourceHandleOp. +type BoostedTreesQuantileStreamResourceHandleOpAttr func(optionalAttr) + +// BoostedTreesQuantileStreamResourceHandleOpContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func BoostedTreesQuantileStreamResourceHandleOpContainer(value string) BoostedTreesQuantileStreamResourceHandleOpAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// BoostedTreesQuantileStreamResourceHandleOpSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func BoostedTreesQuantileStreamResourceHandleOpSharedName(value string) BoostedTreesQuantileStreamResourceHandleOpAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a handle to a BoostedTreesQuantileStreamResource. +func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...BoostedTreesQuantileStreamResourceHandleOpAttr) (resource tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "BoostedTreesQuantileStreamResourceHandleOp", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AudioSummaryV2Attr is an optional argument to AudioSummaryV2. type AudioSummaryV2Attr func(optionalAttr) @@ -34443,70 +36003,24 @@ func AudioSummaryV2(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate t return op.Output(0) } -// Splits a tensor into a list. -// -// list[i] corresponds to lengths[i] tensors from the input tensor. -// The tensor must have rank at least 1 and contain exactly sum(lengths) elements. -// -// tensor: The input tensor. -// element_shape: A shape compatible with that of elements in the tensor. -// lengths: Vector of sizes of the 0th dimension of tensors in the list. -// output_handle: The list. -func TensorListSplit(scope *Scope, tensor tf.Output, element_shape tf.Output, lengths tf.Output) (output_handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorListSplit", - Input: []tf.Input{ - tensor, element_shape, lengths, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AvgPoolAttr is an optional argument to AvgPool. -type AvgPoolAttr func(optionalAttr) - -// AvgPoolDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolDataFormat(value string) AvgPoolAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs average pooling on the input. -// -// Each entry in `output` is the mean of the corresponding size `ksize` -// window in `value`. +// Creates a Dataset that returns pseudorandom numbers. // // Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// ksize: The size of the sliding window for each dimension of `value`. -// strides: The stride of the sliding window for each dimension of `value`. -// padding: The type of padding algorithm to use. +// seed: A scalar seed for the random number generator. If either seed or +// seed2 is set to be non-zero, the random number generator is seeded +// by the given seed. Otherwise, a random seed is used. +// seed2: A second scalar seed to avoid seed collision. // -// Returns The average pooled output tensor. -func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolAttr) (output tf.Output) { +// +func ExperimentalRandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "AvgPool", + Type: "ExperimentalRandomDataset", Input: []tf.Input{ - value, + seed, seed2, }, Attrs: attrs, } @@ -34514,71 +36028,23 @@ func AvgPool(scope *Scope, value tf.Output, ksize []int64, strides []int64, padd return op.Output(0) } -// Merges summaries. +// Creates and returns an empty tensor list. // -// This op creates a -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// protocol buffer that contains the union of all the values in the input -// summaries. +// All list elements must be tensors of dtype element_dtype and shape compatible +// with element_shape. // -// When the Op is run, it reports an `InvalidArgument` error if multiple values -// in the summaries to merge use the same tag. -// -// Arguments: -// inputs: Can be of any shape. Each must contain serialized `Summary` protocol -// buffers. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MergeSummary", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// The shape of the elements of the given list, as a tensor. -// -// input_handle: the list -// element_shape: the shape of elements of the list -func TensorListElementShape(scope *Scope, input_handle tf.Output, shape_type tf.DataType) (element_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shape_type": shape_type} - opspec := tf.OpSpec{ - Type: "TensorListElementShape", - Input: []tf.Input{ - input_handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns the item in the list with the given index. -// -// input_handle: the list -// index: the position in the list from which an element will be retrieved -// item: the element at that position -// -// -func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, element_shape tf.Output, element_dtype tf.DataType) (item tf.Output) { +// handle: an empty tensor list. +// element_dtype: the type of elements in the list. +// element_shape: a shape compatible with that of elements in the list. +func EmptyTensorList(scope *Scope, element_shape tf.Output, max_num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "TensorListGetItem", + Type: "EmptyTensorList", Input: []tf.Input{ - input_handle, index, element_shape, + element_shape, max_num_elements, }, Attrs: attrs, } @@ -34586,6 +36052,352 @@ func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, el return op.Output(0) } +// Returns element-wise remainder of division. This emulates C semantics in that +// +// the result here is consistent with a truncating divide. E.g. +// `tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`. +// +// *NOTE*: `Mod` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Mod", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns a list list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. +// +// tensor: The tensor to put on the list. +// input_handle: The old list. +// output_handle: A list with the elements of the old list followed by tensor. +// element_dtype: the type of elements in the list. +// element_shape: a shape compatible with that of elements in the list. +func TensorListPushBack(scope *Scope, input_handle tf.Output, tensor tf.Output) (output_handle tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorListPushBack", + Input: []tf.Input{ + input_handle, tensor, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MfccAttr is an optional argument to Mfcc. +type MfccAttr func(optionalAttr) + +// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. +// +// value: The highest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 4000 +func MfccUpperFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["upper_frequency_limit"] = value + } +} + +// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. +// +// value: The lowest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 20 +func MfccLowerFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["lower_frequency_limit"] = value + } +} + +// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. +// +// value: Resolution of the Mel bank used internally. +// If not specified, defaults to 40 +func MfccFilterbankChannelCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["filterbank_channel_count"] = value + } +} + +// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. +// +// value: How many output channels to produce per time slice. +// If not specified, defaults to 13 +func MfccDctCoefficientCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["dct_coefficient_count"] = value + } +} + +// Transforms a spectrogram into a form that's useful for speech recognition. +// +// Mel Frequency Cepstral Coefficients are a way of representing audio data that's +// been effective as an input feature for machine learning. They are created by +// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the +// higher frequencies that are less significant to the human ear. They have a long +// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum +// is a good resource to learn more. +// +// Arguments: +// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared +// set to true. +// sample_rate: How many samples per second the source audio used. +func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Mfcc", + Input: []tf.Input{ + spectrogram, sample_rate, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the number of tensors in the input tensor list. +// +// input_handle: the input list +// length: the number of tensors in the list +func TensorListLength(scope *Scope, input_handle tf.Output) (length tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorListLength", + Input: []tf.Input{ + input_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the last element of the input list as well as a list with all but that element. +// +// Fails if the list is empty. +// +// input_handle: the input list +// tensor: the withdrawn last element of the list +// element_dtype: the type of elements in the list +// element_shape: the shape of the output tensor +func TensorListPopBack(scope *Scope, input_handle tf.Output, element_shape tf.Output, element_dtype tf.DataType) (output_handle tf.Output, tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + opspec := tf.OpSpec{ + Type: "TensorListPopBack", + Input: []tf.Input{ + input_handle, element_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap. +type AddManySparseToTensorsMapAttr func(optionalAttr) + +// AddManySparseToTensorsMapContainer sets the optional container attribute to value. +// +// value: The container name for the `SparseTensorsMap` created by this op. +// If not specified, defaults to "" +func AddManySparseToTensorsMapContainer(value string) AddManySparseToTensorsMapAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// AddManySparseToTensorsMapSharedName sets the optional shared_name attribute to value. +// +// value: The shared name for the `SparseTensorsMap` created by this op. +// If blank, the new Operation's unique name is used. +// If not specified, defaults to "" +func AddManySparseToTensorsMapSharedName(value string) AddManySparseToTensorsMapAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Add an `N`-minibatch `SparseTensor` to a `SparseTensorsMap`, return `N` handles. +// +// A `SparseTensor` of rank `R` is represented by three tensors: `sparse_indices`, +// `sparse_values`, and `sparse_shape`, where +// +// ```sparse_indices.shape[1] == sparse_shape.shape[0] == R``` +// +// An `N`-minibatch of `SparseTensor` objects is represented as a `SparseTensor` +// having a first `sparse_indices` column taking values between `[0, N)`, where +// the minibatch size `N == sparse_shape[0]`. +// +// The input `SparseTensor` must have rank `R` greater than 1, and the first +// dimension is treated as the minibatch dimension. Elements of the `SparseTensor` +// must be sorted in increasing order of this first dimension. The stored +// `SparseTensor` objects pointed to by each row of the output `sparse_handles` +// will have rank `R-1`. +// +// The `SparseTensor` values can then be read out as part of a minibatch by passing +// the given keys as vector elements to `TakeManySparseFromTensorsMap`. To ensure +// the correct `SparseTensorsMap` is accessed, ensure that the same +// `container` and `shared_name` are passed to that Op. If no `shared_name` +// is provided here, instead use the *name* of the Operation created by calling +// `AddManySparseToTensorsMap` as the `shared_name` passed to +// `TakeManySparseFromTensorsMap`. Ensure the Operations are colocated. +// +// Arguments: +// sparse_indices: 2-D. The `indices` of the minibatch `SparseTensor`. +// `sparse_indices[:, 0]` must be ordered values in `[0, N)`. +// sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. +// The minibatch size `N == sparse_shape[0]`. +// +// Returns 1-D. The handles of the `SparseTensor` now stored in the +// `SparseTensorsMap`. Shape: `[N]`. +func AddManySparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...AddManySparseToTensorsMapAttr) (sparse_handles tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AddManySparseToTensorsMap", + Input: []tf.Input{ + sparse_indices, sparse_values, sparse_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compute the lower regularized incomplete Gamma function `P(a, x)`. +// +// The lower regularized incomplete Gamma function is defined as: +// +// +// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\) +// +// where +// +// \\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) +// +// is the lower incomplete Gamma function. +// +// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete +// Gamma function. +func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Igamma", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TensorListStackAttr is an optional argument to TensorListStack. +type TensorListStackAttr func(optionalAttr) + +// TensorListStackNumElements sets the optional num_elements attribute to value. +// If not specified, defaults to -1 +func TensorListStackNumElements(value int64) TensorListStackAttr { + return func(m optionalAttr) { + m["num_elements"] = value + } +} + +// Stacks all tensors in the list. +// +// Requires that all tensors have the same shape. +// +// input_handle: the input list +// tensor: the gathered result +// num_elements: optional. If not -1, the number of elements in the list. +// +func TensorListStack(scope *Scope, input_handle tf.Output, element_shape tf.Output, element_dtype tf.DataType, optional ...TensorListStackAttr) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorListStack", + Input: []tf.Input{ + input_handle, element_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AssertAttr is an optional argument to Assert. +type AssertAttr func(optionalAttr) + +// AssertSummarize sets the optional summarize attribute to value. +// +// value: Print this many entries of each tensor. +// If not specified, defaults to 3 +func AssertSummarize(value int64) AssertAttr { + return func(m optionalAttr) { + m["summarize"] = value + } +} + +// Asserts that the given condition is true. +// +// If `condition` evaluates to false, print the list of tensors in `data`. +// `summarize` determines how many entries of the tensors to print. +// +// Arguments: +// condition: The condition to evaluate. +// data: The tensors to print out when condition is false. +// +// Returns the created operation. +func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Assert", + Input: []tf.Input{ + condition, tf.OutputList(data), + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // Resizes the list. // // @@ -34606,94 +36418,23 @@ func TensorListResize(scope *Scope, input_handle tf.Output, size tf.Output) (out return op.Output(0) } -// Returns a diagonal tensor with a given diagonal values. +// Creates a Tensor by indexing into the TensorList. // -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: +// Each row in the produced Tensor corresponds to the element in the TensorList +// specified by the given index (see `tf.gather`). // -// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of -// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: -// -// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. -// -// For example: -// -// ``` -// # 'diagonal' is [1, 2, 3, 4] -// tf.diag(diagonal) ==> [[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]] -// ``` -// -// Arguments: -// diagonal: Rank k tensor where k is at most 1. -func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) { +// input_handle: The input tensor list. +// indices: The indices used to index into the list. +// values: The tensor. +func TensorListGather(scope *Scope, input_handle tf.Output, indices tf.Output, element_shape tf.Output, element_dtype tf.DataType) (values tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "Diag", + Type: "TensorListGather", Input: []tf.Input{ - diagonal, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal. -type ParameterizedTruncatedNormalAttr func(optionalAttr) - -// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random values from a normal distribution. The parameters may each be a -// -// scalar which applies to the entire output, or a vector of length shape[0] which -// stores the parameters for each batch. -// -// Arguments: -// shape: The shape of the output tensor. Batches are indexed by the 0th dimension. -// means: The mean parameter of each batch. -// stdevs: The standard deviation parameter of each batch. Must be greater than 0. -// minvals: The minimum cutoff. May be -infinity. -// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval -// for each batch. -// -// Returns A matrix of shape num_batches x samples_per_batch, filled with random -// truncated normal values using the parameters for each row. -func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ParameterizedTruncatedNormal", - Input: []tf.Input{ - shape, means, stdevs, minvals, maxvals, + input_handle, indices, element_shape, }, Attrs: attrs, } @@ -34701,21 +36442,35 @@ func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output return op.Output(0) } -// Sets the index-th position of the list to contain the given tensor. +// Stops gradient computation. // -// input_handle: the list -// index: the position in the list to which the tensor will be assigned -// item: the element to be assigned to that position -// output_handle: the new list, with the element in the proper position +// When executed in a graph, this op outputs its input tensor as-is. // -func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, item tf.Output) (output_handle tf.Output) { +// When building ops to compute gradients, this op prevents the contribution of +// its inputs to be taken into account. Normally, the gradient generator adds ops +// to a graph to compute the derivatives of a specified 'loss' by recursively +// finding out inputs that contributed to its computation. If you insert this op +// in the graph it inputs are masked from the gradient generator. They are not +// taken into account for computing gradients. +// +// This is useful any time you want to compute a value with TensorFlow but need +// to pretend that the value was a constant. Some examples include: +// +// * The *EM* algorithm where the *M-step* should not involve backpropagation +// through the output of the *E-step*. +// * Contrastive divergence training of Boltzmann machines where, when +// differentiating the energy function, the training must not backpropagate +// through the graph that generated the samples from the model. +// * Adversarial training, where no backprop should happen through the adversarial +// example generation process. +func StopGradient(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorListSetItem", + Type: "StopGradient", Input: []tf.Input{ - input_handle, index, item, + input, }, } op := scope.AddOperation(opspec) @@ -34746,187 +36501,98 @@ func TensorListScatter(scope *Scope, tensor tf.Output, indices tf.Output, elemen return op.Output(0) } -// Deprecated. Use TensorArrayScatterV3 +// Produces the average pool of the input tensor for quantized types. // -// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 -func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Arguments: +// input: 4-D with shape `[batch, height, width, channels]`. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// ksize: The size of the window for each dimension of the input tensor. +// The length must be 4 to match the number of dimensions of the input. +// strides: The stride of the sliding window for each dimension of the input +// tensor. The length must be 4 to match the number of dimensions of the input. +// padding: The type of padding algorithm to use. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "TensorArrayScatterV2", + Type: "QuantizedAvgPool", Input: []tf.Input{ - handle, indices, value, flow_in, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AsStringAttr is an optional argument to AsString. -type AsStringAttr func(optionalAttr) - -// AsStringPrecision sets the optional precision attribute to value. -// -// value: The post-decimal precision to use for floating point numbers. -// Only used if precision > -1. -// If not specified, defaults to -1 -func AsStringPrecision(value int64) AsStringAttr { - return func(m optionalAttr) { - m["precision"] = value - } -} - -// AsStringScientific sets the optional scientific attribute to value. -// -// value: Use scientific notation for floating point numbers. -// If not specified, defaults to false -func AsStringScientific(value bool) AsStringAttr { - return func(m optionalAttr) { - m["scientific"] = value - } -} - -// AsStringShortest sets the optional shortest attribute to value. -// -// value: Use shortest representation (either scientific or standard) for -// floating point numbers. -// If not specified, defaults to false -func AsStringShortest(value bool) AsStringAttr { - return func(m optionalAttr) { - m["shortest"] = value - } -} - -// AsStringWidth sets the optional width attribute to value. -// -// value: Pad pre-decimal numbers to this width. -// Applies to both floating point and integer numbers. -// Only used if width > -1. -// If not specified, defaults to -1 -func AsStringWidth(value int64) AsStringAttr { - return func(m optionalAttr) { - m["width"] = value - } -} - -// AsStringFill sets the optional fill attribute to value. -// -// value: The value to pad if width > -1. If empty, pads with spaces. -// Another typical value is '0'. String cannot be longer than 1 character. -// If not specified, defaults to "" -func AsStringFill(value string) AsStringAttr { - return func(m optionalAttr) { - m["fill"] = value - } -} - -// Converts each entry in the given tensor to strings. Supports many numeric -// -// types and boolean. -func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AsString", - Input: []tf.Input{ - input, + input, min_input, max_input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Returns a `RaggedTensor` containing the specified sequences of numbers. +// Computes the sign and the log of the absolute value of the determinant of // +// one or more square matrices. // -// Returns a `RaggedTensor` `result` composed from `rt_dense_values` and -// `rt_nested_splits`, such that -// `result[i] = range(starts[i], limits[i], deltas[i])`. -// -// ```python -// >>> (rt_nested_splits, rt_dense_values) = gen_ragged_ops.ragged_range( -// ... starts=[2, 5, 8], limits=[3, 5, 12], deltas=1) -// >>> result = ragged.from_nested_row_splits(rt_dense_values, rt_nested_splits) -// >>> print result.eval().tolist() -// [[2], # result[0] = range(2, 3) -// [], # result[1] = range(5, 5) -// [8, 9, 10, 11]] # result[2] = range(8, 12) -// ``` -// -// The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. -// The vector inputs must all have the same size. Scalar inputs are broadcast -// to match the size of the vector inputs. +// The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions +// form square matrices. The outputs are two tensors containing the signs and +// absolute values of the log determinants for all N input submatrices +// `[..., :, :]` such that the determinant = sign*exp(log_abs_determinant). +// The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU +// is the LU decomposition of the input and P is the corresponding +// permutation matrix. // // Arguments: -// starts: The starts of each range. -// limits: The limits of each range. -// deltas: The deltas of each range. +// input: Shape is `[N, M, M]`. // -// Returns The `row_splits` for the returned `RaggedTensor`.The `flat_values` for the returned `RaggedTensor`. -func RaggedRange(scope *Scope, starts tf.Output, limits tf.Output, deltas tf.Output) (rt_nested_splits tf.Output, rt_dense_values tf.Output) { +// Returns The signs of the log determinants of the inputs. Shape is `[N]`.The logs of the absolute values of the determinants +// of the N input matrices. Shape is `[N]`. +func LogMatrixDeterminant(scope *Scope, input tf.Output) (sign tf.Output, log_abs_determinant tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RaggedRange", + Type: "LogMatrixDeterminant", Input: []tf.Input{ - starts, limits, deltas, + input, }, } op := scope.AddOperation(opspec) return op.Output(0), op.Output(1) } -// Deprecated, use python implementation tf.linalg.matrix_exponential. +// Computes the matrix logarithm of one or more square matrices: // -// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead. -func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixExponential", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the Cholesky decomposition of one or more square matrices. +// +// \\(log(exp(A)) = A\\) +// +// This op is only defined for complex matrices. If A is positive-definite and +// real, then casting to a complex matrix, taking the logarithm and casting back +// to a real matrix will give the correct result. +// +// This function computes the matrix logarithm using the Schur-Parlett algorithm. +// Details of the algorithm can be found in Section 11.6.2 of: +// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. +// ISBN 978-0-898716-46-7. // // The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. -// -// The input has to be symmetric and positive definite. Only the lower-triangular -// part of the input will be used for this operation. The upper-triangular part -// will not be read. -// -// The output is a tensor of the same shape as the input -// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. -// -// **Note**: The gradient computation on GPU is faster for large matrices but -// not for large batch dimensions when the submatrices are small. In this -// case it might be faster to use the CPU. +// form square matrices. The output is a tensor of the same shape as the input +// containing the exponential for all input submatrices `[..., :, :]`. // // Arguments: // input: Shape is `[..., M, M]`. // // Returns Shape is `[..., M, M]`. -func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { +// +// @compatibility(scipy) +// Equivalent to scipy.linalg.logm +// @end_compatibility +func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Cholesky", + Type: "MatrixLogarithm", Input: []tf.Input{ input, }, @@ -34935,66 +36601,31 @@ func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// Writes contents to the file at input filename. Creates file and recursively -// -// creates directory if not existing. -// -// Arguments: -// filename: scalar. The name of the file to which we write the contents. -// contents: scalar. The content to be written to the output file. -// -// Returns the created operation. -func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { +// Computes acos of x element-wise. +func Acos(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "WriteFile", + Type: "Acos", Input: []tf.Input{ - filename, contents, + x, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// AllAttr is an optional argument to All. -type AllAttr func(optionalAttr) - -// AllKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AllKeepDims(value bool) AllAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the "logical and" of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. -// -// Returns The reduced tensor. -func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (output tf.Output) { +// Creates a dataset that zips together `input_datasets`. +func ZipDataset(scope *Scope, input_datasets []tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "All", + Type: "ZipDataset", Input: []tf.Input{ - input, axis, + tf.OutputList(input_datasets), }, Attrs: attrs, } @@ -35002,28 +36633,13 @@ func All(scope *Scope, input tf.Output, axis tf.Output, optional ...AllAttr) (ou return op.Output(0) } -// Computes the Eigen Decomposition of a batch of square self-adjoint matrices. -// -// DEPRECATED at GraphDef version 11: Use SelfAdjointEigV2 instead. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices, with the same constraints as the single matrix -// SelfAdjointEig. -// -// The result is a [..., M+1, M] matrix with [..., 0,:] containing the -// eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. The eigenvalues -// are sorted in non-decreasing order. -// -// Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[..., M+1, M]`. -func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { +// Return a tensor with the same shape and contents as the input tensor or value. +func Identity(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SelfAdjointEig", + Type: "Identity", Input: []tf.Input{ input, }, @@ -35032,50 +36648,28 @@ func SelfAdjointEig(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// Computes softplus gradients for a softplus operation. +// Computes the reverse mode backpropagated gradient of the Cholesky algorithm. +// +// For an explanation see "Differentiation of the Cholesky algorithm" by +// Iain Murray http://arxiv.org/abs/1602.07527. // // Arguments: -// gradients: The backpropagated gradients to the corresponding softplus operation. -// features: The features passed as input to the corresponding softplus operation. +// l: Output of batch Cholesky algorithm l = cholesky(A). Shape is `[..., M, M]`. +// Algorithm depends only on lower triangular part of the innermost matrices of +// this tensor. +// grad: df/dl where f is some scalar function. Shape is `[..., M, M]`. +// Algorithm depends only on lower triangular part of the innermost matrices of +// this tensor. // -// Returns The gradients: `gradients / (1 + exp(-features))`. -func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { +// Returns Symmetrized version of df/dA . Shape is `[..., M, M]` +func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SoftplusGrad", + Type: "CholeskyGrad", Input: []tf.Input{ - gradients, features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Solves tridiagonal systems of equations. -// -// `diagonals` is a tensor of shape `[..., 3, M]` whose inner-most 2 dimensions -// represent matrices with three rows being the superdiagonal, diagonals, and -// subdiagonals, in order. The last element of the superdiagonal and the first -// element of the subdiagonal is ignored. -// `rhs` is a tensor of shape `[..., M, K]`, representing K right-hand sides per -// each left-hand side. -// The output is a tensor of shape `[..., M, K]` containing the solutions. -// -// Arguments: -// diagonals: Shape is `[..., 3, M]`. -// rhs: Shape is `[..., M, K]`. -// -// Returns Shape is `[..., M, K]`. -func TridiagonalSolve(scope *Scope, diagonals tf.Output, rhs tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TridiagonalSolve", - Input: []tf.Input{ - diagonals, rhs, + l, grad, }, } op := scope.AddOperation(opspec) @@ -35133,319 +36727,6 @@ func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV return op.Output(0), op.Output(1) } -// Adjust the saturation of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. -// -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A scale is then applied all the saturation -// values, and then remapped back to RGB colorspace. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// scale: A float scale to add to the saturation. -// -// Returns The hue-adjusted image or images. -func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustSaturation", - Input: []tf.Input{ - images, scale, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MatrixSolveAttr is an optional argument to MatrixSolve. -type MatrixSolveAttr func(optionalAttr) - -// MatrixSolveAdjoint sets the optional adjoint attribute to value. -// -// value: Boolean indicating whether to solve with `matrix` or its (block-wise) -// adjoint. -// If not specified, defaults to false -func MatrixSolveAdjoint(value bool) MatrixSolveAttr { - return func(m optionalAttr) { - m["adjoint"] = value - } -} - -// Solves systems of linear equations. -// -// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is -// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix -// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. -// If `adjoint` is `True` then each output matrix satisfies -// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. -// -// Arguments: -// matrix: Shape is `[..., M, M]`. -// rhs: Shape is `[..., M, K]`. -// -// Returns Shape is `[..., M, K]`. -func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixSolve", - Input: []tf.Input{ - matrix, rhs, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyKerasMomentumAttr is an optional argument to ResourceApplyKerasMomentum. -type ResourceApplyKerasMomentumAttr func(optionalAttr) - -// ResourceApplyKerasMomentumUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyKerasMomentumUseLocking(value bool) ResourceApplyKerasMomentumAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyKerasMomentumUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, the tensor passed to compute grad will be -// var + momentum * accum, so in the end, the var you get is actually -// var + momentum * accum. -// If not specified, defaults to false -func ResourceApplyKerasMomentumUseNesterov(value bool) ResourceApplyKerasMomentumAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update '*var' according to the momentum scheme. Set use_nesterov = True if you -// -// want to use Nesterov momentum. -// -// accum = accum * momentum - lr * grad -// var += accum -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// grad: The gradient. -// momentum: Momentum. Must be a scalar. -// -// Returns the created operation. -func ResourceApplyKerasMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyKerasMomentumAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyKerasMomentum", - Input: []tf.Input{ - var_, accum, lr, grad, momentum, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Returns a serialized GraphDef representing `input_dataset`. -// -// Returns a graph representation for `input_dataset`. -// -// Arguments: -// input_dataset: A variant tensor representing the dataset to return the graph representation for. -// -// Returns The graph representation of the dataset (as serialized GraphDef). -func DatasetToGraph(scope *Scope, input_dataset tf.Output) (graph tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DatasetToGraph", - Input: []tf.Input{ - input_dataset, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LuAttr is an optional argument to Lu. -type LuAttr func(optionalAttr) - -// LuOutputIdxType sets the optional output_idx_type attribute to value. -// If not specified, defaults to DT_INT32 -func LuOutputIdxType(value tf.DataType) LuAttr { - return func(m optionalAttr) { - m["output_idx_type"] = value - } -} - -// Computes the LU decomposition of one or more square matrices. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. -// -// The input has to be invertible. -// -// The output consists of two tensors LU and P containing the LU decomposition -// of all input submatrices `[..., :, :]`. LU encodes the lower triangular and -// upper triangular factors. -// -// For each input submatrix of shape `[M, M]`, L is a lower triangular matrix of -// shape `[M, M]` with unit diagonal whose entries correspond to the strictly lower -// triangular part of LU. U is a upper triangular matrix of shape `[M, M]` whose -// entries correspond to the upper triangular part, including the diagonal, of LU. -// -// P represents a permutation matrix encoded as a list of indices each between `0` -// and `M-1`, inclusive. If P_mat denotes the permutation matrix corresponding to -// P, then the L, U and P satisfies P_mat * input = L * U. -// -// Arguments: -// input: A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form matrices of -// size `[M, M]`. -// -// Returns A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the -// lower triangular factor `L` with unit diagonal, and whose upper triangular part -// denotes the upper triangular factor `U`.Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is -// `[..., M]`. -// @compatibility(scipy) -// Similar to `scipy.linalg.lu`, except the triangular factors `L` and `U` are -// packed into a single tensor, the permutation is applied to `input` instead of -// the right hand side and the permutation `P` is returned as a list of indices -// instead of a permutation matrix. -// @end_compatibility -func Lu(scope *Scope, input tf.Output, optional ...LuAttr) (lu tf.Output, p tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Lu", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Deprecated. Use TensorArrayCloseV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 -// -// Returns the created operation. -func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArrayCloseV2", - Input: []tf.Input{ - handle, - }, - } - return scope.AddOperation(opspec) -} - -// EncodeBase64Attr is an optional argument to EncodeBase64. -type EncodeBase64Attr func(optionalAttr) - -// EncodeBase64Pad sets the optional pad attribute to value. -// -// value: Bool whether padding is applied at the ends. -// If not specified, defaults to false -func EncodeBase64Pad(value bool) EncodeBase64Attr { - return func(m optionalAttr) { - m["pad"] = value - } -} - -// Encode strings into web-safe base64 format. -// -// Refer to the following article for more information on base64 format: -// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the -// end so that the encoded has length multiple of 4. See Padding section of the -// link above. -// -// Web-safe means that the encoder uses - and _ instead of + and /. -// -// Arguments: -// input: Strings to be encoded. -// -// Returns Input strings encoded in base64. -func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodeBase64", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// A dataset that creates window datasets from the input dataset. -// -// Arguments: -// -// size: A scalar representing the number of elements to accumulate in a window. -// shift: A scalar representing the steps moving the sliding window forward in one -// iteration. It must be positive. -// stride: A scalar representing the stride of the input elements of the sliding window. -// It must be positive. -// drop_remainder: A scalar representing whether a window should be dropped in case its size is -// smaller than desired. -// -// -func WindowDataset(scope *Scope, input_dataset tf.Output, size tf.Output, shift tf.Output, stride tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "WindowDataset", - Input: []tf.Input{ - input_dataset, size, shift, stride, drop_remainder, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the matrix square root of one or more square matrices: // // matmul(sqrtm(A), sqrtm(A)) = A @@ -35555,204 +36836,170 @@ func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf. return op.Output(0), op.Output(1), op.Output(2) } -// Converts one or more images from RGB to HSV. +// Reorders a SparseTensor into the canonical, row-major ordering. // -// Outputs a tensor of the same shape as the `images` tensor, containing the HSV -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. +// Note that by convention, all sparse ops preserve the canonical ordering along +// increasing dimension number. The only time ordering can be violated is during +// manual manipulation of the indices and values vectors to add entries. // -// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and -// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 -// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. +// Reordering does not affect the shape of the SparseTensor. +// +// If the tensor has rank `R` and `N` non-empty values, `input_indices` has +// shape `[N, R]`, input_values has length `N`, and input_shape has length `R`. // // Arguments: -// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. // -// Returns `images` converted to HSV. -func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { +// Returns 2-D. `N x R` matrix with the same indices as input_indices, but +// in canonical row-major ordering.1-D. `N` non-empty values corresponding to `output_indices`. +func SparseReorder(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RGBToHSV", + Type: "SparseReorder", Input: []tf.Input{ - images, + input_indices, input_values, input_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Deprecated. Use TensorArraySplitV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 +func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArraySplitV2", + Input: []tf.Input{ + handle, value, lengths, flow_in, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Does nothing. Only useful as a placeholder for control edges. +// Gradients for batch normalization. // -// Returns the created operation. -func NoOp(scope *Scope) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NoOp", - } - return scope.AddOperation(opspec) -} - -// MergeV2CheckpointsAttr is an optional argument to MergeV2Checkpoints. -type MergeV2CheckpointsAttr func(optionalAttr) - -// MergeV2CheckpointsDeleteOldDirs sets the optional delete_old_dirs attribute to value. +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() // -// value: see above. -// If not specified, defaults to true -func MergeV2CheckpointsDeleteOldDirs(value bool) MergeV2CheckpointsAttr { - return func(m optionalAttr) { - m["delete_old_dirs"] = value - } -} - -// V2 format specific: merges the metadata files of sharded checkpoints. The -// -// result is one logical checkpoint, with one physical metadata file and renamed -// data files. -// -// Intended for "grouping" multiple checkpoints in a sharded checkpoint setup. -// -// If delete_old_dirs is true, attempts to delete recursively the dirname of each -// path in the input checkpoint_prefixes. This is useful when those paths are non -// user-facing temporary locations. +// This op is deprecated. See `tf.nn.batch_normalization`. // // Arguments: -// checkpoint_prefixes: prefixes of V2 checkpoints to merge. -// destination_prefix: scalar. The desired final prefix. Allowed to be the same -// as one of the checkpoint_prefixes. +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this Tensor will be multiplied +// with the normalized Tensor. +// backprop: 4D backprop Tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. // -// Returns the created operation. -func MergeV2Checkpoints(scope *Scope, checkpoint_prefixes tf.Output, destination_prefix tf.Output, optional ...MergeV2CheckpointsAttr) (o *tf.Operation) { +// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. +func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "MergeV2Checkpoints", + Type: "BatchNormWithGlobalNormalizationGrad", Input: []tf.Input{ - checkpoint_prefixes, destination_prefix, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Saves input tensors slices to disk. -// -// This is like `Save` except that tensors can be listed in the saved file as being -// a slice of a larger tensor. `shapes_and_slices` specifies the shape of the -// larger tensor and the slice that this tensor covers. `shapes_and_slices` must -// have as many elements as `tensor_names`. -// -// Elements of the `shapes_and_slices` input must either be: -// -// * The empty string, in which case the corresponding tensor is -// saved normally. -// * A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the -// `dimI` are the dimensions of the larger tensor and `slice-spec` -// specifies what part is covered by the tensor to save. -// -// `slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1` -// where each `sliceI` is either: -// -// * The string `-` meaning that the slice covers all indices of this dimension -// * `start,length` where `start` and `length` are integers. In that -// case the slice covers `length` indices starting at `start`. -// -// See also `Save`. -// -// Arguments: -// filename: Must have a single element. The name of the file to which we write the -// tensor. -// tensor_names: Shape `[N]`. The names of the tensors to be saved. -// shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when -// saving the tensors. -// data: `N` tensors to save. -// -// Returns the created operation. -func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes_and_slices tf.Output, data []tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SaveSlices", - Input: []tf.Input{ - filename, tensor_names, shapes_and_slices, tf.OutputList(data), - }, - } - return scope.AddOperation(opspec) -} - -// DenseToDenseSetOperationAttr is an optional argument to DenseToDenseSetOperation. -type DenseToDenseSetOperationAttr func(optionalAttr) - -// DenseToDenseSetOperationValidateIndices sets the optional validate_indices attribute to value. -// If not specified, defaults to true -func DenseToDenseSetOperationValidateIndices(value bool) DenseToDenseSetOperationAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Applies set operation along last dimension of 2 `Tensor` inputs. -// -// See SetOperationOp::SetOperationFromContext for values of `set_operation`. -// -// Output `result` is a `SparseTensor` represented by `result_indices`, -// `result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this -// has rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth` -// dimension contains the result of `set_operation` applied to the corresponding -// `[0...n-1]` dimension of `set`. -// -// Arguments: -// set1: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set2`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// set2: `Tensor` with rank `n`. 1st `n-1` dimensions must be the same as `set1`. -// Dimension `n` contains values in a set, duplicates are allowed but ignored. -// -// -// Returns 2D indices of a `SparseTensor`.1D values of a `SparseTensor`.1D `Tensor` shape of a `SparseTensor`. `result_shape[0...n-1]` is -// the same as the 1st `n-1` dimensions of `set1` and `set2`, `result_shape[n]` -// is the max result set size across all `0...n-1` dimensions. -func DenseToDenseSetOperation(scope *Scope, set1 tf.Output, set2 tf.Output, set_operation string, optional ...DenseToDenseSetOperationAttr) (result_indices tf.Output, result_values tf.Output, result_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"set_operation": set_operation} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DenseToDenseSetOperation", - Input: []tf.Input{ - set1, set2, + t, m, v, gamma, backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Generate a sharded filename. The filename is printf formatted as +// Saves the input tensors to disk. // -// %s-%05d-of-%05d, basename, shard, num_shards. -func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shards tf.Output) (filename tf.Output) { +// The size of `tensor_names` must match the number of tensors in `data`. `data[i]` +// is written to `filename` with name `tensor_names[i]`. +// +// See also `SaveSlices`. +// +// Arguments: +// filename: Must have a single element. The name of the file to which we write +// the tensor. +// tensor_names: Shape `[N]`. The names of the tensors to be saved. +// data: `N` tensors to save. +// +// Returns the created operation. +func Save(scope *Scope, filename tf.Output, tensor_names tf.Output, data []tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ShardedFilename", + Type: "Save", Input: []tf.Input{ - basename, shard, num_shards, + filename, tensor_names, tf.OutputList(data), }, } + return scope.AddOperation(opspec) +} + +// RestoreSliceAttr is an optional argument to RestoreSlice. +type RestoreSliceAttr func(optionalAttr) + +// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. +// +// value: Index of file to open first if multiple files match +// `file_pattern`. See the documentation for `Restore`. +// If not specified, defaults to -1 +func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { + return func(m optionalAttr) { + m["preferred_shard"] = value + } +} + +// Restores a tensor from checkpoint files. +// +// This is like `Restore` except that restored tensor can be listed as filling +// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the +// larger tensor and the slice that the restored tensor covers. +// +// The `shape_and_slice` input has the same format as the +// elements of the `shapes_and_slices` input of the `SaveSlices` op. +// +// Arguments: +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// shape_and_slice: Scalar. The shapes and slice specifications to use when +// restoring a tensors. +// dt: The type of the tensor to be restored. +// +// Returns The restored tensor. +func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dt": dt} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RestoreSlice", + Input: []tf.Input{ + file_pattern, tensor_name, shape_and_slice, + }, + Attrs: attrs, + } op := scope.AddOperation(opspec) return op.Output(0) } @@ -35772,6 +37019,84 @@ func ShardedFilespec(scope *Scope, basename tf.Output, num_shards tf.Output) (fi return op.Output(0) } +// Converts the given `resource_handle` representing an iterator to a variant tensor. +// +// Arguments: +// resource_handle: A handle to an iterator resource. +// +// Returns A variant tensor storing the state of the iterator contained in the +// resource. +func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SerializeIterator", + Input: []tf.Input{ + resource_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ExperimentalThreadPoolHandleAttr is an optional argument to ExperimentalThreadPoolHandle. +type ExperimentalThreadPoolHandleAttr func(optionalAttr) + +// ExperimentalThreadPoolHandleMaxIntraOpParallelism sets the optional max_intra_op_parallelism attribute to value. +// +// value: The maximum degree of parallelism to use within operations that execute on this +// threadpool. +// If not specified, defaults to 1 +func ExperimentalThreadPoolHandleMaxIntraOpParallelism(value int64) ExperimentalThreadPoolHandleAttr { + return func(m optionalAttr) { + m["max_intra_op_parallelism"] = value + } +} + +// ExperimentalThreadPoolHandleContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func ExperimentalThreadPoolHandleContainer(value string) ExperimentalThreadPoolHandleAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// ExperimentalThreadPoolHandleSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func ExperimentalThreadPoolHandleSharedName(value string) ExperimentalThreadPoolHandleAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a dataset that uses a custom thread pool to compute `input_dataset`. +// +// Arguments: +// num_threads: The number of threads in the thread pool. +// display_name: A human-readable name for the threads that may be visible in some +// visualizations. +// threadpool. +// +// Returns A resource that can be consumed by one or more ExperimentalThreadPoolDataset +// ops. +func ExperimentalThreadPoolHandle(scope *Scope, num_threads int64, display_name string, optional ...ExperimentalThreadPoolHandleAttr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_threads": num_threads, "display_name": display_name} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ExperimentalThreadPoolHandle", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // TextLineReaderV2Attr is an optional argument to TextLineReaderV2. type TextLineReaderV2Attr func(optionalAttr) @@ -35827,96 +37152,161 @@ func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_ha return op.Output(0) } -// LoadAndRemapMatrixAttr is an optional argument to LoadAndRemapMatrix. -type LoadAndRemapMatrixAttr func(optionalAttr) - -// LoadAndRemapMatrixMaxRowsInMemory sets the optional max_rows_in_memory attribute to value. -// -// value: The maximum number of rows to load from the checkpoint at -// once. If less than or equal to 0, the entire matrix will be loaded into -// memory. Setting this arg trades increased disk reads for lower memory usage. -// If not specified, defaults to -1 -func LoadAndRemapMatrixMaxRowsInMemory(value int64) LoadAndRemapMatrixAttr { - return func(m optionalAttr) { - m["max_rows_in_memory"] = value - } -} - -// Loads a 2-D (matrix) `Tensor` with name `old_tensor_name` from the checkpoint -// -// at `ckpt_path` and potentially reorders its rows and columns using the -// specified remappings. -// -// Most users should use one of the wrapper initializers (such as -// `tf.contrib.framework.load_and_remap_matrix_initializer`) instead of this -// function directly. -// -// The remappings are 1-D tensors with the following properties: -// -// * `row_remapping` must have exactly `num_rows` entries. Row `i` of the output -// matrix will be initialized from the row corresponding to index -// `row_remapping[i]` in the old `Tensor` from the checkpoint. -// * `col_remapping` must have either 0 entries (indicating that no column -// reordering is needed) or `num_cols` entries. If specified, column `j` of the -// output matrix will be initialized from the column corresponding to index -// `col_remapping[j]` in the old `Tensor` from the checkpoint. -// * A value of -1 in either of the remappings signifies a "missing" entry. In that -// case, values from the `initializing_values` tensor will be used to fill that -// missing row or column. If `row_remapping` has `r` missing entries and -// `col_remapping` has `c` missing entries, then the following condition must be -// true: -// -// `(r * num_cols) + (c * num_rows) - (r * c) == len(initializing_values)` -// -// The remapping tensors can be generated using the GenerateVocabRemapping op. -// -// As an example, with row_remapping = [1, 0, -1], col_remapping = [0, 2, -1], -// initializing_values = [0.5, -0.5, 0.25, -0.25, 42], and w(i, j) representing -// the value from row i, column j of the old tensor in the checkpoint, the output -// matrix will look like the following: -// -// [[w(1, 0), w(1, 2), 0.5], -// [w(0, 0), w(0, 2), -0.5], -// [0.25, -0.25, 42]] +// Fetches multiple values from infeed as an XLA tuple. // // Arguments: -// ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from -// which the old matrix `Tensor` will be loaded. -// old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. -// row_remapping: An int `Tensor` of row remappings (generally created by -// `generate_vocab_remapping`). Even if no row remapping is needed, this must -// still be an index-valued Tensor (e.g. [0, 1, 2, ...]), or a shifted -// index-valued `Tensor` (e.g. [8, 9, 10, ...], for partitioned `Variables`). -// col_remapping: An int `Tensor` of column remappings (generally created by -// `generate_vocab_remapping`). May be a size-0 `Tensor` if only row remapping -// is to be done (e.g. column ordering is the same). -// initializing_values: A float `Tensor` containing values to fill in for cells -// in the output matrix that are not loaded from the checkpoint. Length must be -// exactly the same as the number of missing / new cells. -// num_rows: Number of rows (length of the 1st dimension) in the output matrix. -// num_cols: Number of columns (length of the 2nd dimension) in the output matrix. +// dtypes: The element types of each element in `outputs`. +// shapes: The shapes of each tensor in `outputs`. // -// Returns Output matrix containing existing values loaded from the -// checkpoint, and with any missing values filled in from initializing_values. -func LoadAndRemapMatrix(scope *Scope, ckpt_path tf.Output, old_tensor_name tf.Output, row_remapping tf.Output, col_remapping tf.Output, initializing_values tf.Output, num_rows int64, num_cols int64, optional ...LoadAndRemapMatrixAttr) (output_matrix tf.Output) { +// Returns A list of tensors that will be provided using the infeed mechanism. +func InfeedDequeueTuple(scope *Scope, dtypes []tf.DataType, shapes []tf.Shape) (outputs []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_rows": num_rows, "num_cols": num_cols} + attrs := map[string]interface{}{"dtypes": dtypes, "shapes": shapes} + opspec := tf.OpSpec{ + Type: "InfeedDequeueTuple", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("InfeedDequeueTuple", err) + return + } + return outputs +} + +// FixedLengthRecordReaderV2Attr is an optional argument to FixedLengthRecordReaderV2. +type FixedLengthRecordReaderV2Attr func(optionalAttr) + +// FixedLengthRecordReaderV2HeaderBytes sets the optional header_bytes attribute to value. +// +// value: Number of bytes in the header, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HeaderBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["header_bytes"] = value + } +} + +// FixedLengthRecordReaderV2FooterBytes sets the optional footer_bytes attribute to value. +// +// value: Number of bytes in the footer, defaults to 0. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2FooterBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["footer_bytes"] = value + } +} + +// FixedLengthRecordReaderV2HopBytes sets the optional hop_bytes attribute to value. +// +// value: Number of bytes to hop before each read. Default of 0 means using +// record_bytes. +// If not specified, defaults to 0 +func FixedLengthRecordReaderV2HopBytes(value int64) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["hop_bytes"] = value + } +} + +// FixedLengthRecordReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Container(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// FixedLengthRecordReaderV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2SharedName(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// FixedLengthRecordReaderV2Encoding sets the optional encoding attribute to value. +// +// value: The type of encoding for the file. Currently ZLIB and GZIP +// are supported. Defaults to none. +// If not specified, defaults to "" +func FixedLengthRecordReaderV2Encoding(value string) FixedLengthRecordReaderV2Attr { + return func(m optionalAttr) { + m["encoding"] = value + } +} + +// A Reader that outputs fixed-length records from a file. +// +// Arguments: +// record_bytes: Number of bytes in the record. +// +// Returns The handle to reference the Reader. +func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...FixedLengthRecordReaderV2Attr) (reader_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"record_bytes": record_bytes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LoadAndRemapMatrix", - Input: []tf.Input{ - ckpt_path, old_tensor_name, row_remapping, col_remapping, initializing_values, - }, + Type: "FixedLengthRecordReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } +// Runs multiple additive regression ensemble predictors on input instances and +// +// computes the update to cached logits. It is designed to be used during training. +// It traverses the trees starting from cached tree id and cached node id and +// calculates the updates to be pushed to the cache. +// +// Arguments: +// +// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting +// tree of prediction. +// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting +// node of prediction. +// bucketized_features: A list of rank 1 Tensors containing bucket id for each +// feature. +// logits_dimension: scalar, dimension of the logits, to be used for partial logits +// shape. +// +// Returns Rank 2 Tensor containing logits update (with respect to cached +// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids. +func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"logits_dimension": logits_dimension} + opspec := tf.OpSpec{ + Type: "BoostedTreesTrainingPredict", + Input: []tf.Input{ + tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // TFRecordReaderV2Attr is an optional argument to TFRecordReaderV2. type TFRecordReaderV2Attr func(optionalAttr) @@ -35970,41 +37360,27 @@ func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_ha return op.Output(0) } -// QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. -type QuantizeAndDequantizeV3Attr func(optionalAttr) - -// QuantizeAndDequantizeV3SignedInput sets the optional signed_input attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3SignedInput(value bool) QuantizeAndDequantizeV3Attr { - return func(m optionalAttr) { - m["signed_input"] = value - } -} - -// QuantizeAndDequantizeV3RangeGiven sets the optional range_given attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeV3RangeGiven(value bool) QuantizeAndDequantizeV3Attr { - return func(m optionalAttr) { - m["range_given"] = value - } -} - -// Quantizes then dequantizes a tensor. +// Creates a dataset that passes a sliding window over `input_dataset`. // -// This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a -// tensor, so its value can change during training. -func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, num_bits tf.Output, optional ...QuantizeAndDequantizeV3Attr) (output tf.Output) { +// Arguments: +// +// window_size: A scalar representing the number of elements in the +// sliding window. +// window_shift: A scalar representing the steps moving the sliding window +// forward in one iteration. It must be positive. +// window_stride: A scalar representing the stride of the input elements of the sliding window. +// It must be positive. +// +// +func ExperimentalSlidingWindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, window_shift tf.Output, window_stride tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "QuantizeAndDequantizeV3", + Type: "ExperimentalSlidingWindowDataset", Input: []tf.Input{ - input, input_min, input_max, num_bits, + input_dataset, window_size, window_shift, window_stride, }, Attrs: attrs, } @@ -36012,117 +37388,102 @@ func QuantizeAndDequantizeV3(scope *Scope, input tf.Output, input_min tf.Output, return op.Output(0) } -// IdentityReaderV2Attr is an optional argument to IdentityReaderV2. -type IdentityReaderV2Attr func(optionalAttr) +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr is an optional argument to LoadTPUEmbeddingRMSPropParametersGradAccumDebug. +type LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr func(optionalAttr) -// IdentityReaderV2Container sets the optional container attribute to value. +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. +// REQUIRES: value >= -1 +func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableId(value int64) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName sets the optional table_name attribute to value. // If not specified, defaults to "" -func IdentityReaderV2Container(value string) IdentityReaderV2Attr { +func LoadTPUEmbeddingRMSPropParametersGradAccumDebugTableName(value string) LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr { return func(m optionalAttr) { - m["container"] = value + m["table_name"] = value } } -// IdentityReaderV2SharedName sets the optional shared_name attribute to value. +// Load RMSProp embedding parameters with debug support. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func IdentityReaderV2SharedName(value string) IdentityReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A Reader that outputs the queued work as both the key and value. -// -// To use, enqueue strings in a Queue. ReaderRead will take the front -// work string and output (work, work). -// -// Returns The handle to reference the Reader. -func IdentityReaderV2(scope *Scope, optional ...IdentityReaderV2Attr) (reader_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "IdentityReaderV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyGradientDescentAttr is an optional argument to ResourceApplyGradientDescent. -type ResourceApplyGradientDescentAttr func(optionalAttr) - -// ResourceApplyGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyGradientDescentUseLocking(value bool) ResourceApplyGradientDescentAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' by subtracting 'alpha' * 'delta' from it. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// delta: The change. +// parameters: Value of parameters used in the RMSProp optimization algorithm. +// ms: Value of ms used in the RMSProp optimization algorithm. +// mom: Value of mom used in the RMSProp optimization algorithm. +// gradient_accumulators: Value of gradient_accumulators used in the RMSProp optimization algorithm. +// +// // // Returns the created operation. -func ResourceApplyGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, delta tf.Output, optional ...ResourceApplyGradientDescentAttr) (o *tf.Operation) { +func LoadTPUEmbeddingRMSPropParametersGradAccumDebug(scope *Scope, parameters tf.Output, ms tf.Output, mom tf.Output, gradient_accumulators tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingRMSPropParametersGradAccumDebugAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyGradientDescent", + Type: "LoadTPUEmbeddingRMSPropParametersGradAccumDebug", Input: []tf.Input{ - var_, alpha, delta, + parameters, ms, mom, gradient_accumulators, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Returns the next record (key, value pair) produced by a Reader. -// -// Will dequeue from the input queue if necessary (e.g. when the -// Reader needs to start reading from a new file since it has finished -// with the previous file). +// Creates a dataset that uses a custom thread pool to compute `input_dataset`. // // Arguments: -// reader_handle: Handle to a Reader. -// queue_handle: Handle to a Queue, with string work items. // -// Returns A scalar.A scalar. -func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { +// thread_pool: A resource produced by the ThreadPoolHandle op. +// +// +func ExperimentalThreadPoolDataset(scope *Scope, input_dataset tf.Output, thread_pool tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalThreadPoolDataset", + Input: []tf.Input{ + input_dataset, thread_pool, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Delete the tensor specified by its handle in the session. +// +// Arguments: +// handle: The handle for a tensor stored in the session state. +// +// Returns the created operation. +func DeleteSessionTensor(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderReadV2", + Type: "DeleteSessionTensor", Input: []tf.Input{ - reader_handle, queue_handle, + handle, }, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } // Returns up to `num_records` (key, value) pairs produced by a Reader. @@ -36152,158 +37513,795 @@ func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Out return op.Output(0), op.Output(1) } -// Adds v into specified rows of x. -// -// Computes y = x; y[i, :] += v; return y. +// Returns the number of work units this Reader has finished processing. // // Arguments: -// x: A `Tensor` of type T. -// i: A vector. Indices into the left-most dimension of `x`. -// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. -// -// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. -func InplaceAdd(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { +// reader_handle: Handle to a Reader. +func ReaderNumWorkUnitsCompletedV2(scope *Scope, reader_handle tf.Output) (units_completed tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "InplaceAdd", + Type: "ReaderNumWorkUnitsCompletedV2", Input: []tf.Input{ - x, i, v, + reader_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Restore a Reader to its initial clean state. +// Computes gradients for SparseSegmentSqrtN. +// +// Returns tensor "output" with same shape as grad, except for dimension 0 whose +// value is output_dim0. // // Arguments: -// reader_handle: Handle to a Reader. -// -// Returns the created operation. -func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { +// grad: gradient propagated to the SparseSegmentSqrtN op. +// indices: indices passed to the corresponding SparseSegmentSqrtN op. +// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op. +// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. +func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReaderResetV2", + Type: "SparseSegmentSqrtNGrad", Input: []tf.Input{ - reader_handle, + grad, indices, segment_ids, output_dim0, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reshapes a quantized tensor as per the Reshape op. +// +// ``` +// +// Arguments: +// +// shape: Defines the shape of the output tensor. +// input_min: The minimum value of the input. +// input_max: The maximum value of the input. +// +// Returns This value is copied from input_min.This value is copied from input_max. +func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min tf.Output, input_max tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "QuantizedReshape", + Input: []tf.Input{ + tensor, shape, input_min, input_max, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Outputs a `Summary` protocol buffer with a histogram. +// +// The generated +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// has one summary value containing a histogram for `values`. +// +// This op reports an `InvalidArgument` error if any value is not finite. +// +// Arguments: +// tag: Scalar. Tag to use for the `Summary.Value`. +// values: Any shape. Values to use to build the histogram. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "HistogramSummary", + Input: []tf.Input{ + tag, values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Restore a reader to a previously saved state. +// +// Not all Readers support being restored, so this can produce an +// Unimplemented error. +// +// Arguments: +// reader_handle: Handle to a Reader. +// state: Result of a ReaderSerializeState of a Reader with type +// matching reader_handle. +// +// Returns the created operation. +func ReaderRestoreStateV2(scope *Scope, reader_handle tf.Output, state tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderRestoreStateV2", + Input: []tf.Input{ + reader_handle, state, }, } return scope.AddOperation(opspec) } -// BatchAttr is an optional argument to Batch. -type BatchAttr func(optionalAttr) - -// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value. -// If not specified, defaults to 10 -func BatchMaxEnqueuedBatches(value int64) BatchAttr { - return func(m optionalAttr) { - m["max_enqueued_batches"] = value - } -} - -// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value. -// If not specified, defaults to <> -func BatchAllowedBatchSizes(value []int64) BatchAttr { - return func(m optionalAttr) { - m["allowed_batch_sizes"] = value - } -} - -// BatchContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func BatchContainer(value string) BatchAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// BatchSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func BatchSharedName(value string) BatchAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// BatchBatchingQueue sets the optional batching_queue attribute to value. -// If not specified, defaults to "" -func BatchBatchingQueue(value string) BatchAttr { - return func(m optionalAttr) { - m["batching_queue"] = value - } -} - -// Batches all input tensors nondeterministically. +// Elementwise computes the bitwise OR of `x` and `y`. // -// When many instances of this Op are being run concurrently with the same -// container/shared_name in the same device, some will output zero-shaped Tensors -// and others will output Tensors of size up to max_batch_size. -// -// All Tensors in in_tensors are batched together (so, for example, labels and -// features should be batched with a single instance of this operation. -// -// Each invocation of batch emits an `id` scalar which will be used to identify -// this particular invocation when doing unbatch or its gradient. -// -// Each op which emits a non-empty batch will also emit a non-empty batch_index -// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, -// start, and length of elements of each set of Tensors present in batched_tensors. -// -// Batched tensors are concatenated along the first dimension, and all tensors in -// in_tensors must have the first dimension of the same size. -// -// in_tensors: The tensors to be batched. -// num_batch_threads: Number of scheduling threads for processing batches of work. -// Determines the number of batches processed in parallel. -// max_batch_size: Batch sizes will never be bigger than this. -// batch_timeout_micros: Maximum number of microseconds to wait before outputting -// an incomplete batch. -// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does -// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad -// batches up to one of those sizes. The entries must increase monotonically, and -// the final entry must equal max_batch_size. -// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. -// batched_tensors: Either empty tensors or a batch of concatenated Tensors. -// batch_index: If out_tensors is non-empty, has information to invert it. -// container: Controls the scope of sharing of this batch. -// id: always contains a scalar with a unique ID for this invocation of Batch. -// shared_name: Concurrently running instances of batch in the same device with the -// same container and shared_name will batch their elements together. If left -// empty, the op name will be used as the shared name. -// T: the types of tensors to be batched. -func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) { +// The result will have those bits set, that are set in `x`, `y` or both. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros} - for _, a := range optional { - a(attrs) + opspec := tf.OpSpec{ + Type: "BitwiseOr", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Makes a new iterator from the given `dataset` and stores it in `iterator`. +// +// This operation may be executed multiple times. Each execution will reset the +// iterator in `iterator` to the first element of `dataset`. +// +// Returns the created operation. +func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return } opspec := tf.OpSpec{ - Type: "Batch", + Type: "MakeIterator", Input: []tf.Input{ - tf.OutputList(in_tensors), + dataset, iterator, + }, + } + return scope.AddOperation(opspec) +} + +// Selects num_to_sample rows of input using the KMeans++ criterion. +// +// Rows of points are assumed to be input points. One row is selected at random. +// Subsequent rows are sampled with probability proportional to the squared L2 +// distance from the nearest row selected thus far till num_to_sample rows have +// been sampled. +// +// Arguments: +// points: Matrix of shape (n, d). Rows are assumed to be input points. +// num_to_sample: Scalar. The number of rows to sample. This value must not be larger than n. +// seed: Scalar. Seed for initializing the random number generator. +// num_retries_per_sample: Scalar. For each row that is sampled, this parameter +// specifies the number of additional points to draw from the current +// distribution before selecting the best. If a negative value is specified, a +// heuristic is used to sample O(log(num_to_sample)) additional points. +// +// Returns Matrix of shape (num_to_sample, d). The sampled rows. +func KmeansPlusPlusInitialization(scope *Scope, points tf.Output, num_to_sample tf.Output, seed tf.Output, num_retries_per_sample tf.Output) (samples tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "KmeansPlusPlusInitialization", + Input: []tf.Input{ + points, num_to_sample, seed, num_retries_per_sample, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the next record (key, value pair) produced by a Reader. +// +// Will dequeue from the input queue if necessary (e.g. when the +// Reader needs to start reading from a new file since it has finished +// with the previous file). +// +// Arguments: +// reader_handle: Handle to a Reader. +// queue_handle: Handle to a Queue, with string work items. +// +// Returns A scalar.A scalar. +func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output) (key tf.Output, value tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderReadV2", + Input: []tf.Input{ + reader_handle, queue_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Writes contents to the file at input filename. Creates file and recursively +// +// creates directory if not existing. +// +// Arguments: +// filename: scalar. The name of the file to which we write the contents. +// contents: scalar. The content to be written to the output file. +// +// Returns the created operation. +func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "WriteFile", + Input: []tf.Input{ + filename, contents, + }, + } + return scope.AddOperation(opspec) +} + +// Creates a dataset that emits `components` as a tuple of tensors once. +func TensorDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "TensorDataset", + Input: []tf.Input{ + tf.OutputList(components), }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the set of files matching one or more glob patterns. +// +// Note that this routine only supports wildcard characters in the +// basename portion of the pattern, not in the directory portion. +// Note also that the order of filenames returned can be non-deterministic. +// +// Arguments: +// pattern: Shell wildcard pattern(s). Scalar or vector of type string. +// +// Returns A vector of matching filenames. +func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil { - scope.UpdateErr("Batch", err) + opspec := tf.OpSpec{ + Type: "MatchingFiles", + Input: []tf.Input{ + pattern, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeBilinearAttr is an optional argument to ResizeBilinear. +type ResizeBilinearAttr func(optionalAttr) + +// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// ResizeBilinearHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeBilinearHalfPixelCenters(value bool) ResizeBilinearAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value + } +} + +// Resize `images` to `size` using bilinear interpolation. +// +// Input images can be of different types but output images are always float. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { + if scope.Err() != nil { return } - batch_index = op.Output(idx) - id = op.Output(idx) - return batched_tensors, batch_index, id + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBilinear", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QuantizedResizeBilinearAttr is an optional argument to QuantizedResizeBilinear. +type QuantizedResizeBilinearAttr func(optionalAttr) + +// QuantizedResizeBilinearAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func QuantizedResizeBilinearAlignCorners(value bool) QuantizedResizeBilinearAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// QuantizedResizeBilinearHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func QuantizedResizeBilinearHalfPixelCenters(value bool) QuantizedResizeBilinearAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value + } +} + +// Resize quantized `images` to `size` using quantized bilinear interpolation. +// +// Input images and output images must be quantized types. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min tf.Output, max tf.Output, optional ...QuantizedResizeBilinearAttr) (resized_images tf.Output, out_min tf.Output, out_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedResizeBilinear", + Input: []tf.Input{ + images, size, min, max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. +type ResizeBilinearGradAttr func(optionalAttr) + +// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, the centers of the 4 corner pixels of the input and grad tensors are +// aligned. Defaults to false. +// If not specified, defaults to false +func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// ResizeBilinearGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. +// If not specified, defaults to false +func ResizeBilinearGradHalfPixelCenters(value bool) ResizeBilinearGradAttr { + return func(m optionalAttr) { + m["half_pixel_centers"] = value + } +} + +// Computes the gradient of bilinear interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBilinearGrad", + Input: []tf.Input{ + grads, original_image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deserialize and concatenate `SparseTensors` from a serialized minibatch. +// +// The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where +// `N` is the minibatch size and the rows correspond to packed outputs of +// `SerializeSparse`. The ranks of the original `SparseTensor` objects +// must all match. When the final `SparseTensor` is created, it has rank one +// higher than the ranks of the incoming `SparseTensor` objects +// (they have been concatenated along a new row dimension). +// +// The output `SparseTensor` object's shape values for all dimensions but the +// first are the max across the input `SparseTensor` objects' shape values +// for the corresponding dimensions. Its first shape value is `N`, the minibatch +// size. +// +// The input `SparseTensor` objects' indices are assumed ordered in +// standard lexicographic order. If this is not the case, after this +// step run `SparseReorder` to restore index ordering. +// +// For example, if the serialized input is a `[2 x 3]` matrix representing two +// original `SparseTensor` objects: +// +// index = [ 0] +// [10] +// [20] +// values = [1, 2, 3] +// shape = [50] +// +// and +// +// index = [ 2] +// [10] +// values = [4, 5] +// shape = [30] +// +// then the final deserialized `SparseTensor` will be: +// +// index = [0 0] +// [0 10] +// [0 20] +// [1 2] +// [1 10] +// values = [1, 2, 3, 4, 5] +// shape = [2 50] +// +// Arguments: +// serialized_sparse: 2-D, The `N` serialized `SparseTensor` objects. +// Must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` objects. +func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "DeserializeManySparse", + Input: []tf.Input{ + serialized_sparse, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// RandomCropAttr is an optional argument to RandomCrop. +type RandomCropAttr func(optionalAttr) + +// RandomCropSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomCropSeed(value int64) RandomCropAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomCropSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomCropSeed2(value int64) RandomCropAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Randomly crop `image`. +// +// DEPRECATED at GraphDef version 8: Random crop is now pure Python +// +// `size` is a 1-D int64 tensor with 2 elements representing the crop height and +// width. The values must be non negative. +// +// This Op picks a random location in `image` and crops a `height` by `width` +// rectangle from that location. The random location is picked so the cropped +// area will fit inside the original image. +// +// Arguments: +// image: 3-D of shape `[height, width, channels]`. +// size: 1-D of length 2 containing: `crop_height`, `crop_width`.. +// +// Returns 3-D of shape `[crop_height, crop_width, channels].` +func RandomCrop(scope *Scope, image tf.Output, size tf.Output, optional ...RandomCropAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomCrop", + Input: []tf.Input{ + image, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns locations of nonzero / true values in a tensor. +// +// This operation returns the coordinates of true elements in `condition`. The +// coordinates are returned in a 2-D tensor where the first dimension (rows) +// represents the number of true elements, and the second dimension (columns) +// represents the coordinates of the true elements. Keep in mind, the shape of +// the output tensor can vary depending on how many true values there are in +// `condition`. Indices are output in row-major order. +// +// For example: +// +// ``` +// # 'input' tensor is [[True, False] +// # [True, False]] +// # 'input' has two true values, so output has two coordinates. +// # 'input' has rank of 2, so coordinates have two indices. +// where(input) ==> [[0, 0], +// [1, 0]] +// +// # `condition` tensor is [[[True, False] +// # [True, False]] +// # [[False, True] +// # [False, True]] +// # [[False, False] +// # [False, True]]] +// # 'input' has 5 true values, so output has 5 coordinates. +// # 'input' has rank of 3, so coordinates have three indices. +// where(input) ==> [[0, 0, 0], +// [0, 1, 0], +// [1, 0, 1], +// [1, 1, 1], +// [2, 1, 1]] +// +// # `condition` tensor is [[[1.5, 0.0] +// # [-0.5, 0.0]] +// # [[0.0, 0.25] +// # [0.0, 0.75]] +// # [[0.0, 0.0] +// # [0.0, 0.01]]] +// # 'input' has 5 nonzero values, so output has 5 coordinates. +// # 'input' has rank of 3, so coordinates have three indices. +// where(input) ==> [[0, 0, 0], +// [0, 1, 0], +// [1, 0, 1], +// [1, 1, 1], +// [2, 1, 1]] +// +// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j] +// # [0.0 + 0.5j, 0.0 + 0.0j]] +// # [[0.0 + 0.0j, 0.25 + 1.5j] +// # [0.0 + 0.0j, 0.75 + 0.0j]] +// # [[0.0 + 0.0j, 0.0 + 0.0j] +// # [0.0 + 0.0j, 0.01 + 0.0j]]] +// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates. +// # 'input' has rank of 3, so coordinates have three indices. +// where(input) ==> [[0, 0, 0], +// [0, 1, 0], +// [1, 0, 1], +// [1, 1, 1], +// [2, 1, 1]] +// ``` +func Where(scope *Scope, condition tf.Output) (index tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Where", + Input: []tf.Input{ + condition, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeJpegAttr is an optional argument to DecodeJpeg. +type DecodeJpegAttr func(optionalAttr) + +// DecodeJpegChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodeJpegChannels(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodeJpegRatio sets the optional ratio attribute to value. +// +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeJpegRatio(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["ratio"] = value + } +} + +// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["fancy_upscaling"] = value + } +} + +// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeJpegDctMethod sets the optional dct_method attribute to value. +// +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeJpegDctMethod(value string) DecodeJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } +} + +// Decode a JPEG-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// +// This op also supports decoding PNGs and non-animated GIFs since the interface is +// the same, though it is cleaner to use `tf.image.decode_image`. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeJpeg", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Disallowed in GraphDef version >= 2. +// +// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead +func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustContrast", + Input: []tf.Input{ + images, contrast_factor, min_value, max_value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that batches and pads `batch_size` elements from the input. +// +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// padded_shapes: A list of int64 tensors representing the desired padded shapes +// of the corresponding output components. These shapes may be partially +// specified, using `-1` to indicate that a particular dimension should be +// padded to the maximum size of all batch elements. +// padding_values: A list of scalars containing the padding value to use for +// each of the outputs. +// +func PaddedBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "PaddedBatchDataset", + Input: []tf.Input{ + input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } // Adjust the hue of one or more images. @@ -36334,39 +38332,53 @@ func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Outpu return op.Output(0) } -// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad. -type ResizeBicubicGradAttr func(optionalAttr) +// Computes sin of x element-wise. +func Sin(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sin", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// ResizeBicubicGradAlignCorners sets the optional align_corners attribute to value. +// EncodePngAttr is an optional argument to EncodePng. +type EncodePngAttr func(optionalAttr) + +// EncodePngCompression sets the optional compression attribute to value. // -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeBicubicGradAlignCorners(value bool) ResizeBicubicGradAttr { +// value: Compression level. +// If not specified, defaults to -1 +func EncodePngCompression(value int64) EncodePngAttr { return func(m optionalAttr) { - m["align_corners"] = value + m["compression"] = value } } -// ResizeBicubicGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. -// If not specified, defaults to false -func ResizeBicubicGradHalfPixelCenters(value bool) ResizeBicubicGradAttr { - return func(m optionalAttr) { - m["half_pixel_centers"] = value - } -} - -// Computes the gradient of bicubic interpolation. +// PNG-encode an image. +// +// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` +// where `channels` is: +// +// * 1: for grayscale. +// * 2: for grayscale + alpha. +// * 3: for RGB. +// * 4: for RGBA. +// +// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder +// default or a value from 0 to 9. 9 is the highest compression level, generating +// the smallest output, but is slower. // // Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. +// image: 3-D with shape `[height, width, channels]`. // -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBicubicGradAttr) (output tf.Output) { +// Returns 0-D. PNG-encoded image. +func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { if scope.Err() != nil { return } @@ -36375,9 +38387,9 @@ func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeBicubicGrad", + Type: "EncodePng", Input: []tf.Input{ - grads, original_image, + image, }, Attrs: attrs, } @@ -36385,263 +38397,32 @@ func ResizeBicubicGrad(scope *Scope, grads tf.Output, original_image tf.Output, return op.Output(0) } -// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. -type ResizeNearestNeighborAttr func(optionalAttr) +// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler. +type LogUniformCandidateSamplerAttr func(optionalAttr) -// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// ResizeNearestNeighborHalfPixelCenters sets the optional half_pixel_centers attribute to value. -// If not specified, defaults to false -func ResizeNearestNeighborHalfPixelCenters(value bool) ResizeNearestNeighborAttr { - return func(m optionalAttr) { - m["half_pixel_centers"] = value - } -} - -// Resize `images` to `size` using nearest neighbor interpolation. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeNearestNeighbor", - Input: []tf.Input{ - images, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. -type ResizeNearestNeighborGradAttr func(optionalAttr) - -// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and grad tensors are -// aligned. Defaults to false. -// If not specified, defaults to false -func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// ResizeNearestNeighborGradHalfPixelCenters sets the optional half_pixel_centers attribute to value. -// If not specified, defaults to false -func ResizeNearestNeighborGradHalfPixelCenters(value bool) ResizeNearestNeighborGradAttr { - return func(m optionalAttr) { - m["half_pixel_centers"] = value - } -} - -// Computes the gradient of nearest neighbor interpolation. -// -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The -// original input size. -// -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients -// with respect to the input image. -func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeNearestNeighborGrad", - Input: []tf.Input{ - grads, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ExtractJpegShapeAttr is an optional argument to ExtractJpegShape. -type ExtractJpegShapeAttr func(optionalAttr) - -// ExtractJpegShapeOutputType sets the optional output_type attribute to value. -// -// value: (Optional) The output type of the operation (int32 or int64). -// Defaults to int32. -// If not specified, defaults to DT_INT32 -func ExtractJpegShapeOutputType(value tf.DataType) ExtractJpegShapeAttr { - return func(m optionalAttr) { - m["output_type"] = value - } -} - -// Extract the shape information of a JPEG-encoded image. -// -// This op only parses the image header, so it is much faster than DecodeJpeg. -// -// Arguments: -// contents: 0-D. The JPEG-encoded image. -// -// Returns 1-D. The image shape with format [height, width, channels]. -func ExtractJpegShape(scope *Scope, contents tf.Output, optional ...ExtractJpegShapeAttr) (image_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ExtractJpegShape", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DecodePngAttr is an optional argument to DecodePng. -type DecodePngAttr func(optionalAttr) - -// DecodePngChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodePngChannels(value int64) DecodePngAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodePngDtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_UINT8 -func DecodePngDtype(value tf.DataType) DecodePngAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Decode a PNG-encoded image to a uint8 or uint16 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the PNG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// * 4: output an RGBA image. -// -// If needed, the PNG-encoded image is transformed to match the requested number -// of color channels. -// -// This op also supports decoding JPEGs and non-animated GIFs since the interface -// is the same, though it is cleaner to use `tf.image.decode_image`. -// -// Arguments: -// contents: 0-D. The PNG-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`. -func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodePng", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Decode the first frame of a GIF-encoded image to a uint8 tensor. -// -// GIF with frame or transparency compression are not supported -// convert animated GIF from compressed to uncompressed by: -// -// convert $src.gif -coalesce $dst.gif -// -// This op also supports decoding JPEGs and PNGs, though it is cleaner to use -// `tf.image.decode_image`. -// -// Arguments: -// contents: 0-D. The GIF-encoded image. -// -// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order -func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DecodeGif", - Input: []tf.Input{ - contents, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. -type LearnedUnigramCandidateSamplerAttr func(optionalAttr) - -// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// LogUniformCandidateSamplerSeed sets the optional seed attribute to value. // // value: If either seed or seed2 are set to be non-zero, the random number // generator is seeded by the given seed. Otherwise, it is seeded by a // random seed. // If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { +func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr { return func(m optionalAttr) { m["seed"] = value } } -// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value. // // value: An second seed to avoid seed collision. // If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { +func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr { return func(m optionalAttr) { m["seed2"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. +// Generates labels for candidate sampling with a log-uniform distribution. // // See explanations of candidate sampling and the data formats at // go/candidate-sampling. @@ -36670,7 +38451,7 @@ func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSam // candidate representing the number of times the candidate is expected // to occur in a batch of sampled candidates. If unique=true, then this is a // probability. -func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } @@ -36679,7 +38460,7 @@ func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_tr a(attrs) } opspec := tf.OpSpec{ - Type: "LearnedUnigramCandidateSampler", + Type: "LogUniformCandidateSampler", Input: []tf.Input{ true_classes, }, @@ -36689,147 +38470,104 @@ func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_tr return op.Output(0), op.Output(1), op.Output(2) } -// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. -type RandomShuffleQueueV2Attr func(optionalAttr) +// LoadTPUEmbeddingAdadeltaParametersAttr is an optional argument to LoadTPUEmbeddingAdadeltaParameters. +type LoadTPUEmbeddingAdadeltaParametersAttr func(optionalAttr) -// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. -// -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value - } -} - -// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. +// LoadTPUEmbeddingAdadeltaParametersTableId sets the optional table_id attribute to value. // If not specified, defaults to -1 -func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { +// +// REQUIRES: value >= -1 +func LoadTPUEmbeddingAdadeltaParametersTableId(value int64) LoadTPUEmbeddingAdadeltaParametersAttr { return func(m optionalAttr) { - m["capacity"] = value + m["table_id"] = value } } -// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. -// -// value: Dequeue will block unless there would be this -// many elements after the dequeue or the queue is closed. This -// ensures a minimum level of mixing of elements. -// If not specified, defaults to 0 -func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["min_after_dequeue"] = value - } -} - -// RandomShuffleQueueV2Seed sets the optional seed attribute to value. -// -// value: If either seed or seed2 is set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// RandomShuffleQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. +// LoadTPUEmbeddingAdadeltaParametersTableName sets the optional table_name attribute to value. // If not specified, defaults to "" -func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { +func LoadTPUEmbeddingAdadeltaParametersTableName(value string) LoadTPUEmbeddingAdadeltaParametersAttr { return func(m optionalAttr) { - m["container"] = value + m["table_name"] = value } } -// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. +// Load Adadelta embedding parameters. // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that randomizes the order of elements. +// An op that loads optimization parameters into HBM for embedding. Must be +// preceded by a ConfigureTPUEmbeddingHost op that sets up the correct +// embedding table configuration. For example, this op is used to install +// parameters that are loaded from a checkpoint before a training loop is +// executed. // // Arguments: -// component_types: The type of each component in a value. +// parameters: Value of parameters used in the Adadelta optimization algorithm. +// accumulators: Value of accumulators used in the Adadelta optimization algorithm. +// updates: Value of updates used in the Adadelta optimization algorithm. // -// Returns The handle to the queue. -func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { +// +// +// Returns the created operation. +func LoadTPUEmbeddingAdadeltaParameters(scope *Scope, parameters tf.Output, accumulators tf.Output, updates tf.Output, num_shards int64, shard_id int64, optional ...LoadTPUEmbeddingAdadeltaParametersAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomShuffleQueueV2", - + Type: "LoadTPUEmbeddingAdadeltaParameters", + Input: []tf.Input{ + parameters, accumulators, updates, + }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Converts one or more images from RGB to HSV. +// +// Outputs a tensor of the same shape as the `images` tensor, containing the HSV +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and +// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 +// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. +// +// Arguments: +// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. +// +// Returns `images` converted to HSV. +func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RGBToHSV", + Input: []tf.Input{ + images, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } -// SerializeSparseAttr is an optional argument to SerializeSparse. -type SerializeSparseAttr func(optionalAttr) - -// SerializeSparseOutType sets the optional out_type attribute to value. +// Creates a TensorList which, when stacked, has the value of `tensor`. // -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Serialize a `SparseTensor` into a `[3]` `Tensor` object. +// Each tensor in the result list corresponds to one row of the input tensor. // -// Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. -func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { +// tensor: The input tensor. +// output_handle: The list. +func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Output) (output_handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SerializeSparse", + Type: "TensorListFromTensor", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + tensor, element_shape, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -37014,392 +38752,156 @@ func SampleDistortedBoundingBoxV2(scope *Scope, image_size tf.Output, bounding_b return op.Output(0), op.Output(1), op.Output(2) } -// Computes requantization range per channel. +// SpaceToDepthAttr is an optional argument to SpaceToDepth. +type SpaceToDepthAttr func(optionalAttr) + +// SpaceToDepthDataFormat sets the optional data_format attribute to value. +// If not specified, defaults to "NHWC" +func SpaceToDepthDataFormat(value string) SpaceToDepthAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// SpaceToDepth for tensors of type T. +// +// Rearranges blocks of spatial data, into depth. More specifically, +// this op outputs a copy of the input tensor where values from the `height` +// and `width` dimensions are moved to the `depth` dimension. +// The attr `block_size` indicates the input block size. +// +// * Non-overlapping blocks of size `block_size x block size` are rearranged +// into depth at each location. +// * The depth of the output tensor is `block_size * block_size * input_depth`. +// * The Y, X coordinates within each block of the input become the high order +// component of the output channel index. +// * The input tensor's height and width must be divisible by block_size. +// +// The `data_format` attr specifies the layout of the input and output tensors +// with the following options: +// "NHWC": `[ batch, height, width, channels ]` +// "NCHW": `[ batch, channels, height, width ]` +// "NCHW_VECT_C": +// `qint8 [ batch, channels / 4, height, width, 4 ]` +// +// It is useful to consider the operation as transforming a 6-D Tensor. +// e.g. for data_format = NHWC, +// Each element in the input tensor can be specified via 6 coordinates, +// ordered by decreasing memory layout significance as: +// n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates +// within the output image, bX, bY means coordinates +// within the input block, iC means input channels). +// The output would be a transpose to the following layout: +// n,oY,oX,bY,bX,iC +// +// This operation is useful for resizing the activations between convolutions +// (but keeping all data), e.g. instead of pooling. It is also useful for training +// purely convolutional models. +// +// For example, given an input of shape `[1, 2, 2, 1]`, data_format = "NHWC" and +// block_size = 2: +// +// ``` +// x = [[[[1], [2]], +// [[3], [4]]]] +// ``` +// +// This operation will output a tensor of shape `[1, 1, 1, 4]`: +// +// ``` +// [[[[1, 2, 3, 4]]]] +// ``` +// +// Here, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`, +// the corresponding output will have a single element (i.e. width and height are +// both 1) and will have a depth of 4 channels (1 * block_size * block_size). +// The output element shape is `[1, 1, 4]`. +// +// For an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g. +// +// ``` +// x = [[[[1, 2, 3], [4, 5, 6]], +// [[7, 8, 9], [10, 11, 12]]]] +// ``` +// +// This operation, for block_size of 2, will return the following tensor of shape +// `[1, 1, 1, 12]` +// +// ``` +// [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] +// ``` +// +// Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2: +// +// ``` +// x = [[[[1], [2], [5], [6]], +// [[3], [4], [7], [8]], +// [[9], [10], [13], [14]], +// [[11], [12], [15], [16]]]] +// ``` +// +// the operator will return the following tensor of shape `[1 2 2 4]`: +// +// ``` +// x = [[[[1, 2, 3, 4], +// [5, 6, 7, 8]], +// [[9, 10, 11, 12], +// [13, 14, 15, 16]]]] +// ``` // // Arguments: -// input: The original input tensor. -// input_min: The minimum value of the input tensor -// input_max: The maximum value of the input tensor. -// clip_value_max: The maximum value of the output that needs to be clipped. -// Example: set this to 6 for Relu6. // -// Returns The minimum value of the final output tensorThe maximum value of the final output tensor. -func RequantizationRangePerChannel(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, clip_value_max float32) (output_min tf.Output, output_max tf.Output) { +// block_size: The size of the spatial block. +func SpaceToDepth(scope *Scope, input tf.Output, block_size int64, optional ...SpaceToDepthAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"clip_value_max": clip_value_max} - opspec := tf.OpSpec{ - Type: "RequantizationRangePerChannel", - Input: []tf.Input{ - input, input_min, input_max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// ExtractGlimpseAttr is an optional argument to ExtractGlimpse. -type ExtractGlimpseAttr func(optionalAttr) - -// ExtractGlimpseCentered sets the optional centered attribute to value. -// -// value: indicates if the offset coordinates are centered relative to -// the image, in which case the (0, 0) offset is relative to the center -// of the input images. If false, the (0,0) offset corresponds to the -// upper left corner of the input images. -// If not specified, defaults to true -func ExtractGlimpseCentered(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["centered"] = value - } -} - -// ExtractGlimpseNormalized sets the optional normalized attribute to value. -// -// value: indicates if the offset coordinates are normalized. -// If not specified, defaults to true -func ExtractGlimpseNormalized(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["normalized"] = value - } -} - -// ExtractGlimpseUniformNoise sets the optional uniform_noise attribute to value. -// -// value: indicates if the noise should be generated using a -// uniform distribution or a Gaussian distribution. -// If not specified, defaults to true -func ExtractGlimpseUniformNoise(value bool) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["uniform_noise"] = value - } -} - -// ExtractGlimpseNoise sets the optional noise attribute to value. -// -// value: indicates if the noise should `uniform`, `gaussian`, or -// `zero`. The default is `uniform` which means the the noise type -// will be decided by `uniform_noise`. -// If not specified, defaults to "uniform" -func ExtractGlimpseNoise(value string) ExtractGlimpseAttr { - return func(m optionalAttr) { - m["noise"] = value - } -} - -// Extracts a glimpse from the input tensor. -// -// Returns a set of windows called glimpses extracted at location -// `offsets` from the input tensor. If the windows only partially -// overlaps the inputs, the non overlapping areas will be filled with -// random noise. -// -// The result is a 4-D tensor of shape `[batch_size, glimpse_height, -// glimpse_width, channels]`. The channels and batch dimensions are the -// same as that of the input tensor. The height and width of the output -// windows are specified in the `size` parameter. -// -// The argument `normalized` and `centered` controls how the windows are built: -// -// * If the coordinates are normalized but not centered, 0.0 and 1.0 -// correspond to the minimum and maximum of each height and width -// dimension. -// * If the coordinates are both normalized and centered, they range from -// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper -// left corner, the lower right corner is located at (1.0, 1.0) and the -// center is at (0, 0). -// * If the coordinates are not normalized they are interpreted as -// numbers of pixels. -// -// Arguments: -// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. -// size: A 1-D tensor of 2 elements containing the size of the glimpses -// to extract. The glimpse height must be specified first, following -// by the glimpse width. -// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing -// the y, x locations of the center of each window. -// -// Returns A tensor representing the glimpses `[batch_size, -// glimpse_height, glimpse_width, channels]`. -func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseAttr) (glimpse tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"block_size": block_size} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ExtractGlimpse", - Input: []tf.Input{ - input, size, offsets, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// A container for an iterator resource. -// -// Returns A handle to the iterator that can be passed to a "MakeIterator" -// or "IteratorGetNext" op. -func Iterator(scope *Scope, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "Iterator", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorForestTreeResourceHandleOpAttr is an optional argument to TensorForestTreeResourceHandleOp. -type TensorForestTreeResourceHandleOpAttr func(optionalAttr) - -// TensorForestTreeResourceHandleOpContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func TensorForestTreeResourceHandleOpContainer(value string) TensorForestTreeResourceHandleOpAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// TensorForestTreeResourceHandleOpSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func TensorForestTreeResourceHandleOpSharedName(value string) TensorForestTreeResourceHandleOpAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a handle to a TensorForestTreeResource -func TensorForestTreeResourceHandleOp(scope *Scope, optional ...TensorForestTreeResourceHandleOpAttr) (resource tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorForestTreeResourceHandleOp", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. -type CropAndResizeGradImageAttr func(optionalAttr) - -// CropAndResizeGradImageMethod sets the optional method attribute to value. -// -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input image tensor. -// -// Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` -// containing the original image size. Both `image_height` and `image_width` need -// to be positive. -// -// -// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"T": T} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CropAndResizeGradImage", - Input: []tf.Input{ - grads, boxes, box_ind, image_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ShuffleDatasetAttr is an optional argument to ShuffleDataset. -type ShuffleDatasetAttr func(optionalAttr) - -// ShuffleDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value. -// -// value: If true, each iterator over this dataset will be given -// a different pseudorandomly generated seed, based on a sequence seeded by the -// `seed` and `seed2` inputs. If false, each iterator will be given the same -// seed, and repeated iteration over this dataset will yield the exact same -// sequence of results. -// If not specified, defaults to true -func ShuffleDatasetReshuffleEachIteration(value bool) ShuffleDatasetAttr { - return func(m optionalAttr) { - m["reshuffle_each_iteration"] = value - } -} - -// Creates a dataset that shuffles elements from `input_dataset` pseudorandomly. -// -// Arguments: -// -// buffer_size: The number of output elements to buffer in an iterator over -// this dataset. Compare with the `min_after_dequeue` attr when creating a -// `RandomShuffleQueue`. -// seed: A scalar seed for the random number generator. If either `seed` or -// `seed2` is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. -// -// -func ShuffleDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleDatasetAttr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ShuffleDataset", - Input: []tf.Input{ - input_dataset, buffer_size, seed, seed2, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 3D fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 -// dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their 3D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fftn with 3 dimensions. -// @end_compatibility -func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT3D", + Type: "SpaceToDepth", Input: []tf.Input{ input, }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CropAndResizeGradBoxesAttr is an optional argument to CropAndResizeGradBoxes. -type CropAndResizeGradBoxesAttr func(optionalAttr) - -// CropAndResizeGradBoxesMethod sets the optional method attribute to value. -// -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradBoxesMethod(value string) CropAndResizeGradBoxesAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input boxes tensor. -// -// Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -// Both `image_height` and `image_width` need to be positive. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// -// Returns A 2-D tensor of shape `[num_boxes, 4]`. -func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxes tf.Output, box_ind tf.Output, optional ...CropAndResizeGradBoxesAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CropAndResizeGradBoxes", - Input: []tf.Input{ - grads, image, boxes, box_ind, - }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } +// NonMaxSuppressionAttr is an optional argument to NonMaxSuppression. +type NonMaxSuppressionAttr func(optionalAttr) + +// NonMaxSuppressionIouThreshold sets the optional iou_threshold attribute to value. +// +// value: A float representing the threshold for deciding whether boxes +// overlap too much with respect to IOU. +// If not specified, defaults to 0.5 +func NonMaxSuppressionIouThreshold(value float32) NonMaxSuppressionAttr { + return func(m optionalAttr) { + m["iou_threshold"] = value + } +} + // Greedily selects a subset of bounding boxes in descending order of score, // // pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes with score less than -// `score_threshold` are removed. Bounding boxes are supplied as +// with previously selected boxes. Bounding boxes are supplied as // [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any // diagonal pair of box corners and the coordinates can be provided as normalized // (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system and more -// generally is invariant to orthogonal transformations and translations +// is agnostic to where the origin is in the coordinate system. Note that this +// algorithm is invariant to orthogonal transformations and translations // of the coordinate system; thus translating or reflections of the coordinate // system result in the same boxes being selected by the algorithm. // The output of this operation is a set of integers indexing into the input // collection of bounding boxes representing the selected boxes. The bounding // box coordinates corresponding to the selected indices can then be obtained // using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression_v2( -// boxes, scores, max_output_size, iou_threshold, score_threshold) +// selected_indices = tf.image.non_max_suppression( +// boxes, scores, max_output_size, iou_threshold) // selected_boxes = tf.gather(boxes, selected_indices) // // Arguments: @@ -37408,117 +38910,291 @@ func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxe // score corresponding to each box (each row of boxes). // max_output_size: A scalar integer tensor representing the maximum number of // boxes to be selected by non max suppression. -// iou_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too much with respect to IOU. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, optional ...NonMaxSuppressionAttr) (selected_indices tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "NonMaxSuppression", + Input: []tf.Input{ + boxes, scores, max_output_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Transforms a vector of brain.Example protos (as strings) into typed tensors. +// +// Arguments: +// serialized: A vector containing a batch of binary serialized Example protos. +// names: A vector containing the names of the serialized protos. +// May contain, for example, table key (descriptive) names for the +// corresponding serialized protos. These are purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty vector if no names are available. +// If non-empty, this vector must be the same length as "serialized". +// sparse_keys: A list of Nsparse string Tensors (scalars). +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: A list of Ndense string Tensors (scalars). +// The keys expected in the Examples' features associated with dense values. +// dense_defaults: A list of Ndense Tensors (some may be empty). +// dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// sparse_types: A list of Nsparse types; the data types of data in each Feature +// given in sparse_keys. +// Currently the ParseExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: A list of Ndense shapes; the shapes of data in each Feature +// given in dense_keys. +// The number of elements in the Feature corresponding to dense_key[j] +// must always equal dense_shapes[j].NumEntries(). +// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output +// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): +// The dense outputs are just the inputs row-stacked by batch. +// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case +// the shape of the output Tensor dense_values[j] will be +// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks +// of elements of length D1 * .... * DN, across all minibatch entries +// in the input. Any minibatch entry with less than M blocks of elements of +// length D1 * ... * DN will be padded with the corresponding default_value +// scalar element along the second dimension. +func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"sparse_types": sparse_types, "dense_shapes": dense_shapes} + opspec := tf.OpSpec{ + Type: "ParseExample", + Input: []tf.Input{ + serialized, names, tf.OutputList(sparse_keys), tf.OutputList(dense_keys), tf.OutputList(dense_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + return sparse_indices, sparse_values, sparse_shapes, dense_values +} + +// TryRpcAttr is an optional argument to TryRpc. +type TryRpcAttr func(optionalAttr) + +// TryRpcProtocol sets the optional protocol attribute to value. +// +// value: RPC protocol to use. Empty string means use the default protocol. +// Options include 'grpc'. +// If not specified, defaults to "" +func TryRpcProtocol(value string) TryRpcAttr { + return func(m optionalAttr) { + m["protocol"] = value + } +} + +// TryRpcFailFast sets the optional fail_fast attribute to value. +// +// value: `boolean`. If `true` (default), then failures to connect +// (i.e., the server does not immediately respond) cause an RPC failure. +// If not specified, defaults to true +func TryRpcFailFast(value bool) TryRpcAttr { + return func(m optionalAttr) { + m["fail_fast"] = value + } +} + +// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value. +// +// value: `int`. If `0` (default), then the kernel will run the RPC +// request and only time out if the RPC deadline passes or the session times out. +// If this value is greater than `0`, then the op will raise an exception if +// the RPC takes longer than `timeout_in_ms`. +// If not specified, defaults to 0 +func TryRpcTimeoutInMs(value int64) TryRpcAttr { + return func(m optionalAttr) { + m["timeout_in_ms"] = value + } +} + +// Perform batches of RPC requests. +// +// This op asynchronously performs either a single RPC request, or a batch +// of requests. RPC requests are defined by three main parameters: +// +// - `address` (the host+port or BNS address of the request) +// - `method` (the method name for the request) +// - `request` (the serialized proto string, or vector of strings, +// of the RPC request argument). +// +// For example, if you have an RPC service running on port localhost:2345, +// and its interface is configured with the following proto declaration: +// +// ``` +// service MyService { +// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { +// } +// }; +// ``` +// +// then call this op with arguments: +// +// ``` +// address = "localhost:2345" +// method = "MyService/MyMethod" +// ``` +// +// The `request` tensor is a string tensor representing serialized `MyRequestProto` +// strings; and the output string tensor `response` will have the same shape +// and contain (upon successful completion) corresponding serialized +// `MyResponseProto` strings. +// +// For example, to send a single, empty, `MyRequestProto`, call +// this op with `request = ""`. To send 5 **parallel** empty requests, +// call this op with `request = ["", "", "", "", ""]`. +// +// More generally, one can create a batch of `MyRequestProto` serialized protos +// from regular batched tensors using the `encode_proto` op, and convert +// the response `MyResponseProto` serialized protos to batched tensors +// using the `decode_proto` op. +// +// **NOTE** Working with serialized proto strings is faster than instantiating +// actual proto objects in memory, so no performance degradation is expected +// compared to writing custom kernels for this workflow. +// +// Unlike the standard `Rpc` op, if the connection fails or the remote worker +// returns an error status, this op does **not** reraise the exception. +// Instead, the `status_code` and `status_message` entry for the corresponding RPC +// call is set with the error returned from the RPC call. The `response` tensor +// will contain valid response values for those minibatch entries whose RPCs did +// not fail; the rest of the entries will have empty strings. +// +// Arguments: +// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `method` and `request`. +// method: `0-D` or `1-D`. The method address on the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `request`. +// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `method`. +// +// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages +// returned from the RPC calls. +func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TryRpc", + Input: []tf.Input{ + address, method, request, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes natural logarithm of x element-wise. +// +// I.e., \\(y = \log_e x\\). +func Log(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Log", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high overlaps +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. N-by-n overlap values are supplied as square matrix, +// which allows for defining a custom overlap criterium (eg. intersection over union, +// intersection over area, etc.). +// +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// +// selected_indices = tf.image.non_max_suppression_with_overlaps( +// overlaps, scores, max_output_size, overlap_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// overlaps: A 2-D float tensor of shape `[num_boxes, num_boxes]` representing +// the n-by-n box overlap values. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// overlap_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too. // score_threshold: A 0-D float tensor representing the threshold for deciding when to remove // boxes based on score. // // Returns A 1-D integer tensor of shape `[M]` representing the selected // indices from the boxes tensor, where `M <= max_output_size`. -func NonMaxSuppressionV3(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { +func NonMaxSuppressionWithOverlaps(scope *Scope, overlaps tf.Output, scores tf.Output, max_output_size tf.Output, overlap_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NonMaxSuppressionV3", + Type: "NonMaxSuppressionWithOverlaps", Input: []tf.Input{ - boxes, scores, max_output_size, iou_threshold, score_threshold, + overlaps, scores, max_output_size, overlap_threshold, score_threshold, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// NonMaxSuppressionV4Attr is an optional argument to NonMaxSuppressionV4. -type NonMaxSuppressionV4Attr func(optionalAttr) - -// NonMaxSuppressionV4PadToMaxOutputSize sets the optional pad_to_max_output_size attribute to value. -// -// value: If true, the output `selected_indices` is padded to be of length -// `max_output_size`. Defaults to false. -// If not specified, defaults to false -func NonMaxSuppressionV4PadToMaxOutputSize(value bool) NonMaxSuppressionV4Attr { - return func(m optionalAttr) { - m["pad_to_max_output_size"] = value - } -} - -// Greedily selects a subset of bounding boxes in descending order of score, -// -// pruning away boxes that have high intersection-over-union (IOU) overlap -// with previously selected boxes. Bounding boxes with score less than -// `score_threshold` are removed. Bounding boxes are supplied as -// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any -// diagonal pair of box corners and the coordinates can be provided as normalized -// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm -// is agnostic to where the origin is in the coordinate system and more -// generally is invariant to orthogonal transformations and translations -// of the coordinate system; thus translating or reflections of the coordinate -// system result in the same boxes being selected by the algorithm. -// The output of this operation is a set of integers indexing into the input -// collection of bounding boxes representing the selected boxes. The bounding -// box coordinates corresponding to the selected indices can then be obtained -// using the `tf.gather operation`. For example: -// selected_indices = tf.image.non_max_suppression_v2( -// boxes, scores, max_output_size, iou_threshold, score_threshold) -// selected_boxes = tf.gather(boxes, selected_indices) -// -// Arguments: -// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. -// scores: A 1-D float tensor of shape `[num_boxes]` representing a single -// score corresponding to each box (each row of boxes). -// max_output_size: A scalar integer tensor representing the maximum number of -// boxes to be selected by non max suppression. -// iou_threshold: A 0-D float tensor representing the threshold for deciding whether -// boxes overlap too much with respect to IOU. -// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove -// boxes based on score. -// -// Returns A 1-D integer tensor of shape `[M]` representing the selected -// indices from the boxes tensor, where `M <= max_output_size`.A 0-D integer tensor representing the number of valid elements in -// `selected_indices`, with the valid elements appearing first. -func NonMaxSuppressionV4(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output, optional ...NonMaxSuppressionV4Attr) (selected_indices tf.Output, valid_outputs tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "NonMaxSuppressionV4", - Input: []tf.Input{ - boxes, scores, max_output_size, iou_threshold, score_threshold, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Removes keys and its associated values from a table. -// -// The tensor `keys` must of the same type as the keys of the table. Keys not -// already in the table are silently ignored. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys of the elements to remove. -// -// Returns the created operation. -func LookupTableRemoveV2(scope *Scope, table_handle tf.Output, keys tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableRemoveV2", - Input: []tf.Input{ - table_handle, keys, - }, - } - return scope.AddOperation(opspec) -} - // CombinedNonMaxSuppressionAttr is an optional argument to CombinedNonMaxSuppression. type CombinedNonMaxSuppressionAttr func(optionalAttr) @@ -37536,6 +39212,18 @@ func CombinedNonMaxSuppressionPadPerClass(value bool) CombinedNonMaxSuppressionA } } +// CombinedNonMaxSuppressionClipBoxes sets the optional clip_boxes attribute to value. +// +// value: If true, assume the box coordinates are between [0, 1] and clip the output boxes +// if they fall beyond [0, 1]. If false, do not do clipping and output the box +// coordinates as it is. +// If not specified, defaults to true +func CombinedNonMaxSuppressionClipBoxes(value bool) CombinedNonMaxSuppressionAttr { + return func(m optionalAttr) { + m["clip_boxes"] = value + } +} + // Greedily selects a subset of bounding boxes in descending order of score, // // This operation performs non_max_suppression on the inputs per batch, across @@ -37592,40 +39280,160 @@ func CombinedNonMaxSuppression(scope *Scope, boxes tf.Output, scores tf.Output, return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Computes the matrix logarithm of one or more square matrices: +// UnbatchAttr is an optional argument to Unbatch. +type UnbatchAttr func(optionalAttr) + +// UnbatchContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnbatchContainer(value string) UnbatchAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// UnbatchSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnbatchSharedName(value string) UnbatchAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Reverses the operation of Batch for a single output Tensor. // +// An instance of Unbatch either receives an empty batched_tensor, in which case it +// asynchronously waits until the values become available from a concurrently +// running instance of Unbatch with the same container and shared_name, or receives +// a non-empty batched_tensor in which case it finalizes all other concurrently +// running instances and outputs its own element from the batch. // -// \\(log(exp(A)) = A\\) +// batched_tensor: The possibly transformed output of Batch. The size of the first +// dimension should remain unchanged by the transformations for the operation to +// work. +// batch_index: The matching batch_index obtained from Batch. +// id: The id scalar emitted by Batch. +// unbatched_tensor: The Tensor corresponding to this execution. +// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the +// batched input tensor associated with a given invocation of the op. +// container: Container to control resource sharing. +// shared_name: Instances of Unbatch with the same container and shared_name are +// assumed to possibly belong to the same batch. If left empty, the op name will +// be used as the shared name. +func Unbatch(scope *Scope, batched_tensor tf.Output, batch_index tf.Output, id tf.Output, timeout_micros int64, optional ...UnbatchAttr) (unbatched_tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"timeout_micros": timeout_micros} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Unbatch", + Input: []tf.Input{ + batched_tensor, batch_index, id, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MatrixInverseAttr is an optional argument to MatrixInverse. +type MatrixInverseAttr func(optionalAttr) + +// MatrixInverseAdjoint sets the optional adjoint attribute to value. +// If not specified, defaults to false +func MatrixInverseAdjoint(value bool) MatrixInverseAttr { + return func(m optionalAttr) { + m["adjoint"] = value + } +} + +// Computes the inverse of one or more square invertible matrices or their // -// This op is only defined for complex matrices. If A is positive-definite and -// real, then casting to a complex matrix, taking the logarithm and casting back -// to a real matrix will give the correct result. -// -// This function computes the matrix logarithm using the Schur-Parlett algorithm. -// Details of the algorithm can be found in Section 11.6.2 of: -// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. -// ISBN 978-0-898716-46-7. +// adjoints (conjugate transposes). // // The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions // form square matrices. The output is a tensor of the same shape as the input -// containing the exponential for all input submatrices `[..., :, :]`. +// containing the inverse for all input submatrices `[..., :, :]`. +// +// The op uses LU decomposition with partial pivoting to compute the inverses. +// +// If a matrix is not invertible there is no guarantee what the op does. It +// may detect the condition and raise an exception or it may simply return a +// garbage result. // // Arguments: // input: Shape is `[..., M, M]`. // // Returns Shape is `[..., M, M]`. // -// @compatibility(scipy) -// Equivalent to scipy.linalg.logm +// @compatibility(numpy) +// Equivalent to np.linalg.inv // @end_compatibility -func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) { +func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixInverse", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. +// +// N is the size of the segment being reduced. +// +// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. +// +// Read +// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +// for an explanation of segments. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtNWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixLogarithm", + Type: "SparseSegmentSqrtNWithNumSegments", Input: []tf.Input{ - input, + data, indices, segment_ids, num_segments, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Use TensorArrayGradV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 +func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArrayWriteV2", + Input: []tf.Input{ + handle, index, value, flow_in, }, } op := scope.AddOperation(opspec) @@ -37657,127 +39465,37 @@ func FakeParam(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Outpu return op.Output(0) } -// Returns the next representable value of `x1` in the direction of `x2`, element-wise. -// -// This operation returns the same result as the C++ std::nextafter function. -// -// It can also return a subnormal number. -// -// @compatibility(cpp) -// Equivalent to C++ std::nextafter function. -// @end_compatibility -func NextAfter(scope *Scope, x1 tf.Output, x2 tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "NextAfter", - Input: []tf.Input{ - x1, x2, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} +// NthElementAttr is an optional argument to NthElement. +type NthElementAttr func(optionalAttr) -// OrderedMapStageAttr is an optional argument to OrderedMapStage. -type OrderedMapStageAttr func(optionalAttr) - -// OrderedMapStageCapacity sets the optional capacity attribute to value. +// NthElementReverse sets the optional reverse attribute to value. // -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapStageCapacity(value int64) OrderedMapStageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// OrderedMapStageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapStageMemoryLimit(value int64) OrderedMapStageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapStageContainer sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. -// If not specified, defaults to "" -func OrderedMapStageContainer(value string) OrderedMapStageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapStageSharedName sets the optional shared_name attribute to value. -// -// value: It is necessary to match this name to the matching Unstage Op. -// If not specified, defaults to "" -func OrderedMapStageSharedName(value string) OrderedMapStageAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Stage (key, values) in the underlying container which behaves like a ordered -// -// associative container. Elements are ordered by key. -// -// Arguments: -// key: int64 -// -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. -// -// -// Returns the created operation. -func OrderedMapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...OrderedMapStageAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OrderedMapStage", - Input: []tf.Input{ - key, indices, tf.OutputList(values), - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// StackPushV2Attr is an optional argument to StackPushV2. -type StackPushV2Attr func(optionalAttr) - -// StackPushV2SwapMemory sets the optional swap_memory attribute to value. -// -// value: Swap `elem` to CPU. Default to false. +// value: When set to True, find the nth-largest value in the vector and vice +// versa. // If not specified, defaults to false -func StackPushV2SwapMemory(value bool) StackPushV2Attr { +func NthElementReverse(value bool) NthElementAttr { return func(m optionalAttr) { - m["swap_memory"] = value + m["reverse"] = value } } -// Push an element onto the stack. +// Finds values of the `n`-th order statistic for the last dimension. +// +// If the input is a vector (rank-1), finds the entries which is the nth-smallest +// value in the vector and outputs their values as scalar tensor. +// +// For matrices (resp. higher rank input), computes the entries which is the +// nth-smallest value in each row (resp. vector along the last dimension). Thus, +// +// values.shape = input.shape[:-1] // // Arguments: -// handle: The handle to a stack. -// elem: The tensor to be pushed onto the stack. +// input: 1-D or higher with last dimension at least `n+1`. +// n: 0-D. Position of sorted vector to select along the last dimension (along +// each row for matrices). Valid range of n is `[0, input.shape[:-1])` // -// Returns The same tensor as the input 'elem'. -func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...StackPushV2Attr) (output tf.Output) { +// Returns The `n`-th order statistic along each last dimensional slice. +func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) { if scope.Err() != nil { return } @@ -37786,9 +39504,9 @@ func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...Sta a(attrs) } opspec := tf.OpSpec{ - Type: "StackPushV2", + Type: "NthElement", Input: []tf.Input{ - handle, elem, + input, n, }, Attrs: attrs, } @@ -37796,118 +39514,31 @@ func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...Sta return op.Output(0) } -// RpcAttr is an optional argument to Rpc. -type RpcAttr func(optionalAttr) - -// RpcProtocol sets the optional protocol attribute to value. +// Creates a dataset that shards the input dataset. // -// value: RPC protocol to use. Empty string means use the default protocol. -// Options include 'grpc'. -// If not specified, defaults to "" -func RpcProtocol(value string) RpcAttr { - return func(m optionalAttr) { - m["protocol"] = value - } -} - -// RpcFailFast sets the optional fail_fast attribute to value. +// Creates a dataset that shards the input dataset by num_workers, returning a +// sharded dataset for the index-th worker. This attempts to automatically shard +// a dataset by examining the Dataset graph and inserting a shard op before the +// inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset). // -// value: `boolean`. If `true` (default), then failures to connect -// (i.e., the server does not immediately respond) cause an RPC failure. -// If not specified, defaults to true -func RpcFailFast(value bool) RpcAttr { - return func(m optionalAttr) { - m["fail_fast"] = value - } -} - -// RpcTimeoutInMs sets the optional timeout_in_ms attribute to value. -// -// value: `int`. If `0` (default), then the kernel will run the RPC -// request and only time out if the RPC deadline passes or the session times out. -// If this value is greater than `0`, then the op will raise an exception if -// the RPC takes longer than `timeout_in_ms`. -// If not specified, defaults to 0 -func RpcTimeoutInMs(value int64) RpcAttr { - return func(m optionalAttr) { - m["timeout_in_ms"] = value - } -} - -// Perform batches of RPC requests. -// -// This op asynchronously performs either a single RPC request, or a batch -// of requests. RPC requests are defined by three main parameters: -// -// - `address` (the host+port or BNS address of the request) -// - `method` (the RPC method name for the request) -// - `request` (the serialized proto string, or vector of strings, -// of the RPC request argument). -// -// For example, if you have an RPC service running on port localhost:2345, -// and its interface is configured with the following proto declaration: -// -// ``` -// service MyService { -// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { -// } -// }; -// ``` -// -// then call this op with arguments: -// -// ``` -// address = "localhost:2345" -// method = "MyService/MyMethod" -// ``` -// -// The `request` tensor is a string tensor representing serialized `MyRequestProto` -// strings; and the output string tensor `response` will have the same shape -// and contain (upon successful completion) corresponding serialized -// `MyResponseProto` strings. -// -// For example, to send a single, empty, `MyRequestProto`, call -// this op with `request = ""`. To send 5 **parallel** empty requests, -// call this op with `request = ["", "", "", "", ""]`. -// -// More generally, one can create a batch of `MyRequestProto` serialized protos -// from regular batched tensors using the `encode_proto` op, and convert -// the response `MyResponseProto` serialized protos to batched tensors -// using the `decode_proto` op. -// -// **NOTE** Working with serialized proto strings is faster than instantiating -// actual proto objects in memory, so no performance degradation is expected -// compared to writing custom kernels for this workflow. -// -// If the connection fails or the remote worker returns an error -// status, the op reraises this exception locally. -// -// See the `TryRpc` op if you prefer to handle RPC failures manually in the graph. +// This dataset will throw a NotFound error if we cannot shard the dataset +// automatically. // // Arguments: -// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `method` and `request`. -// method: `0-D` or `1-D`. The method address on the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `request`. -// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `method`. +// input_dataset: A variant tensor representing the input dataset. +// num_workers: A scalar representing the number of workers to distribute this dataset across. +// index: A scalar representing the index of the current worker out of num_workers. // -// Returns Same shape as `request`. Serialized proto strings: the rpc responses. -func Rpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...RpcAttr) (response tf.Output) { +// +func ExperimentalAutoShardDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Rpc", + Type: "ExperimentalAutoShardDataset", Input: []tf.Input{ - address, method, request, + input_dataset, num_workers, index, }, Attrs: attrs, } @@ -37932,6 +39563,230 @@ func ExperimentalBytesProducedStatsDataset(scope *Scope, input_dataset tf.Output return op.Output(0) } +// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. +type MatrixSolveLsAttr func(optionalAttr) + +// MatrixSolveLsFast sets the optional fast attribute to value. +// If not specified, defaults to true +func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { + return func(m optionalAttr) { + m["fast"] = value + } +} + +// Solves one or more linear least-squares problems. +// +// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same +// type as `matrix` and shape `[..., M, K]`. +// The output is a tensor shape `[..., N, K]` where each output matrix solves +// each of the equations +// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` +// in the least squares sense. +// +// We use the following notation for (complex) matrix and right-hand sides +// in the batch: +// +// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), +// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), +// `output`=\\(X \in \mathbb{C}^{n \times k}\\), +// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). +// +// If `fast` is `True`, then the solution is computed by solving the normal +// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares +// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). +// If \\(m \lt n\\) then `output` is computed as +// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +// minimum-norm solution to the under-determined linear system, i.e. +// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), +// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable +// when \\(A\\) is numerically full rank and has a condition number +// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is +// sufficiently large. +// +// If `fast` is `False` an algorithm based on the numerically robust complete +// orthogonal decomposition is used. This computes the minimum-norm +// least-squares solution, even when \\(A\\) is rank deficient. This path is +// typically 6-7 times slower than the fast path. If `fast` is `False` then +// `l2_regularizer` is ignored. +// +// Arguments: +// matrix: Shape is `[..., M, N]`. +// rhs: Shape is `[..., M, K]`. +// l2_regularizer: Scalar tensor. +// +// @compatibility(numpy) +// Equivalent to np.linalg.lstsq +// @end_compatibility +// +// Returns Shape is `[..., N, K]`. +func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixSolveLs", + Input: []tf.Input{ + matrix, rhs, l2_regularizer, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Scatter `updates` into a new tensor according to `indices`. +// +// Creates a new tensor by applying sparse `updates` to individual values or +// slices within a tensor (initially zero for numeric, empty for string) of +// the given `shape` according to indices. This operator is the inverse of the +// `tf.gather_nd` operator which extracts values or slices from a given tensor. +// +// This operation is similar to tensor_scatter_add, except that the tensor is +// zero-initialized. Calling `tf.scatter_nd(indices, values, shape)` is identical +// to `tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)` +// +// If `indices` contains duplicates, then their updates are accumulated (summed). +// +// **WARNING**: The order in which updates are applied is nondeterministic, so the +// output will be nondeterministic if `indices` contains duplicates -- because +// of some numerical approximation issues, numbers summed in different order +// may yield different results. +// +// `indices` is an integer tensor containing indices into a new tensor of shape +// `shape`. The last dimension of `indices` can be at most the rank of `shape`: +// +// indices.shape[-1] <= shape.rank +// +// The last dimension of `indices` corresponds to indices into elements +// (if `indices.shape[-1] = shape.rank`) or slices +// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of +// `shape`. `updates` is a tensor with shape +// +// indices.shape[:-1] + shape[indices.shape[-1]:] +// +// The simplest form of scatter is to insert individual elements in a tensor by +// index. For example, say we want to insert 4 scattered elements in a rank-1 +// tensor with 8 elements. +// +//
+// +//
+// +// In Python, this scatter operation would look like this: +// +// ```python +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// shape = tf.constant([8]) +// scatter = tf.scatter_nd(indices, updates, shape) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [0, 11, 0, 10, 9, 0, 0, 12] +// +// We can also, insert entire slices of a higher rank tensor all at once. For +// example, if we wanted to insert two slices in the first dimension of a +// rank-3 tensor with two matrices of new values. +// +//
+// +//
+// +// In Python, this scatter operation would look like this: +// +// ```python +// indices = tf.constant([[0], [2]]) +// updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]], +// [[5, 5, 5, 5], [6, 6, 6, 6], +// [7, 7, 7, 7], [8, 8, 8, 8]]]) +// shape = tf.constant([4, 4, 4]) +// scatter = tf.scatter_nd(indices, updates, shape) +// with tf.Session() as sess: +// print(sess.run(scatter)) +// ``` +// +// The resulting tensor would look like this: +// +// [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], +// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], +// [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], +// [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] +// +// Note that on CPU, if an out of bound index is found, an error is returned. +// On GPU, if an out of bound index is found, the index is ignored. +// +// Arguments: +// indices: Index tensor. +// updates: Updates to scatter into output. +// shape: 1-D. The shape of the resulting tensor. +// +// Returns A new tensor with the given shape and updates applied according +// to the indices. +func ScatterNd(scope *Scope, indices tf.Output, updates tf.Output, shape tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ScatterNd", + Input: []tf.Input{ + indices, updates, shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns (x - y)(x - y) element-wise. +// +// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SquaredDifference", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the cardinality of `input_dataset`. +// +// Returns the cardinality of `input_dataset`. +// +// Arguments: +// input_dataset: A variant tensor representing the dataset to return cardinality for. +// +// Returns The cardinality of `input_dataset`. Named constants are used to represent +// infinite and unknown cardinality. +func ExperimentalDatasetCardinality(scope *Scope, input_dataset tf.Output) (cardinality tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ExperimentalDatasetCardinality", + Input: []tf.Input{ + input_dataset, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // A substitute for `InterleaveDataset` on a fixed list of `N` datasets. // // Arguments: @@ -37957,59 +39812,66 @@ func ExperimentalDirectedInterleaveDataset(scope *Scope, selector_input_dataset return op.Output(0) } -// RandomUniformIntAttr is an optional argument to RandomUniformInt. -type RandomUniformIntAttr func(optionalAttr) - -// RandomUniformIntSeed sets the optional seed attribute to value. +// Bitcasts a tensor from one type to another without copying data. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomUniformIntSeed(value int64) RandomUniformIntAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomUniformIntSeed2 sets the optional seed2 attribute to value. +// Given a tensor `input`, this operation returns a tensor that has the same buffer +// data as `input` with datatype `type`. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomUniformIntSeed2(value int64) RandomUniformIntAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Outputs random integers from a uniform distribution. +// If the input datatype `T` is larger than the output datatype `type` then the +// shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)]. // -// The generated values are uniform integers in the range `[minval, maxval)`. -// The lower bound `minval` is included in the range, while the upper bound -// `maxval` is excluded. +// If `T` is smaller than `type`, the operator requires that the rightmost +// dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from +// [..., sizeof(`type`)/sizeof(`T`)] to [...]. // -// The random integers are slightly biased unless `maxval - minval` is an exact -// power of two. The bias is small for values of `maxval - minval` significantly -// smaller than the range of the output (either `2^32` or `2^64`). +// tf.bitcast() and tf.cast() work differently when real dtype is casted as a complex dtype +// (e.g. tf.complex64 or tf.complex128) as tf.cast() make imaginary part 0 while tf.bitcast() +// gives module error. +// For example, // -// Arguments: -// shape: The shape of the output tensor. -// minval: 0-D. Inclusive lower bound on the generated integers. -// maxval: 0-D. Exclusive upper bound on the generated integers. +// Example 1: +// ```python +// >>> a = [1., 2., 3.] +// >>> equality_bitcast = tf.bitcast(a,tf.complex128) +// tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast] +// >>> equality_cast = tf.cast(a,tf.complex128) +// >>> print(equality_cast) +// tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128) +// ``` +// Example 2: +// ```python +// >>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8) +// +// ``` +// Example 3: +// ```python +// >>> x = [1., 2., 3.] +// >>> y = [0., 2., 3.] +// >>> equality= tf.equal(x,y) +// >>> equality_cast = tf.cast(equality,tf.float32) +// >>> equality_bitcast = tf.bitcast(equality_cast,tf.uint8) +// >>> print(equality) +// tf.Tensor([False True True], shape=(3,), dtype=bool) +// >>> print(equality_cast) +// tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32) +// >>> print(equality_bitcast) +// tf.Tensor( +// [[ 0 0 0 0] +// [ 0 0 128 63] +// [ 0 0 128 63]], shape=(3, 4), dtype=uint8) +// ``` // -// Returns A tensor of the specified shape filled with uniform random integers. -func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) { +// *NOTE*: Bitcast is implemented as a low-level cast, so machines with different +// endian orderings will give different results. +func Bitcast(scope *Scope, input tf.Output, type_ tf.DataType) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"type": type_} opspec := tf.OpSpec{ - Type: "RandomUniformInt", + Type: "Bitcast", Input: []tf.Input{ - shape, minval, maxval, + input, }, Attrs: attrs, } @@ -38017,48 +39879,230 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf return op.Output(0) } -// Add the quantile summaries to each quantile stream resource. +// Returns x / y element-wise for integer types. // -// An op that adds a list of quantile summaries to a quantile stream resource. Each -// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank) -// for a single feature. +// Truncation designates that negative numbers will round fractional quantities +// toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different +// than Python semantics. See `FloorDiv` for a division function that matches +// Python Semantics. // -// Arguments: -// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource. -// summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature. -// -// Returns the created operation. -func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) { +// *NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func TruncateDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesQuantileStreamResourceAddSummaries", + Type: "TruncateDiv", Input: []tf.Input{ - quantile_stream_resource_handle, tf.OutputList(summaries), + x, y, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Creates a Dataset that returns pseudorandom numbers. +// ReverseSequenceAttr is an optional argument to ReverseSequence. +type ReverseSequenceAttr func(optionalAttr) + +// ReverseSequenceBatchDim sets the optional batch_dim attribute to value. +// +// value: The dimension along which reversal is performed. +// If not specified, defaults to 0 +func ReverseSequenceBatchDim(value int64) ReverseSequenceAttr { + return func(m optionalAttr) { + m["batch_dim"] = value + } +} + +// Reverses variable length slices. +// +// This op first slices `input` along the dimension `batch_dim`, and for each +// slice `i`, reverses the first `seq_lengths[i]` elements along +// the dimension `seq_dim`. +// +// The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`, +// and `seq_lengths` must be a vector of length `input.dims[batch_dim]`. +// +// The output slice `i` along dimension `batch_dim` is then given by input +// slice `i`, with the first `seq_lengths[i]` slices along dimension +// `seq_dim` reversed. +// +// For example: +// +// ``` +// # Given this: +// batch_dim = 0 +// seq_dim = 1 +// input.dims = (4, 8, ...) +// seq_lengths = [7, 2, 3, 5] +// +// # then slices of input are reversed on seq_dim, but only up to seq_lengths: +// output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] +// output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] +// output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] +// output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] +// +// # while entries past seq_lens are copied through: +// output[0, 7:, :, ...] = input[0, 7:, :, ...] +// output[1, 2:, :, ...] = input[1, 2:, :, ...] +// output[2, 3:, :, ...] = input[2, 3:, :, ...] +// output[3, 2:, :, ...] = input[3, 2:, :, ...] +// ``` +// +// In contrast, if: +// +// ``` +// # Given this: +// batch_dim = 2 +// seq_dim = 0 +// input.dims = (8, ?, 4, ...) +// seq_lengths = [7, 2, 3, 5] +// +// # then slices of input are reversed on seq_dim, but only up to seq_lengths: +// output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] +// output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] +// output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] +// output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] +// +// # while entries past seq_lens are copied through: +// output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] +// output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] +// output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] +// output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...] +// ``` // // Arguments: -// seed: A scalar seed for the random number generator. If either seed or -// seed2 is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. +// input: The input to reverse. +// seq_lengths: 1-D with length `input.dims(batch_dim)` and +// `max(seq_lengths) <= input.dims(seq_dim)` +// seq_dim: The dimension which is partially reversed. // -// -func ExperimentalRandomDataset(scope *Scope, seed tf.Output, seed2 tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +// Returns The partially reversed input. It has the same shape as `input`. +func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_dim int64, optional ...ReverseSequenceAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"seq_dim": seq_dim} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ReverseSequence", + Input: []tf.Input{ + input, seq_lengths, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that contains the elements of `input_dataset` ignoring errors. +func ExperimentalIgnoreErrorsDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ExperimentalRandomDataset", + Type: "ExperimentalIgnoreErrorsDataset", Input: []tf.Input{ - seed, seed2, + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AudioSummaryAttr is an optional argument to AudioSummary. +type AudioSummaryAttr func(optionalAttr) + +// AudioSummaryMaxOutputs sets the optional max_outputs attribute to value. +// +// value: Max number of batch elements to generate audio for. +// If not specified, defaults to 3 +// +// REQUIRES: value >= 1 +func AudioSummaryMaxOutputs(value int64) AudioSummaryAttr { + return func(m optionalAttr) { + m["max_outputs"] = value + } +} + +// Outputs a `Summary` protocol buffer with audio. +// +// DEPRECATED at GraphDef version 15: Use AudioSummaryV2. +// +// The summary has up to `max_outputs` summary values containing audio. The +// audio is built from `tensor` which must be 3-D with shape `[batch_size, +// frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +// assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. +// +// The `tag` argument is a scalar `Tensor` of type `string`. It is used to +// build the `tag` of the summary values: +// +// * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +// * If `max_outputs` is greater than 1, the summary value tags are +// generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. +// +// Arguments: +// tag: Scalar. Used to build the `tag` attribute of the summary values. +// tensor: 2-D of shape `[batch_size, frames]`. +// sample_rate: The sample rate of the signal in hertz. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate float32, optional ...AudioSummaryAttr) (summary tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"sample_rate": sample_rate} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AudioSummary", + Input: []tf.Input{ + tag, tensor, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes softsign: `features / (abs(features) + 1)`. +func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Softsign", + Input: []tf.Input{ + features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that executes a SQL query and emits rows of the result set. +// +// Arguments: +// driver_name: The database type. Currently, the only supported type is 'sqlite'. +// data_source_name: A connection string to connect to the database. +// query: A SQL query to execute. +// +// +func ExperimentalSqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ExperimentalSqlDataset", + Input: []tf.Input{ + driver_name, data_source_name, query, }, Attrs: attrs, } @@ -38089,84 +40133,22 @@ func ExperimentalMaxIntraOpParallelismDataset(scope *Scope, input_dataset tf.Out return op.Output(0) } -// StringSplitV2Attr is an optional argument to StringSplitV2. -type StringSplitV2Attr func(optionalAttr) - -// StringSplitV2Maxsplit sets the optional maxsplit attribute to value. -// -// value: An `int`. If `maxsplit > 0`, limit of the split of the result. -// If not specified, defaults to -1 -func StringSplitV2Maxsplit(value int64) StringSplitV2Attr { - return func(m optionalAttr) { - m["maxsplit"] = value - } -} - -// Split elements of `source` based on `sep` into a `SparseTensor`. -// -// Let N be the size of source (typically N will be the batch size). Split each -// element of `source` based on `sep` and return a `SparseTensor` -// containing the split tokens. Empty tokens are ignored. -// -// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', -// then the output will be -// ``` -// st.indices = [0, 0; -// 0, 1; -// 1, 0; -// 1, 1; -// 1, 2] -// st.shape = [2, 3] -// st.values = ['hello', 'world', 'a', 'b', 'c'] -// ``` -// -// If `sep` is given, consecutive delimiters are not grouped together and are -// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and -// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty -// string, consecutive whitespace are regarded as a single separator, and the -// result will contain no empty strings at the startor end if the string has -// leading or trailing whitespace. -// -// Note that the above mentioned behavior matches python's str.split. -// -// Arguments: -// input: `1-D` string `Tensor`, the strings to split. -// sep: `0-D` string `Tensor`, the delimiter character. -func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StringSplitV2", - Input: []tf.Input{ - input, sep, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // Creates a dataset that uses a custom thread pool to compute `input_dataset`. // // Arguments: // -// thread_pool: A resource produced by the ThreadPoolHandle op. +// num_threads: Identifies the number of threads to use for the private threadpool. // // -func ExperimentalThreadPoolDataset(scope *Scope, input_dataset tf.Output, thread_pool tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { +func ExperimentalPrivateThreadPoolDataset(scope *Scope, input_dataset tf.Output, num_threads tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "ExperimentalThreadPoolDataset", + Type: "ExperimentalPrivateThreadPoolDataset", Input: []tf.Input{ - input_dataset, thread_pool, + input_dataset, num_threads, }, Attrs: attrs, } @@ -38174,110 +40156,6 @@ func ExperimentalThreadPoolDataset(scope *Scope, input_dataset tf.Output, thread return op.Output(0) } -// Computes softsign: `features / (abs(features) + 1)`. -func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Softsign", - Input: []tf.Input{ - features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// EncodeProtoAttr is an optional argument to EncodeProto. -type EncodeProtoAttr func(optionalAttr) - -// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value. -// If not specified, defaults to "local://" -func EncodeProtoDescriptorSource(value string) EncodeProtoAttr { - return func(m optionalAttr) { - m["descriptor_source"] = value - } -} - -// The op serializes protobuf messages provided in the input tensors. -// -// The types of the tensors in `values` must match the schema for the -// fields specified in `field_names`. All the tensors in `values` must -// have a common shape prefix, *batch_shape*. -// -// The `sizes` tensor specifies repeat counts for each field. The repeat -// count (last dimension) of a each tensor in `values` must be greater -// than or equal to corresponding repeat count in `sizes`. -// -// A `message_type` name must be provided to give context for the field -// names. The actual message descriptor can be looked up either in the -// linked-in descriptor pool or a filename provided by the caller using -// the `descriptor_source` attribute. -// -// The `descriptor_source` attribute selects a source of protocol -// descriptors to consult when looking up `message_type`. This may be a -// filename containing a serialized `FileDescriptorSet` message, -// or the special value `local://`, in which case only descriptors linked -// into the code will be searched; the filename can be on any filesystem -// accessible to TensorFlow. -// -// You can build a `descriptor_source` file using the `--descriptor_set_out` -// and `--include_imports` options to the protocol compiler `protoc`. -// -// The `local://` database only covers descriptors linked into the -// code via C++ libraries, not Python imports. You can link in a proto descriptor -// by creating a cc_library target with alwayslink=1. -// -// There are a few special cases in the value mapping: -// -// Submessage and group fields must be pre-serialized as TensorFlow strings. -// -// TensorFlow lacks support for unsigned int64s, so they must be -// represented as `tf.int64` with the same twos-complement bit pattern -// (the obvious way). -// -// Unsigned int32 values can be represented exactly with `tf.int64`, or -// with sign wrapping if the input is of type `tf.int32`. -// -// Arguments: -// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`. -// values: List of tensors containing values for the corresponding field. -// field_names: List of strings containing proto field names. -// message_type: Name of the proto message type to decode. -// -// Returns Tensor of serialized protos with shape `batch_shape`. -func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodeProto", - Input: []tf.Input{ - sizes, tf.OutputList(values), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates an Optional variant with no value. -func OptionalNone(scope *Scope) (optional tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "OptionalNone", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // DecodeProtoV2Attr is an optional argument to DecodeProtoV2. type DecodeProtoV2Attr func(optionalAttr) @@ -38405,6 +40283,85 @@ func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_nam return sizes, values } +// StringSplitAttr is an optional argument to StringSplit. +type StringSplitAttr func(optionalAttr) + +// StringSplitSkipEmpty sets the optional skip_empty attribute to value. +// +// value: A `bool`. If `True`, skip the empty strings from the result. +// If not specified, defaults to true +func StringSplitSkipEmpty(value bool) StringSplitAttr { + return func(m optionalAttr) { + m["skip_empty"] = value + } +} + +// Split elements of `input` based on `delimiter` into a `SparseTensor`. +// +// Let N be the size of source (typically N will be the batch size). Split each +// element of `input` based on `delimiter` and return a `SparseTensor` +// containing the splitted tokens. Empty tokens are ignored. +// +// `delimiter` can be empty, or a string of split characters. If `delimiter` is an +// empty string, each element of `input` is split into individual single-byte +// character strings, including splitting of UTF-8 multibyte sequences. Otherwise +// every character of `delimiter` is a potential split point. +// +// For example: +// N = 2, input[0] is 'hello world' and input[1] is 'a b c', then the output +// will be +// +// indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// shape = [2, 3] +// values = ['hello', 'world', 'a', 'b', 'c'] +// +// Arguments: +// input: 1-D. Strings to split. +// delimiter: 0-D. Delimiter characters (bytes), or empty string. +// +// Returns A dense matrix of int64 representing the indices of the sparse tensor.A vector of strings corresponding to the splited values.a length-2 vector of int64 representing the shape of the sparse +// tensor, where the first value is N and the second value is the maximum number +// of tokens in a single input entry. +func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ...StringSplitAttr) (indices tf.Output, values tf.Output, shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StringSplit", + Input: []tf.Input{ + input, delimiter, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Creates a dataset that emits each dim-0 slice of `components` once. +func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "TensorSliceDataset", + Input: []tf.Input{ + tf.OutputList(components), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Creates a dataset that splits a SparseTensor into elements row-wise. func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) { if scope.Err() != nil { @@ -38420,26 +40377,6 @@ func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, return op.Output(0) } -// Returns x / y element-wise for real types. -// -// If `x` and `y` are reals, this will return the floating-point division. -// -// *NOTE*: `Div` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RealDiv", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Creates a dataset that concatenates `input_dataset` with `another_dataset`. func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { @@ -38457,51 +40394,23 @@ func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset t return op.Output(0) } -// Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. -// -// The `input` tensor has shape `[batch, in_height, in_width, depth]` and the -// `filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each -// input channel is processed independently of the others with its own structuring -// function. The `output` tensor has shape -// `[batch, out_height, out_width, depth]`. The spatial dimensions of the output -// tensor depend on the `padding` algorithm. We currently only support the default -// "NHWC" `data_format`. -// -// In detail, the grayscale morphological 2-D dilation is the max-sum correlation -// (for consistency with `conv2d`, we use unmirrored filters): -// -// output[b, y, x, c] = -// max_{dy, dx} input[b, -// strides[1] * y + rates[1] * dy, -// strides[2] * x + rates[2] * dx, -// c] + -// filter[dy, dx, c] -// -// Max-pooling is a special case when the filter has size equal to the pooling -// kernel size and contains all zeros. -// -// Note on duality: The dilation of `input` by the `filter` is equal to the -// negation of the erosion of `-input` by the reflected `filter`. +// Creates a dataset that emits the outputs of `input_dataset` `count` times. // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// strides: The stride of the sliding window for each dimension of the input -// tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: The input stride for atrous morphological dilation. Must be: -// `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. // -// Returns 4-D with shape `[batch, out_height, out_width, depth]`. -func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, rates []int64, padding string) (output tf.Output) { +// count: A scalar representing the number of times that `input_dataset` should +// be repeated. A value of `-1` indicates that it should be repeated infinitely. +// +// +func RepeatDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Dilation2D", + Type: "RepeatDataset", Input: []tf.Input{ - input, filter, + input_dataset, count, }, Attrs: attrs, } @@ -38509,465 +40418,91 @@ func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64 return op.Output(0) } -// Converts the given variant tensor to an iterator and stores it in the given resource. +// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2. +type QueueEnqueueManyV2Attr func(optionalAttr) + +// QueueEnqueueManyV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue is too full, this operation will block for up +// to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueEnqueueManyV2TimeoutMs(value int64) QueueEnqueueManyV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Enqueues zero or more tuples of one or more tensors in the given queue. +// +// This operation slices each component tensor along the 0th dimension to +// make multiple queue elements. All of the tuple components must have the +// same size in the 0th dimension. +// +// The components input has k elements, which correspond to the components of +// tuples stored in the given queue. +// +// N.B. If the queue is full, this operation will block until the given +// elements have been enqueued (or 'timeout_ms' elapses, if specified). // // Arguments: -// resource_handle: A handle to an iterator resource. -// serialized: A variant tensor storing the state of the iterator contained in the -// resource. +// handle: The handle to a queue. +// components: One or more tensors from which the enqueued tensors should +// be taken. // // Returns the created operation. -func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { +func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueManyV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DeserializeIterator", + Type: "QueueEnqueueManyV2", Input: []tf.Input{ - resource_handle, serialized, + handle, tf.OutputList(components), }, + Attrs: attrs, } return scope.AddOperation(opspec) } -// Creates a dataset that shuffles and repeats elements from `input_dataset` +// Worker heartbeat op. // -// pseudorandomly. +// Heartbeats may be sent periodically to indicate the coordinator is still active, +// to retrieve the current worker status and to expedite shutdown when necessary. // // Arguments: +// request: A string tensor containing a serialized WorkerHeartbeatRequest // -// buffer_size: The number of output elements to buffer in an iterator over -// this dataset. Compare with the `min_after_dequeue` attr when creating a -// `RandomShuffleQueue`. -// seed: A scalar seed for the random number generator. If either `seed` or -// `seed2` is set to be non-zero, the random number generator is seeded -// by the given seed. Otherwise, a random seed is used. -// seed2: A second scalar seed to avoid seed collision. -// count: A scalar representing the number of times the underlying dataset -// should be repeated. The default is `-1`, which results in infinite repetition. -// -// -func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ShuffleAndRepeatDataset", - Input: []tf.Input{ - input_dataset, buffer_size, seed, seed2, count, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that caches elements from `input_dataset`. -// -// A CacheDataset will iterate over the input_dataset, and store tensors. If the -// cache already exists, the cache will be used. If the cache is inappropriate -// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error -// will the returned when used. -// -// Arguments: -// -// filename: A path on the filesystem where we should cache the dataset. Note: this -// will be a directory. -// -// -func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "CacheDataset", - Input: []tf.Input{ - input_dataset, filename, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that emits the records from one or more binary files. -// -// Arguments: -// filenames: A scalar or a vector containing the name(s) of the file(s) to be -// read. -// header_bytes: A scalar representing the number of bytes to skip at the -// beginning of a file. -// record_bytes: A scalar representing the number of bytes in each record. -// footer_bytes: A scalar representing the number of bytes to skip at the end -// of a file. -// buffer_size: A scalar representing the number of bytes to buffer. Must be > 0. -func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf.Output, record_bytes tf.Output, footer_bytes tf.Output, buffer_size tf.Output) (handle tf.Output) { +// Returns A string tensor containing a serialized WorkerHeartbeatResponse +func WorkerHeartbeat(scope *Scope, request tf.Output) (response tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FixedLengthRecordDataset", + Type: "WorkerHeartbeat", Input: []tf.Input{ - filenames, header_bytes, record_bytes, footer_bytes, buffer_size, + request, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Gradients for batch normalization. +// Computes square of x element-wise. // -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() -// -// This op is deprecated. See `tf.nn.batch_normalization`. -// -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this Tensor will be multiplied -// with the normalized Tensor. -// backprop: 4D backprop Tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -// -// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. -func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} - opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalizationGrad", - Input: []tf.Input{ - t, m, v, gamma, backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// Creates a dataset that emits the records from one or more TFRecord files. -// -// Arguments: -// filenames: A scalar or vector containing the name(s) of the file(s) to be -// read. -// compression_type: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// buffer_size: A scalar representing the number of bytes to buffer. A value of -// 0 means no buffering will be performed. -func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Output, buffer_size tf.Output) (handle tf.Output) { +// I.e., \\(y = x * x = x^2\\). +func Square(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TFRecordDataset", + Type: "Square", Input: []tf.Input{ - filenames, compression_type, buffer_size, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ExperimentalStatsAggregatorHandleAttr is an optional argument to ExperimentalStatsAggregatorHandle. -type ExperimentalStatsAggregatorHandleAttr func(optionalAttr) - -// ExperimentalStatsAggregatorHandleContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func ExperimentalStatsAggregatorHandleContainer(value string) ExperimentalStatsAggregatorHandleAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// ExperimentalStatsAggregatorHandleSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func ExperimentalStatsAggregatorHandleSharedName(value string) ExperimentalStatsAggregatorHandleAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a statistics manager resource. -func ExperimentalStatsAggregatorHandle(scope *Scope, optional ...ExperimentalStatsAggregatorHandleAttr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ExperimentalStatsAggregatorHandle", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// A container for an iterator resource. -// -// Returns A handle to the iterator that can be passed to a "MakeIterator" or -// "IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents -// resource sharing by name, and does not keep a reference to the resource -// container. -func AnonymousIterator(scope *Scope, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "AnonymousIterator", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Adjust the contrast of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are -// interpreted as `[height, width, channels]`. The other dimensions only -// represent a collection of images, such as `[batch, height, width, channels].` -// -// Contrast is adjusted independently for each channel of each image. -// -// For each channel, the Op first computes the mean of the image pixels in the -// channel and then adjusts each component of each pixel to -// `(x - mean) * contrast_factor + mean`. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// contrast_factor: A float multiplier for adjusting contrast. -// -// Returns The contrast-adjusted image or images. -func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustContrastv2", - Input: []tf.Input{ - images, contrast_factor, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Gets the next output from the given iterator . -func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "IteratorGetNext", - Input: []tf.Input{ - iterator, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("IteratorGetNext", err) - return - } - return components -} - -// Outputs the single element from the given dataset. -// -// Arguments: -// dataset: A handle to a dataset that contains a single element. -// -// -// -// Returns The components of the single element of `input`. -func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "DatasetToSingleElement", - Input: []tf.Input{ - dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("DatasetToSingleElement", err) - return - } - return components -} - -// Converts the given `resource_handle` representing an iterator to a string. -// -// Arguments: -// resource_handle: A handle to an iterator resource. -// -// Returns A string representation of the given handle. -func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_handle tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IteratorToStringHandle", - Input: []tf.Input{ - resource_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// IteratorFromStringHandleAttr is an optional argument to IteratorFromStringHandle. -type IteratorFromStringHandleAttr func(optionalAttr) - -// IteratorFromStringHandleOutputTypes sets the optional output_types attribute to value. -// -// value: If specified, defines the type of each tuple component in an -// element produced by the resulting iterator. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func IteratorFromStringHandleOutputTypes(value []tf.DataType) IteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_types"] = value - } -} - -// IteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value. -// -// value: If specified, defines the shape of each tuple component in an -// element produced by the resulting iterator. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func IteratorFromStringHandleOutputShapes(value []tf.Shape) IteratorFromStringHandleAttr { - return func(m optionalAttr) { - m["output_shapes"] = value - } -} - -// Converts the given string representing a handle to an iterator to a resource. -// -// Arguments: -// string_handle: A string representation of the given handle. -// -// Returns A handle to an iterator resource. -func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...IteratorFromStringHandleAttr) (resource_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "IteratorFromStringHandle", - Input: []tf.Input{ - string_handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Gather slices from `params` axis `axis` according to `indices`. -// -// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -// Produces an output tensor with shape `params.shape[:axis] + indices.shape + -// params.shape[axis + 1:]` where: -// -// ```python -// # Scalar indices (output is rank(params) - 1). -// output[a_0, ..., a_n, b_0, ..., b_n] = -// params[a_0, ..., a_n, indices, b_0, ..., b_n] -// -// # Vector indices (output is rank(params)). -// output[a_0, ..., a_n, i, b_0, ..., b_n] = -// params[a_0, ..., a_n, indices[i], b_0, ..., b_n] -// -// # Higher rank indices (output is rank(params) + rank(indices) - 1). -// output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = -// params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] -// ``` -// -//
-// -//
-// -// Note that on CPU, if an out of bound index is found, an error is returned. -// On GPU, if an out of bound index is found, a 0 is stored in the -// corresponding output value. -// -// See also `tf.batch_gather` and `tf.gather_nd`. -// -// Arguments: -// params: The tensor from which to gather values. Must be at least rank -// `axis + 1`. -// indices: Index tensor. Must be in range `[0, params.shape[axis])`. -// axis: The axis in `params` to gather `indices` from. Defaults to the first -// dimension. Supports negative indexes. -// -// Returns Values from `params` gathered from indices given by `indices`, with -// shape `params.shape[:axis] + indices.shape + params.shape[axis + 1:]`. -func GatherV2(scope *Scope, params tf.Output, indices tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "GatherV2", - Input: []tf.Input{ - params, indices, axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Converts the given `resource_handle` representing an iterator to a variant tensor. -// -// Arguments: -// resource_handle: A handle to an iterator resource. -// -// Returns A variant tensor storing the state of the iterator contained in the -// resource. -func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SerializeIterator", - Input: []tf.Input{ - resource_handle, + x, }, } op := scope.AddOperation(opspec) diff --git a/tensorflow/go/test.sh b/tensorflow/go/test.sh index 47c3a683791..b75076563f6 100755 --- a/tensorflow/go/test.sh +++ b/tensorflow/go/test.sh @@ -40,6 +40,7 @@ fi # Setup a GOPATH that includes just the TensorFlow Go API. export GOPATH="${TEST_TMPDIR}/go" +export GOCACHE="${TEST_TMPDIR}/cache" mkdir -p "${GOPATH}/src/github.com/tensorflow" ln -s "${PWD}" "${GOPATH}/src/github.com/tensorflow/tensorflow" diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index d70e0d6c0ab..6a71cd1e9da 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -23,7 +23,7 @@ java_library( ":java_op_sources", ":java_sources", ], - data = [":libtensorflow_jni"], + data = [":libtensorflow_jni"] + tf_binary_additional_srcs(), javacopts = JAVACOPTS, plugins = [":processor"], visibility = ["//visibility:public"], @@ -133,6 +133,45 @@ java_library( deps = [":tensorflow"], ) +tf_java_test( + name = "EagerSessionTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/EagerSessionTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.EagerSessionTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + +tf_java_test( + name = "EagerOperationBuilderTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/EagerOperationBuilderTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.EagerOperationBuilderTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + +tf_java_test( + name = "EagerOperationTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/EagerOperationTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.EagerOperationTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + tf_java_test( name = "GraphTest", size = "small", @@ -147,11 +186,11 @@ tf_java_test( ) tf_java_test( - name = "OperationBuilderTest", + name = "GraphOperationBuilderTest", size = "small", - srcs = ["src/test/java/org/tensorflow/OperationBuilderTest.java"], + srcs = ["src/test/java/org/tensorflow/GraphOperationBuilderTest.java"], javacopts = JAVACOPTS, - test_class = "org.tensorflow.OperationBuilderTest", + test_class = "org.tensorflow.GraphOperationBuilderTest", deps = [ ":tensorflow", ":testutil", @@ -160,11 +199,11 @@ tf_java_test( ) tf_java_test( - name = "OperationTest", + name = "GraphOperationTest", size = "small", - srcs = ["src/test/java/org/tensorflow/OperationTest.java"], + srcs = ["src/test/java/org/tensorflow/GraphOperationTest.java"], javacopts = JAVACOPTS, - test_class = "org.tensorflow.OperationTest", + test_class = "org.tensorflow.GraphOperationTest", deps = [ ":tensorflow", ":testutil", diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index db6116bd5c8..3eb8c5c7129 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -160,12 +160,11 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class, } Method factory = Method::Create("create", return_type); Javadoc factory_doc = Javadoc::Create( - "Factory method to create a class to wrap a new " + op_class.name() + - " operation to the graph, using " - "default output types."); + "Factory method to create a class wrapping a new " + op_class.name() + + " operation using default output types."); Variable scope = Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); - AddArgument(scope, "current graph scope", &factory, &factory_doc); + AddArgument(scope, "current scope", &factory, &factory_doc); std::stringstream factory_statement; factory_statement << "return create(scope"; for (const ArgumentSpec& input : op.inputs()) { @@ -202,11 +201,11 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, SourceWriter* writer) { Method factory = Method::Create("create", op_class); Javadoc factory_doc = - Javadoc::Create("Factory method to create a class to wrap a new " + - op_class.name() + " operation to the graph."); + Javadoc::Create("Factory method to create a class wrapping a new " + + op_class.name() + " operation."); Variable scope = Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); - AddArgument(scope, "current graph scope", &factory, &factory_doc); + AddArgument(scope, "current scope", &factory, &factory_doc); for (const ArgumentSpec& input : op.inputs()) { AddArgument(input.var(), input.description(), &factory, &factory_doc); } @@ -229,7 +228,7 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, factory_doc.add_tag("return", "a new instance of " + op_class.name()); writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc); - writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + + writer->Append("OperationBuilder opBuilder = scope.env().opBuilder(\"" + op.graph_op_name() + "\", scope.makeOpName(\"" + op_class.name() + "\"));"); writer->EndLine(); @@ -244,6 +243,10 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, writer->EndLine(); } } + // Add control dependencies, if any. + writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);"); + writer->EndLine(); + for (const AttributeSpec& attribute : op.attributes()) { WriteSetAttrDirective(attribute, false, writer); } diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index b46721a93dc..3db5bd75d7f 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -49,16 +49,15 @@ def tf_java_op_gen_srcjar( # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar" - gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"] + gen_cmds += ["$(JAVABASE)/bin/jar cMf $(location :" + gen_srcjar + ") -C $(@D) src"] native.genrule( name = name, srcs = srcs, outs = [gen_srcjar], tools = [ - "@local_jdk//:jar", - "@local_jdk//:jdk", gen_tool, ] + tf_binary_additional_srcs(), + toolchains = ["@bazel_tools//tools/jdk:current_host_java_runtime"], cmd = " && ".join(gen_cmds), ) diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index df1426ad751..e3e40a17df5 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,18 +15,6 @@ limitations under the License. package org.tensorflow.processor; -import com.google.common.base.CaseFormat; -import com.google.common.base.Strings; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.JavaFile; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; import java.io.IOException; import java.util.Collection; import java.util.Collections; @@ -35,6 +23,7 @@ import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; + import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; @@ -55,6 +44,21 @@ import javax.lang.model.util.ElementFilter; import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.ParameterizedTypeName; +import com.squareup.javapoet.WildcardTypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; + /** * A compile-time Processor that aggregates classes annotated with {@link * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please @@ -72,7 +76,7 @@ public final class OperatorProcessor extends AbstractProcessor { @Override public SourceVersion getSupportedSourceVersion() { - return SourceVersion.latestSupported(); + return SourceVersion.latest(); } @Override @@ -148,12 +152,21 @@ public final class OperatorProcessor extends AbstractProcessor { private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final TypeName T_OP = ClassName.get("org.tensorflow.op", "Op"); private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator"); private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); - private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); + private static final TypeName T_EXEC_ENV = + ClassName.get("org.tensorflow", "ExecutionEnvironment"); private static final TypeName T_STRING = ClassName.get(String.class); + // Operand + private static final TypeName T_OPERAND = + ParameterizedTypeName.get( + ClassName.get("org.tensorflow", "Operand"), WildcardTypeName.subtypeOf(Object.class)); + // Iterable> + private static final TypeName T_ITERABLE_OPERAND = + ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OPERAND); private Filer filer; private Messager messager; @@ -271,10 +284,7 @@ public final class OperatorProcessor extends AbstractProcessor { private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { StringBuilder javadoc = new StringBuilder(); - javadoc - .append("Adds an {@link ") - .append(opClassName.simpleName()) - .append("} operation to the graph\n\n"); + javadoc.append("Builds an {@link ").append(opClassName.simpleName()).append("} operation\n\n"); // Add all javadoc tags found in the operator factory method but the first one, which should be // in all cases the @@ -305,10 +315,10 @@ public final class OperatorProcessor extends AbstractProcessor { TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") .addModifiers(Modifier.PUBLIC, Modifier.FINAL) .addJavadoc( - "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + "An API for building {@code $L} operations as {@link $T Op}s\n\n" + "@see {@link $T}\n", group, - T_GRAPH, + T_OP, T_OPS) .addMethods(methods) .addMethod(ctorBuilder.build()); @@ -335,7 +345,7 @@ public final class OperatorProcessor extends AbstractProcessor { TypeSpec.classBuilder("Ops") .addModifiers(Modifier.PUBLIC, Modifier.FINAL) .addJavadoc( - "An API for building a {@link $T} with operation wrappers\n

\n" + "An API for building operations as {@link $T Op}s\n

\n" + "Any operation wrapper found in the classpath properly annotated as an" + "{@link $T @Operator} is exposed\n" + "by this API or one of its subgroup.\n

Example usage:\n

{@code\n"
@@ -357,13 +367,13 @@ public final class OperatorProcessor extends AbstractProcessor {
                     + "  // Optional attributes\n"
                     + "  ops.math().matMul(a, b, MatMul.transposeA(true));\n"
                     + "  // Naming operators\n"
-                    + "  ops.withName(“foo”).constant(5); // name “foo”\n"
+                    + "  ops.withName(\"foo\").constant(5); // name \"foo\"\n"
                     + "  // Names can exist in a hierarchy\n"
-                    + "  Ops sub = ops.withSubScope(“sub”);\n"
-                    + "  sub.withName(“bar”).constant(4); // “sub/bar”\n"
+                    + "  Ops sub = ops.withSubScope(\"sub\");\n"
+                    + "  sub.withName(\"bar\").constant(4); // \"sub/bar\"\n"
                     + "}\n"
                     + "}
\n", - T_GRAPH, + T_OP, T_OPERATOR) .addMethods(methods) .addMethod(ctorBuilder.build()); @@ -375,7 +385,7 @@ public final class OperatorProcessor extends AbstractProcessor { .returns(T_OPS) .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) .addJavadoc( - "Returns an API that adds operations to the graph with the provided name prefix.\n" + "Returns an API that builds operations with the provided name prefix.\n" + "\n@see {@link $T#withSubScope(String)}\n", T_SCOPE) .build()); @@ -392,6 +402,18 @@ public final class OperatorProcessor extends AbstractProcessor { T_SCOPE) .build()); + opsBuilder.addMethod( + MethodSpec.methodBuilder("withControlDependencies") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_ITERABLE_OPERAND, "controls") + .returns(T_OPS) + .addStatement("return new Ops(scope.withControlDependencies(controls))") + .addJavadoc( + "Returns an API that adds operations to the graph with the provided control dependencies.\n\n" + + "@see {@link $T#withControlDependencies(Iterable>)}\n", + T_SCOPE) + .build()); + opsBuilder.addField( FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); @@ -414,18 +436,17 @@ public final class OperatorProcessor extends AbstractProcessor { .addModifiers(Modifier.PUBLIC, Modifier.FINAL) .returns(entry.getValue()) .addStatement("return $L", entry.getKey()) - .addJavadoc( - "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .addJavadoc("Returns an API for building {@code $L} operations\n", entry.getKey()) .build()); } opsBuilder.addMethod( MethodSpec.methodBuilder("create") .addModifiers(Modifier.PUBLIC, Modifier.STATIC) - .addParameter(T_GRAPH, "graph") + .addParameter(T_EXEC_ENV, "env") .returns(T_OPS) - .addStatement("return new Ops(new $T(graph))", T_SCOPE) - .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .addStatement("return new Ops(new $T(env))", T_SCOPE) + .addJavadoc("Creates an API for building operations in the provided environment\n") .build()); return opsBuilder.build(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java new file mode 100644 index 00000000000..23b1753e2cb --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/AbstractOperation.java @@ -0,0 +1,87 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** + * Base class for {@link Operation} implementations. + * + *

As opposed to {@link Operation} itself, this class is package private and therefore its usage + * is limited to internal purposes only. + */ +abstract class AbstractOperation implements Operation { + + @Override + public Output[] outputList(int idx, int length) { + Output[] outputs = new Output[length]; + for (int i = 0; i < length; ++i) { + outputs[i] = output(idx + i); + } + return outputs; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Output output(int idx) { + return new Output(this, idx); + } + + @Override + public String toString() { + return String.format("<%s '%s'>", type(), name()); + } + + /** + * Returns the native handle of the {@code outputIdx}th output of this operation. + * + *

The nature of the returned value varies depending on current the execution environment. + * + *

    + *
  • In eager mode, the value is a handle to the tensor returned at this output. + *
  • In graph mode, the value is a handle to the operation itself, which should be paired with + * the index of the output when calling the native layer. + *
+ * + * @param outputIdx index of the output in this operation + * @return a native handle, see method description for more details + */ + abstract long getUnsafeNativeHandle(int outputIdx); + + /** + * Returns the shape of the tensor of the {@code outputIdx}th output of this operation. + * + * @param outputIdx index of the output of this operation + * @return output tensor shape + */ + abstract long[] shape(int outputIdx); + + /** + * Returns the datatype of the tensor of the {@code outputIdx}th output of this operation. + * + * @param outputIdx index of the output of this operation + * @return output tensor datatype + */ + abstract DataType dtype(int outputIdx); + + /** + * Returns the tensor of the {@code outputIdx}th output of this operation. + * + *

This is only supported in an eager execution environment. + * + * @param outputIdx index of the output of this operation + * @return output tensor + */ + abstract Tensor tensor(int outputIdx); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java new file mode 100644 index 00000000000..b6e1fa6db6c --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperation.java @@ -0,0 +1,172 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import java.util.concurrent.atomic.AtomicReferenceArray; + +/** + * Implementation of an {@link Operation} executed eagerly. + * + *

EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of + * is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the + * EagerOperation instance may fail with an {@code IllegalStateException}. + * + *

EagerOperation instances are thread-safe. + */ +class EagerOperation extends AbstractOperation { + + EagerOperation( + EagerSession session, + long opNativeHandle, + long[] outputNativeHandles, + String type, + String name) { + this.session = session; + this.type = type; + this.name = name; + this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles); + this.outputTensors = new AtomicReferenceArray>(outputNativeHandles.length); + } + + @Override + public String name() { + return name; + } + + @Override + public String type() { + return type; + } + + @Override + public int numOutputs() { + return nativeRef.outputHandles.length; + } + + @Override + public int outputListLength(final String name) { + return outputListLength(nativeRef.opHandle, name); + } + + @Override + public int inputListLength(final String name) { + return inputListLength(nativeRef.opHandle, name); + } + + @Override + public long getUnsafeNativeHandle(int outputIndex) { + return nativeRef.outputHandles[outputIndex]; + } + + @Override + public long[] shape(int outputIndex) { + // If the tensor of this output has already been resolved, return its shape. + // Otherwise, retrieve the tensor shape from the native library. + Tensor tensor = outputTensors.get(outputIndex); + if (tensor != null) { + return tensor.shape(); + } + long outputNativeHandle = getUnsafeNativeHandle(outputIndex); + long[] shape = new long[numDims(outputNativeHandle)]; + for (int i = 0; i < shape.length; ++i) { + shape[i] = dim(outputNativeHandle, i); + } + return shape; + } + + @Override + public DataType dtype(int outputIndex) { + // If the tensor of this output has already been resolved, return its datatype. + // Otherwise, retrieve the tensor datatype from the native library. + Tensor tensor = outputTensors.get(outputIndex); + if (tensor != null) { + return tensor.dataType(); + } + long outputNativeHandle = getUnsafeNativeHandle(outputIndex); + return DataType.fromC(dataType(outputNativeHandle)); + } + + @Override + public Tensor tensor(int outputIndex) { + Tensor tensor = outputTensors.get(outputIndex); + if (tensor == null) { + tensor = resolveTensor(outputIndex); + } + return tensor; + } + + private final EagerSession session; + private final NativeReference nativeRef; + private final String type; + private final String name; + private final AtomicReferenceArray> outputTensors; + + private Tensor resolveTensor(int outputIndex) { + // Take an optimistic approach, where we attempt to resolve the output tensor without locking. + // If another thread has resolved it meanwhile, release our copy and reuse the existing one + // instead. + long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex)); + Tensor tensor = Tensor.fromHandle(tensorNativeHandle, session); + if (!outputTensors.compareAndSet(outputIndex, null, tensor)) { + tensor.close(); + tensor = outputTensors.get(outputIndex); + } + return tensor; + } + + private static class NativeReference extends EagerSession.NativeReference { + + NativeReference( + EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) { + super(session, operation); + this.opHandle = opHandle; + this.outputHandles = outputHandles; + } + + @Override + void delete() { + if (opHandle != 0L) { + for (int i = 0; i < outputHandles.length; ++i) { + if (outputHandles[i] != 0L) { + EagerOperation.deleteTensorHandle(outputHandles[i]); + outputHandles[i] = 0L; + } + } + EagerOperation.delete(opHandle); + opHandle = 0L; + } + } + + private long opHandle; + private final long[] outputHandles; + } + + private static native void delete(long handle); + + private static native void deleteTensorHandle(long handle); + + private static native long resolveTensorHandle(long handle); + + private static native int outputListLength(long handle, String name); + + private static native int inputListLength(long handle, String name); + + private static native int dataType(long handle); + + private static native int numDims(long handle); + + private static native long dim(long handle, int index); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java new file mode 100644 index 00000000000..7e5a9a778a4 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -0,0 +1,258 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +/** + * An {@link OperationBuilder} for building {@link Operation Operations} that are executed eagerly. + */ +final class EagerOperationBuilder implements OperationBuilder { + + EagerOperationBuilder(EagerSession session, String type, String name) { + this.session = session; + this.type = type; + this.name = name; + this.nativeRef = new NativeReference(session, this, allocate(session.nativeHandle(), type)); + } + + @Override + public EagerOperation build() { + long[] tensorHandles = execute(nativeRef.opHandle); + EagerOperation operation = + new EagerOperation(session, nativeRef.opHandle, tensorHandles, type, name); + // Release our reference to the native op handle now that we transferred its + // ownership to the EagerOperation + nativeRef.clear(); + return operation; + } + + @Override + public EagerOperationBuilder addInput(Output input) { + addInput(nativeRef.opHandle, input.getUnsafeNativeHandle()); + return this; + } + + @Override + public EagerOperationBuilder addInputList(Output[] inputs) { + long[] inputHandles = new long[inputs.length]; + for (int i = 0; i < inputs.length; ++i) { + inputHandles[i] = inputs[i].getUnsafeNativeHandle(); + } + addInputList(nativeRef.opHandle, inputHandles); + return this; + } + + @Override + public OperationBuilder addControlInput(Operation control) { + throw new UnsupportedOperationException( + "Control inputs are not supported in an eager execution environment"); + } + + @Override + public EagerOperationBuilder setDevice(String device) { + setDevice(nativeRef.opHandle, device); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, String value) { + return setAttr(name, value.getBytes(StandardCharsets.UTF_8)); + } + + @Override + public EagerOperationBuilder setAttr(String name, String[] values) { + Charset utf8 = StandardCharsets.UTF_8; + Object[] objects = new Object[values.length]; + for (int i = 0; i < values.length; ++i) { + objects[i] = values[i].getBytes(utf8); + } + setAttrStringList(nativeRef.opHandle, name, values); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, byte[] values) { + setAttrString(nativeRef.opHandle, name, values); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, long value) { + setAttrInt(nativeRef.opHandle, name, value); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, long[] values) { + setAttrIntList(nativeRef.opHandle, name, values); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, float value) { + setAttrFloat(nativeRef.opHandle, name, value); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, float[] values) { + setAttrFloatList(nativeRef.opHandle, name, values); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, boolean value) { + setAttrBool(nativeRef.opHandle, name, value); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, boolean[] values) { + setAttrBoolList(nativeRef.opHandle, name, values); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, DataType value) { + setAttrType(nativeRef.opHandle, name, value.c()); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, DataType[] values) { + int[] c = new int[values.length]; + for (int i = 0; i < values.length; ++i) { + c[i] = values[i].c(); + } + setAttrTypeList(nativeRef.opHandle, name, c); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, Tensor value) { + setAttrTensor(nativeRef.opHandle, name, value.getNativeHandle()); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, Tensor[] values) { + // TODO (karllessard) could be supported by adding this attribute type in the eager C API + throw new UnsupportedOperationException( + "Tensor list attributes are not supported in eager mode"); + } + + @Override + public EagerOperationBuilder setAttr(String name, Shape value) { + setAttrShape(nativeRef.opHandle, name, value.asArray(), value.numDimensions()); + return this; + } + + @Override + public EagerOperationBuilder setAttr(String name, Shape[] values) { + int[] numDimensions = new int[values.length]; + int totalNumDimensions = 0; + for (int idx = 0; idx < values.length; ++idx) { + int n = values[idx].numDimensions(); + numDimensions[idx] = n; + if (n > 0) { + totalNumDimensions += n; + } + } + // Flatten the shapes into a single array to avoid too much overhead in the + // native part + long[] shapes = new long[totalNumDimensions]; + int shapeIdx = 0; + for (Shape shape : values) { + if (shape.numDimensions() > 0) { + for (long dim : shape.asArray()) { + shapes[shapeIdx++] = dim; + } + } + } + setAttrShapeList(nativeRef.opHandle, name, shapes, numDimensions); + return this; + } + + private static class NativeReference extends EagerSession.NativeReference { + + NativeReference(EagerSession session, EagerOperationBuilder operation, long opHandle) { + super(session, operation); + this.opHandle = opHandle; + } + + @Override + public void clear() { + super.clear(); + opHandle = 0L; + } + + @Override + synchronized void delete() { + if (opHandle != 0L) { + EagerOperationBuilder.delete(opHandle); + opHandle = 0L; + } + } + + private long opHandle; + } + + private final EagerSession session; + private final String type; + private final String name; + private final NativeReference nativeRef; + + private static native long allocate(long ctxHandle, String type); + + private static native void delete(long opHandle); + + private static native long[] execute(long opHandle); + + private static native void addInput(long opHandle, long tensorHandle); + + private static native void addInputList(long opHandle, long[] tensorHandles); + + private static native void setDevice(long opHandle, String device); + + private static native void setAttrString(long opHandle, String name, byte[] value); + + private static native void setAttrStringList(long opHandle, String name, Object[] value); + + private static native void setAttrInt(long opHandle, String name, long value); + + private static native void setAttrIntList(long opHandle, String name, long[] values); + + private static native void setAttrFloat(long opHandle, String name, float value); + + private static native void setAttrFloatList(long opHandle, String name, float[] values); + + private static native void setAttrBool(long opHandle, String name, boolean value); + + private static native void setAttrBoolList(long opHandle, String name, boolean[] values); + + private static native void setAttrType(long opHandle, String name, int type); + + private static native void setAttrTypeList(long opHandle, String name, int[] types); + + private static native void setAttrTensor(long opHandle, String name, long tensorHandle); + + private static native void setAttrShape(long opHandle, String name, long[] shape, int numDims); + + private static native void setAttrShapeList( + long opHandle, String name, long[] shapes, int[] numDims); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java new file mode 100644 index 00000000000..7f36da173e6 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/EagerSession.java @@ -0,0 +1,417 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import java.lang.ref.PhantomReference; +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * An environment for executing TensorFlow operations eagerly. + * + *

Eager execution is an imperative programming environment that evaluates operations + * immediately, without building graphs. Operations return concrete values instead of constructing a + * computational graph to run later, as with {@link Graph}s and {@link Session}s. + * + *

This makes it easy to develop with TensorFlow and debug models, as it behaves more like a + * standard programming library. + * + *

Instances of a {@code EagerSession} are thread-safe. + * + *

WARNING: Resources consumed by an {@code EagerSession} object must be explicitly freed + * by invoking the {@link #close()} method when it is no longer needed. This could be achieve using + * the `try-with-resources` technique as the example below: + * + *

{@code
+ * try (EagerSession s = EagerSession.create()) {
+ *    // execute operations eagerly
+ * }
+ * }
+ * + * In addition, {@code EagerSession} objects clean up unused resources during the session, working + * in pair with the JVM garbage collector. See {@link ResourceCleanupStrategy} for more details. + */ +public final class EagerSession implements ExecutionEnvironment, AutoCloseable { + + /** + * Controls how to act when we try to run an operation on a given device but some input tensors + * are not on that device. + */ + public static enum DevicePlacementPolicy { + + /** Running operations with input tensors on the wrong device will fail. */ + EXPLICIT(0), + + /** Copy the tensor to the right device but log a warning. */ + WARN(1), + + /** + * Silently copy the tensor, which has a performance cost since the operation will be blocked + * till the copy completes. This is the default placement policy. + */ + SILENT(2), + + /** Placement policy which silently copies int32 tensors but not other dtypes. */ + SILENT_FOR_INT32(3); + + private DevicePlacementPolicy(int code) { + this.code = code; + } + + private final int code; + } + + /** + * Controls how TensorFlow resources are cleaned up when they are no longer needed. + * + *

All resources allocated during an {@code EagerSession} are deleted when the session is + * closed. To prevent out-of-memory errors, it is also strongly suggest to cleanup those resources + * during the session. For example, executing n operations in a loop of m iterations will allocate + * a minimum of n*m resources while in most cases, only resources of the last iteration are still + * being used. + * + *

{@code EagerSession} instances can be notified in different ways when TensorFlow objects are + * no longer being referred, so they can proceed to the cleanup of any resources they owned. + */ + public static enum ResourceCleanupStrategy { + + /** + * Monitor and delete unused resources from a new thread running in background. + * + *

This is the most reliable approach to cleanup TensorFlow resources, at the cost of + * starting and running an additional thread dedicated to this task. Each {@code EagerSession} + * instance has its own thread, which is stopped only when the session is closed. + * + *

This strategy is used by default. + */ + IN_BACKGROUND, + + /** + * Monitor and delete unused resources from existing threads, before or after they complete + * another task. + * + *

Unused resources are released when a call to the TensorFlow library reaches a safe point + * for cleanup. This is done synchronously and might block for a short period of time the thread + * who triggered that call. + * + *

This strategy should be used only if, for some reasons, no additional thread should be + * allocated for cleanup. Otherwise, {@link #IN_BACKGROUND} should be preferred. + */ + ON_SAFE_POINTS, + + /** + * Only delete resources when the session is closed. + * + *

All resources allocated during the session will remained in memory until the session is + * explicitly closed (or via the traditional `try-with-resource` technique). No extra task for + * resource cleanup will be attempted. + * + *

This strategy can lead up to out-of-memory errors and its usage is not recommended, unless + * the scope of the session is limited to execute only a small amount of operations. + */ + ON_SESSION_CLOSE, + } + + public static class Options { + + /** + * Controls how operations dispatched are actually executed. + * + *

When set to true, each operation are executed asynchronously (in which case some + * operations might return "non-ready" outputs). When set to false, all operations are executed + * synchronously. + * + *

Synchronous execution is used by default. + * + * @param value true for asynchronous execution, false for synchronous. + */ + public Options async(boolean value) { + async = value; + return this; + } + + /** + * Controls how to act when we try to run an operation on a given device but some input tensors + * are not on that device. + * + *

{@link DevicePlacementPolicy#SILENT} is used by default. + * + * @param value policy to apply + * @see {@link DevicePlacementPolicy} + */ + public Options devicePlacementPolicy(DevicePlacementPolicy value) { + devicePlacementPolicy = value; + return this; + } + + /** + * Controls how TensorFlow resources are cleaned up when no longer needed. + * + *

{@link ResourceCleanupStrategy#IN_BACKGROUND} is used by default. + * + * @param value strategy to use + * @see {@link ResourceCleanupStrategy} + */ + public Options resourceCleanupStrategy(ResourceCleanupStrategy value) { + resourceCleanupStrategy = value; + return this; + } + + /** + * Configures the session based on the data found in the provided buffer, which is serialized + * TensorFlow config proto. + * + *

Warning: the support of this feature is subject to changes since TensorFlow protos might + * not be supported on public endpoints in the future. + * + * @param value a serialized config proto + * @see + * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto + */ + public Options config(byte[] value) { + config = value; + return this; + } + + /** Builds an eager session with the selected options. */ + public EagerSession build() { + return new EagerSession(this); + } + + private boolean async; + private DevicePlacementPolicy devicePlacementPolicy; + private ResourceCleanupStrategy resourceCleanupStrategy; + private byte[] config; + + private Options() { + async = false; + devicePlacementPolicy = DevicePlacementPolicy.SILENT; + resourceCleanupStrategy = ResourceCleanupStrategy.IN_BACKGROUND; + config = null; + } + } + + /** Returns an object that configures and builds a {@code EagerSession} with custom options. */ + public static EagerSession.Options options() { + return new Options(); + } + + /** Returns an {@code EagerSession} configured with default options. */ + public static EagerSession create() { + return options().build(); + } + + private EagerSession(Options options) { + this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); + this.resourceCleanupStrategy = options.resourceCleanupStrategy; + + if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) { + nativeResources.startCleanupThread(); + } + } + + @Override + public synchronized void close() { + if (nativeHandle != 0L) { + if (resourceCleanupStrategy == ResourceCleanupStrategy.IN_BACKGROUND) { + nativeResources.stopCleanupThread(); + } + nativeResources.deleteAll(); + delete(nativeHandle); + nativeHandle = 0L; + } + } + + @Override + public OperationBuilder opBuilder(String type, String name) { + if (resourceCleanupStrategy == ResourceCleanupStrategy.ON_SAFE_POINTS) { + nativeResources.tryCleanup(); + } + checkSession(); + return new EagerOperationBuilder(this, type, name); + } + + long nativeHandle() { + checkSession(); + return nativeHandle; + } + + /** + * A reference to one or more allocated native resources. + * + *

Any Java objects owning native resources must declare a reference to those resources in a + * subclass that extends from {@code NativeReference}. When {@link NativeReference#delete()} is + * invoked, the resources must be freed. For example: + * + *

{@code
+   * private static class NativeReference extends EagerSession.NativeReference {
+   *
+   *    NativeReference(EagerSession session, MyClass referent, long handle) {
+   *        super(session, referent);
+   *        this.handle = handle;
+   *    }
+   *
+   *    @Override
+   *    void delete() {
+   *        MyClass.nativeDelete(handle);
+   *    }
+   *
+   *    private final long handle;
+   * }
+   * }
+ * + * A Java object "owns" a native resource if this resource should not survive beyond the lifetime + * of this object. + * + *

IMPORTANT: All nested subclasses of {@code NativeReference} must be declared as + * static, otherwise their instances will hold an implicit reference to their enclosing object, + * preventing the garbage collector to release them when they are no longer needed. + */ + abstract static class NativeReference extends PhantomReference { + + /** Attach a new phantom reference of {@code referent} to {@code session}. */ + public NativeReference(EagerSession session, Object referent) { + super(referent, session.nativeResources.garbageQueue); + session.checkSession(); + nativeResources = session.nativeResources; + nativeResources.attach(this); + } + + /** + * Detach this reference from its current session. + * + *

Clearing a NativeReference does not invoke {@link #delete()}, thus won't release the + * native resources it refers to. It can be used when passing the ownership of those resources + * to another object. + * + *

If native resources needs to be deleted as well, call {@link #delete()} explicitly. + */ + @Override + public void clear() { + nativeResources.detach(this); + super.clear(); + } + + /** Releases all native resources owned by the referred object, now deleted. */ + abstract void delete(); + + private final NativeResourceCollector nativeResources; + } + + /** + * Collects native references attached to this session and releases their resources if they are no + * longer needed. + */ + private static class NativeResourceCollector { + + void attach(NativeReference nativeRef) { + synchronized (nativeRefs) { + nativeRefs.put(nativeRef, null); + } + } + + void detach(NativeReference nativeRef) { + synchronized (nativeRefs) { + nativeRefs.remove(nativeRef); + } + } + + void delete(NativeReference nativeRef) { + synchronized (nativeRefs) { + if (!nativeRefs.keySet().remove(nativeRef)) { + return; // safety check + } + } + nativeRef.delete(); + } + + void deleteAll() { + synchronized (nativeRefs) { + for (NativeReference nativeRef : nativeRefs.keySet()) { + nativeRef.delete(); + } + nativeRefs.clear(); + } + } + + void tryCleanup() { + Reference nativeRef; + synchronized (nativeRefs) { + while ((nativeRef = garbageQueue.poll()) != null) { + delete((NativeReference) nativeRef); + } + } + } + + synchronized void startCleanupThread() { + if (cleanupInBackground) { + return; // ignore if cleanup thread is already running + } + try { + cleanupInBackground = true; + cleanupService.execute( + new Runnable() { + @Override + public void run() { + try { + while (cleanupInBackground) { + NativeReference nativeRef = (NativeReference) garbageQueue.remove(); + delete(nativeRef); + } + } catch (InterruptedException e) { + // exit + } + } + }); + } catch (Exception e) { + cleanupInBackground = false; + throw e; + } + } + + void stopCleanupThread() { + cleanupInBackground = false; + cleanupService.shutdownNow(); // returns without waiting for the thread to stop + } + + private final ExecutorService cleanupService = Executors.newSingleThreadExecutor(); + private final Map nativeRefs = new IdentityHashMap<>(); + private final ReferenceQueue garbageQueue = new ReferenceQueue<>(); + private volatile boolean cleanupInBackground = false; + } + + private final NativeResourceCollector nativeResources = new NativeResourceCollector(); + private final ResourceCleanupStrategy resourceCleanupStrategy; + private long nativeHandle; + + private void checkSession() { + if (nativeHandle == 0L) { + throw new IllegalStateException("Eager session has been closed"); + } + } + + private static native long allocate(boolean async, int devicePlacementPolicy, byte[] config); + + private static native void delete(long handle); + + static { + TensorFlow.init(); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow/java/src/main/java/org/tensorflow/ExecutionEnvironment.java new file mode 100644 index 00000000000..5cc1930ba5e --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/ExecutionEnvironment.java @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */ +public interface ExecutionEnvironment { + + /** + * Returns a builder to create a new {@link Operation}. + * + * @param type of the Operation (i.e., identifies the computation to be performed) + * @param name to refer to the created Operation in this environment scope. + * @return an {@link OperationBuilder} to create an Operation when {@link + * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, + * then some resources may leak. + */ + OperationBuilder opBuilder(String type, String name); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index d5dae187197..a0e14f1512c 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -25,7 +25,7 @@ import java.util.Iterator; *

WARNING: Resources consumed by the Graph object must be explicitly freed by invoking * the {@link #close()} method then the Graph object is no longer needed. */ -public final class Graph implements AutoCloseable { +public final class Graph implements ExecutionEnvironment, AutoCloseable { /** Create an empty Graph. */ public Graph() { @@ -68,13 +68,13 @@ public final class Graph implements AutoCloseable { * *

Or {@code null} if no such operation exists in the Graph. */ - public Operation operation(String name) { + public GraphOperation operation(String name) { synchronized (nativeHandleLock) { long oph = operation(nativeHandle, name); if (oph == 0) { return null; } - return new Operation(this, oph); + return new GraphOperation(this, oph); } } @@ -97,8 +97,9 @@ public final class Graph implements AutoCloseable { * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, * then some resources may leak. */ - public OperationBuilder opBuilder(String type, String name) { - return new OperationBuilder(this, type, name); + @Override + public GraphOperationBuilder opBuilder(String type, String name) { + return new GraphOperationBuilder(this, type, name); } /** @@ -177,11 +178,11 @@ public final class Graph implements AutoCloseable { try (Reference ref = ref()) { for (int i = 0; i < y.length; ++i) { - yHandles[i] = y[i].op().getUnsafeNativeHandle(); + yHandles[i] = y[i].getUnsafeNativeHandle(); yIndices[i] = y[i].index(); } for (int i = 0; i < x.length; ++i) { - xHandles[i] = x[i].op().getUnsafeNativeHandle(); + xHandles[i] = x[i].getUnsafeNativeHandle(); xIndices[i] = x[i].index(); } if (dx != null && dx.length > 0) { @@ -189,7 +190,7 @@ public final class Graph implements AutoCloseable { dxIndices = new int[dx.length]; for (int i = 0; i < dx.length; ++i) { - dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); + dxHandles[i] = dx[i].getUnsafeNativeHandle(); dxIndices[i] = dx[i].index(); } } @@ -214,7 +215,7 @@ public final class Graph implements AutoCloseable { + " were expected"); } for (int i = 0, j = ndy; i < ndy; ++i, ++j) { - Operation op = new Operation(this, dyHandlesAndIndices[i]); + GraphOperation op = new GraphOperation(this, dyHandlesAndIndices[i]); dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); } } @@ -286,19 +287,19 @@ public final class Graph implements AutoCloseable { try (Reference ref = subgraph.ref()) { for (int i = 0; i < ninputs; i++) { - Operation op = new Operation(subgraph, inputHandles[i]); - inputs[i] = new Output<>(op, inputIndices[i]); + Operation op = new GraphOperation(subgraph, inputHandles[i]); + inputs[i] = op.output(inputIndices[i]); } for (int i = 0; i < noutputs; i++) { - Operation op = new Operation(subgraph, outputHandles[i]); - outputs[i] = new Output<>(op, outputIndices[i]); + Operation op = new GraphOperation(subgraph, outputHandles[i]); + outputs[i] = op.output(outputIndices[i]); } subgraphBuilder.buildSubgraph(subgraph, inputs, outputs); for (int i = 0, j = noutputs; i < noutputs; i++, j++) { - outputHandlesAndIndices[i] = outputs[i].op().getUnsafeNativeHandle(); + outputHandlesAndIndices[i] = outputs[i].getUnsafeNativeHandle(); outputHandlesAndIndices[j] = (long) outputs[i].index(); } } @@ -329,7 +330,7 @@ public final class Graph implements AutoCloseable { try (Reference ref = ref()) { for (int i = 0; i < ninputs; i++) { - inputHandles[i] = inputs[i].op().getUnsafeNativeHandle(); + inputHandles[i] = inputs[i].getUnsafeNativeHandle(); inputIndices[i] = inputs[i].index(); } @@ -337,8 +338,8 @@ public final class Graph implements AutoCloseable { whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder); for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { - Operation op = new Operation(this, outputHandlesAndIndices[i]); - outputs[i] = new Output<>(op, (int) outputHandlesAndIndices[j]); + Operation op = new GraphOperation(this, outputHandlesAndIndices[i]); + outputs[i] = op.output((int) outputHandlesAndIndices[j]); } } return outputs; @@ -411,7 +412,7 @@ public final class Graph implements AutoCloseable { long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position); if ((nativeReturn != null) && (nativeReturn[0] != 0)) { - this.operation = new Operation(this.graph, nativeReturn[0]); + this.operation = new GraphOperation(this.graph, nativeReturn[0]); this.position = (int) nativeReturn[1]; } } finally { diff --git a/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java new file mode 100644 index 00000000000..be56ac889c1 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/GraphOperation.java @@ -0,0 +1,168 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** + * Implementation for an {@link Operation} added as a node to a {@link Graph}. + * + *

GraphOperation instances are valid only as long as the {@link Graph} they are a part of is + * valid. Thus, if {@link Graph#close()} has been invoked, then methods on the GraphOperation + * instance may fail with an {@code IllegalStateException}. + * + *

GraphOperation instances are immutable and thread-safe. + */ +public final class GraphOperation extends AbstractOperation { + + // Create an GraphOperation instance referring to an operation in g, with the given handle to the + // C + // TF_Operation object. The handle is valid only as long as g has not been closed, hence it is + // called unsafeHandle. Graph.ref() is used to safely use the unsafeHandle. + GraphOperation(Graph g, long unsafeNativeHandle) { + this.graph = g; + this.unsafeNativeHandle = unsafeNativeHandle; + } + + @Override + public String name() { + Graph.Reference r = graph.ref(); + try { + return name(getUnsafeNativeHandle()); + } finally { + r.close(); + } + } + + @Override + public String type() { + Graph.Reference r = graph.ref(); + try { + return type(getUnsafeNativeHandle()); + } finally { + r.close(); + } + } + + @Override + public int numOutputs() { + Graph.Reference r = graph.ref(); + try { + return numOutputs(getUnsafeNativeHandle()); + } finally { + r.close(); + } + } + + @Override + public int outputListLength(final String name) { + Graph.Reference r = graph.ref(); + try { + return outputListLength(getUnsafeNativeHandle(), name); + } finally { + r.close(); + } + } + + @Override + public int hashCode() { + return Long.valueOf(getUnsafeNativeHandle()).hashCode(); + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (!(o instanceof GraphOperation)) { + return false; + } + GraphOperation that = (GraphOperation) o; + if (graph != that.graph) { + return false; + } + + // The graph object is known to be identical here, so this one + // reference is sufficient to validate the use of native pointers + // in both objects. + Graph.Reference r = graph.ref(); + try { + return getUnsafeNativeHandle() == that.getUnsafeNativeHandle(); + } finally { + r.close(); + } + } + + @Override + public int inputListLength(final String name) { + Graph.Reference r = graph.ref(); + try { + return inputListLength(getUnsafeNativeHandle(), name); + } finally { + r.close(); + } + } + + @Override + long getUnsafeNativeHandle(int outputIdx) { + return getUnsafeNativeHandle(); + } + + @Override + long[] shape(int outputIdx) { + Graph.Reference r = graph.ref(); + try { + return shape(r.nativeHandle(), getUnsafeNativeHandle(), outputIdx); + } finally { + r.close(); + } + } + + @Override + DataType dtype(int outputIdx) { + Graph.Reference r = graph.ref(); + try { + return DataType.fromC(dtype(r.nativeHandle(), getUnsafeNativeHandle(), outputIdx)); + } finally { + r.close(); + } + } + + @Override + Tensor tensor(int outputIdx) { + throw new IllegalStateException("Graph tensors must be fetched by running a session"); + } + + long getUnsafeNativeHandle() { + return unsafeNativeHandle; + } + + private final Graph graph; + + private final long unsafeNativeHandle; + + private static native String name(long handle); + + private static native String type(long handle); + + private static native int numOutputs(long handle); + + private static native int outputListLength(long handle, String name); + + private static native int inputListLength(long handle, String name); + + private static native long[] shape(long graphHandle, long opHandle, int output); + + private static native int dtype(long graphHandle, long opHandle, int output); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/GraphOperationBuilder.java new file mode 100644 index 00000000000..7567e1e7251 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -0,0 +1,344 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import java.nio.charset.Charset; + +/** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */ +public final class GraphOperationBuilder implements OperationBuilder { + + GraphOperationBuilder(Graph graph, String type, String name) { + this.graph = graph; + Graph.Reference r = graph.ref(); + try { + this.unsafeNativeHandle = allocate(r.nativeHandle(), type, name); + } finally { + r.close(); + } + } + + /** + * Add the {@link GraphOperation} being built to the {@link Graph}. + * + *

The OperationBuilder is not usable after build() returns. + */ + @Override + public GraphOperation build() { + Graph.Reference r = graph.ref(); + try { + GraphOperation op = new GraphOperation(graph, finish(unsafeNativeHandle)); + unsafeNativeHandle = 0; + return op; + } finally { + r.close(); + } + } + + @Override + public GraphOperationBuilder addControlInput(Operation control) { + if (!(control instanceof GraphOperation)) { + throw new IllegalArgumentException( + "Only GraphOperation instances can be used as control inputs"); + } + Graph.Reference r = graph.ref(); + try { + addControlInput(unsafeNativeHandle, ((GraphOperation) control).getUnsafeNativeHandle()); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder addInput(Output input) { + Graph.Reference r = graph.ref(); + try { + addInput(unsafeNativeHandle, input.getUnsafeNativeHandle(), input.index()); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder addInputList(Output[] inputs) { + Graph.Reference r = graph.ref(); + try { + long[] opHandles = new long[inputs.length]; + int[] indices = new int[inputs.length]; + for (int i = 0; i < inputs.length; ++i) { + opHandles[i] = inputs[i].getUnsafeNativeHandle(); + indices[i] = inputs[i].index(); + } + addInputList(unsafeNativeHandle, opHandles, indices); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setDevice(String device) { + Graph.Reference r = graph.ref(); + try { + setDevice(unsafeNativeHandle, device); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, String value) { + setAttr(name, value.getBytes(Charset.forName("UTF-8"))); + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, byte[] value) { + Graph.Reference r = graph.ref(); + try { + setAttrString(unsafeNativeHandle, name, value); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, long value) { + Graph.Reference r = graph.ref(); + try { + setAttrInt(unsafeNativeHandle, name, value); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, long[] value) { + Graph.Reference r = graph.ref(); + try { + setAttrIntList(unsafeNativeHandle, name, value); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, float value) { + Graph.Reference r = graph.ref(); + try { + setAttrFloat(unsafeNativeHandle, name, value); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, float[] value) { + Graph.Reference r = graph.ref(); + try { + setAttrFloatList(unsafeNativeHandle, name, value); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, boolean value) { + Graph.Reference r = graph.ref(); + try { + setAttrBool(unsafeNativeHandle, name, value); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, boolean[] value) { + Graph.Reference r = graph.ref(); + try { + setAttrBoolList(unsafeNativeHandle, name, value); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, DataType value) { + Graph.Reference r = graph.ref(); + try { + setAttrType(unsafeNativeHandle, name, value.c()); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, DataType[] value) { + int[] ctypes = new int[value.length]; + for (int i = 0; i < value.length; ++i) { + ctypes[i] = value[i].c(); + } + Graph.Reference r = graph.ref(); + try { + setAttrTypeList(unsafeNativeHandle, name, ctypes); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, Tensor value) { + Graph.Reference r = graph.ref(); + try { + setAttrTensor(unsafeNativeHandle, name, value.getNativeHandle()); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, Tensor[] value) { + long[] handles = new long[value.length]; + int idx = 0; + for (Tensor t : value) { + handles[idx++] = t.getNativeHandle(); + } + Graph.Reference r = graph.ref(); + try { + setAttrTensorList(unsafeNativeHandle, name, handles); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, Shape value) { + Graph.Reference r = graph.ref(); + try { + setAttrShape(unsafeNativeHandle, name, value.asArray(), value.numDimensions()); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, Shape[] value) { + int[] numDimensions = new int[value.length]; + int totalNumDimensions = 0; + for (int idx = 0; idx < value.length; ++idx) { + int n = value[idx].numDimensions(); + numDimensions[idx] = n; + if (n > 0) { + totalNumDimensions += n; + } + } + // Flatten the shapes into a single array to avoid too much overhead in the + // native part + long[] shapes = new long[totalNumDimensions]; + int shapeIdx = 0; + for (Shape shape : value) { + if (shape.numDimensions() > 0) { + for (long dim : shape.asArray()) { + shapes[shapeIdx++] = dim; + } + } + } + Graph.Reference r = graph.ref(); + try { + setAttrShapeList(unsafeNativeHandle, name, shapes, numDimensions); + } finally { + r.close(); + } + return this; + } + + @Override + public GraphOperationBuilder setAttr(String name, String[] value) { + Charset utf8 = Charset.forName("UTF-8"); + Object[] objects = new Object[value.length]; + for (int i = 0; i < value.length; ++i) { + objects[i] = value[i].getBytes(utf8); + } + Graph.Reference r = graph.ref(); + try { + setAttrStringList(unsafeNativeHandle, name, objects); + } finally { + r.close(); + } + return this; + } + + private long unsafeNativeHandle; + private Graph graph; + + private static native long allocate(long graphHandle, String type, String name); + + private static native long finish(long handle); + + private static native void addInput(long handle, long opHandle, int index); + + private static native void addInputList(long handle, long[] opHandles, int[] indices); + + private static native void addControlInput(long handle, long opHandle); + + private static native void setDevice(long handle, String device); + + // The names of all the setAttr* family functions below correspond to the C library types, not the + // Java library types. Roughly, setAttrFoo calls the TensorFlow C library function: TF_SetAttrFoo. + + private static native void setAttrString(long handle, String name, byte[] value); + + private static native void setAttrInt(long handle, String name, long value); + + private static native void setAttrIntList(long handle, String name, long[] value); + + private static native void setAttrFloat(long handle, String name, float value); + + private static native void setAttrFloatList(long handle, String name, float[] value); + + private static native void setAttrBool(long handle, String name, boolean value); + + private static native void setAttrBoolList(long handle, String name, boolean[] value); + + private static native void setAttrType(long handle, String name, int type); + + private static native void setAttrTypeList(long handle, String name, int[] type); + + private static native void setAttrTensor(long handle, String name, long tensorHandle); + + private static native void setAttrTensorList(long handle, String name, long[] tensorHandle); + + private static native void setAttrShape(long handle, String name, long[] shape, int numDims); + + private static native void setAttrShapeList( + long handle, String name, long[] shapes, int[] numDims); + + private static native void setAttrStringList(long handle, String name, Object[] value); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java index 6b82e5780b0..7dae6c263f3 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Operation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java @@ -16,59 +16,24 @@ limitations under the License. package org.tensorflow; /** - * A Graph node that performs computation on Tensors. + * Performs computation on Tensors. * - *

An Operation is a node in a {@link Graph} that takes zero or more {@link Tensor}s (produced by - * other Operations in the Graph) as input, and produces zero or more {@link Tensor}s as output. - * - *

Operation instances are valid only as long as the Graph they are a part of is valid. Thus, if - * {@link Graph#close()} has been invoked, then methods on the Operation instance may fail with an - * {@code IllegalStateException}. - * - *

Operation instances are immutable and thread-safe. + *

An Operation takes zero or more {@link Tensor}s (produced by other Operations) as input, and + * produces zero or more {@link Tensor}s as output. */ -public final class Operation { - - // Create an Operation instance referring to an operation in g, with the given handle to the C - // TF_Operation object. The handle is valid only as long as g has not been closed, hence it is - // called unsafeHandle. Graph.ref() is used to safely use the unsafeHandle. - Operation(Graph g, long unsafeNativeHandle) { - this.graph = g; - this.unsafeNativeHandle = unsafeNativeHandle; - } +public interface Operation { /** Returns the full name of the Operation. */ - public String name() { - Graph.Reference r = graph.ref(); - try { - return name(unsafeNativeHandle); - } finally { - r.close(); - } - } + String name(); /** * Returns the type of the operation, i.e., the name of the computation performed by the * operation. */ - public String type() { - Graph.Reference r = graph.ref(); - try { - return type(unsafeNativeHandle); - } finally { - r.close(); - } - } + String type(); /** Returns the number of tensors produced by this operation. */ - public int numOutputs() { - Graph.Reference r = graph.ref(); - try { - return numOutputs(unsafeNativeHandle); - } finally { - r.close(); - } - } + int numOutputs(); /** * Returns the size of the list of Tensors produced by this operation. @@ -82,14 +47,7 @@ public final class Operation { * @return the size of the list of Tensors produced by this named output. * @throws IllegalArgumentException if this operation has no output with the provided name. */ - public int outputListLength(final String name) { - Graph.Reference r = graph.ref(); - try { - return outputListLength(unsafeNativeHandle, name); - } finally { - r.close(); - } - } + int outputListLength(final String name); /** * Returns symbolic handles to a list of tensors produced by this operation. @@ -98,13 +56,7 @@ public final class Operation { * @param length number of tensors in the list * @return array of {@code Output} */ - public Output[] outputList(int idx, int length) { - Output[] outputs = new Output[length]; - for (int i = 0; i < length; ++i) { - outputs[i] = output(idx + i); - } - return outputs; - } + Output[] outputList(int idx, int length); /** * Returns a symbolic handle to one of the tensors produced by this operation. @@ -116,44 +68,7 @@ public final class Operation { * @param The expected element type of the tensors produced by this output. * @param idx The index of the output among the outputs produced by this operation. */ - @SuppressWarnings({"rawtypes", "unchecked"}) - public Output output(int idx) { - return new Output(this, idx); - } - - @Override - public int hashCode() { - return Long.valueOf(unsafeNativeHandle).hashCode(); - } - - @Override - public boolean equals(Object o) { - if (o == this) { - return true; - } - if (!(o instanceof Operation)) { - return false; - } - Operation that = (Operation) o; - if (graph != that.graph) { - return false; - } - - // The graph object is known to be identical here, so this one - // reference is sufficient to validate the use of native pointers - // in both objects. - Graph.Reference r = graph.ref(); - try { - return unsafeNativeHandle == that.unsafeNativeHandle; - } finally { - r.close(); - } - } - - @Override - public String toString() { - return String.format("<%s '%s'>", type(), name()); - } + Output output(int idx); /** * Returns the size of the given inputs list of Tensors for this operation. @@ -167,54 +82,5 @@ public final class Operation { * @return the size of the list of Tensors produced by this named input. * @throws IllegalArgumentException if this operation has no input with the provided name. */ - public int inputListLength(final String name) { - Graph.Reference r = graph.ref(); - try { - return inputListLength(unsafeNativeHandle, name); - } finally { - r.close(); - } - } - - long getUnsafeNativeHandle() { - return unsafeNativeHandle; - } - - // Package private, meant primarily for the public Output.shape() method. - long[] shape(int output) { - Graph.Reference r = graph.ref(); - try { - return shape(r.nativeHandle(), unsafeNativeHandle, output); - } finally { - r.close(); - } - } - - // Package private, meant primarily for the public Output.dataType() method. - DataType dtype(int output) { - Graph.Reference r = graph.ref(); - try { - return DataType.fromC(dtype(r.nativeHandle(), unsafeNativeHandle, output)); - } finally { - r.close(); - } - } - - private final long unsafeNativeHandle; - - private final Graph graph; - - private static native String name(long handle); - - private static native String type(long handle); - - private static native int numOutputs(long handle); - - private static native int outputListLength(long handle, String name); - - private static native int inputListLength(long handle, String name); - - private static native long[] shape(long graphHandle, long opHandle, int output); - - private static native int dtype(long graphHandle, long opHandle, int output); + int inputListLength(final String name); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java index a24150484e8..d78f404fb16 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java @@ -15,72 +15,54 @@ limitations under the License. package org.tensorflow; -import java.nio.charset.Charset; - /** - * A builder for {@link Operation}s in a {@link Graph}. + * A builder for {@link Operation}s. * - *

Instances of an OperationBuilder are not thread-safe. - * - *

A builder for adding {@link Operation}s to a {@link Graph}. For example, the following uses - * the builder to create an operation that produces the constant "3" as its output: + *

For example, the following uses the builder to create an operation that produces the constant + * "3" as its output: * *

{@code
- * // g is a Graph instance.
+ * // env is an ExecutionEnvironment, such as a Graph instance.
  * try (Tensor c1 = Tensor.create(3.0f)) {
- *   g.opBuilder("Const", "MyConst")
+ *   env.opBuilder("Const", "MyConst")
  *       .setAttr("dtype", c1.dataType())
  *       .setAttr("value", c1)
  *       .build();
  * }
  * }
*/ -public final class OperationBuilder { - - OperationBuilder(Graph graph, String type, String name) { - this.graph = graph; - Graph.Reference r = graph.ref(); - try { - this.unsafeNativeHandle = allocate(r.nativeHandle(), type, name); - } finally { - r.close(); - } - } +public interface OperationBuilder { /** - * Add the {@link Operation} being built to the {@link Graph}. + * Build the {@link Operation}. + * + *

The following action will also be performed depending on the current execution environment. + * + *

    + *
  • In eager mode, the result of the operation will be computed immediately. + *
  • In graph mode, the operation will be added as a node to the graph to be executed later, + * when running a {@link Session}. + *
* *

The OperationBuilder is not usable after build() returns. */ - public Operation build() { - Graph.Reference r = graph.ref(); - try { - Operation op = new Operation(graph, finish(unsafeNativeHandle)); - unsafeNativeHandle = 0; - return op; - } finally { - r.close(); - } - } + public Operation build(); /** - * Returns the builder to create an operation. + * Add the output of another operation as the next input of the operation being built. * - *

Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is - * used to add a input to a {@link OperationBuilder}. - * - * @param input {@link Output} supposed to be the input of the OperationBuilder. + * @param input {@link Output} supposed to be the input of the operation being built. * @return the OperationBuilder instance for chaining. */ - public OperationBuilder addInput(Output input) { - Graph.Reference r = graph.ref(); - try { - addInput(unsafeNativeHandle, input.op().getUnsafeNativeHandle(), input.index()); - } finally { - r.close(); - } - return this; - } + public OperationBuilder addInput(Output input); + + /** + * Add the outputs of another operation as the next inputs of the operation being built. + * + * @param inputs list of {@link Output} supposed to be the inputs of the operation being built. + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder addInputList(Output[] inputs); /** * Ensure that the operation does not execute before the control operation does. @@ -95,265 +77,148 @@ public final class OperationBuilder { * @param control operation that must be executed before running this operation. * @return the OperationBuilder instance for chaining. */ - public OperationBuilder addControlInput(Operation control) { - Graph.Reference r = graph.ref(); - try { - addControlInput(unsafeNativeHandle, control.getUnsafeNativeHandle()); - } finally { - r.close(); - } - return this; - } + public OperationBuilder addControlInput(Operation control); - public OperationBuilder addInputList(Output[] inputs) { - Graph.Reference r = graph.ref(); - try { - long[] opHandles = new long[inputs.length]; - int[] indices = new int[inputs.length]; - for (int i = 0; i < inputs.length; ++i) { - opHandles[i] = inputs[i].op().getUnsafeNativeHandle(); - indices[i] = inputs[i].index(); - } - addInputList(unsafeNativeHandle, opHandles, indices); - } finally { - r.close(); - } - return this; - } + /** + * Set the device requested for computing the operation being built. + * + * @param device the requested device, as a string + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setDevice(String device); - public OperationBuilder setDevice(String device) { - Graph.Reference r = graph.ref(); - try { - setDevice(unsafeNativeHandle, device); - } finally { - r.close(); - } - return this; - } + /** + * Set the string values of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute values + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, String[] value); - public OperationBuilder setAttr(String name, String value) { - setAttr(name, value.getBytes(Charset.forName("UTF-8"))); - return this; - } + /** + * Set the string value of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, String value); - public OperationBuilder setAttr(String name, byte[] value) { - Graph.Reference r = graph.ref(); - try { - setAttrString(unsafeNativeHandle, name, value); - } finally { - r.close(); - } - return this; - } + /** + * Set the byte values of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute values + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, byte[] value); - public OperationBuilder setAttr(String name, long value) { - Graph.Reference r = graph.ref(); - try { - setAttrInt(unsafeNativeHandle, name, value); - } finally { - r.close(); - } - return this; - } + /** + * Set the long value of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, long value); - public OperationBuilder setAttr(String name, long[] value) { - Graph.Reference r = graph.ref(); - try { - setAttrIntList(unsafeNativeHandle, name, value); - } finally { - r.close(); - } - return this; - } + /** + * Set the long values of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute values + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, long[] value); - public OperationBuilder setAttr(String name, float value) { - Graph.Reference r = graph.ref(); - try { - setAttrFloat(unsafeNativeHandle, name, value); - } finally { - r.close(); - } - return this; - } + /** + * Set the float value of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, float value); - public OperationBuilder setAttr(String name, float[] value) { - Graph.Reference r = graph.ref(); - try { - setAttrFloatList(unsafeNativeHandle, name, value); - } finally { - r.close(); - } - return this; - } + /** + * Set the float values of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute values + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, float[] value); - public OperationBuilder setAttr(String name, boolean value) { - Graph.Reference r = graph.ref(); - try { - setAttrBool(unsafeNativeHandle, name, value); - } finally { - r.close(); - } - return this; - } + /** + * Set the boolean value of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, boolean value); - public OperationBuilder setAttr(String name, boolean[] value) { - Graph.Reference r = graph.ref(); - try { - setAttrBoolList(unsafeNativeHandle, name, value); - } finally { - r.close(); - } - return this; - } + /** + * Set the boolean values of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute values + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, boolean[] value); - public OperationBuilder setAttr(String name, DataType value) { - Graph.Reference r = graph.ref(); - try { - setAttrType(unsafeNativeHandle, name, value.c()); - } finally { - r.close(); - } - return this; - } + /** + * Set the type value of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, DataType value); - public OperationBuilder setAttr(String name, DataType[] value) { - int[] ctypes = new int[value.length]; - for (int i = 0; i < value.length; ++i) { - ctypes[i] = value[i].c(); - } - Graph.Reference r = graph.ref(); - try { - setAttrTypeList(unsafeNativeHandle, name, ctypes); - } finally { - r.close(); - } - return this; - } + /** + * Set the type values of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute values + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, DataType[] value); - public OperationBuilder setAttr(String name, Tensor value) { - Graph.Reference r = graph.ref(); - try { - setAttrTensor(unsafeNativeHandle, name, value.getNativeHandle()); - } finally { - r.close(); - } - return this; - } + /** + * Set the tensor value of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, Tensor value); - public OperationBuilder setAttr(String name, Tensor[] value) { - long[] handles = new long[value.length]; - int idx = 0; - for (Tensor t : value) { - handles[idx++] = t.getNativeHandle(); - } - Graph.Reference r = graph.ref(); - try { - setAttrTensorList(unsafeNativeHandle, name, handles); - } finally { - r.close(); - } - return this; - } + /** + * Set the tensor values of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute values + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, Tensor[] value); - public OperationBuilder setAttr(String name, Shape value) { - Graph.Reference r = graph.ref(); - try { - setAttrShape(unsafeNativeHandle, name, value.asArray(), value.numDimensions()); - } finally { - r.close(); - } - return this; - } + /** + * Set the shape value of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute value + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, Shape value); - public OperationBuilder setAttr(String name, Shape[] value) { - int[] numDimensions = new int[value.length]; - int totalNumDimensions = 0; - for (int idx = 0; idx < value.length; ++idx) { - int n = value[idx].numDimensions(); - numDimensions[idx] = n; - if (n > 0) { - totalNumDimensions += n; - } - } - // Flatten the shapes into a single array to avoid too much overhead in the - // native part - long[] shapes = new long[totalNumDimensions]; - int shapeIdx = 0; - for (Shape shape : value) { - if (shape.numDimensions() > 0) { - for (long dim : shape.asArray()) { - shapes[shapeIdx++] = dim; - } - } - } - Graph.Reference r = graph.ref(); - try { - setAttrShapeList(unsafeNativeHandle, name, shapes, numDimensions); - } finally { - r.close(); - } - return this; - } - - public OperationBuilder setAttr(String name, String[] value) { - Charset utf8 = Charset.forName("UTF-8"); - Object[] objects = new Object[value.length]; - for (int i = 0; i < value.length; ++i) { - objects[i] = value[i].getBytes(utf8); - } - Graph.Reference r = graph.ref(); - try { - setAttrStringList(unsafeNativeHandle, name, objects); - } finally { - r.close(); - } - return this; - } - - private long unsafeNativeHandle; - private Graph graph; - - private static native long allocate(long graphHandle, String type, String name); - - private static native long finish(long handle); - - private static native void addInput(long handle, long opHandle, int index); - - private static native void addInputList(long handle, long[] opHandles, int[] indices); - - private static native void addControlInput(long handle, long opHandle); - - private static native void setDevice(long handle, String device); - - // The names of all the setAttr* family functions below correspond to the C library types, not the - // Java library types. Roughly, setAttrFoo calls the TensorFlow C library function: TF_SetAttrFoo. - - private static native void setAttrString(long handle, String name, byte[] value); - - private static native void setAttrInt(long handle, String name, long value); - - private static native void setAttrIntList(long handle, String name, long[] value); - - private static native void setAttrFloat(long handle, String name, float value); - - private static native void setAttrFloatList(long handle, String name, float[] value); - - private static native void setAttrBool(long handle, String name, boolean value); - - private static native void setAttrBoolList(long handle, String name, boolean[] value); - - private static native void setAttrType(long handle, String name, int type); - - private static native void setAttrTypeList(long handle, String name, int[] type); - - private static native void setAttrTensor(long handle, String name, long tensorHandle); - - private static native void setAttrTensorList(long handle, String name, long[] tensorHandle); - - private static native void setAttrShape(long handle, String name, long[] shape, int numDims); - - private static native void setAttrShapeList( - long handle, String name, long[] shapes, int[] numDims); - - private static native void setAttrStringList(long handle, String name, Object[] value); + /** + * Set the shape values of an attribute of the operation being built. + * + * @param name attribute name + * @param value attribute values + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder setAttr(String name, Shape[] value); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java index 479dc8574c2..f6fc1ac8cfe 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Output.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java @@ -28,12 +28,6 @@ import java.util.Objects; */ public final class Output implements Operand { - /** Handle to the idx-th output of the Operation {@code op}. */ - public Output(Operation op, int idx) { - operation = op; - index = idx; - } - /** Returns the Operation that will produce the tensor referred to by this Output. */ public Operation op() { return operation; @@ -54,6 +48,22 @@ public final class Output implements Operand { return operation.dtype(index); } + /** + * Returns the tensor at this output. + * + *

This operation is only supported on the outputs of an operation executed eagerly. For graph + * environments, output tensors must be fetched by running a session, using {@link + * Session.Runner#fetch(Output)}. + * + * @return tensor + * @throws IllegalStateException if this output results from a graph + * @see EagerSession + */ + @SuppressWarnings("unchecked") + public Tensor tensor() { + return (Tensor) operation.tensor(index); + } + @Override public Output asOutput() { return this; @@ -83,6 +93,16 @@ public final class Output implements Operand { operation.type(), operation.name(), index, shape().toString(), dataType()); } - private final Operation operation; + /** Handle to the idx-th output of the Operation {@code op}. */ + Output(AbstractOperation op, int idx) { + operation = op; + index = idx; + } + + long getUnsafeNativeHandle() { + return operation.getUnsafeNativeHandle(index); + } + + private final AbstractOperation operation; private final int index; } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java index 8cc23e2991b..b5e0f7ac508 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Session.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java @@ -204,7 +204,7 @@ public final class Session implements AutoCloseable { * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s. */ public Runner addTarget(String operation) { - Operation op = operationByName(operation); + GraphOperation op = operationByName(operation); if (op != null) { targets.add(op); } @@ -213,9 +213,17 @@ public final class Session implements AutoCloseable { /** * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s. + * + * @throws execption if the operation is not a {@link GraphOperation} */ public Runner addTarget(Operation operation) { - targets.add(operation); + if (!(operation instanceof GraphOperation)) { + throw new IllegalArgumentException( + "Operation of type " + + operation.getClass().getName() + + " is not supported in graph sessions"); + } + targets.add((GraphOperation) operation); return this; } @@ -293,18 +301,18 @@ public final class Session implements AutoCloseable { } idx = 0; for (Output o : inputs) { - inputOpHandles[idx] = o.op().getUnsafeNativeHandle(); + inputOpHandles[idx] = o.getUnsafeNativeHandle(); inputOpIndices[idx] = o.index(); idx++; } idx = 0; for (Output o : outputs) { - outputOpHandles[idx] = o.op().getUnsafeNativeHandle(); + outputOpHandles[idx] = o.getUnsafeNativeHandle(); outputOpIndices[idx] = o.index(); idx++; } idx = 0; - for (Operation op : targets) { + for (GraphOperation op : targets) { targetOpHandles[idx++] = op.getUnsafeNativeHandle(); } Reference runRef = new Reference(); @@ -366,8 +374,8 @@ public final class Session implements AutoCloseable { } } - private Operation operationByName(String opName) { - Operation op = graph.operation(opName); + private GraphOperation operationByName(String opName) { + GraphOperation op = graph.operation(opName); if (op == null) { throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph"); } @@ -392,7 +400,7 @@ public final class Session implements AutoCloseable { private ArrayList> inputs = new ArrayList>(); private ArrayList> inputTensors = new ArrayList>(); private ArrayList> outputs = new ArrayList>(); - private ArrayList targets = new ArrayList(); + private ArrayList targets = new ArrayList(); private byte[] runOptions = null; } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 89872537689..ebc5b01ee85 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -140,15 +140,17 @@ public final class Tensor implements AutoCloseable { Tensor t = new Tensor(dtype); t.shapeCopy = new long[numDimensions(obj, dtype)]; fillShape(obj, 0, t.shapeCopy); + long nativeHandle; if (t.dtype != DataType.STRING) { int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy); - t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize); - setValue(t.nativeHandle, obj); + nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize); + setValue(nativeHandle, obj); } else if (t.shapeCopy.length != 0) { - t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj); + nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj); } else { - t.nativeHandle = allocateScalarBytes((byte[]) obj); + nativeHandle = allocateScalarBytes((byte[]) obj); } + t.nativeRef = new NativeReference(nativeHandle); return t; } @@ -314,23 +316,22 @@ public final class Tensor implements AutoCloseable { } Tensor t = new Tensor(dataType); t.shapeCopy = Arrays.copyOf(shape, shape.length); - t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes); + long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes); + t.nativeRef = new NativeReference(nativeHandle); return t; } /** * Release resources associated with the Tensor. * - *

WARNING:If not invoked, memory will be leaked. + *

WARNING:This must be invoked for all tensors that were not been produced by an eager + * operation or memory will be leaked. * *

The Tensor object is no longer usable after {@code close} returns. */ @Override public void close() { - if (nativeHandle != 0) { - delete(nativeHandle); - nativeHandle = 0; - } + nativeRef.release(); } /** Returns the {@link DataType} of elements stored in the Tensor. */ @@ -374,7 +375,7 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException if the Tensor does not represent a float scalar. */ public float floatValue() { - return scalarFloat(nativeHandle); + return scalarFloat(getNativeHandle()); } /** @@ -383,7 +384,7 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException if the Tensor does not represent a double scalar. */ public double doubleValue() { - return scalarDouble(nativeHandle); + return scalarDouble(getNativeHandle()); } /** @@ -392,7 +393,7 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException if the Tensor does not represent a int scalar. */ public int intValue() { - return scalarInt(nativeHandle); + return scalarInt(getNativeHandle()); } /** @@ -401,7 +402,7 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException if the Tensor does not represent a long scalar. */ public long longValue() { - return scalarLong(nativeHandle); + return scalarLong(getNativeHandle()); } /** @@ -410,7 +411,7 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. */ public boolean booleanValue() { - return scalarBoolean(nativeHandle); + return scalarBoolean(getNativeHandle()); } /** @@ -419,7 +420,7 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException if the Tensor does not represent a boolean scalar. */ public byte[] bytesValue() { - return scalarBytes(nativeHandle); + return scalarBytes(getNativeHandle()); } /** @@ -448,7 +449,7 @@ public final class Tensor implements AutoCloseable { */ public U copyTo(U dst) { throwExceptionIfTypeIsIncompatible(dst); - readNDArray(nativeHandle, dst); + readNDArray(getNativeHandle(), dst); return dst; } @@ -553,16 +554,27 @@ public final class Tensor implements AutoCloseable { @SuppressWarnings("rawtypes") Tensor t = new Tensor(DataType.fromC(dtype(handle))); t.shapeCopy = shape(handle); - t.nativeHandle = handle; + t.nativeRef = new NativeReference(handle); + return t; + } + + /** + * Create an eager Tensor object from a handle to the C TF_Tensor object. + * + *

Takes ownership of the handle. + */ + static Tensor fromHandle(long handle, EagerSession session) { + Tensor t = fromHandle(handle); + t.nativeRef.eager(session, t); return t; } long getNativeHandle() { - return nativeHandle; + return nativeRef.tensorHandle; } - private long nativeHandle; - private DataType dtype; + private NativeReference nativeRef = null; + private final DataType dtype; private long[] shapeCopy = null; private Tensor(DataType t) { @@ -570,7 +582,7 @@ public final class Tensor implements AutoCloseable { } private ByteBuffer buffer() { - return buffer(nativeHandle).order(ByteOrder.nativeOrder()); + return buffer(getNativeHandle()).order(ByteOrder.nativeOrder()); } private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) { @@ -609,6 +621,65 @@ public final class Tensor implements AutoCloseable { } } + /** + * Reference to the underlying native tensor + * + *

Tensors are commonly allocated in a `try-with-resources` statement, where they get + * automatically released after executing the last line of the `try` block they were declared in. + * + *

They can also be attached to an eager session, where in this case their lifetime ends either + * when this session is closed or when the Tensor instance is no longer referenced and have been + * garbage-collected. + * + *

This helper class wraps the tensor native handle and support both situations; If an eager + * reference to the tensor exists, it will take care of releasing the tensor at the end of its + * life. If the tensor is being explicetly closed before this happens, it will take cake of + * clearing its association with any eager session before cleaning up the resources. + */ + private static class NativeReference { + + /** Attaches this reference to an eager session */ + private class EagerReference extends EagerSession.NativeReference { + + EagerReference(EagerSession session, Tensor tensor) { + super(session, tensor); + } + + @Override + void delete() { + // Mark this eager reference as cleared since it has been deleted by the session + NativeReference.this.eagerRef = null; + NativeReference.this.release(); + } + } + + NativeReference(long tensorHandle) { + this.tensorHandle = tensorHandle; + } + + void eager(EagerSession session, Tensor tensor) { + if (eagerRef != null) { + throw new IllegalStateException("The tensor is already attached to an eager session"); + } + eagerRef = new EagerReference(session, tensor); + } + + synchronized void release() { + if (tensorHandle != 0L) { + // Clear any remaining eager reference to this tensor + if (eagerRef != null) { + eagerRef.clear(); + eagerRef = null; + } + Tensor.delete(tensorHandle); + tensorHandle = 0L; + } + } + + private long tensorHandle; + private EagerReference eagerRef; + } + private static HashMap, DataType> classDataTypes = new HashMap<>(); static { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Op.java b/tensorflow/java/src/main/java/org/tensorflow/op/Op.java index aa6db404571..9ac2158fe6f 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Op.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Op.java @@ -18,8 +18,8 @@ package org.tensorflow.op; /** * A marker interface for all operation wrappers. * - *

Operation wrappers provide strongly typed interfaces for building operations and linking them - * into a graph without the use of literals and indexes required by the core classes. + *

Operation wrappers provide strongly typed interfaces for building and execution operations + * without the use of literals and indexes, as required in the core classes. * *

This interface allows keeping references to any operation wrapper using a common type. * diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 5a233bcc984..ccbf776cbe8 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -15,7 +15,11 @@ limitations under the License. package org.tensorflow.op; -import org.tensorflow.Graph; +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Operand; +import org.tensorflow.OperationBuilder; + +import java.util.ArrayList; /** * Manages groups of related properties when creating Tensorflow Operations, such as a common name @@ -78,15 +82,15 @@ public final class Scope { /** * Create a new top-level scope. * - * @param graph The graph instance to be managed by the scope. + * @param env The execution environment used by the scope. */ - public Scope(Graph graph) { - this(graph, new NameScope()); + public Scope(ExecutionEnvironment env) { + this(env, new NameScope(), new ArrayList>()); } - /** Returns the graph managed by this scope. */ - public Graph graph() { - return graph; + /** Returns the execution environment used by this scope. */ + public ExecutionEnvironment env() { + return env; } /** @@ -103,7 +107,7 @@ public final class Scope { * @throws IllegalArgumentException if the name is invalid */ public Scope withSubScope(String childScopeName) { - return new Scope(graph, nameScope.withSubScope(childScopeName)); + return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies); } /** @@ -119,7 +123,7 @@ public final class Scope { * @throws IllegalArgumentException if the name is invalid */ public Scope withName(String opName) { - return new Scope(graph, nameScope.withName(opName)); + return new Scope(env, nameScope.withName(opName), controlDependencies); } /** @@ -131,12 +135,12 @@ public final class Scope { * instance. Typical operator building code might look like * *

{@code
-   * scope.graph().opBuilder("Const", scope.makeOpName("Const"))...
+   * scope.env().opBuilder("Const", scope.makeOpName("Const"))...
    * }
* - *

Note: if you provide a composite operator building class (i.e, a class that adds a - * set of related operations to the graph by calling other operator building code), the provided - * name will act as a subscope to all underlying operators. + *

Note: if you provide a composite operator building class (i.e, a class that creates a + * set of related operations by calling other operator building code), the provided name will act + * as a subscope to all underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. @@ -146,11 +150,39 @@ public final class Scope { return nameScope.makeOpName(defaultName); } - private Scope(Graph graph, NameScope nameScope) { - this.graph = graph; + private Scope( + ExecutionEnvironment env, NameScope nameScope, Iterable> controlDependencies) { + this.env = env; this.nameScope = nameScope; + this.controlDependencies = controlDependencies; } - private final Graph graph; + /** + * Returns a new scope where added operations will have the provided control dependencies. + * + *

Ops created with this scope will have a control edge from each of the provided controls. All + * other properties are inherited from the current scope. + * + * @param controls control dependencies for ops created with the returned scope + * @return a new scope with the provided control dependencies + */ + public Scope withControlDependencies(Iterable> controls) { + return new Scope(env, nameScope, controls); + } + + /** + * Adds each Operand in controlDependencies as a control input to the provided builder. + * + * @param builder OperationBuilder to add control inputs to + */ + public OperationBuilder applyControlDependencies(OperationBuilder builder) { + for (Operand control : controlDependencies) { + builder = builder.addControlInput(control.asOutput().op()); + } + return builder; + } + + private final ExecutionEnvironment env; + private final Iterable> controlDependencies; private final NameScope nameScope; } diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index 00b6726be34..ee4301f1159 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -523,7 +523,7 @@ public final class Constant extends PrimitiveOp implements Operand { */ public static Constant create(Scope scope, String data, Charset charset) { try (Tensor value = Tensor.create(data.getBytes(charset), String.class)) { - return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class)); + return createWithTensor(scope, value); } } @@ -640,7 +640,7 @@ public final class Constant extends PrimitiveOp implements Operand { private static Constant createWithTensor(Scope scope, Tensor value) { return new Constant( scope - .graph() + .env() .opBuilder("Const", scope.makeOpName("Const")) .setAttr("value", value) .setAttr("dtype", value.dataType()) diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index eea9dc1c47c..ab574066837 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -18,6 +18,8 @@ package org.tensorflow.op.core; import java.util.Arrays; import java.util.Iterator; import java.util.List; + +import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; @@ -77,12 +79,18 @@ public class Gradients implements Op, Iterable> { * @param x inputs of the function for which partial derivatives are computed * @param options carries optional attributes values * @return a new instance of {@code Gradients} + * @throws IllegalArgumentException if execution environment is not a graph */ public static Gradients create( Scope scope, Iterable> y, Iterable> x, Options... options) { + if (!(scope.env() instanceof Graph)) { + throw new IllegalArgumentException( + "Gradients can be computed only in a graph execution environment"); + } + Graph graph = (Graph) scope.env(); Output[] dx = null; if (options != null) { for (Options opts : options) { @@ -92,10 +100,8 @@ public class Gradients implements Op, Iterable> { } } Output[] dy = - scope - .graph() - .addGradients( - scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx); + graph.addGradients( + scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx); return new Gradients(Arrays.asList(dy)); } @@ -110,6 +116,7 @@ public class Gradients implements Op, Iterable> { * @param x inputs of the function for which partial derivatives are computed * @param options carries optional attributes values * @return a new instance of {@code Gradients} + * @throws IllegalArgumentException if execution environment is not a graph */ @SuppressWarnings({"unchecked", "rawtypes"}) public static Gradients create( diff --git a/tensorflow/java/src/main/java/org/tensorflow/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/package-info.java index f353ee31459..983cda5260c 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/package-info.java +++ b/tensorflow/java/src/main/java/org/tensorflow/package-info.java @@ -17,10 +17,9 @@ limitations under the License. * Defines classes to build, save, load and execute TensorFlow models. * *

WARNING: The API is currently experimental and is not covered by TensorFlow API stability - * guarantees. See README.md for installation - * instructions. + * href="https://www.tensorflow.org/guide/version_compat">API stability guarantees. See README.md + * for installation instructions. * *

The LabelImage diff --git a/tensorflow/java/src/main/native/BUILD b/tensorflow/java/src/main/native/BUILD index 4eb62b14bc7..97071009d11 100644 --- a/tensorflow/java/src/main/native/BUILD +++ b/tensorflow/java/src/main/native/BUILD @@ -33,13 +33,13 @@ tf_cuda_library( "//tensorflow:android": [], "//conditions:default": ["."], }), - deps = [ - "//tensorflow/c:c_api", - ] + select({ + deps = select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib", ], "//conditions:default": [ + "//tensorflow/c:c_api", + "//tensorflow/c/eager:c_api", "//tensorflow/core:all_kernels", "//tensorflow/core:direct_session", "//tensorflow/core:ops", diff --git a/tensorflow/java/src/main/native/eager_operation_builder_jni.cc b/tensorflow/java/src/main/native/eager_operation_builder_jni.cc new file mode 100644 index 00000000000..f8ed2072ba0 --- /dev/null +++ b/tensorflow/java/src/main/native/eager_operation_builder_jni.cc @@ -0,0 +1,335 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/eager_operation_builder_jni.h" + +#include +#include +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/java/src/main/native/exception_jni.h" + +// This value should be >= to the maximum number of outputs in any op +#define MAX_OUTPUTS_PER_OP 8 + +namespace { + +TFE_Op* requireOp(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "Operation has already been built"); + return nullptr; + } + return reinterpret_cast(handle); +} + +TFE_Context* requireContext(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, "Context has been deleted"); + return nullptr; + } + return reinterpret_cast(handle); +} + +TF_Tensor* requireTensor(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "close() has been called on the Tensor"); + return nullptr; + } + return reinterpret_cast(handle); +} + +TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "Tensor handle has been deleted"); + return nullptr; + } + return reinterpret_cast(handle); +} + +} // namespace + +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate( + JNIEnv* env, jclass clazz, jlong context_handle, jstring name) { + TFE_Context* context = requireContext(env, context_handle); + if (context == nullptr) return 0; + const char* op_or_function_name = env->GetStringUTFChars(name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(context, op_or_function_name, status); + env->ReleaseStringUTFChars(name, op_or_function_name); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + static_assert(sizeof(jlong) >= sizeof(TFE_Op*), + "Cannot represent a C TFE_Op as a Java long"); + return reinterpret_cast(op); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_delete( + JNIEnv* env, jclass clazz, jlong op_handle) { + if (op_handle == 0) return; + TFE_DeleteOp(reinterpret_cast(op_handle)); +} + +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute( + JNIEnv* env, jclass clazz, jlong op_handle) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return 0; + int num_retvals = MAX_OUTPUTS_PER_OP; + std::unique_ptr retvals( + new TFE_TensorHandle*[num_retvals]); + TF_Status* status = TF_NewStatus(); + TFE_Execute(op, retvals.get(), &num_retvals, status); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return nullptr; + } + TF_DeleteStatus(status); + jlongArray rethandles = env->NewLongArray(num_retvals); + if (num_retvals > 0) { + jlong* retval = env->GetLongArrayElements(rethandles, nullptr); + for (int i = 0; i < num_retvals; ++i) { + retval[i] = reinterpret_cast(retvals[i]); + } + env->ReleaseLongArrayElements(rethandles, retval, 0); + } + return rethandles; +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice( + JNIEnv* env, jclass clazz, jlong op_handle, jstring device_name) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + const char* cname = env->GetStringUTFChars(device_name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_OpSetDevice(op, cname, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + env->ReleaseStringUTFChars(device_name, cname); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput( + JNIEnv* env, jclass clazz, jlong op_handle, jlong input_handle) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, input_handle); + if (tensor_handle == nullptr) return; + TF_Status* status = TF_NewStatus(); + TFE_OpAddInput(op, tensor_handle, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList( + JNIEnv* env, jclass clazz, jlong op_handle, jlongArray input_handles) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + jlong* cinput_handles = env->GetLongArrayElements(input_handles, nullptr); + size_t num_inputs = static_cast(env->GetArrayLength(input_handles)); + std::unique_ptr tensor_handles( + new TFE_TensorHandle*[num_inputs]); + for (int i = 0; i < num_inputs; ++i) { + tensor_handles[i] = requireTensorHandle(env, cinput_handles[i]); + if (tensor_handles[i] == nullptr) { + env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT); + return; + } + } + env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT); + TF_Status* status = TF_NewStatus(); + TFE_OpAddInputList(op, tensor_handles.get(), num_inputs, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString( + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, + jbyteArray value) { + static_assert(sizeof(jbyte) == 1, + "Require Java byte to be represented as a single byte"); + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + jbyte* cvalue = env->GetByteArrayElements(value, nullptr); + TFE_OpSetAttrString(op, cname, cvalue, env->GetArrayLength(value)); + env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT); + env->ReleaseStringUTFChars(attr_name, cname); +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrStringList( + JNIEnv* env, jclass object, jlong op_handle, jstring attr_name, + jobjectArray values) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + int num_values = env->GetArrayLength(values); + static_assert(sizeof(jbyte) == 1, + "Require Java byte to be represented as a single byte"); + std::unique_ptr jarrays(new jbyteArray[num_values]); + std::unique_ptr jvalues(new jbyte*[num_values]); + std::unique_ptr cvalues(new void*[num_values]); + std::unique_ptr lengths(new size_t[num_values]); + + for (int i = 0; i < num_values; ++i) { + jbyteArray v = + static_cast(env->GetObjectArrayElement(values, i)); + jarrays[i] = v; + jvalues[i] = env->GetByteArrayElements(v, nullptr); + cvalues[i] = jvalues[i]; + lengths[i] = static_cast(env->GetArrayLength(v)); + } + TFE_OpSetAttrStringList(op, cname, cvalues.get(), lengths.get(), num_values); + for (int i = 0; i < num_values; ++i) { + env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT); + } + env->ReleaseStringUTFChars(attr_name, cname); +} + +#define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \ + JNIEXPORT void JNICALL \ + Java_org_tensorflow_EagerOperationBuilder_setAttr##name( \ + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \ + jtype value) { \ + static_assert( \ + sizeof(ctype) >= sizeof(jtype), \ + "Information loss when converting between Java and C types"); \ + TFE_Op* op = requireOp(env, op_handle); \ + if (op == nullptr) return; \ + const char* cname = env->GetStringUTFChars(attr_name, nullptr); \ + TFE_OpSetAttr##name(op, cname, static_cast(value)); \ + env->ReleaseStringUTFChars(attr_name, cname); \ + } + +#define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \ + JNIEXPORT void JNICALL \ + Java_org_tensorflow_EagerOperationBuilder_setAttr##name##List( \ + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \ + jtype##Array value) { \ + TFE_Op* op = requireOp(env, op_handle); \ + if (op == nullptr) return; \ + const char* cname = env->GetStringUTFChars(attr_name, nullptr); \ + /* Make a copy of the array to paper over any differences */ \ + /* in byte representations of the jtype and ctype */ \ + /* For example, jint vs TF_DataType. */ \ + /* If this copy turns out to be a problem in practice */ \ + /* can avoid it for many types. */ \ + const int n = env->GetArrayLength(value); \ + std::unique_ptr cvalue(new ctype[n]); \ + jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \ + for (int i = 0; i < n; ++i) { \ + cvalue[i] = static_cast(elems[i]); \ + } \ + TFE_OpSetAttr##name##List(op, cname, cvalue.get(), n); \ + env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \ + env->ReleaseStringUTFChars(attr_name, cname); \ + } + +#define DEFINE_SET_ATTR(name, jname, jtype, ctype) \ + DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \ + DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) + +DEFINE_SET_ATTR(Int, Long, jlong, int64_t); +DEFINE_SET_ATTR(Float, Float, jfloat, float); +DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char); +DEFINE_SET_ATTR(Type, Int, jint, TF_DataType); +#undef DEFINE_SET_ATTR +#undef DEFINE_SET_ATTR_LIST +#undef DEFINE_SET_ATTR_SCALAR + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor( + JNIEnv* env, jclass clazz, jlong handle, jstring attr_name, + jlong tensor_handle) { + TFE_Op* op = requireOp(env, handle); + if (op == nullptr) return; + TF_Tensor* t = requireTensor(env, tensor_handle); + if (t == nullptr) return; + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_OpSetAttrTensor(op, cname, t, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + env->ReleaseStringUTFChars(attr_name, cname); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape( + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, + jlongArray shape, jint num_dims) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + std::unique_ptr cvalue; + // num_dims and env->GetArrayLength(shape) are assumed to be consistent. + // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape). + if (num_dims > 0) { + cvalue.reset(new int64_t[num_dims]); + jlong* elems = env->GetLongArrayElements(shape, nullptr); + for (int i = 0; i < num_dims; ++i) { + cvalue[i] = static_cast(elems[i]); + } + env->ReleaseLongArrayElements(shape, elems, JNI_ABORT); + } + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_OpSetAttrShape(op, cname, cvalue.get(), static_cast(num_dims), + status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + env->ReleaseStringUTFChars(attr_name, cname); +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList( + JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, + jlongArray shapes, jintArray num_dims) { + TFE_Op* op = requireOp(env, op_handle); + if (op == nullptr) return; + std::unique_ptr cshapes; + std::unique_ptr cdims; + std::unique_ptr cnum_dims; + const int num_dims_length = env->GetArrayLength(num_dims); + if (num_dims_length > 0) { + const int shapes_length = env->GetArrayLength(shapes); + cshapes.reset(new int64_t[shapes_length]); + cdims.reset(new const int64_t*[num_dims_length]); + cnum_dims.reset(new int[num_dims_length]); + jlong* shapes_elems = + static_cast(env->GetPrimitiveArrayCritical(shapes, nullptr)); + std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3); + env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT); + int64_t* cshapes_ptr = cshapes.get(); + jint* num_dims_elems = + static_cast(env->GetPrimitiveArrayCritical(num_dims, nullptr)); + for (int i = 0; i < num_dims_length; ++i) { + cnum_dims[i] = static_cast(num_dims_elems[i]); + cdims[i] = cshapes_ptr; + if (cnum_dims[i] > 0) { + cshapes_ptr += cnum_dims[i]; + } + } + env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT); + } + const char* cname = env->GetStringUTFChars(attr_name, nullptr); + TF_Status* status = TF_NewStatus(); + TFE_OpSetAttrShapeList(op, cname, cdims.get(), cnum_dims.get(), + num_dims_length, status); + throwExceptionIfNotOK(env, status); + TF_DeleteStatus(status); + env->ReleaseStringUTFChars(attr_name, cname); +} diff --git a/tensorflow/java/src/main/native/eager_operation_builder_jni.h b/tensorflow/java/src/main/native/eager_operation_builder_jni.h new file mode 100644 index 00000000000..6da891d7ae2 --- /dev/null +++ b/tensorflow/java/src/main/native/eager_operation_builder_jni.h @@ -0,0 +1,191 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_ +#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: allocate + * Signature: (JLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate( + JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: delete + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_delete(JNIEnv *, jclass, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: execute + * Signature: (J)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_EagerOperationBuilder_execute(JNIEnv *, jclass, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: addInput + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput( + JNIEnv *, jclass, jlong, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: addInputList + * Signature: (J[J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList( + JNIEnv *, jclass, jlong, jlongArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setDevice + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice( + JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrString + * Signature: (JLjava/lang/String;[B)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString( + JNIEnv *, jclass, jlong, jstring, jbyteArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrStringList + * Signature: (JLjava/lang/String;[L)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(JNIEnv *, jclass, + jlong, jstring, + jobjectArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrInt + * Signature: (JLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrInt( + JNIEnv *, jclass, jlong, jstring, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrIntList + * Signature: (JLjava/lang/String;[J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrIntList( + JNIEnv *, jclass, jlong, jstring, jlongArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrFloat + * Signature: (JLjava/lang/String;F)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrFloat( + JNIEnv *, jclass, jlong, jstring, jfloat); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrFloatList + * Signature: (JLjava/lang/String;[F)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrFloatList(JNIEnv *, jclass, + jlong, jstring, + jfloatArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrBool + * Signature: (JLjava/lang/String;Z)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrBool( + JNIEnv *, jclass, jlong, jstring, jboolean); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrBoolList + * Signature: (JLjava/lang/String;[Z)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrBoolList(JNIEnv *, jclass, + jlong, jstring, + jbooleanArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrType + * Signature: (JLjava/lang/String;I)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrType( + JNIEnv *, jclass, jlong, jstring, jint); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrTypeList + * Signature: (JLjava/lang/String;[I)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrTypeList(JNIEnv *, jclass, + jlong, jstring, + jintArray); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrTensor + * Signature: (JLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor( + JNIEnv *, jclass, jlong, jstring, jlong); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrShape + * Signature: (JLjava/lang/String;[JI)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape( + JNIEnv *, jclass, jlong, jstring, jlongArray, jint); + +/* + * Class: org_tensorflow_EagerOperationBuilder + * Method: setAttrShapeList + * Signature: (JLjava/lang/String;[J[I)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(JNIEnv *, jclass, + jlong, jstring, + jlongArray, + jintArray); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_BUILDER_JNI_H_ diff --git a/tensorflow/java/src/main/native/eager_operation_jni.cc b/tensorflow/java/src/main/native/eager_operation_jni.cc new file mode 100644 index 00000000000..2dbe81efd35 --- /dev/null +++ b/tensorflow/java/src/main/native/eager_operation_jni.cc @@ -0,0 +1,146 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/eager_operation_jni.h" + +#include +#include +#include + +#include +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/java/src/main/native/exception_jni.h" + +namespace { + +TFE_Op* requireOp(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, + "Eager session has been closed"); + return nullptr; + } + return reinterpret_cast(handle); +} + +TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalStateException, "EagerSession has been closed"); + return nullptr; + } + return reinterpret_cast(handle); +} + +} // namespace + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv* env, + jclass clazz, + jlong handle) { + if (handle == 0) return; + TFE_DeleteOp(reinterpret_cast(handle)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_deleteTensorHandle( + JNIEnv* env, jclass clazz, jlong handle) { + if (handle == 0) return; + TFE_DeleteTensorHandle(reinterpret_cast(handle)); +} + +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_resolveTensorHandle( + JNIEnv* env, jclass clazz, jlong handle) { + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle); + if (tensor_handle == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + TF_Tensor* tensor = TFE_TensorHandleResolve(tensor_handle, status); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + static_assert(sizeof(jlong) >= sizeof(TF_Tensor*), + "Cannot represent a C TF_Tensor as a Java long"); + return reinterpret_cast(tensor); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength( + JNIEnv* env, jclass clazz, jlong handle, jstring name) { + TFE_Op* op = requireOp(env, handle); + if (op == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + const char* cname = env->GetStringUTFChars(name, nullptr); + int length = TFE_OpGetOutputLength(op, cname, status); + env->ReleaseStringUTFChars(name, cname); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return static_cast(length); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength( + JNIEnv* env, jclass clazz, jlong handle, jstring name) { + TFE_Op* op = requireOp(env, handle); + if (op == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + const char* cname = env->GetStringUTFChars(name, nullptr); + int length = TFE_OpGetInputLength(op, cname, status); + env->ReleaseStringUTFChars(name, cname); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return static_cast(length); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType( + JNIEnv* env, jclass clazz, jlong handle) { + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle); + if (tensor_handle == nullptr) return 0; + TF_DataType data_type = TFE_TensorHandleDataType(tensor_handle); + return static_cast(data_type); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims( + JNIEnv* env, jclass clazz, jlong handle) { + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle); + if (tensor_handle == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + int num_dims = TFE_TensorHandleNumDims(tensor_handle, status); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return static_cast(num_dims); +} + +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv* env, + jclass clazz, + jlong handle, + jint dim_index) { + TFE_TensorHandle* tensor_handle = requireTensorHandle(env, handle); + if (tensor_handle == nullptr) return 0; + TF_Status* status = TF_NewStatus(); + int64_t dim = TFE_TensorHandleDim(tensor_handle, dim_index, status); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + return static_cast(dim); +} diff --git a/tensorflow/java/src/main/native/eager_operation_jni.h b/tensorflow/java/src/main/native/eager_operation_jni.h new file mode 100644 index 00000000000..ef38ed038c9 --- /dev/null +++ b/tensorflow/java/src/main/native/eager_operation_jni.h @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_ +#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_tensorflow_EagerOperation + * Method: delete + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperation_delete(JNIEnv *, + jclass, jlong); + +/* + * Class: org_tensorflow_EagerOperation + * Method: deleteTensorHandle + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_EagerOperation_deleteTensorHandle(JNIEnv *, jclass, jlong); + +/** + * Class: org_tensorflow_EagerOperation + * Method: resolveTensorHandle + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_EagerOperation_resolveTensorHandle(JNIEnv *, jclass, jlong); + +/** + * Class: org_tensorflow_EagerOperation + * Method: outputListLength + * Signature: (JLjava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_outputListLength( + JNIEnv *, jclass, jlong, jstring); + +/** + * Class: org_tensorflow_EagerOperation + * Method: inputListLength + * Signature: (JLjava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_inputListLength( + JNIEnv *, jclass, jlong, jstring); + +/** + * Class: org_tensorflow_EagerOperation + * Method: dataType + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_dataType(JNIEnv *, + jclass, + jlong); + +/** + * Class: org_tensorflow_EagerOperation + * Method: numDims + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_EagerOperation_numDims(JNIEnv *, + jclass, + jlong); + +/** + * Class: org_tensorflow_EagerOperation + * Method: dim + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperation_dim(JNIEnv *, jclass, + jlong, jint); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_OPERATION_JNI_H_ diff --git a/tensorflow/java/src/main/native/eager_session_jni.cc b/tensorflow/java/src/main/native/eager_session_jni.cc new file mode 100644 index 00000000000..58905205c94 --- /dev/null +++ b/tensorflow/java/src/main/native/eager_session_jni.cc @@ -0,0 +1,64 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/eager_session_jni.h" + +#include +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/java/src/main/native/exception_jni.h" + +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate( + JNIEnv* env, jclass clazz, jboolean async, jint dpp, jbyteArray config) { + TFE_ContextOptions* opts = TFE_NewContextOptions(); + jbyte* cconfig = nullptr; + TF_Status* status = TF_NewStatus(); + if (config != nullptr) { + cconfig = env->GetByteArrayElements(config, nullptr); + TFE_ContextOptionsSetConfig( + opts, cconfig, static_cast(env->GetArrayLength(config)), + status); + if (!throwExceptionIfNotOK(env, status)) { + env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT); + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + return 0; + } + } + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy( + opts, static_cast(dpp)); + TFE_Context* context = TFE_NewContext(opts, status); + TFE_DeleteContextOptions(opts); + if (config != nullptr) { + env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT); + } + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return 0; + } + TF_DeleteStatus(status); + static_assert(sizeof(jlong) >= sizeof(TFE_Context*), + "Cannot represent a C TFE_Op as a Java long"); + return reinterpret_cast(context); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv* env, + jclass clazz, + jlong handle) { + if (handle == 0) return; + TFE_DeleteContext(reinterpret_cast(handle)); +} diff --git a/tensorflow/lite/java/src/main/native/init_tensorflow_jni.h b/tensorflow/java/src/main/native/eager_session_jni.h similarity index 54% rename from tensorflow/lite/java/src/main/native/init_tensorflow_jni.h rename to tensorflow/java/src/main/native/eager_session_jni.h index 1454d6d4633..9f7bdaccd36 100644 --- a/tensorflow/lite/java/src/main/native/init_tensorflow_jni.h +++ b/tensorflow/java/src/main/native/eager_session_jni.h @@ -12,25 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ -#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ + +#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_SESSION_JNI_H_ +#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_SESSION_JNI_H_ #include #ifdef __cplusplus extern "C" { -#endif // __cplusplus +#endif /* - * Class: org_tensorflow_lite_TensorFlowLite - * Method: initTensorFlow - * Signature: ()V + * Class: org_tensorflow_EagerSession + * Method: allocate + * Signature: (ZI[B)J */ -JNIEXPORT void JNICALL Java_org_tensorflow_lite_TensorFlowLite_initTensorFlow( - JNIEnv* env, jclass clazz); +JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerSession_allocate( + JNIEnv *env, jclass clazz, jboolean async, jint dpp, jbyteArray config); + +/* + * Class: org_tensorflow_EagerSession + * Method: delete + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_EagerSession_delete(JNIEnv *, jclass, + jlong); #ifdef __cplusplus } // extern "C" #endif // __cplusplus - -#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_INIT_TENSORFLOW_JNI_H_ +#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EAGER_SESSION_JNI_H_ diff --git a/tensorflow/java/src/main/native/operation_builder_jni.cc b/tensorflow/java/src/main/native/graph_operation_builder_jni.cc similarity index 85% rename from tensorflow/java/src/main/native/operation_builder_jni.cc rename to tensorflow/java/src/main/native/graph_operation_builder_jni.cc index 55d214a7c4b..28c06bbf6c7 100644 --- a/tensorflow/java/src/main/native/operation_builder_jni.cc +++ b/tensorflow/java/src/main/native/graph_operation_builder_jni.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/java/src/main/native/operation_builder_jni.h" - +#include "tensorflow/java/src/main/native/graph_operation_builder_jni.h" #include #include #include "tensorflow/c/c_api.h" @@ -51,7 +50,7 @@ TF_Tensor* requireTensor(JNIEnv* env, jlong handle) { } } // namespace -JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_allocate( +JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_allocate( JNIEnv* env, jclass clazz, jlong graph_handle, jstring type, jstring name) { if (graph_handle == 0) { throwException(env, kIllegalStateException, @@ -69,7 +68,7 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_allocate( return reinterpret_cast(d); } -JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_finish( +JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_finish( JNIEnv* env, jclass clazz, jlong handle) { TF_OperationDescription* d = requireHandle(env, handle); if (d == nullptr) return 0; @@ -83,7 +82,7 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_finish( return 0; } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInput( +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInput( JNIEnv* env, jclass clazz, jlong handle, jlong op_handle, jint index) { TF_Output out; if (!resolveOutput(env, op_handle, index, &out)) return; @@ -92,7 +91,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInput( TF_AddInput(d, out); } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInputList( +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInputList( JNIEnv* env, jclass clazz, jlong handle, jlongArray op_handles, jintArray indices) { TF_OperationDescription* d = requireHandle(env, handle); @@ -118,8 +117,11 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInputList( TF_AddInputList(d, o.get(), n); } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addControlInput( - JNIEnv* env, jclass clazz, jlong handle, jlong op_handle) { +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_addControlInput(JNIEnv* env, + jclass clazz, + jlong handle, + jlong op_handle) { if (op_handle == 0) { throwException(env, kIllegalStateException, "control input is not valid, " @@ -132,7 +134,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addControlInput( TF_AddControlInput(d, control); } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setDevice( +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setDevice( JNIEnv* env, jclass clazz, jlong handle, jstring device) { TF_OperationDescription* d = requireHandle(env, handle); if (d == nullptr) return; @@ -141,7 +143,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setDevice( env->ReleaseStringUTFChars(device, cdevice); } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrString( +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrString( JNIEnv* env, jclass clazz, jlong handle, jstring name, jbyteArray value) { static_assert(sizeof(jbyte) == 1, "Require Java byte to be represented as a single byte"); @@ -154,41 +156,43 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrString( env->ReleaseStringUTFChars(name, cname); } -#define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \ - JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttr##name( \ - JNIEnv* env, jclass clazz, jlong handle, jstring name, jtype value) { \ - static_assert( \ - sizeof(ctype) >= sizeof(jtype), \ - "Information loss when converting between Java and C types"); \ - TF_OperationDescription* d = requireHandle(env, handle); \ - if (d == nullptr) return; \ - const char* cname = env->GetStringUTFChars(name, nullptr); \ - TF_SetAttr##name(d, cname, static_cast(value)); \ - env->ReleaseStringUTFChars(name, cname); \ +#define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \ + JNIEXPORT void JNICALL \ + Java_org_tensorflow_GraphOperationBuilder_setAttr##name( \ + JNIEnv* env, jclass clazz, jlong handle, jstring name, \ + jtype value) { \ + static_assert( \ + sizeof(ctype) >= sizeof(jtype), \ + "Information loss when converting between Java and C types"); \ + TF_OperationDescription* d = requireHandle(env, handle); \ + if (d == nullptr) return; \ + const char* cname = env->GetStringUTFChars(name, nullptr); \ + TF_SetAttr##name(d, cname, static_cast(value)); \ + env->ReleaseStringUTFChars(name, cname); \ } -#define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \ - JNIEXPORT void JNICALL \ - Java_org_tensorflow_OperationBuilder_setAttr##name##List( \ - JNIEnv* env, jclass clazz, jlong handle, jstring name, \ - jtype##Array value) { \ - TF_OperationDescription* d = requireHandle(env, handle); \ - if (d == nullptr) return; \ - const char* cname = env->GetStringUTFChars(name, nullptr); \ - /* Make a copy of the array to paper over any differences */ \ - /* in byte representations of the jtype and ctype */ \ - /* For example, jint vs TF_DataType. */ \ - /* If this copy turns out to be a problem in practice */ \ - /* can avoid it for many types. */ \ - const int n = env->GetArrayLength(value); \ - std::unique_ptr cvalue(new ctype[n]); \ - jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \ - for (int i = 0; i < n; ++i) { \ - cvalue[i] = static_cast(elems[i]); \ - } \ - TF_SetAttr##name##List(d, cname, cvalue.get(), n); \ - env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \ - env->ReleaseStringUTFChars(name, cname); \ +#define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \ + JNIEXPORT void JNICALL \ + Java_org_tensorflow_GraphOperationBuilder_setAttr##name##List( \ + JNIEnv* env, jclass clazz, jlong handle, jstring name, \ + jtype##Array value) { \ + TF_OperationDescription* d = requireHandle(env, handle); \ + if (d == nullptr) return; \ + const char* cname = env->GetStringUTFChars(name, nullptr); \ + /* Make a copy of the array to paper over any differences */ \ + /* in byte representations of the jtype and ctype */ \ + /* For example, jint vs TF_DataType. */ \ + /* If this copy turns out to be a problem in practice */ \ + /* can avoid it for many types. */ \ + const int n = env->GetArrayLength(value); \ + std::unique_ptr cvalue(new ctype[n]); \ + jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \ + for (int i = 0; i < n; ++i) { \ + cvalue[i] = static_cast(elems[i]); \ + } \ + TF_SetAttr##name##List(d, cname, cvalue.get(), n); \ + env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \ + env->ReleaseStringUTFChars(name, cname); \ } #define DEFINE_SET_ATTR(name, jname, jtype, ctype) \ @@ -203,7 +207,7 @@ DEFINE_SET_ATTR(Type, Int, jint, TF_DataType); #undef DEFINE_SET_ATTR_LIST #undef DEFINE_SET_ATTR_SCALAR -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensor( +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrTensor( JNIEnv* env, jclass clazz, jlong handle, jstring name, jlong tensor_handle) { TF_OperationDescription* d = requireHandle(env, handle); @@ -218,13 +222,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensor( env->ReleaseStringUTFChars(name, cname); } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensorList( +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrTensorList( JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray tensor_handles) { TF_OperationDescription* d = requireHandle(env, handle); if (d == nullptr) return; const int n = env->GetArrayLength(tensor_handles); - std::unique_ptr tensors(new TF_Tensor*[n]); + std::unique_ptr tensors(new TF_Tensor*[n]); jlong* jhandles = env->GetLongArrayElements(tensor_handles, nullptr); bool ok = true; for (int i = 0; i < n && ok; ++i) { @@ -242,7 +247,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensorList( env->ReleaseStringUTFChars(name, cname); } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShape( +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrShape( JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shape, jint num_dims) { TF_OperationDescription* d = requireHandle(env, handle); @@ -263,13 +268,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShape( env->ReleaseStringUTFChars(name, cname); } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList( +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrShapeList( JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shapes, jintArray num_dims) { TF_OperationDescription* d = requireHandle(env, handle); if (d == nullptr) return; std::unique_ptr cshapes; - std::unique_ptr cdims; + std::unique_ptr cdims; std::unique_ptr cnum_dims; const int num_dims_length = env->GetArrayLength(num_dims); if (num_dims_length > 0) { @@ -298,7 +304,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList( env->ReleaseStringUTFChars(name, cname); } -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList( +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrStringList( JNIEnv* env, jclass object, jlong handle, jstring name, jobjectArray values) { TF_OperationDescription* d = requireHandle(env, handle); @@ -308,8 +315,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList( static_assert(sizeof(jbyte) == 1, "Require Java byte to be represented as a single byte"); std::unique_ptr jarrays(new jbyteArray[num_values]); - std::unique_ptr jvalues(new jbyte*[num_values]); - std::unique_ptr cvalues(new void*[num_values]); + std::unique_ptr jvalues(new jbyte*[num_values]); + std::unique_ptr cvalues(new void*[num_values]); std::unique_ptr lengths(new size_t[num_values]); for (int i = 0; i < num_values; ++i) { diff --git a/tensorflow/java/src/main/native/graph_operation_builder_jni.h b/tensorflow/java/src/main/native/graph_operation_builder_jni.h new file mode 100644 index 00000000000..fe76fcf28e7 --- /dev/null +++ b/tensorflow/java/src/main/native/graph_operation_builder_jni.h @@ -0,0 +1,202 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_BUILDER_JNI_H_ +#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_BUILDER_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: allocate + * Signature: (JLjava/lang/String;Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_allocate( + JNIEnv *, jclass, jlong, jstring, jstring); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: finish + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_GraphOperationBuilder_finish(JNIEnv *, jclass, jlong); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: addInput + * Signature: (JJI)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInput( + JNIEnv *, jclass, jlong, jlong, jint); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: addInputList + * Signature: (J[J[I)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInputList( + JNIEnv *, jclass, jlong, jlongArray, jintArray); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: addControlInput + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_addControlInput(JNIEnv *, jclass, + jlong, jlong); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setDevice + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setDevice( + JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrString + * Signature: (JLjava/lang/String;[B)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrString( + JNIEnv *, jclass, jlong, jstring, jbyteArray); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrInt + * Signature: (JLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrInt( + JNIEnv *, jclass, jlong, jstring, jlong); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrIntList + * Signature: (JLjava/lang/String;[J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrIntList( + JNIEnv *, jclass, jlong, jstring, jlongArray); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrFloat + * Signature: (JLjava/lang/String;F)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrFloat( + JNIEnv *, jclass, jlong, jstring, jfloat); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrFloatList + * Signature: (JLjava/lang/String;[F)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrFloatList(JNIEnv *, jclass, + jlong, jstring, + jfloatArray); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrBool + * Signature: (JLjava/lang/String;Z)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrBool( + JNIEnv *, jclass, jlong, jstring, jboolean); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrBoolList + * Signature: (JLjava/lang/String;[Z)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrBoolList(JNIEnv *, jclass, + jlong, jstring, + jbooleanArray); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrType + * Signature: (JLjava/lang/String;I)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrType( + JNIEnv *, jclass, jlong, jstring, jint); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrTypeList + * Signature: (JLjava/lang/String;[I)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrTypeList(JNIEnv *, jclass, + jlong, jstring, + jintArray); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrTensor + * Signature: (JLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrTensor( + JNIEnv *, jclass, jlong, jstring, jlong); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrTensorList + * Signature: (JLjava/lang/String;[J)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrTensorList(JNIEnv *, jclass, + jlong, jstring, + jlongArray); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrShape + * Signature: (JLjava/lang/String;[JI)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrShape( + JNIEnv *, jclass, jlong, jstring, jlongArray, jint); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrShapeList + * Signature: (JLjava/lang/String;[J[I)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrShapeList(JNIEnv *, jclass, + jlong, jstring, + jlongArray, + jintArray); + +/* + * Class: org_tensorflow_GraphOperationBuilder + * Method: setAttrStringList + * Signature: (JLjava/lang/String;[L)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_GraphOperationBuilder_setAttrStringList(JNIEnv *, jclass, + jlong, jstring, + jobjectArray); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_BUILDER_JNI_H_ diff --git a/tensorflow/java/src/main/native/operation_jni.cc b/tensorflow/java/src/main/native/graph_operation_jni.cc similarity index 71% rename from tensorflow/java/src/main/native/operation_jni.cc rename to tensorflow/java/src/main/native/graph_operation_jni.cc index ccc44d91c00..9c5fe786416 100644 --- a/tensorflow/java/src/main/native/operation_jni.cc +++ b/tensorflow/java/src/main/native/graph_operation_jni.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/java/src/main/native/operation_jni.h" - +#include "tensorflow/java/src/main/native/graph_operation_jni.h" #include #include "tensorflow/c/c_api.h" #include "tensorflow/java/src/main/native/exception_jni.h" @@ -42,34 +41,29 @@ TF_Graph* requireGraphHandle(JNIEnv* env, jlong handle) { } } // namespace -JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv* env, - jclass clazz, - jlong handle) { +JNIEXPORT jstring JNICALL Java_org_tensorflow_GraphOperation_name( + JNIEnv* env, jclass clazz, jlong handle) { TF_Operation* op = requireHandle(env, handle); if (op == nullptr) return nullptr; return env->NewStringUTF(TF_OperationName(op)); } -JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv* env, - jclass clazz, - jlong handle) { +JNIEXPORT jstring JNICALL Java_org_tensorflow_GraphOperation_type( + JNIEnv* env, jclass clazz, jlong handle) { TF_Operation* op = requireHandle(env, handle); if (op == nullptr) return nullptr; return env->NewStringUTF(TF_OperationOpType(op)); } -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env, - jclass clazz, - jlong handle) { +JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_numOutputs( + JNIEnv* env, jclass clazz, jlong handle) { TF_Operation* op = requireHandle(env, handle); if (op == nullptr) return 0; return TF_OperationNumOutputs(op); } -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv* env, - jclass clazz, - jlong handle, - jstring name) { +JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_outputListLength( + JNIEnv* env, jclass clazz, jlong handle, jstring name) { TF_Operation* op = requireHandle(env, handle); if (op == nullptr) return 0; @@ -84,7 +78,7 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv* en return result; } -JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape( +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_GraphOperation_shape( JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle, jint output_index) { TF_Graph* graph = requireGraphHandle(env, graph_handle); @@ -135,11 +129,9 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape( return ret; } -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv* env, - jclass clazz, - jlong graph_handle, - jlong op_handle, - jint output_index) { +JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_dtype( + JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle, + jint output_index) { TF_Graph* graph = requireGraphHandle(env, graph_handle); if (graph == nullptr) return 0; TF_Operation* op = requireHandle(env, op_handle); @@ -157,10 +149,8 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv* env, return static_cast(TF_OperationOutputType(TF_Output{op, output_index})); } -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv* env, - jclass clazz, - jlong handle, - jstring name) { +JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_inputListLength( + JNIEnv* env, jclass clazz, jlong handle, jstring name) { TF_Operation* op = requireHandle(env, handle); if (op == nullptr) return 0; diff --git a/tensorflow/java/src/main/native/operation_jni.h b/tensorflow/java/src/main/native/graph_operation_jni.h similarity index 53% rename from tensorflow/java/src/main/native/operation_jni.h rename to tensorflow/java/src/main/native/graph_operation_jni.h index 56da2ebaee3..bad4ada9cea 100644 --- a/tensorflow/java/src/main/native/operation_jni.h +++ b/tensorflow/java/src/main/native/graph_operation_jni.h @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_JNI_H_ -#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_JNI_H_ +#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_JNI_H_ +#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_JNI_H_ #include @@ -23,68 +23,66 @@ extern "C" { #endif /* - * Class: org_tensorflow_Operation + * Class: org_tensorflow_GraphOperation * Method: name * Signature: (J)Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv *, jclass, - jlong); +JNIEXPORT jstring JNICALL Java_org_tensorflow_GraphOperation_name(JNIEnv *, + jclass, + jlong); /* - * Class: org_tensorflow_Operation + * Class: org_tensorflow_GraphOperation * Method: type * Signature: (J)Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv *, jclass, - jlong); +JNIEXPORT jstring JNICALL Java_org_tensorflow_GraphOperation_type(JNIEnv *, + jclass, + jlong); /* - * Class: org_tensorflow_Operation + * Class: org_tensorflow_GraphOperation * Method: numOutputs * Signature: (J)I */ -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv *, - jclass, jlong); +JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_numOutputs(JNIEnv *, + jclass, + jlong); /* - * Class: org_tensorflow_Operation + * Class: org_tensorflow_GraphOperation * Method: outputListLength * Signature: (JLjava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv *, - jclass, - jlong, - jstring); +JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_outputListLength( + JNIEnv *, jclass, jlong, jstring); /* - * Class: org_tensorflow_Operation + * Class: org_tensorflow_GraphOperation * Method: shape * Signature: (JJI)[J */ -JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape(JNIEnv *, - jclass, jlong, - jlong, jint); +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_GraphOperation_shape(JNIEnv *, jclass, jlong, jlong, jint); /* - * Class: org_tensorflow_Operation + * Class: org_tensorflow_GraphOperation * Method: dtype * Signature: (JJI)I */ -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv *, jclass, - jlong, jlong, jint); - +JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_dtype(JNIEnv *, + jclass, jlong, + jlong, jint); /* - * Class: org_tensorflow_Operation + * Class: org_tensorflow_GraphOperation * Method: inputListLength * Signature: (JLjava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv *, - jclass, - jlong, - jstring); +JNIEXPORT jint JNICALL Java_org_tensorflow_GraphOperation_inputListLength( + JNIEnv *, jclass, jlong, jstring); #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_JNI_H_ +#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_OPERATION_JNI_H_ diff --git a/tensorflow/java/src/main/native/operation_builder_jni.h b/tensorflow/java/src/main/native/operation_builder_jni.h deleted file mode 100644 index 1cda7acea88..00000000000 --- a/tensorflow/java/src/main/native/operation_builder_jni.h +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_BUILDER_JNI_H_ -#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_BUILDER_JNI_H_ - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/* - * Class: org_tensorflow_OperationBuilder - * Method: allocate - * Signature: (JLjava/lang/String;Ljava/lang/String;)J - */ -JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_allocate( - JNIEnv *, jclass, jlong, jstring, jstring); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: finish - * Signature: (J)J - */ -JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_finish(JNIEnv *, - jclass, - jlong); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: addInput - * Signature: (JJI)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInput( - JNIEnv *, jclass, jlong, jlong, jint); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: addInputList - * Signature: (J[J[I)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInputList( - JNIEnv *, jclass, jlong, jlongArray, jintArray); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: addControlInput - * Signature: (JJ)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addControlInput( - JNIEnv *, jclass, jlong, jlong); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setDevice - * Signature: (JLjava/lang/String;)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setDevice(JNIEnv *, - jclass, - jlong, - jstring); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrString - * Signature: (JLjava/lang/String;[B)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrString( - JNIEnv *, jclass, jlong, jstring, jbyteArray); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrInt - * Signature: (JLjava/lang/String;J)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrInt( - JNIEnv *, jclass, jlong, jstring, jlong); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrIntList - * Signature: (JLjava/lang/String;[J)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrIntList( - JNIEnv *, jclass, jlong, jstring, jlongArray); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrFloat - * Signature: (JLjava/lang/String;F)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrFloat( - JNIEnv *, jclass, jlong, jstring, jfloat); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrFloatList - * Signature: (JLjava/lang/String;[F)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrFloatList( - JNIEnv *, jclass, jlong, jstring, jfloatArray); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrBool - * Signature: (JLjava/lang/String;Z)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrBool( - JNIEnv *, jclass, jlong, jstring, jboolean); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrBoolList - * Signature: (JLjava/lang/String;[Z)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrBoolList( - JNIEnv *, jclass, jlong, jstring, jbooleanArray); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrType - * Signature: (JLjava/lang/String;I)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrType( - JNIEnv *, jclass, jlong, jstring, jint); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrTypeList - * Signature: (JLjava/lang/String;[I)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTypeList( - JNIEnv *, jclass, jlong, jstring, jintArray); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrTensor - * Signature: (JLjava/lang/String;J)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensor( - JNIEnv *, jclass, jlong, jstring, jlong); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrTensorList - * Signature: (JLjava/lang/String;[J)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensorList( - JNIEnv *, jclass, jlong, jstring, jlongArray); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrShape - * Signature: (JLjava/lang/String;[JI)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShape( - JNIEnv *, jclass, jlong, jstring, jlongArray, jint); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrShapeList - * Signature: (JLjava/lang/String;[J[I)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList( - JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray); - -/* - * Class: org_tensorflow_OperationBuilder - * Method: setAttrStringList - * Signature: (JLjava/lang/String;[L)V - */ -JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList( - JNIEnv *, jclass, jlong, jstring, jobjectArray); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus -#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_BUILDER_JNI_H_ diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java new file mode 100644 index 00000000000..0f00a26dba4 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -0,0 +1,145 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link EagerOperationBuilder} class. */ +@RunWith(JUnit4.class) +public class EagerOperationBuilderTest { + + @Test + public void failToCreateIfSessionIsClosed() { + EagerSession session = EagerSession.create(); + session.close(); + try { + new EagerOperationBuilder(session, "Add", "add"); + fail(); + } catch (IllegalStateException e) { + // expected + } + } + + @Test + public void failToBuildOpIfSessionIsClosed() { + EagerOperationBuilder opBuilder; + try (EagerSession session = EagerSession.create()) { + opBuilder = new EagerOperationBuilder(session, "Empty", "empty"); + } + try { + opBuilder.setAttr("dtype", DataType.FLOAT); + fail(); + } catch (IllegalStateException e) { + // expected + } + } + + @Test + public void addInputs() { + try (EagerSession session = EagerSession.create()) { + Operation asrt = + opBuilder(session, "Assert", "assert") + .addInput(TestUtil.constant(session, "Cond", true)) + .addInputList(new Output[] {TestUtil.constant(session, "Error", -1)}) + .build(); + try { + opBuilder(session, "Const", "var").addControlInput(asrt); + fail(); + } catch (UnsupportedOperationException e) { + // expected + } + } + } + + @Test + public void setDevice() { + try (EagerSession session = EagerSession.create()) { + opBuilder(session, "Add", "SetDevice") + .setDevice("/job:localhost/replica:0/task:0/device:CPU:0") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); + } + } + + @Test + public void setAttrs() { + // The effect of setting an attribute may not easily be visible from the other parts of this + // package's API. Thus, for now, the test simply executes the various setAttr variants to see + // that there are no exceptions. + // + // This is a bit of an awkward test since it has to find operations with attributes of specific + // types that aren't inferred from the input arguments. + try (EagerSession session = EagerSession.create()) { + // dtype, tensor attributes. + try (Tensor t = Tensors.create(1)) { + opBuilder(session, "Const", "DataTypeAndTensor") + .setAttr("dtype", DataType.INT32) + .setAttr("value", t) + .build(); + } + // type, int (TF "int" attributes are 64-bit signed, so a Java long). + opBuilder(session, "RandomUniform", "DataTypeAndInt") + .addInput(TestUtil.constant(session, "RandomUniformShape", new int[] {1})) + .setAttr("seed", 10) + .setAttr("dtype", DataType.FLOAT) + .build(); + // list(int), string + opBuilder(session, "MaxPool", "IntListAndString") + .addInput(TestUtil.constant(session, "MaxPoolInput", new float[2][2][2][2])) + .setAttr("ksize", new long[] {1, 1, 1, 1}) + .setAttr("strides", new long[] {1, 1, 1, 1}) + .setAttr("padding", "SAME") + .build(); + // list(float), device + opBuilder(session, "FractionalMaxPool", "FloatList") + .addInput(TestUtil.constant(session, "FractionalMaxPoolInput", new float[2][2][2][2])) + .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) + .build(); + // shape + opBuilder(session, "EnsureShape", "ShapeAttr") + .addInput(TestUtil.constant(session, "Const", new int[2][2])) + .setAttr("shape", Shape.make(2, 2)) + .build(); + // list(shape) + opBuilder(session, "FIFOQueue", "queue") + .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32}) + .setAttr("shapes", new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)}) + .build(); + // bool + opBuilder(session, "All", "Bool") + .addInput(TestUtil.constant(session, "Const", new boolean[] {true, true, false})) + .addInput(TestUtil.constant(session, "Axis", 0)) + .setAttr("keep_dims", false) + .build(); + // float + opBuilder(session, "ApproximateEqual", "Float") + .addInput(TestUtil.constant(session, "Const1", 10.00001f)) + .addInput(TestUtil.constant(session, "Const2", 10.00000f)) + .setAttr("tolerance", 0.1f) + .build(); + // Missing tests: list(string), list(byte), list(bool), list(type) + } + } + + private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { + return new EagerOperationBuilder(session, type, name); + } +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java new file mode 100644 index 00000000000..228676f28c3 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerOperationTest.java @@ -0,0 +1,180 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link EagerOperation} class. */ +@RunWith(JUnit4.class) +public class EagerOperationTest { + + @Test + public void failToCreateIfSessionIsClosed() { + EagerSession session = EagerSession.create(); + session.close(); + try { + new EagerOperation(session, 1L, new long[] {1L}, "Add", "add"); + fail(); + } catch (IllegalStateException e) { + // expected + } + } + + @Test + public void outputDataTypeAndShape() { + try (EagerSession session = EagerSession.create(); + Tensor t = Tensors.create(new int[2][3])) { + EagerOperation op = + opBuilder(session, "Const", "OutputAttrs") + .setAttr("dtype", DataType.INT32) + .setAttr("value", t) + .build(); + assertEquals(DataType.INT32, op.dtype(0)); + assertEquals(2, op.shape(0)[0]); + assertEquals(3, op.shape(0)[1]); + } + } + + @Test + public void outputTensor() { + try (EagerSession session = EagerSession.create()) { + EagerOperation add = + opBuilder(session, "Add", "CompareResult") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); + assertEquals(6, add.tensor(0).intValue()); + + // Validate that we retrieve the right shape and datatype from the tensor + // that has been resolved + assertEquals(0, add.shape(0).length); + assertEquals(DataType.INT32, add.dtype(0)); + } + } + + @Test + public void inputAndOutputListLengths() { + try (EagerSession session = EagerSession.create()) { + Output c1 = TestUtil.constant(session, "Const1", new float[] {1f, 2f}); + Output c2 = TestUtil.constant(session, "Const2", new float[] {3f, 4f}); + + EagerOperation acc = + opBuilder(session, "AddN", "InputListLength") + .addInputList(new Output[] {c1, c2}) + .build(); + assertEquals(2, acc.inputListLength("inputs")); + assertEquals(1, acc.outputListLength("sum")); + + EagerOperation split = + opBuilder(session, "Split", "OutputListLength") + .addInput(TestUtil.constant(session, "Axis", 0)) + .addInput(c1) + .setAttr("num_split", 2) + .build(); + assertEquals(1, split.inputListLength("split_dim")); + assertEquals(2, split.outputListLength("output")); + + try { + split.inputListLength("no_such_input"); + fail(); + } catch (IllegalArgumentException e) { + // expected + } + + try { + split.outputListLength("no_such_output"); + fail(); + } catch (IllegalArgumentException e) { + // expected + } + } + } + + @Test + public void numOutputs() { + try (EagerSession session = EagerSession.create()) { + EagerOperation op = + opBuilder(session, "UniqueWithCountsV2", "unq") + .addInput(TestUtil.constant(session, "Const1", new int[] {1, 2, 1})) + .addInput(TestUtil.constant(session, "Axis", new int[] {0})) + .setAttr("out_idx", DataType.INT32) + .build(); + assertEquals(3, op.numOutputs()); + } + } + + @Test + public void opNotAccessibleIfSessionIsClosed() { + EagerSession session = EagerSession.create(); + EagerOperation add = + opBuilder(session, "Add", "SessionClosed") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); + assertEquals(1, add.outputListLength("z")); + session.close(); + try { + add.outputListLength("z"); + fail(); + } catch (IllegalStateException e) { + // expected + } + } + + @Test + public void outputIndexOutOfBounds() { + try (EagerSession session = EagerSession.create()) { + EagerOperation add = + opBuilder(session, "Add", "OutOfRange") + .addInput(TestUtil.constant(session, "Const1", 2)) + .addInput(TestUtil.constant(session, "Const2", 4)) + .build(); + try { + add.getUnsafeNativeHandle(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } + try { + add.shape(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } + try { + add.dtype(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } + try { + add.tensor(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } + } + } + + private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { + return new EagerOperationBuilder(session, type, name); + } +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java new file mode 100644 index 00000000000..77f38bb6160 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/EagerSessionTest.java @@ -0,0 +1,173 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.EagerSession.ResourceCleanupStrategy; + +@RunWith(JUnit4.class) +public class EagerSessionTest { + + @Test + public void closeSessionTwiceDoesNotFail() { + try (EagerSession s = EagerSession.create()) { + s.close(); + } + } + + @Test + public void cleanupResourceOnSessionClose() { + AtomicBoolean deleted = new AtomicBoolean(); + + try (EagerSession s = + EagerSession.options() + .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SESSION_CLOSE) + .build()) { + + new TestReference(s, new Object(), deleted); + + assertFalse(deleted.get()); + runGC(); + assertFalse(deleted.get()); + + buildOp(s); + assertFalse(deleted.get()); // reaching safe point did not release resources + } + assertTrue(deleted.get()); + } + + @Test + public void cleanupResourceOnSafePoints() { + AtomicBoolean deleted = new AtomicBoolean(); + + try (EagerSession s = + EagerSession.options() + .resourceCleanupStrategy(ResourceCleanupStrategy.ON_SAFE_POINTS) + .build()) { + + new TestReference(s, new Object(), deleted); + + assertFalse(deleted.get()); + runGC(); + assertFalse(deleted.get()); + + buildOp(s); + assertTrue(deleted.get()); // reaching safe point released resources + } + } + + @Test + public void cleanupResourceInBackground() { + AtomicBoolean deleted = new AtomicBoolean(); + + try (EagerSession s = + EagerSession.options() + .resourceCleanupStrategy(ResourceCleanupStrategy.IN_BACKGROUND) + .build()) { + + new TestReference(s, new Object(), deleted); + + assertFalse(deleted.get()); + runGC(); + sleep(50); // allow some time to the background thread for cleaning up resources + assertTrue(deleted.get()); + } + } + + @Test + public void clearedResourcesAreNotCleanedUp() { + AtomicBoolean deleted = new AtomicBoolean(); + + try (EagerSession s = EagerSession.create()) { + TestReference ref = new TestReference(s, new Object(), deleted); + ref.clear(); + } + assertFalse(deleted.get()); + } + + @Test + public void buildingOpWithClosedSessionFails() { + EagerSession s = EagerSession.create(); + s.close(); + try { + buildOp(s); + fail(); + } catch (IllegalStateException e) { + // ok + } + } + + @Test + public void addingReferenceToClosedSessionFails() { + EagerSession s = EagerSession.create(); + s.close(); + try { + new TestReference(s, new Object(), new AtomicBoolean()); + fail(); + } catch (IllegalStateException e) { + // ok + } + } + + private static class TestReference extends EagerSession.NativeReference { + + TestReference(EagerSession session, Object referent, AtomicBoolean deleted) { + super(session, referent); + this.deleted = deleted; + } + + @Override + void delete() { + if (!deleted.compareAndSet(false, true)) { + fail("Reference was deleted more than once"); + } + } + + private final AtomicBoolean deleted; + } + + private static void buildOp(EagerSession s) { + // Creating an operation is a safe point for resource cleanup + try { + s.opBuilder("Const", "Const"); + } catch (UnsupportedOperationException e) { + // TODO (karlllessard) remove this exception catch when EagerOperationBuilder is implemented + } + } + + private static void runGC() { + // Warning: There is no way to force the garbage collector to run, so here we simply to our best + // to get it triggered but it might be sufficient on some platforms. Adjust accordingly if some + // cleanup tests start to fail. + System.gc(); + System.runFinalization(); + } + + private static void sleep(int millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + } + } +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationBuilderTest.java similarity index 96% rename from tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java rename to tensorflow/java/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index 0a4a8cf4e3f..a0fbe80ed30 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -1,4 +1,4 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,9 +24,9 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link org.tensorflow.OperationBuilder}. */ +/** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */ @RunWith(JUnit4.class) -public class OperationBuilderTest { +public class GraphOperationBuilderTest { // TODO(ashankar): Restore this test once the C API gracefully handles mixing graphs and // operations instead of segfaulting. @Test @@ -136,7 +136,8 @@ public class OperationBuilderTest { assertEquals(-1, n.shape().numDimensions()); assertEquals(DataType.FLOAT, n.dataType()); - n = g.opBuilder("Placeholder", "batch_of_vectors") + n = + g.opBuilder("Placeholder", "batch_of_vectors") .setAttr("dtype", DataType.FLOAT) .setAttr("shape", Shape.make(-1, 784)) .build() @@ -168,7 +169,7 @@ public class OperationBuilderTest { Tensor yes = Tensors.create(true); Tensor no = Tensors.create(false)) { Output placeholder = TestUtil.placeholder(g, "boolean", Boolean.class); - Operation check = + GraphOperation check = g.opBuilder("Assert", "assert") .addInput(placeholder) .addInputList(new Output[] {placeholder}) diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java similarity index 86% rename from tensorflow/java/src/test/java/org/tensorflow/OperationTest.java rename to tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java index 6fe3b3c3278..8fb67b90ce0 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphOperationTest.java @@ -24,13 +24,14 @@ import static org.junit.Assert.fail; import java.util.Arrays; import java.util.HashSet; import java.util.Set; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link org.tensorflow.Operation}. */ +/** Unit tests for {@link org.tensorflow.GraphOperation}. */ @RunWith(JUnit4.class) -public class OperationTest { +public class GraphOperationTest { @Test public void outputListLengthFailsOnInvalidName() { @@ -53,12 +54,12 @@ public class OperationTest { @Test public void operationEquality() { - Operation op1; + GraphOperation op1; try (Graph g = new Graph()) { - op1 = TestUtil.constant(g, "op1", 1).op(); - Operation op2 = TestUtil.constant(g, "op2", 2).op(); - Operation op3 = new Operation(g, op1.getUnsafeNativeHandle()); - Operation op4 = g.operation("op1"); + op1 = TestUtil.constantOp(g, "op1", 1); + GraphOperation op2 = TestUtil.constantOp(g, "op2", 2); + GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle()); + GraphOperation op4 = g.operation("op1"); assertEquals(op1, op1); assertNotEquals(op1, op2); assertEquals(op1, op3); @@ -78,10 +79,10 @@ public class OperationTest { @Test public void operationCollection() { try (Graph g = new Graph()) { - Operation op1 = TestUtil.constant(g, "op1", 1).op(); - Operation op2 = TestUtil.constant(g, "op2", 2).op(); - Operation op3 = new Operation(g, op1.getUnsafeNativeHandle()); - Operation op4 = g.operation("op1"); + GraphOperation op1 = TestUtil.constantOp(g, "op1", 1); + GraphOperation op2 = TestUtil.constantOp(g, "op2", 2); + GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle()); + GraphOperation op4 = g.operation("op1"); Set ops = new HashSet<>(); ops.addAll(Arrays.asList(op1, op2, op3, op4)); assertEquals(2, ops.size()); @@ -166,6 +167,18 @@ public class OperationTest { } } + @Test + public void outputTensorNotSupported() { + try (Graph g = new Graph()) { + Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3); + try { + split.output(0).tensor(); + fail(); + } catch (IllegalStateException e) { + } + } + } + private static int split(int[] values, int num_split) { try (Graph g = new Graph()) { return g.opBuilder("Split", "Split") diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java index 3229cce2776..3a75f3cb5c8 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java @@ -18,6 +18,7 @@ package org.tensorflow; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -28,6 +29,7 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -520,6 +522,25 @@ public class TensorTest { } } + @Test + public void eagerTensorIsReleasedAfterSessionIsClosed() { + Tensor sum; + try (EagerSession session = EagerSession.create()) { + Output x = TestUtil.constant(session, "Const1", 10); + Output y = TestUtil.constant(session, "Const2", 20); + sum = TestUtil.addN(session, x, y).tensor(); + assertNotEquals(0L, sum.getNativeHandle()); + assertEquals(30, sum.intValue()); + } + assertEquals(0L, sum.getNativeHandle()); + try { + sum.intValue(); + fail(); + } catch (NullPointerException e) { + // expected. + } + } + @Test public void fromHandle() { // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index f984c508ee9..6e24d88a310 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -44,9 +44,15 @@ public class TestUtil { } } - public static Output constant(Graph g, String name, Object value) { + public static GraphOperation constantOp(Graph g, String name, Object value) { try (Tensor t = Tensor.create(value)) { - return g.opBuilder("Const", name) + return g.opBuilder("Const", name).setAttr("dtype", t.dataType()).setAttr("value", t).build(); + } + } + + public static Output constant(ExecutionEnvironment env, String name, Object value) { + try (Tensor t = Tensor.create(value)) { + return env.opBuilder("Const", name) .setAttr("dtype", t.dataType()) .setAttr("value", t) .build() @@ -61,8 +67,8 @@ public class TestUtil { .output(0); } - public static Output addN(Graph g, Output... inputs) { - return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); + public static Output addN(ExecutionEnvironment env, Output... inputs) { + return env.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); } public static Output matmul( diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java index 125de73554e..81918a81ac8 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.fail; import java.util.HashMap; import java.util.Map; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -196,7 +197,7 @@ public class ScopeTest { static Const create(Scope s, Tensor value) { return new Const( - s.graph() + s.env() .opBuilder("Const", s.makeOpName("Const")) .setAttr("dtype", value.dataType()) .setAttr("value", value) @@ -207,7 +208,7 @@ public class ScopeTest { static Const create(Scope s, Object v, Class type) { try (Tensor value = Tensor.create(v, type)) { return new Const( - s.graph() + s.env() .opBuilder("Const", s.makeOpName("Const")) .setAttr("dtype", value.dataType()) .setAttr("value", value) @@ -230,7 +231,7 @@ public class ScopeTest { static Mean create(Scope s, Output input, Output reductionIndices) { return new Mean( - s.graph() + s.env() .opBuilder("Mean", s.makeOpName("Mean")) .addInput(input) .addInput(reductionIndices) @@ -252,7 +253,7 @@ public class ScopeTest { static SquaredDifference create(Scope s, Output x, Output y) { return new SquaredDifference( - s.graph() + s.env() .opBuilder("SquaredDifference", s.makeOpName("SquaredDifference")) .addInput(x) .addInput(y) diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java index 49c4ff639ec..daafd6b9503 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java @@ -25,6 +25,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.Shape; import org.tensorflow.op.Ops; @RunWith(JUnit4.class) @@ -57,4 +58,29 @@ public final class GeneratedOperationsTest { } } } + + /** + * Test for Ops.withControlDependencies. + * + *

Creates an add node with a control dependency to an assign node. In other words, the assign + * node is a control input to the add node. When the add node is run, the assign node is expected + * to have run beforehand due to the control dependency. + */ + @Test + public void testControlDependencies() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Ops ops = Ops.create(g); + Operand variable = ops.variable(Shape.scalar(), Integer.class); + Operand initVariable = ops.assign(variable, ops.constant(0)); + ArrayList> controls = new ArrayList>(); + controls.add(ops.assign(variable, ops.constant(3))); + Operand x = + ops.withControlDependencies(controls).math().add(variable, ops.constant(0)); + sess.runner().addTarget(initVariable).run(); + try (Tensor result = sess.runner().fetch(x).run().get(0).expect(Integer.class); ) { + assertEquals(3, result.intValue()); + } + } + } } diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 1b55f967413..f43b8fd4c17 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -41,12 +41,21 @@ config_setting( TFLITE_DEFAULT_COPTS = if_not_windows([ "-Wall", "-Wno-comment", + "-Wno-extern-c-compat", ]) cc_library( - name = "schema_fbs_version", + name = "version", hdrs = ["version.h"], copts = TFLITE_DEFAULT_COPTS, + # Note that we only use the header defines from :version_lib. + deps = ["//tensorflow/core:version_lib"], +) + +# TODO(b/128420794): Migrate clients to use :version directly. +alias( + name = "schema_fbs_version", + actual = ":version", ) cc_library( @@ -62,10 +71,13 @@ cc_library( ], ) -tf_cc_test( +cc_test( name = "arena_planner_test", size = "small", srcs = ["arena_planner_test.cc"], + tags = [ + "tflite_not_portable_android", + ], deps = [ ":arena_planner", "//tensorflow/core:tflite_portable_logging", @@ -147,15 +159,12 @@ cc_library( "stderr_reporter.cc", ] + select({ "//tensorflow:android": [ - "nnapi_delegate.cc", "mmap_allocation.cc", ], "//tensorflow:windows": [ - "nnapi_delegate_disabled.cc", "mmap_allocation_disabled.cc", ], "//conditions:default": [ - "nnapi_delegate_disabled.cc", "mmap_allocation.cc", ], }), @@ -169,7 +178,6 @@ cc_library( "interpreter.h", "model.h", "mutable_op_resolver.h", - "nnapi_delegate.h", "op_resolver.h", "optional_debug_tools.h", "stderr_reporter.h", @@ -180,14 +188,14 @@ cc_library( ":graph_info", ":memory_planner", ":minimal_logging", - ":schema_fbs_version", ":simple_memory_arena", ":string", ":util", + ":version", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/nnapi:nnapi_implementation", - "//tensorflow/lite/profiling:profiler", "//tensorflow/lite/schema:schema_fbs", ] + select({ ":with_select_tf_ops": [ @@ -243,6 +251,7 @@ cc_test( "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/testing:util", + "//third_party/eigen3", "@com_google_googletest//:gtest", ], ) @@ -390,7 +399,9 @@ cc_library( "//tensorflow:android": ["-llog"], "//conditions:default": [], }), - visibility = ["//visibility:private"], + visibility = [ + "//tensorflow/lite:__subpackages__", + ], ) cc_test( diff --git a/tensorflow/lite/arena_planner.cc b/tensorflow/lite/arena_planner.cc index 8a5ef113128..e695c43f13a 100644 --- a/tensorflow/lite/arena_planner.cc +++ b/tensorflow/lite/arena_planner.cc @@ -108,8 +108,8 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { refcounts[tensor_index]++; } - // Variable tensors should are also never overwritten and need to be alive all - // the time. + // Variable tensors also should be ensured to be never overwritten and need to + // be alive all the time. for (int tensor_index : graph_info_->variables()) { refcounts[tensor_index]++; } @@ -135,7 +135,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { } // Count references to node input tensors. - for (int i = 0; i < graph_info_->num_nodes(); ++i) { + for (size_t i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); TfLiteIntArray* node_inputs = node.inputs; for (int j = 0; j < node_inputs->size; ++j) { diff --git a/tensorflow/lite/arena_planner.h b/tensorflow/lite/arena_planner.h index beaadaf4eff..e70d29be034 100644 --- a/tensorflow/lite/arena_planner.h +++ b/tensorflow/lite/arena_planner.h @@ -97,11 +97,11 @@ class ArenaPlanner : public MemoryPlanner { // Stores allocation data for all tensors. std::vector allocs_; - // A chronological list of instructions to allocated and deallocate tensors, + // A chronological list of instructions to allocate and deallocate tensors, // reflecting the way they are used in the graph. std::vector alloc_queue_; - // Raw memory buffer that is allocated for all temporary and graph outputs. + // Raw memory buffer that is allocated for all temporary and graph outputs // that are declared kTfLiteArenaRw. SimpleMemoryArena arena_; @@ -114,7 +114,7 @@ class ArenaPlanner : public MemoryPlanner { // unpredictable results. bool preserve_inputs_; - // If true, then no overlapping of memory areas is done, meaning intermediates + // If true, then no overlapping of memory areas is done, meaning intermediate // results can be queried after running (modulo running delegates). bool preserve_intermediates_; diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index d02d8b34c06..a65312517e5 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -391,7 +391,7 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithDynamicTensor) { }, {3}); - // Make #1 dynaic so it does not get allocated. + // Make #1 dynamic so it does not get allocated. (*graph.tensors())[1].allocation_type = kTfLiteDynamic; SetGraph(&graph); diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 764af75ab93..582ec7144b5 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -86,25 +86,27 @@ def tflite_jni_linkopts_unstripped(): "//conditions:default": [], }) -def tflite_linkopts(): - """Defines linker flags to reduce size of TFLite binary.""" - return tflite_linkopts_unstripped() + select({ +def tflite_symbol_opts(): + """Defines linker flags whether to include symbols or not.""" + return select({ + "//tensorflow:android": [ + "-latomic", # Required for some uses of ISO C++11 in x86. + ], + "//conditions:default": [], + }) + select({ "//tensorflow:debug": [], "//conditions:default": [ "-s", # Omit symbol table, for all non debug builds ], }) +def tflite_linkopts(): + """Defines linker flags to reduce size of TFLite binary.""" + return tflite_linkopts_unstripped() + tflite_symbol_opts() + def tflite_jni_linkopts(): """Defines linker flags to reduce size of TFLite binary with JNI.""" - return tflite_jni_linkopts_unstripped() + [ - "-latomic", # Required for some uses of ISO C++11 in x86.] - ] + select({ - "//tensorflow:debug": [], - "//conditions:default": [ - "-s", # Omit symbol table, for all non debug builds - ], - }) + return tflite_jni_linkopts_unstripped() + tflite_symbol_opts() def tflite_jni_binary( name, @@ -246,6 +248,7 @@ def generated_test_models(): "equal", "exp", "expand_dims", + "eye", "fill", "floor", "floor_div", @@ -258,6 +261,7 @@ def generated_test_models(): "global_batch_norm", "greater", "greater_equal", + "identity", "sum", "l2norm", "l2norm_shared_epsilon", @@ -272,6 +276,8 @@ def generated_test_models(): "logical_or", "logical_xor", "lstm", + "matrix_diag", + "matrix_set_diag", "max_pool", "maximum", "mean", @@ -301,6 +307,7 @@ def generated_test_models(): "resolve_constant_strided_slice", "reverse_sequence", "reverse_v2", + "round", "rsqrt", "shape", "sigmoid", @@ -323,6 +330,9 @@ def generated_test_models(): "topk", "transpose", "transpose_conv", + "unfused_gru", + "unidirectional_sequence_lstm", + "unidirectional_sequence_rnn", "unique", "unpack", "unroll_batch_matmul", @@ -337,7 +347,8 @@ def generated_test_models_failing(conversion_mode): if conversion_mode == "toco-flex": return [ "lstm", # TODO(b/117510976): Restore when lstm flex conversion works. - "unroll_batch_matmul", # TODO(b/123030774): Fails in 1.13 tests. + "unidirectional_sequence_lstm", + "unidirectional_sequence_rnn", ] return [] @@ -392,7 +403,7 @@ def gen_zip_test(name, test_name, conversion_mode, **kwargs): # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050. # if conversion_mode == "pb2lite": # toco = "//tensorflow/lite/experimental/pb2lite:pb2lite" - flags = "--ignore_toco_errors --run_with_flex" + flags = "--ignore_converter_errors --run_with_flex" gen_zipped_test_file( name = "zip_%s" % test_name, diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 7b4efdf4a36..4e86e4bdf27 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -138,6 +138,10 @@ typedef enum { kTfLiteBuiltinRank = 110, kTfLiteBuiltinElu = 111, kTfLiteBuiltinReverseSequence = 112, + kTfLiteBuiltinMatrixDiag = 113, + kTfLiteBuiltinQuantize = 114, + kTfLiteBuiltinMatrixSetDiag = 115, + kTfLiteBuiltinRound = 116, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index 5d1c92d36f5..a760eba1dce 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -46,9 +46,12 @@ typedef enum { kTfLiteMirrorPaddingSymmetric, } TfLiteMirrorPaddingMode; +// TODO(b/130259536): We should move this out of builtin_op_data. typedef struct { int width; int height; + int width_offset; + int height_offset; } TfLitePaddingValues; typedef struct { @@ -334,6 +337,7 @@ typedef struct { } TfLiteShapeParams; typedef struct { + EmptyStructPlaceholder placeholder; } TfLiteRankParams; typedef struct { @@ -373,6 +377,14 @@ typedef struct { int batch_dim; } TfLiteReverseSequenceParams; +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixDiagParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixSetDiagParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/c/c_api_internal.c b/tensorflow/lite/c/c_api_internal.c index f20ee23bd81..926d992011f 100644 --- a/tensorflow/lite/c/c_api_internal.c +++ b/tensorflow/lite/c/c_api_internal.c @@ -172,6 +172,8 @@ const char* TfLiteTypeGetName(TfLiteType type) { return "COMPLEX64"; case kTfLiteString: return "STRING"; + case kTfLiteFloat16: + return "FLOAT16"; } return "Unknown type"; } diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h index 83e2be69076..1948e1ba106 100644 --- a/tensorflow/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -44,12 +44,15 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; // need. Access to the external contexts is controled by one of the // corresponding support files. typedef enum { - kTfLiteEigenContext = 0, // include eigen_support.h to use. - kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. - kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. - kTfLiteMaxExternalContexts = 3 + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. + kTfLiteCpuBackendContext = 3, // include cpu_backend_support.h to use. + kTfLiteMaxExternalContexts = 4 } TfLiteExternalContextType; +struct TfLiteContext; + // An external context is a collection of information unrelated to the TF Lite // framework, but useful to a subset of the ops. TF Lite knows very little // about about the actual contexts, but it keeps a list of them, and is able to @@ -192,6 +195,11 @@ typedef struct { float re, im; // real and imaginary parts, respectively. } TfLiteComplex64; +// Half precision data type compatible with the C99 definition. +typedef struct { + uint16_t data; +} TfLiteFloat16; + // Types supported by tensor typedef enum { kTfLiteNoType = 0, @@ -204,6 +212,7 @@ typedef enum { kTfLiteInt16 = 7, kTfLiteComplex64 = 8, kTfLiteInt8 = 9, + kTfLiteFloat16 = 10, } TfLiteType; // Return the name of a given type, for error reporting purposes. @@ -253,9 +262,11 @@ typedef struct { // A union of pointers that points to memory for a given tensor. typedef union { - int* i32; + int32_t* i32; int64_t* i64; float* f; + // Placeholder for 16b float type. Use uint16* in the pointer union for now. + TfLiteFloat16* f16; char* raw; const char* raw_const; uint8_t* uint8; @@ -279,7 +290,9 @@ typedef enum { // The delegates should use zero or positive integers to represent handles. // -1 is reserved from unallocated status. typedef int TfLiteBufferHandle; -const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; +enum { + kTfLiteNullBufferHandle = -1, +}; // An tensor in the interpreter system which is a wrapper around a buffer of // data including a dimensionality (or NULL if not currently defined). diff --git a/tensorflow/lite/c/c_api_internal_test.cc b/tensorflow/lite/c/c_api_internal_test.cc index d01cf63a3e0..9a37cd9552f 100644 --- a/tensorflow/lite/c/c_api_internal_test.cc +++ b/tensorflow/lite/c/c_api_internal_test.cc @@ -78,6 +78,7 @@ TEST(Types, TestTypeNames) { }; EXPECT_EQ(type_name(kTfLiteNoType), "NOTYPE"); EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32"); + EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16"); EXPECT_EQ(type_name(kTfLiteInt16), "INT16"); EXPECT_EQ(type_name(kTfLiteInt32), "INT32"); EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8"); diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index db6b4a2d18e..17eeed6a687 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -17,6 +17,7 @@ cc_library( "error_reporter.h", "flatbuffer_conversions.h", "op_resolver.h", + "profiler.h", ], copts = tflite_copts(), deps = [ diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 2ba64f51d9a..92ea8837be9 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -55,10 +55,14 @@ TfLiteStatus FlatBufferIntVectorToArray( TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, ErrorReporter* error_reporter) { + *type = kTfLiteNoType; switch (tensor_type) { case TensorType_FLOAT32: *type = kTfLiteFloat32; break; + case TensorType_FLOAT16: + *type = kTfLiteFloat16; + break; case TensorType_INT16: *type = kTfLiteInt16; break; @@ -83,10 +87,10 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_COMPLEX64: *type = kTfLiteComplex64; break; - default: - error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", - EnumNameTensorType(tensor_type), tensor_type); - return kTfLiteError; + } + if (*type == kTfLiteNoType) { + error_reporter->Report("Unsupported data type %d in tensor\n", tensor_type); + return kTfLiteError; } return kTfLiteOk; } @@ -167,7 +171,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_CAST: { TfLiteCastParams* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_CastOptions()) { + if (const auto* schema_params = op->builtin_options_as_CastOptions()) { auto in_status = ConvertTensorType(schema_params->in_data_type(), ¶ms->in_data_type, error_reporter); @@ -185,7 +189,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = allocator->AllocatePOD(); - if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { + if (const auto* lshParams = + op->builtin_options_as_LSHProjectionOptions()) { params->type = parseLSHProjectionType(lshParams->type()); } *builtin_data = reinterpret_cast(params); @@ -195,7 +200,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_MAX_POOL_2D: case BuiltinOperator_L2_POOL_2D: { TfLitePoolParams* params = allocator->AllocatePOD(); - if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { + if (const auto* pool_params = op->builtin_options_as_Pool2DOptions()) { params->padding = parse_padding(pool_params->padding()); params->stride_width = pool_params->stride_w(); params->stride_height = pool_params->stride_h(); @@ -210,7 +215,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_DEPTHWISE_CONV_2D: { TfLiteDepthwiseConvParams* params = allocator->AllocatePOD(); - if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { + if (const auto* conv_params = + op->builtin_options_as_DepthwiseConv2DOptions()) { params->padding = parse_padding(conv_params->padding()); params->stride_width = conv_params->stride_w(); params->stride_height = conv_params->stride_h(); @@ -226,7 +232,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_SVDF: { TfLiteSVDFParams* params = allocator->AllocatePOD(); - if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { + if (const auto* svdf_params = op->builtin_options_as_SVDFOptions()) { params->rank = svdf_params->rank(); params->activation = parse_activation(svdf_params->fused_activation_function()); @@ -236,7 +242,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { auto params = allocator->AllocatePOD(); - if (auto* sequence_rnn_params = + if (const auto* sequence_rnn_params = op->builtin_options_as_SequenceRNNOptions()) { params->activation = parse_activation(sequence_rnn_params->fused_activation_function()); @@ -248,7 +254,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { auto params = allocator->AllocatePOD(); - if (auto* bidi_sequence_rnn_params = + if (const auto* bidi_sequence_rnn_params = op->builtin_options_as_BidirectionalSequenceRNNOptions()) { params->activation = parse_activation( bidi_sequence_rnn_params->fused_activation_function()); @@ -260,7 +266,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_RNN: { TfLiteRNNParams* params = allocator->AllocatePOD(); - if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { + if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { params->activation = parse_activation(rnn_params->fused_activation_function()); } @@ -270,7 +276,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { TfLiteEmbeddingLookupSparseParams* params = allocator->AllocatePOD(); - if (auto* embedding_params = + if (const auto* embedding_params = op->builtin_options_as_EmbeddingLookupSparseOptions()) { params->combiner = parseCombinerType(embedding_params->combiner()); } @@ -280,7 +286,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_FULLY_CONNECTED: { TfLiteFullyConnectedParams* params = allocator->AllocatePOD(); - if (auto* fully_connected_params = + if (const auto* fully_connected_params = op->builtin_options_as_FullyConnectedOptions()) { params->activation = parse_activation( fully_connected_params->fused_activation_function()); @@ -306,7 +312,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SOFTMAX: { TfLiteSoftmaxParams* params = allocator->AllocatePOD(); - if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { + if (const auto* softmax_params = + op->builtin_options_as_SoftmaxOptions()) { params->beta = softmax_params->beta(); } *builtin_data = reinterpret_cast(params); @@ -315,7 +322,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_CONCATENATION: { TfLiteConcatenationParams* params = allocator->AllocatePOD(); - if (auto* concatenation_params = + if (const auto* concatenation_params = op->builtin_options_as_ConcatenationOptions()) { params->activation = parse_activation(concatenation_params->fused_activation_function()); @@ -326,7 +333,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_MUL: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_MulOptions()) { + if (const auto* schema_params = op->builtin_options_as_MulOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); } @@ -335,7 +342,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_ADD: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_AddOptions()) { + if (const auto* schema_params = op->builtin_options_as_AddOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); } @@ -344,7 +351,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_DIV: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_DivOptions()) { + if (const auto* schema_params = op->builtin_options_as_DivOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); } @@ -353,7 +360,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_SUB: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_SubOptions()) { + if (const auto* schema_params = op->builtin_options_as_SubOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); } @@ -362,7 +369,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_L2_NORMALIZATION: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { + if (const auto* schema_params = op->builtin_options_as_L2NormOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); } @@ -371,7 +378,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = + if (const auto* schema_params = op->builtin_options_as_LocalResponseNormalizationOptions()) { params->radius = schema_params->radius(); params->bias = schema_params->bias(); @@ -383,7 +390,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_LSTM: { auto params = allocator->AllocatePOD(); - if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) { params->activation = parse_activation(lstm_params->fused_activation_function()); params->cell_clip = lstm_params->cell_clip(); @@ -395,6 +402,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case LSTMKernelType_BASIC: params->kernel_type = kTfLiteLSTMBasicKernel; break; + default: + error_reporter->Report("Unhandled LSTM kernel type: %d", + lstm_params->kernel_type()); + return kTfLiteError; } } *builtin_data = reinterpret_cast(params); @@ -403,7 +414,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { auto* params = allocator->AllocatePOD(); - if (auto* seq_lstm_params = + if (const auto* seq_lstm_params = op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { params->activation = parse_activation(seq_lstm_params->fused_activation_function()); @@ -417,7 +428,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { auto params = allocator->AllocatePOD(); - if (auto* bidi_lstm_params = + if (const auto* bidi_lstm_params = op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { params->activation = parse_activation(bidi_lstm_params->fused_activation_function()); @@ -431,7 +442,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_RESIZE_BILINEAR: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = + if (const auto* schema_params = op->builtin_options_as_ResizeBilinearOptions()) { params->align_corners = schema_params->align_corners(); } @@ -445,7 +456,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, [&]() { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = + if (const auto* schema_params = op->builtin_options_as_ResizeNearestNeighborOptions()) { params->align_corners = schema_params->align_corners(); } @@ -455,7 +466,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_RESHAPE: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { + if (const auto* schema_params = op->builtin_options_as_ReshapeOptions()) { auto* new_shape = schema_params->new_shape(); TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( sizeof(params->shape), new_shape, params->shape, error_reporter, @@ -468,7 +479,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SKIP_GRAM: { TfLiteSkipGramParams* params = allocator->AllocatePOD(); - if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { + if (const auto* skip_gram_params = + op->builtin_options_as_SkipGramOptions()) { params->ngram_size = skip_gram_params->ngram_size(); params->max_skip_size = skip_gram_params->max_skip_size(); params->include_all_ngrams = skip_gram_params->include_all_ngrams(); @@ -478,7 +490,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_SPACE_TO_DEPTH: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { + if (const auto* schema_params = + op->builtin_options_as_SpaceToDepthOptions()) { params->block_size = schema_params->block_size(); } *builtin_data = reinterpret_cast(params); @@ -487,7 +500,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_GATHER: { TfLiteGatherParams* params = allocator->AllocatePOD(); params->axis = 0; - if (auto* gather_params = op->builtin_options_as_GatherOptions()) { + if (const auto* gather_params = op->builtin_options_as_GatherOptions()) { params->axis = gather_params->axis(); } @@ -501,7 +514,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_REDUCE_ANY: case BuiltinOperator_SUM: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { + if (const auto* schema_params = op->builtin_options_as_ReducerOptions()) { params->keep_dims = schema_params->keep_dims(); } *builtin_data = reinterpret_cast(params); @@ -509,7 +522,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_SPLIT: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_SplitOptions()) { + if (const auto* schema_params = op->builtin_options_as_SplitOptions()) { params->num_splits = schema_params->num_splits(); } *builtin_data = reinterpret_cast(params); @@ -517,7 +530,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_SPLIT_V: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_SplitVOptions()) { + if (const auto* schema_params = op->builtin_options_as_SplitVOptions()) { params->num_splits = schema_params->num_splits(); } *builtin_data = reinterpret_cast(params); @@ -525,7 +538,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_SQUEEZE: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { + if (const auto* schema_params = op->builtin_options_as_SqueezeOptions()) { const auto& squeeze_dims = schema_params->squeeze_dims(); TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( sizeof(params->squeeze_dims), squeeze_dims, params->squeeze_dims, @@ -537,7 +550,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_STRIDED_SLICE: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { + if (const auto* schema_params = + op->builtin_options_as_StridedSliceOptions()) { params->begin_mask = schema_params->begin_mask(); params->end_mask = schema_params->end_mask(); params->ellipsis_mask = schema_params->ellipsis_mask(); @@ -549,7 +563,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_ARG_MAX: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { + if (const auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { ConvertTensorType(schema_params->output_type(), ¶ms->output_type, error_reporter); } @@ -568,7 +582,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = allocator->AllocatePOD(); - if (auto* transpose_conv_params = + if (const auto* transpose_conv_params = op->builtin_options_as_TransposeConvOptions()) { params->padding = parse_padding(transpose_conv_params->padding()); params->stride_width = transpose_conv_params->stride_w(); @@ -580,7 +594,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SPARSE_TO_DENSE: { TfLiteSparseToDenseParams* params = allocator->AllocatePOD(); - if (auto* sparse_to_dense_params = + if (const auto* sparse_to_dense_params = op->builtin_options_as_SparseToDenseOptions()) { params->validate_indices = sparse_to_dense_params->validate_indices(); } @@ -589,7 +603,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_SHAPE: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { + if (const auto* schema_params = op->builtin_options_as_ShapeOptions()) { ConvertTensorType(schema_params->out_type(), ¶ms->out_type, error_reporter); } @@ -598,7 +612,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_PACK: { TfLitePackParams* params = allocator->AllocatePOD(); - if (auto* pack_params = op->builtin_options_as_PackOptions()) { + if (const auto* pack_params = op->builtin_options_as_PackOptions()) { params->values_count = pack_params->values_count(); params->axis = pack_params->axis(); } @@ -612,7 +626,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_FAKE_QUANT: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) { + if (const auto* schema_params = + op->builtin_options_as_FakeQuantOptions()) { params->min = schema_params->min(); params->max = schema_params->max(); params->num_bits = schema_params->num_bits(); @@ -623,7 +638,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_ONE_HOT: { auto* params = allocator->AllocatePOD(); - if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { + if (const auto* schema_params = op->builtin_options_as_OneHotOptions()) { params->axis = schema_params->axis(); } *builtin_data = static_cast(params); @@ -631,7 +646,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_UNPACK: { TfLiteUnpackParams* params = allocator->AllocatePOD(); - if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) { + if (const auto* unpack_params = op->builtin_options_as_UnpackOptions()) { params->num = unpack_params->num(); params->axis = unpack_params->axis(); } @@ -641,7 +656,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LEAKY_RELU: { TfLiteLeakyReluParams* params = allocator->AllocatePOD(); - if (auto* leaky_relu_params = op->builtin_options_as_LeakyReluOptions()) { + if (const auto* leaky_relu_params = + op->builtin_options_as_LeakyReluOptions()) { params->alpha = leaky_relu_params->alpha(); } *builtin_data = reinterpret_cast(params); @@ -650,7 +666,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_MIRROR_PAD: { TfLiteMirrorPaddingParams* params = allocator->AllocatePOD(); - auto* mirror_pad_params = op->builtin_options_as_MirrorPadOptions(); + const auto* mirror_pad_params = op->builtin_options_as_MirrorPadOptions(); if (mirror_pad_params != nullptr) { params->mode = mirror_pad_params->mode() == tflite::MirrorPadMode_REFLECT @@ -662,7 +678,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_UNIQUE: { TfLiteUniqueParams* params = allocator->AllocatePOD(); - auto* unique_params = op->builtin_options_as_UniqueOptions(); + const auto* unique_params = op->builtin_options_as_UniqueOptions(); if (unique_params != nullptr) { params->index_out_type = unique_params->idx_out_type() == tflite::TensorType_INT64 @@ -675,7 +691,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_REVERSE_SEQUENCE: { TfLiteReverseSequenceParams* params = allocator->AllocatePOD(); - if (auto* reverse_seq_params = + if (const auto* reverse_seq_params = op->builtin_options_as_ReverseSequenceOptions()) { params->seq_dim = reverse_seq_params->seq_dim(); params->batch_dim = reverse_seq_params->batch_dim(); @@ -683,8 +699,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - - // Below are the ops with no builtin_data strcture. + // Below are the ops with no builtin_data structure. case BuiltinOperator_ABS: case BuiltinOperator_BATCH_TO_SPACE_ND: // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are @@ -708,6 +723,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LOG: case BuiltinOperator_LOGISTIC: case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MATRIX_DIAG: + case BuiltinOperator_MATRIX_SET_DIAG: case BuiltinOperator_MAXIMUM: case BuiltinOperator_MINIMUM: case BuiltinOperator_NEG: @@ -718,6 +735,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_RELU: case BuiltinOperator_RELU6: case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_ROUND: case BuiltinOperator_RSQRT: case BuiltinOperator_SELECT: case BuiltinOperator_SIN: @@ -744,6 +762,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_GATHER_ND: case BuiltinOperator_WHERE: case BuiltinOperator_RANK: + case BuiltinOperator_QUANTIZE: break; } return kTfLiteOk; diff --git a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc index 4a5de48302c..c7f8c1ad66e 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc @@ -141,6 +141,13 @@ TEST_F(FlatbufferConversionsTest, TestConvertTensorType) { EXPECT_EQ(kTfLiteFloat32, type); } +TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeFloat16) { + TfLiteType type; + EXPECT_EQ(kTfLiteOk, + ConvertTensorType(TensorType_FLOAT16, &type, &mock_reporter_)); + EXPECT_EQ(kTfLiteFloat16, type); +} + } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/lite/core/api/profiler.h b/tensorflow/lite/core/api/profiler.h new file mode 100644 index 00000000000..f36f8e13c3c --- /dev/null +++ b/tensorflow/lite/core/api/profiler.h @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_CORE_API_PROFILER_H_ +#define TENSORFLOW_LITE_CORE_API_PROFILER_H_ + +#include + +namespace tflite { + +// A simple utility for enabling profiled event tracing in TensorFlow Lite. +class Profiler { + public: + enum class EventType { + // Default event type, the metadata field has no special significance. + DEFAULT = 0, + // The event is an operator invocation and the event_metadata field is the + // index of operator node. + OPERATOR_INVOKE_EVENT = 1 + }; + + virtual ~Profiler() {} + + // Signals the beginning of an event, returning a handle to the profile event. + virtual uint32_t BeginEvent(const char* tag, EventType event_type, + uint32_t event_metadata) = 0; + + // Signals an end to the specified profile event. + virtual void EndEvent(uint32_t event_handle) = 0; +}; + +// Adds a profile event to `profiler` that begins with the construction +// of the object and ends when the object goes out of scope. +// The lifetime of tag should be at least the lifetime of `profiler`. +// `profiler` may be null, in which case nothing is profiled. +class ScopedProfile { + public: + ScopedProfile(Profiler* profiler, const char* tag, + Profiler::EventType event_type = Profiler::EventType::DEFAULT, + uint32_t event_metadata = 0) + : profiler_(profiler), event_handle_(0) { + if (profiler) { + event_handle_ = profiler_->BeginEvent(tag, event_type, event_metadata); + } + } + + ~ScopedProfile() { + if (profiler_) { + profiler_->EndEvent(event_handle_); + } + } + + private: + Profiler* const profiler_; + uint32_t event_handle_; +}; + +class ScopedOperatorProfile : public ScopedProfile { + public: + ScopedOperatorProfile(Profiler* profiler, const char* tag, int node_index) + : ScopedProfile(profiler, tag, Profiler::EventType::OPERATOR_INVOKE_EVENT, + static_cast(node_index)) {} +}; + +} // namespace tflite + +#define TFLITE_VARNAME_UNIQ(name, ctr) name##ctr +#define TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE(profiler, tag, node_index) \ + tflite::ScopedOperatorProfile TFLITE_VARNAME_UNIQ(_profile_, __COUNTER__)( \ + (profiler), (tag), (node_index)) +#define TFLITE_SCOPED_OPERATOR_PROFILE(profiler, node_index) \ + TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE((profiler), "OpInvoke", (node_index)) + +#endif // TENSORFLOW_LITE_CORE_API_PROFILER_H_ diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index ec6762b16c9..afa2d63f64f 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -14,16 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/core/subgraph.h" + +#include + #include "tensorflow/lite/arena_planner.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/graph_info.h" -#include "tensorflow/lite/nnapi_delegate.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace { + +struct TfLiteQuantizationDeleter { + void operator()(TfLiteQuantization* q) { + if (q) TfLiteQuantizationFree(q); + } +}; + +using ScopedTfLiteQuantization = + std::unique_ptr; + TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node, const TfLiteRegistration& registration, int node_index, const char* message) { @@ -38,10 +51,10 @@ TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node, } // Stub method which returns kTfLiteError when the function is forbidden. -// We're registrating this function to several different function to save +// We're registering this function to several different function to save // compiled binary size. Please note the restrictions: // * The type of first parameter have to be `TfLiteContext*`. -// * All paramteters must be trivailly destructible. (E.g. No C++ class) +// * All parameters must be trivially destructible. (E.g. No C++ class) TfLiteStatus ForbiddenContextFunction(TfLiteContext* context, ...) { context->ReportError(context, "The function is forbidden if not calling in delegate."); @@ -283,7 +296,7 @@ TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels( execution_plan_.clear(); for (auto& node_subset : node_subsets) { - // Subsets calimed by the delegate should have a "macro" op created, the + // Subsets claimed by the delegate should have a "macro" op created, the // other node_subsets (kTfNonPartition) just have their nodes added back to // the execution plan. switch (node_subset.type) { @@ -324,7 +337,7 @@ TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels( TfLiteExternalContext* Subgraph::GetExternalContext( TfLiteExternalContextType type) { - if (type >= 0 && type < kTfLiteMaxExternalContexts) { + if (static_cast(type) >= 0 && type < kTfLiteMaxExternalContexts) { return external_contexts_[type]; } return nullptr; @@ -337,7 +350,7 @@ TfLiteExternalContext* Subgraph::GetExternalContext( void Subgraph::SetExternalContext(TfLiteExternalContextType type, TfLiteExternalContext* ctx) { - if (type >= 0 && type < kTfLiteMaxExternalContexts) { + if (static_cast(type) >= 0 && type < kTfLiteMaxExternalContexts) { external_contexts_[type] = ctx; } } @@ -456,6 +469,9 @@ TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims, case kTfLiteInt8: *bytes = sizeof(int8_t) * count; break; + case kTfLiteFloat16: + *bytes = sizeof(TfLiteFloat16) * count; + break; default: ReportError( "Only float32, int8, int16, int32, int64, uint8, bool, complex64 " @@ -490,7 +506,7 @@ TfLiteStatus Subgraph::AllocateTensors() { // Reset the variable tensors to zero after (re)allocating the tensors. // Developers shouldn't rely on the side effect of this function to reset - // variable tesnsors. They should call `ResetVariableTensors` directly + // variable tensors. They should call `ResetVariableTensors` directly // instead. ResetVariableTensors(); @@ -519,15 +535,14 @@ TfLiteStatus Subgraph::AddNodeWithParameters( const std::vector& inputs, const std::vector& outputs, const char* init_data, size_t init_data_size, void* builtin_data, const TfLiteRegistration* registration, int* node_index) { + std::unique_ptr builtin_data_deleter(builtin_data, + free); if (state_ == kStateInvokableAndImmutable) { ReportError("AddNodeWithParameters is disallowed when graph is immutable."); return kTfLiteError; } state_ = kStateUninvokable; - std::unique_ptr builtin_data_deleter(builtin_data, - free); - TF_LITE_ENSURE_OK(context_, CheckTensorIndices("node inputs", inputs.data(), inputs.size())); TF_LITE_ENSURE_OK( @@ -666,18 +681,11 @@ TfLiteStatus Subgraph::Invoke() { return kTfLiteError; } - if (nnapi_delegate_) { - if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) { - TF_LITE_ENSURE_OK(context_, nnapi_delegate_->Invoke(this)); - return kTfLiteOk; - } else { - // TODO(aselle): In the future, we would like this to be an - // automatic tflite CPU fallback. - ReportError( - "NNAPI was requested, but dependent sized tensors " - "being used.\n"); - return kTfLiteError; - } + // This is only needed for UseNNAPI(true); + if (should_apply_nnapi_delegate_ && !applied_nnapi_delegate_) { + TF_LITE_ENSURE_OK(context_, ModifyGraphWithDelegate(NnApiDelegate())); + // only need to modify the graph once upon the first invocation. + applied_nnapi_delegate_ = true; } // Invocations are always done in node order. @@ -695,7 +703,7 @@ TfLiteStatus Subgraph::Invoke() { TfLiteNode& node = nodes_and_registration_[node_index].first; const TfLiteRegistration& registration = nodes_and_registration_[node_index].second; - SCOPED_OPERATOR_PROFILE(profiler_, node_index); + TFLITE_SCOPED_OPERATOR_PROFILE(profiler_, node_index); // TODO(ycling): This is an extra loop through inputs to check if the data // need to be copied from Delegate buffer to raw memory, which is often not @@ -820,6 +828,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantization quantization, const char* buffer, size_t bytes, const Allocation* allocation) { + // Ensure quantization cleanup on failure. + ScopedTfLiteQuantization scoped_quantization(&quantization); if (state_ == kStateInvokableAndImmutable) { ReportError( "SetTensorParametersReadOnly is disallowed when graph is immutable."); @@ -847,7 +857,7 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( tensor.data.raw = const_cast(buffer); if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims); tensor.params = GetLegacyQuantization(quantization); - tensor.quantization = quantization; + tensor.quantization = *scoped_quantization.release(); tensor.allocation_type = kTfLiteMmapRo; tensor.allocation = allocation; } else { @@ -858,7 +868,7 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( allocation, false, &tensor); // TODO(suharshs): Update TfLiteTensorReset to include the new quantization // if there are other required callers. - tensor.quantization = quantization; + tensor.quantization = *scoped_quantization.release(); } return kTfLiteOk; } @@ -870,6 +880,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( TfLiteStatus Subgraph::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantization quantization, bool is_variable) { + // Ensure quantization cleanup on failure. + ScopedTfLiteQuantization scoped_quantization(&quantization); if (state_ == kStateInvokableAndImmutable) { ReportError( "SetTensorParametersReadWrite is disallowed when graph is immutable."); @@ -906,7 +918,7 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite( nullptr, is_variable, &tensor); // TODO(suharshs): Update TfLiteTensorReset to include the new quantization // if there are other required callers. - tensor.quantization = quantization; + tensor.quantization = *scoped_quantization.release(); return kTfLiteOk; } @@ -957,14 +969,12 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor, } void Subgraph::UseNNAPI(bool enable) { - // TODO(aselle): This is a workaround for finding if NNAPI exists. - // We also need to make sure getLibraryHandle() is renamed to be NNAPI - // prefixed. - if (!NNAPIDelegate::IsSupported()) enable = false; - if (!enable) { - nnapi_delegate_.reset(); - } else if (!nnapi_delegate_) { - nnapi_delegate_.reset(new NNAPIDelegate); + // Note that there is no way to disable the delegate once it modified the + // graph. + if (applied_nnapi_delegate_ && !enable) { + ReportError("Attempting to disable NNAPI delegate after it's applied."); + } else { + should_apply_nnapi_delegate_ = enable; } } diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 5db15a177ef..b20cd06d686 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -20,8 +20,9 @@ limitations under the License. #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/profiler.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/memory_planner.h" -#include "tensorflow/lite/profiling/profiler.h" #include "tensorflow/lite/util.h" namespace tflite { @@ -84,7 +85,7 @@ class Subgraph { // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal - // to Interpreter. + // to Interpreter. `quantization` ownership is passed to the subgraph. inline TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantization quantization, @@ -102,7 +103,7 @@ class Subgraph { // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal - // to Interpreter. + // to Interpreter. `quantization` ownership is passed to the subgraph. inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantization quantization, @@ -163,10 +164,10 @@ class Subgraph { // Return the number of ops in the model. size_t nodes_size() const { return nodes_and_registration_.size(); } - // Read only access to list of variable tensors. + // Return vector of node indices in the order of execution. std::vector& execution_plan() { return execution_plan_; } - // Read only access to list of variable tensors. + // Return read-only vector of node indices in the order of execution. const std::vector& execution_plan() const { return execution_plan_; } // Mutable form of tensors (TEMPORARY for refactor). @@ -275,12 +276,12 @@ class Subgraph { // WARNING: This is an experimental API and subject to change. TfLiteStatus ResetVariableTensors(); - void SetProfiler(profiling::Profiler* profiler) { + void SetProfiler(Profiler* profiler) { profiler_ = profiler; context_->profiler = profiler; } - profiling::Profiler* GetProfiler() { return profiler_; } + Profiler* GetProfiler() { return profiler_; } // Returns a pointer to vector of subgraphs. // WARNING: This is an experimental API and subject to change. @@ -511,8 +512,9 @@ class Subgraph { // TODO(aselle): replace execution_plan_ with this. std::unique_ptr plan_cache_; - // Whether to delegate to NN API - std::unique_ptr nnapi_delegate_; + // Whether to use delegate to modify the graph. + bool should_apply_nnapi_delegate_ = false; + bool applied_nnapi_delegate_ = false; std::unique_ptr memory_planner_; @@ -525,7 +527,7 @@ class Subgraph { TfLiteExternalContext** external_contexts_; // Profiler for this interpreter instance. - profiling::Profiler* profiler_ = nullptr; + Profiler* profiler_ = nullptr; // A pointer to vector of subgraphs. The vector is owned by the interpreter. std::vector>* subgraphs_ = nullptr; diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index bca8e514fe4..43c3d5f6eb0 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -15,7 +15,6 @@ cc_library( hdrs = ["buffer_map.h"], deps = [ ":util", - "//tensorflow/c:c_api_internal", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite:string", "//tensorflow/lite:string_util", @@ -24,6 +23,7 @@ cc_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + "//tensorflow/c:c_api_internal", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", ], @@ -55,13 +55,13 @@ cc_library( deps = [ ":delegate_data", ":delegate_only_runtime", - "//tensorflow/lite/c:c_api_internal", ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib", ], "//conditions:default": [ "//tensorflow/core:tensorflow", + "//tensorflow/lite/c:c_api_internal", ], }), alwayslink = 1, @@ -86,6 +86,7 @@ cc_library( "@com_google_absl//absl/strings:strings", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", "//tensorflow/lite:string_util", "//tensorflow/lite:util", ] + select({ @@ -119,12 +120,12 @@ cc_library( deps = [ ":buffer_map", "@com_google_absl//absl/memory", - "//tensorflow/core/common_runtime/eager:context", ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core:core_cpu", "//tensorflow/core:lib", ], @@ -153,14 +154,11 @@ cc_library( ":delegate_data", ":util", "@flatbuffers", + "//tensorflow/lite/core/api", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite:kernel_api", "//tensorflow/lite:string", "//tensorflow/lite/kernels:kernel_util", - "//tensorflow/lite/profiling:profiler", - "//tensorflow/core/common_runtime/eager:context", - "//tensorflow/core/common_runtime/eager:execute", - "//tensorflow/core/common_runtime/eager:tensor_handle", ] + select({ # TODO(b/111881878): The android_tensorflow_lib target pulls in the full # set of core TensorFlow kernels. We may want to revisit this dependency @@ -169,6 +167,9 @@ cc_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:framework", @@ -215,7 +216,6 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ - "//tensorflow/c:c_api_internal", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite:kernel_api", ] + select({ @@ -223,6 +223,7 @@ cc_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + "//tensorflow/c:c_api_internal", "//tensorflow/core:lib", "//tensorflow/core:framework", ], diff --git a/tensorflow/lite/delegates/flex/buffer_map.cc b/tensorflow/lite/delegates/flex/buffer_map.cc index 0d0c9536366..1f6df9ada73 100644 --- a/tensorflow/lite/delegates/flex/buffer_map.cc +++ b/tensorflow/lite/delegates/flex/buffer_map.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/typed_allocator.h" #include "tensorflow/lite/delegates/flex/util.h" #include "tensorflow/lite/string.h" #include "tensorflow/lite/string_util.h" -#include "tensorflow/core/framework/allocation_description.pb.h" -#include "tensorflow/core/framework/log_memory.h" namespace tflite { namespace flex { @@ -99,8 +100,9 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer { ~StringTfLiteTensorBuffer() override { LogDeallocation(); - tensorflow::cpu_allocator()->Deallocate( - static_cast(data()), num_strings_); + tensorflow::TypedAllocator::Deallocate( + tensorflow::cpu_allocator(), static_cast(data()), + num_strings_); } size_t size() const override { return num_strings_ * sizeof(string); } @@ -109,7 +111,9 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer { StringTfLiteTensorBuffer(const TfLiteTensor* tensor, int num_strings) : BaseTfLiteTensorBuffer( num_strings != 0 - ? tensorflow::cpu_allocator()->Allocate(num_strings) + ? tensorflow::TypedAllocator::Allocate( + tensorflow::cpu_allocator(), num_strings, + tensorflow::AllocationAttributes()) : nullptr), num_strings_(num_strings) { LogAllocation(); diff --git a/tensorflow/lite/delegates/flex/delegate.cc b/tensorflow/lite/delegates/flex/delegate.cc index dcf5b795d82..9400ef697e0 100644 --- a/tensorflow/lite/delegates/flex/delegate.cc +++ b/tensorflow/lite/delegates/flex/delegate.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/buffer_map.h" #include "tensorflow/lite/delegates/flex/kernel.h" #include "tensorflow/lite/delegates/flex/util.h" +#include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/util.h" @@ -133,6 +134,8 @@ AcquireFlexDelegate() { } std::unique_ptr FlexDelegate::Create() { + TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO, + "Created TensorFlow Lite delegate for select TF ops."); return std::unique_ptr(new FlexDelegate()); } diff --git a/tensorflow/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc index 87f37697468..1c036c2ebd7 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.cc +++ b/tensorflow/lite/delegates/flex/delegate_data.cc @@ -22,7 +22,9 @@ namespace tflite { namespace flex { DelegateData::DelegateData() {} -DelegateData::~DelegateData() {} +DelegateData::~DelegateData() { + if (eager_context_) eager_context_->Unref(); +} tensorflow::Status DelegateData::Prepare( const tensorflow::SessionOptions& session_options) { @@ -40,10 +42,10 @@ tensorflow::Status DelegateData::Prepare( // Note that Rendezvous is ref-counted so it will be automatically deleted. tensorflow::Rendezvous* rendezvous = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - eager_context_.reset(new tensorflow::EagerContext( + eager_context_ = new tensorflow::EagerContext( session_options, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - /*async=*/false, std::move(device_mgr), rendezvous)); + /*async=*/false, std::move(device_mgr), rendezvous); return tensorflow::Status(); } diff --git a/tensorflow/lite/delegates/flex/delegate_data.h b/tensorflow/lite/delegates/flex/delegate_data.h index 20d6b40a5d2..5f88cfbf444 100644 --- a/tensorflow/lite/delegates/flex/delegate_data.h +++ b/tensorflow/lite/delegates/flex/delegate_data.h @@ -39,7 +39,7 @@ class DelegateData { // The EagerContext that is required for execution of Flex Ops. // Note: The context is lazily created after the first call to |Prepare()|. - tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } + tensorflow::EagerContext* GetEagerContext() { return eager_context_; } // Map from TF Lite tensor index to TensorFlow tensor for a given context. BufferMap* GetBufferMap(const TfLiteContext* context) { @@ -48,7 +48,7 @@ class DelegateData { private: // Will be null until Prepare() is called and completes successfully. - std::unique_ptr eager_context_; + tensorflow::EagerContext* eager_context_ = nullptr; // TODO(b/112439500): Clean up stale BufferMap instances after adding the // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate. std::unordered_map buffer_map_; diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index ceb9918f6fa..4f3d0f1dde6 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -24,10 +24,10 @@ limitations under the License. #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/delegates/flex/delegate_data.h" #include "tensorflow/lite/delegates/flex/util.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/profiling/profiler.h" #include "tensorflow/lite/string.h" // Note: this is part of TF Lite's Flex delegation code which is to be @@ -529,8 +529,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Execute the TensorFlow Ops sequentially. for (auto& node_data : op_data->nodes) { - SCOPED_TAGGED_OPERATOR_PROFILE( - reinterpret_cast(context->profiler), + TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE( + reinterpret_cast(context->profiler), node_data->name().c_str(), node_data->index()); auto status = ExecuteFlexOp(context, buffer_map, node_data.get()); @@ -558,9 +558,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace kernel TfLiteRegistration GetKernel() { - TfLiteRegistration registration{&kernel::Init, &kernel::Free, - &kernel::Prepare, &kernel::Eval, - nullptr, kTfLiteBuiltinDelegate}; + TfLiteRegistration registration{ + &kernel::Init, + &kernel::Free, + &kernel::Prepare, + &kernel::Eval, + nullptr, // .profiling_string + kTfLiteBuiltinDelegate, // .builtin_code + "TfLiteFlexDelegate", // .custom_name + 1, // .version + }; return registration; } diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc index c995b360f9d..4279f4ae397 100644 --- a/tensorflow/lite/delegates/flex/util.cc +++ b/tensorflow/lite/delegates/flex/util.cc @@ -60,6 +60,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { return TF_FLOAT; case kTfLiteFloat32: return TF_FLOAT; + case kTfLiteFloat16: + return TF_HALF; case kTfLiteInt16: return TF_INT16; case kTfLiteInt32: @@ -83,6 +85,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) { switch (type) { case TF_FLOAT: return kTfLiteFloat32; + case TF_HALF: + return kTfLiteFloat16; case TF_INT16: return kTfLiteInt16; case TF_INT32: diff --git a/tensorflow/lite/delegates/flex/util_test.cc b/tensorflow/lite/delegates/flex/util_test.cc index 87104751b81..69bba405055 100644 --- a/tensorflow/lite/delegates/flex/util_test.cc +++ b/tensorflow/lite/delegates/flex/util_test.cc @@ -101,9 +101,9 @@ TEST(UtilTest, CopyShapeAndType) { EXPECT_EQ( CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst), - kTfLiteError); - EXPECT_EQ(context.error, - "TF Lite does not support TensorFlow data type: half"); + kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + EXPECT_EQ(dst.type, kTfLiteFloat16); } TEST(UtilTest, TypeConversionsFromTFLite) { diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD new file mode 100644 index 00000000000..33f5a86c422 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -0,0 +1,117 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +# Primary purpose of this config is to replace ::util::Status with our custom +# light implementation ::tflite::gpu::StatusLite to reduce binary size. Besides +# that, certain features that were hard to communicate without full open source +# were hidden away too such as compiled models, serialization, and metadata. +# While the latter will be fully available with the open source release, the +# former will have to stay until absl::Status is released. +config_setting( + name = "tflite_gpu_binary_release", + values = {"copt": "-DTFLITE_GPU_BINARY_RELEASE"}, +) + +cc_library( + name = "gl_delegate", + srcs = ["gl_delegate.cc"], + hdrs = ["gl_delegate.h"], + linkopts = select({ + "//tensorflow:android": [ + "-lEGL", + "-lGLESv3", + ], + "//conditions:default": [], + }), + deps = [ + "@com_google_absl//absl/types:span", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_builder", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", + "//tensorflow/lite/delegates/gpu/gl:api", + "//tensorflow/lite/delegates/gpu/gl:command_queue", + "//tensorflow/lite/delegates/gpu/gl:compiler", + "//tensorflow/lite/delegates/gpu/gl:egl_environment", + "//tensorflow/lite/delegates/gpu/gl:gl_call", + "//tensorflow/lite/delegates/gpu/gl/converters:bhwc_to_phwc4", + "//tensorflow/lite/delegates/gpu/gl/converters:phwc4_to_bhwc", + "//tensorflow/lite/delegates/gpu/gl/kernels:registry", + "//tensorflow/lite/delegates/gpu/gl/workgroups:best_effort_calculator", + ] + select({ + "//conditions:default": [ + "//tensorflow/lite/delegates/gpu/gl:common_cc_fbs", + "//tensorflow/lite/delegates/gpu/gl:metadata_cc_fbs", + "//tensorflow/lite/delegates/gpu/gl:workgroups_cc_fbs", + "@flatbuffers", + "//tensorflow/lite/schema:schema_fbs", + ], + ":tflite_gpu_binary_release": [], + }), +) + +objc_library( + name = "metal_delegate", + srcs = ["metal_delegate.mm"], + hdrs = ["metal_delegate.h"], + copts = ["-std=c++11"], + sdk_frameworks = ["Metal"], + deps = [ + "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_builder", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations", + "//tensorflow/lite/delegates/gpu/metal:api", + "//tensorflow/lite/delegates/gpu/metal:buffer_convert", + "//tensorflow/lite/delegates/gpu/metal:compiled_model", + "//tensorflow/lite/delegates/gpu/metal:inference_context", + "@com_google_absl//absl/types:span", + ], +) + +# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtensorflowlite_gpu_gl.so +cc_binary( + name = "libtensorflowlite_gpu_gl.so", + linkopts = select({ + "//tensorflow:android": [ + "-lEGL", + "-lGLESv3", + ], + "//conditions:default": [], + }), + linkshared = 1, + linkstatic = 1, + tags = [ + "nobuilder", + "notap", + ], + deps = [":gl_delegate"], +) + +# build -c opt --config ios_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always :libtensorflowlite_gpu_metal.so +cc_binary( + name = "libtensorflowlite_gpu_metal.so", + linkshared = 1, + linkstatic = 1, + tags = [ + "nobuilder", + "notap", + ], + deps = [":metal_delegate"], +) diff --git a/tensorflow/lite/delegates/gpu/README.md b/tensorflow/lite/delegates/gpu/README.md new file mode 100644 index 00000000000..5c173af86fb --- /dev/null +++ b/tensorflow/lite/delegates/gpu/README.md @@ -0,0 +1,209 @@ +# TFLite on GPU + +TensorFlow Lite (TFLite) supports several hardware accelerators. This document +describes how to use the GPU backend using the TFLite delegate APIs on Android +and iOS. + +GPUs are designed to have high throughput for massively parallelizable +workloads. Thus, they are well-suited for deep neural nets which consists of a +huge number of operators, each working on some input tensor(s) that can be +easily divided into smaller workloads and carried out in parallel, typically +resulting in lower latency. In the best scenario, inference on the GPU may now +run fast enough and now become suitable for real-time applications if it was not +before. + +GPUs do their computation with 16-bit or 32-bit floating point numbers and do +not require quantization for optimal performance unlike the CPUs. If +quantization of your neural network was not an option due to lower accuracy +caused by lost precision, such concern can be discarded when running deep neural +net models on the GPU. + +Another benefit that comes with GPU inference is its power efficiency. GPUs +carry out the computations in a very efficient and optimized way, so that they +consume less power and generate less heat than when the same task is run on the +CPUs. + +TFLite on GPU supports the following ops in 16-bit and 32-bit float precision: + +* `ADD v1` +* `AVERAGE_POOL_2D v1` +* `CONCATENATION v1` +* `CONV_2D v1` +* `DEPTHWISE_CONV_2D v1-2` +* `FULLY_CONNECTED v1` +* `LOGISTIC v1` +* `LSTM v2 (Basic LSTM only)` +* `MAX_POOL_2D v1` +* `MUL v1` +* `PAD v1` +* `PRELU v1` +* `RELU v1` +* `RELU6 v1` +* `RESHAPE v1` +* `RESIZE_BILINEAR v1` +* `SOFTMAX v1` +* `STRIDED_SLICE v1` +* `SUB v1` +* `TRANSPOSE_CONV v1` + +## Basic Usage + +Using TFLite on GPU is as simple as getting the GPU delegate via +`TfLiteGpuDelegateCreate()` and then passing it to +`Interpreter::ModifyGraphWithDelegate()` instead of calling +`Interpreter::AllocateTensors()`: + +```c++ +//////// +// Set up interpreter. +auto model = FlatBufferModel::BuildFromFile(model_path); +ops::builtin::BuiltinOpResolver op_resolver; +std::unique_ptr interpreter; +InterpreterBuilder(*model, op_resolver)(&interpreter); + +//////// +// NEW: Prepare GPU delegate. +auto* delegate = TfLiteGpuDelegateCreate(/*options=*/nullptr); +if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return; + +//////// +// Run inference. +WriteToInputTensor(interpreter->typed_input_tensor(0)); +if (interpreter->Invoke() != kTfLiteOk) return; +ReadFromOutputTensor(interpreter->typed_output_tensor(0)); + +//////// +// Clean up. +TfLiteGpuDelegateDelete(delegate); +``` + +*IMPORTANT:* When calling `Interpreter::ModifyGraphWithDelegate()` or +`Interpreter::Invoke()`, the caller must have a `EGLContext` in the current +thread and `Interpreter::Invoke()` must be called from the same `EGLContext`. +If such `EGLContext` does not exist, the delegate will internally create one, +but then the developer must ensure that `Interpreter::Invoke()` is always called +from the same thread `Interpreter::ModifyGraphWithDelegate()` was called. + +## Building and Runtime + +TFLite GPU backend uses OpenGL compute shaders and thus requires OpenGL ES 3.1 +or higher. + +```sh +bazel build --config android_arm64 //path/to/your:project +``` + +Metal shaders are used for iOS, which were introduced with iOS 8. Thus, +compilation flags should look like: + +```sh +bazel build --config ios_arm64 //path/to/your:project +``` + +## Advanced Usage: Delegate Options + +There are GPU options that can be set and passed on to +`TfLiteGpuDelegateCreate()`. When option is set to `nullptr` as shown in the +Basic Usage, it translates to: + +```c++ +const TfLiteGpuDelegateOptions kDefaultOptions = { + .metadata = nullptr, + .compile_options = { + .precision_loss_allowed = 0, // false + .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST, + .dynamic_batch_enabled = 0, // false + }, +}; +``` + +Similar for `NewTfLiteMetalDelgate()`: + +```c++ +const TfLiteMetalDelegateOptions kDefaultOptions = { + .precision_loss_allowed = 0, // false + .wait_type = TFLITE_METAL_WAIT_TYPE_SLEEP, +}; +``` + +While it is convenient to just supply `nullptr`, it is recommended to explicitly +set the options to avoid any unexpected artifacts in case default values are +changed. + +*IMPORTANT:* Note that the default option does not allow precision loss, and +thus may not be the fastest. For faster execution, you may want to set +`precision_loss_allowed` to `1` for FP16 execution. + +## Advanced Usage: Input/Output Buffers (C++) + +To do computation on the GPU, data must be made available to the GPU which often +translates to performing a memory copy. It is desirable not to cross the +CPU/GPU memory boundary if possible, as this can take up a significant amount of +time. Usually, such crossing is inevitable, but in some special cases, one or +the other can be omitted. + +If the network's input is an image already loaded in the GPU memory, e.g. a GPU +texture containing the camera feed, it can stay in the GPU memory without ever +entering the CPU memory. Similarly, if the network's output is in the form of a +renderable image, e.g. +[image style transfer](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf), +it can be directly displayed on the screen. + +To let users achieve best performance, TFLite makes it possible for them to +directly read from/write to the delegate's hardware buffer and bypass avoidable +memory copies. + +Assuming the camera input is in the GPU memory as `GL_TEXTURE_2D`, it must be +first converted to a shader storage buffer object (SSBO) for OpenGL or to a +`MTLBuffer` object for Metal. One can associate a TfLiteTensor with a +user-prepared SSBO or `MTLBuffer` with `TfLiteGpuDelegateBindBufferToTensor()` +or `TfLiteMetalDelegateBindBufferToTensor()`, respectively. + +*IMPORTANT:* These must be called before +`Interpreter::ModifyGraphWithDelegate()`. + +*IMPORTANT:* By default, the inference output is copied from GPU memory to CPU +memory implicitly by the framework. This behavior can be turned off by calling +`Interpreter::SetAllowBufferHandleOutput(true)` during initialization. To copy +the inference output from GPU memory to CPU memory, explicit +`Interpreter::EnsureTensorDataIsReadable()` calls are required for each output +tensor. + +```c++ +//////// +// Prepare GPU delegate. +auto* delegate = TfLiteGpuDelegateCreate(nullptr); +interpreter->SetAllowBufferHandleOutput(true); // disable default gpu->cpu copy +#if defined(__ANDROID__) +if (TfLiteGpuDelegateBindBufferToTensor(delegate, user_provided_input_buffer, interpreter->inputs()[0]) != kTfLiteOk) return; +if (TfLiteGpuDelegateBindBufferToTensor(delegate, user_provided_output_buffer, interpreter->outputs()[0]) != kTfLiteOk) return; +#elif defined(__APPLE__) +if (TfLiteMetalDelegateBindBufferToTensor(delegate, user_provided_input_buffer, interpreter->inputs()[0]) != kTfLiteOk) return; +if (TfLiteMetalDelegateBindBufferToTensor(delegate, user_provided_output_buffer, interpreter->outputs()[0]) != kTfLiteOk) return; +#endif +if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return; + +//////// +// Run inference. +if (interpreter->Invoke() != kTfLiteOk) return; +``` + +## Tips and Tricks + +* Some operations that are trivial on CPU side may be high cost in GPU land. + One class of such operation is various forms of reshape operations (including + `BATCH_TO_SPACE`, `SPACE_TO_BATCH`, `SPACE_TO_DEPTH`, etc.). If those ops + are inserted into the network just for the network architect's logical + thinking, it is worth removing them for performance. + +* On GPU, tensor data is sliced into 4-channels. Thus, a computation on a + tensor of shape `[B, H, W, 5]` will perform about the same on a tensor of + shape `[B, H, W, 8]`, but significantly worse than `[B, H, W, 4]`. + +* In that sense, if the camera hardware supports image frames in RGBA, feeding + that 4-channel input is significantly faster as a memory copy (from 3-channel + RGB to 4-channel RGBX) can be avoided. + +* For performance [best practices](https://www.tensorflow.org/lite/performance/best_practices), do not hesitate to re-train your classifier with + mobile-optimized network architecture. That is a significant part of + optimization for on-device inference. diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD new file mode 100644 index 00000000000..c71ec763012 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -0,0 +1,181 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "convert", + srcs = ["convert.cc"], + hdrs = ["convert.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@FP16", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "data_type", + srcs = ["data_type.cc"], + hdrs = ["data_type.h"], +) + +cc_library( + name = "memory_management", + srcs = ["memory_management.cc"], + hdrs = ["memory_management.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "model", + hdrs = ["model.h"], + deps = [ + ":data_type", + ":shape", + ":status", + ":tensor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:any", + ], +) + +cc_test( + name = "model_test", + srcs = ["model_test.cc"], + deps = [ + ":model", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "model_builder", + srcs = ["model_builder.cc"], + hdrs = ["model_builder.h"], + deps = [ + ":data_type", + ":model", + ":operations", + ":shape", + ":status", + ":tensor", + "//tensorflow/lite:context", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "model_builder_test", + srcs = ["model_builder_test.cc"], + deps = [ + ":model_builder", + "//tensorflow/lite/c:c_api_internal", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "model_transformer", + srcs = ["model_transformer.cc"], + hdrs = ["model_transformer.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "@com_google_absl//absl/strings", + ], +) + +# TODO(impjdi): Add unit test for model_transformer. + +cc_library( + name = "operations", + srcs = ["operations.cc"], + hdrs = ["operations.h"], + deps = [ + ":data_type", + ":model", + ":shape", + ":status", + "@com_google_absl//absl/types:variant", + ], +) + +# TODO(impjdi): Add unit test for operations. + +cc_library( + name = "shape", + srcs = ["shape.cc"], + hdrs = ["shape.h"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "shape_test", + srcs = ["shape_test.cc"], + deps = [ + ":shape", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "status", + hdrs = ["status.h"], +) + +cc_library( + name = "tensor", + hdrs = ["tensor.h"], + deps = [ + ":data_type", + ":shape", + ], +) + +cc_library( + name = "types", + hdrs = ["types.h"], + deps = [ + "@FP16", + ], +) + +cc_library( + name = "util", + hdrs = ["util.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:types", + ], +) + +cc_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "memory_management_test", + srcs = ["memory_management_test.cc"], + deps = [ + ":memory_management", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/delegates/gpu/common/convert.cc b/tensorflow/lite/delegates/gpu/common/convert.cc new file mode 100644 index 00000000000..53db297571d --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/convert.cc @@ -0,0 +1,506 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/convert.h" + +#include +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace { + +constexpr int kPhwc4ChannelsInPlane = 4; +constexpr int kPhwo4i4ChannelsInPlane = 4; +constexpr int kPiohw4ChannelsInPlane = 4; + +} // namespace + +uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape) { + return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) * + AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w; +} + +uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape) { + return AlignByN(shape.i, kPhwo4i4ChannelsInPlane) * + AlignByN(shape.o, kPhwo4i4ChannelsInPlane) * shape.h * shape.w; +} + +// Layout is Po,H,W,OI4x4. +Status ConvertToPHWO4I4(absl::Span in, const OHWI& shape, + absl::Span out) { + if (in.size() != shape.DimensionsProduct()) { + return InvalidArgumentError(absl::StrCat( + "ConvertToPHWO4I4: Input data size does not match expected size: ", + in.size(), " != ", shape.DimensionsProduct())); + } + if (out.size() != GetElementsSizeForPHWO4I4(shape)) { + return InvalidArgumentError(absl::StrCat( + "ConvertToPHWO4I4: Output data size does not match expected size: ", + out.size(), " != ", GetElementsSizeForPHWO4I4(shape))); + } + + float* output = out.data(); + for (int p = 0; p < IntegralDivideRoundUp(shape.o, kPhwo4i4ChannelsInPlane); + ++p) { + for (int h = 0; h < shape.h; ++h) { + for (int w = 0; w < shape.w; ++w) { + for (int c = 0; + c < IntegralDivideRoundUp(shape.i, kPhwo4i4ChannelsInPlane); ++c) { + for (int co = 0; co < kPhwo4i4ChannelsInPlane; ++co) { + for (int ci = 0; ci < kPhwo4i4ChannelsInPlane; ++ci) { + float value = 0; + if (c * kPhwo4i4ChannelsInPlane + ci < shape.i && + p * kPhwo4i4ChannelsInPlane + co < shape.o) { + // tensor is in OHWI + int tensor_o = p * kPhwo4i4ChannelsInPlane + co; + int tensor_i = c * kPhwo4i4ChannelsInPlane + ci; + value = in[shape.LinearIndex({tensor_o, h, w, tensor_i})]; + } + (*output++) = value; + } + } + } + } + } + } + return OkStatus(); +} + +std::vector ConvertToPHWO4I4( + const Tensor& tensor) { + std::vector transposed(GetElementsSizeForPHWO4I4(tensor.shape)); + ConvertToPHWO4I4(tensor.data, tensor.shape, + absl::MakeSpan(transposed.data(), transposed.size())) + .IgnoreError(); + return transposed; +} + +uint3 Get3DSizeForPHWO4I4(const OHWI& shape) { + return uint3(AlignByN(shape.i, 4), shape.h * shape.w, + IntegralDivideRoundUp(shape.o, 4)); +} + +// Layout is Po,H,W,OI4x4. +Status ConvertToPHWO4I4(absl::Span in, const IHWO& shape, + absl::Span out) { + if (in.size() != shape.DimensionsProduct()) { + return InvalidArgumentError(absl::StrCat( + "ConvertToPHWO4I4: Input data size does not match expected size: ", + in.size(), " != ", shape.DimensionsProduct())); + } + if (out.size() != GetElementsSizeForPHWO4I4(shape)) { + return InvalidArgumentError(absl::StrCat( + "ConvertToPHWO4I4: Output data size does not match expected size: ", + out.size(), " != ", GetElementsSizeForPHWO4I4(shape))); + } + + const int dst_depth = IntegralDivideRoundUp(shape.o, 4); + const int src_depth = IntegralDivideRoundUp(shape.i, 4); + + float* output = out.data(); + for (int f = 0; f < dst_depth; ++f) { + for (int y = 0; y < shape.h; ++y) { + for (int x = 0; x < shape.w; ++x) { + for (int ch = 0; ch < src_depth; ++ch) { + for (int co = 0; co < 4; ++co) { + for (int ci = 0; ci < 4; ++ci) { + const int src_channel = ch * 4 + ci; + const int dst_channel = f * 4 + co; + float value = 0; + if (src_channel < shape.i && dst_channel < shape.o) { + // tensor is in IHWO + value = in[shape.LinearIndex({src_channel, y, x, dst_channel})]; + } + (*output++) = value; + } + } + } + } + } + } + return OkStatus(); +} + +std::vector ConvertToPHWO4I4( + const Tensor& tensor) { + std::vector transposed(GetElementsSizeForPHWO4I4(tensor.shape)); + ConvertToPHWO4I4(tensor.data, tensor.shape, + absl::MakeSpan(transposed.data(), transposed.size())) + .IgnoreError(); + return transposed; +} + +uint32_t GetElementsSizeForPIOHW4(const OHWI& shape) { + return AlignByN(shape.o * shape.i, kPiohw4ChannelsInPlane) * shape.h * + shape.w; +} + +Status ConvertToPIOHW4(absl::Span in, const OHWI& shape, + absl::Span out) { + if (in.size() != shape.DimensionsProduct()) { + return InvalidArgumentError(absl::StrCat( + "ConvertToPIOHW4: Input data size does not match expected size: ", + in.size(), " != ", shape.DimensionsProduct())); + } + if (out.size() != GetElementsSizeForPIOHW4(shape)) { + return InvalidArgumentError(absl::StrCat( + "ConvertToPIOHW4: Output data size does not match expected size: ", + out.size(), " != ", GetElementsSizeForPIOHW4(shape))); + } + + int32_t output_channels = shape.o * shape.i; + int32_t num_planes = + IntegralDivideRoundUp(output_channels, kPiohw4ChannelsInPlane); + float* output = out.data(); + for (int p = 0; p < num_planes; ++p) { + for (int h = 0; h < shape.h; ++h) { + for (int w = 0; w < shape.w; ++w) { + for (int c = 0; c < kPiohw4ChannelsInPlane; ++c) { + int output_c = p * kPiohw4ChannelsInPlane + c; + (*output++) = output_c >= output_channels + ? 0 + : in[shape.LinearIndex({output_c % shape.o, h, w, + output_c / shape.o})]; + } + } + } + } + return OkStatus(); +} + +std::vector ConvertToPIOHW4( + const Tensor& tensor) { + std::vector transposed(GetElementsSizeForPIOHW4(tensor.shape)); + ConvertToPIOHW4(tensor.data, tensor.shape, + absl::MakeSpan(transposed.data(), transposed.size())) + .IgnoreError(); + return transposed; +} + +template +Status ValidateConvertToPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { + if (in.size() != shape.DimensionsProduct()) { + return InvalidArgumentError(absl::StrCat( + "ConvertToPHWC4: Input data size does not match expected size: ", + in.size(), " != ", shape.DimensionsProduct())); + } + if (out.size() != GetElementsSizeForPHWC4(shape)) { + return InvalidArgumentError(absl::StrCat( + "ConvertToPHWC4: Output data size does not match expected size: ", + out.size(), " != ", GetElementsSizeForPHWC4(shape))); + } + return OkStatus(); +} + +// Layout is Pc,H,W,C4 where P - is a plane based on channels. +Status ConvertToPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { + RETURN_IF_ERROR(ValidateConvertToPHWC4(in, shape, out)); + if (shape.c == 4) { + std::memcpy(out.data(), in.data(), + shape.DimensionsProduct() * sizeof(float)); + return OkStatus(); + } + // Layout is Pc,H,W,C4 where P - is a plane based on channels. + int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); + const int num_pixels = shape.h * shape.w; + // A layer is a set of kPhwc4ChannelsInPlane channels images. + const int num_full_planes = shape.c / kPhwc4ChannelsInPlane; + for (int b = 0; b < shape.b; b++) { + float* dest = + out.data() + b * num_pixels * num_planes * kPhwc4ChannelsInPlane; + for (int p = 0; p < num_full_planes; p++) { + const float* src = + in.data() + shape.LinearIndex({b, 0, 0, p * kPhwc4ChannelsInPlane}); + for (int i = 0; i < num_pixels; i++) { + std::memcpy(dest, src, kPhwc4ChannelsInPlane * sizeof(float)); + src += shape.c; + dest += kPhwc4ChannelsInPlane; + } + } + } + + // Padding last kPhwc4ChannelsInPlane-channel layer to multiple of + // kPhwc4ChannelsInPlane. + const int padded_size = num_pixels * num_planes * kPhwc4ChannelsInPlane; + const int remaining_channels = + shape.c - num_full_planes * kPhwc4ChannelsInPlane; + if (remaining_channels == 0) { + return OkStatus(); + } + for (int b = 0; b < shape.b; b++) { + const float* src = + in.data() + + shape.LinearIndex({b, 0, 0, num_full_planes * kPhwc4ChannelsInPlane}); + float* dest = out.data() + b * padded_size + + num_pixels * num_full_planes * kPhwc4ChannelsInPlane; + for (int p = 0; p < num_pixels; p++) { + std::memcpy(dest, src, remaining_channels * sizeof(float)); + std::memset(dest + remaining_channels, 0, + (4 - remaining_channels) * sizeof(float)); + src += shape.c; + dest += kPhwc4ChannelsInPlane; + } + } + return OkStatus(); +} + +// Layout is Pc,H,W,C4 where P - is a plane based on channels. +Status ConvertToPHWC4Half(absl::Span in, const BHWC& shape, + absl::Span out) { + RETURN_IF_ERROR(ValidateConvertToPHWC4(in, shape, out)); + + // Layout is Pc,H,W,C4 where P - is a plane based on channels. + int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); + const int num_pixels = shape.h * shape.w; + // A layer is a set of kPhwc4ChannelsInPlane channels images. + const int num_full_planes = shape.c / kPhwc4ChannelsInPlane; + for (int b = 0; b < shape.b; b++) { + HalfBits* dest = + out.data() + b * num_pixels * num_planes * kPhwc4ChannelsInPlane; + for (int p = 0; p < num_full_planes; p++) { + const float* src = + in.data() + shape.LinearIndex({b, 0, 0, p * kPhwc4ChannelsInPlane}); + for (int i = 0; i < num_pixels; i++) { + dest[0] = fp16_ieee_from_fp32_value(src[0]); + dest[1] = fp16_ieee_from_fp32_value(src[1]); + dest[2] = fp16_ieee_from_fp32_value(src[2]); + dest[3] = fp16_ieee_from_fp32_value(src[3]); + src += shape.c; + dest += kPhwc4ChannelsInPlane; + } + } + } + + // Padding last kPhwc4ChannelsInPlane-channel layer to multiple of + // kPhwc4ChannelsInPlane. + const int padded_size = num_pixels * num_planes * kPhwc4ChannelsInPlane; + const int remaining_channels = + shape.c - num_full_planes * kPhwc4ChannelsInPlane; + if (remaining_channels == 0) { + return OkStatus(); + } + + for (int b = 0; b < shape.b; b++) { + const float* src = + in.data() + + shape.LinearIndex({b, 0, 0, num_full_planes * kPhwc4ChannelsInPlane}); + HalfBits* dest = out.data() + b * padded_size + + num_pixels * num_full_planes * kPhwc4ChannelsInPlane; + switch (remaining_channels) { + case 1: + for (int p = 0; p < num_pixels; p++) { + dest[0] = fp16_ieee_from_fp32_value(src[0]); + dest[1] = 0; + dest[2] = 0; + dest[3] = 0; + src += shape.c; + dest += kPhwc4ChannelsInPlane; + } + break; + case 2: + for (int p = 0; p < num_pixels; p++) { + dest[0] = fp16_ieee_from_fp32_value(src[0]); + dest[1] = fp16_ieee_from_fp32_value(src[1]); + dest[2] = 0; + dest[3] = 0; + src += shape.c; + dest += kPhwc4ChannelsInPlane; + } + break; + case 3: + for (int p = 0; p < num_pixels; p++) { + dest[0] = fp16_ieee_from_fp32_value(src[0]); + dest[1] = fp16_ieee_from_fp32_value(src[1]); + dest[2] = fp16_ieee_from_fp32_value(src[2]); + dest[3] = 0; + src += shape.c; + dest += kPhwc4ChannelsInPlane; + } + break; + default: + return UnimplementedError( + "ConvertToPHWC4Half: Unsupported channels per planes count."); + } + } + return OkStatus(); +} + +std::vector ConvertToPHWC4( + const Tensor& tensor) { + std::vector transposed(GetElementsSizeForPHWC4(tensor.shape)); + ConvertToPHWC4(tensor.data, tensor.shape, + absl::MakeSpan(transposed.data(), transposed.size())) + .IgnoreError(); + // TODO(akulik): Maybe safer to return Status. + return transposed; +} + +std::vector ConvertToPHWC4( + const Tensor& tensor) { + const BHWC batched_shape = + BHWC(1, tensor.shape.h, tensor.shape.w, tensor.shape.c); + std::vector transposed(GetElementsSizeForPHWC4(batched_shape)); + ConvertToPHWC4(tensor.data, batched_shape, + absl::MakeSpan(transposed.data(), transposed.size())) + .IgnoreError(); + // TODO(akulik): Maybe safer to return Status. + return transposed; +} + +uint32_t GetElementsSizeForPHWC4(const BHWC& shape) { + return shape.b * shape.h * shape.w * AlignByN(shape.c, kPhwc4ChannelsInPlane); +} + +template +Status ValidateConvertFromPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { + if (in.size() != GetElementsSizeForPHWC4(shape)) { + return InvalidArgumentError(absl::StrCat( + "ConvertFromPHWC4: Input data size does not match expected size: ", + in.size(), " != ", GetElementsSizeForPHWC4(shape))); + } + if (out.size() != shape.DimensionsProduct()) { + return InvalidArgumentError(absl::StrCat( + "ConvertFromPHWC4: Output data size does not match expected size: ", + out.size(), " != ", shape.DimensionsProduct())); + } + return OkStatus(); +} + +Status ConvertFromPHWC4(absl::Span in, const BHWC& shape, + absl::Span out) { + RETURN_IF_ERROR(ValidateConvertFromPHWC4(in, shape, out)); + if (shape.c == 4) { + std::memcpy(out.data(), in.data(), + shape.DimensionsProduct() * sizeof(float)); + return OkStatus(); + } + + int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); + const int num_pixels = shape.h * shape.w; + const int padded_size = num_pixels * num_planes * kPhwc4ChannelsInPlane; + // A layer is a set of kPhwc4ChannelsInPlane channels images. + const int num_full_planes = shape.c / kPhwc4ChannelsInPlane; + for (int b = 0; b < shape.b; b++) { + const float* src = in.data() + b * padded_size; + for (int p = 0; p < num_full_planes; p++) { + float* dest = + out.data() + shape.LinearIndex({b, 0, 0, p * kPhwc4ChannelsInPlane}); + for (int i = 0; i < num_pixels; i++) { + std::memcpy(dest, src, kPhwc4ChannelsInPlane * sizeof(float)); + src += kPhwc4ChannelsInPlane; + dest += shape.c; + } + } + } + + // Unpadding last kPhwc4ChannelsInPlane-channel plane + const int remaining_channels = + shape.c - num_full_planes * kPhwc4ChannelsInPlane; + if (remaining_channels == 0) { + return OkStatus(); + } + for (int b = 0; b < shape.b; b++) { + const float* src = in.data() + b * padded_size + + num_pixels * num_full_planes * kPhwc4ChannelsInPlane; + float* dest = + out.data() + + shape.LinearIndex({b, 0, 0, num_full_planes * kPhwc4ChannelsInPlane}); + for (int p = 0; p < num_pixels; p++) { + std::memcpy(dest, src, remaining_channels * sizeof(float)); + src += kPhwc4ChannelsInPlane; + dest += shape.c; + } + } + return OkStatus(); +} + +Status ConvertFromPHWC4Half(absl::Span in, const BHWC& shape, + absl::Span out) { + RETURN_IF_ERROR(ValidateConvertFromPHWC4(in, shape, out)); + int num_planes = IntegralDivideRoundUp(shape.c, kPhwc4ChannelsInPlane); + const int num_pixels = shape.h * shape.w; + const int padded_size = num_pixels * num_planes * kPhwc4ChannelsInPlane; + // A layer is a set of kPhwc4ChannelsInPlane channels images. + const int num_full_planes = shape.c / kPhwc4ChannelsInPlane; + for (int b = 0; b < shape.b; b++) { + const HalfBits* src = in.data() + b * padded_size; + for (int p = 0; p < num_full_planes; p++) { + float* dest = + out.data() + shape.LinearIndex({b, 0, 0, p * kPhwc4ChannelsInPlane}); + for (int i = 0; i < num_pixels; i++) { + dest[0] = fp16_ieee_to_fp32_value(src[0]); + dest[1] = fp16_ieee_to_fp32_value(src[1]); + dest[2] = fp16_ieee_to_fp32_value(src[2]); + dest[3] = fp16_ieee_to_fp32_value(src[3]); + src += kPhwc4ChannelsInPlane; + dest += shape.c; + } + } + } + + // Unpadding last kPhwc4ChannelsInPlane-channel plane + const int remaining_channels = + shape.c - num_full_planes * kPhwc4ChannelsInPlane; + if (remaining_channels == 0) { + return OkStatus(); + } + for (int b = 0; b < shape.b; b++) { + const HalfBits* src = in.data() + b * padded_size + + num_pixels * num_full_planes * kPhwc4ChannelsInPlane; + float* dest = + out.data() + + shape.LinearIndex({b, 0, 0, num_full_planes * kPhwc4ChannelsInPlane}); + switch (remaining_channels) { + case 1: + for (int p = 0; p < num_pixels; p++) { + dest[0] = fp16_ieee_to_fp32_value(src[0]); + src += kPhwc4ChannelsInPlane; + dest += shape.c; + } + break; + case 2: + for (int p = 0; p < num_pixels; p++) { + dest[0] = fp16_ieee_to_fp32_value(src[0]); + dest[1] = fp16_ieee_to_fp32_value(src[1]); + src += kPhwc4ChannelsInPlane; + dest += shape.c; + } + break; + case 3: + for (int p = 0; p < num_pixels; p++) { + dest[0] = fp16_ieee_to_fp32_value(src[0]); + dest[1] = fp16_ieee_to_fp32_value(src[1]); + dest[2] = fp16_ieee_to_fp32_value(src[2]); + src += kPhwc4ChannelsInPlane; + dest += shape.c; + } + break; + default: + return UnimplementedError( + "ConvertToPHWC4Half: Unsupported channels per planes count."); + } + } + return OkStatus(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/convert.h b/tensorflow/lite/delegates/gpu/common/convert.h new file mode 100644 index 00000000000..1907bc83675 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/convert.h @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { + +// PHWC4 layout is where channels are grouped by 4 in a row and P stands for +// a plane that was derived by dividing channels by 4. +::tflite::gpu::Status ConvertToPHWC4(absl::Span in, + const BHWC& shape, absl::Span out); +::tflite::gpu::Status ConvertToPHWC4Half( + absl::Span in, const BHWC& shape, + absl::Span<::tflite::gpu::HalfBits> out); + +// @return number of elements when shape is converted into PHWC4. +uint32_t GetElementsSizeForPHWC4(const BHWC& shape); + +// Operation is opposite to ConvertToPHWC4. +::tflite::gpu::Status ConvertFromPHWC4(absl::Span in, + const BHWC& shape, + absl::Span out); +::tflite::gpu::Status ConvertFromPHWC4Half( + absl::Span in, const BHWC& shape, + absl::Span out); + +// Convenience wrapper around a method above. +std::vector ConvertToPHWC4( + const Tensor& tensor); +std::vector ConvertToPHWC4(const Tensor& tensor); + +// @return number of elements when shape is converted into PIOHW4. +uint32_t GetElementsSizeForPIOHW4(const OHWI& shape); + +// PIOHW4 layout re-arranges weights in groups by 4, where outer dimension is +// P which is OxI/4. +::tflite::gpu::Status ConvertToPIOHW4(absl::Span in, + const OHWI& shape, absl::Span out); + +// Convenience wrapper around a method above. +std::vector ConvertToPIOHW4( + const Tensor& tensor); + +// @return number of elements when shape is converted into PHWO4I4. +uint32_t GetElementsSizeForPHWO4I4(const OHWI& shape); + +// Layout is Po,H,W,OI4x4. +::tflite::gpu::Status ConvertToPHWO4I4(absl::Span in, + const OHWI& shape, + absl::Span out); + +// Convenience wrapper around a method above. +std::vector ConvertToPHWO4I4( + const Tensor& tensor); + +// @return (x,y,z) size for PHWO4I4 to access elements where each element +// consists of 4 values. +::tflite::gpu::uint3 Get3DSizeForPHWO4I4(const OHWI& shape); + +// @return number of elements when shape is converted into PHWO4I4. +uint32_t GetElementsSizeForPHWO4I4(const IHWO& shape); + +// Layout is Po,H,W,OI4x4. +::tflite::gpu::Status ConvertToPHWO4I4(absl::Span in, + const IHWO& shape, + absl::Span out); + +// Convenience wrapper around a method above. +std::vector ConvertToPHWO4I4( + const Tensor& tensor); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_ diff --git a/tensorflow/lite/delegates/gpu/common/data_type.cc b/tensorflow/lite/delegates/gpu/common/data_type.cc new file mode 100644 index 00000000000..b157a4ce338 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/data_type.cc @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/data_type.h" + +#include +#include + +namespace tflite { +namespace gpu { + +size_t SizeOf(DataType data_type) { + switch (data_type) { + case DataType::UINT8: + case DataType::INT8: + return 1; + case DataType::FLOAT16: + case DataType::INT16: + case DataType::UINT16: + return 2; + case DataType::FLOAT32: + case DataType::INT32: + case DataType::UINT32: + return 4; + case DataType::FLOAT64: + case DataType::INT64: + case DataType::UINT64: + return 8; + case DataType::UNKNOWN: + return 0; + } + return 0; +} + +std::string ToString(DataType data_type) { + switch (data_type) { + case DataType::FLOAT16: + return "float16"; + case DataType::FLOAT32: + return "float32"; + case DataType::FLOAT64: + return "float64"; + case DataType::INT16: + return "int16"; + case DataType::INT32: + return "int32"; + case DataType::INT64: + return "int64"; + case DataType::INT8: + return "int8"; + case DataType::UINT16: + return "uint16"; + case DataType::UINT32: + return "uint32"; + case DataType::UINT64: + return "uint64"; + case DataType::UINT8: + return "uint8"; + case DataType::UNKNOWN: + return "unknown"; + } + return "undefined"; +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/data_type.h b/tensorflow/lite/delegates/gpu/common/data_type.h new file mode 100644 index 00000000000..e589820e913 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/data_type.h @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_DATA_TYPE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_DATA_TYPE_H_ + +#include +#include + +namespace tflite { +namespace gpu { + +enum class DataType { + UNKNOWN = 0, + FLOAT16 = 1, + FLOAT32 = 2, + FLOAT64 = 3, + UINT8 = 4, + INT8 = 5, + UINT16 = 6, + INT16 = 7, + UINT32 = 8, + INT32 = 9, + UINT64 = 10, + INT64 = 11, +}; + +size_t SizeOf(DataType type); + +std::string ToString(DataType t); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_DATA_TYPE_H_ diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.cc b/tensorflow/lite/delegates/gpu/common/memory_management.cc new file mode 100644 index 00000000000..73a27a3c4ea --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/memory_management.cc @@ -0,0 +1,366 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/memory_management.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace { + +struct PoolRecord { + PoolRecord(uint32_t size, size_t obj_id) + : object_size(size), object_id(obj_id) {} + + // Objects in pool are ordered by size. + bool operator<(const PoolRecord& other) const { + return (object_size < other.object_size) || + (object_size == other.object_size && object_id < other.object_id); + } + + uint32_t object_size; + size_t object_id; +}; + +struct QueueRecord { + QueueRecord(TaskId task_id, size_t obj_id) + : last_task(task_id), object_id(obj_id) {} + + // Objects in queue are ordered by last_task. + bool operator<(const QueueRecord& other) const { + return (last_task > other.last_task) || + (last_task == other.last_task && object_id > other.object_id); + } + + // Last task, where shared object is used. + TaskId last_task; + size_t object_id; +}; + +// Implements memory management with a naive algorithm. +// +// The problem of memory management is NP-complete. This implements a +// naive algorithm that assigns each tensor to a separate object in memory. +Status NaiveAssignment(const std::vector& usage_records, + ObjectsAssignment* assignment) { + assignment->object_sizes.resize(usage_records.size()); + assignment->object_ids.resize(usage_records.size()); + for (size_t i = 0; i < usage_records.size(); i++) { + auto& record = usage_records[i]; + assignment->object_ids[i] = i; + assignment->object_sizes[i] = record.tensor_size; + } + return OkStatus(); +} + +// Implements memory management with a greedy algorithm. +// +// The problem of memory management is NP-complete. This implements a +// greedy algorithm that approximates an optimal solution with following +// heuristic: +// +// 1. Iterates through all tensor usage records and for every object reference +// assigns shared object from the pool. When object reference is used +// for the last time, corresponding shared object is returned back to +// the pool. +// +// 2. Shared object pool grows when there are no free shared object +// available. +// +// 3. Shared object size may increase when tensor requests larger size. +Status GreedyAssignment(const std::vector& usage_records, + ObjectsAssignment* assignment) { + assignment->object_sizes.clear(); + assignment->object_ids.resize(usage_records.size()); + + // Pool of free shared objects is ordered by object size, because we perform + // lower_bound search in it. + std::set pool; + // Queue of shared objects in use, ordered by their last_task. + std::priority_queue objects_in_use; + for (size_t i = 0; i < usage_records.size(); i++) { + // Pop from the queue and add to the pool all objects that are no longer + // in use at the time of execution of the first_task of i-th intermediate + // tensor. + while (!objects_in_use.empty() && + objects_in_use.top().last_task < usage_records[i].first_task) { + auto object_id = objects_in_use.top().object_id; + pool.insert({assignment->object_sizes[object_id], object_id}); + objects_in_use.pop(); + } + uint32_t tensor_size = usage_records[i].tensor_size; + if (pool.empty()) { + // No free shared object, creating a new one, assign i-th tensor to + // it and add to the queue of objects in use. + assignment->object_ids[i] = assignment->object_sizes.size(); + assignment->object_sizes.push_back(tensor_size); + objects_in_use.push( + {usage_records[i].last_task, assignment->object_ids[i]}); + } else { + auto best_it = pool.end(); + // Find shared object from pool, that will waste the least possible + // amount of memory when reused for current tensor. + auto pool_it = pool.lower_bound({tensor_size, 0}); + uint32_t size_diff = 0; + if (pool_it != pool.end()) { + // Try smallest shared object from pool with size >= tensor_size. + size_diff = pool_it->object_size - tensor_size; + best_it = pool_it; + } + if (pool_it != pool.begin()) { + // Try largest shared object from pool with size < tensor_size. + pool_it--; + if (best_it == pool.end() || + tensor_size - pool_it->object_size < size_diff) { + size_diff = tensor_size - pool_it->object_size; + best_it = pool_it; + } + } + // best_it can't be equal to pool.end(), because pool is not empty + if (best_it == pool.end()) { + return InternalError( + "No shared object is found in non-empty pool in GreedyAssignment."); + } + size_t shared_id = best_it->object_id; + pool.erase(best_it); + assignment->object_ids[i] = shared_id; + assignment->object_sizes[shared_id] = + std::max(assignment->object_sizes[shared_id], tensor_size); + objects_in_use.push( + {usage_records[i].last_task, assignment->object_ids[i]}); + } + } + return OkStatus(); +} + +// This class build flow graph and solves Minimum-cost flow problem in it. +class MinCostFlowSolver { + public: + // Build auxiliary flow graph, based on information about intermediate + // tensors. + void Build(const std::vector& usage_records) { + usage_records_ = &usage_records; + num_tensors_ = usage_records.size(); + source_ = 2 * num_tensors_; + sink_ = source_ + 1; + edges_from_.resize(sink_ + 1); + std::vector old_record_ids; + std::priority_queue objects_in_use; + for (size_t i = 0; i < usage_records.size(); i++) { + // Pop from the queue all objects that are no longer in use at the time of + // execution of the first_task of i-th intermediate tensor. + while (!objects_in_use.empty() && + objects_in_use.top().last_task < usage_records[i].first_task) { + old_record_ids.push_back(objects_in_use.top().object_id); + objects_in_use.pop(); + } + objects_in_use.push({usage_records[i].last_task, i}); + AddEdge(source_, i, 1, 0); + AddEdge(RightPartTwin(i), sink_, 1, 0); + + // Edge from source_ to i-th vertex in the right part of flow graph + // are added for the case of allocation of new shared object for i-th + // tensor. Cost of these edges is equal to the size of i-th tensor. + AddEdge(source_, RightPartTwin(i), 1, usage_records[i].tensor_size); + + // Edges from vertices of the left part of flow graph, corresponding to + // old_record_ids, to i-th vertex in the right part of flow graph are + // added for the case of reusing previously created shared objects for + // i-th tensor. Cost of these edges is an approximation of the size of new + // allocated memory. + for (auto record_id : old_record_ids) { + int cost = 0; + if (usage_records[i].tensor_size > + usage_records[record_id].tensor_size) { + cost = usage_records[i].tensor_size - + usage_records[record_id].tensor_size; + } + AddEdge(record_id, RightPartTwin(i), 1, cost); + } + } + } + + // Solve Minimum-cost flow problem with Shortest Path Faster Algorithm. + void Solve() { + const int kInf = std::numeric_limits::max(); + std::vector prev_edge(sink_ + 1); + while (true) { + std::queue cur_queue, next_queue; + std::vector last_it_in_queue(sink_ + 1); + std::vector dist(sink_ + 1, kInf); + size_t it = 1; + cur_queue.push(source_); + last_it_in_queue[source_] = it; + dist[source_] = 0; + // Find shortest path from source_ to sink_, using only edges with + // positive capacity. + while (!cur_queue.empty()) { + ++it; + while (!cur_queue.empty()) { + auto v = cur_queue.front(); + cur_queue.pop(); + for (const auto& edge_id : edges_from_[v]) { + const Edge& edge = edges_[edge_id]; + if (edge.cap > 0) { + auto u = edge.dst; + int new_dist = dist[v] + edge.cost; + if (new_dist < dist[u]) { + dist[u] = new_dist; + prev_edge[u] = edge_id; + if (last_it_in_queue[u] != it) { + next_queue.push(u); + last_it_in_queue[u] = it; + } + } + } + } + } + std::swap(cur_queue, next_queue); + } + // If path is not found, final result is ready. + if (dist[sink_] == kInf) break; + + // If path is found, we need to decrease the capacity of its edges, and + // increase the capacity of its reversed edges. + for (size_t v = sink_; v != source_;) { + --edges_[prev_edge[v]].cap; + Edge& rev_edge = edges_[prev_edge[v] ^ 1]; + ++rev_edge.cap; + v = rev_edge.dst; + } + } + } + + void CalculateAssignment(ObjectsAssignment* assignment) { + assignment->object_sizes.clear(); + assignment->object_ids.resize(num_tensors_); + is_tensor_assigned_.resize(num_tensors_); + for (const auto& edge_id : edges_from_[source_]) { + const Edge& edge = edges_[edge_id]; + if (edge.cap == 0 && IsRightPartVertex(edge.dst)) { + assignment->object_sizes.push_back( + AssignTensorsToNewSharedObject(LeftPartTwin(edge.dst), assignment)); + } + } + } + + private: + struct Edge { + Edge(size_t dst, int cap, int cost) : dst(dst), cap(cap), cost(cost) {} + + size_t dst; + int cap; + int cost; + }; + + // Add edge from vertex src to vertex dst with given capacity and cost and its + // reversed edge to the flow graph. If some edge has index idx, its reversed + // edge has index idx^1. + void AddEdge(size_t src, size_t dst, int cap, int cost) { + edges_from_[src].push_back(edges_.size()); + edges_.emplace_back(dst, cap, cost); + edges_from_[dst].push_back(edges_.size()); + edges_.push_back({src, 0, -cost}); + } + + // Check, if vertex_id belongs to right part of the flow graph. + bool IsRightPartVertex(size_t vertex_id) const { + return vertex_id >= num_tensors_ && vertex_id < 2 * num_tensors_; + } + + // Return vertex from another part of the graph, that corresponds to the same + // intermediate tensor. + size_t LeftPartTwin(size_t vertex_id) const { + return vertex_id - num_tensors_; + } + size_t RightPartTwin(size_t vertex_id) const { + return vertex_id + num_tensors_; + } + + // This function uses recursive implementation of depth-first search and + // returns maximum size from tensor tensor_id and all tensors, that will be + // allocated at the same place with it after all operations that use tensor_id + // are executed. Next tensor to be allocated at the same place with tensor_id + // is a left part twin of such vertex v, that the edge tensor_id->v is + // saturated (has zero residual capacity). + uint32_t AssignTensorsToNewSharedObject(size_t tensor_id, + ObjectsAssignment* assignment) { + uint32_t cost = (*usage_records_)[tensor_id].tensor_size; + is_tensor_assigned_[tensor_id] = true; + assignment->object_ids[tensor_id] = assignment->object_sizes.size(); + for (const auto& edge_id : edges_from_[tensor_id]) { + const Edge& edge = edges_[edge_id]; + size_t v = edge.dst; + size_t left_twin = LeftPartTwin(v); + if (edge.cap == 0 && IsRightPartVertex(v) && + !is_tensor_assigned_[left_twin]) { + cost = std::max(cost, + AssignTensorsToNewSharedObject(left_twin, assignment)); + } + } + return cost; + } + + size_t source_; + size_t sink_; + size_t num_tensors_; + const std::vector* usage_records_; + std::vector edges_; + std::vector> edges_from_; + std::vector is_tensor_assigned_; +}; + +// Implements memory management with a Minimum-cost flow matching algorithm. +// +// The problem of memory management is NP-complete. This function creates +// auxiliary flow graph, find minimum-cost flow in it and calculates the +// assignment of shared objects to tensors, using the result of the flow +// algorithm. +Status MinCostFlowAssignment( + const std::vector& usage_records, + ObjectsAssignment* assignment) { + MinCostFlowSolver solver; + solver.Build(usage_records); + solver.Solve(); + solver.CalculateAssignment(assignment); + return OkStatus(); +} + +} // namespace + +Status AssignObjectsToTensors( + const std::vector& usage_records, + const MemoryStrategy& strategy, ObjectsAssignment* assignment) { + switch (strategy) { + case MemoryStrategy::NAIVE: + return NaiveAssignment(usage_records, assignment); + case MemoryStrategy::GREEDY: + return GreedyAssignment(usage_records, assignment); + case MemoryStrategy::MINCOSTFLOW: + return MinCostFlowAssignment(usage_records, assignment); + } + return OkStatus(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.h b/tensorflow/lite/delegates/gpu/common/memory_management.h new file mode 100644 index 00000000000..4b8023b8d54 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/memory_management.h @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_H_ + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { + +using TaskId = size_t; + +// Record, containing tensor size and IDs of the first and the last task, that +// use this tensor as input or output. +// For example: tensor #3 with size tensor_size=65536 is first introduced in +// program #2 (first_task=2) and used for the last time in program #7 +// (last_task=7). +struct TensorUsageRecord { + uint32_t tensor_size; + TaskId first_task; + TaskId last_task; + + TensorUsageRecord(uint32_t size, TaskId first, TaskId last) + : tensor_size(size), first_task(first), last_task(last) {} + + // Default order of tensor usage records is increasing order of first_task. + bool operator<(const TensorUsageRecord& other) const { + return first_task < other.first_task; + } +}; + +// Information about assignment of tensors to shared objects +struct ObjectsAssignment { + // shared_object_ids_[i] is ID of shared object, that tensor i will be using. + std::vector object_ids; + // shared_object_sizes_[i] is a size of shared object with ID equal to i. + std::vector object_sizes; +}; + +enum class MemoryStrategy { + // Naive strategy is to allocate each object separately. + // Can be useful for debugging to see all intermediate outputs. + NAIVE, + + // Greedy strategy uses greedy algorithm to reuse memory from tensors, that + // won't be used anymore, for new ones. + GREEDY, + + // Mincostflow strategy consists of building auxiliary flow graph and solving + // the minimum-cost flow problem in it. In the end edges with zero residual + // capacity determine assignment of shared objects to tensors. + MINCOSTFLOW, +}; + +// Calculates the assignement of shared objects to given tensors, including +// objects' sizes. +Status AssignObjectsToTensors( + const std::vector& usage_records, + const MemoryStrategy& strategy, ObjectsAssignment* assignment); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_H_ diff --git a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc new file mode 100644 index 00000000000..a1484cd0e55 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/memory_management.h" + +#include +#include + +namespace tflite { +namespace gpu { +namespace { + +using ::testing::ElementsAre; + +TEST(Model, EmptyRecords) { + ObjectsAssignment assignment; + ASSERT_TRUE( + AssignObjectsToTensors({}, MemoryStrategy::NAIVE, &assignment).ok()); + EXPECT_TRUE(assignment.object_ids.empty()); + EXPECT_TRUE(assignment.object_sizes.empty()); + ASSERT_TRUE( + AssignObjectsToTensors({}, MemoryStrategy::GREEDY, &assignment).ok()); + EXPECT_TRUE(assignment.object_ids.empty()); + EXPECT_TRUE(assignment.object_sizes.empty()); + ASSERT_TRUE( + AssignObjectsToTensors({}, MemoryStrategy::MINCOSTFLOW, &assignment) + .ok()); + EXPECT_TRUE(assignment.object_ids.empty()); + EXPECT_TRUE(assignment.object_sizes.empty()); +} + +TEST(Model, OneRecord) { + std::vector usage_records{ + {/*size=*/16, /*first=*/0, /*last=*/1}}; + ObjectsAssignment assignment; + ASSERT_TRUE( + AssignObjectsToTensors(usage_records, MemoryStrategy::NAIVE, &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0)); + EXPECT_THAT(assignment.object_sizes, ElementsAre(16)); + ASSERT_TRUE( + AssignObjectsToTensors(usage_records, MemoryStrategy::GREEDY, &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0)); + EXPECT_THAT(assignment.object_sizes, ElementsAre(16)); + ASSERT_TRUE(AssignObjectsToTensors(usage_records, MemoryStrategy::MINCOSTFLOW, + &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0)); + EXPECT_THAT(assignment.object_sizes, ElementsAre(16)); +} + +TEST(Model, ChainRecords) { + std::vector usage_records{ + {/*size=*/16, /*first=*/0, /*last=*/1}, + {/*size=*/8, /*first=*/1, /*last=*/2}, + {/*size=*/64, /*first=*/2, /*last=*/3}, + {/*size=*/32, /*first=*/3, /*last=*/4}, + {/*size=*/8, /*first=*/4, /*last=*/5}, + }; + ObjectsAssignment assignment; + ASSERT_TRUE( + AssignObjectsToTensors(usage_records, MemoryStrategy::NAIVE, &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 2, 3, 4)); + EXPECT_THAT(assignment.object_sizes, ElementsAre(16, 8, 64, 32, 8)); + ASSERT_TRUE( + AssignObjectsToTensors(usage_records, MemoryStrategy::GREEDY, &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 0, 1, 0)); + EXPECT_THAT(assignment.object_sizes, ElementsAre(64, 32)); + ASSERT_TRUE(AssignObjectsToTensors(usage_records, MemoryStrategy::MINCOSTFLOW, + &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 0, 1, 0)); + EXPECT_THAT(assignment.object_sizes, ElementsAre(64, 32)); +} + +TEST(Model, ComplexRecords) { + std::vector usage_records{ + {/*size=*/32, /*first=*/0, /*last=*/1}, + {/*size=*/32, /*first=*/1, /*last=*/4}, + {/*size=*/8, /*first=*/2, /*last=*/5}, + {/*size=*/16, /*first=*/3, /*last=*/5}, + {/*size=*/8, /*first=*/4, /*last=*/5}, + {/*size=*/64, /*first=*/5, /*last=*/7}, + {/*size=*/8, /*first=*/6, /*last=*/8}, + {/*size=*/8, /*first=*/7, /*last=*/8}, + {/*size=*/16, /*first=*/8, /*last=*/9}}; + ObjectsAssignment assignment; + ASSERT_TRUE( + AssignObjectsToTensors(usage_records, MemoryStrategy::NAIVE, &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8)); + EXPECT_THAT(assignment.object_sizes, + ElementsAre(32, 32, 8, 16, 8, 64, 8, 8, 16)); + ASSERT_TRUE( + AssignObjectsToTensors(usage_records, MemoryStrategy::GREEDY, &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 0, 2, 3, 1, 3, 2, 0)); + EXPECT_THAT(assignment.object_sizes, ElementsAre(32, 64, 16, 8)); + ASSERT_TRUE(AssignObjectsToTensors(usage_records, MemoryStrategy::MINCOSTFLOW, + &assignment) + .ok()); + EXPECT_THAT(assignment.object_ids, ElementsAre(0, 1, 2, 0, 3, 1, 3, 2, 0)); + EXPECT_THAT(assignment.object_sizes, ElementsAre(32, 64, 8, 8)); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/model.h b/tensorflow/lite/delegates/gpu/common/model.h new file mode 100644 index 00000000000..19a25dbb3dc --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model.h @@ -0,0 +1,585 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_ + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" + +namespace tflite { +namespace gpu { + +// There is yet another representation of CNN graph. The primary purpose of this +// representation is to simplify graph manipulation. + +using ValueId = uint32_t; + +using NodeId = uint32_t; + +// Connects tensor's producer and operation that depends on this tensor. +template +struct Value { + using TensorType = TensorT; + + const ValueId id; + + TensorType tensor; +}; + +struct Operation { + std::string type; + + absl::any attributes; +}; + +struct Node { + const NodeId id; + + Operation operation; +}; + +// Graph is DAG that consists of nodes and values. Each value may have a single +// producer node and multiple consumer nodes. Therefore, each node may have +// multiple input and output values. +// +// Value that does not have a producer is a graph's input. Value that does not +// have a consumer is a graph's output. +// +// Interface provides methods for graph introspection and manipulation. Abstract +// interface makes allows subgraphs representation to ensure safe manipulations. +template +class Graph { + public: + virtual ~Graph() = default; + + // @return a collection of nodes in this graph. + virtual std::vector nodes() const = 0; + + // @return a collection of values in this graph. + virtual std::vector*> values() const = 0; + + // @return graph inputs, that are values without producers. + virtual std::vector*> inputs() const = 0; + + // @return graph outputs, that are values without consumers. + virtual std::vector*> outputs() const = 0; + + // @return inputs into the given node. Returns empty vector for deleted node. + virtual std::vector*> FindInputs(NodeId id) const = 0; + + // @return outputs from the given node. Returns empty vector for deleted node. + virtual std::vector*> FindOutputs(NodeId id) const = 0; + + virtual bool IsGraphInput(ValueId id) const = 0; + + virtual bool IsGraphOutput(ValueId id) const = 0; + + // @return producer of the given value. Returns nullptr for deleted value. + virtual Node* FindProducer(ValueId id) const = 0; + + // @return consumers of the given value. Returns empty vector for deleted + // value. + virtual std::vector FindConsumers(ValueId id) const = 0; + + // @return a node or nullptr if node with the given id is not present. + virtual Node* GetNode(NodeId id) const = 0; + + // @return a value or nullptr if value with the given id is not present. + virtual Value* GetValue(ValueId id) const = 0; + + ////////////////////////////////////////////////////////////////////////////// + // Graph manipulation functions are below + ////////////////////////////////////////////////////////////////////////////// + + // @return new node created in this graph + // NOTE: nodes should be created in the topological order, e.g. node A that + // depends on a value from node B should be created after node B. + virtual Node* NewNode() = 0; + + // @return new value created in this graph + virtual Value* NewValue() = 0; + + // Sets a producer for the given value. There could be a single producer + // for a value. If a value had another producer, it will reassign producer + // appropriately. If a value didn't have a producer, it will be removed + // from a graph's input. + virtual Status SetProducer(NodeId producer, ValueId value) = 0; + + // Removes a producer for the given value. Value becomes producer-less and + // therefore becomes graph's input. + virtual Status RemoveProducer(ValueId value) = 0; + + // Sets a consumer for the given value. There could be multiple consumers + // for a value. + virtual Status AddConsumer(NodeId consumer, ValueId value) = 0; + + // Removes a consumer for the given value. If value does not have any + // consumers it becomes graph's output. + virtual Status RemoveConsumer(NodeId consumer, ValueId value) = 0; + + // Removes node from this graph. For all input values this node will be + // removed from consumers and for all output values a producer will be + // removed. + virtual Status DeleteNode(NodeId id) = 0; + + // Removes value from this graph. It will be removed from inputs for all + // dependent nodes. A node that was a producer of this value will loose its + // output. + virtual Status DeleteValue(ValueId id) = 0; +}; + +// Implementation of a Graph interface. It keeps values and nodes referenced by +// their index in a vector. Therefore, nodes and values are never deleted, but +// rather erased, where corresponding index remains. +// +// It is possible to re-use removed indices, but it is not implemented yet. +template +class Model : public Graph { + public: + const std::string& name() const { return name_; } + + void set_name(std::string name) { name_ = std::move(name); } + + std::vector*> values() const final { + return FilterValues([](const ValueDef&) { return true; }); + } + + std::vector nodes() const final { + return FilterNodes([](const NodeDef&) { return true; }); + } + + std::vector*> inputs() const final { + return FilterValues( + [](const ValueDef& v) { return v.producer == nullptr; }); + } + + std::vector*> outputs() const final { + return FilterValues([](const ValueDef& v) { return v.consumers.empty(); }); + } + + bool IsGraphInput(ValueId id) const final { + if (id >= values_.size()) { + return false; + } + return values_[id].producer == nullptr; + } + + bool IsGraphOutput(ValueId id) const final { + if (id >= values_.size()) { + return false; + } + return values_[id].consumers.empty(); + } + + Node* GetNode(NodeId id) const final { + if (id >= nodes_.size()) { + return {}; + } + return nodes_[id].node.get(); + } + + Value* GetValue(ValueId id) const final { + if (id >= values_.size()) { + return nullptr; + } + return values_[id].value.get(); + } + + Node* NewNode() final { + NodeDef def; + def.node = + absl::make_unique(Node{static_cast(nodes_.size()), {}}); + Node* node = def.node.get(); + nodes_.push_back(std::move(def)); + return node; + } + + Value* NewValue() final { + ValueDef def; + def.value = absl::make_unique>( + Value{static_cast(values_.size()), {}}); + Value* value = def.value.get(); + values_.push_back(std::move(def)); + return value; + } + + std::vector*> FindInputs(NodeId id) const final { + if (id >= nodes_.size()) { + return {}; + } + return nodes_[id].inputs; + } + + std::vector*> FindOutputs(NodeId id) const final { + if (id >= nodes_.size()) { + return {}; + } + return nodes_[id].outputs; + } + + Node* FindProducer(ValueId id) const final { + if (id >= values_.size()) { + return nullptr; + } + return values_[id].producer; + } + + std::vector FindConsumers(ValueId id) const final { + if (id >= values_.size()) { + return {}; + } + return values_[id].consumers; + } + + Status SetProducer(NodeId producer, ValueId value) final { + ValueDef* v; + RETURN_IF_ERROR(LookupValue(value, &v)); + Value* value_ptr = v->value.get(); + NodeDef* n; + RETURN_IF_ERROR(LookupNode(producer, &n)); + Node* node_ptr = n->node.get(); + + // check if this value has the same producer already + if (node_ptr == v->producer) { + return InvalidArgumentError("Node is already a producer of the value"); + } + + // Check if the node is a consumer of this value. + if (std::find(n->inputs.begin(), n->inputs.end(), value_ptr) != + n->inputs.end()) { + return InvalidArgumentError("Node is a consumer of the value"); + } + // TODO(akulik): detect circular dependency? + + if (v->producer != nullptr) { + // value is no longer produced by it's previous producer. + Erase(&nodes_[v->producer->id].outputs, value_ptr); + } + v->producer = node_ptr; + n->outputs.push_back(value_ptr); + return OkStatus(); + } + + Status RemoveProducer(ValueId value) final { + ValueDef* v; + RETURN_IF_ERROR(LookupValue(value, &v)); + Value* value_ptr = v->value.get(); + if (v->producer == nullptr) { + return InvalidArgumentError("Value does not have a producer"); + } + Erase(&nodes_[v->producer->id].outputs, value_ptr); + v->producer = nullptr; + return OkStatus(); + } + + Status AddConsumer(NodeId consumer, ValueId value) final { + ValueDef* v; + RETURN_IF_ERROR(LookupValue(value, &v)); + Value* value_ptr = v->value.get(); + NodeDef* n; + RETURN_IF_ERROR(LookupNode(consumer, &n)); + Node* node_ptr = n->node.get(); + + // check if this value has the same producer already + if (node_ptr == v->producer) { + return InvalidArgumentError("Node is a producer of the value"); + } + + // check if this value has the same consumer already + if (std::find(n->inputs.begin(), n->inputs.end(), value_ptr) != + n->inputs.end()) { + return InvalidArgumentError("Node is already a consumer of the value"); + } + + n->inputs.push_back(value_ptr); + v->consumers.push_back(node_ptr); + return OkStatus(); + } + + Status RemoveConsumer(NodeId consumer, ValueId value) final { + ValueDef* v; + RETURN_IF_ERROR(LookupValue(value, &v)); + Value* value_ptr = v->value.get(); + NodeDef* n; + RETURN_IF_ERROR(LookupNode(consumer, &n)); + Node* node_ptr = n->node.get(); + if (std::find(n->inputs.begin(), n->inputs.end(), value_ptr) == + n->inputs.end()) { + return InvalidArgumentError("Node is not a consumer of the value"); + } + Erase(&n->inputs, value_ptr); + Erase(&v->consumers, node_ptr); + return OkStatus(); + } + + Status DeleteNode(NodeId id) final { + NodeDef* n; + RETURN_IF_ERROR(LookupNode(id, &n)); + Node* node_ptr = n->node.get(); + for (auto value : n->inputs) { + Erase(&values_[value->id].consumers, node_ptr); + } + for (auto value : n->outputs) { + values_[value->id].producer = nullptr; + } + n->inputs.clear(); + n->outputs.clear(); + n->node.reset(); + return OkStatus(); + } + + Status DeleteValue(ValueId id) final { + ValueDef* v; + RETURN_IF_ERROR(LookupValue(id, &v)); + Value* value_ptr = v->value.get(); + if (v->producer != nullptr) { + Erase(&nodes_[v->producer->id].outputs, value_ptr); + } + if (!v->consumers.empty()) { + for (auto node : v->consumers) { + Erase(&nodes_[node->id].inputs, value_ptr); + } + } + v->producer = nullptr; + v->consumers.clear(); + v->value.reset(); + return OkStatus(); + } + + Status MakeExactCopy(Model* model) const { + model->nodes_.clear(); + model->values_.clear(); + model->name_ = name_; + for (auto& value_def : values_) { + model->values_.push_back({}); + if (value_def.value) { + model->values_.back().value = + absl::make_unique>(*value_def.value); + } + } + for (auto& node_def : nodes_) { + model->nodes_.push_back({}); + if (node_def.node) { + model->nodes_.back().node = absl::make_unique(*node_def.node); + for (auto output : node_def.outputs) { + RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id)); + } + for (auto input : node_def.inputs) { + RETURN_IF_ERROR(model->AddConsumer(node_def.node->id, input->id)); + } + } + } + return OkStatus(); + } + + private: + struct NodeDef { + std::vector*> inputs; + std::vector*> outputs; + std::unique_ptr node; + }; + + struct ValueDef { + Node* producer = nullptr; + std::vector consumers; + std::unique_ptr> value; + }; + + template + static void Erase(std::vector* values, T value) { + values->erase(std::find(values->begin(), values->end(), value)); + } + + // @return non-nullptr NodeDef that has valid Node or an error + Status LookupNode(NodeId id, NodeDef** node_def) { + if (id >= nodes_.size()) { + return OutOfRangeError("NodeId is out of range"); + } + auto& n = nodes_[id]; + if (!n.node) { + return OutOfRangeError("Node is already deleted"); + } + *node_def = &n; + return OkStatus(); + } + + // @return non-nullptr ValueDef that has valid Value or an error + Status LookupValue(ValueId id, ValueDef** value_def) { + if (id >= values_.size()) { + return OutOfRangeError("ValueId is out of range"); + } + auto& v = values_[id]; + if (!v.value) { + return OutOfRangeError("Value is already deleted"); + } + *value_def = &v; + return OkStatus(); + } + + template + std::vector*> FilterValues(const Pred& predicate) const { + std::vector*> values; + values.reserve(values_.size()); + for (auto& v : values_) { + if (v.value != nullptr && predicate(v)) { + values.push_back(v.value.get()); + } + } + return values; + } + + template + std::vector FilterNodes(const Pred& predicate) const { + std::vector nodes; + nodes.reserve(nodes_.size()); + for (auto& n : nodes_) { + if (n.node != nullptr && predicate(n)) { + nodes.push_back(n.node.get()); + } + } + return nodes; + } + + std::string name_; + + // There are two approaches possible: wrap entire NodeDef and ValueDef into + // unique_ptr and store it in values_ and nodes_ or store it by value. + // We store it by value here to make introspection calls cheaper. + std::vector values_; + std::vector nodes_; +}; + +// Removes to_remove node that precedes to_keep node only if to_remove has +// outputs that are consumed only by to_keep. In such case to_keep inherits all +// to_remove inputs. +template +Status RemovePrecedingNode(Graph* graph, const Node* to_remove, + const Node* to_keep) { + // Make sure all outputs from to_remove are consumed by to_keep. + for (auto output : graph->FindOutputs(to_remove->id)) { + auto consumers = graph->FindConsumers(output->id); + if (consumers.size() > 1 || + (consumers.size() == 1 && consumers[0] != to_keep)) { + return InvalidArgumentError( + "Output from to_remove node has other consumers"); + } + } + + // Update all references + for (auto input : graph->FindInputs(to_remove->id)) { + RETURN_IF_ERROR(graph->AddConsumer(to_keep->id, input->id)); + } + for (auto output : graph->FindOutputs(to_remove->id)) { + RETURN_IF_ERROR(graph->DeleteValue(output->id)); + } + return graph->DeleteNode(to_remove->id); +} + +// Removes to_remove node that follows to_keep node only if to_remove has inputs +// that are produced by to_keep. to_keep inherits all to_remove inputs. +template +Status RemoveFollowingNode(Graph* graph, const Node* to_remove, + const Node* to_keep) { + // Make sure all inputs to to_remove are produced by to_keep. + for (auto input : graph->FindInputs(to_remove->id)) { + Node* producer = graph->FindProducer(input->id); + if (producer->id != to_keep->id) { + return InvalidArgumentError("To_remove node has other inputs"); + } + } + + for (auto input : graph->FindInputs(to_remove->id)) { + RETURN_IF_ERROR(graph->DeleteValue(input->id)); + } + for (auto output : graph->FindOutputs(to_remove->id)) { + RETURN_IF_ERROR(graph->SetProducer(to_keep->id, output->id)); + } + return graph->DeleteNode(to_remove->id); +} + +// Removes to_remove node. +// Requires that node has one input and one output; +// If to_remove doesn't have producer, all consumers of to_remove, will use +// to_remove input as input +// If to_remove doesn't have consumers, producer of to_remove will have output +// of to_remove. +// If to_remove has producer and consumer(s), consumer(s) will have as input +// output of producer +template +Status RemoveOneInputOneOutputNode(Graph* graph, + const Node* to_remove) { + auto inputs = graph->FindInputs(to_remove->id); + auto outputs = graph->FindOutputs(to_remove->id); + if (inputs.size() != 1 || outputs.size() != 1) { + return InvalidArgumentError( + "To_remove node must have 1 input and 1 output"); + } + auto input_id = inputs[0]->id; + auto output_id = outputs[0]->id; + Node* producer = graph->FindProducer(input_id); + auto consumers = graph->FindConsumers(output_id); + if (!producer && consumers.empty()) { // degenerate case + RETURN_IF_ERROR(graph->DeleteNode(to_remove->id)); + RETURN_IF_ERROR(graph->DeleteValue(input_id)); + return graph->DeleteValue(output_id); + } + if (!producer) { + RETURN_IF_ERROR(graph->DeleteNode(to_remove->id)); + RETURN_IF_ERROR(graph->DeleteValue(output_id)); + for (auto& consumer : consumers) { + RETURN_IF_ERROR(graph->AddConsumer(consumer->id, input_id)); + } + return OkStatus(); + } + return RemoveFollowingNode(graph, to_remove, producer); + return OkStatus(); +} + +template +Status AddOutput(Graph* graph, const Node* from_node, + Value** output) { + auto link = graph->NewValue(); + RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id)); + *output = link; + return OkStatus(); +} + +template +Status ConnectTwoNodes(Graph* graph, const Node* from_node, + const Node* to_node, Value** output) { + Value* link; + RETURN_IF_ERROR(AddOutput(graph, from_node, &link)); + RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id)); + *output = link; + return OkStatus(); +} + +using GraphFloat32 = Model>; + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_H_ diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc new file mode 100644 index 00000000000..b616a72378c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -0,0 +1,2075 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/model_builder.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/builtin_op_data.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace gpu { +namespace { + +using ::absl::make_unique; +using ::absl::StrCat; + +// Creates a node that consumes output from the given node. Because output need +// to stay the same, newly created node will inherit the output from the given +// node, which will in turn get newly created copy of output. This is necessary +// to preserve reference consistency if another node was pointing at that +// output: +// node(output) +// will turn into: +// node(copy(output)) <- passthrough_node(output) +Status NewPassthroughNode(GraphFloat32* graph, Node* node, + const Value* output, + Node** passthru_node) { + *passthru_node = graph->NewNode(); + // Make copies for every output in the original node. + RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id)); + Value* copy_output = graph->NewValue(); + RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id)); + RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id)); + copy_output->tensor = output->tensor; + copy_output->tensor.ref = -1; + return OkStatus(); +} + +template +Status CreateVectorCopyData(const TfLiteTensor& tensor, + std::vector* tensor_data) { + if (tensor.bytes % sizeof(T) != 0) { + return InvalidArgumentError( + StrCat("Input data size ", tensor.bytes, + " is not aligned to expected type: ", sizeof(T))); + } + tensor_data->resize(tensor.bytes / sizeof(T)); + std::memcpy(&(*tensor_data)[0], tensor.data.uint8, tensor.bytes); + return OkStatus(); +} + +template +Status SetAllDimensions(const TfLiteIntArray* dimensions, ShapeT* shape); + +template <> +Status SetAllDimensions(const TfLiteIntArray* dimensions, + Scalar* shape) { + if (dimensions->size < 0) { + return InvalidArgumentError("Invalid Scalar dimensions"); + } + for (int i = 0; i < dimensions->size; ++i) { + if (dimensions->data[i] != 1) { + return InvalidArgumentError("Dimension can not be reduced to scalar."); + } + } + shape->v = 1; + return OkStatus(); +} + +template <> +Status SetAllDimensions(const TfLiteIntArray* dimensions, + Linear* shape) { + if (dimensions->size <= 0) { + return InvalidArgumentError("Dimension is empty."); + } + for (int i = 0; i < dimensions->size - 1; ++i) { + if (dimensions->data[i] != 1) { + return InvalidArgumentError("Dimension can not be reduced to linear."); + } + } + shape->v = dimensions->data[dimensions->size - 1]; + return OkStatus(); +} + +template <> +Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) { + if (dimensions->size != 4) { + return InvalidArgumentError("Dimensions are not HWC"); + } + if (dimensions->data[0] != 1) { + return UnimplementedError("Batch size is not equal to 1."); + } + shape->h = dimensions->data[1]; + shape->w = dimensions->data[2]; + shape->c = dimensions->data[3]; + return OkStatus(); +} + +template <> +Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) { + if (dimensions->size != 2) { + return InvalidArgumentError("Dimensions are not HW"); + } + shape->h = dimensions->data[0]; + shape->w = dimensions->data[1]; + return OkStatus(); +} + +template <> +Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) { + if (dimensions->size != 4) { + return InvalidArgumentError( + StrCat("Dimensions are not OHWI: ", dimensions->size)); + } + shape->o = dimensions->data[0]; + shape->h = dimensions->data[1]; + shape->w = dimensions->data[2]; + shape->i = dimensions->data[3]; + return OkStatus(); +} + +template <> +Status SetAllDimensions(const TfLiteIntArray* dimensions, IHWO* shape) { + if (dimensions->size != 4) { + return InvalidArgumentError( + StrCat("Dimensions are not IHWO: ", dimensions->size)); + } + shape->i = dimensions->data[0]; + shape->h = dimensions->data[1]; + shape->w = dimensions->data[2]; + shape->o = dimensions->data[3]; + return OkStatus(); +} + +template <> +Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) { + if (dimensions->size != 4) { + return InvalidArgumentError("Dimensions are not BHWC"); + } + shape->b = dimensions->data[0]; + shape->h = dimensions->data[1]; + shape->w = dimensions->data[2]; + shape->c = dimensions->data[3]; + return OkStatus(); +} + +DataType ToDataType(TfLiteType type) { + switch (type) { + case kTfLiteFloat32: + return DataType::FLOAT32; + case kTfLiteInt32: + return DataType::INT32; + case kTfLiteInt64: + return DataType::INT64; + case kTfLiteUInt8: + return DataType::UINT8; + default: + return DataType::UNKNOWN; + } +} + +int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context, + const TfLiteNode* tflite_node) { + int number_of_runtime_inputs = 0; + for (int i = 0; i < tflite_node->inputs->size; i++) { + if (!IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) { + number_of_runtime_inputs++; + } + } + return number_of_runtime_inputs; +} + +int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context, + const TfLiteNode* tflite_node) { + int number_of_runtime_outputs = 0; + for (int i = 0; i < tflite_node->outputs->size; i++) { + if (!IsConstantTensor(&context->tensors[tflite_node->outputs->data[i]])) { + number_of_runtime_outputs++; + } + } + return number_of_runtime_outputs; +} + +Status CheckTensorIsAvailable(const TfLiteContext* context, + const TfLiteNode* tflite_node, int idx) { + // If tensor id is in range, it's guaranteed that it'll be available. + if (idx >= tflite_node->inputs->size) { + return OutOfRangeError( + absl::StrFormat("Requested index goes beyond array size (%d vs %d).", + idx, tflite_node->inputs->data[idx])); + } + return OkStatus(); +} + +class ObjectReader { + public: + ObjectReader(GraphFloat32* graph, TfLiteContext* context, + const TfLiteNode* tflite_node, + std::vector*>* tensor_to_value) + : graph_(graph), + context_(context), + tflite_node_(tflite_node), + tensor_to_value_(tensor_to_value) {} + + Status ReadValue(uint32_t idx, Value** value) { + if (idx >= tflite_node_->inputs->size) { + return OutOfRangeError(StrCat("ReadValue: input tensor index: ", idx)); + } + RETURN_IF_ERROR( + ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value)); + return OkStatus(); + } + + int GetNumberOfRuntimeInputs() { + return GetNumberOfRuntimeInputsForNode(context_, tflite_node_); + } + + Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) { + if (idx >= tflite_node_->inputs->size) { + return OutOfRangeError(StrCat("Input tensor index: ", idx)); + } + int32_t tensor_idx = tflite_node_->inputs->data[idx]; + if (tensor_idx < 0 || tensor_idx > context_->tensors_size) { + return OutOfRangeError(StrCat("Tensor index: ", tensor_idx)); + } + const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx]; + *dimensions = *tflite_tensor.dims; + return OkStatus(); + } + + template + Status ReadTensor(uint32_t idx, TensorT* t) const { + RETURN_IF_ERROR(CheckTensorIsAvailable(context_, tflite_node_, idx)); + int32_t tensor_idx = tflite_node_->inputs->data[idx]; + const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx]; + RETURN_IF_ERROR(CreateVectorCopyData(tflite_tensor, &t->data)); + + // Axis and data layout depend on operation this tensor is used in. So, + // postpone resolutions until operations are parsed. + t->id = tensor_idx; + return SetAllDimensions(tflite_tensor.dims, &t->shape); + } + + Status AddOutput(const Node* node, int id) { + if (tflite_node_->outputs->size <= id) { + return InvalidArgumentError( + StrCat("Data id ", id, " must be less than tflite node outputs size ", + tflite_node_->outputs->size)); + } + int output_tensor_idx = tflite_node_->outputs->data[id]; + Value* value; + RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value)); + RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id)); + return OkStatus(); + } + + Status AddOutputs(const Node* node) { + for (int i = 0; i < tflite_node_->outputs->size; ++i) { + RETURN_IF_ERROR(AddOutput(node, i)); + } + return OkStatus(); + } + + Status AddInput(const Node* node, uint32_t idx) { + Value* input; + RETURN_IF_ERROR(ReadValue(idx, &input)); + return graph_->AddConsumer(node->id, input->id); + } + + Status ReadValueByTensorIdx(uint32_t tensor_idx, + Value** value) { + if (tensor_idx >= tensor_to_value_->size()) { + return OutOfRangeError( + StrCat("ReadValue: input tensor index: ", tensor_idx)); + } + if ((*tensor_to_value_)[tensor_idx] == nullptr) { + const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx]; + if (tflite::IsConstantTensor(&tflite_tensor)) { + return NotFoundError( + StrCat("ReadValue: value is a constant tensor: ", tensor_idx)); + } + Value* value = graph_->NewValue(); + RETURN_IF_ERROR( + ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor)); + value->tensor.ref = tensor_idx; + (*tensor_to_value_)[tensor_idx] = value; + } + *value = (*tensor_to_value_)[tensor_idx]; + return OkStatus(); + } + + private: + GraphFloat32* graph_ = nullptr; + const TfLiteContext* context_ = nullptr; + const TfLiteNode* tflite_node_ = nullptr; + std::vector*>* tensor_to_value_; +}; + +Status CheckInputsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, int inputs, + int outputs) { + int runtime_inputs = GetNumberOfRuntimeInputsForNode(context, tflite_node); + if (runtime_inputs != inputs) { + return InternalError( + absl::StrFormat("Expected %d input tensor(s), but node has %d runtime " + "input(s).", + inputs, runtime_inputs)); + } + int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); + if (runtime_outputs != outputs) { + return InternalError( + absl::StrFormat("Expected %d output tensor(s), but node has %d runtime " + "output(s).", + outputs, runtime_outputs)); + } + return OkStatus(); +} + +// A parser responsible for parsing TFLite operation and adding it to a graph. +class TFLiteOperationParser { + public: + virtual ~TFLiteOperationParser() {} + + // Parses TFLite operation. This method allows expanding fused operations + // into more than one node. + virtual Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) = 0; + + // Verifies whether passed tflite node may be built by GPU delegate or not. + virtual Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) = 0; +}; + +Status CheckActivationSupported(TfLiteFusedActivation fused_activation) { + if (fused_activation == kTfLiteActNone) { + return OkStatus(); + } + switch (fused_activation) { + case kTfLiteActRelu: + case kTfLiteActRelu1: + case kTfLiteActRelu6: + case kTfLiteActTanh: + return OkStatus(); + default: + return NotFoundError(absl::StrFormat("Unsupported fused activation: %d.", + fused_activation)); + } +} + +// If there is fused activation present, then there will be another node created +// that will have identical output as the given node. New operation node will +// depend on the given node output. +Status MaybeFuseActivation(TfLiteFusedActivation fused_activation, + const std::vector& output_indices, + GraphFloat32* graph, Node* node) { + if (fused_activation == kTfLiteActNone) { + return OkStatus(); + } + const auto& outputs = graph->FindOutputs(node->id); + if (outputs.empty()) { + return InternalError("Empty outputs in fused node"); + } + switch (fused_activation) { + case kTfLiteActRelu: + case kTfLiteActRelu1: + case kTfLiteActRelu6: { + ReLUAttributes attr; + attr.clip = fused_activation == kTfLiteActRelu + ? 0.0f + : (fused_activation == kTfLiteActRelu1 ? 1.0f : 6.0f); + for (auto index : output_indices) { + Node* activation_node; + RETURN_IF_ERROR( + NewPassthroughNode(graph, node, outputs[index], &activation_node)); + activation_node->operation.type = ToString(OperationType::RELU); + activation_node->operation.attributes = attr; + } + break; + } + case kTfLiteActTanh: + for (auto index : output_indices) { + Node* activation_node; + RETURN_IF_ERROR( + NewPassthroughNode(graph, node, outputs[index], &activation_node)); + activation_node->operation.type = ToString(OperationType::TANH); + } + break; + default: + return NotFoundError( + StrCat("Unsupported fused activation: ", fused_activation)); + } + return OkStatus(); +} + +Status MaybeFuseActivationToTheSingleOutput( + TfLiteFusedActivation fused_activation, GraphFloat32* graph, Node* node) { + if (graph->FindOutputs(node->id).size() != 1) { + return InternalError("Number of outputs exceeds 1"); + } + return MaybeFuseActivation(fused_activation, {0}, graph, node); +} + +HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); } + +template +void UpdatePadding(const TfLitePadding& padding, const BHWC& input_shape, + AttrT* attr) { + if (padding == kTfLitePaddingSame) { + attr->padding = CalculateSamePadding(input_shape, *attr); + } else { + attr->padding.prepended = HW(0, 0); + attr->padding.appended = HW(0, 0); + } +} + +Status GetFullyConnectedAttributes(int weights_tensor_id, int bias_tensor_id, + ObjectReader* reader, + FullyConnectedAttributes* attr) { + Tensor weights; + RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights)); + attr->weights.data = std::move(weights.data); + attr->weights.id = weights.id; + attr->weights.shape.h = 1; + attr->weights.shape.w = 1; + attr->weights.shape.o = weights.shape.h; + attr->weights.shape.i = weights.shape.w; + reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional + + return OkStatus(); +} + +template +Status RetrieveBuiltinData(const TfLiteNode* tflite_node, + ParamsType** tf_options) { + const auto* params = + reinterpret_cast(tflite_node->builtin_data); + if (!params) { + return InternalError("Unable to retrieve builtin_data."); + } + *tf_options = const_cast(params); + return OkStatus(); +} + +template +Status RetrieveCustomInitialData(const TfLiteNode* tflite_node, + ParamsType** tf_options) { + const auto* params = + reinterpret_cast(tflite_node->custom_initial_data); + if (!params) { + return InternalError("Unable to retrieve custom_initial_data."); + } + *tf_options = const_cast(params); + return OkStatus(); +} + +Status CheckMaxSupportedOpVersion(const TfLiteRegistration* registration, + int max_version) { + const int op_version = registration->version; + if (op_version > max_version) { + return UnimplementedError( + absl::StrFormat("Max version supported: %d. Requested version %d.", + max_version, op_version)); + } + return OkStatus(); +} + +Status CheckExactSupportedOpVersion(const TfLiteRegistration* registration, + int expected_version) { + int op_version = registration->version; + if (op_version != expected_version) { + return UnimplementedError( + absl::StrFormat("Only version %d is supported. Requested version %d.", + expected_version, op_version)); + } + return OkStatus(); +} + +Status CheckKernels(int kernel_h, int kernel_w) { + if (kernel_h <= 0 || kernel_w <= 0) { + return InvalidArgumentError(absl::StrFormat( + "Incorrect kernel values: kernel_height = %d, kernel_width = %d.", + kernel_h, kernel_w)); + } + return OkStatus(); +} + +Status CheckStrides(int strides_h, int strides_w) { + if (strides_h <= 0 || strides_w <= 0) { + return InvalidArgumentError(absl::StrFormat( + "Incorrect stride values: stride_height = %d, stride_width = %d.", + strides_h, strides_w)); + } + return OkStatus(); +} + +Status CheckDilation(int dilation_h, int dilation_w) { + if (dilation_h <= 0 || dilation_w <= 0) { + return InvalidArgumentError( + absl::StrFormat("Incorrect dilation values: dilation_factor = %d, " + "dilation_factor = %d.", + dilation_h, dilation_w)); + } + return OkStatus(); +} + +Status CheckStridesAndDilation(int strides_h, int strides_w, int dilation_h, + int dilation_w) { + RETURN_IF_ERROR(CheckStrides(strides_h, strides_w)); + RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w)); + return OkStatus(); +} + +Status CheckKernelsAndStrides(int kernel_h, int kernel_w, int strides_h, + int strides_w) { + RETURN_IF_ERROR(CheckKernels(kernel_h, kernel_w)); + RETURN_IF_ERROR(CheckStrides(strides_h, strides_w)); + return OkStatus(); +} + +class Conv2DOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + RETURN_IF_ERROR( + CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + TfLiteConvParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckStridesAndDilation( + tf_options->stride_height, tf_options->stride_width, + tf_options->dilation_height_factor, tf_options->dilation_width_factor)); + RETURN_IF_ERROR(CheckActivationSupported(tf_options->activation)); + return OkStatus(); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::CONVOLUTION_2D); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + Convolution2DAttributes attr; + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional + + const auto* tf_options = + reinterpret_cast(tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing tflite params"); + } + attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); + attr.dilations = HW(tf_options->dilation_height_factor, + tf_options->dilation_width_factor); + UpdatePadding(tf_options->padding, + graph->FindInputs(node->id)[0]->tensor.shape, &attr); + RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, + graph, node)); + node->operation.attributes = std::move(attr); + return OkStatus(); + } +}; + +// Creates a simple node that holds tensor value. +Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, + Value** value) { + ConstTensorAttributes attr; + attr.tensor = std::move(t); + Node* node = graph->NewNode(); + node->operation.attributes = attr; + node->operation.type = ToString(OperationType::CONST); + *value = graph->NewValue(); + RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id)); + // Keep data inside this tensor. + (*value)->tensor.ref = attr.tensor.id; + (*value)->tensor.type = attr.tensor.kType; + (*value)->tensor.shape = attr.tensor.shape; + return OkStatus(); +} + +class ConcatenationOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + + // TODO(eignasheva): add proper tensor availability checking + // for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) { + // RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx)); + // } + // TODO(eignasheva): add axis checking. + TfLiteConcatenationParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + return OkStatus(); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + ConcatAttributes attr; + // Read inputs first to make sure const node is added to a graph before + // concat node to ensure topological order. + std::vector*> inputs; + for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) { + Value* value; + const auto status = reader->ReadValue(idx, &value); + if (status.ok()) { + inputs.push_back(value); + } else { + TensorFloat32 tensor; + RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor)); + Value* value; + RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value)); + inputs.push_back(value); + } + } + + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::CONCAT); + RETURN_IF_ERROR(reader->AddOutputs(node)); + for (const Value* input : inputs) { + RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); + } + + std::vector input_shapes; + for (auto input : graph->FindInputs(node->id)) { + input_shapes.push_back(input->tensor.shape); + } + RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis)); + + // Guess axis. + BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape; + for (auto input : graph->FindInputs(node->id)) { + if (input->tensor.shape.h != output_shape.h) { + attr.axis = Axis::HEIGHT; + break; + } + if (input->tensor.shape.w != output_shape.w) { + attr.axis = Axis::WIDTH; + break; + } + if (input->tensor.shape.c != output_shape.c) { + attr.axis = Axis::CHANNELS; + break; + } + } + const auto* tf_options = reinterpret_cast( + tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing tflite params"); + } + RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, + graph, node)); + node->operation.attributes = attr; + return OkStatus(); + } + + private: + Status SetAxis(const std::vector& input_shapes, Axis* axis) { + *axis = Axis::BATCH; + for (int i = 1; i < input_shapes.size(); i++) { + if (input_shapes[0].h != input_shapes[i].h && + input_shapes[0].w != input_shapes[i].w && + input_shapes[0].c != input_shapes[i].c) { + *axis = Axis::HEIGHT; + break; + } + } + if (*axis == Axis::BATCH) return OkStatus(); + for (int i = 1; i < input_shapes.size(); i++) { + if (input_shapes[0].b != input_shapes[i].b && + input_shapes[0].w != input_shapes[i].w && + input_shapes[0].c != input_shapes[i].c) { + *axis = Axis::WIDTH; + break; + } + } + if (*axis == Axis::HEIGHT) return OkStatus(); + for (int i = 1; i < input_shapes.size(); i++) { + if (input_shapes[0].b != input_shapes[i].b && + input_shapes[0].h != input_shapes[i].h && + input_shapes[0].c != input_shapes[i].c) { + *axis = Axis::CHANNELS; + break; + } + } + if (*axis == Axis::WIDTH) return OkStatus(); + for (int i = 1; i < input_shapes.size(); i++) { + if (input_shapes[0].b != input_shapes[i].b && + input_shapes[0].w != input_shapes[i].w && + input_shapes[0].h != input_shapes[i].h) { + return UnimplementedError( + "Can concatenate tensors only by batch, height, width, or " + "channels."); + } + } + return OkStatus(); + } +}; + +class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); + RETURN_IF_ERROR( + CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + TfLiteDepthwiseConvParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckStridesAndDilation( + tf_options->stride_height, tf_options->stride_width, + tf_options->dilation_height_factor, tf_options->dilation_width_factor)); + RETURN_IF_ERROR(CheckActivationSupported(tf_options->activation)); + return OkStatus(); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + DepthwiseConvolution2DAttributes attr; + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional + const auto* tf_options = reinterpret_cast( + tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing tflite params"); + } + attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); + attr.dilations = HW(std::max(1, tf_options->dilation_height_factor), + std::max(1, tf_options->dilation_width_factor)); + UpdatePadding(tf_options->padding, + graph->FindInputs(node->id)[0]->tensor.shape, &attr); + RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, + graph, node)); + node->operation.attributes = std::move(attr); + return OkStatus(); + } +}; + +class ReshapeOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + RETURN_IF_ERROR( + CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + // TODO(eignasheva): add shape checking + return OkStatus(); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::RESHAPE); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + // Here we may have extra inputs. Other tensors were supposed to + // define new shape, but in TFLite these are ignored. + // TODO(akulik): check that shapes match? + + // New shape comes from output shape. + ReshapeAttributes attr; + attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape; + node->operation.attributes = attr; + return OkStatus(); + } +}; + +Status ParsePoolingAttributes(const TfLitePoolParams* tf_options, + const BHWC& input_shape, + Pooling2DAttributes* attr) { + attr->kernel = ToHW(tf_options->filter_height, tf_options->filter_width); + attr->strides = ToHW(tf_options->stride_height, tf_options->stride_width); + UpdatePadding(tf_options->padding, input_shape, attr); + return OkStatus(); +} + +class Pooling2DOperationParser : public TFLiteOperationParser { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + TfLitePoolParams* tf_options = nullptr; + auto status = RetrieveCustomInitialData(tflite_node, &tf_options); + if (status.ok()) { // custom case with indices as a second output + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + /*outputs=*/2)); + } else { // common pooling with 1 output + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + /*outputs=*/1)); + } + RETURN_IF_ERROR(CheckKernelsAndStrides( + tf_options->filter_height, tf_options->filter_width, + tf_options->stride_height, tf_options->stride_width)); + RETURN_IF_ERROR(CheckActivationSupported(tf_options->activation)); + return OkStatus(); + } + + public: + explicit Pooling2DOperationParser(PoolingType type) : type_(type) {} + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::POOLING_2D); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutput(node, 0)); + + Pooling2DAttributes attr; + attr.type = type_; + + auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; + + // check whether there are custom options encoded. It happens if operation + // is MaxPoolingWithArgmax2D. There is no way to read + // tflite_node->builtin_code, so, simply check whether custom data is + // available. + auto* tf_options = reinterpret_cast( + tflite_node->custom_initial_data); + if (!tf_options) { + tf_options = + reinterpret_cast(tflite_node->builtin_data); + } + if (!tf_options) { + return InternalError("Missing tflite params"); + } + + std::vector max_tensor_id{0}; + RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, max_tensor_id, + graph, node)); + // Second output is optional. It is not required, it but must be added after + // MaybeAddFusedActivation function is called + reader->AddOutput(node, 1).IgnoreError(); + + // First output is the result of pooling operation, while second output is + // indices used for pooling. + auto outputs = graph->FindOutputs(node->id); + attr.output_indices = outputs.size() == 2; + if (attr.output_indices) { + // Fix data type for output indices. In the model it is set as float32. + outputs[1]->tensor.type = DataType::INT32; + } + RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr)); + node->operation.attributes = attr; + return OkStatus(); + } + + private: + const PoolingType type_; +}; + +class Unpooling2DOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + TfLitePoolParams* tf_options = nullptr; + RETURN_IF_ERROR( + CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1)); + RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckKernelsAndStrides( + tf_options->filter_height, tf_options->filter_width, + tf_options->stride_height, tf_options->stride_width)); + return OkStatus(); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddInput(node, 1)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; + MaxUnpooling2DAttributes attr; + const auto* tf_options = reinterpret_cast( + tflite_node->custom_initial_data); + if (!tf_options) { + return InternalError("Missing tflite params"); + } + attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width); + attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width); + UpdatePadding(tf_options->padding, input_shape, &attr); + + node->operation.attributes = attr; + + auto output_value = graph->FindOutputs(node->id)[0]; + output_value->tensor.shape = CalculateOutputShape(input_shape, attr); + return OkStatus(); + } +}; + +class SoftMaxOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + RETURN_IF_ERROR( + CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + TfLiteSoftmaxParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + if (tf_options->beta != 1) { + // TODO(eignasheva): figure out, what's wrong with softmax. + return UnimplementedError("Softmax.beta != 1 is not supported."); + } + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::SOFT_MAX); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + const auto* tf_options = + reinterpret_cast(tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing tflite params"); + } + if (tf_options->beta != 1) { + // there is multiply by scalar operation fused in SoftMax. Make a layer + // out of it before SoftMax. + return UnimplementedError("Softmax.beta != 1 is not supported."); + // auto mul_node = reader->NewPassthroughNode(node); + // mul_node->operation.type = ToString(OperationType::MUL); + } + SoftMaxAttributes attr; + attr.axis = Axis::CHANNELS; // always by channels + node->operation.attributes = attr; + return OkStatus(); + } +}; + +class AddOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + // TODO(eignasheva): add shapes check. + TfLiteAddParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::ADD); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + AddAttributes attr; + for (int idx = 0; idx < tflite_node->inputs->size; ++idx) { + if (!reader->AddInput(node, idx).ok()) { + if (tflite_node->inputs->size != 2) { + return InvalidArgumentError( + "Broadcast Add should accept 2 inputs, one input tensor and " + "broadcasted tensor"); + } + TfLiteIntArray dims; + RETURN_IF_ERROR(reader->GetTensorDims(1, &dims)); + if (dims.size <= 0) { + Tensor tensor; + RETURN_IF_ERROR(reader->ReadTensor(1, &tensor)); + attr.param = tensor.data[0]; + } else { + Tensor tensor; + RETURN_IF_ERROR(reader->ReadTensor(1, &tensor)); + attr.param = std::move(tensor); + } + } + } + node->operation.attributes = std::move(attr); + + const auto* tf_options = + reinterpret_cast(tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing tflite params"); + } + RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, + graph, node)); + return OkStatus(); + } +}; + +// Basic LSTM Cell: +// +// 1name = name is at input index 1 +// name1 = name is at output index 1 +// +// 0input 1prev_activ +// \ / +// [[concat]] +// \ +// concat_temp2 2weights 3biases +// \ / / +// [[fully-connected]] +// \ +// activ_temp3 4prev_state +// \ / +// [[LSTM]] +// / \ +// new_state1 activation0 +// +class LstmOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2)); + // TODO(eignasheva): Fix bad check. + // RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/5, + // /*outputs=*/4)); + TfLiteLSTMParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckParameters(tf_options)); + return OkStatus(); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + if (tflite_node->inputs->size != 5) { + return InvalidArgumentError("LSTM should have 5 input tensors"); + } + if (tflite_node->outputs->size != 4) { + return InvalidArgumentError("LSTM should have 4 output tensors"); + } + + const auto* params = + reinterpret_cast(tflite_node->builtin_data); + if (!params) { + return InternalError("Missing tflite params"); + } + RETURN_IF_ERROR(CheckParameters(params)); + + Node* concat_node = graph->NewNode(); + concat_node->operation.type = ToString(OperationType::CONCAT); + ConcatAttributes concat_attr; + concat_attr.axis = Axis::CHANNELS; + concat_node->operation.attributes = concat_attr; + + Node* fc_node = graph->NewNode(); + fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED); + FullyConnectedAttributes fc_attr; + RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr)); + fc_node->operation.attributes = std::move(fc_attr); + + Node* lstm_node = graph->NewNode(); + lstm_node->operation.type = ToString(OperationType::LSTM); + LstmAttributes lstm_attr; + lstm_attr.kernel_type = LstmKernelType::BASIC; + lstm_node->operation.attributes = lstm_attr; + + Value* concat_temp; + int concat_tensor_idx = tflite_node->outputs->data[2]; + RETURN_IF_ERROR( + reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp)); + Value* activ_temp; + int activ_tensor_idx = tflite_node->outputs->data[3]; + RETURN_IF_ERROR( + reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp)); + + RETURN_IF_ERROR(reader->AddInput(concat_node, 0)); // input + RETURN_IF_ERROR(reader->AddInput(concat_node, 1)); // prev_activ + RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id)); + + RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id)); + RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id)); + + RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id)); + RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state + RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state + RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation + + return OkStatus(); + } + + private: + Status CheckParameters(const TfLiteLSTMParams* tf_options) { + if (tf_options->kernel_type != + TfLiteLSTMKernelType::kTfLiteLSTMBasicKernel) { + return UnimplementedError("Only kTfLiteLSTMBasicKernel is supported."); + } + if (tf_options->activation != kTfLiteActTanh) { + return UnimplementedError("Only TANH activation is supported."); + } + if (tf_options->cell_clip != 0.0f) { + return UnimplementedError("cell_clip is not supported."); + } + if (tf_options->proj_clip != 0.0f) { + return UnimplementedError("proj_clip is not supported."); + } + return OkStatus(); + } +}; + +class ResizeBilinearOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + RETURN_IF_ERROR( + CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + + // TODO(eignasheva): check shapes. + TfLiteResizeBilinearParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::UPSAMPLE_2D); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + // Here we may have extra inputs. Other tensors were supposed to + // define new shape, but in TFLite these are ignored. + + const auto* tf_options = + reinterpret_cast( + tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing tflite params"); + } + Upsample2DAttributes attr; + attr.align_corners = tf_options->align_corners; + attr.type = UpsamplingType::BILINEAR; + attr.new_shape.CopyAllDefinedAxis( + graph->FindOutputs(node->id)[0]->tensor.shape); + node->operation.attributes = attr; + return OkStatus(); + } +}; + +class PadOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + RETURN_IF_ERROR( + CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::PAD); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + PadAttributes attr; + attr.type = PaddingContentType::ZEROS; + Tensor paddings; + RETURN_IF_ERROR(reader->ReadTensor(1, &paddings)); + + // 4x2 tensor with paddings. + if (paddings.shape.h != 4 || paddings.shape.w != 2) { + return InvalidArgumentError("Paddings tensor has unexpected shape."); + } + if (paddings.data[0] != 0 || paddings.data[1] != 0) { + return UnimplementedError("Padding for BATCH channel is not supported."); + } + attr.prepended = HWC(paddings.data[2], paddings.data[4], paddings.data[6]); + attr.appended = HWC(paddings.data[3], paddings.data[5], paddings.data[7]); + node->operation.attributes = attr; + return OkStatus(); + } +}; + +class ElementwiseOperationParser : public TFLiteOperationParser { + public: + explicit ElementwiseOperationParser(OperationType operation_type) + : operation_type_(operation_type) {} + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + if (IsTwoArgumentOperation()) { + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/2, + /*outputs=*/1)); + TfLiteSubParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckActivationSupported(tf_options->activation)); + } else if (!IsOneArgumentOperation()) { + return InvalidArgumentError("Incorrect operation type passed"); + } + + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(operation_type_); + + if (IsOneArgumentOperation()) { + RETURN_IF_ERROR(reader->AddInput(node, 0)); + } else if (IsTwoArgumentOperation()) { + if (tflite_node->inputs->size != 2) { + return InvalidArgumentError("Applies only two input tensors"); + } + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddInput(node, 1)); + + TfLiteFusedActivation activation = kTfLiteActNone; + switch (operation_type_) { + case OperationType::SUB: { + const auto* tf_options = reinterpret_cast( + tflite_node->builtin_data); + if (tf_options != nullptr) { + activation = tf_options->activation; + } + break; + } + case OperationType::DIV: { + const auto* tf_options = reinterpret_cast( + tflite_node->builtin_data); + if (tf_options != nullptr) { + activation = tf_options->activation; + } + break; + } + default: + // No activation expected. + activation = kTfLiteActNone; + } + + if (activation) { + RETURN_IF_ERROR( + MaybeFuseActivationToTheSingleOutput(activation, graph, node)); + } + } else { + return InvalidArgumentError("Incorrect operation type passed"); + } + + return reader->AddOutputs(node); + } + + private: + bool IsOneArgumentOperation() const { + switch (operation_type_) { + case OperationType::ABS: + case OperationType::SIN: + case OperationType::COS: + case OperationType::LOG: + case OperationType::SQRT: + case OperationType::RSQRT: + case OperationType::SQUARE: + case OperationType::SIGMOID: + case OperationType::TANH: + return true; + default: + return false; + } + } + + bool IsTwoArgumentOperation() const { + switch (operation_type_) { + case OperationType::SUB: + case OperationType::DIV: + case OperationType::POW: + case OperationType::SQUARED_DIFF: + return true; + default: + return false; + } + } + + OperationType operation_type_; +}; + +class PReLuOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + // TODO(eignasheva): add params check + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::PRELU); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; + + PReLUAttributes attr; + Tensor linear_alpha; + Status status = reader->ReadTensor(1, &linear_alpha); + if (status.ok()) { + if (linear_alpha.shape.v != input_shape.c) { + return InvalidArgumentError( + "Linear alpha shape does not match the number of input channels."); + } + attr.alpha = std::move(linear_alpha); + } else { + Tensor hwc_alpha; + RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha)); + if (hwc_alpha.shape.h != input_shape.h || + hwc_alpha.shape.w != input_shape.w || + hwc_alpha.shape.c != input_shape.c) { + return InvalidArgumentError("Alpha shape does not match input shape."); + } + attr.alpha = std::move(hwc_alpha); + } + node->operation.attributes = std::move(attr); + return reader->AddOutputs(node); + } +}; + +class ReLuOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + return OkStatus(); + } + explicit ReLuOperationParser(int clip) : clip_(clip) {} + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::RELU); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + + ReLUAttributes attr; + TfLiteLeakyReluParams* tf_options = nullptr; + RetrieveBuiltinData(tflite_node, &tf_options).IgnoreError(); + attr.alpha = tf_options ? tf_options->alpha : 0; + attr.clip = clip_; + node->operation.attributes = attr; + return reader->AddOutputs(node); + } + + private: + int clip_; +}; + +class MulOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + // TODO(eignasheva): add params check + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + if (reader->GetNumberOfRuntimeInputs() == 2) { + // ApplyMask operation + node->operation.type = ToString(OperationType::APPLY_MASK); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddInput(node, 1)); + } else { + node->operation.type = ToString(OperationType::MULTIPLY_SCALAR); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + MultiplyScalarAttributes attr; + TfLiteIntArray dims; + RETURN_IF_ERROR(reader->GetTensorDims(1, &dims)); + if (dims.size <= 0) { + Tensor tensor; + RETURN_IF_ERROR(reader->ReadTensor(1, &tensor)); + attr.param = tensor.data[0]; + } else { + Tensor tensor; + RETURN_IF_ERROR(reader->ReadTensor(1, &tensor)); + attr.param = std::move(tensor); + } + node->operation.attributes = std::move(attr); + } + return reader->AddOutputs(node); + } +}; + +class FullyConnectedOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + TfLiteFullyConnectedParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + if (tf_options->weights_format != + kTfLiteFullyConnectedWeightsFormatDefault) { + return UnimplementedError("Unsupported FullyConnected weights format."); + } + // TODO(eignasheva): check input shape + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + + const auto* tf_options = + reinterpret_cast( + tflite_node->builtin_data); + if (tf_options->weights_format != + kTfLiteFullyConnectedWeightsFormatDefault) { + return UnimplementedError("Unsupported FullyConnected weights format."); + } + + FullyConnectedAttributes attr; + RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr)); + + Tensor weights; + RETURN_IF_ERROR(reader->ReadTensor(1, &weights)); + auto input = graph->FindInputs(node->id)[0]; + int batch_size = input->tensor.shape.b; + if (input->tensor.shape.DimensionsProduct() / batch_size != + weights.shape.w) { + return UnimplementedError( + "Amount of input data should match weights width"); + } + + Node* conv = node; + if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) { + auto& reshape = node; + conv = graph->NewNode(); // reset conv pointer! + Value* reshaped_value = graph->NewValue(); + reshaped_value->tensor.shape = BHWC(1, 1, 1, weights.shape.w); + RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id)); + reshape->operation.type = ToString(OperationType::RESHAPE); + ReshapeAttributes attr; + attr.new_shape = reshaped_value->tensor.shape; + reshape->operation.attributes = attr; + RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id)); + } + + conv->operation.type = ToString(OperationType::FULLY_CONNECTED); + conv->operation.attributes = std::move(attr); + Status result = reader->AddOutputs(conv); + RETURN_IF_ERROR(MaybeFuseActivationToTheSingleOutput(tf_options->activation, + graph, conv)); + + return result; + } +}; + +class StridedSliceOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + TfLiteStridedSliceParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::SLICE); + RETURN_IF_ERROR(reader->AddOutputs(node)); + Value* input; + RETURN_IF_ERROR(reader->ReadValue(0, &input)); + RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); + + Tensor tmp; + RETURN_IF_ERROR(reader->ReadTensor(1, &tmp)); + + bool read_without_batch = tmp.data.size() == 3; + bool read_with_batch = tmp.data.size() == 4; + if (!read_without_batch && !read_with_batch) { + return UnimplementedError( + "Slicing is supported for 3 or 4 dimensional tensors only."); + } + + const auto* tf_options = reinterpret_cast( + tflite_node->builtin_data); + auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; + if (!tf_options) { + return InternalError("Missing tflite params"); + } + RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); + + SliceAttributes attr; + if (read_without_batch) { + RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options, + input->tensor.shape, &attr)); + } + if (read_with_batch) { + RETURN_IF_ERROR( + ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr)); + } + if (attr.strides.h < 0 || attr.strides.w < 0 || attr.strides.c < 0) { + return UnimplementedError("Reverse slices are not supported."); + } + if (attr.ends.h - attr.starts.h != out_shape.h) { + return UnimplementedError("Output height doesn't match"); + } + if (attr.ends.w - attr.starts.w != out_shape.w) { + return UnimplementedError("Output width doesn't match"); + } + if (attr.ends.c - attr.starts.c != out_shape.c) { + return UnimplementedError("Output channels don't match"); + } + node->operation.attributes = attr; + return OkStatus(); + } + + private: + Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, int ignore_h, int ignore_w, + int ignore_c, SliceAttributes* attr) { + if (tf_options->begin_mask & ignore_h) { + attr->starts.h = 0; + } + if (tf_options->begin_mask & ignore_w) { + attr->starts.w = 0; + } + if (tf_options->begin_mask & ignore_c) { + attr->starts.c = 0; + } + + if (tf_options->end_mask & ignore_h) { + attr->ends.h = input_shape.h; + } + if (tf_options->end_mask & ignore_w) { + attr->ends.w = input_shape.w; + } + if (tf_options->end_mask & ignore_c) { + attr->ends.c = input_shape.c; + } + return OkStatus(); + } + + Status UpdateIfNegative(const BHWC& input_shape, SliceAttributes* attr) { + if (attr->ends.h < 0) { + attr->ends.h = input_shape.h + attr->ends.h; + } + if (attr->ends.w < 0) { + attr->ends.w = input_shape.w + attr->ends.w; + } + if (attr->ends.c < 0) { + attr->ends.c = input_shape.c + attr->ends.c; + } + return OkStatus(); + } + + Status ReadAttribsWithBatch(const ObjectReader* reader, + const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, SliceAttributes* attr) { + auto read_hwc = [&](int tensor_index, HWC* hwc) -> Status { + Tensor t; + RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); + if (t.data[0] != 1 && t.data[0] != 0) { + return UnimplementedError( + "Slicing for BATCH channel is not supported. If you use batch it " + "should be 0 or 1"); + } + *hwc = HWC(t.data[1], t.data[2], t.data[3]); + return OkStatus(); + }; + + RETURN_IF_ERROR(read_hwc(1, &attr->starts)); + RETURN_IF_ERROR(read_hwc(2, &attr->ends)); + RETURN_IF_ERROR(read_hwc(3, &attr->strides)); + RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr)); + RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 2, 4, 8, attr)); + return OkStatus(); + } + + Status ReadAttribsWithoutBatch(const ObjectReader* reader, + const TfLiteStridedSliceParams* tf_options, + const BHWC& input_shape, + SliceAttributes* attr) { + auto read_hwc = [&](int tensor_index, HWC* hwc) -> Status { + Tensor t; + RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); + *hwc = HWC(t.data[0], t.data[1], t.data[2]); + return OkStatus(); + }; + + RETURN_IF_ERROR(read_hwc(1, &attr->starts)); + RETURN_IF_ERROR(read_hwc(2, &attr->ends)); + RETURN_IF_ERROR(read_hwc(3, &attr->strides)); + RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr)); + RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, attr)); + return OkStatus(); + } + Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) { + if (tf_options->ellipsis_mask) { + return UnimplementedError("Slice does not support ellipsis_mask."); + } + if (tf_options->new_axis_mask) { + return UnimplementedError("Slice does not support new_axis_mask."); + } + if (tf_options->shrink_axis_mask) { + return UnimplementedError( + "Slice does not support shrink_axis_mask parameter. "); + } + return OkStatus(); + } +}; + +class TransposeConvOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + TfLiteTransposeConvParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); + RETURN_IF_ERROR( + CheckStrides(tf_options->stride_height, tf_options->stride_width)); + return OkStatus(); + } + // TFLite's TRANSPOSE_CONV expects 3 input (output shape, weights, and input) + // and allows configurable padding & stride. + // TODO(impjdi): Translate output_shape to attr.adjacent. + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + auto* node = graph->NewNode(); + node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); + Value* input; + RETURN_IF_ERROR(reader->ReadValue(2, &input)); + RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + const auto* tf_options = reinterpret_cast( + tflite_node->builtin_data); + if (!tf_options) { + return InternalError("Missing tflite options."); + } + ConvolutionTransposedAttributes attr; + attr.stride = tf_options + ? HW(tf_options->stride_height, tf_options->stride_width) + : HW(1, 1); + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + + // TFLite does not support bias. + + UpdatePadding(tf_options->padding, + graph->FindInputs(node->id)[0]->tensor.shape, &attr); + node->operation.attributes = std::move(attr); + return OkStatus(); + } +}; + +class Convolution2DTransposeBiasParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + TfLiteTransposeConvParams* tf_options = nullptr; + RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); + RETURN_IF_ERROR( + CheckStrides(tf_options->stride_height, tf_options->stride_width)); + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + auto* node = graph->NewNode(); + node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + const auto* params = reinterpret_cast( + tflite_node->custom_initial_data); + ConvolutionTransposedAttributes attr; + attr.stride = + params ? HW(params->stride_height, params->stride_width) : HW(1, 1); + + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional + + UpdatePadding(params->padding, graph->FindInputs(node->id)[0]->tensor.shape, + &attr); + + node->operation.attributes = std::move(attr); + return OkStatus(); + } +}; + +class SpaceToBatchOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + auto* node = graph->NewNode(); + node->operation.type = ToString(OperationType::SPACE_TO_BATCH); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + SpaceToBatchAttributes sb_attr; + Tensor block; + RETURN_IF_ERROR(reader->ReadTensor(1, &block)); + if (block.shape.v != 2) { + return InternalError("Space has to be HxW."); + } + sb_attr.block.h = block.data[0]; + sb_attr.block.w = block.data[1]; + + Tensor padding; + RETURN_IF_ERROR(reader->ReadTensor(2, &padding)); + auto padding_shape = padding.shape; + + if (padding_shape.h != 2 && padding_shape.w != 2) { + return InternalError("Space has to be HxW."); + } + + sb_attr.padding.prepended.h = padding.data[0]; + sb_attr.padding.prepended.w = padding.data[2]; + + sb_attr.padding.appended.h = padding.data[1]; + sb_attr.padding.appended.w = padding.data[3]; + + node->operation.attributes = std::move(sb_attr); + return OkStatus(); + } +}; + +class BatchToSpaceOperationParser : public TFLiteOperationParser { + public: + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return OkStatus(); + } + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + auto* node = graph->NewNode(); + node->operation.type = ToString(OperationType::BATCH_TO_SPACE); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + BatchToSpaceAttributes bs_attr; + Tensor block; + RETURN_IF_ERROR(reader->ReadTensor(1, &block)); + if (block.shape.v != 2) { + return InternalError("Space has to be HxW."); + } + bs_attr.block.h = block.data[0]; + bs_attr.block.w = block.data[1]; + + Tensor crop; + RETURN_IF_ERROR(reader->ReadTensor(2, &crop)); + auto crop_shape = crop.shape; + if (crop_shape.h != 2 && crop_shape.w != 2) { + return InternalError("Space has to be HxW."); + } + + bs_attr.crop.prepended.h = crop.data[0]; + bs_attr.crop.prepended.w = crop.data[2]; + + bs_attr.crop.appended.h = crop.data[1]; + bs_attr.crop.appended.w = crop.data[3]; + + node->operation.attributes = std::move(bs_attr); + return OkStatus(); + } +}; + +class UnsupportedOperationParser : public TFLiteOperationParser { + Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return UnimplementedError("Operation is not supported."); + } + + Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, GraphFloat32* graph, + ObjectReader* reader) final { + return UnimplementedError("Operation is not supported."); + } +}; + +std::unique_ptr NewOperationParser( + const TfLiteRegistration* registration) { + const auto builtin_code = registration->builtin_code; + const absl::string_view custom_name = registration->custom_name; + switch (builtin_code) { + case kTfLiteBuiltinAbs: + return make_unique(OperationType::ABS); + case kTfLiteBuiltinAdd: + return make_unique(); + case kTfLiteBuiltinAveragePool2d: + return make_unique(PoolingType::AVERAGE); + case kTfLiteBuiltinConcatenation: + return make_unique(); + case kTfLiteBuiltinConv2d: + return make_unique(); + case kTfLiteBuiltinCos: + return make_unique(OperationType::COS); + case kTfLiteBuiltinDepthwiseConv2d: + return make_unique(); + case kTfLiteBuiltinDiv: + return make_unique(OperationType::DIV); + case kTfLiteBuiltinFullyConnected: + return make_unique(); + case kTfLiteBuiltinLogistic: + return make_unique(OperationType::SIGMOID); + case kTfLiteBuiltinLog: + return make_unique(OperationType::LOG); + case kTfLiteBuiltinLstm: + return make_unique(); + case kTfLiteBuiltinMaxPool2d: + return make_unique(PoolingType::MAX); + case kTfLiteBuiltinMul: + return make_unique(); + case kTfLiteBuiltinPad: + return make_unique(); + case kTfLiteBuiltinPow: + return make_unique(OperationType::POW); + case kTfLiteBuiltinRelu: + return make_unique(0); + case kTfLiteBuiltinRelu6: + return make_unique(6); + case kTfLiteBuiltinLeakyRelu: + return make_unique(0); + case kTfLiteBuiltinPrelu: + return make_unique(); + case kTfLiteBuiltinReshape: + return make_unique(); + case kTfLiteBuiltinResizeBilinear: + return make_unique(); + case kTfLiteBuiltinRsqrt: + return make_unique(OperationType::RSQRT); + case kTfLiteBuiltinSin: + return make_unique(OperationType::SIN); + case kTfLiteBuiltinSoftmax: + return make_unique(); + case kTfLiteBuiltinStridedSlice: + return make_unique(); + case kTfLiteBuiltinSqrt: + return make_unique(OperationType::SQRT); + case kTfLiteBuiltinSquare: + return make_unique(OperationType::SQUARE); + case kTfLiteBuiltinSquaredDifference: + return make_unique( + OperationType::SQUARED_DIFF); + case kTfLiteBuiltinSub: + return make_unique(OperationType::SUB); + case kTfLiteBuiltinTanh: + return make_unique(OperationType::TANH); + case kTfLiteBuiltinTransposeConv: + return make_unique(); + + case kTfLiteBuiltinCustom: + if (custom_name == "Convolution2DTransposeBias") { + return make_unique(); + } + if (custom_name == "MaxPoolingWithArgmax2D") { + return make_unique(PoolingType::MAX); + } + if (custom_name == "MaxUnpooling2D") { + return make_unique(); + } + break; + } + return make_unique(); +} + +} // namespace + +Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, + TensorRefFloat32* tensor_ref) { + tensor_ref->type = ToDataType(tflite_tensor.type); + const TfLiteIntArray* dims = tflite_tensor.dims; + switch (dims->size) { + case 1: + tensor_ref->shape = BHWC(dims->data[0], 1, 1, 1); + break; + case 2: + tensor_ref->shape = BHWC(dims->data[0], 1, 1, dims->data[1]); + break; + case 3: + tensor_ref->shape = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]); + break; + case 4: + tensor_ref->shape = + BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]); + break; + default: + return InvalidArgumentError(StrCat( + "Tensor ref has unsupported number of dimensions: ", dims->size)); + } + return OkStatus(); +} + +Status IsSupported(const TfLiteContext* context, TfLiteNode* node, + const TfLiteRegistration* registration) { + return NewOperationParser(registration) + ->IsSupported(context, node, registration); +} + +bool IsAllFloatTensors(const TfLiteContext* context, + const TfLiteIntArray* array) { + for (int i = 0; i < array->size; ++i) { + const TfLiteTensor* t = context->tensors + array->data[i]; + if (t->allocation_type == kTfLiteArenaRw && t->type != kTfLiteFloat32) { + return false; + } + } + return true; +} + +std::string GetOpNameByRegistration(const TfLiteRegistration* registration) { + auto op = registration->builtin_code; + std::string result = + EnumNameBuiltinOperator(static_cast(op)); + if (op == kTfLiteBuiltinCustom) { + result += " " + std::string(registration->custom_name); + } + return result; +} + +Status GetNodeAndRegistration(TfLiteContext* context, int node_id, + TfLiteNode** tflite_node, + TfLiteRegistration** registration) { + if (context->GetNodeAndRegistration(context, node_id, tflite_node, + registration) != kTfLiteOk) { + return InvalidArgumentError( + StrCat("Couldn't get node and registration info for op: ", node_id)); + } + return OkStatus(); +} + +// TODO(impjdi): Check number of input/output tensors and their dimensions. +// TODO(impjdi): Check ops' parameters. +TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { + TfLiteIntArray* execution_plan = nullptr; + if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) { + context->ReportError(context, "Unable to get graph execution plan."); + return nullptr; + } + TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size); + subgraph->size = 0; + std::set errors; + for (int i = 0; i < execution_plan->size; ++i) { + TfLiteNode* node = nullptr; + TfLiteRegistration* registration = nullptr; + auto status = GetNodeAndRegistration(context, i, &node, ®istration); + if (!status.ok()) { + context->ReportError(context, status.error_message().c_str()); + return nullptr; + } + status = IsSupported(context, node, registration); + if (status.ok() && + // TODO(eignasheva): resolve sub operation support for metal delegate + // registration->builtin_code != kTfLiteBuiltinSub && + IsAllFloatTensors(context, node->inputs) && + IsAllFloatTensors(context, node->outputs)) { + if (errors.empty()) subgraph->data[subgraph->size++] = i; + } else { + errors.insert(GetOpNameByRegistration(registration) + ": " + + status.error_message()); + } + } + if (!errors.empty()) { + std::string unsupported = absl::StrJoin(errors, "\n"); + std::string error_message = + "Next operations are not supported by GPU delegate:\n" + unsupported + + "\nFirst " + std::to_string(subgraph->size) + + " operations will run on the GPU, and the remaining " + + std::to_string(execution_plan->size - subgraph->size) + " on the CPU."; + context->ReportError(context, error_message.c_str()); + } + return subgraph; +} + +Status BuildModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph) { + std::vector> operations; + for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { + TfLiteNode* tflite_node = nullptr; + TfLiteRegistration* registration = nullptr; + RETURN_IF_ERROR(GetNodeAndRegistration( + context, delegate_params->nodes_to_replace->data[i], &tflite_node, + ®istration)); + auto op_parser = NewOperationParser(registration); + if (!op_parser) { + return UnimplementedError( + StrCat("Operation ", registration->builtin_code, "(", + registration->custom_name, + ") is not supported by TFLite GPU Delegate.")); + } + operations.push_back(std::move(op_parser)); + } + std::vector*> tensor_to_value(context->tensors_size, + nullptr); + for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { + TfLiteNode* tflite_node = nullptr; + TfLiteRegistration* registration = nullptr; + RETURN_IF_ERROR(GetNodeAndRegistration( + context, delegate_params->nodes_to_replace->data[i], &tflite_node, + ®istration)); + ObjectReader reader(graph, context, tflite_node, &tensor_to_value); + RETURN_IF_ERROR( + operations[i]->Parse(tflite_node, registration, graph, &reader)); + } + return OkStatus(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h new file mode 100644 index 00000000000..09026b89af9 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model_builder.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ + +#include +#include + +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { + +// Validates which operations are supported and returns array of operations to +// replace with GPU kernels. The caller must free the pointer on TfLiteIntArray. +TfLiteIntArray* GetOpsToReplace(TfLiteContext* context); + +// Extracts TFLite delegate execution plan from the input TFLite context and +// converts it into generic graph format. +Status BuildModel(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params, + GraphFloat32* graph); + +// Module-internal converter, exposed for unit testing purpose only. +Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, + TensorRefFloat32* tensor_ref); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_ diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc new file mode 100644 index 00000000000..584cadc6d5d --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/model_builder.h" + +#include +#include +#include "tensorflow/lite/c/c_api_internal.h" + +namespace tflite { +namespace gpu { +namespace { + +TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank0) { + TfLiteTensor tflite_tensor; + tflite_tensor.type = TfLiteType::kTfLiteFloat32; + tflite_tensor.dims = TfLiteIntArrayCreate(1); + tflite_tensor.dims->data[0] = 4; + TensorRefFloat32 tensor_ref; + const auto status = + ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); + TfLiteIntArrayFree(tflite_tensor.dims); + ASSERT_TRUE(status.ok()); + EXPECT_EQ(tensor_ref.type, DataType::FLOAT32); + EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 1, 1)); +} + +TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank1) { + TfLiteTensor tflite_tensor; + tflite_tensor.type = TfLiteType::kTfLiteInt32; + tflite_tensor.dims = TfLiteIntArrayCreate(2); + tflite_tensor.dims->data[0] = 4; + tflite_tensor.dims->data[1] = 5; + TensorRefFloat32 tensor_ref; + const auto status = + ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); + TfLiteIntArrayFree(tflite_tensor.dims); + ASSERT_TRUE(status.ok()); + EXPECT_EQ(tensor_ref.type, DataType::INT32); + EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 1, 5)); +} + +TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank2) { + TfLiteTensor tflite_tensor; + tflite_tensor.type = TfLiteType::kTfLiteInt64; + tflite_tensor.dims = TfLiteIntArrayCreate(3); + tflite_tensor.dims->data[0] = 4; + tflite_tensor.dims->data[1] = 5; + tflite_tensor.dims->data[2] = 6; + TensorRefFloat32 tensor_ref; + const auto status = + ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); + TfLiteIntArrayFree(tflite_tensor.dims); + ASSERT_TRUE(status.ok()); + EXPECT_EQ(tensor_ref.type, DataType::INT64); + EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 5, 6)); +} + +TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank3) { + TfLiteTensor tflite_tensor; + tflite_tensor.type = TfLiteType::kTfLiteUInt8; + tflite_tensor.dims = TfLiteIntArrayCreate(4); + tflite_tensor.dims->data[0] = 4; + tflite_tensor.dims->data[1] = 5; + tflite_tensor.dims->data[2] = 6; + tflite_tensor.dims->data[3] = 7; + TensorRefFloat32 tensor_ref; + const auto status = + ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); + TfLiteIntArrayFree(tflite_tensor.dims); + ASSERT_TRUE(status.ok()); + EXPECT_EQ(tensor_ref.type, DataType::UINT8); + EXPECT_EQ(tensor_ref.shape, BHWC(4, 5, 6, 7)); +} + +TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankLT0) { + TfLiteTensor tflite_tensor; + tflite_tensor.type = TfLiteType::kTfLiteFloat32; + tflite_tensor.dims = TfLiteIntArrayCreate(0); + TensorRefFloat32 tensor_ref; + const auto status = + ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); + TfLiteIntArrayFree(tflite_tensor.dims); + // TODO(b/130054481): Cover scalar. + EXPECT_FALSE(status.ok()); +} + +TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) { + TfLiteTensor tflite_tensor; + tflite_tensor.type = TfLiteType::kTfLiteFloat32; + tflite_tensor.dims = TfLiteIntArrayCreate(5); + TensorRefFloat32 tensor_ref; + const auto status = + ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref); + TfLiteIntArrayFree(tflite_tensor.dims); + EXPECT_FALSE(status.ok()); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/model_test.cc b/tensorflow/lite/delegates/gpu/common/model_test.cc new file mode 100644 index 00000000000..ff591469675 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model_test.cc @@ -0,0 +1,365 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/model.h" + +#include +#include + +#include +#include + +namespace tflite { +namespace gpu { +namespace { + +using ::testing::UnorderedElementsAre; + +TEST(Model, SingleNode) { + // graph_input -> node -> graph_output + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); + + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node)); + EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output)); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.FindInputs(node->id), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.FindOutputs(node->id), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.FindConsumers(graph_input->id), UnorderedElementsAre(node)); + EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node)); + EXPECT_THAT(graph.FindConsumers(graph_output->id), UnorderedElementsAre()); + EXPECT_THAT(graph.FindProducer(graph_input->id), ::testing::Eq(nullptr)); +} + +TEST(Model, SingleNodeMultipleOutputs) { + // graph_input -> node -> (graph_output1, graph_output2) + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output1 = graph.NewValue(); + Value* graph_output2 = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok()); + ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok()); + EXPECT_THAT(graph.FindOutputs(node->id), + UnorderedElementsAre(graph_output1, graph_output2)); + EXPECT_THAT(graph.FindProducer(graph_output1->id), ::testing::Eq(node)); + EXPECT_THAT(graph.FindProducer(graph_output2->id), ::testing::Eq(node)); +} + +TEST(Model, SetSameConsumer) { + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* graph_input = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); + EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok()); +} + +TEST(Model, RemoveConsumer) { + // (graph_input1, graph_input2) -> node + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* graph_input1 = graph.NewValue(); + Value* graph_input2 = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok()); + ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok()); + EXPECT_THAT(graph.FindConsumers(graph_input1->id), + UnorderedElementsAre(node)); + EXPECT_THAT(graph.FindConsumers(graph_input2->id), + UnorderedElementsAre(node)); + EXPECT_THAT(graph.FindInputs(node->id), + UnorderedElementsAre(graph_input1, graph_input2)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre()); + + // Now remove graph_input1 + ASSERT_TRUE(graph.RemoveConsumer(node->id, graph_input1->id).ok()); + EXPECT_THAT(graph.FindConsumers(graph_input1->id), UnorderedElementsAre()); + EXPECT_THAT(graph.FindInputs(node->id), UnorderedElementsAre(graph_input2)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_input1)); + + // Can not remove it twice + ASSERT_FALSE(graph.RemoveConsumer(node->id, graph_input1->id).ok()); +} + +TEST(Model, SetSameProducer) { + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* graph_output = graph.NewValue(); + ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); + EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok()); +} + +TEST(Model, RemoveProducer) { + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* graph_output = graph.NewValue(); + + ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre()); + EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node)); + + ASSERT_TRUE(graph.RemoveProducer(graph_output->id).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(nullptr)); + + // Can not remove producer twice + ASSERT_FALSE(graph.RemoveProducer(graph_output->id).ok()); +} + +TEST(Model, RemoveSimpleNodeDegenerateCase) { + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output = graph.NewValue(); + + ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node)); + + ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, node).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre()); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre()); + EXPECT_THAT(graph.nodes(), UnorderedElementsAre()); +} + +TEST(Model, RemoveSimpleNodeNoPreviousNode) { + GraphFloat32 graph; + Node* simple_node = graph.NewNode(); + Node* consumer_node = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output = graph.NewValue(); + Value* value = graph.NewValue(); + + ASSERT_TRUE(graph.AddConsumer(simple_node->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(simple_node->id, value->id).ok()); + ASSERT_TRUE(graph.AddConsumer(consumer_node->id, value->id).ok()); + ASSERT_TRUE(graph.SetProducer(consumer_node->id, graph_output->id).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(simple_node, consumer_node)); + + ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(consumer_node)); +} + +TEST(Model, RemoveSimpleNodeNoAfterNodes) { + GraphFloat32 graph; + Node* simple_node = graph.NewNode(); + Node* producer_node = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output = graph.NewValue(); + Value* value = graph.NewValue(); + + ASSERT_TRUE(graph.AddConsumer(simple_node->id, value->id).ok()); + ASSERT_TRUE(graph.SetProducer(simple_node->id, graph_output->id).ok()); + ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(producer_node->id, value->id).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(simple_node, producer_node)); + + ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(producer_node)); +} + +TEST(Model, RemoveSimpleNodeGeneralCase) { + GraphFloat32 graph; + Node* simple_node = graph.NewNode(); + Node* producer_node = graph.NewNode(); + Node* consumer_node = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output = graph.NewValue(); + Value* value0 = graph.NewValue(); + Value* value1 = graph.NewValue(); + + ASSERT_TRUE(graph.AddConsumer(producer_node->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(producer_node->id, value0->id).ok()); + ASSERT_TRUE(graph.AddConsumer(simple_node->id, value0->id).ok()); + ASSERT_TRUE(graph.SetProducer(simple_node->id, value1->id).ok()); + ASSERT_TRUE(graph.AddConsumer(consumer_node->id, value1->id).ok()); + ASSERT_TRUE(graph.SetProducer(consumer_node->id, graph_output->id).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.nodes(), + UnorderedElementsAre(simple_node, producer_node, consumer_node)); + + ASSERT_TRUE(RemoveOneInputOneOutputNode(&graph, simple_node).ok()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.nodes(), + UnorderedElementsAre(producer_node, consumer_node)); +} + +TEST(Model, CircularDependency) { + { + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* value = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok()); + EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok()); + } + { + GraphFloat32 graph; + Node* node = graph.NewNode(); + Value* value = graph.NewValue(); + ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok()); + EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok()); + } +} + +TEST(Model, ReassignValue) { + // Before: + // graph_input -> node1 -> graph_output + // \ -> node2 + GraphFloat32 graph; + Node* node1 = graph.NewNode(); + Node* node2 = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok()); + ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok()); + + // After: + // graph_input -> node1 + // \ -> node2 -> graph_output + ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok()); + + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2)); + EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre()); + EXPECT_THAT(graph.FindOutputs(node2->id), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.FindConsumers(graph_input->id), + UnorderedElementsAre(node1, node2)); + EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node2)); + EXPECT_THAT(graph.FindConsumers(graph_output->id), UnorderedElementsAre()); +} + +TEST(Model, DeleteValue) { + // graph_input -> node1 -> value -> node2 -> graph_output + GraphFloat32 graph; + Node* node1 = graph.NewNode(); + Node* node2 = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output = graph.NewValue(); + Value* value = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok()); + ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok()); + ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok()); + + EXPECT_THAT(graph.values(), + UnorderedElementsAre(graph_input, graph_output, value)); + EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2)); + EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1)); + EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value)); + EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre(value)); + + ASSERT_TRUE(graph.DeleteValue(value->id).ok()); + value = nullptr; + EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output)); + EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre()); + EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre()); + + ASSERT_TRUE(graph.DeleteValue(graph_input->id).ok()); + graph_input = nullptr; + EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_output)); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre()); + EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre()); + + ASSERT_TRUE(graph.DeleteValue(graph_output->id).ok()); + graph_output = nullptr; + EXPECT_THAT(graph.values(), UnorderedElementsAre()); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre()); + EXPECT_THAT(graph.FindOutputs(node2->id), UnorderedElementsAre()); +} + +TEST(Model, DeleteNode) { + // graph_input -> node1 -> value -> node2 -> graph_output + // \-> node3 -> graph_output2 + GraphFloat32 graph; + Node* node1 = graph.NewNode(); + Node* node2 = graph.NewNode(); + Node* node3 = graph.NewNode(); + Value* graph_input = graph.NewValue(); + Value* graph_output = graph.NewValue(); + Value* graph_output2 = graph.NewValue(); + Value* value = graph.NewValue(); + ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok()); + ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok()); + ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok()); + ASSERT_TRUE(graph.AddConsumer(node3->id, value->id).ok()); + ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok()); + ASSERT_TRUE(graph.SetProducer(node3->id, graph_output2->id).ok()); + + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2, node3)); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input)); + EXPECT_THAT(graph.outputs(), + UnorderedElementsAre(graph_output, graph_output2)); + EXPECT_THAT(graph.FindConsumers(value->id), + UnorderedElementsAre(node2, node3)); + EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1)); + EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value)); + EXPECT_THAT(graph.FindInputs(node3->id), UnorderedElementsAre(value)); + + // graph_input -> node1 -> value -> node2 -> graph_output + // graph_output2 + ASSERT_TRUE(graph.DeleteNode(node3->id).ok()); + node3 = nullptr; + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node1, node2)); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input, graph_output2)); + EXPECT_THAT(graph.outputs(), + UnorderedElementsAre(graph_output, graph_output2)); + EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2)); + + // value -> node2 -> graph_output + // graph_input + // graph_output2 + ASSERT_TRUE(graph.DeleteNode(node1->id).ok()); + node1 = nullptr; + EXPECT_THAT(graph.nodes(), UnorderedElementsAre(node2)); + EXPECT_THAT(graph.inputs(), + UnorderedElementsAre(value, graph_output2, graph_input)); + EXPECT_THAT(graph.outputs(), + UnorderedElementsAre(graph_input, graph_output, graph_output2)); + EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2)); + EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr)); + + ASSERT_TRUE(graph.DeleteNode(node2->id).ok()); + node2 = nullptr; + EXPECT_THAT(graph.nodes(), UnorderedElementsAre()); + EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output, graph_output2, + graph_input, value)); + EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output, graph_output2, + graph_input, value)); + EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre()); + EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr)); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/model_transformer.cc b/tensorflow/lite/delegates/gpu/common/model_transformer.cc new file mode 100644 index 00000000000..81287dd61e5 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model_transformer.cc @@ -0,0 +1,197 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" + +namespace tflite { +namespace gpu { + +bool ModelTransformer::Apply(const std::string& name, + SequenceTransformation* transformation) { + // Seed transformations with starting node. Each node may start a chain of + // transformations. + for (auto input : graph_->inputs()) { + for (auto node : graph_->FindConsumers(input->id)) { + AddNodeToProcess(node); + } + } + while (!to_process_.empty()) { + auto node = graph_->GetNode(to_process_.front()); + if (node) { + if (!ApplyStartingWithNode(name, transformation, node)) { + return false; + } + } + to_process_.pop_front(); + } + processed_.clear(); + return true; +} + +bool ModelTransformer::Apply(const std::string& name, + NodeTransformation* transformation) { + // Apply a transformation only to nodes that are present in the graph before + // transformation. + std::vector nodes; + for (auto node : graph_->nodes()) { + nodes.push_back(node->id); + } + for (auto node_id : nodes) { + auto node = graph_->GetNode(node_id); + if (!node) { + continue; + } + auto result = transformation->ApplyToNode(node, graph_); + if (result.status == TransformStatus::INVALID) { + return false; + } + if (reporter_) { + if (result.status == TransformStatus::APPLIED) { + reporter_->AppliedTransformation(name, std::to_string(node_id), + result.message); + } + if (result.status == TransformStatus::DECLINED) { + reporter_->DeclinedTransformation(name, std::to_string(node_id), + result.message); + } + } + } + return true; +} + +bool ModelTransformer::ApplyStartingWithNode( + const std::string& name, SequenceTransformation* transformation, + Node* begin) { + int expected_sequence_length = transformation->ExpectedSequenceLength(); + + std::deque sequence; + std::vector nodes; + nodes.reserve(transformation->ExpectedSequenceLength()); + sequence.push_back(begin->id); + + // Go over nodes with sequence sliding window of size + // expected_sequence_length until a node with multiple dependents is found. + while (true) { + // Apply transformation if possible. + if (sequence.size() == expected_sequence_length) { + nodes.clear(); + for (NodeId id : sequence) { + // Nodes present in sequence should be present in a graph. If they are + // not, then this transformation changes a graph but didn't say it. + Node* node = graph_->GetNode(id); + if (node == nullptr) { + return false; + } + nodes.push_back(node); + } + + NodeId first_in_sequence = sequence.front(); + auto preceding_node = + graph_->FindProducer(graph_->FindInputs(first_in_sequence)[0]->id); + auto result = transformation->ApplyToNodesSequence(nodes, graph_); + if (result.status == TransformStatus::INVALID) { + // graph is broken now. + return false; + } + if (result.status == TransformStatus::DECLINED) { + if (reporter_) { + reporter_->DeclinedTransformation(name, absl::StrJoin(sequence, "+"), + result.message); + } + } else if (result.status == TransformStatus::APPLIED) { + if (reporter_) { + reporter_->AppliedTransformation(name, absl::StrJoin(sequence, "+"), + result.message); + } + // Also remove first node of a sequence from a set of processed node. + // Out of all nodes in a sequence only first one may have been added + // to "processed" set because other nodes do not have more than one + // dependent. However, if a sequence is changed, then processing needs + // to be restarted again. + processed_.erase(first_in_sequence); + // Transformation was successful. Restart sequence from the node that + // precedes current sequence. + if (preceding_node) { + processed_.erase(preceding_node->id); + AddNodeToProcess(preceding_node); + } else { + // This is the first node in the graph. Re-seed transformation. + for (auto input : graph_->inputs()) { + for (auto node : graph_->FindConsumers(input->id)) { + AddNodeToProcess(node); + } + } + } + return true; + } + } + + // Try to extend current sequence. + Node* next_node_in_sequence = nullptr; + bool has_multiple_children = false; + + // Check that all outputs from last node are consumed by a single node. + for (auto output_value : graph_->FindOutputs(sequence.back())) { + for (auto dependent : graph_->FindConsumers(output_value->id)) { + if (has_multiple_children) { + AddNodeToProcess(dependent); + } else if (next_node_in_sequence == nullptr) { + next_node_in_sequence = dependent; + } else if (next_node_in_sequence != dependent) { + // There are more than two nodes depend on the output from end node, + // therefore here a sequence stops and new will start. Push all such + // nodes. + has_multiple_children = true; + AddNodeToProcess(dependent); + AddNodeToProcess(next_node_in_sequence); + } + } + } + + // Now check that next node has inputs only produced by the last node. + if (!has_multiple_children && next_node_in_sequence) { + for (auto input : graph_->FindInputs(next_node_in_sequence->id)) { + auto producer = graph_->FindProducer(input->id); + if (producer == nullptr || producer->id != sequence.back()) { + has_multiple_children = true; + AddNodeToProcess(next_node_in_sequence); + break; + } + } + } + + if (has_multiple_children || next_node_in_sequence == nullptr) { + // reached end of this transformation sequence. + return true; + } + + sequence.push_back(next_node_in_sequence->id); + // Decrease sequence until it matches expected length. + if (sequence.size() > expected_sequence_length) { + sequence.pop_front(); + } + } + return true; +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/model_transformer.h b/tensorflow/lite/delegates/gpu/common/model_transformer.h new file mode 100644 index 00000000000..d82a6a687ca --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/model_transformer.h @@ -0,0 +1,146 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_ + +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" + +namespace tflite { +namespace gpu { + +class TransformationReporter; + +struct TransformationContext { + GraphFloat32* graph; + TransformationReporter* reporter; +}; + +enum class TransformStatus { + // Transformation was not applied due to trivial conditions mismatch. + // + // This is different from DECLINED code below that provides in-depth + // explanation why a transformation that could have been applied but was not + // due to some issues. + SKIPPED, + + // Transformation was declined, therefore, a model was not modified. + DECLINED, + + // Transformation was applied successfully + APPLIED, + + // Transformation may partially be applied, but left a model in an invalid + // state. This error should be considered unrecoverable. + INVALID, +}; + +struct TransformResult { + TransformStatus status; + std::string message; +}; + +// Class responsible for applying a transformation to a single node. +class NodeTransformation { + public: + virtual ~NodeTransformation() = default; + + virtual TransformResult ApplyToNode(Node* node, GraphFloat32* graph) = 0; +}; + +// Class responsible for applying a transformation to a sequence of nodes. +// Nodes are guaranteed to depend on each other without extra dependents being +// spilled. +class SequenceTransformation { + public: + virtual ~SequenceTransformation() = default; + + // @return number of nodes in a sequence to apply this transformation. + virtual int ExpectedSequenceLength() const = 0; + + // Applies transformations to a sequence of nodes. Transformation + // implementation is free manipulate with sequence nodes including adding + // and/or deleting nodes. if there were updates to nodes in the end and/or + // beginning of the sequence, then referential consistency should be + // maintained by updating relevant references in nodes that precede this + // sequence or depend on a last node of the sequence. + virtual TransformResult ApplyToNodesSequence( + const std::vector& sequence, GraphFloat32* graph) = 0; +}; + +// A class accumulated decisions or updates done by transformations. +class TransformationReporter { + public: + virtual ~TransformationReporter() = default; + + virtual void DeclinedTransformation(const std::string& transformation, + const std::string& node_ids, + const std::string& message) = 0; + + virtual void AppliedTransformation(const std::string& transformation, + const std::string& node_ids, + const std::string& message) = 0; +}; + +// A class is designed to perform model transformations. +class ModelTransformer { + public: + ModelTransformer(GraphFloat32* graph, TransformationReporter* reporter) + : graph_(graph), reporter_(reporter) {} + + // @return false if a graph is in the broken states can not be used any more + bool Apply(const std::string& name, SequenceTransformation* transformation); + + // @return false if a graph is in the broken states can not be used any more + bool Apply(const std::string& name, NodeTransformation* transformation); + + private: + bool ApplyStartingWithNode(const std::string& name, + SequenceTransformation* transformation, + Node* begin); + + void AddNodeToProcess(Node* node) { + if (node && processed_.insert(node->id).second) { + to_process_.push_back(node->id); + } + } + + GraphFloat32* graph_; + TransformationReporter* reporter_; + + std::deque to_process_; + std::unordered_set processed_; +}; + +class NullTransformationReporter : public TransformationReporter { + public: + void DeclinedTransformation(const std::string& transformation, + const std::string& nodes_id, + const std::string& message) override {} + + void AppliedTransformation(const std::string& transformation, + const std::string& nodes_id, + const std::string& message) override {} +}; + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_TRANSFORMER_H_ diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc new file mode 100644 index 00000000000..f7f9d1b7351 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/operations.cc @@ -0,0 +1,414 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/operations.h" + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { + +Padding2D& Padding2D::operator=(const Padding2D& value) { + prepended = value.prepended; + appended = value.appended; + return *this; +} + +bool Padding2D::operator==(const Padding2D& value) { + return this->prepended == value.prepended && this->appended == value.appended; +} + +bool Padding2D::operator!=(const Padding2D& value) { return !(*this == value); } + +Padding2D& Padding2D::operator-(const Padding2D& value) { + prepended.h -= value.prepended.h; + prepended.w -= value.prepended.w; + appended.h -= value.appended.h; + appended.w -= value.appended.w; + return *this; +} + +std::string ToString(enum OperationType op) { + switch (op) { + case OperationType::UNKNOWN: + break; + case OperationType::ABS: + return "abs"; + case OperationType::ADD: + return "add"; + case OperationType::APPLY_MASK: + return "apply_mask"; + case OperationType::BATCH_TO_SPACE: + return "batch_to_space"; + case OperationType::POOLING_2D: + return "pooling_2d"; + case OperationType::MAX_UNPOOLING_2D: + return "max_unpooling"; + case OperationType::BATCH_NORMALIZATION: + return "batch_normalization"; + case OperationType::CONCAT: + return "concat"; + case OperationType::CONST: + return "const"; + case OperationType::CONVOLUTION_2D: + return "convolution_2d"; + case OperationType::COS: + return "cos"; + case OperationType::DEPTHWISE_CONVOLUTION: + return "depthwise_convolution"; + case OperationType::DIV: + return "div"; + case OperationType::LOG: + return "log"; + case OperationType::MUL: + return "mul"; + case OperationType::PAD: + return "pad"; + case OperationType::POW: + return "pow"; + case OperationType::PRELU: + return "prelu"; + case OperationType::RELU: + return "relu"; + case OperationType::RESIZE: + return "resize"; + case OperationType::RESHAPE: + return "reshape"; + case OperationType::RSQRT: + return "rsqrt"; + case OperationType::SIGMOID: + return "sigmoid"; + case OperationType::SIN: + return "sin"; + case OperationType::SLICE: + return "slice"; + case OperationType::SOFT_MAX: + return "soft_max"; + case OperationType::SPACE_TO_BATCH: + return "space_to_batch"; + case OperationType::SQRT: + return "sqrt"; + case OperationType::SQUARE: + return "square"; + case OperationType::SQUARED_DIFF: + return "squared_diff"; + case OperationType::SUB: + return "subtract"; + case OperationType::UPSAMPLE_2D: + return "upsample_2d"; + case OperationType::CONVOLUTION_TRANSPOSED: + return "convolution_transposed"; + case OperationType::MULTIPLY_SCALAR: + return "multiply_scalar"; + case OperationType::FULLY_CONNECTED: + return "fully_connected"; + case OperationType::TANH: + return "tanh"; + case OperationType::LSTM: + return "lstm"; + } + return "unknown_operation"; +} + +OperationType OperationTypeFromString(const std::string& name) { + static const auto operations = + new std::unordered_map({ + {"abs", OperationType::ABS}, + {"add", OperationType::ADD}, + {"apply_mask", OperationType::APPLY_MASK}, + {"batch_normalization", OperationType::BATCH_NORMALIZATION}, + {"concat", OperationType::CONCAT}, + {"const", OperationType::CONST}, + {"convolution_2d", OperationType::CONVOLUTION_2D}, + {"convolution_transposed", OperationType::CONVOLUTION_TRANSPOSED}, + {"cos", OperationType::COS}, + {"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION}, + {"fully_connected", OperationType::FULLY_CONNECTED}, + {"log", OperationType::LOG}, + {"lstm", OperationType::LSTM}, + {"max_unpooling", OperationType::MAX_UNPOOLING_2D}, + {"mul", OperationType::MUL}, + {"multiply_scalar", OperationType::MULTIPLY_SCALAR}, + {"pad", OperationType::PAD}, + {"pooling_2d", OperationType::POOLING_2D}, + {"prelu", OperationType::PRELU}, + {"relu", OperationType::RELU}, + {"resize", OperationType::RESIZE}, + {"reshape", OperationType::RESHAPE}, + {"rsqrt", OperationType::RSQRT}, + {"sigmoid", OperationType::SIGMOID}, + {"sin", OperationType::SIN}, + {"slice", OperationType::SLICE}, + {"soft_max", OperationType::SOFT_MAX}, + {"sqrt", OperationType::SQRT}, + {"square", OperationType::SQUARE}, + {"subtract", OperationType::SUB}, + {"tanh", OperationType::TANH}, + {"upsample_2d", OperationType::UPSAMPLE_2D}, + }); + auto op = operations->find(name); + return op == operations->end() ? OperationType::UNKNOWN : op->second; +} + +namespace { + +template +T IntegralDivideRoundUp(T n, T divisor) { + return (n - 1) / divisor + 1; +} + +int32_t CalculateOutputSizeBeforeStrides(int32_t input, int32_t kernel, + int32_t padding, int32_t dilation) { + const int32_t dilated_kernel = (kernel - 1) * dilation + 1; + return input + padding - dilated_kernel + 1; +} + +template +int32_t CalculateOutputWithoutStrides(const BHWC& input, + const Convolution2DAttributes& attr) { + return CalculateOutputSizeBeforeStrides( + input.get(), attr.weights.shape.get(), + attr.padding.prepended.get() + attr.padding.appended.get(), + attr.dilations.get()); +} + +template +int32_t CalculateOutputWithoutStrides(const BHWC& input, + const Pooling2DAttributes& attr) { + return CalculateOutputSizeBeforeStrides( + input.get(), attr.kernel.get(), + attr.padding.prepended.get() + attr.padding.appended.get(), + /*dilation=*/1); +} + +template +int32_t CalculateOutput(const BHWC& input, + const ConvolutionTransposedAttributes& attr) { + return (input.get() - 1) * attr.stride.get() - + (attr.padding.prepended.get() + attr.padding.appended.get()) + + attr.weights.shape.get() + attr.adjacent.get(); +} + +inline int32_t StridedSize(int32_t size, int32_t stride) { + return stride == 0 ? -1 : IntegralDivideRoundUp(size, stride); +} + +template +int32_t CalculateOutput(const BHWC& input, const AttrT& attr) { + return StridedSize(CalculateOutputWithoutStrides(input, attr), + attr.strides.template get()); +} + +int32_t CalculateSamePadding(int32_t input, int32_t kernel, int32_t dilation, + int32_t stride) { + const int32_t dilated_kernel = (kernel - 1) * dilation + 1; + return std::max(0, dilated_kernel - (input - 1) % stride - 1); +} + +// Returns a padding that should be present to make sure image size stays +// the same. +template +int32_t CalculateSamePadding(const BHWC& input, + const Convolution2DAttributes& attr) { + return CalculateSamePadding( + input.get(), attr.weights.shape.get(), + attr.dilations.get(), attr.strides.get()); +} + +template +int32_t CalculateSamePadding(const BHWC& input, + const ConvolutionTransposedAttributes& attr) { + return CalculateSamePadding(input.get(), + attr.weights.shape.get(), + /*dilation=*/1, attr.stride.get()); +} + +template +int32_t CalculateSamePadding(const BHWC& input, + const Pooling2DAttributes& attr) { + return CalculateSamePadding(input.get(), attr.kernel.get(), + /*dilation=*/1, attr.strides.get()); +} + +template +int32_t CalculateSamePadding(const BHWC& input, + const MaxUnpooling2DAttributes& attr) { + return CalculateSamePadding(input.get(), attr.kernel.get(), + /*dilation=*/1, attr.strides.get()); +} + +Padding2D MakeSamePadding(const BHWC& input, + const ConvolutionTransposedAttributes& attr) { + int32_t padding_height = CalculateSamePadding(input, attr); + int32_t padding_width = CalculateSamePadding(input, attr); + Padding2D padding; + padding.prepended = HW(padding_height / 2, padding_width / 2); + padding.appended = HW(padding_height - padding_height / 2, + padding_width - padding_width / 2); + return padding; +} + +// If padding depends on input, convert it into fixed padding. +template +Padding2D MakeSamePadding(const BHWC& input, const AttrT& attr) { + int32_t padding_height = CalculateSamePadding(input, attr); + int32_t padding_width = CalculateSamePadding(input, attr); + Padding2D padding; + padding.prepended = HW(padding_height / 2, padding_width / 2); + padding.appended = HW(padding_height - padding_height / 2, + padding_width - padding_width / 2); + return padding; +} + +} // namespace + +BHWC CalculateOutputShape(const BHWC& input, + const MaxUnpooling2DAttributes& attr) { + return BHWC(input.b, + input.h * attr.strides.h - attr.padding.prepended.h - + attr.padding.appended.h, + input.w * attr.strides.w - attr.padding.prepended.w - + attr.padding.appended.w, + input.c); +} + +BHWC CalculateOutputShape(const BHWC& input, const Pooling2DAttributes& attr) { + return BHWC(input.b, CalculateOutput(input, attr), + CalculateOutput(input, attr), input.c); +} + +BHWC CalculateOutputShape(const BHWC& input, + const Convolution2DAttributes& attr) { + return BHWC(input.b, CalculateOutput(input, attr), + CalculateOutput(input, attr), + attr.weights.shape.get()); +} + +BHWC CalculateOutputShape(const BHWC& input, + const ConvolutionTransposedAttributes& attr) { + return BHWC(input.b, CalculateOutput(input, attr), + CalculateOutput(input, attr), + attr.weights.shape.get()); +} + +BHWC CalculateOutputShape(const BHWC& input, + const DepthwiseConvolution2DAttributes& attr) { + return BHWC(input.b, CalculateOutput(input, attr), + CalculateOutput(input, attr), + attr.weights.shape.get() * + attr.weights.shape.get()); +} + +BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) { + return BHWC(input.b, StridedSize(attr.ends.h - attr.starts.h, attr.strides.h), + StridedSize(attr.ends.w - attr.starts.w, attr.strides.w), + StridedSize(attr.ends.c - attr.starts.c, attr.strides.c)); +} + +BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr) { + return BHWC(input.b, attr.appended.h + attr.prepended.h + input.h, + attr.appended.w + attr.prepended.w + input.w, + attr.appended.c + attr.prepended.c + input.c); +} + +BHWC CalculateOutputShape(const BHWC& input, + const FullyConnectedAttributes& attr) { + return BHWC(input.b, 1, 1, attr.weights.shape.o); +} + +Status CalculateOutputShape(const std::vector& input, + const ConcatAttributes& attr, BHWC* output_shape) { + BHWC new_shape = input[0]; + switch (attr.axis) { + case Axis::CHANNELS: + for (int i = 1; i < input.size(); i++) { + if (input[i].h != new_shape.h || input[i].w != new_shape.w) { + return InvalidArgumentError( + "Height and Width must be the same when concatenating " + "by channels axis"); + } + new_shape.c += input[i].c; + } + break; + case Axis::HEIGHT: + for (int i = 1; i < input.size(); i++) { + if (input[i].w != new_shape.w || input[i].c != new_shape.c) { + return InvalidArgumentError( + "Channels and Width must be the same when concatenating " + "by height axis"); + } + new_shape.h += input[i].h; + } + break; + case Axis::WIDTH: + for (int i = 1; i < input.size(); i++) { + if (input[i].h != new_shape.h || input[i].c != new_shape.c) { + return InvalidArgumentError( + "Height and Channels must be the same when concatenating " + "by width axis"); + } + new_shape.w += input[i].w; + } + break; + default: + return InvalidArgumentError("Invalid axis"); + break; + } + *output_shape = new_shape; + return OkStatus(); +} + +Padding2D CalculateSamePadding(const BHWC& input, + const Convolution2DAttributes& attr) { + return MakeSamePadding(input, attr); +} + +Padding2D CalculateSamePadding(const BHWC& input, + const ConvolutionTransposedAttributes& attr) { + return MakeSamePadding(input, attr); +} + +Padding2D CalculateSamePadding(const BHWC& input, + const DepthwiseConvolution2DAttributes& attr) { + return MakeSamePadding(input, attr); +} + +Padding2D CalculateSamePadding(const BHWC& input, + const Pooling2DAttributes& attr) { + return MakeSamePadding(input, attr); +} + +Padding2D CalculateSamePadding(const BHWC& input, + const MaxUnpooling2DAttributes& attr) { + return MakeSamePadding(input, attr); +} + +float CalculateResizeScale(int32_t input_size, int32_t output_size, + const Upsample2DAttributes& attr) { + return attr.align_corners && input_size > 1 && output_size > 1 + ? static_cast(input_size - 1) / (output_size - 1) + : static_cast(input_size) / output_size; +} + +BHWC CalculateOutputShape(const BHWC& input, const Upsample2DAttributes& attr) { + return BHWC(input.b, attr.new_shape.h, attr.new_shape.w, input.c); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h new file mode 100644 index 00000000000..ef825376b31 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/operations.h @@ -0,0 +1,337 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATIONS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATIONS_H_ + +#include +#include +#include + +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { + +// Non exhaustive list of operations. +enum class OperationType { + UNKNOWN = 0, + ABS, + ADD, + // TODO(eignasheva): remove APPLY_MASK operation, is should be just MUL + APPLY_MASK, + BATCH_TO_SPACE, + BATCH_NORMALIZATION, + CONCAT, + CONST, + CONVOLUTION_2D, + CONVOLUTION_TRANSPOSED, + COS, + DEPTHWISE_CONVOLUTION, + DIV, + FULLY_CONNECTED, + LOG, + LSTM, + MAX_UNPOOLING_2D, + MUL, + MULTIPLY_SCALAR, + POOLING_2D, + POW, + PAD, + PRELU, + RELU, + RESHAPE, + RESIZE, + RSQRT, + SIGMOID, + SIN, + SLICE, + SOFT_MAX, + SPACE_TO_BATCH, + SQRT, + SQUARE, + SQUARED_DIFF, + SUB, + TANH, + UPSAMPLE_2D, +}; + +std::string ToString(enum OperationType op); + +OperationType OperationTypeFromString(const std::string& name); + +struct Padding2D { + Padding2D() = default; + Padding2D& operator=(const Padding2D& value); + bool operator==(const Padding2D& value); + bool operator!=(const Padding2D& value); + Padding2D& operator-(const Padding2D& value); + + // Padding values for every axis (if needed), where 'prepended' defines + // padding for the beginning of each axis and 'appended' represents end part + // of the corresponding axis. + HW prepended = HW(-1, -1); + HW appended = HW(-1, -1); +}; + +struct Crop2D : public Padding2D {}; + +struct SpaceToBatchAttributes { + HW block; + Padding2D padding; +}; + +struct BatchToSpaceAttributes { + HW block; + Crop2D crop; +}; + +enum class PoolingType { + UNDEFINED = 0, + + // average pooling + AVERAGE = 1, + + // max pooling + MAX = 2, +}; + +struct Pooling2DAttributes { + PoolingType type = PoolingType::UNDEFINED; + // Strides for every axis. + HW strides = HW(-1, -1); + HW kernel = HW(-1, -1); + Padding2D padding; + // NOTE(akulik): technically the number of outputs from Pooling node indicates + // whether indices are needed or not, but I decided to keep it inside + // attributes to simplify processing. + bool output_indices = false; +}; + +struct MaxUnpooling2DAttributes { + // Strides for every axis. + HW strides = HW(-1, -1); + HW kernel = HW(-1, -1); + Padding2D padding; +}; + +struct ConcatAttributes { + // Defines axis by which to concat on. + Axis axis = Axis::UNKNOWN; +}; + +// @return shape of a tensor after MaxUnpooling2D operation is applied to +// the given input. +BHWC CalculateOutputShape(const BHWC& input, + const MaxUnpooling2DAttributes& attr); + +// @return shape of a tensor after Pooling2D operation is applied to the given +// input. +BHWC CalculateOutputShape(const BHWC& input, const Pooling2DAttributes& attr); + +// @return shape of a tensor after Concat operation is applied to the given +// input. +Status CalculateOutputShape(const std::vector& input, + const ConcatAttributes& attr, BHWC* output_shape); + +// @return padding for pooling operation to make sure output keep the same shape +// as the given input. +Padding2D CalculateSamePadding(const BHWC& input, + const Pooling2DAttributes& attr); + +// @return padding for max unpooling operation to make sure output keep the same +// shape as the given input. +Padding2D CalculateSamePadding(const BHWC& input, + const MaxUnpooling2DAttributes& attr); + +struct Convolution2DAttributes { + HW strides = HW(1, 1); // Along each axis. + HW dilations = HW(1, 1); // Along each axis. + Padding2D padding; + + Tensor weights; + Tensor bias; // optional +}; + +// @return shape of a tensor after Convolution2D operation is applied to +// the given input. +BHWC CalculateOutputShape(const BHWC& input, + const Convolution2DAttributes& attr); + +// @return padding for convolution operation to make sure output keep the same +// shape as the given input. +Padding2D CalculateSamePadding(const BHWC& input, + const Convolution2DAttributes& attr); + +struct ConvolutionTransposedAttributes { + HW stride = HW(1, 1); // Along each axis. + HW adjacent; // TODO(sorokin): No op on Flow. + Padding2D padding; + + Tensor weights; + Tensor bias; // optional +}; + +Padding2D CalculateSamePadding(const BHWC& input, + const ConvolutionTransposedAttributes& attr); + +// @return shape of a tensor after ConvolutionTransposed operation is applied to +// the given input. +BHWC CalculateOutputShape(const BHWC& input, + const ConvolutionTransposedAttributes& attr); + +struct DepthwiseConvolution2DAttributes : public Convolution2DAttributes {}; + +// @return shape of a tensor after DepthwiseConvolution2D operation is applied +// to the given input. +BHWC CalculateOutputShape(const BHWC& input, + const DepthwiseConvolution2DAttributes& attr); + +// @return padding for depthwise convolution operation to make sure output keep +// the same shape as the given input. +Padding2D CalculateSamePadding(const BHWC& input, + const DepthwiseConvolution2DAttributes& attr); + +BHWC CalculateOutputShape(const BHWC& input, + const DepthwiseConvolution2DAttributes& attr); + +// f(x):= { +// if x < 0 : x -> alpha * x +// if x >= 0 : x -> min(clip, x) +// } +// +// Examples: +// - ReLU: clip = 0, alpha = 0 +// - ReLU6: clip = 6, alpha = 0 +// - Leaky ReLU: clip = 0, alpha = a +struct ReLUAttributes { + // clip <= 0 mean it is not set. + float clip = 0; + + float alpha = 0; +}; + +struct PReLUAttributes { + // clip <= 0 mean it is not set. + float clip = 0; + + // If alpha is linear, then it is sharded across CHANNELS axis, otherwise + // full shape alpha is required. + absl::variant, + Tensor> + alpha; +}; + +struct SoftMaxAttributes { + Axis axis = Axis::UNKNOWN; +}; + +enum LstmKernelType { + FULL = 0, + BASIC = 1, // Currently, only basic is supported. +}; + +struct LstmAttributes { + LstmKernelType kernel_type = LstmKernelType::BASIC; +}; + +struct MultiplyScalarAttributes { + absl::variant, float> + param; +}; + +enum class UpsamplingType { + NEAREST = 0, + BILINEAR = 1, +}; + +struct Upsample2DAttributes { + HW new_shape; + + UpsamplingType type = UpsamplingType::NEAREST; + + // If true, the centers of the 4 corner pixels of the input and output tensors + // are aligned, preserving the values at the corner pixels. Defaults to false. + bool align_corners = false; +}; + +float CalculateResizeScale(int32_t input_size, int32_t output_size, + const Upsample2DAttributes& attr); + +// @return shape of a tensor after upscale operation is applied to the given +// input. +BHWC CalculateOutputShape(const BHWC& input, const Upsample2DAttributes& attr); + +enum class PaddingContentType { + ZEROS = 0, + REFLECT = 1, + EDGE = 2, +}; + +struct PadAttributes { + PaddingContentType type = PaddingContentType::ZEROS; + + HWC prepended; + HWC appended; +}; + +// @return shape of a tensor after Pad operation is applied to the given input. +BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr); + +struct ConstTensorAttributes { + Tensor tensor; +}; + +// Simple slicing without advanced support for shrinking, reverse slicing etc. +struct SliceAttributes { + // Specifies start and end dimensions for slicing. + HWC starts; + HWC ends; + + // Stride should be >= 1. + HWC strides; +}; + +// @return shape of a tensor after Slice2D operation is applied to the given +// input. +BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr); + +struct AddAttributes { + absl::variant, float> + param; +}; + +struct FullyConnectedAttributes { + Tensor weights; + Tensor bias; +}; + +// @return shape of a tensor after FullyConnected operation is applied to +// the given input. +BHWC CalculateOutputShape(const BHWC& input, + const FullyConnectedAttributes& attr); + +struct ReshapeAttributes { + BHWC new_shape; +}; + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATIONS_H_ diff --git a/tensorflow/lite/delegates/gpu/common/shape.cc b/tensorflow/lite/delegates/gpu/common/shape.cc new file mode 100644 index 00000000000..df34076313c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/shape.cc @@ -0,0 +1,125 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" + +namespace tflite { +namespace gpu { +namespace { + +struct GetAxisByIndexFunc { + template + Axis operator()() const { + return GetAxis(index); + } + int32_t index; +}; + +struct GetIndexByAxisFunc { + template + int operator()() const { + return GetAxisIndex(axis); + } + Axis axis; +}; + +struct NumAxisFunc { + template + int operator()() const { + return Size(); + } +}; + +} // namespace + +std::string ToString(Axis axis) { + switch (axis) { + case Axis::BATCH: + return "batch"; + case Axis::CHANNELS: + return "channels"; + case Axis::INPUT_CHANNELS: + return "input_channels"; + case Axis::OUTPUT_CHANNELS: + return "output_channels"; + case Axis::HEIGHT: + return "height"; + case Axis::WIDTH: + return "width"; + case Axis::VALUE: + return "value"; + case Axis::UNKNOWN: + return "unknown"; + } + return "undefined"; +} + +std::string ToString(Layout layout) { + switch (layout) { + case Layout::SCALAR: + return "scalar"; + case Layout::LINEAR: + return "linear"; + case Layout::HW: + return "hw"; + case Layout::CHW: + return "chw"; + case Layout::HWC: + return "hwc"; + case Layout::OHWI: + return "ohwi"; + case Layout::IHWO: + return "ihwo"; + case Layout::OIHW: + return "oihw"; + case Layout::IOHW: + return "iohw"; + case Layout::BHWC: + return "bhwc"; + case Layout::UNKNOWN: + return "unknown"; + } + return "undefined"; +} + +Axis GetAxis(Layout layout, int32_t index) { + return DispatchByLayout(layout, GetAxisByIndexFunc{index}); +} + +int GetAxisIndex(Layout layout, Axis axis) { + return DispatchByLayout(layout, GetIndexByAxisFunc{axis}); +} + +int Size(Layout layout) { return DispatchByLayout(layout, NumAxisFunc()); } + +std::string ToString(const Shape& s) { + return absl::StrCat("{", ToString(s.layout), ", {", + absl::StrJoin(s.dimensions, ", "), "}}"); +} + +template <> +int64_t StrongShape::LinearIndex( + const std::array& coordinates) const { + int64_t index = coordinates[0]; + index = index * StrongShape::get(1) + coordinates[1]; + index = index * StrongShape::get(2) + coordinates[2]; + index = index * StrongShape::get(3) + coordinates[3]; + return index; +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/shape.h b/tensorflow/lite/delegates/gpu/common/shape.h new file mode 100644 index 00000000000..f18e696517e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/shape.h @@ -0,0 +1,612 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tflite { +namespace gpu { + +enum class Axis { + UNKNOWN = 0, + CHANNELS = 1, + INPUT_CHANNELS = 2, + OUTPUT_CHANNELS = 3, + HEIGHT = 4, + WIDTH = 5, + BATCH = 6, + VALUE = 7, +}; + +std::string ToString(Axis t); + +// Layout represents axis order. +enum class Layout { + UNKNOWN = 0, + SCALAR = 1, + LINEAR = 2, + HW = 3, + CHW = 4, + HWC = 5, + OIHW = 6, + OHWI = 7, + IHWO = 8, + IOHW = 9, + BHWC = 10, +}; + +std::string ToString(Layout l); + +// Returns number of axis for the fixed layout. +template +constexpr int Size(); + +// Returns number of axis for the given layout. +int Size(Layout layout); + +// Returns Axis for the given index and fixed layout. +template +constexpr Axis GetAxis(int index); + +// Returns axis for the given layout and index. +Axis GetAxis(Layout layout, int32_t index); + +// Returns axis index for the given axis and fixed layout. +template +constexpr int GetAxisIndex(Axis axis); + +// Returns axis index for the given layout and axis. +int GetAxisIndex(Layout layout, Axis axis); + +// Stores Layout(axis set and order) and value for dimensions. +struct Shape { + Shape() : layout(Layout::UNKNOWN), dimensions() {} + + explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {} + + Shape(Layout t, std::vector d) + : layout(t), dimensions(std::move(d)) {} + + bool operator==(const Shape& other) const { + return (layout == other.layout) && (dimensions == other.dimensions); + } + + bool operator!=(const Shape& other) const { return !operator==(other); } + + // All methods below are matching same methods defined in StrongShape to + // make sure generic algorithms work both ways. + + // Returns back a dimension or -1 if it is not found. + template + int32_t get() const; + int32_t get(Axis d) const; + + template + bool set(int32_t t); + bool set(Axis d, int32_t t); + + Axis axis(int index) const { return GetAxis(layout, index); } + + int index(Axis d) const { return GetAxisIndex(layout, d); } + + int64_t DimensionsProduct() const { + return std::accumulate(dimensions.begin(), dimensions.end(), 1ll, + std::multiplies()); + } + + Layout layout = Layout::UNKNOWN; + + std::vector dimensions; +}; + +std::string ToString(const Shape& s); + +// StrongShape provides convenient explicit access to dimensions stored in +// shape, e.g. StrongShape s; provides s.h and s.w accessors. +// +// There is a conversion possible both ways between Shape and StrongShape. +// +// OIHW oihw; // specific shape +// Shape l = oihw.ToShape(); +// +// OHWI other; // notice not the same but compatible shape. +// if (!other.Adopt(l)) { +// // error handling +// } +// +// StrongShape supports the following set of operations: +// +// // Returns number of axis in the shape class. +// static constexpr int size(); +// +// // Returns Axis for the given index or Axis::UNKNOWN if index +// // falls outside of the defined range in this shape. +// static constexpr Axis axis(int index); +// +// // Returns index for the given axis or -1 if axis is not defined in this +// // shape. +// static constexpr int index(Axis d); +// +// // Getters +// int32_t get(int index) const; +// int32_t get(Axis d) const; +// int32_t get() const; +// +// // Setters that return false if set was not successful. +// bool set(int index, int32_t v); +// bool set(Axis d, int32_t v); +// bool set(int32_t v); +// +// // Returns shape's layout. +// static const Layout layout; +// +// // Turns specific shape into generic shape. +// Shape ToShape() const; +// +// // Copies all dimensions from the given shape. +// bool Adopt(const Shape&); +// +template +struct StrongShape; + +using Scalar = StrongShape; +using Linear = StrongShape; +using HW = StrongShape; + +// Common tensor shape for CNN models working with images. +using CHW = StrongShape; +using HWC = StrongShape; +using BHWC = StrongShape; + +// Tensor shape used in convolution_2d weights. +using OIHW = StrongShape; +using OHWI = StrongShape; +using IHWO = StrongShape; +using IOHW = StrongShape; + +// ----------------------------------------------------------------------------- +// Everything below are internal implementation details. +// ----------------------------------------------------------------------------- + +namespace internal_shape { + +template +struct AxisTraits; + +#define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName) \ + template <> \ + struct AxisTraits { \ + struct Holder { \ + int32_t HolderName; \ + \ + protected: \ + int32_t operator()() const { return HolderName; } \ + void operator()(int32_t v) { HolderName = v; } \ + }; \ + \ + using dimension_holder_type = Holder; \ + } + +TFLITE_GPU_AXIS_TRAITS(CHANNELS, c); +TFLITE_GPU_AXIS_TRAITS(HEIGHT, h); +TFLITE_GPU_AXIS_TRAITS(WIDTH, w); +TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i); +TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o); +TFLITE_GPU_AXIS_TRAITS(BATCH, b); +TFLITE_GPU_AXIS_TRAITS(VALUE, v); + +#undef TFLITE_GPU_AXIS_TRAITS + +template +struct StrongShapeImpl; + +template +struct StrongShapeImpl { + static constexpr int size() { return N; } + + static constexpr Axis axis(int) { return Axis::UNKNOWN; } + + static constexpr int index(Axis) { return -1; } + + int32_t get(Axis) const { return -1; } + + int32_t get(int) const { return -1; } + + template + int32_t get() const { + return -1; + } + + bool set(Axis, int32_t) { return false; } + + bool set(int, int32_t) { return false; } + + template + bool set(int32_t) { + return false; + } +}; + +// Used to deduce number of axis, and to be a child of a proper holder to +// provide access to the dimension by name +template +struct StrongShapeImpl + : public AxisTraits::dimension_holder_type, + public StrongShapeImpl { + using dimension_holder_type = typename AxisTraits::dimension_holder_type; + + using rest_type = StrongShapeImpl; + + StrongShapeImpl() : dimension_holder_type{0}, rest_type() {} + + template + explicit StrongShapeImpl(int32_t t, Ts... ts) + : dimension_holder_type{t}, rest_type(ts...) {} + + static constexpr Axis axis(int index) { + return index == N ? A : rest_type::axis(index); + } + + static constexpr int index(Axis d) { + return d == A ? N : rest_type::index(d); + } + + int32_t get(Axis d) const { + return d == A ? dimension_holder_type::operator()() : rest_type::get(d); + } + + template + int32_t get() const { + return B == A ? dimension_holder_type::operator()() + : rest_type::template get(); + } + + int32_t get(int index) const { + return index == N ? dimension_holder_type::operator()() + : rest_type::get(index); + } + + bool set(Axis d, int32_t t) { + if (d == A) { + dimension_holder_type::operator()(t); + return true; + } + return rest_type::set(d, t); + } + + bool set(int index, int32_t t) { + if (index == N) { + dimension_holder_type::operator()(t); + return true; + } + return rest_type::set(index, t); + } + + template + bool set(int32_t t) { + if (A == B) { + dimension_holder_type::operator()(t); + return true; + } + return rest_type::template set(t); + } +}; + +template +struct LayoutTraits; + +#define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...) \ + template <> \ + struct LayoutTraits { \ + using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \ + } + +TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH); +TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, + Axis::INPUT_CHANNELS); +TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS, + Axis::HEIGHT, Axis::WIDTH); +TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS, + Axis::HEIGHT, Axis::WIDTH); +TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, + Axis::OUTPUT_CHANNELS); +TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH); +TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS); +TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE); +TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE); +TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, + Axis::CHANNELS); + +#undef TFLITE_GPU_LAYOUT_TRAITS + +template <> +struct LayoutTraits { + using strong_shape_type = StrongShapeImpl<0>; +}; + +template +struct DimensionGetterFixedAxisFunc { + template + int32_t operator()() const { + constexpr int i = GetAxisIndex(A); + return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1; + } + const Shape* l; +}; + +struct DimensionGetterFunc { + template + int32_t operator()() const { + int i = GetAxisIndex(d); + return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1; + } + Axis d; + const Shape* l; +}; + +template +struct DimensionSetterFixedAxisFunc { + template + bool operator()() const { + constexpr int i = GetAxisIndex(A); + if (i >= 0 && i < l->dimensions.size()) { + l->dimensions[i] = v; + return true; + } + return false; + } + Shape* l; + int32_t v; +}; + +struct DimensionSetterFunc { + template + bool operator()() const { + int i = GetAxisIndex(d); + if (i >= 0 && i < l->dimensions.size()) { + l->dimensions[i] = v; + return true; + } + return false; + } + Axis d; + Shape* l; + int32_t v; +}; + +template +struct ToShapeFunc { + template + bool operator()() const { + for (int i = 0; i < StrongShape::size(); ++i) { + int index = GetAxisIndex(StrongShape::axis(i)); + if (index < 0) return false; + shape->set(i, l.dimensions[index]); + } + return true; + } + + StrongShape* shape; + const Shape& l; +}; + +} // namespace internal_shape + +// template +template +struct StrongShape : public internal_shape::LayoutTraits::strong_shape_type { + using strong_shape_type = + typename internal_shape::LayoutTraits::strong_shape_type; + StrongShape() = default; + + template + explicit StrongShape(Ts... t) : strong_shape_type(t...) {} + + constexpr static Layout layout = L; + + bool operator==(const StrongShape& shape) const { + // TODO(akulik): implement better alternative. + return this->ToShape() == shape.ToShape(); + } + + bool operator!=(const StrongShape& shape) const { + // TODO(akulik): implement better alternative. + return this->ToShape() != shape.ToShape(); + } + bool empty() const { return DimensionsProduct() == 0; } + + // Turns StrongShape into generic shape. + Shape ToShape() const { + std::vector dimensions(StrongShape::size()); + for (int i = 0; i < StrongShape::size(); ++i) { + dimensions[i] = StrongShape::get(i); + } + return Shape(L, std::move(dimensions)); + } + + // @return all dimensions multiplied + int64_t DimensionsProduct() const { + int64_t product = 1; + for (int i = 0; i < StrongShape::size(); ++i) { + product *= StrongShape::get(i); + } + return product; + } + + // Translates given coordinates of the layout into a linear index assuming + // dimensions are sorted in tensor access order e.g. if you access + // foobar[i][j][k] order of coordinates should be i,j,k. + int64_t LinearIndex( + const std::array& coordinates) const { + int64_t index = coordinates[0]; + for (int i = 1; i < StrongShape::size(); ++i) { + index = index * StrongShape::get(i) + coordinates[i]; + } + return index; + } + + // Copies all dimensions from the given generic shape into specific shape. + // It requires shape to have all axis defined in the given + // StrongShape. For example: + // - If this shape is OHWI but given shape is OIHW, Adopt will copy all + // dimensions and return true. + // - If this shape is OIHW but input shape is HW, Adopt will copy H and W + // dimensions and return true, but if this shape is HW and given shape + // OIHW, then Adopt will return false because not all axis are present in + // the input shape. + // + // @return false if generic shape is not compatible. + bool Adopt(const Shape& shape) { + return DispatchByLayout(shape.layout, + internal_shape::ToShapeFunc{this, shape}); + } + + // For all axis defined in a given shape copies values to this shape. + // Therefore, it is possible to copy dimensions from CHW to BCHW, but not + // the other way around. + // + // BCHW bchw; + // CHW chw; + // bchw.CopyAllGivenAxis(chw); --> true + // chw.CopyAllGivenAxis(bchw); --> false + // + // @return false if axis in source shape is not defined here, thus value + // was not copied. + template + bool CopyAllGivenAxis(const StrongShape& source) { + for (int i = 0; i < source.size(); ++i) { + if (!StrongShape::set(source.axis(i), source.get(i))) { + return false; + } + } + return true; + } + + // For all axis defined in this shape copies values from the given shape. + // + // BCHW bchw; + // CHW chw; + // bchw.CopyAllDefinedAxis(chw); --> false + // chw.CopyAllDefinedAxis(bchw); --> true + // + // @return false if given shape does not have axis defined here, + // therefore a value was not copied. + template + bool CopyAllDefinedAxis(const StrongShape& source) { + for (int i = 0; i < StrongShape::size(); ++i) { + int source_index = source.index(StrongShape::axis(i)); + if (source_index < 0) { + return false; + } + StrongShape::set(i, source.get(source_index)); // always true + } + return true; + } + + // Copies values only for matching axis. + template + void CopyMatchingAxis(const StrongShape& source) { + for (int i = 0; i < StrongShape::size(); ++i) { + StrongShape::set(source.axis(i), source.get(i)); + } + } +}; + +template +inline std::string ToString(const StrongShape& s) { + return ToString(s.ToShape()); +} + +template +constexpr Layout StrongShape::layout; + +template +auto DispatchByLayout(Layout type, F f) + -> decltype(f.template operator()()) { + switch (type) { + case Layout::HW: + return f.template operator()(); + case Layout::HWC: + return f.template operator()(); + case Layout::CHW: + return f.template operator()(); + case Layout::OIHW: + return f.template operator()(); + case Layout::IOHW: + return f.template operator()(); + case Layout::OHWI: + return f.template operator()(); + case Layout::IHWO: + return f.template operator()(); + case Layout::LINEAR: + return f.template operator()(); + case Layout::SCALAR: + return f.template operator()(); + case Layout::BHWC: + return f.template operator()(); + case Layout::UNKNOWN: + return f.template operator()(); + } +} + +template +constexpr int Size() { + return StrongShape::size(); +} + +template +constexpr Axis GetAxis(int index) { + return StrongShape::axis(index); +} + +template +constexpr int GetAxisIndex(Axis axis) { + return StrongShape::index(axis); +} + +template +inline int32_t Shape::get() const { + return DispatchByLayout( + layout, internal_shape::DimensionGetterFixedAxisFunc{this}); +} + +inline int32_t Shape::get(Axis d) const { + return DispatchByLayout(layout, internal_shape::DimensionGetterFunc{d, this}); +} + +template +inline bool Shape::set(int32_t t) { + return DispatchByLayout( + layout, internal_shape::DimensionSetterFixedAxisFunc{this, t}); +} + +inline bool Shape::set(Axis d, int32_t t) { + return DispatchByLayout(layout, + internal_shape::DimensionSetterFunc{d, this, t}); +} + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_ diff --git a/tensorflow/lite/delegates/gpu/common/shape_test.cc b/tensorflow/lite/delegates/gpu/common/shape_test.cc new file mode 100644 index 00000000000..a55311cad8a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/shape_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +#include +#include + +#include +#include + +namespace tflite { +namespace gpu { +namespace { + +TEST(OIHW, Smoke) { + OIHW OIHW; + + // Test 4 different versions of setters. + OIHW.i = 1; + ASSERT_TRUE(OIHW.set(2)); + ASSERT_TRUE(OIHW.set(Axis::HEIGHT, 3)); + ASSERT_TRUE(OIHW.set(3, 4)); + + // Make sure invalid setters return false. + ASSERT_FALSE(OIHW.set(5, 10)); + ASSERT_FALSE(OIHW.set(Axis::CHANNELS, 10)); + ASSERT_FALSE(OIHW.set(10)); + + // Test 4 different versions of getters + EXPECT_EQ(1, OIHW.get(Axis::INPUT_CHANNELS)); + EXPECT_EQ(2, OIHW.o); + EXPECT_EQ(3, OIHW.get(2)); + EXPECT_EQ(4, OIHW.get()); + + // Make sure getters that fall outside of a range return invalid axis. + EXPECT_EQ(-1, OIHW.get(5)); + EXPECT_EQ(-1, OIHW.get(Axis::CHANNELS)); + EXPECT_EQ(-1, OIHW.get()); + + // Check axis indices are all correct. + ASSERT_EQ(4, OIHW.size()); + std::vector expected = {Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS, + Axis::HEIGHT, Axis::WIDTH}; + for (int i = 0; i < OIHW.size(); ++i) { + Axis axis = OIHW.axis(i); + ASSERT_EQ(expected[i], axis); + ASSERT_EQ(i, OIHW.index(axis)); + } + + // Check equivalent conversions. + OHWI ohwi; + ASSERT_TRUE(ohwi.CopyAllDefinedAxis(OIHW)); + EXPECT_EQ(ohwi.o, OIHW.o); + EXPECT_EQ(ohwi.i, OIHW.i); + EXPECT_EQ(ohwi.h, OIHW.h); + EXPECT_EQ(ohwi.w, OIHW.w); + + ohwi = OHWI(10, 20, 30, 40); + ASSERT_TRUE(OIHW.CopyAllGivenAxis(ohwi)); + EXPECT_EQ(ohwi.o, OIHW.o); + EXPECT_EQ(ohwi.i, OIHW.i); + EXPECT_EQ(ohwi.h, OIHW.h); + EXPECT_EQ(ohwi.w, OIHW.w); +} + +TEST(Layout, Smoke) { + EXPECT_EQ(4, Size()); + EXPECT_EQ(4, Size(Layout::OIHW)); + std::vector expected = {Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS, + Axis::HEIGHT, Axis::WIDTH}; + for (int i = 0; i < Size(); ++i) { + Axis axis = GetAxis(i); + ASSERT_EQ(expected[i], axis); + ASSERT_EQ(axis, GetAxis(Layout::OIHW, i)); + ASSERT_EQ(i, GetAxisIndex(axis)); + ASSERT_EQ(i, GetAxisIndex(Layout::OIHW, axis)); + } + EXPECT_EQ(Axis::UNKNOWN, GetAxis(Layout::OIHW, 5)); + EXPECT_EQ(-1, GetAxisIndex(Axis::CHANNELS)); + EXPECT_EQ(-1, GetAxisIndex(Axis::CHANNELS)); +} + +TEST(Shape, Smoke) { + Shape s(Layout::OIHW, {1, 2, 3, 4}); + EXPECT_TRUE(s.set(Axis::HEIGHT, 10)); + EXPECT_TRUE(s.set(20)); + EXPECT_FALSE(s.set(Axis::BATCH, 10)); + EXPECT_FALSE(s.set(20)); + + ASSERT_EQ(10, s.get()); + ASSERT_EQ(20, s.get(Axis::WIDTH)); + EXPECT_EQ(20, s.dimensions[3]); + + OIHW oihw(1, 2, 10, 20); + Shape s2 = oihw.ToShape(); + EXPECT_EQ(s2.layout, oihw.layout); + EXPECT_EQ(s.layout, s2.layout); + EXPECT_EQ(s.dimensions, s2.dimensions); + + // Convert layout into compatible shape. + OHWI ohwi; + ASSERT_TRUE(ohwi.Adopt(s2)); + EXPECT_EQ(1, ohwi.o); + EXPECT_EQ(2, ohwi.i); + EXPECT_EQ(10, ohwi.h); + EXPECT_EQ(20, ohwi.w); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/status.h b/tensorflow/lite/delegates/gpu/common/status.h new file mode 100644 index 00000000000..250a3b5e3eb --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/status.h @@ -0,0 +1,124 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ + +#include + +namespace tflite { +namespace gpu { + +enum class StatusCode { + kOk = 0, + kCancelled = 1, + kUnknown = 2, + kInvalidArgument = 3, + kDeadlineExceeded = 4, + kNotFound = 5, + kAlreadyExists = 6, + kPermissionDenied = 7, + kResourceExhausted = 8, + kFailedPrecondition = 9, + kAborted = 10, + kOutOfRange = 11, + kUnimplemented = 12, + kInternal = 13, + kUnavailable = 14, + kDataLoss = 15, + kUnauthenticated = 16, + kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 +}; + +// Lite version of Status without dependency on protobuf. +// TODO(b/128867901): Migrate to absl::Status. +class Status { + public: + Status() = default; + Status(StatusCode code) : code_(code) {} + Status(StatusCode code, const std::string& error_message) + : code_(code), error_message_(error_message) {} + + const std::string& error_message() const { return error_message_; } + StatusCode code() const { return code_; } + bool ok() const { return code_ == StatusCode::kOk; } + + void IgnoreError() const {} + + private: + StatusCode code_ = StatusCode::kOk; + std::string error_message_; +}; + +#define RETURN_IF_ERROR(status) \ + { \ + const auto status2 = (status); \ + if (!status2.ok()) return status2; \ + } + +inline Status OkStatus() { return Status(); } + +inline Status AlreadyExistsError(const std::string& message) { + return Status(StatusCode::kAlreadyExists, message); +} + +inline Status DeadlineExceededError(const std::string& message) { + return Status(StatusCode::kDeadlineExceeded, message); +} + +inline Status FailedPreconditionError(const std::string& message) { + return Status(StatusCode::kFailedPrecondition, message); +} + +inline Status InternalError(const std::string& message) { + return Status(StatusCode::kInternal, message); +} + +inline Status InvalidArgumentError(const std::string& message) { + return Status(StatusCode::kInvalidArgument, message); +} + +inline Status NotFoundError(const std::string& message) { + return Status(StatusCode::kNotFound, message); +} + +inline Status OutOfRangeError(const std::string& message) { + return Status(StatusCode::kOutOfRange, message); +} + +inline Status PermissionDeniedError(const std::string& message) { + return Status(StatusCode::kPermissionDenied, message); +} + +inline Status ResourceExhaustedError(const std::string& message) { + return Status(StatusCode::kResourceExhausted, message); +} + +inline Status UnavailableError(const std::string& message) { + return Status(StatusCode::kUnavailable, message); +} + +inline Status UnimplementedError(const std::string& message) { + return Status(StatusCode::kUnimplemented, message); +} + +inline Status UnknownError(const std::string& message) { + return Status(StatusCode::kUnknown, message); +} + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_STATUS_H_ diff --git a/tensorflow/lite/delegates/gpu/common/tensor.h b/tensorflow/lite/delegates/gpu/common/tensor.h new file mode 100644 index 00000000000..2e11b191a47 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/tensor.h @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +namespace tflite { +namespace gpu { +namespace internal_tensor { + +// Meta function given element type returns a type for Tensor data container. +template +struct StorageType; + +template <> +struct StorageType { + using value = std::vector; +}; + +template <> +struct StorageType { + using value = std::vector; +}; + +} // namespace internal_tensor + +template +struct Tensor { + using ShapeType = ShapeT; + + constexpr static DataType kType = Type; + + using TensorStorageType = typename internal_tensor::StorageType::value; + + // Opaque id of a tensor. + int64_t id = -1; + + ShapeType shape; + + TensorStorageType data; +}; + +// TensorRef is a reference to another tensor. If an object should never hold +// tensor data, then TensorRef should be used instead. +template +struct TensorRef { + using ShapeType = ShapeT; + + DataType type = DataType::UNKNOWN; + + ShapeT shape; + + // Opaque reference to a tensor. Upstream component is responsible for + // resolving this reference into an actual tensor. + int64_t ref = -1; +}; + +template +constexpr DataType Tensor::kType; + +template +Tensor MakeZeroTensor(const ShapeT& shape) { + Tensor tensor; + tensor.shape = shape; + tensor.data = typename Tensor::TensorStorageType( + shape.DimensionsProduct(), 0); + return tensor; +} + +using TensorFloat32 = Tensor; +using TensorRefFloat32 = TensorRef; + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/BUILD b/tensorflow/lite/delegates/gpu/common/transformations/BUILD new file mode 100644 index 00000000000..8fa03687adc --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/BUILD @@ -0,0 +1,232 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "add_bias", + srcs = ["add_bias.cc"], + hdrs = ["add_bias.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + ], +) + +cc_library( + name = "fuse_add_to_conv", + srcs = ["fuse_add_to_conv.cc"], + hdrs = ["fuse_add_to_conv.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "fuse_add_to_conv_test", + srcs = ["fuse_add_to_conv_test.cc"], + deps = [ + ":fuse_add_to_conv", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "fuse_mul_to_conv", + srcs = ["fuse_mul_to_conv.cc"], + hdrs = ["fuse_mul_to_conv.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", + ], +) + +cc_test( + name = "fuse_mul_to_conv_test", + srcs = ["fuse_mul_to_conv_test.cc"], + deps = [ + ":fuse_mul_to_conv", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "general_transformations", + srcs = ["general_transformations.cc"], + hdrs = ["general_transformations.h"], + deps = [ + ":fuse_add_to_conv", + ":fuse_mul_to_conv", + ":make_fully_connected", + ":make_padding", + ":match_dilated_convolution", + ":merge_padding_with", + ":remove_noop", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + ], +) + +cc_library( + name = "make_fully_connected", + srcs = ["make_fully_connected.cc"], + hdrs = ["make_fully_connected.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:any", + ], +) + +cc_test( + name = "make_fully_connected_test", + srcs = ["make_fully_connected_test.cc"], + deps = [ + ":make_fully_connected", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "@com_google_absl//absl/types:any", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "make_padding", + srcs = ["make_padding.cc"], + hdrs = ["make_padding.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:any", + ], +) + +cc_test( + name = "make_padding_test", + srcs = ["make_padding_test.cc"], + deps = [ + ":make_padding", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "@com_google_absl//absl/types:any", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "matching", + hdrs = ["matching.h"], + deps = ["//tensorflow/lite/delegates/gpu/common:model"], +) + +cc_library( + name = "match_dilated_convolution", + srcs = ["match_dilated_convolution.cc"], + hdrs = ["match_dilated_convolution.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:any", + ], +) + +cc_test( + name = "match_dilated_convolution_test", + srcs = ["match_dilated_convolution_test.cc"], + deps = [ + ":match_dilated_convolution", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "@com_google_absl//absl/types:any", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "merge_padding_with", + srcs = ["merge_padding_with.cc"], + hdrs = ["merge_padding_with.h"], + deps = [ + ":matching", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + ], +) + +cc_test( + name = "merge_padding_with_test", + srcs = ["merge_padding_with_test.cc"], + deps = [ + ":merge_padding_with", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "@com_google_absl//absl/types:any", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "remove_noop", + srcs = ["remove_noop.cc"], + hdrs = ["remove_noop.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + ], +) + +cc_test( + name = "remove_noop_test", + srcs = ["remove_noop_test.cc"], + deps = [ + ":remove_noop", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc new file mode 100644 index 00000000000..7feac824ef7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc @@ -0,0 +1,74 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace { + +template +TransformResult FillBias(Node* node) { + auto& attr = absl::any_cast(node->operation.attributes); + if (attr.bias.data.empty()) { + const int dst_channels = attr.weights.shape.o; + attr.bias = MakeZeroTensor(Linear(dst_channels)); + return {TransformStatus::APPLIED, "Added bias"}; + } + return {TransformStatus::SKIPPED, ""}; +} + +template TransformResult FillBias(Node* node); +template TransformResult FillBias(Node* node); +template TransformResult FillBias(Node* node); +template TransformResult FillBias(Node* node); + +class AddBias : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final { + if (node->operation.type == ToString(OperationType::CONVOLUTION_2D)) { + return FillBias(node); + } + if (node->operation.type == + ToString(OperationType::CONVOLUTION_TRANSPOSED)) { + return FillBias(node); + } + if (node->operation.type == + ToString(OperationType::DEPTHWISE_CONVOLUTION)) { + return FillBias(node); + } + if (node->operation.type == ToString(OperationType::FULLY_CONNECTED)) { + return FillBias(node); + } + return {TransformStatus::SKIPPED, ""}; + } +}; + +} // namespace + +std::unique_ptr NewAddBias() { + return absl::make_unique(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.h b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.h new file mode 100644 index 00000000000..1523c413c54 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_BIAS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_BIAS_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +// Makes optional bias(Conv/Deconv and etc) as not optional(always present) +std::unique_ptr NewAddBias(); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_BIAS_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc new file mode 100644 index 00000000000..cf7bbc1dd4a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -0,0 +1,235 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h" + +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace { + +void FuseBiasWithAddAttributes(const AddAttributes& add_attr, + const int channels, + Tensor* bias) { + auto add = absl::get_if>(&add_attr.param); + auto add_scalar = absl::get_if(&add_attr.param); + if (bias->data.empty()) { + *bias = MakeZeroTensor(Linear(channels)); + } + for (int d = 0; d < channels; ++d) { + bias->data[d] += add ? add->data[d] : *add_scalar; + } +} + +class MergeConvolutionWithAdd : public SequenceTransformation { + public: + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final { + auto& conv_node = *sequence[0]; + auto& add_node = *sequence[1]; + if (add_node.operation.type != ToString(OperationType::ADD)) { + return {TransformStatus::SKIPPED, ""}; + } + AddAttributes add_attr = + absl::any_cast(add_node.operation.attributes); + if (!absl::get_if>(&add_attr.param) && + !absl::get_if(&add_attr.param)) { + return {TransformStatus::DECLINED, + "This fuse applicable only for broadcast or scalar addition."}; + } + + if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { + Convolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseConvolution2DWithAdd(add_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::CONVOLUTION_TRANSPOSED)) { + ConvolutionTransposedAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseConvolutionTransposedWithAdd(add_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::DEPTHWISE_CONVOLUTION)) { + DepthwiseConvolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseDepthwiseConvolution2DWithAdd(add_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::FULLY_CONNECTED)) { + FullyConnectedAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseFullyConnectedWithAdd(add_attr, conv_attr); + } else { + return {TransformStatus::SKIPPED, ""}; + } + + Status status = RemoveFollowingNode(graph, &add_node, &conv_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove add node after convolution: " + + status.error_message()}; + } + return {TransformStatus::APPLIED, ""}; + } +}; + +class MergeAddWithConvolution : public SequenceTransformation { + public: + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final { + auto& conv_node = *sequence[1]; + auto& add_node = *sequence[0]; + if (add_node.operation.type != ToString(OperationType::ADD)) { + return {TransformStatus::SKIPPED, ""}; + } + AddAttributes add_attr = + absl::any_cast(add_node.operation.attributes); + if (!absl::get_if>(&add_attr.param) && + !absl::get_if(&add_attr.param)) { + return {TransformStatus::DECLINED, + "This fuse applicable only for broadcast or scalar addition."}; + } + + if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { + Convolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseAddWithConvolution2D(add_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::DEPTHWISE_CONVOLUTION)) { + DepthwiseConvolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseAddWithDepthwiseConvolution2D(add_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::FULLY_CONNECTED)) { + FullyConnectedAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseAddWithFullyConnected(add_attr, conv_attr); + } else { + return {TransformStatus::SKIPPED, ""}; + } + + Status status = RemovePrecedingNode(graph, &add_node, &conv_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove add node after convolution: " + + status.error_message()}; + } + return {TransformStatus::APPLIED, ""}; + } +}; +} // namespace + +std::unique_ptr NewMergeConvolutionWithAdd() { + return absl::make_unique(); +} + +std::unique_ptr NewMergeAddWithConvolution() { + return absl::make_unique(); +} + +void FuseConvolution2DWithAdd(const AddAttributes& add_attr, + Convolution2DAttributes* attr) { + FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias); +} + +void FuseDepthwiseConvolution2DWithAdd(const AddAttributes& add_attr, + DepthwiseConvolution2DAttributes* attr) { + FuseBiasWithAddAttributes( + add_attr, attr->weights.shape.o * attr->weights.shape.i, &attr->bias); +} + +void FuseConvolutionTransposedWithAdd(const AddAttributes& add_attr, + ConvolutionTransposedAttributes* attr) { + FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias); +} + +void FuseFullyConnectedWithAdd(const AddAttributes& add_attr, + FullyConnectedAttributes* attr) { + FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias); +} + +void FuseAddWithConvolution2D(const AddAttributes& add_attr, + Convolution2DAttributes* attr) { + auto add = absl::get_if>(&add_attr.param); + auto add_scalar = absl::get_if(&add_attr.param); + if (attr->bias.data.empty()) { + attr->bias = MakeZeroTensor( + Linear(attr->weights.shape.o)); + } + for (int d = 0; d < attr->weights.shape.o; ++d) { + for (int s = 0; s < attr->weights.shape.i; ++s) { + const float add_value = add ? add->data[s] : *add_scalar; + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s}); + attr->bias.data[d] += attr->weights.data[index] * add_value; + } + } + } + } +} + +void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr, + DepthwiseConvolution2DAttributes* attr) { + auto add = absl::get_if>(&add_attr.param); + auto add_scalar = absl::get_if(&add_attr.param); + if (attr->bias.data.empty()) { + attr->bias = MakeZeroTensor( + Linear(attr->weights.shape.o * attr->weights.shape.i)); + } + for (int s = 0; s < attr->weights.shape.i; ++s) { + const float add_value = add ? add->data[s] : *add_scalar; + for (int g = 0; g < attr->weights.shape.o; ++g) { + const int d = s * attr->weights.shape.o + g; + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s}); + attr->bias.data[d] += attr->weights.data[index] * add_value; + } + } + } + } +} + +void FuseAddWithFullyConnected(const AddAttributes& add_attr, + FullyConnectedAttributes* attr) { + auto add = absl::get_if>(&add_attr.param); + auto add_scalar = absl::get_if(&add_attr.param); + if (attr->bias.data.empty()) { + attr->bias = MakeZeroTensor( + Linear(attr->weights.shape.o)); + } + for (int d = 0; d < attr->weights.shape.o; ++d) { + for (int s = 0; s < attr->weights.shape.i; ++s) { + const float add_value = add ? add->data[s] : *add_scalar; + const int index = attr->weights.shape.LinearIndex({d, 0, 0, s}); + attr->bias.data[d] += attr->weights.data[index] * add_value; + } + } +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h new file mode 100644 index 00000000000..49871a815da --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h @@ -0,0 +1,83 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_ADD_TO_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_ADD_TO_CONV_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { + +// Fuse Add Scalar or Add Broadcast after Convolution(Convolution2D, +// DepthWise, TransposedConvolution, FullyConnected) into biases of +// convolution. +std::unique_ptr NewMergeConvolutionWithAdd(); + +// Fuse Add Scalar or Add Broadcast before Convolution(Convolution2D, +// DepthWise, FullyConnected) into biases of +// convolution. +std::unique_ptr NewMergeAddWithConvolution(); + +// Modify Convolution2DAttributes so that after making convolution with +// modified attributes we will have the same result as convolution +// with old attributes and following add operation. +void FuseConvolution2DWithAdd(const AddAttributes& add_attr, + Convolution2DAttributes* attr); + +// Modify DepthwiseConvolution2DAttributes so that after making depth wise +// convolution with modified attributes we will have the same result as depth +// wise convolution with old attributes and following add operation. +void FuseDepthwiseConvolution2DWithAdd(const AddAttributes& add_attr, + DepthwiseConvolution2DAttributes* attr); + +// Modify ConvolutionTransposedAttributes so that after making convolution +// transposed with modified attributes we will have the same result as +// convolution transposed with old attributes and following add operation. +void FuseConvolutionTransposedWithAdd(const AddAttributes& add_attr, + ConvolutionTransposedAttributes* attr); + +// Modify FullyConnectedAttributes so that after making fully connected with +// modified attributes we will have the same result as fully connected +// with old attributes and following add operation. +void FuseFullyConnectedWithAdd(const AddAttributes& add_attr, + FullyConnectedAttributes* attr); + +// Modify Convolution2DAttributes so that after making convolution with +// modified attributes we will have the same result as add operation and +// convolution with old attributes +void FuseAddWithConvolution2D(const AddAttributes& add_attr, + Convolution2DAttributes* attr); + +// Modify DepthwiseConvolution2DAttributes so that after making depth wise +// convolution with modified attributes we will have the same result as add +// operation and depth wise convolution with old attributes +void FuseAddWithDepthwiseConvolution2D(const AddAttributes& add_attr, + DepthwiseConvolution2DAttributes* attr); + +// Modify FullyConnectedAttributes so that after making fully connected +// with modified attributes we will have the same result as add operation and +// fully connected with old attributes +void FuseAddWithFullyConnected(const AddAttributes& add_attr, + FullyConnectedAttributes* attr); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_ADD_TO_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc new file mode 100644 index 00000000000..33bdb77c05a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc @@ -0,0 +1,281 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h" + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +using ::testing::FloatNear; +using ::testing::Pointwise; + +namespace tflite { +namespace gpu { +namespace { + +TEST(MergeConvolutionWithAddTest, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 8); + + Convolution2DAttributes conv_attr; + conv_attr.padding.prepended = HW(0, 0); + conv_attr.padding.appended = HW(0, 0); + conv_attr.strides = HW(1, 1); + conv_attr.dilations = HW(1, 1); + conv_attr.weights.shape = OHWI(16, 3, 2, 8); + conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct()); + conv_attr.bias.shape = Linear(16); + conv_attr.bias.data.resize(16); + + Tensor add_tensor; + add_tensor.shape = Linear(16); + add_tensor.data.resize(16); + AddAttributes add_attr; + add_attr.param = add_tensor; + + auto conv_node = graph.NewNode(); + conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); + conv_node->operation.attributes = conv_attr; + auto add_node = graph.NewNode(); + add_node->operation.type = ToString(OperationType::ADD); + add_node->operation.attributes = add_attr; + + ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok()); + + Value* output; + ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); + output->tensor.shape = BHWC(1, 4, 4, 16); + + Value* link1; + ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, add_node, &link1).ok()); + link1->tensor.shape = BHWC(1, 4, 4, 16); + + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewMergeConvolutionWithAdd(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("merge_convolution_with_add", transformation.get()); + + EXPECT_EQ(1, graph.nodes().size()); + EXPECT_EQ(2, graph.values().size()); + EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D), + graph.nodes()[0]->operation.type); +} + +TEST(MergeAddWithConvolutionTest, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 8); + + Convolution2DAttributes conv_attr; + conv_attr.padding.prepended = HW(0, 0); + conv_attr.padding.appended = HW(0, 0); + conv_attr.strides = HW(1, 1); + conv_attr.dilations = HW(1, 1); + conv_attr.weights.shape = OHWI(16, 3, 2, 8); + conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct()); + conv_attr.bias.shape = Linear(16); + conv_attr.bias.data.resize(16); + + Tensor add_tensor; + add_tensor.shape = Linear(8); + add_tensor.data.resize(8); + AddAttributes add_attr; + add_attr.param = add_tensor; + + auto conv_node = graph.NewNode(); + conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); + conv_node->operation.attributes = conv_attr; + auto add_node = graph.NewNode(); + add_node->operation.type = ToString(OperationType::ADD); + add_node->operation.attributes = add_attr; + + ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok()); + + Value* output; + ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok()); + output->tensor.shape = BHWC(1, 4, 4, 16); + + Value* link1; + ASSERT_TRUE(ConnectTwoNodes(&graph, add_node, conv_node, &link1).ok()); + link1->tensor.shape = BHWC(1, 4, 4, 16); + + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewMergeAddWithConvolution(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("merge_add_with_convolution", transformation.get()); + + EXPECT_EQ(1, graph.nodes().size()); + EXPECT_EQ(2, graph.values().size()); + EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D), + graph.nodes()[0]->operation.type); +} + +TEST(FuseAddAfterConvolution2DTest, Smoke) { + Convolution2DAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.1f, 1.2f}; + + Tensor add_tensor; + add_tensor.shape = Linear(2); + add_tensor.data = {0.3f, 0.7f}; + AddAttributes add_attr; + add_attr.param = add_tensor; + + FuseConvolution2DWithAdd(add_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f})); +} + +TEST(FuseAddAfterDepthwiseConvolution2DTest, Smoke) { + DepthwiseConvolution2DAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(4); + attr.bias.data = {1.1f, 1.2f, 1.3f, 1.4f}; + + Tensor add_tensor; + add_tensor.shape = Linear(4); + add_tensor.data = {0.3f, 0.7f, 0.5f, 0.1f}; + AddAttributes add_attr; + add_attr.param = add_tensor; + + FuseDepthwiseConvolution2DWithAdd(add_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f})); + EXPECT_THAT(attr.bias.data, + Pointwise(FloatNear(1e-6), {1.4f, 1.9f, 1.8f, 1.5f})); +} + +TEST(FuseAddAfterConvolutionTransposedTest, Smoke) { + ConvolutionTransposedAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.1f, 1.2f}; + + Tensor add_tensor; + add_tensor.shape = Linear(2); + add_tensor.data = {0.3f, 0.7f}; + AddAttributes add_attr; + add_attr.param = add_tensor; + + FuseConvolutionTransposedWithAdd(add_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f})); +} + +TEST(FuseAddAfterFullyConnectedTest, Smoke) { + FullyConnectedAttributes attr; + attr.weights.shape = OHWI(2, 1, 1, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.1f, 1.2f}; + + Tensor add_tensor; + add_tensor.shape = Linear(2); + add_tensor.data = {0.3f, 0.7f}; + AddAttributes add_attr; + add_attr.param = add_tensor; + + FuseFullyConnectedWithAdd(add_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), {0.1f, 0.2f, 0.3f, 0.4f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.4f, 1.9f})); +} + +TEST(FuseAddBeforeConvolution2DTest, Smoke) { + Convolution2DAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.1f, 1.2f}; + + Tensor add_tensor; + add_tensor.shape = Linear(2); + add_tensor.data = {2.0f, 0.5f}; + AddAttributes add_attr; + add_attr.param = add_tensor; + + FuseAddWithConvolution2D(add_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {2.2f, 4.3f})); +} + +TEST(FuseAddBeforeDepthwiseConvolution2DTest, Smoke) { + DepthwiseConvolution2DAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(4); + attr.bias.data = {1.1f, 1.2f, 1.3f, 1.4f}; + + Tensor add_tensor; + add_tensor.shape = Linear(4); + add_tensor.data = {0.3f, 0.7f, 0.5f, 0.1f}; + AddAttributes add_attr; + add_attr.param = add_tensor; + + FuseAddWithDepthwiseConvolution2D(add_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f})); + EXPECT_THAT(attr.bias.data, + Pointwise(FloatNear(1e-6), {1.22f, 1.56f, 1.72f, 2.38f})); +} + +TEST(FuseAddBeforeFullyConnectedTest, Smoke) { + FullyConnectedAttributes attr; + attr.weights.shape = OHWI(2, 1, 1, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.1f, 1.2f}; + + Tensor add_tensor; + add_tensor.shape = Linear(2); + add_tensor.data = {0.5f, 2.0f}; + AddAttributes add_attr; + add_attr.param = add_tensor; + + FuseAddWithFullyConnected(add_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), {0.1f, 0.2f, 0.3f, 0.4f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.55f, 2.15f})); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc new file mode 100644 index 00000000000..3090c3f71be --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc @@ -0,0 +1,304 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h" + +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" + +namespace tflite { +namespace gpu { +namespace { + +class MergeConvolutionWithMul : public SequenceTransformation { + public: + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final { + auto& conv_node = *sequence[0]; + auto& mul_node = *sequence[1]; + if (mul_node.operation.type != ToString(OperationType::MUL) && + mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) { + return {TransformStatus::SKIPPED, ""}; + } + + MultiplyScalarAttributes mul_attr = + absl::any_cast(mul_node.operation.attributes); + if (!absl::get_if>( + &mul_attr.param) && + !absl::get_if(&mul_attr.param)) { + return { + TransformStatus::DECLINED, + "This fuse applicable only for broadcast or scalar multiplication."}; + } + + if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { + Convolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseConvolution2DWithMultiply(mul_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::CONVOLUTION_TRANSPOSED)) { + ConvolutionTransposedAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseConvolutionTransposedWithMultiply(mul_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::DEPTHWISE_CONVOLUTION)) { + DepthwiseConvolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseDepthwiseConvolution2DWithMultiply(mul_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::FULLY_CONNECTED)) { + FullyConnectedAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseFullyConnectedWithMultiply(mul_attr, conv_attr); + } else { + return {TransformStatus::SKIPPED, ""}; + } + + Status status = RemoveFollowingNode(graph, &mul_node, &conv_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove mul node after convolution: " + + status.error_message()}; + } + return {TransformStatus::APPLIED, ""}; + } +}; + +class MergeMulWithConvolution : public SequenceTransformation { + public: + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final { + auto& conv_node = *sequence[1]; + auto& mul_node = *sequence[0]; + if (mul_node.operation.type != ToString(OperationType::MUL) && + mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) { + return {TransformStatus::SKIPPED, ""}; + } + + MultiplyScalarAttributes mul_attr = + absl::any_cast(mul_node.operation.attributes); + if (!absl::get_if>( + &mul_attr.param) && + !absl::get_if(&mul_attr.param)) { + return { + TransformStatus::DECLINED, + "This fuse applicable only for broadcast or scalar multiplication."}; + } + + if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { + Convolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseMultiplyWithConvolution2D(mul_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::CONVOLUTION_TRANSPOSED)) { + ConvolutionTransposedAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseMultiplyWithConvolutionTransposed(mul_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::DEPTHWISE_CONVOLUTION)) { + DepthwiseConvolution2DAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseMultiplyWithDepthwiseConvolution2D(mul_attr, conv_attr); + } else if (conv_node.operation.type == + ToString(OperationType::FULLY_CONNECTED)) { + FullyConnectedAttributes* conv_attr = + absl::any_cast( + &conv_node.operation.attributes); + FuseMultiplyWithFullyConnected(mul_attr, conv_attr); + } else { + return {TransformStatus::SKIPPED, ""}; + } + + Status status = RemovePrecedingNode(graph, &mul_node, &conv_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove mul node after convolution: " + + status.error_message()}; + } + return {TransformStatus::APPLIED, ""}; + } +}; + +} // namespace + +std::unique_ptr NewMergeConvolutionWithMul() { + return absl::make_unique(); +} + +std::unique_ptr NewMergeMulWithConvolution() { + return absl::make_unique(); +} + +void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr, + Convolution2DAttributes* attr) { + auto mul = absl::get_if>(&mul_attr.param); + auto mul_scalar = absl::get_if(&mul_attr.param); + for (int d = 0; d < attr->weights.shape.o; ++d) { + const float multiplier = mul ? mul->data[d] : *mul_scalar; + for (int s = 0; s < attr->weights.shape.i; ++s) { + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s}); + attr->weights.data[index] *= multiplier; + } + } + } + if (!attr->bias.data.empty()) { + attr->bias.data[d] *= multiplier; + } + } +} + +void FuseDepthwiseConvolution2DWithMultiply( + const MultiplyScalarAttributes& mul_attr, + DepthwiseConvolution2DAttributes* attr) { + auto mul = absl::get_if>(&mul_attr.param); + auto mul_scalar = absl::get_if(&mul_attr.param); + for (int g = 0; g < attr->weights.shape.o; ++g) { + for (int s = 0; s < attr->weights.shape.i; ++s) { + const int d = s * attr->weights.shape.o + g; + const float multiplier = mul ? mul->data[d] : *mul_scalar; + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s}); + attr->weights.data[index] *= multiplier; + } + } + if (!attr->bias.data.empty()) { + attr->bias.data[d] *= multiplier; + } + } + } +} + +void FuseConvolutionTransposedWithMultiply( + const MultiplyScalarAttributes& mul_attr, + ConvolutionTransposedAttributes* attr) { + auto mul = absl::get_if>(&mul_attr.param); + auto mul_scalar = absl::get_if(&mul_attr.param); + for (int d = 0; d < attr->weights.shape.o; ++d) { + const float multiplier = mul ? mul->data[d] : *mul_scalar; + for (int s = 0; s < attr->weights.shape.i; ++s) { + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s}); + attr->weights.data[index] *= multiplier; + } + } + } + if (!attr->bias.data.empty()) { + attr->bias.data[d] *= multiplier; + } + } +} + +void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr, + FullyConnectedAttributes* attr) { + auto mul = absl::get_if>(&mul_attr.param); + auto mul_scalar = absl::get_if(&mul_attr.param); + for (int d = 0; d < attr->weights.shape.o; ++d) { + const float multiplier = mul ? mul->data[d] : *mul_scalar; + for (int s = 0; s < attr->weights.shape.i; ++s) { + const int index = attr->weights.shape.LinearIndex({d, 0, 0, s}); + attr->weights.data[index] *= multiplier; + } + if (!attr->bias.data.empty()) { + attr->bias.data[d] *= multiplier; + } + } +} + +void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr, + Convolution2DAttributes* attr) { + auto mul = absl::get_if>(&mul_attr.param); + auto mul_scalar = absl::get_if(&mul_attr.param); + for (int s = 0; s < attr->weights.shape.i; ++s) { + const float multiplier = mul ? mul->data[s] : *mul_scalar; + for (int d = 0; d < attr->weights.shape.o; ++d) { + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s}); + attr->weights.data[index] *= multiplier; + } + } + } + } +} + +void FuseMultiplyWithDepthwiseConvolution2D( + const MultiplyScalarAttributes& mul_attr, + DepthwiseConvolution2DAttributes* attr) { + auto mul = absl::get_if>(&mul_attr.param); + auto mul_scalar = absl::get_if(&mul_attr.param); + for (int s = 0; s < attr->weights.shape.i; ++s) { + const float multiplier = mul ? mul->data[s] : *mul_scalar; + for (int g = 0; g < attr->weights.shape.o; ++g) { + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({g, k_y, k_x, s}); + attr->weights.data[index] *= multiplier; + } + } + } + } +} + +void FuseMultiplyWithConvolutionTransposed( + const MultiplyScalarAttributes& mul_attr, + ConvolutionTransposedAttributes* attr) { + auto mul = absl::get_if>(&mul_attr.param); + auto mul_scalar = absl::get_if(&mul_attr.param); + for (int s = 0; s < attr->weights.shape.i; ++s) { + const float multiplier = mul ? mul->data[s] : *mul_scalar; + for (int d = 0; d < attr->weights.shape.o; ++d) { + for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) { + for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) { + const int index = attr->weights.shape.LinearIndex({d, k_y, k_x, s}); + attr->weights.data[index] *= multiplier; + } + } + } + } +} + +void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr, + FullyConnectedAttributes* attr) { + auto mul = absl::get_if>(&mul_attr.param); + auto mul_scalar = absl::get_if(&mul_attr.param); + for (int s = 0; s < attr->weights.shape.i; ++s) { + const float multiplier = mul ? mul->data[s] : *mul_scalar; + for (int d = 0; d < attr->weights.shape.o; ++d) { + const int index = attr->weights.shape.LinearIndex({d, 0, 0, s}); + attr->weights.data[index] *= multiplier; + } + } +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h new file mode 100644 index 00000000000..0227bfcb69c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h @@ -0,0 +1,93 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_MUL_TO_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_MUL_TO_CONV_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { + +// Fuse Multiply Scalar or Multiply Broadcast after Convolution(Convolution2D, +// DepthWise, TransposedConvolution, FullyConnected) into weights and biases of +// convolution. +std::unique_ptr NewMergeConvolutionWithMul(); + +// Fuse Multiply Scalar or Multiply Broadcast before Convolution(Convolution2D, +// DepthWise, TransposedConvolution, FullyConnected) into weights and biases of +// convolution. +std::unique_ptr NewMergeMulWithConvolution(); + +// Modify Convolution2DAttributes so that after making convolution with +// modified attributes we will have the same result as convolution +// with old attributes and following multiply operation. +void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr, + Convolution2DAttributes* attr); + +// Modify DepthwiseConvolution2DAttributes so that after making depth wise +// convolution with modified attributes we will have the same result as depth +// wise convolution with old attributes and following multiply operation. +void FuseDepthwiseConvolution2DWithMultiply( + const MultiplyScalarAttributes& mul_attr, + DepthwiseConvolution2DAttributes* attr); + +// Modify ConvolutionTransposedAttributes so that after making convolution +// transposed with modified attributes we will have the same result as +// convolution transposed with old attributes and following multiply operation. +void FuseConvolutionTransposedWithMultiply( + const MultiplyScalarAttributes& mul_attr, + ConvolutionTransposedAttributes* attr); + +// Modify FullyConnectedAttributes so that after making fully connected with +// modified attributes we will have the same result as fully connected +// with old attributes and following multiply operation. +void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr, + FullyConnectedAttributes* attr); + +// Modify Convolution2DAttributes so that after making convolution with +// modified attributes we will have the same result as multiply operation and +// convolution with old attributes +void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr, + Convolution2DAttributes* attr); + +// Modify DepthwiseConvolution2DAttributes so that after making depth wise +// convolution with modified attributes we will have the same result as multiply +// operation and depth wise convolution with old attributes +void FuseMultiplyWithDepthwiseConvolution2D( + const MultiplyScalarAttributes& mul_attr, + DepthwiseConvolution2DAttributes* attr); + +// Modify ConvolutionTransposedAttributes so that after making convolution +// transposed with modified attributes we will have the same result as multiply +// operation and convolution transposed with old attributes +void FuseMultiplyWithConvolutionTransposed( + const MultiplyScalarAttributes& mul_attr, + ConvolutionTransposedAttributes* attr); + +// Modify FullyConnectedAttributes so that after making fully connected +// with modified attributes we will have the same result as multiply +// operation and fully connected with old attributes +void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr, + FullyConnectedAttributes* attr); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_FUSE_MUL_TO_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc new file mode 100644 index 00000000000..32c61935630 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc @@ -0,0 +1,303 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h" + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +using ::testing::FloatNear; +using ::testing::Pointwise; + +namespace tflite { +namespace gpu { +namespace { + +TEST(MergeConvolutionWithMulTest, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 8); + + Convolution2DAttributes conv_attr; + conv_attr.padding.prepended = HW(0, 0); + conv_attr.padding.appended = HW(0, 0); + conv_attr.strides = HW(1, 1); + conv_attr.dilations = HW(1, 1); + conv_attr.weights.shape = OHWI(16, 3, 2, 8); + conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct()); + conv_attr.bias.shape = Linear(16); + conv_attr.bias.data.resize(16); + + Tensor mul_tensor; + mul_tensor.shape = Linear(16); + mul_tensor.data.resize(16); + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + auto conv_node = graph.NewNode(); + conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); + conv_node->operation.attributes = conv_attr; + auto mul_node = graph.NewNode(); + mul_node->operation.type = ToString(OperationType::MUL); + mul_node->operation.attributes = mul_attr; + + ASSERT_TRUE(graph.AddConsumer(conv_node->id, input->id).ok()); + + Value* output; + ASSERT_TRUE(AddOutput(&graph, mul_node, &output).ok()); + output->tensor.shape = BHWC(1, 4, 4, 16); + + Value* link1; + ASSERT_TRUE(ConnectTwoNodes(&graph, conv_node, mul_node, &link1).ok()); + link1->tensor.shape = BHWC(1, 4, 4, 16); + + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewMergeConvolutionWithMul(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("merge_convolution_with_mul", transformation.get()); + + EXPECT_EQ(1, graph.nodes().size()); + EXPECT_EQ(2, graph.values().size()); + EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D), + graph.nodes()[0]->operation.type); +} + +TEST(MergeMulWithConvolutionTest, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 8); + + Tensor mul_tensor; + mul_tensor.shape = Linear(8); + mul_tensor.data.resize(8); + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + Convolution2DAttributes conv_attr; + conv_attr.padding.prepended = HW(0, 0); + conv_attr.padding.appended = HW(0, 0); + conv_attr.strides = HW(1, 1); + conv_attr.dilations = HW(1, 1); + conv_attr.weights.shape = OHWI(16, 3, 2, 8); + conv_attr.weights.data.resize(conv_attr.weights.shape.DimensionsProduct()); + conv_attr.bias.shape = Linear(16); + conv_attr.bias.data.resize(16); + + auto conv_node = graph.NewNode(); + conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); + conv_node->operation.attributes = conv_attr; + auto mul_node = graph.NewNode(); + mul_node->operation.type = ToString(OperationType::MUL); + mul_node->operation.attributes = mul_attr; + + ASSERT_TRUE(graph.AddConsumer(mul_node->id, input->id).ok()); + + Value* output; + ASSERT_TRUE(AddOutput(&graph, conv_node, &output).ok()); + output->tensor.shape = BHWC(1, 4, 4, 16); + + Value* link1; + ASSERT_TRUE(ConnectTwoNodes(&graph, mul_node, conv_node, &link1).ok()); + link1->tensor.shape = BHWC(1, 4, 4, 16); + + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewMergeMulWithConvolution(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("merge_mul_with_convolution", transformation.get()); + + EXPECT_EQ(1, graph.nodes().size()); + EXPECT_EQ(2, graph.values().size()); + EXPECT_EQ(ToString(OperationType::CONVOLUTION_2D), + graph.nodes()[0]->operation.type); +} + +TEST(FuseMulAfterConvolution2DTest, Smoke) { + Convolution2DAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.5f, 2.5f}; + + Tensor mul_tensor; + mul_tensor.shape = Linear(2); + mul_tensor.data = {0.5f, 2.0f}; + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + FuseConvolution2DWithMultiply(mul_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.05f, 0.1f, 0.15f, 0.2f, 1.0f, 1.2f, 1.4f, 1.6f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {0.75f, 5.0f})); +} + +TEST(FuseMulAfterDepthwiseConvolution2DTest, Smoke) { + DepthwiseConvolution2DAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(4); + attr.bias.data = {1.5f, 2.5f, 1.0f, 2.0f}; + + Tensor mul_tensor; + mul_tensor.shape = Linear(4); + mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f}; + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + FuseDepthwiseConvolution2DWithMultiply(mul_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.05f, 0.8f, 0.15f, 1.6f, 1.0f, 0.15f, 1.4f, 0.2f})); + EXPECT_THAT(attr.bias.data, + Pointwise(FloatNear(1e-6), {0.75f, 5.0f, 4.0f, 0.5f})); +} + +TEST(FuseMulAfterConvolutionTransposedTest, Smoke) { + ConvolutionTransposedAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.5f, 2.5f}; + + Tensor mul_tensor; + mul_tensor.shape = Linear(2); + mul_tensor.data = {0.5f, 2.0f}; + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + FuseConvolutionTransposedWithMultiply(mul_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.05f, 0.1f, 0.15f, 0.2f, 1.0f, 1.2f, 1.4f, 1.6f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {0.75f, 5.0f})); +} + +TEST(FuseMulAfterFullyConnectedTest, Smoke) { + FullyConnectedAttributes attr; + attr.weights.shape = OHWI(2, 1, 1, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.5f, 2.5f}; + + Tensor mul_tensor; + mul_tensor.shape = Linear(2); + mul_tensor.data = {0.5f, 2.0f}; + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + FuseFullyConnectedWithMultiply(mul_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), {0.05f, 0.1f, 0.6f, 0.8f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {0.75f, 5.0f})); +} + +TEST(FuseMulBeforeConvolution2DTest, Smoke) { + Convolution2DAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.5f, 2.5f}; + + Tensor mul_tensor; + mul_tensor.shape = Linear(2); + mul_tensor.data = {0.5f, 2.0f}; + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + FuseMultiplyWithConvolution2D(mul_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.05f, 0.4f, 0.15f, 0.8f, 0.25f, 1.2f, 0.35f, 1.6f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.5f, 2.5f})); +} + +TEST(FuseMulBeforeDepthwiseConvolution2DTest, Smoke) { + DepthwiseConvolution2DAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(4); + attr.bias.data = {1.5f, 2.5f, 1.0f, 2.0f}; + + Tensor mul_tensor; + mul_tensor.shape = Linear(4); + mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f}; + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + FuseMultiplyWithDepthwiseConvolution2D(mul_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.05f, 0.4f, 0.15f, 0.8f, 0.25f, 1.2f, 0.35f, 1.6f})); + EXPECT_THAT(attr.bias.data, + Pointwise(FloatNear(1e-6), {1.5f, 2.5f, 1.0f, 2.0f})); +} + +TEST(FuseMulBeforeConvolutionTransposedTest, Smoke) { + ConvolutionTransposedAttributes attr; + attr.weights.shape = OHWI(2, 1, 2, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.5f, 2.5f}; + + Tensor mul_tensor; + mul_tensor.shape = Linear(2); + mul_tensor.data = {0.5f, 2.0f}; + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + FuseMultiplyWithConvolutionTransposed(mul_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), + {0.05f, 0.4f, 0.15f, 0.8f, 0.25f, 1.2f, 0.35f, 1.6f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.5f, 2.5f})); +} + +TEST(FuseMulBeforeFullyConnectedTest, Smoke) { + FullyConnectedAttributes attr; + attr.weights.shape = OHWI(2, 1, 1, 2); + attr.weights.data = {0.1f, 0.2f, 0.3f, 0.4f}; + attr.bias.shape = Linear(2); + attr.bias.data = {1.5f, 2.5f}; + + Tensor mul_tensor; + mul_tensor.shape = Linear(2); + mul_tensor.data = {0.5f, 2.0f}; + MultiplyScalarAttributes mul_attr; + mul_attr.param = mul_tensor; + + FuseMultiplyWithFullyConnected(mul_attr, &attr); + + EXPECT_THAT(attr.weights.data, + Pointwise(FloatNear(1e-6), {0.05f, 0.4f, 0.15f, 0.8f})); + EXPECT_THAT(attr.bias.data, Pointwise(FloatNear(1e-6), {1.5f, 2.5f})); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc b/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc new file mode 100644 index 00000000000..0e1273f8d4a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" + +#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h" + +namespace tflite { +namespace gpu { + +bool ApplyGeneralTransformations(ModelTransformer* transformer) { + // whenever any of these transforms return false, that means that a graph + // is in the broken state and processing should not continue. + return transformer->Apply("remove_degenerate_upsampling", + NewRemoveDegenerateUpsampling().get()) && + transformer->Apply("remove_single_input_add", + NewRemoveSingleInputAdd().get()) && + transformer->Apply("remove_single_input_concat", + NewRemoveSingleInputConcat().get()) && + transformer->Apply("remove_identity_reshape", + NewRemoveIdentityReshape().get()) && + transformer->Apply("make_padding_from_concat", + NewMakePaddingFromConcat().get()) && + transformer->Apply("make_fully_connected_from_convolution", + NewMakeFullyConnectedFromConvolution().get()) && + transformer->Apply("merge_padding_with_convolution", + NewMergePaddingWithConvolution2D().get()) && + transformer->Apply("merge_padding_with_pooling", + NewMergePaddingWithPooling().get()) && + transformer->Apply("merge_padding_with_depthwise_convolution", + NewMergePaddingWithDepthwiseConvolution().get()) && + transformer->Apply("merge_convolution_with_mul", + NewMergeConvolutionWithMul().get()) && + transformer->Apply("merge_convolution_with_add", + NewMergeConvolutionWithAdd().get()) && + transformer->Apply("merge_mul_with_convolution", + NewMergeMulWithConvolution().get()) && + transformer->Apply("merge_add_with_convolution", + NewMergeAddWithConvolution().get()); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h b/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h new file mode 100644 index 00000000000..ffc5bba4f1a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GENERAL_TRANSFORMATIONS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GENERAL_TRANSFORMATIONS_H_ + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +// @return false when something went wrong that turned a graph in a broken state +bool ApplyGeneralTransformations(ModelTransformer* transformer); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GENERAL_TRANSFORMATIONS_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc new file mode 100644 index 00000000000..1236cdec214 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" + +#include "absl/memory/memory.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace { + +bool IsConvEquivalentToFullyConnected(const Convolution2DAttributes& attr) { + return attr.weights.shape.w == 1 && // + attr.weights.shape.h == 1 && // + attr.strides == HW(1, 1) && // + attr.dilations == HW(1, 1) && // + attr.padding.prepended == HW(0, 0) && // + attr.padding.appended == HW(0, 0); +} + +class MakeFullyConnectedFromConvolution : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final { + if (node->operation.type != ToString(OperationType::CONVOLUTION_2D)) { + return {TransformStatus::SKIPPED, ""}; + } + auto inputs = graph->FindInputs(node->id); + if (inputs.size() != 1) { + return {TransformStatus::SKIPPED, ""}; + } + + const auto& input_shape = inputs[0]->tensor.shape; + if (input_shape.w != 1 || input_shape.h != 1) { + return {TransformStatus::SKIPPED, ""}; + } + + const auto& conv_attr = absl::any_cast( + node->operation.attributes); + if (!IsConvEquivalentToFullyConnected(conv_attr)) { + return {TransformStatus::SKIPPED, ""}; + } + + FullyConnectedAttributes fc_attr; + fc_attr.weights = conv_attr.weights; + fc_attr.bias = conv_attr.bias; + + node->operation.attributes = fc_attr; + node->operation.type = ToString(OperationType::FULLY_CONNECTED); + return {TransformStatus::APPLIED, + "Replaced convolution with fully connected."}; + } +}; + +} // namespace + +std::unique_ptr NewMakeFullyConnectedFromConvolution() { + return absl::make_unique(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h new file mode 100644 index 00000000000..9a62d0bf12e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_FULLY_CONNECTED_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +// Turns convolution with kernel 1x1 and input tensor with h=1 and w=1 into +// fully connected operation +std::unique_ptr NewMakeFullyConnectedFromConvolution(); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_FULLY_CONNECTED_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc new file mode 100644 index 00000000000..dd5a1183fe8 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" + +#include +#include +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +namespace tflite { +namespace gpu { +namespace { + +TEST(MakeFullyConnected, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 8); + + Convolution2DAttributes attr0; + attr0.padding.prepended = HW(0, 0); + attr0.padding.appended = HW(0, 0); + attr0.strides = HW(1, 1); + attr0.dilations = HW(1, 1); + attr0.weights.shape = OHWI(16, 1, 1, 8); + attr0.bias.shape = Linear(16); + + Convolution2DAttributes attr1; + attr1.padding.prepended = HW(0, 0); + attr1.padding.appended = HW(0, 0); + attr1.strides = HW(4, 4); + attr1.dilations = HW(1, 1); + attr1.weights.shape = OHWI(16, 4, 4, 16); + attr1.bias.shape = Linear(16); + + Convolution2DAttributes attr2; + attr2.padding.prepended = HW(0, 0); + attr2.padding.appended = HW(0, 0); + attr2.strides = HW(1, 1); + attr2.dilations = HW(1, 1); + attr2.weights.shape = OHWI(32, 1, 1, 16); + attr2.bias.shape = Linear(32); + + auto conv1x1_node0 = graph.NewNode(); + conv1x1_node0->operation.type = ToString(OperationType::CONVOLUTION_2D); + conv1x1_node0->operation.attributes = attr0; + auto conv4x4_node1 = graph.NewNode(); + conv4x4_node1->operation.type = ToString(OperationType::CONVOLUTION_2D); + conv4x4_node1->operation.attributes = attr1; + auto conv1x1_node2 = graph.NewNode(); + conv1x1_node2->operation.type = ToString(OperationType::CONVOLUTION_2D); + conv1x1_node2->operation.attributes = attr2; + + ASSERT_TRUE(graph.AddConsumer(conv1x1_node0->id, input->id).ok()); + + Value* output; + ASSERT_TRUE(AddOutput(&graph, conv1x1_node2, &output).ok()); + output->tensor.shape = BHWC(1, 1, 1, 32); + + Value* link1; + ASSERT_TRUE( + ConnectTwoNodes(&graph, conv1x1_node0, conv4x4_node1, &link1).ok()); + link1->tensor.shape = BHWC(1, 4, 4, 16); + + Value* link2; + ASSERT_TRUE( + ConnectTwoNodes(&graph, conv4x4_node1, conv1x1_node2, &link2).ok()); + link2->tensor.shape = BHWC(1, 1, 1, 16); + + ASSERT_EQ(3, graph.nodes().size()); + ASSERT_EQ(4, graph.values().size()); + + auto transformation = NewMakeFullyConnectedFromConvolution(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("make_fully_connected", transformation.get()); + + ASSERT_EQ(3, graph.nodes().size()); + ASSERT_EQ(4, graph.values().size()); + ASSERT_EQ(ToString(OperationType::CONVOLUTION_2D), + graph.nodes()[0]->operation.type); + ASSERT_EQ(ToString(OperationType::CONVOLUTION_2D), + graph.nodes()[1]->operation.type); + ASSERT_EQ(ToString(OperationType::FULLY_CONNECTED), + graph.nodes()[2]->operation.type); + auto fc_attr = absl::any_cast( + graph.nodes()[2]->operation.attributes); + EXPECT_EQ(OHWI(32, 1, 1, 16), fc_attr.weights.shape); + EXPECT_EQ(Linear(32), fc_attr.bias.shape); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc new file mode 100644 index 00000000000..f1087c23fc1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" + +#include "absl/memory/memory.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace { + +bool IsConstZeros(const Node& node) { + if (node.operation.type != ToString(OperationType::CONST)) { + return false; + } + auto& attr = + absl::any_cast(node.operation.attributes); + for (auto f : attr.tensor.data) { + if (f != 0) { + return false; + } + } + return true; +} + +class MakePaddingFromZerosConcat : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final { + if (node->operation.type != ToString(OperationType::CONCAT)) { + return {TransformStatus::SKIPPED, ""}; + } + auto inputs = graph->FindInputs(node->id); + if (inputs.size() != 2) { + return {TransformStatus::SKIPPED, ""}; + } + + bool first = true; + for (auto input : inputs) { + auto dep = graph->FindProducer(input->id); + if (dep != nullptr && IsConstZeros(*dep)) { + auto& concat_attr = + absl::any_cast(node->operation.attributes); + PadAttributes pad_attr; + pad_attr.type = PaddingContentType::ZEROS; + pad_attr.appended = HWC(0, 0, 0); + pad_attr.prepended = HWC(0, 0, 0); + HWC* p = first ? &pad_attr.prepended : &pad_attr.appended; + switch (concat_attr.axis) { + case Axis::HEIGHT: + p->h = input->tensor.shape.h; + break; + case Axis::WIDTH: + p->w = input->tensor.shape.w; + break; + case Axis::CHANNELS: + p->c = input->tensor.shape.c; + break; + default: + return {TransformStatus::DECLINED, + "Padding for concat axis is unsupported: " + + ToString(concat_attr.axis)}; + } + Status status = RemovePrecedingNode(graph, dep, node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove const node: " + status.error_message()}; + } + node->operation.attributes = pad_attr; + node->operation.type = ToString(OperationType::PAD); + return {TransformStatus::APPLIED, "Replaced concat with padding"}; + } + first = false; + } + return {TransformStatus::SKIPPED, ""}; + } +}; + +} // namespace + +std::unique_ptr NewMakePaddingFromConcat() { + return absl::make_unique(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.h b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.h new file mode 100644 index 00000000000..c7774eb32ba --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_PADDING_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_PADDING_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +// Turns concat that handles only two tensors, where one tensor is zeros, into +// padding operation. +std::unique_ptr NewMakePaddingFromConcat(); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MAKE_PADDING_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc new file mode 100644 index 00000000000..cbfbfedcf48 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" + +#include +#include +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" + +namespace tflite { +namespace gpu { +namespace { + +TEST(MakePadding, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 2, 3, 5); + + auto concat_node = graph.NewNode(); + ASSERT_TRUE(graph.AddConsumer(concat_node->id, input->id).ok()); + concat_node->operation.type = ToString(OperationType::CONCAT); + ConcatAttributes attr; + attr.axis = Axis::HEIGHT; + concat_node->operation.attributes = attr; + + Value* output; + ASSERT_TRUE(AddOutput(&graph, concat_node, &output).ok()); + output->tensor.shape = BHWC(1, 7, 3, 5); + + auto const_node = graph.NewNode(); + const_node->operation.type = ToString(OperationType::CONST); + ConstTensorAttributes const_attr; + const_attr.tensor.shape = BHWC(1, 5, 3, 5); + const_attr.tensor.data = + std::vector(const_attr.tensor.shape.DimensionsProduct(), 0); + const_node->operation.attributes = const_attr; + + Value* const_link; + ASSERT_TRUE( + ConnectTwoNodes(&graph, const_node, concat_node, &const_link).ok()); + const_link->tensor.shape = const_attr.tensor.shape; + + ASSERT_EQ(2, graph.nodes().size()); + + auto transformation = NewMakePaddingFromConcat(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("make_padding", transformation.get()); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + auto pad_node = graph.nodes()[0]; + ASSERT_EQ(ToString(OperationType::PAD), pad_node->operation.type); + auto pad_attr = absl::any_cast(pad_node->operation.attributes); + EXPECT_EQ(HWC(0, 0, 0), pad_attr.prepended); + EXPECT_EQ(HWC(5, 0, 0), pad_attr.appended); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc new file mode 100644 index 00000000000..5257ba44f0e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.cc @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace { + +class MatchDilatedConvolution : public SequenceTransformation { + public: + int ExpectedSequenceLength() const final { return 3; } + + // TODO(eignasheva): use span instead of const reference b/131628066. + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final { + auto& sb_node = *sequence[0]; + auto& conv_node = *sequence[1]; + auto& bs_node = *sequence[2]; + if (sb_node.operation.type != ToString(OperationType::SPACE_TO_BATCH) && + bs_node.operation.type != ToString(OperationType::BATCH_TO_SPACE)) { + return {TransformStatus::SKIPPED, ""}; + } + if (conv_node.operation.type != + ToString(OperationType::DEPTHWISE_CONVOLUTION) && + conv_node.operation.type != ToString(OperationType::CONVOLUTION_2D)) { + return {TransformStatus::SKIPPED, ""}; + } + + auto sb_attr = + absl::any_cast(sb_node.operation.attributes); + + auto bs_attr = + absl::any_cast(bs_node.operation.attributes); + + if (sb_attr.block != bs_attr.block) { + return {TransformStatus::INVALID, "Invalid block size"}; + } + + if (conv_node.operation.type == + ToString(OperationType::DEPTHWISE_CONVOLUTION)) { + auto dw_attr = absl::any_cast( + conv_node.operation.attributes); + dw_attr.padding = sb_attr.padding - bs_attr.crop; + dw_attr.dilations = sb_attr.block; + conv_node.operation.attributes = std::move(dw_attr); + } else { + auto conv2d_attr = absl::any_cast( + conv_node.operation.attributes); + conv2d_attr.padding = sb_attr.padding - bs_attr.crop; + conv2d_attr.dilations = sb_attr.block; + conv_node.operation.attributes = std::move(conv2d_attr); + } + + Status status = RemoveFollowingNode(graph, &bs_node, &conv_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove batch_to_space node after convolution."}; + } + status = RemovePrecedingNode(graph, &sb_node, &conv_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove space_to_batch node before convolution."}; + } + + return {TransformStatus::APPLIED, ""}; + } +}; + +} // namespace + +std::unique_ptr NewMatchDilatedConvolution() { + return absl::make_unique(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h new file mode 100644 index 00000000000..38b87d855ec --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +// TF->TFLite converter converts convolution with dilation into the chain of +// SpaceToBatch->Convolution->BatchToSpace. Our GPU backend natively supports +// dilation in convolutions, so we try to skip this inefficiency. For more +// information see b/131436214. +std::unique_ptr NewMatchDilatedConvolution(); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution_test.cc new file mode 100644 index 00000000000..74c385b6421 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h" + +#include +#include +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +namespace tflite { +namespace gpu { +namespace { + +TEST(MatchDilatedConvolutionTest, MakesDilatedConvolution) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 95, 1, 17); + + SpaceToBatchAttributes sb_attr; + sb_attr.block = HW(128, 1); + sb_attr.padding.prepended = HW(128, 0); + sb_attr.padding.appended = HW(161, 0); + + DepthwiseConvolution2DAttributes dw_attr; + dw_attr.padding.prepended = HW(0, 0); + dw_attr.padding.appended = HW(0, 0); + dw_attr.strides = HW(1, 1); + dw_attr.dilations = HW(1, 1); + dw_attr.weights.shape = OHWI(1, 3, 1, 17); + dw_attr.bias.shape = Linear(96); + + BatchToSpaceAttributes bs_attr; + bs_attr.block = HW(128, 1); + bs_attr.crop.prepended = HW(0, 0); + bs_attr.crop.appended = HW(33, 0); + + auto sb_node = graph.NewNode(); + sb_node->operation.type = ToString(OperationType::SPACE_TO_BATCH); + sb_node->operation.attributes = sb_attr; + auto dw_node = graph.NewNode(); + dw_node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION); + dw_node->operation.attributes = dw_attr; + auto bs_node = graph.NewNode(); + bs_node->operation.type = ToString(OperationType::BATCH_TO_SPACE); + bs_node->operation.attributes = bs_attr; + + ASSERT_TRUE(graph.AddConsumer(sb_node->id, input->id).ok()); + + Value* output; + ASSERT_TRUE(AddOutput(&graph, bs_node, &output).ok()); + output->tensor.shape = BHWC(1, 95, 1, 17); + + Value* sb_link; + ASSERT_TRUE(ConnectTwoNodes(&graph, sb_node, dw_node, &sb_link).ok()); + sb_link->tensor.shape = BHWC(21, 128, 1, 17); + + Value* bs_link; + ASSERT_TRUE(ConnectTwoNodes(&graph, dw_node, bs_node, &bs_link).ok()); + bs_link->tensor.shape = BHWC(1, 95, 1, 17); + + ASSERT_EQ(graph.nodes().size(), 3); + ASSERT_EQ(graph.values().size(), 4); + + auto transformation = NewMatchDilatedConvolution(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("match_dilated_convolution", transformation.get()); + + ASSERT_EQ(graph.nodes().size(), 1); + ASSERT_EQ(graph.values().size(), 2); + ASSERT_EQ(graph.nodes()[0]->operation.type, + ToString(OperationType::DEPTHWISE_CONVOLUTION)); + + auto updated_dw_attr = absl::any_cast( + graph.nodes()[0]->operation.attributes); + EXPECT_EQ(updated_dw_attr.padding.prepended, HW(128, 0)); + EXPECT_EQ(updated_dw_attr.padding.appended, HW(128, 0)); + EXPECT_EQ(updated_dw_attr.dilations, HW(128, 1)); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/matching.h b/tensorflow/lite/delegates/gpu/common/transformations/matching.h new file mode 100644 index 00000000000..0dfd21e50ba --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/matching.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCHING_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCHING_H_ + +// A file provides predicates to match subgraphs. + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" + +namespace tflite { +namespace gpu { + +// Returns true if a container of nodes contains nodes that all match given +// operation_types. +template +bool MatchesByOperationType(const T& nodes, + const std::vector& types) { + if (nodes.size() != types.size()) return false; + return std::mismatch(nodes.begin(), nodes.end(), types.begin(), + [&](typename T::value_type a, const std::string& b) { + return a->operation.type == b; + }) + .first == nodes.end(); +} + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCHING_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc new file mode 100644 index 00000000000..cd0b282e1cc --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc @@ -0,0 +1,171 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/matching.h" + +namespace tflite { +namespace gpu { +namespace { + +template +class MergePaddingWith2DOperation : public SequenceTransformation { + public: + explicit MergePaddingWith2DOperation(OperationType operation_type) + : operations_to_match_( + {ToString(OperationType::PAD), ToString(operation_type)}) {} + + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final { + if (!MatchesByOperationType(sequence, operations_to_match_)) { + return {TransformStatus::SKIPPED, ""}; + } + + Node* pad_node = sequence.front(); + Node* op_node = sequence.back(); + + PadAttributes pad_attr = + absl::any_cast(pad_node->operation.attributes); + + if (pad_attr.type != PaddingContentType::ZEROS) { + return {TransformStatus::DECLINED, "Only Zero padding is supported."}; + } + if (pad_attr.appended.c != 0 || pad_attr.prepended.c != 0) { + return {TransformStatus::DECLINED, + "Pad has non-zero padding on non HW axis."}; + } + + Attr* node_attr = absl::any_cast(&op_node->operation.attributes); + Status status = RemovePrecedingNode(graph, pad_node, op_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove Pad node with Operation node: " + + status.error_message()}; + } + + node_attr->padding.appended.h += pad_attr.appended.h; + node_attr->padding.appended.w += pad_attr.appended.w; + node_attr->padding.prepended.h += pad_attr.prepended.h; + node_attr->padding.prepended.w += pad_attr.prepended.w; + return { + TransformStatus::APPLIED, + absl::StrCat("Added padding: prepended = {h = ", pad_attr.prepended.h, + ", w = ", pad_attr.prepended.w, "}, appended = { h = ", + pad_attr.appended.h, ", w = ", pad_attr.appended.w, "}")}; + } + + private: + const std::vector operations_to_match_; +}; + +} // namespace + +std::unique_ptr NewMergePaddingWithPooling() { + return absl::make_unique>( + OperationType::POOLING_2D); +} + +std::unique_ptr NewMergePaddingWithConvolution2D() { + return absl::make_unique< + MergePaddingWith2DOperation>( + OperationType::CONVOLUTION_2D); +} + +std::unique_ptr +NewMergePaddingWithDepthwiseConvolution() { + return absl::make_unique< + MergePaddingWith2DOperation>( + OperationType::DEPTHWISE_CONVOLUTION); +} + +class MergePaddingWithAddOperation : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final { + if (node->operation.type != ToString(OperationType::PAD)) { + return {TransformStatus::SKIPPED, ""}; + } + auto inputs = graph->FindInputs(node->id); + if (inputs.size() != 1) { + return {TransformStatus::SKIPPED, ""}; + } + + const auto& input_shape = graph->FindInputs(node->id)[0]->tensor.shape; + if (input_shape.c % 4 != 0) { + return {TransformStatus::DECLINED, + "Pad with input where src_channels % 4 != 0"}; + } + + PadAttributes pad_attr = + absl::any_cast(node->operation.attributes); + + if (pad_attr.type != PaddingContentType::ZEROS) { + return {TransformStatus::DECLINED, "Only Zero padding is supported."}; + } + if (pad_attr.prepended != HWC(0, 0, 0) || pad_attr.appended.h != 0 || + pad_attr.appended.w != 0) { + return {TransformStatus::DECLINED, + "Pad has padding not only in appended channels axis."}; + } + + auto pad_output = graph->FindOutputs(node->id)[0]; + auto consumer_nodes = graph->FindConsumers(pad_output->id); + if (consumer_nodes.size() != 1) { + return {TransformStatus::SKIPPED, ""}; + } + auto add_node = consumer_nodes[0]; + auto consumer_type = OperationTypeFromString(add_node->operation.type); + if (consumer_type != OperationType::ADD) { + return {TransformStatus::SKIPPED, ""}; + } + + AddAttributes add_attr = + absl::any_cast(add_node->operation.attributes); + auto add_broadcated_vector = + absl::get_if>(&add_attr.param); + if (add_broadcated_vector) { + return {TransformStatus::SKIPPED, + "Can not remove padding when this broadcasted ADD"}; + } + + Status status = RemovePrecedingNode(graph, node, add_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove Pad node " + status.error_message()}; + } + + return {TransformStatus::APPLIED, + "Removed padding with zeroes in appended channels dimension"}; + } +}; + +std::unique_ptr NewMergePaddingWithAdd() { + return absl::make_unique(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h new file mode 100644 index 00000000000..d28cdfb70cd --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MERGE_PADDING_WITH_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MERGE_PADDING_WITH_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +std::unique_ptr NewMergePaddingWithPooling(); + +std::unique_ptr NewMergePaddingWithConvolution2D(); + +std::unique_ptr +NewMergePaddingWithDepthwiseConvolution(); + +// This transform requires Add operation support of unequal tensors on input. +// Padding should be with zeroes, and only appended in Z axis. +// Also input tensor channels should be divisible by 4(aligned). +// It should replace following pattern: +// 1) some tensor padded with zeroes in Z dim, for example from 24 to 32 +// channels +// 2) than this tensor used only in Add operation and Add operation +// adds this useless zeroes on 24-32 channels. +// It removes this useless addition +// by using Add with unequal tensors on input. Instead of filling with zeroes +// and adding this part in Add operation, Add operation makes additional check +// for this tensor: +// if (channels < src_channels) { +// result += tensor_from_pad_operation.data[index]; +// } +std::unique_ptr NewMergePaddingWithAdd(); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MERGE_PADDING_WITH_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc new file mode 100644 index 00000000000..2ba7feaf35a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h" + +#include +#include +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +namespace tflite { +namespace gpu { +namespace { + +TEST(MergePaddingWith, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + + auto pad_node = graph.NewNode(); + ASSERT_TRUE(graph.AddConsumer(pad_node->id, input->id).ok()); + pad_node->operation.type = ToString(OperationType::PAD); + PadAttributes attr; + attr.prepended = HWC(1, 1, 0); + attr.appended = HWC(2, 2, 0); + pad_node->operation.attributes = attr; + + auto conv_node = graph.NewNode(); + Value* temp; + ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node, conv_node, &temp).ok()); + ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok()); + conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); + Convolution2DAttributes conv_attr; + conv_attr.padding.appended = HW(0, 0); + conv_attr.padding.prepended = HW(0, 0); + conv_node->operation.attributes = conv_attr; + + ASSERT_EQ(2, graph.nodes().size()); + + auto transformation = NewMergePaddingWithConvolution2D(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("merge_padding", transformation.get()); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + ASSERT_EQ(conv_node, graph.nodes()[0]); + conv_attr = + absl::any_cast(conv_node->operation.attributes); + EXPECT_EQ(HW(1, 1), conv_attr.padding.prepended); + EXPECT_EQ(HW(2, 2), conv_attr.padding.appended); +} + +TEST(MergePaddingWith, MergeTwo) { + GraphFloat32 graph; + auto input = graph.NewValue(); + + auto pad_node1 = graph.NewNode(); + ASSERT_TRUE(graph.AddConsumer(pad_node1->id, input->id).ok()); + pad_node1->operation.type = ToString(OperationType::PAD); + PadAttributes attr; + attr.prepended = HWC(1, 1, 0); + attr.appended = HWC(0, 0, 0); + pad_node1->operation.attributes = attr; + + auto pad_node2 = graph.NewNode(); + Value* temp; + ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node1, pad_node2, &temp).ok()); + pad_node2->operation.type = ToString(OperationType::PAD); + attr.prepended = HWC(0, 0, 0); + attr.appended = HWC(2, 2, 0); + pad_node2->operation.attributes = attr; + + auto conv_node = graph.NewNode(); + ASSERT_TRUE(ConnectTwoNodes(&graph, pad_node2, conv_node, &temp).ok()); + ASSERT_TRUE(AddOutput(&graph, conv_node, &temp).ok()); + conv_node->operation.type = ToString(OperationType::CONVOLUTION_2D); + Convolution2DAttributes conv_attr; + conv_attr.padding.appended = HW(0, 0); + conv_attr.padding.prepended = HW(0, 0); + conv_node->operation.attributes = conv_attr; + + ASSERT_EQ(3, graph.nodes().size()); + + auto transformation = NewMergePaddingWithConvolution2D(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("merge_padding", transformation.get()); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + ASSERT_EQ(conv_node, graph.nodes()[0]); + conv_attr = + absl::any_cast(conv_node->operation.attributes); + EXPECT_EQ(HW(1, 1), conv_attr.padding.prepended); + EXPECT_EQ(HW(2, 2), conv_attr.padding.appended); +} + +TEST(MergePaddingWithAdd, MergeOne) { + GraphFloat32 graph; + auto input0 = graph.NewValue(); + input0->tensor.shape = BHWC(1, 4, 4, 8); + auto input1 = graph.NewValue(); + auto padded = graph.NewValue(); + auto output = graph.NewValue(); + + auto pad_node = graph.NewNode(); + pad_node->operation.type = ToString(OperationType::PAD); + PadAttributes pad_attr; + pad_attr.prepended = HWC(0, 0, 0); + pad_attr.appended = HWC(0, 0, 32); + pad_node->operation.attributes = pad_attr; + + ASSERT_TRUE(graph.AddConsumer(pad_node->id, input0->id).ok()); + ASSERT_TRUE(graph.SetProducer(pad_node->id, padded->id).ok()); + + auto add_node = graph.NewNode(); + AddAttributes add_attr; + ASSERT_TRUE(graph.AddConsumer(add_node->id, padded->id).ok()); + ASSERT_TRUE(graph.AddConsumer(add_node->id, input1->id).ok()); + ASSERT_TRUE(graph.SetProducer(add_node->id, output->id).ok()); + add_node->operation.type = ToString(OperationType::ADD); + add_node->operation.attributes = add_attr; + + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(4, graph.values().size()); + + auto transformation = NewMergePaddingWithAdd(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("merge_padding", transformation.get()); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + EXPECT_EQ(add_node, graph.nodes()[0]); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc new file mode 100644 index 00000000000..7e8727ad1c2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace { + +using ShouldRemoveOperation = std::function; + +class RemoveOperation : public SequenceTransformation { + public: + explicit RemoveOperation(ShouldRemoveOperation remove_predicate) + : remove_predicate_(std::move(remove_predicate)) {} + + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final { + Node* prev_op_node = sequence.front(); + Node* op_node = sequence.back(); + if (!remove_predicate_(graph, op_node)) { + return {TransformStatus::SKIPPED, ""}; + } + Status status = RemoveFollowingNode(graph, op_node, prev_op_node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove a node: " + status.error_message()}; + } + return {TransformStatus::APPLIED, ""}; + } + + private: + ShouldRemoveOperation remove_predicate_; +}; + +} // namespace + +std::unique_ptr NewRemoveSingleInputConcat() { + // Using SequenceTransformation implies that CONCAT has a single input. + auto type = ToString(OperationType::CONCAT); + return absl::make_unique( + [type](GraphFloat32* graph, Node* node) { + return type == node->operation.type; + }); +} + +std::unique_ptr NewRemoveSingleInputAdd() { + // Using SequenceTransformation implies that ADD has a single input. + auto type = ToString(OperationType::ADD); + return absl::make_unique( + [type](GraphFloat32* graph, Node* node) { + if (node->operation.type != type) { + return false; + } + auto& attr = + absl::any_cast(node->operation.attributes); + return absl::get_if>(&attr.param) == + nullptr; + }); +} + +std::unique_ptr NewRemoveDegenerateUpsampling() { + auto type = ToString(OperationType::UPSAMPLE_2D); + return absl::make_unique( + [type](GraphFloat32* graph, Node* node) { + if (node->operation.type != type) { + return false; + } + auto inputs = graph->FindInputs(node->id); + auto outputs = graph->FindOutputs(node->id); + return inputs.size() == 1 && outputs.size() == 1 && + inputs[0]->tensor.shape == outputs[0]->tensor.shape; + }); +} + +class RemoveIdentityReshape : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final { + if (node->operation.type != ToString(OperationType::RESHAPE)) { + return {TransformStatus::SKIPPED, ""}; + } + auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape; + const auto& reshape_attr = + absl::any_cast(node->operation.attributes); + if (input_shape != reshape_attr.new_shape) { + return {TransformStatus::SKIPPED, ""}; + } + Status status = RemoveOneInputOneOutputNode(graph, node); + if (!status.ok()) { + return {TransformStatus::INVALID, + "Unable to remove a node: " + status.error_message()}; + } + return {TransformStatus::APPLIED, + "Removed reshape with input_shape == output_shape."}; + } +}; + +std::unique_ptr NewRemoveIdentityReshape() { + return absl::make_unique(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h new file mode 100644 index 00000000000..ef1939bc24e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_REMOVE_NOOP_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_REMOVE_NOOP_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +std::unique_ptr NewRemoveSingleInputConcat(); + +std::unique_ptr NewRemoveSingleInputAdd(); + +std::unique_ptr NewRemoveDegenerateUpsampling(); + +// Removes reshape with input shape == output shape +std::unique_ptr NewRemoveIdentityReshape(); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_REMOVE_NOOP_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc new file mode 100644 index 00000000000..0c7760de18b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc @@ -0,0 +1,178 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h" + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" + +namespace tflite { +namespace gpu { +namespace { + +TEST(RemoveSingleInputAdd, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + auto first_node = graph.NewNode(); + ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); + + auto add_node = graph.NewNode(); + Value* output; + ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); + add_node->operation.type = ToString(OperationType::ADD); + add_node->operation.attributes = AddAttributes(); + + Value* temp; + ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok()); + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewRemoveSingleInputAdd(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("noop", transformation.get()); + + EXPECT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + ASSERT_EQ(first_node, graph.nodes()[0]); + ASSERT_EQ(input, graph.values()[0]); + ASSERT_EQ(output, graph.values()[1]); +} + +TEST(RemoveSingleInputAdd, DoNotTrigger_Tensor) { + GraphFloat32 graph; + auto input = graph.NewValue(); + auto first_node = graph.NewNode(); + ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); + + auto add_node = graph.NewNode(); + Value* output; + ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); + add_node->operation.type = ToString(OperationType::ADD); + AddAttributes attr; + attr.param = Tensor(); + add_node->operation.attributes = attr; + + Value* temp; + ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, add_node, &temp).ok()); + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewRemoveSingleInputAdd(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("noop", transformation.get()); + + EXPECT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); +} + +TEST(RemoveSingleInputAdd, DoNotTrigger_Multiple) { + GraphFloat32 graph; + auto input = graph.NewValue(); + auto node_a = graph.NewNode(); + auto node_b = graph.NewNode(); + ASSERT_TRUE(graph.AddConsumer(node_a->id, input->id).ok()); + ASSERT_TRUE(graph.AddConsumer(node_b->id, input->id).ok()); + + auto add_node = graph.NewNode(); + Value* output; + ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok()); + add_node->operation.type = ToString(OperationType::ADD); + + Value* temp; + ASSERT_TRUE(ConnectTwoNodes(&graph, node_a, add_node, &temp).ok()); + ASSERT_TRUE(ConnectTwoNodes(&graph, node_b, add_node, &temp).ok()); + ASSERT_EQ(3, graph.nodes().size()); + ASSERT_EQ(4, graph.values().size()); + + auto transformation = NewRemoveSingleInputAdd(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("noop", transformation.get()); + + ASSERT_EQ(3, graph.nodes().size()); + ASSERT_EQ(4, graph.values().size()); +} + +TEST(RemoveDegenerateUpsampling, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + auto first_node = graph.NewNode(); + ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); + + auto node_to_remove = graph.NewNode(); + Value* output; + ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok()); + output->tensor.shape = BHWC(1, 5, 5, 1); + node_to_remove->operation.type = ToString(OperationType::UPSAMPLE_2D); + Upsample2DAttributes attr; + attr.new_shape = HW(5, 5); + attr.type = UpsamplingType::BILINEAR; + node_to_remove->operation.attributes = attr; + + Value* link; + ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok()); + link->tensor.shape = output->tensor.shape; + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewRemoveDegenerateUpsampling(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("noop", transformation.get()); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + EXPECT_EQ(first_node, graph.nodes()[0]); + EXPECT_EQ(input, graph.values()[0]); + EXPECT_EQ(output, graph.values()[1]); +} + +TEST(RemoveIdentityReshape, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + auto first_node = graph.NewNode(); + ASSERT_TRUE(graph.AddConsumer(first_node->id, input->id).ok()); + + auto node_to_remove = graph.NewNode(); + Value* output; + ASSERT_TRUE(AddOutput(&graph, node_to_remove, &output).ok()); + output->tensor.shape = BHWC(1, 1, 1, 11); + node_to_remove->operation.type = ToString(OperationType::RESHAPE); + ReshapeAttributes attr; + attr.new_shape = BHWC(1, 1, 1, 11); + node_to_remove->operation.attributes = attr; + + Value* link; + ASSERT_TRUE(ConnectTwoNodes(&graph, first_node, node_to_remove, &link).ok()); + link->tensor.shape = output->tensor.shape; + ASSERT_EQ(2, graph.nodes().size()); + ASSERT_EQ(3, graph.values().size()); + + auto transformation = NewRemoveIdentityReshape(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("noop", transformation.get()); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + EXPECT_EQ(first_node, graph.nodes()[0]); + EXPECT_EQ(input, graph.values()[0]); + EXPECT_EQ(output, graph.values()[1]); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/types.h b/tensorflow/lite/delegates/gpu/common/types.h new file mode 100644 index 00000000000..8725b4234fe --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/types.h @@ -0,0 +1,208 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TYPES_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TYPES_H_ + +#include +#include +#include +#include + +#include + +namespace tflite { +namespace gpu { + +// TODO(akulik): make these types Google-style compliant. + +using HalfBits = uint16_t; + +class alignas(2) half { + public: + HalfBits bits; + + half() = default; + + half(const half& f) : bits(f.bits) {} + + explicit half(float other) { bits = fp16_ieee_from_fp32_value(other); } + + void operator=(float f) { *this = half(f); } + + operator float() const { return fp16_ieee_to_fp32_value(bits); } +}; + +template +struct alignas(sizeof(T)) Vec4 { + union { + struct { + T x, y, z, w; + }; + std::array data_; + }; + + Vec4() : Vec4(T(0.0f)) {} + + template + Vec4(S x_, S y_, S z_, S w_) : x(x_), y(y_), z(z_), w(w_) {} + explicit Vec4(T v) : x(v), y(v), z(v), w(v) {} + + template + explicit Vec4(S v) : x(v), y(v), z(v), w(v) {} + + Vec4(const Vec4& f) : x(f.x), y(f.y), z(f.z), w(f.w) {} + + template + Vec4(const Vec4& f) : x(f.x), y(f.y), z(f.z), w(f.w) {} + + Vec4& operator=(const Vec4& other) { + x = other.x; + y = other.y; + z = other.z; + w = other.w; + return *this; + } + + static constexpr int size() { return 4; } + + T& operator[](size_t n) { return data_[n]; } + T operator[](size_t n) const { return data_[n]; } + + bool operator==(const Vec4& value) const { + return data_[0] == value[0] && data_[1] == value[1] && + data_[2] == value[2] && data_[3] == value[3]; + } + bool operator!=(const Vec4& value) const { + return !(this->operator==(value)); + } +}; + +template +struct alignas(sizeof(T)) Vec3 { + union { + struct { + T x, y, z; + }; + std::array data_; + }; + + Vec3() : Vec3(T(0.0f)) {} + + template + constexpr Vec3(S x_, S y_, S z_) : x(x_), y(y_), z(z_) {} + explicit Vec3(T v) : x(v), y(v), z(v) {} + + template + explicit Vec3(S v) : x(v), y(v), z(v) {} + + Vec3(const Vec3& f) : x(f.x), y(f.y), z(f.z) {} + + template + Vec3(const Vec3& f) : x(f.x), y(f.y), z(f.z) {} + + Vec3& operator=(const Vec3& other) { + x = other.x; + y = other.y; + z = other.z; + return *this; + } + + static constexpr int size() { return 3; } + + T& operator[](size_t n) { return data_[n]; } + T operator[](size_t n) const { return data_[n]; } + bool operator==(const Vec3& value) const { + return data_[0] == value[0] && data_[1] == value[1] && data_[2] == value[2]; + } + bool operator!=(const Vec3& value) const { + return !(this->operator==(value)); + } +}; + +template +struct alignas(sizeof(T)) Vec2 { + union { + struct { + T x, y; + }; + std::array data_; + }; + + Vec2() : Vec2(T(0.0f)) {} + + template + Vec2(S x_, S y_) : x(x_), y(y_) {} + explicit Vec2(T v) : x(v), y(v) {} + + template + explicit Vec2(S v) : x(v), y(v) {} + + Vec2(const Vec2& f) : x(f.x), y(f.y) {} + + template + Vec2(const Vec2& f) : x(f.x), y(f.y) {} + + Vec2& operator=(const Vec2& other) { + x = other.x; + y = other.y; + return *this; + } + + bool operator==(const Vec2& value) const { + return data_[0] == value[0] && data_[1] == value[1]; + } + + bool operator!=(const Vec2& value) const { + return !(this->operator==(value)); + } + + static constexpr int size() { return 2; } + + T& operator[](size_t n) { return data_[n]; } + T operator[](size_t n) const { return data_[n]; } +}; + +using float2 = Vec2; +using half2 = Vec2; +using byte2 = Vec2; +using ubyte2 = Vec2; +using short2 = Vec2; +using ushort2 = Vec2; +using int2 = Vec2; +using uint2 = Vec2; + +using float3 = Vec3; +using half3 = Vec3; +using byte3 = Vec3; +using ubyte3 = Vec3; +using short3 = Vec3; +using ushort3 = Vec3; +using int3 = Vec3; +using uint3 = Vec3; + +using float4 = Vec4; +using half4 = Vec4; +using byte4 = Vec4; +using ubyte4 = Vec4; +using short4 = Vec4; +using ushort4 = Vec4; +using int4 = Vec4; +using uint4 = Vec4; + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TYPES_H_ diff --git a/tensorflow/lite/delegates/gpu/common/util.h b/tensorflow/lite/delegates/gpu/common/util.h new file mode 100644 index 00000000000..168641eccf8 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/util.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_UTIL_H_ + +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { + +// @param n must be non negative +// @param divisor must be greater than zero +template +T IntegralDivideRoundUp(T n, N divisor) { + const T div = static_cast(divisor); + const T q = n / div; + return n % div == 0 ? q : q + 1; +} + +template <> +inline ::tflite::gpu::uint3 IntegralDivideRoundUp( + ::tflite::gpu::uint3 n, ::tflite::gpu::uint3 divisor) { + return ::tflite::gpu::uint3(IntegralDivideRoundUp(n.x, divisor.x), + IntegralDivideRoundUp(n.y, divisor.y), + IntegralDivideRoundUp(n.z, divisor.z)); +} + +// @param number or its components must be greater than zero +// @param n must be greater than zero +template +T AlignByN(T number, N n) { + return IntegralDivideRoundUp(number, n) * n; +} + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_UTIL_H_ diff --git a/tensorflow/lite/delegates/gpu/common/util_test.cc b/tensorflow/lite/delegates/gpu/common/util_test.cc new file mode 100644 index 00000000000..7c8cb81d156 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/util_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/util.h" + +#include +#include + +namespace tflite { +namespace gpu { +namespace { + +using testing::Eq; + +TEST(UtilTest, IntegralDivideRoundUp) { + EXPECT_THAT(IntegralDivideRoundUp(0, 256), Eq(0)); + EXPECT_THAT(IntegralDivideRoundUp(2u, 256), Eq(1)); + EXPECT_THAT(IntegralDivideRoundUp(2, 256), Eq(1)); + EXPECT_THAT(IntegralDivideRoundUp(255u, 256), Eq(1)); + EXPECT_THAT(IntegralDivideRoundUp(255, 256), Eq(1)); + EXPECT_THAT(IntegralDivideRoundUp(256u, 256), Eq(1)); + EXPECT_THAT(IntegralDivideRoundUp(256, 256), Eq(1)); + EXPECT_THAT(IntegralDivideRoundUp(257u, 256), Eq(2)); + EXPECT_THAT(IntegralDivideRoundUp(257, 256), Eq(2)); +} + +TEST(UtilTest, AlignByN) { + EXPECT_THAT(AlignByN(0u, 256), Eq(0)); + EXPECT_THAT(AlignByN(1u, 256), Eq(256)); + EXPECT_THAT(AlignByN(255u, 256), Eq(256)); + EXPECT_THAT(AlignByN(256u, 256), Eq(256)); + EXPECT_THAT(AlignByN(257u, 256), Eq(512)); + + EXPECT_THAT(AlignByN(1, 4), Eq(4)); + EXPECT_THAT(AlignByN(80, 4), Eq(80)); + EXPECT_THAT(AlignByN(81, 4), Eq(84)); +} + +} // namespace +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/BUILD b/tensorflow/lite/delegates/gpu/gl/BUILD new file mode 100644 index 00000000000..409581b9a87 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/BUILD @@ -0,0 +1,438 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") + +cc_library( + name = "api", + srcs = ["api.cc"], + hdrs = ["api.h"], + deps = [ + ":command_queue", + ":compiler", + ":compiler_options", + ":gl_call", + ":gpu_info", + ":node_shader", + ":object", + ":object_manager", + ":portable", + ":runtime", + ":runtime_options", + ":stats", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/gl/workgroups:calculator", + ] + select({ + "//tensorflow/lite/delegates/gpu:tflite_gpu_binary_release": [], + "//conditions:default": [ + ":serialization", + ], + }), +) + +cc_library( + name = "command_queue", + srcs = ["command_queue.cc"], + hdrs = ["command_queue.h"], + deps = [ + ":gl_call", + ":gl_program", + ":gl_sync", + ":gpu_info", + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/memory", + ], +) + +flatbuffer_cc_library( + name = "common_cc_fbs", + srcs = ["common.fbs"], +) + +# Generic schema for inference on GPU device. +flatbuffer_cc_library( + name = "compiled_model_cc_fbs", + srcs = ["compiled_model.fbs"], + flatc_args = [ + "--scoped-enums", + ], + includes = [ + "//tensorflow/lite/delegates/gpu/gl:common_cc_fbs_includes", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":compiler_options", + ":float16_conversions", + ":gpu_info", + ":node_shader", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl/compiler:compiled_node", + "//tensorflow/lite/delegates/gpu/gl/compiler:fuse_auto_input", + "//tensorflow/lite/delegates/gpu/gl/compiler:fuse_inline", + "//tensorflow/lite/delegates/gpu/gl/compiler:fuse_inplace", + "//tensorflow/lite/delegates/gpu/gl/compiler:shader_code", + "//tensorflow/lite/delegates/gpu/gl/compiler:shader_codegen", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:any", + ], +) + +cc_library( + name = "compiler_options", + hdrs = ["compiler_options.h"], + deps = [ + ":gpu_info", + ":object", + ], +) + +cc_library( + name = "egl_context", + srcs = ["egl_context.cc"], + hdrs = ["egl_context.h"], + deps = [ + ":gl_call", + ":gl_errors", + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + +cc_library( + name = "egl_environment", + srcs = ["egl_environment.cc"], + hdrs = ["egl_environment.h"], + deps = [ + ":egl_context", + ":egl_surface", + ":gl_call", + ":gpu_info", + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "egl_surface", + srcs = ["egl_surface.cc"], + hdrs = ["egl_surface.h"], + deps = [ + ":gl_call", + ":gl_errors", + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + +cc_library( + name = "float16_conversions", + srcs = ["float16_conversions.cc"], + hdrs = ["float16_conversions.h"], + deps = [ + ":object", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@FP16", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "gl_buffer", + srcs = ["gl_buffer.cc"], + hdrs = ["gl_buffer.h"], + deps = [ + ":gl_call", + ":gl_errors", + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "gl_buffer_test", + srcs = ["gl_buffer_test.cc"], + linkopts = [ + "-lGLESv3", + "-lEGL", + ], + tags = [ + "local", + "nobuilder", + "notap", + "tflite_not_portable_ios", + ], + deps = [ + ":egl_environment", + ":gl_buffer", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "gl_call", + hdrs = ["gl_call.h"], + deps = [ + ":gl_errors", + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + +cc_library( + name = "gl_errors", + srcs = ["gl_errors.cc"], + hdrs = ["gl_errors.h"], + deps = [ + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "gl_program", + srcs = ["gl_program.cc"], + hdrs = ["gl_program.h"], + deps = [ + ":gl_call", + ":gl_errors", + ":gl_shader", + ":portable", + ":uniform_parameter", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "gl_shader", + srcs = ["gl_shader.cc"], + hdrs = ["gl_shader.h"], + deps = [ + ":gl_call", + ":gl_errors", + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + +cc_library( + name = "gl_texture", + srcs = ["gl_texture.cc"], + hdrs = ["gl_texture.h"], + deps = [ + ":gl_call", + ":gl_errors", + ":portable", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "gl_sync", + srcs = ["gl_sync.cc"], + hdrs = ["gl_sync.h"], + deps = [ + ":gl_call", + ":gl_errors", + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + +cc_library( + name = "gpu_info", + srcs = ["gpu_info.cc"], + hdrs = ["gpu_info.h"], + deps = [ + ":gl_errors", + ":portable", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/strings", + ], +) + +flatbuffer_cc_library( + name = "metadata_cc_fbs", + srcs = ["metadata.fbs"], + includes = [ + "//tensorflow/lite/delegates/gpu/gl:common_cc_fbs_includes", + "//tensorflow/lite/delegates/gpu/gl:workgroups_cc_fbs_includes", + ], +) + +cc_library( + name = "node_shader", + hdrs = ["node_shader.h"], + deps = [ + ":compiler_options", + ":gpu_info", + ":object", + ":uniform_parameter", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + ], +) + +cc_library( + name = "object", + hdrs = ["object.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "object_manager", + srcs = ["object_manager.cc"], + hdrs = ["object_manager.h"], + deps = [ + ":gl_buffer", + ":gl_texture", + ":stats", + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "portable", + hdrs = [ + "portable_egl.h", + "portable_gl31.h", + ], +) + +cc_library( + name = "runtime", + srcs = ["runtime.cc"], + hdrs = ["runtime.h"], + deps = [ + ":command_queue", + ":gl_buffer", + ":gl_call", + ":gl_errors", + ":gl_program", + ":gl_shader", + ":gl_texture", + ":gpu_info", + ":object", + ":object_manager", + ":portable", + ":runtime_options", + ":stats", + ":uniform_parameter", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl/runtime:shared_buffer", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "runtime_options", + hdrs = ["runtime_options.h"], +) + +cc_library( + name = "serialization", + srcs = ["serialization.cc"], + hdrs = ["serialization.h"], + deps = [ + ":common_cc_fbs", + ":compiled_model_cc_fbs", + ":object", + ":uniform_parameter", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@flatbuffers", + ], +) + +cc_test( + name = "serialization_test", + srcs = ["serialization_test.cc"], + tags = [ + "local", + "nobuilder", + "notap", + "tflite_not_portable_ios", + ], + deps = [ + ":object", + ":serialization", + ":uniform_parameter", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "stats", + hdrs = ["stats.h"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "uniform_parameter", + hdrs = ["uniform_parameter.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/types:variant", + ], +) + +flatbuffer_cc_library( + name = "workgroups_cc_fbs", + srcs = ["workgroups.fbs"], + includes = [ + "//tensorflow/lite/delegates/gpu/gl:common_cc_fbs_includes", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/lite/delegates/gpu/gl/api.cc b/tensorflow/lite/delegates/gpu/gl/api.cc new file mode 100644 index 00000000000..8dea7235b13 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/api.cc @@ -0,0 +1,418 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/api.h" + +#include +#include +#include +#include // NOLINT +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" +#include "tensorflow/lite/delegates/gpu/gl/runtime.h" + +#ifndef TFLITE_GPU_BINARY_RELEASE +#include "tensorflow/lite/delegates/gpu/gl/serialization.h" +#endif // TFLITE_GPU_BINARY_RELEASE + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +using ObjectsSizes = std::unordered_map; + +enum class InferenceContextState { + NOT_STARTED, + IN_PROGRESS, +}; + +class InferenceContextImpl : public InferenceContext { + public: + explicit InferenceContextImpl(std::unique_ptr runtime) + : runtime_(std::move(runtime)) {} + + Status Execute() final { + std::lock_guard lock(guard_); + if (state_ != InferenceContextState::NOT_STARTED) { + return FailedPreconditionError("InferenceContext is not reset"); + } + state_ = InferenceContextState::IN_PROGRESS; + return runtime_->Execute(); + } + + Status Reset() final { + std::lock_guard lock(guard_); + // TODO(akulik): should Reset not return Status? + state_ = InferenceContextState::NOT_STARTED; + return OkStatus(); + } + + RuntimeStats stats() const final { return runtime_->stats(); } + + private: + std::unique_ptr runtime_; + + mutable std::mutex guard_; + InferenceContextState state_ = InferenceContextState::NOT_STARTED; +}; + +class InferenceContextWithBatchImpl : public InferenceContext { + public: + InferenceContextWithBatchImpl(const ObjectsSizes& sizes, + const ObjectManager* objects, + std::unique_ptr refs, + std::unique_ptr runtime) + : sizes_(sizes), + objects_(objects), + refs_(std::move(refs)), + runtime_(std::move(runtime)) {} + + Status Execute() final { + std::lock_guard lock(guard_); + if (state_ != InferenceContextState::NOT_STARTED) { + return FailedPreconditionError("InferenceContext is not reset"); + } + state_ = InferenceContextState::IN_PROGRESS; + + // Calculate expected number of batches and check that all external objects + // match that number. + int num_batches = 0; + for (const auto& s : sizes_) { + const ValueId id = s.first; + const size_t byte_size = s.second; + + auto buffer = objects_->FindBuffer(id); + if (!buffer) continue; + + if (buffer->bytes_size() % byte_size) { + return InvalidArgumentError(absl::StrCat( + "Object ", id, " does not match expected byte size: ", byte_size)); + } + + const size_t b = buffer->bytes_size() / byte_size; + if (num_batches == 0) { + num_batches = b; + } else if (num_batches != b) { + return InvalidArgumentError(absl::StrCat( + "Object ", id, " size does not match expected batch size: ", b, + " vs ", num_batches)); + } + } + + for (size_t b = 0; b < num_batches; ++b) { + // slice external objects by batch. + for (const auto& s : sizes_) { + const ValueId id = s.first; + const size_t byte_size = s.second; + auto buffer = objects_->FindBuffer(id); + if (buffer) { + auto ref = refs_->FindBuffer(id); + if (!ref) { + return InvalidArgumentError( + absl::StrCat("Reference to ", id, " is not found")); + } + RETURN_IF_ERROR(buffer->MakeView(b * byte_size, byte_size, ref)); + } + } + RETURN_IF_ERROR(runtime_->Execute()); + } + return OkStatus(); + } + + Status Reset() final { + std::lock_guard lock(guard_); + state_ = InferenceContextState::NOT_STARTED; + // TODO(akulik): should Reset not return Status? + return OkStatus(); + } + + RuntimeStats stats() const final { return runtime_->stats(); } + + private: + const ObjectsSizes sizes_; + const ObjectManager* objects_; + + // view over external objects provided by a user. + std::unique_ptr refs_; + std::unique_ptr runtime_; + + mutable std::mutex guard_; + InferenceContextState state_ = InferenceContextState::NOT_STARTED; +}; + +struct ProgramParameters { + // A list of uniform parameters to be set. + std::vector parameters; + + // A list of objects to bind to opengl program. + std::vector objects; + + uint3 workgroup_size; + uint3 num_workgroups; + + size_t shader_idx; +}; + +std::string GetShaderHeader(uint3 localsize) { + return absl::StrCat("#version 310 es\nlayout(local_size_x = ", localsize.x, + ", local_size_y = ", localsize.y, + ", local_size_z = ", localsize.z, ") in;\n"); +} + +class CompiledModelImpl +#ifndef TFLITE_GPU_BINARY_RELEASE + : public CompiledModel, + public DeserializationHandler { +#else + : public CompiledModel { +#endif // TFLITE_GPU_BINARY_RELEASE + public: + explicit CompiledModelImpl(const GpuInfo& gpu_info) : gpu_info_(gpu_info) {} + + // Called while compiling shaders from scratch + Status Add(const WorkgroupsCalculator& workgroup_calculator, + ShaderCode code) { + // Calculate workgroup size. + uint3 workgroup_size = workgroup_calculator.Calculate(code); + uint3 num_workgroups = IntegralDivideRoundUp(code.workload, workgroup_size); + + for (const auto& object : code.objects) { + if (IsRef(object)) { + object_sizes_[GetRef(object)] = ByteSizeOf(object); + } + } + + // Store full shader and compile it if necessary. + size_t shader_idx; + RETURN_IF_ERROR( + AddFullShader(code.source_code, workgroup_size, &shader_idx)); + programs_.push_back({ + std::move(code.parameters), + std::move(code.objects), + workgroup_size, + num_workgroups, + shader_idx, + }); + return OkStatus(); + } + + // Store full shader and compile it if necessary. + // Returns full_shader_index + Status AddFullShader(const std::string& partial_shader, + const uint3& workgroup_size, size_t* size) { + std::string shader_src = GetShaderHeader(workgroup_size) + partial_shader; + auto it = shader_to_index_.find(shader_src); + if (it == shader_to_index_.end()) { + GlShader shader; + RETURN_IF_ERROR( + GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src, &shader)); + shaders_.push_back(std::move(shader)); + shader_to_index_.insert({shader_src, shader_to_index_.size()}); + *size = shader_to_index_.size() - 1; + } else { + *size = it->second; + } + return OkStatus(); + } + + Status NewRun( + const RuntimeOptions& options, const ObjectManager* objects, + CommandQueue* command_queue, + std::unique_ptr* inference_context) const final { + std::unique_ptr refs; + if (dynamic_batch_) { + // Runtime is using objects from refs that will point to provided objects. + // At this point just create 0 batch slice references. + refs = absl::make_unique(); + for (const auto& s : object_sizes_) { + auto buffer = objects->FindBuffer(s.first); + if (!buffer) continue; + GlBuffer ref; + RETURN_IF_ERROR(buffer->MakeView(0, s.second, &ref)); + RETURN_IF_ERROR(refs->RegisterBuffer(s.first, std::move(ref))); + } + } + auto runtime = absl::make_unique(options, gpu_info_, command_queue, + refs ? refs.get() : objects); + for (auto& c : programs_) { + RETURN_IF_ERROR(runtime->AddProgram(shaders_[c.shader_idx], c.parameters, + c.objects, c.num_workgroups)); + } + RETURN_IF_ERROR(runtime->PrepareForExecution()); + if (dynamic_batch_) { + *inference_context = absl::make_unique( + object_sizes_, objects, std::move(refs), std::move(runtime)); + } else { + *inference_context = + absl::make_unique(std::move(runtime)); + } + return OkStatus(); + } + +#ifndef TFLITE_GPU_BINARY_RELEASE + // Called on deserialization + Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, const uint3& num_workgroups, + size_t partial_shader_index) final { + for (auto& object : objects) { + if (IsRef(object)) { + object_sizes_[GetRef(object)] = ByteSizeOf(object); + } + } + + size_t shader_idx; + RETURN_IF_ERROR(AddFullShader(partial_shaders_[partial_shader_index], + workgroup_size, &shader_idx)); + programs_.push_back({ + parameters, + objects, + workgroup_size, + num_workgroups, + shader_idx, + }); + return OkStatus(); + } + + Status Serialize( + std::vector* serialized_compiled_model) const final { + SerializedCompiledModelBuilder builder; + + // sort shaders first. They need to be serialized in order. + std::vector full_shaders(shaders_.size()); + for (const auto& shader : shader_to_index_) { + full_shaders[shader.second] = shader.first; + } + + std::unordered_map partial_shader_to_index; + std::vector partial_shaders; + for (const auto& program : programs_) { + // Remove a header from a shader. + std::string shader_without_header = full_shaders[program.shader_idx]; + shader_without_header.erase(0, shader_without_header.find("in;") + 3); + + // Insert shader into partial shaders array. + auto it = partial_shader_to_index.find(shader_without_header); + size_t shader_idx; + if (it == partial_shader_to_index.end()) { + shader_idx = partial_shaders.size(); + partial_shaders.push_back(shader_without_header); + builder.AddShader(shader_without_header); + partial_shader_to_index.insert({shader_without_header, shader_idx}); + } else { + shader_idx = it->second; + } + builder.AddProgram(program.parameters, program.objects, + program.workgroup_size, program.num_workgroups, + shader_idx); + } + CompiledModelOptions options; + options.dynamic_batch = dynamic_batch_; + auto data = builder.Finalize(options); + serialized_compiled_model->insert(serialized_compiled_model->end(), + data.begin(), data.end()); + return OkStatus(); + } + + Status OnShader(absl::Span shader_src) final { + std::string source(shader_src.data(), shader_src.size()); + partial_shaders_.push_back(source); + return OkStatus(); + } + + void OnOptions(const CompiledModelOptions& options) final { + dynamic_batch_ = options.dynamic_batch; + } +#endif // TFLITE_GPU_BINARY_RELEASE + + CompilerStats stats() const final { return stats_; } + + void set_dynamic_batch(bool dynamic_batch) { dynamic_batch_ = dynamic_batch; } + + private: + const GpuInfo gpu_info_; + bool dynamic_batch_ = false; + + std::vector partial_shaders_; + std::vector shaders_; + + // Shaders are serialized in order of their indices. + std::unordered_map shader_to_index_; + std::deque programs_; + std::unordered_map object_sizes_; + CompilerStats stats_; +}; + +// @return true if all tensors have same batch value. +bool IsBatchMatchesForAllValues(const GraphFloat32& model) { + const int32_t b = model.values()[0]->tensor.shape.b; + for (auto value : model.values()) { + if (value->tensor.shape.b != b) { + return false; + } + } + return true; +} + +} // namespace + +Status Compile(const CompilationOptions& options, const GraphFloat32& model, + const NodeShader& node_shader, + const WorkgroupsCalculator& workgroup_calculator, + std::unique_ptr* compiled_model) { + if (!IsBatchMatchesForAllValues(model)) { + return InvalidArgumentError("Only identical batch dimension is supported"); + } + GpuInfo gpu_info; + RETURN_IF_ERROR(RequestGpuInfo(&gpu_info)); + auto compiled_model_impl = absl::make_unique(gpu_info); + compiled_model_impl->set_dynamic_batch(options.dynamic_batch); + auto compiler = NewCompiler(&node_shader, &gpu_info, options); + RETURN_IF_ERROR(compiler->Compile(model, [&](ShaderCode code) -> Status { + return compiled_model_impl->Add(workgroup_calculator, std::move(code)); + })); + *compiled_model = std::move(compiled_model_impl); + return OkStatus(); +} + +#ifndef TFLITE_GPU_BINARY_RELEASE +Status ReadSerializedModel(const std::vector& serialized_model, + std::unique_ptr* compiled_model) { + GpuInfo gpu_info; + RETURN_IF_ERROR(RequestGpuInfo(&gpu_info)); + auto compiled_model_impl = absl::make_unique(gpu_info); + RETURN_IF_ERROR(DeserializeCompiledModel( + absl::MakeConstSpan(serialized_model), compiled_model_impl.get())); + *compiled_model = std::move(compiled_model_impl); + return OkStatus(); +} +#endif // TFLITE_GPU_BINARY_RELEASE + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/api.h b/tensorflow/lite/delegates/gpu/gl/api.h new file mode 100644 index 00000000000..3258d162110 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/api.h @@ -0,0 +1,103 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_API_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_API_H_ + +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/command_queue.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/object_manager.h" +#include "tensorflow/lite/delegates/gpu/gl/runtime_options.h" +#include "tensorflow/lite/delegates/gpu/gl/stats.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" + +namespace tflite { +namespace gpu { +namespace gl { + +class InferenceContext; + +// Represents a model that was prepared for execution. It is stored in a format +// most suitable for execution and optionally may include pre-generated or +// pre-compiled GPU shaders or whatever is needed for efficient execution. +class CompiledModel { + public: + virtual ~CompiledModel() = default; + + virtual CompilerStats stats() const = 0; + + // Creates new inference context. Result can outlive @this. + // + // NewRun call as well as subsequent calls to InferenceContext methods should + // be done from the same EGL context. + virtual Status NewRun( + const RuntimeOptions& options, const ObjectManager* objects, + CommandQueue* command_queue, + std::unique_ptr* inference_context) const = 0; + +#ifndef TFLITE_GPU_BINARY_RELEASE + // Serializes compiled model to a string. + // @return true if serialization finished successfully. + virtual Status Serialize( + std::vector* serialized_compiled_model) const = 0; +#endif // TFLITE_GPU_BINARY_RELEASE +}; + +// Turns the given model into "compiled" form that is suitable for inference. +Status Compile(const CompilationOptions& options, const GraphFloat32& model, + const NodeShader& node_shader, + const WorkgroupsCalculator& workgroup_calculator, + std::unique_ptr* compiled_model); + +#ifndef TFLITE_GPU_BINARY_RELEASE +// Reads serialized representation previously created with +// CompiledModel::Serialize call. +Status ReadSerializedModel(const std::vector& serialized_model, + std::unique_ptr* compiled_model); +#endif // TFLITE_GPU_BINARY_RELEASE + +// Encapsulates everything needed for one or more inference executions done +// sequentially. +// +// Thread-safe. +class InferenceContext { + public: + virtual ~InferenceContext() = default; + + virtual RuntimeStats stats() const = 0; + + // Executes inference. + virtual Status Execute() = 0; + + // Asks context to reset it for another round. Keep in mind that does not + // affect inputs nor outputs which are not cleared, so it is possible to + // re-use them. + // It is an error to call Reset while previous run is still in progress. + virtual Status Reset() = 0; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_API_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.cc b/tensorflow/lite/delegates/gpu/gl/command_queue.cc new file mode 100644 index 00000000000..8e0e085da28 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/command_queue.cc @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/command_queue.h" + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_sync.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class DefaultCommandQueue : public CommandQueue { + public: + Status Dispatch(const GlProgram& program, const uint3& workgroups) override { + RETURN_IF_ERROR(program.Dispatch(workgroups)); + return TFLITE_GPU_CALL_GL(glMemoryBarrier, GL_ALL_BARRIER_BITS); + } + + Status WaitForCompletion() override { + // TODO(akulik): may be let a user to choose what wait method to use. + return GlActiveSyncWait(); + } +}; + +// On Adreno do flush periodically as this affects performance. Command queue +// needs to be manually managed to ensure that accumulated work goes to GPU as +// fast as it can. +// +// Also, on older Adreno devices glFlush is required after every memory barrier +// to avoid hitting GPU driver bug. +class AdrenoCommandQueue : public DefaultCommandQueue { + public: + explicit AdrenoCommandQueue(int flush_every_n) + : flush_every_n_(flush_every_n) {} + + Status Dispatch(const GlProgram& program, const uint3& workgroups) final { + RETURN_IF_ERROR(DefaultCommandQueue::Dispatch(program, workgroups)); + if ((++program_counter_ % flush_every_n_) == 0) { + glFlush(); + } + return OkStatus(); + } + + private: + const int flush_every_n_; + int program_counter_ = 0; +}; + +} // namespace + +std::unique_ptr NewCommandQueue(const GpuInfo& gpu_info) { + if (gpu_info.type == GpuType::ADRENO) { + int flush_every_n = 1; + // On Adreno 630 and Adreno 505 there is up to 2x performance boost when + // glFlush happens not so often. + if (gpu_info.gpu_model == GpuModel::ADRENO630 || + gpu_info.gpu_model == GpuModel::ADRENO505) { + flush_every_n = 10; + } + return absl::make_unique(flush_every_n); + } + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.h b/tensorflow/lite/delegates/gpu/gl/command_queue.h new file mode 100644 index 00000000000..bf313b495a3 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/command_queue.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMMAND_QUEUE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMMAND_QUEUE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// GL programs can be executed directly via dispatch call or using a queue +// abstraction similar to one in OpenCL and Vulkan. +// CommandQueue executes given programs in order as they come. +class CommandQueue { + public: + virtual ~CommandQueue() = default; + + // Dispatches a program. It may or may not call glFlush. + virtual Status Dispatch(const GlProgram& program, + const uint3& workgroups) = 0; + + // Waits until all programs dispatched prior this call are completed. + virtual Status WaitForCompletion() = 0; +}; + +// By default memory barrier is inserted after every dispatch. +std::unique_ptr NewCommandQueue(const GpuInfo& gpu_info); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMMAND_QUEUE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/common.fbs b/tensorflow/lite/delegates/gpu/gl/common.fbs new file mode 100644 index 00000000000..b07123455ac --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/common.fbs @@ -0,0 +1,16 @@ +namespace tflite.gpu.gl.data; + +table Uint3 { + x:uint32; + y:uint32; + z:uint32; +} + +table Uint2 { + x:uint32; + y:uint32; +} + +table Uint1 { + x:uint32; +} diff --git a/tensorflow/lite/delegates/gpu/gl/compiled_model.fbs b/tensorflow/lite/delegates/gpu/gl/compiled_model.fbs new file mode 100644 index 00000000000..1b47bf4d20b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiled_model.fbs @@ -0,0 +1,155 @@ +include "common.fbs"; + +namespace tflite.gpu.gl.data; + +file_identifier "AFCM"; + +file_extension "flow"; + +// Encapsulates entire OpenGL program with all necessary dependencies and +// parameters. +table Program { + // A collection of objects this program refers to. + objects:[Object]; + + // Uniform parameters to be set before execution. + parameters:[UniformParameter]; + + // Defines the number of work groups. + number_workgroups:Uint3; + + // Defines the size of a workgroup. + workgroup_size:Uint3; + + // Reference to a shader in this compiled model. + shader_index:uint32; + + // Contains binary code that was once created after successful shader + // compilation. Normally it is much faster to instantiate a program from + // compiled binary. + binary:ProgramBinary; +} + +// Compiled binary representation of a program. +table ProgramBinary { + format:uint32; // GLenum + + // Compiled binary shader blob extracted from GL. + binary:[ubyte]; +} + +enum ParameterType : byte { + INT32 = 0, + UINT32 = 1, + FLOAT32 = 2, + INT32_2 = 3, +} + +enum DataType : byte { + UNKNOWN = 0, + FLOAT32 = 1, + FLOAT16 = 2, + INT32 = 3, + INT16 = 4, +} + +union DataVariant { + DataInt32, + DataFloat, + DataUint32, +} + +table DataFloat { + data:[float]; +} + +table DataInt32 { + data:[int32]; +} + +table DataUint32 { + data:[uint32]; +} + +table UniformParameter { + name:string; + + type:ParameterType; + + // Data is optional. If it is known in advance, it is encoded here, otherwise + // a parameter will be set in runtime. + data:DataVariant; +} + +enum AccessType : byte { + READ = 0, + WRITE = 1, + READ_WRITE = 2, +} + +enum ObjectType : byte { + UNKNOWN = 0, + BUFFER = 1, + TEXTURE = 2, +} + +union ObjectVariant { + ObjectData, + ObjectRef, +} + +union ObjectSize { + Uint1, + Uint2, + Uint3, +} + +table Object { + access:AccessType; + + binding:uint32; + + data_type:DataType; + + type:ObjectType; + + size:ObjectSize; + + object:ObjectVariant; +} + +// Represents a reference to another object provided by object manager. +table ObjectRef { + // Unique global identifier to be used by an object manager to lookup this + // buffer. + global_id:uint32; +} + +table ObjectData { + data:[uint8]; +} + +// Represents entire model as a collection of programs, inputs and outputs. +table CompiledModel { + parameters:Parameters; + + // A collection of shaders used by programs. + shaders:[string]; + + // A collection of programs that need to be executed in the same order. + programs:[Program]; +} + +table Parameters { + // indicated flow engine version that compiled this model. If engine version + // does not match compiled model, then a model need to be recompiled. + // version:uint32; // not implemented + + // Could potentially be used to track environment when a model was compiled + // and detect whether it was changed and model recompilation is needed. + // environment_hash:uint32; // not implemented + + dynamic_batch:bool; +} + +root_type CompiledModel; diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.cc b/tensorflow/lite/delegates/gpu/gl/compiler.cc new file mode 100644 index 00000000000..c6d5cf5b370 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler.cc @@ -0,0 +1,293 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h" +#include "tensorflow/lite/delegates/gpu/gl/float16_conversions.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +struct ExceedSizeChecker { + bool operator()(uint32_t v) const { return v > max_size; } + + bool operator()(const uint2& v) const { + return v.x > max_size || v.y > max_size; + } + + bool operator()(const uint3& v) const { + return v.x > max_size || v.y > max_size || v.z > max_z_size; + } + + int max_size; + int max_z_size; +}; + +// Returns true if any size variable exceeds the given limit +bool ExceedsMaxSize(const Object& object, const GpuInfo& gpu_info) { + return absl::visit(ExceedSizeChecker{gpu_info.max_texture_size, + gpu_info.max_array_texture_layers}, + object.size); +} + +ObjectType ChooseFastestObjectType(const GpuInfo& gpu_info) { + return gpu_info.type == GpuType::ADRENO ? ObjectType::TEXTURE + : ObjectType::BUFFER; +} + +ObjectType ChooseFastestRefObjectType(const GpuInfo& gpu_info, + const CompilationOptions& options) { + if (gpu_info.type != GpuType::ADRENO) { + return ObjectType::BUFFER; + } + switch (gpu_info.gpu_model) { + case GpuModel::ADRENO630: + return ObjectType::TEXTURE; + default: + return options.allow_precision_loss ? ObjectType::TEXTURE + : ObjectType::BUFFER; + } +} + +// Compiler executes the following steps: +// 1. Runs NodeShader for every node in the input graph. +// 2. Creates a compiled graph that mirrors the input graph and keeps +// GeneratedCode in operation's attributes. +// 3. Fuses nodes in the compiled graph. +// 4. Generates the full shader code using the nodes in the compiled graph. +class CompilerImpl : public Compiler { + public: + // We use const GpuInfo* because it doesn't let you assign temporary object + CompilerImpl(const NodeShader* node_shader, const GpuInfo* gpu_info, + const CompilationOptions& options) + : node_shader_(*node_shader), gpu_info_(*gpu_info), options_(options) { + if (options_.preferred_obj_type == ObjectType::UNKNOWN) { + options_.preferred_obj_type = ChooseFastestObjectType(*gpu_info); + } + if (options_.ref_obj_type == ObjectType::UNKNOWN) { + options_.ref_obj_type = ChooseFastestRefObjectType(*gpu_info, options); + } + } + + Status Compile(const GraphFloat32& graph, + const ShaderCodeCallback& callback) final { + // It is important to have ids in a compiled graph identical to the given + // graph. + RETURN_IF_ERROR(graph.MakeExactCopy(&compiled_graph_)); + + // Clear out batch dimension for dynamic batch support. + if (options_.dynamic_batch) { + for (auto value : compiled_graph_.values()) { + value->tensor.shape.b = 1; + } + } + + // Generate a shader for a node and all input/output objects. + for (auto node : compiled_graph_.nodes()) { + CompiledNodeAttributes attr; + attr.node_indices.push_back(node->id); + RETURN_IF_ERROR(node_shader_.GenerateCode( + {&compiled_graph_, &gpu_info_, node, options_}, &attr.code)); + node->operation.attributes = std::move(attr); + } + + ModelTransformer transformer(&compiled_graph_, nullptr); + if (options_.fuse_operations) { + FuseAutoOutputWithInline fuse_inline; + if (!transformer.Apply("fuse_auto_with_inline", &fuse_inline)) { + return InternalError("fuse_auto_with_inline failed"); + } + FuseInplaceUpdate fuse_inplace; + if (!transformer.Apply("fuse_inplace_update", &fuse_inplace)) { + return InternalError("fuse_inplace failed"); + } + if (options_.auto_input_fusion) { + FuseAutoInput fuse_auto_input; + if (!transformer.Apply("fuse_auto_input", &fuse_auto_input)) { + return InternalError("fuse_auto_input failed"); + } + } + } + RemoveUnusedInplaceUpdates remove_inplace_updates; + if (!transformer.Apply("remove_inplace_updates", &remove_inplace_updates)) { + return InternalError("remove_inplace_updates failed"); + } + + // Prepare internal objects. + std::unordered_map objects; + for (auto value : compiled_graph_.values()) { + Object object = MakePHWC4Ref(value->id, value->tensor.shape); + object.data_type = value->tensor.type; + // External references may not be upgraded to f16 nor be represented as + // textures. + const bool is_external = + graph.IsGraphInput(value->id) || graph.IsGraphOutput(value->id); + if (is_external) { + object.object_type = ObjectType::BUFFER; + } else if (options_.allow_precision_loss) { + MaybeConvertToFloat16(&object); + } + objects[value->id] = std::move(object); + } + + // Prepare readonly objects and check whether object types are supported. + for (auto node : compiled_graph_.nodes()) { + auto& attr = + absl::any_cast(node->operation.attributes); + + // Set workload explicitly. + if (attr.code.workload == uint3()) { + auto outputs = compiled_graph_.FindOutputs(node->id); + auto shape = outputs[0]->tensor.shape; + for (auto output : outputs) { + if (shape != output->tensor.shape) { + return FailedPreconditionError( + "Workload uint3() requires all output sizes to match"); + } + } + attr.code.workload = + uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)); + } + + int num_textures = 0; + // Counts number of used textures and chooses ObjectType for an object. + auto set_object_type = [&](Object* object) { + if (object->object_type == ObjectType::BUFFER) { + // Don't change from buffer once it is set. + return; + } + bool is_ref = IsRef(*object); + if (num_textures < gpu_info_.max_image_units && + !ExceedsMaxSize(*object, gpu_info_) && + (object->object_type == ObjectType::TEXTURE || + (is_ref && options_.ref_obj_type == ObjectType::TEXTURE) || + (!is_ref && options_.preferred_obj_type == ObjectType::TEXTURE))) { + object->object_type = ObjectType::TEXTURE; + num_textures++; + } else { + object->object_type = ObjectType::BUFFER; + } + }; + + for (auto& object : attr.code.objects) { + // Downgrade readonly objects to F16 is requested. + if (options_.allow_precision_loss) { + MaybeConvertToFloat16(&object.second); + } + set_object_type(&object.second); + } + + for (auto ref : compiled_graph_.FindInputs(node->id)) { + set_object_type(&objects[ref->id]); + } + for (auto ref : compiled_graph_.FindOutputs(node->id)) { + set_object_type(&objects[ref->id]); + } + } + + // Generate shaders from the transformed graph. + ShaderCodegen codegen(options_, gpu_info_); + for (auto node : compiled_graph_.nodes()) { + auto& attr = + absl::any_cast(node->operation.attributes); + if (attr.code.source_code.empty()) { + // noop. Skip this node. + continue; + } + + // Declare inputs and outputs explicitly. + for (auto ref : compiled_graph_.FindInputs(node->id)) { + auto object = objects[ref->id]; + object.access = AccessType::READ; + attr.inputs.push_back(object); + } + for (auto ref : compiled_graph_.FindOutputs(node->id)) { + auto object = objects[ref->id]; + object.access = AccessType::WRITE; + attr.outputs.push_back(object); + } + + // Allocate bindings. Textures must be bound first. max_image_units also + // defines max binding number for a texture. + uint32_t binding = 0; + auto set_binding = [&](ObjectType type, Object& object) { + if (object.object_type == type) { + object.binding = binding++; + } + }; + for (auto& object : attr.inputs) { + set_binding(ObjectType::TEXTURE, object); + } + for (auto& object : attr.outputs) { + set_binding(ObjectType::TEXTURE, object); + } + for (auto& object : attr.code.objects) { + set_binding(ObjectType::TEXTURE, object.second); + } + for (auto& object : attr.inputs) { + set_binding(ObjectType::BUFFER, object); + } + for (auto& object : attr.outputs) { + set_binding(ObjectType::BUFFER, object); + } + for (auto& object : attr.code.objects) { + set_binding(ObjectType::BUFFER, object.second); + } + + // Generate source code. + ShaderCode shader_code; + RETURN_IF_ERROR(codegen.Build(std::move(attr), &shader_code)); + RETURN_IF_ERROR(callback(std::move(shader_code))); + } + return OkStatus(); + } + + private: + const NodeShader& node_shader_; + const GpuInfo& gpu_info_; + CompilationOptions options_; + GraphFloat32 compiled_graph_; +}; + +} // namespace + +std::unique_ptr NewCompiler(const NodeShader* node_shader, + const GpuInfo* gpu_info, + const CompilationOptions& options) { + return absl::make_unique(node_shader, gpu_info, options); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler.h b/tensorflow/lite/delegates/gpu/gl/compiler.h new file mode 100644 index 00000000000..b0f1f452610 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +using ShaderCodeCallback = std::function; + +class Compiler { + public: + virtual ~Compiler() = default; + + // Goes over a graph and generates OpenGL shaders for the given graph. + // Callback is called for every generated shader. Callback may execute shaders + // as they come or store them elsewhere to execute later. + virtual Status Compile(const GraphFloat32& graph, + const ShaderCodeCallback& callback) = 0; +}; + +std::unique_ptr NewCompiler( + const NodeShader* node_shader, const GpuInfo* gpu_info, + const CompilationOptions& options); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD new file mode 100644 index 00000000000..da200306a76 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD @@ -0,0 +1,198 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") + +cc_library( + name = "preprocessor", + srcs = ["preprocessor.cc"], + hdrs = ["preprocessor.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "preprocessor_test", + srcs = ["preprocessor_test.cc"], + tags = [ + "local", + "tflite_not_portable_ios", + ], + deps = [ + ":preprocessor", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "parameter_accessor", + srcs = ["parameter_accessor.cc"], + hdrs = ["parameter_accessor.h"], + deps = [ + ":preprocessor", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:uniform_parameter", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "parameter_accessor_test", + srcs = ["parameter_accessor_test.cc"], + tags = [ + "local", + "tflite_not_portable_ios", + ], + deps = [ + ":parameter_accessor", + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "object_accessor", + srcs = ["object_accessor.cc"], + hdrs = ["object_accessor.h"], + deps = [ + ":parameter_accessor", + ":preprocessor", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:object", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "object_accessor_test", + srcs = ["object_accessor_test.cc"], + tags = [ + "local", + ], + deps = [ + ":object_accessor", + ":parameter_accessor", + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/types:variant", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "shader_code", + hdrs = ["shader_code.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:object", + "//tensorflow/lite/delegates/gpu/gl:uniform_parameter", + ], +) + +cc_library( + name = "shader_codegen", + srcs = ["shader_codegen.cc"], + hdrs = ["shader_codegen.h"], + deps = [ + ":compiled_node", + ":object_accessor", + ":parameter_accessor", + ":preprocessor", + ":shader_code", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/gl:compiler_options", + "//tensorflow/lite/delegates/gpu/gl:gpu_info", + "//tensorflow/lite/delegates/gpu/gl:object", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "compiled_node", + srcs = ["compiled_node.cc"], + hdrs = ["compiled_node.h"], + deps = [ + ":rename", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "//tensorflow/lite/delegates/gpu/gl:object", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "fuse_inplace", + srcs = ["fuse_inplace.cc"], + hdrs = ["fuse_inplace.h"], + deps = [ + ":compiled_node", + ":preprocessor", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + ], +) + +cc_library( + name = "fuse_inline", + srcs = ["fuse_inline.cc"], + hdrs = ["fuse_inline.h"], + deps = [ + ":compiled_node", + ":shader_code", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + ], +) + +cc_library( + name = "rename", + srcs = ["rename.cc"], + hdrs = ["rename.h"], + deps = [ + ":object_accessor", + ":parameter_accessor", + ":preprocessor", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "//tensorflow/lite/delegates/gpu/gl:object", + "//tensorflow/lite/delegates/gpu/gl:uniform_parameter", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "fuse_auto_input", + srcs = ["fuse_auto_input.cc"], + hdrs = ["fuse_auto_input.h"], + deps = [ + ":compiled_node", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:any", + "@com_google_absl//absl/types:variant", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc new file mode 100644 index 00000000000..40584738922 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h" + +namespace tflite { +namespace gpu { +namespace gl { + +Status MergeCode(CompiledNodeAttributes* attr, + CompiledNodeAttributes* merged_attr) { + // build a map of known names. + std::unordered_set known_names; + for (const auto& parameter : merged_attr->code.parameters) { + known_names.insert(parameter.name); + } + for (const auto& object : merged_attr->code.objects) { + known_names.insert(object.first); + } + + // Rewrite parameters with unique names. + int index = + merged_attr->code.parameters.size() + merged_attr->code.objects.size(); + RETURN_IF_ERROR(Rename( + [&](absl::string_view name) -> std::string { + std::string n(name.begin(), name.end()); + // if a name is unique, then keep it as is. Otherwise append an unique + // index. + if (known_names.find(n) == known_names.end()) { + return n; + } + return absl::StrCat(n, index++); + }, + &attr->code)); + std::move(attr->code.objects.begin(), attr->code.objects.end(), + std::back_inserter(merged_attr->code.objects)); + std::move(attr->code.parameters.begin(), attr->code.parameters.end(), + std::back_inserter(merged_attr->code.parameters)); + std::move(attr->node_indices.begin(), attr->node_indices.end(), + std::back_inserter(merged_attr->node_indices)); + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h new file mode 100644 index 00000000000..d41a734f4e2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_COMPILED_NODE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_COMPILED_NODE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Contains compiler internal attributes for each node after it was processed by +// NodeShader. +struct CompiledNodeAttributes { + std::vector inputs; + std::vector outputs; + + GeneratedCode code; + + // nodes that are covered by the provided shader. + std::vector node_indices; +}; + +// Moves all code objects, parameters and node indices from attr to merged_attr. +// Parameters and objects in attr.code.source_code are renamed to ensure +// uniqueness. +Status MergeCode(CompiledNodeAttributes* attr, + CompiledNodeAttributes* merged_attr); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_COMPILED_NODE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc new file mode 100644 index 00000000000..045c4feef78 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc @@ -0,0 +1,233 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/types/any.h" +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +std::pair MakeValueReplacement(int n, int k) { + return {absl::StrCat("value_", n), absl::StrCat("value_", k)}; +} + +std::pair MakeDataReplacement(int n, int k) { + return {absl::StrCat("input_data_", n), absl::StrCat("input_data_", k)}; +} + +} // namespace + +TransformResult FuseAutoInput::ApplyToNode(Node* node, GraphFloat32* graph) { + auto& node_attr = + absl::any_cast(node->operation.attributes); + auto& node_code = node_attr.code; + + if (node_code.input != IOStructure::AUTO) { + return {TransformStatus::SKIPPED, ""}; + } + uint3 workgroup = node_code.workgroup; + + auto node_outputs = graph->FindOutputs(node->id); + + // Check which inputs could be fused into the current node. + std::vector> nodes_to_fuse; + std::vector> input_values; + int input_num = -1; + for (auto input_value : graph->FindInputs(node->id)) { + input_num++; + const ValueId input_id = input_value->id; + input_values.push_back({input_id, input_num}); + + if (graph->FindConsumers(input_id).size() > 1) { + continue; // input is consumed by >1 nodes + } + Node* input_producer = graph->FindProducer(input_id); + if (input_producer == nullptr) { + continue; // graph's input + } + if (graph->FindOutputs(input_producer->id).size() != 1) { + continue; // input node has more than one output + } + auto& input_producer_attr = absl::any_cast( + input_producer->operation.attributes); + if (input_producer_attr.code.output != IOStructure::AUTO) { + continue; + } + if (input_producer_attr.code.workload != node_code.workload && + uint3() != input_producer_attr.code.workload) { + continue; + } + if (input_producer_attr.code.workgroup != uint3()) { + // New fused node should fuse only a single shader that has pre-defined + // workgroup. Such shader is considered "heavy". Do not fuse two heavy + // shaders into one. + // TODO(eignasheva): make sure it still works. + if (workgroup != uint3()) { + continue; + } + workgroup = input_producer_attr.code.workgroup; + } + nodes_to_fuse.push_back({input_producer, input_num}); + input_values.pop_back(); // this value will not be used as input. + } + if (nodes_to_fuse.empty()) { + return {TransformStatus::SKIPPED, ""}; + } + + // Break connections between current node and its inputs. + for (auto value : graph->FindInputs(node->id)) { + if (!graph->RemoveConsumer(node->id, value->id).ok()) { + return {TransformStatus::INVALID, ""}; + } + } + + std::string operation_type; + std::string source_code; + std::string values; + + // Node source code need to be appended later to the end. + std::swap(source_code, node_code.source_code); + + // Indicates value_k that is beyond originally declared [0..n] values, + // therefore, it can be used by newly added dependencies. + int extra_input_num = input_num; + input_num = 0; + + // Fuse all nodes into one. + for (auto input_and_num : nodes_to_fuse) { + auto& input = input_and_num.first; + auto& attr = + absl::any_cast(input->operation.attributes); + auto super_inputs = graph->FindInputs(input->id); + + // Replace all internal references in the input source code. For example: + // source code "value_0 = max(0, value_0);" will be rewritten into + // "value_2 = max(0, value_2);" + std::vector> replacements; + for (int i = 0; i < super_inputs.size(); ++i) { + // Node source code uses value_N to access output value from the fused + // node. Use correct reference. + // + // Here value_N does not correspond to input_N anymore. Instead it tracks + // value_n and input_m independently. Value_index uses an index needed + // for the "final" shader, while input_num preserves the order of inputs. + // For example: + // Shader A: input_0, input_1 + // value_0 = value_0 > value_1 ? value_0 : value_1; + // + // Shader B: input_0 + // value_0 = max(0, value_0); + // + // AddShader: input_0, input_1 + // value_0 = value_0 + value_1; + // + // Fused shader is going to have 3 inputs: input_0 (A), input_1 (A), + // input_2 (B). But Shader B need to store result in value_1, because + // AddShader refers to it as 'value_1'. So, fused shader will look as + // follows: + // + // // Shader A + // vec4 value_0 = input_data_0.data[gid.x, gid.y, gid.z]; + // vec4 value_2 = input_data_1.data[gid.x, gid.y, gid.z]; + // value_0 = value_0 > value_2 ? value_0 : value_2; + // + // // Shader B + // vec4 value_1 = input_data_2.data[gid.x, gid.y, gid.z]; + // value_1 = max(0, value_1); + // + // // AddShader + // value_0 = value_0 + value_1; + // + // output_data_0.data[gid.x, gid.y, gid.z] = value_0; + int value_index = i == 0 ? input_and_num.second : ++extra_input_num; + replacements.push_back(MakeValueReplacement(i, value_index)); + replacements.push_back(MakeDataReplacement(i, input_num)); + + // Declare input values based on the input structure of the merged node. + // This code copies what shader_codegen would do automatically. + if (attr.code.input == IOStructure::AUTO) { + absl::StrAppend(&values, " value_", value_index, " = $input_data_", + input_num, "[gid.x, gid.y, gid.z]$;\n"); + } + + if (!graph->AddConsumer(node->id, super_inputs[i]->id).ok()) { + return {TransformStatus::INVALID, ""}; + } + input_num++; + } + + // Also rename all _h and _w parameters to the new names. + for (auto& param : attr.code.parameters) { + param.name = absl::StrReplaceAll(param.name, replacements); + } + attr.code.source_code = + absl::StrReplaceAll(attr.code.source_code, replacements); + + // Merge all objects, parameters and source code. + if (!MergeCode(&attr, &node_attr).ok()) { + return {TransformStatus::INVALID, "Unable to merge the code"}; + } + absl::StrAppend(&node_attr.code.source_code, "{\n", attr.code.source_code, + "\n}"); + + if (!operation_type.empty()) { + operation_type += ","; + } + operation_type += input->operation.type; + + if (!graph->DeleteNode(input->id).ok()) { + return {TransformStatus::INVALID, ""}; + } + } + + // Add back all inputs that are used directly by the fused node. + for (int i = 0; i < input_values.size(); i++) { + if (node_code.input == IOStructure::AUTO) { + absl::StrAppend(&values, " value_", input_values[i].second, + " = $input_data_", input_num, + "[gid.x, gid.y, gid.z]$;\n"); + } + if (!graph->AddConsumer(node->id, input_values[i].first).ok()) { + return {TransformStatus::INVALID, ""}; + } + input_num++; + } + + node_code.input = IOStructure::ONLY_DEFINITIONS; + + absl::StrAppend(&node->operation.type, "(", operation_type, ")"); + node_code.source_code = + absl::StrCat(values, node_code.source_code, "{//FUSED", + node->operation.type, "\n", source_code, "\n}"); + + return {TransformStatus::APPLIED, ""}; +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h new file mode 100644 index 00000000000..ff5ac5b8488 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_AUTO_INPUT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_AUTO_INPUT_H_ + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Fuses nodes that have auto output with auto input node using the following +// rules. +// +// Source graph: +// A B C +// \ | / +// D +// +// - A, B and C each have a single output marked as AUTO +// - Each output is used only by D +// - D has all inputs marked as AUTO +// +// Result: in the best case a single node that does (A,B,C)+D operations. +// +class FuseAutoInput : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_AUTO_INPUT_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc new file mode 100644 index 00000000000..64fab74c7ca --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +TransformResult FuseAutoOutputWithInline::ApplyToNodesSequence( + const std::vector& sequence, GraphFloat32* graph) { + Node* node1 = sequence.front(); + Node* node2 = sequence.back(); + auto& attr1 = + absl::any_cast(node1->operation.attributes); + auto& attr2 = + absl::any_cast(node2->operation.attributes); + + if (attr1.code.output != IOStructure::AUTO || + graph->FindInputs(node2->id).size() != 1 || + graph->FindOutputs(node2->id).size() != 1 || + attr2.code.output != IOStructure::AUTO || + attr2.code.input != IOStructure::AUTO || + (attr1.code.workload != attr2.code.workload && + uint3() != attr2.code.workload) || + graph->FindOutputs(node1->id).size() != + graph->FindInputs(node2->id).size()) { + return {TransformStatus::SKIPPED, ""}; + } + + // Check if the code was not fused yet, and wrap source code into {}. + if (node1->operation.type.find('+') == std::string::npos) { + attr1.code.source_code = + absl::StrCat("\n{\n", attr1.code.source_code, "\n}\n"); + } + if (!MergeCode(&attr2, &attr1).ok()) { + return {TransformStatus::INVALID, "Unable to merge two nodes"}; + } + absl::StrAppend(&attr1.code.source_code, "{\n", attr2.code.source_code, + "\n}"); + node1->operation.type += "+" + node2->operation.type; + + if (!RemoveFollowingNode(graph, node2, node1).ok()) { + return {TransformStatus::INVALID, + "Unable to remove node " + std::to_string(node2->id)}; + } + return {TransformStatus::APPLIED, ""}; +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h new file mode 100644 index 00000000000..09e2cc52712 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.h @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INLINE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INLINE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Fuses every two nodes where first node does default output and second node +// is INLINE. +// +// Generates code as follows: +// 1. all uniforms are inlined +// 2. source code is wrapped into {} +// For example: +// value = clamp(value, 0.0, clip); +// + +// value = 1.0 / (1.0 + exp(-1.0 * value)); +// will turn into: +// { +// value = clamp(value, 0.0, clip); +// } +// { +// value = 1.0 / (1.0 + exp(-1.0 * value)); +// } +class FuseAutoOutputWithInline : public SequenceTransformation { + public: + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INLINE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc new file mode 100644 index 00000000000..e6f0902f054 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc @@ -0,0 +1,151 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +static const char* kInplacePrefix = "inplace_update:\0"; + +class EmptyInplaceRewrite : public InlineRewrite { + public: + RewriteStatus Rewrite(absl::string_view input, std::string* output) final { + if (input.compare(0, strlen(kInplacePrefix), kInplacePrefix) == 0) { + num_rewrites_++; + return RewriteStatus::SUCCESS; + } + return RewriteStatus::NOT_RECOGNIZED; + } + + int num_rewrites() const { return num_rewrites_; } + + private: + int num_rewrites_ = 0; +}; + +// Takes a code as an input. Replaces 'value_0' in the code with a value that +// comes in a rewrite. For example: +// code: value_0 = max(value_0, 0); +// rewrite: inplace_update:result_12 -> result_12 = max(result_12, 0); +// +class InplaceCodeRewrite : public InlineRewrite { + public: + explicit InplaceCodeRewrite(const std::string& code) : code_(code) {} + + RewriteStatus Rewrite(absl::string_view input, std::string* output) final { + int len = strlen(kInplacePrefix); + if (input.compare(0, len, kInplacePrefix) == 0) { + auto variable_name = input.substr(len); + absl::StrAppend(output, + absl::StrReplaceAll(code_, {{"value_0", variable_name}})); + return RewriteStatus::SUCCESS; + } + return RewriteStatus::NOT_RECOGNIZED; + } + + private: + std::string code_; +}; + +} // namespace + +TransformResult RemoveUnusedInplaceUpdates::ApplyToNode(Node* node, + GraphFloat32* graph) { + auto& attr = + absl::any_cast(node->operation.attributes); + // Remove inplace block by rewriting to empty string. + EmptyInplaceRewrite rewrite; + TextPreprocessor preprocessor('$', true); + preprocessor.AddRewrite(&rewrite); + if (!preprocessor.Rewrite(attr.code.source_code, &attr.code.source_code) + .ok()) { + return {TransformStatus::INVALID, ""}; + } + return {rewrite.num_rewrites() > 0 ? TransformStatus::APPLIED + : TransformStatus::SKIPPED, + ""}; +} + +TransformResult FuseInplaceUpdate::ApplyToNodesSequence( + const std::vector& sequence, GraphFloat32* graph) { + Node* node1 = sequence.front(); + Node* node2 = sequence.back(); + auto& attr1 = + absl::any_cast(node1->operation.attributes); + auto& attr2 = + absl::any_cast(node2->operation.attributes); + + if (graph->FindInputs(node2->id).size() != 1 || + graph->FindOutputs(node2->id).size() != 1 || + attr2.code.output != IOStructure::AUTO || + attr2.code.input != IOStructure::AUTO || + (attr1.code.workload != attr2.code.workload && + uint3() != attr2.code.workload)) { + return {TransformStatus::SKIPPED, ""}; + } + + // First count of replaces that would happen to check whether rewrite is + // needed. + { + EmptyInplaceRewrite counting_rewrite; + TextPreprocessor preprocessor('$', true); + preprocessor.AddRewrite(&counting_rewrite); + std::string temp; + if (!preprocessor.Rewrite(attr1.code.source_code, &temp).ok()) { + return {TransformStatus::INVALID, ""}; + } + // no rewrites in the source code. skip it. + if (counting_rewrite.num_rewrites() == 0) { + return {TransformStatus::SKIPPED, ""}; + } + } + if (!MergeCode(&attr2, &attr1).ok()) { + return {TransformStatus::INVALID, "Unable to merge two nodes"}; + } + TextPreprocessor preprocessor('$', true); + InplaceCodeRewrite rewrite(attr2.code.source_code); + preprocessor.AddRewrite(&rewrite); + if (!preprocessor.Rewrite(attr1.code.source_code, &attr1.code.source_code) + .ok()) { + return {TransformStatus::INVALID, ""}; + } + node1->operation.type += "+" + node2->operation.type; + + if (!RemoveFollowingNode(graph, node2, node1).ok()) { + return {TransformStatus::INVALID, + "Unable to remove node " + std::to_string(node2->id)}; + } + return {TransformStatus::APPLIED, ""}; +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h new file mode 100644 index 00000000000..7b334d27a49 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INPLACE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INPLACE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Fuse two shaders where second shader is inline shader with the first. +// First shader should have a special symbol that defines a place where such +// fusion should be made and what variable needs to be changed. +// Second shader needs to operation with 'value_0' variable. +// Example: +// +// First shader: +// vec4 result = input_data_0.data[gid.x, gid.y, gid.z]; +// $inplace_update:result$ +// ... +// output_data_0.data[1,2,3] = result; +// +// Second shader: +// value_0 = max(value_0, 0); +// +// Fused shader: +// vec4 result = input_data_0.data[gid.x, gid.y, gid.z]; +// result = max(result, 0); +// ... +// output_data_0.data[1,2,3] = result; +// +class FuseInplaceUpdate : public SequenceTransformation { + public: + int ExpectedSequenceLength() const final { return 2; } + + TransformResult ApplyToNodesSequence(const std::vector& sequence, + GraphFloat32* graph) final; +}; + +// Removes all %inplace_update:XXX% strings from the code. +class RemoveUnusedInplaceUpdates : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_FUSE_INPLACE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc new file mode 100644 index 00000000000..e6dbb4fea2a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc @@ -0,0 +1,546 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h" + +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace object_accessor_internal { + +// Splits name[index1, index2...] into 'name' and {'index1', 'index2'...}. +IndexedElement ParseElement(absl::string_view input) { + auto i = input.find('['); + if (i == std::string::npos || input.back() != ']') { + return {}; + } + return {input.substr(0, i), + absl::StrSplit(input.substr(i + 1, input.size() - i - 2), ',', + absl::SkipWhitespace())}; +} + +} // namespace object_accessor_internal + +namespace { + +void MaybeConvertToHalf(DataType data_type, absl::string_view value, + std::string* output) { + if (data_type == DataType::FLOAT16) { + absl::StrAppend(output, "Vec4ToHalf(", value, ")"); + } else { + absl::StrAppend(output, value); + } +} + +void MaybeConvertFromHalf(DataType data_type, absl::string_view value, + std::string* output) { + if (data_type == DataType::FLOAT16) { + absl::StrAppend(output, "Vec4FromHalf(", value, ")"); + } else { + absl::StrAppend(output, value); + } +} + +struct ReadFromTextureGenerator { + RewriteStatus operator()(uint32_t) const { + if (element.indices.size() != 1) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + // 1D textures are emulated as 2D textures + absl::StrAppend(result, "imageLoad(", element.object_name, ", ivec2(", + element.indices[0], ", 0))"); + return RewriteStatus::SUCCESS; + } + + template + RewriteStatus operator()(const Shape&) const { + if (element.indices.size() != Shape::size()) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + absl::StrAppend(result, "imageLoad(", element.object_name, ", ivec", + Shape::size(), "(", absl::StrJoin(element.indices, ", "), + "))"); + return RewriteStatus::SUCCESS; + } + + const object_accessor_internal::IndexedElement& element; + std::string* result; +}; + +struct ReadFromBufferGenerator { + RewriteStatus operator()(uint32_t) const { + if (element.indices.size() != 1) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + MaybeConvertFromHalf( + data_type, + absl::StrCat(element.object_name, ".data[", element.indices[0], "]"), + result); + return RewriteStatus::SUCCESS; + } + + RewriteStatus operator()(const uint2& size) const { + if (element.indices.size() == 1) { + // access by linear index. Use method above to generate accessor. + return (*this)(1U); + } + if (element.indices.size() != 2) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + MaybeConvertFromHalf( + data_type, + absl::StrCat(element.object_name, ".data[", element.indices[0], " + $", + element.object_name, "_w$ * (", element.indices[1], ")]"), + result); + *requires_sizes = true; + return RewriteStatus::SUCCESS; + } + + RewriteStatus operator()(const uint3& size) const { + if (element.indices.size() == 1) { + // access by linear index. Use method above to generate accessor. + return (*this)(1U); + } + if (element.indices.size() != 3) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + MaybeConvertFromHalf( + data_type, + absl::StrCat(element.object_name, ".data[", element.indices[0], " + $", + element.object_name, "_w$ * (", element.indices[1], " + $", + element.object_name, "_h$ * (", element.indices[2], "))]"), + result); + *requires_sizes = true; + return RewriteStatus::SUCCESS; + } + + DataType data_type; + const object_accessor_internal::IndexedElement& element; + std::string* result; + + // indicates that generated code accessed _w and/or _h index variables. + bool* requires_sizes; +}; + +// Generates code for reading an element from an object. +RewriteStatus GenerateReadAccessor( + const Object& object, + const object_accessor_internal::IndexedElement& element, + std::string* result, bool* requires_sizes) { + switch (object.object_type) { + case ObjectType::BUFFER: + return absl::visit(ReadFromBufferGenerator{object.data_type, element, + result, requires_sizes}, + object.size); + case ObjectType::TEXTURE: + return absl::visit(ReadFromTextureGenerator{element, result}, + object.size); + case ObjectType::UNKNOWN: + return RewriteStatus::ERROR; + } +} + +struct WriteToBufferGenerator { + RewriteStatus operator()(uint32_t) const { + if (element.indices.size() != 1) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + absl::StrAppend(result, element.object_name, ".data[", element.indices[0], + "] = "); + MaybeConvertToHalf(data_type, value, result); + return RewriteStatus::SUCCESS; + } + + RewriteStatus operator()(const uint2& size) const { + if (element.indices.size() == 1) { + // access by linear index. Use method above to generate accessor. + return (*this)(1U); + } + if (element.indices.size() != 2) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + absl::StrAppend(result, element.object_name, ".data[", element.indices[0], + " + $", element.object_name, "_w$ * (", element.indices[1], + ")] = "); + MaybeConvertToHalf(data_type, value, result); + *requires_sizes = true; + return RewriteStatus::SUCCESS; + } + + RewriteStatus operator()(const uint3& size) const { + if (element.indices.size() == 1) { + // access by linear index. Use method above to generate accessor. + return (*this)(1U); + } + if (element.indices.size() != 3) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + absl::StrAppend(result, element.object_name, ".data[", element.indices[0], + " + $", element.object_name, "_w$ * (", element.indices[1], + " + $", element.object_name, "_h$ * (", element.indices[2], + "))] = "); + MaybeConvertToHalf(data_type, value, result); + *requires_sizes = true; + return RewriteStatus::SUCCESS; + } + + DataType data_type; + const object_accessor_internal::IndexedElement& element; + absl::string_view value; + std::string* result; + + // indicates that generated code accessed _w and/or _h index variables. + bool* requires_sizes; +}; + +struct WriteToTextureGenerator { + RewriteStatus operator()(uint32_t) const { + if (element.indices.size() != 1) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + // 1D textures are emulated as 2D textures + absl::StrAppend(result, "imageStore(", element.object_name, ", ivec2(", + element.indices[0], ", 0), ", value, ")"); + return RewriteStatus::SUCCESS; + } + + template + RewriteStatus operator()(const Shape&) const { + if (element.indices.size() != Shape::size()) { + result->append("WRONG_NUMBER_OF_INDICES"); + return RewriteStatus::ERROR; + } + absl::StrAppend(result, "imageStore(", element.object_name, ", ivec", + Shape::size(), "(", absl::StrJoin(element.indices, ", "), + "), ", value, ")"); + return RewriteStatus::SUCCESS; + } + + const object_accessor_internal::IndexedElement& element; + absl::string_view value; + std::string* result; +}; + +// Generates code for writing value an element in an object. +RewriteStatus GenerateWriteAccessor( + const Object& object, + const object_accessor_internal::IndexedElement& element, + absl::string_view value, std::string* result, bool* requires_sizes) { + switch (object.object_type) { + case ObjectType::BUFFER: + return absl::visit(WriteToBufferGenerator{object.data_type, element, + value, result, requires_sizes}, + object.size); + case ObjectType::TEXTURE: + return absl::visit(WriteToTextureGenerator{element, value, result}, + object.size); + case ObjectType::UNKNOWN: + return RewriteStatus::ERROR; + } +} + +std::string ToAccessModifier(AccessType access, bool use_readonly_modifier) { + switch (access) { + case AccessType::READ: + return use_readonly_modifier ? " readonly" : ""; + case AccessType::WRITE: + return " writeonly"; + case AccessType::READ_WRITE: + return " restrict"; + } + return " unknown_access"; +} + +std::string ToBufferType(DataType data_type) { + switch (data_type) { + case DataType::UINT8: + case DataType::UINT16: + case DataType::UINT32: + return "uvec4"; + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: + return "ivec4"; + case DataType::FLOAT16: + return "uvec2"; + case DataType::FLOAT32: + return "vec4"; + default: + return "unknown"; + } +} + +struct TextureImageTypeGetter { + std::string operator()(uint32_t) const { + // 1D textures are emulated as 2D textures + return (*this)(uint2()); + } + + std::string operator()(const uint2&) const { + switch (type) { + case DataType::UINT16: + case DataType::UINT32: + return "uimage2D"; + case DataType::INT16: + case DataType::INT32: + return "iimage2D"; + case DataType::FLOAT16: + case DataType::FLOAT32: + return "image2D"; + default: + return "unknown"; + } + } + + std::string operator()(const uint3&) const { + switch (type) { + case DataType::UINT16: + case DataType::UINT32: + return "uimage2DArray"; + case DataType::INT16: + case DataType::INT32: + return "iimage2DArray"; + case DataType::FLOAT16: + case DataType::FLOAT32: + return "image2DArray"; + default: + return "unknown"; + } + } + + DataType type; +}; + +std::string ToImageType(const Object& object) { + return absl::visit(TextureImageTypeGetter{object.data_type}, object.size); +} + +std::string ToImageLayoutQualifier(DataType type) { + switch (type) { + case DataType::UINT16: + return "rgba16ui"; + case DataType::UINT32: + return "rgba32ui"; + case DataType::INT16: + return "rgba16i"; + case DataType::INT32: + return "rgba32i"; + case DataType::FLOAT16: + return "rgba16f"; + case DataType::FLOAT32: + return "rgba32f"; + default: + return "unknown"; + } +} + +std::string ToImagePrecision(DataType type) { + switch (type) { + case DataType::UINT16: + case DataType::INT16: + case DataType::FLOAT16: + return "mediump"; + case DataType::UINT32: + case DataType::INT32: + case DataType::FLOAT32: + return "highp"; + default: + return "unknown"; + } +} + +struct SizeParametersAdder { + void operator()(uint32_t) const {} + + void operator()(const uint2& size) const { + parameters->AddParameter( + {absl::StrCat(object_name, "_w"), static_cast(size.x)}); + } + + // p1 and p2 are padding. For some reason buffer does not map correctly + // without it. + void operator()(const uint3& size) const { + parameters->AddParameter( + {absl::StrCat(object_name, "_w"), static_cast(size.x)}); + parameters->AddParameter( + {absl::StrCat(object_name, "_h"), static_cast(size.y)}); + } + + absl::string_view object_name; + ParameterAccessor* parameters; +}; + +// Adds necessary parameters to parameter accessor that represent object size +// needed for indexed access. +// - 1D : empty +// - 2D : 'int object_name_w' +// - 3D : 'int object_name_w' + 'int object_name_h' +void AddSizeParameters(absl::string_view object_name, const Object& object, + ParameterAccessor* parameters) { + absl::visit(SizeParametersAdder{object_name, parameters}, object.size); +} + +void GenerateObjectDeclaration(absl::string_view name, const Object& object, + std::string* declaration, bool is_mali) { + switch (object.object_type) { + case ObjectType::BUFFER: + // readonly modifier used to fix shader compilation for Mali on Android 8, + // see b/111601761 + absl::StrAppend(declaration, "layout(binding = ", object.binding, ")", + ToAccessModifier(object.access, !is_mali), " buffer B", + object.binding, " { ", ToBufferType(object.data_type), + " data[]; } ", name, ";\n"); + break; + case ObjectType::TEXTURE: + absl::StrAppend(declaration, "layout(", + ToImageLayoutQualifier(object.data_type), + ", binding = ", object.binding, ")", + ToAccessModifier(object.access, true), " uniform ", + ToImagePrecision(object.data_type), " ", + ToImageType(object), " ", name, ";\n"); + break; + case ObjectType::UNKNOWN: + // do nothing. + break; + } +} + +} // namespace + +RewriteStatus ObjectAccessor::Rewrite(absl::string_view input, + std::string* output) { + // Splits 'a =b' into {'a','b'}. + std::pair n = + absl::StrSplit(input, absl::MaxSplits('=', 1), absl::SkipWhitespace()); + if (n.first.empty()) { + return RewriteStatus::NOT_RECOGNIZED; + } + if (n.second.empty()) { + return RewriteRead(absl::StripAsciiWhitespace(n.first), output); + } + return RewriteWrite(absl::StripAsciiWhitespace(n.first), + absl::StripAsciiWhitespace(n.second), output); +} + +RewriteStatus ObjectAccessor::RewriteRead(absl::string_view location, + std::string* output) { + auto element = object_accessor_internal::ParseElement(location); + if (element.object_name.empty()) { + return RewriteStatus::NOT_RECOGNIZED; + } + auto it = name_to_object_.find( + std::string(element.object_name.data(), element.object_name.size())); + if (it == name_to_object_.end()) { + return RewriteStatus::NOT_RECOGNIZED; + } + bool requires_sizes = false; + auto status = + GenerateReadAccessor(it->second, element, output, &requires_sizes); + if (requires_sizes) { + AddSizeParameters(it->first, it->second, parameter_accessor_); + } + return status; +} + +RewriteStatus ObjectAccessor::RewriteWrite(absl::string_view location, + absl::string_view value, + std::string* output) { + // name[index1, index2...] = value + auto element = object_accessor_internal::ParseElement(location); + if (element.object_name.empty()) { + return RewriteStatus::NOT_RECOGNIZED; + } + auto it = name_to_object_.find( + std::string(element.object_name.data(), element.object_name.size())); + if (it == name_to_object_.end()) { + return RewriteStatus::NOT_RECOGNIZED; + } + bool requires_sizes = false; + auto status = GenerateWriteAccessor(it->second, element, value, output, + &requires_sizes); + if (requires_sizes) { + AddSizeParameters(it->first, it->second, parameter_accessor_); + } + return status; +} + +bool ObjectAccessor::AddObject(const std::string& name, Object object) { + if (object.object_type == ObjectType::UNKNOWN) { + return false; + } + return name_to_object_.insert({name, std::move(object)}).second; +} + +std::string ObjectAccessor::GetObjectDeclarations() const { + std::string declarations; + for (auto& o : name_to_object_) { + GenerateObjectDeclaration(o.first, o.second, &declarations, is_mali_); + } + return declarations; +} + +std::string ObjectAccessor::GetFunctionsDeclarations() const { + std::string modifier = ""; + // Mali compiler does not want to compile a function without readonly + // modifier. See b/111601761 for the context. + if (is_mali_) { + modifier = "readonly "; + } + // If there is a single object SSBO with F16, then we need to output functions + // as well. + for (const auto& o : name_to_object_) { + if (o.second.data_type == DataType::FLOAT16 && + o.second.object_type == ObjectType::BUFFER) { + return absl::StrCat("vec4 Vec4FromHalf(in ", modifier, + "uvec2 v) { return vec4(unpackHalf2x16(v.x), " + "unpackHalf2x16(v.y)); }\n" + "uvec2 Vec4ToHalf(in ", + modifier, + "vec4 v) { return uvec2(packHalf2x16(v.xy), " + "packHalf2x16(v.zw)); }\n"); + } + } + return ""; +} + +std::vector ObjectAccessor::GetObjects() const { + std::vector objects; + for (auto& o : name_to_object_) { + objects.push_back(o.second); + } + return objects; +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h new file mode 100644 index 00000000000..5c9fe52098a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_ + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// This rewrite handles access to objects both reads and writes. +// +// The following syntax is supported to access objects: +// +// READ: +// vec4 value = $data[i]$; +// where data is a buffer or 1D texture +// vec4 value = $data[i,j]$; +// where data is 2D texture +// vec4 value = $data[i,j,k]$; +// where data is 3D texture +// +// WRITE: +// $data[i] = value$; +// where data is a buffer or 1D texture +// $data[i,j] = value$; +// where data is 2D texture +// $data[i,j,k] = value$; +// where data is 3D texture +// +// Accessor supports all types (gvecN) as well as float16. +// +// TODO(akulik): support field in data[x,y,z].x +// +class ObjectAccessor : public InlineRewrite { + public: + ObjectAccessor(bool is_mali, ParameterAccessor* parameter_accessor) + : is_mali_(is_mali), parameter_accessor_(parameter_accessor) {} + + RewriteStatus Rewrite(absl::string_view input, std::string* output) final; + + // Return true if object was successfully added. + bool AddObject(const std::string& name, Object object); + + // Returns objects declarations that need to be added in a shader's code. + std::string GetObjectDeclarations() const; + + // Returns functions declarations that need to be added in a shader's code. + // These functions are used by code accessing objects. + std::string GetFunctionsDeclarations() const; + + // Returns a collection of registered objects + std::vector GetObjects() const; + + private: + RewriteStatus RewriteRead(absl::string_view location, std::string* output); + + RewriteStatus RewriteWrite(absl::string_view location, + absl::string_view value, std::string* output); + + std::unordered_map name_to_object_; + + const bool is_mali_; + ParameterAccessor* parameter_accessor_; +}; + +// Implementation details below. + +namespace object_accessor_internal { + +// Refers to an element in an object. +struct IndexedElement { + absl::string_view object_name; + std::vector indices; +}; + +// Splits name[index1, index2...] into 'name' and {'index1', 'index2'...}. +IndexedElement ParseElement(absl::string_view input); + +} // namespace object_accessor_internal +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OBJECT_ACCESSOR_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc new file mode 100644 index 00000000000..2ee6d9de461 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h" + +#include +#include + +#include +#include +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h" + +namespace tflite { +namespace gpu { +namespace gl { + +struct ParameterComparator { + template + bool operator()(const T& t) const { + const T* v = absl::get_if(&p.value); + return v && t == *v; + } + const UniformParameter& p; +}; + +// partially equal +bool operator==(const UniformParameter& l, const UniformParameter& r) { + return l.name == r.name && absl::visit(ParameterComparator{l}, r.value); +} + +namespace { + +TEST(Preprocessor, CornerCases) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + std::string result; + ASSERT_EQ(accessor.Rewrite("", &result), RewriteStatus::NOT_RECOGNIZED); + ASSERT_EQ(accessor.Rewrite("=", &result), RewriteStatus::NOT_RECOGNIZED); +} + +TEST(Preprocessor, ReadFromBuffer) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE( + accessor.AddObject("obj", MakeReadonlyBuffer(std::vector{1.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS); + EXPECT_TRUE(parameters.GetUniformParameters().empty()); + ASSERT_EQ(result, "obj.data[i]"); +} + +TEST(Preprocessor, ReadFromBufferLinear) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE(accessor.AddObject( + "obj", MakeReadonlyBuffer(uint3(1, 2, 3), std::vector{1.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS); + EXPECT_TRUE(parameters.GetUniformParameters().empty()); + ASSERT_EQ(result, "obj.data[i]"); +} + +TEST(Preprocessor, ReadFromBufferByIndex) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE(accessor.AddObject( + "obj", MakeReadonlyBuffer(uint3(1, 2, 3), std::vector{1.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite("obj[x,y + 5,z]", &result), + RewriteStatus::SUCCESS); + EXPECT_THAT(parameters.GetUniformParameters(), + testing::UnorderedElementsAre(UniformParameter{"obj_w", 1}, + UniformParameter{"obj_h", 2})); + ASSERT_EQ(result, "obj.data[x + $obj_w$ * (y + 5 + $obj_h$ * (z))]"); +} + +TEST(Preprocessor, ReadFromTexture) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE(accessor.AddObject( + "obj", MakeReadonlyTexture(uint3(1, 2, 3), {1.0, 2.0, 3.0, 4.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite("obj[i,j,k]", &result), RewriteStatus::SUCCESS); + // textures don't need extra variables to be stored for indexed access + EXPECT_TRUE(parameters.GetUniformParameters().empty()); + ASSERT_EQ(result, "imageLoad(obj, ivec3(i, j, k))"); +} + +TEST(Preprocessor, ReadFromTexture1D) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE( + accessor.AddObject("obj", MakeReadonlyTexture({1.0, 2.0, 3.0, 4.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite("obj[i]", &result), RewriteStatus::SUCCESS); + EXPECT_TRUE(parameters.GetUniformParameters().empty()); + ASSERT_EQ(result, "imageLoad(obj, ivec2(i, 0))"); +} + +TEST(Preprocessor, WriteToBuffer) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE( + accessor.AddObject("obj", MakeReadonlyBuffer(std::vector{1.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite(" obj[i] =value", &result), + RewriteStatus::SUCCESS); + EXPECT_TRUE(parameters.GetUniformParameters().empty()); + ASSERT_EQ(result, "obj.data[i] = value"); +} + +TEST(Preprocessor, WriteToBufferByIndex) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE(accessor.AddObject( + "obj", MakeReadonlyBuffer(uint3(1, 2, 3), {1.0, 2.0, 3.0, 4.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite(" obj[i,j,k] =value", &result), + RewriteStatus::SUCCESS); + EXPECT_THAT(parameters.GetUniformParameters(), + testing::UnorderedElementsAre(UniformParameter{"obj_w", 1}, + UniformParameter{"obj_h", 2})); + ASSERT_EQ(result, "obj.data[i + $obj_w$ * (j + $obj_h$ * (k))] = value"); +} + +TEST(Preprocessor, WriteToTexture) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE(accessor.AddObject( + "obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite("obj[i,j,k]= value ", &result), + RewriteStatus::SUCCESS); + ASSERT_EQ(result, "imageStore(obj, ivec3(i, j, k), value)"); +} + +TEST(Preprocessor, WriteToTexture1D) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE( + accessor.AddObject("obj", MakeReadonlyTexture({1.0, 2.0, 3.0, 4.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite("obj[i]= value ", &result), + RewriteStatus::SUCCESS); + EXPECT_TRUE(parameters.GetUniformParameters().empty()); + ASSERT_EQ(result, "imageStore(obj, ivec2(i, 0), value)"); +} + +TEST(Preprocessor, FailedWriteToBuffer) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE( + accessor.AddObject("obj", MakeReadonlyBuffer(std::vector{1.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite(" obj[i,j] =value", &result), + RewriteStatus::ERROR); + ASSERT_EQ(result, "WRONG_NUMBER_OF_INDICES"); +} + +TEST(Preprocessor, FailedWriteToTexture) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE(accessor.AddObject( + "obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0}))); + std::string result; + EXPECT_EQ(accessor.Rewrite("obj[i]= value ", &result), RewriteStatus::ERROR); + ASSERT_EQ(result, "WRONG_NUMBER_OF_INDICES"); +} + +TEST(Preprocessor, DeclareTexture) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(false, ¶meters); + ASSERT_TRUE(accessor.AddObject( + "obj", MakeReadonlyTexture(uint3(1, 1, 1), {1.0, 2.0, 3.0, 4.0}))); + ASSERT_EQ(accessor.GetObjectDeclarations(), + "layout(rgba32f, binding = 0) readonly uniform highp image2DArray " + "obj;\n"); +} + +TEST(Preprocessor, DeclareBuffer) { + ParameterAccessor parameters(false); + ObjectAccessor accessor(true, ¶meters); + ASSERT_TRUE( + accessor.AddObject("obj", MakeReadonlyBuffer(std::vector{1.0}))); + ASSERT_EQ(accessor.GetObjectDeclarations(), + "layout(binding = 0) buffer B0 { vec4 data[]; } obj;\n"); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.cc new file mode 100644 index 00000000000..f3f442bc218 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.cc @@ -0,0 +1,368 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace parameter_accessor_internal { + +// Parse the following regex manually +// name(\[index\])?(\.field)? +ParameterReference Parse(absl::string_view input) { + ParameterReference ref; + auto start_index = input.find('['); + if (start_index != std::string::npos) { + auto end_index = input.rfind(']'); + if (end_index == std::string::npos) { + return ref; + } + ref.index = input.substr(start_index + 1, end_index - start_index - 1); + ref.name = input.substr(0, start_index); + ref.field = input.substr(end_index + 1); + } else { + auto dot = input.find('.'); + if (dot != std::string::npos) { + ref.name = input.substr(0, dot); + ref.field = input.substr(dot); + } else { + ref.name = input; + } + } + return ref; +} + +} // namespace parameter_accessor_internal + +namespace { + +struct UniformTypeGetter { + std::string operator()(int) const { return "int"; } + std::string operator()(const int2&) const { return "ivec2"; } + std::string operator()(const std::vector&) const { return "ivec2"; } + std::string operator()(const int4&) const { return "ivec4"; } + std::string operator()(unsigned int) const { return "uint"; } + std::string operator()(const uint4&) const { return "uvec4"; } + std::string operator()(float) const { return "float"; } + std::string operator()(const float2&) const { return "vec2"; } + std::string operator()(const float4&) const { return "vec4"; } +}; + +// Returns GLSL uniform type of the given parameter. +std::string GetUniformType(const UniformParameter::ValueType& value) { + return absl::visit(UniformTypeGetter(), value); +} + +template +void FormatValue(std::string* result, T t) { + absl::StrAppend(result, t); +} + +template <> +void FormatValue(std::string* result, float t) { + absl::StrAppend(result, absl::StrFormat("%.9ff", t)); +} + +// Unfortunately absl::StrJoin with custom formatter requires formatter to use +// string, not std::string. Therefore, due to this compatibility issue data +// needs to be converted to string representation first and then joined. +template +std::vector ToString(const std::array& data) { + std::vector result(N); + for (int i = 0; i < N; ++i) { + FormatValue(&result[i], data[i]); + } + return result; +} + +struct ConstGenerator { + template + void operator()(T t) const { + FormatValue(result, t); + } + + template + void operator()(const Vec2& v) const { + absl::StrAppend(result, UniformTypeGetter()(v), "(", + absl::StrJoin(ToString(v.data_), ","), ")"); + } + + template + void operator()(const Vec3& v) const { + absl::StrAppend(result, UniformTypeGetter()(v), "(", + absl::StrJoin(ToString(v.data_), ","), ")"); + } + + template + void operator()(const Vec4& v) const { + absl::StrAppend(result, UniformTypeGetter()(v), "(", + absl::StrJoin(ToString(v.data_), ","), ")"); + } + + template + void operator()(const std::vector& v) const { + std::string type = UniformTypeGetter()(v); + absl::StrAppend(result, type, "[", v.size(), "]("); + bool first = true; + for (const auto& i : v) { + if (first) { + first = false; + } else { + absl::StrAppend(result, ","); + } + (*this)(i); + } + absl::StrAppend(result, ")"); + } + + std::string* result; +}; + +// Appends string representation of a parameter value. +void GetValue(const UniformParameter::ValueType& value, std::string* result) { + absl::visit(ConstGenerator{result}, value); +} + +struct UniformDeclarationGenerator { + template + void operator()(const T&) const { + absl::StrAppend(result, "uniform ", GetUniformType(param.value), " ", + param.name, ";\n"); + } + + template + void operator()(const std::vector& v) const { + absl::StrAppend(result, "uniform ", GetUniformType(param.value), " ", + param.name, "[", v.size(), "];\n"); + } + + const UniformParameter& param; + std::string* result; +}; + +void GenerateUniformDeclaration(const UniformParameter& parameter, + std::string* result) { + absl::visit(UniformDeclarationGenerator{parameter, result}, parameter.value); +} + +struct VariableLengthGetter { + template + bool operator()(const T&) const { + return false; + } + template + bool operator()(const std::vector&) const { + return true; + } +}; + +// Returns true if value is a vector +bool IsVariableLength(const UniformParameter::ValueType& value) { + return absl::visit(VariableLengthGetter(), value); +} + +enum Field : uint8_t { UNKNOWN = 4, X = 0, Y = 1, Z = 2, W = 3 }; + +Field ToField(absl::string_view field_name) { + if (field_name.size() == 2 && field_name[0] == '.') { + switch (field_name[1]) { + case 'x': + return Field::X; + case 'y': + return Field::Y; + case 'z': + return Field::Z; + case 'w': + return Field::W; + } + } + return Field::UNKNOWN; +} + +struct FieldAccessor { + template + void operator()(const T&) const {} + + template + void operator()(const Vec2& v) const { + FormatValue(result, v[field]); + } + + template + void operator()(const Vec3& v) const { + FormatValue(result, v[field]); + } + + template + void operator()(const Vec4& v) const { + FormatValue(result, v[field]); + } + + Field field; + std::string* result; +}; + +// Appends formatted value of the given field. +void GetValue(const UniformParameter::ValueType& value, Field field, + std::string* result) { + absl::visit(FieldAccessor{field, result}, value); +} + +struct FieldChecker { + // For trivial as well as variable-length types indexed access is not allowed. + template + bool operator()(const T&) const { + return false; + } + + template + bool operator()(const Vec2& v) const { + return field < v.size(); + } + + template + bool operator()(const Vec3& v) const { + return field < v.size(); + } + + template + bool operator()(const Vec4& v) const { + return field < v.size(); + } + + template + bool operator()(const std::vector&) const { + // technically accessing [0] element of an empty vector is UB, but we need + // only type information for this check. Therefore, construct default T and + // use it instead. + T t; + return (*this)(t); + } + + Field field; +}; + +// Returns true if field has field access and field is not out of bounds. +bool HasField(const UniformParameter::ValueType& value, Field field) { + return absl::visit(FieldChecker{field}, value); +} + +void AssembleAccessor(absl::string_view name, absl::string_view index, + absl::string_view field, std::string* result) { + if (index.empty()) { + absl::StrAppend(result, name, field); + } else { + absl::StrAppend(result, name, "[", index, "]", field); + } +} + +} // namespace + +RewriteStatus ParameterAccessor::Rewrite(absl::string_view input, + std::string* output) { + auto ref = parameter_accessor_internal::Parse(input); + if (ref.name.empty()) { + absl::StrAppend(output, "INVALID_SYNTAX"); + return RewriteStatus::ERROR; + } + + auto it = name_to_param_.find(std::string(ref.name.data(), ref.name.size())); + if (it == name_to_param_.end()) { + // Uniform with this name is not registered. + return RewriteStatus::NOT_RECOGNIZED; + } + const auto& value = it->second.value; + + if (!ref.index.empty() && !IsVariableLength(value)) { + // Trying to access parameter by index, but it is not variable-length. + absl::StrAppend(output, "INVALID_ACCESS_BY_INDEX"); + return RewriteStatus::ERROR; + } + + Field f = ToField(ref.field); + if (!ref.field.empty() && !HasField(value, f)) { + // Trying to access a parameter by field, but it does not have it. + absl::StrAppend(output, "INVALID_ACCESS_BY_FIELD"); + return RewriteStatus::ERROR; + } + + // Error checks are complete now. + + // All variable-length parameters are encoded as-is without inlining. + if (!inline_values_ || IsVariableLength(value)) { + AssembleAccessor(it->second.name, ref.index, ref.field, output); + } else { + // Parameter + field is replaced with field value. + if (f != Field::UNKNOWN) { + GetValue(value, f, output); + } else { + // Parameter is accessed directly. + GetValue(value, output); + } + } + return RewriteStatus::SUCCESS; +} + +bool ParameterAccessor::AddParameter(UniformParameter param) { + std::string name = param.name; + return name_to_param_.insert({name, std::move(param)}).second; +} + +std::string ParameterAccessor::GetConstDeclarations() const { + // Variable length parameters are declared as const and accessed via variable + // with index. + std::string declarations; + for (auto& param : name_to_param_) { + const auto& value = param.second.value; + if (IsVariableLength(value)) { + absl::StrAppend(&declarations, "const ", GetUniformType(value), " ", + param.second.name, "[] = "); + GetValue(value, &declarations); + absl::StrAppend(&declarations, ";\n"); + } + } + return declarations; +} + +std::string ParameterAccessor::GetUniformDeclarations() const { + std::string declarations; + if (!inline_values_) { + for (auto& param : name_to_param_) { + GenerateUniformDeclaration(param.second, &declarations); + } + } + return declarations; +} + +std::vector ParameterAccessor::GetUniformParameters() const { + std::vector params; + if (!inline_values_) { + for (auto& param : name_to_param_) { + params.push_back(param.second); + } + } + return params; +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h new file mode 100644 index 00000000000..e6efed0124f --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_ + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// This rewrite handles access to parameters. It may rewrite a parameter with +// actual values if inline_values is set to true. +// +// The following syntax is supported to access parameters: +// - simple parameter: name +// - parameter with field: name.(x|y|z|w) +// - parameter with index: name[i] +// - parameter with index and field: name[i].(x|y|z|w) +// +// If 'inline_values' is set to true, non variable-length parameters will be +// inlined. For example, 'base.x' will be replaced with value of 'x' field from +// 'base'. Variable-length are declared as const and accessed via index. +// These declarations are returned by GetConstDeclarations. +// +// If 'inline_values' is set to false, all parameters will be declared as +// uniforms. Uniform declarations are returned by GetUniformDeclarations. +class ParameterAccessor : public InlineRewrite { + public: + explicit ParameterAccessor(bool inline_values) + : inline_values_(inline_values) {} + + RewriteStatus Rewrite(absl::string_view input, std::string* output) final; + + // Return true if parameter was successfully added. + bool AddParameter(UniformParameter param); + + // Returns const parameters that need to be inlined in the a shader's code. + std::string GetConstDeclarations() const; + + // Returns uniforms declarations that need to be inlined in a shader's code. + std::string GetUniformDeclarations() const; + + // Returns a collection of uniform parameters. + std::vector GetUniformParameters() const; + + private: + const bool inline_values_; + // Unique parameter index used for obfuscation. + uint32_t unique_param_index_ = 0; + + std::unordered_map name_to_param_; +}; + +// Implementation details below. + +namespace parameter_accessor_internal { + +struct ParameterReference { + absl::string_view name; + absl::string_view index; + absl::string_view field; +}; + +// Parse the following regex manually +// name(\[index\])?(\.field)? +ParameterReference Parse(absl::string_view input); + +} // namespace parameter_accessor_internal +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PARAMETER_ACCESSOR_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor_test.cc new file mode 100644 index 00000000000..96182751b9b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h" + +#include +#include + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +TEST(Preprocessor, CornerCases) { + ParameterAccessor accessor(true); + std::string result; + ASSERT_EQ(accessor.Rewrite("unknown", &result), + RewriteStatus::NOT_RECOGNIZED); +} + +TEST(Preprocessor, Value) { + ParameterAccessor accessor(true); + ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", int32_t(1)})); + std::string result; + EXPECT_EQ(accessor.Rewrite("var", &result), RewriteStatus::SUCCESS); + ASSERT_EQ(result, "1"); +} + +TEST(Preprocessor, ValueVec) { + ParameterAccessor accessor(true); + ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", int2(1, 2)})); + std::string result; + EXPECT_EQ(accessor.Rewrite("var", &result), RewriteStatus::SUCCESS); + ASSERT_EQ(result, "ivec2(1,2)"); +} + +TEST(Preprocessor, Field) { + ParameterAccessor accessor(true); + ASSERT_TRUE( + accessor.AddParameter(UniformParameter{"var", float2(1.0, 2.1234567)})); + std::string result; + EXPECT_EQ(accessor.Rewrite("var.y", &result), RewriteStatus::SUCCESS); + ASSERT_EQ(result, "2.123456717f"); +} + +TEST(Preprocessor, FieldFail) { + ParameterAccessor accessor(true); + ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", 1.0f})); + ASSERT_TRUE(accessor.AddParameter(UniformParameter{"vec", float2(1.0, 1.0)})); + std::string result; + EXPECT_EQ(accessor.Rewrite("var.y", &result), RewriteStatus::ERROR); + ASSERT_EQ(result, "INVALID_ACCESS_BY_FIELD"); + + result.clear(); + EXPECT_EQ(accessor.Rewrite("vec.z", &result), RewriteStatus::ERROR); + ASSERT_EQ(result, "INVALID_ACCESS_BY_FIELD"); +} + +TEST(Preprocessor, Variable) { + ParameterAccessor accessor(true); + std::vector v; + v.push_back(int2(1, 2)); + ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", v})); + std::string result; + EXPECT_EQ(accessor.Rewrite("var[i].y", &result), RewriteStatus::SUCCESS); + ASSERT_EQ(result, "var[i].y"); + ASSERT_EQ(accessor.GetConstDeclarations(), + "const ivec2 var[] = ivec2[1](ivec2(1,2));\n"); +} + +TEST(Preprocessor, InlineVariableFail) { + ParameterAccessor accessor(true); + ASSERT_TRUE(accessor.AddParameter(UniformParameter{"var", 1})); + std::string result; + EXPECT_EQ(accessor.Rewrite("var[i]", &result), RewriteStatus::ERROR); + ASSERT_EQ(result, "INVALID_ACCESS_BY_INDEX"); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc new file mode 100644 index 00000000000..01ea764b0b0 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc @@ -0,0 +1,95 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +// Given input string and a delimiter returns back a substring including +// delimiters. If there was only starting delimiter found, returns single char. +absl::string_view FindInlineBlock(absl::string_view s, char delimiter) { + size_t start = s.find(delimiter); + if (start != absl::string_view::npos) { + size_t end = s.find(delimiter, start + 1); + if (end != std::string::npos) { + return s.substr(start, end - start + 1); + } + // Special case to indicate that we didn't find the end. + return s.substr(start, 1); + } + return s.substr(s.size(), 0); +} + +// For the given 's' and its substring 'subs' returns new substring of 's' that +// begins past 'subs'. +absl::string_view PastSubstr(absl::string_view s, absl::string_view subs) { + return s.substr(subs.data() + subs.size() - s.data()); +} + +} // namespace + +Status TextPreprocessor::Rewrite(const std::string& input, + std::string* output) { + absl::string_view s = input; + std::string result; + while (true) { + absl::string_view inline_block = FindInlineBlock(s, inline_delimiter_); + result.append(s.data(), inline_block.data() - s.data()); + if (inline_block.empty()) { + break; + } + if (inline_block.size() == 1) { + return NotFoundError("Unable to find end of inline block"); + } + s = PastSubstr(s, inline_block); + bool processed = false; + for (auto& rewrite : inline_rewrites_) { + if (processed) { + break; + } + switch (rewrite->Rewrite(inline_block.substr(1, inline_block.size() - 2), + &result)) { + case RewriteStatus::NOT_RECOGNIZED: + // try another rewrite. + break; + case RewriteStatus::SUCCESS: + processed = true; + break; + case RewriteStatus::ERROR: + return InternalError(absl::StrCat("Error while rewriting '", + inline_block, "': ", result)); + } + } + if (!processed) { + if (!keep_unknown_rewrites_) { + return NotFoundError(absl::StrCat("Didn't find inline rewrite for '", + inline_block, "'")); + } + absl::StrAppend(&result, inline_block); + } + } + *output = std::move(result); + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h new file mode 100644 index 00000000000..f01698e784f --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h @@ -0,0 +1,74 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PREPROCESSOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PREPROCESSOR_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace gl { + +enum class RewriteStatus { + SUCCESS = 0, + NOT_RECOGNIZED = 1, + ERROR = 2, +}; + +// Inline rewrite matches a string and rewrites it. +class InlineRewrite { + public: + virtual ~InlineRewrite() = default; + + virtual RewriteStatus Rewrite(absl::string_view input, + std::string* output) = 0; +}; + +// Text preprocessor runs a collection of registered rewrites. +// It uses a single character prefix as inline delimiter that needs to quote +// text to be rewritten. +class TextPreprocessor { + public: + // @param keep_unknown_rewrites if true, will keep unhandled rewrites as is + // instead of reporting an error. + TextPreprocessor(char inline_delimiter, bool keep_unknown_rewrites) + : inline_delimiter_(inline_delimiter), + keep_unknown_rewrites_(keep_unknown_rewrites) {} + + void AddRewrite(InlineRewrite* rewrite) { + inline_rewrites_.push_back(rewrite); + } + + // input and output may point to the same object. + Status Rewrite(const std::string& input, std::string* output); + + private: + const char inline_delimiter_; + const bool keep_unknown_rewrites_; + + std::vector inline_rewrites_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_PREPROCESSOR_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc new file mode 100644 index 00000000000..95fcf624460 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" + +#include +#include + +#include +#include + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class AccuInlineRewrite : public InlineRewrite { + public: + explicit AccuInlineRewrite(std::vector* blocks) + : blocks_(blocks) {} + + RewriteStatus Rewrite(absl::string_view input, std::string* output) final { + blocks_->push_back(std::string(input.data(), input.size())); + output->append("r:"); + output->append(input.data(), input.size()); + return RewriteStatus::SUCCESS; + } + + std::vector* blocks_; +}; + +std::vector ParseInlines(const std::string& text) { + std::vector blocks; + TextPreprocessor preprocessor('$', false); + AccuInlineRewrite rewrite(&blocks); + preprocessor.AddRewrite(&rewrite); + std::string discard; + preprocessor.Rewrite(text, &discard).IgnoreError(); + return blocks; +} + +TEST(Preprocessor, CornerCases) { + EXPECT_THAT(ParseInlines(""), testing::ElementsAre()); + EXPECT_THAT(ParseInlines("text text"), testing::ElementsAre()); + EXPECT_THAT(ParseInlines("$$"), testing::ElementsAre("")); +} + +TEST(Preprocessor, One) { + EXPECT_THAT(ParseInlines("$text$"), testing::ElementsAre("text")); + EXPECT_THAT(ParseInlines(" $text$ "), testing::ElementsAre("text")); +} + +TEST(Preprocessor, More) { + EXPECT_THAT(ParseInlines("Test $inline1$\n$inline2$ test $inline3$ "), + testing::ElementsAre("inline1", "inline2", "inline3")); +} + +std::string RewriteInlines(const std::string& text) { + std::vector blocks; + TextPreprocessor preprocessor('$', false); + AccuInlineRewrite rewrite(&blocks); + preprocessor.AddRewrite(&rewrite); + std::string out; + preprocessor.Rewrite(text, &out).IgnoreError(); + return out; +} + +TEST(Preprocessor, RewriteCornerCases) { + EXPECT_EQ(RewriteInlines(""), ""); + EXPECT_EQ(RewriteInlines("text text"), "text text"); + EXPECT_EQ(RewriteInlines("$$"), "r:"); +} + +TEST(Preprocessor, RewriteOne) { + EXPECT_EQ(RewriteInlines("$text$"), "r:text"); + EXPECT_EQ(RewriteInlines(" $text$ "), " r:text "); +} + +TEST(Preprocessor, RewriteMore) { + EXPECT_EQ(RewriteInlines("Test $inline1$\n$inline2$ test $inline3$ "), + "Test r:inline1\nr:inline2 test r:inline3 "); +} + +class SingleRewrite : public InlineRewrite { + public: + RewriteStatus Rewrite(absl::string_view input, std::string* output) final { + if (input == "foo") { + output->append("bla"); + return RewriteStatus::SUCCESS; + } + return RewriteStatus::NOT_RECOGNIZED; + } + + std::vector* blocks_; +}; + +TEST(Preprocessor, KeepUnknownRewrites) { + TextPreprocessor preprocessor('$', true); + SingleRewrite rewrite; + preprocessor.AddRewrite(&rewrite); + std::string out; + ASSERT_TRUE(preprocessor.Rewrite("Good morning, $name$! $foo$", &out).ok()); + EXPECT_EQ("Good morning, $name$! bla", out); +} + +TEST(Preprocessor, KeepUnknownRewrites_Fail) { + TextPreprocessor preprocessor('$', false); + SingleRewrite rewrite; + preprocessor.AddRewrite(&rewrite); + std::string out; + EXPECT_FALSE(preprocessor.Rewrite("Good morning, $name$! $foo$", &out).ok()); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc new file mode 100644 index 00000000000..1c81ebff6b2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc @@ -0,0 +1,203 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +// Rewrites names of all parameters according to returned values from the +// given NameFunctor. +class ParameterRewriter : public InlineRewrite { + public: + ParameterRewriter(const std::string& inline_delimiter, + const NameFunctor& name_func) + : inline_delimiter_(inline_delimiter), name_func_(name_func) {} + + RewriteStatus Rewrite(absl::string_view input, std::string* output) final { + auto ref = parameter_accessor_internal::Parse(input); + if (ref.name.empty()) { + absl::StrAppend(output, "INVALID_SYNTAX"); + return RewriteStatus::ERROR; + } + + auto it = + name_to_param_.find(std::string(ref.name.data(), ref.name.size())); + if (it == name_to_param_.end()) { + return RewriteStatus::NOT_RECOGNIZED; + } + + // reconstruct access using the new name. + absl::StrAppend(output, inline_delimiter_, it->second.name); + if (!ref.index.empty()) { + absl::StrAppend(output, "[", ref.index, "]"); + } + absl::StrAppend(output, ref.field, inline_delimiter_); + return RewriteStatus::SUCCESS; + } + + // Return true if parameter was successfully added. + bool AddParameter(UniformParameter param) { + std::string old_name = param.name; + param.name = name_func_(old_name); + return name_to_param_.insert({old_name, std::move(param)}).second; + } + + // Returns a collection of uniform parameters with updated names. + std::vector GetUniformParameters() const { + std::vector params; + params.reserve(name_to_param_.size()); + for (auto& param : name_to_param_) { + params.push_back(param.second); + } + return params; + } + + private: + const std::string inline_delimiter_; + const NameFunctor name_func_; + + std::unordered_map name_to_param_; +}; + +// Rewrites names of all objects according to returned values from the +// given NameFunctor. +class ObjectRewriter : public InlineRewrite { + public: + ObjectRewriter(const std::string& inline_delimiter, + const NameFunctor& name_func) + : inline_delimiter_(inline_delimiter), name_func_(name_func) {} + + RewriteStatus Rewrite(absl::string_view input, std::string* output) final { + // Splits 'a = b' into {'a','b'}. + std::pair n = + absl::StrSplit(input, absl::MaxSplits('=', 1), absl::SkipWhitespace()); + if (n.first.empty()) { + return RewriteStatus::NOT_RECOGNIZED; + } + + if (n.second.empty()) { + return RewriteRead(absl::StripAsciiWhitespace(n.first), output); + } + return RewriteWrite(absl::StripAsciiWhitespace(n.first), + absl::StripAsciiWhitespace(n.second), output); + } + + // Return true if an object was successfully added. + bool AddObject(const std::string& name, Object object) { + std::string new_name = name_func_(name); + return name_to_object_.insert({name, {new_name, std::move(object)}}).second; + } + + // Returns a collection of registered objects with updated names. + std::vector> GetObjects() const { + std::vector> objects; + objects.reserve(name_to_object_.size()); + for (auto& o : name_to_object_) { + objects.push_back(o.second); + } + return objects; + } + + private: + RewriteStatus RewriteRead(absl::string_view location, std::string* output) { + auto element = object_accessor_internal::ParseElement(location); + if (element.object_name.empty()) { + absl::StrAppend(output, "UNABLE_TO_PARSE_INDEXED_ELEMENT"); + return RewriteStatus::ERROR; + } + auto it = name_to_object_.find( + std::string(element.object_name.data(), element.object_name.size())); + if (it == name_to_object_.end()) { + return RewriteStatus::NOT_RECOGNIZED; + } + absl::StrAppend(output, inline_delimiter_, it->second.first, "[", + absl::StrJoin(element.indices, ","), "]", + inline_delimiter_); + return RewriteStatus::SUCCESS; + } + + RewriteStatus RewriteWrite(absl::string_view location, + absl::string_view value, std::string* output) { + // name[index1, index2...] = value + auto element = object_accessor_internal::ParseElement(location); + if (element.object_name.empty()) { + absl::StrAppend(output, "UNABLE_TO_PARSE_INDEXED_ELEMENT"); + return RewriteStatus::ERROR; + } + auto it = name_to_object_.find( + std::string(element.object_name.data(), element.object_name.size())); + if (it == name_to_object_.end()) { + return RewriteStatus::NOT_RECOGNIZED; + } + absl::StrAppend(output, inline_delimiter_, it->second.first, "[", + absl::StrJoin(element.indices, ","), "] = ", value, + inline_delimiter_); + return RewriteStatus::SUCCESS; + } + + const std::string inline_delimiter_; + const NameFunctor name_func_; + + std::unordered_map> + name_to_object_; +}; + +} // namespace + +Status Rename(const NameFunctor& name_func, GeneratedCode* code) { + ParameterRewriter param_rewriter("$", name_func); + ObjectRewriter object_rewriter("$", name_func); + for (auto&& param : code->parameters) { + if (!param_rewriter.AddParameter(std::move(param))) { + return InternalError("Parameter name already exists"); + } + } + for (auto&& object : code->objects) { + if (!object_rewriter.AddObject(object.first, std::move(object.second))) { + return InternalError("Object name already exists"); + } + } + TextPreprocessor preprocessor('$', /* keep_unknown_rewrites = */ true); + preprocessor.AddRewrite(¶m_rewriter); + preprocessor.AddRewrite(&object_rewriter); + std::string source_code; + RETURN_IF_ERROR(preprocessor.Rewrite(code->source_code, &source_code)); + code->source_code = source_code; + code->parameters = param_rewriter.GetUniformParameters(); + code->objects = object_rewriter.GetObjects(); + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.h b/tensorflow/lite/delegates/gpu/gl/compiler/rename.h new file mode 100644 index 00000000000..06921dbe3da --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.h @@ -0,0 +1,41 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_RENAME_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_RENAME_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Functor takes old name and returns new name. +using NameFunctor = std::function; + +// Rewrites source code, objects and parameters with the new names supplied +// by the given functor. +Status Rename(const NameFunctor& name_func, GeneratedCode* code); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_RENAME_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h b/tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h new file mode 100644 index 00000000000..8d6d52b002a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODE_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { + +struct ShaderCode { + ShaderCode() = default; + ShaderCode(const std::vector& in_parameters, + const std::vector& in_objects, const uint3& in_workload, + const uint3& in_recommended_workgroup, + const std::string& in_source_code, + const std::vector& in_node_indices) + : parameters(in_parameters), + objects(in_objects), + workload(in_workload), + recommended_workgroup(in_recommended_workgroup), + source_code(in_source_code), + node_indices(in_node_indices) {} + + // A list of uniform parameters to be set. + std::vector parameters; + + // A list of objects to bind to opengl program. + std::vector objects; + + uint3 workload; + + // operation may specify recommended workgroup size + uint3 recommended_workgroup; + + // Generated source code does not set local size, therefore it needs to be set + // elsewhere. + std::string source_code; + + // nodes of the graph that are covered by the shader. + std::vector node_indices; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc new file mode 100644 index 00000000000..95e6a1f976a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.cc @@ -0,0 +1,148 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" + +namespace tflite { +namespace gpu { +namespace gl { + +ShaderCodegen::ShaderCodegen(const CompilationOptions& options, + const GpuInfo& gpu_info) + : options_(options), gpu_type_(gpu_info.type) {} + +Status ShaderCodegen::Build(CompiledNodeAttributes attr, + ShaderCode* shader_code) const { + ParameterAccessor parameters(options_.inline_parameters); + ObjectAccessor objects(gpu_type_ == GpuType::MALI, ¶meters); + + auto add_object = [&](const std::string& name, Object&& object) { + if (!objects.AddObject(name, std::forward(object))) { + return InternalError("There is an object with the same name"); + } + return OkStatus(); + }; + + auto add_parameter = [&](UniformParameter&& param) { + if (!parameters.AddParameter(std::forward(param))) { + return InternalError("There is a parameter with the same name"); + } + return OkStatus(); + }; + + for (auto&& param : attr.code.parameters) { + RETURN_IF_ERROR(add_parameter(std::move(param))); + } + + for (auto&& object : attr.code.objects) { + RETURN_IF_ERROR(add_object(object.first, std::move(object.second))); + } + + int index = 0; + for (auto&& input : attr.inputs) { + RETURN_IF_ERROR( + add_object(absl::StrCat("input_data_", index++), std::move(input))); + } + index = 0; + for (auto&& output : attr.outputs) { + RETURN_IF_ERROR( + add_object(absl::StrCat("output_data_", index++), std::move(output))); + } + + // TODO(akulik): workload params need to go away and be replaced with + // output_data_0_w + RETURN_IF_ERROR(add_parameter( + {"workload_x", static_cast(attr.code.workload.x)})); + RETURN_IF_ERROR(add_parameter( + {"workload_y", static_cast(attr.code.workload.y)})); + RETURN_IF_ERROR(add_parameter( + {"workload_z", static_cast(attr.code.workload.z)})); + + std::string source_code = R"( + ivec3 gid = ivec3(gl_GlobalInvocationID.xyz); + if (gid.x >= $workload_x$ || gid.y >= $workload_y$ || gid.z >= $workload_z$) { + return; + } +)"; + + switch (attr.code.input) { + case IOStructure::ONLY_DEFINITIONS: + for (int i = 0; i < attr.inputs.size(); ++i) { + absl::StrAppend(&source_code, " highp vec4 value_", i, + " = vec4(0);\n"); + } + break; + case IOStructure::AUTO: { + for (int i = 0; i < attr.inputs.size(); ++i) { + absl::StrAppend(&source_code, " highp vec4 value_", i, + " = $input_data_", i, "[gid.x, gid.y, gid.z]$;\n"); + } + break; + } + } + + source_code.append(attr.code.source_code); + + if (attr.code.output == IOStructure::AUTO) { + for (int i = 0; i < attr.outputs.size(); ++i) { + absl::StrAppend(&source_code, " $output_data_", i, + "[gid.x, gid.y, gid.z] = value_", i, "$;\n"); + } + } + + // At this point main function is already generated. Now we need to process + // object and parameter accessors. + + // process objects first. Object accessor may introduce new uniform + // parameters that need to be rewritten in the subsequent pass. + { + TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/true); + preprocessor.AddRewrite(&objects); + RETURN_IF_ERROR(preprocessor.Rewrite(source_code, &source_code)); + } + + { + TextPreprocessor preprocessor('$', /*keep_unknown_rewrites=*/false); + preprocessor.AddRewrite(¶meters); + RETURN_IF_ERROR(preprocessor.Rewrite(source_code, &source_code)); + } + + if (options_.inline_parameters) { + source_code = absl::StrCat(parameters.GetConstDeclarations(), source_code); + } + + std::string declarations = absl::StrCat( + objects.GetFunctionsDeclarations(), "\n", objects.GetObjectDeclarations(), + "\n", parameters.GetUniformDeclarations()); + *shader_code = ShaderCode( + parameters.GetUniformParameters(), objects.GetObjects(), + attr.code.workload, attr.code.workgroup, + absl::StrCat("layout(std430) buffer;\nprecision ", + (options_.allow_precision_loss ? "mediump" : "highp"), + " float;\n", declarations, "\nvoid main() {\n", source_code, + "\n}"), + attr.node_indices); + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h new file mode 100644 index 00000000000..06e4cf8f002 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler/shader_codegen.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODEGEN_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODEGEN_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/parameter_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// This class is responsible for assembling a shader by putting together +// objects, parameters declarations and main function. +class ShaderCodegen { + public: + ShaderCodegen(const CompilationOptions& options, const GpuInfo& gpu_info); + + // Builds final program representation. + Status Build(CompiledNodeAttributes attr, ShaderCode* shader_code) const; + + private: + const CompilationOptions options_; + const GpuType gpu_type_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_SHADER_CODEGEN_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/compiler_options.h b/tensorflow/lite/delegates/gpu/gl/compiler_options.h new file mode 100644 index 00000000000..6f32be234fe --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/compiler_options.h @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OPTIONS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OPTIONS_H_ + +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Default constructor for options turns on all optimizations. +struct CompilationOptions { + // Allows to quantify tensors, downcast values, process in float16 etc. + bool allow_precision_loss = false; + + // When set few operations are fused into a single shader. Therefore, there + // will be less shaders, but each shader will become larger. + bool fuse_operations = true; + + // Parameters will be inlined into a shader. This in turn will generated more + // unique shaders where each will need to be compiled. + bool inline_parameters = false; + + // If true, shaders, that have auto-input and auto-output, will use a single + // object for reading and writing. + bool inline_objects = true; // TODO(akulik): unsupported + + // Can be only Textures or Buffers + ObjectType preferred_obj_type = ObjectType::UNKNOWN; + // User has an option to choose between textures and buffers. Textures work + // better on Adreno and buffers are better for Mali. + + // Chooses object type to represent intermediate tensors. Buffers have more + // efficient memory usage because they represent opaque memory blob, but + // textures work better on Adreno. + // TODO(akulik): may be better name? + ObjectType ref_obj_type = ObjectType::UNKNOWN; + + // If true, a user may change BATCH dimension at runtime. Otherwise, static + // batch size will be fixed during compile time. + // Dynamic mode uses less memory, while static mode may yield better + // performance for small models. + bool dynamic_batch = false; + + // Fuses consequent nodes which have auto output and auto input. + bool auto_input_fusion = true; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_COMPILER_OPTIONS_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/converters/BUILD b/tensorflow/lite/delegates/gpu/gl/converters/BUILD new file mode 100644 index 00000000000..8f6b5618dd6 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/converters/BUILD @@ -0,0 +1,108 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") + +cc_library( + name = "util", + hdrs = ["util.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "bhwc_to_phwc4", + srcs = ["bhwc_to_phwc4.cc"], + hdrs = ["bhwc_to_phwc4.h"], + deps = [ + ":util", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/gl:command_queue", + "//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "//tensorflow/lite/delegates/gpu/gl:gl_program", + "//tensorflow/lite/delegates/gpu/gl:gl_shader", + "//tensorflow/lite/delegates/gpu/gl:uniform_parameter", + ], +) + +cc_test( + name = "bhwc_to_phwc4_test", + size = "small", + srcs = ["bhwc_to_phwc4_test.cc"], + linkopts = [ + "-lGLESv3", + "-lEGL", + ], + tags = [ + "local", + "nobuilder", + "notap", + "tflite_not_portable_ios", + ], + deps = [ + ":bhwc_to_phwc4", + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/gl:egl_environment", + "//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "//tensorflow/lite/delegates/gpu/gl:portable", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "phwc4_to_bhwc", + srcs = ["phwc4_to_bhwc.cc"], + hdrs = ["phwc4_to_bhwc.h"], + deps = [ + ":util", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/gl:command_queue", + "//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "//tensorflow/lite/delegates/gpu/gl:gl_program", + "//tensorflow/lite/delegates/gpu/gl:gl_shader", + "//tensorflow/lite/delegates/gpu/gl:uniform_parameter", + ], +) + +cc_test( + name = "phwc4_to_bhwc_test", + size = "small", + srcs = ["phwc4_to_bhwc_test.cc"], + linkopts = [ + "-lGLESv3", + "-lEGL", + ], + tags = [ + "local", + "nobuilder", + "notap", + "tflite_not_portable_ios", + ], + deps = [ + ":phwc4_to_bhwc", + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/gl:egl_environment", + "//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "//tensorflow/lite/delegates/gpu/gl:portable", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc new file mode 100644 index 00000000000..d48d9544025 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.cc @@ -0,0 +1,106 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h" + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/gl/converters/util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { + +Status ConverterBhwcToPhwc4::Create(ConverterBhwcToPhwc4* converter) { + uint3 workgroup_size = uint3(4, 4, 4); + std::string shader_source = GetShaderHeader(workgroup_size) + R"( + layout(std430) buffer; + + precision highp float; + + layout(binding = 0) readonly buffer B0 { + float elements[]; + } input_data; + + layout(binding = 1) writeonly buffer B1 { + vec4 elements[]; + } output_data; + + uniform ivec4 sizes_; + + void main() { + ivec3 gid = ivec3(gl_GlobalInvocationID.xyz); + if (gid.x >= sizes_.x || gid.y >= sizes_.y || gid.z >= sizes_.z) { + return; + } + vec4 v = vec4(0); + int dst_channel = gid.z * 4; + int index = (gid.y * sizes_.x + gid.x) * sizes_.w + dst_channel; + for (int i = 0; i < 4; ++i, ++index, ++dst_channel) { + if (dst_channel >= sizes_.w) break; + v[i] = input_data.elements[index]; + } + output_data.elements[(gid.z * sizes_.y + gid.y) * sizes_.x + gid.x] = v; + })"; + + GlShader shader; + RETURN_IF_ERROR( + GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, &shader)); + GlProgram program; + RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); + *converter = ConverterBhwcToPhwc4(std::move(program), workgroup_size); + return OkStatus(); +} + +Status ConverterBhwcToPhwc4::Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue, + GlBuffer* destination) { + if (source.bytes_size() < BytesForBHWC(shape)) { + return InvalidArgumentError( + "BhwcToPhwc4: Input data size does not match expected size."); + } + if (destination->bytes_size() < BytesForPHWC4(shape)) { + return InvalidArgumentError( + "BhwcToPhwc4: output data size does not match expected size."); + } + if (shape.b != 1) { + return UnimplementedError("BhwcToPhwc4: Batch size is not equal to 1."); + } + uint3 workload = uint3(shape.w, shape.h, shape.c); + uint3 num_workgroups = IntegralDivideRoundUp(workload, workgroup_size_); + + RETURN_IF_ERROR(program_.SetParameter(UniformParameter{ + "sizes_", + int4(static_cast(workload.x), static_cast(workload.y), + static_cast(workload.z), static_cast(shape.c))})); + RETURN_IF_ERROR(source.BindToIndex(0)); + RETURN_IF_ERROR(destination->BindToIndex(1)); + if (command_queue) { + return command_queue->Dispatch(program_, num_workgroups); + } + return program_.Dispatch(num_workgroups); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h new file mode 100644 index 00000000000..9d9e6402ffa --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_BHWC_TO_PHWC4_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_BHWC_TO_PHWC4_H_ + +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/command_queue.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" + +namespace tflite { +namespace gpu { +namespace gl { + +class ConverterBhwcToPhwc4 { + public: + // Creates invalid object. + ConverterBhwcToPhwc4() : program_(), workgroup_size_() {} + + static Status Create(ConverterBhwcToPhwc4* converter); + + Status Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue /* optional */, + GlBuffer* destination); + + private: + explicit ConverterBhwcToPhwc4(GlProgram program, const uint3& workgroup_size) + : program_(std::move(program)), workgroup_size_(workgroup_size) {} + + GlProgram program_; + uint3 workgroup_size_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_BHWC_TO_PHWC4_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc new file mode 100644 index 00000000000..6fc424047a1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h" + +#include +#include + +#include +#include +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +inline std::vector GenerateFloats(float multiplier, int size) { + std::vector v(size); + for (int i = 0; i < size; ++i) { + v[i] = multiplier * i * (i % 2 == 0 ? -1 : 1); + } + return v; +} + +Status RunTest(const BHWC& shape) { + // Create random input and calculate expected output for it. + std::vector input = GenerateFloats(0.01, shape.DimensionsProduct()); + std::vector output(GetElementsSizeForPHWC4(shape), 0); + RETURN_IF_ERROR( + ConvertToPHWC4(absl::MakeConstSpan(input.data(), input.size()), shape, + absl::MakeSpan(output.data(), output.size()))); + + std::unique_ptr env; + RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env)); + + // Create input and output buffers + GlBuffer input_buffer; + RETURN_IF_ERROR(CreateReadOnlyShaderStorageBuffer( + absl::MakeConstSpan(input.data(), input.size()), &input_buffer)); + + GlBuffer output_buffer; + RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( + GetElementsSizeForPHWC4(shape), &output_buffer)); + + // Create converter and run it. + ConverterBhwcToPhwc4 converter; + RETURN_IF_ERROR(ConverterBhwcToPhwc4::Create(&converter)); + RETURN_IF_ERROR( + converter.Convert(shape, input_buffer, nullptr, &output_buffer)); + + std::vector converted_output(output.size(), 0); + RETURN_IF_ERROR(output_buffer.Read( + absl::MakeSpan(converted_output.data(), converted_output.size()))); + if (output != converted_output) { + return InternalError("Outputs don't match"); + } + return OkStatus(); +} + +TEST(HwcToPhwc4, Smoke) { + for (int32_t h : {1, 2, 3, 7, 20}) { + for (int32_t w : {1, 2, 4, 5, 11}) { + for (int32_t c : {1, 2, 4, 5, 8, 9}) { + BHWC shape(1, h, w, c); + EXPECT_TRUE(RunTest(shape).ok()) + << shape.h << " " << shape.w << " " << shape.c; + } + } + } +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc new file mode 100644 index 00000000000..65f19d4513d --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.cc @@ -0,0 +1,102 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h" + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/gl/converters/util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { + +Status ConverterPhwc4ToBhwc::Create(ConverterPhwc4ToBhwc* converter) { + uint3 workgroup_size = uint3(4, 4, 4); + std::string shader_source = GetShaderHeader(workgroup_size) + R"( + layout(std430) buffer; + + precision highp float; + + layout(binding = 0) readonly buffer B0 { + vec4 elements[]; + } input_data; + + layout(binding = 1) writeonly buffer B1 { + float elements[]; + } output_data; + + uniform ivec4 sizes_; + + void main() { + ivec3 gid = ivec3(gl_GlobalInvocationID.xyz); + if (gid.x >= sizes_.x || gid.y >= sizes_.y || gid.z >= sizes_.z) { + return; + } + output_data.elements[(gid.y * sizes_.x + gid.x) * sizes_.z + gid.z] = input_data.elements[(gid.z / 4 * sizes_.y + gid.y) * sizes_.x + gid.x][gid.z % 4]; + })"; + + GlShader shader; + RETURN_IF_ERROR( + GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, &shader)); + GlProgram program; + RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); + *converter = ConverterPhwc4ToBhwc(std::move(program), workgroup_size); + return OkStatus(); +} + +Status ConverterPhwc4ToBhwc::Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue, + GlBuffer* destination) { + if (source.bytes_size() < BytesForPHWC4(shape)) { + return InvalidArgumentError( + "Phwc4ToBhwc: Input data size does not match expected size."); + } + if (destination->bytes_size() < BytesForBHWC(shape)) { + return InvalidArgumentError( + "Phwc4ToBhwc: output data size does not match expected size."); + } + if (shape.b != 1) { + return UnimplementedError("Phwc4ToBhwc: Batch size is not equal to 1."); + } + + uint3 workload = uint3(shape.w, shape.h, shape.c); + uint3 num_workgroups = IntegralDivideRoundUp(workload, workgroup_size_); + + // TODO(akulik): simply pass workload as soon as UniformParameter + // supports uint3 + RETURN_IF_ERROR(program_.SetParameter(UniformParameter{ + "sizes_", + int4(static_cast(workload.x), static_cast(workload.y), + static_cast(workload.z), 0)})); + RETURN_IF_ERROR(source.BindToIndex(0)); + RETURN_IF_ERROR(destination->BindToIndex(1)); + if (command_queue) { + return command_queue->Dispatch(program_, num_workgroups); + } + return program_.Dispatch(num_workgroups); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h new file mode 100644 index 00000000000..c8b181223ae --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_PHWC4_TO_BHWC_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_PHWC4_TO_BHWC_H_ + +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/command_queue.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" + +namespace tflite { +namespace gpu { +namespace gl { + +class ConverterPhwc4ToBhwc { + public: + // Creates invalid object. + ConverterPhwc4ToBhwc() : program_(), workgroup_size_() {} + + static Status Create(ConverterPhwc4ToBhwc* converter); + + Status Convert(const BHWC& shape, const GlBuffer& source, + CommandQueue* command_queue /* optional */, + GlBuffer* destination); + + private: + explicit ConverterPhwc4ToBhwc(GlProgram program, const uint3& workgroup_size) + : program_(std::move(program)), workgroup_size_(workgroup_size) {} + + GlProgram program_; + uint3 workgroup_size_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_PHWC4_TO_BHWC_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc new file mode 100644 index 00000000000..6f969bb7801 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h" + +#include +#include + +#include +#include +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +inline std::vector GenerateFloats(float multiplier, int size) { + std::vector v(size); + for (int i = 0; i < size; ++i) { + v[i] = multiplier * i * (i % 2 == 0 ? -1 : 1); + } + return v; +} + +Status RunTest(const BHWC& shape) { + // Create random input and calculate expected output for it. + std::vector input = + GenerateFloats(0.01, GetElementsSizeForPHWC4(shape)); + std::vector output(shape.DimensionsProduct(), 0); + RETURN_IF_ERROR( + ConvertFromPHWC4(absl::MakeConstSpan(input.data(), input.size()), shape, + absl::MakeSpan(output.data(), output.size()))); + + std::unique_ptr env; + RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env)); + + // Create input and output buffers + GlBuffer input_buffer; + RETURN_IF_ERROR(CreateReadOnlyShaderStorageBuffer( + absl::MakeConstSpan(input.data(), input.size()), &input_buffer)); + + GlBuffer output_buffer; + RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( + shape.DimensionsProduct(), &output_buffer)); + + // Create converter and run it. + ConverterPhwc4ToBhwc converter; + RETURN_IF_ERROR(ConverterPhwc4ToBhwc::Create(&converter)); + RETURN_IF_ERROR( + converter.Convert(shape, input_buffer, nullptr, &output_buffer)); + + std::vector converted_output(output.size(), 0); + RETURN_IF_ERROR(output_buffer.Read( + absl::MakeSpan(converted_output.data(), converted_output.size()))); + if (output != converted_output) { + return InternalError("Outputs don't match"); + } + return OkStatus(); +} + +TEST(Phwc4ToHwc, Smoke) { + for (int32_t h : {1, 2, 3, 7, 20}) { + for (int32_t w : {1, 2, 4, 5, 11}) { + for (int32_t c : {1, 2, 4, 5, 8, 9}) { + BHWC shape(1, h, w, c); + EXPECT_TRUE(RunTest(shape).ok()) + << shape.h << " " << shape.w << " " << shape.c; + } + } + } +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/converters/util.h b/tensorflow/lite/delegates/gpu/gl/converters/util.h new file mode 100644 index 00000000000..67f35497e88 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/converters/util.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_UTIL_H_ + +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace gl { + +inline std::string GetShaderHeader(const uint3& localsize) { + return absl::StrCat("#version 310 es\nlayout(local_size_x = ", localsize.x, + ", local_size_y = ", localsize.y, + ", local_size_z = ", localsize.z, ") in;\n"); +} + +inline uint32_t BytesForPHWC4(const BHWC& shape) { + return shape.b * shape.h * shape.w * AlignByN(shape.c, 4) * sizeof(float); +} + +inline uint32_t BytesForBHWC(const BHWC& shape) { + return shape.DimensionsProduct() * sizeof(float); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_CONVERTERS_UTIL_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/egl_context.cc b/tensorflow/lite/delegates/gpu/gl/egl_context.cc new file mode 100644 index 00000000000..8d714e27d8b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/egl_context.cc @@ -0,0 +1,143 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/egl_context.h" + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +Status GetConfig(EGLDisplay display, const EGLint* attributes, + EGLConfig* config) { + EGLint config_count; + bool chosen = eglChooseConfig(display, attributes, config, 1, &config_count); + RETURN_IF_ERROR(GetOpenGlErrors()); + if (!chosen || config_count == 0) { + return InternalError("No EGL error, but eglChooseConfig failed."); + } + return OkStatus(); +} + +Status CreateContext(EGLDisplay display, EGLContext shared_context, + EGLConfig config, EglContext* egl_context) { + static const EGLint attributes[] = {EGL_CONTEXT_CLIENT_VERSION, 3, +#ifdef _DEBUG // Add debugging bit + EGL_CONTEXT_FLAGS_KHR, + EGL_CONTEXT_OPENGL_DEBUG_BIT_KHR, +#endif + EGL_NONE}; + EGLContext context = + eglCreateContext(display, config, shared_context, attributes); + RETURN_IF_ERROR(GetOpenGlErrors()); + if (context == EGL_NO_CONTEXT) { + return InternalError("No EGL error, but eglCreateContext failed."); + } + *egl_context = EglContext(context, display, config); + return OkStatus(); +} + +bool HasExtension(EGLDisplay display, const char* name) { + return strstr(eglQueryString(display, EGL_EXTENSIONS), name); +} + +} // namespace + +void EglContext::Invalidate() { + if (context_ != EGL_NO_CONTEXT) { + eglMakeCurrent(display_, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); + eglDestroyContext(display_, context_); + context_ = EGL_NO_CONTEXT; + } +} + +EglContext::EglContext(EglContext&& other) + : context_(other.context_), + display_(other.display_), + config_(other.config_) { + other.context_ = EGL_NO_CONTEXT; +} + +EglContext& EglContext::operator=(EglContext&& other) { + if (this != &other) { + Invalidate(); + std::swap(context_, other.context_); + display_ = other.display_; + config_ = other.config_; + } + return *this; +} + +Status EglContext::MakeCurrent(EGLSurface read, EGLSurface write) { + bool is_made_current = eglMakeCurrent(display_, write, read, context_); + RETURN_IF_ERROR(GetOpenGlErrors()); + if (!is_made_current) { + return InternalError("No EGL error, but eglMakeCurrent failed."); + } + return OkStatus(); +} + +bool EglContext::IsCurrent() const { + return context_ == eglGetCurrentContext(); +} + +Status CreateConfiglessContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context) { + if (!HasExtension(display, "EGL_KHR_no_config_context")) { + return UnavailableError("EGL_KHR_no_config_context not supported"); + } + return CreateContext(display, shared_context, EGL_NO_CONFIG_KHR, egl_context); +} + +Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context) { + if (!HasExtension(display, "EGL_KHR_create_context")) { + return UnavailableError("EGL_KHR_create_context not supported"); + } + if (!HasExtension(display, "EGL_KHR_surfaceless_context")) { + return UnavailableError("EGL_KHR_surfaceless_context not supported"); + } + const EGLint attributes[] = {EGL_RENDERABLE_TYPE, EGL_OPENGL_ES3_BIT_KHR, + EGL_NONE}; + EGLConfig config; + RETURN_IF_ERROR(GetConfig(display, attributes, &config)); + return CreateContext(display, shared_context, config, egl_context); +} + +Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context) { + const EGLint attributes[] = {EGL_SURFACE_TYPE, + EGL_PBUFFER_BIT, + EGL_BLUE_SIZE, + 8, + EGL_GREEN_SIZE, + 8, + EGL_RED_SIZE, + 8, + EGL_RENDERABLE_TYPE, + EGL_OPENGL_ES3_BIT_KHR, + EGL_NONE}; + EGLConfig config; + RETURN_IF_ERROR(GetConfig(display, attributes, &config)); + return CreateContext(display, shared_context, config, egl_context); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/egl_context.h b/tensorflow/lite/delegates/gpu/gl/egl_context.h new file mode 100644 index 00000000000..532d2d856aa --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/egl_context.h @@ -0,0 +1,93 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_CONTEXT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_CONTEXT_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_egl.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// EglContext is an RAII wrapper for an EGLContext. +// +// EglContext is moveable but not copyable. +// +// See https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglIntro.xhtml for +// more info. +class EglContext { + public: + // Creates an invalid EglContext. + EglContext() + : context_(EGL_NO_CONTEXT), + display_(EGL_NO_DISPLAY), + config_(EGL_NO_CONFIG_KHR) {} + + EglContext(EGLContext context, EGLDisplay display, EGLConfig config) + : context_(context), display_(display), config_(config) {} + + // Move only + EglContext(EglContext&& other); + EglContext& operator=(EglContext&& other); + EglContext(const EglContext&) = delete; + EglContext& operator=(const EglContext&) = delete; + + ~EglContext() { Invalidate(); } + + EGLContext context() const { return context_; } + + EGLDisplay display() const { return display_; } + + EGLConfig config() const { return config_; } + + // Make this EglContext the current EGL context on this thread, replacing + // the existing current. + Status MakeCurrent(EGLSurface read, EGLSurface write); + + Status MakeCurrentSurfaceless() { + return MakeCurrent(EGL_NO_SURFACE, EGL_NO_SURFACE); + } + + // Returns true if this is the currently bound EGL context. + bool IsCurrent() const; + + private: + void Invalidate(); + + EGLContext context_; + EGLDisplay display_; + EGLConfig config_; +}; + +// It uses the EGL_KHR_no_config_context extension to create a no config context +// since most modern hardware supports the extension. +Status CreateConfiglessContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context); + +Status CreateSurfacelessContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context); + +Status CreatePBufferContext(EGLDisplay display, EGLContext shared_context, + EglContext* egl_context); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_CONTEXT_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/egl_environment.cc b/tensorflow/lite/delegates/gpu/gl/egl_environment.cc new file mode 100644 index 00000000000..3ef6601fbed --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/egl_environment.cc @@ -0,0 +1,149 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +// TODO(akulik): detect power management event when all contexts are destroyed +// and OpenGL ES is reinitialized. See eglMakeCurrent + +Status InitDisplay(EGLDisplay* egl_display) { + RETURN_IF_ERROR( + TFLITE_GPU_CALL_EGL(eglGetDisplay, egl_display, EGL_DEFAULT_DISPLAY)); + if (*egl_display == EGL_NO_DISPLAY) { + return UnavailableError("eglGetDisplay returned nullptr"); + } + bool is_initialized; + RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglInitialize, &is_initialized, + *egl_display, nullptr, nullptr)); + if (!is_initialized) { + return InternalError("No EGL error, but eglInitialize failed"); + } + return OkStatus(); +} + +} // namespace + +Status EglEnvironment::NewEglEnvironment( + std::unique_ptr* egl_environment) { + *egl_environment = absl::make_unique(); + RETURN_IF_ERROR((*egl_environment)->Init()); + return OkStatus(); +} + +EglEnvironment::~EglEnvironment() { + if (dummy_framebuffer_ != GL_INVALID_INDEX) { + glDeleteFramebuffers(1, &dummy_framebuffer_); + } + if (dummy_texture_ != GL_INVALID_INDEX) { + glDeleteTextures(1, &dummy_texture_); + } +} + +Status EglEnvironment::Init() { + bool is_bound; + RETURN_IF_ERROR( + TFLITE_GPU_CALL_EGL(eglBindAPI, &is_bound, EGL_OPENGL_ES_API)); + if (!is_bound) { + return InternalError("No EGL error, but eglBindAPI failed"); + } + + // Re-use context and display if it was created on this thread. + if (eglGetCurrentContext() != EGL_NO_CONTEXT) { + display_ = eglGetCurrentDisplay(); + context_ = EglContext(eglGetCurrentContext(), display_, EGL_NO_CONFIG_KHR); + } else { + RETURN_IF_ERROR(InitDisplay(&display_)); + + Status status = InitConfiglessContext(); + if (!status.ok()) { + status = InitSurfacelessContext(); + } + if (!status.ok()) { + status = InitPBufferContext(); + } + if (!status.ok()) { + return status; + } + } + + if (gpu_info_.type == GpuType::UNKNOWN) { + RETURN_IF_ERROR(RequestGpuInfo(&gpu_info_)); + } + // TODO(akulik): when do we need ForceSyncTurning? + ForceSyncTurning(); + return OkStatus(); +} + +Status EglEnvironment::InitConfiglessContext() { + RETURN_IF_ERROR(CreateConfiglessContext(display_, EGL_NO_CONTEXT, &context_)); + return context_.MakeCurrentSurfaceless(); +} + +Status EglEnvironment::InitSurfacelessContext() { + RETURN_IF_ERROR( + CreateSurfacelessContext(display_, EGL_NO_CONTEXT, &context_)); + Status status = context_.MakeCurrentSurfaceless(); + if (!status.ok()) { + return status; + } + + // PowerVR support EGL_KHR_surfaceless_context, but glFenceSync crashes on + // PowerVR when it is surface-less. + RETURN_IF_ERROR(RequestGpuInfo(&gpu_info_)); + if (gpu_info_.type == GpuType::POWERVR) { + return UnavailableError( + "Surface-less context is not properly supported on powervr."); + } + return OkStatus(); +} + +Status EglEnvironment::InitPBufferContext() { + RETURN_IF_ERROR(CreatePBufferContext(display_, EGL_NO_CONTEXT, &context_)); + RETURN_IF_ERROR(CreatePbufferRGBSurface(context_.config(), display_, 1, 1, + &surface_read_)); + RETURN_IF_ERROR(CreatePbufferRGBSurface(context_.config(), display_, 1, 1, + &surface_draw_)); + return context_.MakeCurrent(surface_read_.surface(), surface_draw_.surface()); +} + +void EglEnvironment::ForceSyncTurning() { + glGenFramebuffers(1, &dummy_framebuffer_); + glBindFramebuffer(GL_FRAMEBUFFER, dummy_framebuffer_); + + glGenTextures(1, &dummy_texture_); + glBindTexture(GL_TEXTURE_2D, dummy_texture_); + glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA8, 4, 4); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + dummy_texture_, 0); + + GLenum draw_buffers[1] = {GL_COLOR_ATTACHMENT0}; + glDrawBuffers(1, draw_buffers); + + glViewport(0, 0, 4, 4); + glClear(GL_COLOR_BUFFER_BIT); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/egl_environment.h b/tensorflow/lite/delegates/gpu/gl/egl_environment.h new file mode 100644 index 00000000000..86d93c35043 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/egl_environment.h @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_ENVIRONMENT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_ENVIRONMENT_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/egl_context.h" +#include "tensorflow/lite/delegates/gpu/gl/egl_surface.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_egl.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Class encapsulates creation of OpenGL objects needed before starting working +// with OpenGL: binds OpenGL ES API, creates new EGL context, binds it to EGL +// display and creates surfaces if needed. +// +// EGL environment needs to be created once per thread. +class EglEnvironment { + public: + static Status NewEglEnvironment( + std::unique_ptr* egl_environment); + + EglEnvironment() = default; + ~EglEnvironment(); + + const EglContext& context() const { return context_; } + EGLDisplay display() const { return display_; } + const GpuInfo& gpu_info() const { return gpu_info_; } + + private: + Status Init(); + Status InitConfiglessContext(); + Status InitSurfacelessContext(); + Status InitPBufferContext(); + + EGLDisplay display_ = EGL_NO_DISPLAY; + EglContext context_; + EglSurface surface_draw_; + EglSurface surface_read_; + GpuInfo gpu_info_; + + // Strange hack that helps on Mali GPUs + // without it glFinish and glFenceSync don't work + void ForceSyncTurning(); + GLuint dummy_framebuffer_ = GL_INVALID_INDEX; + GLuint dummy_texture_ = GL_INVALID_INDEX; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_ENVIRONMENT_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/egl_surface.cc b/tensorflow/lite/delegates/gpu/gl/egl_surface.cc new file mode 100644 index 00000000000..eaccea6411e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/egl_surface.cc @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/egl_surface.h" + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" + +namespace tflite { +namespace gpu { +namespace gl { + +EglSurface::EglSurface(EglSurface&& other) + : surface_(other.surface_), display_(other.display_) { + other.surface_ = EGL_NO_SURFACE; +} + +EglSurface& EglSurface::operator=(EglSurface&& other) { + if (this != &other) { + display_ = other.display_; + Invalidate(); + std::swap(surface_, other.surface_); + } + return *this; +} + +void EglSurface::Invalidate() { + if (surface_ != EGL_NO_SURFACE) { + eglDestroySurface(display_, surface_); + surface_ = EGL_NO_SURFACE; + } +} + +Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, + uint32_t height, uint32_t width, + EglSurface* egl_surface) { + const EGLint pbuffer_attributes[] = {EGL_WIDTH, + static_cast(width), + EGL_HEIGHT, + static_cast(height), + EGL_TEXTURE_FORMAT, + EGL_TEXTURE_RGB, + EGL_TEXTURE_TARGET, + EGL_TEXTURE_2D, + EGL_NONE}; + EGLSurface surface = + eglCreatePbufferSurface(display, config, pbuffer_attributes); + RETURN_IF_ERROR(GetOpenGlErrors()); + if (surface == EGL_NO_SURFACE) { + return InternalError("No EGL error, but eglCreatePbufferSurface failed"); + } + *egl_surface = EglSurface(surface, display); + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/egl_surface.h b/tensorflow/lite/delegates/gpu/gl/egl_surface.h new file mode 100644 index 00000000000..793dc7a9dc6 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/egl_surface.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_SURFACE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_SURFACE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_egl.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// An RAII wrapper for EGLSurface. +// See https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglIntro.xhtml for +// an introduction to the concepts. +// +// EglSurface is moveable but not copyable. +class EglSurface { + public: + // Creates an invalid EglSurface. + EglSurface() : surface_(EGL_NO_SURFACE), display_(EGL_NO_DISPLAY) {} + + EglSurface(EGLSurface surface, EGLDisplay display) + : surface_(surface), display_(display) {} + + // Move-only + EglSurface(EglSurface&& other); + EglSurface& operator=(EglSurface&& other); + EglSurface(const EglSurface&) = delete; + EglSurface& operator=(const EglSurface&) = delete; + + ~EglSurface() { Invalidate(); } + + EGLSurface surface() const { return surface_; } + + private: + void Invalidate(); + + EGLSurface surface_; + EGLDisplay display_; +}; + +// Creates off-screen pbuffer-based surface of the given height and width. +Status CreatePbufferRGBSurface(EGLConfig config, EGLDisplay display, + uint32_t height, uint32_t width, + EglSurface* egl_surface); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_EGL_SURFACE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/float16_conversions.cc b/tensorflow/lite/delegates/gpu/gl/float16_conversions.cc new file mode 100644 index 00000000000..139c81491e4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/float16_conversions.cc @@ -0,0 +1,73 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/float16_conversions.h" + +#include +#include + +#include +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +// Performs in-place conversion of float32 into float16 +bool ToFloat16(std::vector* values) { + if (values->size() % sizeof(float) != 0) { + return false; + } + + uint16_t* store_f16 = reinterpret_cast(values->data()); + const float* load_f32 = reinterpret_cast(values->data()); + const float* end_load_f32 = + reinterpret_cast(values->data() + values->size()); + + while (load_f32 != end_load_f32) { + *store_f16++ = fp16_ieee_from_fp32_value(*load_f32++); + } + + values->resize(values->size() / 2); + return true; +} + +struct ConverterToFloat16 { + bool operator()(ObjectData& data) const { // NOLINT + return ToFloat16(&data); + } + + bool operator()(ObjectRef& buffer) const { // NOLINT + return true; + } +}; + +} // namespace + +bool MaybeConvertToFloat16(Object* object) { + if (object->data_type == DataType::FLOAT32 && + absl::visit(ConverterToFloat16(), object->object)) { + object->data_type = DataType::FLOAT16; + return true; + } + return false; +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/float16_conversions.h b/tensorflow/lite/delegates/gpu/gl/float16_conversions.h new file mode 100644 index 00000000000..304c2a23fc1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/float16_conversions.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_FLOAT16_CONVERSIONS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_FLOAT16_CONVERSIONS_H_ + +#include "tensorflow/lite/delegates/gpu/gl/object.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// If an object is float32, converts it to float16 representation. +bool MaybeConvertToFloat16(Object* object); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_FLOAT16_CONVERSIONS_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc b/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc new file mode 100644 index 00000000000..6e5e8afa364 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc @@ -0,0 +1,89 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" + +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace gl { + +Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer) { + if (read_buffer.bytes_size() != write_buffer.bytes_size()) { + return InvalidArgumentError( + "Read buffer does not match write buffer size."); + } + gl_buffer_internal::BufferBinder read_buffer_binder(GL_COPY_READ_BUFFER, + read_buffer.id()); + gl_buffer_internal::BufferBinder write_buffer_binder(GL_COPY_WRITE_BUFFER, + write_buffer.id()); + return TFLITE_GPU_CALL_GL(glCopyBufferSubData, GL_COPY_READ_BUFFER, + GL_COPY_WRITE_BUFFER, read_buffer.offset(), + write_buffer.offset(), read_buffer.bytes_size()); +} + +GlBuffer::GlBuffer(GlBuffer&& buffer) + : GlBuffer(buffer.target_, buffer.id_, buffer.bytes_size_, buffer.offset_, + buffer.has_ownership_) { + buffer.has_ownership_ = false; +} + +GlBuffer& GlBuffer::operator=(GlBuffer&& buffer) { + if (this != &buffer) { + Invalidate(); + + target_ = buffer.target_; + bytes_size_ = buffer.bytes_size_; + offset_ = buffer.offset_; + has_ownership_ = buffer.has_ownership_; + id_ = buffer.id_; + buffer.has_ownership_ = false; + } + return *this; +} + +GlBuffer::~GlBuffer() { Invalidate(); } + +void GlBuffer::Invalidate() { + if (has_ownership_ && id_ != GL_INVALID_INDEX) { + TFLITE_GPU_CALL_GL(glDeleteBuffers, 1, &id_).IgnoreError(); + id_ = GL_INVALID_INDEX; + } +} + +Status GlBuffer::BindToIndex(uint32_t index) const { + return TFLITE_GPU_CALL_GL(glBindBufferRange, target_, index, id_, offset_, + bytes_size_); +} + +Status GlBuffer::MakeView(size_t offset, size_t bytes_size, + GlBuffer* gl_buffer) { + if (offset + bytes_size > bytes_size_) { + return OutOfRangeError("GlBuffer view is out of range."); + } + *gl_buffer = GlBuffer(target_, id_, bytes_size, offset_ + offset, + /*has_ownership=*/false); + return OkStatus(); +} + +GlBuffer GlBuffer::MakeRef() { + return GlBuffer(target_, id_, bytes_size_, offset_, + /* has_ownership = */ false); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer.h b/tensorflow/lite/delegates/gpu/gl/gl_buffer.h new file mode 100644 index 00000000000..5897499598c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer.h @@ -0,0 +1,298 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_BUFFER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_BUFFER_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Buffer is an RAII wrapper for OpenGL buffer object. +// See https://www.khronos.org/opengl/wiki/Buffer_Object for more information. +// +// Buffer is moveable but not copyable. +class GlBuffer { + public: + // @param has_ownership indicates that GlBuffer is responsible for + // corresponding GL buffer deletion. + GlBuffer(GLenum target, GLuint id, size_t bytes_size, size_t offset, + bool has_ownership) + : target_(target), + id_(id), + bytes_size_(bytes_size), + offset_(offset), + has_ownership_(has_ownership) {} + + // Creates invalid buffer. + GlBuffer() : GlBuffer(GL_INVALID_ENUM, GL_INVALID_INDEX, 0, 0, false) {} + + // Move-only + GlBuffer(GlBuffer&& buffer); + GlBuffer& operator=(GlBuffer&& buffer); + GlBuffer(const GlBuffer&) = delete; + GlBuffer& operator=(const GlBuffer&) = delete; + + ~GlBuffer(); + + // Reads data from buffer into CPU memory. Data should point to a region that + // has at least bytes_size available. + template + Status Read(absl::Span data) const; + + // Writes data to a buffer. + template + Status Write(absl::Span data); + + // Maps GPU memory to CPU address space and calls reader that may read from + // that memory. + template + Status MappedRead( + const std::function)>& reader) const; + + // Maps GPU memory to CPU address space and calls writer that may write into + // that memory. + template + Status MappedWrite(const std::function)>& writer); + + Status MakeView(size_t offset, size_t bytes_size, GlBuffer* gl_buffer); + + // Makes a copy without ownership of the buffer. + GlBuffer MakeRef(); + + // Binds a buffer to an index. + Status BindToIndex(uint32_t index) const; + + // Releases the ownership of the buffer object. + void Release() { has_ownership_ = false; } + + size_t bytes_size() const { return bytes_size_; } + + const GLenum target() const { return target_; } + + const GLuint id() const { return id_; } + + bool is_valid() const { return id_ != GL_INVALID_INDEX; } + + size_t offset() const { return offset_; } + + // @return true if this object actually owns corresponding GL buffer + // and manages it's lifetime. + bool has_ownership() const { return has_ownership_; } + + private: + void Invalidate(); + + GLenum target_; + GLuint id_; + size_t bytes_size_; + size_t offset_; + bool has_ownership_; +}; + +Status CopyBuffer(const GlBuffer& read_buffer, const GlBuffer& write_buffer); + +// Creates new shader storage buffer that will be modified and used many +// times. +// +// See https://www.khronos.org/opengl/wiki/Shader_Storage_Buffer_Object for +// details. +template +Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, + GlBuffer* gl_buffer); + +// Creates new shader storage buffer that will be filled with data once which +// will be used many times. +template +Status CreateReadOnlyShaderStorageBuffer(absl::Span data, + GlBuffer* gl_buffer); + +// Adapts raw Buffer::Read method to read data into a vector. +template +Status AppendFromBuffer(const GlBuffer& buffer, std::vector* data) { + if (buffer.bytes_size() % sizeof(T) != 0) { + return InvalidArgumentError("Buffer is not aligned"); + } + size_t num_elements = buffer.bytes_size() / sizeof(T); + data->resize(data->size() + num_elements); + return buffer.Read( + absl::MakeSpan(data->data() + data->size() - num_elements, num_elements)); +} + +//////////////////////////////////////////////////////////////////////////////// +// Implementation details are below. + +namespace gl_buffer_internal { + +// RAII for creating and/or owning buffer id. +class BufferId { + public: + BufferId() : id_(GL_INVALID_INDEX) { + TFLITE_GPU_CALL_GL(glGenBuffers, 1 /* number of buffers */, &id_) + .IgnoreError(); + // only possible error here is when a number of buffers is negative. + } + + explicit BufferId(GLuint id) : id_(id) {} + + ~BufferId() { + if (id_ != GL_INVALID_INDEX) { + TFLITE_GPU_CALL_GL(glDeleteBuffers, 1, &id_).IgnoreError(); + } + } + + GLuint id() const { return id_; } + + GLuint Release() { + GLuint id = GL_INVALID_INDEX; + std::swap(id, id_); + return id; + } + + private: + GLuint id_; +}; + +// RAII for binding and unbinding a buffer. +class BufferBinder { + public: + BufferBinder(GLenum target, GLuint id) : target_(target) { + TFLITE_GPU_CALL_GL(glBindBuffer, target_, id).IgnoreError(); + } + + ~BufferBinder() { + TFLITE_GPU_CALL_GL(glBindBuffer, target_, 0).IgnoreError(); + } + + private: + const GLenum target_; +}; + +// RAII for mapping and unmapping a buffer. +class BufferMapper { + public: + BufferMapper(GLenum target, size_t offset, size_t bytes, GLbitfield access) + : target_(target), + data_(glMapBufferRange(target_, offset, bytes, access)) {} + + ~BufferMapper() { TFLITE_GPU_CALL_GL(glUnmapBuffer, target_).IgnoreError(); } + + void* data() { return data_; } + + private: + const GLenum target_; + void* data_; +}; + +} // namespace gl_buffer_internal + +template +Status CreateReadWriteShaderStorageBuffer(uint32_t num_elements, + GlBuffer* gl_buffer) { + gl_buffer_internal::BufferId id; + gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id()); + // TODO(akulik): benchmark DYNAMIC vs STREAM buffer + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glBufferData, GL_SHADER_STORAGE_BUFFER, + num_elements * sizeof(T), nullptr, + GL_STREAM_COPY)); + *gl_buffer = GlBuffer{GL_SHADER_STORAGE_BUFFER, id.Release(), + num_elements * sizeof(T), 0, true}; + return OkStatus(); +} + +template +Status CreateReadOnlyShaderStorageBuffer(absl::Span data, + GlBuffer* gl_buffer) { + gl_buffer_internal::BufferId id; + gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, id.id()); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glBufferData, GL_SHADER_STORAGE_BUFFER, + data.size() * sizeof(T), data.data(), + GL_STATIC_READ)); + *gl_buffer = GlBuffer{GL_SHADER_STORAGE_BUFFER, id.Release(), + data.size() * sizeof(T), 0, true}; + return OkStatus(); +} + +template +Status GlBuffer::Read(absl::Span data) const { + if (data.size() * sizeof(T) < bytes_size()) { + return InvalidArgumentError( + "Read from buffer failed. Destination data is shorter than buffer."); + } + // TODO(akulik): glCopyBufferSubData is actually available in ES 3.1, try it. + return MappedRead([this, data](absl::Span src) { + std::memcpy(data.data(), src.data(), bytes_size()); + return OkStatus(); + }); +} + +template +Status GlBuffer::Write(absl::Span data) { + if (data.size() * sizeof(T) > bytes_size_) { + return InvalidArgumentError( + "Write to buffer failed. Source data is larger than buffer."); + } + gl_buffer_internal::BufferBinder binder(target_, id_); + return TFLITE_GPU_CALL_GL(glBufferSubData, target_, offset_, bytes_size_, + data.data()); +} + +template +Status GlBuffer::MappedRead( + const std::function d)>& reader) const { + if (bytes_size_ % sizeof(T) != 0) { + return InvalidArgumentError("Buffer is not aligned"); + } + gl_buffer_internal::BufferBinder binder(target_, id_); + gl_buffer_internal::BufferMapper mapper(target_, offset_, bytes_size_, + GL_MAP_READ_BIT); + if (!mapper.data()) { + return GetOpenGlErrors(); + } + return reader(absl::MakeSpan(reinterpret_cast(mapper.data()), + bytes_size_ / sizeof(T))); +} + +template +Status GlBuffer::MappedWrite( + const std::function d)>& writer) { + if (bytes_size_ % sizeof(T) != 0) { + return InvalidArgumentError("Buffer is not aligned"); + } + gl_buffer_internal::BufferBinder binder(target_, id_); + gl_buffer_internal::BufferMapper mapper(target_, offset_, bytes_size_, + GL_MAP_WRITE_BIT); + if (!mapper.data()) { + return GetOpenGlErrors(); + } + return writer(absl::MakeSpan(reinterpret_cast(mapper.data()), + bytes_size_ / sizeof(T))); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_BUFFER_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc b/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc new file mode 100644 index 00000000000..1d8031fcf39 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" + +#include + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +TEST(Buffer, Read) { + std::unique_ptr env; + ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok()); + std::vector test = {0, 1, 2, 3}; + GlBuffer buffer; + ASSERT_TRUE(CreateReadOnlyShaderStorageBuffer(test, &buffer).ok()); + std::vector from_buffer; + ASSERT_TRUE(AppendFromBuffer(buffer, &from_buffer).ok()); + EXPECT_EQ(test, from_buffer); +} + +TEST(Buffer, Write) { + std::unique_ptr env; + ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok()); + GlBuffer buffer; + ASSERT_TRUE(CreateReadWriteShaderStorageBuffer(4, &buffer).ok()); + std::vector test = {0, 1, 2, 3}; + ASSERT_TRUE(buffer.Write(test).ok()); + std::vector from_buffer; + ASSERT_TRUE(AppendFromBuffer(buffer, &from_buffer).ok()); + EXPECT_EQ(test, from_buffer); +} + +TEST(Buffer, View) { + std::unique_ptr env; + ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok()); + GlBuffer buffer; + ASSERT_TRUE(CreateReadWriteShaderStorageBuffer(6, &buffer).ok()); + EXPECT_TRUE(buffer.has_ownership()); + EXPECT_EQ(24, buffer.bytes_size()); + EXPECT_EQ(0, buffer.offset()); + + // Create view and write data there. + GlBuffer view; + ASSERT_TRUE(buffer.MakeView(4, 16, &view).ok()); + EXPECT_FALSE(view.has_ownership()); + EXPECT_EQ(16, view.bytes_size()); + EXPECT_EQ(4, view.offset()); + std::vector test = {1, 2, 3, 4}; + ASSERT_TRUE(view.Write(test).ok()); + + // Check that data indeed landed in a buffer with proper offset. + std::vector from_buffer; + ASSERT_TRUE(AppendFromBuffer(buffer, &from_buffer).ok()); + EXPECT_THAT(from_buffer, testing::ElementsAre(0, 1, 2, 3, 4, 0)); + + std::vector from_view; + ASSERT_TRUE(AppendFromBuffer(view, &from_view).ok()); + EXPECT_THAT(from_view, testing::ElementsAre(1, 2, 3, 4)); +} + +TEST(Buffer, SubView) { + std::unique_ptr env; + ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok()); + GlBuffer buffer; + ASSERT_TRUE(CreateReadWriteShaderStorageBuffer(6, &buffer).ok()); + + // Create view and another view over that view. + + GlBuffer view1; + ASSERT_TRUE(buffer.MakeView(4, 16, &view1).ok()); + GlBuffer view2; + EXPECT_NE(view1.MakeView(1, 16, &view2), OkStatus()); + ASSERT_TRUE(view1.MakeView(2, 2, &view2).ok()); + + EXPECT_FALSE(view2.has_ownership()); + EXPECT_EQ(2, view2.bytes_size()); + EXPECT_EQ(6, view2.offset()); +} + +TEST(Buffer, Copy) { + std::unique_ptr env; + ASSERT_TRUE(EglEnvironment::NewEglEnvironment(&env).ok()); + GlBuffer buffer; + ASSERT_TRUE(CreateReadWriteShaderStorageBuffer(4, &buffer).ok()); + + // Create view and write data there. + GlBuffer view1; + ASSERT_TRUE(buffer.MakeView(4, 4, &view1).ok()); + + GlBuffer view2; + ASSERT_TRUE(buffer.MakeView(8, 4, &view2).ok()); + + // Copy data from one view to another + ASSERT_TRUE(view1.Write({1}).ok()); + ASSERT_TRUE(CopyBuffer(view1, view2).ok()); + + // Check that data indeed landed correctly. + std::vector from_buffer; + ASSERT_TRUE(AppendFromBuffer(buffer, &from_buffer).ok()); + EXPECT_THAT(from_buffer, testing::ElementsAre(0, 1, 1, 0)); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gl_call.h b/tensorflow/lite/delegates/gpu/gl/gl_call.h new file mode 100644 index 00000000000..a8a81bae608 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_call.h @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_CALL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_CALL_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Primary purpose of this file is to provide useful macro for calling GL +// functions and checking errors. It also attaches a context to status in case +// of a GL error. +// +// Use TFLITE_GPU_CALL_GL as follows: +// +// For GL functions with a return value: +// Before: +// GLint result = glFunc(...); +// RETURN_IF_ERROR(GetOpenGlErrors()); +// After: +// GLint result; +// RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glFunc, &result, ...)); +// +// For GL functions without a return value: +// Before: +// glFunc(...); +// RETURN_IF_ERROR(GetOpenGlErrors()); +// After: +// RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glFunc, ...)); + +namespace gl_call_internal { + +// For GL functions with a return value. +template +struct Caller { + template + Status operator()(const std::string& context, F func, ErrorF error_func, + T* result, Params&&... params) { + *result = func(std::forward(params)...); + const auto status = error_func(); + if (status.ok()) return OkStatus(); + return Status(status.code(), status.error_message() + ": " + context); + } +}; + +// For GL functions without a return value. +template<> +struct Caller { + template + Status operator()(const std::string& context, F func, ErrorF error_func, + Params&&... params) { + func(std::forward(params)...); + const auto status = error_func(); + if (status.ok()) return OkStatus(); + return Status(status.code(), status.error_message() + ": " + context); + } +}; + +template +Status CallAndCheckError(const std::string& context, F func, ErrorF error_func, + ResultT* result, ParamsT&&... params) { + return Caller()(context, func, error_func, result, + std::forward(params)...); +} + +template +Status CallAndCheckError(const std::string& context, F func, ErrorF error_func, + Params&&... params) { + return Caller()(context, func, error_func, + std::forward(params)...); +} + +} // namespace gl_call_internal + +// XX_STRINGIFY is a helper macro to effectively apply # operator to an +// arbitrary value. +#define TFLITE_GPU_INTERNAL_STRINGIFY_HELPER(x) #x +#define TFLITE_GPU_INTERNAL_STRINGIFY(x) TFLITE_GPU_INTERNAL_STRINGIFY_HELPER(x) +#define TFLITE_GPU_FILE_LINE \ + __FILE__ ":" TFLITE_GPU_INTERNAL_STRINGIFY(__LINE__) + +#define TFLITE_GPU_CALL_GL(method, ...) \ + ::tflite::gpu::gl::gl_call_internal::CallAndCheckError( \ + #method " in " TFLITE_GPU_FILE_LINE, method, \ + ::tflite::gpu::gl::GetOpenGlErrors, __VA_ARGS__) + +#define TFLITE_GPU_CALL_EGL(method, ...) \ + ::tflite::gpu::gl::gl_call_internal::CallAndCheckError( \ + #method " in " TFLITE_GPU_FILE_LINE, method, \ + ::tflite::gpu::gl::GetEglError, __VA_ARGS__) + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_CALL_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/gl_errors.cc b/tensorflow/lite/delegates/gpu/gl/gl_errors.cc new file mode 100644 index 00000000000..2c29127839d --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_errors.cc @@ -0,0 +1,142 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" + +#include +#include + +#include "absl/strings/str_join.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_egl.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +const char* ErrorToString(GLenum error) { + switch (error) { + case GL_INVALID_ENUM: + return "[GL_INVALID_ENUM]: An unacceptable value is specified for an " + "enumerated argument."; + case GL_INVALID_VALUE: + return "[GL_INVALID_VALUE]: A numeric argument is out of range."; + case GL_INVALID_OPERATION: + return "[GL_INVALID_OPERATION]: The specified operation is not allowed " + "in the current state."; + case GL_INVALID_FRAMEBUFFER_OPERATION: + return "[GL_INVALID_FRAMEBUFFER_OPERATION]: The framebuffer object is " + "not complete."; + case GL_OUT_OF_MEMORY: + return "[GL_OUT_OF_MEMORY]: There is not enough memory left to execute " + "the command."; + } + return "[UNKNOWN_GL_ERROR]"; +} + +struct ErrorFormatter { + void operator()(std::string* out, GLenum error) const { + absl::StrAppend(out, ErrorToString(error)); + } +}; + +} // namespace + +// TODO(akulik): create new error space for GL error. + +Status GetOpenGlErrors() { + auto error = glGetError(); + if (error == GL_NO_ERROR) { + return OkStatus(); + } + auto error2 = glGetError(); + if (error2 == GL_NO_ERROR) { + return InternalError(ErrorToString(error)); + } + std::vector errors = {error, error2}; + for (error = glGetError(); error != GL_NO_ERROR; error = glGetError()) { + errors.push_back(error); + } + return InternalError(absl::StrJoin(errors, ",", ErrorFormatter())); +} + +Status GetEglError() { + EGLint error = eglGetError(); + switch (error) { + case EGL_SUCCESS: + return OkStatus(); + case EGL_NOT_INITIALIZED: + return InternalError( + "EGL is not initialized, or could not be initialized, for the " + "specified EGL display connection."); + case EGL_BAD_ACCESS: + return InternalError( + "EGL cannot access a requested resource (for example a context is " + "bound in another thread)."); + case EGL_BAD_ALLOC: + return InternalError( + "EGL failed to allocate resources for the requested operation."); + case EGL_BAD_ATTRIBUTE: + return InternalError( + "An unrecognized attribute or attribute value was passed in the " + "attribute list."); + case EGL_BAD_CONTEXT: + return InternalError( + "An EGLContext argument does not name a valid EGL rendering " + "context."); + case EGL_BAD_CONFIG: + return InternalError( + "An EGLConfig argument does not name a valid EGL frame buffer " + "configuration."); + case EGL_BAD_CURRENT_SURFACE: + return InternalError( + "The current surface of the calling thread is a window, pixel buffer " + "or pixmap that is no longer valid."); + case EGL_BAD_DISPLAY: + return InternalError( + "An EGLDisplay argument does not name a valid EGL display " + "connection."); + case EGL_BAD_SURFACE: + return InternalError( + "An EGLSurface argument does not name a valid surface (window, pixel " + "buffer or pixmap) configured for GL rendering."); + case EGL_BAD_MATCH: + return InternalError( + "Arguments are inconsistent (for example, a valid context requires " + "buffers not supplied by a valid surface)."); + case EGL_BAD_PARAMETER: + return InternalError("One or more argument values are invalid."); + case EGL_BAD_NATIVE_PIXMAP: + return InternalError( + "A NativePixmapType argument does not refer to a valid native " + "pixmap."); + case EGL_BAD_NATIVE_WINDOW: + return InternalError( + "A NativeWindowType argument does not refer to a valid native " + "window."); + case EGL_CONTEXT_LOST: + return InternalError( + "A power management event has occurred. The application must destroy " + "all contexts and reinitialise OpenGL ES state and objects to " + "continue rendering."); + } + return UnknownError("EGL error: " + std::to_string(error)); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gl_errors.h b/tensorflow/lite/delegates/gpu/gl/gl_errors.h new file mode 100644 index 00000000000..978e642abaa --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_errors.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_ERRORS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_ERRORS_H_ + +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// @return recent opengl errors and packs them into Status. +Status GetOpenGlErrors(); + +// @return the error of the last called EGL function in the current thread. +Status GetEglError(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_ERRORS_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/gl_program.cc b/tensorflow/lite/delegates/gpu/gl/gl_program.cc new file mode 100644 index 00000000000..9b0cf3c07db --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_program.cc @@ -0,0 +1,201 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" + +#include + +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +Status CreateNewProgramId(GLuint* program_id) { + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glCreateProgram, program_id)); + if (!*program_id) { + return UnknownError("Can't create opengl program: 0 program_id"); + } + return OkStatus(); +} + +Status CheckProgramLinked(GLuint program_id) { + GLint linked; + glGetProgramiv(program_id, GL_LINK_STATUS, &linked); + if (linked == GL_TRUE) { + return OkStatus(); + } + GLint info_size; + glGetProgramiv(program_id, GL_INFO_LOG_LENGTH, &info_size); + std::string errors; + errors.resize(info_size + 1 /* plus \0 */); + glGetProgramInfoLog(program_id, info_size + 1, nullptr, &errors[0]); + // TODO(akulik): use glValidateProgram to gather more info. + return UnavailableError("Program is not properly linked: " + errors); +} + +struct ParameterSetter { + Status operator()(int value) { + return TFLITE_GPU_CALL_GL(glProgramUniform1i, program_id, uniform_id, + value); + } + Status operator()(const int2& value) { + return TFLITE_GPU_CALL_GL(glProgramUniform2i, program_id, uniform_id, + value.x, value.y); + } + Status operator()(const int4& value) { + return TFLITE_GPU_CALL_GL(glProgramUniform4i, program_id, uniform_id, + value.x, value.y, value.z, value.w); + } + Status operator()(const std::vector& value) { + std::vector ints(value.size() * 2, 0); + for (int i = 0; i < value.size(); ++i) { + ints[i * 2] = value[i].x; + ints[i * 2 + 1] = value[i].y; + } + return TFLITE_GPU_CALL_GL(glProgramUniform2iv, program_id, uniform_id, + ints.size(), ints.data()); + } + Status operator()(unsigned int value) { + return TFLITE_GPU_CALL_GL(glProgramUniform1ui, program_id, uniform_id, + value); + } + Status operator()(const uint4& value) { + return TFLITE_GPU_CALL_GL(glProgramUniform4ui, program_id, uniform_id, + value.x, value.y, value.z, value.w); + } + Status operator()(float value) { + return TFLITE_GPU_CALL_GL(glProgramUniform1f, program_id, uniform_id, + value); + } + Status operator()(const float2& value) { + return TFLITE_GPU_CALL_GL(glProgramUniform2f, program_id, uniform_id, + value.x, value.y); + } + Status operator()(const float4& value) { + return TFLITE_GPU_CALL_GL(glProgramUniform4f, program_id, uniform_id, + value.x, value.y, value.z, value.w); + } + + const GLuint program_id; + const GLint uniform_id; +}; + +} // namespace + +Status GlProgram::CreateWithShader(const GlShader& shader, + GlProgram* gl_program) { + GLuint program_id; + RETURN_IF_ERROR(CreateNewProgramId(&program_id)); + + // program_id needs to be properly deleted if there will be an error, hense + // wrap program_id into Program. + GlProgram program(program_id); + + RETURN_IF_ERROR( + TFLITE_GPU_CALL_GL(glAttachShader, program.id(), shader.id())); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glLinkProgram, program.id())); + RETURN_IF_ERROR(CheckProgramLinked(program.id())); + + *gl_program = std::move(program); + return OkStatus(); +} + +Status GlProgram::CreateWithBinaryShader(const BinaryShader& shader, + GlProgram* gl_program) { + GLuint program_id; + RETURN_IF_ERROR(CreateNewProgramId(&program_id)); + + // program_id needs to be properly deleted if there will be an error, hense + // wrap program_id into Program. + GlProgram program(program_id); + + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glProgramBinary, program.id(), + shader.format(), shader.binary().data(), + shader.binary().size())); + RETURN_IF_ERROR(CheckProgramLinked(program.id())); + + *gl_program = std::move(program); + return OkStatus(); +} + +Status GlProgram::GetBinary(BinaryShader* binary_shader) { + GLint size = 0; + RETURN_IF_ERROR( + TFLITE_GPU_CALL_GL(glGetProgramiv, id_, GL_PROGRAM_BINARY_LENGTH, &size)); + if (!size) { + return InternalError("Getting binary size failed."); + } + // TODO(akulik): call + // glProgramParameteri(id_, GL_PROGRAM_BINARY_RETRIEVABLE_HINT, GL_TRUE) + // before linking a program to increase chances of retrieving a binary. + std::vector binary(size); + GLsizei returned_size; + GLenum format; + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetProgramBinary, id_, size, + &returned_size, &format, + reinterpret_cast(&binary[0]))); + if (size != returned_size) { + return InternalError("Getting binary is failed."); + } + *binary_shader = BinaryShader(format, std::move(binary)); + return OkStatus(); +} + +GlProgram::GlProgram(GlProgram&& program) : id_(program.id_) { + program.id_ = 0; +} + +void GlProgram::Invalidate() { + if (id_) { + glDeleteProgram(id_); + id_ = 0; + } +} + +GlProgram& GlProgram::operator=(GlProgram&& program) { + if (this != &program) { + Invalidate(); + std::swap(id_, program.id_); + } + return *this; +} + +GlProgram::~GlProgram() { Invalidate(); } + +Status GlProgram::SetParameter(const UniformParameter& param) { + GLint uniform_location; + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetUniformLocation, &uniform_location, + id_, param.name.c_str())); + return absl::visit(ParameterSetter{id_, uniform_location}, param.value); +} + +Status GlProgram::Dispatch(const uint3& workgroups) const { + if (workgroups.x == 0 || workgroups.y == 0 || workgroups.z == 0) { + return InvalidArgumentError("Invalid workgroups"); + } + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glUseProgram, id_)); + return TFLITE_GPU_CALL_GL(glDispatchCompute, workgroups.x, workgroups.y, + workgroups.z); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gl_program.h b/tensorflow/lite/delegates/gpu/gl/gl_program.h new file mode 100644 index 00000000000..ff176344d19 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_program.h @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_PROGRAM_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_PROGRAM_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// A wrapper around opengl program id that needs to be recycled when not needed. +// Encapsulates logic needed to bind parameters, link a program and execute it. +class GlProgram { + public: + // Creates invalid program. + GlProgram() : id_(0) {} + + // Creates new program, initializes it, attaches the given shader and links + // a program. Thus, if this call returns a program, one may set parameters and + // finally execute a program. + // therefore it needs to be handled elsewhere. + static Status CreateWithShader(const GlShader& shader, GlProgram* gl_program); + + // Same as CreateWithShader but takes compiled shader in a binary form, + // therefore compilation step is avoided. + static Status CreateWithBinaryShader(const BinaryShader& shader, + GlProgram* gl_program); + + // move-only + GlProgram(GlProgram&& program); + GlProgram& operator=(GlProgram&& program); + GlProgram(const GlProgram&) = delete; + GlProgram& operator=(const GlProgram&) = delete; + + ~GlProgram(); + + GLuint id() const { return id_; } + + // Returns a binary representation for a shader currently attached and linked + // into this program. + Status GetBinary(BinaryShader* binary_shader); + + Status SetParameter(const UniformParameter& param); + + // Executes program + Status Dispatch(const uint3& workgroups) const; + + bool is_valid() const { return id_ != 0; } + + private: + explicit GlProgram(GLuint program_id) : id_(program_id) {} + + void Invalidate(); + + GLint GetUniformId(const std::string& name); + + GLuint id_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_PROGRAM_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/gl_shader.cc b/tensorflow/lite/delegates/gpu/gl/gl_shader.cc new file mode 100644 index 00000000000..32391749985 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_shader.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" + +namespace tflite { +namespace gpu { +namespace gl { + +GlShader::GlShader(GlShader&& shader) : id_(shader.id_) { shader.id_ = 0; } + +void GlShader::Invalidate() { + if (id_) { + glDeleteShader(id_); + id_ = 0; + } +} + +GlShader& GlShader::operator=(GlShader&& shader) { + if (this != &shader) { + Invalidate(); + std::swap(id_, shader.id_); + } + return *this; +} + +GlShader::~GlShader() { Invalidate(); } + +Status GlShader::CompileShader(GLenum shader_type, + const std::string& shader_source, + GlShader* gl_shader) { + // NOTE: code compilation can fail due to gl errors happened before + GLuint shader_id; + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glCreateShader, &shader_id, shader_type)); + GlShader shader(shader_id); + + const char* src = shader_source.c_str(); + RETURN_IF_ERROR( + TFLITE_GPU_CALL_GL(glShaderSource, shader.id(), 1, &src, nullptr)); + + glCompileShader(shader.id()); + // Didn't check for opengl errors here because we want to get better logs + // if it didn't compile. + GLint compiled = GL_FALSE; + glGetShaderiv(shader.id(), GL_COMPILE_STATUS, &compiled); + if (!compiled) { + GLint info_log_len = 0; + glGetShaderiv(shader.id(), GL_INFO_LOG_LENGTH, &info_log_len); + std::string errors(info_log_len, 0); + glGetShaderInfoLog(shader.id(), info_log_len, nullptr, &errors[0]); + return InternalError("Shader compilation failed: " + errors + + "\nProblem shader is:\n" + shader_source); + } + + *gl_shader = std::move(shader); + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gl_shader.h b/tensorflow/lite/delegates/gpu/gl/gl_shader.h new file mode 100644 index 00000000000..d0ec421bb16 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_shader.h @@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_SHADER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_SHADER_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// A wrapper around opengl shader id that needs to be recycled when not needed. +class GlShader { + public: + // Creates and compiles a shader. + // + // @param shader_type is one of GL_VERTEX_SHADER, GL_FRAGMENT_SHADER, or + // GL_COMPUTE_SHADER. + static Status CompileShader(GLenum shader_type, + const std::string& shader_source, + GlShader* gl_shader); + + GlShader() : id_(0) {} + + // move-only + GlShader(GlShader&& shader); + GlShader& operator=(GlShader&& shader); + GlShader(const GlShader&) = delete; + GlShader& operator=(const GlShader&) = delete; + + ~GlShader(); + + GLuint id() const { return id_; } + + private: + explicit GlShader(GLuint id) : id_(id) {} + + void Invalidate(); + + GLuint id_; +}; + +// Holds binary blob for compiled shader. It can be used to instantiate +// a program instead of plain Shader that will need to be compiled first. +// +// Some OpenGL implementations allow to extract binary representation once it +// is compiled. Call Program::GetBinary after program is successfully created +// with a shader from sources. +class BinaryShader { + public: + BinaryShader(GLenum format, std::vector binary) + : format_(format), binary_(std::move(binary)) {} + + GLenum format() const { return format_; } + + const std::vector& binary() const { return binary_; } + + private: + GLenum format_; + std::vector binary_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_SHADER_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/gl_sync.cc b/tensorflow/lite/delegates/gpu/gl/gl_sync.cc new file mode 100644 index 00000000000..889e8dda428 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_sync.cc @@ -0,0 +1,83 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/gl_sync.h" + +#ifdef __ARM_ACLE +#include +#endif // __ARM_ACLE + +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" + +namespace tflite { +namespace gpu { +namespace gl { + +Status GlSyncWait() { + GlSync sync; + RETURN_IF_ERROR(GlSync::NewSync(&sync)); + // Flush sync and loop afterwards without it. + GLenum status = glClientWaitSync(sync.sync(), GL_SYNC_FLUSH_COMMANDS_BIT, + /* timeout ns = */ 0); + while (true) { + switch (status) { + case GL_TIMEOUT_EXPIRED: + break; + case GL_CONDITION_SATISFIED: + case GL_ALREADY_SIGNALED: + return OkStatus(); + case GL_WAIT_FAILED: + return GetOpenGlErrors(); + } + status = glClientWaitSync(sync.sync(), 0, /* timeout ns = */ 10000000); + } + return OkStatus(); +} + +Status GlActiveSyncWait() { + GlSync sync; + RETURN_IF_ERROR(GlSync::NewSync(&sync)); + // Since creating a Sync object is itself a GL command it *must* be flushed. + // Otherwise glGetSynciv may never succeed. Perform a flush with + // glClientWaitSync call. + GLenum status = glClientWaitSync(sync.sync(), GL_SYNC_FLUSH_COMMANDS_BIT, + /* timeout ns = */ 0); + switch (status) { + case GL_TIMEOUT_EXPIRED: + break; + case GL_CONDITION_SATISFIED: + case GL_ALREADY_SIGNALED: + return OkStatus(); + case GL_WAIT_FAILED: + return GetOpenGlErrors(); + } + + // Start active loop. + GLint result = GL_UNSIGNALED; + while (true) { + glGetSynciv(sync.sync(), GL_SYNC_STATUS, sizeof(GLint), nullptr, &result); + if (result == GL_SIGNALED) { + return OkStatus(); + } +#ifdef __ARM_ACLE + // Try to save CPU power by yielding CPU to another thread. + __yield(); +#endif + } +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gl_sync.h b/tensorflow/lite/delegates/gpu/gl/gl_sync.h new file mode 100644 index 00000000000..a00a0c2b048 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_sync.h @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_SYNC_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_SYNC_H_ + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// RAII wrapper for OpenGL GLsync object. +// See https://www.khronos.org/opengl/wiki/Sync_Object for more information. +// +// GlSync is moveable but not copyable. +class GlSync { + public: + static Status NewSync(GlSync* gl_sync) { + GLsync sync; + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glFenceSync, &sync, + GL_SYNC_GPU_COMMANDS_COMPLETE, 0)); + *gl_sync = GlSync(sync); + return OkStatus(); + } + + // Creates invalid object. + GlSync() : GlSync(nullptr) {} + + // Move-only + GlSync(GlSync&& sync) : sync_(sync.sync_) { sync.sync_ = nullptr; } + + GlSync& operator=(GlSync&& sync) { + if (this != &sync) { + Invalidate(); + std::swap(sync_, sync.sync_); + } + return *this; + } + + GlSync(const GlSync&) = delete; + GlSync& operator=(const GlSync&) = delete; + + ~GlSync() { Invalidate(); } + + const GLsync sync() const { return sync_; } + + private: + explicit GlSync(GLsync sync) : sync_(sync) {} + + void Invalidate() { + if (sync_) { + glDeleteSync(sync_); + sync_ = nullptr; + } + } + + GLsync sync_; +}; + +// Waits until GPU is done with processing. +Status GlSyncWait(); + +// Performs active waiting by spinning a thread and checking sync status. It +// leads to shorter wait time (up to tens of ms) but consumes more CPU. +Status GlActiveSyncWait(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_SYNC_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/gl_texture.cc b/tensorflow/lite/delegates/gpu/gl/gl_texture.cc new file mode 100644 index 00000000000..eb20deca758 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_texture.cc @@ -0,0 +1,313 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" + +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" + +namespace tflite { +namespace gpu { +namespace gl { + +GLenum ToTextureFormat(DataType type) { + switch (type) { + case DataType::INT8: + case DataType::UINT16: + case DataType::UINT32: + case DataType::INT16: + case DataType::INT32: + return GL_RGBA_INTEGER; + case DataType::FLOAT16: + case DataType::FLOAT32: + case DataType::UINT8: // this requires GL_RGBA8 internal format + return GL_RGBA; + default: + return 0; + } +} + +GLenum ToTextureInternalFormat(DataType type) { + switch (type) { + case DataType::UINT8: + return GL_RGBA8; // this requires GL_RGBA format + case DataType::INT8: + return GL_RGBA8I; + case DataType::UINT16: + return GL_RGBA16UI; + case DataType::UINT32: + return GL_RGBA32UI; + case DataType::INT16: + return GL_RGBA16I; + case DataType::INT32: + return GL_RGBA32I; + case DataType::FLOAT16: + return GL_RGBA16F; + case DataType::FLOAT32: + return GL_RGBA32F; + default: + return 0; + } +} + +GLenum ToTextureDataType(DataType type) { + switch (type) { + case DataType::UINT8: + return GL_UNSIGNED_BYTE; + case DataType::INT8: + return GL_BYTE; + case DataType::UINT16: + return GL_UNSIGNED_SHORT; + case DataType::UINT32: + return GL_UNSIGNED_INT; + case DataType::INT16: + return GL_SHORT; + case DataType::INT32: + return GL_INT; + case DataType::FLOAT16: + return GL_HALF_FLOAT; + case DataType::FLOAT32: + return GL_FLOAT; + default: + return 0; + } +} + +GlTexture::GlTexture(GlTexture&& texture) + : GlTexture(texture.target_, texture.id_, texture.format_, + texture.bytes_size_, texture.layer_, texture.owned_) { + texture.owned_ = false; +} + +GlTexture& GlTexture::operator=(GlTexture&& texture) { + if (this != &texture) { + Invalidate(); + + target_ = texture.target_; + format_ = texture.format_; + bytes_size_ = texture.bytes_size_; + layer_ = texture.layer_; + owned_ = texture.owned_; + id_ = texture.id_; + texture.owned_ = false; + } + return *this; +} + +GlTexture::~GlTexture() { + Invalidate(); +} + +void GlTexture::Invalidate() { + if (owned_ && id_ != GL_INVALID_INDEX) { + TFLITE_GPU_CALL_GL(glDeleteTextures, 1, &id_).IgnoreError(); + id_ = GL_INVALID_INDEX; + } +} + +Status GlTexture::BindImage(uint32_t index, GLenum access) const { + return TFLITE_GPU_CALL_GL(glBindImageTexture, index, id_, /* level = */ 0, + /* layered = */ GL_TRUE, layer_, access, format_); +} + +Status GlTexture::BindAsReadonlyImage(uint32_t index) const { + return BindImage(index, GL_READ_ONLY); +} + +Status GlTexture::BindAsWriteonlyImage(uint32_t index) const { + return BindImage(index, GL_WRITE_ONLY); +} + +Status GlTexture::BindAsReadWriteImage(uint32_t index) const { + return BindImage(index, GL_READ_WRITE); +} + +Status GlTexture::BindAsSampler2D(uint32_t index) const { + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glActiveTexture, GL_TEXTURE0 + index)); + return TFLITE_GPU_CALL_GL(glBindTexture, GL_TEXTURE_2D, id_); +} + +namespace { + +Status SetTextureWrapAndFilter(GLenum target, GLenum texture_format) { + if (texture_format == GL_RGBA32F) { + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_WRAP_S, GL_REPEAT)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_WRAP_T, GL_REPEAT)); + if (target == GL_TEXTURE_2D_ARRAY || target == GL_TEXTURE_3D) { + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_WRAP_R, GL_REPEAT)); + } + // Texture filtering is not available for GL_RGBA32F, hence explicitly + // specifying GL_NEAREST param for texture (Otherwise, we can end up + // sampling some incorrect values from texture.) + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_MAG_FILTER, GL_NEAREST)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_MIN_FILTER, GL_NEAREST)); + } else if (texture_format == GL_RGBA16F) { + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_WRAP_S, GL_REPEAT)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_WRAP_T, GL_REPEAT)); + if (target == GL_TEXTURE_2D_ARRAY || target == GL_TEXTURE_3D) { + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_WRAP_R, GL_REPEAT)); + } + // Texture filtering is available for GL_RGBA16F, specifying that + // explicitly improves quality for some operations like texture upscaling + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_MAG_FILTER, GL_LINEAR)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexParameteri, target, + GL_TEXTURE_MIN_FILTER, GL_LINEAR)); + } + return OkStatus(); +} + +Status CreateReadOnlyRgba2dImageTexture(DataType data_type, const uint2& size, + const void* data, size_t byte_size, + GlTexture* gl_texture) { + if (byte_size != /* RGBA=*/4 * SizeOf(data_type) * size.x * size.y) { + return InvalidArgumentError( + "Creating image texture failed. Source data size is not matching " + "expected dimensions."); + } + const GLenum kTarget = GL_TEXTURE_2D; + GLenum internal_format = ToTextureInternalFormat(data_type); + GLenum format = ToTextureFormat(data_type); + GLenum type = ToTextureDataType(data_type); + gl_texture_internal::TextureId id; + gl_texture_internal::TextureBinder binder(kTarget, id.id()); + RETURN_IF_ERROR(SetTextureWrapAndFilter(kTarget, internal_format)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexStorage2D, kTarget, + /* num_levels = */ 1, internal_format, + size.x, size.y)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexSubImage2D, kTarget, /* level = */ 0, + 0, 0, size.x, size.y, format, type, data)); + *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, 0, + /*owned=*/true); + return OkStatus(); +} + +Status CreateReadOnlyRgba3dImageTexture(DataType data_type, const uint3& size, + const void* data, size_t byte_size, + GlTexture* gl_texture) { + if (byte_size != /* RGBA=*/4 * SizeOf(data_type) * size.x * size.y * size.z) { + return InvalidArgumentError( + "Creating image texture failed. Source data is larger than dimensions " + "product."); + } + const GLenum kTarget = GL_TEXTURE_2D_ARRAY; + GLenum internal_format = ToTextureInternalFormat(data_type); + GLenum format = ToTextureFormat(data_type); + GLenum type = ToTextureDataType(data_type); + gl_texture_internal::TextureId id; + gl_texture_internal::TextureBinder binder(kTarget, id.id()); + RETURN_IF_ERROR(SetTextureWrapAndFilter(kTarget, internal_format)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexStorage3D, kTarget, + /* num_levels = */ 1, internal_format, + size.x, size.y, size.z)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexSubImage3D, kTarget, /* level = */ 0, + 0, 0, 0, size.x, size.y, size.z, format, + type, data)); + *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, 0, + /*owned=*/true); + return OkStatus(); +} + +} // namespace + +Status CreateReadOnlyImageTexture(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { + return CreateReadOnlyRgba2dImageTexture(DataType::FLOAT32, size, data.data(), + data.size() * sizeof(float), + gl_texture); +} + +Status CreateReadOnlyImageTexture(const uint3& size, + absl::Span data, + GlTexture* gl_texture) { + return CreateReadOnlyRgba3dImageTexture(DataType::FLOAT32, size, data.data(), + data.size() * sizeof(float), + gl_texture); +} + +Status CreateReadOnlyImageTextureU8(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { + return CreateReadOnlyRgba2dImageTexture(DataType::UINT8, size, data.data(), + data.size() * sizeof(uint8_t), + gl_texture); +} + +Status CreateReadOnlyImageTextureF16(const uint2& size, + absl::Span data, + GlTexture* gl_texture) { + return CreateReadOnlyRgba2dImageTexture(DataType::FLOAT16, size, data.data(), + data.size() * sizeof(uint16_t), + gl_texture); +} + +Status CreateReadOnlyImageTextureF16(const uint3& size, + absl::Span data, + GlTexture* gl_texture) { + return CreateReadOnlyRgba3dImageTexture(DataType::FLOAT16, size, data.data(), + data.size() * sizeof(uint16_t), + gl_texture); +} + +Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint2& size, + GlTexture* gl_texture) { + const GLenum kTarget = GL_TEXTURE_2D; + const GLenum internal_format = ToTextureInternalFormat(data_type); + gl_texture_internal::TextureId id; + gl_texture_internal::TextureBinder binder(kTarget, id.id()); + RETURN_IF_ERROR(SetTextureWrapAndFilter(kTarget, internal_format)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexStorage2D, kTarget, + /* num_levels = */ 1, internal_format, + size.x, size.y)); + size_t byte_size = /* RGBA = */ 4 * SizeOf(data_type) * size.x * size.y; + *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, + /* layer = */ 0, + /* owned = */ true); + return OkStatus(); +} + +Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint3& size, + GlTexture* gl_texture) { + const GLenum kTarget = GL_TEXTURE_2D_ARRAY; + GLenum internal_format = ToTextureInternalFormat(data_type); + gl_texture_internal::TextureId id; + gl_texture_internal::TextureBinder binder(kTarget, id.id()); + RETURN_IF_ERROR(SetTextureWrapAndFilter(kTarget, internal_format)); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glTexStorage3D, kTarget, + /* num_levels = */ 1, internal_format, + size.x, size.y, size.z)); + size_t byte_size = + /* RGBA = */ 4 * SizeOf(data_type) * size.x * size.y * size.z; + *gl_texture = GlTexture(kTarget, id.Release(), internal_format, byte_size, + /* layer = */ 0, + /* owned = */ true); + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gl_texture.h b/tensorflow/lite/delegates/gpu/gl/gl_texture.h new file mode 100644 index 00000000000..951b22f23f1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gl_texture.h @@ -0,0 +1,208 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_TEXTURE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_TEXTURE_H_ + +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Texture is an RAII wrapper for OpenGL texture object. +// See https://www.khronos.org/opengl/wiki/Texture for more information. +// +// Texture is moveable but not copyable. +class GlTexture { + public: + // Creates invalid texture. + GlTexture() + : GlTexture(GL_INVALID_ENUM, GL_INVALID_INDEX, GL_INVALID_ENUM, 0, 0, + false) {} + + GlTexture(GLenum target, GLuint id, GLenum format, size_t bytes_size, + GLint layer, bool owned) + : id_(id), + target_(target), + format_(format), + bytes_size_(bytes_size), + layer_(layer), + owned_(owned) {} + + // Move-only + GlTexture(GlTexture&& texture); + GlTexture& operator=(GlTexture&& texture); + GlTexture(const GlTexture&) = delete; + GlTexture& operator=(const GlTexture&) = delete; + + ~GlTexture(); + + // Binds a texture as an image to the given index. + Status BindAsReadonlyImage(uint32_t index) const; + + // Bind texture as an image for write access at given index. + Status BindAsWriteonlyImage(uint32_t index) const; + + // Bind texture as an image for read-write access at given index. + Status BindAsReadWriteImage(uint32_t index) const; + + // Binds a texture as a sampler to the given index. + Status BindAsSampler2D(uint32_t index) const; + + GLenum target() const { return target_; } + + GLuint id() const { return id_; } + + GLenum format() const { return format_; } + + GLint layer() const { return layer_; } + + bool is_valid() const { return id_ != GL_INVALID_INDEX; } + + size_t bytes_size() const { return bytes_size_; } + + // @return true if this object actually owns corresponding GL buffer + // and manages it's lifetime. + bool has_ownership() const { return owned_; } + + private: + void Invalidate(); + + Status BindImage(uint32_t index, GLenum access) const; + + GLuint id_; + GLenum target_; + GLenum format_; + size_t bytes_size_; + GLint layer_; + bool owned_; +}; + +// Creates new 2D image texture that will be filled with float32 data once which +// will be used for reading. +// +// @param size defines 2D image texture size where each pixel is RGBA. +Status CreateReadOnlyImageTexture(const uint2& size, + absl::Span data, + GlTexture* gl_texture); + +// Creates new 2D image texture that will be filled with float16 data once which +// will be used for reading. +// +// @param size defines 2D image texture size where each pixel is RGBA. +Status CreateReadOnlyImageTextureF16(const uint2& size, + absl::Span data, + GlTexture* gl_texture); + +// Creates new 2D image texture that will be filled with uint8 data once which +// will be used for reading. +// +// @param size defines 2D image texture size where each pixel is RGBA. +Status CreateReadOnlyImageTextureU8(const uint2& size, + absl::Span data, + GlTexture* gl_texture); + +// Creates new 3D RGBA image texture that will be filled with float32 data once +// which will be used for reading. +// +// @param size defines 3D image texture size where each pixel is RGBA. +Status CreateReadOnlyImageTexture(const uint3& size, + absl::Span data, + GlTexture* gl_texture); + +// Creates new 3D RGBA image texture that will be filled with float16 data once +// which will be used for reading. +// +// @param size defines 3D image texture size where each pixel is RGBA. +Status CreateReadOnlyImageTextureF16(const uint3& size, + absl::Span data, + GlTexture* gl_texture); + +// Creates new RGBA 2D image texture +// +// @param size defines 2D image texture size where each pixel is RGBA. +Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint2& size, + GlTexture* gl_texture); + +// Creates new RGBA 3D image texture +// +// @param size defines 3D image texture size where each pixel is RGBA. +Status CreateReadWriteRgbaImageTexture(DataType data_type, const uint3& size, + GlTexture* gl_texture); + +GLenum ToTextureFormat(DataType type); + +GLenum ToTextureInternalFormat(DataType type); + +GLenum ToTextureDataType(DataType type); + +namespace gl_texture_internal { + +// RAII for creating and/or owning texture id. +class TextureId { + public: + TextureId() : id_(GL_INVALID_INDEX) { + TFLITE_GPU_CALL_GL(glGenTextures, 1 /* number of textures*/, &id_) + .IgnoreError(); + } + + explicit TextureId(GLuint id) : id_(id) {} + + ~TextureId() { + if (id_ != GL_INVALID_INDEX) { + TFLITE_GPU_CALL_GL(glDeleteTextures, 1, &id_).IgnoreError(); + } + } + + GLuint id() const { return id_; } + + GLuint Release() { + GLuint id = GL_INVALID_INDEX; + std::swap(id, id_); + return id; + } + + private: + GLuint id_; +}; + +// RAII for binding and unbinding a texture. +class TextureBinder { + public: + TextureBinder(GLenum target, GLuint id) : target_(target) { + TFLITE_GPU_CALL_GL(glBindTexture, target_, id).IgnoreError(); + } + + ~TextureBinder() { + TFLITE_GPU_CALL_GL(glBindTexture, target_, 0).IgnoreError(); + } + + private: + const GLenum target_; +}; + +} // namespace gl_texture_internal +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GL_TEXTURE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/gpu_info.cc b/tensorflow/lite/delegates/gpu/gl/gpu_info.cc new file mode 100644 index 00000000000..d40910c3357 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gpu_info.cc @@ -0,0 +1,155 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" + +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +GpuType GetGpuType(const std::string& renderer) { + if (renderer.find("mali") != renderer.npos) { + return GpuType::MALI; + } + if (renderer.find("adreno") != renderer.npos) { + return GpuType::ADRENO; + } + if (renderer.find("powervr") != renderer.npos) { + return GpuType::POWERVR; + } + if (renderer.find("intel") != renderer.npos) { + return GpuType::INTEL; + } + if (renderer.find("nvidia") != renderer.npos) { + return GpuType::NVIDIA; + } + return GpuType::UNKNOWN; +} + +GpuModel GetGpuModel(const std::string& renderer) { + auto found_model = [&](std::string model) -> bool { + return renderer.find(model) != renderer.npos; + }; + // Adreno 6xx series + if (found_model("640")) return GpuModel::ADRENO640; + if (found_model("630")) return GpuModel::ADRENO630; + if (found_model("616")) return GpuModel::ADRENO616; + if (found_model("615")) return GpuModel::ADRENO615; + if (found_model("612")) return GpuModel::ADRENO612; + if (found_model("605")) return GpuModel::ADRENO605; + // Adreno 5xx series + if (found_model("540")) return GpuModel::ADRENO540; + if (found_model("530")) return GpuModel::ADRENO530; + if (found_model("512")) return GpuModel::ADRENO512; + if (found_model("510")) return GpuModel::ADRENO510; + if (found_model("509")) return GpuModel::ADRENO509; + if (found_model("508")) return GpuModel::ADRENO508; + if (found_model("506")) return GpuModel::ADRENO506; + if (found_model("505")) return GpuModel::ADRENO505; + if (found_model("504")) return GpuModel::ADRENO504; + // Adreno 4xx series + if (found_model("430")) return GpuModel::ADRENO430; + if (found_model("420")) return GpuModel::ADRENO420; + if (found_model("418")) return GpuModel::ADRENO418; + if (found_model("405")) return GpuModel::ADRENO405; + // Adreno 3xx series + if (found_model("330")) return GpuModel::ADRENO330; + if (found_model("320")) return GpuModel::ADRENO320; + if (found_model("308")) return GpuModel::ADRENO308; + if (found_model("306")) return GpuModel::ADRENO306; + if (found_model("305")) return GpuModel::ADRENO305; + if (found_model("304")) return GpuModel::ADRENO304; + // Adreno 2xx series + if (found_model("225")) return GpuModel::ADRENO225; + if (found_model("220")) return GpuModel::ADRENO220; + if (found_model("205")) return GpuModel::ADRENO205; + if (found_model("203")) return GpuModel::ADRENO203; + if (found_model("200")) return GpuModel::ADRENO200; + // Adreno 1xx series + if (found_model("130")) return GpuModel::ADRENO130; + return GpuModel::UNKNOWN; +} + +} // namespace + +void GetGpuModelAndType(const std::string& renderer, GpuModel* gpu_model, + GpuType* gpu_type) { + std::string lowered = renderer; + absl::AsciiStrToLower(&lowered); + *gpu_type = GetGpuType(lowered); + *gpu_model = + *gpu_type == GpuType::ADRENO ? GetGpuModel(lowered) : GpuModel::UNKNOWN; +} + +Status RequestGpuInfo(GpuInfo* gpu_info) { + GpuInfo info; + + const GLubyte* renderer_name = glGetString(GL_RENDERER); + if (renderer_name) { + info.renderer_name = reinterpret_cast(renderer_name); + GetGpuModelAndType(info.renderer_name, &info.gpu_model, &info.type); + } + + const GLubyte* vendor_name = glGetString(GL_VENDOR); + if (vendor_name) { + info.vendor_name = reinterpret_cast(vendor_name); + } + + const GLubyte* version_name = glGetString(GL_VERSION); + if (version_name) { + info.version = reinterpret_cast(version_name); + } + + glGetIntegerv(GL_MAJOR_VERSION, &info.major_version); + glGetIntegerv(GL_MINOR_VERSION, &info.minor_version); + + GLint extensions_count; + glGetIntegerv(GL_NUM_EXTENSIONS, &extensions_count); + info.extensions.resize(extensions_count); + for (int i = 0; i < extensions_count; ++i) { + info.extensions[i] = std::string( + reinterpret_cast(glGetStringi(GL_EXTENSIONS, i))); + } + glGetIntegerv(GL_MAX_COMPUTE_SHADER_STORAGE_BLOCKS, &info.max_ssbo_bindings); + glGetIntegerv(GL_MAX_COMPUTE_IMAGE_UNIFORMS, &info.max_image_bindings); + info.max_work_group_size.resize(3); + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 0, + &info.max_work_group_size[0]); + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, + &info.max_work_group_size[1]); + glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 2, + &info.max_work_group_size[2]); + glGetIntegerv(GL_MAX_COMPUTE_WORK_GROUP_INVOCATIONS, + &info.max_work_group_invocations); + glGetIntegerv(GL_MAX_TEXTURE_SIZE, &info.max_texture_size); + glGetIntegerv(GL_MAX_IMAGE_UNITS, &info.max_image_units); + glGetIntegerv(GL_MAX_ARRAY_TEXTURE_LAYERS, &info.max_array_texture_layers); + RETURN_IF_ERROR(GetOpenGlErrors()); + *gpu_info = info; + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gpu_info.h b/tensorflow/lite/delegates/gpu/gl/gpu_info.h new file mode 100644 index 00000000000..ba7e0a5f3dc --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/gpu_info.h @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_GPU_INFO_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_GPU_INFO_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace gl { + +enum class GpuType { UNKNOWN, MALI, ADRENO, POWERVR, INTEL, NVIDIA }; +enum class GpuModel { + UNKNOWN, + // Adreno 6xx series + ADRENO640, + ADRENO630, + ADRENO616, + ADRENO615, + ADRENO612, + ADRENO605, + // Adreno 5xx series + ADRENO540, + ADRENO530, + ADRENO512, + ADRENO510, + ADRENO509, + ADRENO508, + ADRENO506, + ADRENO505, + ADRENO504, + // Adreno 4xx series + ADRENO430, + ADRENO420, + ADRENO418, + ADRENO405, + // Adreno 3xx series + ADRENO330, + ADRENO320, + ADRENO308, + ADRENO306, + ADRENO305, + ADRENO304, + // Adreno 2xx series + ADRENO225, + ADRENO220, + ADRENO205, + ADRENO203, + ADRENO200, + // Adreno 1xx series + ADRENO130, +}; + +struct GpuInfo { + GpuType type = GpuType::UNKNOWN; + std::string renderer_name; + std::string vendor_name; + std::string version; + GpuModel gpu_model; + int major_version = -1; + int minor_version = -1; + std::vector extensions; + int max_ssbo_bindings = 0; + int max_image_bindings = 0; + std::vector max_work_group_size; + int max_work_group_invocations; + int max_texture_size = 0; + int max_image_units = 0; + int max_array_texture_layers = 0; +}; + +// Analyzes `renderer` and returns matching `GpuType` and `GpuModel`. +void GetGpuModelAndType(const std::string& renderer, GpuModel* gpu_model, + GpuType* gpu_type); + +// This method performs multiple GL calls, therefore, egl context needs to be +// created upfront. +Status RequestGpuInfo(GpuInfo* gpu_info); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_GPU_INFO_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD new file mode 100644 index 00000000000..21ebf6cc744 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD @@ -0,0 +1,387 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") + +cc_library( + name = "add", + srcs = ["add.cc"], + hdrs = ["add.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "add_test", + srcs = ["add_test.cc"], + linkopts = [ + "-lEGL", + "-lGLESv3", + ], + tags = [ + "notap", + "tflite_not_portable_ios", + ], + deps = [ + ":add", + ":test_util", + "//tensorflow/lite/delegates/gpu/common:operations", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "concat", + srcs = ["concat.cc"], + hdrs = ["concat.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "conv", + srcs = ["conv.cc"], + hdrs = ["conv.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "//tensorflow/lite/delegates/gpu/gl/workgroups:ideal_workgroup_picker", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "depthwise_conv", + srcs = ["depthwise_conv.cc"], + hdrs = ["depthwise_conv.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "//tensorflow/lite/delegates/gpu/gl/workgroups:ideal_workgroup_picker", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "elementwise", + srcs = ["elementwise.cc"], + hdrs = ["elementwise.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_test( + name = "elementwise_test", + srcs = ["elementwise_test.cc"], + linkopts = [ + "-lEGL", + "-lGLESv3", + ], + tags = [ + "notap", + "tflite_not_portable_ios", + ], + deps = [ + ":elementwise", + ":test_util", + "//tensorflow/lite/delegates/gpu/common:operations", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "fully_connected", + srcs = ["fully_connected.cc"], + hdrs = ["fully_connected.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "lstm", + srcs = ["lstm.cc"], + hdrs = ["lstm.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "max_unpooling", + srcs = ["max_unpooling.cc"], + hdrs = ["max_unpooling.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "mul", + srcs = ["mul.cc"], + hdrs = ["mul.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "pad", + srcs = ["pad.cc"], + hdrs = ["pad.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "prelu", + srcs = ["prelu.cc"], + hdrs = ["prelu.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "relu", + srcs = ["relu.cc"], + hdrs = ["relu.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_test( + name = "relu_test", + srcs = ["relu_test.cc"], + linkopts = [ + "-lEGL", + "-lGLESv3", + ], + tags = [ + "notap", + "tflite_not_portable_ios", + ], + deps = [ + ":relu", + ":test_util", + "//tensorflow/lite/delegates/gpu/common:operations", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "reshape", + srcs = ["reshape.cc"], + hdrs = ["reshape.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "slice", + srcs = ["slice.cc"], + hdrs = ["slice.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "softmax", + srcs = ["softmax.cc"], + hdrs = ["softmax.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/gl:api", + "//tensorflow/lite/delegates/gpu/gl:compiler_options", + "//tensorflow/lite/delegates/gpu/gl:egl_environment", + "//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "//tensorflow/lite/delegates/gpu/gl:gpu_info", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "//tensorflow/lite/delegates/gpu/gl:object_manager", + "//tensorflow/lite/delegates/gpu/gl:runtime_options", + "//tensorflow/lite/delegates/gpu/gl/workgroups:default_calculator", + ], +) + +cc_library( + name = "transpose_conv", + srcs = ["transpose_conv.cc"], + hdrs = ["transpose_conv.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "upsampling_bilinear", + srcs = ["upsampling_bilinear.cc"], + hdrs = ["upsampling_bilinear.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + ], +) + +TFLITE_GPU_BINARY_RELEASE_OPERATORS = [ + "add", + "concat", + "conv", + "depthwise_conv", + "elementwise", + "fully_connected", + "lstm", + "mul", + "pad", + "pooling", + "prelu", + "relu", + "reshape", + "slice", + "softmax", + "transpose_conv", + "upsampling_bilinear", +] + +NON_TFLITE_GPU_BINARY_RELEASE_OPERATORS = [ + "max_unpooling", +] + +cc_library( + name = "registry", + srcs = ["registry.cc"], + hdrs = ["registry.h"], + visibility = ["//visibility:public"], + deps = [":" + op_name for op_name in TFLITE_GPU_BINARY_RELEASE_OPERATORS] + + select({ + "//tensorflow/lite/delegates/gpu:tflite_gpu_binary_release": [], + "//conditions:default": NON_TFLITE_GPU_BINARY_RELEASE_OPERATORS, + }) + [ + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/gl:node_shader", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/add.cc b/tensorflow/lite/delegates/gpu/gl/kernels/add.cc new file mode 100644 index 00000000000..e1073299ecd --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/add.cc @@ -0,0 +1,121 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/add.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class Add : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto attr = absl::any_cast(ctx.node->operation.attributes); + auto adds = absl::get_if>(&attr.param); + auto scalar = absl::get_if(&attr.param); + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + if (!adds && !scalar) { + // check if it is a broadcast + if (inputs.size() == 2 && + inputs[0]->tensor.shape != inputs[1]->tensor.shape && + inputs[1]->tensor.shape.h == 1 && inputs[1]->tensor.shape.w == 1 && + inputs[0]->tensor.shape.c == inputs[1]->tensor.shape.c) { + *generated_code = { + /*parameters=*/{}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/ + "value_0 = $input_data_1[gid.z]$ + $input_data_0[gid.x, gid.y, " + "gid.z]$;", + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } + + std::string code = "value_0 = value_0"; + for (int index = 1; index < inputs.size(); ++index) { + if (inputs[index]->tensor.shape != inputs[0]->tensor.shape) { + return InvalidArgumentError("Shapes are not equal"); + } + absl::StrAppend(&code, " + value_", index); + } + absl::StrAppend(&code, ";"); + *generated_code = { + /*parameters=*/{}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(code), + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } + + if (scalar) { + *generated_code = { + /*parameters=*/{{"scalar", *scalar}}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/"value_0 += $scalar$;", + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + } else { + auto shape = inputs[0]->tensor.shape; + *generated_code = { + /*parameters=*/{}, + /*objects=*/{{"add_buffer", MakeReadonlyObject(adds->data)}}, + // Declare workload explicitly because shader depends on gid.z. + /*workload=*/ + uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)), + /*workgroup=*/uint3(), + /*source_code=*/"value_0 += $add_buffer[gid.z]$;", + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + } + + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewAddNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/add.h b/tensorflow/lite/delegates/gpu/gl/kernels/add.h new file mode 100644 index 00000000000..cfd6ce84533 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/add.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_ADD_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_ADD_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewAddNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_ADD_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/add_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/add_test.cc new file mode 100644 index 00000000000..22a7a74176f --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/add_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/add.h" + +#include + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/test_util.h" + +using ::testing::FloatNear; +using ::testing::Pointwise; + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +TEST(AddTest, AddsTwoInputTensors) { + TensorRefFloat32 augend, addend, output; + augend.type = DataType::FLOAT32; + augend.ref = 0; + augend.shape = BHWC(1, 2, 2, 1); + + addend.type = DataType::FLOAT32; + addend.ref = 1; + addend.shape = BHWC(1, 2, 2, 1); + + output.type = DataType::FLOAT32; + output.ref = 2; + output.shape = BHWC(1, 2, 2, 1); + + AddAttributes attr; + SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, + {augend, addend}, {output}); + ASSERT_TRUE(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8})); + ASSERT_TRUE(model.PopulateTensor(1, {0.1, 0.2, 0.3, 0.5})); + ASSERT_TRUE(model.Invoke(*NewAddNodeShader())); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-1.9, 0.4, 1.0, 1.3})); +} + +TEST(AddTest, AddsOneInputTensorWithBroadcast) { + AddAttributes attr; + attr.param = 0.1f; + TensorRefFloat32 input, output; + input.type = DataType::FLOAT32; + input.ref = 0; + input.shape = BHWC(1, 3, 1, 2); + + output.type = DataType::FLOAT32; + output.ref = 1; + output.shape = BHWC(1, 3, 1, 2); + + SingleOpModel model({ToString(OperationType::ADD), std::move(attr)}, {input}, + {output}); + ASSERT_TRUE(model.PopulateTensor(0, {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0})); + ASSERT_TRUE(model.Invoke(*NewAddNodeShader())); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-1.9, 0.3, 0.8, 0.9, 1.2, 2.1})); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc b/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc new file mode 100644 index 00000000000..5d2afe76c46 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/concat.cc @@ -0,0 +1,485 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class AlignedConcatByChannels : public NodeShader { + public: + static bool IsSupported(const GenerationContext& ctx) { + auto attr = + absl::any_cast(ctx.node->operation.attributes); + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + // Implementation supports concatenation by channels only. + if (attr.axis != Axis::CHANNELS) { + return false; + } + + // Implementation supports concatenation of 2 tensors only. + if (inputs.size() != 2) { + return false; + } + + // H and W must be the same for every concatenated tensor. + auto shape0 = inputs[0]->tensor.shape; + for (int i = 1; i < inputs.size(); i++) { + auto current_shape = inputs[i]->tensor.shape; + if (shape0.h != current_shape.h || shape0.w != current_shape.w) { + return false; + } + } + + // Channels must be aligned by 4 for every concatenated tensor. + for (int i = 0; i < inputs.size(); i++) { + if (inputs[i]->tensor.shape.c % 4 != 0) { + return false; + } + } + + return true; + } + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + if (!IsSupported(ctx)) { + return InvalidArgumentError( + "This case is not supported by aligned concat"); + } + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + // Shader below concatenates 2 tensors which channels are aligned by 4 + std::string source = R"( + if (gid.z < $border$) { + value_0 = $input_data_0[gid.x, gid.y, gid.z]$; + } else { + int z = gid.z - $border$; + value_0 = $input_data_1[gid.x, gid.y, z]$; + } +)"; + *generated_code = { + /*parameters=*/{{"border", inputs[0]->tensor.shape.c / 4}}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +class ConcatByAnyChannel : public NodeShader { + public: + static bool IsSupported(const GenerationContext& ctx) { + auto attr = + absl::any_cast(ctx.node->operation.attributes); + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + // Implementation supports concatenation by channels only. + if (attr.axis != ::tflite::gpu::Axis::CHANNELS) { + return false; + } + + // Implementation supports concatenation of more that 1 tensors only. + if (inputs.size() <= 1) { + return false; + } + + // H and W must be the same for every concatenated tensor. + auto shape0 = inputs[0]->tensor.shape; + for (int i = 1; i < inputs.size(); i++) { + auto current_shape = inputs[i]->tensor.shape; + if (shape0.h != current_shape.h || shape0.w != current_shape.w) { + return false; + } + } + + return true; + } + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + if (!IsSupported(ctx)) { + return UnimplementedError("This case is not supported by concat"); + } + + auto inputs = ctx.graph->FindInputs(ctx.node->id); + auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; + + std::string code = DeclareVariables(); + + // "already_written" is used to keep the amount of already joined channels + int already_written = 0; + // "t" is an id of the next temp* variable. + // Generally, temp* variables are used in macros + // READ_BUFFER_VEC4(buff, addr, var). + // This macros instantiate the variable "var" and + // reads the value from buffer "buff" by address "addr" + int t = 0; + for (int current_input_id = 0; current_input_id < inputs.size(); + current_input_id++) { + // Start joining next inout tensor + + // Grab channels amount + int in_ch = inputs[current_input_id]->tensor.shape.c; + code += PrintStartMessage(current_input_id, in_ch, already_written); + + // Construct the buffer name associated with this tensor + std::string input = "input_data_" + std::to_string(current_input_id); + + // "reminder" shows us how many cells in 4-element vector are left after + // the last write. As example, if we join two tensors both with + // 3 channels, after joining the first one we come to this line again + // and, when joining the second tensor, the reminder value + // will be equal to 1 + int reminder = already_written % 4; + + if (reminder == 0) { + code += AlignedCase(in_ch, input); + } else { + code += UnalignedCase(reminder, in_ch, input, &t); + } + already_written += in_ch; + } + + *generated_code = { + /*parameters=*/{}, + /*objects=*/{}, + /*workload=*/uint3(output->tensor.shape.w, output->tensor.shape.h, 1), + /*workgroup=*/uint3(), + /*source_code=*/std::move(code), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::ONLY_DEFINITIONS, + }; + return OkStatus(); + } + + private: + // Utility function + std::string temp(int t) const { return "temp" + std::to_string(t); } + + std::string DeclareVariables() const { + // "val" is used to collect useful information before the next + // upcoming write. + return R"( +int z = gid.z; +vec4 val = vec4(0.0f); + +)"; + } + + std::string PrintStartMessage(int current_input_id, int in_ch, + int already_written) const { + return "// Joining " + std::to_string(current_input_id) + + " tensor with " + std::to_string(in_ch) + + " channels\n// * * * *\\n// Already wrote " + + std::to_string(already_written) + " elements\n\n"; + } + + std::string AlignedCase(int in_ch, const std::string& input) const { + std::string code; + // This branch is for aligned reading and writing, when we can copy + // all 4 components at once. Address of the first element to write + // should be aligned. + // Visual examples: + // 1) when copy input_data_0 + // + // | * * * * | * * * @ | @ @ . . . + // ^ + // 2) when in the middle of joining process: + // + // | X X X X | * * * @ | @ @ . . . + // ^ + // Note that amount of * equals to the in_ch + // + // X - cells were written before + // * - you are going to write into these cells + // @ - you will fill these cells next cycles + // ^ - first elem you start writing from + int blocks_amount = IntegralDivideRoundUp(in_ch, 4); + code += "// Aligned case\n"; + code += "// I'm going to make " + std::to_string(blocks_amount) + + " write(s)\n\n"; + for (int block = 0; block < blocks_amount; block++) { + // Copy full 4-element vector + code += "val = $" + input + "[gid.x, gid.y, " + std::to_string(block) + + "]$;\n" + + "$output_data_0[gid.x, gid.y, z] = val$;\n" + // calculate next address to write + + "z++; \n\n"; + } + return code; + } + + std::string UnalignedCase(int reminder, int in_ch, const std::string& input, + int* t) const { + // This branch is for copying cell-by-cell. It will never start from the + // first tensor input_data_0. This function is splitting in two stages: + // 1) Copy the "leftovers" for the previous cells + // 2) Copy all other + // Visual examples: + // + // Stage 1 Stage 2 + // ----------- ------------------------- + // . . X | X X X *1 | *2 *2 *2 @ | @ @ . . . + // ^ + // . . X | X X *1 *1 | *2 *2 *2 *2 | *2 *2 . . . + // ^ + // . . X | X *1 *1 *1 | *2 @ @ @ | @ @ . . . + // ^ + // Note that amount of * equals to the in_ch + // + // X - cells were written before + // *1 - write there at the Stage 1 + // *2 - write there at the Stage 2 + // @ - you will fill these cells next cycles + // ^ - first elem you start writing from + + std::string code = "// Unaligned case\n"; + + // Variable "shift" showes how many "empty" cells are left after previous + // write. Remember, that this case should is unaligned. + // shift now can only be 1, 2 or 3 + int shift = 4 - reminder; + if (shift > in_ch) { + shift = in_ch; + } + code += "\n// Stage 1\n"; + code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, 0]$;\n"; + for (int i = 0; i < shift; i++) { + // Note that reminder + i has implicitly added 1, cause + // reminder by it's nature is an amount, not an index + code += "val[" + std::to_string(reminder + i) + "] = " + temp(*t) + "[" + + std::to_string(i) + "];\n"; + } + // Rewrite previous value with updated last cells + code += "$output_data_0[gid.x, gid.y, z - 1] = val$;\n"; + (*t)++; + + // "left_blocks" is equal to an amount of WRITE_BUFFER_VEC4 calls + // which will are left for this input to be finally copied + int left_blocks = (in_ch - shift) / 4; + if ((in_ch - shift) % 4 != 0) { + left_blocks++; + } + if (left_blocks) { + code += "\n// Stage 2\n"; + for (int block = 0; block < left_blocks; block++) { + for (int elem = 0; elem < 4; elem++) { + if (shift % 4 == 0) { + code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, " + + std::to_string(block + 1) + "]$;\n"; + (*t)++; + } + code += "val[" + std::to_string(elem) + "] = " + temp(*t - 1) + "[" + + std::to_string(shift % 4) + "];\n"; + if (shift == in_ch) { + break; + } + shift++; + } + code += "$output_data_0[gid.x, gid.y, z] = val$;\n"; + code += "z++;\n"; + } + } else { + code += "// No Stage 2\n"; + } + return code; + } +}; + +class FlatConcatByHeight : public NodeShader { + public: + static bool IsSupported(const GenerationContext& ctx) { + auto attr = + absl::any_cast(ctx.node->operation.attributes); + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + // Implementation supports concatenation by height only. + if (attr.axis != ::tflite::gpu::Axis::HEIGHT) { + return false; + } + + // Implementation supports concatenation of more that 1 tensors only. + if (inputs.size() <= 1) { + return false; + } + + // C and W must be the same for every concatenated tensor. + auto shape0 = inputs[0]->tensor.shape; + for (int i = 1; i < inputs.size(); i++) { + auto current_shape = inputs[i]->tensor.shape; + if (shape0.c != current_shape.c || shape0.w != current_shape.w) { + return false; + } + } + + return true; + } + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto inputs = ctx.graph->FindInputs(ctx.node->id); + std::string code; + std::vector params; + for (int i = 0, shift = 0; i < inputs.size(); + shift += inputs[i]->tensor.shape.h, i++) { + code += "if ("; + if (i != 0) { + code += "$input_data_" + std::to_string(i - 1) + "_h$ <= gid.y && "; + } + code += "gid.y < " + std::to_string(shift + inputs[i]->tensor.shape.h) + + ") {\n"; + code += "if (gid.y - " + std::to_string(shift) + " >= $input_data_" + + std::to_string(i) + "_h$) return;\n"; + code += "value_0 = $input_data_" + std::to_string(i) + + "[gid.x, gid.y - " + std::to_string(shift) + ", gid.z]$;\n}\n"; + if (i != inputs.size() - 1) { + code += " else "; + } + params.push_back({"input_data_" + std::to_string(i) + "_h", + inputs[i]->tensor.shape.h}); + } + + *generated_code = { + /*parameters=*/std::move(params), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(code), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +class FlatConcatByWidth : public NodeShader { + public: + static bool IsSupported(const GenerationContext& ctx) { + auto attr = + absl::any_cast(ctx.node->operation.attributes); + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + // Implementation supports concatenation by width only. + if (attr.axis != ::tflite::gpu::Axis::WIDTH) { + return false; + } + + // Implementation supports concatenation of more that 1 tensors only. + if (inputs.size() <= 1) { + return false; + } + + // C and H must be the same for every concatenated tensor. + auto shape0 = inputs[0]->tensor.shape; + for (int i = 1; i < inputs.size(); i++) { + auto current_shape = inputs[i]->tensor.shape; + if (shape0.c != current_shape.c || shape0.h != current_shape.h) { + return false; + } + } + + return true; + } + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto inputs = ctx.graph->FindInputs(ctx.node->id); + std::string code; + std::vector params; + for (int i = 0, shift = 0; i < inputs.size(); + shift += inputs[i]->tensor.shape.w, i++) { + code += "if ("; + if (i != 0) { + code += "$input_data_" + std::to_string(i - 1) + "_w$ <= gid.x && "; + } + code += "gid.x < " + std::to_string(shift + inputs[i]->tensor.shape.w) + + ") {\n"; + code += "if (gid.x - " + std::to_string(shift) + " >= $input_data_" + + std::to_string(i) + "_w$) return;\n"; + code += "value_0 = $input_data_" + std::to_string(i) + "[gid.x - " + + std::to_string(shift) + ", gid.y, gid.z]$;\n}\n"; + if (i != inputs.size() - 1) { + code += " else "; + } + params.push_back({"input_data_" + std::to_string(i) + "_w", + inputs[i]->tensor.shape.w}); + } + + *generated_code = { + /*parameters=*/std::move(params), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(code), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +class FlatConcat : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + if (FlatConcatByHeight::IsSupported(ctx)) { + return flat_concat_by_height_.GenerateCode(ctx, generated_code); + } + if (FlatConcatByWidth::IsSupported(ctx)) { + return flat_concat_by_width_.GenerateCode(ctx, generated_code); + } + return InvalidArgumentError("This case is not supported by flat concat"); + } + + private: + FlatConcatByHeight flat_concat_by_height_; + FlatConcatByWidth flat_concat_by_width_; +}; + +} // namespace + +std::unique_ptr NewAlignedConcatNodeShader() { + return absl::make_unique(); +} + +std::unique_ptr NewConcatNodeShader() { + return absl::make_unique(); +} + +std::unique_ptr NewFlatConcatNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/concat.h b/tensorflow/lite/delegates/gpu/gl/kernels/concat.h new file mode 100644 index 00000000000..34c027da88b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/concat.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CONCAT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CONCAT_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewAlignedConcatNodeShader(); +std::unique_ptr NewConcatNodeShader(); +std::unique_ptr NewFlatConcatNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CONCAT_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc new file mode 100644 index 00000000000..2c19fcc24e4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/conv.cc @@ -0,0 +1,274 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/conv.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class Convolution : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + auto attr = absl::any_cast( + ctx.node->operation.attributes); + auto weights = attr.weights.shape; + const int offsets_count = weights.h * weights.w; + std::vector offsets; + for (int h = 0; h < weights.h; ++h) { + for (int w = 0; w < weights.w; ++w) { + offsets.emplace_back(w * attr.dilations.w - attr.padding.prepended.w, + h * attr.dilations.h - attr.padding.prepended.h); + } + } + std::vector parameters = { + {"input_data_0_h", input->tensor.shape.h}, + {"input_data_0_w", input->tensor.shape.w}, + {"offsets_count", offsets_count}, + {"offsets", offsets}, + {"src_depth", IntegralDivideRoundUp(weights.i, 4)}, + {"stride", int2(attr.strides.w, attr.strides.h)}, + }; + + // at least one padding is not empty + bool non_empty_padding = + attr.padding.appended.h != 0 || attr.padding.appended.w != 0 || + attr.padding.prepended.h != 0 || attr.padding.prepended.w != 0; + + std::vector> objects = { + {"weights", MakeReadonlyObject(Get3DSizeForPHWO4I4(attr.weights.shape), + ConvertToPHWO4I4(attr.weights))}}; + + std::string source = R"( + for (int i = 0; i < $offsets_count$; ++i) { + ivec2 coord = gid.xy * $stride$ + $offsets[i]$;)"; + if (non_empty_padding) { + source += R"( + if (coord.x < 0 || coord.y < 0 || coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$) { + continue; + })"; + } + source += R"( + for (int l = 0; l < $src_depth$; ++l) { + highp vec4 input_ = $input_data_0[coord.x, coord.y, l]$; + value_0.x += dot(input_, $weights[l * 4 + 0, i, gid.z]$); + value_0.y += dot(input_, $weights[l * 4 + 1, i, gid.z]$); + value_0.z += dot(input_, $weights[l * 4 + 2, i, gid.z]$); + value_0.w += dot(input_, $weights[l * 4 + 3, i, gid.z]$); + } + } + )"; + if (!attr.bias.data.empty()) { + source += "value_0 += $bias[gid.z]$;\n"; + objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)}); + } + + // This is a hotfix for special convolution, which worked 10ms on + // textures16. With this fix it works 4ms. + // TODO(eignasheva): fix this problem in the proper way + uint3 workgroup = uint3(0, 0, 0); + if (weights.h == 7 && weights.w == 7 && attr.strides.h == 4 && + attr.strides.w == 4) { + workgroup = uint3(8, 8, 8); + } + + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/std::move(objects), + /*workload=*/uint3(), + /*workgroup=*/ + GetIdealWorkgroupIfPossible( + ctx.gpu_info->gpu_model, OperationType::CONVOLUTION_2D, + HW(weights.h, weights.w), attr.strides, workgroup, + OHWI(weights.o, input->tensor.shape.h, input->tensor.shape.w, + input->tensor.shape.c)), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +int SelectMultiplier(int32_t input_width, + const NodeShader::GenerationContext& ctx) { + std::vector multipliers = {4, 2}; + if (!ctx.compiler_options.allow_precision_loss && + ctx.gpu_info->type == GpuType::MALI) { + multipliers = {2}; + } + for (int i : multipliers) { + if (input_width % i == 0) { + return i; + } + } + return 1; +} + +class Convolution1x1 : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; + auto attr = absl::any_cast( + ctx.node->operation.attributes); + if (attr.weights.shape.h != 1 || attr.weights.shape.w != 1) { + return UnimplementedError("Height and width should be 1."); + } + if (attr.dilations.h != 1 || attr.dilations.w != 1) { + return UnimplementedError("Dilations are not supported."); + } + if (attr.strides.h != 1 || attr.strides.w != 1) { + return UnimplementedError("Strides are not supported."); + } + if (attr.padding.appended.h != 0 || attr.padding.appended.w != 0 || + attr.padding.prepended.h != 0 || attr.padding.prepended.w != 0) { + return UnimplementedError("Padding is not supported."); + } + + int multiplier = SelectMultiplier(input->tensor.shape.w, ctx); + + std::vector parameters = { + {"src_depth", IntegralDivideRoundUp(input->tensor.shape.c, 4)}, + }; + + std::vector> objects = { + {"weights", MakeReadonlyObject( + uint3(4, IntegralDivideRoundUp(attr.weights.shape.i, 4), + IntegralDivideRoundUp(attr.weights.shape.o, 4)), + ConvertToPHWO4I4(attr.weights))}}; + std::string source; + for (int i = 0; i < multiplier; i++) { + absl::StrAppend(&source, "highp vec4 result", i, " = vec4(0);\n"); + } + absl::StrAppend(&source, "vec4 f;\n"); + absl::StrAppend(&source, "for (int l = 0; l < $src_depth$; ++l) {\n"); + for (int i = 0; i < multiplier; i++) { + absl::StrAppend(&source, " vec4 input", i, " = $input_data_0[gid.x * ", + multiplier, " + ", i, ",gid.y,l]$;\n"); + } + for (int k = 0; k < 4; k++) { + absl::StrAppend(&source, " f = $weights[", k, ", l, gid.z]$;\n"); + for (int i = 0; i < multiplier; i++) { + absl::StrAppend(&source, " result", i, "[", k, "] += dot(input", i, + ", f);\n"); + } + } + absl::StrAppend(&source, "}\n"); + if (!attr.bias.data.empty()) { + objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)}); + absl::StrAppend(&source, "vec4 b = $bias[gid.z]$;\n"); + for (int i = 0; i < multiplier; i++) { + absl::StrAppend(&source, "result", i, " += b;\n"); + } + } + if (multiplier != 1) { + for (int i = 0; i < multiplier; i++) { + absl::StrAppend(&source, "$inplace_update:result", i, "$\n"); + absl::StrAppend(&source, "$output_data_0[gid.x * ", multiplier, " + ", + i, ",gid.y,gid.z] = result", i, "$;\n"); + } + } else { + absl::StrAppend(&source, "value_0 = result0;\n"); + } + + auto dst_depth = IntegralDivideRoundUp(output->tensor.shape.c, 4); + uint3 workgroup = uint3(16, 16, 1); + if (ctx.gpu_info->type == GpuType::ADRENO) { + if (dst_depth >= 2) { + workgroup = uint3(8, 8, 2); + } + if (dst_depth >= 4) { + workgroup = uint3(4, 8, 4); + } + if (dst_depth >= 8) { + workgroup = uint3(4, 4, 8); + } + if (dst_depth >= 32) { + workgroup = uint3(4, 4, 16); + } + if (dst_depth >= 64) { + workgroup = uint3(2, 8, 16); + } + } else { + if (dst_depth >= 2) { + workgroup = uint3(16, 8, 2); + } + if (dst_depth >= 4) { + workgroup = uint3(16, 4, 4); + } + if (dst_depth >= 8) { + workgroup = uint3(8, 4, 8); + } + if (dst_depth >= 32) { + workgroup = uint3(8, 4, 8); + } + if (dst_depth >= 64) { + workgroup = uint3(8, 4, 8); + } + } + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/std::move(objects), + /*workload=*/ + uint3(output->tensor.shape.w / multiplier, output->tensor.shape.h, + IntegralDivideRoundUp(output->tensor.shape.c, 4)), + /*workgroup=*/ + GetIdealWorkgroupIfPossible( + ctx.gpu_info->gpu_model, OperationType::CONVOLUTION_2D, + HW(attr.weights.shape.h, attr.weights.shape.w), attr.strides, + workgroup, + OHWI(attr.weights.shape.o, input->tensor.shape.h, + input->tensor.shape.w, input->tensor.shape.c)), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/multiplier == 1 ? IOStructure::AUTO + : IOStructure::ONLY_DEFINITIONS, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewConvolutionNodeShader() { + return absl::make_unique(); +} + +std::unique_ptr NewConvolution1x1NodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/conv.h b/tensorflow/lite/delegates/gpu/gl/kernels/conv.h new file mode 100644 index 00000000000..c2f2d217493 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/conv.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CONV_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewConvolutionNodeShader(); + +// Specialization for 1x1 convolutions. +std::unique_ptr NewConvolution1x1NodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc new file mode 100644 index 00000000000..ac381e1ca08 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc @@ -0,0 +1,123 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class DepthwiseConvolution : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + auto attr = absl::any_cast( + ctx.node->operation.attributes); + auto weights = attr.weights.shape; + const int offsets_count = weights.h * weights.w; + std::vector offsets; + for (int h = 0; h < weights.h; ++h) { + for (int w = 0; w < weights.w; ++w) { + offsets.emplace_back(w * attr.dilations.w - attr.padding.prepended.w, + h * attr.dilations.h - attr.padding.prepended.h); + } + } + std::vector parameters = { + {"input_data_0_h", input->tensor.shape.h}, + {"input_data_0_w", input->tensor.shape.w}, + {"offsets_count", offsets_count}, + {"offsets", offsets}, + {"src_depth", IntegralDivideRoundUp(weights.i, 4)}, + {"channel_multiplier", weights.o}, + {"stride", int2(attr.strides.w, attr.strides.h)}, + }; + + bool non_empty_padding = + attr.padding.appended.h != 0 || attr.padding.appended.w != 0 || + attr.padding.prepended.h != 0 || attr.padding.prepended.w != 0; + + std::vector> objects = { + {"weights", MakeReadonlyObject(ConvertToPIOHW4(attr.weights))}}; + + std::string source = R"( + int src_layer_offset = (gid.z % $channel_multiplier$) * 4; + int filter_offset = gid.z * $src_depth$ * $offsets_count$ * 4; + for (int i = 0; i < $offsets_count$; ++i) { + ivec2 coord = gid.xy * $stride$ + $offsets[i]$;)"; + if (non_empty_padding) { + source += R"( + if (coord.x < 0 || coord.y < 0 || + coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$) { + continue; + })"; + } + source += R"( + int src_layer = gid.z / $channel_multiplier$; + vec4 input_ = $input_data_0[coord.x, coord.y, src_layer]$; + highp vec4 input_shifted; + input_shifted[0] = input_[(src_layer_offset + 0) / $channel_multiplier$]; + input_shifted[1] = input_[(src_layer_offset + 1) / $channel_multiplier$]; + input_shifted[2] = input_[(src_layer_offset + 2) / $channel_multiplier$]; + input_shifted[3] = input_[(src_layer_offset + 3) / $channel_multiplier$]; + int filter_offset = gid.z * $offsets_count$ + i; + value_0 += input_shifted * $weights[filter_offset]$; + } +)"; + if (!attr.bias.data.empty()) { + source += "value_0 += $bias[gid.z]$;\n"; + objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)}); + } + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/std::move(objects), + /*workload=*/uint3(), + /*workgroup=*/ + GetIdealWorkgroupIfPossible( + ctx.gpu_info->gpu_model, OperationType::DEPTHWISE_CONVOLUTION, + HW(attr.weights.shape.h, attr.weights.shape.w), attr.strides, + OHWI(attr.weights.shape.o, input->tensor.shape.h, + input->tensor.shape.w, input->tensor.shape.c)), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewDepthwiseConvolutionNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h new file mode 100644 index 00000000000..a953010e4bd --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_DEPTHWISE_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_DEPTHWISE_CONV_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewDepthwiseConvolutionNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_DEPTHWISE_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc new file mode 100644 index 00000000000..37ee322ac8a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc @@ -0,0 +1,207 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h" + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class ElementwiseOneArgument : public NodeShader { + public: + explicit ElementwiseOneArgument(OperationType operation_type) + : operation_type_(operation_type) {} + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + std::string source; + switch (operation_type_) { + case OperationType::ABS: { + source = "value_0 = abs(value_0);"; + break; + } + case OperationType::SIN: { + source = "value_0 = sin(value_0);"; + break; + } + case OperationType::COS: { + source = "value_0 = cos(value_0);"; + break; + } + case OperationType::LOG: { + source = R"( + const float nan = normalize(vec4(0,0,0,0)).x; + value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan; + value_0.y = value_0.y > 0.0 ? log(value_0.y) : nan; + value_0.z = value_0.z > 0.0 ? log(value_0.z) : nan; + value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan; + )"; + break; + } + case OperationType::SQRT: { + source = R"( + const float nan = normalize(vec4(0,0,0,0)).x; + value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan; + value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan; + value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan; + value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan; + )"; + break; + } + case OperationType::RSQRT: { + source = R"( + const float nan = normalize(vec4(0,0,0,0)).x; + value_0.x = value_0.x >= 0.0 ? 1.0 / sqrt(value_0.x) : nan; + value_0.y = value_0.y >= 0.0 ? 1.0 / sqrt(value_0.y) : nan; + value_0.z = value_0.z >= 0.0 ? 1.0 / sqrt(value_0.z) : nan; + value_0.w = value_0.w >= 0.0 ? 1.0 / sqrt(value_0.w) : nan; + )"; + break; + } + case OperationType::SQUARE: { + source = "value_0 = value_0 * value_0;"; + break; + } + case OperationType::SIGMOID: { + source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));"; + break; + } + case OperationType::TANH: { + source = "value_0 = tanh(value_0);"; + break; + } + default: + return InvalidArgumentError("Incorrect elementwise operation type."); + } + *generated_code = { + /*parameters=*/{}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + source, + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } + + private: + OperationType operation_type_; +}; + +class ElementwiseTwoArguments : public NodeShader { + public: + explicit ElementwiseTwoArguments(OperationType operation_type) + : operation_type_(operation_type) {} + static bool IsSupported(const GenerationContext& ctx) { + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + // Implementation supports concatenation of 2 tensors only. + if (inputs.size() != 2) { + return false; + } + + auto shape0 = inputs[0]->tensor.shape; + auto shape1 = inputs[1]->tensor.shape; + + // Shapes must be the same + if (shape0 != shape1) { + return false; + } + + return true; + } + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + if (!IsSupported(ctx)) { + return InvalidArgumentError( + "This case is not supported by subtract operation"); + } + std::string source; + switch (operation_type_) { + case OperationType::SUB: { + source = "value_0 -= value_1;"; + break; + } + case OperationType::DIV: { + source = "value_0 /= value_1;"; + break; + } + case OperationType::POW: { + // From documentation : + // The result is undefined if x<0 or if x=0 and y≤0. + source = "value_0 = pow(value_0, value_1);"; + break; + } + case OperationType::SQUARED_DIFF: { + source = "value_0 = (value_0 - value_1) * (value_0 - value_1);"; + break; + } + + default: + return InvalidArgumentError( + "Incorrect elementwise with two arguments operation type."); + } + *generated_code = { + /*parameters=*/{}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/source, + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } + + private: + OperationType operation_type_; +}; + +} // namespace + +std::unique_ptr NewElementwiseNodeShader( + OperationType operation_type) { + switch (operation_type) { + case OperationType::ABS: + case OperationType::SIN: + case OperationType::COS: + case OperationType::LOG: + case OperationType::SQRT: + case OperationType::RSQRT: + case OperationType::SQUARE: + case OperationType::SIGMOID: + case OperationType::TANH: + return absl::make_unique(operation_type); + case OperationType::SUB: + case OperationType::DIV: + case OperationType::POW: + case OperationType::SQUARED_DIFF: + return absl::make_unique(operation_type); + default: + return nullptr; + } +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h new file mode 100644 index 00000000000..42109d91779 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_ELEMENTWISE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_ELEMENTWISE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewElementwiseNodeShader( + OperationType operation_type); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_ELEMENTWISE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc new file mode 100644 index 00000000000..e1835bcb761 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc @@ -0,0 +1,196 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h" + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/test_util.h" + +using ::testing::FloatNear; +using ::testing::Pointwise; + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class ElementwiseOneArgumentTest : public ::testing::Test { + public: + ElementwiseOneArgumentTest() = default; + ~ElementwiseOneArgumentTest() override = default; + + TensorRefFloat32 GetTensorRef(int ref) { + TensorRefFloat32 tensor_ref; + tensor_ref.type = DataType::FLOAT32; + tensor_ref.ref = ref; + tensor_ref.shape = BHWC(1, 2, 2, 1); + return tensor_ref; + } +}; + +TEST_F(ElementwiseOneArgumentTest, Abs) { + OperationType op_type = OperationType::ABS; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0})); +} + +TEST_F(ElementwiseOneArgumentTest, Sin) { + OperationType op_type = OperationType::SIN; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471})); +} + +TEST_F(ElementwiseOneArgumentTest, Cos) { + OperationType op_type = OperationType::COS; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302})); +} + +TEST_F(ElementwiseOneArgumentTest, Log) { + OperationType op_type = OperationType::LOG; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, 3.1415926, 1.0, 1.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0})); +} + +TEST_F(ElementwiseOneArgumentTest, Sqrt) { + OperationType op_type = OperationType::SQRT; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0})); +} + +TEST_F(ElementwiseOneArgumentTest, Rsqrt) { + OperationType op_type = OperationType::RSQRT; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333})); +} + +TEST_F(ElementwiseOneArgumentTest, Square) { + OperationType op_type = OperationType::SQUARE; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0})); +} + +TEST_F(ElementwiseOneArgumentTest, Sigmoid) { + OperationType op_type = OperationType::SIGMOID; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014})); +} + +TEST_F(ElementwiseOneArgumentTest, Tanh) { + OperationType op_type = OperationType::TANH; + SingleOpModel model({ToString(op_type), {}}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329})); +} + +class ElementwiseTwoArgumentsTest : public ::testing::Test { + public: + ElementwiseTwoArgumentsTest() = default; + ~ElementwiseTwoArgumentsTest() override = default; + + TensorRefFloat32 GetTensorRef(int ref) { + TensorRefFloat32 tensor_ref; + tensor_ref.type = DataType::FLOAT32; + tensor_ref.ref = ref; + tensor_ref.shape = BHWC(1, 2, 2, 1); + return tensor_ref; + } +}; + +TEST_F(ElementwiseTwoArgumentsTest, Sub) { + OperationType op_type = OperationType::SUB; + SingleOpModel model({ToString(op_type), {}}, + {GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); + ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0})); +} + +TEST_F(ElementwiseTwoArgumentsTest, Div) { + OperationType op_type = OperationType::DIV; + SingleOpModel model({ToString(op_type), {}}, + {GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); + ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0})); +} + +TEST_F(ElementwiseTwoArgumentsTest, Pow) { + OperationType op_type = OperationType::POW; + SingleOpModel model({ToString(op_type), {}}, + {GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); + ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, 3.0, 4.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0})); +} + +TEST_F(ElementwiseTwoArgumentsTest, SquaredDiff) { + OperationType op_type = OperationType::SQUARED_DIFF; + SingleOpModel model({ToString(op_type), {}}, + {GetTensorRef(0), GetTensorRef(1)}, {GetTensorRef(2)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 2.0, 2.0, 4.0})); + ASSERT_TRUE(model.PopulateTensor(1, {1.0, 1.0, 5.0, 4.0})); + ASSERT_TRUE(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0})); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc new file mode 100644 index 00000000000..487db2b5d86 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.cc @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class FullyConnectedBuffers : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto attr = absl::any_cast( + ctx.node->operation.attributes); + + // TODO(akulik): check that input has h,w == 1,1 + std::vector parameters = { + {"src_depth", IntegralDivideRoundUp(attr.weights.shape.i, 4)}, + }; + + // TODO(akulik): refactor indexed access to weights. + std::vector> objects = { + {"weights", MakeReadonlyObject(ConvertToPHWO4I4(attr.weights))}}; + + std::string source = R"( + int offset = gid.z * $src_depth$ * 4; + for (int d = 0; d < $src_depth$; ++d, offset += 4) { + vec4 src = $input_data_0[0, 0, d]$; + value_0.x += dot(src, $weights[offset]$); + value_0.y += dot(src, $weights[offset + 1]$); + value_0.z += dot(src, $weights[offset + 2]$); + value_0.w += dot(src, $weights[offset + 3]$); + } +)"; + if (!attr.bias.data.empty()) { + source += " value_0 += $bias[gid.z]$;\n"; + objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)}); + } + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/std::move(objects), + /*workload=*/ + uint3(1, 1, IntegralDivideRoundUp(attr.weights.shape.o, 4)), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewFullyConnectedNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h new file mode 100644 index 00000000000..3a137f4ffa3 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_FULLY_CONNECTED_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewFullyConnectedNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_FULLY_CONNECTED_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc new file mode 100644 index 00000000000..696d5257598 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/lstm.cc @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/lstm.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +// Basic LSTMCell gates. +// +// inputs: 0 1 +// activ_temp prev_state +// \ / +// [[LSTM gates]] +// / \ +// new_state activation +// outputs: 0 1 +// +// The size of activ_temp should be 4x size of new_state. +// The size of prev_state == new_state == activation. +// +class LstmNodeShader : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + std::string code = R"( + vec4 prev_state = $input_data_1[gid.x, gid.y, gid.z]$; + + int c0 = 0 * $workload_z$; + int c1 = 1 * $workload_z$; + int c2 = 2 * $workload_z$; + int c3 = 3 * $workload_z$; + + // input, new, forget, output + vec4 gate_0 = $input_data_0[gid.x, gid.y, gid.z + c0]$; + vec4 gate_1 = $input_data_0[gid.x, gid.y, gid.z + c1]$; + vec4 gate_2 = $input_data_0[gid.x, gid.y, gid.z + c2]$; + vec4 gate_3 = $input_data_0[gid.x, gid.y, gid.z + c3]$; + + vec4 input_gate = 1.0f / (1.0f + exp(-1.0 * gate_0)); // sig(x) + vec4 new_input = tanh(gate_1); // tanh(x) + vec4 forget_gate = 1.0f / (1.0f + exp(-1.0 * gate_2)); // sig(x) + vec4 output_gate = 1.0f / (1.0f + exp(-1.0 * gate_3)); // sig(x) + + vec4 new_state = input_gate * new_input + forget_gate * prev_state; + vec4 activation = output_gate * tanh(new_state); + + value_0 = new_state; + value_1 = activation; + )"; + *generated_code = { + /*parameters=*/{}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(code), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewLstmNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/lstm.h b/tensorflow/lite/delegates/gpu/gl/kernels/lstm.h new file mode 100644 index 00000000000..fcc5acdd7b7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/lstm.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_LSTM_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_LSTM_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewLstmNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_LSTM_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc new file mode 100644 index 00000000000..610679df2ca --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.cc @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class MaxUnpooling : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto attr = absl::any_cast( + ctx.node->operation.attributes); + std::vector parameters = { + {"stride", int2(attr.strides.w, attr.strides.h)}, + {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)}, + {"window_h", attr.kernel.h}, + {"window_w", attr.kernel.w}, + }; + + std::string source = R"( + ivec2 coord = (gid.xy + $offset$) / $stride$; + ivec4 indices = $input_data_1[coord.x, coord.y, gid.z]$; + vec4 input_ = $input_data_0[coord.x, coord.y, gid.z]$; + coord = coord * $stride$ - $offset$; + for (int i = 0; i < 4; ++i) { + ivec2 t = coord + ivec2(indices[i] % $window_w$, indices[i] / $window_w$); + if (t.x == gid.x && t.y == gid.y) { + value_0[i] = input_[i]; + } + } + )"; + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewMaxUnpoolingNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h b/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h new file mode 100644 index 00000000000..f4deb739c14 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MAX_UNPOOLING_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MAX_UNPOOLING_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewMaxUnpoolingNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MAX_UNPOOLING_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc new file mode 100644 index 00000000000..eb94013937b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mul.cc @@ -0,0 +1,153 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class ApplyMask : public NodeShader { + public: + static bool IsSupported(const GenerationContext& ctx) { + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + // Implementation requires 2 input tensors: source and mask. + if (inputs.size() != 2) { + return false; + } + + auto src_shape = inputs[0]->tensor.shape; + auto mask_shape = inputs[1]->tensor.shape; + + // Height and width dimensions of the two input tensors must be the same. + if (src_shape.h != mask_shape.h || src_shape.w != mask_shape.w) { + return false; + } + + // Broadcast will be done if mask tensor has 1 channel. + if (mask_shape.c == 1) { + return true; + } + + // Bitwise multiplication will be done if mask tensor has the same amount of + // channels as source tensor. + if (src_shape.c == mask_shape.c) { + return true; + } + + // Other cases are not supported. + return false; + } + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + if (!IsSupported(ctx)) { + return InvalidArgumentError( + "This case is not supported by apply mask operation"); + } + auto inputs = ctx.graph->FindInputs(ctx.node->id); + + std::string source; + if (inputs[1]->tensor.shape.c == 1) { + // Broadcast case, mask channels size == 1. + source = + "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * " + "$input_data_1[gid.x, gid.y, 0]$.x;"; + } else { + // Bitwise multiplication case, src channels size == mask channels size. + source = + "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * " + "$input_data_1[gid.x, gid.y, 0]$;"; + } + + *generated_code = { + /*parameters=*/{}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +class MultiplyScalar : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto attr = absl::any_cast( + ctx.node->operation.attributes); + auto muls = absl::get_if>(&attr.param); + auto scalar = absl::get_if(&attr.param); + + if (scalar) { + *generated_code = { + /*parameters=*/{{"scalar", *scalar}}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/"value_0 *= $scalar$;", + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + } else { + if (!muls) { + return InvalidArgumentError("Empty parameters for Multiplication."); + } + auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape; + *generated_code = { + /*parameters=*/{}, + /*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}}, + // Declare workload explicitly because shader depends on gid.z. + /*workload=*/ + uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)), + /*workgroup=*/uint3(), + /*source_code=*/"value_0 *= $mul_buffer[gid.z]$;", + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + } + + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewApplyMaskNodeShader() { + return absl::make_unique(); +} + +std::unique_ptr NewMultiplyScalarNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mul.h b/tensorflow/lite/delegates/gpu/gl/kernels/mul.h new file mode 100644 index 00000000000..5868d0e6f8f --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/mul.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MUL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MUL_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewApplyMaskNodeShader(); + +std::unique_ptr NewMultiplyScalarNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MUL_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc b/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc new file mode 100644 index 00000000000..6d6662c9a54 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/pad.cc @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/pad.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class Pad : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + auto attr = absl::any_cast(ctx.node->operation.attributes); + + if (attr.type != PaddingContentType::ZEROS) { + return UnimplementedError( + "Padding with content type ~= ZEROS is not supported."); + } + if (attr.appended.h < 0 || attr.appended.w < 0 || attr.appended.c < 0 || + attr.prepended.h < 0 || attr.prepended.w < 0 || attr.prepended.c < 0) { + return UnimplementedError("Negative padding is not supported."); + } + std::vector parameters = { + {"input_data_0_h", input->tensor.shape.h}, + {"input_data_0_w", input->tensor.shape.w}, + {"prepended", + int4(attr.prepended.w, attr.prepended.h, attr.prepended.c, 0)}, + {"src_channels", input->tensor.shape.c}, + }; + + std::string source = R"( + int src_x = gid.x - $prepended.x$; + int src_y = gid.y - $prepended.y$; + if (src_x >= 0 && src_x < $input_data_0_w$ && src_y >= 0 && src_y < $input_data_0_h$) { + int start_channel = gid.z * 4; + for (int i = 0; i < 4; ++i) { + int channel = start_channel + i; + int src_z = channel - $prepended.z$; + if (src_z >= 0 && src_z < $src_channels$) { + value_0[i] = $input_data_0[src_x, src_y, src_z / 4]$[src_z % 4]; + } + } + } +)"; + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewPadNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/pad.h b/tensorflow/lite/delegates/gpu/gl/kernels/pad.h new file mode 100644 index 00000000000..c6840df1cf9 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/pad.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_PAD_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_PAD_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewPadNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_PAD_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc new file mode 100644 index 00000000000..aac3823f2f2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/pooling.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr, + const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + + if (attr.padding.prepended.h > attr.kernel.h || + attr.padding.prepended.w > attr.kernel.w) { + return InvalidArgumentError("Padding is bigger than kernel."); + } + + std::vector parameters = { + {"input_data_0_h", input->tensor.shape.h}, + {"input_data_0_w", input->tensor.shape.w}, + {"stride", int2(attr.strides.w, attr.strides.h)}, + {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)}, + {"window_h", attr.kernel.h}, + {"window_w", attr.kernel.w}, + }; + + // Per GLSL_ES 3.1 spec in Issue 13.4 + // "Floating Point Representation and Functionality" highp floats are + // expected to behave as defined in IEEE 754. In particular, signed + // infinities are mandated and defined as a number divided by 0. + std::string source = R"( + const highp float inf = -(1.0f / 0.0f); + value_0 = vec4(inf);)"; + if (attr.output_indices) { + source += R"( + ivec4 value_1; +)"; + } + source += R"( + ivec2 base_coord = gid.xy * $stride$ - $offset$; + for (int a = 0; a < $window_h$; ++a) { + for (int b = 0; b < $window_w$; ++b) { + ivec2 coord = base_coord + ivec2(b, a); + if (coord.x < 0 || coord.y < 0 || coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$) { + continue; + } + vec4 input_ = $input_data_0[coord.x, coord.y, gid.z]$;)"; + if (attr.output_indices) { + source += R"( + int window_index = a * $window_w$ + b; + if (input_.x > value_0.x) value_1.x = window_index; + if (input_.y > value_0.y) value_1.y = window_index; + if (input_.z > value_0.z) value_1.z = window_index; + if (input_.w > value_0.w) value_1.w = window_index;)"; + } + source += R"( + value_0 = max(value_0, input_); + } + } +)"; + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); +} + +Status GenerateAveragePoolingCode(const Pooling2DAttributes& attr, + const NodeShader::GenerationContext& ctx, + GeneratedCode* generated_code) { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + + std::vector parameters = { + {"input_data_0_h", input->tensor.shape.h}, + {"input_data_0_w", input->tensor.shape.w}, + {"stride", int2(attr.strides.w, attr.strides.h)}, + {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)}, + {"window_h", attr.kernel.h}, + {"window_w", attr.kernel.w}, + {"multiplier", 1.0f / static_cast(attr.kernel.h * attr.kernel.w)}, + }; + + std::string source = R"( + for (int a = 0; a < $window_h$; ++a) { + for (int b = 0; b < $window_w$; ++b) { + ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a); + if (coord.x >= 0 && coord.y >= 0 && coord.x < $input_data_0_w$ && coord.y < $input_data_0_h$) { + value_0 += $input_data_0[coord.x, coord.y, gid.z]$; + } + } + } + value_0 *= $multiplier$; +)"; + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); +} + +class Pooling : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + const auto& attr = + absl::any_cast(ctx.node->operation.attributes); + switch (attr.type) { + case PoolingType::AVERAGE: + return GenerateAveragePoolingCode(attr, ctx, generated_code); + case PoolingType::MAX: + return GenerateMaxPoolingCode(attr, ctx, generated_code); + default: + return InvalidArgumentError("Incorrect attributes' type."); + } + } +}; + +} // namespace + +std::unique_ptr NewPoolingNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/pooling.h b/tensorflow/lite/delegates/gpu/gl/kernels/pooling.h new file mode 100644 index 00000000000..c4f650cfe59 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/pooling.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_POOLING_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_POOLING_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewPoolingNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_POOLING_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc new file mode 100644 index 00000000000..9aaeceebaf7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/prelu.cc @@ -0,0 +1,164 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class PReLULinearAlpha : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; + auto attr = + absl::any_cast(ctx.node->operation.attributes); + auto alpha = absl::get_if>(&attr.alpha); + if (!alpha) { + return InvalidArgumentError("Alpha is missing"); + } + if (alpha->shape.v != output->tensor.shape.c) { + return InvalidArgumentError( + "Alpha shape does not match the number of channels."); + } + + auto shape = output->tensor.shape; + + *generated_code = + attr.clip + ? GeneratedCode{ + /*parameters=*/{{"clip", attr.clip}}, + /*objects=*/{{"alpha", MakeReadonlyObject(alpha->data)}}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + "value_0 = clamp(value_0, 0.0, $clip$) + $alpha[gid.z]$ * " + "min(value_0, 0.0);", + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + } + : GeneratedCode{ + /*parameters=*/{}, + /*objects=*/{{"alpha", MakeReadonlyBuffer(alpha->data)}}, + // Declare workload explicitly because shader depends on + // gid.z. + /*workload=*/ + uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)), + /*workgroup=*/uint3(), + /*source_code=*/ + "value_0 = max(value_0, 0.0) + $alpha[gid.z]$ * min(value_0, " + "0.0);", + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +class PReLUFull : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; + auto attr = + absl::any_cast(ctx.node->operation.attributes); + auto alpha = absl::get_if>( + &attr.alpha); + if (!alpha) { + return InvalidArgumentError("Alpha is missing"); + } + if (alpha->shape.h != output->tensor.shape.h || + alpha->shape.w != output->tensor.shape.w || + alpha->shape.c != output->tensor.shape.c) { + return InvalidArgumentError("Alpha shape does not match input shape."); + } + + auto shape = output->tensor.shape; + + *generated_code = + attr.clip + ? GeneratedCode{ + /*parameters=*/{{"clip", attr.clip}}, + /*objects=*/ + {{"alpha", MakeReadonlyObject(ConvertToPHWC4(*alpha))}}, + // Declare workload explicitly because shader + // depends on gid.z. + /*workload=*/ + uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)), + /*workgroup=*/uint3(), + /*source_code=*/ + "value_0 = clamp(value_0, 0.0, $clip$) + " + "$alpha[gid.x, gid.y, gid.z]$ * min(value_0, 0.0);", + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + } + : GeneratedCode{ + /*parameters=*/{}, + /*objects=*/ + {{"alpha", MakeReadonlyObject(ConvertToPHWC4(*alpha))}}, + // Declare workload explicitly because shader depends on + // gid.z. + /*workload=*/ + uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)), + /*workgroup=*/uint3(), + /*source_code=*/ + "value_0 = max(value_0, 0.0) + $alpha[gid.x, gid.y, gid.z]$ " + "* min(value_0, 0.0);", + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +class PReLU : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto attr = + absl::any_cast(ctx.node->operation.attributes); + auto alpha = absl::get_if>(&attr.alpha); + return alpha ? full_.GenerateCode(ctx, generated_code) + : linear_.GenerateCode(ctx, generated_code); + } + + private: + PReLULinearAlpha linear_; + PReLUFull full_; +}; + +} // namespace + +std::unique_ptr NewPReLUNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/prelu.h b/tensorflow/lite/delegates/gpu/gl/kernels/prelu.h new file mode 100644 index 00000000000..30d30198f41 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/prelu.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_PRELU_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_PRELU_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewPReLUNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_PRELU_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc new file mode 100644 index 00000000000..2201d0018dd --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc @@ -0,0 +1,141 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/registry.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/add.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/conv.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/lstm.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/pad.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/relu.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/reshape.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/slice.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.h" + +#ifndef TFLITE_GPU_BINARY_RELEASE +#include "tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h" +#endif // TFLITE_GPU_BINARY_RELEASE + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class Registry : public NodeShader { + public: + Registry() { + using Type = OperationType; + using NewShaderFunc = std::function()>; + + auto insert_op = [&](Type type, NewShaderFunc func) { + shaders_[ToString(type)].push_back(func()); + }; + auto insert_elementwise_op = [&](Type operation_type) { + shaders_[ToString(operation_type)].push_back( + NewElementwiseNodeShader(operation_type)); + }; + + insert_op(Type::ADD, NewAddNodeShader); + insert_op(Type::APPLY_MASK, NewApplyMaskNodeShader); + insert_op(Type::CONCAT, NewAlignedConcatNodeShader); + insert_op(Type::CONCAT, NewFlatConcatNodeShader); + insert_op(Type::CONCAT, NewConcatNodeShader); + insert_op(Type::CONVOLUTION_2D, NewConvolution1x1NodeShader); + insert_op(Type::CONVOLUTION_2D, NewConvolutionNodeShader); + insert_op(Type::CONVOLUTION_TRANSPOSED, NewConvolutionTransposedNodeShader); + insert_op(Type::DEPTHWISE_CONVOLUTION, NewDepthwiseConvolutionNodeShader); + insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader); + insert_op(Type::LSTM, NewLstmNodeShader); + insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader); + insert_op(Type::PAD, NewPadNodeShader); + insert_op(Type::POOLING_2D, NewPoolingNodeShader); + insert_op(Type::RELU, NewReLUNodeShader); + insert_op(Type::RESHAPE, NewReshapeNodeShader); + insert_op(Type::PRELU, NewPReLUNodeShader); + insert_op(Type::SLICE, NewSliceNodeShader); + insert_op(Type::SOFT_MAX, NewSoftMaxNodeShader); + insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader); + + insert_elementwise_op(Type::ABS); + insert_elementwise_op(Type::COS); + insert_elementwise_op(Type::LOG); + insert_elementwise_op(Type::RSQRT); + insert_elementwise_op(Type::SIGMOID); + insert_elementwise_op(Type::SIN); + insert_elementwise_op(Type::SQRT); + insert_elementwise_op(Type::SQUARE); + insert_elementwise_op(Type::TANH); + insert_elementwise_op(Type::SUB); + insert_elementwise_op(Type::DIV); + insert_elementwise_op(Type::POW); + insert_elementwise_op(Type::SQUARED_DIFF); + +#ifndef TFLITE_GPU_BINARY_RELEASE + insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader); +#endif // TFLITE_GPU_BINARY_RELEASE + } + + ~Registry() final = default; + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + std::vector errors; + auto it = shaders_.find(ctx.node->operation.type); + if (it != shaders_.end()) { + for (auto& shader : it->second) { + const auto status = shader->GenerateCode(ctx, generated_code); + if (status.ok()) return status; + errors.push_back(status.error_message()); + } + } + return NotFoundError(absl::StrCat("Suitable node shader is not found: ", + absl::StrJoin(errors, ", "))); + } + + private: + std::unordered_map>> + shaders_; +}; + +} // namespace + +std::unique_ptr NewNodeShaderRegistry() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/registry.h b/tensorflow/lite/delegates/gpu/gl/kernels/registry.h new file mode 100644 index 00000000000..009a9283afa --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/registry.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_REGISTRY_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_REGISTRY_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewNodeShaderRegistry(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_REGISTRY_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc b/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc new file mode 100644 index 00000000000..c00b9f616f5 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/relu.cc @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/relu.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class ReLU : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto attr = absl::any_cast(ctx.node->operation.attributes); + // clamp(value, min(0, alpha * value), clip) + std::vector params; + std::string min; + if (attr.alpha == 0) { + min = "vec4(0.0)"; + } else { + min = "min($alpha$ * value_0, 0.0)"; + params.push_back({"alpha", attr.alpha}); + } + std::string code; + if (attr.clip == 0) { + code = "value_0 = max(value_0, " + min + ");"; + } else { + code = "value_0 = clamp(value_0, " + min + ", vec4($clip$));"; + params.push_back({"clip", attr.clip}); + } + *generated_code = { + /*parameters=*/std::move(params), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(code), + /*input=*/IOStructure::AUTO, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewReLUNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/relu.h b/tensorflow/lite/delegates/gpu/gl/kernels/relu.h new file mode 100644 index 00000000000..fdc812ba662 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/relu.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RELU_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RELU_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewReLUNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RELU_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/relu_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/relu_test.cc new file mode 100644 index 00000000000..8807b228486 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/relu_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/relu.h" + +#include +#include +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/test_util.h" + +using ::testing::FloatNear; +using ::testing::Pointwise; + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class ReluTest : public ::testing::Test { + public: + ReluTest() = default; + ~ReluTest() override = default; + + TensorRefFloat32 GetTensorRef(int ref) { + TensorRefFloat32 tensor_ref; + tensor_ref.type = DataType::FLOAT32; + tensor_ref.ref = ref; + tensor_ref.shape = BHWC(1, 2, 2, 1); + return tensor_ref; + } +}; + +TEST_F(ReluTest, Smoke) { + OperationType op_type = OperationType::RELU; + ReLUAttributes attr; + attr.clip = 0; + attr.alpha = 0; + SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); + ASSERT_TRUE(model.Invoke(*NewReLUNodeShader())); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.0, 2.0, 8.0})); +} + +TEST_F(ReluTest, ClipOnly) { + OperationType op_type = OperationType::RELU; + ReLUAttributes attr; + attr.clip = 6; + attr.alpha = 0; + SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); + ASSERT_TRUE(model.Invoke(*NewReLUNodeShader())); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.0, 2.0, 6.0})); +} + +TEST_F(ReluTest, AlphaOnly) { + OperationType op_type = OperationType::RELU; + ReLUAttributes attr; + attr.clip = 0; + attr.alpha = 0.5; + SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); + ASSERT_TRUE(model.Invoke(*NewReLUNodeShader())); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-3.0, 0.0, 2.0, 8.0})); +} + +TEST_F(ReluTest, ClipAndAlpha) { + OperationType op_type = OperationType::RELU; + ReLUAttributes attr; + attr.clip = 6; + attr.alpha = 0.5; + SingleOpModel model({ToString(op_type), attr}, {GetTensorRef(0)}, + {GetTensorRef(1)}); + ASSERT_TRUE(model.PopulateTensor(0, {-6.0, 0.0, 2.0, 8.0})); + ASSERT_TRUE(model.Invoke(*NewReLUNodeShader())); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-3.0, 0.0, 2.0, 6.0})); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc b/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc new file mode 100644 index 00000000000..f2c0dc50e0b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/reshape.cc @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/reshape.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class Reshape : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; + if (input->tensor.shape.DimensionsProduct() != + output->tensor.shape.DimensionsProduct()) { + return InvalidArgumentError("Dimensions product is reshape don't match"); + } + auto attr = + absl::any_cast(ctx.node->operation.attributes); + if (input->tensor.shape.DimensionsProduct() != + output->tensor.shape.DimensionsProduct()) { + return InvalidArgumentError("Dimensions product is reshape don't match"); + } + if (attr.new_shape != output->tensor.shape) { + return InvalidArgumentError( + "Dimensions for output does not match new_shape attribute"); + } + + std::string code = R"( + int input_ch_w = $input_channels$ * $input_data_0_w$; + int output_ch_w = $output_channels$ * $output_data_0_w$; + for (int i = 0; i < 4; ++i) { + int dst_channel = gid.z * 4 + i; + if (dst_channel >= $output_channels$) { + continue; + } + int p = dst_channel + $output_channels$ * gid.x + output_ch_w * gid.y; + int src_y = p / input_ch_w; + int src_x = (p % input_ch_w) / $input_channels$; + int src_z = (p % input_ch_w) % $input_channels$; + int src_layer = src_z / 4; + int src_channel = src_z % 4; + value_0[i] = $input_data_0[src_x, src_y, src_layer]$[src_channel]; + } + )"; + *generated_code = { + /*parameters=*/{ + {"output_data_0_w", output->tensor.shape.w}, + {"input_data_0_w", input->tensor.shape.w}, + {"input_channels", input->tensor.shape.c}, + {"output_channels", output->tensor.shape.c}, + }, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(code), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewReshapeNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/reshape.h b/tensorflow/lite/delegates/gpu/gl/kernels/reshape.h new file mode 100644 index 00000000000..b2b0914a6d7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/reshape.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RESHAPE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RESHAPE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewReshapeNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_RESHAPE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc b/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc new file mode 100644 index 00000000000..66f9abb6b90 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/slice.cc @@ -0,0 +1,120 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/slice.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class Slice : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; + + auto attr = + absl::any_cast(ctx.node->operation.attributes); + + const int4 channels(attr.starts.c, attr.strides.c, attr.ends.c, 0); + const int4 heights(attr.starts.h, attr.strides.h, attr.ends.h, 0); + const int4 widths(attr.starts.w, attr.strides.w, attr.ends.w, 0); + + std::vector parameters = { + {"channels", channels}, + {"heights", heights}, + {"widths", widths}, + {"dst_size", output->tensor.shape.c}, + }; + + std::string code; + code += " ivec2 offset;\n"; + if (attr.strides.w > 0) { + code += " offset.x = $widths.x$;\n"; + } else { + if (attr.ends.w > 0) { + code += " offset.x = $widths.z$;\n"; + } else { + code += " offset.x = $src_size.x$ + $widths.z$;\n"; + } + } + if (attr.strides.h > 0) { + code += " offset.y = $heights.x$;\n"; + } else { + if (attr.ends.h > 0) { + code += " offset.y = $heights.z$;\n"; + } else { + code += " offset.y = src_height + $heights.z$;\n"; + } + } + code += " ivec2 stride = ivec2($widths.y$, $heights.y$);\n"; + code += " ivec2 coord = offset + ivec2(gid.xy) * stride;\n"; + code += " bool outside = false;\n"; + code += " int step = gid.z * 4;\n"; + code += " int buffer_index = 0;\n"; + code += " int addr = 0;\n"; + for (int i = 0; i < 4; i++) { + code += " addr = step * $channels.y$;\n"; + if (attr.strides.c > 0) { + code += " addr += $channels.x$;\n"; + } else { + if (attr.ends.c > 0) { + code += " addr += $channels.z$;\n"; + } else { + code += " addr += src_channels + $channels.z$;\n"; + } + } + code += " if (step < $dst_size$) {\n value_0[" + + std::to_string(i) + + "] = $input_data_0[coord.x, coord.y, addr / 4]$[addr % 4];\n " + " }\n"; + if (i != 3) { + code += " step++;\n"; + } + } + + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(code), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewSliceNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/slice.h b/tensorflow/lite/delegates/gpu/gl/kernels/slice.h new file mode 100644 index 00000000000..bf93043f578 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/slice.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SLICE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SLICE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewSliceNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SLICE_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc new file mode 100644 index 00000000000..000f2b00c5a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.cc @@ -0,0 +1,96 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class SoftMax : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; + auto attr = + absl::any_cast(ctx.node->operation.attributes); + if (input->tensor.shape != output->tensor.shape) { + return InvalidArgumentError("Input and output shape does not match"); + } + if (attr.axis != Axis::CHANNELS) { + return UnimplementedError("Softmax is only supported for channels axis."); + } + + float4 mask(0.0f); + const int channels = output->tensor.shape.c; + const int reminder = (channels % 4 == 0) ? 4 : channels % 4; + for (int i = 0; i < reminder; ++i) { + mask[i] = 1.0f; + } + std::vector parameters = { + {"src_depth", IntegralDivideRoundUp(output->tensor.shape.c, 4)}, + {"mask", mask}, + }; + + std::string source = R"( + highp float sum = 0.0; + for (int d = 0; d < $src_depth$ - 1; ++d) { + sum += dot(vec4(1.0), exp($input_data_0[gid.x, gid.y, d]$)); + } + { + int d = $src_depth$ - 1; + sum += dot($mask$, exp($input_data_0[gid.x, gid.y, d]$)); + } + for (int d = 0; d < $src_depth$; ++d) { + vec4 temp_sum = exp($input_data_0[gid.x, gid.y, d]$) / sum; + $output_data_0[gid.x, gid.y, d] = temp_sum$; + } +)"; + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/{}, + /*workload=*/uint3(output->tensor.shape.w, output->tensor.shape.h, 1), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::ONLY_DEFINITIONS, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewSoftMaxNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/softmax.h b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.h new file mode 100644 index 00000000000..2eaf91b6157 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/softmax.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SOFTMAX_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SOFTMAX_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewSoftMaxNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_SOFTMAX_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc new file mode 100644 index 00000000000..53be4fc4df7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.cc @@ -0,0 +1,132 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/test_util.h" + +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/gl/api.h" +#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/object_manager.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h" + +namespace tflite { +namespace gpu { +namespace gl { + +SingleOpModel::SingleOpModel(Operation&& operation, + const std::vector& inputs, + const std::vector& outputs) { + auto node = graph_.NewNode(); + node->operation = std::move(operation); + + for (int i = 0; i < inputs.size(); ++i) { + auto input = graph_.NewValue(); + input->tensor = inputs[i]; + graph_.AddConsumer(node->id, input->id).IgnoreError(); + TensorFloat32 tensor; + tensor.id = input->tensor.ref; + tensor.shape = input->tensor.shape; + inputs_.emplace_back(std::move(tensor)); + } + + for (int i = 0; i < outputs.size(); ++i) { + auto output = graph_.NewValue(); + output->tensor = outputs[i]; + graph_.SetProducer(node->id, output->id).IgnoreError(); + } +} + +bool SingleOpModel::PopulateTensor(int index, std::vector&& data) { + if (index >= inputs_.size() || + inputs_[index].shape.DimensionsProduct() != data.size()) { + return false; + } + inputs_[index].data = std::move(data); + return true; +} + +Status SingleOpModel::InvokeInternal(const CompilationOptions& compile_options, + const RuntimeOptions& runtime_options, + const NodeShader& shader) { + std::unique_ptr env; + RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env)); + + ObjectManager objects; + + // Create buffers for input tensors. + { + std::unordered_map tensor_to_id; + for (const auto* input : graph_.inputs()) { + tensor_to_id[input->tensor.ref] = input->id; + } + for (const auto& input : inputs_) { + GlBuffer buffer; + RETURN_IF_ERROR(CreatePHWC4BufferFromTensor(input, &buffer)); + RETURN_IF_ERROR( + objects.RegisterBuffer(tensor_to_id[input.id], std::move(buffer))); + } + } + + // Create buffers for output tensors. + for (const auto* output : graph_.outputs()) { + GlBuffer buffer; + RETURN_IF_ERROR(CreatePHWC4BufferFromTensorRef(output->tensor, &buffer)); + RETURN_IF_ERROR(objects.RegisterBuffer(output->id, std::move(buffer))); + } + + // Compile model. + GpuInfo gpu_info; + RETURN_IF_ERROR(RequestGpuInfo(&gpu_info)); + std::unique_ptr compiled_model; + RETURN_IF_ERROR(Compile(compile_options, graph_, shader, + *NewDefaultWorkgroupsCalculator(gpu_info), + &compiled_model)); + + // Get inference context. + auto command_queue = NewCommandQueue(gpu_info); + std::unique_ptr inference_context; + RETURN_IF_ERROR(compiled_model->NewRun( + runtime_options, &objects, command_queue.get(), &inference_context)); + RETURN_IF_ERROR(inference_context->Reset()); + + // Run inference. + RETURN_IF_ERROR(inference_context->Execute()); + + // Copy output tensors to `output_`. + for (const auto* output : graph_.outputs()) { + TensorFloat32 tensor; + tensor.id = output->tensor.ref; + tensor.shape = output->tensor.shape; + tensor.data.reserve(output->tensor.shape.DimensionsProduct()); + RETURN_IF_ERROR( + CopyFromPHWC4Buffer(*objects.FindBuffer(output->id), &tensor)); + outputs_.push_back(std::move(tensor)); + } + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h new file mode 100644 index 00000000000..b928402d263 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/test_util.h @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_TEST_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_TEST_UTIL_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace gl { + +class SingleOpModel { + public: + SingleOpModel() = delete; + SingleOpModel(Operation&& operation, + const std::vector& inputs, + const std::vector& outputs); + + virtual ~SingleOpModel() = default; + + bool PopulateTensor(int index, std::vector&& data); + + bool Invoke(const NodeShader& shader) { + return InvokeInternal(CompilationOptions(), RuntimeOptions(), shader).ok(); + } + + const std::vector& GetOutput(int index) const { + return outputs_[index].data; + } + + protected: + GraphFloat32 graph_; + std::vector inputs_; + std::vector outputs_; + + private: + Status InvokeInternal(const CompilationOptions& compile_options, + const RuntimeOptions& runtime_options, + const NodeShader& shader); +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc new file mode 100644 index 00000000000..f3719c5751c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class ConvolutionTransposedBuffers : public NodeShader { + public: + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + auto attr = absl::any_cast( + ctx.node->operation.attributes); + auto weights = attr.weights.shape; + const int32_t inner_size_w = (weights.w - 1) / attr.stride.w + 1; + const int32_t inner_size_h = (weights.h - 1) / attr.stride.h + 1; + + std::vector parameters = { + {"input_data_0_h", input->tensor.shape.h}, + {"input_data_0_w", input->tensor.shape.w}, + {"src_depth", IntegralDivideRoundUp(weights.i, 4)}, + {"kernel_size", int2(weights.w, weights.h)}, + {"stride", int2(attr.stride.w, attr.stride.h)}, + {"padding", int2(attr.padding.prepended.w, attr.padding.prepended.h)}, + {"inner_size", int2(inner_size_w, inner_size_h)}, + }; + + std::vector> objects = { + {"weights", MakeReadonlyObject(Get3DSizeForPHWO4I4(attr.weights.shape), + ConvertToPHWO4I4(attr.weights))}}; + + std::string source = R"( + ivec2 kernel_offset = $kernel_size$ - ivec2(1,1); + ivec2 offset = gid.xy + $padding$ - kernel_offset; + offset %= $stride$; + offset += $stride$; + offset %= $stride$; + ivec2 f_offset; + f_offset.x = offset.x == 0 ? 0 : ($stride.x$ - offset.x); + f_offset.y = offset.y == 0 ? 0 : ($stride.y$ - offset.y); + for (int ky = 0; ky < $inner_size.y$; ++ky) { + for (int kx = 0; kx < $inner_size.x$; ++kx) { + ivec2 index = ivec2(kx, ky) * $stride$ + f_offset; + bool inside_kernel = index.x < $kernel_size.x$ && index.y < $kernel_size.y$; + ivec2 coord = (gid.xy + index + $padding$ - kernel_offset) / $stride$; + bool outside = coord.x < 0 || coord.y < 0 || + coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$; + if (inside_kernel && !outside) { + index = kernel_offset - index; + int i = index.y * $kernel_size.x$ + index.x; + for (int l = 0; l < $src_depth$; ++l) { + vec4 src_color = $input_data_0[coord.x, coord.y, l]$; + value_0.x += dot(src_color, $weights[l * 4 + 0, i, gid.z]$); + value_0.y += dot(src_color, $weights[l * 4 + 1, i, gid.z]$); + value_0.z += dot(src_color, $weights[l * 4 + 2, i, gid.z]$); + value_0.w += dot(src_color, $weights[l * 4 + 3, i, gid.z]$); + } + } + } + } +)"; + if (!attr.bias.data.empty()) { + source += "value_0 += $bias[gid.z]$;\n"; + objects.push_back({"bias", MakeReadonlyObject(attr.bias.data)}); + } + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/std::move(objects), + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/source, + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewConvolutionTransposedNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h new file mode 100644 index 00000000000..553704bf68b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_TRANSPOSE_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_TRANSPOSE_CONV_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewConvolutionTransposedNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_TRANSPOSE_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.cc b/tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.cc new file mode 100644 index 00000000000..baca806b79a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.cc @@ -0,0 +1,120 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class UpsamplingBilinear : public NodeShader { + public: + UpsamplingBilinear() {} + + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + auto input = ctx.graph->FindInputs(ctx.node->id)[0]; + auto output = ctx.graph->FindOutputs(ctx.node->id)[0]; + auto attr = + absl::any_cast(ctx.node->operation.attributes); + + if (input->tensor.shape.w > output->tensor.shape.w || + input->tensor.shape.h > output->tensor.shape.h) { + return InvalidArgumentError("Output size is less than input size."); + } + if (output->tensor.shape.w != attr.new_shape.w || + output->tensor.shape.h != attr.new_shape.h) { + return InvalidArgumentError( + "Output size does not match new_size in attributes."); + } + if (input->tensor.shape.c != output->tensor.shape.c) { + return InvalidArgumentError("Input/output channels mismatch."); + } + if (attr.type != UpsamplingType::BILINEAR) { + return UnimplementedError("Upsample2D supports only bilinear type."); + } + if (input->tensor.shape.h == 1 && input->tensor.shape.w == 1) { + // Copy a single element from input. + *generated_code = { + /*parameters=*/{}, + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/"value_0 = $input_data_0[0, 0, gid.z]$;", + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } + std::vector parameters = { + {"input_data_0_h", input->tensor.shape.h}, + {"input_data_0_w", input->tensor.shape.w}, + {"scale_factor", + float2(CalculateResizeScale(input->tensor.shape.w, + output->tensor.shape.w, attr), + CalculateResizeScale(input->tensor.shape.h, + output->tensor.shape.h, attr))}, + }; + + std::string source = R"( + vec2 coord = vec2(gid.xy) * $scale_factor$; + + ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1); + ivec4 st; + st.xy = ivec2(coord); + st.zw = min(st.xy + ivec2(1, 1), borders); + + vec2 t = coord - vec2(st.xy); //interpolating factors + + vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$; + vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$; + vec4 tex12 = $input_data_0[st.x, st.w, gid.z]$; + vec4 tex22 = $input_data_0[st.z, st.w, gid.z]$; + + value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y); +)"; + *generated_code = { + /*parameters=*/std::move(parameters), + /*objects=*/{}, + /*workload=*/uint3(), + /*workgroup=*/uint3(), + /*source_code=*/std::move(source), + /*input=*/IOStructure::ONLY_DEFINITIONS, + /*output=*/IOStructure::AUTO, + }; + return OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewUpsamplingNodeShader() { + return absl::make_unique(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.h b/tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.h new file mode 100644 index 00000000000..702110ba7d8 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/kernels/upsampling_bilinear.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_UPSAMPLING_BILINEAR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_UPSAMPLING_BILINEAR_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr NewUpsamplingNodeShader(); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_UPSAMPLING_BILINEAR_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/metadata.fbs b/tensorflow/lite/delegates/gpu/gl/metadata.fbs new file mode 100644 index 00000000000..088d824a8ea --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/metadata.fbs @@ -0,0 +1,11 @@ +include "workgroups.fbs"; + +namespace tflite.gpu.gl.data; + +file_identifier "AFFL"; + +table FlowMetadata { + workgroups:CustomWorkgroups; +} + +root_type FlowMetadata; diff --git a/tensorflow/lite/delegates/gpu/gl/node_shader.h b/tensorflow/lite/delegates/gpu/gl/node_shader.h new file mode 100644 index 00000000000..20491272e35 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/node_shader.h @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_NODE_SHADER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_NODE_SHADER_H_ + +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { + +enum class IOStructure { + // Source code uses standard inputs or outputs that should be generated from + // node inputs/outputs. Compiler will generate them automatically as + // 'input_data_N'/'output_data_N', where N is an index of the input/output. + // + // Generated code should not return input objects. + ONLY_DEFINITIONS, + + // For inputs: + // Source code runs computations using 'vec4 value_N' declared by + // the compiler, where where N is an index of the input. Each value comes + // from inputs using coordinates set by GlobalInvocationID and a dispatch + // method, therefore, source code should not explicitly read values. + // + // For outputs: + // Source code runs computations and leaves results in 'vec4 value_N' + // declared by the compiler, where N is an index of the output. Value will + // be written to the output using coordinates set by GlobalInvocationID and + // a dispatch method. Therefore, source code should not explicitly write + // results. + AUTO, +}; + +struct GeneratedCode { + // A list of parameters to be set as uniform or hardcoded in a shader. + std::vector parameters; + + // A list of objects to bind before shader could be executed. + std::vector> objects; + + // Compute shader operate on an abstract concept of work groups, each + // three-dimensional. The number of work groups to be executed is defined by + // workload tuple. Therefore, + // workload[x,y,z] := workgroup_size[x,y,z] X workgroup_count[x,y,z] + // where 'X' is element-wise multiplication. + // + // Zero workload is calculated as PHWC4 based on output tensor. + uint3 workload; + + // operation may specify recommended workgroup size. If not set, runtime will + // figure it out automatically. + uint3 workgroup; + + std::string source_code; + + // Parameters below reveal additional information about source_code. + + IOStructure input; + IOStructure output; +}; + +// A class handles shader generation and setting runtime shader parameters. +class NodeShader { + public: + virtual ~NodeShader() = default; + + // A context for generating a code. + struct GenerationContext { + const GraphFloat32* graph; + const GpuInfo* gpu_info; + const Node* node; + CompilationOptions compiler_options; + }; + + // Generates shader code for a node. The code should be just a function body. + virtual Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const = 0; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_NODE_SHADER_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/object.h b/tensorflow/lite/delegates/gpu/gl/object.h new file mode 100644 index 00000000000..15c83d690a3 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/object.h @@ -0,0 +1,190 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_OBJECT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_OBJECT_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace gl { + +enum class AccessType { + READ, + WRITE, + READ_WRITE, +}; + +using ObjectData = std::vector; + +// Generic identifier to be used to lookup an object. +using ObjectRef = uint32_t; + +constexpr ObjectRef kInvalidObjectRef = ~0; + +enum class ObjectType : int { + UNKNOWN = 0, + TEXTURE = 1, + BUFFER = 2, +}; + +using ObjectSize = absl::variant; + +// An object represents a reference to or pre-defined constant OpenGL Buffer or +// Texture. NodeShader is supposed to set all fields but leave binding = 0 +// that will be set later by a compiler. +struct Object { + AccessType access; + + DataType data_type; + + ObjectType object_type; + + // OpenGL-specific binding information + uint32_t binding; + + // Indicates size of 1D, 2D or 3D object in elements, where single element + // consists of 4 values. + ObjectSize size; + + absl::variant object; +}; + +// @return true if object is a reference. +inline bool IsRef(const Object& object) { + return !absl::get_if(&object.object); +} + +inline ObjectRef GetRef(const Object& object) { + auto ref = absl::get_if(&object.object); + return ref ? *ref : kInvalidObjectRef; +} + +inline const ObjectData* GetData(const Object& object) { + return absl::get_if(&object.object); +} + +inline size_t ByteSizeOf(const Object& object); + +// @return object that references an object created externally. +template +inline Object MakeObjectRef(ObjectRef unique_id, const SizeT& size, + AccessType access_type) { + return Object{access_type, DataType::FLOAT32, ObjectType::UNKNOWN, 0, + size, unique_id}; +} + +namespace internal_object { + +template +std::vector ToBytesVector(const std::vector& data, + size_t alignment) { + std::vector t(AlignByN(data.size() * sizeof(T), alignment)); + std::memcpy(t.data(), data.data(), data.size() * sizeof(T)); + return t; +} + +struct ObjectSizer { + size_t operator()(const uint3& size) const { + return size.x * size.y * size.z; + } + + size_t operator()(const uint2& size) const { return size.x * size.y; } + + size_t operator()(uint32_t size) const { return size; } +}; + +} // namespace internal_object + +inline size_t NumElements(const ObjectSize& size) { + return absl::visit(internal_object::ObjectSizer{}, size); +} + +inline size_t ByteSizeOf(const Object& object) { + return SizeOf(object.data_type) * /* vec4 */ 4 * NumElements(object.size); +} + +template +Object MakeReadonlyObject(const SizeT& size, const std::vector& data) { + return Object{AccessType::READ, + DataType::FLOAT32, + ObjectType::UNKNOWN, + 0, + size, + internal_object::ToBytesVector(data, 16)}; +} + +template +Object MakeReadonlyTexture(const SizeT& size, const std::vector& data) { + return Object{AccessType::READ, + DataType::FLOAT32, + ObjectType::TEXTURE, + 0, + size, + internal_object::ToBytesVector(data, 16)}; +} + +template +Object MakeReadonlyBuffer(const SizeT& size, const std::vector& data) { + return Object{AccessType::READ, + DataType::FLOAT32, + ObjectType::BUFFER, + 0, + size, + internal_object::ToBytesVector(data, 16)}; +} + +inline Object MakeReadonlyObject(const std::vector& data) { + return MakeReadonlyObject(IntegralDivideRoundUp(data.size(), 4U), data); +} + +inline Object MakeReadonlyTexture(const std::vector& data) { + return MakeReadonlyTexture(IntegralDivideRoundUp(data.size(), 4U), data); +} + +inline Object MakeReadonlyBuffer(const std::vector& data) { + return MakeReadonlyBuffer(IntegralDivideRoundUp(data.size(), 4U), data); +} + +// TODO(akulik): find better place for functions below. + +inline uint3 GetPHWC4Size(const BHWC& shape) { + uint3 size; + size.x = shape.w; + size.y = shape.h; + size.z = shape.b * IntegralDivideRoundUp(shape.c, 4); + return size; +} + +inline Object MakePHWC4Ref(uint32_t global_id, const BHWC& shape) { + return MakeObjectRef(global_id, GetPHWC4Size(shape), AccessType::READ_WRITE); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_OBJECT_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/object_manager.cc b/tensorflow/lite/delegates/gpu/gl/object_manager.cc new file mode 100644 index 00000000000..49d5ee9b56c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/object_manager.cc @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/object_manager.h" + +#include "absl/memory/memory.h" +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace gl { + +Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, + GlBuffer* gl_buffer) { + std::vector transposed(GetElementsSizeForPHWC4(tensor.shape)); + RETURN_IF_ERROR( + ConvertToPHWC4(tensor.data, tensor.shape, absl::MakeSpan(transposed))); + return CreateReadOnlyShaderStorageBuffer(transposed, gl_buffer); +} + +Status CreatePHWC4BufferFromTensorRef(const TensorRefFloat32& tensor_ref, + GlBuffer* gl_buffer) { + return CreateReadWriteShaderStorageBuffer( + GetElementsSizeForPHWC4(tensor_ref.shape), gl_buffer); +} + +Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor) { + return buffer.MappedRead( + [tensor, &buffer](absl::Span data) { + tensor->data.resize(tensor->shape.DimensionsProduct()); + return ConvertFromPHWC4(absl::MakeConstSpan(data), tensor->shape, + absl::MakeSpan(tensor->data)); + }); +} + +Status ObjectManager::RegisterBuffer(uint32_t id, GlBuffer buffer) { + if (id < buffers_.size()) { + if (buffers_[id]) { + return AlreadyExistsError( + "Buffer with the same id is already registered: " + + std::to_string(id)); + } + } else { + buffers_.resize(id + 1); + } + buffers_[id] = absl::make_unique(std::move(buffer)); + return OkStatus(); +} + +void ObjectManager::RemoveBuffer(uint32_t id) { + if (id < buffers_.size()) { + buffers_[id].reset(nullptr); + } +} + +GlBuffer* ObjectManager::FindBuffer(uint32_t id) const { + return id >= buffers_.size() ? nullptr : buffers_[id].get(); +} + +Status ObjectManager::RegisterTexture(uint32_t id, GlTexture texture) { + if (id < textures_.size()) { + if (textures_[id]) { + return AlreadyExistsError( + "Texture with the same id is already registered: " + + std::to_string(id)); + } + } else { + textures_.resize(id + 1); + } + textures_[id] = absl::make_unique(std::move(texture)); + return OkStatus(); +} + +void ObjectManager::RemoveTexture(uint32_t id) { + if (id < textures_.size()) { + textures_[id].reset(nullptr); + } +} + +GlTexture* ObjectManager::FindTexture(uint32_t id) const { + return id >= textures_.size() ? nullptr : textures_[id].get(); +} + +ObjectsStats ObjectManager::stats() const { + ObjectsStats stats; + for (auto& texture : textures_) { + if (!texture || !texture->has_ownership()) continue; + stats.textures.count++; + stats.textures.total_bytes += texture->bytes_size(); + } + for (auto& buffer : buffers_) { + if (!buffer || !buffer->has_ownership()) continue; + stats.buffers.count++; + stats.buffers.total_bytes += buffer->bytes_size(); + } + return stats; +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/object_manager.h b/tensorflow/lite/delegates/gpu/gl/object_manager.h new file mode 100644 index 00000000000..6aee0a699ea --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/object_manager.h @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_OBJECT_MANAGER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_OBJECT_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" +#include "tensorflow/lite/delegates/gpu/gl/stats.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// ObjectManager is a registry that owns corresponding objects and provides +// discovery functionality. All objects are kept until manager is destroyed. +// +// All buffers and textures share the same id space, therefore, it is an error +// to register two objects with the same id. +// TODO(akulik): make ObjectManager templated by object type. +class ObjectManager { + public: + // Moves ownership over the given buffer to the manager. + Status RegisterBuffer(uint32_t id, GlBuffer buffer); + + void RemoveBuffer(uint32_t id); + + // Return a permanent pointer to a buffer for the given id or nullptr. + GlBuffer* FindBuffer(uint32_t id) const; + + // Moves ownership over the given texture to the manager. + Status RegisterTexture(uint32_t id, GlTexture texture); + + void RemoveTexture(uint32_t id); + + // Return a permanent pointer to a texture for the given id or nullptr. + GlTexture* FindTexture(uint32_t id) const; + + ObjectsStats stats() const; + + private: + std::vector> buffers_; + std::vector> textures_; +}; + +// TODO(akulik): find better place for functions below. + +// Creates read-only buffer from the given tensor. Tensor data is converted to +// PHWC4 layout. +Status CreatePHWC4BufferFromTensor(const TensorFloat32& tensor, + GlBuffer* gl_buffer); + +// Creates read-write buffer for the given tensor shape, where data layout is +// supposed to be PHWC4. +Status CreatePHWC4BufferFromTensorRef(const TensorRefFloat32& tensor_ref, + GlBuffer* gl_buffer); + +// Copies data from a buffer that holds data in PHWC4 layout to the given +// tensor. +Status CopyFromPHWC4Buffer(const GlBuffer& buffer, TensorFloat32* tensor); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_OBJECT_MANAGER_H_ diff --git a/tensorflow/compiler/xla/python/xla.i b/tensorflow/lite/delegates/gpu/gl/portable_egl.h similarity index 68% rename from tensorflow/compiler/xla/python/xla.i rename to tensorflow/lite/delegates/gpu/gl/portable_egl.h index 1c4021a558d..7be19851758 100644 --- a/tensorflow/compiler/xla/python/xla.i +++ b/tensorflow/lite/delegates/gpu/gl/portable_egl.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/* XLA-wide SWIG wrapper */ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_PORTABLE_EGL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_PORTABLE_EGL_H_ -%include "tensorflow/compiler/xla/python/local_computation_builder.i" +#include +#include + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_PORTABLE_EGL_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/portable_gl31.h b/tensorflow/lite/delegates/gpu/gl/portable_gl31.h new file mode 100644 index 00000000000..a3d03bf1058 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/portable_gl31.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_PORTABLE_GL31_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_PORTABLE_GL31_H_ + +#define HAS_EGL 1 + +#include +#include +#include + +#ifdef __ANDROID__ +// Weak-link all GL APIs included from this point on. +// TODO(camillol): Annotate these with availability attributes for the +// appropriate versions of Android, by including gl{3,31,31}.h and resetting +// GL_APICALL for each. +#undef GL_APICALL +#define GL_APICALL __attribute__((weak_import)) KHRONOS_APICALL +#endif // __ANDROID__ + +#include + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_PORTABLE_GL31_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.cc b/tensorflow/lite/delegates/gpu/gl/runtime.cc new file mode 100644 index 00000000000..c09512c55e7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/runtime.cc @@ -0,0 +1,605 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/runtime.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_errors.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +struct TextureF16Maker { + Status operator()(const uint3& size) const { + return CreateReadOnlyImageTextureF16(size, data, gl_texture); + } + Status operator()(const uint2& size) const { + return CreateReadOnlyImageTextureF16(size, data, gl_texture); + } + Status operator()(const uint32_t& size) const { + return CreateReadOnlyImageTextureF16(uint2(size, 1U), data, gl_texture); + } + absl::Span data; + GlTexture* gl_texture; +}; + +struct TextureF32Maker { + Status operator()(const uint3& size) const { + return CreateReadOnlyImageTexture(size, data, gl_texture); + } + Status operator()(const uint2& size) const { + return CreateReadOnlyImageTexture(size, data, gl_texture); + } + Status operator()(const uint32_t& size) const { + return CreateReadOnlyImageTexture(uint2(size, 1U), data, gl_texture); + } + absl::Span data; + GlTexture* gl_texture; +}; + +Status MakeGlTexture(const Object& object, const ObjectData& data, + GlTexture* gl_texture) { + if (object.access == AccessType::READ_WRITE || + object.access == AccessType::WRITE) { + return InvalidArgumentError("Read-write textures are not supported"); + } + if (object.data_type != DataType::FLOAT16 && + object.data_type != DataType::FLOAT32) { + return InvalidArgumentError("Textures support float16 or float32 only."); + } + switch (object.data_type) { + case DataType::FLOAT16: { + if (data.size() % 2 != 0) { + return InvalidArgumentError("Texture size is not aligned"); + } + return absl::visit( + TextureF16Maker{ + .data = absl::MakeConstSpan( + reinterpret_cast(data.data()), + data.size() / 2), + .gl_texture = gl_texture, + }, + object.size); + } + case DataType::FLOAT32: { + if (data.size() % sizeof(float) != 0) { + return InvalidArgumentError("Texture size is not aligned"); + } + return absl::visit( + TextureF32Maker{ + .data = absl::MakeConstSpan( + reinterpret_cast(data.data()), + data.size() / sizeof(float)), + .gl_texture = gl_texture, + }, + object.size); + } + default: + return InvalidArgumentError("Unsupported textures data type."); + } +} + +struct TextureRefMaker { + Status operator()(const uint3& size) const { + return CreateReadWriteRgbaImageTexture(type, size, gl_texture); + } + Status operator()(const uint2& size) const { + return CreateReadWriteRgbaImageTexture(type, size, gl_texture); + } + Status operator()(const uint32_t& size) const { + return CreateReadWriteRgbaImageTexture(type, uint2(size, 1U), gl_texture); + } + DataType type; + GlTexture* gl_texture; +}; + +// Makes read-write gl texture +Status MakeGlTextureRef(const Object& object, GlTexture* gl_texture) { + return absl::visit(TextureRefMaker{object.data_type, gl_texture}, + object.size); +} + +Status MakeGlBuffer(const Object& object, const ObjectData& data, + GlBuffer* gl_buffer) { + if (data.size() % SizeOf(object.data_type) != 0) { + return InvalidArgumentError("Buffer size is not aligned"); + } + return CreateReadOnlyShaderStorageBuffer(absl::MakeConstSpan(data), + gl_buffer); +} + +// Looks up an object with the given id. If found, makes a binding function. +Status MakeBindingFunc(const Object& object, uint32_t id, + const ObjectManager& objects, + std::function* binding_func) { + const uint32_t binding = object.binding; + switch (object.object_type) { + case ObjectType::BUFFER: { + auto ptr = objects.FindBuffer(id); + if (!ptr) { + return NotFoundError(absl::StrCat("Buffer ", id, " is not found")); + } + + // Validate buffer. + size_t size_in_bytes = ByteSizeOf(object); + // TODO(akulik): make comparison != instead of < + if (ptr->bytes_size() < size_in_bytes) { + return FailedPreconditionError( + absl::StrCat("Buffer ", id, " size in bytes ", ptr->bytes_size(), + " < requested size_in_bytes ", size_in_bytes)); + } + *binding_func = [=]() { return ptr->BindToIndex(binding); }; + break; + } + case ObjectType::TEXTURE: { + auto ptr = objects.FindTexture(id); + if (!ptr) { + return NotFoundError(absl::StrCat("Texture ", id, " is not found")); + } + *binding_func = [=]() { return ptr->BindAsReadWriteImage(binding); }; + break; + } + case ObjectType::UNKNOWN: + return InvalidArgumentError("Unknown object type"); + } + return OkStatus(); +} + +} // namespace + +Runtime::Runtime(const RuntimeOptions& options, const GpuInfo& gpu_info, + CommandQueue* command_queue, + const ObjectManager* external_objects) + : options_(options), + gpu_info_(gpu_info), + external_objects_(external_objects), + command_queue_(command_queue) { + programs_.reserve(256); + if (options_.bundle_readonly_objects) { + shared_readonly_buffer_ = absl::make_unique(); + } +} + +Status Runtime::AddProgram(const GlShader& shader, + const std::vector& parameters, + const std::vector& objects, + const uint3& num_workgroups) { + GlProgram program; + RETURN_IF_ERROR(GlProgram::CreateWithShader(shader, &program)); + + for (auto& parameter : parameters) { + RETURN_IF_ERROR(program.SetParameter(parameter)); + } + + programs_.emplace_back( + CompiledProgramDescriptor{std::move(program), num_workgroups, {}}); + + // Create const buffers, resolve external references and collect internal + // buffer references. + for (auto& object : objects) { + auto& program = programs_.back(); + BindFunc binding_func; + if (IsRef(object)) { + // Reference object could be provided externally as a model input/output + // but also for debugging purposes. Otherwise all references are collected + // and allocated later. + Status status = MakeBindingFunc(object, GetRef(object), + *external_objects_, &binding_func); + if (!status.ok()) { + if (status.code() == StatusCode::kNotFound) { + program.refs.push_back(object); + continue; // don't add to binding. + } + return status; + } + } else { + // Allocate const object. + uint32_t id; + RETURN_IF_ERROR(AllocateConstObject(object, &id)); + RETURN_IF_ERROR( + MakeBindingFunc(object, id, const_objects_, &binding_func)); + } + program.bindings.push_back(std::move(binding_func)); + } + + // All parameters once set stay with program, therefore, we only need to keep + // program and bindings for execution. + return OkStatus(); +} + +Status Runtime::AllocateInternalObject(const Object& object) { + const ObjectRef ref = GetRef(object); + switch (object.object_type) { + case ObjectType::BUFFER: { + GlBuffer gl_buffer; + RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( + ByteSizeOf(object), &gl_buffer)); + RETURN_IF_ERROR( + internal_objects_.RegisterBuffer(ref, std::move(gl_buffer))); + break; + } + case ObjectType::TEXTURE: { + GlTexture gl_texture; + RETURN_IF_ERROR(MakeGlTextureRef(object, &gl_texture)); + RETURN_IF_ERROR( + internal_objects_.RegisterTexture(ref, std::move(gl_texture))); + break; + } + default: + return InternalError("Unexpected internal object type"); + } + return OkStatus(); +} + +Status Runtime::AllocateConstObject(const Object& object, uint32_t* id) { + const ObjectData* data = GetData(object); + if (data == nullptr) { + return InternalError("Unable to allocate reference as a const object"); + } + *id = next_const_id_++; + switch (object.object_type) { + case ObjectType::BUFFER: { + GlBuffer gl_buffer; + if (!shared_readonly_buffer_ || + !shared_readonly_buffer_->Add(*data, &gl_buffer)) { + RETURN_IF_ERROR(MakeGlBuffer(object, *data, &gl_buffer)); + } + RETURN_IF_ERROR(const_objects_.RegisterBuffer(*id, std::move(gl_buffer))); + break; + } + case ObjectType::TEXTURE: { + GlTexture gl_texture; + RETURN_IF_ERROR(MakeGlTexture(object, *data, &gl_texture)); + RETURN_IF_ERROR( + const_objects_.RegisterTexture(*id, std::move(gl_texture))); + break; + } + case ObjectType::UNKNOWN: + return InternalError("Unknown object type"); + } + return OkStatus(); +} + +Status Runtime::PrepareForExecution() { + if (shared_readonly_buffer_ && !shared_readonly_buffer_->empty()) { + GlBuffer shared_buffer; + RETURN_IF_ERROR( + shared_readonly_buffer_->CreateSharedGlBuffer(&shared_buffer)); + shared_readonly_buffer_.reset(nullptr); + RETURN_IF_ERROR(const_objects_.RegisterBuffer(next_const_id_++, + std::move(shared_buffer))); + } + + if (options_.reuse_internal_objects) { + // Analyze internal objects and make a pool of shared objects to be re-used + // by them. These shared objects need to be allocated upfront. + std::vector shared_objects; + RETURN_IF_ERROR(AssignInternalObjects(&shared_objects)); + for (const Object& object : shared_objects) { + RETURN_IF_ERROR(AllocateInternalObject(object)); + } + } + + // Allocate all internal objects and create bindings for them. + for (auto& program : programs_) { + for (auto& object : program.refs) { + // Check whether it is created already. + BindFunc binding; + ObjectRef ref = GetRef(object); + Status status = MakeBindingFunc(object, ref, internal_objects_, &binding); + if (!status.ok()) { + if (status.code() != StatusCode::kNotFound) { + return status; + } + RETURN_IF_ERROR(AllocateInternalObject(object)); + RETURN_IF_ERROR( + MakeBindingFunc(object, ref, internal_objects_, &binding)); + } + program.bindings.push_back(std::move(binding)); + } + program.refs.clear(); + } + return OkStatus(); +} + +namespace { + +struct FitSizeFunc { + bool operator()(const uint3& size) const { + auto s = absl::get_if(&b); + if (!s) return false; + *result = uint3(std::max(s->x, size.x), std::max(s->y, size.y), + std::max(s->z, size.z)); + return true; + } + + bool operator()(const uint2& size) const { + auto s = absl::get_if(&b); + if (!s) return false; + *result = uint2(std::max(s->x, size.x), std::max(s->y, size.y)); + return true; + } + + bool operator()(uint32_t size) const { + auto s = absl::get_if(&b); + if (!s) return false; + *result = std::max(*s, size); + return true; + } + + const ObjectSize& b; + ObjectSize* result; +}; + +// Makes new size which combines largest dimensions of both given sizes. +// +// @return false if sizes have different number of dimensions +bool FitSize(const ObjectSize& a, const ObjectSize& b, ObjectSize* result) { + return absl::visit(FitSizeFunc{b, result}, a); +} + +// Texture fitting policy is: +// - 1D: source texture will always fit into target because it is linear +// - 2D: source texture should fit without growing target texture +// - 3D: source texture should fit without growing target texture +// +struct TextureFitPolicy { + bool operator()(const uint3& size) const { + auto s = absl::get_if(&target); + return s && size.x <= s->x && size.y <= s->y && size.z <= s->z; + } + + bool operator()(const uint2& size) const { + auto s = absl::get_if(&target); + return s && size.x <= s->x && size.y <= s->y; + } + + bool operator()(uint32_t size) const { + return absl::get_if(&target); + } + + const ObjectSize& target; +}; + +// Makes new size which combines largest dimensions of both given sizes. +// +// @return false if sizes have different number of dimensions +bool WillTextureFit(const ObjectSize& source, const ObjectSize& target) { + return absl::visit(TextureFitPolicy{target}, source); +} + +struct TextureNumElementsFunc { + size_t operator()(const uint3& size) const { + auto s = absl::get_if(&target); + return s ? size.z * s->x * s->y + size.y * s->x + size.x : 0; + } + + size_t operator()(const uint2& size) const { + auto s = absl::get_if(&target); + return s ? size.y * s->x + size.x : 0; + } + + size_t operator()(uint32_t size) const { + auto s = absl::get_if(&target); + return s ? size : 0; + } + + const ObjectSize& target; +}; + +// @return estimated number of elements if target texture is used to keep source +// texture data assuming XYZ layout. +size_t TextureNumElements(const ObjectSize& source, const ObjectSize& target) { + return absl::visit(TextureNumElementsFunc{target}, source); +} + +// Checks whether the given object fits into 'to' object. Returns number of +// bytes used if an object fits, or 0 otherwise. +// +// Fitting policy: +// - buffer will always fit into another buffer because they all are linear. +// - textures are handles by the policy above +// +size_t WillItFit(const Object& object, const Object& to) { + if (object.object_type != to.object_type || + object.data_type != to.data_type) { + return 0; + } + switch (object.object_type) { + case ObjectType::BUFFER: + return ByteSizeOf(object); + case ObjectType::TEXTURE: { + if (!WillTextureFit(object.size, to.size)) return 0; + // Expand 'to' dimensions to ensure an object fits. + ObjectSize new_texture_size; + if (!FitSize(object.size, to.size, &new_texture_size)) return 0; + return /* RGBA = */ 4 * SizeOf(object.data_type) * + TextureNumElements(object.size, new_texture_size); + } + default: + return 0; + } +} + +} // namespace + +// Algorithm works as follows: +// +// 1. First it collects usage intervals for each object reference. +// For example: buffer #3 is introduced in program #2 and used for the +// last time in program #7. +// +// 2. Iterates through all programs where for every object reference +// assigns shared object from the pool. When object reference is used +// for the last time, corresponding shared object is returned back to +// the pool. +// +// 3. Shared object pool grows when there are no free shared object +// available. +// +// 4. Shared object size may increase when object reference requests bigger +// size. +// +// Therefore, in the end all references are remapped to ids in the range +// [0..num_shared_objects]. To avoid ref space collision with global reference +// all shared objects are allocated in internal_objects_. +Status Runtime::AssignInternalObjects(std::vector* shared_objects) { + // Build interval set for objects to know where each object is introduced + // and used for the last time. + std::vector> usage_intervals; + for (int32_t i = 0; i < programs_.size(); ++i) { + for (auto& object : programs_[i].refs) { + auto ref = GetRef(object); + if (ref >= usage_intervals.size()) { + usage_intervals.resize(ref + 1, std::make_pair(programs_.size(), -1)); + } + auto& it = usage_intervals[ref]; + it.first = std::min(it.first, i); + it.second = std::max(it.second, i); + } + } + + std::vector is_used_shared_object; + std::vector global_ref_to_shared_ref(usage_intervals.size(), + kInvalidObjectRef); + + for (size_t i = 0; i < programs_.size(); ++i) { + auto& program = programs_[i]; + // list of object indices to return to the pool. + std::vector object_refs_to_return; + + // Assign to every internal buffer, that is not yet allocated, appropriate + // shared buffer from a heap of unused. + for (auto& object : program.refs) { + const ObjectRef ref = GetRef(object); + ObjectRef shared_ref = global_ref_to_shared_ref[ref]; + const auto& usage = usage_intervals[ref]; + + if (usage.first == i) { + // First time a reference is introduced. Assign shared object. + if (shared_ref != kInvalidObjectRef) { + return InternalError( + "Internal object is introduced for the first time but is already " + "assigned"); + } + + // Try to find a free shared object that is as close as possible by + // size. Here we assume that number of shared objects is relatively + // small (< 100), therefore, search linearly over all of them. + size_t selected_waste_bytes = 0; + for (int32_t b = 0; b < shared_objects->size(); ++b) { + // Check whether shared object is available. + if (is_used_shared_object[b]) continue; + auto& shared_object = (*shared_objects)[b]; + + // Bytes needed to fit object in the shared object. + size_t alloc_bytes = WillItFit(object, shared_object); + if (alloc_bytes == 0) continue; + + // Prefer shared object that will waste less memory. + size_t shared_byte_size = ByteSizeOf(shared_object); + // sizes are unsigned, therefore '-' may undeflow. Take smallest. + size_t waste_bytes = std::min(shared_byte_size - alloc_bytes, + alloc_bytes - shared_byte_size); + if (shared_ref == kInvalidObjectRef || + waste_bytes < selected_waste_bytes) { + selected_waste_bytes = waste_bytes; + shared_ref = b; + } + } + + if (shared_ref == kInvalidObjectRef) { + // Didn't find an object to share. Create new one. + shared_ref = shared_objects->size(); + Object shared_object = object; + shared_object.access = AccessType::READ_WRITE; + shared_object.object = shared_ref; + if (shared_object.object_type == ObjectType::BUFFER) { + // Make a buffer linear. + shared_object.size = NumElements(object.size); + } + shared_objects->push_back(std::move(shared_object)); + is_used_shared_object.push_back(false); + } else { + // Check chosen shared object and update it's size. + Object& shared_object = (*shared_objects)[shared_ref]; + switch (object.object_type) { + case ObjectType::BUFFER: + shared_object.size = std::max(NumElements(object.size), + NumElements(shared_object.size)); + break; + case ObjectType::TEXTURE: { + if (!FitSize(object.size, shared_object.size, + &shared_object.size)) { + return InternalError( + "Already assigned shared texture does not fit an object"); + } + break; + } + default: + return InternalError("Unexpected shared object type"); + } + } + } + + // Mark shared object as used and map internal object to it. + is_used_shared_object[shared_ref] = true; + global_ref_to_shared_ref[ref] = shared_ref; + object.object = shared_ref; + + // At this point we want to return unused object, but it should be + // returned later to avoid re-using the same object in this operation + // for a different purpose. + if (usage.second == i) { + object_refs_to_return.push_back(shared_ref); + } + } + + // Mark all returned objects from this program as unused. + for (size_t ref : object_refs_to_return) { + is_used_shared_object[ref] = false; + } + } + return OkStatus(); +} + +Status Runtime::Execute() { + for (const auto& descriptor : programs_) { + for (auto& b : descriptor.bindings) { + RETURN_IF_ERROR(b()); + } + RETURN_IF_ERROR(command_queue_->Dispatch(descriptor.program, + descriptor.num_workgroups)); + } + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.h b/tensorflow/lite/delegates/gpu/gl/runtime.h new file mode 100644 index 00000000000..6761d730628 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/runtime.h @@ -0,0 +1,111 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/command_queue.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/object_manager.h" +#include "tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/runtime_options.h" +#include "tensorflow/lite/delegates/gpu/gl/stats.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Runtime compiles code and executes it once all code is compiled. It creates +// intermediate objects and destroys them when runtime is destroyed. +class Runtime { + public: + Runtime(const RuntimeOptions& options, const GpuInfo& gpu_info, + CommandQueue* command_queue, const ObjectManager* external_objects); + + // Takes parameters and objects and prepares GL program. + Status AddProgram(const GlShader& shader, + const std::vector& parameters, + const std::vector& objects, + const uint3& num_workgroups); + + // Needs to be called once all programs and shaders has been added to runtime. + Status PrepareForExecution(); + + // Executes all compiled programs. + // TODO(akulik): add more controls over execution. Execution policy? + Status Execute(); + + // Gets access to objects created while executing generated code. + const ObjectManager* internal_objects() const { return &internal_objects_; } + + RuntimeStats stats() const { + RuntimeStats stats; + stats.const_objects = const_objects_.stats(); + stats.internal_objects = internal_objects_.stats(); + if (external_objects_) { + stats.external_objects = external_objects_->stats(); + } + return stats; + } + + private: + Status AllocateInternalObject(const Object& object); + + Status AllocateConstObject(const Object& object, uint32_t* id); + + // Goes over objects in programs and decides how to allocate them to + // minimize total allocated memory. Returns a collection of objects to be + // allocated and shared by internal objects. + Status AssignInternalObjects(std::vector* objects); + + const RuntimeOptions options_; + const GpuInfo gpu_info_; + const ObjectManager* external_objects_; + CommandQueue* command_queue_; + + ObjectManager internal_objects_; + ObjectManager const_objects_; + uint32_t next_const_id_ = 0; // id for const objects + + std::unique_ptr shared_readonly_buffer_; + + using BindFunc = std::function; + + // Encapsulates a program and all object to bind before dispatch. + struct CompiledProgramDescriptor { + GlProgram program; + uint3 num_workgroups; + + std::vector bindings; + std::vector refs; + }; + + std::vector programs_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/runtime/BUILD b/tensorflow/lite/delegates/gpu/gl/runtime/BUILD new file mode 100644 index 00000000000..98d03f77adb --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/runtime/BUILD @@ -0,0 +1,16 @@ +package(default_visibility = ["//tensorflow/lite/delegates/gpu/gl:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "shared_buffer", + hdrs = ["shared_buffer.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "//tensorflow/lite/delegates/gpu/gl:gl_call", + "//tensorflow/lite/delegates/gpu/gl:object", + "//tensorflow/lite/delegates/gpu/gl:portable", + ], +) diff --git a/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h b/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h new file mode 100644 index 00000000000..d4f49d1952c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/runtime/shared_buffer.h @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_SHARED_BUFFER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_SHARED_BUFFER_H_ + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Class accumulates readonly data and creates a single buffer out of it. +// User should call Add one or more times and complete shared buffer creation +// with CreateSharedBuffer() call. +class SharedBufferData { + public: + SharedBufferData() { + glGetIntegerv(GL_SHADER_STORAGE_BUFFER_OFFSET_ALIGNMENT, &alignment_); + } + + // @return true if data was added to the shared buffer. + bool Add(const ObjectData& data, GlBuffer* buffer) { + // TODO(akulik): Does it make sense to bundle even big buffers > 1MB? + + // align buffer's data. + shared_data_.resize(AlignByN(shared_data_.size(), alignment_), 0); + // Accumulate readonly data in a single shared buffer buffer. + *buffer = GlBuffer(GL_SHADER_STORAGE_BUFFER, buffer_id_.id(), data.size(), + shared_data_.size(), /*has_ownership=*/false); + std::copy(data.begin(), data.end(), std::back_inserter(shared_data_)); + return true; + } + + bool empty() const { return shared_data_.empty(); } + + // Returns a single GlBuffer that owns entire shared data. + Status CreateSharedGlBuffer(GlBuffer* gl_buffer) { + // Upload data to a buffer + gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, + buffer_id_.id()); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glBufferData, GL_SHADER_STORAGE_BUFFER, + shared_data_.size(), shared_data_.data(), + GL_STATIC_READ)); + *gl_buffer = GlBuffer(GL_SHADER_STORAGE_BUFFER, buffer_id_.Release(), + shared_data_.size(), 0, /*has_ownership=*/true); + return OkStatus(); + } + + private: + GLint alignment_ = 256; + gl_buffer_internal::BufferId buffer_id_; + ObjectData shared_data_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_SHARED_BUFFER_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/runtime_options.h b/tensorflow/lite/delegates/gpu/gl/runtime_options.h new file mode 100644 index 00000000000..44e054ec0a3 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/runtime_options.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_OPTIONS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_OPTIONS_H_ + +namespace tflite { +namespace gpu { +namespace gl { + +struct RuntimeOptions { + RuntimeOptions() + : reuse_internal_objects(true), bundle_readonly_objects(true) {} + + // If enabled triggers greedy algorithm to re-use internal buffers when + // possible. + // Keep this false when, for example, one need to analyze intermediate + // results for debugging purposes. + bool reuse_internal_objects; + + // If enabled all readonly objects will be bundled to create as few buffers or + // textures as possible. + bool bundle_readonly_objects; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_RUNTIME_OPTIONS_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/serialization.cc b/tensorflow/lite/delegates/gpu/gl/serialization.cc new file mode 100644 index 00000000000..0b950884239 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/serialization.cc @@ -0,0 +1,573 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/serialization.h" + +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { + +using flatbuffers::Offset; +using flatbuffers::Vector; + +namespace { + +struct ParameterValueGetter { + Offset operator()(int32_t value) { + auto offset = builder->CreateVector(std::vector{value}); + data::DataInt32Builder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + Offset operator()(const int2& value) { + auto offset = builder->CreateVector(std::vector{value.x, value.y}); + data::DataInt32Builder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + Offset operator()(const int4& value) { + auto offset = builder->CreateVector( + std::vector{value.x, value.y, value.z, value.w}); + data::DataInt32Builder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + Offset operator()(const std::vector& value) { + std::vector d(value.size() * 2); + for (size_t i = 0; i < value.size(); ++i) { + d[i * 2] = value[i].x; + d[i * 2 + 1] = value[i].y; + } + auto offset = builder->CreateVector(d); + data::DataInt32Builder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + Offset operator()(uint32_t value) { + auto offset = builder->CreateVector(std::vector{value}); + data::DataUint32Builder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + Offset operator()(const uint4& value) { + auto offset = builder->CreateVector( + std::vector{value.x, value.y, value.z, value.w}); + data::DataUint32Builder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + Offset operator()(float value) { + auto offset = builder->CreateVector(std::vector{value}); + data::DataFloatBuilder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + Offset operator()(const float2& value) { + auto offset = builder->CreateVector(std::vector{value.x, value.y}); + data::DataFloatBuilder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + Offset operator()(const float4& value) { + auto offset = builder->CreateVector( + std::vector{value.x, value.y, value.z, value.w}); + data::DataFloatBuilder data(*builder); + data.add_data(offset); + return data.Finish().Union(); + } + + ::flatbuffers::FlatBufferBuilder* builder; +}; + +struct DataVariantTypeGetter { + data::DataVariant operator()(int32_t) const { + return data::DataVariant::DataInt32; + } + data::DataVariant operator()(const int2&) const { + return data::DataVariant::DataInt32; + } + data::DataVariant operator()(const int4&) const { + return data::DataVariant::DataInt32; + } + data::DataVariant operator()(const std::vector&) const { + return data::DataVariant::DataInt32; + } + data::DataVariant operator()(uint32_t) const { + return data::DataVariant::DataUint32; + } + data::DataVariant operator()(const uint4&) const { + return data::DataVariant::DataUint32; + } + data::DataVariant operator()(float) const { + return data::DataVariant::DataFloat; + } + data::DataVariant operator()(const float2&) const { + return data::DataVariant::DataFloat; + } + data::DataVariant operator()(const float4&) const { + return data::DataVariant::DataFloat; + } +}; + +struct ParameterTypeGetter { + data::ParameterType operator()(int32_t) const { + return data::ParameterType::INT32; + } + data::ParameterType operator()(const int2&) const { + return data::ParameterType::INT32; + } + data::ParameterType operator()(const int4&) const { + return data::ParameterType::INT32; + } + data::ParameterType operator()(const std::vector&) const { + return data::ParameterType::INT32_2; + } + data::ParameterType operator()(uint32_t) const { + return data::ParameterType::UINT32; + } + data::ParameterType operator()(const uint4&) const { + return data::ParameterType::UINT32; + } + data::ParameterType operator()(float) const { + return data::ParameterType::FLOAT32; + } + data::ParameterType operator()(const float2&) const { + return data::ParameterType::FLOAT32; + } + data::ParameterType operator()(const float4&) const { + return data::ParameterType::FLOAT32; + } +}; + +data::DataType ToFB(DataType type) { + switch (type) { + case DataType::INT16: + return data::DataType::INT16; + case DataType::INT32: + return data::DataType::INT32; + case DataType::FLOAT16: + return data::DataType::FLOAT16; + case DataType::FLOAT32: + return data::DataType::FLOAT32; + default: + return data::DataType::UNKNOWN; + } +} + +data::ObjectType ToFB(ObjectType type) { + switch (type) { + case ObjectType::TEXTURE: + return data::ObjectType::TEXTURE; + case ObjectType::BUFFER: + return data::ObjectType::BUFFER; + default: + return data::ObjectType::UNKNOWN; + } +} + +struct ObjectSizeGetter { + Offset operator()(const uint3& shape) { + data::Uint3Builder shape_builder(*builder); + shape_builder.add_x(shape.x); + shape_builder.add_y(shape.y); + shape_builder.add_z(shape.z); + return shape_builder.Finish().Union(); + } + Offset operator()(const uint2& shape) { + data::Uint2Builder shape_builder(*builder); + shape_builder.add_x(shape.x); + shape_builder.add_y(shape.y); + return shape_builder.Finish().Union(); + } + Offset operator()(uint32_t shape) { + data::Uint1Builder shape_builder(*builder); + shape_builder.add_x(shape); + return shape_builder.Finish().Union(); + } + + ::flatbuffers::FlatBufferBuilder* builder; +}; + +struct ObjectSizeTypeGetter { + data::ObjectSize operator()(const uint3&) const { + return data::ObjectSize::Uint3; + } + data::ObjectSize operator()(const uint2&) const { + return data::ObjectSize::Uint2; + } + data::ObjectSize operator()(const uint32_t&) const { + return data::ObjectSize::Uint1; + } +}; + +struct ObjectGetter { + Offset operator()(const ObjectData& data) { + auto fb_data = builder->CreateVector(data); + data::ObjectDataBuilder data_builder(*builder); + data_builder.add_data(fb_data); + return data_builder.Finish().Union(); + } + Offset operator()(ObjectRef ref) { + data::ObjectRefBuilder ref_builder(*builder); + ref_builder.add_global_id(ref); + return ref_builder.Finish().Union(); + } + + ::flatbuffers::FlatBufferBuilder* builder; +}; + +struct ObjectTypeGetter { + data::ObjectVariant operator()(const ObjectData&) const { + return data::ObjectVariant::ObjectData; + } + data::ObjectVariant operator()(const ObjectRef&) const { + return data::ObjectVariant::ObjectRef; + } +}; + +data::AccessType ToFB(AccessType type) { + switch (type) { + case AccessType::READ: + return data::AccessType::READ; + case AccessType::WRITE: + return data::AccessType::WRITE; + case AccessType::READ_WRITE: + return data::AccessType::READ_WRITE; + } +} + +Offset Encode(const uint3& v, + ::flatbuffers::FlatBufferBuilder* builder) { + data::Uint3Builder uint3_builder(*builder); + uint3_builder.add_x(v.x); + uint3_builder.add_y(v.y); + uint3_builder.add_z(v.z); + return uint3_builder.Finish(); +} + +Offset Encode(const CompiledModelOptions& options, + ::flatbuffers::FlatBufferBuilder* builder) { + data::ParametersBuilder params_builder(*builder); + params_builder.add_dynamic_batch(options.dynamic_batch); + return params_builder.Finish(); +} + +} // namespace + +void SerializedCompiledModelBuilder::AddShader(const std::string& shader_src) { + shaders_.push_back(builder_.CreateString(shader_src)); +} + +void SerializedCompiledModelBuilder::AddProgram( + const std::vector& parameters, + const std::vector& objects, const uint3& workgroup_size, + const uint3& num_workgroups, size_t shader_index) { + Offset fb_workgroups = Encode(num_workgroups, &builder_); + Offset fb_workgroup_size = Encode(workgroup_size, &builder_); + + Offset>> fb_params; + { + std::vector> offsets; + for (const UniformParameter& param : parameters) { + auto name = builder_.CreateString(param.name); + auto data = absl::visit(ParameterValueGetter{&builder_}, param.value); + data::UniformParameterBuilder builder(builder_); + builder.add_name(name); + builder.add_data_type(absl::visit(DataVariantTypeGetter{}, param.value)); + builder.add_data(data); + builder.add_type(absl::visit(ParameterTypeGetter{}, param.value)); + offsets.push_back(builder.Finish()); + } + fb_params = builder_.CreateVector(offsets); + } + + Offset>> fb_objects; + { + std::vector> offsets; + for (const Object& object : objects) { + auto object_variant = absl::visit(ObjectGetter{&builder_}, object.object); + auto size = absl::visit(ObjectSizeGetter{&builder_}, object.size); + + data::ObjectBuilder builder(builder_); + builder.add_access(ToFB(object.access)); + builder.add_binding(object.binding); + builder.add_type(ToFB(object.object_type)); + builder.add_data_type(ToFB(object.data_type)); + builder.add_size_type(absl::visit(ObjectSizeTypeGetter{}, object.size)); + builder.add_size(size); + builder.add_object_type(absl::visit(ObjectTypeGetter{}, object.object)); + builder.add_object(object_variant); + offsets.push_back(builder.Finish()); + } + fb_objects = builder_.CreateVector(offsets); + } + + data::ProgramBuilder program_builder(builder_); + program_builder.add_number_workgroups(fb_workgroups); + program_builder.add_workgroup_size(fb_workgroup_size); + program_builder.add_parameters(fb_params); + program_builder.add_objects(fb_objects); + program_builder.add_shader_index(shader_index); + programs_.push_back(program_builder.Finish()); +} + +absl::Span SerializedCompiledModelBuilder::Finalize( + const CompiledModelOptions& options) { + auto shaders = builder_.CreateVector(shaders_); + auto programs = builder_.CreateVector(programs_); + auto parameters = Encode(options, &builder_); + data::CompiledModelBuilder model_builder(builder_); + model_builder.add_shaders(shaders); + model_builder.add_programs(programs); + model_builder.add_parameters(parameters); + data::FinishCompiledModelBuffer(builder_, model_builder.Finish()); + return absl::MakeConstSpan(builder_.GetBufferPointer(), builder_.GetSize()); +} + +namespace { + +Status ParseParameter(const data::UniformParameter& fb_parameter, + UniformParameter* parameter) { + parameter->name = fb_parameter.name()->str(); + switch (fb_parameter.type()) { + case data::ParameterType::INT32: { + auto* ptr = fb_parameter.data_as_DataInt32(); + if (ptr == nullptr) { + return InvalidArgumentError("Unexpected data type '" + parameter->name + + "'"); + } + switch (ptr->data()->size()) { + case 1: + parameter->value = (*ptr->data())[0]; + break; + case 2: + parameter->value = int2((*ptr->data())[0], (*ptr->data())[1]); + break; + case 4: + parameter->value = int4((*ptr->data())[0], (*ptr->data())[1], + (*ptr->data())[2], (*ptr->data())[3]); + break; + default: + return InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); + } + break; + } + case data::ParameterType::UINT32: { + auto* ptr = fb_parameter.data_as_DataUint32(); + if (ptr == nullptr) { + return InvalidArgumentError("Unexpected data type '" + parameter->name + + "'"); + } + switch (ptr->data()->size()) { + case 1: + parameter->value = (*ptr->data())[0]; + break; + case 4: + parameter->value = uint4((*ptr->data())[0], (*ptr->data())[1], + (*ptr->data())[2], (*ptr->data())[3]); + break; + default: + return InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); + } + break; + } + case data::ParameterType::FLOAT32: { + auto* ptr = fb_parameter.data_as_DataFloat(); + if (ptr == nullptr) { + return InvalidArgumentError("Unexpected data type '" + parameter->name + + "'"); + } + switch (ptr->data()->size()) { + case 1: + parameter->value = (*ptr->data())[0]; + break; + case 2: + parameter->value = float2((*ptr->data())[0], (*ptr->data())[1]); + break; + case 4: + parameter->value = float4((*ptr->data())[0], (*ptr->data())[1], + (*ptr->data())[2], (*ptr->data())[3]); + break; + default: + return InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); + } + break; + } + case data::ParameterType::INT32_2: { + auto* ptr = fb_parameter.data_as_DataInt32(); + if (ptr == nullptr) { + return InvalidArgumentError("Unexpected data type '" + parameter->name + + "'"); + } + + if (ptr->data()->size() % 2 != 0) { + return InvalidArgumentError("Unexpected size for parameter '" + + parameter->name + "'"); + } + + std::vector values(ptr->data()->size() / 2); + for (int i = 0; i < values.size(); ++i) { + values[i] = int2((*ptr->data())[i * 2], (*ptr->data())[i * 2 + 1]); + } + parameter->value = values; + break; + } + } + return OkStatus(); +} + +DataType ToEnum(data::DataType type) { + switch (type) { + case data::DataType::INT16: + return DataType::INT16; + case data::DataType::INT32: + return DataType::INT32; + case data::DataType::FLOAT16: + return DataType::FLOAT16; + case data::DataType::FLOAT32: + return DataType::FLOAT32; + default: + return DataType::UNKNOWN; + } +} + +ObjectType ToEnum(data::ObjectType type) { + switch (type) { + case data::ObjectType::TEXTURE: + return ObjectType::TEXTURE; + case data::ObjectType::BUFFER: + return ObjectType::BUFFER; + default: + return ObjectType::UNKNOWN; + } +} + +AccessType ToEnum(data::AccessType type) { + switch (type) { + case data::AccessType::READ: + return AccessType::READ; + case data::AccessType::WRITE: + return AccessType::WRITE; + case data::AccessType::READ_WRITE: + return AccessType::READ_WRITE; + } +} + +Status ParseObject(const data::Object& fb_object, Object* object) { + object->access = ToEnum(fb_object.access()); + object->binding = fb_object.binding(); + object->object_type = ToEnum(fb_object.type()); + object->data_type = ToEnum(fb_object.data_type()); + + switch (fb_object.size_type()) { + case data::ObjectSize::Uint3: { + auto* size = fb_object.size_as_Uint3(); + object->size = uint3(size->x(), size->y(), size->z()); + break; + } + case data::ObjectSize::Uint2: { + auto* size = fb_object.size_as_Uint2(); + object->size = uint2(size->x(), size->y()); + break; + } + case data::ObjectSize::Uint1: { + auto* size = fb_object.size_as_Uint1(); + object->size = size->x(); + break; + } + case data::ObjectSize::NONE: + return InvalidArgumentError("Texture size is not set"); + } + + switch (fb_object.object_type()) { + case data::ObjectVariant::ObjectData: { + auto* fb_data = fb_object.object_as_ObjectData(); + object->object = std::vector( + fb_data->data()->data(), + fb_data->data()->data() + fb_data->data()->size()); + break; + } + case data::ObjectVariant::ObjectRef: { + auto* fb_ref = fb_object.object_as_ObjectRef(); + object->object = fb_ref->global_id(); + break; + } + case data::ObjectVariant::NONE: { + return InvalidArgumentError("Object is not set"); + } + } + return OkStatus(); +} + +CompiledModelOptions ParseParameters(const data::Parameters& fb_parameters) { + CompiledModelOptions options; + options.dynamic_batch = fb_parameters.dynamic_batch(); + return options; +} + +} // namespace + +Status DeserializeCompiledModel(absl::Span serialized, + DeserializationHandler* handler) { + flatbuffers::Verifier verifier(serialized.data(), serialized.size()); + if (!data::VerifyCompiledModelBuffer(verifier)) { + return InvalidArgumentError("Serialized model is corrupted."); + } + + auto model = data::GetCompiledModel(serialized.data()); + for (auto shader : *model->shaders()) { + RETURN_IF_ERROR( + handler->OnShader(absl::MakeSpan(shader->c_str(), shader->size()))); + } + std::vector parameters; + std::vector objects; + for (auto program : *model->programs()) { + parameters.clear(); + objects.clear(); + for (auto fb_parameter : *program->parameters()) { + UniformParameter parameter; + RETURN_IF_ERROR(ParseParameter(*fb_parameter, ¶meter)); + parameters.push_back(std::move(parameter)); + } + for (auto fb_object : *program->objects()) { + Object object; + RETURN_IF_ERROR(ParseObject(*fb_object, &object)); + objects.push_back(std::move(object)); + } + uint3 workgroup_size(program->workgroup_size()->x(), + program->workgroup_size()->y(), + program->workgroup_size()->z()); + uint3 num_workgroups(program->number_workgroups()->x(), + program->number_workgroups()->y(), + program->number_workgroups()->z()); + RETURN_IF_ERROR(handler->OnProgram(parameters, objects, workgroup_size, + num_workgroups, + program->shader_index())); + } + handler->OnOptions(ParseParameters(*model->parameters())); + return OkStatus(); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/serialization.h b/tensorflow/lite/delegates/gpu/gl/serialization.h new file mode 100644 index 00000000000..5c981731ae2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/serialization.h @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_ + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiled_model_generated.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { + +struct CompiledModelOptions { + // If true, a model was compiled with dynamic batch size and therefore, + // a user may change BATCH dimension at runtime. + bool dynamic_batch = false; +}; + +// Accumulates shaders and programs and stores it in FlatBuffer format. +class SerializedCompiledModelBuilder { + public: + SerializedCompiledModelBuilder() : builder_(32 * 1024) {} + + void AddShader(const std::string& shader_src); + + void AddProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, const uint3& num_workgroups, + size_t shader_index); + + // Returns serialized data that will stay valid until this object is + // destroyed. + absl::Span Finalize(const CompiledModelOptions& options); + + private: + std::vector> shaders_; + std::vector> programs_; + ::flatbuffers::FlatBufferBuilder builder_; +}; + +// Handles deserialization events. it is guaranteed that shaders will be called +// first in the appropriate order and programs come next. +class DeserializationHandler { + public: + virtual ~DeserializationHandler() = default; + + virtual Status OnShader(absl::Span shader_src) = 0; + + virtual Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, + const uint3& num_workgroups, + size_t shader_index) = 0; + + virtual void OnOptions(const CompiledModelOptions& options) = 0; +}; + +Status DeserializeCompiledModel(absl::Span serialized, + DeserializationHandler* handler); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/serialization_test.cc b/tensorflow/lite/delegates/gpu/gl/serialization_test.cc new file mode 100644 index 00000000000..6256d970f29 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/serialization_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/serialization.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/uniform_parameter.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +struct ProgramDesc { + std::vector parameters; + std::vector objects; + uint3 workgroup_size; + uint3 num_workgroups; + size_t shader_index; +}; + +struct Handler : public DeserializationHandler { + Status OnShader(absl::Span shader_src) final { + shaders.push_back(std::string(shader_src.data(), shader_src.size())); + return OkStatus(); + } + + Status OnProgram(const std::vector& parameters, + const std::vector& objects, + const uint3& workgroup_size, const uint3& num_workgroups, + size_t shader_index) final { + programs.push_back( + {parameters, objects, workgroup_size, num_workgroups, shader_index}); + return OkStatus(); + } + + void OnOptions(const CompiledModelOptions& o) final { options = o; } + + std::vector shaders; + std::vector programs; + CompiledModelOptions options; +}; + +struct ParameterComparator { + bool operator()(int32_t value) const { + return value == absl::get(a.value); + } + bool operator()(const int2& value) const { + auto v = absl::get(a.value); + return value.x == v.x && value.y == v.y; + } + bool operator()(const int4& value) const { + auto v = absl::get(a.value); + return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w; + } + bool operator()(const std::vector& value) const { + auto v = absl::get>(a.value); + if (v.size() != value.size()) { + return false; + } + for (int i = 0; i < v.size(); ++i) { + if (v[i].x != value[i].x || v[i].y != value[i].y) { + return false; + } + } + return true; + } + bool operator()(uint32_t value) const { + return value == absl::get(a.value); + } + bool operator()(const uint4& value) const { + auto v = absl::get(a.value); + return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w; + } + bool operator()(float value) const { + return value == absl::get(a.value); + } + bool operator()(float2 value) const { + auto v = absl::get(a.value); + return value.x == v.x && value.y == v.y; + } + bool operator()(const float4& value) const { + auto v = absl::get(a.value); + return value.x == v.x && value.y == v.y && value.z == v.z && value.w == v.w; + } + UniformParameter a; +}; + +bool Eq(const UniformParameter& a, const UniformParameter& b) { + return a.name == b.name && absl::visit(ParameterComparator{a}, b.value); +} + +struct ObjectComparator { + bool operator()(const ObjectData& data) const { + return absl::get(a.object) == data; + } + bool operator()(const ObjectRef& ref) const { + return absl::get(a.object) == ref; + } + + Object a; +}; + +bool Eq(const Object& a, const Object& b) { + return a.access == b.access && a.binding == b.binding && + absl::visit(ObjectComparator{a}, b.object); +} + +TEST(Smoke, Read) { + std::string shader1 = "A"; + std::string shader2 = "B"; + + SerializedCompiledModelBuilder builder; + builder.AddShader(shader1); + builder.AddShader(shader2); + + std::vector parameters; + parameters.push_back(UniformParameter{"1", int32_t(1)}); + parameters.push_back(UniformParameter{"2", int2(1, 2)}); + parameters.push_back(UniformParameter{"3", int4(1, 2, 3, 4)}); + parameters.push_back(UniformParameter{"4", uint32_t(10)}); + parameters.push_back(UniformParameter{"5", uint4(10, 20, 30, 40)}); + parameters.push_back(UniformParameter{"6", -2.0f}); + parameters.push_back(UniformParameter{"7", float2(1, -1)}); + parameters.push_back(UniformParameter{"8", float4(1, -1, 2, -2)}); + parameters.push_back(UniformParameter{ + "9", std::vector{int2(1, 2), int2(3, 4), int2(5, 6)}}); + + std::vector objects; + objects.push_back(MakeReadonlyBuffer(std::vector{1, 2, 3, 4})); + objects.push_back(Object{AccessType::WRITE, DataType::FLOAT32, + ObjectType::TEXTURE, 5, uint3(1, 2, 3), 100}); + objects.push_back(Object{AccessType::READ_WRITE, DataType::INT8, + ObjectType::BUFFER, 6, uint2(2, 1), + std::vector{7, 9}}); + uint3 num_workgroups(10, 20, 30); + uint3 workgroup_size(1, 2, 3); + builder.AddProgram(parameters, objects, workgroup_size, num_workgroups, 1); + + Handler handler; + CompiledModelOptions options; + options.dynamic_batch = true; + ASSERT_TRUE( + DeserializeCompiledModel(builder.Finalize(options), &handler).ok()); + EXPECT_EQ(num_workgroups.data_, handler.programs[0].num_workgroups.data_); + EXPECT_EQ(workgroup_size.data_, handler.programs[0].workgroup_size.data_); + EXPECT_THAT(handler.shaders, ::testing::ElementsAre(shader1, shader2)); + EXPECT_EQ(handler.programs[0].parameters.size(), parameters.size()); + for (int i = 0; i < parameters.size(); ++i) { + EXPECT_TRUE(Eq(parameters[i], handler.programs[0].parameters[i])) << i; + } + EXPECT_EQ(handler.programs[0].objects.size(), objects.size()); + for (int i = 0; i < objects.size(); ++i) { + EXPECT_TRUE(Eq(objects[i], handler.programs[0].objects[i])) << i; + } + EXPECT_TRUE(handler.options.dynamic_batch); +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/stats.h b/tensorflow/lite/delegates/gpu/gl/stats.h new file mode 100644 index 00000000000..198f9ed6929 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/stats.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_STATS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_STATS_H_ + +#include + +#include "absl/strings/str_cat.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// A collection of compile-time stats exposed via API. +struct CompilerStats {}; + +struct ObjectStats { + // Number of allocated objects. + int32_t count = 0; + + // Total bytes allocated. + int64_t total_bytes = 0; +}; + +struct ObjectsStats { + ObjectStats buffers; + + ObjectStats textures; +}; + +// A collection of runtime-time stats exposed via API. +struct RuntimeStats { + ObjectsStats internal_objects; + + ObjectsStats const_objects; + + ObjectsStats external_objects; +}; + +inline std::string ToString(const ObjectStats& stats) { + return absl::StrCat("count = ", stats.count, + ", total bytes = ", stats.total_bytes); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_STATS_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/uniform_parameter.h b/tensorflow/lite/delegates/gpu/gl/uniform_parameter.h new file mode 100644 index 00000000000..90e2c237f90 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/uniform_parameter.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_UNIFORM_PARAMETER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_UNIFORM_PARAMETER_H_ + +#include +#include +#include + +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace gl { + +struct UniformParameter { + using ValueType = absl::variant>; + + std::string name; + ValueType value; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_UNIFORM_PARAMETER_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups.fbs b/tensorflow/lite/delegates/gpu/gl/workgroups.fbs new file mode 100644 index 00000000000..6609afef494 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups.fbs @@ -0,0 +1,29 @@ +include "common.fbs"; + +namespace tflite.gpu.gl.data; + +file_identifier "AFWS"; + +// Workgroup size that applies only to a specific shader that covers predefined +// collection of nodes. +table HardcodedWorkgroup { + // Defines the size of a workgroup. + size:Uint3; + + // Shader has to cover exactly these nodes to have workgroup size applied. + node_indices:[uint32]; +} + +// A collection of matchers to override default workgroup sizes in shaders. +table HardcodedWorkgroups { + // if set, workgroups are applied only if mobile gpu info matches. + gpu_info:string; + + workgroups:[HardcodedWorkgroup]; +} + +table CustomWorkgroups { + hardcoded_workgroups:[HardcodedWorkgroups]; +} + +root_type CustomWorkgroups; diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/BUILD b/tensorflow/lite/delegates/gpu/gl/workgroups/BUILD new file mode 100644 index 00000000000..101852058a7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/BUILD @@ -0,0 +1,74 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "calculator", + srcs = ["calculator.cc"], + hdrs = ["calculator.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:gpu_info", + "//tensorflow/lite/delegates/gpu/gl/compiler:shader_code", + ], +) + +cc_library( + name = "default_calculator", + srcs = ["default_calculator.cc"], + hdrs = ["default_calculator.h"], + deps = [ + ":calculator", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:gpu_info", + ], +) + +cc_library( + name = "calculator_from_metadata", + srcs = ["calculator_from_metadata.cc"], + hdrs = ["calculator_from_metadata.h"], + deps = select({ + "//tensorflow/lite/delegates/gpu:tflite_gpu_binary_release": [], + "//conditions:default": [ + ":default_calculator", + "//tensorflow/lite/delegates/gpu/gl:common_cc_fbs", + "//tensorflow/lite/delegates/gpu/gl:workgroups_cc_fbs", + "//tensorflow/lite/delegates/gpu/gl:gpu_info", + "//tensorflow/lite/delegates/gpu/gl:metadata_cc_fbs", + ":calculator", + "@com_google_absl//absl/memory", + "@flatbuffers", + "//tensorflow/lite/delegates/gpu/common:types", + ], + }), +) + +cc_library( + name = "best_effort_calculator", + srcs = ["best_effort_calculator.cc"], + hdrs = ["best_effort_calculator.h"], + deps = [ + ":calculator", + ":default_calculator", + "//tensorflow/lite/delegates/gpu/gl:gpu_info", + ] + select({ + "//tensorflow/lite/delegates/gpu:tflite_gpu_binary_release": [], + "//conditions:default": [ + ":calculator_from_metadata", + ], + }), +) + +cc_library( + name = "ideal_workgroup_picker", + srcs = ["ideal_workgroup_picker.cc"], + hdrs = ["ideal_workgroup_picker.h"], + deps = [ + ":calculator", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:gpu_info", + ], +) diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc new file mode 100644 index 00000000000..f0a1c4fbd40 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.cc @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h" + +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h" + +#ifndef TFLITE_GPU_BINARY_RELEASE +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h" +#endif + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr BestEffortWorkgroupsCalculator( + const uint8_t* metadata, const GpuInfo& gpu_info) { +#ifndef TFLITE_GPU_BINARY_RELEASE + std::unique_ptr calculator_from_metadata = + NewWorkgroupsCalculatorFromMetadata(metadata, gpu_info); + if (calculator_from_metadata) { + return calculator_from_metadata; + } +#endif + return NewDefaultWorkgroupsCalculator(gpu_info); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h new file mode 100644 index 00000000000..56d192d55cc --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_BEST_EFFORT_CALCULATOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_BEST_EFFORT_CALCULATOR_H_ + +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" + +namespace tflite { +namespace gpu { +namespace gl { + +std::unique_ptr BestEffortWorkgroupsCalculator( + const uint8_t* metadata, const GpuInfo& gpu_info); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_BEST_EFFORT_CALCULATOR_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc new file mode 100644 index 00000000000..82ddf006555 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.cc @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" + +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +uint64_t CalculateProduct(const uint3& value) { + return static_cast(value.x) * value.y * value.z; +} + +void MaybeShrinkWorkgroup(const GpuInfo& gpu_info, uint3* wg) { + while (wg->x > gpu_info.max_work_group_size[0]) { + wg->x /= 2; + } + + while (wg->y > gpu_info.max_work_group_size[1]) { + wg->y /= 2; + } + + while (wg->z > gpu_info.max_work_group_size[2]) { + wg->z /= 2; + } + + // Code below decreases amount of invocations per workgroup in a balanced way. + // As example, workgroup size is x=16, y=8, z=8 (16x8x8 = 1024), but + // max_work_group_invocations = 512. We need to fit this limit and we can + // reduce workgroup size in different ways, but we want to use the most + // balanced way. So code below will find the maximal of three dimensions and + // reduce it, so the whole workgroup is kept balanced by all dimensions. And + // the final reduced workgroup will be x=8, y=8, z=8 for the given example. + while (CalculateProduct(*wg) > gpu_info.max_work_group_invocations) { + unsigned int* max = &wg->x; + if (wg->y > *max) max = &wg->y; + if (wg->z > *max) max = &wg->z; + *max = *max /= 2; + } +} + +} // namespace + +WorkgroupsCalculator::WorkgroupsCalculator(const GpuInfo& gpu_info) + : gpu_info_{gpu_info} {} + +uint3 WorkgroupsCalculator::Calculate(const ShaderCode& shader_code) const { + uint3 workgroup_size = shader_code.recommended_workgroup; + if (workgroup_size == kEmptyWorkgroupSize) { + workgroup_size = CalculateInternal(shader_code); + } + MaybeShrinkWorkgroup(gpu_info_, &workgroup_size); + return workgroup_size; +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h new file mode 100644 index 00000000000..c59a9433ffd --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h @@ -0,0 +1,58 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" + +namespace tflite { +namespace gpu { +namespace gl { + +constexpr uint3 kEmptyWorkgroupSize(0, 0, 0); + +// Calculates workgroup size for the given shader code in a model graph. +// +// Potentially there are multiple implementations possible: +// - per-operation type hard-coded constants +// - statistic-based calculator that uses aggregated stats for all operations +class WorkgroupsCalculator { + public: + explicit WorkgroupsCalculator(const GpuInfo& gpu_info); + + virtual ~WorkgroupsCalculator() = default; + + // Uses shader code recommended work group size if available and doesn't + // exceed max work group invocations num, otherwise work group size from + // passed calculator. + uint3 Calculate(const ShaderCode& shader_code) const; + + protected: + virtual uint3 CalculateInternal(const ShaderCode& shader_code) const = 0; + + private: + GpuInfo gpu_info_; +}; + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc new file mode 100644 index 00000000000..673eedc3273 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.cc @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h" + +#ifndef TFLITE_GPU_BINARY_RELEASE + +#include +#include + +#include "tensorflow/lite/delegates/gpu/gl/metadata_generated.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups_generated.h" + +#include "absl/memory/memory.h" +#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/delegates/gpu/common/types.h" + +#endif // TFLITE_GPU_BINARY_RELEASE + +namespace tflite { +namespace gpu { +namespace gl { + +#ifndef TFLITE_GPU_BINARY_RELEASE +namespace { +class WorkgroupsCalculatorFromMetadata : public WorkgroupsCalculator { + public: + WorkgroupsCalculatorFromMetadata(const data::HardcodedWorkgroups& workgroups, + const GpuInfo& gpu_info) + : WorkgroupsCalculator(gpu_info), + default_calculator_(NewDefaultWorkgroupsCalculator(gpu_info)) { + for (const auto* workgroup : *workgroups.workgroups()) { + uint3 size(workgroup->size()->x(), workgroup->size()->y(), + workgroup->size()->z()); + // Class implementation relies on the fact that it uses unique graph + // representation where each node id appears in a single workgroup. + for (auto node_id : *workgroup->node_indices()) { + workgroups_.insert({node_id, size}); + } + } + } + + uint3 CalculateInternal(const ShaderCode& shader_code) const final { + auto it = workgroups_.find(shader_code.node_indices[0]); + return it != workgroups_.end() + ? it->second + : default_calculator_->Calculate(shader_code); + } + + private: + std::unordered_map workgroups_; + std::unique_ptr default_calculator_; +}; + +const data::HardcodedWorkgroups* FindWorkgroups( + const data::CustomWorkgroups& workgroups, const GpuInfo& gpu_info) { + for (auto workgroup : *workgroups.hardcoded_workgroups()) { + if (workgroup->gpu_info()->c_str() == gpu_info.renderer_name) { + return workgroup; + } + } + return nullptr; +} + +} // namespace + +std::unique_ptr NewWorkgroupsCalculatorFromMetadata( + const uint8_t* metadata, const GpuInfo& gpu_info) { + if (!metadata) return nullptr; + const auto* flow_metadata = + flatbuffers::GetRoot(metadata); + if (!flow_metadata || !flow_metadata->workgroups()) return nullptr; + const data::HardcodedWorkgroups* workgroups = + FindWorkgroups(*flow_metadata->workgroups(), gpu_info); + if (!workgroups) return nullptr; + return absl::make_unique(*workgroups, + gpu_info); +} + +#else // TFLITE_GPU_BINARY_RELEASE + +std::unique_ptr NewWorkgroupsCalculatorFromMetadata( + const uint8_t* metadata, const GpuInfo& gpu_info) { + return nullptr; +} + +#endif // TFLITE_GPU_BINARY_RELEASE + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h new file mode 100644 index 00000000000..cca859f8795 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_FROM_METADATA_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_FROM_METADATA_H_ + +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Creates new workgroup calculator that uses extra information serialized in +// metadata. +std::unique_ptr NewWorkgroupsCalculatorFromMetadata( + const uint8_t* metadata, const GpuInfo& gpu_info); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_CALCULATOR_FROM_METADATA_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.cc new file mode 100644 index 00000000000..ebfba146d93 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.cc @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h" + +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +class DefaultWorkgroupsCalculator : public WorkgroupsCalculator { + public: + explicit DefaultWorkgroupsCalculator(const GpuInfo& gpu_info) + : WorkgroupsCalculator(gpu_info) {} + uint3 CalculateInternal(const ShaderCode& shader_code) const final { + const auto& workload = shader_code.workload; + if (workload.z >= 64) { + return uint3(4, 4, 64); + } + if (workload.z >= 32) { + return uint3(8, 4, 32); + } + if (workload.z >= 16) { + return uint3(8, 8, 16); + } + if (workload.z >= 8) { + return uint3(16, 8, 8); + } + if (workload.z >= 4) { + return uint3(16, 16, 4); + } + if (workload.z >= 2) { + return uint3(32, 16, 2); + } + return uint3(32, 32, 1); + } +}; + +class WorkgroupsCalculatorForMali : public WorkgroupsCalculator { + public: + explicit WorkgroupsCalculatorForMali(const GpuInfo& gpu_info) + : WorkgroupsCalculator(gpu_info) {} + uint3 CalculateInternal(const ShaderCode& shader_code) const final { + const auto& workload = shader_code.workload; + if (workload.z >= 32) { + return uint3(2, 2, 32); + } + if (workload.z >= 16) { + return uint3(4, 2, 16); + } + if (workload.z >= 8) { + return uint3(4, 4, 8); + } + if (workload.z >= 4) { + return uint3(8, 4, 4); + } + if (workload.z >= 2) { + return uint3(8, 8, 2); + } + return uint3(16, 8, 1); + } +}; + +} // namespace + +std::unique_ptr NewDefaultWorkgroupsCalculator( + const GpuInfo& gpu_info) { + if (gpu_info.type == GpuType::MALI) { + return absl::make_unique(gpu_info); + } else { + return absl::make_unique(gpu_info); + } +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h new file mode 100644 index 00000000000..c8840abf4e5 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_DEFAULT_CALCULATOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_DEFAULT_CALCULATOR_H_ + +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Creates new workgroups calculator for the general case or specificly for Mali +std::unique_ptr NewDefaultWorkgroupsCalculator( + const GpuInfo& gpu_info); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_DEFAULT_CALCULATOR_H_ diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc new file mode 100644 index 00000000000..07dffa306a1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.cc @@ -0,0 +1,200 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h" + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h" + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +// This code employs the results the workgroup performance reseach +// (b/117291356). + +// Describes the ideal convolution for the specific operation case +// Case here means specific "kernel + strides" conbination for specific +// operatoins type, not sizes of input and output tensors, they can be any. +struct IdealByCase { + bool ParamsAccepted(OperationType in_op_type, HW in_kernel, + HW in_strides) const { + return operation_type == in_op_type && kernel == in_kernel && + strides == in_strides; + } + OperationType operation_type; + HW kernel; + HW strides; + uint3 ideal_workgroup; +}; + +// Describes the ideal convolution for the type of operations. It means that +// any configuration of operation of this type will be working with top 10% +// performance with the particular GPU. +struct IdealByType { + bool ParamsAccepted(OperationType in_op_type) const { + return operation_type == in_op_type; + } + OperationType operation_type; + uint3 ideal_workgroup; +}; + +// Describes ideal workgroups for the particular GPU model. +struct IdealWorkgroups { + std::vector by_type; + std::vector by_case; +}; + +// List of Ideal workgroups which is received after the research mentioned +// above. + +// Ideal workgroups for Adreno 630. +std::vector* kIdealByTypeAdreno630Ptr = + new std::vector{ + {OperationType::CONVOLUTION_2D, uint3(4, 8, 4)}, + {OperationType::DEPTHWISE_CONVOLUTION, uint3(4, 4, 8)}, + }; + +std::vector* kIdealByCaseAdreno630Ptr = + new std::vector{ + {OperationType::CONVOLUTION_2D, HW(1, 1), HW(1, 1), uint3(4, 8, 4)}, + {OperationType::CONVOLUTION_2D, HW(3, 3), HW(2, 2), uint3(8, 4, 4)}, + {OperationType::DEPTHWISE_CONVOLUTION, HW(1, 1), HW(1, 1), + uint3(8, 4, 4)}, + {OperationType::DEPTHWISE_CONVOLUTION, HW(3, 3), HW(2, 2), + uint3(4, 4, 4)}, + }; + +// Ideal workgroups for Adreno 540. +std::vector* kIdealByTypeAdreno540Ptr = + new std::vector{ + {OperationType::CONVOLUTION_2D, uint3(8, 2, 2)}, + {OperationType::DEPTHWISE_CONVOLUTION, uint3(8, 8, 2)}, + }; + +std::vector* kIdealByCaseAdreno540Ptr = + new std::vector{ + {OperationType::CONVOLUTION_2D, HW(1, 1), HW(1, 1), uint3(4, 2, 8)}, + {OperationType::CONVOLUTION_2D, HW(3, 3), HW(2, 2), uint3(8, 2, 8)}, + {OperationType::DEPTHWISE_CONVOLUTION, HW(1, 1), HW(1, 1), + uint3(8, 4, 8)}, + {OperationType::DEPTHWISE_CONVOLUTION, HW(3, 3), HW(2, 2), + uint3(4, 4, 8)}, + }; + +// Ideal workgroups for Adreno 510. +std::vector* kIdealByTypeAdreno510Ptr = + new std::vector{ + {OperationType::CONVOLUTION_2D, uint3(8, 4, 4)}, + {OperationType::DEPTHWISE_CONVOLUTION, uint3(8, 4, 4)}, + }; + +std::vector* kIdealByCaseAdreno510Ptr = + new std::vector{ + {OperationType::CONVOLUTION_2D, HW(1, 1), HW(1, 1), uint3(4, 2, 8)}, + {OperationType::CONVOLUTION_2D, HW(3, 3), HW(2, 2), uint3(8, 2, 8)}, + {OperationType::DEPTHWISE_CONVOLUTION, HW(1, 1), HW(1, 1), + uint3(8, 4, 8)}, + {OperationType::DEPTHWISE_CONVOLUTION, HW(3, 3), HW(2, 2), + uint3(4, 4, 8)}, + }; + +// Ideal workgroups for Adreno 509. +std::vector* kIdealByTypeAdreno509Ptr = + new std::vector{ + {OperationType::CONVOLUTION_2D, uint3(8, 4, 8)}, + {OperationType::DEPTHWISE_CONVOLUTION, uint3(8, 8, 2)}, + }; + +// Ideal workgroups for Adreno 508, 506, 505, 418, 405 +std::vector* kIdealByTypeAdreno508Ptr = + new std::vector{ + {OperationType::CONVOLUTION_2D, uint3(8, 4, 8)}, + {OperationType::DEPTHWISE_CONVOLUTION, uint3(8, 4, 8)}, + }; +std::vector* kIdealByTypeAdreno506Ptr = kIdealByTypeAdreno508Ptr; +std::vector* kIdealByTypeAdreno505Ptr = kIdealByTypeAdreno508Ptr; +std::vector* kIdealByTypeAdreno418Ptr = kIdealByTypeAdreno508Ptr; +std::vector* kIdealByTypeAdreno405Ptr = kIdealByTypeAdreno508Ptr; + +// Put all ideal workgroups from the list together. +const std::map* kIdealWorkgroupsInfoPtr = + new std::map{ + {GpuModel::ADRENO630, + {*kIdealByTypeAdreno630Ptr, *kIdealByCaseAdreno630Ptr}}, + {GpuModel::ADRENO540, {*kIdealByTypeAdreno540Ptr, {}}}, + {GpuModel::ADRENO510, + {*kIdealByTypeAdreno510Ptr, *kIdealByCaseAdreno510Ptr}}, + {GpuModel::ADRENO509, {*kIdealByTypeAdreno509Ptr, {}}}, + {GpuModel::ADRENO508, {*kIdealByTypeAdreno508Ptr, {}}}, + {GpuModel::ADRENO506, {*kIdealByTypeAdreno506Ptr, {}}}, + {GpuModel::ADRENO505, {*kIdealByTypeAdreno505Ptr, {}}}, + {GpuModel::ADRENO418, {*kIdealByTypeAdreno418Ptr, {}}}, + {GpuModel::ADRENO405, {*kIdealByTypeAdreno405Ptr, {}}}, + }; + +} // namespace + +uint3 GetIdealWorkgroupIfPossible(GpuModel gpu_model, OperationType op_type, + HW kernel, HW strides, uint3 default_wg, + OHWI workload) { + // Research showed that ideal workgroup approach doesn't work well with + // convolutions, which have small amount of output channels or output + // height/width dimensions + if (workload.o < 32 || workload.h <= 5 || workload.w <= 5) return default_wg; + + // If GPU was investigated + if (!kIdealWorkgroupsInfoPtr->count(gpu_model)) { + return default_wg; + } + + // Try to find the ideal workgroup by the specific operation case, cause they + // are expected to be better tuned than default "by type" cases + for (const auto& specific_case : + kIdealWorkgroupsInfoPtr->at(gpu_model).by_case) { + if (specific_case.ParamsAccepted(op_type, kernel, strides)) { + return specific_case.ideal_workgroup; + } + } + + // Try to find the ideal workgroup by the operation type + for (const auto& default_case : + kIdealWorkgroupsInfoPtr->at(gpu_model).by_type) { + if (default_case.ParamsAccepted(op_type)) { + return default_case.ideal_workgroup; + } + } + + // If no ideal workgroup is found, use the default workgroup suggested by each + // operation. + return default_wg; +} + +uint3 GetIdealWorkgroupIfPossible(GpuModel gpu_model, OperationType op_type, + HW kernel, HW strides, OHWI workload) { + return GetIdealWorkgroupIfPossible(gpu_model, op_type, kernel, strides, + kEmptyWorkgroupSize, workload); +} + +} // namespace gl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h new file mode 100644 index 00000000000..34461bdab50 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl/workgroups/ideal_workgroup_picker.h @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_IDEAL_WORKGROUP_PICKER_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_IDEAL_WORKGROUP_PICKER_H_ + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/gpu_info.h" + +namespace tflite { +namespace gpu { +namespace gl { + +// Picks up the ideal workgroup size for the given convolution case. +// Ideal workgroup gives top 10% of the possible performance for the given case. +// They are received after the workgroup performance research (b/117291356). +uint3 GetIdealWorkgroupIfPossible(GpuModel gpu_model, OperationType op_type, + HW kernel, HW strides, OHWI workload); + +// Does the same as the function above. Use this one if your operation can +// suggest some reasonable workgroup size. It's expected to give better +// performance than the default workgroup calculator. +uint3 GetIdealWorkgroupIfPossible(GpuModel gpu_model, OperationType op_type, + HW kernel, HW strides, uint3 default_wg, + OHWI workload); + +} // namespace gl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_WORKGROUPS_IDEAL_WORKGROUP_PICKER_H_ diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc new file mode 100644 index 00000000000..a3ac65d6213 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc @@ -0,0 +1,498 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/types/span.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_builder.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/gpu/gl/api.h" +#include "tensorflow/lite/delegates/gpu/gl/command_queue.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler.h" +#include "tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h" +#include "tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h" +#include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/kernels/registry.h" +#include "tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h" +#include "tensorflow/lite/minimal_logging.h" + +#ifndef TFLITE_GPU_BINARY_RELEASE +#include "flatbuffers/flatbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/delegates/gpu/gl/metadata_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" +#endif // TFLITE_GPU_BINARY_RELEASE + +namespace tflite { +namespace gpu { +namespace gl { +namespace { + +// Forward declarations. +TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); +TfLiteStatus DelegateCopyFromBufferHandle( + TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, // ValueId + TfLiteTensor* tensor); +TfLiteStatus DelegateCopyToBufferHandle( + TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, // ValueId + TfLiteTensor* tensor); + +inline bool IsPHWC4(const BHWC& shape) { + return shape.c == 4 || (shape.h == 1 && shape.w == 1 && shape.c % 4 == 0); +} + +class Delegate { + struct ValueRef { + BHWC shape; + int tensor_index; + }; + + public: + explicit Delegate(const TfLiteGpuDelegateOptions* options) { + if (options) { + options_ = *options; + } else { + // Default options. + options_.metadata = nullptr; + options_.compile_options.precision_loss_allowed = 0; + options_.compile_options.preferred_gl_object_type = + TFLITE_GL_OBJECT_TYPE_FASTEST; + options_.compile_options.dynamic_batch_enabled = 0; + } + } + + Status CopyFromBufferHandle(TfLiteBufferHandle handle, TfLiteTensor* tensor) { + ValueRef ref; + RETURN_IF_ERROR(FindObject(handle, &ref)); + auto buffer = phwc4_objects_.FindBuffer(handle); + return buffer->MappedRead([&](absl::Span data) { + tensor->data_is_stale = false; + return ConvertFromPHWC4( + data, ref.shape, + absl::MakeSpan(tensor->data.f, tensor->bytes / sizeof(float))); + }); + } + + Status CopyToBufferHandle(TfLiteBufferHandle handle, + TfLiteTensor* tensor) const { + ValueRef ref; + RETURN_IF_ERROR(FindObject(handle, &ref)); + auto buffer = phwc4_objects_.FindBuffer(handle); + return buffer->MappedWrite([&](absl::Span data) { + return ConvertToPHWC4( + absl::MakeConstSpan(tensor->data.f, tensor->bytes / sizeof(float)), + ref.shape, data); + }); + } + + Status BindBufferToTensor(GLuint ssbo, int tensor_index) { + int64_t bytes_size; + { + gl_buffer_internal::BufferBinder binder(GL_SHADER_STORAGE_BUFFER, ssbo); + RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetBufferParameteri64v, + GL_SHADER_STORAGE_BUFFER, + GL_BUFFER_SIZE, &bytes_size)); + } + return bhwc_objects_.RegisterBuffer( + tensor_index, GlBuffer(GL_SHADER_STORAGE_BUFFER, ssbo, bytes_size, + /* offset = */ 0, + /* has_ownership = */ false)); + } + + Status Prepare(TfLiteContext* context, + const TfLiteDelegateParams* delegate_params) { + // Extract TFLite delegate execution plan from the context and convert it + // into FlowGraph32. + GraphFloat32 graph; + RETURN_IF_ERROR(BuildModel(context, delegate_params, &graph)); + + // Apply general transformations on the graph. + NullTransformationReporter reporter; + ModelTransformer transformer(&graph, &reporter); + if (!ApplyGeneralTransformations(&transformer)) { + return InternalError("Graph general transformations failed"); + } + + if (!env_) RETURN_IF_ERROR(EglEnvironment::NewEglEnvironment(&env_)); + + // TODO(impjdi): Remove code duplication. + auto values = graph.values(); + auto find_value = [&](int tensor_index) -> Value* { + for (auto value : values) { + if (value->tensor.ref == tensor_index) return value; + } + return nullptr; + }; + tensors_.reserve(values.back()->id + 1); + for (auto value : values) { + if (tensors_.size() <= value->id) { + tensors_.resize(value->id + 1); + } + tensors_[value->id] = {value->tensor.shape, 0}; + } + + // Prepare graph inputs. + // + // Note that graph.inputs() cannot be used directly, as the notion of + // graph input has a different meaning in public API and GPU-internal API. + { + inputs_.reserve(delegate_params->input_tensors->size); + for (int i = 0; i < delegate_params->input_tensors->size; ++i) { + const int tensor_index = delegate_params->input_tensors->data[i]; + auto* tensor = context->tensors + tensor_index; + if (tensor->allocation_type == TfLiteAllocationType::kTfLiteMmapRo) { + continue; + } + const auto* input = find_value(tensor_index); + if (!input || tensor->type != TfLiteType::kTfLiteFloat32) { + return NotFoundError("Input tensor is not found in the graph."); + } + + inputs_.push_back(input->id); + tensor->buffer_handle = input->id; + tensor->delegate = &delegate_; + tensors_[input->id].tensor_index = tensor_index; + + // Create phwc4 input buffer. + // Check whether there is externally provided object is already in + // PHWC4. If yes, we may skip conversion step. + // We need to keep same buffer in bhwc_objects_ to indicate there is + // externally provided buffer. + auto external_buffer = bhwc_objects_.FindBuffer(tensor_index); + GlBuffer buffer; + if (IsPHWC4(input->tensor.shape) && external_buffer) { + buffer = external_buffer->MakeRef(); + } else { + RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( + GetElementsSizeForPHWC4(input->tensor.shape), &buffer)); + } + RETURN_IF_ERROR( + phwc4_objects_.RegisterBuffer(input->id, std::move(buffer))); + } + } + + // Prepare graph outputs. + // + // Note that graph.outputs() cannot be used directly, as the notion of + // graph output has a different meaning in public API and GPU-internal API. + { + outputs_.reserve(delegate_params->output_tensors->size); + for (int i = 0; i < delegate_params->output_tensors->size; ++i) { + const int tensor_index = delegate_params->output_tensors->data[i]; + auto* tensor = context->tensors + tensor_index; + const auto* output = find_value(tensor_index); + if (!output || tensor->type != TfLiteType::kTfLiteFloat32) { + return NotFoundError("Output tensor is not found in the graph."); + } + + outputs_.push_back(output->id); + tensor->buffer_handle = output->id; + tensor->delegate = &delegate_; + tensors_[output->id].tensor_index = tensor_index; + + // Create phwc4 output buffer. + // Check whether there is externally provided object is already in + // PHWC4. If yes, we may skip conversion step. + auto external_buffer = bhwc_objects_.FindBuffer(tensor_index); + GlBuffer buffer; + if (IsPHWC4(output->tensor.shape) && external_buffer) { + buffer = external_buffer->MakeRef(); + } else { + RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( + GetElementsSizeForPHWC4(output->tensor.shape), &buffer)); + } + RETURN_IF_ERROR( + phwc4_objects_.RegisterBuffer(output->id, std::move(buffer))); + } + } + + // Create shaders to convert from/to phwc4. + RETURN_IF_ERROR(ConverterBhwcToPhwc4::Create(&bhwc_to_phwc4_)); + RETURN_IF_ERROR(ConverterPhwc4ToBhwc::Create(&phwc4_to_bhwc_)); + + // Compile model. + CompilationOptions compile_options; + compile_options.allow_precision_loss = + static_cast(options_.compile_options.precision_loss_allowed); + compile_options.preferred_obj_type = static_cast( + options_.compile_options.preferred_gl_object_type); + compile_options.ref_obj_type = static_cast( + options_.compile_options.preferred_gl_object_type); + compile_options.dynamic_batch = + static_cast(options_.compile_options.dynamic_batch_enabled); + auto shaders = NewNodeShaderRegistry(); + GpuInfo gpu_info; + RETURN_IF_ERROR(RequestGpuInfo(&gpu_info)); + command_queue_ = NewCommandQueue(gpu_info); + auto workgroups_calculator = + BestEffortWorkgroupsCalculator(options_.metadata, gpu_info); + std::unique_ptr compiled_model; + RETURN_IF_ERROR(Compile(compile_options, graph, *shaders, + *workgroups_calculator, &compiled_model)); + + // Create inference context. + const RuntimeOptions runtime_options; + RETURN_IF_ERROR(compiled_model->NewRun(runtime_options, &phwc4_objects_, + command_queue_.get(), + &inference_context_)); + return OkStatus(); + } + + Status Invoke(TfLiteContext* context) { + const EGLContext egl_context_at_delegate_init = env_->context().context(); + const EGLContext egl_context_at_delegate_invoke = eglGetCurrentContext(); + if (egl_context_at_delegate_init != egl_context_at_delegate_invoke) { + return FailedPreconditionError( + "Delegate should run on the same thread where it was initialized."); + } + + // Push input data from a tensor to GPU. + for (ValueId id : inputs_) { + const ValueRef& ref = tensors_[id]; + auto external_object = bhwc_objects_.FindBuffer(ref.tensor_index); + if (external_object) { + // Use input from GPU. + // Conversion is needed only when external object is not phwc4. + if (!IsPHWC4(tensors_[id].shape)) { + RETURN_IF_ERROR(bhwc_to_phwc4_.Convert( + ref.shape, *external_object, command_queue_.get(), + phwc4_objects_.FindBuffer(id))); + } + } else { + // Copy from CPU to GPU + TfLiteTensor& tensor = context->tensors[ref.tensor_index]; + RETURN_IF_ERROR(CopyToBufferHandle(id, &tensor)); + } + } + + // Run inference. + RETURN_IF_ERROR(inference_context_->Reset()); + RETURN_IF_ERROR(inference_context_->Execute()); + + // Push output data from GPU to a tensor. + bool finished_gpu_processing = false; + for (ValueId id : outputs_) { + const ValueRef& ref = tensors_[id]; + auto external_object = bhwc_objects_.FindBuffer(ref.tensor_index); + if (external_object) { + // Convert data from PHWC4 to BHWC and leave it in GPU object. + // Conversion is needed only when external object is not phwc4. + if (!IsPHWC4(tensors_[id].shape)) { + RETURN_IF_ERROR( + phwc4_to_bhwc_.Convert(ref.shape, *phwc4_objects_.FindBuffer(id), + command_queue_.get(), external_object)); + } + } else { + // Wait until all GPU command are completed. This call leads to a lower + // processing latency because a buffer reading below will not stall if + // data is not yet ready. + if (!finished_gpu_processing) { + RETURN_IF_ERROR(command_queue_->WaitForCompletion()); + finished_gpu_processing = true; + } + // Copy from GPU to CPU. + TfLiteTensor& tensor = context->tensors[ref.tensor_index]; + RETURN_IF_ERROR(CopyFromBufferHandle(id, &tensor)); + } + } + return OkStatus(); + } + + TfLiteDelegate* tflite_delegate() { return &delegate_; } + + private: + Status FindObject(ValueId id, ValueRef* ref) const { + if (id >= tensors_.size()) { + return InvalidArgumentError("Invalid buffer id"); + } + *ref = tensors_[id]; + return OkStatus(); + } + + TfLiteDelegate delegate_ = { + reinterpret_cast(this), // .data_ + DelegatePrepare, // .Prepare + DelegateCopyFromBufferHandle, // .CopyFromBufferHandle + DelegateCopyToBufferHandle, // .CopyToBufferHandle + nullptr, // .FreeBufferHandle + kTfLiteDelegateFlagsNone, // .flags + }; + + TfLiteGpuDelegateOptions options_; + + std::unique_ptr env_; + std::vector tensors_; // indexed by ValueId + std::vector inputs_; + std::vector outputs_; + ObjectManager phwc4_objects_; + ObjectManager bhwc_objects_; // key is tensor_index + ConverterPhwc4ToBhwc phwc4_to_bhwc_; + ConverterBhwcToPhwc4 bhwc_to_phwc4_; + std::unique_ptr command_queue_; + std::unique_ptr inference_context_; +}; + +// TODO(impjdi): Merge with MetalDelegate. +bool IsAllFloatTensors(const TfLiteContext* context, + const TfLiteIntArray* array) { + for (int i = 0; i < array->size; ++i) { + const TfLiteTensor* t = context->tensors + array->data[i]; + if (t->allocation_type == kTfLiteArenaRw && t->type != kTfLiteFloat32) { + return false; + } + } + return true; +} + +inline Delegate* GetGpuDelegate(TfLiteNode* node) { + return reinterpret_cast(node->user_data); +} + +inline Delegate* GetGpuDelegate(TfLiteDelegate* delegate) { + return reinterpret_cast(delegate->data_); +} + +TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { + const TfLiteRegistration kRegistration = { + // .init + [](TfLiteContext* context, const char* buffer, size_t) -> void* { + const auto* params = + reinterpret_cast(buffer); + auto* gpu_delegate = GetGpuDelegate(params->delegate); + // Everything below should happen in prepare function call, but TFLite + // for whatever reason forbids that. + const auto status = gpu_delegate->Prepare(context, params); + if (status.ok()) return gpu_delegate; + context->ReportError(context, "TfLiteGpuDelegate Prepare: %s", + status.error_message().c_str()); + return nullptr; + }, + // .free + [](TfLiteContext*, void* buffer) -> void {}, + // .prepare + [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { + return node->user_data ? kTfLiteOk : kTfLiteError; + }, + // .invoke + [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { + const auto status = GetGpuDelegate(node)->Invoke(context); + if (status.ok()) return kTfLiteOk; + context->ReportError(context, "TfLiteGpuDelegate Invoke: %s", + status.error_message().c_str()); + return kTfLiteError; + }, + nullptr, // .profiling_string + 0, // .builtin_code + "TfLiteGpuDelegate", // .custom_name + 1, // .version + }; + TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); + const auto status = context->ReplaceNodeSubsetsWithDelegateKernels( + context, kRegistration, ops_to_replace, delegate); + TfLiteIntArrayFree(ops_to_replace); + return status; +} + +TfLiteStatus DelegateCopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + auto* gpu_delegate = GetGpuDelegate(delegate); + if (!gpu_delegate) return kTfLiteError; + const auto status = gpu_delegate->CopyFromBufferHandle(buffer_handle, tensor); + if (status.ok()) return kTfLiteOk; + context->ReportError(context, "TfLiteGpuDelegate CopyFromBufferHandle: %s", + status.error_message().c_str()); + return kTfLiteError; +} + +TfLiteStatus DelegateCopyToBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) { + auto* gpu_delegate = GetGpuDelegate(delegate); + if (!gpu_delegate) return kTfLiteError; + const auto status = gpu_delegate->CopyToBufferHandle(buffer_handle, tensor); + if (status.ok()) return kTfLiteOk; + context->ReportError(context, "TfLiteGpuDelegate CopyToBufferHandle: %s", + status.error_message().c_str()); + return kTfLiteError; +} + +} // namespace +} // namespace gl +} // namespace gpu +} // namespace tflite + +TfLiteDelegate* TfLiteGpuDelegateCreate( + const TfLiteGpuDelegateOptions* options) { + TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, + "Created TensorFlow Lite delegate for GPU."); + auto* gpu_delegate = new tflite::gpu::gl::Delegate(options); + return gpu_delegate ? gpu_delegate->tflite_delegate() : nullptr; +} + +void TfLiteGpuDelegateDelete(TfLiteDelegate* delegate) { + delete tflite::gpu::gl::GetGpuDelegate(delegate); +} + +TfLiteStatus TfLiteGpuDelegateBindBufferToTensor(TfLiteDelegate* delegate, + GLuint buffer, + int tensor_index) { + auto* gpu_delegate = tflite::gpu::gl::GetGpuDelegate(delegate); + return gpu_delegate && + gpu_delegate->BindBufferToTensor(buffer, tensor_index).ok() + ? kTfLiteOk + : kTfLiteError; +} + +#ifndef TFLITE_GPU_BINARY_RELEASE +const uint8_t* TfLiteGpuDelegateGetModelMetadata(const void* tflite_model) { + const auto* model = reinterpret_cast(tflite_model); + if (!model || !model->metadata_buffer() || !model->buffers()) return nullptr; + for (int32_t buffer_index : *model->metadata_buffer()) { + if (buffer_index < 0 && buffer_index >= model->buffers()->size()) continue; + const tflite::Buffer* buffer = model->buffers()->Get(buffer_index); + if (!buffer) continue; + const uint8_t* data = buffer->data()->data(); + if (!flatbuffers::BufferHasIdentifier( + data, tflite::gpu::gl::data::FlowMetadataIdentifier())) { + continue; + } + flatbuffers::Verifier verifier(data, buffer->data()->size()); + return tflite::gpu::gl::data::VerifyFlowMetadataBuffer(verifier) ? data + : nullptr; + } + return nullptr; +} +#endif // TFLITE_GPU_BINARY_RELEASE diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.h b/tensorflow/lite/delegates/gpu/gl_delegate.h new file mode 100644 index 00000000000..aa78e1b9804 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/gl_delegate.h @@ -0,0 +1,114 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_DELEGATE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_GL_DELEGATE_H_ + +#include + +#include +#include "tensorflow/lite/c/c_api_internal.h" + +#ifdef SWIG +#define TFL_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TFL_CAPI_EXPORT __declspec(dllexport) +#else +#define TFL_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TFL_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// LINT.IfChange +enum TfLiteGlObjectType { + TFLITE_GL_OBJECT_TYPE_FASTEST = 0, + TFLITE_GL_OBJECT_TYPE_TEXTURE = 1, + TFLITE_GL_OBJECT_TYPE_BUFFER = 2, +}; + +// Shader compilation options. +// TODO(impjdi): Unify with opengl::CompilationOptions. +struct TFL_CAPI_EXPORT TfLiteGlCompileOptions { + // When set to zero, computations are carried out in 32-bit floating point. + // Otherwise, the GPU may quantify tensors, downcast values, process in FP16 + // (recommended). + int32_t precision_loss_allowed; + + // User's preferred GL object to represent tensors. When set to: + // * `TFLITE_GL_OBJECT_TYPE_FASTEST`, the delegate chooses a GL object type + // automatically that will perform fastest (recommended). + // * `TFLITE_GL_OBJECT_TYPE_TEXTURE`: GL textures are used to represent + // tensors which often work faster on Adreno-based devices, but may use more + // memory. + // * `TFLITE_GL_OBJECT_TYPE_BUFFER`: GL shader storage buffer objects are used + // to represent tensors. + int32_t preferred_gl_object_type; + + // When set to zero, dynamic batching is disabled and input/output tensors + // must have a batch size of 1 (probably what you unless you use LSTMs). + // Otherwise, enables dynamic batching and input/output tensor can have a + // batch size greater than 1. + int32_t dynamic_batch_enabled; +}; + +struct TFL_CAPI_EXPORT TfLiteGpuDelegateOptions { + const uint8_t* metadata; // Internal. + TfLiteGlCompileOptions compile_options; +}; +// LINT.ThenChange(//tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java) + +// Creates a new delegate instance that need to be destroyed with +// TfLiteGpuDelegateDelete when delegate is no longer used by TFLite. +// When `options` is set to `nullptr`, the following default values are used: +// .metadata = nullptr, +// .compile_options = { +// .precision_loss_allowed = false, +// .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST, +// .dynamic_batch_enabled = false, +// }, +TFL_CAPI_EXPORT TfLiteDelegate* TfLiteGpuDelegateCreate( + const TfLiteGpuDelegateOptions* options); + +// Destroys a delegate created with `TfLiteGpuDelegateCreate` call. +TFL_CAPI_EXPORT void TfLiteGpuDelegateDelete(TfLiteDelegate* delegate); + +// Binds GL shader storage object to an input or an output tensor in the +// initialized delegate. Bound buffer should have sufficient storage to +// accommodate all elements of a tensor. +// +// *** Must be called *before* `Interpreter::ModifyGraphWithDelegate`. *** +TFL_CAPI_EXPORT TfLiteStatus TfLiteGpuDelegateBindBufferToTensor( + TfLiteDelegate* delegate, GLuint buffer, int tensor_index); + +#ifndef TFLITE_GPU_BINARY_RELEASE +// Returns the metadata of `tflite_model` if it has one, or `nullptr` otherwise. +// Designed to be used with `TfLiteGpuDelegateOptions.metadata`. +TFL_CAPI_EXPORT const uint8_t* TfLiteGpuDelegateGetModelMetadata( + const void* tflite_model); +#endif // TFLITE_GPU_BINARY_RELEASE + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_DELEGATE_H_ diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD new file mode 100644 index 00000000000..f1739169055 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/BUILD @@ -0,0 +1,7 @@ +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "gpu_delegate", + srcs = ["GpuDelegate.java"], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java new file mode 100644 index 00000000000..b19dc346fcb --- /dev/null +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java @@ -0,0 +1,165 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.gpu; + +import java.io.Closeable; +import org.tensorflow.lite.Delegate; +import org.tensorflow.lite.Tensor; + +/** {@link Delegate} for GPU inference. */ +public class GpuDelegate implements Delegate, Closeable { + + private static final long INVALID_DELEGATE_HANDLE = 0; + private static final String TFLITE_GPU_LIB = "tensorflowlite_gpu_jni"; + + private long delegateHandle; + + /** Shader compilation options. */ + public static final class CompileOptions { + public CompileOptions() {} + + /** Delegate chooses fastest GL object type to represent tensors (default). */ + public static final int GL_OBJECT_TYPE_FASTEST = 0; + /** + * Delegate uses GL textures to represent tensors, which works faster on Adreno-based devices, + * but may use more memory. + */ + public static final int GL_OBJECT_TYPE_TEXTURE = 1; + /** Delegate uses GL shader storage buffer objects to represent tensors. */ + public static final int GL_OBJECT_TYPE_BUFFER = 2; + + /** + * Sets whether precision loss is allowed. + * + * @param precisionLossAllowed When `true` (default), the GPU may quantify tensors, downcast + * values, process in FP16. When `false`, computations are carried out in 32-bit floating + * point. + */ + public CompileOptions setPrecisionLossAllowed(boolean precisionLossAllowed) { + this.precisionLossAllowed = precisionLossAllowed; + return this; + } + + /** + * Sets whether dynamic batch is enabled. + * + * @param dynamicBatchEnabled When `false` (default), dynamic batching is disabled and + * input/output tensors must have a batch size of 1 (probably what you want, unless you use + * LSTMs). When `true`, enables dynamic batching and input/output tensor can have a batch + * size greater than 1. + */ + public CompileOptions setDynamicBatchEnabled(boolean dynamicBatchEnabled) { + this.dynamicBatchEnabled = dynamicBatchEnabled; + return this; + } + + /** + * Sets the preferred GL object type for tensor representation + * + * @param preferredGlObjectType One of `GL_OBJECT_TYPE_FASTEST` (default), + * `GL_OBJECT_TYPE_TEXTURE`, `GL_OBJECT_TYPE_BUFFER`. + */ + public CompileOptions setPreferredGlObjectType(int preferredGlObjectType) { + this.preferredGlObjectType = preferredGlObjectType; + return this; + } + + boolean precisionLossAllowed = true; + boolean dynamicBatchEnabled = false; + int preferredGlObjectType = GL_OBJECT_TYPE_FASTEST; + } + + /** Delegate options. */ + public static final class Options { + public Options() {} + + private static final CompileOptions DEFAULT_COMPILE_OPTIONS = new CompileOptions(); + + /** + * Sets the shader compilation options to be used by the delegate. + * + * @param compileOptions the {@link CompileOptions} to use. + */ + public Options setCompileOptions(CompileOptions compileOptions) { + this.compileOptions = compileOptions != null ? compileOptions : DEFAULT_COMPILE_OPTIONS; + return this; + } + + CompileOptions compileOptions = DEFAULT_COMPILE_OPTIONS; + } + + public GpuDelegate(Options options) { + delegateHandle = + createDelegate( + options.compileOptions.precisionLossAllowed, + options.compileOptions.dynamicBatchEnabled, + options.compileOptions.preferredGlObjectType); + } + + public GpuDelegate() { + this(new Options()); + } + + /** + * Advanced: Binds a GL SSBO to an input or an output tensor in the initialized delegate. + * + *

The bound buffer should have sufficient storage to accommodate all elements of the tensor. + * + *

Note: This method must be called *before* calling the delegate instance is installed + * in the {@link Interpreter}. + * + *

WARNING: This is an experimental API and subject to change. + * + * @param tensor The input or output {@link Tensor} to bind to the buffer object. + * @param ssbo The GL buffer object to bind to the tensor. See also {@link + * Interpreter.Options#setAllowBufferHandleOutput()} for details on allowing zero-copy output + * when GL textures are bound to output tensors. + * @return Whether the operation succeeded. + */ + public boolean bindGlBufferToTensor(Tensor tensor, int ssbo) { + return bindGlBufferToTensor(delegateHandle, tensor.index(), ssbo); + } + + @Override + public long getNativeHandle() { + return delegateHandle; + } + + /** + * Frees TFLite resources in C runtime. + * + *

User is expected to call this method explicitly. + */ + @Override + public void close() { + if (delegateHandle != INVALID_DELEGATE_HANDLE) { + deleteDelegate(delegateHandle); + delegateHandle = INVALID_DELEGATE_HANDLE; + } + } + + static { + System.loadLibrary(TFLITE_GPU_LIB); + } + + private static native long createDelegate( + boolean precisionLossAllowed, boolean dynamicBatchEnabled, int preferredGlObjectType); + + private static native void deleteDelegate(long delegateHandle); + + private static native boolean bindGlBufferToTensor( + long delegateHandle, int tensorIndex, int ssbo); +} diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD b/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD new file mode 100644 index 00000000000..7f8162275f3 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/java/src/main/native/BUILD @@ -0,0 +1,31 @@ +# Description: +# Java Native Interface (JNI) library intended for implementing the +# TensorFlow Lite GPU delegate Java API using the TensorFlow Lite CC library. + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow/lite:build_def.bzl", "tflite_copts") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "native", + srcs = ["gpu_delegate_jni.cc"], + copts = tflite_copts(), + linkopts = select({ + "//tensorflow:android": [ + "-lGLESv3", + "-lEGL", + ], + "//conditions:default": [], + }), + tags = [ + "manual", + "notap", + ], + deps = [ + "//tensorflow/lite/delegates/gpu:gl_delegate", + "//tensorflow/lite/java/jni", + ], + alwayslink = 1, +) diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc new file mode 100644 index 00000000000..51e3ce130a8 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc @@ -0,0 +1,55 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate( + JNIEnv* env, jclass clazz, jboolean precision_loss_allowed, + jboolean dynamic_batch_enabled, jint preferred_gl_object_type) { + TfLiteGpuDelegateOptions options; + options.metadata = nullptr; + options.compile_options.precision_loss_allowed = + precision_loss_allowed == JNI_TRUE ? 1 : 0; + options.compile_options.preferred_gl_object_type = + static_cast(preferred_gl_object_type); + options.compile_options.dynamic_batch_enabled = + dynamic_batch_enabled == JNI_TRUE ? 1 : 0; + return reinterpret_cast(TfLiteGpuDelegateCreate(&options)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_deleteDelegate( + JNIEnv* env, jclass clazz, jlong delegate) { + TfLiteGpuDelegateDelete(reinterpret_cast(delegate)); +} + +JNIEXPORT jboolean JNICALL +Java_org_tensorflow_lite_gpu_GpuDelegate_bindGlBufferToTensor( + JNIEnv* env, jclass clazz, jlong delegate, jint tensor_index, jint ssbo) { + return TfLiteGpuDelegateBindBufferToTensor( + reinterpret_cast(delegate), + static_cast(ssbo), static_cast(tensor_index)) + ? JNI_TRUE + : JNI_FALSE; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD new file mode 100644 index 00000000000..855107bacdb --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/BUILD @@ -0,0 +1,139 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +DEFAULT_COPTS = [ + "-std=c++11", + "-Wno-shorten-64-to-32", +] + +cc_library( + name = "api", + srcs = ["api.cc"], + hdrs = ["api.h"], + deps = [ + ":compiled_model", + ":compute_task_descriptor", + ":environment", + ":runtime_options", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/metal/kernels", + ], +) + +objc_library( + name = "buffer_convert", + srcs = ["buffer_convert.mm"], + hdrs = ["buffer_convert.h"], + copts = DEFAULT_COPTS, + sdk_frameworks = [ + "Metal", + ], + deps = [ + ":common", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:util", + ], +) + +objc_library( + name = "common", + srcs = ["common.mm"], + hdrs = ["common.h"], + copts = DEFAULT_COPTS, + sdk_frameworks = [ + "Metal", + "UIKit", + ], + deps = [ + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + +cc_library( + name = "compiled_model", + srcs = ["compiled_model.cc"], + hdrs = ["compiled_model.h"], + deps = [ + ":compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/strings", + ], +) + +objc_library( + name = "compute_task", + srcs = ["compute_task.mm"], + hdrs = ["compute_task.h"], + copts = DEFAULT_COPTS, + sdk_frameworks = ["Metal"], + deps = [ + ":common", + ":compute_task_descriptor", + ":runtime_options", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + ], +) + +objc_library( + name = "compute_task_descriptor", + srcs = ["compute_task_descriptor.cc"], + hdrs = ["compute_task_descriptor.h"], + copts = DEFAULT_COPTS, + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@FP16", + ], +) + +objc_library( + name = "environment", + srcs = ["environment.mm"], + hdrs = ["environment.h"], + copts = DEFAULT_COPTS, + sdk_frameworks = [ + "Metal", + "UIKit", + ], + deps = [ + ":common", + ], +) + +objc_library( + name = "inference_context", + srcs = ["inference_context.mm"], + hdrs = ["inference_context.h"], + copts = DEFAULT_COPTS, + sdk_frameworks = ["Metal"], + deps = [ + ":compute_task", + ":compute_task_descriptor", + ":runtime_options", + "//tensorflow/lite/delegates/gpu/common:memory_management", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "runtime_options", + hdrs = ["runtime_options.h"], +) diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc new file mode 100644 index 00000000000..3588cd97169 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -0,0 +1,270 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/api.h" + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/concat.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/relu.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/reshape.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/slice.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/upsample.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::vector SelectConvolution( + const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) { + // Special precise version, in case we cover dst_shape poorly with standard + // work group size. + const auto dst_shape = graph.FindOutputs(id)[0]->tensor.shape; + if (GetThreadsRatioUsualToPreciseConvolution(dst_shape) >= 1.2f) { + // Special version for PowerVR >= IPhone6S/SE + // Metal has bad driver for PowerVR in IPhone6, so for Iphone6 we should use + // default kernel with shared memory. + if ((GetAppleSocVersion() == 9 || GetAppleSocVersion() == 10) && + CheckConvolutionPrecise1x1Support(attr)) { + return ConvolutionPrecise1x1PowerVR(id, input_id, output_id, attr, + options); + } + if (GetAppleSocVersion() >= 11 && + GetThreadsRatioUsualToPreciseConvolution(dst_shape) >= 1.2f) { + return ConvolutionPrecise(id, input_id, output_id, attr, options); + } + } + if (GetAppleSocVersion() >= 11) { + if (CheckConvolution1x1Support(attr)) { + return Convolution1x1(id, input_id, output_id, attr, options); + } else { + return ConvolutionGeneric(id, input_id, output_id, attr, options); + } + } else { + return Convolution(id, input_id, output_id, attr, options); + } +} + +std::vector SelectDepthWiseConv( + int id, ValueId input_id, ValueId output_id, + const DepthwiseConvolution2DAttributes& attr, + const metal::RuntimeOptions& options) { + if (CheckDepthWiseConv3x3Stride1x1Support(attr)) { + return DepthWiseConv3x3Stride1x1(id, input_id, output_id, attr, options); + } else if (CheckDepthWiseConv3x3Stride2Support(attr)) { + return DepthWiseConv3x3Stride2(id, input_id, output_id, attr, options); + } else { + return DepthWiseConvolution(id, input_id, output_id, attr, options); + } +} + +std::vector SelectReshape( + const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id, + const ReshapeAttributes& attr) { + const auto src_shape = graph.FindInputs(id)[0]->tensor.shape; + if (src_shape.c % 4 == 0 && attr.new_shape.c % 4 == 0) { + return Reshapex4(id, input_id, output_id, attr); + } else { + return Reshape(id, input_id, output_id, attr); + } +} + +std::vector SelectSoftmax(const GraphFloat32& graph, + int id, ValueId input_id, + ValueId output_id) { + const auto src_shape = graph.FindInputs(id)[0]->tensor.shape; + if (src_shape.w == 1 && src_shape.h == 1) { + return Softmax1x1(id, input_id, output_id, src_shape.c); + } else { + return Softmax(id, input_id, output_id, src_shape.c); + } +} + +} // namespace + +Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, + CompiledModel* compiled_model) { + for (const auto& node : graph.nodes()) { + int node_id = static_cast(node->id); + std::vector inputs; + for (auto& input : graph.FindInputs(node->id)) { + inputs.push_back(static_cast(input->id)); + } + std::vector outputs; + for (auto& output : graph.FindOutputs(node->id)) { + outputs.push_back(static_cast(output->id)); + } + + std::vector tasks; + auto op_type = OperationTypeFromString(node->operation.type); + switch (op_type) { + case OperationType::ADD: + tasks = AddTable(node_id, inputs, outputs[0]); + break; + case OperationType::CONCAT: { + std::vector input_shapes; + for (auto& input : graph.FindInputs(node->id)) { + input_shapes.push_back(input->tensor.shape); + } + tasks = + Concat(node_id, inputs, outputs[0], + absl::any_cast(node->operation.attributes), + input_shapes); + break; + } + case OperationType::CONVOLUTION_2D: + tasks = SelectConvolution( + graph, node_id, inputs[0], outputs[0], + absl::any_cast(node->operation.attributes), + options); + break; + case OperationType::CONVOLUTION_TRANSPOSED: + tasks = ConvolutionTransposed( + node_id, inputs[0], outputs[0], + absl::any_cast( + node->operation.attributes), + options); + break; + case OperationType::DEPTHWISE_CONVOLUTION: + tasks = SelectDepthWiseConv( + node_id, inputs[0], outputs[0], + absl::any_cast( + node->operation.attributes), + options); + break; + case OperationType::FULLY_CONNECTED: + tasks = FullyConnected(node_id, inputs[0], outputs[0], + absl::any_cast( + node->operation.attributes), + options); + break; + case OperationType::MAX_UNPOOLING_2D: + tasks = MaxUnpooling(node_id, inputs[0], inputs[1], outputs[0], + absl::any_cast( + node->operation.attributes)); + break; + case OperationType::MULTIPLY_SCALAR: + tasks = Multiply(node_id, inputs[0], outputs[0], + absl::any_cast( + node->operation.attributes), + options); + break; + case OperationType::PAD: + tasks = + Padding(node_id, inputs[0], outputs[0], + absl::any_cast(node->operation.attributes)); + break; + case OperationType::POOLING_2D: + tasks = Pooling( + node_id, inputs[0], outputs, + absl::any_cast(node->operation.attributes)); + break; + case OperationType::PRELU: + tasks = + PReLU(node_id, inputs[0], outputs[0], + absl::any_cast(node->operation.attributes), + options); + break; + case OperationType::RELU: + tasks = + ReLU(node_id, inputs[0], outputs[0], + absl::any_cast(node->operation.attributes)); + break; + case OperationType::RESHAPE: + tasks = SelectReshape( + graph, node_id, inputs[0], outputs[0], + absl::any_cast(node->operation.attributes)); + break; + case OperationType::SLICE: + tasks = + Slice(node_id, inputs[0], outputs[0], + absl::any_cast(node->operation.attributes)); + break; + case OperationType::SOFT_MAX: { + auto attr = + absl::any_cast(node->operation.attributes); + if (attr.axis != Axis::CHANNELS) { + return UnimplementedError("Softmax supports only CHANNELS dimension"); + } + tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]); + break; + } + case OperationType::UPSAMPLE_2D: + tasks = Upsample( + node_id, inputs[0], outputs[0], + absl::any_cast(node->operation.attributes)); + break; + + case OperationType::ABS: + case OperationType::COS: + case OperationType::LOG: + case OperationType::RSQRT: + case OperationType::SIGMOID: + case OperationType::SIN: + case OperationType::SQRT: + case OperationType::SQUARE: + case OperationType::TANH: + tasks = + ElementwiseWithOneInput(node_id, inputs[0], outputs[0], op_type); + break; + + case OperationType::SUB: + case OperationType::DIV: + case OperationType::POW: + case OperationType::SQUARED_DIFF: + tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type); + break; + + case OperationType::APPLY_MASK: + case OperationType::BATCH_NORMALIZATION: + case OperationType::BATCH_TO_SPACE: + case OperationType::CONST: + case OperationType::LSTM: + case OperationType::MUL: + case OperationType::RESIZE: + case OperationType::SPACE_TO_BATCH: + case OperationType::UNKNOWN: + return UnimplementedError("Unsupported op: " + node->operation.type); + } + compiled_model->insert(compiled_model->end(), tasks.begin(), tasks.end()); + } + return OkStatus(); +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/api.h b/tensorflow/lite/delegates/gpu/metal/api.h new file mode 100644 index 00000000000..dd3c423a612 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/api.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_API_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_API_H_ + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions. +Status Compile(const GraphFloat32& graph, const RuntimeOptions& options, + CompiledModel* compiled_model); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_API_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/buffer_convert.h b/tensorflow/lite/delegates/gpu/metal/buffer_convert.h new file mode 100644 index 00000000000..52738a9ebde --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/buffer_convert.h @@ -0,0 +1,41 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_CONVERT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_CONVERT_H_ + +#import + +#include "tensorflow/lite/delegates/gpu/common/shape.h" + +@interface TFLBufferConvert : NSObject + +/// Constructs converter from/to BHWC <-> BPHWC4 +/// @param isFloat16 the BPHWC4 buffer is in float16 format. +/// @param convertToPBHWC4 convert BHWC -> BPHWC4 if true or BPHWC4 -> BHWC instead. +- (id)initWithDevice:(id)device + isFloat16:(bool)isFloat16 + convertToPBHWC4:(bool)convertToPBHWC4; + +/// Converts from/to BHWC <-> BPHWC4 +/// @param shape shape of BHWC tensor. +- (void)convertWithEncoder:(id)encoder + shape:(const ::tflite::gpu::BHWC&)shape + sourceBuffer:(id)sourceBuffer + convertedBuffer:(id)convertedBuffer; + +@end + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_CONVERT_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/buffer_convert.mm b/tensorflow/lite/delegates/gpu/metal/buffer_convert.mm new file mode 100644 index 00000000000..8ddf78eac41 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/buffer_convert.mm @@ -0,0 +1,113 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#import "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" + +#import + +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/common.h" + +using ::tflite::gpu::IntegralDivideRoundUp; +using ::tflite::gpu::BHWC; +using ::tflite::gpu::metal::CreateComputeProgram; + +@implementation TFLBufferConvert { + id _program; +} + +- (id)initWithDevice:(id)device + isFloat16:(bool)isFloat16 + convertToPBHWC4:(bool)convertToPBHWC4 { + if (self = [super init]) { + std::string shaderSource; + if (convertToPBHWC4) { + shaderSource = R"( + #include + using namespace metal; + kernel void ComputeFunction(device float* const input_buffer [[buffer(0)]], + device FLT4* output_buffer [[buffer(1)]], + constant int4& size [[buffer(2)]], + uint3 gid[[thread_position_in_grid]]) { + if (int(gid.x) >= size.x || int(gid.y) >= size.y) { + return; + } + FLT4 value = FLT4(0.0); + for (int i = 0; i < 4; i++) { + int channel = gid.z * 4 + i; + if (channel >= size.z) break; + const int bhwc_index = (gid.y * size.x + gid.x) * size.z + channel; + value[i] = input_buffer[bhwc_index]; + } + const int bphwc4_index = (gid.z * size.y + gid.y) * size.x + gid.x; + output_buffer[bphwc4_index] = value; + } + )"; + } else { + shaderSource = R"( + #include + using namespace metal; + kernel void ComputeFunction(device FLT4* const input_buffer [[buffer(0)]], + device float* output_buffer [[buffer(1)]], + constant int4& size [[buffer(2)]], + uint3 gid[[thread_position_in_grid]]) { + if (int(gid.x) >= size.x || int(gid.y) >= size.y) { + return; + } + const int bphwc4_index = (gid.z * size.y + gid.y) * size.x + gid.x; + FLT4 value = input_buffer[bphwc4_index]; + for (int i = 0; i < 4; i++) { + int channel = gid.z * 4 + i; + if (channel >= size.z) break; + const int bhwc_index = (gid.y * size.x + gid.x) * size.z + channel; + output_buffer[bhwc_index] = value[i]; + } + } + )"; + } + NSDictionary* macros = @{@"FLT4" : (isFloat16 ? @"half4" : @"float4")}; + NSString* code = [NSString stringWithCString:shaderSource.c_str() + encoding:[NSString defaultCStringEncoding]]; + id program; + if (CreateComputeProgram(device, code, @"ComputeFunction", macros, &program).ok()) { + _program = program; + return self; + } + } + return nil; +} + +- (void)convertWithEncoder:(id)encoder + shape:(const BHWC&)shape + sourceBuffer:(id)sourceBuffer + convertedBuffer:(id)convertedBuffer { + [encoder setComputePipelineState:_program]; + [encoder setBuffer:sourceBuffer offset:0 atIndex:0]; + [encoder setBuffer:convertedBuffer offset:0 atIndex:1]; + + std::vector uniforms = {shape.w, shape.h, shape.c, shape.b}; + [encoder setBytes:uniforms.data() length:uniforms.size() * sizeof(int) atIndex:2]; + + MTLSize group_size = MTLSizeMake(16, 16, 1); + int layers = IntegralDivideRoundUp(shape.c, 4); + int groups_x = IntegralDivideRoundUp(shape.w, group_size.width); + int groups_y = IntegralDivideRoundUp(shape.h, group_size.height); + int groups_z = IntegralDivideRoundUp(layers, group_size.depth); + MTLSize groups_count = MTLSizeMake(groups_x, groups_y, groups_z); + [encoder dispatchThreadgroups:groups_count threadsPerThreadgroup:group_size]; +} + +@end diff --git a/tensorflow/lite/delegates/gpu/metal/common.h b/tensorflow/lite/delegates/gpu/metal/common.h new file mode 100644 index 00000000000..6d5d4d2c28e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/common.h @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMMON_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMMON_H_ + +#import + +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace metal { + +/// Returns system default device on iOS or Intel GPU on macOS. +id GetBestSupportedMetalDevice(); + +/// Returns version of the GPU that supports Metal. +/// @param device Used as a parameter because mac can contain multiple devices. +/// @discussion Refer to Apple docs for MTLFeatureSet_macOS_GPUFamily1_v1 for details. +/// 1 - Intel integrated GPU the only device that is supported +int GetMacOsGpuVersion(id device); + +/// Metal compute shader compilation +/// @param device The device on which that shader program will be stored. +/// @param code Shader source. +/// @param functionName The name of the main shader function. +/// @param macros Compile-time definitions. +/// @param program A non-nil pointer to the program object that will be filled. +/// @return Returns a valid program pointer or error string. At least one pointer is valid but not +/// both. +/// @discussion The function autoselects the maximum shader language version supported by the target +/// OS. FastMath is enabled. +::tflite::gpu::Status CreateComputeProgram(id device, NSString* code, + NSString* functionName, + NSDictionary* macros, + id* program); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMMON_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/common.mm b/tensorflow/lite/delegates/gpu/metal/common.mm new file mode 100644 index 00000000000..c48a43db78e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/common.mm @@ -0,0 +1,123 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/common.h" + +#import + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" + +// Compile-time message: print define name and value. +#define VALUE_TO_STRING(x) #x +#define VALUE(x) VALUE_TO_STRING(x) +#define VAR_NAME_VALUE(var) #var "=" VALUE(var) + +namespace tflite { +namespace gpu { +namespace metal { + +id GetBestSupportedMetalDevice() { return MTLCreateSystemDefaultDevice(); } + +int GetMacOsGpuVersion(id device) { + const std::vector> features = { +#if defined(__MAC_10_11) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_11 + {MTLFeatureSet_macOS_GPUFamily1_v1, 1}, +#endif +#if defined(__MAC_10_12) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_12 + {MTLFeatureSet_macOS_GPUFamily1_v2, 1}, +#endif +#if defined(__MAC_10_13) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_13 + {MTLFeatureSet_macOS_GPUFamily1_v3, 1}, +#endif +#if defined(__MAC_10_14) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_14 + {MTLFeatureSet_macOS_GPUFamily1_v4, 1}, + {MTLFeatureSet_macOS_GPUFamily2_v1, 2}, +#endif + }; + for (const auto& type : features) { + if ([device supportsFeatureSet:type.first]) { + return type.second; + } + } + return 0; +} + +Status CreateComputeProgram(id device, NSString* code, NSString* functionName, + NSDictionary* macros, + id* program) { + MTLCompileOptions* options = [[MTLCompileOptions alloc] init]; + +#if (defined(__MAC_10_14) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_14) || \ + (defined(__IPHONE_12_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_12_0) || \ + (defined(__TVOS_12_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_12_0) + [options setLanguageVersion:MTLLanguageVersion2_1]; +#elif (defined(__MAC_10_13) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_13) || \ + (defined(__IPHONE_11_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_11_0) || \ + (defined(__TVOS_11_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_11_0) + [options setLanguageVersion:MTLLanguageVersion2_0]; +#elif (defined(__MAC_10_12) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_12) || \ + (defined(__IPHONE_10_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_10_0) || \ + (defined(__TVOS_10_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_10_0) + [options setLanguageVersion:MTLLanguageVersion1_2]; +#elif (defined(__MAC_10_11) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_11) || \ + (defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0) || \ + (defined(__TVOS_9_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_9_0) + [options setLanguageVersion:MTLLanguageVersion1_1]; +#else +#pragma message(VAR_NAME_VALUE(__MAC_OS_X_VERSION_MIN_REQUIRED)) +#pragma message(VAR_NAME_VALUE(__IPHONE_OS_VERSION_MIN_REQUIRED)) +#pragma message(VAR_NAME_VALUE(__TV_OS_VERSION_MIN_REQUIRED)) +#if !defined(TARGET_OS_SIMULATOR) || TARGET_OS_SIMULATOR == 0 +// NOLINTBEGIN +#error \ + "The Metal delegate is not supported on current target SDK. Minimum supported os: iOS/tvOS 9.0, macOS 10.11" +// NOLINTEND +#endif +#endif + + [options setFastMathEnabled:YES]; + [options setPreprocessorMacros:macros]; + NSError* error = nil; + id library = [device newLibraryWithSource:code options:options error:&error]; + if (!library) { + NSString* errorString = + [NSString stringWithFormat:@"newLibraryWithSource: %@", [error localizedDescription]]; + return InternalError([errorString UTF8String]); + } + + id function = [library newFunctionWithName:functionName]; + if (!function) { + NSString* errorString = + [NSString stringWithFormat:@"newFunctionWithName: %@", [error localizedDescription]]; + return InternalError([errorString UTF8String]); + } + + *program = [device newComputePipelineStateWithFunction:function error:&error]; + if (!program) { + NSString* errorString = + [NSString stringWithFormat:@"newComputePipelineStateWithFunction error: %@", + [error localizedDescription]]; + return InternalError([errorString UTF8String]); + } + return OkStatus(); +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc new file mode 100644 index 00000000000..97f8dfc6fa2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc @@ -0,0 +1,584 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +// Allows to get result about the graph compilation to validate graph. This +// information helps to find a cause of performance degradation, like misfusing. +struct OptimizationInfo { + // Initial operations count before compilation. + int operations_count; + // GPU tasks count after fusion and splitting complex operations into few GPU + // subtasks. + int gpu_tasks_count; + // Some operations are not used due to dependencies of the graph. + std::vector unused_operations; + // Used inputs. + std::vector input_buffer_ids; + // Unused inputs. Requested outputs do not require this inputs to be used. + std::vector unused_input_buffer_ids; + // The outputs are deducted by the graph but not requested by user. + std::vector extra_output_buffer_ids; + // Outputs that are requested but can't be calculated by the graph. + std::vector missing_output_buffer_ids; +}; + +using FusionSequence = std::vector; + +bool Contains(const std::vector& container, ValueId value) { + return std::find(container.begin(), container.end(), value) != + container.end(); +} + +template +bool Contains(const std::vector& container, ValueId value) { + for (const auto& buffer : container) { + if (buffer.id == value) { + return true; + } + } + return false; +} + +// Checks if all elements of the narrow vector exist in the wide vector. Vectors +// are expected to be unsorted. +bool Contains(const std::vector& wide, + const std::vector& narrow) { + if (narrow.empty() || narrow.size() > wide.size()) { + return false; + } + std::set wide_sorted; + wide_sorted.insert(wide.begin(), wide.end()); + for (auto element : narrow) { + if (std::find(wide.begin(), wide.end(), element) == wide.end()) { + return false; + } + } + return true; +} + +// Checks if all elements of the narrow vector exist in the wide vector. Vectors +// are expected to be unsorted. +bool Contains( + const std::vector& wide, + const std::vector& buffers) { + if (buffers.empty() || buffers.size() > wide.size()) { + return false; + } + std::set wide_sorted(wide.begin(), wide.end()); + for (const auto& buffer : buffers) { + if (!std::binary_search(wide_sorted.begin(), wide_sorted.end(), + buffer.id)) { + return false; + } + } + return true; +} + +// Examines if the second operation can be linked to the first one. Linking may +// be skipped in the situation when conflic may happen: if first operation's +// output is used by more than 1 other operation. +bool CanFuseOperations(const ComputeTaskDescriptorPtr first, + const ComputeTaskDescriptorPtr second, + const std::vector& output_ids, + const std::list& descriptors) { + int use_count = 0; + if (second->is_linkable && !Contains(output_ids, first->output_buffer.id)) { + for (auto& desc : descriptors) { + if (Contains(desc->input_buffers, first->output_buffer.id)) { + use_count++; + } + } + } + return (use_count == 1); +} + +// Takes an unsorted list of task descriptors, builds a list of chains. Each +// chain is a list of task descriptors that can be fused into a single GPU task. +// Building is started from the input IDs and building statistic is filled. +void BuildFusableChains(const std::vector& input_ids, + const std::vector& output_ids, + std::list* descriptors, + std::list* chains, + std::vector* unused_ids) { + // Proxy tasks for inputs - only output is valid on this elements. + for (auto input_id : input_ids) { + auto desc = std::make_shared(); + desc->id = 0; + desc->is_linkable = true; + desc->output_buffer = {input_id}; + chains->push_back({desc}); + } + + if (descriptors->empty()) return; + // Get all possible operations - grow-up chains. + bool added; + do { + // At least one element must be added to any chain at this step. + added = false; + for (auto it = descriptors->begin(); it != descriptors->end();) { + const ComputeTaskDescriptorPtr task_descriptor = *it; + + // Gather all outputs of all chains to check with. + std::vector ready_buffer_ids; + ready_buffer_ids.reserve(chains->size()); + for (const auto& chain : *chains) { + ready_buffer_ids.push_back(chain.back()->output_buffer.id); + } + + // Check if all inputs of this operation are ready. + if (Contains(ready_buffer_ids, task_descriptor->input_buffers)) { + // Now find a chain to fuse with. + for (auto& chain : *chains) { + // We can fuse only single output for now. + if (Contains(task_descriptor->input_buffers, + chain.back()->output_buffer.id)) { + if (CanFuseOperations(chain.back(), task_descriptor, output_ids, + *descriptors)) { + chain.push_back(task_descriptor); + } else { + // Start new chain. + chains->push_back({task_descriptor}); + } + break; + } + } + + // Remove operation from original list and start from the beginning. + descriptors->erase(it); + added = true; + break; + } else { + ++it; + } + } + } while (!descriptors->empty() && added); + + unused_ids->reserve(descriptors->size()); + for (const auto& desc : *descriptors) { + unused_ids->push_back(desc->id); + } +} + +// Accepts unsorted list of chains and returns sorted list with the order of GPU +// task execution. +std::list SortChains( + const std::vector& graph_input_ids, + std::list* chains) { + std::list sorted_chains; + while (!chains->empty()) { + // Collect ready buffers. + std::vector ready_buffer_ids; + ready_buffer_ids.reserve(graph_input_ids.size() + sorted_chains.size()); + ready_buffer_ids.insert(ready_buffer_ids.begin(), graph_input_ids.begin(), + graph_input_ids.end()); + for (auto& chain : sorted_chains) { + ready_buffer_ids.push_back(chain.back()->output_buffer.id); + } + + for (auto it = chains->begin(); it != chains->end();) { + const FusionSequence& chain = *it; + + // If the input is also is the output in the same chain - eliminate + // because it used internally inside this chain only. + std::vector elements_output_buffer_ids; + elements_output_buffer_ids.reserve(chain.size()); + for (const ComputeTaskDescriptorPtr& element : chain) { + elements_output_buffer_ids.push_back(element->output_buffer.id); + } + + // Collect all inputs also for linked operations. + std::vector elements_input_buffer_ids; + for (auto element : chain) { + for (const auto& buffer : element->input_buffers) { + if (!Contains(elements_output_buffer_ids, buffer.id)) { + elements_input_buffer_ids.push_back(buffer.id); + } + } + } + + if (Contains(ready_buffer_ids, elements_input_buffer_ids)) { + // All input buffers for all elements of this chain are ready. + sorted_chains.push_back(chain); + it = chains->erase(it); + } else { + ++it; + } + } + } + return sorted_chains; +} + +// If a graph structure contains unused outputs then it can lead to unused +// operations and unused input buffers. It's not an error but some sort of +// warning. +std::vector GetUsedInputBufferIds( + const std::list& sorted_chains) { + // Match requested outputs with all outputs and intermediate buffers. + std::vector output_and_intermediate_ids; + output_and_intermediate_ids.reserve(sorted_chains.size()); + std::set input_and_intermediate_ids; + for (auto it = sorted_chains.begin(); it != sorted_chains.end(); ++it) { + output_and_intermediate_ids.push_back(it->back()->output_buffer.id); + for (const auto& buffer : it->front()->input_buffers) { + input_and_intermediate_ids.insert(buffer.id); + } + } + std::vector input_ids; + for (ValueId id : input_and_intermediate_ids) { + if (!Contains(output_and_intermediate_ids, id)) { + input_ids.push_back(id); + } + } + return input_ids; +} + +// If a buffer is requested as output from the graph but the graph structure +// can't provide this buffer by output (can't deduct), that means the graph +// structure is incorrect. +std::vector GetMissingOutputBufferIds( + const std::vector& output_ids, + const std::list& sorted_chains) { + // Match requested outputs with all output and intermediate buffers. + std::vector output_and_intermediate_ids; + output_and_intermediate_ids.reserve(sorted_chains.size()); + for (auto it = sorted_chains.begin(); it != sorted_chains.end(); ++it) { + output_and_intermediate_ids.push_back(it->back()->output_buffer.id); + } + std::vector missing_output_ids; + for (ValueId id : output_ids) { + if (!Contains(output_and_intermediate_ids, id)) { + missing_output_ids.push_back(id); + } + } + return missing_output_ids; +} + +// Graph may contain leafs with outputs that are not requested. It wastes GPU +// computations. +std::vector DeductOutputBufferIds( + const std::vector& output_ids, + const std::list& sorted_chains) { + std::vector extra_output_ids; + // Detect all unused output buffers - all outputs. + for (auto it1 = sorted_chains.begin(); it1 != sorted_chains.end(); ++it1) { + bool found_as_input = false; + for (auto it2 = sorted_chains.begin(); it2 != sorted_chains.end(); ++it2) { + if (it1 != it2) { + std::vector input_ids; + for (auto element : *it2) { + for (const auto& buffer : element->input_buffers) { + input_ids.push_back(buffer.id); + } + } + if (Contains(input_ids, it1->back()->output_buffer.id)) { + found_as_input = true; + break; + } + } + } + if (!found_as_input) { + if (!Contains(output_ids, it1->back()->output_buffer.id)) { + extra_output_ids.push_back(it1->back()->output_buffer.id); + } + } + } + return extra_output_ids; +} + +// Delete all unused task descriptors that have non-requested outputs. +// TODO(chirkov): delete not the whole chain but only the last element, then +// others. +std::vector DeleteUnusedTasks(const std::vector& output_ids, + std::list* chains) { + std::vector unused_operations; + for (auto it1 = chains->rbegin(); it1 != chains->rend();) { + // Don't delete if output is requested. + if (Contains(output_ids, it1->back()->output_buffer.id)) { + ++it1; + continue; + } + + // Don't delete if some operation uses the output. + bool output_used = false; + for (auto it2 = chains->rbegin(); it2 != chains->rend(); ++it2) { + std::vector input_ids; + for (auto element : *it2) { + for (const auto& buffer : element->input_buffers) { + input_ids.push_back(buffer.id); + } + } + if (Contains(input_ids, it1->back()->output_buffer.id)) { + output_used = true; + break; + } + } + if (output_used) { + ++it1; + continue; + } + // Delete if not used. + unused_operations.push_back(it1->back()->id); + it1 = decltype(it1){chains->erase(std::next(it1).base())}; + } + return unused_operations; +} + +// Returns unused input buffer IDs. +void RemoveInputProxies(std::list* chains) { + // Remove input proxy and sort items. + for (auto it = chains->begin(); it != chains->end();) { + auto& chain = *it; + // Remove input proxy-operations. + if (chain.front()->input_buffers.empty()) { + chain.erase(chain.begin()); + } + if (chain.empty()) { + // Input proxy operation has been deleted and the chain is empty due to + // unused input buffer. + it = chains->erase(it); + } else { + ++it; + } + } +} + +ComputeTaskDescriptorPtr NonLinkableStub(int operation_id, ValueId input_id, + ValueId output_id) { + auto desc = std::make_shared(); + desc->id = operation_id; + desc->is_linkable = false; + desc->shader_source = R"( + #include + using namespace metal; + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (int(gid.x) >= size.x || int(gid.y) >= size.y) { + return; + } + const int linear_index = (gid.z * size.y + gid.y) * size.x + gid.x; + FLT4 value = input_buffer[linear_index]; + $2 + output_buffer[linear_index] = value; + } + )"; + + desc->input_buffers = { + {input_id, "device FLT4* const input_buffer"}, + }; + + desc->output_buffer = {output_id, "device FLT4* output_buffer", + [input_id](const std::map& buffers) { + return buffers.find(input_id)->second; + }}; + + desc->uniform_buffers = { + {"constant int2& size", + [input_id](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + return VectorToUint8Vector(std::vector{dimension.w, dimension.h}); + }}, + }; + + desc->resize_function = [input_id](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + uint3 groups_size{16, 16, 1}; + uint3 groups_count{AlignByN(dimension.w, groups_size.x), + AlignByN(dimension.h, groups_size.y), + AlignByN(dimension.c, 4)}; + return std::make_pair(groups_size, groups_count); + }; + + return {desc}; +} + +ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { + auto fused_desciptor = std::make_shared(); + // The id of fused descriptor is the id of the first descriptor in the list. + fused_desciptor->id = chain.front()->id; + FusionSequence sequence; + if (chain.front()->is_linkable) { + // The first task is linkable so it contains only linkable code. Insert + // unlinkable meta-task with remaining shader code. + sequence.push_back(NonLinkableStub(-1, chain.front()->input_buffers[0].id, + chain.front()->input_buffers[0].id)); + } + sequence.insert(sequence.end(), chain.begin(), chain.end()); + + // Count buffers to calculate proper indices then. + int num_outputs = 1; + int num_inputs = 0; + int num_immutables = 0; + bool invalid_id = true; + ValueId fused_id; + for (const auto& desc : sequence) { + for (const auto& buffer : desc->input_buffers) { + if (invalid_id || buffer.id != fused_id) { + num_inputs++; + } + } + fused_id = desc->output_buffer.id; + invalid_id = false; + num_immutables += desc->immutable_buffers.size(); + } + + int output_index = 0; + int input_index = num_outputs; + int immutable_index = num_outputs + num_inputs; + int uniform_index = num_outputs + num_inputs + num_immutables; + + int function_index = 0; + std::string function_code; + std::string buffer_declarations; + std::string call_code; + invalid_id = true; + for (const auto& desc : sequence) { + if (desc->is_linkable) { + function_code += + absl::Substitute(desc->shader_source, function_index) + "\n"; + } else { + // Declare output buffer only for the first unlinkable task. + buffer_declarations += + desc->output_buffer.declaration + "[[buffer(0)]],\n"; + output_index++; + } + + std::string call_arguments; + for (const auto& buffer : desc->input_buffers) { + if (invalid_id || buffer.id != fused_id) { + std::string index = std::to_string(input_index); + std::string name = (desc->is_linkable ? (" buffer" + index) : ""); + buffer_declarations += + buffer.declaration + name + "[[buffer(" + index + ")]],\n"; + call_arguments += ", buffer" + index; + input_index++; + fused_desciptor->input_buffers.push_back({buffer.id, ""}); + } + } + // We have an output id that is the input for the next task. + fused_id = desc->output_buffer.id; + invalid_id = false; + + for (auto buffer : desc->immutable_buffers) { + std::string index = std::to_string(immutable_index); + std::string name = (desc->is_linkable ? (" buffer" + index) : ""); + buffer_declarations += + buffer.declaration + name + "[[buffer(" + index + ")]],\n"; + call_arguments += ", buffer" + index; + immutable_index++; + fused_desciptor->immutable_buffers.push_back(buffer); + } + + for (auto buffer : desc->uniform_buffers) { + std::string index = std::to_string(uniform_index); + std::string name = (desc->is_linkable ? (" buffer" + index) : ""); + buffer_declarations += + buffer.declaration + name + "[[buffer(" + index + ")]],\n"; + call_arguments += ", buffer" + index; + uniform_index++; + fused_desciptor->uniform_buffers.push_back({"", buffer.data_function}); + } + + if (desc->is_linkable) { + call_code += + absl::Substitute("value = linkable$0(value, linear_index, gid$1);\n", + function_index, call_arguments); + function_index++; + } + } + + ComputeTaskDescriptorPtr non_linkable = sequence.front(); + fused_desciptor->shader_source = + absl::Substitute(non_linkable->shader_source, function_code, + buffer_declarations, call_code); + std::vector alias; + alias.reserve(chain.size() - 1); + for (int i = 0; i < chain.size() - 1; i++) { + alias.push_back(chain[i]->output_buffer.id); + } + fused_desciptor->output_buffer = { + fused_id, "", non_linkable->output_buffer.dimensions_function, alias}; + fused_desciptor->resize_function = non_linkable->resize_function; + return fused_desciptor; +} + +} // namespace + +Status ValidateOptimizeModel(const std::vector& input_buffers, + const std::vector& output_buffers, + const CompiledModel& input_vector, + CompiledModel* output) { + std::list input; + input.insert(input.end(), input_vector.begin(), input_vector.end()); + OptimizationInfo info; + info.operations_count = static_cast(input.size()); + + // A chain is a sequence of fusable operations. All internal outputs are + // consumed with the next element of the chain. The last element of each chain + // contains outputs which are ready to be used as inputs. if a chain can't be + // extended with linkable element then new chain is created. + std::list unsorted_chains; + BuildFusableChains(input_buffers, output_buffers, &input, &unsorted_chains, + &info.unused_operations); + + RemoveInputProxies(&unsorted_chains); + std::list sorted_chains = + SortChains(input_buffers, &unsorted_chains); + + info.extra_output_buffer_ids = + DeductOutputBufferIds(output_buffers, sorted_chains); + info.unused_operations = DeleteUnusedTasks(output_buffers, &sorted_chains); + info.input_buffer_ids = GetUsedInputBufferIds(sorted_chains); + // find provided input buffers that has not being used + for (ValueId id : input_buffers) { + if (!Contains(info.input_buffer_ids, id)) { + info.unused_input_buffer_ids.push_back(id); + } + } + info.missing_output_buffer_ids = + GetMissingOutputBufferIds(output_buffers, sorted_chains); + info.gpu_tasks_count = static_cast(sorted_chains.size()); + if (sorted_chains.empty()) { + return InternalError("Empty chains"); + } + for (const auto& chain : sorted_chains) output->push_back(FuseChain(chain)); + return OkStatus(); +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.h b/tensorflow/lite/delegates/gpu/metal/compiled_model.h new file mode 100644 index 00000000000..5f9982d0a66 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPILED_MODEL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPILED_MODEL_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +using CompiledModel = std::vector; + +// Receives input CompiledModel, validates, optimizes it and returns output +// CompiledModel. No shader compilation or memory allocation happen here, this +// function just does high-level operations fusion. +Status ValidateOptimizeModel(const std::vector& input_buffers, + const std::vector& output_buffers, + const CompiledModel& input, CompiledModel* output); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPILED_MODEL_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.h b/tensorflow/lite/delegates/gpu/metal/compute_task.h new file mode 100644 index 00000000000..611185b8fc1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.h @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ + +#import + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +@interface TFLComputeTask : NSObject + +/// Returns empty string or error if shader can't be compiled. +- (::tflite::gpu::Status)compileWithDevice:(id)device + taskDescriptor:(::tflite::gpu::metal::ComputeTaskDescriptorPtr)desc + runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; + +/// Updates dimensions for inputs/outputs/intermediate tensors +- (::tflite::gpu::Status) + setInputDimensionsWithDevice:(id)device + dimensions:(std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions; + +/// Updates buffers for intermediate tensors only. Returns error if out of memory or a buffer is +/// larger than MTLDevice can support. +/// @param buffers is a map from intermediate tensors' ValueId to metal handles with corresponding +/// buffers. +/// @param outputIDs must match the output of added operations. +/// @param usageRecordIds is a map from intermediate tensors' ValueId to corresponding tensor usage +/// records ids. +/// @param sharedBufferIds contain shared buffer id for each tensor usage record id. +/// @param sharedBuffers contain metal handles to the allocated buffers for each shared buffer id. +/// TODO(ypisarchyk): probably we can decrease the number of parameters here +- (::tflite::gpu::Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers + outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds + usageRecordIds: + (const std::map<::tflite::gpu::ValueId, size_t>&)usageRecordIds + sharedBufferIds:(const std::vector&)sharedBufferIds + sharedBuffers:(const std::vector>&)sharedBuffers; + +- (void)encodeWithEncoder:(id)encoder + inputOutputBuffers: + (const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers; + +@end + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.mm b/tensorflow/lite/delegates/gpu/metal/compute_task.mm new file mode 100644 index 00000000000..2c2926d69b0 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.mm @@ -0,0 +1,255 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/compute_task.h" + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/common.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +using ::tflite::gpu::AlignByN; +using ::tflite::gpu::BHWC; +using ::tflite::gpu::InternalError; +using ::tflite::gpu::InvalidArgumentError; +using ::tflite::gpu::HalfBits; +using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; +using ::tflite::gpu::metal::CreateComputeProgram; +using ::tflite::gpu::metal::DispatchParamsFunction; +using ::tflite::gpu::metal::OutputDimensions; +using ::tflite::gpu::metal::RuntimeOptions; +using ::tflite::gpu::metal::UniformsFunction; +using ::tflite::gpu::OkStatus; +using ::tflite::gpu::Status; +using ::tflite::gpu::uint3; +using ::tflite::gpu::ValueId; + +@implementation TFLComputeTask { + struct InputBuffer { + ValueId uid; + id metalHandle; + }; + struct OutputBuffer { + ValueId uid; + id metalHandle; + OutputDimensions dimensionsFunction; + std::vector alias; + }; + struct UniformBuffer { + std::vector data; + UniformsFunction dataFunction; + }; + + id _program; + std::vector _inputBuffers; + std::vector _outputBuffers; + std::vector> _immutableBuffers; + std::vector _uniformBuffers; + uint3 _groupsSize; + uint3 _groupsCount; + DispatchParamsFunction _resizeFunction; +} + +- (Status)compileWithDevice:(id)device + taskDescriptor:(ComputeTaskDescriptorPtr)desc + runtimeOptions:(const RuntimeOptions&)options { +#if (defined(__MAC_10_13) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_13) || \ + (defined(__IPHONE_10_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_10_0) || \ + (defined(__TVOS_10_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_10_0) + NSString* barrier = @"simdgroup_barrier"; +#else + NSString* barrier = @"threadgroup_barrier"; +#endif + NSString* storageType; + NSString* accumulatorType; + NSString* toAccumulatorType = @""; + NSString* toAccumulatorType2 = @""; + NSString* toAccumulatorType3 = @""; + NSString* toAccumulatorType4 = @""; + if (options.storage_precision == RuntimeOptions::Precision::FP32) { + storageType = @"float"; + accumulatorType = @"float"; + } else { + // FP16 + storageType = @"half"; + if (options.accumulator_precision == RuntimeOptions::Precision::FP32) { + accumulatorType = @"float"; + toAccumulatorType = @"float"; + toAccumulatorType2 = @"float2"; + toAccumulatorType3 = @"float3"; + toAccumulatorType4 = @"float4"; + } else { + accumulatorType = @"half"; + } + } + NSDictionary* macros = @{ + @"FLT" : storageType, + @"FLT2" : [NSString stringWithFormat:@"%@2", storageType], + @"FLT3" : [NSString stringWithFormat:@"%@3", storageType], + @"FLT4" : [NSString stringWithFormat:@"%@4", storageType], + @"ACCUM_FLT" : accumulatorType, + @"ACCUM_FLT2" : [NSString stringWithFormat:@"%@2", accumulatorType], + @"ACCUM_FLT3" : [NSString stringWithFormat:@"%@3", accumulatorType], + @"ACCUM_FLT4" : [NSString stringWithFormat:@"%@4", accumulatorType], + @"TO_ACCUM_TYPE" : toAccumulatorType, + @"TO_ACCUM2_TYPE" : toAccumulatorType2, + @"TO_ACCUM3_TYPE" : toAccumulatorType3, + @"TO_ACCUM4_TYPE" : toAccumulatorType4, + @"BARRIER" : barrier, + }; + + NSString* code = [NSString stringWithCString:desc->shader_source.c_str() + encoding:[NSString defaultCStringEncoding]]; + id program; + RETURN_IF_ERROR(CreateComputeProgram(device, code, @"ComputeFunction", macros, &program)); + if (!program) { + return InternalError("Unknown shader compilation error"); + } + for (auto& buffer : desc->input_buffers) { + _inputBuffers.emplace_back(InputBuffer{buffer.id, nil}); + } + for (auto& uniform : desc->uniform_buffers) { + _uniformBuffers.emplace_back(UniformBuffer{{}, uniform.data_function}); + } + _outputBuffers.emplace_back(OutputBuffer{desc->output_buffer.id, nil, + desc->output_buffer.dimensions_function, + desc->output_buffer.alias}); + for (auto& immutable : desc->immutable_buffers) { + int padding = + 4 * (options.storage_precision == RuntimeOptions::Precision::FP32 ? sizeof(float) + : sizeof(HalfBits)); + int paddedSize = AlignByN(immutable.data.size(), padding); + immutable.data.resize(paddedSize); + id metalBuffer = [device newBufferWithBytes:immutable.data.data() + length:immutable.data.size() + options:MTLResourceStorageModeShared]; + _immutableBuffers.emplace_back(metalBuffer); + } + _resizeFunction = desc->resize_function; + _program = program; + return OkStatus(); +} + +- (Status)setInputDimensionsWithDevice:(id)device + dimensions: + (std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)dimensions { + // Re-calculate output buffers dimensions + for (auto& buffer : _outputBuffers) { + auto outputDimensions = buffer.dimensionsFunction(*dimensions); + for (ValueId duplicate : buffer.alias) { + (*dimensions)[duplicate] = outputDimensions; + } + // Store buffer dimensions + (*dimensions)[buffer.uid] = outputDimensions; + } + + for (auto& uniform : _uniformBuffers) { + uniform.data = uniform.dataFunction(*dimensions); + } + + // Dispatch parameters re-calculation + auto workGroups = _resizeFunction(*dimensions); + _groupsSize = workGroups.first; + MTLSize threadsPerGroup = [device maxThreadsPerThreadgroup]; + if (_groupsSize.x > threadsPerGroup.width || _groupsSize.y > threadsPerGroup.height || + _groupsSize.z > threadsPerGroup.depth) { + std::string error("Threads per working group: "); + error += std::to_string(_groupsSize.x) + ", " + std::to_string(_groupsSize.y) + ", " + + std::to_string(_groupsSize.z); + error += "is larger than the MTLDevice can support: "; + error += std::to_string(threadsPerGroup.width) + ", " + std::to_string(threadsPerGroup.height) + + ", " + std::to_string(threadsPerGroup.depth); + return InvalidArgumentError(error); + } + _groupsCount = workGroups.second; + return OkStatus(); +} + +- (Status)assignBuffers:(std::map<::tflite::gpu::ValueId, id>*)buffers + outputIds:(const std::vector<::tflite::gpu::ValueId>&)outputIds + usageRecordIds:(const std::map&)usageRecordIds + sharedBufferIds:(const std::vector&)sharedBufferIds + sharedBuffers:(const std::vector>&)sharedBuffers { + for (auto& buffer : _outputBuffers) { + // If the buffer is intermediate: set its metalHandle from sharedBuffers + if (std::find(outputIds.begin(), outputIds.end(), buffer.uid) == outputIds.end()) { + auto usageRecordIt = usageRecordIds.find(buffer.uid); + if (usageRecordIt == usageRecordIds.end()) { + return InternalError("TensorUsageRecord for intermediate tensor is not found."); + } + buffer.metalHandle = sharedBuffers.at(sharedBufferIds.at(usageRecordIt->second)); + (*buffers)[buffer.uid] = buffer.metalHandle; + } + } + + // Re-assign input buffers + for (auto& buffer : _inputBuffers) { + buffer.metalHandle = (*buffers)[buffer.uid]; + } + return OkStatus(); +} + +- (void)encodeWithEncoder:(id)encoder + inputOutputBuffers:(const std::map>&)inputOutputBuffers { + // The dispatch call is intended to be skipped. + if (_groupsCount.x * _groupsCount.y * _groupsCount.z == 0) { + return; + } + + [encoder setComputePipelineState:_program]; + + int bindIndex = 0; + for (auto& buffer : _outputBuffers) { + const auto externalBuffer = inputOutputBuffers.find(buffer.uid); + if (externalBuffer == inputOutputBuffers.end()) { + [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex]; + } else { + // the buffer is input or output + [encoder setBuffer:externalBuffer->second offset:0 atIndex:bindIndex]; + } + bindIndex++; + } + for (auto& buffer : _inputBuffers) { + const auto externalBuffer = inputOutputBuffers.find(buffer.uid); + if (externalBuffer == inputOutputBuffers.end()) { + [encoder setBuffer:buffer.metalHandle offset:0 atIndex:bindIndex]; + } else { + // the buffer is input or output + [encoder setBuffer:externalBuffer->second offset:0 atIndex:bindIndex]; + } + bindIndex++; + } + for (auto& immutable : _immutableBuffers) { + [encoder setBuffer:immutable offset:0 atIndex:bindIndex]; + bindIndex++; + } + for (auto& uniform : _uniformBuffers) { + [encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex]; + bindIndex++; + } + + MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z); + MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z); + [encoder dispatchThreadgroups:groupsCount threadsPerThreadgroup:groupsSize]; +} + +@end diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc new file mode 100644 index 00000000000..84a15a705ba --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.cc @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +#include +#include + +#include +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace metal { + +/// Helper function to convert buffer's content into stream of bytes +std::vector VectorFloatToHalf(const std::vector& input_vector) { + std::vector result; + result.reserve(input_vector.size()); + for (const float v : input_vector) { + result.push_back(fp16_ieee_from_fp32_value(v)); + } + return VectorToUint8Vector(result); +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h new file mode 100644 index 00000000000..6d528db951e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h @@ -0,0 +1,133 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_DESCRIPTOR_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_DESCRIPTOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace metal { + +using OutputDimensions = + std::function& buffers)>; +using UniformsFunction = + std::function(const std::map& buffers)>; +using DispatchParamsFunction = std::function( + const std::map& buffers)>; + +// Compute task descriptor contains a linkable shader code or a code for +// complete shader to which other linkable can be attached or not. An operation +// can produce one or more descriptors and graph compiler uses descriptors as +// building blocks. All required data like immutable operation parameters +// (weights etc.) is attached to the descriptor. +struct ComputeTaskDescriptor { + struct InputBufferDescriptor { + ValueId id; + // The declaration is inserted into the compute function arguments list. + // Example for non-linkable task: "device FLT4* const input_buffer" + // Example for linkable: "device FLT4* const" + std::string declaration; + }; + struct OutputBufferDescriptor { + ValueId id; + // The declaration is inserted into the compute function arguments list. + // Example for non-linkable task: "device FLT4* output_buffer" + // Example for linkable: "device FLT4*" + std::string declaration; + // Multiple outputs are allowed from a linkable operation so after fusion + // each buffer's dimensions are calculated separately from different + // operations. + OutputDimensions dimensions_function; + // Fusion absorbs intermediate tensors. Keep this ids to properly store + // output dimensions. + std::vector alias; + }; + struct ImmutableBufferDescriptor { + std::string declaration; + std::vector data; + }; + // Uniforms are recalculated at any setInputDimensions call. + struct UniformBufferDescriptor { + // The declaration is inserted into the compute function arguments list. + // Example: "constant uint4& some_uniforms" + std::string declaration; + // This function re-calculates uniforms for specific input dimensions. + UniformsFunction data_function; + }; + + // Unique ID to match the graph compilation errors. + int id; + bool is_linkable; + // A linkable function or a full shader source with 3 parameters $ for + // substitute function. Example of linkable: "(FLT4 linkable$0(FLT4 value, int + // linear_index) { return value; })" Example of non-linkable function: + // #include + // using namespace metal; + // $0 + // kernel void ComputeFunction( + // $1 + // uint3 gid[[thread_position_in_grid]]) { + // if (int(gid.x) >= size.x || int(gid.y) >= size.y) { + // return; + // } + // const int linear_index = (gid.z * size.y + gid.y) * size.x + gid.x; + // FLT4 value = input_buffer[linear_index] + 1.0f; + // $2 + // output_buffer[linear_index] = value; + // } + std::string shader_source; + std::vector input_buffers; + // A single per-operation output is supported now. + OutputBufferDescriptor output_buffer; + std::vector immutable_buffers; + std::vector uniform_buffers; + // Dynamic resizing of input tensor is supported. User-defined functions to + // calculate new parameters for GPU compute task dispatching. A leading + // unlinkable task must provide this. + DispatchParamsFunction resize_function; +}; + +using ComputeTaskDescriptorPtr = std::shared_ptr; + +/// Helper function to convert buffer's content into stream of bytes +template +std::vector VectorToUint8Vector(const std::vector& input_vector) { + std::vector result; + result.insert(result.begin(), + reinterpret_cast(input_vector.data()), + reinterpret_cast(input_vector.data()) + + input_vector.size() * sizeof(*input_vector.data())); + return result; +} + +/// Helper function to convert FP32 to FP16 and into stream of bytes +std::vector VectorFloatToHalf(const std::vector& input_vector); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_COMPUTE_TASK_DESCRIPTOR_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/environment.h b/tensorflow/lite/delegates/gpu/metal/environment.h new file mode 100644 index 00000000000..90da1f3403e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/environment.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_ + +namespace tflite { +namespace gpu { +namespace metal { + +// Returns runtime operation system version. Example 10.1 +float GetiOsSystemVersion(); + +// Returns Apple SoC generation number. The list of Apple SoC that support Metal +// API: +// 7 - A7 iPhone 5s, iPad Air, iPad Mini 2, iPad Mini 3. +// 8 - A8 iPhone 6, A8X iPad Air 2, iPad Mini 4. +// 9 - A9 iPhone 6s, iPad (2017), A9X iPad Pro (1st generation). +// 10 - A10 iPhone 7, iPad (2018), A10X iPad Pro (2nd generation). +// 11 - A11 iPhone 8/X. +// 12 - A12 iPhone Xs. +int GetAppleSocVersion(); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/environment.mm b/tensorflow/lite/delegates/gpu/metal/environment.mm new file mode 100644 index 00000000000..6a189c25b68 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/environment.mm @@ -0,0 +1,73 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/environment.h" + +#import +#import + +#include +#include + +#include "tensorflow/lite/delegates/gpu/metal/common.h" + +namespace tflite { +namespace gpu { +namespace metal { + +float GetiOsSystemVersion() { return [[[UIDevice currentDevice] systemVersion] floatValue]; } + +int GetAppleSocVersion() { + std::vector> features = { +#if defined(__IPHONE_8_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_8_0 + {MTLFeatureSet_iOS_GPUFamily1_v1, 7}, + {MTLFeatureSet_iOS_GPUFamily2_v1, 8}, +#endif +#if defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0 + {MTLFeatureSet_iOS_GPUFamily1_v2, 7}, + {MTLFeatureSet_iOS_GPUFamily2_v2, 8}, + {MTLFeatureSet_iOS_GPUFamily3_v1, 9}, +#endif +#if defined(__IPHONE_10_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_10_0 + {MTLFeatureSet_iOS_GPUFamily1_v3, 7}, + {MTLFeatureSet_iOS_GPUFamily2_v3, 8}, + {MTLFeatureSet_iOS_GPUFamily3_v2, 9}, +#endif +#if defined(__IPHONE_11_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_11_0 + {MTLFeatureSet_iOS_GPUFamily2_v4, 8}, + {MTLFeatureSet_iOS_GPUFamily3_v3, 9}, + {MTLFeatureSet_iOS_GPUFamily4_v1, 11}, +#endif +#if defined(__IPHONE_12_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_12_0 + {MTLFeatureSet_iOS_GPUFamily1_v5, 7}, + {MTLFeatureSet_iOS_GPUFamily2_v5, 8}, + {MTLFeatureSet_iOS_GPUFamily3_v4, 9}, + {MTLFeatureSet_iOS_GPUFamily4_v2, 11}, + {MTLFeatureSet_iOS_GPUFamily5_v1, 12}, +#endif + }; + id device = GetBestSupportedMetalDevice(); + int max_feature_set = 0; + for (auto &type : features) { + if ([device supportsFeatureSet:type.first]) { + max_feature_set = std::max(max_feature_set, type.second); + } + } + return max_feature_set; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.h b/tensorflow/lite/delegates/gpu/metal/inference_context.h new file mode 100644 index 00000000000..536b87f780c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.h @@ -0,0 +1,87 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ + +#import + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +/// Stages of model preprocessing: +/// 1. Operations' initialization. All operations are initialized and added into +/// model. Every operation is represented as a vector of +/// ComputeTaskDescriptors. +/// 2. Model compilation. Global list of ComputeTaskDescriptors is transformed +/// into the sorted list of sets of descriptors. A set can be transformed +/// later into a single GPU task. +/// 3. GPU compute tasks generation. Shader code generation happes here. +/// 4. Intermediate resource allocation. +/// Inference. +@interface TFLInferenceContext : NSObject + +/// Compiles model: groups operations to be fused; validates model structure. +/// @param device Used to create resources: shaders, buffers. Also the device is used in +/// consecutive call setInputDimensions(). +/// @param taskDescriptors The ordered vector of shader programs ready to be compiled for GPU and +/// with all supplementary buffers data. +/// @param outputBufferIDs IDs must match the output of added operations. +/// @param runtimeOptions Options are used to specify data/calculations precision. +/// @return Status signals whether model is compiled successfully or not. +/// @discussion Previously added operations are distilled into sorted list of sets of +/// ComputeTaskDescriptors, which can be fused into a single GPU task. +- (::tflite::gpu::Status) + compileModelWithDevice:(id)device + taskDescriptors: + (const std::vector<::tflite::gpu::metal::ComputeTaskDescriptorPtr>&)taskDescriptors + outputBufferIDs:(const std::vector<::tflite::gpu::ValueId>&)outputBufferIDs + runtimeOptions:(const ::tflite::gpu::metal::RuntimeOptions&)options; + +/// Creates intermediate buffers. The model is ready to be used after this call. +/// @param inputDimensions Used to create resources: shaders, buffers. +/// @param outputDimensions Will be initialized during this call. +/// @return Status signals whether intermediate buffers are successfully created or not. +/// @discussion The operation is intended to be lightweight with minimum overhead. A preceding call +/// compileModelWithDevice() must be made with the proper device parameter set. +- (::tflite::gpu::Status) + setInputDimensions:(const std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>&)inputDimensions + outputDimensions:(std::map<::tflite::gpu::ValueId, ::tflite::gpu::BHWC>*)outputDimensions + taskDescriptors: + (const std::vector<::tflite::gpu::metal::ComputeTaskDescriptorPtr>&)taskDescriptors; + +/// Inserts all GPU compute tasks into the command encoder. +/// @param inputOutputBuffers Must be created and passed into the method with pairs ID:buffer +/// @param encoderBlock User-defined block to take control over command encoder. Can be nil. +/// The block can be used, for example, for fine-graned benchmarking where end encoding +/// is performed and command buffer is committed with completion block. A new command +/// buffer must be created and new command encoder must be returned by the block. +/// The block is called after every dispatch encoding. +/// @discussion No GPU sychronization functions are used inside. All GPU resources must be created +/// with the same device which has been used in compileModelWithDevice() method. +- (void)encodeWithEncoder:(id)commandEncoder + inputOutputBuffers:(const std::map<::tflite::gpu::ValueId, id>&)inputOutputBuffers + encoderBlock:(id (^)(bool isLast))encoderBlock; + +@end + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/inference_context.mm b/tensorflow/lite/delegates/gpu/metal/inference_context.mm new file mode 100644 index 00000000000..720872ad8a6 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/inference_context.mm @@ -0,0 +1,171 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/inference_context.h" + +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/memory_management.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +using ::tflite::gpu::BHWC; +using ::tflite::gpu::metal::ComputeTaskDescriptorPtr; +using ::tflite::gpu::metal::RuntimeOptions; +using ::tflite::gpu::InternalError; +using ::tflite::gpu::OkStatus; +using ::tflite::gpu::Status; +using ::tflite::gpu::ValueId; +using ::tflite::gpu::AlignByN; +using ::tflite::gpu::HalfBits; +using ::tflite::gpu::MemoryStrategy; +using ::tflite::gpu::TensorUsageRecord; + +@implementation TFLInferenceContext { + std::vector _computeTasks; + std::vector _outputIds; + id _device; + RuntimeOptions _options; +} + +- (Status)compileModelWithDevice:(id)device + taskDescriptors:(const std::vector&)taskDescriptors + outputBufferIDs:(const std::vector&)requestedOutputBufferIDs + runtimeOptions:(const RuntimeOptions&)options { + _device = device; + _outputIds = requestedOutputBufferIDs; + _options = options; + // Metal resources are created here. + for (const auto& node : taskDescriptors) { + TFLComputeTask* task = [[TFLComputeTask alloc] init]; + RETURN_IF_ERROR([task compileWithDevice:_device taskDescriptor:node runtimeOptions:_options]); + _computeTasks.emplace_back(task); + } + return OkStatus(); +} + +- (Status)setInputDimensions:(const std::map&)inputDimensions + outputDimensions:(std::map*)outputDimensions + taskDescriptors:(const std::vector&)taskDescriptors { + // These maps contain all input/output/intermediate buffers shared across model. + std::map dimensions = inputDimensions; + std::map> buffers; + std::set preallocatedIds; + // Insert uninitialized input buffers. This buffers will be set externally. + for (auto dimension : dimensions) { + buffers[dimension.first] = nil; + preallocatedIds.insert(dimension.first); + } + for (const auto& outputId : _outputIds) { + preallocatedIds.insert(outputId); + } + for (auto& task : _computeTasks) { + // The same device must be used here as well as on shader compilation stage. + RETURN_IF_ERROR([task setInputDimensionsWithDevice:_device dimensions:&dimensions]); + } + for (auto id : _outputIds) { + (*outputDimensions)[id] = dimensions[id]; + } + + // TODO(ypisarchyk): it make sense to move it to separate function + // Generate usage records for each intermediate tensor in order of their first_task + std::vector usageRecords; + std::map usageRecordIds; + for (uint32_t i = 0; i < taskDescriptors.size(); ++i) { + auto outputId = taskDescriptors[i]->output_buffer.id; + if (!preallocatedIds.count(outputId)) { + if (!usageRecordIds.count(outputId)) { + const auto it = dimensions.find(outputId); + if (it == dimensions.end()) { + return InternalError("Dimensions for intermediate tensor not found."); + } + usageRecordIds[outputId] = usageRecords.size(); + usageRecords.emplace_back(it->second.w * it->second.h * AlignByN(it->second.c, 4), i, i); + } else { + usageRecords[usageRecordIds[outputId]].last_task = i; + } + } + for (auto& buffer : taskDescriptors[i]->input_buffers) { + if (!preallocatedIds.count(buffer.id)) { + usageRecords[usageRecordIds[buffer.id]].last_task = i; + } + } + } + + tflite::gpu::ObjectsAssignment assignment; + RETURN_IF_ERROR(AssignObjectsToTensors(usageRecords, MemoryStrategy::GREEDY, &assignment)); + auto objectsCount = assignment.object_sizes.size(); + std::vector> sharedBuffers(objectsCount); + size_t dataTypeSize = _options.storage_precision == RuntimeOptions::Precision::FP32 + ? sizeof(float) + : sizeof(HalfBits); + + // allocate buffers for each shared object + for (size_t i = 0; i < objectsCount; ++i) { + // Initialize metal buffer + NSUInteger bufferSize = dataTypeSize * assignment.object_sizes[i]; + +#if (defined(__MAC_10_14) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_14) || \ + (defined(__IPHONE_12_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_12_0) || \ + (defined(__TVOS_12_0) && __TV_OS_VERSION_MIN_REQUIRED >= __TVOS_12_0) + if (bufferSize > [_device maxBufferLength]) { + std::string error("Tensor id: "); + error += std::to_string(assignment.object_ids[i]) + + " with size: " + std::to_string(bufferSize) + + " exceeds MTLDevice maxBufferLength: " + std::to_string([_device maxBufferLength]); + return ::tflite::gpu::ResourceExhaustedError(error); + } +#endif +#if defined(__MAC_10_12) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_12 + if ([_device currentAllocatedSize] + bufferSize > [_device recommendedMaxWorkingSetSize]) { + std::string error("Out of memory in MTLBuffer allocation. Currently allocated: "); + error += std::to_string([_device currentAllocatedSize]); + return ::tflite::gpu::ResourceExhaustedError(error); + } +#endif + + sharedBuffers[i] = [_device newBufferWithLength:bufferSize + options:MTLResourceStorageModeShared]; + } + for (auto& task : _computeTasks) { + RETURN_IF_ERROR([task assignBuffers:&buffers + outputIds:_outputIds + usageRecordIds:usageRecordIds + sharedBufferIds:assignment.object_ids + sharedBuffers:sharedBuffers]); + } + return OkStatus(); +} + +- (void)encodeWithEncoder:(id)commandEncoder + inputOutputBuffers:(const std::map>&)inputOutputBuffers + encoderBlock:(id (^)(bool isLast))encoderBlock { + for (int i = 0; i < _computeTasks.size(); ++i) { + auto& task = _computeTasks[i]; + [task encodeWithEncoder:commandEncoder inputOutputBuffers:inputOutputBuffers]; + if (encoderBlock != nil) { + commandEncoder = encoderBlock(i == _computeTasks.size() - 1); + } + } +} + +@end diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD new file mode 100644 index 00000000000..b61324cabb4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD @@ -0,0 +1,304 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "kernels", + deps = [ + ":add", + ":concat", + ":conv", + ":depthwise_conv", + ":elementwise", + ":fully_connected", + ":max_unpooling", + ":mul", + ":padding", + ":pooling", + ":prelu", + ":relu", + ":reshape", + ":slice", + ":softmax", + ":transpose_conv", + ":upsample", + ], +) + +cc_library( + name = "add", + srcs = ["add.cc"], + hdrs = ["add.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "concat", + srcs = ["concat.cc"], + hdrs = ["concat.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + ], +) + +cc_library( + name = "conv", + srcs = ["conv.cc"], + hdrs = ["conv.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "depthwise_conv", + srcs = ["depthwise_conv.cc"], + hdrs = ["depthwise_conv.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "elementwise", + srcs = ["elementwise.cc"], + hdrs = ["elementwise.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:environment", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "fully_connected", + srcs = ["fully_connected.cc"], + hdrs = ["fully_connected.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:environment", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "max_unpooling", + srcs = ["max_unpooling.cc"], + hdrs = ["max_unpooling.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "mul", + srcs = ["mul.cc"], + hdrs = ["mul.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "padding", + srcs = ["padding.cc"], + hdrs = ["padding.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + ":util", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "prelu", + srcs = ["prelu.cc"], + hdrs = ["prelu.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "relu", + srcs = ["relu.cc"], + hdrs = ["relu.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "reshape", + srcs = ["reshape.cc"], + hdrs = ["reshape.h"], + deps = [ + ":util", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "slice", + srcs = ["slice.cc"], + hdrs = ["slice.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "softmax", + srcs = ["softmax.cc"], + hdrs = ["softmax.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "transpose_conv", + srcs = ["transpose_conv.cc"], + hdrs = ["transpose_conv.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "//tensorflow/lite/delegates/gpu/metal:environment", + "//tensorflow/lite/delegates/gpu/metal:runtime_options", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "upsample", + srcs = ["upsample.cc"], + hdrs = ["upsample.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:types", + ], +) diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add.cc b/tensorflow/lite/delegates/gpu/metal/kernels/add.cc new file mode 100644 index 00000000000..b7b6fe3312b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add.cc @@ -0,0 +1,139 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/add.h" + +#include +#include +#include +#include +#include + +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::string GetAddTableCode(int src_count) { + std::string code = R"( + #include + using namespace metal; + + struct uniforms { + int4 src_size; + }; + + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (static_cast(gid.x) >= params.src_size.x || + static_cast(gid.y) >= params.src_size.y) { + return; + } + + FLT4 value = FLT4(0.0f); + int linear_index = (int(gid.z) * params.src_size.y + int(gid.y)) * + params.src_size.x + int(gid.x); + )"; + for (int i = 0; i < src_count; ++i) { + code += " value += src_buffer" + std::to_string(i) + "[linear_index];\n"; + } + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += "}\n"; + return code; +} +} // namespace + +std::vector Add(int id, ValueId input_id, + ValueId output_id, + const AddAttributes& attr, + const RuntimeOptions& options) { + auto add_buffer = + absl::get_if>(&attr.param); + if (!add_buffer) { + return {}; + } + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = true; + desc->shader_source = + R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, + device FLT4* const add_buf) { + return value + add_buf[gid.z]; + })"; + desc->input_buffers = {{input_id}}; + desc->output_buffer = {output_id}; + auto coeffs = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(add_buffer->data) + : VectorFloatToHalf(add_buffer->data); + desc->immutable_buffers = { + {"device FLT4* const", coeffs}, + }; + return {desc}; +} + +std::vector AddTable(int id, + std::vector input_ids, + ValueId output_id) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetAddTableCode(input_ids.size()); + + for (int i = 0; i < input_ids.size(); ++i) { + const std::string buffer_name = + "device FLT4* const src_buffer" + std::to_string(i); + desc->input_buffers.push_back({input_ids[i], buffer_name}); + } + + desc->output_buffer = {output_id, "device FLT4* dst_buffer", + [input_ids](const std::map& buffers) { + return buffers.find(input_ids[0])->second; + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_ids](const std::map& buffers) { + const auto& dimension = buffers.find(input_ids[0])->second; + std::vector uniform_params = {dimension.w, dimension.h, 0, 0}; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [input_ids](const std::map& buffers) { + const auto& src_dim = buffers.find(input_ids[0])->second; + const uint3 groups_size{16, 16, 1}; + int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y); + const int dst_layers = IntegralDivideRoundUp(src_dim.c, 4); + int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/add.h b/tensorflow/lite/delegates/gpu/metal/kernels/add.h new file mode 100644 index 00000000000..6a826e6b100 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/add.h @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_ADD_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_ADD_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// Add with broadcast. +std::vector Add(int id, ValueId input_id, + ValueId output_id, + const AddAttributes& attr, + const RuntimeOptions& options); + +// Add tensors. +std::vector AddTable(int id, + std::vector input_ids, + ValueId output_id); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_ADD_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/concat.cc b/tensorflow/lite/delegates/gpu/metal/kernels/concat.cc new file mode 100644 index 00000000000..7d82979629c --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/concat.cc @@ -0,0 +1,363 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/concat.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::string GetConcatZCode(const std::vector channels) { + const std::string postfix[] = {".x", ".y", ".z", ".w"}; + const std::string postfix_2[] = {".x", ".xy", ".xyz", ""}; + const std::string types[] = {"FLT", "FLT2", "FLT3", "FLT4"}; + std::string code = R"( + #include + using namespace metal; + struct uniforms { + int4 src_size; + }; + + $0 + kernel void ComputeFunction( + $1 + uint2 ugid[[thread_position_in_grid]]) { + if (static_cast(ugid.x) >= params.src_size.x || + static_cast(ugid.y) >= params.src_size.y) { + return; + } + + FLT4 value = FLT4(0.0f); + const int xy_offset = int(ugid.y) * params.src_size.x + int(ugid.x); + int linear_index = xy_offset; + )"; + + int out_channel = 0; + int read_index = 0; + int dst_z = 0; + for (int i = 0; i < channels.size(); ++i) { + const int depth = IntegralDivideRoundUp(channels[i], 4); + code += " {\n"; + code += " int src_address = xy_offset;\n"; + for (int d = 0; d < depth; ++d) { + const int channels_in_group = std::min(4, channels[i] - d * 4); + const std::string temp_name = "t" + std::to_string(read_index); + code += " " + types[channels_in_group - 1] + " " + temp_name + " = " + + "src_buffer" + std::to_string(i) + "[src_address]" + + postfix_2[channels_in_group - 1] + ";\n"; + code += " src_address += params.src_size.w;\n"; + for (int c = 0; c < channels_in_group; ++c) { + if (channels_in_group == 1) { + code += " value" + postfix[out_channel] + " = " + temp_name + ";\n"; + } else { + code += " value" + postfix[out_channel] + " = " + temp_name + + postfix[c] += ";\n"; + } + out_channel++; + if (out_channel == 4) { + out_channel = 0; + code += " {\n"; + code += " uint3 gid = uint3(ugid.x, ugid.y, " + + std::to_string(dst_z) + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " linear_index += params.src_size.w;\n"; + code += " }\n"; + dst_z++; + } + } + read_index++; + } + code += " }\n"; + } + if (out_channel != 0) { + code += " {\n"; + code += " uint3 gid = uint3(ugid.x, ugid.y, " + std::to_string(dst_z) + + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " }\n"; + } + code += "}\n"; + return code; +} +} // namespace + +std::vector ConcatZ( + int id, std::vector input_ids, ValueId output_id, + const ConcatAttributes& attr, const std::vector& input_shapes) { + std::vector channels; + channels.reserve(input_shapes.size()); + for (const auto& shape : input_shapes) { + channels.push_back(shape.c); + } + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetConcatZCode(channels); + + for (int i = 0; i < input_ids.size(); ++i) { + const std::string buffer_name = + "device FLT4* const src_buffer" + std::to_string(i); + desc->input_buffers.push_back({input_ids[i], buffer_name}); + } + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_ids, attr](const std::map& buffers) { + std::vector src_shapes(input_ids.size()); + for (int i = 0; i < input_ids.size(); ++i) { + src_shapes[i] = buffers.find(input_ids[i])->second; + } + BHWC dst_shape; + CalculateOutputShape(src_shapes, attr, &dst_shape).IgnoreError(); + return dst_shape; + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_ids](const std::map& buffers) { + const auto& dimension = buffers.find(input_ids[0])->second; + std::vector uniform_params{ + dimension.w, + dimension.h, + 0, + dimension.w * dimension.h, + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [input_ids](const std::map& buffers) { + const auto& src_dim = buffers.find(input_ids[0])->second; + const uint3 groups_size{16, 16, 1}; + int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y); + int groups_z = 1; + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +std::vector ConcatX( + int id, std::vector input_ids, ValueId output_id, + const ConcatAttributes& attr, const std::vector& input_shapes) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + + std::string code = R"( + #include + using namespace metal; + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (int(gid.x) >= size.x || int(gid.y) >= size.y) { + return; + } + FLT4 value; + )"; + int output_width = 0; + for (int buffer_index = 0; buffer_index < input_shapes.size(); + buffer_index++) { + const auto& dims = input_shapes[buffer_index]; + output_width += dims.w; + + // Generated shader example: + // if (gid.x < 10) value = src_buffer0[(gid.y + gid.z * 3) * 4 + gid.x - 3]; + // else + if (buffer_index < input_shapes.size() - 1) { + code += "if (gid.x < " + std::to_string(output_width) + ")"; + } + code += "value = src_buffer" + std::to_string(buffer_index) + + "[(gid.y + gid.z * " + std::to_string(dims.h) + ") * " + + std::to_string(dims.w) + " + gid.x - " + + std::to_string(output_width - dims.w) + "];\n"; + if (buffer_index < input_shapes.size() - 1) { + code += "else "; + } + } + code += "const int linear_index = (gid.y + gid.z * " + + std::to_string(input_shapes[0].h) + ") * " + + std::to_string(output_width) + " + gid.x;"; + code += R"( + $2 + dst_buffer[linear_index] = value; + } + )"; + desc->shader_source = code; + + for (int i = 0; i < input_ids.size(); ++i) { + const std::string buffer_name = + "device FLT4* const src_buffer" + std::to_string(i); + desc->input_buffers.push_back({input_ids[i], buffer_name}); + } + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_ids, attr](const std::map& buffers) { + std::vector src_shapes(input_ids.size()); + for (int i = 0; i < input_ids.size(); ++i) { + src_shapes[i] = buffers.find(input_ids[i])->second; + } + BHWC dst_shape; + CalculateOutputShape(src_shapes, attr, &dst_shape).IgnoreError(); + return dst_shape; + }}; + + desc->uniform_buffers = { + {"constant int3& size", + [output_id](const std::map& buffers) { + const auto& dimension = buffers.find(output_id)->second; + std::vector uniform_params{dimension.w, dimension.h, + IntegralDivideRoundUp(dimension.c, 4), + /*padding=*/0}; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + const auto& output_dims = buffers.find(output_id)->second; + const uint3 groups_size{1, 1, 1}; + int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y); + int groups_z = IntegralDivideRoundUp(output_dims.c, 4); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +std::vector ConcatY( + int id, std::vector input_ids, ValueId output_id, + const ConcatAttributes& attr, const std::vector& input_shapes) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + + std::string code = R"( + #include + using namespace metal; + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (int(gid.x) >= size.x || int(gid.y) >= size.y) { + return; + } + FLT4 value; + )"; + int output_height = 0; + for (int buffer_index = 0; buffer_index < input_shapes.size(); + buffer_index++) { + const auto& dims = input_shapes[buffer_index]; + output_height += dims.h; + + // Generated shader example: + // if (gid.y < 10) value = src_buffer0[(gid.y - 3 + gid.z * 5) * 4 + gid.x]; + // else + if (buffer_index < input_shapes.size() - 1) { + code += "if (gid.y < " + std::to_string(output_height) + ")"; + } + code += "value = src_buffer" + std::to_string(buffer_index) + "[(gid.y - " + + std::to_string(output_height - dims.h) + " + gid.z * " + + std::to_string(dims.h) + ") * " + std::to_string(dims.w) + + " + gid.x];\n"; + if (buffer_index < input_shapes.size() - 1) { + code += "else "; + } + } + const auto& dims = input_shapes[0]; + code += "const int linear_index = (gid.y + gid.z * " + + std::to_string(output_height) + ") * " + std::to_string(dims.w) + + " + gid.x;"; + code += R"( + $2 + dst_buffer[linear_index] = value; + } + )"; + desc->shader_source = code; + + for (int i = 0; i < input_ids.size(); ++i) { + const std::string buffer_name = + "device FLT4* const src_buffer" + std::to_string(i); + desc->input_buffers.push_back({input_ids[i], buffer_name}); + } + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_ids, attr](const std::map& buffers) { + std::vector src_shapes(input_ids.size()); + for (int i = 0; i < input_ids.size(); ++i) { + src_shapes[i] = buffers.find(input_ids[i])->second; + } + BHWC dst_shape; + CalculateOutputShape(src_shapes, attr, &dst_shape).IgnoreError(); + return dst_shape; + }}; + + desc->uniform_buffers = { + {"constant int3& size", + [output_id](const std::map& buffers) { + const auto& dimension = buffers.find(output_id)->second; + std::vector uniform_params{dimension.w, dimension.h, + IntegralDivideRoundUp(dimension.c, 4)}; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + const auto& output_dims = buffers.find(output_id)->second; + const uint3 groups_size{1, 1, 1}; + int groups_x = IntegralDivideRoundUp(output_dims.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(output_dims.h, groups_size.y); + int groups_z = IntegralDivideRoundUp(output_dims.c, 4); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +std::vector Concat( + int id, std::vector input_ids, ValueId output_id, + const ConcatAttributes& attr, const std::vector& input_shapes) { + if (attr.axis == Axis::CHANNELS) { + return ConcatZ(id, input_ids, output_id, attr, input_shapes); + } else if (attr.axis == Axis::WIDTH) { + return ConcatX(id, input_ids, output_id, attr, input_shapes); + } else { + return ConcatY(id, input_ids, output_id, attr, input_shapes); + } +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/concat.h b/tensorflow/lite/delegates/gpu/metal/kernels/concat.h new file mode 100644 index 00000000000..9fec8a3bf6e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/concat.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONCAT_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONCAT_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector Concat( + int id, std::vector input_ids, ValueId output_id, + const ConcatAttributes& attr, const std::vector& input_shapes); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONCAT_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc new file mode 100644 index 00000000000..19be9d4902a --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc @@ -0,0 +1,1216 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +int GetNumOutputSlices(int dst_channels) { + const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); + if (dst_depth % 4 == 0) { + return 4; + } else if (dst_depth % 2 == 0) { + return 2; + } else { + return 1; + } +} + +int GetSrcBatchSize(int dst_channels) { + const int dst_depth = IntegralDivideRoundUp(dst_channels, 4); + if (dst_depth % 4 == 0) { + return 2; + } else if (dst_depth % 2 == 0) { + return 4; + } else { + return 8; + } +} + +std::string GetValuesDeclarationPart(int num_output_slices, bool is_1x1) { + std::string code; + for (int d = 0; d < num_output_slices; ++d) { + code += absl::Substitute(R"( + float4 sum$0 = float4(0.0f, 0.0f, 0.0f, 0.0f); + )", + d); + } + if (is_1x1) { + code += absl::Substitute(R"( + threadgroup FLT4 temp[32]; + device FLT4* f_offseted = weights + (gid.z + params.z_offset.x) * $0 * src_offset; + )", + num_output_slices * 4); + } else { + code += absl::Substitute(R"( + threadgroup FLT4 temp[32]; + device FLT4* f_offseted = weights + (gid.z + params.z_offset.x) * $0 * src_offset * + kernel_y * kernel_x; + )", + num_output_slices * 4); + } + return code; +} + +std::string GetLocalMemoryUploadPart() { + std::string code = R"( + BARRIER(mem_flags::mem_none); + temp[tid] = f_offseted[tid]; + f_offseted += 32; + BARRIER(mem_flags::mem_threadgroup); + )"; + return code; +} + +std::string GetSummationPart(int num_output_slices, int index) { + std::string code = R"( + { + const FLT4 src = src_buffer[src_address]; + src_address += params.dillation_layer_offsets.z; + )"; + for (int d = 0; d < num_output_slices; ++d) { + code += absl::Substitute(R"( + sum$6.x += dot(temp[$0 * $1 + $2], src) * multiplier; + sum$6.y += dot(temp[$0 * $1 + $3], src) * multiplier; + sum$6.z += dot(temp[$0 * $1 + $4], src) * multiplier; + sum$6.w += dot(temp[$0 * $1 + $5], src) * multiplier; + )", + index, num_output_slices * 4, d * 4 + 0, d * 4 + 1, + d * 4 + 2, d * 4 + 3, d); + } + code += "}"; + return code; +} + +std::string GetBiasReadingPart(int num_output_slices) { + std::string code = absl::Substitute(R"( + { + gid.z = (gid.z + params.z_offset.x) * $0; + BARRIER(mem_flags::mem_none); + if (tid < $0) { + temp[tid] = biases[gid.z + tid]; + } + BARRIER(mem_flags::mem_threadgroup); + if (outside) { + return; + } + })", + num_output_slices); + return code; +} + +std::string GetWritingPart(int num_output_slices) { + std::string code; + for (int d = 0; d < num_output_slices; ++d) { + code += absl::Substitute(R"( + { + int dst_address = int(gid.y) * params.size.z + int(gid.x); + FLT4 value = FLT4(sum$0) + temp[$0]; + const int linear_index = gid.z * params.dillation_layer_offsets.w + dst_address; + $$2 + dst_buffer[linear_index + params.z_offset.y] = value; + gid.z += 1; + })", + d); + } + return code; +} + +std::string GetKernelForConv(const Convolution2DAttributes& params) { + const int num_output_slices = GetNumOutputSlices(params.weights.shape.o); + std::string code; + code.reserve(16 * 1024); // Reserve large enough buffer. + const bool is_1x1 = + params.weights.shape.w == 1 && params.weights.shape.h == 1; + const bool is_strided = params.strides.w > 1 || params.strides.h > 1; + const int src_group_size = GetSrcBatchSize(params.weights.shape.o); + + const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); + const int src_groups = src_depth / src_group_size; + const int src_depth_aligned = AlignByN(src_depth, src_group_size); + const int reminder_src_depth = src_depth - src_groups * src_group_size; + + code = absl::Substitute(R"( + #include + using namespace metal; + constant int src_depth_groups = $0; + constant int src_offset = $1; + constant int kernel_x = $2; + constant int kernel_y = $3; + struct uniforms { + int4 stride_padding; + int4 dillation_layer_offsets; + int4 size; + int4 z_offset; + }; + $$0 + kernel void ComputeFunction( + $$1 + uint tid[[thread_index_in_threadgroup]], + uint3 gid[[thread_position_in_grid]]) + { + const bool outside = static_cast(gid.x) >= params.size.z || + static_cast(gid.y) >= params.size.w; + )", + src_groups, src_depth_aligned, params.weights.shape.w, + params.weights.shape.h); + code += GetValuesDeclarationPart(num_output_slices, is_1x1); + + if (!is_1x1) { + code += R"( + for(int ky = 0; ky < kernel_y; ++ky) { + for(int kx = 0; kx < kernel_x; ++kx) { + int2 coords = int2(gid.xy) * params.stride_padding.xy + int2(kx, ky) * + params.dillation_layer_offsets.xy - params.stride_padding.zw; + const bool el_outside = coords.x < 0 || coords.y < 0 || coords.x >= params.size.x || + coords.y >= params.size.y; + const FLT multiplier = el_outside ? 0.0f : 1.0f; + )"; + } else { + code += "const FLT multiplier = 1.0f;\n"; + code += "int2 coords = int2(gid.xy)"; + if (is_strided) { + code += " * params.stride_padding.xy"; + } + code += ";\n"; + } + code += R"( + coords = clamp(coords, int2(0, 0), int2(params.size.x - 1, params.size.y - 1)); + int src_address = coords.y * params.size.x + coords.x; + for(int s = 0; s < src_depth_groups; ++s) { + )"; + code += GetLocalMemoryUploadPart(); + for (int sub_s = 0; sub_s < src_group_size; ++sub_s) { + code += GetSummationPart(num_output_slices, sub_s); + } + code += R"( + } + )"; + if (reminder_src_depth != 0) { + code += GetLocalMemoryUploadPart(); + for (int sub_s = 0; sub_s < reminder_src_depth; ++sub_s) { + code += GetSummationPart(num_output_slices, sub_s); + } + } + if (!is_1x1) { + code += R"( + } + } + )"; + } + code += GetBiasReadingPart(num_output_slices); + code += GetWritingPart(num_output_slices); + code += " }"; + return code; +} + +// Reorder weights to make the weights memory access pattern cache friendly for +// GPU +std::vector ReorderWeightsForConvShared( + const Convolution2DAttributes& params) { + const int dst_batch_size = GetNumOutputSlices(params.weights.shape.o) * 4; + const int src_batch_size = GetSrcBatchSize(params.weights.shape.o); + BHWC input_dimensions{params.weights.shape.o, params.weights.shape.h, + params.weights.shape.w, params.weights.shape.i}; + const int gpu_simd_size = dst_batch_size * src_batch_size; + const int weights_width = AlignByN(input_dimensions.c, gpu_simd_size); + const int weights_height = AlignByN(input_dimensions.b, dst_batch_size); + const int weights_channels = params.weights.shape.w * params.weights.shape.h; + const int weights_aligned_size = + weights_width * weights_height * weights_channels; + std::vector weights_reordered(weights_aligned_size); + float* destination = weights_reordered.data(); + const int dst_groups = + IntegralDivideRoundUp(input_dimensions.b, dst_batch_size); + const int src_sub_groups = + IntegralDivideRoundUp(input_dimensions.c, 4 * src_batch_size); + for (int group = 0; group < dst_groups; ++group) { + for (int y = 0; y < params.weights.shape.h; ++y) { + for (int x = 0; x < params.weights.shape.w; ++x) { + for (int sub_group = 0; sub_group < src_sub_groups; ++sub_group) { + for (int s = 0; s < src_batch_size; ++s) { + for (int d = 0; d < dst_batch_size; ++d) { + int output_index = group * dst_batch_size + d; + for (int i = 0; i < 4; ++i) { + int input_index = (sub_group * src_batch_size + s) * 4 + i; + if (input_index >= input_dimensions.c || + output_index >= input_dimensions.b) { + // Padding with zero + *destination++ = 0.0f; + } else { + int linear_index = + input_index + + input_dimensions.c * + (x + input_dimensions.w * + (y + input_dimensions.h * output_index)); + *destination++ = params.weights.data[linear_index]; + } + } + } + } + } + } + } + } + return weights_reordered; +} + +std::vector GetUniformBufferForConvShared( + const BHWC& input_dimensions, const BHWC& output_dimensions, + const Convolution2DAttributes& params) { + std::vector uniform_params = { + params.strides.w, + params.strides.h, + params.padding.prepended.w, + params.padding.prepended.h, + params.dilations.w, + params.dilations.h, + input_dimensions.w * input_dimensions.h, + output_dimensions.w * output_dimensions.h, + input_dimensions.w, + input_dimensions.h, + output_dimensions.w, + output_dimensions.h, + // TODO(chirkov): use z_offset for concat table optimization + /*z_offset.x=*/0, + /*z_offset.y=*/0, + /*z_offset.z=*/0, + /*z_offset.w=*/0, + }; + return VectorToUint8Vector(uniform_params); +} + +std::string GetKernelForConv1x1(const Convolution2DAttributes& params, + int z_out) { + std::string code; + code.reserve(16 * 1024); // Reserve large enough buffer. + std::string channels[4] = {"x", "y", "z", "w"}; + code += R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int4 stride_padding; + int4 kernel_dilation; + uint4 work_group_size; +}; +$0 + +kernel void ComputeFunction( + $1 + uint3 group_id[[threadgroup_position_in_grid]], + uint3 tid3d[[thread_position_in_threadgroup]]) +{ + int gid_x = group_id.y * params.work_group_size.x + tid3d.x; + int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) << 1u; + )"; + code += " int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " + + std::to_string(z_out) + "u;\n"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } + code += R"( + device FLT4* tmp = filters + gid_z * 4 * params.src_size.w; + + int y0 = clamp(gid_y, 0, params.src_size.y - 1); + int y1 = clamp(gid_y + 1, 0, params.src_size.y - 1); + int x0 = clamp(gid_x, 0, params.src_size.x - 1); + + int s = 0; + + device FLT4* src_loc_0 = src_buffer + y0 * params.src_size.x + x0; + device FLT4* src_loc_1 = src_buffer + y1 * params.src_size.x + x0; + do { + FLT4 src_0 = *src_loc_0; + FLT4 src_1 = *src_loc_1; + src_loc_0 += params.src_size.z; + src_loc_1 += params.src_size.z; + )"; + for (int i = 0; i < z_out * 4; ++i) { + const std::string s_i = std::to_string(i); + code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + + " += dot(tmp[" + s_i + "], src_0);\n"; + code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + + " += dot(tmp[" + s_i + "], src_1);\n"; + } + + code += " tmp += " + std::to_string(z_out * 4) + ";\n"; + code += R"( + s += 1; + } while (s < params.src_size.w); + const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x; + const int offset_1 = offset_0 + params.dst_size.x; + bool y0_in = gid_y < params.dst_size.y; + bool y1_in = gid_y + 1 < params.dst_size.y; + + device FLT4* bias_loc = biases + gid_z; + )"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; + code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; + } + code += R"( + if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) { + return; + } + )"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; + code += " if (y0_in) {\n"; + code += " FLT4 value = FLT4(r" + s_i + ");\n"; + code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + + ";\n"; + code += " uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " }\n"; + code += " if (y1_in) {\n"; + code += " FLT4 value = FLT4(l" + s_i + ");\n"; + code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + + ";\n"; + code += " uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " }\n"; + code += " }\n"; + } + code += " }\n"; + return code; +} + +std::string GetKernelForConvGeneric(const Convolution2DAttributes& params, + int z_out) { + std::string code; + code.reserve(16 * 1024); // Reserve large enough buffer. + std::string channels[4] = {"x", "y", "z", "w"}; + code += R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int4 stride_padding; + int4 kernel_dilation; + uint4 work_group_size; +}; +$0 + +kernel void ComputeFunction( + $1 + uint3 group_id[[threadgroup_position_in_grid]], + uint3 tid3d[[thread_position_in_threadgroup]]) +{ + int gid_x = group_id.y * params.work_group_size.x + tid3d.x; + int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) * 2; + )"; + code += " int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " + + std::to_string(z_out) + "u;\n"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } + code += R"( + device FLT4* tmp = filters + gid_z * 4 * params.src_size.w * params.kernel_dilation.x * params.kernel_dilation.y; + + int y0 = gid_y * params.stride_padding.y + params.stride_padding.w; + int y1 = (gid_y + 1) * params.stride_padding.y + params.stride_padding.w; + int x0 = gid_x * params.stride_padding.x + params.stride_padding.z; + + int y = 0; + do { + int coord_y0 = y * params.kernel_dilation.w + y0; + int coord_y1 = y * params.kernel_dilation.w + y1; + bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y; + bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y; + coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1); + coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1); + int x = 0; + do { + int coord_x0 = x * params.kernel_dilation.z + x0; + bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x; + coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1); + FLT m0 = !(y0_out || x0_out); + FLT m1 = !(y1_out || x0_out); + int s = 0; + device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0; + device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x0; + do { + FLT4 src_0 = *src_loc_0 * m0; + FLT4 src_1 = *src_loc_1 * m1; + src_loc_0 += params.src_size.z; + src_loc_1 += params.src_size.z; + )"; + for (int i = 0; i < z_out * 4; ++i) { + const std::string s_i = std::to_string(i); + code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + + " += dot(tmp[" + s_i + "], src_0);\n"; + code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + + " += dot(tmp[" + s_i + "], src_1);\n"; + } + + code += " tmp += " + std::to_string(z_out * 4) + ";\n"; + code += R"( + s += 1; + } while (s < params.src_size.w); + x++; + } while (x < params.kernel_dilation.x); + y++; + } while (y < params.kernel_dilation.y); + const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x; + const int offset_1 = offset_0 + params.dst_size.x; + bool p0_in = gid_x < params.dst_size.x && gid_y < params.dst_size.y; + bool p1_in = gid_x < params.dst_size.x && gid_y + 1 < params.dst_size.y; + + device FLT4* bias_loc = biases + gid_z; + )"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; + code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; + } + code += R"( + if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) { + return; + } + )"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; + code += " if (p0_in) {\n"; + code += " FLT4 value = FLT4(r" + s_i + ");\n"; + code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + + ";\n"; + code += " uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " }\n"; + code += " if (p1_in) {\n"; + code += " FLT4 value = FLT4(l" + s_i + ");\n"; + code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + + ";\n"; + code += " uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " }\n"; + code += " }\n"; + } + code += " }\n"; + return code; +} + +std::string GetKernelForConvPrecise(int z_out) { + std::string channels[4] = {"x", "y", "z", "w"}; + std::string code; + code.reserve(16 * 1024); // Reserve large enough buffer. + code += R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int4 stride_padding; + int4 kernel_dilation; + int4 slices; +}; +$0 + +kernel void ComputeFunction( + $1 + uint3 ugid[[thread_position_in_grid]]) +{ + int linear_id = ugid.x; + int gid_z = linear_id / params.slices.y; + int linear_xy = (linear_id - gid_z * params.slices.y) << 1; + )"; + code += " gid_z *= " + std::to_string(z_out) + ";\n"; + code += R"( + int gid_y0 = linear_xy / params.slices.x; + int gid_x0 = linear_xy - gid_y0 * params.slices.x; + linear_xy += 1; + int gid_y1 = linear_xy / params.slices.x; + int gid_x1 = linear_xy - gid_y1 * params.slices.x; + + if (gid_z >= params.dst_size.w) return; + )"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } + code += R"( + device FLT4* tmp = filters + gid_z * 4 * params.src_size.w * + params.kernel_dilation.x * params.kernel_dilation.y; + + int y0 = gid_y0 * params.stride_padding.y + params.stride_padding.w; + int y1 = gid_y1 * params.stride_padding.y + params.stride_padding.w; + int x0 = gid_x0 * params.stride_padding.x + params.stride_padding.z; + int x1 = gid_x1 * params.stride_padding.x + params.stride_padding.z; +)"; + code += R"( + int y = 0; + do { + int coord_y0 = y * params.kernel_dilation.w + y0; + int coord_y1 = y * params.kernel_dilation.w + y1; + bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y; + bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y; + coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1); + coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1); + int x = 0; + do { + int coord_x0 = x * params.kernel_dilation.z + x0; + int coord_x1 = x * params.kernel_dilation.z + x1; + bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x; + bool x1_out = coord_x1 < 0 || coord_x1 >= params.src_size.x; + coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1); + coord_x1 = clamp(coord_x1, 0, params.src_size.x - 1); + FLT m0 = !(y0_out || x0_out); + FLT m1 = !(y1_out || x1_out); + device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0; + device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x1; + int s = 0; + do { + FLT4 src_0 = *src_loc_0 * m0; + FLT4 src_1 = *src_loc_1 * m1; + src_loc_0 += params.src_size.z; + src_loc_1 += params.src_size.z; +)"; + for (int i = 0; i < z_out * 4; ++i) { + const std::string s_i = std::to_string(i); + code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + + " += dot(tmp[" + s_i + "], src_0);\n"; + code += " l" + std::to_string(i / 4) + "." + channels[i % 4] + + " += dot(tmp[" + s_i + "], src_1);\n"; + } + + code += " tmp += " + std::to_string(z_out * 4) + ";\n"; + code += R"( + s += 1; + } while (s < params.src_size.w); + x++; + } while (x < params.kernel_dilation.x); + y++; + } while (y < params.kernel_dilation.y); + const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0; + const int offset_1 = gid_z * params.dst_size.z + gid_y1 * params.dst_size.x + gid_x1; + bool p0_in = gid_x0 < params.dst_size.x && gid_y0 < params.dst_size.y; + bool p1_in = gid_x1 < params.dst_size.x && gid_y1 < params.dst_size.y; + + device FLT4* bias_loc = biases + gid_z; + )"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; + code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n"; + } + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; + code += " if (p0_in) {\n"; + code += " FLT4 value = FLT4(r" + s_i + ");\n"; + code += " int linear_index = offset_0 + params.dst_size.z * " + s_i + + ";\n"; + code += " uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " }\n"; + code += " if (p1_in) {\n"; + code += " FLT4 value = FLT4(l" + s_i + ");\n"; + code += " int linear_index = offset_1 + params.dst_size.z * " + s_i + + ";\n"; + code += " uint3 gid = uint3(gid_x1, gid_y1, gid_z + " + s_i + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " }\n"; + code += " }\n"; + } + code += " }\n"; + return code; +} + +std::string GetKernelForConvPrecise1x1PowerVR(int z_out) { + std::string channels[4] = {"x", "y", "z", "w"}; + std::string code; + code.reserve(16 * 1024); // Reserve large enough buffer. + code += R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int4 slices; + int4 dummy0; +}; +$0 + +kernel void ComputeFunction( + $1 + uint3 ugid[[thread_position_in_grid]]) +{ + int linear_id = ugid.x; + int gid_z = linear_id / params.slices.y; + int linear_xy = linear_id - gid_z * params.slices.y; +)"; + code += " gid_z *= " + std::to_string(z_out) + ";\n"; + code += R"( + int gid_y0 = linear_xy / params.slices.x; + int gid_x0 = linear_xy - gid_y0 * params.slices.x; + + if (gid_z >= params.dst_size.w) return; +)"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " float4 r" + s_i + " = float4(0.0f, 0.0f, 0.0f, 0.0f);\n"; + } + code += R"( + device FLT4* tmp = filters + gid_z * 4 * params.src_size.w; + + device FLT4* src_loc_0 = src_buffer + gid_y0 * params.src_size.x + gid_x0; + int s = 0; + do { + FLT4 src_0 = *src_loc_0; + src_loc_0 += params.src_size.z; +)"; + for (int i = 0; i < z_out * 4; ++i) { + const std::string s_i = std::to_string(i); + code += " r" + std::to_string(i / 4) + "." + channels[i % 4] + + " += dot(tmp[" + s_i + "], src_0);\n"; + } + + code += " tmp += " + std::to_string(z_out * 4) + ";\n"; + code += R"( + s += 1; + } while (s < params.src_size.w); + const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0; + + device FLT4* bias_loc = biases + gid_z; + )"; + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " r" + s_i + " += float4(bias_loc[" + s_i + "]);\n"; + } + for (int i = 0; i < z_out; ++i) { + const std::string s_i = std::to_string(i); + code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n"; + code += " FLT4 value = FLT4(r" + s_i + ");\n"; + code += + " int linear_index = offset_0 + params.dst_size.z * " + s_i + ";\n"; + code += " uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += " }\n"; + } + code += " }\n"; + return code; +} + +// Reorder weights to make the weights memory access pattern cache friendly for +// Convolution1x1/ConvolutionGeneric +std::vector ReorderWeightsForConv(const Convolution2DAttributes& params, + int z_out) { + const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4); + const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); + std::vector weights_reordered(params.weights.shape.w * + params.weights.shape.h * dst_depth * 4 * + src_depth * 4); + int counter = 0; + for (int d = 0; d < IntegralDivideRoundUp(dst_depth, z_out); ++d) { + for (int y = 0; y < params.weights.shape.h; ++y) { + for (int x = 0; x < params.weights.shape.w; ++x) { + for (int s = 0; s < src_depth; ++s) { + for (int k = 0; k < z_out; ++k) { + for (int j = 0; j < 4; ++j) { + for (int i = 0; i < 4; ++i) { + int src_ch = s * 4 + i; + int dst_ch = (d * z_out + k) * 4 + j; + if (src_ch >= params.weights.shape.i || + dst_ch >= params.weights.shape.o) { + weights_reordered[counter++] = 0.0f; + } else { + const int f_index = + params.weights.shape.LinearIndex({dst_ch, y, x, src_ch}); + weights_reordered[counter++] = params.weights.data[f_index]; + } + } + } + } + } + } + } + } + return weights_reordered; +} + +uint3 GetWorkGroupForConv() { return {8, 4, 1}; } +uint3 GetWorkGroupForConvPrecise() { return {32, 1, 1}; } + +std::vector GetUniformBufferForConv( + const BHWC& src_size, const BHWC& dst_size, + const Convolution2DAttributes& params) { + const int3 group_size = GetWorkGroupForConv(); + std::vector uniform_params = { + src_size.w, + src_size.h, + src_size.w * src_size.h, + IntegralDivideRoundUp(src_size.c, 4), + dst_size.w, + dst_size.h, + dst_size.w * dst_size.h, + IntegralDivideRoundUp(dst_size.c, 4), + params.strides.w, + params.strides.h, + -params.padding.prepended.w, + -params.padding.prepended.h, + params.weights.shape.w, + params.weights.shape.h, + params.dilations.w, + params.dilations.h, + group_size.x, + group_size.y, + group_size.z, + 1u, // dummy, for alignment + }; + return VectorToUint8Vector(uniform_params); +} + +std::vector GetUniformBufferForConvPrecise( + const BHWC& src_size, const BHWC& dst_size, + const Convolution2DAttributes& params) { + std::vector uniform_params = { + src_size.w, + src_size.h, + src_size.w * src_size.h, + IntegralDivideRoundUp(src_size.c, 4), + dst_size.w, + dst_size.h, + dst_size.w * dst_size.h, + IntegralDivideRoundUp(dst_size.c, 4), + params.strides.w, + params.strides.h, + -params.padding.prepended.w, + -params.padding.prepended.h, + params.weights.shape.w, + params.weights.shape.h, + params.dilations.w, + params.dilations.h, + dst_size.w, + IntegralDivideRoundUp(dst_size.w * dst_size.h, 2), + 0u, // dummy, for alignment + 0u, // dummy, for alignment + }; + return VectorToUint8Vector(uniform_params); +} + +std::vector GetUniformBufferForConvPrecise1x1( + const BHWC& src_size, const BHWC& dst_size, + const Convolution2DAttributes& params) { + std::vector uniform_params = { + src_size.w, + src_size.h, + src_size.w * src_size.h, + IntegralDivideRoundUp(src_size.c, 4), + dst_size.w, + dst_size.h, + dst_size.w * dst_size.h, + IntegralDivideRoundUp(dst_size.c, 4), + dst_size.w, + IntegralDivideRoundUp(dst_size.w * dst_size.h, 1), + 0u, // dummy, for alignment + 0u, // dummy, for alignment + 0u, // dummy, for alignment + 0u, // dummy, for alignment + 0u, // dummy, for alignment + 0u, // dummy, for alignment + }; + return VectorToUint8Vector(uniform_params); +} + +uint3 GetGroupsCountForConv(const uint3& group_size, const BHWC& dst_shape) { + const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4); + int groups_x = IntegralDivideRoundUp(dst_shape.w, group_size.x); + int groups_y = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_shape.h, 2), + group_size.y); + const int z_out = GetNumOutputSlices(dst_shape.c); + int groups_z = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_depth, z_out), + group_size.z); + return {groups_x, groups_y, groups_z}; +} + +uint3 GetGroupsCountForConvPrecise(const uint3& group_size, + const BHWC& dst_shape, int xy_pixels) { + const int z_out = GetNumOutputSlices(dst_shape.c); + const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4); + int xy_size = IntegralDivideRoundUp(dst_shape.w * dst_shape.h, xy_pixels); + int z_size = IntegralDivideRoundUp(dst_depth, z_out); + int task_size = xy_size * z_size; + return {IntegralDivideRoundUp(task_size, group_size.x), 1, 1}; +} + +int GetConvolutionThreadsCount(const BHWC& dst_shape) { + const uint3 group_size = GetWorkGroupForConv(); + const uint3 groups_count = GetGroupsCountForConv(group_size, dst_shape); + return groups_count.x * groups_count.y * groups_count.z * group_size.x * + group_size.y * group_size.z; +} + +int GetConvolutionPreciseThreadsCount(const BHWC& dst_shape, int xy_pixels) { + const uint3 group_size = GetWorkGroupForConvPrecise(); + const uint3 groups_count = + GetGroupsCountForConvPrecise(group_size, dst_shape, xy_pixels); + return groups_count.x * groups_count.y * groups_count.z * group_size.x * + group_size.y * group_size.z; +} + +bool IsConv1x1(const Convolution2DAttributes& attr) { + return attr.weights.shape.h == 1 && attr.weights.shape.w == 1 && + attr.strides.h == 1 && attr.strides.w == 1 && attr.dilations.h == 1 && + attr.dilations.w == 1 && attr.padding.prepended.h == 0 && + attr.padding.prepended.w == 0 && attr.padding.appended.h == 0 && + attr.padding.appended.w == 0; +} + +} // namespace + +std::vector Convolution( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, const RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetKernelForConv(params); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, params](const std::map& buffers) { + return CalculateOutputShape(buffers.find(input_id)->second, params); + }}; + + auto weights_reordered = ReorderWeightsForConvShared(params); + auto weights = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(weights_reordered) + : VectorFloatToHalf(weights_reordered); + auto biases = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(params.bias.data) + : VectorFloatToHalf(params.bias.data); + desc->immutable_buffers = { + {"device FLT4* const weights", weights}, + {"device FLT4* const biases", biases}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& input_dimensions = buffers.find(input_id)->second; + const auto& output_dimensions = buffers.find(output_id)->second; + return GetUniformBufferForConvShared(input_dimensions, + output_dimensions, params); + }}, + }; + + desc->resize_function = [output_id, + params](const std::map& buffers) { + const auto& output_dims = buffers.find(output_id)->second; + const int num_output_slices = GetNumOutputSlices(params.weights.shape.o); + const uint3 group_size{8, 4, 1}; + int groups_x = IntegralDivideRoundUp(output_dims.w, group_size.x); + int groups_y = IntegralDivideRoundUp(output_dims.h, group_size.y); + const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4); + int groups_z = IntegralDivideRoundUp(dst_depth, num_output_slices); + return std::make_pair(group_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +std::vector Convolution1x1( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, + const metal::RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + const int z_out = GetNumOutputSlices(params.weights.shape.o); + desc->shader_source = GetKernelForConv1x1(params, z_out); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, params](const std::map& buffers) { + auto out_shape = + CalculateOutputShape(buffers.find(input_id)->second, params); + return out_shape; + }}; + + auto weights_reordered = ReorderWeightsForConv(params, z_out); + auto weights = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(weights_reordered) + : VectorFloatToHalf(weights_reordered); + auto biases = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(params.bias.data) + : VectorFloatToHalf(params.bias.data); + desc->immutable_buffers = { + {"device FLT4* const filters", weights}, + {"device FLT4* const biases", biases}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& input_dimensions = buffers.find(input_id)->second; + const auto& output_dimensions = buffers.find(output_id)->second; + return GetUniformBufferForConv(input_dimensions, output_dimensions, + params); + }}, + }; + + desc->resize_function = [output_id, + params](const std::map& buffers) { + const auto& output_dims = buffers.find(output_id)->second; + const uint3 group_size = GetWorkGroupForConv(); + const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims); + return std::make_pair( + group_size, uint3{groups_count.z, groups_count.x, groups_count.y}); + }; + + return {desc}; +} + +bool CheckConvolution1x1Support(const Convolution2DAttributes& attr) { + return IsConv1x1(attr); +} + +std::vector ConvolutionGeneric( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, + const metal::RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + const int z_out = GetNumOutputSlices(params.weights.shape.o); + desc->shader_source = GetKernelForConvGeneric(params, z_out); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, params](const std::map& buffers) { + auto out_shape = + CalculateOutputShape(buffers.find(input_id)->second, params); + return out_shape; + }}; + + auto weights_reordered = ReorderWeightsForConv(params, z_out); + auto weights = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(weights_reordered) + : VectorFloatToHalf(weights_reordered); + auto biases = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(params.bias.data) + : VectorFloatToHalf(params.bias.data); + desc->immutable_buffers = { + {"device FLT4* const filters", weights}, + {"device FLT4* const biases", biases}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& input_dimensions = buffers.find(input_id)->second; + const auto& output_dimensions = buffers.find(output_id)->second; + return GetUniformBufferForConv(input_dimensions, output_dimensions, + params); + }}, + }; + + desc->resize_function = [output_id, + params](const std::map& buffers) { + const auto& output_dims = buffers.find(output_id)->second; + const uint3 group_size = GetWorkGroupForConv(); + const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims); + return std::make_pair( + group_size, uint3{groups_count.z, groups_count.x, groups_count.y}); + }; + + return {desc}; +} + +std::vector ConvolutionPrecise( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, + const metal::RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + const int z_out = GetNumOutputSlices(params.weights.shape.o); + desc->shader_source = GetKernelForConvPrecise(z_out); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, params](const std::map& buffers) { + auto out_shape = + CalculateOutputShape(buffers.find(input_id)->second, params); + return out_shape; + }}; + + auto weights_reordered = ReorderWeightsForConv(params, z_out); + auto weights = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(weights_reordered) + : VectorFloatToHalf(weights_reordered); + auto biases = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(params.bias.data) + : VectorFloatToHalf(params.bias.data); + desc->immutable_buffers = { + {"device FLT4* const filters", weights}, + {"device FLT4* const biases", biases}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& input_dimensions = buffers.find(input_id)->second; + const auto& output_dimensions = buffers.find(output_id)->second; + return GetUniformBufferForConvPrecise(input_dimensions, + output_dimensions, params); + }}, + }; + + desc->resize_function = [output_id, + params](const std::map& buffers) { + const auto& output_dims = buffers.find(output_id)->second; + const uint3 group_size = GetWorkGroupForConvPrecise(); + const uint3 groups_count = + GetGroupsCountForConvPrecise(group_size, output_dims, 2); + return std::make_pair(group_size, groups_count); + }; + + return {desc}; +} + +float GetThreadsRatioUsualToPreciseConvolution(const BHWC& dst_shape) { + return static_cast(GetConvolutionThreadsCount(dst_shape)) / + static_cast(GetConvolutionPreciseThreadsCount(dst_shape, 2)); +} + +std::vector ConvolutionPrecise1x1PowerVR( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, const RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + const int z_out = GetNumOutputSlices(params.weights.shape.o); + desc->shader_source = GetKernelForConvPrecise1x1PowerVR(z_out); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, params](const std::map& buffers) { + auto out_shape = + CalculateOutputShape(buffers.find(input_id)->second, params); + return out_shape; + }}; + + auto weights_reordered = ReorderWeightsForConv(params, z_out); + auto weights = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(weights_reordered) + : VectorFloatToHalf(weights_reordered); + auto biases = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(params.bias.data) + : VectorFloatToHalf(params.bias.data); + desc->immutable_buffers = { + {"device FLT4* const filters", weights}, + {"device FLT4* const biases", biases}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& input_dimensions = buffers.find(input_id)->second; + const auto& output_dimensions = buffers.find(output_id)->second; + return GetUniformBufferForConvPrecise1x1(input_dimensions, + output_dimensions, params); + }}, + }; + + desc->resize_function = [output_id, + params](const std::map& buffers) { + const auto& output_dims = buffers.find(output_id)->second; + const uint3 group_size = GetWorkGroupForConvPrecise(); + const uint3 groups_count = + GetGroupsCountForConvPrecise(group_size, output_dims, 1); + return std::make_pair(group_size, groups_count); + }; + + return {desc}; +} + +bool CheckConvolutionPrecise1x1Support(const Convolution2DAttributes& attr) { + return IsConv1x1(attr); +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h new file mode 100644 index 00000000000..692145678cb --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/conv.h @@ -0,0 +1,95 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONV_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector Convolution( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, + const metal::RuntimeOptions& options); + +// Convolution for kernel 1x1 +// require: +// kernel_size = 1x1; +// padding prepended and appended = 0x0 +// dilation = 1x1; +// stride = 1x1; +// Works very good on A12 (IPhoneXS, etc). +// Works good on A9/A10/A11 (IPhone6S, IPhone7, IPhoneX, etc). +// Works bad on A7/A8 (IPhone5S, IPhone6, etc). +std::vector Convolution1x1( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, const RuntimeOptions& options); + +// TODO(impjdi): Move it inside module. +bool CheckConvolution1x1Support(const Convolution2DAttributes& attr); + +// This convolution pass all conv parameters (beside output_channels) +// as dynamic arguments (uniform buffer) to kernel. +// Depending on output_channels can be generated different kernels +// Kernel can proceed 4/8/12/16 output channels per one thread. +// 16 channels output is the fastest but the least flexible. +std::vector ConvolutionGeneric( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, const RuntimeOptions& options); + +// This convolution makes more precise mapping of threads on elements. +// For example, if we have output tensor 12x7 and work group = 8x4, +// then we need 4 workgroups to cover this tensor in usual case. +// But in general we have only 84 elements(12*7), and we can cover it with 3 +// workgroups of size 32. So this version of convolution use this precise +// mapping. +// But this convolution, due to some hardware limitations, doesn't work better +// always. In general it works good on A12. +// Each thread process 2 pixels in XY dimension and variable amount of pixels +// in Z dimension(depends on dst_channels). +std::vector ConvolutionPrecise( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, const RuntimeOptions& options); + +// As previous, but specific for 1x1 and each thread process 1 pixel in XY +// dimension. +// This convolution for PowerVR in FP16 mode with FP32 accumulator +// It will work in other modes also, but not with good performance +std::vector ConvolutionPrecise1x1PowerVR( + int id, ValueId input_id, ValueId output_id, + const Convolution2DAttributes& params, const RuntimeOptions& options); + +// TODO(impjdi): Move it inside module. +bool CheckConvolutionPrecise1x1Support(const Convolution2DAttributes& attr); + +// This function calculates amount of threads that should be launched for +// ConvolutionGeneric or Convolution1x1 (threads_count1) and amount of threads +// that should be launched for ConvolutionPrecise (threads_count2) and returns +// threads_count1 / threads_count2. +float GetThreadsRatioUsualToPreciseConvolution(const BHWC& dst_shape); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc new file mode 100644 index 00000000000..15b46541562 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc @@ -0,0 +1,756 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h" + +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::string GetKernelDepthWiseConv3x3Stride1x1() { + std::string code = R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int2 padding; + int2 dummy0; // for alignment + int4 dummy1; // for alignment +}; +$0 + +kernel void ComputeFunction( + $1 + uint3 ugid[[thread_position_in_grid]]) +{ + int gid_x = ugid.x * 2; + int gid_y = ugid.y * 2; + int gid_z = ugid.z; + + if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) { + return; + } + + ACCUM_FLT4 r0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f); + ACCUM_FLT4 l0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f); + ACCUM_FLT4 t0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f); + ACCUM_FLT4 b0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f); + + int x0 = gid_x + params.padding.x; + int x1 = gid_x + params.padding.x + 1; + int x2 = gid_x + params.padding.x + 2; + int x3 = gid_x + params.padding.x + 3; + int y0 = gid_y + params.padding.y; + int y1 = gid_y + params.padding.y + 1; + int y2 = gid_y + params.padding.y + 2; + int y3 = gid_y + params.padding.y + 3; + + bool x0_out = x0 < 0 || x0 >= params.src_size.x; + bool x1_out = x1 < 0 || x1 >= params.src_size.x; + bool x2_out = x2 < 0 || x2 >= params.src_size.x; + bool x3_out = x3 < 0 || x3 >= params.src_size.x; + bool y0_out = y0 < 0 || y0 >= params.src_size.y; + bool y1_out = y1 < 0 || y1 >= params.src_size.y; + bool y2_out = y2 < 0 || y2 >= params.src_size.y; + bool y3_out = y3 < 0 || y3 >= params.src_size.y; + + x0 = clamp(x0, 0, params.src_size.x - 1); + x1 = clamp(x1, 0, params.src_size.x - 1); + x2 = clamp(x2, 0, params.src_size.x - 1); + x3 = clamp(x3, 0, params.src_size.x - 1); + y0 = clamp(y0, 0, params.src_size.y - 1); + y1 = clamp(y1, 0, params.src_size.y - 1); + y2 = clamp(y2, 0, params.src_size.y - 1); + y3 = clamp(y3, 0, params.src_size.y - 1); + + device FLT4* src_loc = src_buffer + gid_z * params.src_size.z; + device FLT4* filters_loc = filters + gid_z * 10; + + FLT4 s0 = src_loc[y0 * params.src_size.x + x0] * FLT(!(x0_out || y0_out)); + FLT4 s1 = src_loc[y1 * params.src_size.x + x0] * FLT(!(x0_out || y1_out)); + FLT4 s2 = src_loc[y2 * params.src_size.x + x0] * FLT(!(x0_out || y2_out)); + FLT4 s3 = src_loc[y3 * params.src_size.x + x0] * FLT(!(x0_out || y3_out)); + + r0 += TO_ACCUM4_TYPE(s0 * filters_loc[0]); + r0 += TO_ACCUM4_TYPE(s1 * filters_loc[1]); + r0 += TO_ACCUM4_TYPE(s2 * filters_loc[2]); + l0 += TO_ACCUM4_TYPE(s1 * filters_loc[0]); + l0 += TO_ACCUM4_TYPE(s2 * filters_loc[1]); + l0 += TO_ACCUM4_TYPE(s3 * filters_loc[2]); + + s0 = src_loc[y0 * params.src_size.x + x1] * FLT(!(x1_out || y0_out)); + s1 = src_loc[y1 * params.src_size.x + x1] * FLT(!(x1_out || y1_out)); + s2 = src_loc[y2 * params.src_size.x + x1] * FLT(!(x1_out || y2_out)); + s3 = src_loc[y3 * params.src_size.x + x1] * FLT(!(x1_out || y3_out)); + + r0 += TO_ACCUM4_TYPE(s0 * filters_loc[3]); + r0 += TO_ACCUM4_TYPE(s1 * filters_loc[4]); + r0 += TO_ACCUM4_TYPE(s2 * filters_loc[5]); + l0 += TO_ACCUM4_TYPE(s1 * filters_loc[3]); + l0 += TO_ACCUM4_TYPE(s2 * filters_loc[4]); + l0 += TO_ACCUM4_TYPE(s3 * filters_loc[5]); + t0 += TO_ACCUM4_TYPE(s0 * filters_loc[0]); + t0 += TO_ACCUM4_TYPE(s1 * filters_loc[1]); + t0 += TO_ACCUM4_TYPE(s2 * filters_loc[2]); + b0 += TO_ACCUM4_TYPE(s1 * filters_loc[0]); + b0 += TO_ACCUM4_TYPE(s2 * filters_loc[1]); + b0 += TO_ACCUM4_TYPE(s3 * filters_loc[2]); + + s0 = src_loc[y0 * params.src_size.x + x2] * FLT(!(x2_out || y0_out)); + s1 = src_loc[y1 * params.src_size.x + x2] * FLT(!(x2_out || y1_out)); + s2 = src_loc[y2 * params.src_size.x + x2] * FLT(!(x2_out || y2_out)); + s3 = src_loc[y3 * params.src_size.x + x2] * FLT(!(x2_out || y3_out)); + + r0 += TO_ACCUM4_TYPE(s0 * filters_loc[6]); + r0 += TO_ACCUM4_TYPE(s1 * filters_loc[7]); + r0 += TO_ACCUM4_TYPE(s2 * filters_loc[8]); + l0 += TO_ACCUM4_TYPE(s1 * filters_loc[6]); + l0 += TO_ACCUM4_TYPE(s2 * filters_loc[7]); + l0 += TO_ACCUM4_TYPE(s3 * filters_loc[8]); + t0 += TO_ACCUM4_TYPE(s0 * filters_loc[3]); + t0 += TO_ACCUM4_TYPE(s1 * filters_loc[4]); + t0 += TO_ACCUM4_TYPE(s2 * filters_loc[5]); + b0 += TO_ACCUM4_TYPE(s1 * filters_loc[3]); + b0 += TO_ACCUM4_TYPE(s2 * filters_loc[4]); + b0 += TO_ACCUM4_TYPE(s3 * filters_loc[5]); + + s0 = src_loc[y0 * params.src_size.x + x3] * FLT(!(x3_out || y0_out)); + s1 = src_loc[y1 * params.src_size.x + x3] * FLT(!(x3_out || y1_out)); + s2 = src_loc[y2 * params.src_size.x + x3] * FLT(!(x3_out || y2_out)); + s3 = src_loc[y3 * params.src_size.x + x3] * FLT(!(x3_out || y3_out)); + + t0 += TO_ACCUM4_TYPE(s0 * filters_loc[6]); + t0 += TO_ACCUM4_TYPE(s1 * filters_loc[7]); + t0 += TO_ACCUM4_TYPE(s2 * filters_loc[8]); + b0 += TO_ACCUM4_TYPE(s1 * filters_loc[6]); + b0 += TO_ACCUM4_TYPE(s2 * filters_loc[7]); + b0 += TO_ACCUM4_TYPE(s3 * filters_loc[8]); + + r0 += TO_ACCUM4_TYPE(filters_loc[9]); + l0 += TO_ACCUM4_TYPE(filters_loc[9]); + t0 += TO_ACCUM4_TYPE(filters_loc[9]); + b0 += TO_ACCUM4_TYPE(filters_loc[9]); + + const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x; + const int offset_1 = offset_0 + params.dst_size.x; + const int offset_2 = offset_0 + 1; + const int offset_3 = offset_0 + params.dst_size.x + 1; + bool x0_in = gid_x < params.dst_size.x; + bool x1_in = gid_x + 1 < params.dst_size.x; + bool y0_in = gid_y < params.dst_size.y; + bool y1_in = gid_y + 1 < params.dst_size.y; + + if (y0_in && x0_in) { + int linear_index = offset_0; + FLT4 value = FLT4(r0); + uint3 gid = uint3(gid_x, gid_y, gid_z); + $2 + dst_buffer[linear_index] = value; + } + if (y1_in && x0_in) { + int linear_index = offset_1; + FLT4 value = FLT4(l0); + uint3 gid = uint3(gid_x, gid_y + 1, gid_z); + $2 + dst_buffer[linear_index] = value; + } + if (y0_in && x1_in) { + int linear_index = offset_2; + FLT4 value = FLT4(t0); + uint3 gid = uint3(gid_x + 1, gid_y, gid_z); + $2 + dst_buffer[linear_index] = value; + } + if (y1_in && x1_in) { + int linear_index = offset_3; + FLT4 value = FLT4(b0); + uint3 gid = uint3(gid_x + 1, gid_y + 1, gid_z); + $2 + dst_buffer[linear_index] = value; + } +} + )"; + + return code; +} + +// Reorder weights to make the weights memory access pattern cache friendly for +// DepthWiseConv3x3Stride1x1 +std::vector ReorderWeightsDepthWiseConv3x3Stride1x1( + const DepthwiseConvolution2DAttributes& attr) { + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int kernel_x = 3; + const int kernel_y = 3; + std::vector weights_reordered((kernel_x * kernel_y + 1) * src_depth * + 4); + + int counter = 0; + for (int s = 0; s < src_depth; ++s) { + for (int x = 0; x < kernel_x; ++x) { + for (int y = 0; y < kernel_y; ++y) { + for (int i = 0; i < 4; ++i) { + const int s_ch = s * 4 + i; + if (s_ch < attr.weights.shape.i) { + const int f_index = attr.weights.shape.LinearIndex({0, y, x, s_ch}); + weights_reordered[counter++] = attr.weights.data[f_index]; + } else { + weights_reordered[counter++] = 0.0f; + } + } + } + } + + for (int i = 0; i < 4; ++i) { + const int dst_ch = s * 4 + i; + if (dst_ch < attr.bias.shape.v) { + weights_reordered[counter++] = attr.bias.data[dst_ch]; + } else { + weights_reordered[counter++] = 0.0f; + } + } + } + + return weights_reordered; +} + +static std::vector GetUniformBufferDepthWiseConv3x3Stride1x1( + const BHWC& src_size, const BHWC& dst_size, + const DepthwiseConvolution2DAttributes& params) { + std::vector uniform_params = { + src_size.w, + src_size.h, + src_size.w * src_size.h, + IntegralDivideRoundUp(src_size.c, 4), + dst_size.w, + dst_size.h, + dst_size.w * dst_size.h, + IntegralDivideRoundUp(dst_size.c, 4), + -params.padding.prepended.w, + -params.padding.prepended.h, + 0, // dummy, for alignment + 0, // dummy, for alignment + 0, // dummy, for alignment + 0, // dummy, for alignment + 0, // dummy, for alignment + 0, // dummy, for alignment + }; + return VectorToUint8Vector(uniform_params); +} + +std::string GetKernelDepthWiseConv3x3Stride2() { + std::string code = R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int2 padding; + int2 stride; + int2 dilation; + int2 dummy0; // for alignment +}; +$0 + +kernel void ComputeFunction( + $1 + uint3 ugid[[thread_position_in_grid]]) +{ + int gid_x = ugid.x; + int gid_y = ugid.y * 2; + int gid_z = ugid.z; + + if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) { + return; + } + + device FLT4* src_loc = src_buffer + gid_z * params.src_size.z; + device FLT4* filters_loc = filters + gid_z * 10; + + ACCUM_FLT4 r0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f); + ACCUM_FLT4 l0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f); + + int x0 = gid_x * params.stride.x + params.padding.x; + int x1 = gid_x * params.stride.x + params.padding.x + params.dilation.x; + int x2 = gid_x * params.stride.x + params.padding.x + 2 * params.dilation.x; + int y0 = gid_y * 2 + params.padding.y; + int y1 = gid_y * 2 + params.padding.y + 1; + int y2 = gid_y * 2 + params.padding.y + 2; + int y3 = gid_y * 2 + params.padding.y + 3; + int y4 = gid_y * 2 + params.padding.y + 4; + + bool x0_out = x0 < 0 || x0 >= params.src_size.x; + bool x1_out = x1 < 0 || x1 >= params.src_size.x; + bool x2_out = x2 < 0 || x2 >= params.src_size.x; + bool y0_out = y0 < 0 || y0 >= params.src_size.y; + bool y1_out = y1 < 0 || y1 >= params.src_size.y; + bool y2_out = y2 < 0 || y2 >= params.src_size.y; + bool y3_out = y3 < 0 || y3 >= params.src_size.y; + bool y4_out = y4 < 0 || y4 >= params.src_size.y; + + x0 = clamp(x0, 0, params.src_size.x - 1); + x1 = clamp(x1, 0, params.src_size.x - 1); + x2 = clamp(x2, 0, params.src_size.x - 1); + y0 = clamp(y0, 0, params.src_size.y - 1); + y1 = clamp(y1, 0, params.src_size.y - 1); + y2 = clamp(y2, 0, params.src_size.y - 1); + y3 = clamp(y3, 0, params.src_size.y - 1); + y4 = clamp(y4, 0, params.src_size.y - 1); + + FLT4 s0 = src_loc[y0 * params.src_size.x + x0] * FLT(!(x0_out || y0_out)); + FLT4 s1 = src_loc[y0 * params.src_size.x + x1] * FLT(!(x1_out || y0_out)); + FLT4 s2 = src_loc[y0 * params.src_size.x + x2] * FLT(!(x2_out || y0_out)); + + r0 += TO_ACCUM4_TYPE(s0 * filters_loc[0]); + r0 += TO_ACCUM4_TYPE(s1 * filters_loc[1]); + r0 += TO_ACCUM4_TYPE(s2 * filters_loc[2]); + + s0 = src_loc[y1 * params.src_size.x + x0] * FLT(!(x0_out || y1_out)); + s1 = src_loc[y1 * params.src_size.x + x1] * FLT(!(x1_out || y1_out)); + s2 = src_loc[y1 * params.src_size.x + x2] * FLT(!(x2_out || y1_out)); + + r0 += TO_ACCUM4_TYPE(s0 * filters_loc[3]); + r0 += TO_ACCUM4_TYPE(s1 * filters_loc[4]); + r0 += TO_ACCUM4_TYPE(s2 * filters_loc[5]); + + s0 = src_loc[y2 * params.src_size.x + x0] * FLT(!(x0_out || y2_out)); + s1 = src_loc[y2 * params.src_size.x + x1] * FLT(!(x1_out || y2_out)); + s2 = src_loc[y2 * params.src_size.x + x2] * FLT(!(x2_out || y2_out)); + + r0 += TO_ACCUM4_TYPE(s0 * filters_loc[6]); + r0 += TO_ACCUM4_TYPE(s1 * filters_loc[7]); + r0 += TO_ACCUM4_TYPE(s2 * filters_loc[8]); + l0 += TO_ACCUM4_TYPE(s0 * filters_loc[0]); + l0 += TO_ACCUM4_TYPE(s1 * filters_loc[1]); + l0 += TO_ACCUM4_TYPE(s2 * filters_loc[2]); + + s0 = src_loc[y3 * params.src_size.x + x0] * FLT(!(x0_out || y3_out)); + s1 = src_loc[y3 * params.src_size.x + x1] * FLT(!(x1_out || y3_out)); + s2 = src_loc[y3 * params.src_size.x + x2] * FLT(!(x2_out || y3_out)); + + l0 += TO_ACCUM4_TYPE(s0 * filters_loc[3]); + l0 += TO_ACCUM4_TYPE(s1 * filters_loc[4]); + l0 += TO_ACCUM4_TYPE(s2 * filters_loc[5]); + + s0 = src_loc[y4 * params.src_size.x + x0] * FLT(!(x0_out || y4_out)); + s1 = src_loc[y4 * params.src_size.x + x1] * FLT(!(x1_out || y4_out)); + s2 = src_loc[y4 * params.src_size.x + x2] * FLT(!(x2_out || y4_out)); + + l0 += TO_ACCUM4_TYPE(s0 * filters_loc[6]); + l0 += TO_ACCUM4_TYPE(s1 * filters_loc[7]); + l0 += TO_ACCUM4_TYPE(s2 * filters_loc[8]); + + r0 += TO_ACCUM4_TYPE(filters_loc[9]); + l0 += TO_ACCUM4_TYPE(filters_loc[9]); + + const int offset_0 = gid_z * params.dst_size.z + + gid_y * params.dst_size.x + gid_x; + const int offset_1 = offset_0 + params.dst_size.x; + bool y0_in = gid_y < params.dst_size.y; + bool y1_in = gid_y + 1 < params.dst_size.y; + + if (y0_in) { + int linear_index = offset_0; + FLT4 value = FLT4(r0); + uint3 gid = uint3(gid_x, gid_y, gid_z); + $2 + dst_buffer[linear_index] = value; + } + if (y1_in) { + int linear_index = offset_1; + FLT4 value = FLT4(l0); + uint3 gid = uint3(gid_x, gid_y, gid_z); + $2 + dst_buffer[linear_index] = value; + } +} + )"; + + return code; +} + +// Reorder weights to make the weights memory access pattern cache friendly for +// DepthWiseConv3x3Stride2 +std::vector ReorderWeightsDepthWiseConv3x3Stride2( + const DepthwiseConvolution2DAttributes& attr) { + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int kernel_x = 3; + const int kernel_y = 3; + std::vector weights_reordered((kernel_x * kernel_y + 1) * src_depth * + 4); + + int counter = 0; + for (int s = 0; s < src_depth; ++s) { + for (int y = 0; y < kernel_y; ++y) { + for (int x = 0; x < kernel_x; ++x) { + for (int i = 0; i < 4; ++i) { + const int s_ch = s * 4 + i; + if (s_ch < attr.weights.shape.i) { + const int f_index = attr.weights.shape.LinearIndex({0, y, x, s_ch}); + weights_reordered[counter++] = attr.weights.data[f_index]; + } else { + weights_reordered[counter++] = 0.0f; + } + } + } + } + + for (int i = 0; i < 4; ++i) { + const int dst_ch = s * 4 + i; + if (dst_ch < attr.bias.shape.v) { + weights_reordered[counter++] = attr.bias.data[dst_ch]; + } else { + weights_reordered[counter++] = 0.0f; + } + } + } + + return weights_reordered; +} + +static std::vector GetUniformBufferDepthWiseConv3x3Stride2( + const BHWC& src_size, const BHWC& dst_size, + const DepthwiseConvolution2DAttributes& attr) { + std::vector uniform_params = { + src_size.w, + src_size.h, + src_size.w * src_size.h, + IntegralDivideRoundUp(src_size.c, 4), + dst_size.w, + dst_size.h, + dst_size.w * dst_size.h, + IntegralDivideRoundUp(dst_size.c, 4), + -attr.padding.prepended.w, + -attr.padding.prepended.h, + attr.strides.w, + attr.strides.h, + attr.dilations.w, + attr.dilations.h, + 0, // dummy, for alignment + 0, // dummy, for alignment + }; + return VectorToUint8Vector(uniform_params); +} + +} // namespace + +std::vector DepthWiseConvolution( + int id, ValueId input_id, ValueId output_id, + const DepthwiseConvolution2DAttributes& attr, + const RuntimeOptions& options) { + int channels_multiplier = attr.weights.shape.o; + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + std::string shader_source = R"( + #include + using namespace metal; + constant int kernel_x = $0; + constant int kernel_y = $1; + struct uniforms { + int4 stride; + int4 padding; + int4 dillation; + int4 size; + int4 channel_multiplier; + }; + $$0 + kernel void ComputeFunction( + $$1 + uint tid[[thread_index_in_threadgroup]], + uint3 gid[[thread_position_in_grid]]) { + const bool outside = static_cast(gid.x) >= params.size.z || + static_cast(gid.y) >= params.size.w; + if (outside) { + return; + } + device FLT4* temp = filters + gid.z * kernel_y * kernel_x; + float4 sum0 = float4(0.0f, 0.0f, 0.0f, 0.0f); + + for(int ky = 0; ky < kernel_y; ++ky) { + for(int kx = 0; kx < kernel_x; ++kx) { + int2 coords = int2(gid.xy) * params.stride.xy + int2(kx, ky) * params.dillation.xy - + params.padding.xy; + const bool outside = coords.x < 0 || coords.y < 0 || + coords.x >= params.size.x || coords.y >= params.size.y; + if (outside) continue; +)"; + if (channels_multiplier == 1) { + shader_source += R"( + const int src_layer = gid.z; + const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; + const FLT4 src_modified = src_buffer[src_index]; +)"; + } else if (channels_multiplier == 2) { + shader_source += R"( + const int src_layer = gid.z / 2; + const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; + const FLT4 src = src_buffer[src_index]; + const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw; + const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y); +)"; + } else if (channels_multiplier == 4) { + shader_source += R"( + const int src_layer = gid.z / 4; + const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; + const FLT4 src = src_buffer[src_index]; + const FLT t0 = src[gid.z % 4]; + const FLT4 src_modified = FLT4(t0, t0, t0, t0); +)"; + } else { + shader_source += R"( + const int src_layer = gid.z / params.channel_multiplier.x; + const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x; + const FLT4 src = src_buffer[src_index]; + FLT4 src_modified; + const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4; + src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x]; + src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x]; + src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x]; + src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x]; +)"; + } + shader_source += R"( + sum0 += float4(src_modified * temp[ky * kernel_x + kx]); + } + } + FLT4 res = FLT4(sum0 + float4(biases[gid.z])); + const int linear_index = (gid.z * params.size.w + int(gid.y)) * params.size.z + int(gid.x); + FLT4 value = res; + $$2 + output_buffer[linear_index] = value; + } + )"; + desc->shader_source = absl::Substitute(shader_source, attr.weights.shape.w, + attr.weights.shape.h); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* output_buffer", + [input_id, attr](const std::map& buffers) { + auto out_shape = + CalculateOutputShape(buffers.find(input_id)->second, attr); + return out_shape; + }}; + + std::vector filters_reordered = ConvertToPIOHW4(attr.weights); + auto filters = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(filters_reordered) + : VectorFloatToHalf(filters_reordered); + auto biases = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(attr.bias.data) + : VectorFloatToHalf(attr.bias.data); + desc->immutable_buffers = { + {"device FLT4* const filters", filters}, + {"device FLT4* const biases", biases}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, attr](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + const auto& output_dimension = buffers.find(output_id)->second; + std::vector uniform_params{ + attr.strides.w, + attr.strides.h, + 1, + 1, + attr.padding.prepended.w, + attr.padding.prepended.h, + 1, + 1, + attr.dilations.w, + attr.dilations.h, + 1, + 1, + dimension.w, + dimension.h, + output_dimension.w, + output_dimension.h, + attr.weights.shape.o, + 0, + 0, + 0, + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + const auto& dimension = buffers.find(output_id)->second; + uint3 groups_size{8, 4, 1}; + uint3 groups_count{IntegralDivideRoundUp(dimension.w, groups_size.x), + IntegralDivideRoundUp(dimension.h, groups_size.y), + IntegralDivideRoundUp(dimension.c, 4)}; + return std::make_pair(groups_size, groups_count); + }; + + return {desc}; +} + +std::vector DepthWiseConv3x3Stride1x1( + int id, ValueId input_id, ValueId output_id, + const DepthwiseConvolution2DAttributes& attr, + const RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetKernelDepthWiseConv3x3Stride1x1(); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + auto out_shape = + CalculateOutputShape(buffers.find(input_id)->second, attr); + return out_shape; + }}; + + // For this operation we keep weights and biases in one buffer + auto weights_reordered = ReorderWeightsDepthWiseConv3x3Stride1x1(attr); + auto weights = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(weights_reordered) + : VectorFloatToHalf(weights_reordered); + desc->immutable_buffers = { + {"device FLT4* const filters", weights}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, attr](const std::map& buffers) { + const auto& input_dimensions = buffers.find(input_id)->second; + const auto& output_dimensions = buffers.find(output_id)->second; + return GetUniformBufferDepthWiseConv3x3Stride1x1( + input_dimensions, output_dimensions, attr); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + const auto& dimension = buffers.find(output_id)->second; + const int grid_x = IntegralDivideRoundUp(dimension.w, 2); + const int grid_y = IntegralDivideRoundUp(dimension.h, 2); + const int grid_z = IntegralDivideRoundUp(dimension.c, 4); + uint3 group_size{8, 4, 1}; + if (grid_x <= 4) { + group_size.x = 4; + group_size.z = grid_z % 2 == 0 ? 2 : 1; + } + const int groups_x = IntegralDivideRoundUp(grid_x, group_size.x); + const int groups_y = IntegralDivideRoundUp(grid_y, group_size.y); + const int groups_z = IntegralDivideRoundUp(grid_z, group_size.z); + return std::make_pair(group_size, uint3(groups_x, groups_y, groups_z)); + }; + + return {desc}; +} + +bool CheckDepthWiseConv3x3Stride1x1Support( + const DepthwiseConvolution2DAttributes& attr) { + return attr.weights.shape.o == 1 && attr.weights.shape.h == 3 && + attr.weights.shape.w == 3 && attr.strides.h == 1 && + attr.strides.w == 1 && attr.dilations.h == 1 && attr.dilations.w == 1; +} + +std::vector DepthWiseConv3x3Stride2( + int id, ValueId input_id, ValueId output_id, + const DepthwiseConvolution2DAttributes& attr, + const RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetKernelDepthWiseConv3x3Stride2(); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + auto out_shape = + CalculateOutputShape(buffers.find(input_id)->second, attr); + return out_shape; + }}; + + // For this operation we keep weights and biases in one buffer + auto weights_reordered = ReorderWeightsDepthWiseConv3x3Stride2(attr); + auto weights = + options.storage_precision == metal::RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(weights_reordered) + : VectorFloatToHalf(weights_reordered); + desc->immutable_buffers = { + {"device FLT4* const filters", weights}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, attr](const std::map& buffers) { + const auto& input_dimensions = buffers.find(input_id)->second; + const auto& output_dimensions = buffers.find(output_id)->second; + return GetUniformBufferDepthWiseConv3x3Stride2( + input_dimensions, output_dimensions, attr); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + const auto& dimension = buffers.find(output_id)->second; + const int grid_x = dimension.w; + const int grid_y = IntegralDivideRoundUp(dimension.h, 2); + const int grid_z = IntegralDivideRoundUp(dimension.c, 4); + const uint3 group_size{8, 4, 1}; + const int groups_x = IntegralDivideRoundUp(grid_x, group_size.x); + const int groups_y = IntegralDivideRoundUp(grid_y, group_size.y); + const int groups_z = IntegralDivideRoundUp(grid_z, group_size.z); + return std::make_pair(group_size, uint3(groups_x, groups_y, groups_z)); + }; + + return {desc}; +} + +bool CheckDepthWiseConv3x3Stride2Support( + const DepthwiseConvolution2DAttributes& attr) { + return attr.weights.shape.o == 1 && attr.weights.shape.h == 3 && + attr.weights.shape.w == 3 && attr.strides.h == 2 && + attr.dilations.h == 1; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h new file mode 100644 index 00000000000..488b883a099 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h @@ -0,0 +1,69 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_DEPTHWISE_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_DEPTHWISE_CONV_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector DepthWiseConvolution( + int id, ValueId input_id, ValueId output_id, + const DepthwiseConvolution2DAttributes& attr, + const RuntimeOptions& options); + +// Depth Wise Convolution for kernel 3x3 +// require: +// channels_multiplier = 1; +// kernel_size = 3x3; +// dilation = 1x1; +// stride = 1x1; +std::vector DepthWiseConv3x3Stride1x1( + int id, ValueId input_id, ValueId output_id, + const DepthwiseConvolution2DAttributes& attr, + const RuntimeOptions& options); + +// TODO(impjdi): Move it inside module. +bool CheckDepthWiseConv3x3Stride1x1Support( + const DepthwiseConvolution2DAttributes& attr); + +// Depth Wise Convolution for kernel 3x3 +// require: +// channels_multiplier = 1; +// kernel_size = 3x3; +// dilation.y = 1; +// stride.y = 2; +std::vector DepthWiseConv3x3Stride2( + int id, ValueId input_id, ValueId output_id, + const DepthwiseConvolution2DAttributes& attr, + const RuntimeOptions& options); + +// TODO(impjdi): Move it inside module. +bool CheckDepthWiseConv3x3Stride2Support( + const DepthwiseConvolution2DAttributes& attr); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_DEPTHWISE_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc new file mode 100644 index 00000000000..3f873eddb5e --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.cc @@ -0,0 +1,169 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h" + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace metal { + +namespace { + +std::string GetElementwiseWithTwoInputsCode(int src_count, + OperationType op_type) { + std::string code = R"( + #include + using namespace metal; + + struct uniforms { + int4 src_size; + }; + + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (static_cast(gid.x) >= params.src_size.x || + static_cast(gid.y) >= params.src_size.y) { + return; + } + + int linear_index = (int(gid.z) * params.src_size.y + int(gid.y)) * + params.src_size.x + int(gid.x); + )"; + + switch (op_type) { + case OperationType::SUB: { + code += + " FLT4 value = src_buffer0[linear_index] - " + "src_buffer1[linear_index];"; + break; + } + case OperationType::DIV: { + code += + " FLT4 value = src_buffer0[linear_index] / " + "src_buffer1[linear_index];"; + break; + } + case OperationType::POW: { + code += + " FLT4 value = pow(src_buffer0[linear_index], " + "src_buffer1[linear_index]);"; + break; + } + case OperationType::SQUARED_DIFF: { + code += R"( + FLT4 src_0 = src_buffer0[linear_index]; + FLT4 src_1 = src_buffer1[linear_index]; + FLT4 value = (src_0 - src_1) * (src_0 - src_1); + )"; + break; + } + default: { + return ""; + } + } + code += R"( + $2 + dst_buffer[linear_index] = value; + })"; + return code; +} +} // namespace + +std::vector ElementwiseWithTwoInputs( + int id, std::vector input_ids, ValueId output_id, + OperationType op_type) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = + GetElementwiseWithTwoInputsCode(input_ids.size(), op_type); + + for (int i = 0; i < input_ids.size(); ++i) { + const std::string buffer_name = + "device FLT4* const src_buffer" + std::to_string(i); + desc->input_buffers.push_back({input_ids[i], buffer_name}); + } + + desc->output_buffer = {output_id, "device FLT4* dst_buffer", + [input_ids](const std::map& buffers) { + return buffers.find(input_ids[0])->second; + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_ids](const std::map& buffers) { + const auto& dimension = buffers.find(input_ids[0])->second; + std::vector uniform_params = {dimension.w, dimension.h, 0, 0}; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [input_ids](const std::map& buffers) { + const auto& src_dim = buffers.find(input_ids[0])->second; + const uint3 groups_size{16, 16, 1}; + int groups_x = IntegralDivideRoundUp(src_dim.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(src_dim.h, groups_size.y); + const int dst_layers = IntegralDivideRoundUp(src_dim.c, 4); + int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + return {desc}; +} + +std::vector ElementwiseWithOneInput( + int id, ValueId input_id, ValueId output_id, OperationType op_type) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = true; + + const std::unordered_map functors{ + {OperationType::ABS, "abs(value)"}, + {OperationType::SIN, "sin(value)"}, + {OperationType::COS, "cos(value)"}, + {OperationType::LOG, "log(value)"}, + {OperationType::SQRT, "sqrt(value)"}, + {OperationType::RSQRT, "1.0 / sqrt(value)"}, + {OperationType::SQUARE, "value * value"}, + {OperationType::SIGMOID, "1.0 / (1.0 + exp(-1.0 * value))"}, + {OperationType::TANH, "tanh(value)"}, + }; + + if (functors.count(op_type) == 0) { + return {}; + } + + desc->shader_source = + "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid) {\n"; + desc->shader_source += " return " + functors.at(op_type) + ";\n"; + desc->shader_source += " }"; + + desc->input_buffers = {{input_id}}; + desc->output_buffer = {output_id}; + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h new file mode 100644 index 00000000000..c8cee339d1b --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/elementwise.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_ELEMENTWISE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_ELEMENTWISE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector ElementwiseWithTwoInputs( + int id, std::vector input_ids, ValueId output_id, + OperationType op_type); + +std::vector ElementwiseWithOneInput( + int id, ValueId input_id, ValueId output_id, OperationType op_type); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_ELEMENTWISE_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc new file mode 100644 index 00000000000..b1bc0287496 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc @@ -0,0 +1,199 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::string GetFullyConnectedCode(bool shared_memory, int src_channels, + int dst_channels) { + const int src_depth = IntegralDivideRoundUp(src_channels, 4); + std::stringstream code; + code << R"( + #include + using namespace metal; + + struct uniforms { + uint src_depth; + uint dst_channels; + uint out_channels; + uint dummy; + }; + + $$0 + kernel void ComputeFunction( + $$1 + uint3 tid[[thread_position_in_threadgroup]], + uint tid_index[[thread_index_in_threadgroup]], + uint3 ugid[[thread_position_in_grid]]) { + +)"; + if (shared_memory) { + code << R"( + float summa = 0.0f; + threadgroup FLT4 local_vector[32]; + for (int j = 0; j < $0; ++j) { + local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ? + FLT4(0.0f) : vector[j * 32 + tid_index]; + BARRIER(mem_flags::mem_threadgroup); + for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) { + summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]); + } + BARRIER(mem_flags::mem_none); + } + )"; + } else { + code << R"( + float summa = 0.0f; + uint counter = ugid.y * $0; + for (uint i = 0; i < $0; ++i, ++counter) { + )"; + if (src_depth % 4 != 0) { + code << " if (counter >= params.src_depth) continue;" << std::endl; + } + code << " summa += dot(vector[counter], matrix[counter * " + "params.dst_channels + ugid.x]);" + << std::endl; + code << " }" << std::endl; + } + code << R"( + + threadgroup float temp[8][4]; + temp[tid.x][tid.y] = summa; + BARRIER(mem_flags::mem_threadgroup); + if (tid.y == 0) { + summa += temp[tid.x][1]; + summa += temp[tid.x][2]; + summa += temp[tid.x][3]; + temp[tid.x][0] = summa; + } + BARRIER(mem_flags::mem_threadgroup); + if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) { + const int linear_index = ugid.x / 4; + FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) + + biases[linear_index]; + uint3 gid = uint3(1u, 1u, uint(linear_index)); + $$2 + result[linear_index] = value; + } +} + )"; + const int src_depth_sub_groups = shared_memory + ? IntegralDivideRoundUp(src_depth, 32) + : IntegralDivideRoundUp(src_depth, 4); + return absl::Substitute(code.str(), src_depth_sub_groups); +} +} // namespace + +std::vector FullyConnected( + int id, ValueId input_id, ValueId output_id, + const FullyConnectedAttributes& attr, const RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + int gpu_type = GetAppleSocVersion(); + bool shared = gpu_type == 7 || gpu_type == 8; + desc->shader_source = + GetFullyConnectedCode(shared, attr.weights.shape.i, attr.weights.shape.o); + + desc->input_buffers = { + {input_id, "device FLT4* const vector"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* result", + [input_id, attr](const std::map& buffers) { + return CalculateOutputShape(buffers.find(input_id)->second, attr); + }}; + + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int src_depth_aligned = AlignByN(src_depth, shared ? 32 : 4); + const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8); + + int counter = 0; + std::vector filters_reordered(dst_channels_aligned * + src_depth_aligned * 4); + for (int j = 0; j < src_depth_aligned; ++j) { + for (int i = 0; i < dst_channels_aligned; ++i) { + for (int k = 0; k < 4; ++k) { + if (j * 4 + k >= attr.weights.shape.i || i >= attr.weights.shape.o) { + filters_reordered[counter++] = 0.0f; + } else { + const int f_index = + attr.weights.shape.LinearIndex({i, 0, 0, j * 4 + k}); + filters_reordered[counter++] = attr.weights.data[f_index]; + } + } + } + } + + auto filters = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(filters_reordered) + : VectorFloatToHalf(filters_reordered); + auto biases = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(attr.bias.data) + : VectorFloatToHalf(attr.bias.data); + desc->immutable_buffers = { + {"device FLT4* const matrix", filters}, + {"device FLT4* const biases", biases}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [attr](const std::map& buffers) { + std::vector uniform_params{ + static_cast( + IntegralDivideRoundUp(attr.weights.shape.i, 4)), + static_cast(AlignByN(attr.weights.shape.o, 8)), + static_cast(attr.weights.shape.o), + static_cast(0), + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [attr](const std::map& buffers) { + const uint3 groups_size{8, 4, 1}; + const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8); + int groups_x = IntegralDivideRoundUp(dst_channels_aligned, groups_size.x); + return std::make_pair(groups_size, uint3{groups_x, 1, 1}); + }; + + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h new file mode 100644 index 00000000000..00d73fdf944 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// creates TaskDescriptor for FullyConnected +// FullyConnected is equivalent to matrix-vector multiplication +// Also this operation can be replaced with convolution 1x1, but it +// will be inefficient +std::vector FullyConnected( + int id, ValueId input_id, ValueId output_id, + const FullyConnectedAttributes& attr, const RuntimeOptions& options); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.cc new file mode 100644 index 00000000000..c2b439113e2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.cc @@ -0,0 +1,145 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::string GetMaxUnpoolingCode(const HW& kernel_size) { + std::string shader_source = R"( + #include + using namespace metal; + constant int window_w = $0; + struct uniforms { + int2 src_size; + int2 dst_size; + int2 stride; + int2 offset; + }; + + $$0 + kernel void ComputeFunction( + $$1 + uint3 gid[[thread_position_in_grid]]) { + int X = static_cast(gid.x); + int Y = static_cast(gid.y); + if (X >= params.dst_size.x || Y >= params.dst_size.y) { + return; + } + + int src_x = (X + params.offset.x) / params.stride.x; + int src_y = (Y + params.offset.y) / params.stride.y; + + bool outside = src_x < 0 || src_y < 0 || + src_x >= params.src_size.x || src_y >= params.src_size.y; + + int src_index = (gid.z * params.src_size.y + src_y) * params.src_size.x + src_x; + int linear_index = (gid.z * params.dst_size.y + Y) * params.dst_size.x + X; + + int4 indexes = outside ? int4(0) : int4(src_indices_buffer[src_index]); + FLT4 src_color = outside ? FLT4(0.0f) : src_buffer[src_index]; + + int t_x = X - (src_x * params.stride.x - params.offset.x); + int t_y = Y - (src_y * params.stride.y - params.offset.y); + int t_index = t_y * window_w + t_x; + + FLT4 value; + value.x = t_index == indexes.x ? src_color.x : 0.0; + value.y = t_index == indexes.y ? src_color.y : 0.0; + value.z = t_index == indexes.z ? src_color.z : 0.0; + value.w = t_index == indexes.w ? src_color.w : 0.0; + + $$2 + output_buffer[linear_index] = value; + } + )"; + return absl::Substitute(shader_source, kernel_size.w); +} +} // namespace + +std::vector MaxUnpooling( + int id, ValueId input_id, ValueId input_indices_id, ValueId output_id, + const MaxUnpooling2DAttributes& params) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetMaxUnpoolingCode(params.kernel); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + {input_indices_id, "device FLT4* const src_indices_buffer"}, + }; + + desc->output_buffer = {output_id, "device FLT4* output_buffer", + [input_id, input_indices_id, + params](const std::map& buffers) { + return CalculateOutputShape( + buffers.find(input_id)->second, params); + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, input_indices_id, output_id, + params](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + const auto& output_dimension = buffers.find(output_id)->second; + std::vector uniform_params{ + dimension.w, + dimension.h, + output_dimension.w, + output_dimension.h, + params.strides.w, + params.strides.h, + params.padding.prepended.w, + params.padding.prepended.h, + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [input_id, input_indices_id, + params](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + BHWC dst_shape = CalculateOutputShape(src_shape, params); + const uint3 groups_size{16, 16, 1}; + int groups_x = IntegralDivideRoundUp(dst_shape.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(dst_shape.h, groups_size.y); + int groups_z = IntegralDivideRoundUp(dst_shape.c, 4); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h new file mode 100644 index 00000000000..6cf5865e799 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/max_unpooling.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MAX_UNPOOLING_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MAX_UNPOOLING_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector MaxUnpooling( + int id, ValueId input_id, ValueId input_indices_id, ValueId output_id, + const MaxUnpooling2DAttributes& params); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MAX_UNPOOLING_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc b/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc new file mode 100644 index 00000000000..745d0183ebb --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul.cc @@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/mul.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector Multiply( + int id, ValueId input_id, ValueId output_id, + const MultiplyScalarAttributes& attr, const RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = true; + auto multiplier = absl::get_if(&attr.param); + auto mul_buffer = + absl::get_if>(&attr.param); + const bool scalar = multiplier != nullptr; + const std::string param_desc = + scalar ? "float multiplier" : "device FLT4* const mul_buf"; + std::string code = + "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, "; + code += param_desc + ") {\n"; + if (scalar) { + code += "return value * multiplier;\n"; + } else { + code += "return value * mul_buf[gid.z];\n"; + } + code += "}\n"; + desc->shader_source = code; + desc->input_buffers = {{input_id}}; + desc->output_buffer = {output_id}; + if (scalar) { + std::vector multiplier_bits = + VectorToUint8Vector(std::vector{*multiplier}); + desc->uniform_buffers = { + {"constant float&", + [multiplier_bits](const std::map& buffers) { + return multiplier_bits; + }}, + }; + } else { + auto coeffs = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(mul_buffer->data) + : VectorFloatToHalf(mul_buffer->data); + desc->immutable_buffers = { + {"device FLT4* const", coeffs}, + }; + } + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/mul.h b/tensorflow/lite/delegates/gpu/metal/kernels/mul.h new file mode 100644 index 00000000000..60d52163af0 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/mul.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_ + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// Multiply operation, supports scalar and vector broadcast. +std::vector Multiply( + int id, ValueId input_id, ValueId output_id, + const MultiplyScalarAttributes& attr, const RuntimeOptions& options); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_MUL_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/padding.cc b/tensorflow/lite/delegates/gpu/metal/kernels/padding.cc new file mode 100644 index 00000000000..7a9dc1d22b8 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/padding.cc @@ -0,0 +1,154 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/padding.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::string GetPaddingCode() { + const std::string channels[] = {".x", ".y", ".z", ".w"}; + std::string code = R"( + #include + using namespace metal; + + struct uniforms { + int4 src_size; + int4 dst_size; + int4 padding; + }; + + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (static_cast(gid.x) >= params.dst_size.x || + static_cast(gid.y) >= params.dst_size.y) { + return; + } + + FLT4 value = FLT4(0.0f); + int s_x = static_cast(gid.x) - params.padding.x; + int s_y = static_cast(gid.y) - params.padding.y; + bool inside_x = s_x >= 0 && s_x < params.src_size.x; + bool inside_y = s_y >= 0 && s_y < params.src_size.y; + if (inside_x && inside_y) { + int start_channel = static_cast(gid.z) * 4; + )"; + for (int i = 0; i < 4; ++i) { + const auto& s = channels[i]; + code += " {\n"; + code += " int channel = start_channel + " + std::to_string(i) + ";\n"; + code += " int s_z = channel - params.padding.z;\n"; + code += " if (s_z >= 0 && s_z < params.src_size.z) {\n"; + code += + " int buffer_index = ((s_z / 4) * params.src_size.y + s_y) * " + "params.src_size.x + " + "s_x;\n"; + code += " FLT4 t = src_buffer[buffer_index];\n"; + code += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n"; + code += " value" + s + " = t_ar[s_z % 4];\n"; + code += " }\n"; + code += " }\n"; + } + code += " }\n"; + code += + " int linear_index = (gid.z * params.dst_size.y + int(gid.y)) * " + "params.dst_size.x + " + "int(gid.x);\n"; + code += " $2\n"; + code += " dst_buffer[linear_index] = value;\n"; + code += "}\n"; + return code; +} +} // namespace + +std::vector Padding(int id, ValueId input_id, + ValueId output_id, + const PadAttributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetPaddingCode(); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + return CalculateOutputShape(buffers.find(input_id)->second, attr); + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, attr](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + const auto& output_dimension = buffers.find(output_id)->second; + std::vector uniform_params{ + // int4 src_size + dimension.w, + dimension.h, + dimension.c, + IntegralDivideRoundUp(dimension.c, 4), + // int4 dst_size + output_dimension.w, + output_dimension.h, + output_dimension.c, + IntegralDivideRoundUp(output_dimension.c, 4), + // int3 prepended padding + alignment to int4 + attr.prepended.w, + attr.prepended.h, + attr.prepended.c, + 0, + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [input_id, + attr](const std::map& buffers) { + const uint3 groups_size{16, 16, 1}; + const auto& src_shape = buffers.find(input_id)->second; + BHWC dst_shape = CalculateOutputShape(src_shape, attr); + const int dst_layers = IntegralDivideRoundUp(dst_shape.c, 4); + int groups_x = IntegralDivideRoundUp(dst_shape.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(dst_shape.h, groups_size.y); + int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/padding.h b/tensorflow/lite/delegates/gpu/metal/kernels/padding.h new file mode 100644 index 00000000000..177cc4055c2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/padding.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_PADDING_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_PADDING_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// creates TaskDescriptor for Padding operation +std::vector Padding(int id, ValueId input_id, + ValueId output_id, + const PadAttributes& attr); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_PADDING_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/metal/kernels/pooling.cc new file mode 100644 index 00000000000..5e876228bd3 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/pooling.cc @@ -0,0 +1,272 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/pooling.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::string GetMaxPoolingCode(const HW& kernel_size) { + std::string shader_source = R"( + #include + using namespace metal; + constant int window_w = $0; + constant int window_h = $1; + struct uniforms { + int4 src_size; + int4 dst_size; + int2 stride; + int2 offset; + }; + + $$0 + kernel void ComputeFunction( + $$1 + uint3 gid[[thread_position_in_grid]]) { + if (static_cast(gid.x) >= params.dst_size.x || + static_cast(gid.y) >= params.dst_size.y || + static_cast(gid.z) >= params.dst_size.z) { + return; + } + + FLT4 maximum = FLT4(-10000.0); + for (int a = 0; a < window_h; ++a) { + for (int b = 0; b < window_w; ++b) { + const int2 coords = int2(gid.xy) * params.stride - params.offset + int2(b, a); + bool outside = coords.x < 0 || coords.y < 0 || coords.x >= params.src_size.x || + coords.y >= params.src_size.y; + const int buffer_index = (gid.z * params.src_size.y + coords.y) * + params.src_size.x + coords.x; + FLT4 src_color = outside ? FLT4(-10000.0) : src_buffer[buffer_index]; + maximum = max(maximum, src_color); + } + } + const int linear_index = (gid.z * params.dst_size.y + int(gid.y)) * params.dst_size.x + + int(gid.x); + FLT4 value = maximum; + $$2 + output_buffer[linear_index] = value; + } + )"; + return absl::Substitute(shader_source, kernel_size.w, kernel_size.h); +} + +std::string GetMaxPoolingIndicesCode(const HW& kernel_size) { + std::string shader_source = R"( + #include + using namespace metal; + constant int window_w = $0; + constant int window_h = $1; + struct uniforms { + int4 src_size; + int4 dst_size; + int2 stride; + int2 offset; + }; + + $$0 + kernel void ComputeFunction( + $$1 + uint3 gid[[thread_position_in_grid]]) { + if (static_cast(gid.x) >= params.dst_size.x || + static_cast(gid.y) >= params.dst_size.y || + static_cast(gid.z) >= params.dst_size.z) { + return; + } + + FLT4 maximum = FLT4(-10000.0); + ushort4 indexes = ushort4(0); + ushort index_counter = 0; + for (int a = 0; a < window_h; ++a) { + for (int b = 0; b < window_w; ++b) { + const int2 coords = int2(gid.xy) * params.stride - params.offset + int2(b, a); + bool outside = coords.x < 0 || coords.y < 0 || coords.x >= params.src_size.x || + coords.y >= params.src_size.y; + const int buffer_index = (gid.z * params.src_size.y + coords.y) * + params.src_size.x + coords.x; + FLT4 src_color = outside ? FLT4(-10000.0) : src_buffer[buffer_index]; + if (src_color.x > maximum.x) { + indexes.x = index_counter; + maximum.x = src_color.x; + } + if (src_color.y > maximum.y) { + indexes.y = index_counter; + maximum.y = src_color.y; + } + if (src_color.z > maximum.z) { + indexes.z = index_counter; + maximum.z = src_color.z; + } + if (src_color.w > maximum.w) { + indexes.w = index_counter; + maximum.w = src_color.w; + } + index_counter++; + } + } + const int linear_index = (gid.z * params.dst_size.y + int(gid.y)) * params.dst_size.x + + int(gid.x); + FLT4 value = static_cast(indexes) + FLT4(0.1); + $$2 + output_buffer[linear_index] = value; + } + )"; + return absl::Substitute(shader_source, kernel_size.w, kernel_size.h); +} + +std::string GetAveragePoolingCode(const HW& kernel_size) { + std::string shader_source = R"( + #include + using namespace metal; + constant int window_w = $0; + constant int window_h = $1; + constant float multiplier = $2; + struct uniforms { + int4 src_size; + int4 dst_size; + int2 stride; + int2 offset; + }; + $$0 + kernel void ComputeFunction( + $$1 + uint tid[[thread_index_in_threadgroup]], + uint3 gid[[thread_position_in_grid]]) { + if (static_cast(gid.x) >= params.dst_size.x || + static_cast(gid.y) >= params.dst_size.y || + static_cast(gid.z) >= params.dst_size.z) { + return; + } + + float4 sum = float4(0.0f); + for (int a = 0; a < window_h; ++a) { + for (int b = 0; b < window_w; ++b) { + const int2 coords = int2(gid.xy) * params.stride - params.offset + int2(b, a); + bool outside = coords.x < 0 || coords.y < 0 || coords.x >= params.src_size.x || + coords.y >= params.src_size.y; + const int buffer_index = (gid.z * params.src_size.y + coords.y) * + params.src_size.x + coords.x; + const float4 src_color = outside ? float4(0.0f) : float4(src_buffer[buffer_index]); + sum += src_color; + } + } + const int linear_index = (gid.z * params.dst_size.y + int(gid.y)) * params.dst_size.x + + int(gid.x); + FLT4 value = FLT4(sum * multiplier); + $$2 + output_buffer[linear_index] = value; + } +)"; + float multiplier = 1.0f / static_cast(kernel_size.w * kernel_size.h); + return absl::Substitute(shader_source, kernel_size.w, kernel_size.h, + multiplier); +} + +ComputeTaskDescriptorPtr PoolingInternal(int id, ValueId input_id, + ValueId output_id, + const Pooling2DAttributes& params, + bool generate_indices) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + if (params.type == PoolingType::MAX) { + desc->shader_source = generate_indices + ? GetMaxPoolingIndicesCode(params.kernel) + : GetMaxPoolingCode(params.kernel); + } else if (params.type == PoolingType::AVERAGE) { + desc->shader_source = GetAveragePoolingCode(params.kernel); + } + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* output_buffer", + [input_id, params](const std::map& buffers) { + return CalculateOutputShape(buffers.find(input_id)->second, params); + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + const auto& output_dimension = buffers.find(output_id)->second; + std::vector uniform_params = { + dimension.w, + dimension.h, + IntegralDivideRoundUp(dimension.c, 4), + dimension.w * dimension.h, + output_dimension.w, + output_dimension.h, + IntegralDivideRoundUp(dimension.c, 4), + output_dimension.w * output_dimension.h, + params.strides.w, + params.strides.h, + params.padding.prepended.w, + params.padding.prepended.h, + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + BHWC dst_shape = buffers.find(output_id)->second; + const uint3 grid = + uint3(dst_shape.w, dst_shape.h, IntegralDivideRoundUp(dst_shape.c, 4)); + const uint3 groups_size = GetWorkGroupSizeForGrid(grid); + int groups_x = IntegralDivideRoundUp(grid.x, groups_size.x); + int groups_y = IntegralDivideRoundUp(grid.y, groups_size.y); + int groups_z = IntegralDivideRoundUp(grid.z, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return desc; +} + +} // namespace + +std::vector Pooling( + int id, ValueId input_id, const std::vector& output_ids, + const Pooling2DAttributes& params) { + std::vector descriptors; + descriptors.push_back( + PoolingInternal(id, input_id, output_ids[0], params, false)); + if (params.type == PoolingType::MAX && params.output_indices) { + descriptors.push_back( + PoolingInternal(id, input_id, output_ids[1], params, true)); + } + return descriptors; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/pooling.h b/tensorflow/lite/delegates/gpu/metal/kernels/pooling.h new file mode 100644 index 00000000000..c2b3ff7e5c2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/pooling.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_POOLING_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_POOLING_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector Pooling( + int id, ValueId input_id, const std::vector& output_id, + const Pooling2DAttributes& params); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_POOLING_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/metal/kernels/prelu.cc new file mode 100644 index 00000000000..617ab6ddf06 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/prelu.cc @@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/prelu.h" + +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector PReLU(int id, ValueId input_id, + ValueId output_id, + const PReLUAttributes& attr, + const RuntimeOptions& options) { + auto alpha_buffer = + absl::get_if>(&attr.alpha); + if (!alpha_buffer) { + return {}; + } + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = true; + if (attr.clip != 0) { + desc->shader_source = + R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, + device FLT4* const alphas, float clip) { + return FLT4(clamp(value, FLT4(0.0f), FLT4(clip)) + alphas[gid.z] * min(FLT4(0.0f), value)); + })"; + } else { + desc->shader_source = + R"(FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, + device FLT4* const alphas) { + return FLT4(max(FLT4(0.0f), value) + alphas[gid.z] * min(FLT4(0.0f), value)); + })"; + } + desc->input_buffers = {{input_id}}; + desc->output_buffer = {output_id}; + auto alphas = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(alpha_buffer->data) + : VectorFloatToHalf(alpha_buffer->data); + desc->immutable_buffers = { + {"device FLT4* const", alphas}, + }; + if (attr.clip != 0) { + desc->uniform_buffers = { + {"constant float&", + [attr](const std::map& buffers) { + std::vector attr_clip = + VectorToUint8Vector(std::vector{attr.clip}); + return attr_clip; + }}, + }; + } + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/prelu.h b/tensorflow/lite/delegates/gpu/metal/kernels/prelu.h new file mode 100644 index 00000000000..b29387f0003 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/prelu.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_PRELU_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_PRELU_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// Parametric Rectified Linear Unit. +std::vector PReLU(int id, ValueId input_id, + ValueId output_id, + const PReLUAttributes& attr, + const RuntimeOptions& options); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_PRELU_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/relu.cc b/tensorflow/lite/delegates/gpu/metal/kernels/relu.cc new file mode 100644 index 00000000000..ccd45a15ea5 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/relu.cc @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/relu.h" + +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector ReLU(int id, ValueId input_id, + ValueId output_id, + const ReLUAttributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = true; + const std::string min_func = + attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * params.x, 0.0f)"; + const std::string parameters = + "FLT4 linkable$0(FLT4 value, int linear_index, uint3 gid, float2 params) " + "{\n"; + if (attr.clip != 0.0) { + desc->shader_source = parameters + " return FLT4(clamp(value, " + + min_func + ", FLT4(params.y)));\n}"; + } else { + desc->shader_source = + parameters + " return FLT4(max(value, " + min_func + "));\n}"; + } + desc->input_buffers = {{input_id}}; + desc->output_buffer = {output_id}; + desc->uniform_buffers = { + {"constant float2&", + [attr](const std::map& buffers) { + return VectorToUint8Vector(std::vector{attr.alpha, attr.clip}); + }}, + }; + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/relu.h b/tensorflow/lite/delegates/gpu/metal/kernels/relu.h new file mode 100644 index 00000000000..a6b8dfaef69 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/relu.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RELU_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RELU_H_ + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// Rectified Linear Unit +std::vector ReLU(int id, ValueId input_id, + ValueId output_id, + const ReLUAttributes& attr); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RELU_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/reshape.cc b/tensorflow/lite/delegates/gpu/metal/kernels/reshape.cc new file mode 100644 index 00000000000..16783035b65 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/reshape.cc @@ -0,0 +1,234 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/reshape.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { +std::string GetReshapeCode() { + std::string code = R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; +}; + +$0 +kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + const int3 igid = int3(gid); + + if (igid.x >= params.dst_size.x || igid.y >= params.dst_size.y || + igid.z * 4 >= params.dst_size.z) return; + + FLT4 value; + + for (int i = 0; i < 4; ++i) { + const int dst_channel = igid.z * 4 + i; + if (dst_channel < params.dst_size.z) { + int p = dst_channel + params.dst_size.z * igid.x + params.dst_size.w * igid.y; + int src_y = p / params.src_size.w; + int t0 = p - src_y * params.src_size.w; // p % params.src_size.w; + int src_x = t0 / params.src_size.z; + int src_z = t0 - src_x * params.src_size.z; // t0 % params.src_size.z; + int src_layer = src_z >> 2; + int src_channel = src_z & 3; + int src_linear_id = (src_layer * params.src_size.y + src_y) * params.src_size.x + src_x; + value[i] = src_buffer[src_linear_id][src_channel]; + } + } + + int linear_index = (igid.z * params.dst_size.y + igid.y) * params.dst_size.x + igid.x; + $2 + dst_buffer[linear_index] = value; +})"; + return code; +} + +std::string GetReshapex4Code() { + std::string code = R"( +#include +using namespace metal; + +struct uniforms { + int4 src_size; + int4 dst_size; + int2 plane_xz; + int2 dummy0; // dummy, for alignment + int4 dummy1; // dummy, for alignment +}; + +$0 +kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + int X = gid.x; + int Y = gid.y; + int Z = gid.z; + + if (X >= params.dst_size.x || Y >= params.dst_size.y || Z >= params.dst_size.z) return; + + int p = Z + params.dst_size.z * X + params.plane_xz.y * Y; + int src_y = p / params.plane_xz.x; + int t0 = p - src_y * params.plane_xz.x; // p % params.plane_xz.x; + int src_x = t0 / params.src_size.z; + int src_z = t0 - src_x * params.src_size.z; // t0 % params.src_size.z; + + int src_index = src_z * params.src_size.w + src_y * params.src_size.x + src_x; + int linear_index = Z * params.dst_size.w + Y * params.dst_size.x + X; + FLT4 value = src_buffer[src_index]; + $2 + dst_buffer[linear_index] = value; +})"; + return code; +} + +} // namespace + +std::vector Reshape(int id, ValueId input_id, + ValueId output_id, + const ReshapeAttributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetReshapeCode(); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + int batch = buffers.find(input_id)->second.b; + return BHWC{batch, attr.new_shape.h, attr.new_shape.w, + attr.new_shape.c}; + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id](const std::map& buffers) { + const auto& src_dim = buffers.find(input_id)->second; + const auto& dst_dim = buffers.find(output_id)->second; + std::vector uniform_params{ + // int4 src_size + src_dim.w, + src_dim.h, + src_dim.c, + src_dim.c * src_dim.w, + // int4 dst_size + dst_dim.w, + dst_dim.h, + dst_dim.c, + dst_dim.c * dst_dim.w, + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [attr](const std::map& buffers) { + const uint3 grid = uint3(attr.new_shape.w, attr.new_shape.h, + IntegralDivideRoundUp(attr.new_shape.c, 4)); + const uint3 groups_size = GetWorkGroupSizeForGrid(grid); + int groups_x = IntegralDivideRoundUp(grid.x, groups_size.x); + int groups_y = IntegralDivideRoundUp(grid.y, groups_size.y); + int groups_z = IntegralDivideRoundUp(grid.z, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +std::vector Reshapex4(int id, ValueId input_id, + ValueId output_id, + const ReshapeAttributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetReshapex4Code(); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + int batch = buffers.find(input_id)->second.b; + return BHWC{batch, attr.new_shape.h, attr.new_shape.w, + attr.new_shape.c}; + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id](const std::map& buffers) { + const auto& src_dim = buffers.find(input_id)->second; + const auto& dst_dim = buffers.find(output_id)->second; + std::vector uniform_params{ + // int4 src_size + src_dim.w, src_dim.h, IntegralDivideRoundUp(src_dim.c, 4), + src_dim.w * src_dim.h, + // int4 dst_size + dst_dim.w, dst_dim.h, IntegralDivideRoundUp(dst_dim.c, 4), + dst_dim.w * dst_dim.h, + // int2 plane_xz + src_dim.w * IntegralDivideRoundUp(src_dim.c, 4), + dst_dim.w * IntegralDivideRoundUp(dst_dim.c, 4), + 0, // dummy, for alignment + 0, // dummy, for alignment + 0, // dummy, for alignment + 0, // dummy, for alignment + 0, // dummy, for alignment + 0 // dummy, for alignment + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [attr](const std::map& buffers) { + const uint3 grid = uint3(attr.new_shape.w, attr.new_shape.h, + IntegralDivideRoundUp(attr.new_shape.c, 4)); + const uint3 groups_size = GetWorkGroupSizeForGrid(grid); + int groups_x = IntegralDivideRoundUp(grid.x, groups_size.x); + int groups_y = IntegralDivideRoundUp(grid.y, groups_size.y); + int groups_z = IntegralDivideRoundUp(grid.z, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/reshape.h b/tensorflow/lite/delegates/gpu/metal/kernels/reshape.h new file mode 100644 index 00000000000..650cfc1d2a4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/reshape.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RESHAPE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RESHAPE_H_ + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// Reshapes a tensor. +// Given tensor, this operation returns a tensor that has the same values +// as tensor with shape dst_shape. +std::vector Reshape(int id, ValueId input_id, + ValueId output_id, + const ReshapeAttributes& attr); + +// This specialization performs faster for the case +// src_channels % 4 == 0 and dst_channels % 4 == 0 +std::vector Reshapex4(int id, ValueId input_id, + ValueId output_id, + const ReshapeAttributes& attr); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_RESHAPE_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/slice.cc b/tensorflow/lite/delegates/gpu/metal/kernels/slice.cc new file mode 100644 index 00000000000..47716b4fa34 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/slice.cc @@ -0,0 +1,187 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/slice.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +std::string GetSliceCode(const SliceAttributes& attr) { + std::stringstream code; + + code << R"( + #include + using namespace metal; + + struct uniforms { + int4 src_size; + int4 dst_size; + }; + + constant int4 width = int4($0, $1, $2, 0); + constant int4 height = int4($3, $4, $5, 0); + constant int4 channels = int4($6, $7, $8, 0); + constant FLT4 null_vec = FLT4(0.0f, 0.0f, 0.0f, 0.0f); + + $$0 + kernel void ComputeFunction( + $$1 + uint3 gid[[thread_position_in_grid]]) { + if (static_cast(gid.x) >= params.dst_size.x || + static_cast(gid.y) >= params.dst_size.y) { + return; + } + + FLT4 value; + short2 offset; + )"; + if (attr.strides.w > 0) { + code << " offset.x = width.x;" << std::endl; + } else { + if (attr.ends.w > 0) { + code << " offset.x = width.z;" << std::endl; + } else { + code << " offset.x = params.src_size.x + width.z;" << std::endl; + } + } + if (attr.strides.h > 0) { + code << " offset.y = height.x;" << std::endl; + } else { + if (attr.ends.h > 0) { + code << " offset.y = height.z;" << std::endl; + } else { + code << " offset.y = params.src_size.y + height.z;" << std::endl; + } + } + code << std::endl; + code << " short2 stride = short2(width.y, height.y);" << std::endl; + + code << " const short2 s_c = offset + short2(gid.xy) * stride;" + << std::endl; + code << " bool outside = false;" << std::endl; + code << " int step = gid.z * 4;" << std::endl; + code << " FLT4 tmp;" << std::endl; + code << " int buffer_index = 0;" << std::endl; + code << " int addr = 0;" << std::endl; + code << std::endl; + for (int i = 0; i < 4; i++) { + code << " addr = step * channels.y;" << std::endl; + if (attr.strides.c > 0) { + code << " addr += channels.x;" << std::endl; + } else { + if (attr.ends.c > 0) { + code << " addr += channels.z;" << std::endl; + } else { + code << " addr += params.src_size.z + channels.z;" << std::endl; + } + } + code << " buffer_index = ((addr / 4) * params.src_size.y + s_c.y) * " + "params.src_size.x + " + "s_c.x;" + << std::endl; + code << " outside = step >= params.dst_size.z;" << std::endl; + code << " tmp = outside ? null_vec : src_buffer[buffer_index];" + << std::endl; + code << " value[" << i << "] = tmp[addr % 4];" << std::endl; + if (i != 3) { + code << " step++;" << std::endl; + code << std::endl; + } + } + code << R"( + int linear_index = (gid.z * params.dst_size.y + int(gid.y)) * params.dst_size.x + int(gid.x); + $$2 + dst_buffer[linear_index] = value; + })"; + return absl::Substitute( + code.str(), attr.starts.w, attr.strides.w, attr.ends.w, attr.starts.h, + attr.strides.h, attr.ends.h, attr.starts.c, attr.strides.c, attr.ends.c); +} +} // namespace + +std::vector Slice(int id, ValueId input_id, + ValueId output_id, + const SliceAttributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetSliceCode(attr); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, attr](const std::map& buffers) { + return CalculateOutputShape(buffers.find(input_id)->second, attr); + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + const auto& output_dimension = buffers.find(output_id)->second; + std::vector uniform_params{ + // int4 src_size + dimension.w, + dimension.h, + dimension.c, + IntegralDivideRoundUp(dimension.c, 4), + // int4 dst_size + output_dimension.w, + output_dimension.h, + output_dimension.c, + IntegralDivideRoundUp(output_dimension.c, 4), + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [input_id, + attr](const std::map& buffers) { + const uint3 groups_size{16, 16, 1}; + const auto& src_shape = buffers.find(input_id)->second; + BHWC dst_shape = CalculateOutputShape(src_shape, attr); + int groups_x = IntegralDivideRoundUp(dst_shape.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(dst_shape.h, groups_size.y); + const int dst_layers = IntegralDivideRoundUp(dst_shape.c, 4); + int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/slice.h b/tensorflow/lite/delegates/gpu/metal/kernels/slice.h new file mode 100644 index 00000000000..494170082b1 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/slice.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SLICE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SLICE_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// Extracts a strided slice of a tensor +std::vector Slice(int id, ValueId input_id, + ValueId output_id, + const SliceAttributes& attr); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SLICE_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc new file mode 100644 index 00000000000..b6e8e2073f7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.cc @@ -0,0 +1,218 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h" + +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { +std::string GetSoftmax1x1Code() { + std::string code = R"( +#include +using namespace metal; + +struct uniforms { + int4 size; + float4 mask; +}; + +$0 + +kernel void ComputeFunction($1 + uint tid[[thread_index_in_threadgroup]], + uint3 ugid[[thread_position_in_grid]]) +{ + int offset = 0; + float sum = 0.0f; + int s = 0; + do { + if (offset + tid < params.size.x) { + float4 mask_temp = offset + tid == params.size.x - 1 ? params.mask : float4(1.0h); + float4 src = float4(src_buffer[offset + tid]); + sum += dot(mask_temp, exp(src)); + offset += 32; + } + s++; + } while (s < params.size.y); + + threadgroup float4 tmp[8]; + threadgroup float* tmpx1 = (threadgroup float*)tmp; + tmpx1[tid] = sum; + BARRIER(mem_flags::mem_threadgroup); + if (tid == 0) { + sum = dot(float4(1.0f), tmp[0]); + sum += dot(float4(1.0f), tmp[1]); + sum += dot(float4(1.0f), tmp[2]); + sum += dot(float4(1.0f), tmp[3]); + sum += dot(float4(1.0f), tmp[4]); + sum += dot(float4(1.0f), tmp[5]); + sum += dot(float4(1.0f), tmp[6]); + sum += dot(float4(1.0f), tmp[7]); + tmpx1[0] = 1.0 / sum; + } + BARRIER(mem_flags::mem_threadgroup); + sum = tmpx1[0]; + + offset = 0; + s = 0; + do { + if (offset + tid < params.size.x) { + int linear_index = offset + tid; + FLT4 value = FLT4(exp(float4(src_buffer[linear_index])) * sum); + uint3 gid = uint3(0, 0, linear_index); + $2 + dst_buffer[linear_index] = value; + offset += 32; + } + s++; + } while (s < params.size.y); +})"; + return code; +} +} // namespace + +std::vector Softmax(int id, ValueId input_id, + ValueId output_id, + int channels_count) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = R"( + #include + using namespace metal; + constant int src_channels = )"; + desc->shader_source += std::to_string(channels_count); + desc->shader_source += R"(; + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (int(gid.x) >= size.x || int(gid.y) >= size.y) { + return; + } + float shift = 0.0f; + int remaining_channels = src_channels % 4; + + float sum = 0.0f; + for (int d = 0; d < src_channels / 4; ++d) { + int buffer_index = (d * size.y + gid.y) * size.x + gid.x; + sum += dot(float4(1.0f), exp(float4(input_buffer[buffer_index]) - shift)); + } + if (remaining_channels > 0) { + int buffer_index = ((src_channels / 4) * size.y + gid.y) * size.x + gid.x; + float4 last_element = float4(input_buffer[buffer_index]); + sum += exp(last_element.x - shift); + if (remaining_channels > 1) sum += exp(last_element.y - shift); + if (remaining_channels == 3) sum += exp(last_element.z - shift); + } + + for (int d = 0; d < (src_channels + 3) / 4; ++d) { + const int linear_index = (d * size.y + gid.y) * size.x + gid.x; + FLT4 value = FLT4(exp(float4(input_buffer[linear_index]) - shift) / sum); + $2 + output_buffer[linear_index] = value; + } + } + )"; + + desc->input_buffers = { + {input_id, "device FLT4* const input_buffer"}, + }; + + desc->output_buffer = {output_id, "device FLT4* output_buffer", + [input_id](const std::map& buffers) { + return buffers.find(input_id)->second; + }}; + + desc->uniform_buffers = { + {"constant int2& size", + [output_id](const std::map& buffers) { + const auto& dimension = buffers.find(output_id)->second; + std::vector sizes{dimension.w, dimension.h}; + return VectorToUint8Vector(sizes); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + uint3 groups_size{8, 4, 1}; + const auto& dimension = buffers.find(output_id)->second; + uint3 groups_count{IntegralDivideRoundUp(dimension.w, groups_size.x), + IntegralDivideRoundUp(dimension.h, groups_size.y), 1}; + return std::make_pair(groups_size, groups_count); + }; + + return {desc}; +} + +std::vector Softmax1x1(int id, ValueId input_id, + ValueId output_id, + int channels_count) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + desc->shader_source = GetSoftmax1x1Code(); + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = {output_id, "device FLT4* dst_buffer", + [input_id](const std::map& buffers) { + return buffers.find(input_id)->second; + }}; + + desc->uniform_buffers = { + {"constant uniforms& params", + [channels_count](const std::map& buffers) { + const int src_depth = IntegralDivideRoundUp(channels_count, 4); + struct uniforms { + int4 size; + float4 mask; + }; + uniforms params; + params.size = {src_depth, IntegralDivideRoundUp(src_depth, 32), 1, 1}; + params.mask = {0.0f, 0.0f, 0.0f, 0.0f}; + const int reminder = channels_count % 4 == 0 ? 4 : channels_count % 4; + for (int i = 0; i < reminder; ++i) { + params.mask[i] = 1.0f; + } + const uint8_t* ptr = reinterpret_cast(¶ms); + return std::vector(ptr, ptr + sizeof(uniforms)); + }}, + }; + + desc->resize_function = [](const std::map& buffers) { + return std::make_pair(uint3{32u, 1u, 1u}, uint3{1u, 1u, 1u}); + }; + + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h new file mode 100644 index 00000000000..24fa38e8f57 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/softmax.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SOFTMAX_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SOFTMAX_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector Softmax(int id, ValueId input_id, + ValueId output_id, + int channels_count); + +// Softmax for case when width = height = 1 and AXIS = CHANNELS +// We have this case in MobilenetV1/V2. +std::vector Softmax1x1(int id, ValueId input_id, + ValueId output_id, + int channels_count); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_SOFTMAX_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc new file mode 100644 index 00000000000..43d2b8fd1c7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.cc @@ -0,0 +1,1191 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/environment.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +const int kThreadGroupWidth = 16; +const int kThreadGroupHeight = 4; + +std::string GetDeconvolution(const ConvolutionTransposedAttributes& attr) { + std::string constant_args = R"( + constant short2 padding = {$0, $1}; + constant short2 stride = {$2, $3}; + constant short2 kernel_size = {$4, $5}; + constant short2 inner_size = {$6, $7}; + constant short2 kernel_offset = {$8, $9}; + )"; + std::string shader_source = R"( + #include + using namespace metal; + + struct FilterStripe { + FLT4 vals[$0]; + }; + + constant int src_depth = $1; + constant int dst_depth = $2; + constant int dst_channels = $3; + constant int dst_channels_aligned = $4; + + $5 + + struct uniforms { + int2 src_size; + int2 dst_size; + }; + + $$0 + kernel void ComputeFunction( + $$1 + uint2 ugid[[thread_position_in_grid]]) { + if (static_cast(ugid.x) >= params.dst_size.x || + static_cast(ugid.y) >= params.dst_size.y) { + return; + } + + float out[$4]; + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + short2 offset = (short2(ugid) + padding - kernel_offset); + offset.x = offset.x % stride.x; + offset.y = offset.y % stride.y; + offset += stride; + offset.x = offset.x % stride.x; + offset.y = offset.y % stride.y; + short2 f_offset; + f_offset.x = offset.x == 0 ? 0 : (stride.x - offset.x); + f_offset.y = offset.y == 0 ? 0 : (stride.y - offset.y); + for (int ky = 0; ky < inner_size.y; ++ky) { + for (int kx = 0; kx < inner_size.x; ++kx) { + short2 index = short2(kx, ky) * stride + f_offset; + bool inside_kernel = index.x < kernel_size.x && index.y < kernel_size.y; + const short2 src_coord = (short2(ugid) + index + padding - kernel_offset) / stride; + index = kernel_size - short2(1, 1) - index; + bool outside = src_coord.x < 0 || src_coord.y < 0 || + src_coord.x >= params.src_size.x || src_coord.y >= params.src_size.y; + const int kernel_index = index.y * kernel_size.x + index.x; + bool belong = inside_kernel && !outside; + if (belong) { + for (int l = 0; l < src_depth; ++l) { + const int src_index = (l * params.src_size.y + src_coord.y) + * params.src_size.x + src_coord.x; + FLT4 srcColor = src_buffer[src_index]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor, filters[kernel_index].vals[l * dst_channels_aligned + k]); + } + } + } + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(ugid.y)) + * params.dst_size.x + int(ugid.x); + uint3 gid = uint3(ugid.x, ugid.y, uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + } + )"; + const int kernel_x = attr.weights.shape.w; + const int kernel_y = attr.weights.shape.h; + const int inner_size_x = (kernel_x - 1) / attr.stride.w + 1; + const int inner_size_y = (kernel_y - 1) / attr.stride.h + 1; + std::string constant_args_inplaced = absl::Substitute( + constant_args, attr.padding.prepended.w, attr.padding.prepended.h, + attr.stride.w, attr.stride.h, kernel_x, kernel_y, inner_size_x, + inner_size_y, kernel_x - 1, kernel_y - 1); + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); + const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4); + return absl::Substitute(shader_source, src_depth * dst_channels_aligned, + src_depth, dst_depth, attr.weights.shape.o, + dst_channels_aligned, constant_args_inplaced); +} + +std::string GetDeconvolutionShared(const ConvolutionTransposedAttributes& attr, + int workgroup_x, int workgroup_y) { + std::string constant_args = R"( + constant short2 padding = {$0, $1}; + constant short2 stride = {$2, $3}; + constant short2 kernel_size = {$4, $5}; + constant short2 inner_size = {$6, $7}; + constant short2 kernel_offset = {$8, $9}; + )"; + std::string shader_source = R"( + #include + using namespace metal; + + struct FilterStripe { + FLT4 vals[$0]; + }; + + constant int src_depth = $1; + constant int dst_depth = $2; + constant int dst_channels = $3; + constant int dst_channels_aligned = $4; + + $5 + + constant short2 src_local_size = {$6, $7}; + + struct uniforms { + int2 src_size; + int2 dst_size; + }; + + $$0 + kernel void ComputeFunction( + $$1 + uint2 tid[[thread_position_in_threadgroup]], + uint2 ugid[[thread_position_in_grid]]) { + float out[$4]; + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + short2 offset = (short2(ugid) + padding - kernel_offset); + offset.x = offset.x % stride.x; + offset.y = offset.y % stride.y; + offset += stride; + offset.x = offset.x % stride.x; + offset.y = offset.y % stride.y; + short2 f_offset; + f_offset.x = offset.x == 0 ? 0 : stride.x - offset.x; + f_offset.y = offset.y == 0 ? 0 : stride.y - offset.y; + + short2 first_gid = short2((ugid.x / $8) * $8, (ugid.y / $9) * $9); + + short2 shared_offset = (first_gid + padding - kernel_offset); + shared_offset.x = shared_offset.x % stride.x; + shared_offset.y = shared_offset.y % stride.y; + shared_offset += stride; + shared_offset.x = shared_offset.x % stride.x; + shared_offset.y = shared_offset.y % stride.y; + short2 shared_f_offset; + shared_f_offset.x = shared_offset.x == 0 ? 0 : (stride.x - shared_offset.x); + shared_f_offset.y = shared_offset.y == 0 ? 0 : (stride.y - shared_offset.y); + + short2 first_index = short2(0, 0) * stride + shared_f_offset; + const short2 first_src_coord = (first_gid + first_index + padding - kernel_offset) / stride; + threadgroup FLT4 src_shared[$6][$7][$1]; + if (static_cast(tid.x) < src_local_size.x && + static_cast(tid.y) < src_local_size.y) { + for (int z = 0; z < src_depth; ++z) { + const short2 src_coord = first_src_coord + short2(tid); + bool outside = src_coord.x < 0 || src_coord.y < 0 || + src_coord.x >= params.src_size.x || src_coord.y >= params.src_size.y; + const int src_index = (z * params.src_size.y + src_coord.y) + * params.src_size.x + src_coord.x; + FLT4 src = !outside ? src_buffer[src_index] : FLT4(0.0f); + src_shared[tid.x][tid.y][z] = src; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (static_cast(ugid.x) >= params.dst_size.x || + static_cast(ugid.y) >= params.dst_size.y) { + return; + } + + for (int ky = 0; ky < inner_size.y; ++ky) { + for (int kx = 0; kx < inner_size.x; ++kx) { + short2 index = short2(kx, ky) * stride + f_offset; + bool inside_kernel = index.x < kernel_size.x && index.y < kernel_size.y; + const short2 src_coord = (short2(ugid) + index + padding - kernel_offset) / stride; + index = kernel_size - short2(1, 1) - index; + bool outside = src_coord.x < 0 || src_coord.y < 0 || + src_coord.x >= params.src_size.x || src_coord.y >= params.src_size.y; + const int kernel_index = index.y * kernel_size.x + index.x; + bool belong = inside_kernel && !outside; + if (belong) { + for (int k = 0; k < dst_channels; ++k) { + for (int l = 0; l < src_depth; ++l) { + short2 src_index = src_coord - first_src_coord; + out[k] += dot(src_shared[src_index.x][src_index.y][l], + filters[kernel_index].vals[l * dst_channels_aligned + k]); + } + } + } + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(ugid.y)) + * params.dst_size.x + int(ugid.x); + uint3 gid = uint3(ugid.x, ugid.y, uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + } + )"; + const int kernel_x = attr.weights.shape.w; + const int kernel_y = attr.weights.shape.h; + const int inner_size_x = (kernel_x - 1) / attr.stride.w + 1; + const int inner_size_y = (kernel_y - 1) / attr.stride.h + 1; + std::string constant_args_inplaced = absl::Substitute( + constant_args, attr.padding.prepended.w, attr.padding.prepended.h, + attr.stride.w, attr.stride.h, kernel_x, kernel_y, inner_size_x, + inner_size_y, kernel_x - 1, kernel_y - 1); + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); + const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4); + const int src_local_size_x = (workgroup_x + kernel_x) / attr.stride.w; + const int src_local_size_y = (workgroup_y + kernel_y) / attr.stride.h; + return absl::Substitute( + shader_source, src_depth * dst_channels_aligned, src_depth, dst_depth, + attr.weights.shape.o, dst_channels_aligned, constant_args_inplaced, + src_local_size_x, src_local_size_y, workgroup_x, workgroup_y); +} + +struct GridParams { + uint rect_offsets[4]; + uint widths[4]; + short2 origins[4]; + uint elements_count; +}; + +struct Params3x3 { + short2 inner_size; + short2 src_offset; + short2 dst_offset; +}; + +void Init3x3(const ConvolutionTransposedAttributes& attr, const int2& src_size, + const int2& dst_size, GridParams* grid_params, + Params3x3* params3x3) { + short2 src_size_scaled; + src_size_scaled.x = (src_size.x - 1) * 2; + src_size_scaled.y = (src_size.y - 1) * 2; + short2 top_left_src, bottom_right_src; + top_left_src.x = 1 - attr.padding.prepended.w; + top_left_src.y = 1 - attr.padding.prepended.h; + bottom_right_src.x = top_left_src.x + src_size_scaled.x; + bottom_right_src.y = top_left_src.y + src_size_scaled.y; + short2 top_left_inner, bottom_right_inner; + if (top_left_src.x >= 0) { + top_left_inner.x = top_left_src.x; + } else { + top_left_inner.x = std::abs(top_left_src.x % 2); + } + if (top_left_src.y >= 0) { + top_left_inner.y = top_left_src.y; + } else { + top_left_inner.y = std::abs(top_left_src.y % 2); + } + + if (bottom_right_src.x <= dst_size.x) { + bottom_right_inner.x = bottom_right_src.x; + } else { + bottom_right_inner.x = dst_size.x; + } + if (top_left_src.x % 2 == 0) { + bottom_right_inner.x -= bottom_right_inner.x % 2; + } else { + if (bottom_right_inner.x % 2 == 0) { + bottom_right_inner.x -= 1; + } + } + bottom_right_inner.x -= 1; + + if (bottom_right_src.y <= dst_size.y) { + bottom_right_inner.y = bottom_right_src.y; + } else { + bottom_right_inner.y = dst_size.y; + } + if (top_left_src.y % 2 == 0) { + bottom_right_inner.y -= bottom_right_inner.y % 2; + } else { + if (bottom_right_inner.y % 2 == 0) { + bottom_right_inner.y -= 1; + } + } + bottom_right_inner.y -= 1; + + params3x3->dst_offset = top_left_inner; + params3x3->src_offset.x = (top_left_inner.x - top_left_src.x) / 2; + params3x3->src_offset.y = (top_left_inner.y - top_left_src.y) / 2; + params3x3->inner_size.x = + std::max(0, bottom_right_inner.x - top_left_inner.x + 1) / 2; + params3x3->inner_size.y = + std::max(0, bottom_right_inner.y - top_left_inner.y + 1) / 2; + + short2 top_rect, bottom_rect, left_rect, right_rect; + + top_rect.x = dst_size.x; + top_rect.y = top_left_inner.y; + + bottom_rect.x = dst_size.x; + bottom_rect.y = dst_size.y - bottom_right_inner.y - 1; + + left_rect.x = top_left_inner.x; + left_rect.y = dst_size.y - top_rect.y - bottom_rect.y; + + right_rect.x = dst_size.x - bottom_right_inner.x - 1; + right_rect.y = left_rect.y; + + grid_params->widths[0] = top_rect.x; + grid_params->widths[1] = left_rect.x; + grid_params->widths[2] = right_rect.x; + grid_params->widths[3] = bottom_rect.x; + + grid_params->rect_offsets[0] = 0; + grid_params->rect_offsets[1] = + grid_params->rect_offsets[0] + top_rect.x * top_rect.y; + grid_params->rect_offsets[2] = + grid_params->rect_offsets[1] + left_rect.x * left_rect.y; + grid_params->rect_offsets[3] = + grid_params->rect_offsets[2] + right_rect.x * right_rect.y; + grid_params->elements_count = + grid_params->rect_offsets[3] + bottom_rect.x * bottom_rect.y; + + grid_params->origins[0] = short2(0, 0); + grid_params->origins[1] = short2(int16_t(0), int16_t(top_rect.y)); + grid_params->origins[2] = + short2(int16_t(dst_size.x - right_rect.x), int16_t(top_rect.y)); + grid_params->origins[3] = short2(0, dst_size.y - bottom_rect.y); +} + +std::string GetDeconvolutionBorder( + const ConvolutionTransposedAttributes& attr) { + std::string constant_args = R"( + constant short2 padding = {$0, $1}; + constant short2 stride = {$2, $3}; + constant short2 kernel_size = {$4, $5}; + constant short2 inner_size = {$6, $7}; + constant short2 kernel_offset = {$8, $9}; + )"; + std::string shader_source = R"( + #include + using namespace metal; + + struct FilterStripe { + FLT4 vals[$0]; + }; + + constant int src_depth = $1; + constant int dst_depth = $2; + constant int dst_channels = $3; + constant int dst_channels_aligned = $4; + + $5 + + struct uniforms { + int2 src_size; + int2 dst_size; + uint rect_offsets[4]; + uint widths[4]; + short2 origins[4]; + uint elements_count; + }; + + short2 GetGridIdByLinearId(uint linear_id, constant uniforms& params); + + short2 GetGridIdByLinearId(uint linear_id, constant uniforms& params) { + int index = 0; + index = linear_id >= params.rect_offsets[0] ? 0 : index; + index = linear_id >= params.rect_offsets[1] ? 1 : index; + index = linear_id >= params.rect_offsets[2] ? 2 : index; + index = linear_id >= params.rect_offsets[3] ? 3 : index; + + const uint rect_index = linear_id - params.rect_offsets[index]; + + const uint rect_width = params.widths[index]; + const short2 offset = short2(rect_index % rect_width, rect_index / rect_width); + return params.origins[index] + offset; + } + + $$0 + kernel void ComputeFunction( + $$1 + uint linear_id[[thread_position_in_grid]]) { + if (linear_id >= params.elements_count) { + return; + } + short2 gid_sh = GetGridIdByLinearId(linear_id, params); + + float out[$4]; + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + short2 offset = gid_sh + padding - kernel_offset; + offset.x = offset.x % stride.x; + offset.y = offset.y % stride.y; + offset += stride; + offset.x = offset.x % stride.x; + offset.y = offset.y % stride.y; + short2 f_offset; + f_offset.x = offset.x == 0 ? 0 : stride.x - offset.x; + f_offset.y = offset.y == 0 ? 0 : stride.y - offset.y; + for (int ky = 0; ky < inner_size.y; ++ky) { + for (int kx = 0; kx < inner_size.x; ++kx) { + short2 index = short2(kx, ky) * stride + f_offset; + bool inside_kernel = index.x < kernel_size.x && index.y < kernel_size.y; + const short2 src_coord = (gid_sh + index + padding - kernel_offset) / stride; + index = kernel_size - short2(1, 1) - index; + bool outside = src_coord.x < 0 || src_coord.y < 0 || + src_coord.x >= params.src_size.x || src_coord.y >= params.src_size.y; + const int kernel_index = index.y * kernel_size.x + index.x; + bool belong = inside_kernel && !outside; + if (belong) { + for (int l = 0; l < src_depth; ++l) { + const int src_index = (l * params.src_size.y + src_coord.y) * + params.src_size.x + src_coord.x; + FLT4 srcColor = src_buffer[src_index]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor, filters[kernel_index].vals[l * dst_channels_aligned + k]); + } + } + } + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(gid_sh.y)) * + params.dst_size.x + int(gid_sh.x); + uint3 gid = uint3(uint(gid_sh.x), uint(gid_sh.y), uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + } + )"; + const int kernel_x = attr.weights.shape.w; + const int kernel_y = attr.weights.shape.h; + const int inner_size_x = (kernel_x - 1) / attr.stride.w + 1; + const int inner_size_y = (kernel_y - 1) / attr.stride.h + 1; + std::string constant_args_inplaced = absl::Substitute( + constant_args, attr.padding.prepended.w, attr.padding.prepended.h, + attr.stride.w, attr.stride.h, kernel_x, kernel_y, inner_size_x, + inner_size_y, kernel_x - 1, kernel_y - 1); + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); + const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4); + return absl::Substitute(shader_source, src_depth * dst_channels_aligned, + src_depth, dst_depth, attr.weights.shape.o, + dst_channels_aligned, constant_args_inplaced); +} + +std::string GetDeconvolution3x3(const ConvolutionTransposedAttributes& attr) { + std::string shader_source = R"( + #include + using namespace metal; + + struct FilterStripe { + FLT4 vals[$0]; + }; + + constant int src_depth = $1; + constant int dst_depth = $2; + constant int dst_channels = $3; + constant int dst_channels_aligned = $4; + + struct uniforms { + int2 src_size; + int2 dst_size; + short2 inner_size; + short2 src_offset; + short2 dst_offset; + }; + + $$0 + kernel void ComputeFunction( + $$1 + uint tid[[thread_index_in_threadgroup]], + uint2 ugid[[thread_position_in_grid]]) { + if (static_cast(ugid.x) >= params.inner_size.x || + static_cast(ugid.y) >= params.inner_size.y) { + return; + } + + float out[$4]; + short2 src_coord_0 = short2(ugid) + params.src_offset; + short2 dst_coord = short2(ugid) * 2 + params.dst_offset; + + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + for (int l = 0; l < src_depth; ++l) { + const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * + params.src_size.x + src_coord_0.x; + FLT4 srcColor_0 = src_buffer[src_index_0]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor_0, filters[4].vals[l * dst_channels_aligned + k]); + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * + params.dst_size.x + int(dst_coord.x); + uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + + short2 src_coord_1 = src_coord_0 + short2(1, 0); + dst_coord += short2(1, 0); + + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + for (int l = 0; l < src_depth; ++l) { + const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * + params.src_size.x + src_coord_0.x; + const int src_index_1 = (l * params.src_size.y + src_coord_1.y) * + params.src_size.x + src_coord_1.x; + FLT4 srcColor_0 = src_buffer[src_index_0]; + FLT4 srcColor_1 = src_buffer[src_index_1]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor_0, filters[5].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_1, filters[3].vals[l * dst_channels_aligned + k]); + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * + params.dst_size.x + int(dst_coord.x); + uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + + short2 src_coord_2 = src_coord_0 + short2(0, 1); + dst_coord += short2(-1, 1); + + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + for (int l = 0; l < src_depth; ++l) { + const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * + params.src_size.x + src_coord_0.x; + const int src_index_2 = (l * params.src_size.y + src_coord_2.y) * + params.src_size.x + src_coord_2.x; + FLT4 srcColor_0 = src_buffer[src_index_0]; + FLT4 srcColor_2 = src_buffer[src_index_2]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor_0, filters[7].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_2, filters[1].vals[l * dst_channels_aligned + k]); + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * + params.dst_size.x + int(dst_coord.x); + uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + + short2 src_coord_3 = src_coord_0 + short2(1, 1); + dst_coord += short2(1, 0); + + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + for (int l = 0; l < src_depth; ++l) { + const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * + params.src_size.x + src_coord_0.x; + const int src_index_1 = (l * params.src_size.y + src_coord_1.y) * + params.src_size.x + src_coord_1.x; + const int src_index_2 = (l * params.src_size.y + src_coord_2.y) * + params.src_size.x + src_coord_2.x; + const int src_index_3 = (l * params.src_size.y + src_coord_3.y) * + params.src_size.x + src_coord_3.x; + FLT4 srcColor_0 = src_buffer[src_index_0]; + FLT4 srcColor_1 = src_buffer[src_index_1]; + FLT4 srcColor_2 = src_buffer[src_index_2]; + FLT4 srcColor_3 = src_buffer[src_index_3]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor_0, filters[8].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_1, filters[6].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_2, filters[2].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_3, filters[0].vals[l * dst_channels_aligned + k]); + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * + params.dst_size.x + int(dst_coord.x); + uint3 gid = uint3(uint(dst_coord.x), uint(dst_coord.y), uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + } + )"; + + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); + const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4); + return absl::Substitute(shader_source, src_depth * dst_channels_aligned, + src_depth, dst_depth, attr.weights.shape.o, + dst_channels_aligned); +} + +std::string GetDeconvolutionShared3x3( + const ConvolutionTransposedAttributes& attr) { + std::string shader_source = R"( + #include + using namespace metal; + + struct FilterStripe { + FLT4 vals[$0]; + }; + + constant int src_depth = $1; + constant int dst_depth = $2; + constant int dst_channels = $3; + constant int dst_channels_aligned = $4; + + struct uniforms { + int2 src_size; + int2 dst_size; + short2 inner_size; + short2 src_offset; + short2 dst_offset; + }; + + $$0 + kernel void ComputeFunction( + $$1 + uint tid[[thread_index_in_threadgroup]], + uint2 ugid[[thread_position_in_grid]]) { + + float out[$4]; + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + threadgroup FilterStripe stripes[4]; + threadgroup_barrier(mem_flags::mem_none); + if (tid < dst_channels) { + for (int l = 0; l < src_depth; ++l) { + stripes[0].vals[l * dst_channels_aligned + tid] + = filters[4].vals[l * dst_channels_aligned + tid]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + bool inside_grid = (static_cast(ugid.x) < params.inner_size.x) + && (static_cast(ugid.y) < params.inner_size.y); + + short2 src_coord_0 = short2(ugid) + params.src_offset; + short2 dst_coord = short2(ugid) * 2 + params.dst_offset; + + if (inside_grid) { + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + for (int l = 0; l < src_depth; ++l) { + const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * + params.src_size.x + src_coord_0.x; + FLT4 srcColor_0 = src_buffer[src_index_0]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]); + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * + params.dst_size.x + int(dst_coord.x); + uint3 gid = uint3(ugid.x, ugid.y, uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + } + + short2 src_coord_1 = src_coord_0 + short2(1, 0); + dst_coord += short2(1, 0); + + threadgroup_barrier(mem_flags::mem_none); + if (tid < dst_channels) { + for (int l = 0; l < src_depth; ++l) { + stripes[0].vals[l * dst_channels_aligned + tid] + = filters[5].vals[l * dst_channels_aligned + tid]; + stripes[1].vals[l * dst_channels_aligned + tid] + = filters[3].vals[l * dst_channels_aligned + tid]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (inside_grid) { + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + for (int l = 0; l < src_depth; ++l) { + const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * + params.src_size.x + src_coord_0.x; + const int src_index_1 = (l * params.src_size.y + src_coord_1.y) * + params.src_size.x + src_coord_1.x; + FLT4 srcColor_0 = src_buffer[src_index_0]; + FLT4 srcColor_1 = src_buffer[src_index_1]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_1, stripes[1].vals[l * dst_channels_aligned + k]); + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * + params.dst_size.x + int(dst_coord.x); + uint3 gid = uint3(ugid.x, ugid.y, uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + } + + short2 src_coord_2 = src_coord_0 + short2(0, 1); + dst_coord += short2(-1, 1); + + threadgroup_barrier(mem_flags::mem_none); + if (tid < dst_channels) { + for (int l = 0; l < src_depth; ++l) { + stripes[0].vals[l * dst_channels_aligned + tid] + = filters[7].vals[l * dst_channels_aligned + tid]; + stripes[1].vals[l * dst_channels_aligned + tid] + = filters[1].vals[l * dst_channels_aligned + tid]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (inside_grid) { + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + for (int l = 0; l < src_depth; ++l) { + const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * + params.src_size.x + src_coord_0.x; + const int src_index_2 = (l * params.src_size.y + src_coord_2.y) * + params.src_size.x + src_coord_2.x; + FLT4 srcColor_0 = src_buffer[src_index_0]; + FLT4 srcColor_2 = src_buffer[src_index_2]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_2, stripes[1].vals[l * dst_channels_aligned + k]); + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * + params.dst_size.x + int(dst_coord.x); + uint3 gid = uint3(ugid.x, ugid.y, uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + } + + short2 src_coord_3 = src_coord_0 + short2(1, 1); + dst_coord += short2(1, 0); + + threadgroup_barrier(mem_flags::mem_none); + if (tid < dst_channels) { + for (int l = 0; l < src_depth; ++l) { + stripes[0].vals[l * dst_channels_aligned + tid] + = filters[8].vals[l * dst_channels_aligned + tid]; + stripes[1].vals[l * dst_channels_aligned + tid] + = filters[6].vals[l * dst_channels_aligned + tid]; + stripes[2].vals[l * dst_channels_aligned + tid] + = filters[2].vals[l * dst_channels_aligned + tid]; + stripes[3].vals[l * dst_channels_aligned + tid] + = filters[0].vals[l * dst_channels_aligned + tid]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (inside_grid) { + for (short l = 0; l < dst_depth * 4; ++l) { + out[l] = float(0.0f); + } + + for (int l = 0; l < src_depth; ++l) { + const int src_index_0 = (l * params.src_size.y + src_coord_0.y) * + params.src_size.x + src_coord_0.x; + const int src_index_1 = (l * params.src_size.y + src_coord_1.y) * + params.src_size.x + src_coord_1.x; + const int src_index_2 = (l * params.src_size.y + src_coord_2.y) * + params.src_size.x + src_coord_2.x; + const int src_index_3 = (l * params.src_size.y + src_coord_3.y) * + params.src_size.x + src_coord_3.x; + FLT4 srcColor_0 = src_buffer[src_index_0]; + FLT4 srcColor_1 = src_buffer[src_index_1]; + FLT4 srcColor_2 = src_buffer[src_index_2]; + FLT4 srcColor_3 = src_buffer[src_index_3]; + for (int k = 0; k < dst_channels; ++k) { + out[k] += dot(srcColor_0, stripes[0].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_1, stripes[1].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_2, stripes[2].vals[l * dst_channels_aligned + k]); + out[k] += dot(srcColor_3, stripes[3].vals[l * dst_channels_aligned + k]); + } + } + + for (short l = 0; l < dst_depth; ++l) { + FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + biases[l]; + const int linear_index = (l * params.dst_size.y + int(dst_coord.y)) * + params.dst_size.x + int(dst_coord.x); + uint3 gid = uint3(ugid.x, ugid.y, uint(l)); + $$2 + dst_buffer[linear_index] = value; + } + } + } + )"; + const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); + const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); + const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 4); + return absl::Substitute(shader_source, src_depth * dst_channels_aligned, + src_depth, dst_depth, attr.weights.shape.o, + dst_channels_aligned); +} + +} // namespace + +std::vector ConvolutionTransposed( + int id, ValueId input_id, ValueId output_id, + const ConvolutionTransposedAttributes& params, + const RuntimeOptions& options) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + + const int src_local_size_x = + (kThreadGroupWidth + params.weights.shape.w) / params.stride.w; + const int src_local_size_y = + (kThreadGroupHeight + params.weights.shape.h) / params.stride.h; + const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); + const int shared_size = + sizeof(float) * 4 * src_depth * src_local_size_x * src_local_size_y; + int gpu_type = GetAppleSocVersion(); + if (shared_size < 1000 * 16 && (gpu_type == 7 || gpu_type == 8)) { + desc->shader_source = + GetDeconvolutionShared(params, kThreadGroupWidth, kThreadGroupHeight); + } else { + desc->shader_source = GetDeconvolution(params); + } + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, params](const std::map& buffers) { + return CalculateOutputShape(buffers.find(input_id)->second, params); + }}; + + const int src_ch_aligned = AlignByN(params.weights.shape.i, 4); + const int dst_ch_aligned = AlignByN(params.weights.shape.o, 4); + const int kernel_x = params.weights.shape.w; + const int kernel_y = params.weights.shape.h; + const int filters_aligned_size = + src_ch_aligned * dst_ch_aligned * kernel_x * kernel_y; + std::vector filters_reordered(filters_aligned_size); + + int counter = 0; + for (int y = 0; y < kernel_y; ++y) { + for (int x = 0; x < kernel_x; ++x) { + for (int ch = 0; ch < src_depth; ++ch) { + for (int f = 0; f < dst_ch_aligned; ++f) { + for (int i = 0; i < 4; ++i) { + if (ch * 4 + i >= params.weights.shape.i || + f >= params.weights.shape.o) { + filters_reordered[counter++] = 0.0f; + } else { + const int f_index = + params.weights.shape.LinearIndex({f, y, x, ch * 4 + i}); + filters_reordered[counter++] = params.weights.data[f_index]; + } + } + } + } + } + } + + auto filters = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(filters_reordered) + : VectorFloatToHalf(filters_reordered); + auto biases = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(params.bias.data) + : VectorFloatToHalf(params.bias.data); + desc->immutable_buffers = { + {"device FilterStripe* const filters", filters}, + {"constant FLT4* const biases", biases}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + const auto& output_dimension = buffers.find(output_id)->second; + std::vector uniform_params{ + dimension.w, + dimension.h, + output_dimension.w, + output_dimension.h, + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [input_id, + params](const std::map& buffers) { + const uint3 groups_size{kThreadGroupWidth, kThreadGroupHeight, 1}; + BHWC dst_shape = + CalculateOutputShape(buffers.find(input_id)->second, params); + int groups_x = IntegralDivideRoundUp(dst_shape.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(dst_shape.h, groups_size.y); + int groups_z = 1; + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {desc}; +} + +std::vector ConvolutionTransposed3x3( + int id, ValueId input_id, ValueId output_id, + const ConvolutionTransposedAttributes& params, + const RuntimeOptions& options) { + const int kThreadGroupWidth = 16; + const int kThreadGroupHeight = 4; + + auto border_desc = std::make_shared(); + border_desc->id = id; + border_desc->is_linkable = false; + + border_desc->shader_source = GetDeconvolutionBorder(params); + + border_desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + border_desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, params](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + BHWC dst_shape = CalculateOutputShape(src_shape, params); + return BHWC{src_shape.b, dst_shape.h, dst_shape.w, dst_shape.c}; + }}; + + const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4); + const int src_ch_aligned = AlignByN(params.weights.shape.i, 4); + const int dst_ch_aligned = AlignByN(params.weights.shape.o, 4); + const int kernel_x = params.weights.shape.w; + const int kernel_y = params.weights.shape.h; + const int filters_aligned_size = + src_ch_aligned * dst_ch_aligned * kernel_x * kernel_y; + std::vector filters_reordered(filters_aligned_size); + + int counter = 0; + for (int y = 0; y < kernel_y; ++y) { + for (int x = 0; x < kernel_x; ++x) { + for (int ch = 0; ch < src_depth; ++ch) { + for (int f = 0; f < dst_ch_aligned; ++f) { + for (int i = 0; i < 4; ++i) { + if (ch * 4 + i >= params.weights.shape.i || + f >= params.weights.shape.o) { + filters_reordered[counter++] = 0.0f; + } else { + const int f_index = + params.weights.shape.LinearIndex({f, y, x, ch * 4 + i}); + filters_reordered[counter++] = params.weights.data[f_index]; + } + } + } + } + } + } + + auto filters = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(filters_reordered) + : VectorFloatToHalf(filters_reordered); + auto biases = options.storage_precision == RuntimeOptions::Precision::FP32 + ? VectorToUint8Vector(params.bias.data) + : VectorFloatToHalf(params.bias.data); + border_desc->immutable_buffers = { + {"device FilterStripe* const filters", filters}, + {"constant FLT4* const biases", biases}, + }; + + border_desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& src_dim = buffers.find(input_id)->second; + const auto& dst_dim = buffers.find(output_id)->second; + GridParams grid_params; + Params3x3 params3x3; + Init3x3(params, int2(src_dim.w, src_dim.h), int2(dst_dim.w, dst_dim.h), + &grid_params, ¶ms3x3); + int* ptr = reinterpret_cast(&grid_params); + std::vector uniform_params{ + src_dim.w, + src_dim.h, + dst_dim.w, + dst_dim.h, + /*uint GridParams.rect_offsets[4]*/ + ptr[0], + ptr[1], + ptr[2], + ptr[3], + /*uint GridParams.widths[4]*/ + ptr[4], + ptr[5], + ptr[6], + ptr[7], + /*short2 GridParams.origins[4]*/ + ptr[8], + ptr[9], + ptr[10], + ptr[11], + /*uint GridParams.elements_count*/ + ptr[12], + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + border_desc->resize_function = + [input_id, params](const std::map& buffers) { + const uint3 groups_size{kThreadGroupWidth * kThreadGroupHeight, 1, 1}; + const auto& src_shape = buffers.find(input_id)->second; + BHWC dst_shape = CalculateOutputShape(src_shape, params); + GridParams grid_params; + Params3x3 params3x3; + Init3x3(params, int2(src_shape.w, src_shape.h), + int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3); + if (grid_params.elements_count == 0) { + return std::make_pair(groups_size, uint3{0, 0, 0}); + } + int groups_x = + IntegralDivideRoundUp(grid_params.elements_count, groups_size.x); + int groups_y = 1; + int groups_z = 1; + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + + const int shared_size = sizeof(float) * 4 * src_depth * dst_ch_aligned * 4; + int gpu_type = GetAppleSocVersion(); + if (shared_size < (1024 * 16 - 32) && (gpu_type == 7 || gpu_type == 8) && + dst_ch_aligned <= kThreadGroupWidth * kThreadGroupHeight) { + desc->shader_source = GetDeconvolutionShared3x3(params); + } else { + desc->shader_source = GetDeconvolution3x3(params); + } + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* dst_buffer", + [input_id, params](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + BHWC dst_shape = CalculateOutputShape(src_shape, params); + return BHWC{src_shape.b, dst_shape.h, dst_shape.w, dst_shape.c}; + }}; + + desc->immutable_buffers = { + {"device FilterStripe* const filters", + VectorToUint8Vector(filters_reordered)}, + {"constant FLT4* const biases", VectorToUint8Vector(params.bias.data)}, + }; + + desc->uniform_buffers = { + {"constant uniforms& params", + [input_id, output_id, params](const std::map& buffers) { + const auto& src_shape = buffers.find(input_id)->second; + const auto& dst_shape = buffers.find(output_id)->second; + GridParams grid_params; + Params3x3 params3x3; + Init3x3(params, int2(src_shape.w, src_shape.h), + int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3); + int* ptr = reinterpret_cast(¶ms3x3); + std::vector uniform_params{ + src_shape.w, + src_shape.h, + dst_shape.w, + dst_shape.h, + /*short2 Params3x3.inner_size*/ ptr[0], + /*short2 Params3x3.src_offset*/ ptr[1], + /*short2 Params3x3.dst_offset*/ ptr[2], + }; + return VectorToUint8Vector(uniform_params); + }}, + }; + + desc->resize_function = [input_id, + params](const std::map& buffers) { + const uint3 groups_size{kThreadGroupWidth, kThreadGroupHeight, 1}; + const auto& src_shape = buffers.find(input_id)->second; + BHWC dst_shape = CalculateOutputShape(src_shape, params); + GridParams grid_params; + Params3x3 params3x3; + Init3x3(params, int2(src_shape.w, src_shape.h), + int2(dst_shape.w, dst_shape.h), &grid_params, ¶ms3x3); + if (params3x3.inner_size.x * params3x3.inner_size.y == 0) { + return std::make_pair(groups_size, uint3{0, 0, 0}); + } + int groups_x = IntegralDivideRoundUp(params3x3.inner_size.x, groups_size.x); + int groups_y = IntegralDivideRoundUp(params3x3.inner_size.y, groups_size.y); + int groups_z = 1; + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + + return {border_desc, desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h new file mode 100644 index 00000000000..d74cc2f17a4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_TRANSPOSE_CONV_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_TRANSPOSE_CONV_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector ConvolutionTransposed( + int id, ValueId input_id, ValueId output_id, + const ConvolutionTransposedAttributes& params, + const RuntimeOptions& options); + +std::vector ConvolutionTransposed3x3( + int id, ValueId input_id, ValueId output_id, + const ConvolutionTransposedAttributes& params, + const RuntimeOptions& options); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_TRANSPOSE_CONV_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/upsample.cc b/tensorflow/lite/delegates/gpu/metal/kernels/upsample.cc new file mode 100644 index 00000000000..69d88876efb --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/upsample.cc @@ -0,0 +1,125 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/upsample.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector Upsample( + int id, ValueId input_id, ValueId output_id, + const Upsample2DAttributes& attr) { + auto desc = std::make_shared(); + desc->id = id; + desc->is_linkable = false; + + desc->shader_source = R"( + #include + using namespace metal; + $0 + kernel void ComputeFunction( + $1 + uint3 gid[[thread_position_in_grid]]) { + if (int(gid.x) >= size.z || int(gid.y) >= size.w) { + return; + } + const float2 tex_coord = float2(gid.xy) * scale; + int4 st; + const int2 borders = size.xy - int2(1, 1); + st.xy = clamp(int2(tex_coord), int2(0, 0), borders); + st.zw = min(st.xy + int2(1, 1), borders); + const float2 t = tex_coord - float2(st.xy); //interpolating factors + const int src_index0 = (gid.z * size.y + st.y) * size.x + st.x; + const int src_index1 = (gid.z * size.y + st.y) * size.x + st.z; + const int src_index2 = (gid.z * size.y + st.w) * size.x + st.x; + const int src_index3 = (gid.z * size.y + st.w) * size.x + st.z; + FLT4 tex11 = src_buffer[src_index0]; + FLT4 tex21 = src_buffer[src_index1]; + FLT4 tex12 = src_buffer[src_index2]; + FLT4 tex22 = src_buffer[src_index3]; + // bilinear interpolation + FLT4 value = mix(mix(tex11, tex21, static_cast(t.x)), + mix(tex12, tex22, static_cast(t.x)), static_cast(t.y)); + const int linear_index = (gid.z * size.w + gid.y) * size.z + gid.x; + $2 + output_buffer[linear_index] = value; + } + )"; + + desc->input_buffers = { + {input_id, "device FLT4* const src_buffer"}, + }; + + desc->output_buffer = { + output_id, "device FLT4* output_buffer", + [input_id, attr](const std::map& buffers) { + return CalculateOutputShape(buffers.find(input_id)->second, attr); + }}; + + desc->uniform_buffers = { + {"constant int4& size", + [input_id, output_id](const std::map& buffers) { + const auto& dimension = buffers.find(input_id)->second; + const auto& output_dimension = buffers.find(output_id)->second; + std::vector sizes = { + dimension.w, + dimension.h, + output_dimension.w, + output_dimension.h, + }; + return VectorToUint8Vector(sizes); + }}, + {"constant float2& scale", + [input_id, output_id, attr](const std::map& buffers) { + const auto& input_dimensions = buffers.find(input_id)->second; + const auto& output_dimensions = buffers.find(output_id)->second; + std::vector sizes = { + CalculateResizeScale(input_dimensions.w, output_dimensions.w, + attr), + CalculateResizeScale(input_dimensions.h, output_dimensions.h, + attr), + }; + return VectorToUint8Vector(sizes); + }}, + }; + + desc->resize_function = [output_id](const std::map& buffers) { + const uint3 groups_size{16, 16, 1}; + const auto& dst_dim = buffers.find(output_id)->second; + int groups_x = IntegralDivideRoundUp(dst_dim.w, groups_size.x); + int groups_y = IntegralDivideRoundUp(dst_dim.h, groups_size.y); + const int dst_layers = IntegralDivideRoundUp(dst_dim.c, 4); + int groups_z = IntegralDivideRoundUp(dst_layers, groups_size.z); + return std::make_pair(groups_size, uint3{groups_x, groups_y, groups_z}); + }; + return {desc}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/upsample.h b/tensorflow/lite/delegates/gpu/metal/kernels/upsample.h new file mode 100644 index 00000000000..54d10a0a5a2 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/upsample.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UPSAMPLE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UPSAMPLE_H_ + +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" + +namespace tflite { +namespace gpu { +namespace metal { + +std::vector Upsample( + int id, ValueId input_id, ValueId output_id, + const Upsample2DAttributes& attr); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UPSAMPLE_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/util.cc b/tensorflow/lite/delegates/gpu/metal/kernels/util.cc new file mode 100644 index 00000000000..5e4466661cd --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/util.cc @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +unsigned int GetOptimalSize(unsigned int grid_size) { + if (grid_size % 8 == 0 || grid_size % 8 >= 4 || grid_size >= 16) { + return 8; + } + if (grid_size % 4 == 0 || grid_size % 4 >= 2 || grid_size >= 8) { + return 4; + } + if (grid_size % 2 == 0 || grid_size >= 4) { + return 2; + } + return 1; +} + +} // namespace + +uint3 GetWorkGroupSizeForGrid(const uint3& grid_size) { + unsigned int x_size = GetOptimalSize(grid_size.x); + unsigned int y_size = GetOptimalSize(grid_size.y); + unsigned int z_size = std::max(1u, 32u / (x_size * y_size)); + return {x_size, y_size, z_size}; +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/util.h b/tensorflow/lite/delegates/gpu/metal/kernels/util.h new file mode 100644 index 00000000000..a1028ee25ea --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/kernels/util.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UTIL_H_ + +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace metal { + +// returns work-group size for grid that tries to cover grid optimaly +// If you use work-group size generated by this method you MUST check +// all three dimensions of thread on out of border in your kernel. +uint3 GetWorkGroupSizeForGrid(const uint3& grid_size); + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_UTIL_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/runtime_options.h b/tensorflow/lite/delegates/gpu/metal/runtime_options.h new file mode 100644 index 00000000000..d8e8fe3dd92 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/runtime_options.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_ + +namespace tflite { +namespace gpu { +namespace metal { + +struct RuntimeOptions { + enum class Precision { + FP16, + FP32, + }; + // Buffer storage format. If FP32 then accumulator must be FP32. + Precision storage_precision = Precision::FP32; + // Accumulator precision. Defines the precision for convolutions. + Precision accumulator_precision = Precision::FP32; +}; + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_RUNTIME_OPTIONS_H_ diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.h b/tensorflow/lite/delegates/gpu/metal_delegate.h new file mode 100644 index 00000000000..d38e73a4a19 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal_delegate.h @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_H_ + +#import + +#include + +#include "tensorflow/lite/c/c_api_internal.h" + +// Creates a new delegate instance that need to be destroyed with +// DeleteFlowDelegate when delegate is no longer used by tflite. +struct GpuDelegateOptions { + // Allows to quantify tensors, downcast values, process in float16 etc. + bool allow_precision_loss; + + enum class WaitType { + // waitUntilCompleted + kPassive, + // Minimize latency. It uses active spinning instead of mutex and consumes + // additional CPU resources. + kActive, + // Useful when the output is used with GPU pipeline then or if external + // command encoder is set. + kDoNotWait, + // Tries to avoid GPU sleep mode. + kAggressive, + }; + WaitType wait_type; +}; + +// Creates a new delegate instance that need to be destroyed with +// `DeleteTfLiteGpuDelegate` when delegate is no longer used by TFLite. +// When `options` is set to `nullptr`, the following default values are used: +// .precision_loss_allowed = false, +// .wait_type = kPassive, +TfLiteDelegate* NewGpuDelegate(const GpuDelegateOptions* options); + +// Destroys a delegate created with `NewGpuDelegate` call. +void DeleteGpuDelegate(TfLiteDelegate* delegate); + +// Binds Metal buffer to an input or an output tensor in the initialized +// delegate. Bound buffer should have sufficient storage to accommodate all +// elements of a tensor. Returns non-zero on success, or zero otherwise. +// +// *** Must be called *before* `Interpreter::ModifyGraphWithDelegate`. *** +bool BindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index, + id metal_buffer); + +// Binds user-defined MTLComputeCommandEncoder. The delegate puts all GPU tasks +// into this encoder instead of the internal encoder. +// The callback is a user-defined function to take control over encoder and +// command buffer. Can be nullptr. +bool TFLSetCommandEncoder( + TfLiteDelegate* delegate, id encoder, + std::function(bool is_last)> control_encoder); + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_H_ diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm new file mode 100644 index 00000000000..d62d10a7e39 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -0,0 +1,636 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#import "tensorflow/lite/delegates/gpu/metal_delegate.h" + +#import + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_builder.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/metal/api.h" +#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" +#include "tensorflow/lite/delegates/gpu/metal/common.h" +#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h" +#include "tensorflow/lite/delegates/gpu/metal/inference_context.h" +#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" +#include "tensorflow/lite/minimal_logging.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { + +// Multi-thread safe alarm clock for preventing GPU sleeping. It spawns lightweight compute tasks +// until no inference is performing on a device. It's reduces the CPU-to-CPU inference latency. +// The class is used only for kAggressive wait type. +class GpuAlarmClock { + public: + explicit GpuAlarmClock(id command_queue) { + auto device = [command_queue device]; + std::lock_guard lock(alarms_mutex_); + if (!alarms_) alarms_ = new std::map, GpuAlarmClockInternal*>(); + auto it = alarms_->find(device); + if (it == alarms_->end()) { + internal_ = new GpuAlarmClockInternal(command_queue); + (*alarms_)[device] = internal_; + } else { + internal_ = it->second; + internal_->total_alarms_++; + } + } + ~GpuAlarmClock() { + std::lock_guard lock(alarms_mutex_); + if (--internal_->total_alarms_ > 0) return; + Stop(); + delete internal_; + // Remove the alarm from the container to free-up device handle. + for (auto it = alarms_->begin(); it != alarms_->end(); ++it) { + if (it->second == internal_) { + alarms_->erase(it); + break; + } + } + if (alarms_->empty()) { + delete alarms_; + alarms_ = nullptr; + } + } + void Start() { + if (started_) return; + started_ = true; + internal_->active_alarms_++; + } + void Stop() { + if (!started_) return; + started_ = false; + internal_->active_alarms_--; + } + + private: + class GpuAlarmClockInternal { + public: + id stub_program_; + id stub_buffer_; + explicit GpuAlarmClockInternal(id command_queue) { + command_queue_ = command_queue; + device_ = [command_queue_ device]; + total_alarms_ = 1; + NSString* error; + id program; + CreateComputeProgram(device_, + @"kernel void ComputeFunction(device int* output_buffer [[buffer(0)]]) " + @"{ output_buffer[0] = 0; }", + @"ComputeFunction", nullptr, &program); + stub_program_ = program; + stub_buffer_ = [device_ newBufferWithLength:sizeof(int) * 4 + options:MTLResourceHazardTrackingModeUntracked]; + alarm_thread_ = std::thread([this]() { + id prev_command_buffer; + while (!release_thread_) { + if (active_alarms_ == total_alarms_) { + id command_buffer = [command_queue_ commandBuffer]; + id encoder = [command_buffer computeCommandEncoder]; + [encoder setComputePipelineState:stub_program_]; + [encoder setBuffer:stub_buffer_ offset:0 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) + threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder endEncoding]; + [command_buffer commit]; + if (prev_command_buffer != nil) [prev_command_buffer waitUntilScheduled]; + prev_command_buffer = command_buffer; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + }); + } + ~GpuAlarmClockInternal() { + release_thread_ = true; + alarm_thread_.join(); + } + + private: + friend class GpuAlarmClock; + std::atomic active_alarms_; + std::thread alarm_thread_; + id command_queue_; + id device_; + volatile bool release_thread_ = false; + int total_alarms_ = 0; + }; + static std::map, GpuAlarmClockInternal*>* alarms_; + std::mutex alarms_mutex_; + GpuAlarmClockInternal* internal_; + bool started_ = false; +}; +std::map, GpuAlarmClock::GpuAlarmClockInternal*>* GpuAlarmClock::alarms_ = nullptr; + +// Forward declaration. +TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); + +class Delegate { + struct ValueRef { + BHWC shape; + int64_t tensor_id; + }; + + public: + explicit Delegate(const GpuDelegateOptions* options) { + if (options) { + options_ = *options; + } else { + // Default options. + options_.allow_precision_loss = false; + options_.wait_type = GpuDelegateOptions::WaitType::kPassive; + } + metal_device_ = MTLCreateSystemDefaultDevice(); + command_queue_ = [metal_device_ newCommandQueue]; + if (options_.wait_type == GpuDelegateOptions::WaitType::kAggressive) { + gpu_alarm_clock_ = std::unique_ptr(new GpuAlarmClock(command_queue_)); + NSString* code = @R"( + kernel void ComputeFunction(device int* output_buffer [[buffer(0)]], + constant int& value [[buffer(1)]]) { + output_buffer[0] = value; + } + )"; + NSString* error; + id signal_program; + CreateComputeProgram(metal_device_, code, @"ComputeFunction", nullptr, &signal_program); + signal_program_ = signal_program; + signal_buffer_ = [metal_device_ newBufferWithLength:sizeof(int) * 4 + options:MTLResourceStorageModeShared | + MTLResourceHazardTrackingModeUntracked]; + } + } + + Status BindBufferToTensor(id buffer, int tensor_index) { + for (auto& input : graph_inputs_) { + if (input.tensor_id == tensor_index) { + input_output_buffers_[input.id] = buffer; + bphwc4_buffers_[input.id] = buffer; + input.set_externally = true; + return OkStatus(); + } + } + for (auto& output : graph_outputs_) { + if (output.tensor_id == tensor_index) { + input_output_buffers_[output.id] = buffer; + bphwc4_buffers_[output.id] = buffer; + output.set_externally = true; + return OkStatus(); + } + } + return NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index)); + } + + void SetCommandEncoder( + id encoder, + std::function(bool is_last)> control_encoder) { + control_encoder_ = control_encoder; + external_command_encoder_ = encoder; + } + + Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { + // Extract TFLite delegate execution plan from the context and convert it into FlowGraph32. + GraphFloat32 graph; + RETURN_IF_ERROR(BuildModel(context, delegate_params, &graph)); + + // Apply general transformations on the graph. + NullTransformationReporter reporter; + ModelTransformer transformer(&graph, &reporter); + if (!ApplyGeneralTransformations(&transformer)) { + return InternalError("Graph general transformations failed"); + } + + // TODO(impjdi): Remove code duplication. + auto values = graph.values(); + auto find_value = [&](int tensor_index) -> Value* { + for (auto value : values) { + if (value->tensor.ref == tensor_index) return value; + } + return nullptr; + }; + tensors_.reserve(values.back()->id + 1); + for (const auto* value : values) { + if (tensors_.size() <= value->id) tensors_.resize(value->id + 1); + tensors_[value->id] = { + value->tensor.shape, // .shape + value->tensor.ref, // .tensor_id + }; + } + + // Prepare graph inputs. + // + // Note that graph.inputs() cannot be used directly, as the notion of graph input has a + // different meaning in public API and GPU-internal API. + inputs_.reserve(delegate_params->input_tensors->size); + for (int i = 0; i < delegate_params->input_tensors->size; ++i) { + const int tensor_index = delegate_params->input_tensors->data[i]; + auto* tensor = context->tensors + tensor_index; + if (tensor->allocation_type == TfLiteAllocationType::kTfLiteMmapRo) continue; + const auto* input = find_value(tensor_index); + if (!input || tensor->type != TfLiteType::kTfLiteFloat32) { + return NotFoundError("Input tensor is not found in the graph."); + } + + inputs_.push_back(input->id); + tensor->buffer_handle = input->id; + tensor->delegate = &delegate_; + } + + // Prepare graph outputs. + // + // Note that graph.outputs() cannot be used directly, as the notion of graph output has a + // different meaning in public API and GPU-internal API. + outputs_.reserve(delegate_params->output_tensors->size); + for (int i = 0; i < delegate_params->output_tensors->size; ++i) { + const int tensor_index = delegate_params->output_tensors->data[i]; + auto* tensor = context->tensors + tensor_index; + const auto* output = find_value(tensor_index); + if (!output || tensor->type != TfLiteType::kTfLiteFloat32) { + return NotFoundError("Output tensor is not found in the graph."); + } + + outputs_.push_back(output->id); + tensor->buffer_handle = output->id; + tensor->delegate = &delegate_; + } + + size_t storage_type_size; + RuntimeOptions runtime_options; + if (options_.allow_precision_loss) { + storage_type_size = sizeof(HalfBits); + runtime_options.storage_precision = RuntimeOptions::Precision::FP16; + runtime_options.accumulator_precision = RuntimeOptions::Precision::FP16; + } else { + storage_type_size = sizeof(float); + runtime_options.storage_precision = RuntimeOptions::Precision::FP32; + runtime_options.accumulator_precision = RuntimeOptions::Precision::FP32; + } + + // TODO(impjdi): Merge logic with above. + // Pre-allocate input and output metal buffers + std::vector<::tflite::gpu::ValueId> input_ids; + input_ids.reserve(inputs_.size()); + std::map<::tflite::gpu::ValueId, BHWC> input_dimensions; + graph_inputs_.reserve(inputs_.size()); + for (const ValueId input : inputs_) { + const auto& input_tensor = tensors_[input]; + const auto tensor_id = input_tensor.tensor_id; + input_ids.push_back(input); + if (input_tensor.shape.b != 1) return UnimplementedError("Batching is not supported yet."); + input_dimensions[input] = input_tensor.shape; + graph_inputs_.push_back({ + input, // .id + tensor_id, // .tensor_id + input_tensor.shape, // .shape + false, // .set_externally + }); + int bhwc_length = static_cast(sizeof(float) * input_tensor.shape.DimensionsProduct()); + int bphwc4_length = + static_cast(storage_type_size * GetElementsSizeForPHWC4(input_tensor.shape)); + id buffer = [metal_device_ newBufferWithLength:bhwc_length + options:MTLResourceStorageModeShared]; + input_output_buffers_[input] = buffer; + if (options_.allow_precision_loss || input_tensor.shape.c != 4) { + bphwc4_buffers_[input] = [metal_device_ newBufferWithLength:bphwc4_length + options:MTLResourceStorageModeShared]; + if (converter_to_BPHWC4_ == nil) { + converter_to_BPHWC4_ = + [[TFLBufferConvert alloc] initWithDevice:metal_device_ + isFloat16:options_.allow_precision_loss + convertToPBHWC4:true]; + if (converter_to_BPHWC4_ == nil) { + return InternalError("Error initialization of input buffer converter"); + } + } + } else { + bphwc4_buffers_[input] = buffer; + } + } + + std::vector<::tflite::gpu::ValueId> output_ids; + output_ids.reserve(outputs_.size()); + graph_outputs_.reserve(outputs_.size()); + for (const ValueId output : outputs_) { + const auto& output_tensor = tensors_[output]; + const auto tensor_id = output_tensor.tensor_id; + output_ids.push_back(output); + graph_outputs_.push_back({ + output, // .id + tensor_id, // .tensor_id + output_tensor.shape, // .shape + false, // .set_externally + }); + // Create BHWC buffer + int bhwc_length = static_cast(sizeof(float) * output_tensor.shape.DimensionsProduct()); + int bphwc4_length = + static_cast(storage_type_size * GetElementsSizeForPHWC4(output_tensor.shape)); + id buffer = [metal_device_ newBufferWithLength:bhwc_length + options:MTLResourceStorageModeShared]; + input_output_buffers_[output] = buffer; + if (options_.allow_precision_loss || output_tensor.shape.c != 4) { + bphwc4_buffers_[output] = [metal_device_ newBufferWithLength:bphwc4_length + options:MTLResourceStorageModeShared]; + if (converter_from_BPHWC4_ == nil) { + converter_from_BPHWC4_ = + [[TFLBufferConvert alloc] initWithDevice:metal_device_ + isFloat16:options_.allow_precision_loss + convertToPBHWC4:false]; + if (converter_from_BPHWC4_ == nil) { + return InternalError("Error initialization of output buffer converter"); + } + } + } else { + bphwc4_buffers_[output] = buffer; + } + } + + // TODO(impjdi): Merge these. + CompiledModel compiled_model; + RETURN_IF_ERROR(Compile(graph, runtime_options, &compiled_model)); + CompiledModel optimized_model; + RETURN_IF_ERROR(ValidateOptimizeModel(input_ids, output_ids, compiled_model, &optimized_model)); + + inference_context_ = [[TFLInferenceContext alloc] init]; + RETURN_IF_ERROR([inference_context_ compileModelWithDevice:metal_device_ + taskDescriptors:optimized_model + outputBufferIDs:output_ids + runtimeOptions:runtime_options]); + std::map<::tflite::gpu::ValueId, BHWC> output_dimensions; + RETURN_IF_ERROR([inference_context_ setInputDimensions:input_dimensions + outputDimensions:&output_dimensions + taskDescriptors:optimized_model]); + return OkStatus(); + } + + Status Invoke(TfLiteContext* context) { + if (options_.wait_type == GpuDelegateOptions::WaitType::kAggressive) gpu_alarm_clock_->Stop(); + // We need only synchronization so volatile works better than atomic which reads from global + // memory each time. + __block volatile bool buffer_completed = false; + __block id command_buffer; + __block id encoder = external_command_encoder_; + if (external_command_encoder_ == nil) { + command_buffer = [command_queue_ commandBuffer]; + encoder = [command_buffer computeCommandEncoder]; + } + + // CPU HWC input data conversion to PHWC4 and fill the GPU buffer + for (const auto& input : graph_inputs_) { + if (input.set_externally) continue; + // A user provides data on CPU memory for this buffer - need to copy to MTLBuffer + + TfLiteTensor* tensor = context->tensors + input.tensor_id; + void* gpu_ptr = [input_output_buffers_[input.id] contents]; + std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float)); + if (input_output_buffers_[input.id] == bphwc4_buffers_[input.id]) continue; + [converter_to_BPHWC4_ convertWithEncoder:encoder + shape:input.shape + sourceBuffer:input_output_buffers_[input.id] + convertedBuffer:bphwc4_buffers_[input.id]]; + if (external_command_encoder_ == nil) { + [encoder endEncoding]; + [command_buffer commit]; + command_buffer = [command_queue_ commandBuffer]; + encoder = [command_buffer computeCommandEncoder]; + } + } + + [inference_context_ encodeWithEncoder:encoder + inputOutputBuffers:bphwc4_buffers_ + encoderBlock:^(bool isLast) { + if (control_encoder_ != nullptr) { + return control_encoder_(isLast); + } + if (external_command_encoder_ != nil || + options_.wait_type == GpuDelegateOptions::WaitType::kPassive) { + return encoder; + } + if (isLast) { + if (options_.wait_type == GpuDelegateOptions::WaitType::kActive) { + [command_buffer addCompletedHandler:^(id) { + buffer_completed = true; + }]; + } + } else { + [encoder endEncoding]; + [command_buffer commit]; + command_buffer = [command_queue_ commandBuffer]; + encoder = [command_buffer computeCommandEncoder]; + } + return encoder; + }]; + for (const auto& output : graph_outputs_) { + if (output.set_externally) continue; + if (bphwc4_buffers_[output.id] == input_output_buffers_[output.id]) continue; + [converter_from_BPHWC4_ convertWithEncoder:encoder + shape:output.shape + sourceBuffer:bphwc4_buffers_[output.id] + convertedBuffer:input_output_buffers_[output.id]]; + } + + if (external_command_encoder_ == nil) { + [encoder endEncoding]; + [command_buffer commit]; + if (options_.wait_type == GpuDelegateOptions::WaitType::kActive) { + while (!buffer_completed) { + // Busy wait. Use local variable. Volatile uses RAM access all the time. + for (volatile int i = 0; i < 100; i++) { + } + } + } else if (options_.wait_type == GpuDelegateOptions::WaitType::kPassive) { + // passive wait: this thread sleeps until GPU finishes. + [command_buffer waitUntilCompleted]; + } else if (options_.wait_type == GpuDelegateOptions::WaitType::kAggressive) { + command_buffer = [command_queue_ commandBuffer]; + encoder = [command_buffer computeCommandEncoder]; + [encoder setComputePipelineState:signal_program_]; + [encoder setBuffer:signal_buffer_ offset:0 atIndex:0]; + signal_value_++; + [encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1]; + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) + threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder endEncoding]; + [command_buffer commit]; + gpu_alarm_clock_->Start(); + const int* signal_ptr = reinterpret_cast([signal_buffer_ contents]); + while (signal_ptr[0] != signal_value_) { + // Busy wait. Spinning with local variable to avoid RAM pressure. + for (volatile int i = 0; i < 100; i++) { + } + } + } + } else { + // External command encoder must be set before every invoke call. + external_command_encoder_ = nil; + // External command encoder is assigned so all output buffers are controlled by a user. + for (const auto& output : graph_outputs_) { + if (!output.set_externally) { + return InternalError( + "External command encoder is used, but not all output buffers are bound."); + } + } + return OkStatus(); + } + + // Retrieve data from GPU and convert from PHWC4 to HWC. + for (const auto& output : graph_outputs_) { + if (output.set_externally) continue; + // A user retrieves data on CPU memory for this buffer - need to copy from MTLBuffer. + TfLiteTensor* tensor = context->tensors + output.tensor_id; + const void* gpu_ptr = [input_output_buffers_[output.id] contents]; + std::memcpy(tensor->data.f, gpu_ptr, output.shape.DimensionsProduct() * sizeof(float)); + } + return OkStatus(); + } + + TfLiteDelegate* tflite_delegate() { return &delegate_; } + + private: + TfLiteDelegate delegate_ = { + reinterpret_cast(this), // .data_ + DelegatePrepare, // .Prepare + nullptr, // .CopyFromBufferHandle + nullptr, // .CopyToBufferHandle + nullptr, // .FreeBufferHandle + kTfLiteDelegateFlagsNone, // .flags + }; + + GpuDelegateOptions options_; + + id metal_device_; + + std::vector tensors_; // indexed by ValueId + std::vector inputs_; + std::vector outputs_; + + TFLInferenceContext* inference_context_; + // input and output buffers are passed into Metal inference engine + std::map<::tflite::gpu::ValueId, id> input_output_buffers_; + std::map<::tflite::gpu::ValueId, id> bphwc4_buffers_; + TFLBufferConvert* converter_to_BPHWC4_ = nil; + TFLBufferConvert* converter_from_BPHWC4_ = nil; + + struct BufferDescriptor { + ValueId id; + int64_t tensor_id; + BHWC shape; + bool set_externally; // a user fills/retrieves data on this MTLBuffer buffer + }; + std::vector graph_inputs_; + std::vector graph_outputs_; + + id external_command_encoder_; + std::function(bool is_last)> control_encoder_; + id command_queue_; + std::unique_ptr gpu_alarm_clock_; + id signal_program_; + id signal_buffer_; + int signal_value_ = 0; +}; + +Delegate* GetMetalDelegate(TfLiteNode* node) { + return reinterpret_cast(node->user_data); +} + +Delegate* GetMetalDelegate(TfLiteDelegate* delegate) { + return reinterpret_cast(delegate->data_); +} + +TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { + const TfLiteRegistration kRegistration = { + // .init + [](TfLiteContext* context, const char* buffer, size_t) -> void* { + const auto* params = reinterpret_cast(buffer); + auto* metal_delegate = GetMetalDelegate(params->delegate); + // Everything below should happen in prepare function call, but TFLite for whatever reason + // forbids that. + const auto status = metal_delegate->Prepare(context, params); + if (status.ok()) return metal_delegate; + context->ReportError(context, "TfLiteGpuDelegate Prepare: %s", status.message().data()); + return nullptr; + }, + // .free + [](TfLiteContext*, void* buffer) -> void {}, + // .prepare + [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { + return node->user_data ? kTfLiteOk : kTfLiteError; + }, + // .invoke + [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { + const auto status = GetMetalDelegate(node)->Invoke(context); + if (status.ok()) return kTfLiteOk; + context->ReportError(context, "TfLiteMetalDelegate Invoke: %s", status.message().data()); + return kTfLiteError; + }, + nullptr, // .profiling_string + 0, // .builtin_code + "TfLiteMetalDelegate", // .custom_name + 1, // .version + }; + TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); + const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(context, kRegistration, + ops_to_replace, delegate); + TfLiteIntArrayFree(ops_to_replace); + return status; +} + +} // namespace +} // namespace metal +} // namespace gpu +} // namespace tflite + +TfLiteDelegate* NewGpuDelegate(const GpuDelegateOptions* options) { + TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Created TensorFlow Lite delegate for Metal."); + auto* metal_delegate = new ::tflite::gpu::metal::Delegate(options); + return metal_delegate ? metal_delegate->tflite_delegate() : nullptr; +} + +void DeleteGpuDelegate(TfLiteDelegate* delegate) { + delete ::tflite::gpu::metal::GetMetalDelegate(delegate); +} + +bool BindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index, id buffer) { + auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); + return metal_delegate && metal_delegate->BindBufferToTensor(buffer, tensor_index).ok(); +} + +bool TFLSetCommandEncoder( + TfLiteDelegate* delegate, id encoder, + std::function(bool is_last)> control_encoder) { + auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); + if (!metal_delegate) return false; + metal_delegate->SetCommandEncoder(encoder, control_encoder); + return true; +} diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD index 99cd6d3f859..f8f3c03ea25 100644 --- a/tensorflow/lite/delegates/nnapi/BUILD +++ b/tensorflow/lite/delegates/nnapi/BUILD @@ -2,17 +2,25 @@ package(default_visibility = [ "//visibility:public", ]) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") licenses(["notice"]) # Apache 2.0 cc_library( name = "nnapi_delegate", - srcs = ["nnapi_delegate.cc"], + srcs = select({ + "//tensorflow:ios": [ + "nnapi_delegate_disabled.cc", + ], + "//tensorflow:windows": [ + "nnapi_delegate_disabled.cc", + ], + "//conditions:default": [ + "nnapi_delegate.cc", + ], + }), hdrs = ["nnapi_delegate.h"], deps = [ - "//tensorflow/lite:framework", "//tensorflow/lite:kernel_api", "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/kernels:kernel_util", @@ -20,13 +28,12 @@ cc_library( ], ) -tf_cc_test( +cc_test( name = "nnapi_delegate_test", size = "small", srcs = ["nnapi_delegate_test.cc"], tags = [ - # TODO(b/122987564): Enable on Android after resolving API 27 failures. - "tflite_not_portable_android", + "no_windows", "tflite_not_portable_ios", ], deps = [ diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/BUILD new file mode 100644 index 00000000000..17a238980d0 --- /dev/null +++ b/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/BUILD @@ -0,0 +1,7 @@ +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "nnapi_delegate_src", + srcs = ["NnApiDelegate.java"], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java b/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java new file mode 100644 index 00000000000..3e680162452 --- /dev/null +++ b/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.nnapi; + +import org.tensorflow.lite.Delegate; + +/** {@link Delegate} for NNAPI inference. */ +public class NnApiDelegate implements Delegate, AutoCloseable { + + private static final long INVALID_DELEGATE_HANDLE = 0; + + private long delegateHandle; + + public NnApiDelegate() { + delegateHandle = createDelegate(); + } + + @Override + public long getNativeHandle() { + return delegateHandle; + } + + /** + * The NNAPI delegate is singleton. Nothing to delete for now, so mark the handle invalid only. + */ + @Override + public void close() { + if (delegateHandle != INVALID_DELEGATE_HANDLE) { + delegateHandle = INVALID_DELEGATE_HANDLE; + } + } + + private static native long createDelegate(); +} diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/native/BUILD b/tensorflow/lite/delegates/nnapi/java/src/main/native/BUILD new file mode 100644 index 00000000000..4c12ef344d5 --- /dev/null +++ b/tensorflow/lite/delegates/nnapi/java/src/main/native/BUILD @@ -0,0 +1,24 @@ +# Description: +# Java Native Interface (JNI) library intended for implementing the +# TensorFlow Lite GPU delegate Java API using the TensorFlow Lite CC library. + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow/lite:build_def.bzl", "tflite_copts") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "native", + srcs = ["nnapi_delegate_jni.cc"], + copts = tflite_copts(), + tags = [ + "manual", + "notap", + ], + deps = [ + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + "//tensorflow/lite/java/jni", + ], + alwayslink = 1, +) diff --git a/tensorflow/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc similarity index 60% rename from tensorflow/lite/java/src/main/native/tensorflow_lite_jni.h rename to tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc index de3e703110c..d68ff5efac1 100644 --- a/tensorflow/lite/java/src/main/native/tensorflow_lite_jni.h +++ b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,24 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ -#define TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ - #include +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus -/* - * Class: org_tensorflow_lite_TensorFlowLite - * Method: version - * Signature: ()Ljava/lang/String; - */ -JNIEXPORT jstring JNICALL -Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass); +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate(JNIEnv* env, + jclass clazz) { + return reinterpret_cast(tflite::NnApiDelegate()); +} #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // TENSORFLOW_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_ diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 2f5ed54db8e..6f6b450575f 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -12,18 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" + #include #include #include #include #include -#include "tensorflow/lite/allocation.h" #include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/context_util.h" -#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h" @@ -65,7 +65,19 @@ bool IsQuantized(TfLiteType type) { switch (type) { case kTfLiteUInt8: case kTfLiteInt8: - case kTfLiteInt16: + return true; + default: + // kTfLiteInt16 isn't supported as quantized type yet. + return false; + } +} + +bool IsScalarInputSupported(int builtin_code) { + switch (builtin_code) { + case kTfLiteBuiltinAdd: + case kTfLiteBuiltinMul: + case kTfLiteBuiltinSub: + case kTfLiteBuiltinDiv: return true; default: return false; @@ -83,15 +95,71 @@ bool IsHybridOperator(const TfLiteContext* context, int builtin_code, const TfLiteType filter_type = context->tensors[filter_id].type; return IsFloat(input_type) && IsQuantized(filter_type); } + case kTfLiteBuiltinLstm: { + const int input_id = node->inputs->data[0]; + // Input #1 is optional so use #2 to determine if hybrid. + const int weights_id = node->inputs->data[2]; + const TfLiteType input_type = context->tensors[input_id].type; + const TfLiteType weights_type = context->tensors[weights_id].type; + return IsFloat(input_type) && IsQuantized(weights_type); + } default: return false; } } +// When using NN API version 1.0 or 1.1, the condition below must be true for +// quantized versions of the following ops: +// * CONV_2D +// * DEPTHWISE_CONV_2D +// * FULLY_CONNECTED (where filter actually stands for weights) +// The condition is relaxed and no longer required since version 1.2. +bool IsRestrictedScalesCompliant(const TfLiteContext* context, + const TfLiteNode* node) { + const int input_id = node->inputs->data[0]; + const int filter_id = node->inputs->data[1]; + const int output_id = node->outputs->data[0]; + const float input_scale = context->tensors[input_id].params.scale; + const float filter_scale = context->tensors[filter_id].params.scale; + const float output_scale = context->tensors[output_id].params.scale; + return input_scale * filter_scale < output_scale; +} + constexpr int32_t kMinSdkVersionForNNAPI = 27; constexpr int32_t kMinSdkVersionForNNAPI11 = 28; constexpr int32_t kMinSdkVersionForNNAPI12 = 29; +constexpr size_t kDefaultByteAlignmentForNNAPI = 16; +static size_t getNumPaddingBytes(size_t byte_size) { + size_t num_padding_bytes = 0; + if (byte_size % kDefaultByteAlignmentForNNAPI) { + num_padding_bytes = kDefaultByteAlignmentForNNAPI - + (byte_size % kDefaultByteAlignmentForNNAPI); + } + return num_padding_bytes; +} + +// Return NNAPI device handle with the provided null-terminated device name. If +// no matching device could be found, nullptr will be returned. +ANeuralNetworksDevice* GetDeviceHandle(const char* device_name_ptr) { + if (!device_name_ptr) return nullptr; + ANeuralNetworksDevice* device_handle = nullptr; + std::string device_name(device_name_ptr); + uint32_t numDevices = 0; + NnApiImplementation()->ANeuralNetworks_getDeviceCount(&numDevices); + + for (uint32_t i = 0; i < numDevices; i++) { + ANeuralNetworksDevice* device = nullptr; + const char* buffer = nullptr; + NnApiImplementation()->ANeuralNetworks_getDevice(i, &device); + NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer); + if (device_name == buffer) { + device_handle = device; + break; + } + } + return device_handle; +} } // namespace // RAII NN API Model Destructor for use with std::unique_ptr @@ -107,6 +175,13 @@ struct NNFreeCompilation { } }; +// RAII NN API Execution Destructor for use with std::unique_ptr +struct NNFreeExecution { + void operator()(ANeuralNetworksExecution* execution) { + NnApiImplementation()->ANeuralNetworksExecution_free(execution); + } +}; + // Manage NNAPI shared memory handle class NNMemory { public: @@ -255,8 +330,10 @@ class NNAPIOpBuilder { return kTfLiteOk; } - TfLiteStatus AddTensorInput(int tensor_index, bool hybrid_op) { - return AddTensor(tensor_index, hybrid_op, &augmented_inputs_); + TfLiteStatus AddTensorInput(int tensor_index, bool hybrid_op, + bool scalar_as_tensor = false) { + return AddTensor(tensor_index, hybrid_op, &augmented_inputs_, + scalar_as_tensor); } TfLiteStatus AddTensorOutput(int tensor_index) { @@ -386,7 +463,8 @@ class NNAPIOpBuilder { // If another caller previously created a NN API tensor for `tensor_index` // then the existing one is returned. TfLiteStatus AddTensor(int tensor_index, bool hybrid_op, - std::vector* indices) { + std::vector* indices, + bool scalar_as_tensor = false) { int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); if (ann_tensor_index != -1) { indices->push_back(ann_tensor_index); @@ -437,10 +515,16 @@ class NNAPIOpBuilder { context_->ReportError(context_, "Logic error in NN API Delegate.\n"); return kTfLiteError; } + uint32_t tensor_rank = static_cast(tensor->dims->size); + uint32_t* tensor_dims = reinterpret_cast(tensor->dims->data); + if (scalar_as_tensor && tensor_rank == 0) { + // Use rank 1, shape {1} operand for TFLite scalar tensors. + tensor_rank = 1; + tensor_dims = &tensor_rank; + } - ANeuralNetworksOperandType operand_type{ - nn_type, static_cast(tensor->dims->size), - reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + ANeuralNetworksOperandType operand_type{nn_type, tensor_rank, tensor_dims, + scale, zeroPoint}; RETURN_TFLITE_ERROR_IF_NN_ERROR( context_, nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); @@ -572,6 +656,12 @@ class NNAPIDelegateKernel { // Hybrid operators not supported before NNAPI 1.2. return nullptr; } + const auto input_type = context->tensors[node->inputs->data[0]].type; + if (android_sdk_version < kMinSdkVersionForNNAPI12 && + input_type == kTfLiteUInt8 && + !IsRestrictedScalesCompliant(context, node)) { + return nullptr; + } auto builtin = reinterpret_cast(node->builtin_data); if (builtin->dilation_width_factor != 1 || @@ -593,6 +683,12 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinDepthwiseConv2d: if (version == 1) { + const auto input_type = context->tensors[node->inputs->data[0]].type; + if (android_sdk_version < kMinSdkVersionForNNAPI12 && + input_type == kTfLiteUInt8 && + !IsRestrictedScalesCompliant(context, node)) { + return nullptr; + } return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( @@ -614,6 +710,12 @@ class NNAPIDelegateKernel { // Hybrid operators not supported before NNAPI 1.2. return nullptr; } + const auto input_type = context->tensors[node->inputs->data[0]].type; + if (android_sdk_version < kMinSdkVersionForNNAPI12 && + input_type == kTfLiteUInt8 && + !IsRestrictedScalesCompliant(context, node)) { + return nullptr; + } return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( @@ -625,11 +727,23 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinSoftmax: if (version == 1) { + const auto& input = context->tensors[node->outputs->data[0]]; + if (input.type != kTfLiteFloat32 && input.type != kTfLiteUInt8) { + return nullptr; + } + const int input_rank = input.dims->size; + if (input_rank > 4) return nullptr; + // Before API level 29 only 2D and 4D input tensors were supported. + if (android_sdk_version < kMinSdkVersionForNNAPI12) { + if (input_rank != 2 && input_rank != 4) return nullptr; + } return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( mapping_args.node->builtin_data); mapping_args.builder->AddScalarFloat32Operand(builtin->beta); + // Optional scalar specifying the dimension the activation would be + // performed on is not added. Default to -1. return ANEURALNETWORKS_SOFTMAX; }; } @@ -639,6 +753,31 @@ class NNAPIDelegateKernel { return BasicMappingFn; } break; + case kTfLiteBuiltinResizeBilinear: + if (version == 1) { + const auto& input = context->tensors[node->inputs->data[0]]; + if (input.dims->size != 4) return nullptr; + if (input.type != kTfLiteFloat32 && input.type != kTfLiteUInt8) { + return nullptr; + } + if (android_sdk_version < kMinSdkVersionForNNAPI12 && + input.type != kTfLiteFloat32) { + // NNAPI 1.0 & 11 only supports float input. + return nullptr; + } + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + const int output_id = mapping_args.node->outputs->data[0]; + auto& output = mapping_args.context->tensors[output_id]; + const int output_height = output.dims->data[1]; + const int output_width = output.dims->data[2]; + // TfLiteResizeBilinearParams's |align_corners| is ignored. + mapping_args.builder->AddScalarInt32Operand(output_width); + mapping_args.builder->AddScalarInt32Operand(output_height); + return ANEURALNETWORKS_RESIZE_BILINEAR; + }; + } + break; case kTfLiteBuiltinSqueeze: if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) { return [](const NNAPIOpMappingArgs& mapping_args) @@ -792,12 +931,17 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinPad: - if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 && - node->inputs->size == 2 && - context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + if (version == 1 && node->inputs->size == 2 && + (android_sdk_version >= kMinSdkVersionForNNAPI11) && + (context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 || + (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8 && + context->tensors[node->inputs->data[0]].params.zero_point == 0) || + android_sdk_version >= kMinSdkVersionForNNAPI12)) { // NNAPI does not support specifying the padding value. - // NNAPI pads physical zero for quantized tensors, so only delegate - // float pad to NNAPI. + // Before 1.2, NNAPI pads physical zero for quantized tensors, so only + // delegate pad with float input or quantized input with zero_point == + // 0 to NNAPI. NNAPI 1.2 onwards pads with zero-point, so delegate + // other quantized pad as well. return BasicMappingFn; } break; @@ -832,6 +976,36 @@ class NNAPIDelegateKernel { return BasicMappingFn; } break; + case kTfLiteBuiltinAbs: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { + return BasicMappingFn; + } + break; + case kTfLiteBuiltinExp: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { + return BasicMappingFn; + } + break; + case kTfLiteBuiltinLog: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { + return BasicMappingFn; + } + break; + case kTfLiteBuiltinRsqrt: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { + return BasicMappingFn; + } + break; + case kTfLiteBuiltinSin: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { + return BasicMappingFn; + } + break; + case kTfLiteBuiltinSqrt: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { + return BasicMappingFn; + } + break; case kTfLiteBuiltinRnn: // NNAPI only support float32 weights. if (version == 1 && node->inputs->size == 5 && @@ -883,15 +1057,35 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinLstm: - // NNAPI only support float32 weights. - // Only delegate to NNAPI 1.1, as 1.0 has a bug for optional tensors - // which would affect LSTM. // TODO(miaowang): add loggings to indicate why the op is rejected. - if (version == 1 && node->inputs->size == 20 && - android_sdk_version >= kMinSdkVersionForNNAPI11 && - context->tensors[node->inputs - ->data[/*kInputToOutputWeightsTensor*/ 4]] - .type == kTfLiteFloat32) { + if (version == 1) { + if (android_sdk_version < kMinSdkVersionForNNAPI11) { + // Only delegate to NNAPI 1.1+, as 1.0 has a bug for optional + // tensors which would affect LSTM. + return nullptr; + } + if (android_sdk_version < kMinSdkVersionForNNAPI12 && + IsHybridOperator(context, builtin_code, node)) { + // Hybrid operators not supported before NNAPI 1.2. + return nullptr; + } + // TODO(levp): name the constants for number of inputs in LSTM kernel. + if (node->inputs->size != 20 && node->inputs->size != 24) { + return nullptr; + } + if (node->inputs->size == 24 && + android_sdk_version < kMinSdkVersionForNNAPI12) { + // LSTM with layer norm introduced in API level 29 + return nullptr; + } + const TfLiteType weight_type = + context + ->tensors[node->inputs + ->data[/*kInputToOutputWeightsTensor*/ 4]] + .type; + if (weight_type != kTfLiteFloat32 && weight_type != kTfLiteUInt8) { + return nullptr; + } return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( @@ -900,7 +1094,7 @@ class NNAPIDelegateKernel { mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip); mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip); - // Current NNAPI implementation requires the sratch_buffer as + // Current NNAPI implementation requires the scratch_buffer as // output. mapping_args.builder->AddAdditionalFloat32OutputTensor(2); @@ -922,6 +1116,20 @@ class NNAPIDelegateKernel { mapping_args.model_state_tfl_inputs->push_back( mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19]); + const bool hybrid_op = IsHybridOperator( + mapping_args.context, kTfLiteBuiltinLstm, mapping_args.node); + + if (mapping_args.node->inputs->size == 24) { + for (int i = 20; i < 24; ++i) { + const auto input_index = mapping_args.node->inputs->data[i]; + if (input_index != kOptionalTensor) { + mapping_args.builder->AddTensorInput(input_index, hybrid_op); + } else { + mapping_args.builder->AddVectorFloat32Operand(nullptr, 0); + } + } + } + return ANEURALNETWORKS_LSTM; }; } @@ -956,6 +1164,11 @@ class NNAPIDelegateKernel { return BasicMappingFn; } break; + case kTfLiteBuiltinPrelu: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12) { + return BasicMappingFn; + } + break; default: // All other operators are not mapped. return nullptr; @@ -970,8 +1183,22 @@ class NNAPIDelegateKernel { nodes_.push_back(node_index); } + const char* device_name_ptr = + StatefulNnApiDelegate::GetOptions(params->delegate).accelerator_name; + // user specified an acclelerator to use. + if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI12 && + device_name_ptr != nullptr) { + nnapi_device_ = GetDeviceHandle(device_name_ptr); + if (nnapi_device_ == nullptr) { + context->ReportError(context, + "Could not find the specified accelerator: %s.", + device_name_ptr); + return kTfLiteError; + } + } + if (!nn_model_) { - ANeuralNetworksModel* model; + ANeuralNetworksModel* model = nullptr; RETURN_TFLITE_ERROR_IF_NN_ERROR( context, nnapi_->ANeuralNetworksModel_create(&model)); nn_model_.reset(model); @@ -981,12 +1208,39 @@ class NNAPIDelegateKernel { } if (!nn_compilation_) { - ANeuralNetworksCompilation* compilation; - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, nnapi_->ANeuralNetworksCompilation_create(nn_model_.get(), - &compilation)); - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, nnapi_->ANeuralNetworksCompilation_finish(compilation)); + ANeuralNetworksCompilation* compilation = nullptr; + if (nnapi_device_ != nullptr) { + // Compile for the selected accelerator. + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, nnapi_->ANeuralNetworksCompilation_createForDevices( + nn_model_.get(), &nnapi_device_, 1, &compilation)); + } else { + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, nnapi_->ANeuralNetworksCompilation_create(nn_model_.get(), + &compilation)); + } + + auto preference = StatefulNnApiDelegate::GetOptions(params->delegate) + .execution_preference; + if (preference != + StatefulNnApiDelegate::Options::ExecutionPreference::kUndefined) { + const int preference_result = + nnapi_->ANeuralNetworksCompilation_setPreference(compilation, + preference); + if (preference_result != ANEURALNETWORKS_NO_ERROR) { + nnapi_->ANeuralNetworksCompilation_free(compilation); + compilation = nullptr; + } + RETURN_TFLITE_ERROR_IF_NN_ERROR(context, preference_result); + } + + const int finish_result = + nnapi_->ANeuralNetworksCompilation_finish(compilation); + if (finish_result != ANEURALNETWORKS_NO_ERROR) { + nnapi_->ANeuralNetworksCompilation_free(compilation); + compilation = nullptr; + } + RETURN_TFLITE_ERROR_IF_NN_ERROR(context, finish_result); nn_compilation_.reset(compilation); } return kTfLiteOk; @@ -997,6 +1251,8 @@ class NNAPIDelegateKernel { RETURN_TFLITE_ERROR_IF_NN_ERROR( context, nnapi_->ANeuralNetworksExecution_create(nn_compilation_.get(), &execution)); + std::unique_ptr + execution_unique_ptr(execution); // Set the input tensor buffers. Note: we access tflite tensors using // absolute indices but NN api indices inputs by relative indices. @@ -1020,6 +1276,7 @@ class NNAPIDelegateKernel { execution, relative_input_index, nullptr, nn_input_memory_->get_handle(), input_offset, tensor->bytes)); input_offset += tensor->bytes; + input_offset += getNumPaddingBytes(tensor->bytes); relative_input_index++; } } @@ -1035,6 +1292,7 @@ class NNAPIDelegateKernel { execution, relative_output_index, nullptr, nn_output_memory_->get_handle(), output_offset, tensor->bytes)); output_offset += tensor->bytes; + output_offset += getNumPaddingBytes(tensor->bytes); relative_output_index++; } @@ -1053,14 +1311,19 @@ class NNAPIDelegateKernel { relative_output_index++; } // Invoke ANN in blocking fashion. - ANeuralNetworksEvent* event = nullptr; - RETURN_TFLITE_ERROR_IF_NN_ERROR( - context, - nnapi_->ANeuralNetworksExecution_startCompute(execution, &event)); - RETURN_TFLITE_ERROR_IF_NN_ERROR(context, - nnapi_->ANeuralNetworksEvent_wait(event)); - nnapi_->ANeuralNetworksEvent_free(event); - nnapi_->ANeuralNetworksExecution_free(execution); + if (nnapi_->android_sdk_version < kMinSdkVersionForNNAPI12) { + ANeuralNetworksEvent* event = nullptr; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, + nnapi_->ANeuralNetworksExecution_startCompute(execution, &event)); + const int wait_result = nnapi_->ANeuralNetworksEvent_wait(event); + nnapi_->ANeuralNetworksEvent_free(event); + RETURN_TFLITE_ERROR_IF_NN_ERROR(context, wait_result); + } else { + // Use synchronous execution for NNAPI 1.2+. + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, nnapi_->ANeuralNetworksExecution_compute(execution)); + } // copy results from shared memory to the destination. output_offset = 0; @@ -1069,6 +1332,7 @@ class NNAPIDelegateKernel { memcpy(tensor->data.raw, nn_output_memory_->get_data_ptr() + output_offset, tensor->bytes); output_offset += tensor->bytes; + output_offset += getNumPaddingBytes(tensor->bytes); } return kTfLiteOk; @@ -1077,6 +1341,8 @@ class NNAPIDelegateKernel { private: // Access to NNApi. const NnApi* nnapi_; + // ANN device handle. + ANeuralNetworksDevice* nnapi_device_ = nullptr; // ANN API state. std::unique_ptr nn_model_; std::unique_ptr @@ -1116,6 +1382,13 @@ class NNAPIDelegateKernel { inputs_to_potentially_dequantize = {1, 2}; break; } + case kTfLiteBuiltinLstm: { + input_tensor_index = 0; + inputs_to_potentially_dequantize = {1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 20, 21, 22, 23}; + break; + } default: return; } @@ -1127,6 +1400,7 @@ class NNAPIDelegateKernel { if (!IsFloat(context->tensors[tensor_id].type)) return; for (int i : inputs_to_potentially_dequantize) { + if (i < 0 || i >= node->inputs->size) continue; // Ignore invalid index. tensor_id = node->inputs->data[i]; if (tensor_id < 0) continue; // Ignore optional input. @@ -1155,9 +1429,18 @@ class NNAPIDelegateKernel { context->GetNodeAndRegistration(context, node_index, &node, ®)); const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node); + const bool scalar_as_tensor = IsScalarInputSupported(reg->builtin_code); // Map inputs to NN API tensor indices. + int num_added_inputs = 0; for (auto input_index : TfLiteIntArrayView(node->inputs)) { + if (reg->builtin_code == kTfLiteBuiltinLstm && num_added_inputs >= 20) { + // Skip layer normalization weights. They are added in the Map + // function (after all the other inputs added there) since layer + // normalization weights are the last four inputs of the LSTM op in + // NNAPI. + continue; + } if (input_index == kOptionalTensor && (reg->builtin_code == kTfLiteBuiltinLstm || reg->builtin_code == kTfLiteBuiltinSvdf)) { @@ -1166,9 +1449,20 @@ class NNAPIDelegateKernel { // TODO(miaowang): make sure this is also able to handle quantized // tensor when supported by NNAPI. TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0)); + } else if (reg->builtin_code == kTfLiteBuiltinResizeBilinear) { + if (num_added_inputs == 0) { + // Only the first input tensor is added. The second one, specifying + // the output height and width, is not added and instead the height + // and width will be added individually as scalars by the mapping + // function returned by Map(). + TF_LITE_ENSURE_STATUS( + builder.AddTensorInput(input_index, hybrid_op)); + } } else { - TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op)); + TF_LITE_ENSURE_STATUS( + builder.AddTensorInput(input_index, hybrid_op, scalar_as_tensor)); } + ++num_added_inputs; } // Get op type and operands int nn_op_type = Map( @@ -1209,6 +1503,7 @@ class NNAPIDelegateKernel { context->tensors[i].allocation_type != kTfLiteMmapRo) { inputs.push_back(operand_mapping_.lite_index_to_ann(i)); total_input_byte_size += context->tensors[i].bytes; + total_input_byte_size += getNumPaddingBytes(context->tensors[i].bytes); } } @@ -1216,6 +1511,7 @@ class NNAPIDelegateKernel { for (int i : TfLiteIntArrayView(output_tensors)) { outputs.push_back(operand_mapping_.lite_index_to_ann(i)); total_output_byte_size += context->tensors[i].bytes; + total_output_byte_size += getNumPaddingBytes(context->tensors[i].bytes); } // Add state output tensors as model outputs. @@ -1253,88 +1549,131 @@ class NNAPIDelegateKernel { } // namespace -// Return a NN API Delegate struct that can check for support of ops. +StatefulNnApiDelegate::StatefulNnApiDelegate(Options options) + : TfLiteDelegate(TfLiteDelegateCreate()), + delegate_data_( + Data{.execution_preference = options.execution_preference}) { + if (options.accelerator_name) { + delegate_data_.accelerator_name = options.accelerator_name; + } + Prepare = DoPrepare; + data_ = &delegate_data_; +} + +StatefulNnApiDelegate::StatefulNnApiDelegate() + : StatefulNnApiDelegate(Options()) {} + +const StatefulNnApiDelegate::Options StatefulNnApiDelegate::GetOptions( + TfLiteDelegate* delegate) { + auto delegate_data = reinterpret_cast(delegate->data_); + StatefulNnApiDelegate::Options options; + options.execution_preference = delegate_data->execution_preference; + options.accelerator_name = delegate_data->accelerator_name.empty() + ? nullptr + : delegate_data->accelerator_name.c_str(); + return options; +} + +TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context, + TfLiteDelegate* delegate) { + // Do not check nodes_ if NN API is unavailable. + const NnApi* nnapi = NnApiImplementation(); + if (nnapi->android_sdk_version < kMinSdkVersionForNNAPI || + !nnapi->nnapi_exists) { + return kTfLiteOk; + } + // For NNAPI 1.2+, check if there is any accelerator available. + // If not, don't delegate to NNAPI's CPU reference implementation. + if (nnapi->android_sdk_version >= kMinSdkVersionForNNAPI12) { + uint32_t device_count = 0; + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, nnapi->ANeuralNetworks_getDeviceCount(&device_count)); + // Any available accelerator will make the device_count larger than 1. + // More sophisticated check and whitelisting can be added later. + if (device_count <= 1) { + return kTfLiteOk; + } + // Check if user specified an acclelerator to use. + const char* device_name_ptr = GetOptions(delegate).accelerator_name; + if (device_name_ptr && !GetDeviceHandle(device_name_ptr)) { + // If the selected accelerator cannot be found, NNAPI will not be used. + context->ReportError(context, + "Could not find the specified accelerator: %s.", + device_name_ptr); + return kTfLiteOk; + } + } + // Allocate one element in vector already since TensorFlow Lite uses + // the first value as the number of nodes. The actual value will be set + // later, after the vector has been filled. + std::vector supported_nodes(1); + // We don't care about all nodes_, we only care about ones in the + // current plan. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + + int android_sdk_version = NnApiImplementation()->android_sdk_version; + // Check for every node if it is supported + // TODO(b/80625235): Fix this to do more careful checking of versioning. + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + if (NNAPIDelegateKernel::Map(context, registration->builtin_code, + registration->version, android_sdk_version, + node)) { + supported_nodes.push_back(node_index); + } + } + // First element in vector must be the number of actual nodes. + supported_nodes[0] = supported_nodes.size() - 1; + + // NN API Delegate Registration (the pseudo kernel that will invoke NN + // API node sub sets) + static const TfLiteRegistration nnapi_delegate_kernel = { + .init = [](TfLiteContext* context, const char* buffer, + size_t length) -> void* { + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel; + kernel_state->Init(context, params); + return kernel_state; + }, + + .free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }, + + .prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { + // Since the underlying resize happened ahead of delegation + // worked. This does nothing. + return kTfLiteOk; + }, + + .invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { + NNAPIDelegateKernel* state = + reinterpret_cast(node->user_data); + return state->Invoke(context, node); + }, + + .profiling_string = nullptr, + .builtin_code = kTfLiteBuiltinDelegate, + .custom_name = "TfLiteNnapiDelegate", + .version = 1, + }; + + // Request TFLite to partition the graph and make kernels + // for each independent node sub set a new nnapi_delegate_kernel. + return context->ReplaceNodeSubsetsWithDelegateKernels( + context, nnapi_delegate_kernel, + reinterpret_cast(supported_nodes.data()), delegate); +} + +// Returns a singleton NNAPI Delegate that can check for support of ops. TfLiteDelegate* NnApiDelegate() { - static TfLiteDelegate delegate = { - .data_ = nullptr, - .flags = kTfLiteDelegateFlagsNone, - .Prepare = [](TfLiteContext* context, - TfLiteDelegate* delegate) -> TfLiteStatus { - // Do not check nodes_ if NN API is unavailable. - const NnApi* nnapi = NnApiImplementation(); - if (nnapi->android_sdk_version < kMinSdkVersionForNNAPI || - !nnapi->nnapi_exists) { - return kTfLiteOk; - } - - // Allocate one element in vector already since TensorFlow Lite uses - // the first value as the number of nodes. The actual value will be set - // later, after the vector has been filled. - std::vector supported_nodes(1); - // We don't care about all nodes_, we only care about ones in the - // current plan. - TfLiteIntArray* plan; - TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); - - int android_sdk_version = NnApiImplementation()->android_sdk_version; - // Check for every node if it is supported - // TODO(b/80625235): Fix this to do more careful checking of versioning. - for (int node_index : TfLiteIntArrayView(plan)) { - TfLiteNode* node; - TfLiteRegistration* registration; - TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( - context, node_index, &node, ®istration)); - if (NNAPIDelegateKernel::Map(context, registration->builtin_code, - registration->version, - android_sdk_version, node)) { - supported_nodes.push_back(node_index); - } - } - // First element in vector must be the number of actual nodes. - supported_nodes[0] = supported_nodes.size() - 1; - - // NN API Delegate Registration (the pseudo kernel that will invoke NN - // API node sub sets) - static const TfLiteRegistration nnapi_delegate_kernel = { - .init = [](TfLiteContext* context, const char* buffer, - size_t length) -> void* { - const TfLiteDelegateParams* params = - reinterpret_cast(buffer); - NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel; - kernel_state->Init(context, params); - return kernel_state; - }, - - .free = [](TfLiteContext* context, void* buffer) -> void { - delete reinterpret_cast(buffer); - }, - - .prepare = [](TfLiteContext* context, - TfLiteNode* node) -> TfLiteStatus { - // Since the underlying resize happened ahead of delegation - // worked. This does nothing. - return kTfLiteOk; - }, - - .invoke = [](TfLiteContext* context, - TfLiteNode* node) -> TfLiteStatus { - NNAPIDelegateKernel* state = - reinterpret_cast(node->user_data); - return state->Invoke(context, node); - }, - - .builtin_code = kTfLiteBuiltinDelegate, - }; - - // Request TFLite to partition the graph and make kernels - // for each independent node sub set a new nnapi_delegate_kernel. - return context->ReplaceNodeSubsetsWithDelegateKernels( - context, nnapi_delegate_kernel, - reinterpret_cast(supported_nodes.data()), - delegate); - }}; - - return &delegate; + static StatefulNnApiDelegate* delegate = new StatefulNnApiDelegate(); + return delegate; } } // namespace tflite diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h index 099fb724292..782744efb10 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h @@ -15,17 +15,79 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ #define TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#include +#include + #include "tensorflow/lite/c/c_api_internal.h" namespace tflite { -// Return a delegate that can be used to use the NN API. +// TFliteDelegate to interface with NNAPI. +class StatefulNnApiDelegate : public TfLiteDelegate { + public: + // Encapsulates all options that are specific to NNAPI delegate. + struct Options { + // Preferred Power/perf trade-off. For more details please see + // ANeuralNetworksCompilation_setPreference documentation in : + // https://developer.android.com/ndk/reference/group/neural-networks.html + enum ExecutionPreference { + kUndefined = -1, + kLowPower = 0, + kFastSingleAnswer = 1, + kSustainedSpeed = 2, + }; + + // Preferred Power/perf trade-off. + ExecutionPreference execution_preference = kUndefined; + + // Selected NNAPI accelerator with nul-terminated name. + // Default to nullptr, which implies the NNAPI default behavior: NNAPI + // runtime is allowed to use all available accelerators. If the selected + // accelerator cannot be found, NNAPI will not be used. + // It is the caller's responsibility to ensure the string is valid for the + // duration of the Options object lifetime. + const char* accelerator_name = nullptr; + }; + + // Uses default options. + StatefulNnApiDelegate(); + + // The constructor that accepts options from user. + explicit StatefulNnApiDelegate(Options options); + + ~StatefulNnApiDelegate() = default; + + // Returns the delegate options. + static const Options GetOptions(TfLiteDelegate* delegate); + + private: + // Encapsulates all delegate data. + struct Data { + // Preferred Power/perf trade-off. + Options::ExecutionPreference execution_preference; + // Selected NNAPI accelerator name. + std::string accelerator_name; + }; + + // Implements TfLiteDelegate::Prepare. Please refer to TFLiteDelegate + // documentation for more info. + static TfLiteStatus DoPrepare(TfLiteContext* context, + TfLiteDelegate* delegate); + + // Delegate data presented through TfLiteDelegate::data_. + Data delegate_data_; +}; + +// DEPRECATED: Please use StatefulNnApiDelegate class instead. +// +// Returns a singleton delegate that can be used to use the NN API. // e.g. // NnApiDelegate* delegate = NnApiDelegate(); // interpreter->ModifyGraphWithDelegate(&delegate); // NnApiDelegate() returns a singleton, so you should not free this // pointer or worry about its lifetime. TfLiteDelegate* NnApiDelegate(); + } // namespace tflite #endif // TENSORFLOW_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ diff --git a/tensorflow/core/kernels/bitcast_op.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc similarity index 50% rename from tensorflow/core/kernels/bitcast_op.h rename to tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc index 1f3659f3033..1eb783af179 100644 --- a/tensorflow/core/kernels/bitcast_op.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// See docs in ../ops/array_ops.cc. +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" -#ifndef TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ -#define TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ +namespace tflite { -#include // for memcpy +// Return a non-functional NN API Delegate struct. +TfLiteDelegate* NnApiDelegate() { + static TfLiteDelegate delegate = [] { + TfLiteDelegate delegate = TfLiteDelegateCreate(); + delegate.Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // Silently succeed without modifying the graph. + return kTfLiteOk; + }; + return delegate; + }(); -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" + return &delegate; +} -#endif // TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_ +} // namespace tflite diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc index 4a8a3a0da82..cf20d10e485 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" + #include #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/test_util.h" @@ -45,13 +46,22 @@ class SingleOpModelWithNNAPI : public SingleOpModel { }); } + explicit SingleOpModelWithNNAPI( + const StatefulNnApiDelegate::Options& options) { + stateful_delegate_.reset(new StatefulNnApiDelegate(options)); + auto* delegate = stateful_delegate_.get(); + this->SetApplyDelegate([delegate](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(delegate); + }); + } + TfLiteStatus ResizeInputTensor(int tensor_index, const std::vector& dims) { return interpreter_->ResizeInputTensor(tensor_index, dims); } protected: - void SetData(int index, TensorType type, std::initializer_list data) { + void SetData(int index, TensorType type, const std::vector& data) { switch (type) { case TensorType_FLOAT32: PopulateTensor(index, data); @@ -70,6 +80,26 @@ class SingleOpModelWithNNAPI : public SingleOpModel { break; } } + + void GetData(int index, TensorType type, std::vector* output) { + switch (type) { + case TensorType_FLOAT32: + *output = ExtractVector(index); + break; + case TensorType_UINT8: + *output = Dequantize(ExtractVector(index), + GetScale(index), GetZeroPoint(index)); + break; + default: + FAIL() << "Type not supported: " << type; + break; + } + } + + private: + // Stateful NNAPI delegate. This is valid only if the state-ful constructor is + // used. + std::unique_ptr stateful_delegate_; }; class FloatAddOpModel : public SingleOpModelWithNNAPI { @@ -78,13 +108,16 @@ class FloatAddOpModel : public SingleOpModelWithNNAPI { const TensorData& output, ActivationFunctionType activation_type, bool allow_fp32_relax_to_fp16 = false) { - input1_ = AddInput(input1); - input2_ = AddInput(input2); - output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, - CreateAddOptions(builder_, activation_type).Union()); - BuildInterpreter({GetShape(input1_), GetShape(input2_)}, - allow_fp32_relax_to_fp16); + Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16); + } + + FloatAddOpModel(const StatefulNnApiDelegate::Options& options, + const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type, + bool allow_fp32_relax_to_fp16 = false) + : SingleOpModelWithNNAPI(options) { + Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16); } int input1() { return input1_; } @@ -96,6 +129,20 @@ class FloatAddOpModel : public SingleOpModelWithNNAPI { int input1_; int input2_; int output_; + + private: + // Performs initialization logic shared across all constructors. + void Init(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type, + bool allow_fp32_relax_to_fp16 = false) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}, + allow_fp32_relax_to_fp16); + } }; // Do a test with the NN API using no activation. @@ -109,6 +156,17 @@ TEST(NNAPIDelegate, AddWithNoActivation) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); } +// Do a test with scalar input using no activation. +TEST(NNAPIDelegate, AddScalarWithNoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.7}); + m.PopulateTensor(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.3, 0.8, 0.8})); +} + // Do a test with the NN API using no activation. // The test allows computing FP32 with FP16 precision. In this particular case, // calculating in FP32 or FP16 should produce the same results. @@ -144,6 +202,38 @@ TEST(NNAPIDelegate, ResizeFails) { EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 3, 1}), kTfLiteError); } +// Sanity check for the state-ful NNAPI delegate. +TEST(NNAPIDelegate, StatefulDelegate) { + StatefulNnApiDelegate::Options options; + options.execution_preference = + StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower; + + FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Sanity check for the state-ful NNAPI delegate with accelerator_name +// specified. +TEST(NNAPIDelegate, StatefulDelegateWithAcceleratorName) { + StatefulNnApiDelegate::Options options; + options.execution_preference = + StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower; + options.accelerator_name = "nnapi-reference"; + + FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + class FloatMulOpModel : public SingleOpModelWithNNAPI { public: FloatMulOpModel(const TensorData& input1, const TensorData& input2, @@ -272,13 +362,6 @@ class ConvolutionOpModel : public SingleOpModelWithNNAPI { output_ = AddOutput(output); - if (input_type_ != TensorType_FLOAT32) { - // The following is required by quantized inference. It is the unittest's - // responsibility to make sure the output scale falls into the correct - // range. - CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); - } - SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, CreateConv2DOptions( builder_, padding, stride_width, stride_height, activation, @@ -430,10 +513,48 @@ TEST(ConvolutionOpTest, NoActivation) { })); } +TEST(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) { + // output_multiplier = 1.0118 + ConvolutionOpModel quant_op({TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128}, + {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128}, + {TensorType_UINT8, {}, -127, 128}); + ConvolutionOpModel float_op({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + std::initializer_list input = { + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }; + std::initializer_list filter = { + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }; + std::initializer_list bias = {1, 2, 3}; + + quant_op.SetInput(input); + quant_op.SetFilter(filter); + quant_op.SetBias(bias); + quant_op.Invoke(); + + float_op.SetInput(input); + float_op.SetFilter(filter); + float_op.SetBias(bias); + float_op.Invoke(); + + EXPECT_THAT(quant_op.GetOutput(), + ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); +} + class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI { public: DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, - const TensorData& output) { + const TensorData& output) + : input_type_(input.type) { input_ = AddInput(input); filter_ = AddInput(filter); @@ -465,21 +586,36 @@ class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI { BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); } - void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } - - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + SetData(input_, input_type_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + void SetFilter(std::initializer_list data) { + SetData(filter_, input_type_, data); + } + + void SetBias(std::initializer_list data) { + const auto bias_type = + (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32; + SetData(bias_, bias_type, data); + } + + std::vector GetOutput() { + if (input_type_ == TensorType_FLOAT32) { + return ExtractVector(output_); + } else { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + } protected: int input_; int filter_; int bias_; int output_; + + const TensorType input_type_; }; TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { @@ -508,6 +644,42 @@ TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { })); } +TEST(QuantizedDepthwiseConv2DTest, FilterMultiplierGreaterThan1) { + DepthwiseConvolutionOpModel quant_op( + {TensorType_UINT8, {1, 3, 2, 2}, -128.5, 128}, + {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128}, + {TensorType_UINT8, {}, -127, 128}); + DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + std::initializer_list input = { + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }; + std::initializer_list filter = { + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }; + std::initializer_list bias = {1, 2, 3, 4}; + + quant_op.SetInput(input); + quant_op.SetFilter(filter); + quant_op.SetBias(bias); + quant_op.Invoke(); + + float_op.SetInput(input); + float_op.SetFilter(filter); + float_op.SetBias(bias); + float_op.Invoke(); + + EXPECT_THAT(quant_op.GetOutput(), + ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); +} + class FullyConnectedOpModel : public SingleOpModelWithNNAPI { public: FullyConnectedOpModel( @@ -552,7 +724,14 @@ class FullyConnectedOpModel : public SingleOpModelWithNNAPI { SetData(bias_, bias_type, data); } - std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutput() { + if (input_type_ == TensorType_FLOAT32) { + return ExtractVector(output_); + } else { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + } protected: int input_; @@ -607,15 +786,41 @@ TEST(FullyConnectedOpTest, FloatInputQuantizedWeights) { ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}, 1.3))); } +TEST(FullyConnectedOpTest, QuantizedOutputMultiplierGreaterThan1) { + // real_multiplier = 2. + FullyConnectedOpModel m( + /*input=*/{TensorType_UINT8, {2, 10}, -127, 128}, + /*weights=*/{TensorType_UINT8, {3, 10}, -127, 128}, + /*output=*/{TensorType_UINT8, {}, -63.5, 64}); + + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // first batch + 58, 59, 60, // second batch + }))); +} + class SoftmaxOpModel : public SingleOpModelWithNNAPI { public: - SoftmaxOpModel(int batches, int size, float beta) - : batches_(batches), input_size_(size), beta_(beta) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); + SoftmaxOpModel(const TensorData& input, float beta) { + input_ = AddInput(input); + output_ = AddOutput(input); SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, - CreateSoftmaxOptions(builder_, beta_).Union()); - BuildInterpreter({{batches_, input_size_}}); + CreateSoftmaxOptions(builder_, beta).Union()); + BuildInterpreter({GetShape(input_)}); } void SetInput(std::initializer_list data) { @@ -631,17 +836,13 @@ class SoftmaxOpModel : public SingleOpModelWithNNAPI { private: int input_; int output_; - - int batches_; - int input_size_; - float beta_; }; -TEST(NNAPIDelegate, SoftmaxSimpleTest) { - SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); +TEST(SoftmaxOpTest, SimpleTest) { + SoftmaxOpModel m({TensorType_FLOAT32, {2, 5}}, /*beta=*/1.0); m.SetInput({ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 - -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 }); m.Invoke(); @@ -654,6 +855,63 @@ TEST(NNAPIDelegate, SoftmaxSimpleTest) { 1e-6))); } +TEST(SoftmaxOpTest, Beta2) { + SoftmaxOpModel m({TensorType_FLOAT32, {1, 5}}, /*beta=*/2.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.000290076, 0.002143387, 0.015837606, 0.117024957, 0.864703974}, + 1e-6))); +} + +TEST(SoftmaxOpTest, 3dInput) { + SoftmaxOpModel m({TensorType_FLOAT32, {2, 2, 5}}, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + 5.0, 1.0, 2.0, 3.0, 4.0, // b = 1 + -5.0, -1.0, -2.0, -3.0, -4.0, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231, + 0.636408647, 0.011656231, 0.031684921, 0.086128544, 0.234121657, + 0.011656231, 0.636408647, 0.234121657, 0.086128544, 0.031684921}, + 1e-6))); +} + +TEST(SoftmaxOpTest, 4dInput) { + SoftmaxOpModel m({TensorType_FLOAT32, {2, 2, 1, 5}}, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + 5.0, 1.0, 2.0, 3.0, 4.0, // b = 1 + -5.0, -1.0, -2.0, -3.0, -4.0, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231, + 0.636408647, 0.011656231, 0.031684921, 0.086128544, 0.234121657, + 0.011656231, 0.636408647, 0.234121657, 0.086128544, 0.031684921}, + 1e-6))); +} + class ReshapeOpModel : public SingleOpModelWithNNAPI { public: ReshapeOpModel(std::initializer_list input_shape, @@ -714,7 +972,8 @@ class SqueezeOpModel : public SingleOpModelWithNNAPI { int output_; }; -TEST(NNAPIDelegate, SqueezeSimpleTest) { +// TODO(b/215935381): Enable after resolving issues with flakiness. +TEST(NNAPIDelegate, DISABLED_SqueezeSimpleTest) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -818,6 +1077,87 @@ TEST(NNAPIDelegate, TransposeSimpleTest) { 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); } +class ElementwiseOpBaseModel : public SingleOpModelWithNNAPI { + public: + int input() const { return input_; } + int output() const { return output_; } + + protected: + int input_; + int output_; +}; + +class ElementwiseOpFloatModel : public ElementwiseOpBaseModel { + public: + ElementwiseOpFloatModel(BuiltinOperator op, + std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); + BuildInterpreter({input_shape}); + } +}; + +TEST(Elementwise, Abs) { + ElementwiseOpFloatModel m(BuiltinOperator_ABS, {1, 2, 4, 1}); + m.PopulateTensor(m.input(), { + 0.f, -6.2f, 2.f, 4.f, // + 3.f, -2.f, 10.f, 1.f, // + }); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), ElementsAreArray({ + 0.f, 6.2f, 2.f, 4.f, // + 3.f, 2.f, 10.f, 1.f, // + })); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 4, 1})); +} + +TEST(Elementwise, Exp) { + ElementwiseOpFloatModel m(BuiltinOperator_EXP, {3, 1, 2}); + m.PopulateTensor(m.input(), {1.0, 0.0, -1.0, 1.0, 1.0, -1.0}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear( + {2.71828, 1, 0.367879, 2.71828, 2.71828, 0.367879}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({3, 1, 2})); +} + +TEST(Elementwise, Log) { + ElementwiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(Elementwise, Rsqrt) { + ElementwiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 2, 4, 9}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({1, 0.7071, 0.5, 0.33333}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(Elementwise, Sin) { + ElementwiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 0, 0, 0.84147}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(Elementwise, Sqrt) { + ElementwiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {0, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1, 1.41421, 2}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + class FloatSubOpModel : public SingleOpModelWithNNAPI { public: FloatSubOpModel(const TensorData& input1, const TensorData& input2, @@ -1417,11 +1757,10 @@ TEST(NNAPIDelegate, LogisticQuantized) { {128, 1, 227, 251, 244, 32, 255, 188})); } -#if 0 class ResizeBilinearOpModel : public SingleOpModelWithNNAPI { public: ResizeBilinearOpModel(const TensorData& input, - std::initializer_list size_data = {}) { + std::initializer_list size_data) { bool const_size = size_data.size() != 0; input_ = AddInput(input); if (const_size) { @@ -1457,14 +1796,16 @@ class ResizeBilinearOpModel : public SingleOpModelWithNNAPI { int output_; }; -TEST(NNAPIDelegate, ResizeBilinearHorizontal) { - ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); +TEST(ResizeBilinear, Horizontal) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {}); m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} +TEST(ResizeBilinear, HorizontalConstant) { ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); const_m.SetInput({3, 6}); const_m.Invoke(); @@ -1472,14 +1813,16 @@ TEST(NNAPIDelegate, ResizeBilinearHorizontal) { ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } -TEST(NNAPIDelegate, ResizeBilinearVertical) { - ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); +TEST(ResizeBilinear, Vertical) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {}); m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} +TEST(ResizeBilinear, VerticalConstant) { ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); const_m.SetInput({3, 9}); const_m.Invoke(); @@ -1487,8 +1830,8 @@ TEST(NNAPIDelegate, ResizeBilinearVertical) { ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } -TEST(NNAPIDelegate, ResizeBilinearTwoDimensional) { - ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); +TEST(ResizeBilinear, TwoDimensional) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {}); m.SetInput({ 3, 6, // 9, 12 // @@ -1500,7 +1843,9 @@ TEST(NNAPIDelegate, ResizeBilinearTwoDimensional) { 7, 9, 10, // 9, 11, 12, // }))); +} +TEST(ResizeBilinear, TwoDimensionalConstant) { ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); const_m.SetInput({ 3, 6, // @@ -1513,7 +1858,6 @@ TEST(NNAPIDelegate, ResizeBilinearTwoDimensional) { 9, 11, 12, // }))); } -#endif template class PadOpModel : public SingleOpModelWithNNAPI { @@ -1887,8 +2231,8 @@ static std::initializer_list rnn_bias = { class RNNOpModel : public SingleOpModelWithNNAPI { public: RNNOpModel(int batches, int units, int size, - const TensorType& weights = TensorType_FLOAT32, - const TensorType& recurrent_weights = TensorType_FLOAT32) + const TensorType weights = TensorType_FLOAT32, + const TensorType recurrent_weights = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(size) { input_ = AddInput(TensorType_FLOAT32); weights_ = AddInput(weights); @@ -2246,11 +2590,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - const TensorType& weight_type = TensorType_FLOAT32) + const TensorType weight_type) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), - n_output_(n_output) { + n_output_(n_output), + weight_type_(weight_type) { input_ = AddInput(TensorType_FLOAT32); if (use_cifg) { @@ -2309,10 +2654,30 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { } // Adding the 2 input state tensors. - input_activation_state_ = - AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true); - input_cell_state_ = - AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true); + input_activation_state_ = AddInput(TensorType_FLOAT32, true); + input_cell_state_ = AddInput(TensorType_FLOAT32, true); + + const bool use_layer_norm = input_shapes.size() > 20; + // Layer norm weights. + if (use_layer_norm) { + const int kInputLayerNormCoeffsIndex = 20; + const int kForgetLayerNormCoeffsIndex = 21; + const int kCellLayerNormCoeffsIndex = 22; + const int kOutputLayerNormCoeffsIndex = 23; + + if (use_cifg) { + input_layer_norm_coefficients_ = AddNullInput(); + } else { + input_layer_norm_coefficients_ = + AddLayerNormCoeffsTensor(kInputLayerNormCoeffsIndex, input_shapes); + } + forget_layer_norm_coefficients_ = + AddLayerNormCoeffsTensor(kForgetLayerNormCoeffsIndex, input_shapes); + cell_layer_norm_coefficients_ = + AddLayerNormCoeffsTensor(kCellLayerNormCoeffsIndex, input_shapes); + output_layer_norm_coefficients_ = + AddLayerNormCoeffsTensor(kOutputLayerNormCoeffsIndex, input_shapes); + } output_ = AddOutput(TensorType_FLOAT32); @@ -2323,72 +2688,90 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { BuildInterpreter(input_shapes); } - void SetInputToInputWeights(std::vector f) { - PopulateTensor(input_to_input_weights_, f); + void SetInputToInputWeights(const std::vector& f) { + SetData(input_to_input_weights_, weight_type_, f); } - void SetInputToForgetWeights(std::vector f) { - PopulateTensor(input_to_forget_weights_, f); + void SetInputToForgetWeights(const std::vector& f) { + SetData(input_to_forget_weights_, weight_type_, f); } - void SetInputToCellWeights(std::vector f) { - PopulateTensor(input_to_cell_weights_, f); + void SetInputToCellWeights(const std::vector& f) { + SetData(input_to_cell_weights_, weight_type_, f); } - void SetInputToOutputWeights(std::vector f) { - PopulateTensor(input_to_output_weights_, f); + void SetInputToOutputWeights(const std::vector& f) { + SetData(input_to_output_weights_, weight_type_, f); } - void SetRecurrentToInputWeights(std::vector f) { - PopulateTensor(recurrent_to_input_weights_, f); + void SetRecurrentToInputWeights(const std::vector& f) { + SetData(recurrent_to_input_weights_, weight_type_, f); } - void SetRecurrentToForgetWeights(std::vector f) { - PopulateTensor(recurrent_to_forget_weights_, f); + void SetRecurrentToForgetWeights(const std::vector& f) { + SetData(recurrent_to_forget_weights_, weight_type_, f); } - void SetRecurrentToCellWeights(std::vector f) { - PopulateTensor(recurrent_to_cell_weights_, f); + void SetRecurrentToCellWeights(const std::vector& f) { + SetData(recurrent_to_cell_weights_, weight_type_, f); } - void SetRecurrentToOutputWeights(std::vector f) { - PopulateTensor(recurrent_to_output_weights_, f); + void SetRecurrentToOutputWeights(const std::vector& f) { + SetData(recurrent_to_output_weights_, weight_type_, f); } - void SetCellToInputWeights(std::vector f) { - PopulateTensor(cell_to_input_weights_, f); + void SetCellToInputWeights(const std::vector& f) { + SetData(cell_to_input_weights_, weight_type_, f); } - void SetCellToForgetWeights(std::vector f) { - PopulateTensor(cell_to_forget_weights_, f); + void SetCellToForgetWeights(const std::vector& f) { + SetData(cell_to_forget_weights_, weight_type_, f); } - void SetCellToOutputWeights(std::vector f) { - PopulateTensor(cell_to_output_weights_, f); + void SetCellToOutputWeights(const std::vector& f) { + SetData(cell_to_output_weights_, weight_type_, f); } - void SetInputGateBias(std::vector f) { + void SetInputGateBias(const std::vector& f) { PopulateTensor(input_gate_bias_, f); } - void SetForgetGateBias(std::vector f) { + void SetForgetGateBias(const std::vector& f) { PopulateTensor(forget_gate_bias_, f); } - void SetCellBias(std::vector f) { PopulateTensor(cell_bias_, f); } + void SetCellBias(const std::vector& f) { + PopulateTensor(cell_bias_, f); + } - void SetOutputGateBias(std::vector f) { + void SetOutputGateBias(const std::vector& f) { PopulateTensor(output_gate_bias_, f); } - void SetProjectionWeights(std::vector f) { - PopulateTensor(projection_weights_, f); + void SetProjectionWeights(const std::vector& f) { + SetData(projection_weights_, weight_type_, f); } - void SetProjectionBias(std::vector f) { + void SetProjectionBias(const std::vector& f) { PopulateTensor(projection_bias_, f); } + void SetInputLayerNormCoefficients(std::vector f) { + PopulateTensor(input_layer_norm_coefficients_, f); + } + + void SetForgetLayerNormCoefficients(std::vector f) { + PopulateTensor(forget_layer_norm_coefficients_, f); + } + + void SetCellLayerNormCoefficients(std::vector f) { + PopulateTensor(cell_layer_norm_coefficients_, f); + } + + void SetOutputLayerNormCoefficients(std::vector f) { + PopulateTensor(output_layer_norm_coefficients_, f); + } + void SetInput(int offset, const float* begin, const float* end) { PopulateTensor(input_, offset, const_cast(begin), const_cast(end)); @@ -2427,6 +2810,11 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { int input_activation_state_; int input_cell_state_; + int input_layer_norm_coefficients_; + int forget_layer_norm_coefficients_; + int cell_layer_norm_coefficients_; + int output_layer_norm_coefficients_; + int output_; int output_state_; int cell_state_; @@ -2435,6 +2823,18 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { int n_input_; int n_cell_; int n_output_; + + private: + const TensorType weight_type_; + + int AddLayerNormCoeffsTensor( + int tensor_index, const std::vector>& input_shapes) { + if (input_shapes[tensor_index][0] != 0) { + return AddInput(TensorType_FLOAT32); + } else { + return AddNullInput(); + } + } }; class BaseLstmTest : public ::testing::Test { @@ -2456,6 +2856,10 @@ class BaseLstmTest : public ::testing::Test { std::vector cell_to_forget_weights_; std::vector cell_to_output_weights_; std::vector projection_weights_; + std::vector input_layer_norm_coefficients_; + std::vector forget_layer_norm_coefficients_; + std::vector cell_layer_norm_coefficients_; + std::vector output_layer_norm_coefficients_; // LSTM input is stored as num_batch x num_inputs vector. std::vector> lstm_input_; @@ -2580,7 +2984,80 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { {0, 0}, // projection_weight tensor {0}, // projection_bias tensor - }); + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + }, + /*weight_type=*/TensorType_FLOAT32); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest + : public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest, + LstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {0}, // input_layer_norm_coefficient tensor + {0}, // forget_layer_norm_coefficient tensor + {0}, // cell_layer_norm_coefficient tensor + {0}, // output_layer_norm_coefficient tensor + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -2683,7 +3160,11 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { {0, 0}, // projection_weight tensor {0}, // projection_bias tensor - }); + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_); @@ -3337,7 +3818,11 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { {n_output, n_cell}, // projection_weight tensor {0}, // projection_bias tensor - }); + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -3363,12 +3848,310 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } +class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest + : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, + 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5, + -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; + + input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, + -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, + -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + + input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, + -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, + -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + + input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, + -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, + -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + input_gate_bias_ = {0.03, 0.15, 0.22, 0.38}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, + -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; + + recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, + -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + + recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, + 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + + recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, + -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5}; + forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5}; + + projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, + 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + lstm_input_ = { + {// Batch0: 3 (input_sequence_size) * 5 (n_input) + 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 + 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 + 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 + + {// Batch1: 3 (input_sequence_size) * 5 (n_input) + 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 + 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 + 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 + }; + } +}; + +TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, + LayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + LSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_cell}, // input_layer_norm_coefficient tensor + {n_cell}, // forget_layer_norm_coefficient tensor + {n_cell}, // cell_layer_norm_coefficient tensor + {n_cell}, // output_layer_norm_coefficient tensor + }, + /*weight_type=*/TensorType_FLOAT32); + + layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetInputGateBias(input_gate_bias_); + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_); + layer_norm_lstm.SetForgetLayerNormCoefficients( + forget_layer_norm_coefficients_); + layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_); + layer_norm_lstm.SetOutputLayerNormCoefficients( + output_layer_norm_coefficients_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + // Verify the final output. + const std::vector> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0244077, 0.128027, -0.00170918, // seq 0 + 0.0137642, 0.140751, 0.0395835, // seq 1 + -0.00459231, 0.155278, 0.0837377, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.00692428, 0.0848741, 0.063445, // seq 0 + -0.00403912, 0.139963, 0.072681, // seq 1 + 0.00752706, 0.161903, 0.0561371, // seq 2 + }}; + + VerifyGoldens(lstm_input_, layer_norm_lstm_golden_output, &layer_norm_lstm); +} + +class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, + -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, + -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, + -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, + -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, + -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, + -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, + -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, + 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, + -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5}; + projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, + 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + lstm_input_ = { + {// Batch0: 3 (input_sequence_size) * 5 (n_input) + 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 + 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 + 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 + + {// Batch1: 3 (input_sequence_size) * 5 (n_input) + 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 + 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 + 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 + }; + } +}; + +TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, + LayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + LSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {0}, // input_layer_norm_coefficient tensor + {n_cell}, // forget_layer_norm_coefficient tensor + {n_cell}, // cell_layer_norm_coefficient tensor + {n_cell}, // output_layer_norm_coefficient tensor + }, + /*weight_type=*/TensorType_FLOAT32); + + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetForgetLayerNormCoefficients( + forget_layer_norm_coefficients_); + layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_); + layer_norm_lstm.SetOutputLayerNormCoefficients( + output_layer_norm_coefficients_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + // Verify the final output. + const std::vector> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.02129706, 0.140816242, 0.0112733059, // seq 0 + 0.0132302344, 0.152308047, 0.0346313119, // seq 1 + -0.0123688057, 0.165790111, 0.0893077999, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.0226350538, 0.0916948169, 0.0769175813, // seq 0 + -0.0269966982, 0.149707705, 0.094149217, // seq 1 + -0.0103429332, 0.173016444, 0.0720508844, // seq 2 + }}; + + VerifyGoldens(lstm_input_, layer_norm_lstm_golden_output, &layer_norm_lstm); +} + class BaseReduceOpModel : public SingleOpModelWithNNAPI { public: void SetAxis(const std::vector& data) { PopulateTensor(axis_, data); } template - void SetInput(std::vector data) { + void SetInput(const std::vector& data) { PopulateTensor(input_, data); } @@ -3605,6 +4388,85 @@ TEST(NNAPIDelegate, HashtableLookupTest1DInput) { 1, })); } + +// A base class of PRelu op model. It provides the constructor for +// FloatPReluOpModel and QuantizedPReluOpModel. +class PReluOpModel : public SingleOpModelWithNNAPI { + public: + PReluOpModel(const TensorData& input, const TensorData& alpha) + : input_type_(input.type) { + input_ = AddInput(input); + alpha_ = AddInput(alpha); + output_ = AddOutput({input.type, input.shape, input.min, input.max}); + SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_), GetShape(alpha_)}); + } + + void SetInput(std::initializer_list data) { + SetData(input_, input_type_, data); + } + + void SetAlpha(std::initializer_list data) { + SetData(alpha_, input_type_, data); + } + + std::vector GetOutput() { + std::vector output; + GetData(output_, input_type_, &output); + return output; + } + + protected: + int input_; + int alpha_; + int output_; + + const TensorType input_type_; +}; + +TEST(NNAPIDelegate, PReluFloat) { + PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}}, + {TensorType_FLOAT32, {1, 1, 3}}); + + m.SetInput({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + -1.0f, -1.0f, -1.0f, // Row 2, Column 1 + -2.0f, -2.0f, -2.0f, // Row 1, Column 2 + }); + m.SetAlpha({0.0f, 1.0f, 2.0f}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + 0.0f, -1.0f, -2.0f, // Row 2, Column 1 + 0.0f, -2.0f, -4.0f, // Row 1, Column 2 + })); +} + +TEST(NNAPIDelegate, PReluQuantized) { + const float kMin = -1; + const float kMax = 127.f / 128.f; + PReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax}, + {TensorType_UINT8, {1, 1, 3}, kMin, kMax}); + m.SetInput({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 0.5f, 0.5f, 0.5f, // Row 1, Column 2 + -1.0f, -1.0f, -1.0f, // Row 2, Column 1 + -0.25f, -0.25f, -0.25f, // Row 1, Column 2 + }); + m.SetAlpha({0.0f, 0.5f, -0.5f}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 0.5f, 0.5f, 0.5f, // Row 1, Column 2 + 0.0f, -0.5f, 0.5f, // Row 2, Column 1 + 0.0f, -0.125f, 0.125f, // Row 1, Column 2 + }, + kQuantizedTolerance))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/examples/android/BUILD b/tensorflow/lite/examples/android/BUILD deleted file mode 100644 index 80cefd415a5..00000000000 --- a/tensorflow/lite/examples/android/BUILD +++ /dev/null @@ -1,61 +0,0 @@ -# Description: -# TensorFlow camera demo app for Android. - -load("@build_bazel_rules_android//android:rules.bzl", "android_binary") - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -# Build the demo native demo lib from the original directory to reduce code -# reuse. Note that the Java counterparts (ObjectTracker.java and -# ImageUtils.java) are still duplicated. -cc_library( - name = "tensorflow_native_libs", - srcs = [ - "//tensorflow/examples/android:libtensorflow_demo.so", - ], - tags = [ - "manual", - "notap", - ], -) - -android_binary( - name = "tflite_demo", - srcs = glob([ - "app/src/main/java/**/*.java", - ]), - aapt_version = "aapt", - # Package assets from assets dir as well as all model targets. - # Remove undesired models (and corresponding Activities in source) - # to reduce APK size. - assets = [ - "//tensorflow/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", - "@tflite_mobilenet_quant//:mobilenet_v1_1.0_224_quant.tflite", - "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", - "//tensorflow/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", - "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", - "@tflite_mobilenet_ssd_quant//:detect.tflite", - "//tensorflow/lite/examples/android/app/src/main/assets:box_priors.txt", - "//tensorflow/lite/examples/android/app/src/main/assets:coco_labels_list.txt", - ], - assets_dir = "", - custom_package = "org.tensorflow.lite.demo", - inline_constants = 1, - manifest = "app/src/main/AndroidManifest.xml", - nocompress_extensions = [ - ".tflite", - ], - resource_files = glob(["app/src/main/res/**"]), - tags = [ - "manual", - "notap", - ], - deps = [ - ":tensorflow_native_libs", - "//tensorflow/lite/java:tensorflowlite", - ], -) diff --git a/tensorflow/lite/examples/android/android.iml b/tensorflow/lite/examples/android/android.iml deleted file mode 100644 index f0a5ac2bf4c..00000000000 --- a/tensorflow/lite/examples/android/android.iml +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/tensorflow/lite/examples/android/app/README.md b/tensorflow/lite/examples/android/app/README.md index e2b1b2691bb..0da8d13fcce 100644 --- a/tensorflow/lite/examples/android/app/README.md +++ b/tensorflow/lite/examples/android/app/README.md @@ -1,54 +1,9 @@ -# TF Lite Android App Example +# TF Lite Android Example (Deprecated) -A simple Android example that demonstrates image classification and object -detection using the camera, as well as speech recognition using the microphone. +This example has been moved to the new +[TensorFlow examples repo](https://github.com/tensorflow/examples), and split +into several distinct examples: -## Building in Android Studio with TensorFlow Lite AAR from JCenter. -The build.gradle is configured to use TensorFlow Lite's nightly build. - -If you see a build error related to compatibility with Tensorflow Lite's Java -API (example: method X is undefined for type Interpreter), there has likely been -a backwards compatible change to the API. You will need to pull new app code -that's compatible with the nightly build and may need to first wait a few days -for our external and internal code to merge. - -## Building from Source with Bazel - -1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): - - 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites). - It's easiest with Android Studio. - - - You'll need at least SDK version 23. - - Make sure to install the latest version of Bazel. Some distributions - ship with Bazel 0.5.4, which is too old. - - Bazel requires Android Build Tools `26.0.1` or higher. - - You also need to install the Android Support Repository, available - through Android Studio under `Android SDK Manager -> SDK Tools -> - Android Support Repository`. - - 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace) - to add SDK and NDK targets. - - NOTE: As long as you have the SDK and NDK installed, the `./configure` - script will create these rules for you. Answer "Yes" when the script asks - to automatically configure the `./WORKSPACE`. - - - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that - you have installed. - - By default, Android Studio will install the SDK to `~/Android/Sdk` and - the NDK to `~/Android/Sdk/ndk-bundle`. - -2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device: - - ```shell - bazel build -c opt --cxxopt='--std=c++11' --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ - //tensorflow/lite/examples/android:tflite_demo - ``` - -3. Install the demo on a - [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install): - - ```shell - adb install bazel-bin/tensorflow/lite/examples/android/tflite_demo.apk - ``` +* [Image Classification](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android) +* [Object Detection](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android) +* [Speech Commands](https://github.com/tensorflow/examples/tree/master/lite/examples/speech_commands/android) diff --git a/tensorflow/lite/examples/android/app/build.gradle b/tensorflow/lite/examples/android/app/build.gradle deleted file mode 100644 index d2bc9846af5..00000000000 --- a/tensorflow/lite/examples/android/app/build.gradle +++ /dev/null @@ -1,50 +0,0 @@ -apply plugin: 'com.android.application' - -// import DownloadModels task -project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' -project.ext.TMP_DIR = project.buildDir.toString() + '/downloads' - -// Download default models; if you wish to use your own models then -// place them in the "assets" directory and comment out this line. -apply from: "download-models.gradle" - -android { - compileSdkVersion 26 - buildToolsVersion '28.0.3' - defaultConfig { - applicationId "org.tensorflow.lite.demo" - minSdkVersion 15 - targetSdkVersion 26 - versionCode 1 - versionName "1.0" - - } - lintOptions { - abortOnError false - } - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' - } - } - aaptOptions { - noCompress "tflite" - } - - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 - } -} - -repositories { - maven { - url 'https://google.bintray.com/tensorflow' - } -} - -dependencies { - implementation fileTree(dir: 'libs', include: ['*.jar']) - implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly' -} diff --git a/tensorflow/lite/examples/android/app/download-models.gradle b/tensorflow/lite/examples/android/app/download-models.gradle deleted file mode 100644 index 514eeb01350..00000000000 --- a/tensorflow/lite/examples/android/app/download-models.gradle +++ /dev/null @@ -1,78 +0,0 @@ -/* - * download-models.gradle - * Downloads model files from ${MODEL_URL} into application's asset folder - * Input: - * project.ext.TMP_DIR: absolute path to hold downloaded zip files - * project.ext.ASSET_DIR: absolute path to save unzipped model files - * Output: - * 3 model files will be downloaded into given folder of ext.ASSET_DIR - */ -// hard coded model files - -def models = ['https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip', - 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip', - 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip', - 'http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz', - 'http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz'] - -// Root URL for model archives -def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite' - -buildscript { - repositories { - jcenter() - } - dependencies { - classpath 'de.undercouch:gradle-download-task:3.2.0' - } -} - -import de.undercouch.gradle.tasks.download.Download -task downloadFile(type: Download){ - for (modelUrl in models) { - def localFile = modelUrl.split("/")[-1] - println "Downloading ${localFile} from ${modelUrl}" - src modelUrl - } - - dest new File(project.ext.TMP_DIR) - overwrite true -} - -task extractModels(type: Copy) { - for (f in models) { - def localFile = f.split("/")[-1] - def localExt = localFile.split("[.]")[-1] - if (localExt == "tgz") { - from tarTree(project.ext.TMP_DIR + '/' + localFile) - } else { - from zipTree(project.ext.TMP_DIR + '/' + localFile) - } - } - - into file(project.ext.ASSET_DIR) - fileMode 0644 - exclude '**/LICENSE' - - def needDownload = false - for (f in models) { - def localFile = f.split("/")[-1] - if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) { - needDownload = true - } - } - - if (needDownload) { - dependsOn downloadFile - } -} - -tasks.whenTaskAdded { task -> - if (task.name == 'assembleDebug') { - task.dependsOn 'extractModels' - } - if (task.name == 'assembleRelease') { - task.dependsOn 'extractModels' - } -} - diff --git a/tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml b/tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml deleted file mode 100644 index d4c98c61cca..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/AndroidManifest.xml +++ /dev/null @@ -1,60 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/lite/examples/android/app/src/main/assets/coco_labels_list.txt b/tensorflow/lite/examples/android/app/src/main/assets/coco_labels_list.txt deleted file mode 100644 index 5a70ff82aa7..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/assets/coco_labels_list.txt +++ /dev/null @@ -1,91 +0,0 @@ -??? -person -bicycle -car -motorcycle -airplane -bus -train -truck -boat -traffic light -fire hydrant -??? -stop sign -parking meter -bench -bird -cat -dog -horse -sheep -cow -elephant -bear -zebra -giraffe -??? -backpack -umbrella -??? -??? -handbag -tie -suitcase -frisbee -skis -snowboard -sports ball -kite -baseball bat -baseball glove -skateboard -surfboard -tennis racket -bottle -??? -wine glass -cup -fork -knife -spoon -bowl -banana -apple -sandwich -orange -broccoli -carrot -hot dog -pizza -donut -cake -chair -couch -potted plant -bed -??? -dining table -??? -??? -toilet -??? -tv -laptop -mouse -remote -keyboard -cell phone -microwave -oven -toaster -sink -refrigerator -??? -book -clock -vase -scissors -teddy bear -hair drier -toothbrush diff --git a/tensorflow/lite/examples/android/app/src/main/assets/conv_actions_labels.txt b/tensorflow/lite/examples/android/app/src/main/assets/conv_actions_labels.txt deleted file mode 100644 index ba416458b01..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/assets/conv_actions_labels.txt +++ /dev/null @@ -1,12 +0,0 @@ -_silence_ -_unknown_ -yes -no -up -down -left -right -on -off -stop -go \ No newline at end of file diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java deleted file mode 100644 index eff24afdba4..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.demo; - -import android.content.Context; -import android.util.AttributeSet; -import android.view.TextureView; - -/** - * A {@link TextureView} that can be adjusted to a specified aspect ratio. - */ -public class AutoFitTextureView extends TextureView { - private int ratioWidth = 0; - private int ratioHeight = 0; - - public AutoFitTextureView(final Context context) { - this(context, null); - } - - public AutoFitTextureView(final Context context, final AttributeSet attrs) { - this(context, attrs, 0); - } - - public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) { - super(context, attrs, defStyle); - } - - /** - * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio - * calculated from the parameters. Note that the actual sizes of parameters don't matter, that - * is, calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. - * - * @param width Relative horizontal size - * @param height Relative vertical size - */ - public void setAspectRatio(final int width, final int height) { - if (width < 0 || height < 0) { - throw new IllegalArgumentException("Size cannot be negative."); - } - ratioWidth = width; - ratioHeight = height; - requestLayout(); - } - - @Override - protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { - super.onMeasure(widthMeasureSpec, heightMeasureSpec); - final int width = MeasureSpec.getSize(widthMeasureSpec); - final int height = MeasureSpec.getSize(heightMeasureSpec); - if (0 == ratioWidth || 0 == ratioHeight) { - setMeasuredDimension(width, height); - } else { - if (width < height * ratioWidth / ratioHeight) { - setMeasuredDimension(width, width * ratioHeight / ratioWidth); - } else { - setMeasuredDimension(height * ratioWidth / ratioHeight, height); - } - } - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java deleted file mode 100644 index 15d5456f027..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java +++ /dev/null @@ -1,450 +0,0 @@ -/* - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.demo; - -import android.Manifest; -import android.app.Activity; -import android.app.Fragment; -import android.content.Context; -import android.content.pm.PackageManager; -import android.hardware.Camera; -import android.hardware.camera2.CameraAccessException; -import android.hardware.camera2.CameraCharacteristics; -import android.hardware.camera2.CameraManager; -import android.hardware.camera2.params.StreamConfigurationMap; -import android.media.Image; -import android.media.Image.Plane; -import android.media.ImageReader; -import android.media.ImageReader.OnImageAvailableListener; -import android.os.Build; -import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Trace; -import android.util.Size; -import android.view.KeyEvent; -import android.view.Surface; -import android.view.WindowManager; -import android.widget.Toast; -import java.nio.ByteBuffer; -import org.tensorflow.demo.env.ImageUtils; -import org.tensorflow.demo.env.Logger; -import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds. - -public abstract class CameraActivity extends Activity - implements OnImageAvailableListener, Camera.PreviewCallback { - private static final Logger LOGGER = new Logger(); - - private static final int PERMISSIONS_REQUEST = 1; - - private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA; - private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE; - - private boolean debug = false; - - private Handler handler; - private HandlerThread handlerThread; - private boolean useCamera2API; - private boolean isProcessingFrame = false; - private byte[][] yuvBytes = new byte[3][]; - private int[] rgbBytes = null; - private int yRowStride; - - protected int previewWidth = 0; - protected int previewHeight = 0; - - private Runnable postInferenceCallback; - private Runnable imageConverter; - - @Override - protected void onCreate(final Bundle savedInstanceState) { - LOGGER.d("onCreate " + this); - super.onCreate(null); - getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON); - - setContentView(R.layout.activity_camera); - - if (hasPermission()) { - setFragment(); - } else { - requestPermission(); - } - } - - - protected int[] getRgbBytes() { - imageConverter.run(); - return rgbBytes; - } - - protected int getLuminanceStride() { - return yRowStride; - } - - protected byte[] getLuminance() { - return yuvBytes[0]; - } - - /** - * Callback for android.hardware.Camera API - */ - @Override - public void onPreviewFrame(final byte[] bytes, final Camera camera) { - if (isProcessingFrame) { - LOGGER.w("Dropping frame!"); - return; - } - - try { - // Initialize the storage bitmaps once when the resolution is known. - if (rgbBytes == null) { - Camera.Size previewSize = camera.getParameters().getPreviewSize(); - previewHeight = previewSize.height; - previewWidth = previewSize.width; - rgbBytes = new int[previewWidth * previewHeight]; - onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90); - } - } catch (final Exception e) { - LOGGER.e(e, "Exception!"); - return; - } - - isProcessingFrame = true; - yuvBytes[0] = bytes; - yRowStride = previewWidth; - - imageConverter = - new Runnable() { - @Override - public void run() { - ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes); - } - }; - - postInferenceCallback = - new Runnable() { - @Override - public void run() { - camera.addCallbackBuffer(bytes); - isProcessingFrame = false; - } - }; - processImage(); - } - - /** - * Callback for Camera2 API - */ - @Override - public void onImageAvailable(final ImageReader reader) { - //We need wait until we have some size from onPreviewSizeChosen - if (previewWidth == 0 || previewHeight == 0) { - return; - } - if (rgbBytes == null) { - rgbBytes = new int[previewWidth * previewHeight]; - } - try { - final Image image = reader.acquireLatestImage(); - - if (image == null) { - return; - } - - if (isProcessingFrame) { - image.close(); - return; - } - isProcessingFrame = true; - Trace.beginSection("imageAvailable"); - final Plane[] planes = image.getPlanes(); - fillBytes(planes, yuvBytes); - yRowStride = planes[0].getRowStride(); - final int uvRowStride = planes[1].getRowStride(); - final int uvPixelStride = planes[1].getPixelStride(); - - imageConverter = - new Runnable() { - @Override - public void run() { - ImageUtils.convertYUV420ToARGB8888( - yuvBytes[0], - yuvBytes[1], - yuvBytes[2], - previewWidth, - previewHeight, - yRowStride, - uvRowStride, - uvPixelStride, - rgbBytes); - } - }; - - postInferenceCallback = - new Runnable() { - @Override - public void run() { - image.close(); - isProcessingFrame = false; - } - }; - - processImage(); - } catch (final Exception e) { - LOGGER.e(e, "Exception!"); - Trace.endSection(); - return; - } - Trace.endSection(); - } - - @Override - public synchronized void onStart() { - LOGGER.d("onStart " + this); - super.onStart(); - } - - @Override - public synchronized void onResume() { - LOGGER.d("onResume " + this); - super.onResume(); - - handlerThread = new HandlerThread("inference"); - handlerThread.start(); - handler = new Handler(handlerThread.getLooper()); - } - - @Override - public synchronized void onPause() { - LOGGER.d("onPause " + this); - - if (!isFinishing()) { - LOGGER.d("Requesting finish"); - finish(); - } - - handlerThread.quitSafely(); - try { - handlerThread.join(); - handlerThread = null; - handler = null; - } catch (final InterruptedException e) { - LOGGER.e(e, "Exception!"); - } - - super.onPause(); - } - - @Override - public synchronized void onStop() { - LOGGER.d("onStop " + this); - super.onStop(); - } - - @Override - public synchronized void onDestroy() { - LOGGER.d("onDestroy " + this); - super.onDestroy(); - } - - protected synchronized void runInBackground(final Runnable r) { - if (handler != null) { - handler.post(r); - } - } - - @Override - public void onRequestPermissionsResult( - final int requestCode, final String[] permissions, final int[] grantResults) { - if (requestCode == PERMISSIONS_REQUEST) { - if (grantResults.length > 0 - && grantResults[0] == PackageManager.PERMISSION_GRANTED - && grantResults[1] == PackageManager.PERMISSION_GRANTED) { - setFragment(); - } else { - requestPermission(); - } - } - } - - private boolean hasPermission() { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { - return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED && - checkSelfPermission(PERMISSION_STORAGE) == PackageManager.PERMISSION_GRANTED; - } else { - return true; - } - } - - private void requestPermission() { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { - if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA) || - shouldShowRequestPermissionRationale(PERMISSION_STORAGE)) { - Toast.makeText(CameraActivity.this, - "Camera AND storage permission are required for this demo", Toast.LENGTH_LONG).show(); - } - requestPermissions(new String[] {PERMISSION_CAMERA, PERMISSION_STORAGE}, PERMISSIONS_REQUEST); - } - } - - // Returns true if the device supports the required hardware level, or better. - private boolean isHardwareLevelSupported( - CameraCharacteristics characteristics, int requiredLevel) { - int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL); - if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) { - return requiredLevel == deviceLevel; - } - // deviceLevel is not LEGACY, can use numerical sort - return requiredLevel <= deviceLevel; - } - - private String chooseCamera() { - final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE); - try { - for (final String cameraId : manager.getCameraIdList()) { - final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); - - // We don't use a front facing camera in this sample. - final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); - if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { - continue; - } - - final StreamConfigurationMap map = - characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); - - if (map == null) { - continue; - } - - // Fallback to camera1 API for internal cameras that don't have full support. - // This should help with legacy situations where using the camera2 API causes - // distorted or otherwise broken previews. - useCamera2API = (facing == CameraCharacteristics.LENS_FACING_EXTERNAL) - || isHardwareLevelSupported(characteristics, - CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL); - LOGGER.i("Camera API lv2?: %s", useCamera2API); - return cameraId; - } - } catch (CameraAccessException e) { - LOGGER.e(e, "Not allowed to access camera"); - } - - return null; - } - - protected void setFragment() { - String cameraId = chooseCamera(); - - Fragment fragment; - if (useCamera2API) { - CameraConnectionFragment camera2Fragment = - CameraConnectionFragment.newInstance( - new CameraConnectionFragment.ConnectionCallback() { - @Override - public void onPreviewSizeChosen(final Size size, final int rotation) { - previewHeight = size.getHeight(); - previewWidth = size.getWidth(); - CameraActivity.this.onPreviewSizeChosen(size, rotation); - } - }, - this, - getLayoutId(), - getDesiredPreviewFrameSize()); - - camera2Fragment.setCamera(cameraId); - fragment = camera2Fragment; - } else { - fragment = - new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize()); - } - - getFragmentManager() - .beginTransaction() - .replace(R.id.container, fragment) - .commit(); - } - - protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) { - // Because of the variable row stride it's not possible to know in - // advance the actual necessary dimensions of the yuv planes. - for (int i = 0; i < planes.length; ++i) { - final ByteBuffer buffer = planes[i].getBuffer(); - if (yuvBytes[i] == null) { - LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity()); - yuvBytes[i] = new byte[buffer.capacity()]; - } - buffer.get(yuvBytes[i]); - } - } - - public boolean isDebug() { - return debug; - } - - public void requestRender() { - final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay); - if (overlay != null) { - overlay.postInvalidate(); - } - } - - public void addCallback(final OverlayView.DrawCallback callback) { - final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay); - if (overlay != null) { - overlay.addCallback(callback); - } - } - - public void onSetDebug(final boolean debug) {} - - @Override - public boolean onKeyDown(final int keyCode, final KeyEvent event) { - if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP) { - debug = !debug; - requestRender(); - onSetDebug(debug); - return true; - } - return super.onKeyDown(keyCode, event); - } - - protected void readyForNextImage() { - if (postInferenceCallback != null) { - postInferenceCallback.run(); - } - } - - protected int getScreenOrientation() { - switch (getWindowManager().getDefaultDisplay().getRotation()) { - case Surface.ROTATION_270: - return 270; - case Surface.ROTATION_180: - return 180; - case Surface.ROTATION_90: - return 90; - default: - return 0; - } - } - - protected abstract void processImage(); - - protected abstract void onPreviewSizeChosen(final Size size, final int rotation); - protected abstract int getLayoutId(); - protected abstract Size getDesiredPreviewFrameSize(); -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java deleted file mode 100644 index 51a1adb538e..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java +++ /dev/null @@ -1,634 +0,0 @@ -/* - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.demo; - -import android.app.Activity; -import android.app.AlertDialog; -import android.app.Dialog; -import android.app.DialogFragment; -import android.app.Fragment; -import android.content.Context; -import android.content.DialogInterface; -import android.content.res.Configuration; -import android.graphics.ImageFormat; -import android.graphics.Matrix; -import android.graphics.RectF; -import android.graphics.SurfaceTexture; -import android.hardware.camera2.CameraAccessException; -import android.hardware.camera2.CameraCaptureSession; -import android.hardware.camera2.CameraCharacteristics; -import android.hardware.camera2.CameraDevice; -import android.hardware.camera2.CameraManager; -import android.hardware.camera2.CaptureRequest; -import android.hardware.camera2.CaptureResult; -import android.hardware.camera2.TotalCaptureResult; -import android.hardware.camera2.params.StreamConfigurationMap; -import android.media.ImageReader; -import android.media.ImageReader.OnImageAvailableListener; -import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.text.TextUtils; -import android.util.Size; -import android.util.SparseIntArray; -import android.view.LayoutInflater; -import android.view.Surface; -import android.view.TextureView; -import android.view.View; -import android.view.ViewGroup; -import android.widget.Toast; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; -import org.tensorflow.demo.env.Logger; -import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds. - -public class CameraConnectionFragment extends Fragment { - private static final Logger LOGGER = new Logger(); - - /** - * The camera preview size will be chosen to be the smallest frame by pixel size capable of - * containing a DESIRED_SIZE x DESIRED_SIZE square. - */ - private static final int MINIMUM_PREVIEW_SIZE = 320; - - /** - * Conversion from screen rotation to JPEG orientation. - */ - private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); - private static final String FRAGMENT_DIALOG = "dialog"; - - static { - ORIENTATIONS.append(Surface.ROTATION_0, 90); - ORIENTATIONS.append(Surface.ROTATION_90, 0); - ORIENTATIONS.append(Surface.ROTATION_180, 270); - ORIENTATIONS.append(Surface.ROTATION_270, 180); - } - - /** - * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a - * {@link TextureView}. - */ - private final TextureView.SurfaceTextureListener surfaceTextureListener = - new TextureView.SurfaceTextureListener() { - @Override - public void onSurfaceTextureAvailable( - final SurfaceTexture texture, final int width, final int height) { - openCamera(width, height); - } - - @Override - public void onSurfaceTextureSizeChanged( - final SurfaceTexture texture, final int width, final int height) { - configureTransform(width, height); - } - - @Override - public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { - return true; - } - - @Override - public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} - }; - - /** - * Callback for Activities to use to initialize their data once the - * selected preview size is known. - */ - public interface ConnectionCallback { - void onPreviewSizeChosen(Size size, int cameraRotation); - } - - /** - * ID of the current {@link CameraDevice}. - */ - private String cameraId; - - /** - * An {@link AutoFitTextureView} for camera preview. - */ - private AutoFitTextureView textureView; - - /** - * A {@link CameraCaptureSession } for camera preview. - */ - private CameraCaptureSession captureSession; - - /** - * A reference to the opened {@link CameraDevice}. - */ - private CameraDevice cameraDevice; - - /** - * The rotation in degrees of the camera sensor from the display. - */ - private Integer sensorOrientation; - - /** - * The {@link android.util.Size} of camera preview. - */ - private Size previewSize; - - /** - * {@link android.hardware.camera2.CameraDevice.StateCallback} - * is called when {@link CameraDevice} changes its state. - */ - private final CameraDevice.StateCallback stateCallback = - new CameraDevice.StateCallback() { - @Override - public void onOpened(final CameraDevice cd) { - // This method is called when the camera is opened. We start camera preview here. - cameraOpenCloseLock.release(); - cameraDevice = cd; - createCameraPreviewSession(); - } - - @Override - public void onDisconnected(final CameraDevice cd) { - cameraOpenCloseLock.release(); - cd.close(); - cameraDevice = null; - } - - @Override - public void onError(final CameraDevice cd, final int error) { - cameraOpenCloseLock.release(); - cd.close(); - cameraDevice = null; - final Activity activity = getActivity(); - if (null != activity) { - activity.finish(); - } - } - }; - - /** - * An additional thread for running tasks that shouldn't block the UI. - */ - private HandlerThread backgroundThread; - - /** - * A {@link Handler} for running tasks in the background. - */ - private Handler backgroundHandler; - - /** - * An {@link ImageReader} that handles preview frame capture. - */ - private ImageReader previewReader; - - /** - * {@link android.hardware.camera2.CaptureRequest.Builder} for the camera preview - */ - private CaptureRequest.Builder previewRequestBuilder; - - /** - * {@link CaptureRequest} generated by {@link #previewRequestBuilder} - */ - private CaptureRequest previewRequest; - - /** - * A {@link Semaphore} to prevent the app from exiting before closing the camera. - */ - private final Semaphore cameraOpenCloseLock = new Semaphore(1); - - /** - * A {@link OnImageAvailableListener} to receive frames as they are available. - */ - private final OnImageAvailableListener imageListener; - - /** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */ - private final Size inputSize; - - /** - * The layout identifier to inflate for this Fragment. - */ - private final int layout; - - - private final ConnectionCallback cameraConnectionCallback; - - private CameraConnectionFragment( - final ConnectionCallback connectionCallback, - final OnImageAvailableListener imageListener, - final int layout, - final Size inputSize) { - this.cameraConnectionCallback = connectionCallback; - this.imageListener = imageListener; - this.layout = layout; - this.inputSize = inputSize; - } - - /** - * Shows a {@link Toast} on the UI thread. - * - * @param text The message to show - */ - private void showToast(final String text) { - final Activity activity = getActivity(); - if (activity != null) { - activity.runOnUiThread( - new Runnable() { - @Override - public void run() { - Toast.makeText(activity, text, Toast.LENGTH_SHORT).show(); - } - }); - } - } - - /** - * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose - * width and height are at least as large as the minimum of both, or an exact match if possible. - * - * @param choices The list of sizes that the camera supports for the intended output class - * @param width The minimum desired width - * @param height The minimum desired height - * @return The optimal {@code Size}, or an arbitrary one if none were big enough - */ - protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) { - final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE); - final Size desiredSize = new Size(width, height); - - // Collect the supported resolutions that are at least as big as the preview Surface - boolean exactSizeFound = false; - final List bigEnough = new ArrayList(); - final List tooSmall = new ArrayList(); - for (final Size option : choices) { - if (option.equals(desiredSize)) { - // Set the size but don't return yet so that remaining sizes will still be logged. - exactSizeFound = true; - } - - if (option.getHeight() >= minSize && option.getWidth() >= minSize) { - bigEnough.add(option); - } else { - tooSmall.add(option); - } - } - - LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize); - LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]"); - LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]"); - - if (exactSizeFound) { - LOGGER.i("Exact size match found."); - return desiredSize; - } - - // Pick the smallest of those, assuming we found any - if (bigEnough.size() > 0) { - final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea()); - LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight()); - return chosenSize; - } else { - LOGGER.e("Couldn't find any suitable preview size"); - return choices[0]; - } - } - - public static CameraConnectionFragment newInstance( - final ConnectionCallback callback, - final OnImageAvailableListener imageListener, - final int layout, - final Size inputSize) { - return new CameraConnectionFragment(callback, imageListener, layout, inputSize); - } - - @Override - public View onCreateView( - final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { - return inflater.inflate(layout, container, false); - } - - @Override - public void onViewCreated(final View view, final Bundle savedInstanceState) { - textureView = (AutoFitTextureView) view.findViewById(R.id.texture); - } - - @Override - public void onActivityCreated(final Bundle savedInstanceState) { - super.onActivityCreated(savedInstanceState); - } - - @Override - public void onResume() { - super.onResume(); - startBackgroundThread(); - - // When the screen is turned off and turned back on, the SurfaceTexture is already - // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open - // a camera and start preview from here (otherwise, we wait until the surface is ready in - // the SurfaceTextureListener). - if (textureView.isAvailable()) { - openCamera(textureView.getWidth(), textureView.getHeight()); - } else { - textureView.setSurfaceTextureListener(surfaceTextureListener); - } - } - - @Override - public void onPause() { - closeCamera(); - stopBackgroundThread(); - super.onPause(); - } - - public void setCamera(String cameraId) { - this.cameraId = cameraId; - } - - /** - * Sets up member variables related to camera. - */ - private void setUpCameraOutputs() { - final Activity activity = getActivity(); - final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); - try { - final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); - - final StreamConfigurationMap map = - characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); - - sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); - - // Danger, W.R.! Attempting to use too large a preview size could exceed the camera - // bus' bandwidth limitation, resulting in gorgeous previews but the storage of - // garbage capture data. - previewSize = - chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class), - inputSize.getWidth(), - inputSize.getHeight()); - - // We fit the aspect ratio of TextureView to the size of preview we picked. - final int orientation = getResources().getConfiguration().orientation; - if (orientation == Configuration.ORIENTATION_LANDSCAPE) { - textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); - } else { - textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); - } - } catch (final CameraAccessException e) { - LOGGER.e(e, "Exception!"); - } catch (final NullPointerException e) { - // Currently an NPE is thrown when the Camera2API is used but not supported on the - // device this code runs. - // TODO(andrewharp): abstract ErrorDialog/RuntimeException handling out into new method and - // reuse throughout app. - ErrorDialog.newInstance(getString(R.string.camera_error)) - .show(getChildFragmentManager(), FRAGMENT_DIALOG); - throw new RuntimeException(getString(R.string.camera_error)); - } - - cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation); - } - - /** - * Opens the camera specified by {@link CameraConnectionFragment#cameraId}. - */ - private void openCamera(final int width, final int height) { - setUpCameraOutputs(); - configureTransform(width, height); - final Activity activity = getActivity(); - final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); - try { - if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { - throw new RuntimeException("Time out waiting to lock camera opening."); - } - manager.openCamera(cameraId, stateCallback, backgroundHandler); - } catch (final CameraAccessException e) { - LOGGER.e(e, "Exception!"); - } catch (final InterruptedException e) { - throw new RuntimeException("Interrupted while trying to lock camera opening.", e); - } - } - - /** - * Closes the current {@link CameraDevice}. - */ - private void closeCamera() { - try { - cameraOpenCloseLock.acquire(); - if (null != captureSession) { - captureSession.close(); - captureSession = null; - } - if (null != cameraDevice) { - cameraDevice.close(); - cameraDevice = null; - } - if (null != previewReader) { - previewReader.close(); - previewReader = null; - } - } catch (final InterruptedException e) { - throw new RuntimeException("Interrupted while trying to lock camera closing.", e); - } finally { - cameraOpenCloseLock.release(); - } - } - - /** - * Starts a background thread and its {@link Handler}. - */ - private void startBackgroundThread() { - backgroundThread = new HandlerThread("ImageListener"); - backgroundThread.start(); - backgroundHandler = new Handler(backgroundThread.getLooper()); - } - - /** - * Stops the background thread and its {@link Handler}. - */ - private void stopBackgroundThread() { - backgroundThread.quitSafely(); - try { - backgroundThread.join(); - backgroundThread = null; - backgroundHandler = null; - } catch (final InterruptedException e) { - LOGGER.e(e, "Exception!"); - } - } - - private final CameraCaptureSession.CaptureCallback captureCallback = - new CameraCaptureSession.CaptureCallback() { - @Override - public void onCaptureProgressed( - final CameraCaptureSession session, - final CaptureRequest request, - final CaptureResult partialResult) {} - - @Override - public void onCaptureCompleted( - final CameraCaptureSession session, - final CaptureRequest request, - final TotalCaptureResult result) {} - }; - - /** - * Creates a new {@link CameraCaptureSession} for camera preview. - */ - private void createCameraPreviewSession() { - try { - final SurfaceTexture texture = textureView.getSurfaceTexture(); - assert texture != null; - - // We configure the size of default buffer to be the size of camera preview we want. - texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); - - // This is the output Surface we need to start preview. - final Surface surface = new Surface(texture); - - // We set up a CaptureRequest.Builder with the output Surface. - previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); - previewRequestBuilder.addTarget(surface); - - LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight()); - - // Create the reader for the preview frames. - previewReader = - ImageReader.newInstance( - previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2); - - previewReader.setOnImageAvailableListener(imageListener, backgroundHandler); - previewRequestBuilder.addTarget(previewReader.getSurface()); - - // Here, we create a CameraCaptureSession for camera preview. - cameraDevice.createCaptureSession( - Arrays.asList(surface, previewReader.getSurface()), - new CameraCaptureSession.StateCallback() { - - @Override - public void onConfigured(final CameraCaptureSession cameraCaptureSession) { - // The camera is already closed - if (null == cameraDevice) { - return; - } - - // When the session is ready, we start displaying the preview. - captureSession = cameraCaptureSession; - try { - // Auto focus should be continuous for camera preview. - previewRequestBuilder.set( - CaptureRequest.CONTROL_AF_MODE, - CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); - // Flash is automatically enabled when necessary. - previewRequestBuilder.set( - CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH); - - // Finally, we start displaying the camera preview. - previewRequest = previewRequestBuilder.build(); - captureSession.setRepeatingRequest( - previewRequest, captureCallback, backgroundHandler); - } catch (final CameraAccessException e) { - LOGGER.e(e, "Exception!"); - } - } - - @Override - public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) { - showToast("Failed"); - } - }, - null); - } catch (final CameraAccessException e) { - LOGGER.e(e, "Exception!"); - } - } - - /** - * Configures the necessary {@link android.graphics.Matrix} transformation to `mTextureView`. - * This method should be called after the camera preview size is determined in - * setUpCameraOutputs and also the size of `mTextureView` is fixed. - * - * @param viewWidth The width of `mTextureView` - * @param viewHeight The height of `mTextureView` - */ - private void configureTransform(final int viewWidth, final int viewHeight) { - final Activity activity = getActivity(); - if (null == textureView || null == previewSize || null == activity) { - return; - } - final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); - final Matrix matrix = new Matrix(); - final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); - final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); - final float centerX = viewRect.centerX(); - final float centerY = viewRect.centerY(); - if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { - bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); - matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); - final float scale = - Math.max( - (float) viewHeight / previewSize.getHeight(), - (float) viewWidth / previewSize.getWidth()); - matrix.postScale(scale, scale, centerX, centerY); - matrix.postRotate(90 * (rotation - 2), centerX, centerY); - } else if (Surface.ROTATION_180 == rotation) { - matrix.postRotate(180, centerX, centerY); - } - textureView.setTransform(matrix); - } - - /** - * Compares two {@code Size}s based on their areas. - */ - static class CompareSizesByArea implements Comparator { - @Override - public int compare(final Size lhs, final Size rhs) { - // We cast here to ensure the multiplications won't overflow - return Long.signum( - (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); - } - } - - /** - * Shows an error message dialog. - */ - public static class ErrorDialog extends DialogFragment { - private static final String ARG_MESSAGE = "message"; - - public static ErrorDialog newInstance(final String message) { - final ErrorDialog dialog = new ErrorDialog(); - final Bundle args = new Bundle(); - args.putString(ARG_MESSAGE, message); - dialog.setArguments(args); - return dialog; - } - - @Override - public Dialog onCreateDialog(final Bundle savedInstanceState) { - final Activity activity = getActivity(); - return new AlertDialog.Builder(activity) - .setMessage(getArguments().getString(ARG_MESSAGE)) - .setPositiveButton( - android.R.string.ok, - new DialogInterface.OnClickListener() { - @Override - public void onClick(final DialogInterface dialogInterface, final int i) { - activity.finish(); - } - }) - .create(); - } - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java deleted file mode 100644 index 07995febaf5..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo; - -import android.graphics.Bitmap; -import android.graphics.RectF; -import java.util.List; - -/** - * Generic interface for interacting with different recognition engines. - */ -public interface Classifier { - /** - * An immutable result returned by a Classifier describing what was recognized. - */ - public class Recognition { - /** - * A unique identifier for what has been recognized. Specific to the class, not the instance of - * the object. - */ - private final String id; - - /** - * Display name for the recognition. - */ - private final String title; - - /** - * A sortable score for how good the recognition is relative to others. Higher should be better. - */ - private final Float confidence; - - /** Optional location within the source image for the location of the recognized object. */ - private RectF location; - - public Recognition( - final String id, final String title, final Float confidence, final RectF location) { - this.id = id; - this.title = title; - this.confidence = confidence; - this.location = location; - } - - public String getId() { - return id; - } - - public String getTitle() { - return title; - } - - public Float getConfidence() { - return confidence; - } - - public RectF getLocation() { - return new RectF(location); - } - - public void setLocation(RectF location) { - this.location = location; - } - - @Override - public String toString() { - String resultString = ""; - if (id != null) { - resultString += "[" + id + "] "; - } - - if (title != null) { - resultString += title + " "; - } - - if (confidence != null) { - resultString += String.format("(%.1f%%) ", confidence * 100.0f); - } - - if (location != null) { - resultString += location + " "; - } - - return resultString.trim(); - } - } - - List recognizeImage(Bitmap bitmap); - - void enableStatLogging(final boolean debug); - - String getStatString(); - - void close(); -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java deleted file mode 100644 index 698251d8b4a..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.demo; - -import android.graphics.Bitmap; -import android.graphics.Bitmap.Config; -import android.graphics.Canvas; -import android.graphics.Matrix; -import android.graphics.Paint; -import android.graphics.Typeface; -import android.media.ImageReader.OnImageAvailableListener; -import android.os.SystemClock; -import android.util.Size; -import android.util.TypedValue; -import java.util.List; -import java.util.Vector; -import org.tensorflow.demo.OverlayView.DrawCallback; -import org.tensorflow.demo.env.BorderedText; -import org.tensorflow.demo.env.ImageUtils; -import org.tensorflow.demo.env.Logger; -import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds. - -public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener { - private static final Logger LOGGER = new Logger(); - - protected static final boolean SAVE_PREVIEW_BITMAP = false; - - private ResultsView resultsView; - - private Bitmap rgbFrameBitmap = null; - private Bitmap croppedBitmap = null; - private Bitmap cropCopyBitmap = null; - - private long lastProcessingTimeMs; - - // These are the settings for the original v1 Inception model. If you want to - // use a model that's been produced from the TensorFlow for Poets codelab, - // you'll need to set IMAGE_SIZE = 299, IMAGE_MEAN = 128, IMAGE_STD = 128, - // INPUT_NAME = "Mul", and OUTPUT_NAME = "final_result". - // You'll also need to update the MODEL_FILE and LABEL_FILE paths to point to - // the ones you produced. - // - // To use v3 Inception model, strip the DecodeJpeg Op from your retrained - // model first: - // - // python strip_unused.py \ - // --input_graph= \ - // --output_graph= \ - // --input_node_names="Mul" \ - // --output_node_names="final_result" \ - // --input_binary=true - private static final int INPUT_SIZE = 224; - - private static final String MODEL_FILE = "mobilenet_v1_1.0_224_quant.tflite"; - private static final String LABEL_FILE = "labels_mobilenet_quant_v1_224.txt"; - - private static final boolean MAINTAIN_ASPECT = true; - - private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); - - - private Integer sensorOrientation; - private Classifier classifier; - private Matrix frameToCropTransform; - private Matrix cropToFrameTransform; - - private BorderedText borderedText; - - @Override - protected int getLayoutId() { - return R.layout.camera_connection_fragment; - } - - @Override - protected Size getDesiredPreviewFrameSize() { - return DESIRED_PREVIEW_SIZE; - } - - private static final float TEXT_SIZE_DIP = 10; - - @Override - public void onPreviewSizeChosen(final Size size, final int rotation) { - final float textSizePx = TypedValue.applyDimension( - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); - borderedText = new BorderedText(textSizePx); - borderedText.setTypeface(Typeface.MONOSPACE); - - classifier = TFLiteImageClassifier.create(getAssets(), MODEL_FILE, LABEL_FILE, INPUT_SIZE); - - previewWidth = size.getWidth(); - previewHeight = size.getHeight(); - - sensorOrientation = rotation - getScreenOrientation(); - LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); - - LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); - rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); - croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); - - frameToCropTransform = ImageUtils.getTransformationMatrix( - previewWidth, previewHeight, - INPUT_SIZE, INPUT_SIZE, - sensorOrientation, MAINTAIN_ASPECT); - - cropToFrameTransform = new Matrix(); - frameToCropTransform.invert(cropToFrameTransform); - - addCallback( - new DrawCallback() { - @Override - public void drawCallback(final Canvas canvas) { - renderDebug(canvas); - } - }); - } - - @Override - protected void processImage() { - rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); - final Canvas canvas = new Canvas(croppedBitmap); - canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); - - // For examining the actual TF input. - if (SAVE_PREVIEW_BITMAP) { - ImageUtils.saveBitmap(croppedBitmap); - } - runInBackground( - new Runnable() { - @Override - public void run() { - final long startTime = SystemClock.uptimeMillis(); - final List results = classifier.recognizeImage(croppedBitmap); - lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; - LOGGER.i("Detect: %s", results); - cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); - if (resultsView == null) { - resultsView = (ResultsView) findViewById(R.id.results); - } - resultsView.setResults(results); - requestRender(); - readyForNextImage(); - } - }); - } - - @Override - public void onSetDebug(boolean debug) { - classifier.enableStatLogging(debug); - } - - private void renderDebug(final Canvas canvas) { - if (!isDebug()) { - return; - } - final Bitmap copy = cropCopyBitmap; - if (copy != null) { - final Matrix matrix = new Matrix(); - final float scaleFactor = 2; - matrix.postScale(scaleFactor, scaleFactor); - matrix.postTranslate( - canvas.getWidth() - copy.getWidth() * scaleFactor, - canvas.getHeight() - copy.getHeight() * scaleFactor); - canvas.drawBitmap(copy, matrix, new Paint()); - - final Vector lines = new Vector(); - if (classifier != null) { - String statString = classifier.getStatString(); - String[] statLines = statString.split("\n"); - for (String line : statLines) { - lines.add(line); - } - } - - lines.add("Frame: " + previewWidth + "x" + previewHeight); - lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); - lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); - lines.add("Rotation: " + sensorOrientation); - lines.add("Inference time: " + lastProcessingTimeMs + "ms"); - - borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); - } - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java deleted file mode 100644 index 2feca79e888..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java +++ /dev/null @@ -1,301 +0,0 @@ -/* - * Copyright 2018 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.demo; - -import android.graphics.Bitmap; -import android.graphics.Bitmap.Config; -import android.graphics.Canvas; -import android.graphics.Color; -import android.graphics.Matrix; -import android.graphics.Paint; -import android.graphics.Paint.Style; -import android.graphics.RectF; -import android.graphics.Typeface; -import android.media.ImageReader.OnImageAvailableListener; -import android.os.SystemClock; -import android.util.Size; -import android.util.TypedValue; -import android.widget.Toast; -import java.io.IOException; -import java.util.LinkedList; -import java.util.List; -import java.util.Vector; -import org.tensorflow.demo.OverlayView.DrawCallback; -import org.tensorflow.demo.env.BorderedText; -import org.tensorflow.demo.env.ImageUtils; -import org.tensorflow.demo.env.Logger; -import org.tensorflow.demo.tracking.MultiBoxTracker; -import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds. - -/** - * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track - * objects. - */ -public class DetectorActivity extends CameraActivity implements OnImageAvailableListener { - private static final Logger LOGGER = new Logger(); - - // Configuration values for the prepackaged SSD model. - private static final int TF_OD_API_INPUT_SIZE = 300; - private static final boolean TF_OD_API_IS_QUANTIZED = true; - private static final String TF_OD_API_MODEL_FILE = "detect.tflite"; - private static final String TF_OD_API_LABELS_FILE = "coco_labels_list.txt"; - - // Which detection model to use: by default uses Tensorflow Object Detection API frozen - // checkpoints. - private enum DetectorMode { - TF_OD_API; - } - - private static final DetectorMode MODE = DetectorMode.TF_OD_API; - - // Minimum detection confidence to track a detection. - private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f; - - private static final boolean MAINTAIN_ASPECT = false; - - private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); - - private static final boolean SAVE_PREVIEW_BITMAP = false; - private static final float TEXT_SIZE_DIP = 10; - - private Integer sensorOrientation; - - private Classifier detector; - - private long lastProcessingTimeMs; - private Bitmap rgbFrameBitmap = null; - private Bitmap croppedBitmap = null; - private Bitmap cropCopyBitmap = null; - - private boolean computingDetection = false; - - private long timestamp = 0; - - private Matrix frameToCropTransform; - private Matrix cropToFrameTransform; - - private MultiBoxTracker tracker; - - private byte[] luminanceCopy; - - private BorderedText borderedText; - @Override - public void onPreviewSizeChosen(final Size size, final int rotation) { - final float textSizePx = - TypedValue.applyDimension( - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); - borderedText = new BorderedText(textSizePx); - borderedText.setTypeface(Typeface.MONOSPACE); - - tracker = new MultiBoxTracker(this); - - int cropSize = TF_OD_API_INPUT_SIZE; - - try { - detector = - TFLiteObjectDetectionAPIModel.create( - getAssets(), - TF_OD_API_MODEL_FILE, - TF_OD_API_LABELS_FILE, - TF_OD_API_INPUT_SIZE, - TF_OD_API_IS_QUANTIZED); - cropSize = TF_OD_API_INPUT_SIZE; - } catch (final IOException e) { - LOGGER.e("Exception initializing classifier!", e); - Toast toast = - Toast.makeText( - getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); - toast.show(); - finish(); - } - - - previewWidth = size.getWidth(); - previewHeight = size.getHeight(); - - sensorOrientation = rotation - getScreenOrientation(); - LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); - - LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); - rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); - croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888); - - frameToCropTransform = - ImageUtils.getTransformationMatrix( - previewWidth, previewHeight, - cropSize, cropSize, - sensorOrientation, MAINTAIN_ASPECT); - - cropToFrameTransform = new Matrix(); - frameToCropTransform.invert(cropToFrameTransform); - - trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay); - trackingOverlay.addCallback( - new DrawCallback() { - @Override - public void drawCallback(final Canvas canvas) { - tracker.draw(canvas); - if (isDebug()) { - tracker.drawDebug(canvas); - } - } - }); - - addCallback( - new DrawCallback() { - @Override - public void drawCallback(final Canvas canvas) { - if (!isDebug()) { - return; - } - final Bitmap copy = cropCopyBitmap; - if (copy == null) { - return; - } - - final int backgroundColor = Color.argb(100, 0, 0, 0); - canvas.drawColor(backgroundColor); - - final Matrix matrix = new Matrix(); - final float scaleFactor = 2; - matrix.postScale(scaleFactor, scaleFactor); - matrix.postTranslate( - canvas.getWidth() - copy.getWidth() * scaleFactor, - canvas.getHeight() - copy.getHeight() * scaleFactor); - canvas.drawBitmap(copy, matrix, new Paint()); - - final Vector lines = new Vector(); - if (detector != null) { - final String statString = detector.getStatString(); - final String[] statLines = statString.split("\n"); - for (final String line : statLines) { - lines.add(line); - } - } - lines.add(""); - - lines.add("Frame: " + previewWidth + "x" + previewHeight); - lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); - lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); - lines.add("Rotation: " + sensorOrientation); - lines.add("Inference time: " + lastProcessingTimeMs + "ms"); - - borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); - } - }); - } - - OverlayView trackingOverlay; - - @Override - protected void processImage() { - ++timestamp; - final long currTimestamp = timestamp; - byte[] originalLuminance = getLuminance(); - tracker.onFrame( - previewWidth, - previewHeight, - getLuminanceStride(), - sensorOrientation, - originalLuminance, - timestamp); - trackingOverlay.postInvalidate(); - - // No mutex needed as this method is not reentrant. - if (computingDetection) { - readyForNextImage(); - return; - } - computingDetection = true; - LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread."); - - rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); - - if (luminanceCopy == null) { - luminanceCopy = new byte[originalLuminance.length]; - } - System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length); - readyForNextImage(); - - final Canvas canvas = new Canvas(croppedBitmap); - canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); - // For examining the actual TF input. - if (SAVE_PREVIEW_BITMAP) { - ImageUtils.saveBitmap(croppedBitmap); - } - - runInBackground( - new Runnable() { - @Override - public void run() { - LOGGER.i("Running detection on image " + currTimestamp); - final long startTime = SystemClock.uptimeMillis(); - final List results = detector.recognizeImage(croppedBitmap); - lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; - - cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); - final Canvas canvas = new Canvas(cropCopyBitmap); - final Paint paint = new Paint(); - paint.setColor(Color.RED); - paint.setStyle(Style.STROKE); - paint.setStrokeWidth(2.0f); - - float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; - switch (MODE) { - case TF_OD_API: - minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; - break; - } - - final List mappedRecognitions = - new LinkedList(); - - for (final Classifier.Recognition result : results) { - final RectF location = result.getLocation(); - if (location != null && result.getConfidence() >= minimumConfidence) { - canvas.drawRect(location, paint); - - cropToFrameTransform.mapRect(location); - result.setLocation(location); - mappedRecognitions.add(result); - } - } - - tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp); - trackingOverlay.postInvalidate(); - - requestRender(); - computingDetection = false; - } - }); - } - - @Override - protected int getLayoutId() { - return R.layout.camera_connection_fragment_tracking; - } - - @Override - protected Size getDesiredPreviewFrameSize() { - return DESIRED_PREVIEW_SIZE; - } - - @Override - public void onSetDebug(final boolean debug) { - detector.enableStatLogging(debug); - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java deleted file mode 100644 index fd830297533..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java +++ /dev/null @@ -1,216 +0,0 @@ -package org.tensorflow.demo; - -/* - * Copyright 2017 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import android.app.Fragment; -import android.graphics.SurfaceTexture; -import android.hardware.Camera; -import android.hardware.Camera.CameraInfo; -import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.util.Size; -import android.util.SparseIntArray; -import android.view.LayoutInflater; -import android.view.Surface; -import android.view.TextureView; -import android.view.View; -import android.view.ViewGroup; -import java.io.IOException; -import java.util.List; -import org.tensorflow.demo.env.ImageUtils; -import org.tensorflow.demo.env.Logger; -import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds. - -public class LegacyCameraConnectionFragment extends Fragment { - private Camera camera; - private static final Logger LOGGER = new Logger(); - private Camera.PreviewCallback imageListener; - private Size desiredSize; - - /** - * The layout identifier to inflate for this Fragment. - */ - private int layout; - - public LegacyCameraConnectionFragment( - final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) { - this.imageListener = imageListener; - this.layout = layout; - this.desiredSize = desiredSize; - } - - /** - * Conversion from screen rotation to JPEG orientation. - */ - private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); - - static { - ORIENTATIONS.append(Surface.ROTATION_0, 90); - ORIENTATIONS.append(Surface.ROTATION_90, 0); - ORIENTATIONS.append(Surface.ROTATION_180, 270); - ORIENTATIONS.append(Surface.ROTATION_270, 180); - } - - /** - * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a - * {@link TextureView}. - */ - private final TextureView.SurfaceTextureListener surfaceTextureListener = - new TextureView.SurfaceTextureListener() { - @Override - public void onSurfaceTextureAvailable( - final SurfaceTexture texture, final int width, final int height) { - - int index = getCameraId(); - camera = Camera.open(index); - - try { - Camera.Parameters parameters = camera.getParameters(); - List focusModes = parameters.getSupportedFocusModes(); - if (focusModes != null - && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) { - parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); - } - List cameraSizes = parameters.getSupportedPreviewSizes(); - Size[] sizes = new Size[cameraSizes.size()]; - int i = 0; - for (Camera.Size size : cameraSizes) { - sizes[i++] = new Size(size.width, size.height); - } - Size previewSize = - CameraConnectionFragment.chooseOptimalSize( - sizes, desiredSize.getWidth(), desiredSize.getHeight()); - parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight()); - camera.setDisplayOrientation(90); - camera.setParameters(parameters); - camera.setPreviewTexture(texture); - } catch (IOException exception) { - camera.release(); - } - - camera.setPreviewCallbackWithBuffer(imageListener); - Camera.Size s = camera.getParameters().getPreviewSize(); - camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]); - - textureView.setAspectRatio(s.height, s.width); - - camera.startPreview(); - } - - @Override - public void onSurfaceTextureSizeChanged( - final SurfaceTexture texture, final int width, final int height) {} - - @Override - public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { - return true; - } - - @Override - public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} - }; - - /** - * An {@link AutoFitTextureView} for camera preview. - */ - private AutoFitTextureView textureView; - - /** - * An additional thread for running tasks that shouldn't block the UI. - */ - private HandlerThread backgroundThread; - - @Override - public View onCreateView( - final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { - return inflater.inflate(layout, container, false); - } - - @Override - public void onViewCreated(final View view, final Bundle savedInstanceState) { - textureView = (AutoFitTextureView) view.findViewById(R.id.texture); - } - - @Override - public void onActivityCreated(final Bundle savedInstanceState) { - super.onActivityCreated(savedInstanceState); - } - - @Override - public void onResume() { - super.onResume(); - startBackgroundThread(); - // When the screen is turned off and turned back on, the SurfaceTexture is already - // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open - // a camera and start preview from here (otherwise, we wait until the surface is ready in - // the SurfaceTextureListener). - - if (textureView.isAvailable()) { - camera.startPreview(); - } else { - textureView.setSurfaceTextureListener(surfaceTextureListener); - } - } - - @Override - public void onPause() { - stopCamera(); - stopBackgroundThread(); - super.onPause(); - } - - /** - * Starts a background thread and its {@link Handler}. - */ - private void startBackgroundThread() { - backgroundThread = new HandlerThread("CameraBackground"); - backgroundThread.start(); - } - - /** - * Stops the background thread and its {@link Handler}. - */ - private void stopBackgroundThread() { - backgroundThread.quitSafely(); - try { - backgroundThread.join(); - backgroundThread = null; - } catch (final InterruptedException e) { - LOGGER.e(e, "Exception!"); - } - } - - protected void stopCamera() { - if (camera != null) { - camera.stopPreview(); - camera.setPreviewCallback(null); - camera.release(); - camera = null; - } - } - - private int getCameraId() { - CameraInfo ci = new CameraInfo(); - for (int i = 0; i < Camera.getNumberOfCameras(); i++) { - Camera.getCameraInfo(i, ci); - if (ci.facing == CameraInfo.CAMERA_FACING_BACK) - return i; - } - return -1; // No camera found - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java deleted file mode 100644 index 0f8d109fb46..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo; - -import android.content.Context; -import android.graphics.Canvas; -import android.util.AttributeSet; -import android.view.View; -import java.util.LinkedList; -import java.util.List; - -/** - * A simple View providing a render callback to other classes. - */ -public class OverlayView extends View { - private final List callbacks = new LinkedList(); - - public OverlayView(final Context context, final AttributeSet attrs) { - super(context, attrs); - } - - /** - * Interface defining the callback for client classes. - */ - public interface DrawCallback { - public void drawCallback(final Canvas canvas); - } - - public void addCallback(final DrawCallback callback) { - callbacks.add(callback); - } - - @Override - public synchronized void draw(final Canvas canvas) { - for (final DrawCallback callback : callbacks) { - callback.drawCallback(canvas); - } - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java deleted file mode 100644 index 31a4b07c838..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo; - -import android.content.Context; -import android.graphics.Canvas; -import android.graphics.Paint; -import android.util.AttributeSet; -import android.util.TypedValue; -import android.view.View; -import java.util.List; -import org.tensorflow.demo.Classifier.Recognition; - -public class RecognitionScoreView extends View implements ResultsView { - private static final float TEXT_SIZE_DIP = 24; - private List results; - private final float textSizePx; - private final Paint fgPaint; - private final Paint bgPaint; - - public RecognitionScoreView(final Context context, final AttributeSet set) { - super(context, set); - - textSizePx = - TypedValue.applyDimension( - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); - fgPaint = new Paint(); - fgPaint.setTextSize(textSizePx); - - bgPaint = new Paint(); - bgPaint.setColor(0xcc4285f4); - } - - @Override - public void setResults(final List results) { - this.results = results; - postInvalidate(); - } - - @Override - public void onDraw(final Canvas canvas) { - final int x = 10; - int y = (int) (fgPaint.getTextSize() * 1.5f); - - canvas.drawPaint(bgPaint); - - if (results != null) { - for (final Recognition recog : results) { - canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint); - y += (int) (fgPaint.getTextSize() * 1.5f); - } - } - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java deleted file mode 100644 index 9e91aea7efc..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Copyright 2017 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.demo; - -import android.util.Log; -import android.util.Pair; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Deque; -import java.util.List; - -/** Reads in results from an instantaneous audio recognition model and smoothes them over time. */ -public class RecognizeCommands { - // Configuration settings. - private List labels = new ArrayList(); - private long averageWindowDurationMs; - private float detectionThreshold; - private int suppressionMs; - private int minimumCount; - private long minimumTimeBetweenSamplesMs; - - // Working variables. - private Deque> previousResults = new ArrayDeque>(); - private String previousTopLabel; - private int labelsCount; - private long previousTopLabelTime; - private float previousTopLabelScore; - - private static final String SILENCE_LABEL = "_silence_"; - private static final long MINIMUM_TIME_FRACTION = 4; - - public RecognizeCommands( - List inLabels, - long inAverageWindowDurationMs, - float inDetectionThreshold, - int inSuppressionMS, - int inMinimumCount, - long inMinimumTimeBetweenSamplesMS) { - labels = inLabels; - averageWindowDurationMs = inAverageWindowDurationMs; - detectionThreshold = inDetectionThreshold; - suppressionMs = inSuppressionMS; - minimumCount = inMinimumCount; - labelsCount = inLabels.size(); - previousTopLabel = SILENCE_LABEL; - previousTopLabelTime = Long.MIN_VALUE; - previousTopLabelScore = 0.0f; - minimumTimeBetweenSamplesMs = inMinimumTimeBetweenSamplesMS; - } - - /** Holds information about what's been recognized. */ - public static class RecognitionResult { - public final String foundCommand; - public final float score; - public final boolean isNewCommand; - - public RecognitionResult(String inFoundCommand, float inScore, boolean inIsNewCommand) { - foundCommand = inFoundCommand; - score = inScore; - isNewCommand = inIsNewCommand; - } - } - - private static class ScoreForSorting implements Comparable { - public final float score; - public final int index; - - public ScoreForSorting(float inScore, int inIndex) { - score = inScore; - index = inIndex; - } - - @Override - public int compareTo(ScoreForSorting other) { - if (this.score > other.score) { - return -1; - } else if (this.score < other.score) { - return 1; - } else { - return 0; - } - } - } - - public RecognitionResult processLatestResults(float[] currentResults, long currentTimeMS) { - if (currentResults.length != labelsCount) { - throw new RuntimeException( - "The results for recognition should contain " - + labelsCount - + " elements, but there are " - + currentResults.length); - } - - if ((!previousResults.isEmpty()) && (currentTimeMS < previousResults.getFirst().first)) { - throw new RuntimeException( - "You must feed results in increasing time order, but received a timestamp of " - + currentTimeMS - + " that was earlier than the previous one of " - + previousResults.getFirst().first); - } - - final int howManyResults = previousResults.size(); - // Ignore any results that are coming in too frequently. - if (howManyResults > 1) { - final long timeSinceMostRecent = currentTimeMS - previousResults.getLast().first; - if (timeSinceMostRecent < minimumTimeBetweenSamplesMs) { - return new RecognitionResult(previousTopLabel, previousTopLabelScore, false); - } - } - - // Add the latest results to the head of the queue. - previousResults.addLast(new Pair(currentTimeMS, currentResults)); - - // Prune any earlier results that are too old for the averaging window. - final long timeLimit = currentTimeMS - averageWindowDurationMs; - while (previousResults.getFirst().first < timeLimit) { - previousResults.removeFirst(); - } - - // If there are too few results, assume the result will be unreliable and - // bail. - final long earliestTime = previousResults.getFirst().first; - final long samplesDuration = currentTimeMS - earliestTime; - if ((howManyResults < minimumCount) - || (samplesDuration < (averageWindowDurationMs / MINIMUM_TIME_FRACTION))) { - Log.v("RecognizeResult", "Too few results"); - return new RecognitionResult(previousTopLabel, 0.0f, false); - } - - // Calculate the average score across all the results in the window. - float[] averageScores = new float[labelsCount]; - for (Pair previousResult : previousResults) { - final float[] scoresTensor = previousResult.second; - int i = 0; - while (i < scoresTensor.length) { - averageScores[i] += scoresTensor[i] / howManyResults; - ++i; - } - } - - // Sort the averaged results in descending score order. - ScoreForSorting[] sortedAverageScores = new ScoreForSorting[labelsCount]; - for (int i = 0; i < labelsCount; ++i) { - sortedAverageScores[i] = new ScoreForSorting(averageScores[i], i); - } - Arrays.sort(sortedAverageScores); - - // See if the latest top score is enough to trigger a detection. - final int currentTopIndex = sortedAverageScores[0].index; - final String currentTopLabel = labels.get(currentTopIndex); - final float currentTopScore = sortedAverageScores[0].score; - // If we've recently had another label trigger, assume one that occurs too - // soon afterwards is a bad result. - long timeSinceLastTop; - if (previousTopLabel.equals(SILENCE_LABEL) || (previousTopLabelTime == Long.MIN_VALUE)) { - timeSinceLastTop = Long.MAX_VALUE; - } else { - timeSinceLastTop = currentTimeMS - previousTopLabelTime; - } - boolean isNewCommand; - if ((currentTopScore > detectionThreshold) && (timeSinceLastTop > suppressionMs)) { - previousTopLabel = currentTopLabel; - previousTopLabelTime = currentTimeMS; - previousTopLabelScore = currentTopScore; - isNewCommand = true; - } else { - isNewCommand = false; - } - return new RecognitionResult(currentTopLabel, currentTopScore, isNewCommand); - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java deleted file mode 100644 index 9c9c30bc098..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java +++ /dev/null @@ -1,381 +0,0 @@ -/* - * Copyright 2017 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Demonstrates how to run an audio recognition model in Android. - -This example loads a simple speech recognition model trained by the tutorial at -https://www.tensorflow.org/tutorials/audio_training - -The model files should be downloaded automatically from the TensorFlow website, -but if you have a custom model you can update the LABEL_FILENAME and -MODEL_FILENAME constants to point to your own files. - -The example application displays a list view with all of the known audio labels, -and highlights each one when it thinks it has detected one through the -microphone. The averaging of results to give a more reliable signal happens in -the RecognizeCommands helper class. -*/ - -package org.tensorflow.demo; - -import android.animation.ValueAnimator; -import android.app.Activity; -import android.content.pm.PackageManager; -import android.content.res.AssetFileDescriptor; -import android.content.res.AssetManager; -import android.media.AudioFormat; -import android.media.AudioRecord; -import android.media.MediaRecorder; -import android.os.Build; -import android.os.Bundle; -import android.util.Log; -import android.view.View; -import android.widget.ArrayAdapter; -import android.widget.Button; -import android.widget.ListView; -import java.io.BufferedReader; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.locks.ReentrantLock; -import org.tensorflow.lite.Interpreter; -import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds. - -/** - * An activity that listens for audio and then uses a TensorFlow model to detect particular classes, - * by default a small set of action words. - */ -public class SpeechActivity extends Activity { - - // Constants that control the behavior of the recognition code and model - // settings. See the audio recognition tutorial for a detailed explanation of - // all these, but you should customize them to match your training settings if - // you are running your own model. - private static final int SAMPLE_RATE = 16000; - private static final int SAMPLE_DURATION_MS = 1000; - private static final int RECORDING_LENGTH = (int) (SAMPLE_RATE * SAMPLE_DURATION_MS / 1000); - private static final long AVERAGE_WINDOW_DURATION_MS = 500; - private static final float DETECTION_THRESHOLD = 0.70f; - private static final int SUPPRESSION_MS = 1500; - private static final int MINIMUM_COUNT = 3; - private static final long MINIMUM_TIME_BETWEEN_SAMPLES_MS = 30; - private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_labels.txt"; - private static final String MODEL_FILENAME = "file:///android_asset/conv_actions_frozen.tflite"; - - // UI elements. - private static final int REQUEST_RECORD_AUDIO = 13; - private Button quitButton; - private ListView labelsListView; - private static final String LOG_TAG = SpeechActivity.class.getSimpleName(); - - // Working variables. - short[] recordingBuffer = new short[RECORDING_LENGTH]; - int recordingOffset = 0; - boolean shouldContinue = true; - private Thread recordingThread; - boolean shouldContinueRecognition = true; - private Thread recognitionThread; - private final ReentrantLock recordingBufferLock = new ReentrantLock(); - - private List labels = new ArrayList(); - private List displayedLabels = new ArrayList<>(); - private RecognizeCommands recognizeCommands = null; - - private Interpreter tfLite; - - /** Memory-map the model file in Assets. */ - private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename) - throws IOException { - AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename); - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); - FileChannel fileChannel = inputStream.getChannel(); - long startOffset = fileDescriptor.getStartOffset(); - long declaredLength = fileDescriptor.getDeclaredLength(); - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - } - - @Override - protected void onCreate(Bundle savedInstanceState) { - // Set up the UI. - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_speech); - quitButton = (Button) findViewById(R.id.quit); - quitButton.setOnClickListener( - new View.OnClickListener() { - @Override - public void onClick(View view) { - moveTaskToBack(true); - android.os.Process.killProcess(android.os.Process.myPid()); - System.exit(1); - } - }); - labelsListView = (ListView) findViewById(R.id.list_view); - - // Load the labels for the model, but only display those that don't start - // with an underscore. - String actualLabelFilename = LABEL_FILENAME.split("file:///android_asset/", -1)[1]; - Log.i(LOG_TAG, "Reading labels from: " + actualLabelFilename); - BufferedReader br = null; - try { - br = new BufferedReader(new InputStreamReader(getAssets().open(actualLabelFilename))); - String line; - while ((line = br.readLine()) != null) { - labels.add(line); - if (line.charAt(0) != '_') { - displayedLabels.add(line.substring(0, 1).toUpperCase() + line.substring(1)); - } - } - br.close(); - } catch (IOException e) { - throw new RuntimeException("Problem reading label file!", e); - } - - // Build a list view based on these labels. - ArrayAdapter arrayAdapter = - new ArrayAdapter(this, R.layout.list_text_item, displayedLabels); - labelsListView.setAdapter(arrayAdapter); - - // Set up an object to smooth recognition results to increase accuracy. - recognizeCommands = - new RecognizeCommands( - labels, - AVERAGE_WINDOW_DURATION_MS, - DETECTION_THRESHOLD, - SUPPRESSION_MS, - MINIMUM_COUNT, - MINIMUM_TIME_BETWEEN_SAMPLES_MS); - - String actualModelFilename = MODEL_FILENAME.split("file:///android_asset/", -1)[1]; - try { - tfLite = new Interpreter(loadModelFile(getAssets(), actualModelFilename)); - } catch (Exception e) { - throw new RuntimeException(e); - } - - tfLite.resizeInput(0, new int[] {RECORDING_LENGTH, 1}); - tfLite.resizeInput(1, new int[] {1}); - - // Start the recording and recognition threads. - requestMicrophonePermission(); - startRecording(); - startRecognition(); - } - - private void requestMicrophonePermission() { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { - requestPermissions( - new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO); - } - } - - @Override - public void onRequestPermissionsResult( - int requestCode, String[] permissions, int[] grantResults) { - if (requestCode == REQUEST_RECORD_AUDIO - && grantResults.length > 0 - && grantResults[0] == PackageManager.PERMISSION_GRANTED) { - startRecording(); - startRecognition(); - } - } - - public synchronized void startRecording() { - if (recordingThread != null) { - return; - } - shouldContinue = true; - recordingThread = - new Thread( - new Runnable() { - @Override - public void run() { - record(); - } - }); - recordingThread.start(); - } - - public synchronized void stopRecording() { - if (recordingThread == null) { - return; - } - shouldContinue = false; - recordingThread = null; - } - - private void record() { - android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO); - - // Estimate the buffer size we'll need for this device. - int bufferSize = - AudioRecord.getMinBufferSize( - SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT); - if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) { - bufferSize = SAMPLE_RATE * 2; - } - short[] audioBuffer = new short[bufferSize / 2]; - - AudioRecord record = - new AudioRecord( - MediaRecorder.AudioSource.DEFAULT, - SAMPLE_RATE, - AudioFormat.CHANNEL_IN_MONO, - AudioFormat.ENCODING_PCM_16BIT, - bufferSize); - - if (record.getState() != AudioRecord.STATE_INITIALIZED) { - Log.e(LOG_TAG, "Audio Record can't initialize!"); - return; - } - - record.startRecording(); - - Log.v(LOG_TAG, "Start recording"); - - // Loop, gathering audio data and copying it to a round-robin buffer. - while (shouldContinue) { - int numberRead = record.read(audioBuffer, 0, audioBuffer.length); - int maxLength = recordingBuffer.length; - int newRecordingOffset = recordingOffset + numberRead; - int secondCopyLength = Math.max(0, newRecordingOffset - maxLength); - int firstCopyLength = numberRead - secondCopyLength; - // We store off all the data for the recognition thread to access. The ML - // thread will copy out of this buffer into its own, while holding the - // lock, so this should be thread safe. - recordingBufferLock.lock(); - try { - System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, firstCopyLength); - System.arraycopy(audioBuffer, firstCopyLength, recordingBuffer, 0, secondCopyLength); - recordingOffset = newRecordingOffset % maxLength; - } finally { - recordingBufferLock.unlock(); - } - } - - record.stop(); - record.release(); - } - - public synchronized void startRecognition() { - if (recognitionThread != null) { - return; - } - shouldContinueRecognition = true; - recognitionThread = - new Thread( - new Runnable() { - @Override - public void run() { - recognize(); - } - }); - recognitionThread.start(); - } - - public synchronized void stopRecognition() { - if (recognitionThread == null) { - return; - } - shouldContinueRecognition = false; - recognitionThread = null; - } - - private void recognize() { - Log.v(LOG_TAG, "Start recognition"); - - short[] inputBuffer = new short[RECORDING_LENGTH]; - float[][] floatInputBuffer = new float[RECORDING_LENGTH][1]; - float[][] outputScores = new float[1][labels.size()]; - int[] sampleRateList = new int[] {SAMPLE_RATE}; - - // Loop, grabbing recorded data and running the recognition model on it. - while (shouldContinueRecognition) { - // The recording thread places data in this round-robin buffer, so lock to - // make sure there's no writing happening and then copy it to our own - // local version. - recordingBufferLock.lock(); - try { - int maxLength = recordingBuffer.length; - int firstCopyLength = maxLength - recordingOffset; - int secondCopyLength = recordingOffset; - System.arraycopy(recordingBuffer, recordingOffset, inputBuffer, 0, firstCopyLength); - System.arraycopy(recordingBuffer, 0, inputBuffer, firstCopyLength, secondCopyLength); - } finally { - recordingBufferLock.unlock(); - } - - // We need to feed in float values between -1.0f and 1.0f, so divide the - // signed 16-bit inputs. - for (int i = 0; i < RECORDING_LENGTH; ++i) { - floatInputBuffer[i][0] = inputBuffer[i] / 32767.0f; - } - - Object[] inputArray = {floatInputBuffer, sampleRateList}; - Map outputMap = new HashMap<>(); - outputMap.put(0, outputScores); - - // Run the model. - tfLite.runForMultipleInputsOutputs(inputArray, outputMap); - - // Use the smoother to figure out if we've had a real recognition event. - long currentTime = System.currentTimeMillis(); - final RecognizeCommands.RecognitionResult result = - recognizeCommands.processLatestResults(outputScores[0], currentTime); - - runOnUiThread( - new Runnable() { - @Override - public void run() { - // If we do have a new command, highlight the right list entry. - if (!result.foundCommand.startsWith("_") && result.isNewCommand) { - int labelIndex = -1; - for (int i = 0; i < labels.size(); ++i) { - if (labels.get(i).equals(result.foundCommand)) { - labelIndex = i; - } - } - final View labelView = (View) labelsListView.getChildAt(labelIndex - 2); - ValueAnimator colorAnimation = - ValueAnimator.ofArgb(0x00b3ccff, 0xffb3ccff, 0x00b3ccff); - colorAnimation.setDuration(750); - colorAnimation.addUpdateListener( - new ValueAnimator.AnimatorUpdateListener() { - @Override - public void onAnimationUpdate(ValueAnimator animator) { - labelView.setBackgroundColor((int) animator.getAnimatedValue()); - } - }); - colorAnimation.start(); - } - } - }); - try { - // We don't need to run too frequently, so snooze for a bit. - Thread.sleep(MINIMUM_TIME_BETWEEN_SAMPLES_MS); - } catch (InterruptedException e) { - // Ignore - } - } - - Log.v(LOG_TAG, "End recognition"); - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java deleted file mode 100644 index d75c3ceadab..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java +++ /dev/null @@ -1,209 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo; - -import android.content.res.AssetFileDescriptor; -import android.content.res.AssetManager; -import android.graphics.Bitmap; -import android.os.SystemClock; -import android.os.Trace; -import android.util.Log; -import java.io.BufferedReader; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; -import java.util.PriorityQueue; -import java.util.Vector; -import org.tensorflow.lite.Interpreter; - -/** A classifier specialized to label images using TensorFlow. */ -public class TFLiteImageClassifier implements Classifier { - private static final String TAG = "TFLiteImageClassifier"; - - // Only return this many results with at least this confidence. - private static final int MAX_RESULTS = 3; - - private Interpreter tfLite; - - /** Dimensions of inputs. */ - private static final int DIM_BATCH_SIZE = 1; - - private static final int DIM_PIXEL_SIZE = 3; - - private static final int DIM_IMG_SIZE_X = 224; - private static final int DIM_IMG_SIZE_Y = 224; - - byte[][] labelProb; - - // Pre-allocated buffers. - private Vector labels = new Vector(); - private int[] intValues; - private ByteBuffer imgData = null; - - private TFLiteImageClassifier() {} - - /** Memory-map the model file in Assets. */ - private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename) - throws IOException { - AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename); - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); - FileChannel fileChannel = inputStream.getChannel(); - long startOffset = fileDescriptor.getStartOffset(); - long declaredLength = fileDescriptor.getDeclaredLength(); - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - } - - /** - * Initializes a native TensorFlow session for classifying images. - * - * @param assetManager The asset manager to be used to load assets. - * @param modelFilename The filepath of the model GraphDef protocol buffer. - * @param labelFilename The filepath of label file for classes. - * @param inputSize The input size. A square image of inputSize x inputSize is assumed. - * @throws IOException - */ - public static Classifier create( - AssetManager assetManager, String modelFilename, String labelFilename, int inputSize) { - TFLiteImageClassifier c = new TFLiteImageClassifier(); - - // Read the label names into memory. - // TODO(andrewharp): make this handle non-assets. - Log.i(TAG, "Reading labels from: " + labelFilename); - BufferedReader br = null; - try { - br = new BufferedReader(new InputStreamReader(assetManager.open(labelFilename))); - String line; - while ((line = br.readLine()) != null) { - c.labels.add(line); - } - br.close(); - } catch (IOException e) { - throw new RuntimeException("Problem reading label file!" , e); - } - - c.imgData = - ByteBuffer.allocateDirect( - DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); - - c.imgData.order(ByteOrder.nativeOrder()); - try { - c.tfLite = new Interpreter(loadModelFile(assetManager, modelFilename)); - } catch (Exception e) { - throw new RuntimeException(e); - } - - // The shape of the output is [N, NUM_CLASSES], where N is the batch size. - Log.i(TAG, "Read " + c.labels.size() + " labels"); - - // Pre-allocate buffers. - c.intValues = new int[inputSize * inputSize]; - - c.labelProb = new byte[1][c.labels.size()]; - - return c; - } - - /** Writes Image data into a {@code ByteBuffer}. */ - private void convertBitmapToByteBuffer(Bitmap bitmap) { - if (imgData == null) { - return; - } - imgData.rewind(); - bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); - // Convert the image to floating point. - int pixel = 0; - long startTime = SystemClock.uptimeMillis(); - for (int i = 0; i < DIM_IMG_SIZE_X; ++i) { - for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) { - final int val = intValues[pixel++]; - imgData.put((byte) ((val >> 16) & 0xFF)); - imgData.put((byte) ((val >> 8) & 0xFF)); - imgData.put((byte) (val & 0xFF)); - } - } - long endTime = SystemClock.uptimeMillis(); - Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); - } - - @Override - public List recognizeImage(final Bitmap bitmap) { - // Log this method so that it can be analyzed with systrace. - Trace.beginSection("recognizeImage"); - - Trace.beginSection("preprocessBitmap"); - - long startTime; - long endTime; - startTime = SystemClock.uptimeMillis(); - - convertBitmapToByteBuffer(bitmap); - - // Run the inference call. - Trace.beginSection("run"); - startTime = SystemClock.uptimeMillis(); - tfLite.run(imgData, labelProb); - endTime = SystemClock.uptimeMillis(); - Log.i(TAG, "Inf time: " + (endTime - startTime)); - Trace.endSection(); - - // Find the best classifications. - PriorityQueue pq = - new PriorityQueue( - 3, - new Comparator() { - @Override - public int compare(Recognition lhs, Recognition rhs) { - // Intentionally reversed to put high confidence at the head of the queue. - return Float.compare(rhs.getConfidence(), lhs.getConfidence()); - } - }); - for (int i = 0; i < labels.size(); ++i) { - pq.add( - new Recognition( - "" + i, - labels.size() > i ? labels.get(i) : "unknown", - (float) labelProb[0][i], - null)); - } - final ArrayList recognitions = new ArrayList(); - int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); - for (int i = 0; i < recognitionsSize; ++i) { - recognitions.add(pq.poll()); - } - Trace.endSection(); // "recognizeImage" - return recognitions; - } - - @Override - public void enableStatLogging(boolean logStats) { - } - - @Override - public String getStatString() { - return ""; - } - - @Override - public void close() { - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java deleted file mode 100644 index afbf3178314..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java +++ /dev/null @@ -1,233 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo; - -import android.content.res.AssetFileDescriptor; -import android.content.res.AssetManager; -import android.graphics.Bitmap; -import android.graphics.RectF; -import android.os.Trace; -import java.io.BufferedReader; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Vector; -import org.tensorflow.demo.env.Logger; -import org.tensorflow.lite.Interpreter; - -/** - * Wrapper for frozen detection models trained using the Tensorflow Object Detection API: - * github.com/tensorflow/models/tree/master/research/object_detection - */ -public class TFLiteObjectDetectionAPIModel implements Classifier { - private static final Logger LOGGER = new Logger(); - - // Only return this many results. - private static final int NUM_DETECTIONS = 10; - private boolean isModelQuantized; - // Float model - private static final float IMAGE_MEAN = 128.0f; - private static final float IMAGE_STD = 128.0f; - // Number of threads in the java app - private static final int NUM_THREADS = 4; - // Config values. - private int inputSize; - // Pre-allocated buffers. - private Vector labels = new Vector(); - private int[] intValues; - // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4] - // contains the location of detected boxes - private float[][][] outputLocations; - // outputClasses: array of shape [Batchsize, NUM_DETECTIONS] - // contains the classes of detected boxes - private float[][] outputClasses; - // outputScores: array of shape [Batchsize, NUM_DETECTIONS] - // contains the scores of detected boxes - private float[][] outputScores; - // numDetections: array of shape [Batchsize] - // contains the number of detected boxes - private float[] numDetections; - - private ByteBuffer imgData; - - private Interpreter tfLite; - - - /** Memory-map the model file in Assets. */ - private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename) - throws IOException { - AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename); - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); - FileChannel fileChannel = inputStream.getChannel(); - long startOffset = fileDescriptor.getStartOffset(); - long declaredLength = fileDescriptor.getDeclaredLength(); - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - } - - /** - * Initializes a native TensorFlow session for classifying images. - * - * @param assetManager The asset manager to be used to load assets. - * @param modelFilename The filepath of the model GraphDef protocol buffer. - * @param labelFilename The filepath of label file for classes. - * @param inputSize The size of image input - * @param isQuantized Boolean representing model is quantized or not - */ - public static Classifier create( - final AssetManager assetManager, - final String modelFilename, - final String labelFilename, - final int inputSize, - final boolean isQuantized) - throws IOException { - final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel(); - - InputStream labelsInput = null; - labelsInput = assetManager.open(labelFilename); - BufferedReader br = null; - br = new BufferedReader(new InputStreamReader(labelsInput)); - String line; - while ((line = br.readLine()) != null) { - LOGGER.w(line); - d.labels.add(line); - } - br.close(); - - d.inputSize = inputSize; - - try { - d.tfLite = new Interpreter(loadModelFile(assetManager, modelFilename)); - } catch (Exception e) { - throw new RuntimeException(e); - } - - d.isModelQuantized = isQuantized; - // Pre-allocate buffers. - int numBytesPerChannel; - if (isQuantized) { - numBytesPerChannel = 1; // Quantized - } else { - numBytesPerChannel = 4; // Floating point - } - d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel); - d.imgData.order(ByteOrder.nativeOrder()); - d.intValues = new int[d.inputSize * d.inputSize]; - - d.tfLite.setNumThreads(NUM_THREADS); - d.outputLocations = new float[1][NUM_DETECTIONS][4]; - d.outputClasses = new float[1][NUM_DETECTIONS]; - d.outputScores = new float[1][NUM_DETECTIONS]; - d.numDetections = new float[1]; - return d; - } - - private TFLiteObjectDetectionAPIModel() {} - - @Override - public List recognizeImage(final Bitmap bitmap) { - // Log this method so that it can be analyzed with systrace. - Trace.beginSection("recognizeImage"); - - Trace.beginSection("preprocessBitmap"); - // Preprocess the image data from 0-255 int to normalized float based - // on the provided parameters. - bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); - - imgData.rewind(); - for (int i = 0; i < inputSize; ++i) { - for (int j = 0; j < inputSize; ++j) { - int pixelValue = intValues[i * inputSize + j]; - if (isModelQuantized) { - // Quantized model - imgData.put((byte) ((pixelValue >> 16) & 0xFF)); - imgData.put((byte) ((pixelValue >> 8) & 0xFF)); - imgData.put((byte) (pixelValue & 0xFF)); - } else { // Float model - imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); - imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); - imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); - } - } - } - Trace.endSection(); // preprocessBitmap - - // Copy the input data into TensorFlow. - Trace.beginSection("feed"); - outputLocations = new float[1][NUM_DETECTIONS][4]; - outputClasses = new float[1][NUM_DETECTIONS]; - outputScores = new float[1][NUM_DETECTIONS]; - numDetections = new float[1]; - - Object[] inputArray = {imgData}; - Map outputMap = new HashMap<>(); - outputMap.put(0, outputLocations); - outputMap.put(1, outputClasses); - outputMap.put(2, outputScores); - outputMap.put(3, numDetections); - Trace.endSection(); - - // Run the inference call. - Trace.beginSection("run"); - tfLite.runForMultipleInputsOutputs(inputArray, outputMap); - Trace.endSection(); - - // Show the best detections. - // after scaling them back to the input size. - final ArrayList recognitions = new ArrayList<>(NUM_DETECTIONS); - for (int i = 0; i < NUM_DETECTIONS; ++i) { - final RectF detection = - new RectF( - outputLocations[0][i][1] * inputSize, - outputLocations[0][i][0] * inputSize, - outputLocations[0][i][3] * inputSize, - outputLocations[0][i][2] * inputSize); - // SSD Mobilenet V1 Model assumes class 0 is background class - // in label file and class labels start from 1 to number_of_classes+1, - // while outputClasses correspond to class index from 0 to number_of_classes - int labelOffset = 1; - recognitions.add( - new Recognition( - "" + i, - labels.get((int) outputClasses[0][i] + labelOffset), - outputScores[0][i], - detection)); - } - Trace.endSection(); // "recognizeImage" - return recognitions; - } - - @Override - public void enableStatLogging(final boolean logStats) { - } - - @Override - public String getStatString() { - return ""; - } - - @Override - public void close() { - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java deleted file mode 100644 index c50efdf8891..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo.env; - -import android.content.Context; -import android.content.res.AssetManager; -import android.util.Log; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; - -/** Utilities for dealing with assets. */ -public class AssetUtils { - - private static final String TAG = AssetUtils.class.getSimpleName(); - - private static final int BYTE_BUF_SIZE = 2048; - - /** - * Copies a file from assets. - * - * @param context application context used to discover assets. - * @param assetName the relative file name within assets. - * @param targetName the target file name, always over write the existing file. - * @throws IOException if operation fails. - */ - public static void copy(Context context, String assetName, String targetName) throws IOException { - - Log.d(TAG, "creating file " + targetName + " from " + assetName); - - File targetFile = null; - InputStream inputStream = null; - FileOutputStream outputStream = null; - - try { - AssetManager assets = context.getAssets(); - targetFile = new File(targetName); - inputStream = assets.open(assetName); - // TODO(kanlig): refactor log messages to make them more useful. - Log.d(TAG, "Creating outputstream"); - outputStream = new FileOutputStream(targetFile, false /* append */); - copy(inputStream, outputStream); - } finally { - if (outputStream != null) { - outputStream.close(); - } - if (inputStream != null) { - inputStream.close(); - } - } - } - - private static void copy(InputStream from, OutputStream to) throws IOException { - byte[] buf = new byte[BYTE_BUF_SIZE]; - while (true) { - int r = from.read(buf); - if (r == -1) { - break; - } - to.write(buf, 0, r); - } - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java deleted file mode 100644 index decfc3d8793..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo.env; - -import android.graphics.Canvas; -import android.graphics.Color; -import android.graphics.Paint; -import android.graphics.Paint.Align; -import android.graphics.Paint.Style; -import android.graphics.Rect; -import android.graphics.Typeface; -import java.util.Vector; - -/** - * A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas. - */ -public class BorderedText { - private final Paint interiorPaint; - private final Paint exteriorPaint; - - private final float textSize; - - /** - * Creates a left-aligned bordered text object with a white interior, and a black exterior with - * the specified text size. - * - * @param textSize text size in pixels - */ - public BorderedText(final float textSize) { - this(Color.WHITE, Color.BLACK, textSize); - } - - /** - * Create a bordered text object with the specified interior and exterior colors, text size and - * alignment. - * - * @param interiorColor the interior text color - * @param exteriorColor the exterior text color - * @param textSize text size in pixels - */ - public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) { - interiorPaint = new Paint(); - interiorPaint.setTextSize(textSize); - interiorPaint.setColor(interiorColor); - interiorPaint.setStyle(Style.FILL); - interiorPaint.setAntiAlias(false); - interiorPaint.setAlpha(255); - - exteriorPaint = new Paint(); - exteriorPaint.setTextSize(textSize); - exteriorPaint.setColor(exteriorColor); - exteriorPaint.setStyle(Style.FILL_AND_STROKE); - exteriorPaint.setStrokeWidth(textSize / 8); - exteriorPaint.setAntiAlias(false); - exteriorPaint.setAlpha(255); - - this.textSize = textSize; - } - - public void setTypeface(Typeface typeface) { - interiorPaint.setTypeface(typeface); - exteriorPaint.setTypeface(typeface); - } - - public void drawText(final Canvas canvas, final float posX, final float posY, final String text) { - canvas.drawText(text, posX, posY, exteriorPaint); - canvas.drawText(text, posX, posY, interiorPaint); - } - - public void drawLines(Canvas canvas, final float posX, final float posY, Vector lines) { - int lineNum = 0; - for (final String line : lines) { - drawText(canvas, posX, posY - getTextSize() * (lines.size() - lineNum - 1), line); - ++lineNum; - } - } - - public void setInteriorColor(final int color) { - interiorPaint.setColor(color); - } - - public void setExteriorColor(final int color) { - exteriorPaint.setColor(color); - } - - public float getTextSize() { - return textSize; - } - - public void setAlpha(final int alpha) { - interiorPaint.setAlpha(alpha); - exteriorPaint.setAlpha(alpha); - } - - public void getTextBounds( - final String line, final int index, final int count, final Rect lineBounds) { - interiorPaint.getTextBounds(line, index, count, lineBounds); - } - - public void setTextAlign(final Align align) { - interiorPaint.setTextAlign(align); - exteriorPaint.setTextAlign(align); - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java deleted file mode 100644 index e02c6559176..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java +++ /dev/null @@ -1,344 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo.env; - -import android.graphics.Bitmap; -import android.graphics.Matrix; -import android.os.Environment; -import java.io.File; -import java.io.FileOutputStream; - -/** - * Utility class for manipulating images. - **/ -public class ImageUtils { - @SuppressWarnings("unused") - private static final Logger LOGGER = new Logger(); - - static { - try { - System.loadLibrary("tensorflow_demo"); - } catch (UnsatisfiedLinkError e) { - LOGGER.w("Native library not found, native RGB -> YUV conversion may be unavailable."); - } - } - - /** - * Utility method to compute the allocated size in bytes of a YUV420SP image - * of the given dimensions. - */ - public static int getYUVByteSize(final int width, final int height) { - // The luminance plane requires 1 byte per pixel. - final int ySize = width * height; - - // The UV plane works on 2x2 blocks, so dimensions with odd size must be rounded up. - // Each 2x2 block takes 2 bytes to encode, one each for U and V. - final int uvSize = ((width + 1) / 2) * ((height + 1) / 2) * 2; - - return ySize + uvSize; - } - - /** - * Saves a Bitmap object to disk for analysis. - * - * @param bitmap The bitmap to save. - */ - public static void saveBitmap(final Bitmap bitmap) { - saveBitmap(bitmap, "preview.png"); - } - - /** - * Saves a Bitmap object to disk for analysis. - * - * @param bitmap The bitmap to save. - * @param filename The location to save the bitmap to. - */ - public static void saveBitmap(final Bitmap bitmap, final String filename) { - final String root = - Environment.getExternalStorageDirectory().getAbsolutePath() + File.separator + "tensorflow"; - LOGGER.i("Saving %dx%d bitmap to %s.", bitmap.getWidth(), bitmap.getHeight(), root); - final File myDir = new File(root); - - if (!myDir.mkdirs()) { - LOGGER.i("Make dir failed"); - } - - final String fname = filename; - final File file = new File(myDir, fname); - if (file.exists()) { - file.delete(); - } - try { - final FileOutputStream out = new FileOutputStream(file); - bitmap.compress(Bitmap.CompressFormat.PNG, 99, out); - out.flush(); - out.close(); - } catch (final Exception e) { - LOGGER.e(e, "Exception!"); - } - } - - // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges - // are normalized to eight bits. - static final int kMaxChannelValue = 262143; - - // Always prefer the native implementation if available. - private static boolean useNativeConversion = false; - - public static void convertYUV420SPToARGB8888( - byte[] input, - int width, - int height, - int[] output) { - if (useNativeConversion) { - try { - ImageUtils.convertYUV420SPToARGB8888(input, output, width, height, false); - return; - } catch (UnsatisfiedLinkError e) { - LOGGER.w( - "Native YUV420SP -> RGB implementation not found, falling back to Java implementation"); - useNativeConversion = false; - } - } - - // Java implementation of YUV420SP to ARGB8888 converting - final int frameSize = width * height; - for (int j = 0, yp = 0; j < height; j++) { - int uvp = frameSize + (j >> 1) * width; - int u = 0; - int v = 0; - - for (int i = 0; i < width; i++, yp++) { - int y = 0xff & input[yp]; - if ((i & 1) == 0) { - v = 0xff & input[uvp++]; - u = 0xff & input[uvp++]; - } - - output[yp] = YUV2RGB(y, u, v); - } - } - } - - private static int YUV2RGB(int y, int u, int v) { - // Adjust and check YUV values - y = (y - 16) < 0 ? 0 : (y - 16); - u -= 128; - v -= 128; - - // This is the floating point equivalent. We do the conversion in integer - // because some Android devices do not have floating point in hardware. - // nR = (int)(1.164 * nY + 2.018 * nU); - // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); - // nB = (int)(1.164 * nY + 1.596 * nV); - int y1192 = 1192 * y; - int r = (y1192 + 1634 * v); - int g = (y1192 - 833 * v - 400 * u); - int b = (y1192 + 2066 * u); - - // Clipping RGB values to be inside boundaries [ 0 , kMaxChannelValue ] - r = r > kMaxChannelValue ? kMaxChannelValue : (r < 0 ? 0 : r); - g = g > kMaxChannelValue ? kMaxChannelValue : (g < 0 ? 0 : g); - b = b > kMaxChannelValue ? kMaxChannelValue : (b < 0 ? 0 : b); - - return 0xff000000 | ((r << 6) & 0xff0000) | ((g >> 2) & 0xff00) | ((b >> 10) & 0xff); - } - - - public static void convertYUV420ToARGB8888( - byte[] yData, - byte[] uData, - byte[] vData, - int width, - int height, - int yRowStride, - int uvRowStride, - int uvPixelStride, - int[] out) { - if (useNativeConversion) { - try { - convertYUV420ToARGB8888( - yData, uData, vData, out, width, height, yRowStride, uvRowStride, uvPixelStride, false); - return; - } catch (UnsatisfiedLinkError e) { - LOGGER.w( - "Native YUV420 -> RGB implementation not found, falling back to Java implementation"); - useNativeConversion = false; - } - } - - int yp = 0; - for (int j = 0; j < height; j++) { - int pY = yRowStride * j; - int pUV = uvRowStride * (j >> 1); - - for (int i = 0; i < width; i++) { - int uv_offset = pUV + (i >> 1) * uvPixelStride; - - out[yp++] = YUV2RGB( - 0xff & yData[pY + i], - 0xff & uData[uv_offset], - 0xff & vData[uv_offset]); - } - } - } - - - /** - * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width and height. The - * input and output must already be allocated and non-null. For efficiency, no error checking is - * performed. - * - * @param input The array of YUV 4:2:0 input data. - * @param output A pre-allocated array for the ARGB 8:8:8:8 output data. - * @param width The width of the input image. - * @param height The height of the input image. - * @param halfSize If true, downsample to 50% in each dimension, otherwise not. - */ - private static native void convertYUV420SPToARGB8888( - byte[] input, int[] output, int width, int height, boolean halfSize); - - /** - * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width - * and height. The input and output must already be allocated and non-null. - * For efficiency, no error checking is performed. - * - * @param y - * @param u - * @param v - * @param uvPixelStride - * @param width The width of the input image. - * @param height The height of the input image. - * @param halfSize If true, downsample to 50% in each dimension, otherwise not. - * @param output A pre-allocated array for the ARGB 8:8:8:8 output data. - */ - private static native void convertYUV420ToARGB8888( - byte[] y, - byte[] u, - byte[] v, - int[] output, - int width, - int height, - int yRowStride, - int uvRowStride, - int uvPixelStride, - boolean halfSize); - - /** - * Converts YUV420 semi-planar data to RGB 565 data using the supplied width - * and height. The input and output must already be allocated and non-null. - * For efficiency, no error checking is performed. - * - * @param input The array of YUV 4:2:0 input data. - * @param output A pre-allocated array for the RGB 5:6:5 output data. - * @param width The width of the input image. - * @param height The height of the input image. - */ - private static native void convertYUV420SPToRGB565( - byte[] input, byte[] output, int width, int height); - - /** - * Converts 32-bit ARGB8888 image data to YUV420SP data. This is useful, for - * instance, in creating data to feed the classes that rely on raw camera - * preview frames. - * - * @param input An array of input pixels in ARGB8888 format. - * @param output A pre-allocated array for the YUV420SP output data. - * @param width The width of the input image. - * @param height The height of the input image. - */ - private static native void convertARGB8888ToYUV420SP( - int[] input, byte[] output, int width, int height); - - /** - * Converts 16-bit RGB565 image data to YUV420SP data. This is useful, for - * instance, in creating data to feed the classes that rely on raw camera - * preview frames. - * - * @param input An array of input pixels in RGB565 format. - * @param output A pre-allocated array for the YUV420SP output data. - * @param width The width of the input image. - * @param height The height of the input image. - */ - private static native void convertRGB565ToYUV420SP( - byte[] input, byte[] output, int width, int height); - - /** - * Returns a transformation matrix from one reference frame into another. - * Handles cropping (if maintaining aspect ratio is desired) and rotation. - * - * @param srcWidth Width of source frame. - * @param srcHeight Height of source frame. - * @param dstWidth Width of destination frame. - * @param dstHeight Height of destination frame. - * @param applyRotation Amount of rotation to apply from one frame to another. - * Must be a multiple of 90. - * @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant, - * cropping the image if necessary. - * @return The transformation fulfilling the desired requirements. - */ - public static Matrix getTransformationMatrix( - final int srcWidth, - final int srcHeight, - final int dstWidth, - final int dstHeight, - final int applyRotation, - final boolean maintainAspectRatio) { - final Matrix matrix = new Matrix(); - - if (applyRotation != 0) { - if (applyRotation % 90 != 0) { - LOGGER.w("Rotation of %d % 90 != 0", applyRotation); - } - - // Translate so center of image is at origin. - matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f); - - // Rotate around origin. - matrix.postRotate(applyRotation); - } - - // Account for the already applied rotation, if any, and then determine how - // much scaling is needed for each axis. - final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0; - - final int inWidth = transpose ? srcHeight : srcWidth; - final int inHeight = transpose ? srcWidth : srcHeight; - - // Apply scaling if necessary. - if (inWidth != dstWidth || inHeight != dstHeight) { - final float scaleFactorX = dstWidth / (float) inWidth; - final float scaleFactorY = dstHeight / (float) inHeight; - - if (maintainAspectRatio) { - // Scale by minimum factor so that dst is filled completely while - // maintaining the aspect ratio. Some image may fall off the edge. - final float scaleFactor = Math.max(scaleFactorX, scaleFactorY); - matrix.postScale(scaleFactor, scaleFactor); - } else { - // Scale exactly to fill dst from src. - matrix.postScale(scaleFactorX, scaleFactorY); - } - } - - if (applyRotation != 0) { - // Translate back from origin centered reference to destination frame. - matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f); - } - - return matrix; - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java deleted file mode 100644 index 0d984096a08..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo.env; - -import android.util.Log; -import java.util.HashSet; -import java.util.Set; - -/** - * Wrapper for the platform log function, allows convenient message prefixing and log disabling. - */ -public final class Logger { - private static final String DEFAULT_TAG = "tensorflow"; - private static final int DEFAULT_MIN_LOG_LEVEL = Log.DEBUG; - - // Classes to be ignored when examining the stack trace - private static final Set IGNORED_CLASS_NAMES; - - static { - IGNORED_CLASS_NAMES = new HashSet(3); - IGNORED_CLASS_NAMES.add("dalvik.system.VMStack"); - IGNORED_CLASS_NAMES.add("java.lang.Thread"); - IGNORED_CLASS_NAMES.add(Logger.class.getCanonicalName()); - } - - private final String tag; - private final String messagePrefix; - private int minLogLevel = DEFAULT_MIN_LOG_LEVEL; - - /** - * Creates a Logger using the class name as the message prefix. - * - * @param clazz the simple name of this class is used as the message prefix. - */ - public Logger(final Class clazz) { - this(clazz.getSimpleName()); - } - - /** - * Creates a Logger using the specified message prefix. - * - * @param messagePrefix is prepended to the text of every message. - */ - public Logger(final String messagePrefix) { - this(DEFAULT_TAG, messagePrefix); - } - - /** - * Creates a Logger with a custom tag and a custom message prefix. If the message prefix - * is set to

null
, the caller's class name is used as the prefix. - * - * @param tag identifies the source of a log message. - * @param messagePrefix prepended to every message if non-null. If null, the name of the caller is - * being used - */ - public Logger(final String tag, final String messagePrefix) { - this.tag = tag; - final String prefix = messagePrefix == null ? getCallerSimpleName() : messagePrefix; - this.messagePrefix = (prefix.length() > 0) ? prefix + ": " : prefix; - } - - /** - * Creates a Logger using the caller's class name as the message prefix. - */ - public Logger() { - this(DEFAULT_TAG, null); - } - - /** - * Creates a Logger using the caller's class name as the message prefix. - */ - public Logger(final int minLogLevel) { - this(DEFAULT_TAG, null); - this.minLogLevel = minLogLevel; - } - - public void setMinLogLevel(final int minLogLevel) { - this.minLogLevel = minLogLevel; - } - - public boolean isLoggable(final int logLevel) { - return logLevel >= minLogLevel || Log.isLoggable(tag, logLevel); - } - - /** - * Return caller's simple name. - * - * Android getStackTrace() returns an array that looks like this: - * stackTrace[0]: dalvik.system.VMStack - * stackTrace[1]: java.lang.Thread - * stackTrace[2]: com.google.android.apps.unveil.env.UnveilLogger - * stackTrace[3]: com.google.android.apps.unveil.BaseApplication - * - * This function returns the simple version of the first non-filtered name. - * - * @return caller's simple name - */ - private static String getCallerSimpleName() { - // Get the current callstack so we can pull the class of the caller off of it. - final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); - - for (final StackTraceElement elem : stackTrace) { - final String className = elem.getClassName(); - if (!IGNORED_CLASS_NAMES.contains(className)) { - // We're only interested in the simple name of the class, not the complete package. - final String[] classParts = className.split("\\."); - return classParts[classParts.length - 1]; - } - } - - return Logger.class.getSimpleName(); - } - - private String toMessage(final String format, final Object... args) { - return messagePrefix + (args.length > 0 ? String.format(format, args) : format); - } - - public void v(final String format, final Object... args) { - if (isLoggable(Log.VERBOSE)) { - Log.v(tag, toMessage(format, args)); - } - } - - public void v(final Throwable t, final String format, final Object... args) { - if (isLoggable(Log.VERBOSE)) { - Log.v(tag, toMessage(format, args), t); - } - } - - public void d(final String format, final Object... args) { - if (isLoggable(Log.DEBUG)) { - Log.d(tag, toMessage(format, args)); - } - } - - public void d(final Throwable t, final String format, final Object... args) { - if (isLoggable(Log.DEBUG)) { - Log.d(tag, toMessage(format, args), t); - } - } - - public void i(final String format, final Object... args) { - if (isLoggable(Log.INFO)) { - Log.i(tag, toMessage(format, args)); - } - } - - public void i(final Throwable t, final String format, final Object... args) { - if (isLoggable(Log.INFO)) { - Log.i(tag, toMessage(format, args), t); - } - } - - public void w(final String format, final Object... args) { - if (isLoggable(Log.WARN)) { - Log.w(tag, toMessage(format, args)); - } - } - - public void w(final Throwable t, final String format, final Object... args) { - if (isLoggable(Log.WARN)) { - Log.w(tag, toMessage(format, args), t); - } - } - - public void e(final String format, final Object... args) { - if (isLoggable(Log.ERROR)) { - Log.e(tag, toMessage(format, args)); - } - } - - public void e(final Throwable t, final String format, final Object... args) { - if (isLoggable(Log.ERROR)) { - Log.e(tag, toMessage(format, args), t); - } - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java deleted file mode 100644 index ef15d14daa8..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo.env; - -import android.graphics.Bitmap; -import android.text.TextUtils; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -/** - * Size class independent of a Camera object. - */ -public class Size implements Comparable, Serializable { - - // 1.4 went out with this UID so we'll need to maintain it to preserve pending queries when - // upgrading. - public static final long serialVersionUID = 7689808733290872361L; - - public final int width; - public final int height; - - public Size(final int width, final int height) { - this.width = width; - this.height = height; - } - - public Size(final Bitmap bmp) { - this.width = bmp.getWidth(); - this.height = bmp.getHeight(); - } - - /** - * Rotate a size by the given number of degrees. - * @param size Size to rotate. - * @param rotation Degrees {0, 90, 180, 270} to rotate the size. - * @return Rotated size. - */ - public static Size getRotatedSize(final Size size, final int rotation) { - if (rotation % 180 != 0) { - // The phone is portrait, therefore the camera is sideways and frame should be rotated. - return new Size(size.height, size.width); - } - return size; - } - - public static Size parseFromString(String sizeString) { - if (TextUtils.isEmpty(sizeString)) { - return null; - } - - sizeString = sizeString.trim(); - - // The expected format is "x". - final String[] components = sizeString.split("x"); - if (components.length == 2) { - try { - final int width = Integer.parseInt(components[0]); - final int height = Integer.parseInt(components[1]); - return new Size(width, height); - } catch (final NumberFormatException e) { - return null; - } - } else { - return null; - } - } - - public static List sizeStringToList(final String sizes) { - final List sizeList = new ArrayList(); - if (sizes != null) { - final String[] pairs = sizes.split(","); - for (final String pair : pairs) { - final Size size = Size.parseFromString(pair); - if (size != null) { - sizeList.add(size); - } - } - } - return sizeList; - } - - public static String sizeListToString(final List sizes) { - String sizesString = ""; - if (sizes != null && sizes.size() > 0) { - sizesString = sizes.get(0).toString(); - for (int i = 1; i < sizes.size(); i++) { - sizesString += "," + sizes.get(i).toString(); - } - } - return sizesString; - } - - public final float aspectRatio() { - return (float) width / (float) height; - } - - @Override - public int compareTo(final Size other) { - return width * height - other.width * other.height; - } - - @Override - public boolean equals(final Object other) { - if (other == null) { - return false; - } - - if (!(other instanceof Size)) { - return false; - } - - final Size otherSize = (Size) other; - return (width == otherSize.width && height == otherSize.height); - } - - @Override - public int hashCode() { - return width * 32713 + height; - } - - @Override - public String toString() { - return dimensionsAsString(width, height); - } - - public static final String dimensionsAsString(final int width, final int height) { - return width + "x" + height; - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java deleted file mode 100644 index 459b0a0d4db..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo.env; - -import android.os.SystemClock; - -/** - * A simple utility timer for measuring CPU time and wall-clock splits. - */ -public class SplitTimer { - private final Logger logger; - - private long lastWallTime; - private long lastCpuTime; - - public SplitTimer(final String name) { - logger = new Logger(name); - newSplit(); - } - - public void newSplit() { - lastWallTime = SystemClock.uptimeMillis(); - lastCpuTime = SystemClock.currentThreadTimeMillis(); - } - - public void endSplit(final String splitName) { - final long currWallTime = SystemClock.uptimeMillis(); - final long currCpuTime = SystemClock.currentThreadTimeMillis(); - - logger.i( - "%s: cpu=%dms wall=%dms", - splitName, currCpuTime - lastCpuTime, currWallTime - lastWallTime); - - lastWallTime = currWallTime; - lastCpuTime = currCpuTime; - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java deleted file mode 100644 index af6af2bc8f5..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java +++ /dev/null @@ -1,421 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo.tracking; - -import android.content.Context; -import android.graphics.Canvas; -import android.graphics.Color; -import android.graphics.Matrix; -import android.graphics.Paint; -import android.graphics.Paint.Cap; -import android.graphics.Paint.Join; -import android.graphics.Paint.Style; -import android.graphics.RectF; -import android.text.TextUtils; -import android.util.Pair; -import android.util.TypedValue; -import android.widget.Toast; -import java.util.LinkedList; -import java.util.List; -import java.util.Queue; -import org.tensorflow.demo.Classifier.Recognition; -import org.tensorflow.demo.env.BorderedText; -import org.tensorflow.demo.env.ImageUtils; -import org.tensorflow.demo.env.Logger; - -/** - * A tracker wrapping ObjectTracker that also handles non-max suppression and matching existing - * objects to new detections. - */ -public class MultiBoxTracker { - private final Logger logger = new Logger(); - - private static final float TEXT_SIZE_DIP = 18; - - // Maximum percentage of a box that can be overlapped by another box at detection time. Otherwise - // the lower scored box (new or old) will be removed. - private static final float MAX_OVERLAP = 0.2f; - - private static final float MIN_SIZE = 16.0f; - - // Allow replacement of the tracked box with new results if - // correlation has dropped below this level. - private static final float MARGINAL_CORRELATION = 0.75f; - - // Consider object to be lost if correlation falls below this threshold. - private static final float MIN_CORRELATION = 0.3f; - - private static final int[] COLORS = { - Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE, - Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"), - Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"), - Color.parseColor("#AA33AA"), Color.parseColor("#0D0068") - }; - - private final Queue availableColors = new LinkedList(); - - public ObjectTracker objectTracker; - - final List> screenRects = new LinkedList>(); - - private static class TrackedRecognition { - ObjectTracker.TrackedObject trackedObject; - RectF location; - float detectionConfidence; - int color; - String title; - } - - private final List trackedObjects = new LinkedList(); - - private final Paint boxPaint = new Paint(); - - private final float textSizePx; - private final BorderedText borderedText; - - private Matrix frameToCanvasMatrix; - - private int frameWidth; - private int frameHeight; - - private int sensorOrientation; - private Context context; - - public MultiBoxTracker(final Context context) { - this.context = context; - for (final int color : COLORS) { - availableColors.add(color); - } - - boxPaint.setColor(Color.RED); - boxPaint.setStyle(Style.STROKE); - boxPaint.setStrokeWidth(12.0f); - boxPaint.setStrokeCap(Cap.ROUND); - boxPaint.setStrokeJoin(Join.ROUND); - boxPaint.setStrokeMiter(100); - - textSizePx = - TypedValue.applyDimension( - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics()); - borderedText = new BorderedText(textSizePx); - } - - private Matrix getFrameToCanvasMatrix() { - return frameToCanvasMatrix; - } - - public synchronized void drawDebug(final Canvas canvas) { - final Paint textPaint = new Paint(); - textPaint.setColor(Color.WHITE); - textPaint.setTextSize(60.0f); - - final Paint boxPaint = new Paint(); - boxPaint.setColor(Color.RED); - boxPaint.setAlpha(200); - boxPaint.setStyle(Style.STROKE); - - for (final Pair detection : screenRects) { - final RectF rect = detection.second; - canvas.drawRect(rect, boxPaint); - canvas.drawText("" + detection.first, rect.left, rect.top, textPaint); - borderedText.drawText(canvas, rect.centerX(), rect.centerY(), "" + detection.first); - } - - if (objectTracker == null) { - return; - } - - // Draw correlations. - for (final TrackedRecognition recognition : trackedObjects) { - final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; - - final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); - - if (getFrameToCanvasMatrix().mapRect(trackedPos)) { - final String labelString = String.format("%.2f", trackedObject.getCurrentCorrelation()); - borderedText.drawText(canvas, trackedPos.right, trackedPos.bottom, labelString); - } - } - - final Matrix matrix = getFrameToCanvasMatrix(); - objectTracker.drawDebug(canvas, matrix); - } - - public synchronized void trackResults( - final List results, final byte[] frame, final long timestamp) { - logger.i("Processing %d results from %d", results.size(), timestamp); - processResults(timestamp, results, frame); - } - - public synchronized void draw(final Canvas canvas) { - final boolean rotated = sensorOrientation % 180 == 90; - final float multiplier = - Math.min(canvas.getHeight() / (float) (rotated ? frameWidth : frameHeight), - canvas.getWidth() / (float) (rotated ? frameHeight : frameWidth)); - frameToCanvasMatrix = - ImageUtils.getTransformationMatrix( - frameWidth, - frameHeight, - (int) (multiplier * (rotated ? frameHeight : frameWidth)), - (int) (multiplier * (rotated ? frameWidth : frameHeight)), - sensorOrientation, - false); - for (final TrackedRecognition recognition : trackedObjects) { - final RectF trackedPos = - (objectTracker != null) - ? recognition.trackedObject.getTrackedPositionInPreviewFrame() - : new RectF(recognition.location); - - getFrameToCanvasMatrix().mapRect(trackedPos); - boxPaint.setColor(recognition.color); - - final float cornerSize = Math.min(trackedPos.width(), trackedPos.height()) / 8.0f; - canvas.drawRoundRect(trackedPos, cornerSize, cornerSize, boxPaint); - - final String labelString = - !TextUtils.isEmpty(recognition.title) - ? String.format("%s %.2f", recognition.title, recognition.detectionConfidence) - : String.format("%.2f", recognition.detectionConfidence); - borderedText.drawText(canvas, trackedPos.left + cornerSize, trackedPos.bottom, labelString); - } - } - - private boolean initialized = false; - - public synchronized void onFrame( - final int w, - final int h, - final int rowStride, - final int sensorOrientation, - final byte[] frame, - final long timestamp) { - if (objectTracker == null && !initialized) { - ObjectTracker.clearInstance(); - - logger.i("Initializing ObjectTracker: %dx%d", w, h); - objectTracker = ObjectTracker.getInstance(w, h, rowStride, true); - frameWidth = w; - frameHeight = h; - this.sensorOrientation = sensorOrientation; - initialized = true; - - if (objectTracker == null) { - String message = - "Object tracking support not found. " - + "See tensorflow/examples/android/README.md for details."; - Toast.makeText(context, message, Toast.LENGTH_LONG).show(); - logger.e(message); - } - } - - if (objectTracker == null) { - return; - } - - objectTracker.nextFrame(frame, null, timestamp, null, true); - - // Clean up any objects not worth tracking any more. - final LinkedList copyList = - new LinkedList(trackedObjects); - for (final TrackedRecognition recognition : copyList) { - final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; - final float correlation = trackedObject.getCurrentCorrelation(); - if (correlation < MIN_CORRELATION) { - logger.v("Removing tracked object %s because NCC is %.2f", trackedObject, correlation); - trackedObject.stopTracking(); - trackedObjects.remove(recognition); - - availableColors.add(recognition.color); - } - } - } - - private void processResults( - final long timestamp, final List results, final byte[] originalFrame) { - final List> rectsToTrack = new LinkedList>(); - - screenRects.clear(); - final Matrix rgbFrameToScreen = new Matrix(getFrameToCanvasMatrix()); - - for (final Recognition result : results) { - if (result.getLocation() == null) { - continue; - } - final RectF detectionFrameRect = new RectF(result.getLocation()); - - final RectF detectionScreenRect = new RectF(); - rgbFrameToScreen.mapRect(detectionScreenRect, detectionFrameRect); - - logger.v( - "Result! Frame: " + result.getLocation() + " mapped to screen:" + detectionScreenRect); - - screenRects.add(new Pair(result.getConfidence(), detectionScreenRect)); - - if (detectionFrameRect.width() < MIN_SIZE || detectionFrameRect.height() < MIN_SIZE) { - logger.w("Degenerate rectangle! " + detectionFrameRect); - continue; - } - - rectsToTrack.add(new Pair(result.getConfidence(), result)); - } - - if (rectsToTrack.isEmpty()) { - logger.v("Nothing to track, aborting."); - return; - } - - if (objectTracker == null) { - trackedObjects.clear(); - for (final Pair potential : rectsToTrack) { - final TrackedRecognition trackedRecognition = new TrackedRecognition(); - trackedRecognition.detectionConfidence = potential.first; - trackedRecognition.location = new RectF(potential.second.getLocation()); - trackedRecognition.trackedObject = null; - trackedRecognition.title = potential.second.getTitle(); - trackedRecognition.color = COLORS[trackedObjects.size()]; - trackedObjects.add(trackedRecognition); - - if (trackedObjects.size() >= COLORS.length) { - break; - } - } - return; - } - - logger.i("%d rects to track", rectsToTrack.size()); - for (final Pair potential : rectsToTrack) { - handleDetection(originalFrame, timestamp, potential); - } - } - - private void handleDetection( - final byte[] frameCopy, final long timestamp, final Pair potential) { - final ObjectTracker.TrackedObject potentialObject = - objectTracker.trackObject(potential.second.getLocation(), timestamp, frameCopy); - - final float potentialCorrelation = potentialObject.getCurrentCorrelation(); - logger.v( - "Tracked object went from %s to %s with correlation %.2f", - potential.second, potentialObject.getTrackedPositionInPreviewFrame(), potentialCorrelation); - - if (potentialCorrelation < MARGINAL_CORRELATION) { - logger.v("Correlation too low to begin tracking %s.", potentialObject); - potentialObject.stopTracking(); - return; - } - - final List removeList = new LinkedList(); - - float maxIntersect = 0.0f; - - // This is the current tracked object whose color we will take. If left null we'll take the - // first one from the color queue. - TrackedRecognition recogToReplace = null; - - // Look for intersections that will be overridden by this object or an intersection that would - // prevent this one from being placed. - for (final TrackedRecognition trackedRecognition : trackedObjects) { - final RectF a = trackedRecognition.trackedObject.getTrackedPositionInPreviewFrame(); - final RectF b = potentialObject.getTrackedPositionInPreviewFrame(); - final RectF intersection = new RectF(); - final boolean intersects = intersection.setIntersect(a, b); - - final float intersectArea = intersection.width() * intersection.height(); - final float totalArea = a.width() * a.height() + b.width() * b.height() - intersectArea; - final float intersectOverUnion = intersectArea / totalArea; - - // If there is an intersection with this currently tracked box above the maximum overlap - // percentage allowed, either the new recognition needs to be dismissed or the old - // recognition needs to be removed and possibly replaced with the new one. - if (intersects && intersectOverUnion > MAX_OVERLAP) { - if (potential.first < trackedRecognition.detectionConfidence - && trackedRecognition.trackedObject.getCurrentCorrelation() > MARGINAL_CORRELATION) { - // If track for the existing object is still going strong and the detection score was - // good, reject this new object. - potentialObject.stopTracking(); - return; - } else { - removeList.add(trackedRecognition); - - // Let the previously tracked object with max intersection amount donate its color to - // the new object. - if (intersectOverUnion > maxIntersect) { - maxIntersect = intersectOverUnion; - recogToReplace = trackedRecognition; - } - } - } - } - - // If we're already tracking the max object and no intersections were found to bump off, - // pick the worst current tracked object to remove, if it's also worse than this candidate - // object. - if (availableColors.isEmpty() && removeList.isEmpty()) { - for (final TrackedRecognition candidate : trackedObjects) { - if (candidate.detectionConfidence < potential.first) { - if (recogToReplace == null - || candidate.detectionConfidence < recogToReplace.detectionConfidence) { - // Save it so that we use this color for the new object. - recogToReplace = candidate; - } - } - } - if (recogToReplace != null) { - logger.v("Found non-intersecting object to remove."); - removeList.add(recogToReplace); - } else { - logger.v("No non-intersecting object found to remove"); - } - } - - // Remove everything that got intersected. - for (final TrackedRecognition trackedRecognition : removeList) { - logger.v( - "Removing tracked object %s with detection confidence %.2f, correlation %.2f", - trackedRecognition.trackedObject, - trackedRecognition.detectionConfidence, - trackedRecognition.trackedObject.getCurrentCorrelation()); - trackedRecognition.trackedObject.stopTracking(); - trackedObjects.remove(trackedRecognition); - if (trackedRecognition != recogToReplace) { - availableColors.add(trackedRecognition.color); - } - } - - if (recogToReplace == null && availableColors.isEmpty()) { - logger.e("No room to track this object, aborting."); - potentialObject.stopTracking(); - return; - } - - // Finally safe to say we can track this object. - logger.v( - "Tracking object %s (%s) with detection confidence %.2f at position %s", - potentialObject, - potential.second.getTitle(), - potential.first, - potential.second.getLocation()); - final TrackedRecognition trackedRecognition = new TrackedRecognition(); - trackedRecognition.detectionConfidence = potential.first; - trackedRecognition.trackedObject = potentialObject; - trackedRecognition.title = potential.second.getTitle(); - - // Use the color from a replaced object before taking one from the color queue. - trackedRecognition.color = - recogToReplace != null ? recogToReplace.color : availableColors.poll(); - trackedObjects.add(trackedRecognition); - } -} diff --git a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java deleted file mode 100644 index 8b4248d8fbc..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java +++ /dev/null @@ -1,661 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.demo.tracking; - -import android.graphics.Canvas; -import android.graphics.Color; -import android.graphics.Matrix; -import android.graphics.Paint; -import android.graphics.PointF; -import android.graphics.RectF; -import android.graphics.Typeface; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Vector; -import javax.microedition.khronos.opengles.GL10; -import org.tensorflow.demo.env.Logger; -import org.tensorflow.demo.env.Size; - -/** - * True object detector/tracker class that tracks objects across consecutive preview frames. - * It provides a simplified Java interface to the analogous native object defined by - * jni/client_vision/tracking/object_tracker.*. - * - * Currently, the ObjectTracker is a singleton due to native code restrictions, and so must - * be allocated by ObjectTracker.getInstance(). In addition, release() should be called - * as soon as the ObjectTracker is no longer needed, and before a new one is created. - * - * nextFrame() should be called as new frames become available, preferably as often as possible. - * - * After allocation, new TrackedObjects may be instantiated via trackObject(). TrackedObjects - * are associated with the ObjectTracker that created them, and are only valid while that - * ObjectTracker still exists. - */ -public class ObjectTracker { - private static final Logger LOGGER = new Logger(); - - private static boolean libraryFound = false; - - static { - try { - System.loadLibrary("tensorflow_demo"); - libraryFound = true; - } catch (UnsatisfiedLinkError e) { - LOGGER.e("libtensorflow_demo.so not found, tracking unavailable"); - } - } - - private static final boolean DRAW_TEXT = false; - - /** - * How many history points to keep track of and draw in the red history line. - */ - private static final int MAX_DEBUG_HISTORY_SIZE = 30; - - /** - * How many frames of optical flow deltas to record. - * TODO(andrewharp): Push this down to the native level so it can be polled - * efficiently into a an array for upload, instead of keeping a duplicate - * copy in Java. - */ - private static final int MAX_FRAME_HISTORY_SIZE = 200; - - private static final int DOWNSAMPLE_FACTOR = 2; - - private final byte[] downsampledFrame; - - protected static ObjectTracker instance; - - private final Map trackedObjects; - - private long lastTimestamp; - - private FrameChange lastKeypoints; - - private final Vector debugHistory; - - private final LinkedList timestampedDeltas; - - protected final int frameWidth; - protected final int frameHeight; - private final int rowStride; - protected final boolean alwaysTrack; - - private static class TimestampedDeltas { - final long timestamp; - final byte[] deltas; - - public TimestampedDeltas(final long timestamp, final byte[] deltas) { - this.timestamp = timestamp; - this.deltas = deltas; - } - } - - /** - * A simple class that records keypoint information, which includes - * local location, score and type. This will be used in calculating - * FrameChange. - */ - public static class Keypoint { - public final float x; - public final float y; - public final float score; - public final int type; - - public Keypoint(final float x, final float y) { - this.x = x; - this.y = y; - this.score = 0; - this.type = -1; - } - - public Keypoint(final float x, final float y, final float score, final int type) { - this.x = x; - this.y = y; - this.score = score; - this.type = type; - } - - Keypoint delta(final Keypoint other) { - return new Keypoint(this.x - other.x, this.y - other.y); - } - } - - /** - * A simple class that could calculate Keypoint delta. - * This class will be used in calculating frame translation delta - * for optical flow. - */ - public static class PointChange { - public final Keypoint keypointA; - public final Keypoint keypointB; - Keypoint pointDelta; - private final boolean wasFound; - - public PointChange(final float x1, final float y1, - final float x2, final float y2, - final float score, final int type, - final boolean wasFound) { - this.wasFound = wasFound; - - keypointA = new Keypoint(x1, y1, score, type); - keypointB = new Keypoint(x2, y2); - } - - public Keypoint getDelta() { - if (pointDelta == null) { - pointDelta = keypointB.delta(keypointA); - } - return pointDelta; - } - } - - /** A class that records a timestamped frame translation delta for optical flow. */ - public static class FrameChange { - public static final int KEYPOINT_STEP = 7; - - public final Vector pointDeltas; - - private final float minScore; - private final float maxScore; - - public FrameChange(final float[] framePoints) { - float minScore = 100.0f; - float maxScore = -100.0f; - - pointDeltas = new Vector(framePoints.length / KEYPOINT_STEP); - - for (int i = 0; i < framePoints.length; i += KEYPOINT_STEP) { - final float x1 = framePoints[i + 0] * DOWNSAMPLE_FACTOR; - final float y1 = framePoints[i + 1] * DOWNSAMPLE_FACTOR; - - final boolean wasFound = framePoints[i + 2] > 0.0f; - - final float x2 = framePoints[i + 3] * DOWNSAMPLE_FACTOR; - final float y2 = framePoints[i + 4] * DOWNSAMPLE_FACTOR; - final float score = framePoints[i + 5]; - final int type = (int) framePoints[i + 6]; - - minScore = Math.min(minScore, score); - maxScore = Math.max(maxScore, score); - - pointDeltas.add(new PointChange(x1, y1, x2, y2, score, type, wasFound)); - } - - this.minScore = minScore; - this.maxScore = maxScore; - } - } - - public static synchronized ObjectTracker getInstance( - final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) { - if (!libraryFound) { - LOGGER.e( - "Native object tracking support not found. " - + "See tensorflow/examples/android/README.md for details."); - return null; - } - - if (instance == null) { - instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack); - instance.init(); - } else { - throw new RuntimeException( - "Tried to create a new objectracker before releasing the old one!"); - } - return instance; - } - - public static synchronized void clearInstance() { - if (instance != null) { - instance.release(); - } - } - - protected ObjectTracker( - final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) { - this.frameWidth = frameWidth; - this.frameHeight = frameHeight; - this.rowStride = rowStride; - this.alwaysTrack = alwaysTrack; - this.timestampedDeltas = new LinkedList(); - - trackedObjects = new HashMap(); - - debugHistory = new Vector(MAX_DEBUG_HISTORY_SIZE); - - downsampledFrame = - new byte - [(frameWidth + DOWNSAMPLE_FACTOR - 1) - / DOWNSAMPLE_FACTOR - * (frameWidth + DOWNSAMPLE_FACTOR - 1) - / DOWNSAMPLE_FACTOR]; - } - - protected void init() { - // The native tracker never sees the full frame, so pre-scale dimensions - // by the downsample factor. - initNative(frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, alwaysTrack); - } - - private final float[] matrixValues = new float[9]; - - private long downsampledTimestamp; - - @SuppressWarnings("unused") - public synchronized void drawOverlay(final GL10 gl, - final Size cameraViewSize, final Matrix matrix) { - final Matrix tempMatrix = new Matrix(matrix); - tempMatrix.preScale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR); - tempMatrix.getValues(matrixValues); - drawNative(cameraViewSize.width, cameraViewSize.height, matrixValues); - } - - public synchronized void nextFrame( - final byte[] frameData, final byte[] uvData, - final long timestamp, final float[] transformationMatrix, - final boolean updateDebugInfo) { - if (downsampledTimestamp != timestamp) { - ObjectTracker.downsampleImageNative( - frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame); - downsampledTimestamp = timestamp; - } - - // Do Lucas Kanade using the fullframe initializer. - nextFrameNative(downsampledFrame, uvData, timestamp, transformationMatrix); - - timestampedDeltas.add(new TimestampedDeltas(timestamp, getKeypointsPacked(DOWNSAMPLE_FACTOR))); - while (timestampedDeltas.size() > MAX_FRAME_HISTORY_SIZE) { - timestampedDeltas.removeFirst(); - } - - for (final TrackedObject trackedObject : trackedObjects.values()) { - trackedObject.updateTrackedPosition(); - } - - if (updateDebugInfo) { - updateDebugHistory(); - } - - lastTimestamp = timestamp; - } - - public synchronized void release() { - releaseMemoryNative(); - synchronized (ObjectTracker.class) { - instance = null; - } - } - - private void drawHistoryDebug(final Canvas canvas) { - drawHistoryPoint( - canvas, frameWidth * DOWNSAMPLE_FACTOR / 2, frameHeight * DOWNSAMPLE_FACTOR / 2); - } - - private void drawHistoryPoint(final Canvas canvas, final float startX, final float startY) { - final Paint p = new Paint(); - p.setAntiAlias(false); - p.setTypeface(Typeface.SERIF); - - p.setColor(Color.RED); - p.setStrokeWidth(2.0f); - - // Draw the center circle. - p.setColor(Color.GREEN); - canvas.drawCircle(startX, startY, 3.0f, p); - - p.setColor(Color.RED); - - // Iterate through in backwards order. - synchronized (debugHistory) { - final int numPoints = debugHistory.size(); - float lastX = startX; - float lastY = startY; - for (int keypointNum = 0; keypointNum < numPoints; ++keypointNum) { - final PointF delta = debugHistory.get(numPoints - keypointNum - 1); - final float newX = lastX + delta.x; - final float newY = lastY + delta.y; - canvas.drawLine(lastX, lastY, newX, newY, p); - lastX = newX; - lastY = newY; - } - } - } - - private static int floatToChar(final float value) { - return Math.max(0, Math.min((int) (value * 255.999f), 255)); - } - - private void drawKeypointsDebug(final Canvas canvas) { - final Paint p = new Paint(); - if (lastKeypoints == null) { - return; - } - final int keypointSize = 3; - - final float minScore = lastKeypoints.minScore; - final float maxScore = lastKeypoints.maxScore; - - for (final PointChange keypoint : lastKeypoints.pointDeltas) { - if (keypoint.wasFound) { - final int r = - floatToChar((keypoint.keypointA.score - minScore) / (maxScore - minScore)); - final int b = - floatToChar(1.0f - (keypoint.keypointA.score - minScore) / (maxScore - minScore)); - - final int color = 0xFF000000 | (r << 16) | b; - p.setColor(color); - - final float[] screenPoints = {keypoint.keypointA.x, keypoint.keypointA.y, - keypoint.keypointB.x, keypoint.keypointB.y}; - canvas.drawRect(screenPoints[2] - keypointSize, - screenPoints[3] - keypointSize, - screenPoints[2] + keypointSize, - screenPoints[3] + keypointSize, p); - p.setColor(Color.CYAN); - canvas.drawLine(screenPoints[2], screenPoints[3], - screenPoints[0], screenPoints[1], p); - - if (DRAW_TEXT) { - p.setColor(Color.WHITE); - canvas.drawText(keypoint.keypointA.type + ": " + keypoint.keypointA.score, - keypoint.keypointA.x, keypoint.keypointA.y, p); - } - } else { - p.setColor(Color.YELLOW); - final float[] screenPoint = {keypoint.keypointA.x, keypoint.keypointA.y}; - canvas.drawCircle(screenPoint[0], screenPoint[1], 5.0f, p); - } - } - } - - private synchronized PointF getAccumulatedDelta(final long timestamp, final float positionX, - final float positionY, final float radius) { - final RectF currPosition = getCurrentPosition(timestamp, - new RectF(positionX - radius, positionY - radius, positionX + radius, positionY + radius)); - return new PointF(currPosition.centerX() - positionX, currPosition.centerY() - positionY); - } - - private synchronized RectF getCurrentPosition(final long timestamp, final RectF - oldPosition) { - final RectF downscaledFrameRect = downscaleRect(oldPosition); - - final float[] delta = new float[4]; - getCurrentPositionNative(timestamp, downscaledFrameRect.left, downscaledFrameRect.top, - downscaledFrameRect.right, downscaledFrameRect.bottom, delta); - - final RectF newPosition = new RectF(delta[0], delta[1], delta[2], delta[3]); - - return upscaleRect(newPosition); - } - - private void updateDebugHistory() { - lastKeypoints = new FrameChange(getKeypointsNative(false)); - - if (lastTimestamp == 0) { - return; - } - - final PointF delta = - getAccumulatedDelta( - lastTimestamp, frameWidth / DOWNSAMPLE_FACTOR, frameHeight / DOWNSAMPLE_FACTOR, 100); - - synchronized (debugHistory) { - debugHistory.add(delta); - - while (debugHistory.size() > MAX_DEBUG_HISTORY_SIZE) { - debugHistory.remove(0); - } - } - } - - public synchronized void drawDebug(final Canvas canvas, final Matrix frameToCanvas) { - canvas.save(); - canvas.setMatrix(frameToCanvas); - - drawHistoryDebug(canvas); - drawKeypointsDebug(canvas); - - canvas.restore(); - } - - public Vector getDebugText() { - final Vector lines = new Vector(); - - if (lastKeypoints != null) { - lines.add("Num keypoints " + lastKeypoints.pointDeltas.size()); - lines.add("Min score: " + lastKeypoints.minScore); - lines.add("Max score: " + lastKeypoints.maxScore); - } - - return lines; - } - - public synchronized List pollAccumulatedFlowData(final long endFrameTime) { - final List frameDeltas = new ArrayList(); - while (timestampedDeltas.size() > 0) { - final TimestampedDeltas currentDeltas = timestampedDeltas.peek(); - if (currentDeltas.timestamp <= endFrameTime) { - frameDeltas.add(currentDeltas.deltas); - timestampedDeltas.removeFirst(); - } else { - break; - } - } - - return frameDeltas; - } - - private RectF downscaleRect(final RectF fullFrameRect) { - return new RectF( - fullFrameRect.left / DOWNSAMPLE_FACTOR, - fullFrameRect.top / DOWNSAMPLE_FACTOR, - fullFrameRect.right / DOWNSAMPLE_FACTOR, - fullFrameRect.bottom / DOWNSAMPLE_FACTOR); - } - - private RectF upscaleRect(final RectF downsampledFrameRect) { - return new RectF( - downsampledFrameRect.left * DOWNSAMPLE_FACTOR, - downsampledFrameRect.top * DOWNSAMPLE_FACTOR, - downsampledFrameRect.right * DOWNSAMPLE_FACTOR, - downsampledFrameRect.bottom * DOWNSAMPLE_FACTOR); - } - - /** - * A TrackedObject represents a native TrackedObject, and provides access to the - * relevant native tracking information available after every frame update. They may - * be safely passed around and accessed externally, but will become invalid after - * stopTracking() is called or the related creating ObjectTracker is deactivated. - * - * @author andrewharp@google.com (Andrew Harp) - */ - public class TrackedObject { - private final String id; - - private long lastExternalPositionTime; - - private RectF lastTrackedPosition; - private boolean visibleInLastFrame; - - private boolean isDead; - - TrackedObject(final RectF position, final long timestamp, final byte[] data) { - isDead = false; - - id = Integer.toString(this.hashCode()); - - lastExternalPositionTime = timestamp; - - synchronized (ObjectTracker.this) { - registerInitialAppearance(position, data); - setPreviousPosition(position, timestamp); - trackedObjects.put(id, this); - } - } - - public void stopTracking() { - checkValidObject(); - - synchronized (ObjectTracker.this) { - isDead = true; - forgetNative(id); - trackedObjects.remove(id); - } - } - - public float getCurrentCorrelation() { - checkValidObject(); - return ObjectTracker.this.getCurrentCorrelation(id); - } - - void registerInitialAppearance(final RectF position, final byte[] data) { - final RectF externalPosition = downscaleRect(position); - registerNewObjectWithAppearanceNative(id, - externalPosition.left, externalPosition.top, - externalPosition.right, externalPosition.bottom, - data); - } - - synchronized void setPreviousPosition(final RectF position, final long timestamp) { - checkValidObject(); - synchronized (ObjectTracker.this) { - if (lastExternalPositionTime > timestamp) { - LOGGER.w("Tried to use older position time!"); - return; - } - final RectF externalPosition = downscaleRect(position); - lastExternalPositionTime = timestamp; - - setPreviousPositionNative(id, - externalPosition.left, externalPosition.top, - externalPosition.right, externalPosition.bottom, - lastExternalPositionTime); - - updateTrackedPosition(); - } - } - - void setCurrentPosition(final RectF position) { - checkValidObject(); - final RectF downsampledPosition = downscaleRect(position); - synchronized (ObjectTracker.this) { - setCurrentPositionNative(id, - downsampledPosition.left, downsampledPosition.top, - downsampledPosition.right, downsampledPosition.bottom); - } - } - - private synchronized void updateTrackedPosition() { - checkValidObject(); - - final float[] delta = new float[4]; - getTrackedPositionNative(id, delta); - lastTrackedPosition = new RectF(delta[0], delta[1], delta[2], delta[3]); - - visibleInLastFrame = isObjectVisible(id); - } - - public synchronized RectF getTrackedPositionInPreviewFrame() { - checkValidObject(); - - if (lastTrackedPosition == null) { - return null; - } - return upscaleRect(lastTrackedPosition); - } - - synchronized long getLastExternalPositionTime() { - return lastExternalPositionTime; - } - - public synchronized boolean visibleInLastPreviewFrame() { - return visibleInLastFrame; - } - - private void checkValidObject() { - if (isDead) { - throw new RuntimeException("TrackedObject already removed from tracking!"); - } else if (ObjectTracker.this != instance) { - throw new RuntimeException("TrackedObject created with another ObjectTracker!"); - } - } - } - - public synchronized TrackedObject trackObject( - final RectF position, final long timestamp, final byte[] frameData) { - if (downsampledTimestamp != timestamp) { - ObjectTracker.downsampleImageNative( - frameWidth, frameHeight, rowStride, frameData, DOWNSAMPLE_FACTOR, downsampledFrame); - downsampledTimestamp = timestamp; - } - return new TrackedObject(position, timestamp, downsampledFrame); - } - - public synchronized TrackedObject trackObject(final RectF position, final byte[] frameData) { - return new TrackedObject(position, lastTimestamp, frameData); - } - - /** ********************* NATIVE CODE ************************************ */ - - /** This will contain an opaque pointer to the native ObjectTracker */ - private long nativeObjectTracker; - - private native void initNative(int imageWidth, int imageHeight, boolean alwaysTrack); - - protected native void registerNewObjectWithAppearanceNative( - String objectId, float x1, float y1, float x2, float y2, byte[] data); - - protected native void setPreviousPositionNative( - String objectId, float x1, float y1, float x2, float y2, long timestamp); - - protected native void setCurrentPositionNative( - String objectId, float x1, float y1, float x2, float y2); - - protected native void forgetNative(String key); - - protected native String getModelIdNative(String key); - - protected native boolean haveObject(String key); - protected native boolean isObjectVisible(String key); - protected native float getCurrentCorrelation(String key); - - protected native float getMatchScore(String key); - - protected native void getTrackedPositionNative(String key, float[] points); - - protected native void nextFrameNative( - byte[] frameData, byte[] uvData, long timestamp, float[] frameAlignMatrix); - - protected native void releaseMemoryNative(); - - protected native void getCurrentPositionNative(long timestamp, - final float positionX1, final float positionY1, - final float positionX2, final float positionY2, - final float[] delta); - - protected native byte[] getKeypointsPacked(float scaleFactor); - - protected native float[] getKeypointsNative(boolean onlyReturnCorrespondingKeypoints); - - protected native void drawNative(int viewWidth, int viewHeight, float[] frameToCanvas); - - protected static native void downsampleImageNative( - int width, int height, int rowStride, byte[] input, int factor, byte[] output); -} diff --git a/tensorflow/lite/examples/android/app/src/main/res/animator/color_animation.xml b/tensorflow/lite/examples/android/app/src/main/res/animator/color_animation.xml deleted file mode 100644 index 891d8cc1d4f..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/res/animator/color_animation.xml +++ /dev/null @@ -1,30 +0,0 @@ - - - - - diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png deleted file mode 100644 index 32bd1aabcab..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png deleted file mode 100644 index b3113cd15c3..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png deleted file mode 100644 index 135862883e2..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png deleted file mode 100644 index 8efbbf8b3c4..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png deleted file mode 100644 index 51f87ee6507..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png deleted file mode 100644 index ba143ea7a80..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png deleted file mode 100644 index 6361d792dac..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png deleted file mode 100644 index 394eb7e5349..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png deleted file mode 100644 index 2e27bec9785..00000000000 Binary files a/tensorflow/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png and /dev/null differ diff --git a/tensorflow/lite/examples/android/app/src/main/res/drawable/border.xml b/tensorflow/lite/examples/android/app/src/main/res/drawable/border.xml deleted file mode 100644 index dd1d64d1d61..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/res/drawable/border.xml +++ /dev/null @@ -1,19 +0,0 @@ - - - - - diff --git a/tensorflow/lite/examples/android/app/src/main/res/layout/activity_camera.xml b/tensorflow/lite/examples/android/app/src/main/res/layout/activity_camera.xml deleted file mode 100644 index 1a22d4b33eb..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/res/layout/activity_camera.xml +++ /dev/null @@ -1,22 +0,0 @@ - - diff --git a/tensorflow/lite/examples/android/app/src/main/res/layout/activity_speech.xml b/tensorflow/lite/examples/android/app/src/main/res/layout/activity_speech.xml deleted file mode 100644 index 2fe1338da57..00000000000 --- a/tensorflow/lite/examples/android/app/src/main/res/layout/activity_speech.xml +++ /dev/null @@ -1,55 +0,0 @@ - - - - - - - -